File size: 14,250 Bytes
f748552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
"""Dual-stream denoising transformer for MiniHack.

Ported from minihack_reference/src/model.py. Architecture follows the
Craftax denoiser conventions (forward return format, obs-encoder pattern)
while using the MiniHack dual-stream design (local CNN + gated global
CNN + auxiliary goal head).
"""

from __future__ import annotations

import copy
import logging
import shutil
from types import SimpleNamespace

import torch
import torch.nn as nn
from torch import Tensor

logger = logging.getLogger(__name__)


class LocalDiffusionPlannerWithGlobal(nn.Module):
    """Dual-stream transformer for masked diffusion action planning.

    Combines a local 9x9 glyph crop with a gated global 21x79 map
    context. Produces action logits and an auxiliary staircase-coordinate
    prediction.

    Architecture:
        Local stream:  Embedding(6000,64) -> CNN(64->32->64) -> Linear -> 1 token
        Global stream: Embedding(6000,32) -> CNN(32->32->64) -> Pool(2,4)
                       -> Linear -> 8 tokens, gated by sigmoid(learnable scalar)
        Goal head:     mean(global_tokens) -> MLP -> [B,2] (before gate)
        Action stream: Embedding(14, n_embd) + timestep + position
        Transformer:   concat all -> TransformerEncoder -> last 64 tokens -> head

    Args:
        cfg: Config namespace with ``action_dim``, ``n_embd``, ``n_head``,
            ``n_layer``, ``n_global_tokens``, ``seq_len``,
            ``global_gate_init``, ``num_diffusion_steps``.
    """

    def __init__(self, cfg: SimpleNamespace) -> None:
        super().__init__()
        action_dim = cfg.action_dim
        n_embd = cfg.n_embd
        n_head = cfg.n_head
        n_layer = cfg.n_layer
        n_global_tokens = cfg.n_global_tokens
        seq_len = cfg.seq_len

        assert n_embd % n_head == 0, (
            f"n_embd ({n_embd}) must be divisible by n_head ({n_head})"
        )

        self.n_global_tokens = n_global_tokens

        # ── Local stream: 9x9 crop -> 1 token ──────────────────────
        self.embedding = nn.Embedding(6000, 64)
        self.cnn = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.GELU(),
            nn.Flatten(),
            nn.Linear(64 * 9 * 9, n_embd),
        )

        # ── Action stream ──────────────────────────────────────────
        self.action_emb = nn.Embedding(action_dim + 2, n_embd)
        self.timestep_emb = nn.Embedding(
            cfg.num_diffusion_steps, n_embd,
        )
        self.pos_emb = nn.Embedding(seq_len, n_embd)

        # ── Transformer ───────────────────────────────────────────
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * 4,
            dropout=getattr(cfg, "dropout", 0.0),
            activation="gelu",
            norm_first=True,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=n_layer, enable_nested_tensor=False,
        )
        self.head = nn.Linear(n_embd, action_dim)

        # ── Global stream: 21x79 map -> 8 tokens ──────────────────
        self.global_embedding = nn.Embedding(6000, 32)
        self.global_cnn = nn.Sequential(
            nn.Conv2d(32, 32, 5, stride=2, padding=2),
            nn.GELU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.GELU(),
        )
        self.global_pool = nn.AdaptiveAvgPool2d((2, 4))
        self.global_proj = nn.Linear(64, n_embd)
        self.global_gate = nn.Parameter(
            torch.tensor(cfg.global_gate_init)
        )

        # ── Auxiliary goal head (before gate) ──────────────────────
        self.goal_head = nn.Sequential(
            nn.Linear(n_embd, 128),
            nn.GELU(),
            nn.Linear(128, 2),
        )

    def forward(
        self,
        local_obs: Tensor,
        global_obs: Tensor,
        action_seq: Tensor,
        t_discrete: int | Tensor,
    ) -> dict[str, Tensor]:
        """Forward pass producing action logits and goal prediction.

        Args:
            local_obs: Local glyph crop. Shape ``[B, 9, 9]``, int.
            global_obs: Full glyph map. Shape ``[B, 21, 79]``, int.
            action_seq: Noisy action sequence. Shape ``[B, seq_len]``, int.
            t_discrete: Discrete timestep index (scalar int or ``[B]``).

        Returns:
            Dict with keys:
            - ``"actions"``: ``[B, seq_len, action_dim]`` logits.
            - ``"goal_pred"``: ``[B, 2]`` normalised staircase coords.
        """
        B, Seq = action_seq.shape
        device = local_obs.device

        # Local stream -> [B, 1, n_embd]
        x_local = self.embedding(local_obs)  # [B, 9, 9, 64]
        x_local = x_local.permute(0, 3, 1, 2)  # [B, 64, 9, 9]
        local_token = self.cnn(x_local).unsqueeze(1)  # [B, 1, n_embd]

        # Global stream -> [B, 8, n_embd]
        x_global = self.global_embedding(global_obs)  # [B, 21, 79, 32]
        x_global = x_global.permute(0, 3, 1, 2)  # [B, 32, 21, 79]
        gf = self.global_cnn(x_global)  # [B, 64, H', W']
        gf = self.global_pool(gf)  # [B, 64, 2, 4]
        global_tokens = gf.permute(0, 2, 3, 1)  # [B, 2, 4, 64]
        global_tokens = global_tokens.reshape(
            B, self.n_global_tokens, -1
        )  # [B, 8, 64]
        global_tokens = self.global_proj(global_tokens)  # [B, 8, n_embd]

        # Aux goal head (before gate for direct gradient to CNN)
        goal_pred = self.goal_head(
            global_tokens.mean(dim=1)
        )  # [B, 2]

        # Apply gate
        gate = torch.sigmoid(self.global_gate)
        global_tokens = global_tokens * gate  # [B, 8, n_embd]

        # Action stream -> [B, seq_len, n_embd]
        positions = torch.arange(
            Seq, device=device,
        ).unsqueeze(0).expand(B, -1)  # [B, seq_len]

        if isinstance(t_discrete, int):
            t_tensor = torch.full(
                (B,), t_discrete, dtype=torch.long, device=device,
            )
        else:
            t_tensor = t_discrete.long().to(device)

        seq_emb = (
            self.action_emb(action_seq)
            + self.timestep_emb(t_tensor).unsqueeze(1)
            + self.pos_emb(positions)
        )  # [B, seq_len, n_embd]

        # Concatenate: [local(1), global(8), actions(seq_len)]
        x = torch.cat(
            [local_token, global_tokens, seq_emb], dim=1,
        )  # [B, 1+8+seq_len, n_embd]

        # Transformer
        out = self.transformer(x)  # [B, 1+8+seq_len, n_embd]

        # Take last seq_len tokens for action predictions
        n_prefix = 1 + self.n_global_tokens
        action_logits = self.head(
            out[:, n_prefix:, :]
        )  # [B, seq_len, action_dim]

        return {"actions": action_logits, "goal_pred": goal_pred}


