NoobNovel commited on
Commit
e0670a4
·
1 Parent(s): ddcb7f3

AdaptiveRAG: Agentic + Self-RAG + Modular RAG pipeline with visual UI

Browse files

- Semantic chunking + dual embedding (dense MiniLM + sparse BM25)
- Hybrid retrieval: Chroma cosine + BM25 fused via RRF + BGE cross-encoder rerank
- Self-RAG router (RETRIEVE / ANSWER_DIRECTLY / CLARIFY)
- Agentic loop: plan -> retrieve -> answer -> self-critique -> refine
- Streamlit UI exposing every pipeline stage with 2D vector space scatter plot
- Pre-built index: 1934 chunks across 14 foundational AI papers (git-lfs)
- Auto-switches backend: Ollama locally, Groq API when GROQ_API_KEY is set

Files changed (44) hide show
  1. .gitattributes +2 -0
  2. .gitignore +15 -0
  3. .streamlit/config.toml +5 -0
  4. README.md +38 -6
  5. agent/__init__.py +0 -0
  6. agent/critic.py +65 -0
  7. agent/loop.py +198 -0
  8. agent/planner.py +52 -0
  9. agent/router.py +39 -0
  10. agent/tools.py +85 -0
  11. app.py +554 -0
  12. ask.py +32 -0
  13. config.py +57 -0
  14. download_papers.sh +28 -0
  15. ingest.py +43 -0
  16. ingestion/__init__.py +0 -0
  17. ingestion/chunker.py +104 -0
  18. ingestion/embedder.py +32 -0
  19. ingestion/indexer.py +134 -0
  20. ingestion/loader.py +87 -0
  21. llm/__init__.py +0 -0
  22. llm/client_factory.py +16 -0
  23. llm/groq_client.py +133 -0
  24. llm/ollama_client.py +116 -0
  25. requirements.txt +9 -0
  26. retrieval/__init__.py +0 -0
  27. retrieval/dense.py +39 -0
  28. retrieval/hybrid.py +47 -0
  29. retrieval/pipeline.py +12 -0
  30. retrieval/reranker.py +36 -0
  31. retrieval/sparse.py +44 -0
  32. storage/bm25.pkl +3 -0
  33. storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/data_level0.bin +3 -0
  34. storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/header.bin +3 -0
  35. storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/index_metadata.pickle +3 -0
  36. storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/length.bin +3 -0
  37. storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/link_lists.bin +3 -0
  38. storage/chroma/chroma.sqlite3 +3 -0
  39. storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/data_level0.bin +3 -0
  40. storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/header.bin +3 -0
  41. storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/index_metadata.pickle +3 -0
  42. storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/length.bin +3 -0
  43. storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/link_lists.bin +3 -0
  44. storage/manifest.json +19 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ # HF standard LFS patterns
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
 
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ storage/chroma/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ .venv/
3
+ __pycache__/
4
+ *.py[cod]
5
+ *.egg-info/
6
+ .env
7
+
8
+ # macOS
9
+ .DS_Store
10
+
11
+ # Papers — excluded (public ArXiv PDFs; re-download with download_papers.sh)
12
+ papers/
13
+
14
+ # Streamlit
15
+ .streamlit/secrets.toml
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [browser]
2
+ gatherUsageStats = false
3
+
4
+ [server]
5
+ headless = true
README.md CHANGED
@@ -1,12 +1,44 @@
1
  ---
2
  title: AdaptiveRAG
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: gray
6
  sdk: docker
7
- pinned: false
8
  license: mit
9
- short_description: production-grade RAG using Modular, Self-RAG, Agentic
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: AdaptiveRAG
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
+ pinned: true
8
  license: mit
9
+ short_description: Agentic + Self-RAG + Modular RAG with visual pipeline UI
10
  ---
11
 
