razaali10 commited on
Commit
dc879ce
·
verified ·
1 Parent(s): 35d9d7e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (Gradio – Original / Research Demo)
2
+
3
+ import re
4
+ from functools import lru_cache
5
+
6
+ import torch
7
+ import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+ MODEL_ID = "Psychotherapy-LLM/PsychoCounsel-Llama3-8B"
11
+
12
+
13
+ # -----------------------------
14
+ # Load model once (cached)
15
+ # -----------------------------
16
+ @lru_cache(maxsize=1)
17
+ def get_model():
18
+ """
19
+ Load PsychoCounsel-Llama3-8B in full precision on GPU (ZeroGPU) with device_map='auto'.
20
+ This is called lazily the first time a request comes in and then cached.
21
+ """
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23
+
24
+ # No bitsandbytes here: ZeroGPU gives you a GPU so we let Transformers
25
+ # place layers automatically with device_map="auto".
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_ID,
28
+ torch_dtype=torch.bfloat16,
29
+ device_map="auto",
30
+ )
31
+
32
+ return tokenizer, model
33
+
34
+
35
+ # -----------------------------
36
+ # Core generation logic
37
+ # -----------------------------
38
+ def build_original_prompt(client_text: str, context: str, mode: str) -> str:
39
+ client_text = (client_text or "").strip()
40
+ context = (context or "").strip()
41
+
42
+ # Hard cap length so extremely long vignettes don't explode cost/time
43
+ MAX_CHARS = 2000
44
+ if len(client_text) > MAX_CHARS:
45
+ client_text = client_text[:MAX_CHARS] + " [...]"
46
+
47
+ if mode == "Brief (5–7 sentences)":
48
+ instruction = (
49
+ "You are a professional psychotherapist conducting a session with a client. "
50
+ "Write 5–7 sentences in a warm, empathic, reflective tone, similar to the "
51
+ "PsychoCounsel-Llama3-8B Appendix case studies. You may ask some open-ended "
52
+ "questions and use gentle cognitive and reflective exploration. "
53
+ "Only output what the therapist says to the client."
54
+ )
55
+ else:
56
+ instruction = (
57
+ "You are a professional psychotherapist conducting a session with a client. "
58
+ "Generate a detailed, multi-paragraph therapeutic response in the tone and "
59
+ "structure of the Appendix case study for PsychoCounsel-Llama3-8B. Start with "
60
+ "validation and normalization, explore fears and beliefs, reflect on self-trust "
61
+ "and values, consider introducing a simple exercise, and close by inviting the "
62
+ "client to share what resonates. Only output what the therapist says."
63
+ )
64
+
65
+ if context:
66
+ instruction += " Consider this additional context about the therapist's stance: " + context
67
+
68
+ prompt = f"""{instruction}
69
+
70
+ Client Speech:
71
+ {client_text}
72
+
73
+ Therapist:
74
+ """
75
+ return prompt
76
+
77
+
78
+ def generate_response(
79
+ client_speech: str,
80
+ therapist_context: str,
81
+ mode: str,
82
+ temperature: float,
83
+ top_p: float,
84
+ ):
85
+ if not client_speech or not client_speech.strip():
86
+ return "Please enter some client speech."
87
+
88
+ tokenizer, model = get_model()
89
+
90
+ prompt = build_original_prompt(client_speech, therapist_context, mode)
91
+
92
+ # Tokenize on the model's device
93
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
94
+
95
+ # Keep generation lengths moderate to avoid timeouts
96
+ if mode == "Brief (5–7 sentences)":
97
+ max_tokens = 140
98
+ else:
99
+ max_tokens = 260
100
+
101
+ with torch.no_grad():
102
+ outputs = model.generate(
103
+ **inputs,
104
+ max_new_tokens=max_tokens,
105
+ temperature=temperature,
106
+ top_p=top_p,
107
+ do_sample=True, # use sampling for some variability
108
+ eos_token_id=tokenizer.eos_token_id,
109
+ )
110
+
111
+ generated = outputs[0][inputs["input_ids"].shape[1]:]
112
+ raw = tokenizer.decode(generated, skip_special_tokens=True)
113
+
114
+ # Light cleanup of known artifacts
115
+ clean = raw.split("Note:")[0].split("FINAL ANSWER")[0].strip()
116
+
117
+ if mode == "Brief (5–7 sentences)":
118
+ sents = re.split(r'(?<=[.!?])\s+', clean)
119
+ sents = [s.strip() for s in sents if s.strip()]
120
+ clean = " ".join(sents[:7])
121
+
122
+ return clean
123
+
124
+
125
+ # -----------------------------
126
+ # Gradio UI
127
+ # -----------------------------
128
+ DESCRIPTION = """
129
+ This app uses **Psychotherapy-LLM/PsychoCounsel-Llama3-8B** in a style similar to the paper's Appendix case studies.
130
+
131
+ > ⚠️ **Important:** This version does *not* include additional safety logic for paranoia / harm content.
132
+ > It is intended for research, benchmarking, and model analysis by professionals.
133
+ > It is **not** a standalone clinical tool, nor a substitute for real-world psychiatric or psychological care.
134
+ """
135
+
136
+ default_example = (
137
+ "Anxiety often strikes when I’m faced with making decisions. The fear of making "
138
+ "the wrong choice or disappointing others paralyzes me, leaving me stuck in indecision. "
139
+ "I want to learn how to trust myself and make confident choices."
140
+ )
141
+
142
+ with gr.Blocks(title="PsychoCounsel-Llama3-8B — Original / Research Demo") as demo:
143
+ gr.Markdown("# 🧠 PsychoCounsel-Llama3-8B — Original / Research Demo")
144
+ gr.Markdown(DESCRIPTION)
145
+
146
+ with gr.Row():
147
+ with gr.Column(scale=1):
148
+ mode = gr.Radio(
149
+ ["Brief (5–7 sentences)", "Extended (Appendix-style)"],
150
+ value="Brief (5–7 sentences)",
151
+ label="Response Style",
152
+ )
153
+ temperature = gr.Slider(
154
+ 0.1, 1.0, value=0.6, step=0.05, label="Temperature"
155
+ )
156
+ top_p = gr.Slider(
157
+ 0.5, 1.0, value=0.9, step=0.05, label="Top-p"
158
+ )
159
+ gr.Markdown(
160
+ "This version is for **research / replication** and may generate content "
161
+ "that is not appropriate for direct use with vulnerable clients."
162
+ )
163
+
164
+ with gr.Column(scale=2):
165
+ client_speech_box = gr.Textbox(
166
+ label="Client Speech",
167
+ value=default_example,
168
+ lines=10,
169
+ placeholder="Paste or type the client's speech / vignette here…",
170
+ )
171
+ therapist_context_box = gr.Textbox(
172
+ label="Optional: Therapist context (e.g., modality, goals)",
173
+ value="",
174
+ lines=5,
175
+ )
176
+ generate_btn = gr.Button("Generate Therapist Response", variant="primary")
177
+ output_box = gr.Markdown(label="Therapist Response (Model Output)")
178
+
179
+ generate_btn.click(
180
+ fn=generate_response,
181
+ inputs=[
182
+ client_speech_box,
183
+ therapist_context_box,
184
+ mode,
185
+ temperature,
186
+ top_p,
187
+ ],
188
+ outputs=output_box,
189
+ )
190
+
191
+ if __name__ == "__main__":
192
+ demo.launch()