MedSSS-8B-PRM
Introduction
MedSSS-PRM is trained with the newly proposed soft dual-sided object, designed for identifying intermediate erroneous steps within a correct medical reasoning trajectory.
It will assign a [0-1] float value for every internal reasoning step of MedSSS-Policy.
For more information, visit our GitHub repository: https://github.com/pixas/MedSSS.
Usage
We build the PRM model as a LoRA adapter, which saves the memory to use it.
As this LoRA adapter is built on pixas/MedSSS_Policy, you need to first prepare the base model in your platform.
from itertools import chain
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
# `outputs` generated by the MedSSS-Policy
messages = [
{"role": "user", "content": inputs},
{"role": "assistant", "content": response}
]
prompt_text = tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
completions = ["Step" + completion if not completion.startswith("Step") else completion for completion in response.split("\n\nStep")]
completion_ids = [
tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions
]
response_id = list(chain(*completion_ids))
pre_response_id = tokenizer(prompt_text, add_special_tokens=False)['input_ids']
input_ids = pre_response_id + response_id
outputs = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device)) # [1, N]
value = torch.softmax(outputs[0], dim=-1)[..., 1]
completion_index = []
for i, completion in enumerate(completion_ids):
if i == 0:
completion_index.append(len(completion) + len(pre_response_id) - 1)
else:
completion_index.append(completion_index[-1] + len(completion))
step_value = value[0, completion_index].cpu().numpy().tolist()
return step_value
base_model = AutoModelForTokenClassification.from_pretrained("pixas/MedSSS_Policy",torch_dtype="auto",device_map="auto")
model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM")
input_text = "How to stop a cough?"
step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]"
value = obtain_prm_value_for_single_pair(tokenizer, model, input_text, step_wise_generation)
print(value)
MedSSS-PRM uses "\n\nStep" to separate intermediate steps. So the token classification happens before the next "Step k: " or the end of the sequence.
Model tree for pixas/MedSSS_PRM
Base model
meta-llama/Llama-3.1-8B
Finetuned
meta-llama/Llama-3.1-8B-Instruct
Finetuned
pixas/MedSSS_Policy