Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| from gtts import gTTS | |
| import tempfile | |
| # ========================================== | |
| # 1. LOAD MODELS (100% LOCAL FOR CPU) | |
| # ========================================== | |
| print("Booting up AI Therapist Pipeline...") | |
| asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-tiny") | |
| state_classifier = pipeline("text-classification", model="models/therapy_brain_V1", tokenizer="models/therapy_brain_V1") | |
| chat_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| print("All models loaded successfully!") | |
| # ========================================== | |
| # 2. CORE LOGIC | |
| # ========================================== | |
| def generate_bot_response(user_text, history): | |
| # 1. THE GREETING GUARDRAIL (Zero Hallucinations for "hi") | |
| clean_text = user_text.lower().strip() | |
| if clean_text in ["hi", "hello", "hey", "good morning", "good evening", "hi there", "sup"]: | |
| bot_reply = "Hello there. I'm here to listen and support you. How are you feeling today?" | |
| tts = gTTS(text=bot_reply, lang='en', slow=False) | |
| temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
| tts.save(temp_audio.name) | |
| return bot_reply, "Neutral", temp_audio.name | |
| # 2. Diagnose | |
| pred = state_classifier(user_text)[0] | |
| raw_label = pred['label'] | |
| label_map = { | |
| "LABEL_0": "General Statistics", | |
| "LABEL_1": "Negative Emotion", | |
| "LABEL_2": "Neutral", | |
| "LABEL_3": "Obsessive-Compulsive Disorder (OCD)", | |
| "LABEL_4": "Positive Emotion", | |
| "LABEL_5": "Post-Traumatic Stress Disorder (PTSD)", | |
| "LABEL_6": "Severe Stress and Anxiety" | |
| } | |
| mental_state = label_map.get(raw_label, raw_label) | |
| # 3. Prompt Setup | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a professional clinical therapist talking directly to a patient. " | |
| "Reply directly with your response. Do NOT write a script. Do NOT write [Patient] or [Therapist]. " | |
| "Be warm, concise, and helpful. " | |
| f"The patient's current diagnosed mental state is: {mental_state}." | |
| ) | |
| } | |
| ] | |
| # Add history | |
| messages.extend(history) | |
| messages.append({"role": "user", "content": user_text}) | |
| # 4. Generate | |
| prompt = chat_pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| outputs = chat_pipeline(prompt, max_new_tokens=150, temperature=0.1, do_sample=True, return_full_text=False) | |
| bot_reply = outputs[0]["generated_text"].strip() | |
| # 5. Clean Hallucinated Tags (Destroys any script-writing attempts) | |
| bot_reply = bot_reply.replace("[Therapist's voice]", "").replace("[Therapist]", "").replace("[Patient]", "").replace("Therapist:", "").replace("Patient:", "").strip() | |
| # 6. Audio | |
| tts = gTTS(text=bot_reply, lang='en', slow=False) | |
| temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
| tts.save(temp_audio.name) | |
| return bot_reply, mental_state, temp_audio.name | |
| def process_interaction(audio_filepath, text_input, history): | |
| if history is None: | |
| history = [] | |
| user_text = "" | |
| if audio_filepath is not None: | |
| transcription = asr_pipeline(audio_filepath) | |
| user_text = transcription['text'].strip() | |
| elif text_input is not None and text_input.strip() != "": | |
| user_text = text_input.strip() | |
| else: | |
| return history, "🧠 Current State Detected: **Awaiting Input...**", None, None, "", history | |
| # Get Response | |
| bot_reply, state, bot_audio_path = generate_bot_response(user_text, history) | |
| # THE ULTIMATE FIX: We feed Gradio EXACTLY the dictionary format it asked for in your screenshot! | |
| history.append({"role": "user", "content": user_text}) | |
| history.append({"role": "assistant", "content": bot_reply}) | |
| formatted_diagnosis = f"🧠 Current State Detected: **{state.upper()}**" | |
| return history, formatted_diagnosis, bot_audio_path, None, "", history | |
| # ========================================== | |
| # 3. GRADIO UI DESIGN | |
| # ========================================== | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown("# 🧠 Empathetic AI Therapist") | |
| gr.Markdown("A private, local AI companion. Type a message or upload a voice note to begin.") | |
| # ONE single memory using Dictionaries | |
| history_state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # No 'type' argument here so it doesn't crash older Gradio versions, but it accepts our dicts perfectly! | |
| chatbot = gr.Chatbot(label="Therapy Session", height=400) | |
| with gr.Row(): | |
| text_in = gr.Textbox(show_label=False, placeholder="Type how you are feeling here...", scale=3) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| gr.Markdown("---") | |
| audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Voice Input") | |
| with gr.Column(scale=1): | |
| diagnosis_out = gr.Markdown("🧠 Current State Detected: **Awaiting Input...**") | |
| audio_out = gr.Audio(label="Therapist Voice", autoplay=True) | |
| gr.Markdown(""" | |
| ### System Info | |
| * **Voice Model:** Whisper-Tiny | |
| * **Diagnostician:** Custom DistilBERT | |
| * **Brain:** TinyLlama-1.1B (Guarded CPU Mode) | |
| """) | |
| # Triggers | |
| submit_btn.click( | |
| fn=process_interaction, | |
| inputs=[audio_in, text_in, history_state], | |
| outputs=[chatbot, diagnosis_out, audio_out, audio_in, text_in, history_state] | |
| ) | |
| text_in.submit( | |
| fn=process_interaction, | |
| inputs=[audio_in, text_in, history_state], | |
| outputs=[chatbot, diagnosis_out, audio_out, audio_in, text_in, history_state] | |
| ) | |
| audio_in.change( | |
| fn=process_interaction, | |
| inputs=[audio_in, text_in, history_state], | |
| outputs=[chatbot, diagnosis_out, audio_out, audio_in, text_in, history_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |