Spaces:
Running
Running
File size: 7,922 Bytes
634117a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | """
kerdos_rag/core.py
High-level KerdosRAG façade — the primary interface for library consumers.
Usage:
from kerdos_rag import KerdosRAG
engine = KerdosRAG(hf_token="hf_...")
engine.index(["policy.pdf", "manual.docx"])
for token in engine.chat("What is the refund policy?"):
print(token, end="", flush=True)
"""
from __future__ import annotations
import json
import os
import pickle
from pathlib import Path
from typing import Generator
from rag.document_loader import load_documents
from rag.embedder import VectorIndex, build_index, add_to_index
from rag.retriever import retrieve
from rag.chain import answer_stream
_DEFAULT_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
_DEFAULT_TOP_K = 5
_DEFAULT_MIN_SCORE = 0.30
class KerdosRAG:
"""
Batteries-included RAG engine.
Args:
hf_token: Hugging Face API token. Falls back to HF_TOKEN env var.
model: HF model ID (e.g. 'mistralai/Mistral-7B-Instruct-v0.3').
Falls back to LLM_MODEL env var, then Llama 3.1 8B.
top_k: Number of chunks to retrieve per query.
min_score: Minimum cosine similarity threshold (chunks below this
are dropped before being sent to the LLM).
"""
def __init__(
self,
hf_token: str = "",
model: str | None = None,
top_k: int = _DEFAULT_TOP_K,
min_score: float = _DEFAULT_MIN_SCORE,
) -> None:
self.hf_token: str = hf_token.strip() or os.environ.get("HF_TOKEN", "")
self.model: str = model or os.environ.get("LLM_MODEL", _DEFAULT_MODEL)
self.top_k: int = top_k
self.min_score: float = min_score
self._index: VectorIndex | None = None
self._indexed_sources: set[str] = set()
# ── Properties ────────────────────────────────────────────────────────────
@property
def indexed_sources(self) -> set[str]:
"""File names currently in the knowledge base."""
return set(self._indexed_sources)
@property
def chunk_count(self) -> int:
"""Total number of vector chunks in the index."""
return self._index.index.ntotal if self._index else 0
@property
def is_ready(self) -> bool:
"""True when at least one document has been indexed."""
return self._index is not None and self.chunk_count > 0
# ── Core operations ───────────────────────────────────────────────────────
def index(self, file_paths: list[str]) -> dict:
"""
Parse and index documents into the knowledge base.
Duplicate filenames are automatically skipped.
Args:
file_paths: Absolute or relative paths to PDF, DOCX, TXT, MD, or CSV files.
Returns:
{
"indexed": ["file1.pdf", ...], # newly indexed
"skipped": ["dup.pdf", ...], # already in index
"chunk_count": 142 # total chunks
}
"""
paths = [str(p) for p in file_paths]
new_paths, skipped = [], []
for p in paths:
name = Path(p).name
if name in self._indexed_sources:
skipped.append(name)
else:
new_paths.append(p)
if not new_paths:
return {"indexed": [], "skipped": skipped, "chunk_count": self.chunk_count}
docs = load_documents(new_paths)
if not docs:
raise ValueError("Could not extract text from any of the provided files.")
if self._index is None:
self._index = build_index(docs)
else:
self._index = add_to_index(self._index, docs)
newly_indexed = list({d["source"] for d in docs})
self._indexed_sources.update(newly_indexed)
return {
"indexed": newly_indexed,
"skipped": skipped,
"chunk_count": self.chunk_count,
}
def chat(
self,
query: str,
history: list[dict] | None = None,
) -> Generator[str, None, None]:
"""
Ask a question and stream the answer token-by-token.
Args:
query: The user's question.
history: Optional list of prior messages in
[{"role": "user"|"assistant", "content": "..."}] format.
Yields:
Progressively-growing answer strings (suitable for real-time display).
Raises:
RuntimeError: If no documents have been indexed yet.
ValueError: If no HF token is available.
"""
if not self.is_ready:
raise RuntimeError("No documents indexed. Call engine.index(file_paths) first.")
if not self.hf_token:
raise ValueError(
"No Hugging Face token. Pass hf_token= to KerdosRAG() or set HF_TOKEN env var."
)
# Temporarily patch retriever's MIN_SCORE with instance setting
import rag.retriever as _r
original_min = _r.MIN_SCORE
_r.MIN_SCORE = self.min_score
try:
chunks = retrieve(query, self._index, top_k=self.top_k)
yield from answer_stream(query, chunks, self.hf_token, chat_history=history)
finally:
_r.MIN_SCORE = original_min
def reset(self) -> None:
"""Clear the knowledge base."""
self._index = None
self._indexed_sources = set()
# ── Persistence ───────────────────────────────────────────────────────────
def save(self, directory: str | Path) -> None:
"""
Persist the index to disk so it can be reloaded across sessions.
Creates two files in `directory`:
- ``kerdos_index.faiss`` — the raw FAISS vectors
- ``kerdos_meta.pkl`` — chunks + source tracking
Args:
directory: Path to a folder (will be created if needed).
"""
import faiss
if not self.is_ready:
raise RuntimeError("Nothing to save — index is empty.")
out = Path(directory)
out.mkdir(parents=True, exist_ok=True)
faiss.write_index(self._index.index, str(out / "kerdos_index.faiss"))
meta = {
"chunks": self._index.chunks,
"indexed_sources": list(self._indexed_sources),
"model": self.model,
"top_k": self.top_k,
"min_score": self.min_score,
}
with open(out / "kerdos_meta.pkl", "wb") as f:
pickle.dump(meta, f)
@classmethod
def load(cls, directory: str | Path, hf_token: str = "") -> "KerdosRAG":
"""
Restore an engine from a directory previously written by :meth:`save`.
Args:
directory: Folder containing ``kerdos_index.faiss`` and ``kerdos_meta.pkl``.
hf_token: HF token for chat (can also be set via HF_TOKEN env var).
Returns:
A fully initialised :class:`KerdosRAG` instance.
"""
import faiss
from rag.embedder import _get_model
d = Path(directory)
with open(d / "kerdos_meta.pkl", "rb") as f:
meta = pickle.load(f)
engine = cls(
hf_token=hf_token,
model=meta["model"],
top_k=meta["top_k"],
min_score=meta["min_score"],
)
model = _get_model()
idx = faiss.read_index(str(d / "kerdos_index.faiss"))
engine._index = VectorIndex(chunks=meta["chunks"], index=idx, embedder=model)
engine._indexed_sources = set(meta["indexed_sources"])
return engine
|