Spaces:
Running on Zero
Running on Zero
File size: 2,463 Bytes
e340a84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | import os
import torch
from typing import Dict, Any
from longstream.models.longstream import LongStream
from longstream.utils.hub import resolve_checkpoint_path
class LongStreamModel(torch.nn.Module):
def __init__(self, cfg: Dict[str, Any] | None):
super().__init__()
cfg = cfg or {}
ckpt_path = resolve_checkpoint_path(
cfg.get("checkpoint", None), cfg.get("hf", None)
)
stream_cfg = dict(cfg.get("longstream_cfg", {}) or {})
rel_pose_cfg = stream_cfg.pop(
"rel_pose_head_cfg", cfg.get("rel_pose_head_cfg", None)
)
use_rel_pose_head = bool(stream_cfg.pop("use_rel_pose_head", False))
if use_rel_pose_head and rel_pose_cfg is not None:
stream_cfg["rel_pose_head_cfg"] = rel_pose_cfg
self.longstream = LongStream(**stream_cfg)
if ckpt_path:
self.load_checkpoint(ckpt_path, strict=bool(cfg.get("strict_load", True)))
def load_checkpoint(self, ckpt_path: str, strict: bool = True):
if not os.path.exists(ckpt_path):
raise FileNotFoundError(ckpt_path)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
if isinstance(ckpt, dict):
if "model" in ckpt and isinstance(ckpt["model"], dict):
state = ckpt["model"]
elif "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
state = ckpt["state_dict"]
else:
state = ckpt
else:
raise TypeError("Unsupported checkpoint format")
if state:
first_key = next(iter(state.keys()))
if first_key.startswith("sampler.longstream."):
state = {k.replace("sampler.", "", 1): v for k, v in state.items()}
missing, unexpected = self.load_state_dict(state, strict=False)
if missing or unexpected:
msg = f"checkpoint mismatch: missing={len(missing)} unexpected={len(unexpected)}"
if strict:
raise RuntimeError(msg)
print(msg)
def forward(self, *args, **kwargs):
return self.longstream(*args, **kwargs)
@property
def aggregator(self):
return self.longstream.aggregator
@property
def camera_head(self):
return getattr(self.longstream, "camera_head", None)
@property
def rel_pose_head(self):
return getattr(self.longstream, "rel_pose_head", None)
|