Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import scipy.io.wavfile | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import json | |
| import base64 | |
| app = FastAPI( | |
| title="Wami - Dioula STT, TTS & Traduction API", | |
| description="API de reconnaissance vocale (STT), synthèse vocale (TTS) et traduction (Dioula ↔ Français)", | |
| version="1.1.0" | |
| ) | |
| # CORS pour permettre les appels depuis n'importe quel domaine | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Gestionnaires d'erreur globaux | |
| async def http_exception_handler(request: Request, exc: HTTPException): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": exc.detail} | |
| ) | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Erreur serveur: {str(exc)}"} | |
| ) | |
| # Globals | |
| stt_processor = None | |
| stt_model = None | |
| tts_tokenizer = None | |
| tts_model = None | |
| translate_tokenizer = None | |
| translate_model = None | |
| device = "cpu" | |
| def load_models(): | |
| global stt_processor, stt_model, tts_tokenizer, tts_model, translate_tokenizer, translate_model, device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🚀 Device: {device}") | |
| # STT | |
| from transformers import AutoProcessor, Wav2Vec2ForCTC | |
| print("⏳ Chargement du modèle STT (Dioula)...") | |
| stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang="dyu") | |
| stt_model = Wav2Vec2ForCTC.from_pretrained( | |
| "facebook/mms-1b-all", | |
| target_lang="dyu", | |
| ignore_mismatched_sizes=True | |
| ) | |
| stt_model.load_adapter("dyu") | |
| stt_model.to(device) | |
| print("✅ STT prêt!") | |
| # TTS | |
| from transformers import AutoTokenizer, VitsModel | |
| print("⏳ Chargement du modèle TTS (Dioula)...") | |
| tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-dyu") | |
| tts_model = VitsModel.from_pretrained("facebook/mms-tts-dyu").to(device) | |
| print("✅ TTS prêt!") | |
| # Translation (Dioula → Français) | |
| from transformers import AutoModelForSeq2SeqLM, NllbTokenizer | |
| print("⏳ Chargement du modèle de traduction (Dioula → Français)...") | |
| translate_tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| translate_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device) | |
| print("✅ Traduction prête!") | |
| # Page d'accueil avec documentation | |
| def home(): | |
| return """ | |
| <!DOCTYPE html> | |
| <html lang="fr"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Wami - API Dioula STT & TTS</title> | |
| <style> | |
| body { font-family: system-ui; max-width: 800px; margin: 40px auto; padding: 20px; line-height: 1.6; } | |
| h1 { color: #2563eb; } | |
| h2 { color: #1e40af; margin-top: 30px; } | |
| code { background: #f1f5f9; padding: 2px 6px; border-radius: 4px; } | |
| pre { background: #0f172a; color: #e2e8f0; padding: 16px; border-radius: 8px; overflow-x: auto; } | |
| .endpoint { background: #f8fafc; padding: 16px; border-left: 4px solid #3b82f6; margin: 16px 0; } | |
| .method { display: inline-block; padding: 4px 8px; border-radius: 4px; font-weight: bold; margin-right: 8px; } | |
| .post { background: #10b981; color: white; } | |
| .get { background: #3b82f6; color: white; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🎙️ Wami - API Dioula STT, TTS & Traduction</h1> | |
| <p>API de reconnaissance vocale (STT), synthèse vocale (TTS) et traduction (Dioula ↔ Français).</p> | |
| <h2>📖 Endpoints</h2> | |
| <div class="endpoint"> | |
| <p><span class="method get">GET</span> <code>/</code></p> | |
| <p>Cette page de documentation</p> | |
| </div> | |
| <div class="endpoint"> | |
| <p><span class="method get">GET</span> <code>/health</code></p> | |
| <p>Statut de l'API et des modèles</p> | |
| </div> | |
| <div class="endpoint"> | |
| <p><span class="method post">POST</span> <code>/api/stt</code></p> | |
| <p><strong>Speech-to-Text</strong> - Transcrit un fichier audio en texte Dioula</p> | |
| <p><strong>Entrée:</strong> Fichier audio (WebM, WAV, MP3)</p> | |
| <p><strong>Sortie:</strong> <code>{"transcription": "texte en dioula"}</code></p> | |
| <pre>curl -X POST https://votre-space.hf.space/api/stt \\ | |
| -F "audio=@recording.wav"</pre> | |
| </div> | |
| <div class="endpoint"> | |
| <p><span class="method post">POST</span> <code>/api/tts</code></p> | |
| <p><strong>Text-to-Speech</strong> - Génère un audio en Dioula depuis du texte</p> | |
| <p><strong>Entrée:</strong> Texte en Dioula (paramètre <code>text</code>)</p> | |
| <p><strong>Sortie:</strong> Fichier WAV</p> | |
| <pre>curl -X POST https://votre-space.hf.space/api/tts \\ | |
| -F "text=na an be do minkɛ" \\ | |
| -o output.wav</pre> | |
| </div> | |
| <div class="endpoint"> | |
| <p><span class="method post">POST</span> <code>/api/translate/dyu-fr</code></p> | |
| <p><strong>Traduction Dioula → Français</strong></p> | |
| <p><strong>Entrée:</strong> Texte en Dioula (paramètre <code>text</code>)</p> | |
| <p><strong>Sortie:</strong> JSON avec traduction française</p> | |
| <pre>curl -X POST https://votre-space.hf.space/api/translate/dyu-fr \\ | |
| -F "text=Sanji bɛna kɛ bi"</pre> | |
| </div> | |
| <div class="endpoint"> | |
| <p><span class="method post">POST</span> <code>/api/translate/fr-dyu</code></p> | |
| <p><strong>Traduction Français → Dioula</strong></p> | |
| <p><strong>Entrée:</strong> Texte en Français (paramètre <code>text</code>)</p> | |
| <p><strong>Sortie:</strong> JSON avec traduction dioula</p> | |
| <pre>curl -X POST https://votre-space.hf.space/api/translate/fr-dyu \\ | |
| -F "text=Il va pleuvoir aujourd'hui"</pre> | |
| </div> | |
| <div class="endpoint"> | |
| <p><span class="method get" style="background: #f59e0b;">WS</span> <code>/ws/pipeline</code></p> | |
| <p><strong>Pipeline WebSocket</strong> - Audio → STT → Traduction (temps réel)</p> | |
| <p><strong>Entrée:</strong> JSON avec audio base64</p> | |
| <p><strong>Sortie:</strong> Progression en temps réel + résultats</p> | |
| <pre>const ws = new WebSocket('wss://votre-space.hf.space/ws/pipeline'); | |
| ws.send(JSON.stringify({ | |
| action: "process", | |
| audio: "base64_audio", | |
| target_lang: "fr" | |
| }));</pre> | |
| </div> | |
| <h2>🔗 Liens utiles</h2> | |
| <p> | |
| <a href="/demo" style="color: #3b82f6; font-weight: bold;">🎯 Demo Live</a> | | |
| <a href="/ws-demo" style="color: #f59e0b; font-weight: bold;">🔄 WebSocket Demo</a> | | |
| <a href="/docs">Swagger UI</a> | | |
| <a href="/redoc">ReDoc</a> | |
| </p> | |
| <h2>ℹ️ Modèles</h2> | |
| <ul> | |
| <li><strong>STT:</strong> facebook/mms-1b-all (adapter Dioula)</li> | |
| <li><strong>TTS:</strong> facebook/mms-tts-dyu</li> | |
| <li><strong>Traduction:</strong> facebook/nllb-200-distilled-600M (Dioula ↔ Français)</li> | |
| </ul> | |
| <h2>🔄 Flux de travail complet</h2> | |
| <p><strong>Exemple :</strong> Audio Dioula → Texte Dioula → Traduction Français → Audio Français</p> | |
| <ol> | |
| <li><code>/api/stt</code> : Convertit audio en texte dioula</li> | |
| <li><code>/api/translate/dyu-fr</code> : Traduit en français</li> | |
| <li>(Optionnel) Utiliser un TTS français pour générer l'audio</li> | |
| </ol> | |
| </body> | |
| </html> | |
| """ | |
| def demo_page(): | |
| """Page de démonstration interactive""" | |
| return Path("demo.html").read_text(encoding="utf-8") | |
| def websocket_demo_page(): | |
| """Page de démonstration WebSocket pipeline""" | |
| return Path("websocket_example.html").read_text(encoding="utf-8") | |
| def health_check(): | |
| """Vérifie le statut de l'API et des modèles""" | |
| return { | |
| "status": "healthy", | |
| "device": device, | |
| "models_loaded": { | |
| "stt": stt_model is not None, | |
| "tts": tts_model is not None, | |
| "translate": translate_model is not None | |
| } | |
| } | |
| async def speech_to_text(audio: UploadFile = File(...)): | |
| """ | |
| Transcrit un fichier audio en texte Dioula | |
| - **audio**: Fichier audio (WebM, WAV, MP3, etc.) | |
| """ | |
| tmp_input = None | |
| tmp_wav = None | |
| try: | |
| audio_bytes = await audio.read() | |
| # Déterminer l'extension | |
| content_type = audio.content_type or "" | |
| if "webm" in content_type: | |
| suffix = ".webm" | |
| elif "wav" in content_type: | |
| suffix = ".wav" | |
| elif "mp3" in content_type: | |
| suffix = ".mp3" | |
| else: | |
| suffix = ".webm" | |
| # Sauvegarder temporairement | |
| tmp_input = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) | |
| tmp_input.write(audio_bytes) | |
| tmp_input.close() | |
| # Convertir en WAV si nécessaire | |
| if suffix != ".wav": | |
| try: | |
| audio_data, sample_rate = sf.read(tmp_input.name) | |
| tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp_wav.close() | |
| sf.write(tmp_wav.name, audio_data, sample_rate) | |
| audio_path = tmp_wav.name | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Impossible de lire l'audio. Format non supporté. Erreur: {str(e)}" | |
| ) | |
| else: | |
| audio_path = tmp_input.name | |
| # Charger avec torchaudio | |
| audio_input, sample_rate = torchaudio.load(audio_path) | |
| # Mono | |
| if audio_input.shape[0] > 1: | |
| audio_input = torch.mean(audio_input, dim=0, keepdim=True) | |
| # Resample à 16 kHz | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| audio_input = resampler(audio_input) | |
| audio_input = audio_input.squeeze() | |
| # Inférence | |
| inputs = stt_processor(audio_input, sampling_rate=16000, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = stt_model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = stt_processor.batch_decode(predicted_ids)[0] | |
| return {"transcription": transcription} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Erreur STT: {e}") | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de la transcription: {str(e)}") | |
| finally: | |
| if tmp_input and Path(tmp_input.name).exists(): | |
| Path(tmp_input.name).unlink(missing_ok=True) | |
| if tmp_wav and Path(tmp_wav.name).exists(): | |
| Path(tmp_wav.name).unlink(missing_ok=True) | |
| async def text_to_speech(text: str = Form(...)): | |
| """ | |
| Génère un audio en Dioula depuis du texte | |
| - **text**: Texte en Dioula à synthétiser | |
| """ | |
| try: | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide") | |
| inputs = tts_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| waveform = tts_model(**inputs).waveform | |
| audio_data = waveform[0].cpu().numpy() | |
| sample_rate = tts_model.config.sampling_rate | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| scipy.io.wavfile.write(tmp.name, rate=sample_rate, data=audio_data) | |
| tmp.close() | |
| return FileResponse( | |
| tmp.name, | |
| media_type="audio/wav", | |
| filename="tts_dioula.wav" | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Erreur TTS: {e}") | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de la génération audio: {str(e)}") | |
| async def translate_dioula_to_french(text: str = Form(...)): | |
| """ | |
| Traduit du texte du Dioula vers le Français | |
| - **text**: Texte en Dioula à traduire | |
| """ | |
| try: | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide") | |
| # Définir la langue source : Dioula | |
| translate_tokenizer.src_lang = "dyu_Latn" | |
| # Préparation des tokens | |
| inputs = translate_tokenizer(text, return_tensors="pt").to(device) | |
| # Récupérer l'ID de la langue cible : Français | |
| target_lang_id = translate_tokenizer.convert_tokens_to_ids("fra_Latn") | |
| # Génération de la traduction | |
| with torch.no_grad(): | |
| translated_tokens = translate_model.generate( | |
| **inputs, | |
| forced_bos_token_id=target_lang_id, | |
| max_length=200 | |
| ) | |
| # Décodage | |
| traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| return { | |
| "texte_source": text, | |
| "langue_source": "dioula", | |
| "texte_traduit": traduction, | |
| "langue_cible": "français" | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Erreur Traduction DYU→FR: {e}") | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de la traduction: {str(e)}") | |
| async def translate_french_to_dioula(text: str = Form(...)): | |
| """ | |
| Traduit du texte du Français vers le Dioula | |
| - **text**: Texte en Français à traduire | |
| """ | |
| try: | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Le texte ne peut pas être vide") | |
| # Définir la langue source : Français | |
| translate_tokenizer.src_lang = "fra_Latn" | |
| # Préparation des tokens | |
| inputs = translate_tokenizer(text, return_tensors="pt").to(device) | |
| # Récupérer l'ID de la langue cible : Dioula | |
| target_lang_id = translate_tokenizer.convert_tokens_to_ids("dyu_Latn") | |
| # Génération de la traduction | |
| with torch.no_grad(): | |
| translated_tokens = translate_model.generate( | |
| **inputs, | |
| forced_bos_token_id=target_lang_id, | |
| max_length=200 | |
| ) | |
| # Décodage | |
| traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| return { | |
| "texte_source": text, | |
| "langue_source": "français", | |
| "texte_traduit": traduction, | |
| "langue_cible": "dioula" | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Erreur Traduction FR→DYU: {e}") | |
| raise HTTPException(status_code=500, detail=f"Erreur lors de la traduction: {str(e)}") | |
| async def websocket_pipeline(websocket: WebSocket): | |
| """ | |
| WebSocket pour pipeline complet : Audio → STT → Traduction → TTS | |
| Le client envoie : | |
| { | |
| "action": "process", | |
| "audio": "base64_encoded_audio", | |
| "target_lang": "fr" ou "dyu" (optionnel, défaut: "fr") | |
| } | |
| Le serveur répond : | |
| { | |
| "status": "processing", | |
| "step": "stt" | "translate" | "tts" | "done", | |
| "transcription": "...", | |
| "traduction": "...", | |
| "audio_base64": "..." (si TTS demandé) | |
| } | |
| """ | |
| await websocket.accept() | |
| try: | |
| while True: | |
| # Recevoir le message du client | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| if message.get("action") != "process": | |
| await websocket.send_json({"error": "Action invalide. Utilisez 'process'"}) | |
| continue | |
| audio_base64 = message.get("audio") | |
| target_lang = message.get("target_lang", "fr") | |
| include_tts = message.get("include_tts", False) | |
| if not audio_base64: | |
| await websocket.send_json({"error": "Aucun audio fourni"}) | |
| continue | |
| try: | |
| # Étape 1 : Décoder l'audio | |
| await websocket.send_json({"status": "processing", "step": "decoding", "message": "Décodage de l'audio..."}) | |
| audio_bytes = base64.b64decode(audio_base64) | |
| # Sauvegarder temporairement | |
| tmp_input = tempfile.NamedTemporaryFile(suffix=".webm", delete=False) | |
| tmp_input.write(audio_bytes) | |
| tmp_input.close() | |
| # Convertir en WAV | |
| try: | |
| audio_data, sample_rate = sf.read(tmp_input.name) | |
| tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp_wav.close() | |
| sf.write(tmp_wav.name, audio_data, sample_rate) | |
| audio_path = tmp_wav.name | |
| except Exception as e: | |
| await websocket.send_json({"error": f"Erreur de conversion audio: {str(e)}"}) | |
| Path(tmp_input.name).unlink(missing_ok=True) | |
| continue | |
| # Étape 2 : STT | |
| await websocket.send_json({"status": "processing", "step": "stt", "message": "Transcription en cours..."}) | |
| audio_input, sample_rate = torchaudio.load(audio_path) | |
| if audio_input.shape[0] > 1: | |
| audio_input = torch.mean(audio_input, dim=0, keepdim=True) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| audio_input = resampler(audio_input) | |
| audio_input = audio_input.squeeze() | |
| inputs = stt_processor(audio_input, sampling_rate=16000, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = stt_model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = stt_processor.batch_decode(predicted_ids)[0] | |
| await websocket.send_json({ | |
| "status": "processing", | |
| "step": "stt_done", | |
| "transcription": transcription | |
| }) | |
| # Étape 3 : Traduction | |
| await websocket.send_json({"status": "processing", "step": "translate", "message": "Traduction en cours..."}) | |
| if target_lang == "fr": | |
| translate_tokenizer.src_lang = "dyu_Latn" | |
| target_lang_id = translate_tokenizer.convert_tokens_to_ids("fra_Latn") | |
| else: | |
| translate_tokenizer.src_lang = "fra_Latn" | |
| target_lang_id = translate_tokenizer.convert_tokens_to_ids("dyu_Latn") | |
| translate_inputs = translate_tokenizer(transcription, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| translated_tokens = translate_model.generate( | |
| **translate_inputs, | |
| forced_bos_token_id=target_lang_id, | |
| max_length=200 | |
| ) | |
| traduction = translate_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| await websocket.send_json({ | |
| "status": "processing", | |
| "step": "translate_done", | |
| "traduction": traduction | |
| }) | |
| # Étape 4 : TTS (optionnel) | |
| audio_b64 = None | |
| if include_tts and target_lang == "dyu": | |
| await websocket.send_json({"status": "processing", "step": "tts", "message": "Génération audio..."}) | |
| tts_inputs = tts_tokenizer(traduction, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| waveform = tts_model(**tts_inputs).waveform | |
| audio_data_tts = waveform[0].cpu().numpy() | |
| sample_rate_tts = tts_model.config.sampling_rate | |
| tmp_tts = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| scipy.io.wavfile.write(tmp_tts.name, rate=sample_rate_tts, data=audio_data_tts) | |
| with open(tmp_tts.name, "rb") as f: | |
| audio_b64 = base64.b64encode(f.read()).decode('utf-8') | |
| Path(tmp_tts.name).unlink(missing_ok=True) | |
| # Résultat final | |
| result = { | |
| "status": "done", | |
| "step": "complete", | |
| "transcription": transcription, | |
| "langue_source": "dioula", | |
| "traduction": traduction, | |
| "langue_cible": "français" if target_lang == "fr" else "dioula" | |
| } | |
| if audio_b64: | |
| result["audio_base64"] = audio_b64 | |
| await websocket.send_json(result) | |
| # Nettoyer | |
| Path(tmp_input.name).unlink(missing_ok=True) | |
| Path(audio_path).unlink(missing_ok=True) | |
| except Exception as e: | |
| await websocket.send_json({"error": f"Erreur pipeline: {str(e)}"}) | |
| except WebSocketDisconnect: | |
| print("Client déconnecté du WebSocket") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |