diff --git "a/absolute_zero_reasoner/testtime/complete_pipeline.py" "b/absolute_zero_reasoner/testtime/complete_pipeline.py" new file mode 100644--- /dev/null +++ "b/absolute_zero_reasoner/testtime/complete_pipeline.py" @@ -0,0 +1,1943 @@ +""" +Complete TestTime RLVR Pipeline + +LLM 솔루션 생성 → IPO 추출 → 태스크 생성 → LLM 평가 → Reward 계산 +모든 단계에서 AZR 코드를 최대한 그대로 활용 +""" + +from typing import Dict, List, Any, Optional +import torch +import re +import os +import json +import ray +import math +from pathlib import Path +from datetime import datetime + +from .benchmark_loader import BenchmarkProblemLoader +from .solution_generator import InitialSolutionGenerator +from .ipo_extractor import IPOTripleExtractor, IPOBuffer +from .task_generator import TestTimeTaskGenerator +from .config import TestTimeConfig, BenchmarkConfig +# Ray Actor 제거 - VLLM 배치 처리 사용 +from .logger import TestTimeLogger + +# AZR Reward Manager 직접 사용 +from ..rewards.reward_managers import CodeIORewardManager + + +@ray.remote +class RemoteTestTimePipeline: + """Ray Actor로 작동하는 TestTime Pipeline (VeRL 패턴)""" + + def __init__(self, config: TestTimeConfig, model_path: str): + """Ray worker 내부에서 모델 로딩""" + self.config = config + self.model_path = model_path + + # Ray worker에서 VLLM 모델 로딩 + from .solution_generator import InitialSolutionGenerator + import os + + # Ray runtime_env에서 쉼표가 잘리는 문제 해결 + # VLLM_USE_SPECIFIC_GPUS가 설정되어 있으면 그걸 사용 + if 'VLLM_USE_SPECIFIC_GPUS' in os.environ: + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['VLLM_USE_SPECIFIC_GPUS'] + print(f"[RemoteTestTimePipeline] Restored CUDA_VISIBLE_DEVICES from VLLM_USE_SPECIFIC_GPUS: {os.environ['CUDA_VISIBLE_DEVICES']}") + + # GPU 설정 + device = 'cuda:0' + if 'CUDA_VISIBLE_DEVICES' in os.environ: + device = f"cuda:0" + + # 멀티 GPU 환경에서는 VLLM 사용 + cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + print(f"[RemoteTestTimePipeline] CUDA_VISIBLE_DEVICES: {cuda_devices}") + + # config에서 명시적으로 use_vllm_for_data_generation 설정 확인 + use_vllm = getattr(config, 'use_vllm_for_data_generation', len(cuda_devices.split(',')) > 1) + gpu_count = len(cuda_devices.split(',')) + + # Step 5 (VeRL)와 GPU 공유를 위해 Step 1-4는 VLLM 2개 GPU만 사용 + vllm_tensor_parallel_size = min(2, gpu_count) if use_vllm else 1 + + print(f"[RemoteTestTimePipeline] GPU count: {gpu_count}, use_vllm: {use_vllm}, tensor_parallel_size: {vllm_tensor_parallel_size}") + + self.model, self.tokenizer = InitialSolutionGenerator.load_model_with_optimizations( + model_path, device, config, use_vllm=use_vllm, tensor_parallel_size=vllm_tensor_parallel_size + ) + + # 로거 설정 - Ray worker에서도 동일한 로그 파일 사용 + import os + log_file = os.environ.get('TTRLVR_LOG_FILE', None) + if log_file: + self.logger = TestTimeLogger(log_file=log_file) + else: + self.logger = TestTimeLogger() + + # CompleteTestTimePipeline 초기화 + self.pipeline = CompleteTestTimePipeline( + model=self.model, + tokenizer=self.tokenizer, + config=config, + logger=self.logger + ) + + def run_complete_pipeline(self, benchmark_config: BenchmarkConfig, + problem_id: str, round_num: int = 1, session_timestamp: str = None, + output_base_dir: str = None) -> Dict[str, Any]: + """원격에서 파이프라인 실행""" + return self.pipeline.run_complete_pipeline(benchmark_config, problem_id, round_num, session_timestamp, output_base_dir) + + def generate_batch_vllm(self, prompts: List[str], max_tokens: int = 512, + temperature: float = 0.7, top_p: float = 1.0, n: int = 1) -> Dict[str, Any]: + """ + VeRL에서 호출할 수 있는 배치 생성 메서드 + Step 5의 SharedVLLMRollout에서 사용 + """ + from .solution_generator import InitialSolutionGenerator + + # VLLM 배치 생성 + if hasattr(self.model, 'generate'): + # VLLM 모델 + outputs = InitialSolutionGenerator.generate_batch_vllm( + self.model, prompts, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + n=n + ) + + # 결과 포맷팅 + responses = [] + for output in outputs: + generated_text = output.outputs[0].text + responses.append(generated_text) + + # Tokenizer로 input_ids 생성 (VeRL이 필요로 함) + tokenized = self.tokenizer(prompts, padding=True, truncation=True, return_tensors="pt") + + return { + 'responses': responses, + 'input_ids': tokenized['input_ids'].tolist(), + 'attention_mask': tokenized['attention_mask'].tolist() + } + else: + # HuggingFace 모델 (fallback) + raise NotImplementedError("HuggingFace batch generation not implemented for VeRL sharing") + + def update_model_weights(self, model_path: str) -> bool: + """ + 학습된 모델로 VLLM 가중치 업데이트 + 매 라운드 후 호출됨 + """ + try: + self.logger.log_info(f"🔄 Updating VLLM weights from: {model_path}") + + # VLLM은 동적 가중치 업데이트를 지원하지 않으므로 + # 새로운 엔진으로 교체해야 함 + from .solution_generator import InitialSolutionGenerator + import os + + device = 'cuda:0' + use_vllm = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(',')) > 1 + gpu_count = len(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(',')) + vllm_tensor_parallel_size = min(2, gpu_count) if use_vllm else 1 + + # 기존 VLLM 엔진 정리 + if hasattr(self.model, 'llm_engine'): + del self.model + import torch + torch.cuda.empty_cache() + self.logger.log_info(" - Old VLLM engine cleaned up") + + # 새로운 모델 로드 + self.model, _ = InitialSolutionGenerator.load_model_with_optimizations( + model_path, device, self.config, + use_vllm=use_vllm, + tensor_parallel_size=vllm_tensor_parallel_size + ) + + # Pipeline 인스턴스 업데이트 + self.pipeline.model = self.model + self.pipeline.solution_generator.model = self.model + + self.logger.log_info(f"✅ VLLM weights updated successfully") + return True + + except Exception as e: + self.logger.log_error(f"Failed to update VLLM weights: {e}") + return False + + def update_model_weights_from_state_dict(self, state_dict: Dict[str, Any]) -> bool: + """ + State dict로 직접 가중치 업데이트 (더 효율적) + VeRL에서 학습된 가중치를 직접 전달받아 업데이트 + """ + try: + self.logger.log_info("🔄 Updating VLLM weights from state dict") + + # VLLM은 동적 업데이트를 지원하지 않으므로 이 방법은 제한적 + # HuggingFace 모델인 경우에만 가능 + if not hasattr(self.model, 'llm_engine'): + # HuggingFace 모델 + self.model.load_state_dict(state_dict) + self.logger.log_info("✅ Model weights updated via state dict") + return True + else: + # VLLM은 파일 기반 로드만 지원 + self.logger.log_warning("⚠️ VLLM requires file-based weight loading") + return False + + except Exception as e: + self.logger.log_error(f"Failed to update weights from state dict: {e}") + return False + + def cleanup(self): + """Ray Actor 종료 전 리소스 정리""" + try: + self.logger.log_info("🧹 Cleaning up RemoteTestTimePipeline resources...") + + # VLLM 모델이 있으면 정리 + if hasattr(self, 'model') and self.model is not None: + self.logger.log_info(" - Cleaning up VLLM model...") + # VLLM 인스턴스 삭제 + del self.model + self.model = None + + # Pipeline 정리 + if hasattr(self, 'pipeline') and self.pipeline is not None: + if hasattr(self.pipeline, 'cleanup'): + self.pipeline.cleanup() + del self.pipeline + self.pipeline = None + + # GPU 메모리 정리 + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Garbage collection 강제 실행 + import gc + gc.collect() + + self.logger.log_info("✅ Cleanup completed") + return True + except Exception as e: + self.logger.log_error(f"⚠️ Cleanup error: {e}") + return False + + +class CompleteTestTimePipeline: + """완전한 TestTime RLVR 파이프라인""" + + def __init__(self, model, tokenizer, config: TestTimeConfig, + logger: Optional[TestTimeLogger] = None): + self.model = model + self.tokenizer = tokenizer + self.config = config + # 로거 초기화 시 task_output_dir 전달 (라운드별 경로는 run() 호출 시 설정) + self.logger = logger or TestTimeLogger() + + # 각 컴포넌트 초기화 + self.benchmark_loader = BenchmarkProblemLoader(config, self.logger) + + # 모델이 None이 아닌 경우에만 컴포넌트 초기화 + if model is not None and tokenizer is not None: + # 엔진 선택 설정 (config에서 가져오기, 기본값: VLLM) + use_vllm = getattr(config, 'use_vllm_for_data_generation', True) + self.solution_generator = InitialSolutionGenerator(model, tokenizer, config, self.logger, use_vllm=use_vllm) + self.ipo_extractor = IPOTripleExtractor(config, self.logger, model, tokenizer) + # IPO extractor에 solution generator 참조 전달 (배치 처리용) + self.ipo_extractor.solution_generator = self.solution_generator + self.reward_manager = self._setup_azr_reward_manager() + else: + # Lazy initialization 플래그 + self.solution_generator = None + self.ipo_extractor = None + self.reward_manager = None + + self.task_generator = TestTimeTaskGenerator(config, self.logger) + + # 실행 모드 설정 + self.execution_mode = "single_gpu" # 기본값, iterative_trainer에서 설정됨 + self.available_gpus = [] + + # IPO Buffer 초기화 + self.ipo_buffer = IPOBuffer() + + # Task output directory 설정 + self.task_output_dir = Path('./tmp/batch_results') + + def _ensure_models_loaded(self): + """모델과 컴포넌트들이 로드되었는지 확인하고 필요시 초기화""" + + if self.model is None or self.tokenizer is None: + raise RuntimeError("Model and tokenizer must be provided during initialization") + + # 컴포넌트들이 None인 경우 초기화 + if self.solution_generator is None: + # 엔진 선택 설정 (config에서 가져오기, 기본값: VLLM) + use_vllm = getattr(self.config, 'use_vllm_for_data_generation', True) + self.logger.log_info(f"🔧 Initializing solution generator with use_vllm={use_vllm}") + try: + self.solution_generator = InitialSolutionGenerator( + self.model, self.tokenizer, self.config, self.logger, use_vllm=use_vllm + ) + self.logger.log_info(f"✅ Solution generator initialized successfully") + except Exception as e: + self.logger.log_error(f"❌ Failed to initialize solution generator: {e}") + import traceback + self.logger.log_error(f"Traceback: {traceback.format_exc()}") + raise + + if self.ipo_extractor is None: + self.ipo_extractor = IPOTripleExtractor( + self.config, self.logger, self.model, self.tokenizer + ) + # IPO extractor에 solution generator 참조 전달 + self.ipo_extractor.solution_generator = self.solution_generator + + if self.reward_manager is None: + self.reward_manager = self._setup_azr_reward_manager() + + self.logger.log_info("✅ All components ready") + + def set_execution_mode(self, execution_mode: str, available_gpus: List[int]): + """실행 모드 설정 (iterative_trainer에서 호출)""" + self.execution_mode = execution_mode + self.available_gpus = available_gpus + + self.logger.log_info(f"🎯 Execution mode set to: {execution_mode}") + self.logger.log_info(f"🎯 Available GPUs: {available_gpus}") + + + def _setup_azr_reward_manager(self) -> CodeIORewardManager: + """AZR Reward Manager 설정 (기존 설정 그대로 사용)""" + + # AZR에서 사용하는 설정으로 초기화 + class SimpleConfig: + def __init__(self): + self.use_original_code_as_ref = False + self.reward_type = 'code_execution' + self.weight = 1.0 + + reward_manager = CodeIORewardManager( + tokenizer=self.tokenizer, + num_examine=0, + reward_fn_extraction_type='rule', + math_metric='accuracy', + split='test', + splitter='boxed', + output_path='./testtime_output', + max_prompt_length=1024, + generation_reward_config=SimpleConfig() + ) + + return reward_manager + + def run_complete_pipeline(self, benchmark_config: BenchmarkConfig, + problem_id: str, round_num: int = 1, session_timestamp: str = None, + output_base_dir: str = None) -> Dict[str, Any]: + """완전한 파이프라인 실행 + + Args: + benchmark_config: 벤치마크 설정 + problem_id: 문제 ID + round_num: 라운드 번호 + session_timestamp: 세션 타임스탬프 + output_base_dir: 로그 저장 기본 디렉토리 (None이면 기본 경로 사용) + """ + + # 설��된 디렉토리 구조에 맞는 로거 재설정 + if session_timestamp is None: + # 독립 실행 시 새 timestamp 생성 + session_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + + benchmark_safe = benchmark_config.name + problem_safe = problem_id.replace('/', '_') + + # output_base_dir이 제공되면 사용, 아니면 기본 경로 사용 + if output_base_dir: + round_log_dir = os.path.join(output_base_dir, benchmark_safe, problem_safe, f'round_{round_num}') + else: + round_log_dir = f'/home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/batch_results/ttrlvr_azr_{session_timestamp}/{benchmark_safe}/{problem_safe}/round_{round_num}' + + # 새 로거로 재설정 (설계된 구조 사용) + self.logger = TestTimeLogger(task_output_dir=round_log_dir) + + self.logger.log_info(f"🚀 Starting complete TestTime RLVR pipeline for {problem_id}") + + # 모델이 로드되지 않았으면 로드 + self._ensure_models_loaded() + + pipeline_result = { + 'problem_id': problem_id, + 'benchmark': benchmark_config.name, + 'round': round_num, + 'output_dir': round_log_dir, + 'steps': {}, + 'success': False, + 'error': None + } + + try: + # Step 1: 벤치마크 문제 로딩 + self.logger.log_info("📄 Step 1: Loading benchmark problem") + problem = self.benchmark_loader.load_problem(benchmark_config, problem_id) + pipeline_result['steps']['problem_loading'] = { + 'success': True, + 'problem': problem + } + + # Step 1.5: 베이스라인 성능 측정 (NEW) + self.logger.log_info("📊 Step 1.5: Baseline performance evaluation") + baseline_results = self._evaluate_baseline_performance(problem) + pipeline_result['steps']['baseline_evaluation'] = baseline_results + + # 🔄 라운드별 IPO buffer 초기화 (각 라운드는 독립적) + self.logger.log_info(f"🔄 Clearing IPO buffer for round {round_num}") + self.ipo_buffer.clear(problem_id) + + # Step 2: 다양한 프로그램 생성 및 IPO 처리 (NEW) + diverse_programs_results = self._generate_diverse_programs_and_ipo(problem) + pipeline_result['steps']['diverse_programs'] = diverse_programs_results + + # Diverse programs 평가 결과 저장 + self._save_diverse_programs_evaluation(problem, diverse_programs_results) + + if not diverse_programs_results['success']: + self.logger.log_error(f"❌ No valid diverse programs generated") + pipeline_result['error'] = "No valid diverse programs could be generated" + return pipeline_result + + # Step 3: 현재 라운드에서 생성된 IPO triples로만 태스크 생성 + self.logger.log_info("🎯 Step 3: Generating tasks from current round IPO triples") + current_round_triples = self.ipo_buffer.get_all(problem_id) + self.logger.log_info(f"🎯 Using {len(current_round_triples)} IPO triples from current round") + all_tasks = self.task_generator.generate_tasks(current_round_triples, problem_id, round_num) + + total_tasks = sum(len(tasks) for tasks in all_tasks.values()) + pipeline_result['steps']['task_generation'] = { + 'success': total_tasks > 0, + 'total_tasks': total_tasks, + 'tasks_by_type': {k: len(v) for k, v in all_tasks.items()}, + 'all_tasks': all_tasks + } + # Step 4: LLM으로 태스크 평가 (스킵 가능) + if getattr(self.config, 'skip_task_evaluation', False): + self.logger.log_info("⏭️ Step 4: Skipping task evaluation (fast mode)") + task_evaluations = {task_type: [] for task_type in all_tasks.keys()} + pipeline_result['steps']['task_evaluation'] = { + 'success': True, + 'skipped': True, + 'evaluations': task_evaluations + } + else: + self.logger.log_info("💭 Step 4: Evaluating tasks with LLM") + task_evaluations = self._evaluate_tasks_with_llm(all_tasks) + + pipeline_result['steps']['task_evaluation'] = { + 'success': True, + 'evaluations': task_evaluations + } + + # Step 5: Reward 계산 (AZR Reward Manager 사용) + self.logger.log_info("🏆 Step 5: Computing rewards") + buffered_triples = self.ipo_buffer.get_all(problem_id) + rewards = self._compute_rewards_with_azr(task_evaluations, buffered_triples) + + pipeline_result['steps']['reward_computation'] = { + 'success': True, + 'rewards': rewards + } + + # Step 6: AZR 학습용 데이터 저장 + self.logger.log_info("💾 Step 6: Saving AZR training data") + output_dir = pipeline_result.get('output_dir', './testtime_output') + azr_files = self._save_azr_training_data(all_tasks, problem_id, round_num, output_dir) + + pipeline_result['steps']['azr_data_saving'] = { + 'success': len(azr_files) > 0, + 'files': azr_files, + 'total_tasks': sum(len(tasks) for tasks in all_tasks.values()) + } + + # Step 7: Summary 파일 생성 (batch evaluation과 동일한 형식) + self.logger.log_info("📋 Step 7: Generating task summary") + self._save_task_summary_json(problem, baseline_results, task_evaluations, round_num) + + # 전체 성공 + pipeline_result['success'] = True + pipeline_result['azr_training_data'] = azr_files # AZR 데이터 파일 경로 추가 + self.logger.log_info("✅ Complete pipeline executed successfully") + + return pipeline_result + + except Exception as e: + self.logger.log_error(f"💥 Pipeline failed: {e}") + pipeline_result['error'] = str(e) + return pipeline_result + + finally: + # 리소스 정리 + self.ipo_extractor.cleanup() + + def _evaluate_tasks_with_llm(self, all_tasks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]: + """LLM으로 생성된 태스크들 평가하고 basic_accuracy 업데이트""" + + evaluations = {} + + # 정확도 계산용 executor 초기화 + from ..utils.code_utils.python_executor import PythonExecutor + executor = PythonExecutor() + + for task_type, tasks in all_tasks.items(): + self.logger.log_info(f"🔄 Evaluating {len(tasks)} {task_type} tasks") + task_evaluations = [] + + for task in tasks: + # LLM으로 태스크 해결 + task_prompt = task['prompt'] + + # AZR 방식으로 생성 + llm_response = self._generate_task_response(task_prompt) + + # 평가 결과 저장 + evaluation = { + 'task_id': task['task_id'], + 'task_type': task_type, + 'prompt': task_prompt, + 'llm_response': llm_response, + 'expected_solution': task['expected_solution'], + 'evaluation_data': task['evaluation_data'] + } + + # 🆕 정확도 계산 및 task 업데이트 + accuracy = self._calculate_task_accuracy(evaluation, task_type, executor) + task['basic_accuracy'] = accuracy # 원본 task 객체 업데이트 + evaluation['basic_accuracy'] = accuracy # evaluation에도 추가 + + task_evaluations.append(evaluation) + + evaluations[task_type] = task_evaluations + + # LLM 응답 저장 + self._save_llm_responses(task_type, task_evaluations) + + return evaluations + + def _generate_task_response(self, prompt: str) -> str: + """단일 태스크에 대한 LLM 응답 생성 (AZR 방식)""" + + # VLLM 사용 여부 확인 + try: + from vllm import LLM + if isinstance(self.model, LLM): + # VLLM 모델인 경우 + from vllm import SamplingParams + + sampling_params = SamplingParams( + temperature=0.05, + max_tokens=512, + top_p=0.95, + stop=["\n\n\n", "# Task:", "================================================================================"] # 더 구체적인 stop token + ) + + outputs = self.model.generate([prompt], sampling_params, use_tqdm=False) + response = outputs[0].outputs[0].text.replace("\t", " ") + return response.strip() + except ImportError: + pass + + # HuggingFace 모델인 경우 + inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096) + + # attention mask 명시적으로 설정 + if 'attention_mask' not in inputs: + inputs['attention_mask'] = torch.ones_like(inputs['input_ids']) + + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + + with torch.no_grad(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # AZR evaluation과 동���한 설정 (VLLM과 동일한 temperature 사용) + outputs = self.model.generate( + inputs['input_ids'], + attention_mask=inputs['attention_mask'], # attention mask 명시적으로 전달 + max_new_tokens=256, # 태스크 응답용으로 적당한 길이 + do_sample=True, # sampling 활성화 + temperature=0.05, # VLLM과 동일한 temperature + top_p=0.95, # VLLM과 동일한 top_p + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.eos_token_id + ) + + response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + response = response[len(prompt):].strip() + + return response + + def _calculate_task_accuracy(self, evaluation: Dict[str, Any], task_type: str, executor) -> float: + """개별 task의 정확도 계산""" + + try: + llm_response = evaluation['llm_response'] + expected = evaluation['expected_solution'] + evaluation_data = evaluation['evaluation_data'] + + # AZR 방식으로 답변 추출 + extracted_answer = self._extract_answer_by_task_type(llm_response, task_type) + + if task_type == 'abduction': + # Abduction: LLM이 생성한 input을 function에 입력 → 결과와 expected output 비교 + code = evaluation_data['function_code'] + expected_output_value = evaluation_data['expected_output'] + agent_input = extracted_answer + + try: + # 함수명을 f로 변경 (EVAL_INPUT_PREDICTION_TEMPLATE이 f를 기대함) + import re + func_name_match = re.search(r'def\s+(\w+)\s*\(', code) + if func_name_match: + original_func_name = func_name_match.group(1) + # 함수명을 f로 변경 + code = re.sub(r'def\s+' + re.escape(original_func_name) + r'\s*\(', 'def f(', code) + + from ..utils.code_utils.templates import EVAL_INPUT_PREDICTION_TEMPLATE + code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format( + code=code, + gold_output=expected_output_value, + agent_input=agent_input + ) + result, status = executor.apply(code_snippet) + + if 'error' in status.lower(): + accuracy = 0.0 + else: + try: + if isinstance(result, bool): + agent_output = result + else: + agent_output = eval(result) + accuracy = 1.0 if agent_output else 0.0 + except: + accuracy = 0.0 + except: + accuracy = 0.0 + + elif task_type == 'deduction': + # Deduction: LLM이 생성한 output을 expected output과 비교 + expected_output = expected + agent_output = extracted_answer + + try: + accuracy = 1.0 if eval(expected_output) == eval(agent_output) else 0.0 + except: + accuracy = 0.0 + + elif task_type == 'induction': + # Induction: LLM이 생성한 program으로 input을 실행 → 결과와 expected output 비교 + input_output_pairs = evaluation_data['input_output_pairs'] + agent_code = extracted_answer + + accuracies = [] + for test_input, expected_output in input_output_pairs: + try: + accuracy = executor.eval_input_prediction(agent_code, expected_output, test_input) + accuracies.append(accuracy if accuracy is not None else 0.0) + except: + accuracies.append(0.0) + + # 평균 정확도 + accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0 + + else: + # 기본값: 문자열 매칭 + accuracy = 1.0 if expected.strip() == extracted_answer.strip() else 0.0 + + except Exception as e: + self.logger.log_error(f"Error calculating accuracy for {task_type}: {e}") + accuracy = 0.0 + + return accuracy + + def _extract_answer_by_task_type(self, llm_response: str, task_type: str) -> str: + """태스크 타입별 AZR 방식 정답 추출""" + + if task_type == 'induction': + # 태그 추출 (AZR 새 포맷) + pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE) + matches = pattern.findall(llm_response) + return matches[-1].strip() if matches else llm_response.strip() + + elif task_type == 'abduction': + # 태그 추출 (AZR 새 포맷) + pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE) + matches = pattern.findall(llm_response) + return matches[-1].strip() if matches else llm_response.strip() + + elif task_type == 'deduction': + # 태그 추출 (abduction과 동일한 포맷) + pattern = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE) + matches = pattern.findall(llm_response) + return matches[-1].strip() if matches else llm_response.strip() + + else: + # 기본값: 전체 응답 반환 + return llm_response.strip() + + def _compute_rewards_with_azr(self, task_evaluations: Dict[str, List[Dict[str, Any]]], + ipo_triples: List[Dict[str, Any]]) -> Dict[str, Any]: + """AZR Reward Manager로 보상 계산 (실제 코드 실행 기반 평가)""" + + # PythonExecutor 가져오기 + from ..utils.code_utils.python_executor import PythonExecutor + executor = PythonExecutor() + + rewards_by_type = {} + total_rewards = [] + + for task_type, evaluations in task_evaluations.items(): + self.logger.log_info(f"🎯 Computing rewards for {task_type} tasks") + + type_rewards = [] + + for evaluation in evaluations: + expected = evaluation['expected_solution'] + llm_response = evaluation['llm_response'] + evaluation_data = evaluation['evaluation_data'] + + # AZR 방식으로 정답 추출 + extracted_answer = self._extract_answer_by_task_type(llm_response, task_type) + + # 실제 코드 실행 기반 평가 (AZR 방식) + try: + if task_type == 'abduction': + # Abduction: LLM이 예측한 input으로 program을 실행한 결과와 expected output이 같은지 비교 + code = evaluation_data['function_code'] + expected_output = evaluation_data['expected_output'] + agent_input = extracted_answer + + # 함수 정의만 추출 (assert 문 등 제거) + import re + def extract_function_definition(code): + """코드에서 import문과 함수 정의를 추출""" + lines = code.split('\n') + import_lines = [] + func_lines = [] + in_function = False + base_indent = None + + for line in lines: + # import 문 수집 + if line.strip().startswith('from ') or line.strip().startswith('import '): + import_lines.append(line) + # 함수 정의 시작 + elif line.strip().startswith('def '): + in_function = True + base_indent = len(line) - len(line.lstrip()) + func_lines.append(line) + elif in_function: + # 빈 줄이거나 함수 내부인 경우 + if line.strip() == '': + func_lines.append(line) + elif line.startswith(' ' * (base_indent + 1)) or line.startswith('\t'): + # 함수 내부 (들여쓰기가 더 깊음) + func_lines.append(line) + else: + # 함수 외부 코드 (assert 문 등) - 중단 + break + + # import문과 함수를 합쳐서 반환 + if import_lines: + return '\n'.join(import_lines) + '\n\n' + '\n'.join(func_lines) + else: + return '\n'.join(func_lines) + + # 함수 정의만 추출 + code = extract_function_definition(code) + + # AZR 방식: 함수명을 f로 통일 (process_code_reasoning_data.py:34 참조) + # 함수명 ��출 + func_name_match = re.search(r'def\s+(\w+)\s*\(', code) + if func_name_match: + original_func_name = func_name_match.group(1) + # 함수명을 f로 변경 (AZR 방식) + code = re.sub(r'def\s+' + re.escape(original_func_name) + r'\s*\(', 'def f(', code) + + # expected_output을 실제 값으로 변환 + try: + expected_output_value = eval(expected_output) + except: + expected_output_value = expected_output + + # AZR 방식: EVAL_INPUT_PREDICTION_TEMPLATE 사용 + try: + from ..utils.code_utils.templates import EVAL_INPUT_PREDICTION_TEMPLATE + code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format( + code=code, + gold_output=expected_output_value, + agent_input=agent_input + ) + result, status = executor.apply(code_snippet) + + if 'error' in status.lower(): + accuracy = 0.0 + else: + # 실행 결과와 expected output 비교 + try: + # AZR 방식: 결과는 Boolean 값 (gold_output == f(agent_input)) + if isinstance(result, bool): + # result가 이미 boolean인 경우 + agent_output = result + else: + # result가 문자열인 경우 eval 사용 + agent_output = eval(result) + accuracy = 1.0 if agent_output else 0.0 + except: + accuracy = 0.0 + except: + accuracy = 0.0 + + elif task_type == 'deduction': + # Deduction: LLM이 생성한 output을 expected output과 비교 + expected_output = expected + agent_output = extracted_answer + + # 간단한 eval 비교 (AZR 방식) + try: + accuracy = 1.0 if eval(expected_output) == eval(agent_output) else 0.0 + except: + accuracy = 0.0 + + elif task_type == 'induction': + # Induction: LLM이 생성한 program으로 input을 실행 → 결과와 expected output 비교 + input_output_pairs = evaluation_data['input_output_pairs'] + agent_code = extracted_answer + + # 모든 input-output 쌍에 대해 테스트 + accuracies = [] + for test_input, expected_output in input_output_pairs: + try: + accuracy = executor.eval_input_prediction(agent_code, expected_output, test_input) + accuracies.append(accuracy if accuracy is not None else 0.0) + except: + accuracies.append(0.0) + + # 평균 정확도 + accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0 + + else: + # 기본값: 문자열 매칭 + accuracy = 1.0 if expected.strip() == extracted_answer.strip() else 0.0 + + except Exception as e: + self.logger.log_error(f"Error in {task_type} evaluation: {e}") + accuracy = 0.0 + + # 보상 정보 저장 + reward = { + 'task_id': evaluation['task_id'], + 'task_type': task_type, + 'extracted_answer': extracted_answer, + 'expected_solution': expected, + 'basic_accuracy': accuracy, + 'final_reward': accuracy + } + + type_rewards.append(reward) + total_rewards.append(reward['final_reward']) + + rewards_by_type[task_type] = type_rewards + + # 전체 통계 + avg_reward = sum(total_rewards) / len(total_rewards) if total_rewards else 0.0 + + return { + 'rewards_by_type': rewards_by_type, + 'total_tasks': len(total_rewards), + 'average_reward': avg_reward, + 'reward_distribution': { + task_type: sum(r['final_reward'] for r in rewards) / len(rewards) if rewards else 0.0 + for task_type, rewards in rewards_by_type.items() + } + } + + def _compute_similarity(self, expected: str, actual: str) -> float: + """문자열 유사성 계산 (간단한 방식)""" + + expected_words = set(expected.lower().split()) + actual_words = set(actual.lower().split()) + + if not expected_words and not actual_words: + return 1.0 + if not expected_words or not actual_words: + return 0.0 + + intersection = expected_words & actual_words + union = expected_words | actual_words + + return len(intersection) / len(union) # Jaccard similarity + + def _evaluate_baseline_performance(self, problem: Dict[str, Any]) -> Dict[str, Any]: + """베이스라인 성능 측정 (temperature=0.05로 5번 실행)""" + + self.logger.log_info(f"📊 Evaluating baseline performance for {problem.get('task_id', 'unknown')}") + + baseline_results = { + 'success': True, + 'total_rounds': self.config.baseline_evaluation_rounds, + 'solutions': [], + 'evaluations': [], + 'success_count': 0, + 'average_accuracy': 0.0, + 'error': None + } + + try: + for round_id in range(self.config.baseline_evaluation_rounds): + self.logger.log_info(f" 🔄 Baseline round {round_id + 1}/{self.config.baseline_evaluation_rounds}") + + # 베이스라인 temperature로 솔루션 생성 + solution = self.solution_generator.generate(problem) + + # 구문 검증 + is_valid, syntax_error = self.solution_generator.validate_syntax(solution) + + solution_result = { + 'round_id': round_id, + 'solution': solution, + 'syntax_valid': is_valid, + 'syntax_error': syntax_error, + 'evaluation': None + } + + # 정확성 평가 + if is_valid: + evaluation = self.solution_generator.evaluate_solution(problem, solution) + solution_result['evaluation'] = evaluation + + if evaluation['correct']: + baseline_results['success_count'] += 1 + self.logger.log_info(f" ✅ Round {round_id + 1}: PASSED ({evaluation['passed_tests']}/{evaluation['total_tests']} tests)") + + # 베이스라인 성공 케이스 로그 + self.logger.log_problem_attempt(problem, solution, True, evaluation) + else: + self.logger.log_info(f" ❌ Round {round_id + 1}: FAILED ({evaluation['passed_tests']}/{evaluation['total_tests']} tests)") + + # 베이스라인 실패 케이스 로그 + self.logger.log_problem_attempt(problem, solution, False, evaluation) + else: + self.logger.log_warning(f" ❌ Round {round_id + 1}: Syntax error - {syntax_error}") + + # 구문 오류 케이스 로그 + syntax_validation = {'syntax_valid': False, 'syntax_error': syntax_error} + self.logger.log_problem_attempt(problem, solution, False, syntax_validation) + + baseline_results['solutions'].append(solution_result) + + # Batch evaluation 형식으로 상세 로그 저장 (모든 라운드) + self._save_batch_evaluation_format(problem, solution_result, attempt_num=round_id + 1) + + # 평균 정확도 계산 + if baseline_results['success_count'] > 0: + baseline_results['average_accuracy'] = baseline_results['success_count'] / baseline_results['total_rounds'] + + self.logger.log_info(f" 📈 Baseline performance: {baseline_results['success_count']}/{baseline_results['total_rounds']} success ({baseline_results['average_accuracy']:.3f})") + + except Exception as e: + self.logger.log_error(f"❌ Baseline evaluation failed: {e}") + baseline_results['success'] = False + baseline_results['error'] = str(e) + + return baseline_results + + def _save_batch_evaluation_format(self, problem: Dict[str, Any], solution_result: Dict[str, Any], attempt_num: int): + """Batch evaluation과 동일한 형식으로 상세 로그 저장""" + + from ..testtime.prompts import get_prompt + + # Current evaluation 디렉토리 생성 + current_dir = os.path.join(self.logger.log_dir, "current_evaluation") + os.makedirs(current_dir, exist_ok=True) + + # attempt 파일 생성 + attempt_file = os.path.join(current_dir, f"attempt_{attempt_num}.txt") + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + + with open(attempt_file, 'w', encoding='utf-8') as f: + # 헤더 정보 + f.write(f"Current Evaluation - Attempt {attempt_num}\n") + f.write(f"Problem ID: {problem.get('task_id', 'unknown')}\n") + f.write(f"Benchmark: {problem.get('benchmark_name', 'unknown')}\n") + f.write(f"Generated: {timestamp}\n") + f.write("="*80 + "\n\n") + + # 1. 원본 문제 + f.write("1. ORIGINAL PROBLEM:\n") + f.write("="*80 + "\n") + f.write(problem.get('prompt', 'No prompt available')) + f.write("\n" + "="*80 + "\n\n") + + # 2. LLM에 들어가는 스크립트 (프롬프트) + f.write("2. LLM INPUT SCRIPT (PROMPT):\n") + f.write("="*80 + "\n") + problem_prompt = problem.get('prompt', '') + + # 중앙 프롬프트 시스템 사용 + try: + if 'HumanEval' in problem.get('task_id', ''): + full_prompt = get_prompt("solution_humaneval_basic", + problem_prompt=problem_prompt) + else: + full_prompt = get_prompt("solution_mbpp_basic", + problem_prompt=problem_prompt) + f.write(full_prompt.strip()) + except Exception as e: + # 프롬프트 생성 실패 시 기본 형식 사용 + f.write(f"You are a Python writing assistant. Complete the following Python function.\n\n{problem_prompt}\n\nPlease provide a complete implementation of the function.") + + f.write("\n" + "="*80 + "\n\n") + + # 3. LLM의 응답 + f.write("3. LLM RESPONSE:\n") + f.write("="*80 + "\n") + f.write(solution_result.get('solution', 'No solution generated')) + f.write("\n" + "="*80 + "\n\n") + + # 4. 정답 여부 + f.write("4. CORRECTNESS EVALUATION:\n") + f.write("="*80 + "\n") + + # 구문 검증 + f.write(f"Syntax Valid: {'✅ YES' if solution_result.get('syntax_valid', False) else '❌ NO'}\n") + if solution_result.get('syntax_error'): + f.write(f"Syntax Error: {solution_result['syntax_error']}\n") + + # 정확성 평가 + evaluation = solution_result.get('evaluation') + if evaluation: + if evaluation.get('correct', False): + f.write(f"Result: ✅ CORRECT ({evaluation.get('passed_tests', 0)}/{evaluation.get('total_tests', 0)} tests passed)\n") + else: + f.write(f"Result: ❌ INCORRECT ({evaluation.get('passed_tests', 0)}/{evaluation.get('total_tests', 0)} tests passed)\n") + + if evaluation.get('error'): + f.write(f"Evaluation Error: {evaluation['error']}\n") + else: + f.write("Result: ❌ NO EVALUATION (syntax error or evaluation failed)\n") + + f.write("="*80 + "\n") + + self.logger.log_info(f"📝 Batch evaluation format saved: {attempt_file}") + + def _save_llm_responses(self, task_type: str, evaluations: List[Dict[str, Any]]): + """LLM 응답을 llm_responses 디렉토리에 저장 (batch evaluation과 동일한 형식)""" + + try: + # LLM responses 디렉토리 생성 + llm_dir = os.path.join(self.logger.log_dir, "llm_responses") + os.makedirs(llm_dir, exist_ok=True) + + # 각 task별로 개별 파일 생성 (batch evaluation과 동일) + for i, evaluation in enumerate(evaluations, 1): + problem_id = evaluation['task_id'].split('_')[0] if '_' in evaluation['task_id'] else evaluation['task_id'] + response_file = os.path.join(llm_dir, f"{problem_id}_{task_type}_{i}_response.txt") + + with open(response_file, 'w', encoding='utf-8') as f: + f.write(f"Task Type: {task_type}\n") + f.write(f"Task ID: {evaluation['task_id']}\n") + f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") + f.write("="*80 + "\nORIGINAL PROMPT:\n") + f.write("="*80 + "\n") + f.write(evaluation['prompt']) + f.write("\n" + "="*80 + "\n") + f.write("LLM RESPONSE:\n") + f.write("="*80 + "\n") + f.write(evaluation['llm_response']) + f.write("\n" + "="*80 + "\n") + f.write("EXPECTED SOLUTION:\n") + f.write("="*80 + "\n") + f.write(str(evaluation['expected_solution'])) + f.write("\n" + "="*80 + "\n") + f.write("EXTRACTED ANSWER:\n") + f.write("="*80 + "\n") + # AZR 방식으로 답안 추출 + extracted_answer = self._extract_answer_by_task_type(evaluation['llm_response'], task_type) + f.write(extracted_answer) + f.write("\n" + "="*80 + "\n") + f.write("MATCH RESULT:\n") + f.write("="*80 + "\n") + accuracy = evaluation.get('basic_accuracy', 0.0) + if accuracy > 0.5: + f.write(f"✅ CORRECT (Score: {accuracy:.3f})") + else: + f.write(f"❌ INCORRECT (Score: {accuracy:.3f})") + + self.logger.log_info(f"📝 LLM responses saved to {llm_dir} (batch evaluation format)") + + except Exception as e: + self.logger.log_warning(f"Failed to save LLM responses: {e}") + + def _save_task_summary_json(self, problem: Dict[str, Any], baseline_results: Dict[str, Any], + task_evaluations: Dict[str, List[Dict[str, Any]]], round_num: int = None): + """batch evaluation과 동일한 형식의 summary.json 생성""" + + try: + problem_id = problem.get('task_id', 'unknown') + problem_id_safe = problem_id.replace('/', '_') + + # Round별 summary와 전체 summary 모두 생성 + if round_num is not None: + # Round별 summary 파일 (현재 라운드 디렉토리에) + round_summary_file = os.path.join(self.logger.log_dir, f"{problem_id_safe}_round_{round_num}_summary.json") + self._save_single_summary(problem, baseline_results, task_evaluations, round_summary_file, round_num) + + # 전체 summary 파일 (problem 레벨에) + summary_file = os.path.join(self.logger.log_dir.parent, f"{problem_id_safe}_summary.json") + self._save_single_summary(problem, baseline_results, task_evaluations, summary_file) + + except Exception as e: + self.logger.log_warning(f"Failed to save task summary: {e}") + + def _save_single_summary(self, problem: Dict[str, Any], baseline_results: Dict[str, Any], + task_evaluations: Dict[str, List[Dict[str, Any]]], summary_file: str, round_num: int = None): + """단일 summary 파일 저장""" + + with open(summary_file, 'w', encoding='utf-8') as f: + problem_id = problem.get('task_id', 'unknown') + + summary = { + 'problem_id': problem_id, + 'benchmark': problem.get('benchmark_name', 'unknown'), + 'success': True, + 'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S'), + 'initial_solution_correct': False, + 'ipo_extraction_success': True, # IPO 추출은 항상 성공한다고 가정 + 'reasoning_task_results': {} + } + + # Round 정보 추가 (라운드별 summary인 경우) + if round_num is not None: + summary['round'] = round_num + + # 초기 솔루션 결과 (baseline에서 가져오기) + if baseline_results.get('success_count', 0) > 0: + summary['initial_solution_correct'] = True + + # Reasoning task 결과 (batch evaluation과 동일한 형식) + for task_type, evaluations in task_evaluations.items(): + if evaluations: + correct_count = sum(1 for eval_data in evaluations if eval_data.get('basic_accuracy', 0) > 0.5) + total_count = len(evaluations) + + summary['reasoning_task_results'][task_type] = { + 'correct': correct_count, + 'total': total_count, + 'accuracy': correct_count / total_count if total_count > 0 else 0 + } + + json.dump(summary, f, indent=2, ensure_ascii=False) + + self.logger.log_info(f"📋 Summary saved: {summary_file}") + + def _generate_diverse_programs_and_ipo(self, problem: Dict[str, Any]) -> Dict[str, Any]: + """다양한 프로그램 생성 및 IPO 추출 (VLLM 배치 처리)""" + + # VLLM 배치 처리 사용 + return self._generate_programs_batch_vllm(problem) + + def _generate_programs_batch_vllm(self, problem: Dict[str, Any]) -> Dict[str, Any]: + """VLLM 배치 처리로 프로그램 생성""" + + problem_id = problem.get('task_id', 'unknown') + batch_size = getattr(self.config, 'parallel_batch_size', 4) + + self.logger.log_info(f"🎨 Generating {self.config.num_program_variations} diverse programs for {problem_id} (BATCH)") + self.logger.log_info(f"📊 Using batch size: {batch_size} (concurrent prompts)") + + diverse_results = { + 'success': False, + 'total_programs': self.config.num_program_variations, + 'valid_programs': 0, + 'programs': [], + 'total_ipo_triples': 0, + 'error': None, + 'batch_processing': True, + 'batch_size': batch_size + } + + try: + # 배치별로 프로그램 생성 + all_programs = [] + + for batch_idx, batch_start in enumerate(range(0, self.config.num_program_variations, batch_size)): + batch_end = min(batch_start + batch_size, self.config.num_program_variations) + batch_ids = list(range(batch_start, batch_end)) + + self.logger.log_info(f" 🎯 Processing batch {batch_idx + 1}: programs {batch_start}-{batch_end-1}") + + # 배치용 프롬프트 생성 + batch_prompts = [] + for variation_id in batch_ids: + prompt = self._create_diverse_generation_prompt(problem, variation_id) + batch_prompts.append(prompt) + + # VLLM 배치 추론 + batch_solutions = self.solution_generator.generate_batch( + batch_prompts, + temperature=self.config.diverse_generation_temperature + ) + + self.logger.log_info(f" 📊 Generated {len(batch_solutions)} solutions") + + # 배치 결과 처리 - 먼저 IPO 추출 + batch_program_results = [] + for i, (variation_id, solution) in enumerate(zip(batch_ids, batch_solutions)): + program_result = self._process_single_program_basic(problem, solution, variation_id) + batch_program_results.append(program_result) + + # 성공한 프로그램들에 대해 input generation을 배치로 처리 + successful_programs = [p for p in batch_program_results if p.get('success', False)] + + self.logger.log_info(f" 📊 Batch results: {len(batch_program_results)} programs, {len(successful_programs)} successful") + for i, prog in enumerate(batch_program_results): + self.logger.log_info(f" Program {i}: success={prog.get('success')}, IPO triples={prog.get('num_ipo_triples', 0)}") + + if successful_programs: + self.logger.log_info(f" 🎲 Generating inputs for {len(successful_programs)} valid programs (BATCH)") + + # Input generation을 위한 배치 데이터 준비 + input_generation_pairs = [] + for program_result in successful_programs: + for round_num in range(getattr(self.config, 'input_generation_rounds', 3)): + # 현재까지의 모든 예제 수집 + existing_examples = [(triple['full_input_str'], triple['actual_output']) + for triple in program_result['ipo_triples']] + + # 이미 생성된 입력도 기존 예제에 포함 + for prev_input in program_result.get('all_generated_inputs', []): + if 'input_args' in prev_input and 'expected_output' in prev_input: + existing_examples.append(( + str(prev_input['input_args']), + str(prev_input['expected_output']) + )) + + input_generation_pairs.append({ + 'problem': problem, + 'solution': program_result['extracted_function_code'], + 'existing_examples': existing_examples, + 'program_result': program_result, + 'round_num': round_num + }) + + # 배치로 input generation 실행 + if input_generation_pairs: + self.logger.log_info(f" 📊 Total input generation pairs: {len(input_generation_pairs)}") + batch_input_results, batch_generation_info = self.ipo_extractor.generate_diverse_inputs_batch(input_generation_pairs) + self.logger.log_info(f" 📊 Batch input results: {len(batch_input_results)} responses") + + # ���과를 프로그램별로 다시 정리 + pair_idx = 0 + for program_result in successful_programs: + program_result['all_generated_inputs'] = [] + program_result['input_generation_info'] = [] # Store generation info for each round + input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3) + + for round_num in range(input_generation_rounds): + if pair_idx < len(batch_input_results): + round_inputs = batch_input_results[pair_idx] + program_result['all_generated_inputs'].extend(round_inputs) + + # Store generation info for this round + if pair_idx < len(batch_generation_info) and batch_generation_info[pair_idx]: + program_result['input_generation_info'].append(batch_generation_info[pair_idx]) + + # 새로운 입력으로 IPO triple 생성 + for new_input in round_inputs: + new_triple = self.ipo_extractor.create_ipo_from_input( + problem, program_result['extracted_function_code'], new_input + ) + if new_triple: + program_result['ipo_triples'].append(new_triple) + + pair_idx += 1 + + # 최종 통계 업데이트 + program_result['num_generated_inputs'] = len(program_result['all_generated_inputs']) + program_result['num_ipo_triples'] = len(program_result['ipo_triples']) + program_result['input_generation_rounds'] = input_generation_rounds + + # 전체 결과에 추가 + for prog_idx, program_result in enumerate(batch_program_results): + all_programs.append(program_result) + + if program_result.get('success', False): + diverse_results['valid_programs'] += 1 + diverse_results['total_ipo_triples'] += program_result.get('num_ipo_triples', 0) + + # IPO triples를 buffer에 추가 (Sequential 모드와 동일하게) + program_id = f'program_{batch_idx * batch_size + prog_idx}' + for ipo_idx, triple in enumerate(program_result.get('ipo_triples', [])): + # IPO triple에 매핑 정보 추가 + triple['source_program_id'] = program_id + triple['ipo_index'] = ipo_idx + self.ipo_buffer.add(problem_id, triple) + + if program_result.get('ipo_triples'): + self.logger.log_info(f" 📥 Added {len(program_result['ipo_triples'])} IPO triples to buffer from {program_id}") + + diverse_results['programs'] = all_programs + diverse_results['success'] = diverse_results['valid_programs'] > 0 + + self.logger.log_info(f"✅ Batch processing completed:") + self.logger.log_info(f" - Valid programs: {diverse_results['valid_programs']}/{diverse_results['total_programs']}") + self.logger.log_info(f" - Total IPO triples: {diverse_results['total_ipo_triples']}") + + # 프로그램별 디렉토리 구조 저장 (배치 모드에서도 동일한 함수 사용) + self._save_diverse_programs_evaluation(problem, diverse_results) + + except Exception as e: + self.logger.log_error(f"Batch processing failed: {e}") + diverse_results['error'] = str(e) + # 실패 시 순차 처리로 fallback + self.logger.log_info("🔄 Falling back to sequential processing") + return self._generate_programs_sequential(problem) + + return diverse_results + + def _create_diverse_generation_prompt(self, problem: Dict[str, Any], variation_id: int) -> str: + """다양한 프로그램 생성용 프롬프트 생성""" + + # solution_generator의 기존 로직 활용 + problem_description = problem.get('prompt', '') + if not problem_description: + problem_description = problem.get('description', '') + + # 다양성을 위한 프롬프트 + diversity_prompts = [ + "Please generate a complete, self-contained Python script that solves the following problem.", + "Write a Python solution for this problem using a different approach.", + "Create an alternative Python implementation for the given problem.", + "Solve this problem with a unique Python solution approach." + ] + + base_prompt = diversity_prompts[variation_id % len(diversity_prompts)] + + prompt = f"""{base_prompt} + +Problem statement: +\"\"\" +{problem_description} +\"\"\" + +Please provide a complete solution with proper function implementation.""" + + return prompt + + def _process_single_program(self, problem: Dict[str, Any], solution: str, variation_id: int) -> Dict[str, Any]: + """단일 프로그램 처리 (검증 + IPO 추출)""" + + # 구문 검증 + is_valid, syntax_error = self.solution_generator.validate_syntax(solution) + + program_result = { + 'variation_id': variation_id, + 'solution': solution, + 'syntax_valid': is_valid, + 'syntax_error': syntax_error, + 'ipo_triples': [], + 'num_ipo_triples': 0, + 'generated_inputs': [], + 'num_generated_inputs': 0, + 'input_generation_rounds': 0, + 'success': False + } + + if is_valid: + try: + # IPO 추출 + extracted_function_code = self.solution_generator._extract_function_code(solution) + ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code) + + if ipo_triples: + program_result['ipo_triples'] = ipo_triples + program_result['num_ipo_triples'] = len(ipo_triples) + + # 다중 라운드 Input 증강 + all_generated_inputs = [] + input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3) + + for round_num in range(input_generation_rounds): + # 현재까지의 모든 예제 수집 + existing_examples = [(triple['full_input_str'], triple['actual_output']) + for triple in ipo_triples] + + # 이미 생성된 입력도 기존 예제에 포함 + for prev_input in all_generated_inputs: + if 'input_args' in prev_input and 'expected_output' in prev_input: + existing_examples.append(( + str(prev_input['input_args']), + str(prev_input['expected_output']) + )) + + # 새로운 diverse input 생성 + diverse_inputs = self.ipo_extractor.generate_diverse_inputs( + problem, extracted_function_code, existing_examples + ) + + if diverse_inputs: + # 새로운 입력으로 추가 IPO 생성 + for new_input in diverse_inputs: + new_triple = self.ipo_extractor.create_ipo_from_input( + problem, extracted_function_code, new_input + ) + if new_triple: + ipo_triples.append(new_triple) + + all_generated_inputs.extend(diverse_inputs) + + program_result['generated_inputs'] = all_generated_inputs + program_result['num_generated_inputs'] = len(all_generated_inputs) + program_result['ipo_triples'] = ipo_triples # 업데이트된 triple 목록 + program_result['num_ipo_triples'] = len(ipo_triples) + program_result['input_generation_rounds'] = input_generation_rounds + program_result['success'] = True + + except Exception as e: + program_result['error'] = str(e) + self.logger.log_error(f"IPO extraction failed for variation {variation_id}: {e}") + + return program_result + + def _process_single_program_basic(self, problem: Dict[str, Any], solution: str, variation_id: int) -> Dict[str, Any]: + """단일 프로그램 기본 처리 (IPO 추출만, input generation 제외)""" + + # 구문 검증 + is_valid, syntax_error = self.solution_generator.validate_syntax(solution) + + program_result = { + 'variation_id': variation_id, + 'solution': solution, + 'syntax_valid': is_valid, + 'syntax_error': syntax_error, + 'ipo_triples': [], + 'num_ipo_triples': 0, + 'all_generated_inputs': [], + 'num_generated_inputs': 0, + 'input_generation_rounds': 0, + 'extracted_function_code': None, + 'success': False + } + + if is_valid: + try: + # IPO 추출 + extracted_function_code = self.solution_generator._extract_function_code(solution) + program_result['extracted_function_code'] = extracted_function_code + ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code) + + if ipo_triples: + program_result['ipo_triples'] = ipo_triples + program_result['num_ipo_triples'] = len(ipo_triples) + program_result['success'] = True + + except Exception as e: + program_result['error'] = str(e) + self.logger.log_error(f"IPO extraction failed for variation {variation_id}: {e}") + + return program_result + + + def _generate_programs_sequential(self, problem: Dict[str, Any]) -> Dict[str, Any]: + """다양한 프로그램 생성 및 각각에서 IPO 추출""" + + self.logger.log_info(f"🎨 Generating {self.config.num_program_variations} diverse programs for {problem.get('task_id', 'unknown')}") + + diverse_results = { + 'success': False, + 'total_programs': self.config.num_program_variations, + 'valid_programs': 0, + 'programs': [], + 'total_ipo_triples': 0, + 'error': None + } + + try: + for variation_id in range(self.config.num_program_variations): + self.logger.log_info(f" 🎯 Generating program variation {variation_id + 1}/{self.config.num_program_variations}") + + # 다양한 솔루션 생성 (temperature=0.7) + diverse_solution = self.solution_generator.generate_diverse( + problem, + temperature=self.config.diverse_generation_temperature, + variation_id=variation_id + ) + + # 구문 검증 + is_valid, syntax_error = self.solution_generator.validate_syntax(diverse_solution) + + program_result = { + 'variation_id': variation_id, + 'solution': diverse_solution, + 'syntax_valid': is_valid, + 'syntax_error': syntax_error, + 'ipo_triples': [], + 'num_ipo_triples': 0, + 'generated_inputs': [], + 'num_generated_inputs': 0 + } + + if is_valid: + diverse_results['valid_programs'] += 1 + + # IPO 추출 + extracted_function_code = self.solution_generator._extract_function_code(diverse_solution) + ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code) + + if ipo_triples: + program_result['ipo_triples'] = ipo_triples + program_result['num_ipo_triples'] = len(ipo_triples) + + # 다중 라운드 Input 증강 + all_generated_inputs = [] + input_generation_rounds = getattr(self.config, 'input_generation_rounds', 3) # 기본 3라운드 + + for round_num in range(input_generation_rounds): + self.logger.log_info(f" 🎯 Input generation round {round_num + 1}/{input_generation_rounds}") + + # 현재까지의 모든 예제 수집 + existing_examples = [(triple['full_input_str'], triple['actual_output']) + for triple in ipo_triples] + + # 이미 생성된 입력도 기존 예제에 포함 + for prev_input in all_generated_inputs: + if 'input_args' in prev_input and 'expected_output' in prev_input: + existing_examples.append(( + str(prev_input['input_args']), + str(prev_input['expected_output']) + )) + + # 새로운 diverse input 생성 + diverse_inputs = self.ipo_extractor.generate_diverse_inputs( + problem, extracted_function_code, existing_examples + ) + + if not diverse_inputs: + self.logger.log_warning(f" ��️ Round {round_num + 1}: No valid inputs generated") + continue + + self.logger.log_info(f" ✅ Round {round_num + 1}: Generated {len(diverse_inputs)} new inputs") + + # Input generation 정보 저장 (첫 번째 라운드만) + if round_num == 0 and hasattr(self.ipo_extractor, 'last_input_generation_info'): + self._save_input_generation_details( + problem, + variation_id + 1, + self.ipo_extractor.last_input_generation_info + ) + + # 새로운 입력으로 추가 IPO 생성 + round_ipo_count = 0 + for new_input in diverse_inputs: + new_triple = self.ipo_extractor.create_ipo_from_input( + problem, extracted_function_code, new_input + ) + if new_triple: + ipo_triples.append(new_triple) + round_ipo_count += 1 + + self.logger.log_info(f" 📊 Round {round_num + 1}: Created {round_ipo_count} IPO triples") + all_generated_inputs.extend(diverse_inputs) + + program_result['generated_inputs'] = all_generated_inputs + program_result['input_generation_rounds'] = input_generation_rounds + program_result['num_generated_inputs'] = len(diverse_inputs) + program_result['num_ipo_triples'] = len(ipo_triples) + + # Input generation 정보 추가 + if hasattr(self.ipo_extractor, 'last_input_generation_info'): + program_result['input_generation_info'] = self.ipo_extractor.last_input_generation_info + + # Buffer에 저장 (source_program_id 추가) + problem_id = problem.get('task_id', 'unknown') + program_id = f'program_{variation_id}' + for ipo_idx, triple in enumerate(ipo_triples): + # 🆕 IPO triple에 매핑 정보 추가 + triple['source_program_id'] = program_id + triple['ipo_index'] = ipo_idx + self.ipo_buffer.add(problem_id, triple) + + diverse_results['total_ipo_triples'] += len(ipo_triples) + + self.logger.log_info(f" ✅ Program {variation_id + 1}: {len(ipo_triples)} IPO triples generated") + else: + self.logger.log_warning(f" ⚠️ Program {variation_id + 1}: No IPO triples extracted") + else: + self.logger.log_warning(f" ❌ Program {variation_id + 1}: Syntax error - {syntax_error}") + + diverse_results['programs'].append(program_result) + + # 성공 판정: 최소 1개 이상의 유효한 프로그램이 있어야 함 + diverse_results['success'] = diverse_results['valid_programs'] > 0 + + self.logger.log_info(f" 📊 Diverse programs: {diverse_results['valid_programs']}/{diverse_results['total_programs']} valid, {diverse_results['total_ipo_triples']} total IPO triples") + + except Exception as e: + self.logger.log_error(f"❌ Diverse program generation failed: {e}") + diverse_results['error'] = str(e) + + return diverse_results + + def _save_input_generation_details(self, problem: Dict[str, Any], program_id: int, + input_gen_info: Dict[str, Any]) -> None: + """Input generation 상세 정보를 파일로 저장""" + try: + problem_id = problem.get('task_id', 'unknown') + + # input_generation_info를 IPO extractor에서 저장해둔 것을 사용 + # 이 정보는 나중에 batch_evaluate_testtime.py에서 저장될 것임 + # 여기서는 메모리에만 보관 + if not hasattr(self, '_input_generation_infos'): + self._input_generation_infos = {} + + key = f"{problem_id}_program_{program_id}" + self._input_generation_infos[key] = input_gen_info + + self.logger.log_info(f"Input generation info collected for {problem_id} program {program_id}") + return + + # 프로그램별 상세 파일 생성 + detail_file = input_gen_dir / f"program_{program_id}_details.txt" + + with open(detail_file, 'w', encoding='utf-8') as f: + f.write(f"Input Generation Details\n") + f.write(f"Problem ID: {problem_id}\n") + f.write(f"Program ID: {program_id}\n") + f.write(f"Generated: {self.timestamp}\n") + f.write("=" * 80 + "\n\n") + + f.write("1. FUNCTION INFO:\n") + f.write("=" * 80 + "\n") + func_info = input_gen_info.get('function_info', {}) + f.write(f"Function Name: {func_info.get('name', 'N/A')}\n") + f.write(f"Parameters: {func_info.get('args', 'N/A')}\n") + f.write(f"Return Type: {func_info.get('return_type', 'N/A')}\n\n") + + f.write("2. ARGUMENT TYPE INFO:\n") + f.write("=" * 80 + "\n") + f.write(input_gen_info.get('arg_type_info', 'N/A') + "\n\n") + + f.write("3. EXISTING EXAMPLES:\n") + f.write("=" * 80 + "\n") + for i, (inp, out) in enumerate(input_gen_info.get('existing_examples', [])): + f.write(f"Example {i+1}: Input: {inp} → Output: {out}\n") + f.write("\n") + + f.write("4. LLM PROMPT:\n") + f.write("=" * 80 + "\n") + f.write(input_gen_info.get('prompt', 'N/A') + "\n\n") + + f.write("5. LLM RESPONSE:\n") + f.write("=" * 80 + "\n") + f.write(input_gen_info.get('llm_response', 'N/A') + "\n\n") + + f.write("6. EXTRACTED INPUTS:\n") + f.write("=" * 80 + "\n") + extracted = input_gen_info.get('extracted_inputs', []) + if extracted: + for i, inp_data in enumerate(extracted): + f.write(f"Input {i+1}: {inp_data}\n") + else: + f.write("No inputs extracted\n") + + # 전체 요약 파일 업데이트 + summary_file = input_gen_dir / "input_generation_summary.txt" + mode = 'a' if summary_file.exists() else 'w' + + with open(summary_file, mode, encoding='utf-8') as f: + if mode == 'w': + f.write(f"Input Generation Summary\n") + f.write(f"Problem ID: {problem_id}\n") + f.write(f"Generated: {self.timestamp}\n") + f.write("=" * 80 + "\n\n") + + f.write(f"Program {program_id}: {len(extracted)} inputs generated\n") + + except Exception as e: + self.logger.log_warning(f"Failed to save input generation details: {e}") + + def _save_diverse_programs_evaluation(self, problem: Dict[str, Any], + diverse_results: Dict[str, Any]) -> None: + """다양한 프로그램들의 평가 결과를 저장 (batch evaluation과 동일한 형식)""" + try: + problem_id = problem.get('task_id', 'unknown') + + # Diverse programs 디렉토리 생성 + diverse_dir = os.path.join(self.logger.log_dir, "diverse_programs") + os.makedirs(diverse_dir, exist_ok=True) + + # 요약 파일 생성 (batch evaluation과 동일) + summary_file = os.path.join(diverse_dir, "diverse_summary.txt") + with open(summary_file, 'w', encoding='utf-8') as f: + f.write(f"Diverse Programs Summary\n") + f.write(f"Problem ID: {problem_id}\n") + f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") + f.write(f"Total Programs: {len(diverse_results.get('programs', []))}\n") + f.write("="*50 + "\n\n") + + for i, program in enumerate(diverse_results.get('programs', []), 1): + f.write(f"Program {i}: {'✅ Valid' if program.get('syntax_valid', False) else '❌ Invalid'}\n") + f.write(f"IPO Triples: {program.get('num_ipo_triples', 0)}\n") + f.write(f"Generated Inputs: {program.get('num_generated_inputs', 0)}\n\n") + + # 각 프로그램별 디렉토리와 파일 생성 (batch evaluation과 동일한 구조) + for i, program in enumerate(diverse_results.get('programs', []), 1): + program_dir = os.path.join(diverse_dir, f"program_{i}") + os.makedirs(program_dir, exist_ok=True) + + # 1. generation_details.txt (batch evaluation과 동일한 형식) + details_file = os.path.join(program_dir, "generation_details.txt") + with open(details_file, 'w', encoding='utf-8') as f: + f.write(f"Diverse Program {i} - Generation Details\n") + f.write(f"Problem ID: {problem_id}\n") + f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") + f.write("="*80 + "\n\n") + + f.write("1. ORIGINAL PROBLEM:\n") + f.write("="*80 + "\n") + f.write(problem.get('prompt', 'No prompt available') + "\n") + f.write("="*80 + "\n\n") + + f.write("2. DIVERSITY PROMPT USED:\n") + f.write("="*80 + "\n") + f.write(program.get('diversity_instruction', 'Standard generation') + "\n") + f.write("="*80 + "\n\n") + + f.write("3. LLM RESPONSE:\n") + f.write("="*80 + "\n") + f.write(program.get('solution', 'N/A') + "\n") + f.write("="*80 + "\n\n") + + f.write("4. EVALUATION RESULTS:\n") + f.write("="*80 + "\n") + f.write(f"Syntax Valid: {'✅ YES' if program.get('syntax_valid', False) else '❌ NO'}\n") + f.write(f"IPO Triples Generated: {program.get('num_ipo_triples', 0)}\n") + f.write(f"Input Generation: {program.get('num_generated_inputs', 0)} new inputs\n") + f.write("="*80 + "\n") + + # 2. solution.py + solution_file = os.path.join(program_dir, "solution.py") + with open(solution_file, 'w', encoding='utf-8') as f: + f.write(f"# Diverse Program {i}\n") + f.write(f"# Problem ID: {problem_id}\n") + f.write(f"# Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") + f.write(f"# Syntax Valid: {program.get('syntax_valid', False)}\n") + f.write(f"# IPO Triples: {program.get('num_ipo_triples', 0)}\n\n") + + # 추출된 함수 코드가 있으면 사용, 없으면 원본 솔루션 사용 + # 이미 generate_batch에서 후처리가 되었으므로 solution을 그대로 사용 + f.write(program.get('solution', '# No solution available')) + + # 3. ipo_triples 디렉토리와 파일들 + if program.get('ipo_triples'): + ipo_dir = os.path.join(program_dir, "ipo_triples") + os.makedirs(ipo_dir, exist_ok=True) + + for j, triple in enumerate(program['ipo_triples'], 1): + triple_file = os.path.join(ipo_dir, f"triple_{j}.json") + with open(triple_file, 'w', encoding='utf-8') as f: + json.dump(triple, f, indent=2) + + # 4. input_generation_details.txt 파일 생성 + if program.get('input_generation_info'): + details_file = os.path.join(program_dir, "input_generation_details.txt") + with open(details_file, 'w', encoding='utf-8') as f: + f.write(f"Input Generation Details - Program {i}\n") + f.write(f"Problem ID: {problem_id}\n") + f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") + f.write("="*80 + "\n\n") + + # 각 라운드의 정보 저장 + for round_idx, gen_info in enumerate(program['input_generation_info'], 1): + if 'error' in gen_info: + f.write(f"ROUND {round_idx} - ERROR:\n") + f.write("="*80 + "\n") + f.write(f"Error: {gen_info.get('error', 'Unknown error')}\n") + f.write(f"Traceback:\n{gen_info.get('traceback', 'No traceback')}\n") + f.write("\n") + continue + + f.write(f"ROUND {round_idx}:\n") + f.write("="*80 + "\n\n") + + # Function info + func_info = gen_info.get('function_info', {}) + f.write("1. FUNCTION INFO:\n") + f.write("="*80 + "\n") + f.write(f"Function Name: {func_info.get('name', 'N/A')}\n") + f.write(f"Parameters: {func_info.get('args', 'N/A')}\n") + f.write(f"Parameters String: {func_info.get('signature', 'N/A')}\n\n") + + # Argument type info + f.write("2. ARGUMENT TYPE INFO:\n") + f.write("="*80 + "\n") + arg_types = gen_info.get('arg_type_info', {}) + if arg_types: + f.write("Argument types:\n") + for arg, arg_type in arg_types.items(): + f.write(f"- {arg}: {arg_type}\n") + else: + f.write("No argument type information available\n") + f.write("\n") + + # Existing examples + f.write("3. EXISTING EXAMPLES:\n") + f.write("="*80 + "\n") + existing_examples = gen_info.get('existing_examples', []) + if existing_examples: + for idx, example in enumerate(existing_examples, 1): + f.write(f"Example {idx}: Input: {example[0]} → Output: {example[1]}\n") + else: + f.write("No existing examples\n") + f.write("\n") + + # LLM prompt + f.write("4. LLM PROMPT:\n") + f.write("="*80 + "\n") + f.write(gen_info.get('prompt', 'No prompt available')) + f.write("\n"*2 + "="*80 + "\n\n") + + # LLM response + f.write("5. LLM RESPONSE:\n") + f.write("="*80 + "\n") + f.write(gen_info.get('llm_response', 'No response available')) + f.write("\n"*2 + "="*80 + "\n\n") + + # Extracted inputs + f.write("6. EXTRACTED INPUTS:\n") + f.write("="*80 + "\n") + extracted = gen_info.get('extracted_inputs', []) + if extracted: + for idx, inp in enumerate(extracted, 1): + f.write(f"Input {idx}: {inp}\n") + else: + f.write("No inputs extracted\n") + + # Valid inputs (if different from extracted) + valid = gen_info.get('valid_inputs', []) + if valid != extracted: + f.write("\n7. VALID INPUTS (after validation):\n") + f.write("="*80 + "\n") + if valid: + for idx, inp in enumerate(valid, 1): + f.write(f"Input {idx}: {inp}\n") + else: + f.write("No valid inputs after validation\n") + + f.write("\n") + + self.logger.log_info(f"💾 Diverse programs saved to {diverse_dir} (batch evaluation format)") + + except Exception as e: + self.logger.log_error(f"Failed to save diverse programs evaluation: {e}") + + def _save_azr_training_data(self, all_tasks: Dict[str, List[Dict[str, Any]]], + problem_id: str, round_num: int, + output_dir: str) -> Dict[str, str]: + """AZR 학습용 데이터를 parquet 형식으로 저장""" + + try: + import pandas as pd + import os + + # AZR 학습용 디렉토리 생성 + azr_dir = os.path.join(output_dir, 'azr_training_data') + os.makedirs(azr_dir, exist_ok=True) + + saved_files = {} + total_tasks = 0 + + # Task 타입별로 parquet 파일 저장 + for task_type, tasks in all_tasks.items(): + if not tasks: + continue + + # AZR parquet 형식으로 변환 + azr_data = [] + for task in tasks: + # 프롬프트는 이미 포맷된 문자열이므로 그대로 저장 + # Phase 5의 RLHFDataset이 문자열을 처리하도록 수정 필요 + + # AZR과 동일한 형식으로 변환: 문자열 → 딕셔너리 리스트 + # 이렇게 하면 RLHFDataset에서 chat template이 올바르게 적용됨 + prompt_dict_list = [{"role": "user", "content": task['prompt']}] + print(f"[DEBUG AZR DATA SAVE] Converting prompt to dict list format") + print(f"[DEBUG] Original prompt type: {type(task['prompt'])}, length: {len(task['prompt']) if isinstance(task['prompt'], str) else 'N/A'}") + print(f"[DEBUG] Converted prompt type: {type(prompt_dict_list)}, first elem: {type(prompt_dict_list[0])}") + azr_record = { + 'prompt': prompt_dict_list, # 딕셔너리 리스트로 저장 (AZR과 동일) + 'uid': task['uid'], + 'ipo_group_id': task['ipo_group_id'], + 'source_program_id': task['source_program_id'], + 'ipo_index': task['ipo_index'], + 'problem': { + 'input': task['ipo_triple']['input'], + 'output': task['ipo_triple']['output'], + 'snippet': task['ipo_triple']['program'] + }, + 'ground_truth': task['ground_truth'], + 'extra_info': task['extra_info'], + 'basic_accuracy': task['basic_accuracy'], + 'original_problem_id': task['original_problem_id'], + 'round': task['round'] + } + azr_data.append(azr_record) + + # ipo_group_id로 정렬하여 배치 보장 + azr_data.sort(key=lambda x: x['ipo_group_id']) + + # Parquet 파일로 저장 + df = pd.DataFrame(azr_data) + file_path = os.path.join(azr_dir, f'{task_type}.parquet') + df.to_parquet(file_path, index=False) + + # 디버그: 저장된 데이터 확인 + print(f"[DEBUG] Saved {task_type}.parquet with {len(df)} records") + if len(df) > 0: + saved_prompt = df.iloc[0]['prompt'] + print(f"[DEBUG] First saved prompt type: {type(saved_prompt)}") + + saved_files[task_type] = file_path + total_tasks += len(tasks) + + self.logger.log_info(f"💾 Saved {len(tasks)} {task_type} tasks to {file_path}") + + # 통계 정보 저장 + stats = { + 'problem_id': problem_id, + 'round': round_num, + 'total_tasks': total_tasks, + 'tasks_by_type': {k: len(v) for k, v in all_tasks.items()}, + 'files': saved_files, + 'batch_groups': len(set(task['ipo_group_id'] for tasks in all_tasks.values() for task in tasks)) + } + + stats_file = os.path.join(azr_dir, 'training_stats.json') + import json + with open(stats_file, 'w') as f: + json.dump(stats, f, indent=2) + + self.logger.log_info(f"✅ AZR training data saved: {total_tasks} tasks in {len(saved_files)} files") + self.logger.log_info(f"📊 Batch groups: {stats['batch_groups']} (for batch alignment)") + + return saved_files + + except Exception as e: + self.logger.log_error(f"Failed to save AZR training data: {e}") + return {} \ No newline at end of file