#!/usr/bin/env python3 """ TTRLVR + AZR 통합 학습 메인 스크립트 (Unified Version) UnifiedTTRLVRTrainer를 사용하여 하나의 VeRL 세션에서 전체 학습 진행: 1. VeRL worker 한 번만 초기화 2. 각 라운드마다 같은 vLLM으로 Phase 1-4 실행 3. 같은 vLLM으로 Phase 5 PPO 학습 4. 동기화 문제 완전 해결 (dummy_dtensor 사용 가능) 사용 예시: # 일반 학습 python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 python train_ttrlvr_azr_unified.py --benchmark humaneval --problems 5 --rounds 10 python train_ttrlvr_azr_unified.py --benchmark mbpp --problem-id Mbpp/2 --rounds 5 # GPU 지정 python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 --gpu 0,1,2,3 """ import os import sys import argparse import json from datetime import datetime from pathlib import Path from typing import List import warnings import ray import torch # Gradient checkpointing 관련 경고 필터링 warnings.filterwarnings("ignore", message=".*Caching is incompatible with gradient checkpointing.*") # 경로 설정 - 상대 경로 사용 project_root = Path(__file__).parent.parent # TestTime-RLVR-v2 디렉토리 sys.path.append(str(project_root)) # verl과 Absolute-Zero-Reasoner는 상위 디렉토리에서 찾기 parent_dir = project_root.parent for lib_name in ['verl', 'Absolute-Zero-Reasoner']: lib_path = parent_dir / lib_name if lib_path.exists(): sys.path.append(str(lib_path)) # pip로 설치된 경우는 자동으로 import 됨 # AZR/VeRL 모듈 임포트 (main_azr_ppo.py와 동일한 구조) from verl import DataProto from omegaconf import OmegaConf import ray from verl.utils import hf_tokenizer from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role from verl.single_controller.ray import RayWorkerGroup from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter # TTRLVR 모듈 임포트 from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig from absolute_zero_reasoner.testtime.logger import TestTimeLogger # Ray 정리 변수 _trainer_instance = None _logger_instance = None def cleanup_ray(): """Ray 클러스터 정리 함수""" global _trainer_instance, _logger_instance try: if _logger_instance: _logger_instance.log_info("🔄 강제 종료 감지: Ray 클러스터 정리 중...") except: print("🔄 강제 종료 감지: Ray 클러스터 정리 중...") try: # IterativeTrainer 정리 if _trainer_instance: _trainer_instance.cleanup_ray() except Exception as e: try: if _logger_instance: _logger_instance.log_error(f"IterativeTrainer 정리 실패: {e}") except: print(f"IterativeTrainer 정리 실패: {e}") try: # 현재 프로그램의 Ray만 종료 (안전한 방법) import ray if ray.is_initialized(): ray.shutdown() except Exception as e: try: if _logger_instance: _logger_instance.log_error(f"Ray 종료 실패: {e}") except: print(f"Ray 종료 실패: {e}") try: if _logger_instance: _logger_instance.log_info("✅ Ray 정리 완료") except: print("✅ Ray 정리 완료") def signal_handler(signum, frame): """시그널 핸들러 (Ctrl+C, 강제 종료 등)""" try: if _logger_instance: _logger_instance.log_info(f"🛑 시그널 {signum} 수신: 프로그램 종료 중...") except: print(f"🛑 시그널 {signum} 수신: 프로그램 종료 중...") cleanup_ray() sys.exit(1) def parse_arguments(): """명령행 인자 파싱""" parser = argparse.ArgumentParser( description='TTRLVR + AZR 통합 반복 학습', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 예시: # MBPP 10문제로 30라운드 학습 python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 # HumanEval 5문제로 10라운드 학습 python train_ttrlvr_azr.py --benchmark humaneval --problems 5 --rounds 10 # 15라운드부터 재개 python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --resume 15 # 특정 GPU 사용 python train_ttrlvr_azr.py --benchmark mbpp --problems 10 --rounds 30 --gpu 4 """ ) parser.add_argument( '--benchmark', choices=['mbpp', 'humaneval'], default='mbpp', help='벤치마크 선택 (기본값: mbpp)' ) parser.add_argument( '--problems', type=int, default=10, help='문제 수 (기본값: 10)' ) parser.add_argument( '--problem-id', type=str, help='특정 문제 ID (예: HumanEval/1, Mbpp/10)' ) parser.add_argument( '--rounds', type=int, default=30, help='총 라운드 수 (기본값: 30)' ) parser.add_argument( '--resume', type=int, default=1, help='재개할 라운드 번호 (기본값: 1)' ) parser.add_argument( '--gpu', type=str, default='5', help='사용할 GPU 번호 (단일: 5, 다중: 1,2,3,5)' ) parser.add_argument( '--output-dir', type=str, default='./results/ttrlvr_azr', help='결과 저장 디렉토리 (기본값: ./results/ttrlvr_azr)' ) parser.add_argument( '--config', type=str, help='설정 파일 경로 (선택사항)' ) parser.add_argument( '--model', type=str, default='Qwen/Qwen2.5-7B', help='사용할 모델 (기본값: Qwen/Qwen2.5-7B)' ) parser.add_argument( '--debug', action='store_true', help='디버그 모드 활성화' ) parser.add_argument( '--batch-size', type=int, default=24, help='학습 배치 크기 (기본값: 24, OOM 시 줄이기)' ) parser.add_argument( '--batch-epochs', type=int, default=1, help='배치당 에폭 수 (기본값: 1, 더 많은 학습을 위해 증가 가능)' ) parser.add_argument( '--num-programs', type=int, default=4, help='생성할 다양한 프로그램 수 (기본값: 4, 더 다양한 데이터를 위해 증가 가능)' ) parser.add_argument( '--input-generation-rounds', type=int, default=3, help='다양한 입력 생성 라운드 수 (기본값: 3, 라운드당 5개씩 생성)' ) parser.add_argument( '--parallel-batch-size', type=int, default=4, help='동시 처리할 프롬프트 수 (기본값: 4, GPU 메모리에 따라 조정)' ) parser.add_argument( '--eval-rounds', type=int, default=5, help='매 라운드 정확도 측정 횟수 (기본값: 5, 더 정확한 평가를 위해 증가 가능)' ) parser.add_argument( '--skip-task-eval', action='store_true', help='Task evaluation(4단계) 스킵하여 빠른 테스트 (데이터 생성 후 바로 VeRL 학습)' ) parser.add_argument( '--save-every-round', action='store_true', help='매 라운드마다 체크포인트 저장 (기본값: False)' ) parser.add_argument( '--save-round-interval', type=int, default=5, help='체크포인트 저장 간격 (예: 5 = 5라운드마다 저장, 기본값: 5)' ) parser.add_argument( '--offline', action='store_true', help='Ray를 오프라인 모드로 실행 (외부 연결 차단, 기본값: False)' ) return parser.parse_args() def setup_environment(gpu_id: str, batch_size: int = None): """환경 변수 설정 - run_ttrlvr_azr_training.sh와 동일하게""" # GPU 설정 - 명령행 인자를 우선 사용하고, 없으면 기존 환경변수 사용 if gpu_id: os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id print(f"🎯 Using command line GPU setting: {gpu_id}") elif 'CUDA_VISIBLE_DEVICES' in os.environ and os.environ['CUDA_VISIBLE_DEVICES']: print(f"🎯 Using existing CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}") else: os.environ['CUDA_VISIBLE_DEVICES'] = '5' # 기본값 print(f"🎯 Using default GPU: 5") # VLLM 설정 (run_ttrlvr_azr_training.sh와 동일) os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASH_ATTN' # Ray 설정 (run_ttrlvr_azr_training.sh와 동일) os.environ['RAY_memory_monitor_refresh_ms'] = '0' os.environ['RAY_LOGGING_LEVEL'] = 'DEBUG' # Hydra 설정 os.environ['HYDRA_FULL_ERROR'] = '1' # Python 경로 설정 (verl 경로 추가) - 상대 경로 사용 pythonpath = os.environ.get('PYTHONPATH', '') project_root = Path(__file__).parent.parent # TestTime-RLVR-v2 directory # 프로젝트 경로들 설정 paths_to_add = [str(project_root)] parent_dir = project_root.parent # verl과 Absolute-Zero-Reasoner 경로 추가 (존재하는 경우) if (parent_dir / 'verl').exists(): paths_to_add.append(str(parent_dir / 'verl')) if (parent_dir / 'Absolute-Zero-Reasoner').exists(): paths_to_add.append(str(parent_dir / 'Absolute-Zero-Reasoner')) # PYTHONPATH 업데이트 for path in paths_to_add: if path not in pythonpath: pythonpath = f"{path}:{pythonpath}" if pythonpath else path os.environ['PYTHONPATH'] = pythonpath # batch size 설정 if batch_size is not None: os.environ['TRAIN_BATCH_SIZE'] = str(batch_size) # 추가 환경 변수 설정 (위에서 설정하지 않은 것들만) # 기본값으로 홈 디렉토리 사용, 환경변수로 오버라이드 가능 os.environ.setdefault('HF_HOME', os.path.expanduser('~/.cache/huggingface')) os.environ.setdefault('TRANSFORMERS_CACHE', os.path.expanduser('~/.cache/huggingface')) os.environ['TOKENIZERS_PARALLELISM'] = 'false' # PYTHONPATH 설정 - 상대 경로 사용 current_pythonpath = os.environ.get('PYTHONPATH', '') project_root = Path(__file__).parent.parent # TestTime-RLVR-v2 directory new_paths = [ str(project_root) # site-packages는 자동으로 포함되므로 제거 ] for path in new_paths: if path not in current_pythonpath: current_pythonpath = f"{path}:{current_pythonpath}" if current_pythonpath else path os.environ['PYTHONPATH'] = current_pythonpath def setup_offline_environment(): """오프라인 Ray 환경 설정 (AZR 방식과 동일)""" import os import subprocess # Ray 오프라인 환경변수 설정 os.environ["RAY_DISABLE_IMPORT_WARNING"] = "1" os.environ["RAY_USAGE_STATS_ENABLED"] = "0" os.environ["RAY_DISABLE_RUNTIME_METRICS"] = "1" os.environ["RAY_ADDRESS"] = "" os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1" # Ray GPU 자동 설정 비활성화 (핵심!) os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" # AMD GPU 환경변수 완전 제거 os.environ.pop("HIP_VISIBLE_DEVICES", None) os.environ.pop("ROCR_VISIBLE_DEVICES", None) os.environ.pop("HSA_VISIBLE_DEVICES", None) # Ray 프로세스 정리 try: subprocess.run(["pkill", "-f", "ray"], capture_output=True, timeout=2) except: pass try: subprocess.run(["ray", "stop", "--force"], capture_output=True, timeout=2) except: pass print("🔧 오프라인 Ray 환경 설정 완료") print(" - 외부 연결 차단") print(" - GPU 충돌 방지 설정") print(" - Ray 프로세스 정리") def load_benchmark_problems(benchmark_config: BenchmarkConfig) -> List[str]: """벤치마크에서 문제 ID 목록 로드 (기존 TTRLVR 방식 사용)""" problems = [] if benchmark_config.name == 'mbpp': # MBPP+ EvalPlus 표준 데이터 로딩 try: from evalplus.data.mbpp import get_mbpp_plus mbpp_problems = get_mbpp_plus() # 자동으로 mbpp_deserialize_inputs 적용됨 problems = list(mbpp_problems.keys()) print(f"✅ MBPP+ 데이터 로드 성공: {len(problems)}개 문제 (EvalPlus 표준 방식)") except Exception as e: print(f"❌ MBPP+ EvalPlus 로딩 실패, 기존 방식 사용: {e}") # Fallback to original method data_path = benchmark_config.data_path if os.path.exists(data_path): with open(data_path, 'r') as f: for line in f: problem = json.loads(line.strip()) problems.append(problem['task_id']) elif benchmark_config.name == 'humaneval': # HumanEval+ EvalPlus 표준 데이터 로딩 try: from evalplus.data.humaneval import get_human_eval_plus humaneval_problems = get_human_eval_plus() problems = list(humaneval_problems.keys()) print(f"✅ HumanEval+ 데이터 로드 성공: {len(problems)}개 문제 (EvalPlus 표준 방식)") except Exception as e: print(f"❌ HumanEval+ EvalPlus 로딩 실패, 기존 방식 사용: {e}") # Fallback to original method data_path = benchmark_config.data_path if os.path.exists(data_path): with open(data_path, 'r') as f: for line in f: problem = json.loads(line.strip()) problems.append(problem['task_id']) return problems def create_problem_list(benchmark: str, num_problems: int, specific_problem_id: str = None) -> list: """벤치마크별 문제 ID 리스트 생성 (기존 TTRLVR 방식 사용)""" # BenchmarkConfig 생성 benchmark_config = create_benchmark_config(benchmark) # 전체 문제 목록 로드 all_problems = load_benchmark_problems(benchmark_config) if not all_problems: raise ValueError(f"No problems found for benchmark: {benchmark}") # 특정 문제 ID가 지정된 경우 if specific_problem_id: if specific_problem_id in all_problems: return [specific_problem_id] else: raise ValueError(f"Problem ID '{specific_problem_id}' not found in {benchmark} benchmark") # 요청된 수만큼 문제 선택 if num_problems <= 0 or num_problems > len(all_problems): return all_problems else: return all_problems[:num_problems] def create_config(args) -> TestTimeConfig: """TestTimeConfig 생성""" config = TestTimeConfig() # 기본 설정 config.model_name = args.model # 인자로 받은 모델 사용 config.max_new_tokens = 512 config.temperature = 0.05 config.baseline_evaluation_rounds = args.eval_rounds # 평가 횟수 # 프로그램 생성 설정 config.num_program_variations = args.num_programs # 다양한 프로그램 개수 config.input_generation_rounds = args.input_generation_rounds # 입력 생성 라운드 수 config.parallel_batch_size = args.parallel_batch_size # 동시 처리 프롬프트 수 # Task evaluation 스킵 설정 config.skip_task_evaluation = args.skip_task_eval # Task evaluation 스킵 여부 # 디버그 모드 if args.debug: config.debug = True config.verbose = True return config def create_benchmark_config(benchmark: str) -> BenchmarkConfig: """BenchmarkConfig 생성 (기존 TTRLVR 방식 사용)""" # 기존 TTRLVR 시스템과 동일한 방식으로 BenchmarkConfig 생성 # TestTime-RLVR-v2 디렉토리를 base로 사용 base_dir = Path(__file__).parent.parent # TestTime-RLVR-v2 directory if benchmark == 'mbpp': benchmark_config = BenchmarkConfig.get_mbpp_config() benchmark_config.data_path = str(base_dir / 'evaluation/code_eval/data/MbppPlus.jsonl') return benchmark_config elif benchmark == 'humaneval': benchmark_config = BenchmarkConfig.get_humaneval_config() benchmark_config.data_path = str(base_dir / 'evaluation/code_eval/data/HumanEvalPlus.jsonl') return benchmark_config else: raise ValueError(f"Unknown benchmark: {benchmark}") def run_step5_only_mode(args): """Step 5 전용 모드 실행""" from pathlib import Path print(f"🎓 Running Step 5 (VeRL training) only mode") print(f"📂 Data path: {args.data_path}") # 데이터 경로 검증 data_path = Path(args.data_path) if not data_path.exists(): print(f"❌ Error: Data path does not exist: {data_path}") return 1 # 필수 파일들 확인 required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet'] missing_files = [] for file_name in required_files: if not (data_path / file_name).exists(): missing_files.append(file_name) if missing_files: print(f"❌ Error: Missing required files: {missing_files}") return 1 print(f"✅ Found all required training data files in: {data_path}") # 파일 크기 정보 출력 for file_name in required_files: file_path = data_path / file_name file_size = file_path.stat().st_size print(f" 📄 {file_name}: {file_size:,} bytes") # 환경 설정 setup_environment(args.gpu, args.batch_size) # 설정 파일 경로 결정 config_path = args.config if not config_path: # GPU 개수에 따라 기본 설정 파일 선택 gpu_count = len(args.gpu.split(',')) if args.gpu else 1 if gpu_count >= 4: config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_ppo_4gpu.yaml') else: config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_ppo_1gpu.yaml') print(f"🚀 Initializing trainer with config: {config_path}") # TestTimeConfig 생성 (기존 create_config 함수 사용) config = create_config(args) # 로거 초기화 logger = TestTimeLogger() # IterativeTrainer 초기화 global _trainer_instance _trainer_instance = IterativeTrainer( config=config, logger=logger, verl_config_path=config_path ) # Step 5 전용 VeRL 학습 실행 try: result = _trainer_instance.run_verl_training_only( training_data_path=str(data_path), round_num=args.resume, # resume을 round number로 사용 experiment_name=f"step5_only_{args.benchmark}" ) if result.get('success', False): print(f"✅ VeRL training completed successfully!") print(f"⏱️ Duration: {result.get('duration', 'N/A')} seconds") if 'model_path' in result: print(f"🤖 Updated model: {result['model_path']}") return 0 else: print(f"❌ VeRL training failed: {result.get('error', 'Unknown error')}") return 1 except Exception as e: print(f"💥 Training failed with exception: {e}") import traceback traceback.print_exc() return 1 def main(): """메인 실행 함수 - UnifiedTTRLVRTrainer 사용""" # 인자 파싱 args = parse_arguments() # 환경 설정 setup_environment(args.gpu) # 출력 디렉토리 생성 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') output_dir = os.path.join( args.output_dir, f'ttrlvr_unified_{args.benchmark}_{args.rounds}rounds_{timestamp}' ) os.makedirs(output_dir, exist_ok=True) PrettyPrinter.section_header("🚀 TTRLVR Unified Training") PrettyPrinter.status("Config", f"Benchmark: {args.benchmark}", "info") PrettyPrinter.status("Config", f"Rounds: {args.rounds}", "info") PrettyPrinter.status("Config", f"Output: {output_dir}", "info") # 문제 리스트 생성 problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id) PrettyPrinter.status("Problems", f"Selected {len(problem_ids)} problems", "info") # TTRLVR 설정 ttrlvr_config = { 'num_programs': args.num_programs, 'input_generation_rounds': args.input_generation_rounds, 'parallel_batch_size': args.parallel_batch_size, } # VeRL config 파일 경로 if args.config: config_path = os.path.abspath(args.config) else: # 현재는 4GPU config만 사용 (추후 1GPU config 추가 시 수정) config_path = str(Path(__file__).parent / 'configs/ttrlvr_azr_unified_4gpu.yaml') PrettyPrinter.status("Config", f"Using VeRL config: {config_path}", "info") try: # ============================================ # VeRL을 통해 UnifiedTTRLVRTrainer 실행 # ============================================ # VeRL 실행을 위한 환경 변수 설정 os.environ['TTRLVR_PROBLEM_IDS'] = json.dumps(problem_ids) os.environ['TTRLVR_TOTAL_ROUNDS'] = str(args.rounds) os.environ['TTRLVR_OUTPUT_DIR'] = output_dir os.environ['TTRLVR_CONFIG'] = json.dumps(ttrlvr_config) # ============================================ # AZR 형식으로 초기화, TTRLVR 방식으로 실행 # (main_azr_ppo.py의 구조를 따르되 UnifiedTTRLVRTrainer 사용) # ============================================ PrettyPrinter.section_header("🎯 Starting UnifiedTTRLVRTrainer (AZR-style initialization)") # 1. Config 로드 (main_azr_ppo.py와 동일) PrettyPrinter.status("Config", f"Loading {config_path}", "info") verl_config = OmegaConf.load(config_path) # Config 업데이트 verl_config.trainer.project_name = f'ttrlvr_unified_{args.benchmark}' verl_config.trainer.experiment_name = f'round_{args.rounds}_{timestamp}' verl_config.trainer.total_epochs = args.rounds # 2. Ray 초기화 (오프라인/온라인 모드 분기) if not ray.is_initialized(): cuda_visible_devices = args.gpu or "0,1,2,3" if args.offline: # 오프라인 모드: AZR 수정과 동일 PrettyPrinter.status("Ray", f"Initializing Ray in OFFLINE mode (GPUs: {cuda_visible_devices})", "warning") # 오프라인 환경 설정 (이미 함수가 정의되어 있음) setup_offline_environment() # GPU 환경변수 강제 설정 os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices # AMD GPU 변수 제거 (setup_offline_environment에서 이미 처리) # 기존 Ray 강제 종료 if ray.is_initialized(): ray.shutdown() # 오프라인용 단순 초기화 (runtime_env 완전 제거) ray.init( local_mode=False, # 클러스터 모드지만 로컬에서 실행 ignore_reinit_error=True, # 재초기화 에러 무시 num_cpus=verl_config.ray_init.num_cpus, ) PrettyPrinter.status("Ray", "Ray initialized in OFFLINE mode", "success") else: # 온라인 모드: 기존 방식 유지 PrettyPrinter.status("Ray", f"Initializing Ray cluster (GPUs: {cuda_visible_devices})", "info") ray.init( runtime_env={"env_vars": { "TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", "CUDA_VISIBLE_DEVICES": cuda_visible_devices }}, num_cpus=verl_config.ray_init.num_cpus, # num_gpus 지정하지 않음 - Ray가 자동으로 GPU 감지 (AZR 원본과 동일) ) # 3. Tokenizer 로드 (main_azr_ppo.py와 동일) model_path = verl_config.actor_rollout_ref.model.path PrettyPrinter.status("Model", f"Loading tokenizer from {model_path}", "info") tokenizer = hf_tokenizer(model_path) # 4. Worker 매핑 설정 (main_azr_ppo.py와 동일) role_worker_mapping = {} # Actor/Rollout Worker 선택 if verl_config.actor_rollout_ref.rollout.name == 'vllm': if verl_config.actor_rollout_ref.rollout.mode == 'async': actor_rollout_cls = AsyncActorRolloutRefWorker else: actor_rollout_cls = ActorRolloutRefWorker # AZR 원본과 동일하게 ray.remote() 사용 role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) PrettyPrinter.status("Workers", f"Using {actor_rollout_cls.__name__} for ActorRollout", "info") # Critic Worker (REINFORCE++는 사용 안함) if verl_config.critic.include_critic: # AZR 원본과 동일하게 ray.remote() 사용 role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) PrettyPrinter.status("Workers", "Including Critic worker", "info") else: PrettyPrinter.status("Workers", "No Critic (using REINFORCE++)", "info") # 5. ResourcePoolManager 생성 (main_azr_ppo.py와 동일) # AZR 스타일로 resource_pool_spec 직접 생성 global_pool_id = "global_pool" n_gpus_per_node = verl_config.trainer.n_gpus_per_node nnodes = verl_config.trainer.nnodes resource_pool_spec = { global_pool_id: [n_gpus_per_node] * nnodes, } mapping = { Role.ActorRollout: global_pool_id, } if verl_config.critic.include_critic: mapping[Role.Critic] = global_pool_id resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) PrettyPrinter.status("Resources", f"Created ResourcePoolManager with {len(resource_pool_spec)} pools", "info") # 6. UnifiedTTRLVRTrainer 생성 (CodeIORayPPOTrainer 대신) from trainer.unified_ttrlvr_trainer import UnifiedTTRLVRTrainer PrettyPrinter.status("Trainer", "Creating UnifiedTTRLVRTrainer", "info") trainer = UnifiedTTRLVRTrainer( past_epoch_window=verl_config.azr.past_epoch_window, # AZR 필수 파라미터 (TTRLVR은 매 라운드 새 데이터) config=verl_config, tokenizer=tokenizer, processor=None, # TTRLVR은 텍스트 전용이므로 불필요 role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=RayWorkerGroup, reward_fn=None, # TTRLVR은 자체 보상 계산 사용 (use_ttrlvr_rewards=True) val_reward_fn=None, # TTRLVR은 검증 없음 # TTRLVR 특화 파라미터 ttrlvr_config=ttrlvr_config, problem_ids=problem_ids, total_rounds=args.rounds, output_dir=output_dir ) # 7. 학습 실행 (main_azr_ppo.py와 동일) PrettyPrinter.section_header("🚀 Starting Training") PrettyPrinter.status("Training", f"Running {args.rounds} rounds with {len(problem_ids)} problems", "info") trainer.fit() # 내부에서 TTRLVR Phase 1-5 실행 PrettyPrinter.section_header("✅ Training Complete") return 0 except KeyboardInterrupt: PrettyPrinter.status("Interrupt", "Training interrupted by user", "warning") return 130 except Exception as e: PrettyPrinter.status("Error", f"Training failed: {e}", "error") import traceback traceback.print_exc() return 1 finally: # Ray cleanup if ray.is_initialized(): ray.shutdown() PrettyPrinter.status("Cleanup", "Resources cleaned up", "success") if __name__ == '__main__': exit_code = main() sys.exit(exit_code)