Demos / backend /main /chunk_and_save_to_vector_db.py
nikhile-galileo's picture
Adding finance protect demo
e68d535
raw
history blame
4.45 kB
from pathlib import Path
import pandas as pd
from backend.classes.chunker.text_chunker import RecursiveCharacterTextChunkerConfig, RecursiveCharacterTextChunker
from backend.classes.embedding_model import EmbeddingModelConfig, EmbeddingModel
from pydantic import BaseModel
import json
import dotenv
from backend.classes.vector_database.milvus_vector_database import MilvusVectorDatabaseConfig, MilvusVectorDatabase
from backend.utils.utils import get_embedding_model, read_config, set_env_variables, create_vector_database, \
create_text_chunker, initialize_logger
dotenv.load_dotenv()
def get_files(folder_path: str, extension: str = "jsonl") -> list:
# Get all pdf files from folder in a recursive manner using pathlib.Path
files = []
for path in Path(folder_path).rglob(f"*.{extension}"):
files.append(path)
return files
class ChunkerVectorDbConfig(BaseModel):
folder_path: str
chunker: RecursiveCharacterTextChunker
vector_database: MilvusVectorDatabase
embedding_model: EmbeddingModel
class Config:
arbitrary_types_allowed = True
def get_file_data(file_path: str) -> pd.DataFrame:
try:
return pd.read_json(file_path, lines=True)
except Exception as e:
logger.exception(e)
raise e
def chunk_and_save_to_vector_db(config: ChunkerVectorDbConfig):
# Read files from folder
file_paths = get_files(config.folder_path)
logger.info(f"There are {len(file_paths)} to process")
# Extract text from pdf files
for file_path in file_paths:
# Extract pdf data in markdown
logger.info(f"Processing {file_path}")
data_df = get_file_data(str(file_path))
# There are a few rows that are empty due to images not being extracted
# Remove them
data_df = data_df[data_df["markdown_text"] != ""]
data_df["text_chunks"] = data_df["markdown_text"].apply(config.chunker.chunk_text)
data_df = data_df.explode("text_chunks").rename(columns={"text_chunks": "text"})
data_df["chunk_id"] = data_df.groupby("id").cumcount() + 1
data_df["row_chunk_id"] = data_df["id"] + data_df["chunk_id"].astype(str)
data_df["metadata_json"] = data_df["metadata"].apply(lambda d: json.dumps(d))
data_df = data_df.drop(columns=["metadata", "id", "row_chunk_id", "markdown_text", "chunk_id"]).rename(columns={"metadata_json": "metadata"})
embeddings = config.embedding_model.encode(data_df.text.tolist())
config.vector_database.add_texts(data_df, embeddings)
def run(config: dict):
# Create embedding model object
embedding_model_config = EmbeddingModelConfig(model_name=config["embedding_model"]["model_name"],
batch_size=config["embedding_model"]["batch_size"])
embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config)
# Create vector db model object
vector_db_config = MilvusVectorDatabaseConfig(db_path=config["vector_database"]["db_path"],
collection_name=config["vector_database"]["collection_name"],
vector_dimensions=config["vector_database"]["dimensions"])
vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config)
text_chunker_config = RecursiveCharacterTextChunkerConfig(chunk_size=config["chunker"]["chunk_size"],
chunk_overlap=config["chunker"]["chunk_overlap"])
text_chunker = create_text_chunker(RecursiveCharacterTextChunker, text_chunker_config)
chunker_vector_db_config = ChunkerVectorDbConfig(folder_path=config["data"]["output_data_path"],
chunker=text_chunker,
vector_database=vector_db,
embedding_model=embedding_model)
chunk_and_save_to_vector_db(chunker_vector_db_config)
if __name__ == "__main__":
logger = initialize_logger()
# get current file path using Path
config = read_config(str(Path(Path(__file__).parent, "../conf/config.yaml")))
# check if environment variables are set
env_variables = set_env_variables(config["env_variables"])
app_config = config[env_variables["APP_ENV"]]
app_config["env_vars"] = env_variables
run(app_config)