File size: 2,031 Bytes
2a2c039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import annotations

from pathlib import Path
import os
import warnings

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

from .config import BioRAGConfig
from .data_loader import PubMedQASample


class KnowledgeBaseBuilder:
    def __init__(self, config: BioRAGConfig) -> None:
        self.config = config
        self.embeddings = HuggingFaceEmbeddings(
            model_name=config.embedding_model,
            show_progress=True,
            encode_kwargs={"batch_size": 32}
        )

    def build(self, samples: list[PubMedQASample]) -> FAISS:
        documents = [
            Document(
                page_content=sample.context,
                metadata={
                    "qid": sample.qid,
                    "question": sample.question,
                    "answer": sample.answer,
                    "authors": sample.authors,
                    "year": sample.year,
                    "journal": sample.journal,
                    "title": sample.title,
                },
            )
            for sample in samples
        ]
        return FAISS.from_documents(documents, self.embeddings)

    def save(self, vectorstore: FAISS) -> None:
        self.config.index_path.mkdir(parents=True, exist_ok=True)
        vectorstore.save_local(str(self.config.index_path))

    def load_or_build(self, samples: list[PubMedQASample]) -> FAISS:
        path = self.config.index_path
        if _looks_like_faiss_index(path):
            return FAISS.load_local(
                str(path),
                self.embeddings,
                allow_dangerous_deserialization=True,
            )

        vectorstore = self.build(samples)
        self.save(vectorstore)
        return vectorstore


def _looks_like_faiss_index(path: Path) -> bool:
    return path.exists() and (path / "index.faiss").exists() and (path / "index.pkl").exists()