neural-mesh-v2 / test /utils /checkpoint_manager.py
hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
#!/usr/bin/env python3
"""
AZR ์ฒดํฌํฌ์ธํŠธ ๊ด€๋ฆฌ ์œ ํ‹ธ๋ฆฌํ‹ฐ
์ฒดํฌํฌ์ธํŠธ ์ €์žฅ/๋กœ๋“œ ๋ฐ ๊ฒฝ๋กœ ๊ด€๋ฆฌ
"""
import os
import json
import glob
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from datetime import datetime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class CheckpointManager:
"""AZR ์ฒดํฌํฌ์ธํŠธ ๊ด€๋ฆฌ์ž"""
def __init__(self, base_checkpoint_path: str = "/data/RLVR/checkpoints/ttrlvr_azr", logger: Optional[Any] = None):
"""
Args:
base_checkpoint_path: ์ฒดํฌํฌ์ธํŠธ ๊ธฐ๋ณธ ๊ฒฝ๋กœ
logger: ๋กœ๊ฑฐ ๊ฐ์ฒด
"""
self.base_checkpoint_path = Path(base_checkpoint_path)
self.logger = logger
def log_info(self, msg: str):
"""๋กœ๊น… ํ—ฌํผ"""
if self.logger:
self.logger.log_info(msg)
else:
print(f"[INFO] {msg}")
def log_error(self, msg: str):
"""์—๋Ÿฌ ๋กœ๊น… ํ—ฌํผ"""
if self.logger:
self.logger.log_error(msg)
else:
print(f"[ERROR] {msg}")
def find_latest_checkpoint(self, experiment_name: str, round_num: Optional[int] = None) -> Optional[str]:
"""
์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ ์ฐพ๊ธฐ
Args:
experiment_name: ์‹คํ—˜ ์ด๋ฆ„
round_num: ํŠน์ • ๋ผ์šด๋“œ ๋ฒˆํ˜ธ (None์ด๋ฉด ์ตœ์‹ )
Returns:
์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ ๋˜๋Š” None
"""
try:
# ์ฒดํฌํฌ์ธํŠธ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ๋“ค
# AZR์€ ๋ณดํ†ต ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŒจํ„ด์œผ๋กœ ์ €์žฅ:
# /data/RLVR/checkpoints/ttrlvr_azr/{experiment_name}/actor_checkpoint_{step}
exp_dir = self.base_checkpoint_path / experiment_name
if not exp_dir.exists():
self.log_info(f"Checkpoint directory not found: {exp_dir}")
return None
# actor_checkpoint_* ํŒจํ„ด ์ฐพ๊ธฐ
checkpoint_patterns = [
"actor_checkpoint_*",
"checkpoint_*",
"model_*"
]
all_checkpoints = []
for pattern in checkpoint_patterns:
checkpoints = list(exp_dir.glob(pattern))
all_checkpoints.extend(checkpoints)
if not all_checkpoints:
self.log_info(f"No checkpoints found in {exp_dir}")
return None
# ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ ์ฐพ๊ธฐ (์ˆ˜์ • ์‹œ๊ฐ„ ๊ธฐ์ค€)
latest_checkpoint = max(all_checkpoints, key=lambda p: p.stat().st_mtime)
self.log_info(f"Found latest checkpoint: {latest_checkpoint}")
return str(latest_checkpoint)
except Exception as e:
self.log_error(f"Error finding checkpoint: {e}")
return None
def load_checkpoint(self, checkpoint_path: str, device_map: str = "auto",
torch_dtype: Any = torch.float16) -> Optional[Tuple[Any, Any]]:
"""
์ฒดํฌํฌ์ธํŠธ์—์„œ ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
Args:
checkpoint_path: ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ
device_map: ๋””๋ฐ”์ด์Šค ๋งคํ•‘
torch_dtype: ๋ฐ์ดํ„ฐ ํƒ€์ž…
Returns:
(model, tokenizer) ํŠœํ”Œ ๋˜๋Š” None
"""
try:
if not os.path.exists(checkpoint_path):
self.log_error(f"Checkpoint not found: {checkpoint_path}")
return None
self.log_info(f"Loading checkpoint from: {checkpoint_path}")
# ๋ชจ๋ธ ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=True,
use_cache=False # ํ•™์Šต์šฉ
)
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained(
checkpoint_path,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
self.log_info(f"โœ… Successfully loaded model and tokenizer from checkpoint")
return model, tokenizer
except Exception as e:
self.log_error(f"Error loading checkpoint: {e}")
return None
def save_checkpoint_info(self, checkpoint_path: str, round_num: int,
metrics: Optional[Dict[str, Any]] = None):
"""
์ฒดํฌํฌ์ธํŠธ ์ •๋ณด ์ €์žฅ (๋ฉ”ํƒ€๋ฐ์ดํ„ฐ)
Args:
checkpoint_path: ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ
round_num: ๋ผ์šด๋“œ ๋ฒˆํ˜ธ
metrics: ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ
"""
try:
info = {
"checkpoint_path": checkpoint_path,
"round_num": round_num,
"timestamp": datetime.now().isoformat(),
"metrics": metrics or {}
}
info_path = Path(checkpoint_path) / "checkpoint_info.json"
with open(info_path, 'w') as f:
json.dump(info, f, indent=2)
self.log_info(f"Saved checkpoint info to: {info_path}")
except Exception as e:
self.log_error(f"Error saving checkpoint info: {e}")
def get_checkpoint_for_round(self, round_num: int, experiment_name: str) -> Optional[str]:
"""
ํŠน์ • ๋ผ์šด๋“œ์˜ ์ฒดํฌํฌ์ธํŠธ ์ฐพ๊ธฐ
Args:
round_num: ๋ผ์šด๋“œ ๋ฒˆํ˜ธ
experiment_name: ์‹คํ—˜ ์ด๋ฆ„
Returns:
์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ ๋˜๋Š” None
"""
# ๋ผ์šด๋“œ๋ณ„ ์‹คํ—˜ ์ด๋ฆ„ ํŒจํ„ด
round_exp_name = f"{experiment_name}_round_{round_num}"
# ๋จผ์ € ์ •ํ™•ํ•œ ๋ผ์šด๋“œ ์ฒดํฌํฌ์ธํŠธ ์ฐพ๊ธฐ
checkpoint = self.find_latest_checkpoint(round_exp_name)
if not checkpoint:
# ์—†์œผ๋ฉด ์ผ๋ฐ˜ ์‹คํ—˜ ์ด๋ฆ„์œผ๋กœ ์ฐพ๊ธฐ
checkpoint = self.find_latest_checkpoint(experiment_name)
return checkpoint
def clean_old_checkpoints(self, experiment_name: str, keep_last: int = 5):
"""
์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์ •๋ฆฌ
Args:
experiment_name: ์‹คํ—˜ ์ด๋ฆ„
keep_last: ์œ ์ง€ํ•  ์ตœ๊ทผ ์ฒดํฌํฌ์ธํŠธ ์ˆ˜
"""
try:
exp_dir = self.base_checkpoint_path / experiment_name
if not exp_dir.exists():
return
# ๋ชจ๋“  ์ฒดํฌํฌ์ธํŠธ ์ฐพ๊ธฐ
all_checkpoints = list(exp_dir.glob("actor_checkpoint_*"))
if len(all_checkpoints) <= keep_last:
return
# ์ˆ˜์ • ์‹œ๊ฐ„ ๊ธฐ์ค€ ์ •๋ ฌ
all_checkpoints.sort(key=lambda p: p.stat().st_mtime, reverse=True)
# ์˜ค๋ž˜๋œ ๊ฒƒ๋“ค ์‚ญ์ œ
for checkpoint in all_checkpoints[keep_last:]:
self.log_info(f"Removing old checkpoint: {checkpoint}")
# ์‹ค์ œ ์‚ญ์ œ๋Š” ์ฃผ์˜ํ•ด์„œ ์ˆ˜ํ–‰
# shutil.rmtree(checkpoint)
except Exception as e:
self.log_error(f"Error cleaning checkpoints: {e}")
if __name__ == "__main__":
# ํ…Œ์ŠคํŠธ
manager = CheckpointManager()
# ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ ์ฐพ๊ธฐ
checkpoint = manager.find_latest_checkpoint("ttrlvr_azr_gpu5")
if checkpoint:
print(f"Latest checkpoint: {checkpoint}")
# ๋ชจ๋ธ ๋กœ๋“œ ํ…Œ์ŠคํŠธ
result = manager.load_checkpoint(checkpoint)
if result:
model, tokenizer = result
print(f"Model loaded: {type(model).__name__}")
print(f"Tokenizer loaded: {type(tokenizer).__name__}")