Spaces:
Running
Running
File size: 2,031 Bytes
2a2c039 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | from __future__ import annotations
from pathlib import Path
import os
import warnings
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from .config import BioRAGConfig
from .data_loader import PubMedQASample
class KnowledgeBaseBuilder:
def __init__(self, config: BioRAGConfig) -> None:
self.config = config
self.embeddings = HuggingFaceEmbeddings(
model_name=config.embedding_model,
show_progress=True,
encode_kwargs={"batch_size": 32}
)
def build(self, samples: list[PubMedQASample]) -> FAISS:
documents = [
Document(
page_content=sample.context,
metadata={
"qid": sample.qid,
"question": sample.question,
"answer": sample.answer,
"authors": sample.authors,
"year": sample.year,
"journal": sample.journal,
"title": sample.title,
},
)
for sample in samples
]
return FAISS.from_documents(documents, self.embeddings)
def save(self, vectorstore: FAISS) -> None:
self.config.index_path.mkdir(parents=True, exist_ok=True)
vectorstore.save_local(str(self.config.index_path))
def load_or_build(self, samples: list[PubMedQASample]) -> FAISS:
path = self.config.index_path
if _looks_like_faiss_index(path):
return FAISS.load_local(
str(path),
self.embeddings,
allow_dangerous_deserialization=True,
)
vectorstore = self.build(samples)
self.save(vectorstore)
return vectorstore
def _looks_like_faiss_index(path: Path) -> bool:
return path.exists() and (path / "index.faiss").exists() and (path / "index.pkl").exists()
|