| """Streamlit UI for doc-audi-ai — talks to the FastAPI backend only.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import time |
| from typing import Any |
|
|
| import httpx |
| import streamlit as st |
|
|
| DEFAULT_API_BASE = os.environ.get("DOC_AUDI_API_BASE", "http://127.0.0.1:8000") |
|
|
| |
| _HTTP_READ_TIMEOUT_DEFAULT_S = 3600.0 |
| _HTTP_READ_TIMEOUT_MIN_S = 60.0 |
| _HTTP_READ_TIMEOUT_MAX_S = 7200.0 |
|
|
|
|
| def _http_read_timeout_seconds() -> float: |
| raw = os.environ.get( |
| "DOC_AUDI_HTTP_READ_TIMEOUT", |
| str(int(_HTTP_READ_TIMEOUT_DEFAULT_S)), |
| ) |
| try: |
| read_s = float(raw) |
| except ValueError: |
| read_s = _HTTP_READ_TIMEOUT_DEFAULT_S |
| return max(_HTTP_READ_TIMEOUT_MIN_S, min(read_s, _HTTP_READ_TIMEOUT_MAX_S)) |
|
|
|
|
| def _http_timeout() -> httpx.Timeout: |
| """LLM + embeddings can exceed a few minutes on CPU or cold Ollama; Streamlit uses this, not Uvicorn.""" |
| read_s = _http_read_timeout_seconds() |
| return httpx.Timeout(connect=20.0, read=read_s, write=120.0, pool=30.0) |
|
|
|
|
| def _fmt_timeout_hint() -> str: |
| cap = int(_http_read_timeout_seconds()) |
| lo, hi = int(_HTTP_READ_TIMEOUT_MIN_S), int(_HTTP_READ_TIMEOUT_MAX_S) |
| return ( |
| f"The UI stops waiting after **{cap}s** per request (set **DOC_AUDI_HTTP_READ_TIMEOUT**, " |
| f"allowed **{lo}–{hi}** s). " |
| "Ensure `ollama serve` is running; cold models or CPU inference can exceed a few minutes." |
| ) |
|
|
|
|
| def _api_base() -> str: |
| """Resolve API base URL. Whitespace-only sidebar input must not win over default (breaks httpx).""" |
| raw = st.session_state.get("api_base") |
| if raw is None: |
| return DEFAULT_API_BASE.rstrip("/") |
| s = str(raw).strip() |
| if not s: |
| return DEFAULT_API_BASE.rstrip("/") |
| return s.rstrip("/") |
|
|
|
|
| def _client() -> httpx.Client: |
| return httpx.Client(base_url=_api_base(), timeout=_http_timeout()) |
|
|
|
|
| def _fmt_api_error(exc: httpx.HTTPStatusError) -> str: |
| try: |
| body = exc.response.json() |
| except Exception: |
| return f"HTTP {exc.response.status_code}: {exc.response.text[:500]}" |
| detail = body.get("detail") |
| if isinstance(detail, list): |
| parts = [] |
| for item in detail: |
| if isinstance(item, dict): |
| loc = item.get("loc", ()) |
| msg = item.get("msg", "") |
| parts.append(f"{'/'.join(str(x) for x in loc)}: {msg}") |
| else: |
| parts.append(str(item)) |
| return f"HTTP {exc.response.status_code}: " + "; ".join(parts) |
| if detail is not None: |
| return f"HTTP {exc.response.status_code}: {detail}" |
| return f"HTTP {exc.response.status_code}" |
|
|
|
|
| def _fmt_request_error(exc: httpx.RequestError) -> str: |
| """Human-readable transport errors (connection, timeouts, TLS, etc.).""" |
| base = _api_base() |
| if isinstance(exc, httpx.ReadTimeout): |
| return ( |
| f"**Read timeout** — `{base}` did not send a full response in time (embeddings/LLM can be slow). " |
| f"{_fmt_timeout_hint()}" |
| ) |
| if isinstance(exc, httpx.ConnectTimeout): |
| return ( |
| f"**Connect timeout** — could not open TCP to `{base}` in time. " |
| "Confirm the FastAPI process is listening (`uv run uvicorn api.main:app --host 0.0.0.0 --port 8000`)." |
| ) |
| if isinstance(exc, httpx.ConnectError): |
| return ( |
| f"**Connection failed** — nothing is accepting HTTP at `{base}`: {exc}. " |
| "Start the API, or fix **API base URL** / **`DOC_AUDI_API_BASE`** (use `http://127.0.0.1:8000` from the same machine, not `0.0.0.0`)." |
| ) |
| if isinstance(exc, httpx.TimeoutException): |
| return f"**Timeout** ({type(exc).__name__}): {exc}. {_fmt_timeout_hint()}" |
| return f"**Request error** ({type(exc).__name__}): {exc}. Backend: `{base}`." |
|
|
|
|
| def _post_query_ask( |
| client: httpx.Client, |
| *, |
| question: str, |
| collection_name: str, |
| top_k: int = 5, |
| user_id: str = "anonymous", |
| ) -> httpx.Response: |
| """POST /query/ask (falls back to POST /query on older servers).""" |
| body: dict[str, object] = { |
| "question": question.strip(), |
| "collection_name": collection_name, |
| "top_k": top_k, |
| "user_id": user_id, |
| } |
| r = client.post("/query/ask", json=body) |
| if r.status_code == 404: |
| r = client.post("/query", json=body) |
| return r |
|
|
|
|
| def _get_audit_logs( |
| client: httpx.Client, |
| *, |
| limit: int, |
| offset: int, |
| user_id: str | None = None, |
| from_date: str | None = None, |
| to_date: str | None = None, |
| ) -> httpx.Response: |
| params: dict[str, object] = {"limit": limit, "offset": offset} |
| if user_id: |
| params["user_id"] = user_id |
| if from_date: |
| params["from_date"] = from_date |
| if to_date: |
| params["to_date"] = to_date |
| r = client.get("/audit/logs", params=params) |
| if r.status_code == 404: |
| r = client.get("/audit", params=params) |
| return r |
|
|
|
|
| def _get_audit_event_detail(client: httpx.Client, event_id: str) -> httpx.Response: |
| r = client.get(f"/audit/logs/{event_id}") |
| if r.status_code == 404: |
| r = client.get(f"/audit/{event_id}") |
| return r |
|
|
|
|
| def _health_check() -> tuple[bool, str]: |
| try: |
| with _client() as c: |
| r = c.get("/health") |
| r.raise_for_status() |
| data = r.json() |
| return True, str(data) |
| except httpx.HTTPStatusError as e: |
| return False, _fmt_api_error(e) |
| except httpx.RequestError as e: |
| return False, _fmt_request_error(e) |
| except Exception as e: |
| return False, str(e) |
|
|
|
|
| def main() -> None: |
| st.set_page_config(page_title="doc-audi-ai", layout="wide") |
| if "api_base" not in st.session_state: |
| st.session_state.api_base = DEFAULT_API_BASE |
|
|
| st.title("doc-audi-ai") |
| st.caption("Ingest, query, and audit via the FastAPI backend.") |
| st.caption(f"Requests go to: `{_api_base()}`") |
|
|
| with st.sidebar: |
| st.subheader("Backend") |
| st.text_input( |
| "API base URL", |
| key="api_base", |
| placeholder=DEFAULT_API_BASE, |
| help=f"Default: {DEFAULT_API_BASE}. Clear the field to use the default.", |
| ) |
| st.caption( |
| f"Ask/Summarise wait up to **{int(_http_read_timeout_seconds())}s** per request " |
| f"(env `DOC_AUDI_HTTP_READ_TIMEOUT`, range {int(_HTTP_READ_TIMEOUT_MIN_S)}–{int(_HTTP_READ_TIMEOUT_MAX_S)})." |
| ) |
| if st.button("Test connection"): |
| ok, msg = _health_check() |
| if ok: |
| st.success(msg) |
| else: |
| st.error(msg) |
|
|
| tab_upload, tab_jobs, tab_ask, tab_sum, tab_audit = st.tabs( |
| ["Upload", "Jobs", "Ask", "Summarise", "Audit"] |
| ) |
|
|
| with tab_upload: |
| st.subheader("Upload document") |
| col_u1, col_u2 = st.columns(2) |
| with col_u1: |
| up_collection = st.text_input("Collection", value="default", key="up_col") |
| uploaded = st.file_uploader("PDF, TXT, or Markdown", type=["pdf", "txt", "md"], key="up_file") |
| with col_u2: |
| if st.button("Submit upload", key="btn_upload", disabled=uploaded is None): |
| if uploaded is None: |
| st.warning("Choose a file first.") |
| else: |
| try: |
| files = {"files": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")} |
| data = {"collection_name": up_collection} |
| with _client() as c: |
| r = c.post("/ingest/upload", files=files, data=data) |
| r.raise_for_status() |
| out = r.json() |
| st.success(out.get("message", "Queued")) |
| st.json(out) |
| if out.get("job_id"): |
| st.session_state["last_job_id"] = out["job_id"] |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| st.subheader("Ingest from URL") |
| url_col = st.columns([3, 1]) |
| with url_col[0]: |
| ingest_url = st.text_input("Document URL (http/https)", key="ingest_url") |
| with url_col[1]: |
| url_collection = st.text_input("Collection", value="default", key="url_col") |
| if st.button("Queue URL ingest", key="btn_url"): |
| if not ingest_url.strip(): |
| st.warning("Enter a URL.") |
| else: |
| try: |
| with _client() as c: |
| r = c.post( |
| "/ingest/url", |
| json={"urls": [ingest_url.strip()], "collection_name": url_collection}, |
| ) |
| r.raise_for_status() |
| out = r.json() |
| st.success(out.get("message", "Queued")) |
| st.json(out) |
| if out.get("job_id"): |
| st.session_state["last_job_id"] = out["job_id"] |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| st.subheader("Collections") |
| if st.button("Refresh collections", key="btn_collections"): |
| try: |
| with _client() as c: |
| r = c.get("/ingest/collections") |
| r.raise_for_status() |
| cols = r.json() |
| rows = cols.get("collections", []) |
| st.write(f"{cols.get('total', len(rows))} collection(s).") |
| if rows: |
| st.dataframe(rows, hide_index=True, use_container_width=True) |
| else: |
| st.info("No collections yet.") |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| del_name = st.text_input("Delete collection name (optional)", key="del_col") |
| if st.button("Delete collection", key="btn_del_col"): |
| if not del_name.strip(): |
| st.warning("Enter a collection name.") |
| else: |
| try: |
| with _client() as c: |
| r = c.delete(f"/ingest/collection/{del_name.strip()}") |
| r.raise_for_status() |
| del_body = r.json() |
| st.success(del_body.get("message", "Deleted")) |
| if "documents_removed" in del_body: |
| st.caption(f"Documents removed: **{del_body['documents_removed']}**") |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| with tab_jobs: |
| st.subheader("Job list") |
| j1, j2 = st.columns(2) |
| with j1: |
| j_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="j_lim") |
| with j2: |
| j_offset = st.number_input("Offset", min_value=0, value=0, key="j_off") |
| if st.button("List jobs", key="btn_jobs"): |
| try: |
| with _client() as c: |
| r = c.get("/jobs", params={"limit": int(j_limit), "offset": int(j_offset)}) |
| r.raise_for_status() |
| payload = r.json() |
| jobs: list[dict[str, Any]] = payload.get("jobs", []) |
| st.caption(f"Total jobs (matching filters): **{payload.get('total', len(jobs))}**") |
| if jobs: |
| st.dataframe(jobs, hide_index=True, use_container_width=True) |
| else: |
| st.info("No jobs in this window.") |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| st.subheader("Job detail") |
| default_job = st.session_state.get("last_job_id", "") |
| job_id = st.text_input("Job ID", value=default_job, key="job_id_in") |
| c1, c2 = st.columns(2) |
| with c1: |
| fetch_job = st.button("Fetch job", key="btn_job_one") |
| with c2: |
| poll_job = st.button("Poll until completed/failed", key="btn_job_poll") |
|
|
| if fetch_job and job_id.strip(): |
| try: |
| with _client() as c: |
| r = c.get(f"/jobs/{job_id.strip()}") |
| r.raise_for_status() |
| detail = r.json() |
| st.json(detail) |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| if poll_job and job_id.strip(): |
| status_ph = st.empty() |
| try: |
| with _client() as c: |
| for i in range(120): |
| r = c.get(f"/jobs/{job_id.strip()}") |
| r.raise_for_status() |
| body = r.json() |
| st_ = body.get("status", "") |
| status_ph.write(f"Poll {i + 1}: **{st_}** — {body.get('progress_percent', 0)}%") |
| if st_ in ("completed", "failed"): |
| st.json(body) |
| break |
| time.sleep(1) |
| else: |
| status_ph.write("Stopped after 120 attempts (~2 min).") |
| st.json(body) |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| with tab_ask: |
| st.subheader("Ask a question") |
| q_col = st.text_input("Collection", value="default", key="ask_col") |
| question = st.text_area("Question", height=120, key="ask_q") |
| if st.button("Ask", key="btn_ask"): |
| if not question.strip(): |
| st.warning("Enter a question.") |
| else: |
| try: |
| with st.spinner( |
| "Calling the API (embeddings + LLM can take several minutes on a slow machine; " |
| "ensure Ollama is running). Timeout is controlled by DOC_AUDI_HTTP_READ_TIMEOUT…" |
| ): |
| with _client() as c: |
| r = _post_query_ask( |
| c, |
| question=question, |
| collection_name=q_col, |
| ) |
| r.raise_for_status() |
| ans = r.json() |
| st.success(f"Query id: `{ans.get('query_id', '')}`") |
| if ans.get("answer"): |
| st.markdown("### Answer") |
| st.markdown(ans["answer"]) |
| else: |
| st.warning( |
| "The API returned no **answer** text. " |
| "Check the collection has ingested chunks, LLM env, and expand **Raw response** below." |
| ) |
| src = ans.get("sources") or [] |
| if src: |
| with st.expander(f"Sources ({len(src)})"): |
| st.json(src) |
| else: |
| st.caption("No sources in this response (empty retrieval or model returned nothing).") |
| with st.expander("Raw response (debug)"): |
| st.json(ans) |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| with tab_sum: |
| st.subheader("Summarise collection") |
| s_col = st.text_input("Collection", value="default", key="sum_col") |
| focus = st.text_input("Optional focus / angle", value="", key="sum_focus") |
| if st.button("Summarise", key="btn_sum"): |
| try: |
| body: dict[str, Any] = {"collection_name": s_col} |
| if focus.strip(): |
| body["focus"] = focus.strip() |
| with st.spinner("Calling summarise (can take 1–2 minutes on a cold model)…"): |
| with _client() as c: |
| r = c.post("/query/summarise", json=body) |
| r.raise_for_status() |
| ans = r.json() |
| st.success(f"Query id: `{ans.get('query_id', '')}` · documents: **{ans.get('document_count', '')}**") |
| summary_text = ans.get("summary") or ans.get("answer") |
| if summary_text: |
| st.markdown("### Summary") |
| st.markdown(summary_text) |
| else: |
| st.warning("No summary text in the response; see **Raw response** below.") |
| src = ans.get("sources") or [] |
| if src: |
| with st.expander(f"Sources ({len(src)})"): |
| st.json(src) |
| with st.expander("Raw response (debug)"): |
| st.json(ans) |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| with tab_audit: |
| st.subheader("Audit log") |
| a1, a2 = st.columns(2) |
| with a1: |
| a_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="a_lim") |
| with a2: |
| a_offset = st.number_input("Offset", min_value=0, value=0, key="a_off") |
| if st.button("List audit events", key="btn_audit_list"): |
| try: |
| with _client() as c: |
| r = _get_audit_logs( |
| c, |
| limit=int(a_limit), |
| offset=int(a_offset), |
| ) |
| r.raise_for_status() |
| payload = r.json() |
| events = payload.get("logs", payload.get("events", [])) |
| st.caption(f"Total matching: **{payload.get('total', len(events))}**") |
| if events: |
| st.dataframe(events, hide_index=True, use_container_width=True) |
| ids = [ |
| e.get("query_id") or e.get("event_id") |
| for e in events |
| if isinstance(e, dict) and (e.get("query_id") or e.get("event_id")) |
| ] |
| if ids: |
| st.session_state["_audit_ids"] = ids |
| else: |
| st.info("No audit events.") |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
| st.subheader("Audit event detail") |
| ids_for_select = st.session_state.get("_audit_ids", []) |
| pick = "" |
| if ids_for_select: |
| pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick") |
| manual_id = st.text_input("Or enter query / event ID", key="audit_manual") |
| ev_id = (manual_id.strip() or (pick or "").strip()).strip() |
| if st.button("Load detail", key="btn_audit_detail") and ev_id: |
| try: |
| with _client() as c: |
| r = _get_audit_event_detail(c, ev_id) |
| r.raise_for_status() |
| st.json(r.json()) |
| except httpx.HTTPStatusError as e: |
| st.error(_fmt_api_error(e)) |
| except httpx.RequestError as e: |
| st.error(_fmt_request_error(e)) |
| except Exception as e: |
| st.exception(e) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|