"""Streamlit UI for doc-audi-ai — talks to the FastAPI backend only.""" from __future__ import annotations import os import time from typing import Any import httpx import streamlit as st DEFAULT_API_BASE = os.environ.get("DOC_AUDI_API_BASE", "http://127.0.0.1:8000") # httpx read timeout for Ask/Summarise: embeddings + LLM on CPU or cold Ollama often exceeds 10 minutes. _HTTP_READ_TIMEOUT_DEFAULT_S = 3600.0 _HTTP_READ_TIMEOUT_MIN_S = 60.0 _HTTP_READ_TIMEOUT_MAX_S = 7200.0 def _http_read_timeout_seconds() -> float: raw = os.environ.get( "DOC_AUDI_HTTP_READ_TIMEOUT", str(int(_HTTP_READ_TIMEOUT_DEFAULT_S)), ) try: read_s = float(raw) except ValueError: read_s = _HTTP_READ_TIMEOUT_DEFAULT_S return max(_HTTP_READ_TIMEOUT_MIN_S, min(read_s, _HTTP_READ_TIMEOUT_MAX_S)) def _http_timeout() -> httpx.Timeout: """LLM + embeddings can exceed a few minutes on CPU or cold Ollama; Streamlit uses this, not Uvicorn.""" read_s = _http_read_timeout_seconds() return httpx.Timeout(connect=20.0, read=read_s, write=120.0, pool=30.0) def _fmt_timeout_hint() -> str: cap = int(_http_read_timeout_seconds()) lo, hi = int(_HTTP_READ_TIMEOUT_MIN_S), int(_HTTP_READ_TIMEOUT_MAX_S) return ( f"The UI stops waiting after **{cap}s** per request (set **DOC_AUDI_HTTP_READ_TIMEOUT**, " f"allowed **{lo}–{hi}** s). " "Ensure `ollama serve` is running; cold models or CPU inference can exceed a few minutes." ) def _api_base() -> str: """Resolve API base URL. Whitespace-only sidebar input must not win over default (breaks httpx).""" raw = st.session_state.get("api_base") if raw is None: return DEFAULT_API_BASE.rstrip("/") s = str(raw).strip() if not s: return DEFAULT_API_BASE.rstrip("/") return s.rstrip("/") def _client() -> httpx.Client: return httpx.Client(base_url=_api_base(), timeout=_http_timeout()) def _fmt_api_error(exc: httpx.HTTPStatusError) -> str: try: body = exc.response.json() except Exception: return f"HTTP {exc.response.status_code}: {exc.response.text[:500]}" detail = body.get("detail") if isinstance(detail, list): parts = [] for item in detail: if isinstance(item, dict): loc = item.get("loc", ()) msg = item.get("msg", "") parts.append(f"{'/'.join(str(x) for x in loc)}: {msg}") else: parts.append(str(item)) return f"HTTP {exc.response.status_code}: " + "; ".join(parts) if detail is not None: return f"HTTP {exc.response.status_code}: {detail}" return f"HTTP {exc.response.status_code}" def _fmt_request_error(exc: httpx.RequestError) -> str: """Human-readable transport errors (connection, timeouts, TLS, etc.).""" base = _api_base() if isinstance(exc, httpx.ReadTimeout): return ( f"**Read timeout** — `{base}` did not send a full response in time (embeddings/LLM can be slow). " f"{_fmt_timeout_hint()}" ) if isinstance(exc, httpx.ConnectTimeout): return ( f"**Connect timeout** — could not open TCP to `{base}` in time. " "Confirm the FastAPI process is listening (`uv run uvicorn api.main:app --host 0.0.0.0 --port 8000`)." ) if isinstance(exc, httpx.ConnectError): return ( f"**Connection failed** — nothing is accepting HTTP at `{base}`: {exc}. " "Start the API, or fix **API base URL** / **`DOC_AUDI_API_BASE`** (use `http://127.0.0.1:8000` from the same machine, not `0.0.0.0`)." ) if isinstance(exc, httpx.TimeoutException): return f"**Timeout** ({type(exc).__name__}): {exc}. {_fmt_timeout_hint()}" return f"**Request error** ({type(exc).__name__}): {exc}. Backend: `{base}`." def _post_query_ask( client: httpx.Client, *, question: str, collection_name: str, top_k: int = 5, user_id: str = "anonymous", ) -> httpx.Response: """POST /query/ask (falls back to POST /query on older servers).""" body: dict[str, object] = { "question": question.strip(), "collection_name": collection_name, "top_k": top_k, "user_id": user_id, } r = client.post("/query/ask", json=body) if r.status_code == 404: r = client.post("/query", json=body) return r def _get_audit_logs( client: httpx.Client, *, limit: int, offset: int, user_id: str | None = None, from_date: str | None = None, to_date: str | None = None, ) -> httpx.Response: params: dict[str, object] = {"limit": limit, "offset": offset} if user_id: params["user_id"] = user_id if from_date: params["from_date"] = from_date if to_date: params["to_date"] = to_date r = client.get("/audit/logs", params=params) if r.status_code == 404: r = client.get("/audit", params=params) return r def _get_audit_event_detail(client: httpx.Client, event_id: str) -> httpx.Response: r = client.get(f"/audit/logs/{event_id}") if r.status_code == 404: r = client.get(f"/audit/{event_id}") return r def _health_check() -> tuple[bool, str]: try: with _client() as c: r = c.get("/health") r.raise_for_status() data = r.json() return True, str(data) except httpx.HTTPStatusError as e: return False, _fmt_api_error(e) except httpx.RequestError as e: return False, _fmt_request_error(e) except Exception as e: return False, str(e) def main() -> None: st.set_page_config(page_title="doc-audi-ai", layout="wide") if "api_base" not in st.session_state: st.session_state.api_base = DEFAULT_API_BASE st.title("doc-audi-ai") st.caption("Ingest, query, and audit via the FastAPI backend.") st.caption(f"Requests go to: `{_api_base()}`") with st.sidebar: st.subheader("Backend") st.text_input( "API base URL", key="api_base", placeholder=DEFAULT_API_BASE, help=f"Default: {DEFAULT_API_BASE}. Clear the field to use the default.", ) st.caption( f"Ask/Summarise wait up to **{int(_http_read_timeout_seconds())}s** per request " f"(env `DOC_AUDI_HTTP_READ_TIMEOUT`, range {int(_HTTP_READ_TIMEOUT_MIN_S)}–{int(_HTTP_READ_TIMEOUT_MAX_S)})." ) if st.button("Test connection"): ok, msg = _health_check() if ok: st.success(msg) else: st.error(msg) tab_upload, tab_jobs, tab_ask, tab_sum, tab_audit = st.tabs( ["Upload", "Jobs", "Ask", "Summarise", "Audit"] ) with tab_upload: st.subheader("Upload document") col_u1, col_u2 = st.columns(2) with col_u1: up_collection = st.text_input("Collection", value="default", key="up_col") uploaded = st.file_uploader("PDF, TXT, or Markdown", type=["pdf", "txt", "md"], key="up_file") with col_u2: if st.button("Submit upload", key="btn_upload", disabled=uploaded is None): if uploaded is None: st.warning("Choose a file first.") else: try: files = {"files": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")} data = {"collection_name": up_collection} with _client() as c: r = c.post("/ingest/upload", files=files, data=data) r.raise_for_status() out = r.json() st.success(out.get("message", "Queued")) st.json(out) if out.get("job_id"): st.session_state["last_job_id"] = out["job_id"] except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) st.subheader("Ingest from URL") url_col = st.columns([3, 1]) with url_col[0]: ingest_url = st.text_input("Document URL (http/https)", key="ingest_url") with url_col[1]: url_collection = st.text_input("Collection", value="default", key="url_col") if st.button("Queue URL ingest", key="btn_url"): if not ingest_url.strip(): st.warning("Enter a URL.") else: try: with _client() as c: r = c.post( "/ingest/url", json={"urls": [ingest_url.strip()], "collection_name": url_collection}, ) r.raise_for_status() out = r.json() st.success(out.get("message", "Queued")) st.json(out) if out.get("job_id"): st.session_state["last_job_id"] = out["job_id"] except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) st.subheader("Collections") if st.button("Refresh collections", key="btn_collections"): try: with _client() as c: r = c.get("/ingest/collections") r.raise_for_status() cols = r.json() rows = cols.get("collections", []) st.write(f"{cols.get('total', len(rows))} collection(s).") if rows: st.dataframe(rows, hide_index=True, use_container_width=True) else: st.info("No collections yet.") except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) del_name = st.text_input("Delete collection name (optional)", key="del_col") if st.button("Delete collection", key="btn_del_col"): if not del_name.strip(): st.warning("Enter a collection name.") else: try: with _client() as c: r = c.delete(f"/ingest/collection/{del_name.strip()}") r.raise_for_status() del_body = r.json() st.success(del_body.get("message", "Deleted")) if "documents_removed" in del_body: st.caption(f"Documents removed: **{del_body['documents_removed']}**") except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) with tab_jobs: st.subheader("Job list") j1, j2 = st.columns(2) with j1: j_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="j_lim") with j2: j_offset = st.number_input("Offset", min_value=0, value=0, key="j_off") if st.button("List jobs", key="btn_jobs"): try: with _client() as c: r = c.get("/jobs", params={"limit": int(j_limit), "offset": int(j_offset)}) r.raise_for_status() payload = r.json() jobs: list[dict[str, Any]] = payload.get("jobs", []) st.caption(f"Total jobs (matching filters): **{payload.get('total', len(jobs))}**") if jobs: st.dataframe(jobs, hide_index=True, use_container_width=True) else: st.info("No jobs in this window.") except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) st.subheader("Job detail") default_job = st.session_state.get("last_job_id", "") job_id = st.text_input("Job ID", value=default_job, key="job_id_in") c1, c2 = st.columns(2) with c1: fetch_job = st.button("Fetch job", key="btn_job_one") with c2: poll_job = st.button("Poll until completed/failed", key="btn_job_poll") if fetch_job and job_id.strip(): try: with _client() as c: r = c.get(f"/jobs/{job_id.strip()}") r.raise_for_status() detail = r.json() st.json(detail) except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) if poll_job and job_id.strip(): status_ph = st.empty() try: with _client() as c: for i in range(120): r = c.get(f"/jobs/{job_id.strip()}") r.raise_for_status() body = r.json() st_ = body.get("status", "") status_ph.write(f"Poll {i + 1}: **{st_}** — {body.get('progress_percent', 0)}%") if st_ in ("completed", "failed"): st.json(body) break time.sleep(1) else: status_ph.write("Stopped after 120 attempts (~2 min).") st.json(body) except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) with tab_ask: st.subheader("Ask a question") q_col = st.text_input("Collection", value="default", key="ask_col") question = st.text_area("Question", height=120, key="ask_q") if st.button("Ask", key="btn_ask"): if not question.strip(): st.warning("Enter a question.") else: try: with st.spinner( "Calling the API (embeddings + LLM can take several minutes on a slow machine; " "ensure Ollama is running). Timeout is controlled by DOC_AUDI_HTTP_READ_TIMEOUT…" ): with _client() as c: r = _post_query_ask( c, question=question, collection_name=q_col, ) r.raise_for_status() ans = r.json() st.success(f"Query id: `{ans.get('query_id', '')}`") if ans.get("answer"): st.markdown("### Answer") st.markdown(ans["answer"]) else: st.warning( "The API returned no **answer** text. " "Check the collection has ingested chunks, LLM env, and expand **Raw response** below." ) src = ans.get("sources") or [] if src: with st.expander(f"Sources ({len(src)})"): st.json(src) else: st.caption("No sources in this response (empty retrieval or model returned nothing).") with st.expander("Raw response (debug)"): st.json(ans) except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) with tab_sum: st.subheader("Summarise collection") s_col = st.text_input("Collection", value="default", key="sum_col") focus = st.text_input("Optional focus / angle", value="", key="sum_focus") if st.button("Summarise", key="btn_sum"): try: body: dict[str, Any] = {"collection_name": s_col} if focus.strip(): body["focus"] = focus.strip() with st.spinner("Calling summarise (can take 1–2 minutes on a cold model)…"): with _client() as c: r = c.post("/query/summarise", json=body) r.raise_for_status() ans = r.json() st.success(f"Query id: `{ans.get('query_id', '')}` · documents: **{ans.get('document_count', '')}**") summary_text = ans.get("summary") or ans.get("answer") if summary_text: st.markdown("### Summary") st.markdown(summary_text) else: st.warning("No summary text in the response; see **Raw response** below.") src = ans.get("sources") or [] if src: with st.expander(f"Sources ({len(src)})"): st.json(src) with st.expander("Raw response (debug)"): st.json(ans) except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) with tab_audit: st.subheader("Audit log") a1, a2 = st.columns(2) with a1: a_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="a_lim") with a2: a_offset = st.number_input("Offset", min_value=0, value=0, key="a_off") if st.button("List audit events", key="btn_audit_list"): try: with _client() as c: r = _get_audit_logs( c, limit=int(a_limit), offset=int(a_offset), ) r.raise_for_status() payload = r.json() events = payload.get("logs", payload.get("events", [])) st.caption(f"Total matching: **{payload.get('total', len(events))}**") if events: st.dataframe(events, hide_index=True, use_container_width=True) ids = [ e.get("query_id") or e.get("event_id") for e in events if isinstance(e, dict) and (e.get("query_id") or e.get("event_id")) ] if ids: st.session_state["_audit_ids"] = ids else: st.info("No audit events.") except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) st.subheader("Audit event detail") ids_for_select = st.session_state.get("_audit_ids", []) pick = "" if ids_for_select: pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick") manual_id = st.text_input("Or enter query / event ID", key="audit_manual") ev_id = (manual_id.strip() or (pick or "").strip()).strip() if st.button("Load detail", key="btn_audit_detail") and ev_id: try: with _client() as c: r = _get_audit_event_detail(c, ev_id) r.raise_for_status() st.json(r.json()) except httpx.HTTPStatusError as e: st.error(_fmt_api_error(e)) except httpx.RequestError as e: st.error(_fmt_request_error(e)) except Exception as e: st.exception(e) if __name__ == "__main__": main()