phxdev commited on
Commit
c5ed134
·
verified ·
1 Parent(s): da21388

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +352 -51
app.py CHANGED
@@ -1,64 +1,365 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
27
 
28
- response = ""
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from peft import PeftModel
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import warnings
6
+ import time
7
+ import random
8
 
9
+ warnings.filterwarnings('ignore')
 
 
 
10
 
11
+ # Global model variables
12
+ model = None
13
+ tokenizer = None
14
+ device = None
15
 
16
+ def load_model():
17
+ """Load the psychology-tuned model with error handling"""
18
+ global model, tokenizer, device
19
+
20
+ try:
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f"Loading model on {device}...")
23
+
24
+ # Load tokenizer
25
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
26
+ if tokenizer.pad_token is None:
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+
29
+ # Load base model
30
+ base_model = AutoModelForCausalLM.from_pretrained(
31
+ "Qwen/Qwen2.5-0.5B",
32
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
33
+ device_map="auto" if torch.cuda.is_available() else None,
34
+ trust_remote_code=True
35
+ )
36
+
37
+ # Load PEFT adapter
38
+ model = PeftModel.from_pretrained(base_model, "phxdev/psychology-qwen-0.5b")
39
+ model = model.merge_and_unload()
40
+
41
+ if not torch.cuda.is_available():
42
+ model = model.to(device)
43
+
44
+ print("✅ Model loaded successfully!")
45
+ return True
46
+
47
+ except Exception as e:
48
+ print(f"❌ Error loading model: {e}")
49
+ return False
50
 
51
+ def generate_response(
52
+ message: str,
53
+ history: list,
54
+ temperature: float = 0.8,
55
+ max_tokens: int = 300,
56
+ prompt_style: str = "Therapeutic"
57
+ ) -> str:
58
+ """Generate psychology-focused response"""
59
+
60
+ if model is None:
61
+ return "⚠️ Model is still loading. Please wait a moment and try again."
62
+
63
+ # Define prompt templates
64
+ prompt_templates = {
65
+ "Therapeutic": """You're a supportive therapist in session. The client just said: "{message}"
66
 
67
+ Respond with empathy and practical guidance. Start with validation, then give 2-3 specific strategies they can try:""",
68
+
69
+ "Supportive Friend": """You're a caring friend who studied psychology. Someone you care about just told you: "{message}"
70
 
71
+ Give them warm, understanding support with practical advice:""",
72
+
73
+ "Crisis Support": """You are a crisis counselor. Someone is reaching out for support. They said: "{message}"
74
 
75
+ Provide immediate, caring support and grounding techniques:""",
76
+
77
+ "Anxiety Coach": """You're an anxiety specialist. Help them manage their anxiety with evidence-based techniques.
 
 
 
 
 
78
 
79
+ They're struggling with: "{message}"
 
80
 
81
+ Here are specific techniques they can try right now:""",
82
+
83
+ "Mindfulness Guide": """You're a mindfulness teacher. Guide them toward present-moment awareness and self-compassion.
84
 
85
+ They shared: "{message}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ Offer mindful perspective and a gentle practice:"""
88
+ }
89
+
90
+ # Select and format prompt
91
+ template = prompt_templates.get(prompt_style, prompt_templates["Therapeutic"])
92
+ formatted_prompt = template.format(message=message)
93
+
94
+ try:
95
+ # Tokenize
96
+ inputs = tokenizer(
97
+ formatted_prompt,
98
+ return_tensors="pt",
99
+ truncation=True,
100
+ max_length=512
101
+ ).to(device)
102
+
103
+ # Generate
104
+ with torch.no_grad():
105
+ outputs = model.generate(
106
+ **inputs,
107
+ max_length=len(inputs.input_ids[0]) + max_tokens,
108
+ temperature=temperature,
109
+ top_p=0.9,
110
+ do_sample=True,
111
+ repetition_penalty=1.1,
112
+ pad_token_id=tokenizer.eos_token_id,
113
+ eos_token_id=tokenizer.eos_token_id,
114
+ )
115
+
116
+ # Decode
117
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
118
+ generated_text = response[len(formatted_prompt):].strip()
119
+
120
+ # Clean up response
121
+ if generated_text.startswith('"') and generated_text.endswith('"'):
122
+ generated_text = generated_text[1:-1]
123
+
124
+ return generated_text
125
+
126
+ except Exception as e:
127
+ return f"⚠️ Error generating response: {str(e)}"
128
 
