""" Benchmark Problem Loader AZR 기반 TestTime RLVR을 위한 벤치마크 문제 로딩 시스템 기존 Test-Time-RLVR의 load_humaneval_problem 함수를 확장 """ import json import os from typing import Dict, List, Any, Tuple, Optional from pathlib import Path from .config import BenchmarkConfig, TestTimeConfig from .logger import TestTimeLogger class BenchmarkProblemLoader: """벤치마크 문제 로딩 및 관리 (EvalPlus 표준 방식 사용)""" def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None): self.config = config self.logger = logger or TestTimeLogger() self.loaded_problems = {} # 캐시 self.evalplus_cache = {} # EvalPlus 데이터 캐시 def _load_evalplus_data(self, benchmark_name: str) -> Dict[str, Dict[str, Any]]: """EvalPlus 데이터 로드 및 캐시""" if benchmark_name in self.evalplus_cache: return self.evalplus_cache[benchmark_name] try: if benchmark_name == 'mbpp': from evalplus.data.mbpp import get_mbpp_plus problems = get_mbpp_plus() # 자동으로 mbpp_deserialize_inputs 적용됨 self.logger.log_info(f"✅ MBPP+ EvalPlus 데이터 로드 성공: {len(problems)}개 문제") elif benchmark_name == 'humaneval': from evalplus.data.humaneval import get_human_eval_plus problems = get_human_eval_plus() # EvalPlus 표준 방식 self.logger.log_info(f"✅ HumanEval+ EvalPlus 데이터 로드 성공: {len(problems)}개 문제") else: raise ValueError(f"Unsupported benchmark for EvalPlus: {benchmark_name}") self.evalplus_cache[benchmark_name] = problems return problems except Exception as e: self.logger.log_error(f"❌ {benchmark_name.upper()}+ EvalPlus 로딩 실패: {e}") return {} def load_problem(self, benchmark_config: BenchmarkConfig, problem_id: str) -> Dict[str, Any]: """특정 벤치마크 문제 로드 (EvalPlus 표준 방식 우선 사용)""" cache_key = f"{benchmark_config.name}_{problem_id}" if cache_key in self.loaded_problems: return self.loaded_problems[cache_key] # EvalPlus 방식 시도 if benchmark_config.name in ['mbpp', 'humaneval']: evalplus_problems = self._load_evalplus_data(benchmark_config.name) if problem_id in evalplus_problems: problem = evalplus_problems[problem_id].copy() # 추가 메타데이터 설정 problem['benchmark_name'] = benchmark_config.name problem['benchmark_config'] = benchmark_config # 캐시에 저장 self.loaded_problems[cache_key] = problem self.logger.log_info(f"✅ Problem loaded: {problem_id} from {benchmark_config.name} (EvalPlus)") return problem # Fallback: 기존 방식 self.logger.log_info(f"⚠️ {problem_id} EvalPlus 로딩 실패, 기존 방식 사용") problem_file = benchmark_config.data_path # 파일 존재 확인 if not os.path.exists(problem_file): raise FileNotFoundError(f"Benchmark file not found: {problem_file}") # JSONL 파일 로드 (기존 방식과 동일) with open(problem_file, 'r', encoding='utf-8') as f: problems = [json.loads(line) for line in f] # 문제 ID로 검색 for problem in problems: if problem['task_id'] == problem_id: # 추가 메타데이터 설정 problem['benchmark_name'] = benchmark_config.name problem['benchmark_config'] = benchmark_config # 캐시에 저장 self.loaded_problems[cache_key] = problem self.logger.log_info(f"✅ Problem loaded: {problem_id} from {benchmark_config.name} (Original)") return problem raise ValueError(f"Problem {problem_id} not found in {problem_file}") def load_problem_batch(self, benchmark_config: BenchmarkConfig, problem_ids: List[str]) -> List[Dict[str, Any]]: """여러 문제 배치 로딩""" problems = [] for problem_id in problem_ids: try: problem = self.load_problem(benchmark_config, problem_id) problems.append(problem) except Exception as e: self.logger.log_error(f"Failed to load {problem_id}: {e}") return problems def get_test_cases(self, problem: Dict[str, Any]) -> List[Tuple[str, str]]: """문제에서 테스트 케이스 추출""" test_cases = [] # 기본 테스트 케이스 (test 필드) if 'test' in problem: test_code = problem['test'] # assert 문에서 입력-출력 쌍 추출 test_cases.extend(self._parse_assert_statements(test_code)) # Plus 테스트 케이스 (plus_input, plus_output) if 'plus_input' in problem and 'plus_output' in problem: plus_inputs = problem['plus_input'] plus_outputs = problem['plus_output'] if isinstance(plus_inputs, str): plus_inputs = json.loads(plus_inputs) if isinstance(plus_outputs, str): plus_outputs = json.loads(plus_outputs) for inp, out in zip(plus_inputs, plus_outputs): test_cases.append((str(inp), str(out))) return test_cases def _parse_assert_statements(self, test_code: str) -> List[Tuple[str, str]]: """assert 문에서 입력-출력 쌍 추출""" import re test_cases = [] lines = test_code.strip().split('\n') for line in lines: line = line.strip() if line.startswith('assert '): # assert function(args) == expected 형태 파싱 match = re.match(r'assert\s+(\w+)\(([^)]*)\)\s*==\s*(.+)', line) if match: func_name, args, expected = match.groups() test_cases.append((args.strip(), expected.strip())) return test_cases def validate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]: """솔루션 검증 (AZR Python Executor 사용 예정)""" validation_result = { 'problem_id': problem['task_id'], 'solution': solution, 'syntax_valid': False, 'test_results': [], 'overall_success': False, 'error_message': None } try: # 1. 구문 검증 compile(solution, '', 'exec') validation_result['syntax_valid'] = True # 2. 테스트 케이스 실행 (향후 AZR Python Executor 연동) test_cases = self.get_test_cases(problem) validation_result['test_results'] = [ {'input': inp, 'expected': out, 'passed': False} for inp, out in test_cases ] # 임시: 구문만 통과하면 성공으로 처리 validation_result['overall_success'] = True except SyntaxError as e: validation_result['error_message'] = f"Syntax Error: {e}" except Exception as e: validation_result['error_message'] = f"Validation Error: {e}" return validation_result def get_sequential_problems(self, benchmark_config: BenchmarkConfig, num_problems: int) -> List[Dict[str, Any]]: """순차적으로 N개 문제 로드""" problems = [] for i in range(num_problems): problem_index = benchmark_config.start_index + i problem_id = f"{benchmark_config.problem_prefix}/{problem_index}" try: problem = self.load_problem(benchmark_config, problem_id) problems.append(problem) except Exception as e: self.logger.log_error(f"Failed to load {problem_id}: {e}") continue return problems def get_problem_statistics(self, benchmark_config: BenchmarkConfig) -> Dict[str, Any]: """벤치마크 통계 정보""" problem_file = benchmark_config.data_path if not os.path.exists(problem_file): return {"error": f"File not found: {problem_file}"} with open(problem_file, 'r', encoding='utf-8') as f: problems = [json.loads(line) for line in f] stats = { 'total_problems': len(problems), 'benchmark_name': benchmark_config.name, 'data_file': problem_file, 'sample_problem_ids': [p['task_id'] for p in problems[:5]] } return stats