3d_model / ylff /utils /checkpoint_utils.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Optimized checkpoint utilities for faster saving/loading.
Features:
- Async checkpoint saving (non-blocking)
- Compression (gzip) for smaller files
- Incremental checkpoints (only save changed weights)
- Checkpoint validation
"""
import gzip
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
# Global thread pool for async operations
_executor = ThreadPoolExecutor(max_workers=2)
def save_checkpoint_async(
checkpoint_data: Dict[str, Any],
checkpoint_path: Path,
compress: bool = True,
validate: bool = True,
) -> None:
"""
Save checkpoint asynchronously (non-blocking).
Args:
checkpoint_data: Checkpoint data dict
checkpoint_path: Path to save checkpoint
compress: Whether to compress checkpoint (gzip)
validate: Whether to validate checkpoint after saving
"""
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
def _save():
try:
if compress:
# Save compressed
with gzip.open(f"{checkpoint_path}.gz", "wb") as f:
torch.save(checkpoint_data, f)
logger.debug(f"Saved compressed checkpoint to {checkpoint_path}.gz")
else:
# Save uncompressed
torch.save(checkpoint_data, checkpoint_path)
logger.debug(f"Saved checkpoint to {checkpoint_path}")
if validate:
# Validate by loading
if compress:
with gzip.open(f"{checkpoint_path}.gz", "rb") as f:
_ = torch.load(f)
else:
_ = torch.load(checkpoint_path)
logger.debug(f"Validated checkpoint: {checkpoint_path}")
except Exception as e:
logger.error(f"Error saving checkpoint asynchronously: {e}")
# Submit to thread pool (non-blocking)
_executor.submit(_save)
def save_checkpoint_compressed(
checkpoint_data: Dict[str, Any],
checkpoint_path: Path,
compression_level: int = 6,
) -> Path:
"""
Save checkpoint with compression.
Args:
checkpoint_data: Checkpoint data dict
checkpoint_path: Path to save checkpoint
compression_level: Gzip compression level (0-9)
Returns:
Path to saved checkpoint (with .gz extension)
"""
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz")
# Save compressed
with gzip.open(compressed_path, "wb", compresslevel=compression_level) as f:
torch.save(checkpoint_data, f)
original_size = sum(
p.stat().st_size
for p in checkpoint_path.parent.glob(checkpoint_path.name)
if p != compressed_path
)
compressed_size = compressed_path.stat().st_size
compression_ratio = (1 - compressed_size / original_size) * 100 if original_size > 0 else 0
logger.info(
f"Saved compressed checkpoint: {compressed_path} "
f"({compressed_size / 1024 / 1024:.2f} MB, "
f"{compression_ratio:.1f}% compression)"
)
return compressed_path
def load_checkpoint_compressed(checkpoint_path: Path) -> Dict[str, Any]:
"""
Load compressed checkpoint.
Args:
checkpoint_path: Path to checkpoint (with or without .gz extension)
Returns:
Checkpoint data dict
"""
# Try compressed first
if checkpoint_path.suffix == ".gz":
compressed_path = checkpoint_path
else:
compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz")
if compressed_path.exists():
with gzip.open(compressed_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
logger.info(f"Loaded compressed checkpoint from {compressed_path}")
return checkpoint
# Fallback to uncompressed
if checkpoint_path.exists():
checkpoint = torch.load(checkpoint_path, map_location="cpu")
logger.info(f"Loaded checkpoint from {checkpoint_path}")
return checkpoint
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
def save_incremental_checkpoint(
model: nn.Module,
optimizer,
scheduler,
epoch: int,
loss: float,
checkpoint_path: Path,
base_checkpoint_path: Optional[Path] = None,
save_full_every: int = 10,
) -> Path:
"""
Save incremental checkpoint (only changed weights).
Args:
model: Model to save
optimizer: Optimizer state
scheduler: Scheduler state
epoch: Current epoch
loss: Current loss
checkpoint_path: Path to save checkpoint
base_checkpoint_path: Path to base checkpoint (for diff)
save_full_every: Save full checkpoint every N epochs
Returns:
Path to saved checkpoint
"""
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
# Save full checkpoint periodically
if base_checkpoint_path is None or epoch % save_full_every == 0:
checkpoint_data = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": loss,
"is_full": True,
}
torch.save(checkpoint_data, checkpoint_path)
logger.info(f"Saved full checkpoint to {checkpoint_path}")
return checkpoint_path
# Save incremental checkpoint (diff from base)
if base_checkpoint_path and base_checkpoint_path.exists():
base_checkpoint = torch.load(base_checkpoint_path, map_location="cpu")
base_state = base_checkpoint.get("model_state_dict", {})
current_state = model.state_dict()
diff_state = {}
# Only save changed parameters
for key, value in current_state.items():
if key not in base_state or not torch.equal(value, base_state[key]):
diff_state[key] = value
checkpoint_data = {
"epoch": epoch,
"model_state_dict": diff_state, # Only differences
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": loss,
"is_full": False,
"base_checkpoint": str(base_checkpoint_path),
}
torch.save(checkpoint_data, checkpoint_path)
logger.info(
f"Saved incremental checkpoint to {checkpoint_path} "
f"({len(diff_state)}/{len(current_state)} parameters changed)"
)
return checkpoint_path
# Fallback to full checkpoint
checkpoint_data = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": loss,
"is_full": True,
}
torch.save(checkpoint_data, checkpoint_path)
logger.info(f"Saved full checkpoint to {checkpoint_path}")
return checkpoint_path
def load_incremental_checkpoint(
model: nn.Module,
checkpoint_path: Path,
device: str = "cpu",
) -> Dict[str, Any]:
"""
Load incremental checkpoint (applies diff to base).
Args:
model: Model to load weights into
checkpoint_path: Path to incremental checkpoint
device: Device to load on
Returns:
Checkpoint data dict
"""
checkpoint = torch.load(checkpoint_path, map_location=device)
if checkpoint.get("is_full", True):
# Full checkpoint
model.load_state_dict(checkpoint["model_state_dict"])
logger.info(f"Loaded full checkpoint from {checkpoint_path}")
return checkpoint
# Incremental checkpoint - need to load base first
base_checkpoint_path = Path(checkpoint.get("base_checkpoint", ""))
if not base_checkpoint_path.exists():
logger.warning(
f"Base checkpoint not found: {base_checkpoint_path}. "
"Loading incremental checkpoint as-is."
)
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
return checkpoint
# Load base checkpoint
base_checkpoint = torch.load(base_checkpoint_path, map_location=device)
base_state = base_checkpoint.get("model_state_dict", {})
# Apply diff
diff_state = checkpoint["model_state_dict"]
full_state = base_state.copy()
full_state.update(diff_state)
model.load_state_dict(full_state)
logger.info(
f"Loaded incremental checkpoint from {checkpoint_path} "
f"(applied to base: {base_checkpoint_path})"
)
return checkpoint
def validate_checkpoint(checkpoint_path: Path) -> bool:
"""
Validate checkpoint file integrity.
Args:
checkpoint_path: Path to checkpoint
Returns:
True if valid, False otherwise
"""
try:
if checkpoint_path.suffix == ".gz":
with gzip.open(checkpoint_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
else:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Check required keys
required_keys = ["epoch", "model_state_dict"]
if not all(key in checkpoint for key in required_keys):
logger.error(f"Checkpoint missing required keys: {required_keys}")
return False
# Check state dict is valid
if not isinstance(checkpoint["model_state_dict"], dict):
logger.error("Checkpoint model_state_dict is not a dict")
return False
logger.info(f"Checkpoint validated: {checkpoint_path}")
return True
except Exception as e:
logger.error(f"Checkpoint validation failed: {e}")
return False
def get_checkpoint_size(checkpoint_path: Path) -> Dict[str, float]:
"""
Get checkpoint file size information.
Args:
checkpoint_path: Path to checkpoint
Returns:
Dict with size information (bytes, mb, etc.)
"""
sizes = {}
# Check compressed version
compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz")
if compressed_path.exists():
sizes["compressed_bytes"] = compressed_path.stat().st_size
sizes["compressed_mb"] = sizes["compressed_bytes"] / 1024 / 1024
# Check uncompressed version
if checkpoint_path.exists():
sizes["uncompressed_bytes"] = checkpoint_path.stat().st_size
sizes["uncompressed_mb"] = sizes["uncompressed_bytes"] / 1024 / 1024
if "compressed_bytes" in sizes:
sizes["compression_ratio"] = (
1 - sizes["compressed_bytes"] / sizes["uncompressed_bytes"]
) * 100
return sizes