File size: 13,699 Bytes
5a344d3
3d809db
 
 
04192d0
5b34b4f
 
 
5a344d3
8dcb83f
8ffd792
5b34b4f
caac5cc
3f551a8
5a344d3
8ffd792
5a344d3
 
8ffd792
5a344d3
5b34b4f
 
 
 
04192d0
5b34b4f
629f670
8dcb83f
c36f329
 
 
 
 
e2c032e
3d809db
04192d0
bc60e12
 
 
 
 
5b34b4f
60c5cb8
c36f329
5b34b4f
04192d0
8ffd792
5a344d3
3d809db
8ffd792
36b13a4
5d04636
cdeb6d2
5d04636
 
60c5cb8
e06e474
5a344d3
23c8db3
afeb5ae
 
23c8db3
62a97bb
 
 
 
23c8db3
 
62a97bb
23c8db3
5b34b4f
 
 
 
 
 
 
 
8ffd792
5a344d3
 
5b34b4f
23c8db3
5b34b4f
 
 
23c8db3
 
 
 
caac5cc
23c8db3
caac5cc
 
 
23c8db3
caac5cc
23c8db3
caac5cc
3f551a8
88a845e
caac5cc
 
 
 
 
 
215caa8
 
b79bf8c
caac5cc
3f551a8
 
caac5cc
 
 
 
 
 
 
 
 
 
 
3f551a8
caac5cc
8ffd792
23c8db3
b79bf8c
8ffd792
5a344d3
 
 
 
5b34b4f
5a344d3
5b34b4f
5a344d3
 
 
5b34b4f
5a344d3
 
caac5cc
5a344d3
 
 
5b34b4f
04192d0
3d809db
23c8db3
04192d0
5a344d3
5b34b4f
5a344d3
 
04192d0
 
 
8ffd792
5b34b4f
8ffd792
ef86ff5
8ffd792
 
 
 
 
 
 
 
 
afeb5ae
23c8db3
60c5cb8
e06e474
 
60c5cb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffd792
5b34b4f
b79bf8c
 
 
5b34b4f
5a344d3
b79bf8c
5a344d3
b79bf8c
e536ade
5a344d3
b79bf8c
5a344d3
 
 
 
b79bf8c
 
 
5a344d3
b79bf8c
 
5a344d3
 
b79bf8c
e536ade
b79bf8c
5a344d3
 
 
b79bf8c
 
5a344d3
 
b79bf8c
5b34b4f
f6a9546
 
 
 
 
 
 
 
 
5a344d3
 
59d5c3e
 
 
 
5a344d3
 
59d5c3e
5a344d3
f6a9546
 
 
 
 
 
 
59d5c3e
f6a9546
 
 
 
 
 
5a344d3
59d5c3e
5a344d3
59d5c3e
f6a9546
 
 
 
 
59d5c3e
f6a9546
 
 
 
 
 
 
3f551a8
f6a9546
473a83f
3f551a8
f6a9546
 
2b9afa2
473a83f
f6a9546
 
 
 
 
 
 
59d5c3e
f6a9546
 
 
 
 
59d5c3e
f6a9546
 
 
 
 
5a344d3
59d5c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a344d3
 
59d5c3e
 
 
 
 
 
 
 
 
 
b79bf8c
59d5c3e
 
 
 
 
 
 
 
 
 
b79bf8c
 
04192d0
 
 
f6a9546
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import os
import faiss
import pickle
import numpy as np
import gradio as gr
import torch
import scipy.io.wavfile
import tempfile

from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
from transformers import VitsModel, AutoTokenizer, pipeline


# ── Auth ───────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("HF_TOKEN is not set. Add it in Space Settings β†’ Repository Secrets.")

# ── Device ────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
print(f"Running on: {device}")

# ── RAG: FAISS index ──────────────────────────────────────────
print("Loading FAISS index...")

FAISS_FILE = "alzheimers_index_233.faiss"
CHUNKS_FILE = "chunks_233.pkl"

index = faiss.read_index(FAISS_FILE)

with open(CHUNKS_FILE, "rb") as f:
    chunks = pickle.load(f)

print(f"Total chunks: {len(chunks)}")
print(f"Type: {type(chunks[0])}")
print(f"\n--- Chunk 0 ---\n{chunks[0]}")
print(f"\n--- Chunk 1 ---\n{chunks[1]}")

embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# ── Index file debug ──────────────────────────────────────────
print(f"Loaded FAISS index with {index.ntotal} vectors")
print(f"Loaded {len(chunks)} chunks")

def retrieve_rag_context(query, k=5):
    query_embedding = embed_model.encode([query])
    distances, indices = index.search(np.array(query_embedding), k)
    results = []
    for i in indices[0]:
        chunk = chunks[i]
        print(f"  RAG chunk: source={chunk.get('source')} text={chunk['text'][:80]}")
        results.append(chunk["text"])
    return "\n\n---\n\n".join(results)
