File size: 6,827 Bytes
dc879ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# app.py  (Gradio – Original / Research Demo)

import re
from functools import lru_cache

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "Psychotherapy-LLM/PsychoCounsel-Llama3-8B"


# -----------------------------
# Load model once (cached)
# -----------------------------
@lru_cache(maxsize=1)
def get_model():
    """

    Load PsychoCounsel-Llama3-8B in full precision on GPU (ZeroGPU) with device_map='auto'.

    This is called lazily the first time a request comes in and then cached.

    """
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    # No bitsandbytes here: ZeroGPU gives you a GPU so we let Transformers
    # place layers automatically with device_map="auto".
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    return tokenizer, model


# -----------------------------
# Core generation logic
# -----------------------------
def build_original_prompt(client_text: str, context: str, mode: str) -> str:
    client_text = (client_text or "").strip()
    context = (context or "").strip()

    # Hard cap length so extremely long vignettes don't explode cost/time
    MAX_CHARS = 2000
    if len(client_text) > MAX_CHARS:
        client_text = client_text[:MAX_CHARS] + " [...]"

    if mode == "Brief (5–7 sentences)":
        instruction = (
            "You are a professional psychotherapist conducting a session with a client. "
            "Write 5–7 sentences in a warm, empathic, reflective tone, similar to the "
            "PsychoCounsel-Llama3-8B Appendix case studies. You may ask some open-ended "
            "questions and use gentle cognitive and reflective exploration. "
            "Only output what the therapist says to the client."
        )
    else:
        instruction = (
            "You are a professional psychotherapist conducting a session with a client. "
            "Generate a detailed, multi-paragraph therapeutic response in the tone and "
            "structure of the Appendix case study for PsychoCounsel-Llama3-8B. Start with "
            "validation and normalization, explore fears and beliefs, reflect on self-trust "
            "and values, consider introducing a simple exercise, and close by inviting the "
            "client to share what resonates. Only output what the therapist says."
        )

    if context:
        instruction += " Consider this additional context about the therapist's stance: " + context

    prompt = f"""{instruction}



Client Speech:

{client_text}



Therapist:

"""
    return prompt


def generate_response(

    client_speech: str,

    therapist_context: str,

    mode: str,

    temperature: float,

    top_p: float,

):
    if not client_speech or not client_speech.strip():
        return "Please enter some client speech."

    tokenizer, model = get_model()

    prompt = build_original_prompt(client_speech, therapist_context, mode)

    # Tokenize on the model's device
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Keep generation lengths moderate to avoid timeouts
    if mode == "Brief (5–7 sentences)":
        max_tokens = 140
    else:
        max_tokens = 260

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,                 # use sampling for some variability
            eos_token_id=tokenizer.eos_token_id,
        )

    generated = outputs[0][inputs["input_ids"].shape[1]:]
    raw = tokenizer.decode(generated, skip_special_tokens=True)

    # Light cleanup of known artifacts
    clean = raw.split("Note:")[0].split("FINAL ANSWER")[0].strip()

    if mode == "Brief (5–7 sentences)":
        sents = re.split(r'(?<=[.!?])\s+', clean)
        sents = [s.strip() for s in sents if s.strip()]
        clean = " ".join(sents[:7])

    return clean


# -----------------------------
# Gradio UI
# -----------------------------
DESCRIPTION = """

This app uses **Psychotherapy-LLM/PsychoCounsel-Llama3-8B** in a style similar to the paper's Appendix case studies.



> ⚠️ **Important:** This version does *not* include additional safety logic for paranoia / harm content.

> It is intended for research, benchmarking, and model analysis by professionals.

> It is **not** a standalone clinical tool, nor a substitute for real-world psychiatric or psychological care.

"""

default_example = (
    "Anxiety often strikes when I’m faced with making decisions. The fear of making "
    "the wrong choice or disappointing others paralyzes me, leaving me stuck in indecision. "
    "I want to learn how to trust myself and make confident choices."
)

with gr.Blocks(title="PsychoCounsel-Llama3-8B — Original / Research Demo") as demo:
    gr.Markdown("# 🧠 PsychoCounsel-Llama3-8B — Original / Research Demo")
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column(scale=1):
            mode = gr.Radio(
                ["Brief (5–7 sentences)", "Extended (Appendix-style)"],
                value="Brief (5–7 sentences)",
                label="Response Style",
            )
            temperature = gr.Slider(
                0.1, 1.0, value=0.6, step=0.05, label="Temperature"
            )
            top_p = gr.Slider(
                0.5, 1.0, value=0.9, step=0.05, label="Top-p"
            )
            gr.Markdown(
                "This version is for **research / replication** and may generate content "
                "that is not appropriate for direct use with vulnerable clients."
            )

        with gr.Column(scale=2):
            client_speech_box = gr.Textbox(
                label="Client Speech",
                value=default_example,
                lines=10,
                placeholder="Paste or type the client's speech / vignette here…",
            )
            therapist_context_box = gr.Textbox(
                label="Optional: Therapist context (e.g., modality, goals)",
                value="",
                lines=5,
            )
            generate_btn = gr.Button("Generate Therapist Response", variant="primary")
            output_box = gr.Markdown(label="Therapist Response (Model Output)")

    generate_btn.click(
        fn=generate_response,
        inputs=[
            client_speech_box,
            therapist_context_box,
            mode,
            temperature,
            top_p,
        ],
        outputs=output_box,
    )

if __name__ == "__main__":
    demo.launch()