File size: 1,475 Bytes
e34b94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from typing import Dict, List
import re
from larm.data.envs.base_env import StaticEnv
from larm.common.registry import registry
from larm.data.utils.code_utils import PyExecutor, extract_python_code
@registry.register_env("kodcode")
class KodCodeEnv(StaticEnv):
def __init__(self, config):
super().__init__(config)
@classmethod
def _rename_func(cls, answer: str, function_name: str) -> str:
"""
Replace the name of the first function in `answer` with `function_name`.
Only modifies the function name, keeps everything else intact.
"""
pattern = r"def\s+(\w+)\s*\("
new_answer = re.sub(pattern, f"def {function_name}(", answer, count=1)
return new_answer
@classmethod
def _accuracy_reward(cls, completions: List[str], test: List[str], test_info, **kwargs) -> List[float]:
py_executor = PyExecutor()
scores = []
for completion, t, tf in zip(completions, test, test_info):
func_blocks = extract_python_code(completion.strip())
collected_answer = '\n'.join(func_blocks)
renamed_answer = cls._rename_func(collected_answer, tf[0]["function_name"])
_, _, results = py_executor.execute(renamed_answer, [t])
score = sum(results) / len(results)
scores.append(score)
return scores
@classmethod
def _format_reward(cls, completions: List[str], **kwargs):
pass
|