# ── Retrieval debug ──────────────────────────────────────────
    print(f"Retrieved chunks: {[c.get('topic') for c in [chunks[i] for i in indices[0]]]}")

# ── SYSTEM PROMPTS ─────────────────────────────
def get_system_prompt(lang="es"):
    if lang == "ca":
        return """Ets un assistent cΓ lid i empΓ tic per a cuidadors de persones amb Alzheimer a Barcelona. 
                    Proporciona orientaciΓ³ clara, menciona serveis locals si existeixen en el context i mantΓ©n les respostes breus i comprensibles."""
    elif lang == "en":
        return """You are a warm, empathetic assistant for caregivers of people with Alzheimer's in Barcelona.
                Provide clear guidance, mention local services if they appear in the context, and keep responses brief and easy to understand."""
    else:
        return """Eres un asistente cΓ‘lido y empΓ‘tico para cuidadores de personas con Alzheimer en Barcelona.
                Proporciona orientaciΓ³n clara, menciona recursos locales si existen en el contexto y mantΓ©n las respuestas breves y comprensibles."""

# ── STT: Distil-Whisper ───────────────────────────────────────
print("Loading Whisper STT model...")
stt_pipe = pipeline(
    "automatic-speech-recognition",
    model="distil-whisper/distil-large-v3",
    torch_dtype=torch_dtype,
    device=device,
)

def transcribe_audio(audio_path):
    if audio_path is None:
        return ""
    result = stt_pipe(audio_path, generate_kwargs={"task": "transcribe"}, return_timestamps=False)
    transcript = result["text"].strip()
    return transcript

def detect_language(text):
    try:
        lang = detect(text)
        if lang == "ca":
            return "ca"
        elif lang == "es":
            return "es"
        elif lang == "en":
            return "en"
        else:
            return "es"
    except:
        return "es"
        
# ── TTS: Parler TTS mini v1 ─────────
print("Loading MMS TTS models...")
tts_models, tts_tokenizers = {}, {}
for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items():
    tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo)
    tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device)
    tts_models[lang_code].eval()
print("Loading TTS models...")

def text_to_speech(text, lang="es"):
    if not text or lang not in tts_models:
        return None
    try:
        inputs = tts_tokenizers[lang](text, return_tensors="pt").to(device)
        with torch.no_grad():
            audio = tts_models[lang](**inputs).waveform
        audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16")
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            scipy.io.wavfile.write(f.name, rate=tts_models[lang].config.sampling_rate, data=audio_int16)
            return f.name
    except Exception as e:
        print(f"TTS error: {e}")
        return None

    except Exception as e:
        print(f"TTS error: {e}")
        return None


# ── LLM: HF Inference API + RAG ───────────────────────────────
SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.

Your role is to:
- Provide clear, compassionate guidance for caregiving challenges
- Suggest relevant local support services when available in the retrieved context
- Give practical, actionable advice
- Keep responses concise β€” under 120 words β€” so they are easy to listen to
- Always be encouraging and non-judgmental

