|
|
|
|
|
""" |
|
|
AZR ์ฒดํฌํฌ์ธํธ ๊ด๋ฆฌ ์ ํธ๋ฆฌํฐ |
|
|
์ฒดํฌํฌ์ธํธ ์ ์ฅ/๋ก๋ ๋ฐ ๊ฒฝ๋ก ๊ด๋ฆฌ |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import glob |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any, Tuple |
|
|
from datetime import datetime |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
class CheckpointManager: |
|
|
"""AZR ์ฒดํฌํฌ์ธํธ ๊ด๋ฆฌ์""" |
|
|
|
|
|
def __init__(self, base_checkpoint_path: str = "/data/RLVR/checkpoints/ttrlvr_azr", logger: Optional[Any] = None): |
|
|
""" |
|
|
Args: |
|
|
base_checkpoint_path: ์ฒดํฌํฌ์ธํธ ๊ธฐ๋ณธ ๊ฒฝ๋ก |
|
|
logger: ๋ก๊ฑฐ ๊ฐ์ฒด |
|
|
""" |
|
|
self.base_checkpoint_path = Path(base_checkpoint_path) |
|
|
self.logger = logger |
|
|
|
|
|
def log_info(self, msg: str): |
|
|
"""๋ก๊น
ํฌํผ""" |
|
|
if self.logger: |
|
|
self.logger.log_info(msg) |
|
|
else: |
|
|
print(f"[INFO] {msg}") |
|
|
|
|
|
def log_error(self, msg: str): |
|
|
"""์๋ฌ ๋ก๊น
ํฌํผ""" |
|
|
if self.logger: |
|
|
self.logger.log_error(msg) |
|
|
else: |
|
|
print(f"[ERROR] {msg}") |
|
|
|
|
|
def find_latest_checkpoint(self, experiment_name: str, round_num: Optional[int] = None) -> Optional[str]: |
|
|
""" |
|
|
์ต์ ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก ์ฐพ๊ธฐ |
|
|
|
|
|
Args: |
|
|
experiment_name: ์คํ ์ด๋ฆ |
|
|
round_num: ํน์ ๋ผ์ด๋ ๋ฒํธ (None์ด๋ฉด ์ต์ ) |
|
|
|
|
|
Returns: |
|
|
์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก ๋๋ None |
|
|
""" |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
exp_dir = self.base_checkpoint_path / experiment_name |
|
|
|
|
|
if not exp_dir.exists(): |
|
|
self.log_info(f"Checkpoint directory not found: {exp_dir}") |
|
|
return None |
|
|
|
|
|
|
|
|
checkpoint_patterns = [ |
|
|
"actor_checkpoint_*", |
|
|
"checkpoint_*", |
|
|
"model_*" |
|
|
] |
|
|
|
|
|
all_checkpoints = [] |
|
|
for pattern in checkpoint_patterns: |
|
|
checkpoints = list(exp_dir.glob(pattern)) |
|
|
all_checkpoints.extend(checkpoints) |
|
|
|
|
|
if not all_checkpoints: |
|
|
self.log_info(f"No checkpoints found in {exp_dir}") |
|
|
return None |
|
|
|
|
|
|
|
|
latest_checkpoint = max(all_checkpoints, key=lambda p: p.stat().st_mtime) |
|
|
|
|
|
self.log_info(f"Found latest checkpoint: {latest_checkpoint}") |
|
|
return str(latest_checkpoint) |
|
|
|
|
|
except Exception as e: |
|
|
self.log_error(f"Error finding checkpoint: {e}") |
|
|
return None |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path: str, device_map: str = "auto", |
|
|
torch_dtype: Any = torch.float16) -> Optional[Tuple[Any, Any]]: |
|
|
""" |
|
|
์ฒดํฌํฌ์ธํธ์์ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ |
|
|
|
|
|
Args: |
|
|
checkpoint_path: ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก |
|
|
device_map: ๋๋ฐ์ด์ค ๋งคํ |
|
|
torch_dtype: ๋ฐ์ดํฐ ํ์
|
|
|
|
|
|
Returns: |
|
|
(model, tokenizer) ํํ ๋๋ None |
|
|
""" |
|
|
try: |
|
|
if not os.path.exists(checkpoint_path): |
|
|
self.log_error(f"Checkpoint not found: {checkpoint_path}") |
|
|
return None |
|
|
|
|
|
self.log_info(f"Loading checkpoint from: {checkpoint_path}") |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
checkpoint_path, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map=device_map, |
|
|
trust_remote_code=True, |
|
|
use_cache=False |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
checkpoint_path, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
self.log_info(f"โ
Successfully loaded model and tokenizer from checkpoint") |
|
|
return model, tokenizer |
|
|
|
|
|
except Exception as e: |
|
|
self.log_error(f"Error loading checkpoint: {e}") |
|
|
return None |
|
|
|
|
|
def save_checkpoint_info(self, checkpoint_path: str, round_num: int, |
|
|
metrics: Optional[Dict[str, Any]] = None): |
|
|
""" |
|
|
์ฒดํฌํฌ์ธํธ ์ ๋ณด ์ ์ฅ (๋ฉํ๋ฐ์ดํฐ) |
|
|
|
|
|
Args: |
|
|
checkpoint_path: ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก |
|
|
round_num: ๋ผ์ด๋ ๋ฒํธ |
|
|
metrics: ํ์ต ๋ฉํธ๋ฆญ |
|
|
""" |
|
|
try: |
|
|
info = { |
|
|
"checkpoint_path": checkpoint_path, |
|
|
"round_num": round_num, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"metrics": metrics or {} |
|
|
} |
|
|
|
|
|
info_path = Path(checkpoint_path) / "checkpoint_info.json" |
|
|
with open(info_path, 'w') as f: |
|
|
json.dump(info, f, indent=2) |
|
|
|
|
|
self.log_info(f"Saved checkpoint info to: {info_path}") |
|
|
|
|
|
except Exception as e: |
|
|
self.log_error(f"Error saving checkpoint info: {e}") |
|
|
|
|
|
def get_checkpoint_for_round(self, round_num: int, experiment_name: str) -> Optional[str]: |
|
|
""" |
|
|
ํน์ ๋ผ์ด๋์ ์ฒดํฌํฌ์ธํธ ์ฐพ๊ธฐ |
|
|
|
|
|
Args: |
|
|
round_num: ๋ผ์ด๋ ๋ฒํธ |
|
|
experiment_name: ์คํ ์ด๋ฆ |
|
|
|
|
|
Returns: |
|
|
์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก ๋๋ None |
|
|
""" |
|
|
|
|
|
round_exp_name = f"{experiment_name}_round_{round_num}" |
|
|
|
|
|
|
|
|
checkpoint = self.find_latest_checkpoint(round_exp_name) |
|
|
|
|
|
if not checkpoint: |
|
|
|
|
|
checkpoint = self.find_latest_checkpoint(experiment_name) |
|
|
|
|
|
return checkpoint |
|
|
|
|
|
def clean_old_checkpoints(self, experiment_name: str, keep_last: int = 5): |
|
|
""" |
|
|
์ค๋๋ ์ฒดํฌํฌ์ธํธ ์ ๋ฆฌ |
|
|
|
|
|
Args: |
|
|
experiment_name: ์คํ ์ด๋ฆ |
|
|
keep_last: ์ ์งํ ์ต๊ทผ ์ฒดํฌํฌ์ธํธ ์ |
|
|
""" |
|
|
try: |
|
|
exp_dir = self.base_checkpoint_path / experiment_name |
|
|
if not exp_dir.exists(): |
|
|
return |
|
|
|
|
|
|
|
|
all_checkpoints = list(exp_dir.glob("actor_checkpoint_*")) |
|
|
|
|
|
if len(all_checkpoints) <= keep_last: |
|
|
return |
|
|
|
|
|
|
|
|
all_checkpoints.sort(key=lambda p: p.stat().st_mtime, reverse=True) |
|
|
|
|
|
|
|
|
for checkpoint in all_checkpoints[keep_last:]: |
|
|
self.log_info(f"Removing old checkpoint: {checkpoint}") |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
self.log_error(f"Error cleaning checkpoints: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
manager = CheckpointManager() |
|
|
|
|
|
|
|
|
checkpoint = manager.find_latest_checkpoint("ttrlvr_azr_gpu5") |
|
|
if checkpoint: |
|
|
print(f"Latest checkpoint: {checkpoint}") |
|
|
|
|
|
|
|
|
result = manager.load_checkpoint(checkpoint) |
|
|
if result: |
|
|
model, tokenizer = result |
|
|
print(f"Model loaded: {type(model).__name__}") |
|
|
print(f"Tokenizer loaded: {type(tokenizer).__name__}") |