import uuid from pathlib import Path from copy import deepcopy from typing import List, Dict, Tuple, Any, Optional import random import json from collections import defaultdict import threading import gc import os import pickle import ast from datetime import datetime import ray import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from omegaconf import OmegaConf import numpy as np from verl.utils.dataset.rl_dataset import collate_fn from verl.utils.debug import marked_timer from verl.trainer.ppo.ray_trainer import ( apply_kl_penalty, compute_advantage, compute_response_mask, reduce_metrics, compute_timing_metrics, agg_loss, ) from verl.trainer.ppo.metric_utils import _compute_response_info from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto, DataProto from absolute_zero_reasoner.utils.tracking import ReasonRLTracking from absolute_zero_reasoner.data_construction.constructor import get_gen_code_io_data, get_pred_code_io_data from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer from absolute_zero_reasoner.utils.dataset.rl_dataset import RLHFDataset from absolute_zero_reasoner.rewards.code_reward import parse_code_input_output, parse_inputs_message from absolute_zero_reasoner.utils.code_utils.python_executor import PythonExecutor # Import SandboxfusionExecutor only when needed to avoid docker dependency SandboxfusionExecutor = None from absolute_zero_reasoner.utils.auxiliary import reflection_keywords from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter # TTRLVR 모듈 import 추가 import sys from pathlib import Path # 상대 경로 사용 project_root = Path(__file__).parent.parent.parent # TestTime-RLVR-v2 directory sys.path.append(str(project_root)) from absolute_zero_reasoner.testtime.complete_pipeline import CompleteTestTimePipeline from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig from absolute_zero_reasoner.testtime.logger import TestTimeLogger from absolute_zero_reasoner.testtime.prompts import get_prompt, get_diversity_instruction, get_temperature from absolute_zero_reasoner.rewards.custom_evaluate import extract_code from absolute_zero_reasoner.testtime.ipo_extractor import IPOTripleExtractor, IPOBuffer from absolute_zero_reasoner.testtime.solution_generator import InitialSolutionGenerator import re seed_program = """def f(a): return a""" def create_default_dict(): return defaultdict(int) def compute_data_metrics(batch, use_critic=True, tokenizer=None): sequence_score = batch.batch['token_level_scores'].sum(-1) sequence_reward = batch.batch['token_level_rewards'].sum(-1) advantages = batch.batch['advantages'] returns = batch.batch['returns'] max_response_length = batch.batch['responses'].shape[-1] prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) response_info = _compute_response_info(batch) prompt_length = response_info['prompt_length'] response_length = response_info['response_length'] valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) if use_critic: values = batch.batch['values'] valid_values = torch.masked_select(values, response_mask) return_diff_var = torch.var(valid_returns - valid_values) return_var = torch.var(valid_returns) reflect_list = [] correct_list = [] correct_response_length = [] incorrect_response_length = [] for i in range(len(batch)): data_item = batch[i] # DataProtoItem prompt_ids = data_item.batch['prompts'] _prompt_length = prompt_ids.shape[-1] response_ids = data_item.batch['responses'] valid_response_length = data_item.batch['attention_mask'][_prompt_length:].sum() valid_response_ids = response_ids[:valid_response_length] # decode responses_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) reflect = any([kw in responses_str.lower() for kw in reflection_keywords]) reflect_list.append(reflect) reward = data_item.batch['token_level_rewards'].sum(-1) correct = reward >= 1 correct_list.append(correct) if correct: correct_response_length.append(valid_response_length.item()) else: incorrect_response_length.append(valid_response_length.item()) # the ratio of reflection reflect_ratio = sum(reflect_list) / len(reflect_list) \ if len(reflect_list) > 0 else 0 # the ratio of correct response in relfection samples correct_ratio = sum([reflect_list[i] and correct_list[i] for i in range(len(reflect_list))]) / \ sum(reflect_list) if sum(reflect_list) > 0 else 0 # separate lengths length_metrics = {} if len(correct_response_length) > 0: length_metrics['correct_response_length/mean'] = sum(correct_response_length) / len(correct_response_length) if len(incorrect_response_length) > 0: length_metrics['incorrect_response_length/mean'] = sum(incorrect_response_length) / len(incorrect_response_length) metrics = { # score 'critic/score/mean': torch.mean(sequence_score).detach().item(), 'critic/score/max': torch.max(sequence_score).detach().item(), 'critic/score/min': torch.min(sequence_score).detach().item(), # reward 'critic/rewards/mean': torch.mean(sequence_reward).detach().item(), 'critic/rewards/max': torch.max(sequence_reward).detach().item(), 'critic/rewards/min': torch.min(sequence_reward).detach().item(), # adv 'critic/advantages/mean': torch.mean(valid_adv).detach().item(), 'critic/advantages/max': torch.max(valid_adv).detach().item(), 'critic/advantages/min': torch.min(valid_adv).detach().item(), # returns 'critic/returns/mean': torch.mean(valid_returns).detach().item(), 'critic/returns/max': torch.max(valid_returns).detach().item(), 'critic/returns/min': torch.min(valid_returns).detach().item(), **({ # values 'critic/values/mean': torch.mean(valid_values).detach().item(), 'critic/values/max': torch.max(valid_values).detach().item(), 'critic/values/min': torch.min(valid_values).detach().item(), # vf explained var 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), } if use_critic else {}), # response length 'response_length/mean': torch.mean(response_length).detach().item(), 'response_length/max': torch.max(response_length).detach().item(), 'response_length/min': torch.min(response_length).detach().item(), 'response_length/clip_ratio': torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), "response_length/reflect_ratio": reflect_ratio, "response_length/correct_reflect_ratio": correct_ratio, **length_metrics, # prompt length 'prompt_length/mean': torch.mean(prompt_length).detach().item(), 'prompt_length/max': torch.max(prompt_length).detach().item(), 'prompt_length/min': torch.min(prompt_length).detach().item(), 'prompt_length/clip_ratio': torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), } return metrics # Create a local function to process elements before sending to manager def process_elements(entries): """Process element types locally before sending to manager""" processed = [] for entry in entries: entry_copy = entry.copy() if 'input' in entry: try: input_type = determine_type(entry['input']) entry_copy['_input_type'] = input_type except: entry_copy['_input_type'] = "str" if 'output' in entry: try: output_type = determine_type(entry['output']) entry_copy['_output_type'] = output_type except: entry_copy['_output_type'] = "str" if 'inputs' in entry: try: entry_copy['_input_types'] = [determine_type(inp) for inp in entry['inputs']] except: entry_copy['_input_types'] = ["str"] * len(entry['inputs']) if 'outputs' in entry: try: entry_copy['_output_types'] = [determine_type(out) for out in entry['outputs']] except: entry_copy['_output_types'] = ["str"] * len(entry['outputs']) processed.append(entry_copy) return processed def determine_type(element): """Determine type safely without eval""" try: # Handle potential tuple strings without parentheses if isinstance(element, str) and ',' in element: # Attempt to parse as tuple by wrapping in parentheses try: wrapped = f'({element})' parsed_tuple = ast.literal_eval(wrapped) if isinstance(parsed_tuple, tuple): return 'tuple' except: pass # Proceed to normal parsing # Try using ast.literal_eval for safety parsed = ast.literal_eval(element) if is_pickleable(parsed): return type(parsed).__name__ else: return "str" except: return "str" def is_pickleable(obj): try: pickle.dumps(obj) return True except (pickle.PicklingError, TypeError, AttributeError): return False def save_programs_to_file(programs, problem_type, global_step, save_path, data_type="valid"): """ Save programs to JSONL file with metadata Args: programs: List of program dictionaries problem_type: Type of problem (gen_code_i, gen_code_o, etc.) global_step: Current training step save_path: Base path for saving data_type: "valid" or "invalid" """ if not programs or not save_path: return # Create directory structure timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_dir = os.path.join(save_path, problem_type, data_type) os.makedirs(save_dir, exist_ok=True) # Create filename with step and timestamp filename = f"step_{global_step:06d}_{timestamp}.jsonl" file_path = os.path.join(save_dir, filename) # Save programs with metadata try: with open(file_path, 'w', encoding='utf-8') as f: for i, program in enumerate(programs): program_data = { **program, 'meta': { 'global_step': global_step, 'timestamp': timestamp, 'program_id': f"{global_step}_{problem_type}_{data_type}_{i}", 'problem_type': problem_type, 'data_type': data_type, 'save_time': datetime.now().isoformat() } } f.write(json.dumps(program_data, ensure_ascii=False) + '\n') print(f"💾 Saved {len(programs)} {data_type} {problem_type} programs to {file_path}") return file_path except Exception as e: print(f"❌ Failed to save {data_type} {problem_type} programs: {e}") return None @ray.remote class DatasetManager: def __init__(self): self.datasets = { 'input': [], # Stores only data entries 'output': [], # Stores only data entries 'seed': [], 'error': [], 'problem': [], 'error_seed': [], 'input_steps': [], # Parallel list storing step numbers 'output_steps': [], # Parallel list storing step numbers 'error_steps': [], # Parallel list storing step numbers 'problem_steps': [], # Parallel list storing step numbers 'input_steps_counter': defaultdict(int), 'output_steps_counter': defaultdict(int), 'error_steps_counter': defaultdict(int), 'problem_steps_counter': defaultdict(int), } self.type_counters = { 'input_types': defaultdict(create_default_dict), 'output_types': defaultdict(create_default_dict), 'error_types': defaultdict(create_default_dict), } self.locks = { 'input': threading.Lock(), 'output': threading.Lock(), 'seed': threading.Lock(), 'error': threading.Lock(), 'problem': threading.Lock(), 'error_seed': threading.Lock(), 'input_steps': threading.Lock(), 'output_steps': threading.Lock(), 'error_steps': threading.Lock(), 'problem_steps': threading.Lock(), 'input_steps_counter': threading.Lock(), 'output_steps_counter': threading.Lock(), 'error_steps_counter': threading.Lock(), 'problem_steps_counter': threading.Lock(), 'input_types': threading.RLock(), 'output_types': threading.RLock(), 'error_types': threading.RLock(), } def update_seed(self, entries): with self.locks['seed']: existing = {json.dumps(d, sort_keys=True): True for d in self.datasets['seed']} new_entries = [e for e in entries if json.dumps(e, sort_keys=True) not in existing] for entry in new_entries: if 'input' in entry and '_input_type' in entry: self.count_element(entry['input'], entry['_input_type'], 'input') if 'output' in entry and '_output_type' in entry: self.count_element(entry['output'], entry['_output_type'], 'output') self.datasets['seed'].extend(new_entries) return len(new_entries) def update_error_seed(self, entries): with self.locks['error_seed'], self.locks['error_types']: existing = {json.dumps(d, sort_keys=True): True for d in self.datasets['error_seed']} new_entries = [e for e in entries if json.dumps(e, sort_keys=True) not in existing] # Process using pre-computed types for entry in new_entries: if 'output' in entry and '_output_type' in entry: self.count_element(entry['output'], entry['_output_type'], 'error') self.datasets['error_seed'].extend(new_entries) return len(new_entries) def get_dataset(self, name) -> List[Dict]: """Returns only the data entries without step information""" return deepcopy(self.datasets[name]) def get_all_datasets(self) -> Dict[str, List[Dict]]: """Returns all datasets without step information""" return { 'input': deepcopy(self.datasets['input']), 'output': deepcopy(self.datasets['output']), 'seed': deepcopy(self.datasets['seed']), 'error': deepcopy(self.datasets['error']), 'problem': deepcopy(self.datasets['problem']), 'error_seed': deepcopy(self.datasets['error_seed']), 'input_steps': deepcopy(self.datasets['input_steps']), 'output_steps': deepcopy(self.datasets['output_steps']), 'error_steps': deepcopy(self.datasets['error_steps']), 'problem_steps': deepcopy(self.datasets['problem_steps']), 'input_steps_counter': deepcopy(self.datasets['input_steps_counter']), 'output_steps_counter': deepcopy(self.datasets['output_steps_counter']), 'error_steps_counter': deepcopy(self.datasets['error_steps_counter']), 'problem_steps_counter': deepcopy(self.datasets['problem_steps_counter']), } def add_input_batch(self, entries: List[Dict], global_step: int): with self.locks['input'], self.locks['input_steps'], self.locks['input_types']: for entry in entries: if 'input' in entry and '_input_type' in entry: self.count_element(entry['input'], entry['_input_type'], 'input') self.datasets['input'].extend(entries) self.datasets['input_steps'].extend([global_step]*len(entries)) self.datasets['input_steps_counter'][global_step] += len(entries) return len(self.datasets['input']) def add_output_batch(self, entries: List[Dict], global_step: int): with self.locks['output'], self.locks['output_steps'], self.locks['output_types']: for entry in entries: if 'output' in entry and '_output_type' in entry: self.count_element(entry['output'], entry['_output_type'], 'output') self.datasets['output'].extend(entries) self.datasets['output_steps'].extend([global_step]*len(entries)) self.datasets['output_steps_counter'][global_step] += len(entries) return len(self.datasets['output']) def add_error_batch(self, entries: List[Dict], global_step: int): with self.locks['error'], self.locks['error_steps'], self.locks['error_types'], self.locks['error_types']: for entry in entries: if 'output' in entry and '_output_type' in entry: self.count_element(entry['output'], entry['_output_type'], 'error') self.datasets['error'].extend(entries) self.datasets['error_steps'].extend([global_step]*len(entries)) self.datasets['error_steps_counter'][global_step] += len(entries) return len(self.datasets['error']) def add_error_seed_batch(self, entries: List[Dict], global_step: int): with self.locks['error_seed'], self.locks['error_steps']: for entry in entries: if 'output' in entry and '_output_type' in entry: self.count_element(entry['output'], entry['_output_type'], 'error') self.datasets['error_seed'].extend(entries) self.datasets['error_steps'].extend([global_step]*len(entries)) self.datasets['error_steps_counter'][global_step] += len(entries) return len(self.datasets['error_seed']) def add_problem_batch(self, entries: List[Dict], global_step: int): with self.locks['problem'], self.locks['problem_steps'], self.locks['problem_steps_counter']: for entry in entries: if 'inputs' in entry and '_input_types' in entry: for inp, inp_type in zip(entry['inputs'], entry['_input_types']): self.count_element(inp, inp_type, 'input') if 'outputs' in entry and '_output_types' in entry: for out, out_type in zip(entry['outputs'], entry['_output_types']): self.count_element(out, out_type, 'output') self.datasets['problem'].extend(entries) self.datasets['problem_steps'].extend([global_step]*len(entries)) self.datasets['problem_steps_counter'][global_step] += len(entries) return len(entries) def get_snippets(self) -> List[Dict]: # get the snippets from input and output datasets merged together snippets = [] if self.datasets['input'] or self.datasets['output']: for d in self.datasets['input']: snippets.append({'snippet': d['snippet'], 'original_snippet': d['original_snippet'], 'imports': d['imports']}) for d in self.datasets['output']: snippets.append({'snippet': d['snippet'], 'original_snippet': d['original_snippet'], 'imports': d['imports']}) return list(snippets) else: # we are in the seed stage for d in self.datasets['seed']: snippets.append({'snippet': d['snippet'], 'original_snippet': d['original_snippet'], 'imports': d['imports']}) return list(snippets) def get_snippets_with_steps(self) -> List[Tuple[Dict, int]]: snippets = self.get_snippets() return list(zip(snippets, self.datasets['input_steps'] + self.datasets['output_steps'])) def get_recent_additions(self, dataset_key: str, current_step: int, window: int) -> int: counter_key = f"{dataset_key}_steps_counter" with self.locks[counter_key]: # Get steps from the counter dictionary instead of list recent_steps = [ step for step in self.datasets[counter_key].keys() if (current_step - step) <= window ] total_recent = sum( self.datasets[counter_key][step] for step in recent_steps ) return total_recent def get_dataset_with_steps(self, name) -> List[Tuple[Dict, int]]: if name == 'input': assert len(self.datasets['input']) == len(self.datasets['input_steps']), \ "Input data/steps mismatch!" return list(zip(deepcopy(self.datasets['input']), self.datasets['input_steps'])) elif name == 'output': assert len(self.datasets['output']) == len(self.datasets['output_steps']), \ "Output data/steps mismatch!" return list(zip(deepcopy(self.datasets['output']), self.datasets['output_steps'])) elif name == 'error': assert len(self.datasets['error']) == len(self.datasets['error_steps']), \ "Error data/steps mismatch!" return list(zip(deepcopy(self.datasets['error']), self.datasets['error_steps'])) elif name == 'problem': assert len(self.datasets['problem']) == len(self.datasets['problem_steps']), \ "Problem data/steps mismatch!" return list(zip(deepcopy(self.datasets['problem']), self.datasets['problem_steps'])) raise ValueError(f"Invalid dataset name: {name}") def get_steps_dataset(self, name) -> List[int]: if name == 'input': return self.datasets['input_steps'] elif name == 'output': return self.datasets['output_steps'] elif name == 'error': return self.datasets['error_steps'] elif name == 'problem': return self.datasets['problem_steps'] raise ValueError(f"Invalid dataset name: {name}") def truncate_datasets(self, max_length: int, name: str) -> Tuple[int, int]: if name == 'input': with self.locks['input'], self.locks['input_steps']: before_length = len(self.datasets['input']) self.datasets['input'] = self.datasets['input'][:max_length] self.datasets['input_steps'] = self.datasets['input_steps'][:max_length] truncated_length = before_length - len(self.datasets['input']) return truncated_length, before_length elif name == 'output': with self.locks['output'], self.locks['output_steps']: before_length = len(self.datasets['output']) self.datasets['output'] = self.datasets['output'][:max_length] self.datasets['output_steps'] = self.datasets['output_steps'][:max_length] truncated_length = before_length - len(self.datasets['output']) return truncated_length, before_length elif name == 'seed': with self.locks['seed']: before_length = len(self.datasets['seed']) self.datasets['seed'] = self.datasets['seed'][:max_length] truncated_length = before_length - len(self.datasets['seed']) return truncated_length, before_length elif name == 'error': with self.locks['error']: before_length = len(self.datasets['error']) self.datasets['error'] = self.datasets['error'][:max_length] truncated_length = before_length - len(self.datasets['error']) return truncated_length, before_length elif name == 'error_seed': with self.locks['error_seed']: before_length = len(self.datasets['error_seed']) self.datasets['error_seed'] = self.datasets['error_seed'][:max_length] truncated_length = before_length - len(self.datasets['error_seed']) return truncated_length, before_length elif name == 'problem': with self.locks['problem']: before_length = len(self.datasets['problem']) self.datasets['problem'] = self.datasets['problem'][:max_length] truncated_length = before_length - len(self.datasets['problem']) return truncated_length, before_length else: raise ValueError(f"Invalid dataset name: {name}") def get_dataset_size(self, name: str) -> int: with self.locks[name]: return len(self.datasets[name]) def full_load_datasets(self, datasets): """Load all datasets from a dictionary""" self.datasets = datasets def full_load_data_with_type_counters(self, data: Dict): """Load datasets and type counters""" # First create a copy of the current empty structure default_structure = { 'input': [], 'output': [], 'seed': [], 'error': [], 'problem': [], 'error_seed': [], 'input_steps': [], 'output_steps': [], 'error_steps': [], 'problem_steps': [], 'input_steps_counter': defaultdict(int), 'output_steps_counter': defaultdict(int), 'error_steps_counter': defaultdict(int), 'problem_steps_counter': defaultdict(int) } # Extract datasets datasets_only = {k: v for k, v in data.items() if k != 'type_counters'} # Merge loaded data with default structure merged_datasets = default_structure.copy() merged_datasets.update(datasets_only) # Set the merged result self.datasets = merged_datasets # Then load type counters if available if 'type_counters' in data: with self.locks['input_types'], self.locks['output_types'], self.locks['error_types']: for counter_key in ['input_types', 'output_types', 'error_types']: if counter_key in data['type_counters']: self.type_counters[counter_key] = defaultdict(create_default_dict) for type_name, values in data['type_counters'][counter_key].items(): for value, count in values.items(): self.type_counters[counter_key][type_name][value] = count def get_type_statistics(self, counter_key): """Get statistics about the types and their counts.""" with self.locks[counter_key]: return { type_name: { "total_unique": len(values), "total_count": sum(values.values()), "examples": list(values.keys())[:5] # First 5 examples } for type_name, values in self.type_counters[counter_key].items() } def get_all_type_statistics(self): """Get all type statistics for inputs, outputs, and errors.""" return { 'input_types': self.get_type_statistics('input_types'), 'output_types': self.get_type_statistics('output_types'), 'error_types': self.get_type_statistics('error_types') } def get_all_data_with_type_counters(self) -> Dict: """Returns all datasets and type counters""" all_data = self.get_all_datasets() all_data.update({ 'type_counters': { 'input_types': deepcopy(self.type_counters['input_types']), 'output_types': deepcopy(self.type_counters['output_types']), 'error_types': deepcopy(self.type_counters['error_types']), } }) return all_data def get_type_counter(self, counter_key): counter_type = f"{counter_key}_types" with self.locks[counter_type]: return self.type_counters[counter_type] def count_element(self, element, element_type, counter_key): counter_type = f"{counter_key}_types" with self.locks[counter_type]: self.type_counters[counter_type][element_type][element] += 1 class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer): """ TTRLVR의 모든 Phase를 하나의 VeRL 세션에서 처리하는 통합 Trainer AZR의 CodeIORayPPOTrainer를 상속받아 TTRLVR용으로 수정 """ def __init__(self, ttrlvr_config: Dict = None, problem_ids: List[str] = None, total_rounds: int = 30, output_dir: str = None, *args, **kwargs): # CodeIORayPPOTrainer의 __init__ 호출 (past_epoch_window 포함) super().__init__(*args, **kwargs) # CodeIORayPPOTrainer가 이미 다음을 처리함: # - self._use_ttrlvr_rewards 설정 # - self.ttrlvr_processor 초기화 # - assert 문들 # - self._past_epoch_window 설정 # TTRLVR 특화 설정 self.ttrlvr_config = ttrlvr_config or {} self.problem_ids = problem_ids or [] self.total_rounds = total_rounds self.output_dir = output_dir or '/tmp/ttrlvr_unified' self.current_round = 0 self.session_timestamp = None # 전체 세션 타임스탬프 (fit() 시작 시 설정) # Phase 1-4용 설정 self.num_programs = self.ttrlvr_config.get('num_programs', 4) self.input_rounds = self.ttrlvr_config.get('input_generation_rounds', 3) self.parallel_batch_size = self.ttrlvr_config.get('parallel_batch_size', 4) # TestTime 설정 생성 self.testtime_config = TestTimeConfig( num_program_variations=self.num_programs, # 다양한 프로그램 생성 수 baseline_evaluation_rounds=5, # 베이스라인 평가 라운드 skip_task_evaluation=False, # 베이스라인 평가 실행 max_adaptation_steps=10, # 적응 학습 스텝 adaptation_batch_size=self.parallel_batch_size # 병렬 배치 크기 ) # CompleteTestTimePipeline은 나중에 초기화 (모델이 준비된 후) self.ttrlvr_pipeline = None # 로거 설정 - CompleteTestTimePipeline과 동일한 구조 그대로 사용 from datetime import datetime self.session_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # 초기에는 임시 로거 (라운드별로 재설정됨) self.ttrlvr_logger = TestTimeLogger( task_output_dir=self.output_dir, log_file=os.path.join(self.output_dir, 'ttrlvr_unified.log') ) if self.config.azr.executor == 'qwq': self._executor = PythonExecutor( timeout_length=self.config.azr.execute_max_timeout, ast_check=self.config.azr.ast_check, max_workers=self.config.azr.get('executor_max_workers', 1) ) elif self.config.azr.executor == 'sandboxfusion': # Lazy import to avoid docker dependency when not using sandboxfusion global SandboxfusionExecutor if SandboxfusionExecutor is None: from absolute_zero_reasoner.utils.code_utils.sandboxfusion_executor import SandboxfusionExecutor self._executor = SandboxfusionExecutor( timeout_length=self.config.azr.execute_max_timeout, ast_check=self.config.azr.ast_check, max_workers=self.config.azr.get('executor_max_workers', 1), use_china_mirror=self.config.azr.get('use_china_mirror', True) ) else: raise ValueError(f'Invalid executor: {self.config.azr.executor}') self.dataset_manager = DatasetManager.remote() self._last_cleanup_step = 0 self._cleanup_frequency = self.config.azr.get('executor_cleanup_frequency', 5) def cleanup(self): """Clean up the executor and other resources""" if hasattr(self._executor, 'cleanup'): PrettyPrinter.status("CLEANUP", "Cleaning up executor...", "info") self._executor.cleanup() # Force garbage collection gc.collect() def _create_train_code_gen_dataloader( self, problem_type: str, data_len: int, dataset_key: str = None, seeding: bool = False, ) -> DataLoader: if dataset_key is None: if problem_type == 'code_i': dataset_key = 'input' elif problem_type == 'code_o': dataset_key = 'output' elif problem_type == 'code_e': dataset_key = 'error' elif problem_type == 'code_f': # For code_f we use merged snippets from all datasets io_data = ray.get(self.dataset_manager.get_snippets.remote()) else: raise ValueError(f'Invalid problem type: {problem_type}') if problem_type != 'code_f': io_data = ray.get(self.dataset_manager.get_dataset.remote(dataset_key)) parquet_path = (self._code_dir / f'train_gen_{problem_type}.parquet').as_posix() os.makedirs(os.path.dirname(parquet_path), exist_ok=True) # Handle weights strategy if problem_type == 'code_f' and not seeding: if self.config.azr.gen_data_probabilities_strategy == 'step': entries_with_steps = ray.get(self.dataset_manager.get_snippets_with_steps.remote()) weights = [w + 1 for _, w in entries_with_steps] if entries_with_steps else [1.0]*len(io_data) else: weights = [1.0] * len(io_data) elif dataset_key == 'seed': weights = None elif self.config.azr.gen_data_probabilities_strategy == 'uniform': weights = [1.0] * len(io_data) elif self.config.azr.gen_data_probabilities_strategy == 'step': weights = [w + 1 for w in ray.get(self.dataset_manager.get_steps_dataset.remote(dataset_key))] else: raise ValueError(f"Unknown strategy: {self.config.azr.gen_data_probabilities_strategy}") # Common parameters for get_gen_code_io_data gen_params = { 'io_data': io_data, 'target_data_len': data_len, 'problem_type': problem_type, 'content_max_length': self.config.azr.data_selection_strategy.content_max_length, 'io_n': 1 if problem_type == 'code_f' else self.config.azr.data_selection_strategy.io_n, 'instruction_type': self.config.reward_fn.extraction_type, 'output_path': parquet_path, 'split': 'train', 'tokenizer': self.tokenizer, 'banned_keywords': self.config.azr.data_selection_strategy.banned_words, 'banned_assertion_keywords': self.config.azr.data_selection_strategy.banned_keywords_for_errors_and_exceptions, 'weights': weights, 'enable_composite_function': self.config.azr.data_selection_strategy.composite_start_step > 0 and self.global_steps >= self.config.azr.data_selection_strategy.composite_start_step, 'composite_function_n_min': self.config.azr.data_selection_strategy.composite_function_n_min, 'composite_function_n_max': self.config.azr.data_selection_strategy.composite_function_n_max, 'composite_chance': self.config.azr.data_selection_strategy.composite_chance, 'remove_after_return': self.config.azr.reward.generation_reward_config.remove_after_return, 'remove_input_from_snippet': self.config.azr.reward.generation_reward_config.remove_input_from_snippet, 'include_references': self.config.azr.reward.generation_reward_config.include_references, } # Add code_f specific parameters if problem_type == 'code_f': gen_params.update({ 'num_inputs': self.config.azr.data_selection_strategy.num_inputs, }) get_gen_code_io_data(**gen_params) code_gen_train_dataset = RLHFDataset( parquet_files=parquet_path, tokenizer=self.tokenizer, prompt_key=self.config.data.prompt_key, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error', extra_source_key=f"gen_{problem_type}_train" ) if self.config.data.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) sampler = RandomSampler(code_gen_train_dataset, generator=train_dataloader_generator) else: sampler = SequentialSampler(code_gen_train_dataset) return iter(DataLoader( dataset=code_gen_train_dataset, batch_size=self.config.data.train_batch_size, drop_last=True, collate_fn=collate_fn, sampler=sampler )) def _create_train_code_pred_dataloader(self, problem_type: str, data_len: int) -> DataLoader: if problem_type == 'code_i': dataset_key = 'input' elif problem_type == 'code_o': dataset_key = 'output' elif problem_type == 'code_e': dataset_key = 'error' elif problem_type == 'code_f': dataset_key = 'problem' else: raise ValueError(f'Invalid problem type: {problem_type}') full_dataset = ray.get(self.dataset_manager.get_dataset.remote(dataset_key)) strategy = self.config.azr.pred_data_mix_strategy if strategy == "step": # Get entries with their creation steps entries_with_steps = ray.get(self.dataset_manager.get_dataset_with_steps.remote(dataset_key)) if not entries_with_steps: selected_data = [] else: entries, steps = zip(*entries_with_steps) # Calculate inverse step weights (newer entries get higher weight) selected_indices = random.choices( range(len(entries)), weights=steps, k=min(data_len, len(entries)) ) selected_data = [entries[i] for i in selected_indices] elif strategy == "uniform_total": selected_data = random.sample(full_dataset, min(len(full_dataset), data_len)) elif strategy == "max_new": total_recent = ray.get(self.dataset_manager.get_recent_additions.remote( dataset_key, self.global_steps, self._past_epoch_window )) new_programs = full_dataset[-total_recent:] if total_recent > 0 else [] new_samples = random.sample(new_programs, min(len(new_programs), data_len)) remaining = data_len - len(new_samples) selected_data = new_samples + random.sample(full_dataset, remaining) elif strategy == "half_new": total_recent = ray.get(self.dataset_manager.get_recent_additions.remote( dataset_key, self.global_steps, self._past_epoch_window )) new_programs = full_dataset[-total_recent:] if total_recent > 0 else [] new_count = min(len(new_programs), data_len//2) base_count = data_len - new_count selected_data = random.sample(new_programs, new_count) + random.sample(full_dataset, base_count) else: raise ValueError(f"Unknown strategy: {strategy}") parquet_path = (self._code_dir / f'train_pred_{problem_type}.parquet').as_posix() get_pred_code_io_data( io_data=selected_data, target_data_len=data_len, problem_type=problem_type, content_max_length=self.config.azr.data_selection_strategy.content_max_length, output_path=parquet_path, split='train', tokenizer=self.tokenizer, instruction_type=self.config.reward_fn.extraction_type, ) code_pred_train_dataset = RLHFDataset(parquet_files=parquet_path, tokenizer=self.tokenizer, prompt_key=self.config.data.prompt_key, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error', extra_source_key=f"pred_{problem_type}_train") # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) sampler = RandomSampler(data_source=code_pred_train_dataset, generator=train_dataloader_generator) else: sampler = SequentialSampler(data_source=code_pred_train_dataset) code_pred_train_dataloader = DataLoader(dataset=code_pred_train_dataset, batch_size=self.config.data.train_batch_size, drop_last=True, collate_fn=collate_fn, sampler=sampler) assert len(code_pred_train_dataloader) >= 1 return iter(code_pred_train_dataloader) def _compute_batch(self, batch: DataProto, metrics: dict, timing_raw: dict, problem_type: str, executor: PythonExecutor) -> tuple[DataProto, dict]: PrettyPrinter.section_header(f"Computing batch for {problem_type}") # pop those keys for generation gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) # generate a batch with marked_timer(f'gen/{problem_type}', timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) batch.batch["response_mask"] = compute_response_mask(batch) # Balance the number of valid tokens across DP ranks. # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), # but might affect the loss calculation (due to the change of mini-batching). # TODO: Decouple the DP balancing and mini-batching. if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() # recompute old_log_probs with marked_timer(f'old_log_prob/{problem_type}', timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if "rollout_log_probs" in batch.batch.keys(): # TODO: we may want to add diff of probs too. rollout_old_log_probs = batch.batch["rollout_log_probs"] actor_old_log_probs = batch.batch["old_log_probs"] attention_mask = batch.batch["attention_mask"] responses = batch.batch["responses"] response_length = responses.size(1) response_mask = attention_mask[:, -response_length:] rollout_probs = torch.exp(rollout_old_log_probs) actor_probs = torch.exp(actor_old_log_probs) rollout_probs_diff = torch.abs(rollout_probs - actor_probs) rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) rollout_probs_diff_max = torch.max(rollout_probs_diff) rollout_probs_diff_mean = torch.mean(rollout_probs_diff) rollout_probs_diff_std = torch.std(rollout_probs_diff) metrics.update( { "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), } ) if self.use_reference_policy: with marked_timer(f'ref/{problem_type}', timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) # compute values if self.use_critic: with marked_timer(f'values/{problem_type}', timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) with marked_timer(f'adv/{problem_type}', timing_raw): if self.use_rm: reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor) input_type_counters, output_type_counters, error_type_counters = None, None, None # Get the appropriate type counters based on problem type if problem_type == 'gen_code_i': input_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('input')) elif problem_type == 'gen_code_o': output_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('output')) elif problem_type == 'gen_code_e': error_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('error')) elif problem_type == 'gen_code_f': input_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('input')) output_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('output')) # make sure actor_rollout_wg n > 1 if problem_type.startswith('gen'): reward_fn_kwargs = { 'data': batch, 'problem_type': problem_type, 'executor': executor, # need this to check for execution errors 'rollout_actor_wg': self.actor_rollout_wg, # need this to estimate difficulty reward 'banned_words': self.config.azr.data_selection_strategy.banned_words, # need this to check for banned words 'n_samples': self.config.azr.reward.n_samples, 'input_type_counters': input_type_counters, 'output_type_counters': output_type_counters, 'error_type_counters': error_type_counters, } elif problem_type.startswith('pred'): reward_fn_kwargs = { 'data': batch, 'problem_type': problem_type, 'executor': executor, } # For prediction tasks, initialize empty invalid_programs invalid_programs = [] with marked_timer(f'reward_fn/{problem_type}', timing_raw): PrettyPrinter.status("REWARD", f"Computing rewards for {problem_type}...", "info") if problem_type.startswith('gen'): reward_tensor, train_metrics, valid_programs, correct_predictions, invalid_programs = self.reward_fn(**reward_fn_kwargs) else: # pred type reward_tensor, train_metrics, valid_programs, correct_predictions, invalid_programs = self.reward_fn(**reward_fn_kwargs) PrettyPrinter.status("REWARD", f"Found {len(valid_programs) if valid_programs else 0} valid programs", "success") # get avg_program lines avg_program_lines = sum(len(program['snippet'].split('\n')) for program in valid_programs) / len(valid_programs) if valid_programs else 0 train_metrics[f'{problem_type}/avg_program_lines'] = avg_program_lines # Save generated data if enabled if (self.config.azr.save_generated_data and self.config.azr.save_data_path and self.global_steps % self.config.azr.save_frequency == 0): # Save valid programs if valid_programs and self.config.azr.save_valid_data: save_programs_to_file( programs=valid_programs, problem_type=problem_type, global_step=self.global_steps, save_path=self.config.azr.save_data_path, data_type="valid" ) # Save invalid programs if invalid_programs and self.config.azr.save_invalid_data: save_programs_to_file( programs=invalid_programs, problem_type=problem_type, global_step=self.global_steps, save_path=self.config.azr.save_data_path, data_type="invalid" ) # Add statistics to metrics train_metrics[f'{problem_type}/num_valid_saved'] = len(valid_programs) if valid_programs else 0 train_metrics[f'{problem_type}/num_invalid_saved'] = len(invalid_programs) if invalid_programs else 0 # Log new programs if available if valid_programs and self.config.azr.random_print_max_programs > 0: PrettyPrinter.section_header(f"New {problem_type} Programs") max_print = min(self.config.azr.random_print_max_programs, len(valid_programs)) for program in random.sample(valid_programs, max_print): PrettyPrinter.status(f"PROBLEM TYPE", problem_type, "info") if 'code_f' not in problem_type: PrettyPrinter.code_block(program['snippet'], "python") PrettyPrinter.status("INPUT", program['input'], "info") PrettyPrinter.status("OUTPUT", program['output'], "info") PrettyPrinter.status("THOUGHT", program['thought'], "info") PrettyPrinter.status("COMPOSITE FUNCTION", "YES!" if len(program['composite_functions']) > 0 else "NO!", "info") else: PrettyPrinter.code_block(program['snippet'], "python") PrettyPrinter.status("INPUT", program['inputs'], "info") PrettyPrinter.status("OUTPUT", program['outputs'], "info") PrettyPrinter.status("MESSAGE", program['message'], "info") PrettyPrinter.status("THOUGHT", program['thought'], "info") print("\n" + "-"*80 + "\n") if correct_predictions and self.config.azr.random_print_max_programs > 0: PrettyPrinter.section_header(f"New {problem_type} Programs") max_print = min(self.config.azr.random_print_max_programs, len(correct_predictions)) for program in random.sample(correct_predictions, max_print): if 'code_f' not in problem_type: PrettyPrinter.code_block(program['program'], "python") # also print the problem_type PrettyPrinter.status(f"PROBLEM TYPE", problem_type, "info") PrettyPrinter.status("INPUT", program['input'], "info") PrettyPrinter.status("OUTPUT", program['output'], "info") PrettyPrinter.status("THOUGHT", program['thought'], "info") else: PrettyPrinter.code_block(program['answer']['snippet'], "python") PrettyPrinter.code_block(program['answer']['gold_program'], "python (gold)") PrettyPrinter.status("HIDDEN INPUT", program['hidden_inputs'], "info") PrettyPrinter.status("HIDDEN OUTPUT", program['hidden_outputs'], "info") PrettyPrinter.status("GIVEN INPUT", program['given_inputs'], "info") PrettyPrinter.status("GIVEN OUTPUT", program['given_outputs'], "info") PrettyPrinter.status("MESSAGE", program['answer']['message'], "info") PrettyPrinter.status("THOUGHT", program['answer']['thought'], "info") print("\n" + "-"*80 + "\n") if problem_type.endswith('code_i'): if valid_programs: # Process locally first processed_programs = process_elements(valid_programs) # Then batch add to dataset ray.get(self.dataset_manager.add_input_batch.remote(processed_programs, self.global_steps)) elif problem_type.endswith('code_o'): if valid_programs: processed_programs = process_elements(valid_programs) ray.get(self.dataset_manager.add_output_batch.remote(processed_programs, self.global_steps)) elif problem_type.endswith('code_e'): if valid_programs: processed_programs = process_elements(valid_programs) ray.get(self.dataset_manager.add_error_batch.remote(processed_programs, self.global_steps)) elif problem_type.endswith('code_f'): if valid_programs: processed_programs = process_elements(valid_programs) ray.get(self.dataset_manager.add_problem_batch.remote(processed_programs, self.global_steps)) else: raise ValueError(f'Invalid problem type: {problem_type}') if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_i'): truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'input')) PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from input dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_o'): truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'output')) PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from output dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_e'): truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'error')) PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from error dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_f'): truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'problem')) PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from problem dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") train_metrics = {f'{problem_type}/{k}': np.mean(v) for k, v in train_metrics.items()} # log the number of valid programs added to the dataset if problem_type.startswith('gen'): if problem_type.endswith('code_i'): dataset_key = 'input' elif problem_type.endswith('code_o'): dataset_key = 'output' elif problem_type.endswith('code_e'): dataset_key = 'error' elif problem_type.endswith('code_f'): dataset_key = 'problem' else: raise ValueError(f'Invalid problem type: {problem_type}') train_metrics[f'{problem_type}/num_valid_programs'] = ray.get( self.dataset_manager.get_recent_additions.remote( dataset_key, self.global_steps, self._past_epoch_window ) ) metrics.update(train_metrics) batch.batch['token_level_scores'] = reward_tensor if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty) metrics.update(kl_metrics) else: batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, gamma=self.config.algorithm.gamma, lam=self.config.algorithm.lam, num_repeat=self.config.actor_rollout_ref.rollout.n, config=self.config.algorithm) gc.collect() return batch, metrics def _init_seed_dataset(self, problem_types: List[str]) -> Tuple[List[Dict], List[Dict], List[Dict]]: # Initialize with seed program using the coordinator if ('code_i' in problem_types or 'code_o' in problem_types) and ray.get(self.dataset_manager.get_dataset.remote('seed')) == []: ray.get(self.dataset_manager.update_seed.remote([ {'snippet': seed_program, 'input': '"Hello world"', 'output': '"Hello world"', 'imports': [], 'original_snippet': seed_program, 'composite_functions': []} ])) if 'code_e' in problem_types and ray.get(self.dataset_manager.get_dataset.remote('error_seed')) == []: ray.get(self.dataset_manager.update_error_seed.remote([ {'snippet': seed_program, 'input': '"Hello world"', 'output': 'NoError', 'imports': [], 'original_snippet': seed_program, 'composite_functions': []} ])) if 'code_f' in problem_types and ray.get(self.dataset_manager.get_dataset.remote('problem')) == []: ray.get(self.dataset_manager.add_problem_batch.remote([ { 'snippet': seed_program, 'inputs': ['"Hello world"', '1', "dict(a=1, b=2)", '(1.1, 1.2, 1.3)', '"[[1, 0, 0], [0, 0, 0], [0, 0, 0]]"', '1001101100010001'], 'outputs': ['"Hello world"', '1', "dict(a=1, b=2)", '(1.1, 1.2, 1.3)', '"[[1, 0, 0], [0, 0, 0], [0, 0, 0]]"', '1001101100010001'], 'message': 'Write a function that returns whatever you input', 'imports': [], } ], self.global_steps)) target_size = self.config.azr.data_selection_strategy.data_len * self.config.azr.data_selection_strategy.seed_batch_factor while problem_types != ['code_f']: # we can skip this loop if we are only generating code_f dataset # Get current dataset state seed_dataset = ray.get(self.dataset_manager.get_dataset.remote('seed')) error_dataset = ray.get(self.dataset_manager.get_dataset.remote('error_seed')) if problem_types == ['code_e'] and len(error_dataset) >= target_size: # only generate error seed dataset break if 'code_e' not in problem_types and len(seed_dataset) >= target_size: # only generate seed dataset break if len(seed_dataset) >= target_size and len(error_dataset) >= target_size: # generating both seed and error seed dataset break for problem_type in problem_types: if problem_type == 'code_f': # skip code_f dataset, we will generate it later continue if problem_type == 'code_e' and len(error_dataset) >= target_size: continue if problem_type != 'code_e' and len(seed_dataset) >= target_size: continue seed_dataloader = self._create_train_code_gen_dataloader( problem_type=problem_type, data_len=self.config.data.train_batch_size, dataset_key='error_seed' if problem_type == 'code_e' else 'seed', seeding=True, ) for batch_dict in seed_dataloader: batch: DataProto = DataProto.from_single_dict(batch_dict) gen_batch = batch.pop(['input_ids', 'attention_mask', 'position_ids']) gen_batch.meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, 'recompute_log_prob': False, 'do_sample': True, 'validate': True, } # pad to be divisible by dp_size gen_batch_padded, pad_size = pad_dataproto_to_divisor(gen_batch, self.actor_rollout_wg.world_size) output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(gen_batch_padded) pad_size *= self.config.actor_rollout_ref.rollout.n # unpad output_gen_batch = unpad_dataproto(output_gen_batch_padded, pad_size=pad_size) # If we're doing multiple samples, repeat the original batch batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(output_gen_batch) # Store generated outputs output_ids = batch.batch['responses'] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] local_entries = [] local_error_entries = [] for output_text in output_texts: success, result = parse_code_input_output( output_text, parse_output=False, remove_after_return=self.config.azr.reward.generation_reward_config.remove_after_return, remove_comments=self.config.azr.reward.generation_reward_config.remove_comments, remove_print=self.config.azr.reward.generation_reward_config.remove_print, reject_multiple_functions=self.config.azr.reward.generation_reward_config.reject_multiple_functions, f_replace_location=self.config.azr.reward.generation_reward_config.f_replace_location, reject_test_input_in_code=self.config.azr.reward.generation_reward_config.reject_test_input_in_code, code_location=self.config.azr.reward.generation_reward_config.code_location, ) if success: code_validity, output = self._executor.check_all( code=result['code'], inputs=result['input'], banned_keywords=self.config.azr.data_selection_strategy.banned_words, check_determinism=True, imports=list(set(result['imports'])), check_error=problem_type == 'code_e', banned_keywords_for_errors_and_exceptions=self.config.azr.data_selection_strategy.banned_keywords_for_errors_and_exceptions, ) if code_validity: if problem_type == 'code_e': local_error_entries.append( { 'snippet': result['code'], 'input': result['input'], 'output': output, 'imports': result['imports'], 'original_snippet': result['code'], 'composite_functions': [] } ) else: local_entries.append( { 'snippet': result['code'], 'input': result['input'], 'output': output, 'imports': result['imports'], 'original_snippet': result['code'], 'composite_functions': [] } ) if self.config.azr.data_selection_strategy.get('generate_seed_dataset_only', False): with open(self.config.azr.data_selection_strategy.output_seed_path.replace('.jsonl', f'_temp.jsonl'), 'a') as f: for entry in local_entries: f.write(json.dumps(entry) + '\n') break # only use the first batch, to continuously generate more diverse data # Atomically update shared dataset if problem_type != 'code_e': # Process locally first processed_entries = process_elements(local_entries) # Then send to ray added_count = ray.get( self.dataset_manager.update_seed.remote(processed_entries) ) # Get updated dataset seed_dataset = ray.get(self.dataset_manager.get_dataset.remote('seed')) PrettyPrinter.status( "WORKER", f"Added {added_count} new entries (Total: {len(seed_dataset)})", "info" ) PrettyPrinter.progress_bar( current=len(seed_dataset), total=target_size, label="Dataset Growth" ) # Early exit if we've reached target size if len(seed_dataset) >= target_size: break elif problem_type == 'code_e': # Process locally first processed_entries = process_elements(local_error_entries) # Then send to ray error_added_count = ray.get( self.dataset_manager.update_error_seed.remote(processed_entries) ) error_dataset = ray.get(self.dataset_manager.get_dataset.remote('error_seed')) PrettyPrinter.status( "WORKER", f"Added {error_added_count} new entries (Total: {len(error_dataset)})", "info" ) PrettyPrinter.progress_bar( current=len(error_dataset), total=target_size, label="Error Dataset Growth" ) if len(error_dataset) >= target_size: break # now get the code_f dataset if 'code_f' in problem_types: code_f_dataset = [] all_snippets = ray.get(self.dataset_manager.get_snippets.remote()) while len(code_f_dataset) < target_size: # randomly sample a snippet from all_snippets code_f_dataset = ray.get(self.dataset_manager.get_dataset.remote('problem')) code_f_seed_dataloader = self._create_train_code_gen_dataloader( data_len=len(all_snippets), problem_type='code_f', seeding=True ) epoch_entries = [] for batch in code_f_seed_dataloader: batch: DataProto = DataProto.from_single_dict(batch) gen_batch = batch.pop(['input_ids', 'attention_mask', 'position_ids']) gen_batch.meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, 'recompute_log_prob': False, 'do_sample': True, 'validate': True, } # pad to be divisible by dp_size gen_batch_padded, pad_size = pad_dataproto_to_divisor(gen_batch, self.actor_rollout_wg.world_size) output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(gen_batch_padded) pad_size *= self.config.actor_rollout_ref.rollout.n # unpad output_gen_batch = unpad_dataproto(output_gen_batch_padded, pad_size=pad_size) # If we're doing multiple samples, repeat the original batch batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(output_gen_batch) # Store generated outputs output_ids = batch.batch['responses'] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] for idx, output_text in enumerate(output_texts): success, result = parse_inputs_message(output_text, num_inputs=self.config.azr.data_selection_strategy.num_inputs) if success: outputs = [] for ipt in result['inputs']: code_validity, output = self._executor.check_all( code=batch.non_tensor_batch['extra_info'][idx]['chosen_references'][0]['snippet'], inputs=ipt, banned_keywords=[], check_determinism=True, imports=batch.non_tensor_batch['extra_info'][idx]['imports'], check_error=False, banned_keywords_for_errors_and_exceptions=[], ) outputs.append(output) if not code_validity: break if code_validity: epoch_entries.append( { 'snippet': batch.non_tensor_batch['extra_info'][idx]['chosen_references'][0]['snippet'], 'inputs': result['inputs'], 'outputs': outputs, 'message': result['message'], 'imports': batch.non_tensor_batch['extra_info'][idx]['imports'].tolist(), } ) # Then send to ray processed_entries = process_elements(epoch_entries) added_count = ray.get(self.dataset_manager.add_problem_batch.remote(processed_entries, self.global_steps)) # Get updated dataset code_f_dataset = ray.get(self.dataset_manager.get_dataset.remote('problem')) PrettyPrinter.status( "WORKER", f"Added {added_count} new entries (Total: {len(code_f_dataset)})", "info" ) PrettyPrinter.progress_bar( current=len(code_f_dataset), total=target_size, label="Code F Dataset Growth" ) if self.config.azr.data_selection_strategy.get('generate_seed_dataset_only', False): with open(self.config.azr.data_selection_strategy.output_code_f_seed_path.replace('.jsonl', f'_temp.jsonl'), 'a') as f: for entry in code_f_dataset: f.write(json.dumps(entry) + '\n') # Early exit if we've reached target size if len(code_f_dataset) >= target_size: break # truncate the dataset to the target size ray.get(self.dataset_manager.truncate_datasets.remote(target_size, 'seed')) ray.get(self.dataset_manager.truncate_datasets.remote(target_size, 'error_seed')) ray.get(self.dataset_manager.truncate_datasets.remote(target_size, 'problem')) # Sample type statistics after seed initialization if self.global_steps == 0: # Only log this on first initialization type_stats = ray.get(self.dataset_manager.get_all_type_statistics.remote()) PrettyPrinter.section_header("Initial Type Statistics") for category, type_data in type_stats.items(): category_display = { 'input_types': 'Input Types', 'output_types': 'Output Types', 'error_types': 'Error Types' }.get(category, category) if type_data: # Only show if we have data PrettyPrinter.status(category_display.upper(), f"Total types: {len(type_data)}", "info") for type_name, stats in sorted(type_data.items(), key=lambda x: x[1]['total_unique'], reverse=True)[:5]: # Show top 5 by unique count PrettyPrinter.status( f" {type_name}", f"Unique: {stats['total_unique']}, Total: {stats['total_count']}", "info" ) if stats['examples']: example = stats['examples'][0] if len(example) > 100: example = example[:97] + "..." PrettyPrinter.status(" Example", example, "info") # Final dataset from coordinator seed_dataset = ray.get(self.dataset_manager.get_dataset.remote('seed')) error_dataset = ray.get(self.dataset_manager.get_dataset.remote('error_seed')) code_f_dataset = ray.get(self.dataset_manager.get_dataset.remote('problem')) # Modify dataset saving condition if ('code_i' in problem_types or 'code_o' in problem_types) and self.config.azr.output_seed_path is not None: PrettyPrinter.status("DATASET", "Writing seed dataset to JSONL file...", "info") output_path = Path(self.config.azr.output_seed_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: for item in seed_dataset: f.write(json.dumps(item) + '\n') PrettyPrinter.status("DATASET", f"Saved {len(seed_dataset)} entries to {str(output_path)}", "success") if 'code_e' in problem_types and self.config.azr.output_error_seed_path is not None: PrettyPrinter.status("DATASET", "Writing error seed dataset to JSONL file...", "info") output_path = Path(self.config.azr.output_error_seed_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: for item in error_dataset: f.write(json.dumps(item) + '\n') PrettyPrinter.status("DATASET", f"Saved {len(error_dataset)} entries to {str(output_path)}", "success") if 'code_f' in problem_types and self.config.azr.output_code_f_seed_path is not None: PrettyPrinter.status("DATASET", "Writing code f seed dataset to JSONL file...", "info") output_path = Path(self.config.azr.output_code_f_seed_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: for item in code_f_dataset: try: f.write(json.dumps(item) + '\n') except: print(item) raise Exception("Failed to save code f dataset") PrettyPrinter.status("DATASET", f"Saved {len(code_f_dataset)} entries to {str(output_path)}", "success") # Show a few sample entries if 'code_i' in problem_types or 'code_o' in problem_types: # Print detailed dataset summary PrettyPrinter.section_header("Seed Dataset Summary") PrettyPrinter.table( ["Key", "Value"], [ ["Total Samples", len(seed_dataset)], ["Target Size", target_size], ["Storage Path", self.config.azr.output_seed_path], ["Sample Types", len(set(item['snippet'] for item in seed_dataset))], ["Average Snippet Length", sum(len(item['snippet']) for item in seed_dataset) // len(seed_dataset) if seed_dataset else 0] ], title="Dataset Statistics" ) PrettyPrinter.section_header("Sample Entries") # sample 3 entries for i, item in enumerate(random.sample(seed_dataset, self.config.azr.random_print_max_programs)): PrettyPrinter.code_block(item['snippet'], "python") PrettyPrinter.status("INPUT", item['input'], "info") PrettyPrinter.status("OUTPUT", item['output'], "info") if i < 2: # Don't print separator after last item print("\n" + "-" * 80 + "\n") if 'code_e' in problem_types: PrettyPrinter.section_header("Error Dataset Summary") PrettyPrinter.table( ["Key", "Value"], [ ["Total Samples", len(error_dataset)], ["Target Size", target_size], ["Storage Path", self.config.azr.output_error_seed_path], ["Sample Types", len(set(item['snippet'] for item in error_dataset))], ["Average Snippet Length", sum(len(item['snippet']) for item in error_dataset) // len(error_dataset) if error_dataset else 0] ], title="Error Dataset Statistics" ) PrettyPrinter.section_header("Error Sample Entries") # sample 3 entries for i, item in enumerate(random.sample(error_dataset, self.config.azr.random_print_max_programs)): PrettyPrinter.code_block(item['snippet'], "python") PrettyPrinter.status("INPUT", item['input'], "info") PrettyPrinter.status("OUTPUT", item['output'], "info") if i < 2: # Don't print separator after last item print("\n" + "-" * 80 + "\n") if 'code_f' in problem_types: PrettyPrinter.section_header("Code F Dataset Summary") PrettyPrinter.table( ["Key", "Value"], [ ["Total Samples", len(code_f_dataset)], ["Target Size", target_size], ["Storage Path", self.config.azr.output_code_f_seed_path], ["Sample Types", len(set(item['snippet'] for item in code_f_dataset))], ["Average Snippet Length", sum(len(item['snippet']) for item in code_f_dataset) // len(code_f_dataset) if code_f_dataset else 0] ], title="Code F Dataset Statistics" ) PrettyPrinter.section_header("Code F Sample Entries") # sample 3 entries for i, item in enumerate(random.sample(code_f_dataset, self.config.azr.random_print_max_programs)): PrettyPrinter.code_block(item['snippet'], "python") PrettyPrinter.status("INPUTS", item['inputs'], "info") PrettyPrinter.status("OUTPUTS", item['outputs'], "info") PrettyPrinter.status("MESSAGE", item['message'], "info") if i < 2: # Don't print separator after last item print("\n" + "-" * 80 + "\n") return seed_dataset, error_dataset, code_f_dataset def _create_seed_datasets_old(self): """[OLD VERSION - DEPRECATED] Create seed datasets for generation tasks""" need_seed_dataset = any(problem_type != 'code_e' for problem_type in self.config.azr.problem_types) or 'code_f' in self.config.azr.problem_types need_error_dataset = 'code_e' in self.config.azr.problem_types need_code_f_dataset = 'code_f' in self.config.azr.problem_types # Initialize with defaults seed_dataset = [] error_dataset = [] code_f_dataset = [] # Load or generate seed dataset if needed if need_seed_dataset: if self.config.azr.seed_dataset is not None: PrettyPrinter.status("DATA", "Loading seed dataset from file...", "info") with open(self.config.azr.seed_dataset, 'r') as file: seed_dataset = [json.loads(line) for line in file] seed_dataset = seed_dataset[:self.config.azr.data_selection_strategy.data_len * self.config.azr.data_selection_strategy.seed_batch_factor] PrettyPrinter.status("DATA", f"Loaded {len(seed_dataset)} seed entries", "success") if 'code_f' in self.config.azr.problem_types: # we need seed to generate code_f ray.get(self.dataset_manager.update_seed.remote(seed_dataset)) else: PrettyPrinter.status("DATA", "Seed dataset not provided, will generate", "info") # Load or prepare to generate error dataset if needed if need_error_dataset: if self.config.azr.error_seed_dataset is not None: PrettyPrinter.status("DATA", "Loading error seed dataset from file...", "info") with open(self.config.azr.error_seed_dataset, 'r') as file: error_dataset = [json.loads(line) for line in file] error_dataset = error_dataset[:self.config.azr.data_selection_strategy.data_len * self.config.azr.data_selection_strategy.seed_batch_factor] PrettyPrinter.status("DATA", f"Loaded {len(error_dataset)} error entries", "success") else: PrettyPrinter.status("DATA", "Error seed dataset not provided, will generate", "info") if need_code_f_dataset: if self.config.azr.code_f_seed_dataset is not None: PrettyPrinter.status("DATA", "Loading code f seed dataset from file...", "info") with open(self.config.azr.code_f_seed_dataset, 'r') as file: code_f_dataset = [json.loads(line) for line in file] code_f_dataset = code_f_dataset[:self.config.azr.data_selection_strategy.data_len * self.config.azr.data_selection_strategy.seed_batch_factor] PrettyPrinter.status("DATA", f"Loaded {len(code_f_dataset)} code f entries", "success") # Generate missing datasets if needed need_to_generate_seed = need_seed_dataset and len(seed_dataset) == 0 need_to_generate_error = need_error_dataset and len(error_dataset) == 0 need_to_generate_code_f = need_code_f_dataset and len(code_f_dataset) == 0 # TTRLVR: Skip seed generation when train_propose is False if self.config.azr.train_propose and (need_to_generate_seed or need_to_generate_error or need_to_generate_code_f): sample_problem_types = [] for problem_type in self.config.azr.problem_types: if problem_type == 'code_e' and need_to_generate_error: sample_problem_types.append(problem_type) elif problem_type != 'code_e' and need_to_generate_seed: sample_problem_types.append(problem_type) elif problem_type == 'code_f' and need_to_generate_code_f: sample_problem_types.append(problem_type) PrettyPrinter.status("DATA", f"Generating missing datasets for {', '.join(sample_problem_types)}...", "info") generated_seed, generated_error, generated_code_f = self._init_seed_dataset(problem_types=sample_problem_types) elif not self.config.azr.train_propose and (need_to_generate_seed or need_to_generate_error or need_to_generate_code_f): PrettyPrinter.status("DATA", "Skipping seed generation (train_propose=False)", "info") generated_seed, generated_error, generated_code_f = [], [], [] if need_to_generate_seed: seed_dataset = generated_seed PrettyPrinter.status("DATA", f"Generated {len(seed_dataset)} seed entries", "success") if need_to_generate_error: error_dataset = generated_error PrettyPrinter.status("DATA", f"Generated {len(error_dataset)} error entries", "success") if need_to_generate_code_f: code_f_dataset = generated_code_f PrettyPrinter.status("DATA", f"Generated {len(code_f_dataset)} code f entries", "success") if self.config.azr.get('generate_seed_dataset_only', False): return # Now initialize datasets in dataset manager if need_seed_dataset and self.config.azr.train_propose: assert len(seed_dataset) >= self.config.azr.data_selection_strategy.data_len if 'code_i' in self.config.azr.problem_types: # Process locally first processed_seed_dataset = process_elements(seed_dataset) # Initialize input dataset with seed data ray.get(self.dataset_manager.add_input_batch.remote(processed_seed_dataset, self.global_steps)) PrettyPrinter.status( "DATA", f"Input dataset initialized with {len(seed_dataset)} entries", "success" ) if 'code_o' in self.config.azr.problem_types: processed_seed_dataset = process_elements(seed_dataset) ray.get(self.dataset_manager.add_output_batch.remote(processed_seed_dataset, self.global_steps)) PrettyPrinter.status( "DATA", f"Output dataset initialized with {len(seed_dataset)} entries", "success" ) if need_error_dataset and self.config.azr.train_propose: assert len(error_dataset) >= self.config.azr.data_selection_strategy.data_len processed_error_dataset = process_elements(error_dataset) ray.get(self.dataset_manager.add_error_batch.remote(processed_error_dataset, self.global_steps)) PrettyPrinter.status( "DATA", f"Error dataset initialized with {len(error_dataset)} entries", "success" ) if need_code_f_dataset and self.config.azr.train_propose: assert len(code_f_dataset) >= self.config.azr.data_selection_strategy.data_len processed_code_f_dataset = process_elements(code_f_dataset) ray.get(self.dataset_manager.add_problem_batch.remote(processed_code_f_dataset, self.global_steps)) PrettyPrinter.status( "DATA", f"Code f dataset initialized with {len(code_f_dataset)} entries", "success" ) # we start from step 1 self.global_steps += 1 if self.config.azr.pretrain_pred_steps > 0 and self.global_steps <= self.config.azr.pretrain_pred_steps: self.pretrain_pred = True else: self.pretrain_pred = False while self.global_steps < self.total_training_steps: PrettyPrinter.section_header(f"Training Step {self.global_steps}") if self.config.azr.data_selection_strategy.composite_scheduler.enabled: self.scheduler_step() # Calculate progress metrics progress_percentage = (self.global_steps / self.total_training_steps) * 100 import time if not hasattr(self, '_start_time'): self._start_time = time.time() elapsed_time = time.time() - self._start_time if self.global_steps > 0: time_per_step = elapsed_time / self.global_steps remaining_steps = self.total_training_steps - self.global_steps eta_seconds = remaining_steps * time_per_step eta_hours = eta_seconds / 3600 PrettyPrinter.progress_bar( current=self.global_steps, total=self.total_training_steps, label=f"Training Progress - {progress_percentage:.1f}% - ETA: {eta_hours:.1f}h" ) # Print detailed timing info PrettyPrinter.status("TIMING", f"Elapsed: {elapsed_time/3600:.1f}h | Step: {time_per_step/60:.1f}min | ETA: {eta_hours:.1f}h", "info") else: PrettyPrinter.progress_bar( current=self.global_steps, total=self.total_training_steps, label="Training Progress - Starting..." ) data_len = self.config.data.train_batch_size * self.config.azr.data_selection_strategy.update_iteration if 'code_i' in self.config.azr.problem_types: gen_code_i_dataloader = self._create_train_code_gen_dataloader( problem_type='code_i', data_len=data_len, ) pred_code_i_dataloader = self._create_train_code_pred_dataloader( problem_type='code_i', data_len=data_len, ) if 'code_o' in self.config.azr.problem_types: gen_code_o_dataloader = self._create_train_code_gen_dataloader( problem_type='code_o', data_len=data_len, ) pred_code_o_dataloader = self._create_train_code_pred_dataloader( problem_type='code_o', data_len=data_len, ) if 'code_e' in self.config.azr.problem_types: gen_code_e_dataloader = self._create_train_code_gen_dataloader( problem_type='code_e', data_len=data_len, ) pred_code_e_dataloader = self._create_train_code_pred_dataloader( problem_type='code_e', data_len=data_len, ) if 'code_f' in self.config.azr.problem_types: gen_code_f_dataloader = self._create_train_code_gen_dataloader( data_len=data_len, problem_type='code_f', seeding=False, ) pred_code_f_dataloader = self._create_train_code_pred_dataloader( problem_type='code_f', data_len=data_len, ) for _ in range(self.config.azr.data_selection_strategy.update_iteration): metrics = {} timing_raw = {} batches = {} with marked_timer('step', timing_raw): # Clean up executor periodically if self.global_steps - self._last_cleanup_step >= self._cleanup_frequency: PrettyPrinter.section_header("Periodic Cleanup") with marked_timer('cleanup', timing_raw): self.cleanup() self._last_cleanup_step = self.global_steps if 'code_i' in self.config.azr.problem_types: if not self.pretrain_pred: batch_dict = next(gen_code_i_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_i', executor=self._executor) if self.config.azr.train_propose: batches[f'gen_code_i'] = batch batch_dict = next(pred_code_i_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_i', executor=self._executor) batches[f'pred_code_i'] = batch if 'code_o' in self.config.azr.problem_types: if not self.pretrain_pred: batch_dict = next(gen_code_o_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_o', executor=self._executor) if self.config.azr.train_propose: batches[f'gen_code_o'] = batch batch_dict = next(pred_code_o_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_o', executor=self._executor) batches[f'pred_code_o'] = batch if 'code_e' in self.config.azr.problem_types: if not self.pretrain_pred: batch_dict = next(gen_code_e_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_e', executor=self._executor) if self.config.azr.train_propose: batches[f'gen_code_e'] = batch batch_dict = next(pred_code_e_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_e', executor=self._executor) batches[f'pred_code_e'] = batch if 'code_f' in self.config.azr.problem_types: if not self.pretrain_pred: batch_dict = next(gen_code_f_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_f', executor=self._executor) if self.config.azr.train_propose: batches[f'gen_code_f'] = batch batch_dict = next(pred_code_f_dataloader) batch: DataProto = DataProto.from_single_dict(batch_dict) batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_f', executor=self._executor) batches[f'pred_code_f'] = batch # concatenate batches batch = DataProto.concat(list(batches.values())) PrettyPrinter.section_header(f"Starting Parameter Updates") # update critic if self.use_critic: with marked_timer('update_critic', timing_raw): critic_output = self.critic_wg.update_critic(batch) critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: with marked_timer('update_actor', timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) metrics.update(actor_output_metrics) # validate PrettyPrinter.section_header(f"Starting Validation") if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ self.global_steps % self.config.trainer.test_freq == 0: with marked_timer('testing', timing_raw): val_metrics: dict = self._validate() PrettyPrinter.table( ["Data Source", "Average Score"], [[k, v] for k, v in val_metrics.items()], title="Validation Results" ) metrics.update(val_metrics) # print the statistics of the number of programs in the dataset if 'code_i' in self.config.azr.problem_types: PrettyPrinter.status( "DATA", f"Number of programs in the input dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('input'))}", "info" ) if 'code_o' in self.config.azr.problem_types: PrettyPrinter.status( "DATA", f"Number of programs in the output dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('output'))}", "info" ) if 'code_e' in self.config.azr.problem_types: PrettyPrinter.status( "DATA", f"Number of programs in the error dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('error'))}", "info" ) if 'code_f' in self.config.azr.problem_types: PrettyPrinter.status( "DATA", f"Number of programs in the code_f dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('problem'))}", "info" ) if self.config.trainer.save_freq > 0 and \ self.global_steps % self.config.trainer.save_freq == 0: with marked_timer('save_checkpoint', timing_raw): self._save_checkpoint() # collect metrics, separate problem types all_types = [] if 'code_i' in self.config.azr.problem_types: if not self.pretrain_pred: all_types.append('gen_code_i') all_types.append('pred_code_i') if 'code_o' in self.config.azr.problem_types: if not self.pretrain_pred: all_types.append('gen_code_o') all_types.append('pred_code_o') if 'code_e' in self.config.azr.problem_types: if not self.pretrain_pred: all_types.append('gen_code_e') all_types.append('pred_code_e') if 'code_f' in self.config.azr.problem_types: if not self.pretrain_pred: all_types.append('gen_code_f') all_types.append('pred_code_f') sep_batches = batch.chunk(len(all_types)) for sep_batch, problem_type in zip(sep_batches, all_types): sep_metrics = compute_data_metrics(batch=sep_batch, use_critic=self.use_critic, tokenizer=self.tokenizer) sep_metrics = {f'{problem_type}/{k}': v for k, v in sep_metrics.items()} metrics.update(sep_metrics) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # Get and log type statistics periodically type_stats = ray.get(self.dataset_manager.get_all_type_statistics.remote()) # Log summary metrics about types for category, type_data in type_stats.items(): # Calculate diversity metrics total_types = len(type_data) total_unique_values = sum(stats["total_unique"] for stats in type_data.values()) total_instances = sum(stats["total_count"] for stats in type_data.values()) # Add to metrics metrics[f"types/{category}/distinct_types"] = total_types metrics[f"types/{category}/total_unique_values"] = total_unique_values metrics[f"types/{category}/total_instances"] = total_instances # Per-type metrics for type_name, stats in type_data.items(): metrics[f"types/{category}/{type_name}/unique_count"] = stats["total_unique"] metrics[f"types/{category}/{type_name}/total_count"] = stats["total_count"] # Print type statistics summary PrettyPrinter.section_header("Type Statistics Summary") for category, type_data in type_stats.items(): category_display = { 'input_types': 'Input Types', 'output_types': 'Output Types', 'error_types': 'Error Types', }.get(category, category) PrettyPrinter.status(category_display.upper(), f"Total types: {len(type_data)}", "info") for type_name, stats in sorted(type_data.items(), key=lambda x: x[1]['total_unique'], reverse=True)[:5]: # Show top 5 by unique count PrettyPrinter.status( f" {type_name}", f"Unique: {stats['total_unique']}, Total: {stats['total_count']}", "info" ) if stats['examples']: example = stats['examples'][0] if len(example) > 100: example = example[:97] + "..." PrettyPrinter.status(" Example", example, "info") PrettyPrinter.table( ["Category", "Value"], [[k, v] for k, v in metrics.items()], title="Step Metrics" ) logger.log(data=metrics, step=self.global_steps) if self.global_steps >= self.config.azr.pretrain_pred_steps: self.pretrain_pred = False self.global_steps += 1 gc.collect() if self.global_steps >= self.total_training_steps: # perform validation after training if self.val_reward_fn is not None: PrettyPrinter.section_header(f"Starting Final Validation") val_metrics = self._validate() PrettyPrinter.table( ["Data Source", "Average Score"], [[k, v] for k, v in val_metrics.items()], title="Final Validation Results" ) logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.save_freq > 0 and \ (self.global_steps - 1) % self.config.trainer.save_freq != 0: with marked_timer('save_checkpoint', timing_raw): self._save_checkpoint() return def _validate(self): """ The validation loop of PPO. The only difference is logging more metrics. """ reward_tensor_lst = [] data_source_lst = [] # Lists to collect samples for the table sample_inputs = [] sample_outputs = [] sample_scores = [] all_eval_metrics = defaultdict(list) for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) # we only do validation on rule-based rm if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': return {} # Store original inputs input_ids = test_batch.batch['input_ids'] input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids']) test_gen_batch.meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, 'recompute_log_prob': False, 'do_sample': False, 'validate': True, } # pad to be divisible by dp_size test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) PrettyPrinter.status("VALID", "Generation completed", "success") # Store generated outputs output_ids = test_output_gen_batch.batch['responses'] output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] sample_outputs.extend(output_texts) test_batch = test_batch.union(test_output_gen_batch) # evaluate using reward_function reward_tensor, eval_metrics, _, _, _ = self.val_reward_fn( test_batch, problem_type=None, executor=self._executor, ) for k, v in eval_metrics.items(): all_eval_metrics[k].append(v) # Store scores scores = reward_tensor.sum(-1).cpu().tolist() sample_scores.extend(scores) reward_tensor_lst.append(reward_tensor) data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) data_sources = np.concatenate(data_source_lst, axis=0) # evaluate test_score based on data source data_source_reward = {} for i in range(reward_tensor.shape[0]): data_source = data_sources[i] if data_source not in data_source_reward: data_source_reward[data_source] = [] data_source_reward[data_source].append(reward_tensor[i].item()) metric_dict = {} for data_source, rewards in data_source_reward.items(): metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) for k, v in all_eval_metrics.items(): metric_dict[k] = np.mean(v) return metric_dict def _save_datasets(self, save_dir: Path): """Save input/output datasets as JSONL files""" save_dir.mkdir(parents=True, exist_ok=True) # get all datasets and type counters datasets_with_types = ray.get(self.dataset_manager.get_all_data_with_type_counters.remote()) # save datasets pickle.dump(datasets_with_types, open(save_dir / 'datasets.pkl', 'wb')) PrettyPrinter.status("SAVE", f"Saved datasets and type counters to {save_dir}", "success") def _load_datasets(self, save_dir: Path): """Load input/output datasets from JSONL files""" datasets_with_types = pickle.load(open(Path(save_dir) / 'datasets' / 'datasets.pkl', 'rb')) # Filter datasets based on global step if self.global_steps > 0: # Filter datasets that have step info for dataset_key in ['input', 'output', 'error', 'problem']: steps_key = f"{dataset_key}_steps" if steps_key in datasets_with_types and dataset_key in datasets_with_types: # Create lists of entries to keep filtered_data = [] filtered_steps = [] # Only keep entries with steps less than current global_steps for entry, step in zip(datasets_with_types[dataset_key], datasets_with_types[steps_key]): if step <= self.global_steps: filtered_data.append(entry) filtered_steps.append(step) # Update the datasets datasets_with_types[dataset_key] = filtered_data datasets_with_types[steps_key] = filtered_steps # Also filter the step counter dictionaries counter_key = f"{dataset_key}_steps_counter" if counter_key in datasets_with_types: filtered_counter = defaultdict(int) for step, count in datasets_with_types[counter_key].items(): if step <= self.global_steps: filtered_counter[step] = count datasets_with_types[counter_key] = filtered_counter PrettyPrinter.status("FILTER", f"Filtered datasets to only include entries with steps <= {self.global_steps}", "info") ray.get(self.dataset_manager.full_load_data_with_type_counters.remote(datasets_with_types)) PrettyPrinter.status("LOAD", f"Loaded datasets and type counters from {self.config.trainer.default_local_dir}", "success") self.loaded_datasets = True def _save_checkpoint(self): super()._save_checkpoint() # save datasets self._save_datasets(Path(self.config.trainer.default_local_dir) / 'datasets') PrettyPrinter.status("SAVE", f"Saved checkpoint to {self.config.trainer.default_local_dir}", "success") def _load_checkpoint(self): """체크포인트 로딩 - None 체크 추가""" try: super()._load_checkpoint() if self.global_steps == 0: PrettyPrinter.section_header(f"Training from scratch") else: PrettyPrinter.section_header(f"Resuming training from checkpoint, step {self.global_steps}") except AttributeError as e: # global_step_folder가 None인 경우 (첫 실행) if "'NoneType' object has no attribute 'split'" in str(e): self.global_steps = 0 PrettyPrinter.section_header(f"Training from scratch (no checkpoint found)") else: raise # load datasets # first check if all the datasets exist code_dir = Path(self.config.trainer.default_local_dir) / 'code' self._code_dir = code_dir self.loaded_datasets = False if self.config.trainer.resume_mode == 'auto' and os.path.exists(os.path.join(self.config.trainer.default_local_dir, 'datasets', 'datasets.pkl')): self._load_datasets(self.config.trainer.default_local_dir) elif self.config.trainer.resume_mode == 'disable': if code_dir.exists(): # delete all files and subdirectories in the code_dir for file in code_dir.glob('**/*'): if file.is_file(): file.unlink() elif file.is_dir(): file.rmdir() PrettyPrinter.status("Directory", f"Cleaned existing code directory at {code_dir}", "info") elif not code_dir.exists(): code_dir.mkdir(parents=True, exist_ok=True) PrettyPrinter.status("Directory", f"Created new code directory at {code_dir}", "info") def scheduler_step(self): if self.config.azr.data_selection_strategy.composite_scheduler.enabled: # Update number of programs - calculate directly based on global steps if self.global_steps >= self.config.azr.data_selection_strategy.composite_scheduler.update_num_programs_start: steps_since_start = self.global_steps - self.config.azr.data_selection_strategy.composite_scheduler.update_num_programs_start num_updates = steps_since_start // self.config.azr.data_selection_strategy.composite_scheduler.update_num_programs_interval # Calculate new value directly from initial value + increments initial_max = self.config.azr.data_selection_strategy.max_programs_initial new_max = min(initial_max + num_updates, self.config.azr.data_selection_strategy.composite_scheduler.num_programs_max) # Only log if value changed if new_max != self.config.azr.data_selection_strategy.composite_function_n_max: current_max = self.config.azr.data_selection_strategy.composite_function_n_max self.config.azr.data_selection_strategy.composite_function_n_max = new_max PrettyPrinter.status("Scheduler", f"Updated max programs from {current_max} to {new_max}", "info") # Update composite probability - calculate directly based on global steps if self.global_steps >= self.config.azr.data_selection_strategy.composite_scheduler.update_probability_start: steps_since_start = self.global_steps - self.config.azr.data_selection_strategy.composite_scheduler.update_probability_start num_updates = steps_since_start // self.config.azr.data_selection_strategy.composite_scheduler.update_probability_interval # Calculate new value directly from initial value + increments initial_prob = self.config.azr.data_selection_strategy.composite_chance_initial new_prob = min(initial_prob + (num_updates * self.config.azr.data_selection_strategy.composite_scheduler.update_probability_increment), self.config.azr.data_selection_strategy.composite_scheduler.update_probability_max) # Only log if value changed if new_prob != self.config.azr.data_selection_strategy.composite_chance: current_prob = self.config.azr.data_selection_strategy.composite_chance self.config.azr.data_selection_strategy.composite_chance = new_prob PrettyPrinter.status("Scheduler", f"Updated composite probability from {current_prob:.2f} to {new_prob:.2f}", "info") def save_final_datasets(self): """Save all accumulated datasets at the end of training""" if not (self.config.azr.save_final_datasets and self.config.azr.save_data_path): return try: PrettyPrinter.section_header("Saving Final Datasets") all_datasets = ray.get(self.dataset_manager.get_all_datasets.remote()) save_dir = os.path.join(self.config.azr.save_data_path, "final_datasets") os.makedirs(save_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") for dataset_name, dataset_data in all_datasets.items(): if dataset_data and isinstance(dataset_data, list): filename = f"{dataset_name}_final_{timestamp}.jsonl" file_path = os.path.join(save_dir, filename) with open(file_path, 'w', encoding='utf-8') as f: for item in dataset_data: # Add final metadata item_with_meta = { **item, 'final_meta': { 'dataset_type': dataset_name, 'total_items': len(dataset_data), 'final_timestamp': timestamp, 'training_completed': True } } f.write(json.dumps(item_with_meta, ensure_ascii=False) + '\n') PrettyPrinter.status("FINAL SAVE", f"Saved {len(dataset_data)} items to {filename}", "success") except Exception as e: PrettyPrinter.status("FINAL SAVE", f"Failed to save final datasets: {e}", "error") def cleanup_and_save(self): """Cleanup and save final data before training ends""" self.save_final_datasets() # Any other cleanup operations can be added here # ============ TTRLVR 통합 메서드들 ============ def fit(self): """ ============================================ TTRLVR 통합 학습 루프 - 전체 5 Phase 관리 ============================================ 각 라운드마다: - Phase 1-4: 데이터 생성 (CompleteTestTimePipeline) - Phase 5: PPO 학습 (VeRL) 모든 Phase에서 같은 vLLM 인스턴스 사용 ============================================ """ logger = ReasonRLTracking( project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True), tags=self.config.trainer.wandb_tags, resume="must" if self.config.trainer.resume_mode == 'auto' and \ self.config.trainer.wandb_run_id is not None else False, run_id=self.config.trainer.wandb_run_id \ if self.config.trainer.wandb_run_id is not None else None ) self.global_steps = 0 # 체크포인트 로드 self._load_checkpoint() # 토크나이저 설정 if self.config.actor_rollout_ref.model.pretrained_tokenizer: self.tokenizer.chat_template = "{%- for message in messages -%}{{- '\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}" # 전체 세션 타임스탬프 설정 (모든 라운드에서 공유) if self.session_timestamp is None: self.session_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') PrettyPrinter.section_header("🚀 Starting Unified TTRLVR Training") PrettyPrinter.status("Config", f"Total rounds: {self.total_rounds}", "info") PrettyPrinter.status("Config", f"Problems: {self.problem_ids}", "info") PrettyPrinter.status("Config", f"Output directory: {self.output_dir}", "info") PrettyPrinter.status("Config", f"Session timestamp: {self.session_timestamp}", "info") # ============================================ # VeRL Workers 초기화 (vLLM 포함) # ============================================ PrettyPrinter.status("Init", "Initializing VeRL workers (Actor, Rollout)...", "info") self.init_workers() PrettyPrinter.status("Init", "✅ VeRL workers initialized", "success") # ============================================ # 메인 학습 루프: 각 라운드마다 Phase 1-5 실행 # ============================================ for round_num in range(1, self.total_rounds + 1): self.current_round = round_num logger.log({"round": round_num}, step=self.global_steps) PrettyPrinter.section_header(f"🔄 Round {round_num}/{self.total_rounds}") # ====== Phase 1-4: 데이터 생성 ====== round_start = datetime.now() round_data = self._generate_round_data() # CompleteTestTimePipeline 실행 data_gen_time = (datetime.now() - round_start).total_seconds() if not round_data: PrettyPrinter.status("Warning", "No data generated in Phase 1-4, skipping to next round", "warning") continue PrettyPrinter.status("📊 Phase 1-4 Complete", f"Generated {len(round_data)} training examples in {data_gen_time:.2f}s", "success") # 데이터를 parquet 파일로 저장 (AZR이 읽을 형식) saved_files = self._save_round_data(round_data, round_num) # ====== Phase 5: PPO 학습 (1 epoch) ====== PrettyPrinter.section_header(f"🎯 Phase 5: PPO Training") # Phase 5 로그 디렉토리 준비 (각 문제별 round 디렉토리에 phase5_training 추가) phase5_log_dirs = [] for problem_id in self.problem_ids: benchmark_safe = self.benchmark_config.name.replace('/', '_') problem_safe = problem_id.replace('/', '_') project_root = Path(__file__).parent.parent.parent # TestTime-RLVR-v2 directory phase5_dir = str(project_root / f'tmp/batch_results/ttrlvr_azr_unified_{self.session_timestamp}/{benchmark_safe}/{problem_safe}/round_{round_num}/phase5_training') os.makedirs(phase5_dir, exist_ok=True) phase5_log_dirs.append(phase5_dir) train_start = datetime.now() metrics = self._train_one_round(saved_files, logger, phase5_log_dirs) # saved_files 전달 train_time = (datetime.now() - train_start).total_seconds() PrettyPrinter.status("✅ Phase 5 Complete", f"Training completed in {train_time:.2f}s", "success") # 라운드 전체 메트릭 로깅 total_time = data_gen_time + train_time logger.log({ "round_time/phase_1_4": data_gen_time, "round_time/phase_5": train_time, "round_time/total": total_time, "round_data/num_tasks": len(round_data), **metrics }, step=self.global_steps) PrettyPrinter.status("⏱️ Round Summary", f"Total time: {total_time:.2f}s (Phase 1-4: {data_gen_time:.2f}s, Phase 5: {train_time:.2f}s)", "info") # 체크포인트 저장 if round_num % 5 == 0: PrettyPrinter.status("💾 Checkpoint", f"Saving checkpoint at round {round_num}", "info") self._save_checkpoint() # ============================================ # 학습 완료 및 정리 # ============================================ PrettyPrinter.section_header("🏁 Training Complete") self.cleanup_and_save() # ReasonRLTracking doesn't have finish method, only call if available if hasattr(logger, 'finish'): logger.finish() def _generate_round_data(self) -> List[Dict[str, Any]]: """ ============================================ Phase 1-4: 데이터 생성 (CompleteTestTimePipeline 사용) ============================================ Phase 1: 벤치마크 문제에서 다양한 프로그램 생성 Phase 2: 생성된 프로그램에서 I/O 쌍 추출 Phase 3: IPO 트리플 추출 (Input, Program, Output) Phase 4: I/O/P로부터 Induction/Deduction/Abduction task 생성 ============================================ """ # CompleteTestTimePipeline 초기화 (처음 호출 시) if self.ttrlvr_pipeline is None: self._init_ttrlvr_pipeline() all_tasks = [] # 전체 세션의 타임스탬프 사용 (모든 라운드가 동일한 세션에 속함) session_timestamp = self.session_timestamp # 벤치마크 설정 (MBPP) - 기존 TTRLVR config 사용 # 처음 호출 시 benchmark_config 생성 및 저장 if not hasattr(self, 'benchmark_config'): self.benchmark_config = BenchmarkConfig.get_mbpp_config() project_root = Path(__file__).parent.parent.parent # TestTime-RLVR-v2 directory self.benchmark_config.data_path = str(project_root / "evaluation/code_eval/data/MbppPlus.jsonl") for problem_id in self.problem_ids: PrettyPrinter.section_header(f"📝 Phase 1-4: Processing {problem_id}") # 현재 처리중인 문제 ID 저장 (Phase 5 로그에서 사용) self.current_problem_id = problem_id try: # ====== CompleteTestTimePipeline의 로직을 직접 구현 ====== # CompleteTestTimePipeline과 동일한 라운드별 로그 디렉토리 설정 benchmark_safe = self.benchmark_config.name.replace('/', '_') problem_safe = problem_id.replace('/', '_') project_root = Path(__file__).parent.parent.parent # TestTime-RLVR-v2 directory round_log_dir = str(project_root / f'tmp/batch_results/ttrlvr_azr_unified_{self.session_timestamp}/{benchmark_safe}/{problem_safe}/round_{self.current_round}') # 라운드별 로거 재설정 (CompleteTestTimePipeline과 동일한 구조) self.ttrlvr_logger = TestTimeLogger( task_output_dir=round_log_dir # use_integrated_structure는 TestTimeLogger 내부에서 자동 설정됨 ) # Phase 1-1: 벤치마크 문제 로딩 self.ttrlvr_logger.log_info(f"📄 Phase 1-1: Loading benchmark problem for Round {self.current_round}") problem = self.benchmark_loader.load_problem(self.benchmark_config, problem_id) # Phase 1-2: 베이스라인 성능 측정 (current_evaluation에 저장) # 베이스라인 평가는 항상 실행 (skip_task_evaluation과 무관) self.ttrlvr_logger.log_info(f"📊 Phase 1-2: Baseline performance evaluation") baseline_results = self._evaluate_baseline_performance(problem) # 라운드별 IPO buffer 초기화 (CompleteTestTimePipeline과 동일) self.ttrlvr_logger.log_info(f"🔄 Phase 1-3: Clearing IPO buffer for round {self.current_round}") self.ipo_buffer.clear(problem_id) # Phase 1-3 & Phase 2 & Phase 3: 다양한 프로그램 생성 및 IPO 처리 self.ttrlvr_logger.log_info("🎨 Phase 1-3 → Phase 3: Generating diverse programs and extracting IPO") diverse_programs_results = self._generate_diverse_programs_and_ipo(problem) # Diverse programs 결과를 diverse_programs/ 디렉토리에 저장 (성공/실패 관계없이) self._save_diverse_programs_results(diverse_programs_results) if not diverse_programs_results['success']: self.ttrlvr_logger.log_error(f"❌ No valid diverse programs generated") continue # Phase 3-1: IPO triples 수집 self.ttrlvr_logger.log_info("🎯 Phase 3-1: Collecting IPO triples for task generation") current_round_triples = self.ipo_buffer.get_all(problem_id) self.ttrlvr_logger.log_info(f"🎯 Phase 3-2: Using {len(current_round_triples)} IPO triples from current round") # Phase 4: Task 생성 (Induction/Deduction/Abduction) self.ttrlvr_logger.log_info("📝 Phase 4: Generating tasks from IPO triples") tasks_dict = self._generate_tasks_from_ipo(current_round_triples, problem_id) # Task는 이미 _save_round_data에서 저장됨 # 총 task 수 계산 total_tasks = sum(len(task_list) for task_list in tasks_dict.values()) result = { 'success': total_tasks > 0, 'final_tasks': tasks_dict } if result['success'] and 'final_tasks' in result: # Phase 4 완료: 검증된 task들 수집 tasks_dict = result['final_tasks'] all_tasks.append(tasks_dict) # Dict 형태로 추가 PrettyPrinter.status("✅ Phase 1-4 Complete", f"Generated {total_tasks} tasks for {problem_id}", "success") # 각 Phase별 결과 로깅 if 'steps' in result: for phase, phase_data in result['steps'].items(): if isinstance(phase_data, dict) and 'success' in phase_data: PrettyPrinter.status(f" - {phase}", f"Success: {phase_data['success']}", "info") else: PrettyPrinter.status("❌ Phase 1-4 Failed", f"Failed for {problem_id}: {result.get('error', 'Unknown error')}", "error") except Exception as e: PrettyPrinter.status("Error", f"Failed processing {problem_id}: {e}", "error") import traceback self.ttrlvr_logger.log_error(f"Traceback: {traceback.format_exc()}") continue PrettyPrinter.status("Phase 1-4 Summary", f"Total tasks generated: {len(all_tasks)}", "success") # Dict 구조 그대로 반환 (_save_round_data에서 처리) return all_tasks def _init_ttrlvr_pipeline(self): """ CompleteTestTimePipeline을 VeRL의 vLLM으로 초기화 중요: VeRL의 actor_rollout_wg가 준비된 후에 호출되어야 함 """ PrettyPrinter.status("Init", "Initializing TTRLVR Pipeline with VeRL's vLLM", "info") # VeRL의 모델과 토크나이저 사용 # actor_rollout_wg가 초기화되어 있어야 함 if not hasattr(self, 'actor_rollout_wg') or self.actor_rollout_wg is None: raise RuntimeError("actor_rollout_wg not initialized. Call init_workers() first.") # Option C: CompleteTestTimePipeline의 로직을 직접 구현 # 모델이 필요 없는 헬퍼 컴포넌트 초기화 from absolute_zero_reasoner.testtime.benchmark_loader import BenchmarkProblemLoader from absolute_zero_reasoner.testtime.ipo_extractor import IPOBuffer self.benchmark_loader = BenchmarkProblemLoader(self.testtime_config, self.ttrlvr_logger) self.ipo_buffer = IPOBuffer() # Executor는 이미 UnifiedTTRLVRTrainer.__init__에서 self._executor로 초기화됨 # 여기서는 별도로 초기화하지 않고 기존 것을 사용 self.python_executor = self._executor # 프롬프트 템플릿 준비 (TTRLVR의 것 그대로 사용) # 이미 상단에서 import됨 self.get_prompt = get_prompt self.get_temperature = get_temperature # 다양성 생성을 위한 프롬프트 (CompleteTestTimePipeline과 동일) self.diversity_prompts = [ "Please generate a complete, self-contained Python script that solves the following problem.", "Write a Python solution for this problem using a different approach.", "Create an alternative Python implementation for the given problem.", "Solve this problem with a unique Python solution approach." ] PrettyPrinter.status("Init", "✅ TTRLVR Pipeline initialized", "success") def _generate_diverse_programs_and_ipo(self, problem: Dict[str, Any]) -> Dict[str, Any]: """2-Phase 방식: 프로그램 생성 → 배치 Input Generation""" problem_id = problem.get('task_id', 'unknown') self.problem_id = problem_id # problem_id를 인스턴스 변수로 저장 worker_batch_size = 4 # Ray worker 수에 맞춤 self.ttrlvr_logger.log_info(f"🎨 Generating {self.num_programs} diverse programs for {problem_id} (2-PHASE)") diverse_results = { 'success': False, 'total_programs': self.num_programs, 'valid_programs': 0, 'programs': [], 'total_ipo_triples': 0, 'error': None, 'two_phase': True } try: # ============================================ # Phase 1: 프로그램 생성 및 검증 # ============================================ self.ttrlvr_logger.log_info("📝 Phase 1-3: Generating and validating programs") all_programs = [] valid_programs = [] # TTRLVR의 프롬프트 시스템 사용 import re # 이미 상단에서 import됨 # 모든 프로그램 프롬프트 생성 all_prompts = [] for i in range(self.num_programs): diversity_instruction = get_diversity_instruction(i) problem_description = problem.get('prompt', '') # HumanEval인지 MBPP인지 확인 if 'HumanEval' in problem_id: entry_point = problem.get('entry_point', 'unknown') # 프롬프트에서 함수가 여러 개 있는지 확인 function_count = len(re.findall(r'^\s*def\s+\w+', problem_description, re.MULTILINE)) if function_count > 1: prompt = get_prompt("diverse_humaneval_multi", diversity_instruction=diversity_instruction, problem_prompt=problem_description, entry_point=entry_point) else: prompt = get_prompt("diverse_humaneval_basic", diversity_instruction=diversity_instruction, problem_prompt=problem_description) else: # MBPP 다양성 프롬프트 사용 prompt = get_prompt("diverse_mbpp_basic", diversity_instruction=diversity_instruction, problem_prompt=problem_description) all_prompts.append(prompt) # 4개씩 배치로 프로그램 생성 for i in range(0, len(all_prompts), worker_batch_size): batch_prompts = all_prompts[i:i+worker_batch_size] # 4개가 안되면 패딩 while len(batch_prompts) < worker_batch_size: batch_prompts.append(batch_prompts[-1]) # 배치 생성 solutions = self._generate_batch_with_vllm(batch_prompts, temperature=0.7) # 실제 필요한 개수만큼만 처리 및 검증 num_real = min(len(all_prompts) - i, worker_batch_size) for j in range(num_real): solution = solutions[j] if j < len(solutions) else "" prompt = batch_prompts[j] if j < len(batch_prompts) else "" # TTRLVR의 후처리 파이프라인 적용 # 1. 마크다운 코드 블록에서 Python 코드 추출 extracted_solution = self._extract_python_code(solution) if extracted_solution and extracted_solution != solution: self.ttrlvr_logger.log_info(f" - Extracted Python code from markdown block for program {len(all_programs)+1}") solution = extracted_solution # 2. HumanEval의 경우 import 추가 if 'HumanEval' in problem_id: solution = self._add_imports_from_prompt(solution, problem.get('prompt', '')) # 3. 함수 정의 복구 solution = self._fix_function_definition(solution, prompt, problem_id) # 프로그램 검증만 수행 (input generation은 나중에) program_result = { 'variation_id': len(all_programs), 'solution': solution, # 후처리된 솔루션 (테스트 코드 포함될 수 있음) 'raw_solution': solutions[j] if j < len(solutions) else "", # 원본 LLM 응답 'syntax_valid': False, 'ipo_triples': [], 'problem_prompt': problem.get('prompt', ''), # 원본 문제 'diversity_instruction': batch_prompts[j] if j < len(batch_prompts) else '', # 사용된 프롬프트 'num_ipo_triples': 0, 'num_generated_inputs': 0, 'input_generation_info': [] # Input generation 정보 } # 구문 검증 (후처리된 솔루션에 대해) try: compile(solution, '', 'exec') program_result['syntax_valid'] = True self.ttrlvr_logger.log_info(f" ✅ Program {len(all_programs)+1}: Syntax valid") except SyntaxError as e: program_result['syntax_error'] = str(e) self.ttrlvr_logger.log_info(f" ❌ Program {len(all_programs)+1}: Syntax error - {e}") all_programs.append(program_result) if program_result['syntax_valid']: valid_programs.append(program_result) diverse_results['valid_programs'] += 1 self.ttrlvr_logger.log_info(f"✅ Phase 1-3 complete: {len(valid_programs)} valid programs") # ============================================ # Phase 2: 배치 Input Generation (유효한 프로그램만) # ============================================ if valid_programs: self.ttrlvr_logger.log_info(f"🎲 Phase 2-1: Batch input generation for {len(valid_programs)} programs") # 4개씩 묶어서 input generation for i in range(0, len(valid_programs), worker_batch_size): batch = valid_programs[i:i+worker_batch_size] # Input generation 프롬프트 준비 input_prompts = [] for prog in batch: # 기존 _generate_diverse_inputs_with_model의 프롬프트 생성 부분 재사용 entry_point = problem.get('entry_point', 'solution') func_info = self._extract_function_info(prog['solution'], entry_point) if func_info: # Docstring/Assert에서 예제 추출 (치팅 방지 - base_input 사용 안함) existing_examples = self._extract_examples_from_prompt( problem.get('prompt', ''), entry_point, prog['solution'] ) # 인자 타입 추론 arg_type_info = self._infer_argument_types(func_info, existing_examples, prog['solution']) # 프롬프트 생성 prompt = self._create_input_generation_prompt( problem_description=problem.get('prompt', ''), existing_examples=existing_examples, full_code=prog['solution'], arg_type_info=arg_type_info ) input_prompts.append(prompt) prog['func_info'] = func_info # 나중에 사용하기 위해 저장 else: # 빈 프롬프트 대신 기본 프롬프트 사용 default_prompt = "Generate 5 example inputs for a Python function." input_prompts.append(default_prompt) # 4개가 안되면 유효한 프롬프트로 패딩 while len(input_prompts) < worker_batch_size: if input_prompts and input_prompts[-1]: # 마지막이 비어있지 않으면 input_prompts.append(input_prompts[-1]) else: # 마지막이 비어있으면 기본 프롬프트 사용 input_prompts.append("Generate 5 example inputs for a Python function.") # 배치로 input 생성 (모든 프롬프트가 비어있으면 건너뛰기) if not input_prompts or all(not p for p in input_prompts): self.ttrlvr_logger.log_warning(f" ⚠️ Skipping input generation - no valid prompts") continue input_responses = self._generate_batch_with_vllm(input_prompts, temperature=0.5) # 생성된 input 처리 num_real = min(len(batch), worker_batch_size) for j in range(num_real): if j < len(input_responses) and 'func_info' in batch[j]: # Input generation 정보 기록 (기존 TTRLVR 형식과 완전 동일) func_info = batch[j]['func_info'] gen_info = { 'function_info': { 'name': func_info.get('name', 'unknown'), 'parameters': func_info.get('args', []), 'signature': f"def {func_info.get('name', 'unknown')}({', '.join(func_info.get('args', []))}):" }, 'argument_types': func_info.get('arg_types', {}), 'existing_examples': [], 'prompt': input_prompts[j] if j < len(input_prompts) else "", 'response': input_responses[j], 'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S') } # 기존 예제 추가 (problem의 test_list에서) if problem.get('test_list'): for test in problem['test_list'][:1]: # 첫 번째 예제만 example_str = f"Input: {func_info.get('name', 'unknown')}{test['input']} → Output: {test.get('output', 'N/A')}" gen_info['existing_examples'].append(example_str) # 응답에서 입력 파싱 parsed_inputs = self._parse_llm_input_response( input_responses[j], batch[j]['func_info'] ) # 파싱된 입력들 기록 gen_info['extracted_inputs'] = parsed_inputs # IPO 트리플 생성 - TTRLVR 방식 (PythonExecutor 사용) ipo_triples = [] entry_point = problem.get('entry_point', 'solution') for input_idx, input_dict in enumerate(parsed_inputs[:5]): # 최대 5개 try: # 기존 TTRLVR 방식: validation 제거, 직접 실행 # 1. 필수 인자 확인 required_args = set(batch[j]['func_info'].get('args', [])) provided_args = set(input_dict.keys()) if not required_args.issubset(provided_args): self.ttrlvr_logger.log_warning(f"Input {input_idx+1} missing required args: {required_args - provided_args}") continue # 2. 입력 준비 (TTRLVR 방식으로 문자열 생성) input_args = [] for arg_name in batch[j]['func_info'].get('args', []): if arg_name in input_dict: input_args.append(input_dict[arg_name]) # 인자를 문자열로 변환 if len(input_args) == 1 and isinstance(input_args[0], list): args_str = repr(input_args[0]) elif len(input_args) == 1: args_str = repr(input_args[0]) else: args_str = ', '.join(repr(arg) for arg in input_args) # 3. 실행 코드 생성 (TTRLVR 방식) execution_code = f""" {batch[j]['solution']} # Execute function with generated inputs try: result = {entry_point}({args_str}) print(repr(result)) except Exception as e: print(f"EXECUTION_ERROR: {{e}}") """ # 4. PythonExecutor로 실행 output, status = self._executor.apply(execution_code) # 5. 결과 처리 success = 'error' not in status.lower() and 'EXECUTION_ERROR' not in output result = None if success and output: output_lines = output.strip().split('\n') if output_lines: result = output_lines[-1].strip() if success and result is not None: # IPO 트리플 생성 (기존 TTRLVR과 완전 동일한 형식) input_str = ', '.join(repr(arg) for arg in input_args) ipo_triple = { 'id': f"{problem_id}_triple_{len(ipo_triples)}", 'input': input_str, # 문자열 형식으로 'full_input_str': f"{entry_point}({input_str})", 'program': batch[j]['solution'], 'expected_output': str(result), 'actual_output': str(result), 'function_name': entry_point, 'function_args': batch[j]['func_info'].get('args', []), 'is_correct': True, 'extraction_method': 'generated', 'source_program_id': f"program_{batch[j].get('variation_id', j)}", 'ipo_index': len(ipo_triples) } ipo_triples.append(ipo_triple) self.ttrlvr_logger.log_info(f"✅ Input {input_idx+1}: IPO triple created successfully") else: self.ttrlvr_logger.log_warning(f"❌ Input {input_idx+1} execution failed") except Exception as e: # 에러 로깅만 하고 계속 진행 (TTRLVR 방식) self.ttrlvr_logger.log_error(f"Input {input_idx+1} IPO extraction error: {e}") continue # 다음 입력으로 계속 batch[j]['ipo_triples'] = ipo_triples batch[j]['num_ipo_triples'] = len(ipo_triples) # IPO 카운트 업데이트 batch[j]['num_generated_inputs'] = len(ipo_triples) # 생성된 입력 카운트 # Input generation 정보 추가 (기존 TTRLVR 형식) if 'input_generation_info' not in batch[j]: batch[j]['input_generation_info'] = [] batch[j]['input_generation_info'].append(gen_info) # IPO buffer에 추가 (메타데이터 포함) program_id = f"var_{j}" # variation ID를 program ID로 사용 for ipo_idx, triple in enumerate(ipo_triples): # 원본 TTRLVR과 동일하게 메타데이터 추가 triple['source_program_id'] = program_id triple['ipo_index'] = ipo_idx self.ipo_buffer.add(problem_id, triple) self.ttrlvr_logger.log_info(f"✅ Phase 2-3 complete: {len(self.ipo_buffer.get_all(problem_id))} IPO triples extracted") diverse_results['programs'] = all_programs diverse_results['success'] = diverse_results['valid_programs'] > 0 diverse_results['total_ipo_triples'] = len(self.ipo_buffer.get_all(problem_id)) except Exception as e: self.ttrlvr_logger.log_error(f"❌ Error in 2-phase generation: {e}") diverse_results['error'] = str(e) return diverse_results def _process_single_program(self, problem: Dict[str, Any], solution: str, variation_id: int) -> Dict[str, Any]: """단일 프로그램 처리 (CompleteTestTimePipeline과 동일)""" from absolute_zero_reasoner.rewards.custom_evaluate import check_correctness program_result = { 'variation_id': variation_id, 'solution': solution, 'syntax_valid': False, 'ipo_triples': [] } # 구문 검증 try: compile(solution, '', 'exec') program_result['syntax_valid'] = True # IPO 추출 (VeRL vLLM 사용) # 기존 TTRLVR처럼 함수만 추출해서 IPO 생성 extracted_function_code = self._extract_function_code(solution) ipo_triples = self._extract_ipo_for_program(problem, extracted_function_code) program_result['ipo_triples'] = ipo_triples program_result['extracted_function'] = extracted_function_code # 저장용 # IPO buffer에 추가 problem_id = problem.get('task_id', 'unknown') for triple in ipo_triples: self.ipo_buffer.add(problem_id, triple) except SyntaxError as e: program_result['syntax_error'] = str(e) return program_result def _extract_ipo_for_program(self, problem: Dict[str, Any], program: str) -> List[Dict]: """프로그램에서 IPO (Input-Program-Output) 트리플 추출""" ipo_triples = [] try: # IPOTripleExtractor의 방식과 동일하게 처리 entry_point = problem.get('entry_point', 'solution') # 1. Prompt에서 공개된 예제 추출 (치팅 방지 - base_input 사용 안함) existing_examples = self._extract_examples_from_prompt( problem.get('prompt', ''), entry_point, program ) # 2. 모델 기반 다양한 입력 생성 diverse_inputs = self._generate_diverse_inputs_with_model( problem, program, existing_examples ) # 3. 생성된 입력으로 IPO 트리플 생성 for i, input_dict in enumerate(diverse_inputs[:5]): # 최대 5개 try: exec_globals = {} exec(program, exec_globals) if entry_point in exec_globals: func = exec_globals[entry_point] # input_dict에서 인자 추출 if isinstance(input_dict, dict): # 함수 시그니처에 맞게 인자 정렬 import inspect sig = inspect.signature(func) args = [input_dict.get(param, None) for param in sig.parameters] args = [arg for arg in args if arg is not None] if args: output = str(func(*args)) args_str = ', '.join(repr(arg) for arg in args) else: # dict가 비어있으면 기본값 사용 continue else: # 단순 값인 경우 output = str(func(input_dict)) args_str = repr(input_dict) # TaskGenerator가 기대하는 형식 ipo_triples.append({ 'id': f"{problem.get('task_id', 'unknown')}_triple_{i}", 'input': args_str if 'args_str' in locals() else str(input_dict), 'full_input_str': f"{entry_point}({args_str if 'args_str' in locals() else input_dict})", 'program': program, 'actual_output': output, 'source_program_id': f"program_0", 'ipo_index': i }) except Exception as e: self.ttrlvr_logger.log_warning(f"Failed to execute with input {i}: {e}") # 기존 예제도 IPO로 추가 (다양성 보장) for i, (input_str, output) in enumerate(existing_examples): # input_str에서 실제 인자 추출 import re match = re.match(r'\w+\((.*)\)', input_str) if match: args_str = match.group(1) else: args_str = str(input_str) ipo_triples.append({ 'id': f"{problem.get('task_id', 'unknown')}_triple_base_{i}", 'input': args_str, 'full_input_str': input_str, 'program': program, 'actual_output': output, 'source_program_id': f"program_0", 'ipo_index': i }) except Exception as e: # self.ttrlvr_logger.log_error(f"IPO extraction failed: {e}") pass return ipo_triples def _generate_diverse_inputs_with_model(self, problem: Dict[str, Any], program: str, existing_examples: List[Tuple[str, str]]) -> List[Dict]: """모델을 사용하여 다양한 입력 생성 (TTRLVR CompleteTestTimePipeline과 동일)""" generated_inputs = [] try: # 함수 시그니처 추출 entry_point = problem.get('entry_point', 'solution') func_info = self._extract_function_info(program, entry_point) if not func_info: return [] # 인자 타입 추론 arg_type_info = self._infer_argument_types(func_info, existing_examples, program) # 입력 생성 프롬프트 생성 prompt = self._create_input_generation_prompt( problem_description=problem.get('prompt', ''), existing_examples=existing_examples, full_code=program, arg_type_info=arg_type_info ) # VeRL의 vLLM을 사용하여 생성 (4의 배수 제약 처리) # _generate_with_vllm 함수 사용 (내부적으로 4개로 패딩) response = self._generate_with_vllm(prompt, temperature=0.7) # 응답에서 입력 예제 파싱 parsed_inputs = self._parse_llm_input_response(response, func_info) # 검증된 입력만 반환 for input_dict in parsed_inputs: if self._validate_input(input_dict, func_info, program): generated_inputs.append(input_dict) except Exception as e: self.ttrlvr_logger.log_warning(f"Model-based input generation failed: {e}, using fallback") # Fallback: 간단한 기본 입력 생성 import random for i in range(3): generated_inputs.append({ 'value': random.randint(1, 100) }) return generated_inputs def _extract_function_info(self, code: str, entry_point: str) -> Optional[Dict]: """함수 정보 추출 - TTRLVR 방식""" import ast try: # AST로 함수 정의 파싱 tree = ast.parse(code) # Entry point 함수 우선 검색 target_function = None all_functions = [] for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): func_info = { 'name': node.name, 'args': [arg.arg for arg in node.args.args], 'signature': f"def {node.name}({', '.join([arg.arg for arg in node.args.args])}):", 'full_code': code, # TTRLVR처럼 전체 코드 포함 'defaults': node.args.defaults, 'return_type': None } all_functions.append(func_info) # Entry point와 일치하는 함수 우선 선택 if entry_point and node.name == entry_point: target_function = func_info break # Entry point 함수를 찾았으면 반환 if target_function: return target_function # Entry point를 찾지 못했으면 첫 번째 함수 반환 if all_functions: self.ttrlvr_logger.log_debug(f"Entry point '{entry_point}' not found, using first function: {all_functions[0]['name']}") return all_functions[0] except Exception as e: self.ttrlvr_logger.log_debug(f"Function parsing failed: {e}") return None def _infer_argument_types(self, func_info: Dict, existing_examples: List, code: str) -> Dict[str, str]: """기존 예제에서 인자 타입 추론""" arg_types = {} # 기본 타입 추론 for arg in func_info['args']: arg_types[arg] = 'Any' # 기본값 # 예제에서 타입 추론 if existing_examples: # 첫 번째 예제 분석 import re for input_str, _ in existing_examples[:1]: match = re.match(r'\w+\((.*)\)', input_str) if match: args_str = match.group(1) try: # 안전한 평가 import ast args = ast.literal_eval(f"[{args_str}]") for i, (arg_name, arg_val) in enumerate(zip(func_info['args'], args)): if isinstance(arg_val, int): arg_types[arg_name] = 'int' elif isinstance(arg_val, float): arg_types[arg_name] = 'float' elif isinstance(arg_val, str): arg_types[arg_name] = 'str' elif isinstance(arg_val, list): arg_types[arg_name] = 'List' elif isinstance(arg_val, dict): arg_types[arg_name] = 'Dict' except: pass return arg_types def _create_input_generation_prompt(self, problem_description: str, existing_examples: List[Tuple[str, str]], full_code: str, arg_type_info: Dict[str, str]) -> str: """입력 생성 프롬프트 (IPOTripleExtractor와 동일)""" examples_text = "" for i, (input_str, output_str) in enumerate(existing_examples): examples_text += f"Example {i+1}:\n" examples_text += f"Input: {input_str}\n" examples_text += f"Output: {output_str}\n\n" arg_type_text = "Argument types:\n" for arg, arg_type in arg_type_info.items(): arg_type_text += f"- {arg}: {arg_type}\n" prompt = f"""Given the following problem description and its Python function implementation, first analyze the types and valid ranges of the function arguments, then write **5 different example inputs** for the function that cover a diverse mix of typical (general) cases and edge/boundary cases. Problem Description: ''' {problem_description} ''' Existing Examples from Problem: {examples_text} Function Implementation: ```python {full_code} ``` {arg_type_text} Please generate 5 diverse test inputs, considering purpose of the python code. Format them as a Python list of dictionaries where each dictionary contains the function arguments. Format your response as: ```python examples = [ {{dict_with_all_function_parameters}}, # Description of this test case {{dict_with_all_function_parameters}}, # Description of this test case ... # Continue for all 5 examples ] ``` Ensure your examples include: - At least 2 typical/general cases - At least 2 edge/boundary cases - 1 special case (empty, zero, maximum values, etc.) - All examples should be DIFFERENT from the existing examples """ # prompt = f"""Given the following problem description and its Python function implementation, first analyze the types and valid ranges of the function arguments, then write **5 different example inputs** for the function that cover a diverse mix of typical (general) cases and edge/boundary cases. # Problem Description: # ''' # {problem_description} # ''' # Existing Examples from Problem: # {examples_text} # Function Implementation: # ```python # {full_code} # ``` # {arg_type_text} # Please generate 5 diverse test inputs. Format them as a Python list of dictionaries where each dictionary contains the function arguments. # ```python # examples = [ # # Your examples here # ] # ```""" return prompt def _parse_llm_input_response(self, response: str, func_info: Dict) -> List[Dict]: """LLM 응답에서 입력 예제 파싱""" parsed_inputs = [] try: # ```python ... ``` 블록에서 코드 추출 import re code_pattern = r'```python\n(.*?)\n```' matches = re.findall(code_pattern, response, re.DOTALL) if matches: code = matches[0] exec_globals = {} exec(code, exec_globals) examples = exec_globals.get('examples', []) # 예제를 입력 딕셔너리로 변환 for example in examples: if isinstance(example, dict): parsed_inputs.append(example) elif isinstance(example, (list, tuple)): # 리스트/튜플인 경우 함수 인자에 맞게 매핑 input_dict = {} for i, (arg_name, val) in enumerate(zip(func_info['args'], example)): input_dict[arg_name] = val parsed_inputs.append(input_dict) else: # 단일 값인 경우 if func_info['args']: parsed_inputs.append({func_info['args'][0]: example}) except Exception as e: self.ttrlvr_logger.log_warning(f"Failed to parse LLM response with exec: {e}") # Fallback: 키에 따옴표 추가 후 재시도 try: if matches: code = matches[0] # 각 파라미터 이름에 따옴표 추가 for arg in func_info.get('args', []): # arg: 패턴을 "arg": 로 변경 code = re.sub(rf'\b{arg}\s*:', f'"{arg}":', code) exec_globals = {} exec(code, exec_globals) examples = exec_globals.get('examples', []) for example in examples: if isinstance(example, dict): parsed_inputs.append(example) if parsed_inputs: self.ttrlvr_logger.log_info(f"Successfully parsed {len(parsed_inputs)} inputs with fallback") except Exception as e2: self.ttrlvr_logger.log_warning(f"Fallback parsing also failed: {e2}") return parsed_inputs def _extract_examples_from_prompt(self, prompt: str, entry_point: str, solution: str) -> List[Tuple[str, str]]: """ Prompt에서 공개된 예제만 추출 (치팅 방지) 추출 순서: 1. >>> 패턴 (HumanEval docstring) 2. assert 패턴 (MBPP) 3. 예제가 없으면 빈 리스트 반환 """ examples = [] # 1. Docstring >>> 예제 추출 (HumanEval) lines = prompt.split('\n') for i in range(len(lines)): line = lines[i].strip() # >>> func_name(...) 패턴 찾기 if line.startswith('>>>') and entry_point in line: # >>> 제거하고 입력 추출 input_line = line[3:].strip() # 다음 줄에서 출력 추출 if i + 1 < len(lines): output_line = lines[i + 1].strip() # 출력이 >>> 로 시작하지 않으면 출력값으로 간주 if not output_line.startswith('>>>'): # 실제 실행해서 검증 try: # 입력에서 인자 추출 import re match = re.match(rf'{entry_point}\((.*)\)', input_line) if match: args_str = match.group(1) # 안전한 평가 import ast try: args = ast.literal_eval(f"[{args_str}]") except: args = eval(f"[{args_str}]") # 실행해서 실제 output 계산 actual_output = self._execute_with_solution(solution, entry_point, args) if actual_output is not None: examples.append((input_line, str(actual_output))) else: # 실행 실패시 원본 사용 examples.append((input_line, output_line)) except Exception as e: self.ttrlvr_logger.log_debug(f"Failed to parse >>> example: {e}") # 파싱 실패시 원본 그대로 사용 examples.append((input_line, output_line)) # 2. Assert 문 예제 추출 (MBPP) if not examples: import re # assert 패턴: assert func_name(...) == expected # 또는 assert set(func_name(...)) == set(...) assert_patterns = [ rf'assert\s+set\({entry_point}\((.*?)\)\)\s*==\s*set\((.*?)\)', rf'assert\s+{entry_point}\((.*?)\)\s*==\s*(.*?)(?:\n|$)', rf'assert\s+.*{entry_point}\s*\((.*?)\)\s*==\s*(.*?)(?:\n|$)' ] for pattern in assert_patterns: matches = re.findall(pattern, prompt, re.MULTILINE | re.DOTALL) for args_str, expected in matches[:2]: # 최대 2개 try: # 인자 파싱 args = eval(f"[{args_str}]") # 실행해서 실제 output 계산 actual_output = self._execute_with_solution(solution, entry_point, args) if actual_output is not None: input_str = f"{entry_point}({args_str})" examples.append((input_str, str(actual_output))) except Exception as e: self.ttrlvr_logger.log_debug(f"Failed to parse assert example: {e}") # 3. 예제가 없어도 괜찮음 (빈 리스트 반환) if not examples: self.ttrlvr_logger.log_info(f"No examples found in prompt for {entry_point}") else: self.ttrlvr_logger.log_info(f"Extracted {len(examples)} examples from prompt for {entry_point}") return examples def _execute_with_solution(self, solution: str, entry_point: str, args: List) -> Optional[Any]: """주어진 solution으로 함수 실행""" try: # 실행 코드 생성 if len(args) == 1 and isinstance(args[0], list): args_str = repr(args[0]) elif len(args) == 1: args_str = repr(args[0]) else: args_str = ', '.join(repr(arg) for arg in args) execution_code = f""" {solution} try: result = {entry_point}({args_str}) print(repr(result)) except Exception as e: print(f"EXECUTION_ERROR: {{e}}") """ # PythonExecutor로 실행 output, status = self._executor.apply(execution_code) # 결과 처리 if 'error' not in status.lower() and 'EXECUTION_ERROR' not in output: output_lines = output.strip().split('\n') if output_lines: result_str = output_lines[-1].strip() # 안전한 평가 try: import ast return ast.literal_eval(result_str) except: return result_str return None except Exception as e: self.ttrlvr_logger.log_debug(f"Failed to execute solution: {e}") return None def _validate_input(self, input_dict: Dict, func_info: Dict, program: str) -> bool: """ [DEPRECATED] 생성된 입력 검증 기존 TTRLVR 방식으로 변경됨 - validation 없이 직접 실행 시도 이 함수는 더 이상 사용되지 않지만 호환성을 위해 유지 """ try: # 필수 인자만 확인 required_args = set(func_info['args']) provided_args = set(input_dict.keys()) if not required_args.issubset(provided_args): return False return True # 기본적인 체크만 수행 except: return False def _generate_tasks_from_ipo(self, ipo_triples: List[Dict], problem_id: str) -> Dict[str, List[Dict]]: """IPO에서 Task 생성 - 실제 TTRLVR의 TaskGenerator와 완전히 동일한 프롬프트 사용""" from absolute_zero_reasoner.data_construction.prompts import get_code_problem_predictor_prompt from absolute_zero_reasoner.testtime.task_generator import TestTimeTaskGenerator # 실제 TaskGenerator 인스턴스 생성 (프롬프트 생성 로직 재사용) if not hasattr(self, 'task_generator'): self.task_generator = TestTimeTaskGenerator(self.testtime_config, self.ttrlvr_logger) # TaskGenerator의 generate_tasks 메서드 직접 사용 all_tasks = self.task_generator.generate_tasks(ipo_triples, problem_id, self.current_round) # all_tasks는 {'induction': [...], 'deduction': [...], 'abduction': [...]} 형태 # 원본 TTRLVR과 동일하게 Dict 구조 그대로 반환 return all_tasks def _save_diverse_programs_results(self, diverse_results: Dict[str, Any]): """다양한 프로그램 생성 결과를 diverse_programs/ 디렉토리에 저장 (기존 TTRLVR과 동일한 구조)""" try: problem_id = self.problem_id if hasattr(self, 'problem_id') else 'unknown' # Diverse programs 디렉토리 생성 diverse_dir = os.path.join(self.ttrlvr_logger.log_dir, "diverse_programs") os.makedirs(diverse_dir, exist_ok=True) # 요약 파일 생성 (batch evaluation과 동일) summary_file = os.path.join(diverse_dir, "diverse_summary.txt") with open(summary_file, 'w', encoding='utf-8') as f: f.write(f"Diverse Programs Summary\n") f.write(f"Problem ID: {problem_id}\n") f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") f.write(f"Total Programs: {len(diverse_results.get('programs', []))}\n") f.write("="*50 + "\n\n") for i, program in enumerate(diverse_results.get('programs', []), 1): f.write(f"Program {i}: {'✅ Valid' if program.get('syntax_valid', False) else '❌ Invalid'}\n") f.write(f"IPO Triples: {program.get('num_ipo_triples', 0)}\n") f.write(f"Generated Inputs: {program.get('num_generated_inputs', 0)}\n\n") # 각 프로그램별 디렉토리와 파일 생성 (batch evaluation과 동일한 구조) for i, program in enumerate(diverse_results.get('programs', []), 1): program_dir = os.path.join(diverse_dir, f"program_{i}") os.makedirs(program_dir, exist_ok=True) # 1. generation_details.txt (batch evaluation과 동일한 형식) details_file = os.path.join(program_dir, "generation_details.txt") with open(details_file, 'w', encoding='utf-8') as f: f.write(f"Diverse Program {i} - Generation Details\n") f.write(f"Problem ID: {problem_id}\n") f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") f.write("="*80 + "\n\n") f.write("1. ORIGINAL PROBLEM:\n") f.write("="*80 + "\n") f.write(program.get('problem_prompt', 'No prompt available') + "\n") f.write("="*80 + "\n\n") f.write("2. DIVERSITY PROMPT USED:\n") f.write("="*80 + "\n") f.write(program.get('diversity_instruction', 'Standard generation') + "\n") f.write("="*80 + "\n\n") f.write("3. LLM RAW RESPONSE:\n") f.write("="*80 + "\n") f.write(program.get('raw_solution', program.get('solution', 'N/A')) + "\n") f.write("="*80 + "\n\n") f.write("4. PROCESSED SOLUTION (Function only, no test code):\n") f.write("="*80 + "\n") # extracted_function이 있으면 사용, 없으면 solution 사용 f.write(program.get('extracted_function', program.get('solution', 'N/A')) + "\n") f.write("="*80 + "\n\n") f.write("5. EVALUATION RESULTS:\n") f.write("="*80 + "\n") f.write(f"Syntax Valid: {'✅ YES' if program.get('syntax_valid', False) else '❌ NO'}\n") if program.get('syntax_error'): f.write(f"Syntax Error: {program['syntax_error']}\n") f.write(f"IPO Triples Generated: {program.get('num_ipo_triples', len(program.get('ipo_triples', [])))}\n") f.write(f"Input Generation: {program.get('num_generated_inputs', 0)} new inputs\n") f.write("="*80 + "\n") # 2. solution.py solution_file = os.path.join(program_dir, "solution.py") with open(solution_file, 'w', encoding='utf-8') as f: f.write(f"# Diverse Program {i}\n") f.write(f"# Problem ID: {problem_id}\n") f.write(f"# Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") f.write(f"# Syntax Valid: {program.get('syntax_valid', False)}\n") f.write(f"# IPO Triples: {program.get('num_ipo_triples', len(program.get('ipo_triples', [])))}\n\n") # 추출된 함수 코드가 있으면 사용, 없으면 원본 솔루션 사용 f.write(program.get('solution', '# No solution available')) # 3. ipo_triples 디렉토리와 파일들 if program.get('ipo_triples'): ipo_dir = os.path.join(program_dir, "ipo_triples") os.makedirs(ipo_dir, exist_ok=True) for j, triple in enumerate(program['ipo_triples'], 1): triple_file = os.path.join(ipo_dir, f"triple_{j}.json") with open(triple_file, 'w', encoding='utf-8') as f: json.dump(triple, f, indent=2) # 4. input_generation_details.txt 파일 생성 (기존 TTRLVR과 완전 동일한 형식) if program.get('input_generation_info'): details_file = os.path.join(program_dir, "input_generation_details.txt") with open(details_file, 'w', encoding='utf-8') as f: f.write(f"Input Generation Details - Program {i}\n") f.write(f"Problem ID: {problem_id}\n") f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") f.write("="*80 + "\n\n") # 각 라운드의 정보 저장 (기존 TTRLVR 형식) for round_idx, gen_info in enumerate(program.get('input_generation_info', []), 1): f.write(f"ROUND {round_idx}:\n") f.write("="*80 + "\n\n") # 1. FUNCTION INFO f.write("1. FUNCTION INFO:\n") f.write("="*80 + "\n") func_info = gen_info.get('function_info', {}) f.write(f"Function Name: {func_info.get('name', 'unknown')}\n") f.write(f"Parameters: {func_info.get('parameters', [])}\n") f.write(f"Parameters String: {func_info.get('signature', 'N/A')}\n\n") # 2. ARGUMENT TYPE INFO f.write("2. ARGUMENT TYPE INFO:\n") f.write("="*80 + "\n") f.write("Argument types:\n") for arg_name, arg_type in gen_info.get('argument_types', {}).items(): f.write(f"- {arg_name}: {arg_type}\n") if not gen_info.get('argument_types'): f.write("(No type information available)\n") f.write("\n") # 3. EXISTING EXAMPLES f.write("3. EXISTING EXAMPLES:\n") f.write("="*80 + "\n") for ex_idx, example in enumerate(gen_info.get('existing_examples', []), 1): f.write(f"Example {ex_idx}: {example}\n") if not gen_info.get('existing_examples'): f.write("(No existing examples)\n") f.write("\n") # 4. LLM PROMPT f.write("4. LLM PROMPT:\n") f.write("="*80 + "\n") f.write(gen_info.get('prompt', 'No prompt available')) f.write("\n") f.write("="*80 + "\n\n") # 5. LLM RESPONSE f.write("5. LLM RESPONSE:\n") f.write("="*80 + "\n") f.write(gen_info.get('response', 'No response available')) f.write("\n") f.write("="*80 + "\n\n") # 6. EXTRACTED INPUTS f.write("6. EXTRACTED INPUTS:\n") f.write("="*80 + "\n") for inp_idx, inp in enumerate(gen_info.get('extracted_inputs', []), 1): f.write(f"Input {inp_idx}: {inp}\n") if not gen_info.get('extracted_inputs'): f.write("(No inputs extracted)\n") f.write("\n") self.ttrlvr_logger.log_info(f"📝 Diverse programs evaluation saved to: {diverse_dir}") except Exception as e: self.ttrlvr_logger.log_error(f"Failed to save diverse programs evaluation: {e}") def _save_task_responses(self, task_type: str, responses: List[str]): """Task 응답을 llm_responses/ 디렉토리에 저장""" for i, response in enumerate(responses): response_file = f"{task_type}_response_{i}.txt" response_path = os.path.join(self.ttrlvr_logger.log_dir, "llm_responses", response_file) os.makedirs(os.path.dirname(response_path), exist_ok=True) with open(response_path, 'w', encoding='utf-8') as f: f.write(response) def _save_azr_training_data(self, tasks: List[Dict]): """AZR 학습 데이터를 azr_training_data/ 디렉토리에 저장""" # Task type별로 그룹화 from collections import defaultdict tasks_by_type = defaultdict(list) for task in tasks: task_type = task.get('task_type', 'unknown') tasks_by_type[task_type].append(task) # 각 타입별로 parquet 파일로 저장 import pandas as pd for task_type, task_list in tasks_by_type.items(): df = pd.DataFrame(task_list) parquet_path = os.path.join( self.ttrlvr_logger.log_dir, "azr_training_data", f"{task_type}.parquet" ) os.makedirs(os.path.dirname(parquet_path), exist_ok=True) df.to_parquet(parquet_path, index=False) def _save_phase5_logs(self, phase5_log_dirs: List[str], step_metrics: List[Dict], generated_responses: List[Dict], aggregated_metrics: Dict): """Phase 5 학습 로그를 각 문제별 phase5_training/ 디렉토리에 저장""" # 모든 문제의 통합 메트릭이므로 각 디렉토리에 동일하게 저장 for log_dir in phase5_log_dirs: try: # 1. 학습 메트릭 저장 (JSON) metrics_file = os.path.join(log_dir, 'training_metrics.json') with open(metrics_file, 'w', encoding='utf-8') as f: json.dump({ 'aggregated': aggregated_metrics, 'steps': step_metrics }, f, indent=2, default=str) # 2. PPO losses 저장 (별도 파일) ppo_losses = { 'actor_loss': [], 'critic_loss': [], 'kl_divergence': [], 'entropy': [] } for metrics in step_metrics: for key in ppo_losses.keys(): if key in metrics: ppo_losses[key].append(metrics[key]) losses_file = os.path.join(log_dir, 'ppo_losses.json') with open(losses_file, 'w', encoding='utf-8') as f: json.dump(ppo_losses, f, indent=2) # 3. Rewards 저장 rewards_file = os.path.join(log_dir, 'rewards.json') rewards_data = { 'step_rewards': [m.get('reward_mean', 0) for m in step_metrics], 'average_reward': aggregated_metrics.get('reward/mean', 0), 'max_reward': max([m.get('reward_mean', 0) for m in step_metrics]) if step_metrics else 0, 'min_reward': min([m.get('reward_mean', 0) for m in step_metrics]) if step_metrics else 0 } with open(rewards_file, 'w', encoding='utf-8') as f: json.dump(rewards_data, f, indent=2) # 4. 생성된 응답 샘플 저장 if generated_responses: responses_dir = os.path.join(log_dir, 'generated_responses') os.makedirs(responses_dir, exist_ok=True) for resp_data in generated_responses: step_num = resp_data['step'] resp_file = os.path.join(responses_dir, f'step_{step_num:04d}.json') with open(resp_file, 'w', encoding='utf-8') as f: json.dump(resp_data, f, indent=2) self.ttrlvr_logger.log_info(f"✅ Phase 5 logs saved to {log_dir}") except Exception as e: self.ttrlvr_logger.log_error(f"Failed to save Phase 5 logs to {log_dir}: {e}") def _save_phase5_batch_to_llm_responses(self, phase5_log_dirs: List[str], input_ids: List, response_ids: List, step_num: int, round_num: int): """Phase 5 배치 응답을 step별 디렉토리에 저장하고 정답률 분석""" from datetime import datetime for log_dir in phase5_log_dirs: try: # llm_responses/step_{num} 디렉토리 생성 llm_responses_dir = os.path.dirname(log_dir) # phase5_training -> round_N step_responses_path = os.path.join(llm_responses_dir, "llm_responses", f"step_{step_num:04d}") os.makedirs(step_responses_path, exist_ok=True) # 이 step의 평가 결과를 수집할 리스트 step_evaluations = [] # 각 배치 샘플을 개별 파일로 저장 (CompleteTestTimePipeline 형식) for i, (input_id, response_id) in enumerate(zip(input_ids, response_ids)): try: # 프롬프트와 응답 디코딩 prompt = self.tokenizer.decode(input_id, skip_special_tokens=True) response = self.tokenizer.decode(response_id, skip_special_tokens=True) # Task 정보 추출 (프롬프트에서) task_type, task_info = self._analyze_prompt_and_response(prompt, response) problem_id = getattr(self, 'current_problem_id', f'Problem_Unknown') if '/' in problem_id: problem_id = problem_id.replace('/', '_') # 개별 응답 파일 저장 response_file = os.path.join( step_responses_path, f"{problem_id}_{task_type}_sample{i+1}_response.txt" ) # 정답 여부 판단 - TTRLVRRewardManager가 실제 평가를 수행 # 여기서는 간단히 응답 존재 여부만 확인 is_correct = bool(response and response.strip()) accuracy_score = 1.0 if is_correct else 0.0 with open(response_file, 'w', encoding='utf-8') as f: # CompleteTestTimePipeline과 완전히 동일한 형식 f.write(f"Task Type: {task_type}\n") f.write(f"Task ID: phase5_round{round_num}_step{step_num}_sample{i+1}\n") f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") f.write("="*80 + "\nORIGINAL PROMPT:\n") f.write("="*80 + "\n") f.write(prompt) f.write("\n" + "="*80 + "\n") f.write("LLM RESPONSE:\n") f.write("="*80 + "\n") f.write(response) f.write("\n" + "="*80 + "\n") f.write("EXPECTED SOLUTION:\n") f.write("="*80 + "\n") if task_info.get('expected_answer'): f.write(task_info['expected_answer']) else: f.write("(Phase 5: Expected solution extracted from prompt)") f.write("\n" + "="*80 + "\n") f.write("EXTRACTED ANSWER:\n") f.write("="*80 + "\n") f.write(self._extract_answer_from_response(response, task_type)) f.write("\n" + "="*80 + "\n") f.write("MATCH RESULT:\n") f.write("="*80 + "\n") if is_correct: f.write(f"✅ CORRECT (Score: {accuracy_score:.3f})") else: f.write(f"❌ INCORRECT (Score: {accuracy_score:.3f})") # 평가 결과 수집 step_evaluations.append({ 'sample_id': f"sample{i+1}", 'task_type': task_type, 'is_correct': is_correct, 'accuracy_score': accuracy_score, 'response_file': os.path.basename(response_file) }) except Exception as e: self.ttrlvr_logger.log_warning(f"Failed to save sample {i} from step {step_num}: {e}") # Step 요약 파일 생성 self._save_step_summary(step_responses_path, step_evaluations, step_num, round_num) except Exception as e: self.ttrlvr_logger.log_error(f"Failed to save batch responses to {log_dir}: {e}") def _analyze_prompt_and_response(self, prompt: str, response: str) -> tuple: """프롬프트에서 task type과 task 정보 추출""" task_type = "unknown" task_info = {} # Task type 판단 (프롬프트 내용 기반) prompt_lower = prompt.lower() if "given the following input-output pairs" in prompt_lower and "write a function" in prompt_lower: task_type = "induction" elif "given the following function" in prompt_lower and "what is the output" in prompt_lower: task_type = "deduction" elif "given the following function" in prompt_lower and "what input" in prompt_lower: task_type = "abduction" elif any(keyword in prompt_lower for keyword in ["code_f", "function", "implement"]): task_type = "induction" # 기본값 else: task_type = "general" # Expected answer 추출 시도 try: if "expected output:" in prompt_lower: lines = prompt.split('\n') for i, line in enumerate(lines): if "expected output:" in line.lower() and i + 1 < len(lines): task_info['expected_answer'] = lines[i + 1].strip() break except: pass return task_type, task_info # 휴리스틱 평가 메서드 제거됨 - TTRLVRRewardManager가 실제 평가 수행 def _extract_answer_from_response(self, response: str, task_type: str) -> str: """응답에서 핵심 답안 추출""" response_clean = response.strip() if task_type == "induction": # 함수 정의 부분 추출 lines = response_clean.split('\n') func_lines = [line for line in lines if line.strip().startswith('def ')] if func_lines: return func_lines[0].strip() # 기본적으로 첫 번째 라인이나 전체 응답 first_line = response_clean.split('\n')[0].strip() if len(first_line) > 100: return first_line[:100] + "..." return first_line or response_clean[:100] def _save_step_summary(self, step_dir: str, evaluations: List[Dict], step_num: int, round_num: int): """Step별 요약 파일 생성""" # 통계 계산 total_samples = len(evaluations) if total_samples == 0: return correct_count = sum(1 for eval in evaluations if eval['is_correct']) overall_accuracy = correct_count / total_samples # Task type별 통계 task_type_stats = {} for eval in evaluations: task_type = eval['task_type'] if task_type not in task_type_stats: task_type_stats[task_type] = {'total': 0, 'correct': 0, 'accuracy_sum': 0.0} task_type_stats[task_type]['total'] += 1 if eval['is_correct']: task_type_stats[task_type]['correct'] += 1 task_type_stats[task_type]['accuracy_sum'] += eval['accuracy_score'] # Task type별 정확도 계산 for task_type in task_type_stats: stats = task_type_stats[task_type] stats['accuracy'] = stats['correct'] / stats['total'] stats['avg_score'] = stats['accuracy_sum'] / stats['total'] # 요약 파일 저장 summary_file = os.path.join(step_dir, "step_summary.json") summary_data = { 'step_info': { 'step_number': step_num, 'round_number': round_num, 'total_samples': total_samples, 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') }, 'overall_stats': { 'total_samples': total_samples, 'correct_samples': correct_count, 'overall_accuracy': overall_accuracy, 'average_score': sum(eval['accuracy_score'] for eval in evaluations) / total_samples }, 'task_type_stats': task_type_stats, 'sample_details': evaluations } with open(summary_file, 'w', encoding='utf-8') as f: json.dump(summary_data, f, indent=2, ensure_ascii=False) # 텍스트 요약도 저장 summary_txt_file = os.path.join(step_dir, "step_summary.txt") with open(summary_txt_file, 'w', encoding='utf-8') as f: f.write(f"=== Step {step_num} Summary (Round {round_num}) ===\n\n") f.write(f"Overall Results:\n") f.write(f" Total Samples: {total_samples}\n") f.write(f" Correct: {correct_count} ({overall_accuracy:.1%})\n") f.write(f" Average Score: {summary_data['overall_stats']['average_score']:.3f}\n\n") f.write(f"Task Type Breakdown:\n") for task_type, stats in task_type_stats.items(): f.write(f" {task_type.upper()}:\n") f.write(f" Total: {stats['total']}\n") f.write(f" Correct: {stats['correct']} ({stats['accuracy']:.1%})\n") f.write(f" Avg Score: {stats['avg_score']:.3f}\n\n") self.ttrlvr_logger.log_info( f"📊 Step {step_num} Summary: {correct_count}/{total_samples} correct ({overall_accuracy:.1%})" ) def _train_one_round( self, saved_files: Dict[str, str], # Parquet 파일 경로들 logger, phase5_log_dirs: List[str] = None ) -> Dict[str, float]: """ ============================================ Phase 5: PPO 학습 (기존 방식 유지) ============================================ - Parquet 파일에서 통합 DataLoader 생성 - 모든 task를 섞어서 순차 처리 - 세밀한 제어와 로깅 유지 ============================================ """ PrettyPrinter.status("Phase 5", "Calculating PPO parameters", "info") # PPO mini-batch size 자동 계산 추가 train_batch_size = self.config.data.train_batch_size # 원래 값 백업 original_ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size # 자동 계산: train_batch_size × 3 (induction, deduction, abduction) # 각 IPO에서 3가지 task type이 생성되므로 calculated_ppo_mini_batch_size = train_batch_size * 3 # Config에 적용 self.config.actor_rollout_ref.actor.ppo_mini_batch_size = calculated_ppo_mini_batch_size # 로깅 PrettyPrinter.status("PPO Config", f"ppo_mini_batch_size: {original_ppo_mini_batch_size} → {calculated_ppo_mini_batch_size}", "info") PrettyPrinter.status("PPO Config", f" - train_batch_size: {train_batch_size}", "info") PrettyPrinter.status("PPO Config", f" - task types: 3 (induction, deduction, abduction)", "info") # Data length 자동 계산 (update_iteration 적용) update_iteration = getattr(self.config.azr.data_selection_strategy, 'update_iteration', 1) # data_len이 이미 설정되어 있을 수도 있으므로 원래 값 백업 if hasattr(self.config.azr.data_selection_strategy, 'data_len'): original_data_len = self.config.azr.data_selection_strategy.data_len else: original_data_len = train_batch_size # 기본값 # data_len = train_batch_size × update_iteration calculated_data_len = train_batch_size * update_iteration self.config.azr.data_selection_strategy.data_len = calculated_data_len # 로깅 PrettyPrinter.status("PPO Config", f"data_len: {original_data_len} → {calculated_data_len}", "info") PrettyPrinter.status("PPO Config", f" - update_iteration: {update_iteration}", "info") PrettyPrinter.status("Phase 5", "Creating task-separated DataLoaders from parquet files", "info") # ====== Task별 분리 처리 (부모 클래스 방식) ====== epoch_metrics = {} all_step_metrics = [] # Phase 5 로그 저장용 all_generated_responses = [] # 생성된 응답 저장용 # Task별 분리된 DataLoader 생성 (부모 클래스 azr_ray_trainer.py Line 694-758 그대로 복사) self._create_ttrlvr_dataloaders_from_parent(saved_files) # 총 스텝 수 계산 (가장 긴 dataloader 기준) total_steps = max(len(loader) for loader in self.ttrlvr_dataloaders.values()) if self.ttrlvr_dataloaders else 0 PrettyPrinter.status("Phase 5", f"Starting training with {total_steps} steps", "info") # 부모 클래스처럼 task별로 처리 (azr_ray_trainer.py Line 2136-2174) for step in range(total_steps): step_num = step + 1 # 진행 상황 로깅 (10 step마다) if step_num % 10 == 0 or step_num == total_steps: PrettyPrinter.status("Phase 5 Progress", f"Step {step_num}/{total_steps}", "info") # Task type별로 배치 수집 batches = {} task_types = ['induction', 'deduction', 'abduction'] # TTRLVR task types PrettyPrinter.status("Phase 5", f"Collecting batches for step {step_num}", "info") # 각 task type별로 배치 수집 for task_type in task_types: if task_type not in self.ttrlvr_dataloaders: continue try: # 해당 task type의 배치 가져오기 batch_dict = self._get_ttrlvr_batch(task_type) batch = DataProto.from_single_dict(batch_dict) PrettyPrinter.status("Phase 5", f"Processing {task_type} batch with {len(batch_dict['prompt'])} samples", "info") # --- PPO Step 1: Response 생성 --- gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) gen_batch.meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, 'recompute_log_prob': False } # 시퀀스 생성 gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) # --- PPO Step 2: Batch 결합 및 보상 계산 --- batch = batch.union(gen_batch_output) # Response mask 계산 batch.batch["response_mask"] = compute_response_mask(batch) # compute global_valid tokens (VeRL 요구사항) batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() # old_log_probs 계산 (AZR/VeRL 표준 방식) - PPO에 필요한 단계 old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} # entropys 제거하고 old_log_probs를 배치에 추가 old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) # 보상 계산 - 부모 클래스의 TTRLVR 처리 로직 사용 if self._use_ttrlvr_rewards: # TTRLVR reward 계산 prompts = batch.non_tensor_batch.get('prompts', []) if isinstance(prompts, np.ndarray): prompts = prompts.tolist() elif torch.is_tensor(prompts): prompts = prompts.tolist() # responses 디코딩 responses = [] if hasattr(batch, 'batch') and 'responses' in batch.batch: try: responses = self.tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True) except Exception as e: self.ttrlvr_logger.log_error(f"Error decoding responses: {e}") responses = [] # metadata 추출 (task_type 포함) metadata = batch.non_tensor_batch.get('ttrlvr_metadata', []) if isinstance(metadata, np.ndarray): metadata = metadata.tolist() elif torch.is_tensor(metadata): metadata = metadata.tolist() # TTRLVR reward 계산 rewards = self.ttrlvr_processor.compute_rewards_for_responses( prompts=prompts, responses=responses, metadata=metadata ) # ===== TTRLVR 샘플 로깅 (원본 azr_ray_trainer.py Line 1260-1301 복사) ===== # 모든 step에서 샘플 응답 출력 if len(responses) > 0 and len(prompts) > 0: PrettyPrinter.section_header(f"TTRLVR Sample Task & Response (Step {self.global_steps})") i = 0 # 첫 번째 샘플만 출력 meta = metadata[i] if i < len(metadata) else {} current_task_type = meta.get('task_type', task_type) print(f"\n=== Step {self.global_steps} - Task Type: {current_task_type} ===\n") # 전체 프롬프트 출력 if i < len(prompts): prompt_str = str(prompts[i]) if isinstance(prompts[i], list) else prompts[i] print("=" * 80) print("FULL PROMPT:") print("=" * 80) print(prompt_str) print("=" * 80) else: print(f"Prompt: (no prompt available, prompts length: {len(prompts)})") # 전체 응답 출력 if i < len(responses): print("\n" + "=" * 80) print("LLM FULL RESPONSE:") print("=" * 80) print(responses[i]) print("=" * 80) else: print(f"Response: (no response available)") # 보상 출력 (정답 여부) if i < len(rewards): print(f"\nReward: {rewards[i]:.2f} {'✅ Correct' if rewards[i] > 0 else '❌ Wrong'}") else: print(f"Reward: (no reward available)") # 기대했던 정답도 출력 if 'expected_solution' in meta: print("\n" + "=" * 80) print("EXPECTED SOLUTION:") print("=" * 80) print(meta['expected_solution']) print("=" * 80) # 추가 정보: 평가 데이터 if 'evaluation_data' in meta: eval_data = meta['evaluation_data'] print("\n" + "-" * 40) print("EVALUATION DATA:") print("-" * 40) if current_task_type == 'induction' and 'input_output_pairs' in eval_data: print(f"Input/Output Pairs: {eval_data['input_output_pairs']}") elif current_task_type == 'deduction' and 'input' in eval_data: print(f"Test Input: {eval_data['input']}") elif current_task_type == 'abduction' and 'expected_output' in eval_data: print(f"Expected Output: {eval_data['expected_output']}") print("-" * 40) # tensor로 변환 및 token-level로 확장 reward_tensor = torch.tensor(rewards, dtype=torch.float32, device=batch.batch['responses'].device) reward_tensor = reward_tensor.unsqueeze(-1).expand(-1, batch.batch['responses'].size(1)) batch.batch['token_level_rewards'] = reward_tensor batch.batch['token_level_scores'] = reward_tensor # Task별 정확도 메트릭 저장 (나중에 step_metrics에 추가) batch.meta_info[f'{task_type}_accuracy'] = sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0 batch.meta_info[f'{task_type}_reward_mean'] = float(np.mean(rewards)) if rewards else 0.0 else: # 일반 AZR 모드 rm_output = self.reward_fn(batch) batch.batch['token_level_rewards'] = rm_output.batch['token_level_rewards'] batch.batch['token_level_scores'] = rm_output.batch['token_level_scores'] # compute_advantage 호출 (REINFORCE++ 적용) - 각 task별로! batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, gamma=self.config.algorithm.gamma, lam=self.config.algorithm.lam, num_repeat=self.config.actor_rollout_ref.rollout.n, config=self.config.algorithm ) # 배치 저장 batches[task_type] = batch except Exception as e: PrettyPrinter.status("Phase 5", f"Error processing {task_type}: {str(e)}", "error") continue # 수집된 배치가 있는 경우에만 진행 if batches: # 모든 배치 연결 PrettyPrinter.status("Phase 5", f"Concatenating {len(batches)} batches from task types: {list(batches.keys())}", "info") combined_batch = DataProto.concat(list(batches.values())) # temperature를 meta_info에 추가 (actor 업데이트에 필요) if not hasattr(combined_batch, 'meta_info') or combined_batch.meta_info is None: combined_batch.meta_info = {} # VeRL fsdp_workers.py와 동일하게 rollout.temperature 사용 combined_batch.meta_info['temperature'] = self.config.actor_rollout_ref.rollout.temperature # --- PPO Step 3: Actor/Critic 업데이트 (결합된 배치로) --- step_metrics = {} # Task별 정확도 메트릭 추가 for task_type, task_batch in batches.items(): if f'{task_type}_accuracy' in task_batch.meta_info: step_metrics[f'{task_type}/accuracy'] = task_batch.meta_info[f'{task_type}_accuracy'] step_metrics[f'{task_type}/reward_mean'] = task_batch.meta_info[f'{task_type}_reward_mean'] # 전체 평균 정확도 계산 accuracy_values = [v for k, v in step_metrics.items() if 'accuracy' in k] if accuracy_values: step_metrics['overall/accuracy'] = np.mean(accuracy_values) # Critic 업데이트 if self.use_critic: critic_output = self.critic_wg.update_critic(combined_batch) critic_metrics = reduce_metrics(critic_output.meta_info['metrics']) step_metrics.update(critic_metrics) # Actor 업데이트 if self.config.trainer.critic_warmup <= self.global_steps: # compute global_valid tokens for combined batch (VeRL 요구사항) combined_batch.meta_info['global_token_num'] = torch.sum(combined_batch.batch['attention_mask'], dim=-1).tolist() actor_output = self.actor_rollout_wg.update_actor(combined_batch) actor_metrics = reduce_metrics(actor_output.meta_info['metrics']) step_metrics.update(actor_metrics) # Phase 5 배치 응답 저장 if phase5_log_dirs: # 매 step마다 저장 try: # combined_batch에서 데이터 추출 if hasattr(combined_batch, 'batch'): input_ids = combined_batch.batch.get('input_ids', None) response_ids = combined_batch.batch.get('responses', None) if input_ids is not None and response_ids is not None: # 실제 메서드 호출 - llm_responses/step_XXXX/ 디렉토리에 저장 self._save_phase5_batch_to_llm_responses( phase5_log_dirs, input_ids, response_ids, step_num, self.current_round ) # 간단한 요약도 메모리에 저장 (나중에 phase5_training에 저장) if len(response_ids) > 0: sample_responses = [ self.tokenizer.decode(response_ids[i], skip_special_tokens=True) for i in range(min(3, len(response_ids))) ] all_generated_responses.append({ 'step': step_num, 'samples': sample_responses }) except Exception as e: PrettyPrinter.status("Warning", f"Failed to save llm_responses: {e}", "warning") else: # 배치가 없는 경우 PrettyPrinter.status("Phase 5", "No batches collected, skipping update", "warning") step_metrics = {} # 메트릭 수집 all_step_metrics.append(step_metrics) for k, v in step_metrics.items(): if k not in epoch_metrics: epoch_metrics[k] = [] epoch_metrics[k].append(v) # 원본 AZR과 동일하게 매 step마다 메트릭 로깅 if step_metrics: # 콘솔에 메트릭 테이블 출력 (원본 AZR 스타일) display_metrics = {} for k, v in step_metrics.items(): if isinstance(v, float): display_metrics[k] = f"{v:.6f}" else: display_metrics[k] = v PrettyPrinter.table( ["Metric", "Value"], [[k, v] for k, v in display_metrics.items()], title=f"Step {self.global_steps} Metrics" ) # WandB/Console logger에 기록 logger.log(data=step_metrics, step=self.global_steps) # Global step 증가 self.global_steps += 1 # 메모리 정리 (100 step마다) if step_num % 100 == 0: import gc gc.collect() torch.cuda.empty_cache() # 평균 메트릭 계산 avg_metrics = { k: np.mean(v) for k, v in epoch_metrics.items() } # Phase 5 로그를 각 문제별 디렉토리에 저장 if phase5_log_dirs: self._save_phase5_logs(phase5_log_dirs, all_step_metrics, all_generated_responses, avg_metrics) # Phase 5 완료 상태 overall_acc = avg_metrics.get('overall/accuracy', 0) induction_acc = avg_metrics.get('induction/accuracy', 0) deduction_acc = avg_metrics.get('deduction/accuracy', 0) abduction_acc = avg_metrics.get('abduction/accuracy', 0) PrettyPrinter.section_header("Phase 5 Training Complete") print(f"\n📊 Final Accuracy Metrics:") print(f" • Overall: {overall_acc:.2%}") print(f" • Induction: {induction_acc:.2%}") print(f" • Deduction: {deduction_acc:.2%}") print(f" • Abduction: {abduction_acc:.2%}") print(f"\n📈 Average Reward: {avg_metrics.get('overall/reward_mean', 0):.4f}") return avg_metrics def _generate_batch_with_vllm(self, prompts: List[str], temperature: float = 0.0, seed: int = None, n: int = 1) -> List[str]: """ VeRL의 vLLM을 사용한 배치 생성 (4의 배수 필수) Args: prompts: 프롬프트 리스트 (4의 배수여야 함, n>1이면 첫 번째만 사용) temperature: 생성 온도 seed: 랜덤 시드 (None이면 랜덤 생성) n: 각 프롬프트당 생성할 응답 개수 (기본값 1) Returns: 생성된 텍스트 리스트 (n>1이면 n개, 아니면 len(prompts)개) """ if len(prompts) % 4 != 0: self.ttrlvr_logger.log_warning(f"Batch size {len(prompts)} is not divisible by 4, padding will be applied") # 빈 프롬프트 검증 및 교체 validated_prompts = [] for i, prompt in enumerate(prompts): if not prompt or not prompt.strip(): self.ttrlvr_logger.log_warning(f"Empty prompt at index {i}, using default") validated_prompts.append("Please provide a Python function.") else: validated_prompts.append(prompt) # n>1이면 첫 번째 프롬프트만 사용 (n개 응답이 생성됨) if n > 1: # VeRL의 4의 배수 요구사항을 충족하기 위해 첫 번째 프롬프트를 4번 반복 tokenize_prompts = [validated_prompts[0]] * 4 self.ttrlvr_logger.log_info(f"Using n={n}, processing single prompt with 4x padding") else: tokenize_prompts = validated_prompts # VeRL/AZR 스타일 토큰화 import verl.utils.torch_functional as verl_F from verl.utils.model import compute_position_id_with_mask inputs_list = [] for prompt in tokenize_prompts: # 채팅 템플릿 적용 if isinstance(prompt, list): # 이미 채팅 형태인 경우 formatted_prompt = self.tokenizer.apply_chat_template( prompt, add_generation_prompt=True, tokenize=False ) else: # 단순 문자열인 경우 user 메시지로 변환 formatted_prompt = self.tokenizer.apply_chat_template([ {"role": "user", "content": prompt} ], add_generation_prompt=True, tokenize=False) # VeRL 토큰화 사용 input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( prompt=formatted_prompt, tokenizer=self.tokenizer, max_length=2048, pad_token_id=self.tokenizer.pad_token_id, left_pad=True, truncation="error" ) inputs_list.append({ 'input_ids': input_ids[0], 'attention_mask': attention_mask[0] }) # 수동 스택킹 (collate_fn과 동일한 효과) inputs = {} inputs['input_ids'] = torch.stack([item['input_ids'] for item in inputs_list]) inputs['attention_mask'] = torch.stack([item['attention_mask'] for item in inputs_list]) # position_ids 생성 (VeRL 표준 방식) position_ids = compute_position_id_with_mask(inputs['attention_mask']) # DataProto 생성 - validate=False로 변경하여 meta_info의 파라미터 직접 사용 meta_info = { 'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id, 'do_sample': temperature > 0, 'validate': False, # False로 변경하여 meta_info 파라미터 사용 'temperature': temperature, # temperature 직접 전달 'top_p': 0.95, 'top_k': -1, 'n': n, # n 파라미터 전달 'recompute_log_prob': False, } # seed가 명시적으로 제공된 경우에만 추가 (다양성을 위해 기본적으로 seed 제거) if seed is not None: meta_info['seed'] = seed self.ttrlvr_logger.log_debug(f"Using fixed seed={seed} for generation") else: # seed=None인 경우 seed를 meta_info에 추가하지 않음 (vLLM이 자동으로 랜덤 생성) self.ttrlvr_logger.log_debug(f"No seed specified - using random generation for diversity (n={n})") gen_batch = DataProto.from_dict( tensors={ 'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'position_ids': position_ids, # position_ids 추가 }, meta_info=meta_info ) # 4의 배수로 패딩 (필요한 경우) if len(prompts) % 4 != 0: gen_batch_padded, pad_size = pad_dataproto_to_divisor(gen_batch, self.actor_rollout_wg.world_size) output_batch_padded = self.actor_rollout_wg.generate_sequences(gen_batch_padded) output_batch = unpad_dataproto(output_batch_padded, pad_size=pad_size) else: # 이미 4의 배수면 패딩 불필요 output_batch = self.actor_rollout_wg.generate_sequences(gen_batch) # 응답 디코딩 - n을 고려한 처리 try: responses = self.tokenizer.batch_decode(output_batch.batch['responses'], skip_special_tokens=True) # TTRLVR처럼 탭을 공백으로 변환 responses = [resp.replace("\t", " ") for resp in responses] # n>1일 때: 단일 프롬프트에서 n개 응답이 생성됨 if n > 1: # 첫 n개 응답이 실제 결과 (나머지는 패딩으로 인한 중복) responses = responses[:n] self.ttrlvr_logger.log_info(f"✅ Generated {len(responses)} diverse responses using n={n}") else: self.ttrlvr_logger.log_info(f"✅ Successfully generated {len(responses)} responses") except Exception as e: self.ttrlvr_logger.log_error(f"❌ Error decoding responses: {e}") expected_count = n if n > 1 else len(prompts) responses = [""] * expected_count # 빈 응답으로 fallback return responses def _generate_with_vllm( self, prompt: str, temperature: float = 0.7 ) -> str: """ 헬퍼 함수: VeRL의 vLLM을 사용한 단일 텍스트 생성 4의 배수 제약을 처리하기 위해 내부적으로 배치 함수 사용 """ # 4개로 패딩하여 배치 생성 prompts = [prompt, prompt, prompt, prompt] # 같은 프롬프트 4개 # 배치 생성 responses = self._generate_batch_with_vllm(prompts, temperature) # 첫 번째 응답만 반환 return responses[0] if responses else "" def _save_round_data(self, round_data: List[Dict], round_num: int) -> Dict[str, str]: """라운드 데이터를 parquet 파일로 저장 (AZR이 기대하는 형식)""" import pandas as pd from collections import defaultdict # AZR 학습용 디렉토리 (VeRL이 읽을 위치) # ttrlvr_logger.log_dir과 동일한 경로 사용 (Phase 5와 일치) if hasattr(self, 'ttrlvr_logger') and self.ttrlvr_logger: # ttrlvr_logger.log_dir은 이미 round_X까지 포함하므로 바로 azr_training_data 추가 output_dir = os.path.join(self.ttrlvr_logger.log_dir, "azr_training_data") else: # Fallback: self.output_dir 사용 (기존 방식) output_dir = os.path.join(self.output_dir, f"round_{round_num}", "azr_training_data") os.makedirs(output_dir, exist_ok=True) # Task 타입별로 분리 - Dict 구조 처리 tasks_by_type = defaultdict(list) # round_data는 [{'induction': [...], 'deduction': [...], 'abduction': [...]}, ...] 형태 for tasks_dict in round_data: if isinstance(tasks_dict, dict) and any(isinstance(v, list) for v in tasks_dict.values()): # Dict 구조인 경우 (새로운 방식) - value가 리스트인 것들만 for task_type, task_list in tasks_dict.items(): if isinstance(task_list, list): # 🔧 중요: 문자열 prompt를 딕셔너리 리스트로 변환 + evaluation_data 보존 converted_tasks = [] for task in task_list: task_copy = task.copy() # 1. prompt 변환 (기존 로직) if 'prompt' in task and isinstance(task['prompt'], str): task_copy['prompt'] = [{"role": "user", "content": task['prompt']}] # 2. evaluation_data 명시적 보존 (Phase 5 reward 계산에 필수) # DataFrame 변환 시 누락되지 않도록 보장 if 'evaluation_data' in task: task_copy['evaluation_data'] = task['evaluation_data'] converted_tasks.append(task_copy) tasks_by_type[task_type].extend(converted_tasks) elif isinstance(tasks_dict, dict): # 기존 방식 (단일 task dict) - 호환성 유지 task_type = tasks_dict.get('task_type', 'unknown') task_copy = tasks_dict.copy() # 🔧 중요: 문자열 prompt를 딕셔너리 리스트로 변환 if 'prompt' in tasks_dict and isinstance(tasks_dict['prompt'], str): task_copy['prompt'] = [{"role": "user", "content": tasks_dict['prompt']}] # evaluation_data 보존 if 'evaluation_data' in tasks_dict: task_copy['evaluation_data'] = tasks_dict['evaluation_data'] tasks_by_type[task_type].append(task_copy) saved_files = {} for task_type in ['induction', 'deduction', 'abduction']: if task_type in tasks_by_type and tasks_by_type[task_type]: tasks = tasks_by_type[task_type] # IPO group ID로 정렬 (배치 일관성) df = pd.DataFrame(tasks) if 'ipo_group_id' in df.columns: df = df.sort_values('ipo_group_id') file_path = os.path.join(output_dir, f"{task_type}.parquet") df.to_parquet(file_path, index=False) saved_files[task_type] = file_path PrettyPrinter.status("Data", f"Saved {len(tasks)} {task_type} tasks to parquet", "info") return saved_files def _create_ttrlvr_dataloaders_from_parent(self, saved_files: Dict[str, str]): """ 부모 클래스(azr_ray_trainer.py Line 694-758)의 코드를 그대로 복사 Task type별로 분리된 dataloader 생성하여 각각 advantage normalization """ from absolute_zero_reasoner.utils.dataset.ttrlvr_dataset import TTRLVRDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler # Task type별 dataloader 생성 self.ttrlvr_dataloaders = {} self.ttrlvr_iterators = {} # iterator 추가 task_types = ['induction', 'deduction', 'abduction'] for task_type in task_types: if task_type in saved_files and os.path.exists(saved_files[task_type]): # 각 task type별 데이터셋 생성 (부모 클래스 Line 728-737) task_dataset = TTRLVRDataset( parquet_files=saved_files[task_type], tokenizer=self.tokenizer, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error', extra_source_key=f"ttrlvr_{task_type}", task_type=task_type # task type 명시 ) # Sampler 설정 (부모 클래스 Line 740-745) if self.config.data.shuffle: generator = torch.Generator() generator.manual_seed(self.config.data.get('seed', 1) + hash(task_type) % 1000) sampler = RandomSampler(data_source=task_dataset, generator=generator) else: sampler = SequentialSampler(data_source=task_dataset) # DataLoader 생성 (부모 클래스 Line 748-754) self.ttrlvr_dataloaders[task_type] = DataLoader( dataset=task_dataset, batch_size=self.config.data.train_batch_size, drop_last=True, collate_fn=self._ttrlvr_collate_fn, sampler=sampler ) # Iterator 초기화 self.ttrlvr_iterators[task_type] = iter(self.ttrlvr_dataloaders[task_type]) PrettyPrinter.status("TTRLVR", f"Created dataloader for {task_type} with {len(task_dataset)} samples", "success") def _get_ttrlvr_batch(self, task_type: str): """ 부모 클래스(azr_ray_trainer.py Line 678-692)의 코드를 그대로 복사 TTRLVR task type별 배치를 가져오는 메서드 """ # Iterator가 없거나 소진된 경우 새로 생성 if task_type not in self.ttrlvr_iterators: if task_type in self.ttrlvr_dataloaders: self.ttrlvr_iterators[task_type] = iter(self.ttrlvr_dataloaders[task_type]) else: raise ValueError(f"No dataloader for task type: {task_type}") try: return next(self.ttrlvr_iterators[task_type]) except StopIteration: # Iterator 소진시 재생성 self.ttrlvr_iterators[task_type] = iter(self.ttrlvr_dataloaders[task_type]) return next(self.ttrlvr_iterators[task_type]) def _ttrlvr_collate_fn(self, batch): """TTRLVR 데이터를 위한 collate function (부모 클래스와 동일)""" from verl.utils.dataset.rl_dataset import collate_fn import numpy as np # 기본 collate 실행 collated = collate_fn(batch) # TTRLVR 메타데이터 추가 ttrlvr_metadata = [] prompts = [] for item in batch: if 'ttrlvr_metadata' in item: ttrlvr_metadata.append(item['ttrlvr_metadata']) prompts.append(item.get('prompt', '')) # DataProto가 요구하는 형식으로 변환 # dict의 리스트를 numpy array로 변환 (dtype=object) if ttrlvr_metadata: collated['ttrlvr_metadata'] = np.array(ttrlvr_metadata, dtype=object) if prompts: collated['prompts'] = np.array(prompts, dtype=object) return collated def _create_unified_dataloader(self, saved_files: Dict[str, str]): """ Parquet 파일들을 읽어서 통합 DataLoader 생성 모든 task type을 섞어서 하나의 DataLoader로 """ from absolute_zero_reasoner.utils.dataset.rl_dataset import RLHFDataset from torch.utils.data import DataLoader, RandomSampler, ConcatDataset from verl.utils.dataset.rl_dataset import collate_fn import pandas as pd # 모든 parquet 파일 읽어서 하나로 합치기 all_data = [] for task_type, file_path in saved_files.items(): if os.path.exists(file_path): df = pd.read_parquet(file_path) # task_type을 데이터에 추가 (나중에 메트릭 분리용) df['task_type'] = task_type all_data.append(df) PrettyPrinter.status("Data", f"Loaded {len(df)} {task_type} samples from parquet", "info") # 데이터프레임 합치기 if all_data: combined_df = pd.concat(all_data, ignore_index=True) # 셔플 (필요한 경우) if self.config.data.shuffle: combined_df = combined_df.sample(frac=1).reset_index(drop=True) # 임시 parquet 파일로 저장 temp_parquet = os.path.join(self.output_dir, "temp_combined.parquet") combined_df.to_parquet(temp_parquet, index=False) # RLHFDataset 생성 dataset = RLHFDataset( parquet_files=temp_parquet, tokenizer=self.tokenizer, prompt_key=self.config.data.prompt_key, max_prompt_length=self.config.data.max_prompt_length, filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error', extra_source_key="ttrlvr_unified" ) # DataLoader 생성 if self.config.data.shuffle: sampler = RandomSampler(dataset) else: sampler = None self.train_dataloader = DataLoader( dataset=dataset, batch_size=self.config.data.train_batch_size, drop_last=True, collate_fn=collate_fn, sampler=sampler ) PrettyPrinter.status("DataLoader", f"Created unified dataloader with {len(dataset)} total samples", "success") else: raise ValueError("No data files found to create dataloader") def _convert_to_verl_dataset(self, round_data: List[Dict]) -> Any: """TTRLVR 데이터를 VeRL 형식으로 변환""" # VeRL이 기대하는 형식으로 변환 converted_data = [] for task in round_data: # VeRL/AZR 스타일 토큰화 import verl.utils.torch_functional as verl_F # 채팅 템플릿 적용 if isinstance(task['prompt'], list): formatted_prompt = self.tokenizer.apply_chat_template( task['prompt'], add_generation_prompt=True, tokenize=False ) else: formatted_prompt = self.tokenizer.apply_chat_template([ {"role": "user", "content": task['prompt']} ], add_generation_prompt=True, tokenize=False) # VeRL 토큰화 사용 input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( prompt=formatted_prompt, tokenizer=self.tokenizer, max_length=self.config.data.max_prompt_length, pad_token_id=self.tokenizer.pad_token_id, left_pad=True, truncation="error" ) # VeRL 형식의 데이터 verl_item = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'prompt': task['prompt'], 'target': task['target'], 'task_type': task['task_type'], 'problem_id': task['problem_id'] } converted_data.append(verl_item) return converted_data def _create_dataloader(self, dataset=None, batch_size=None): """부모 클래스의 _create_dataloader 오버라이드 - TTRLVR은 동적으로 데이터 생성""" # 부모 클래스 호출 시 (인자 없이) if dataset is None and batch_size is None: # TTRLVR은 각 라운드마다 동적으로 데이터를 생성하므로 # 초기화 시점에는 빈 데이터로더 설정 self.train_dataset = None self.train_dataloader = None self.valid_dataset = None self.valid_dataloader = None return # 실제 데이터로더 생성 시 (인자와 함께) from torch.utils.data import DataLoader, Dataset class SimpleDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] dataset_obj = SimpleDataset(dataset) return DataLoader( dataset_obj, batch_size=batch_size, shuffle=True, collate_fn=self._collate_fn ) def _collate_fn(self, batch): """배치 데이터 정리""" # VeRL collate 함수 사용 return collate_fn(batch) def _prepare_generation_batch(self, batch): """생성을 위한 배치 준비""" # VeRL 형식으로 배치 준비 return batch def _ppo_update(self, batch, reward_tensor): """PPO 업데이트 수행""" # 부모 클래스의 PPO 업데이트 로직 사용 # 실제 구현은 ReasonRLRayPPOTrainer에 있음 return {} def _parse_output(self, text: str) -> List[Any]: """생성된 텍스트에서 출력 파싱""" # 간단한 파싱 로직 # 실제로는 더 복잡한 파싱이 필요할 수 있음 try: # Python 리스트 형식으로 파싱 시도 import ast return ast.literal_eval(text.strip()) except: # 파싱 실패 시 빈 리스트 return [] def _extract_python_code(self, solution: str) -> str: """개선된 Python 코드 추출 (AZR 방식 + 추가 패턴) - TTRLVR과 완전 동일""" # 1. AZR의 extract_code 함수 먼저 시도 try: extracted = extract_code(solution, language="python") if extracted: return extracted except: pass # 2. 다양한 마크다운 패턴 시도 patterns = [ r'```python\n(.*?)```', # ```python ... ``` r'```\n(.*?)```', # ``` ... ``` r'```py\n(.*?)```', # ```py ... ``` r'```Python\n(.*?)```', # ```Python ... ``` r'Here is.*?:\n\n```python\n(.*?)```', # 설명 텍스트 포함 r'Here is.*?:\n\n```\n(.*?)```', # 설명 텍스트 포함 ] for pattern in patterns: matches = re.findall(pattern, solution, re.DOTALL | re.IGNORECASE) if matches: return matches[-1].strip() # 3. def로 시작하는 함수 찾기 lines = solution.split('\n') code_lines = [] in_function = False for line in lines: if line.strip().startswith('def '): in_function = True code_lines.append(line) elif in_function and (line.startswith(' ') or line.strip() == ''): code_lines.append(line) elif in_function and line.strip() and not line.startswith(' '): # 함수 정의 끝 break if code_lines: return '\n'.join(code_lines) # 4. 원본 반환 return solution def _add_imports_from_prompt(self, solution: str, prompt: str) -> str: """HumanEval 프롬프트에서 import 문을 추출하여 솔루션에 추가 (EvalPlus 방식)""" # 이미 import가 있으면 그대로 반환 if 'from typing import' in solution or 'import typing' in solution: return solution # 프롬프트에서 import 문 추출 import_lines = [] prompt_lines = prompt.split('\n') for line in prompt_lines: stripped = line.strip() # import 문 찾기 if (stripped.startswith('from ') and 'import' in stripped) or stripped.startswith('import '): import_lines.append(line) # 함수 정의가 시작되면 중단 elif stripped.startswith('def '): break # import가 없으면 원본 반환 if not import_lines: return solution # 솔루션이 이미 import로 시작하는지 확인 solution_lines = solution.split('\n') first_non_empty_line = None for i, line in enumerate(solution_lines): if line.strip(): first_non_empty_line = i break # import를 맨 앞에 추가 if first_non_empty_line is not None: # 기존 import 뒤에 추가하거나 맨 앞에 추가 imports_text = '\n'.join(import_lines) + '\n\n' # 첫 번째 비어있지 않은 줄이 import인 경우 if solution_lines[first_non_empty_line].strip().startswith(('import ', 'from ')): # 마지막 import 찾기 last_import_idx = first_non_empty_line for i in range(first_non_empty_line, len(solution_lines)): if solution_lines[i].strip() and not solution_lines[i].strip().startswith(('import ', 'from ')): break if solution_lines[i].strip().startswith(('import ', 'from ')): last_import_idx = i # 마지막 import 다음에 추가 solution_lines.insert(last_import_idx + 1, '') solution_lines.insert(last_import_idx + 1, '\n'.join(import_lines)) else: # 맨 앞에 추가 solution = imports_text + solution return solution return '\n'.join(solution_lines) def _fix_function_definition(self, solution: str, prompt: str, problem_id: str = "") -> str: """함수 정의가 누락된 경우 복구 + lpw 스타일 중복 처리""" # lpw 스타일: 프롬프트에서 함수 이름 추출 func_def_match = re.search(r'def\s+(\w+)\([^)]*\)(?:\s*->\s*[^:]+)?:', prompt) if not func_def_match: return solution entry_point = func_def_match.group(1) func_def_line = func_def_match.group(0) # HumanEval의 경우 전체 코드를 반환하므로 중복 처리 불필요 if 'HumanEval' in problem_id: # 이미 전체 코드가 있으므로 그대로 반환 return solution # MBPP의 경우 기존 로직 유지 # Case 1: LLM이 전체 함수를 생성한 경우 (lpw 스타일 체크) if (prompt in solution) or (f'def {entry_point}(' in solution): # 함수가 이미 포함되어 있음 return solution # Case 2: 함수 본문만 생성한 경우 - 함수 정의 추가 if solution and not solution.startswith('def '): # 함수 정의와 함수 내용을 결합 lines = solution.split('\n') fixed_lines = [func_def_line] for line in lines: if line.strip(): # 빈 줄이 아닌 경우 # if __name__ == "__main__": 부분은 함수 밖에 있어야 함 if line.strip().startswith('if __name__'): # 함수 정의 끝내고 메인 부분 시작 fixed_lines.append('') # 빈 줄 추가 fixed_lines.append(line.strip()) else: # 함수 내용은 4칸 인덴테이션 if not line.startswith(' ') and line.strip(): line = ' ' + line.lstrip() fixed_lines.append(line) else: fixed_lines.append(line) solution = '\n'.join(fixed_lines) return solution def _evaluate_baseline_performance(self, problem: Dict[str, Any]) -> Dict[str, Any]: """베이스라인 성능 측정 (temperature=0.05로 5번 실행)""" self.ttrlvr_logger.log_info(f"📊 Evaluating baseline performance for {problem.get('task_id', 'unknown')}") baseline_results = { 'success': True, 'total_rounds': getattr(self.config, 'baseline_evaluation_rounds', 4), # 4의 배수로 변경 'solutions': [], 'evaluations': [], 'success_count': 0, 'average_accuracy': 0.0, 'error': None } try: # 프롬프트 준비 problem_prompt = problem.get('prompt', '') problem_id = problem.get('task_id', 'unknown') # 벤치마크에 따라 적절한 프롬프트 선택 if 'HumanEval' in problem_id: prompt = get_prompt("solution_humaneval_basic", problem_prompt=problem_prompt) else: # MBPP 프롬프트 사용 prompt = get_prompt("solution_mbpp_basic", problem_prompt=problem_prompt) # vLLM의 n=4 파라미터를 사용하여 다양한 응답 생성 # seed를 제거하여 다양성 확보 (테스트 결과 확인됨) # VeRL은 4의 배수를 요구하므로 1개 프롬프트를 4번 반복하지만 # n=4로 실제로는 첫 번째 프롬프트에서만 4개 응답 생성 batch_prompts = [prompt] * 4 # VeRL 요구사항 충족 # n=4를 사용하여 다양한 솔루션 생성 (seed 제거로 다양성 확보) all_solutions = self._generate_batch_with_vllm( [prompt], # 실제로는 1개 프롬프트만 필요 temperature=0.7, seed=None, # seed 제거하여 다양성 확보 n=4 # 4개의 다양한 응답 생성 ) # 각 솔루션 처리 for round_id in range(baseline_results['total_rounds']): self.ttrlvr_logger.log_info(f" 🔄 Baseline round {round_id + 1}/{baseline_results['total_rounds']}") solution = all_solutions[round_id] if round_id < len(all_solutions) else "" # TTRLVR과 완전 동일한 후처리 적용 # 1. 마크다운 코드 블록에서 Python 코드 추출 extracted_solution = self._extract_python_code(solution) if extracted_solution and extracted_solution != solution: solution = extracted_solution # 2. HumanEval의 경우 import 추가 if 'HumanEval' in problem_id: solution = self._add_imports_from_prompt(solution, problem_prompt) # 3. 함수 정의 복구 solution = self._fix_function_definition(solution, prompt, problem_id) # 구문 검증 syntax_valid = False syntax_error = None try: compile(solution, '', 'exec') syntax_valid = True except SyntaxError as e: syntax_error = str(e) solution_result = { 'round_id': round_id, 'solution': solution, 'syntax_valid': syntax_valid, 'syntax_error': syntax_error, 'evaluation': None } # 정확성 평가 (테스트 실행) if syntax_valid: evaluation = self._evaluate_solution(problem, solution) solution_result['evaluation'] = evaluation if evaluation.get('correct', False): baseline_results['success_count'] += 1 self.ttrlvr_logger.log_info(f" ✅ Round {round_id + 1}: PASSED (Base: {evaluation.get('base_passed', 0)}/{evaluation.get('base_total', 0)}, Plus: {evaluation.get('plus_passed', 0)}/{evaluation.get('plus_total', 0)})") else: self.ttrlvr_logger.log_info(f" ❌ Round {round_id + 1}: FAILED (Base: {evaluation.get('base_passed', 0)}/{evaluation.get('base_total', 0)}, Plus: {evaluation.get('plus_passed', 0)}/{evaluation.get('plus_total', 0)})") else: self.ttrlvr_logger.log_warning(f" ❌ Round {round_id + 1}: Syntax error - {syntax_error}") baseline_results['solutions'].append(solution_result) # Batch evaluation 형식으로 상세 로그 저장 (모든 라운드) self._save_batch_evaluation_format(problem, solution_result, attempt_num=round_id + 1) # 평균 정확도 계산 if baseline_results['success_count'] > 0: baseline_results['average_accuracy'] = baseline_results['success_count'] / baseline_results['total_rounds'] self.ttrlvr_logger.log_info(f" 📈 Baseline performance: {baseline_results['success_count']}/{baseline_results['total_rounds']} success ({baseline_results['average_accuracy']:.3f})") # baseline_results.json과 problem_metadata.json 저장 self._save_baseline_results_files(baseline_results, problem) except Exception as e: self.ttrlvr_logger.log_error(f"❌ Baseline evaluation failed: {e}") baseline_results['success'] = False baseline_results['error'] = str(e) return baseline_results def _save_baseline_results_files(self, baseline_results: Dict[str, Any], problem: Dict[str, Any]): """베이스라인 평가 결과를 JSON 파일들로 저장""" import json from datetime import datetime # current_evaluation 디렉토리 current_dir = os.path.join(self.ttrlvr_logger.log_dir, "current_evaluation") os.makedirs(current_dir, exist_ok=True) # 1. baseline_results.json 저장 baseline_results_file = os.path.join(current_dir, "baseline_results.json") baseline_data = { "timestamp": datetime.now().isoformat(), "problem_id": problem.get('task_id', 'unknown'), "benchmark": problem.get('benchmark', 'mbpp'), "total_attempts": baseline_results['total_rounds'], "successful_attempts": baseline_results['success_count'], "success_rate": baseline_results['average_accuracy'], "evaluation_status": "SUCCESS" if baseline_results['average_accuracy'] > 0 else "FAILED", "solutions": baseline_results['solutions'], "error": baseline_results.get('error') } with open(baseline_results_file, 'w', encoding='utf-8') as f: json.dump(baseline_data, f, indent=2, default=str) # 2. problem_metadata.json 저장 metadata_file = os.path.join(current_dir, "problem_metadata.json") metadata = { "task_id": problem.get('task_id', 'unknown'), "benchmark": problem.get('benchmark', 'mbpp'), "prompt": problem.get('prompt', ''), "entry_point": problem.get('entry_point', ''), "canonical_solution": problem.get('canonical_solution', ''), "test_imports": problem.get('test_imports', []), "base_input": problem.get('base_input', []), "plus_input": problem.get('plus_input', []), "base_test_count": len(problem.get('base_input', [])), "plus_test_count": len(problem.get('plus_input', [])), "timestamp": datetime.now().isoformat() } with open(metadata_file, 'w', encoding='utf-8') as f: json.dump(metadata, f, indent=2, default=str) # 3. summary.txt 저장 (기존 TTRLVR 형식 유지) summary_file = os.path.join(current_dir, "summary.txt") with open(summary_file, 'w', encoding='utf-8') as f: f.write("Current Evaluation Summary\n") f.write(f"Problem ID: {problem.get('task_id', 'unknown')}\n") f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n") f.write("=" * 80 + "\n\n") f.write("OVERALL STATISTICS:\n") f.write("=" * 80 + "\n") f.write(f"Total Attempts: {baseline_results['total_rounds']}\n") f.write(f"Successful Attempts: {baseline_results['success_count']}\n") f.write(f"Success Rate: {baseline_results['average_accuracy']:.3f}\n") status = "✅ SUCCESS" if baseline_results['average_accuracy'] > 0 else "❌ FAILED" f.write(f"Evaluation Status: {status}\n\n") # Individual attempt files 리스트 attempt_files = [f"attempt_{i+1}.txt" for i in range(baseline_results['total_rounds'])] f.write(f"Individual attempt files: {', '.join(attempt_files)}\n") self.ttrlvr_logger.log_info(f"✅ Baseline results saved to {current_dir}") def _extract_function_code(self, code: str) -> str: """코드에서 함수 정의와 필요한 import 추출 (TTRLVR solution_generator.py 복사)""" import re lines = code.strip().split('\n') import_lines = [] func_lines = [] in_function = False indent_level = 0 # 1. import 문 수집 for line in lines: stripped = line.strip() if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'): import_lines.append(line) # 2. 함수 정의 찾기 for line in lines: if line.strip().startswith('def '): in_function = True func_lines = [line] # 첫 줄의 들여쓰기 레벨 저장 indent_level = len(line) - len(line.lstrip()) elif in_function: # 빈 줄이거나 같은/더 깊은 들여쓰기면 함수의 일부 if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level): func_lines.append(line) else: # 함수 끝 break # 3. import + function 결합 if func_lines: result_lines = import_lines + [''] + func_lines if import_lines else func_lines return '\n'.join(result_lines) else: return code def _evaluate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]: """LLM 솔루션을 벤치마크 테스트로 평가 (TTRLVR solution_generator.py evaluate_solution 완전 복사)""" try: # EvalPlus 함수들 임포트 (pip으로 설치된 버전 사용) self.ttrlvr_logger.log_info("🔄 Attempting to import EvalPlus...") from evalplus.evaluate import check_correctness from evalplus.gen.util import trusted_exec from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS from evalplus.eval import PASS self.ttrlvr_logger.log_info("✅ Using EvalPlus for evaluation") except ImportError as e: # EvalPlus가 없으면 오류로 처리 (fallback 제거) self.ttrlvr_logger.log_error(f"❌ EvalPlus is required but not available: {e}") import traceback self.ttrlvr_logger.log_error(f"📋 Import traceback: {traceback.format_exc()}") return { 'correct': False, 'passed_tests': 0, 'total_tests': 0, 'error': f"EvalPlus import failed: {e}. Please install EvalPlus properly.", 'execution_results': [], 'base_passed': 0, 'plus_passed': 0, 'base_total': 0, 'plus_total': 0 } except Exception as e: self.ttrlvr_logger.log_error(f"❌ EvalPlus import failed with unexpected error: {e}") return { 'correct': False, 'passed_tests': 0, 'total_tests': 0, 'error': f"EvalPlus import error: {e}", 'execution_results': [], 'base_passed': 0, 'plus_passed': 0, 'base_total': 0, 'plus_total': 0 } result = { 'correct': False, 'passed_tests': 0, 'total_tests': 0, 'error': None, 'execution_results': [], 'base_passed': 0, 'plus_passed': 0, 'base_total': 0, 'plus_total': 0 } try: # 1. 함수 정의 추출 extracted_code = self._extract_function_code(solution) if not extracted_code: result['error'] = "No function definition found" return result # 2. 데이터셋 타입 결정 task_id = problem.get('task_id', '') if task_id.startswith('Mbpp'): dataset = 'mbpp' elif task_id.startswith('HumanEval'): dataset = 'humaneval' else: # 기본값 dataset = 'mbpp' # 3. expected outputs 생성 (canonical solution 사용) entry_point = problem.get('entry_point', '') canonical_solution = problem.get('canonical_solution', '') if not canonical_solution: result['error'] = "No canonical_solution found" return result # Expected outputs 계산 expected_output = {} # Base tests base_inputs = problem.get('base_input', []) if base_inputs: expected_output['base'], expected_output['base_time'] = trusted_exec( problem.get('prompt', '') + canonical_solution, base_inputs, entry_point, record_time=True, output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS ) # Plus tests plus_inputs = problem.get('plus_input', []) if plus_inputs: expected_output['plus'], expected_output['plus_time'] = trusted_exec( problem.get('prompt', '') + canonical_solution, plus_inputs, entry_point, record_time=True, output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS ) # 4. EvalPlus check_correctness 호출 evalplus_result = check_correctness( dataset=dataset, completion_id=0, problem=problem, solution=extracted_code, expected_output=expected_output, base_only=False, # Plus tests도 실행 fast_check=False, # 모든 테스트 실행 identifier=task_id ) # 5. 결과 파싱 if 'base' in evalplus_result: base_stat, base_details = evalplus_result['base'] result['base_total'] = len(base_inputs) if base_stat == PASS: result['base_passed'] = result['base_total'] else: result['base_passed'] = sum(1 for d in base_details if d) if base_details else 0 result['passed_tests'] += result['base_passed'] result['total_tests'] += result['base_total'] if 'plus' in evalplus_result: plus_stat, plus_details = evalplus_result['plus'] result['plus_total'] = len(plus_inputs) if plus_stat == PASS: result['plus_passed'] = result['plus_total'] else: result['plus_passed'] = sum(1 for d in plus_details if d) if plus_details else 0 result['passed_tests'] += result['plus_passed'] result['total_tests'] += result['plus_total'] # EvalPlus 기준: 모든 테스트 통과해야 correct result['correct'] = (result['passed_tests'] == result['total_tests']) and result['total_tests'] > 0 # 에러 메시지 설정 if not result['correct']: if 'base' in evalplus_result: base_stat, _ = evalplus_result['base'] if base_stat != PASS: result['error'] = f"Base tests failed: {base_stat}" if 'plus' in evalplus_result: plus_stat, _ = evalplus_result['plus'] if plus_stat != PASS and not result['error']: result['error'] = f"Plus tests failed: {plus_stat}" # 로깅 self.ttrlvr_logger.log_info(f"EvalPlus evaluation for {task_id}:") self.ttrlvr_logger.log_info(f" Base: {result['base_passed']}/{result['base_total']}") self.ttrlvr_logger.log_info(f" Plus: {result['plus_passed']}/{result['plus_total']}") self.ttrlvr_logger.log_info(f" Total: {result['passed_tests']}/{result['total_tests']}") self.ttrlvr_logger.log_info(f" Correct: {result['correct']}") except Exception as e: result['error'] = f"Evaluation failed: {str(e)}" import traceback self.ttrlvr_logger.log_info(f"Evaluation traceback: {traceback.format_exc()}") return result def _save_batch_evaluation_format(self, problem: Dict[str, Any], solution_result: Dict[str, Any], attempt_num: int): """Batch evaluation과 동일한 형식으로 상세 로그 저장""" from datetime import datetime # Current evaluation 디렉토리 생성 current_dir = os.path.join(self.ttrlvr_logger.log_dir, "current_evaluation") os.makedirs(current_dir, exist_ok=True) # attempt 파일 생성 attempt_file = os.path.join(current_dir, f"attempt_{attempt_num}.txt") timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') with open(attempt_file, 'w', encoding='utf-8') as f: # 헤더 정보 f.write(f"Current Evaluation - Attempt {attempt_num}\n") f.write(f"Problem ID: {problem.get('task_id', 'unknown')}\n") f.write(f"Benchmark: {problem.get('benchmark_name', 'unknown')}\n") f.write(f"Generated: {timestamp}\n") f.write("="*80 + "\n\n") # 1. 원본 문제 f.write("1. ORIGINAL PROBLEM:\n") f.write("="*80 + "\n") f.write(problem.get('prompt', 'No prompt available')) f.write("\n" + "="*80 + "\n\n") # 2. LLM에 들어가는 스크립트 (프롬프트) f.write("2. LLM INPUT SCRIPT (PROMPT):\n") f.write("="*80 + "\n") problem_prompt = problem.get('prompt', '') # 프롬프트 생성 (간단한 버전) full_prompt = f"""You are a Python writing assistant. Complete the following Python function. {problem_prompt} Please provide a complete implementation of the function.""" f.write(full_prompt.strip()) f.write("\n" + "="*80 + "\n\n") # 3. LLM의 응답 f.write("3. LLM RESPONSE:\n") f.write("="*80 + "\n") f.write(solution_result.get('solution', 'No solution generated')) f.write("\n" + "="*80 + "\n\n") # 4. 정답 여부 f.write("4. CORRECTNESS EVALUATION:\n") f.write("="*80 + "\n") # 구문 검증 f.write(f"Syntax Valid: {'✅ YES' if solution_result.get('syntax_valid', False) else '❌ NO'}\n") if solution_result.get('syntax_error'): f.write(f"Syntax Error: {solution_result['syntax_error']}\n") # 정확성 평가 evaluation = solution_result.get('evaluation') if evaluation: if evaluation.get('correct', False): f.write(f"Result: ✅ CORRECT\n") else: f.write(f"Result: ❌ INCORRECT\n") # EvalPlus 상세 결과 f.write(f"Base Tests: {evaluation.get('base_passed', 0)}/{evaluation.get('base_total', 0)} passed\n") f.write(f"Plus Tests: {evaluation.get('plus_passed', 0)}/{evaluation.get('plus_total', 0)} passed\n") f.write(f"Total: {evaluation.get('passed_tests', 0)}/{evaluation.get('total_tests', 0)} tests passed\n") if evaluation.get('error'): f.write(f"Evaluation Error: {evaluation['error']}\n") else: f.write("Result: ❌ NO EVALUATION (syntax error or evaluation failed)\n") f.write("="*80 + "\n") self.ttrlvr_logger.log_info(f"📝 Batch evaluation format saved: {attempt_file}")