| | import re |
| | import torch |
| | torch.cuda.empty_cache() |
| | from typing import List |
| | from copy import deepcopy |
| |
|
| | from swift.plugin import ORM, orms |
| | from swift.utils import get_logger |
| |
|
| | logger = get_logger() |
| | """ |
| | Step 1: Define a Reward Class |
| | Implement your custom reward calculation logic within the __call__ method. |
| | The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters. |
| | |
| | Step 2: Register the Reward Class in orms |
| | For example: |
| | python orms['external_math_acc'] = MathAccuracy |
| | |
| | Step 3: Configure the Arguments |
| | Use the following arguments when running the script: |
| | bash --plugin /path/to/plugin.py --reward_funcs external_math_acc |
| | """ |
| |
|
| | def count_xml(text) -> float: |
| | """ |
| | Count XML tags in response. |
| | |
| | Args: |
| | text: Input text |
| | |
| | Returns: |
| | Score based on XML tag presence |
| | """ |
| | count = 0.0 |
| | if text.count("<think>") == 1: |
| | count += 0.5 |
| | if text.count("</think>") == 1: |
| | count += 0.5 |
| | return count |
| |
|
| | def extract_xml_answer(text: str) -> str: |
| | """ |
| | Extract answer from XML-formatted text. |
| | |
| | Args: |
| | text: Input text with XML tags |
| | |
| | Returns: |
| | Extracted answer text |
| | """ |
| | try: |
| | answer = text.split("</think>")[1] |
| | return answer.strip() |
| | except: |
| | return "" |
| |
|
| | def xmlcount_reward_func(completions, **kwargs) -> List[float]: |
| | """ |
| | Reward function based on proper XML tag usage. |
| | |
| | Args: |
| | completions: Model completions |
| | |
| | Returns: |
| | List of reward scores |
| | """ |
| | |
| | contents = completions |
| | return [count_xml(c) for c in contents] |
| |
|
| | def int_reward_func(completions, **kwargs) -> List[float]: |
| | """ |
| | Reward function that checks if responses contain valid direction tokens. |
| | |
| | Args: |
| | completions: Model completions |
| | |
| | Returns: |
| | List of reward scores |
| | """ |
| | allowed_tokens = {"<|up|>", "<|down|>", "<|right|>", "<|left|>"} |
| | |
| | |
| | responses = completions |
| | extracted_responses = [extract_xml_answer(r) for r in responses] |
| |
|
| | def is_valid_sequence(seq): |
| | |
| | seq_no_whitespace = re.sub(r'\s+', '', seq) |
| | if not seq_no_whitespace: |
| | return False |
| | found_tokens = re.findall(r'<\|(?:up|down|right|left)\|>', seq_no_whitespace) |
| | reconstructed = ''.join(found_tokens) |
| | if reconstructed != seq_no_whitespace: |
| | return False |
| | return all(token in allowed_tokens for token in found_tokens) |
| | |
| | return [1.0 if is_valid_sequence(r) else 0.0 for r in extracted_responses] |
| |
|
| | def count_turns(steps): |
| | moves = re.findall(r"<\|(.*?)\|>", steps) |
| | turns = sum(1 for i in range(1, len(moves)) if moves[i] != moves[i - 1]) |
| | return moves, turns |
| |
|
| | def correctness_reward_func(completions, answer, **kwargs) -> List[float]: |
| | """ |
| | Reward function that checks correctness of answers. |
| | |
| | Args: |
| | prompts: Input prompts |
| | completions: Model completions |
| | answer: Ground truth answers |
| | |
| | Returns: |
| | List of reward scores |
| | """ |
| | rewards = [] |
| | responses = completions |
| | extracted_responses = [extract_xml_answer(r) for r in responses] |
| | logger.debug('-'*20) |
| | |
| | logger.debug(f"\nAnswer:\n{answer[0]}") |
| | logger.debug(f"\nResponse:\n{responses[0]}") |
| | logger.debug(f"\nExtracted:\n{extracted_responses[0]}") |
| | for r, a in zip(extracted_responses, answer): |
| | r_steps, r_turns = count_turns(r) |
| | a_steps, a_turns = count_turns(a) |
| | if r == a: |
| | reward = len(r_steps) * 2 * (r_turns + 1) |
| | else: |
| | k = 0 |
| | for r_s, a_s in zip(r_steps, a_steps): |
| | if r_s == a_s: |
| | k += 1 |
| | else: |
| | break |
| | prefix = r_steps[:k] |
| | turns = count_turns("".join(prefix))[1] |
| | reward = k * 1 * (turns + 1) |
| | rewards.append(reward) |
| | return rewards |
| |
|
| | class MazeReward(ORM): |
| |
|
| | def __call__(self, completions, solution, **kwargs) -> List[float]: |
| | |
| | rewards = correctness_reward_func(completions, solution) |
| | return rewards |
| |
|
| | class MazeFormat(ORM): |
| |
|
| | def __call__(self, completions, solution, **kwargs) -> List[float]: |
| | |
| | rewards = int_reward_func(completions) |
| | return rewards |
| |
|
| | class Format(ORM): |
| |
|
| | def __call__(self, completions, **kwargs) -> List[float]: |
| | rewards = xmlcount_reward_func(completions) |
| | return rewards |
| |
|
| | orms['external_r1v_acc'] = MazeReward |
| | orms['external_r1v_format'] = MazeFormat |
| | orms['format'] = Format |