hbchiu commited on
Commit
5b34b4f
Β·
verified Β·
1 Parent(s): 7721466

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -130
app.py CHANGED
@@ -1,184 +1,165 @@
1
  # app.py β€” CareCompanion: Alzheimer's Caregiver Voice Assistant
2
  #
3
- # Stack (all via API β€” no local model loading, fast startup):
4
- # STT: openai/whisper-large-v3 via HF Inference API
5
- # LANG: papluca/xlm-roberta-base-... via HF Inference API
6
- # LLM: openai/gpt-oss-20b + FAISS RAG via HF Inference API
7
- # TTS: facebook/mms-tts-* via HF Inference API (per language)
8
  #
9
  # Secrets needed in HF Space Settings:
10
- # HF_TOKEN β€” your Hugging Face access token (required)
11
 
12
  import os
13
- import time
14
  import faiss
15
  import pickle
16
- import tempfile
17
  import numpy as np
18
  import gradio as gr
 
 
 
19
 
20
- from huggingface_hub import InferenceClient
21
  from sentence_transformers import SentenceTransformer
 
 
 
22
 
23
  # ── Auth ───────────────────────────────────────────────────────
24
  HF_TOKEN = os.environ.get("HF_TOKEN")
25
  if not HF_TOKEN:
26
  raise ValueError("HF_TOKEN is not set. Add it in Space Settings β†’ Repository Secrets.")
27
 
28
- # Single shared API client β€” reused for all calls
29
- api_client = InferenceClient(token=HF_TOKEN)
 
 
30
 
31
- # ── RAG: FAISS + multilingual embeddings ───────────────────────
32
- print("Loading FAISS index and embedding model...")
33
  index = faiss.read_index("alzheimers_index.faiss")
34
  with open("chunks.pkl", "rb") as f:
35
  chunks = pickle.load(f)
36
 
37
- # Multilingual embeddings β€” handles English, Spanish, Catalan and 50+ others
38
- # Much better than all-MiniLM-L6-v2 for multilingual content
39
- embed_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
40
- print(f"Loaded {len(chunks)} chunks into RAG")
41
 
42
  def retrieve_rag_context(query, k=5):
43
- """Retrieve top-k relevant chunks from FAISS index."""
44
  query_embedding = embed_model.encode([query])
45
  distances, indices = index.search(np.array(query_embedding), k)
46
  results = []
47
  for i in indices[0]:
48
  chunk = chunks[i]
49
  print(f" RAG chunk: id={chunk.get('id')} topic={chunk.get('topic')} lang={chunk.get('language')}")
50
- print(f" Preview: {chunk['text'][:100]}")
51
  results.append(chunk["text"])
52
  return "\n\n---\n\n".join(results)
53
 
54
- # ── Language detection ─────────────────────────────────────────
55
- # Maps detected language codes to MMS TTS model names
56
- MMS_MODELS = {
57
- "en": "facebook/mms-tts-eng",
58
- "es": "facebook/mms-tts-spa",
59
- "ca": "facebook/mms-tts-cat", # Catalan
60
- "fr": "facebook/mms-tts-fra",
61
- "de": "facebook/mms-tts-deu",
62
- "it": "facebook/mms-tts-ita",
63
- "pt": "facebook/mms-tts-por",
64
- }
65
- DEFAULT_TTS_MODEL = "facebook/mms-tts-eng"
66
-
67
- def detect_language(text):
68
- """Detect language of text using xlm-roberta model."""
69
- try:
70
- result = api_client.text_classification(
71
- text,
72
- model="papluca/xlm-roberta-base-language-detection"
73
- )
74
- lang_code = result[0].label[:2].lower()
75
- print(f" Detected language: {lang_code} (confidence: {result[0].score:.2f})")
76
- return lang_code
77
- except Exception as e:
78
- print(f" Language detection failed: {e} β€” defaulting to English")
79
- return "en"
80
 
81
- # ── STT: Whisper via HF Inference API ─────────────────────────
82
  def transcribe_audio(audio_path):
