Minte commited on
Commit
0717063
Β·
1 Parent(s): e04060f

refactor: replace subprocess with uroman for romanization and update TTS generation

Browse files
Files changed (2) hide show
  1. app.py +13 -19
  2. requirements.txt +3 -1
app.py CHANGED
@@ -10,7 +10,7 @@ from transformers import (
10
  import gradio as gr
11
  import resampy
12
  import tempfile
13
- import subprocess
14
 
15
  # --- Load ASR model ---
16
  try:
@@ -62,11 +62,11 @@ except Exception as e:
62
  # --- Romanization helper ---
63
  def romanize(text):
64
  try:
65
- result = subprocess.run(["uroman"], input=text.encode("utf-8"), stdout=subprocess.PIPE)
66
- return result.stdout.decode("utf-8").strip()
67
  except Exception as e:
68
  print("[ERROR] Romanization failed:", e)
69
- return text # fallback
70
 
71
  # --- ASR ---
72
  def transcribe_amharic(audio_file):
@@ -117,18 +117,8 @@ def generate_chat_response(text):
117
  if chat_model is None:
118
  return "Chat model not loaded"
119
  try:
120
- inputs = chat_model.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
121
- with torch.no_grad():
122
- outputs = chat_model.model.generate(
123
- inputs.input_ids,
124
- max_length=128,
125
- num_beams=4,
126
- no_repeat_ngram_size=2,
127
- early_stopping=True,
128
- repetition_penalty=1.3,
129
- do_sample=True
130
- )
131
- response = chat_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
132
  return response.strip()
133
  except Exception as e:
134
  print("[ERROR] Chat generation failed:", e)
@@ -145,7 +135,11 @@ def generate_tts(text):
145
  romanized_text = romanize(text)
146
  inputs = tts_processor(text=romanized_text, return_tensors="pt")
147
  with torch.no_grad():
148
- speech = tts_model.generate_speech(inputs["input_ids"], tts_vocoder)
 
 
 
 
149
  audio_data = speech.numpy()
150
  max_val = np.max(np.abs(audio_data))
151
  if max_val > 0:
@@ -226,10 +220,10 @@ def assistant_pipeline(audio):
226
  # --- Gradio UI ---
227
  with gr.Blocks(title="🌍 Local Language AI Assistant") as demo:
228
  gr.Markdown("# 🌍 Local Language AI Assistant")
229
- gr.Markdown("πŸŽ™οΈ Speak **or upload** Amharic audio and get AI responses with synthesized Amharic speech!")
230
 
231
  with gr.Row():
232
- audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="🎀 Record or Upload your voice")
233
 
234
  submit_btn = gr.Button("Process", variant="primary")
235
 
 
10
  import gradio as gr
11
  import resampy
12
  import tempfile
13
+ from uroman import uroman # βœ… Use Python version instead of subprocess
14
 
15
  # --- Load ASR model ---
16
  try:
 
62
  # --- Romanization helper ---
63
  def romanize(text):
64
  try:
65
+ romanized = uroman.romanize_string(text)
66
+ return romanized.strip()
67
  except Exception as e:
68
  print("[ERROR] Romanization failed:", e)
69
+ return text
70
 
71
  # --- ASR ---
72
  def transcribe_amharic(audio_file):
 
117
  if chat_model is None:
118
  return "Chat model not loaded"
119
  try:
120
+ result = chat_model(text, max_length=128, num_beams=4, do_sample=True)
121
+ response = result[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
122
  return response.strip()
123
  except Exception as e:
124
  print("[ERROR] Chat generation failed:", e)
 
135
  romanized_text = romanize(text)
136
  inputs = tts_processor(text=romanized_text, return_tensors="pt")
137
  with torch.no_grad():
138
+ speech = tts_model.generate_speech(
139
+ inputs["input_ids"],
140
+ speaker_embeddings=torch.zeros((1, 512)), # βœ… fixed
141
+ vocoder=tts_vocoder
142
+ )
143
  audio_data = speech.numpy()
144
  max_val = np.max(np.abs(audio_data))
145
  if max_val > 0:
 
220
  # --- Gradio UI ---
221
  with gr.Blocks(title="🌍 Local Language AI Assistant") as demo:
222
  gr.Markdown("# 🌍 Local Language AI Assistant")
223
+ gr.Markdown("Speak or upload Amharic audio and get AI responses with voice output!")
224
 
225
  with gr.Row():
226
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="🎀 Record or Upload Audio")
227
 
228
  submit_btn = gr.Button("Process", variant="primary")
229
 
requirements.txt CHANGED
@@ -8,4 +8,6 @@ accelerate
8
  sentencepiece
9
  scipy
10
  numpy
11
- sacremoses
 
 
 
8
  sentencepiece
9
  scipy
10
  numpy
11
+ sacremoses
12
+ sacremoses
13
+ git+https://github.com/isi-nlp/uroman.git