File size: 446 Bytes
d02bacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""Load trusted local PyTorch state-dict checkpoints (CoS policy weights)."""
from __future__ import annotations

from pathlib import Path
from typing import Any


def load_state_dict(ckpt_path: Path, map_location: str | Any = "cpu") -> Any:
    import torch

    p = Path(ckpt_path)
    try:
        return torch.load(p, map_location=map_location, weights_only=True)
    except TypeError:
        return torch.load(p, map_location=map_location)