cmd0160 commited on
Commit
18ef2cd
·
1 Parent(s): 8755a87

Updating package file structure

Browse files
app.py CHANGED
@@ -1,351 +1,388 @@
1
  import os
2
- import sys
3
- import subprocess
4
- import re
5
 
 
6
  os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
7
  os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")
8
  os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false")
9
 
10
  import streamlit as st
11
- from src.vectorstore import get_retriever
12
- from src.qa_chain import make_conversational_chain
13
 
14
- st.set_page_config(page_title="Abalone RAG Chatbot", page_icon="🐚")
15
-
16
- st.title("Abalone RAG Chatbot")
17
- st.write(
18
- "Ask natural-language questions about abalone studies and data. "
19
- "The app uses a local Chroma vectorstore and OpenAI to retrieve and answer."
20
  )
 
 
 
21
 
22
- # ---------------- Sidebar ----------------
23
-
24
- st.sidebar.header("Model Settings")
25
 
26
- model_name = st.sidebar.selectbox(
27
- "Model",
28
- options=["gpt-3.5-turbo", "gpt-4"],
29
- index=0,
30
- )
31
 
32
- st.sidebar.markdown("---")
 
 
33
 
34
- st.sidebar.header("Retrieval Configuration")
 
 
 
 
 
35
 
