Spaces:
Sleeping
Sleeping
JerameeUC
12 Commit PyTest Working But Failing for some. The individual sections need to be completed to fix.
0c4f0e3 | # /memory/rag/retriever.py | |
| """ | |
| Minimal RAG retriever that sits on top of the TF-IDF indexer. | |
| Features | |
| - Top-k document retrieval via indexer.search() | |
| - Optional filters (tags, title substring) | |
| - Passage extraction around query terms with overlap | |
| - Lightweight proximity-based reranking of passages | |
| No third-party dependencies; pairs with memory/rag/indexer.py. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Iterable, List, Optional, Tuple | |
| from pathlib import Path | |
| import re | |
| from .indexer import ( | |
| load_index, | |
| search as index_search, | |
| DEFAULT_INDEX_PATH, | |
| tokenize, | |
| TfidfIndex, | |
| DocMeta, | |
| ) | |
| # ----------------------------- | |
| # Public types | |
| # ----------------------------- | |
| class Passage: | |
| doc_id: str | |
| source: str | |
| title: Optional[str] | |
| tags: Optional[List[str]] | |
| score: float # combined score (index score +/- rerank) | |
| start: int # char start in original text | |
| end: int # char end in original text | |
| text: str # extracted passage | |
| snippet: str # human-friendly short snippet (may equal text if short) | |
| class Filters: | |
| title_contains: Optional[str] = None # case-insensitive containment | |
| require_tags: Optional[Iterable[str]] = None # all tags must be present (AND) | |
| # ----------------------------- | |
| # Retrieval API | |
| # ----------------------------- | |
| def retrieve( | |
| query: str, | |
| k: int = 5, | |
| index_path: str | Path = DEFAULT_INDEX_PATH, | |
| filters: Optional[Filters] = None, | |
| passage_chars: int = 350, | |
| passage_overlap: int = 60, | |
| enable_rerank: bool = True, | |
| ) -> List[Passage]: | |
| """ | |
| Retrieve top-k passages for a query. | |
| Steps: | |
| 1. Run TF-IDF doc search | |
| 2. Apply optional filters | |
| 3. Extract a focused passage per doc | |
| 4. (Optional) Rerank by term proximity within the passage | |
| """ | |
| idx = load_index(index_path) | |
| if idx.n_docs == 0 or not query.strip(): | |
| return [] | |
| hits = index_search(query, k=max(k * 3, k), path=index_path) # overshoot; filter+rerank will trim | |
| if filters: | |
| hits = _apply_filters(hits, idx, filters) | |
| q_tokens = tokenize(query) | |
| passages: List[Passage] = [] | |
| for h in hits: | |
| doc = idx.docs.get(h.doc_id) | |
| if not doc: | |
| continue | |
| meta: DocMeta = doc["meta"] | |
| full_text: str = doc.get("text", "") or "" | |
| start, end, passage_text = _extract_passage(full_text, q_tokens, window=passage_chars, overlap=passage_overlap) | |
| snippet = passage_text if len(passage_text) <= 220 else passage_text[:220].rstrip() + "…" | |
| passages.append(Passage( | |
| doc_id=h.doc_id, | |
| source=meta.source, | |
| title=meta.title, | |
| tags=meta.tags, | |
| score=float(h.score), | |
| start=start, | |
| end=end, | |
| text=passage_text, | |
| snippet=snippet, | |
| )) | |
| if not passages: | |
| return [] | |
| if enable_rerank: | |
| passages = _rerank_by_proximity(passages, q_tokens) | |
| passages.sort(key=lambda p: p.score, reverse=True) | |
| return passages[:k] | |
| def retrieve_texts(query: str, k: int = 5, **kwargs) -> List[str]: | |
| """Convenience: return only the passage texts for a query.""" | |
| return [p.text for p in retrieve(query, k=k, **kwargs)] | |
| # ----------------------------- | |
| # Internals | |
| # ----------------------------- | |
| def _apply_filters(hits, idx: TfidfIndex, filters: Filters): | |
| out = [] | |
| want_title = (filters.title_contains or "").strip().lower() or None | |
| want_tags = set(t.strip().lower() for t in (filters.require_tags or []) if str(t).strip()) | |
| for h in hits: | |
| d = idx.docs.get(h.doc_id) | |
| if not d: | |
| continue | |
| meta: DocMeta = d["meta"] | |
| if want_title: | |
| t = (meta.title or "").lower() | |
| if want_title not in t: | |
| continue | |
| if want_tags: | |
| tags = set((meta.tags or [])) | |
| tags = set(x.lower() for x in tags) | |
| if not want_tags.issubset(tags): | |
| continue | |
| out.append(h) | |
| return out | |
| _WORD_RE = re.compile(r"[A-Za-z0-9']+") | |
| def _find_all(term: str, text: str) -> List[int]: | |
| """Return starting indices of all case-insensitive matches of term in text.""" | |
| if not term or not text: | |
| return [] | |
| term_l = term.lower() | |
| low = text.lower() | |
| out: List[int] = [] | |
| i = low.find(term_l) | |
| while i >= 0: | |
| out.append(i) | |
| i = low.find(term_l, i + 1) | |
| return out | |
| def _extract_passage(text: str, q_tokens: List[str], window: int = 350, overlap: int = 60) -> Tuple[int, int, str]: | |
| """ | |
| Pick a passage around the earliest match of any query token. | |
| If no match found, return the first window. | |
| """ | |
| if not text: | |
| return 0, 0, "" | |
| hit_positions: List[int] = [] | |
| for qt in q_tokens: | |
| hit_positions.extend(_find_all(qt, text)) | |
| if hit_positions: | |
| start = max(0, min(hit_positions) - overlap) | |
| end = min(len(text), start + window) | |
| else: | |
| start = 0 | |
| end = min(len(text), window) | |
| return start, end, text[start:end].strip() | |
| def _rerank_by_proximity(passages: List[Passage], q_tokens: List[str]) -> List[Passage]: | |
| """ | |
| Adjust scores based on how tightly query tokens cluster inside the passage. | |
| Heuristic: shorter span between matched terms → slightly higher score (≤ +0.25). | |
| """ | |
| q_unique = [t for t in dict.fromkeys(q_tokens)] # dedupe, preserve order | |
| if not q_unique: | |
| return passages | |
| def word_positions(text: str, term: str) -> List[int]: | |
| words = [w.group(0).lower() for w in _WORD_RE.finditer(text)] | |
| return [i for i, w in enumerate(words) if w == term] | |
| def proximity_bonus(p: Passage) -> float: | |
| pos_lists = [word_positions(p.text, t) for t in q_unique] | |
| if all(not pl for pl in pos_lists): | |
| return 0.0 | |
| reps = [(pl[0] if pl else None) for pl in pos_lists] | |
| core = [x for x in reps if x is not None] | |
| if not core: | |
| return 0.0 | |
| core.sort() | |
| mid = core[len(core)//2] | |
| avg_dist = sum(abs((x if x is not None else mid) - mid) for x in reps) / max(1, len(reps)) | |
| bonus = max(0.0, 0.25 * (1.0 - min(avg_dist, 10.0) / 10.0)) | |
| return float(bonus) | |
| out: List[Passage] = [] | |
| for p in passages: | |
| b = proximity_bonus(p) | |
| out.append(Passage(**{**p.__dict__, "score": p.score + b})) | |
| return out | |
| if __name__ == "__main__": | |
| import sys | |
| q = " ".join(sys.argv[1:]) or "anonymous chatbot rules" | |
| out = retrieve(q, k=3) | |
| for i, p in enumerate(out, 1): | |
| print(f"[{i}] {p.score:.4f} {p.title or '(untitled)'} — {p.source}") | |
| print(" ", (p.snippet.replace('\\n', ' ') if p.snippet else '')[:200]) | |