If asked about local resources, ONLY reference services mentioned in the retrieved context. Do not invent services.
If no relevant local services are in the context, say so honestly.
Always remind caregivers that asking for help is a sign of strength, not weakness."""

def respond_to_message(message, history, lang="es"):
    if not message.strip():
        return ""

    client = InferenceClient(token=HF_TOKEN, model="openai/gpt-oss-20b")

    rag_context = retrieve_rag_context(message)
    full_system = f"{get_system_prompt(lang)}\n\n=== RETRIEVED CONTEXT ===\n{rag_context}"

    messages = [{"role": "system", "content": full_system}]
    for h in history[-6:]:
        if isinstance(h, dict):
            messages.append({"role": h["role"], "content": h["content"]})
    messages.append({"role": "user", "content": message})

    response = ""
    try:
        for chunk in client.chat_completion(
            messages,
            max_tokens=150,
            stream=True,
            temperature=0.7,
            top_p=0.95,
        ):
            if chunk.choices and chunk.choices[0].delta.content:
                response += chunk.choices[0].delta.content
        return response.strip()
    except Exception as e:
        print(f"LLM error: {e}")
        return "Ho sento, no puc generar una resposta en aquest moment." if lang=="ca" else "Lo siento, no puedo generar una respuesta en este momento."

# ── RAG debug ──────────────────────────────────────────
    print(f"Full system prompt length: {len(full_system)} chars")
    print(f"RAG context preview: {rag_context[:300]}")
    

# ── User Onboarding ─────────────────────────────
#
# For new user, initiate introductory conversation
# Capture user info and preferences 
### Adapt questions from Zarit Burden Interview, Caregiver Qual of Life Index, COPE inventory
### What is their knowledge of AD? How long have you been their caregiver? Self-rate stress level? Etc
#
# Capture care recipient info and preferences
### Adapt questions from functional staging tool (FAST), Global Deterioration Scale (GDS)
### Do they live alone, with caregiver, with someone else? Urban (public transportation) or suburban? (Driving)
### Is home smart-device enabled, or is it a possibilty? (Fire alarms, elopement alarms, bed alarms, auto-lighting, voice asst)
#
# Option to complete by voice or text
# Store in caregiver profile -> json
# Dynamic questions using responses to personalize. ("Hi, Maria. It's nice to meet you! Can you tell me more about...")
# Inject profile into LLM for personalization


# ── Pipelines ─────────────────────────────────────────────────
# ── Voice Pipeline with Language Support ─────────────────────────
def voice_pipeline(audio_input, history, tts_lang):
    # Transcribe audio
    transcript = transcribe_audio(audio_input)
    if not transcript:
        return history, None, "⚠️ Could not transcribe audio. Please try again."

    # Generate response from LLM + RAG
    reply = respond_to_message(transcript, history, tts_lang)

    # Update chat history
    history = history or []
    history.append({"role": "user", "content": transcript})
    history.append({"role": "assistant", "content": reply})

    # Convert to speech
    audio_out = text_to_speech(reply, tts_lang)
    return history, audio_out, f'"{transcript}"'

# ── Text Pipeline with Language Support ─────────────────────────
def text_pipeline(text_input, history, tts_lang):
    if not text_input.strip():
        return history, None, ""

    reply = respond_to_message(text_input, history, tts_lang)

    history = history or []
    history.append({"role": "user", "content": text_input})
    history.append({"role": "assistant", "content": reply})

    audio_out = text_to_speech(reply, tts_lang)
    return history, audio_out, ""


# ── Gradio UI ─────────────────────────────────────────────────
with gr.Blocks(
    theme=gr.themes.Soft(
        primary_hue="green",
        neutral_hue="slate",
        font=gr.themes.GoogleFont("DM Sans"),
    ),
    title="CareCompanion",
) as demo:

    chat_history = gr.State([])

    gr.Markdown("""
    # SherpaAI
    ### Smart support for AD caregivers in Barcelona
    """)

    with gr.Row():

        with gr.Column(scale=2):
            chatbot = gr.Chatbot(
                label="Conversation",
                height=420,
                type="messages",
                show_label=False,
                bubble_full_width=False,
            )

            audio_output = gr.Audio(
                label="πŸ”Š Voice Response",
                autoplay=True,
                show_download_button=False,
            )

        with gr.Column(scale=1):

            gr.Markdown("### 🎀 Voice Input")

            audio_input = gr.Audio(
                sources=["microphone"],
                type="filepath",
                label="Record your question",
            )

            voice_btn = gr.Button(
                "🎀 Send Voice Message",
                variant="primary",
                size="lg",
            )

            lang_selector = gr.Dropdown(
                ["es", "ca", "en"],
                value="es",
                label="Voice",
                info="Choose Spanish, Catalan, or English",
            )

            transcript_display = gr.Textbox(
                label="πŸ“ What you said",
                interactive=False,
                lines=2,
                placeholder="Your transcribed speech will appear here…",
            )

            gr.Markdown("---")
            gr.Markdown("### ⌨️ Text Input")

            text_input = gr.Textbox(
                placeholder="Or type your question here…",
                label="",
                lines=3,
            )

            text_btn = gr.Button(
                "➀ Send Text Message",
                variant="secondary",
                size="lg",
            )

    gr.Markdown("""
    ---
    *Responses are AI-generated and do not replace professional medical advice.*
    *In emergencies, call 112 or your local emergency services.*
    """)

    # Helper function
    def update_chatbot(history):
        return history

    # 🎀 Voice button click (NOW INSIDE BLOCKS)
    voice_btn.click(
        fn=voice_pipeline,
        inputs=[audio_input, chat_history, lang_selector],
        outputs=[chat_history, audio_output, transcript_display],
    ).then(
        fn=update_chatbot,
        inputs=[chat_history],
        outputs=[chatbot],
    )

    # ⌨️ Text button click
    text_btn.click(
        fn=text_pipeline,
        inputs=[text_input, chat_history, lang_selector],
        outputs=[chat_history, audio_output, transcript_display],
    ).then(
        fn=update_chatbot,
        inputs=[chat_history],
        outputs=[chatbot],
    )

    # ⌨️ Enter key submit
    text_input.submit(
        fn=text_pipeline,
        inputs=[text_input, chat_history, lang_selector],
        outputs=[chat_history, audio_output, transcript_display],
    ).then(
        fn=update_chatbot,
        inputs=[chat_history],
        outputs=[chatbot],
    )



if __name__ == "__main__":
    demo.launch()