| |
| """ |
| Colab-compatible script to build the medical knowledge base. |
| Run this in Google Colab for stable environment. |
| |
| Usage: |
| 1. Upload your final_project folder to Colab or mount Google Drive |
| 2. Run: !pip install chromadb sentence-transformers pandas pyarrow tqdm |
| 3. Run this script |
| """ |
| import sys |
| import gc |
| from pathlib import Path |
|
|
| |
| PROJECT_ROOT = Path("/content/final_project") |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| import pandas as pd |
| from tqdm import tqdm |
| import numpy as np |
| from typing import List, Dict, Generator |
|
|
|
|
| class SimpleEmbedder: |
| """Simple sentence-transformers embedder.""" |
| |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2"): |
| from sentence_transformers import SentenceTransformer |
| self.model = SentenceTransformer(model_name) |
| self.dimension = self.model.get_sentence_embedding_dimension() |
| print(f"Loaded embedding model. Dimension: {self.dimension}") |
| |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: |
| embeddings = self.model.encode( |
| texts, |
| batch_size=batch_size, |
| show_progress_bar=False, |
| normalize_embeddings=True |
| ) |
| return np.array(embeddings) |
|
|
|
|
| class SimpleVectorStore: |
| """Simplified ChromaDB vector store.""" |
| |
| def __init__(self, collection_name: str, persist_directory: str): |
| import chromadb |
| from chromadb.config import Settings |
| |
| self.persist_directory = Path(persist_directory) |
| self.persist_directory.mkdir(parents=True, exist_ok=True) |
| |
| self.client = chromadb.PersistentClient(path=str(self.persist_directory)) |
| self.collection = self.client.get_or_create_collection( |
| name=collection_name, |
| metadata={"hnsw:space": "cosine"} |
| ) |
| print(f"Vector store initialized. Documents: {self.collection.count()}") |
| |
| def add_documents(self, documents: List[str], embeddings: List[List[float]], |
| metadatas: List[Dict], ids: List[str]): |
| |
| clean_metadatas = [] |
| for meta in metadatas: |
| clean_meta = {} |
| for k, v in meta.items(): |
| if isinstance(v, (str, int, float, bool)): |
| clean_meta[k] = v |
| elif v is None: |
| clean_meta[k] = "" |
| else: |
| clean_meta[k] = str(v) |
| clean_metadatas.append(clean_meta) |
| |
| self.collection.add( |
| ids=ids, |
| embeddings=embeddings, |
| documents=documents, |
| metadatas=clean_metadatas |
| ) |
| |
| def count(self): |
| return self.collection.count() |
|
|
|
|
| class TextChunk: |
| def __init__(self, content, source, chunk_id, total_chunks, metadata): |
| self.content = content |
| self.source = source |
| self.chunk_id = chunk_id |
| self.total_chunks = total_chunks |
| self.metadata = metadata |
|
|
|
|
| def chunk_text(text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]: |
| """Simple text chunking.""" |
| words = text.split() |
| chunks = [] |
| start = 0 |
| |
| while start < len(words): |
| end = start + chunk_size |
| chunk = " ".join(words[start:end]) |
| if chunk.strip(): |
| chunks.append(chunk) |
| start = end - overlap |
| if end >= len(words): |
| break |
| |
| return chunks if chunks else [text] |
|
|
|
|
| def load_all_qa_pairs(data_dir: Path) -> Generator[Dict, None, None]: |
| """Load all QA pairs from parquet files.""" |
| |
| |
| path = data_dir / "mediqa" / "medquad.parquet" |
| if path.exists(): |
| df = pd.read_parquet(path) |
| for _, row in df.iterrows(): |
| yield { |
| "question": row.get("Question", row.get("question", "")), |
| "answer": row.get("Answer", row.get("answer", "")), |
| "source": "MedQuAD" |
| } |
| print(f" Loaded MedQuAD: {len(df):,}") |
| |
| |
| path = data_dir / "pubmed" / "pubmedqa_labeled.parquet" |
| if path.exists(): |
| df = pd.read_parquet(path) |
| for _, row in df.iterrows(): |
| yield { |
| "question": row.get("question", ""), |
| "answer": row.get("long_answer", ""), |
| "source": "PubMedQA" |
| } |
| print(f" Loaded PubMedQA: {len(df):,}") |
| |
| |
| path = data_dir / "mediqa" / "medmcqa_train.parquet" |
| if path.exists(): |
| df = pd.read_parquet(path) |
| count = 0 |
| for _, row in df.iterrows(): |
| answer = row.get("exp") |
| if answer and not pd.isna(answer): |
| yield { |
| "question": row.get("question", ""), |
| "answer": str(answer), |
| "source": f"MedMCQA" |
| } |
| count += 1 |
| print(f" Loaded MedMCQA: {count:,}") |
| |
| |
| path = data_dir / "mediqa" / "healthcare_magic.parquet" |
| if path.exists(): |
| df = pd.read_parquet(path) |
| for _, row in df.iterrows(): |
| question = row.get("input", row.get("instruction", "")) |
| yield { |
| "question": question, |
| "answer": row.get("output", ""), |
| "source": "HealthCareMagic" |
| } |
| print(f" Loaded HealthCareMagic: {len(df):,}") |
| |
| |
| for filename in ["medqa_usmle_train.parquet", "medqa_usmle_test.parquet"]: |
| path = data_dir / "medqa" / filename |
| if path.exists(): |
| df = pd.read_parquet(path) |
| for _, row in df.iterrows(): |
| question = row.get("question", row.get("sent1", "")) |
| answer = row.get("answer", "") |
| |
| options = row.get("options", []) |
| answer_idx = row.get("answer_idx", row.get("label", -1)) |
| if options and isinstance(answer_idx, int) and 0 <= answer_idx < len(options): |
| answer = options[answer_idx] |
| |
| if question and answer: |
| yield { |
| "question": question, |
| "answer": str(answer), |
| "source": "MedQA-USMLE" |
| } |
| print(f" Loaded {filename}: {len(df):,}") |
| |
| |
| for filename in ["chatdoctor_icliniq.parquet", "chatdoctor_healthcaremagic.parquet"]: |
| path = data_dir / "chatdoctor" / filename |
| if path.exists(): |
| df = pd.read_parquet(path) |
| for _, row in df.iterrows(): |
| question = row.get("input", row.get("instruction", row.get("question", ""))) |
| answer = row.get("output", row.get("answer", "")) |
| if question and answer: |
| yield { |
| "question": question, |
| "answer": answer, |
| "source": "ChatDoctor" |
| } |
| print(f" Loaded {filename}: {len(df):,}") |
| |
| |
| meadow_dir = data_dir / "medical_meadow" |
| if meadow_dir.exists(): |
| for parquet_file in meadow_dir.glob("*.parquet"): |
| df = pd.read_parquet(parquet_file) |
| for _, row in df.iterrows(): |
| instruction = row.get("instruction", "") |
| input_text = row.get("input", "") |
| output_text = row.get("output", "") |
| |
| question = instruction |
| if input_text: |
| question = f"{instruction}\n\n{input_text}" if instruction else input_text |
| |
| if question and output_text: |
| yield { |
| "question": question, |
| "answer": output_text, |
| "source": f"MedicalMeadow" |
| } |
| print(f" Loaded {parquet_file.name}: {len(df):,}") |
|
|
|
|
| def main(): |
| print("\n" + "=" * 60) |
| print(" BUILDING MEDICAL KNOWLEDGE BASE (Colab Version)") |
| print("=" * 60) |
| |
| DATA_DIR = PROJECT_ROOT / "data" / "raw" |
| KB_DIR = PROJECT_ROOT / "data" / "knowledge_base_new" |
| |
| |
| print("\n[1/4] Initializing components...") |
| embedder = SimpleEmbedder("all-MiniLM-L6-v2") |
| vector_store = SimpleVectorStore( |
| collection_name="medical_knowledge", |
| persist_directory=str(KB_DIR) |
| ) |
| |
| |
| print("\n[2/4] Loading and processing documents...") |
| all_chunks = [] |
| doc_count = 0 |
| |
| for qa in tqdm(load_all_qa_pairs(DATA_DIR), desc="Processing"): |
| content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}" |
| |
| |
| if len(content.strip()) < 50: |
| continue |
| |
| |
| chunks = chunk_text(content, chunk_size=512, overlap=50) |
| |
| for i, chunk in enumerate(chunks): |
| all_chunks.append(TextChunk( |
| content=chunk, |
| source=qa['source'], |
| chunk_id=i + 1, |
| total_chunks=len(chunks), |
| metadata={"type": "qa_pair"} |
| )) |
| |
| doc_count += 1 |
| |
| |
| if doc_count % 50000 == 0: |
| gc.collect() |
| print(f" Processed {doc_count:,} documents, {len(all_chunks):,} chunks...") |
| |
| print(f"\n Total documents: {doc_count:,}") |
| print(f" Total chunks: {len(all_chunks):,}") |
| |
| |
| print("\n[3/4] Generating embeddings and indexing...") |
| |
| batch_size = 500 |
| total_chunks = len(all_chunks) |
| |
| for i in tqdm(range(0, total_chunks, batch_size), desc="Indexing"): |
| batch = all_chunks[i : i + batch_size] |
| texts = [chunk.content for chunk in batch] |
| |
| try: |
| embeddings = embedder.embed_documents(texts, batch_size=32) |
| |
| metadatas = [ |
| { |
| "source": chunk.source, |
| "chunk_id": chunk.chunk_id, |
| "total_chunks": chunk.total_chunks, |
| **chunk.metadata |
| } |
| for chunk in batch |
| ] |
| |
| ids = [f"chunk_{i + j}" for j in range(len(batch))] |
| |
| vector_store.add_documents( |
| documents=texts, |
| embeddings=embeddings.tolist(), |
| metadatas=metadatas, |
| ids=ids |
| ) |
| except Exception as e: |
| print(f"\n Error at batch {i}: {e}") |
| continue |
| |
| if (i // batch_size) % 100 == 0: |
| gc.collect() |
| |
| |
| print("\n[4/4] Finalizing...") |
| final_count = vector_store.count() |
| |
| print("\n" + "=" * 60) |
| print(" BUILD COMPLETE!") |
| print("=" * 60) |
| print(f" Documents processed: {doc_count:,}") |
| print(f" Chunks indexed: {final_count:,}") |
| print(f" Location: {KB_DIR}") |
| print("\nDownload the knowledge_base_new folder and replace your local one!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|