class LocalDiffusionPlanner(nn.Module):
    """Local-only ablation model (no global stream, no goal head).

    Args:
        cfg: Config namespace.
    """

    def __init__(self, cfg: SimpleNamespace) -> None:
        super().__init__()
        action_dim = cfg.action_dim
        n_embd = cfg.n_embd
        seq_len = cfg.seq_len

        self.embedding = nn.Embedding(6000, 64)
        self.cnn = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.GELU(),
            nn.Flatten(),
            nn.Linear(64 * 9 * 9, n_embd),
        )
        self.action_emb = nn.Embedding(action_dim + 2, n_embd)
        self.timestep_emb = nn.Embedding(cfg.num_diffusion_steps, n_embd)
        self.pos_emb = nn.Embedding(seq_len, n_embd)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=cfg.n_head,
            dim_feedforward=n_embd * 4,
            dropout=getattr(cfg, "dropout", 0.0),
            activation="gelu",
            norm_first=True,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=cfg.n_layer,
        )
        self.head = nn.Linear(n_embd, action_dim)

    def forward(
        self,
        local_obs: Tensor,
        global_obs: Tensor,
        action_seq: Tensor,
        t_discrete: int | Tensor,
    ) -> dict[str, Tensor]:
        """Forward pass (ignores global_obs).

        Args:
            local_obs: ``[B, 9, 9]`` int.
            global_obs: ``[B, 21, 79]`` int (ignored).
            action_seq: ``[B, seq_len]`` int.
            t_discrete: Timestep index.

        Returns:
            Dict with ``"actions"`` key only (no goal_pred).
        """
        B, Seq = action_seq.shape
        device = local_obs.device

        x_state = self.embedding(local_obs).permute(0, 3, 1, 2)
        state_emb = self.cnn(x_state).unsqueeze(1)  # [B, 1, n_embd]

        positions = torch.arange(
            Seq, device=device,
        ).unsqueeze(0).expand(B, -1)

        if isinstance(t_discrete, int):
            t_tensor = torch.full(
                (B,), t_discrete, dtype=torch.long, device=device,
            )
        else:
            t_tensor = t_discrete.long().to(device)

        seq_emb = (
            self.action_emb(action_seq)
            + self.timestep_emb(t_tensor).unsqueeze(1)
            + self.pos_emb(positions)
        )
        x = torch.cat([state_emb, seq_emb], dim=1)
        out = self.transformer(x)
        return {"actions": self.head(out[:, 1:, :])}


