hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
"""
Benchmark Problem Loader
AZR ๊ธฐ๋ฐ˜ TestTime RLVR์„ ์œ„ํ•œ ๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ ๋กœ๋”ฉ ์‹œ์Šคํ…œ
๊ธฐ์กด Test-Time-RLVR์˜ load_humaneval_problem ํ•จ์ˆ˜๋ฅผ ํ™•์žฅ
"""
import json
import os
from typing import Dict, List, Any, Tuple, Optional
from pathlib import Path
from .config import BenchmarkConfig, TestTimeConfig
from .logger import TestTimeLogger
class BenchmarkProblemLoader:
"""๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ ๋กœ๋”ฉ ๋ฐ ๊ด€๋ฆฌ (EvalPlus ํ‘œ์ค€ ๋ฐฉ์‹ ์‚ฌ์šฉ)"""
def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None):
self.config = config
self.logger = logger or TestTimeLogger()
self.loaded_problems = {} # ์บ์‹œ
self.evalplus_cache = {} # EvalPlus ๋ฐ์ดํ„ฐ ์บ์‹œ
def _load_evalplus_data(self, benchmark_name: str) -> Dict[str, Dict[str, Any]]:
"""EvalPlus ๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ์บ์‹œ"""
if benchmark_name in self.evalplus_cache:
return self.evalplus_cache[benchmark_name]
try:
if benchmark_name == 'mbpp':
from evalplus.data.mbpp import get_mbpp_plus
problems = get_mbpp_plus() # ์ž๋™์œผ๋กœ mbpp_deserialize_inputs ์ ์šฉ๋จ
self.logger.log_info(f"โœ… MBPP+ EvalPlus ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต: {len(problems)}๊ฐœ ๋ฌธ์ œ")
elif benchmark_name == 'humaneval':
from evalplus.data.humaneval import get_human_eval_plus
problems = get_human_eval_plus() # EvalPlus ํ‘œ์ค€ ๋ฐฉ์‹
self.logger.log_info(f"โœ… HumanEval+ EvalPlus ๋ฐ์ดํ„ฐ ๋กœ๋“œ ์„ฑ๊ณต: {len(problems)}๊ฐœ ๋ฌธ์ œ")
else:
raise ValueError(f"Unsupported benchmark for EvalPlus: {benchmark_name}")
self.evalplus_cache[benchmark_name] = problems
return problems
except Exception as e:
self.logger.log_error(f"โŒ {benchmark_name.upper()}+ EvalPlus ๋กœ๋”ฉ ์‹คํŒจ: {e}")
return {}
def load_problem(self, benchmark_config: BenchmarkConfig, problem_id: str) -> Dict[str, Any]:
"""ํŠน์ • ๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ ๋กœ๋“œ (EvalPlus ํ‘œ์ค€ ๋ฐฉ์‹ ์šฐ์„  ์‚ฌ์šฉ)"""
cache_key = f"{benchmark_config.name}_{problem_id}"
if cache_key in self.loaded_problems:
return self.loaded_problems[cache_key]
# EvalPlus ๋ฐฉ์‹ ์‹œ๋„
if benchmark_config.name in ['mbpp', 'humaneval']:
evalplus_problems = self._load_evalplus_data(benchmark_config.name)
if problem_id in evalplus_problems:
problem = evalplus_problems[problem_id].copy()
# ์ถ”๊ฐ€ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์„ค์ •
problem['benchmark_name'] = benchmark_config.name
problem['benchmark_config'] = benchmark_config
# ์บ์‹œ์— ์ €์žฅ
self.loaded_problems[cache_key] = problem
self.logger.log_info(f"โœ… Problem loaded: {problem_id} from {benchmark_config.name} (EvalPlus)")
return problem
# Fallback: ๊ธฐ์กด ๋ฐฉ์‹
self.logger.log_info(f"โš ๏ธ {problem_id} EvalPlus ๋กœ๋”ฉ ์‹คํŒจ, ๊ธฐ์กด ๋ฐฉ์‹ ์‚ฌ์šฉ")
problem_file = benchmark_config.data_path
# ํŒŒ์ผ ์กด์žฌ ํ™•์ธ
if not os.path.exists(problem_file):
raise FileNotFoundError(f"Benchmark file not found: {problem_file}")
# JSONL ํŒŒ์ผ ๋กœ๋“œ (๊ธฐ์กด ๋ฐฉ์‹๊ณผ ๋™์ผ)
with open(problem_file, 'r', encoding='utf-8') as f:
problems = [json.loads(line) for line in f]
# ๋ฌธ์ œ ID๋กœ ๊ฒ€์ƒ‰
for problem in problems:
if problem['task_id'] == problem_id:
# ์ถ”๊ฐ€ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์„ค์ •
problem['benchmark_name'] = benchmark_config.name
problem['benchmark_config'] = benchmark_config
# ์บ์‹œ์— ์ €์žฅ
self.loaded_problems[cache_key] = problem
self.logger.log_info(f"โœ… Problem loaded: {problem_id} from {benchmark_config.name} (Original)")
return problem
raise ValueError(f"Problem {problem_id} not found in {problem_file}")
def load_problem_batch(self, benchmark_config: BenchmarkConfig,
problem_ids: List[str]) -> List[Dict[str, Any]]:
"""์—ฌ๋Ÿฌ ๋ฌธ์ œ ๋ฐฐ์น˜ ๋กœ๋”ฉ"""
problems = []
for problem_id in problem_ids:
try:
problem = self.load_problem(benchmark_config, problem_id)
problems.append(problem)
except Exception as e:
self.logger.log_error(f"Failed to load {problem_id}: {e}")
return problems
def get_test_cases(self, problem: Dict[str, Any]) -> List[Tuple[str, str]]:
"""๋ฌธ์ œ์—์„œ ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค ์ถ”์ถœ"""
test_cases = []
# ๊ธฐ๋ณธ ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค (test ํ•„๋“œ)
if 'test' in problem:
test_code = problem['test']
# assert ๋ฌธ์—์„œ ์ž…๋ ฅ-์ถœ๋ ฅ ์Œ ์ถ”์ถœ
test_cases.extend(self._parse_assert_statements(test_code))
# Plus ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค (plus_input, plus_output)
if 'plus_input' in problem and 'plus_output' in problem:
plus_inputs = problem['plus_input']
plus_outputs = problem['plus_output']
if isinstance(plus_inputs, str):
plus_inputs = json.loads(plus_inputs)
if isinstance(plus_outputs, str):
plus_outputs = json.loads(plus_outputs)
for inp, out in zip(plus_inputs, plus_outputs):
test_cases.append((str(inp), str(out)))
return test_cases
def _parse_assert_statements(self, test_code: str) -> List[Tuple[str, str]]:
"""assert ๋ฌธ์—์„œ ์ž…๋ ฅ-์ถœ๋ ฅ ์Œ ์ถ”์ถœ"""
import re
test_cases = []
lines = test_code.strip().split('\n')
for line in lines:
line = line.strip()
if line.startswith('assert '):
# assert function(args) == expected ํ˜•ํƒœ ํŒŒ์‹ฑ
match = re.match(r'assert\s+(\w+)\(([^)]*)\)\s*==\s*(.+)', line)
if match:
func_name, args, expected = match.groups()
test_cases.append((args.strip(), expected.strip()))
return test_cases
def validate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]:
"""์†”๋ฃจ์…˜ ๊ฒ€์ฆ (AZR Python Executor ์‚ฌ์šฉ ์˜ˆ์ •)"""
validation_result = {
'problem_id': problem['task_id'],
'solution': solution,
'syntax_valid': False,
'test_results': [],
'overall_success': False,
'error_message': None
}
try:
# 1. ๊ตฌ๋ฌธ ๊ฒ€์ฆ
compile(solution, '<string>', 'exec')
validation_result['syntax_valid'] = True
# 2. ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค ์‹คํ–‰ (ํ–ฅํ›„ AZR Python Executor ์—ฐ๋™)
test_cases = self.get_test_cases(problem)
validation_result['test_results'] = [
{'input': inp, 'expected': out, 'passed': False}
for inp, out in test_cases
]
# ์ž„์‹œ: ๊ตฌ๋ฌธ๋งŒ ํ†ต๊ณผํ•˜๋ฉด ์„ฑ๊ณต์œผ๋กœ ์ฒ˜๋ฆฌ
validation_result['overall_success'] = True
except SyntaxError as e:
validation_result['error_message'] = f"Syntax Error: {e}"
except Exception as e:
validation_result['error_message'] = f"Validation Error: {e}"
return validation_result
def get_sequential_problems(self, benchmark_config: BenchmarkConfig,
num_problems: int) -> List[Dict[str, Any]]:
"""์ˆœ์ฐจ์ ์œผ๋กœ N๊ฐœ ๋ฌธ์ œ ๋กœ๋“œ"""
problems = []
for i in range(num_problems):
problem_index = benchmark_config.start_index + i
problem_id = f"{benchmark_config.problem_prefix}/{problem_index}"
try:
problem = self.load_problem(benchmark_config, problem_id)
problems.append(problem)
except Exception as e:
self.logger.log_error(f"Failed to load {problem_id}: {e}")
continue
return problems
def get_problem_statistics(self, benchmark_config: BenchmarkConfig) -> Dict[str, Any]:
"""๋ฒค์น˜๋งˆํฌ ํ†ต๊ณ„ ์ •๋ณด"""
problem_file = benchmark_config.data_path
if not os.path.exists(problem_file):
return {"error": f"File not found: {problem_file}"}
with open(problem_file, 'r', encoding='utf-8') as f:
problems = [json.loads(line) for line in f]
stats = {
'total_problems': len(problems),
'benchmark_name': benchmark_config.name,
'data_file': problem_file,
'sample_problem_ids': [p['task_id'] for p in problems[:5]]
}
return stats