83
- """Transcribe audio file using Whisper via HF API."""
84
  if audio_path is None:
85
- return "", "en"
86
-
87
- t0 = time.time()
88
- try:
89
- with open(audio_path, "rb") as f:
90
- result = api_client.automatic_speech_recognition(
91
- f,
92
- model="openai/whisper-large-v3"
93
- )
94
- transcript = result.text.strip()
95
- print(f"STT done in {time.time()-t0:.1f}s: '{transcript}'")
96
-
97
- # Detect language from what was spoken
98
- lang = detect_language(transcript) if transcript else "en"
99
- return transcript, lang
100
-
101
- except Exception as e:
102
- print(f"STT error: {e}")
103
- return "", "en"
104
-
105
- # ── TTS: Facebook MMS via HF Inference API ────────────────────
106
- def text_to_speech(text, language="en"):
107
- """Convert text to speech using Facebook MMS β€” proper per-language voices."""
 
 
 
 
 
 
 
 
 
108
  if not text:
109
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- t0 = time.time()
112
- tts_model = MMS_MODELS.get(language, DEFAULT_TTS_MODEL)
113
- print(f"TTS using model: {tts_model}")
114
 
115
- try:
116
- audio_bytes = api_client.text_to_speech(
117
- text,
118
- model=tts_model
119
- )
120
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
121
- f.write(audio_bytes)
122
- print(f"TTS done in {time.time()-t0:.1f}s")
123
  return f.name
124
 
125
  except Exception as e:
126
- print(f"TTS error ({tts_model}): {e}")
127
- # Try fallback to English if language-specific model fails
128
- try:
129
- audio_bytes = api_client.text_to_speech(text, model=DEFAULT_TTS_MODEL)
130
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
131
- f.write(audio_bytes)
132
- return f.name
133
- except Exception as e2:
134
- print(f"TTS fallback also failed: {e2}")
135
- return None
136
 
137
  # ── LLM: HF Inference API + RAG ───────────────────────────────
138
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
139
 
140
  Your role is to:
141
  - Provide clear, compassionate guidance for caregiving challenges
142
- - Suggest relevant local support services when available in the retrieved context below
143
  - Give practical, actionable advice
144
- - Keep responses concise β€” under 100 words β€” so they are easy to listen to
145
  - Always be encouraging and non-judgmental
146
  - Respond in the same language the user wrote in
147
 
148
  If asked about local resources, ONLY reference services mentioned in the retrieved context. Do not invent services.
