Pant0x commited on
Commit
6607367
·
verified ·
1 Parent(s): 6c7327f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -52
app.py CHANGED
@@ -3,21 +3,21 @@ from huggingface_hub import InferenceClient
3
  import random
4
  import re
5
  import torch
6
- from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
7
  import librosa
8
  from gtts import gTTS
9
  import tempfile
10
  import os
11
 
12
- # ===== Mental health keywords (EN + AR + transliterated AR)
 
13
  MENTAL_KEYWORDS = [
14
  "depression", "depressed", "anxiety", "anxious", "panic", "stress", "sad", "lonely",
15
- "trauma", "mental", "therapy", "therapist", "counselor", "mood", "overwhelmed",
16
- "anger", "fear", "worry", "self-esteem", "confidence", "motivation", "relationship",
17
- "cope", "coping", "relax", "calm", "sleep", "emotion", "feeling", "feel", "thoughts",
18
- "help", "life", "advice", "unmotivated", "lost", "hopeless", "tired", "burnout",
19
- "cry", "hurt", "love", "breakup", "friend", "family", "alone", "heartbroken",
20
- "scared", "fearful",
21
  # Transliterated Arabic
22
  "ana", "zahqan", "daye2", "ha2t", "mota3ab", "mota3eb", "za3lan", "malo", "khalni",
23
  "mash3or", "bakhaf", "w7ed", "msh 3aref", "mash fahem", "malish", "3ayez", "ayez",
@@ -31,7 +31,6 @@ OFF_TOPIC = [
31
  "recipe", "song", "music", "lyrics", "joke", "funny", "laugh", "code", "python",
32
  "program", "game", "food", "cook", "movie", "film", "series", "sport", "football",
33
  "instagram", "tiktok", "money", "business", "crypto", "ai", "computer",
34
- # Arabic
35
  "نكتة", "ضحك", "اغنية", "اغاني", "طبخ", "اكل", "فيلم", "مسلسل", "كورة", "رياضة",
36
  "بيزنس", "فلوس", "العاب", "لعبة", "كود", "برمجة", "ذكاء اصطناعي"
37
  ]
@@ -42,10 +41,11 @@ OFF_TOPIC_RESPONSES = [
42
  "Let's bring it back to your emotions — I'm here to help process stress or challenges.",
43
  ]
44
 
45
- # Detect Arabic text
46
  def contains_arabic(text: str) -> bool:
47
  return bool(re.search(r"[\u0600-\u06FF]", text))
48
 
 
49
  def is_mental_health_related(text: str) -> bool:
50
  text_lower = text.lower()
51
  if any(word in text_lower for word in OFF_TOPIC):
@@ -56,84 +56,82 @@ def is_mental_health_related(text: str) -> bool:
56
  return True
57
  return False
58
 
59
- # ===== Voice emotion detection
 
 
60
  voice_model_name = "Hatman/audio-emotion-detection"
61
  voice_model = Wav2Vec2ForSequenceClassification.from_pretrained(voice_model_name)
62
- voice_processor = Wav2Vec2Processor.from_pretrained(voice_model_name)
 
63
 
64
  def detect_voice_emotion(audio_file):
65
  audio, sr = librosa.load(audio_file, sr=16000)
66
- inputs = voice_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
67
  with torch.no_grad():
68
  logits = voice_model(**inputs).logits
69
  predicted_id = torch.argmax(logits, dim=-1).item()
70
  return voice_model.config.id2label[predicted_id]
71
 
72
- # ===== Chat function with mood, TTS, transcript
 
 
73
  def respond(message, history, system_message, max_tokens, temperature, top_p, hf_token: gr.OAuthToken, audio=None):
74
  transcript = []
75
  response_text = ""
76
 
 
77
  if audio:
78
  mood = detect_voice_emotion(audio)
79
  response_text += f"[Detected mood: {mood}] "
80
 
 
81
  if not is_mental_health_related(message):
82
- response_text += random.choice(OFF_TOPIC_RESPONSES)
83
- transcript.append(("User", message))
84
- transcript.append(("Bot", response_text))
85
  tts_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
86
- tts = gTTS(response_text)
87
- tts.save(tts_file)
88
- return response_text, tts_file, transcript
89
 
 
 
90
  locked_system_message = (
91
  "You are a licensed mental health therapy assistant. "
92
  "You respond with empathy, emotional intelligence, and a therapeutic tone. "
93
- "Never answer questions unrelated to emotional or mental wellness."
94
  )
95
 
96
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
97
  messages = [{"role": "system", "content": locked_system_message}]
98
  messages.extend(history)
99
  messages.append({"role": "user", "content": message})
100
 
101
  for msg in client.chat_completion(messages, max_tokens=max_tokens, stream=True,
102
  temperature=temperature, top_p=top_p):
103
- choices = msg.choices
104
- token = ""
105
- if len(choices) and choices[0].delta.content:
106
- token = choices[0].delta.content
107
- response_text += token
108
 
109
- transcript.append(("User", message))
110
- transcript.append(("Bot", response_text))
111
 
112
  tts_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
113
- tts = gTTS(response_text)
114
- tts.save(tts_file)
115
 
116
  return response_text, tts_file, transcript
117
 
118
- # ===== Gradio UI
 
 
119
  with gr.Blocks() as demo:
120
- gr.Markdown("## Mental Health Chatbot with Voice Mood Detection")
121
- with gr.Row():
122
- with gr.Column():
123
- chatbot = gr.ChatInterface(
124
- respond,
125
- type="messages",
126
- additional_inputs=[
127
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
128
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
129
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
130
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
131
- gr.Audio(label="Record your voice (optional)", type="filepath"),
132
- gr.OAuthToken(label="Hugging Face Token"),
133
- ],
134
- )
135
-
136
- # Output area for transcript
137
- transcript_box = gr.Textbox(label="Transcript (User & Bot)", interactive=False)
138
-
139
- demo.launch()
 
3
  import random
4
  import re
5
  import torch
6
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
7
  import librosa
8
  from gtts import gTTS
9
  import tempfile
10
  import os
11
 
12
+ # ========== MENTAL HEALTH FILTERS ==========
13
+
14
  MENTAL_KEYWORDS = [
15
  "depression", "depressed", "anxiety", "anxious", "panic", "stress", "sad", "lonely",
16
+ "trauma", "mental", "therapy", "therapist", "mood", "overwhelmed", "anger", "fear",
17
+ "worry", "self-esteem", "confidence", "motivation", "relationship", "cope", "coping",
18
+ "relax", "calm", "sleep", "emotion", "feeling", "feel", "thoughts", "help", "life",
19
+ "advice", "unmotivated", "lost", "hopeless", "tired", "burnout", "cry", "hurt", "love",
20
+ "breakup", "friend", "family", "alone", "heartbroken", "scared", "fearful",
 
21
  # Transliterated Arabic
22
  "ana", "zahqan", "daye2", "ha2t", "mota3ab", "mota3eb", "za3lan", "malo", "khalni",
23
  "mash3or", "bakhaf", "w7ed", "msh 3aref", "mash fahem", "malish", "3ayez", "ayez",
 
31
  "recipe", "song", "music", "lyrics", "joke", "funny", "laugh", "code", "python",
32
  "program", "game", "food", "cook", "movie", "film", "series", "sport", "football",
33
  "instagram", "tiktok", "money", "business", "crypto", "ai", "computer",
 
34
  "نكتة", "ضحك", "اغنية", "اغاني", "طبخ", "اكل", "فيلم", "مسلسل", "كورة", "رياضة",
35
  "بيزنس", "فلوس", "العاب", "لعبة", "كود", "برمجة", "ذكاء اصطناعي"
36
  ]
 
41
  "Let's bring it back to your emotions — I'm here to help process stress or challenges.",
42
  ]
