import io import os import tempfile from pathlib import Path import numpy as np import scipy.io.wavfile import soundfile as sf import torch import torchaudio from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware import json import base64 app = FastAPI( title="Wami - Dioula STT, TTS & Traduction API", description="API de reconnaissance vocale (STT), synthèse vocale (TTS) et traduction (Dioula ↔ Français)", version="1.1.0" ) # CORS pour permettre les appels depuis n'importe quel domaine app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Gestionnaires d'erreur globaux @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): return JSONResponse( status_code=exc.status_code, content={"error": exc.detail} ) @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): return JSONResponse( status_code=500, content={"error": f"Erreur serveur: {str(exc)}"} ) # Globals stt_processor = None stt_model = None tts_tokenizer = None tts_model = None translate_tokenizer = None translate_model = None device = "cpu" @app.on_event("startup") def load_models(): global stt_processor, stt_model, tts_tokenizer, tts_model, translate_tokenizer, translate_model, device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Device: {device}") # STT from transformers import AutoProcessor, Wav2Vec2ForCTC print("⏳ Chargement du modèle STT (Dioula)...") stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang="dyu") stt_model = Wav2Vec2ForCTC.from_pretrained( "facebook/mms-1b-all", target_lang="dyu", ignore_mismatched_sizes=True ) stt_model.load_adapter("dyu") stt_model.to(device) print("✅ STT prêt!") # TTS from transformers import AutoTokenizer, VitsModel print("⏳ Chargement du modèle TTS (Dioula)...") tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-dyu") tts_model = VitsModel.from_pretrained("facebook/mms-tts-dyu").to(device) print("✅ TTS prêt!") # Translation (Dioula → Français) from transformers import AutoModelForSeq2SeqLM, NllbTokenizer print("⏳ Chargement du modèle de traduction (Dioula → Français)...") translate_tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") translate_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device) print("✅ Traduction prête!") # Page d'accueil avec documentation @app.get("/", response_class=HTMLResponse) def home(): return """ Wami - API Dioula STT & TTS

🎙️ Wami - API Dioula STT, TTS & Traduction

API de reconnaissance vocale (STT), synthèse vocale (TTS) et traduction (Dioula ↔ Français).

📖 Endpoints

GET /

Cette page de documentation

GET /health

Statut de l'API et des modèles

POST /api/stt

Speech-to-Text - Transcrit un fichier audio en texte Dioula

Entrée: Fichier audio (WebM, WAV, MP3)

Sortie: {"transcription": "texte en dioula"}

curl -X POST https://votre-space.hf.space/api/stt \\
  -F "audio=@recording.wav"

POST /api/tts

Text-to-Speech - Génère un audio en Dioula depuis du texte

Entrée: Texte en Dioula (paramètre text)

Sortie: Fichier WAV

curl -X POST https://votre-space.hf.space/api/tts \\
  -F "text=na an be do minkɛ" \\
  -o output.wav

POST /api/translate/dyu-fr

Traduction Dioula → Français

Entrée: Texte en Dioula (paramètre text)

Sortie: JSON avec traduction française

curl -X POST https://votre-space.hf.space/api/translate/dyu-fr \\
  -F "text=Sanji bɛna kɛ bi"

POST /api/translate/fr-dyu

Traduction Français → Dioula

Entrée: Texte en Français (paramètre text)

Sortie: JSON avec traduction dioula

curl -X POST https://votre-space.hf.space/api/translate/fr-dyu \\
  -F "text=Il va pleuvoir aujourd'hui"

WS /ws/pipeline

Pipeline WebSocket - Audio → STT → Traduction (temps réel)

Entrée: JSON avec audio base64

Sortie: Progression en temps réel + résultats

const ws = new WebSocket('wss://votre-space.hf.space/ws/pipeline');
ws.send(JSON.stringify({
  action: "process",
  audio: "base64_audio",
  target_lang: "fr"
}));

🔗 Liens utiles

🎯 Demo Live | 🔄 WebSocket Demo | Swagger UI | ReDoc

ℹ️ Modèles

🔄 Flux de travail complet

Exemple : Audio Dioula → Texte Dioula → Traduction Français → Audio Français

  1. /api/stt : Convertit audio en texte dioula
  2. /api/translate/dyu-fr : Traduit en français
  3. (Optionnel) Utiliser un TTS français pour générer l'audio
