import gradio as gr from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, VitsTokenizer, VitsModel, set_seed, ) from enum_ import trans_languages, tts_languages, whisper_languages import logging import torch from TTS.api import TTS from functools import lru_cache import numpy as np from faster_whisper import WhisperModel import librosa import numpy as np import torch import os from evaluate import load ##translation translation_model_name = "facebook/nllb-200-distilled-600M" tokenizer = AutoTokenizer.from_pretrained(translation_model_name) translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name) wer_metric = load("wer") cer_metric = load("cer") @lru_cache(maxsize=10) def translate_sentence(sentence, src_lang, tgt_lang): logging.info(src_lang, tgt_lang) if not sentence: return "Error: no input sentence" try: translator = pipeline( "translation", model=translation_model, tokenizer=tokenizer, src_lang=trans_languages[src_lang], tgt_lang=trans_languages[tgt_lang], max_length=400, ) result = translator(sentence) logging.info(f"Translation: {result}") except Exception as e: return f"Translation error: {e}" if len(result) == 0: return "No output from translator" return result[0].get("translation_text", "No translation_text key in output") @lru_cache(maxsize=10) def load_tts(): # Get device device = "cuda" if torch.cuda.is_available() else "cpu" # Init TTS tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) return tts_model @lru_cache(maxsize=10) def load_mms_tts(language): tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{language}") model = VitsModel.from_pretrained(f"facebook/mms-tts-{language}") return model, tokenizer def convert_vits_output_to_wav(vits_output): """ Convert VITS model output to WAV format. Parameters: vits_output: torch.Tensor or np.ndarray The audio output from the VITS model (float32). sample_rate: int, default 24000 The sample rate of the generated audio. Returns: None, but saves a file as 'output.wav' """ if isinstance(vits_output, torch.Tensor): arr = vits_output.detach().cpu().numpy() else: arr = np.asarray(vits_output) arr = np.squeeze(arr) # Clip to valid range arr = np.clip(arr, -1.0, 1.0).astype(np.float32) arr = librosa.resample(arr, orig_sr=16000, target_sr=24000) return arr def tts(sentence, language): if not sentence or sentence.strip() == "": return None try: language_code = tts_languages[language] if language_code in ["en", "ko", "ja", "zh-cn"]: tts_model = load_tts() base_dir = os.path.dirname(os.path.abspath(__file__)) wav_path = os.path.join(base_dir, "example.mp3") wav = tts_model.tts( text=sentence, speaker_wav=wav_path, language=language_code ) # Return as (sample_rate, audio_array) tuple for Gradio return (24000, np.array(wav)) else: model, tokenizer = load_mms_tts(tts_languages[language]) inputs = tokenizer(text=sentence, return_tensors="pt") set_seed(555) # make deterministic with torch.no_grad(): outputs = model(inputs["input_ids"]) outputs_resample = convert_vits_output_to_wav(outputs.waveform) return (24000, outputs_resample) except Exception as e: logging.error(f"TTS error: {e}") return None @lru_cache(maxsize=10) def load_whisper(type): model = WhisperModel(type) return model def transcribe(audio, language=None): if audio is None: return "" sr, y = audio if y.ndim > 1: y = y.mean(axis=1) y = y.astype(np.float32) / 32768.0 if sr != 16000: y = librosa.resample(y, orig_sr=sr, target_sr=16000) sr = 16000 model = load_whisper("large-v2") if language: segments, info = model.transcribe(y, language=whisper_languages[language]) else: segments, info = model.transcribe(y) logging.info(f"Detected language: {info.language}") transcription = "" for segment in segments: logging.info(segment.text) transcription += f"{segment.text}\n" return f"{transcription}" def evaluate(language, reference, prediction): ### wer if language in ["Traditional Chinese", "Vitetnamese"]: wer = wer_metric.compute(predictions=prediction, reference=reference) return str((1 - wer) * 100) + "%" ### cer else: cer = cer_metric.compute(predictions=prediction, reference=reference) return str((1 - cer) * 100) + "%" with gr.Blocks() as demo: gr.Markdown( """ ## Language Learning Assistant Learn a new language interactively: 1. **Type a Sentence**: Enter a sentence you want to learn and get an instant translation. 2. **Listen to Pronunciation**: Generate and listen to the correct pronunciation. 3. **Practice Speaking**: Record your pronunciation and compare it to the audio. 4. **Speech-to-Text Feedback**: Check if your pronunciation is recognized using speech-to-text and get real-time feedback. Improve your speaking and comprehension skills, all in one place! """ ) with gr.Row(): # Left column: translation / text output with gr.Column(scale=1, min_width=300): with gr.Row(): src = gr.Dropdown( list(trans_languages.keys()), label="Input Language", value="Traditional Chinese", ) tgt = gr.Dropdown( list(trans_languages.keys()), label="Output Language", value="English", ) sentence = gr.Textbox(label="Sentence", interactive=True) translate_btn = gr.Button("Translate Sentence") with gr.Column(scale=1, min_width=300): translation = gr.Textbox(label="Translation", interactive=False) speech = gr.Audio() with gr.Column(scale=1, min_width=300): mic = gr.Audio( sources=["microphone"], type="numpy", label="Record yourself" ) transcription = gr.Textbox(label="Your transcription") accuracy = gr.Textbox(label="Accuracy") translate_btn.click( fn=lambda txt, s_lang, t_lang: translate_sentence(txt, s_lang, t_lang), inputs=[sentence, src, tgt], outputs=translation, ) translation.change(fn=tts, inputs=[translation, tgt], outputs=speech) mic.stop(fn=transcribe, inputs=[mic, tgt], outputs=[transcription]) transcription.change( fn=evaluate, inputs=[tgt, translation, transcription], outputs=[accuracy] ) # You could add more callbacks: e.g. after generating sentence, allow translation etc. demo.launch(share=True)