File size: 4,805 Bytes
0ad9ab3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import re
import torch
torch.cuda.empty_cache()
from typing import List
from copy import deepcopy
from swift.plugin import ORM, orms
from swift.utils import get_logger
logger = get_logger()
"""
Step 1: Define a Reward Class
Implement your custom reward calculation logic within the __call__ method.
The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
Step 2: Register the Reward Class in orms
For example:
python orms['external_math_acc'] = MathAccuracy
Step 3: Configure the Arguments
Use the following arguments when running the script:
bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
"""
def count_xml(text) -> float:
"""
Count XML tags in response.
Args:
text: Input text
Returns:
Score based on XML tag presence
"""
count = 0.0
if text.count("<think>") == 1:
count += 0.5
if text.count("</think>") == 1:
count += 0.5
return count
def extract_xml_answer(text: str) -> str:
"""
Extract answer from XML-formatted text.
Args:
text: Input text with XML tags
Returns:
Extracted answer text
"""
try:
answer = text.split("</think>")[1]
return answer.strip()
except:
return ""
def xmlcount_reward_func(completions, **kwargs) -> List[float]:
"""
Reward function based on proper XML tag usage.
Args:
completions: Model completions
Returns:
List of reward scores
"""
# contents = [completion[0]["content"] for completion in completions]
contents = completions
return [count_xml(c) for c in contents]
def int_reward_func(completions, **kwargs) -> List[float]:
"""
Reward function that checks if responses contain valid direction tokens.
Args:
completions: Model completions
Returns:
List of reward scores
"""
allowed_tokens = {"<|up|>", "<|down|>", "<|right|>", "<|left|>"}
# responses = [completion[0]['content'] for completion in completions]
responses = completions
extracted_responses = [extract_xml_answer(r) for r in responses]
def is_valid_sequence(seq):
seq_no_whitespace = re.sub(r'\s+', '', seq)
if not seq_no_whitespace:
return False
found_tokens = re.findall(r'<\|(?:up|down|right|left)\|>', seq_no_whitespace)
reconstructed = ''.join(found_tokens)
if reconstructed != seq_no_whitespace:
return False
return all(token in allowed_tokens for token in found_tokens)
return [1.0 if is_valid_sequence(r) else 0.0 for r in extracted_responses]
def count_turns(steps):
moves = re.findall(r"<\|(.*?)\|>", steps)
turns = sum(1 for i in range(1, len(moves)) if moves[i] != moves[i - 1])
return moves, turns
def correctness_reward_func(completions, answer, **kwargs) -> List[float]:
"""
Reward function that checks correctness of answers.
Args:
prompts: Input prompts
completions: Model completions
answer: Ground truth answers
Returns:
List of reward scores
"""
rewards = []
responses = completions
extracted_responses = [extract_xml_answer(r) for r in responses]
logger.debug('-'*20)
# logger.debug(f"Question:\n{q}")
logger.debug(f"\nAnswer:\n{answer[0]}")
logger.debug(f"\nResponse:\n{responses[0]}")
logger.debug(f"\nExtracted:\n{extracted_responses[0]}")
for r, a in zip(extracted_responses, answer):
r_steps, r_turns = count_turns(r)
a_steps, a_turns = count_turns(a)
if r == a:
reward = len(r_steps) * 2 * (r_turns + 1)
else:
k = 0
for r_s, a_s in zip(r_steps, a_steps):
if r_s == a_s:
k += 1
else:
break
prefix = r_steps[:k]
turns = count_turns("".join(prefix))[1]
reward = k * 1 * (turns + 1)
rewards.append(reward)
return rewards
class MazeReward(ORM):
def __call__(self, completions, solution, **kwargs) -> List[float]:
# print(completions)
rewards = correctness_reward_func(completions, solution)
return rewards
class MazeFormat(ORM):
def __call__(self, completions, solution, **kwargs) -> List[float]:
# print(completions)
rewards = int_reward_func(completions)
return rewards
class Format(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
rewards = xmlcount_reward_func(completions)
return rewards
orms['external_r1v_acc'] = MazeReward
orms['external_r1v_format'] = MazeFormat
orms['format'] = Format |