File size: 4,452 Bytes
e68d535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from pathlib import Path
import pandas as pd

from backend.classes.chunker.text_chunker import RecursiveCharacterTextChunkerConfig, RecursiveCharacterTextChunker
from backend.classes.embedding_model import EmbeddingModelConfig, EmbeddingModel
from pydantic import BaseModel
import json
import dotenv

from backend.classes.vector_database.milvus_vector_database import MilvusVectorDatabaseConfig, MilvusVectorDatabase
from backend.utils.utils import get_embedding_model, read_config, set_env_variables, create_vector_database, \
    create_text_chunker, initialize_logger

dotenv.load_dotenv()

def get_files(folder_path: str, extension: str = "jsonl") -> list:
    # Get all pdf files from folder in a recursive manner using pathlib.Path
    files = []
    for path in Path(folder_path).rglob(f"*.{extension}"):
        files.append(path)

    return files


class ChunkerVectorDbConfig(BaseModel):
    folder_path: str
    chunker: RecursiveCharacterTextChunker
    vector_database: MilvusVectorDatabase
    embedding_model: EmbeddingModel

    class Config:
        arbitrary_types_allowed = True


def get_file_data(file_path: str) -> pd.DataFrame:
    try:
        return pd.read_json(file_path, lines=True)
    except Exception as e:
        logger.exception(e)
        raise e

def chunk_and_save_to_vector_db(config: ChunkerVectorDbConfig):
    # Read files from folder
    file_paths = get_files(config.folder_path)
    logger.info(f"There are {len(file_paths)} to process")

    # Extract text from pdf files
    for file_path in file_paths:
        # Extract pdf data in markdown
        logger.info(f"Processing {file_path}")
        data_df = get_file_data(str(file_path))

        # There are a few rows that are empty due to images not being extracted
        # Remove them
        data_df = data_df[data_df["markdown_text"] != ""]

        data_df["text_chunks"] = data_df["markdown_text"].apply(config.chunker.chunk_text)
        data_df = data_df.explode("text_chunks").rename(columns={"text_chunks": "text"})
        data_df["chunk_id"] = data_df.groupby("id").cumcount() + 1
        data_df["row_chunk_id"] = data_df["id"] + data_df["chunk_id"].astype(str)

        data_df["metadata_json"] = data_df["metadata"].apply(lambda d: json.dumps(d))
        data_df = data_df.drop(columns=["metadata", "id", "row_chunk_id", "markdown_text", "chunk_id"]).rename(columns={"metadata_json": "metadata"})

        embeddings = config.embedding_model.encode(data_df.text.tolist())
        config.vector_database.add_texts(data_df, embeddings)


def run(config: dict):
    # Create embedding model object
    embedding_model_config = EmbeddingModelConfig(model_name=config["embedding_model"]["model_name"],
                                                  batch_size=config["embedding_model"]["batch_size"])
    embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config)

    # Create vector db model object
    vector_db_config = MilvusVectorDatabaseConfig(db_path=config["vector_database"]["db_path"],
                                                  collection_name=config["vector_database"]["collection_name"],
                                                  vector_dimensions=config["vector_database"]["dimensions"])
    vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config)

    text_chunker_config = RecursiveCharacterTextChunkerConfig(chunk_size=config["chunker"]["chunk_size"],
                                                              chunk_overlap=config["chunker"]["chunk_overlap"])
    text_chunker = create_text_chunker(RecursiveCharacterTextChunker, text_chunker_config)

    chunker_vector_db_config = ChunkerVectorDbConfig(folder_path=config["data"]["output_data_path"],
                                                     chunker=text_chunker,
                                                     vector_database=vector_db,
                                                     embedding_model=embedding_model)

    chunk_and_save_to_vector_db(chunker_vector_db_config)


if __name__ == "__main__":
    logger = initialize_logger()

    # get current file path using Path
    config = read_config(str(Path(Path(__file__).parent, "../conf/config.yaml")))

    # check if environment variables are set
    env_variables = set_env_variables(config["env_variables"])

    app_config = config[env_variables["APP_ENV"]]
    app_config["env_vars"] = env_variables

    run(app_config)