Hemanth-05 commited on
Commit
47056dc
·
1 Parent(s): 3fb7184

RAG: wire chat + prompt + services module

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. artifacts/prompt.poml +14 -0
  3. pages/chat.py +3 -3
  4. services/rag_engine.py +218 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
 
 
1
  __pycache__/
2
  *.pyc
3
  .DS_Store
 
1
+ .venv/
2
+ venv/
3
  __pycache__/
4
  *.pyc
5
  .DS_Store
artifacts/prompt.poml CHANGED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a grounded assistant for a NotebookLM-style app.
2
+ Rules:
3
+ 1) Answer ONLY from the provided context.
4
+ 2) If the answer is not in context, say you could not find it in the uploaded sources.
5
+ 3) Cite supporting sources inline using [S1], [S2], etc.
6
+ 4) Keep the answer concise and factual.
7
+
8
+ Question:
9
+ {{question}}
10
+
11
+ Context:
12
+ {{context}}
13
+
14
+ Answer:
pages/chat.py CHANGED
@@ -3,7 +3,7 @@
3
  import uuid
4
  from datetime import datetime
5
  from state import UserData, Message, get_active_notebook
6
- from mock_data import get_mock_response
7
 
8
 
9
  FILE_TYPE_ICONS = {
@@ -67,8 +67,8 @@ def handle_chat_submit(message: str, state: UserData) -> tuple[UserData, list[di
67
  )
68
  nb.messages.append(user_msg)
69
 
70
- # Get mock response
71
- response = get_mock_response(message)
72
 
73
  # Add assistant message
74
  assistant_msg = Message(
 
3
  import uuid
4
  from datetime import datetime
5
  from state import UserData, Message, get_active_notebook
6
+ from services.rag_engine import rag_answer
7
 
8
 
9
  FILE_TYPE_ICONS = {
 
67
  )
68
  nb.messages.append(user_msg)
69
 
70
+ # Get actual response
71
+ response = rag_answer(message.strip(), nb.id)
72
 
73
  # Add assistant message
74
  assistant_msg = Message(
services/rag_engine.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retrieval-only RAG engine for chat responses."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import re
8
+ from pathlib import Path
9
+
10
+ from ingestion_engine.embedding_generator import generate_query
11
+ from persistence.vector_store import VectorStore
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ K_RETRIEVE = 40
16
+ K_FINAL = 8
17
+ ALPHA = 0.05
18
+ MAX_SNIPPET_CHARS = 280
19
+ GEN_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
20
+ MAX_NEW_TOKENS = 400
21
+ TEMPERATURE = 0.2
22
+ TIMEOUT_SEC = 45
23
+
24
+ PROMPT_FILE = Path(__file__).resolve().parent.parent / "artifacts" / "prompt.poml"
25
+ DEFAULT_PROMPT_TEMPLATE = (
26
+ "You are a grounded assistant for a NotebookLM-style app.\n"
27
+ "Rules:\n"
28
+ "1) Answer ONLY from the provided context.\n"
29
+ "2) If the answer is not in context, say you could not find it in the uploaded sources.\n"
30
+ "3) Cite supporting sources inline using [S1], [S2], etc.\n"
31
+ "4) Keep the answer concise and factual.\n\n"
32
+ "Question:\n{{question}}\n\n"
33
+ "Context:\n{{context}}\n\n"
34
+ "Answer:"
35
+ )
36
+
37
+
38
+ def _clean_text(text: str) -> str:
39
+ return " ".join((text or "").split())
40
+
41
+
42
+ def _tokenize_keywords(text: str) -> set[str]:
43
+ tokens = re.split(r"[^a-z0-9]+", (text or "").lower())
44
+ return {t for t in tokens if len(t) >= 3}
45
+
46
+
47
+ def _keyword_hit_count(query_keywords: set[str], chunk_text: str) -> int:
48
+ if not query_keywords:
49
+ return 0
50
+ chunk_tokens = _tokenize_keywords(chunk_text)
51
+ return len(query_keywords.intersection(chunk_tokens))
52
+
53
+
54
+ def _rerank_matches(query: str, matches: list[dict]) -> list[dict]:
55
+ """Stage 2 rerank: pinecone score + ALPHA * lexical keyword hits."""
56
+ query_keywords = _tokenize_keywords(query)
57
+ rescored = []
58
+ for m in matches:
59
+ pinecone_score = float(m.get("score", 0.0) or 0.0)
60
+ hits = _keyword_hit_count(query_keywords, m.get("text", ""))
61
+ combined_score = pinecone_score + ALPHA * hits
62
+ rescored.append(
63
+ {
64
+ **m,
65
+ "keyword_hit_count": hits,
66
+ "combined_score": combined_score,
67
+ }
68
+ )
69
+ rescored.sort(key=lambda x: x.get("combined_score", 0.0), reverse=True)
70
+ return rescored[:K_FINAL]
71
+
72
+
73
+ def _build_citations(matches: list[dict]) -> list[dict]:
74
+ """Convert vector matches into the citation format used by pages/chat.py."""
75
+ citations = []
76
+ seen = set()
77
+
78
+ for match in matches:
79
+ source = match.get("source_filename", "Unknown source")
80
+ chunk_index = int(match.get("chunk_index", 0) or 0)
81
+ key = (source, chunk_index)
82
+ if key in seen:
83
+ continue
84
+ seen.add(key)
85
+
86
+ snippet = _clean_text(match.get("text", ""))
87
+ if len(snippet) > MAX_SNIPPET_CHARS:
88
+ snippet = snippet[:MAX_SNIPPET_CHARS].rstrip() + "..."
89
+
90
+ citations.append(
91
+ {
92
+ "source": source,
93
+ "page": chunk_index,
94
+ "text": snippet,
95
+ }
96
+ )
97
+
98
+ return citations
99
+
100
+
101
+ def _build_content(matches: list[dict]) -> str:
102
+ if not matches:
103
+ return (
104
+ "I couldn't find relevant information in your uploaded sources for that question. "
105
+ "Try rephrasing the question or adding more sources."
106
+ )
107
+
108
+ lines = ["Based on your uploaded sources, here are the most relevant passages:", ""]
109
+ for idx, match in enumerate(matches, start=1):
110
+ source = match.get("source_filename", "Unknown source")
111
+ chunk_index = int(match.get("chunk_index", 0) or 0)
112
+ score = float(match.get("score", 0.0) or 0.0)
113
+ combined = float(match.get("combined_score", score) or score)
114
+ hits = int(match.get("keyword_hit_count", 0) or 0)
115
+ snippet = _clean_text(match.get("text", ""))
116
+ if len(snippet) > MAX_SNIPPET_CHARS:
117
+ snippet = snippet[:MAX_SNIPPET_CHARS].rstrip() + "..."
118
+
119
+ lines.append(
120
+ f"{idx}. **{source}** (chunk {chunk_index}, pinecone: {score:.3f}, hits: {hits}, combined: {combined:.3f})"
121
+ )
122
+ lines.append(f" {snippet}")
123
+ lines.append("")
124
+
125
+ lines.append("This is a two-stage retrieval-only response (no LLM synthesis yet).")
126
+ return "\n".join(lines)
127
+
128
+
129
+ def _build_prompt(question: str, reranked_matches: list[dict]) -> str:
130
+ """Build a grounded prompt from top reranked chunks."""
131
+ context_blocks = []
132
+ for idx, match in enumerate(reranked_matches, start=1):
133
+ source = match.get("source_filename", "Unknown source")
134
+ chunk_index = int(match.get("chunk_index", 0) or 0)
135
+ text = _clean_text(match.get("text", ""))
136
+ context_blocks.append(f"[S{idx}] source={source} chunk={chunk_index}\n{text}")
137
+
138
+ context_text = "\n\n".join(context_blocks)
139
+ template = _load_prompt_template()
140
+ return (
141
+ template.replace("{{question}}", question)
142
+ .replace("{{context}}", context_text)
143
+ )
144
+
145
+
146
+ def _load_prompt_template() -> str:
147
+ """Load prompt template from artifacts/prompt.poml; fallback to default."""
148
+ try:
149
+ text = PROMPT_FILE.read_text(encoding="utf-8")
150
+ if text.strip():
151
+ return text
152
+ except Exception:
153
+ pass
154
+ return DEFAULT_PROMPT_TEMPLATE
155
+
156
+
157
+ def _generate_answer(question: str, context_chunks: list[dict]) -> str:
158
+ """Generate a grounded response using Hugging Face Inference API."""
159
+ from huggingface_hub import InferenceClient
160
+
161
+ token = os.environ.get("HF_TOKEN")
162
+ client = InferenceClient(token=token, timeout=TIMEOUT_SEC)
163
+ prompt = _build_prompt(question, context_chunks)
164
+
165
+ output = client.text_generation(
166
+ prompt=prompt,
167
+ model=GEN_MODEL,
168
+ max_new_tokens=MAX_NEW_TOKENS,
169
+ temperature=TEMPERATURE,
170
+ do_sample=True,
171
+ return_full_text=False,
172
+ )
173
+ return (output or "").strip()
174
+
175
+
176
+ def rag_answer(question: str, notebook_id: str) -> dict:
177
+ """Return a retrieval-only answer object: {"content": str, "citations": list}."""
178
+ q = (question or "").strip()
179
+ if not q:
180
+ return {"content": "Please enter a question.", "citations": []}
181
+
182
+ try:
183
+ query_vector = generate_query(q)
184
+ # Stage 1: retrieve candidate pool
185
+ matches = VectorStore().query(query_vector=query_vector, namespace=notebook_id, top_k=K_RETRIEVE)
186
+ candidates = [m for m in matches if m.get("text")]
187
+ if not candidates:
188
+ return {
189
+ "content": (
190
+ "I couldn't find relevant information in your uploaded sources for that question. "
191
+ "Try rephrasing the question or adding more sources."
192
+ ),
193
+ "citations": [],
194
+ }
195
+
196
+ # Stage 2: rerank and keep top K_FINAL
197
+ final_matches = _rerank_matches(q, candidates)
198
+ citations = _build_citations(final_matches)
199
+ retrieval_only = _build_content(final_matches)
200
+
201
+ try:
202
+ generated = _generate_answer(q, final_matches)
203
+ content = generated or retrieval_only
204
+ except Exception as e:
205
+ logger.warning("Generation failed, falling back to retrieval-only content: %s", e)
206
+ content = retrieval_only
207
+
208
+ return {
209
+ "content": content,
210
+ "citations": citations,
211
+ }
212
+
213
+ except Exception as e:
214
+ logger.error("RAG retrieval failed: %s", e)
215
+ return {
216
+ "content": f"I ran into an error while retrieving from sources: {e}",
217
+ "citations": [],
218
+ }