Spaces:
Running
Running
| """ | |
| kerdos_rag/core.py | |
| High-level KerdosRAG façade — the primary interface for library consumers. | |
| Usage: | |
| from kerdos_rag import KerdosRAG | |
| engine = KerdosRAG(hf_token="hf_...") | |
| engine.index(["policy.pdf", "manual.docx"]) | |
| for token in engine.chat("What is the refund policy?"): | |
| print(token, end="", flush=True) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import pickle | |
| from pathlib import Path | |
| from typing import Generator | |
| from rag.document_loader import load_documents | |
| from rag.embedder import VectorIndex, build_index, add_to_index | |
| from rag.retriever import retrieve | |
| from rag.chain import answer_stream | |
| _DEFAULT_MODEL = "meta-llama/Llama-3.1-8B-Instruct" | |
| _DEFAULT_TOP_K = 5 | |
| _DEFAULT_MIN_SCORE = 0.30 | |
| class KerdosRAG: | |
| """ | |
| Batteries-included RAG engine. | |
| Args: | |
| hf_token: Hugging Face API token. Falls back to HF_TOKEN env var. | |
| model: HF model ID (e.g. 'mistralai/Mistral-7B-Instruct-v0.3'). | |
| Falls back to LLM_MODEL env var, then Llama 3.1 8B. | |
| top_k: Number of chunks to retrieve per query. | |
| min_score: Minimum cosine similarity threshold (chunks below this | |
| are dropped before being sent to the LLM). | |
| """ | |
| def __init__( | |
| self, | |
| hf_token: str = "", | |
| model: str | None = None, | |
| top_k: int = _DEFAULT_TOP_K, | |
| min_score: float = _DEFAULT_MIN_SCORE, | |
| ) -> None: | |
| self.hf_token: str = hf_token.strip() or os.environ.get("HF_TOKEN", "") | |
| self.model: str = model or os.environ.get("LLM_MODEL", _DEFAULT_MODEL) | |
| self.top_k: int = top_k | |
| self.min_score: float = min_score | |
| self._index: VectorIndex | None = None | |
| self._indexed_sources: set[str] = set() | |
| # ── Properties ──────────────────────────────────────────────────────────── | |
| def indexed_sources(self) -> set[str]: | |
| """File names currently in the knowledge base.""" | |
| return set(self._indexed_sources) | |
| def chunk_count(self) -> int: | |
| """Total number of vector chunks in the index.""" | |
| return self._index.index.ntotal if self._index else 0 | |
| def is_ready(self) -> bool: | |
| """True when at least one document has been indexed.""" | |
| return self._index is not None and self.chunk_count > 0 | |
| # ── Core operations ─────────────────────────────────────────────────────── | |
| def index(self, file_paths: list[str]) -> dict: | |
| """ | |
| Parse and index documents into the knowledge base. | |
| Duplicate filenames are automatically skipped. | |
| Args: | |
| file_paths: Absolute or relative paths to PDF, DOCX, TXT, MD, or CSV files. | |
| Returns: | |
| { | |
| "indexed": ["file1.pdf", ...], # newly indexed | |
| "skipped": ["dup.pdf", ...], # already in index | |
| "chunk_count": 142 # total chunks | |
| } | |
| """ | |
| paths = [str(p) for p in file_paths] | |
| new_paths, skipped = [], [] | |
| for p in paths: | |
| name = Path(p).name | |
| if name in self._indexed_sources: | |
| skipped.append(name) | |
| else: | |
| new_paths.append(p) | |
| if not new_paths: | |
| return {"indexed": [], "skipped": skipped, "chunk_count": self.chunk_count} | |
| docs = load_documents(new_paths) | |
| if not docs: | |
| raise ValueError("Could not extract text from any of the provided files.") | |
| if self._index is None: | |
| self._index = build_index(docs) | |
| else: | |
| self._index = add_to_index(self._index, docs) | |
| newly_indexed = list({d["source"] for d in docs}) | |
| self._indexed_sources.update(newly_indexed) | |
| return { | |
| "indexed": newly_indexed, | |
| "skipped": skipped, | |
| "chunk_count": self.chunk_count, | |
| } | |
| def chat( | |
| self, | |
| query: str, | |
| history: list[dict] | None = None, | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Ask a question and stream the answer token-by-token. | |
| Args: | |
| query: The user's question. | |
| history: Optional list of prior messages in | |
| [{"role": "user"|"assistant", "content": "..."}] format. | |
| Yields: | |
| Progressively-growing answer strings (suitable for real-time display). | |
| Raises: | |
| RuntimeError: If no documents have been indexed yet. | |
| ValueError: If no HF token is available. | |
| """ | |
| if not self.is_ready: | |
| raise RuntimeError("No documents indexed. Call engine.index(file_paths) first.") | |
| if not self.hf_token: | |
| raise ValueError( | |
| "No Hugging Face token. Pass hf_token= to KerdosRAG() or set HF_TOKEN env var." | |
| ) | |
| # Temporarily patch retriever's MIN_SCORE with instance setting | |
| import rag.retriever as _r | |
| original_min = _r.MIN_SCORE | |
| _r.MIN_SCORE = self.min_score | |
| try: | |
| chunks = retrieve(query, self._index, top_k=self.top_k) | |
| yield from answer_stream(query, chunks, self.hf_token, chat_history=history) | |
| finally: | |
| _r.MIN_SCORE = original_min | |
| def reset(self) -> None: | |
| """Clear the knowledge base.""" | |
| self._index = None | |
| self._indexed_sources = set() | |
| # ── Persistence ─────────────────────────────────────────────────────────── | |
| def save(self, directory: str | Path) -> None: | |
| """ | |
| Persist the index to disk so it can be reloaded across sessions. | |
| Creates two files in `directory`: | |
| - ``kerdos_index.faiss`` — the raw FAISS vectors | |
| - ``kerdos_meta.pkl`` — chunks + source tracking | |
| Args: | |
| directory: Path to a folder (will be created if needed). | |
| """ | |
| import faiss | |
| if not self.is_ready: | |
| raise RuntimeError("Nothing to save — index is empty.") | |
| out = Path(directory) | |
| out.mkdir(parents=True, exist_ok=True) | |
| faiss.write_index(self._index.index, str(out / "kerdos_index.faiss")) | |
| meta = { | |
| "chunks": self._index.chunks, | |
| "indexed_sources": list(self._indexed_sources), | |
| "model": self.model, | |
| "top_k": self.top_k, | |
| "min_score": self.min_score, | |
| } | |
| with open(out / "kerdos_meta.pkl", "wb") as f: | |
| pickle.dump(meta, f) | |
| def load(cls, directory: str | Path, hf_token: str = "") -> "KerdosRAG": | |
| """ | |
| Restore an engine from a directory previously written by :meth:`save`. | |
| Args: | |
| directory: Folder containing ``kerdos_index.faiss`` and ``kerdos_meta.pkl``. | |
| hf_token: HF token for chat (can also be set via HF_TOKEN env var). | |
| Returns: | |
| A fully initialised :class:`KerdosRAG` instance. | |
| """ | |
| import faiss | |
| from rag.embedder import _get_model | |
| d = Path(directory) | |
| with open(d / "kerdos_meta.pkl", "rb") as f: | |
| meta = pickle.load(f) | |
| engine = cls( | |
| hf_token=hf_token, | |
| model=meta["model"], | |
| top_k=meta["top_k"], | |
| min_score=meta["min_score"], | |
| ) | |
| model = _get_model() | |
| idx = faiss.read_index(str(d / "kerdos_index.faiss")) | |
| engine._index = VectorIndex(chunks=meta["chunks"], index=idx, embedder=model) | |
| engine._indexed_sources = set(meta["indexed_sources"]) | |
| return engine | |