File size: 7,206 Bytes
98150db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4072
98150db
 
 
 
 
 
40e4072
 
 
98150db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4072
98150db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4072
98150db
 
40e4072
98150db
 
 
 
40e4072
 
 
a71a5b3
40e4072
 
 
a71a5b3
40e4072
 
 
98150db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e4072
98150db
 
40e4072
98150db
 
 
 
 
 
 
 
 
54cf316
40e4072
 
 
98150db
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import gradio as gr
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    pipeline,
    VitsTokenizer,
    VitsModel,
    set_seed,
)
from enum_ import trans_languages, tts_languages, whisper_languages
import logging
import torch
from TTS.api import TTS
from functools import lru_cache
import numpy as np
from faster_whisper import WhisperModel
import librosa
import numpy as np
import torch
import os
from evaluate import load

##translation
translation_model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)

wer_metric = load("wer")
cer_metric = load("cer")


@lru_cache(maxsize=10)
def translate_sentence(sentence, src_lang, tgt_lang):
    logging.info(src_lang, tgt_lang)
    if not sentence:
        return "Error: no input sentence"
    try:
        translator = pipeline(
            "translation",
            model=translation_model,
            tokenizer=tokenizer,
            src_lang=trans_languages[src_lang],
            tgt_lang=trans_languages[tgt_lang],
            max_length=400,
        )
        result = translator(sentence)
        logging.info(f"Translation: {result}")
    except Exception as e:
        return f"Translation error: {e}"
    if len(result) == 0:
        return "No output from translator"
    return result[0].get("translation_text", "No translation_text key in output")


@lru_cache(maxsize=10)
def load_tts():
    # Get device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Init TTS
    tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
    return tts_model


@lru_cache(maxsize=10)
def load_mms_tts(language):
    tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{language}")
    model = VitsModel.from_pretrained(f"facebook/mms-tts-{language}")
    return model, tokenizer


def convert_vits_output_to_wav(vits_output):
    """
    Convert VITS model output to WAV format.

    Parameters:
        vits_output: torch.Tensor or np.ndarray
            The audio output from the VITS model (float32).
        sample_rate: int, default 24000
            The sample rate of the generated audio.

    Returns:
        None, but saves a file as 'output.wav'
    """

    if isinstance(vits_output, torch.Tensor):
        arr = vits_output.detach().cpu().numpy()
    else:
        arr = np.asarray(vits_output)

    arr = np.squeeze(arr)

    # Clip to valid range
    arr = np.clip(arr, -1.0, 1.0).astype(np.float32)
    arr = librosa.resample(arr, orig_sr=16000, target_sr=24000)
    return arr


def tts(sentence, language):
    if not sentence or sentence.strip() == "":
        return None
    try:
        language_code = tts_languages[language]
        if language_code in ["en", "ko", "ja", "zh-cn"]:
            tts_model = load_tts()
            base_dir = os.path.dirname(os.path.abspath(__file__))
            wav_path = os.path.join(base_dir, "example.mp3")
            wav = tts_model.tts(
                text=sentence, speaker_wav=wav_path, language=language_code
            )
            # Return as (sample_rate, audio_array) tuple for Gradio
            return (24000, np.array(wav))
        else:
            model, tokenizer = load_mms_tts(tts_languages[language])
            inputs = tokenizer(text=sentence, return_tensors="pt")
            set_seed(555)  # make deterministic

            with torch.no_grad():
                outputs = model(inputs["input_ids"])
            outputs_resample = convert_vits_output_to_wav(outputs.waveform)
            return (24000, outputs_resample)

    except Exception as e:
        logging.error(f"TTS error: {e}")
        return None


@lru_cache(maxsize=10)
def load_whisper(type):
    model = WhisperModel(type)
    return model


def transcribe(audio, language=None):
    if audio is None:
        return ""

    sr, y = audio
    if y.ndim > 1:
        y = y.mean(axis=1)
    y = y.astype(np.float32) / 32768.0

    if sr != 16000:
        y = librosa.resample(y, orig_sr=sr, target_sr=16000)
        sr = 16000

    model = load_whisper("large-v2")
    if language:
        segments, info = model.transcribe(y, language=whisper_languages[language])
    else:
        segments, info = model.transcribe(y)
        logging.info(f"Detected language: {info.language}")
    transcription = ""
    for segment in segments:
        logging.info(segment.text)
        transcription += f"{segment.text}\n"
    return f"{transcription}"


def evaluate(language, reference, prediction):
    ### wer
    if language in ["Traditional Chinese", "Vitetnamese"]:
        wer = wer_metric.compute(predictions=prediction, reference=reference)
        return str((1 - wer) * 100) + "%"
    ### cer
    else:
        cer = cer_metric.compute(predictions=prediction, reference=reference)
        return str((1 - cer) * 100) + "%"


with gr.Blocks() as demo:
    gr.Markdown(
        """
    ## Language Learning Assistant

    Learn a new language interactively:

    1. **Type a Sentence**: Enter a sentence you want to learn and get an instant translation.
    2. **Listen to Pronunciation**: Generate and listen to the correct pronunciation.
    3. **Practice Speaking**: Record your pronunciation and compare it to the audio.
    4. **Speech-to-Text Feedback**: Check if your pronunciation is recognized using speech-to-text and get real-time feedback.

    Improve your speaking and comprehension skills, all in one place!
    """
    )
    with gr.Row():
        # Left column: translation / text output
        with gr.Column(scale=1, min_width=300):
            with gr.Row():
                src = gr.Dropdown(
                    list(trans_languages.keys()),
                    label="Input Language",
                    value="Traditional Chinese",
                )
                tgt = gr.Dropdown(
                    list(trans_languages.keys()),
                    label="Output Language",
                    value="English",
                )
            sentence = gr.Textbox(label="Sentence", interactive=True)
            translate_btn = gr.Button("Translate Sentence")
        with gr.Column(scale=1, min_width=300):
            translation = gr.Textbox(label="Translation", interactive=False)
            speech = gr.Audio()

        with gr.Column(scale=1, min_width=300):
            mic = gr.Audio(
                sources=["microphone"], type="numpy", label="Record yourself"
            )
            transcription = gr.Textbox(label="Your transcription")
            accuracy = gr.Textbox(label="Accuracy")

    translate_btn.click(
        fn=lambda txt, s_lang, t_lang: translate_sentence(txt, s_lang, t_lang),
        inputs=[sentence, src, tgt],
        outputs=translation,
    )

    translation.change(fn=tts, inputs=[translation, tgt], outputs=speech)

    mic.stop(fn=transcribe, inputs=[mic, tgt], outputs=[transcription])
    transcription.change(
        fn=evaluate, inputs=[tgt, translation, transcription], outputs=[accuracy]
    )
    # You could add more callbacks: e.g. after generating sentence, allow translation etc.

demo.launch(share=True)