36
- top_k = st.sidebar.slider(
37
- "Number of retrieved chunks (k)",
38
- min_value=2,
39
- max_value=10,
40
- value=4,
41
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- retrieval_mode_label = st.sidebar.selectbox(
44
- "Retrieval mode",
45
- ["MMR (diverse)", "Similarity", "Hybrid (dense + MMR)"],
46
- index=0,
47
- )
48
 
49
- retrieval_mode_map = {
50
- "MMR (diverse)": "mmr",
51
- "Similarity": "similarity",
52
- "Hybrid (dense + MMR)": "hybrid",
53
- }
54
- retrieval_mode = retrieval_mode_map[retrieval_mode_label]
55
 
56
- st.sidebar.markdown("---")
 
57
 
58
- st.sidebar.header("Answer Style")
 
 
 
 
 
59
 
60
- temperature = st.sidebar.slider(
61
- "Temperature",
62
- min_value=0.0,
63
- max_value=1.0,
64
- value=0.2,
65
- step=0.05,
66
- )
 
 
 
 
67
 
68
- answer_length = st.sidebar.selectbox(
69
- "Answer length",
70
- ["Short", "Medium", "Long"],
71
- index=1,
72
- )
73
 
74
- st.sidebar.markdown("---")
 
75
 
76
- st.sidebar.header("Vectorstore Controls")
 
 
 
 
 
 
77
 
78
- rebuild_clicked = st.sidebar.button("Rebuild vectorstore", use_container_width=True)
 
 
 
 
79
 
80
- st.sidebar.markdown(
81
- "<small>Use this when you add or modify files in <code>./data</code>.</small>",
82
- unsafe_allow_html=True,
83
- )
84
 
85
- # -------------- Core config ----------------
86
-
87
- length_instruction_map = {
88
- "Short": "Answer in 1–3 sentences.",
89
- "Medium": "Answer in 1–2 paragraphs.",
90
- "Long": "Provide a detailed, multi-paragraph explanation.",
91
- }
92
- length_instruction = length_instruction_map[answer_length]
93
- style_instruction = (
94
- length_instruction
95
- + f" Use a response style appropriate for a temperature of {temperature:.2f}, "
96
- "where lower values are more factual and higher values are more exploratory."
97
- )
98
 
99
- data_dir = "./data"
100
- persist_dir = "./vectorstore"
101
-
102
- if "chat_history" not in st.session_state:
103
- st.session_state["chat_history"] = []
104
-
105
- if "rebuild_pending" not in st.session_state:
106
- st.session_state["rebuild_pending"] = False
107
-
108
- # -------------- Helpers ----------------
109
-
110
- def ensure_openai_key() -> bool:
111
- if not os.environ.get("OPENAI_API_KEY"):
112
- st.error("OPENAI_API_KEY is not set.")
113
- return False
114
- return True
115
-
116
- def run_ingest_cli(data_dir: str, persist_dir: str):
117
- cmd = [
118
- sys.executable,
119
- "-m",
120
- "src.ingest",
121
- "--data-dir",
122
- data_dir,
123
- "--persist-dir",
124
- persist_dir,
125
- ]
126
- subprocess.run(cmd, check=True)
127
-
128
- @st.cache_resource(show_spinner=False)
129
- def build_or_load_retriever_cached(
130
- data_dir: str,
131
- persist_dir: str,
132
- top_k: int,
133
- retrieval_mode: str,
134
- ):
135
- try:
136
- return get_retriever(
137
- persist_dir=persist_dir,
138
- top_k=top_k,
139
- retrieval_mode=retrieval_mode,
140
  )
141
- except Exception:
142
- run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
143
- return get_retriever(
144
- persist_dir=persist_dir,
145
- top_k=top_k,
146
- retrieval_mode=retrieval_mode,
147
  )
148
 
149
- @st.cache_resource(show_spinner=False)
150
- def get_chain(model_name: str, top_k: int, retrieval_mode: str):
151
- retriever = build_or_load_retriever_cached(
152
- data_dir=data_dir,
153
- persist_dir=persist_dir,
154
- top_k=top_k,
155
- retrieval_mode=retrieval_mode,
156
- )
157
- return make_conversational_chain(retriever, model_name=model_name)
158
-
159
- def format_source_label(meta: dict, index: int) -> str:
160
- source = (
161
- meta.get("source")
162
- or meta.get("file_path")
163
- or meta.get("path")
164
- or meta.get("document_id")
165
- or "Unknown source"
166
- )
167
- return f"[{index}] {source}"
168
-
169
- def tokenize(text: str):
170
- return [w.lower() for w in re.findall(r"\w+", text) if len(w) > 3]
171
-
172
- def compute_quality_scores(question: str, answer: str, sources: list):
173
- all_chunk_text = " ".join(s.get("content", "") for s in sources)
174
- q_tokens = tokenize(question)
175
- a_tokens = tokenize(answer)
176
- c_tokens = set(tokenize(all_chunk_text))
177
- if not c_tokens:
178
- return 0.0, 0.0
179
- if not q_tokens:
180
- coverage = 0.0
181
- else:
182
- coverage = sum(1 for t in q_tokens if t in c_tokens) / len(q_tokens)
183
- if not a_tokens:
184
- grounding = 0.0
185
- else:
186
- grounding = sum(1 for t in a_tokens if t in c_tokens) / len(a_tokens)
187
- return coverage, grounding
188
-
189
- if not ensure_openai_key():
190
- st.stop()
191
-
192
- # -------------- Rebuild confirmation + chain init ----------------
193
-
194
- if rebuild_clicked:
195
- st.session_state["rebuild_pending"] = True
196
-
197
- chain = None
198
-
199
- if st.session_state["rebuild_pending"]:
200
- st.warning(
201
- "Rebuild the vectorstore from the current contents of ./data? "
202
- "This will overwrite existing embeddings."
203
- )
204
-
205
- col_left, col_center, col_right = st.columns([1, 2, 1])
206
-
207
- with col_center:
208
- confirm_rebuild = st.button(
209
- "Yes, rebuild",
210
- key="confirm_rebuild",
211
- use_container_width=True,
212
  )
213
- cancel_rebuild = st.button(
214
- "Cancel",
215
- key="cancel_rebuild",
216
- use_container_width=True,
 
 
 
 
 
217
  )
218
 
219
- st.markdown(
 
 
 
 
 
 
 
 
 
 
 
 
220
  """
221
- <style>
222
- div[data-testid="column"] div:has(> button[aria-label="Yes, rebuild"]) button {
223
- background-color: #27ae60 !important;
224
- color: white !important;
225
- }
226
- div[data-testid="column"] div:has(> button[aria-label="Cancel"]) button {
227
- background-color: #c0392b !important;
228
- color: white !important;
229
- }
230
- </style>
231
- """,
232
- unsafe_allow_html=True,
233
- )
234
-
235
- if confirm_rebuild:
236
- with st.spinner("Rebuilding vectorstore..."):
237
- run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
238
- build_or_load_retriever_cached.clear()
239
- get_chain.clear()
240
- chain = get_chain(
241
- model_name=model_name,
242
- top_k=top_k,
243
- retrieval_mode=retrieval_mode,
244
- )
245
- st.session_state["rebuild_pending"] = False
246
- st.success("Vectorstore rebuilt successfully.")
247
-
248
- elif cancel_rebuild:
249
- st.session_state["rebuild_pending"] = False
250
- st.info("Rebuild canceled.")
251
-
252
- if chain is None and not st.session_state["rebuild_pending"]:
253
- with st.spinner("Initializing knowledge base and chat model..."):
254
- chain = get_chain(
255
- model_name=model_name,
256
- top_k=top_k,
257
- retrieval_mode=retrieval_mode,
258
  )
259
- st.success("Knowledge base and model are ready.")
260
- elif chain is not None and not st.session_state["rebuild_pending"]:
261
- st.success("Knowledge base and model are ready.")
262
 
263
- # -------------- Render chat history ----------------
264
 
265
- if st.session_state["chat_history"]:
266
- for turn in st.session_state["chat_history"]:
267
- with st.chat_message("user"):
268
- st.markdown(turn["question"])
269
- answer_text = turn["answer"]
270
- with st.chat_message("assistant"):
271
- st.markdown(answer_text)
 
 
 
 
272
 
273
- # -------------- New user input ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- user_input = st.chat_input("Ask a question about abalone (biology, data, methodology, etc.)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- if user_input and chain is not None and not st.session_state["rebuild_pending"]:
278
- with st.chat_message("user"):
279
- st.markdown(user_input)
280
 
281
- with st.spinner("Thinking..."):
282
- prior_history = [
283
- (h.get("question"), h.get("answer", "")) for h in st.session_state["chat_history"]
284
- ]
 
 
285
 
286
- styled_question = style_instruction + "\n\nQuestion: " + user_input
287
 
288
- result = chain(
289
- {"question": styled_question, "chat_history": prior_history}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
 
 
291
 
292
- answer = (
293
- result.get("answer")
294
- or result.get("result")
295
- or result.get("output_text")
296
- or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  )
298
- source_docs = result.get("source_documents") or []
299
-
300
- sources_for_ui = []
301
- for idx, sd in enumerate(source_docs, start=1):
302
- if isinstance(sd, dict):
303
- meta = sd.get("metadata", {}) or {}
304
- content_full = sd.get("page_content") or sd.get("content") or sd.get("text", "")
305
- else:
306
- meta = getattr(sd, "metadata", {}) or {}
307
- content_full = getattr(sd, "page_content", None)
308
- if content_full is None:
309
- content_full = getattr(sd, "content", "")
310
- if content_full is None:
311
- content_full = ""
312
- sources_for_ui.append(
313
- {
314
- "index": idx,
315
- "metadata": meta,
316
- "content": str(content_full),
317
- }
318
- )
319
 
320
- coverage, grounding = compute_quality_scores(user_input, answer, sources_for_ui)
321
- coverage_pct = int(round(coverage * 100))
322
- grounding_pct = int(round(grounding * 100))
323
- answer_text = answer
324
-
325
- with st.chat_message("assistant"):
326
- st.markdown(answer_text)
327
-
328
- with st.expander("Retrieval Metrics and Sources"):
329
- st.markdown(f"- Retrieval mode: `{retrieval_mode}`")
330
- st.markdown(f"- k: `{top_k}`")
331
- st.markdown(f"- Coverage score (question vs sources): **{coverage_pct}%**")
332
- st.markdown(f"- Grounding score (answer vs sources): **{grounding_pct}%**")
333
-
334
- if sources_for_ui:
335
- st.markdown("**Retrieved chunks:**")
336
- for src in sources_for_ui:
337
- idx = src.get("index", 0)
338
- meta = src.get("metadata", {}) or {}
339
- label = format_source_label(meta, idx)
340
- chunk_text = src.get("content", "")
341
- snippet = chunk_text[:200].replace("\n", " ")
342
- st.markdown(f"**[{idx}] {label}**")
343
- st.code(snippet + "...")
344
-
345
- st.session_state["chat_history"].append(
346
- {
347
- "question": user_input,
348
- "answer": answer,
349
- "sources": sources_for_ui,
350
- }
351
- )
 
1
  import os
2
+ from typing import List, Dict, Tuple, Optional
 
 
3
 
4
+ # Disable telemetry for LangChain and Chroma by default
5
  os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
6
  os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")
7
  os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false")
8
 
9
  import streamlit as st
 
 
10
 
11
+ from src.utils.rag_runtime import (
12
+ run_ingest_cli,
13
+ build_or_load_retriever_cached,
14
+ get_chain_cached,
 
 
15
  )
16
+ from src.utils.metrics import compute_quality_scores
17
+ from src.utils.formatting import format_source_label
18
+ from src.utils.env import ensure_openai_key
19
 
 
 
 
20
 
21
+ class AbaloneRAGApp:
22
+ """Main application class for the Abalone RAG Chatbot."""
 
 
 
23
 
24
+ def __init__(self) -> None:
25
+ """Initialize the Streamlit page and application state."""
26
+ st.set_page_config(page_title="Abalone RAG Chatbot", page_icon="🐚")
27
 
28
+ st.title("Abalone RAG Chatbot")
29
+ st.write(
30
+ "Ask natural-language questions about abalone biology, ecology, "
31
+ "and research datasets. The app uses a local Chroma vectorstore "
32
+ "and OpenAI to retrieve and answer questions accurately."
33
+ )
34
 
35
+ # Data and vectorstore locations
36
+ self.data_dir = "./data"
37
+ self.persist_dir = "./vectorstore"
38
+
39
+ # Initialize session state
40
+ st.session_state.setdefault("chat_history", [])
41
+ st.session_state.setdefault("rebuild_pending", False)
42
+ self.chat_history: List[Dict] = st.session_state["chat_history"]
43
+
44
+ # Sidebar configuration
45
+ (
46
+ self.model_name,
47
+ self.top_k,
48
+ self.retrieval_mode,
49
+ self.temperature,
50
+ self.answer_length,
51
+ self.style_instruction,
52
+ self.rebuild_clicked,
53
+ ) = self._build_sidebar()
54
+
55
+ # QA chain instance (loaded lazily)
56
+ self.chain: Optional[object] = None
57
+
58
+ # ------------------------------------------------------------------
59
+ # Sidebar configuration
60
+ # ------------------------------------------------------------------
61
+
62
+ def _build_sidebar(self) -> Tuple[str, int, str, float, str, str, bool]:
63
+ """Render all sidebar controls and return model configuration.
64
+
65
+ Returns:
66
+ Tuple containing:
67
+ - model_name: Which LLM to use.
68
+ - top_k: Number of chunks to retrieve.
69
+ - retrieval_mode: Strategy (mmr, similarity, hybrid).
70
+ - temperature: LLM temperature.
71
+ - answer_length: Short/Medium/Long preference.
72
+ - style_instruction: Natural-language style directive.
73
+ - rebuild_clicked: Whether "Rebuild vectorstore" was pressed.
74
+ """
75
+ st.sidebar.header("Model Settings")
76
 
77
+ model_name = st.sidebar.selectbox(
78
+ "Model",
79
+ options=["gpt-3.5-turbo", "gpt-4"],
80
+ index=0,
81
+ )
82
 
83
+ st.sidebar.markdown("---")
 
 
 
 
 
84
 
85
+ # Retrieval configuration
86
+ st.sidebar.header("Retrieval Configuration")
87
 
88
+ top_k = st.sidebar.slider(
89
+ "Number of retrieved chunks (k)",
90
+ min_value=2,
91
+ max_value=10,
92
+ value=4,
93
+ )
94
 
95
+ retrieval_mode_label = st.sidebar.selectbox(
96
+ "Retrieval mode",
97
+ ["MMR (diverse)", "Similarity", "Hybrid (dense + MMR)"],
98
+ index=0,
99
+ )
100
+ retrieval_mode_map = {
101
+ "MMR (diverse)": "mmr",
102
+ "Similarity": "similarity",
103
+ "Hybrid (dense + MMR)": "hybrid",
104
+ }
105
+ retrieval_mode = retrieval_mode_map[retrieval_mode_label]
106
 
107
+ st.sidebar.markdown("---")
 
 
 
 
108
 
109
+ # Answer style
110
+ st.sidebar.header("Answer Style")
111
 
112
+ temperature = st.sidebar.slider(
113
+ "Temperature",
114
+ min_value=0.0,
115
+ max_value=1.0,
116
+ value=0.2,
117
+ step=0.05,
118
+ )
119
 
120
+ answer_length = st.sidebar.selectbox(
121
+ "Answer length",
122
+ ["Short", "Medium", "Long"],
123
+ index=1,
124
+ )
125
 
126
+ st.sidebar.markdown("---")
 
 
 
127
 
128
+ # Vectorstore controls
129
+ st.sidebar.header("Vectorstore Controls")
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ rebuild_clicked = st.sidebar.button(
132
+ "Rebuild vectorstore",
133
+ use_container_width=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
135
+
136
+ st.sidebar.markdown(
137
+ "<small>Use this when you add or modify files in <code>./data</code>.</small>",
138
+ unsafe_allow_html=True,
 
 
139
  )
140
 
141
+ # Build style instruction for the LLM
142
+ length_instruction_map = {
143
+ "Short": "Answer in 1–3 sentences.",
144
+ "Medium": "Answer in 1–2 paragraphs.",
145
+ "Long": "Provide a detailed, multi-paragraph explanation.",
146
+ }
147
+ length_instruction = length_instruction_map[answer_length]
148
+ style_instruction = (
149
+ length_instruction
150
+ + f" Use a response style appropriate for a temperature of {temperature:.2f}, "
151
+ "where lower values are more factual and higher values are more exploratory."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
+
154
+ return (
155
+ model_name,
156
+ top_k,
157
+ retrieval_mode,
158
+ temperature,
159
+ answer_length,
160
+ style_instruction,
161
+ rebuild_clicked,
162
  )
163
 
164
+ # ------------------------------------------------------------------
165
+ # Vectorstore rebuild workflow
166
+ # ------------------------------------------------------------------
167
+
168
+ def handle_rebuild(self) -> None:
169
+ """Render rebuild confirmation dialog and rebuild if confirmed.
170
+
171
+ This manages the 2-step rebuild process:
172
+
173
+ 1. User clicks "Rebuild vectorstore".
174
+ 2. A confirmation dialog appears with "Yes, rebuild" and "Cancel".
175
+
176
+ If confirmed, the vectorstore is regenerated and caches are cleared.
177
  """
178
+ if self.rebuild_clicked:
179
+ st.session_state["rebuild_pending"] = True
180
+
181
+ if not st.session_state["rebuild_pending"]:
182
+ return
183
+
184
+ st.warning(
185
+ "Rebuild the vectorstore from the current contents of ./data? "
186
+ "This will overwrite existing embeddings."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
 
 
 
188
 
189
+ col_left, col_center, col_right = st.columns([1, 2, 1])
190
 
191
+ with col_center:
192
+ confirm = st.button(
193
+ "Yes, rebuild",
194
+ key="confirm_rebuild",
195
+ use_container_width=True,
196
+ )
197
+ cancel = st.button(
198
+ "Cancel",
199
+ key="cancel_rebuild",
200
+ use_container_width=True,
201
+ )
202
 
203
+ # Centered green (confirm) and red (cancel) buttons
204
+ st.markdown(
205
+ """
206
+ <style>
207
+ div[data-testid="column"] div:has(> button[aria-label="Yes, rebuild"]) button {
208
+ background-color: #27ae60 !important;
209
+ color: white !important;
210
+ }
211
+ div[data-testid="column"] div:has(> button[aria-label="Cancel"]) button {
212
+ background-color: #c0392b !important;
213
+ color: white !important;
214
+ }
215
+ </style>
216
+ """,
217
+ unsafe_allow_html=True,
218
+ )
219
 
220
+ if confirm:
221
+ with st.spinner("Rebuilding vectorstore..."):
222
+ run_ingest_cli(data_dir=self.data_dir, persist_dir=self.persist_dir)
223
+ build_or_load_retriever_cached.clear()
224
+ get_chain_cached.clear()
225
+
226
+ self.chain = get_chain_cached(
227
+ model_name=self.model_name,
228
+ top_k=self.top_k,
229
+ retrieval_mode=self.retrieval_mode,
230
+ data_dir=self.data_dir,
231
+ persist_dir=self.persist_dir,
232
+ )
233
+
234
+ st.session_state["rebuild_pending"] = False
235
+ st.success("Vectorstore rebuilt successfully.")
236
+
237
+ elif cancel:
238
+ st.session_state["rebuild_pending"] = False
239
+ st.info("Rebuild canceled.")
240
+
241
+ # ------------------------------------------------------------------
242
+ # Chain loading
243
+ # ------------------------------------------------------------------
244
+
245
+ def ensure_chain_ready(self) -> None:
246
+ """Load or create the QA chain unless a rebuild is still pending."""
247
+ if st.session_state["rebuild_pending"]:
248
+ return
249
+
250
+ if self.chain is None:
251
+ with st.spinner("Initializing knowledge base and chat model..."):
252
+ self.chain = get_chain_cached(
253
+ model_name=self.model_name,
254
+ top_k=self.top_k,
255
+ retrieval_mode=self.retrieval_mode,
256
+ data_dir=self.data_dir,
257
+ persist_dir=self.persist_dir,
258
+ )
259
+ st.success("Knowledge base and model are ready.")
260
+ else:
261
+ st.success("Knowledge base and model are ready.")
262
+
263
+ # ------------------------------------------------------------------
264
+ # Chat UI
265
+ # ------------------------------------------------------------------
266
+
267
+ def render_chat_history(self) -> None:
268
+ """Render previous user and assistant messages."""
269
+ for turn in self.chat_history:
270
+ with st.chat_message("user"):
271
+ st.markdown(turn["question"])
272
+ with st.chat_message("assistant"):
273
+ st.markdown(turn["answer"])
274
+
275
+ def handle_user_input(self) -> None:
276
+ """Process new user queries, run RAG, compute metrics, and display results."""
277
+ if st.session_state["rebuild_pending"] or self.chain is None:
278
+ return
279
+
280
+ user_input = st.chat_input(
281
+ "Ask a question about abalone (biology, data, methodology, etc.)"
282
+ )
283
+ if not user_input:
284
+ return
285
 
286
+ # Render user message
287
+ with st.chat_message("user"):
288
+ st.markdown(user_input)
289
 
290
+ # Run inference
291
+ with st.spinner("Thinking..."):
292
+ prior_history: List[Tuple[str, str]] = [
293
+ (h.get("question"), h.get("answer", ""))
294
+ for h in self.chat_history
295
+ ]
296
 
297
+ styled_question = self.style_instruction + "\n\nQuestion: " + user_input
298
 
299
+ result = self.chain(
300
+ {"question": styled_question, "chat_history": prior_history}
301
+ )
302
+
303
+ answer = (
304
+ result.get("answer")
305
+ or result.get("result")
306
+ or result.get("output_text")
307
+ or ""
308
+ )
309
+ source_docs = result.get("source_documents") or []
310
+
311
+ # Normalize retrieved docs for UI and metrics
312
+ formatted_sources: List[Dict] = []
313
+ for idx, sd in enumerate(source_docs, start=1):
314
+ if isinstance(sd, dict):
315
+ meta = sd.get("metadata", {}) or {}
316
+ text = (
317
+ sd.get("page_content")
318
+ or sd.get("content")
319
+ or sd.get("text", "")
320
+ or ""
321
+ )
322
+ else:
323
+ meta = getattr(sd, "metadata", {}) or {}
324
+ text = (
325
+ getattr(sd, "page_content", None)
326
+ or getattr(sd, "content", "")
327
+ or ""
328
+ )
329
+
330
+ formatted_sources.append(
331
+ {"index": idx, "metadata": meta, "content": str(text)}
332
+ )
333
+
334
+ # Compute simple retrieval quality metrics
335
+ coverage, grounding = compute_quality_scores(
336
+ user_input, answer, formatted_sources
337
  )
338
+ coverage_pct = int(round(coverage * 100))
339
+ grounding_pct = int(round(grounding * 100))
340
 
341
+ # Render assistant message + debug block
342
+ with st.chat_message("assistant"):
343
+ st.markdown(answer)
344
+
345
+ with st.expander("Retrieval Metrics and Sources"):
346
+ st.markdown(f"- Retrieval mode: `{self.retrieval_mode}`")
347
+ st.markdown(f"- k: `{self.top_k}`")
348
+ st.markdown(
349
+ f"- Coverage score (question vs sources): **{coverage_pct}%**"
350
+ )
351
+ st.markdown(
352
+ f"- Grounding score (answer vs sources): **{grounding_pct}%**"
353
+ )
354
+
355
+ if formatted_sources:
356
+ st.markdown("**Retrieved chunks:**")
357
+ for src in formatted_sources:
358
+ label = format_source_label(src["metadata"], src["index"])
359
+ snippet = src["content"][:200].replace("\n", " ")
360
+ st.markdown(f"**[{src['index']}] {label}**")
361
+ st.code(snippet + "...")
362
+
363
+ # Persist turn in chat history
364
+ self.chat_history.append(
365
+ {
366
+ "question": user_input,
367
+ "answer": answer,
368
+ "sources": formatted_sources,
369
+ }
370
  )
371
+ st.session_state["chat_history"] = self.chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
+
374
+ def main() -> None:
375
+ """Main entry point for running the Abalone RAG Chatbot app."""
376
+ app = AbaloneRAGApp()
377
+
378
+ if not ensure_openai_key():
379
+ st.stop()
380
+
381
+ app.handle_rebuild()
382
+ app.ensure_chain_ready()
383
+ app.render_chat_history()
384
+ app.handle_user_input()
385
+
386
+
387
+ if __name__ == "__main__":
388
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__init__.py ADDED
File without changes
src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
src/utils/__pycache__/env.cpython-310.pyc ADDED
Binary file (522 Bytes). View file
 
src/utils/__pycache__/formatting.cpython-310.pyc ADDED
Binary file (552 Bytes). View file
 
src/utils/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
src/utils/__pycache__/rag_runtime.cpython-310.pyc ADDED
Binary file (2.55 kB). View file
 
src/utils/env.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+
4
+
5
+ def ensure_openai_key(env_var: str = "OPENAI_API_KEY") -> bool:
6
+ """Ensure the specified OpenAI API key environment variable is present."""
7
+ if not os.environ.get(env_var):
8
+ st.error(f"{env_var} is not set.")
9
+ return False
10
+ return True
src/utils/formatting.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+
4
+ def format_source_label(meta: Dict, index: int) -> str:
5
+ """Create a readable label for a retrieved chunk."""
6
+ source = (
7
+ meta.get("source")
8
+ or meta.get("file_path")
9
+ or meta.get("path")
10
+ or meta.get("document_id")
11
+ or "Unknown source"
12
+ )
13
+ return f"[{index}] {source}"
14
+
src/utils/metrics.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Dict, Tuple
3
+
4
+
5
+ def tokenize(text: str) -> List[str]:
6
+ """Tokenize a string into lowercase words >3 chars."""
7
+ return [w.lower() for w in re.findall(r"\w+", text) if len(w) > 3]
8
+
9
+
10
+ def compute_quality_scores(
11
+ question: str,
12
+ answer: str,
13
+ sources: List[Dict],
14
+ ) -> Tuple[float, float]:
15
+ """Compute retrieval quality metrics (coverage & grounding).
16
+
17
+ Args:
18
+ question: User's question text.
19
+ answer: Model-generated answer text.
20
+ sources: Retrieved documents/chunks, each with a 'content' field.
21
+
22
+ Returns:
23
+ (coverage, grounding) as floats in [0.0, 1.0].
24
+ """
25
+ all_chunk_text = " ".join(s.get("content", "") for s in sources)
26
+ q_tokens = tokenize(question)
27
+ a_tokens = tokenize(answer)
28
+ c_tokens = set(tokenize(all_chunk_text))
29
+
30
+ if not c_tokens:
31
+ return 0.0, 0.0
32
+
33
+ coverage = (
34
+ sum(1 for t in q_tokens if t in c_tokens) / len(q_tokens)
35
+ if q_tokens
36
+ else 0.0
37
+ )
38
+ grounding = (
39
+ sum(1 for t in a_tokens if t in c_tokens) / len(a_tokens)
40
+ if a_tokens
41
+ else 0.0
42
+ )
43
+
44
+ return coverage, grounding
src/utils/rag_runtime.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import subprocess
3
+ from typing import Any
4
+
5
+ import streamlit as st
6
+
7
+ from src.vectorstore import get_retriever
8
+ from src.qa_chain import make_conversational_chain
9
+
10
+
11
+ def run_ingest_cli(data_dir: str, persist_dir: str) -> None:
12
+ """Run the ingestion module to rebuild the vectorstore.
13
+
14
+ Args:
15
+ data_dir: Directory containing the raw text files.
16
+ persist_dir: Directory where embeddings and Chroma DB should be stored.
17
+
18
+ Raises:
19
+ CalledProcessError: If the underlying subprocess fails.
20
+ """
21
+ cmd = [
22
+ sys.executable,
23
+ "-m",
24
+ "src.ingest",
25
+ "--data-dir",
26
+ data_dir,
27
+ "--persist-dir",
28
+ persist_dir,
29
+ ]
30
+ subprocess.run(cmd, check=True)
31
+
32
+
33
+ @st.cache_resource(show_spinner=False)
34
+ def build_or_load_retriever_cached(
35
+ data_dir: str,
36
+ persist_dir: str,
37
+ top_k: int,
38
+ retrieval_mode: str,
39
+ ) -> Any:
40
+ """Load a retriever from the persisted vectorstore or build a new one.
41
+
42
+ If loading fails—usually because the vectorstore doesn't exist—this
43
+ function triggers ingestion and retries loading.
44
+
45
+ Args:
46
+ data_dir: Directory containing input documents.
47
+ persist_dir: Directory where the Chroma vectorstore is stored.
48
+ top_k: Number of chunks to retrieve for queries.
49
+ retrieval_mode: Retrieval strategy (mmr, similarity, hybrid).
50
+
51
+ Returns:
52
+ An initialized retriever instance.
53
+ """
54
+ try:
55
+ return get_retriever(
56
+ persist_dir=persist_dir,
57
+ top_k=top_k,
58
+ retrieval_mode=retrieval_mode,
59
+ )
60
+ except Exception:
61
+ run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
62
+ return get_retriever(
63
+ persist_dir=persist_dir,
64
+ top_k=top_k,
65
+ retrieval_mode=retrieval_mode,
66
+ )
67
+
68
+
69
+ @st.cache_resource(show_spinner=False)
70
+ def get_chain_cached(
71
+ model_name: str,
72
+ top_k: int,
73
+ retrieval_mode: str,
74
+ data_dir: str,
75
+ persist_dir: str,
76
+ ) -> Any:
77
+ """Create or load a cached conversational QA chain.
78
+
79
+ Args:
80
+ model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4).
81
+ top_k: Number of chunks to retrieve.
82
+ retrieval_mode: Retrieval mode for the retriever.
83
+ data_dir: Path to data directory.
84
+ persist_dir: Path to vectorstore directory.
85
+
86
+ Returns:
87
+ A fully configured conversational QA chain.
88
+ """
89
+ retriever = build_or_load_retriever_cached(
90
+ data_dir=data_dir,
91
+ persist_dir=persist_dir,
92
+ top_k=top_k,
93
+ retrieval_mode=retrieval_mode,
94
+ )
95
+ return make_conversational_chain(retriever, model_name=model_name)