File size: 8,839 Bytes
9fa16f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#!/usr/bin/env python3
"""
Pre-processes PDF files in the data directory, creating and saving:
1. Document chunks with metadata
2. Vector embeddings

This allows the app to load pre-processed data directly instead of processing
PDFs at runtime, making the app start faster and eliminating the need to
upload PDFs to Hugging Face.
"""

import os
import pickle
import tiktoken
from pathlib import Path
from collections import defaultdict
import json

# Import required LangChain modules
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_qdrant import Qdrant
from langchain_core.documents import Document
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct

# Set OpenAI API key
from dotenv import load_dotenv
load_dotenv()

# Ensure OpenAI API key is available
if not os.environ.get("OPENAI_API_KEY"):
    raise EnvironmentError("OPENAI_API_KEY environment variable not found. Please set it in your .env file.")

# Create directories to store pre-processed data
PROCESSED_DATA_DIR = Path("processed_data")
PROCESSED_DATA_DIR.mkdir(exist_ok=True)

CHUNKS_FILE = PROCESSED_DATA_DIR / "document_chunks.pkl"
QDRANT_DIR = PROCESSED_DATA_DIR / "qdrant_vectorstore"

def load_and_process_documents():
    """
    Load PDF documents, merge them by source, and split into chunks.
    This is the same process used in the notebook and app.py.
    """
    print("Loading PDF documents...")
    path = "data/"
    loader = DirectoryLoader(path, glob="*.pdf", loader_cls=PyPDFLoader)
    all_docs = loader.load()
    print(f"Loaded {len(all_docs)} pages from PDF files")
    
    # Create a mapping of merged document chunks back to original pages
    page_map = {}
    current_index = 0
    
    # For source tracking, we'll store page information before merging
    docs_by_source = defaultdict(list)
    
    # Group documents by their source file
    for doc in all_docs:
        source = doc.metadata.get("source", "")
        docs_by_source[source].append(doc)
    
    # Merge pages from the same PDF but track page ranges
    merged_docs = []
    for source, source_docs in docs_by_source.items():
        # Sort by page number if available
        source_docs.sort(key=lambda x: x.metadata.get("page", 0))
        
        # Merge the content
        merged_content = ""
        page_ranges = []
        current_pos = 0
        
        for doc in source_docs:
            # Get the page number (1-indexed for human readability)
            page_num = doc.metadata.get("page", 0) + 1
            
            # Add a separator between pages for clarity
            if merged_content:
                merged_content += "\n\n"
            
            # Record where this page's content starts in the merged document
            start_pos = len(merged_content)
            merged_content += doc.page_content
            end_pos = len(merged_content)
            
            # Store the mapping of character ranges to original page numbers
            page_ranges.append({
                "start": start_pos,
                "end": end_pos,
                "page": page_num,
                "source": source
            })
        
        # Create merged metadata that includes page mapping information
        merged_metadata = {
            "source": source,
            "title": source.split("/")[-1],
            "page_count": len(source_docs),
            "merged": True,
            "page_ranges": page_ranges  # Store the page ranges for later reference
        }
        
        # Create a new document with the merged content
        merged_doc = Document(page_content=merged_content, metadata=merged_metadata)
        merged_docs.append(merged_doc)
    
    print(f"Created {len(merged_docs)} merged documents")
    
    # tiktoken_len counts tokens (not characters) using the gpt-4o-mini tokenizer
    def tiktoken_len(text):
        tokens = tiktoken.encoding_for_model("gpt-4o-mini").encode(
            text,
        )
        return len(tokens)
    
    # Process splits to add page info based on character position
    def add_page_info_to_splits(splits):
        for split in splits:
            # Get the start position of this chunk
            start_pos = split.metadata.get("start_index", 0)
            end_pos = start_pos + len(split.page_content)
            
            # Find which page this chunk belongs to
            if "page_ranges" in split.metadata:
                for page_range in split.metadata["page_ranges"]:
                    # If chunk significantly overlaps with this page range
                    if (start_pos <= page_range["end"] and 
                        end_pos >= page_range["start"]):
                        # Use this page number
                        split.metadata["page"] = page_range["page"]
                        break
        return splits
    
    # Split the text with start index tracking
    print("Splitting documents into chunks...")
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=300,
        chunk_overlap=50,
        length_function=tiktoken_len,
        add_start_index=True
    )
    
    # Split and then process to add page information
    raw_splits = text_splitter.split_documents(merged_docs)
    split_chunks = add_page_info_to_splits(raw_splits)
    print(f"Created {len(split_chunks)} document chunks")
    
    return split_chunks

def create_and_save_vectorstore(chunks):
    """
    Create a vector store from document chunks and save it to disk.
    """
    print("Creating embeddings and vector store...")
    embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
    
    # Extract text and metadata for separate processing
    texts = [doc.page_content for doc in chunks]
    metadatas = [doc.metadata for doc in chunks]
    
    # Ensure the directory exists
    QDRANT_DIR.mkdir(exist_ok=True, parents=True)

    # Create a local Qdrant client
    client = QdrantClient(path=str(QDRANT_DIR))
    
    # Get the embedding dimension
    sample_embedding = embedding_model.embed_query("Sample text")
    
    # Create the collection if it doesn't exist
    collection_name = "kohavi_ab_testing_pdf_collection"
    try:
        collection_info = client.get_collection(collection_name)
        print(f"Collection {collection_name} already exists")
    except Exception:
        # Collection doesn't exist, create it
        print(f"Creating collection {collection_name}")
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(
                size=len(sample_embedding),
                distance=Distance.COSINE
            )
        )
    
    # Process in batches to avoid memory issues
    batch_size = 100
    print(f"Processing {len(texts)} documents in batches of {batch_size}")
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_metadatas = metadatas[i:i+batch_size]
        
        print(f"Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}")
        
        # Get embeddings for this batch
        embeddings = embedding_model.embed_documents(batch_texts)
        
        # Create points for this batch
        points = []
        for j, (text, embedding, metadata) in enumerate(zip(batch_texts, embeddings, batch_metadatas)):
            points.append(PointStruct(
                id=i + j,
                vector=embedding,
                payload={
                    "text": text,
                    "metadata": metadata
                }
            ))
        
        # Upsert points into the collection
        client.upsert(
            collection_name=collection_name,
            points=points
        )
    
    print(f"Vector store created and saved to {QDRANT_DIR}")
    return True

def main():
    # Load and process documents
    print("Starting pre-processing of PDF files...")
    chunks = load_and_process_documents()
    
    # Save chunks to disk
    print(f"Saving {len(chunks)} document chunks to {CHUNKS_FILE}...")
    with open(CHUNKS_FILE, 'wb') as f:
        pickle.dump(chunks, f)
    print(f"Chunks saved to {CHUNKS_FILE}")
    
    # Create and save vector store
    success = create_and_save_vectorstore(chunks)
    
    if success:
        print("Pre-processing complete! The application can now use these pre-processed files.")
        print(f"- Document chunks: {CHUNKS_FILE}")
        print(f"- Vector store: {QDRANT_DIR}")
    else:
        print("Error creating vector store. Please check the logs.")

if __name__ == "__main__":
    main()