Spaces:
Running on Zero
Running on Zero
| import os | |
| import torch | |
| from typing import Dict, Any | |
| from longstream.models.longstream import LongStream | |
| from longstream.utils.hub import resolve_checkpoint_path | |
| class LongStreamModel(torch.nn.Module): | |
| def __init__(self, cfg: Dict[str, Any] | None): | |
| super().__init__() | |
| cfg = cfg or {} | |
| ckpt_path = resolve_checkpoint_path( | |
| cfg.get("checkpoint", None), cfg.get("hf", None) | |
| ) | |
| stream_cfg = dict(cfg.get("longstream_cfg", {}) or {}) | |
| rel_pose_cfg = stream_cfg.pop( | |
| "rel_pose_head_cfg", cfg.get("rel_pose_head_cfg", None) | |
| ) | |
| use_rel_pose_head = bool(stream_cfg.pop("use_rel_pose_head", False)) | |
| if use_rel_pose_head and rel_pose_cfg is not None: | |
| stream_cfg["rel_pose_head_cfg"] = rel_pose_cfg | |
| self.longstream = LongStream(**stream_cfg) | |
| if ckpt_path: | |
| self.load_checkpoint(ckpt_path, strict=bool(cfg.get("strict_load", True))) | |
| def load_checkpoint(self, ckpt_path: str, strict: bool = True): | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError(ckpt_path) | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| if isinstance(ckpt, dict): | |
| if "model" in ckpt and isinstance(ckpt["model"], dict): | |
| state = ckpt["model"] | |
| elif "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict): | |
| state = ckpt["state_dict"] | |
| else: | |
| state = ckpt | |
| else: | |
| raise TypeError("Unsupported checkpoint format") | |
| if state: | |
| first_key = next(iter(state.keys())) | |
| if first_key.startswith("sampler.longstream."): | |
| state = {k.replace("sampler.", "", 1): v for k, v in state.items()} | |
| missing, unexpected = self.load_state_dict(state, strict=False) | |
| if missing or unexpected: | |
| msg = f"checkpoint mismatch: missing={len(missing)} unexpected={len(unexpected)}" | |
| if strict: | |
| raise RuntimeError(msg) | |
| print(msg) | |
| def forward(self, *args, **kwargs): | |
| return self.longstream(*args, **kwargs) | |
| def aggregator(self): | |
| return self.longstream.aggregator | |
| def camera_head(self): | |
| return getattr(self.longstream, "camera_head", None) | |
| def rel_pose_head(self): | |
| return getattr(self.longstream, "rel_pose_head", None) | |