""" Cortex RAG β Streamlit UI (Phase 1) Tabs: π Ask β streaming Q&A with inline citations and source cards π₯ Ingest β upload documents or provide a directory path π©Ί System β health check and collection statistics """ from __future__ import annotations import json import time from pathlib import Path from typing import Optional import sys sys.path.append(str(Path(__file__).resolve().parent.parent)) from config import get_settings import requests import streamlit as st # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββ cfg = get_settings() API_BASE = f"http://{cfg.api_host}:{cfg.api_port}" REDIS_URL = cfg.redis_url st.set_page_config( page_title="Cortex RAG", page_icon="π§ ", layout="wide", initial_sidebar_state="expanded", ) # ββ Styles ββββββββββββββββββββββββββββββββββββββββββββββββββββ st.markdown(""" """, unsafe_allow_html=True) # ββ Session state defaults ββββββββββββββββββββββββββββββββββββ def _init_state(): defaults = { "messages": [], # list of {role, content, chunks} "ingest_log": [], "health": None, } for k, v in defaults.items(): if k not in st.session_state: st.session_state[k] = v _init_state() def _render_source_cards_raw(chunks: list[dict]): st.caption("**Retrieved passages**") cols = st.columns(min(len(chunks), 3)) for i, chunk in enumerate(chunks): with cols[i % len(cols)]: score_pct = int(chunk.get("score", 0) * 100) title = chunk.get("title", "Unknown") source = Path(chunk.get("source", "")).name snippet = chunk.get("text_snippet", "")[:160] retriever = chunk.get("retriever", "dense") retriever_colors = { "dense": "#dbeafe:#1e40af", "bm25": "#dcfce7:#166534", "dense+bm25": "#f3e8ff:#6b21a8", "bm25+dense": "#f3e8ff:#6b21a8", "graph": "#fef9c3:#854d0e", "web_search": "#fee2e2:#991b1b", } ret_style = retriever_colors.get(retriever, "#f3f4f6:#374151") ret_bg, ret_fg = ret_style.split(":") st.markdown(f"""
Cortex RAG
', unsafe_allow_html=True) st.markdown( 'Production-grade RAG Β· Phase 1: Dense retrieval + streaming generation
', unsafe_allow_html=True ) st.divider() tab_ask, tab_ingest, tab_eval, tab_system = st.tabs(["π Ask", "π₯ Ingest", "π Evaluation", "π©Ί System"]) # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # TAB 1 β ASK # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ with tab_ask: # Replay conversation history for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) if msg.get("chunks"): _render_source_cards(msg["chunks"]) query = st.chat_input("Ask anything about your documentsβ¦") if query: # Append and display user message st.session_state.messages.append({"role": "user", "content": query}) with st.chat_message("user"): st.markdown(query) # Fetch streamed response with st.chat_message("assistant"): answer_placeholder = st.empty() sources_placeholder = st.empty() status_placeholder = st.empty() answer_text = "" retrieved_chunks = [] try: with requests.post( f"{API_BASE}/query/stream", json={"query": query, "top_k": top_k, "stream": True}, stream=True, timeout=300, ) as resp: resp.raise_for_status() for raw_line in resp.iter_lines(): if not raw_line: continue line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line if not line.startswith("data: "): continue payload = json.loads(line[6:]) event_type = payload.get("type") if event_type == "chunk_meta": retrieved_chunks = payload.get("chunks", []) routing = payload.get("routing", {}) intent = routing.get("intent", "") strategies = routing.get("strategies", []) hits = routing.get("retriever_hits", {}) hits_str = " Β· ".join(f"{k}: {v}" for k, v in hits.items()) strategy_str = " + ".join(s.upper() for s in strategies) status_placeholder.caption( f"π§ **{intent}** β {strategy_str} | π {len(retrieved_chunks)} passages | {hits_str}" ) elif event_type == "token": answer_text += payload.get("text", "") answer_placeholder.markdown(answer_text + "β") elif event_type == "sources": # Replace cursor and append sources answer_placeholder.markdown(answer_text) sources_placeholder.markdown(payload.get("text", "")) status_placeholder.empty() elif event_type == "crag_update": grade = payload.get("grade", "") rewritten = payload.get("rewritten_query") web_used = payload.get("web_search_used", False) reasoning = payload.get("reasoning", "") icon = {"POOR": "π", "ABSENT": "π"}.get(grade, "βΉοΈ") msg = f"{icon} **CRAG {grade}**: {reasoning[:100]}" if rewritten: msg += " \n\u21a9 Rewritten: *" + rewritten + "*" if web_used: msg += " \n\U0001f310 Web search fallback used" status_placeholder.info(msg) elif event_type == "done": answer_placeholder.markdown(answer_text) status_placeholder.empty() break elif event_type == "error": st.error(f"API error: {payload.get('message')}") break except requests.exceptions.ConnectionError: st.error( "β οΈ Cannot reach the Cortex API. " "Make sure `uvicorn api.main:app` is running on port 8000." ) answer_text = "_Connection error β see above._" except Exception as exc: st.error(f"Unexpected error: {exc}") answer_text = "_Error β see above._" # Render source cards inline if retrieved_chunks: _render_source_cards_raw(retrieved_chunks) # Save to conversation history st.session_state.messages.append({ "role": "assistant", "content": answer_text, "chunks": retrieved_chunks, }) # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # TAB 2 β INGEST # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ with tab_ingest: st.subheader("Ingest documents into the knowledge base") st.caption( "Supported formats: **PDF**, **HTML**, **TXT**, **Markdown**. " "Files are deduplicated automatically." ) col_left, col_right = st.columns([1, 1], gap="large") with col_left: st.markdown("#### Option A β Provide a server path") ingest_path = st.text_input( "Path on server", placeholder="data/documents or /abs/path/to/file.pdf", help="Relative or absolute path accessible by the API process.", ) recursive = st.checkbox("Recursive (include subdirectories)", value=True) if st.button("π Start ingestion", type="primary", disabled=not ingest_path): with st.spinner("Ingestingβ¦"): try: resp = requests.post( f"{API_BASE}/ingest", json={"path": ingest_path, "recursive": recursive}, timeout=300, ) resp.raise_for_status() result = resp.json() st.success( f"β {result['documents_processed']} documents processed, " f"{result['chunks_stored']} chunks stored." ) if result.get("documents_skipped"): st.info(f"βΉοΈ {result['documents_skipped']} documents already existed β skipped.") if result.get("errors"): st.warning(f"β οΈ {len(result['errors'])} errors:") for err in result["errors"]: st.code(json.dumps(err, indent=2)) st.session_state.ingest_log.append(result) except requests.exceptions.ConnectionError: st.error("Cannot reach the API. Is uvicorn running?") except Exception as exc: st.error(f"Ingestion failed: {exc}") with col_right: st.markdown("#### Ingestion log") if st.session_state.ingest_log: for i, entry in enumerate(reversed(st.session_state.ingest_log[-5:])): with st.expander(f"Run {len(st.session_state.ingest_log) - i}", expanded=(i==0)): st.json(entry) else: st.caption("No ingestion runs yet.") # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # TAB 3 β EVALUATION DASHBOARD # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ with tab_eval: st.subheader("RAG evaluation dashboard") st.caption("Metrics update automatically after each query. RAGAS scores compute in the background (~5s after response).") if st.button("π Refresh metrics"): st.session_state.pop("metrics_data", None) if "metrics_data" not in st.session_state: try: resp = requests.get(f"{API_BASE}/metrics?limit=200&days=14", timeout=5) resp.raise_for_status() st.session_state.metrics_data = resp.json() except Exception as exc: st.session_state.metrics_data = {"error": str(exc)} mdata = st.session_state.get("metrics_data", {}) if "error" in mdata: st.error(f"Cannot reach API: {mdata['error']}") else: summary = mdata.get("summary", {}) cache = mdata.get("cache", {}) # ββ Header KPI row βββββββββββββββββββββββββββββββββββββ k1, k2, k3, k4, k5, k6 = st.columns(6) k1.metric("Total queries", summary.get("total_queries", 0)) k2.metric("Faithfulness", f"{summary.get('avg_faithfulness', 0):.2f}") k3.metric("Answer relevancy", f"{summary.get('avg_answer_relevancy', 0):.2f}") k4.metric("Context precision",f"{summary.get('avg_context_precision', 0):.2f}") k5.metric("Avg latency", f"{summary.get('avg_latency_ms', 0):.0f} ms") k6.metric("Cache hit rate", f"{cache.get('hit_rate', 0):.0%}" if cache.get('enabled') else "off") st.divider() # ββ Metric timeseries ββββββββββββββββββββββββββββββββββ ts = mdata.get("timeseries", []) if ts: import pandas as pd df_ts = pd.DataFrame(ts) df_ts["hour"] = df_ts["hour_bucket"] st.markdown("#### RAGAS metrics over time") st.line_chart( df_ts.set_index("hour")[["faithfulness", "answer_relevancy", "context_precision"]], height=220, ) else: st.info("No evaluation data yet. Run some queries to populate the dashboard.") st.divider() col_left, col_right = st.columns(2, gap="large") with col_left: # ββ CRAG grade distribution ββββββββββββββββββββββββ grade_dist = summary.get("crag_grade_dist", {}) if grade_dist: import pandas as pd st.markdown("#### CRAG grade distribution") df_grades = pd.DataFrame( list(grade_dist.items()), columns=["Grade", "Count"] ) st.bar_chart(df_grades.set_index("Grade"), height=180) # ββ Strategy distribution ββββββββββββββββββββββββββ strat_dist = summary.get("strategy_dist", {}) if strat_dist: import pandas as pd st.markdown("#### Retrieval strategy mix") rows = [] for strat_json, cnt in strat_dist.items(): try: import json as _json label = "+".join(_json.loads(strat_json)).upper() except Exception: label = strat_json rows.append({"Strategy": label, "Count": cnt}) df_strat = pd.DataFrame(rows) st.bar_chart(df_strat.set_index("Strategy"), height=180) with col_right: # ββ Cache stats ββββββββββββββββββββββββββββββββββββ st.markdown("#### Cache") if cache.get("enabled"): c1, c2 = st.columns(2) c1.metric("Hits", cache.get("hits", 0)) c2.metric("Misses", cache.get("misses", 0)) st.caption(f"TTL: {cache.get('ttl_s', 0)//60} min") if st.button("ποΈ Flush cache"): try: r = requests.post(f"{REDIS_URL}/cache/flush", timeout=5) st.success(f"Flushed {r.json().get('deleted', 0)} entries.") st.session_state.pop("metrics_data", None) except Exception as e: st.error(str(e)) else: st.caption("Redis not connected. Start Redis to enable caching.") st.code("docker run -d -p 6379:6379 redis:7-alpine", language="bash") st.divider() # ββ Recent query log table βββββββββββββββββββββββββββββ recent = mdata.get("recent", []) if recent: import pandas as pd st.markdown("#### Recent queries") rows = [] for r in recent[:50]: rows.append({ "Query": r.get("query", "")[:60], "Intent": r.get("intent", ""), "CRAG": r.get("crag_grade", ""), "Faithful": f"{r['faithfulness']:.2f}" if r.get("faithfulness") else "β", "Relevancy": f"{r['answer_relevancy']:.2f}" if r.get("answer_relevancy") else "β", "Precision": f"{r['context_precision']:.2f}" if r.get("context_precision") else "β", "Latency ms": f"{r.get('latency_ms', 0):.0f}", }) st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # TAB 4 β SYSTEM HEALTH # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ with tab_system: st.subheader("System health") if st.button("π Refresh health"): st.session_state.health = None if st.session_state.health is None: try: resp = requests.get(f"{API_BASE}/health", timeout=5) resp.raise_for_status() st.session_state.health = resp.json() except Exception as exc: st.session_state.health = {"error": str(exc)} health = st.session_state.health if health: if "error" in health: st.error(f"Cannot reach API: {health['error']}") else: status = health.get("status", "unknown") icon = "β " if status == "ok" else "β οΈ" st.markdown(f"**Overall status**: {icon} `{status}`") col_a, col_b, col_c = st.columns(3) with col_a: milvus = health.get("milvus", "unknown") st.metric("Milvus", "β ok" if milvus == "ok" else f"β {milvus}") with col_b: embedder = health.get("embedder", "unknown") st.metric("Embedder", "β loaded" if embedder == "loaded" else "β³ not loaded") with col_c: stats = health.get("collection_stats", {}) st.metric("Chunks indexed", stats.get("entity_count", "β")) st.divider() graph_stats = health.get("graph_stats", {}) if graph_stats: col_d, col_e = st.columns(2) with col_d: st.metric("Graph nodes", graph_stats.get("nodes", "β")) with col_e: st.metric("Graph edges", graph_stats.get("edges", "β")) st.divider() st.markdown("**Raw health response**") st.json(health)