Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCTC, AutoProcessor, VitsModel, AutoTokenizer | |
| import librosa | |
| import numpy as np | |
| import io | |
| import soundfile as sf | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| SAMPLE_RATE = 16000 | |
| MAX_AUDIO_LENGTH = 30 | |
| # Mapping des langues TTS | |
| LANGUAGE_MAPPING = { | |
| "Biali (beh)": "facebook/mms-tts-beh", | |
| "Baatombu (bba)": "facebook/mms-tts-bba", | |
| "Dendi (ddn)": "facebook/mms-tts-ddn", | |
| "Éwé (ewe)": "facebook/mms-tts-ewe", | |
| "Mina (gej)": "facebook/mms-tts-gej", | |
| "Ditammari (tbz)": "facebook/mms-tts-tbz", | |
| "Yoruba (yor)": "facebook/mms-tts-yor", | |
| "Fon (fon)": "facebook/mms-tts-fon", | |
| "English (eng)": "facebook/mms-tts-eng", | |
| } | |
| # Cache des modèles | |
| models_cache = {} | |
| def get_device(): | |
| """Retourne le device disponible""" | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_asr_model(): | |
| """Charge le modèle ASR""" | |
| if "asr" not in models_cache: | |
| device = get_device() | |
| logger.info("⏳ Chargement du modèle ASR...") | |
| processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
| model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all").to(device) | |
| model.eval() | |
| models_cache["asr"] = {"model": model, "processor": processor} | |
| logger.info("✅ Modèle ASR chargé") | |
| return models_cache["asr"]["model"], models_cache["asr"]["processor"] | |
| def load_tts_model(language_name): | |
| """Charge le modèle TTS pour une langue""" | |
| if language_name not in models_cache: | |
| device = get_device() | |
| model_id = LANGUAGE_MAPPING.get(language_name) | |
| if not model_id: | |
| raise ValueError(f"Langue non supportée: {language_name}") | |
| logger.info(f"⏳ Chargement du modèle TTS {language_name}...") | |
| model = VitsModel.from_pretrained(model_id).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model.eval() | |
| models_cache[language_name] = {"model": model, "tokenizer": tokenizer} | |
| logger.info(f"✅ Modèle TTS {language_name} chargé") | |
| return models_cache[language_name]["model"], models_cache[language_name]["tokenizer"] | |
| def process_audio(audio_data): | |
| """Traite l'audio""" | |
| try: | |
| if isinstance(audio_data, tuple): | |
| # Gradio retourne (sample_rate, audio_array) | |
| sr, audio = audio_data | |
| else: | |
| sr = SAMPLE_RATE | |
| audio = audio_data | |
| # Convertit en float32 si nécessaire | |
| audio = np.array(audio, dtype=np.float32) | |
| # Mono | |
| if len(audio.shape) > 1: | |
| audio = np.mean(audio, axis=1) | |
| # Rééchantillonne | |
| if sr != SAMPLE_RATE: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) | |
| # Normalise | |
| if np.max(np.abs(audio)) > 0: | |
| audio = audio / np.max(np.abs(audio)) | |
| # Tronque | |
| max_samples = MAX_AUDIO_LENGTH * SAMPLE_RATE | |
| if len(audio) > max_samples: | |
| audio = audio[:max_samples] | |
| return audio | |
| except Exception as e: | |
| logger.error(f"Erreur traitement audio: {e}") | |
| raise | |
| def transcribe_audio(audio, language_label): | |
| """Transcrit l'audio en texte (ASR)""" | |
| if audio is None: | |
| return "❌ Veuillez enregistrer ou uploader un fichier audio" | |
| try: | |
| # Extrait le code de langue du format "Langue (code)" | |
| language = language_label.split("(")[-1].rstrip(")") | |
| audio_processed = process_audio(audio) | |
| model, processor = load_asr_model() | |
| processor.current_lang = language | |
| device = get_device() | |
| with torch.no_grad(): | |
| inputs = processor(audio_processed, sampling_rate=SAMPLE_RATE, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| ids = torch.argmax(outputs.logits, dim=-1)[0] | |
| transcription = processor.decode(ids) | |
| return f"✅ Transcription:\n{transcription}" | |
| except Exception as e: | |
| logger.error(f"Erreur ASR: {e}") | |
| return f"❌ Erreur: {str(e)}" | |
| def synthesize_speech(text, language): | |
| """Synthétise le texte en audio (TTS)""" | |
| if not text or not text.strip(): | |
| return None, "❌ Veuillez entrer du texte" | |
| try: | |
| text = text.strip()[:1000] # Limite à 1000 chars | |
| model, tokenizer = load_tts_model(language) | |
| device = get_device() | |
| with torch.no_grad(): | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| waveform = outputs.waveform.cpu().numpy().flatten() | |
| # Récupère le taux d'échantillonnage réel du modèle | |
| sample_rate = model.config.sampling_rate | |
| # Normalise l'amplitude pour une meilleure qualité audio | |
| max_val = np.abs(waveform).max() | |
| if max_val > 0: | |
| # Normalise entre -0.95 et 0.95 pour éviter la saturation | |
| waveform = (waveform / max_val) * 0.95 | |
| # Convertit en int16 pour une meilleure qualité | |
| waveform_int16 = (waveform * 32767).astype(np.int16) | |
| # Retourne au format (sample_rate, audio_array) pour Gradio | |
| return (sample_rate, waveform_int16), f"✅ Audio généré ({len(waveform_int16)} samples @ {sample_rate}Hz)!" | |
| except Exception as e: | |
| logger.error(f"Erreur TTS: {e}") | |
| return None, f"❌ Erreur: {str(e)}" | |
| # ============= INTERFACE GRADIO ============= | |
| with gr.Blocks(title="🎙️ MMS ASR/TTS - Speech AI", theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center;"> | |
| <h1>🎙️ Meta MMS Speech AI</h1> | |
| <p style="font-size: 16px; color: #666;"> | |
| Reconnaissance vocale (ASR) + Synthèse vocale (TTS) multilingue | |
| </p> | |
| <p style="font-size: 14px; color: #999;"> | |
| Utilise les modèles <strong>facebook/mms-1b-all</strong> et <strong>facebook/mms-tts</strong> | |
| </p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # ============= TAB 1: ASR ============= | |
| with gr.TabItem("🔊 ASR (Audio → Texte)", id="asr"): | |
| gr.HTML("<h2>Reconnaissance Vocale Multilingue</h2>") | |
| gr.HTML("<p>Enregistre ou uploader un fichier audio pour obtenir la transcription.</p>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="📁 Fichier audio", | |
| type="numpy", | |
| sources=["upload", "microphone"] | |
| ) | |
| language_asr = gr.Dropdown( | |
| choices=[ | |
| "Biali (beh)", | |
| "Baatombu (bba)", | |
| "Dendi (ddn)", | |
| "Éwé (ewe)", | |
| "Mina (gej)", | |
| "Ditammari (tbz)", | |
| "Yoruba (yor)", | |
| "Fon (fon)", | |
| "English (eng)", | |
| ], | |
| value="English (eng)", | |
| label="🌐 Langue" | |
| ) | |
| btn_asr = gr.Button("🎯 Transcrire", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_asr = gr.Textbox( | |
| label="📝 Transcription", | |
| lines=6, | |
| interactive=False | |
| ) | |
| btn_asr.click( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, language_asr], | |
| outputs=output_asr | |
| ) | |
| # ============= TAB 2: TTS ============= | |
| with gr.TabItem("📢 TTS (Texte → Audio)", id="tts"): | |
| gr.HTML("<h2>Synthèse Vocale</h2>") | |
| gr.HTML("<p>Entre du texte et écoute la synthèse vocale dans la langue choisie.</p>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="✍️ Texte à convertir", | |
| placeholder="Écris du texte ici...", | |
| lines=4 | |
| ) | |
| language_tts = gr.Dropdown( | |
| choices=list(LANGUAGE_MAPPING.keys()), | |
| value="English (eng)", | |
| label="🌐 Langue" | |
| ) | |
| btn_tts = gr.Button("🔊 Générer l'audio", variant="primary", size="lg") | |
| info_tts = gr.Textbox( | |
| label="📊 Info", | |
| interactive=False, | |
| value="Clique sur 'Générer l'audio' pour commencer" | |
| ) | |
| with gr.Column(): | |
| audio_output = gr.Audio( | |
| label="🎵 Audio généré", | |
| type="numpy" | |
| ) | |
| btn_tts.click( | |
| fn=synthesize_speech, | |
| inputs=[text_input, language_tts], | |
| outputs=[audio_output, info_tts] | |
| ) | |
| # Exemples (optionnel - commenté pour éviter les erreurs) | |
| # Uncomment pour activer après test | |
| # gr.Examples( | |
| # examples=[ | |
| # ["Hello world", "English (eng)"], | |
| # ["Àbọ̀ wa", "Yoruba (yor)"], | |
| # ["Bonjour", "English (eng)"], | |
| # ], | |
| # fn=synthesize_speech, | |
| # inputs=[text_input, language_tts], | |
| # outputs=[audio_output, info_tts], | |
| # label="💡 Exemples" | |
| # ) | |
| # ============= TAB 3: INFOS ============= | |
| with gr.TabItem("ℹ️ À propos", id="about"): | |
| gr.HTML(""" | |
| <h2>À propos de cette API</h2> | |
| <h3>🎙️ ASR (Automatic Speech Recognition)</h3> | |
| <ul> | |
| <li><strong>Modèle:</strong> facebook/mms-1b-all (964M params)</li> | |
| <li><strong>Langues:</strong> 100+ langues (ISO 639-3)</li> | |
| <li><strong>Architecture:</strong> wav2vec2</li> | |
| <li><strong>Taux d'échantillonnage:</strong> 16 kHz</li> | |
| <li><strong>Limite:</strong> 30 secondes d'audio</li> | |
| </ul> | |
| <h3>📢 TTS (Text-to-Speech)</h3> | |
| <ul> | |
| <li><strong>Modèle:</strong> facebook/mms-tts-* (VITS)</li> | |
| <li><strong>Langues supportées:</strong> 8 langues</li> | |
| <li><strong>Taux d'échantillonnage:</strong> 22050 Hz</li> | |
| <li><strong>Limite:</strong> 1000 caractères</li> | |
| </ul> | |
| <h3>🌍 Langues TTS</h3> | |
| <ul> | |
| <li>🇧🇯 Biali (beh)</li> | |
| <li>🇧🇯 Baatombu (bba)</li> | |
| <li>🇧🇯 Dendi (ddn)</li> | |
| <li>🇬🇭 Éwé (ewe)</li> | |
| <li>🇧🇯 Mina (gej)</li> | |
| <li>🇧🇯 Ditammari (tbz)</li> | |
| <li>🇳🇬 Yoruba (yor)</li> | |
| <li>🇧🇯 Fon (fon)</li> | |
| <li>🇬🇧 English (eng)</li> | |
| </ul> | |
| <h3>🚀 Déploiement</h3> | |
| <p>Cette application est déployée sur <strong>Hugging Face Spaces</strong></p> | |
| <p>Code source: <a href="https://huggingface.co/spaces" target="_blank">GitHub</a></p> | |
| <h3>📚 Ressources</h3> | |
| <ul> | |
| <li><a href="https://arxiv.org/abs/2305.13516" target="_blank">Meta MMS Paper</a></li> | |
| <li><a href="https://huggingface.co/facebook/mms-1b-all" target="_blank">facebook/mms-1b-all</a></li> | |
| <li><a href="https://huggingface.co/facebook/mms-tts" target="_blank">facebook/mms-tts</a></li> | |
| </ul> | |
| <h3>⚖️ Licence</h3> | |
| <p>CC-BY-NC-4.0 (comme les modèles Meta MMS)</p> | |
| """) | |
| # Footer | |
| gr.HTML(""" | |
| <hr> | |
| <div style="text-align: center; font-size: 12px; color: #999; margin-top: 20px;"> | |
| <p>🏠 Powered by <strong>Gradio</strong> + <strong>Hugging Face</strong> | | |
| Device: <span id="device">Loading...</span></p> | |
| </div> | |
| <script> | |
| document.getElementById('device').innerText = document.body.innerText.includes('cuda') ? '🚀 GPU' : '💻 CPU'; | |
| </script> | |
| """) | |
| if __name__ == "__main__": | |
| logger.info(f"🚀 Démarrage de l'interface Gradio") | |
| logger.info(f"📊 Device: {get_device()}") | |
| demo.launch(share=False, debug=False) | |