import json import os from datetime import datetime from pathlib import Path from typing import Any, Dict, Generator, List, Optional, Tuple import httpx import streamlit as st from services.backend_client import post_upload_text from services.file_convert import convert_uploaded_file_to_text def get_backend_base_url() -> str: """Prefer Streamlit secrets, then environment variable, then localhost.""" if "BACKEND_BASE_URL" in st.secrets: base_url = st.secrets["BACKEND_BASE_URL"] else: base_url = os.getenv("BACKEND_BASE_URL", "http://localhost:8000") return str(base_url).rstrip("/") def get_api_key() -> Optional[str]: """Read API key from Streamlit secrets or environment.""" if "API_KEY" in st.secrets: return str(st.secrets["API_KEY"]) return os.getenv("API_KEY") def ping_health(base_url: str, api_key: Optional[str]) -> Dict[str, Any]: url = f"{base_url}/health" headers: Dict[str, str] = {} if api_key: headers["X-API-Key"] = api_key resp = httpx.get(url, headers=headers, timeout=10.0) resp.raise_for_status() return resp.json() def call_chat( base_url: str, api_key: str, payload: Dict[str, Any], ) -> Dict[str, Any]: url = f"{base_url}/chat" headers: Dict[str, str] = {"Content-Type": "application/json", "X-API-Key": api_key} resp = httpx.post(url, json=payload, headers=headers, timeout=60.0) resp.raise_for_status() return resp.json() def iter_chat_stream( base_url: str, api_key: str, payload: Dict[str, Any], ) -> Generator[Tuple[str, Optional[Dict[str, Any]]], None, None]: """Stream tokens from /chat/stream and yield (partial_answer, final_payload). The final_payload is None for intermediate updates and populated once when the terminating SSE event is received. """ url = f"{base_url}/chat/stream" headers: Dict[str, str] = {"Content-Type": "application/json", "X-API-Key": api_key} full_answer = "" final_payload: Optional[Dict[str, Any]] = None current_event: Optional[str] = None with httpx.Client(timeout=60.0) as client: with client.stream("POST", url, json=payload, headers=headers) as resp: resp.raise_for_status() for line in resp.iter_lines(): if not line: continue if line.startswith("event:"): current_event = line.split(":", 1)[1].strip() continue if line.startswith("data:"): data = line.split(":", 1)[1].lstrip() if current_event == "end": # Final payload with full JSON response. try: final_payload = json.loads(data) except json.JSONDecodeError: final_payload = None else: if data: if full_answer: full_answer += " " full_answer += data # Yield intermediate answer text. yield full_answer, None # After stream ends, make sure we yield at least once with final payload. if final_payload is not None: # If the backend included the final answer in the JSON payload, prefer it. answer_text = str(final_payload.get("answer") or full_answer) yield answer_text, final_payload elif full_answer: yield full_answer, None def init_session_state() -> None: if "messages" not in st.session_state: st.session_state.messages: List[Dict[str, Any]] = [] if "show_sources" not in st.session_state: st.session_state.show_sources = True if "supports_stream" not in st.session_state: st.session_state.supports_stream = True # Namespace is fixed for now; default to "dev". if "namespace" not in st.session_state: st.session_state.namespace = "dev" if "recent_uploads" not in st.session_state: st.session_state.recent_uploads: List[Dict[str, Any]] = [] if "chat_prefill" not in st.session_state: st.session_state.chat_prefill = None def render_sidebar(backend_base_url: str, api_key: Optional[str]) -> Dict[str, Any]: with st.sidebar: st.header("Backend") st.markdown(f"**Backend URL:** `{backend_base_url}`") if api_key: st.markdown("**API key:** configured in Streamlit secrets or environment.") else: st.warning( "API_KEY is not configured. The backend is expected to be protected; " "chat will be disabled until an API key is set." ) if st.button("Ping /health"): try: health = ping_health(backend_base_url, api_key) st.success("Backend reachable.") st.json(health) except Exception as exc: # noqa: BLE001 st.error(f"Health check failed: {exc}") st.markdown("---") st.subheader("Chat settings") top_k = st.slider("Top K", min_value=1, max_value=20, value=5, step=1) min_score = st.slider( "Minimum relevance score", min_value=0.0, max_value=1.0, value=0.25, step=0.05, ) use_web_fallback = st.checkbox( "Use web fallback (Tavily)", value=True, help="When enabled, /chat may call Tavily if retrieval is weak.", ) st.session_state.show_sources = st.checkbox( "Show sources", value=st.session_state.show_sources ) if st.button("Clear chat"): st.session_state.messages = [] st.markdown("---") st.subheader("Recent uploads") recent = st.session_state.get("recent_uploads", []) if not recent: st.caption("No documents uploaded yet.") else: for idx, item in enumerate(recent): title = item.get("title") or "Untitled" ns = item.get("namespace") or st.session_state.get("namespace", "dev") ts = item.get("timestamp", "") st.markdown(f"- **{title}** \n Namespace: `{ns}` \n Uploaded: {ts}") if st.button("Search this document", key=f"search_upload_{idx}"): st.session_state.chat_prefill = f"Summarize: {title}" return { "top_k": top_k, "min_score": float(min_score), "use_web_fallback": bool(use_web_fallback), } def render_chat_history(show_sources: bool) -> None: for message in st.session_state.messages: role = message.get("role", "user") content = message.get("content", "") with st.chat_message("assistant" if role == "assistant" else "user"): st.markdown(content) if role == "assistant" and show_sources: sources = message.get("sources") or [] if sources: with st.expander("Sources", expanded=False): for idx, src in enumerate(sources, start=1): title = src.get("title") or f"Source {idx}" url = src.get("url") or "" score = src.get("score", 0.0) st.markdown(f"**[{idx}] {title}** (score={score:.3f})") if url: st.markdown(f"- URL: {url}") chunk_text = src.get("chunk_text") or "" if chunk_text: st.write(chunk_text[:1000] + ("..." if len(chunk_text) > 1000 else "")) @st.dialog("Upload document") def upload_dialog(backend_base_url: str, api_key: Optional[str]) -> None: """Modal dialog for uploading and ingesting a document via /documents/upload-text.""" st.write("Upload a document to ingest it into the RAG backend.") with st.form("upload_form"): uploaded_file = st.file_uploader( "Choose a file", type=["pdf", "md", "txt", "docx", "pptx", "xlsx", "html", "htm"], accept_multiple_files=False, ) default_title = "" if uploaded_file is not None: default_title = Path(uploaded_file.name).stem title = st.text_input("Title", value=default_title) namespace = st.text_input( "Namespace", value=st.session_state.get("namespace", "dev"), help="Target Pinecone namespace.", ) source = st.text_input("Source label", value="ui-upload") tags = st.text_input("Tags (comma separated)", value="") notes = st.text_area("Notes", value="", height=80) upload_anyway = st.checkbox( "Upload even if extracted text is very short", value=False, help="Enable to upload even when the extracted text is shorter than 200 characters.", ) submit = st.form_submit_button("Upload") if not submit: return if uploaded_file is None: st.error("Please select a file to upload.") return if not title.strip(): st.error("Please provide a title.") return if not api_key: st.error("API_KEY is not configured; cannot upload to a protected backend.") return with st.spinner("Converting and uploading document..."): try: uploaded_file.seek(0) text, conv_meta = convert_uploaded_file_to_text(uploaded_file) except Exception as exc: # noqa: BLE001 st.error(f"Error converting file: {exc}") return if len(text.strip()) < 200 and not upload_anyway: st.warning( "Extracted text is very short (< 200 characters). " "Check the file or enable the checkbox to upload anyway." ) return meta: Dict[str, Any] = { **conv_meta, "tags": [t.strip() for t in tags.split(",") if t.strip()], "notes": notes, } payload = { "title": title.strip(), "source": source.strip() or "ui-upload", "text": text, "namespace": namespace.strip() or st.session_state.get("namespace", "dev"), "metadata": meta, } try: response = post_upload_text(backend_base_url, api_key, payload) except httpx.HTTPStatusError as exc: if exc.response is not None: detail = exc.response.text status_code = exc.response.status_code else: detail = str(exc) status_code = "error" st.error(f"Upload failed ({status_code}): {detail}") return except Exception as exc: # noqa: BLE001 st.error(f"Upload failed: {exc}") return # Record recent upload and suggest a follow-up chat action. rec = { "title": title.strip(), "namespace": payload["namespace"], "timestamp": datetime.utcnow().isoformat() + "Z", "response": response, } recent = st.session_state.get("recent_uploads", []) recent.append(rec) st.session_state.recent_uploads = recent[-5:] st.success(f"Uploaded and indexed: {title.strip()}") st.rerun() def main() -> None: st.set_page_config(page_title="RAG Agent Workbench", layout="wide") st.title("RAG Agent Workbench – Chatbot") init_session_state() backend_base_url = get_backend_base_url() api_key = get_api_key() # Upload button near the top-level chat UI. if st.button("📄 Upload Document"): upload_dialog(backend_base_url, api_key) settings = render_sidebar(backend_base_url, api_key) render_chat_history(show_sources=st.session_state.show_sources) if not api_key: st.info( "Configure `API_KEY` in Streamlit secrets (and on the backend) to start chatting." ) return # Pre-fill chat input if a suggestion was set (e.g. from recent uploads). prefill = st.session_state.get("chat_prefill") if prefill and "chat_input" not in st.session_state: st.session_state.chat_input = prefill user_message = st.chat_input( "Ask a question about your documents...", key="chat_input" ) if not user_message: return # Clear any prefill once the user has sent a message. st.session_state.chat_prefill = None # Record and display user message st.session_state.messages.append({"role": "user", "content": user_message}) with st.chat_message("user"): st.markdown(user_message) # Prepare payload for backend chat_history = [ {"role": msg["role"], "content": msg["content"]} for msg in st.session_state.messages if msg.get("role") in ("user", "assistant") ] payload: Dict[str, Any] = { "query": user_message, "namespace": st.session_state.namespace, "top_k": int(settings["top_k"]), "use_web_fallback": settings["use_web_fallback"], "min_score": float(settings["min_score"]), "max_web_results": 5, "chat_history": chat_history, } # Call backend and stream / display assistant response with st.chat_message("assistant"): placeholder = st.empty() placeholder.markdown("_Thinking..._") response: Optional[Dict[str, Any]] = None try: if st.session_state.get("supports_stream", True): try: # Attempt to use streaming endpoint first. for partial_answer, final_payload in iter_chat_stream( backend_base_url, api_key, payload, ): if partial_answer: placeholder.markdown(partial_answer) if final_payload is not None: response = final_payload break except httpx.HTTPStatusError as exc: # If /chat/stream is not available, fall back to /chat. if exc.response is not None and exc.response.status_code == 404: st.session_state.supports_stream = False else: raise if response is None: # Fallback to non-streaming /chat. response = call_chat(backend_base_url, api_key, payload) answer_text = str(response.get("answer") or "") if answer_text: placeholder.markdown(answer_text) else: placeholder.markdown("_No answer returned._") except Exception as exc: # noqa: BLE001 placeholder.markdown("") st.error(f"Error calling backend: {exc}") return if not response: return answer = str(response.get("answer") or "") sources = response.get("sources") or [] timings = response.get("timings") or {} # Optionally render sources for this assistant turn. if st.session_state.show_sources and sources: with st.expander("Sources", expanded=False): for idx, src in enumerate(sources, start=1): title = src.get("title") or f"Source {idx}" url = src.get("url") or "" score = src.get("score", 0.0) st.markdown(f"**[{idx}] {title}** (score={score:.3f})") if url: st.markdown(f"- URL: {url}") chunk_text = src.get("chunk_text") or "" if chunk_text: st.write(chunk_text[:1000] + ("..." if len(chunk_text) > 1000 else "")) # Persist assistant message with metadata. st.session_state.messages.append( { "role": "assistant", "content": answer, "sources": sources, "timings": timings, } ) if __name__ == "__main__": main()