MohitGupta41
Final Commit
4390dd3
import os, io, base64, json, time
from typing import Optional, Tuple
import requests
from PIL import Image, ImageOps, ImageDraw
import streamlit as st
from streamlit_mic_recorder import mic_recorder, speech_to_text
from gtts import gTTS
# -----------------------
# Config
# -----------------------
BACKEND = os.getenv("BACKEND_URL", "http://localhost:8000").rstrip("/")
TIMEOUT = 180
st.set_page_config(page_title="Realtime BI Assistant (Frontend)", layout="centered")
# -----------------------
# Helpers
# -----------------------
HF_TOKEN = os.getenv("HF_TOKEN")
with st.sidebar:
st.subheader("Auth / Tokens")
_tok = st.text_input("HF Token (optional)", value=HF_TOKEN or "", type="password", help="Used by backend to call Hugging Face APIs.")
if _tok.strip():
HF_TOKEN = _tok.strip()
def _headers():
h = {"Accept": "application/json"}
if HF_TOKEN:
h["Authorization"] = f"Bearer {HF_TOKEN}"
return h
def ping_backend() -> Tuple[bool, Optional[dict]]:
try:
r = requests.get(f"{BACKEND}/", timeout=10, headers=_headers())
return r.ok, (r.json() if r.ok else None)
except Exception as e:
return False, {"error": str(e)}
def post_json(path: str, payload: dict) -> requests.Response:
return requests.post(f"{BACKEND}{path}", json=payload, timeout=TIMEOUT, headers=_headers())
def post_multipart(path: str, files: dict, params: dict) -> requests.Response:
return requests.post(f"{BACKEND}{path}", files=files, params=params, timeout=TIMEOUT, headers=_headers())
def pil_from_upload(file) -> Optional[Image.Image]:
try:
return Image.open(file).convert("RGB")
except Exception:
return None
def compress_and_b64(img: Image.Image, max_side: int = 1280, quality: int = 85):
img = ImageOps.exif_transpose(img)
w0, h0 = img.size
scale = max(w0, h0) / max_side if max(w0, h0) > max_side else 1.0
img_proc = img.resize((int(w0/scale), int(h0/scale))) if scale > 1.0 else img
buf = io.BytesIO()
img_proc.save(buf, format="JPEG", quality=quality, optimize=True)
b64 = base64.b64encode(buf.getvalue()).decode()
return b64, img_proc, (w0, h0), img_proc.size
def draw_bbox(img: Image.Image, bbox: list[int], color=(0, 255, 0), width: int = 4) -> Image.Image:
out = img.copy()
draw = ImageDraw.Draw(out)
x1, y1, x2, y2 = bbox
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
return out
def tts_gtts_bytes(text: str, lang: str = "en", tld: str = "com", slow: bool = False) -> bytes:
buf = io.BytesIO()
gTTS(text=text, lang=lang, tld=tld, slow=slow).write_to_fp(buf)
return buf.getvalue()
# --- Chat state helpers ---
if "chat" not in st.session_state:
st.session_state.chat = [] # list of {"role": "user"|"assistant", "text": str}
def add_chat(role: str, text: str):
st.session_state.chat.append({"role": role, "text": text})
def render_chat_transcript():
st.subheader("🗨️ Conversation")
for m in st.session_state.chat[-100:]: # show last 100 turns
with st.chat_message("user" if m["role"]=="user" else "assistant"):
st.markdown(m["text"])
# -----------------------
# Small UI renderers (so we can reorder cleanly)
# -----------------------
def render_examples_buttons(key_prefix: str = "main"):
cols = st.columns(2)
examples = [
"What is total sales (revenue) of Ramesh?",
"Revenue for BLR on 2025-09-06",
"Monthly revenue for Electronics in BLR for 2025-09",
"Top 5 SKUs by revenue in HYD on 2025-09-06 (include category)",
"Ramesh's total sales in NCR on 2025-09-06",
]
for i, ex in enumerate(examples):
if cols[i % 2].button(ex, key=f"{key_prefix}_ex_{i}"):
st.session_state["q_text"] = ex
st.rerun()
def render_bi_question_section(section_heading=True, key_prefix: str = "main"):
if section_heading:
st.subheader("3) Ask a BI question")
with st.expander("Examples", expanded=False):
render_examples_buttons(key_prefix=key_prefix)
# Use a unique key for the textarea.
default_q = st.session_state.get("q_text", "What is total sales (revenue) of Ramesh?")
q_text = st.text_area("Your question", value=default_q, height=100,
key=f"{key_prefix}_q_textarea")
with st.expander("Optional: visual context (JSON)", expanded=False):
vis_str = st.text_area("visual_ctx", value="{}", height=80,
key=f"{key_prefix}_vis_text")
try:
visual_ctx = json.loads(vis_str) if vis_str.strip() else {}
except Exception:
visual_ctx = {}
st.warning("`visual_ctx` is not valid JSON; ignored.")
if st.button("Ask", key=f"{key_prefix}_ask"):
payload = {
"user_id": st.session_state.user_name or None,
"text": q_text.strip(),
"visual_ctx": visual_ctx,
}
try:
with st.spinner("Querying…"):
r = post_json("/query", payload)
if r.ok:
resp = r.json()
answer = resp.get("answer_text", "")
st.success(answer)
st.session_state["last_answer_text"] = answer
sqls = [c[4:] for c in resp.get("citations", [])
if isinstance(c, str) and c.startswith("sql:")]
if sqls:
with st.expander("SQL used", expanded=True):
st.code(sqls[0], language="sql")
for s in sqls[1:]:
st.code(s, language="sql")
if resp.get("metrics"):
with st.expander("Metrics", expanded=False):
st.json(resp["metrics"])
if resp.get("chart_refs"):
with st.expander("Charts", expanded=False):
st.json(resp["chart_refs"])
if "uncertainty" in resp:
st.caption(f"Uncertainty: {resp['uncertainty']:.2f}")
else:
try:
err = r.json()
except Exception:
err = {"detail": r.text}
st.error(f"Backend error {r.status_code}: {err.get('detail')}")
if "SQLGenTool disabled" in str(err.get("detail", "")):
st.info("Add your Hugging Face token in the sidebar (or set the HF_TOKEN env var).")
except Exception as e:
st.error(f"Request error: {e}")
def render_voice_to_text():
st.caption("Voice → Text (browser STT)")
c1, c2 = st.columns([2, 1])
with c1:
st.write("Click to speak; recognized text will fill the question box.")
with c2:
stt_lang = st.selectbox("STT language", ["en", "hi"], index=0, key="stt_lang_dd")
if "prev_stt_lang" not in st.session_state:
st.session_state["prev_stt_lang"] = stt_lang
elif st.session_state["prev_stt_lang"] != stt_lang:
st.session_state["prev_stt_lang"] = stt_lang
st.rerun()
stt_text = speech_to_text(
language=st.session_state.get("stt_lang_dd", "en"),
use_container_width=True,
just_once=True,
start_prompt="🎙️ Start recording",
stop_prompt="⏹️ Stop recording",
key="stt_main_btn",
)
if stt_text:
st.session_state["q_text"] = stt_text
st.success(f"Recognized: {stt_text}")
st.rerun()
st.markdown("---")
st.caption("Optional: record & play raw audio (no transcription)")
rec = mic_recorder(
start_prompt="🎙️ Start",
stop_prompt="⏹️ Stop",
just_once=True,
key="mic_raw_btn",
)
if rec and rec.get("bytes"):
st.audio(rec["bytes"], format="audio/wav")
def render_tts_controls():
st.markdown("---")
st.caption("Text → Voice (gTTS)")
tts_lang = st.selectbox("TTS language", ["en", "hi"], index=0, key="tts_lang_dd")
tld_label = st.selectbox(
"Accent / region (tld)",
["Default (.com)", "India (.co.in)", "US (.us)", "UK (.co.uk)"],
index=1,
key="tts_tld_dd"
)
tld_map = {
"Default (.com)": "com",
"India (.co.in)": "co.in",
"US (.us)": "us",
"UK (.co.uk)": "co.uk",
}
if st.button("🔊 Speak last answer", key="tts_speak_btn"):
ans = st.session_state.get("last_answer_text", "")
if not ans.strip():
st.warning("Ask a question first to generate an answer.")
else:
try:
mp3 = tts_gtts_bytes(ans, lang=tts_lang, tld=tld_map[tld_label], slow=False)
st.audio(mp3, format="audio/mp3")
except Exception as e:
st.error(f"TTS error: {e}")
# -----------------------
# UI
# -----------------------
st.title("Realtime BI Assistant")
st.caption("Face upsert/identify + BI Q&A (text) via your FastAPI backend")
ok, info = ping_backend()
status_col, url_col = st.columns([1,3])
with status_col:
st.metric("Backend", "Online ✅" if ok else "Offline ❌")
with url_col:
st.code(BACKEND, language="text")
if not ok and info:
st.warning(f"Backend unreachable: {info}")
# Persist chosen user
if "user_name" not in st.session_state:
st.session_state.user_name = "mohit"
# -----------------------
# 1) Bulk enroll via ZIP (Images/<UserName>/*)
# -----------------------
with st.expander("1) Bulk enroll via ZIP (Images/<UserName>/*)", expanded=False):
zip_up = st.file_uploader("Upload ZIP", type=["zip"], key="zip_enroll")
if st.button("Enroll ZIP"):
if not zip_up:
st.error("Please upload a ZIP.")
else:
try:
with st.spinner("Uploading & enrolling…"):
files = {"zipfile_upload": ("enroll.zip", zip_up.read(), "application/zip")}
r = post_multipart("/enroll_zip", files=files, params={})
if r.ok:
st.success("Enrollment complete ✅")
st.json(r.json())
else:
st.error(f"Enrollment failed: {r.status_code}")
st.text(r.text)
except Exception as e:
st.error(f"Request error: {e}")
# -----------------------
# 2) Identify from image
# -----------------------
with st.expander("2) Identify from image", expanded=False):
col_u, col_c = st.columns(2)
with col_u:
test_upload = st.file_uploader("Upload test image", type=["jpg","jpeg","png"], key="test_upload")
with col_c:
test_cam = st.camera_input("Or capture from camera", key="test_cam")
test_img = None
if test_cam is not None:
test_img = pil_from_upload(test_cam)
elif test_upload is not None:
test_img = pil_from_upload(test_upload)
def encode_for_backend(img: Image.Image):
b64_out = compress_and_b64(img)
if isinstance(b64_out, (tuple, list)):
b64_str = b64_out[0]
else:
b64_str = b64_out
if isinstance(b64_str, (bytes, bytearray)):
b64_str = b64_str.decode("utf-8")
if isinstance(b64_str, str) and b64_str.startswith("data:"):
b64_str = b64_str.split(",", 1)[1]
raw = base64.b64decode(b64_str.encode("utf-8"))
sent_img = Image.open(io.BytesIO(raw)).convert("RGB")
return b64_str, sent_img
def draw_many(img: Image.Image, dets: list[dict]) -> Image.Image:
out = img.copy()
draw = ImageDraw.Draw(out)
for d in dets:
x1, y1, x2, y2 = [int(v) for v in d.get("bbox", [0, 0, 0, 0])]
name = d.get("decision", "Unknown")
score = float(d.get("best_score", 0.0))
label = f"{name} ({score:.3f})"
draw.rectangle([x1, y1, x2, y2], outline=(0, 255, 0), width=3)
try:
tb = draw.textbbox((x1, y1), label)
tw, th = tb[2] - tb[0], tb[3] - tb[1]
except Exception:
tw, th = max(60, len(label) * 7), 14
by1 = max(0, y1 - th - 6)
draw.rectangle([x1, by1, x1 + tw + 6, y1], fill=(0, 0, 0))
draw.text((x1 + 3, by1 + 2), label, fill=(0, 255, 0))
return out
if st.button("Identify"):
if test_img is None:
st.warning("Please provide an image first.")
else:
try:
b64, sent_img = encode_for_backend(test_img)
with st.spinner("Identifying…"):
r = post_json("/identify_many", {"image_b64": b64, "top_k": 3})
if not r.ok:
st.error(f"Identify failed: {r.status_code}")
st.text(r.text)
else:
data = r.json()
dets = data.get("detections", [])
st.caption(f"Faces found: {len(dets)}")
st.image(draw_many(sent_img, dets), use_container_width=True)
if dets:
with st.expander("Details"):
st.json(dets)
except Exception as e:
st.error(f"Request error: {e}")
# -----------------------
# 2.5) Voice mode (frontend-only) with requested order
# -----------------------
st.subheader("🎙️ Voice mode (optional)")
with st.expander("Speak your question / hear the answer", expanded=True):
# (1) Voice → Text first
render_voice_to_text()
# (2) Ask a BI question (same logic as main section)
render_bi_question_section(section_heading=False, key_prefix="voice")
# (3) Listen response (TTS button)
render_tts_controls()
# -----------------------
# 2.6) Talk → Ask → Speak (voice chat with transcript)
# -----------------------
st.subheader("🗣️ Talk → Ask → Speak")
c_left, c_right = st.columns([2, 3])
with c_left:
st.caption("Press to speak; we'll answer, speak back, and log the chat below.")
with c_right:
# voice settings reuse your TTS controls' state if present; else defaults
tts_lang = st.session_state.get("tts_lang_dd", "en")
tld_map = {"Default (.com)": "com", "India (.co.in)": "co.in", "US (.us)": "us", "UK (.co.uk)": "co.uk"}
tld_label = st.session_state.get("tts_tld_dd", "India (.co.in)")
# Mic widget (one utterance per click)
spoken = speech_to_text(
language=st.session_state.get("stt_lang_dd", "en"),
use_container_width=True,
just_once=True,
start_prompt="🎙️ Speak",
stop_prompt="⏹️ Stop",
key="stt_conv_btn",
)
if spoken:
user_text = spoken.strip()
if user_text:
add_chat("user", user_text)
payload = {"user_id": st.session_state.user_name or None, "text": user_text, "visual_ctx": {}}
with st.spinner("Thinking…"):
r = post_json("/query", payload)
if r.ok:
resp = r.json()
answer = resp.get("answer_text", "").strip()
add_chat("assistant", answer or "_(no rows)_")
st.session_state["last_answer_text"] = answer
# speak the answer
try:
mp3 = tts_gtts_bytes(answer or "I have no rows to report.",
lang=tts_lang,
tld=tld_map.get(tld_label, "co.in"),
slow=False)
st.audio(mp3, format="audio/mp3")
except Exception as e:
st.error(f"TTS error: {e}")
else:
try:
err = r.json()
except Exception:
err = {"detail": r.text}
add_chat("assistant", f"Backend error {r.status_code}: {err.get('detail')}")
st.error(f"Backend error {r.status_code}: {err.get('detail')}")
# Show running transcript
render_chat_transcript()
# -----------------------
# 3) Ask a BI question (also kept as a main section for non-voice users)
# -----------------------
# render_bi_question_section(section_heading=True, key_prefix="main")