Minte commited on
Commit
cb4630e
·
1 Parent(s): 018fb8e

fix: refactor model loading and enhance ASR and translation functionality with SeamlessM4T integration

Browse files
Files changed (1) hide show
  1. app.py +66 -82
app.py CHANGED
@@ -3,58 +3,44 @@ import soundfile as sf
3
  import torch
4
  import numpy as np
5
  from transformers import (
6
- AutoProcessor, AutoModelForSpeechSeq2Seq,
7
- pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer,
8
- VitsModel, AutoTokenizer
9
  )
10
  import gradio as gr
11
  import resampy
12
  import tempfile
13
  import subprocess
14
 
15
- # --- Load ASR model ---
16
  try:
17
  model_id = "facebook/seamless-m4t-v2-large"
18
  processor = AutoProcessor.from_pretrained(model_id)
19
- asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id).to("cpu")
20
- print("[INFO] ASR model loaded.")
21
  except Exception as e:
22
- print("[ERROR] Failed to load ASR model:", e)
23
  traceback.print_exc()
24
- asr_model = None
25
  processor = None
26
 
27
- # --- Load translation models ---
28
  try:
29
- back_translate_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to("cpu")
30
- back_translate_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
31
- print("[INFO] Back translation model loaded.")
32
- except Exception as e:
33
- print("[ERROR] Failed to load back translation model:", e)
34
- traceback.print_exc()
35
- back_translate_model = None
36
- back_translate_tokenizer = None
37
-
38
- # --- Load other pipelines ---
39
- try:
40
- translate_to_en = pipeline("translation", model="Helsinki-NLP/opus-mt-mul-en")
41
  chat_model = pipeline("text2text-generation", model="google/flan-t5-base")
42
- print("[INFO] Translation and chat models loaded successfully.")
43
  except Exception as e:
44
- print("[ERROR] Failed to load pipelines:", e)
45
  traceback.print_exc()
46
- translate_to_en = None
47
  chat_model = None
48
 
49
  # --- Load TTS model (Facebook MMS for Amharic) ---
50
  try:
51
- tts_processor = AutoProcessor.from_pretrained("facebook/mms-tts-amh")
52
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-amh").to("cpu")
53
  print("[INFO] Facebook MMS TTS model for Amharic loaded successfully.")
54
  except Exception as e:
55
  print("[ERROR] Failed to load Facebook MMS TTS model:", e)
56
  traceback.print_exc()
57
- tts_processor = None
58
  tts_model = None
59
 
60
  # --- Romanization helper ---
@@ -66,44 +52,65 @@ def romanize(text):
66
  print("[ERROR] Romanization failed:", e)
67
  return text # fallback
68
 
69
- # --- ASR ---
70
  def transcribe_amharic(audio_file):
71
- if asr_model is None or processor is None:
72
  return "ASR Model loading failed"
73
  try:
74
  audio, sr = sf.read(audio_file)
75
  if audio.ndim > 1:
76
  audio = audio.mean(axis=1)
77
  audio = resampy.resample(audio, sr, 16000)
 
 
78
  inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt")
79
  with torch.no_grad():
80
- generated_ids = asr_model.generate(**inputs, tgt_lang="amh")
81
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
82
  return transcription.strip()
83
  except Exception as e:
84
  print("[ERROR] ASR transcription failed:", e)
85
  traceback.print_exc()
86
  return f"ASR failed: {str(e)[:50]}..."
87
 
88
- # --- Back translation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def back_translate_en_to_am(en_text):
90
- if back_translate_model is None or back_translate_tokenizer is None:
91
  return "Back translation model not loaded"
92
  try:
93
- back_translate_tokenizer.src_lang = "en"
94
- inputs = back_translate_tokenizer(en_text, return_tensors="pt")
95
  with torch.no_grad():
96
- generated_tokens = back_translate_model.generate(
97
- **inputs,
98
- forced_bos_token_id=back_translate_tokenizer.get_lang_id("am"),
99
- max_length=128,
100
- num_beams=5,
101
- no_repeat_ngram_size=2,
102
- early_stopping=True,
103
- repetition_penalty=1.5,
104
- length_penalty=0.8
105
  )
106
- am_response = back_translate_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
107
  return am_response.strip()
108
  except Exception as e:
109
  print("[ERROR] Back translation failed:", e)
@@ -117,19 +124,7 @@ def generate_chat_response(text):
117
  try:
118
  # Add context to make responses more meaningful
119
  prompt = f"Respond to this in a helpful and conversational way: {text}"
120
- inputs = chat_model.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
121
- with torch.no_grad():
122
- outputs = chat_model.model.generate(
123
- inputs.input_ids,
124
- max_length=150,
125
- num_beams=5,
126
- no_repeat_ngram_size=3,
127
- early_stopping=True,
128
- repetition_penalty=2.0,
129
- temperature=0.7,
130
- do_sample=True
131
- )
132
- response = chat_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
133
  return response.strip()
