RemiProAtos's picture
refactor as gradio server
8fc59e8 verified
Raw
History Blame Contribute Delete
20.9 kB
"""
Voice Chatbot — Hugging Face Space version (Mistral API + MLflow tracing).
Architecture:
- LLM : Mistral Chat API (streaming SSE via httpx, traced manually)
- STT : Mistral Transcriptions API (via Mistral SDK — auto-traced by MLflow)
- TTS : Mistral Speech API (via httpx, not traced)
MLflow tracing captures every LLM and STT call with:
- Input prompts / messages
- Output completions / transcriptions
- Token counts & latency
- Model name & generation parameters
To activate, set these HF Space secrets:
MISTRAL_API_KEY — required
MLFLOW_TRACKING_URI — your cloud MLflow 3 server URL
MLFLOW_EXPERIMENT_NAME — optional (default: mistral-chatbot)
LLM_MODEL — optional (default: mistral-small-latest)
STT_MODEL — optional (default: voxtral-mini-latest)
TTS_MODEL — optional (default: voxtral-mini-tts-2603)
TTS_VOICE — optional TTS voice ID
"""
import asyncio
import base64
import io
import json
import logging
import os
import re
import time
from gradio import Server
import httpx
import soundfile as sf
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
logger = logging.getLogger("app")
# ══════════════════════════════════════════════════════════════════════════
# CONFIGURATION — set via HF Space secrets (Settings → Repository secrets)
# ══════════════════════════════════════════════════════════════════════════
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "")
# ── LLM ─────────────────────────────────────────────────────────────────
LLM_MODEL = os.getenv("LLM_MODEL", "mistral-small-latest")
LLM_API_URL = os.getenv(
"LLM_API_URL",
"https://api.mistral.ai/v1/chat/completions",
)
# ── STT ─────────────────────────────────────────────────────────────────
STT_MODEL = os.getenv("STT_MODEL", "voxtral-mini-latest")
STT_API_URL = os.getenv(
"STT_API_URL",
"https://api.mistral.ai/v1/audio/transcriptions",
)
# ── TTS ─────────────────────────────────────────────────────────────────
TTS_API_URL = os.getenv(
"TTS_API_URL",
"https://api.mistral.ai/v1/audio/speech",
)
TTS_MODEL = os.getenv("TTS_MODEL", "voxtral-mini-tts-2603")
TTS_VOICE = os.getenv("TTS_VOICE", "")
# ── MLflow Tracing ──────────────────────────────────────────────────────
MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "")
MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "mistral-chatbot")
# ── Sentence boundary regex ────────────────────────────────────────────
SENTENCE_END = re.compile(r"[.?!…]+\s+")
SYSTEM_PROMPT = {
"role": "system",
"content": "You are a helpful and concise voice assistant.",
}
if not MISTRAL_API_KEY:
print("WARNING: MISTRAL_API_KEY not set — all API calls will fail.")
# ══════════════════════════════════════════════════════════════════════════
# MLflow setup (runs once at module import time)
# ══════════════════════════════════════════════════════════════════════════
MLFLOW_ENABLED = bool(MLFLOW_TRACKING_URI)
if MLFLOW_ENABLED:
try:
import mlflow
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
mlflow.mistral.autolog()
logger.info(
"MLflow tracing ENABLED — tracking URI: %s, experiment: %s",
MLFLOW_TRACKING_URI,
MLFLOW_EXPERIMENT_NAME,
)
except Exception as exc:
logger.warning(
"MLflow init failed (%s) — tracing disabled. "
"Set MLFLOW_TRACKING_URI to a valid MLflow 3 server URL.",
exc,
)
MLFLOW_ENABLED = False
# ── Mistral SDK client (used for auto-traced STT calls) ─────────────────
if MLFLOW_ENABLED:
try:
from mistralai.async_client import MistralAsync as _MistralAsync
_mistral = _MistralAsync(api_key=MISTRAL_API_KEY)
except Exception as exc:
logger.warning("Failed to create Mistral SDK client: %s", exc)
_mistral = None
else:
_mistral = None
# ══════════════════════════════════════════════════════════════════════════
# IN-MEMORY METRICS
# ══════════════════════════════════════════════════════════════════════════
class Metrics:
"""Simple in-memory counters and accumulators."""
def __init__(self):
self.lock = asyncio.Lock()
self.reset()
def reset(self):
self.stt_count = 0
self.stt_total_s = 0.0
self.llm_count = 0
self.llm_total_s = 0.0
self.tts_count = 0
self.tts_total_s = 0.0
self.total_tokens = 0
self.error_count = 0
self.last_errors: list[str] = []
async def record_stt(self, elapsed: float):
async with self.lock:
self.stt_count += 1
self.stt_total_s += elapsed
async def record_llm(self, elapsed: float, tokens: int):
async with self.lock:
self.llm_count += 1
self.llm_total_s += elapsed
self.total_tokens += tokens
async def record_tts(self, elapsed: float):
async with self.lock:
self.tts_count += 1
self.tts_total_s += elapsed
async def record_error(self, msg: str):
async with self.lock:
self.error_count += 1
self.last_errors.append(msg)
if len(self.last_errors) > 20:
self.last_errors.pop(0)
def snapshot(self) -> dict:
def avg(total, count):
return round(total / count, 3) if count else None
m = self
return {
"stt": {"calls": m.stt_count, "avg_latency_s": avg(m.stt_total_s, m.stt_count), "total_s": round(m.stt_total_s, 2)},
"llm": {"calls": m.llm_count, "avg_latency_s": avg(m.llm_total_s, m.llm_count), "total_s": round(m.llm_total_s, 2)},
"tts": {"calls": m.tts_count, "avg_latency_s": avg(m.tts_total_s, m.tts_count), "total_s": round(m.tts_total_s, 2)},
"total_tokens": m.total_tokens,
"errors": m.error_count,
"last_errors": m.last_errors[-5:],
}
_metrics = Metrics()
# ══════════════════════════════════════════════════════════════════════════
# HELPERS
# ══════════════════════════════════════════════════════════════════════════
def build_messages(history: list[dict], user_text: str) -> list[dict]:
"""Build the messages array for the Mistral Chat API."""
messages = [SYSTEM_PROMPT] + list(history)
messages.append({"role": "user", "content": user_text})
return messages
# ══════════════════════════════════════════════════════════════════════════
# STT — speech-to-text
# ══════════════════════════════════════════════════════════════════════════
async def transcribe(audio_path: str) -> str:
"""Convert audio to 16 kHz mono WAV and transcribe via Mistral STT API."""
import subprocess
import tempfile as tf
t0 = time.perf_counter()
wav_path = None
try:
with tf.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
wav_path = tmp.name
subprocess.run(
["ffmpeg", "-y", "-i", audio_path, "-ar", "16000", "-ac", "1", wav_path],
capture_output=True, check=True,
)
with open(wav_path, "rb") as f:
audio_bytes = f.read()
except Exception as e:
logger.warning("ffmpeg conversion failed, sending raw audio: %s", e)
with open(audio_path, "rb") as f:
audio_bytes = f.read()
finally:
if wav_path and os.path.exists(wav_path):
os.unlink(wav_path)
try:
if _mistral is not None:
result = await _mistral.audio.transcriptions.complete(
model=STT_MODEL,
file={"content": audio_bytes, "file_name": "audio.wav"},
)
text = result.text
else:
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"}
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
STT_API_URL,
headers=headers,
files={
"file": ("audio.wav", audio_bytes, "audio/wav"),
"model": (None, STT_MODEL),
},
)
resp.raise_for_status()
text = resp.json()["text"]
await _metrics.record_stt(time.perf_counter() - t0)
logger.info("STT ok %.2fs %.0f bytes model=%s", time.perf_counter() - t0, len(audio_bytes), STT_MODEL)
return text
except Exception as e:
await _metrics.record_error(f"STT: {e}")
raise
# ══════════════════════════════════════════════════════════════════════════
# TTS — text-to-speech
# ══════════════════════════════════════════════════════════════════════════
async def call_tts(client: httpx.AsyncClient, text: str) -> str | None:
"""Synthesise speech via Mistral TTS API. Returns base64-encoded WAV string."""
t0 = time.perf_counter()
try:
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"}
body: dict = {
"model": TTS_MODEL,
"input": text,
"response_format": "wav",
}
if TTS_VOICE:
body["voice_id"] = TTS_VOICE
resp = await client.post(TTS_API_URL, headers=headers, json=body, timeout=60.0)
resp.raise_for_status()
data = resp.json()
elapsed = time.perf_counter() - t0
await _metrics.record_tts(elapsed)
logger.info("TTS ok %.2fs %d chars model=%s", elapsed, len(text), TTS_MODEL)
return data["audio_data"] # base64-encoded WAV
except Exception as e:
await _metrics.record_error(f"TTS: {e}")
logger.warning("TTS failed (%.1fs): %s", time.perf_counter() - t0, e)
return None
# ══════════════════════════════════════════════════════════════════════════
# LLM — streaming text generation
# ══════════════════════════════════════════════════════════════════════════
async def stream_llm(messages: list[dict]):
"""Stream tokens from Mistral Chat API via SSE. Yields (token, cumulative_count)."""
headers = {
"Authorization": f"Bearer {MISTRAL_API_KEY}",
"Content-Type": "application/json",
}
body = {
"model": LLM_MODEL,
"messages": messages,
"stream": True,
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.9,
}
t0 = time.perf_counter()
token_count = 0
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=120.0)) as client:
async with client.stream("POST", LLM_API_URL, json=body, headers=headers) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if line.startswith("data: "):
payload = line[6:].strip()
if payload == "[DONE]":
break
if not payload:
continue
try:
data = json.loads(payload)
except json.JSONDecodeError:
continue
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content = delta.get("content")
if content:
token_count += 1
yield content, token_count
elapsed = time.perf_counter() - t0
await _metrics.record_llm(elapsed, token_count)
logger.info("LLM ok %.2fs %d tokens %.1f tok/s model=%s", elapsed, token_count, token_count / elapsed if elapsed else 0, LLM_MODEL)
except httpx.HTTPStatusError as e:
body_text = await e.response.aread()
detail = body_text.decode(errors="replace")[:300]
elapsed = time.perf_counter() - t0
await _metrics.record_error(f"LLM HTTP {e.response.status_code}: {detail[:80]}")
logger.error("LLM HTTP %s %.1fs %s", e.response.status_code, elapsed, detail)
yield f"[ERROR] LLM API error ({e.response.status_code}): {detail}", 0
except Exception as e:
elapsed = time.perf_counter() - t0
await _metrics.record_error(f"LLM: {type(e).__name__}: {e}")
logger.error("LLM error %.1fs %s", elapsed, e)
yield f"[ERROR] LLM error: {type(e).__name__}: {e}", 0
# ══════════════════════════════════════════════════════════════════════════
# STREAM ORCHESTRATOR
# ══════════════════════════════════════════════════════════════════════════
async def stream_reply(messages: list[dict]):
"""Consume LLM stream, record MLflow trace, and synthesize audio.
Yields dicts: {"reply": str, "audio_b64": str|None, "done": bool}
"""
token_buffer = ""
full_reply = ""
_trace_token_count = 0
_trace_start = time.perf_counter()
_trace_error = None
try:
async for token_or_error, token_count in stream_llm(messages):
if token_or_error.startswith("[ERROR]"):
_trace_error = token_or_error
yield {"reply": token_or_error, "audio_b64": None, "done": True}
return
_trace_token_count = token_count
token_buffer += token_or_error
full_reply += token_or_error
match = SENTENCE_END.search(token_buffer)
if match:
sentence = token_buffer[: match.end()].strip()
token_buffer = token_buffer[match.end():]
if sentence:
async with httpx.AsyncClient() as client:
audio_b64 = await call_tts(client, sentence)
yield {"reply": full_reply, "audio_b64": audio_b64, "done": False}
continue
yield {"reply": full_reply, "audio_b64": None, "done": False}
# Flush remaining text
if token_buffer.strip():
async with httpx.AsyncClient() as client:
audio_b64 = await call_tts(client, token_buffer.strip())
yield {"reply": full_reply, "audio_b64": audio_b64, "done": True}
else:
yield {"reply": full_reply, "audio_b64": None, "done": True}
finally:
if MLFLOW_ENABLED and full_reply:
_elapsed = time.perf_counter() - _trace_start
try:
import mlflow
from mlflow.tracing.fluent import start_span
with start_span("llm_chat_stream") as span:
span.set_inputs({
"model": LLM_MODEL,
"messages": messages,
"temperature": 0.7,
"max_tokens": 512,
})
span.set_outputs({
"response": full_reply,
"token_count": _trace_token_count,
"latency_seconds": round(_elapsed, 3),
"tokens_per_second": round(_trace_token_count / _elapsed, 1) if _elapsed > 0 else 0,
})
if _trace_error:
span.set_status(mlflow.tracing.Status.ERROR, str(_trace_error))
except Exception as trace_err:
logger.warning("MLflow trace recording failed: %s", trace_err)
# ══════════════════════════════════════════════════════════════════════════
# GRADIO SERVER
# ══════════════════════════════════════════════════════════════════════════
app = Server(
title="Voice Chatbot",
description=f"API LLM: `{LLM_MODEL}` · STT: `{STT_MODEL}` · TTS: `{TTS_MODEL}`",
)
@app.api(name="text_turn")
async def text_turn(user_text: str, history: list[dict]) -> dict:
"""Send a text message and stream back reply chunks with optional audio.
history: list of {"role": "user"|"assistant", "content": str}
Yields: {"reply": str, "audio_b64": str|None, "done": bool}
"""
if not user_text.strip():
yield {"reply": "", "audio_b64": None, "done": True}
return
messages = build_messages(history, user_text)
async for chunk in stream_reply(messages):
yield chunk
@app.api(name="voice_turn")
async def voice_turn(audio_path: str, history: list[dict]) -> dict:
"""Transcribe audio then stream back a reply.
audio_path: local file path to recorded audio.
history: list of {"role": "user"|"assistant", "content": str}
Yields: {"reply": str, "audio_b64": str|None, "done": bool, "user_text": str}
"""
if not audio_path:
yield {"reply": "", "audio_b64": None, "done": True, "user_text": ""}
return
try:
user_text = await transcribe(audio_path)
except Exception as e:
err = f"[ERROR] STT error: {type(e).__name__}: {e}"
logger.exception("STT error")
yield {"reply": err, "audio_b64": None, "done": True, "user_text": ""}
return
messages = build_messages(history, user_text)
async for chunk in stream_reply(messages):
yield {**chunk, "user_text": user_text}
@app.api(name="get_metrics")
def get_metrics() -> dict:
"""Return current in-memory metrics snapshot."""
return _metrics.snapshot()
@app.api(name="reset_metrics")
async def reset_metrics_api() -> dict:
"""Reset all in-memory metrics counters."""
_metrics.reset()
return _metrics.snapshot()
if __name__ == "__main__":
app.launch(server_name="0.0.0.0")