hbchiu commited on
Commit
215caa8
Β·
verified Β·
1 Parent(s): e536ade

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -35
app.py CHANGED
@@ -20,7 +20,7 @@ import tempfile
20
  from huggingface_hub import hf_hub_download
21
  from sentence_transformers import SentenceTransformer
22
  from huggingface_hub import InferenceClient
23
- from transformers import VitsModel, AutoTokenizer, pipeline
24
 
25
 
26
  # ── Auth ───────────────────────────────────────────────────────
@@ -104,47 +104,52 @@ def detect_language(text):
104
  return "EspaΓ±ol"
105
 
106
  # ── TTS: Parler TTS mini v1 (neutral catalΓ /spanish voice) ─────────
107
- print("Loading MMS TTS models...")
108
- tts_models, tts_tokenizers = {}, {}
109
- for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items():
110
- tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo)
111
- tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device)
112
- tts_models[lang_code].eval()
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def text_to_speech(text, lang="es"):
115
- if not text or lang not in tts_models:
116
  return None
117
  try:
118
- inputs = tts_tokenizers[lang](text, return_tensors="pt").to(device)
119
- with torch.no_grad():
120
- audio = tts_models[lang](**inputs).waveform
121
- audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16")
122
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
123
- scipy.io.wavfile.write(f.name, rate=tts_models[lang].config.sampling_rate, data=audio_int16)
124
- return f.name
 
 
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
- print(f"TTS error: {e}")
127
  return None
128
 
129
- try:
130
- input_ids = tts_tokenizer(voice_desc, return_tensors="pt").input_ids.to(device)
131
- prompt_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
132
-
133
- with torch.no_grad():
134
- generation = tts_model.generate(
135
- input_ids=input_ids,
136
- prompt_input_ids=prompt_ids,
137
- )
138
-
139
- audio_array = generation.cpu().to(torch.float32).numpy().squeeze()
140
-
141
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
142
- scipy.io.wavfile.write(f.name, rate=sampling_rate, data=audio_array)
143
- return f.name
144
-
145
- except Exception as e:
146
- print(f"TTS error: {e}")
147
- return None
148
 
149
  # ── LLM: HF Inference API + RAG ───────────────────────────────
150
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
 
20
  from huggingface_hub import hf_hub_download
21
  from sentence_transformers import SentenceTransformer
22
  from huggingface_hub import InferenceClient
23
+ from transformers import VitsModel, AutoTokenizer, pipeline, SpeechT5HifiGan
24
 
25
 
26
  # ── Auth ───────────────────────────────────────────────────────
 
104
  return "EspaΓ±ol"
105
 
106
  # ── TTS: Parler TTS mini v1 (neutral catalΓ /spanish voice) ─────────
107
+ #print("Loading MMS TTS models...")
108
+ #tts_models, tts_tokenizers = {}, {}
109
+ #for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items():
110
+ # tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo)
111
+ # tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device)
112
+ # tts_models[lang_code].eval()
113
+ print("Loading TTS models...")
114
+
115
+ # Kokoro for English and Spanish
116
+ from kokoro import KPipeline
117
+ kokoro_en = KPipeline(lang_code='en')
118
+ kokoro_es = KPipeline(lang_code='es')
119
+
120
+ # Matxa (BSC) for Catalan
121
+ tts_tokenizers, tts_models = {}, {}
122
+ tts_tokenizers["ca"] = AutoTokenizer.from_pretrained("projecte-aina/matxa-tts-cat-multiaccent")
123
+ tts_models["ca"] = VitsModel.from_pretrained("projecte-aina/matxa-tts-cat-multiaccent").to(device)
124
+ tts_models["ca"].eval()
125
 
126
  def text_to_speech(text, lang="es"):
127
+ if not text:
128
  return None
129
  try:
130
+ if lang == "ca":
131
+ inputs = tts_tokenizers["ca"](text, return_tensors="pt").to(device)
132
+ with torch.no_grad():
133
+ audio = tts_models["ca"](**inputs).waveform
134
+ audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16")
135
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
136
+ scipy.io.wavfile.write(f.name, rate=tts_models["ca"].config.sampling_rate, data=audio_int16)
137
+ return f.name
138
+ else:
139
+ pipeline = kokoro_en if lang == "en" else kokoro_es
140
+ voice = "af_heart" if lang == "en" else "ef_dora"
141
+ audio_chunks = []
142
+ for _, _, audio in pipeline(text, voice=voice):
143
+ audio_chunks.append(audio)
144
+ audio_np = np.concatenate(audio_chunks)
145
+ audio_int16 = (audio_np * 32767).clip(-32768, 32767).astype("int16")
146
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
147
+ scipy.io.wavfile.write(f.name, rate=24000, data=audio_int16)
148
+ return f.name
149
  except Exception as e:
150
+ print(f"TTS error ({lang}): {e}")
151
  return None
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # ── LLM: HF Inference API + RAG ───────────────────────────────
155
  SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.