File size: 3,969 Bytes
3fd0b05
225134f
3fd0b05
 
225134f
3fd0b05
 
225134f
3fd0b05
225134f
 
 
3fd0b05
 
 
225134f
3fd0b05
 
 
 
 
 
 
 
 
 
225134f
 
 
 
 
 
 
 
 
 
 
 
 
 
3fd0b05
 
225134f
 
 
 
 
3fd0b05
 
 
 
225134f
 
 
 
 
 
 
 
3fd0b05
225134f
 
3fd0b05
 
 
 
 
 
225134f
3fd0b05
 
 
 
 
 
 
 
 
 
 
 
225134f
 
3fd0b05
225134f
3fd0b05
225134f
 
3fd0b05
 
 
 
 
 
 
 
 
 
225134f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
from typing import List, Optional
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma

class DirectoryEmbeddingManager:
    """
    Extracts text from PDFs in a directory (or a single PDF) and builds/reuses
    a persisted Chroma vector store.
    Persistence path: ./embeddings/<PDF_STEM or DIR_NAME>
    """
    def __init__(
        self,
        pdf_dir_or_file: str,
        base_dir: str = "./embeddings",
        chunk_size: int = 2048,
        chunk_overlap: int = 128,
        embedding_model: str = "text-embedding-ada-002",
        openai_api_key_env: str = "OPENAI_API_KEY",
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.embedding_model = embedding_model
        self.openai_api_key_env = openai_api_key_env

        if os.path.isfile(pdf_dir_or_file):
            self.pdf_files = [pdf_dir_or_file]
            folder_name = os.path.splitext(os.path.basename(pdf_dir_or_file))[0]
        else:
            self.pdf_files = [
                os.path.join(pdf_dir_or_file, f)
                for f in os.listdir(pdf_dir_or_file)
                if f.lower().endswith(".pdf")
            ]
            folder_name = os.path.basename(pdf_dir_or_file.rstrip("/"))

        self.base_dir = base_dir
        self.persist_dir = os.path.join(base_dir, folder_name)
        os.makedirs(base_dir, exist_ok=True)

        # Cache path for combined text from all PDFs
        self.txt_path = os.path.join(base_dir, f"{folder_name}.txt")

    def pdfs_to_txt(self) -> str:
        """Dump text from all PDFs to a single .txt (idempotent)."""
        if os.path.exists(self.txt_path):
            print(f"[INFO] Using existing text at {self.txt_path}")
            return self.txt_path

        all_texts: List[str] = []
        for pdf_path in self.pdf_files:
            reader = PdfReader(pdf_path)
            file_text = [page.extract_text() for page in reader.pages if page.extract_text()]
            if file_text:
                all_texts.append("\n".join(file_text))

        combined_text = "\n".join(all_texts)
        with open(self.txt_path, "w", encoding="utf-8") as f:
            f.write(combined_text)

        print(f"[INFO] Extracted text to {self.txt_path}")
        return self.txt_path

    def _load_embeddings(self) -> OpenAIEmbeddings:
        key = os.environ.get(self.openai_api_key_env)
        if not key:
            raise RuntimeError(f"Missing {self.openai_api_key_env} in environment.")
        return OpenAIEmbeddings(api_key=key, model=self.embedding_model)

    def get_or_create_embeddings(self) -> Chroma:
        """
        Returns a Chroma vector store, creating & persisting if needed.
        """
        embeddings = self._load_embeddings()

        if os.path.exists(self.persist_dir) and os.listdir(self.persist_dir):
            print(f"[INFO] Loading embeddings from {self.persist_dir}")
            return Chroma(persist_directory=self.persist_dir, embedding_function=embeddings)

        txt_file = self.pdfs_to_txt()
        with open(txt_file, "r", encoding="utf-8") as f:
            text = f.read()

        splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap
        )
        chunks: List[str] = splitter.split_text(text)

        vectordb = Chroma.from_texts(
            chunks,
            embedding=embeddings,
            persist_directory=self.persist_dir
        )
        vectordb.persist()
        print(f"[INFO] Created embeddings in {self.persist_dir}")
        return vectordb

    def query(self, query_text: str, top_k: int = 5) -> str:
        vectordb = self.get_or_create_embeddings()
        results = vectordb.similarity_search(query_text, k=top_k)
        return "\n".join([doc.page_content for doc in results])