""" Voice Chatbot — Hugging Face Space version (Mistral API + MLflow tracing). Architecture: - LLM : Mistral Chat API (streaming SSE via httpx, traced manually) - STT : Mistral Transcriptions API (via Mistral SDK — auto-traced by MLflow) - TTS : Mistral Speech API (via httpx, not traced) MLflow tracing captures every LLM and STT call with: - Input prompts / messages - Output completions / transcriptions - Token counts & latency - Model name & generation parameters To activate, set these HF Space secrets: MISTRAL_API_KEY — required MLFLOW_TRACKING_URI — your cloud MLflow 3 server URL MLFLOW_EXPERIMENT_NAME — optional (default: mistral-chatbot) LLM_MODEL — optional (default: mistral-small-latest) STT_MODEL — optional (default: voxtral-mini-latest) TTS_MODEL — optional (default: voxtral-mini-tts-2603) TTS_VOICE — optional TTS voice ID """ import asyncio import base64 import io import json import logging import os import re import time import gradio as gr import httpx import numpy as np import soundfile as sf from pathlib import Path from mistralai.client import Mistral logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", ) logger = logging.getLogger("app") # ══════════════════════════════════════════════════════════════════════════ # CONFIGURATION — set via HF Space secrets (Settings → Repository secrets) # ══════════════════════════════════════════════════════════════════════════ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "") # ── LLM ───────────────────────────────────────────────────────────────── LLM_MODEL = os.getenv("LLM_MODEL", "mistral-small-latest") LLM_API_URL = os.getenv( "LLM_API_URL", "https://api.mistral.ai/v1/chat/completions", ) # ── STT ───────────────────────────────────────────────────────────────── STT_MODEL = os.getenv("STT_MODEL", "voxtral-mini-latest") STT_API_URL = os.getenv( "STT_API_URL", "https://api.mistral.ai/v1/audio/transcriptions", ) # ── TTS ───────────────────────────────────────────────────────────────── TTS_API_URL = os.getenv( "TTS_API_URL", "https://api.mistral.ai/v1/audio/speech", ) TTS_MODEL = os.getenv("TTS_MODEL", "voxtral-mini-tts-2603") TTS_VOICE = os.getenv("TTS_VOICE", "5a271406-039d-46fe-835b-fbbb00eaf08d") # "marie_neutral") # ── MLflow Tracing ────────────────────────────────────────────────────── # 1. Deploy a cloud MLflow 3 Tracking Server (e.g. on a VM or managed). # 2. Set MLFLOW_TRACKING_URI to its URL, e.g. "https://mlflow.example.com". # 3. Optionally set MLFLOW_EXPERIMENT_NAME (default: "mistral-chatbot"). # 4. Both are read from environment variables (HF Space secrets). # # What gets traced automatically (via mlflow.mistral.autolog()): # - STT calls (go through the Mistral Python SDK) # # What gets traced manually (via mlflow.start_span): # - LLM streaming calls (Mistral autolog doesn't support streaming) # # What is NOT traced: # - TTS calls (raw httpx, low diagnostic value) MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "") MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "mistral-chatbot") # ── Sentence boundary regex ──────────────────────────────────────────── SENTENCE_END = re.compile(r"[.?!…]+\s+") SYSTEM_PROMPT = { "role": "system", "content": "Tu es un assistant d'agent de Bordeaux Metropole. Tu communique exclusisement en FRANCAIS. Tu gardes un ton courtois.", } if not MISTRAL_API_KEY: print("⚠️ MISTRAL_API_KEY not set — all API calls will fail.") # ══════════════════════════════════════════════════════════════════════════ # MLflow setup (runs once at module import time) # ══════════════════════════════════════════════════════════════════════════ # Guard flag — tracing is only active when a tracking URI is configured. MLFLOW_ENABLED = bool(MLFLOW_TRACKING_URI) if MLFLOW_ENABLED: # The mlflow package is declared in requirements.txt; if it's somehow # missing the try/except degrades gracefully (tracing disabled). try: import mlflow # Point the MLflow client at your cloud server. # Example: mlflow.set_tracking_uri("https://mlflow.example.com") mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) # Create or reuse the experiment. Runs will be grouped under it. mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME) # ── Auto-trace Mistral SDK calls ────────────────────────────── # This patches Mistral's Python SDK so every call to # client.chat.complete() # client.audio.transcriptions.complete() # etc. # is automatically recorded as a trace in MLflow. # # LIMITATION: streaming chat is NOT auto-traced; we handle that # manually in the stream_reply() function further down. # # Doc: https://mlflow.org/docs/latest/genai/tracing/integrations/listing/mistral/ mlflow.mistral.autolog() logger.info( "MLflow tracing ENABLED — tracking URI: %s, experiment: %s", MLFLOW_TRACKING_URI, MLFLOW_EXPERIMENT_NAME, ) except Exception as exc: logger.warning( "MLflow init failed (%s) — tracing disabled. " "Set MLFLOW_TRACKING_URI to a valid MLflow 3 server URL.", exc, ) MLFLOW_ENABLED = False # ── Mistral SDK client (used for auto-traced calls) ───────────────────── # We instantiate the async Mistral client. Calls made through this client # (e.g. audio transcriptions) are auto-traced by mlflow.mistral.autolog(). # The LLM streaming call uses raw httpx instead (see stream_llm) because # the Mistral SDK + streaming is not supported by MLflow auto-tracing. if MLFLOW_ENABLED: try: from mistralai.async_client import MistralAsync as _MistralAsync _mistral = _MistralAsync(api_key=MISTRAL_API_KEY) except Exception as exc: logger.warning("Failed to create Mistral SDK client: %s", exc) _mistral = None else: _mistral = None # ══════════════════════════════════════════════════════════════════════════ # IN-MEMORY METRICS (not related to MLflow — shown in the Gradio UI) # ══════════════════════════════════════════════════════════════════════════ class Metrics: """Simple in-memory counters and accumulators for the UI stats panel.""" def __init__(self): self.lock = asyncio.Lock() self.reset() def reset(self): self.stt_count = 0 self.stt_total_s = 0.0 self.llm_count = 0 self.llm_total_s = 0.0 self.tts_count = 0 self.tts_total_s = 0.0 self.total_tokens = 0 self.error_count = 0 self.last_errors: list[str] = [] async def record_stt(self, elapsed: float): async with self.lock: self.stt_count += 1 self.stt_total_s += elapsed async def record_llm(self, elapsed: float, tokens: int): async with self.lock: self.llm_count += 1 self.llm_total_s += elapsed self.total_tokens += tokens async def record_tts(self, elapsed: float): async with self.lock: self.tts_count += 1 self.tts_total_s += elapsed async def record_error(self, msg: str): async with self.lock: self.error_count += 1 self.last_errors.append(msg) if len(self.last_errors) > 20: self.last_errors.pop(0) def snapshot_md(self) -> str: def avg(total, count): return f"{total/count:.2f}s" if count else "—" m = self lines = [ "### 📊 Metrics", "| Service | Calls | Avg Latency | Total Time |", "|---------|-------|-------------|------------|", f"| **STT** | {m.stt_count} | {avg(m.stt_total_s, m.stt_count)} | {m.stt_total_s:.1f}s |", f"| **LLM** | {m.llm_count} | {avg(m.llm_total_s, m.llm_count)} | {m.llm_total_s:.1f}s |", f"| **TTS** | {m.tts_count} | {avg(m.tts_total_s, m.tts_count)} | {m.tts_total_s:.1f}s |", "", f"**Total tokens generated:** {m.total_tokens}", f"**Errors:** {m.error_count}", ] if m.last_errors: lines.append("\n**Recent errors:**") for e in m.last_errors[-5:]: lines.append(f"- `{e[:120]}`") return "\n".join(lines) _metrics = Metrics() async def reset_metrics(): _metrics.reset() return _metrics.snapshot_md() # ══════════════════════════════════════════════════════════════════════════ # HELPERS # ══════════════════════════════════════════════════════════════════════════ def to_chatbot(history: list) -> list[dict]: """Convert internal tuple history [(user, bot), …] → Gradio messages format.""" msgs = [] for user_text, assistant_text in history: if user_text: msgs.append({"role": "user", "content": user_text}) if assistant_text: msgs.append({"role": "assistant", "content": assistant_text}) return msgs def build_messages(history: list, user_text: str) -> list[dict]: """Build the messages array for the Mistral Chat API from conversation history.""" messages = [SYSTEM_PROMPT] for human, assistant in history: messages.append({"role": "user", "content": human}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": user_text}) return messages # ══════════════════════════════════════════════════════════════════════════ # STT — speech-to-text (Mistral Transcriptions API, auto-traced via SDK) # ══════════════════════════════════════════════════════════════════════════ async def transcribe(audio_path: str) -> str: """Convert audio to 16 kHz mono WAV and transcribe via Mistral STT API. When MLflow tracing is enabled, this call goes through the Mistral Python SDK and is automatically captured by mlflow.mistral.autolog(). """ # ── Audio preprocessing ─────────────────────────────────────────── # Browser mic output (WebM/OGG) must be converted to 16 kHz mono WAV. import subprocess import tempfile as tf t0 = time.perf_counter() wav_path = None try: with tf.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: wav_path = tmp.name subprocess.run( ["ffmpeg", "-y", "-i", audio_path, "-ar", "16000", "-ac", "1", wav_path], capture_output=True, check=True, ) with open(wav_path, "rb") as f: audio_bytes = f.read() except Exception as e: logger.warning("ffmpeg conversion failed, sending raw audio: %s", e) with open(audio_path, "rb") as f: audio_bytes = f.read() finally: if wav_path and os.path.exists(wav_path): os.unlink(wav_path) # ── Transcribe ──────────────────────────────────────────────────── try: if _mistral is not None: # Mistral SDK call — auto-traced by MLflow via autolog(). # The trace captures: model, audio file info, response text, # and usage stats (tokens, audio seconds). result = await _mistral.audio.transcriptions.complete( model=STT_MODEL, file={"content": audio_bytes, "file_name": "audio.wav"}, ) text = result.text else: # Fallback: raw httpx (no tracing). headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"} async with httpx.AsyncClient(timeout=120.0) as client: resp = await client.post( STT_API_URL, headers=headers, files={ "file": ("audio.wav", audio_bytes, "audio/wav"), "model": (None, STT_MODEL), }, ) resp.raise_for_status() text = resp.json()["text"] await _metrics.record_stt(time.perf_counter() - t0) logger.info( "STT ok %.2fs %.0f bytes model=%s", time.perf_counter() - t0, len(audio_bytes), STT_MODEL, ) return text except Exception as e: await _metrics.record_error(f"STT: {e}") raise # ══════════════════════════════════════════════════════════════════════════ # TTS — text-to-speech (Mistral Speech API, raw httpx, not traced) # ══════════════════════════════════════════════════════════════════════════ async def call_tts(client: Mistral, text: str) -> tuple | None: """Synthesise speech via Mistral TTS API, return (sample_rate, numpy_array). Uses the official Mistral SDK (audio.speech.complete). Not MLflow-traced (low diagnostic value for TTS). """ t0 = time.perf_counter() try: kwargs: dict = { "model": TTS_MODEL, "input": text, "response_format": "mp3", } if TTS_VOICE: kwargs["voice_id"] = TTS_VOICE # SDK is synchronous — run in a thread to avoid blocking the event loop loop = asyncio.get_running_loop() response = await loop.run_in_executor( None, lambda: client.audio.speech.complete(**kwargs), ) audio_bytes = base64.b64decode(response.audio_data) audio_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32") elapsed = time.perf_counter() - t0 await _metrics.record_tts(elapsed) logger.info("TTS ok %.2fs %d chars model=%s", elapsed, len(text), TTS_MODEL) return sr, audio_np except Exception as e: await _metrics.record_error(f"TTS: {e}") logger.warning("TTS failed (%.1fs): %s", time.perf_counter() - t0, e) return None # ══════════════════════════════════════════════════════════════════════════ # LLM — streaming text generation (Mistral Chat API via httpx SSE) # ══════════════════════════════════════════════════════════════════════════ async def stream_llm(messages: list[dict]): """Stream tokens from Mistral Chat API via Server-Sent Events. Yields (token, cumulative_token_count). Uses raw httpx because MLflow's Mistral autolog does NOT support streaming chat completions (the auto-trace would miss the response). Instead, the caller (stream_reply) manually records an MLflow span. """ headers = { "Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json", } body = { "model": LLM_MODEL, "messages": messages, "stream": True, "max_tokens": 512, "temperature": 0.7, "top_p": 0.9, } t0 = time.perf_counter() token_count = 0 try: async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=120.0)) as client: async with client.stream( "POST", LLM_API_URL, json=body, headers=headers, ) as resp: resp.raise_for_status() async for line in resp.aiter_lines(): if line.startswith("data: "): payload = line[6:].strip() if payload == "[DONE]": break if not payload: continue try: data = json.loads(payload) except json.JSONDecodeError: continue choices = data.get("choices", []) if not choices: continue delta = choices[0].get("delta", {}) content = delta.get("content") if content: token_count += 1 yield content, token_count elapsed = time.perf_counter() - t0 await _metrics.record_llm(elapsed, token_count) logger.info( "LLM ok %.2fs %d tokens %.1f tok/s model=%s", elapsed, token_count, token_count / elapsed if elapsed else 0, LLM_MODEL, ) except httpx.HTTPStatusError as e: body_text = await e.response.aread() detail = body_text.decode(errors="replace")[:300] elapsed = time.perf_counter() - t0 await _metrics.record_error( f"LLM HTTP {e.response.status_code}: {detail[:80]}" ) logger.error("LLM HTTP %s %.1fs %s", e.response.status_code, elapsed, detail) yield f"⚠️ LLM API error ({e.response.status_code}): {detail}", 0 except Exception as e: elapsed = time.perf_counter() - t0 await _metrics.record_error(f"LLM: {type(e).__name__}: {e}") logger.error("LLM error %.1fs %s", elapsed, e) yield f"⚠️ LLM error: {type(e).__name__}: {e}", 0 # ══════════════════════════════════════════════════════════════════════════ # STREAM ORCHESTRATOR — consumes LLM tokens, records MLflow trace, calls TTS # ══════════════════════════════════════════════════════════════════════════ async def stream_reply(messages: list[dict]): """Consume LLM stream, record an MLflow trace, and synthesize audio. Yields (partial_reply, audio_tuple_or_None). MLflow tracing: - The Mistral auto-log does not cover streaming completions. - We manually create a span with mlflow.start_span() that records the full input messages and the aggregated output. - The span is finalized when the stream ends (the ``finally`` block). - If MLFLOW_TRACKING_URI is not set, no trace is recorded. """ token_buffer = "" full_reply = "" _trace_token_count = 0 _trace_start = time.perf_counter() _trace_error = None try: async for token_or_error, token_count in stream_llm(messages): # ── Error in stream → stop & propagate ─────────────────── if token_or_error.startswith("⚠️"): _trace_error = token_or_error yield token_or_error, None return _trace_token_count = token_count token_buffer += token_or_error full_reply += token_or_error # ── Sentence-boundary TTS ───────────────────────────────── # When a complete sentence has arrived, dispatch a TTS call # so the user hears audio before the full reply finishes. match = SENTENCE_END.search(token_buffer) if match: sentence = token_buffer[: match.end()].strip() token_buffer = token_buffer[match.end():] if sentence: async with Mistral(api_key=MISTRAL_API_KEY) as client: audio = await call_tts(client, sentence) if audio is not None: yield full_reply, audio continue yield full_reply, None # ── Remaining text after last sentence boundary ─────────────── if token_buffer.strip(): async with Mistral(api_key=MISTRAL_API_KEY) as client: audio = await call_tts(client, token_buffer.strip()) if audio is not None: yield full_reply, audio return yield full_reply, None finally: # ── MLflow manual trace ─────────────────────────────────────── # This runs when the generator is exhausted (async for finishes) # or when an exception propagates out. # # It creates a single span ("llm_chat_stream") that contains the # full request (messages + params) and the aggregated response. if MLFLOW_ENABLED and full_reply: _elapsed = time.perf_counter() - _trace_start try: import mlflow from mlflow.tracing.fluent import start_span # start_span outside of a trace context creates a new # root span (i.e. a new trace). with start_span("llm_chat_stream") as span: span.set_inputs({ "model": LLM_MODEL, "messages": messages, "temperature": 0.95, "max_tokens": 512, }) span.set_outputs({ "response": full_reply, "token_count": _trace_token_count, "latency_seconds": round(_elapsed, 3), "tokens_per_second": round( _trace_token_count / _elapsed, 1 ) if _elapsed > 0 else 0, }) if _trace_error: span.set_status( mlflow.tracing.Status.ERROR, str(_trace_error), ) except Exception as trace_err: # Don't crash the app if the MLflow server is unreachable. logger.warning("MLflow trace recording failed: %s", trace_err) # ══════════════════════════════════════════════════════════════════════════ # GRADIO TURN HANDLERS # ══════════════════════════════════════════════════════════════════════════ async def run_turn(user_text: str, history: list): """Shared generator: streams chatbot text updates and per-sentence audio.""" history = history or [] messages = build_messages(history, user_text) new_history = history + [(user_text, "▌")] audio_chunks: list[np.ndarray] = [] audio_sr = None async for partial_reply, audio in stream_reply(messages): new_history[-1] = (user_text, partial_reply + "▌") if audio is not None: sr, arr = audio audio_sr = sr audio_chunks.append(arr) yield to_chatbot(new_history), new_history, None, None, gr.update() # . , ., audio, ., . remove partial stream final_text = new_history[-1][1].rstrip("▌").rstrip() new_history[-1] = (user_text, final_text) full_audio = ( (audio_sr, np.concatenate(audio_chunks)) if audio_chunks else None ) yield to_chatbot(new_history), new_history, None, full_audio, gr.update() async def text_turn(user_text: str, history: list): if not user_text.strip(): yield to_chatbot(history or []), history or [], None, None, user_text, gr.update() return first = True async for h, s, a, la, _ in run_turn(user_text, history): yield h, s, a, la, ("" if first else gr.update()), gr.update() first = False async def voice_turn(audio_path: str, history: list): if audio_path is None: yield to_chatbot(history or []), history or [], None, None, gr.update() return try: user_text = await transcribe(audio_path) except Exception as e: logger.exception("STT error") err = f"⚠️ STT error: {type(e).__name__}: {e}" yield to_chatbot((history or []) + [("[voice]", err)]), history or [], None, None, gr.update() return async for h, s, a, la, _ in run_turn(user_text, history): yield h, s, a, la, gr.update() # ══════════════════════════════════════════════════════════════════════════ # GRADIO UI # ══════════════════════════════════════════════════════════════════════════ DESCRIPTION = f"""## Voice Chatbot _API LLM: `{LLM_MODEL}` · STT: `{STT_MODEL}` · TTS: `{TTS_MODEL}`_ """ with gr.Blocks( title="Voice Chatbot", theme=gr.themes.Soft(), css="footer {visibility: hidden}", ) as demo: gr.Markdown(DESCRIPTION) chatbot = gr.Chatbot(label="Conversation", height=380) state = gr.State([]) last_audio = gr.State(None) with gr.Row(): text_box = gr.Textbox( placeholder="Type a message and press Enter …", show_label=False, scale=5, ) send_btn = gr.Button("Send", scale=1, variant="primary") with gr.Row(): mic_input = gr.Audio( sources=["microphone"], type="filepath", label="Voice input", scale=5, ) voice_btn = gr.Button("Submit voice", scale=1) audio_out = gr.Audio(label="Bot response (audio)") # knowed issue on huggingface to stream audio ==> implement from GradioSpace https://huggingface.co/spaces/gradio/stream_audio_out/blob/main/app.py , autoplay=True) replay_btn = gr.Button("🔁 Replay", size="sm") # ── Observability panel ─────────────────────────────────────────── with gr.Accordion("📊 Stats & Observability", open=False): stats_md = gr.Markdown(_metrics.snapshot_md()) with gr.Row(): refresh_btn = gr.Button("🔄 Refresh", size="sm") reset_btn = gr.Button("🗑 Reset", size="sm") refresh_btn.click(fn=_metrics.snapshot_md, outputs=[stats_md]) reset_btn.click(fn=reset_metrics, outputs=[stats_md]) # ── Event wiring ────────────────────────────────────────────────── send_btn.click( text_turn, inputs=[text_box, state], outputs=[chatbot, state, audio_out, last_audio, text_box, stats_md], ) text_box.submit( text_turn, inputs=[text_box, state], outputs=[chatbot, state, audio_out, last_audio, text_box, stats_md], ) voice_btn.click( voice_turn, inputs=[mic_input, state], outputs=[chatbot, state, audio_out, last_audio, stats_md], ) replay_btn.click( fn=lambda a: a, inputs=[last_audio], outputs=[audio_out], ) demo.queue() if __name__ == "__main__": demo.launch(server_name="0.0.0.0")