Spaces:
Runtime error
Runtime error
File size: 6,827 Bytes
dc879ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# app.py (Gradio – Original / Research Demo)
import re
from functools import lru_cache
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "Psychotherapy-LLM/PsychoCounsel-Llama3-8B"
# -----------------------------
# Load model once (cached)
# -----------------------------
@lru_cache(maxsize=1)
def get_model():
"""
Load PsychoCounsel-Llama3-8B in full precision on GPU (ZeroGPU) with device_map='auto'.
This is called lazily the first time a request comes in and then cached.
"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# No bitsandbytes here: ZeroGPU gives you a GPU so we let Transformers
# place layers automatically with device_map="auto".
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
return tokenizer, model
# -----------------------------
# Core generation logic
# -----------------------------
def build_original_prompt(client_text: str, context: str, mode: str) -> str:
client_text = (client_text or "").strip()
context = (context or "").strip()
# Hard cap length so extremely long vignettes don't explode cost/time
MAX_CHARS = 2000
if len(client_text) > MAX_CHARS:
client_text = client_text[:MAX_CHARS] + " [...]"
if mode == "Brief (5–7 sentences)":
instruction = (
"You are a professional psychotherapist conducting a session with a client. "
"Write 5–7 sentences in a warm, empathic, reflective tone, similar to the "
"PsychoCounsel-Llama3-8B Appendix case studies. You may ask some open-ended "
"questions and use gentle cognitive and reflective exploration. "
"Only output what the therapist says to the client."
)
else:
instruction = (
"You are a professional psychotherapist conducting a session with a client. "
"Generate a detailed, multi-paragraph therapeutic response in the tone and "
"structure of the Appendix case study for PsychoCounsel-Llama3-8B. Start with "
"validation and normalization, explore fears and beliefs, reflect on self-trust "
"and values, consider introducing a simple exercise, and close by inviting the "
"client to share what resonates. Only output what the therapist says."
)
if context:
instruction += " Consider this additional context about the therapist's stance: " + context
prompt = f"""{instruction}
Client Speech:
{client_text}
Therapist:
"""
return prompt
def generate_response(
client_speech: str,
therapist_context: str,
mode: str,
temperature: float,
top_p: float,
):
if not client_speech or not client_speech.strip():
return "Please enter some client speech."
tokenizer, model = get_model()
prompt = build_original_prompt(client_speech, therapist_context, mode)
# Tokenize on the model's device
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Keep generation lengths moderate to avoid timeouts
if mode == "Brief (5–7 sentences)":
max_tokens = 140
else:
max_tokens = 260
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True, # use sampling for some variability
eos_token_id=tokenizer.eos_token_id,
)
generated = outputs[0][inputs["input_ids"].shape[1]:]
raw = tokenizer.decode(generated, skip_special_tokens=True)
# Light cleanup of known artifacts
clean = raw.split("Note:")[0].split("FINAL ANSWER")[0].strip()
if mode == "Brief (5–7 sentences)":
sents = re.split(r'(?<=[.!?])\s+', clean)
sents = [s.strip() for s in sents if s.strip()]
clean = " ".join(sents[:7])
return clean
# -----------------------------
# Gradio UI
# -----------------------------
DESCRIPTION = """
This app uses **Psychotherapy-LLM/PsychoCounsel-Llama3-8B** in a style similar to the paper's Appendix case studies.
> ⚠️ **Important:** This version does *not* include additional safety logic for paranoia / harm content.
> It is intended for research, benchmarking, and model analysis by professionals.
> It is **not** a standalone clinical tool, nor a substitute for real-world psychiatric or psychological care.
"""
default_example = (
"Anxiety often strikes when I’m faced with making decisions. The fear of making "
"the wrong choice or disappointing others paralyzes me, leaving me stuck in indecision. "
"I want to learn how to trust myself and make confident choices."
)
with gr.Blocks(title="PsychoCounsel-Llama3-8B — Original / Research Demo") as demo:
gr.Markdown("# 🧠 PsychoCounsel-Llama3-8B — Original / Research Demo")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
mode = gr.Radio(
["Brief (5–7 sentences)", "Extended (Appendix-style)"],
value="Brief (5–7 sentences)",
label="Response Style",
)
temperature = gr.Slider(
0.1, 1.0, value=0.6, step=0.05, label="Temperature"
)
top_p = gr.Slider(
0.5, 1.0, value=0.9, step=0.05, label="Top-p"
)
gr.Markdown(
"This version is for **research / replication** and may generate content "
"that is not appropriate for direct use with vulnerable clients."
)
with gr.Column(scale=2):
client_speech_box = gr.Textbox(
label="Client Speech",
value=default_example,
lines=10,
placeholder="Paste or type the client's speech / vignette here…",
)
therapist_context_box = gr.Textbox(
label="Optional: Therapist context (e.g., modality, goals)",
value="",
lines=5,
)
generate_btn = gr.Button("Generate Therapist Response", variant="primary")
output_box = gr.Markdown(label="Therapist Response (Model Output)")
generate_btn.click(
fn=generate_response,
inputs=[
client_speech_box,
therapist_context_box,
mode,
temperature,
top_p,
],
outputs=output_box,
)
if __name__ == "__main__":
demo.launch()
|