diff --git "a/absolute_zero_reasoner/trainer/ppo/azr_ray_trainer.py" "b/absolute_zero_reasoner/trainer/ppo/azr_ray_trainer.py" new file mode 100644--- /dev/null +++ "b/absolute_zero_reasoner/trainer/ppo/azr_ray_trainer.py" @@ -0,0 +1,2804 @@ +import uuid +from pathlib import Path +from copy import deepcopy +from typing import List, Dict, Tuple +import random +import json +from collections import defaultdict +import threading +import gc +import os +import pickle +import ast +from datetime import datetime + +import ray +import torch +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from omegaconf import OmegaConf +import numpy as np +from verl.utils.dataset.rl_dataset import collate_fn +from verl.utils.debug import marked_timer +from verl.trainer.ppo.ray_trainer import ( + apply_kl_penalty, + compute_advantage, + compute_response_mask, + reduce_metrics, + compute_timing_metrics, + agg_loss, +) +from verl.trainer.ppo.metric_utils import _compute_response_info +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto, DataProto + +from absolute_zero_reasoner.utils.tracking import ReasonRLTracking +from absolute_zero_reasoner.data_construction.constructor import get_gen_code_io_data, get_pred_code_io_data +from absolute_zero_reasoner.trainer.ppo.reason_rl_ray_trainer import ReasonRLRayPPOTrainer +from absolute_zero_reasoner.utils.dataset.rl_dataset import RLHFDataset +from absolute_zero_reasoner.rewards.code_reward import parse_code_input_output, parse_inputs_message +from absolute_zero_reasoner.utils.code_utils.python_executor import PythonExecutor +# SandboxfusionExecutor는 필요할 때만 import (Docker 의존성 회피) +from absolute_zero_reasoner.utils.auxiliary import reflection_keywords +from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter + + +seed_program = """def f(a): + return a""" + + +def create_default_dict(): + return defaultdict(int) + + +def compute_data_metrics(batch, use_critic=True, tokenizer=None): + sequence_score = batch.batch['token_level_scores'].sum(-1) + sequence_reward = batch.batch['token_level_rewards'].sum(-1) + + advantages = batch.batch['advantages'] + returns = batch.batch['returns'] + + max_response_length = batch.batch['responses'].shape[-1] + + prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() + response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info['prompt_length'] + response_length = response_info['response_length'] + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch['values'] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + reflect_list = [] + correct_list = [] + correct_response_length = [] + incorrect_response_length = [] + for i in range(len(batch)): + data_item = batch[i] # DataProtoItem + prompt_ids = data_item.batch['prompts'] + _prompt_length = prompt_ids.shape[-1] + response_ids = data_item.batch['responses'] + valid_response_length = data_item.batch['attention_mask'][_prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + # decode + responses_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) + reflect = any([kw in responses_str.lower() for kw in reflection_keywords]) + reflect_list.append(reflect) + + reward = data_item.batch['token_level_rewards'].sum(-1) + correct = reward >= 1 + correct_list.append(correct) + if correct: + correct_response_length.append(valid_response_length.item()) + else: + incorrect_response_length.append(valid_response_length.item()) + + # the ratio of reflection + reflect_ratio = sum(reflect_list) / len(reflect_list) \ + if len(reflect_list) > 0 else 0 + # the ratio of correct response in relfection samples + correct_ratio = sum([reflect_list[i] and correct_list[i] for i in range(len(reflect_list))]) / \ + sum(reflect_list) if sum(reflect_list) > 0 else 0 + + # separate lengths + length_metrics = {} + if len(correct_response_length) > 0: + length_metrics['correct_response_length/mean'] = sum(correct_response_length) / len(correct_response_length) + if len(incorrect_response_length) > 0: + length_metrics['incorrect_response_length/mean'] = sum(incorrect_response_length) / len(incorrect_response_length) + + metrics = { + # score + 'critic/score/mean': + torch.mean(sequence_score).detach().item(), + 'critic/score/max': + torch.max(sequence_score).detach().item(), + 'critic/score/min': + torch.min(sequence_score).detach().item(), + # reward + 'critic/rewards/mean': + torch.mean(sequence_reward).detach().item(), + 'critic/rewards/max': + torch.max(sequence_reward).detach().item(), + 'critic/rewards/min': + torch.min(sequence_reward).detach().item(), + # adv + 'critic/advantages/mean': + torch.mean(valid_adv).detach().item(), + 'critic/advantages/max': + torch.max(valid_adv).detach().item(), + 'critic/advantages/min': + torch.min(valid_adv).detach().item(), + # returns + 'critic/returns/mean': + torch.mean(valid_returns).detach().item(), + 'critic/returns/max': + torch.max(valid_returns).detach().item(), + 'critic/returns/min': + torch.min(valid_returns).detach().item(), + **({ + # values + 'critic/values/mean': torch.mean(valid_values).detach().item(), + 'critic/values/max': torch.max(valid_values).detach().item(), + 'critic/values/min': torch.min(valid_values).detach().item(), + # vf explained var + 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } if use_critic else {}), + + # response length + 'response_length/mean': + torch.mean(response_length).detach().item(), + 'response_length/max': + torch.max(response_length).detach().item(), + 'response_length/min': + torch.min(response_length).detach().item(), + 'response_length/clip_ratio': + torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + "response_length/reflect_ratio": reflect_ratio, + "response_length/correct_reflect_ratio": correct_ratio, + **length_metrics, + # prompt length + 'prompt_length/mean': + torch.mean(prompt_length).detach().item(), + 'prompt_length/max': + torch.max(prompt_length).detach().item(), + 'prompt_length/min': + torch.min(prompt_length).detach().item(), + 'prompt_length/clip_ratio': + torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + return metrics + + +# Create a local function to process elements before sending to manager +def process_elements(entries): + """Process element types locally before sending to manager""" + processed = [] + for entry in entries: + entry_copy = entry.copy() + if 'input' in entry: + try: + input_type = determine_type(entry['input']) + entry_copy['_input_type'] = input_type + except: + entry_copy['_input_type'] = "str" + + if 'output' in entry: + try: + output_type = determine_type(entry['output']) + entry_copy['_output_type'] = output_type + except: + entry_copy['_output_type'] = "str" + + if 'inputs' in entry: + try: + entry_copy['_input_types'] = [determine_type(inp) for inp in entry['inputs']] + except: + entry_copy['_input_types'] = ["str"] * len(entry['inputs']) + + if 'outputs' in entry: + try: + entry_copy['_output_types'] = [determine_type(out) for out in entry['outputs']] + except: + entry_copy['_output_types'] = ["str"] * len(entry['outputs']) + + processed.append(entry_copy) + return processed + +def determine_type(element): + """Determine type safely without eval""" + try: + # Handle potential tuple strings without parentheses + if isinstance(element, str) and ',' in element: + # Attempt to parse as tuple by wrapping in parentheses + try: + wrapped = f'({element})' + parsed_tuple = ast.literal_eval(wrapped) + if isinstance(parsed_tuple, tuple): + return 'tuple' + except: + pass # Proceed to normal parsing + + # Try using ast.literal_eval for safety + parsed = ast.literal_eval(element) + if is_pickleable(parsed): + return type(parsed).__name__ + else: + return "str" + except: + return "str" + + +def is_pickleable(obj): + try: + pickle.dumps(obj) + return True + except (pickle.PicklingError, TypeError, AttributeError): + return False + + +def save_programs_to_file(programs, problem_type, global_step, save_path, data_type="valid"): + """ + Save programs to JSONL file with metadata + + Args: + programs: List of program dictionaries + problem_type: Type of problem (gen_code_i, gen_code_o, etc.) + global_step: Current training step + save_path: Base path for saving + data_type: "valid" or "invalid" + """ + if not programs or not save_path: + return + + # Create directory structure + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_dir = os.path.join(save_path, problem_type, data_type) + os.makedirs(save_dir, exist_ok=True) + + # Create filename with step and timestamp + filename = f"step_{global_step:06d}_{timestamp}.jsonl" + file_path = os.path.join(save_dir, filename) + + # Save programs with metadata + try: + with open(file_path, 'w', encoding='utf-8') as f: + for i, program in enumerate(programs): + program_data = { + **program, + 'meta': { + 'global_step': global_step, + 'timestamp': timestamp, + 'program_id': f"{global_step}_{problem_type}_{data_type}_{i}", + 'problem_type': problem_type, + 'data_type': data_type, + 'save_time': datetime.now().isoformat() + } + } + f.write(json.dumps(program_data, ensure_ascii=False) + '\n') + + print(f"💾 Saved {len(programs)} {data_type} {problem_type} programs to {file_path}") + return file_path + except Exception as e: + print(f"❌ Failed to save {data_type} {problem_type} programs: {e}") + return None + + +@ray.remote +class DatasetManager: + def __init__(self): + self.datasets = { + 'input': [], # Stores only data entries + 'output': [], # Stores only data entries + 'seed': [], + 'error': [], + 'problem': [], + 'error_seed': [], + 'input_steps': [], # Parallel list storing step numbers + 'output_steps': [], # Parallel list storing step numbers + 'error_steps': [], # Parallel list storing step numbers + 'problem_steps': [], # Parallel list storing step numbers + 'input_steps_counter': defaultdict(int), + 'output_steps_counter': defaultdict(int), + 'error_steps_counter': defaultdict(int), + 'problem_steps_counter': defaultdict(int), + } + self.type_counters = { + 'input_types': defaultdict(create_default_dict), + 'output_types': defaultdict(create_default_dict), + 'error_types': defaultdict(create_default_dict), + } + self.locks = { + 'input': threading.Lock(), + 'output': threading.Lock(), + 'seed': threading.Lock(), + 'error': threading.Lock(), + 'problem': threading.Lock(), + 'error_seed': threading.Lock(), + 'input_steps': threading.Lock(), + 'output_steps': threading.Lock(), + 'error_steps': threading.Lock(), + 'problem_steps': threading.Lock(), + 'input_steps_counter': threading.Lock(), + 'output_steps_counter': threading.Lock(), + 'error_steps_counter': threading.Lock(), + 'problem_steps_counter': threading.Lock(), + 'input_types': threading.RLock(), + 'output_types': threading.RLock(), + 'error_types': threading.RLock(), + } + + def update_seed(self, entries): + with self.locks['seed']: + existing = {json.dumps(d, sort_keys=True): True for d in self.datasets['seed']} + new_entries = [e for e in entries if json.dumps(e, sort_keys=True) not in existing] + + for entry in new_entries: + if 'input' in entry and '_input_type' in entry: + self.count_element(entry['input'], entry['_input_type'], 'input') + if 'output' in entry and '_output_type' in entry: + self.count_element(entry['output'], entry['_output_type'], 'output') + + self.datasets['seed'].extend(new_entries) + return len(new_entries) + + def update_error_seed(self, entries): + with self.locks['error_seed'], self.locks['error_types']: + existing = {json.dumps(d, sort_keys=True): True for d in self.datasets['error_seed']} + new_entries = [e for e in entries if json.dumps(e, sort_keys=True) not in existing] + + # Process using pre-computed types + for entry in new_entries: + if 'output' in entry and '_output_type' in entry: + self.count_element(entry['output'], entry['_output_type'], 'error') + + self.datasets['error_seed'].extend(new_entries) + return len(new_entries) + + def get_dataset(self, name) -> List[Dict]: + """Returns only the data entries without step information""" + return deepcopy(self.datasets[name]) + + def get_all_datasets(self) -> Dict[str, List[Dict]]: + """Returns all datasets without step information""" + return { + 'input': deepcopy(self.datasets['input']), + 'output': deepcopy(self.datasets['output']), + 'seed': deepcopy(self.datasets['seed']), + 'error': deepcopy(self.datasets['error']), + 'problem': deepcopy(self.datasets['problem']), + 'error_seed': deepcopy(self.datasets['error_seed']), + 'input_steps': deepcopy(self.datasets['input_steps']), + 'output_steps': deepcopy(self.datasets['output_steps']), + 'error_steps': deepcopy(self.datasets['error_steps']), + 'problem_steps': deepcopy(self.datasets['problem_steps']), + 'input_steps_counter': deepcopy(self.datasets['input_steps_counter']), + 'output_steps_counter': deepcopy(self.datasets['output_steps_counter']), + 'error_steps_counter': deepcopy(self.datasets['error_steps_counter']), + 'problem_steps_counter': deepcopy(self.datasets['problem_steps_counter']), + } + + def add_input_batch(self, entries: List[Dict], global_step: int): + with self.locks['input'], self.locks['input_steps'], self.locks['input_types']: + for entry in entries: + if 'input' in entry and '_input_type' in entry: + self.count_element(entry['input'], entry['_input_type'], 'input') + + self.datasets['input'].extend(entries) + self.datasets['input_steps'].extend([global_step]*len(entries)) + self.datasets['input_steps_counter'][global_step] += len(entries) + return len(self.datasets['input']) + + def add_output_batch(self, entries: List[Dict], global_step: int): + with self.locks['output'], self.locks['output_steps'], self.locks['output_types']: + for entry in entries: + if 'output' in entry and '_output_type' in entry: + self.count_element(entry['output'], entry['_output_type'], 'output') + + self.datasets['output'].extend(entries) + self.datasets['output_steps'].extend([global_step]*len(entries)) + self.datasets['output_steps_counter'][global_step] += len(entries) + return len(self.datasets['output']) + + def add_error_batch(self, entries: List[Dict], global_step: int): + with self.locks['error'], self.locks['error_steps'], self.locks['error_types'], self.locks['error_types']: + for entry in entries: + if 'output' in entry and '_output_type' in entry: + self.count_element(entry['output'], entry['_output_type'], 'error') + + self.datasets['error'].extend(entries) + self.datasets['error_steps'].extend([global_step]*len(entries)) + self.datasets['error_steps_counter'][global_step] += len(entries) + return len(self.datasets['error']) + + def add_error_seed_batch(self, entries: List[Dict], global_step: int): + with self.locks['error_seed'], self.locks['error_steps']: + for entry in entries: + if 'output' in entry and '_output_type' in entry: + self.count_element(entry['output'], entry['_output_type'], 'error') + + self.datasets['error_seed'].extend(entries) + self.datasets['error_steps'].extend([global_step]*len(entries)) + self.datasets['error_steps_counter'][global_step] += len(entries) + return len(self.datasets['error_seed']) + + def add_problem_batch(self, entries: List[Dict], global_step: int): + with self.locks['problem'], self.locks['problem_steps'], self.locks['problem_steps_counter']: + for entry in entries: + if 'inputs' in entry and '_input_types' in entry: + for inp, inp_type in zip(entry['inputs'], entry['_input_types']): + self.count_element(inp, inp_type, 'input') + if 'outputs' in entry and '_output_types' in entry: + for out, out_type in zip(entry['outputs'], entry['_output_types']): + self.count_element(out, out_type, 'output') + + self.datasets['problem'].extend(entries) + self.datasets['problem_steps'].extend([global_step]*len(entries)) + self.datasets['problem_steps_counter'][global_step] += len(entries) + return len(entries) + + def get_snippets(self) -> List[Dict]: + # get the snippets from input and output datasets merged together + snippets = [] + if self.datasets['input'] or self.datasets['output']: + for d in self.datasets['input']: + snippets.append({'snippet': d['snippet'], 'original_snippet': d['original_snippet'], 'imports': d['imports']}) + for d in self.datasets['output']: + snippets.append({'snippet': d['snippet'], 'original_snippet': d['original_snippet'], 'imports': d['imports']}) + return list(snippets) + else: # we are in the seed stage + for d in self.datasets['seed']: + snippets.append({'snippet': d['snippet'], 'original_snippet': d['original_snippet'], 'imports': d['imports']}) + return list(snippets) + + def get_snippets_with_steps(self) -> List[Tuple[Dict, int]]: + snippets = self.get_snippets() + return list(zip(snippets, self.datasets['input_steps'] + self.datasets['output_steps'])) + + def get_recent_additions(self, dataset_key: str, current_step: int, window: int) -> int: + counter_key = f"{dataset_key}_steps_counter" + with self.locks[counter_key]: + # Get steps from the counter dictionary instead of list + recent_steps = [ + step for step in self.datasets[counter_key].keys() + if (current_step - step) <= window + ] + total_recent = sum( + self.datasets[counter_key][step] + for step in recent_steps + ) + return total_recent + + def get_dataset_with_steps(self, name) -> List[Tuple[Dict, int]]: + if name == 'input': + assert len(self.datasets['input']) == len(self.datasets['input_steps']), \ + "Input data/steps mismatch!" + return list(zip(deepcopy(self.datasets['input']), self.datasets['input_steps'])) + elif name == 'output': + assert len(self.datasets['output']) == len(self.datasets['output_steps']), \ + "Output data/steps mismatch!" + return list(zip(deepcopy(self.datasets['output']), self.datasets['output_steps'])) + elif name == 'error': + assert len(self.datasets['error']) == len(self.datasets['error_steps']), \ + "Error data/steps mismatch!" + return list(zip(deepcopy(self.datasets['error']), self.datasets['error_steps'])) + elif name == 'problem': + assert len(self.datasets['problem']) == len(self.datasets['problem_steps']), \ + "Problem data/steps mismatch!" + return list(zip(deepcopy(self.datasets['problem']), self.datasets['problem_steps'])) + raise ValueError(f"Invalid dataset name: {name}") + + def get_steps_dataset(self, name) -> List[int]: + if name == 'input': + return self.datasets['input_steps'] + elif name == 'output': + return self.datasets['output_steps'] + elif name == 'error': + return self.datasets['error_steps'] + elif name == 'problem': + return self.datasets['problem_steps'] + raise ValueError(f"Invalid dataset name: {name}") + + def truncate_datasets(self, max_length: int, name: str) -> Tuple[int, int]: + if name == 'input': + with self.locks['input'], self.locks['input_steps']: + before_length = len(self.datasets['input']) + self.datasets['input'] = self.datasets['input'][:max_length] + self.datasets['input_steps'] = self.datasets['input_steps'][:max_length] + truncated_length = before_length - len(self.datasets['input']) + return truncated_length, before_length + elif name == 'output': + with self.locks['output'], self.locks['output_steps']: + before_length = len(self.datasets['output']) + self.datasets['output'] = self.datasets['output'][:max_length] + self.datasets['output_steps'] = self.datasets['output_steps'][:max_length] + truncated_length = before_length - len(self.datasets['output']) + return truncated_length, before_length + elif name == 'seed': + with self.locks['seed']: + before_length = len(self.datasets['seed']) + self.datasets['seed'] = self.datasets['seed'][:max_length] + truncated_length = before_length - len(self.datasets['seed']) + return truncated_length, before_length + elif name == 'error': + with self.locks['error']: + before_length = len(self.datasets['error']) + self.datasets['error'] = self.datasets['error'][:max_length] + truncated_length = before_length - len(self.datasets['error']) + return truncated_length, before_length + elif name == 'error_seed': + with self.locks['error_seed']: + before_length = len(self.datasets['error_seed']) + self.datasets['error_seed'] = self.datasets['error_seed'][:max_length] + truncated_length = before_length - len(self.datasets['error_seed']) + return truncated_length, before_length + elif name == 'problem': + with self.locks['problem']: + before_length = len(self.datasets['problem']) + self.datasets['problem'] = self.datasets['problem'][:max_length] + truncated_length = before_length - len(self.datasets['problem']) + return truncated_length, before_length + else: + raise ValueError(f"Invalid dataset name: {name}") + + def get_dataset_size(self, name: str) -> int: + with self.locks[name]: + return len(self.datasets[name]) + + def full_load_datasets(self, datasets): + """Load all datasets from a dictionary""" + self.datasets = datasets + + def full_load_data_with_type_counters(self, data: Dict): + """Load datasets and type counters""" + # First create a copy of the current empty structure + default_structure = { + 'input': [], 'output': [], 'seed': [], 'error': [], 'problem': [], + 'error_seed': [], 'input_steps': [], 'output_steps': [], 'error_steps': [], + 'problem_steps': [], 'input_steps_counter': defaultdict(int), + 'output_steps_counter': defaultdict(int), 'error_steps_counter': defaultdict(int), + 'problem_steps_counter': defaultdict(int) + } + + # Extract datasets + datasets_only = {k: v for k, v in data.items() if k != 'type_counters'} + + # Merge loaded data with default structure + merged_datasets = default_structure.copy() + merged_datasets.update(datasets_only) + + # Set the merged result + self.datasets = merged_datasets + + # Then load type counters if available + if 'type_counters' in data: + with self.locks['input_types'], self.locks['output_types'], self.locks['error_types']: + for counter_key in ['input_types', 'output_types', 'error_types']: + if counter_key in data['type_counters']: + self.type_counters[counter_key] = defaultdict(create_default_dict) + for type_name, values in data['type_counters'][counter_key].items(): + for value, count in values.items(): + self.type_counters[counter_key][type_name][value] = count + + def get_type_statistics(self, counter_key): + """Get statistics about the types and their counts.""" + with self.locks[counter_key]: + return { + type_name: { + "total_unique": len(values), + "total_count": sum(values.values()), + "examples": list(values.keys())[:5] # First 5 examples + } + for type_name, values in self.type_counters[counter_key].items() + } + + def get_all_type_statistics(self): + """Get all type statistics for inputs, outputs, and errors.""" + return { + 'input_types': self.get_type_statistics('input_types'), + 'output_types': self.get_type_statistics('output_types'), + 'error_types': self.get_type_statistics('error_types') + } + + def get_all_data_with_type_counters(self) -> Dict: + """Returns all datasets and type counters""" + all_data = self.get_all_datasets() + all_data.update({ + 'type_counters': { + 'input_types': deepcopy(self.type_counters['input_types']), + 'output_types': deepcopy(self.type_counters['output_types']), + 'error_types': deepcopy(self.type_counters['error_types']), + } + }) + return all_data + + def get_type_counter(self, counter_key): + counter_type = f"{counter_key}_types" + with self.locks[counter_type]: + return self.type_counters[counter_type] + + def count_element(self, element, element_type, counter_key): + counter_type = f"{counter_key}_types" + with self.locks[counter_type]: + self.type_counters[counter_type][element_type][element] += 1 + + +class CodeIORayPPOTrainer(ReasonRLRayPPOTrainer): + _supported_tasks = {'code_i', 'code_o', 'code_e', 'code_f'} + def __init__(self, past_epoch_window: int = 10, *args, **kwargs): + # TTRLVR integration 체크 - super().__init__() 호출 전에 설정 + config = kwargs.get('config', args[0] if args else None) + self._use_ttrlvr_rewards = getattr(config.azr, 'use_ttrlvr_rewards', False) if config else False + + super().__init__(*args, **kwargs) + assert self.config.actor_rollout_ref.rollout.n == 1, "CodeIO only supports n=1 for now" + assert all(problem_type in self._supported_tasks for problem_type in self.config.azr.problem_types), \ + f"Invalid problem type: {self.config.azr.problem_types}" + self._past_epoch_window = past_epoch_window + + # TTRLVR processor 초기화 + if self._use_ttrlvr_rewards: + from .ttrlvr_azr_integration import TTRLVRAZRDataProcessor + self.ttrlvr_processor = TTRLVRAZRDataProcessor(self.tokenizer) + PrettyPrinter.status("TTRLVR", "Using TTRLVR reward calculation", "success") + + if self.config.azr.executor == 'qwq': + self._executor = PythonExecutor( + timeout_length=self.config.azr.execute_max_timeout, + ast_check=self.config.azr.ast_check, + max_workers=self.config.azr.get('executor_max_workers', 1) + ) + elif self.config.azr.executor == 'sandboxfusion': + try: + from absolute_zero_reasoner.utils.code_utils.sandboxfusion_executor import SandboxfusionExecutor + self._executor = SandboxfusionExecutor( + timeout_length=self.config.azr.execute_max_timeout, + ast_check=self.config.azr.ast_check, + max_workers=self.config.azr.get('executor_max_workers', 1), + use_china_mirror=self.config.azr.get('use_china_mirror', True) + ) + except ImportError as e: + raise ImportError(f"SandboxfusionExecutor requires Docker. Please install Docker or use executor='qwq'. Error: {e}") + else: + raise ValueError(f'Invalid executor: {self.config.azr.executor}') + self.dataset_manager = DatasetManager.remote() + self._last_cleanup_step = 0 + self._cleanup_frequency = self.config.azr.get('executor_cleanup_frequency', 5) + + # TTRLVR dataloader iterators + self.ttrlvr_iterators = {} + + def _get_ttrlvr_batch(self, task_type: str): + """TTRLVR task type별 배치를 가져오는 메서드""" + # Iterator가 없거나 소진된 경우 새로 생성 + if task_type not in self.ttrlvr_iterators: + if task_type in self.ttrlvr_dataloaders: + self.ttrlvr_iterators[task_type] = iter(self.ttrlvr_dataloaders[task_type]) + else: + raise ValueError(f"No dataloader for task type: {task_type}") + + try: + return next(self.ttrlvr_iterators[task_type]) + except StopIteration: + # Iterator 소진시 재생성 + self.ttrlvr_iterators[task_type] = iter(self.ttrlvr_dataloaders[task_type]) + return next(self.ttrlvr_iterators[task_type]) + + def _create_dataloader(self): + """ + TTRLVR 데이터를 사용하는 경우 TTRLVRDataset 사용 + """ + if self._use_ttrlvr_rewards: + from absolute_zero_reasoner.utils.dataset.ttrlvr_dataset import TTRLVRDataset + from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + + # Task type별 dataloader 생성 + self.ttrlvr_dataloaders = {} + + # TTRLVR 데이터 경로가 디렉토리인지 확인 + import os + from omegaconf import ListConfig + + train_files = self.config.data.train_files + + # OmegaConf ListConfig를 일반 리스트로 변환 + if isinstance(train_files, ListConfig): + train_files = list(train_files) + + # train_files가 리스트인 경우 첫 번째 파일의 디렉토리 사용 + if isinstance(train_files, (list, tuple)) and len(train_files) > 0: + data_dir = os.path.dirname(str(train_files[0])) + else: + data_dir = str(train_files) + + if os.path.isdir(data_dir): + task_types = ['induction', 'deduction', 'abduction'] + + for task_type in task_types: + task_file = os.path.join(data_dir, f"{task_type}.parquet") + if os.path.exists(task_file): + # 각 task type별 데이터셋 생성 + task_dataset = TTRLVRDataset( + parquet_files=task_file, + tokenizer=self.tokenizer, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error', + extra_source_key=f"ttrlvr_{task_type}", + task_type=task_type # task type 명시 + ) + + # Sampler 설정 + if self.config.data.shuffle: + generator = torch.Generator() + generator.manual_seed(self.config.data.get('seed', 1) + hash(task_type) % 1000) + sampler = RandomSampler(data_source=task_dataset, generator=generator) + else: + sampler = SequentialSampler(data_source=task_dataset) + + # DataLoader 생성 + self.ttrlvr_dataloaders[task_type] = DataLoader( + dataset=task_dataset, + batch_size=self.config.data.train_batch_size, + drop_last=True, + collate_fn=self._ttrlvr_collate_fn, + sampler=sampler + ) + + PrettyPrinter.status("TTRLVR", + f"Created dataloader for {task_type} with {len(task_dataset)} samples", + "success") + else: + # 단일 파일인 경우 기존 방식 사용 + self.train_dataset = TTRLVRDataset( + parquet_files=self.config.data.train_files, + tokenizer=self.tokenizer, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error', + extra_source_key="ttrlvr_train" + ) + + # Sampler 설정 + if self.config.data.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) + sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=self.train_dataset) + + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.train_batch_size, + drop_last=True, + collate_fn=self._ttrlvr_collate_fn, + sampler=sampler + ) + + # Validation 데이터로더 (필요시) + if self.config.data.val_files: + self.val_dataset = TTRLVRDataset( + parquet_files=self.config.data.val_files, + tokenizer=self.tokenizer, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error', + extra_source_key="ttrlvr_val" + ) + + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=self.config.data.val_batch_size, + drop_last=False, + collate_fn=self._ttrlvr_collate_fn + ) + else: + self.val_dataloader = None + + # total_training_steps 설정 (부모 클래스와 동일한 로직) + # TTRLVR은 각 task type별 dataloader의 최소 길이 사용 (모든 task에서 배치를 가져와야 하므로) + if hasattr(self, 'ttrlvr_dataloaders') and self.ttrlvr_dataloaders: + # 가장 작은 dataloader 길이가 실제 훈련 스텝 수 (모든 task에서 배치를 동시에 가져와야 함) + total_batches_per_epoch = min(len(dl) for dl in self.ttrlvr_dataloaders.values()) + total_training_steps = total_batches_per_epoch * self.config.trainer.total_epochs + else: + # 단일 dataloader 사용하는 경우 + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + # Config에 total_training_steps 설정 + try: + from omegaconf import open_dict, OmegaConf + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + else: + # 기존 AZR 데이터로더 사용 + super()._create_dataloader() + + def _ttrlvr_collate_fn(self, batch): + """TTRLVR 데이터를 위한 collate function""" + from verl.utils.dataset.rl_dataset import collate_fn + import numpy as np + + # 기본 collate 실행 + collated = collate_fn(batch) + + # TTRLVR 메타데이터 추가 + ttrlvr_metadata = [] + prompts = [] + + for item in batch: + if 'ttrlvr_metadata' in item: + ttrlvr_metadata.append(item['ttrlvr_metadata']) + prompts.append(item.get('prompt', '')) + + # DataProto가 요구하는 형식으로 변환 + # dict의 리스트를 numpy array로 변환 (dtype=object) + if ttrlvr_metadata: + collated['ttrlvr_metadata'] = np.array(ttrlvr_metadata, dtype=object) + if prompts: + collated['prompts'] = np.array(prompts, dtype=object) + + # DEBUG: 첫 번째 배치만 로그 + if not hasattr(self, '_ttrlvr_collate_logged'): + print(f"\n[_ttrlvr_collate_fn] Called with batch size: {len(batch)}") + print(f"[_ttrlvr_collate_fn] Number of prompts collected: {len(prompts)}") + print(f"[_ttrlvr_collate_fn] Added to collated: prompts={len(prompts)}, metadata={len(ttrlvr_metadata)}") + self._ttrlvr_collate_logged = True + + return collated + + def cleanup(self): + """Clean up the executor and other resources""" + if hasattr(self._executor, 'cleanup'): + PrettyPrinter.status("CLEANUP", "Cleaning up executor...", "info") + self._executor.cleanup() + # Force garbage collection + gc.collect() + + def _create_train_code_gen_dataloader( + self, + problem_type: str, + data_len: int, + dataset_key: str = None, + seeding: bool = False, + ) -> DataLoader: + if dataset_key is None: + if problem_type == 'code_i': + dataset_key = 'input' + elif problem_type == 'code_o': + dataset_key = 'output' + elif problem_type == 'code_e': + dataset_key = 'error' + elif problem_type == 'code_f': + # For code_f we use merged snippets from all datasets + io_data = ray.get(self.dataset_manager.get_snippets.remote()) + else: + raise ValueError(f'Invalid problem type: {problem_type}') + + if problem_type != 'code_f': + io_data = ray.get(self.dataset_manager.get_dataset.remote(dataset_key)) + + parquet_path = (self._code_dir / f'train_gen_{problem_type}.parquet').as_posix() + os.makedirs(os.path.dirname(parquet_path), exist_ok=True) + + # Handle weights strategy + if problem_type == 'code_f' and not seeding: + if self.config.azr.gen_data_probabilities_strategy == 'step': + entries_with_steps = ray.get(self.dataset_manager.get_snippets_with_steps.remote()) + weights = [w + 1 for _, w in entries_with_steps] if entries_with_steps else [1.0]*len(io_data) + else: + weights = [1.0] * len(io_data) + elif dataset_key == 'seed': + weights = None + elif self.config.azr.gen_data_probabilities_strategy == 'uniform': + weights = [1.0] * len(io_data) + elif self.config.azr.gen_data_probabilities_strategy == 'step': + weights = [w + 1 for w in ray.get(self.dataset_manager.get_steps_dataset.remote(dataset_key))] + else: + raise ValueError(f"Unknown strategy: {self.config.azr.gen_data_probabilities_strategy}") + + # Common parameters for get_gen_code_io_data + gen_params = { + 'io_data': io_data, + 'target_data_len': data_len, + 'problem_type': problem_type, + 'content_max_length': self.config.azr.data_selection_strategy.content_max_length, + 'io_n': 1 if problem_type == 'code_f' else self.config.azr.data_selection_strategy.io_n, + 'instruction_type': self.config.reward_fn.extraction_type, + 'output_path': parquet_path, + 'split': 'train', + 'tokenizer': self.tokenizer, + 'banned_keywords': self.config.azr.data_selection_strategy.banned_words, + 'banned_assertion_keywords': self.config.azr.data_selection_strategy.banned_keywords_for_errors_and_exceptions, + 'weights': weights, + 'enable_composite_function': self.config.azr.data_selection_strategy.composite_start_step > 0 and self.global_steps >= self.config.azr.data_selection_strategy.composite_start_step, + 'composite_function_n_min': self.config.azr.data_selection_strategy.composite_function_n_min, + 'composite_function_n_max': self.config.azr.data_selection_strategy.composite_function_n_max, + 'composite_chance': self.config.azr.data_selection_strategy.composite_chance, + 'remove_after_return': self.config.azr.reward.generation_reward_config.remove_after_return, + 'remove_input_from_snippet': self.config.azr.reward.generation_reward_config.remove_input_from_snippet, + 'include_references': self.config.azr.reward.generation_reward_config.include_references, + } + + # Add code_f specific parameters + if problem_type == 'code_f': + gen_params.update({ + 'num_inputs': self.config.azr.data_selection_strategy.num_inputs, + }) + + get_gen_code_io_data(**gen_params) + + code_gen_train_dataset = RLHFDataset( + parquet_files=parquet_path, + tokenizer=self.tokenizer, + prompt_key=self.config.data.prompt_key, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error', + extra_source_key=f"gen_{problem_type}_train" + ) + + if self.config.data.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) + sampler = RandomSampler(code_gen_train_dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(code_gen_train_dataset) + + return iter(DataLoader( + dataset=code_gen_train_dataset, + batch_size=self.config.data.train_batch_size, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler + )) + + def _create_train_code_pred_dataloader(self, problem_type: str, data_len: int) -> DataLoader: + if problem_type == 'code_i': + dataset_key = 'input' + elif problem_type == 'code_o': + dataset_key = 'output' + elif problem_type == 'code_e': + dataset_key = 'error' + elif problem_type == 'code_f': + dataset_key = 'problem' + else: + raise ValueError(f'Invalid problem type: {problem_type}') + full_dataset = ray.get(self.dataset_manager.get_dataset.remote(dataset_key)) + + strategy = self.config.azr.pred_data_mix_strategy + + if strategy == "step": + # Get entries with their creation steps + entries_with_steps = ray.get(self.dataset_manager.get_dataset_with_steps.remote(dataset_key)) + if not entries_with_steps: + selected_data = [] + else: + entries, steps = zip(*entries_with_steps) + # Calculate inverse step weights (newer entries get higher weight) + selected_indices = random.choices( + range(len(entries)), + weights=steps, + k=min(data_len, len(entries)) + ) + selected_data = [entries[i] for i in selected_indices] + elif strategy == "uniform_total": + selected_data = random.sample(full_dataset, min(len(full_dataset), data_len)) + elif strategy == "max_new": + total_recent = ray.get(self.dataset_manager.get_recent_additions.remote( + dataset_key, self.global_steps, self._past_epoch_window + )) + new_programs = full_dataset[-total_recent:] if total_recent > 0 else [] + new_samples = random.sample(new_programs, min(len(new_programs), data_len)) + remaining = data_len - len(new_samples) + selected_data = new_samples + random.sample(full_dataset, remaining) + elif strategy == "half_new": + total_recent = ray.get(self.dataset_manager.get_recent_additions.remote( + dataset_key, self.global_steps, self._past_epoch_window + )) + new_programs = full_dataset[-total_recent:] if total_recent > 0 else [] + new_count = min(len(new_programs), data_len//2) + base_count = data_len - new_count + selected_data = random.sample(new_programs, new_count) + random.sample(full_dataset, base_count) + else: + raise ValueError(f"Unknown strategy: {strategy}") + + parquet_path = (self._code_dir / f'train_pred_{problem_type}.parquet').as_posix() + get_pred_code_io_data( + io_data=selected_data, + target_data_len=data_len, + problem_type=problem_type, + content_max_length=self.config.azr.data_selection_strategy.content_max_length, + output_path=parquet_path, + split='train', + tokenizer=self.tokenizer, + instruction_type=self.config.reward_fn.extraction_type, + ) + code_pred_train_dataset = RLHFDataset(parquet_files=parquet_path, + tokenizer=self.tokenizer, + prompt_key=self.config.data.prompt_key, + max_prompt_length=self.config.data.max_prompt_length, + filter_prompts=True, + return_raw_chat=self.config.data.get('return_raw_chat', False), + truncation='error', + extra_source_key=f"pred_{problem_type}_train") + # use sampler for better ckpt resume + if self.config.data.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(self.config.data.get('seed', 1)) + sampler = RandomSampler(data_source=code_pred_train_dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=code_pred_train_dataset) + + code_pred_train_dataloader = DataLoader(dataset=code_pred_train_dataset, + batch_size=self.config.data.train_batch_size, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler) + + assert len(code_pred_train_dataloader) >= 1 + return iter(code_pred_train_dataloader) + + def _compute_batch(self, batch: DataProto, metrics: dict, timing_raw: dict, problem_type: str, executor: PythonExecutor) -> tuple[DataProto, dict]: + PrettyPrinter.section_header(f"Computing batch for {problem_type}") + # pop those keys for generation + gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + + # generate a batch + with marked_timer(f'gen/{problem_type}', timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + + # recompute old_log_probs + with marked_timer(f'old_log_prob/{problem_type}', timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + rollout_old_log_probs = batch.batch["rollout_log_probs"] + actor_old_log_probs = batch.batch["old_log_probs"] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + response_length = responses.size(1) + response_mask = attention_mask[:, -response_length:] + + rollout_probs = torch.exp(rollout_old_log_probs) + actor_probs = torch.exp(actor_old_log_probs) + rollout_probs_diff = torch.abs(rollout_probs - actor_probs) + rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) + rollout_probs_diff_max = torch.max(rollout_probs_diff) + rollout_probs_diff_mean = torch.mean(rollout_probs_diff) + rollout_probs_diff_std = torch.std(rollout_probs_diff) + metrics.update( + { + "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), + "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), + "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), + } + ) + + if self.use_reference_policy: + with marked_timer(f'ref/{problem_type}', timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer(f'values/{problem_type}', timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer(f'adv/{problem_type}', timing_raw): + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + input_type_counters, output_type_counters, error_type_counters = None, None, None + # Get the appropriate type counters based on problem type + if problem_type == 'gen_code_i': + input_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('input')) + elif problem_type == 'gen_code_o': + output_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('output')) + elif problem_type == 'gen_code_e': + error_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('error')) + elif problem_type == 'gen_code_f': + input_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('input')) + output_type_counters = ray.get(self.dataset_manager.get_type_counter.remote('output')) + + # make sure actor_rollout_wg n > 1 + if problem_type.startswith('gen'): + reward_fn_kwargs = { + 'data': batch, + 'problem_type': problem_type, + 'executor': executor, # need this to check for execution errors + 'rollout_actor_wg': self.actor_rollout_wg, # need this to estimate difficulty reward + 'banned_words': self.config.azr.data_selection_strategy.banned_words, # need this to check for banned words + 'n_samples': self.config.azr.reward.n_samples, + 'input_type_counters': input_type_counters, + 'output_type_counters': output_type_counters, + 'error_type_counters': error_type_counters, + } + elif problem_type.startswith('pred'): + reward_fn_kwargs = { + 'data': batch, + 'problem_type': problem_type, + 'executor': executor, + } + # For prediction tasks, initialize empty invalid_programs + invalid_programs = [] + else: + # TTRLVR problem types (induction, deduction, abduction) 또는 기타 + reward_fn_kwargs = { + 'data': batch, + 'problem_type': problem_type, + 'executor': executor, + } + invalid_programs = [] + with marked_timer(f'reward_fn/{problem_type}', timing_raw): + PrettyPrinter.status("REWARD", f"Computing rewards for {problem_type}...", "info") + + # TTRLVR reward 사용 여부 확인 + # TTRLVR problem types: induction, deduction, abduction + is_ttrlvr_problem = problem_type in ['induction', 'deduction', 'abduction'] + + # 디버깅 로그 + PrettyPrinter.status("DEBUG", f"problem_type: {problem_type}", "info") + PrettyPrinter.status("DEBUG", f"is_ttrlvr_problem: {is_ttrlvr_problem}", "info") + PrettyPrinter.status("DEBUG", f"_use_ttrlvr_rewards: {self._use_ttrlvr_rewards}", "info") + + if self._use_ttrlvr_rewards and is_ttrlvr_problem: + # TTRLVR reward 계산 + PrettyPrinter.status("TTRLVR", f"Using TTRLVR reward calculation for {problem_type}", "info") + + + # 프롬프트와 응답 추출 + # prompts와 metadata는 non_tensor_batch에 있음 + + # DEBUG: batch 구조 확인 + if self.global_steps == 1: + print(f"\n=== TTRLVR Batch Structure Debug ===") + print(f"batch.batch keys: {list(batch.batch.keys())[:10]}") + print(f"batch.non_tensor_batch keys: {list(batch.non_tensor_batch.keys())[:10]}") + print(f"'prompts' in batch.batch: {'prompts' in batch.batch}") + print(f"'prompts' in batch.non_tensor_batch: {'prompts' in batch.non_tensor_batch}") + print(f"'ttrlvr_metadata' in batch.non_tensor_batch: {'ttrlvr_metadata' in batch.non_tensor_batch}") + print("===================================\n") + + prompts = batch.non_tensor_batch.get('prompts', []) + if isinstance(prompts, np.ndarray): + prompts = prompts.tolist() + elif torch.is_tensor(prompts): + prompts = prompts.tolist() + + # responses는 batch에 있음 + if hasattr(batch, 'batch'): + print(f"\n[DEBUG] Step {self.global_steps} - Batch keys: {list(batch.batch.keys())}") + + if hasattr(batch, 'batch') and 'responses' in batch.batch: + # 디버깅: responses 텐서 상태 확인 + response_tensor = batch.batch['responses'] + print(f"\n[DEBUG] Step {self.global_steps} - Response tensor shape: {response_tensor.shape}") + print(f"[DEBUG] Response tensor dtype: {response_tensor.dtype}") + print(f"[DEBUG] Response tensor device: {response_tensor.device}") + print(f"[DEBUG] First 10 tokens of first response: {response_tensor[0][:10].tolist() if response_tensor.shape[0] > 0 else 'Empty'}") + + try: + responses = self.tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True) + print(f"[DEBUG] Successfully decoded {len(responses)} responses") + if len(responses) > 0 and responses[0]: + print(f"[DEBUG] First response preview (first 100 chars): {responses[0][:100]}") + except Exception as e: + print(f"[DEBUG] Error decoding responses: {e}") + responses = [] + else: + print(f"[DEBUG] Step {self.global_steps} - No responses in batch") + responses = [] + + # metadata도 non_tensor_batch에 있음 + metadata = batch.non_tensor_batch.get('ttrlvr_metadata', []) + if isinstance(metadata, np.ndarray): + metadata = metadata.tolist() + elif torch.is_tensor(metadata): + metadata = metadata.tolist() + + # TTRLVR reward 계산 + rewards = self.ttrlvr_processor.compute_rewards_for_responses( + prompts=prompts, + responses=responses, + metadata=metadata + ) + + # 첫 번째 배치에서만 샘플 응답 출력 + if self.global_steps <= 3 and len(responses) > 0 and len(prompts) > 0: + PrettyPrinter.section_header(f"TTRLVR Sample Task & Response (Step {self.global_steps})") + i = 0 + meta = metadata[i] if i < len(metadata) else {} + task_type = meta.get('task_type', 'unknown') + + print(f"\n=== Step {self.global_steps} - Task Type: {task_type} ===\n") + + # 전체 프롬프트 출력 + if i < len(prompts): + prompt_str = str(prompts[i]) if isinstance(prompts[i], list) else prompts[i] + print("=" * 80) + print("FULL PROMPT:") + print("=" * 80) + print(prompt_str) + print("=" * 80) + else: + print(f"Prompt: (no prompt available, prompts length: {len(prompts)})") + + # 전체 응답 출력 + if i < len(responses): + print("\n" + "=" * 80) + print("LLM FULL RESPONSE:") + print("=" * 80) + print(responses[i]) + print("=" * 80) + else: + print(f"Response: (no response available)") + + # 보상 출력 + if i < len(rewards): + print(f"\nReward: {rewards[i]:.2f} {'✅ Correct' if rewards[i] > 0 else '❌ Wrong'}") + else: + print(f"Reward: (no reward available)") + + # 기대했던 정답도 출력 + if 'expected_solution' in meta: + print("\n" + "=" * 80) + print("EXPECTED SOLUTION:") + print("=" * 80) + print(meta['expected_solution']) + print("=" * 80) + + # tensor로 변환 + reward_tensor = torch.tensor(rewards, dtype=torch.float32, device=batch.batch['responses'].device) + reward_tensor = reward_tensor.unsqueeze(-1).expand(-1, batch.batch['responses'].size(1)) + + # 기본 메트릭 + train_metrics = { + f'{problem_type}/reward_mean': float(torch.mean(reward_tensor)), + f'{problem_type}/accuracy': sum(1 for r in rewards if r > 0) / len(rewards) if rewards else 0.0 + } + + # 더미 값들 (TTRLVR에서는 사용하지 않음) + valid_programs = [] + correct_predictions = [] + invalid_programs = [] + + else: + # 기존 AZR reward 계산 + if problem_type.startswith('gen'): + reward_tensor, train_metrics, valid_programs, correct_predictions, invalid_programs = self.reward_fn(**reward_fn_kwargs) + else: # pred type + reward_tensor, train_metrics, valid_programs, correct_predictions, invalid_programs = self.reward_fn(**reward_fn_kwargs) + PrettyPrinter.status("REWARD", f"Found {len(valid_programs) if valid_programs else 0} valid programs", "success") + + # get avg_program lines + avg_program_lines = sum(len(program['snippet'].split('\n')) for program in valid_programs) / len(valid_programs) if valid_programs else 0 + train_metrics[f'{problem_type}/avg_program_lines'] = avg_program_lines + + # Save generated data if enabled + if (self.config.azr.save_generated_data and + self.config.azr.save_data_path and + self.global_steps % self.config.azr.save_frequency == 0): + + # Save valid programs + if valid_programs and self.config.azr.save_valid_data: + save_programs_to_file( + programs=valid_programs, + problem_type=problem_type, + global_step=self.global_steps, + save_path=self.config.azr.save_data_path, + data_type="valid" + ) + + # Save invalid programs + if invalid_programs and self.config.azr.save_invalid_data: + save_programs_to_file( + programs=invalid_programs, + problem_type=problem_type, + global_step=self.global_steps, + save_path=self.config.azr.save_data_path, + data_type="invalid" + ) + + # Add statistics to metrics + train_metrics[f'{problem_type}/num_valid_saved'] = len(valid_programs) if valid_programs else 0 + train_metrics[f'{problem_type}/num_invalid_saved'] = len(invalid_programs) if invalid_programs else 0 + + # Log new programs if available + if valid_programs and self.config.azr.random_print_max_programs > 0: + PrettyPrinter.section_header(f"New {problem_type} Programs") + max_print = min(self.config.azr.random_print_max_programs, len(valid_programs)) + for program in random.sample(valid_programs, max_print): + PrettyPrinter.status(f"PROBLEM TYPE", problem_type, "info") + if 'code_f' not in problem_type: + PrettyPrinter.code_block(program['snippet'], "python") + PrettyPrinter.status("INPUT", program['input'], "info") + PrettyPrinter.status("OUTPUT", program['output'], "info") + PrettyPrinter.status("THOUGHT", program['thought'], "info") + PrettyPrinter.status("COMPOSITE FUNCTION", "YES!" if len(program['composite_functions']) > 0 else "NO!", "info") + else: + PrettyPrinter.code_block(program['snippet'], "python") + PrettyPrinter.status("INPUT", program['inputs'], "info") + PrettyPrinter.status("OUTPUT", program['outputs'], "info") + PrettyPrinter.status("MESSAGE", program['message'], "info") + PrettyPrinter.status("THOUGHT", program['thought'], "info") + print("\n" + "-"*80 + "\n") + if correct_predictions and self.config.azr.random_print_max_programs > 0: + PrettyPrinter.section_header(f"New {problem_type} Programs") + max_print = min(self.config.azr.random_print_max_programs, len(correct_predictions)) + for program in random.sample(correct_predictions, max_print): + if 'code_f' not in problem_type: + PrettyPrinter.code_block(program['program'], "python") + # also print the problem_type + PrettyPrinter.status(f"PROBLEM TYPE", problem_type, "info") + PrettyPrinter.status("INPUT", program['input'], "info") + PrettyPrinter.status("OUTPUT", program['output'], "info") + PrettyPrinter.status("THOUGHT", program['thought'], "info") + else: + PrettyPrinter.code_block(program['answer']['snippet'], "python") + PrettyPrinter.code_block(program['answer']['gold_program'], "python (gold)") + PrettyPrinter.status("HIDDEN INPUT", program['hidden_inputs'], "info") + PrettyPrinter.status("HIDDEN OUTPUT", program['hidden_outputs'], "info") + PrettyPrinter.status("GIVEN INPUT", program['given_inputs'], "info") + PrettyPrinter.status("GIVEN OUTPUT", program['given_outputs'], "info") + PrettyPrinter.status("MESSAGE", program['answer']['message'], "info") + PrettyPrinter.status("THOUGHT", program['answer']['thought'], "info") + print("\n" + "-"*80 + "\n") + + # TTRLVR problem types는 dataset manager 업데이트 불필요 + if problem_type in ['induction', 'deduction', 'abduction']: + # TTRLVR은 고정된 벤치마크 데이터를 사용하므로 dataset manager 업데이트 불필요 + PrettyPrinter.status("TTRLVR", f"Skipping dataset update for {problem_type} (benchmark data)", "info") + # ✅ TTRLVR도 token_level_scores를 설정해야 함! + batch.batch['token_level_scores'] = reward_tensor + + # 🔍 Reward matrix 로깅 + PrettyPrinter.section_header(f"TTRLVR Reward Matrix for {problem_type}") + print(f"Reward tensor shape BEFORE expansion: {torch.tensor(rewards).shape}") + print(f"Reward tensor shape AFTER expansion: {reward_tensor.shape}") + print(f"Response shape: {batch.batch['responses'].shape}") + print(f"First few rewards (sequence-level): {rewards[:4]}") + print(f"Token-level scores sum for first sample: {reward_tensor[0].sum().item()}") + print(f"Reward statistics - Mean: {reward_tensor.mean():.4f}, Std: {reward_tensor.std():.4f}, Min: {reward_tensor.min():.4f}, Max: {reward_tensor.max():.4f}") + elif problem_type.endswith('code_i'): + if valid_programs: + # Process locally first + processed_programs = process_elements(valid_programs) + # Then batch add to dataset + ray.get(self.dataset_manager.add_input_batch.remote(processed_programs, self.global_steps)) + elif problem_type.endswith('code_o'): + if valid_programs: + processed_programs = process_elements(valid_programs) + ray.get(self.dataset_manager.add_output_batch.remote(processed_programs, self.global_steps)) + elif problem_type.endswith('code_e'): + if valid_programs: + processed_programs = process_elements(valid_programs) + ray.get(self.dataset_manager.add_error_batch.remote(processed_programs, self.global_steps)) + elif problem_type.endswith('code_f'): + if valid_programs: + processed_programs = process_elements(valid_programs) + ray.get(self.dataset_manager.add_problem_batch.remote(processed_programs, self.global_steps)) + else: + raise ValueError(f'Invalid problem type: {problem_type}') + + if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_i'): + truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'input')) + PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from input dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") + if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_o'): + truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'output')) + PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from output dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") + if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_e'): + truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'error')) + PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from error dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") + if self.config.azr.data_selection_strategy.max_programs is not None and problem_type.endswith('code_f'): + truncated_length, before_length = ray.get(self.dataset_manager.truncate_datasets.remote(self.config.azr.data_selection_strategy.max_programs, 'problem')) + PrettyPrinter.status("DATA", f"Truncated {truncated_length} programs from problem dataset, max programs is {self.config.azr.data_selection_strategy.max_programs}, dataset size was {before_length} before truncation", "info") + + train_metrics = {f'{problem_type}/{k}': np.mean(v) for k, v in train_metrics.items()} + # log the number of valid programs added to the dataset + if problem_type.startswith('gen'): + if problem_type.endswith('code_i'): + dataset_key = 'input' + elif problem_type.endswith('code_o'): + dataset_key = 'output' + elif problem_type.endswith('code_e'): + dataset_key = 'error' + elif problem_type.endswith('code_f'): + dataset_key = 'problem' + else: + raise ValueError(f'Invalid problem type: {problem_type}') + train_metrics[f'{problem_type}/num_valid_programs'] = ray.get( + self.dataset_manager.get_recent_additions.remote( + dataset_key, self.global_steps, self._past_epoch_window + ) + ) + metrics.update(train_metrics) + batch.batch['token_level_scores'] = reward_tensor + + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] + + batch = compute_advantage(batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + config=self.config.algorithm) + + # 🔍 Advantage 로깅 (TTRLVR에 대해서만) + if problem_type in ['induction', 'deduction', 'abduction']: + PrettyPrinter.section_header(f"TTRLVR Advantages for {problem_type}") + advantages = batch.batch.get('advantages') + if advantages is not None: + print(f"Advantages shape: {advantages.shape}") + # 첫 몇 개 샘플의 advantage 출력 + for i in range(min(4, advantages.shape[0])): + print(f"Sample {i+1} advantages: {advantages[i].tolist()[:10]}..." if advantages.shape[1] > 10 else f"Sample {i+1} advantages: {advantages[i].tolist()}") + print(f"Advantages statistics - Mean: {advantages.mean():.4f}, Std: {advantages.std():.4f}, Min: {advantages.min():.4f}, Max: {advantages.max():.4f}") + + # response_mask와 함께 확인 + response_mask = batch.batch.get('response_mask') + if response_mask is not None: + masked_advantages = advantages * response_mask + print(f"Masked advantages statistics - Mean: {masked_advantages.sum() / response_mask.sum():.4f}") + else: + print("No advantages found in batch!") + + gc.collect() + return batch, metrics + + def _init_seed_dataset(self, problem_types: List[str]) -> Tuple[List[Dict], List[Dict], List[Dict]]: + # Initialize with seed program using the coordinator + if ('code_i' in problem_types or 'code_o' in problem_types) and ray.get(self.dataset_manager.get_dataset.remote('seed')) == []: + ray.get(self.dataset_manager.update_seed.remote([ + {'snippet': seed_program, 'input': '"Hello world"', 'output': '"Hello world"', 'imports': [], 'original_snippet': seed_program, 'composite_functions': []} + ])) + if 'code_e' in problem_types and ray.get(self.dataset_manager.get_dataset.remote('error_seed')) == []: + ray.get(self.dataset_manager.update_error_seed.remote([ + {'snippet': seed_program, 'input': '"Hello world"', 'output': 'NoError', 'imports': [], 'original_snippet': seed_program, 'composite_functions': []} + ])) + if 'code_f' in problem_types and ray.get(self.dataset_manager.get_dataset.remote('problem')) == []: + ray.get(self.dataset_manager.add_problem_batch.remote([ + { + 'snippet': seed_program, + 'inputs': ['"Hello world"', '1', "dict(a=1, b=2)", '(1.1, 1.2, 1.3)', '"[[1, 0, 0], [0, 0, 0], [0, 0, 0]]"', '1001101100010001'], + 'outputs': ['"Hello world"', '1', "dict(a=1, b=2)", '(1.1, 1.2, 1.3)', '"[[1, 0, 0], [0, 0, 0], [0, 0, 0]]"', '1001101100010001'], + 'message': 'Write a function that returns whatever you input', + 'imports': [], + } + ], self.global_steps)) + + target_size = self.config.azr.data_selection_strategy.data_len * self.config.azr.data_selection_strategy.seed_batch_factor + + while problem_types != ['code_f']: # we can skip this loop if we are only generating code_f dataset + # Get current dataset state + seed_dataset = ray.get(self.dataset_manager.get_dataset.remote('seed')) + error_dataset = ray.get(self.dataset_manager.get_dataset.remote('error_seed')) + if problem_types == ['code_e'] and len(error_dataset) >= target_size: # only generate error seed dataset + break + if 'code_e' not in problem_types and len(seed_dataset) >= target_size: # only generate seed dataset + break + if len(seed_dataset) >= target_size and len(error_dataset) >= target_size: # generating both seed and error seed dataset + break + + for problem_type in problem_types: + if problem_type == 'code_f': # skip code_f dataset, we will generate it later + continue + if problem_type == 'code_e' and len(error_dataset) >= target_size: + continue + if problem_type != 'code_e' and len(seed_dataset) >= target_size: + continue + seed_dataloader = self._create_train_code_gen_dataloader( + problem_type=problem_type, + data_len=self.config.data.train_batch_size, + dataset_key='error_seed' if problem_type == 'code_e' else 'seed', + seeding=True, + ) + for batch_dict in seed_dataloader: + batch: DataProto = DataProto.from_single_dict(batch_dict) + gen_batch = batch.pop(['input_ids', 'attention_mask', 'position_ids']) + gen_batch.meta_info = { + 'eos_token_id': self.tokenizer.eos_token_id, + 'pad_token_id': self.tokenizer.pad_token_id, + 'recompute_log_prob': False, + 'do_sample': True, + 'validate': True, + } + + # pad to be divisible by dp_size + gen_batch_padded, pad_size = pad_dataproto_to_divisor(gen_batch, self.actor_rollout_wg.world_size) + output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(gen_batch_padded) + pad_size *= self.config.actor_rollout_ref.rollout.n + + # unpad + output_gen_batch = unpad_dataproto(output_gen_batch_padded, pad_size=pad_size) + + # If we're doing multiple samples, repeat the original batch + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + batch = batch.union(output_gen_batch) + + # Store generated outputs + output_ids = batch.batch['responses'] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + local_entries = [] + local_error_entries = [] + for output_text in output_texts: + success, result = parse_code_input_output( + output_text, + parse_output=False, + remove_after_return=self.config.azr.reward.generation_reward_config.remove_after_return, + remove_comments=self.config.azr.reward.generation_reward_config.remove_comments, + remove_print=self.config.azr.reward.generation_reward_config.remove_print, + reject_multiple_functions=self.config.azr.reward.generation_reward_config.reject_multiple_functions, + f_replace_location=self.config.azr.reward.generation_reward_config.f_replace_location, + reject_test_input_in_code=self.config.azr.reward.generation_reward_config.reject_test_input_in_code, + code_location=self.config.azr.reward.generation_reward_config.code_location, + ) + if success: + code_validity, output = self._executor.check_all( + code=result['code'], + inputs=result['input'], + banned_keywords=self.config.azr.data_selection_strategy.banned_words, + check_determinism=True, + imports=list(set(result['imports'])), + check_error=problem_type == 'code_e', + banned_keywords_for_errors_and_exceptions=self.config.azr.data_selection_strategy.banned_keywords_for_errors_and_exceptions, + ) + if code_validity: + if problem_type == 'code_e': + local_error_entries.append( + { + 'snippet': result['code'], + 'input': result['input'], + 'output': output, + 'imports': result['imports'], + 'original_snippet': result['code'], + 'composite_functions': [] + } + ) + else: + local_entries.append( + { + 'snippet': result['code'], + 'input': result['input'], + 'output': output, + 'imports': result['imports'], + 'original_snippet': result['code'], + 'composite_functions': [] + } + ) + + if self.config.azr.data_selection_strategy.get('generate_seed_dataset_only', False): + with open(self.config.azr.data_selection_strategy.output_seed_path.replace('.jsonl', f'_temp.jsonl'), 'a') as f: + for entry in local_entries: + f.write(json.dumps(entry) + '\n') + + break # only use the first batch, to continuously generate more diverse data + + # Atomically update shared dataset + if problem_type != 'code_e': + # Process locally first + processed_entries = process_elements(local_entries) + # Then send to ray + added_count = ray.get( + self.dataset_manager.update_seed.remote(processed_entries) + ) + # Get updated dataset + seed_dataset = ray.get(self.dataset_manager.get_dataset.remote('seed')) + PrettyPrinter.status( + "WORKER", + f"Added {added_count} new entries (Total: {len(seed_dataset)})", + "info" + ) + PrettyPrinter.progress_bar( + current=len(seed_dataset), + total=target_size, + label="Dataset Growth" + ) + # Early exit if we've reached target size + if len(seed_dataset) >= target_size: + break + + elif problem_type == 'code_e': + # Process locally first + processed_entries = process_elements(local_error_entries) + # Then send to ray + error_added_count = ray.get( + self.dataset_manager.update_error_seed.remote(processed_entries) + ) + error_dataset = ray.get(self.dataset_manager.get_dataset.remote('error_seed')) + PrettyPrinter.status( + "WORKER", + f"Added {error_added_count} new entries (Total: {len(error_dataset)})", + "info" + ) + PrettyPrinter.progress_bar( + current=len(error_dataset), + total=target_size, + label="Error Dataset Growth" + ) + if len(error_dataset) >= target_size: + break + + # now get the code_f dataset + if 'code_f' in problem_types: + code_f_dataset = [] + all_snippets = ray.get(self.dataset_manager.get_snippets.remote()) + while len(code_f_dataset) < target_size: + # randomly sample a snippet from all_snippets + code_f_dataset = ray.get(self.dataset_manager.get_dataset.remote('problem')) + code_f_seed_dataloader = self._create_train_code_gen_dataloader( + data_len=len(all_snippets), + problem_type='code_f', + seeding=True + ) + epoch_entries = [] + for batch in code_f_seed_dataloader: + batch: DataProto = DataProto.from_single_dict(batch) + gen_batch = batch.pop(['input_ids', 'attention_mask', 'position_ids']) + gen_batch.meta_info = { + 'eos_token_id': self.tokenizer.eos_token_id, + 'pad_token_id': self.tokenizer.pad_token_id, + 'recompute_log_prob': False, + 'do_sample': True, + 'validate': True, + } + + # pad to be divisible by dp_size + gen_batch_padded, pad_size = pad_dataproto_to_divisor(gen_batch, self.actor_rollout_wg.world_size) + output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(gen_batch_padded) + pad_size *= self.config.actor_rollout_ref.rollout.n + + # unpad + output_gen_batch = unpad_dataproto(output_gen_batch_padded, pad_size=pad_size) + + # If we're doing multiple samples, repeat the original batch + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + + batch = batch.union(output_gen_batch) + + # Store generated outputs + output_ids = batch.batch['responses'] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + for idx, output_text in enumerate(output_texts): + success, result = parse_inputs_message(output_text, num_inputs=self.config.azr.data_selection_strategy.num_inputs) + if success: + outputs = [] + for ipt in result['inputs']: + code_validity, output = self._executor.check_all( + code=batch.non_tensor_batch['extra_info'][idx]['chosen_references'][0]['snippet'], + inputs=ipt, + banned_keywords=[], + check_determinism=True, + imports=batch.non_tensor_batch['extra_info'][idx]['imports'], + check_error=False, + banned_keywords_for_errors_and_exceptions=[], + ) + outputs.append(output) + if not code_validity: + break + if code_validity: + epoch_entries.append( + { + 'snippet': batch.non_tensor_batch['extra_info'][idx]['chosen_references'][0]['snippet'], + 'inputs': result['inputs'], + 'outputs': outputs, + 'message': result['message'], + 'imports': batch.non_tensor_batch['extra_info'][idx]['imports'].tolist(), + } + ) + + # Then send to ray + processed_entries = process_elements(epoch_entries) + added_count = ray.get(self.dataset_manager.add_problem_batch.remote(processed_entries, self.global_steps)) + # Get updated dataset + code_f_dataset = ray.get(self.dataset_manager.get_dataset.remote('problem')) + PrettyPrinter.status( + "WORKER", + f"Added {added_count} new entries (Total: {len(code_f_dataset)})", + "info" + ) + PrettyPrinter.progress_bar( + current=len(code_f_dataset), + total=target_size, + label="Code F Dataset Growth" + ) + if self.config.azr.data_selection_strategy.get('generate_seed_dataset_only', False): + with open(self.config.azr.data_selection_strategy.output_code_f_seed_path.replace('.jsonl', f'_temp.jsonl'), 'a') as f: + for entry in code_f_dataset: + f.write(json.dumps(entry) + '\n') + # Early exit if we've reached target size + if len(code_f_dataset) >= target_size: + break + + # truncate the dataset to the target size + ray.get(self.dataset_manager.truncate_datasets.remote(target_size, 'seed')) + ray.get(self.dataset_manager.truncate_datasets.remote(target_size, 'error_seed')) + ray.get(self.dataset_manager.truncate_datasets.remote(target_size, 'problem')) + + # Sample type statistics after seed initialization + if self.global_steps == 0: # Only log this on first initialization + type_stats = ray.get(self.dataset_manager.get_all_type_statistics.remote()) + PrettyPrinter.section_header("Initial Type Statistics") + for category, type_data in type_stats.items(): + category_display = { + 'input_types': 'Input Types', + 'output_types': 'Output Types', + 'error_types': 'Error Types' + }.get(category, category) + + if type_data: # Only show if we have data + PrettyPrinter.status(category_display.upper(), f"Total types: {len(type_data)}", "info") + for type_name, stats in sorted(type_data.items(), + key=lambda x: x[1]['total_unique'], + reverse=True)[:5]: # Show top 5 by unique count + PrettyPrinter.status( + f" {type_name}", + f"Unique: {stats['total_unique']}, Total: {stats['total_count']}", + "info" + ) + if stats['examples']: + example = stats['examples'][0] + if len(example) > 100: + example = example[:97] + "..." + PrettyPrinter.status(" Example", example, "info") + + # Final dataset from coordinator + seed_dataset = ray.get(self.dataset_manager.get_dataset.remote('seed')) + error_dataset = ray.get(self.dataset_manager.get_dataset.remote('error_seed')) + code_f_dataset = ray.get(self.dataset_manager.get_dataset.remote('problem')) + + # Modify dataset saving condition + if ('code_i' in problem_types or 'code_o' in problem_types) and self.config.azr.output_seed_path is not None: + PrettyPrinter.status("DATASET", "Writing seed dataset to JSONL file...", "info") + output_path = Path(self.config.azr.output_seed_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + for item in seed_dataset: + f.write(json.dumps(item) + '\n') + PrettyPrinter.status("DATASET", f"Saved {len(seed_dataset)} entries to {str(output_path)}", "success") + + if 'code_e' in problem_types and self.config.azr.output_error_seed_path is not None: + PrettyPrinter.status("DATASET", "Writing error seed dataset to JSONL file...", "info") + output_path = Path(self.config.azr.output_error_seed_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w') as f: + for item in error_dataset: + f.write(json.dumps(item) + '\n') + PrettyPrinter.status("DATASET", f"Saved {len(error_dataset)} entries to {str(output_path)}", "success") + + if 'code_f' in problem_types and self.config.azr.output_code_f_seed_path is not None: + PrettyPrinter.status("DATASET", "Writing code f seed dataset to JSONL file...", "info") + output_path = Path(self.config.azr.output_code_f_seed_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w') as f: + for item in code_f_dataset: + try: + f.write(json.dumps(item) + '\n') + except: + print(item) + raise Exception("Failed to save code f dataset") + PrettyPrinter.status("DATASET", f"Saved {len(code_f_dataset)} entries to {str(output_path)}", "success") + + # Show a few sample entries + if 'code_i' in problem_types or 'code_o' in problem_types: + # Print detailed dataset summary + PrettyPrinter.section_header("Seed Dataset Summary") + PrettyPrinter.table( + ["Key", "Value"], + [ + ["Total Samples", len(seed_dataset)], + ["Target Size", target_size], + ["Storage Path", self.config.azr.output_seed_path], + ["Sample Types", len(set(item['snippet'] for item in seed_dataset))], + ["Average Snippet Length", sum(len(item['snippet']) for item in seed_dataset) // len(seed_dataset) if seed_dataset else 0] + ], + title="Dataset Statistics" + ) + + PrettyPrinter.section_header("Sample Entries") + # sample 3 entries + for i, item in enumerate(random.sample(seed_dataset, self.config.azr.random_print_max_programs)): + PrettyPrinter.code_block(item['snippet'], "python") + PrettyPrinter.status("INPUT", item['input'], "info") + PrettyPrinter.status("OUTPUT", item['output'], "info") + if i < 2: # Don't print separator after last item + print("\n" + "-" * 80 + "\n") + + if 'code_e' in problem_types: + PrettyPrinter.section_header("Error Dataset Summary") + PrettyPrinter.table( + ["Key", "Value"], + [ + ["Total Samples", len(error_dataset)], + ["Target Size", target_size], + ["Storage Path", self.config.azr.output_error_seed_path], + ["Sample Types", len(set(item['snippet'] for item in error_dataset))], + ["Average Snippet Length", sum(len(item['snippet']) for item in error_dataset) // len(error_dataset) if error_dataset else 0] + ], + title="Error Dataset Statistics" + ) + PrettyPrinter.section_header("Error Sample Entries") + # sample 3 entries + for i, item in enumerate(random.sample(error_dataset, self.config.azr.random_print_max_programs)): + PrettyPrinter.code_block(item['snippet'], "python") + PrettyPrinter.status("INPUT", item['input'], "info") + PrettyPrinter.status("OUTPUT", item['output'], "info") + if i < 2: # Don't print separator after last item + print("\n" + "-" * 80 + "\n") + + if 'code_f' in problem_types: + PrettyPrinter.section_header("Code F Dataset Summary") + PrettyPrinter.table( + ["Key", "Value"], + [ + ["Total Samples", len(code_f_dataset)], + ["Target Size", target_size], + ["Storage Path", self.config.azr.output_code_f_seed_path], + ["Sample Types", len(set(item['snippet'] for item in code_f_dataset))], + ["Average Snippet Length", sum(len(item['snippet']) for item in code_f_dataset) // len(code_f_dataset) if code_f_dataset else 0] + ], + title="Code F Dataset Statistics" + ) + + PrettyPrinter.section_header("Code F Sample Entries") + # sample 3 entries + for i, item in enumerate(random.sample(code_f_dataset, self.config.azr.random_print_max_programs)): + PrettyPrinter.code_block(item['snippet'], "python") + PrettyPrinter.status("INPUTS", item['inputs'], "info") + PrettyPrinter.status("OUTPUTS", item['outputs'], "info") + PrettyPrinter.status("MESSAGE", item['message'], "info") + if i < 2: # Don't print separator after last item + print("\n" + "-" * 80 + "\n") + + return seed_dataset, error_dataset, code_f_dataset + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + + logger = ReasonRLTracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + tags=self.config.trainer.wandb_tags, + resume="must" if self.config.trainer.resume_mode == 'auto' and \ + self.config.trainer.wandb_run_id is not None else False, # Add resume flag + run_id=self.config.trainer.wandb_run_id \ + if self.config.trainer.wandb_run_id is not None else None # Pass existing run ID + ) + + self.global_steps = 0 + + # TTRLVR: 체크포인트 로딩 완전히 비활성화 (메모리 기반 학습) + # self._load_checkpoint() # 주석 처리 + + # base model chat template + if self.config.actor_rollout_ref.model.pretrained_tokenizer: + self.tokenizer.chat_template = "{%- for message in messages -%}{{- '\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}" + + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True) and self.global_steps == 0: + PrettyPrinter.section_header(f"Starting Initial Validation") + val_metrics = self._validate() + PrettyPrinter.table( + ["Metric", "Value"], + [[k, v] for k, v in val_metrics.items()], + title="Initial Validation Metrics" + ) + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get('val_only', False): + return + + # TTRLVR: 항상 새로운 데이터 사용 (체크포인트 복원 없음) + # if getattr(self, 'loaded_datasets', False): # 불필요한 체크포인트 복원 로직 제거 + + # TTRLVR: Worker group 속성 확인 및 디버그 + if not hasattr(self, 'actor_rollout_wg'): + print("DEBUG: actor_rollout_wg not found, checking all attributes:") + all_attrs = [attr for attr in dir(self) if not attr.startswith('_')] + print(f"Total non-private attributes: {len(all_attrs)}") + + # 관련 속성들 찾기 + relevant_attrs = [] + for attr in all_attrs: + attr_lower = attr.lower() + if any(keyword in attr_lower for keyword in ['wg', 'worker', 'group', 'actor', 'rollout']): + try: + attr_value = getattr(self, attr, None) + relevant_attrs.append(f" - {attr}: {type(attr_value)}") + except: + relevant_attrs.append(f" - {attr}: ") + + print("Relevant attributes:") + for attr_info in relevant_attrs[:20]: # 처음 20개만 출력 + print(attr_info) + + # 특별히 찾아야 할 속성들 + critical_attrs = ['worker_group', 'all_worker_groups', 'worker_groups', 'actor_rollout', 'rollout_ref'] + print("\nChecking critical attributes:") + for attr in critical_attrs: + if hasattr(self, attr): + print(f" ✓ {attr}: {type(getattr(self, attr, None))}") + else: + print(f" ✗ {attr}: Not found") + PrettyPrinter.section_header(f"Creating initial seed datasets") + # create init dataset + need_seed_dataset = any(problem_type != 'code_e' for problem_type in self.config.azr.problem_types) or 'code_f' in self.config.azr.problem_types + need_error_dataset = 'code_e' in self.config.azr.problem_types + need_code_f_dataset = 'code_f' in self.config.azr.problem_types + + # Initialize with defaults + seed_dataset = [] + error_dataset = [] + code_f_dataset = [] + + # Load or generate seed dataset if needed + if need_seed_dataset: + if self.config.azr.seed_dataset is not None: + PrettyPrinter.status("DATA", "Loading seed dataset from file...", "info") + with open(self.config.azr.seed_dataset, 'r') as file: + seed_dataset = [json.loads(line) for line in file] + seed_dataset = seed_dataset[:self.config.azr.data_selection_strategy.data_len * + self.config.azr.data_selection_strategy.seed_batch_factor] + PrettyPrinter.status("DATA", f"Loaded {len(seed_dataset)} seed entries", "success") + if 'code_f' in self.config.azr.problem_types: # we need seed to generate code_f + ray.get(self.dataset_manager.update_seed.remote(seed_dataset)) + else: + PrettyPrinter.status("DATA", "Seed dataset not provided, will generate", "info") + + # Load or prepare to generate error dataset if needed + if need_error_dataset: + if self.config.azr.error_seed_dataset is not None: + PrettyPrinter.status("DATA", "Loading error seed dataset from file...", "info") + with open(self.config.azr.error_seed_dataset, 'r') as file: + error_dataset = [json.loads(line) for line in file] + error_dataset = error_dataset[:self.config.azr.data_selection_strategy.data_len * + self.config.azr.data_selection_strategy.seed_batch_factor] + PrettyPrinter.status("DATA", f"Loaded {len(error_dataset)} error entries", "success") + else: + PrettyPrinter.status("DATA", "Error seed dataset not provided, will generate", "info") + + if need_code_f_dataset: + if self.config.azr.code_f_seed_dataset is not None: + PrettyPrinter.status("DATA", "Loading code f seed dataset from file...", "info") + with open(self.config.azr.code_f_seed_dataset, 'r') as file: + code_f_dataset = [json.loads(line) for line in file] + code_f_dataset = code_f_dataset[:self.config.azr.data_selection_strategy.data_len * + self.config.azr.data_selection_strategy.seed_batch_factor] + PrettyPrinter.status("DATA", f"Loaded {len(code_f_dataset)} code f entries", "success") + + # Generate missing datasets if needed + need_to_generate_seed = need_seed_dataset and len(seed_dataset) == 0 + need_to_generate_error = need_error_dataset and len(error_dataset) == 0 + need_to_generate_code_f = need_code_f_dataset and len(code_f_dataset) == 0 + + # TTRLVR: Skip seed generation when train_propose is False + if self.config.azr.train_propose and (need_to_generate_seed or need_to_generate_error or need_to_generate_code_f): + sample_problem_types = [] + for problem_type in self.config.azr.problem_types: + if problem_type == 'code_e' and need_to_generate_error: + sample_problem_types.append(problem_type) + elif problem_type != 'code_e' and need_to_generate_seed: + sample_problem_types.append(problem_type) + elif problem_type == 'code_f' and need_to_generate_code_f: + sample_problem_types.append(problem_type) + PrettyPrinter.status("DATA", f"Generating missing datasets for {', '.join(sample_problem_types)}...", "info") + generated_seed, generated_error, generated_code_f = self._init_seed_dataset(problem_types=sample_problem_types) + elif not self.config.azr.train_propose and (need_to_generate_seed or need_to_generate_error or need_to_generate_code_f): + PrettyPrinter.status("DATA", "Skipping seed generation (train_propose=False)", "info") + generated_seed, generated_error, generated_code_f = [], [], [] + + if need_to_generate_seed: + seed_dataset = generated_seed + PrettyPrinter.status("DATA", f"Generated {len(seed_dataset)} seed entries", "success") + + if need_to_generate_error: + error_dataset = generated_error + PrettyPrinter.status("DATA", f"Generated {len(error_dataset)} error entries", "success") + + if need_to_generate_code_f: + code_f_dataset = generated_code_f + PrettyPrinter.status("DATA", f"Generated {len(code_f_dataset)} code f entries", "success") + + if self.config.azr.get('generate_seed_dataset_only', False): + return + + # Now initialize datasets in dataset manager + if need_seed_dataset and self.config.azr.train_propose: + assert len(seed_dataset) >= self.config.azr.data_selection_strategy.data_len + + if 'code_i' in self.config.azr.problem_types: + # Process locally first + processed_seed_dataset = process_elements(seed_dataset) + # Initialize input dataset with seed data + ray.get(self.dataset_manager.add_input_batch.remote(processed_seed_dataset, self.global_steps)) + PrettyPrinter.status( + "DATA", + f"Input dataset initialized with {len(seed_dataset)} entries", + "success" + ) + + if 'code_o' in self.config.azr.problem_types: + processed_seed_dataset = process_elements(seed_dataset) + ray.get(self.dataset_manager.add_output_batch.remote(processed_seed_dataset, self.global_steps)) + PrettyPrinter.status( + "DATA", + f"Output dataset initialized with {len(seed_dataset)} entries", + "success" + ) + + if need_error_dataset and self.config.azr.train_propose: + assert len(error_dataset) >= self.config.azr.data_selection_strategy.data_len + processed_error_dataset = process_elements(error_dataset) + ray.get(self.dataset_manager.add_error_batch.remote(processed_error_dataset, self.global_steps)) + PrettyPrinter.status( + "DATA", + f"Error dataset initialized with {len(error_dataset)} entries", + "success" + ) + + if need_code_f_dataset and self.config.azr.train_propose: + assert len(code_f_dataset) >= self.config.azr.data_selection_strategy.data_len + processed_code_f_dataset = process_elements(code_f_dataset) + ray.get(self.dataset_manager.add_problem_batch.remote(processed_code_f_dataset, self.global_steps)) + PrettyPrinter.status( + "DATA", + f"Code f dataset initialized with {len(code_f_dataset)} entries", + "success" + ) + + # we start from step 1 + self.global_steps += 1 + if self.config.azr.pretrain_pred_steps > 0 and self.global_steps <= self.config.azr.pretrain_pred_steps: + self.pretrain_pred = True + else: + self.pretrain_pred = False + + while self.global_steps < self.total_training_steps: + PrettyPrinter.section_header(f"Training Step {self.global_steps}") + + # TTRLVR 사용시 모든 task type 배치 수집 후 업데이트 + if self._use_ttrlvr_rewards: + metrics = {} + timing_raw = {} + + with marked_timer('step', timing_raw): + # Task type별 dataloader가 있는지 확인 + if hasattr(self, 'ttrlvr_dataloaders') and self.ttrlvr_dataloaders: + # Task type별 분리된 dataloader 사용 + batches = {} + task_types = ['induction', 'deduction', 'abduction'] + + PrettyPrinter.section_header("Collecting TTRLVR Batches") + + # 각 task type별로 배치 수집 + for task_type in task_types: + if task_type not in self.ttrlvr_dataloaders: + continue + + with marked_timer(f'get_batch_{task_type}', timing_raw): + try: + # 해당 task type의 배치 가져오기 (자동 재생성 포함) + batch_dict = self._get_ttrlvr_batch(task_type) + batch: DataProto = DataProto.from_single_dict(batch_dict) + + PrettyPrinter.status("TTRLVR", + f"Processing {task_type} batch with {len(batch_dict['prompt'])} samples", + "info") + + # 배치 계산 + with marked_timer(f'compute_batch_{task_type}', timing_raw): + batch, metrics = self._compute_batch(batch, metrics, timing_raw, + problem_type=task_type, + executor=self._executor) + + # 배치 저장 + batches[task_type] = batch + + except Exception as e: + PrettyPrinter.status("TTRLVR", + f"Error processing {task_type}: {str(e)}", + "error") + continue + + # 수집된 배치가 있는 경우에만 진행 + if batches: + # 모든 배치 연결 + PrettyPrinter.status("TTRLVR", + f"Concatenating {len(batches)} batches from task types: {list(batches.keys())}", + "info") + combined_batch = DataProto.concat(list(batches.values())) + + # Actor/Critic 업데이트 (결합된 배치로) + PrettyPrinter.section_header("Starting Parameter Updates with Combined Batch") + + # update critic + if self.use_critic: + with marked_timer('update_critic', timing_raw): + critic_output = self.critic_wg.update_critic(combined_batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # update actor + if self.config.trainer.critic_warmup <= self.global_steps: + with marked_timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(combined_batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # 🔍 정밀 값 디버그 출력 + if 'actor/lr' in actor_output_metrics: + print(f"\n[DEBUG] Learning Rate: {actor_output_metrics['actor/lr']:.10f}") + if 'actor/pg_loss' in actor_output_metrics: + print(f"[DEBUG] PG Loss: {actor_output_metrics['actor/pg_loss']:.10f}") + if 'actor/grad_norm' in actor_output_metrics: + print(f"[DEBUG] Grad Norm: {actor_output_metrics['actor/grad_norm']:.10f}") + if 'actor/ppo_kl' in actor_output_metrics: + print(f"[DEBUG] PPO KL: {actor_output_metrics['actor/ppo_kl']:.10f}\n") + + # 각 task type별 메트릭 계산 + sep_batches = combined_batch.chunk(len(batches)) + for (task_type, _), sep_batch in zip(batches.items(), sep_batches): + sep_metrics = compute_data_metrics(batch=sep_batch, + use_critic=self.use_critic, + tokenizer=self.tokenizer) + sep_metrics = {f'{task_type}/{k}': v for k, v in sep_metrics.items()} + metrics.update(sep_metrics) + else: + # 배치가 없는 경우 경고만 출력 + PrettyPrinter.status("TTRLVR", "No batches collected, skipping update", "warning") + + else: + # 기존 단일 dataloader 방식 (fallback) + with marked_timer('get_batch', timing_raw): + batch_dict = next(iter(self.train_dataloader)) + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # TTRLVR 전용 batch 처리 + actual_problem_type = 'induction' # 기본값 + if 'problem_type' in batch_dict: + actual_problem_type = batch_dict['problem_type'][0] + + PrettyPrinter.status("TTRLVR", f"Processing batch with problem_type: {actual_problem_type}", "info") + + with marked_timer('ttrlvr_compute_batch', timing_raw): + batch, metrics = self._compute_batch(batch, metrics, timing_raw, + problem_type=actual_problem_type, executor=self._executor) + + # Actor/Critic 업데이트 + PrettyPrinter.section_header(f"Starting Parameter Updates") + + if self.use_critic: + with marked_timer('update_critic', timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + if self.config.trainer.critic_warmup <= self.global_steps: + with marked_timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # 🔍 정밀 값 디버그 출력 + if 'actor/lr_precise' in actor_output_metrics: + print(f"\n[DEBUG] LR precise value: {actor_output_metrics['actor/lr_precise']}") + if 'actor/pg_loss_precise' in actor_output_metrics: + print(f"[DEBUG] PG loss precise value: {actor_output_metrics['actor/pg_loss_precise']}") + if 'actor/grad_norm_precise' in actor_output_metrics: + print(f"[DEBUG] Grad norm precise value: {actor_output_metrics['actor/grad_norm_precise']}\n") + + + # step 증가 + self.global_steps += 1 + # step 타이머 블록 종료 + + # 메트릭 로깅 (원본 AZR과 동일하게 단순화) + logger.log(data=metrics, step=self.global_steps) + + # 계속 진행 (TTRLVR은 나머지 AZR 로직 스킵) + continue + if self.config.azr.data_selection_strategy.composite_scheduler.enabled: + self.scheduler_step() + + # Calculate progress metrics + progress_percentage = (self.global_steps / self.total_training_steps) * 100 + import time + if not hasattr(self, '_start_time'): + self._start_time = time.time() + elapsed_time = time.time() - self._start_time + if self.global_steps > 0: + time_per_step = elapsed_time / self.global_steps + remaining_steps = self.total_training_steps - self.global_steps + eta_seconds = remaining_steps * time_per_step + eta_hours = eta_seconds / 3600 + + PrettyPrinter.progress_bar( + current=self.global_steps, + total=self.total_training_steps, + label=f"Training Progress - {progress_percentage:.1f}% - ETA: {eta_hours:.1f}h" + ) + + # Print detailed timing info + PrettyPrinter.status("TIMING", + f"Elapsed: {elapsed_time/3600:.1f}h | Step: {time_per_step/60:.1f}min | ETA: {eta_hours:.1f}h", + "info") + else: + PrettyPrinter.progress_bar( + current=self.global_steps, + total=self.total_training_steps, + label="Training Progress - Starting..." + ) + + data_len = self.config.data.train_batch_size * self.config.azr.data_selection_strategy.update_iteration + if 'code_i' in self.config.azr.problem_types: + gen_code_i_dataloader = self._create_train_code_gen_dataloader( + problem_type='code_i', + data_len=data_len, + ) + pred_code_i_dataloader = self._create_train_code_pred_dataloader( + problem_type='code_i', + data_len=data_len, + ) + if 'code_o' in self.config.azr.problem_types: + gen_code_o_dataloader = self._create_train_code_gen_dataloader( + problem_type='code_o', + data_len=data_len, + ) + pred_code_o_dataloader = self._create_train_code_pred_dataloader( + problem_type='code_o', + data_len=data_len, + ) + + if 'code_e' in self.config.azr.problem_types: + gen_code_e_dataloader = self._create_train_code_gen_dataloader( + problem_type='code_e', + data_len=data_len, + ) + pred_code_e_dataloader = self._create_train_code_pred_dataloader( + problem_type='code_e', + data_len=data_len, + ) + + if 'code_f' in self.config.azr.problem_types: + gen_code_f_dataloader = self._create_train_code_gen_dataloader( + data_len=data_len, + problem_type='code_f', + seeding=False, + ) + pred_code_f_dataloader = self._create_train_code_pred_dataloader( + problem_type='code_f', + data_len=data_len, + ) + for _ in range(self.config.azr.data_selection_strategy.update_iteration): + metrics = {} + timing_raw = {} + batches = {} + with marked_timer('step', timing_raw): + # Clean up executor periodically + if self.global_steps - self._last_cleanup_step >= self._cleanup_frequency: + PrettyPrinter.section_header("Periodic Cleanup") + with marked_timer('cleanup', timing_raw): + self.cleanup() + self._last_cleanup_step = self.global_steps + + if 'code_i' in self.config.azr.problem_types: + if not self.pretrain_pred: + batch_dict = next(gen_code_i_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_i', executor=self._executor) + if self.config.azr.train_propose: + batches[f'gen_code_i'] = batch + batch_dict = next(pred_code_i_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_i', executor=self._executor) + batches[f'pred_code_i'] = batch + + if 'code_o' in self.config.azr.problem_types: + if not self.pretrain_pred: + batch_dict = next(gen_code_o_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_o', executor=self._executor) + if self.config.azr.train_propose: + batches[f'gen_code_o'] = batch + batch_dict = next(pred_code_o_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_o', executor=self._executor) + batches[f'pred_code_o'] = batch + + if 'code_e' in self.config.azr.problem_types: + if not self.pretrain_pred: + batch_dict = next(gen_code_e_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_e', executor=self._executor) + if self.config.azr.train_propose: + batches[f'gen_code_e'] = batch + batch_dict = next(pred_code_e_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_e', executor=self._executor) + batches[f'pred_code_e'] = batch + + if 'code_f' in self.config.azr.problem_types: + if not self.pretrain_pred: + batch_dict = next(gen_code_f_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='gen_code_f', executor=self._executor) + if self.config.azr.train_propose: + batches[f'gen_code_f'] = batch + batch_dict = next(pred_code_f_dataloader) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch, metrics = self._compute_batch(batch, metrics, timing_raw, problem_type='pred_code_f', executor=self._executor) + batches[f'pred_code_f'] = batch + + # concatenate batches + batch = DataProto.concat(list(batches.values())) + + PrettyPrinter.section_header(f"Starting Parameter Updates") + # update critic + if self.use_critic: + with marked_timer('update_critic', timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + with marked_timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # validate + PrettyPrinter.section_header(f"Starting Validation") + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ + self.global_steps % self.config.trainer.test_freq == 0: + with marked_timer('testing', timing_raw): + val_metrics: dict = self._validate() + PrettyPrinter.table( + ["Data Source", "Average Score"], + [[k, v] for k, v in val_metrics.items()], + title="Validation Results" + ) + metrics.update(val_metrics) + + # print the statistics of the number of programs in the dataset + if 'code_i' in self.config.azr.problem_types: + PrettyPrinter.status( + "DATA", + f"Number of programs in the input dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('input'))}", + "info" + ) + if 'code_o' in self.config.azr.problem_types: + PrettyPrinter.status( + "DATA", + f"Number of programs in the output dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('output'))}", + "info" + ) + if 'code_e' in self.config.azr.problem_types: + PrettyPrinter.status( + "DATA", + f"Number of programs in the error dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('error'))}", + "info" + ) + if 'code_f' in self.config.azr.problem_types: + PrettyPrinter.status( + "DATA", + f"Number of programs in the code_f dataset: {ray.get(self.dataset_manager.get_dataset_size.remote('problem'))}", + "info" + ) + if self.config.trainer.save_freq > 0 and \ + self.global_steps % self.config.trainer.save_freq == 0: + with marked_timer('save_checkpoint', timing_raw): + self._save_checkpoint() + + # collect metrics, separate problem types + all_types = [] + if 'code_i' in self.config.azr.problem_types: + if not self.pretrain_pred: + all_types.append('gen_code_i') + all_types.append('pred_code_i') + if 'code_o' in self.config.azr.problem_types: + if not self.pretrain_pred: + all_types.append('gen_code_o') + all_types.append('pred_code_o') + if 'code_e' in self.config.azr.problem_types: + if not self.pretrain_pred: + all_types.append('gen_code_e') + all_types.append('pred_code_e') + if 'code_f' in self.config.azr.problem_types: + if not self.pretrain_pred: + all_types.append('gen_code_f') + all_types.append('pred_code_f') + sep_batches = batch.chunk(len(all_types)) + for sep_batch, problem_type in zip(sep_batches, all_types): + sep_metrics = compute_data_metrics(batch=sep_batch, use_critic=self.use_critic, tokenizer=self.tokenizer) + sep_metrics = {f'{problem_type}/{k}': v for k, v in sep_metrics.items()} + metrics.update(sep_metrics) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + + # Get and log type statistics periodically + type_stats = ray.get(self.dataset_manager.get_all_type_statistics.remote()) + + # Log summary metrics about types + for category, type_data in type_stats.items(): + # Calculate diversity metrics + total_types = len(type_data) + total_unique_values = sum(stats["total_unique"] for stats in type_data.values()) + total_instances = sum(stats["total_count"] for stats in type_data.values()) + + # Add to metrics + metrics[f"types/{category}/distinct_types"] = total_types + metrics[f"types/{category}/total_unique_values"] = total_unique_values + metrics[f"types/{category}/total_instances"] = total_instances + + # Per-type metrics + for type_name, stats in type_data.items(): + metrics[f"types/{category}/{type_name}/unique_count"] = stats["total_unique"] + metrics[f"types/{category}/{type_name}/total_count"] = stats["total_count"] + + # Print type statistics summary + PrettyPrinter.section_header("Type Statistics Summary") + for category, type_data in type_stats.items(): + category_display = { + 'input_types': 'Input Types', + 'output_types': 'Output Types', + 'error_types': 'Error Types', + }.get(category, category) + + PrettyPrinter.status(category_display.upper(), f"Total types: {len(type_data)}", "info") + for type_name, stats in sorted(type_data.items(), + key=lambda x: x[1]['total_unique'], + reverse=True)[:5]: # Show top 5 by unique count + PrettyPrinter.status( + f" {type_name}", + f"Unique: {stats['total_unique']}, Total: {stats['total_count']}", + "info" + ) + if stats['examples']: + example = stats['examples'][0] + if len(example) > 100: + example = example[:97] + "..." + PrettyPrinter.status(" Example", example, "info") + + PrettyPrinter.table( + ["Category", "Value"], + [[k, v] for k, v in metrics.items()], + title="Step Metrics" + ) + + logger.log(data=metrics, step=self.global_steps) + + if self.global_steps >= self.config.azr.pretrain_pred_steps: + self.pretrain_pred = False + + self.global_steps += 1 + + gc.collect() + + if self.global_steps >= self.total_training_steps: + + # perform validation after training + if self.val_reward_fn is not None: + PrettyPrinter.section_header(f"Starting Final Validation") + val_metrics = self._validate() + PrettyPrinter.table( + ["Data Source", "Average Score"], + [[k, v] for k, v in val_metrics.items()], + title="Final Validation Results" + ) + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.save_freq > 0 and \ + (self.global_steps - 1) % self.config.trainer.save_freq != 0: + with marked_timer('save_checkpoint', timing_raw): + self._save_checkpoint() + return + + def _validate(self): + """ + The validation loop of PPO. + The only difference is logging more metrics. + """ + reward_tensor_lst = [] + data_source_lst = [] + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + all_eval_metrics = defaultdict(list) + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model': + return {} + + # Store original inputs + input_ids = test_batch.batch['input_ids'] + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + + test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids']) + test_gen_batch.meta_info = { + 'eos_token_id': self.tokenizer.eos_token_id, + 'pad_token_id': self.tokenizer.pad_token_id, + 'recompute_log_prob': False, + 'do_sample': False, + 'validate': True, + } + + # pad to be divisible by dp_size + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + PrettyPrinter.status("VALID", "Generation completed", "success") + + # Store generated outputs + output_ids = test_output_gen_batch.batch['responses'] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + + # evaluate using reward_function + reward_tensor, eval_metrics, _, _, _ = self.val_reward_fn( + test_batch, + problem_type=None, + executor=self._executor, + ) + for k, v in eval_metrics.items(): + all_eval_metrics[k].append(v) + + # Store scores + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_tensor_lst.append(reward_tensor) + data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) + data_sources = np.concatenate(data_source_lst, axis=0) + + # evaluate test_score based on data source + data_source_reward = {} + for i in range(reward_tensor.shape[0]): + data_source = data_sources[i] + if data_source not in data_source_reward: + data_source_reward[data_source] = [] + data_source_reward[data_source].append(reward_tensor[i].item()) + + metric_dict = {} + for data_source, rewards in data_source_reward.items(): + metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) + + for k, v in all_eval_metrics.items(): + metric_dict[k] = np.mean(v) + + return metric_dict + + def _save_datasets(self, save_dir: Path): + """Save input/output datasets as JSONL files""" + save_dir.mkdir(parents=True, exist_ok=True) + + # get all datasets and type counters + datasets_with_types = ray.get(self.dataset_manager.get_all_data_with_type_counters.remote()) + + # save datasets + pickle.dump(datasets_with_types, open(save_dir / 'datasets.pkl', 'wb')) + PrettyPrinter.status("SAVE", f"Saved datasets and type counters to {save_dir}", "success") + + def _load_datasets(self, save_dir: Path): + """Load input/output datasets from JSONL files""" + datasets_with_types = pickle.load(open(Path(save_dir) / 'datasets' / 'datasets.pkl', 'rb')) + + # Filter datasets based on global step + if self.global_steps > 0: + # Filter datasets that have step info + for dataset_key in ['input', 'output', 'error', 'problem']: + steps_key = f"{dataset_key}_steps" + if steps_key in datasets_with_types and dataset_key in datasets_with_types: + # Create lists of entries to keep + filtered_data = [] + filtered_steps = [] + + # Only keep entries with steps less than current global_steps + for entry, step in zip(datasets_with_types[dataset_key], datasets_with_types[steps_key]): + if step <= self.global_steps: + filtered_data.append(entry) + filtered_steps.append(step) + + # Update the datasets + datasets_with_types[dataset_key] = filtered_data + datasets_with_types[steps_key] = filtered_steps + + # Also filter the step counter dictionaries + counter_key = f"{dataset_key}_steps_counter" + if counter_key in datasets_with_types: + filtered_counter = defaultdict(int) + for step, count in datasets_with_types[counter_key].items(): + if step <= self.global_steps: + filtered_counter[step] = count + datasets_with_types[counter_key] = filtered_counter + + PrettyPrinter.status("FILTER", f"Filtered datasets to only include entries with steps <= {self.global_steps}", "info") + + ray.get(self.dataset_manager.full_load_data_with_type_counters.remote(datasets_with_types)) + PrettyPrinter.status("LOAD", f"Loaded datasets and type counters from {self.config.trainer.default_local_dir}", "success") + self.loaded_datasets = True + + def _save_checkpoint(self): + super()._save_checkpoint() + # save datasets + self._save_datasets(Path(self.config.trainer.default_local_dir) / 'datasets') + PrettyPrinter.status("SAVE", f"Saved checkpoint to {self.config.trainer.default_local_dir}", "success") + + def _load_checkpoint(self): + super()._load_checkpoint() + if self.global_steps == 0: + PrettyPrinter.section_header(f"Training from scratch") + else: + PrettyPrinter.section_header(f"Resuming training from checkpoint, step {self.global_steps}") + + # load datasets + # first check if all the datasets exist + code_dir = Path(self.config.trainer.default_local_dir) / 'code' + self._code_dir = code_dir + self.loaded_datasets = False + if self.config.trainer.resume_mode == 'auto' and os.path.exists(os.path.join(self.config.trainer.default_local_dir, 'datasets', 'datasets.pkl')): + self._load_datasets(self.config.trainer.default_local_dir) + elif self.config.trainer.resume_mode == 'disable': + if code_dir.exists(): + # delete all files and subdirectories in the code_dir + for file in code_dir.glob('**/*'): + if file.is_file(): + file.unlink() + elif file.is_dir(): + file.rmdir() + PrettyPrinter.status("Directory", f"Cleaned existing code directory at {code_dir}", "info") + elif not code_dir.exists(): + code_dir.mkdir(parents=True, exist_ok=True) + PrettyPrinter.status("Directory", f"Created new code directory at {code_dir}", "info") + + def scheduler_step(self): + if self.config.azr.data_selection_strategy.composite_scheduler.enabled: + # Update number of programs - calculate directly based on global steps + if self.global_steps >= self.config.azr.data_selection_strategy.composite_scheduler.update_num_programs_start: + steps_since_start = self.global_steps - self.config.azr.data_selection_strategy.composite_scheduler.update_num_programs_start + num_updates = steps_since_start // self.config.azr.data_selection_strategy.composite_scheduler.update_num_programs_interval + + # Calculate new value directly from initial value + increments + initial_max = self.config.azr.data_selection_strategy.max_programs_initial + new_max = min(initial_max + num_updates, self.config.azr.data_selection_strategy.composite_scheduler.num_programs_max) + + # Only log if value changed + if new_max != self.config.azr.data_selection_strategy.composite_function_n_max: + current_max = self.config.azr.data_selection_strategy.composite_function_n_max + self.config.azr.data_selection_strategy.composite_function_n_max = new_max + PrettyPrinter.status("Scheduler", f"Updated max programs from {current_max} to {new_max}", "info") + + # Update composite probability - calculate directly based on global steps + if self.global_steps >= self.config.azr.data_selection_strategy.composite_scheduler.update_probability_start: + steps_since_start = self.global_steps - self.config.azr.data_selection_strategy.composite_scheduler.update_probability_start + num_updates = steps_since_start // self.config.azr.data_selection_strategy.composite_scheduler.update_probability_interval + + # Calculate new value directly from initial value + increments + initial_prob = self.config.azr.data_selection_strategy.composite_chance_initial + new_prob = min(initial_prob + (num_updates * self.config.azr.data_selection_strategy.composite_scheduler.update_probability_increment), + self.config.azr.data_selection_strategy.composite_scheduler.update_probability_max) + + # Only log if value changed + if new_prob != self.config.azr.data_selection_strategy.composite_chance: + current_prob = self.config.azr.data_selection_strategy.composite_chance + self.config.azr.data_selection_strategy.composite_chance = new_prob + PrettyPrinter.status("Scheduler", f"Updated composite probability from {current_prob:.2f} to {new_prob:.2f}", "info") + + def save_final_datasets(self): + """Save all accumulated datasets at the end of training""" + if not (self.config.azr.save_final_datasets and self.config.azr.save_data_path): + return + + try: + PrettyPrinter.section_header("Saving Final Datasets") + all_datasets = ray.get(self.dataset_manager.get_all_datasets.remote()) + + save_dir = os.path.join(self.config.azr.save_data_path, "final_datasets") + os.makedirs(save_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + for dataset_name, dataset_data in all_datasets.items(): + if dataset_data and isinstance(dataset_data, list): + filename = f"{dataset_name}_final_{timestamp}.jsonl" + file_path = os.path.join(save_dir, filename) + + with open(file_path, 'w', encoding='utf-8') as f: + for item in dataset_data: + # Add final metadata + item_with_meta = { + **item, + 'final_meta': { + 'dataset_type': dataset_name, + 'total_items': len(dataset_data), + 'final_timestamp': timestamp, + 'training_completed': True + } + } + f.write(json.dumps(item_with_meta, ensure_ascii=False) + '\n') + + PrettyPrinter.status("FINAL SAVE", f"Saved {len(dataset_data)} items to {filename}", "success") + + except Exception as e: + PrettyPrinter.status("FINAL SAVE", f"Failed to save final datasets: {e}", "error") + + def cleanup_and_save(self): + """Cleanup and save final data before training ends""" + self.save_final_datasets() + # Any other cleanup operations can be added here