Gaaaavin commited on
Commit
cec7d2d
·
verified ·
1 Parent(s): bc06f9e

v2: full encoder weights + auto_map for trust_remote_code=True

Browse files

Re-converted from CityWalker_2000hr.ckpt with the reworked encoder pipeline (transformers.Dinov2Model in __init__, no separate load_obs_encoder). Adds auto_map + modeling_citywalker.py + configuration_citywalker.py so users can do AutoModel.from_pretrained("ai4ce/citywalker", trust_remote_code=True) without pip-installing wanderland-lab. The DINOv2 backbone build path is meta-device-aware: under the outer from_pretrained context it constructs an empty Dinov2Model(Dinov2Config) shell that the safetensors blob then populates; under direct CityWalkerModel(cfg) construction it pulls real weights from facebook/dinov2-base.

config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "CityWalkerModel"
4
  ],
 
 
 
 
5
  "context_size": 5,
6
  "cord_include_input": true,
7
  "cord_num_freqs": 6,
 
2
  "architectures": [
3
  "CityWalkerModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_citywalker.CityWalkerConfig",
7
+ "AutoModel": "modeling_citywalker.CityWalkerModel"
8
+ },
9
  "context_size": 5,
10
  "cord_include_input": true,
11
  "cord_num_freqs": 6,
