import gradio as gr import torch from transformers import AutoModelForCTC, AutoProcessor, VitsModel, AutoTokenizer import librosa import numpy as np import io import soundfile as sf import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration SAMPLE_RATE = 16000 MAX_AUDIO_LENGTH = 30 # Mapping des langues TTS LANGUAGE_MAPPING = { "Biali (beh)": "facebook/mms-tts-beh", "Baatombu (bba)": "facebook/mms-tts-bba", "Dendi (ddn)": "facebook/mms-tts-ddn", "Éwé (ewe)": "facebook/mms-tts-ewe", "Mina (gej)": "facebook/mms-tts-gej", "Ditammari (tbz)": "facebook/mms-tts-tbz", "Yoruba (yor)": "facebook/mms-tts-yor", "Fon (fon)": "facebook/mms-tts-fon", "English (eng)": "facebook/mms-tts-eng", } # Cache des modèles models_cache = {} def get_device(): """Retourne le device disponible""" return "cuda" if torch.cuda.is_available() else "cpu" def load_asr_model(): """Charge le modèle ASR""" if "asr" not in models_cache: device = get_device() logger.info("⏳ Chargement du modèle ASR...") processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all").to(device) model.eval() models_cache["asr"] = {"model": model, "processor": processor} logger.info("✅ Modèle ASR chargé") return models_cache["asr"]["model"], models_cache["asr"]["processor"] def load_tts_model(language_name): """Charge le modèle TTS pour une langue""" if language_name not in models_cache: device = get_device() model_id = LANGUAGE_MAPPING.get(language_name) if not model_id: raise ValueError(f"Langue non supportée: {language_name}") logger.info(f"⏳ Chargement du modèle TTS {language_name}...") model = VitsModel.from_pretrained(model_id).to(device) tokenizer = AutoTokenizer.from_pretrained(model_id) model.eval() models_cache[language_name] = {"model": model, "tokenizer": tokenizer} logger.info(f"✅ Modèle TTS {language_name} chargé") return models_cache[language_name]["model"], models_cache[language_name]["tokenizer"] def process_audio(audio_data): """Traite l'audio""" try: if isinstance(audio_data, tuple): # Gradio retourne (sample_rate, audio_array) sr, audio = audio_data else: sr = SAMPLE_RATE audio = audio_data # Convertit en float32 si nécessaire audio = np.array(audio, dtype=np.float32) # Mono if len(audio.shape) > 1: audio = np.mean(audio, axis=1) # Rééchantillonne if sr != SAMPLE_RATE: audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) # Normalise if np.max(np.abs(audio)) > 0: audio = audio / np.max(np.abs(audio)) # Tronque max_samples = MAX_AUDIO_LENGTH * SAMPLE_RATE if len(audio) > max_samples: audio = audio[:max_samples] return audio except Exception as e: logger.error(f"Erreur traitement audio: {e}") raise def transcribe_audio(audio, language_label): """Transcrit l'audio en texte (ASR)""" if audio is None: return "❌ Veuillez enregistrer ou uploader un fichier audio" try: # Extrait le code de langue du format "Langue (code)" language = language_label.split("(")[-1].rstrip(")") audio_processed = process_audio(audio) model, processor = load_asr_model() processor.current_lang = language device = get_device() with torch.no_grad(): inputs = processor(audio_processed, sampling_rate=SAMPLE_RATE, return_tensors="pt").to(device) outputs = model(**inputs) ids = torch.argmax(outputs.logits, dim=-1)[0] transcription = processor.decode(ids) return f"✅ Transcription:\n{transcription}" except Exception as e: logger.error(f"Erreur ASR: {e}") return f"❌ Erreur: {str(e)}" def synthesize_speech(text, language): """Synthétise le texte en audio (TTS)""" if not text or not text.strip(): return None, "❌ Veuillez entrer du texte" try: text = text.strip()[:1000] # Limite à 1000 chars model, tokenizer = load_tts_model(language) device = get_device() with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt").to(device) outputs = model(**inputs) waveform = outputs.waveform.cpu().numpy().flatten() # Récupère le taux d'échantillonnage réel du modèle sample_rate = model.config.sampling_rate # Normalise l'amplitude pour une meilleure qualité audio max_val = np.abs(waveform).max() if max_val > 0: # Normalise entre -0.95 et 0.95 pour éviter la saturation waveform = (waveform / max_val) * 0.95 # Convertit en int16 pour une meilleure qualité waveform_int16 = (waveform * 32767).astype(np.int16) # Retourne au format (sample_rate, audio_array) pour Gradio return (sample_rate, waveform_int16), f"✅ Audio généré ({len(waveform_int16)} samples @ {sample_rate}Hz)!" except Exception as e: logger.error(f"Erreur TTS: {e}") return None, f"❌ Erreur: {str(e)}" # ============= INTERFACE GRADIO ============= with gr.Blocks(title="🎙️ MMS ASR/TTS - Speech AI", theme=gr.themes.Soft()) as demo: gr.HTML("""

