File size: 8,055 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#!/usr/bin/env python3
"""
AZR 체크포인트 관리 유틸리티
체크포인트 저장/로드 및 경로 관리
"""

import os
import json
import glob
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from datetime import datetime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


class CheckpointManager:
    """AZR 체크포인트 관리자"""
    
    def __init__(self, base_checkpoint_path: str = "/data/RLVR/checkpoints/ttrlvr_azr", logger: Optional[Any] = None):
        """
        Args:
            base_checkpoint_path: 체크포인트 기본 경로
            logger: 로거 객체
        """
        self.base_checkpoint_path = Path(base_checkpoint_path)
        self.logger = logger
        
    def log_info(self, msg: str):
        """로깅 헬퍼"""
        if self.logger:
            self.logger.log_info(msg)
        else:
            print(f"[INFO] {msg}")
            
    def log_error(self, msg: str):
        """에러 로깅 헬퍼"""
        if self.logger:
            self.logger.log_error(msg)
        else:
            print(f"[ERROR] {msg}")
    
    def find_latest_checkpoint(self, experiment_name: str, round_num: Optional[int] = None) -> Optional[str]:
        """
        최신 체크포인트 경로 찾기
        
        Args:
            experiment_name: 실험 이름
            round_num: 특정 라운드 번호 (None이면 최신)
            
        Returns:
            체크포인트 경로 또는 None
        """
        try:
            # 체크포인트 디렉토리 경로들
            # AZR은 보통 다음과 같은 패턴으로 저장:
            # /data/RLVR/checkpoints/ttrlvr_azr/{experiment_name}/actor_checkpoint_{step}
            exp_dir = self.base_checkpoint_path / experiment_name
            
            if not exp_dir.exists():
                self.log_info(f"Checkpoint directory not found: {exp_dir}")
                return None
            
            # actor_checkpoint_* 패턴 찾기
            checkpoint_patterns = [
                "actor_checkpoint_*",
                "checkpoint_*",
                "model_*"
            ]
            
            all_checkpoints = []
            for pattern in checkpoint_patterns:
                checkpoints = list(exp_dir.glob(pattern))
                all_checkpoints.extend(checkpoints)
            
            if not all_checkpoints:
                self.log_info(f"No checkpoints found in {exp_dir}")
                return None
            
            # 최신 체크포인트 찾기 (수정 시간 기준)
            latest_checkpoint = max(all_checkpoints, key=lambda p: p.stat().st_mtime)
            
            self.log_info(f"Found latest checkpoint: {latest_checkpoint}")
            return str(latest_checkpoint)
            
        except Exception as e:
            self.log_error(f"Error finding checkpoint: {e}")
            return None
    
    def load_checkpoint(self, checkpoint_path: str, device_map: str = "auto", 
                       torch_dtype: Any = torch.float16) -> Optional[Tuple[Any, Any]]:
        """
        체크포인트에서 모델과 토크나이저 로드
        
        Args:
            checkpoint_path: 체크포인트 경로
            device_map: 디바이스 매핑
            torch_dtype: 데이터 타입
            
        Returns:
            (model, tokenizer) 튜플 또는 None
        """
        try:
            if not os.path.exists(checkpoint_path):
                self.log_error(f"Checkpoint not found: {checkpoint_path}")
                return None
            
            self.log_info(f"Loading checkpoint from: {checkpoint_path}")
            
            # 모델 로드
            model = AutoModelForCausalLM.from_pretrained(
                checkpoint_path,
                torch_dtype=torch_dtype,
                device_map=device_map,
                trust_remote_code=True,
                use_cache=False  # 학습용
            )
            
            # 토크나이저 로드
            tokenizer = AutoTokenizer.from_pretrained(
                checkpoint_path,
                trust_remote_code=True
            )
            
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            self.log_info(f"✅ Successfully loaded model and tokenizer from checkpoint")
            return model, tokenizer
            
        except Exception as e:
            self.log_error(f"Error loading checkpoint: {e}")
            return None
    
    def save_checkpoint_info(self, checkpoint_path: str, round_num: int, 
                           metrics: Optional[Dict[str, Any]] = None):
        """
        체크포인트 정보 저장 (메타데이터)
        
        Args:
            checkpoint_path: 체크포인트 경로
            round_num: 라운드 번호
            metrics: 학습 메트릭
        """
        try:
            info = {
                "checkpoint_path": checkpoint_path,
                "round_num": round_num,
                "timestamp": datetime.now().isoformat(),
                "metrics": metrics or {}
            }
            
            info_path = Path(checkpoint_path) / "checkpoint_info.json"
            with open(info_path, 'w') as f:
                json.dump(info, f, indent=2)
                
            self.log_info(f"Saved checkpoint info to: {info_path}")
            
        except Exception as e:
            self.log_error(f"Error saving checkpoint info: {e}")
    
    def get_checkpoint_for_round(self, round_num: int, experiment_name: str) -> Optional[str]:
        """
        특정 라운드의 체크포인트 찾기
        
        Args:
            round_num: 라운드 번호
            experiment_name: 실험 이름
            
        Returns:
            체크포인트 경로 또는 None
        """
        # 라운드별 실험 이름 패턴
        round_exp_name = f"{experiment_name}_round_{round_num}"
        
        # 먼저 정확한 라운드 체크포인트 찾기
        checkpoint = self.find_latest_checkpoint(round_exp_name)
        
        if not checkpoint:
            # 없으면 일반 실험 이름으로 찾기
            checkpoint = self.find_latest_checkpoint(experiment_name)
        
        return checkpoint
    
    def clean_old_checkpoints(self, experiment_name: str, keep_last: int = 5):
        """
        오래된 체크포인트 정리
        
        Args:
            experiment_name: 실험 이름
            keep_last: 유지할 최근 체크포인트 수
        """
        try:
            exp_dir = self.base_checkpoint_path / experiment_name
            if not exp_dir.exists():
                return
            
            # 모든 체크포인트 찾기
            all_checkpoints = list(exp_dir.glob("actor_checkpoint_*"))
            
            if len(all_checkpoints) <= keep_last:
                return
            
            # 수정 시간 기준 정렬
            all_checkpoints.sort(key=lambda p: p.stat().st_mtime, reverse=True)
            
            # 오래된 것들 삭제
            for checkpoint in all_checkpoints[keep_last:]:
                self.log_info(f"Removing old checkpoint: {checkpoint}")
                # 실제 삭제는 주의해서 수행
                # shutil.rmtree(checkpoint)
                
        except Exception as e:
            self.log_error(f"Error cleaning checkpoints: {e}")


if __name__ == "__main__":
    # 테스트
    manager = CheckpointManager()
    
    # 최신 체크포인트 찾기
    checkpoint = manager.find_latest_checkpoint("ttrlvr_azr_gpu5")
    if checkpoint:
        print(f"Latest checkpoint: {checkpoint}")
        
        # 모델 로드 테스트
        result = manager.load_checkpoint(checkpoint)
        if result:
            model, tokenizer = result
            print(f"Model loaded: {type(model).__name__}")
            print(f"Tokenizer loaded: {type(tokenizer).__name__}")