134
  except Exception as e:
135
  print("[ERROR] Chat generation failed:", e)
@@ -137,34 +132,28 @@ def generate_chat_response(text):
137
 
138
  # --- TTS with Facebook MMS ---
139
  def generate_tts(text):
140
- if tts_model is None or tts_processor is None:
141
  print("[ERROR] TTS model not loaded")
142
  return None
143
  try:
144
  if not text.strip():
145
  return None
146
 
147
- # Process text and generate speech
148
- inputs = tts_processor(text=text, return_tensors="pt")
149
 
150
  with torch.no_grad():
151
- speech = tts_model(**inputs).waveform
 
152
 
153
  # Convert to numpy and normalize
154
- if isinstance(speech, torch.Tensor):
155
- audio_data = speech.cpu().numpy()
156
- else:
157
- audio_data = speech
158
-
159
- # Handle mono/stereo and normalize
160
- if audio_data.ndim > 1:
161
- audio_data = audio_data.squeeze()
162
 
163
  max_val = np.max(np.abs(audio_data))
164
  if max_val > 0:
165
  audio_data = audio_data / max_val
166
 
167
- return audio_data, 16000 # MMS TTS typically uses 16kHz
168
 
169
  except Exception as e:
170
  print("[ERROR] MMS TTS generation failed:", e)
@@ -219,29 +208,24 @@ def create_wav_file(audio_array, sample_rate):
219
  def assistant_pipeline(audio):
220
  if not audio:
221
  return "No audio", "", "", "", None
 
 
222
  asr_result = transcribe_amharic(audio)
223
  print(f"ASR Result: {asr_result}")
224
 
225
- # Translation
226
- if translate_to_en is None:
227
- en_text = "Translation model not loaded"
228
- else:
229
- try:
230
- en_text = translate_to_en(asr_result)[0]["translation_text"]
231
- except Exception as e:
232
- print("[ERROR] Translation to English failed:", e)
233
- en_text = f"Translation failed: {str(e)[:50]}..."
234
  print(f"English Translation: {en_text}")
235
 
236
- # Chat
237
  en_response = generate_chat_response(en_text)
238
  print(f"Chat Response: {en_response}")
239
 
240
- # Back translation
241
  am_response = back_translate_en_to_am(en_response)
242
  print(f"Amharic Response: {am_response}")
243
 
244
- # TTS with multiple fallbacks
245
  audio_file_path = None
246
  if am_response and not am_response.startswith("Back translation failed"):
247
  # Try MMS TTS first
 
3
  import torch
4
  import numpy as np
5
  from transformers import (
6
+ SeamlessM4TModel, AutoProcessor,
7
+ pipeline, VitsModel, AutoTokenizer
 
8
  )
9
  import gradio as gr
10
  import resampy
11
  import tempfile
12
  import subprocess
13
 
14
+ # --- Load SeamlessM4T model for ASR and translation ---
15
  try:
16
  model_id = "facebook/seamless-m4t-v2-large"
17
  processor = AutoProcessor.from_pretrained(model_id)
18
+ model = SeamlessM4TModel.from_pretrained(model_id).to("cpu")
19
+ print("[INFO] SeamlessM4T model loaded for ASR and translation.")
20
  except Exception as e:
21
+ print("[ERROR] Failed to load SeamlessM4T model:", e)
22
  traceback.print_exc()
23
+ model = None
24
  processor = None
25
 
26
+ # --- Load chat model ---
27
  try:
 
 
 
 
 
 
 
 
 
 
 
 
28
  chat_model = pipeline("text2text-generation", model="google/flan-t5-base")
29
+ print("[INFO] Chat model loaded successfully.")
30
  except Exception as e:
31
+ print("[ERROR] Failed to load chat model:", e)
32
  traceback.print_exc()
 
33
  chat_model = None
34
 
35
  # --- Load TTS model (Facebook MMS for Amharic) ---
36
  try:
37
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-amh")
38
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-amh").to("cpu")
39
  print("[INFO] Facebook MMS TTS model for Amharic loaded successfully.")
40
  except Exception as e:
41
  print("[ERROR] Failed to load Facebook MMS TTS model:", e)
42
  traceback.print_exc()
43
+ tts_tokenizer = None
44
  tts_model = None
45
 
46
  # --- Romanization helper ---
 
52
  print("[ERROR] Romanization failed:", e)
53
  return text # fallback
54
 
55
+ # --- ASR with SeamlessM4T ---
56
  def transcribe_amharic(audio_file):
57
+ if model is None or processor is None:
58
  return "ASR Model loading failed"
59
  try:
