| import os, io, base64, json, time |
| from typing import Optional, Tuple |
| import requests |
| from PIL import Image, ImageOps, ImageDraw |
| import streamlit as st |
| from streamlit_mic_recorder import mic_recorder, speech_to_text |
| from gtts import gTTS |
|
|
| |
| |
| |
| BACKEND = os.getenv("BACKEND_URL", "http://localhost:8000").rstrip("/") |
| TIMEOUT = 180 |
|
|
| st.set_page_config(page_title="Realtime BI Assistant (Frontend)", layout="centered") |
|
|
| |
| |
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| with st.sidebar: |
| st.subheader("Auth / Tokens") |
| _tok = st.text_input("HF Token (optional)", value=HF_TOKEN or "", type="password", help="Used by backend to call Hugging Face APIs.") |
| if _tok.strip(): |
| HF_TOKEN = _tok.strip() |
|
|
| def _headers(): |
| h = {"Accept": "application/json"} |
| if HF_TOKEN: |
| h["Authorization"] = f"Bearer {HF_TOKEN}" |
| return h |
|
|
| def ping_backend() -> Tuple[bool, Optional[dict]]: |
| try: |
| r = requests.get(f"{BACKEND}/", timeout=10, headers=_headers()) |
| return r.ok, (r.json() if r.ok else None) |
| except Exception as e: |
| return False, {"error": str(e)} |
|
|
| def post_json(path: str, payload: dict) -> requests.Response: |
| return requests.post(f"{BACKEND}{path}", json=payload, timeout=TIMEOUT, headers=_headers()) |
|
|
| def post_multipart(path: str, files: dict, params: dict) -> requests.Response: |
| return requests.post(f"{BACKEND}{path}", files=files, params=params, timeout=TIMEOUT, headers=_headers()) |
|
|
| def pil_from_upload(file) -> Optional[Image.Image]: |
| try: |
| return Image.open(file).convert("RGB") |
| except Exception: |
| return None |
|
|
| def compress_and_b64(img: Image.Image, max_side: int = 1280, quality: int = 85): |
| img = ImageOps.exif_transpose(img) |
| w0, h0 = img.size |
| scale = max(w0, h0) / max_side if max(w0, h0) > max_side else 1.0 |
| img_proc = img.resize((int(w0/scale), int(h0/scale))) if scale > 1.0 else img |
|
|
| buf = io.BytesIO() |
| img_proc.save(buf, format="JPEG", quality=quality, optimize=True) |
| b64 = base64.b64encode(buf.getvalue()).decode() |
| return b64, img_proc, (w0, h0), img_proc.size |
|
|
| def draw_bbox(img: Image.Image, bbox: list[int], color=(0, 255, 0), width: int = 4) -> Image.Image: |
| out = img.copy() |
| draw = ImageDraw.Draw(out) |
| x1, y1, x2, y2 = bbox |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) |
| return out |
|
|
| def tts_gtts_bytes(text: str, lang: str = "en", tld: str = "com", slow: bool = False) -> bytes: |
| buf = io.BytesIO() |
| gTTS(text=text, lang=lang, tld=tld, slow=slow).write_to_fp(buf) |
| return buf.getvalue() |
|
|
| |
| if "chat" not in st.session_state: |
| st.session_state.chat = [] |
|
|
| def add_chat(role: str, text: str): |
| st.session_state.chat.append({"role": role, "text": text}) |
|
|
| def render_chat_transcript(): |
| st.subheader("🗨️ Conversation") |
| for m in st.session_state.chat[-100:]: |
| with st.chat_message("user" if m["role"]=="user" else "assistant"): |
| st.markdown(m["text"]) |
|
|
| |
| |
| |
| def render_examples_buttons(key_prefix: str = "main"): |
| cols = st.columns(2) |
| examples = [ |
| "What is total sales (revenue) of Ramesh?", |
| "Revenue for BLR on 2025-09-06", |
| "Monthly revenue for Electronics in BLR for 2025-09", |
| "Top 5 SKUs by revenue in HYD on 2025-09-06 (include category)", |
| "Ramesh's total sales in NCR on 2025-09-06", |
| ] |
| for i, ex in enumerate(examples): |
| if cols[i % 2].button(ex, key=f"{key_prefix}_ex_{i}"): |
| st.session_state["q_text"] = ex |
| st.rerun() |
|
|
|
|
| def render_bi_question_section(section_heading=True, key_prefix: str = "main"): |
| if section_heading: |
| st.subheader("3) Ask a BI question") |
|
|
| with st.expander("Examples", expanded=False): |
| render_examples_buttons(key_prefix=key_prefix) |
|
|
| |
| default_q = st.session_state.get("q_text", "What is total sales (revenue) of Ramesh?") |
| q_text = st.text_area("Your question", value=default_q, height=100, |
| key=f"{key_prefix}_q_textarea") |
|
|
| with st.expander("Optional: visual context (JSON)", expanded=False): |
| vis_str = st.text_area("visual_ctx", value="{}", height=80, |
| key=f"{key_prefix}_vis_text") |
|
|
| try: |
| visual_ctx = json.loads(vis_str) if vis_str.strip() else {} |
| except Exception: |
| visual_ctx = {} |
| st.warning("`visual_ctx` is not valid JSON; ignored.") |
|
|
| if st.button("Ask", key=f"{key_prefix}_ask"): |
| payload = { |
| "user_id": st.session_state.user_name or None, |
| "text": q_text.strip(), |
| "visual_ctx": visual_ctx, |
| } |
| try: |
| with st.spinner("Querying…"): |
| r = post_json("/query", payload) |
|
|
| if r.ok: |
| resp = r.json() |
| answer = resp.get("answer_text", "") |
| st.success(answer) |
| st.session_state["last_answer_text"] = answer |
|
|
| sqls = [c[4:] for c in resp.get("citations", []) |
| if isinstance(c, str) and c.startswith("sql:")] |
| if sqls: |
| with st.expander("SQL used", expanded=True): |
| st.code(sqls[0], language="sql") |
| for s in sqls[1:]: |
| st.code(s, language="sql") |
|
|
| if resp.get("metrics"): |
| with st.expander("Metrics", expanded=False): |
| st.json(resp["metrics"]) |
| if resp.get("chart_refs"): |
| with st.expander("Charts", expanded=False): |
| st.json(resp["chart_refs"]) |
| if "uncertainty" in resp: |
| st.caption(f"Uncertainty: {resp['uncertainty']:.2f}") |
| else: |
| try: |
| err = r.json() |
| except Exception: |
| err = {"detail": r.text} |
| st.error(f"Backend error {r.status_code}: {err.get('detail')}") |
| if "SQLGenTool disabled" in str(err.get("detail", "")): |
| st.info("Add your Hugging Face token in the sidebar (or set the HF_TOKEN env var).") |
| except Exception as e: |
| st.error(f"Request error: {e}") |
|
|
| def render_voice_to_text(): |
| st.caption("Voice → Text (browser STT)") |
| c1, c2 = st.columns([2, 1]) |
| with c1: |
| st.write("Click to speak; recognized text will fill the question box.") |
| with c2: |
| stt_lang = st.selectbox("STT language", ["en", "hi"], index=0, key="stt_lang_dd") |
| if "prev_stt_lang" not in st.session_state: |
| st.session_state["prev_stt_lang"] = stt_lang |
| elif st.session_state["prev_stt_lang"] != stt_lang: |
| st.session_state["prev_stt_lang"] = stt_lang |
| st.rerun() |
|
|
| stt_text = speech_to_text( |
| language=st.session_state.get("stt_lang_dd", "en"), |
| use_container_width=True, |
| just_once=True, |
| start_prompt="🎙️ Start recording", |
| stop_prompt="⏹️ Stop recording", |
| key="stt_main_btn", |
| ) |
|
|
| if stt_text: |
| st.session_state["q_text"] = stt_text |
| st.success(f"Recognized: {stt_text}") |
| st.rerun() |
|
|
| st.markdown("---") |
| st.caption("Optional: record & play raw audio (no transcription)") |
| rec = mic_recorder( |
| start_prompt="🎙️ Start", |
| stop_prompt="⏹️ Stop", |
| just_once=True, |
| key="mic_raw_btn", |
| ) |
| if rec and rec.get("bytes"): |
| st.audio(rec["bytes"], format="audio/wav") |
|
|
| def render_tts_controls(): |
| st.markdown("---") |
| st.caption("Text → Voice (gTTS)") |
| tts_lang = st.selectbox("TTS language", ["en", "hi"], index=0, key="tts_lang_dd") |
| tld_label = st.selectbox( |
| "Accent / region (tld)", |
| ["Default (.com)", "India (.co.in)", "US (.us)", "UK (.co.uk)"], |
| index=1, |
| key="tts_tld_dd" |
| ) |
| tld_map = { |
| "Default (.com)": "com", |
| "India (.co.in)": "co.in", |
| "US (.us)": "us", |
| "UK (.co.uk)": "co.uk", |
| } |
| |
| if st.button("🔊 Speak last answer", key="tts_speak_btn"): |
| ans = st.session_state.get("last_answer_text", "") |
| if not ans.strip(): |
| st.warning("Ask a question first to generate an answer.") |
| else: |
| try: |
| mp3 = tts_gtts_bytes(ans, lang=tts_lang, tld=tld_map[tld_label], slow=False) |
| st.audio(mp3, format="audio/mp3") |
| except Exception as e: |
| st.error(f"TTS error: {e}") |
|
|
| |
| |
| |
| st.title("Realtime BI Assistant") |
| st.caption("Face upsert/identify + BI Q&A (text) via your FastAPI backend") |
|
|
| ok, info = ping_backend() |
| status_col, url_col = st.columns([1,3]) |
| with status_col: |
| st.metric("Backend", "Online ✅" if ok else "Offline ❌") |
| with url_col: |
| st.code(BACKEND, language="text") |
|
|
| if not ok and info: |
| st.warning(f"Backend unreachable: {info}") |
|
|
| |
| if "user_name" not in st.session_state: |
| st.session_state.user_name = "mohit" |
|
|
| |
| |
| |
| with st.expander("1) Bulk enroll via ZIP (Images/<UserName>/*)", expanded=False): |
| zip_up = st.file_uploader("Upload ZIP", type=["zip"], key="zip_enroll") |
| if st.button("Enroll ZIP"): |
| if not zip_up: |
| st.error("Please upload a ZIP.") |
| else: |
| try: |
| with st.spinner("Uploading & enrolling…"): |
| files = {"zipfile_upload": ("enroll.zip", zip_up.read(), "application/zip")} |
| r = post_multipart("/enroll_zip", files=files, params={}) |
| if r.ok: |
| st.success("Enrollment complete ✅") |
| st.json(r.json()) |
| else: |
| st.error(f"Enrollment failed: {r.status_code}") |
| st.text(r.text) |
| except Exception as e: |
| st.error(f"Request error: {e}") |
|
|
| |
| |
| |
| with st.expander("2) Identify from image", expanded=False): |
| col_u, col_c = st.columns(2) |
| with col_u: |
| test_upload = st.file_uploader("Upload test image", type=["jpg","jpeg","png"], key="test_upload") |
| with col_c: |
| test_cam = st.camera_input("Or capture from camera", key="test_cam") |
|
|
| test_img = None |
| if test_cam is not None: |
| test_img = pil_from_upload(test_cam) |
| elif test_upload is not None: |
| test_img = pil_from_upload(test_upload) |
|
|
| def encode_for_backend(img: Image.Image): |
| b64_out = compress_and_b64(img) |
| if isinstance(b64_out, (tuple, list)): |
| b64_str = b64_out[0] |
| else: |
| b64_str = b64_out |
|
|
| if isinstance(b64_str, (bytes, bytearray)): |
| b64_str = b64_str.decode("utf-8") |
|
|
| if isinstance(b64_str, str) and b64_str.startswith("data:"): |
| b64_str = b64_str.split(",", 1)[1] |
|
|
| raw = base64.b64decode(b64_str.encode("utf-8")) |
| sent_img = Image.open(io.BytesIO(raw)).convert("RGB") |
| return b64_str, sent_img |
|
|
| def draw_many(img: Image.Image, dets: list[dict]) -> Image.Image: |
| out = img.copy() |
| draw = ImageDraw.Draw(out) |
| for d in dets: |
| x1, y1, x2, y2 = [int(v) for v in d.get("bbox", [0, 0, 0, 0])] |
| name = d.get("decision", "Unknown") |
| score = float(d.get("best_score", 0.0)) |
| label = f"{name} ({score:.3f})" |
| draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 0), width=3) |
| try: |
| tb = draw.textbbox((x1, y1), label) |
| tw, th = tb[2] - tb[0], tb[3] - tb[1] |
| except Exception: |
| tw, th = max(60, len(label) * 7), 14 |
| by1 = max(0, y1 - th - 6) |
| draw.rectangle([x1, by1, x1 + tw + 6, y1], fill=(0, 0, 0)) |
| draw.text((x1 + 3, by1 + 2), label, fill=(0, 255, 0)) |
| return out |
|
|
| if st.button("Identify"): |
| if test_img is None: |
| st.warning("Please provide an image first.") |
| else: |
| try: |
| b64, sent_img = encode_for_backend(test_img) |
| with st.spinner("Identifying…"): |
| r = post_json("/identify_many", {"image_b64": b64, "top_k": 3}) |
| if not r.ok: |
| st.error(f"Identify failed: {r.status_code}") |
| st.text(r.text) |
| else: |
| data = r.json() |
| dets = data.get("detections", []) |
| st.caption(f"Faces found: {len(dets)}") |
| st.image(draw_many(sent_img, dets), use_container_width=True) |
| if dets: |
| with st.expander("Details"): |
| st.json(dets) |
| except Exception as e: |
| st.error(f"Request error: {e}") |
|
|
| |
| |
| |
| st.subheader("🎙️ Voice mode (optional)") |
| with st.expander("Speak your question / hear the answer", expanded=True): |
| |
| render_voice_to_text() |
|
|
| |
| render_bi_question_section(section_heading=False, key_prefix="voice") |
|
|
| |
| render_tts_controls() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| st.subheader("🗣️ Talk → Ask → Speak") |
| c_left, c_right = st.columns([2, 3]) |
| with c_left: |
| st.caption("Press to speak; we'll answer, speak back, and log the chat below.") |
| with c_right: |
| |
| tts_lang = st.session_state.get("tts_lang_dd", "en") |
| tld_map = {"Default (.com)": "com", "India (.co.in)": "co.in", "US (.us)": "us", "UK (.co.uk)": "co.uk"} |
| tld_label = st.session_state.get("tts_tld_dd", "India (.co.in)") |
|
|
| |
| spoken = speech_to_text( |
| language=st.session_state.get("stt_lang_dd", "en"), |
| use_container_width=True, |
| just_once=True, |
| start_prompt="🎙️ Speak", |
| stop_prompt="⏹️ Stop", |
| key="stt_conv_btn", |
| ) |
|
|
| if spoken: |
| user_text = spoken.strip() |
| if user_text: |
| add_chat("user", user_text) |
| payload = {"user_id": st.session_state.user_name or None, "text": user_text, "visual_ctx": {}} |
| with st.spinner("Thinking…"): |
| r = post_json("/query", payload) |
| if r.ok: |
| resp = r.json() |
| answer = resp.get("answer_text", "").strip() |
| add_chat("assistant", answer or "_(no rows)_") |
| st.session_state["last_answer_text"] = answer |
| |
| try: |
| mp3 = tts_gtts_bytes(answer or "I have no rows to report.", |
| lang=tts_lang, |
| tld=tld_map.get(tld_label, "co.in"), |
| slow=False) |
| st.audio(mp3, format="audio/mp3") |
| except Exception as e: |
| st.error(f"TTS error: {e}") |
| else: |
| try: |
| err = r.json() |
| except Exception: |
| err = {"detail": r.text} |
| add_chat("assistant", f"Backend error {r.status_code}: {err.get('detail')}") |
| st.error(f"Backend error {r.status_code}: {err.get('detail')}") |
|
|
| |
| render_chat_transcript() |
|
|
| |
| |
| |
| |
|
|