SherpaAI / app.py
hbchiu's picture
Update app.py
60c5cb8 verified
import os
import faiss
import pickle
import numpy as np
import gradio as gr
import torch
import scipy.io.wavfile
import tempfile
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
from transformers import VitsModel, AutoTokenizer, pipeline
# ── Auth ───────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN is not set. Add it in Space Settings β†’ Repository Secrets.")
# ── Device ────────────────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
print(f"Running on: {device}")
# ── RAG: FAISS index ──────────────────────────────────────────
print("Loading FAISS index...")
FAISS_FILE = "alzheimers_index_233.faiss"
CHUNKS_FILE = "chunks_233.pkl"
index = faiss.read_index(FAISS_FILE)
with open(CHUNKS_FILE, "rb") as f:
chunks = pickle.load(f)
print(f"Total chunks: {len(chunks)}")
print(f"Type: {type(chunks[0])}")
print(f"\n--- Chunk 0 ---\n{chunks[0]}")
print(f"\n--- Chunk 1 ---\n{chunks[1]}")
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# ── Index file debug ──────────────────────────────────────────
print(f"Loaded FAISS index with {index.ntotal} vectors")
print(f"Loaded {len(chunks)} chunks")
def retrieve_rag_context(query, k=5):
query_embedding = embed_model.encode([query])
distances, indices = index.search(np.array(query_embedding), k)
results = []
for i in indices[0]:
chunk = chunks[i]
print(f" RAG chunk: source={chunk.get('source')} text={chunk['text'][:80]}")
results.append(chunk["text"])
return "\n\n---\n\n".join(results)
# ── Retrieval debug ──────────────────────────────────────────
print(f"Retrieved chunks: {[c.get('topic') for c in [chunks[i] for i in indices[0]]]}")
# ── SYSTEM PROMPTS ─────────────────────────────
def get_system_prompt(lang="es"):
if lang == "ca":
return """Ets un assistent cΓ lid i empΓ tic per a cuidadors de persones amb Alzheimer a Barcelona.
Proporciona orientaciΓ³ clara, menciona serveis locals si existeixen en el context i mantΓ©n les respostes breus i comprensibles."""
elif lang == "en":
return """You are a warm, empathetic assistant for caregivers of people with Alzheimer's in Barcelona.
Provide clear guidance, mention local services if they appear in the context, and keep responses brief and easy to understand."""
else:
return """Eres un asistente cΓ‘lido y empΓ‘tico para cuidadores de personas con Alzheimer en Barcelona.
Proporciona orientaciΓ³n clara, menciona recursos locales si existen en el contexto y mantΓ©n las respuestas breves y comprensibles."""
# ── STT: Distil-Whisper ───────────────────────────────────────
print("Loading Whisper STT model...")
stt_pipe = pipeline(
"automatic-speech-recognition",
model="distil-whisper/distil-large-v3",
torch_dtype=torch_dtype,
device=device,
)
def transcribe_audio(audio_path):
if audio_path is None:
return ""
result = stt_pipe(audio_path, generate_kwargs={"task": "transcribe"}, return_timestamps=False)
transcript = result["text"].strip()
return transcript
def detect_language(text):
try:
lang = detect(text)
if lang == "ca":
return "ca"
elif lang == "es":
return "es"
elif lang == "en":
return "en"
else:
return "es"
except:
return "es"
# ── TTS: Parler TTS mini v1 ─────────
print("Loading MMS TTS models...")
tts_models, tts_tokenizers = {}, {}
for lang_code, repo in {"en": "facebook/mms-tts-eng", "es": "facebook/mms-tts-spa", "ca": "facebook/mms-tts-cat"}.items():
tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(repo)
tts_models[lang_code] = VitsModel.from_pretrained(repo).to(device)
tts_models[lang_code].eval()
print("Loading TTS models...")
def text_to_speech(text, lang="es"):
if not text or lang not in tts_models:
return None
try:
inputs = tts_tokenizers[lang](text, return_tensors="pt").to(device)
with torch.no_grad():
audio = tts_models[lang](**inputs).waveform
audio_int16 = (audio.squeeze().cpu().float().numpy() * 32767).clip(-32768, 32767).astype("int16")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
scipy.io.wavfile.write(f.name, rate=tts_models[lang].config.sampling_rate, data=audio_int16)
return f.name
except Exception as e:
print(f"TTS error: {e}")
return None
except Exception as e:
print(f"TTS error: {e}")
return None
# ── LLM: HF Inference API + RAG ───────────────────────────────
SYSTEM_PROMPT = """You are a warm, calm, and knowledgeable support assistant for caregivers of people with Alzheimer's disease.
Your role is to:
- Provide clear, compassionate guidance for caregiving challenges
- Suggest relevant local support services when available in the retrieved context
- Give practical, actionable advice
- Keep responses concise β€” under 120 words β€” so they are easy to listen to
- Always be encouraging and non-judgmental
If asked about local resources, ONLY reference services mentioned in the retrieved context. Do not invent services.
If no relevant local services are in the context, say so honestly.
Always remind caregivers that asking for help is a sign of strength, not weakness."""
def respond_to_message(message, history, lang="es"):
if not message.strip():
return ""
client = InferenceClient(token=HF_TOKEN, model="openai/gpt-oss-20b")
rag_context = retrieve_rag_context(message)
full_system = f"{get_system_prompt(lang)}\n\n=== RETRIEVED CONTEXT ===\n{rag_context}"
messages = [{"role": "system", "content": full_system}]
for h in history[-6:]:
if isinstance(h, dict):
messages.append({"role": h["role"], "content": h["content"]})
messages.append({"role": "user", "content": message})
response = ""
try:
for chunk in client.chat_completion(
messages,
max_tokens=150,
stream=True,
temperature=0.7,
top_p=0.95,
):
if chunk.choices and chunk.choices[0].delta.content:
response += chunk.choices[0].delta.content
return response.strip()
except Exception as e:
print(f"LLM error: {e}")
return "Ho sento, no puc generar una resposta en aquest moment." if lang=="ca" else "Lo siento, no puedo generar una respuesta en este momento."
# ── RAG debug ──────────────────────────────────────────
print(f"Full system prompt length: {len(full_system)} chars")
print(f"RAG context preview: {rag_context[:300]}")
# ── User Onboarding ─────────────────────────────
#
# For new user, initiate introductory conversation
# Capture user info and preferences
### Adapt questions from Zarit Burden Interview, Caregiver Qual of Life Index, COPE inventory
### What is their knowledge of AD? How long have you been their caregiver? Self-rate stress level? Etc
#
# Capture care recipient info and preferences
### Adapt questions from functional staging tool (FAST), Global Deterioration Scale (GDS)
### Do they live alone, with caregiver, with someone else? Urban (public transportation) or suburban? (Driving)
### Is home smart-device enabled, or is it a possibilty? (Fire alarms, elopement alarms, bed alarms, auto-lighting, voice asst)
#
# Option to complete by voice or text
# Store in caregiver profile -> json
# Dynamic questions using responses to personalize. ("Hi, Maria. It's nice to meet you! Can you tell me more about...")
# Inject profile into LLM for personalization
# ── Pipelines ─────────────────────────────────────────────────
# ── Voice Pipeline with Language Support ─────────────────────────
def voice_pipeline(audio_input, history, tts_lang):
# Transcribe audio
transcript = transcribe_audio(audio_input)
if not transcript:
return history, None, "⚠️ Could not transcribe audio. Please try again."
# Generate response from LLM + RAG
reply = respond_to_message(transcript, history, tts_lang)
# Update chat history
history = history or []
history.append({"role": "user", "content": transcript})
history.append({"role": "assistant", "content": reply})
# Convert to speech
audio_out = text_to_speech(reply, tts_lang)
return history, audio_out, f'"{transcript}"'
# ── Text Pipeline with Language Support ─────────────────────────
def text_pipeline(text_input, history, tts_lang):
if not text_input.strip():
return history, None, ""
reply = respond_to_message(text_input, history, tts_lang)
history = history or []
history.append({"role": "user", "content": text_input})
history.append({"role": "assistant", "content": reply})
audio_out = text_to_speech(reply, tts_lang)
return history, audio_out, ""
# ── Gradio UI ─────────────────────────────────────────────────
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="green",
neutral_hue="slate",
font=gr.themes.GoogleFont("DM Sans"),
),
title="CareCompanion",
) as demo:
chat_history = gr.State([])
gr.Markdown("""
# SherpaAI
### Smart support for AD caregivers in Barcelona
""")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversation",
height=420,
type="messages",
show_label=False,
bubble_full_width=False,
)
audio_output = gr.Audio(
label="πŸ”Š Voice Response",
autoplay=True,
show_download_button=False,
)
with gr.Column(scale=1):
gr.Markdown("### 🎀 Voice Input")
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record your question",
)
voice_btn = gr.Button(
"🎀 Send Voice Message",
variant="primary",
size="lg",
)
lang_selector = gr.Dropdown(
["es", "ca", "en"],
value="es",
label="Voice",
info="Choose Spanish, Catalan, or English",
)
transcript_display = gr.Textbox(
label="πŸ“ What you said",
interactive=False,
lines=2,
placeholder="Your transcribed speech will appear here…",
)
gr.Markdown("---")
gr.Markdown("### ⌨️ Text Input")
text_input = gr.Textbox(
placeholder="Or type your question here…",
label="",
lines=3,
)
text_btn = gr.Button(
"➀ Send Text Message",
variant="secondary",
size="lg",
)
gr.Markdown("""
---
*Responses are AI-generated and do not replace professional medical advice.*
*In emergencies, call 112 or your local emergency services.*
""")
# Helper function
def update_chatbot(history):
return history
# 🎀 Voice button click (NOW INSIDE BLOCKS)
voice_btn.click(
fn=voice_pipeline,
inputs=[audio_input, chat_history, lang_selector],
outputs=[chat_history, audio_output, transcript_display],
).then(
fn=update_chatbot,
inputs=[chat_history],
outputs=[chatbot],
)
# ⌨️ Text button click
text_btn.click(
fn=text_pipeline,
inputs=[text_input, chat_history, lang_selector],
outputs=[chat_history, audio_output, transcript_display],
).then(
fn=update_chatbot,
inputs=[chat_history],
outputs=[chatbot],
)
# ⌨️ Enter key submit
text_input.submit(
fn=text_pipeline,
inputs=[text_input, chat_history, lang_selector],
outputs=[chat_history, audio_output, transcript_display],
).then(
fn=update_chatbot,
inputs=[chat_history],
outputs=[chatbot],
)
if __name__ == "__main__":
demo.launch()