GradioChat / app.py
RemiProAtos's picture
forced French Language & decrease temperature to 0
f3e7ec5 verified
Raw
History Blame Contribute Delete
30.2 kB
"""
Voice Chatbot — Hugging Face Space version (Mistral API + MLflow tracing).
Architecture:
- LLM : Mistral Chat API (streaming SSE via httpx, traced manually)
- STT : Mistral Transcriptions API (via Mistral SDK — auto-traced by MLflow)
- TTS : Mistral Speech API (via httpx, not traced)
MLflow tracing captures every LLM and STT call with:
- Input prompts / messages
- Output completions / transcriptions
- Token counts & latency
- Model name & generation parameters
To activate, set these HF Space secrets:
MISTRAL_API_KEY — required
MLFLOW_TRACKING_URI — your cloud MLflow 3 server URL
MLFLOW_EXPERIMENT_NAME — optional (default: mistral-chatbot)
LLM_MODEL — optional (default: mistral-small-latest)
STT_MODEL — optional (default: voxtral-mini-latest)
TTS_MODEL — optional (default: voxtral-mini-tts-2603)
TTS_VOICE — optional TTS voice ID
"""
import asyncio
import base64
import io
import json
import logging
import os
import re
import time
import gradio as gr
import httpx
import numpy as np
import soundfile as sf
from pathlib import Path
from mistralai.client import Mistral
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
logger = logging.getLogger("app")
# ══════════════════════════════════════════════════════════════════════════
# CONFIGURATION — set via HF Space secrets (Settings → Repository secrets)
# ══════════════════════════════════════════════════════════════════════════
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "")
# ── LLM ─────────────────────────────────────────────────────────────────
LLM_MODEL = os.getenv("LLM_MODEL", "mistral-small-latest")
LLM_API_URL = os.getenv(
"LLM_API_URL",
"https://api.mistral.ai/v1/chat/completions",
)
# ── STT ─────────────────────────────────────────────────────────────────
STT_MODEL = os.getenv("STT_MODEL", "voxtral-mini-latest")
STT_API_URL = os.getenv(
"STT_API_URL",
"https://api.mistral.ai/v1/audio/transcriptions",
)
# ── TTS ─────────────────────────────────────────────────────────────────
TTS_API_URL = os.getenv(
"TTS_API_URL",
"https://api.mistral.ai/v1/audio/speech",
)
TTS_MODEL = os.getenv("TTS_MODEL", "voxtral-mini-tts-2603")
TTS_VOICE = os.getenv("TTS_VOICE", "5a271406-039d-46fe-835b-fbbb00eaf08d") # "marie_neutral")
# ── MLflow Tracing ──────────────────────────────────────────────────────
# 1. Deploy a cloud MLflow 3 Tracking Server (e.g. on a VM or managed).
# 2. Set MLFLOW_TRACKING_URI to its URL, e.g. "https://mlflow.example.com".
# 3. Optionally set MLFLOW_EXPERIMENT_NAME (default: "mistral-chatbot").
# 4. Both are read from environment variables (HF Space secrets).
#
# What gets traced automatically (via mlflow.mistral.autolog()):
# - STT calls (go through the Mistral Python SDK)
#
# What gets traced manually (via mlflow.start_span):
# - LLM streaming calls (Mistral autolog doesn't support streaming)
#
# What is NOT traced:
# - TTS calls (raw httpx, low diagnostic value)
MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "")
MLFLOW_EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "mistral-chatbot")
# ── Sentence boundary regex ────────────────────────────────────────────
SENTENCE_END = re.compile(r"[.?!…]+\s+")
SYSTEM_PROMPT = {
"role": "system",
"content": "Tu es un assistant d'agent de Bordeaux Metropole. Tu communique exclusisement en FRANCAIS. Tu gardes un ton courtois.",
}
if not MISTRAL_API_KEY:
print("⚠️ MISTRAL_API_KEY not set — all API calls will fail.")
# ══════════════════════════════════════════════════════════════════════════
# MLflow setup (runs once at module import time)
# ══════════════════════════════════════════════════════════════════════════
# Guard flag — tracing is only active when a tracking URI is configured.
MLFLOW_ENABLED = bool(MLFLOW_TRACKING_URI)
if MLFLOW_ENABLED:
# The mlflow package is declared in requirements.txt; if it's somehow
# missing the try/except degrades gracefully (tracing disabled).
try:
import mlflow
# Point the MLflow client at your cloud server.
# Example: mlflow.set_tracking_uri("https://mlflow.example.com")
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
# Create or reuse the experiment. Runs will be grouped under it.
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
# ── Auto-trace Mistral SDK calls ──────────────────────────────
# This patches Mistral's Python SDK so every call to
# client.chat.complete()
# client.audio.transcriptions.complete()
# etc.
# is automatically recorded as a trace in MLflow.
#
# LIMITATION: streaming chat is NOT auto-traced; we handle that
# manually in the stream_reply() function further down.
#
# Doc: https://mlflow.org/docs/latest/genai/tracing/integrations/listing/mistral/
mlflow.mistral.autolog()
logger.info(
"MLflow tracing ENABLED — tracking URI: %s, experiment: %s",
MLFLOW_TRACKING_URI,
MLFLOW_EXPERIMENT_NAME,
)
except Exception as exc:
logger.warning(
"MLflow init failed (%s) — tracing disabled. "
"Set MLFLOW_TRACKING_URI to a valid MLflow 3 server URL.",
exc,
)
MLFLOW_ENABLED = False
# ── Mistral SDK client (used for auto-traced calls) ─────────────────────
# We instantiate the async Mistral client. Calls made through this client
# (e.g. audio transcriptions) are auto-traced by mlflow.mistral.autolog().
# The LLM streaming call uses raw httpx instead (see stream_llm) because
# the Mistral SDK + streaming is not supported by MLflow auto-tracing.
if MLFLOW_ENABLED:
try:
from mistralai.async_client import MistralAsync as _MistralAsync
_mistral = _MistralAsync(api_key=MISTRAL_API_KEY)
except Exception as exc:
logger.warning("Failed to create Mistral SDK client: %s", exc)
_mistral = None
else:
_mistral = None
# ══════════════════════════════════════════════════════════════════════════
# IN-MEMORY METRICS (not related to MLflow — shown in the Gradio UI)
# ══════════════════════════════════════════════════════════════════════════
class Metrics:
"""Simple in-memory counters and accumulators for the UI stats panel."""
def __init__(self):
self.lock = asyncio.Lock()
self.reset()
def reset(self):
self.stt_count = 0
self.stt_total_s = 0.0
self.llm_count = 0
self.llm_total_s = 0.0
self.tts_count = 0
self.tts_total_s = 0.0
self.total_tokens = 0
self.error_count = 0
self.last_errors: list[str] = []
async def record_stt(self, elapsed: float):
async with self.lock:
self.stt_count += 1
self.stt_total_s += elapsed
async def record_llm(self, elapsed: float, tokens: int):
async with self.lock:
self.llm_count += 1
self.llm_total_s += elapsed
self.total_tokens += tokens
async def record_tts(self, elapsed: float):
async with self.lock:
self.tts_count += 1
self.tts_total_s += elapsed
async def record_error(self, msg: str):
async with self.lock:
self.error_count += 1
self.last_errors.append(msg)
if len(self.last_errors) > 20:
self.last_errors.pop(0)
def snapshot_md(self) -> str:
def avg(total, count):
return f"{total/count:.2f}s" if count else "—"
m = self
lines = [
"### 📊 Metrics",
"| Service | Calls | Avg Latency | Total Time |",
"|---------|-------|-------------|------------|",
f"| **STT** | {m.stt_count} | {avg(m.stt_total_s, m.stt_count)} | {m.stt_total_s:.1f}s |",
f"| **LLM** | {m.llm_count} | {avg(m.llm_total_s, m.llm_count)} | {m.llm_total_s:.1f}s |",
f"| **TTS** | {m.tts_count} | {avg(m.tts_total_s, m.tts_count)} | {m.tts_total_s:.1f}s |",
"",
f"**Total tokens generated:** {m.total_tokens}",
f"**Errors:** {m.error_count}",
]
if m.last_errors:
lines.append("\n**Recent errors:**")
for e in m.last_errors[-5:]:
lines.append(f"- `{e[:120]}`")
return "\n".join(lines)
_metrics = Metrics()
async def reset_metrics():
_metrics.reset()
return _metrics.snapshot_md()
# ══════════════════════════════════════════════════════════════════════════
# HELPERS
# ══════════════════════════════════════════════════════════════════════════
def to_chatbot(history: list) -> list[dict]:
"""Convert internal tuple history [(user, bot), …] → Gradio messages format."""
msgs = []
for user_text, assistant_text in history:
if user_text:
msgs.append({"role": "user", "content": user_text})
if assistant_text:
msgs.append({"role": "assistant", "content": assistant_text})
return msgs
def build_messages(history: list, user_text: str) -> list[dict]:
"""Build the messages array for the Mistral Chat API from conversation history."""
messages = [SYSTEM_PROMPT]
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": user_text})
return messages
# ══════════════════════════════════════════════════════════════════════════
# STT — speech-to-text (Mistral Transcriptions API, auto-traced via SDK)
# ══════════════════════════════════════════════════════════════════════════
async def transcribe(audio_path: str) -> str:
"""Convert audio to 16 kHz mono WAV and transcribe via Mistral STT API.
When MLflow tracing is enabled, this call goes through the Mistral
Python SDK and is automatically captured by mlflow.mistral.autolog().
"""
# ── Audio preprocessing ───────────────────────────────────────────
# Browser mic output (WebM/OGG) must be converted to 16 kHz mono WAV.
import subprocess
import tempfile as tf
t0 = time.perf_counter()
wav_path = None
try:
with tf.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
wav_path = tmp.name
subprocess.run(
["ffmpeg", "-y", "-i", audio_path, "-ar", "16000", "-ac", "1", wav_path],
capture_output=True, check=True,
)
with open(wav_path, "rb") as f:
audio_bytes = f.read()
except Exception as e:
logger.warning("ffmpeg conversion failed, sending raw audio: %s", e)
with open(audio_path, "rb") as f:
audio_bytes = f.read()
finally:
if wav_path and os.path.exists(wav_path):
os.unlink(wav_path)
# ── Transcribe ────────────────────────────────────────────────────
try:
if _mistral is not None:
# Mistral SDK call — auto-traced by MLflow via autolog().
# The trace captures: model, audio file info, response text,
# and usage stats (tokens, audio seconds).
result = await _mistral.audio.transcriptions.complete(
model=STT_MODEL,
file={"content": audio_bytes, "file_name": "audio.wav"},
)
text = result.text
else:
# Fallback: raw httpx (no tracing).
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"}
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
STT_API_URL,
headers=headers,
files={
"file": ("audio.wav", audio_bytes, "audio/wav"),
"model": (None, STT_MODEL),
},
)
resp.raise_for_status()
text = resp.json()["text"]
await _metrics.record_stt(time.perf_counter() - t0)
logger.info(
"STT ok %.2fs %.0f bytes model=%s",
time.perf_counter() - t0,
len(audio_bytes),
STT_MODEL,
)
return text
except Exception as e:
await _metrics.record_error(f"STT: {e}")
raise
# ══════════════════════════════════════════════════════════════════════════
# TTS — text-to-speech (Mistral Speech API, raw httpx, not traced)
# ══════════════════════════════════════════════════════════════════════════
async def call_tts(client: Mistral, text: str) -> tuple | None:
"""Synthesise speech via Mistral TTS API, return (sample_rate, numpy_array).
Uses the official Mistral SDK (audio.speech.complete).
Not MLflow-traced (low diagnostic value for TTS).
"""
t0 = time.perf_counter()
try:
kwargs: dict = {
"model": TTS_MODEL,
"input": text,
"response_format": "mp3",
}
if TTS_VOICE:
kwargs["voice_id"] = TTS_VOICE
# SDK is synchronous — run in a thread to avoid blocking the event loop
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
lambda: client.audio.speech.complete(**kwargs),
)
audio_bytes = base64.b64decode(response.audio_data)
audio_np, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
elapsed = time.perf_counter() - t0
await _metrics.record_tts(elapsed)
logger.info("TTS ok %.2fs %d chars model=%s", elapsed, len(text), TTS_MODEL)
return sr, audio_np
except Exception as e:
await _metrics.record_error(f"TTS: {e}")
logger.warning("TTS failed (%.1fs): %s", time.perf_counter() - t0, e)
return None
# ══════════════════════════════════════════════════════════════════════════
# LLM — streaming text generation (Mistral Chat API via httpx SSE)
# ══════════════════════════════════════════════════════════════════════════
async def stream_llm(messages: list[dict]):
"""Stream tokens from Mistral Chat API via Server-Sent Events.
Yields (token, cumulative_token_count).
Uses raw httpx because MLflow's Mistral autolog does NOT support
streaming chat completions (the auto-trace would miss the response).
Instead, the caller (stream_reply) manually records an MLflow span.
"""
headers = {
"Authorization": f"Bearer {MISTRAL_API_KEY}",
"Content-Type": "application/json",
}
body = {
"model": LLM_MODEL,
"messages": messages,
"stream": True,
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.9,
}
t0 = time.perf_counter()
token_count = 0
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, read=120.0)) as client:
async with client.stream(
"POST", LLM_API_URL, json=body, headers=headers,
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if line.startswith("data: "):
payload = line[6:].strip()
if payload == "[DONE]":
break
if not payload:
continue
try:
data = json.loads(payload)
except json.JSONDecodeError:
continue
choices = data.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content = delta.get("content")
if content:
token_count += 1
yield content, token_count
elapsed = time.perf_counter() - t0
await _metrics.record_llm(elapsed, token_count)
logger.info(
"LLM ok %.2fs %d tokens %.1f tok/s model=%s",
elapsed, token_count,
token_count / elapsed if elapsed else 0,
LLM_MODEL,
)
except httpx.HTTPStatusError as e:
body_text = await e.response.aread()
detail = body_text.decode(errors="replace")[:300]
elapsed = time.perf_counter() - t0
await _metrics.record_error(
f"LLM HTTP {e.response.status_code}: {detail[:80]}"
)
logger.error("LLM HTTP %s %.1fs %s", e.response.status_code, elapsed, detail)
yield f"⚠️ LLM API error ({e.response.status_code}): {detail}", 0
except Exception as e:
elapsed = time.perf_counter() - t0
await _metrics.record_error(f"LLM: {type(e).__name__}: {e}")
logger.error("LLM error %.1fs %s", elapsed, e)
yield f"⚠️ LLM error: {type(e).__name__}: {e}", 0
# ══════════════════════════════════════════════════════════════════════════
# STREAM ORCHESTRATOR — consumes LLM tokens, records MLflow trace, calls TTS
# ══════════════════════════════════════════════════════════════════════════
async def stream_reply(messages: list[dict]):
"""Consume LLM stream, record an MLflow trace, and synthesize audio.
Yields (partial_reply, audio_tuple_or_None).
MLflow tracing:
- The Mistral auto-log does not cover streaming completions.
- We manually create a span with mlflow.start_span() that records
the full input messages and the aggregated output.
- The span is finalized when the stream ends (the ``finally`` block).
- If MLFLOW_TRACKING_URI is not set, no trace is recorded.
"""
token_buffer = ""
full_reply = ""
_trace_token_count = 0
_trace_start = time.perf_counter()
_trace_error = None
try:
async for token_or_error, token_count in stream_llm(messages):
# ── Error in stream → stop & propagate ───────────────────
if token_or_error.startswith("⚠️"):
_trace_error = token_or_error
yield token_or_error, None
return
_trace_token_count = token_count
token_buffer += token_or_error
full_reply += token_or_error
# ── Sentence-boundary TTS ─────────────────────────────────
# When a complete sentence has arrived, dispatch a TTS call
# so the user hears audio before the full reply finishes.
match = SENTENCE_END.search(token_buffer)
if match:
sentence = token_buffer[: match.end()].strip()
token_buffer = token_buffer[match.end():]
if sentence:
async with Mistral(api_key=MISTRAL_API_KEY) as client:
audio = await call_tts(client, sentence)
if audio is not None:
yield full_reply, audio
continue
yield full_reply, None
# ── Remaining text after last sentence boundary ───────────────
if token_buffer.strip():
async with Mistral(api_key=MISTRAL_API_KEY) as client:
audio = await call_tts(client, token_buffer.strip())
if audio is not None:
yield full_reply, audio
return
yield full_reply, None
finally:
# ── MLflow manual trace ───────────────────────────────────────
# This runs when the generator is exhausted (async for finishes)
# or when an exception propagates out.
#
# It creates a single span ("llm_chat_stream") that contains the
# full request (messages + params) and the aggregated response.
if MLFLOW_ENABLED and full_reply:
_elapsed = time.perf_counter() - _trace_start
try:
import mlflow
from mlflow.tracing.fluent import start_span
# start_span outside of a trace context creates a new
# root span (i.e. a new trace).
with start_span("llm_chat_stream") as span:
span.set_inputs({
"model": LLM_MODEL,
"messages": messages,
"temperature": 0.95,
"max_tokens": 512,
})
span.set_outputs({
"response": full_reply,
"token_count": _trace_token_count,
"latency_seconds": round(_elapsed, 3),
"tokens_per_second": round(
_trace_token_count / _elapsed, 1
) if _elapsed > 0 else 0,
})
if _trace_error:
span.set_status(
mlflow.tracing.Status.ERROR,
str(_trace_error),
)
except Exception as trace_err:
# Don't crash the app if the MLflow server is unreachable.
logger.warning("MLflow trace recording failed: %s", trace_err)
# ══════════════════════════════════════════════════════════════════════════
# GRADIO TURN HANDLERS
# ══════════════════════════════════════════════════════════════════════════
async def run_turn(user_text: str, history: list):
"""Shared generator: streams chatbot text updates and per-sentence audio."""
history = history or []
messages = build_messages(history, user_text)
new_history = history + [(user_text, "▌")]
audio_chunks: list[np.ndarray] = []
audio_sr = None
async for partial_reply, audio in stream_reply(messages):
new_history[-1] = (user_text, partial_reply + "▌")
if audio is not None:
sr, arr = audio
audio_sr = sr
audio_chunks.append(arr)
yield to_chatbot(new_history), new_history, None, None, gr.update() # . , ., audio, ., . remove partial stream
final_text = new_history[-1][1].rstrip("▌").rstrip()
new_history[-1] = (user_text, final_text)
full_audio = (
(audio_sr, np.concatenate(audio_chunks)) if audio_chunks else None
)
yield to_chatbot(new_history), new_history, None, full_audio, gr.update()
async def text_turn(user_text: str, history: list):
if not user_text.strip():
yield to_chatbot(history or []), history or [], None, None, user_text, gr.update()
return
first = True
async for h, s, a, la, _ in run_turn(user_text, history):
yield h, s, a, la, ("" if first else gr.update()), gr.update()
first = False
async def voice_turn(audio_path: str, history: list):
if audio_path is None:
yield to_chatbot(history or []), history or [], None, None, gr.update()
return
try:
user_text = await transcribe(audio_path)
except Exception as e:
logger.exception("STT error")
err = f"⚠️ STT error: {type(e).__name__}: {e}"
yield to_chatbot((history or []) + [("[voice]", err)]), history or [], None, None, gr.update()
return
async for h, s, a, la, _ in run_turn(user_text, history):
yield h, s, a, la, gr.update()
# ══════════════════════════════════════════════════════════════════════════
# GRADIO UI
# ══════════════════════════════════════════════════════════════════════════
DESCRIPTION = f"""## Voice Chatbot
_API LLM: `{LLM_MODEL}` · STT: `{STT_MODEL}` · TTS: `{TTS_MODEL}`_
"""
with gr.Blocks(
title="Voice Chatbot",
theme=gr.themes.Soft(),
css="footer {visibility: hidden}",
) as demo:
gr.Markdown(DESCRIPTION)
chatbot = gr.Chatbot(label="Conversation", height=380)
state = gr.State([])
last_audio = gr.State(None)
with gr.Row():
text_box = gr.Textbox(
placeholder="Type a message and press Enter …",
show_label=False,
scale=5,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
with gr.Row():
mic_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Voice input",
scale=5,
)
voice_btn = gr.Button("Submit voice", scale=1)
audio_out = gr.Audio(label="Bot response (audio)") # knowed issue on huggingface to stream audio ==> implement from GradioSpace https://huggingface.co/spaces/gradio/stream_audio_out/blob/main/app.py , autoplay=True)
replay_btn = gr.Button("🔁 Replay", size="sm")
# ── Observability panel ───────────────────────────────────────────
with gr.Accordion("📊 Stats & Observability", open=False):
stats_md = gr.Markdown(_metrics.snapshot_md())
with gr.Row():
refresh_btn = gr.Button("🔄 Refresh", size="sm")
reset_btn = gr.Button("🗑 Reset", size="sm")
refresh_btn.click(fn=_metrics.snapshot_md, outputs=[stats_md])
reset_btn.click(fn=reset_metrics, outputs=[stats_md])
# ── Event wiring ──────────────────────────────────────────────────
send_btn.click(
text_turn,
inputs=[text_box, state],
outputs=[chatbot, state, audio_out, last_audio, text_box, stats_md],
)
text_box.submit(
text_turn,
inputs=[text_box, state],
outputs=[chatbot, state, audio_out, last_audio, text_box, stats_md],
)
voice_btn.click(
voice_turn,
inputs=[mic_input, state],
outputs=[chatbot, state, audio_out, last_audio, stats_md],
)
replay_btn.click(
fn=lambda a: a,
inputs=[last_audio],
outputs=[audio_out],
)
demo.queue()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")