Jatila commited on
Commit
b79bf8c
Β·
verified Β·
1 Parent(s): f6a9546

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -51
app.py CHANGED
@@ -113,31 +113,42 @@ VOICE_DESCRIPTION = (
113
  "The audio is very clean with no background noise."
114
  )
115
 
116
- def text_to_speech(text, lang="EspaΓ±ol"):
117
  if not text:
118
  return None
119
 
120
- try:
121
- # Spanish-capable TTS, adjust for Catalan if a model exists
122
- model_repo = "tts_models/es/tacotron2-DDC"
123
- tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
124
- model_repo, torch_dtype=torch_dtype
125
- ).to(device)
126
- tts_tokenizer = AutoTokenizer.from_pretrained(model_repo)
127
- sampling_rate = 22050
 
 
 
128
 
129
- input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
 
 
130
 
131
  with torch.no_grad():
132
- generation = tts_model.generate(input_ids=input_ids)
 
 
 
133
 
134
  audio_array = generation.cpu().to(torch.float32).numpy().squeeze()
 
135
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
136
  scipy.io.wavfile.write(f.name, rate=sampling_rate, data=audio_array)
137
  return f.name
 
138
  except Exception as e:
139
  print(f"TTS error: {e}")
140
  return None
 
141
  # ── LLM: HF Inference API + RAG ───────────────────────────────
142
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
143
 
@@ -185,31 +196,40 @@ def respond_to_message(message, history, lang="EspaΓ±ol"):
185
 
186
 
187
  # ── Pipelines ─────────────────────────────────────────────────
188
- def voice_pipeline(audio_input, history):
 
 
189
  transcript = transcribe_audio(audio_input)
190
  if not transcript:
191
- return history, None, "⚠️ Could not transcribe audio."
192
 
193
- lang = detect_language(transcript)
194
- reply = respond_to_message(transcript, history, lang=lang)
195
 
 
196
  history = history or []
197
  history.append({"role": "user", "content": transcript})
198
  history.append({"role": "assistant", "content": reply})
199
 
200
- audio_out = text_to_speech(reply, lang=lang)
201
- return history, audio_out, transcript
 
202
 
203
- def text_pipeline(text_input, history, lang):
 
204
  if not text_input.strip():
205
  return history, None, ""
206
- reply = respond_to_message(text_input, history, lang=lang)
 
 
207
  history = history or []
208
  history.append({"role": "user", "content": text_input})
209
  history.append({"role": "assistant", "content": reply})
210
- audio_out = text_to_speech(reply, lang=lang)
 
211
  return history, audio_out, ""
212
 
 
213
  # ── Gradio UI ─────────────────────────────────────────────────
