hbchiu commited on
Commit
caac5cc
Β·
verified Β·
1 Parent(s): e8c2fbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -54
app.py CHANGED
@@ -1,13 +1,3 @@
1
- # app.py β€” CareCompanion: Alzheimer's Caregiver Voice Assistant
2
- #
3
- # Stack:
4
- # STT: distil-whisper/distil-large-v3 (local, fast)
5
- # LLM: openai/gpt-oss-20b + FAISS RAG (HF Inference API)
6
- # TTS: parler-tts/parler-tts-mini-v1 (local, neutral American voice)
7
- #
8
- # Secrets needed in HF Space Settings:
9
- # HF_TOKEN β€” your Hugging Face access token
10
-
11
  import os
12
  import faiss
13
  import pickle
@@ -20,7 +10,7 @@ import tempfile
20
  from huggingface_hub import hf_hub_download
21
  from sentence_transformers import SentenceTransformer
22
  from huggingface_hub import InferenceClient
23
- from transformers import VitsModel, AutoTokenizer, pipeline, SpeechT5HifiGan
24
 
25
 
26
  # ── Auth ───────────────────────────────────────────────────────
@@ -95,59 +85,58 @@ def detect_language(text):
95
  try:
96
  lang = detect(text)
97
  if lang == "ca":
98
- return "CatalΓ "
99
  elif lang == "es":
100
- return "EspaΓ±ol"
 
 
101
  else:
102
- return "EspaΓ±ol"
103
  except:
104
- return "EspaΓ±ol"
105
 
106
  # ── TTS: Parler TTS mini v1 (neutral catalΓ /spanish voice) ─────────
107
- #print("Loading MMS TTS models...")
108
- #tts_models, tts_tokenizers = {}, {}
109
- #for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items():
110
- # tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo)
111
- # tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device)
112
- # tts_models[lang_code].eval()
113
  print("Loading TTS models...")
114
 
115
- # Kokoro for English and Spanish
116
- from kokoro import KPipeline
117
- kokoro_en = KPipeline(lang_code='en')
118
- kokoro_es = KPipeline(lang_code='es')
119
-
120
- # Matxa (BSC) for Catalan
121
- tts_tokenizers, tts_models = {}, {}
122
- tts_tokenizers["ca"] = AutoTokenizer.from_pretrained("projecte-aina/matxa-tts-cat-multiaccent")
123
- tts_models["ca"] = VitsModel.from_pretrained("projecte-aina/matxa-tts-cat-multiaccent").to(device)
124
- tts_models["ca"].eval()
125
-
126
  def text_to_speech(text, lang="es"):
127
- if not text:
128
  return None
129
  try:
130
- if lang == "ca":
131
- inputs = tts_tokenizers["ca"](text, return_tensors="pt").to(device)
132
- with torch.no_grad():
133
- audio = tts_models["ca"](**inputs).waveform
134
- audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16")
135
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
136
- scipy.io.wavfile.write(f.name, rate=tts_models["ca"].config.sampling_rate, data=audio_int16)
137
- return f.name
138
- else:
139
- pipeline = kokoro_en if lang == "en" else kokoro_es
140
- voice = "af_heart" if lang == "en" else "ef_dora"
141
- audio_chunks = []
142
- for _, _, audio in pipeline(text, voice=voice):
143
- audio_chunks.append(audio)
144
- audio_np = np.concatenate(audio_chunks)
145
- audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype("int16")
146
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
147
- scipy.io.wavfile.write(f.name, rate=24000, data=audio_int16)
148
- return f.name
 
 
 
 
 
 
 
 
149
  except Exception as e:
150
- print(f"TTS error ({lang}): {e}")
151
  return None
152
 
153
 
@@ -165,7 +154,7 @@ If asked about local resources, ONLY reference services mentioned in the retriev
165
  If no relevant local services are in the context, say so honestly.
166
  Always remind caregivers that asking for help is a sign of strength, not weakness."""
167
 
168
- def respond_to_message(message, history, lang="EspaΓ±ol"):
169
  if not message.strip():
170
  return ""
171
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import faiss
3
  import pickle
 
10
  from huggingface_hub import hf_hub_download
11
  from sentence_transformers import SentenceTransformer
12
  from huggingface_hub import InferenceClient
13
+ from transformers import VitsModel, AutoTokenizer, pipeline
14
 
15
 
16
  # ── Auth ───────────────────────────────────────────────────────
 
85
  try:
86
  lang = detect(text)
87
  if lang == "ca":
88
+ return "ca"
89
  elif lang == "es":
90
+ return "es"
91
+ elif lang == "en":
92
+ return "en"
93
  else:
94
+ return "es"
95
  except:
96
+ return "es"
97
 
98
  # ── TTS: Parler TTS mini v1 (neutral catalΓ /spanish voice) ─────────
99
+ print("Loading MMS TTS models...")
100
+ tts_models, tts_tokenizers = {}, {}
101
+ for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items():
102
+ tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo)
103
+ tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device)
104
+ tts_models[lang_code].eval()
105
  print("Loading TTS models...")
106
 
 
 
 
 
 
 
 
 
 
 
 
107
  def text_to_speech(text, lang="es"):
108
+ if not text or lang not in tts_models:
109
  return None
110
  try:
111
+ inputs = tts_tokenizers[lang](text, return_tensors="pt").to(device)
112
+ with torch.no_grad():
113
+ audio = tts_models[lang](**inputs).waveform
114
+ audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16")
115
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
116
+ scipy.io.wavfile.write(f.name, rate=tts_models[lang].config.sampling_rate, data=audio_int16)
117
+ return f.name
118
+ except Exception as e:
119
+ print(f"TTS error: {e}")
120
+ return None
121
+
122
+ try:
123
+ input_ids = tts_tokenizer(voice_desc, return_tensors="pt").input_ids.to(device)
124
+ prompt_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
125
+
126
+ with torch.no_grad():
127
+ generation = tts_model.generate(
128
+ input_ids=input_ids,
129
+ prompt_input_ids=prompt_ids,
130
+ )
131
+
132
+ audio_array = generation.cpu().to(torch.float32).numpy().squeeze()
133
+
134
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
135
+ scipy.io.wavfile.write(f.name, rate=sampling_rate, data=audio_array)
136
+ return f.name
137
+
138
  except Exception as e:
139
+ print(f"TTS error: {e}")
140
  return None
141
 
142
 
 
154
  If no relevant local services are in the context, say so honestly.
155
  Always remind caregivers that asking for help is a sign of strength, not weakness."""
156
 
157
+ def respond_to_message(message, history, lang="es"):
158
  if not message.strip():
159
  return ""
160