suryaprakash01's picture
Update src/vectordbs.py
f2a90d5 verified
#!/usr/bin/env python3
"""
BiomedBERT Embeddings → FAISS Vector Database
Creates searchable vector database from pathology reports
"""
import json
import numpy as np
from pathlib import Path
from typing import List, Dict
from tqdm import tqdm
import pickle
from datetime import datetime
import faiss
from sentence_transformers import SentenceTransformer
class BiomedBERTToVectorDB:
def __init__(
self,
faiss_index_type: str = "hnsw",
output_dir: str = "biomedbert_vector_db",
embedding_model: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
self.embedding_dim = 768
self.faiss_index_type = faiss_index_type
print("Loading BiomedBERT model...")
self.model = SentenceTransformer(embedding_model)
self.index = self._create_faiss_index()
self.chunks = []
self.chunk_id_to_idx = {}
self.stats = {
"total_files": 0,
"total_chunks": 0,
"total_entities": 0,
"entity_types": {},
"files_processed": []
}
def _create_faiss_index(self):
if self.faiss_index_type == "flat":
index = faiss.IndexFlatIP(self.embedding_dim)
elif self.faiss_index_type == "ivf":
quantizer = faiss.IndexFlatIP(self.embedding_dim)
index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, 100)
elif self.faiss_index_type == "hnsw":
index = faiss.IndexHNSWFlat(
self.embedding_dim,
32,
faiss.METRIC_INNER_PRODUCT
)
else:
raise ValueError("Unknown index type")
return index
def load_embedding_and_metadata(self, base_path: Path):
emb_file = base_path.parent / f"{base_path.stem}_embedding.npy"
if not emb_file.exists():
return None, None
embedding = np.load(emb_file)
json_file = base_path.parent / f"{base_path.stem}_nlp.json"
if not json_file.exists():
return embedding, {}
with open(json_file) as f:
metadata = json.load(f)
return embedding, metadata
def create_chunks_from_text(
self,
text: str,
filename: str,
entities: List[Dict],
chunk_size: int = 512
):
chunks = []
sentences = text.split(". ")
current_chunk_text = []
current_chunk_entities = []
chunk_id = 0
current_length = 0
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
sentence = sentence + ". "
length = len(sentence)
sentence_entities = [
e for e in entities
if e["text"].lower() in sentence.lower()
]
if current_length + length > chunk_size and current_chunk_text:
chunk_text = "".join(current_chunk_text)
chunks.append({
"chunk_id": chunk_id,
"text": chunk_text,
"filename": filename,
"entities": current_chunk_entities,
"entity_count": len(current_chunk_entities),
"entity_types": list(
set([e["type"] for e in current_chunk_entities])
)
})
chunk_id += 1
current_chunk_text = [sentence]
current_chunk_entities = sentence_entities.copy()
current_length = length
else:
current_chunk_text.append(sentence)
current_chunk_entities.extend(sentence_entities)
current_length += length
if current_chunk_text:
chunk_text = "".join(current_chunk_text)
chunks.append({
"chunk_id": chunk_id,
"text": chunk_text,
"filename": filename,
"entities": current_chunk_entities,
"entity_count": len(current_chunk_entities),
"entity_types": list(
set([e["type"] for e in current_chunk_entities])
)
})
return chunks
def process_file(self, embedding_file: Path, original_text_dir: Path):
base_name = embedding_file.stem.replace("_embedding", "")
base_path = embedding_file.parent / base_name
_, metadata = self.load_embedding_and_metadata(base_path)
filename = metadata.get("filename", base_name)
entities = metadata.get("entities", [])
self.stats["total_entities"] += len(entities)
for e in entities:
et = e.get("type", "UNKNOWN")
self.stats["entity_types"][et] = (
self.stats["entity_types"].get(et, 0) + 1
)
txt_file = original_text_dir / f"{base_name}.txt"
if not txt_file.exists():
return []
text = txt_file.read_text(encoding="utf-8")
if text.startswith("# GDC Pathology Report"):
lines = text.split("\n")
text = "\n".join([l for l in lines if not l.startswith("#")])
chunks = self.create_chunks_from_text(text, filename, entities)
texts = [c["text"] for c in chunks]
embeddings = self.model.encode(
texts,
batch_size=32,
normalize_embeddings=True,
show_progress_bar=False
)
results = []
for emb, chunk in zip(embeddings, chunks):
results.append((emb, chunk))
return results
def add_to_faiss(self, embeddings, chunks):
embeddings = embeddings.astype("float32")
if self.faiss_index_type == "ivf" and not self.index.is_trained:
self.index.train(embeddings)
start_idx = self.index.ntotal
self.index.add(embeddings)
for i, chunk in enumerate(chunks):
cid = f"{chunk['filename']}_{chunk['chunk_id']}"
self.chunk_id_to_idx[cid] = start_idx + i
self.chunks.append(chunk)
def process_directory(self, biomedbert_output_dir, original_text_dir):
emb_files = sorted(Path(biomedbert_output_dir).glob("*_embedding.npy"))
if not emb_files:
print("No embeddings found")
return
original_text_dir = Path(original_text_dir)
all_embeddings = []
all_chunks = []
for emb_file in tqdm(emb_files):
self.stats["total_files"] += 1
try:
results = self.process_file(emb_file, original_text_dir)
for emb, chunk in results:
all_embeddings.append(emb)
all_chunks.append(chunk)
self.stats["files_processed"].append(
emb_file.stem.replace("_embedding", "")
)
except Exception as e:
print("Error:", e)
if not all_embeddings:
print("No embeddings processed")
return
self.stats["total_chunks"] = len(all_chunks)
embeddings_matrix = np.vstack(all_embeddings)
self.add_to_faiss(embeddings_matrix, all_chunks)
self.save()
self.print_summary()
def save(self):
index_file = self.output_dir / "faiss.index"
faiss.write_index(self.index, str(index_file))
metadata_file = self.output_dir / "metadata.pkl"
with open(metadata_file, "wb") as f:
pickle.dump(
{
"chunks": self.chunks,
"chunk_id_to_idx": self.chunk_id_to_idx,
"embedding_dim": self.embedding_dim,
"index_type": self.faiss_index_type,
"model": "BiomedBERT"
},
f,
)
stats_file = self.output_dir / "stats.json"
with open(stats_file, "w") as f:
json.dump(
{
**self.stats,
"timestamp": datetime.now().isoformat(),
"embedding_dim": self.embedding_dim,
"index_type": self.faiss_index_type
},
f,
indent=2,
)
def print_summary(self):
print("\n==============================")
print("VECTOR DATABASE SUMMARY")
print("==============================")
print("Files processed:", self.stats["total_files"])
print("Chunks:", self.stats["total_chunks"])
print("Entities:", self.stats["total_entities"])
print("Vectors:", self.index.ntotal)
print("Index type:", self.faiss_index_type)
def main():
print("BiomedBERT → FAISS Vector DB")
biomedbert_output = "output/biomedbert_output"
original_texts = "output/extracted_text"
vector_db_output = "output/vector_db"
pipeline = BiomedBERTToVectorDB(
faiss_index_type="hnsw",
output_dir=vector_db_output
)
pipeline.process_directory(
biomedbert_output_dir=biomedbert_output,
original_text_dir=original_texts
)
print("Vector database created")
if __name__ == "__main__":
main()