149
- If no relevant local services are found in the context, say so honestly.
150
  Always remind caregivers that asking for help is a sign of strength, not weakness."""
151
 
152
- def respond_to_message(message, history, detected_lang="en"):
153
- """Generate a response using RAG context + LLM."""
154
  if not message.strip():
155
  return ""
156
 
157
- t0 = time.time()
158
 
159
- # Retrieve relevant chunks from FAISS
160
  rag_context = retrieve_rag_context(message)
161
  full_system = (
162
  f"{SYSTEM_PROMPT}\n\n"
163
- f"User's language: {detected_lang}\n\n"
164
  f"=== RETRIEVED KNOWLEDGE BASE CONTEXT ===\n{rag_context}\n"
165
  f"========================================\n"
166
  f"Only use the above context for local resource recommendations."
167
  )
168
 
169
- # Build message history
170
  messages = [{"role": "system", "content": full_system}]
171
- for h in history[-6:]: # keep last 6 turns
172
  if isinstance(h, dict):
173
  messages.append({"role": h["role"], "content": h["content"]})
174
  messages.append({"role": "user", "content": message})
175
 
176
- # Stream response from LLM
177
  response = ""
178
  try:
179
- for chunk in api_client.chat_completion(
180
  messages,
181
- model="openai/gpt-oss-20b",
182
  max_tokens=350,
183
  stream=True,
184
  temperature=0.7,
@@ -186,53 +167,40 @@ def respond_to_message(message, history, detected_lang="en"):
186
  ):
187
  if chunk.choices and chunk.choices[0].delta.content:
188
  response += chunk.choices[0].delta.content
189
-
190
- print(f"LLM done in {time.time()-t0:.1f}s")
191
  return response.strip()
192
-
193
  except Exception as e:
194
  print(f"LLM error: {e}")
195
  return "I'm sorry, I had trouble generating a response. Please try again."
196
 
197
- # ── Voice pipeline: mic β†’ STT β†’ LLM+RAG β†’ TTS ────────────────
198
  def voice_pipeline(audio_input, history):
199
- t_start = time.time()
200
-
201
- transcript, lang = transcribe_audio(audio_input)
202
  if not transcript:
203
  return history, None, "⚠️ Could not transcribe audio. Please try again."
204
 
205
- reply = respond_to_message(transcript, history, lang)
206
 
207
  history = history or []
208
  history.append({"role": "user", "content": transcript})
209
  history.append({"role": "assistant", "content": reply})
210
 
211
- audio_out = text_to_speech(reply, language=lang)
 
212
 
213
- print(f"Total voice pipeline: {time.time()-t_start:.1f}s")
214
- return history, audio_out, f'"{transcript}" [{lang}]'
215
-
216
- # ── Text pipeline: text β†’ LLM+RAG β†’ TTS ──────────────────────
217
  def text_pipeline(text_input, history):
218
  if not text_input.strip():
219
  return history, None, ""
220
 
221
- t_start = time.time()
222
-
223
- lang = detect_language(text_input)
224
- reply = respond_to_message(text_input, history, lang)
225
 
226
  history = history or []
227
  history.append({"role": "user", "content": text_input})
228
  history.append({"role": "assistant", "content": reply})
229
 
230
- audio_out = text_to_speech(reply, language=lang)
231
-
232
- print(f"Total text pipeline: {time.time()-t_start:.1f}s")
233
  return history, audio_out, ""
234
 
235
- # ── Gradio UI ──────────────────────────────────────────────────
236
  with gr.Blocks(
237
  theme=gr.themes.Soft(
238
  primary_hue="green",
@@ -246,10 +214,9 @@ with gr.Blocks(
246
 
247
  gr.Markdown(
248
  """
249
- # Sherpa
250
- ### Smart Support for Alzheimer's Caregivers in Barcelona
251
- *Ask anything by voice or text β€” in English, Spanish, or Catalan.*
252
- *Responses draw from a curated local knowledge base.*
253
  """
254
  )
255
 
@@ -284,7 +251,7 @@ with gr.Blocks(
284
  gr.Markdown("---")
285
  gr.Markdown("### ⌨️ Text Input")
286
  text_input = gr.Textbox(
287
- placeholder="Or type your question here… (any language)",
288
  label="",
289
  lines=3,
290
  )
@@ -305,7 +272,7 @@ with gr.Blocks(
305
  """
306
  ---
307
  *Responses are AI-generated and do not replace professional medical advice.
