Graph2Counsel
Collection
2 items • Updated
Llama-3-8B-Instruct fine-tuned on the Graph2Counsel dataset for mental health counseling dialogue generation. The model generates counselor response to a client dialogue in a multi-turn counseling dialogue.
Base model: meta-llama/Meta-Llama-3-8B-Instruct
Input: A system prompt with a fixed counselor instruction, followed by the dialogue history and the client profile.
Output: The next counselor turn in the therapeutic dialogue.
Training data: Graph2Counsel — a dataset of synthetic counseling sessions grounded in CPGs derived from real counseling sessions.
Fine-tuning method: QLoRA
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "UKPLab/Llama3-G2C"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
system_prompt = (
"You are a professional counselor. Your task is to generate a natural, empathetic "
"and therapeutic response to the client's most recent utterance while adhering to "
"established psychological techniques. You are provided with the current dialogue "
"history and the client profile. Please be mindful to only generate the counselor "
"response for a single turn, and do not include extra text like \"here is the next "
"counselor utterance\" or \"Here is a possible next utterance\" or anything mentioning "
"or explaining the used technique."
)
history = (
"Counselor: What brings you in today?\n"
"Client: I've been feeling really anxious at work lately. "
"It usually happens when I have to give feedback. "
"I worry my comments won't be taken seriously."
)
profile = (
"Client is a 28-year-old graphic designer who overthinks interactions "
"with colleagues and struggles to articulate her feelings in stressful situations."
)
user_content = f"Dialogue History:\n{history}\nClient Profile:\n{profile}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
response = tokenizer.decode(output[0][input_ids.shape[-1]:], skip_special_tokens=True)
print(response)
Please cite this model using:
@misc{PLACEHOLDER,
title = {TITLE},
author = {AUTHOR(S)},
year = {YEAR},
eprint = {ARXIV_ID},
archivePrefix = {arXiv},
primaryClass = {cs.CL},
url = {https://arxiv.org/abs/ARXIV_ID},
}
For questions or feedback, please contact: aishik.mandal@tu-darmstadt.de
Base model
meta-llama/Meta-Llama-3-8B-Instruct