wami / app.py
Bgk Injector SqLi
Fix: Use NllbTokenizer instead of AutoTokenizer
9b5d5f1
import io
import os
import tempfile
from pathlib import Path
import numpy as np
import scipy.io.wavfile
import soundfile as sf
import torch
import torchaudio
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import json
import base64
app = FastAPI(
title="Wami - Dioula STT, TTS & Traduction API",
description="API de reconnaissance vocale (STT), synthèse vocale (TTS) et traduction (Dioula ↔ Français)",
version="1.1.0"
)
# CORS pour permettre les appels depuis n'importe quel domaine
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Gestionnaires d'erreur globaux
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.detail}
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={"error": f"Erreur serveur: {str(exc)}"}
)
# Globals
stt_processor = None
stt_model = None
tts_tokenizer = None
tts_model = None
translate_tokenizer = None
translate_model = None
device = "cpu"
@app.on_event("startup")
def load_models():
global stt_processor, stt_model, tts_tokenizer, tts_model, translate_tokenizer, translate_model, device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Device: {device}")
# STT
from transformers import AutoProcessor, Wav2Vec2ForCTC
print("⏳ Chargement du modèle STT (Dioula)...")
stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang="dyu")
stt_model = Wav2Vec2ForCTC.from_pretrained(
"facebook/mms-1b-all",
target_lang="dyu",
ignore_mismatched_sizes=True
)
stt_model.load_adapter("dyu")
stt_model.to(device)
print("✅ STT prêt!")
# TTS
from transformers import AutoTokenizer, VitsModel
print("⏳ Chargement du modèle TTS (Dioula)...")
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-dyu")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-dyu").to(device)
print("✅ TTS prêt!")
# Translation (Dioula → Français)
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
print("⏳ Chargement du modèle de traduction (Dioula → Français)...")
translate_tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
translate_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device)
print("✅ Traduction prête!")
# Page d'accueil avec documentation
@app.get("/", response_class=HTMLResponse)
def home():
return """
<!DOCTYPE html>
<html lang="fr">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Wami - API Dioula STT & TTS</title>
<style>
body { font-family: system-ui; max-width: 800px; margin: 40px auto; padding: 20px; line-height: 1.6; }
h1 { color: #2563eb; }
h2 { color: #1e40af; margin-top: 30px; }
code { background: #f1f5f9; padding: 2px 6px; border-radius: 4px; }
pre { background: #0f172a; color: #e2e8f0; padding: 16px; border-radius: 8px; overflow-x: auto; }
.endpoint { background: #f8fafc; padding: 16px; border-left: 4px solid #3b82f6; margin: 16px 0; }
.method { display: inline-block; padding: 4px 8px; border-radius: 4px; font-weight: bold; margin-right: 8px; }
.post { background: #10b981; color: white; }
.get { background: #3b82f6; color: white; }
</style>
</head>
<body>
<h1>🎙️ Wami - API Dioula STT, TTS & Traduction</h1>
<p>API de reconnaissance vocale (STT), synthèse vocale (TTS) et traduction (Dioula ↔ Français).</p>
<h2>📖 Endpoints</h2>
<div class="endpoint">
<p><span class="method get">GET</span> <code>/</code></p>
<p>Cette page de documentation</p>
</div>
<div class="endpoint">
<p><span class="method get">GET</span> <code>/health</code></p>
<p>Statut de l'API et des modèles</p>
</div>
<div class="endpoint">
<p><span class="method post">POST</span> <code>/api/stt</code></p>
<p><strong>Speech-to-Text</strong> - Transcrit un fichier audio en texte Dioula</p>
<p><strong>Entrée:</strong> Fichier audio (WebM, WAV, MP3)</p>
<p><strong>Sortie:</strong> <code>{"transcription": "texte en dioula"}</code></p>
<pre>curl -X POST https://votre-space.hf.space/api/stt \\
-F "audio=@recording.wav"</pre>
</div>
<div class="endpoint">
<p><span class="method post">POST</span> <code>/api/tts</code></p>
<p><strong>Text-to-Speech</strong> - Génère un audio en Dioula depuis du texte</p>
<p><strong>Entrée:</strong> Texte en Dioula (paramètre <code>text</code>)</p>
<p><strong>Sortie:</strong> Fichier WAV</p>
<pre>curl -X POST https://votre-space.hf.space/api/tts \\
-F "text=na an be do minkɛ" \\
-o output.wav</pre>
</div>
<div class="endpoint">
<p><span class="method post">POST</span> <code>/api/translate/dyu-fr</code></p>
<p><strong>Traduction Dioula → Français</strong></p>
<p><strong>Entrée:</strong> Texte en Dioula (paramètre <code>text</code>)</p>
<p><strong>Sortie:</strong> JSON avec traduction française</p>
<pre>curl -X POST https://votre-space.hf.space/api/translate/dyu-fr \\
-F "text=Sanji bɛna kɛ bi"</pre>
</div>
<div class="endpoint">
<p><span class="method post">POST</span> <code>/api/translate/fr-dyu</code></p>
<p><strong>Traduction Français → Dioula</strong></p>
<p><strong>Entrée:</strong> Texte en Français (paramètre <code>text</code>)</p>
<p><strong>Sortie:</strong> JSON avec traduction dioula</p>
<pre>curl -X POST https://votre-space.hf.space/api/translate/fr-dyu \\
-F "text=Il va pleuvoir aujourd'hui"</pre>
</div>
<div class="endpoint">
<p><span class="method get" style="background: #f59e0b;">WS</span> <code>/ws/pipeline</code></p>
<p><strong>Pipeline WebSocket</strong> - Audio → STT → Traduction (temps réel)</p>
<p><strong>Entrée:</strong> JSON avec audio base64</p>
<p><strong>Sortie:</strong> Progression en temps réel + résultats</p>
<pre>const ws = new WebSocket('wss://votre-space.hf.space/ws/pipeline');
ws.send(JSON.stringify({
action: "process",
audio: "base64_audio",
target_lang: "fr"
}));</pre>
</div>
<h2>🔗 Liens utiles</h2>
<p>
<a href="/demo" style="color: #3b82f6; font-weight: bold;">🎯 Demo Live</a> |
<a href="/ws-demo" style="color: #f59e0b; font-weight: bold;">🔄 WebSocket Demo</a> |
<a href="/docs">Swagger UI</a> |
<a href="/redoc">ReDoc</a>
</p>
<h2>ℹ️ Modèles</h2>
<ul>
<li><strong>STT:</strong> facebook/mms-1b-all (adapter Dioula)</li>
<li><strong>TTS:</strong> facebook/mms-tts-dyu</li>
<li><strong>Traduction:</strong> facebook/nllb-200-distilled-600M (Dioula ↔ Français)</li>
</ul>
<h2>🔄 Flux de travail complet</h2>
<p><strong>Exemple :</strong> Audio Dioula → Texte Dioula → Traduction Français → Audio Français</p>
<ol>
<li><code>/api/stt</code> : Convertit audio en texte dioula</li>
<li><code>/api/translate/dyu-fr</code> : Traduit en français</li>
<li>(Optionnel) Utiliser un TTS français pour générer l'audio</li>
</ol>
</body>
</html>
"""
@app.get("/demo", response_class=HTMLResponse)
def demo_page():
"""Page de démonstration interactive"""
return Path("demo.html").read_text(encoding="utf-8")
@app.get("/ws-demo", response_class=HTMLResponse)
def websocket_demo_page():
"""Page de démonstration WebSocket pipeline"""
return Path("websocket_example.html").read_text(encoding="utf-8")
@app.get("/health")
def health_check():
"""Vérifie le statut de l'API et des modèles"""
return {
"status": "healthy",
"device": device,
"models_loaded": {
"stt": stt_model is not None,
"tts": tts_model is not None,
"translate": translate_model is not None
}
}
@app.post("/api/stt")
async def speech_to_text(audio: UploadFile = File(...)):
"""
Transcrit un fichier audio en texte Dioula
- **audio**: Fichier audio (WebM, WAV, MP3, etc.)
"""
tmp_input = None
tmp_wav = None
try:
audio_bytes = await audio.read()
# Déterminer l'extension
content_type = audio.content_type or ""
if "webm" in content_type:
suffix = ".webm"
elif "wav" in content_type:
suffix = ".wav"
elif "mp3" in content_type:
suffix = ".mp3"
else:
suffix = ".webm"
# Sauvegarder temporairement
tmp_input = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_input.write(audio_bytes)
tmp_input.close()
# Convertir en WAV si nécessaire
if suffix != ".wav":
try:
audio_data, sample_rate = sf.read(tmp_input.name)
tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_wav.close()
sf.write(tmp_wav.name, audio_data, sample_rate)
audio_path = tmp_wav.name
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Impossible de lire l'audio. Format non supporté. Erreur: {str(e)}"
)
else:
audio_path = tmp_input.name
# Charger avec torchaudio
audio_input, sample_rate = torchaudio.load(audio_path)
# Mono
if audio_input.shape[0] > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
# Resample à 16 kHz
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
audio_input = resampler(audio_input)
audio_input = audio_input.squeeze()
# Inférence
inputs = stt_processor(audio_input, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = stt_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = stt_processor.batch_decode(predicted_ids)[0]
return {"transcription": transcription}
except HTTPException:
raise
except Exception as e:
print(f"Erreur STT: {e}")
raise HTTPException(status_code=500, detail=f"Erreur lors de la transcription: {str(e)}")
finally:
if tmp_input and Path(tmp_input.name).exists():
Path(tmp_input.name).unlink(missing_ok=True)
if tmp_wav and Path(tmp_wav.name).exists():
Path(tmp_wav.name).unlink(missing_ok=True)
@app.post("/api/tts")
async def text_to_speech(text: str = Form(...)):
"""
Génère un audio en Dioula depuis du texte
- **text**: Texte en Dioula à synthétiser
"""
try:
if not text.strip():
raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide")
inputs = tts_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
waveform = tts_model(**inputs).waveform
audio_data = waveform[0].cpu().numpy()
sample_rate = tts_model.config.sampling_rate
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
scipy.io.wavfile.write(tmp.name, rate=sample_rate, data=audio_data)
tmp.close()
return FileResponse(
tmp.name,
media_type="audio/wav",
filename="tts_dioula.wav"
)
except HTTPException:
raise
except Exception as e:
print(f"Erreur TTS: {e}")
raise HTTPException(status_code=500, detail=f"Erreur lors de la génération audio: {str(e)}")
@app.post("/api/translate/dyu-fr")
async def translate_dioula_to_french(text: str = Form(...)):
"""
Traduit du texte du Dioula vers le Français
- **text**: Texte en Dioula à traduire
"""
try:
if not text.strip():
raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide")
# Définir la langue source : Dioula
translate_tokenizer.src_lang = "dyu_Latn"
# Préparation des tokens
inputs = translate_tokenizer(text, return_tensors="pt").to(device)
# Récupérer l'ID de la langue cible : Français
target_lang_id = translate_tokenizer.convert_tokens_to_ids("fra_Latn")
# Génération de la traduction
with torch.no_grad():
translated_tokens = translate_model.generate(
**inputs,
forced_bos_token_id=target_lang_id,
max_length=200
)
# Décodage
traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return {
"texte_source": text,
"langue_source": "dioula",
"texte_traduit": traduction,
"langue_cible": "français"
}
except HTTPException:
raise
except Exception as e:
print(f"Erreur Traduction DYU→FR: {e}")
raise HTTPException(status_code=500, detail=f"Erreur lors de la traduction: {str(e)}")
@app.post("/api/translate/fr-dyu")
async def translate_french_to_dioula(text: str = Form(...)):
"""
Traduit du texte du Français vers le Dioula
- **text**: Texte en Français à traduire
"""
try:
if not text.strip():
raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide")
# Définir la langue source : Français
translate_tokenizer.src_lang = "fra_Latn"
# Préparation des tokens
inputs = translate_tokenizer(text, return_tensors="pt").to(device)
# Récupérer l'ID de la langue cible : Dioula
target_lang_id = translate_tokenizer.convert_tokens_to_ids("dyu_Latn")
# Génération de la traduction
with torch.no_grad():
translated_tokens = translate_model.generate(
**inputs,
forced_bos_token_id=target_lang_id,
max_length=200
)
# Décodage
traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return {
"texte_source": text,
"langue_source": "français",
"texte_traduit": traduction,
"langue_cible": "dioula"
}
except HTTPException:
raise
except Exception as e:
print(f"Erreur Traduction FR→DYU: {e}")
raise HTTPException(status_code=500, detail=f"Erreur lors de la traduction: {str(e)}")
@app.websocket("/ws/pipeline")
async def websocket_pipeline(websocket: WebSocket):
"""
WebSocket pour pipeline complet : Audio → STT → Traduction → TTS
Le client envoie :
{
"action": "process",
"audio": "base64_encoded_audio",
"target_lang": "fr" ou "dyu" (optionnel, défaut: "fr")
}
Le serveur répond :
{
"status": "processing",
"step": "stt" | "translate" | "tts" | "done",
"transcription": "...",
"traduction": "...",
"audio_base64": "..." (si TTS demandé)
}
"""
await websocket.accept()
try:
while True:
# Recevoir le message du client
data = await websocket.receive_text()
message = json.loads(data)
if message.get("action") != "process":
await websocket.send_json({"error": "Action invalide. Utilisez 'process'"})
continue
audio_base64 = message.get("audio")
target_lang = message.get("target_lang", "fr")
include_tts = message.get("include_tts", False)
if not audio_base64:
await websocket.send_json({"error": "Aucun audio fourni"})
continue
try:
# Étape 1 : Décoder l'audio
await websocket.send_json({"status": "processing", "step": "decoding", "message": "Décodage de l'audio..."})
audio_bytes = base64.b64decode(audio_base64)
# Sauvegarder temporairement
tmp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False)
tmp_input.write(audio_bytes)
tmp_input.close()
# Convertir en WAV
try:
audio_data, sample_rate = sf.read(tmp_input.name)
tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_wav.close()
sf.write(tmp_wav.name, audio_data, sample_rate)
audio_path = tmp_wav.name
except Exception as e:
await websocket.send_json({"error": f"Erreur de conversion audio: {str(e)}"})
Path(tmp_input.name).unlink(missing_ok=True)
continue
# Étape 2 : STT
await websocket.send_json({"status": "processing", "step": "stt", "message": "Transcription en cours..."})
audio_input, sample_rate = torchaudio.load(audio_path)
if audio_input.shape[0] > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
audio_input = resampler(audio_input)
audio_input = audio_input.squeeze()
inputs = stt_processor(audio_input, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = stt_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = stt_processor.batch_decode(predicted_ids)[0]
await websocket.send_json({
"status": "processing",
"step": "stt_done",
"transcription": transcription
})
# Étape 3 : Traduction
await websocket.send_json({"status": "processing", "step": "translate", "message": "Traduction en cours..."})
if target_lang == "fr":
translate_tokenizer.src_lang = "dyu_Latn"
target_lang_id = translate_tokenizer.convert_tokens_to_ids("fra_Latn")
else:
translate_tokenizer.src_lang = "fra_Latn"
target_lang_id = translate_tokenizer.convert_tokens_to_ids("dyu_Latn")
translate_inputs = translate_tokenizer(transcription, return_tensors="pt").to(device)
with torch.no_grad():
translated_tokens = translate_model.generate(
**translate_inputs,
forced_bos_token_id=target_lang_id,
max_length=200
)
traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
await websocket.send_json({
"status": "processing",
"step": "translate_done",
"traduction": traduction
})
# Étape 4 : TTS (optionnel)
audio_b64 = None
if include_tts and target_lang == "dyu":
await websocket.send_json({"status": "processing", "step": "tts", "message": "Génération audio..."})
tts_inputs = tts_tokenizer(traduction, return_tensors="pt").to(device)
with torch.no_grad():
waveform = tts_model(**tts_inputs).waveform
audio_data_tts = waveform[0].cpu().numpy()
sample_rate_tts = tts_model.config.sampling_rate
tmp_tts = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
scipy.io.wavfile.write(tmp_tts.name, rate=sample_rate_tts, data=audio_data_tts)
with open(tmp_tts.name, "rb") as f:
audio_b64 = base64.b64encode(f.read()).decode('utf-8')
Path(tmp_tts.name).unlink(missing_ok=True)
# Résultat final
result = {
"status": "done",
"step": "complete",
"transcription": transcription,
"langue_source": "dioula",
"traduction": traduction,
"langue_cible": "français" if target_lang == "fr" else "dioula"
}
if audio_b64:
result["audio_base64"] = audio_b64
await websocket.send_json(result)
# Nettoyer
Path(tmp_input.name).unlink(missing_ok=True)
Path(audio_path).unlink(missing_ok=True)
except Exception as e:
await websocket.send_json({"error": f"Erreur pipeline: {str(e)}"})
except WebSocketDisconnect:
print("Client déconnecté du WebSocket")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)