import os import faiss import pickle import numpy as np import gradio as gr import torch import scipy.io.wavfile import tempfile from huggingface_hub import hf_hub_download from sentence_transformers import SentenceTransformer from huggingface_hub import InferenceClient from transformers import VitsModel, AutoTokenizer, pipeline # ── Auth ─────────────────────────────────────────────────────── HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN is not set. Add it in Space Settings → Repository Secrets.") # ── Device ──────────────────────────────────────────────────── device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 print(f"Running on: {device}") # ── RAG: FAISS index ────────────────────────────────────────── print("Loading FAISS index...") FAISS_FILE = "alzheimers_index_233.faiss" CHUNKS_FILE = "chunks_233.pkl" index = faiss.read_index(FAISS_FILE) with open(CHUNKS_FILE, "rb") as f: chunks = pickle.load(f) print(f"Total chunks: {len(chunks)}") print(f"Type: {type(chunks[0])}") print(f"\n--- Chunk 0 ---\n{chunks[0]}") print(f"\n--- Chunk 1 ---\n{chunks[1]}") embed_model = SentenceTransformer("all-MiniLM-L6-v2") # ── Index file debug ────────────────────────────────────────── print(f"Loaded FAISS index with {index.ntotal} vectors") print(f"Loaded {len(chunks)} chunks") def retrieve_rag_context(query, k=5): query_embedding = embed_model.encode([query]) distances, indices = index.search(np.array(query_embedding), k) results = [] for i in indices[0]: chunk = chunks[i] print(f" RAG chunk: source={chunk.get('source')} text={chunk['text'][:80]}") results.append(chunk["text"]) return "\n\n---\n\n".join(results) # ── Retrieval debug ────────────────────────────────────────── print(f"Retrieved chunks: {[c.get('topic') for c in [chunks[i] for i in indices[0]]]}") # ── SYSTEM PROMPTS ───────────────────────────── def get_system_prompt(lang="es"): if lang == "ca": return """Ets un assistent càlid i empàtic per a cuidadors de persones amb Alzheimer a Barcelona. Proporciona orientació clara, menciona serveis locals si existeixen en el context i mantén les respostes breus i comprensibles.""" elif lang == "en": return """You are a warm, empathetic assistant for caregivers of people with Alzheimer's in Barcelona. Provide clear guidance, mention local services if they appear in the context, and keep responses brief and easy to understand.""" else: return """Eres un asistente cálido y empático para cuidadores de personas con Alzheimer en Barcelona. Proporciona orientación clara, menciona recursos locales si existen en el contexto y mantén las respuestas breves y comprensibles.""" # ── STT: Distil-Whisper ─────────────────────────────────────── print("Loading Whisper STT model...") stt_pipe = pipeline( "automatic-speech-recognition", model="distil-whisper/distil-large-v3", torch_dtype=torch_dtype, device=device, ) def transcribe_audio(audio_path): if audio_path is None: return "" result = stt_pipe(audio_path, generate_kwargs={"task": "transcribe"}, return_timestamps=False) transcript = result["text"].strip() return transcript def detect_language(text): try: lang = detect(text) if lang == "ca": return "ca" elif lang == "es": return "es" elif lang == "en": return "en" else: return "es" except: return "es" # ── TTS: Parler TTS mini v1 ───────── print("Loading MMS TTS models...") tts_models, tts_tokenizers = {}, {} for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items(): tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo) tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device) tts_models[lang_code].eval() print("Loading TTS models...") def text_to_speech(text, lang="es"): if not text or lang not in tts_models: return None try: inputs = tts_tokenizers[lang](text, return_tensors="pt").to(device) with torch.no_grad(): audio = tts_models[lang](**inputs).waveform audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16") with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: scipy.io.wavfile.write(f.name, rate=tts_models[lang].config.sampling_rate, data=audio_int16) return f.name except Exception as e: print(f"TTS error: {e}") return None except Exception as e: print(f"TTS error: {e}") return None # ── LLM: HF Inference API + RAG ─────────────────────────────── SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease. Your role is to: - Provide clear, compassionate guidance for caregiving challenges - Suggest relevant local support services when available in the retrieved context - Give practical, actionable advice - Keep responses concise — under 120 words — so they are easy to listen to - Always be encouraging and non-judgmental If asked about local resources, ONLY reference services mentioned in the retrieved context. Do not invent services. If no relevant local services are in the context, say so honestly. Always remind caregivers that asking for help is a sign of strength, not weakness.""" def respond_to_message(message, history, lang="es"): if not message.strip(): return "" client = InferenceClient(token=HF_TOKEN, model="openai/gpt-oss-20b") rag_context = retrieve_rag_context(message) full_system = f"{get_system_prompt(lang)}\n\n=== RETRIEVED CONTEXT ===\n{rag_context}" messages = [{"role": "system", "content": full_system}] for h in history[-6:]: if isinstance(h, dict): messages.append({"role": h["role"], "content": h["content"]}) messages.append({"role": "user", "content": message}) response = "" try: for chunk in client.chat_completion( messages, max_tokens=150, stream=True, temperature=0.7, top_p=0.95, ): if chunk.choices and chunk.choices[0].delta.content: response += chunk.choices[0].delta.content return response.strip() except Exception as e: print(f"LLM error: {e}") return "Ho sento, no puc generar una resposta en aquest moment." if lang=="ca" else "Lo siento, no puedo generar una respuesta en este momento." # ── RAG debug ────────────────────────────────────────── print(f"Full system prompt length: {len(full_system)} chars") print(f"RAG context preview: {rag_context[:300]}") # ── User Onboarding ───────────────────────────── # # For new user, initiate introductory conversation # Capture user info and preferences ### Adapt questions from Zarit Burden Interview, Caregiver Qual of Life Index, COPE inventory ### What is their knowledge of AD? How long have you been their caregiver? Self-rate stress level? Etc # # Capture care recipient info and preferences ### Adapt questions from functional staging tool (FAST), Global Deterioration Scale (GDS) ### Do they live alone, with caregiver, with someone else? Urban (public transportation) or suburban? (Driving) ### Is home smart-device enabled, or is it a possibilty? (Fire alarms, elopement alarms, bed alarms, auto-lighting, voice asst) # # Option to complete by voice or text # Store in caregiver profile -> json # Dynamic questions using responses to personalize. ("Hi, Maria. It's nice to meet you! Can you tell me more about...") # Inject profile into LLM for personalization # ── Pipelines ───────────────────────────────────────────────── # ── Voice Pipeline with Language Support ───────────────────────── def voice_pipeline(audio_input, history, tts_lang): # Transcribe audio transcript = transcribe_audio(audio_input) if not transcript: return history, None, "⚠️ Could not transcribe audio. Please try again." # Generate response from LLM + RAG reply = respond_to_message(transcript, history, tts_lang) # Update chat history history = history or [] history.append({"role": "user", "content": transcript}) history.append({"role": "assistant", "content": reply}) # Convert to speech audio_out = text_to_speech(reply, tts_lang) return history, audio_out, f'"{transcript}"' # ── Text Pipeline with Language Support ───────────────────────── def text_pipeline(text_input, history, tts_lang): if not text_input.strip(): return history, None, "" reply = respond_to_message(text_input, history, tts_lang) history = history or [] history.append({"role": "user", "content": text_input}) history.append({"role": "assistant", "content": reply}) audio_out = text_to_speech(reply, tts_lang) return history, audio_out, "" # ── Gradio UI ───────────────────────────────────────────────── with gr.Blocks( theme=gr.themes.Soft( primary_hue="green", neutral_hue="slate", font=gr.themes.GoogleFont("DM Sans"), ), title="CareCompanion", ) as demo: chat_history = gr.State([]) gr.Markdown(""" # SherpaAI ### Smart support for AD caregivers in Barcelona """) with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( label="Conversation", height=420, type="messages", show_label=False, bubble_full_width=False, ) audio_output = gr.Audio( label="🔊 Voice Response", autoplay=True, show_download_button=False, ) with gr.Column(scale=1): gr.Markdown("### 🎤 Voice Input") audio_input = gr.Audio( sources=["microphone"], type="filepath", label="Record your question", ) voice_btn = gr.Button( "🎤 Send Voice Message", variant="primary", size="lg", ) lang_selector = gr.Dropdown( ["es", "ca", "en"], value="es", label="Voice", info="Choose Spanish, Catalan, or English", ) transcript_display = gr.Textbox( label="📝 What you said", interactive=False, lines=2, placeholder="Your transcribed speech will appear here…", ) gr.Markdown("---") gr.Markdown("### ⌨️ Text Input") text_input = gr.Textbox( placeholder="Or type your question here…", label="", lines=3, ) text_btn = gr.Button( "➤ Send Text Message", variant="secondary", size="lg", ) gr.Markdown(""" --- *Responses are AI-generated and do not replace professional medical advice.* *In emergencies, call 112 or your local emergency services.* """) # Helper function def update_chatbot(history): return history # 🎤 Voice button click (NOW INSIDE BLOCKS) voice_btn.click( fn=voice_pipeline, inputs=[audio_input, chat_history, lang_selector], outputs=[chat_history, audio_output, transcript_display], ).then( fn=update_chatbot, inputs=[chat_history], outputs=[chatbot], ) # ⌨️ Text button click text_btn.click( fn=text_pipeline, inputs=[text_input, chat_history, lang_selector], outputs=[chat_history, audio_output, transcript_display], ).then( fn=update_chatbot, inputs=[chat_history], outputs=[chatbot], ) # ⌨️ Enter key submit text_input.submit( fn=text_pipeline, inputs=[text_input, chat_history, lang_selector], outputs=[chat_history, audio_output, transcript_display], ).then( fn=update_chatbot, inputs=[chat_history], outputs=[chatbot], ) if __name__ == "__main__": demo.launch()