MedSpace / scripts /build_knowledge_base_colab.py
kbsss's picture
Upload folder using huggingface_hub
f373e2b verified
Raw
History Blame Contribute Delete
11.2 kB
#!/usr/bin/env python3
"""
Colab-compatible script to build the medical knowledge base.
Run this in Google Colab for stable environment.
Usage:
1. Upload your final_project folder to Colab or mount Google Drive
2. Run: !pip install chromadb sentence-transformers pandas pyarrow tqdm
3. Run this script
"""
import sys
import gc
from pathlib import Path
# Add project to path
PROJECT_ROOT = Path("/content/final_project") # Change to your path
sys.path.insert(0, str(PROJECT_ROOT))
import pandas as pd
from tqdm import tqdm
import numpy as np
from typing import List, Dict, Generator
class SimpleEmbedder:
"""Simple sentence-transformers embedder."""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name)
self.dimension = self.model.get_sentence_embedding_dimension()
print(f"Loaded embedding model. Dimension: {self.dimension}")
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False,
normalize_embeddings=True
)
return np.array(embeddings)
class SimpleVectorStore:
"""Simplified ChromaDB vector store."""
def __init__(self, collection_name: str, persist_directory: str):
import chromadb
from chromadb.config import Settings
self.persist_directory = Path(persist_directory)
self.persist_directory.mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(path=str(self.persist_directory))
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
print(f"Vector store initialized. Documents: {self.collection.count()}")
def add_documents(self, documents: List[str], embeddings: List[List[float]],
metadatas: List[Dict], ids: List[str]):
# Clean metadata
clean_metadatas = []
for meta in metadatas:
clean_meta = {}
for k, v in meta.items():
if isinstance(v, (str, int, float, bool)):
clean_meta[k] = v
elif v is None:
clean_meta[k] = ""
else:
clean_meta[k] = str(v)
clean_metadatas.append(clean_meta)
self.collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=clean_metadatas
)
def count(self):
return self.collection.count()
class TextChunk:
def __init__(self, content, source, chunk_id, total_chunks, metadata):
self.content = content
self.source = source
self.chunk_id = chunk_id
self.total_chunks = total_chunks
self.metadata = metadata
def chunk_text(text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
"""Simple text chunking."""
words = text.split()
chunks = []
start = 0
while start < len(words):
end = start + chunk_size
chunk = " ".join(words[start:end])
if chunk.strip():
chunks.append(chunk)
start = end - overlap
if end >= len(words):
break
return chunks if chunks else [text]
def load_all_qa_pairs(data_dir: Path) -> Generator[Dict, None, None]:
"""Load all QA pairs from parquet files."""
# MedQuAD
path = data_dir / "mediqa" / "medquad.parquet"
if path.exists():
df = pd.read_parquet(path)
for _, row in df.iterrows():
yield {
"question": row.get("Question", row.get("question", "")),
"answer": row.get("Answer", row.get("answer", "")),
"source": "MedQuAD"
}
print(f" Loaded MedQuAD: {len(df):,}")
# PubMedQA
path = data_dir / "pubmed" / "pubmedqa_labeled.parquet"
if path.exists():
df = pd.read_parquet(path)
for _, row in df.iterrows():
yield {
"question": row.get("question", ""),
"answer": row.get("long_answer", ""),
"source": "PubMedQA"
}
print(f" Loaded PubMedQA: {len(df):,}")
# MedMCQA
path = data_dir / "mediqa" / "medmcqa_train.parquet"
if path.exists():
df = pd.read_parquet(path)
count = 0
for _, row in df.iterrows():
answer = row.get("exp")
if answer and not pd.isna(answer):
yield {
"question": row.get("question", ""),
"answer": str(answer),
"source": f"MedMCQA"
}
count += 1
print(f" Loaded MedMCQA: {count:,}")
# HealthCareMagic
path = data_dir / "mediqa" / "healthcare_magic.parquet"
if path.exists():
df = pd.read_parquet(path)
for _, row in df.iterrows():
question = row.get("input", row.get("instruction", ""))
yield {
"question": question,
"answer": row.get("output", ""),
"source": "HealthCareMagic"
}
print(f" Loaded HealthCareMagic: {len(df):,}")
# MedQA USMLE
for filename in ["medqa_usmle_train.parquet", "medqa_usmle_test.parquet"]:
path = data_dir / "medqa" / filename
if path.exists():
df = pd.read_parquet(path)
for _, row in df.iterrows():
question = row.get("question", row.get("sent1", ""))
answer = row.get("answer", "")
options = row.get("options", [])
answer_idx = row.get("answer_idx", row.get("label", -1))
if options and isinstance(answer_idx, int) and 0 <= answer_idx < len(options):
answer = options[answer_idx]
if question and answer:
yield {
"question": question,
"answer": str(answer),
"source": "MedQA-USMLE"
}
print(f" Loaded {filename}: {len(df):,}")
# ChatDoctor
for filename in ["chatdoctor_icliniq.parquet", "chatdoctor_healthcaremagic.parquet"]:
path = data_dir / "chatdoctor" / filename
if path.exists():
df = pd.read_parquet(path)
for _, row in df.iterrows():
question = row.get("input", row.get("instruction", row.get("question", "")))
answer = row.get("output", row.get("answer", ""))
if question and answer:
yield {
"question": question,
"answer": answer,
"source": "ChatDoctor"
}
print(f" Loaded {filename}: {len(df):,}")
# Medical Meadow
meadow_dir = data_dir / "medical_meadow"
if meadow_dir.exists():
for parquet_file in meadow_dir.glob("*.parquet"):
df = pd.read_parquet(parquet_file)
for _, row in df.iterrows():
instruction = row.get("instruction", "")
input_text = row.get("input", "")
output_text = row.get("output", "")
question = instruction
if input_text:
question = f"{instruction}\n\n{input_text}" if instruction else input_text
if question and output_text:
yield {
"question": question,
"answer": output_text,
"source": f"MedicalMeadow"
}
print(f" Loaded {parquet_file.name}: {len(df):,}")
def main():
print("\n" + "=" * 60)
print(" BUILDING MEDICAL KNOWLEDGE BASE (Colab Version)")
print("=" * 60)
DATA_DIR = PROJECT_ROOT / "data" / "raw"
KB_DIR = PROJECT_ROOT / "data" / "knowledge_base_new"
# Initialize components
print("\n[1/4] Initializing components...")
embedder = SimpleEmbedder("all-MiniLM-L6-v2")
vector_store = SimpleVectorStore(
collection_name="medical_knowledge",
persist_directory=str(KB_DIR)
)
# Process documents
print("\n[2/4] Loading and processing documents...")
all_chunks = []
doc_count = 0
for qa in tqdm(load_all_qa_pairs(DATA_DIR), desc="Processing"):
content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}"
# Skip very short content
if len(content.strip()) < 50:
continue
# Chunk the content
chunks = chunk_text(content, chunk_size=512, overlap=50)
for i, chunk in enumerate(chunks):
all_chunks.append(TextChunk(
content=chunk,
source=qa['source'],
chunk_id=i + 1,
total_chunks=len(chunks),
metadata={"type": "qa_pair"}
))
doc_count += 1
# Periodic garbage collection
if doc_count % 50000 == 0:
gc.collect()
print(f" Processed {doc_count:,} documents, {len(all_chunks):,} chunks...")
print(f"\n Total documents: {doc_count:,}")
print(f" Total chunks: {len(all_chunks):,}")
# Generate embeddings and index
print("\n[3/4] Generating embeddings and indexing...")
batch_size = 500
total_chunks = len(all_chunks)
for i in tqdm(range(0, total_chunks, batch_size), desc="Indexing"):
batch = all_chunks[i : i + batch_size]
texts = [chunk.content for chunk in batch]
try:
embeddings = embedder.embed_documents(texts, batch_size=32)
metadatas = [
{
"source": chunk.source,
"chunk_id": chunk.chunk_id,
"total_chunks": chunk.total_chunks,
**chunk.metadata
}
for chunk in batch
]
ids = [f"chunk_{i + j}" for j in range(len(batch))]
vector_store.add_documents(
documents=texts,
embeddings=embeddings.tolist(),
metadatas=metadatas,
ids=ids
)
except Exception as e:
print(f"\n Error at batch {i}: {e}")
continue
if (i // batch_size) % 100 == 0:
gc.collect()
# Done
print("\n[4/4] Finalizing...")
final_count = vector_store.count()
print("\n" + "=" * 60)
print(" BUILD COMPLETE!")
print("=" * 60)
print(f" Documents processed: {doc_count:,}")
print(f" Chunks indexed: {final_count:,}")
print(f" Location: {KB_DIR}")
print("\nDownload the knowledge_base_new folder and replace your local one!")
if __name__ == "__main__":
main()