Jatila commited on
Commit
23c8db3
Β·
verified Β·
1 Parent(s): e2c032e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -97
app.py CHANGED
@@ -58,6 +58,15 @@ def retrieve_rag_context(query, k=5):
58
  results.append(chunk["text"])
59
  return "\n\n---\n\n".join(results)
60
 
 
 
 
 
 
 
 
 
 
61
  # ── STT: Distil-Whisper ───────────────────────────────────────
62
  print("Loading Whisper STT model...")
63
  stt_pipe = pipeline(
@@ -70,15 +79,21 @@ stt_pipe = pipeline(
70
  def transcribe_audio(audio_path):
71
  if audio_path is None:
72
  return ""
73
- result = stt_pipe(
74
- audio_path,
75
- generate_kwargs={"task": "transcribe"},
76
- return_timestamps=False,
77
- )
78
  transcript = result["text"].strip()
79
- print(f"Transcript: '{transcript}'")
80
  return transcript
81
 
 
 
 
 
 
 
 
 
 
 
 
82
  # ── TTS: Parler TTS mini v1 (neutral American voice) ─────────
83
  print("Loading Parler TTS model...")
84
  TTS_REPO = "parler-tts/parler-tts-mini-v1"
@@ -98,34 +113,31 @@ VOICE_DESCRIPTION = (
98
  "The audio is very clean with no background noise."
99
  )
100
 
101
- def text_to_speech(text):
102
  if not text:
103
  return None
 
104
  try:
105
- input_ids = tts_tokenizer(
106
- VOICE_DESCRIPTION, return_tensors="pt"
107
- ).input_ids.to(device)
108
- prompt_ids = tts_tokenizer(
109
- text, return_tensors="pt"
110
- ).input_ids.to(device)
 
 
 
111
 
112
  with torch.no_grad():
113
- generation = tts_model.generate(
114
- input_ids=input_ids,
115
- prompt_input_ids=prompt_ids,
116
- )
117
 
118
- # ← float16 fix: convert to float32 before writing WAV
119
  audio_array = generation.cpu().to(torch.float32).numpy().squeeze()
120
-
121
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
122
  scipy.io.wavfile.write(f.name, rate=sampling_rate, data=audio_array)
123
  return f.name
124
-
125
  except Exception as e:
126
  print(f"TTS error: {e}")
127
- return None # silently skip audio, text response still shows
128
-
129
  # ── LLM: HF Inference API + RAG ───────────────────────────────
130
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
131
 
@@ -140,19 +152,14 @@ If asked about local resources, ONLY reference services mentioned in the retriev
140
  If no relevant local services are in the context, say so honestly.
141
  Always remind caregivers that asking for help is a sign of strength, not weakness."""
142
 
143
- def respond_to_message(message, history):
144
  if not message.strip():
145
  return ""
146
 
147
  client = InferenceClient(token=HF_TOKEN, model="openai/gpt-oss-20b")
148
 
149
  rag_context = retrieve_rag_context(message)
150
- full_system = (
151
- f"{SYSTEM_PROMPT}\n\n"
152
- f"=== RETRIEVED KNOWLEDGE BASE CONTEXT ===\n{rag_context}\n"
153
- f"========================================\n"
154
- f"Only use the above context for local resource recommendations."
155
- )
156
 
157
  messages = [{"role": "system", "content": full_system}]
158
  for h in history[-6:]:
@@ -174,102 +181,65 @@ def respond_to_message(message, history):
174
  return response.strip()
175
  except Exception as e:
176
  print(f"LLM error: {e}")
177
- return "I'm sorry, I had trouble generating a response. Please try again."
 
178
 
179
  # ── Pipelines ─────────────────────────────────────────────────
180
  def voice_pipeline(audio_input, history):
181
  transcript = transcribe_audio(audio_input)
182
  if not transcript:
183
- return history, None, "⚠️ Could not transcribe audio. Please try again."
184
 
185
- reply = respond_to_message(transcript, history)
 
186
 
187
  history = history or []
188
  history.append({"role": "user", "content": transcript})
189
  history.append({"role": "assistant", "content": reply})
190
 
191
- audio_out = text_to_speech(reply)
192
- return history, audio_out, f'"{transcript}"'
193
 
194
- def text_pipeline(text_input, history):
195
  if not text_input.strip():
196
  return history, None, ""
197
-
198
- reply = respond_to_message(text_input, history)
199
-
200
  history = history or []
201
  history.append({"role": "user", "content": text_input})
202
  history.append({"role": "assistant", "content": reply})
203
-
204
- audio_out = text_to_speech(reply)
205
  return history, audio_out, ""
206
 
207
  # ── Gradio UI ─────────────────────────────────────────────────
208
- with gr.Blocks(
209
- theme=gr.themes.Soft(
210
- primary_hue="green",
211
- neutral_hue="slate",
212
- font=gr.themes.GoogleFont("DM Sans"),
213
- ),
214
- title="CareCompanion",
215
- ) as demo:
216
-
217
  chat_history = gr.State([])
218
 
219
- gr.Markdown(
220
- """
221
- # SherpaAI
222
- ### Smart support for AD caregivers in Barcelona
223
- """
224
- )
225
 
226
  with gr.Row():
227
  with gr.Column(scale=2):
228
- chatbot = gr.Chatbot(
229
- label="Conversation",
230
- height=420,
231
- type="messages",
232
- show_label=False,
233
- bubble_full_width=False,
234
- )
235
- audio_output = gr.Audio(
236
- label="πŸ”Š Voice Response",
237
- autoplay=True,
238
- show_download_button=False,
239
- )
240
 
241
  with gr.Column(scale=1):
 
 
 
242
  gr.Markdown("### 🎀 Voice Input")
243
- audio_input = gr.Audio(
244
- sources=["microphone"],
245
- type="filepath",
246
- label="Record your question",
247
- )
248
- voice_btn = gr.Button(
249
- "🎀 Send Voice Message",
250
- variant="primary",
251
- size="lg",
252
- )
253
-
254
- gr.Markdown("---")
255
- gr.Markdown("### ⌨️ Text Input")
256
- text_input = gr.Textbox(
257
- placeholder="Or type your question here…",
258
- label="",
259
- lines=3,
260
- )
261
- text_btn = gr.Button(
262
- "➀ Send Text Message",
263
- variant="secondary",
264
- size="lg",
265
- )
266
-
267
- transcript_display = gr.Textbox(
268
- label="πŸ“ What I heard",
269
- interactive=False,
270
- lines=2,
271
- placeholder="Your transcribed speech will appear here…",
272
- )
273
 
274
  gr.Markdown(
275
  """
 
58
  results.append(chunk["text"])
59
  return "\n\n---\n\n".join(results)
60
 
61
+ # ── SYSTEM PROMPTS ─────────────────────────────
62
+ def get_system_prompt(lang="EspaΓ±ol"):
63
+ if lang == "CatalΓ ":
64
+ return """Ets un assistent cΓ lid i empΓ tic per a cuidadors de persones amb Alzheimer a Barcelona.
65
+ Proporciona orientaciΓ³ clara, menciona serveis locals si existeixen en el context i mantΓ©n les respostes breus i comprensibles."""
66
+ else:
67
+ return """Eres un asistente cΓ‘lido y empΓ‘tico para cuidadores de personas con Alzheimer en Barcelona.
68
+ Proporciona orientaciΓ³n clara, menciona recursos locales si existen en el contexto y mantΓ©n las respuestas breves y comprensibles."""
69
+
70
  # ── STT: Distil-Whisper ───────────────────────────────────────
71
  print("Loading Whisper STT model...")
72
  stt_pipe = pipeline(
 
79
  def transcribe_audio(audio_path):
80
  if audio_path is None:
81
  return ""
82
+ result = stt_pipe(audio_path, generate_kwargs={"task": "transcribe"}, return_timestamps=False)
 
 
 
 
83
  transcript = result["text"].strip()
 
84
  return transcript
85
 
86
+ def detect_language(text):
87
+ try:
88
+ lang = detect(text)
89
+ if lang == "ca":
90
+ return "CatalΓ "
91
+ elif lang == "es":
92
+ return "EspaΓ±ol"
93
+ else:
94
+ return "EspaΓ±ol"
95
+ except:
96
+ return "EspaΓ±ol"
97
  # ── TTS: Parler TTS mini v1 (neutral American voice) ─────────
98
  print("Loading Parler TTS model...")
99
  TTS_REPO = "parler-tts/parler-tts-mini-v1"
 
113
  "The audio is very clean with no background noise."
114
  )
115
 
116
+ def text_to_speech(text, lang="EspaΓ±ol"):
117
  if not text:
118
  return None
119
+
120
  try:
121
+ # Spanish-capable TTS, adjust for Catalan if a model exists
122
+ model_repo = "tts_models/es/tacotron2-DDC"
123
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
124
+ model_repo, torch_dtype=torch_dtype
125
+ ).to(device)
126
+ tts_tokenizer = AutoTokenizer.from_pretrained(model_repo)
127
+ sampling_rate = 22050
128
+
129
+ input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
130
 
131
  with torch.no_grad():
132
+ generation = tts_model.generate(input_ids=input_ids)
 
 
 
133
 
 
134
  audio_array = generation.cpu().to(torch.float32).numpy().squeeze()
 
135
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
136
  scipy.io.wavfile.write(f.name, rate=sampling_rate, data=audio_array)
137
  return f.name
 
138
  except Exception as e:
139
  print(f"TTS error: {e}")
140
+ return None
 
141
  # ── LLM: HF Inference API + RAG ───────────────────────────────
142
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
143
 
 
152
  If no relevant local services are in the context, say so honestly.
153
  Always remind caregivers that asking for help is a sign of strength, not weakness."""
154
 
155
+ def respond_to_message(message, history, lang="EspaΓ±ol"):
156
  if not message.strip():
157
  return ""
158
 
159
  client = InferenceClient(token=HF_TOKEN, model="openai/gpt-oss-20b")
160
 
161
  rag_context = retrieve_rag_context(message)
162
+ full_system = f"{get_system_prompt(lang)}\n\n=== RETRIEVED CONTEXT ===\n{rag_context}"
 
 
 
 
 
163
 
164
  messages = [{"role": "system", "content": full_system}]
165
  for h in history[-6:]:
 
181
  return response.strip()
182
  except Exception as e:
183
  print(f"LLM error: {e}")
184
+ return "Ho sento, no puc generar una resposta en aquest moment." if lang=="CatalΓ " else "Lo siento, no puedo generar una respuesta en este momento."
185
+
186
 
187
  # ── Pipelines ─────────────────────────────────────────────────
188
  def voice_pipeline(audio_input, history):
189
  transcript = transcribe_audio(audio_input)
190
  if not transcript:
191
+ return history, None, "⚠️ Could not transcribe audio."
192
 
193
+ lang = detect_language(transcript)
194
+ reply = respond_to_message(transcript, history, lang=lang)
195
 
196
  history = history or []
197
  history.append({"role": "user", "content": transcript})
198
  history.append({"role": "assistant", "content": reply})
199
 
200
+ audio_out = text_to_speech(reply, lang=lang)
201
+ return history, audio_out, transcript
202
 
203
+ def text_pipeline(text_input, history, lang):
204
  if not text_input.strip():
205
  return history, None, ""
206
+ reply = respond_to_message(text_input, history, lang=lang)
 
 
207
  history = history or []
208
  history.append({"role": "user", "content": text_input})
209
  history.append({"role": "assistant", "content": reply})
210
+ audio_out = text_to_speech(reply, lang=lang)
 
211
  return history, audio_out, ""
212
 
213
  # ── Gradio UI ─────────────────────────────────────────────────
214
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
215
  chat_history = gr.State([])
216
 
217
+ gr.Markdown("## SherpaAI β€” Suport intelΒ·ligent per a cuidadors d’Alzheimer a Barcelona")
 
 
 
 
 
218
 
219
  with gr.Row():
220
  with gr.Column(scale=2):
221
+ chatbot = gr.Chatbot(label="Conversation", height=420)
222
+ audio_output = gr.Audio(label="πŸ”Š Voice Response", autoplay=True)
 
 
 
 
 
 
 
 
 
 
223
 
224
  with gr.Column(scale=1):
225
+ lang_selector = gr.Dropdown(["EspaΓ±ol", "CatalΓ "], label="Language", value="EspaΓ±ol")
226
+ text_input = gr.Textbox(placeholder="Escriu la teva pregunta aquí…", lines=3)
227
+ text_btn = gr.Button("Enviar / Send")
228
  gr.Markdown("### 🎀 Voice Input")
229
+ audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Record your question")
230
+ voice_btn = gr.Button("🎀 Send Voice Message")
231
+
232
+ text_btn.click(
233
+ fn=text_pipeline,
234
+ inputs=[text_input, chat_history, lang_selector],
235
+ outputs=[chat_history, audio_output, text_input],
236
+ )
237
+
238
+ voice_btn.click(
239
+ fn=voice_pipeline,
240
+ inputs=[audio_input, chat_history],
241
+ outputs=[chat_history, audio_output, text_input],
242
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  gr.Markdown(
245
  """