configuration_citywalker.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace `PretrainedConfig` for the CityWalker waypoint-prediction model.
2
+
3
+ Mirrors the fields of upstream CityWalker's nested OmegaConf struct
4
+ (`config/finetune.yaml`) but in a flat, typed, JSON-serializable form so the
5
+ model round-trips through `save_pretrained` / `from_pretrained`.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from transformers import PretrainedConfig
11
+
12
+
13
+ class CityWalkerConfig(PretrainedConfig):
14
+ model_type = "citywalker"
15
+
16
+ def __init__(
17
+ self,
18
+ # Observation encoder (DINOv2 backbone).
19
+ obs_encoder_type: str = "dinov2_vitb14",
20
+ context_size: int = 5,
21
+ crop: tuple[int, int] = (400, 400),
22
+ resize: tuple[int, int] = (392, 392),
23
+ freeze_obs_encoder: bool = True,
24
+ # Coordinate embedding.
25
+ cord_num_freqs: int = 6,
26
+ cord_include_input: bool = True,
27
+ # Image preprocessing inside the model forward pass (upstream behavior).
28
+ do_rgb_normalize: bool = True,
29
+ do_resize: bool = True,
30
+ # Transformer decoder.
31
+ decoder_num_heads: int = 8,
32
+ decoder_num_layers: int = 16,
33
+ decoder_ff_dim_factor: int = 4,
34
+ # Output head.
35
+ len_traj_pred: int = 5,
36
+ **kwargs,
37
+ ):
38
+ self.obs_encoder_type = obs_encoder_type
39
+ self.context_size = int(context_size)
40
+ self.crop = tuple(crop)
41
+ self.resize = tuple(resize)
42
+ self.freeze_obs_encoder = bool(freeze_obs_encoder)
43
+ self.cord_num_freqs = int(cord_num_freqs)
44
+ self.cord_include_input = bool(cord_include_input)
45
+ self.do_rgb_normalize = bool(do_rgb_normalize)
46
+ self.do_resize = bool(do_resize)
47
+ self.decoder_num_heads = int(decoder_num_heads)
48
+ self.decoder_num_layers = int(decoder_num_layers)
49
+ self.decoder_ff_dim_factor = int(decoder_ff_dim_factor)
50
+ self.len_traj_pred = int(len_traj_pred)
51
+ super().__init__(**kwargs)
52
+
53
+ @property
54
+ def feature_dim(self) -> int:
55
+ """Feature width of the chosen DINOv2 variant."""
56
+ return {
57
+ "dinov2_vits14": 384,
58
+ "dinov2_vitb14": 768,
59
+ "dinov2_vitl14": 1024,
60
+ "dinov2_vitg14": 1536,
61
+ }[self.obs_encoder_type]
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb3c609a411eb901cdf4500a542c324e33bcf7a2b6ce328de6590cc55b8b8ca9
3
- size 833735756
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:180913e72708fae8317621d940a236d02caf41f2f0086217530cdde0f19d6538
3
+ size 833744196
modeling_citywalker.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CityWalker waypoint-prediction model, ported to a HuggingFace `PreTrainedModel`.
2
+
3
+ Port of `model/citywalker_feat.py` + supporting modules from
4
+ https://github.com/ai4ce/CityWalker, stripped of Lightning/OmegaConf.
5
+
6
+ Architecture (inference-only):
7
+
8
+ images (B,T,3,H,W) ──► DINOv2 ──► obs tokens (B,T,D)
9
+ coords (B,T+1,2) ──► PolarEmbedding + Linear ──► goal token (B,1,D)
10
+ ──► concat ──► (B,T+2,D)
11
+ ──► TransformerEncoder (self-attention decoder)
12
+ ──► MLP head ──► (waypoints_pred, arrive_pred)
13
+
14
+ Outputs:
15
+ waypoints_pred : (B, len_traj_pred, 2) cumulative XY deltas in body frame
16
+ arrive_pred : (B, 1) logits
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import math
22
+ from dataclasses import dataclass
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torchvision.transforms.functional as TF
28
+ from transformers import Dinov2Config, Dinov2Model, PreTrainedModel
29
+ from transformers.modeling_outputs import ModelOutput
30
+
31
+ from .configuration_citywalker import CityWalkerConfig
32
+
33
+
34
+ def _build_obs_encoder(name: str) -> Dinov2Model:
35
+ """Build the DINOv2 backbone, working under both fresh-init and
36
+ `from_pretrained` (which wraps __init__ in a `with torch.device("meta")`
37
+ context starting in transformers 5.x).
38
+
39
+ Inside the meta context, calling ``Dinov2Model.from_pretrained`` raises
40
+ because nested `from_pretrained` calls are an anti-pattern: the outer
41
+ loader is responsible for materializing weights. So when we detect the
42
+ meta context, we just build the empty `Dinov2Model(config)` shell — the
43
+ outer `from_pretrained` will populate the encoder weights from the
44
+ bundled safetensors blob (which contains the encoder's weights via
45
+ Phase 2's full-state-dict save).
46
+
47
+ Outside the meta context (direct `CityWalkerModel(cfg)` construction),
48
+ we still pull the real DINOv2 weights from `facebook/dinov2-*` so users
49
+ instantiating from scratch get a useful backbone.
50
+ """
51
+ in_meta = (
52
+ torch.device("meta") == _peek_default_device()
53
+ )
54
+ if in_meta:
55
+ return Dinov2Model(Dinov2Config.from_pretrained(name))
56
+ return Dinov2Model.from_pretrained(name)
57
+
58
+
59
+ def _peek_default_device() -> Optional[torch.device]:
60
+ """Return the device set by the outermost `with torch.device(...)` /
61
+ `torch.set_default_device(...)` context, or None if neither is active."""
62
+ try:
63
+ from transformers.modeling_utils import (
64
+ get_torch_context_manager_or_global_device,
65
+ )
66
+
67
+ return get_torch_context_manager_or_global_device()
68
+ except Exception:
69
+ return None
70
+
71
+
72
+ # Map our `obs_encoder_type` strings (matching upstream torch.hub names) to
73
+ # the corresponding facebook/dinov2-* HF repo. We mirror only the four LVD142M
74
+ # no-register variants — same backbones, same weights, just shipped via HF
75
+ # instead of torch.hub. This is what lets us drop torch.hub entirely while
76
+ # keeping the legacy CityWalker `obs_encoder_type` strings working.
77
+ _DINOV2_HF_REPOS = {
78
+ "dinov2_vits14": "facebook/dinov2-small",
79
+ "dinov2_vitb14": "facebook/dinov2-base",
80
+ "dinov2_vitl14": "facebook/dinov2-large",
81
+ "dinov2_vitg14": "facebook/dinov2-giant",
82
+ }
83
+
84
+
85
+ @dataclass
86
+ class CityWalkerOutput(ModelOutput):
87
+ waypoints: torch.FloatTensor = None
88
+ arrive_logits: torch.FloatTensor = None
89
+ token_features: Optional[torch.FloatTensor] = None
90
+ future_features: Optional[torch.FloatTensor] = None
91
+
92
+
93
+ class PolarEmbedding(nn.Module):
94
+ """Fourier-feature encoding of 2D body-frame coordinates in polar form."""
95
+
96
+ def __init__(self, num_freqs: int, include_input: bool):
97
+ super().__init__()
98
+ self.num_freqs = num_freqs
99
+ self.include_input = include_input
100
+ freq_bands = 2.0 ** torch.linspace(0, num_freqs - 1, num_freqs)
101
+ self.register_buffer("freq_bands", freq_bands)
102
+ self.out_dim = (2 if include_input else 0) + 4 * num_freqs
103
+
104
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
105
+ x, y = coords[..., 0], coords[..., 1]
106
+ r = torch.sqrt(x * x + y * y).unsqueeze(-1)
107
+ theta = torch.atan2(y, x).unsqueeze(-1)
108
+
109
+ parts = [r, theta] if self.include_input else []
110
+ fb = self.freq_bands.view(1, 1, -1)
111
+ parts.append(torch.sin(theta * fb))
112
+ parts.append(torch.cos(theta * fb))
113
+ parts.append(torch.sin(r * fb))
114
+ parts.append(torch.cos(r * fb))
115
+ return torch.cat(parts, dim=-1)
116
+
117
+
118
+ class _PositionalEncoding(nn.Module):
119
+ """Sinusoidal positional encoding (upstream naming preserved for weight-key parity)."""
120
+
121
+ def __init__(self, d_model: int, max_seq_len: int):
122
+ super().__init__()
123
+ pos_enc = torch.zeros(max_seq_len, d_model)
124
+ pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
125
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
126
+ pos_enc[:, 0::2] = torch.sin(pos * div_term)
127
+ pos_enc[:, 1::2] = torch.cos(pos * div_term)
128
+ self.register_buffer("pos_enc", pos_enc.unsqueeze(0))
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ return x + self.pos_enc[:, : x.size(1), :]
132
+
133
+
134
+ class _FeatPredictor(nn.Module):
135
+ """Transformer self-attention stack over (context_size + 2) tokens."""
136
+
137
+ def __init__(self, embed_dim: int, seq_len: int, nhead: int, num_layers: int, ff_dim_factor: int):
138
+ super().__init__()
139
+ self.positional_encoding = _PositionalEncoding(embed_dim, max_seq_len=seq_len)
140
+ layer = nn.TransformerEncoderLayer(
141
+ d_model=embed_dim,
142
+ nhead=nhead,
143
+ dim_feedforward=ff_dim_factor * embed_dim,
144
+ activation="gelu",
145
+ batch_first=True,
146
+ norm_first=True,
147
+ )
148
+ self.sa_layer = layer
149
+ self.sa_decoder = nn.TransformerEncoder(layer, num_layers=num_layers)
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ return self.sa_decoder(self.positional_encoding(x))
153
+
154
+
155
+ class CityWalkerModel(PreTrainedModel):
156
+ """HF-compatible CityWalker model. Inference path only; training stays upstream."""
157
+
158
+ config_class = CityWalkerConfig
159
+ base_model_prefix = "citywalker"
160
+ supports_gradient_checkpointing = False
161
+ main_input_name = "images"
162
+
163
+ def __init__(self, config: CityWalkerConfig):
164
+ super().__init__(config)
165
+ self.config = config
166
+
167
+ if config.do_rgb_normalize:
168
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
169
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
170
+
171
+ if config.obs_encoder_type not in _DINOV2_HF_REPOS:
172
+ raise ValueError(
173
+ f"Unsupported obs_encoder_type: {config.obs_encoder_type!r}. "
174
+ f"Expected one of {sorted(_DINOV2_HF_REPOS)}."
175
+ )
176
+ # DINOv2 backbone. See `_build_obs_encoder` — handles the case where
177
+ # we're inside the outer `from_pretrained`'s meta-device context
178
+ # (transformers 5.x) by building an empty shell that the outer
179
+ # loader will fill from our safetensors blob.
180
+ self.obs_encoder = _build_obs_encoder(
181
+ _DINOV2_HF_REPOS[config.obs_encoder_type]
182
+ )
183
+ if config.freeze_obs_encoder:
184
+ for p in self.obs_encoder.parameters():
185
+ p.requires_grad = False
186
+ self.obs_encoder.eval()
187
+ self._feature_dim = config.feature_dim
188
+
189
+ self.cord_embedding = PolarEmbedding(
190
+ num_freqs=config.cord_num_freqs,
191
+ include_input=config.cord_include_input,
192
+ )
193
+ cord_enc_dim = self.cord_embedding.out_dim * (config.context_size + 1)
194
+ self.compress_goal_enc = nn.Linear(cord_enc_dim, self._feature_dim)
195
+
196
+ self.predictor = _FeatPredictor(
197
+ embed_dim=self._feature_dim,
198
+ seq_len=config.context_size + 1,
199
+ nhead=config.decoder_num_heads,
200
+ num_layers=config.decoder_num_layers,
201
+ ff_dim_factor=config.decoder_ff_dim_factor,
202
+ )
203
+ self.predictor_mlp = nn.Sequential(
204
+ nn.Linear((config.context_size + 1) * self._feature_dim, 256),
205
+ nn.ReLU(),
206
+ nn.Linear(256, 128),
207
+ nn.ReLU(),
208
+ nn.Linear(128, 64),
209
+ nn.ReLU(),
210
+ nn.Linear(64, 32),
211
+ )
212
+ self.wp_predictor = nn.Linear(32, config.len_traj_pred * 2)
213
+ self.arrive_predictor = nn.Linear(32, 1)
214
+
215
+ self.post_init()
216
+
217
+ def _encode_obs(self, x: torch.Tensor) -> torch.Tensor:
218
+ """Run a batch through the DINOv2 backbone and return the CLS token.
219
+
220
+ Upstream's torch.hub backbone returns ``head(x_norm_clstoken)`` (head
221
+ is Identity for the pretrained variants), giving (B, feature_dim).
222
+ HF's ``Dinov2Model`` returns ``BaseModelOutputWithPooling`` with
223
+ ``last_hidden_state`` of shape (B, num_patches+1, feature_dim); the
224
+ CLS token is at index 0 along the sequence dim. Using ``[:, 0]`` here
225
+ matches upstream byte-for-byte at inference (same weights, same
226
+ layernorm, same tokenization).
227
+ """
228
+ out = self.obs_encoder(pixel_values=x)
229
+ return out.last_hidden_state[:, 0]
230
+
231
+ def _preprocess(self, x: torch.Tensor) -> torch.Tensor:
232
+ if self.config.do_rgb_normalize:
233
+ x = (x - self.mean) / self.std
234
+ if self.config.do_resize:
235
+ x = TF.center_crop(x, list(self.config.crop))
236
+ x = TF.resize(x, list(self.config.resize))
237
+ return x
238
+
239
+ def forward(
240
+ self,
241
+ images: torch.Tensor,
242
+ coords: torch.Tensor,
243
+ future_images: Optional[torch.Tensor] = None,
244
+ return_dict: bool = True,
245
+ ):
246
+ """
247
+ Args:
248
+ images: (B, context_size, 3, H, W) float tensor in [0, 1].
249
+ coords: (B, context_size + 1, 2) recent body-frame XY positions.
250
+ future_images: optional (B, context_size, 3, H, W) for the
251
+ feature-prediction head (unused at inference).
252
+ """
253
+ B, T, _, H, W = images.shape
254
+ x = self._preprocess(images.view(B * T, 3, H, W))
255
+ obs_enc = self._encode_obs(x).view(B, T, -1)
256
+
257
+ future_enc: Optional[torch.Tensor] = None
258
+ if future_images is not None:
259
+ fx = self._preprocess(future_images.view(B * T, 3, H, W))
260
+ future_enc = self._encode_obs(fx).view(B, T, -1)
261
+
262
+ cord_enc = self.cord_embedding(coords).view(B, -1)
263
+ cord_enc = self.compress_goal_enc(cord_enc).view(B, 1, -1)
264
+
265
+ tokens = torch.cat([obs_enc, cord_enc], dim=1)
266
+ features = self.predictor(tokens)
267
+ dec_out = self.predictor_mlp(features.view(B, -1))
268
+
269
+ wp = self.wp_predictor(dec_out).view(B, self.config.len_traj_pred, 2)
270
+ wp = torch.cumsum(wp, dim=1)
271
+ arrive = self.arrive_predictor(dec_out).view(B, 1)
272
+
273
+ if not return_dict:
274
+ return wp, arrive, features[:, :-1], future_enc
275
+ return CityWalkerOutput(
276
+ waypoints=wp,
277
+ arrive_logits=arrive,
278
+ token_features=features[:, :-1],
279
+ future_features=future_enc,
280
+ )