import asyncio import re from copy import deepcopy from typing import List import json import torch from swift.llm import Template, to_device from swift.plugin import ORM, orms, rm_plugins from swift.utils import get_logger logger = get_logger() """ Step 1: Define a Reward Class Implement your custom reward calculation logic within the __call__ method. The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters. Step 2: Register the Reward Class in orms For example: python orms['external_math_acc'] = MathAccuracy Step 3: Configure the Arguments Use the following arguments when running the script: bash --plugin /path/to/plugin.py --reward_funcs external_math_acc """ # Code borrowed from plugin/orm.py class MathAccuracy(ORM): def __init__(self): import importlib.util assert importlib.util.find_spec('math_verify') is not None, ( "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.") def __call__(self, completions, solution, **kwargs) -> List[float]: from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify rewards = [] for content, sol in zip(completions, solution): gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()]) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) answer_parsed = parse( content, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, equations=True, boxed=True, units=True, ), # Ensures that boxed is tried first boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode='first_match', ) # Reward 1 if the content is the same as the ground truth, 0 otherwise reward = float(verify(answer_parsed, gold_parsed)) else: # If the gold solution is not parseable, we reward 1 to skip this example reward = 1.0 rewards.append(reward) return rewards class MathFormat(ORM): def __call__(self, completions, **kwargs) -> List[float]: """Reward function that checks if the completion has a specific format.""" pattern = r'^.*?\s*.*?(?![\s\S])' matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] return [1.0 if match else 0.0 for match in matches] class CountdownORM(ORM): def __call__(self, completions, target, nums, **kwargs) -> List[float]: """ Evaluates completions based on Mathematical correctness of the answer Args: completions (list[str]): Generated outputs target (list[str]): Expected answers nums (list[str]): Available numbers Returns: list[float]: Reward scores """ rewards = [] for completion, gt, numbers in zip(completions, target, nums): try: # Check if the format is correct match = re.search(r'(.*?)<\/answer>', completion) if match is None: rewards.append(0.0) continue # Extract the "answer" part from the completion equation = match.group(1).strip() if '=' in equation: equation = equation.split('=')[0] # Extract all numbers from the equation used_numbers = [int(n) for n in re.findall(r'\d+', equation)] # Check if all numbers are used exactly once if sorted(used_numbers) != sorted(numbers): rewards.append(0.0) continue # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace allowed_pattern = r'^[\d+\-*/().\s]+$' if not re.match(allowed_pattern, equation): rewards.append(0.0) continue # Evaluate the equation with restricted globals and locals result = eval(equation, {"__builti'ns__": None}, {}) # Check if the equation is correct and matches the ground truth if abs(float(result) - float(gt)) < 1e-5: rewards.append(1.0) else: rewards.append(0.0) except Exception: # If evaluation fails, reward is 0 rewards.append(0.0) return rewards class MultiModalAccuracyORM(ORM): def __call__(self, completions, solution, **kwargs) -> List[float]: """ Reward function that checks if the completion is correct. Args: completions (list[str]): Generated outputs solution (list[str]): Ground Truths. Returns: list[float]: Reward scores """ rewards = [] from math_verify import parse, verify for content, sol in zip(completions, solution): reward = 0.0 # Try symbolic verification first try: answer = parse(content) if float(verify(answer, parse(sol))) > 0: reward = 1.0 except Exception: pass # Continue to next verification method if this fails # If symbolic verification failed, try string matching if reward == 0.0: try: # Extract answer from solution if it has think/answer tags sol_match = re.search(r'(.*?)', sol) ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() # Extract answer from content if it has think/answer tags content_match = re.search(r'(.*?)', content) student_answer = content_match.group(1).strip() if content_match else content.strip() # Compare the extracted answers if student_answer == ground_truth: reward = 1.0 except Exception: pass # Keep reward as 0.0 if both methods fail rewards.append(reward) return rewards # ref implementation: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py class CodeReward(ORM): def __init__(self): import importlib.util assert importlib.util.find_spec('e2b') is not None, ( "The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'." ) from dotenv import load_dotenv load_dotenv() @staticmethod def extract_code(completion: str, language: str) -> str: pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL) matches = pattern.findall(completion) extracted_answer = matches[-1] if len(matches) >= 1 else '' return extracted_answer def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]: """Function wrapping the `run_async` function.""" # Create a new event loop and set it loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: # Run the async function and get the result rewards = loop.run_until_complete(self.run_async(scripts, languages)) finally: loop.close() return rewards async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]: from e2b_code_interpreter import AsyncSandbox # Create the sandbox by hand, currently there's no context manager for this version try: sbx = await AsyncSandbox.create(timeout=30, request_timeout=3) except Exception as e: logger.warning(f'Error from E2B executor: {e}') return [0.0] * len(scripts) # Create a list of tasks for running scripts concurrently tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)] # Wait for all tasks to complete and gather their results as they finish results = await asyncio.gather(*tasks) rewards = list(results) # collect results # Kill the sandbox after all the tasks are complete await sbx.kill() return rewards async def run_script(self, sbx, script: str, language: str) -> float: try: execution = await sbx.run_code(script, language=language, timeout=30) except Exception as e: logger.warning(f'Error from E2B executor: {e}') return 0.0 try: return float(execution.text) except (TypeError, ValueError): return 0.0 def __call__(self, completions, **kwargs) -> List[float]: """Reward function that evaluates code snippets using the E2B code interpreter. Assumes the dataset contains a `verification_info` column with test cases. """ evaluation_script_template = """ import subprocess import json def evaluate_code(code, test_cases): passed = 0 total = len(test_cases) exec_timeout = 5 for case in test_cases: process = subprocess.run( ["python3", "-c", code], input=case["input"], text=True, capture_output=True, timeout=exec_timeout ) if process.returncode != 0: # Error in execution continue output = process.stdout.strip() if output.strip() == case["output"].strip(): passed += 1 success_rate = (passed / total) return success_rate code_snippet = {code} test_cases = json.loads({test_cases}) evaluate_code(code_snippet, test_cases) """ verification_info = kwargs['verification_info'] languages = [info['language'] for info in verification_info] code_snippets = [ self.extract_code(completion, language) for completion, language in zip(completions, languages) ] scripts = [ evaluation_script_template.format( code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases']))) for code, info in zip(code_snippets, verification_info) ] try: rewards = self.run_async_from_sync(scripts, languages) except Exception as e: logger.warning(f'Error from E2B executor: {e}') rewards = [0.0] * len(completions) return rewards class CodeFormat(ORM): def __call__(self, completions, **kwargs) -> List[float]: verification_info = kwargs['verification_info'] rewards = [] for content, info in zip(completions, verification_info): pattern = r'^.*?\s*.*?```{}.*?```.*?(?![\s\S])'.format(info['language']) match = re.match(pattern, content, re.DOTALL | re.MULTILINE) reward = 1.0 if match else 0.0 rewards.append(reward) return rewards class CodeRewardByJudge0(ORM): LANGUAGE_ID_MAP = { 'assembly': 45, 'bash': 46, 'basic': 47, 'c': 50, 'c++': 54, 'clojure': 86, 'c#': 51, 'cobol': 77, 'common lisp': 55, 'd': 56, 'elixir': 57, 'erlang': 58, 'executable': 44, 'f#': 87, 'fortran': 59, 'go': 60, 'groovy': 88, 'haskell': 61, 'java': 62, 'javascript': 63, 'kotlin': 78, 'lua': 64, 'multi-file program': 89, 'objective-c': 79, 'ocaml': 65, 'octave': 66, 'pascal': 67, 'perl': 85, 'php': 68, 'plain text': 43, 'prolog': 69, 'python': 71, 'python2': 70, 'python3': 71, 'r': 80, 'ruby': 72, 'rust': 73, 'scala': 81, 'sql': 82, 'swift': 83, 'typescript': 74, 'visual basic.net': 84 } PYTHON_ID = 71 def __init__(self): import os self.endpoint = os.getenv('JUDGE0_ENDPOINT') assert self.endpoint is not None, ( 'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.') x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN') self.headers = {'Content-Type': 'application/json'} if x_auth_token is not None: self.headers['X-Auth-Token'] = x_auth_token @staticmethod def extract_code(completion: str, language: str) -> str: pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL) matches = pattern.findall(completion) extracted_answer = matches[-1] if len(matches) >= 1 else '' return extracted_answer @classmethod def get_language_id(cls, language): if language is None: return cls.PYTHON_ID return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID) async def _evaluate_code(self, code, test_cases, language_id): import aiohttp try: passed = 0 total = len(test_cases) for case in test_cases: if code is not None and code != '': async with aiohttp.ClientSession() as session: payload = { 'source_code': code, 'language_id': language_id, 'stdin': case['input'], 'expected_output': case['output'] } logger.debug(f'Payload: {payload}') async with session.post( self.endpoint + '/submissions/?wait=true', json=payload, headers=self.headers) as response: response_json = await response.json() logger.debug(f'Response: {response_json}') if response_json['status']['description'] == 'Accepted': passed += 1 success_rate = (passed / total) return success_rate except Exception as e: logger.warning(f'Error from Judge0 executor: {e}') return 0.0 def run_async_from_sync(self): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: rewards = loop.run_until_complete(self.run_async()) finally: loop.close() return rewards async def run_async(self): tasks = [ self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language'])) for code, info in zip(self.code_snippets, self.verification_info) ] results = await asyncio.gather(*tasks) rewards = list(results) return rewards def __call__(self, completions, **kwargs) -> List[float]: self.verification_info = kwargs['verification_info'] languages = [info['language'] for info in self.verification_info] self.code_snippets = [ self.extract_code(completion, language) for completion, language in zip(completions, languages) ] try: rewards = self.run_async_from_sync() except Exception as e: logger.warning(f'Error from Judge0 executor: {e}') rewards = [0.0] * len(completions) return rewards orms['external_math_acc'] = MathAccuracy orms['external_math_format'] = MathFormat orms['external_countdown'] = CountdownORM orms['external_r1v_acc'] = MultiModalAccuracyORM orms['external_code_reward'] = CodeReward orms['external_code_format'] = CodeFormat orms['external_code_reward_by_judge0'] = CodeRewardByJudge0 # For genrm you can refer to swift/llm/plugin/rm_plugin/GenRMPlugin class CustomizedRMPlugin: """ Customized Reward Model Plugin, same to DefaultRMPlugin It assumes that `self.model` is a classification model with a value head(output dimmension 1). The first logits value from the model's output is used as the reward score. """ def __init__(self, model, template): self.model = model self.template: Template = template def __call__(self, inputs): batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs] reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device) reward_inputs.pop('labels') with torch.inference_mode(): return self.model(**reward_inputs).logits[:, 0] rm_plugins['my_rmplugin'] = CustomizedRMPlugin