Spaces:
Runtime error
Runtime error
| # app.py (Gradio – Original / Research Demo) | |
| import re | |
| from functools import lru_cache | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| MODEL_ID = "Psychotherapy-LLM/PsychoCounsel-Llama3-8B" | |
| # ----------------------------- | |
| # Load model once (cached) | |
| # ----------------------------- | |
| def get_model(): | |
| """ | |
| Load PsychoCounsel-Llama3-8B in full precision on GPU (ZeroGPU) with device_map='auto'. | |
| This is called lazily the first time a request comes in and then cached. | |
| """ | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # No bitsandbytes here: ZeroGPU gives you a GPU so we let Transformers | |
| # place layers automatically with device_map="auto". | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| return tokenizer, model | |
| # ----------------------------- | |
| # Core generation logic | |
| # ----------------------------- | |
| def build_original_prompt(client_text: str, context: str, mode: str) -> str: | |
| client_text = (client_text or "").strip() | |
| context = (context or "").strip() | |
| # Hard cap length so extremely long vignettes don't explode cost/time | |
| MAX_CHARS = 2000 | |
| if len(client_text) > MAX_CHARS: | |
| client_text = client_text[:MAX_CHARS] + " [...]" | |
| if mode == "Brief (5–7 sentences)": | |
| instruction = ( | |
| "You are a professional psychotherapist conducting a session with a client. " | |
| "Write 5–7 sentences in a warm, empathic, reflective tone, similar to the " | |
| "PsychoCounsel-Llama3-8B Appendix case studies. You may ask some open-ended " | |
| "questions and use gentle cognitive and reflective exploration. " | |
| "Only output what the therapist says to the client." | |
| ) | |
| else: | |
| instruction = ( | |
| "You are a professional psychotherapist conducting a session with a client. " | |
| "Generate a detailed, multi-paragraph therapeutic response in the tone and " | |
| "structure of the Appendix case study for PsychoCounsel-Llama3-8B. Start with " | |
| "validation and normalization, explore fears and beliefs, reflect on self-trust " | |
| "and values, consider introducing a simple exercise, and close by inviting the " | |
| "client to share what resonates. Only output what the therapist says." | |
| ) | |
| if context: | |
| instruction += " Consider this additional context about the therapist's stance: " + context | |
| prompt = f"""{instruction} | |
| Client Speech: | |
| {client_text} | |
| Therapist: | |
| """ | |
| return prompt | |
| def generate_response( | |
| client_speech: str, | |
| therapist_context: str, | |
| mode: str, | |
| temperature: float, | |
| top_p: float, | |
| ): | |
| if not client_speech or not client_speech.strip(): | |
| return "Please enter some client speech." | |
| tokenizer, model = get_model() | |
| prompt = build_original_prompt(client_speech, therapist_context, mode) | |
| # Tokenize on the model's device | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Keep generation lengths moderate to avoid timeouts | |
| if mode == "Brief (5–7 sentences)": | |
| max_tokens = 140 | |
| else: | |
| max_tokens = 260 | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, # use sampling for some variability | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated = outputs[0][inputs["input_ids"].shape[1]:] | |
| raw = tokenizer.decode(generated, skip_special_tokens=True) | |
| # Light cleanup of known artifacts | |
| clean = raw.split("Note:")[0].split("FINAL ANSWER")[0].strip() | |
| if mode == "Brief (5–7 sentences)": | |
| sents = re.split(r'(?<=[.!?])\s+', clean) | |
| sents = [s.strip() for s in sents if s.strip()] | |
| clean = " ".join(sents[:7]) | |
| return clean | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| DESCRIPTION = """ | |
| This app uses **Psychotherapy-LLM/PsychoCounsel-Llama3-8B** in a style similar to the paper's Appendix case studies. | |
| > ⚠️ **Important:** This version does *not* include additional safety logic for paranoia / harm content. | |
| > It is intended for research, benchmarking, and model analysis by professionals. | |
| > It is **not** a standalone clinical tool, nor a substitute for real-world psychiatric or psychological care. | |
| """ | |
| default_example = ( | |
| "Anxiety often strikes when I’m faced with making decisions. The fear of making " | |
| "the wrong choice or disappointing others paralyzes me, leaving me stuck in indecision. " | |
| "I want to learn how to trust myself and make confident choices." | |
| ) | |
| with gr.Blocks(title="PsychoCounsel-Llama3-8B — Original / Research Demo") as demo: | |
| gr.Markdown("# 🧠 PsychoCounsel-Llama3-8B — Original / Research Demo") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| mode = gr.Radio( | |
| ["Brief (5–7 sentences)", "Extended (Appendix-style)"], | |
| value="Brief (5–7 sentences)", | |
| label="Response Style", | |
| ) | |
| temperature = gr.Slider( | |
| 0.1, 1.0, value=0.6, step=0.05, label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| 0.5, 1.0, value=0.9, step=0.05, label="Top-p" | |
| ) | |
| gr.Markdown( | |
| "This version is for **research / replication** and may generate content " | |
| "that is not appropriate for direct use with vulnerable clients." | |
| ) | |
| with gr.Column(scale=2): | |
| client_speech_box = gr.Textbox( | |
| label="Client Speech", | |
| value=default_example, | |
| lines=10, | |
| placeholder="Paste or type the client's speech / vignette here…", | |
| ) | |
| therapist_context_box = gr.Textbox( | |
| label="Optional: Therapist context (e.g., modality, goals)", | |
| value="", | |
| lines=5, | |
| ) | |
| generate_btn = gr.Button("Generate Therapist Response", variant="primary") | |
| output_box = gr.Markdown(label="Therapist Response (Model Output)") | |
| generate_btn.click( | |
| fn=generate_response, | |
| inputs=[ | |
| client_speech_box, | |
| therapist_context_box, | |
| mode, | |
| temperature, | |
| top_p, | |
| ], | |
| outputs=output_box, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |