Therapyspace / src /streamlit_app.py
amberroohee's picture
Update src/streamlit_app.py
eb2394a verified
import streamlit as st
from gtts import gTTS
from io import BytesIO
import json
import datetime
import re
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
# Page config
st.set_page_config(
page_title="شفیق - AI Mental Health Assistant",
page_icon="🧠",
layout="wide"
)
# ============== LOAD MENTAL HEALTH BERT MODEL ==============
@st.cache_resource
def load_mental_health_model():
"""Load the mental health diagnosis model"""
try:
# Primary model: Mental Health BERT
model_name = "mental/mental-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Create pipeline
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True,
device=-1 # Use CPU (change to 0 if you have GPU)
)
return classifier, model.config.id2label
except Exception as e:
st.error(f"Model loading error: {e}")
return None, None
@st.cache_resource
def load_emotion_model():
"""Load emotion detection model"""
try:
emotion_classifier = pipeline(
"text-classification",
model="j-hartmann/emotion-english-distilroberta-base",
return_all_scores=True,
device=-1
)
return emotion_classifier
except:
return None
@st.cache_resource
def load_suicide_risk_model():
"""Load suicide risk detection model"""
try:
# Using a general classifier for risk assessment
risk_classifier = pipeline(
"text-classification",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=-1
)
return risk_classifier
except:
return None
# Load models
with st.spinner("🔄 AI models loading... please wait"):
mental_health_classifier, id2label = load_mental_health_model()
emotion_classifier = load_emotion_model()
risk_classifier = load_suicide_risk_model()
# ============== CSS STYLING ==============
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Noto+Nastaliq+Urdu&display=swap');
.urdu-text {
font-family: 'Noto Nastaliq Urdu', 'Jameel Noori Nastaleeq', serif;
direction: rtl;
text-align: right;
line-height: 2.5;
font-size: 20px;
}
.diagnosis-box {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
border-radius: 15px;
margin: 10px 0;
}
.risk-high { border-left: 5px solid #ff4757; background: #ffebee; }
.risk-medium { border-left: 5px solid #ffa502; background: #fff3e0; }
.risk-low { border-left: 5px solid #2ed573; background: #e8f5e9; }
.metric-card {
background: white;
padding: 15px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
text-align: center;
}
.severity-critical { color: #ff4757; font-weight: bold; }
.severity-high { color: #ff6348; font-weight: bold; }
.severity-moderate { color: #ffa502; font-weight: bold; }
.severity-low { color: #2ed573; font-weight: bold; }
</style>
""", unsafe_allow_html=True)
# ============== MENTAL HEALTH ANALYSIS FUNCTIONS ==============
def analyze_mental_health(text):
"""
Use BERT model to analyze mental health conditions
Returns: dict with conditions, scores, and severity
"""
results = {
'primary_condition': 'unknown',
'confidence': 0.0,
'all_conditions': {},
'severity': 'low',
'risk_factors': [],
'recommendations': []
}
if mental_health_classifier is None:
return fallback_analysis(text)
try:
# Get predictions from BERT model
predictions = mental_health_classifier(text[:512]) # Limit text length
if predictions and len(predictions) > 0:
scores = predictions[0]
# Sort by score
sorted_scores = sorted(scores, key=lambda x: x['score'], reverse=True)
# Get primary condition
primary = sorted_scores[0]
results['primary_condition'] = primary['label']
results['confidence'] = round(primary['score'] * 100, 2)
# Store all conditions
for item in sorted_scores:
results['all_conditions'][item['label']] = round(item['score'] * 100, 2)
# Determine severity based on confidence and condition type
high_risk_conditions = ['suicidal', 'self-harm', 'severe-depression', 'psychosis']
medium_risk_conditions = ['depression', 'anxiety', 'ptsd', 'bipolar']
if results['primary_condition'].lower() in high_risk_conditions or results['confidence'] > 85:
results['severity'] = 'critical'
elif results['primary_condition'].lower() in medium_risk_conditions or results['confidence'] > 70:
results['severity'] = 'moderate-high'
elif results['confidence'] > 50:
results['severity'] = 'moderate'
else:
results['severity'] = 'low'
# Extract risk factors from text
results['risk_factors'] = extract_risk_factors(text)
# Generate recommendations
results['recommendations'] = generate_recommendations(
results['primary_condition'],
results['severity']
)
except Exception as e:
st.error(f"Analysis error: {e}")
return fallback_analysis(text)
return results
def fallback_analysis(text):
"""Fallback when BERT model fails"""
text_lower = text.lower()
# Keyword-based fallback
conditions = {
'depression': ['اداس', 'مایوس', 'udas', 'mayoos', 'hopeless', 'khamoshi', 'تنہا'],
'anxiety': ['پریشان', 'ghabrahat', 'tension', 'fikar', 'bechaini', 'گھبراہٹ'],
'ptsd': ['خوف', 'khoof', 'nightmare', 'flashback', 'حادثہ', 'trauma'],
'suicidal': ['خودکشی', 'mar jaun', 'موت', 'zehar', 'مرنا', 'khatam'],
'stress': ['tension', 'دباؤ', 'stress', 'bojh', 'بوجھ', 'pressure']
}
detected = {}
for condition, keywords in conditions.items():
score = sum(1 for kw in keywords if kw in text_lower)
if score > 0:
detected[condition] = min(score * 20, 100)
if not detected:
return {
'primary_condition': 'unknown',
'confidence': 0,
'all_conditions': {},
'severity': 'low',
'risk_factors': [],
'recommendations': ['general_support']
}
primary = max(detected, key=detected.get)
return {
'primary_condition': primary,
'confidence': detected[primary],
'all_conditions': detected,
'severity': 'moderate' if detected[primary] > 50 else 'low',
'risk_factors': extract_risk_factors(text),
'recommendations': generate_recommendations(primary, 'moderate')
}
def extract_risk_factors(text):
"""Extract specific risk factors from text"""
text_lower = text.lower()
factors = []
risk_indicators = {
'sleep_issues': ['نیند', 'neend', 'neend nahi', 'جاگنا', 'so nahi pa raha'],
'social_isolation': ['اکیلا', 'tanha', 'koi nahi', 'دور', 'alone'],
'substance_abuse': ['شراب', 'drugs', 'nasha', 'سیگریٹ', 'smoking'],
'self_harm_history': ['زخم', 'cutting', 'khud ko chot', 'خون'],
'family_history': ['ghar mein', 'والدین', 'maa baap', 'خاندان'],
'work_stress': ['نوکری', 'job', 'kaam', 'boss', 'office', 'پیسے']
}
for factor, keywords in risk_indicators.items():
if any(kw in text_lower for kw in keywords):
factors.append(factor)
return factors
def generate_recommendations(condition, severity):
"""Generate therapeutic recommendations"""
recommendations = {
'critical': [
"🚨 فوری پیشہ ورانہ مدد ضروری ہے",
"کسی قریبی ہسپتال یا کلینک جائیں",
"کسی قریبی رشتہ دار کو مطلع کریں",
"ہیلپ لائن 1122 پر کال کریں"
],
'moderate-high': [
"پیشہ ورانہ مدد مشورہ دینا چاہیے",
"نفسیاتی ماہر سے ملاقات کریں",
"مستقل مانیٹرنگ ضروری ہے",
"دواؤں پر غور کریں"
],
'moderate': [
"کاؤنسلنگ سے فائدہ ہوگا",
"مشقیں اور تھراپی جاری رکھیں",
"دوستوں سے بات کریں",
"ورزش اور冥思 کریں"
],
'low': [
"خود مدد کی تکنیکیں استعمال کریں",
"مثبت سرگرمیاں جاری رکھیں",
"ضرورت ہو تو کاؤنسلنگ کریں"
]
}
specific_recs = {
'depression': ["روزانہ شیڈول بنائیں", "کھیل کود میں حصہ لیں", "نیند درست کریں"],
'anxiety': ["گہری سانسیں لیں", "زمینی حقیقتوں پر توجہ دیں", "پیشہ ورانہ مدد لیں"],
'ptsd': ["ٹریما سے نمٹنے کی تربیت", "محفوظ ماحول بنائیں", "پیشہ ورانہ تھراپی"],
'suicidal': ["فوری ہسپتال جائیں", "کسی کو بتائیں", "ہتھیار دور رکھیں"]
}
base_recs = recommendations.get(severity, recommendations['low'])
specific = specific_recs.get(condition, [])
return base_recs + specific
def analyze_emotion_enhanced(text):
"""Enhanced emotion analysis using BERT"""
if emotion_classifier:
try:
results = emotion_classifier(text[:512])
if results:
emotions = {item['label']: item['score'] for item in results[0]}
dominant = max(emotions, key=emotions.get)
return dominant, emotions
except:
pass
# Fallback to keyword
return detect_emotion_fallback(text)
def detect_emotion_fallback(text):
"""Fallback emotion detection"""
text_lower = text.lower()
emotion_keywords = {
'sadness': ['اداس', 'rona', 'udas', 'غم', 'dukh', 'tanha'],
'fear': ['ڈر', 'khoof', 'ghabrahat', 'fear', 'خوف'],
'anger': ['غصہ', 'gussa', 'naraz', 'angry'],
'joy': ['خوش', 'khush', 'happy', 'khushi'],
'surprise': ['حیران', 'heran', ' shocked', 'واہ'],
'disgust': ['نفرت', 'nafrat', 'گھن', 'ghin']
}
scores = {}
for emotion, keywords in emotion_keywords.items():
scores[emotion] = sum(1 for kw in keywords if kw in text_lower)
if max(scores.values()) == 0:
return 'neutral', {'neutral': 1.0}
dominant = max(scores, key=scores.get)
total = sum(scores.values())
normalized = {k: v/total for k, v in scores.items()}
return dominant, normalized
def get_therapeutic_response_enhanced(mental_health_data, emotion, text):
"""Generate response based on BERT diagnosis"""
condition = mental_health_data['primary_condition']
severity = mental_health_data['severity']
confidence = mental_health_data['confidence']
# Crisis response for critical cases
if severity == 'critical':
return """
🚨 **اہم انتباہ / Critical Alert**
میں نوٹ کر رہا ہوں کہ آپ بہت پریشان ہیں۔ آپ کی زندگی قیمتی ہے۔
**فوری اقدامات:**
- 📞 ہیلپ لائن: 1122
- 🏥 قریبی ہسپتال جائیں
- 👨‍👩‍👧 کسی کو بتائیں
آپ اکیلے نہیں ہیں۔ مدد دستیاب ہے۔
"""
# Condition-specific responses
responses = {
'depression': [
f"میں سمجھتا ہوں آپ {condition} کا سامنا کر رہے ہیں ({confidence}% یقین)۔",
"یہ ایک طبی حالت ہے جو علاج سے ٹھیک ہو سکتی ہے۔",
"پیشہ ورانہ مدد سے آپ بہتر محسوس کریں گے۔"
],
'anxiety': [
f"آپ کے {condition} کی نشانیاں نظر آ رہی ہیں ({confidence}% یقین)۔",
"گہری سانسیں اور زمینی تکنیکیں مددگار ثابت ہو سکتی ہیں۔",
"یہ قابل علاج ہے، امید رکھیں۔"
],
'ptsd': [
f"ممکنہ طور پر {condition} کے اثرات ({confidence}% یقین)۔",
"ٹریما بہت گہرا ہوتا ہے، پیشہ ورانہ مدد ضروری ہے۔",
"آپ محفوظ ہیں، یہ احساس گزر جائے گا۔"
],
'stress': [
"زندگی کے دباؤ آپ پر بھاری ہو رہے ہیں۔",
"چھوٹے وقفے لیں، خود کو ترجیح دیں۔",
"تناؤ کا نظم سیکھنا ضروری ہے۔"
]
}
base_response = responses.get(condition, responses.get('stress', ["میں آپ کی مدد کرنا چاہتا ہوں۔"]))
# Add recommendations
recs = mental_health_data.get('recommendations', [])
if recs:
base_response.append("\n**سفارشات:**")
for i, rec in enumerate(recs[:3], 1):
base_response.append(f"{i}. {rec}")
return "\n\n".join(base_response)
def text_to_speech(text):
"""Convert text to Urdu speech"""
try:
# Clean text for TTS
clean_text = re.sub(r'[^\w\s\u0600-\u06FF]', ' ', text)
clean_text = clean_text[:500] # Limit length
tts = gTTS(text=clean_text, lang='ur', slow=False)
mp3 = BytesIO()
tts.write_to_fp(mp3)
mp3.seek(0)
return mp3
except:
return None
# ============== MAIN APP ==============
def main():
# Header
st.markdown('<h1 style="text-align: center; color: #667eea;">🧠 شفیق Pro</h1>',
unsafe_allow_html=True)
st.markdown('<h4 style="text-align: center; color: #666;">AI-Powered Mental Health Assistant</h4>',
unsafe_allow_html=True)
# Initialize session
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
st.session_state.diagnosis_history = []
# Sidebar - Diagnosis Dashboard
with st.sidebar:
st.header("📊 طبی تجزیہ / Medical Analysis")
if st.session_state.diagnosis_history:
latest = st.session_state.diagnosis_history[-1]
# Severity indicator
severity = latest['severity']
severity_class = f"severity-{severity.replace('-', '')}"
st.markdown(f"""
<div class="metric-card {severity_class}">
<h3>سنگینی / Severity</h3>
<h2>{severity.upper()}</h2>
</div>
""", unsafe_allow_html=True)
# Primary condition
st.write(f"**Primary Condition:** {latest['primary_condition']}")
st.write(f"**Confidence:** {latest['confidence']}%")
# Risk factors
if latest['risk_factors']:
st.write("**Risk Factors:**")
for factor in latest['risk_factors']:
st.write(f"- {factor}")
# History chart
if len(st.session_state.diagnosis_history) > 1:
st.write("**Trend:**")
conditions = [d['primary_condition'] for d in st.session_state.diagnosis_history]
st.bar_chart(pd.Series(conditions).value_counts())
st.markdown("---")
st.info("""
⚠️ **Disclaimer:** This AI provides preliminary screening only.
Not a substitute for professional psychiatric evaluation.
""")
if st.button("🗑️ New Session"):
st.session_state.chat_history = []
st.session_state.diagnosis_history = []
st.rerun()
# Main chat area
st.markdown("---")
# Display chat
for msg in st.session_state.chat_history:
if msg['role'] == 'user':
st.markdown(f"""
<div style="background: #f3e5f5; padding: 15px; border-radius: 15px;
text-align: right; margin: 10px 0;">
<strong>👤 You:</strong><br>
<span class="urdu-text">{msg['content']}</span>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div class="diagnosis-box urdu-text">
<strong>🤖 شفیق:</strong><br>
{msg['content']}
<br><small>Diagnosis: {msg.get('diagnosis', 'N/A')} |
Confidence: {msg.get('confidence', 0)}%</small>
</div>
""", unsafe_allow_html=True)
if msg.get('audio'):
st.audio(msg['audio'], format='audio/mp3')
# Input
st.markdown("---")
col1, col2 = st.columns([4, 1])
with col1:
user_input = st.text_input("Message / پیغام...",
key="input",
placeholder="اپنے جذبات بیان کریں...")
with col2:
send = st.button("📤 Send", use_container_width=True)
if send and user_input:
with st.spinner("Analyzing with AI..."):
# Run BERT analysis
mental_health_data = analyze_mental_health(user_input)
# Get emotion
emotion, emotion_scores = analyze_emotion_enhanced(user_input)
# Generate response
response = get_therapeutic_response_enhanced(
mental_health_data, emotion, user_input
)
# Text to speech
audio = text_to_speech(response)
# Save to history
st.session_state.chat_history.append({
'role': 'user',
'content': user_input
})
st.session_state.chat_history.append({
'role': 'bot',
'content': response,
'diagnosis': mental_health_data['primary_condition'],
'confidence': mental_health_data['confidence'],
'audio': audio
})
st.session_state.diagnosis_history.append(mental_health_data)
st.rerun()
if __name__ == "__main__":
import pandas as pd
main()