sarahwei commited on
Commit
98150db
·
verified ·
1 Parent(s): efbc9e2

Upload files

Browse files
Files changed (3) hide show
  1. app.py +210 -0
  2. enum_.py +26 -0
  3. requirements.txt +16 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ AutoModelForSeq2SeqLM,
4
+ AutoTokenizer,
5
+ pipeline,
6
+ VitsTokenizer,
7
+ VitsModel,
8
+ set_seed,
9
+ )
10
+ from enum_ import trans_languages, tts_languages, whisper_languages
11
+ import logging
12
+ import torch
13
+ from TTS.api import TTS
14
+ from functools import lru_cache
15
+ import numpy as np
16
+ from faster_whisper import WhisperModel
17
+ import librosa
18
+ import numpy as np
19
+ import torch
20
+ import os
21
+ from pydub import AudioSegment
22
+ import io
23
+
24
+ ##translation
25
+ translation_model_name = "facebook/nllb-200-distilled-600M"
26
+ tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
27
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
28
+
29
+
30
+ @lru_cache(maxsize=10)
31
+ def translate_sentence(sentence, src_lang, tgt_lang):
32
+ logging.info(src_lang, tgt_lang)
33
+ if not sentence:
34
+ return "Error: no input sentence"
35
+ try:
36
+ translator = pipeline(
37
+ "translation",
38
+ model=translation_model,
39
+ tokenizer=tokenizer,
40
+ src_lang=trans_languages[src_lang],
41
+ tgt_lang=trans_languages[tgt_lang],
42
+ max_length=400,
43
+ )
44
+ result = translator(sentence)
45
+ logging.info(f"Translation: {result}")
46
+ except Exception as e:
47
+ return f"Translation error: {e}"
48
+ if len(result) == 0:
49
+ return "No output from translator"
50
+ return result[0].get("translation_text", "No translation_text key in output")
51
+
52
+
53
+ @lru_cache(maxsize=10)
54
+ def load_tts():
55
+ # Get device
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ # Init TTS
58
+ tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
59
+ return tts_model
60
+
61
+
62
+ @lru_cache(maxsize=10)
63
+ def load_mms_tts(language):
64
+ tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{language}")
65
+ model = VitsModel.from_pretrained(f"facebook/mms-tts-{language}")
66
+ return model, tokenizer
67
+
68
+
69
+ def convert_vits_output_to_wav(vits_output):
70
+ """
71
+ Convert VITS model output to WAV format.
72
+
73
+ Parameters:
74
+ vits_output: torch.Tensor or np.ndarray
75
+ The audio output from the VITS model (float32).
76
+ sample_rate: int, default 24000
77
+ The sample rate of the generated audio.
78
+
79
+ Returns:
80
+ None, but saves a file as 'output.wav'
81
+ """
82
+
83
+ if isinstance(vits_output, torch.Tensor):
84
+ arr = vits_output.detach().cpu().numpy()
85
+ else:
86
+ arr = np.asarray(vits_output)
87
+
88
+ arr = np.squeeze(arr)
89
+
90
+ # Clip to valid range
91
+ arr = np.clip(arr, -1.0, 1.0).astype(np.float32)
92
+ arr = librosa.resample(arr, orig_sr=16000, target_sr=24000)
93
+ return arr
94
+
95
+
96
+ def tts(sentence, language):
97
+ if not sentence or sentence.strip() == "":
98
+ return None
99
+ try:
100
+ language_code = tts_languages[language]
101
+ if language_code in ["en", "ko", "ja"]:
102
+ tts_model = load_tts()
103
+ base_dir = os.path.dirname(os.path.abspath(__file__))
104
+ wav_path = os.path.join(base_dir, "example.mp3")
105
+ wav = tts_model.tts(
106
+ text=sentence, speaker_wav=wav_path, language=language_code
107
+ )
108
+ # Return as (sample_rate, audio_array) tuple for Gradio
109
+ return (24000, np.array(wav))
110
+ else:
111
+ model, tokenizer = load_mms_tts(tts_languages[language])
112
+ inputs = tokenizer(text=sentence, return_tensors="pt")
113
+ set_seed(555) # make deterministic
114
+
115
+ with torch.no_grad():
116
+ outputs = model(inputs["input_ids"])
117
+ outputs_resample = convert_vits_output_to_wav(outputs.waveform)
118
+ return (24000, outputs_resample)
119
+
120
+ except Exception as e:
121
+ logging.error(f"TTS error: {e}")
122
+ return None
123
+
124
+
125
+ @lru_cache(maxsize=10)
126
+ def load_whisper(type):
127
+ model = WhisperModel(type)
128
+ return model
129
+
130
+
131
+ def transcribe(audio, language=None):
132
+ if audio is None:
133
+ return ""
134
+
135
+ sr, y = audio
136
+ if y.ndim > 1:
137
+ y = y.mean(axis=1)
138
+ y = y.astype(np.float32) / 32768.0
139
+
140
+ if sr != 16000:
141
+ y = librosa.resample(y, orig_sr=sr, target_sr=16000)
142
+ sr = 16000
143
+
144
+ model = load_whisper("large-v2")
145
+ if language:
146
+ segments, info = model.transcribe(y, language=whisper_languages[language])
147
+ else:
148
+ segments, info = model.transcribe(y)
149
+ print(info.language)
150
+ transcription = ""
151
+ for segment in segments:
152
+ print(segment.text)
153
+ transcription += f"{segment.text}\n"
154
+ return f"{transcription}"
155
+
156
+
157
+ with gr.Blocks() as demo:
158
+ gr.Markdown(
159
+ """
160
+ ## Language Learning Assistant
161
+
162
+ Learn a new language interactively:
163
+
164
+ 1. **Type a Sentence**: Enter a sentence you want to learn and get an instant translation.
165
+ 2. **Listen to Pronunciation**: Generate and listen to the correct pronunciation.
166
+ 3. **Practice Speaking**: Record your pronunciation and compare it to the audio.
167
+ 4. **Speech-to-Text Feedback**: Check if your pronunciation is recognized using speech-to-text and get real-time feedback.
168
+
169
+ Improve your speaking and comprehension skills, all in one place!
170
+ """
171
+ )
172
+ with gr.Row():
173
+ # Left column: translation / text output
174
+ with gr.Column(scale=1, min_width=300):
175
+ with gr.Row():
176
+ src = gr.Dropdown(
177
+ list(trans_languages.keys()),
178
+ label="Input Language",
179
+ value="Traditional Chinese",
180
+ )
181
+ tgt = gr.Dropdown(
182
+ list(trans_languages.keys()),
183
+ label="Output Language",
184
+ value="English",
185
+ )
186
+ sentence = gr.Textbox(label="Sentence", interactive=True)
187
+ translate_btn = gr.Button("Translate Sentence")
188
+ with gr.Column(scale=1, min_width=300):
189
+ translation = gr.Textbox(label="Translation", interactive=False)
190
+ speech = gr.Audio()
191
+
192
+ with gr.Column(scale=1, min_width=300):
193
+ mic = gr.Audio(
194
+ sources=["microphone"], type="filepath", label="Record yourself"
195
+ )
196
+ transcription = gr.Textbox(label="Your transcription")
197
+ feedback = gr.Textbox(label="Feedback")
198
+
199
+ translate_btn.click(
200
+ fn=lambda txt, s_lang, t_lang: translate_sentence(txt, s_lang, t_lang),
201
+ inputs=[sentence, src, tgt],
202
+ outputs=translation,
203
+ )
204
+
205
+ translation.change(fn=tts, inputs=[translation, tgt], outputs=speech)
206
+
207
+ mic.change(fn=transcribe, inputs=[mic, tgt], outputs=[transcription])
208
+ # You could add more callbacks: e.g. after generating sentence, allow translation etc.
209
+
210
+ demo.launch(share=True)
enum_.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trans_languages = {
2
+ "Traditional Chinese": "zho_Hant",
3
+ "English": "eng_Latn",
4
+ "Korean": "kor_Hang",
5
+ "Vietnamese": "vie_Latn",
6
+ "Thai": "tha_Thai",
7
+ "Japanese": "jpn_Jpan",
8
+ }
9
+
10
+ tts_languages = {
11
+ "Traditional Chinese": "zh-tw",
12
+ "English": "en",
13
+ "Korean": "ko",
14
+ "Vietnamese": "vie",
15
+ "Thai": "tha",
16
+ "Japanese": "ja",
17
+ }
18
+
19
+ whisper_languages = {
20
+ "Traditional Chinese": "zh",
21
+ "English": "en",
22
+ "Korean": "ko",
23
+ "Vietnamese": "vi",
24
+ "Thai": "th",
25
+ "Japanese": "ja",
26
+ }
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.1.0
2
+ transformers==4.36.2
3
+ torch==2.1.2
4
+ torchaudio==2.1.2
5
+ librosa==0.10.0
6
+ numpy==1.26.3
7
+ scipy==1.12.0
8
+ soundfile==0.12.1
9
+ huggingface-hub==0.36.0
10
+ accelerate==0.24.0
11
+ typing-extensions==4.7.1
12
+ faster-whisper==1.2.1
13
+ librosa==0.10.0
14
+ cutlet==0.5.0
15
+ fugashi==1.5.2
16
+ pydub==0.25.1