""" @app.get("/demo", response_class=HTMLResponse) def demo_page(): """Page de démonstration interactive""" return Path("demo.html").read_text(encoding="utf-8") @app.get("/ws-demo", response_class=HTMLResponse) def websocket_demo_page(): """Page de démonstration WebSocket pipeline""" return Path("websocket_example.html").read_text(encoding="utf-8") @app.get("/health") def health_check(): """Vérifie le statut de l'API et des modèles""" return { "status": "healthy", "device": device, "models_loaded": { "stt": stt_model is not None, "tts": tts_model is not None, "translate": translate_model is not None } } @app.post("/api/stt") async def speech_to_text(audio: UploadFile = File(...)): """ Transcrit un fichier audio en texte Dioula - **audio**: Fichier audio (WebM, WAV, MP3, etc.) """ tmp_input = None tmp_wav = None try: audio_bytes = await audio.read() # Déterminer l'extension content_type = audio.content_type or "" if "webm" in content_type: suffix = ".webm" elif "wav" in content_type: suffix = ".wav" elif "mp3" in content_type: suffix = ".mp3" else: suffix = ".webm" # Sauvegarder temporairement tmp_input = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) tmp_input.write(audio_bytes) tmp_input.close() # Convertir en WAV si nécessaire if suffix != ".wav": try: audio_data, sample_rate = sf.read(tmp_input.name) tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) tmp_wav.close() sf.write(tmp_wav.name, audio_data, sample_rate) audio_path = tmp_wav.name except Exception as e: raise HTTPException( status_code=400, detail=f"Impossible de lire l'audio. Format non supporté. Erreur: {str(e)}" ) else: audio_path = tmp_input.name # Charger avec torchaudio audio_input, sample_rate = torchaudio.load(audio_path) # Mono if audio_input.shape[0] > 1: audio_input = torch.mean(audio_input, dim=0, keepdim=True) # Resample à 16 kHz if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) audio_input = resampler(audio_input) audio_input = audio_input.squeeze() # Inférence inputs = stt_processor(audio_input, sampling_rate=16000, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = stt_model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = stt_processor.batch_decode(predicted_ids)[0] return {"transcription": transcription} except HTTPException: raise except Exception as e: print(f"Erreur STT: {e}") raise HTTPException(status_code=500, detail=f"Erreur lors de la transcription: {str(e)}") finally: if tmp_input and Path(tmp_input.name).exists(): Path(tmp_input.name).unlink(missing_ok=True) if tmp_wav and Path(tmp_wav.name).exists(): Path(tmp_wav.name).unlink(missing_ok=True) @app.post("/api/tts") async def text_to_speech(text: str = Form(...)): """ Génère un audio en Dioula depuis du texte - **text**: Texte en Dioula à synthétiser """ try: if not text.strip(): raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide") inputs = tts_tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): waveform = tts_model(**inputs).waveform audio_data = waveform[0].cpu().numpy() sample_rate = tts_model.config.sampling_rate tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) scipy.io.wavfile.write(tmp.name, rate=sample_rate, data=audio_data) tmp.close() return FileResponse( tmp.name, media_type="audio/wav", filename="tts_dioula.wav" ) except HTTPException: raise except Exception as e: print(f"Erreur TTS: {e}") raise HTTPException(status_code=500, detail=f"Erreur lors de la génération audio: {str(e)}") @app.post("/api/translate/dyu-fr") async def translate_dioula_to_french(text: str = Form(...)): """ Traduit du texte du Dioula vers le Français - **text**: Texte en Dioula à traduire """ try: if not text.strip(): raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide") # Définir la langue source : Dioula translate_tokenizer.src_lang = "dyu_Latn" # Préparation des tokens inputs = translate_tokenizer(text, return_tensors="pt").to(device) # Récupérer l'ID de la langue cible : Français target_lang_id = translate_tokenizer.convert_tokens_to_ids("fra_Latn") # Génération de la traduction with torch.no_grad(): translated_tokens = translate_model.generate( **inputs, forced_bos_token_id=target_lang_id, max_length=200 ) # Décodage traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] return { "texte_source": text, "langue_source": "dioula", "texte_traduit": traduction, "langue_cible": "français" } except HTTPException: raise except Exception as e: print(f"Erreur Traduction DYU→FR: {e}") raise HTTPException(status_code=500, detail=f"Erreur lors de la traduction: {str(e)}") @app.post("/api/translate/fr-dyu") async def translate_french_to_dioula(text: str = Form(...)): """ Traduit du texte du Français vers le Dioula - **text**: Texte en Français à traduire """ try: if not text.strip(): raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide") # Définir la langue source : Français translate_tokenizer.src_lang = "fra_Latn" # Préparation des tokens inputs = translate_tokenizer(text, return_tensors="pt").to(device) # Récupérer l'ID de la langue cible : Dioula target_lang_id = translate_tokenizer.convert_tokens_to_ids("dyu_Latn") # Génération de la traduction with torch.no_grad(): translated_tokens = translate_model.generate( **inputs, forced_bos_token_id=target_lang_id, max_length=200 ) # Décodage traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] return { "texte_source": text, "langue_source": "français", "texte_traduit": traduction, "langue_cible": "dioula" } except HTTPException: raise except Exception as e: print(f"Erreur Traduction FR→DYU: {e}") raise HTTPException(status_code=500, detail=f"Erreur lors de la traduction: {str(e)}") @app.websocket("/ws/pipeline") async def websocket_pipeline(websocket: WebSocket): """ WebSocket pour pipeline complet : Audio → STT → Traduction → TTS Le client envoie : { "action": "process", "audio": "base64_encoded_audio", "target_lang": "fr" ou "dyu" (optionnel, défaut: "fr") } Le serveur répond : { "status": "processing", "step": "stt" | "translate" | "tts" | "done", "transcription": "...", "traduction": "...", "audio_base64": "..." (si TTS demandé) } """ await websocket.accept() try: while True: # Recevoir le message du client data = await websocket.receive_text() message = json.loads(data) if message.get("action") != "process": await websocket.send_json({"error": "Action invalide. Utilisez 'process'"}) continue audio_base64 = message.get("audio") target_lang = message.get("target_lang", "fr") include_tts = message.get("include_tts", False) if not audio_base64: await websocket.send_json({"error": "Aucun audio fourni"}) continue try: # Étape 1 : Décoder l'audio await websocket.send_json({"status": "processing", "step": "decoding", "message": "Décodage de l'audio..."}) audio_bytes = base64.b64decode(audio_base64) # Sauvegarder temporairement tmp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False) tmp_input.write(audio_bytes) tmp_input.close() # Convertir en WAV try: audio_data, sample_rate = sf.read(tmp_input.name) tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) tmp_wav.close() sf.write(tmp_wav.name, audio_data, sample_rate) audio_path = tmp_wav.name except Exception as e: await websocket.send_json({"error": f"Erreur de conversion audio: {str(e)}"}) Path(tmp_input.name).unlink(missing_ok=True) continue # Étape 2 : STT await websocket.send_json({"status": "processing", "step": "stt", "message": "Transcription en cours..."}) audio_input, sample_rate = torchaudio.load(audio_path) if audio_input.shape[0] > 1: audio_input = torch.mean(audio_input, dim=0, keepdim=True) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) audio_input = resampler(audio_input) audio_input = audio_input.squeeze() inputs = stt_processor(audio_input, sampling_rate=16000, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = stt_model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = stt_processor.batch_decode(predicted_ids)[0] await websocket.send_json({ "status": "processing", "step": "stt_done", "transcription": transcription }) # Étape 3 : Traduction await websocket.send_json({"status": "processing", "step": "translate", "message": "Traduction en cours..."}) if target_lang == "fr": translate_tokenizer.src_lang = "dyu_Latn" target_lang_id = translate_tokenizer.convert_tokens_to_ids("fra_Latn") else: translate_tokenizer.src_lang = "fra_Latn" target_lang_id = translate_tokenizer.convert_tokens_to_ids("dyu_Latn") translate_inputs = translate_tokenizer(transcription, return_tensors="pt").to(device) with torch.no_grad(): translated_tokens = translate_model.generate( **translate_inputs, forced_bos_token_id=target_lang_id, max_length=200 ) traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] await websocket.send_json({ "status": "processing", "step": "translate_done", "traduction": traduction }) # Étape 4 : TTS (optionnel) audio_b64 = None if include_tts and target_lang == "dyu": await websocket.send_json({"status": "processing", "step": "tts", "message": "Génération audio..."}) tts_inputs = tts_tokenizer(traduction, return_tensors="pt").to(device) with torch.no_grad(): waveform = tts_model(**tts_inputs).waveform audio_data_tts = waveform[0].cpu().numpy() sample_rate_tts = tts_model.config.sampling_rate tmp_tts = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) scipy.io.wavfile.write(tmp_tts.name, rate=sample_rate_tts, data=audio_data_tts) with open(tmp_tts.name, "rb") as f: audio_b64 = base64.b64encode(f.read()).decode('utf-8') Path(tmp_tts.name).unlink(missing_ok=True) # Résultat final result = { "status": "done", "step": "complete", "transcription": transcription, "langue_source": "dioula", "traduction": traduction, "langue_cible": "français" if target_lang == "fr" else "dioula" } if audio_b64: result["audio_base64"] = audio_b64 await websocket.send_json(result) # Nettoyer Path(tmp_input.name).unlink(missing_ok=True) Path(audio_path).unlink(missing_ok=True) except Exception as e: await websocket.send_json({"error": f"Erreur pipeline: {str(e)}"}) except WebSocketDisconnect: print("Client déconnecté du WebSocket") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)