|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
class MentalHealthChatbot: |
|
|
def __init__(self, model_path="hemhemoh/Gemma-2-2b-it-wazobia-wellness-bot"): |
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
self.predefined_instruction = ( |
|
|
"CORE THERAPEUTIC GUIDELINES:\n" |
|
|
"1. Professional Role: You are an advanced AI mental health support assistant, " |
|
|
"trained to provide compassionate, culturally sensitive, and professional psychological support.\n" |
|
|
"2. Communication Principles:\n" |
|
|
" - Listen actively and empathetically\n" |
|
|
" - Validate the user's emotions without judgment\n" |
|
|
" - Provide support that is culturally nuanced (English, Yoruba, Igbo, Hausa)\n" |
|
|
" - Maintain professional boundaries\n" |
|
|
"3. Response Strategy:\n" |
|
|
" - Use a warm, supportive, and non-threatening tone\n" |
|
|
" - Ask open-ended, reflective questions\n" |
|
|
" - Offer practical coping strategies\n" |
|
|
" - Avoid direct medical diagnosis\n" |
|
|
"4. Safety Protocol:\n" |
|
|
" - If user expresses thoughts of self-harm or suicide, respond with immediate compassion and provide crisis resource information\n" |
|
|
" - Never minimize the user's feelings\n" |
|
|
" - Encourage professional help when issues seem complex\n" |
|
|
"5. Linguistic Flexibility:\n" |
|
|
" - Respond in the language of the user\n" |
|
|
" - Use culturally appropriate language and metaphors\n" |
|
|
) |
|
|
|
|
|
def prepare_prompt(self, message, history): |
|
|
|
|
|
conversation_context = "" |
|
|
for user_input, assistant_response in history: |
|
|
if user_input: |
|
|
conversation_context += f"User: {user_input}\n" |
|
|
if assistant_response: |
|
|
conversation_context += f"Assistant: {assistant_response}\n" |
|
|
|
|
|
|
|
|
full_prompt = ( |
|
|
f"{self.predefined_instruction}\n" |
|
|
f"CONVERSATION HISTORY:\n{conversation_context}\n" |
|
|
f"CURRENT USER MESSAGE:\n{message}\n" |
|
|
f"ASSISTANT'S COMPASSIONATE RESPONSE:" |
|
|
) |
|
|
|
|
|
return full_prompt |
|
|
|
|
|
def respond( |
|
|
self, |
|
|
message, |
|
|
history, |
|
|
max_tokens=512, |
|
|
temperature=0.2, |
|
|
top_p=0.5 |
|
|
): |
|
|
|
|
|
full_prompt = self.prepare_prompt(message, history) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_length=max(512, inputs["input_ids"].shape[1] + max_tokens) |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
no_repeat_ngram_size=3, |
|
|
do_sample=True |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "ASSISTANT'S COMPASSIONATE RESPONSE:" in response: |
|
|
response = response.split("ASSISTANT'S COMPASSIONATE RESPONSE:")[-1].strip() |
|
|
|
|
|
yield response |
|
|
|
|
|
def main(): |
|
|
chatbot = MentalHealthChatbot() |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
chatbot.respond, |
|
|
additional_inputs=[ |
|
|
gr.Slider(minimum=1, maximum=512, value=170, step=1, label="Max new tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"), |
|
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.05, label="Top-p (nucleus sampling)"), |
|
|
], |
|
|
title="Mental Health Support Chatbot", |
|
|
description="An AI assistant providing compassionate, culturally sensitive mental health support." |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |