Spaces:
Sleeping
Sleeping
| """ | |
| Voice Chatbot — Hugging Face Space version (Mistral API + MLflow tracing). | |
| Architecture: | |
| - LLM : Mistral Chat API (streaming SSE via httpx, traced manually) | |
| - STT : Mistral Transcriptions API (via Mistral SDK — auto-traced by MLflow) | |
| - TTS : Mistral Speech API (via httpx, not traced) | |
| MLflow tracing captures every LLM and STT call with: | |
| - Input prompts / messages | |
| - Output completions / transcriptions | |
| - Token counts & latency | |
| - Model name & generation parameters | |
| To activate, set these HF Space secrets: | |
| MISTRAL_API_KEY — required | |
| MLFLOW_TRACKING_URI — your cloud MLflow 3 server URL | |
| MLFLOW_EXPERIMENT_NAME — optional (default: mistral-chatbot) | |
| LLM_MODEL — optional (default: mistral-small-latest) | |
| STT_MODEL — optional (default: voxtral-mini-latest) | |
| TTS_MODEL — optional (default: voxtral-mini-tts-2603) | |
| TTS_VOICE — optional TTS voice ID | |
| """ | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import time | |
| from gradio import Server | |
| import httpx | |
| import soundfile as sf | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| ) | |
| logger = logging.getLogger("app") | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # CONFIGURATION — set via HF Space secrets (Settings → Repository secrets) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "") | |
| # ── LLM ───────────────────────────────────────────────────────────────── | |
| LLM_MODEL = os.getenv("LLM_MODEL", "mistral-small-latest") | |
| LLM_API_URL = os.getenv( | |
| "LLM_API_URL", | |
| "https://api.mistral.ai/v1/chat/completions", | |
| ) | |
| # ── STT ───────────────────────────────────────────────────────────────── | |
| STT_MODEL = os.getenv("STT_MODEL", "voxtral-mini-latest") | |
| STT_API_URL = os.getenv( | |
| "STT_API_URL", | |
| "https://api.mistral.ai/v1/audio/transcriptions", | |
| ) | |
| # ── TTS ───────────────────────────────────────────────────────────────── | |
| TTS_API_URL = os.getenv( | |
| "TTS_API_URL", | |
| "https://api.mistral.ai/v1/audio/speech", | |
| ) | |
| TTS_MODEL = os.getenv("TTS_MODEL", "voxtral-mini-tts-2603") | |
| TTS_VOICE = os.getenv("TTS_VOICE", "") | |
| # ── MLflow Tracing ────────────────────────────────────────────────────── | |
| MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "") | |
| MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "mistral-chatbot") | |
| # ── Sentence boundary regex ──────────────────────────────────────────── | |
| SENTENCE_END = re.compile(r"[.?!…]+\s+") | |
| SYSTEM_PROMPT = { | |
| "role": "system", | |
| "content": "You are a helpful and concise voice assistant.", | |
| } | |
| if not MISTRAL_API_KEY: | |
| print("WARNING: MISTRAL_API_KEY not set — all API calls will fail.") | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # MLflow setup (runs once at module import time) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| MLFLOW_ENABLED = bool(MLFLOW_TRACKING_URI) | |
| if MLFLOW_ENABLED: | |
| try: | |
| import mlflow | |
| mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) | |
| mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME) | |
| mlflow.mistral.autolog() | |
| logger.info( | |
| "MLflow tracing ENABLED — tracking URI: %s, experiment: %s", | |
| MLFLOW_TRACKING_URI, | |
| MLFLOW_EXPERIMENT_NAME, | |
| ) | |
| except Exception as exc: | |
| logger.warning( | |
| "MLflow init failed (%s) — tracing disabled. " | |
| "Set MLFLOW_TRACKING_URI to a valid MLflow 3 server URL.", | |
| exc, | |
| ) | |
| MLFLOW_ENABLED = False | |
| # ── Mistral SDK client (used for auto-traced STT calls) ───────────────── | |
| if MLFLOW_ENABLED: | |
| try: | |
| from mistralai.async_client import MistralAsync as _MistralAsync | |
| _mistral = _MistralAsync(api_key=MISTRAL_API_KEY) | |
| except Exception as exc: | |
| logger.warning("Failed to create Mistral SDK client: %s", exc) | |
| _mistral = None | |
| else: | |
| _mistral = None | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # IN-MEMORY METRICS | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| class Metrics: | |
| """Simple in-memory counters and accumulators.""" | |
| def __init__(self): | |
| self.lock = asyncio.Lock() | |
| self.reset() | |
| def reset(self): | |
| self.stt_count = 0 | |
| self.stt_total_s = 0.0 | |
| self.llm_count = 0 | |
| self.llm_total_s = 0.0 | |
| self.tts_count = 0 | |
| self.tts_total_s = 0.0 | |
| self.total_tokens = 0 | |
| self.error_count = 0 | |
| self.last_errors: list[str] = [] | |
| async def record_stt(self, elapsed: float): | |
| async with self.lock: | |
| self.stt_count += 1 | |
| self.stt_total_s += elapsed | |
| async def record_llm(self, elapsed: float, tokens: int): | |
| async with self.lock: | |
| self.llm_count += 1 | |
| self.llm_total_s += elapsed | |
| self.total_tokens += tokens | |
| async def record_tts(self, elapsed: float): | |
| async with self.lock: | |
| self.tts_count += 1 | |
| self.tts_total_s += elapsed | |
| async def record_error(self, msg: str): | |
| async with self.lock: | |
| self.error_count += 1 | |
| self.last_errors.append(msg) | |
| if len(self.last_errors) > 20: | |
| self.last_errors.pop(0) | |
| def snapshot(self) -> dict: | |
| def avg(total, count): | |
| return round(total / count, 3) if count else None | |
| m = self | |
| return { | |
| "stt": {"calls": m.stt_count, "avg_latency_s": avg(m.stt_total_s, m.stt_count), "total_s": round(m.stt_total_s, 2)}, | |
| "llm": {"calls": m.llm_count, "avg_latency_s": avg(m.llm_total_s, m.llm_count), "total_s": round(m.llm_total_s, 2)}, | |
| "tts": {"calls": m.tts_count, "avg_latency_s": avg(m.tts_total_s, m.tts_count), "total_s": round(m.tts_total_s, 2)}, | |
| "total_tokens": m.total_tokens, | |
| "errors": m.error_count, | |
| "last_errors": m.last_errors[-5:], | |
| } | |
| _metrics = Metrics() | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # HELPERS | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| def build_messages(history: list[dict], user_text: str) -> list[dict]: | |
| """Build the messages array for the Mistral Chat API.""" | |
| messages = [SYSTEM_PROMPT] + list(history) | |
| messages.append({"role": "user", "content": user_text}) | |
| return messages | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # STT — speech-to-text | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def transcribe(audio_path: str) -> str: | |
| """Convert audio to 16 kHz mono WAV and transcribe via Mistral STT API.""" | |
| import subprocess | |
| import tempfile as tf | |
| t0 = time.perf_counter() | |
| wav_path = None | |
| try: | |
| with tf.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| wav_path = tmp.name | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-i", audio_path, "-ar", "16000", "-ac", "1", wav_path], | |
| capture_output=True, check=True, | |
| ) | |
| with open(wav_path, "rb") as f: | |
| audio_bytes = f.read() | |
| except Exception as e: | |
| logger.warning("ffmpeg conversion failed, sending raw audio: %s", e) | |
| with open(audio_path, "rb") as f: | |
| audio_bytes = f.read() | |
| finally: | |
| if wav_path and os.path.exists(wav_path): | |
| os.unlink(wav_path) | |
| try: | |
| if _mistral is not None: | |
| result = await _mistral.audio.transcriptions.complete( | |
| model=STT_MODEL, | |
| file={"content": audio_bytes, "file_name": "audio.wav"}, | |
| ) | |
| text = result.text | |
| else: | |
| headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"} | |
| async with httpx.AsyncClient(timeout=120.0) as client: | |
| resp = await client.post( | |
| STT_API_URL, | |
| headers=headers, | |
| files={ | |
| "file": ("audio.wav", audio_bytes, "audio/wav"), | |
| "model": (None, STT_MODEL), | |
| }, | |
| ) | |
| resp.raise_for_status() | |
| text = resp.json()["text"] | |
| await _metrics.record_stt(time.perf_counter() - t0) | |
| logger.info("STT ok %.2fs %.0f bytes model=%s", time.perf_counter() - t0, len(audio_bytes), STT_MODEL) | |
| return text | |
| except Exception as e: | |
| await _metrics.record_error(f"STT: {e}") | |
| raise | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # TTS — text-to-speech | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def call_tts(client: httpx.AsyncClient, text: str) -> str | None: | |
| """Synthesise speech via Mistral TTS API. Returns base64-encoded WAV string.""" | |
| t0 = time.perf_counter() | |
| try: | |
| headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"} | |
| body: dict = { | |
| "model": TTS_MODEL, | |
| "input": text, | |
| "response_format": "wav", | |
| } | |
| if TTS_VOICE: | |
| body["voice_id"] = TTS_VOICE | |
| resp = await client.post(TTS_API_URL, headers=headers, json=body, timeout=60.0) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_tts(elapsed) | |
| logger.info("TTS ok %.2fs %d chars model=%s", elapsed, len(text), TTS_MODEL) | |
| return data["audio_data"] # base64-encoded WAV | |
| except Exception as e: | |
| await _metrics.record_error(f"TTS: {e}") | |
| logger.warning("TTS failed (%.1fs): %s", time.perf_counter() - t0, e) | |
| return None | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # LLM — streaming text generation | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def stream_llm(messages: list[dict]): | |
| """Stream tokens from Mistral Chat API via SSE. Yields (token, cumulative_count).""" | |
| headers = { | |
| "Authorization": f"Bearer {MISTRAL_API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| body = { | |
| "model": LLM_MODEL, | |
| "messages": messages, | |
| "stream": True, | |
| "max_tokens": 512, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| } | |
| t0 = time.perf_counter() | |
| token_count = 0 | |
| try: | |
| async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=120.0)) as client: | |
| async with client.stream("POST", LLM_API_URL, json=body, headers=headers) as resp: | |
| resp.raise_for_status() | |
| async for line in resp.aiter_lines(): | |
| if line.startswith("data: "): | |
| payload = line[6:].strip() | |
| if payload == "[DONE]": | |
| break | |
| if not payload: | |
| continue | |
| try: | |
| data = json.loads(payload) | |
| except json.JSONDecodeError: | |
| continue | |
| choices = data.get("choices", []) | |
| if not choices: | |
| continue | |
| delta = choices[0].get("delta", {}) | |
| content = delta.get("content") | |
| if content: | |
| token_count += 1 | |
| yield content, token_count | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_llm(elapsed, token_count) | |
| logger.info("LLM ok %.2fs %d tokens %.1f tok/s model=%s", elapsed, token_count, token_count / elapsed if elapsed else 0, LLM_MODEL) | |
| except httpx.HTTPStatusError as e: | |
| body_text = await e.response.aread() | |
| detail = body_text.decode(errors="replace")[:300] | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_error(f"LLM HTTP {e.response.status_code}: {detail[:80]}") | |
| logger.error("LLM HTTP %s %.1fs %s", e.response.status_code, elapsed, detail) | |
| yield f"[ERROR] LLM API error ({e.response.status_code}): {detail}", 0 | |
| except Exception as e: | |
| elapsed = time.perf_counter() - t0 | |
| await _metrics.record_error(f"LLM: {type(e).__name__}: {e}") | |
| logger.error("LLM error %.1fs %s", elapsed, e) | |
| yield f"[ERROR] LLM error: {type(e).__name__}: {e}", 0 | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # STREAM ORCHESTRATOR | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| async def stream_reply(messages: list[dict]): | |
| """Consume LLM stream, record MLflow trace, and synthesize audio. | |
| Yields dicts: {"reply": str, "audio_b64": str|None, "done": bool} | |
| """ | |
| token_buffer = "" | |
| full_reply = "" | |
| _trace_token_count = 0 | |
| _trace_start = time.perf_counter() | |
| _trace_error = None | |
| try: | |
| async for token_or_error, token_count in stream_llm(messages): | |
| if token_or_error.startswith("[ERROR]"): | |
| _trace_error = token_or_error | |
| yield {"reply": token_or_error, "audio_b64": None, "done": True} | |
| return | |
| _trace_token_count = token_count | |
| token_buffer += token_or_error | |
| full_reply += token_or_error | |
| match = SENTENCE_END.search(token_buffer) | |
| if match: | |
| sentence = token_buffer[: match.end()].strip() | |
| token_buffer = token_buffer[match.end():] | |
| if sentence: | |
| async with httpx.AsyncClient() as client: | |
| audio_b64 = await call_tts(client, sentence) | |
| yield {"reply": full_reply, "audio_b64": audio_b64, "done": False} | |
| continue | |
| yield {"reply": full_reply, "audio_b64": None, "done": False} | |
| # Flush remaining text | |
| if token_buffer.strip(): | |
| async with httpx.AsyncClient() as client: | |
| audio_b64 = await call_tts(client, token_buffer.strip()) | |
| yield {"reply": full_reply, "audio_b64": audio_b64, "done": True} | |
| else: | |
| yield {"reply": full_reply, "audio_b64": None, "done": True} | |
| finally: | |
| if MLFLOW_ENABLED and full_reply: | |
| _elapsed = time.perf_counter() - _trace_start | |
| try: | |
| import mlflow | |
| from mlflow.tracing.fluent import start_span | |
| with start_span("llm_chat_stream") as span: | |
| span.set_inputs({ | |
| "model": LLM_MODEL, | |
| "messages": messages, | |
| "temperature": 0.7, | |
| "max_tokens": 512, | |
| }) | |
| span.set_outputs({ | |
| "response": full_reply, | |
| "token_count": _trace_token_count, | |
| "latency_seconds": round(_elapsed, 3), | |
| "tokens_per_second": round(_trace_token_count / _elapsed, 1) if _elapsed > 0 else 0, | |
| }) | |
| if _trace_error: | |
| span.set_status(mlflow.tracing.Status.ERROR, str(_trace_error)) | |
| except Exception as trace_err: | |
| logger.warning("MLflow trace recording failed: %s", trace_err) | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| # GRADIO SERVER | |
| # ══════════════════════════════════════════════════════════════════════════ | |
| app = Server( | |
| title="Voice Chatbot", | |
| description=f"API LLM: `{LLM_MODEL}` · STT: `{STT_MODEL}` · TTS: `{TTS_MODEL}`", | |
| ) | |
| async def text_turn(user_text: str, history: list[dict]) -> dict: | |
| """Send a text message and stream back reply chunks with optional audio. | |
| history: list of {"role": "user"|"assistant", "content": str} | |
| Yields: {"reply": str, "audio_b64": str|None, "done": bool} | |
| """ | |
| if not user_text.strip(): | |
| yield {"reply": "", "audio_b64": None, "done": True} | |
| return | |
| messages = build_messages(history, user_text) | |
| async for chunk in stream_reply(messages): | |
| yield chunk | |
| async def voice_turn(audio_path: str, history: list[dict]) -> dict: | |
| """Transcribe audio then stream back a reply. | |
| audio_path: local file path to recorded audio. | |
| history: list of {"role": "user"|"assistant", "content": str} | |
| Yields: {"reply": str, "audio_b64": str|None, "done": bool, "user_text": str} | |
| """ | |
| if not audio_path: | |
| yield {"reply": "", "audio_b64": None, "done": True, "user_text": ""} | |
| return | |
| try: | |
| user_text = await transcribe(audio_path) | |
| except Exception as e: | |
| err = f"[ERROR] STT error: {type(e).__name__}: {e}" | |
| logger.exception("STT error") | |
| yield {"reply": err, "audio_b64": None, "done": True, "user_text": ""} | |
| return | |
| messages = build_messages(history, user_text) | |
| async for chunk in stream_reply(messages): | |
| yield {**chunk, "user_text": user_text} | |
| def get_metrics() -> dict: | |
| """Return current in-memory metrics snapshot.""" | |
| return _metrics.snapshot() | |
| async def reset_metrics_api() -> dict: | |
| """Reset all in-memory metrics counters.""" | |
| _metrics.reset() | |
| return _metrics.snapshot() | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0") | |