Spaces:
Sleeping
Sleeping
AdaptiveRAG: Agentic + Self-RAG + Modular RAG pipeline with visual UI
Browse files- Semantic chunking + dual embedding (dense MiniLM + sparse BM25)
- Hybrid retrieval: Chroma cosine + BM25 fused via RRF + BGE cross-encoder rerank
- Self-RAG router (RETRIEVE / ANSWER_DIRECTLY / CLARIFY)
- Agentic loop: plan -> retrieve -> answer -> self-critique -> refine
- Streamlit UI exposing every pipeline stage with 2D vector space scatter plot
- Pre-built index: 1934 chunks across 14 foundational AI papers (git-lfs)
- Auto-switches backend: Ollama locally, Groq API when GROQ_API_KEY is set
- .gitattributes +2 -0
- .gitignore +15 -0
- .streamlit/config.toml +5 -0
- README.md +38 -6
- agent/__init__.py +0 -0
- agent/critic.py +65 -0
- agent/loop.py +198 -0
- agent/planner.py +52 -0
- agent/router.py +39 -0
- agent/tools.py +85 -0
- app.py +554 -0
- ask.py +32 -0
- config.py +57 -0
- download_papers.sh +28 -0
- ingest.py +43 -0
- ingestion/__init__.py +0 -0
- ingestion/chunker.py +104 -0
- ingestion/embedder.py +32 -0
- ingestion/indexer.py +134 -0
- ingestion/loader.py +87 -0
- llm/__init__.py +0 -0
- llm/client_factory.py +16 -0
- llm/groq_client.py +133 -0
- llm/ollama_client.py +116 -0
- requirements.txt +9 -0
- retrieval/__init__.py +0 -0
- retrieval/dense.py +39 -0
- retrieval/hybrid.py +47 -0
- retrieval/pipeline.py +12 -0
- retrieval/reranker.py +36 -0
- retrieval/sparse.py +44 -0
- storage/bm25.pkl +3 -0
- storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/data_level0.bin +3 -0
- storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/header.bin +3 -0
- storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/index_metadata.pickle +3 -0
- storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/length.bin +3 -0
- storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/link_lists.bin +3 -0
- storage/chroma/chroma.sqlite3 +3 -0
- storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/data_level0.bin +3 -0
- storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/header.bin +3 -0
- storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/index_metadata.pickle +3 -0
- storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/length.bin +3 -0
- storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/link_lists.bin +3 -0
- storage/manifest.json +19 -0
.gitattributes
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
@@ -33,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 1 |
+
# HF standard LFS patterns
|
| 2 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 34 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
storage/chroma/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
.venv/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*.egg-info/
|
| 6 |
+
.env
|
| 7 |
+
|
| 8 |
+
# macOS
|
| 9 |
+
.DS_Store
|
| 10 |
+
|
| 11 |
+
# Papers — excluded (public ArXiv PDFs; re-download with download_papers.sh)
|
| 12 |
+
papers/
|
| 13 |
+
|
| 14 |
+
# Streamlit
|
| 15 |
+
.streamlit/secrets.toml
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[browser]
|
| 2 |
+
gatherUsageStats = false
|
| 3 |
+
|
| 4 |
+
[server]
|
| 5 |
+
headless = true
|
README.md
CHANGED
|
@@ -1,12 +1,44 @@
|
|
| 1 |
---
|
| 2 |
title: AdaptiveRAG
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
pinned:
|
| 8 |
license: mit
|
| 9 |
-
short_description:
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: AdaptiveRAG
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
+
pinned: true
|
| 8 |
license: mit
|
| 9 |
+
short_description: Agentic + Self-RAG + Modular RAG with visual pipeline UI
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# AdaptiveRAG — Agentic + Self-RAG + Modular RAG
|
| 13 |
+
|
| 14 |
+
Live demo of a production-grade RAG pipeline — every stage is visible in the UI.
|
| 15 |
+
|
| 16 |
+
**Tech stack:** ChromaDB · sentence-transformers · BM25 · Reciprocal Rank Fusion · BGE cross-encoder · LLaMA 3.1 via Groq
|
| 17 |
+
|
| 18 |
+
**Knowledge base:** 14 foundational AI papers (Transformers, BERT, GPT-3, DDPM, RAG, Self-RAG, HyDE, ViT, CLIP, ReAct, Chain-of-Thought, LLM Survey)
|
| 19 |
+
|
| 20 |
+
## What it shows
|
| 21 |
+
|
| 22 |
+
| Stage | What you see |
|
| 23 |
+
|---|---|
|
| 24 |
+
| Question encoding | 384-dim embedding vector + bar chart of first 32 dims |
|
| 25 |
+
| Self-RAG router | RETRIEVE / ANSWER_DIRECTLY / CLARIFY decision + reason |
|
| 26 |
+
| Planner | Sub-query decomposition with rationales |
|
| 27 |
+
| Dense retrieval | Cosine similarity scores vs ChromaDB |
|
| 28 |
+
| Sparse retrieval | BM25 keyword match scores |
|
| 29 |
+
| RRF fusion | Combined ranking chart |
|
| 30 |
+
| Cross-encoder rerank | BGE relevance scores |
|
| 31 |
+
| Vector space | 2D PCA projection of query + hits |
|
| 32 |
+
| Self-critique | Grounded / Complete / Confidence score |
|
| 33 |
+
|
| 34 |
+
## Run locally
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
git clone https://github.com/Gh-Novel/AdaptiveRAG
|
| 38 |
+
cd AdaptiveRAG
|
| 39 |
+
python -m venv .venv && source .venv/bin/activate
|
| 40 |
+
pip install -r requirements.txt
|
| 41 |
+
|
| 42 |
+
# needs Ollama running with qwen3-vl:8b-instruct-q8_0-optimized
|
| 43 |
+
streamlit run app.py
|
| 44 |
+
```
|
agent/__init__.py
ADDED
|
File without changes
|
agent/critic.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Self-critique: judge answer for grounding, completeness, confidence."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from llm.ollama_client import OllamaClient
|
| 5 |
+
|
| 6 |
+
CRITIC_SYSTEM = (
|
| 7 |
+
"You are a strict reviewer. Judge whether an AI-generated answer is grounded "
|
| 8 |
+
"in the provided context and whether it fully answers the user's question."
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
CRITIC_PROMPT = """Question: {question}
|
| 12 |
+
|
| 13 |
+
Answer to review:
|
| 14 |
+
{answer}
|
| 15 |
+
|
| 16 |
+
Context that was provided to the answerer:
|
| 17 |
+
{context}
|
| 18 |
+
|
| 19 |
+
Score the answer. Return strict JSON only:
|
| 20 |
+
{{
|
| 21 |
+
"grounded": true | false, // Is every factual claim supported by the context?
|
| 22 |
+
"complete": true | false, // Does it fully address the question?
|
| 23 |
+
"confidence": 0.0-1.0, // Overall confidence in the answer
|
| 24 |
+
"missing": "<what info is missing or weakly supported, or empty string>"
|
| 25 |
+
}}
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def critique(question: str, answer: str, context: str, llm: OllamaClient | None = None) -> dict:
|
| 30 |
+
llm = llm or OllamaClient()
|
| 31 |
+
out = llm.generate_json(
|
| 32 |
+
prompt=CRITIC_PROMPT.format(question=question, answer=answer, context=context),
|
| 33 |
+
system=CRITIC_SYSTEM,
|
| 34 |
+
temperature=0.0,
|
| 35 |
+
)
|
| 36 |
+
return {
|
| 37 |
+
"grounded": bool(out.get("grounded", False)),
|
| 38 |
+
"complete": bool(out.get("complete", False)),
|
| 39 |
+
"confidence": float(out.get("confidence", 0.0) or 0.0),
|
| 40 |
+
"missing": str(out.get("missing", "") or ""),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
REFINE_SYSTEM = (
|
| 45 |
+
"You rewrite a search query so it retrieves the missing information."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
REFINE_PROMPT = """Original question: {question}
|
| 49 |
+
|
| 50 |
+
A previous attempt was missing the following information:
|
| 51 |
+
{missing}
|
| 52 |
+
|
| 53 |
+
Rewrite the query to specifically target the missing information. Output the
|
| 54 |
+
rewritten search query as a single line of text, no quotes, no explanation.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def refine_query(question: str, missing: str, llm: OllamaClient | None = None) -> str:
|
| 59 |
+
llm = llm or OllamaClient()
|
| 60 |
+
out = llm.generate(
|
| 61 |
+
prompt=REFINE_PROMPT.format(question=question, missing=missing or "more detail"),
|
| 62 |
+
system=REFINE_SYSTEM,
|
| 63 |
+
temperature=0.1,
|
| 64 |
+
)
|
| 65 |
+
return out.strip().splitlines()[0] if out.strip() else question
|
agent/loop.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""The agentic RAG loop.
|
| 2 |
+
|
| 3 |
+
Pipeline:
|
| 4 |
+
1. Self-RAG router: RETRIEVE / ANSWER_DIRECTLY / CLARIFY
|
| 5 |
+
2. (RETRIEVE branch) plan -> tools -> answer -> self-critique
|
| 6 |
+
3. If confidence < threshold and budget left: refine and retry
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from config import AGENT_CONFIG
|
| 14 |
+
from agent.critic import critique, refine_query
|
| 15 |
+
from agent.planner import plan
|
| 16 |
+
from agent.router import route
|
| 17 |
+
from agent.tools import ToolResult, vector_search
|
| 18 |
+
from llm.ollama_client import OllamaClient
|
| 19 |
+
from retrieval.dense import Hit
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TraceStep:
|
| 24 |
+
kind: str
|
| 25 |
+
detail: dict[str, Any] = field(default_factory=dict)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class AgentResult:
|
| 30 |
+
answer: str
|
| 31 |
+
citations: list[dict]
|
| 32 |
+
confidence: float
|
| 33 |
+
trace: list[TraceStep]
|
| 34 |
+
iterations: int
|
| 35 |
+
route: str
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
ANSWER_SYSTEM = (
|
| 39 |
+
"You are a careful research assistant. Use ONLY the provided passages to "
|
| 40 |
+
"answer the question. Cite sources inline with [N] where N is the passage "
|
| 41 |
+
"number. If the passages are insufficient, say so explicitly."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
ANSWER_PROMPT = """Question: {question}
|
| 45 |
+
|
| 46 |
+
Passages:
|
| 47 |
+
{context}
|
| 48 |
+
|
| 49 |
+
Write a concise, well-grounded answer. Use inline citations like [1], [2] that
|
| 50 |
+
match the passage numbers above. If multiple passages support a claim, cite
|
| 51 |
+
them all. If the passages do not contain enough information, say so plainly.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _format_context_block(hits: list[Hit]) -> tuple[str, list[dict]]:
|
| 56 |
+
lines = []
|
| 57 |
+
citations = []
|
| 58 |
+
for i, h in enumerate(hits, start=1):
|
| 59 |
+
meta = h.metadata
|
| 60 |
+
title = meta.get("title") or meta.get("source_path", "?")
|
| 61 |
+
pages = f"p.{meta.get('page_start')}-{meta.get('page_end')}"
|
| 62 |
+
head = f"[{i}] {title} ({pages})"
|
| 63 |
+
lines.append(f"{head}\n{h.text}")
|
| 64 |
+
citations.append(
|
| 65 |
+
{
|
| 66 |
+
"n": i,
|
| 67 |
+
"chunk_id": h.chunk_id,
|
| 68 |
+
"title": title,
|
| 69 |
+
"source_path": meta.get("source_path"),
|
| 70 |
+
"page_start": meta.get("page_start"),
|
| 71 |
+
"page_end": meta.get("page_end"),
|
| 72 |
+
"score": float(h.score),
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
return "\n\n".join(lines), citations
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _dedupe_hits(hits: list[Hit], limit: int) -> list[Hit]:
|
| 79 |
+
seen: set[str] = set()
|
| 80 |
+
out: list[Hit] = []
|
| 81 |
+
for h in hits:
|
| 82 |
+
if h.chunk_id in seen:
|
| 83 |
+
continue
|
| 84 |
+
seen.add(h.chunk_id)
|
| 85 |
+
out.append(h)
|
| 86 |
+
if len(out) >= limit:
|
| 87 |
+
break
|
| 88 |
+
return out
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def run_agent(query: str, llm: OllamaClient | None = None) -> AgentResult:
|
| 92 |
+
llm = llm or OllamaClient()
|
| 93 |
+
trace: list[TraceStep] = []
|
| 94 |
+
|
| 95 |
+
# 1. Router
|
| 96 |
+
decision = route(query, llm=llm)
|
| 97 |
+
trace.append(TraceStep("router", decision))
|
| 98 |
+
|
| 99 |
+
if decision["action"] == "ANSWER_DIRECTLY":
|
| 100 |
+
ans = llm.generate(
|
| 101 |
+
prompt=query,
|
| 102 |
+
system="You are a helpful research assistant. Be concise.",
|
| 103 |
+
temperature=0.2,
|
| 104 |
+
)
|
| 105 |
+
return AgentResult(
|
| 106 |
+
answer=ans,
|
| 107 |
+
citations=[],
|
| 108 |
+
confidence=1.0,
|
| 109 |
+
trace=trace,
|
| 110 |
+
iterations=0,
|
| 111 |
+
route="ANSWER_DIRECTLY",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if decision["action"] == "CLARIFY":
|
| 115 |
+
ans = llm.generate(
|
| 116 |
+
prompt=(
|
| 117 |
+
"The user asked: " + query + "\n\n"
|
| 118 |
+
"It is too ambiguous to answer well. Ask one short clarifying "
|
| 119 |
+
"question to narrow it down."
|
| 120 |
+
),
|
| 121 |
+
system="You are a helpful research assistant.",
|
| 122 |
+
temperature=0.2,
|
| 123 |
+
)
|
| 124 |
+
return AgentResult(
|
| 125 |
+
answer=ans,
|
| 126 |
+
citations=[],
|
| 127 |
+
confidence=0.0,
|
| 128 |
+
trace=trace,
|
| 129 |
+
iterations=0,
|
| 130 |
+
route="CLARIFY",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# 2. RETRIEVE branch — agentic loop
|
| 134 |
+
current_query = query
|
| 135 |
+
last_critique: dict[str, Any] = {}
|
| 136 |
+
accumulated: list[Hit] = []
|
| 137 |
+
|
| 138 |
+
for iteration in range(AGENT_CONFIG["max_iterations"]):
|
| 139 |
+
prior_summary = ""
|
| 140 |
+
if accumulated:
|
| 141 |
+
titles = sorted({h.metadata.get("title", "?") for h in accumulated})
|
| 142 |
+
prior_summary = "Already gathered passages from: " + ", ".join(titles)
|
| 143 |
+
|
| 144 |
+
steps = plan(current_query, prior_summary=prior_summary, llm=llm)
|
| 145 |
+
trace.append(TraceStep("plan", {"iteration": iteration, "steps": steps}))
|
| 146 |
+
|
| 147 |
+
for step in steps:
|
| 148 |
+
tool_res: ToolResult = vector_search(step["query"])
|
| 149 |
+
accumulated.extend(tool_res.hits)
|
| 150 |
+
trace.append(
|
| 151 |
+
TraceStep(
|
| 152 |
+
"tool",
|
| 153 |
+
{
|
| 154 |
+
"tool": "vector_search",
|
| 155 |
+
"query": step["query"],
|
| 156 |
+
"n_hits": len(tool_res.hits),
|
| 157 |
+
"top_titles": [
|
| 158 |
+
h.metadata.get("title") for h in tool_res.hits[:3]
|
| 159 |
+
],
|
| 160 |
+
},
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
unique_hits = _dedupe_hits(accumulated, limit=8)
|
| 165 |
+
context_block, citations = _format_context_block(unique_hits)
|
| 166 |
+
|
| 167 |
+
answer = llm.generate(
|
| 168 |
+
prompt=ANSWER_PROMPT.format(question=query, context=context_block),
|
| 169 |
+
system=ANSWER_SYSTEM,
|
| 170 |
+
temperature=0.1,
|
| 171 |
+
)
|
| 172 |
+
trace.append(TraceStep("answer", {"iteration": iteration, "n_passages": len(unique_hits)}))
|
| 173 |
+
|
| 174 |
+
crit = critique(query, answer, context_block, llm=llm)
|
| 175 |
+
last_critique = crit
|
| 176 |
+
trace.append(TraceStep("critique", {"iteration": iteration, **crit}))
|
| 177 |
+
|
| 178 |
+
if crit["confidence"] >= AGENT_CONFIG["confidence_threshold"] and crit["grounded"]:
|
| 179 |
+
return AgentResult(
|
| 180 |
+
answer=answer,
|
| 181 |
+
citations=citations,
|
| 182 |
+
confidence=crit["confidence"],
|
| 183 |
+
trace=trace,
|
| 184 |
+
iterations=iteration + 1,
|
| 185 |
+
route="RETRIEVE",
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
current_query = refine_query(query, crit.get("missing", ""), llm=llm)
|
| 189 |
+
trace.append(TraceStep("refine", {"new_query": current_query}))
|
| 190 |
+
|
| 191 |
+
return AgentResult(
|
| 192 |
+
answer=answer,
|
| 193 |
+
citations=citations,
|
| 194 |
+
confidence=last_critique.get("confidence", 0.0),
|
| 195 |
+
trace=trace,
|
| 196 |
+
iterations=AGENT_CONFIG["max_iterations"],
|
| 197 |
+
route="RETRIEVE",
|
| 198 |
+
)
|
agent/planner.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-step query planner. Break a question into focused sub-queries."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from config import AGENT_CONFIG
|
| 5 |
+
from llm.ollama_client import OllamaClient
|
| 6 |
+
|
| 7 |
+
PLANNER_SYSTEM = (
|
| 8 |
+
"You are a research planner. Given a user question, decompose it into a small "
|
| 9 |
+
"number of focused sub-queries. Each sub-query targets one piece of information "
|
| 10 |
+
"needed to answer the original question. Avoid redundant or overly broad steps."
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
PLANNER_PROMPT = """Decompose the user question into 1-{max_steps} focused retrieval sub-queries.
|
| 14 |
+
Use fewer steps when the question is simple; only use multiple steps for genuinely
|
| 15 |
+
multi-part or comparative questions.
|
| 16 |
+
|
| 17 |
+
Each sub-query should be a self-contained search query (10-20 words) phrased to
|
| 18 |
+
match passages in academic papers.
|
| 19 |
+
|
| 20 |
+
Respond with strict JSON only:
|
| 21 |
+
{{"steps": [
|
| 22 |
+
{{"query": "<search query>", "rationale": "<what this sub-query is looking for>"}}
|
| 23 |
+
]}}
|
| 24 |
+
|
| 25 |
+
User question: {query}
|
| 26 |
+
|
| 27 |
+
Context already gathered (may be empty):
|
| 28 |
+
{context_summary}
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def plan(query: str, prior_summary: str = "", llm: OllamaClient | None = None) -> list[dict]:
|
| 33 |
+
llm = llm or OllamaClient()
|
| 34 |
+
out = llm.generate_json(
|
| 35 |
+
prompt=PLANNER_PROMPT.format(
|
| 36 |
+
query=query,
|
| 37 |
+
max_steps=AGENT_CONFIG["max_plan_steps"],
|
| 38 |
+
context_summary=prior_summary or "(none)",
|
| 39 |
+
),
|
| 40 |
+
system=PLANNER_SYSTEM,
|
| 41 |
+
temperature=0.1,
|
| 42 |
+
)
|
| 43 |
+
steps = out.get("steps") if isinstance(out, dict) else None
|
| 44 |
+
if not steps or not isinstance(steps, list):
|
| 45 |
+
return [{"query": query, "rationale": "fallback: use the original question"}]
|
| 46 |
+
cleaned: list[dict] = []
|
| 47 |
+
for s in steps[: AGENT_CONFIG["max_plan_steps"]]:
|
| 48 |
+
if isinstance(s, dict) and s.get("query"):
|
| 49 |
+
cleaned.append(
|
| 50 |
+
{"query": str(s["query"]).strip(), "rationale": str(s.get("rationale", "")).strip()}
|
| 51 |
+
)
|
| 52 |
+
return cleaned or [{"query": query, "rationale": "fallback"}]
|
agent/router.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Self-RAG router. Decide whether to retrieve, answer directly, or clarify."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from llm.ollama_client import OllamaClient
|
| 5 |
+
|
| 6 |
+
ROUTER_SYSTEM = (
|
| 7 |
+
"You are a routing classifier for an AI research assistant whose knowledge base "
|
| 8 |
+
"contains papers on Transformers, BERT, GPT-3, diffusion (DDPM/DDIM), RAG, "
|
| 9 |
+
"Self-RAG, HyDE, ViT, CLIP, ReAct, Chain-of-Thought, and an LLM survey. "
|
| 10 |
+
"Decide how to handle a user query."
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
ROUTER_PROMPT = """Classify the query into one of three actions:
|
| 14 |
+
|
| 15 |
+
- "RETRIEVE": the user is asking about substantive content (concepts, methods, comparisons,
|
| 16 |
+
details from papers). The knowledge base is likely needed. Default to this when unsure.
|
| 17 |
+
- "ANSWER_DIRECTLY": pure conversational/meta queries (greetings, "what can you do",
|
| 18 |
+
"thanks") that need NO knowledge lookup.
|
| 19 |
+
- "CLARIFY": the query is too ambiguous or under-specified to act on (e.g. "tell me more"
|
| 20 |
+
with no prior context, "what about that paper" with no referent).
|
| 21 |
+
|
| 22 |
+
Respond with strict JSON only:
|
| 23 |
+
{{"action": "RETRIEVE" | "ANSWER_DIRECTLY" | "CLARIFY", "reason": "<one short sentence>"}}
|
| 24 |
+
|
| 25 |
+
Query: {query}
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def route(query: str, llm: OllamaClient | None = None) -> dict:
|
| 30 |
+
llm = llm or OllamaClient()
|
| 31 |
+
out = llm.generate_json(
|
| 32 |
+
prompt=ROUTER_PROMPT.format(query=query),
|
| 33 |
+
system=ROUTER_SYSTEM,
|
| 34 |
+
temperature=0.0,
|
| 35 |
+
)
|
| 36 |
+
action = str(out.get("action", "RETRIEVE")).upper()
|
| 37 |
+
if action not in {"RETRIEVE", "ANSWER_DIRECTLY", "CLARIFY"}:
|
| 38 |
+
action = "RETRIEVE"
|
| 39 |
+
return {"action": action, "reason": out.get("reason", "")}
|
agent/tools.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent tools.
|
| 2 |
+
|
| 3 |
+
The vector_search tool drives the hybrid retriever. image_reason performs
|
| 4 |
+
multimodal RAG: caption the image, retrieve text by caption+query, then ask
|
| 5 |
+
Qwen3-VL to ground its answer in both image and text.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
from llm.ollama_client import OllamaClient
|
| 12 |
+
from retrieval.dense import Hit
|
| 13 |
+
from retrieval.pipeline import hybrid_retrieve
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ToolResult:
|
| 18 |
+
tool: str
|
| 19 |
+
query: str
|
| 20 |
+
hits: list[Hit]
|
| 21 |
+
notes: str = ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def vector_search(query: str, top_n: int | None = None) -> ToolResult:
|
| 25 |
+
hits = hybrid_retrieve(query, top_n=top_n)
|
| 26 |
+
return ToolResult(tool="vector_search", query=query, hits=hits)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
CAPTION_SYSTEM = "You describe images in concise, factual language."
|
| 30 |
+
CAPTION_PROMPT = (
|
| 31 |
+
"Describe this image in 1-3 sentences. Mention the type of figure (chart, diagram, "
|
| 32 |
+
"screenshot, photo, equation, etc.), key labels, and the main visual content."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def caption_image(image_path: str, llm: OllamaClient | None = None) -> str:
|
| 37 |
+
llm = llm or OllamaClient()
|
| 38 |
+
return llm.generate(
|
| 39 |
+
prompt=CAPTION_PROMPT,
|
| 40 |
+
system=CAPTION_SYSTEM,
|
| 41 |
+
images=[image_path],
|
| 42 |
+
temperature=0.0,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
MM_SYSTEM = (
|
| 47 |
+
"You are a careful research assistant. Answer using ONLY the provided image and "
|
| 48 |
+
"the cited text passages. If the answer is not supported, say so."
|
| 49 |
+
)
|
| 50 |
+
MM_PROMPT = """Image (provided separately) + question.
|
| 51 |
+
|
| 52 |
+
Question: {question}
|
| 53 |
+
|
| 54 |
+
Relevant passages:
|
| 55 |
+
{context}
|
| 56 |
+
|
| 57 |
+
Answer concisely. When citing a passage, use [N] where N is the passage number.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def image_retrieve_and_reason(
|
| 62 |
+
image_path: str, query: str, llm: OllamaClient | None = None
|
| 63 |
+
) -> dict:
|
| 64 |
+
llm = llm or OllamaClient()
|
| 65 |
+
caption = caption_image(image_path, llm=llm)
|
| 66 |
+
fused_query = f"{caption} {query}".strip()
|
| 67 |
+
hits = hybrid_retrieve(fused_query)
|
| 68 |
+
context_block = _format_context(hits)
|
| 69 |
+
answer = llm.generate(
|
| 70 |
+
prompt=MM_PROMPT.format(question=query, context=context_block),
|
| 71 |
+
system=MM_SYSTEM,
|
| 72 |
+
images=[image_path],
|
| 73 |
+
temperature=0.1,
|
| 74 |
+
)
|
| 75 |
+
return {"caption": caption, "answer": answer, "hits": hits}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _format_context(hits: list[Hit]) -> str:
|
| 79 |
+
lines = []
|
| 80 |
+
for i, h in enumerate(hits, start=1):
|
| 81 |
+
meta = h.metadata
|
| 82 |
+
head = f"[{i}] {meta.get('title', meta.get('source_path', '?'))} "
|
| 83 |
+
head += f"(p.{meta.get('page_start')}-{meta.get('page_end')})"
|
| 84 |
+
lines.append(f"{head}\n{h.text}")
|
| 85 |
+
return "\n\n".join(lines)
|
app.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AdaptiveRAG — under-the-hood pipeline visualizer.
|
| 2 |
+
|
| 3 |
+
Run: streamlit run app.py
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import tempfile
|
| 11 |
+
import time
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
os.environ.setdefault("ANONYMIZED_TELEMETRY", "False")
|
| 15 |
+
|
| 16 |
+
# suppress harmless noise from Streamlit's torch inspector + ChromaDB posthog client
|
| 17 |
+
logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
|
| 18 |
+
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.ERROR)
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import streamlit as st
|
| 23 |
+
|
| 24 |
+
from agent.critic import critique, refine_query
|
| 25 |
+
from agent.planner import plan
|
| 26 |
+
from agent.router import route
|
| 27 |
+
from agent.tools import image_retrieve_and_reason
|
| 28 |
+
from config import AGENT_CONFIG, EMBEDDING_CONFIG, HOSTED, LLM_CONFIG, PATHS, RETRIEVAL_CONFIG
|
| 29 |
+
from ingestion.embedder import embed_query
|
| 30 |
+
from ingestion.indexer import fetch_embeddings
|
| 31 |
+
from llm.client_factory import get_llm
|
| 32 |
+
from retrieval.dense import Hit, dense_search
|
| 33 |
+
from retrieval.hybrid import reciprocal_rank_fusion
|
| 34 |
+
from retrieval.reranker import rerank
|
| 35 |
+
from retrieval.sparse import sparse_search
|
| 36 |
+
|
| 37 |
+
st.set_page_config(page_title="AdaptiveRAG — Underhood", page_icon="🔬", layout="wide")
|
| 38 |
+
|
| 39 |
+
# ───────────────────────────── styling ──────────────────────────────
|
| 40 |
+
st.markdown(
|
| 41 |
+
"""
|
| 42 |
+
<style>
|
| 43 |
+
.phase-card {
|
| 44 |
+
border-left: 4px solid var(--accent, #4f8cff);
|
| 45 |
+
padding: .6rem 1rem;
|
| 46 |
+
margin: .25rem 0 .5rem 0;
|
| 47 |
+
background: rgba(79,140,255,0.06);
|
| 48 |
+
border-radius: 6px;
|
| 49 |
+
}
|
| 50 |
+
.phase-num { color: #4f8cff; font-weight: 700; margin-right: .4rem; }
|
| 51 |
+
.pill { display: inline-block; padding: .15rem .55rem; border-radius: 999px;
|
| 52 |
+
font-size: .78rem; font-weight: 600; margin-right: .4rem; }
|
| 53 |
+
.pill-blue { background: #1e3a5f; color: #9ec5ff; }
|
| 54 |
+
.pill-green { background: #1e4f30; color: #a3e6b5; }
|
| 55 |
+
.pill-purple { background: #3d2a5e; color: #c8a8f5; }
|
| 56 |
+
.pill-amber { background: #5e3f0e; color: #f3c97a; }
|
| 57 |
+
.pill-red { background: #5a1f1f; color: #f3a3a3; }
|
| 58 |
+
.pill-grey { background: #2c2c33; color: #b8b8c0; }
|
| 59 |
+
.chunk-card {
|
| 60 |
+
background: rgba(255,255,255,0.03);
|
| 61 |
+
border: 1px solid rgba(255,255,255,0.08);
|
| 62 |
+
border-radius: 6px; padding: .55rem .7rem; margin-bottom: .4rem;
|
| 63 |
+
font-size: .82rem;
|
| 64 |
+
}
|
| 65 |
+
.chunk-meta { color: #9aa3b2; font-size: .73rem; margin-bottom: .25rem; }
|
| 66 |
+
.mini-vec {
|
| 67 |
+
font-family: 'SF Mono', Menlo, monospace; font-size: .68rem;
|
| 68 |
+
color: #8b949e; word-break: break-all;
|
| 69 |
+
}
|
| 70 |
+
</style>
|
| 71 |
+
""",
|
| 72 |
+
unsafe_allow_html=True,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ───────────────────────────── helpers ──────────────────────────────
|
| 77 |
+
@st.cache_resource
|
| 78 |
+
def _llm():
|
| 79 |
+
return get_llm()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _load_manifest() -> dict:
|
| 83 |
+
p = PATHS["manifest_path"]
|
| 84 |
+
return json.loads(p.read_text()) if p.exists() else {}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def phase_header(num: int, title: str, subtitle: str = "") -> None:
|
| 88 |
+
st.markdown(
|
| 89 |
+
f"<div class='phase-card'><span class='phase-num'>STEP {num}</span>"
|
| 90 |
+
f"<b>{title}</b><br><span style='color:#9aa3b2;font-size:.85rem;'>{subtitle}</span>"
|
| 91 |
+
f"</div>",
|
| 92 |
+
unsafe_allow_html=True,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def hits_to_df(hits: list[Hit], score_label: str = "score") -> pd.DataFrame:
|
| 97 |
+
rows = []
|
| 98 |
+
for h in hits:
|
| 99 |
+
title = h.metadata.get("title") or h.metadata.get("source_path", "?")
|
| 100 |
+
short = title.split(" (")[0]
|
| 101 |
+
if len(short) > 38:
|
| 102 |
+
short = short[:35] + "…"
|
| 103 |
+
label = f"{short} · p{h.metadata.get('page_start')} · {h.chunk_id.split('::')[-1]}"
|
| 104 |
+
rows.append({"chunk": label, score_label: float(h.score), "chunk_id": h.chunk_id})
|
| 105 |
+
return pd.DataFrame(rows)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def render_hits(hits: list[Hit], badge_class: str, label: str, max_chars: int = 220) -> None:
|
| 109 |
+
if not hits:
|
| 110 |
+
st.caption(f"_(no {label.lower()} hits)_")
|
| 111 |
+
return
|
| 112 |
+
for i, h in enumerate(hits, start=1):
|
| 113 |
+
meta = h.metadata
|
| 114 |
+
snippet = h.text[:max_chars].replace("\n", " ")
|
| 115 |
+
if len(h.text) > max_chars:
|
| 116 |
+
snippet += "…"
|
| 117 |
+
st.markdown(
|
| 118 |
+
f"<div class='chunk-card'>"
|
| 119 |
+
f"<div class='chunk-meta'>"
|
| 120 |
+
f"<span class='pill {badge_class}'>{label} #{i}</span>"
|
| 121 |
+
f"score <b>{h.score:.3f}</b> · "
|
| 122 |
+
f"{meta.get('title','?')} · p.{meta.get('page_start')}–{meta.get('page_end')} · "
|
| 123 |
+
f"<code>{h.chunk_id}</code>"
|
| 124 |
+
f"</div>{snippet}</div>",
|
| 125 |
+
unsafe_allow_html=True,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def pca_2d(matrix: np.ndarray) -> np.ndarray:
|
| 130 |
+
centered = matrix - matrix.mean(axis=0, keepdims=True)
|
| 131 |
+
_, _, vt = np.linalg.svd(centered, full_matrices=False)
|
| 132 |
+
return centered @ vt[:2].T
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def vector_space_plot(query_vec: list[float], fused_hits: list[Hit],
|
| 136 |
+
dense_ids: set[str], sparse_ids: set[str],
|
| 137 |
+
kept_ids: set[str]) -> None:
|
| 138 |
+
if not fused_hits:
|
| 139 |
+
st.caption("_(nothing to plot)_")
|
| 140 |
+
return
|
| 141 |
+
embs = fetch_embeddings([h.chunk_id for h in fused_hits])
|
| 142 |
+
rows = []
|
| 143 |
+
vecs = [np.array(query_vec, dtype=np.float32)]
|
| 144 |
+
for h in fused_hits:
|
| 145 |
+
v = embs.get(h.chunk_id)
|
| 146 |
+
if v is None:
|
| 147 |
+
continue
|
| 148 |
+
vecs.append(np.array(v, dtype=np.float32))
|
| 149 |
+
in_d, in_s = h.chunk_id in dense_ids, h.chunk_id in sparse_ids
|
| 150 |
+
in_keep = h.chunk_id in kept_ids
|
| 151 |
+
if in_d and in_s:
|
| 152 |
+
color = "fused (both)"
|
| 153 |
+
elif in_d:
|
| 154 |
+
color = "dense only"
|
| 155 |
+
elif in_s:
|
| 156 |
+
color = "sparse only"
|
| 157 |
+
else:
|
| 158 |
+
color = "other"
|
| 159 |
+
title = (h.metadata.get("title") or "?").split(" (")[0][:40]
|
| 160 |
+
label = f"{title} · p{h.metadata.get('page_start')}"
|
| 161 |
+
rows.append({"label": label, "color": color, "size": 90 if in_keep else 50})
|
| 162 |
+
if len(vecs) < 3:
|
| 163 |
+
st.caption("_(need at least 2 hits for a 2D projection)_")
|
| 164 |
+
return
|
| 165 |
+
proj = pca_2d(np.vstack(vecs))
|
| 166 |
+
df = pd.DataFrame(
|
| 167 |
+
[{"x": proj[0, 0], "y": proj[0, 1], "label": "🔎 your question",
|
| 168 |
+
"color": "QUERY", "size": 220}]
|
| 169 |
+
+ [{"x": proj[i + 1, 0], "y": proj[i + 1, 1], **rows[i]}
|
| 170 |
+
for i in range(len(rows))]
|
| 171 |
+
)
|
| 172 |
+
st.scatter_chart(
|
| 173 |
+
df, x="x", y="y", color="color", size="size",
|
| 174 |
+
height=380, use_container_width=True,
|
| 175 |
+
)
|
| 176 |
+
st.caption(
|
| 177 |
+
"PCA projection of the query embedding + fused hit embeddings. "
|
| 178 |
+
"Larger points survived cross-encoder reranking."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def render_embedding_card(query: str, qv: list[float], dt: float) -> None:
|
| 183 |
+
arr = np.array(qv, dtype=np.float32)
|
| 184 |
+
cols = st.columns([1, 1, 1, 3])
|
| 185 |
+
cols[0].metric("Model", EMBEDDING_CONFIG["model"].split("/")[-1])
|
| 186 |
+
cols[1].metric("Dimensions", len(qv))
|
| 187 |
+
cols[2].metric("L2 norm", f"{float(np.linalg.norm(arr)):.3f}")
|
| 188 |
+
cols[3].metric("Embed time", f"{dt*1000:.0f} ms")
|
| 189 |
+
st.caption(f"Question ({len(query)} chars, ~{len(query.split())} words):")
|
| 190 |
+
st.code(query, language="text")
|
| 191 |
+
st.caption("First 32 dimensions of the embedding vector:")
|
| 192 |
+
st.bar_chart(pd.DataFrame({"value": arr[:32]}), height=140, use_container_width=True)
|
| 193 |
+
preview = ", ".join(f"{x:+.3f}" for x in arr[:8]) + ", …"
|
| 194 |
+
st.markdown(f"<span class='mini-vec'>vector[0:8] = [{preview}]</span>",
|
| 195 |
+
unsafe_allow_html=True)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ───────────────────────────── pipeline view ──────────────────────────────
|
| 199 |
+
def visual_pipeline(query: str) -> None:
|
| 200 |
+
llm = _llm()
|
| 201 |
+
|
| 202 |
+
# ── Step 1: embed the question ────────────────────────────────
|
| 203 |
+
phase_header(1, "Question encoding",
|
| 204 |
+
"Convert text → 384-dim dense vector via sentence-transformers (MiniLM-L6).")
|
| 205 |
+
t0 = time.time()
|
| 206 |
+
qv = embed_query(query)
|
| 207 |
+
render_embedding_card(query, qv, time.time() - t0)
|
| 208 |
+
|
| 209 |
+
# ── Step 2: Self-RAG router ────────────────────────────────
|
| 210 |
+
phase_header(2, "Self-RAG router",
|
| 211 |
+
"Decide whether to RETRIEVE, ANSWER_DIRECTLY, or CLARIFY before touching the index.")
|
| 212 |
+
t0 = time.time()
|
| 213 |
+
decision = route(query, llm=llm)
|
| 214 |
+
dt = time.time() - t0
|
| 215 |
+
pill_map = {"RETRIEVE": "pill-blue", "ANSWER_DIRECTLY": "pill-green", "CLARIFY": "pill-amber"}
|
| 216 |
+
pill = pill_map.get(decision["action"], "pill-grey")
|
| 217 |
+
st.markdown(
|
| 218 |
+
f"<span class='pill {pill}'>{decision['action']}</span>"
|
| 219 |
+
f"<span style='color:#9aa3b2;'>{decision.get('reason','')}</span>"
|
| 220 |
+
f"<span style='float:right;color:#9aa3b2;font-size:.78rem;'>"
|
| 221 |
+
f"router latency: {dt*1000:.0f} ms</span>",
|
| 222 |
+
unsafe_allow_html=True,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if decision["action"] == "ANSWER_DIRECTLY":
|
| 226 |
+
st.markdown("### Direct answer (no retrieval)")
|
| 227 |
+
ans = llm.generate(prompt=query,
|
| 228 |
+
system="You are a helpful research assistant. Be concise.",
|
| 229 |
+
temperature=0.2)
|
| 230 |
+
st.markdown(ans)
|
| 231 |
+
return
|
| 232 |
+
if decision["action"] == "CLARIFY":
|
| 233 |
+
st.markdown("### Clarifying question")
|
| 234 |
+
ans = llm.generate(
|
| 235 |
+
prompt=("The user asked: " + query +
|
| 236 |
+
"\n\nIt is too ambiguous to answer well. Ask one short clarifying question."),
|
| 237 |
+
system="You are a helpful research assistant.",
|
| 238 |
+
temperature=0.2,
|
| 239 |
+
)
|
| 240 |
+
st.markdown(ans)
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
# ── Iterations of plan → retrieve → answer → critique ────────
|
| 244 |
+
accumulated: list[Hit] = []
|
| 245 |
+
current_query = query
|
| 246 |
+
|
| 247 |
+
for it in range(AGENT_CONFIG["max_iterations"]):
|
| 248 |
+
st.markdown(f"---\n## 🔁 Iteration {it + 1}")
|
| 249 |
+
if current_query != query:
|
| 250 |
+
st.info(f"Refined query → **{current_query}**")
|
| 251 |
+
|
| 252 |
+
# ── Step 3: plan ─────────────────────────────────────
|
| 253 |
+
phase_header(3, "Planner", "LLM decomposes the question into focused sub-queries.")
|
| 254 |
+
prior = ""
|
| 255 |
+
if accumulated:
|
| 256 |
+
titles = sorted({h.metadata.get("title", "?") for h in accumulated})
|
| 257 |
+
prior = "Already gathered passages from: " + ", ".join(titles)
|
| 258 |
+
t0 = time.time()
|
| 259 |
+
steps = plan(current_query, prior_summary=prior, llm=llm)
|
| 260 |
+
dt = time.time() - t0
|
| 261 |
+
st.caption(f"Generated {len(steps)} sub-quer{'y' if len(steps)==1 else 'ies'} in {dt*1000:.0f} ms")
|
| 262 |
+
for i, s in enumerate(steps, start=1):
|
| 263 |
+
st.markdown(
|
| 264 |
+
f"<div class='chunk-card'>"
|
| 265 |
+
f"<span class='pill pill-purple'>sub-query {i}</span>"
|
| 266 |
+
f"<b>{s['query']}</b>"
|
| 267 |
+
f"<div class='chunk-meta' style='margin-top:.3rem;'>"
|
| 268 |
+
f"rationale: {s.get('rationale','—')}</div></div>",
|
| 269 |
+
unsafe_allow_html=True,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# ── Step 4: retrieval per sub-query ──────────────────
|
| 273 |
+
phase_header(
|
| 274 |
+
4,
|
| 275 |
+
"Hybrid retrieval per sub-query",
|
| 276 |
+
f"Dense (Chroma cosine, k={RETRIEVAL_CONFIG['dense_k']}) ∥ "
|
| 277 |
+
f"Sparse (BM25, k={RETRIEVAL_CONFIG['sparse_k']}) → "
|
| 278 |
+
f"Reciprocal Rank Fusion → Cross-encoder rerank "
|
| 279 |
+
f"(BGE, top {RETRIEVAL_CONFIG['rerank_top_n']}).",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
for si, step in enumerate(steps, start=1):
|
| 283 |
+
with st.expander(f"Sub-query {si}: {step['query']}", expanded=(si == 1)):
|
| 284 |
+
t0 = time.time()
|
| 285 |
+
dense_hits = dense_search(step["query"])
|
| 286 |
+
t_dense = time.time() - t0
|
| 287 |
+
t0 = time.time()
|
| 288 |
+
sparse_hits = sparse_search(step["query"])
|
| 289 |
+
t_sparse = time.time() - t0
|
| 290 |
+
t0 = time.time()
|
| 291 |
+
fused = reciprocal_rank_fusion([dense_hits, sparse_hits],
|
| 292 |
+
top_k=max(RETRIEVAL_CONFIG["dense_k"],
|
| 293 |
+
RETRIEVAL_CONFIG["sparse_k"]))
|
| 294 |
+
t_fuse = time.time() - t0
|
| 295 |
+
t0 = time.time()
|
| 296 |
+
reranked = rerank(step["query"], fused)
|
| 297 |
+
t_rerank = time.time() - t0
|
| 298 |
+
|
| 299 |
+
m1, m2, m3, m4 = st.columns(4)
|
| 300 |
+
m1.metric("Dense hits", len(dense_hits), f"{t_dense*1000:.0f} ms")
|
| 301 |
+
m2.metric("Sparse hits", len(sparse_hits), f"{t_sparse*1000:.0f} ms")
|
| 302 |
+
m3.metric("After RRF", len(fused), f"{t_fuse*1000:.0f} ms")
|
| 303 |
+
m4.metric("After rerank", len(reranked), f"{t_rerank*1000:.0f} ms")
|
| 304 |
+
|
| 305 |
+
tabs = st.tabs([
|
| 306 |
+
"🔵 Dense (vectors)",
|
| 307 |
+
"🟢 Sparse (BM25)",
|
| 308 |
+
"🟣 RRF fusion",
|
| 309 |
+
"🟡 Cross-encoder rerank",
|
| 310 |
+
"🗺️ Vector space",
|
| 311 |
+
])
|
| 312 |
+
with tabs[0]:
|
| 313 |
+
st.caption("Top-K nearest neighbors by cosine similarity.")
|
| 314 |
+
if dense_hits:
|
| 315 |
+
st.bar_chart(hits_to_df(dense_hits, "cosine_sim"),
|
| 316 |
+
x="chunk", y="cosine_sim",
|
| 317 |
+
height=260, use_container_width=True)
|
| 318 |
+
render_hits(dense_hits[:5], "pill-blue", "DENSE")
|
| 319 |
+
|
| 320 |
+
with tabs[1]:
|
| 321 |
+
st.caption("Top-K BM25 keyword matches (normalized).")
|
| 322 |
+
if sparse_hits:
|
| 323 |
+
st.bar_chart(hits_to_df(sparse_hits, "bm25_norm"),
|
| 324 |
+
x="chunk", y="bm25_norm",
|
| 325 |
+
height=260, use_container_width=True)
|
| 326 |
+
render_hits(sparse_hits[:5], "pill-green", "BM25")
|
| 327 |
+
|
| 328 |
+
with tabs[2]:
|
| 329 |
+
st.caption(
|
| 330 |
+
"Reciprocal Rank Fusion: score(d) = Σ 1/(k + rank). "
|
| 331 |
+
"Combines dense + sparse rankings into one merged list."
|
| 332 |
+
)
|
| 333 |
+
if fused:
|
| 334 |
+
st.bar_chart(hits_to_df(fused[:12], "rrf_score"),
|
| 335 |
+
x="chunk", y="rrf_score",
|
| 336 |
+
height=280, use_container_width=True)
|
| 337 |
+
render_hits(fused[:5], "pill-purple", "FUSED")
|
| 338 |
+
|
| 339 |
+
with tabs[3]:
|
| 340 |
+
st.caption(
|
| 341 |
+
"Cross-encoder scores (query, chunk) jointly — much more "
|
| 342 |
+
"accurate than bi-encoder cosine, but slower → only run on "
|
| 343 |
+
"the fused candidate set."
|
| 344 |
+
)
|
| 345 |
+
if reranked:
|
| 346 |
+
st.bar_chart(hits_to_df(reranked, "ce_score"),
|
| 347 |
+
x="chunk", y="ce_score",
|
| 348 |
+
height=240, use_container_width=True)
|
| 349 |
+
render_hits(reranked, "pill-amber", "RERANKED")
|
| 350 |
+
|
| 351 |
+
with tabs[4]:
|
| 352 |
+
dense_ids = {h.chunk_id for h in dense_hits}
|
| 353 |
+
sparse_ids = {h.chunk_id for h in sparse_hits}
|
| 354 |
+
kept_ids = {h.chunk_id for h in reranked}
|
| 355 |
+
vector_space_plot(qv, fused[:20], dense_ids, sparse_ids, kept_ids)
|
| 356 |
+
|
| 357 |
+
accumulated.extend(reranked)
|
| 358 |
+
|
| 359 |
+
# ── Step 5: answer ─────────────────────────────────────
|
| 360 |
+
# Dedupe + cap to 8 passages for the final prompt
|
| 361 |
+
seen: set[str] = set()
|
| 362 |
+
unique: list[Hit] = []
|
| 363 |
+
for h in accumulated:
|
| 364 |
+
if h.chunk_id in seen:
|
| 365 |
+
continue
|
| 366 |
+
seen.add(h.chunk_id)
|
| 367 |
+
unique.append(h)
|
| 368 |
+
if len(unique) >= 8:
|
| 369 |
+
break
|
| 370 |
+
context_lines, citations = [], []
|
| 371 |
+
for i, h in enumerate(unique, start=1):
|
| 372 |
+
meta = h.metadata
|
| 373 |
+
head = (f"[{i}] {meta.get('title','?')} "
|
| 374 |
+
f"(p.{meta.get('page_start')}-{meta.get('page_end')})")
|
| 375 |
+
context_lines.append(f"{head}\n{h.text}")
|
| 376 |
+
citations.append({
|
| 377 |
+
"n": i, "chunk_id": h.chunk_id,
|
| 378 |
+
"title": meta.get("title"),
|
| 379 |
+
"source_path": meta.get("source_path"),
|
| 380 |
+
"page_start": meta.get("page_start"),
|
| 381 |
+
"page_end": meta.get("page_end"),
|
| 382 |
+
"score": float(h.score),
|
| 383 |
+
})
|
| 384 |
+
context_block = "\n\n".join(context_lines)
|
| 385 |
+
|
| 386 |
+
phase_header(5, "Context assembly + answer generation",
|
| 387 |
+
f"Top {len(unique)} unique passages → {LLM_CONFIG['model']} via {LLM_CONFIG['provider']}.")
|
| 388 |
+
with st.expander("📦 Context handed to the LLM", expanded=False):
|
| 389 |
+
for c in citations:
|
| 390 |
+
st.markdown(
|
| 391 |
+
f"**[{c['n']}]** {c['title']} · pages {c['page_start']}–{c['page_end']} · "
|
| 392 |
+
f"score `{c['score']:.3f}`"
|
| 393 |
+
)
|
| 394 |
+
st.code(context_block[:3000] + ("…" if len(context_block) > 3000 else ""),
|
| 395 |
+
language="text")
|
| 396 |
+
|
| 397 |
+
t0 = time.time()
|
| 398 |
+
ANSWER_SYSTEM = (
|
| 399 |
+
"You are a careful research assistant. Use ONLY the provided passages to "
|
| 400 |
+
"answer the question. Cite sources inline with [N] where N is the passage "
|
| 401 |
+
"number. If the passages are insufficient, say so explicitly."
|
| 402 |
+
)
|
| 403 |
+
ANSWER_PROMPT = (
|
| 404 |
+
f"Question: {query}\n\nPassages:\n{context_block}\n\n"
|
| 405 |
+
"Write a concise, well-grounded answer. Use inline citations like [1], [2] "
|
| 406 |
+
"that match the passage numbers above."
|
| 407 |
+
)
|
| 408 |
+
answer = llm.generate(prompt=ANSWER_PROMPT, system=ANSWER_SYSTEM, temperature=0.1)
|
| 409 |
+
st.caption(f"LLM generation: {time.time()-t0:.1f} s")
|
| 410 |
+
st.markdown("### Answer")
|
| 411 |
+
st.markdown(answer)
|
| 412 |
+
|
| 413 |
+
st.markdown("### Citations")
|
| 414 |
+
for c in citations:
|
| 415 |
+
st.markdown(
|
| 416 |
+
f"**[{c['n']}]** {c['title']} — pages {c['page_start']}–{c['page_end']} "
|
| 417 |
+
f"· score `{c['score']:.3f}` · `{Path(c['source_path']).name}`"
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# ── Step 6: critic ─────────────────────────────────────
|
| 421 |
+
phase_header(6, "Self-critique",
|
| 422 |
+
"LLM scores its own answer for grounding + completeness.")
|
| 423 |
+
t0 = time.time()
|
| 424 |
+
crit = critique(query, answer, context_block, llm=llm)
|
| 425 |
+
c1, c2, c3 = st.columns(3)
|
| 426 |
+
c1.metric("Grounded", "✅ yes" if crit["grounded"] else "⚠️ no")
|
| 427 |
+
c2.metric("Complete", "✅ yes" if crit["complete"] else "⚠️ no")
|
| 428 |
+
c3.metric("Confidence", f"{crit['confidence']:.2f}",
|
| 429 |
+
delta=f"threshold {AGENT_CONFIG['confidence_threshold']:.2f}")
|
| 430 |
+
if crit.get("missing"):
|
| 431 |
+
st.warning(f"Missing: {crit['missing']}")
|
| 432 |
+
st.caption(f"Critique latency: {time.time()-t0:.1f} s")
|
| 433 |
+
|
| 434 |
+
if crit["confidence"] >= AGENT_CONFIG["confidence_threshold"] and crit["grounded"]:
|
| 435 |
+
st.success(f"✓ Confidence {crit['confidence']:.2f} ≥ threshold — answer accepted.")
|
| 436 |
+
return
|
| 437 |
+
|
| 438 |
+
if it < AGENT_CONFIG["max_iterations"] - 1:
|
| 439 |
+
st.warning("Confidence below threshold — refining query and retrying.")
|
| 440 |
+
current_query = refine_query(query, crit.get("missing", ""), llm=llm)
|
| 441 |
+
else:
|
| 442 |
+
st.error("Max iterations reached. Returning best-effort answer.")
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# ───────────────────────────── sidebar + tabs ──────────────────────────────
|
| 446 |
+
def _sidebar() -> None:
|
| 447 |
+
st.sidebar.title("AdaptiveRAG")
|
| 448 |
+
st.sidebar.caption("Agentic + Self-RAG + Modular RAG")
|
| 449 |
+
llm = _llm()
|
| 450 |
+
ok = llm.health()
|
| 451 |
+
backend = "Groq API" if HOSTED else "Ollama (local)"
|
| 452 |
+
st.sidebar.markdown(f"**LLM backend**: {'🟢' if ok else '🔴'} {backend}")
|
| 453 |
+
st.sidebar.markdown(f"**Model**: `{LLM_CONFIG['model']}`")
|
| 454 |
+
st.sidebar.markdown(f"**Embedder**: `{EMBEDDING_CONFIG['model'].split('/')[-1]}`")
|
| 455 |
+
st.sidebar.markdown(f"**Reranker**: `bge-reranker-base`")
|
| 456 |
+
manifest = _load_manifest()
|
| 457 |
+
if manifest:
|
| 458 |
+
st.sidebar.markdown(f"**Index**: {manifest.get('n_chunks','?')} chunks across "
|
| 459 |
+
f"{len(manifest.get('chunks_per_doc',{}))} docs")
|
| 460 |
+
with st.sidebar.expander("Documents"):
|
| 461 |
+
for doc, n in sorted(manifest.get("chunks_per_doc", {}).items()):
|
| 462 |
+
st.markdown(f"- `{doc}` — {n}")
|
| 463 |
+
else:
|
| 464 |
+
st.sidebar.warning("No index found. Run `python ingest.py --reset`.")
|
| 465 |
+
st.sidebar.divider()
|
| 466 |
+
st.sidebar.markdown("### Pipeline")
|
| 467 |
+
st.sidebar.code(
|
| 468 |
+
"question\n ↓ embed (MiniLM)\n ↓ Self-RAG router\n ↓ planner → sub-queries\n"
|
| 469 |
+
" ↓ dense ∥ sparse\n ↓ RRF fusion\n ↓ cross-encoder rerank\n ↓ Qwen3-VL answer\n"
|
| 470 |
+
" ↓ self-critique → retry?\n → answer + citations",
|
| 471 |
+
language="text",
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def pipeline_tab() -> None:
|
| 476 |
+
st.subheader("🔬 Underhood: watch every stage of the agentic RAG pipeline")
|
| 477 |
+
st.caption(
|
| 478 |
+
"Each step renders its inputs and outputs as it runs — embedding vector, "
|
| 479 |
+
"router decision, planner sub-queries, dense vs sparse hits side-by-side, "
|
| 480 |
+
"RRF fusion, cross-encoder rerank, vector-space projection, answer, self-critique."
|
| 481 |
+
)
|
| 482 |
+
samples = [
|
| 483 |
+
"How does Self-RAG decide when to retrieve, and what reflection tokens does it use?",
|
| 484 |
+
"Compare DDPM and DDIM sampling — what does DDIM gain by being non-Markovian?",
|
| 485 |
+
"What is multi-head self-attention and why does parallelism matter?",
|
| 486 |
+
"How does HyDE improve dense retrieval without relevance labels?",
|
| 487 |
+
"How does ReAct combine reasoning and acting, vs chain-of-thought?",
|
| 488 |
+
"hello, what can you do?",
|
| 489 |
+
]
|
| 490 |
+
if "vq" not in st.session_state:
|
| 491 |
+
st.session_state.vq = samples[0]
|
| 492 |
+
cols = st.columns(3)
|
| 493 |
+
for i, s in enumerate(samples):
|
| 494 |
+
if cols[i % 3].button(s, key=f"vs{i}", use_container_width=True):
|
| 495 |
+
st.session_state.vq = s
|
| 496 |
+
q = st.text_area("Question", value=st.session_state.vq, height=80, key="vq_input")
|
| 497 |
+
if st.button("▶ Run pipeline", type="primary"):
|
| 498 |
+
if q.strip():
|
| 499 |
+
visual_pipeline(q.strip())
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def image_tab() -> None:
|
| 503 |
+
st.subheader("🖼️ Multimodal RAG (Qwen3-VL)")
|
| 504 |
+
st.caption(
|
| 505 |
+
"Upload an image (e.g. a figure from a paper). Qwen3-VL captions it, the "
|
| 506 |
+
"caption + question drives hybrid retrieval, then the model reasons over "
|
| 507 |
+
"image + retrieved passages together."
|
| 508 |
+
)
|
| 509 |
+
uploaded = st.file_uploader("Image", type=["png", "jpg", "jpeg", "webp"])
|
| 510 |
+
q = st.text_input("Question about the image", "Explain what this figure shows.")
|
| 511 |
+
go = st.button("Reason", type="primary", key="img_go")
|
| 512 |
+
if uploaded:
|
| 513 |
+
st.image(uploaded, width=400)
|
| 514 |
+
if not (go and uploaded):
|
| 515 |
+
return
|
| 516 |
+
with tempfile.NamedTemporaryFile(suffix=Path(uploaded.name).suffix, delete=False) as f:
|
| 517 |
+
f.write(uploaded.getbuffer())
|
| 518 |
+
tmp_path = f.name
|
| 519 |
+
try:
|
| 520 |
+
with st.spinner("Captioning → retrieving → multimodal reasoning..."):
|
| 521 |
+
out = image_retrieve_and_reason(tmp_path, q, llm=_llm())
|
| 522 |
+
st.markdown("### Caption")
|
| 523 |
+
st.write(out["caption"])
|
| 524 |
+
st.markdown("### Answer")
|
| 525 |
+
st.markdown(out["answer"])
|
| 526 |
+
st.markdown("### Retrieved passages")
|
| 527 |
+
for i, h in enumerate(out["hits"], start=1):
|
| 528 |
+
st.markdown(
|
| 529 |
+
f"**[{i}]** {h.metadata.get('title')} "
|
| 530 |
+
f"(p.{h.metadata.get('page_start')}–{h.metadata.get('page_end')}) "
|
| 531 |
+
f"· score `{h.score:.3f}`"
|
| 532 |
+
)
|
| 533 |
+
st.caption(h.text[:300] + ("…" if len(h.text) > 300 else ""))
|
| 534 |
+
finally:
|
| 535 |
+
os.unlink(tmp_path)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def main() -> None:
|
| 539 |
+
_sidebar()
|
| 540 |
+
st.title("AdaptiveRAG 📚🔬")
|
| 541 |
+
st.caption(
|
| 542 |
+
"Agentic + Self-RAG + Modular RAG over your local paper library — "
|
| 543 |
+
f"powered by `{LLM_CONFIG['model']}` via **{LLM_CONFIG['provider']}**. "
|
| 544 |
+
"Every pipeline stage is exposed below."
|
| 545 |
+
)
|
| 546 |
+
pipe, img = st.tabs(["🔬 Underhood pipeline", "🖼️ Image Q&A (multimodal)"])
|
| 547 |
+
with pipe:
|
| 548 |
+
pipeline_tab()
|
| 549 |
+
with img:
|
| 550 |
+
image_tab()
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
if __name__ == "__main__":
|
| 554 |
+
main()
|
ask.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI: python ask.py 'your question here'"""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
from agent.loop import run_agent
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main() -> None:
|
| 11 |
+
ap = argparse.ArgumentParser()
|
| 12 |
+
ap.add_argument("question", nargs="+")
|
| 13 |
+
ap.add_argument("--trace", action="store_true")
|
| 14 |
+
args = ap.parse_args()
|
| 15 |
+
q = " ".join(args.question)
|
| 16 |
+
res = run_agent(q)
|
| 17 |
+
print("\n=== ROUTE ===")
|
| 18 |
+
print(res.route)
|
| 19 |
+
print("\n=== ANSWER ===")
|
| 20 |
+
print(res.answer)
|
| 21 |
+
print("\n=== CITATIONS ===")
|
| 22 |
+
for c in res.citations:
|
| 23 |
+
print(f" [{c['n']}] {c['title']} (p.{c['page_start']}-{c['page_end']}) score={c['score']:.3f}")
|
| 24 |
+
print(f"\nConfidence: {res.confidence:.2f} iterations: {res.iterations}")
|
| 25 |
+
if args.trace:
|
| 26 |
+
print("\n=== TRACE ===")
|
| 27 |
+
for s in res.trace:
|
| 28 |
+
print(f" • {s.kind}: {json.dumps(s.detail, default=str)[:200]}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Central configuration for AdaptiveRAG."""
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
ROOT = Path(__file__).parent.resolve()
|
| 6 |
+
|
| 7 |
+
# Detect hosting environment
|
| 8 |
+
HOSTED = bool(os.environ.get("GROQ_API_KEY"))
|
| 9 |
+
|
| 10 |
+
LLM_CONFIG = {
|
| 11 |
+
"provider": "groq" if HOSTED else "ollama",
|
| 12 |
+
"model": "llama-3.1-8b-instant" if HOSTED else "qwen3-vl:8b-instruct-q8_0-optimized",
|
| 13 |
+
"base_url": "https://api.groq.com/openai/v1" if HOSTED else "http://localhost:11434",
|
| 14 |
+
"temperature": 0.1,
|
| 15 |
+
"timeout": 60 if HOSTED else 180,
|
| 16 |
+
"num_ctx": 8192,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
EMBEDDING_CONFIG = {
|
| 20 |
+
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
| 21 |
+
"device": "cpu",
|
| 22 |
+
"batch_size": 32,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
RERANKER_CONFIG = {
|
| 26 |
+
"model": "BAAI/bge-reranker-base",
|
| 27 |
+
"device": "cpu",
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
CHUNKING_CONFIG = {
|
| 31 |
+
"target_chunk_chars": 1400,
|
| 32 |
+
"max_chunk_chars": 2200,
|
| 33 |
+
"min_chunk_chars": 350,
|
| 34 |
+
"overlap_chars": 200,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
RETRIEVAL_CONFIG = {
|
| 38 |
+
"dense_k": 12,
|
| 39 |
+
"sparse_k": 12,
|
| 40 |
+
"rrf_k": 60,
|
| 41 |
+
"rerank_top_n": 5,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
AGENT_CONFIG = {
|
| 45 |
+
"max_iterations": 3,
|
| 46 |
+
"confidence_threshold": 0.85,
|
| 47 |
+
"max_plan_steps": 3,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
PATHS = {
|
| 51 |
+
"papers_dir": ROOT / "papers",
|
| 52 |
+
"chroma_dir": ROOT / "storage" / "chroma",
|
| 53 |
+
"bm25_path": ROOT / "storage" / "bm25.pkl",
|
| 54 |
+
"manifest_path": ROOT / "storage" / "manifest.json",
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
CHROMA_COLLECTION = "adaptive_rag"
|
download_papers.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
mkdir -p papers && cd papers
|
| 3 |
+
|
| 4 |
+
# Transformers
|
| 5 |
+
curl -L "https://arxiv.org/pdf/1706.03762" -o "01_attention_is_all_you_need.pdf"
|
| 6 |
+
curl -L "https://arxiv.org/pdf/1810.04805" -o "02_bert.pdf"
|
| 7 |
+
curl -L "https://arxiv.org/pdf/2005.14165" -o "03_gpt3.pdf"
|
| 8 |
+
|
| 9 |
+
# Diffusion
|
| 10 |
+
curl -L "https://arxiv.org/pdf/2006.11239" -o "04_ddpm.pdf"
|
| 11 |
+
curl -L "https://arxiv.org/pdf/2010.02502" -o "05_ddim.pdf"
|
| 12 |
+
|
| 13 |
+
# RAG
|
| 14 |
+
curl -L "https://arxiv.org/pdf/2005.11401" -o "06_rag_original.pdf"
|
| 15 |
+
curl -L "https://arxiv.org/pdf/2312.10997" -o "07_rag_survey.pdf"
|
| 16 |
+
curl -L "https://arxiv.org/pdf/2310.11511" -o "08_self_rag.pdf"
|
| 17 |
+
curl -L "https://arxiv.org/pdf/2212.10496" -o "09_hyde.pdf"
|
| 18 |
+
|
| 19 |
+
# Vision
|
| 20 |
+
curl -L "https://arxiv.org/pdf/2010.11929" -o "10_vit.pdf"
|
| 21 |
+
curl -L "https://arxiv.org/pdf/2103.00020" -o "11_clip.pdf"
|
| 22 |
+
|
| 23 |
+
# Agents
|
| 24 |
+
curl -L "https://arxiv.org/pdf/2210.03629" -o "12_react.pdf"
|
| 25 |
+
curl -L "https://arxiv.org/pdf/2201.11903" -o "13_chain_of_thought.pdf"
|
| 26 |
+
curl -L "https://arxiv.org/pdf/2303.18223" -o "14_llm_survey.pdf"
|
| 27 |
+
|
| 28 |
+
echo "Downloaded $(ls *.pdf | wc -l) papers"
|
ingest.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ingest all PDFs from the papers/ directory into ChromaDB + BM25."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
from config import PATHS
|
| 8 |
+
from ingestion.chunker import chunk_document
|
| 9 |
+
from ingestion.indexer import index_chunks
|
| 10 |
+
from ingestion.loader import discover_pdfs, load_pdf
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main() -> None:
|
| 14 |
+
ap = argparse.ArgumentParser()
|
| 15 |
+
ap.add_argument("--reset", action="store_true", help="Reset the index first")
|
| 16 |
+
ap.add_argument("--papers-dir", default=str(PATHS["papers_dir"]))
|
| 17 |
+
args = ap.parse_args()
|
| 18 |
+
|
| 19 |
+
pdfs = discover_pdfs(args.papers_dir)
|
| 20 |
+
if not pdfs:
|
| 21 |
+
print(f"No PDFs in {args.papers_dir}")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
print(f"Found {len(pdfs)} PDFs in {args.papers_dir}")
|
| 25 |
+
all_chunks = []
|
| 26 |
+
t0 = time.time()
|
| 27 |
+
for path in pdfs:
|
| 28 |
+
doc_id = path.stem
|
| 29 |
+
print(f"\n[{doc_id}]")
|
| 30 |
+
doc = load_pdf(path)
|
| 31 |
+
print(f" Loaded: {len(doc.pages)} pages, title={doc.title!r}")
|
| 32 |
+
chunks = chunk_document(doc, doc_id=doc_id)
|
| 33 |
+
print(f" Chunked: {len(chunks)} chunks (avg {sum(len(c.text) for c in chunks)//max(len(chunks),1)} chars)")
|
| 34 |
+
all_chunks.extend(chunks)
|
| 35 |
+
|
| 36 |
+
print(f"\nIndexing {len(all_chunks)} chunks total...")
|
| 37 |
+
manifest = index_chunks(all_chunks, reset=args.reset)
|
| 38 |
+
dt = time.time() - t0
|
| 39 |
+
print(f"\nDone in {dt:.1f}s. Manifest: {manifest}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
main()
|
ingestion/__init__.py
ADDED
|
File without changes
|
ingestion/chunker.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Semantic-aware chunking.
|
| 2 |
+
|
| 3 |
+
Strategy: split each page into sentences, then greedily group sentences into
|
| 4 |
+
chunks targeting CHUNKING_CONFIG['target_chunk_chars']. Carry an overlap of
|
| 5 |
+
the last few sentences (~overlap_chars) to the next chunk so context isn't
|
| 6 |
+
sliced mid-thought. Headings hint chunk boundaries.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
from config import CHUNKING_CONFIG
|
| 14 |
+
from ingestion.loader import LoadedDoc
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class Chunk:
|
| 19 |
+
chunk_id: str
|
| 20 |
+
doc_id: str
|
| 21 |
+
source_path: str
|
| 22 |
+
title: str
|
| 23 |
+
page_start: int
|
| 24 |
+
page_end: int
|
| 25 |
+
text: str
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+(?=[A-Z0-9(])")
|
| 29 |
+
_HEADING_HINT = re.compile(r"^(?:[0-9]+(?:\.[0-9]+)*\s+|abstract|introduction|conclusion|references|method|results|discussion|background)\b", re.IGNORECASE)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _split_sentences(text: str) -> list[str]:
|
| 33 |
+
parts: list[str] = []
|
| 34 |
+
for line in text.splitlines():
|
| 35 |
+
line = line.strip()
|
| 36 |
+
if not line:
|
| 37 |
+
continue
|
| 38 |
+
if _HEADING_HINT.match(line):
|
| 39 |
+
parts.append(line)
|
| 40 |
+
continue
|
| 41 |
+
for sent in _SENTENCE_SPLIT.split(line):
|
| 42 |
+
sent = sent.strip()
|
| 43 |
+
if sent:
|
| 44 |
+
parts.append(sent)
|
| 45 |
+
return parts
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def chunk_document(doc: LoadedDoc, doc_id: str) -> list[Chunk]:
|
| 49 |
+
target = CHUNKING_CONFIG["target_chunk_chars"]
|
| 50 |
+
max_chars = CHUNKING_CONFIG["max_chunk_chars"]
|
| 51 |
+
min_chars = CHUNKING_CONFIG["min_chunk_chars"]
|
| 52 |
+
overlap = CHUNKING_CONFIG["overlap_chars"]
|
| 53 |
+
|
| 54 |
+
units: list[tuple[int, str]] = []
|
| 55 |
+
for page in doc.pages:
|
| 56 |
+
for sent in _split_sentences(page.text):
|
| 57 |
+
units.append((page.page_number, sent))
|
| 58 |
+
|
| 59 |
+
chunks: list[Chunk] = []
|
| 60 |
+
buf: list[tuple[int, str]] = []
|
| 61 |
+
buf_len = 0
|
| 62 |
+
|
| 63 |
+
def flush() -> None:
|
| 64 |
+
nonlocal buf, buf_len
|
| 65 |
+
if not buf:
|
| 66 |
+
return
|
| 67 |
+
text = " ".join(s for _, s in buf).strip()
|
| 68 |
+
if len(text) < min_chars and chunks:
|
| 69 |
+
chunks[-1].text = (chunks[-1].text + " " + text).strip()
|
| 70 |
+
chunks[-1].page_end = buf[-1][0]
|
| 71 |
+
buf, buf_len = [], 0
|
| 72 |
+
return
|
| 73 |
+
chunk = Chunk(
|
| 74 |
+
chunk_id=f"{doc_id}::c{len(chunks):04d}",
|
| 75 |
+
doc_id=doc_id,
|
| 76 |
+
source_path=doc.source_path,
|
| 77 |
+
title=doc.title,
|
| 78 |
+
page_start=buf[0][0],
|
| 79 |
+
page_end=buf[-1][0],
|
| 80 |
+
text=text,
|
| 81 |
+
)
|
| 82 |
+
chunks.append(chunk)
|
| 83 |
+
carry: list[tuple[int, str]] = []
|
| 84 |
+
carry_len = 0
|
| 85 |
+
for pn, s in reversed(buf):
|
| 86 |
+
if carry_len + len(s) + 1 > overlap:
|
| 87 |
+
break
|
| 88 |
+
carry.insert(0, (pn, s))
|
| 89 |
+
carry_len += len(s) + 1
|
| 90 |
+
buf = carry
|
| 91 |
+
buf_len = sum(len(s) + 1 for _, s in buf)
|
| 92 |
+
|
| 93 |
+
for pn, sent in units:
|
| 94 |
+
is_heading = bool(_HEADING_HINT.match(sent))
|
| 95 |
+
if is_heading and buf_len >= min_chars:
|
| 96 |
+
flush()
|
| 97 |
+
buf.append((pn, sent))
|
| 98 |
+
buf_len += len(sent) + 1
|
| 99 |
+
if buf_len >= target:
|
| 100 |
+
flush()
|
| 101 |
+
elif buf_len >= max_chars:
|
| 102 |
+
flush()
|
| 103 |
+
flush()
|
| 104 |
+
return chunks
|
ingestion/embedder.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dense embeddings via sentence-transformers."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
from config import EMBEDDING_CONFIG
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@lru_cache(maxsize=1)
|
| 12 |
+
def get_embedder() -> SentenceTransformer:
|
| 13 |
+
return SentenceTransformer(
|
| 14 |
+
EMBEDDING_CONFIG["model"],
|
| 15 |
+
device=EMBEDDING_CONFIG["device"],
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def embed_texts(texts: list[str]) -> list[list[float]]:
|
| 20 |
+
model = get_embedder()
|
| 21 |
+
vecs = model.encode(
|
| 22 |
+
texts,
|
| 23 |
+
batch_size=EMBEDDING_CONFIG["batch_size"],
|
| 24 |
+
convert_to_numpy=True,
|
| 25 |
+
normalize_embeddings=True,
|
| 26 |
+
show_progress_bar=False,
|
| 27 |
+
)
|
| 28 |
+
return vecs.tolist()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def embed_query(text: str) -> list[float]:
|
| 32 |
+
return embed_texts([text])[0]
|
ingestion/indexer.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ChromaDB management + BM25 corpus persistence."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
os.environ.setdefault("ANONYMIZED_TELEMETRY", "False")
|
| 10 |
+
import chromadb
|
| 11 |
+
from chromadb.config import Settings # noqa: E402
|
| 12 |
+
|
| 13 |
+
from config import CHROMA_COLLECTION, PATHS
|
| 14 |
+
from ingestion.chunker import Chunk
|
| 15 |
+
from ingestion.embedder import embed_texts
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _ensure_dirs() -> None:
|
| 19 |
+
PATHS["chroma_dir"].mkdir(parents=True, exist_ok=True)
|
| 20 |
+
PATHS["bm25_path"].parent.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _client() -> chromadb.PersistentClient:
|
| 24 |
+
return chromadb.PersistentClient(
|
| 25 |
+
path=str(PATHS["chroma_dir"]),
|
| 26 |
+
settings=Settings(anonymized_telemetry=False),
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_chroma_collection():
|
| 31 |
+
_ensure_dirs()
|
| 32 |
+
client = _client()
|
| 33 |
+
return client.get_or_create_collection(
|
| 34 |
+
name=CHROMA_COLLECTION,
|
| 35 |
+
metadata={"hnsw:space": "cosine"},
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def reset_index() -> None:
|
| 40 |
+
_ensure_dirs()
|
| 41 |
+
client = _client()
|
| 42 |
+
try:
|
| 43 |
+
client.delete_collection(CHROMA_COLLECTION)
|
| 44 |
+
except Exception:
|
| 45 |
+
pass
|
| 46 |
+
if PATHS["bm25_path"].exists():
|
| 47 |
+
PATHS["bm25_path"].unlink()
|
| 48 |
+
if PATHS["manifest_path"].exists():
|
| 49 |
+
PATHS["manifest_path"].unlink()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
_TOKEN_RE = re.compile(r"[a-zA-Z0-9]+")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def tokenize(text: str) -> list[str]:
|
| 56 |
+
return [t.lower() for t in _TOKEN_RE.findall(text)]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def index_chunks(chunks: list[Chunk], reset: bool = False) -> dict:
|
| 60 |
+
_ensure_dirs()
|
| 61 |
+
if reset:
|
| 62 |
+
reset_index()
|
| 63 |
+
|
| 64 |
+
coll = get_chroma_collection()
|
| 65 |
+
|
| 66 |
+
ids = [c.chunk_id for c in chunks]
|
| 67 |
+
docs = [c.text for c in chunks]
|
| 68 |
+
metas = [
|
| 69 |
+
{
|
| 70 |
+
"doc_id": c.doc_id,
|
| 71 |
+
"source_path": c.source_path,
|
| 72 |
+
"title": c.title,
|
| 73 |
+
"page_start": c.page_start,
|
| 74 |
+
"page_end": c.page_end,
|
| 75 |
+
}
|
| 76 |
+
for c in chunks
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
print(f" Embedding {len(docs)} chunks...")
|
| 80 |
+
embeddings = embed_texts(docs)
|
| 81 |
+
|
| 82 |
+
print(f" Writing to ChromaDB ({CHROMA_COLLECTION})...")
|
| 83 |
+
BATCH = 256
|
| 84 |
+
for i in range(0, len(ids), BATCH):
|
| 85 |
+
coll.upsert(
|
| 86 |
+
ids=ids[i : i + BATCH],
|
| 87 |
+
documents=docs[i : i + BATCH],
|
| 88 |
+
metadatas=metas[i : i + BATCH],
|
| 89 |
+
embeddings=embeddings[i : i + BATCH],
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
print(" Building BM25 corpus...")
|
| 93 |
+
tokenized = [tokenize(d) for d in docs]
|
| 94 |
+
with open(PATHS["bm25_path"], "wb") as f:
|
| 95 |
+
pickle.dump(
|
| 96 |
+
{"ids": ids, "tokenized": tokenized, "metas": metas, "docs": docs},
|
| 97 |
+
f,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
manifest = {
|
| 101 |
+
"n_chunks": len(ids),
|
| 102 |
+
"chunks_per_doc": _group_count([c.doc_id for c in chunks]),
|
| 103 |
+
}
|
| 104 |
+
with open(PATHS["manifest_path"], "w") as f:
|
| 105 |
+
json.dump(manifest, f, indent=2)
|
| 106 |
+
return manifest
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _group_count(items: list[str]) -> dict:
|
| 110 |
+
out: dict = {}
|
| 111 |
+
for x in items:
|
| 112 |
+
out[x] = out.get(x, 0) + 1
|
| 113 |
+
return out
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def fetch_embeddings(chunk_ids: list[str]) -> dict[str, list[float]]:
|
| 117 |
+
"""Pull stored embeddings for a list of chunk ids (used for visualization)."""
|
| 118 |
+
if not chunk_ids:
|
| 119 |
+
return {}
|
| 120 |
+
coll = get_chroma_collection()
|
| 121 |
+
res = coll.get(ids=list(chunk_ids), include=["embeddings"])
|
| 122 |
+
out: dict[str, list[float]] = {}
|
| 123 |
+
for cid, vec in zip(res["ids"], res["embeddings"]):
|
| 124 |
+
out[cid] = list(vec)
|
| 125 |
+
return out
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_bm25_corpus() -> dict:
|
| 129 |
+
if not PATHS["bm25_path"].exists():
|
| 130 |
+
raise FileNotFoundError(
|
| 131 |
+
f"BM25 corpus not found at {PATHS['bm25_path']}. Run ingestion first."
|
| 132 |
+
)
|
| 133 |
+
with open(PATHS["bm25_path"], "rb") as f:
|
| 134 |
+
return pickle.load(f)
|
ingestion/loader.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PDF loader. Returns per-page text + structural metadata."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import pymupdf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class PageText:
|
| 13 |
+
page_number: int
|
| 14 |
+
text: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class LoadedDoc:
|
| 19 |
+
source_path: str
|
| 20 |
+
title: str
|
| 21 |
+
pages: list[PageText] = field(default_factory=list)
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def full_text(self) -> str:
|
| 25 |
+
return "\n\n".join(p.text for p in self.pages)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_LIGATURES = {
|
| 29 |
+
"ff": "ff", "fi": "fi", "fl": "fl",
|
| 30 |
+
"ffi": "ffi", "ffl": "ffl",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _clean(text: str) -> str:
|
| 35 |
+
for k, v in _LIGATURES.items():
|
| 36 |
+
text = text.replace(k, v)
|
| 37 |
+
text = re.sub(r"-\n(?=\w)", "", text)
|
| 38 |
+
text = re.sub(r"[ \t]+", " ", text)
|
| 39 |
+
text = re.sub(r"\n{3,}", "\n\n", text)
|
| 40 |
+
return text.strip()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_TITLE_OVERRIDES = {
|
| 44 |
+
"01_attention_is_all_you_need": "Attention Is All You Need (Vaswani et al., 2017)",
|
| 45 |
+
"02_bert": "BERT: Pre-training of Deep Bidirectional Transformers (Devlin et al., 2018)",
|
| 46 |
+
"03_gpt3": "Language Models are Few-Shot Learners (GPT-3, Brown et al., 2020)",
|
| 47 |
+
"04_ddpm": "Denoising Diffusion Probabilistic Models (Ho et al., 2020)",
|
| 48 |
+
"05_ddim": "Denoising Diffusion Implicit Models (Song et al., 2020)",
|
| 49 |
+
"06_rag_original": "Retrieval-Augmented Generation for Knowledge-Intensive NLP (Lewis et al., 2020)",
|
| 50 |
+
"07_rag_survey": "Retrieval-Augmented Generation for LLMs: A Survey (Gao et al., 2023)",
|
| 51 |
+
"08_self_rag": "Self-RAG: Learning to Retrieve, Generate, and Critique (Asai et al., 2023)",
|
| 52 |
+
"09_hyde": "Precise Zero-Shot Dense Retrieval with HyDE (Gao et al., 2022)",
|
| 53 |
+
"10_vit": "An Image is Worth 16x16 Words (Vision Transformer, Dosovitskiy et al., 2020)",
|
| 54 |
+
"11_clip": "Learning Transferable Visual Models from Natural Language Supervision (CLIP, Radford et al., 2021)",
|
| 55 |
+
"12_react": "ReAct: Synergizing Reasoning and Acting in Language Models (Yao et al., 2022)",
|
| 56 |
+
"13_chain_of_thought": "Chain-of-Thought Prompting Elicits Reasoning (Wei et al., 2022)",
|
| 57 |
+
"14_llm_survey": "A Survey of Large Language Models (Zhao et al., 2023)",
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_pdf(path: str | Path) -> LoadedDoc:
|
| 62 |
+
path = Path(path)
|
| 63 |
+
doc = pymupdf.open(path)
|
| 64 |
+
pages: list[PageText] = []
|
| 65 |
+
for i, page in enumerate(doc, start=1):
|
| 66 |
+
raw = page.get_text("text")
|
| 67 |
+
cleaned = _clean(raw)
|
| 68 |
+
if cleaned:
|
| 69 |
+
pages.append(PageText(page_number=i, text=cleaned))
|
| 70 |
+
title = _TITLE_OVERRIDES.get(path.stem) or _guess_title(pages, fallback=path.stem)
|
| 71 |
+
doc.close()
|
| 72 |
+
return LoadedDoc(source_path=str(path), title=title, pages=pages)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _guess_title(pages: list[PageText], fallback: str) -> str:
|
| 76 |
+
if not pages:
|
| 77 |
+
return fallback
|
| 78 |
+
first = pages[0].text
|
| 79 |
+
for line in first.splitlines():
|
| 80 |
+
line = line.strip()
|
| 81 |
+
if 10 < len(line) < 180 and not line.lower().startswith(("abstract", "arxiv:")):
|
| 82 |
+
return line
|
| 83 |
+
return fallback
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def discover_pdfs(papers_dir: str | Path) -> list[Path]:
|
| 87 |
+
return sorted(Path(papers_dir).glob("*.pdf"))
|
llm/__init__.py
ADDED
|
File without changes
|
llm/client_factory.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Return the right LLM client based on environment.
|
| 2 |
+
|
| 3 |
+
Local (Ollama running) → OllamaClient
|
| 4 |
+
Hosted (GROQ_API_KEY set) → GroqClient
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_llm():
|
| 12 |
+
if os.environ.get("GROQ_API_KEY"):
|
| 13 |
+
from llm.groq_client import GroqClient
|
| 14 |
+
return GroqClient()
|
| 15 |
+
from llm.ollama_client import OllamaClient
|
| 16 |
+
return OllamaClient()
|
llm/groq_client.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Groq API client — drop-in replacement for OllamaClient when running hosted.
|
| 2 |
+
|
| 3 |
+
Free tier: 14,400 requests/day, ~500 req/min.
|
| 4 |
+
Text model : llama-3.1-8b-instant (fast, cheap)
|
| 5 |
+
Vision model: llama-3.2-11b-vision-preview (images)
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import base64
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
+
|
| 17 |
+
_TEXT_MODEL = "llama-3.1-8b-instant"
|
| 18 |
+
_VISION_MODEL = "llama-3.2-11b-vision-preview"
|
| 19 |
+
_BASE_URL = "https://api.groq.com/openai/v1"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class GroqClient:
|
| 23 |
+
def __init__(self) -> None:
|
| 24 |
+
self.api_key = os.environ.get("GROQ_API_KEY", "")
|
| 25 |
+
self.temperature = 0.1
|
| 26 |
+
self.timeout = 60
|
| 27 |
+
|
| 28 |
+
def _headers(self) -> dict[str, str]:
|
| 29 |
+
return {
|
| 30 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 31 |
+
"Content-Type": "application/json",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def _chat(
|
| 35 |
+
self,
|
| 36 |
+
messages: list[dict],
|
| 37 |
+
model: str,
|
| 38 |
+
temperature: float,
|
| 39 |
+
json_mode: bool = False,
|
| 40 |
+
) -> str:
|
| 41 |
+
payload: dict[str, Any] = {
|
| 42 |
+
"model": model,
|
| 43 |
+
"messages": messages,
|
| 44 |
+
"temperature": temperature,
|
| 45 |
+
}
|
| 46 |
+
if json_mode:
|
| 47 |
+
payload["response_format"] = {"type": "json_object"}
|
| 48 |
+
r = requests.post(
|
| 49 |
+
f"{_BASE_URL}/chat/completions",
|
| 50 |
+
headers=self._headers(),
|
| 51 |
+
json=payload,
|
| 52 |
+
timeout=self.timeout,
|
| 53 |
+
)
|
| 54 |
+
r.raise_for_status()
|
| 55 |
+
return r.json()["choices"][0]["message"]["content"].strip()
|
| 56 |
+
|
| 57 |
+
def _build_messages(
|
| 58 |
+
self,
|
| 59 |
+
prompt: str,
|
| 60 |
+
system: str | None,
|
| 61 |
+
images: list[str] | None,
|
| 62 |
+
) -> tuple[list[dict], str]:
|
| 63 |
+
msgs: list[dict] = []
|
| 64 |
+
if system:
|
| 65 |
+
msgs.append({"role": "system", "content": system})
|
| 66 |
+
if images:
|
| 67 |
+
content: list[dict] = [{"type": "text", "text": prompt}]
|
| 68 |
+
for img_path in images:
|
| 69 |
+
b64 = self._encode_image(img_path)
|
| 70 |
+
suffix = Path(img_path).suffix.lower().lstrip(".")
|
| 71 |
+
mime = {"jpg": "image/jpeg", "jpeg": "image/jpeg",
|
| 72 |
+
"png": "image/png", "webp": "image/webp"}.get(suffix, "image/png")
|
| 73 |
+
content.append({"type": "image_url",
|
| 74 |
+
"image_url": {"url": f"data:{mime};base64,{b64}"}})
|
| 75 |
+
msgs.append({"role": "user", "content": content})
|
| 76 |
+
return msgs, _VISION_MODEL
|
| 77 |
+
msgs.append({"role": "user", "content": prompt})
|
| 78 |
+
return msgs, _TEXT_MODEL
|
| 79 |
+
|
| 80 |
+
def generate(
|
| 81 |
+
self,
|
| 82 |
+
prompt: str,
|
| 83 |
+
system: str | None = None,
|
| 84 |
+
temperature: float | None = None,
|
| 85 |
+
images: list[str] | None = None,
|
| 86 |
+
format: str | None = None,
|
| 87 |
+
) -> str:
|
| 88 |
+
msgs, model = self._build_messages(prompt, system, images)
|
| 89 |
+
return self._chat(
|
| 90 |
+
msgs, model,
|
| 91 |
+
temperature if temperature is not None else self.temperature,
|
| 92 |
+
json_mode=(format == "json"),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def generate_json(
|
| 96 |
+
self,
|
| 97 |
+
prompt: str,
|
| 98 |
+
system: str | None = None,
|
| 99 |
+
temperature: float | None = None,
|
| 100 |
+
images: list[str] | None = None,
|
| 101 |
+
) -> dict[str, Any]:
|
| 102 |
+
text = self.generate(prompt=prompt, system=system,
|
| 103 |
+
temperature=temperature, images=images, format="json")
|
| 104 |
+
return _safe_json(text)
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def _encode_image(path: str | Path) -> str:
|
| 108 |
+
with open(path, "rb") as f:
|
| 109 |
+
return base64.b64encode(f.read()).decode("utf-8")
|
| 110 |
+
|
| 111 |
+
def health(self) -> bool:
|
| 112 |
+
if not self.api_key:
|
| 113 |
+
return False
|
| 114 |
+
try:
|
| 115 |
+
r = requests.get(f"{_BASE_URL}/models", headers=self._headers(), timeout=5)
|
| 116 |
+
return r.status_code == 200
|
| 117 |
+
except requests.RequestException:
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _safe_json(text: str) -> dict[str, Any]:
|
| 122 |
+
text = text.strip()
|
| 123 |
+
try:
|
| 124 |
+
return json.loads(text)
|
| 125 |
+
except json.JSONDecodeError:
|
| 126 |
+
pass
|
| 127 |
+
start, end = text.find("{"), text.rfind("}")
|
| 128 |
+
if start != -1 and end > start:
|
| 129 |
+
try:
|
| 130 |
+
return json.loads(text[start:end + 1])
|
| 131 |
+
except json.JSONDecodeError:
|
| 132 |
+
pass
|
| 133 |
+
return {"_raw": text}
|
llm/ollama_client.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thin Ollama HTTP client. Handles text + multimodal (image) requests."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import base64
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
from config import LLM_CONFIG
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class OllamaClient:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
model: str | None = None,
|
| 18 |
+
base_url: str | None = None,
|
| 19 |
+
temperature: float | None = None,
|
| 20 |
+
timeout: int | None = None,
|
| 21 |
+
num_ctx: int | None = None,
|
| 22 |
+
) -> None:
|
| 23 |
+
self.model = model or LLM_CONFIG["model"]
|
| 24 |
+
self.base_url = (base_url or LLM_CONFIG["base_url"]).rstrip("/")
|
| 25 |
+
self.temperature = temperature if temperature is not None else LLM_CONFIG["temperature"]
|
| 26 |
+
self.timeout = timeout or LLM_CONFIG["timeout"]
|
| 27 |
+
self.num_ctx = num_ctx or LLM_CONFIG["num_ctx"]
|
| 28 |
+
|
| 29 |
+
def _options(self, **overrides: Any) -> dict[str, Any]:
|
| 30 |
+
opts = {"temperature": self.temperature, "num_ctx": self.num_ctx}
|
| 31 |
+
opts.update({k: v for k, v in overrides.items() if v is not None})
|
| 32 |
+
return opts
|
| 33 |
+
|
| 34 |
+
def generate(
|
| 35 |
+
self,
|
| 36 |
+
prompt: str,
|
| 37 |
+
system: str | None = None,
|
| 38 |
+
temperature: float | None = None,
|
| 39 |
+
images: list[str] | None = None,
|
| 40 |
+
format: str | None = None,
|
| 41 |
+
) -> str:
|
| 42 |
+
payload: dict[str, Any] = {
|
| 43 |
+
"model": self.model,
|
| 44 |
+
"prompt": prompt,
|
| 45 |
+
"stream": False,
|
| 46 |
+
"options": self._options(temperature=temperature),
|
| 47 |
+
}
|
| 48 |
+
if system:
|
| 49 |
+
payload["system"] = system
|
| 50 |
+
if images:
|
| 51 |
+
payload["images"] = [self._encode_image(p) for p in images]
|
| 52 |
+
if format:
|
| 53 |
+
payload["format"] = format
|
| 54 |
+
r = requests.post(
|
| 55 |
+
f"{self.base_url}/api/generate",
|
| 56 |
+
json=payload,
|
| 57 |
+
timeout=self.timeout,
|
| 58 |
+
)
|
| 59 |
+
r.raise_for_status()
|
| 60 |
+
return r.json().get("response", "").strip()
|
| 61 |
+
|
| 62 |
+
def generate_json(
|
| 63 |
+
self,
|
| 64 |
+
prompt: str,
|
| 65 |
+
system: str | None = None,
|
| 66 |
+
temperature: float | None = None,
|
| 67 |
+
images: list[str] | None = None,
|
| 68 |
+
) -> dict[str, Any]:
|
| 69 |
+
"""Ask the model for a JSON object and parse it. Falls back to extraction
|
| 70 |
+
if the model wraps JSON in prose or fences."""
|
| 71 |
+
text = self.generate(
|
| 72 |
+
prompt=prompt,
|
| 73 |
+
system=system,
|
| 74 |
+
temperature=temperature,
|
| 75 |
+
images=images,
|
| 76 |
+
format="json",
|
| 77 |
+
)
|
| 78 |
+
return _safe_json_loads(text)
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def _encode_image(path: str | Path) -> str:
|
| 82 |
+
with open(path, "rb") as f:
|
| 83 |
+
return base64.b64encode(f.read()).decode("utf-8")
|
| 84 |
+
|
| 85 |
+
def health(self) -> bool:
|
| 86 |
+
try:
|
| 87 |
+
r = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
| 88 |
+
return r.status_code == 200
|
| 89 |
+
except requests.RequestException:
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _safe_json_loads(text: str) -> dict[str, Any]:
|
| 94 |
+
text = text.strip()
|
| 95 |
+
try:
|
| 96 |
+
return json.loads(text)
|
| 97 |
+
except json.JSONDecodeError:
|
| 98 |
+
pass
|
| 99 |
+
if "```" in text:
|
| 100 |
+
parts = text.split("```")
|
| 101 |
+
for part in parts:
|
| 102 |
+
stripped = part.strip()
|
| 103 |
+
if stripped.startswith("json"):
|
| 104 |
+
stripped = stripped[4:].strip()
|
| 105 |
+
try:
|
| 106 |
+
return json.loads(stripped)
|
| 107 |
+
except json.JSONDecodeError:
|
| 108 |
+
continue
|
| 109 |
+
start = text.find("{")
|
| 110 |
+
end = text.rfind("}")
|
| 111 |
+
if start != -1 and end != -1 and end > start:
|
| 112 |
+
try:
|
| 113 |
+
return json.loads(text[start : end + 1])
|
| 114 |
+
except json.JSONDecodeError:
|
| 115 |
+
pass
|
| 116 |
+
return {"_raw": text}
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
chromadb==0.5.23
|
| 2 |
+
sentence-transformers==3.3.1
|
| 3 |
+
rank-bm25==0.2.2
|
| 4 |
+
pymupdf==1.25.1
|
| 5 |
+
requests==2.32.3
|
| 6 |
+
numpy==1.26.4
|
| 7 |
+
streamlit==1.41.1
|
| 8 |
+
tqdm==4.67.1
|
| 9 |
+
pillow==11.0.0
|
retrieval/__init__.py
ADDED
|
File without changes
|
retrieval/dense.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dense (vector) retrieval over ChromaDB."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from config import RETRIEVAL_CONFIG
|
| 8 |
+
from ingestion.embedder import embed_query
|
| 9 |
+
from ingestion.indexer import get_chroma_collection
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class Hit:
|
| 14 |
+
chunk_id: str
|
| 15 |
+
text: str
|
| 16 |
+
metadata: dict
|
| 17 |
+
score: float
|
| 18 |
+
rank: int
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def dense_search(query: str, k: int | None = None) -> list[Hit]:
|
| 22 |
+
k = k or RETRIEVAL_CONFIG["dense_k"]
|
| 23 |
+
coll = get_chroma_collection()
|
| 24 |
+
qv = embed_query(query)
|
| 25 |
+
res = coll.query(
|
| 26 |
+
query_embeddings=[qv],
|
| 27 |
+
n_results=k,
|
| 28 |
+
include=["documents", "metadatas", "distances"],
|
| 29 |
+
)
|
| 30 |
+
ids = res["ids"][0]
|
| 31 |
+
docs = res["documents"][0]
|
| 32 |
+
metas = res["metadatas"][0]
|
| 33 |
+
dists = res["distances"][0]
|
| 34 |
+
hits: list[Hit] = []
|
| 35 |
+
for r, (cid, doc, meta, dist) in enumerate(zip(ids, docs, metas, dists)):
|
| 36 |
+
# cosine distance → similarity in [0,1]; chroma returns 1 - cosine_sim for cosine space
|
| 37 |
+
score = max(0.0, 1.0 - float(dist))
|
| 38 |
+
hits.append(Hit(chunk_id=cid, text=doc, metadata=dict(meta), score=score, rank=r))
|
| 39 |
+
return hits
|
retrieval/hybrid.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reciprocal Rank Fusion of dense + sparse hits."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from config import RETRIEVAL_CONFIG
|
| 5 |
+
from retrieval.dense import Hit, dense_search
|
| 6 |
+
from retrieval.sparse import sparse_search
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def reciprocal_rank_fusion(
|
| 10 |
+
rankings: list[list[Hit]],
|
| 11 |
+
k_const: int | None = None,
|
| 12 |
+
top_k: int | None = None,
|
| 13 |
+
) -> list[Hit]:
|
| 14 |
+
k_const = k_const or RETRIEVAL_CONFIG["rrf_k"]
|
| 15 |
+
fused: dict[str, dict] = {}
|
| 16 |
+
for ranking in rankings:
|
| 17 |
+
for r, hit in enumerate(ranking):
|
| 18 |
+
entry = fused.setdefault(
|
| 19 |
+
hit.chunk_id,
|
| 20 |
+
{"hit": hit, "score": 0.0},
|
| 21 |
+
)
|
| 22 |
+
entry["score"] += 1.0 / (k_const + r + 1)
|
| 23 |
+
# prefer higher-ranked instance for the canonical hit object
|
| 24 |
+
if hit.rank < entry["hit"].rank:
|
| 25 |
+
entry["hit"] = hit
|
| 26 |
+
merged = sorted(fused.values(), key=lambda x: x["score"], reverse=True)
|
| 27 |
+
out: list[Hit] = []
|
| 28 |
+
for r, entry in enumerate(merged):
|
| 29 |
+
h = entry["hit"]
|
| 30 |
+
out.append(
|
| 31 |
+
Hit(
|
| 32 |
+
chunk_id=h.chunk_id,
|
| 33 |
+
text=h.text,
|
| 34 |
+
metadata=h.metadata,
|
| 35 |
+
score=entry["score"],
|
| 36 |
+
rank=r,
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
if top_k:
|
| 40 |
+
out = out[:top_k]
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def hybrid_search(query: str, top_k: int | None = None) -> list[Hit]:
|
| 45 |
+
dense = dense_search(query)
|
| 46 |
+
sparse = sparse_search(query)
|
| 47 |
+
return reciprocal_rank_fusion([dense, sparse], top_k=top_k)
|
retrieval/pipeline.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""End-to-end retrieval: hybrid search + cross-encoder reranking."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from config import RETRIEVAL_CONFIG
|
| 5 |
+
from retrieval.dense import Hit
|
| 6 |
+
from retrieval.hybrid import hybrid_search
|
| 7 |
+
from retrieval.reranker import rerank
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def hybrid_retrieve(query: str, top_n: int | None = None) -> list[Hit]:
|
| 11 |
+
fused = hybrid_search(query, top_k=max(RETRIEVAL_CONFIG["dense_k"], RETRIEVAL_CONFIG["sparse_k"]))
|
| 12 |
+
return rerank(query, fused, top_n=top_n)
|
retrieval/reranker.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cross-encoder reranker: deep relevance scoring on top of fused hits."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
from sentence_transformers import CrossEncoder
|
| 7 |
+
|
| 8 |
+
from config import RERANKER_CONFIG, RETRIEVAL_CONFIG
|
| 9 |
+
from retrieval.dense import Hit
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@lru_cache(maxsize=1)
|
| 13 |
+
def get_reranker() -> CrossEncoder:
|
| 14 |
+
return CrossEncoder(RERANKER_CONFIG["model"], device=RERANKER_CONFIG["device"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def rerank(query: str, hits: list[Hit], top_n: int | None = None) -> list[Hit]:
|
| 18 |
+
if not hits:
|
| 19 |
+
return []
|
| 20 |
+
top_n = top_n or RETRIEVAL_CONFIG["rerank_top_n"]
|
| 21 |
+
model = get_reranker()
|
| 22 |
+
pairs = [(query, h.text) for h in hits]
|
| 23 |
+
scores = model.predict(pairs, show_progress_bar=False)
|
| 24 |
+
ranked = sorted(zip(hits, scores), key=lambda x: float(x[1]), reverse=True)[:top_n]
|
| 25 |
+
out: list[Hit] = []
|
| 26 |
+
for r, (h, s) in enumerate(ranked):
|
| 27 |
+
out.append(
|
| 28 |
+
Hit(
|
| 29 |
+
chunk_id=h.chunk_id,
|
| 30 |
+
text=h.text,
|
| 31 |
+
metadata=h.metadata,
|
| 32 |
+
score=float(s),
|
| 33 |
+
rank=r,
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
return out
|
retrieval/sparse.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sparse retrieval via BM25Okapi over the persisted token corpus."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
from rank_bm25 import BM25Okapi
|
| 7 |
+
|
| 8 |
+
from config import RETRIEVAL_CONFIG
|
| 9 |
+
from ingestion.indexer import load_bm25_corpus, tokenize
|
| 10 |
+
from retrieval.dense import Hit
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@lru_cache(maxsize=1)
|
| 14 |
+
def _bm25_state():
|
| 15 |
+
corpus = load_bm25_corpus()
|
| 16 |
+
bm25 = BM25Okapi(corpus["tokenized"])
|
| 17 |
+
return bm25, corpus
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sparse_search(query: str, k: int | None = None) -> list[Hit]:
|
| 21 |
+
k = k or RETRIEVAL_CONFIG["sparse_k"]
|
| 22 |
+
bm25, corpus = _bm25_state()
|
| 23 |
+
tokens = tokenize(query)
|
| 24 |
+
if not tokens:
|
| 25 |
+
return []
|
| 26 |
+
scores = bm25.get_scores(tokens)
|
| 27 |
+
idx_sorted = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
|
| 28 |
+
max_score = float(scores[idx_sorted[0]]) if idx_sorted else 0.0
|
| 29 |
+
hits: list[Hit] = []
|
| 30 |
+
for r, i in enumerate(idx_sorted):
|
| 31 |
+
s = float(scores[i])
|
| 32 |
+
if s <= 0:
|
| 33 |
+
continue
|
| 34 |
+
norm = s / max_score if max_score > 0 else 0.0
|
| 35 |
+
hits.append(
|
| 36 |
+
Hit(
|
| 37 |
+
chunk_id=corpus["ids"][i],
|
| 38 |
+
text=corpus["docs"][i],
|
| 39 |
+
metadata=dict(corpus["metas"][i]),
|
| 40 |
+
score=norm,
|
| 41 |
+
rank=r,
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
return hits
|
storage/bm25.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:01eb71c2d4f2c3f021e925e3c96a3131fd8406592c017daf8a405ecac1578ec6
|
| 3 |
+
size 5864776
|
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/data_level0.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:657eee8d3193abf7c307965c3cb694d08ac955dcc9b6ed3139d1d60e746449b8
|
| 3 |
+
size 1676000
|
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/header.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2ca0f721e87b23fa44989e0cb59d71339f5a16cc05a8eb3c7777e658757e2e5
|
| 3 |
+
size 100
|
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/index_metadata.pickle
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a9b9b398fb0167ff163c08c65e2728ff53af2a5e421a2355c4ee1ecdc746e2c
|
| 3 |
+
size 36020
|
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/length.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d51e5510e0fb53d9746f2be0e91a1f7cc74c25bdf21a7ca8096fccf6074521a3
|
| 3 |
+
size 4000
|
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/link_lists.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:358ea02768d90865ee8a426535f44933635bc209dda6b9481c2b9f221f08b18b
|
| 3 |
+
size 8148
|
storage/chroma/chroma.sqlite3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9526de2262c6da38b5d8ca3bb71f7902e912713b831020b7fb9dc22ba9d4d7fc
|
| 3 |
+
size 51924992
|
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/data_level0.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be8ace3c8f7cfb3c3d7fca6e733651bff7d18d34a51fbeac7096129dd7bb883d
|
| 3 |
+
size 1676000
|
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/header.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2ca0f721e87b23fa44989e0cb59d71339f5a16cc05a8eb3c7777e658757e2e5
|
| 3 |
+
size 100
|
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/index_metadata.pickle
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:56cdf88e893d7b0e61a770d05eb551f743b72862a8666e733304527383803612
|
| 3 |
+
size 36020
|
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/length.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49b0ed9d141de28649f8bb3bcf395a50c5a960b5f003bf4b4964e5ac62fba885
|
| 3 |
+
size 4000
|
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/link_lists.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa7615217e24a544a1ecbaaa207daea4e33fe9dca8b5967ca8c5a90b14a90782
|
| 3 |
+
size 8148
|
storage/manifest.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"n_chunks": 1934,
|
| 3 |
+
"chunks_per_doc": {
|
| 4 |
+
"01_attention_is_all_you_need": 34,
|
| 5 |
+
"02_bert": 60,
|
| 6 |
+
"03_gpt3": 237,
|
| 7 |
+
"04_ddpm": 48,
|
| 8 |
+
"05_ddim": 45,
|
| 9 |
+
"06_rag_original": 59,
|
| 10 |
+
"07_rag_survey": 95,
|
| 11 |
+
"08_self_rag": 88,
|
| 12 |
+
"09_hyde": 35,
|
| 13 |
+
"10_vit": 61,
|
| 14 |
+
"11_clip": 217,
|
| 15 |
+
"12_react": 95,
|
| 16 |
+
"13_chain_of_thought": 135,
|
| 17 |
+
"14_llm_survey": 725
|
| 18 |
+
}
|
| 19 |
+
}
|