StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Checkpoint management utilities.
Canonical Checkpoint Format (new):
{
'step': int, # Training step number
'model_state': Dict[str, Tensor], # Model state dict
'optimizer_state': Dict, # Optimizer state dict (optional)
'config': Dict, # TrainingConfig as dict
'metrics': Dict[str, float], # Training metrics
'global_step': int, # (deprecated, kept for compat) same as step
'current_epoch': int, # (optional) current epoch number
'best_loss': float, # (optional) best validation loss
}
Legacy Checkpoint Format (old, from BaseTrainer):
{
'global_step': int,
'current_epoch': int,
'best_loss': float,
'model_state_dict': Dict[str, Tensor], # ← Note: uses '_dict' suffix
'optimizer_state_dict': Dict,
'config': Dict,
}
The load() function auto-detects and migrates legacy format to canonical format.
"""
from pathlib import Path
from typing import Dict, Any, Optional
import torch
from taoTrain.config import TrainingConfig
class CheckpointManager:
"""Manage model checkpoints with versioning."""
def __init__(
self,
checkpoint_dir: str | Path,
keep_last_n: int = 3,
track_best: bool = True,
):
"""
Initialize checkpoint manager.
Args:
checkpoint_dir: Directory to save checkpoints
keep_last_n: Number of recent checkpoints to keep
track_best: Whether to track best model
"""
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.keep_last_n = keep_last_n
self.track_best = track_best
self.best_metric = None
self.best_metric_name = None
self.saved_checkpoints = []
def save(
self,
step: int,
model_state: Dict[str, Any],
optimizer_state: Optional[Dict[str, Any]] = None,
config: Optional[TrainingConfig] = None,
metrics: Optional[Dict[str, float]] = None,
is_best: bool = False,
) -> Path:
"""
Save a checkpoint.
Args:
step: Training step
model_state: Model state dict
optimizer_state: Optimizer state dict
config: Training config
metrics: Metrics dict
is_best: Whether this is the best model so far
Returns:
Path to saved checkpoint
"""
checkpoint = {
"step": step,
"model_state": model_state,
"optimizer_state": optimizer_state,
"config": config.to_dict() if config else None,
"metrics": metrics or {},
}
filename = f"checkpoint_step_{step:06d}.pt"
if is_best:
filename = "best_model.pt"
path = self.checkpoint_dir / filename
torch.save(checkpoint, path)
# Track saved checkpoints
if not is_best:
self.saved_checkpoints.append((step, path))
# Clean up old checkpoints
if len(self.saved_checkpoints) > self.keep_last_n:
_, old_path = self.saved_checkpoints.pop(0)
if old_path.exists():
old_path.unlink()
return path
def load(
self,
checkpoint_path: str | Path,
device: Optional[torch.device] = None,
) -> Dict[str, Any]:
"""
Load a checkpoint with backward-compatible format handling.
Auto-detects checkpoint format (canonical or legacy) and normalizes
to canonical format in-memory. Legacy checkpoints are migrated without
modifying the file.
Args:
checkpoint_path: Path to checkpoint
device: Device to load to
Returns:
Checkpoint dict in canonical format with 'model_state' key
"""
if device is None:
device = torch.device("cpu")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Auto-detect and migrate legacy format to canonical format
checkpoint = self._normalize_checkpoint_format(checkpoint)
return checkpoint
def _normalize_checkpoint_format(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalize checkpoint to canonical format.
Detects if checkpoint is in legacy format (from BaseTrainer with 'model_state_dict')
and migrates it to canonical format (with 'model_state').
Args:
checkpoint: Raw checkpoint dict
Returns:
Normalized checkpoint dict with canonical keys
"""
# Check if this is a legacy checkpoint (has 'model_state_dict' but not 'model_state')
if "model_state_dict" in checkpoint and "model_state" not in checkpoint:
# Migrate legacy format to canonical
migrated = {
"step": checkpoint.get("global_step", 0),
"model_state": checkpoint["model_state_dict"],
"optimizer_state": checkpoint.get("optimizer_state_dict"),
"config": checkpoint.get("config"),
"metrics": {},
# Keep legacy keys for backward compatibility in code that uses them
"global_step": checkpoint.get("global_step", 0),
"current_epoch": checkpoint.get("current_epoch", 0),
"best_loss": checkpoint.get("best_loss", float('inf')),
}
print(f"\n✓ [CheckpointManager] Detected legacy checkpoint format. Auto-migrated to canonical format.")
return migrated
# Already in canonical format or unknown format
if "model_state" not in checkpoint:
# If neither format detected, ensure model_state is accessible
# (might be a raw state_dict)
print(f"\n⚠ [CheckpointManager] Checkpoint format unclear. Assuming raw state_dict format.")
checkpoint["model_state"] = checkpoint
return checkpoint
def get_latest(self) -> Optional[Path]:
"""Get path to latest checkpoint."""
if not self.saved_checkpoints:
return None
return self.saved_checkpoints[-1][1]
def get_best(self) -> Optional[Path]:
"""Get path to best checkpoint."""
best_path = self.checkpoint_dir / "best_model.pt"
if best_path.exists():
return best_path
return None
def list_checkpoints(self) -> list[Path]:
"""List all saved checkpoints."""
return sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt"))