🎙️ Meta MMS Speech AI

Reconnaissance vocale (ASR) + Synthèse vocale (TTS) multilingue

Utilise les modèles facebook/mms-1b-all et facebook/mms-tts

""") with gr.Tabs(): # ============= TAB 1: ASR ============= with gr.TabItem("🔊 ASR (Audio → Texte)", id="asr"): gr.HTML("

Reconnaissance Vocale Multilingue

") gr.HTML("

Enregistre ou uploader un fichier audio pour obtenir la transcription.

") with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="📁 Fichier audio", type="numpy", sources=["upload", "microphone"] ) language_asr = gr.Dropdown( choices=[ "Biali (beh)", "Baatombu (bba)", "Dendi (ddn)", "Éwé (ewe)", "Mina (gej)", "Ditammari (tbz)", "Yoruba (yor)", "Fon (fon)", "English (eng)", ], value="English (eng)", label="🌐 Langue" ) btn_asr = gr.Button("🎯 Transcrire", variant="primary", size="lg") with gr.Column(): output_asr = gr.Textbox( label="📝 Transcription", lines=6, interactive=False ) btn_asr.click( fn=transcribe_audio, inputs=[audio_input, language_asr], outputs=output_asr ) # ============= TAB 2: TTS ============= with gr.TabItem("📢 TTS (Texte → Audio)", id="tts"): gr.HTML("

Synthèse Vocale

") gr.HTML("

Entre du texte et écoute la synthèse vocale dans la langue choisie.

") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="✍️ Texte à convertir", placeholder="Écris du texte ici...", lines=4 ) language_tts = gr.Dropdown( choices=list(LANGUAGE_MAPPING.keys()), value="English (eng)", label="🌐 Langue" ) btn_tts = gr.Button("🔊 Générer l'audio", variant="primary", size="lg") info_tts = gr.Textbox( label="📊 Info", interactive=False, value="Clique sur 'Générer l'audio' pour commencer" ) with gr.Column(): audio_output = gr.Audio( label="🎵 Audio généré", type="numpy" ) btn_tts.click( fn=synthesize_speech, inputs=[text_input, language_tts], outputs=[audio_output, info_tts] ) # Exemples (optionnel - commenté pour éviter les erreurs) # Uncomment pour activer après test # gr.Examples( # examples=[ # ["Hello world", "English (eng)"], # ["Àbọ̀ wa", "Yoruba (yor)"], # ["Bonjour", "English (eng)"], # ], # fn=synthesize_speech, # inputs=[text_input, language_tts], # outputs=[audio_output, info_tts], # label="💡 Exemples" # ) # ============= TAB 3: INFOS ============= with gr.TabItem("ℹ️ À propos", id="about"): gr.HTML("""

À propos de cette API

🎙️ ASR (Automatic Speech Recognition)

📢 TTS (Text-to-Speech)

🌍 Langues TTS

🚀 Déploiement

Cette application est déployée sur Hugging Face Spaces

Code source: GitHub

📚 Ressources

⚖️ Licence

CC-BY-NC-4.0 (comme les modèles Meta MMS)

""") # Footer gr.HTML("""

🏠 Powered by Gradio + Hugging Face | Device: Loading...

""") if __name__ == "__main__": logger.info(f"🚀 Démarrage de l'interface Gradio") logger.info(f"📊 Device: {get_device()}") demo.launch(share=False, debug=False)