File size: 4,727 Bytes
f373e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
"""
Build the medical knowledge base from downloaded datasets.
Optimized for handling large datasets with progress tracking.
"""
import sys
import gc
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

from tqdm import tqdm
from src.data_pipeline.loaders.dataset_loader import MedicalDatasetLoader
from src.data_pipeline.preprocessors.text_cleaner import MedicalTextCleaner
from src.data_pipeline.preprocessors.chunker import MedicalTextChunker
from src.embeddings.embedding_models import MedicalEmbedder
from src.embeddings.vector_store import VectorStore


def main():
    print("\n" + "=" * 60)
    print("  BUILDING MEDICAL KNOWLEDGE BASE")
    print("=" * 60)
    
    # Initialize components
    print("\n[1/5] Initializing components...")
    loader = MedicalDatasetLoader()
    cleaner = MedicalTextCleaner()
    chunker = MedicalTextChunker(chunk_size=512, chunk_overlap=50)
    embedder = MedicalEmbedder(model_name="all-minilm")
    vector_store = VectorStore(
        collection_name="medical_knowledge",
        persist_directory="data/knowledge_base"
    )
    
    # Show dataset statistics
    print("\n[2/5] Checking available datasets...")
    try:
        stats = loader.get_stats()
        print("  Available datasets:")
        for name, count in stats.items():
            if name != "total" and count > 0:
                print(f"    - {name}: {count:,} entries")
        print(f"  Total raw entries: {stats.get('total', 0):,}")
    except Exception as e:
        print(f"  Could not get stats: {e}")
    
    # Load documents with streaming to handle large datasets
    print("\n[3/5] Loading and processing documents...")
    all_chunks = []
    doc_count = 0
    
    # Process in streaming fashion to reduce memory usage
    for doc in tqdm(loader.get_documents_for_knowledge_base(), desc="Processing documents"):
        try:
            cleaned_content = cleaner.clean(doc["content"])
            if len(cleaned_content.strip()) < 50:  # Skip very short content
                continue
                
            chunks = chunker.chunk_document({
                "content": cleaned_content,
                "source": doc["source"],
                "metadata": doc.get("metadata", {})
            })
            all_chunks.extend(chunks)
            doc_count += 1
            
            # Periodic garbage collection for large datasets
            if doc_count % 50000 == 0:
                gc.collect()
                print(f"    Processed {doc_count:,} documents, {len(all_chunks):,} chunks so far...")
                
        except Exception as e:
            # Skip problematic documents
            continue
    
    print(f"  Processed {doc_count:,} documents")
    print(f"  Created {len(all_chunks):,} text chunks")
    
    # Generate embeddings and add to vector store
    print("\n[4/5] Generating embeddings and indexing...")
    
    batch_size = 500  # Smaller batches for better memory management
    total_chunks = len(all_chunks)
    
    for i in tqdm(range(0, total_chunks, batch_size), desc="Indexing batches"):
        batch = all_chunks[i : i + batch_size]
        texts = [chunk.content for chunk in batch]
        
        try:
            # Generate embeddings for batch
            embeddings = embedder.embed_documents(texts, batch_size=32)
            
            # Prepare metadata
            metadatas = [
                {
                    "source": chunk.source,
                    "chunk_id": chunk.chunk_id,
                    "total_chunks": chunk.total_chunks,
                    **chunk.metadata
                }
                for chunk in batch
            ]
            
            # Add batch to vector store
            vector_store.add_documents(
                documents=texts,
                embeddings=embeddings.tolist(),
                metadatas=metadatas
            )
        except Exception as e:
            print(f"\n  Warning: Failed to process batch at {i}: {e}")
            continue
        
        # Periodic garbage collection
        if (i // batch_size) % 100 == 0:
            gc.collect()
    
    # Verify and summarize
    print("\n[5/5] Finalizing...")
    
    try:
        final_stats = vector_store.get_stats()
    except:
        final_stats = {"count": "unknown"}
    
    print("\n" + "=" * 60)
    print("  KNOWLEDGE BASE BUILD COMPLETE")
    print("=" * 60)
    print(f"  Documents processed: {doc_count:,}")
    print(f"  Chunks created: {len(all_chunks):,}")
    print(f"  Vector store stats: {final_stats}")
    print(f"  Location: data/knowledge_base")
    print("\nThe knowledge base is ready for use!")


if __name__ == "__main__":
    main()