Risk-Adjustment-Version1 / embedding_manager.py
sujataprakashdatycs's picture
Update embedding_manager.py
225134f verified
import os
from typing import List, Optional
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
class DirectoryEmbeddingManager:
"""
Extracts text from PDFs in a directory (or a single PDF) and builds/reuses
a persisted Chroma vector store.
Persistence path: ./embeddings/<PDF_STEM or DIR_NAME>
"""
def __init__(
self,
pdf_dir_or_file: str,
base_dir: str = "./embeddings",
chunk_size: int = 2048,
chunk_overlap: int = 128,
embedding_model: str = "text-embedding-ada-002",
openai_api_key_env: str = "OPENAI_API_KEY",
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.embedding_model = embedding_model
self.openai_api_key_env = openai_api_key_env
if os.path.isfile(pdf_dir_or_file):
self.pdf_files = [pdf_dir_or_file]
folder_name = os.path.splitext(os.path.basename(pdf_dir_or_file))[0]
else:
self.pdf_files = [
os.path.join(pdf_dir_or_file, f)
for f in os.listdir(pdf_dir_or_file)
if f.lower().endswith(".pdf")
]
folder_name = os.path.basename(pdf_dir_or_file.rstrip("/"))
self.base_dir = base_dir
self.persist_dir = os.path.join(base_dir, folder_name)
os.makedirs(base_dir, exist_ok=True)
# Cache path for combined text from all PDFs
self.txt_path = os.path.join(base_dir, f"{folder_name}.txt")
def pdfs_to_txt(self) -> str:
"""Dump text from all PDFs to a single .txt (idempotent)."""
if os.path.exists(self.txt_path):
print(f"[INFO] Using existing text at {self.txt_path}")
return self.txt_path
all_texts: List[str] = []
for pdf_path in self.pdf_files:
reader = PdfReader(pdf_path)
file_text = [page.extract_text() for page in reader.pages if page.extract_text()]
if file_text:
all_texts.append("\n".join(file_text))
combined_text = "\n".join(all_texts)
with open(self.txt_path, "w", encoding="utf-8") as f:
f.write(combined_text)
print(f"[INFO] Extracted text to {self.txt_path}")
return self.txt_path
def _load_embeddings(self) -> OpenAIEmbeddings:
key = os.environ.get(self.openai_api_key_env)
if not key:
raise RuntimeError(f"Missing {self.openai_api_key_env} in environment.")
return OpenAIEmbeddings(api_key=key, model=self.embedding_model)
def get_or_create_embeddings(self) -> Chroma:
"""
Returns a Chroma vector store, creating & persisting if needed.
"""
embeddings = self._load_embeddings()
if os.path.exists(self.persist_dir) and os.listdir(self.persist_dir):
print(f"[INFO] Loading embeddings from {self.persist_dir}")
return Chroma(persist_directory=self.persist_dir, embedding_function=embeddings)
txt_file = self.pdfs_to_txt()
with open(txt_file, "r", encoding="utf-8") as f:
text = f.read()
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
chunks: List[str] = splitter.split_text(text)
vectordb = Chroma.from_texts(
chunks,
embedding=embeddings,
persist_directory=self.persist_dir
)
vectordb.persist()
print(f"[INFO] Created embeddings in {self.persist_dir}")
return vectordb
def query(self, query_text: str, top_k: int = 5) -> str:
vectordb = self.get_or_create_embeddings()
results = vectordb.similarity_search(query_text, k=top_k)
return "\n".join([doc.page_content for doc in results])