from __future__ import annotations from pathlib import Path from typing import Any import torch from sllm.utils import ensure_dir def save_checkpoint( path: str | Path, model: torch.nn.Module, optimizer: torch.optim.Optimizer | None, step: int, model_config: dict[str, Any], train_config: dict[str, Any], extra_state: dict[str, Any] | None = None, ) -> None: payload = { "step": step, "model": model.state_dict(), "optimizer": optimizer.state_dict() if optimizer is not None else None, "model_config": model_config, "train_config": train_config, "extra_state": extra_state or {}, } path = Path(path) ensure_dir(path.parent) torch.save(payload, path) def load_checkpoint(path: str | Path, map_location: str | torch.device = "cpu") -> dict[str, Any]: return torch.load(Path(path), map_location=map_location)