43
 
44
+
45
  def contains_arabic(text: str) -> bool:
46
  return bool(re.search(r"[\u0600-\u06FF]", text))
47
 
48
+
49
  def is_mental_health_related(text: str) -> bool:
50
  text_lower = text.lower()
51
  if any(word in text_lower for word in OFF_TOPIC):
 
56
  return True
57
  return False
58
 
59
+
60
+ # ========== EMOTION DETECTION MODEL ==========
61
+
62
  voice_model_name = "Hatman/audio-emotion-detection"
63
  voice_model = Wav2Vec2ForSequenceClassification.from_pretrained(voice_model_name)
64
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(voice_model_name)
65
+
66
 
67
  def detect_voice_emotion(audio_file):
68
  audio, sr = librosa.load(audio_file, sr=16000)
69
+ inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
70
  with torch.no_grad():
71
  logits = voice_model(**inputs).logits
72
  predicted_id = torch.argmax(logits, dim=-1).item()
73
  return voice_model.config.id2label[predicted_id]
74
 
75
+
76
+ # ========== RESPONSE LOGIC ==========
77
+
78
  def respond(message, history, system_message, max_tokens, temperature, top_p, hf_token: gr.OAuthToken, audio=None):
79
  transcript = []
80
  response_text = ""
81
 
82
+ # Mood detection from voice
83
  if audio:
84
  mood = detect_voice_emotion(audio)
85
  response_text += f"[Detected mood: {mood}] "
86
 
87
+ # Mental health filtering
88
  if not is_mental_health_related(message):
89
+ bot_reply = random.choice(OFF_TOPIC_RESPONSES)
90
+ transcript.extend([("User", message), ("Bot", bot_reply)])
 
91
  tts_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
92
+ gTTS(bot_reply).save(tts_file)
93
+ return bot_reply, tts_file, transcript
 
94
 
95
+ # GPT-based mental health conversation
96
+ client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
97
  locked_system_message = (
98
  "You are a licensed mental health therapy assistant. "
99
  "You respond with empathy, emotional intelligence, and a therapeutic tone. "
100
+ "Never answer unrelated questions."
101
  )
102
 
 
103
  messages = [{"role": "system", "content": locked_system_message}]
104
  messages.extend(history)
105
  messages.append({"role": "user", "content": message})
106
 
107
  for msg in client.chat_completion(messages, max_tokens=max_tokens, stream=True,
108
  temperature=temperature, top_p=top_p):
109
+ if msg.choices and msg.choices[0].delta.content:
110
+ response_text += msg.choices[0].delta.content
 
 
 
111
 
112
+ transcript.extend([("User", message), ("Bot", response_text)])
 
113
 
114
  tts_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
115
+ gTTS(response_text).save(tts_file)
 
116
 
117
  return response_text, tts_file, transcript
118
 
119
+
120
+ # ========== GRADIO UI ==========
121
+
122
  with gr.Blocks() as demo:
123
+ gr.Markdown("## 🧠 Mental Health Chatbot with Voice Mood Detection & TTS")
124
+ chatbot = gr.ChatInterface(
125
+ respond,
126
+ type="messages",
127
+ additional_inputs=[
128
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
129
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
130
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
131
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
132
+ gr.Audio(label="🎙️ Speak (optional)", type="filepath"),
133
+ gr.OAuthToken(label="Hugging Face Token"),
134
+ ],
135
+ )
136
+
137
+ demo.launch()