| |
| """ |
| BiomedBERT Embeddings → FAISS Vector Database |
| Creates searchable vector database from pathology reports |
| """ |
|
|
| import json |
| import numpy as np |
| from pathlib import Path |
| from typing import List, Dict |
| from tqdm import tqdm |
| import pickle |
| from datetime import datetime |
|
|
| import faiss |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| class BiomedBERTToVectorDB: |
|
|
| def __init__( |
| self, |
| faiss_index_type: str = "hnsw", |
| output_dir: str = "biomedbert_vector_db", |
| embedding_model: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" |
| ): |
|
|
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(exist_ok=True) |
|
|
| self.embedding_dim = 768 |
| self.faiss_index_type = faiss_index_type |
|
|
| print("Loading BiomedBERT model...") |
| self.model = SentenceTransformer(embedding_model) |
|
|
| self.index = self._create_faiss_index() |
|
|
| self.chunks = [] |
| self.chunk_id_to_idx = {} |
|
|
| self.stats = { |
| "total_files": 0, |
| "total_chunks": 0, |
| "total_entities": 0, |
| "entity_types": {}, |
| "files_processed": [] |
| } |
|
|
| def _create_faiss_index(self): |
|
|
| if self.faiss_index_type == "flat": |
| index = faiss.IndexFlatIP(self.embedding_dim) |
|
|
| elif self.faiss_index_type == "ivf": |
| quantizer = faiss.IndexFlatIP(self.embedding_dim) |
| index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, 100) |
|
|
| elif self.faiss_index_type == "hnsw": |
| index = faiss.IndexHNSWFlat( |
| self.embedding_dim, |
| 32, |
| faiss.METRIC_INNER_PRODUCT |
| ) |
|
|
| else: |
| raise ValueError("Unknown index type") |
|
|
| return index |
|
|
| def load_embedding_and_metadata(self, base_path: Path): |
|
|
| emb_file = base_path.parent / f"{base_path.stem}_embedding.npy" |
|
|
| if not emb_file.exists(): |
| return None, None |
|
|
| embedding = np.load(emb_file) |
|
|
| json_file = base_path.parent / f"{base_path.stem}_nlp.json" |
|
|
| if not json_file.exists(): |
| return embedding, {} |
|
|
| with open(json_file) as f: |
| metadata = json.load(f) |
|
|
| return embedding, metadata |
|
|
| def create_chunks_from_text( |
| self, |
| text: str, |
| filename: str, |
| entities: List[Dict], |
| chunk_size: int = 512 |
| ): |
|
|
| chunks = [] |
|
|
| sentences = text.split(". ") |
|
|
| current_chunk_text = [] |
| current_chunk_entities = [] |
|
|
| chunk_id = 0 |
| current_length = 0 |
|
|
| for sentence in sentences: |
|
|
| sentence = sentence.strip() |
|
|
| if not sentence: |
| continue |
|
|
| sentence = sentence + ". " |
| length = len(sentence) |
|
|
| sentence_entities = [ |
| e for e in entities |
| if e["text"].lower() in sentence.lower() |
| ] |
|
|
| if current_length + length > chunk_size and current_chunk_text: |
|
|
| chunk_text = "".join(current_chunk_text) |
|
|
| chunks.append({ |
| "chunk_id": chunk_id, |
| "text": chunk_text, |
| "filename": filename, |
| "entities": current_chunk_entities, |
| "entity_count": len(current_chunk_entities), |
| "entity_types": list( |
| set([e["type"] for e in current_chunk_entities]) |
| ) |
| }) |
|
|
| chunk_id += 1 |
| current_chunk_text = [sentence] |
| current_chunk_entities = sentence_entities.copy() |
| current_length = length |
|
|
| else: |
|
|
| current_chunk_text.append(sentence) |
| current_chunk_entities.extend(sentence_entities) |
| current_length += length |
|
|
| if current_chunk_text: |
|
|
| chunk_text = "".join(current_chunk_text) |
|
|
| chunks.append({ |
| "chunk_id": chunk_id, |
| "text": chunk_text, |
| "filename": filename, |
| "entities": current_chunk_entities, |
| "entity_count": len(current_chunk_entities), |
| "entity_types": list( |
| set([e["type"] for e in current_chunk_entities]) |
| ) |
| }) |
|
|
| return chunks |
|
|
| def process_file(self, embedding_file: Path, original_text_dir: Path): |
|
|
| base_name = embedding_file.stem.replace("_embedding", "") |
|
|
| base_path = embedding_file.parent / base_name |
| _, metadata = self.load_embedding_and_metadata(base_path) |
|
|
| filename = metadata.get("filename", base_name) |
| entities = metadata.get("entities", []) |
|
|
| self.stats["total_entities"] += len(entities) |
|
|
| for e in entities: |
| et = e.get("type", "UNKNOWN") |
| self.stats["entity_types"][et] = ( |
| self.stats["entity_types"].get(et, 0) + 1 |
| ) |
|
|
| txt_file = original_text_dir / f"{base_name}.txt" |
|
|
| if not txt_file.exists(): |
| return [] |
|
|
| text = txt_file.read_text(encoding="utf-8") |
|
|
| if text.startswith("# GDC Pathology Report"): |
| lines = text.split("\n") |
| text = "\n".join([l for l in lines if not l.startswith("#")]) |
|
|
| chunks = self.create_chunks_from_text(text, filename, entities) |
|
|
| texts = [c["text"] for c in chunks] |
|
|
| embeddings = self.model.encode( |
| texts, |
| batch_size=32, |
| normalize_embeddings=True, |
| show_progress_bar=False |
| ) |
|
|
| results = [] |
|
|
| for emb, chunk in zip(embeddings, chunks): |
| results.append((emb, chunk)) |
|
|
| return results |
|
|
| def add_to_faiss(self, embeddings, chunks): |
|
|
| embeddings = embeddings.astype("float32") |
|
|
| if self.faiss_index_type == "ivf" and not self.index.is_trained: |
| self.index.train(embeddings) |
|
|
| start_idx = self.index.ntotal |
| self.index.add(embeddings) |
|
|
| for i, chunk in enumerate(chunks): |
|
|
| cid = f"{chunk['filename']}_{chunk['chunk_id']}" |
| self.chunk_id_to_idx[cid] = start_idx + i |
| self.chunks.append(chunk) |
|
|
| def process_directory(self, biomedbert_output_dir, original_text_dir): |
|
|
| emb_files = sorted(Path(biomedbert_output_dir).glob("*_embedding.npy")) |
|
|
| if not emb_files: |
| print("No embeddings found") |
| return |
|
|
| original_text_dir = Path(original_text_dir) |
|
|
| all_embeddings = [] |
| all_chunks = [] |
|
|
| for emb_file in tqdm(emb_files): |
|
|
| self.stats["total_files"] += 1 |
|
|
| try: |
|
|
| results = self.process_file(emb_file, original_text_dir) |
|
|
| for emb, chunk in results: |
| all_embeddings.append(emb) |
| all_chunks.append(chunk) |
|
|
| self.stats["files_processed"].append( |
| emb_file.stem.replace("_embedding", "") |
| ) |
|
|
| except Exception as e: |
| print("Error:", e) |
|
|
| if not all_embeddings: |
| print("No embeddings processed") |
| return |
|
|
| self.stats["total_chunks"] = len(all_chunks) |
|
|
| embeddings_matrix = np.vstack(all_embeddings) |
|
|
| self.add_to_faiss(embeddings_matrix, all_chunks) |
|
|
| self.save() |
|
|
| self.print_summary() |
|
|
| def save(self): |
|
|
| index_file = self.output_dir / "faiss.index" |
| faiss.write_index(self.index, str(index_file)) |
|
|
| metadata_file = self.output_dir / "metadata.pkl" |
|
|
| with open(metadata_file, "wb") as f: |
|
|
| pickle.dump( |
| { |
| "chunks": self.chunks, |
| "chunk_id_to_idx": self.chunk_id_to_idx, |
| "embedding_dim": self.embedding_dim, |
| "index_type": self.faiss_index_type, |
| "model": "BiomedBERT" |
| }, |
| f, |
| ) |
|
|
| stats_file = self.output_dir / "stats.json" |
|
|
| with open(stats_file, "w") as f: |
|
|
| json.dump( |
| { |
| **self.stats, |
| "timestamp": datetime.now().isoformat(), |
| "embedding_dim": self.embedding_dim, |
| "index_type": self.faiss_index_type |
| }, |
| f, |
| indent=2, |
| ) |
|
|
| def print_summary(self): |
|
|
| print("\n==============================") |
| print("VECTOR DATABASE SUMMARY") |
| print("==============================") |
|
|
| print("Files processed:", self.stats["total_files"]) |
| print("Chunks:", self.stats["total_chunks"]) |
| print("Entities:", self.stats["total_entities"]) |
| print("Vectors:", self.index.ntotal) |
|
|
| print("Index type:", self.faiss_index_type) |
|
|
|
|
| def main(): |
|
|
| print("BiomedBERT → FAISS Vector DB") |
|
|
| biomedbert_output = "output/biomedbert_output" |
| original_texts = "output/extracted_text" |
| vector_db_output = "output/vector_db" |
|
|
| pipeline = BiomedBERTToVectorDB( |
| faiss_index_type="hnsw", |
| output_dir=vector_db_output |
| ) |
|
|
| pipeline.process_directory( |
| biomedbert_output_dir=biomedbert_output, |
| original_text_dir=original_texts |
| ) |
|
|
| print("Vector database created") |
|
|
|
|
| if __name__ == "__main__": |
| main() |