Spaces:
Sleeping
Sleeping
| import os | |
| import faiss | |
| import pickle | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import scipy.io.wavfile | |
| import tempfile | |
| from huggingface_hub import hf_hub_download | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| from transformers import VitsModel, AutoTokenizer, pipeline | |
| # ββ Auth βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN is not set. Add it in Space Settings β Repository Secrets.") | |
| # ββ Device ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| print(f"Running on: {device}") | |
| # ββ RAG: FAISS index ββββββββββββββββββββββββββββββββββββββββββ | |
| print("Loading FAISS index...") | |
| FAISS_FILE = "alzheimers_index_233.faiss" | |
| CHUNKS_FILE = "chunks_233.pkl" | |
| index = faiss.read_index(FAISS_FILE) | |
| with open(CHUNKS_FILE, "rb") as f: | |
| chunks = pickle.load(f) | |
| print(f"Total chunks: {len(chunks)}") | |
| print(f"Type: {type(chunks[0])}") | |
| print(f"\n--- Chunk 0 ---\n{chunks[0]}") | |
| print(f"\n--- Chunk 1 ---\n{chunks[1]}") | |
| embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # ββ Index file debug ββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Loaded FAISS index with {index.ntotal} vectors") | |
| print(f"Loaded {len(chunks)} chunks") | |
| def retrieve_rag_context(query, k=5): | |
| query_embedding = embed_model.encode([query]) | |
| distances, indices = index.search(np.array(query_embedding), k) | |
| results = [] | |
| for i in indices[0]: | |
| chunk = chunks[i] | |
| print(f" RAG chunk: source={chunk.get('source')} text={chunk['text'][:80]}") | |
| results.append(chunk["text"]) | |
| return "\n\n---\n\n".join(results) | |
| # ββ Retrieval debug ββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Retrieved chunks: {[c.get('topic') for c in [chunks[i] for i in indices[0]]]}") | |
| # ββ SYSTEM PROMPTS βββββββββββββββββββββββββββββ | |
| def get_system_prompt(lang="es"): | |
| if lang == "ca": | |
| return """Ets un assistent cΓ lid i empΓ tic per a cuidadors de persones amb Alzheimer a Barcelona. | |
| Proporciona orientaciΓ³ clara, menciona serveis locals si existeixen en el context i mantΓ©n les respostes breus i comprensibles.""" | |
| elif lang == "en": | |
| return """You are a warm, empathetic assistant for caregivers of people with Alzheimer's in Barcelona. | |
| Provide clear guidance, mention local services if they appear in the context, and keep responses brief and easy to understand.""" | |
| else: | |
| return """Eres un asistente cΓ‘lido y empΓ‘tico para cuidadores de personas con Alzheimer en Barcelona. | |
| Proporciona orientaciΓ³n clara, menciona recursos locales si existen en el contexto y mantΓ©n las respuestas breves y comprensibles.""" | |
| # ββ STT: Distil-Whisper βββββββββββββββββββββββββββββββββββββββ | |
| print("Loading Whisper STT model...") | |
| stt_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="distil-whisper/distil-large-v3", | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| def transcribe_audio(audio_path): | |
| if audio_path is None: | |
| return "" | |
| result = stt_pipe(audio_path, generate_kwargs={"task": "transcribe"}, return_timestamps=False) | |
| transcript = result["text"].strip() | |
| return transcript | |
| def detect_language(text): | |
| try: | |
| lang = detect(text) | |
| if lang == "ca": | |
| return "ca" | |
| elif lang == "es": | |
| return "es" | |
| elif lang == "en": | |
| return "en" | |
| else: | |
| return "es" | |
| except: | |
| return "es" | |
| # ββ TTS: Parler TTS mini v1 βββββββββ | |
| print("Loading MMS TTS models...") | |
| tts_models, tts_tokenizers = {}, {} | |
| for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items(): | |
| tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo) | |
| tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device) | |
| tts_models[lang_code].eval() | |
| print("Loading TTS models...") | |
| def text_to_speech(text, lang="es"): | |
| if not text or lang not in tts_models: | |
| return None | |
| try: | |
| inputs = tts_tokenizers[lang](text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| audio = tts_models[lang](**inputs).waveform | |
| audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16") | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| scipy.io.wavfile.write(f.name, rate=tts_models[lang].config.sampling_rate, data=audio_int16) | |
| return f.name | |
| except Exception as e: | |
| print(f"TTS error: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"TTS error: {e}") | |
| return None | |
| # ββ LLM: HF Inference API + RAG βββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease. | |
| Your role is to: | |
| - Provide clear, compassionate guidance for caregiving challenges | |
| - Suggest relevant local support services when available in the retrieved context | |
| - Give practical, actionable advice | |
| - Keep responses concise β under 120 words β so they are easy to listen to | |
| - Always be encouraging and non-judgmental | |
| If asked about local resources, ONLY reference services mentioned in the retrieved context. Do not invent services. | |
| If no relevant local services are in the context, say so honestly. | |
| Always remind caregivers that asking for help is a sign of strength, not weakness.""" | |
| def respond_to_message(message, history, lang="es"): | |
| if not message.strip(): | |
| return "" | |
| client = InferenceClient(token=HF_TOKEN, model="openai/gpt-oss-20b") | |
| rag_context = retrieve_rag_context(message) | |
| full_system = f"{get_system_prompt(lang)}\n\n=== RETRIEVED CONTEXT ===\n{rag_context}" | |
| messages = [{"role": "system", "content": full_system}] | |
| for h in history[-6:]: | |
| if isinstance(h, dict): | |
| messages.append({"role": h["role"], "content": h["content"]}) | |
| messages.append({"role": "user", "content": message}) | |
| response = "" | |
| try: | |
| for chunk in client.chat_completion( | |
| messages, | |
| max_tokens=150, | |
| stream=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| ): | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| response += chunk.choices[0].delta.content | |
| return response.strip() | |
| except Exception as e: | |
| print(f"LLM error: {e}") | |
| return "Ho sento, no puc generar una resposta en aquest moment." if lang=="ca" else "Lo siento, no puedo generar una respuesta en este momento." | |
| # ββ RAG debug ββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Full system prompt length: {len(full_system)} chars") | |
| print(f"RAG context preview: {rag_context[:300]}") | |
| # ββ User Onboarding βββββββββββββββββββββββββββββ | |
| # | |
| # For new user, initiate introductory conversation | |
| # Capture user info and preferences | |
| ### Adapt questions from Zarit Burden Interview, Caregiver Qual of Life Index, COPE inventory | |
| ### What is their knowledge of AD? How long have you been their caregiver? Self-rate stress level? Etc | |
| # | |
| # Capture care recipient info and preferences | |
| ### Adapt questions from functional staging tool (FAST), Global Deterioration Scale (GDS) | |
| ### Do they live alone, with caregiver, with someone else? Urban (public transportation) or suburban? (Driving) | |
| ### Is home smart-device enabled, or is it a possibilty? (Fire alarms, elopement alarms, bed alarms, auto-lighting, voice asst) | |
| # | |
| # Option to complete by voice or text | |
| # Store in caregiver profile -> json | |
| # Dynamic questions using responses to personalize. ("Hi, Maria. It's nice to meet you! Can you tell me more about...") | |
| # Inject profile into LLM for personalization | |
| # ββ Pipelines βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ Voice Pipeline with Language Support βββββββββββββββββββββββββ | |
| def voice_pipeline(audio_input, history, tts_lang): | |
| # Transcribe audio | |
| transcript = transcribe_audio(audio_input) | |
| if not transcript: | |
| return history, None, "β οΈ Could not transcribe audio. Please try again." | |
| # Generate response from LLM + RAG | |
| reply = respond_to_message(transcript, history, tts_lang) | |
| # Update chat history | |
| history = history or [] | |
| history.append({"role": "user", "content": transcript}) | |
| history.append({"role": "assistant", "content": reply}) | |
| # Convert to speech | |
| audio_out = text_to_speech(reply, tts_lang) | |
| return history, audio_out, f'"{transcript}"' | |
| # ββ Text Pipeline with Language Support βββββββββββββββββββββββββ | |
| def text_pipeline(text_input, history, tts_lang): | |
| if not text_input.strip(): | |
| return history, None, "" | |
| reply = respond_to_message(text_input, history, tts_lang) | |
| history = history or [] | |
| history.append({"role": "user", "content": text_input}) | |
| history.append({"role": "assistant", "content": reply}) | |
| audio_out = text_to_speech(reply, tts_lang) | |
| return history, audio_out, "" | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="green", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("DM Sans"), | |
| ), | |
| title="CareCompanion", | |
| ) as demo: | |
| chat_history = gr.State([]) | |
| gr.Markdown(""" | |
| # SherpaAI | |
| ### Smart support for AD caregivers in Barcelona | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=420, | |
| type="messages", | |
| show_label=False, | |
| bubble_full_width=False, | |
| ) | |
| audio_output = gr.Audio( | |
| label="π Voice Response", | |
| autoplay=True, | |
| show_download_button=False, | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Voice Input") | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Record your question", | |
| ) | |
| voice_btn = gr.Button( | |
| "π€ Send Voice Message", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| lang_selector = gr.Dropdown( | |
| ["es", "ca", "en"], | |
| value="es", | |
| label="Voice", | |
| info="Choose Spanish, Catalan, or English", | |
| ) | |
| transcript_display = gr.Textbox( | |
| label="π What you said", | |
| interactive=False, | |
| lines=2, | |
| placeholder="Your transcribed speech will appear hereβ¦", | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### β¨οΈ Text Input") | |
| text_input = gr.Textbox( | |
| placeholder="Or type your question hereβ¦", | |
| label="", | |
| lines=3, | |
| ) | |
| text_btn = gr.Button( | |
| "β€ Send Text Message", | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| *Responses are AI-generated and do not replace professional medical advice.* | |
| *In emergencies, call 112 or your local emergency services.* | |
| """) | |
| # Helper function | |
| def update_chatbot(history): | |
| return history | |
| # π€ Voice button click (NOW INSIDE BLOCKS) | |
| voice_btn.click( | |
| fn=voice_pipeline, | |
| inputs=[audio_input, chat_history, lang_selector], | |
| outputs=[chat_history, audio_output, transcript_display], | |
| ).then( | |
| fn=update_chatbot, | |
| inputs=[chat_history], | |
| outputs=[chatbot], | |
| ) | |
| # β¨οΈ Text button click | |
| text_btn.click( | |
| fn=text_pipeline, | |
| inputs=[text_input, chat_history, lang_selector], | |
| outputs=[chat_history, audio_output, transcript_display], | |
| ).then( | |
| fn=update_chatbot, | |
| inputs=[chat_history], | |
| outputs=[chatbot], | |
| ) | |
| # β¨οΈ Enter key submit | |
| text_input.submit( | |
| fn=text_pipeline, | |
| inputs=[text_input, chat_history, lang_selector], | |
| outputs=[chat_history, audio_output, transcript_display], | |
| ).then( | |
| fn=update_chatbot, | |
| inputs=[chat_history], | |
| outputs=[chatbot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |