pixas commited on
Commit
4bd4fd1
·
verified ·
1 Parent(s): 08d8702

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -19
README.md CHANGED
@@ -32,34 +32,31 @@ We build the PRM model as a LoRA adapter, which saves the memory to use it.
32
  As this LoRA adapter is built on `pixas/MedSSS_Policy`, you need to first prepare the base model in your platform.
33
 
34
  ```python
 
 
 
 
35
 
36
  def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
37
  # `outputs` generated by the MedSSS-Policy
38
- response = outputs
39
- completions = [f"Step" + completion if not completion.startswith("Step") else completion for k, completion in enumerate(outputs.split("\n\nStep"))]
40
-
41
  messages = [
42
  {"role": "user", "content": inputs},
43
  {"role": "assistant", "content": response}
44
  ]
45
- input_text = tokenizer.apply_chat_template(messages, tokenize=False)
46
-
47
- response_begin_index = input_text.index(response)
48
 
49
- pre_response_input = input_text[:response_begin_index]
50
- after_response_input = input_text[response_begin_index + len(response):]
51
  completion_ids = [
52
  tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions
53
  ]
54
-
55
  response_id = list(chain(*completion_ids))
56
- pre_response_id = tokenizer(pre_response_input, add_special_tokens=False)['input_ids']
57
- after_response_id = tokenizer(after_response_input, add_special_tokens=False)['input_ids']
58
 
59
-
60
- input_ids = pre_response_id + response_id + after_response_id
61
-
62
- value = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device)) # [1, N]
63
 
64
  completion_index = []
65
  for i, completion in enumerate(completion_ids):
@@ -70,12 +67,12 @@ def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
70
 
71
  step_value = value[0, completion_index].cpu().numpy().tolist()
72
  return step_value
73
- from transformers import AutoModelForTokenClassification, AutoTokenizer
74
- from peft import PeftModel
75
- base_model = AutoModelForTokenClassification.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",torch_dtype="auto",device_map="auto")
76
  model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto")
77
  tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM")
78
- steps
79
  input_text = "How to stop a cough?"
80
  step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]"
81
 
 
32
  As this LoRA adapter is built on `pixas/MedSSS_Policy`, you need to first prepare the base model in your platform.
33
 
34
  ```python
35
+ from itertools import chain
36
+ import torch
37
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
38
+ from peft import PeftModel
39
 
40
  def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
41
  # `outputs` generated by the MedSSS-Policy
 
 
 
42
  messages = [
43
  {"role": "user", "content": inputs},
44
  {"role": "assistant", "content": response}
45
  ]
46
+
47
+ prompt_text = tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
48
+ completions = ["Step" + completion if not completion.startswith("Step") else completion for completion in response.split("\n\nStep")]
49
 
 
 
50
  completion_ids = [
51
  tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions
52
  ]
 
53
  response_id = list(chain(*completion_ids))
54
+ pre_response_id = tokenizer(prompt_text, add_special_tokens=False)['input_ids']
 
55
 
56
+ input_ids = pre_response_id + response_id
57
+
58
+ outputs = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device)) # [1, N]
59
+ value = torch.softmax(outputs[0], dim=-1)[..., 1]
60
 
61
  completion_index = []
62
  for i, completion in enumerate(completion_ids):
 
67
 
68
  step_value = value[0, completion_index].cpu().numpy().tolist()
69
  return step_value
70
+
71
+
72
+ base_model = AutoModelForTokenClassification.from_pretrained("pixas/MedSSS_Policy",torch_dtype="auto",device_map="auto")
73
  model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto")
74
  tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM")
75
+
76
  input_text = "How to stop a cough?"
77
  step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]"
78