Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
import asyncio
import re
from copy import deepcopy
from typing import List
import json
import torch
from swift.llm import Template, to_device
from swift.plugin import ORM, orms, rm_plugins
from swift.utils import get_logger
logger = get_logger()
"""
Step 1: Define a Reward Class
Implement your custom reward calculation logic within the __call__ method.
The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
Step 2: Register the Reward Class in orms
For example:
python orms['external_math_acc'] = MathAccuracy
Step 3: Configure the Arguments
Use the following arguments when running the script:
bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
"""
# Code borrowed from plugin/orm.py
class MathAccuracy(ORM):
def __init__(self):
import importlib.util
assert importlib.util.find_spec('math_verify') is not None, (
"The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
def __call__(self, completions, solution, **kwargs) -> List[float]:
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
rewards = []
for content, sol in zip(completions, solution):
gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode='first_match',
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
rewards.append(reward)
return rewards
class MathFormat(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])'
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
return [1.0 if match else 0.0 for match in matches]
class CountdownORM(ORM):
def __call__(self, completions, target, nums, **kwargs) -> List[float]:
"""
Evaluates completions based on Mathematical correctness of the answer
Args:
completions (list[str]): Generated outputs
target (list[str]): Expected answers
nums (list[str]): Available numbers
Returns:
list[float]: Reward scores
"""
rewards = []
for completion, gt, numbers in zip(completions, target, nums):
try:
# Check if the format is correct
match = re.search(r'<answer>(.*?)<\/answer>', completion)
if match is None:
rewards.append(0.0)
continue
# Extract the "answer" part from the completion
equation = match.group(1).strip()
if '=' in equation:
equation = equation.split('=')[0]
# Extract all numbers from the equation
used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
# Check if all numbers are used exactly once
if sorted(used_numbers) != sorted(numbers):
rewards.append(0.0)
continue
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
allowed_pattern = r'^[\d+\-*/().\s]+$'
if not re.match(allowed_pattern, equation):
rewards.append(0.0)
continue
# Evaluate the equation with restricted globals and locals
result = eval(equation, {"__builti'ns__": None}, {})
# Check if the equation is correct and matches the ground truth
if abs(float(result) - float(gt)) < 1e-5:
rewards.append(1.0)
else:
rewards.append(0.0)
except Exception:
# If evaluation fails, reward is 0
rewards.append(0.0)
return rewards
class MultiModalAccuracyORM(ORM):
def __call__(self, completions, solution, **kwargs) -> List[float]:
"""
Reward function that checks if the completion is correct.
Args:
completions (list[str]): Generated outputs
solution (list[str]): Ground Truths.
Returns:
list[float]: Reward scores
"""
rewards = []
from math_verify import parse, verify
for content, sol in zip(completions, solution):
reward = 0.0
# Try symbolic verification first
try:
answer = parse(content)
if float(verify(answer, parse(sol))) > 0:
reward = 1.0
except Exception:
pass # Continue to next verification method if this fails
# If symbolic verification failed, try string matching
if reward == 0.0:
try:
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
# Extract answer from content if it has think/answer tags
content_match = re.search(r'<answer>(.*?)</answer>', content)
student_answer = content_match.group(1).strip() if content_match else content.strip()
# Compare the extracted answers
if student_answer == ground_truth:
reward = 1.0
except Exception:
pass # Keep reward as 0.0 if both methods fail
rewards.append(reward)
return rewards
# ref implementation: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
class CodeReward(ORM):
def __init__(self):
import importlib.util
assert importlib.util.find_spec('e2b') is not None, (
"The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'."
)
from dotenv import load_dotenv
load_dotenv()
@staticmethod
def extract_code(completion: str, language: str) -> str:
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
matches = pattern.findall(completion)
extracted_answer = matches[-1] if len(matches) >= 1 else ''
return extracted_answer
def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]:
"""Function wrapping the `run_async` function."""
# Create a new event loop and set it
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Run the async function and get the result
rewards = loop.run_until_complete(self.run_async(scripts, languages))
finally:
loop.close()
return rewards
async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]:
from e2b_code_interpreter import AsyncSandbox
# Create the sandbox by hand, currently there's no context manager for this version
try:
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
except Exception as e:
logger.warning(f'Error from E2B executor: {e}')
return [0.0] * len(scripts)
# Create a list of tasks for running scripts concurrently
tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)]
# Wait for all tasks to complete and gather their results as they finish
results = await asyncio.gather(*tasks)
rewards = list(results) # collect results
# Kill the sandbox after all the tasks are complete
await sbx.kill()
return rewards
async def run_script(self, sbx, script: str, language: str) -> float:
try:
execution = await sbx.run_code(script, language=language, timeout=30)
except Exception as e:
logger.warning(f'Error from E2B executor: {e}')
return 0.0
try:
return float(execution.text)
except (TypeError, ValueError):
return 0.0
def __call__(self, completions, **kwargs) -> List[float]:
"""Reward function that evaluates code snippets using the E2B code interpreter.
Assumes the dataset contains a `verification_info` column with test cases.
"""
evaluation_script_template = """
import subprocess
import json
def evaluate_code(code, test_cases):
passed = 0
total = len(test_cases)
exec_timeout = 5
for case in test_cases:
process = subprocess.run(
["python3", "-c", code],
input=case["input"],
text=True,
capture_output=True,
timeout=exec_timeout
)
if process.returncode != 0: # Error in execution
continue
output = process.stdout.strip()
if output.strip() == case["output"].strip():
passed += 1
success_rate = (passed / total)
return success_rate
code_snippet = {code}
test_cases = json.loads({test_cases})
evaluate_code(code_snippet, test_cases)
"""
verification_info = kwargs['verification_info']
languages = [info['language'] for info in verification_info]
code_snippets = [
self.extract_code(completion, language) for completion, language in zip(completions, languages)
]
scripts = [
evaluation_script_template.format(
code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases'])))
for code, info in zip(code_snippets, verification_info)
]
try:
rewards = self.run_async_from_sync(scripts, languages)
except Exception as e:
logger.warning(f'Error from E2B executor: {e}')
rewards = [0.0] * len(completions)
return rewards
class CodeFormat(ORM):
def __call__(self, completions, **kwargs) -> List[float]:
verification_info = kwargs['verification_info']
rewards = []
for content, info in zip(completions, verification_info):
pattern = r'^<think>.*?</think>\s*<answer>.*?```{}.*?```.*?</answer>(?![\s\S])'.format(info['language'])
match = re.match(pattern, content, re.DOTALL | re.MULTILINE)
reward = 1.0 if match else 0.0
rewards.append(reward)
return rewards
class CodeRewardByJudge0(ORM):
LANGUAGE_ID_MAP = {
'assembly': 45,
'bash': 46,
'basic': 47,
'c': 50,
'c++': 54,
'clojure': 86,
'c#': 51,
'cobol': 77,
'common lisp': 55,
'd': 56,
'elixir': 57,
'erlang': 58,
'executable': 44,
'f#': 87,
'fortran': 59,
'go': 60,
'groovy': 88,
'haskell': 61,
'java': 62,
'javascript': 63,
'kotlin': 78,
'lua': 64,
'multi-file program': 89,
'objective-c': 79,
'ocaml': 65,
'octave': 66,
'pascal': 67,
'perl': 85,
'php': 68,
'plain text': 43,
'prolog': 69,
'python': 71,
'python2': 70,
'python3': 71,
'r': 80,
'ruby': 72,
'rust': 73,
'scala': 81,
'sql': 82,
'swift': 83,
'typescript': 74,
'visual basic.net': 84
}
PYTHON_ID = 71
def __init__(self):
import os
self.endpoint = os.getenv('JUDGE0_ENDPOINT')
assert self.endpoint is not None, (
'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.')
x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN')
self.headers = {'Content-Type': 'application/json'}
if x_auth_token is not None:
self.headers['X-Auth-Token'] = x_auth_token
@staticmethod
def extract_code(completion: str, language: str) -> str:
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
matches = pattern.findall(completion)
extracted_answer = matches[-1] if len(matches) >= 1 else ''
return extracted_answer
@classmethod
def get_language_id(cls, language):
if language is None:
return cls.PYTHON_ID
return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID)
async def _evaluate_code(self, code, test_cases, language_id):
import aiohttp
try:
passed = 0
total = len(test_cases)
for case in test_cases:
if code is not None and code != '':
async with aiohttp.ClientSession() as session:
payload = {
'source_code': code,
'language_id': language_id,
'stdin': case['input'],
'expected_output': case['output']
}
logger.debug(f'Payload: {payload}')
async with session.post(
self.endpoint + '/submissions/?wait=true', json=payload,
headers=self.headers) as response:
response_json = await response.json()
logger.debug(f'Response: {response_json}')
if response_json['status']['description'] == 'Accepted':
passed += 1
success_rate = (passed / total)
return success_rate
except Exception as e:
logger.warning(f'Error from Judge0 executor: {e}')
return 0.0
def run_async_from_sync(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
rewards = loop.run_until_complete(self.run_async())
finally:
loop.close()
return rewards
async def run_async(self):
tasks = [
self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language']))
for code, info in zip(self.code_snippets, self.verification_info)
]
results = await asyncio.gather(*tasks)
rewards = list(results)
return rewards
def __call__(self, completions, **kwargs) -> List[float]:
self.verification_info = kwargs['verification_info']
languages = [info['language'] for info in self.verification_info]
self.code_snippets = [
self.extract_code(completion, language) for completion, language in zip(completions, languages)
]
try:
rewards = self.run_async_from_sync()
except Exception as e:
logger.warning(f'Error from Judge0 executor: {e}')
rewards = [0.0] * len(completions)
return rewards
orms['external_math_acc'] = MathAccuracy
orms['external_math_format'] = MathFormat
orms['external_countdown'] = CountdownORM
orms['external_r1v_acc'] = MultiModalAccuracyORM
orms['external_code_reward'] = CodeReward
orms['external_code_format'] = CodeFormat
orms['external_code_reward_by_judge0'] = CodeRewardByJudge0
# For genrm you can refer to swift/llm/plugin/rm_plugin/GenRMPlugin
class CustomizedRMPlugin:
"""
Customized Reward Model Plugin, same to DefaultRMPlugin
It assumes that `self.model` is a classification model with a value head(output dimmension 1).
The first logits value from the model's output is used as the reward score.
"""
def __init__(self, model, template):
self.model = model
self.template: Template = template
def __call__(self, inputs):
batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
reward_inputs.pop('labels')
with torch.inference_mode():
return self.model(**reward_inputs).logits[:, 0]
rm_plugins['my_rmplugin'] = CustomizedRMPlugin