AutoDataLab2.0 / ceo_brief_env /checkpoint_load.py
uchihamadara1816's picture
Upload 172 files
d02bacd verified
raw
history blame contribute delete
446 Bytes
"""Load trusted local PyTorch state-dict checkpoints (CoS policy weights)."""
from __future__ import annotations
from pathlib import Path
from typing import Any
def load_state_dict(ckpt_path: Path, map_location: str | Any = "cpu") -> Any:
import torch
p = Path(ckpt_path)
try:
return torch.load(p, map_location=map_location, weights_only=True)
except TypeError:
return torch.load(p, map_location=map_location)