308
- In emergencies, call 112 (EU) or your local emergency services.*
309
  """
310
  )
311
 
 
1
  # app.py β€” CareCompanion: Alzheimer's Caregiver Voice Assistant
2
  #
3
+ # Stack:
4
+ # STT: distil-whisper/distil-large-v3 (local, fast)
5
+ # LLM: openai/gpt-oss-20b + FAISS RAG (HF Inference API)
6
+ # TTS: parler-tts/parler-tts-mini-v1 (local, neutral American voice)
 
7
  #
8
  # Secrets needed in HF Space Settings:
9
+ # HF_TOKEN β€” your Hugging Face access token
10
 
11
  import os
 
12
  import faiss
13
  import pickle
 
14
  import numpy as np
15
  import gradio as gr
16
+ import torch
17
+ import scipy.io.wavfile
18
+ import tempfile
19
 
 
20
  from sentence_transformers import SentenceTransformer
21
+ from huggingface_hub import InferenceClient
22
+ from transformers import AutoTokenizer, pipeline
23
+ from parler_tts import ParlerTTSForConditionalGeneration
24
 
25
  # ── Auth ───────────────────────────────────────────────────────
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
  if not HF_TOKEN:
28
  raise ValueError("HF_TOKEN is not set. Add it in Space Settings β†’ Repository Secrets.")
29
 
30
+ # ── Device ────────────────────────────────────────────────────
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
33
+ print(f"Running on: {device}")
34
 
35
+ # ── RAG: FAISS index ──────────────────────────────────────────
36
+ print("Loading FAISS index...")
37
  index = faiss.read_index("alzheimers_index.faiss")
38
  with open("chunks.pkl", "rb") as f:
39
  chunks = pickle.load(f)
40
 
41
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
42
+ print(f"Loaded {len(chunks)} chunks")
 
 
43
 
44
  def retrieve_rag_context(query, k=5):
 
45
  query_embedding = embed_model.encode([query])
46
  distances, indices = index.search(np.array(query_embedding), k)
47
  results = []
48
  for i in indices[0]:
49
  chunk = chunks[i]
50
  print(f" RAG chunk: id={chunk.get('id')} topic={chunk.get('topic')} lang={chunk.get('language')}")
 
51
  results.append(chunk["text"])
52
  return "\n\n---\n\n".join(results)
53
 
54
+ # ── STT: Distil-Whisper ───────────────────────────────────────
55
+ print("Loading Whisper STT model...")
56
+ stt_pipe = pipeline(
57
+ "automatic-speech-recognition",
58
+ model="distil-whisper/distil-large-v3",
59
+ torch_dtype=torch_dtype,
60
+ device=device,
61
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  def transcribe_audio(audio_path):
 
64
  if audio_path is None:
65
+ return ""
66
+ result = stt_pipe(
67
+ audio_path,
68
+ generate_kwargs={"task": "transcribe"},
69
+ return_timestamps=False,
70
+ )
71
+ transcript = result["text"].strip()
72
+ print(f"Transcript: '{transcript}'")
73
+ return transcript
74
+
75
+ # ── TTS: Parler TTS mini v1 (neutral American voice) ─────────
76
+ # Using base mini-v1 model β€” NOT jenny (which is Irish)
77
+ # Laura is a warm, calm American speaker in this model
78
+ print("Loading Parler TTS model...")
79
+ TTS_REPO = "parler-tts/parler-tts-mini-v1"
80
+
81
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
82
+ TTS_REPO,
83
+ torch_dtype=torch_dtype,
84
+ low_cpu_mem_usage=True,
85
+ ).to(device)
86
+
87
+ tts_tokenizer = AutoTokenizer.from_pretrained(TTS_REPO)
88
+ sampling_rate = tts_model.audio_encoder.config.sampling_rate
89
+
90
+ VOICE_DESCRIPTION = (
91
+ "Laura speaks with a warm, calm and empathetic American accent. "
92
+ "She speaks clearly at a gentle, measured pace, like a caring nurse. "
93
+ "The audio is very clean with no background noise."
94
+ )
95
+
96
+ def text_to_speech(text):
97
  if not text:
98
  return None
99
+ try:
100
+ input_ids = tts_tokenizer(
101
+ VOICE_DESCRIPTION, return_tensors="pt"
102
+ ).input_ids.to(device)
103
+ prompt_ids = tts_tokenizer(
104
+ text, return_tensors="pt"
105
+ ).input_ids.to(device)
106
+
107
+ with torch.no_grad():
108
+ generation = tts_model.generate(
109
+ input_ids=input_ids,
110
+ prompt_input_ids=prompt_ids,
111
+ )
112
 
113
+ # ← float16 fix: convert to float32 before writing WAV
114
+ audio_array = generation.cpu().to(torch.float32).numpy().squeeze()
 
115
 
 
 
 
 
 
116
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
117
+ scipy.io.wavfile.write(f.name, rate=sampling_rate, data=audio_array)
 
118
  return f.name
119
 
120
  except Exception as e:
121
+ print(f"TTS error: {e}")
122
+ return None # silently skip audio, text response still shows
 
 
 
 
 
 
 
 
123
 
124
  # ── LLM: HF Inference API + RAG ───────────────────────────────
125
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
126
 
127
  Your role is to:
128
  - Provide clear, compassionate guidance for caregiving challenges
129
+ - Suggest relevant local support services when available in the retrieved context
130
  - Give practical, actionable advice
131
+ - Keep responses concise β€” under 120 words β€” so they are easy to listen to
132
  - Always be encouraging and non-judgmental
133
  - Respond in the same language the user wrote in
134
 
135
  If asked about local resources, ONLY reference services mentioned in the retrieved context. Do not invent services.
136
+ If no relevant local services are in the context, say so honestly.
137
  Always remind caregivers that asking for help is a sign of strength, not weakness."""
