File size: 10,868 Bytes
36b893f
de38977
ace5cd4
de38977
90e2962
a555050
de38977
a555050
de38977
 
 
 
 
36b893f
de38977
90e2962
 
de38977
36b893f
de38977
 
 
 
 
 
 
 
 
 
 
 
82b35ca
ace5cd4
a555050
90e2962
 
 
a555050
 
de38977
 
 
 
 
 
 
 
 
 
 
 
 
a555050
 
de38977
 
9513c18
ace5cd4
de38977
 
 
25d31a4
 
de38977
 
 
25d31a4
de38977
25d31a4
de38977
25d31a4
de38977
 
 
25d31a4
de38977
25d31a4
 
de38977
25d31a4
de38977
25d31a4
 
de38977
25d31a4
 
 
de38977
25d31a4
de38977
25d31a4
de38977
 
25d31a4
 
 
244f753
de38977
 
 
244f753
 
de38977
244f753
 
de38977
 
244f753
ace5cd4
de38977
244f753
 
 
 
 
 
ace5cd4
244f753
 
de38977
244f753
 
 
 
de38977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ace5cd4
de38977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ace5cd4
 
de38977
244f753
 
 
de38977
244f753
 
 
 
de38977
 
 
 
 
ace5cd4
de38977
cf1fb02
de38977
cf1fb02
 
de38977
 
 
 
 
cf1fb02
ace5cd4
de38977
 
cf1fb02
 
de38977
 
cf1fb02
de38977
 
cf1fb02
 
ace5cd4
de38977
ace5cd4
 
 
 
de38977
ace5cd4
 
de38977
 
 
ace5cd4
 
de38977
ace5cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de38977
 
ace5cd4
 
 
de38977
 
ace5cd4
 
de38977
ace5cd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import os
import time
from pathlib import Path

from dotenv import load_dotenv
from langchain_nebius import NebiusEmbeddings
from langchain_unstructured import UnstructuredLoader
from pydantic import SecretStr
from pymilvus import MilvusClient, DataType
from unstructured.cleaners.core import (
    clean_extra_whitespace, 
    replace_unicode_quotes
)

# Load environment variables
load_dotenv()

# Configuration constants
MILVUS_URI = os.getenv("MILVUS_URI", "http://localhost:19530")
COLLECTION_NAME = "my_rag_collection"
DOCUMENT_DIR = "data/"
EMBEDDING_DIMENSION = 4096
TEXT_MAX_LENGTH = 65000
CHUNK_SIZE = 100
BATCH_SIZE = 5

# Chunking configuration
MAX_CHARACTERS = 1500
COMBINE_TEXT_UNDER_N_CHARS = 200

# Initialize clients
milvus_client = MilvusClient(uri=MILVUS_URI, token=os.getenv("MILVUS_API_KEY"))

embedding_model = NebiusEmbeddings(
    api_key=SecretStr(os.getenv("NEBIUS_API_KEY", os.getenv("OPENAI_API_KEY"))),
    model="Qwen/Qwen3-Embedding-8B",
    base_url="https://api.studio.nebius.ai/v1"
)

def clean_text(text):
    """Simple text cleaning for educational documents."""
    import re
    
    # Basic cleaning without problematic functions
    text = clean_extra_whitespace(text)
    text = replace_unicode_quotes(text)
    
    # Simple normalizations
    text = re.sub(r'[\r\n]+', ' ', text)  # Convert newlines to spaces
    text = re.sub(r'\s+', ' ', text)      # Multiple spaces to single space
    
    return text.strip()


def generate_embedding(text):
    """Generate embedding for a single text."""
    return embedding_model.embed_query(text)


def generate_embeddings_batch(texts):
    """Generate embeddings for multiple texts efficiently."""
    return embedding_model.embed_documents(texts)


