Therapy_Ai / app.py
Pant0x's picture
Update app.py
2b57a17 verified
import gradio as gr
from transformers import pipeline
from gtts import gTTS
import tempfile
# ==========================================
# 1. LOAD MODELS (100% LOCAL FOR CPU)
# ==========================================
print("Booting up AI Therapist Pipeline...")
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny")
state_classifier = pipeline("text-classification", model="models/therapy_brain_V1", tokenizer="models/therapy_brain_V1")
chat_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
print("All models loaded successfully!")
# ==========================================
# 2. CORE LOGIC
# ==========================================
def generate_bot_response(user_text, history):
# 1. THE GREETING GUARDRAIL (Zero Hallucinations for "hi")
clean_text = user_text.lower().strip()
if clean_text in ["hi", "hello", "hey", "good morning", "good evening", "hi there", "sup"]:
bot_reply = "Hello there. I'm here to listen and support you. How are you feeling today?"
tts = gTTS(text=bot_reply, lang='en', slow=False)
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(temp_audio.name)
return bot_reply, "Neutral", temp_audio.name
# 2. Diagnose
pred = state_classifier(user_text)[0]
raw_label = pred['label']
label_map = {
"LABEL_0": "General Statistics",
"LABEL_1": "Negative Emotion",
"LABEL_2": "Neutral",
"LABEL_3": "Obsessive-Compulsive Disorder (OCD)",
"LABEL_4": "Positive Emotion",
"LABEL_5": "Post-Traumatic Stress Disorder (PTSD)",
"LABEL_6": "Severe Stress and Anxiety"
}
mental_state = label_map.get(raw_label, raw_label)
# 3. Prompt Setup
messages = [
{
"role": "system",
"content": (
"You are a professional clinical therapist talking directly to a patient. "
"Reply directly with your response. Do NOT write a script. Do NOT write [Patient] or [Therapist]. "
"Be warm, concise, and helpful. "
f"The patient's current diagnosed mental state is: {mental_state}."
)
}
]
# Add history
messages.extend(history)
messages.append({"role": "user", "content": user_text})
# 4. Generate
prompt = chat_pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = chat_pipeline(prompt, max_new_tokens=150, temperature=0.1, do_sample=True, return_full_text=False)
bot_reply = outputs[0]["generated_text"].strip()
# 5. Clean Hallucinated Tags (Destroys any script-writing attempts)
bot_reply = bot_reply.replace("[Therapist's voice]", "").replace("[Therapist]", "").replace("[Patient]", "").replace("Therapist:", "").replace("Patient:", "").strip()
# 6. Audio
tts = gTTS(text=bot_reply, lang='en', slow=False)
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(temp_audio.name)
return bot_reply, mental_state, temp_audio.name
def process_interaction(audio_filepath, text_input, history):
if history is None:
history = []
user_text = ""
if audio_filepath is not None:
transcription = asr_pipeline(audio_filepath)
user_text = transcription['text'].strip()
elif text_input is not None and text_input.strip() != "":
user_text = text_input.strip()
else:
return history, "🧠 Current State Detected: **Awaiting Input...**", None, None, "", history
# Get Response
bot_reply, state, bot_audio_path = generate_bot_response(user_text, history)
# THE ULTIMATE FIX: We feed Gradio EXACTLY the dictionary format it asked for in your screenshot!
history.append({"role": "user", "content": user_text})
history.append({"role": "assistant", "content": bot_reply})
formatted_diagnosis = f"🧠 Current State Detected: **{state.upper()}**"
return history, formatted_diagnosis, bot_audio_path, None, "", history
# ==========================================
# 3. GRADIO UI DESIGN
# ==========================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("# 🧠 Empathetic AI Therapist")
gr.Markdown("A private, local AI companion. Type a message or upload a voice note to begin.")
# ONE single memory using Dictionaries
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=2):
# No 'type' argument here so it doesn't crash older Gradio versions, but it accepts our dicts perfectly!
chatbot = gr.Chatbot(label="Therapy Session", height=400)
with gr.Row():
text_in = gr.Textbox(show_label=False, placeholder="Type how you are feeling here...", scale=3)
submit_btn = gr.Button("Send", variant="primary", scale=1)
gr.Markdown("---")
audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Voice Input")
with gr.Column(scale=1):
diagnosis_out = gr.Markdown("🧠 Current State Detected: **Awaiting Input...**")
audio_out = gr.Audio(label="Therapist Voice", autoplay=True)
gr.Markdown("""
### System Info
* **Voice Model:** Whisper-Tiny
* **Diagnostician:** Custom DistilBERT
* **Brain:** TinyLlama-1.1B (Guarded CPU Mode)
""")
# Triggers
submit_btn.click(
fn=process_interaction,
inputs=[audio_in, text_in, history_state],
outputs=[chatbot, diagnosis_out, audio_out, audio_in, text_in, history_state]
)
text_in.submit(
fn=process_interaction,
inputs=[audio_in, text_in, history_state],
outputs=[chatbot, diagnosis_out, audio_out, audio_in, text_in, history_state]
)
audio_in.change(
fn=process_interaction,
inputs=[audio_in, text_in, history_state],
outputs=[chatbot, diagnosis_out, audio_out, audio_in, text_in, history_state]
)
if __name__ == "__main__":
demo.launch()