|
|
""" |
|
|
Benchmark Problem Loader |
|
|
|
|
|
AZR ๊ธฐ๋ฐ TestTime RLVR์ ์ํ ๋ฒค์น๋งํฌ ๋ฌธ์ ๋ก๋ฉ ์์คํ
|
|
|
๊ธฐ์กด Test-Time-RLVR์ load_humaneval_problem ํจ์๋ฅผ ํ์ฅ |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from typing import Dict, List, Any, Tuple, Optional |
|
|
from pathlib import Path |
|
|
|
|
|
from .config import BenchmarkConfig, TestTimeConfig |
|
|
from .logger import TestTimeLogger |
|
|
|
|
|
|
|
|
class BenchmarkProblemLoader: |
|
|
"""๋ฒค์น๋งํฌ ๋ฌธ์ ๋ก๋ฉ ๋ฐ ๊ด๋ฆฌ (EvalPlus ํ์ค ๋ฐฉ์ ์ฌ์ฉ)""" |
|
|
|
|
|
def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None): |
|
|
self.config = config |
|
|
self.logger = logger or TestTimeLogger() |
|
|
self.loaded_problems = {} |
|
|
self.evalplus_cache = {} |
|
|
|
|
|
def _load_evalplus_data(self, benchmark_name: str) -> Dict[str, Dict[str, Any]]: |
|
|
"""EvalPlus ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์บ์""" |
|
|
if benchmark_name in self.evalplus_cache: |
|
|
return self.evalplus_cache[benchmark_name] |
|
|
|
|
|
try: |
|
|
if benchmark_name == 'mbpp': |
|
|
from evalplus.data.mbpp import get_mbpp_plus |
|
|
problems = get_mbpp_plus() |
|
|
self.logger.log_info(f"โ
MBPP+ EvalPlus ๋ฐ์ดํฐ ๋ก๋ ์ฑ๊ณต: {len(problems)}๊ฐ ๋ฌธ์ ") |
|
|
elif benchmark_name == 'humaneval': |
|
|
from evalplus.data.humaneval import get_human_eval_plus |
|
|
problems = get_human_eval_plus() |
|
|
self.logger.log_info(f"โ
HumanEval+ EvalPlus ๋ฐ์ดํฐ ๋ก๋ ์ฑ๊ณต: {len(problems)}๊ฐ ๋ฌธ์ ") |
|
|
else: |
|
|
raise ValueError(f"Unsupported benchmark for EvalPlus: {benchmark_name}") |
|
|
|
|
|
self.evalplus_cache[benchmark_name] = problems |
|
|
return problems |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.log_error(f"โ {benchmark_name.upper()}+ EvalPlus ๋ก๋ฉ ์คํจ: {e}") |
|
|
return {} |
|
|
|
|
|
def load_problem(self, benchmark_config: BenchmarkConfig, problem_id: str) -> Dict[str, Any]: |
|
|
"""ํน์ ๋ฒค์น๋งํฌ ๋ฌธ์ ๋ก๋ (EvalPlus ํ์ค ๋ฐฉ์ ์ฐ์ ์ฌ์ฉ)""" |
|
|
|
|
|
cache_key = f"{benchmark_config.name}_{problem_id}" |
|
|
if cache_key in self.loaded_problems: |
|
|
return self.loaded_problems[cache_key] |
|
|
|
|
|
|
|
|
if benchmark_config.name in ['mbpp', 'humaneval']: |
|
|
evalplus_problems = self._load_evalplus_data(benchmark_config.name) |
|
|
if problem_id in evalplus_problems: |
|
|
problem = evalplus_problems[problem_id].copy() |
|
|
|
|
|
problem['benchmark_name'] = benchmark_config.name |
|
|
problem['benchmark_config'] = benchmark_config |
|
|
|
|
|
|
|
|
self.loaded_problems[cache_key] = problem |
|
|
self.logger.log_info(f"โ
Problem loaded: {problem_id} from {benchmark_config.name} (EvalPlus)") |
|
|
return problem |
|
|
|
|
|
|
|
|
self.logger.log_info(f"โ ๏ธ {problem_id} EvalPlus ๋ก๋ฉ ์คํจ, ๊ธฐ์กด ๋ฐฉ์ ์ฌ์ฉ") |
|
|
problem_file = benchmark_config.data_path |
|
|
|
|
|
|
|
|
if not os.path.exists(problem_file): |
|
|
raise FileNotFoundError(f"Benchmark file not found: {problem_file}") |
|
|
|
|
|
|
|
|
with open(problem_file, 'r', encoding='utf-8') as f: |
|
|
problems = [json.loads(line) for line in f] |
|
|
|
|
|
|
|
|
for problem in problems: |
|
|
if problem['task_id'] == problem_id: |
|
|
|
|
|
problem['benchmark_name'] = benchmark_config.name |
|
|
problem['benchmark_config'] = benchmark_config |
|
|
|
|
|
|
|
|
self.loaded_problems[cache_key] = problem |
|
|
|
|
|
self.logger.log_info(f"โ
Problem loaded: {problem_id} from {benchmark_config.name} (Original)") |
|
|
return problem |
|
|
|
|
|
raise ValueError(f"Problem {problem_id} not found in {problem_file}") |
|
|
|
|
|
def load_problem_batch(self, benchmark_config: BenchmarkConfig, |
|
|
problem_ids: List[str]) -> List[Dict[str, Any]]: |
|
|
"""์ฌ๋ฌ ๋ฌธ์ ๋ฐฐ์น ๋ก๋ฉ""" |
|
|
problems = [] |
|
|
for problem_id in problem_ids: |
|
|
try: |
|
|
problem = self.load_problem(benchmark_config, problem_id) |
|
|
problems.append(problem) |
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Failed to load {problem_id}: {e}") |
|
|
|
|
|
return problems |
|
|
|
|
|
def get_test_cases(self, problem: Dict[str, Any]) -> List[Tuple[str, str]]: |
|
|
"""๋ฌธ์ ์์ ํ
์คํธ ์ผ์ด์ค ์ถ์ถ""" |
|
|
test_cases = [] |
|
|
|
|
|
|
|
|
if 'test' in problem: |
|
|
test_code = problem['test'] |
|
|
|
|
|
test_cases.extend(self._parse_assert_statements(test_code)) |
|
|
|
|
|
|
|
|
if 'plus_input' in problem and 'plus_output' in problem: |
|
|
plus_inputs = problem['plus_input'] |
|
|
plus_outputs = problem['plus_output'] |
|
|
|
|
|
if isinstance(plus_inputs, str): |
|
|
plus_inputs = json.loads(plus_inputs) |
|
|
if isinstance(plus_outputs, str): |
|
|
plus_outputs = json.loads(plus_outputs) |
|
|
|
|
|
for inp, out in zip(plus_inputs, plus_outputs): |
|
|
test_cases.append((str(inp), str(out))) |
|
|
|
|
|
return test_cases |
|
|
|
|
|
def _parse_assert_statements(self, test_code: str) -> List[Tuple[str, str]]: |
|
|
"""assert ๋ฌธ์์ ์
๋ ฅ-์ถ๋ ฅ ์ ์ถ์ถ""" |
|
|
import re |
|
|
|
|
|
test_cases = [] |
|
|
lines = test_code.strip().split('\n') |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line.startswith('assert '): |
|
|
|
|
|
match = re.match(r'assert\s+(\w+)\(([^)]*)\)\s*==\s*(.+)', line) |
|
|
if match: |
|
|
func_name, args, expected = match.groups() |
|
|
test_cases.append((args.strip(), expected.strip())) |
|
|
|
|
|
return test_cases |
|
|
|
|
|
def validate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]: |
|
|
"""์๋ฃจ์
๊ฒ์ฆ (AZR Python Executor ์ฌ์ฉ ์์ )""" |
|
|
|
|
|
validation_result = { |
|
|
'problem_id': problem['task_id'], |
|
|
'solution': solution, |
|
|
'syntax_valid': False, |
|
|
'test_results': [], |
|
|
'overall_success': False, |
|
|
'error_message': None |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
compile(solution, '<string>', 'exec') |
|
|
validation_result['syntax_valid'] = True |
|
|
|
|
|
|
|
|
test_cases = self.get_test_cases(problem) |
|
|
validation_result['test_results'] = [ |
|
|
{'input': inp, 'expected': out, 'passed': False} |
|
|
for inp, out in test_cases |
|
|
] |
|
|
|
|
|
|
|
|
validation_result['overall_success'] = True |
|
|
|
|
|
except SyntaxError as e: |
|
|
validation_result['error_message'] = f"Syntax Error: {e}" |
|
|
except Exception as e: |
|
|
validation_result['error_message'] = f"Validation Error: {e}" |
|
|
|
|
|
return validation_result |
|
|
|
|
|
def get_sequential_problems(self, benchmark_config: BenchmarkConfig, |
|
|
num_problems: int) -> List[Dict[str, Any]]: |
|
|
"""์์ฐจ์ ์ผ๋ก N๊ฐ ๋ฌธ์ ๋ก๋""" |
|
|
problems = [] |
|
|
|
|
|
for i in range(num_problems): |
|
|
problem_index = benchmark_config.start_index + i |
|
|
problem_id = f"{benchmark_config.problem_prefix}/{problem_index}" |
|
|
|
|
|
try: |
|
|
problem = self.load_problem(benchmark_config, problem_id) |
|
|
problems.append(problem) |
|
|
except Exception as e: |
|
|
self.logger.log_error(f"Failed to load {problem_id}: {e}") |
|
|
continue |
|
|
|
|
|
return problems |
|
|
|
|
|
def get_problem_statistics(self, benchmark_config: BenchmarkConfig) -> Dict[str, Any]: |
|
|
"""๋ฒค์น๋งํฌ ํต๊ณ ์ ๋ณด""" |
|
|
problem_file = benchmark_config.data_path |
|
|
|
|
|
if not os.path.exists(problem_file): |
|
|
return {"error": f"File not found: {problem_file}"} |
|
|
|
|
|
with open(problem_file, 'r', encoding='utf-8') as f: |
|
|
problems = [json.loads(line) for line in f] |
|
|
|
|
|
stats = { |
|
|
'total_problems': len(problems), |
|
|
'benchmark_name': benchmark_config.name, |
|
|
'data_file': problem_file, |
|
|
'sample_problem_ids': [p['task_id'] for p in problems[:5]] |
|
|
} |
|
|
|
|
|
return stats |