def process_embeddings_in_batches(texts, batch_size=BATCH_SIZE):
    """Process embeddings in batches with error handling."""
    all_embeddings = []
    total_batches = (len(texts) + batch_size - 1) // batch_size
    
    print(f"Generating embeddings in {total_batches} batches of {batch_size}...")
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        batch_num = i // batch_size + 1
        
        print(f"Processing batch {batch_num}/{total_batches}")
        
        try:
            batch_embeddings = generate_embeddings_batch(batch_texts)
            all_embeddings.extend(batch_embeddings)
            time.sleep(1.5)  # API rate limiting
            
        except Exception as e:
            print(f"Batch {batch_num} failed: {e}. Processing individually...")
            
            for j, text in enumerate(batch_texts):
                try:
                    embedding = generate_embedding(text)
                    all_embeddings.append(embedding)
                    time.sleep(2)
                except Exception as individual_error:
                    print(f"Failed to process document {i+j+1}: {individual_error}")
                    all_embeddings.append([0.0] * EMBEDDING_DIMENSION)
    
    return all_embeddings

def create_collection():
    """Create Milvus collection if it doesn't exist."""
    if milvus_client.has_collection(COLLECTION_NAME):
        milvus_client.load_collection(collection_name=COLLECTION_NAME)
        return
    
    # Create collection schema
    schema = milvus_client.create_schema(auto_id=False, enable_dynamic_field=False)
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIMENSION)
    schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
    schema.add_field(field_name="metadata", datatype=DataType.JSON)

    # Create vector index
    index_params = MilvusClient.prepare_index_params()
    index_params.add_index(
        field_name="vector",
        metric_type="COSINE",
        index_type="AUTOINDEX",
    )

    # Create and load collection
    milvus_client.create_collection(
        collection_name=COLLECTION_NAME,
        schema=schema,
        index_params=index_params,
        consistency_level="Strong",
    )
    milvus_client.load_collection(collection_name=COLLECTION_NAME)

def load_documents():
    """Load documents from the data directory."""
    file_extensions = ["*.pdf", "*.docx", "*.html"]
    file_paths = []

    for ext in file_extensions:
        file_paths.extend(Path(DOCUMENT_DIR).glob(ext))

    file_paths = [str(file) for file in file_paths]
    
    loader = UnstructuredLoader(
        file_paths, 
        chunking_strategy="by_title",
        include_orig_elements=False
    )
    
    docs = loader.load()
    print(f"Loaded {len(docs)} initial documents")
    
    # Apply additional cleaning and chunking
    final_chunks = []
    
    for doc in docs:
        # Clean text
        cleaned_text = clean_text(doc.page_content)
        
        # Skip very short chunks
        if len(cleaned_text) < 50:
            continue
            
        # Split if too large
        if len(cleaned_text) <= MAX_CHARACTERS:
            doc.page_content = cleaned_text
            final_chunks.append(doc)
        else:
            # Split large chunks on sentence boundaries
            chunks = _split_large_chunk(cleaned_text, doc.metadata)
            final_chunks.extend(chunks)
    
    print(f"Final processed chunks: {len(final_chunks)}")
    if final_chunks:
        avg_length = sum(len(doc.page_content) for doc in final_chunks) / len(final_chunks)
        print(f"Average chunk length: {avg_length:.0f} characters")
    
    return final_chunks


def _split_large_chunk(text, metadata):
    """Split large text into smaller chunks."""
    from langchain.schema import Document
    
    chunks = []
    sentences = text.split('. ')
    current_chunk = ""
    
    for sentence in sentences:
        potential_chunk = current_chunk + sentence + '. '
        
        if len(potential_chunk) > MAX_CHARACTERS and len(current_chunk) > COMBINE_TEXT_UNDER_N_CHARS:
            if current_chunk.strip():
                chunks.append(Document(
                    page_content=current_chunk.strip(),
                    metadata=metadata.copy()
                ))
            current_chunk = sentence + '. '
        else:
            current_chunk = potential_chunk
    
    # Add remaining content
    if current_chunk.strip():
        chunks.append(Document(
            page_content=current_chunk.strip(),
            metadata=metadata.copy()
        ))
    
    return chunks