12
+ # AdaptiveRAG Agentic + Self-RAG + Modular RAG
13
+
14
+ Live demo of a production-grade RAG pipeline — every stage is visible in the UI.
15
+
16
+ **Tech stack:** ChromaDB · sentence-transformers · BM25 · Reciprocal Rank Fusion · BGE cross-encoder · LLaMA 3.1 via Groq
17
+
18
+ **Knowledge base:** 14 foundational AI papers (Transformers, BERT, GPT-3, DDPM, RAG, Self-RAG, HyDE, ViT, CLIP, ReAct, Chain-of-Thought, LLM Survey)
19
+
20
+ ## What it shows
21
+
22
+ | Stage | What you see |
23
+ |---|---|
24
+ | Question encoding | 384-dim embedding vector + bar chart of first 32 dims |
25
+ | Self-RAG router | RETRIEVE / ANSWER_DIRECTLY / CLARIFY decision + reason |
26
+ | Planner | Sub-query decomposition with rationales |
27
+ | Dense retrieval | Cosine similarity scores vs ChromaDB |
28
+ | Sparse retrieval | BM25 keyword match scores |
29
+ | RRF fusion | Combined ranking chart |
30
+ | Cross-encoder rerank | BGE relevance scores |
31
+ | Vector space | 2D PCA projection of query + hits |
32
+ | Self-critique | Grounded / Complete / Confidence score |
33
+
34
+ ## Run locally
35
+
36
+ ```bash
37
+ git clone https://github.com/Gh-Novel/AdaptiveRAG
38
+ cd AdaptiveRAG
39
+ python -m venv .venv && source .venv/bin/activate
40
+ pip install -r requirements.txt
41
+
42
+ # needs Ollama running with qwen3-vl:8b-instruct-q8_0-optimized
43
+ streamlit run app.py
44
+ ```
agent/__init__.py ADDED
File without changes
agent/critic.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-critique: judge answer for grounding, completeness, confidence."""
2
+ from __future__ import annotations
3
+
4
+ from llm.ollama_client import OllamaClient
5
+
6
+ CRITIC_SYSTEM = (
7
+ "You are a strict reviewer. Judge whether an AI-generated answer is grounded "
8
+ "in the provided context and whether it fully answers the user's question."
9
+ )
10
+
11
+ CRITIC_PROMPT = """Question: {question}
12
+
13
+ Answer to review:
14
+ {answer}
15
+
16
+ Context that was provided to the answerer:
17
+ {context}
18
+
19
+ Score the answer. Return strict JSON only:
20
+ {{
21
+ "grounded": true | false, // Is every factual claim supported by the context?
22
+ "complete": true | false, // Does it fully address the question?
23
+ "confidence": 0.0-1.0, // Overall confidence in the answer
24
+ "missing": "<what info is missing or weakly supported, or empty string>"
25
+ }}
26
+ """
27
+
28
+
29
+ def critique(question: str, answer: str, context: str, llm: OllamaClient | None = None) -> dict:
30
+ llm = llm or OllamaClient()
31
+ out = llm.generate_json(
32
+ prompt=CRITIC_PROMPT.format(question=question, answer=answer, context=context),
33
+ system=CRITIC_SYSTEM,
34
+ temperature=0.0,
35
+ )
36
+ return {
37
+ "grounded": bool(out.get("grounded", False)),
38
+ "complete": bool(out.get("complete", False)),
39
+ "confidence": float(out.get("confidence", 0.0) or 0.0),
40
+ "missing": str(out.get("missing", "") or ""),
41
+ }
42
+
43
+
44
+ REFINE_SYSTEM = (
45
+ "You rewrite a search query so it retrieves the missing information."
46
+ )
47
+
48
+ REFINE_PROMPT = """Original question: {question}
49
+
50
+ A previous attempt was missing the following information:
51
+ {missing}
52
+
53
+ Rewrite the query to specifically target the missing information. Output the
54
+ rewritten search query as a single line of text, no quotes, no explanation.
55
+ """
56
+
57
+
58
+ def refine_query(question: str, missing: str, llm: OllamaClient | None = None) -> str:
59
+ llm = llm or OllamaClient()
60
+ out = llm.generate(
61
+ prompt=REFINE_PROMPT.format(question=question, missing=missing or "more detail"),
62
+ system=REFINE_SYSTEM,
63
+ temperature=0.1,
64
+ )
65
+ return out.strip().splitlines()[0] if out.strip() else question
agent/loop.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The agentic RAG loop.
2
+
3
+ Pipeline:
4
+ 1. Self-RAG router: RETRIEVE / ANSWER_DIRECTLY / CLARIFY
5
+ 2. (RETRIEVE branch) plan -> tools -> answer -> self-critique
6
+ 3. If confidence < threshold and budget left: refine and retry
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ from config import AGENT_CONFIG
14
+ from agent.critic import critique, refine_query
15
+ from agent.planner import plan
16
+ from agent.router import route
17
+ from agent.tools import ToolResult, vector_search
18
+ from llm.ollama_client import OllamaClient
19
+ from retrieval.dense import Hit
20
+
21
+
22
+ @dataclass
23
+ class TraceStep:
24
+ kind: str
25
+ detail: dict[str, Any] = field(default_factory=dict)
26
+
27
+
28
+ @dataclass
29
+ class AgentResult:
30
+ answer: str
31
+ citations: list[dict]
32
+ confidence: float
33
+ trace: list[TraceStep]
34
+ iterations: int
35
+ route: str
36
+
37
+
38
+ ANSWER_SYSTEM = (
39
+ "You are a careful research assistant. Use ONLY the provided passages to "
40
+ "answer the question. Cite sources inline with [N] where N is the passage "
41
+ "number. If the passages are insufficient, say so explicitly."
42
+ )
43
+
44
+ ANSWER_PROMPT = """Question: {question}
45
+
46
+ Passages:
47
+ {context}
48
+
49
+ Write a concise, well-grounded answer. Use inline citations like [1], [2] that
50
+ match the passage numbers above. If multiple passages support a claim, cite
51
+ them all. If the passages do not contain enough information, say so plainly.
52
+ """
53
+
54
+
55
+ def _format_context_block(hits: list[Hit]) -> tuple[str, list[dict]]:
56
+ lines = []
57
+ citations = []
58
+ for i, h in enumerate(hits, start=1):
59
+ meta = h.metadata
60
+ title = meta.get("title") or meta.get("source_path", "?")
61
+ pages = f"p.{meta.get('page_start')}-{meta.get('page_end')}"
62
+ head = f"[{i}] {title} ({pages})"
63
+ lines.append(f"{head}\n{h.text}")
64
+ citations.append(
65
+ {
66
+ "n": i,
67
+ "chunk_id": h.chunk_id,
68
+ "title": title,
69
+ "source_path": meta.get("source_path"),
70
+ "page_start": meta.get("page_start"),
71
+ "page_end": meta.get("page_end"),
72
+ "score": float(h.score),
73
+ }
74
+ )
75
+ return "\n\n".join(lines), citations
76
+
77
+
78
+ def _dedupe_hits(hits: list[Hit], limit: int) -> list[Hit]:
79
+ seen: set[str] = set()
80
+ out: list[Hit] = []
81
+ for h in hits:
82
+ if h.chunk_id in seen:
83
+ continue
84
+ seen.add(h.chunk_id)
85
+ out.append(h)
86
+ if len(out) >= limit:
87
+ break
88
+ return out
89
+
90
+
91
+ def run_agent(query: str, llm: OllamaClient | None = None) -> AgentResult:
92
+ llm = llm or OllamaClient()
93
+ trace: list[TraceStep] = []
94
+
95
+ # 1. Router
96
+ decision = route(query, llm=llm)
97
+ trace.append(TraceStep("router", decision))
98
+
99
+ if decision["action"] == "ANSWER_DIRECTLY":
100
+ ans = llm.generate(
101
+ prompt=query,
102
+ system="You are a helpful research assistant. Be concise.",
103
+ temperature=0.2,
104
+ )
105
+ return AgentResult(
106
+ answer=ans,
107
+ citations=[],
108
+ confidence=1.0,
109
+ trace=trace,
110
+ iterations=0,
111
+ route="ANSWER_DIRECTLY",
112
+ )
113
+
114
+ if decision["action"] == "CLARIFY":
115
+ ans = llm.generate(
116
+ prompt=(
117
+ "The user asked: " + query + "\n\n"
118
+ "It is too ambiguous to answer well. Ask one short clarifying "
119
+ "question to narrow it down."
120
+ ),
121
+ system="You are a helpful research assistant.",
122
+ temperature=0.2,
123
+ )
124
+ return AgentResult(
125
+ answer=ans,
126
+ citations=[],
127
+ confidence=0.0,
128
+ trace=trace,
129
+ iterations=0,
130
+ route="CLARIFY",
131
+ )
132
+
133
+ # 2. RETRIEVE branch — agentic loop
134
+ current_query = query
135
+ last_critique: dict[str, Any] = {}
136
+ accumulated: list[Hit] = []
137
+
138
+ for iteration in range(AGENT_CONFIG["max_iterations"]):
139
+ prior_summary = ""
140
+ if accumulated:
141
+ titles = sorted({h.metadata.get("title", "?") for h in accumulated})
142
+ prior_summary = "Already gathered passages from: " + ", ".join(titles)
143
+
144
+ steps = plan(current_query, prior_summary=prior_summary, llm=llm)
145
+ trace.append(TraceStep("plan", {"iteration": iteration, "steps": steps}))
146
+
147
+ for step in steps:
148
+ tool_res: ToolResult = vector_search(step["query"])
149
+ accumulated.extend(tool_res.hits)
150
+ trace.append(
151
+ TraceStep(
152
+ "tool",
153
+ {
154
+ "tool": "vector_search",
155
+ "query": step["query"],
156
+ "n_hits": len(tool_res.hits),
157
+ "top_titles": [
158
+ h.metadata.get("title") for h in tool_res.hits[:3]
159
+ ],
160
+ },
161
+ )
162
+ )
163
+
164
+ unique_hits = _dedupe_hits(accumulated, limit=8)
165
+ context_block, citations = _format_context_block(unique_hits)
166
+
167
+ answer = llm.generate(
168
+ prompt=ANSWER_PROMPT.format(question=query, context=context_block),
169
+ system=ANSWER_SYSTEM,
170
+ temperature=0.1,
171
+ )
172
+ trace.append(TraceStep("answer", {"iteration": iteration, "n_passages": len(unique_hits)}))
173
+
174
+ crit = critique(query, answer, context_block, llm=llm)
175
+ last_critique = crit
176
+ trace.append(TraceStep("critique", {"iteration": iteration, **crit}))
177
+
178
+ if crit["confidence"] >= AGENT_CONFIG["confidence_threshold"] and crit["grounded"]:
179
+ return AgentResult(
180
+ answer=answer,
181
+ citations=citations,
182
+ confidence=crit["confidence"],
183
+ trace=trace,
184
+ iterations=iteration + 1,
185
+ route="RETRIEVE",
186
+ )
187
+
188
+ current_query = refine_query(query, crit.get("missing", ""), llm=llm)
189
+ trace.append(TraceStep("refine", {"new_query": current_query}))
190
+
191
+ return AgentResult(
192
+ answer=answer,
193
+ citations=citations,
194
+ confidence=last_critique.get("confidence", 0.0),
195
+ trace=trace,
196
+ iterations=AGENT_CONFIG["max_iterations"],
197
+ route="RETRIEVE",
198
+ )
agent/planner.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-step query planner. Break a question into focused sub-queries."""
2
+ from __future__ import annotations
3
+
4
+ from config import AGENT_CONFIG
5
+ from llm.ollama_client import OllamaClient
6
+
7
+ PLANNER_SYSTEM = (
8
+ "You are a research planner. Given a user question, decompose it into a small "
9
+ "number of focused sub-queries. Each sub-query targets one piece of information "
10
+ "needed to answer the original question. Avoid redundant or overly broad steps."
11
+ )
12
+
13
+ PLANNER_PROMPT = """Decompose the user question into 1-{max_steps} focused retrieval sub-queries.
14
+ Use fewer steps when the question is simple; only use multiple steps for genuinely
15
+ multi-part or comparative questions.
16
+
17
+ Each sub-query should be a self-contained search query (10-20 words) phrased to
18
+ match passages in academic papers.
19
+
20
+ Respond with strict JSON only:
21
+ {{"steps": [
22
+ {{"query": "<search query>", "rationale": "<what this sub-query is looking for>"}}
23
+ ]}}
24
+
25
+ User question: {query}
26
+
27
+ Context already gathered (may be empty):
28
+ {context_summary}
29
+ """
30
+
31
+
32
+ def plan(query: str, prior_summary: str = "", llm: OllamaClient | None = None) -> list[dict]:
33
+ llm = llm or OllamaClient()
34
+ out = llm.generate_json(
35
+ prompt=PLANNER_PROMPT.format(
36
+ query=query,
37
+ max_steps=AGENT_CONFIG["max_plan_steps"],
38
+ context_summary=prior_summary or "(none)",
39
+ ),
40
+ system=PLANNER_SYSTEM,
41
+ temperature=0.1,
42
+ )
43
+ steps = out.get("steps") if isinstance(out, dict) else None
44
+ if not steps or not isinstance(steps, list):
45
+ return [{"query": query, "rationale": "fallback: use the original question"}]
46
+ cleaned: list[dict] = []
47
+ for s in steps[: AGENT_CONFIG["max_plan_steps"]]:
48
+ if isinstance(s, dict) and s.get("query"):
49
+ cleaned.append(
50
+ {"query": str(s["query"]).strip(), "rationale": str(s.get("rationale", "")).strip()}
51
+ )
52
+ return cleaned or [{"query": query, "rationale": "fallback"}]
agent/router.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-RAG router. Decide whether to retrieve, answer directly, or clarify."""
2
+ from __future__ import annotations
3
+
4
+ from llm.ollama_client import OllamaClient
5
+
6
+ ROUTER_SYSTEM = (
7
+ "You are a routing classifier for an AI research assistant whose knowledge base "
8
+ "contains papers on Transformers, BERT, GPT-3, diffusion (DDPM/DDIM), RAG, "
9
+ "Self-RAG, HyDE, ViT, CLIP, ReAct, Chain-of-Thought, and an LLM survey. "
10
+ "Decide how to handle a user query."
11
+ )
12
+
13
+ ROUTER_PROMPT = """Classify the query into one of three actions:
14
+
15
+ - "RETRIEVE": the user is asking about substantive content (concepts, methods, comparisons,
16
+ details from papers). The knowledge base is likely needed. Default to this when unsure.
17
+ - "ANSWER_DIRECTLY": pure conversational/meta queries (greetings, "what can you do",
18
+ "thanks") that need NO knowledge lookup.
19
+ - "CLARIFY": the query is too ambiguous or under-specified to act on (e.g. "tell me more"
20
+ with no prior context, "what about that paper" with no referent).
21
+
22
+ Respond with strict JSON only:
23
+ {{"action": "RETRIEVE" | "ANSWER_DIRECTLY" | "CLARIFY", "reason": "<one short sentence>"}}
24
+
25
+ Query: {query}
26
+ """
27
+
28
+
29
+ def route(query: str, llm: OllamaClient | None = None) -> dict:
30
+ llm = llm or OllamaClient()
31
+ out = llm.generate_json(
32
+ prompt=ROUTER_PROMPT.format(query=query),
33
+ system=ROUTER_SYSTEM,
34
+ temperature=0.0,
35
+ )
36
+ action = str(out.get("action", "RETRIEVE")).upper()
37
+ if action not in {"RETRIEVE", "ANSWER_DIRECTLY", "CLARIFY"}:
38
+ action = "RETRIEVE"
39
+ return {"action": action, "reason": out.get("reason", "")}
agent/tools.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent tools.
2
+
3
+ The vector_search tool drives the hybrid retriever. image_reason performs
4
+ multimodal RAG: caption the image, retrieve text by caption+query, then ask
5
+ Qwen3-VL to ground its answer in both image and text.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+
11
+ from llm.ollama_client import OllamaClient
12
+ from retrieval.dense import Hit
13
+ from retrieval.pipeline import hybrid_retrieve
14
+
15
+
16
+ @dataclass
17
+ class ToolResult:
18
+ tool: str
19
+ query: str
20
+ hits: list[Hit]
21
+ notes: str = ""
22
+
23
+
24
+ def vector_search(query: str, top_n: int | None = None) -> ToolResult:
25
+ hits = hybrid_retrieve(query, top_n=top_n)
26
+ return ToolResult(tool="vector_search", query=query, hits=hits)
27
+
28
+
29
+ CAPTION_SYSTEM = "You describe images in concise, factual language."
30
+ CAPTION_PROMPT = (
31
+ "Describe this image in 1-3 sentences. Mention the type of figure (chart, diagram, "
32
+ "screenshot, photo, equation, etc.), key labels, and the main visual content."
33
+ )
34
+
35
+
36
+ def caption_image(image_path: str, llm: OllamaClient | None = None) -> str:
37
+ llm = llm or OllamaClient()
38
+ return llm.generate(
39
+ prompt=CAPTION_PROMPT,
40
+ system=CAPTION_SYSTEM,
41
+ images=[image_path],
42
+ temperature=0.0,
43
+ )
44
+
45
+
46
+ MM_SYSTEM = (
47
+ "You are a careful research assistant. Answer using ONLY the provided image and "
48
+ "the cited text passages. If the answer is not supported, say so."
49
+ )
50
+ MM_PROMPT = """Image (provided separately) + question.
51
+
52
+ Question: {question}
53
+
54
+ Relevant passages:
55
+ {context}
56
+
57
+ Answer concisely. When citing a passage, use [N] where N is the passage number.
58
+ """
59
+
60
+
61
+ def image_retrieve_and_reason(
62
+ image_path: str, query: str, llm: OllamaClient | None = None
63
+ ) -> dict:
64
+ llm = llm or OllamaClient()
65
+ caption = caption_image(image_path, llm=llm)
66
+ fused_query = f"{caption} {query}".strip()
67
+ hits = hybrid_retrieve(fused_query)
68
+ context_block = _format_context(hits)
69
+ answer = llm.generate(
70
+ prompt=MM_PROMPT.format(question=query, context=context_block),
71
+ system=MM_SYSTEM,
72
+ images=[image_path],
73
+ temperature=0.1,
74
+ )
75
+ return {"caption": caption, "answer": answer, "hits": hits}
76
+
77
+
78
+ def _format_context(hits: list[Hit]) -> str:
79
+ lines = []
80
+ for i, h in enumerate(hits, start=1):
81
+ meta = h.metadata
82
+ head = f"[{i}] {meta.get('title', meta.get('source_path', '?'))} "
83
+ head += f"(p.{meta.get('page_start')}-{meta.get('page_end')})"
84
+ lines.append(f"{head}\n{h.text}")
85
+ return "\n\n".join(lines)
app.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AdaptiveRAG — under-the-hood pipeline visualizer.
2
+
3
+ Run: streamlit run app.py
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import tempfile
11
+ import time
12
+ from pathlib import Path
13
+
14
+ os.environ.setdefault("ANONYMIZED_TELEMETRY", "False")
15
+
16
+ # suppress harmless noise from Streamlit's torch inspector + ChromaDB posthog client
17
+ logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
18
+ logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.ERROR)
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ import streamlit as st
23
+
24
+ from agent.critic import critique, refine_query
25
+ from agent.planner import plan
26
+ from agent.router import route
27
+ from agent.tools import image_retrieve_and_reason
28
+ from config import AGENT_CONFIG, EMBEDDING_CONFIG, HOSTED, LLM_CONFIG, PATHS, RETRIEVAL_CONFIG
29
+ from ingestion.embedder import embed_query
30
+ from ingestion.indexer import fetch_embeddings
31
+ from llm.client_factory import get_llm
32
+ from retrieval.dense import Hit, dense_search
33
+ from retrieval.hybrid import reciprocal_rank_fusion
34
+ from retrieval.reranker import rerank
35
+ from retrieval.sparse import sparse_search
36
+
37
+ st.set_page_config(page_title="AdaptiveRAG — Underhood", page_icon="🔬", layout="wide")
38
+
39
+ # ───────────────────────────── styling ──────────────────────────────
40
+ st.markdown(
41
+ """
42
+ <style>
43
+ .phase-card {
44
+ border-left: 4px solid var(--accent, #4f8cff);
45
+ padding: .6rem 1rem;
46
+ margin: .25rem 0 .5rem 0;
47
+ background: rgba(79,140,255,0.06);
48
+ border-radius: 6px;
49
+ }
50
+ .phase-num { color: #4f8cff; font-weight: 700; margin-right: .4rem; }
51
+ .pill { display: inline-block; padding: .15rem .55rem; border-radius: 999px;
52
+ font-size: .78rem; font-weight: 600; margin-right: .4rem; }
53
+ .pill-blue { background: #1e3a5f; color: #9ec5ff; }
54
+ .pill-green { background: #1e4f30; color: #a3e6b5; }
55
+ .pill-purple { background: #3d2a5e; color: #c8a8f5; }
56
+ .pill-amber { background: #5e3f0e; color: #f3c97a; }
57
+ .pill-red { background: #5a1f1f; color: #f3a3a3; }
58
+ .pill-grey { background: #2c2c33; color: #b8b8c0; }
59
+ .chunk-card {
60
+ background: rgba(255,255,255,0.03);
61
+ border: 1px solid rgba(255,255,255,0.08);
62
+ border-radius: 6px; padding: .55rem .7rem; margin-bottom: .4rem;
63
+ font-size: .82rem;
64
+ }
65
+ .chunk-meta { color: #9aa3b2; font-size: .73rem; margin-bottom: .25rem; }
66
+ .mini-vec {
67
+ font-family: 'SF Mono', Menlo, monospace; font-size: .68rem;
68
+ color: #8b949e; word-break: break-all;
69
+ }
70
+ </style>
71
+ """,
72
+ unsafe_allow_html=True,
73
+ )
74
+
75
+
76
+ # ───────────────────────────── helpers ──────────────────────────────
77
+ @st.cache_resource
78
+ def _llm():
79
+ return get_llm()
80
+
81
+
82
+ def _load_manifest() -> dict:
83
+ p = PATHS["manifest_path"]
84
+ return json.loads(p.read_text()) if p.exists() else {}
85
+
86
+
87
+ def phase_header(num: int, title: str, subtitle: str = "") -> None:
88
+ st.markdown(
89
+ f"<div class='phase-card'><span class='phase-num'>STEP {num}</span>"
90
+ f"<b>{title}</b><br><span style='color:#9aa3b2;font-size:.85rem;'>{subtitle}</span>"
91
+ f"</div>",
92
+ unsafe_allow_html=True,
93
+ )
94
+
95
+
96
+ def hits_to_df(hits: list[Hit], score_label: str = "score") -> pd.DataFrame:
97
+ rows = []
98
+ for h in hits:
99
+ title = h.metadata.get("title") or h.metadata.get("source_path", "?")
100
+ short = title.split(" (")[0]
101
+ if len(short) > 38:
102
+ short = short[:35] + "…"
103
+ label = f"{short} · p{h.metadata.get('page_start')} · {h.chunk_id.split('::')[-1]}"
104
+ rows.append({"chunk": label, score_label: float(h.score), "chunk_id": h.chunk_id})
105
+ return pd.DataFrame(rows)
106
+
107
+
108
+ def render_hits(hits: list[Hit], badge_class: str, label: str, max_chars: int = 220) -> None:
109
+ if not hits:
110
+ st.caption(f"_(no {label.lower()} hits)_")
111
+ return
112
+ for i, h in enumerate(hits, start=1):
113
+ meta = h.metadata
114
+ snippet = h.text[:max_chars].replace("\n", " ")
115
+ if len(h.text) > max_chars:
116
+ snippet += "…"
117
+ st.markdown(
118
+ f"<div class='chunk-card'>"
119
+ f"<div class='chunk-meta'>"
120
+ f"<span class='pill {badge_class}'>{label} #{i}</span>"
121
+ f"score <b>{h.score:.3f}</b> · "
122
+ f"{meta.get('title','?')} · p.{meta.get('page_start')}–{meta.get('page_end')} · "
123
+ f"<code>{h.chunk_id}</code>"
124
+ f"</div>{snippet}</div>",
125
+ unsafe_allow_html=True,
126
+ )
127
+
128
+
129
+ def pca_2d(matrix: np.ndarray) -> np.ndarray:
130
+ centered = matrix - matrix.mean(axis=0, keepdims=True)
131
+ _, _, vt = np.linalg.svd(centered, full_matrices=False)
132
+ return centered @ vt[:2].T
133
+
134
+
135
+ def vector_space_plot(query_vec: list[float], fused_hits: list[Hit],
136
+ dense_ids: set[str], sparse_ids: set[str],
137
+ kept_ids: set[str]) -> None:
138
+ if not fused_hits:
139
+ st.caption("_(nothing to plot)_")
140
+ return
141
+ embs = fetch_embeddings([h.chunk_id for h in fused_hits])
142
+ rows = []
143
+ vecs = [np.array(query_vec, dtype=np.float32)]
144
+ for h in fused_hits:
145
+ v = embs.get(h.chunk_id)
146
+ if v is None:
147
+ continue
148
+ vecs.append(np.array(v, dtype=np.float32))
149
+ in_d, in_s = h.chunk_id in dense_ids, h.chunk_id in sparse_ids
150
+ in_keep = h.chunk_id in kept_ids
151
+ if in_d and in_s:
152
+ color = "fused (both)"
153
+ elif in_d:
154
+ color = "dense only"
155
+ elif in_s:
156
+ color = "sparse only"
157
+ else:
158
+ color = "other"
159
+ title = (h.metadata.get("title") or "?").split(" (")[0][:40]
160
+ label = f"{title} · p{h.metadata.get('page_start')}"
161
+ rows.append({"label": label, "color": color, "size": 90 if in_keep else 50})
162
+ if len(vecs) < 3:
163
+ st.caption("_(need at least 2 hits for a 2D projection)_")
164
+ return
165
+ proj = pca_2d(np.vstack(vecs))
166
+ df = pd.DataFrame(
167
+ [{"x": proj[0, 0], "y": proj[0, 1], "label": "🔎 your question",
168
+ "color": "QUERY", "size": 220}]
169
+ + [{"x": proj[i + 1, 0], "y": proj[i + 1, 1], **rows[i]}
170
+ for i in range(len(rows))]
171
+ )
172
+ st.scatter_chart(
173
+ df, x="x", y="y", color="color", size="size",
174
+ height=380, use_container_width=True,
175
+ )
176
+ st.caption(
177
+ "PCA projection of the query embedding + fused hit embeddings. "
178
+ "Larger points survived cross-encoder reranking."
179
+ )
180
+
181
+
182
+ def render_embedding_card(query: str, qv: list[float], dt: float) -> None:
183
+ arr = np.array(qv, dtype=np.float32)
184
+ cols = st.columns([1, 1, 1, 3])
185
+ cols[0].metric("Model", EMBEDDING_CONFIG["model"].split("/")[-1])
186
+ cols[1].metric("Dimensions", len(qv))
187
+ cols[2].metric("L2 norm", f"{float(np.linalg.norm(arr)):.3f}")
188
+ cols[3].metric("Embed time", f"{dt*1000:.0f} ms")
189
+ st.caption(f"Question ({len(query)} chars, ~{len(query.split())} words):")
190
+ st.code(query, language="text")
191
+ st.caption("First 32 dimensions of the embedding vector:")
192
+ st.bar_chart(pd.DataFrame({"value": arr[:32]}), height=140, use_container_width=True)
193
+ preview = ", ".join(f"{x:+.3f}" for x in arr[:8]) + ", …"
194
+ st.markdown(f"<span class='mini-vec'>vector[0:8] = [{preview}]</span>",
195
+ unsafe_allow_html=True)
196
+
197
+
198
+ # ───────────────────────────── pipeline view ──────────────────────────────
199
+ def visual_pipeline(query: str) -> None:
200
+ llm = _llm()
201
+
202
+ # ── Step 1: embed the question ────────────────────────────────
203
+ phase_header(1, "Question encoding",
204
+ "Convert text → 384-dim dense vector via sentence-transformers (MiniLM-L6).")
205
+ t0 = time.time()
206
+ qv = embed_query(query)
207
+ render_embedding_card(query, qv, time.time() - t0)
208
+
209
+ # ── Step 2: Self-RAG router ────────────────────────────────
210
+ phase_header(2, "Self-RAG router",
211
+ "Decide whether to RETRIEVE, ANSWER_DIRECTLY, or CLARIFY before touching the index.")
212
+ t0 = time.time()
213
+ decision = route(query, llm=llm)
214
+ dt = time.time() - t0
215
+ pill_map = {"RETRIEVE": "pill-blue", "ANSWER_DIRECTLY": "pill-green", "CLARIFY": "pill-amber"}
216
+ pill = pill_map.get(decision["action"], "pill-grey")
217
+ st.markdown(
218
+ f"<span class='pill {pill}'>{decision['action']}</span>"
219
+ f"<span style='color:#9aa3b2;'>{decision.get('reason','')}</span>"
220
+ f"<span style='float:right;color:#9aa3b2;font-size:.78rem;'>"
221
+ f"router latency: {dt*1000:.0f} ms</span>",
222
+ unsafe_allow_html=True,
223
+ )
224
+
225
+ if decision["action"] == "ANSWER_DIRECTLY":
226
+ st.markdown("### Direct answer (no retrieval)")
227
+ ans = llm.generate(prompt=query,
228
+ system="You are a helpful research assistant. Be concise.",
229
+ temperature=0.2)
230
+ st.markdown(ans)
231
+ return
232
+ if decision["action"] == "CLARIFY":
233
+ st.markdown("### Clarifying question")
234
+ ans = llm.generate(
235
+ prompt=("The user asked: " + query +
236
+ "\n\nIt is too ambiguous to answer well. Ask one short clarifying question."),
237
+ system="You are a helpful research assistant.",
238
+ temperature=0.2,
239
+ )
240
+ st.markdown(ans)
241
+ return
242
+
243
+ # ── Iterations of plan → retrieve → answer → critique ────────
244
+ accumulated: list[Hit] = []
245
+ current_query = query
246
+
247
+ for it in range(AGENT_CONFIG["max_iterations"]):
248
+ st.markdown(f"---\n## 🔁 Iteration {it + 1}")
249
+ if current_query != query:
250
+ st.info(f"Refined query → **{current_query}**")
251
+
252
+ # ── Step 3: plan ─────────────────────────────────────
253
+ phase_header(3, "Planner", "LLM decomposes the question into focused sub-queries.")
254
+ prior = ""
255
+ if accumulated:
256
+ titles = sorted({h.metadata.get("title", "?") for h in accumulated})
257
+ prior = "Already gathered passages from: " + ", ".join(titles)
258
+ t0 = time.time()
259
+ steps = plan(current_query, prior_summary=prior, llm=llm)
260
+ dt = time.time() - t0
261
+ st.caption(f"Generated {len(steps)} sub-quer{'y' if len(steps)==1 else 'ies'} in {dt*1000:.0f} ms")
262
+ for i, s in enumerate(steps, start=1):
263
+ st.markdown(
264
+ f"<div class='chunk-card'>"
265
+ f"<span class='pill pill-purple'>sub-query {i}</span>"
266
+ f"<b>{s['query']}</b>"
267
+ f"<div class='chunk-meta' style='margin-top:.3rem;'>"
268
+ f"rationale: {s.get('rationale','—')}</div></div>",
269
+ unsafe_allow_html=True,
270
+ )
271
+
272
+ # ── Step 4: retrieval per sub-query ──────────────────
273
+ phase_header(
274
+ 4,
275
+ "Hybrid retrieval per sub-query",
276
+ f"Dense (Chroma cosine, k={RETRIEVAL_CONFIG['dense_k']}) ∥ "
277
+ f"Sparse (BM25, k={RETRIEVAL_CONFIG['sparse_k']}) → "
278
+ f"Reciprocal Rank Fusion → Cross-encoder rerank "
279
+ f"(BGE, top {RETRIEVAL_CONFIG['rerank_top_n']}).",
280
+ )
281
+
282
+ for si, step in enumerate(steps, start=1):
283
+ with st.expander(f"Sub-query {si}: {step['query']}", expanded=(si == 1)):
284
+ t0 = time.time()
285
+ dense_hits = dense_search(step["query"])
286
+ t_dense = time.time() - t0
287
+ t0 = time.time()
288
+ sparse_hits = sparse_search(step["query"])
289
+ t_sparse = time.time() - t0
290
+ t0 = time.time()
291
+ fused = reciprocal_rank_fusion([dense_hits, sparse_hits],
292
+ top_k=max(RETRIEVAL_CONFIG["dense_k"],
293
+ RETRIEVAL_CONFIG["sparse_k"]))
294
+ t_fuse = time.time() - t0
295
+ t0 = time.time()
296
+ reranked = rerank(step["query"], fused)
297
+ t_rerank = time.time() - t0
298
+
299
+ m1, m2, m3, m4 = st.columns(4)
300
+ m1.metric("Dense hits", len(dense_hits), f"{t_dense*1000:.0f} ms")
301
+ m2.metric("Sparse hits", len(sparse_hits), f"{t_sparse*1000:.0f} ms")
302
+ m3.metric("After RRF", len(fused), f"{t_fuse*1000:.0f} ms")
303
+ m4.metric("After rerank", len(reranked), f"{t_rerank*1000:.0f} ms")
304
+
305
+ tabs = st.tabs([
306
+ "🔵 Dense (vectors)",
307
+ "🟢 Sparse (BM25)",
308
+ "🟣 RRF fusion",
309
+ "🟡 Cross-encoder rerank",
310
+ "🗺️ Vector space",
311
+ ])
312
+ with tabs[0]:
313
+ st.caption("Top-K nearest neighbors by cosine similarity.")
314
+ if dense_hits:
315
+ st.bar_chart(hits_to_df(dense_hits, "cosine_sim"),
316
+ x="chunk", y="cosine_sim",
317
+ height=260, use_container_width=True)
318
+ render_hits(dense_hits[:5], "pill-blue", "DENSE")
319
+
320
+ with tabs[1]:
321
+ st.caption("Top-K BM25 keyword matches (normalized).")
322
+ if sparse_hits:
323
+ st.bar_chart(hits_to_df(sparse_hits, "bm25_norm"),
324
+ x="chunk", y="bm25_norm",
325
+ height=260, use_container_width=True)
326
+ render_hits(sparse_hits[:5], "pill-green", "BM25")
327
+
328
+ with tabs[2]:
329
+ st.caption(
330
+ "Reciprocal Rank Fusion: score(d) = Σ 1/(k + rank). "
331
+ "Combines dense + sparse rankings into one merged list."
332
+ )
333
+ if fused:
334
+ st.bar_chart(hits_to_df(fused[:12], "rrf_score"),
335
+ x="chunk", y="rrf_score",
336
+ height=280, use_container_width=True)
337
+ render_hits(fused[:5], "pill-purple", "FUSED")
338
+
339
+ with tabs[3]:
340
+ st.caption(
341
+ "Cross-encoder scores (query, chunk) jointly — much more "
342
+ "accurate than bi-encoder cosine, but slower → only run on "
343
+ "the fused candidate set."
344
+ )
345
+ if reranked:
346
+ st.bar_chart(hits_to_df(reranked, "ce_score"),
347
+ x="chunk", y="ce_score",
348
+ height=240, use_container_width=True)
349
+ render_hits(reranked, "pill-amber", "RERANKED")
350
+
351
+ with tabs[4]:
352
+ dense_ids = {h.chunk_id for h in dense_hits}
353
+ sparse_ids = {h.chunk_id for h in sparse_hits}
354
+ kept_ids = {h.chunk_id for h in reranked}
355
+ vector_space_plot(qv, fused[:20], dense_ids, sparse_ids, kept_ids)
356
+
357
+ accumulated.extend(reranked)
358
+
359
+ # ── Step 5: answer ─────────────────────────────────────
360
+ # Dedupe + cap to 8 passages for the final prompt
361
+ seen: set[str] = set()
362
+ unique: list[Hit] = []
363
+ for h in accumulated:
364
+ if h.chunk_id in seen:
365
+ continue
366
+ seen.add(h.chunk_id)
367
+ unique.append(h)
368
+ if len(unique) >= 8:
369
+ break
370
+ context_lines, citations = [], []
371
+ for i, h in enumerate(unique, start=1):
372
+ meta = h.metadata
373
+ head = (f"[{i}] {meta.get('title','?')} "
374
+ f"(p.{meta.get('page_start')}-{meta.get('page_end')})")
375
+ context_lines.append(f"{head}\n{h.text}")
376
+ citations.append({
377
+ "n": i, "chunk_id": h.chunk_id,
378
+ "title": meta.get("title"),
379
+ "source_path": meta.get("source_path"),
380
+ "page_start": meta.get("page_start"),
381
+ "page_end": meta.get("page_end"),
382
+ "score": float(h.score),
383
+ })
384
+ context_block = "\n\n".join(context_lines)
385
+
386
+ phase_header(5, "Context assembly + answer generation",
387
+ f"Top {len(unique)} unique passages → {LLM_CONFIG['model']} via {LLM_CONFIG['provider']}.")
388
+ with st.expander("📦 Context handed to the LLM", expanded=False):
389
+ for c in citations:
390
+ st.markdown(
391
+ f"**[{c['n']}]** {c['title']} · pages {c['page_start']}–{c['page_end']} · "
392
+ f"score `{c['score']:.3f}`"
393
+ )
394
+ st.code(context_block[:3000] + ("…" if len(context_block) > 3000 else ""),
395
+ language="text")
396
+
397
+ t0 = time.time()
398
+ ANSWER_SYSTEM = (
399
+ "You are a careful research assistant. Use ONLY the provided passages to "
400
+ "answer the question. Cite sources inline with [N] where N is the passage "
401
+ "number. If the passages are insufficient, say so explicitly."
402
+ )
403
+ ANSWER_PROMPT = (
404
+ f"Question: {query}\n\nPassages:\n{context_block}\n\n"
405
+ "Write a concise, well-grounded answer. Use inline citations like [1], [2] "
406
+ "that match the passage numbers above."
407
+ )
408
+ answer = llm.generate(prompt=ANSWER_PROMPT, system=ANSWER_SYSTEM, temperature=0.1)
409
+ st.caption(f"LLM generation: {time.time()-t0:.1f} s")
410
+ st.markdown("### Answer")
411
+ st.markdown(answer)
412
+
413
+ st.markdown("### Citations")
414
+ for c in citations:
415
+ st.markdown(
416
+ f"**[{c['n']}]** {c['title']} — pages {c['page_start']}–{c['page_end']} "
417
+ f"· score `{c['score']:.3f}` · `{Path(c['source_path']).name}`"
418
+ )
419
+
420
+ # ── Step 6: critic ─────────────────────────────────────
421
+ phase_header(6, "Self-critique",
422
+ "LLM scores its own answer for grounding + completeness.")
423
+ t0 = time.time()
424
+ crit = critique(query, answer, context_block, llm=llm)
425
+ c1, c2, c3 = st.columns(3)
426
+ c1.metric("Grounded", "✅ yes" if crit["grounded"] else "⚠️ no")
427
+ c2.metric("Complete", "✅ yes" if crit["complete"] else "⚠️ no")
428
+ c3.metric("Confidence", f"{crit['confidence']:.2f}",
429
+ delta=f"threshold {AGENT_CONFIG['confidence_threshold']:.2f}")
430
+ if crit.get("missing"):
431
+ st.warning(f"Missing: {crit['missing']}")
432
+ st.caption(f"Critique latency: {time.time()-t0:.1f} s")
433
+
434
+ if crit["confidence"] >= AGENT_CONFIG["confidence_threshold"] and crit["grounded"]:
435
+ st.success(f"✓ Confidence {crit['confidence']:.2f} ≥ threshold — answer accepted.")
436
+ return
437
+
438
+ if it < AGENT_CONFIG["max_iterations"] - 1:
439
+ st.warning("Confidence below threshold — refining query and retrying.")
440
+ current_query = refine_query(query, crit.get("missing", ""), llm=llm)
441
+ else:
442
+ st.error("Max iterations reached. Returning best-effort answer.")
443
+
444
+
445
+ # ───────────────────────────── sidebar + tabs ──────────────────────────────
446
+ def _sidebar() -> None:
447
+ st.sidebar.title("AdaptiveRAG")
448
+ st.sidebar.caption("Agentic + Self-RAG + Modular RAG")
449
+ llm = _llm()
450
+ ok = llm.health()
451
+ backend = "Groq API" if HOSTED else "Ollama (local)"
452
+ st.sidebar.markdown(f"**LLM backend**: {'🟢' if ok else '🔴'} {backend}")
453
+ st.sidebar.markdown(f"**Model**: `{LLM_CONFIG['model']}`")
454
+ st.sidebar.markdown(f"**Embedder**: `{EMBEDDING_CONFIG['model'].split('/')[-1]}`")
455
+ st.sidebar.markdown(f"**Reranker**: `bge-reranker-base`")
456
+ manifest = _load_manifest()
457
+ if manifest:
458
+ st.sidebar.markdown(f"**Index**: {manifest.get('n_chunks','?')} chunks across "
459
+ f"{len(manifest.get('chunks_per_doc',{}))} docs")
460
+ with st.sidebar.expander("Documents"):
461
+ for doc, n in sorted(manifest.get("chunks_per_doc", {}).items()):
462
+ st.markdown(f"- `{doc}` — {n}")
463
+ else:
464
+ st.sidebar.warning("No index found. Run `python ingest.py --reset`.")
465
+ st.sidebar.divider()
466
+ st.sidebar.markdown("### Pipeline")
467
+ st.sidebar.code(
468
+ "question\n ↓ embed (MiniLM)\n ↓ Self-RAG router\n ↓ planner → sub-queries\n"
469
+ " ↓ dense ∥ sparse\n ↓ RRF fusion\n ↓ cross-encoder rerank\n ↓ Qwen3-VL answer\n"
470
+ " ↓ self-critique → retry?\n → answer + citations",
471
+ language="text",
472
+ )
473
+
474
+
475
+ def pipeline_tab() -> None:
476
+ st.subheader("🔬 Underhood: watch every stage of the agentic RAG pipeline")
477
+ st.caption(
478
+ "Each step renders its inputs and outputs as it runs — embedding vector, "
479
+ "router decision, planner sub-queries, dense vs sparse hits side-by-side, "
480
+ "RRF fusion, cross-encoder rerank, vector-space projection, answer, self-critique."
481
+ )
482
+ samples = [
483
+ "How does Self-RAG decide when to retrieve, and what reflection tokens does it use?",
484
+ "Compare DDPM and DDIM sampling — what does DDIM gain by being non-Markovian?",
485
+ "What is multi-head self-attention and why does parallelism matter?",
486
+ "How does HyDE improve dense retrieval without relevance labels?",
487
+ "How does ReAct combine reasoning and acting, vs chain-of-thought?",
488
+ "hello, what can you do?",
489
+ ]
490
+ if "vq" not in st.session_state:
491
+ st.session_state.vq = samples[0]
492
+ cols = st.columns(3)
493
+ for i, s in enumerate(samples):
494
+ if cols[i % 3].button(s, key=f"vs{i}", use_container_width=True):
495
+ st.session_state.vq = s
496
+ q = st.text_area("Question", value=st.session_state.vq, height=80, key="vq_input")
497
+ if st.button("▶ Run pipeline", type="primary"):
498
+ if q.strip():
499
+ visual_pipeline(q.strip())
500
+
501
+
502
+ def image_tab() -> None:
503
+ st.subheader("🖼️ Multimodal RAG (Qwen3-VL)")
504
+ st.caption(
505
+ "Upload an image (e.g. a figure from a paper). Qwen3-VL captions it, the "
506
+ "caption + question drives hybrid retrieval, then the model reasons over "
507
+ "image + retrieved passages together."
508
+ )
509
+ uploaded = st.file_uploader("Image", type=["png", "jpg", "jpeg", "webp"])
510
+ q = st.text_input("Question about the image", "Explain what this figure shows.")
511
+ go = st.button("Reason", type="primary", key="img_go")
512
+ if uploaded:
513
+ st.image(uploaded, width=400)
514
+ if not (go and uploaded):
515
+ return
516
+ with tempfile.NamedTemporaryFile(suffix=Path(uploaded.name).suffix, delete=False) as f:
517
+ f.write(uploaded.getbuffer())
518
+ tmp_path = f.name
519
+ try:
520
+ with st.spinner("Captioning → retrieving → multimodal reasoning..."):
521
+ out = image_retrieve_and_reason(tmp_path, q, llm=_llm())
522
+ st.markdown("### Caption")
523
+ st.write(out["caption"])
524
+ st.markdown("### Answer")
525
+ st.markdown(out["answer"])
526
+ st.markdown("### Retrieved passages")
527
+ for i, h in enumerate(out["hits"], start=1):
528
+ st.markdown(
529
+ f"**[{i}]** {h.metadata.get('title')} "
530
+ f"(p.{h.metadata.get('page_start')}–{h.metadata.get('page_end')}) "
531
+ f"· score `{h.score:.3f}`"
532
+ )
533
+ st.caption(h.text[:300] + ("…" if len(h.text) > 300 else ""))
534
+ finally:
535
+ os.unlink(tmp_path)
536
+
537
+
538
+ def main() -> None:
539
+ _sidebar()
540
+ st.title("AdaptiveRAG 📚🔬")
541
+ st.caption(
542
+ "Agentic + Self-RAG + Modular RAG over your local paper library — "
543
+ f"powered by `{LLM_CONFIG['model']}` via **{LLM_CONFIG['provider']}**. "
544
+ "Every pipeline stage is exposed below."
545
+ )
546
+ pipe, img = st.tabs(["🔬 Underhood pipeline", "🖼️ Image Q&A (multimodal)"])
547
+ with pipe:
548
+ pipeline_tab()
549
+ with img:
550
+ image_tab()
551
+
552
+
553
+ if __name__ == "__main__":
554
+ main()
ask.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI: python ask.py 'your question here'"""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+
7
+ from agent.loop import run_agent
8
+
9
+
10
+ def main() -> None:
11
+ ap = argparse.ArgumentParser()
12
+ ap.add_argument("question", nargs="+")
13
+ ap.add_argument("--trace", action="store_true")
14
+ args = ap.parse_args()
15
+ q = " ".join(args.question)
16
+ res = run_agent(q)
17
+ print("\n=== ROUTE ===")
18
+ print(res.route)
19
+ print("\n=== ANSWER ===")
20
+ print(res.answer)
21
+ print("\n=== CITATIONS ===")
22
+ for c in res.citations:
23
+ print(f" [{c['n']}] {c['title']} (p.{c['page_start']}-{c['page_end']}) score={c['score']:.3f}")
24
+ print(f"\nConfidence: {res.confidence:.2f} iterations: {res.iterations}")
25
+ if args.trace:
26
+ print("\n=== TRACE ===")
27
+ for s in res.trace:
28
+ print(f" • {s.kind}: {json.dumps(s.detail, default=str)[:200]}")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Central configuration for AdaptiveRAG."""
2
+ import os
3
+ from pathlib import Path
4
+
5
+ ROOT = Path(__file__).parent.resolve()
6
+
7
+ # Detect hosting environment
8
+ HOSTED = bool(os.environ.get("GROQ_API_KEY"))
9
+
10
+ LLM_CONFIG = {
11
+ "provider": "groq" if HOSTED else "ollama",
12
+ "model": "llama-3.1-8b-instant" if HOSTED else "qwen3-vl:8b-instruct-q8_0-optimized",
13
+ "base_url": "https://api.groq.com/openai/v1" if HOSTED else "http://localhost:11434",
14
+ "temperature": 0.1,
15
+ "timeout": 60 if HOSTED else 180,
16
+ "num_ctx": 8192,
17
+ }
18
+
19
+ EMBEDDING_CONFIG = {
20
+ "model": "sentence-transformers/all-MiniLM-L6-v2",
21
+ "device": "cpu",
22
+ "batch_size": 32,
23
+ }
24
+
25
+ RERANKER_CONFIG = {
26
+ "model": "BAAI/bge-reranker-base",
27
+ "device": "cpu",
28
+ }
29
+
30
+ CHUNKING_CONFIG = {
31
+ "target_chunk_chars": 1400,
32
+ "max_chunk_chars": 2200,
33
+ "min_chunk_chars": 350,
34
+ "overlap_chars": 200,
35
+ }
36
+
37
+ RETRIEVAL_CONFIG = {
38
+ "dense_k": 12,
39
+ "sparse_k": 12,
40
+ "rrf_k": 60,
41
+ "rerank_top_n": 5,
42
+ }
43
+
44
+ AGENT_CONFIG = {
45
+ "max_iterations": 3,
46
+ "confidence_threshold": 0.85,
47
+ "max_plan_steps": 3,
48
+ }
49
+
50
+ PATHS = {
51
+ "papers_dir": ROOT / "papers",
52
+ "chroma_dir": ROOT / "storage" / "chroma",
53
+ "bm25_path": ROOT / "storage" / "bm25.pkl",
54
+ "manifest_path": ROOT / "storage" / "manifest.json",
55
+ }
56
+
57
+ CHROMA_COLLECTION = "adaptive_rag"
download_papers.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mkdir -p papers && cd papers
3
+
4
+ # Transformers
5
+ curl -L "https://arxiv.org/pdf/1706.03762" -o "01_attention_is_all_you_need.pdf"
6
+ curl -L "https://arxiv.org/pdf/1810.04805" -o "02_bert.pdf"
7
+ curl -L "https://arxiv.org/pdf/2005.14165" -o "03_gpt3.pdf"
8
+
9
+ # Diffusion
10
+ curl -L "https://arxiv.org/pdf/2006.11239" -o "04_ddpm.pdf"
11
+ curl -L "https://arxiv.org/pdf/2010.02502" -o "05_ddim.pdf"
12
+
13
+ # RAG
14
+ curl -L "https://arxiv.org/pdf/2005.11401" -o "06_rag_original.pdf"
15
+ curl -L "https://arxiv.org/pdf/2312.10997" -o "07_rag_survey.pdf"
16
+ curl -L "https://arxiv.org/pdf/2310.11511" -o "08_self_rag.pdf"
17
+ curl -L "https://arxiv.org/pdf/2212.10496" -o "09_hyde.pdf"
18
+
19
+ # Vision
20
+ curl -L "https://arxiv.org/pdf/2010.11929" -o "10_vit.pdf"
21
+ curl -L "https://arxiv.org/pdf/2103.00020" -o "11_clip.pdf"
22
+
23
+ # Agents
24
+ curl -L "https://arxiv.org/pdf/2210.03629" -o "12_react.pdf"
25
+ curl -L "https://arxiv.org/pdf/2201.11903" -o "13_chain_of_thought.pdf"
26
+ curl -L "https://arxiv.org/pdf/2303.18223" -o "14_llm_survey.pdf"
27
+
28
+ echo "Downloaded $(ls *.pdf | wc -l) papers"
ingest.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ingest all PDFs from the papers/ directory into ChromaDB + BM25."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import time
6
+
7
+ from config import PATHS
8
+ from ingestion.chunker import chunk_document
9
+ from ingestion.indexer import index_chunks
10
+ from ingestion.loader import discover_pdfs, load_pdf
11
+
12
+
13
+ def main() -> None:
14
+ ap = argparse.ArgumentParser()
15
+ ap.add_argument("--reset", action="store_true", help="Reset the index first")
16
+ ap.add_argument("--papers-dir", default=str(PATHS["papers_dir"]))
17
+ args = ap.parse_args()
18
+
19
+ pdfs = discover_pdfs(args.papers_dir)
20
+ if not pdfs:
21
+ print(f"No PDFs in {args.papers_dir}")
22
+ return
23
+
24
+ print(f"Found {len(pdfs)} PDFs in {args.papers_dir}")
25
+ all_chunks = []
26
+ t0 = time.time()
27
+ for path in pdfs:
28
+ doc_id = path.stem
29
+ print(f"\n[{doc_id}]")
30
+ doc = load_pdf(path)
31
+ print(f" Loaded: {len(doc.pages)} pages, title={doc.title!r}")
32
+ chunks = chunk_document(doc, doc_id=doc_id)
33
+ print(f" Chunked: {len(chunks)} chunks (avg {sum(len(c.text) for c in chunks)//max(len(chunks),1)} chars)")
34
+ all_chunks.extend(chunks)
35
+
36
+ print(f"\nIndexing {len(all_chunks)} chunks total...")
37
+ manifest = index_chunks(all_chunks, reset=args.reset)
38
+ dt = time.time() - t0
39
+ print(f"\nDone in {dt:.1f}s. Manifest: {manifest}")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
ingestion/__init__.py ADDED
File without changes
ingestion/chunker.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semantic-aware chunking.
2
+
3
+ Strategy: split each page into sentences, then greedily group sentences into
4
+ chunks targeting CHUNKING_CONFIG['target_chunk_chars']. Carry an overlap of
5
+ the last few sentences (~overlap_chars) to the next chunk so context isn't
6
+ sliced mid-thought. Headings hint chunk boundaries.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import re
11
+ from dataclasses import dataclass
12
+
13
+ from config import CHUNKING_CONFIG
14
+ from ingestion.loader import LoadedDoc
15
+
16
+
17
+ @dataclass
18
+ class Chunk:
19
+ chunk_id: str
20
+ doc_id: str
21
+ source_path: str
22
+ title: str
23
+ page_start: int
24
+ page_end: int
25
+ text: str
26
+
27
+
28
+ _SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+(?=[A-Z0-9(])")
29
+ _HEADING_HINT = re.compile(r"^(?:[0-9]+(?:\.[0-9]+)*\s+|abstract|introduction|conclusion|references|method|results|discussion|background)\b", re.IGNORECASE)
30
+
31
+
32
+ def _split_sentences(text: str) -> list[str]:
33
+ parts: list[str] = []
34
+ for line in text.splitlines():
35
+ line = line.strip()
36
+ if not line:
37
+ continue
38
+ if _HEADING_HINT.match(line):
39
+ parts.append(line)
40
+ continue
41
+ for sent in _SENTENCE_SPLIT.split(line):
42
+ sent = sent.strip()
43
+ if sent:
44
+ parts.append(sent)
45
+ return parts
46
+
47
+
48
+ def chunk_document(doc: LoadedDoc, doc_id: str) -> list[Chunk]:
49
+ target = CHUNKING_CONFIG["target_chunk_chars"]
50
+ max_chars = CHUNKING_CONFIG["max_chunk_chars"]
51
+ min_chars = CHUNKING_CONFIG["min_chunk_chars"]
52
+ overlap = CHUNKING_CONFIG["overlap_chars"]
53
+
54
+ units: list[tuple[int, str]] = []
55
+ for page in doc.pages:
56
+ for sent in _split_sentences(page.text):
57
+ units.append((page.page_number, sent))
58
+
59
+ chunks: list[Chunk] = []
60
+ buf: list[tuple[int, str]] = []
61
+ buf_len = 0
62
+
63
+ def flush() -> None:
64
+ nonlocal buf, buf_len
65
+ if not buf:
66
+ return
67
+ text = " ".join(s for _, s in buf).strip()
68
+ if len(text) < min_chars and chunks:
69
+ chunks[-1].text = (chunks[-1].text + " " + text).strip()
70
+ chunks[-1].page_end = buf[-1][0]
71
+ buf, buf_len = [], 0
72
+ return
73
+ chunk = Chunk(
74
+ chunk_id=f"{doc_id}::c{len(chunks):04d}",
75
+ doc_id=doc_id,
76
+ source_path=doc.source_path,
77
+ title=doc.title,
78
+ page_start=buf[0][0],
79
+ page_end=buf[-1][0],
80
+ text=text,
81
+ )
82
+ chunks.append(chunk)
83
+ carry: list[tuple[int, str]] = []
84
+ carry_len = 0
85
+ for pn, s in reversed(buf):
86
+ if carry_len + len(s) + 1 > overlap:
87
+ break
88
+ carry.insert(0, (pn, s))
89
+ carry_len += len(s) + 1
90
+ buf = carry
91
+ buf_len = sum(len(s) + 1 for _, s in buf)
92
+
93
+ for pn, sent in units:
94
+ is_heading = bool(_HEADING_HINT.match(sent))
95
+ if is_heading and buf_len >= min_chars:
96
+ flush()
97
+ buf.append((pn, sent))
98
+ buf_len += len(sent) + 1
99
+ if buf_len >= target:
100
+ flush()
101
+ elif buf_len >= max_chars:
102
+ flush()
103
+ flush()
104
+ return chunks
ingestion/embedder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dense embeddings via sentence-transformers."""
2
+ from __future__ import annotations
3
+
4
+ from functools import lru_cache
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ from config import EMBEDDING_CONFIG
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_embedder() -> SentenceTransformer:
13
+ return SentenceTransformer(
14
+ EMBEDDING_CONFIG["model"],
15
+ device=EMBEDDING_CONFIG["device"],
16
+ )
17
+
18
+
19
+ def embed_texts(texts: list[str]) -> list[list[float]]:
20
+ model = get_embedder()
21
+ vecs = model.encode(
22
+ texts,
23
+ batch_size=EMBEDDING_CONFIG["batch_size"],
24
+ convert_to_numpy=True,
25
+ normalize_embeddings=True,
26
+ show_progress_bar=False,
27
+ )
28
+ return vecs.tolist()
29
+
30
+
31
+ def embed_query(text: str) -> list[float]:
32
+ return embed_texts([text])[0]
ingestion/indexer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChromaDB management + BM25 corpus persistence."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ import pickle
7
+ import re
8
+
9
+ os.environ.setdefault("ANONYMIZED_TELEMETRY", "False")
10
+ import chromadb
11
+ from chromadb.config import Settings # noqa: E402
12
+
13
+ from config import CHROMA_COLLECTION, PATHS
14
+ from ingestion.chunker import Chunk
15
+ from ingestion.embedder import embed_texts
16
+
17
+
18
+ def _ensure_dirs() -> None:
19
+ PATHS["chroma_dir"].mkdir(parents=True, exist_ok=True)
20
+ PATHS["bm25_path"].parent.mkdir(parents=True, exist_ok=True)
21
+
22
+
23
+ def _client() -> chromadb.PersistentClient:
24
+ return chromadb.PersistentClient(
25
+ path=str(PATHS["chroma_dir"]),
26
+ settings=Settings(anonymized_telemetry=False),
27
+ )
28
+
29
+
30
+ def get_chroma_collection():
31
+ _ensure_dirs()
32
+ client = _client()
33
+ return client.get_or_create_collection(
34
+ name=CHROMA_COLLECTION,
35
+ metadata={"hnsw:space": "cosine"},
36
+ )
37
+
38
+
39
+ def reset_index() -> None:
40
+ _ensure_dirs()
41
+ client = _client()
42
+ try:
43
+ client.delete_collection(CHROMA_COLLECTION)
44
+ except Exception:
45
+ pass
46
+ if PATHS["bm25_path"].exists():
47
+ PATHS["bm25_path"].unlink()
48
+ if PATHS["manifest_path"].exists():
49
+ PATHS["manifest_path"].unlink()
50
+
51
+
52
+ _TOKEN_RE = re.compile(r"[a-zA-Z0-9]+")
53
+
54
+
55
+ def tokenize(text: str) -> list[str]:
56
+ return [t.lower() for t in _TOKEN_RE.findall(text)]
57
+
58
+
59
+ def index_chunks(chunks: list[Chunk], reset: bool = False) -> dict:
60
+ _ensure_dirs()
61
+ if reset:
62
+ reset_index()
63
+
64
+ coll = get_chroma_collection()
65
+
66
+ ids = [c.chunk_id for c in chunks]
67
+ docs = [c.text for c in chunks]
68
+ metas = [
69
+ {
70
+ "doc_id": c.doc_id,
71
+ "source_path": c.source_path,
72
+ "title": c.title,
73
+ "page_start": c.page_start,
74
+ "page_end": c.page_end,
75
+ }
76
+ for c in chunks
77
+ ]
78
+
79
+ print(f" Embedding {len(docs)} chunks...")
80
+ embeddings = embed_texts(docs)
81
+
82
+ print(f" Writing to ChromaDB ({CHROMA_COLLECTION})...")
83
+ BATCH = 256
84
+ for i in range(0, len(ids), BATCH):
85
+ coll.upsert(
86
+ ids=ids[i : i + BATCH],
87
+ documents=docs[i : i + BATCH],
88
+ metadatas=metas[i : i + BATCH],
89
+ embeddings=embeddings[i : i + BATCH],
90
+ )
91
+
92
+ print(" Building BM25 corpus...")
93
+ tokenized = [tokenize(d) for d in docs]
94
+ with open(PATHS["bm25_path"], "wb") as f:
95
+ pickle.dump(
96
+ {"ids": ids, "tokenized": tokenized, "metas": metas, "docs": docs},
97
+ f,
98
+ )
99
+
100
+ manifest = {
101
+ "n_chunks": len(ids),
102
+ "chunks_per_doc": _group_count([c.doc_id for c in chunks]),
103
+ }
104
+ with open(PATHS["manifest_path"], "w") as f:
105
+ json.dump(manifest, f, indent=2)
106
+ return manifest
107
+
108
+
109
+ def _group_count(items: list[str]) -> dict:
110
+ out: dict = {}
111
+ for x in items:
112
+ out[x] = out.get(x, 0) + 1
113
+ return out
114
+
115
+
116
+ def fetch_embeddings(chunk_ids: list[str]) -> dict[str, list[float]]:
117
+ """Pull stored embeddings for a list of chunk ids (used for visualization)."""
118
+ if not chunk_ids:
119
+ return {}
120
+ coll = get_chroma_collection()
121
+ res = coll.get(ids=list(chunk_ids), include=["embeddings"])
122
+ out: dict[str, list[float]] = {}
123
+ for cid, vec in zip(res["ids"], res["embeddings"]):
124
+ out[cid] = list(vec)
125
+ return out
126
+
127
+
128
+ def load_bm25_corpus() -> dict:
129
+ if not PATHS["bm25_path"].exists():
130
+ raise FileNotFoundError(
131
+ f"BM25 corpus not found at {PATHS['bm25_path']}. Run ingestion first."
132
+ )
133
+ with open(PATHS["bm25_path"], "rb") as f:
134
+ return pickle.load(f)
ingestion/loader.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PDF loader. Returns per-page text + structural metadata."""
2
+ from __future__ import annotations
3
+
4
+ import re
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+
8
+ import pymupdf
9
+
10
+
11
+ @dataclass
12
+ class PageText:
13
+ page_number: int
14
+ text: str
15
+
16
+
17
+ @dataclass
18
+ class LoadedDoc:
19
+ source_path: str
20
+ title: str
21
+ pages: list[PageText] = field(default_factory=list)
22
+
23
+ @property
24
+ def full_text(self) -> str:
25
+ return "\n\n".join(p.text for p in self.pages)
26
+
27
+
28
+ _LIGATURES = {
29
+ "ff": "ff", "fi": "fi", "fl": "fl",
30
+ "ffi": "ffi", "ffl": "ffl",
31
+ }
32
+
33
+
34
+ def _clean(text: str) -> str:
35
+ for k, v in _LIGATURES.items():
36
+ text = text.replace(k, v)
37
+ text = re.sub(r"-\n(?=\w)", "", text)
38
+ text = re.sub(r"[ \t]+", " ", text)
39
+ text = re.sub(r"\n{3,}", "\n\n", text)
40
+ return text.strip()
41
+
42
+
43
+ _TITLE_OVERRIDES = {
44
+ "01_attention_is_all_you_need": "Attention Is All You Need (Vaswani et al., 2017)",
45
+ "02_bert": "BERT: Pre-training of Deep Bidirectional Transformers (Devlin et al., 2018)",
46
+ "03_gpt3": "Language Models are Few-Shot Learners (GPT-3, Brown et al., 2020)",
47
+ "04_ddpm": "Denoising Diffusion Probabilistic Models (Ho et al., 2020)",
48
+ "05_ddim": "Denoising Diffusion Implicit Models (Song et al., 2020)",
49
+ "06_rag_original": "Retrieval-Augmented Generation for Knowledge-Intensive NLP (Lewis et al., 2020)",
50
+ "07_rag_survey": "Retrieval-Augmented Generation for LLMs: A Survey (Gao et al., 2023)",
51
+ "08_self_rag": "Self-RAG: Learning to Retrieve, Generate, and Critique (Asai et al., 2023)",
52
+ "09_hyde": "Precise Zero-Shot Dense Retrieval with HyDE (Gao et al., 2022)",
53
+ "10_vit": "An Image is Worth 16x16 Words (Vision Transformer, Dosovitskiy et al., 2020)",
54
+ "11_clip": "Learning Transferable Visual Models from Natural Language Supervision (CLIP, Radford et al., 2021)",
55
+ "12_react": "ReAct: Synergizing Reasoning and Acting in Language Models (Yao et al., 2022)",
56
+ "13_chain_of_thought": "Chain-of-Thought Prompting Elicits Reasoning (Wei et al., 2022)",
57
+ "14_llm_survey": "A Survey of Large Language Models (Zhao et al., 2023)",
58
+ }
59
+
60
+
61
+ def load_pdf(path: str | Path) -> LoadedDoc:
62
+ path = Path(path)
63
+ doc = pymupdf.open(path)
64
+ pages: list[PageText] = []
65
+ for i, page in enumerate(doc, start=1):
66
+ raw = page.get_text("text")
67
+ cleaned = _clean(raw)
68
+ if cleaned:
69
+ pages.append(PageText(page_number=i, text=cleaned))
70
+ title = _TITLE_OVERRIDES.get(path.stem) or _guess_title(pages, fallback=path.stem)
71
+ doc.close()
72
+ return LoadedDoc(source_path=str(path), title=title, pages=pages)
73
+
74
+
75
+ def _guess_title(pages: list[PageText], fallback: str) -> str:
76
+ if not pages:
77
+ return fallback
78
+ first = pages[0].text
79
+ for line in first.splitlines():
80
+ line = line.strip()
81
+ if 10 < len(line) < 180 and not line.lower().startswith(("abstract", "arxiv:")):
82
+ return line
83
+ return fallback
84
+
85
+
86
+ def discover_pdfs(papers_dir: str | Path) -> list[Path]:
87
+ return sorted(Path(papers_dir).glob("*.pdf"))
llm/__init__.py ADDED
File without changes
llm/client_factory.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Return the right LLM client based on environment.
2
+
3
+ Local (Ollama running) → OllamaClient
4
+ Hosted (GROQ_API_KEY set) → GroqClient
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import os
9
+
10
+
11
+ def get_llm():
12
+ if os.environ.get("GROQ_API_KEY"):
13
+ from llm.groq_client import GroqClient
14
+ return GroqClient()
15
+ from llm.ollama_client import OllamaClient
16
+ return OllamaClient()
llm/groq_client.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Groq API client — drop-in replacement for OllamaClient when running hosted.
2
+
3
+ Free tier: 14,400 requests/day, ~500 req/min.
4
+ Text model : llama-3.1-8b-instant (fast, cheap)
5
+ Vision model: llama-3.2-11b-vision-preview (images)
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import base64
10
+ import json
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import requests
16
+
17
+ _TEXT_MODEL = "llama-3.1-8b-instant"
18
+ _VISION_MODEL = "llama-3.2-11b-vision-preview"
19
+ _BASE_URL = "https://api.groq.com/openai/v1"
20
+
21
+
22
+ class GroqClient:
23
+ def __init__(self) -> None:
24
+ self.api_key = os.environ.get("GROQ_API_KEY", "")
25
+ self.temperature = 0.1
26
+ self.timeout = 60
27
+
28
+ def _headers(self) -> dict[str, str]:
29
+ return {
30
+ "Authorization": f"Bearer {self.api_key}",
31
+ "Content-Type": "application/json",
32
+ }
33
+
34
+ def _chat(
35
+ self,
36
+ messages: list[dict],
37
+ model: str,
38
+ temperature: float,
39
+ json_mode: bool = False,
40
+ ) -> str:
41
+ payload: dict[str, Any] = {
42
+ "model": model,
43
+ "messages": messages,
44
+ "temperature": temperature,
45
+ }
46
+ if json_mode:
47
+ payload["response_format"] = {"type": "json_object"}
48
+ r = requests.post(
49
+ f"{_BASE_URL}/chat/completions",
50
+ headers=self._headers(),
51
+ json=payload,
52
+ timeout=self.timeout,
53
+ )
54
+ r.raise_for_status()
55
+ return r.json()["choices"][0]["message"]["content"].strip()
56
+
57
+ def _build_messages(
58
+ self,
59
+ prompt: str,
60
+ system: str | None,
61
+ images: list[str] | None,
62
+ ) -> tuple[list[dict], str]:
63
+ msgs: list[dict] = []
64
+ if system:
65
+ msgs.append({"role": "system", "content": system})
66
+ if images:
67
+ content: list[dict] = [{"type": "text", "text": prompt}]
68
+ for img_path in images:
69
+ b64 = self._encode_image(img_path)
70
+ suffix = Path(img_path).suffix.lower().lstrip(".")
71
+ mime = {"jpg": "image/jpeg", "jpeg": "image/jpeg",
72
+ "png": "image/png", "webp": "image/webp"}.get(suffix, "image/png")
73
+ content.append({"type": "image_url",
74
+ "image_url": {"url": f"data:{mime};base64,{b64}"}})
75
+ msgs.append({"role": "user", "content": content})
76
+ return msgs, _VISION_MODEL
77
+ msgs.append({"role": "user", "content": prompt})
78
+ return msgs, _TEXT_MODEL
79
+
80
+ def generate(
81
+ self,
82
+ prompt: str,
83
+ system: str | None = None,
84
+ temperature: float | None = None,
85
+ images: list[str] | None = None,
86
+ format: str | None = None,
87
+ ) -> str:
88
+ msgs, model = self._build_messages(prompt, system, images)
89
+ return self._chat(
90
+ msgs, model,
91
+ temperature if temperature is not None else self.temperature,
92
+ json_mode=(format == "json"),
93
+ )
94
+
95
+ def generate_json(
96
+ self,
97
+ prompt: str,
98
+ system: str | None = None,
99
+ temperature: float | None = None,
100
+ images: list[str] | None = None,
101
+ ) -> dict[str, Any]:
102
+ text = self.generate(prompt=prompt, system=system,
103
+ temperature=temperature, images=images, format="json")
104
+ return _safe_json(text)
105
+
106
+ @staticmethod
107
+ def _encode_image(path: str | Path) -> str:
108
+ with open(path, "rb") as f:
109
+ return base64.b64encode(f.read()).decode("utf-8")
110
+
111
+ def health(self) -> bool:
112
+ if not self.api_key:
113
+ return False
114
+ try:
115
+ r = requests.get(f"{_BASE_URL}/models", headers=self._headers(), timeout=5)
116
+ return r.status_code == 200
117
+ except requests.RequestException:
118
+ return False
119
+
120
+
121
+ def _safe_json(text: str) -> dict[str, Any]:
122
+ text = text.strip()
123
+ try:
124
+ return json.loads(text)
125
+ except json.JSONDecodeError:
126
+ pass
127
+ start, end = text.find("{"), text.rfind("}")
128
+ if start != -1 and end > start:
129
+ try:
130
+ return json.loads(text[start:end + 1])
131
+ except json.JSONDecodeError:
132
+ pass
133
+ return {"_raw": text}
llm/ollama_client.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thin Ollama HTTP client. Handles text + multimodal (image) requests."""
2
+ from __future__ import annotations
3
+
4
+ import base64
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import requests
10
+
11
+ from config import LLM_CONFIG
12
+
13
+
14
+ class OllamaClient:
15
+ def __init__(
16
+ self,
17
+ model: str | None = None,
18
+ base_url: str | None = None,
19
+ temperature: float | None = None,
20
+ timeout: int | None = None,
21
+ num_ctx: int | None = None,
22
+ ) -> None:
23
+ self.model = model or LLM_CONFIG["model"]
24
+ self.base_url = (base_url or LLM_CONFIG["base_url"]).rstrip("/")
25
+ self.temperature = temperature if temperature is not None else LLM_CONFIG["temperature"]
26
+ self.timeout = timeout or LLM_CONFIG["timeout"]
27
+ self.num_ctx = num_ctx or LLM_CONFIG["num_ctx"]
28
+
29
+ def _options(self, **overrides: Any) -> dict[str, Any]:
30
+ opts = {"temperature": self.temperature, "num_ctx": self.num_ctx}
31
+ opts.update({k: v for k, v in overrides.items() if v is not None})
32
+ return opts
33
+
34
+ def generate(
35
+ self,
36
+ prompt: str,
37
+ system: str | None = None,
38
+ temperature: float | None = None,
39
+ images: list[str] | None = None,
40
+ format: str | None = None,
41
+ ) -> str:
42
+ payload: dict[str, Any] = {
43
+ "model": self.model,
44
+ "prompt": prompt,
45
+ "stream": False,
46
+ "options": self._options(temperature=temperature),
47
+ }
48
+ if system:
49
+ payload["system"] = system
50
+ if images:
51
+ payload["images"] = [self._encode_image(p) for p in images]
52
+ if format:
53
+ payload["format"] = format
54
+ r = requests.post(
55
+ f"{self.base_url}/api/generate",
56
+ json=payload,
57
+ timeout=self.timeout,
58
+ )
59
+ r.raise_for_status()
60
+ return r.json().get("response", "").strip()
61
+
62
+ def generate_json(
63
+ self,
64
+ prompt: str,
65
+ system: str | None = None,
66
+ temperature: float | None = None,
67
+ images: list[str] | None = None,
68
+ ) -> dict[str, Any]:
69
+ """Ask the model for a JSON object and parse it. Falls back to extraction
70
+ if the model wraps JSON in prose or fences."""
71
+ text = self.generate(
72
+ prompt=prompt,
73
+ system=system,
74
+ temperature=temperature,
75
+ images=images,
76
+ format="json",
77
+ )
78
+ return _safe_json_loads(text)
79
+
80
+ @staticmethod
81
+ def _encode_image(path: str | Path) -> str:
82
+ with open(path, "rb") as f:
83
+ return base64.b64encode(f.read()).decode("utf-8")
84
+
85
+ def health(self) -> bool:
86
+ try:
87
+ r = requests.get(f"{self.base_url}/api/tags", timeout=5)
88
+ return r.status_code == 200
89
+ except requests.RequestException:
90
+ return False
91
+
92
+
93
+ def _safe_json_loads(text: str) -> dict[str, Any]:
94
+ text = text.strip()
95
+ try:
96
+ return json.loads(text)
97
+ except json.JSONDecodeError:
98
+ pass
99
+ if "```" in text:
100
+ parts = text.split("```")
101
+ for part in parts:
102
+ stripped = part.strip()
103
+ if stripped.startswith("json"):
104
+ stripped = stripped[4:].strip()
105
+ try:
106
+ return json.loads(stripped)
107
+ except json.JSONDecodeError:
108
+ continue
109
+ start = text.find("{")
110
+ end = text.rfind("}")
111
+ if start != -1 and end != -1 and end > start:
112
+ try:
113
+ return json.loads(text[start : end + 1])
114
+ except json.JSONDecodeError:
115
+ pass
116
+ return {"_raw": text}
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.5.23
2
+ sentence-transformers==3.3.1
3
+ rank-bm25==0.2.2
4
+ pymupdf==1.25.1
5
+ requests==2.32.3
6
+ numpy==1.26.4
7
+ streamlit==1.41.1
8
+ tqdm==4.67.1
9
+ pillow==11.0.0
retrieval/__init__.py ADDED
File without changes
retrieval/dense.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dense (vector) retrieval over ChromaDB."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ from config import RETRIEVAL_CONFIG
8
+ from ingestion.embedder import embed_query
9
+ from ingestion.indexer import get_chroma_collection
10
+
11
+
12
+ @dataclass
13
+ class Hit:
14
+ chunk_id: str
15
+ text: str
16
+ metadata: dict
17
+ score: float
18
+ rank: int
19
+
20
+
21
+ def dense_search(query: str, k: int | None = None) -> list[Hit]:
22
+ k = k or RETRIEVAL_CONFIG["dense_k"]
23
+ coll = get_chroma_collection()
24
+ qv = embed_query(query)
25
+ res = coll.query(
26
+ query_embeddings=[qv],
27
+ n_results=k,
28
+ include=["documents", "metadatas", "distances"],
29
+ )
30
+ ids = res["ids"][0]
31
+ docs = res["documents"][0]
32
+ metas = res["metadatas"][0]
33
+ dists = res["distances"][0]
34
+ hits: list[Hit] = []
35
+ for r, (cid, doc, meta, dist) in enumerate(zip(ids, docs, metas, dists)):
36
+ # cosine distance → similarity in [0,1]; chroma returns 1 - cosine_sim for cosine space
37
+ score = max(0.0, 1.0 - float(dist))
38
+ hits.append(Hit(chunk_id=cid, text=doc, metadata=dict(meta), score=score, rank=r))
39
+ return hits
retrieval/hybrid.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reciprocal Rank Fusion of dense + sparse hits."""
2
+ from __future__ import annotations
3
+
4
+ from config import RETRIEVAL_CONFIG
5
+ from retrieval.dense import Hit, dense_search
6
+ from retrieval.sparse import sparse_search
7
+
8
+
9
+ def reciprocal_rank_fusion(
10
+ rankings: list[list[Hit]],
11
+ k_const: int | None = None,
12
+ top_k: int | None = None,
13
+ ) -> list[Hit]:
14
+ k_const = k_const or RETRIEVAL_CONFIG["rrf_k"]
15
+ fused: dict[str, dict] = {}
16
+ for ranking in rankings:
17
+ for r, hit in enumerate(ranking):
18
+ entry = fused.setdefault(
19
+ hit.chunk_id,
20
+ {"hit": hit, "score": 0.0},
21
+ )
22
+ entry["score"] += 1.0 / (k_const + r + 1)
23
+ # prefer higher-ranked instance for the canonical hit object
24
+ if hit.rank < entry["hit"].rank:
25
+ entry["hit"] = hit
26
+ merged = sorted(fused.values(), key=lambda x: x["score"], reverse=True)
27
+ out: list[Hit] = []
28
+ for r, entry in enumerate(merged):
29
+ h = entry["hit"]
30
+ out.append(
31
+ Hit(
32
+ chunk_id=h.chunk_id,
33
+ text=h.text,
34
+ metadata=h.metadata,
35
+ score=entry["score"],
36
+ rank=r,
37
+ )
38
+ )
39
+ if top_k:
40
+ out = out[:top_k]
41
+ return out
42
+
43
+
44
+ def hybrid_search(query: str, top_k: int | None = None) -> list[Hit]:
45
+ dense = dense_search(query)
46
+ sparse = sparse_search(query)
47
+ return reciprocal_rank_fusion([dense, sparse], top_k=top_k)
retrieval/pipeline.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end retrieval: hybrid search + cross-encoder reranking."""
2
+ from __future__ import annotations
3
+
4
+ from config import RETRIEVAL_CONFIG
5
+ from retrieval.dense import Hit
6
+ from retrieval.hybrid import hybrid_search
7
+ from retrieval.reranker import rerank
8
+
9
+
10
+ def hybrid_retrieve(query: str, top_n: int | None = None) -> list[Hit]:
11
+ fused = hybrid_search(query, top_k=max(RETRIEVAL_CONFIG["dense_k"], RETRIEVAL_CONFIG["sparse_k"]))
12
+ return rerank(query, fused, top_n=top_n)
retrieval/reranker.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cross-encoder reranker: deep relevance scoring on top of fused hits."""
2
+ from __future__ import annotations
3
+
4
+ from functools import lru_cache
5
+
6
+ from sentence_transformers import CrossEncoder
7
+
8
+ from config import RERANKER_CONFIG, RETRIEVAL_CONFIG
9
+ from retrieval.dense import Hit
10
+
11
+
12
+ @lru_cache(maxsize=1)
13
+ def get_reranker() -> CrossEncoder:
14
+ return CrossEncoder(RERANKER_CONFIG["model"], device=RERANKER_CONFIG["device"])
15
+
16
+
17
+ def rerank(query: str, hits: list[Hit], top_n: int | None = None) -> list[Hit]:
18
+ if not hits:
19
+ return []
20
+ top_n = top_n or RETRIEVAL_CONFIG["rerank_top_n"]
21
+ model = get_reranker()
22
+ pairs = [(query, h.text) for h in hits]
23
+ scores = model.predict(pairs, show_progress_bar=False)
24
+ ranked = sorted(zip(hits, scores), key=lambda x: float(x[1]), reverse=True)[:top_n]
25
+ out: list[Hit] = []
26
+ for r, (h, s) in enumerate(ranked):
27
+ out.append(
28
+ Hit(
29
+ chunk_id=h.chunk_id,
30
+ text=h.text,
31
+ metadata=h.metadata,
32
+ score=float(s),
33
+ rank=r,
34
+ )
35
+ )
36
+ return out
retrieval/sparse.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sparse retrieval via BM25Okapi over the persisted token corpus."""
2
+ from __future__ import annotations
3
+
4
+ from functools import lru_cache
5
+
6
+ from rank_bm25 import BM25Okapi
7
+
8
+ from config import RETRIEVAL_CONFIG
9
+ from ingestion.indexer import load_bm25_corpus, tokenize
10
+ from retrieval.dense import Hit
11
+
12
+
13
+ @lru_cache(maxsize=1)
14
+ def _bm25_state():
15
+ corpus = load_bm25_corpus()
16
+ bm25 = BM25Okapi(corpus["tokenized"])
17
+ return bm25, corpus
18
+
19
+
20
+ def sparse_search(query: str, k: int | None = None) -> list[Hit]:
21
+ k = k or RETRIEVAL_CONFIG["sparse_k"]
22
+ bm25, corpus = _bm25_state()
23
+ tokens = tokenize(query)
24
+ if not tokens:
25
+ return []
26
+ scores = bm25.get_scores(tokens)
27
+ idx_sorted = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
28
+ max_score = float(scores[idx_sorted[0]]) if idx_sorted else 0.0
29
+ hits: list[Hit] = []
30
+ for r, i in enumerate(idx_sorted):
31
+ s = float(scores[i])
32
+ if s <= 0:
33
+ continue
34
+ norm = s / max_score if max_score > 0 else 0.0
35
+ hits.append(
36
+ Hit(
37
+ chunk_id=corpus["ids"][i],
38
+ text=corpus["docs"][i],
39
+ metadata=dict(corpus["metas"][i]),
40
+ score=norm,
41
+ rank=r,
42
+ )
43
+ )
44
+ return hits
storage/bm25.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01eb71c2d4f2c3f021e925e3c96a3131fd8406592c017daf8a405ecac1578ec6
3
+ size 5864776
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:657eee8d3193abf7c307965c3cb694d08ac955dcc9b6ed3139d1d60e746449b8
3
+ size 1676000
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2ca0f721e87b23fa44989e0cb59d71339f5a16cc05a8eb3c7777e658757e2e5
3
+ size 100
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a9b9b398fb0167ff163c08c65e2728ff53af2a5e421a2355c4ee1ecdc746e2c
3
+ size 36020
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d51e5510e0fb53d9746f2be0e91a1f7cc74c25bdf21a7ca8096fccf6074521a3
3
+ size 4000
storage/chroma/bccd7ca5-4f87-4c9e-a569-6cf0dcdced21/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:358ea02768d90865ee8a426535f44933635bc209dda6b9481c2b9f221f08b18b
3
+ size 8148
storage/chroma/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9526de2262c6da38b5d8ca3bb71f7902e912713b831020b7fb9dc22ba9d4d7fc
3
+ size 51924992
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be8ace3c8f7cfb3c3d7fca6e733651bff7d18d34a51fbeac7096129dd7bb883d
3
+ size 1676000
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2ca0f721e87b23fa44989e0cb59d71339f5a16cc05a8eb3c7777e658757e2e5
3
+ size 100
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56cdf88e893d7b0e61a770d05eb551f743b72862a8666e733304527383803612
3
+ size 36020
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49b0ed9d141de28649f8bb3bcf395a50c5a960b5f003bf4b4964e5ac62fba885
3
+ size 4000
storage/chroma/d7228068-4c70-4b64-a819-d7dbd7d28b63/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa7615217e24a544a1ecbaaa207daea4e33fe9dca8b5967ca8c5a90b14a90782
3
+ size 8148
storage/manifest.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_chunks": 1934,
3
+ "chunks_per_doc": {
4
+ "01_attention_is_all_you_need": 34,
5
+ "02_bert": 60,
6
+ "03_gpt3": 237,
7
+ "04_ddpm": 48,
8
+ "05_ddim": 45,
9
+ "06_rag_original": 59,
10
+ "07_rag_survey": 95,
11
+ "08_self_rag": 88,
12
+ "09_hyde": 35,
13
+ "10_vit": 61,
14
+ "11_clip": 217,
15
+ "12_react": 95,
16
+ "13_chain_of_thought": 135,
17
+ "14_llm_survey": 725
18
+ }
19
+ }