Spaces:
Sleeping
Sleeping
| import os, io, base64, json, time | |
| from typing import Optional, Tuple | |
| import requests | |
| from PIL import Image, ImageOps, ImageDraw | |
| import streamlit as st | |
| from streamlit_mic_recorder import mic_recorder, speech_to_text | |
| from gtts import gTTS | |
| # ----------------------- | |
| # Config | |
| # ----------------------- | |
| BACKEND = os.getenv("BACKEND_URL", "http://localhost:8000").rstrip("/") | |
| TIMEOUT = 180 | |
| st.set_page_config(page_title="Realtime BI Assistant (Frontend)", layout="centered") | |
| # ----------------------- | |
| # Helpers | |
| # ----------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| with st.sidebar: | |
| st.subheader("Auth / Tokens") | |
| _tok = st.text_input("HF Token (optional)", value=HF_TOKEN or "", type="password", help="Used by backend to call Hugging Face APIs.") | |
| if _tok.strip(): | |
| HF_TOKEN = _tok.strip() | |
| def _headers(): | |
| h = {"Accept": "application/json"} | |
| if HF_TOKEN: | |
| h["Authorization"] = f"Bearer {HF_TOKEN}" | |
| return h | |
| def ping_backend() -> Tuple[bool, Optional[dict]]: | |
| try: | |
| r = requests.get(f"{BACKEND}/", timeout=10, headers=_headers()) | |
| return r.ok, (r.json() if r.ok else None) | |
| except Exception as e: | |
| return False, {"error": str(e)} | |
| def post_json(path: str, payload: dict) -> requests.Response: | |
| return requests.post(f"{BACKEND}{path}", json=payload, timeout=TIMEOUT, headers=_headers()) | |
| def post_multipart(path: str, files: dict, params: dict) -> requests.Response: | |
| return requests.post(f"{BACKEND}{path}", files=files, params=params, timeout=TIMEOUT, headers=_headers()) | |
| def pil_from_upload(file) -> Optional[Image.Image]: | |
| try: | |
| return Image.open(file).convert("RGB") | |
| except Exception: | |
| return None | |
| def compress_and_b64(img: Image.Image, max_side: int = 1280, quality: int = 85): | |
| img = ImageOps.exif_transpose(img) | |
| w0, h0 = img.size | |
| scale = max(w0, h0) / max_side if max(w0, h0) > max_side else 1.0 | |
| img_proc = img.resize((int(w0/scale), int(h0/scale))) if scale > 1.0 else img | |
| buf = io.BytesIO() | |
| img_proc.save(buf, format="JPEG", quality=quality, optimize=True) | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| return b64, img_proc, (w0, h0), img_proc.size | |
| def draw_bbox(img: Image.Image, bbox: list[int], color=(0, 255, 0), width: int = 4) -> Image.Image: | |
| out = img.copy() | |
| draw = ImageDraw.Draw(out) | |
| x1, y1, x2, y2 = bbox | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| return out | |
| def tts_gtts_bytes(text: str, lang: str = "en", tld: str = "com", slow: bool = False) -> bytes: | |
| buf = io.BytesIO() | |
| gTTS(text=text, lang=lang, tld=tld, slow=slow).write_to_fp(buf) | |
| return buf.getvalue() | |
| # --- Chat state helpers --- | |
| if "chat" not in st.session_state: | |
| st.session_state.chat = [] # list of {"role": "user"|"assistant", "text": str} | |
| def add_chat(role: str, text: str): | |
| st.session_state.chat.append({"role": role, "text": text}) | |
| def render_chat_transcript(): | |
| st.subheader("🗨️ Conversation") | |
| for m in st.session_state.chat[-100:]: # show last 100 turns | |
| with st.chat_message("user" if m["role"]=="user" else "assistant"): | |
| st.markdown(m["text"]) | |
| # ----------------------- | |
| # Small UI renderers (so we can reorder cleanly) | |
| # ----------------------- | |
| def render_examples_buttons(key_prefix: str = "main"): | |
| cols = st.columns(2) | |
| examples = [ | |
| "What is total sales (revenue) of Ramesh?", | |
| "Revenue for BLR on 2025-09-06", | |
| "Monthly revenue for Electronics in BLR for 2025-09", | |
| "Top 5 SKUs by revenue in HYD on 2025-09-06 (include category)", | |
| "Ramesh's total sales in NCR on 2025-09-06", | |
| ] | |
| for i, ex in enumerate(examples): | |
| if cols[i % 2].button(ex, key=f"{key_prefix}_ex_{i}"): | |
| st.session_state["q_text"] = ex | |
| st.rerun() | |
| def render_bi_question_section(section_heading=True, key_prefix: str = "main"): | |
| if section_heading: | |
| st.subheader("3) Ask a BI question") | |
| with st.expander("Examples", expanded=False): | |
| render_examples_buttons(key_prefix=key_prefix) | |
| # Use a unique key for the textarea. | |
| default_q = st.session_state.get("q_text", "What is total sales (revenue) of Ramesh?") | |
| q_text = st.text_area("Your question", value=default_q, height=100, | |
| key=f"{key_prefix}_q_textarea") | |
| with st.expander("Optional: visual context (JSON)", expanded=False): | |
| vis_str = st.text_area("visual_ctx", value="{}", height=80, | |
| key=f"{key_prefix}_vis_text") | |
| try: | |
| visual_ctx = json.loads(vis_str) if vis_str.strip() else {} | |
| except Exception: | |
| visual_ctx = {} | |
| st.warning("`visual_ctx` is not valid JSON; ignored.") | |
| if st.button("Ask", key=f"{key_prefix}_ask"): | |
| payload = { | |
| "user_id": st.session_state.user_name or None, | |
| "text": q_text.strip(), | |
| "visual_ctx": visual_ctx, | |
| } | |
| try: | |
| with st.spinner("Querying…"): | |
| r = post_json("/query", payload) | |
| if r.ok: | |
| resp = r.json() | |
| answer = resp.get("answer_text", "") | |
| st.success(answer) | |
| st.session_state["last_answer_text"] = answer | |
| sqls = [c[4:] for c in resp.get("citations", []) | |
| if isinstance(c, str) and c.startswith("sql:")] | |
| if sqls: | |
| with st.expander("SQL used", expanded=True): | |
| st.code(sqls[0], language="sql") | |
| for s in sqls[1:]: | |
| st.code(s, language="sql") | |
| if resp.get("metrics"): | |
| with st.expander("Metrics", expanded=False): | |
| st.json(resp["metrics"]) | |
| if resp.get("chart_refs"): | |
| with st.expander("Charts", expanded=False): | |
| st.json(resp["chart_refs"]) | |
| if "uncertainty" in resp: | |
| st.caption(f"Uncertainty: {resp['uncertainty']:.2f}") | |
| else: | |
| try: | |
| err = r.json() | |
| except Exception: | |
| err = {"detail": r.text} | |
| st.error(f"Backend error {r.status_code}: {err.get('detail')}") | |
| if "SQLGenTool disabled" in str(err.get("detail", "")): | |
| st.info("Add your Hugging Face token in the sidebar (or set the HF_TOKEN env var).") | |
| except Exception as e: | |
| st.error(f"Request error: {e}") | |
| def render_voice_to_text(): | |
| st.caption("Voice → Text (browser STT)") | |
| c1, c2 = st.columns([2, 1]) | |
| with c1: | |
| st.write("Click to speak; recognized text will fill the question box.") | |
| with c2: | |
| stt_lang = st.selectbox("STT language", ["en", "hi"], index=0, key="stt_lang_dd") | |
| if "prev_stt_lang" not in st.session_state: | |
| st.session_state["prev_stt_lang"] = stt_lang | |
| elif st.session_state["prev_stt_lang"] != stt_lang: | |
| st.session_state["prev_stt_lang"] = stt_lang | |
| st.rerun() | |
| stt_text = speech_to_text( | |
| language=st.session_state.get("stt_lang_dd", "en"), | |
| use_container_width=True, | |
| just_once=True, | |
| start_prompt="🎙️ Start recording", | |
| stop_prompt="⏹️ Stop recording", | |
| key="stt_main_btn", | |
| ) | |
| if stt_text: | |
| st.session_state["q_text"] = stt_text | |
| st.success(f"Recognized: {stt_text}") | |
| st.rerun() | |
| st.markdown("---") | |
| st.caption("Optional: record & play raw audio (no transcription)") | |
| rec = mic_recorder( | |
| start_prompt="🎙️ Start", | |
| stop_prompt="⏹️ Stop", | |
| just_once=True, | |
| key="mic_raw_btn", | |
| ) | |
| if rec and rec.get("bytes"): | |
| st.audio(rec["bytes"], format="audio/wav") | |
| def render_tts_controls(): | |
| st.markdown("---") | |
| st.caption("Text → Voice (gTTS)") | |
| tts_lang = st.selectbox("TTS language", ["en", "hi"], index=0, key="tts_lang_dd") | |
| tld_label = st.selectbox( | |
| "Accent / region (tld)", | |
| ["Default (.com)", "India (.co.in)", "US (.us)", "UK (.co.uk)"], | |
| index=1, | |
| key="tts_tld_dd" | |
| ) | |
| tld_map = { | |
| "Default (.com)": "com", | |
| "India (.co.in)": "co.in", | |
| "US (.us)": "us", | |
| "UK (.co.uk)": "co.uk", | |
| } | |
| if st.button("🔊 Speak last answer", key="tts_speak_btn"): | |
| ans = st.session_state.get("last_answer_text", "") | |
| if not ans.strip(): | |
| st.warning("Ask a question first to generate an answer.") | |
| else: | |
| try: | |
| mp3 = tts_gtts_bytes(ans, lang=tts_lang, tld=tld_map[tld_label], slow=False) | |
| st.audio(mp3, format="audio/mp3") | |
| except Exception as e: | |
| st.error(f"TTS error: {e}") | |
| # ----------------------- | |
| # UI | |
| # ----------------------- | |
| st.title("Realtime BI Assistant") | |
| st.caption("Face upsert/identify + BI Q&A (text) via your FastAPI backend") | |
| ok, info = ping_backend() | |
| status_col, url_col = st.columns([1,3]) | |
| with status_col: | |
| st.metric("Backend", "Online ✅" if ok else "Offline ❌") | |
| with url_col: | |
| st.code(BACKEND, language="text") | |
| if not ok and info: | |
| st.warning(f"Backend unreachable: {info}") | |
| # Persist chosen user | |
| if "user_name" not in st.session_state: | |
| st.session_state.user_name = "mohit" | |
| # ----------------------- | |
| # 1) Bulk enroll via ZIP (Images/<UserName>/*) | |
| # ----------------------- | |
| with st.expander("1) Bulk enroll via ZIP (Images/<UserName>/*)", expanded=False): | |
| zip_up = st.file_uploader("Upload ZIP", type=["zip"], key="zip_enroll") | |
| if st.button("Enroll ZIP"): | |
| if not zip_up: | |
| st.error("Please upload a ZIP.") | |
| else: | |
| try: | |
| with st.spinner("Uploading & enrolling…"): | |
| files = {"zipfile_upload": ("enroll.zip", zip_up.read(), "application/zip")} | |
| r = post_multipart("/enroll_zip", files=files, params={}) | |
| if r.ok: | |
| st.success("Enrollment complete ✅") | |
| st.json(r.json()) | |
| else: | |
| st.error(f"Enrollment failed: {r.status_code}") | |
| st.text(r.text) | |
| except Exception as e: | |
| st.error(f"Request error: {e}") | |
| # ----------------------- | |
| # 2) Identify from image | |
| # ----------------------- | |
| with st.expander("2) Identify from image", expanded=False): | |
| col_u, col_c = st.columns(2) | |
| with col_u: | |
| test_upload = st.file_uploader("Upload test image", type=["jpg","jpeg","png"], key="test_upload") | |
| with col_c: | |
| test_cam = st.camera_input("Or capture from camera", key="test_cam") | |
| test_img = None | |
| if test_cam is not None: | |
| test_img = pil_from_upload(test_cam) | |
| elif test_upload is not None: | |
| test_img = pil_from_upload(test_upload) | |
| def encode_for_backend(img: Image.Image): | |
| b64_out = compress_and_b64(img) | |
| if isinstance(b64_out, (tuple, list)): | |
| b64_str = b64_out[0] | |
| else: | |
| b64_str = b64_out | |
| if isinstance(b64_str, (bytes, bytearray)): | |
| b64_str = b64_str.decode("utf-8") | |
| if isinstance(b64_str, str) and b64_str.startswith("data:"): | |
| b64_str = b64_str.split(",", 1)[1] | |
| raw = base64.b64decode(b64_str.encode("utf-8")) | |
| sent_img = Image.open(io.BytesIO(raw)).convert("RGB") | |
| return b64_str, sent_img | |
| def draw_many(img: Image.Image, dets: list[dict]) -> Image.Image: | |
| out = img.copy() | |
| draw = ImageDraw.Draw(out) | |
| for d in dets: | |
| x1, y1, x2, y2 = [int(v) for v in d.get("bbox", [0, 0, 0, 0])] | |
| name = d.get("decision", "Unknown") | |
| score = float(d.get("best_score", 0.0)) | |
| label = f"{name} ({score:.3f})" | |
| draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 0), width=3) | |
| try: | |
| tb = draw.textbbox((x1, y1), label) | |
| tw, th = tb[2] - tb[0], tb[3] - tb[1] | |
| except Exception: | |
| tw, th = max(60, len(label) * 7), 14 | |
| by1 = max(0, y1 - th - 6) | |
| draw.rectangle([x1, by1, x1 + tw + 6, y1], fill=(0, 0, 0)) | |
| draw.text((x1 + 3, by1 + 2), label, fill=(0, 255, 0)) | |
| return out | |
| if st.button("Identify"): | |
| if test_img is None: | |
| st.warning("Please provide an image first.") | |
| else: | |
| try: | |
| b64, sent_img = encode_for_backend(test_img) | |
| with st.spinner("Identifying…"): | |
| r = post_json("/identify_many", {"image_b64": b64, "top_k": 3}) | |
| if not r.ok: | |
| st.error(f"Identify failed: {r.status_code}") | |
| st.text(r.text) | |
| else: | |
| data = r.json() | |
| dets = data.get("detections", []) | |
| st.caption(f"Faces found: {len(dets)}") | |
| st.image(draw_many(sent_img, dets), use_container_width=True) | |
| if dets: | |
| with st.expander("Details"): | |
| st.json(dets) | |
| except Exception as e: | |
| st.error(f"Request error: {e}") | |
| # ----------------------- | |
| # 2.5) Voice mode (frontend-only) with requested order | |
| # ----------------------- | |
| st.subheader("🎙️ Voice mode (optional)") | |
| with st.expander("Speak your question / hear the answer", expanded=True): | |
| # (1) Voice → Text first | |
| render_voice_to_text() | |
| # (2) Ask a BI question (same logic as main section) | |
| render_bi_question_section(section_heading=False, key_prefix="voice") | |
| # (3) Listen response (TTS button) | |
| render_tts_controls() | |
| # ----------------------- | |
| # 2.6) Talk → Ask → Speak (voice chat with transcript) | |
| # ----------------------- | |
| st.subheader("🗣️ Talk → Ask → Speak") | |
| c_left, c_right = st.columns([2, 3]) | |
| with c_left: | |
| st.caption("Press to speak; we'll answer, speak back, and log the chat below.") | |
| with c_right: | |
| # voice settings reuse your TTS controls' state if present; else defaults | |
| tts_lang = st.session_state.get("tts_lang_dd", "en") | |
| tld_map = {"Default (.com)": "com", "India (.co.in)": "co.in", "US (.us)": "us", "UK (.co.uk)": "co.uk"} | |
| tld_label = st.session_state.get("tts_tld_dd", "India (.co.in)") | |
| # Mic widget (one utterance per click) | |
| spoken = speech_to_text( | |
| language=st.session_state.get("stt_lang_dd", "en"), | |
| use_container_width=True, | |
| just_once=True, | |
| start_prompt="🎙️ Speak", | |
| stop_prompt="⏹️ Stop", | |
| key="stt_conv_btn", | |
| ) | |
| if spoken: | |
| user_text = spoken.strip() | |
| if user_text: | |
| add_chat("user", user_text) | |
| payload = {"user_id": st.session_state.user_name or None, "text": user_text, "visual_ctx": {}} | |
| with st.spinner("Thinking…"): | |
| r = post_json("/query", payload) | |
| if r.ok: | |
| resp = r.json() | |
| answer = resp.get("answer_text", "").strip() | |
| add_chat("assistant", answer or "_(no rows)_") | |
| st.session_state["last_answer_text"] = answer | |
| # speak the answer | |
| try: | |
| mp3 = tts_gtts_bytes(answer or "I have no rows to report.", | |
| lang=tts_lang, | |
| tld=tld_map.get(tld_label, "co.in"), | |
| slow=False) | |
| st.audio(mp3, format="audio/mp3") | |
| except Exception as e: | |
| st.error(f"TTS error: {e}") | |
| else: | |
| try: | |
| err = r.json() | |
| except Exception: | |
| err = {"detail": r.text} | |
| add_chat("assistant", f"Backend error {r.status_code}: {err.get('detail')}") | |
| st.error(f"Backend error {r.status_code}: {err.get('detail')}") | |
| # Show running transcript | |
| render_chat_transcript() | |
| # ----------------------- | |
| # 3) Ask a BI question (also kept as a main section for non-voice users) | |
| # ----------------------- | |
| # render_bi_question_section(section_heading=True, key_prefix="main") | |