|
|
from typing import Dict, List |
|
|
import re |
|
|
|
|
|
from larm.data.envs.base_env import StaticEnv |
|
|
from larm.common.registry import registry |
|
|
from larm.data.utils.code_utils import PyExecutor, extract_python_code |
|
|
|
|
|
|
|
|
|
|
|
@registry.register_env("kodcode") |
|
|
class KodCodeEnv(StaticEnv): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
@classmethod |
|
|
def _rename_func(cls, answer: str, function_name: str) -> str: |
|
|
""" |
|
|
Replace the name of the first function in `answer` with `function_name`. |
|
|
Only modifies the function name, keeps everything else intact. |
|
|
""" |
|
|
pattern = r"def\s+(\w+)\s*\(" |
|
|
|
|
|
new_answer = re.sub(pattern, f"def {function_name}(", answer, count=1) |
|
|
return new_answer |
|
|
|
|
|
@classmethod |
|
|
def _accuracy_reward(cls, completions: List[str], test: List[str], test_info, **kwargs) -> List[float]: |
|
|
|
|
|
py_executor = PyExecutor() |
|
|
scores = [] |
|
|
for completion, t, tf in zip(completions, test, test_info): |
|
|
func_blocks = extract_python_code(completion.strip()) |
|
|
collected_answer = '\n'.join(func_blocks) |
|
|
renamed_answer = cls._rename_func(collected_answer, tf[0]["function_name"]) |
|
|
_, _, results = py_executor.execute(renamed_answer, [t]) |
|
|
|
|
|
score = sum(results) / len(results) |
|
|
scores.append(score) |
|
|
|
|
|
return scores |
|
|
|
|
|
@classmethod |
|
|
def _format_reward(cls, completions: List[str], **kwargs): |
|
|
pass |
|
|
|