|
|
import os |
|
|
from typing import List, Dict, Tuple, Optional, Any |
|
|
import streamlit as st |
|
|
import logging |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false") |
|
|
os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true") |
|
|
os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false") |
|
|
|
|
|
from src.utils.rag_runtime import ( |
|
|
run_ingest_cli, |
|
|
build_or_load_retriever_cached, |
|
|
get_chain_cached, |
|
|
answer_with_kg, |
|
|
) |
|
|
from src.utils.metrics import compute_quality_scores |
|
|
from src.utils.formatting import format_source_label |
|
|
from src.utils.env import ensure_openai_key |
|
|
|
|
|
|
|
|
class AbaloneRAGApp: |
|
|
"""Main application class for the Abalone RAG Chatbot.""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
"""Initialize the Streamlit page and application state.""" |
|
|
st.set_page_config(page_title="Abalone RAG Chatbot", page_icon="π") |
|
|
|
|
|
|
|
|
header_col, action_col = st.columns([5, 1]) |
|
|
with header_col: |
|
|
st.title("Abalone RAG Chatbot") |
|
|
st.write( |
|
|
"Ask natural-language questions about abalone biology, ecology, " |
|
|
"and research datasets. The app uses a local Chroma vectorstore " |
|
|
"and OpenAI to retrieve and answer questions accurately." |
|
|
) |
|
|
with action_col: |
|
|
|
|
|
self._top_rebuild_clicked = st.button( |
|
|
"Rebuild vectorstore", |
|
|
key="top_rebuild", |
|
|
use_container_width=True, |
|
|
) |
|
|
|
|
|
|
|
|
self.data_dir = "./data" |
|
|
self.persist_dir = "./vectorstore" |
|
|
|
|
|
|
|
|
st.session_state.setdefault("chat_history", []) |
|
|
st.session_state.setdefault("rebuild_pending", False) |
|
|
self.chat_history: List[Dict] = st.session_state["chat_history"] |
|
|
|
|
|
|
|
|
( |
|
|
self.model_name, |
|
|
self.top_k, |
|
|
self.retrieval_mode, |
|
|
self.temperature, |
|
|
self.answer_length, |
|
|
self.style_instruction, |
|
|
self.use_kg, |
|
|
self.kg_hops, |
|
|
) = self._build_sidebar() |
|
|
|
|
|
|
|
|
self.rebuild_clicked = bool(getattr(self, "_top_rebuild_clicked", False)) |
|
|
|
|
|
|
|
|
|
|
|
self.chain: Optional[Any] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_sidebar(self) -> Tuple[str, int, str, float, str, str, bool, int]: |
|
|
"""Render all sidebar controls and return model configuration. |
|
|
|
|
|
Returns: |
|
|
Tuple containing: |
|
|
- model_name: Which LLM to use. |
|
|
- top_k: Number of chunks to retrieve. |
|
|
- retrieval_mode: Strategy (mmr, similarity, hybrid). |
|
|
- temperature: LLM temperature. |
|
|
- answer_length: Short/Medium/Long preference. |
|
|
- style_instruction: Natural-language style directive. |
|
|
- rebuild_clicked: Whether "Rebuild vectorstore" was pressed. |
|
|
""" |
|
|
st.sidebar.header("Model Settings") |
|
|
|
|
|
model_name = st.sidebar.selectbox( |
|
|
"Model", |
|
|
options=["gpt-3.5-turbo", "gpt-4"], |
|
|
index=0, |
|
|
) |
|
|
|
|
|
st.sidebar.markdown("---") |
|
|
|
|
|
|
|
|
st.sidebar.header("Retrieval Configuration") |
|
|
|
|
|
top_k = st.sidebar.slider( |
|
|
"Number of retrieved chunks (k)", |
|
|
min_value=2, |
|
|
max_value=10, |
|
|
value=4, |
|
|
) |
|
|
|
|
|
retrieval_mode_label = st.sidebar.selectbox( |
|
|
"Retrieval mode", |
|
|
["MMR (diverse)", "Similarity", "Hybrid (dense + MMR)"], |
|
|
index=2, |
|
|
) |
|
|
retrieval_mode_map = { |
|
|
"MMR (diverse)": "mmr", |
|
|
"Similarity": "similarity", |
|
|
"Hybrid (dense + MMR)": "hybrid", |
|
|
} |
|
|
retrieval_mode = retrieval_mode_map[retrieval_mode_label] |
|
|
|
|
|
|
|
|
st.sidebar.markdown("---") |
|
|
st.sidebar.header("Knowledge Graph") |
|
|
use_kg = st.sidebar.checkbox("Use knowledge graph for retrieval", value=False) |
|
|
kg_hops = st.sidebar.slider("KG hops", min_value=1, max_value=3, value=1) |
|
|
|
|
|
st.sidebar.markdown("---") |
|
|
|
|
|
|
|
|
st.sidebar.header("Answer Style") |
|
|
|
|
|
temperature = st.sidebar.slider( |
|
|
"Temperature", |
|
|
min_value=0.0, |
|
|
max_value=1.0, |
|
|
value=0.2, |
|
|
step=0.05, |
|
|
) |
|
|
|
|
|
answer_length = st.sidebar.selectbox( |
|
|
"Answer length", |
|
|
["Short", "Medium", "Long"], |
|
|
index=1, |
|
|
) |
|
|
|
|
|
|
|
|
st.sidebar.markdown("---") |
|
|
st.sidebar.markdown("<small>To rebuild the vectorstore use the top-right \"Rebuild vectorstore\" button.</small>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
length_instruction_map = { |
|
|
"Short": "Answer in 1β3 sentences.", |
|
|
"Medium": "Answer in 1β2 paragraphs.", |
|
|
"Long": "Provide a detailed, multi-paragraph explanation.", |
|
|
} |
|
|
length_instruction = length_instruction_map[answer_length] |
|
|
style_instruction = ( |
|
|
length_instruction |
|
|
+ f" Use a response style appropriate for a temperature of {temperature:.2f}, " |
|
|
"where lower values are more factual and higher values are more exploratory." |
|
|
) |
|
|
|
|
|
return ( |
|
|
model_name, |
|
|
top_k, |
|
|
retrieval_mode, |
|
|
temperature, |
|
|
answer_length, |
|
|
style_instruction, |
|
|
use_kg, |
|
|
kg_hops, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_rebuild(self) -> None: |
|
|
"""Render rebuild confirmation dialog and rebuild if confirmed. |
|
|
|
|
|
This manages the 2-step rebuild process: |
|
|
|
|
|
1. User clicks "Rebuild vectorstore". |
|
|
2. A confirmation dialog appears with "Yes, rebuild" and "Cancel". |
|
|
|
|
|
If confirmed, the vectorstore is regenerated and caches are cleared. |
|
|
""" |
|
|
if self.rebuild_clicked: |
|
|
st.session_state["rebuild_pending"] = True |
|
|
|
|
|
if not st.session_state["rebuild_pending"]: |
|
|
return |
|
|
|
|
|
st.warning( |
|
|
"Rebuild the vectorstore from the current contents of ./data? " |
|
|
"This will overwrite existing embeddings." |
|
|
) |
|
|
|
|
|
col_left, col_center, col_right = st.columns([1, 2, 1]) |
|
|
|
|
|
with col_center: |
|
|
confirm = st.button( |
|
|
"Yes, rebuild", |
|
|
key="confirm_rebuild", |
|
|
use_container_width=True, |
|
|
) |
|
|
cancel = st.button( |
|
|
"Cancel", |
|
|
key="cancel_rebuild", |
|
|
use_container_width=True, |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
div[data-testid="column"] div:has(> button[aria-label="Yes, rebuild"]) button { |
|
|
background-color: #27ae60 !important; |
|
|
color: white !important; |
|
|
} |
|
|
div[data-testid="column"] div:has(> button[aria-label="Cancel"]) button { |
|
|
background-color: #c0392b !important; |
|
|
color: white !important; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
def _ui_log(msg: str): |
|
|
try: |
|
|
os.makedirs(self.persist_dir, exist_ok=True) |
|
|
with open(os.path.join(self.persist_dir, "ui_rebuild.log"), "a", encoding="utf-8") as fh: |
|
|
fh.write(f"{msg}\n") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if confirm: |
|
|
_ui_log(f"{datetime.utcnow().isoformat()} - Confirm rebuild clicked by user") |
|
|
with st.spinner("Rebuilding vectorstore..."): |
|
|
try: |
|
|
out = run_ingest_cli(data_dir=self.data_dir, persist_dir=self.persist_dir) |
|
|
_ui_log(f"{datetime.utcnow().isoformat()} - Rebuild succeeded") |
|
|
except Exception as e: |
|
|
import subprocess as _sp |
|
|
_ui_log(f"{datetime.utcnow().isoformat()} - Rebuild failed: {e}") |
|
|
if isinstance(e, _sp.CalledProcessError): |
|
|
stderr = getattr(e, 'stderr', None) |
|
|
stdout = getattr(e, 'output', None) or getattr(e, 'stdout', None) |
|
|
st.error("Rebuild failed. See logs below.") |
|
|
if stdout: |
|
|
st.markdown("**ingest stdout:**") |
|
|
st.code(stdout) |
|
|
if stderr: |
|
|
st.markdown("**ingest stderr:**") |
|
|
st.code(stderr) |
|
|
else: |
|
|
st.error(f"Rebuild failed: {e}") |
|
|
st.session_state["rebuild_pending"] = False |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
build_or_load_retriever_cached.clear() |
|
|
get_chain_cached.clear() |
|
|
except Exception: |
|
|
|
|
|
_ui_log(f"{datetime.utcnow().isoformat()} - Warning: failed to clear cached functions") |
|
|
|
|
|
self.chain = get_chain_cached( |
|
|
model_name=self.model_name, |
|
|
top_k=self.top_k, |
|
|
retrieval_mode=self.retrieval_mode, |
|
|
data_dir=self.data_dir, |
|
|
persist_dir=self.persist_dir, |
|
|
) |
|
|
|
|
|
st.session_state["rebuild_pending"] = False |
|
|
st.success("Vectorstore rebuilt successfully.") |
|
|
|
|
|
elif cancel: |
|
|
st.session_state["rebuild_pending"] = False |
|
|
st.info("Rebuild canceled.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_chain_ready(self) -> None: |
|
|
"""Load or create the QA chain unless a rebuild is still pending.""" |
|
|
if st.session_state["rebuild_pending"]: |
|
|
return |
|
|
|
|
|
if self.chain is None: |
|
|
with st.spinner("Initializing knowledge base and chat model..."): |
|
|
self.chain = get_chain_cached( |
|
|
model_name=self.model_name, |
|
|
top_k=self.top_k, |
|
|
retrieval_mode=self.retrieval_mode, |
|
|
data_dir=self.data_dir, |
|
|
persist_dir=self.persist_dir, |
|
|
) |
|
|
st.success("Knowledge base and model are ready.") |
|
|
else: |
|
|
st.success("Knowledge base and model are ready.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def render_chat_history(self) -> None: |
|
|
"""Render previous user and assistant messages.""" |
|
|
for turn in self.chat_history: |
|
|
with st.chat_message("user"): |
|
|
st.markdown(turn["question"]) |
|
|
with st.chat_message("assistant"): |
|
|
st.markdown(turn["answer"]) |
|
|
|
|
|
def handle_user_input(self) -> None: |
|
|
"""Process new user queries, run RAG, compute metrics, and display results.""" |
|
|
if st.session_state["rebuild_pending"] or self.chain is None: |
|
|
return |
|
|
|
|
|
user_input = st.chat_input( |
|
|
"Ask a question about abalone (biology, data, methodology, etc.)" |
|
|
) |
|
|
if not user_input: |
|
|
return |
|
|
|
|
|
|
|
|
with st.chat_message("user"): |
|
|
st.markdown(user_input) |
|
|
|
|
|
|
|
|
with st.spinner("Thinking..."): |
|
|
prior_history: List[Tuple[str, str]] = [ |
|
|
(h.get("question"), h.get("answer", "")) |
|
|
for h in self.chat_history |
|
|
] |
|
|
|
|
|
styled_question = self.style_instruction + "\n\nQuestion: " + user_input |
|
|
|
|
|
if self.chain is None: |
|
|
st.error("Model not initialized. Please wait for the knowledge base and model to be ready or rebuild the vectorstore.") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attempted_rebuild = False |
|
|
last_exception = None |
|
|
while True: |
|
|
try: |
|
|
if getattr(self, 'use_kg', False): |
|
|
result = answer_with_kg( |
|
|
self.chain, |
|
|
styled_question, |
|
|
prior_history, |
|
|
persist_dir=self.persist_dir, |
|
|
kg_hops=self.kg_hops, |
|
|
) |
|
|
else: |
|
|
result = self.chain({"question": styled_question, "chat_history": prior_history}) |
|
|
break |
|
|
except Exception as e: |
|
|
|
|
|
last_exception = e |
|
|
|
|
|
if attempted_rebuild: |
|
|
st.error("Retrieval error: failed to query the knowledge base. Try rebuilding the vectorstore manually.") |
|
|
|
|
|
st.exception(e) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
attempted_rebuild = True |
|
|
st.warning("Detected retrieval backend issue β attempting to rebuild the vectorstore and retry...") |
|
|
try: |
|
|
run_ingest_cli(data_dir=self.data_dir, persist_dir=self.persist_dir) |
|
|
except Exception as rebuild_err: |
|
|
st.error("Automatic rebuild failed; please rebuild manually from the sidebar or CLI.") |
|
|
st.exception(rebuild_err) |
|
|
return |
|
|
|
|
|
try: |
|
|
build_or_load_retriever_cached.clear() |
|
|
get_chain_cached.clear() |
|
|
self.chain = get_chain_cached( |
|
|
model_name=self.model_name, |
|
|
top_k=self.top_k, |
|
|
retrieval_mode=self.retrieval_mode, |
|
|
data_dir=self.data_dir, |
|
|
persist_dir=self.persist_dir, |
|
|
) |
|
|
except Exception as reload_err: |
|
|
st.error("Failed to reload the QA chain after rebuilding the vectorstore.") |
|
|
st.exception(reload_err) |
|
|
return |
|
|
|
|
|
|
|
|
answer = ( |
|
|
result.get("answer") |
|
|
or result.get("result") |
|
|
or result.get("output_text") |
|
|
or "" |
|
|
) |
|
|
source_docs = result.get("source_documents") or [] |
|
|
|
|
|
|
|
|
formatted_sources: List[Dict] = [] |
|
|
for idx, sd in enumerate(source_docs, start=1): |
|
|
if isinstance(sd, dict): |
|
|
meta = sd.get("metadata", {}) or {} |
|
|
text = ( |
|
|
sd.get("page_content") |
|
|
or sd.get("content") |
|
|
or sd.get("text", "") |
|
|
or "" |
|
|
) |
|
|
else: |
|
|
meta = getattr(sd, "metadata", {}) or {} |
|
|
text = ( |
|
|
getattr(sd, "page_content", None) |
|
|
or getattr(sd, "content", "") |
|
|
or "" |
|
|
) |
|
|
|
|
|
formatted_sources.append( |
|
|
{"index": idx, "metadata": meta, "content": str(text)} |
|
|
) |
|
|
|
|
|
|
|
|
coverage, grounding = compute_quality_scores( |
|
|
user_input, answer, formatted_sources |
|
|
) |
|
|
coverage_pct = int(round(coverage * 100)) |
|
|
grounding_pct = int(round(grounding * 100)) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
st.markdown(answer) |
|
|
|
|
|
with st.expander("Retrieval Metrics and Sources"): |
|
|
st.markdown(f"- Retrieval mode: `{self.retrieval_mode}`") |
|
|
st.markdown(f"- k: `{self.top_k}`") |
|
|
st.markdown( |
|
|
f"- Coverage score (question vs sources): **{coverage_pct}%**" |
|
|
) |
|
|
st.markdown( |
|
|
f"- Grounding score (answer vs sources): **{grounding_pct}%**" |
|
|
) |
|
|
|
|
|
if formatted_sources: |
|
|
st.markdown("**Retrieved chunks:**") |
|
|
for src in formatted_sources: |
|
|
label = format_source_label(src["metadata"], src["index"]) |
|
|
snippet = src["content"][:200].replace("\n", " ") |
|
|
st.markdown(f"**[{src['index']}] {label}**") |
|
|
st.code(snippet + "...") |
|
|
|
|
|
|
|
|
self.chat_history.append( |
|
|
{ |
|
|
"question": user_input, |
|
|
"answer": answer, |
|
|
"sources": formatted_sources, |
|
|
} |
|
|
) |
|
|
st.session_state["chat_history"] = self.chat_history |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
"""Main entry point for running the Abalone RAG Chatbot app.""" |
|
|
app = AbaloneRAGApp() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.handle_rebuild() |
|
|
|
|
|
if not ensure_openai_key(): |
|
|
st.stop() |
|
|
|
|
|
app.ensure_chain_ready() |
|
|
app.render_chat_history() |
|
|
app.handle_user_input() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|