129
+ def get_example_prompts():
130
+ """Return curated example prompts for different scenarios"""
131
+ examples = {
132
+ "Work Stress": [
133
+ "I feel overwhelmed with my workload and don't know how to manage everything.",
134
+ "My boss is very demanding and I'm afraid of disappointing them.",
135
+ "I feel like I'm not good enough at my job compared to my colleagues."
136
+ ],
137
+ "Anxiety & Worry": [
138
+ "I can't stop worrying about things that might go wrong.",
139
+ "I have panic attacks before important meetings or presentations.",
140
+ "My mind races at night and I can't fall asleep."
141
+ ],
142
+ "Relationships": [
143
+ "I have trouble setting boundaries with people who take advantage of me.",
144
+ "I feel lonely even when I'm around other people.",
145
+ "I'm afraid of being rejected if I show my true self."
146
+ ],
147
+ "Self-Esteem": [
148
+ "I constantly criticize myself and focus on my mistakes.",
149
+ "I feel like everyone else has their life figured out except me.",
150
+ "I'm afraid to try new things because I might fail."
151
+ ],
152
+ "Life Changes": [
153
+ "I'm going through a major life transition and feel lost.",
154
+ "I'm grieving the loss of someone important to me.",
155
+ "I feel stuck in patterns that aren't serving me anymore."
156
+ ]
157
+ }
158
+ return examples
159
+
160
+ def create_interface():
161
+ """Create the main Gradio interface"""
162
+
163
+ # Custom CSS for modern, calming design
164
+ custom_css = """
165
+ .main-header {
166
+ text-align: center;
167
+ padding: 2rem;
168
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
169
+ color: white;
170
+ border-radius: 10px;
171
+ margin-bottom: 2rem;
172
+ }
173
+ .disclaimer-box {
174
+ background-color: #f8f9fa;
175
+ border-left: 4px solid #17a2b8;
176
+ padding: 1rem;
177
+ margin: 1rem 0;
178
+ border-radius: 5px;
179
+ }
180
+ .example-category {
181
+ margin: 0.5rem 0;
182
+ padding: 0.5rem;
183
+ background-color: #f1f3f4;
184
+ border-radius: 5px;
185
+ }
186
+ """
187
+
188
+ # Get example prompts
189
+ examples = get_example_prompts()
190
+
191
+ with gr.Blocks(css=custom_css, title="Psychology AI Assistant", theme=gr.themes.Soft()) as demo:
192
+
193
+ # Header
194
+ gr.HTML("""
195
+ <div class="main-header">
196
+ <h1>🧠 Psychology AI Assistant</h1>
197
+ <p>A supportive AI trained in psychology and mental health</p>
198
+ <p><em>Powered by fine-tuned Qwen 2.5-0.5B</em></p>
199
+ </div>
200
+ """)
201
+
202
+ # Disclaimer
203
+ gr.HTML("""
204
+ <div class="disclaimer-box">
205
+ <strong>⚠️ Important Disclaimer:</strong> This AI assistant provides supportive guidance based on psychological principles,
206
+ but it is not a replacement for professional therapy or medical advice. If you're experiencing a mental health crisis,
207
+ please contact a mental health professional or crisis hotline immediately.
208
+ </div>
209
+ """)
210
+
211
+ with gr.Row():
212
+ with gr.Column(scale=3):
213
+ # Main chat interface
214
+ chatbot = gr.Chatbot(
215
+ height=500,
216
+ placeholder="👋 Hi! I'm here to provide supportive guidance and practical strategies for mental wellness. What's on your mind today?",
217
+ avatar_images=("🧑‍💼", "🧠")
218
+ )
219
+
220
+ with gr.Row():
221
+ message_input = gr.Textbox(
222
+ placeholder="Share what's on your mind...",
223
+ container=False,
224
+ scale=4,
225
+ lines=2
226
+ )
227
+ send_btn = gr.Button("Send", variant="primary", scale=1)
228
+
229
+ # Example prompts
230
+ gr.HTML("<h3>💡 Try these conversation starters:</h3>")
231
+
232
+ for category, prompts in examples.items():
233
+ with gr.Row():
234
+ gr.HTML(f"<strong>{category}:</strong>")
235
+ with gr.Row():
236
+ for prompt in prompts:
237
+ example_btn = gr.Button(
238
+ prompt[:60] + "..." if len(prompt) > 60 else prompt,
239
+ size="sm",
240
+ variant="secondary"
241
+ )
242
+ example_btn.click(
243
+ lambda x=prompt: x,
244
+ outputs=[message_input]
245
+ )
246
+
247
+ with gr.Column(scale=1):
248
+ # Settings panel
249
+ gr.HTML("<h3>⚙️ Settings</h3>")
250
+
251
+ prompt_style = gr.Dropdown(
252
+ choices=["Therapeutic", "Supportive Friend", "Crisis Support", "Anxiety Coach", "Mindfulness Guide"],
253
+ value="Therapeutic",
254
+ label="Response Style",
255
+ info="Choose the type of support you prefer"
256
+ )
257
+
258
+ temperature = gr.Slider(
259
+ minimum=0.1,
260
+ maximum=1.0,
261
+ value=0.8,
262
+ step=0.1,
263
+ label="Creativity",
264
+ info="Higher = more creative responses"
265
+ )
266
+
267
+ max_tokens = gr.Slider(
268
+ minimum=100,
269
+ maximum=500,
270
+ value=300,
271
+ step=50,
272
+ label="Response Length",
273
+ info="Maximum response length"
274
+ )
275
+
276
+ clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary")
277
+
278
+ # Model info
279
+ gr.HTML("""
280
+ <div style="margin-top: 2rem; padding: 1rem; background-color: #f8f9fa; border-radius: 5px;">
281
+ <h4>📊 Model Info</h4>
282
+ <p><strong>Base Model:</strong> Qwen/Qwen2.5-0.5B</p>
283
+ <p><strong>Fine-tuned:</strong> phxdev/psychology-qwen-0.5b</p>
284
+ <p><strong>Specialization:</strong> Psychology & Mental Health</p>
285
+ <p><strong>Training:</strong> PEFT/LoRA</p>
286
+ </div>
287
+ """)
288
+
289
+ # Chat functionality
290
+ def respond(message, history, temp, max_tok, style):
291
+ if not message.strip():
292
+ return history, ""
293
+
294
+ # Add user message
295
+ history = history + [[message, None]]
296
+
297
+ # Generate response
298
+ bot_response = generate_response(message, history, temp, max_tok, style)
299
+
300
+ # Add bot response
301
+ history[-1][1] = bot_response
302
+
303
+ return history, ""
304
+
305
+ def clear_chat():
306
+ return [], ""
307
+
308
+ # Event handlers
309
+ send_btn.click(
310
+ respond,
311
+ inputs=[message_input, chatbot, temperature, max_tokens, prompt_style],
312
+ outputs=[chatbot, message_input]
313
+ )
314
+
315
+ message_input.submit(
316
+ respond,
317
+ inputs=[message_input, chatbot, temperature, max_tokens, prompt_style],
318
+ outputs=[chatbot, message_input]
319
+ )
320
+
321
+ clear_btn.click(clear_chat, outputs=[chatbot, message_input])
322
+
323
+ # Footer
324
+ gr.HTML("""
325
+ <div style="text-align: center; margin-top: 2rem; padding: 1rem; color: #666;">
326
+ <p>Built with ❤️ using Gradio • Fine-tuned by @phxdev</p>
327
+ <p>If you're in crisis, please reach out: <a href="https://988lifeline.org/" target="_blank">988 Suicide & Crisis Lifeline</a></p>
328
+ </div>
329
+ """)
330
+
331
+ return demo
332
+
333
+ # Initialize the model
334
+ print("🚀 Starting Psychology AI Assistant...")
335
+ model_loaded = load_model()
336
+
337
+ if model_loaded:
338
+ print("✅ Model loaded successfully!")
339
+ demo = create_interface()
340
+
341
+ if __name__ == "__main__":
342
+ demo.launch(
343
+ share=False,
344
+ server_name="0.0.0.0",
345
+ server_port=7860,
346
+ show_error=True
347
+ )
348
+ else:
349
+ print("❌ Failed to load model. Creating error interface...")
350
+
351
+ def create_error_interface():
352
+ with gr.Blocks() as error_demo:
353
+ gr.HTML("""
354
+ <div style="text-align: center; padding: 2rem;">
355
+ <h1>⚠️ Model Loading Error</h1>
356
+ <p>Sorry, there was an issue loading the psychology model.</p>
357
+ <p>Please try refreshing the page or contact support.</p>
358
+ </div>
359
+ """)
360
+ return error_demo
361
+
362
+ demo = create_error_interface()
363
+
364
+ if __name__ == "__main__":
365
+ demo.launch()