BioRAG / src /bio_rag /knowledge_base.py
aseelflihan's picture
Deploy Bio-RAG
2a2c039
from __future__ import annotations
from pathlib import Path
import os
import warnings
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from .config import BioRAGConfig
from .data_loader import PubMedQASample
class KnowledgeBaseBuilder:
def __init__(self, config: BioRAGConfig) -> None:
self.config = config
self.embeddings = HuggingFaceEmbeddings(
model_name=config.embedding_model,
show_progress=True,
encode_kwargs={"batch_size": 32}
)
def build(self, samples: list[PubMedQASample]) -> FAISS:
documents = [
Document(
page_content=sample.context,
metadata={
"qid": sample.qid,
"question": sample.question,
"answer": sample.answer,
"authors": sample.authors,
"year": sample.year,
"journal": sample.journal,
"title": sample.title,
},
)
for sample in samples
]
return FAISS.from_documents(documents, self.embeddings)
def save(self, vectorstore: FAISS) -> None:
self.config.index_path.mkdir(parents=True, exist_ok=True)
vectorstore.save_local(str(self.config.index_path))
def load_or_build(self, samples: list[PubMedQASample]) -> FAISS:
path = self.config.index_path
if _looks_like_faiss_index(path):
return FAISS.load_local(
str(path),
self.embeddings,
allow_dangerous_deserialization=True,
)
vectorstore = self.build(samples)
self.save(vectorstore)
return vectorstore
def _looks_like_faiss_index(path: Path) -> bool:
return path.exists() and (path / "index.faiss").exists() and (path / "index.pkl").exists()