File size: 2,463 Bytes
e340a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
from typing import Dict, Any

from longstream.models.longstream import LongStream
from longstream.utils.hub import resolve_checkpoint_path


class LongStreamModel(torch.nn.Module):
    def __init__(self, cfg: Dict[str, Any] | None):
        super().__init__()
        cfg = cfg or {}

        ckpt_path = resolve_checkpoint_path(
            cfg.get("checkpoint", None), cfg.get("hf", None)
        )

        stream_cfg = dict(cfg.get("longstream_cfg", {}) or {})
        rel_pose_cfg = stream_cfg.pop(
            "rel_pose_head_cfg", cfg.get("rel_pose_head_cfg", None)
        )
        use_rel_pose_head = bool(stream_cfg.pop("use_rel_pose_head", False))
        if use_rel_pose_head and rel_pose_cfg is not None:
            stream_cfg["rel_pose_head_cfg"] = rel_pose_cfg
        self.longstream = LongStream(**stream_cfg)

        if ckpt_path:
            self.load_checkpoint(ckpt_path, strict=bool(cfg.get("strict_load", True)))

    def load_checkpoint(self, ckpt_path: str, strict: bool = True):
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(ckpt_path)
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        if isinstance(ckpt, dict):
            if "model" in ckpt and isinstance(ckpt["model"], dict):
                state = ckpt["model"]
            elif "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
                state = ckpt["state_dict"]
            else:
                state = ckpt
        else:
            raise TypeError("Unsupported checkpoint format")

        if state:
            first_key = next(iter(state.keys()))
            if first_key.startswith("sampler.longstream."):
                state = {k.replace("sampler.", "", 1): v for k, v in state.items()}

        missing, unexpected = self.load_state_dict(state, strict=False)
        if missing or unexpected:
            msg = f"checkpoint mismatch: missing={len(missing)} unexpected={len(unexpected)}"
            if strict:
                raise RuntimeError(msg)
            print(msg)

    def forward(self, *args, **kwargs):
        return self.longstream(*args, **kwargs)

    @property
    def aggregator(self):
        return self.longstream.aggregator

    @property
    def camera_head(self):
        return getattr(self.longstream, "camera_head", None)

    @property
    def rel_pose_head(self):
        return getattr(self.longstream, "rel_pose_head", None)