""" Voice Chatbot — Hugging Face Space version (Mistral API + MLflow tracing). Architecture: - LLM : Mistral Chat API (streaming SSE via httpx, traced manually) - STT : Mistral Transcriptions API (via Mistral SDK — auto-traced by MLflow) - TTS : Mistral Speech API (via httpx, not traced) MLflow tracing captures every LLM and STT call with: - Input prompts / messages - Output completions / transcriptions - Token counts & latency - Model name & generation parameters To activate, set these HF Space secrets: MISTRAL_API_KEY — required MLFLOW_TRACKING_URI — your cloud MLflow 3 server URL MLFLOW_EXPERIMENT_NAME — optional (default: mistral-chatbot) LLM_MODEL — optional (default: mistral-small-latest) STT_MODEL — optional (default: voxtral-mini-latest) TTS_MODEL — optional (default: voxtral-mini-tts-2603) TTS_VOICE — optional TTS voice ID """ import asyncio import base64 import io import json import logging import os import re import time from gradio import Server import httpx import soundfile as sf logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", ) logger = logging.getLogger("app") # ══════════════════════════════════════════════════════════════════════════ # CONFIGURATION — set via HF Space secrets (Settings → Repository secrets) # ══════════════════════════════════════════════════════════════════════════ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "") # ── LLM ───────────────────────────────────────────────────────────────── LLM_MODEL = os.getenv("LLM_MODEL", "mistral-small-latest") LLM_API_URL = os.getenv( "LLM_API_URL", "https://api.mistral.ai/v1/chat/completions", ) # ── STT ───────────────────────────────────────────────────────────────── STT_MODEL = os.getenv("STT_MODEL", "voxtral-mini-latest") STT_API_URL = os.getenv( "STT_API_URL", "https://api.mistral.ai/v1/audio/transcriptions", ) # ── TTS ───────────────────────────────────────────────────────────────── TTS_API_URL = os.getenv( "TTS_API_URL", "https://api.mistral.ai/v1/audio/speech", ) TTS_MODEL = os.getenv("TTS_MODEL", "voxtral-mini-tts-2603") TTS_VOICE = os.getenv("TTS_VOICE", "") # ── MLflow Tracing ────────────────────────────────────────────────────── MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "") MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "mistral-chatbot") # ── Sentence boundary regex ──────────────────────────────────────────── SENTENCE_END = re.compile(r"[.?!…]+\s+") SYSTEM_PROMPT = { "role": "system", "content": "You are a helpful and concise voice assistant.", } if not MISTRAL_API_KEY: print("WARNING: MISTRAL_API_KEY not set — all API calls will fail.") # ══════════════════════════════════════════════════════════════════════════ # MLflow setup (runs once at module import time) # ══════════════════════════════════════════════════════════════════════════ MLFLOW_ENABLED = bool(MLFLOW_TRACKING_URI) if MLFLOW_ENABLED: try: import mlflow mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME) mlflow.mistral.autolog() logger.info( "MLflow tracing ENABLED — tracking URI: %s, experiment: %s", MLFLOW_TRACKING_URI, MLFLOW_EXPERIMENT_NAME, ) except Exception as exc: logger.warning( "MLflow init failed (%s) — tracing disabled. " "Set MLFLOW_TRACKING_URI to a valid MLflow 3 server URL.", exc, ) MLFLOW_ENABLED = False # ── Mistral SDK client (used for auto-traced STT calls) ───────────────── if MLFLOW_ENABLED: try: from mistralai.async_client import MistralAsync as _MistralAsync _mistral = _MistralAsync(api_key=MISTRAL_API_KEY) except Exception as exc: logger.warning("Failed to create Mistral SDK client: %s", exc) _mistral = None else: _mistral = None # ══════════════════════════════════════════════════════════════════════════ # IN-MEMORY METRICS # ══════════════════════════════════════════════════════════════════════════ class Metrics: """Simple in-memory counters and accumulators.""" def __init__(self): self.lock = asyncio.Lock() self.reset() def reset(self): self.stt_count = 0 self.stt_total_s = 0.0 self.llm_count = 0 self.llm_total_s = 0.0 self.tts_count = 0 self.tts_total_s = 0.0 self.total_tokens = 0 self.error_count = 0 self.last_errors: list[str] = [] async def record_stt(self, elapsed: float): async with self.lock: self.stt_count += 1 self.stt_total_s += elapsed async def record_llm(self, elapsed: float, tokens: int): async with self.lock: self.llm_count += 1 self.llm_total_s += elapsed self.total_tokens += tokens async def record_tts(self, elapsed: float): async with self.lock: self.tts_count += 1 self.tts_total_s += elapsed async def record_error(self, msg: str): async with self.lock: self.error_count += 1 self.last_errors.append(msg) if len(self.last_errors) > 20: self.last_errors.pop(0) def snapshot(self) -> dict: def avg(total, count): return round(total / count, 3) if count else None m = self return { "stt": {"calls": m.stt_count, "avg_latency_s": avg(m.stt_total_s, m.stt_count), "total_s": round(m.stt_total_s, 2)}, "llm": {"calls": m.llm_count, "avg_latency_s": avg(m.llm_total_s, m.llm_count), "total_s": round(m.llm_total_s, 2)}, "tts": {"calls": m.tts_count, "avg_latency_s": avg(m.tts_total_s, m.tts_count), "total_s": round(m.tts_total_s, 2)}, "total_tokens": m.total_tokens, "errors": m.error_count, "last_errors": m.last_errors[-5:], } _metrics = Metrics() # ══════════════════════════════════════════════════════════════════════════ # HELPERS # ══════════════════════════════════════════════════════════════════════════ def build_messages(history: list[dict], user_text: str) -> list[dict]: """Build the messages array for the Mistral Chat API.""" messages = [SYSTEM_PROMPT] + list(history) messages.append({"role": "user", "content": user_text}) return messages # ══════════════════════════════════════════════════════════════════════════ # STT — speech-to-text # ══════════════════════════════════════════════════════════════════════════ async def transcribe(audio_path: str) -> str: """Convert audio to 16 kHz mono WAV and transcribe via Mistral STT API.""" import subprocess import tempfile as tf t0 = time.perf_counter() wav_path = None try: with tf.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: wav_path = tmp.name subprocess.run( ["ffmpeg", "-y", "-i", audio_path, "-ar", "16000", "-ac", "1", wav_path], capture_output=True, check=True, ) with open(wav_path, "rb") as f: audio_bytes = f.read() except Exception as e: logger.warning("ffmpeg conversion failed, sending raw audio: %s", e) with open(audio_path, "rb") as f: audio_bytes = f.read() finally: if wav_path and os.path.exists(wav_path): os.unlink(wav_path) try: if _mistral is not None: result = await _mistral.audio.transcriptions.complete( model=STT_MODEL, file={"content": audio_bytes, "file_name": "audio.wav"}, ) text = result.text else: headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"} async with httpx.AsyncClient(timeout=120.0) as client: resp = await client.post( STT_API_URL, headers=headers, files={ "file": ("audio.wav", audio_bytes, "audio/wav"), "model": (None, STT_MODEL), }, ) resp.raise_for_status() text = resp.json()["text"] await _metrics.record_stt(time.perf_counter() - t0) logger.info("STT ok %.2fs %.0f bytes model=%s", time.perf_counter() - t0, len(audio_bytes), STT_MODEL) return text except Exception as e: await _metrics.record_error(f"STT: {e}") raise # ══════════════════════════════════════════════════════════════════════════ # TTS — text-to-speech # ══════════════════════════════════════════════════════════════════════════ async def call_tts(client: httpx.AsyncClient, text: str) -> str | None: """Synthesise speech via Mistral TTS API. Returns base64-encoded WAV string.""" t0 = time.perf_counter() try: headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"} body: dict = { "model": TTS_MODEL, "input": text, "response_format": "wav", } if TTS_VOICE: body["voice_id"] = TTS_VOICE resp = await client.post(TTS_API_URL, headers=headers, json=body, timeout=60.0) resp.raise_for_status() data = resp.json() elapsed = time.perf_counter() - t0 await _metrics.record_tts(elapsed) logger.info("TTS ok %.2fs %d chars model=%s", elapsed, len(text), TTS_MODEL) return data["audio_data"] # base64-encoded WAV except Exception as e: await _metrics.record_error(f"TTS: {e}") logger.warning("TTS failed (%.1fs): %s", time.perf_counter() - t0, e) return None # ══════════════════════════════════════════════════════════════════════════ # LLM — streaming text generation # ══════════════════════════════════════════════════════════════════════════ async def stream_llm(messages: list[dict]): """Stream tokens from Mistral Chat API via SSE. Yields (token, cumulative_count).""" headers = { "Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json", } body = { "model": LLM_MODEL, "messages": messages, "stream": True, "max_tokens": 512, "temperature": 0.7, "top_p": 0.9, } t0 = time.perf_counter() token_count = 0 try: async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=120.0)) as client: async with client.stream("POST", LLM_API_URL, json=body, headers=headers) as resp: resp.raise_for_status() async for line in resp.aiter_lines(): if line.startswith("data: "): payload = line[6:].strip() if payload == "[DONE]": break if not payload: continue try: data = json.loads(payload) except json.JSONDecodeError: continue choices = data.get("choices", []) if not choices: continue delta = choices[0].get("delta", {}) content = delta.get("content") if content: token_count += 1 yield content, token_count elapsed = time.perf_counter() - t0 await _metrics.record_llm(elapsed, token_count) logger.info("LLM ok %.2fs %d tokens %.1f tok/s model=%s", elapsed, token_count, token_count / elapsed if elapsed else 0, LLM_MODEL) except httpx.HTTPStatusError as e: body_text = await e.response.aread() detail = body_text.decode(errors="replace")[:300] elapsed = time.perf_counter() - t0 await _metrics.record_error(f"LLM HTTP {e.response.status_code}: {detail[:80]}") logger.error("LLM HTTP %s %.1fs %s", e.response.status_code, elapsed, detail) yield f"[ERROR] LLM API error ({e.response.status_code}): {detail}", 0 except Exception as e: elapsed = time.perf_counter() - t0 await _metrics.record_error(f"LLM: {type(e).__name__}: {e}") logger.error("LLM error %.1fs %s", elapsed, e) yield f"[ERROR] LLM error: {type(e).__name__}: {e}", 0 # ══════════════════════════════════════════════════════════════════════════ # STREAM ORCHESTRATOR # ══════════════════════════════════════════════════════════════════════════ async def stream_reply(messages: list[dict]): """Consume LLM stream, record MLflow trace, and synthesize audio. Yields dicts: {"reply": str, "audio_b64": str|None, "done": bool} """ token_buffer = "" full_reply = "" _trace_token_count = 0 _trace_start = time.perf_counter() _trace_error = None try: async for token_or_error, token_count in stream_llm(messages): if token_or_error.startswith("[ERROR]"): _trace_error = token_or_error yield {"reply": token_or_error, "audio_b64": None, "done": True} return _trace_token_count = token_count token_buffer += token_or_error full_reply += token_or_error match = SENTENCE_END.search(token_buffer) if match: sentence = token_buffer[: match.end()].strip() token_buffer = token_buffer[match.end():] if sentence: async with httpx.AsyncClient() as client: audio_b64 = await call_tts(client, sentence) yield {"reply": full_reply, "audio_b64": audio_b64, "done": False} continue yield {"reply": full_reply, "audio_b64": None, "done": False} # Flush remaining text if token_buffer.strip(): async with httpx.AsyncClient() as client: audio_b64 = await call_tts(client, token_buffer.strip()) yield {"reply": full_reply, "audio_b64": audio_b64, "done": True} else: yield {"reply": full_reply, "audio_b64": None, "done": True} finally: if MLFLOW_ENABLED and full_reply: _elapsed = time.perf_counter() - _trace_start try: import mlflow from mlflow.tracing.fluent import start_span with start_span("llm_chat_stream") as span: span.set_inputs({ "model": LLM_MODEL, "messages": messages, "temperature": 0.7, "max_tokens": 512, }) span.set_outputs({ "response": full_reply, "token_count": _trace_token_count, "latency_seconds": round(_elapsed, 3), "tokens_per_second": round(_trace_token_count / _elapsed, 1) if _elapsed > 0 else 0, }) if _trace_error: span.set_status(mlflow.tracing.Status.ERROR, str(_trace_error)) except Exception as trace_err: logger.warning("MLflow trace recording failed: %s", trace_err) # ══════════════════════════════════════════════════════════════════════════ # GRADIO SERVER # ══════════════════════════════════════════════════════════════════════════ app = Server( title="Voice Chatbot", description=f"API LLM: `{LLM_MODEL}` · STT: `{STT_MODEL}` · TTS: `{TTS_MODEL}`", ) @app.api(name="text_turn") async def text_turn(user_text: str, history: list[dict]) -> dict: """Send a text message and stream back reply chunks with optional audio. history: list of {"role": "user"|"assistant", "content": str} Yields: {"reply": str, "audio_b64": str|None, "done": bool} """ if not user_text.strip(): yield {"reply": "", "audio_b64": None, "done": True} return messages = build_messages(history, user_text) async for chunk in stream_reply(messages): yield chunk @app.api(name="voice_turn") async def voice_turn(audio_path: str, history: list[dict]) -> dict: """Transcribe audio then stream back a reply. audio_path: local file path to recorded audio. history: list of {"role": "user"|"assistant", "content": str} Yields: {"reply": str, "audio_b64": str|None, "done": bool, "user_text": str} """ if not audio_path: yield {"reply": "", "audio_b64": None, "done": True, "user_text": ""} return try: user_text = await transcribe(audio_path) except Exception as e: err = f"[ERROR] STT error: {type(e).__name__}: {e}" logger.exception("STT error") yield {"reply": err, "audio_b64": None, "done": True, "user_text": ""} return messages = build_messages(history, user_text) async for chunk in stream_reply(messages): yield {**chunk, "user_text": user_text} @app.api(name="get_metrics") def get_metrics() -> dict: """Return current in-memory metrics snapshot.""" return _metrics.snapshot() @app.api(name="reset_metrics") async def reset_metrics_api() -> dict: """Reset all in-memory metrics counters.""" _metrics.reset() return _metrics.snapshot() if __name__ == "__main__": app.launch(server_name="0.0.0.0")