def prepare_document_data(docs, start_idx=0):
    """Prepare document data for insertion."""
    texts_to_embed = []
    doc_data = []
    
    for i, doc in enumerate(docs):
        text_content = doc.page_content
        if len(text_content) > TEXT_MAX_LENGTH:
            text_content = text_content[:TEXT_MAX_LENGTH]
            print(f"Document {start_idx + i + 1} truncated to {TEXT_MAX_LENGTH} characters")
        
        texts_to_embed.append(text_content)
        doc_data.append({
            "id": start_idx + i,
            "text": text_content,
            "metadata": doc.metadata or {}
        })
    
    return texts_to_embed, doc_data


def process_document_chunk(docs, chunk_idx, chunk_num, total_chunks):
    """Process a single chunk of documents."""
    print(f"\nProcessing chunk {chunk_num}/{total_chunks}")
    
    # Prepare document data
    texts_to_embed, doc_data = prepare_document_data(docs, chunk_idx)
    
    # Generate embeddings
    print(f"Generating embeddings for {len(texts_to_embed)} documents...")
    embeddings = process_embeddings_in_batches(texts_to_embed)
    
    # Prepare data for insertion
    data_to_insert = []
    for doc_info, embedding in zip(doc_data, embeddings):
        data_to_insert.append({
            "id": doc_info["id"],
            "vector": embedding,
            "text": doc_info["text"],
            "metadata": doc_info["metadata"]
        })
    
    # Insert into Milvus
    insert_result = milvus_client.insert(collection_name=COLLECTION_NAME, data=data_to_insert)
    return insert_result['insert_count']

def main():
    """Main function to process and insert documents into Milvus."""
    create_collection()
    
    # Check if collection already has data
    stats = milvus_client.get_collection_stats(COLLECTION_NAME)
    if stats['row_count'] > 0:
        print(f"Collection already contains {stats['row_count']} documents. Skipping insertion.")
        return
    
    # Load documents
    docs = load_documents()
    if not docs:
        print("No documents found to process.")
        return
    
    # Process documents in chunks
    total_docs = len(docs)
    total_chunks = (total_docs + CHUNK_SIZE - 1) // CHUNK_SIZE
    total_inserted = 0
    
    print(f"Processing {total_docs} documents in {total_chunks} chunks of {CHUNK_SIZE}")
    
    for chunk_idx in range(0, total_docs, CHUNK_SIZE):
        chunk_end = min(chunk_idx + CHUNK_SIZE, total_docs)
        chunk_num = chunk_idx // CHUNK_SIZE + 1
        current_chunk = docs[chunk_idx:chunk_end]
        
        # Process chunk
        chunk_inserted = process_document_chunk(current_chunk, chunk_idx, chunk_num, total_chunks)
        total_inserted += chunk_inserted
        
        print(f"Chunk {chunk_num} complete: {chunk_inserted} docs inserted")
        print(f"Progress: {total_inserted}/{total_docs} ({(total_inserted/total_docs)*100:.1f}%)")
        
        # Memory cleanup
        del current_chunk
        if chunk_num < total_chunks:
            time.sleep(2)
    
    print(f"\nSuccessfully processed {total_inserted} documents!")


def verify_insertion():
    """Verify that data was successfully inserted into Milvus."""
    stats = milvus_client.get_collection_stats(COLLECTION_NAME)
    print(f"Collection stats: {stats}")
    
    # Test search functionality
    test_query = "Why should reasonable adjustments be made?"
    test_embedding = generate_embedding(test_query)
    
    search_results = milvus_client.search(
        collection_name=COLLECTION_NAME,
        data=[test_embedding],
        limit=3,
        output_fields=["text", "metadata"]
    )
    
    print(f"\nTest search results for '{test_query}':")
    for i, result in enumerate(search_results[0]):
        print(f"Result {i+1}:")
        print(f"  Score: {result['distance']:.4f}")
        print(f"  Text preview: {result['entity']['text'][:200]}...")
        print(f"  Metadata: {result['entity']['metadata']}")
        print("-" * 50)


if __name__ == "__main__":
    start_time = time.time()
    
    print("Starting document processing and Milvus insertion")
    print("=" * 60)
    
    main()
    
    print("\nVerifying data insertion")
    print("=" * 30)
    verify_insertion()
    
    elapsed_time = time.time() - start_time
    print(f"\nTotal execution time: {elapsed_time:.2f} seconds")