214
  with gr.Blocks(
215
  theme=gr.themes.Soft(
@@ -299,39 +319,46 @@ with gr.Blocks(
299
  )
300
 
301
  # Update chatbot function (dummy, required for Gradio workflow)
302
- def update_chatbot(history):
303
- return history
304
-
305
- # Button callbacks
306
- voice_btn.click(
307
- fn=voice_pipeline,
308
- inputs=[audio_input, chat_history, lang_selector],
309
- outputs=[chat_history, audio_output, transcript_display],
310
- ).then(
311
- fn=update_chatbot,
312
- inputs=[chat_history],
313
- outputs=[chatbot],
314
- )
 
 
315
 
316
- text_btn.click(
317
- fn=text_pipeline,
318
- inputs=[text_input, chat_history, lang_selector],
319
- outputs=[chat_history, audio_output, transcript_display],
320
- ).then(
321
- fn=update_chatbot,
322
- inputs=[chat_history],
323
- outputs=[chatbot],
324
- )
325
 
326
- text_input.submit(
327
- fn=text_pipeline,
328
- inputs=[text_input, chat_history, lang_selector],
329
- outputs=[chat_history, audio_output, transcript_display],
330
- ).then(
331
- fn=update_chatbot,
332
- inputs=[chat_history],
333
- outputs=[chatbot],
334
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  if __name__ == "__main__":
337
  demo.launch()
 
113
  "The audio is very clean with no background noise."
114
  )
115
 
116
+ def text_to_speech(text, lang="es"):
117
  if not text:
118
  return None
119
 
120
+ # Choose voice description per language
121
+ if lang == "ca":
122
+ voice_desc = (
123
+ "Clara speaks Catalan with a calm, clear, empathetic voice. "
124
+ "She speaks slowly, like a caring nurse."
125
+ )
126
+ else: # default Spanish
127
+ voice_desc = (
128
+ "Laura speaks Spanish with a warm, clear, empathetic voice. "
129
+ "She speaks slowly, like a caring nurse."
130
+ )
131
 
132
+ try:
133
+ input_ids = tts_tokenizer(voice_desc, return_tensors="pt").input_ids.to(device)
134
+ prompt_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
135
 
136
  with torch.no_grad():
137
+ generation = tts_model.generate(
138
+ input_ids=input_ids,
139
+ prompt_input_ids=prompt_ids,
140
+ )
141
 
142
  audio_array = generation.cpu().to(torch.float32).numpy().squeeze()
143
+
144
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
145
  scipy.io.wavfile.write(f.name, rate=sampling_rate, data=audio_array)
146
  return f.name
147
+
148
  except Exception as e:
149
  print(f"TTS error: {e}")
150
  return None
151
+
152
  # ── LLM: HF Inference API + RAG ───────────────────────────────
153
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
154
 
 
196
 
197
 
198
  # ── Pipelines ─────────────────────────────────────────────────
199
+ # ── Voice Pipeline with Language Support ─────────────────────────
200
+ def voice_pipeline(audio_input, history, tts_lang):
201
+ # Transcribe audio
202
  transcript = transcribe_audio(audio_input)
203
  if not transcript:
204
+ return history, None, "⚠️ Could not transcribe audio. Please try again."
205
 
206
+ # Generate response from LLM + RAG
207
+ reply = respond_to_message(transcript, history)
208
 
209
+ # Update chat history
210
  history = history or []
211
  history.append({"role": "user", "content": transcript})
212
  history.append({"role": "assistant", "content": reply})
213
 
214
+ # Convert to speech
215
+ audio_out = text_to_speech(reply, tts_lang)
216
+ return history, audio_out, f'"{transcript}"'
217
 
218
+ # ── Text Pipeline with Language Support ─────────────────────────
219
+ def text_pipeline(text_input, history, tts_lang):
220
  if not text_input.strip():
221
  return history, None, ""
222
+
223
+ reply = respond_to_message(text_input, history)
224
+
225
  history = history or []
226
  history.append({"role": "user", "content": text_input})
227
  history.append({"role": "assistant", "content": reply})
228
+
229
+ audio_out = text_to_speech(reply, tts_lang)
230
  return history, audio_out, ""
231
 
232
+
233
  # ── Gradio UI ─────────────────────────────────────────────────
234
  with gr.Blocks(
235
  theme=gr.themes.Soft(
 
319
  )
320
 
321
  # Update chatbot function (dummy, required for Gradio workflow)
322
+ # Helper to refresh chatbot UI
323
+ def update_chatbot(history):
324
+ return history
325
+
326
+
327
+ # 🎀 Voice button click
328
+ voice_btn.click(
329
+ fn=voice_pipeline,
330
+ inputs=[audio_input, chat_history, lang_selector],
331
+ outputs=[chat_history, audio_output, transcript_display],
332
+ ).then(
333
+ fn=update_chatbot,
334
+ inputs=[chat_history],
335
+ outputs=[chatbot],
336
+ )
337
 
 
 
 
 
 
 
 
 
 
338
 
339
+ # ⌨️ Text button click
340
+ text_btn.click(
341
+ fn=text_pipeline,
342
+ inputs=[text_input, chat_history, lang_selector],
343
+ outputs=[chat_history, audio_output, transcript_display],
344
+ ).then(
345
+ fn=update_chatbot,
346
+ inputs=[chat_history],
347
+ outputs=[chatbot],
348
+ )
349
+
350
+
351
+ # ⌨️ Press Enter to send text
352
+ text_input.submit(
353
+ fn=text_pipeline,
354
+ inputs=[text_input, chat_history, lang_selector],
355
+ outputs=[chat_history, audio_output, transcript_display],
356
+ ).then(
357
+ fn=update_chatbot,
358
+ inputs=[chat_history],
359
+ outputs=[chatbot],
360
+ )
361
+
362
 
363
  if __name__ == "__main__":
364
  demo.launch()