LeonardoMdSA's picture
another push
a2f233f
"""
multi_doc_chat/rag_service.py
RAG service using ModelLoader, FAISS, and local embeddings.
"""
from pathlib import Path
from typing import List, Optional
import yaml
from .model_loader import ModelLoader
try:
import faiss
except Exception:
faiss = None
# --- ROOT paths for FAISS and models ---
ROOT_DIR = Path(__file__).resolve().parent.parent
DEFAULT_FAISS_DIR = ROOT_DIR / "faiss_index"
MODELS_DIR = ROOT_DIR / "models"
# load config
CFG_PATH = ROOT_DIR / "configs" / "default.yaml"
if CFG_PATH.exists():
with open(CFG_PATH, "r") as f:
_CFG = yaml.safe_load(f)
else:
_CFG = {"faiss_dir": str(DEFAULT_FAISS_DIR)}
class RAGService:
def __init__(self, model_loader: Optional[ModelLoader] = None, faiss_dir: Optional[str] = None):
cfg_faiss = faiss_dir or _CFG.get("faiss_dir", str(DEFAULT_FAISS_DIR))
self.faiss_dir = Path(cfg_faiss)
self.faiss_dir.mkdir(parents=True, exist_ok=True)
self.loader = model_loader or ModelLoader(faiss_dir=str(self.faiss_dir))
self.documents: List[str] = self.loader.documents or []
self.index = None
if faiss:
self._try_load_index()
def _try_load_index(self):
idx_path = self.faiss_dir / "index.bin"
docs_path = self.faiss_dir / "docs.txt"
if idx_path.exists() and faiss:
self.index = faiss.read_index(str(idx_path))
if docs_path.exists():
with open(docs_path, "r", encoding="utf-8") as f:
self.documents = [line.rstrip("\n") for line in f.readlines()]
def ingest_documents(self, texts: List[str]):
if not texts:
return
# extend docs
self.documents.extend(texts)
if self.loader.embedder is None:
raise RuntimeError("Embedding model not loaded.")
embeddings = self.loader.embed(texts).astype("float32")
if faiss is None:
raise RuntimeError("faiss not available (install faiss-cpu)")
if self.index is None:
dim = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dim)
self.index.add(embeddings)
# persist
idx_path = self.faiss_dir / "index.bin"
docs_path = self.faiss_dir / "docs.txt"
faiss.write_index(self.index, str(idx_path))
with open(docs_path, "w", encoding="utf-8") as f:
for d in self.documents:
f.write(d.replace("\n", " ")+"\n")
def search(self, query: str, top_k: int = 3):
if self.index is None:
return []
q_vec = self.loader.embed([query]).astype("float32")
distances, indices = self.index.search(q_vec, top_k)
results = []
for idx in indices[0]:
if 0 <= idx < len(self.documents):
results.append(self.documents[idx])
return results
def query(self, question: str, top_k: int = 3, max_tokens: int = 256) -> str:
# retrieve
chunks = self.search(question, top_k=top_k)
context = "\n\n".join(chunks)
prompt = f"You are an assistant. Use the context to answer the question.\n\nCONTEXT:\n{context}\n\nQUESTION: {question}\n\nANSWER:"
return self.loader.answer_from_rag(prompt, max_tokens=max_tokens)
def create_rag_service(faiss_dir: str = str(DEFAULT_FAISS_DIR)) -> RAGService:
loader = ModelLoader()
rs = RAGService(model_loader=loader, faiss_dir=faiss_dir)
return rs