138
 
139
+ def respond_to_message(message, history):
 
140
  if not message.strip():
141
  return ""
142
 
143
+ client = InferenceClient(token=HF_TOKEN, model="openai/gpt-oss-20b")
144
 
 
145
  rag_context = retrieve_rag_context(message)
146
  full_system = (
147
  f"{SYSTEM_PROMPT}\n\n"
 
148
  f"=== RETRIEVED KNOWLEDGE BASE CONTEXT ===\n{rag_context}\n"
149
  f"========================================\n"
150
  f"Only use the above context for local resource recommendations."
151
  )
152
 
 
153
  messages = [{"role": "system", "content": full_system}]
154
+ for h in history[-6:]:
155
  if isinstance(h, dict):
156
  messages.append({"role": h["role"], "content": h["content"]})
157
  messages.append({"role": "user", "content": message})
158
 
 
159
  response = ""
160
  try:
161
+ for chunk in client.chat_completion(
162
  messages,
 
163
  max_tokens=350,
164
  stream=True,
165
  temperature=0.7,
 
167
  ):
168
  if chunk.choices and chunk.choices[0].delta.content:
169
  response += chunk.choices[0].delta.content
 
 
170
  return response.strip()
 
171
  except Exception as e:
172
  print(f"LLM error: {e}")
173
  return "I'm sorry, I had trouble generating a response. Please try again."
174
 
175
+ # ── Pipelines ─────────────────────────────────────────────────
176
  def voice_pipeline(audio_input, history):
177
+ transcript = transcribe_audio(audio_input)
 
 
178
  if not transcript:
179
  return history, None, "⚠️ Could not transcribe audio. Please try again."
180
 
181
+ reply = respond_to_message(transcript, history)
182
 
183
  history = history or []
184
  history.append({"role": "user", "content": transcript})
185
  history.append({"role": "assistant", "content": reply})
186
 
187
+ audio_out = text_to_speech(reply)
188
+ return history, audio_out, f'"{transcript}"'
189
 
 
 
 
 
190
  def text_pipeline(text_input, history):
191
  if not text_input.strip():
192
  return history, None, ""
193
 
194
+ reply = respond_to_message(text_input, history)
 
 
 
195
 
196
  history = history or []
197
  history.append({"role": "user", "content": text_input})
198
  history.append({"role": "assistant", "content": reply})
199
 
200
+ audio_out = text_to_speech(reply)
 
 
201
  return history, audio_out, ""
202
 
203
+ # ── Gradio UI ─────────────────────────────────────────────────
204
  with gr.Blocks(
205
  theme=gr.themes.Soft(
206
  primary_hue="green",
 
214
 
215
  gr.Markdown(
216
  """
217
+ # 🀍 CareCompanion
218
+ ### Alzheimer's Caregiver Support Assistant
219
+ *Ask anything β€” by voice or text. Responses draw from a curated Alzheimer's knowledge base.*
 
220
  """
221
  )
222
 
 
251
  gr.Markdown("---")
252
  gr.Markdown("### ⌨️ Text Input")
253
  text_input = gr.Textbox(
254
+ placeholder="Or type your question here…",
255
  label="",
256
  lines=3,
257
  )
 
272
  """
273
  ---
274
  *Responses are AI-generated and do not replace professional medical advice.
275
+ In emergencies, call 112 or your local emergency services.*
276
  """
277
  )
278