# ── Factory ──────────────────────────────────────────────────────────


def make_model(cfg: SimpleNamespace) -> nn.Module:
    """Instantiate the default MiniHack denoising model.

    Args:
        cfg: Config namespace.

    Returns:
        ``LocalDiffusionPlannerWithGlobal`` instance.
    """
    return LocalDiffusionPlannerWithGlobal(cfg)


def _has_c_compiler() -> bool:
    """Check whether a C compiler is reachable by Triton.

    Checks the ``CC`` env var (set by conda activation scripts),
    then falls back to ``cc`` and ``gcc`` on ``PATH``.
    """
    import os
    cc_env = os.environ.get("CC")
    if cc_env and shutil.which(cc_env):
        return True
    return shutil.which("cc") is not None or shutil.which("gcc") is not None


def try_compile(model: nn.Module, cfg: SimpleNamespace) -> nn.Module:
    """Wrap *model* with ``torch.compile`` if enabled and a C compiler exists.

    Falls back to the uncompiled model when ``torch.compile`` is
    unavailable or Triton cannot find a C compiler (common on managed
    GPU nodes that lack ``gcc``/``cc``).

    Args:
        model: The raw (uncompiled) model.
        cfg: Config namespace; reads ``torch_compile`` bool.

    Returns:
        Compiled model, or *model* unchanged on fallback.
    """
    if not getattr(cfg, "torch_compile", False):
        return model
    if not hasattr(torch, "compile"):
        return model
    if not _has_c_compiler():
        logger.warning(
            "torch.compile requested but no C compiler found "
            "(CC env var, cc, gcc); falling back to eager mode"
        )
        return model
    logger.info("Compiling model with torch.compile")
    return torch.compile(model, mode="default")  # type: ignore[return-value]


# ── EMA ──────────────────────────────────────────────────────────────


class ModelEMA:
    """Exponential moving average of model parameters.

    Maintains a shadow copy of parameters updated as
    ``theta_ema <- decay * theta_ema + (1 - decay) * theta``.

    Args:
        model: Source model.
        decay: EMA decay factor (default 0.999).
    """

    def __init__(self, model: nn.Module, decay: float = 0.999) -> None:
        self._decay = decay
        self._shadow: dict[str, Tensor] = {}
        for name, param in model.named_parameters():
            self._shadow[name] = param.data.clone()

    @torch.no_grad()
    def update(self, model: nn.Module) -> None:
        """Update shadow parameters from *model*.

        Args:
            model: Source model whose parameters are blended in.
        """
        for name, param in model.named_parameters():
            self._shadow[name].mul_(self._decay).add_(
                param.data, alpha=1.0 - self._decay,
            )

    def apply_to(self, model: nn.Module) -> None:
        """Copy shadow parameters into *model* (for inference).

        Args:
            model: Target model to overwrite.
        """
        for name, param in model.named_parameters():
            param.data.copy_(self._shadow[name])

    def state_dict(self) -> dict[str, Tensor]:
        """Return shadow parameter dict for serialisation.

        Returns:
            Dict mapping parameter names to EMA tensors.
        """
        return {k: v.clone() for k, v in self._shadow.items()}

    def load_state_dict(self, sd: dict[str, Tensor]) -> None:
        """Restore shadow parameters from *sd*.

        Args:
            sd: State dict from a prior ``state_dict()`` call.
        """
        for k, v in sd.items():
            if k in self._shadow:
                self._shadow[k].copy_(v)

    def parameters(self):
        """Iterate over shadow parameter tensors.

        Yields:
            EMA parameter tensors.
        """
        yield from self._shadow.values()

    def make_eval_model(self, model: nn.Module) -> nn.Module:
        """Return a deep copy of *model* with EMA weights applied.

        Args:
            model: Template model (architecture).

        Returns:
            New model with shadow parameters.
        """
        eval_model = copy.deepcopy(model)
        self.apply_to(eval_model)
        eval_model.eval()
        return eval_model