|
|
import asyncio |
|
|
import re |
|
|
from copy import deepcopy |
|
|
from typing import List |
|
|
|
|
|
import json |
|
|
import torch |
|
|
|
|
|
from swift.llm import Template, to_device |
|
|
from swift.plugin import ORM, orms, rm_plugins |
|
|
from swift.utils import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
""" |
|
|
Step 1: Define a Reward Class |
|
|
Implement your custom reward calculation logic within the __call__ method. |
|
|
The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters. |
|
|
|
|
|
Step 2: Register the Reward Class in orms |
|
|
For example: |
|
|
python orms['external_math_acc'] = MathAccuracy |
|
|
|
|
|
Step 3: Configure the Arguments |
|
|
Use the following arguments when running the script: |
|
|
bash --plugin /path/to/plugin.py --reward_funcs external_math_acc |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
class MathAccuracy(ORM): |
|
|
|
|
|
def __init__(self): |
|
|
import importlib.util |
|
|
assert importlib.util.find_spec('math_verify') is not None, ( |
|
|
"The math_verify package is required but not installed. Please install it using 'pip install math_verify'.") |
|
|
|
|
|
def __call__(self, completions, solution, **kwargs) -> List[float]: |
|
|
from latex2sympy2_extended import NormalizationConfig |
|
|
from math_verify import LatexExtractionConfig, parse, verify |
|
|
rewards = [] |
|
|
for content, sol in zip(completions, solution): |
|
|
gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()]) |
|
|
if len(gold_parsed) != 0: |
|
|
|
|
|
answer_parsed = parse( |
|
|
content, |
|
|
extraction_config=[ |
|
|
LatexExtractionConfig( |
|
|
normalization_config=NormalizationConfig( |
|
|
nits=False, |
|
|
malformed_operators=False, |
|
|
basic_latex=True, |
|
|
equations=True, |
|
|
boxed=True, |
|
|
units=True, |
|
|
), |
|
|
|
|
|
boxed_match_priority=0, |
|
|
try_extract_without_anchor=False, |
|
|
) |
|
|
], |
|
|
extraction_mode='first_match', |
|
|
) |
|
|
|
|
|
reward = float(verify(answer_parsed, gold_parsed)) |
|
|
else: |
|
|
|
|
|
reward = 1.0 |
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
|
|
|
|
|
|
class MathFormat(ORM): |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
"""Reward function that checks if the completion has a specific format.""" |
|
|
pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])' |
|
|
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] |
|
|
return [1.0 if match else 0.0 for match in matches] |
|
|
|
|
|
|
|
|
class CountdownORM(ORM): |
|
|
|
|
|
def __call__(self, completions, target, nums, **kwargs) -> List[float]: |
|
|
""" |
|
|
Evaluates completions based on Mathematical correctness of the answer |
|
|
|
|
|
Args: |
|
|
completions (list[str]): Generated outputs |
|
|
target (list[str]): Expected answers |
|
|
nums (list[str]): Available numbers |
|
|
|
|
|
Returns: |
|
|
list[float]: Reward scores |
|
|
""" |
|
|
rewards = [] |
|
|
for completion, gt, numbers in zip(completions, target, nums): |
|
|
try: |
|
|
|
|
|
match = re.search(r'<answer>(.*?)<\/answer>', completion) |
|
|
if match is None: |
|
|
rewards.append(0.0) |
|
|
continue |
|
|
|
|
|
equation = match.group(1).strip() |
|
|
if '=' in equation: |
|
|
equation = equation.split('=')[0] |
|
|
|
|
|
used_numbers = [int(n) for n in re.findall(r'\d+', equation)] |
|
|
|
|
|
|
|
|
if sorted(used_numbers) != sorted(numbers): |
|
|
rewards.append(0.0) |
|
|
continue |
|
|
|
|
|
allowed_pattern = r'^[\d+\-*/().\s]+$' |
|
|
if not re.match(allowed_pattern, equation): |
|
|
rewards.append(0.0) |
|
|
continue |
|
|
|
|
|
|
|
|
result = eval(equation, {"__builti'ns__": None}, {}) |
|
|
|
|
|
if abs(float(result) - float(gt)) < 1e-5: |
|
|
rewards.append(1.0) |
|
|
else: |
|
|
rewards.append(0.0) |
|
|
except Exception: |
|
|
|
|
|
rewards.append(0.0) |
|
|
return rewards |
|
|
|
|
|
|
|
|
class MultiModalAccuracyORM(ORM): |
|
|
|
|
|
def __call__(self, completions, solution, **kwargs) -> List[float]: |
|
|
""" |
|
|
Reward function that checks if the completion is correct. |
|
|
Args: |
|
|
completions (list[str]): Generated outputs |
|
|
solution (list[str]): Ground Truths. |
|
|
|
|
|
Returns: |
|
|
list[float]: Reward scores |
|
|
""" |
|
|
rewards = [] |
|
|
from math_verify import parse, verify |
|
|
for content, sol in zip(completions, solution): |
|
|
reward = 0.0 |
|
|
|
|
|
try: |
|
|
answer = parse(content) |
|
|
if float(verify(answer, parse(sol))) > 0: |
|
|
reward = 1.0 |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if reward == 0.0: |
|
|
try: |
|
|
|
|
|
sol_match = re.search(r'<answer>(.*?)</answer>', sol) |
|
|
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() |
|
|
|
|
|
|
|
|
content_match = re.search(r'<answer>(.*?)</answer>', content) |
|
|
student_answer = content_match.group(1).strip() if content_match else content.strip() |
|
|
|
|
|
|
|
|
if student_answer == ground_truth: |
|
|
reward = 1.0 |
|
|
except Exception: |
|
|
pass |
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
|
|
|
|
|
|
|
|
|
class CodeReward(ORM): |
|
|
|
|
|
def __init__(self): |
|
|
import importlib.util |
|
|
assert importlib.util.find_spec('e2b') is not None, ( |
|
|
"The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'." |
|
|
) |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
@staticmethod |
|
|
def extract_code(completion: str, language: str) -> str: |
|
|
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL) |
|
|
matches = pattern.findall(completion) |
|
|
extracted_answer = matches[-1] if len(matches) >= 1 else '' |
|
|
return extracted_answer |
|
|
|
|
|
def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]: |
|
|
"""Function wrapping the `run_async` function.""" |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
try: |
|
|
|
|
|
rewards = loop.run_until_complete(self.run_async(scripts, languages)) |
|
|
finally: |
|
|
loop.close() |
|
|
|
|
|
return rewards |
|
|
|
|
|
async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]: |
|
|
from e2b_code_interpreter import AsyncSandbox |
|
|
|
|
|
|
|
|
try: |
|
|
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3) |
|
|
except Exception as e: |
|
|
logger.warning(f'Error from E2B executor: {e}') |
|
|
return [0.0] * len(scripts) |
|
|
|
|
|
tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)] |
|
|
|
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
|
rewards = list(results) |
|
|
|
|
|
|
|
|
await sbx.kill() |
|
|
|
|
|
return rewards |
|
|
|
|
|
async def run_script(self, sbx, script: str, language: str) -> float: |
|
|
try: |
|
|
execution = await sbx.run_code(script, language=language, timeout=30) |
|
|
except Exception as e: |
|
|
logger.warning(f'Error from E2B executor: {e}') |
|
|
return 0.0 |
|
|
try: |
|
|
return float(execution.text) |
|
|
except (TypeError, ValueError): |
|
|
return 0.0 |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
"""Reward function that evaluates code snippets using the E2B code interpreter. |
|
|
|
|
|
Assumes the dataset contains a `verification_info` column with test cases. |
|
|
""" |
|
|
evaluation_script_template = """ |
|
|
import subprocess |
|
|
import json |
|
|
|
|
|
def evaluate_code(code, test_cases): |
|
|
passed = 0 |
|
|
total = len(test_cases) |
|
|
exec_timeout = 5 |
|
|
|
|
|
for case in test_cases: |
|
|
process = subprocess.run( |
|
|
["python3", "-c", code], |
|
|
input=case["input"], |
|
|
text=True, |
|
|
capture_output=True, |
|
|
timeout=exec_timeout |
|
|
) |
|
|
|
|
|
if process.returncode != 0: # Error in execution |
|
|
continue |
|
|
|
|
|
output = process.stdout.strip() |
|
|
if output.strip() == case["output"].strip(): |
|
|
passed += 1 |
|
|
|
|
|
success_rate = (passed / total) |
|
|
return success_rate |
|
|
|
|
|
code_snippet = {code} |
|
|
test_cases = json.loads({test_cases}) |
|
|
|
|
|
evaluate_code(code_snippet, test_cases) |
|
|
""" |
|
|
verification_info = kwargs['verification_info'] |
|
|
languages = [info['language'] for info in verification_info] |
|
|
code_snippets = [ |
|
|
self.extract_code(completion, language) for completion, language in zip(completions, languages) |
|
|
] |
|
|
scripts = [ |
|
|
evaluation_script_template.format( |
|
|
code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases']))) |
|
|
for code, info in zip(code_snippets, verification_info) |
|
|
] |
|
|
try: |
|
|
rewards = self.run_async_from_sync(scripts, languages) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f'Error from E2B executor: {e}') |
|
|
rewards = [0.0] * len(completions) |
|
|
|
|
|
return rewards |
|
|
|
|
|
|
|
|
class CodeFormat(ORM): |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
verification_info = kwargs['verification_info'] |
|
|
rewards = [] |
|
|
for content, info in zip(completions, verification_info): |
|
|
pattern = r'^<think>.*?</think>\s*<answer>.*?```{}.*?```.*?</answer>(?![\s\S])'.format(info['language']) |
|
|
match = re.match(pattern, content, re.DOTALL | re.MULTILINE) |
|
|
reward = 1.0 if match else 0.0 |
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
|
|
|
|
|
|
class CodeRewardByJudge0(ORM): |
|
|
LANGUAGE_ID_MAP = { |
|
|
'assembly': 45, |
|
|
'bash': 46, |
|
|
'basic': 47, |
|
|
'c': 50, |
|
|
'c++': 54, |
|
|
'clojure': 86, |
|
|
'c#': 51, |
|
|
'cobol': 77, |
|
|
'common lisp': 55, |
|
|
'd': 56, |
|
|
'elixir': 57, |
|
|
'erlang': 58, |
|
|
'executable': 44, |
|
|
'f#': 87, |
|
|
'fortran': 59, |
|
|
'go': 60, |
|
|
'groovy': 88, |
|
|
'haskell': 61, |
|
|
'java': 62, |
|
|
'javascript': 63, |
|
|
'kotlin': 78, |
|
|
'lua': 64, |
|
|
'multi-file program': 89, |
|
|
'objective-c': 79, |
|
|
'ocaml': 65, |
|
|
'octave': 66, |
|
|
'pascal': 67, |
|
|
'perl': 85, |
|
|
'php': 68, |
|
|
'plain text': 43, |
|
|
'prolog': 69, |
|
|
'python': 71, |
|
|
'python2': 70, |
|
|
'python3': 71, |
|
|
'r': 80, |
|
|
'ruby': 72, |
|
|
'rust': 73, |
|
|
'scala': 81, |
|
|
'sql': 82, |
|
|
'swift': 83, |
|
|
'typescript': 74, |
|
|
'visual basic.net': 84 |
|
|
} |
|
|
PYTHON_ID = 71 |
|
|
|
|
|
def __init__(self): |
|
|
import os |
|
|
self.endpoint = os.getenv('JUDGE0_ENDPOINT') |
|
|
assert self.endpoint is not None, ( |
|
|
'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.') |
|
|
x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN') |
|
|
self.headers = {'Content-Type': 'application/json'} |
|
|
if x_auth_token is not None: |
|
|
self.headers['X-Auth-Token'] = x_auth_token |
|
|
|
|
|
@staticmethod |
|
|
def extract_code(completion: str, language: str) -> str: |
|
|
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL) |
|
|
matches = pattern.findall(completion) |
|
|
extracted_answer = matches[-1] if len(matches) >= 1 else '' |
|
|
return extracted_answer |
|
|
|
|
|
@classmethod |
|
|
def get_language_id(cls, language): |
|
|
if language is None: |
|
|
return cls.PYTHON_ID |
|
|
return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID) |
|
|
|
|
|
async def _evaluate_code(self, code, test_cases, language_id): |
|
|
import aiohttp |
|
|
try: |
|
|
passed = 0 |
|
|
total = len(test_cases) |
|
|
|
|
|
for case in test_cases: |
|
|
if code is not None and code != '': |
|
|
async with aiohttp.ClientSession() as session: |
|
|
payload = { |
|
|
'source_code': code, |
|
|
'language_id': language_id, |
|
|
'stdin': case['input'], |
|
|
'expected_output': case['output'] |
|
|
} |
|
|
logger.debug(f'Payload: {payload}') |
|
|
async with session.post( |
|
|
self.endpoint + '/submissions/?wait=true', json=payload, |
|
|
headers=self.headers) as response: |
|
|
response_json = await response.json() |
|
|
logger.debug(f'Response: {response_json}') |
|
|
if response_json['status']['description'] == 'Accepted': |
|
|
passed += 1 |
|
|
|
|
|
success_rate = (passed / total) |
|
|
return success_rate |
|
|
except Exception as e: |
|
|
logger.warning(f'Error from Judge0 executor: {e}') |
|
|
return 0.0 |
|
|
|
|
|
def run_async_from_sync(self): |
|
|
loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(loop) |
|
|
try: |
|
|
rewards = loop.run_until_complete(self.run_async()) |
|
|
finally: |
|
|
loop.close() |
|
|
return rewards |
|
|
|
|
|
async def run_async(self): |
|
|
tasks = [ |
|
|
self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language'])) |
|
|
for code, info in zip(self.code_snippets, self.verification_info) |
|
|
] |
|
|
results = await asyncio.gather(*tasks) |
|
|
rewards = list(results) |
|
|
return rewards |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
self.verification_info = kwargs['verification_info'] |
|
|
|
|
|
languages = [info['language'] for info in self.verification_info] |
|
|
self.code_snippets = [ |
|
|
self.extract_code(completion, language) for completion, language in zip(completions, languages) |
|
|
] |
|
|
|
|
|
try: |
|
|
rewards = self.run_async_from_sync() |
|
|
except Exception as e: |
|
|
logger.warning(f'Error from Judge0 executor: {e}') |
|
|
rewards = [0.0] * len(completions) |
|
|
return rewards |
|
|
|
|
|
|
|
|
orms['external_math_acc'] = MathAccuracy |
|
|
orms['external_math_format'] = MathFormat |
|
|
orms['external_countdown'] = CountdownORM |
|
|
orms['external_r1v_acc'] = MultiModalAccuracyORM |
|
|
orms['external_code_reward'] = CodeReward |
|
|
orms['external_code_format'] = CodeFormat |
|
|
orms['external_code_reward_by_judge0'] = CodeRewardByJudge0 |
|
|
|
|
|
|
|
|
|
|
|
class CustomizedRMPlugin: |
|
|
""" |
|
|
Customized Reward Model Plugin, same to DefaultRMPlugin |
|
|
|
|
|
It assumes that `self.model` is a classification model with a value head(output dimmension 1). |
|
|
The first logits value from the model's output is used as the reward score. |
|
|
""" |
|
|
|
|
|
def __init__(self, model, template): |
|
|
self.model = model |
|
|
self.template: Template = template |
|
|
|
|
|
def __call__(self, inputs): |
|
|
batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs] |
|
|
reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device) |
|
|
reward_inputs.pop('labels') |
|
|
|
|
|
with torch.inference_mode(): |
|
|
return self.model(**reward_inputs).logits[:, 0] |
|
|
|
|
|
|
|
|
rm_plugins['my_rmplugin'] = CustomizedRMPlugin |
|
|
|