|
|
""" |
|
|
Optimized checkpoint utilities for faster saving/loading. |
|
|
|
|
|
Features: |
|
|
- Async checkpoint saving (non-blocking) |
|
|
- Compression (gzip) for smaller files |
|
|
- Incremental checkpoints (only save changed weights) |
|
|
- Checkpoint validation |
|
|
""" |
|
|
|
|
|
import gzip |
|
|
import logging |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
_executor = ThreadPoolExecutor(max_workers=2) |
|
|
|
|
|
|
|
|
def save_checkpoint_async( |
|
|
checkpoint_data: Dict[str, Any], |
|
|
checkpoint_path: Path, |
|
|
compress: bool = True, |
|
|
validate: bool = True, |
|
|
) -> None: |
|
|
""" |
|
|
Save checkpoint asynchronously (non-blocking). |
|
|
|
|
|
Args: |
|
|
checkpoint_data: Checkpoint data dict |
|
|
checkpoint_path: Path to save checkpoint |
|
|
compress: Whether to compress checkpoint (gzip) |
|
|
validate: Whether to validate checkpoint after saving |
|
|
""" |
|
|
checkpoint_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def _save(): |
|
|
try: |
|
|
if compress: |
|
|
|
|
|
with gzip.open(f"{checkpoint_path}.gz", "wb") as f: |
|
|
torch.save(checkpoint_data, f) |
|
|
logger.debug(f"Saved compressed checkpoint to {checkpoint_path}.gz") |
|
|
else: |
|
|
|
|
|
torch.save(checkpoint_data, checkpoint_path) |
|
|
logger.debug(f"Saved checkpoint to {checkpoint_path}") |
|
|
|
|
|
if validate: |
|
|
|
|
|
if compress: |
|
|
with gzip.open(f"{checkpoint_path}.gz", "rb") as f: |
|
|
_ = torch.load(f) |
|
|
else: |
|
|
_ = torch.load(checkpoint_path) |
|
|
logger.debug(f"Validated checkpoint: {checkpoint_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error saving checkpoint asynchronously: {e}") |
|
|
|
|
|
|
|
|
_executor.submit(_save) |
|
|
|
|
|
|
|
|
def save_checkpoint_compressed( |
|
|
checkpoint_data: Dict[str, Any], |
|
|
checkpoint_path: Path, |
|
|
compression_level: int = 6, |
|
|
) -> Path: |
|
|
""" |
|
|
Save checkpoint with compression. |
|
|
|
|
|
Args: |
|
|
checkpoint_data: Checkpoint data dict |
|
|
checkpoint_path: Path to save checkpoint |
|
|
compression_level: Gzip compression level (0-9) |
|
|
|
|
|
Returns: |
|
|
Path to saved checkpoint (with .gz extension) |
|
|
""" |
|
|
checkpoint_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz") |
|
|
|
|
|
|
|
|
with gzip.open(compressed_path, "wb", compresslevel=compression_level) as f: |
|
|
torch.save(checkpoint_data, f) |
|
|
|
|
|
original_size = sum( |
|
|
p.stat().st_size |
|
|
for p in checkpoint_path.parent.glob(checkpoint_path.name) |
|
|
if p != compressed_path |
|
|
) |
|
|
compressed_size = compressed_path.stat().st_size |
|
|
compression_ratio = (1 - compressed_size / original_size) * 100 if original_size > 0 else 0 |
|
|
|
|
|
logger.info( |
|
|
f"Saved compressed checkpoint: {compressed_path} " |
|
|
f"({compressed_size / 1024 / 1024:.2f} MB, " |
|
|
f"{compression_ratio:.1f}% compression)" |
|
|
) |
|
|
|
|
|
return compressed_path |
|
|
|
|
|
|
|
|
def load_checkpoint_compressed(checkpoint_path: Path) -> Dict[str, Any]: |
|
|
""" |
|
|
Load compressed checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to checkpoint (with or without .gz extension) |
|
|
|
|
|
Returns: |
|
|
Checkpoint data dict |
|
|
""" |
|
|
|
|
|
if checkpoint_path.suffix == ".gz": |
|
|
compressed_path = checkpoint_path |
|
|
else: |
|
|
compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz") |
|
|
|
|
|
if compressed_path.exists(): |
|
|
with gzip.open(compressed_path, "rb") as f: |
|
|
checkpoint = torch.load(f, map_location="cpu") |
|
|
logger.info(f"Loaded compressed checkpoint from {compressed_path}") |
|
|
return checkpoint |
|
|
|
|
|
|
|
|
if checkpoint_path.exists(): |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
logger.info(f"Loaded checkpoint from {checkpoint_path}") |
|
|
return checkpoint |
|
|
|
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
|
|
|
|
|
|
def save_incremental_checkpoint( |
|
|
model: nn.Module, |
|
|
optimizer, |
|
|
scheduler, |
|
|
epoch: int, |
|
|
loss: float, |
|
|
checkpoint_path: Path, |
|
|
base_checkpoint_path: Optional[Path] = None, |
|
|
save_full_every: int = 10, |
|
|
) -> Path: |
|
|
""" |
|
|
Save incremental checkpoint (only changed weights). |
|
|
|
|
|
Args: |
|
|
model: Model to save |
|
|
optimizer: Optimizer state |
|
|
scheduler: Scheduler state |
|
|
epoch: Current epoch |
|
|
loss: Current loss |
|
|
checkpoint_path: Path to save checkpoint |
|
|
base_checkpoint_path: Path to base checkpoint (for diff) |
|
|
save_full_every: Save full checkpoint every N epochs |
|
|
|
|
|
Returns: |
|
|
Path to saved checkpoint |
|
|
""" |
|
|
checkpoint_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
if base_checkpoint_path is None or epoch % save_full_every == 0: |
|
|
checkpoint_data = { |
|
|
"epoch": epoch, |
|
|
"model_state_dict": model.state_dict(), |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
"loss": loss, |
|
|
"is_full": True, |
|
|
} |
|
|
torch.save(checkpoint_data, checkpoint_path) |
|
|
logger.info(f"Saved full checkpoint to {checkpoint_path}") |
|
|
return checkpoint_path |
|
|
|
|
|
|
|
|
if base_checkpoint_path and base_checkpoint_path.exists(): |
|
|
base_checkpoint = torch.load(base_checkpoint_path, map_location="cpu") |
|
|
base_state = base_checkpoint.get("model_state_dict", {}) |
|
|
|
|
|
current_state = model.state_dict() |
|
|
diff_state = {} |
|
|
|
|
|
|
|
|
for key, value in current_state.items(): |
|
|
if key not in base_state or not torch.equal(value, base_state[key]): |
|
|
diff_state[key] = value |
|
|
|
|
|
checkpoint_data = { |
|
|
"epoch": epoch, |
|
|
"model_state_dict": diff_state, |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
"loss": loss, |
|
|
"is_full": False, |
|
|
"base_checkpoint": str(base_checkpoint_path), |
|
|
} |
|
|
torch.save(checkpoint_data, checkpoint_path) |
|
|
logger.info( |
|
|
f"Saved incremental checkpoint to {checkpoint_path} " |
|
|
f"({len(diff_state)}/{len(current_state)} parameters changed)" |
|
|
) |
|
|
return checkpoint_path |
|
|
|
|
|
|
|
|
checkpoint_data = { |
|
|
"epoch": epoch, |
|
|
"model_state_dict": model.state_dict(), |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
"loss": loss, |
|
|
"is_full": True, |
|
|
} |
|
|
torch.save(checkpoint_data, checkpoint_path) |
|
|
logger.info(f"Saved full checkpoint to {checkpoint_path}") |
|
|
return checkpoint_path |
|
|
|
|
|
|
|
|
def load_incremental_checkpoint( |
|
|
model: nn.Module, |
|
|
checkpoint_path: Path, |
|
|
device: str = "cpu", |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Load incremental checkpoint (applies diff to base). |
|
|
|
|
|
Args: |
|
|
model: Model to load weights into |
|
|
checkpoint_path: Path to incremental checkpoint |
|
|
device: Device to load on |
|
|
|
|
|
Returns: |
|
|
Checkpoint data dict |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
if checkpoint.get("is_full", True): |
|
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
logger.info(f"Loaded full checkpoint from {checkpoint_path}") |
|
|
return checkpoint |
|
|
|
|
|
|
|
|
base_checkpoint_path = Path(checkpoint.get("base_checkpoint", "")) |
|
|
if not base_checkpoint_path.exists(): |
|
|
logger.warning( |
|
|
f"Base checkpoint not found: {base_checkpoint_path}. " |
|
|
"Loading incremental checkpoint as-is." |
|
|
) |
|
|
model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
|
|
return checkpoint |
|
|
|
|
|
|
|
|
base_checkpoint = torch.load(base_checkpoint_path, map_location=device) |
|
|
base_state = base_checkpoint.get("model_state_dict", {}) |
|
|
|
|
|
|
|
|
diff_state = checkpoint["model_state_dict"] |
|
|
full_state = base_state.copy() |
|
|
full_state.update(diff_state) |
|
|
|
|
|
model.load_state_dict(full_state) |
|
|
logger.info( |
|
|
f"Loaded incremental checkpoint from {checkpoint_path} " |
|
|
f"(applied to base: {base_checkpoint_path})" |
|
|
) |
|
|
|
|
|
return checkpoint |
|
|
|
|
|
|
|
|
def validate_checkpoint(checkpoint_path: Path) -> bool: |
|
|
""" |
|
|
Validate checkpoint file integrity. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to checkpoint |
|
|
|
|
|
Returns: |
|
|
True if valid, False otherwise |
|
|
""" |
|
|
try: |
|
|
if checkpoint_path.suffix == ".gz": |
|
|
with gzip.open(checkpoint_path, "rb") as f: |
|
|
checkpoint = torch.load(f, map_location="cpu") |
|
|
else: |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
|
|
|
required_keys = ["epoch", "model_state_dict"] |
|
|
if not all(key in checkpoint for key in required_keys): |
|
|
logger.error(f"Checkpoint missing required keys: {required_keys}") |
|
|
return False |
|
|
|
|
|
|
|
|
if not isinstance(checkpoint["model_state_dict"], dict): |
|
|
logger.error("Checkpoint model_state_dict is not a dict") |
|
|
return False |
|
|
|
|
|
logger.info(f"Checkpoint validated: {checkpoint_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Checkpoint validation failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_checkpoint_size(checkpoint_path: Path) -> Dict[str, float]: |
|
|
""" |
|
|
Get checkpoint file size information. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to checkpoint |
|
|
|
|
|
Returns: |
|
|
Dict with size information (bytes, mb, etc.) |
|
|
""" |
|
|
sizes = {} |
|
|
|
|
|
|
|
|
compressed_path = checkpoint_path.with_suffix(checkpoint_path.suffix + ".gz") |
|
|
if compressed_path.exists(): |
|
|
sizes["compressed_bytes"] = compressed_path.stat().st_size |
|
|
sizes["compressed_mb"] = sizes["compressed_bytes"] / 1024 / 1024 |
|
|
|
|
|
|
|
|
if checkpoint_path.exists(): |
|
|
sizes["uncompressed_bytes"] = checkpoint_path.stat().st_size |
|
|
sizes["uncompressed_mb"] = sizes["uncompressed_bytes"] / 1024 / 1024 |
|
|
|
|
|
if "compressed_bytes" in sizes: |
|
|
sizes["compression_ratio"] = ( |
|
|
1 - sizes["compressed_bytes"] / sizes["uncompressed_bytes"] |
|
|
) * 100 |
|
|
|
|
|
return sizes |
|
|
|