File size: 7,922 Bytes
634117a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
kerdos_rag/core.py
High-level KerdosRAG façade — the primary interface for library consumers.

Usage:
    from kerdos_rag import KerdosRAG

    engine = KerdosRAG(hf_token="hf_...")
    engine.index(["policy.pdf", "manual.docx"])
    for token in engine.chat("What is the refund policy?"):
        print(token, end="", flush=True)
"""

from __future__ import annotations

import json
import os
import pickle
from pathlib import Path
from typing import Generator

from rag.document_loader import load_documents
from rag.embedder import VectorIndex, build_index, add_to_index
from rag.retriever import retrieve
from rag.chain import answer_stream

_DEFAULT_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
_DEFAULT_TOP_K = 5
_DEFAULT_MIN_SCORE = 0.30


class KerdosRAG:
    """
    Batteries-included RAG engine.

    Args:
        hf_token:   Hugging Face API token. Falls back to HF_TOKEN env var.
        model:      HF model ID (e.g. 'mistralai/Mistral-7B-Instruct-v0.3').
                    Falls back to LLM_MODEL env var, then Llama 3.1 8B.
        top_k:      Number of chunks to retrieve per query.
        min_score:  Minimum cosine similarity threshold (chunks below this
                    are dropped before being sent to the LLM).
    """

    def __init__(
        self,
        hf_token: str = "",
        model: str | None = None,
        top_k: int = _DEFAULT_TOP_K,
        min_score: float = _DEFAULT_MIN_SCORE,
    ) -> None:
        self.hf_token: str = hf_token.strip() or os.environ.get("HF_TOKEN", "")
        self.model: str = model or os.environ.get("LLM_MODEL", _DEFAULT_MODEL)
        self.top_k: int = top_k
        self.min_score: float = min_score

        self._index: VectorIndex | None = None
        self._indexed_sources: set[str] = set()

    # ── Properties ────────────────────────────────────────────────────────────

    @property
    def indexed_sources(self) -> set[str]:
        """File names currently in the knowledge base."""
        return set(self._indexed_sources)

    @property
    def chunk_count(self) -> int:
        """Total number of vector chunks in the index."""
        return self._index.index.ntotal if self._index else 0

    @property
    def is_ready(self) -> bool:
        """True when at least one document has been indexed."""
        return self._index is not None and self.chunk_count > 0

    # ── Core operations ───────────────────────────────────────────────────────

    def index(self, file_paths: list[str]) -> dict:
        """
        Parse and index documents into the knowledge base.

        Duplicate filenames are automatically skipped.

        Args:
            file_paths: Absolute or relative paths to PDF, DOCX, TXT, MD, or CSV files.

        Returns:
            {
              "indexed": ["file1.pdf", ...],   # newly indexed
              "skipped": ["dup.pdf", ...],      # already in index
              "chunk_count": 142               # total chunks
            }
        """
        paths = [str(p) for p in file_paths]

        new_paths, skipped = [], []
        for p in paths:
            name = Path(p).name
            if name in self._indexed_sources:
                skipped.append(name)
            else:
                new_paths.append(p)

        if not new_paths:
            return {"indexed": [], "skipped": skipped, "chunk_count": self.chunk_count}

        docs = load_documents(new_paths)
        if not docs:
            raise ValueError("Could not extract text from any of the provided files.")

        if self._index is None:
            self._index = build_index(docs)
        else:
            self._index = add_to_index(self._index, docs)

        newly_indexed = list({d["source"] for d in docs})
        self._indexed_sources.update(newly_indexed)

        return {
            "indexed": newly_indexed,
            "skipped": skipped,
            "chunk_count": self.chunk_count,
        }

    def chat(
        self,
        query: str,
        history: list[dict] | None = None,
    ) -> Generator[str, None, None]:
        """
        Ask a question and stream the answer token-by-token.

        Args:
            query:   The user's question.
            history: Optional list of prior messages in
                     [{"role": "user"|"assistant", "content": "..."}] format.

        Yields:
            Progressively-growing answer strings (suitable for real-time display).

        Raises:
            RuntimeError: If no documents have been indexed yet.
            ValueError:   If no HF token is available.
        """
        if not self.is_ready:
            raise RuntimeError("No documents indexed. Call engine.index(file_paths) first.")
        if not self.hf_token:
            raise ValueError(
                "No Hugging Face token. Pass hf_token= to KerdosRAG() or set HF_TOKEN env var."
            )

        # Temporarily patch retriever's MIN_SCORE with instance setting
        import rag.retriever as _r
        original_min = _r.MIN_SCORE
        _r.MIN_SCORE = self.min_score
        try:
            chunks = retrieve(query, self._index, top_k=self.top_k)
            yield from answer_stream(query, chunks, self.hf_token, chat_history=history)
        finally:
            _r.MIN_SCORE = original_min

    def reset(self) -> None:
        """Clear the knowledge base."""
        self._index = None
        self._indexed_sources = set()

    # ── Persistence ───────────────────────────────────────────────────────────

    def save(self, directory: str | Path) -> None:
        """
        Persist the index to disk so it can be reloaded across sessions.

        Creates two files in `directory`:
          - ``kerdos_index.faiss``  — the raw FAISS vectors
          - ``kerdos_meta.pkl``     — chunks + source tracking

        Args:
            directory: Path to a folder (will be created if needed).
        """
        import faiss

        if not self.is_ready:
            raise RuntimeError("Nothing to save — index is empty.")

        out = Path(directory)
        out.mkdir(parents=True, exist_ok=True)

        faiss.write_index(self._index.index, str(out / "kerdos_index.faiss"))
        meta = {
            "chunks": self._index.chunks,
            "indexed_sources": list(self._indexed_sources),
            "model": self.model,
            "top_k": self.top_k,
            "min_score": self.min_score,
        }
        with open(out / "kerdos_meta.pkl", "wb") as f:
            pickle.dump(meta, f)

    @classmethod
    def load(cls, directory: str | Path, hf_token: str = "") -> "KerdosRAG":
        """
        Restore an engine from a directory previously written by :meth:`save`.

        Args:
            directory: Folder containing ``kerdos_index.faiss`` and ``kerdos_meta.pkl``.
            hf_token:  HF token for chat (can also be set via HF_TOKEN env var).

        Returns:
            A fully initialised :class:`KerdosRAG` instance.
        """
        import faiss
        from rag.embedder import _get_model

        d = Path(directory)
        with open(d / "kerdos_meta.pkl", "rb") as f:
            meta = pickle.load(f)

        engine = cls(
            hf_token=hf_token,
            model=meta["model"],
            top_k=meta["top_k"],
            min_score=meta["min_score"],
        )
        model = _get_model()
        idx = faiss.read_index(str(d / "kerdos_index.faiss"))
        engine._index = VectorIndex(chunks=meta["chunks"], index=idx, embedder=model)
        engine._indexed_sources = set(meta["indexed_sources"])
        return engine