neural-mesh-v2 / test /utils /iterative_trainer.py
hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
#!/usr/bin/env python3
"""
TTRLVR + AZR ๋ฐ˜๋ณต ํ•™์Šต ํŠธ๋ ˆ์ด๋„ˆ
30๋ผ์šด๋“œ ๋ฐ˜๋ณต ํ•™์Šต์„ ๊ด€๋ฆฌํ•˜๋ฉฐ, ๊ฐ ๋ผ์šด๋“œ๋งˆ๋‹ค:
1. TTRLVR ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ (i,p,o) โ†’ tasks ์ƒ์„ฑ
2. ํ•ด๋‹น ๋ผ์šด๋“œ ๋ฐ์ดํ„ฐ๋กœ ์‹ค์ œ AZR์˜ CodeIORayPPOTrainer ํ•™์Šต
3. ๊ฐœ์„ ๋œ ๋ชจ๋ธ๋กœ ๋‹ค์Œ ๋ผ์šด๋“œ ์ง„ํ–‰
"""
import os
import sys
import json
import pandas as pd
import ray
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any, Optional
# TTRLVR ๋ชจ๋“ˆ ์ž„ํฌํŠธ
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2')
from absolute_zero_reasoner.testtime.complete_pipeline import CompleteTestTimePipeline
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig
from absolute_zero_reasoner.testtime.logger import TestTimeLogger
# VeRL ๊ธฐ๋ฐ˜ AZR ์‹คํ–‰์„ ์œ„ํ•œ ์ž„ํฌํŠธ
from utils.checkpoint_manager import CheckpointManager
from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer
from utils.custom_ray_trainer import CustomCodeIORayPPOTrainer
import hydra
from hydra.core.global_hydra import GlobalHydra
class IterativeTrainer:
"""TTRLVR + AZR ๋ฐ˜๋ณต ํ•™์Šต ๊ด€๋ฆฌ์ž"""
def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None, batch_epochs: int = 1, verl_config_path: str = None, save_every_round: bool = False, save_round_interval: int = 5):
self.config = config
self.logger = logger or TestTimeLogger()
self.batch_epochs = batch_epochs # ๋ฐฐ์น˜๋‹น ์—ํญ ์ˆ˜ ์ €์žฅ
self.verl_config_path = verl_config_path # VeRL config ํŒŒ์ผ ๊ฒฝ๋กœ
self.save_every_round = save_every_round # ๋งค ๋ผ์šด๋“œ ์ €์žฅ ์—ฌ๋ถ€
self.save_round_interval = save_round_interval # ์ €์žฅ ๊ฐ„๊ฒฉ
# GPU ๊ฐœ์ˆ˜ ๊ฐ์ง€ ๋ฐ ์‹คํ–‰ ๋ชจ๋“œ ๊ฒฐ์ •
self.available_gpus = self._detect_available_gpus()
self.execution_mode = self._determine_execution_mode()
self.logger.log_info(f"๐ŸŽฏ Detected {len(self.available_gpus)} GPUs: {self.available_gpus}")
self.logger.log_info(f"๐ŸŽฏ Execution mode: {self.execution_mode}")
# ์™„์ „ํ•œ ํŒŒ์ดํ”„๋ผ์ธ ์ธ์Šคํ„ด์Šค (lazy initialization)
self.complete_pipeline = None
# ์ฒดํฌํฌ์ธํŠธ ๋งค๋‹ˆ์ € ์ดˆ๊ธฐํ™”
self.checkpoint_manager = CheckpointManager(logger=self.logger)
# ํ•™์Šต ์ƒํƒœ ์ถ”์ 
self.original_model_name = config.model_name # ์›๋ณธ ๋ชจ๋ธ ์ด๋ฆ„ ์ €์žฅ (tokenizer ๋กœ๋“œ์šฉ)
self.current_model_path = config.model_name # config์—์„œ ๋ชจ๋ธ ์ด๋ฆ„ ๊ฐ€์ ธ์˜ค๊ธฐ
self.current_model = None # ํ˜„์žฌ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ €์žฅ (VeRL๊ณผ ๊ณต์œ ์šฉ)
self.round_results = {}
self.checkpoint_dir = "/data/RLVR/checkpoints/ttrlvr_azr"
# Ray Actor๋กœ ํŒŒ์ดํ”„๋ผ์ธ ๊ด€๋ฆฌ (VeRL ํŒจํ„ด)
self.remote_pipeline = None
# VeRL trainer ์ธ์Šคํ„ด์Šค (ํ•œ ๋ฒˆ๋งŒ ์ดˆ๊ธฐํ™”, ๋ฉ”๋ชจ๋ฆฌ์—์„œ ๊ณ„์† ์‚ฌ์šฉ)
self.verl_trainer = None
self.verl_config = None
self.ray_initialized = False
# ํ•™์Šต ์‹คํ–‰ ์‹œ๊ฐ„ ๊ธฐ๋ก
self.start_time = None
self.round_times = {}
def cleanup(self):
"""Ray ํด๋Ÿฌ์Šคํ„ฐ ๋ฐ ๊ด€๋ จ ๋ฆฌ์†Œ์Šค ์ •๋ฆฌ"""
try:
self.logger.log_info("๐Ÿงน Starting cleanup process...")
# VeRL trainer ์ •๋ฆฌ
if hasattr(self, 'verl_trainer') and self.verl_trainer is not None:
try:
self.logger.log_info(" - Cleaning up VeRL trainer...")
# VeRL trainer์˜ Ray actors ์ •๋ฆฌ
if hasattr(self.verl_trainer, 'shutdown'):
self.verl_trainer.shutdown()
self.verl_trainer = None
except Exception as e:
self.logger.log_warning(f" - VeRL trainer cleanup warning: {e}")
# Remote pipeline actor ์ข…๋ฃŒ
if self.remote_pipeline is not None:
try:
self.logger.log_info(" - Killing remote pipeline actor...")
ray.kill(self.remote_pipeline)
except:
pass
self.remote_pipeline = None
# Ray ํด๋Ÿฌ์Šคํ„ฐ ์ข…๋ฃŒ
if self.ray_initialized and ray.is_initialized():
self.logger.log_info(" - Shutting down Ray cluster...")
# ๋ชจ๋“  Ray actors ๊ฐ•์ œ ์ข…๋ฃŒ
try:
# ํ˜„์žฌ ์‹คํ–‰ ์ค‘์ธ ๋ชจ๋“  actors ๊ฐ€์ ธ์˜ค๊ธฐ
actors = ray.util.list_named_actors()
if actors:
self.logger.log_info(f" - Found {len(actors)} named actors to kill")
for actor in actors:
try:
ray.kill(ray.get_actor(actor['name']))
except:
pass
except:
pass
# Ray shutdown with force
ray.shutdown()
self.ray_initialized = False
# Ray ํ”„๋กœ์„ธ์Šค๊ฐ€ ์™„์ „ํžˆ ์ข…๋ฃŒ๋  ๋•Œ๊นŒ์ง€ ์ž ์‹œ ๋Œ€๊ธฐ
import time
time.sleep(2)
self.logger.log_info("โœ… Ray cluster shutdown complete")
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.logger.log_info(" - GPU memory cleared")
except:
pass
except Exception as e:
self.logger.log_error(f"Error during cleanup: {e}")
# ๊ทธ๋ž˜๋„ Ray๋Š” ๊ฐ•์ œ ์ข…๋ฃŒ ์‹œ๋„
try:
ray.shutdown()
except:
pass
def run_iterative_training(self, benchmark_config: BenchmarkConfig,
problem_ids: List[str],
total_rounds: int = 30,
resume_from_round: int = 1) -> Dict[str, Any]:
"""30๋ผ์šด๋“œ ๋ฐ˜๋ณต ํ•™์Šต ๋ฉ”์ธ ๋ฃจํ”„"""
self.start_time = datetime.now()
# ์„ธ์…˜ ์ „์ฒด์—์„œ ์‚ฌ์šฉํ•  timestamp ์ƒ์„ฑ (ํ•œ ๋ฒˆ๋งŒ)
self.session_timestamp = self.start_time.strftime('%Y%m%d_%H%M%S')
self.logger.log_info(f"๐Ÿš€ Starting TTRLVR + AZR iterative training")
self.logger.log_info(f"๐Ÿ“Š Configuration: {len(problem_ids)} problems, {total_rounds} rounds")
self.logger.log_info(f"๐ŸŽฏ Problems: {problem_ids}")
self.logger.log_info(f"๐Ÿ“ Session timestamp: {self.session_timestamp}")
# ์ฒดํฌํฌ์ธํŠธ์—์„œ ์žฌ๊ฐœํ•˜๋Š” ๊ฒฝ์šฐ
if resume_from_round > 1:
self.logger.log_info(f"๐Ÿ”„ Resuming from round {resume_from_round}")
checkpoint_model = self._load_checkpoint(resume_from_round - 1)
if checkpoint_model:
self.current_model_path = checkpoint_model
training_results = {
'start_time': self.start_time.isoformat(),
'session_timestamp': self.session_timestamp,
'benchmark': benchmark_config.name,
'problem_ids': problem_ids,
'total_rounds': total_rounds,
'resume_from_round': resume_from_round,
'rounds': {},
'success': False,
'error': None
}
try:
# ๋ฉ”์ธ ๋ฐ˜๋ณต ํ•™์Šต ๋ฃจํ”„
for round_num in range(resume_from_round, total_rounds + 1):
round_start_time = datetime.now()
self.logger.log_info(f"" + "="*80)
self.logger.log_info(f"๐Ÿ”„ ROUND {round_num}/{total_rounds} - Starting")
self.logger.log_info(f"๐Ÿค– Current model: {self.current_model_path}")
self.logger.log_info(f"" + "="*80)
# ๋‹จ์ผ ๋ผ์šด๋“œ ์‹คํ–‰
round_result = self._run_single_round(
benchmark_config, problem_ids, round_num
)
# ๋ผ์šด๋“œ ๊ฒฐ๊ณผ ์ €์žฅ
round_end_time = datetime.now()
round_duration = (round_end_time - round_start_time).total_seconds()
self.round_times[round_num] = round_duration
round_result['duration_seconds'] = round_duration
round_result['model_before'] = self.current_model_path
training_results['rounds'][round_num] = round_result
if not round_result['success']:
self.logger.log_error(f"โŒ Round {round_num} failed: {round_result.get('error', 'Unknown error')}")
continue
# AZR ํ•™์Šต ์‹คํ–‰
if round_result['training_data_files']:
self.logger.log_info(f"๐ŸŽ“ Starting AZR training for round {round_num}")
new_model_path = self._train_azr_with_round_data(
round_result['training_data_files'], round_num
)
if new_model_path:
self.current_model_path = new_model_path
round_result['model_after'] = new_model_path
self.logger.log_info(f"โœ… Round {round_num} completed successfully")
self.logger.log_info(f"๐ŸŽฏ New model: {new_model_path}")
# โญ VLLM Ray Actor์˜ ๊ฐ€์ค‘์น˜๋„ ์—…๋ฐ์ดํŠธ (์ง„์ •ํ•œ ๋ชจ๋ธ ๊ณต์œ )
if hasattr(self, 'remote_pipeline') and self.remote_pipeline is not None:
self.logger.log_info("๐Ÿ”„ Updating VLLM Ray Actor weights with trained model")
import ray
update_success = ray.get(self.remote_pipeline.update_model_weights.remote(new_model_path))
if update_success:
self.logger.log_info("โœ… VLLM weights updated successfully for next round")
else:
self.logger.log_warning("โš ๏ธ Failed to update VLLM weights, using old model")
else:
self.logger.log_error(f"โŒ AZR training failed for round {round_num}")
round_result['training_error'] = "AZR training failed"
# ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (5๋ผ์šด๋“œ๋งˆ๋‹ค)
if round_num % 5 == 0:
self._save_checkpoint(round_num, self.current_model_path, training_results)
self.logger.log_info(f"๐Ÿ’พ Checkpoint saved for round {round_num}")
# ๋ผ์šด๋“œ ์š”์•ฝ ๋กœ๊ทธ
self._log_round_summary(round_num, round_result, round_duration)
# ์ „์ฒด ํ•™์Šต ์™„๋ฃŒ
training_results['success'] = True
training_results['end_time'] = datetime.now().isoformat()
training_results['total_duration_seconds'] = (datetime.now() - self.start_time).total_seconds()
training_results['final_model'] = self.current_model_path
self.logger.log_info(f"๐ŸŽ‰ TTRLVR + AZR iterative training completed successfully!")
# VeRL Trainer ์ •๋ฆฌ
if hasattr(self, 'verl_trainer') and self.verl_trainer is not None:
self.logger.log_info("๐Ÿงน Cleaning up VeRL Trainer...")
try:
# VeRL trainer cleanup (Ray ๋“ฑ)
if hasattr(self.verl_trainer, 'cleanup'):
self.verl_trainer.cleanup()
self.verl_trainer = None
except Exception as cleanup_error:
self.logger.log_warning(f"Cleanup warning: {cleanup_error}")
self._log_final_summary(training_results)
return training_results
except Exception as e:
self.logger.log_error(f"๐Ÿ’ฅ Iterative training failed: {e}")
import traceback
traceback.print_exc()
return {
'success': False,
'error': str(e),
'rounds': self.round_results
}
def run_verl_training_only(self, training_data_path: str, round_num: int = 1,
experiment_name: Optional[str] = None) -> Dict[str, Any]:
"""
VeRL training(5๋‹จ๊ณ„)๋งŒ ๋ณ„๋„๋กœ ์‹คํ–‰
1-4๋‹จ๊ณ„์—์„œ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ๋กœ VeRL PPO ํ•™์Šต๋งŒ ์ˆ˜ํ–‰
Args:
training_data_path: TTRLVR์—์„œ ์ƒ์„ฑ๋œ ํ•™์Šต ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ (parquet ํŒŒ์ผ๋“ค)
round_num: ๋ผ์šด๋“œ ๋ฒˆํ˜ธ (๋กœ๊ทธ์šฉ)
experiment_name: ์‹คํ—˜ ์ด๋ฆ„ (์„ ํƒ์‚ฌํ•ญ)
Returns:
ํ•™์Šต ๊ฒฐ๊ณผ ๋”•์…”๋„ˆ๋ฆฌ
"""
try:
self.logger.log_info("๐Ÿš€ Starting VeRL training ONLY (Step 5)")
self.logger.log_info("="*80)
self.logger.log_info(f"๐Ÿ“‚ Training data path: {training_data_path}")
self.logger.log_info(f"๐Ÿ”„ Round: {round_num}")
# VeRL config ๋กœ๋“œ (ํ•„์š”์‹œ)
if not hasattr(self, 'verl_config') or self.verl_config is None:
self.logger.log_info("๐Ÿ”ง Loading VeRL config for standalone training")
self._load_verl_config()
# ํ•™์Šต ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ์—…๋ฐ์ดํŠธ
if not os.path.exists(training_data_path):
raise FileNotFoundError(f"Training data path not found: {training_data_path}")
# parquet ํŒŒ์ผ๋“ค ์ฐพ๊ธฐ
parquet_files = list(Path(training_data_path).glob("*.parquet"))
if not parquet_files:
raise FileNotFoundError(f"No parquet files found in: {training_data_path}")
# VeRL config ์—…๋ฐ์ดํŠธ
self.verl_config.data.train_files = [str(f) for f in parquet_files]
self.verl_config.data.val_files = [str(f) for f in parquet_files[:1]] # ์ฒซ ๋ฒˆ์งธ ํŒŒ์ผ์„ validation์œผ๋กœ
self.logger.log_info(f"๐Ÿ“Š Found {len(parquet_files)} training files")
for i, f in enumerate(parquet_files):
self.logger.log_info(f" {i+1}. {f.name}")
# โญ VeRL trainer ์ดˆ๊ธฐํ™” (์‹ค์ œ ๋ฐ์ดํ„ฐ๋กœ)
if not hasattr(self, 'verl_trainer') or self.verl_trainer is None:
self.logger.log_info("๐Ÿš€ Initializing VeRL trainer with actual training data")
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ (FSDP ๋กœ๋“œ ์ „)
import torch
torch.cuda.empty_cache()
# VeRL trainer ์ดˆ๊ธฐํ™”
self._initialize_verl_trainer(training_data_path)
self.logger.log_info(f"โœ… FSDP model loaded on GPU {self.available_gpus}")
self.logger.log_info("โœ… GPU sharing enabled: VLLM + FSDP on same GPUs")
else:
# ๊ธฐ์กด VeRL trainer๊ฐ€ ์žˆ๋‹ค๋ฉด ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ์—…๋ฐ์ดํŠธ
self.logger.log_info("๐Ÿ”„ Updating existing VeRL trainer with new data files")
# Trainer์˜ config ์—…๋ฐ์ดํŠธ
self.verl_trainer.config.data.train_files = self.verl_config.data.train_files
self.verl_trainer.config.data.val_files = self.verl_config.data.val_files
# init_workers๊ฐ€ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•˜๋ฏ€๋กœ ๋‹ค์‹œ ํ˜ธ์ถœ
self.logger.log_info("๐Ÿ”ง Re-initializing VeRL workers with new data...")
self.verl_trainer.init_workers()
self.logger.log_info("โœ… VeRL workers re-initialized with actual training data")
# ์‹คํ—˜๋ช… ์„ค์ •
if experiment_name:
self.verl_config.experiment.name = experiment_name
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.verl_config.experiment.name = f"verl_only_round_{round_num}_{timestamp}"
self.logger.log_info(f"๐Ÿท๏ธ Experiment: {self.verl_config.experiment.name}")
# 5๋‹จ๊ณ„: VeRL PPO ํ•™์Šต ์‹คํ–‰
self.logger.log_info("๐ŸŽ“ Starting VeRL PPO training...")
start_time = datetime.now()
# VeRL trainer๋กœ ์ง์ ‘ ํ•™์Šต ์‹คํ–‰
try:
if hasattr(self, 'verl_trainer') and self.verl_trainer is not None:
# ํ•™์Šต ์ค‘ ์ƒ์„ฑ๋œ ์‘๋‹ต์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ • (๊ธฐ์กด ๊ฒฝ๋กœ ๊ตฌ์กฐ์— ์ถ”๊ฐ€)
llm_responses_dir = os.path.join(os.path.dirname(training_data_path), "llm_responses")
os.makedirs(llm_responses_dir, exist_ok=True)
self.logger.log_info(f"๐Ÿ“ LLM responses will be saved to: {llm_responses_dir}")
# VeRL config์— rollout ๋ฐ์ดํ„ฐ ์ €์žฅ ๊ฒฝ๋กœ ์„ค์ •
self.verl_trainer.config.trainer.rollout_data_dir = llm_responses_dir
# ์ปค์Šคํ…€ ๋กœ๊น…์„ ์œ„ํ•œ ์ฝœ๋ฐฑ ์„ค์ •
self.llm_responses_dir = llm_responses_dir
self.response_counter = 0
self.logger.log_info("๐ŸŽฏ Running VeRL PPO training...")
self.verl_trainer.fit()
training_result = {'success': True, 'model_path': self.current_model_path}
self.logger.log_info("โœ… VeRL training completed successfully")
# JSONL ํŒŒ์ผ๋“ค์„ TTRLVR ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
if hasattr(self, 'llm_responses_dir') and os.path.exists(self.llm_responses_dir):
self.logger.log_info("๐Ÿ“ Converting VeRL outputs to TTRLVR format...")
jsonl_files = list(Path(self.llm_responses_dir).glob("*.jsonl"))
for jsonl_file in jsonl_files:
self._convert_jsonl_to_ttrlvr_format(str(jsonl_file), self.llm_responses_dir)
self.logger.log_info(f"โœ… Converted {len(jsonl_files)} JSONL files to TTRLVR format")
else:
raise ValueError("VeRL trainer not initialized")
except Exception as e:
self.logger.log_error(f"VeRL training failed: {e}")
training_result = {'success': False, 'error': str(e)}
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
self.logger.log_info(f"โฑ๏ธ VeRL training completed in {duration:.1f} seconds")
# ๊ฒฐ๊ณผ ๊ตฌ์„ฑ
result = {
'success': training_result.get('success', False),
'round': round_num,
'experiment_name': self.verl_config.experiment.name,
'training_data_path': training_data_path,
'duration_seconds': duration,
'start_time': start_time.isoformat(),
'end_time': end_time.isoformat(),
'training_files': len(parquet_files),
'model_path': getattr(training_result, 'model_path', self.current_model_path),
'details': training_result
}
if result['success']:
self.logger.log_info("๐ŸŽ‰ VeRL training completed successfully!")
if 'model_path' in training_result:
self.current_model_path = training_result['model_path']
self.logger.log_info(f"๐Ÿค– Updated model path: {self.current_model_path}")
else:
self.logger.log_error("โŒ VeRL training failed")
self.logger.log_error(f"Error: {training_result.get('error', 'Unknown error')}")
return result
except Exception as e:
self.logger.log_error(f"๐Ÿ’ฅ VeRL-only training failed: {e}")
import traceback
traceback.print_exc()
return {
'success': False,
'error': str(e),
'round': round_num,
'training_data_path': training_data_path
}
def _process_single_round(self, benchmark_config: BenchmarkConfig,
problem_ids: List[str], round_num: int) -> Dict[str, Any]:
"""๋‹จ์ผ ๋ผ์šด๋“œ ์ฒ˜๋ฆฌ"""
round_start_time = datetime.now()
self.logger.log_info(f"๐Ÿ”„ ROUND {round_num} - Starting")
try:
# ๋ผ์šด๋“œ ์‹คํ–‰ ๋กœ์ง์„ _run_single_round๋กœ ์œ„์ž„
return self._run_single_round(benchmark_config, problem_ids, round_num)
except Exception as e:
self.logger.log_error(f"Round {round_num} failed: {e}")
return {
'success': False,
'error': str(e),
'round': round_num,
'duration_seconds': (datetime.now() - round_start_time).total_seconds()
}
def _run_single_round(self, benchmark_config: BenchmarkConfig,
problem_ids: List[str], round_num: int) -> Dict[str, Any]:
"""๋‹จ์ผ ๋ผ์šด๋“œ ์‹คํ–‰ - Ray๋ฅผ ํ™œ์šฉํ•œ ๋ณ‘๋ ฌ TTRLVR ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰"""
round_result = {
'round_num': round_num,
'problems': {},
'training_data_files': [],
'success': False,
'error': None,
'stats': {
'total_problems': len(problem_ids),
'successful_problems': 0,
'failed_problems': 0,
'total_tasks': 0,
'tasks_by_type': {'induction': 0, 'deduction': 0, 'abduction': 0}
}
}
try:
# VeRL config ๋กœ๋“œ (Ray ์„ค์ • ํ™•์ธ์šฉ)
if not hasattr(self, 'verl_config') or self.verl_config is None:
self.logger.log_info("๐Ÿ”ง Loading VeRL config for Ray settings")
self._load_verl_config()
# Ray ํด๋Ÿฌ์Šคํ„ฐ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์—ˆ๋Š”์ง€ ํ™•์ธ
if not self.ray_initialized:
self.logger.log_info("๐Ÿš€ Initializing Ray for data generation")
self._initialize_ray_cluster()
# โญ VeRL trainer๋Š” 5๋‹จ๊ณ„์—์„œ ์‹ค์ œ ๋ฐ์ดํ„ฐ๋กœ ์ดˆ๊ธฐํ™”
# GPU ๊ณต์œ  ์„ค์ •๋งŒ ๋ฏธ๋ฆฌ ํ™•์ธ
if round_num == 1:
self.logger.log_info("๐Ÿ“Œ VeRL trainer will be initialized in Step 5 with actual data")
self.logger.log_info("๐Ÿ”ง GPU sharing plan: VLLM (GPU 1,2) + FSDP (GPU 0,1,2,3)")
# ํ˜„์žฌ ๋ชจ๋ธ๋กœ ํŒŒ์ดํ”„๋ผ์ธ ์—…๋ฐ์ดํŠธ
self._update_pipeline_model(self.current_model_path)
successful_problems = 0
# ํ•ญ์ƒ ์ˆœ์ฐจ ์ฒ˜๋ฆฌ ์‚ฌ์šฉ (๋ฌธ์ œ ๊ฐ„ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ์ œ๊ฑฐ)
# ๋‹จ์ผ ๋ฌธ์ œ ๋‚ด์—์„œ๋งŒ VLLM ๋ฐฐ์น˜ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ์‚ฌ์šฉ
self.logger.log_info(f"๐Ÿ“ Using sequential processing for {len(problem_ids)} problems")
self.logger.log_info(" - Multi-problem parallelization disabled")
self.logger.log_info(" - Single-problem VLLM batch processing enabled")
results = self._process_problems_sequential(benchmark_config, problem_ids, round_num)
# ๊ฒฐ๊ณผ ํ†ตํ•ฉ
for problem_id, pipeline_result in results.items():
round_result['problems'][problem_id] = pipeline_result
if pipeline_result['success']:
successful_problems += 1
# AZR ํ•™์Šต ๋ฐ์ดํ„ฐ ํŒŒ์ผ ์ˆ˜์ง‘
if 'azr_training_data' in pipeline_result:
round_result['training_data_files'].append({
'problem_id': problem_id,
'files': pipeline_result['azr_training_data']
})
# ํ†ต๊ณ„ ์—…๋ฐ์ดํŠธ
if 'azr_data_saving' in pipeline_result['steps']:
total_tasks = pipeline_result['steps']['azr_data_saving']['total_tasks']
round_result['stats']['total_tasks'] += total_tasks
self.logger.log_info(f"โœ… {problem_id} completed successfully")
else:
self.logger.log_error(f"โŒ {problem_id} failed: {pipeline_result.get('error', 'Unknown error')}")
# ๋ผ์šด๋“œ ํ†ต๊ณ„ ์—…๋ฐ์ดํŠธ
round_result['stats']['successful_problems'] = successful_problems
round_result['stats']['failed_problems'] = len(problem_ids) - successful_problems
round_result['success'] = successful_problems > 0
if successful_problems == 0:
round_result['error'] = "No problems completed successfully"
return round_result
except Exception as e:
round_result['error'] = str(e)
return round_result
def _initialize_pipeline(self):
"""Ray Actor๋กœ ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (VeRL ํŒจํ„ด)"""
if self.remote_pipeline is None:
try:
# TTRLVR ํŒŒ์ดํ”„๋ผ์ธ์šฉ config ์—…๋ฐ์ดํŠธ
ttrlvr_config = self.config
# ์‹คํ–‰ ๋ชจ๋“œ์— ๋”ฐ๋ฅธ ์—”์ง„ ์„ ํƒ
# VeRL config์—์„œ rollout name ํ™•์ธ
if hasattr(self, 'verl_config') and self.verl_config and hasattr(self.verl_config, 'actor_rollout_ref'):
rollout_name = self.verl_config.actor_rollout_ref.rollout.name
# HuggingFace rollout์ด๋ฉด HuggingFace ์‚ฌ์šฉ
use_vllm = (rollout_name == "vllm")
else:
# ๊ธฐ๋ณธ๊ฐ’: distributed๋ฉด vllm, single_gpu๋ฉด huggingface
use_vllm = (self.execution_mode == "distributed")
ttrlvr_config.use_vllm_for_data_generation = use_vllm
engine_name = "vllm" if use_vllm else "huggingface"
self.logger.log_info(f"๐Ÿ”ง TTRLVR data generation using: {engine_name} (execution_mode: {self.execution_mode})")
# Config ๋””๋ฒ„๊น… ๋กœ๊ทธ ์ถ”๊ฐ€
self.logger.log_info(f"๐Ÿ” Config debug: num_program_variations = {ttrlvr_config.num_program_variations}")
self.logger.log_info(f"๐Ÿ” Config debug: skip_task_evaluation = {getattr(ttrlvr_config, 'skip_task_evaluation', False)}")
# Ray Actor๋กœ ํŒŒ์ดํ”„๋ผ์ธ ์ƒ์„ฑ (GPU ๊ฐœ์ˆ˜์— ๋งž์ถฐ ๋™์  ์ƒ์„ฑ)
from absolute_zero_reasoner.testtime.complete_pipeline import RemoteTestTimePipeline
gpu_count = len(self.available_gpus)
self.logger.log_info(f"๐Ÿš€ Creating RemoteTestTimePipeline with {gpu_count} GPUs, model: {self.current_model_path}")
# ๋กœ๊ทธ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋กœ ์„ค์ •
import os
if hasattr(self.logger, 'log_file_path') and self.logger.log_file_path:
runtime_env = {
"env_vars": {
"TTRLVR_LOG_FILE": self.logger.log_file_path,
"CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0")
}
}
self.logger.log_info(f"๐Ÿ”ง Setting TTRLVR_LOG_FILE for Ray worker: {self.logger.log_file_path}")
else:
runtime_env = {
"env_vars": {
"CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES", "0")
}
}
self.logger.log_warning("โš ๏ธ Logger does not have log_file_path attribute, Ray worker will create its own log file")
# GPU ๊ฐœ์ˆ˜์— ๋”ฐ๋ฅธ Ray Actor ์ƒ์„ฑ
# โญ GPU๋ฅผ ๋…์ ํ•˜์ง€ ์•Š๋„๋ก num_gpus=0์œผ๋กœ ์„ค์ •
# VLLM์€ ๋‚ด๋ถ€์ ์œผ๋กœ CUDA_VISIBLE_DEVICES๋กœ ์ฒซ 2๊ฐœ GPU๋งŒ ์‚ฌ์šฉ
gpu_count = len(self.available_gpus)
# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ์ค‘ ์ฒซ 2๊ฐœ๋ฅผ VLLM์šฉ์œผ๋กœ ํ• ๋‹น
if gpu_count >= 2:
# ์‹ค์ œ GPU ์ธ๋ฑ์Šค ๊ฐ€์ ธ์˜ค๊ธฐ
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3").split(',')
vllm_gpus = f"{cuda_devices[0]},{cuda_devices[1]}" # ์ฒซ 2๊ฐœ GPU
else:
vllm_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
# Runtime ํ™˜๊ฒฝ์— VLLM GPU ์ œํ•œ ์ถ”๊ฐ€
# Ray runtime_env์—์„œ ์‰ผํ‘œ๊ฐ€ ์žˆ๋Š” ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ์ œ๋Œ€๋กœ ์ฒ˜๋ฆฌํ•˜๋„๋ก ๋ช…์‹œ์ ์œผ๋กœ ๋ฌธ์ž์—ด๋กœ ์„ค์ •
runtime_env["env_vars"]["CUDA_VISIBLE_DEVICES"] = str(vllm_gpus)
runtime_env["env_vars"]["VLLM_USE_SPECIFIC_GPUS"] = str(vllm_gpus)
self.logger.log_info(f"๐ŸŽฏ Creating Ray Actor without exclusive GPU allocation")
self.logger.log_info(f" - VLLM will use GPUs: {vllm_gpus} (via CUDA_VISIBLE_DEVICES)")
self.logger.log_info(f" - FSDP can use all GPUs: {os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3')}")
# AZR ๋ฐฉ์‹: GPU ํ• ๋‹น ์—†์ด Ray Actor ์ƒ์„ฑ
# GPU๋Š” CUDA_VISIBLE_DEVICES๋กœ๋งŒ ์ œ์–ด
RemoteTestTimePipelineWithGPUs = RemoteTestTimePipeline.options(
num_cpus=4, # VLLM ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์ถฉ๋ถ„ํ•œ CPU ํ• ๋‹น
# num_gpus ์„ค์ •ํ•˜์ง€ ์•Š์Œ - GPU๋Š” CUDA_VISIBLE_DEVICES๋กœ ์ œ์–ด
runtime_env=runtime_env
)
self.remote_pipeline = RemoteTestTimePipelineWithGPUs.remote(
config=ttrlvr_config,
model_path=self.current_model_path
)
self.logger.log_info(f"๐Ÿ”„ RemoteTestTimePipeline initialized in {self.execution_mode} mode")
if self.execution_mode == "distributed":
self.logger.log_info(f" - Using VLLM distributed inference on GPU 0,1")
self.logger.log_info(f" - FSDP can use all GPUs: 0,1,2,3 with sharing")
self.logger.log_info(f" - Model loading handled inside Ray worker")
else:
self.logger.log_info(f" - Using HuggingFace single GPU inference inside Ray worker")
except Exception as e:
self.logger.log_error(f"Failed to initialize pipeline: {e}")
raise
def _update_pipeline_model(self, model_path: str):
"""ํŒŒ์ดํ”„๋ผ์ธ์˜ ๋ชจ๋ธ ๋ ˆํผ๋Ÿฐ์Šค ์—…๋ฐ์ดํŠธ (๋ฉ”๋ชจ๋ฆฌ ๋‚ด ๋™์ผ ๋ชจ๋ธ ์œ ์ง€)"""
try:
# Ray Actor ํŒŒ์ดํ”„๋ผ์ธ์ด ์—†์œผ๋ฉด ์ดˆ๊ธฐํ™”
if self.remote_pipeline is None:
self._initialize_pipeline()
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์—…๋ฐ์ดํŠธ (Ray worker๋Š” ์ƒˆ๋กœ์šด ๋ชจ๋ธ ๊ฒฝ๋กœ๋กœ ์žฌ์ƒ์„ฑ)
self.current_model_path = model_path
self.logger.log_info(f"๐Ÿ”„ Pipeline model path updated to: {model_path}")
# ์ƒˆ๋กœ์šด ๋ชจ๋ธ์ด๋ฉด Ray Actor ์žฌ์ƒ์„ฑ
if hasattr(self, '_last_model_path') and self._last_model_path != model_path:
self.logger.log_info("๐Ÿ”„ Model path changed, recreating Ray Actor")
self.remote_pipeline = None
self._initialize_pipeline()
self._last_model_path = model_path
except Exception as e:
self.logger.log_warning(f"Failed to update pipeline model: {e}")
def _train_azr_with_round_data(self, training_data_files: List[Dict[str, Any]],
round_num: int) -> Optional[str]:
"""ํ•ด๋‹น ๋ผ์šด๋“œ ๋ฐ์ดํ„ฐ๋กœ ๋ฉ”๋ชจ๋ฆฌ ๋‚ด ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ"""
try:
# 1. ๋ผ์šด๋“œ๋ณ„ ํ†ตํ•ฉ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
combined_data_path = self._combine_round_data(training_data_files, round_num)
if not combined_data_path:
self.logger.log_error(f"Failed to combine training data for round {round_num}")
return None
self.logger.log_info(f"๐Ÿ“Š Combined training data: {combined_data_path}")
# 2. ๋ฉ”๋ชจ๋ฆฌ ๋‚ด์—์„œ ์ง์ ‘ ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ (๋””์Šคํฌ ์ €์žฅ/๋กœ๋“œ ์—†์Œ)
updated_model = self._update_model_in_memory(combined_data_path, round_num)
if updated_model:
# ๋ฉ”๋ชจ๋ฆฌ ๋‚ด ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์—…๋ฐ์ดํŠธ
self.current_model = updated_model
# ํŒŒ์ดํ”„๋ผ์ธ ์ปดํฌ๋„ŒํŠธ๋“ค๋„ ์—…๋ฐ์ดํŠธ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ
if self.complete_pipeline:
self.complete_pipeline.model = self.current_model
if hasattr(self.complete_pipeline, 'solution_generator') and self.complete_pipeline.solution_generator:
self.complete_pipeline.solution_generator.model = self.current_model
if hasattr(self.complete_pipeline, 'ipo_extractor') and self.complete_pipeline.ipo_extractor:
self.complete_pipeline.ipo_extractor.model = self.current_model
# ์ฐธ์กฐ์šฉ ๊ฒฝ๋กœ ์—…๋ฐ์ดํŠธ
virtual_model_path = f"memory://round_{round_num}_model"
self.logger.log_info(f"โœ… Model updated in memory for round {round_num}")
self.logger.log_info(f"๐ŸŽฏ Virtual model path: {virtual_model_path}")
return virtual_model_path
else:
self.logger.log_error(f"โŒ Model update failed for round {round_num}")
return None
except Exception as e:
self.logger.log_error(f"๐Ÿ’ฅ Model update execution failed: {e}")
return None
def _update_model_in_memory(self, training_data_path: str, round_num: int) -> Optional[Any]:
"""๋ฉ”๋ชจ๋ฆฌ ๋‚ด์—์„œ VeRL์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์ง์ ‘ ์—…๋ฐ์ดํŠธ (AZR REINFORCE++ ํ•™์Šต)"""
try:
self.logger.log_info(f"๐ŸŽ“ Starting VeRL-based AZR training for round {round_num}")
self.logger.log_info(f"๐Ÿ“‚ Training data: {training_data_path}")
# ๊ฐ„๋‹จํ•œ ์ฒดํฌ: ํ•™์Šต ๋ฐ์ดํ„ฐ๊ฐ€ ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธ
task_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet']
available_files = []
for task_file in task_files:
file_path = os.path.join(training_data_path, task_file)
if os.path.exists(file_path):
available_files.append(task_file)
# ํŒŒ์ผ ํฌ๊ธฐ ํ™•์ธ
file_size = os.path.getsize(file_path)
self.logger.log_info(f" ๐Ÿ“„ {task_file}: {file_size} bytes")
if not available_files:
self.logger.log_warning("โš ๏ธ No training data files found in specified directory")
# ์‹ค์ œ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒ€์ƒ‰
self.logger.log_info("๐Ÿ” Searching for actual training data...")
actual_data_path = self._find_actual_training_data()
if actual_data_path:
self.logger.log_info(f"โœ… Found actual training data: {actual_data_path}")
training_data_path = actual_data_path
# ๋‹ค์‹œ ํŒŒ์ผ ํ™•์ธ
for task_file in task_files:
file_path = os.path.join(training_data_path, task_file)
if os.path.exists(file_path):
available_files.append(task_file)
file_size = os.path.getsize(file_path)
self.logger.log_info(f" ๐Ÿ“„ {task_file}: {file_size} bytes")
else:
self.logger.log_error("โŒ No actual training data found anywhere")
return None
self.logger.log_info(f"๐Ÿ“š Processing {len(available_files)} task types")
# โญ Step 1-4๊ฐ€ ์™„๋ฃŒ๋˜์—ˆ์œผ๋ฏ€๋กœ VLLM Ray Actor ํ•ด์ œํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ํ™•๋ณด
if hasattr(self, 'remote_pipeline') and self.remote_pipeline is not None:
self.logger.log_info("๐Ÿงน Releasing VLLM Ray Actor memory before Step 5...")
try:
# ๋จผ์ € cleanup ๋ฉ”์„œ๋“œ ํ˜ธ์ถœํ•˜์—ฌ ๋‚ด๋ถ€ ๋ฆฌ์†Œ์Šค ์ •๋ฆฌ
cleanup_result = ray.get(self.remote_pipeline.cleanup.remote())
if cleanup_result:
self.logger.log_info("โœ… VLLM internal resources cleaned up")
# ๊ทธ ๋‹ค์Œ Ray Actor ์ข…๋ฃŒ
ray.kill(self.remote_pipeline)
self.remote_pipeline = None
self.logger.log_info("โœ… VLLM Ray Actor killed successfully")
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
import torch
torch.cuda.empty_cache()
# ์ž ์‹œ ๋Œ€๊ธฐํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์™„์ „ํžˆ ํ•ด์ œ๋˜๋„๋ก ํ•จ
import time
time.sleep(2)
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ƒํƒœ ํ™•์ธ
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
memory_allocated = torch.cuda.memory_allocated(i) / 1024**3
memory_reserved = torch.cuda.memory_reserved(i) / 1024**3
self.logger.log_info(f" GPU {i}: Allocated={memory_allocated:.2f}GB, Reserved={memory_reserved:.2f}GB")
except Exception as e:
self.logger.log_warning(f"โš ๏ธ Error during VLLM cleanup: {e}")
# VeRL trainer ์ดˆ๊ธฐํ™” (์ฒซ ๋ฒˆ์งธ ๋ผ์šด๋“œ์—์„œ๋งŒ)
if self.verl_trainer is None:
self._initialize_verl_trainer(training_data_path)
else:
# ๊ธฐ์กด trainer์—์„œ ๋ฐ์ดํ„ฐ๋งŒ ์—…๋ฐ์ดํŠธ
self._update_verl_trainer_data(training_data_path)
if self.verl_trainer is None:
self.logger.log_error("Failed to initialize VeRL trainer")
return self.current_model
# VeRL ๋ฉ”๋ชจ๋ฆฌ ๋‚ด ํ•™์Šต ์‹คํ–‰
self.logger.log_info(f"๐Ÿš€ Starting in-memory VeRL training for round {round_num}")
# ํ•™์Šต ์ค‘ ์ƒ์„ฑ๋œ ์‘๋‹ต์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
llm_responses_dir = os.path.join(os.path.dirname(training_data_path), "llm_responses")
os.makedirs(llm_responses_dir, exist_ok=True)
self.logger.log_info(f"๐Ÿ“ LLM responses will be saved to: {llm_responses_dir}")
# VeRL config์— rollout ๋ฐ์ดํ„ฐ ์ €์žฅ ๊ฒฝ๋กœ ์„ค์ •
self.verl_config.trainer.rollout_data_dir = llm_responses_dir
# Epoch ์ˆ˜ ๋™์  ์กฐ์ • (ํ•„์š”์‹œ)
if hasattr(self, 'batch_epochs') and self.batch_epochs > 1:
original_epochs = self.verl_config.trainer.total_epochs
self.verl_config.trainer.total_epochs = self.batch_epochs
self.logger.log_info(f"๐Ÿ”ง Adjusted epochs from {original_epochs} to {self.batch_epochs}")
# main_azr_ppo.py์ฒ˜๋Ÿผ ppo_mini_batch_size ์ž๋™ ๊ณ„์‚ฐ (์ค‘์š”!)
train_batch_size = self.verl_config.data.train_batch_size
problem_types = getattr(self.verl_config.azr, 'problem_types', ['code_i', 'code_o', 'code_f'])
train_propose = getattr(self.verl_config.azr, 'train_propose', False)
# ์›๋ž˜ ๊ฐ’ ์ €์žฅ
original_ppo_mini_batch_size = self.verl_config.actor_rollout_ref.actor.ppo_mini_batch_size
# ์ž๋™ ๊ณ„์‚ฐ: train_batch_size * problem_types ๊ฐœ์ˆ˜ * (propose ์—ฌ๋ถ€)
calculated_ppo_mini_batch_size = train_batch_size * len(problem_types) * (2 if train_propose else 1)
self.verl_config.actor_rollout_ref.actor.ppo_mini_batch_size = calculated_ppo_mini_batch_size
# data_len๋„ ์ž๋™ ๊ณ„์‚ฐ (main_azr_ppo.py์™€ ๋™์ผ)
update_iteration = getattr(self.verl_config.azr.data_selection_strategy, 'update_iteration', 1)
self.verl_config.azr.data_selection_strategy.data_len = train_batch_size * update_iteration
self.logger.log_info(f"๐Ÿ”ง Auto-calculated ppo_mini_batch_size: {original_ppo_mini_batch_size} โ†’ {calculated_ppo_mini_batch_size}")
self.logger.log_info(f" - train_batch_size: {train_batch_size}")
self.logger.log_info(f" - problem_types: {len(problem_types)} ({problem_types})")
self.logger.log_info(f" - train_propose: {train_propose}")
self.logger.log_info(f"๐Ÿ”ง Auto-calculated data_len: {self.verl_config.azr.data_selection_strategy.data_len}")
# VeRL ํ•™์Šต ์‹คํ–‰
self.logger.log_info(f"๐Ÿƒ Calling verl_trainer.fit() for round {round_num}")
self.logger.log_info(f"๐Ÿ“Š Config - total_epochs: {self.verl_config.trainer.total_epochs}")
self.logger.log_info(f"๐Ÿ“Š Config - train_batch_size: {self.verl_config.data.train_batch_size}")
self.logger.log_info(f"๐Ÿ“Š Config - total_training_steps: {self.verl_config.trainer.total_training_steps}")
# trainer ์ธ์Šคํ„ด์Šค์˜ config๋„ ์—…๋ฐ์ดํŠธ (์ค‘์š”!)
if hasattr(self.verl_trainer, 'config'):
self.verl_trainer.config.trainer.rollout_data_dir = llm_responses_dir
# ์‹ค์ œ fit ํ˜ธ์ถœ
fit_start = datetime.now()
self.verl_trainer.fit()
fit_end = datetime.now()
fit_duration = (fit_end - fit_start).total_seconds()
self.logger.log_info(f"โฑ๏ธ verl_trainer.fit() completed in {fit_duration:.3f} seconds")
# JSONL ํŒŒ์ผ๋“ค์„ TTRLVR ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
if os.path.exists(llm_responses_dir):
self.logger.log_info("๐Ÿ“ Converting VeRL outputs to TTRLVR format...")
jsonl_files = list(Path(llm_responses_dir).glob("*.jsonl"))
for jsonl_file in jsonl_files:
self._convert_jsonl_to_ttrlvr_format(str(jsonl_file), llm_responses_dir)
self.logger.log_info(f"โœ… Converted {len(jsonl_files)} JSONL files to TTRLVR format")
# ํ•™์Šต๋œ ๋ชจ๋ธ์€ ์ด๋ฏธ VeRL trainer ๋‚ด๋ถ€์—์„œ ์—…๋ฐ์ดํŠธ๋จ
# ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค๋Š” ๋ฉ”๋ชจ๋ฆฌ์—์„œ ๊ณ„์† ์œ ์ง€๋จ
self.logger.log_info(f"โœ… Model updated successfully with REINFORCE++ for round {round_num}")
# ํ•™์Šต ํ›„ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (์กฐ๊ฑด๋ถ€)
if self._should_save_checkpoint(round_num):
self._save_round_checkpoint(round_num)
# ํ˜„์žฌ ๋ชจ๋ธ ๊ฐ์ฒด ๋ฐ˜ํ™˜ (๊ฐ€์ค‘์น˜๊ฐ€ ์—…๋ฐ์ดํŠธ๋จ)
# VeRL์—์„œ๋Š” ๋ชจ๋ธ์ด Ray worker ๋‚ด๋ถ€์—์„œ ์—…๋ฐ์ดํŠธ๋˜๋ฏ€๋กœ, ์‹ฌ๋ณผ๋ฆญ ์ฐธ์กฐ ๋ฐ˜ํ™˜
if self.current_model is None:
self.current_model = "verl_trained_model" # ์‹ฌ๋ณผ๋ฆญ ์ฐธ์กฐ
return self.current_model
except Exception as e:
self.logger.log_error(f"Failed to update model in memory: {e}")
import traceback
traceback.print_exc()
return None
def _combine_round_data(self, training_data_files: List[Dict[str, Any]],
round_num: int) -> Optional[str]:
"""๋ผ์šด๋“œ์˜ ๋ชจ๋“  ๋ฌธ์ œ ๋ฐ์ดํ„ฐ๋ฅผ task๋ณ„๋กœ ํ†ตํ•ฉ"""
try:
# ํ†ตํ•ฉ ๋ฐ์ดํ„ฐ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ
output_dir = f"/tmp/ttrlvr_azr_training/round_{round_num}"
os.makedirs(output_dir, exist_ok=True)
# Task ํƒ€์ž…๋ณ„ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘
combined_data = {
'induction': [],
'deduction': [],
'abduction': []
}
total_files_processed = 0
# ๊ฐ ๋ฌธ์ œ์˜ ๋ฐ์ดํ„ฐ ํŒŒ์ผ๋“ค์„ ์ˆœํšŒ
for problem_data in training_data_files:
problem_id = problem_data['problem_id']
files = problem_data['files']
self.logger.log_info(f"๐Ÿ“ Processing data for {problem_id}")
# Task ํƒ€์ž…๋ณ„ ํŒŒ์ผ ์ฒ˜๋ฆฌ
for task_type in ['induction', 'deduction', 'abduction']:
if task_type in files:
file_path = files[task_type]
if os.path.exists(file_path):
try:
df = pd.read_parquet(file_path)
task_data = df.to_dict('records')
combined_data[task_type].extend(task_data)
total_files_processed += 1
self.logger.log_info(f" โœ… {task_type}: {len(task_data)} tasks from {file_path}")
except Exception as e:
self.logger.log_warning(f" โš ๏ธ Failed to read {file_path}: {e}")
else:
self.logger.log_warning(f" โš ๏ธ File not found: {file_path}")
if total_files_processed == 0:
self.logger.log_error("No training data files found to combine")
return None
# ํ†ตํ•ฉ๋œ ๋ฐ์ดํ„ฐ๋ฅผ task๋ณ„ parquet ํŒŒ์ผ๋กœ ์ €์žฅ
combined_files = {}
total_tasks = 0
for task_type, data in combined_data.items():
if data:
# ipo_group_id๋กœ ์ •๋ ฌํ•˜์—ฌ ๋ฐฐ์น˜ ๋ณด์žฅ
df = pd.DataFrame(data)
df = df.sort_values('ipo_group_id')
# ํŒŒ์ผ ์ €์žฅ
file_path = os.path.join(output_dir, f"{task_type}.parquet")
df.to_parquet(file_path, index=False)
combined_files[task_type] = file_path
total_tasks += len(data)
self.logger.log_info(f"๐Ÿ’พ Saved {len(data)} {task_type} tasks to {file_path}")
else:
self.logger.log_warning(f"No {task_type} tasks found for round {round_num}")
# ํ†ต๊ณ„ ์ €์žฅ
stats = {
'round': round_num,
'total_tasks': total_tasks,
'tasks_by_type': {k: len(v) for k, v in combined_data.items()},
'files': combined_files,
'problems_processed': len(training_data_files),
'batch_groups': len(set(
task['ipo_group_id']
for task_data in combined_data.values()
for task in task_data
))
}
stats_file = os.path.join(output_dir, 'round_training_stats.json')
with open(stats_file, 'w') as f:
json.dump(stats, f, indent=2)
self.logger.log_info(f"๐Ÿ“Š Round {round_num} data summary:")
self.logger.log_info(f" - Total tasks: {total_tasks}")
self.logger.log_info(f" - Batch groups: {stats['batch_groups']}")
self.logger.log_info(f" - Files: {list(combined_files.keys())}")
return output_dir
except Exception as e:
self.logger.log_error(f"Failed to combine round data: {e}")
return None
def _save_checkpoint(self, round_num: int, model_path: str,
training_results: Dict[str, Any]):
"""์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (๋ชจ๋ธ ์ƒํƒœ, ํ•™์Šต ํ†ต๊ณ„, ๋ผ์šด๋“œ ์ •๋ณด)"""
try:
checkpoint_path = os.path.join(self.checkpoint_dir, f"checkpoint_round_{round_num}")
os.makedirs(checkpoint_path, exist_ok=True)
# ์ฒดํฌํฌ์ธํŠธ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
checkpoint_data = {
'round_num': round_num,
'model_path': model_path,
'timestamp': datetime.now().isoformat(),
'total_rounds': training_results.get('total_rounds', 30),
'completed_rounds': round_num,
'training_results': training_results,
'round_times': self.round_times
}
# JSON์œผ๋กœ ์ €์žฅ
checkpoint_file = os.path.join(checkpoint_path, 'checkpoint.json')
with open(checkpoint_file, 'w') as f:
json.dump(checkpoint_data, f, indent=2)
# ์š”์•ฝ ํ…์ŠคํŠธ ํŒŒ์ผ ์ €์žฅ
summary_file = os.path.join(checkpoint_path, 'summary.txt')
with open(summary_file, 'w') as f:
f.write(f"TTRLVR + AZR Training Checkpoint - Round {round_num}\n")
f.write("=" * 60 + "\n\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Completed Rounds: {round_num}/{training_results.get('total_rounds', 30)}\n")
f.write(f"Current Model: {model_path}\n")
f.write(f"Total Training Time: {sum(self.round_times.values()):.1f} seconds\n\n")
# ๋ผ์šด๋“œ๋ณ„ ํ†ต๊ณ„
f.write("Round Statistics:\n")
f.write("-" * 20 + "\n")
for r_num, r_time in self.round_times.items():
if r_num <= round_num:
round_result = training_results['rounds'].get(r_num, {})
success = "โœ…" if round_result.get('success', False) else "โŒ"
f.write(f"Round {r_num:2d}: {success} ({r_time:.1f}s)\n")
self.logger.log_info(f"๐Ÿ’พ Checkpoint saved: {checkpoint_path}")
except Exception as e:
self.logger.log_error(f"Failed to save checkpoint: {e}")
def _load_checkpoint(self, round_num: int) -> Optional[str]:
"""์ฒดํฌํฌ์ธํŠธ์—์„œ ๋ชจ๋ธ ๋กœ๋“œ"""
try:
checkpoint_path = os.path.join(self.checkpoint_dir, f"checkpoint_round_{round_num}")
checkpoint_file = os.path.join(checkpoint_path, 'checkpoint.json')
if not os.path.exists(checkpoint_file):
self.logger.log_warning(f"Checkpoint not found: {checkpoint_file}")
return None
with open(checkpoint_file, 'r') as f:
checkpoint_data = json.load(f)
model_path = checkpoint_data.get('model_path')
if model_path and os.path.exists(model_path):
self.logger.log_info(f"๐Ÿ“‚ Loaded checkpoint from round {round_num}")
self.logger.log_info(f"๐Ÿค– Model: {model_path}")
# ์ด์ „ ๋ผ์šด๋“œ ์‹œ๊ฐ„ ๋ณต์›
if 'round_times' in checkpoint_data:
self.round_times.update(checkpoint_data['round_times'])
return model_path
else:
self.logger.log_warning(f"Model path in checkpoint does not exist: {model_path}")
return None
except Exception as e:
self.logger.log_error(f"Failed to load checkpoint: {e}")
return None
def _log_round_summary(self, round_num: int, round_result: Dict[str, Any],
duration: float):
"""๋ผ์šด๋“œ ์™„๋ฃŒ ์š”์•ฝ ๋กœ๊ทธ"""
stats = round_result.get('stats', {})
self.logger.log_info(f"")
self.logger.log_info(f"๐Ÿ“Š ROUND {round_num} SUMMARY")
self.logger.log_info(f"" + "="*50)
self.logger.log_info(f"โฑ๏ธ Duration: {duration:.1f} seconds")
self.logger.log_info(f"๐Ÿ“ Problems: {stats.get('successful_problems', 0)}/{stats.get('total_problems', 0)} successful")
self.logger.log_info(f"๐ŸŽฏ Total tasks: {stats.get('total_tasks', 0)}")
tasks_by_type = stats.get('tasks_by_type', {})
for task_type, count in tasks_by_type.items():
if count > 0:
self.logger.log_info(f" - {task_type}: {count}")
if round_result.get('success'):
self.logger.log_info(f"โœ… Round {round_num} completed successfully")
else:
self.logger.log_info(f"โŒ Round {round_num} failed")
self.logger.log_info(f"")
def _log_final_summary(self, training_results: Dict[str, Any]):
"""์ „์ฒด ํ•™์Šต ์™„๋ฃŒ ์š”์•ฝ ๋กœ๊ทธ"""
total_duration = training_results.get('total_duration_seconds', 0)
total_rounds = training_results.get('total_rounds', 0)
completed_rounds = len(training_results.get('rounds', {}))
self.logger.log_info(f"")
self.logger.log_info(f"๐ŸŽ‰ TTRLVR + AZR TRAINING COMPLETED")
self.logger.log_info(f"" + "="*60)
self.logger.log_info(f"โฑ๏ธ Total Duration: {total_duration:.1f} seconds ({total_duration/3600:.1f} hours)")
self.logger.log_info(f"๐Ÿ”„ Completed Rounds: {completed_rounds}/{total_rounds}")
self.logger.log_info(f"๐Ÿค– Final Model: {training_results.get('final_model', 'N/A')}")
# ๋ผ์šด๋“œ๋ณ„ ์„ฑ๊ณต/์‹คํŒจ ํ†ต๊ณ„
successful_rounds = 0
failed_rounds = 0
for round_result in training_results['rounds'].values():
if round_result.get('success'):
successful_rounds += 1
else:
failed_rounds += 1
self.logger.log_info(f"๐Ÿ“Š Round Statistics:")
self.logger.log_info(f" - Successful: {successful_rounds}")
self.logger.log_info(f" - Failed: {failed_rounds}")
self.logger.log_info(f" - Success Rate: {successful_rounds/completed_rounds*100:.1f}%")
# ํ‰๊ท  ๋ผ์šด๋“œ ์‹œ๊ฐ„
if self.round_times:
avg_round_time = sum(self.round_times.values()) / len(self.round_times)
self.logger.log_info(f"โŒ› Average Round Time: {avg_round_time:.1f} seconds")
self.logger.log_info(f"")
self.logger.log_info(f"๐Ÿ’พ All results saved to: {self.checkpoint_dir}")
self.logger.log_info(f"๐ŸŽฏ Training completed successfully!")
self.logger.log_info(f"")
def _initialize_verl_trainer(self, training_data_path: str):
"""์ฒซ ๋ฒˆ์งธ ๋ผ์šด๋“œ์—์„œ VeRL trainer ๋ฐ Ray ํด๋Ÿฌ์Šคํ„ฐ ์ดˆ๊ธฐํ™”"""
try:
self.logger.log_info("๐Ÿš€ Initializing VeRL trainer for AZR training")
# Ray ์ดˆ๊ธฐํ™” (์ „์ฒด ์„ธ์…˜์—์„œ ํ•œ ๋ฒˆ๋งŒ)
if not self.ray_initialized:
self.logger.log_info("๐Ÿš€ Initializing Ray cluster for first time")
self._initialize_ray_cluster()
else:
self.logger.log_info("โ™ป๏ธ Using existing Ray cluster")
# VeRL config ๋กœ๋“œ (์•„์ง ๋กœ๋“œ๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ)
if not hasattr(self, 'verl_config') or self.verl_config is None:
self._load_verl_config()
# ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ๋™์  ์„ค์ • (parquet ํŒŒ์ผ ๋ฆฌ์ŠคํŠธ)
train_files = [
os.path.join(training_data_path, "induction.parquet"),
os.path.join(training_data_path, "deduction.parquet"),
os.path.join(training_data_path, "abduction.parquet")
]
# ์กด์žฌํ•˜๋Š” ํŒŒ์ผ๋งŒ ์„ ํƒ
valid_train_files = [f for f in train_files if os.path.exists(f)]
self.verl_config.data.train_files = valid_train_files
self.verl_config.data.val_files = valid_train_files
# ์ฒดํฌํฌ์ธํŠธ ๋น„ํ™œ์„ฑํ™”๋กœ ์ธํ•ด ๊ณ ์œ  ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ • ๋ถˆํ•„์š”
# VeRL์ด ๊ธฐ์กด ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค๋ฅผ ์‚ฌ์šฉํ•˜๋„๋ก ์„ค์ • (๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ)
# Config์—๋Š” ์›๋ณธ ๊ฒฝ๋กœ ์œ ์ง€ (tokenizer ๋กœ๋“œ์šฉ)
self.verl_config.actor_rollout_ref.model.path = self.original_model_name
self.logger.log_info(f"๐Ÿ”ง VeRL config set to original model path: {self.original_model_name}")
# TTRLVR ๋ฐ์ดํ„ฐ ์ƒ์„ฑ์—์„œ ์‚ฌ์šฉํ•  ์—”์ง„ ์„ค์ • ์ถ”๊ฐ€
inference_engine = getattr(self.verl_config.data, 'ttrlvr_inference_engine', 'vllm')
self.logger.log_info(f"๐Ÿ”ง TTRLVR inference engine: {inference_engine}")
self.logger.log_info(f"๐Ÿ“ VeRL config loaded successfully")
self.logger.log_info(f"๐Ÿ“‚ Training data files: {len(self.verl_config.data.train_files)}")
# VeRL trainer ์ƒ์„ฑ์„ ์œ„ํ•œ ํ•„์ˆ˜ ๊ตฌ์„ฑ ์š”์†Œ๋“ค ์ค€๋น„
from transformers import AutoTokenizer
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
# Worker ํด๋ž˜์Šค๋“ค import (main_azr_ppo.py์™€ ๋™์ผ)
import ray
# VeRL ๋กœ๊น…์„ TTRLVR ๋กœ๊ทธ์— ํ†ตํ•ฉ
import logging
# ์—ฌ๋Ÿฌ VeRL ๊ด€๋ จ ๋กœ๊ฑฐ๋“ค ์„ค์ •
verl_loggers = [
"verl",
"verl.workers",
"verl.trainer",
"verl.workers.fsdp_workers",
"verl.workers.sharding_manager",
"absolute_zero_reasoner.trainer.ppo"
]
# TTRLVR ๋กœ๊ทธ ํŒŒ์ผ์— VeRL ๋กœ๊ทธ ์ถ”๊ฐ€
if hasattr(self.logger, 'log_file_path') and self.logger.log_file_path:
file_handler = logging.FileHandler(self.logger.log_file_path)
file_handler.setFormatter(logging.Formatter('[VeRL] %(asctime)s - %(name)s - %(levelname)s - %(message)s'))
for logger_name in verl_loggers:
verl_logger = logging.getLogger(logger_name)
verl_logger.setLevel(logging.INFO)
verl_logger.addHandler(file_handler)
# strategy ํ™•์ธ
strategy = self.verl_config.actor_rollout_ref.actor.strategy
self.logger.log_info(f"๐Ÿ”ง Actor strategy: {strategy}")
if strategy in ["fsdp", "fsdp2"]:
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
# VeRL worker ์„ ํƒ (AZR๊ณผ ๋™์ผํ•˜๊ฒŒ ๋งค๋ฒˆ ์ƒˆ๋กœ์šด vLLM ์ƒ์„ฑ)
actor_rollout_cls = AsyncActorRolloutRefWorker if self.verl_config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
elif strategy == "none":
# ๋‹จ์ผ GPU ํ™˜๊ฒฝ - FSDP worker๋ฅผ ์‚ฌ์šฉํ•˜๋˜ FSDP๋Š” ๋น„ํ™œ์„ฑํ™”
self.logger.log_info("๐Ÿ”ง Using single GPU configuration (FSDP workers without FSDP)")
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
# ๋‹จ์ผ GPU์—์„œ๋„ FSDP worker ์‚ฌ์šฉ (FSDP๋Š” ๋‚ด๋ถ€์—์„œ ๋น„ํ™œ์„ฑํ™”๋จ)
actor_rollout_cls = AsyncActorRolloutRefWorker if self.verl_config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
else:
raise NotImplementedError(f"Strategy '{strategy}' not supported. Supported: fsdp, fsdp2, none")
# Tokenizer ์ดˆ๊ธฐํ™” (์›๋ณธ ๋ชจ๋ธ ๊ฒฝ๋กœ ์‚ฌ์šฉ)
if self.current_model_path.startswith('memory://'):
# ๊ฐ€์ƒ ๊ฒฝ๋กœ์ธ ๊ฒฝ์šฐ ์›๋ณธ ๋ชจ๋ธ ๊ฒฝ๋กœ ์‚ฌ์šฉ
tokenizer_path = self.original_model_name
self.logger.log_info(f"๐Ÿ”ง Using original model path for tokenizer: {self.original_model_name}")
else:
tokenizer_path = self.current_model_path
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Resource pool spec ์„ค์ • (VeRL API์— ๋งž๊ฒŒ)
resource_pool_spec = {
"actor_rollout": [0], # GPU 0 ์‚ฌ์šฉ
"critic": [0],
"ref": [0],
"reward": [0]
}
# Role mapping ์„ค์ • (main_azr_ppo.py์™€ ๋™์ผํ•˜๊ฒŒ ray.remote๋กœ ๋ž˜ํ•‘)
role_worker_mapping = {
Role.ActorRollout: ray.remote(actor_rollout_cls),
Role.Critic: ray.remote(CriticWorker),
}
# Resource pool manager ์ดˆ๊ธฐํ™” (main_azr_ppo.py์™€ ๋™์ผํ•œ API)
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [self.verl_config.trainer.n_gpus_per_node] * self.verl_config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(
resource_pool_spec=resource_pool_spec,
mapping=mapping
)
# โญ ํ•ต์‹ฌ: VeRL trainer ์ƒ์„ฑ ์ „์— total_training_steps ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ
self.logger.log_info("๐Ÿ”ข Pre-calculating total_training_steps before trainer creation")
# ๋ฐ์ดํ„ฐ๋กœ๋” ํฌ๊ธฐ ์˜ˆ์ƒ ๊ณ„์‚ฐ (parquet ํŒŒ์ผ ๊ธฐ๋ฐ˜)
import pandas as pd
task_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet']
task_dataloader_sizes = []
for task_file in task_files:
file_path = os.path.join(training_data_path, task_file)
if os.path.exists(file_path):
df = pd.read_parquet(file_path)
train_batch_size = self.verl_config.data.train_batch_size
task_dataloader_size = (len(df) + train_batch_size - 1) // train_batch_size
task_dataloader_sizes.append(task_dataloader_size)
self.logger.log_info(f" ๐Ÿ“„ {task_file}: {len(df)} samples โ†’ {task_dataloader_size} batches")
# TTRLVR์—์„œ๋Š” ๋ชจ๋“  task์—์„œ ๋™์‹œ์— ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์™€์•ผ ํ•˜๋ฏ€๋กœ ์ตœ์†Œ๊ฐ’ ์‚ฌ์šฉ
if task_dataloader_sizes:
estimated_dataloader_size = min(task_dataloader_sizes)
estimated_total_training_steps = estimated_dataloader_size * self.verl_config.trainer.total_epochs
self.logger.log_info(f" ๐Ÿ”ข Min dataloader size: {estimated_dataloader_size}, Total steps: {estimated_total_training_steps}")
else:
estimated_total_training_steps = 100 # fallback
self.logger.log_info(f"๐Ÿ“Š Pre-calculated training steps:")
self.logger.log_info(f" - Task dataloader sizes: {task_dataloader_sizes}")
self.logger.log_info(f" - Min dataloader size: {estimated_dataloader_size}")
self.logger.log_info(f" - Total epochs: {self.verl_config.trainer.total_epochs}")
self.logger.log_info(f" - Estimated total_training_steps: {estimated_total_training_steps}")
# VeRL config์— ๋ฏธ๋ฆฌ ์ฃผ์ž… (VeRL trainer ์ƒ์„ฑ ์ „์—!)
from omegaconf import OmegaConf, open_dict
OmegaConf.set_struct(self.verl_config, True)
with open_dict(self.verl_config):
# Actor optim์— ์ฃผ์ž…
self.verl_config.actor_rollout_ref.actor.optim.total_training_steps = estimated_total_training_steps
# Trainer ๋ ˆ๋ฒจ์—๋„ ์ฃผ์ž… (VeRL์ด ์ด ๊ฐ’์„ ์ฐธ์กฐํ•จ)
self.verl_config.trainer.total_training_steps = estimated_total_training_steps
# Critic ์‚ฌ์šฉ์‹œ ์ฃผ์ž…
if hasattr(self.verl_config, 'critic') and self.verl_config.critic.get('include_critic', False):
self.verl_config.critic.optim.total_training_steps = estimated_total_training_steps
self.logger.log_info(f"โœ… Injected total_training_steps={estimated_total_training_steps} into config before trainer creation")
# ์ฃผ์ž…๋œ ๊ฐ’ ํ™•์ธ
actor_value = OmegaConf.select(self.verl_config, "actor_rollout_ref.actor.optim.total_training_steps")
trainer_value = OmegaConf.select(self.verl_config, "trainer.total_training_steps")
self.logger.log_info(f"๐Ÿ” Verification: actor.optim.total_training_steps = {actor_value}")
self.logger.log_info(f"๐Ÿ” Verification: trainer.total_training_steps = {trainer_value}")
# VeRL trainer ์ƒ์„ฑ (main_azr_ppo.py์™€ ๋™์ผํ•œ ๋ฐฉ์‹)
self.logger.log_info("๐Ÿš€ Creating new VLLM for VeRL (AZR pattern)")
self.verl_trainer = CodeIORayPPOTrainer(
config=self.verl_config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls
)
# โญ ํ•ต์‹ฌ: Worker ์ดˆ๊ธฐํ™” (main_azr_ppo.py์™€ ๋™์ผ)
self.logger.log_info("๐Ÿ”ง Initializing VeRL workers...")
self.verl_trainer.init_workers()
self.logger.log_info("โœ… VeRL workers initialized")
# โญ ๊ฒ€์ฆ: ์‹ค์ œ dataloader ํฌ๊ธฐ์™€ ๋น„๊ต
self.logger.log_info(f"๐Ÿ” Verifying dataloader after trainer creation:")
if hasattr(self.verl_trainer, 'train_dataloader'):
actual_dataloader_size = len(self.verl_trainer.train_dataloader) if self.verl_trainer.train_dataloader else 0
self.logger.log_info(f" - Actual dataloader size: {actual_dataloader_size}")
self.logger.log_info(f" - Estimated dataloader size: {estimated_dataloader_size}")
if actual_dataloader_size != estimated_dataloader_size:
self.logger.log_warning(f"โš ๏ธ Dataloader size mismatch! Estimated: {estimated_dataloader_size}, Actual: {actual_dataloader_size}")
else:
self.logger.log_info("โœ… Dataloader size estimation was correct")
else:
self.logger.log_warning("โš ๏ธ No train_dataloader found after trainer creation")
# โญ ํ•ต์‹ฌ: VeRL trainer์˜ ๋ชจ๋ธ์„ ๊ธฐ์กด ์ธ์Šคํ„ด์Šค๋กœ ๊ต์ฒด (๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ)
self._replace_verl_model_with_existing_instance()
self.logger.log_info("โœ… VeRL trainer initialized successfully")
except Exception as e:
self.logger.log_error(f"Failed to initialize VeRL trainer: {e}")
import traceback
traceback.print_exc()
self.verl_trainer = None
def _replace_verl_model_with_existing_instance(self):
"""VeRL trainer์˜ ๋ชจ๋ธ์„ ๊ธฐ์กด ์ธ์Šคํ„ด์Šค๋กœ ๊ต์ฒดํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ"""
try:
self.logger.log_info("๐Ÿ”„ Replacing VeRL models with existing instance for memory efficiency")
# Actor ๋ชจ๋ธ ๊ต์ฒด
if hasattr(self.verl_trainer, 'actor_rollout_ref'):
if hasattr(self.verl_trainer.actor_rollout_ref, 'actor'):
if hasattr(self.verl_trainer.actor_rollout_ref.actor, 'model'):
# ๊ธฐ์กด VeRL ๋ชจ๋ธ ์‚ญ์ œ (๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ)
del self.verl_trainer.actor_rollout_ref.actor.model
# ๊ธฐ์กด ์ธ์Šคํ„ด์Šค๋กœ ๊ต์ฒด
self.verl_trainer.actor_rollout_ref.actor.model = self.current_model
self.logger.log_info("โœ… Actor model replaced with existing instance")
# Rollout ๋ชจ๋ธ๋„ ๋™์ผํ•˜๊ฒŒ ๊ต์ฒด (ํ•„์š”์‹œ)
if hasattr(self.verl_trainer.actor_rollout_ref, 'rollout'):
if hasattr(self.verl_trainer.actor_rollout_ref.rollout, 'llm'):
# VLLM ์—”์ง„์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๊ธฐ์กด ๋ชจ๋ธ๊ณผ ์—ฐ๊ฒฐ
self.logger.log_info("๐Ÿ”ง Rollout engine detected - using existing model weights")
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
current_memory = torch.cuda.memory_allocated() / 1024**3
self.logger.log_info(f"๐Ÿ“Š GPU memory after model replacement: {current_memory:.1f}GB")
self.logger.log_info("๐ŸŽฏ Single model instance now used across all steps (1-5)!")
except Exception as e:
self.logger.log_warning(f"Model replacement failed, using default VeRL behavior: {e}")
# ์‹คํŒจํ•ด๋„ VeRL์ด ์ž์ฒด ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ ๊ณ„์† ์ง„ํ–‰
def _update_verl_trainer_data(self, training_data_path: str):
"""๊ธฐ์กด VeRL trainer์—์„œ ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ๋งŒ ์—…๋ฐ์ดํŠธ"""
try:
self.logger.log_info("๐Ÿ”„ Updating VeRL trainer data for new round")
# ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ์—…๋ฐ์ดํŠธ
new_train_files = [
os.path.join(training_data_path, "induction.parquet"),
os.path.join(training_data_path, "deduction.parquet"),
os.path.join(training_data_path, "abduction.parquet")
]
# ์กด์žฌํ•˜๋Š” ํŒŒ์ผ๋งŒ ์„ ํƒ
valid_files = [f for f in new_train_files if os.path.exists(f)]
if not valid_files:
self.logger.log_warning("โš ๏ธ No valid training files found for update")
return
# Config ์—…๋ฐ์ดํŠธ
self.verl_config.data.train_files = valid_files
self.verl_config.data.val_files = valid_files
# Trainer์˜ ๋ฐ์ดํ„ฐ ๋กœ๋” ์—…๋ฐ์ดํŠธ
if hasattr(self.verl_trainer, 'update_data_files'):
self.verl_trainer.update_data_files(valid_files)
else:
# Trainer ๋‚ด๋ถ€์˜ config ์—…๋ฐ์ดํŠธ
self.verl_trainer.config.data.train_files = valid_files
self.verl_trainer.config.data.val_files = valid_files
self.logger.log_info(f"โœ… Updated training data: {len(valid_files)} files")
# โญ ์ค‘์š”: ๋ฐ์ดํ„ฐ๊ฐ€ ๋ณ€๊ฒฝ๋˜์—ˆ์œผ๋ฏ€๋กœ worker๋ฅผ ์žฌ์ดˆ๊ธฐํ™”ํ•ด์•ผ ํ•จ
self.logger.log_info("๐Ÿ”ง Re-initializing VeRL workers with new data...")
self.verl_trainer.init_workers()
self.logger.log_info("โœ… VeRL workers re-initialized with actual training data")
except Exception as e:
self.logger.log_error(f"Failed to update VeRL trainer data: {e}")
import traceback
traceback.print_exc()
def _load_verl_config(self):
"""VeRL config ๋กœ๋“œ - ๊ธฐ์กด YAML ํŒŒ์ผ ์‚ฌ์šฉ"""
try:
# VeRL config ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ • (์‹คํ–‰ ๋ชจ๋“œ์— ๋”ฐ๋ผ ์ž๋™ ์„ ํƒ)
if self.verl_config_path:
config_path = self.verl_config_path
else:
config_path = self._get_default_config_path()
self.logger.log_info(f"๐Ÿ”ง Loading VeRL config from: {config_path}")
self.logger.log_info(f"๐Ÿ”ง Config selected for {self.execution_mode} mode")
from omegaconf import OmegaConf
import os
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
# YAML ํŒŒ์ผ ๋กœ๋“œ
self.verl_config = OmegaConf.load(config_path)
# ๋ชจ๋ธ ๊ฒฝ๋กœ๋ฅผ ์‹ค์ œ HuggingFace ๊ฒฝ๋กœ๋กœ ์„ค์ •
if hasattr(self, 'current_model_path') and self.current_model_path:
if self.current_model_path.startswith('memory://'):
# ๊ฐ€์ƒ ๊ฒฝ๋กœ์ธ ๊ฒฝ์šฐ ์›๋ณธ ๋ชจ๋ธ ๊ฒฝ๋กœ ์‚ฌ์šฉ
model_path_for_verl = self.original_model_name
self.logger.log_info(f"๐Ÿ”ง Using original model path for VeRL: {model_path_for_verl}")
else:
model_path_for_verl = self.current_model_path
self.verl_config.actor_rollout_ref.model.path = model_path_for_verl
self.logger.log_info(f"๐Ÿ”ง Updated VeRL model path to: {model_path_for_verl}")
self.logger.log_info("โœ… VeRL config loaded successfully from YAML")
self.logger.log_info(f" - TTRLVR Ray parallel processing: {self.verl_config.data.ttrlvr_ray_config.parallel_processing}")
self.logger.log_info(f" - TTRLVR inference engine: {self.verl_config.data.ttrlvr_inference_engine}")
except Exception as e:
self.logger.log_error(f"Config loading failed: {e}")
self.verl_config = None
def _detect_available_gpus(self):
"""์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ๋ฆฌ์ŠคํŠธ ๊ฐ์ง€"""
import os
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
if cuda_visible_devices:
return [int(gpu.strip()) for gpu in cuda_visible_devices.split(',') if gpu.strip()]
else:
return [0] # ๊ธฐ๋ณธ๊ฐ’
def _determine_execution_mode(self):
"""GPU ๊ฐœ์ˆ˜์— ๋”ฐ๋ฅธ ์‹คํ–‰ ๋ชจ๋“œ ๊ฒฐ์ •"""
num_gpus = len(self.available_gpus)
if num_gpus == 1:
return "single_gpu" # ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ๋‹จ์ผ GPU
else:
return "distributed" # ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ๋ถ„์‚ฐ GPU
def _get_default_config_path(self):
"""์‹คํ–‰ ๋ชจ๋“œ์— ๋”ฐ๋ฅธ ๊ธฐ๋ณธ config ํŒŒ์ผ ๊ฒฝ๋กœ ๋ฐ˜ํ™˜"""
base_path = "/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs"
if self.execution_mode == "single_gpu":
return f"{base_path}/ttrlvr_azr_ppo.yaml" # ๋‹จ์ผ GPU config
else:
return f"{base_path}/ttrlvr_azr_ppo_4gpu.yaml" # ๋‹ค์ค‘ GPU config
def _should_save_checkpoint(self, round_num: int) -> bool:
"""์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ์—ฌ๋ถ€ ๊ฒฐ์ •"""
if self.save_every_round:
return True
if round_num % self.save_round_interval == 0:
return True
return False
def _convert_jsonl_to_ttrlvr_format(self, jsonl_path: str, output_dir: str):
"""VeRL์˜ JSONL ์ถœ๋ ฅ์„ TTRLVR ํ˜•์‹์˜ ๊ฐœ๋ณ„ ํ…์ŠคํŠธ ํŒŒ์ผ๋กœ ๋ณ€ํ™˜"""
import json
try:
with open(jsonl_path, 'r') as f:
data = json.load(f)
# ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ๊ฐœ๋ณ„ ํŒŒ์ผ ์ƒ์„ฑ
for i in range(len(data.get('input', []))):
prompt = data['input'][i] if 'input' in data else ""
response = data['output'][i] if 'output' in data else ""
score = data['score'][i] if 'score' in data else 0.0
# ํ”„๋กฌํ”„ํŠธ์—์„œ task type ์ถ”์ถœ (induction/deduction/abduction)
task_type = "unknown"
if "induction" in prompt.lower() or "input/output pairs" in prompt:
task_type = "induction"
elif "deduction" in prompt.lower() or "observed output" in prompt:
task_type = "deduction"
elif "abduction" in prompt.lower() or "which input produces" in prompt:
task_type = "abduction"
# task type๋ณ„ ์„œ๋ธŒ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
task_dir = os.path.join(output_dir, task_type)
os.makedirs(task_dir, exist_ok=True)
# ํŒŒ์ผ๋ช… ์ƒ์„ฑ
filename = f"verl_training_{task_type}_{self.response_counter}_response.txt"
filepath = os.path.join(task_dir, filename)
# TTRLVR ํ˜•์‹์œผ๋กœ ์ €์žฅ
with open(filepath, 'w') as f:
f.write(f"Task Type: {task_type}\n")
f.write(f"Task ID: verl_step_{data.get('step', 0)[i]}_{i}\n")
f.write(f"Generated: {datetime.now().strftime('%Y%m%d_%H%M%S')}\n")
f.write("="*80 + "\n")
f.write("ORIGINAL PROMPT:\n")
f.write("="*80 + "\n")
f.write(prompt + "\n")
f.write("="*80 + "\n")
f.write("LLM RESPONSE:\n")
f.write("="*80 + "\n")
f.write(response + "\n")
f.write("="*80 + "\n")
f.write("REWARD SCORE:\n")
f.write("="*80 + "\n")
f.write(f"Score: {score:.3f}\n")
# ์ถ”๊ฐ€ ์ •๋ณด๊ฐ€ ์žˆ์œผ๋ฉด ํฌํ•จ
for key in data.keys():
if key not in ['input', 'output', 'score', 'step'] and isinstance(data[key], list):
f.write("="*80 + "\n")
f.write(f"{key.upper()}:\n")
f.write("="*80 + "\n")
f.write(f"{data[key][i] if i < len(data[key]) else 'N/A'}\n")
self.response_counter += 1
except Exception as e:
self.logger.log_error(f"Failed to convert JSONL to TTRLVR format: {e}")
def _save_round_checkpoint(self, round_num: int):
"""๋งค ๋ผ์šด๋“œ๋งˆ๋‹ค VeRL ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ"""
try:
if hasattr(self, 'verl_trainer') and self.verl_trainer:
# VeRL trainer์˜ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ
checkpoint_path = f"checkpoint_round_{round_num}"
# VeRL trainer ์„ค์ •์—์„œ ์ €์žฅ ๊ฒฝ๋กœ ์—…๋ฐ์ดํŠธ
original_dir = self.verl_trainer.config.trainer.default_local_dir
round_checkpoint_dir = f"{original_dir}/{checkpoint_path}"
# ์ž„์‹œ๋กœ ์ €์žฅ ๊ฒฝ๋กœ ๋ณ€๊ฒฝ
self.verl_trainer.config.trainer.default_local_dir = round_checkpoint_dir
# ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
self.verl_trainer._save_checkpoint()
# ์›๋ž˜ ๊ฒฝ๋กœ ๋ณต์›
self.verl_trainer.config.trainer.default_local_dir = original_dir
self.logger.log_info(f"๐Ÿ’พ Round {round_num} checkpoint saved to: {round_checkpoint_dir}")
return round_checkpoint_dir
else:
self.logger.log_warning("โš ๏ธ VeRL trainer not available for checkpoint saving")
return None
except Exception as e:
self.logger.log_error(f"Failed to save round {round_num} checkpoint: {e}")
import traceback
traceback.print_exc()
return None
def _initialize_ray_cluster(self):
"""Ray ํด๋Ÿฌ์Šคํ„ฐ ์ดˆ๊ธฐํ™” (์ „์ฒด ์„ธ์…˜์—์„œ ํ•œ ๋ฒˆ๋งŒ)"""
try:
import ray
import os
# Ray๊ฐ€ ์ด๋ฏธ ์ดˆ๊ธฐํ™”๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ
if ray.is_initialized():
self.logger.log_info("โš ๏ธ Ray already initialized, using existing cluster")
self.ray_initialized = True
return
self.logger.log_info("๐Ÿš€ Initializing Ray cluster with all GPUs for shared usage")
# GPU ์„ค์ •
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
self.logger.log_info(f"๐ŸŽฏ Ray initialization with CUDA_VISIBLE_DEVICES: {cuda_visible_devices}")
# GPU ๊ฐœ์ˆ˜ ํ™•์ธ
available_gpus = cuda_visible_devices.split(',') if cuda_visible_devices else ['0']
self.logger.log_info(f"๐ŸŽฏ Available GPUs: {available_gpus} (count: {len(available_gpus)})")
# VeRL config์—์„œ Ray ์„ค์ • ๊ฐ€์ ธ์˜ค๊ธฐ
ray_config = getattr(self.verl_config, 'ray_init', None) if hasattr(self, 'verl_config') and self.verl_config else None
# Ray ์ดˆ๊ธฐํ™” - AZR ๋ฐฉ์‹๋Œ€๋กœ GPU ๊ฐœ์ˆ˜๋ฅผ ๋ช…์‹œํ•˜์ง€ ์•Š์Œ
# Ray๊ฐ€ GPU๋ฅผ ์ง์ ‘ ๊ด€๋ฆฌํ•˜์ง€ ์•Š๊ณ  CUDA_VISIBLE_DEVICES๋กœ ์ œ์–ด
ray.init(
runtime_env={"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "WARN",
"VERL_LOGGING_LEVEL": "INFO", # VeRL ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ •
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true",
"CUDA_VISIBLE_DEVICES": cuda_visible_devices
}},
num_cpus=ray_config.num_cpus if ray_config else 16, # AZR config์™€ ๋™์ผ
# num_gpus ์„ค์ •ํ•˜์ง€ ์•Š์Œ - AZR ๋ฐฉ์‹
ignore_reinit_error=True # ์žฌ์ดˆ๊ธฐํ™” ์—๋Ÿฌ ๋ฌด์‹œ
)
self.ray_initialized = True
self.logger.log_info("โœ… Ray cluster initialized successfully")
self.logger.log_info(f" - GPUs available via CUDA: {cuda_visible_devices}")
self.logger.log_info(f" - CPUs: {ray_config.num_cpus if ray_config else 16}")
self.logger.log_info(" - GPU management: CUDA_VISIBLE_DEVICES (not Ray)")
self.logger.log_info(" - GPU sharing enabled: VLLM (GPU 0,1) + FSDP (GPU 0,1,2,3)")
except Exception as e:
self.logger.log_error(f"Failed to initialize Ray cluster: {e}")
import traceback
traceback.print_exc()
raise
def _process_problems_sequential(self, benchmark_config: BenchmarkConfig,
problem_ids: List[str], round_num: int) -> Dict[str, Dict]:
"""์ˆœ์ฐจ์ ์œผ๋กœ ๋ฌธ์ œ๋“ค์„ ์ฒ˜๋ฆฌ (๊ธฐ์กด ๋ฐฉ์‹)"""
results = {}
for i, problem_id in enumerate(problem_ids):
self.logger.log_info(f"๐Ÿ“ Processing problem {i+1}/{len(problem_ids)}: {problem_id}")
try:
# Ray Actor ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (ํ•„์š”์‹œ)
if self.remote_pipeline is None:
self._initialize_pipeline()
# Ray Actor์—์„œ ์™„์ „ํ•œ ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰ (์›๊ฒฉ ํ˜ธ์ถœ)
pipeline_result = ray.get(self.remote_pipeline.run_complete_pipeline.remote(
benchmark_config, problem_id, round_num, self.session_timestamp
))
results[problem_id] = pipeline_result
except Exception as e:
self.logger.log_error(f"๐Ÿ’ฅ Failed to process {problem_id}: {e}")
results[problem_id] = {
'success': False,
'error': str(e)
}
return results
def _process_problems_parallel(self, benchmark_config: BenchmarkConfig,
problem_ids: List[str], round_num: int) -> Dict[str, Dict]:
"""[DEPRECATED] Ray๋ฅผ ์‚ฌ์šฉํ•œ ๋ณ‘๋ ฌ ๋ฌธ์ œ ์ฒ˜๋ฆฌ - ํ˜„์žฌ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ
Note: ๋ฌธ์ œ ๊ฐ„ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋Š” ๋น„ํ™œ์„ฑํ™”๋จ. ๋‹จ์ผ ๋ฌธ์ œ ๋‚ด VLLM ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋งŒ ์‚ฌ์šฉ.
"""
try:
# VeRL config์—์„œ TTRLVR Ray ์„ค์ • ๊ฐ€์ ธ์˜ค๊ธฐ
ray_config = getattr(self.verl_config.data, 'ttrlvr_ray_config', {}) if hasattr(self, 'verl_config') and self.verl_config else {}
# ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ํ™œ์„ฑํ™” ์—ฌ๋ถ€ ํ™•์ธ
parallel_enabled = ray_config.get('parallel_processing', False)
max_concurrent = ray_config.get('max_concurrent_problems', 4)
if not parallel_enabled or len(problem_ids) <= 1:
self.logger.log_info("๐Ÿ“ Using sequential processing (parallel_processing=False or single problem)")
return self._process_problems_sequential(benchmark_config, problem_ids, round_num)
# ์‹ค์ œ Ray ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๊ตฌํ˜„
self.logger.log_info(f"๐Ÿš€ Using Ray parallel processing for {len(problem_ids)} problems")
self.logger.log_info(f" - Max concurrent: {min(max_concurrent, len(problem_ids))}")
import ray
# Ray Actor๋ฅผ ์‚ฌ์šฉํ•œ ๋ณ‘๋ ฌ TTRLVR ํŒŒ์ดํ”„๋ผ์ธ ์ฒ˜๋ฆฌ
import os
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
available_gpus = cuda_visible_devices.split(',') if cuda_visible_devices else ['0']
@ray.remote(num_gpus=1)
class TTRLVRPipelineActor:
def __init__(self, config, logger_config, gpu_id=0):
# GPU ์„ค์ • ๋จผ์ €
import os
import torch
# ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ์‹ค์ œ GPU ๋ฒˆํ˜ธ ๊ฐ€์ ธ์˜ค๊ธฐ
cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
available_gpus = cuda_devices.split(',') if cuda_devices else ['0']
actual_gpu = available_gpus[gpu_id % len(available_gpus)]
# ํ˜„์žฌ ํ”„๋กœ์„ธ์Šค์˜ CUDA_VISIBLE_DEVICES ์„ค์ •
os.environ['CUDA_VISIBLE_DEVICES'] = actual_gpu
# CUDA ์ดˆ๊ธฐํ™” ๊ฐ•์ œ
if torch.cuda.is_available():
torch.cuda.set_device(0) # ๋กœ์ปฌ์—์„œ๋Š” ํ•ญ์ƒ 0๋ฒˆ (์‹ค์ œ๋กœ๋Š” actual_gpu)
print(f"๐ŸŽฏ Actor initialized on GPU {actual_gpu} (local:0)")
# TTRLVR ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
from absolute_zero_reasoner.testtime.complete_pipeline import CompleteTestTimePipeline
from absolute_zero_reasoner.testtime.logger import TestTimeLogger
# ๋กœ๊ฑฐ ์žฌ์ƒ์„ฑ (์ง๋ ฌํ™” ๋ฌธ์ œ ํ•ด๊ฒฐ)
logger = TestTimeLogger(log_dir=logger_config.get('log_dir', '/tmp'))
# ๋ชจ๋ธ ๋กœ๋“œ (๊ฐ Actor๋งˆ๋‹ค ๋…๋ฆฝ์ ์œผ๋กœ)
model, tokenizer = self._load_pipeline_model(config)
self.pipeline = CompleteTestTimePipeline(
model=model,
tokenizer=tokenizer,
config=config,
logger=logger
)
def _load_pipeline_model(self, config):
"""๊ฐ Actor์—์„œ ๋…๋ฆฝ์ ์œผ๋กœ ๋ชจ๋ธ ๋กœ๋“œ"""
from absolute_zero_reasoner.testtime.solution_generator import InitialSolutionGenerator
import torch
import os
# ์—”์ง„ ์„ ํƒ (config ๊ธฐ๋ณธ๊ฐ’ ์‚ฌ์šฉ)
use_vllm = getattr(config, 'use_vllm_for_data_generation', True)
# GPU ์„ค์ • - ํ˜„์žฌ Actor์— ํ• ๋‹น๋œ GPU ์‚ฌ์šฉ
device = 'cuda:0' # ๋กœ์ปฌ์—์„œ๋Š” ํ•ญ์ƒ 0๋ฒˆ (์‹ค์ œ GPU๋Š” CUDA_VISIBLE_DEVICES๋กœ ์ œ์–ด)
print(f"๐Ÿ”„ Loading model on device {device} (actual GPU: {os.environ.get('CUDA_VISIBLE_DEVICES', 'unknown')})")
return InitialSolutionGenerator.load_model_with_optimizations(
config.model_name, device, config, use_vllm=use_vllm
)
def process_problem(self, benchmark_config, problem_id, round_num, session_timestamp):
"""๋‹จ์ผ ๋ฌธ์ œ ์ฒ˜๋ฆฌ"""
try:
result = self.pipeline.run_complete_pipeline(
benchmark_config, problem_id, round_num, session_timestamp
)
return problem_id, result
except Exception as e:
return problem_id, {
'success': False,
'error': str(e)
}
# Actor ์ƒ์„ฑ (์ตœ๋Œ€ ๋™์‹œ ์‹คํ–‰ ์ˆ˜๋งŒํผ, GPU ์ˆ˜ ๊ณ ๋ ค)
num_actors = min(max_concurrent, len(problem_ids), len(available_gpus))
self.logger.log_info(f"๐ŸŽญ Creating {num_actors} Ray actors across {len(available_gpus)} GPUs")
self.logger.log_info(f" - Available GPUs: {available_gpus}")
self.logger.log_info(f" - Debug: max_concurrent={max_concurrent}, len(problem_ids)={len(problem_ids)}, len(available_gpus)={len(available_gpus)}")
# ๋กœ๊ฑฐ ์„ค์ • ์ง๋ ฌํ™”
logger_config = {
'log_dir': self.logger.log_dir if hasattr(self.logger, 'log_dir') else '/tmp'
}
# GPU๋ณ„๋กœ Actor ์ƒ์„ฑ
actors = []
for i in range(num_actors):
gpu_id = i % len(available_gpus)
self.logger.log_info(f" - Actor {i} -> GPU {available_gpus[gpu_id]}")
actors.append(TTRLVRPipelineActor.remote(self.config, logger_config, gpu_id))
# ์ž‘์—… ๋ถ„๋ฐฐ ๋ฐ ์‹คํ–‰
futures = []
for i, problem_id in enumerate(problem_ids):
actor_idx = i % num_actors
future = actors[actor_idx].process_problem.remote(
benchmark_config, problem_id, round_num, self.session_timestamp
)
futures.append(future)
# ๊ฒฐ๊ณผ ์ˆ˜์ง‘
self.logger.log_info(f"โณ Waiting for {len(futures)} parallel tasks to complete...")
results_list = ray.get(futures)
# ๊ฒฐ๊ณผ ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
results = {}
for problem_id, result in results_list:
results[problem_id] = result
self.logger.log_info(f"โœ… Parallel processing completed: {len(results)} problems processed")
return results
except Exception as e:
self.logger.log_error(f"๐Ÿ’ฅ Parallel processing failed: {e}")
self.logger.log_info("๐Ÿ“ Falling back to sequential processing")
return self._process_problems_sequential(benchmark_config, problem_ids, round_num)
def _find_actual_training_data(self) -> Optional[str]:
"""์ตœ๊ทผ ์ƒ์„ฑ๋œ ์‹ค์ œ ํ•™์Šต ๋ฐ์ดํ„ฐ ๋””๋ ‰ํ† ๋ฆฌ ์ฐพ๊ธฐ"""
try:
# tmp/batch_results์—์„œ ์ตœ๊ทผ ์ƒ์„ฑ๋œ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒ€์ƒ‰
base_path = "/home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/batch_results"
# azr_training_data ํด๋” ๊ฒ€์ƒ‰
import glob
search_pattern = os.path.join(base_path, "**/azr_training_data")
data_dirs = glob.glob(search_pattern, recursive=True)
if not data_dirs:
return None
# ๊ฐ€์žฅ ์ตœ๊ทผ ์ˆ˜์ •๋œ ๋””๋ ‰ํ† ๋ฆฌ ์ฐพ๊ธฐ
latest_dir = max(data_dirs, key=os.path.getmtime)
# parquet ํŒŒ์ผ์ด ์‹ค์ œ๋กœ ์žˆ๋Š”์ง€ ํ™•์ธ
task_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet']
parquet_count = 0
for task_file in task_files:
if os.path.exists(os.path.join(latest_dir, task_file)):
parquet_count += 1
if parquet_count > 0:
return latest_dir
else:
return None
except Exception as e:
self.logger.log_error(f"Error finding actual training data: {e}")
return None