#!/usr/bin/env python3 """ AZR 체크포인트 관리 유틸리티 체크포인트 저장/로드 및 경로 관리 """ import os import json import glob from pathlib import Path from typing import Optional, Dict, Any, Tuple from datetime import datetime import torch from transformers import AutoModelForCausalLM, AutoTokenizer class CheckpointManager: """AZR 체크포인트 관리자""" def __init__(self, base_checkpoint_path: str = "/data/RLVR/checkpoints/ttrlvr_azr", logger: Optional[Any] = None): """ Args: base_checkpoint_path: 체크포인트 기본 경로 logger: 로거 객체 """ self.base_checkpoint_path = Path(base_checkpoint_path) self.logger = logger def log_info(self, msg: str): """로깅 헬퍼""" if self.logger: self.logger.log_info(msg) else: print(f"[INFO] {msg}") def log_error(self, msg: str): """에러 로깅 헬퍼""" if self.logger: self.logger.log_error(msg) else: print(f"[ERROR] {msg}") def find_latest_checkpoint(self, experiment_name: str, round_num: Optional[int] = None) -> Optional[str]: """ 최신 체크포인트 경로 찾기 Args: experiment_name: 실험 이름 round_num: 특정 라운드 번호 (None이면 최신) Returns: 체크포인트 경로 또는 None """ try: # 체크포인트 디렉토리 경로들 # AZR은 보통 다음과 같은 패턴으로 저장: # /data/RLVR/checkpoints/ttrlvr_azr/{experiment_name}/actor_checkpoint_{step} exp_dir = self.base_checkpoint_path / experiment_name if not exp_dir.exists(): self.log_info(f"Checkpoint directory not found: {exp_dir}") return None # actor_checkpoint_* 패턴 찾기 checkpoint_patterns = [ "actor_checkpoint_*", "checkpoint_*", "model_*" ] all_checkpoints = [] for pattern in checkpoint_patterns: checkpoints = list(exp_dir.glob(pattern)) all_checkpoints.extend(checkpoints) if not all_checkpoints: self.log_info(f"No checkpoints found in {exp_dir}") return None # 최신 체크포인트 찾기 (수정 시간 기준) latest_checkpoint = max(all_checkpoints, key=lambda p: p.stat().st_mtime) self.log_info(f"Found latest checkpoint: {latest_checkpoint}") return str(latest_checkpoint) except Exception as e: self.log_error(f"Error finding checkpoint: {e}") return None def load_checkpoint(self, checkpoint_path: str, device_map: str = "auto", torch_dtype: Any = torch.float16) -> Optional[Tuple[Any, Any]]: """ 체크포인트에서 모델과 토크나이저 로드 Args: checkpoint_path: 체크포인트 경로 device_map: 디바이스 매핑 torch_dtype: 데이터 타입 Returns: (model, tokenizer) 튜플 또는 None """ try: if not os.path.exists(checkpoint_path): self.log_error(f"Checkpoint not found: {checkpoint_path}") return None self.log_info(f"Loading checkpoint from: {checkpoint_path}") # 모델 로드 model = AutoModelForCausalLM.from_pretrained( checkpoint_path, torch_dtype=torch_dtype, device_map=device_map, trust_remote_code=True, use_cache=False # 학습용 ) # 토크나이저 로드 tokenizer = AutoTokenizer.from_pretrained( checkpoint_path, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token self.log_info(f"✅ Successfully loaded model and tokenizer from checkpoint") return model, tokenizer except Exception as e: self.log_error(f"Error loading checkpoint: {e}") return None def save_checkpoint_info(self, checkpoint_path: str, round_num: int, metrics: Optional[Dict[str, Any]] = None): """ 체크포인트 정보 저장 (메타데이터) Args: checkpoint_path: 체크포인트 경로 round_num: 라운드 번호 metrics: 학습 메트릭 """ try: info = { "checkpoint_path": checkpoint_path, "round_num": round_num, "timestamp": datetime.now().isoformat(), "metrics": metrics or {} } info_path = Path(checkpoint_path) / "checkpoint_info.json" with open(info_path, 'w') as f: json.dump(info, f, indent=2) self.log_info(f"Saved checkpoint info to: {info_path}") except Exception as e: self.log_error(f"Error saving checkpoint info: {e}") def get_checkpoint_for_round(self, round_num: int, experiment_name: str) -> Optional[str]: """ 특정 라운드의 체크포인트 찾기 Args: round_num: 라운드 번호 experiment_name: 실험 이름 Returns: 체크포인트 경로 또는 None """ # 라운드별 실험 이름 패턴 round_exp_name = f"{experiment_name}_round_{round_num}" # 먼저 정확한 라운드 체크포인트 찾기 checkpoint = self.find_latest_checkpoint(round_exp_name) if not checkpoint: # 없으면 일반 실험 이름으로 찾기 checkpoint = self.find_latest_checkpoint(experiment_name) return checkpoint def clean_old_checkpoints(self, experiment_name: str, keep_last: int = 5): """ 오래된 체크포인트 정리 Args: experiment_name: 실험 이름 keep_last: 유지할 최근 체크포인트 수 """ try: exp_dir = self.base_checkpoint_path / experiment_name if not exp_dir.exists(): return # 모든 체크포인트 찾기 all_checkpoints = list(exp_dir.glob("actor_checkpoint_*")) if len(all_checkpoints) <= keep_last: return # 수정 시간 기준 정렬 all_checkpoints.sort(key=lambda p: p.stat().st_mtime, reverse=True) # 오래된 것들 삭제 for checkpoint in all_checkpoints[keep_last:]: self.log_info(f"Removing old checkpoint: {checkpoint}") # 실제 삭제는 주의해서 수행 # shutil.rmtree(checkpoint) except Exception as e: self.log_error(f"Error cleaning checkpoints: {e}") if __name__ == "__main__": # 테스트 manager = CheckpointManager() # 최신 체크포인트 찾기 checkpoint = manager.find_latest_checkpoint("ttrlvr_azr_gpu5") if checkpoint: print(f"Latest checkpoint: {checkpoint}") # 모델 로드 테스트 result = manager.load_checkpoint(checkpoint) if result: model, tokenizer = result print(f"Model loaded: {type(model).__name__}") print(f"Tokenizer loaded: {type(tokenizer).__name__}")