60
  audio, sr = sf.read(audio_file)
61
  if audio.ndim > 1:
62
  audio = audio.mean(axis=1)
63
  audio = resampy.resample(audio, sr, 16000)
64
+
65
+ # Direct Amharic transcription
66
  inputs = processor(audio=audio, sampling_rate=16000, return_tensors="pt")
67
  with torch.no_grad():
68
+ generated_ids = model.generate(
69
+ **inputs,
70
+ tgt_lang="amh",
71
+ generate_speech=False
72
+ )
73
+ transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
74
  return transcription.strip()
75
  except Exception as e:
76
  print("[ERROR] ASR transcription failed:", e)
77
  traceback.print_exc()
78
  return f"ASR failed: {str(e)[:50]}..."
79
 
80
+ # --- Translation with SeamlessM4T (Amharic to English) ---
81
+ def translate_am_to_en(amharic_text):
82
+ if model is None or processor is None:
83
+ return "Translation model not loaded"
84
+ try:
85
+ # Translate Amharic to English using SeamlessM4T
86
+ text_inputs = processor(text=amharic_text, src_lang="amh", return_tensors="pt")
87
+ with torch.no_grad():
88
+ output_tokens = model.generate(
89
+ **text_inputs,
90
+ tgt_lang="eng",
91
+ generate_speech=False
92
+ )
93
+ translated_text = processor.decode(output_tokens[0], skip_special_tokens=True)
94
+ return translated_text.strip()
95
+ except Exception as e:
96
+ print("[ERROR] Translation failed:", e)
97
+ traceback.print_exc()
98
+ return f"Translation failed: {str(e)[:50]}..."
99
+
100
+ # --- Back translation with SeamlessM4T (English to Amharic) ---
101
  def back_translate_en_to_am(en_text):
102
+ if model is None or processor is None:
103
  return "Back translation model not loaded"
104
  try:
105
+ # Translate English back to Amharic using SeamlessM4T
106
+ text_inputs = processor(text=en_text, src_lang="eng", return_tensors="pt")
107
  with torch.no_grad():
108
+ output_tokens = model.generate(
109
+ **text_inputs,
110
+ tgt_lang="amh",
111
+ generate_speech=False
 
 
 
 
 
112
  )
113
+ am_response = processor.decode(output_tokens[0], skip_special_tokens=True)
114
  return am_response.strip()
115
  except Exception as e:
116
  print("[ERROR] Back translation failed:", e)
 
124
  try:
125
  # Add context to make responses more meaningful
126
  prompt = f"Respond to this in a helpful and conversational way: {text}"
127
+ response = chat_model(prompt, max_length=150, num_beams=5, temperature=0.7, do_sample=True)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
128
  return response.strip()
129
  except Exception as e:
130
  print("[ERROR] Chat generation failed:", e)
 
132
 
133
  # --- TTS with Facebook MMS ---
134
  def generate_tts(text):
135
+ if tts_model is None or tts_tokenizer is None:
136
  print("[ERROR] TTS model not loaded")
137
  return None
138
  try:
139
  if not text.strip():
140
  return None
141
 
142
+ # Tokenize text and generate speech
143
+ inputs = tts_tokenizer(text, return_tensors="pt")
144
 
145
  with torch.no_grad():
146
+ output = tts_model(**inputs)
147
+ speech = output.waveform
148
 
149
  # Convert to numpy and normalize
150
+ audio_data = speech.cpu().numpy().squeeze()
 
 
 
 
 
 
 
151
 
152
  max_val = np.max(np.abs(audio_data))
153
  if max_val > 0:
154
  audio_data = audio_data / max_val
155
 
156
+ return audio_data, tts_model.config.sampling_rate
157
 
158
  except Exception as e:
159
  print("[ERROR] MMS TTS generation failed:", e)
 
208
  def assistant_pipeline(audio):
209
  if not audio:
210
  return "No audio", "", "", "", None
211
+
212
+ # Step 1: ASR with SeamlessM4T
213
  asr_result = transcribe_amharic(audio)
214
  print(f"ASR Result: {asr_result}")
215
 
216
+ # Step 2: Translation with SeamlessM4T
217
+ en_text = translate_am_to_en(asr_result)
 
 
 
 
 
 
 
218
  print(f"English Translation: {en_text}")
219
 
220
+ # Step 3: Chat response
221
  en_response = generate_chat_response(en_text)
222
  print(f"Chat Response: {en_response}")
223
 
224
+ # Step 4: Back translation with SeamlessM4T
225
  am_response = back_translate_en_to_am(en_response)
226
  print(f"Amharic Response: {am_response}")
227
 
228
+ # Step 5: TTS
229
  audio_file_path = None
230
  if am_response and not am_response.startswith("Back translation failed"):
231
  # Try MMS TTS first