| # TTRLVR-AZR ํตํฉ ๊ณํ์ | |
| ## ๊ฐ์ | |
| TTRLVR์ AZR ๋ฐฉ์์ผ๋ก ์์ ํตํฉํ์ฌ ํ๋์ VeRL ์ธ์ ์์ ๋ชจ๋ Phase๋ฅผ ์ฒ๋ฆฌํ๋๋ก ์ฌ๊ตฌ์กฐํ | |
| ## 1. ์ ์ฒด ๊ตฌ์กฐ ๋ณ๊ฒฝ | |
| ### ํ์ฌ ๊ตฌ์กฐ (๋ถ๋ฆฌํ) | |
| ``` | |
| train_ttrlvr_azr.py | |
| โโโ for round in rounds: | |
| โ โโโ Phase 1-4: RemoteTestTimePipeline (๋ ๋ฆฝ vLLM) | |
| โ โ โโโ Step 1: ํ๋ก๊ทธ๋จ ์์ฑ | |
| โ โ โโโ Step 2: I/O ์ ์์ฑ | |
| โ โ โโโ Step 3: Task ์์ฑ | |
| โ โ โโโ Step 4: ๊ฒ์ฆ | |
| โ โโโ ray.kill(pipeline) # vLLM ์ญ์ | |
| โ โโโ Phase 5: VeRL Training (์ vLLM) | |
| โ โโโ trainer.init_workers() # ๋งค ๋ผ์ด๋๋ง๋ค | |
| โ โโโ trainer.fit() # 1 epoch | |
| ``` | |
| ### ๋ชฉํ ๊ตฌ์กฐ (ํตํฉํ) | |
| ``` | |
| train_ttrlvr_azr_unified.py | |
| โโโ trainer = UnifiedTTRLVRTrainer() | |
| โโโ trainer.init_workers() # 1๋ฒ๋ง! | |
| โโโ trainer.fit() | |
| โโโ for round in rounds: # ๋ด๋ถ์์ ์ฒ๋ฆฌ | |
| โโโ Phase 1-4: ๋ฐ์ดํฐ ์์ฑ (๊ฐ์ vLLM) | |
| โโโ Phase 5: ํ์ต (๊ฐ์ vLLM) | |
| ``` | |
| ## 2. ํ์ผ๋ณ ์์ ๊ณํ | |
| ### 2.1 ์๋ก์ด ํ์ผ ์์ฑ | |
| #### `/test/trainer/unified_ttrlvr_trainer.py` | |
| ```python | |
| """ | |
| ํตํฉ TTRLVR Trainer - ๋ชจ๋ Phase๋ฅผ ํ๋์ ์ธ์ ์์ ์ฒ๋ฆฌ | |
| """ | |
| from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer | |
| class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer): | |
| def __init__(self, ttrlvr_config, problem_ids, total_rounds, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.ttrlvr_config = ttrlvr_config | |
| self.problem_ids = problem_ids | |
| self.total_rounds = total_rounds | |
| self.current_round = 0 | |
| def fit(self): | |
| """๋ฉ์ธ ํ์ต ๋ฃจํ - ๋ผ์ด๋๋ณ ์ฒ๋ฆฌ ํฌํจ""" | |
| # ๋ก๊ฑฐ ์ค์ | |
| logger = self._setup_logger() | |
| # ์ ์ฒด ๋ผ์ด๋ ๋ฐ๋ณต | |
| for round_num in range(1, self.total_rounds + 1): | |
| self.current_round = round_num | |
| # Phase 1-4: ๋ฐ์ดํฐ ์์ฑ | |
| round_data = self._generate_round_data() | |
| # Phase 5: 1 epoch ํ์ต | |
| self._train_one_round(round_data) | |
| # ์ฒดํฌํฌ์ธํธ ์ ์ฅ | |
| if round_num % 5 == 0: | |
| self._save_checkpoint() | |
| def _generate_round_data(self): | |
| """Phase 1-4๋ฅผ VeRL ๋ด๋ถ์์ ์ฒ๋ฆฌ""" | |
| # ๊ธฐ์กด TestTimePipeline ๋ก์ง์ ์ด๊ณณ์ผ๋ก ์ด๋ | |
| pass | |
| ``` | |
| ### 2.2 ๊ธฐ์กด ํ์ผ ์์ | |
| #### `/test/train_ttrlvr_azr.py` โ `/test/train_ttrlvr_azr_unified.py` | |
| ๋ณ๊ฒฝ ์ : | |
| ```python | |
| # ๋ณต์กํ ๋ผ์ด๋๋ณ ์ฒ๋ฆฌ | |
| trainer = IterativeTrainer(...) | |
| for round in rounds: | |
| # Phase 1-4 | |
| pipeline = RemoteTestTimePipeline(...) | |
| data = pipeline.run_complete_pipeline(...) | |
| ray.kill(pipeline) | |
| # Phase 5 | |
| trainer.train_with_data(data) | |
| ``` | |
| ๋ณ๊ฒฝ ํ: | |
| ```python | |
| # ๋จ์ํ๋ ๋ฉ์ธ ๋ก์ง | |
| from trainer.unified_ttrlvr_trainer import UnifiedTTRLVRTrainer | |
| # ์ค์ | |
| config = load_config() | |
| trainer = UnifiedTTRLVRTrainer( | |
| config=config, | |
| problem_ids=problem_ids, | |
| total_rounds=args.rounds, | |
| tokenizer=tokenizer, | |
| ... | |
| ) | |
| # ํ ๋ฒ๋ง ์ด๊ธฐํ | |
| trainer.init_workers() | |
| # ๋ชจ๋ ๋ผ์ด๋ ์ฒ๋ฆฌ | |
| trainer.fit() | |
| ``` | |
| #### `/test/utils/testtime_pipeline.py` ๋ก์ง ์ด๋ | |
| ๊ธฐ์กด Phase 1-4 ๋ก์ง์ UnifiedTTRLVRTrainer๋ก ์ด๋: | |
| ```python | |
| class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer): | |
| def _generate_programs(self, problem_data): | |
| """Step 1: ํ๋ก๊ทธ๋จ ์์ฑ - TestTimePipeline์์ ์ด๋""" | |
| prompt = self._create_program_prompt(problem_data) | |
| # VeRL์ vLLM ์ฌ์ฉ! | |
| prompts_proto = DataProto.from_dict({ | |
| "input_ids": tokenize(prompt), | |
| "attention_mask": ... | |
| }) | |
| # ๊ธฐ์กด actor_rollout_wg ์ฌ์ฉ | |
| outputs = self.actor_rollout_wg.generate_sequences(prompts_proto) | |
| return self._parse_programs(outputs) | |
| def _generate_io_pairs(self, programs): | |
| """Step 2: I/O ์์ฑ - TestTimePipeline์์ ์ด๋""" | |
| # ๊ฐ์ ๋ฐฉ์์ผ๋ก ๊ตฌํ | |
| pass | |
| ``` | |
| ### 2.3 ์ค์ ํ์ผ ์์ | |
| #### `/test/configs/ttrlvr_azr_unified.yaml` | |
| ```yaml | |
| # ํตํฉ ์ค์ | |
| actor_rollout_ref: | |
| rollout: | |
| # dummy_dtensor ์ฌ์ฉ ๊ฐ๋ฅ (๊ฐ์ vLLM ๊ณ์ ์ฌ์ฉ) | |
| load_format: dummy_dtensor | |
| # TTRLVR ํนํ ์ค์ | |
| ttrlvr: | |
| # Phase 1-4 ์ค์ | |
| num_programs: 4 | |
| input_generation_rounds: 3 | |
| # Phase 5 ์ค์ | |
| train_batch_size: 8 | |
| epochs_per_round: 1 # ๋ผ์ด๋๋น 1 epoch | |
| ``` | |
| ## 3. ๊ตฌํ ์์ธ | |
| ### 3.1 UnifiedTTRLVRTrainer ์ ์ฒด ๊ตฌํ | |
| ```python | |
| # /test/trainer/unified_ttrlvr_trainer.py | |
| import os | |
| import json | |
| import torch | |
| import pandas as pd | |
| from typing import List, Dict, Any, Optional | |
| from datetime import datetime | |
| import numpy as np | |
| from verl import DataProto | |
| from verl.utils.py_utils import merge_dict | |
| from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer | |
| from absolute_zero_reasoner.testtime.config import BenchmarkConfig | |
| from absolute_zero_reasoner.testtime.execution import PythonExecutor | |
| class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer): | |
| """ | |
| TTRLVR์ ๋ชจ๋ Phase๋ฅผ ํ๋์ VeRL ์ธ์ ์์ ์ฒ๋ฆฌํ๋ ํตํฉ Trainer | |
| """ | |
| def __init__( | |
| self, | |
| ttrlvr_config: Dict[str, Any], | |
| benchmark_config: BenchmarkConfig, | |
| problem_ids: List[str], | |
| total_rounds: int, | |
| output_dir: str, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.ttrlvr_config = ttrlvr_config | |
| self.benchmark_config = benchmark_config | |
| self.problem_ids = problem_ids | |
| self.total_rounds = total_rounds | |
| self.output_dir = output_dir | |
| self.current_round = 0 | |
| # Phase 1-4์ฉ ์ค์ | |
| self.num_programs = ttrlvr_config.get('num_programs', 4) | |
| self.input_rounds = ttrlvr_config.get('input_generation_rounds', 3) | |
| self.parallel_batch_size = ttrlvr_config.get('parallel_batch_size', 4) | |
| # Python ์คํ๊ธฐ | |
| self.executor = PythonExecutor(timeout_length=10) | |
| def fit(self): | |
| """ | |
| ํตํฉ ํ์ต ๋ฃจํ - AZR์ fit()์ ํ์ฅํ์ฌ ๋ผ์ด๋๋ณ ์ฒ๋ฆฌ | |
| """ | |
| # ๊ธฐ๋ณธ ๋ก๊ฑฐ ์ค์ | |
| from verl.utils.tracking import Tracking | |
| logger = Tracking( | |
| project_name=self.config.trainer.project_name, | |
| experiment_name=self.config.trainer.experiment_name, | |
| default_backend=self.config.trainer.logger, | |
| config=self.config, | |
| tags=self.config.trainer.wandb_tags, | |
| entity=self.config.trainer.wandb.entity, | |
| wandb_run_id=self.config.trainer.wandb_run_id, | |
| ) | |
| # ์ ์ฒด ๋ผ์ด๋ ๋ฐ๋ณต | |
| for round_num in range(1, self.total_rounds + 1): | |
| self.current_round = round_num | |
| logger.log({"round": round_num}) | |
| print(f"\n{'='*80}") | |
| print(f"๐ Round {round_num}/{self.total_rounds}") | |
| print(f"{'='*80}") | |
| # Phase 1-4: ๋ฐ์ดํฐ ์์ฑ | |
| round_start = datetime.now() | |
| round_data = self._generate_round_data() | |
| data_gen_time = (datetime.now() - round_start).total_seconds() | |
| print(f"โ Data generation completed in {data_gen_time:.2f}s") | |
| print(f"๐ Generated {len(round_data)} training examples") | |
| # ๋ฐ์ดํฐ๋ฅผ parquet ํ์ผ๋ก ์ ์ฅ | |
| self._save_round_data(round_data, round_num) | |
| # Phase 5: PPO ํ์ต (1 epoch) | |
| train_start = datetime.now() | |
| metrics = self._train_one_round(round_data, logger) | |
| train_time = (datetime.now() - train_start).total_seconds() | |
| print(f"โ Training completed in {train_time:.2f}s") | |
| # ๋ฉํธ๋ฆญ ๋ก๊น | |
| logger.log({ | |
| "round_time/data_generation": data_gen_time, | |
| "round_time/training": train_time, | |
| "round_time/total": data_gen_time + train_time, | |
| **metrics | |
| }) | |
| # ์ฒดํฌํฌ์ธํธ ์ ์ฅ | |
| if round_num % 5 == 0: | |
| self._save_checkpoint() | |
| def _generate_round_data(self) -> List[Dict[str, Any]]: | |
| """ | |
| Phase 1-4: ํ์ฌ ๋ชจ๋ธ๋ก ๋ผ์ด๋ ๋ฐ์ดํฐ ์์ฑ | |
| """ | |
| all_tasks = [] | |
| for problem_id in self.problem_ids: | |
| print(f"\n๐ Processing problem: {problem_id}") | |
| try: | |
| # Step 1: ํ๋ก๊ทธ๋จ ์์ฑ | |
| programs = self._generate_programs(problem_id) | |
| print(f" โ Generated {len(programs)} programs") | |
| # Step 2: I/O ์ ์์ฑ | |
| io_pairs = self._generate_io_pairs(problem_id, programs) | |
| print(f" โ Generated {len(io_pairs)} I/O pairs") | |
| # Step 3: Task ์์ฑ | |
| tasks = self._create_reasoning_tasks(problem_id, programs, io_pairs) | |
| print(f" โ Created {len(tasks)} tasks") | |
| # Step 4: ๊ฒ์ฆ | |
| valid_tasks = self._validate_tasks(tasks) | |
| print(f" โ Validated {len(valid_tasks)}/{len(tasks)} tasks") | |
| all_tasks.extend(valid_tasks) | |
| except Exception as e: | |
| print(f" โ Error processing {problem_id}: {e}") | |
| continue | |
| return all_tasks | |
| def _generate_programs(self, problem_id: str) -> List[str]: | |
| """ | |
| Step 1: ๋ค์ํ ํ๋ก๊ทธ๋จ ์์ฑ | |
| VeRL์ vLLM์ ์ฌ์ฉํ์ฌ ์์ฑ | |
| """ | |
| # ๋ฌธ์ ๋ฐ์ดํฐ ๋ก๋ | |
| problem_data = self._load_problem_data(problem_id) | |
| # ํ๋กฌํํธ ์์ฑ | |
| prompt = f"""You are given a programming problem. Generate {self.num_programs} different solutions. | |
| Problem: {problem_data['description']} | |
| Generate {self.num_programs} different Python solutions:""" | |
| # ํ ํฐํ | |
| input_ids = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.config.data.max_prompt_length | |
| ).input_ids | |
| # DataProto ์์ฑ | |
| prompts_proto = DataProto.from_dict({ | |
| "input_ids": input_ids.cuda(), | |
| "attention_mask": torch.ones_like(input_ids).cuda(), | |
| "position_ids": torch.arange(input_ids.size(1)).unsqueeze(0).cuda() | |
| }) | |
| # ๋ฉํ ์ ๋ณด ์ถ๊ฐ | |
| prompts_proto.meta_info = { | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| "pad_token_id": self.tokenizer.pad_token_id, | |
| "temperature": 0.8, # ๋ค์์ฑ์ ์ํด ๋์ temperature | |
| "do_sample": True, | |
| "top_p": 0.95, | |
| "response_length": 512 | |
| } | |
| # VeRL์ vLLM์ผ๋ก ์์ฑ! | |
| outputs = self.actor_rollout_wg.generate_sequences(prompts_proto) | |
| # ํ๋ก๊ทธ๋จ ์ถ์ถ | |
| programs = [] | |
| generated_text = self.tokenizer.decode( | |
| outputs.batch["input_ids"][0], | |
| skip_special_tokens=True | |
| ) | |
| # ํ๋ก๊ทธ๋จ ํ์ฑ (์ฝ๋ ๋ธ๋ก ์ถ์ถ) | |
| code_blocks = self._extract_code_blocks(generated_text) | |
| programs.extend(code_blocks[:self.num_programs]) | |
| return programs | |
| def _generate_io_pairs( | |
| self, | |
| problem_id: str, | |
| programs: List[str] | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Step 2: ํ๋ก๊ทธ๋จ๋ค๋ก๋ถํฐ I/O ์ ์์ฑ | |
| """ | |
| io_pairs = [] | |
| for program in programs: | |
| # ๊ฐ ํ๋ก๊ทธ๋จ์ ๋ํด ์ฌ๋ฌ ์ ๋ ฅ ์์ฑ | |
| for round_idx in range(self.input_rounds): | |
| prompt = f"""Given this Python function, generate {5} test inputs. | |
| Function: | |
| ```python | |
| {program} | |
| ``` | |
| Generate {5} different test inputs as a Python list:""" | |
| # ์ ๋ ฅ ์์ฑ | |
| inputs = self._generate_with_vllm(prompt, temperature=0.7) | |
| # ๊ฐ ์ ๋ ฅ์ ๋ํด ์ถ๋ ฅ ๊ณ์ฐ | |
| for test_input in inputs: | |
| try: | |
| output = self.executor.execute(program, test_input) | |
| if output['success']: | |
| io_pairs.append({ | |
| 'input': test_input, | |
| 'output': output['result'], | |
| 'program': program | |
| }) | |
| except: | |
| continue | |
| return io_pairs | |
| def _create_reasoning_tasks( | |
| self, | |
| problem_id: str, | |
| programs: List[str], | |
| io_pairs: List[Dict[str, Any]] | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Step 3: Induction, Deduction, Abduction task ์์ฑ | |
| """ | |
| tasks = [] | |
| for io_pair in io_pairs: | |
| # Induction: I/O โ Program | |
| tasks.append({ | |
| 'problem_id': problem_id, | |
| 'task_type': 'induction', | |
| 'input': io_pair['input'], | |
| 'output': io_pair['output'], | |
| 'target': io_pair['program'], | |
| 'prompt': self._create_induction_prompt(io_pair) | |
| }) | |
| # Deduction: Program + Input โ Output | |
| tasks.append({ | |
| 'problem_id': problem_id, | |
| 'task_type': 'deduction', | |
| 'input': io_pair['input'], | |
| 'program': io_pair['program'], | |
| 'target': io_pair['output'], | |
| 'prompt': self._create_deduction_prompt(io_pair) | |
| }) | |
| # Abduction: Program + Output โ Input | |
| tasks.append({ | |
| 'problem_id': problem_id, | |
| 'task_type': 'abduction', | |
| 'program': io_pair['program'], | |
| 'output': io_pair['output'], | |
| 'target': io_pair['input'], | |
| 'prompt': self._create_abduction_prompt(io_pair) | |
| }) | |
| return tasks | |
| def _train_one_round( | |
| self, | |
| round_data: List[Dict[str, Any]], | |
| logger | |
| ) -> Dict[str, float]: | |
| """ | |
| Phase 5: ํ ๋ผ์ด๋์ PPO ํ์ต | |
| """ | |
| # ๋ฐ์ดํฐ๋ฅผ VeRL ํ์์ผ๋ก ๋ณํ | |
| train_dataset = self._convert_to_verl_dataset(round_data) | |
| # ํ์ฌ dataloader ์ ๋ฐ์ดํธ | |
| self.train_dataloader = self._create_dataloader( | |
| train_dataset, | |
| self.val_dataset, | |
| self.collate_fn, | |
| self.train_sampler | |
| ) | |
| # 1 epoch ํ์ต | |
| epoch_metrics = {} | |
| for step, batch in enumerate(self.train_dataloader): | |
| # ๋ฐฐ์น ์ค๋น | |
| gen_batch = self._prepare_generation_batch(batch) | |
| # ์ํ์ค ์์ฑ (๊ฐ์ vLLM ์ฌ์ฉ!) | |
| gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) | |
| # ๋ฆฌ์๋ ๊ณ์ฐ | |
| batch = batch.union(gen_batch_output) | |
| reward_tensor = self.reward_fn(batch) | |
| # PPO ์ ๋ฐ์ดํธ | |
| update_metrics = self._ppo_update(batch, reward_tensor) | |
| # ๋ฉํธ๋ฆญ ์์ง | |
| for k, v in update_metrics.items(): | |
| if k not in epoch_metrics: | |
| epoch_metrics[k] = [] | |
| epoch_metrics[k].append(v) | |
| # ํ๊ท ๋ฉํธ๋ฆญ ๊ณ์ฐ | |
| avg_metrics = { | |
| k: np.mean(v) for k, v in epoch_metrics.items() | |
| } | |
| return avg_metrics | |
| def _generate_with_vllm( | |
| self, | |
| prompt: str, | |
| temperature: float = 0.7 | |
| ) -> Any: | |
| """ | |
| ํฌํผ ํจ์: VeRL์ vLLM์ ์ฌ์ฉํ ํ ์คํธ ์์ฑ | |
| """ | |
| # ํ ํฐํ | |
| input_ids = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True | |
| ).input_ids | |
| # DataProto ์์ฑ | |
| prompts_proto = DataProto.from_dict({ | |
| "input_ids": input_ids.cuda(), | |
| "attention_mask": torch.ones_like(input_ids).cuda(), | |
| }) | |
| prompts_proto.meta_info = { | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| "pad_token_id": self.tokenizer.pad_token_id, | |
| "temperature": temperature, | |
| "do_sample": True, | |
| "response_length": 256 | |
| } | |
| # ์์ฑ | |
| outputs = self.actor_rollout_wg.generate_sequences(prompts_proto) | |
| # ๋์ฝ๋ฉ | |
| generated_text = self.tokenizer.decode( | |
| outputs.batch["input_ids"][0], | |
| skip_special_tokens=True | |
| ) | |
| return self._parse_output(generated_text) | |
| def _save_round_data(self, round_data: List[Dict], round_num: int): | |
| """๋ผ์ด๋ ๋ฐ์ดํฐ๋ฅผ parquet ํ์ผ๋ก ์ ์ฅ""" | |
| output_dir = os.path.join(self.output_dir, f"round_{round_num}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Task ํ์ ๋ณ๋ก ๋ถ๋ฆฌ | |
| for task_type in ['induction', 'deduction', 'abduction']: | |
| tasks = [t for t in round_data if t['task_type'] == task_type] | |
| if tasks: | |
| df = pd.DataFrame(tasks) | |
| df.to_parquet(os.path.join(output_dir, f"{task_type}.parquet")) | |
| ``` | |
| ### 3.2 ๋ฐ์ดํฐ ํ๋ฆ ์์ธ | |
| ```python | |
| # ์ค์ ๋ฐ์ดํฐ๊ฐ ํ๋ฅด๋ ๊ณผ์ | |
| # Round 1 ์์ | |
| trainer.current_round = 1 | |
| # 1. ํ๋ก๊ทธ๋จ ์์ฑ | |
| programs = trainer._generate_programs("Mbpp/1") | |
| # โ trainer.actor_rollout_wg.generate_sequences() ํธ์ถ | |
| # โ FSDP ๋ชจ๋ธ์ ๊ฐ์ค์น๊ฐ vLLM์ ๋๊ธฐํ๋จ (์ฒซ ๋ฒ์งธ) | |
| # โ ์ถ๋ ฅ: ["def solve(x): return x*2", "def solve(x): return 2*x", ...] | |
| # 2. I/O ์์ฑ | |
| io_pairs = trainer._generate_io_pairs("Mbpp/1", programs) | |
| # โ ๊ฐ์ vLLM ์ฌ์ฉ (๋๊ธฐํ ๊ฑด๋๋ - base_sync_done=True) | |
| # โ ์ถ๋ ฅ: [{"input": 5, "output": 10}, {"input": 3, "output": 6}, ...] | |
| # 3. Task ์์ฑ | |
| tasks = trainer._create_reasoning_tasks(...) | |
| # โ ๋ฉ๋ชจ๋ฆฌ์์๋ง ์ฒ๋ฆฌ (vLLM ํธ์ถ ์์) | |
| # 4. PPO ํ์ต | |
| trainer._train_one_round(tasks) | |
| # โ ๊ฐ์ vLLM์ผ๋ก response ์์ฑ | |
| # โ FSDP ๋ชจ๋ธ ์ ๋ฐ์ดํธ | |
| # โ vLLM์ ๋ฉ๋ชจ๋ฆฌ ์ฐธ์กฐ๋ก ์๋ ์ ๋ฐ์ดํธ | |
| # Round 2 ์์ - ๊ฐ์ vLLM ๊ณ์ ์ฌ์ฉ! | |
| ``` | |
| ### 3.3 ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ์์ธ | |
| ```python | |
| class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer): | |
| def _manage_memory(self): | |
| """Phase ์ ํ ์ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ""" | |
| # Ray actor๋ ์ ์งํ๋ฉด์ GPU ์บ์๋ง ์ ๋ฆฌ | |
| torch.cuda.empty_cache() | |
| # vLLM์ KV ์บ์ ์ ๋ฆฌ (์ ํ์ ) | |
| if hasattr(self.actor_rollout_wg, 'clear_kv_cache'): | |
| self.actor_rollout_wg.clear_kv_cache() | |
| def _monitor_memory(self): | |
| """๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ๋ชจ๋ํฐ๋ง""" | |
| for i in range(torch.cuda.device_count()): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**3 | |
| reserved = torch.cuda.memory_reserved(i) / 1024**3 | |
| print(f"GPU {i}: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB") | |
| ``` | |
| ### 3.4 ๋๊ธฐํ ๋ฉ์ปค๋์ฆ ์์ธ | |
| ```python | |
| # ๋๊ธฐํ๊ฐ ์ด๋ป๊ฒ ๋ณด์ฅ๋๋์ง | |
| # 1. ์ฒซ ๋ฒ์งธ generate_sequences ํธ์ถ | |
| with self.rollout_sharding_manager: # __enter__() ํธ์ถ | |
| # dummy_dtensor ์ฌ์ฉ ์: | |
| # - self.base_sync_done = False (์ด๊ธฐ๊ฐ) | |
| # - sync_model_weights() ์คํ โ FSDP โ vLLM ๋๊ธฐํ | |
| # - self.base_sync_done = True ์ค์ | |
| # 2. ์ดํ generate_sequences ํธ์ถ๋ค | |
| with self.rollout_sharding_manager: # __enter__() ํธ์ถ | |
| # - self.base_sync_done = True | |
| # - sync_model_weights() ๊ฑด๋๋ | |
| # - ํ์ง๋ง ๊ฐ์ vLLM์ด๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ ์ฐธ์กฐ๋ก ์ ๋ฐ์ดํธ๋จ | |
| # 3. ๋ฉ๋ชจ๋ฆฌ ์ฐธ์กฐ ๋ฉ์ปค๋์ฆ | |
| # FSDP ๋ชจ๋ธ๊ณผ vLLM ๋ชจ๋ธ์ด ๊ฐ์ tensor๋ฅผ ์ฐธ์กฐ | |
| # FSDP ์ ๋ฐ์ดํธ โ tensor ๊ฐ ๋ณ๊ฒฝ โ vLLM๋ ์๋์ผ๋ก ์ ๊ฐ ์ฌ์ฉ | |
| ``` | |
| ### 3.5 ์๋ฌ ์ฒ๋ฆฌ ๋ฐ ๋ณต๊ตฌ | |
| ```python | |
| class UnifiedTTRLVRTrainer(CodeIORayPPOTrainer): | |
| def _safe_generate(self, prompt: str, max_retries: int = 3): | |
| """์์ ํ ์์ฑ with ์ฌ์๋""" | |
| for attempt in range(max_retries): | |
| try: | |
| return self._generate_with_vllm(prompt) | |
| except Exception as e: | |
| print(f"Generation failed (attempt {attempt+1}): {e}") | |
| if attempt == max_retries - 1: | |
| raise | |
| # GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํ ์ฌ์๋ | |
| torch.cuda.empty_cache() | |
| time.sleep(1) | |
| def _validate_tasks(self, tasks: List[Dict]) -> List[Dict]: | |
| """์์ฑ๋ task ๊ฒ์ฆ""" | |
| valid_tasks = [] | |
| for task in tasks: | |
| if self._is_valid_task(task): | |
| valid_tasks.append(task) | |
| else: | |
| print(f"Invalid task filtered: {task['task_type']}") | |
| return valid_tasks | |
| ``` | |
| ## 4. ๋ง์ด๊ทธ๋ ์ด์ ๊ณํ | |
| ### Phase 1: ์ฝ๋ ์ค๋น | |
| 1. UnifiedTTRLVRTrainer ํด๋์ค ์์ฑ | |
| 2. TestTimePipeline ๋ก์ง ์ด๋ | |
| 3. ๋จ์ ํ ์คํธ ์์ฑ | |
| ### Phase 2: ํตํฉ ํ ์คํธ | |
| 1. ์๊ท๋ชจ ๋ฌธ์ ๋ก ํ ์คํธ (1-2 rounds) | |
| 2. ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ๋ชจ๋ํฐ๋ง | |
| 3. ํ์ต ์ฑ๋ฅ ๋น๊ต | |
| ### Phase 3: ์ ํ | |
| 1. ๊ธฐ์กด ์คํฌ๋ฆฝํธ ๋ฐฑ์ | |
| 2. ์ ์คํฌ๋ฆฝํธ๋ก ๊ต์ฒด | |
| 3. ์ ์ฒด ํ์ต ์คํ | |
| ## 5. ์์ ๊ฒฐ๊ณผ | |
| ### ์ฅ์ | |
| - โ ๋๊ธฐํ ๋ฌธ์ ์์ ํด๊ฒฐ | |
| - โ 30-40% ๋น ๋ฅธ ์คํ (vLLM ์ฌ์์ฑ ์์) | |
| - โ ๋ฉ๋ชจ๋ฆฌ ํจ์จ 20-30% ๊ฐ์ | |
| - โ ์ฝ๋ ๊ตฌ์กฐ ๋จ์ํ | |
| ### ๋จ์ ๋ฐ ๋์ | |
| - โ Phase ๊ฐ ๊ฒฐํฉ๋ ์ฆ๊ฐ | |
| - โ ๋ช ํํ ์ธํฐํ์ด์ค ์ ์๋ก ํด๊ฒฐ | |
| - โ ๋๋ฒ๊น ๋ณต์ก๋ | |
| - โ ์์ธํ ๋ก๊น ์ถ๊ฐ | |
| - โ ๊ธฐ์กด ์ฝ๋์ ํธํ์ฑ | |
| - โ ๋ ๋ฒ์ ๋ณํ ์ ์ง | |
| ## 6. ๊ตฌํ ์ฐ์ ์์ | |
| 1. **๋์**: UnifiedTTRLVRTrainer ๊ธฐ๋ณธ ๊ตฌ์กฐ | |
| 2. **๋์**: Phase 1-4 ๋ก์ง ์ด๋ | |
| 3. **์ค๊ฐ**: ์ค์ ํ์ผ ํตํฉ | |
| 4. **๋ฎ์**: ์ถ๊ฐ ์ต์ ํ | |
| ## 7. ํ ์คํธ ๊ณํ | |
| ```bash | |
| # ๋จ๊ณ๋ณ ํ ์คํธ | |
| # 1. ์๊ท๋ชจ ํ ์คํธ | |
| python train_ttrlvr_azr_unified.py --rounds 2 --problems 1 | |
| # 2. ์ค๊ฐ ํ ์คํธ | |
| python train_ttrlvr_azr_unified.py --rounds 5 --problems 5 | |
| # 3. ์ ์ฒด ํ ์คํธ | |
| python train_ttrlvr_azr_unified.py --rounds 30 --problems 10 | |
| ``` | |
| ## 8. ๋กค๋ฐฑ ๊ณํ | |
| ๋ฌธ์ ๋ฐ์ ์: | |
| 1. ๊ธฐ์กด ๋ถ๋ฆฌํ ๊ตฌ์กฐ๋ก ์ฆ์ ๋ณต๊ท ๊ฐ๋ฅ | |
| 2. load_format: dtensor ์ฌ์ฉ์ผ๋ก ์์ ํด๊ฒฐ | |
| 3. ๋จ๊ณ์ ํตํฉ (Phase 5๋ง ๋จผ์ ) |