Updating package file structure
Browse files- app.py +342 -305
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/env.cpython-310.pyc +0 -0
- src/utils/__pycache__/formatting.cpython-310.pyc +0 -0
- src/utils/__pycache__/metrics.cpython-310.pyc +0 -0
- src/utils/__pycache__/rag_runtime.cpython-310.pyc +0 -0
- src/utils/env.py +10 -0
- src/utils/formatting.py +14 -0
- src/utils/metrics.py +44 -0
- src/utils/rag_runtime.py +95 -0
app.py
CHANGED
|
@@ -1,351 +1,388 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
-
import subprocess
|
| 4 |
-
import re
|
| 5 |
|
|
|
|
| 6 |
os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
|
| 7 |
os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")
|
| 8 |
os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false")
|
| 9 |
|
| 10 |
import streamlit as st
|
| 11 |
-
from src.vectorstore import get_retriever
|
| 12 |
-
from src.qa_chain import make_conversational_chain
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
"Ask natural-language questions about abalone studies and data. "
|
| 19 |
-
"The app uses a local Chroma vectorstore and OpenAI to retrieve and answer."
|
| 20 |
)
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
# ---------------- Sidebar ----------------
|
| 23 |
-
|
| 24 |
-
st.sidebar.header("Model Settings")
|
| 25 |
|
| 26 |
-
|
| 27 |
-
"
|
| 28 |
-
options=["gpt-3.5-turbo", "gpt-4"],
|
| 29 |
-
index=0,
|
| 30 |
-
)
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
"MMR (diverse)": "mmr",
|
| 51 |
-
"Similarity": "similarity",
|
| 52 |
-
"Hybrid (dense + MMR)": "hybrid",
|
| 53 |
-
}
|
| 54 |
-
retrieval_mode = retrieval_mode_map[retrieval_mode_label]
|
| 55 |
|
| 56 |
-
|
|
|
|
| 57 |
|
| 58 |
-
st.sidebar.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
"Answer length",
|
| 70 |
-
["Short", "Medium", "Long"],
|
| 71 |
-
index=1,
|
| 72 |
-
)
|
| 73 |
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
-
st.sidebar.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
st.sidebar.markdown(
|
| 81 |
-
"<small>Use this when you add or modify files in <code>./data</code>.</small>",
|
| 82 |
-
unsafe_allow_html=True,
|
| 83 |
-
)
|
| 84 |
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
length_instruction_map = {
|
| 88 |
-
"Short": "Answer in 1–3 sentences.",
|
| 89 |
-
"Medium": "Answer in 1–2 paragraphs.",
|
| 90 |
-
"Long": "Provide a detailed, multi-paragraph explanation.",
|
| 91 |
-
}
|
| 92 |
-
length_instruction = length_instruction_map[answer_length]
|
| 93 |
-
style_instruction = (
|
| 94 |
-
length_instruction
|
| 95 |
-
+ f" Use a response style appropriate for a temperature of {temperature:.2f}, "
|
| 96 |
-
"where lower values are more factual and higher values are more exploratory."
|
| 97 |
-
)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
if "chat_history" not in st.session_state:
|
| 103 |
-
st.session_state["chat_history"] = []
|
| 104 |
-
|
| 105 |
-
if "rebuild_pending" not in st.session_state:
|
| 106 |
-
st.session_state["rebuild_pending"] = False
|
| 107 |
-
|
| 108 |
-
# -------------- Helpers ----------------
|
| 109 |
-
|
| 110 |
-
def ensure_openai_key() -> bool:
|
| 111 |
-
if not os.environ.get("OPENAI_API_KEY"):
|
| 112 |
-
st.error("OPENAI_API_KEY is not set.")
|
| 113 |
-
return False
|
| 114 |
-
return True
|
| 115 |
-
|
| 116 |
-
def run_ingest_cli(data_dir: str, persist_dir: str):
|
| 117 |
-
cmd = [
|
| 118 |
-
sys.executable,
|
| 119 |
-
"-m",
|
| 120 |
-
"src.ingest",
|
| 121 |
-
"--data-dir",
|
| 122 |
-
data_dir,
|
| 123 |
-
"--persist-dir",
|
| 124 |
-
persist_dir,
|
| 125 |
-
]
|
| 126 |
-
subprocess.run(cmd, check=True)
|
| 127 |
-
|
| 128 |
-
@st.cache_resource(show_spinner=False)
|
| 129 |
-
def build_or_load_retriever_cached(
|
| 130 |
-
data_dir: str,
|
| 131 |
-
persist_dir: str,
|
| 132 |
-
top_k: int,
|
| 133 |
-
retrieval_mode: str,
|
| 134 |
-
):
|
| 135 |
-
try:
|
| 136 |
-
return get_retriever(
|
| 137 |
-
persist_dir=persist_dir,
|
| 138 |
-
top_k=top_k,
|
| 139 |
-
retrieval_mode=retrieval_mode,
|
| 140 |
)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
top_k=top_k,
|
| 146 |
-
retrieval_mode=retrieval_mode,
|
| 147 |
)
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
source = (
|
| 161 |
-
meta.get("source")
|
| 162 |
-
or meta.get("file_path")
|
| 163 |
-
or meta.get("path")
|
| 164 |
-
or meta.get("document_id")
|
| 165 |
-
or "Unknown source"
|
| 166 |
-
)
|
| 167 |
-
return f"[{index}] {source}"
|
| 168 |
-
|
| 169 |
-
def tokenize(text: str):
|
| 170 |
-
return [w.lower() for w in re.findall(r"\w+", text) if len(w) > 3]
|
| 171 |
-
|
| 172 |
-
def compute_quality_scores(question: str, answer: str, sources: list):
|
| 173 |
-
all_chunk_text = " ".join(s.get("content", "") for s in sources)
|
| 174 |
-
q_tokens = tokenize(question)
|
| 175 |
-
a_tokens = tokenize(answer)
|
| 176 |
-
c_tokens = set(tokenize(all_chunk_text))
|
| 177 |
-
if not c_tokens:
|
| 178 |
-
return 0.0, 0.0
|
| 179 |
-
if not q_tokens:
|
| 180 |
-
coverage = 0.0
|
| 181 |
-
else:
|
| 182 |
-
coverage = sum(1 for t in q_tokens if t in c_tokens) / len(q_tokens)
|
| 183 |
-
if not a_tokens:
|
| 184 |
-
grounding = 0.0
|
| 185 |
-
else:
|
| 186 |
-
grounding = sum(1 for t in a_tokens if t in c_tokens) / len(a_tokens)
|
| 187 |
-
return coverage, grounding
|
| 188 |
-
|
| 189 |
-
if not ensure_openai_key():
|
| 190 |
-
st.stop()
|
| 191 |
-
|
| 192 |
-
# -------------- Rebuild confirmation + chain init ----------------
|
| 193 |
-
|
| 194 |
-
if rebuild_clicked:
|
| 195 |
-
st.session_state["rebuild_pending"] = True
|
| 196 |
-
|
| 197 |
-
chain = None
|
| 198 |
-
|
| 199 |
-
if st.session_state["rebuild_pending"]:
|
| 200 |
-
st.warning(
|
| 201 |
-
"Rebuild the vectorstore from the current contents of ./data? "
|
| 202 |
-
"This will overwrite existing embeddings."
|
| 203 |
-
)
|
| 204 |
-
|
| 205 |
-
col_left, col_center, col_right = st.columns([1, 2, 1])
|
| 206 |
-
|
| 207 |
-
with col_center:
|
| 208 |
-
confirm_rebuild = st.button(
|
| 209 |
-
"Yes, rebuild",
|
| 210 |
-
key="confirm_rebuild",
|
| 211 |
-
use_container_width=True,
|
| 212 |
)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
)
|
| 218 |
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
"""
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
</style>
|
| 231 |
-
""",
|
| 232 |
-
unsafe_allow_html=True,
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
if confirm_rebuild:
|
| 236 |
-
with st.spinner("Rebuilding vectorstore..."):
|
| 237 |
-
run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
|
| 238 |
-
build_or_load_retriever_cached.clear()
|
| 239 |
-
get_chain.clear()
|
| 240 |
-
chain = get_chain(
|
| 241 |
-
model_name=model_name,
|
| 242 |
-
top_k=top_k,
|
| 243 |
-
retrieval_mode=retrieval_mode,
|
| 244 |
-
)
|
| 245 |
-
st.session_state["rebuild_pending"] = False
|
| 246 |
-
st.success("Vectorstore rebuilt successfully.")
|
| 247 |
-
|
| 248 |
-
elif cancel_rebuild:
|
| 249 |
-
st.session_state["rebuild_pending"] = False
|
| 250 |
-
st.info("Rebuild canceled.")
|
| 251 |
-
|
| 252 |
-
if chain is None and not st.session_state["rebuild_pending"]:
|
| 253 |
-
with st.spinner("Initializing knowledge base and chat model..."):
|
| 254 |
-
chain = get_chain(
|
| 255 |
-
model_name=model_name,
|
| 256 |
-
top_k=top_k,
|
| 257 |
-
retrieval_mode=retrieval_mode,
|
| 258 |
)
|
| 259 |
-
st.success("Knowledge base and model are ready.")
|
| 260 |
-
elif chain is not None and not st.session_state["rebuild_pending"]:
|
| 261 |
-
st.success("Knowledge base and model are ready.")
|
| 262 |
|
| 263 |
-
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
| 285 |
|
| 286 |
-
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
)
|
|
|
|
|
|
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
sources_for_ui = []
|
| 301 |
-
for idx, sd in enumerate(source_docs, start=1):
|
| 302 |
-
if isinstance(sd, dict):
|
| 303 |
-
meta = sd.get("metadata", {}) or {}
|
| 304 |
-
content_full = sd.get("page_content") or sd.get("content") or sd.get("text", "")
|
| 305 |
-
else:
|
| 306 |
-
meta = getattr(sd, "metadata", {}) or {}
|
| 307 |
-
content_full = getattr(sd, "page_content", None)
|
| 308 |
-
if content_full is None:
|
| 309 |
-
content_full = getattr(sd, "content", "")
|
| 310 |
-
if content_full is None:
|
| 311 |
-
content_full = ""
|
| 312 |
-
sources_for_ui.append(
|
| 313 |
-
{
|
| 314 |
-
"index": idx,
|
| 315 |
-
"metadata": meta,
|
| 316 |
-
"content": str(content_full),
|
| 317 |
-
}
|
| 318 |
-
)
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
st.
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
for src in sources_for_ui:
|
| 337 |
-
idx = src.get("index", 0)
|
| 338 |
-
meta = src.get("metadata", {}) or {}
|
| 339 |
-
label = format_source_label(meta, idx)
|
| 340 |
-
chunk_text = src.get("content", "")
|
| 341 |
-
snippet = chunk_text[:200].replace("\n", " ")
|
| 342 |
-
st.markdown(f"**[{idx}] {label}**")
|
| 343 |
-
st.code(snippet + "...")
|
| 344 |
-
|
| 345 |
-
st.session_state["chat_history"].append(
|
| 346 |
-
{
|
| 347 |
-
"question": user_input,
|
| 348 |
-
"answer": answer,
|
| 349 |
-
"sources": sources_for_ui,
|
| 350 |
-
}
|
| 351 |
-
)
|
|
|
|
| 1 |
import os
|
| 2 |
+
from typing import List, Dict, Tuple, Optional
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
# Disable telemetry for LangChain and Chroma by default
|
| 5 |
os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
|
| 6 |
os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")
|
| 7 |
os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false")
|
| 8 |
|
| 9 |
import streamlit as st
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
from src.utils.rag_runtime import (
|
| 12 |
+
run_ingest_cli,
|
| 13 |
+
build_or_load_retriever_cached,
|
| 14 |
+
get_chain_cached,
|
|
|
|
|
|
|
| 15 |
)
|
| 16 |
+
from src.utils.metrics import compute_quality_scores
|
| 17 |
+
from src.utils.formatting import format_source_label
|
| 18 |
+
from src.utils.env import ensure_openai_key
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
class AbaloneRAGApp:
|
| 22 |
+
"""Main application class for the Abalone RAG Chatbot."""
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
def __init__(self) -> None:
|
| 25 |
+
"""Initialize the Streamlit page and application state."""
|
| 26 |
+
st.set_page_config(page_title="Abalone RAG Chatbot", page_icon="🐚")
|
| 27 |
|
| 28 |
+
st.title("Abalone RAG Chatbot")
|
| 29 |
+
st.write(
|
| 30 |
+
"Ask natural-language questions about abalone biology, ecology, "
|
| 31 |
+
"and research datasets. The app uses a local Chroma vectorstore "
|
| 32 |
+
"and OpenAI to retrieve and answer questions accurately."
|
| 33 |
+
)
|
| 34 |
|
| 35 |
+
# Data and vectorstore locations
|
| 36 |
+
self.data_dir = "./data"
|
| 37 |
+
self.persist_dir = "./vectorstore"
|
| 38 |
+
|
| 39 |
+
# Initialize session state
|
| 40 |
+
st.session_state.setdefault("chat_history", [])
|
| 41 |
+
st.session_state.setdefault("rebuild_pending", False)
|
| 42 |
+
self.chat_history: List[Dict] = st.session_state["chat_history"]
|
| 43 |
+
|
| 44 |
+
# Sidebar configuration
|
| 45 |
+
(
|
| 46 |
+
self.model_name,
|
| 47 |
+
self.top_k,
|
| 48 |
+
self.retrieval_mode,
|
| 49 |
+
self.temperature,
|
| 50 |
+
self.answer_length,
|
| 51 |
+
self.style_instruction,
|
| 52 |
+
self.rebuild_clicked,
|
| 53 |
+
) = self._build_sidebar()
|
| 54 |
+
|
| 55 |
+
# QA chain instance (loaded lazily)
|
| 56 |
+
self.chain: Optional[object] = None
|
| 57 |
+
|
| 58 |
+
# ------------------------------------------------------------------
|
| 59 |
+
# Sidebar configuration
|
| 60 |
+
# ------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
def _build_sidebar(self) -> Tuple[str, int, str, float, str, str, bool]:
|
| 63 |
+
"""Render all sidebar controls and return model configuration.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple containing:
|
| 67 |
+
- model_name: Which LLM to use.
|
| 68 |
+
- top_k: Number of chunks to retrieve.
|
| 69 |
+
- retrieval_mode: Strategy (mmr, similarity, hybrid).
|
| 70 |
+
- temperature: LLM temperature.
|
| 71 |
+
- answer_length: Short/Medium/Long preference.
|
| 72 |
+
- style_instruction: Natural-language style directive.
|
| 73 |
+
- rebuild_clicked: Whether "Rebuild vectorstore" was pressed.
|
| 74 |
+
"""
|
| 75 |
+
st.sidebar.header("Model Settings")
|
| 76 |
|
| 77 |
+
model_name = st.sidebar.selectbox(
|
| 78 |
+
"Model",
|
| 79 |
+
options=["gpt-3.5-turbo", "gpt-4"],
|
| 80 |
+
index=0,
|
| 81 |
+
)
|
| 82 |
|
| 83 |
+
st.sidebar.markdown("---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
# Retrieval configuration
|
| 86 |
+
st.sidebar.header("Retrieval Configuration")
|
| 87 |
|
| 88 |
+
top_k = st.sidebar.slider(
|
| 89 |
+
"Number of retrieved chunks (k)",
|
| 90 |
+
min_value=2,
|
| 91 |
+
max_value=10,
|
| 92 |
+
value=4,
|
| 93 |
+
)
|
| 94 |
|
| 95 |
+
retrieval_mode_label = st.sidebar.selectbox(
|
| 96 |
+
"Retrieval mode",
|
| 97 |
+
["MMR (diverse)", "Similarity", "Hybrid (dense + MMR)"],
|
| 98 |
+
index=0,
|
| 99 |
+
)
|
| 100 |
+
retrieval_mode_map = {
|
| 101 |
+
"MMR (diverse)": "mmr",
|
| 102 |
+
"Similarity": "similarity",
|
| 103 |
+
"Hybrid (dense + MMR)": "hybrid",
|
| 104 |
+
}
|
| 105 |
+
retrieval_mode = retrieval_mode_map[retrieval_mode_label]
|
| 106 |
|
| 107 |
+
st.sidebar.markdown("---")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
# Answer style
|
| 110 |
+
st.sidebar.header("Answer Style")
|
| 111 |
|
| 112 |
+
temperature = st.sidebar.slider(
|
| 113 |
+
"Temperature",
|
| 114 |
+
min_value=0.0,
|
| 115 |
+
max_value=1.0,
|
| 116 |
+
value=0.2,
|
| 117 |
+
step=0.05,
|
| 118 |
+
)
|
| 119 |
|
| 120 |
+
answer_length = st.sidebar.selectbox(
|
| 121 |
+
"Answer length",
|
| 122 |
+
["Short", "Medium", "Long"],
|
| 123 |
+
index=1,
|
| 124 |
+
)
|
| 125 |
|
| 126 |
+
st.sidebar.markdown("---")
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
# Vectorstore controls
|
| 129 |
+
st.sidebar.header("Vectorstore Controls")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
rebuild_clicked = st.sidebar.button(
|
| 132 |
+
"Rebuild vectorstore",
|
| 133 |
+
use_container_width=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
)
|
| 135 |
+
|
| 136 |
+
st.sidebar.markdown(
|
| 137 |
+
"<small>Use this when you add or modify files in <code>./data</code>.</small>",
|
| 138 |
+
unsafe_allow_html=True,
|
|
|
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
+
# Build style instruction for the LLM
|
| 142 |
+
length_instruction_map = {
|
| 143 |
+
"Short": "Answer in 1–3 sentences.",
|
| 144 |
+
"Medium": "Answer in 1–2 paragraphs.",
|
| 145 |
+
"Long": "Provide a detailed, multi-paragraph explanation.",
|
| 146 |
+
}
|
| 147 |
+
length_instruction = length_instruction_map[answer_length]
|
| 148 |
+
style_instruction = (
|
| 149 |
+
length_instruction
|
| 150 |
+
+ f" Use a response style appropriate for a temperature of {temperature:.2f}, "
|
| 151 |
+
"where lower values are more factual and higher values are more exploratory."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
+
|
| 154 |
+
return (
|
| 155 |
+
model_name,
|
| 156 |
+
top_k,
|
| 157 |
+
retrieval_mode,
|
| 158 |
+
temperature,
|
| 159 |
+
answer_length,
|
| 160 |
+
style_instruction,
|
| 161 |
+
rebuild_clicked,
|
| 162 |
)
|
| 163 |
|
| 164 |
+
# ------------------------------------------------------------------
|
| 165 |
+
# Vectorstore rebuild workflow
|
| 166 |
+
# ------------------------------------------------------------------
|
| 167 |
+
|
| 168 |
+
def handle_rebuild(self) -> None:
|
| 169 |
+
"""Render rebuild confirmation dialog and rebuild if confirmed.
|
| 170 |
+
|
| 171 |
+
This manages the 2-step rebuild process:
|
| 172 |
+
|
| 173 |
+
1. User clicks "Rebuild vectorstore".
|
| 174 |
+
2. A confirmation dialog appears with "Yes, rebuild" and "Cancel".
|
| 175 |
+
|
| 176 |
+
If confirmed, the vectorstore is regenerated and caches are cleared.
|
| 177 |
"""
|
| 178 |
+
if self.rebuild_clicked:
|
| 179 |
+
st.session_state["rebuild_pending"] = True
|
| 180 |
+
|
| 181 |
+
if not st.session_state["rebuild_pending"]:
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
st.warning(
|
| 185 |
+
"Rebuild the vectorstore from the current contents of ./data? "
|
| 186 |
+
"This will overwrite existing embeddings."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
+
col_left, col_center, col_right = st.columns([1, 2, 1])
|
| 190 |
|
| 191 |
+
with col_center:
|
| 192 |
+
confirm = st.button(
|
| 193 |
+
"Yes, rebuild",
|
| 194 |
+
key="confirm_rebuild",
|
| 195 |
+
use_container_width=True,
|
| 196 |
+
)
|
| 197 |
+
cancel = st.button(
|
| 198 |
+
"Cancel",
|
| 199 |
+
key="cancel_rebuild",
|
| 200 |
+
use_container_width=True,
|
| 201 |
+
)
|
| 202 |
|
| 203 |
+
# Centered green (confirm) and red (cancel) buttons
|
| 204 |
+
st.markdown(
|
| 205 |
+
"""
|
| 206 |
+
<style>
|
| 207 |
+
div[data-testid="column"] div:has(> button[aria-label="Yes, rebuild"]) button {
|
| 208 |
+
background-color: #27ae60 !important;
|
| 209 |
+
color: white !important;
|
| 210 |
+
}
|
| 211 |
+
div[data-testid="column"] div:has(> button[aria-label="Cancel"]) button {
|
| 212 |
+
background-color: #c0392b !important;
|
| 213 |
+
color: white !important;
|
| 214 |
+
}
|
| 215 |
+
</style>
|
| 216 |
+
""",
|
| 217 |
+
unsafe_allow_html=True,
|
| 218 |
+
)
|
| 219 |
|
| 220 |
+
if confirm:
|
| 221 |
+
with st.spinner("Rebuilding vectorstore..."):
|
| 222 |
+
run_ingest_cli(data_dir=self.data_dir, persist_dir=self.persist_dir)
|
| 223 |
+
build_or_load_retriever_cached.clear()
|
| 224 |
+
get_chain_cached.clear()
|
| 225 |
+
|
| 226 |
+
self.chain = get_chain_cached(
|
| 227 |
+
model_name=self.model_name,
|
| 228 |
+
top_k=self.top_k,
|
| 229 |
+
retrieval_mode=self.retrieval_mode,
|
| 230 |
+
data_dir=self.data_dir,
|
| 231 |
+
persist_dir=self.persist_dir,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
st.session_state["rebuild_pending"] = False
|
| 235 |
+
st.success("Vectorstore rebuilt successfully.")
|
| 236 |
+
|
| 237 |
+
elif cancel:
|
| 238 |
+
st.session_state["rebuild_pending"] = False
|
| 239 |
+
st.info("Rebuild canceled.")
|
| 240 |
+
|
| 241 |
+
# ------------------------------------------------------------------
|
| 242 |
+
# Chain loading
|
| 243 |
+
# ------------------------------------------------------------------
|
| 244 |
+
|
| 245 |
+
def ensure_chain_ready(self) -> None:
|
| 246 |
+
"""Load or create the QA chain unless a rebuild is still pending."""
|
| 247 |
+
if st.session_state["rebuild_pending"]:
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
if self.chain is None:
|
| 251 |
+
with st.spinner("Initializing knowledge base and chat model..."):
|
| 252 |
+
self.chain = get_chain_cached(
|
| 253 |
+
model_name=self.model_name,
|
| 254 |
+
top_k=self.top_k,
|
| 255 |
+
retrieval_mode=self.retrieval_mode,
|
| 256 |
+
data_dir=self.data_dir,
|
| 257 |
+
persist_dir=self.persist_dir,
|
| 258 |
+
)
|
| 259 |
+
st.success("Knowledge base and model are ready.")
|
| 260 |
+
else:
|
| 261 |
+
st.success("Knowledge base and model are ready.")
|
| 262 |
+
|
| 263 |
+
# ------------------------------------------------------------------
|
| 264 |
+
# Chat UI
|
| 265 |
+
# ------------------------------------------------------------------
|
| 266 |
+
|
| 267 |
+
def render_chat_history(self) -> None:
|
| 268 |
+
"""Render previous user and assistant messages."""
|
| 269 |
+
for turn in self.chat_history:
|
| 270 |
+
with st.chat_message("user"):
|
| 271 |
+
st.markdown(turn["question"])
|
| 272 |
+
with st.chat_message("assistant"):
|
| 273 |
+
st.markdown(turn["answer"])
|
| 274 |
+
|
| 275 |
+
def handle_user_input(self) -> None:
|
| 276 |
+
"""Process new user queries, run RAG, compute metrics, and display results."""
|
| 277 |
+
if st.session_state["rebuild_pending"] or self.chain is None:
|
| 278 |
+
return
|
| 279 |
+
|
| 280 |
+
user_input = st.chat_input(
|
| 281 |
+
"Ask a question about abalone (biology, data, methodology, etc.)"
|
| 282 |
+
)
|
| 283 |
+
if not user_input:
|
| 284 |
+
return
|
| 285 |
|
| 286 |
+
# Render user message
|
| 287 |
+
with st.chat_message("user"):
|
| 288 |
+
st.markdown(user_input)
|
| 289 |
|
| 290 |
+
# Run inference
|
| 291 |
+
with st.spinner("Thinking..."):
|
| 292 |
+
prior_history: List[Tuple[str, str]] = [
|
| 293 |
+
(h.get("question"), h.get("answer", ""))
|
| 294 |
+
for h in self.chat_history
|
| 295 |
+
]
|
| 296 |
|
| 297 |
+
styled_question = self.style_instruction + "\n\nQuestion: " + user_input
|
| 298 |
|
| 299 |
+
result = self.chain(
|
| 300 |
+
{"question": styled_question, "chat_history": prior_history}
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
answer = (
|
| 304 |
+
result.get("answer")
|
| 305 |
+
or result.get("result")
|
| 306 |
+
or result.get("output_text")
|
| 307 |
+
or ""
|
| 308 |
+
)
|
| 309 |
+
source_docs = result.get("source_documents") or []
|
| 310 |
+
|
| 311 |
+
# Normalize retrieved docs for UI and metrics
|
| 312 |
+
formatted_sources: List[Dict] = []
|
| 313 |
+
for idx, sd in enumerate(source_docs, start=1):
|
| 314 |
+
if isinstance(sd, dict):
|
| 315 |
+
meta = sd.get("metadata", {}) or {}
|
| 316 |
+
text = (
|
| 317 |
+
sd.get("page_content")
|
| 318 |
+
or sd.get("content")
|
| 319 |
+
or sd.get("text", "")
|
| 320 |
+
or ""
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
meta = getattr(sd, "metadata", {}) or {}
|
| 324 |
+
text = (
|
| 325 |
+
getattr(sd, "page_content", None)
|
| 326 |
+
or getattr(sd, "content", "")
|
| 327 |
+
or ""
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
formatted_sources.append(
|
| 331 |
+
{"index": idx, "metadata": meta, "content": str(text)}
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Compute simple retrieval quality metrics
|
| 335 |
+
coverage, grounding = compute_quality_scores(
|
| 336 |
+
user_input, answer, formatted_sources
|
| 337 |
)
|
| 338 |
+
coverage_pct = int(round(coverage * 100))
|
| 339 |
+
grounding_pct = int(round(grounding * 100))
|
| 340 |
|
| 341 |
+
# Render assistant message + debug block
|
| 342 |
+
with st.chat_message("assistant"):
|
| 343 |
+
st.markdown(answer)
|
| 344 |
+
|
| 345 |
+
with st.expander("Retrieval Metrics and Sources"):
|
| 346 |
+
st.markdown(f"- Retrieval mode: `{self.retrieval_mode}`")
|
| 347 |
+
st.markdown(f"- k: `{self.top_k}`")
|
| 348 |
+
st.markdown(
|
| 349 |
+
f"- Coverage score (question vs sources): **{coverage_pct}%**"
|
| 350 |
+
)
|
| 351 |
+
st.markdown(
|
| 352 |
+
f"- Grounding score (answer vs sources): **{grounding_pct}%**"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if formatted_sources:
|
| 356 |
+
st.markdown("**Retrieved chunks:**")
|
| 357 |
+
for src in formatted_sources:
|
| 358 |
+
label = format_source_label(src["metadata"], src["index"])
|
| 359 |
+
snippet = src["content"][:200].replace("\n", " ")
|
| 360 |
+
st.markdown(f"**[{src['index']}] {label}**")
|
| 361 |
+
st.code(snippet + "...")
|
| 362 |
+
|
| 363 |
+
# Persist turn in chat history
|
| 364 |
+
self.chat_history.append(
|
| 365 |
+
{
|
| 366 |
+
"question": user_input,
|
| 367 |
+
"answer": answer,
|
| 368 |
+
"sources": formatted_sources,
|
| 369 |
+
}
|
| 370 |
)
|
| 371 |
+
st.session_state["chat_history"] = self.chat_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
+
|
| 374 |
+
def main() -> None:
|
| 375 |
+
"""Main entry point for running the Abalone RAG Chatbot app."""
|
| 376 |
+
app = AbaloneRAGApp()
|
| 377 |
+
|
| 378 |
+
if not ensure_openai_key():
|
| 379 |
+
st.stop()
|
| 380 |
+
|
| 381 |
+
app.handle_rebuild()
|
| 382 |
+
app.ensure_chain_ready()
|
| 383 |
+
app.render_chat_history()
|
| 384 |
+
app.handle_user_input()
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
src/utils/__pycache__/env.cpython-310.pyc
ADDED
|
Binary file (522 Bytes). View file
|
|
|
src/utils/__pycache__/formatting.cpython-310.pyc
ADDED
|
Binary file (552 Bytes). View file
|
|
|
src/utils/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (1.76 kB). View file
|
|
|
src/utils/__pycache__/rag_runtime.cpython-310.pyc
ADDED
|
Binary file (2.55 kB). View file
|
|
|
src/utils/env.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import streamlit as st
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def ensure_openai_key(env_var: str = "OPENAI_API_KEY") -> bool:
|
| 6 |
+
"""Ensure the specified OpenAI API key environment variable is present."""
|
| 7 |
+
if not os.environ.get(env_var):
|
| 8 |
+
st.error(f"{env_var} is not set.")
|
| 9 |
+
return False
|
| 10 |
+
return True
|
src/utils/formatting.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def format_source_label(meta: Dict, index: int) -> str:
|
| 5 |
+
"""Create a readable label for a retrieved chunk."""
|
| 6 |
+
source = (
|
| 7 |
+
meta.get("source")
|
| 8 |
+
or meta.get("file_path")
|
| 9 |
+
or meta.get("path")
|
| 10 |
+
or meta.get("document_id")
|
| 11 |
+
or "Unknown source"
|
| 12 |
+
)
|
| 13 |
+
return f"[{index}] {source}"
|
| 14 |
+
|
src/utils/metrics.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List, Dict, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def tokenize(text: str) -> List[str]:
|
| 6 |
+
"""Tokenize a string into lowercase words >3 chars."""
|
| 7 |
+
return [w.lower() for w in re.findall(r"\w+", text) if len(w) > 3]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def compute_quality_scores(
|
| 11 |
+
question: str,
|
| 12 |
+
answer: str,
|
| 13 |
+
sources: List[Dict],
|
| 14 |
+
) -> Tuple[float, float]:
|
| 15 |
+
"""Compute retrieval quality metrics (coverage & grounding).
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
question: User's question text.
|
| 19 |
+
answer: Model-generated answer text.
|
| 20 |
+
sources: Retrieved documents/chunks, each with a 'content' field.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
(coverage, grounding) as floats in [0.0, 1.0].
|
| 24 |
+
"""
|
| 25 |
+
all_chunk_text = " ".join(s.get("content", "") for s in sources)
|
| 26 |
+
q_tokens = tokenize(question)
|
| 27 |
+
a_tokens = tokenize(answer)
|
| 28 |
+
c_tokens = set(tokenize(all_chunk_text))
|
| 29 |
+
|
| 30 |
+
if not c_tokens:
|
| 31 |
+
return 0.0, 0.0
|
| 32 |
+
|
| 33 |
+
coverage = (
|
| 34 |
+
sum(1 for t in q_tokens if t in c_tokens) / len(q_tokens)
|
| 35 |
+
if q_tokens
|
| 36 |
+
else 0.0
|
| 37 |
+
)
|
| 38 |
+
grounding = (
|
| 39 |
+
sum(1 for t in a_tokens if t in c_tokens) / len(a_tokens)
|
| 40 |
+
if a_tokens
|
| 41 |
+
else 0.0
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
return coverage, grounding
|
src/utils/rag_runtime.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import subprocess
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
|
| 7 |
+
from src.vectorstore import get_retriever
|
| 8 |
+
from src.qa_chain import make_conversational_chain
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def run_ingest_cli(data_dir: str, persist_dir: str) -> None:
|
| 12 |
+
"""Run the ingestion module to rebuild the vectorstore.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
data_dir: Directory containing the raw text files.
|
| 16 |
+
persist_dir: Directory where embeddings and Chroma DB should be stored.
|
| 17 |
+
|
| 18 |
+
Raises:
|
| 19 |
+
CalledProcessError: If the underlying subprocess fails.
|
| 20 |
+
"""
|
| 21 |
+
cmd = [
|
| 22 |
+
sys.executable,
|
| 23 |
+
"-m",
|
| 24 |
+
"src.ingest",
|
| 25 |
+
"--data-dir",
|
| 26 |
+
data_dir,
|
| 27 |
+
"--persist-dir",
|
| 28 |
+
persist_dir,
|
| 29 |
+
]
|
| 30 |
+
subprocess.run(cmd, check=True)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@st.cache_resource(show_spinner=False)
|
| 34 |
+
def build_or_load_retriever_cached(
|
| 35 |
+
data_dir: str,
|
| 36 |
+
persist_dir: str,
|
| 37 |
+
top_k: int,
|
| 38 |
+
retrieval_mode: str,
|
| 39 |
+
) -> Any:
|
| 40 |
+
"""Load a retriever from the persisted vectorstore or build a new one.
|
| 41 |
+
|
| 42 |
+
If loading fails—usually because the vectorstore doesn't exist—this
|
| 43 |
+
function triggers ingestion and retries loading.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
data_dir: Directory containing input documents.
|
| 47 |
+
persist_dir: Directory where the Chroma vectorstore is stored.
|
| 48 |
+
top_k: Number of chunks to retrieve for queries.
|
| 49 |
+
retrieval_mode: Retrieval strategy (mmr, similarity, hybrid).
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
An initialized retriever instance.
|
| 53 |
+
"""
|
| 54 |
+
try:
|
| 55 |
+
return get_retriever(
|
| 56 |
+
persist_dir=persist_dir,
|
| 57 |
+
top_k=top_k,
|
| 58 |
+
retrieval_mode=retrieval_mode,
|
| 59 |
+
)
|
| 60 |
+
except Exception:
|
| 61 |
+
run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
|
| 62 |
+
return get_retriever(
|
| 63 |
+
persist_dir=persist_dir,
|
| 64 |
+
top_k=top_k,
|
| 65 |
+
retrieval_mode=retrieval_mode,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@st.cache_resource(show_spinner=False)
|
| 70 |
+
def get_chain_cached(
|
| 71 |
+
model_name: str,
|
| 72 |
+
top_k: int,
|
| 73 |
+
retrieval_mode: str,
|
| 74 |
+
data_dir: str,
|
| 75 |
+
persist_dir: str,
|
| 76 |
+
) -> Any:
|
| 77 |
+
"""Create or load a cached conversational QA chain.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4).
|
| 81 |
+
top_k: Number of chunks to retrieve.
|
| 82 |
+
retrieval_mode: Retrieval mode for the retriever.
|
| 83 |
+
data_dir: Path to data directory.
|
| 84 |
+
persist_dir: Path to vectorstore directory.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
A fully configured conversational QA chain.
|
| 88 |
+
"""
|
| 89 |
+
retriever = build_or_load_retriever_cached(
|
| 90 |
+
data_dir=data_dir,
|
| 91 |
+
persist_dir=persist_dir,
|
| 92 |
+
top_k=top_k,
|
| 93 |
+
retrieval_mode=retrieval_mode,
|
| 94 |
+
)
|
| 95 |
+
return make_conversational_chain(retriever, model_name=model_name)
|