""" FR-03 + FR-03b: Embedding Generation & FAISS Index Construction =============================================================== Model : dmis-lab/biobert-v1.1 (768-dim dense vectors, SentenceTransformer) Index : FAISS IndexFlatIP with L2-normalized vectors (= cosine similarity) Metadata: Parallel dict[int → dict] saved as pickle alongside index Usage: python src/pipeline/embedder.py """ from __future__ import annotations import sys import os from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) import json import logging import pickle import faiss import numpy as np import yaml import src # noqa: F401 — logging setup logger = logging.getLogger(__name__) def _load_config() -> dict: with open("config.yaml", "r", encoding="utf-8") as f: return yaml.safe_load(f) def load_chunks(chunks_path: str = "data/processed/chunks.jsonl") -> list[dict]: """Load chunks from JSONL produced by ingest.py.""" path = Path(chunks_path) if not path.exists(): raise FileNotFoundError( f"Chunks file not found: '{chunks_path}'. " "Run python src/pipeline/ingest.py first." ) chunks = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: chunks.append(json.loads(line)) logger.info("Loaded %d chunks from %s", len(chunks), chunks_path) return chunks def encode_texts( texts: list[str], model_name: str, batch_size: int = 32, ) -> np.ndarray: """ Encode texts using BioBERT via SentenceTransformer. Returns L2-normalized float32 array of shape (N, 768). """ from sentence_transformers import SentenceTransformer logger.info("Loading embedding model: %s", model_name) model = SentenceTransformer(model_name) logger.info("Encoding %d texts (batch_size=%d)...", len(texts), batch_size) embeddings: np.ndarray = model.encode( texts, batch_size=batch_size, show_progress_bar=True, normalize_embeddings=True, # L2-normalise → cosine via IndexFlatIP convert_to_numpy=True, ) logger.info("Encoded shape: %s", embeddings.shape) return embeddings.astype(np.float32) def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatIP: """ Build FAISS IndexFlatIP. Because vectors are L2-normalised, inner product == cosine similarity. """ dim = embeddings.shape[1] # 768 for BioBERT index = faiss.IndexFlatIP(dim) index.add(embeddings) logger.info( "FAISS IndexFlatIP built: %d vectors, dim=%d", index.ntotal, dim ) return index def build_metadata_store(chunks: list[dict]) -> dict[int, dict]: """ Build parallel metadata dict → key = FAISS integer index (0-based). Stores the full FR-03b schema plus chunk_text for retrieval. """ store: dict[int, dict] = {} for i, chunk in enumerate(chunks): store[i] = { "chunk_id": chunk["chunk_id"], "doc_id": chunk["doc_id"], "source": chunk["source"], "title": chunk["title"], "pub_type": chunk["pub_type"], "pub_year": chunk["pub_year"], "journal": chunk["journal"], "chunk_index": chunk["chunk_index"], "total_chunks": chunk["total_chunks"], "chunk_text": chunk["chunk_text"], # kept for retrieval } return store def save_artifacts( index: faiss.IndexFlatIP, metadata_store: dict, config: dict, ) -> None: """Persist FAISS index and metadata pickle to disk.""" index_path = Path(config["retrieval"]["index_path"]) meta_path = Path(config["retrieval"]["metadata_path"]) index_path.parent.mkdir(parents=True, exist_ok=True) meta_path.parent.mkdir(parents=True, exist_ok=True) faiss.write_index(index, str(index_path)) logger.info("FAISS index written to %s", index_path) with open(meta_path, "wb") as f: pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info( "Metadata store written to %s (%d entries)", meta_path, len(metadata_store) ) def main() -> None: config = _load_config() chunks = load_chunks("data/processed/chunks.jsonl") if not chunks: logger.error("No chunks to embed. Run python src/pipeline/ingest.py first.") sys.exit(1) texts = [c["chunk_text"] for c in chunks] model_name = config["retrieval"]["embedding_model"] embeddings = encode_texts(texts, model_name, batch_size=32) index = build_faiss_index(embeddings) metadata_store = build_metadata_store(chunks) save_artifacts(index, metadata_store, config) logger.info( "Embedding complete. Index has %d vectors. " "Next: python scripts/warmup.py && streamlit run src/dashboard/app.py", index.ntotal, ) if __name__ == "__main__": main()