Document-Audit-RAG / streamlit_app.py
Mayank Chugh
Deploy DocuAudit AI to Hugging Face Space (no binaries)
d44b33d
"""Streamlit UI for doc-audi-ai — talks to the FastAPI backend only."""
from __future__ import annotations
import os
import time
from typing import Any
import httpx
import streamlit as st
DEFAULT_API_BASE = os.environ.get("DOC_AUDI_API_BASE", "http://127.0.0.1:8000")
# httpx read timeout for Ask/Summarise: embeddings + LLM on CPU or cold Ollama often exceeds 10 minutes.
_HTTP_READ_TIMEOUT_DEFAULT_S = 3600.0
_HTTP_READ_TIMEOUT_MIN_S = 60.0
_HTTP_READ_TIMEOUT_MAX_S = 7200.0
def _http_read_timeout_seconds() -> float:
raw = os.environ.get(
"DOC_AUDI_HTTP_READ_TIMEOUT",
str(int(_HTTP_READ_TIMEOUT_DEFAULT_S)),
)
try:
read_s = float(raw)
except ValueError:
read_s = _HTTP_READ_TIMEOUT_DEFAULT_S
return max(_HTTP_READ_TIMEOUT_MIN_S, min(read_s, _HTTP_READ_TIMEOUT_MAX_S))
def _http_timeout() -> httpx.Timeout:
"""LLM + embeddings can exceed a few minutes on CPU or cold Ollama; Streamlit uses this, not Uvicorn."""
read_s = _http_read_timeout_seconds()
return httpx.Timeout(connect=20.0, read=read_s, write=120.0, pool=30.0)
def _fmt_timeout_hint() -> str:
cap = int(_http_read_timeout_seconds())
lo, hi = int(_HTTP_READ_TIMEOUT_MIN_S), int(_HTTP_READ_TIMEOUT_MAX_S)
return (
f"The UI stops waiting after **{cap}s** per request (set **DOC_AUDI_HTTP_READ_TIMEOUT**, "
f"allowed **{lo}{hi}** s). "
"Ensure `ollama serve` is running; cold models or CPU inference can exceed a few minutes."
)
def _api_base() -> str:
"""Resolve API base URL. Whitespace-only sidebar input must not win over default (breaks httpx)."""
raw = st.session_state.get("api_base")
if raw is None:
return DEFAULT_API_BASE.rstrip("/")
s = str(raw).strip()
if not s:
return DEFAULT_API_BASE.rstrip("/")
return s.rstrip("/")
def _client() -> httpx.Client:
return httpx.Client(base_url=_api_base(), timeout=_http_timeout())
def _fmt_api_error(exc: httpx.HTTPStatusError) -> str:
try:
body = exc.response.json()
except Exception:
return f"HTTP {exc.response.status_code}: {exc.response.text[:500]}"
detail = body.get("detail")
if isinstance(detail, list):
parts = []
for item in detail:
if isinstance(item, dict):
loc = item.get("loc", ())
msg = item.get("msg", "")
parts.append(f"{'/'.join(str(x) for x in loc)}: {msg}")
else:
parts.append(str(item))
return f"HTTP {exc.response.status_code}: " + "; ".join(parts)
if detail is not None:
return f"HTTP {exc.response.status_code}: {detail}"
return f"HTTP {exc.response.status_code}"
def _fmt_request_error(exc: httpx.RequestError) -> str:
"""Human-readable transport errors (connection, timeouts, TLS, etc.)."""
base = _api_base()
if isinstance(exc, httpx.ReadTimeout):
return (
f"**Read timeout** — `{base}` did not send a full response in time (embeddings/LLM can be slow). "
f"{_fmt_timeout_hint()}"
)
if isinstance(exc, httpx.ConnectTimeout):
return (
f"**Connect timeout** — could not open TCP to `{base}` in time. "
"Confirm the FastAPI process is listening (`uv run uvicorn api.main:app --host 0.0.0.0 --port 8000`)."
)
if isinstance(exc, httpx.ConnectError):
return (
f"**Connection failed** — nothing is accepting HTTP at `{base}`: {exc}. "
"Start the API, or fix **API base URL** / **`DOC_AUDI_API_BASE`** (use `http://127.0.0.1:8000` from the same machine, not `0.0.0.0`)."
)
if isinstance(exc, httpx.TimeoutException):
return f"**Timeout** ({type(exc).__name__}): {exc}. {_fmt_timeout_hint()}"
return f"**Request error** ({type(exc).__name__}): {exc}. Backend: `{base}`."
def _post_query_ask(
client: httpx.Client,
*,
question: str,
collection_name: str,
top_k: int = 5,
user_id: str = "anonymous",
) -> httpx.Response:
"""POST /query/ask (falls back to POST /query on older servers)."""
body: dict[str, object] = {
"question": question.strip(),
"collection_name": collection_name,
"top_k": top_k,
"user_id": user_id,
}
r = client.post("/query/ask", json=body)
if r.status_code == 404:
r = client.post("/query", json=body)
return r
def _get_audit_logs(
client: httpx.Client,
*,
limit: int,
offset: int,
user_id: str | None = None,
from_date: str | None = None,
to_date: str | None = None,
) -> httpx.Response:
params: dict[str, object] = {"limit": limit, "offset": offset}
if user_id:
params["user_id"] = user_id
if from_date:
params["from_date"] = from_date
if to_date:
params["to_date"] = to_date
r = client.get("/audit/logs", params=params)
if r.status_code == 404:
r = client.get("/audit", params=params)
return r
def _get_audit_event_detail(client: httpx.Client, event_id: str) -> httpx.Response:
r = client.get(f"/audit/logs/{event_id}")
if r.status_code == 404:
r = client.get(f"/audit/{event_id}")
return r
def _health_check() -> tuple[bool, str]:
try:
with _client() as c:
r = c.get("/health")
r.raise_for_status()
data = r.json()
return True, str(data)
except httpx.HTTPStatusError as e:
return False, _fmt_api_error(e)
except httpx.RequestError as e:
return False, _fmt_request_error(e)
except Exception as e:
return False, str(e)
def main() -> None:
st.set_page_config(page_title="doc-audi-ai", layout="wide")
if "api_base" not in st.session_state:
st.session_state.api_base = DEFAULT_API_BASE
st.title("doc-audi-ai")
st.caption("Ingest, query, and audit via the FastAPI backend.")
st.caption(f"Requests go to: `{_api_base()}`")
with st.sidebar:
st.subheader("Backend")
st.text_input(
"API base URL",
key="api_base",
placeholder=DEFAULT_API_BASE,
help=f"Default: {DEFAULT_API_BASE}. Clear the field to use the default.",
)
st.caption(
f"Ask/Summarise wait up to **{int(_http_read_timeout_seconds())}s** per request "
f"(env `DOC_AUDI_HTTP_READ_TIMEOUT`, range {int(_HTTP_READ_TIMEOUT_MIN_S)}{int(_HTTP_READ_TIMEOUT_MAX_S)})."
)
if st.button("Test connection"):
ok, msg = _health_check()
if ok:
st.success(msg)
else:
st.error(msg)
tab_upload, tab_jobs, tab_ask, tab_sum, tab_audit = st.tabs(
["Upload", "Jobs", "Ask", "Summarise", "Audit"]
)
with tab_upload:
st.subheader("Upload document")
col_u1, col_u2 = st.columns(2)
with col_u1:
up_collection = st.text_input("Collection", value="default", key="up_col")
uploaded = st.file_uploader("PDF, TXT, or Markdown", type=["pdf", "txt", "md"], key="up_file")
with col_u2:
if st.button("Submit upload", key="btn_upload", disabled=uploaded is None):
if uploaded is None:
st.warning("Choose a file first.")
else:
try:
files = {"files": (uploaded.name, uploaded.getvalue(), uploaded.type or "application/octet-stream")}
data = {"collection_name": up_collection}
with _client() as c:
r = c.post("/ingest/upload", files=files, data=data)
r.raise_for_status()
out = r.json()
st.success(out.get("message", "Queued"))
st.json(out)
if out.get("job_id"):
st.session_state["last_job_id"] = out["job_id"]
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
st.subheader("Ingest from URL")
url_col = st.columns([3, 1])
with url_col[0]:
ingest_url = st.text_input("Document URL (http/https)", key="ingest_url")
with url_col[1]:
url_collection = st.text_input("Collection", value="default", key="url_col")
if st.button("Queue URL ingest", key="btn_url"):
if not ingest_url.strip():
st.warning("Enter a URL.")
else:
try:
with _client() as c:
r = c.post(
"/ingest/url",
json={"urls": [ingest_url.strip()], "collection_name": url_collection},
)
r.raise_for_status()
out = r.json()
st.success(out.get("message", "Queued"))
st.json(out)
if out.get("job_id"):
st.session_state["last_job_id"] = out["job_id"]
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
st.subheader("Collections")
if st.button("Refresh collections", key="btn_collections"):
try:
with _client() as c:
r = c.get("/ingest/collections")
r.raise_for_status()
cols = r.json()
rows = cols.get("collections", [])
st.write(f"{cols.get('total', len(rows))} collection(s).")
if rows:
st.dataframe(rows, hide_index=True, use_container_width=True)
else:
st.info("No collections yet.")
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
del_name = st.text_input("Delete collection name (optional)", key="del_col")
if st.button("Delete collection", key="btn_del_col"):
if not del_name.strip():
st.warning("Enter a collection name.")
else:
try:
with _client() as c:
r = c.delete(f"/ingest/collection/{del_name.strip()}")
r.raise_for_status()
del_body = r.json()
st.success(del_body.get("message", "Deleted"))
if "documents_removed" in del_body:
st.caption(f"Documents removed: **{del_body['documents_removed']}**")
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
with tab_jobs:
st.subheader("Job list")
j1, j2 = st.columns(2)
with j1:
j_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="j_lim")
with j2:
j_offset = st.number_input("Offset", min_value=0, value=0, key="j_off")
if st.button("List jobs", key="btn_jobs"):
try:
with _client() as c:
r = c.get("/jobs", params={"limit": int(j_limit), "offset": int(j_offset)})
r.raise_for_status()
payload = r.json()
jobs: list[dict[str, Any]] = payload.get("jobs", [])
st.caption(f"Total jobs (matching filters): **{payload.get('total', len(jobs))}**")
if jobs:
st.dataframe(jobs, hide_index=True, use_container_width=True)
else:
st.info("No jobs in this window.")
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
st.subheader("Job detail")
default_job = st.session_state.get("last_job_id", "")
job_id = st.text_input("Job ID", value=default_job, key="job_id_in")
c1, c2 = st.columns(2)
with c1:
fetch_job = st.button("Fetch job", key="btn_job_one")
with c2:
poll_job = st.button("Poll until completed/failed", key="btn_job_poll")
if fetch_job and job_id.strip():
try:
with _client() as c:
r = c.get(f"/jobs/{job_id.strip()}")
r.raise_for_status()
detail = r.json()
st.json(detail)
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
if poll_job and job_id.strip():
status_ph = st.empty()
try:
with _client() as c:
for i in range(120):
r = c.get(f"/jobs/{job_id.strip()}")
r.raise_for_status()
body = r.json()
st_ = body.get("status", "")
status_ph.write(f"Poll {i + 1}: **{st_}** — {body.get('progress_percent', 0)}%")
if st_ in ("completed", "failed"):
st.json(body)
break
time.sleep(1)
else:
status_ph.write("Stopped after 120 attempts (~2 min).")
st.json(body)
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
with tab_ask:
st.subheader("Ask a question")
q_col = st.text_input("Collection", value="default", key="ask_col")
question = st.text_area("Question", height=120, key="ask_q")
if st.button("Ask", key="btn_ask"):
if not question.strip():
st.warning("Enter a question.")
else:
try:
with st.spinner(
"Calling the API (embeddings + LLM can take several minutes on a slow machine; "
"ensure Ollama is running). Timeout is controlled by DOC_AUDI_HTTP_READ_TIMEOUT…"
):
with _client() as c:
r = _post_query_ask(
c,
question=question,
collection_name=q_col,
)
r.raise_for_status()
ans = r.json()
st.success(f"Query id: `{ans.get('query_id', '')}`")
if ans.get("answer"):
st.markdown("### Answer")
st.markdown(ans["answer"])
else:
st.warning(
"The API returned no **answer** text. "
"Check the collection has ingested chunks, LLM env, and expand **Raw response** below."
)
src = ans.get("sources") or []
if src:
with st.expander(f"Sources ({len(src)})"):
st.json(src)
else:
st.caption("No sources in this response (empty retrieval or model returned nothing).")
with st.expander("Raw response (debug)"):
st.json(ans)
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
with tab_sum:
st.subheader("Summarise collection")
s_col = st.text_input("Collection", value="default", key="sum_col")
focus = st.text_input("Optional focus / angle", value="", key="sum_focus")
if st.button("Summarise", key="btn_sum"):
try:
body: dict[str, Any] = {"collection_name": s_col}
if focus.strip():
body["focus"] = focus.strip()
with st.spinner("Calling summarise (can take 1–2 minutes on a cold model)…"):
with _client() as c:
r = c.post("/query/summarise", json=body)
r.raise_for_status()
ans = r.json()
st.success(f"Query id: `{ans.get('query_id', '')}` · documents: **{ans.get('document_count', '')}**")
summary_text = ans.get("summary") or ans.get("answer")
if summary_text:
st.markdown("### Summary")
st.markdown(summary_text)
else:
st.warning("No summary text in the response; see **Raw response** below.")
src = ans.get("sources") or []
if src:
with st.expander(f"Sources ({len(src)})"):
st.json(src)
with st.expander("Raw response (debug)"):
st.json(ans)
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
with tab_audit:
st.subheader("Audit log")
a1, a2 = st.columns(2)
with a1:
a_limit = st.number_input("Limit", min_value=1, max_value=100, value=20, key="a_lim")
with a2:
a_offset = st.number_input("Offset", min_value=0, value=0, key="a_off")
if st.button("List audit events", key="btn_audit_list"):
try:
with _client() as c:
r = _get_audit_logs(
c,
limit=int(a_limit),
offset=int(a_offset),
)
r.raise_for_status()
payload = r.json()
events = payload.get("logs", payload.get("events", []))
st.caption(f"Total matching: **{payload.get('total', len(events))}**")
if events:
st.dataframe(events, hide_index=True, use_container_width=True)
ids = [
e.get("query_id") or e.get("event_id")
for e in events
if isinstance(e, dict) and (e.get("query_id") or e.get("event_id"))
]
if ids:
st.session_state["_audit_ids"] = ids
else:
st.info("No audit events.")
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
st.subheader("Audit event detail")
ids_for_select = st.session_state.get("_audit_ids", [])
pick = ""
if ids_for_select:
pick = st.selectbox("Event ID", options=[""] + list(ids_for_select), key="audit_pick")
manual_id = st.text_input("Or enter query / event ID", key="audit_manual")
ev_id = (manual_id.strip() or (pick or "").strip()).strip()
if st.button("Load detail", key="btn_audit_detail") and ev_id:
try:
with _client() as c:
r = _get_audit_event_detail(c, ev_id)
r.raise_for_status()
st.json(r.json())
except httpx.HTTPStatusError as e:
st.error(_fmt_api_error(e))
except httpx.RequestError as e:
st.error(_fmt_request_error(e))
except Exception as e:
st.exception(e)
if __name__ == "__main__":
main()