Spaces:
Sleeping
Sleeping
| """ | |
| Voice Chatbot — Hugging Face Space version (Mistral API + MLflow tracing). | |
| Architecture: | |
| - LLM : Mistral Chat API (streaming SSE via httpx, traced manually) | |
| - STT : Mistral Transcriptions API (via Mistral SDK — auto-traced by MLflow) | |
| - TTS : Mistral Speech API (via httpx, not traced) | |
| MLflow tracing captures every LLM and STT call with: | |
| - Input prompts / messages | |
| - Output completions / transcriptions | |
| - Token counts & latency | |
| - Model name & generation parameters | |
| To activate, set these HF Space secrets: | |
| MISTRAL_API_KEY — required | |
| MLFLOW_TRACKING_URI — your cloud MLflow 3 server URL | |
| MLFLOW_EXPERIMENT_NAME — optional (default: mistral-chatbot) | |
| LLM_MODEL — optional (default: mistral-small-latest) | |
| STT_MODEL — optional (default: voxtral-mini-latest) | |
| TTS_MODEL — optional (default: voxtral-mini-tts-2603) | |
| TTS_VOICE — optional TTS voice ID | |
| """ | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import time | |
| import gradio as gr | |
| import httpx | |
| import numpy as np | |
| import soundfile as sf | |
| from pathlib import Path | |
| from mistralai.client import Mistral | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| ) | |
| logger = logging.getLogger("app") | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # CONFIGURATION — set via HF Space secrets (Settings → Repository secrets) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "") | |
| # ── LLM ───────────────────────────────────────────────────────────────── | |
| LLM_MODEL = os.getenv("LLM_MODEL", "mistral-small-latest") | |
| LLM_API_URL = os.getenv( | |
| "LLM_API_URL", | |
| "https://api.mistral.ai/v1/chat/completions", | |
| ) | |
| # ── STT ───────────────────────────────────────────────────────────────── | |
| STT_MODEL = os.getenv("STT_MODEL", "voxtral-mini-latest") | |
| STT_API_URL = os.getenv( | |
| "STT_API_URL", | |
| "https://api.mistral.ai/v1/audio/transcriptions", | |
| ) | |
| # ── TTS ───────────────────────────────────────────────────────────────── | |
| TTS_API_URL = os.getenv( | |
| "TTS_API_URL", | |
| "https://api.mistral.ai/v1/audio/speech", | |
| ) | |
| TTS_MODEL = os.getenv("TTS_MODEL", "voxtral-mini-tts-2603") | |
| TTS_VOICE = os.getenv("TTS_VOICE", "5a271406-039d-46fe-835b-fbbb00eaf08d") # "marie_neutral") | |
| # ── MLflow Tracing ────────────────────────────────────────────────────── | |
| # 1. Deploy a cloud MLflow 3 Tracking Server (e.g. on a VM or managed). | |
| # 2. Set MLFLOW_TRACKING_URI to its URL, e.g. "https://mlflow.example.com". | |
| # 3. Optionally set MLFLOW_EXPERIMENT_NAME (default: "mistral-chatbot"). | |
| # 4. Both are read from environment variables (HF Space secrets). | |
| # | |
| # What gets traced automatically (via mlflow.mistral.autolog()): | |
| # - STT calls (go through the Mistral Python SDK) | |
| # | |
| # What gets traced manually (via mlflow.start_span): | |
| # - LLM streaming calls (Mistral autolog doesn't support streaming) | |
| # | |
| # What is NOT traced: | |
| # - TTS calls (raw httpx, low diagnostic value) | |
| MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "") | |
| MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "mistral-chatbot") | |
| # ── Sentence boundary regex ──────────────────────────────────────────── | |
| SENTENCE_END = re.compile(r"[.?!…]+\s+") | |
| SYSTEM_PROMPT = { | |
| "role": "system", | |
| "content": "Tu es un assistant d'agent de Bordeaux Metropole. Tu communique exclusisement en FRANCAIS. Tu gardes un ton courtois.", | |
| } | |
| if not MISTRAL_API_KEY: | |
| print("⚠️ MISTRAL_API_KEY not set — all API calls will fail.") | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # MLflow setup (runs once at module import time) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # Guard flag — tracing is only active when a tracking URI is configured. | |
| MLFLOW_ENABLED = bool(MLFLOW_TRACKING_URI) | |
| if MLFLOW_ENABLED: | |
| # The mlflow package is declared in requirements.txt; if it's somehow | |
| # missing the try/except degrades gracefully (tracing disabled). | |
| try: | |
| import mlflow | |
| # Point the MLflow client at your cloud server. | |
| # Example: mlflow.set_tracking_uri("https://mlflow.example.com") | |
| mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) | |
| # Create or reuse the experiment. Runs will be grouped under it. | |
| mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME) | |
| # ── Auto-trace Mistral SDK calls ────────────────────────────── | |
| # This patches Mistral's Python SDK so every call to | |
| # client.chat.complete() | |
| # client.audio.transcriptions.complete() | |
| # etc. | |
| # is automatically recorded as a trace in MLflow. | |
| # | |
| # LIMITATION: streaming chat is NOT auto-traced; we handle that | |
| # manually in the stream_reply() function further down. | |
| # | |
| # Doc: https://mlflow.org/docs/latest/genai/tracing/integrations/listing/mistral/ | |
| mlflow.mistral.autolog() | |
| logger.info( | |
| "MLflow tracing ENABLED — tracking URI: %s, experiment: %s", | |
| MLFLOW_TRACKING_URI, | |
| MLFLOW_EXPERIMENT_NAME, | |
| ) | |
| except Exception as exc: | |
| logger.warning( | |
| "MLflow init failed (%s) — tracing disabled. " | |
| "Set MLFLOW_TRACKING_URI to a valid MLflow 3 server URL.", | |
| exc, | |
| ) | |
| MLFLOW_ENABLED = False | |
| # ── Mistral SDK client (used for auto-traced calls) ───────────────────── | |
| # We instantiate the async Mistral client. Calls made through this client | |
| # (e.g. audio transcriptions) are auto-traced by mlflow.mistral.autolog(). | |
| # The LLM streaming call uses raw httpx instead (see stream_llm) because | |
| # the Mistral SDK + streaming is not supported by MLflow auto-tracing. | |
| if MLFLOW_ENABLED: | |
| try: | |
| from mistralai.async_client import MistralAsync as _MistralAsync | |
| _mistral = _MistralAsync(api_key=MISTRAL_API_KEY) | |
| except Exception as exc: | |
| logger.warning("Failed to create Mistral SDK client: %s", exc) | |
| _mistral = None | |
| else: | |
| _mistral = None | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # IN-MEMORY METRICS (not related to MLflow — shown in the Gradio UI) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| class Metrics: | |
| """Simple in-memory counters and accumulators for the UI stats panel.""" | |
| def __init__(self): | |
| self.lock = asyncio.Lock() | |
| self.reset() | |
| def reset(self): | |
| self.stt_count = 0 | |
| self.stt_total_s = 0.0 | |
| self.llm_count = 0 | |
| self.llm_total_s = 0.0 | |
| self.tts_count = 0 | |
| self.tts_total_s = 0.0 | |
| self.total_tokens = 0 | |
| self.error_count = 0 | |
| self.last_errors: list[str] = [] | |
| async def record_stt(self, elapsed: float): | |
| async with self.lock: | |
| self.stt_count += 1 | |
| self.stt_total_s += elapsed | |
| async def record_llm(self, elapsed: float, tokens: int): | |
| async with self.lock: | |
| self.llm_count += 1 | |
| self.llm_total_s += elapsed | |
| self.total_tokens += tokens | |
| async def record_tts(self, elapsed: float): | |
| async with self.lock: | |
| self.tts_count += 1 | |
| self.tts_total_s += elapsed | |
| async def record_error(self, msg: str): | |
| async with self.lock: | |
| self.error_count += 1 | |
| self.last_errors.append(msg) | |
| if len(self.last_errors) > 20: | |
| self.last_errors.pop(0) | |
| def snapshot_md(self) -> str: | |
| def avg(total, count): | |
| return f"{total/count:.2f}s" if count else "—" | |
| m = self | |
| lines = [ | |
| "### 📊 Metrics", | |
| "| Service | Calls | Avg Latency | Total Time |", | |
| "|---------|-------|-------------|------------|", | |
| f"| **STT** | {m.stt_count} | {avg(m.stt_total_s, m.stt_count)} | {m.stt_total_s:.1f}s |", | |
| f"| **LLM** | {m.llm_count} | {avg(m.llm_total_s, m.llm_count)} | {m.llm_total_s:.1f}s |", | |
| f"| **TTS** | {m.tts_count} | {avg(m.tts_total_s, m.tts_count)} | {m.tts_total_s:.1f}s |", | |
| "", | |
| f"**Total tokens generated:** {m.total_tokens}", | |
| f"**Errors:** {m.error_count}", | |
| ] | |
| if m.last_errors: | |
| lines.append("\n**Recent errors:**") | |
| for e in m.last_errors[-5:]: | |
| lines.append(f"- `{e[:120]}`") | |
| return "\n".join(lines) | |
| _metrics = Metrics() | |
| async def reset_metrics(): | |
| _metrics.reset() | |
| return _metrics.snapshot_md() | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # HELPERS | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| def to_chatbot(history: list) -> list[dict]: | |
| """Convert internal tuple history [(user, bot), …] → Gradio messages format.""" | |
| msgs = [] | |
| for user_text, assistant_text in history: | |
| if user_text: | |
| msgs.append({"role": "user", "content": user_text}) | |
| if assistant_text: | |
| msgs.append({"role": "assistant", "content": assistant_text}) | |
| return msgs | |
| def build_messages(history: list, user_text: str) -> list[dict]: | |
| """Build the messages array for the Mistral Chat API from conversation history.""" | |
| messages = [SYSTEM_PROMPT] | |
| for human, assistant in history: | |
| messages.append({"role": "user", "content": human}) | |
| messages.append({"role": "assistant", "content": assistant}) | |
| messages.append({"role": "user", "content": user_text}) | |
| return messages | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # STT — speech-to-text (Mistral Transcriptions API, auto-traced via SDK) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def transcribe(audio_path: str) -> str: | |
| """Convert audio to 16 kHz mono WAV and transcribe via Mistral STT API. | |
| When MLflow tracing is enabled, this call goes through the Mistral | |
| Python SDK and is automatically captured by mlflow.mistral.autolog(). | |
| """ | |
| # ── Audio preprocessing ─────────────────────────────────────────── | |
| # Browser mic output (WebM/OGG) must be converted to 16 kHz mono WAV. | |
| import subprocess | |
| import tempfile as tf | |
| t0 = time.perf_counter() | |
| wav_path = None | |
| try: | |
| with tf.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| wav_path = tmp.name | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-i", audio_path, "-ar", "16000", "-ac", "1", wav_path], | |
| capture_output=True, check=True, | |
| ) | |
| with open(wav_path, "rb") as f: | |
| audio_bytes = f.read() | |
| except Exception as e: | |
| logger.warning("ffmpeg conversion failed, sending raw audio: %s", e) | |
| with open(audio_path, "rb") as f: | |
| audio_bytes = f.read() | |
| finally: | |
| if wav_path and os.path.exists(wav_path): | |
| os.unlink(wav_path) | |
| # ── Transcribe ──────────────────────────────────────────────────── | |
| try: | |
| if _mistral is not None: | |
| # Mistral SDK call — auto-traced by MLflow via autolog(). | |
| # The trace captures: model, audio file info, response text, | |
| # and usage stats (tokens, audio seconds). | |
| result = await _mistral.audio.transcriptions.complete( | |
| model=STT_MODEL, | |
| file={"content": audio_bytes, "file_name": "audio.wav"}, | |
| ) | |
| text = result.text | |
| else: | |
| # Fallback: raw httpx (no tracing). | |
| headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"} | |
| async with httpx.AsyncClient(timeout=120.0) as client: | |
| resp = await client.post( | |
| STT_API_URL, | |
| headers=headers, | |
| files={ | |
| "file": ("audio.wav", audio_bytes, "audio/wav"), | |
| "model": (None, STT_MODEL), | |
| }, | |
| ) | |
| resp.raise_for_status() | |
| text = resp.json()["text"] | |
| await _metrics.record_stt(time.perf_counter() - t0) | |
| logger.info( | |
| "STT ok %.2fs %.0f bytes model=%s", | |
| time.perf_counter() - t0, | |
| len(audio_bytes), | |
| STT_MODEL, | |
| ) | |
| return text | |
| except Exception as e: | |
| await _metrics.record_error(f"STT: {e}") | |
| raise | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # TTS — text-to-speech (Mistral Speech API, raw httpx, not traced) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def call_tts(client: Mistral, text: str) -> tuple | None: | |
| """Synthesise speech via Mistral TTS API, return (sample_rate, numpy_array). | |
| Uses the official Mistral SDK (audio.speech.complete). | |
| Not MLflow-traced (low diagnostic value for TTS). | |
| """ | |
| t0 = time.perf_counter() | |
| try: | |
| kwargs: dict = { | |
| "model": TTS_MODEL, | |
| "input": text, | |
| "response_format": "mp3", | |
| } | |
| if TTS_VOICE: | |
| kwargs["voice_id"] = TTS_VOICE | |
| # SDK is synchronous — run in a thread to avoid blocking the event loop | |
| loop = asyncio.get_running_loop() | |
| response = await loop.run_in_executor( | |
| None, | |
| lambda: client.audio.speech.complete(**kwargs), | |
| ) | |
| audio_bytes = base64.b64decode(response.audio_data) | |
| audio_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32") | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_tts(elapsed) | |
| logger.info("TTS ok %.2fs %d chars model=%s", elapsed, len(text), TTS_MODEL) | |
| return sr, audio_np | |
| except Exception as e: | |
| await _metrics.record_error(f"TTS: {e}") | |
| logger.warning("TTS failed (%.1fs): %s", time.perf_counter() - t0, e) | |
| return None | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # LLM — streaming text generation (Mistral Chat API via httpx SSE) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def stream_llm(messages: list[dict]): | |
| """Stream tokens from Mistral Chat API via Server-Sent Events. | |
| Yields (token, cumulative_token_count). | |
| Uses raw httpx because MLflow's Mistral autolog does NOT support | |
| streaming chat completions (the auto-trace would miss the response). | |
| Instead, the caller (stream_reply) manually records an MLflow span. | |
| """ | |
| headers = { | |
| "Authorization": f"Bearer {MISTRAL_API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| body = { | |
| "model": LLM_MODEL, | |
| "messages": messages, | |
| "stream": True, | |
| "max_tokens": 512, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| } | |
| t0 = time.perf_counter() | |
| token_count = 0 | |
| try: | |
| async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=120.0)) as client: | |
| async with client.stream( | |
| "POST", LLM_API_URL, json=body, headers=headers, | |
| ) as resp: | |
| resp.raise_for_status() | |
| async for line in resp.aiter_lines(): | |
| if line.startswith("data: "): | |
| payload = line[6:].strip() | |
| if payload == "[DONE]": | |
| break | |
| if not payload: | |
| continue | |
| try: | |
| data = json.loads(payload) | |
| except json.JSONDecodeError: | |
| continue | |
| choices = data.get("choices", []) | |
| if not choices: | |
| continue | |
| delta = choices[0].get("delta", {}) | |
| content = delta.get("content") | |
| if content: | |
| token_count += 1 | |
| yield content, token_count | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_llm(elapsed, token_count) | |
| logger.info( | |
| "LLM ok %.2fs %d tokens %.1f tok/s model=%s", | |
| elapsed, token_count, | |
| token_count / elapsed if elapsed else 0, | |
| LLM_MODEL, | |
| ) | |
| except httpx.HTTPStatusError as e: | |
| body_text = await e.response.aread() | |
| detail = body_text.decode(errors="replace")[:300] | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_error( | |
| f"LLM HTTP {e.response.status_code}: {detail[:80]}" | |
| ) | |
| logger.error("LLM HTTP %s %.1fs %s", e.response.status_code, elapsed, detail) | |
| yield f"⚠️ LLM API error ({e.response.status_code}): {detail}", 0 | |
| except Exception as e: | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_error(f"LLM: {type(e).__name__}: {e}") | |
| logger.error("LLM error %.1fs %s", elapsed, e) | |
| yield f"⚠️ LLM error: {type(e).__name__}: {e}", 0 | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # STREAM ORCHESTRATOR — consumes LLM tokens, records MLflow trace, calls TTS | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def stream_reply(messages: list[dict]): | |
| """Consume LLM stream, record an MLflow trace, and synthesize audio. | |
| Yields (partial_reply, audio_tuple_or_None). | |
| MLflow tracing: | |
| - The Mistral auto-log does not cover streaming completions. | |
| - We manually create a span with mlflow.start_span() that records | |
| the full input messages and the aggregated output. | |
| - The span is finalized when the stream ends (the ``finally`` block). | |
| - If MLFLOW_TRACKING_URI is not set, no trace is recorded. | |
| """ | |
| token_buffer = "" | |
| full_reply = "" | |
| _trace_token_count = 0 | |
| _trace_start = time.perf_counter() | |
| _trace_error = None | |
| try: | |
| async for token_or_error, token_count in stream_llm(messages): | |
| # ── Error in stream → stop & propagate ─────────────────── | |
| if token_or_error.startswith("⚠️"): | |
| _trace_error = token_or_error | |
| yield token_or_error, None | |
| return | |
| _trace_token_count = token_count | |
| token_buffer += token_or_error | |
| full_reply += token_or_error | |
| # ── Sentence-boundary TTS ───────────────────────────────── | |
| # When a complete sentence has arrived, dispatch a TTS call | |
| # so the user hears audio before the full reply finishes. | |
| match = SENTENCE_END.search(token_buffer) | |
| if match: | |
| sentence = token_buffer[: match.end()].strip() | |
| token_buffer = token_buffer[match.end():] | |
| if sentence: | |
| async with Mistral(api_key=MISTRAL_API_KEY) as client: | |
| audio = await call_tts(client, sentence) | |
| if audio is not None: | |
| yield full_reply, audio | |
| continue | |
| yield full_reply, None | |
| # ── Remaining text after last sentence boundary ─────────────── | |
| if token_buffer.strip(): | |
| async with Mistral(api_key=MISTRAL_API_KEY) as client: | |
| audio = await call_tts(client, token_buffer.strip()) | |
| if audio is not None: | |
| yield full_reply, audio | |
| return | |
| yield full_reply, None | |
| finally: | |
| # ── MLflow manual trace ─────────────────────────────────────── | |
| # This runs when the generator is exhausted (async for finishes) | |
| # or when an exception propagates out. | |
| # | |
| # It creates a single span ("llm_chat_stream") that contains the | |
| # full request (messages + params) and the aggregated response. | |
| if MLFLOW_ENABLED and full_reply: | |
| _elapsed = time.perf_counter() - _trace_start | |
| try: | |
| import mlflow | |
| from mlflow.tracing.fluent import start_span | |
| # start_span outside of a trace context creates a new | |
| # root span (i.e. a new trace). | |
| with start_span("llm_chat_stream") as span: | |
| span.set_inputs({ | |
| "model": LLM_MODEL, | |
| "messages": messages, | |
| "temperature": 0.95, | |
| "max_tokens": 512, | |
| }) | |
| span.set_outputs({ | |
| "response": full_reply, | |
| "token_count": _trace_token_count, | |
| "latency_seconds": round(_elapsed, 3), | |
| "tokens_per_second": round( | |
| _trace_token_count / _elapsed, 1 | |
| ) if _elapsed > 0 else 0, | |
| }) | |
| if _trace_error: | |
| span.set_status( | |
| mlflow.tracing.Status.ERROR, | |
| str(_trace_error), | |
| ) | |
| except Exception as trace_err: | |
| # Don't crash the app if the MLflow server is unreachable. | |
| logger.warning("MLflow trace recording failed: %s", trace_err) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # GRADIO TURN HANDLERS | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def run_turn(user_text: str, history: list): | |
| """Shared generator: streams chatbot text updates and per-sentence audio.""" | |
| history = history or [] | |
| messages = build_messages(history, user_text) | |
| new_history = history + [(user_text, "▌")] | |
| audio_chunks: list[np.ndarray] = [] | |
| audio_sr = None | |
| async for partial_reply, audio in stream_reply(messages): | |
| new_history[-1] = (user_text, partial_reply + "▌") | |
| if audio is not None: | |
| sr, arr = audio | |
| audio_sr = sr | |
| audio_chunks.append(arr) | |
| yield to_chatbot(new_history), new_history, None, None, gr.update() # . , ., audio, ., . remove partial stream | |
| final_text = new_history[-1][1].rstrip("▌").rstrip() | |
| new_history[-1] = (user_text, final_text) | |
| full_audio = ( | |
| (audio_sr, np.concatenate(audio_chunks)) if audio_chunks else None | |
| ) | |
| yield to_chatbot(new_history), new_history, None, full_audio, gr.update() | |
| async def text_turn(user_text: str, history: list): | |
| if not user_text.strip(): | |
| yield to_chatbot(history or []), history or [], None, None, user_text, gr.update() | |
| return | |
| first = True | |
| async for h, s, a, la, _ in run_turn(user_text, history): | |
| yield h, s, a, la, ("" if first else gr.update()), gr.update() | |
| first = False | |
| async def voice_turn(audio_path: str, history: list): | |
| if audio_path is None: | |
| yield to_chatbot(history or []), history or [], None, None, gr.update() | |
| return | |
| try: | |
| user_text = await transcribe(audio_path) | |
| except Exception as e: | |
| logger.exception("STT error") | |
| err = f"⚠️ STT error: {type(e).__name__}: {e}" | |
| yield to_chatbot((history or []) + [("[voice]", err)]), history or [], None, None, gr.update() | |
| return | |
| async for h, s, a, la, _ in run_turn(user_text, history): | |
| yield h, s, a, la, gr.update() | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # GRADIO UI | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| DESCRIPTION = f"""## Voice Chatbot | |
| _API LLM: `{LLM_MODEL}` · STT: `{STT_MODEL}` · TTS: `{TTS_MODEL}`_ | |
| """ | |
| with gr.Blocks( | |
| title="Voice Chatbot", | |
| theme=gr.themes.Soft(), | |
| css="footer {visibility: hidden}", | |
| ) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| chatbot = gr.Chatbot(label="Conversation", height=380) | |
| state = gr.State([]) | |
| last_audio = gr.State(None) | |
| with gr.Row(): | |
| text_box = gr.Textbox( | |
| placeholder="Type a message and press Enter …", | |
| show_label=False, | |
| scale=5, | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| with gr.Row(): | |
| mic_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Voice input", | |
| scale=5, | |
| ) | |
| voice_btn = gr.Button("Submit voice", scale=1) | |
| audio_out = gr.Audio(label="Bot response (audio)") # knowed issue on huggingface to stream audio ==> implement from GradioSpace https://huggingface.co/spaces/gradio/stream_audio_out/blob/main/app.py , autoplay=True) | |
| replay_btn = gr.Button("🔁 Replay", size="sm") | |
| # ── Observability panel ─────────────────────────────────────────── | |
| with gr.Accordion("📊 Stats & Observability", open=False): | |
| stats_md = gr.Markdown(_metrics.snapshot_md()) | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄 Refresh", size="sm") | |
| reset_btn = gr.Button("🗑 Reset", size="sm") | |
| refresh_btn.click(fn=_metrics.snapshot_md, outputs=[stats_md]) | |
| reset_btn.click(fn=reset_metrics, outputs=[stats_md]) | |
| # ── Event wiring ────────────────────────────────────────────────── | |
| send_btn.click( | |
| text_turn, | |
| inputs=[text_box, state], | |
| outputs=[chatbot, state, audio_out, last_audio, text_box, stats_md], | |
| ) | |
| text_box.submit( | |
| text_turn, | |
| inputs=[text_box, state], | |
| outputs=[chatbot, state, audio_out, last_audio, text_box, stats_md], | |
| ) | |
| voice_btn.click( | |
| voice_turn, | |
| inputs=[mic_input, state], | |
| outputs=[chatbot, state, audio_out, last_audio, stats_md], | |
| ) | |
| replay_btn.click( | |
| fn=lambda a: a, | |
| inputs=[last_audio], | |
| outputs=[audio_out], | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") | |