dn6 HF Staff commited on
Commit
57eef5f
·
verified ·
1 Parent(s): 3b9a7ab

Add diffusers support

Browse files
README.md CHANGED
@@ -8,7 +8,7 @@ tags:
8
  - Egocentric
9
  ---
10
 
11
- Waypoint-1-Small is a 2.3 billion parameter control-and-text-conditioned causal diffusion model. It is a transformer architecture utilizing rectified flow, distilled via self forcing with DMD. The model can autoregressively generate new frames given historical frames, actions, and text.
12
 
13
  # Capabilities:
14
 
@@ -23,12 +23,99 @@ In order to simply use Waypoint-1-Small, we recommend [Biome](https://github.com
23
 
24
  To run the model locally, we recommend an NVIDIA RTX 5090, which should achieve 20-30 FPS, or an RTX 6000 Pro Blackwell, which should achieve ~35 FPS.
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Keywords
27
 
28
  To properly explain limitations and misuse we must define some terms. While the model can be used for general interactive video generation tasks, we herein define interacting with the model via sending controls and receiving new frames as “playing” the model, and the agent/user inputting controls as the “player”. The model has two forms of output, continuations and generations. Continuations occur when seed frames are given and no inputs are given. For example, if a scene has fire or water, you may see them evolve progressively in the generated frames even if no action is given. Likewise, if you seed with an image of a humanoid entity, the entity will persist on the screen as you move/look around. However, generations occur when the player plays with the model extensively, for example moving around, turning around fully, or interacting with objects/items. Continuations roughly correspond to moving around already existing information in the given context frames while generations correspond to creating entirely new information.
29
 
30
  # Limitations
31
-
32
  - Continuations can plausibly model any inputted scene or photo, and will depend largely on the seed frame given. For generations, the model may occasionally:
33
  - Ignore given text prompt
34
  - Ignore certain controls in specific contexts
@@ -46,4 +133,5 @@ To properly explain limitations and misuse we must define some terms. While the
46
  - For simulating extremely violent acts
47
  - For generating violent/gory video
48
  - For facilitation of large-scale disinformation campaigns
49
- - For the purpose of generating any sexually explicit or suggestive material
 
 
8
  - Egocentric
9
  ---
10
 
11
+ Waypoint-1-Small is a 2.3 billion parameter control-and-text-conditioned causal diffusion model. It is a transformer architecture utilizing rectified flow, distilled via self forcing with DMD. The model can autoregressively generate new frames given historical frames, actions, and text.
12
 
13
  # Capabilities:
14
 
 
23
 
24
  To run the model locally, we recommend an NVIDIA RTX 5090, which should achieve 20-30 FPS, or an RTX 6000 Pro Blackwell, which should achieve ~35 FPS.
25
 
26
+ # Run with Diffusers Modular Pipelines
27
+
28
+ World Engine and Waypoint-1 can be used with [Diffusers Modular Pipelines](https://huggingface.co/docs/diffusers/main/en/api/modular_pipelines/modular_pipeline).
29
+
30
+ ## Setup
31
+
32
+ ```bash
33
+ uv venv -p 3.11 && uv pip install \
34
+ torch>=2.9.0 \
35
+ diffusers>=0.36.0 \
36
+ transformers>=4.57.1 \
37
+ einops>=0.8.0 \
38
+ tensordict>=0.5.0 \
39
+ regex \
40
+ ftfy \
41
+ imageio \
42
+ imageio-ffmpeg \
43
+ tqdm
44
+ ```
45
+
46
+ ## Usage Example
47
+
48
+ ```python
49
+ import random
50
+ import torch
51
+
52
+ from tqdm import tqdm
53
+ from dataclasses import dataclass, field
54
+ from typing import Set, Tuple
55
+ from diffusers.modular_pipelines import ModularPipeline
56
+ from diffusers.utils import load_image, export_to_video
57
+
58
+ @dataclass
59
+ class CtrlInput:
60
+ button: Set[int] = field(default_factory=set) # pressed button IDs
61
+ mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) velocity
62
+
63
+
64
+ # Generate random control trajectories
65
+ ctrl = lambda: random.choice(
66
+ [
67
+ CtrlInput(button={48, 42}, mouse=[0.4, 0.3]),
68
+ CtrlInput(mouse=[0.1, 0.2]),
69
+ CtrlInput(button={95, 32, 105}),
70
+ ]
71
+ )
72
+ model_id = "Overworld/Waypoint-1-Small"
73
+
74
+ pipe = ModularPipeline.from_pretrained(model_id, trust_remote_code=True)
75
+ pipe.load_components(
76
+ device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True
77
+ )
78
+ pipe.transformer.apply_inference_patches()
79
+
80
+ # Optional Quantization Step
81
+ # Available options are: nvfp4 (if running on Blackwell hardware), fp8, w8a8
82
+ # pipe.transformer.quantize("nvfp4")
83
+ pipe.transformer.compile(fullgraph=True, mode="max-autotune", dynamic=False)
84
+ pipe.vae.bake_weight_norm()
85
+ pipe.vae.compile(fullgraph=True, mode="max-autotune")
86
+
87
+ prompt = "A fun game"
88
+ image = load_image(
89
+ "https://gist.github.com/user-attachments/assets/4adc5a3d-6980-4d1e-b6e8-9033cdf61c66"
90
+ )
91
+
92
+ num_frames = 240
93
+ outputs = []
94
+
95
+ # create world state based on an initial image
96
+ state = pipe(prompt=prompt, image=image, button=ctrl().button, mouse=ctrl().mouse)
97
+ outputs.append(state.values["images"])
98
+
99
+ state.values["image"] = None
100
+ for _ in tqdm(range(1, num_frames)):
101
+ state = pipe(
102
+ state,
103
+ prompt=prompt,
104
+ button=ctrl().button,
105
+ mouse=ctrl().mouse,
106
+ output_type="pil",
107
+ )
108
+ outputs.append(state.values["images"])
109
+
110
+ export_to_video(outputs, "waypoint-1-small.mp4", fps=60)
111
+ ```
112
+
113
  # Keywords
114
 
115
  To properly explain limitations and misuse we must define some terms. While the model can be used for general interactive video generation tasks, we herein define interacting with the model via sending controls and receiving new frames as “playing” the model, and the agent/user inputting controls as the “player”. The model has two forms of output, continuations and generations. Continuations occur when seed frames are given and no inputs are given. For example, if a scene has fire or water, you may see them evolve progressively in the generated frames even if no action is given. Likewise, if you seed with an image of a humanoid entity, the entity will persist on the screen as you move/look around. However, generations occur when the player plays with the model extensively, for example moving around, turning around fully, or interacting with objects/items. Continuations roughly correspond to moving around already existing information in the given context frames while generations correspond to creating entirely new information.
116
 
117
  # Limitations
118
+
119
  - Continuations can plausibly model any inputted scene or photo, and will depend largely on the seed frame given. For generations, the model may occasionally:
120
  - Ignore given text prompt
121
  - Ignore certain controls in specific contexts
 
133
  - For simulating extremely violent acts
134
  - For generating violent/gory video
135
  - For facilitation of large-scale disinformation campaigns
136
+ - For the purpose of generating any sexually explicit or suggestive material
137
+
__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """
17
+ WorldEngine Modular Pipeline
18
+
19
+ A Diffusers-compatible modular pipeline for frame-by-frame world model generation.
20
+ Supports text and controller (mouse + button + scroll) conditioning.
21
+ """
22
+
23
+ from .modular_blocks import WorldEngineBlocks, AUTO_BLOCKS
24
+ from .encoders import WorldEngineTextEncoderStep, WorldEngineControllerEncoderStep
25
+ from .before_denoise import (
26
+ WorldEngineBeforeDenoiseStep,
27
+ WorldEngineSetTimestepsStep,
28
+ WorldEnginePrepareLatentsStep,
29
+ WorldEngineSetupKVCacheStep,
30
+ StaticKVCache,
31
+ LayerKVCache,
32
+ )
33
+ from .denoise import WorldEngineDenoiseLoop
34
+ from .decoders import WorldEngineDecodeStep
35
+ from .vae import WorldEngineVAE
36
+
37
+ __version__ = "0.1.0"
38
+
39
+ __all__ = [
40
+ # Main pipeline blocks
41
+ "WorldEngineBlocks",
42
+ "AUTO_BLOCKS",
43
+ # Encoder blocks
44
+ "WorldEngineTextEncoderStep",
45
+ "WorldEngineControllerEncoderStep",
46
+ # Before denoise blocks
47
+ "WorldEngineBeforeDenoiseStep",
48
+ "WorldEngineSetTimestepsStep",
49
+ "WorldEnginePrepareLatentsStep",
50
+ "WorldEngineSetupKVCacheStep",
51
+ # Denoise block
52
+ "WorldEngineDenoiseLoop",
53
+ # Decoder blocks
54
+ "WorldEngineDecodeStep",
55
+ # Models
56
+ "WorldModel",
57
+ "WorldEngineVAE",
58
+ # KV Cache
59
+ "StaticKVCache",
60
+ "LayerKVCache",
61
+ ]
before_denoise.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Before-denoise blocks for WorldEngine modular pipeline."""
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ import PIL.Image
21
+ import torch
22
+ from torch import nn, Tensor
23
+ from tensordict import TensorDict
24
+ from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask
25
+
26
+ from diffusers.configuration_utils import FrozenDict
27
+ from diffusers.image_processor import VaeImageProcessor
28
+ from diffusers.utils import logging
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.modular_pipelines import (
31
+ ModularPipelineBlocks,
32
+ ModularPipeline,
33
+ PipelineState,
34
+ SequentialPipelineBlocks,
35
+ )
36
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
37
+ ComponentSpec,
38
+ ConfigSpec,
39
+ InputParam,
40
+ OutputParam,
41
+ )
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask:
47
+ """
48
+ Create a block mask for flex_attention.
49
+
50
+ Args:
51
+ T: Q length for this frame
52
+ L: KV capacity == written.numel()
53
+ written: [L] bool, True where there is valid KV data
54
+ """
55
+ BS = _DEFAULT_SPARSE_BLOCK_SIZE
56
+ KV_blocks = (L + BS - 1) // BS
57
+ Q_blocks = (T + BS - 1) // BS
58
+
59
+ # [KV_blocks, BS]
60
+ written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view(
61
+ KV_blocks, BS
62
+ )
63
+
64
+ # Block-level occupancy
65
+ block_any = written_blocks.any(-1) # block has at least one written token
66
+ block_all = written_blocks.all(-1) # block is fully written
67
+
68
+ # Every Q-block sees the same KV-block pattern
69
+ nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) # [Q_blocks, KV_blocks]
70
+ full_bm = block_all[None, :].expand_as(nonzero_bm) # [Q_blocks, KV_blocks]
71
+ partial_bm = nonzero_bm & ~full_bm # [Q_blocks, KV_blocks]
72
+
73
+ def dense_to_ordered(dense_mask: torch.Tensor):
74
+ # dense_mask: [Q_blocks, KV_blocks] bool
75
+ # returns: [1,1,Q_blocks], [1,1,Q_blocks,KV_blocks]
76
+ num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) # [Q_blocks]
77
+ indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(
78
+ torch.int32
79
+ )
80
+ return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
81
+
82
+ # Partial blocks (need mask_mod)
83
+ kv_num_blocks, kv_indices = dense_to_ordered(partial_bm)
84
+
85
+ # Full blocks (mask_mod can be skipped entirely)
86
+ full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm)
87
+
88
+ def mask_mod(b, h, q, kv):
89
+ return written[kv]
90
+
91
+ bm = BlockMask.from_kv_blocks(
92
+ kv_num_blocks,
93
+ kv_indices,
94
+ full_kv_num_blocks,
95
+ full_kv_indices,
96
+ BLOCK_SIZE=BS,
97
+ mask_mod=mask_mod,
98
+ seq_lengths=(T, L),
99
+ compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path
100
+ )
101
+
102
+ return bm
103
+
104
+
105
+ class LayerKVCache(nn.Module):
106
+ """
107
+ Ring-buffer KV cache with fixed capacity L (tokens) for history plus
108
+ one extra frame (tokens_per_frame) at the tail holding the current frame.
109
+ """
110
+
111
+ def __init__(
112
+ self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1
113
+ ):
114
+ super().__init__()
115
+ self.tpf = tokens_per_frame
116
+ self.L = L
117
+ # total KV capacity: ring (L) + tail frame (tpf)
118
+ self.capacity = L + self.tpf
119
+ self.pinned_dilation = pinned_dilation
120
+ self.num_buckets = (L // self.tpf) // self.pinned_dilation
121
+ assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0
122
+
123
+ # KV buffer: [2, B, H, capacity, Dh]
124
+ self.kv = nn.Buffer(
125
+ torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype),
126
+ persistent=False,
127
+ )
128
+
129
+ # which slots have ever been written
130
+ # tail slice [L, L+tpf) always holds the current frame and is considered written
131
+ written = torch.zeros(self.capacity, dtype=torch.bool)
132
+ written[L:] = True
133
+ self.written = nn.Buffer(written, persistent=False)
134
+
135
+ # Precompute indices:
136
+ # frame_offsets: [0, 1, ..., tpf-1] (for ring indexing)
137
+ # current_idx: [L, L+1, ..., L+tpf-1] (tail slice)
138
+ self.frame_offsets = nn.Buffer(
139
+ torch.arange(self.tpf, dtype=torch.long), persistent=False
140
+ )
141
+ self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False)
142
+
143
+ def reset(self):
144
+ self.kv.zero_()
145
+ self.written.zero_()
146
+ self.written[self.L :].fill_(True)
147
+
148
+ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool):
149
+ """
150
+ Args:
151
+ kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame)
152
+ pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1)
153
+ """
154
+ T = self.tpf
155
+ t_pos = pos_ids["t_pos"]
156
+
157
+ if not torch.compiler.is_compiling():
158
+ torch._check(
159
+ kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert"
160
+ )
161
+ torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]")
162
+ torch._check(self.tpf <= self.L, "frame longer than KV ring capacity")
163
+ torch._check(
164
+ self.L % self.tpf == 0,
165
+ f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})",
166
+ )
167
+ torch._check(
168
+ self.kv.size(3) == self.capacity,
169
+ "KV buffer has unexpected length (expected L + tokens_per_frame)",
170
+ )
171
+ torch._check(
172
+ (t_pos >= 0).all().item(),
173
+ "t_pos must be non-negative during inference",
174
+ )
175
+ torch._check(
176
+ ((t_pos == t_pos[:, :1]).all()).item(),
177
+ "t_pos must be constant within frame",
178
+ )
179
+
180
+ frame_t = t_pos[0, 0]
181
+
182
+ # map frame_t to a bucket, each bucket owns T contiguous slots
183
+ bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation
184
+ slot = bucket % self.num_buckets
185
+ base = slot * T
186
+
187
+ # indices in the ring for this frame: [T] in [0, L)
188
+ ring_idx = self.frame_offsets + base
189
+
190
+ # Always write current frame into the tail slice [L, L+T):
191
+ # this is the "self-attention component" for the current frame.
192
+ self.kv.index_copy_(3, self.current_idx, kv)
193
+
194
+ write_step = frame_t.remainder(self.pinned_dilation) == 0
195
+ mask_written = self.written.clone()
196
+ mask_written[ring_idx] = mask_written[ring_idx] & ~write_step
197
+ bm = make_block_mask(T, self.capacity, mask_written)
198
+
199
+ # Persist current frame into the ring for future queries when unfrozen.
200
+ if not is_frozen:
201
+ # Persist current frame into the ring for future queries.
202
+ dst = torch.where(write_step, ring_idx, self.current_idx)
203
+ self.kv.index_copy_(3, dst, kv)
204
+ self.written[dst] = True
205
+
206
+ k, v = self.kv.unbind(0)
207
+ return k, v, bm
208
+
209
+
210
+ class StaticKVCache(nn.Module):
211
+ """Static KV cache with per-layer configuration for local/global attention."""
212
+
213
+ def __init__(self, config, batch_size, dtype):
214
+ super().__init__()
215
+
216
+ self.tpf = config.tokens_per_frame
217
+
218
+ local_L = config.local_window * self.tpf
219
+ global_L = config.global_window * self.tpf
220
+
221
+ period = config.global_attn_period
222
+ off = getattr(config, "global_attn_offset", 0) % period
223
+ self.layers = nn.ModuleList(
224
+ [
225
+ LayerKVCache(
226
+ batch_size,
227
+ getattr(config, "n_kv_heads", config.n_heads),
228
+ global_L if ((layer_idx - off) % period == 0) else local_L,
229
+ config.d_model // config.n_heads,
230
+ dtype,
231
+ self.tpf,
232
+ (
233
+ config.global_pinned_dilation
234
+ if ((layer_idx - off) % period == 0)
235
+ else 1
236
+ ),
237
+ )
238
+ for layer_idx in range(config.n_layers)
239
+ ]
240
+ )
241
+
242
+ self._is_frozen = True
243
+
244
+ def reset(self):
245
+ for layer in self.layers:
246
+ layer.reset()
247
+ self._is_frozen = True
248
+
249
+ def set_frozen(self, is_frozen: bool):
250
+ self._is_frozen = is_frozen
251
+
252
+ def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int):
253
+ kv = torch.stack([k, v], dim=0)
254
+ return self.layers[layer].upsert(kv, pos_ids, self._is_frozen)
255
+
256
+
257
+ class WorldEngineSetTimestepsStep(ModularPipelineBlocks):
258
+ """Sets up the scheduler sigmas for rectified flow denoising."""
259
+
260
+ model_name = "world_engine"
261
+
262
+ @property
263
+ def description(self) -> str:
264
+ return "Sets up scheduler sigmas for rectified flow denoising"
265
+
266
+ @property
267
+ def expected_components(self) -> List[ComponentSpec]:
268
+ return []
269
+
270
+ @property
271
+ def expected_configs(self) -> List[ConfigSpec]:
272
+ return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])]
273
+
274
+ @property
275
+ def inputs(self) -> List[InputParam]:
276
+ return [
277
+ InputParam(
278
+ "scheduler_sigmas",
279
+ type_hint=List[float],
280
+ description="Custom scheduler sigmas (overrides config)",
281
+ ),
282
+ InputParam(
283
+ "frame_timestamp",
284
+ type_hint=torch.Tensor,
285
+ description="Current frame timestamp",
286
+ ),
287
+ ]
288
+
289
+ @property
290
+ def intermediate_outputs(self) -> List[OutputParam]:
291
+ return [
292
+ OutputParam(
293
+ "scheduler_sigmas",
294
+ type_hint=torch.Tensor,
295
+ description="Tensor of scheduler sigmas for denoising",
296
+ ),
297
+ OutputParam(
298
+ "frame_timestamp",
299
+ type_hint=torch.Tensor,
300
+ description="Current frame timestamp",
301
+ ),
302
+ ]
303
+
304
+ @torch.no_grad()
305
+ def __call__(
306
+ self, components: ModularPipeline, state: PipelineState
307
+ ) -> PipelineState:
308
+ block_state = self.get_block_state(state)
309
+ device = components._execution_device
310
+ dtype = components.transformer.dtype
311
+
312
+ # Use provided sigmas or get from config
313
+ sigmas = block_state.scheduler_sigmas
314
+ if sigmas is None:
315
+ sigmas = components.config.scheduler_sigmas
316
+ block_state.scheduler_sigmas = torch.tensor(
317
+ sigmas, device=device, dtype=dtype
318
+ )
319
+
320
+ frame_ts = block_state.frame_timestamp
321
+ if frame_ts is None:
322
+ frame_ts = torch.tensor([[0]], dtype=torch.long, device=device)
323
+ elif isinstance(frame_ts, int):
324
+ frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device)
325
+
326
+ block_state.frame_timestamp = frame_ts
327
+
328
+ self.set_block_state(state, block_state)
329
+ return components, state
330
+
331
+
332
+ class WorldEngineSetupKVCacheStep(ModularPipelineBlocks):
333
+ """Initializes or reuses the KV cache for autoregressive generation."""
334
+
335
+ model_name = "world_engine"
336
+
337
+ @property
338
+ def description(self) -> str:
339
+ return "Initializes or reuses KV cache for autoregressive frame generation"
340
+
341
+ @property
342
+ def expected_components(self) -> List[ComponentSpec]:
343
+ return []
344
+
345
+ @property
346
+ def inputs(self) -> List[InputParam]:
347
+ return [
348
+ InputParam(
349
+ "kv_cache",
350
+ type_hint=Optional[StaticKVCache],
351
+ description="Existing KV cache (will be reused if provided)",
352
+ ),
353
+ InputParam(
354
+ "reset_cache",
355
+ type_hint=bool,
356
+ default=False,
357
+ description="If True, reset the KV cache even if one exists",
358
+ ),
359
+ ]
360
+
361
+ @property
362
+ def intermediate_outputs(self) -> List[OutputParam]:
363
+ return [
364
+ OutputParam(
365
+ "kv_cache",
366
+ type_hint=StaticKVCache,
367
+ description="KV cache for transformer attention",
368
+ ),
369
+ ]
370
+
371
+ @torch.no_grad()
372
+ def __call__(
373
+ self, components: ModularPipeline, state: PipelineState
374
+ ) -> PipelineState:
375
+ block_state = self.get_block_state(state)
376
+ device = components._execution_device
377
+ dtype = components.transformer.dtype
378
+
379
+ # Create or reuse KV cache
380
+ if block_state.kv_cache is None:
381
+ block_state.kv_cache = StaticKVCache(
382
+ components.transformer.config,
383
+ batch_size=1,
384
+ dtype=dtype,
385
+ ).to(device)
386
+ elif block_state.reset_cache:
387
+ block_state.kv_cache.reset()
388
+
389
+ self.set_block_state(state, block_state)
390
+ return components, state
391
+
392
+
393
+ class WorldEnginePrepareLatentsStep(ModularPipelineBlocks):
394
+ """Prepares latents for frame generation, optionally encoding an input image."""
395
+
396
+ model_name = "world_engine"
397
+
398
+ @property
399
+ def description(self) -> str:
400
+ return (
401
+ "Prepares latents for frame generation. If an image is provided on the "
402
+ "first frame, encodes it and caches it as context. Always creates fresh "
403
+ "random noise for the actual denoising."
404
+ )
405
+
406
+ @property
407
+ def expected_components(self) -> List[ComponentSpec]:
408
+ return [
409
+ ComponentSpec(
410
+ "image_processor",
411
+ VaeImageProcessor,
412
+ config=FrozenDict(
413
+ {
414
+ "vae_scale_factor": 16,
415
+ "do_normalize": False,
416
+ "do_convert_rgb": False,
417
+ }
418
+ ),
419
+ default_creation_method="from_config",
420
+ ),
421
+ ]
422
+
423
+ @property
424
+ def expected_configs(self) -> List[ConfigSpec]:
425
+ return [
426
+ ConfigSpec("channels", 16),
427
+ ConfigSpec("height", 16),
428
+ ConfigSpec("width", 16),
429
+ ConfigSpec("patch", [2, 2]),
430
+ ConfigSpec("vae_scale_factor", 16),
431
+ ]
432
+
433
+ @property
434
+ def inputs(self) -> List[InputParam]:
435
+ return [
436
+ InputParam(
437
+ "image",
438
+ type_hint=Union[PIL.Image.Image, torch.Tensor],
439
+ description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame",
440
+ ),
441
+ InputParam(
442
+ "latents",
443
+ type_hint=torch.Tensor,
444
+ description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.",
445
+ ),
446
+ InputParam(
447
+ "use_random_latents",
448
+ type_hint=bool,
449
+ default=True,
450
+ description="If True, always generate fresh random latents. If False, use provided latents.",
451
+ ),
452
+ InputParam(
453
+ "kv_cache",
454
+ description="KV cache to update",
455
+ ),
456
+ InputParam(
457
+ "frame_timestamp",
458
+ type_hint=torch.Tensor,
459
+ description="Current frame timestamp",
460
+ ),
461
+ InputParam(
462
+ "prompt_embeds",
463
+ type_hint=torch.Tensor,
464
+ description="Prompt embeddings for cache pass",
465
+ ),
466
+ InputParam(
467
+ "prompt_pad_mask",
468
+ type_hint=torch.Tensor,
469
+ description="Prompt padding mask",
470
+ ),
471
+ InputParam(
472
+ "button_tensor",
473
+ type_hint=torch.Tensor,
474
+ description="Button tensor for cache pass",
475
+ ),
476
+ InputParam(
477
+ "mouse_tensor",
478
+ type_hint=torch.Tensor,
479
+ description="Mouse tensor for cache pass",
480
+ ),
481
+ InputParam(
482
+ "scroll_tensor",
483
+ type_hint=torch.Tensor,
484
+ description="Scroll tensor for cache pass",
485
+ ),
486
+ InputParam(
487
+ "generator",
488
+ type_hint=torch.Generator,
489
+ default=None,
490
+ description="torch Generator for deterministic output",
491
+ ),
492
+ ]
493
+
494
+ @property
495
+ def intermediate_outputs(self) -> List[OutputParam]:
496
+ return [
497
+ OutputParam(
498
+ "latents",
499
+ type_hint=torch.Tensor,
500
+ description="Latent tensor for denoising [1, 1, C, H, W]",
501
+ ),
502
+ ]
503
+
504
+ @staticmethod
505
+ def _cache_pass(
506
+ transformer,
507
+ x,
508
+ frame_timestamp,
509
+ prompt_emb,
510
+ prompt_pad_mask,
511
+ mouse,
512
+ button,
513
+ scroll,
514
+ kv_cache,
515
+ ):
516
+ """Cache pass to persist frame in KV cache."""
517
+ kv_cache.set_frozen(False)
518
+ transformer(
519
+ x=x,
520
+ sigma=x.new_zeros((x.size(0), x.size(1))),
521
+ frame_timestamp=frame_timestamp,
522
+ prompt_emb=prompt_emb,
523
+ prompt_pad_mask=prompt_pad_mask,
524
+ mouse=mouse,
525
+ button=button,
526
+ scroll=scroll,
527
+ kv_cache=kv_cache,
528
+ )
529
+
530
+ @torch.inference_mode()
531
+ def __call__(
532
+ self, components: ModularPipeline, state: PipelineState
533
+ ) -> PipelineState:
534
+ block_state = self.get_block_state(state)
535
+ device = components._execution_device
536
+ dtype = components.transformer.dtype
537
+
538
+ # Get latent shape info
539
+ channels = components.config.channels
540
+ height = components.config.height
541
+ width = components.config.width
542
+ patch = components.config.patch
543
+
544
+ pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch)
545
+ shape = (
546
+ 1,
547
+ 1,
548
+ channels,
549
+ components.config.vae_scale_factor * pH,
550
+ components.config.vae_scale_factor * pW,
551
+ )
552
+
553
+ if block_state.image is not None:
554
+ image = block_state.image
555
+ # Preprocess: PIL/tensor -> [B, C, H, W] float32 in [0, 1]
556
+ image = components.image_processor.preprocess(
557
+ image,
558
+ height=height,
559
+ width=width,
560
+ )
561
+ # Convert to [H, W, 3] uint8 for VAE encoder
562
+ image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8)
563
+
564
+ assert image.dtype == torch.uint8, (
565
+ f"Expected uint8 image, got {image.dtype}"
566
+ )
567
+
568
+ latents = components.vae.encode(image)
569
+ latents = latents.unsqueeze(1)
570
+
571
+ # Run cache pass to persist encoded frame
572
+ self._cache_pass(
573
+ components.transformer,
574
+ latents,
575
+ block_state.frame_timestamp,
576
+ block_state.prompt_embeds,
577
+ block_state.prompt_pad_mask,
578
+ block_state.mouse_tensor,
579
+ block_state.button_tensor,
580
+ block_state.scroll_tensor,
581
+ block_state.kv_cache,
582
+ )
583
+ block_state.frame_timestamp.add_(1)
584
+
585
+ # Generate latents based on use_random_latents flag
586
+ if block_state.use_random_latents or block_state.latents is None:
587
+ block_state.latents = torch.randn(
588
+ shape, device=device, dtype=torch.bfloat16
589
+ )
590
+
591
+ self.set_block_state(state, block_state)
592
+ return components, state
593
+
594
+
595
+ class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks):
596
+ """Sequential pipeline that prepares all inputs for denoising."""
597
+
598
+ block_classes = [
599
+ WorldEngineSetTimestepsStep,
600
+ WorldEngineSetupKVCacheStep,
601
+ WorldEnginePrepareLatentsStep,
602
+ ]
603
+ block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"]
604
+
605
+ @property
606
+ def description(self) -> str:
607
+ return (
608
+ "Before denoise step that prepares inputs for denoising:\n"
609
+ " - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n"
610
+ " - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n"
611
+ " - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise"
612
+ )
decoders.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Decoder blocks for WorldEngine modular pipeline."""
17
+
18
+ from typing import List, Union
19
+
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+
24
+ from diffusers import AutoModel
25
+ from diffusers.configuration_utils import FrozenDict
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.utils import logging
28
+ from diffusers.modular_pipelines import (
29
+ ModularPipelineBlocks,
30
+ ModularPipeline,
31
+ PipelineState,
32
+ )
33
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
34
+ ComponentSpec,
35
+ InputParam,
36
+ OutputParam,
37
+ )
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class WorldEngineDecodeStep(ModularPipelineBlocks):
43
+ """Decodes denoised latents back to RGB image using VAE."""
44
+
45
+ model_name = "world_engine"
46
+
47
+ @property
48
+ def expected_components(self) -> List[ComponentSpec]:
49
+ return [
50
+ ComponentSpec("vae", AutoModel),
51
+ ComponentSpec(
52
+ "image_processor",
53
+ VaeImageProcessor,
54
+ config=FrozenDict(
55
+ {
56
+ "vae_scale_factor": 16,
57
+ "do_normalize": False,
58
+ "do_convert_rgb": True,
59
+ }
60
+ ),
61
+ default_creation_method="from_config",
62
+ ),
63
+ ]
64
+
65
+ @property
66
+ def description(self) -> str:
67
+ return "Decodes denoised latents to RGB image using the VAE decoder"
68
+
69
+ @property
70
+ def inputs(self) -> List[InputParam]:
71
+ return [
72
+ InputParam(
73
+ "latents",
74
+ required=True,
75
+ type_hint=torch.Tensor,
76
+ description="Denoised latent tensor [1, 1, C, H, W]",
77
+ ),
78
+ InputParam(
79
+ "output_type",
80
+ default="pil",
81
+ description="The output format for the generated images (pil, latent, pt, or np)",
82
+ ),
83
+ ]
84
+
85
+ @property
86
+ def intermediate_outputs(self) -> List[OutputParam]:
87
+ return [
88
+ OutputParam(
89
+ "images",
90
+ type_hint=Union[PIL.Image.Image, torch.Tensor, np.ndarray],
91
+ description="Decoded RGB image in requested output format",
92
+ ),
93
+ ]
94
+
95
+ @torch.no_grad()
96
+ def __call__(
97
+ self, components: ModularPipeline, state: PipelineState
98
+ ) -> PipelineState:
99
+ block_state = self.get_block_state(state)
100
+ latents = block_state.latents
101
+ output_type = block_state.output_type or "pil"
102
+
103
+ if output_type == "latent":
104
+ block_state.images = latents
105
+ else:
106
+ # Decode to image
107
+ # VAE expects [B, C, H, W] input, squeeze frame dim
108
+ # VAE returns [H, W, 3] uint8 tensor
109
+ image = components.vae.decode(latents.squeeze(1))
110
+
111
+ # Postprocess based on output_type
112
+ if output_type == "pt":
113
+ block_state.images = image
114
+ elif output_type == "np":
115
+ block_state.images = image.cpu().numpy()
116
+ else: # "pil"
117
+ block_state.images = PIL.Image.fromarray(image.cpu().numpy())
118
+
119
+ # Clear latents so next frame generates fresh random noise
120
+ block_state.latents = None
121
+ self.set_block_state(state, block_state)
122
+ return components, state
denoise.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Denoising block for WorldEngine modular pipeline."""
17
+
18
+ from typing import List
19
+
20
+ import torch
21
+
22
+ from diffusers.utils import logging
23
+ from diffusers.modular_pipelines import (
24
+ ModularPipelineBlocks,
25
+ ModularPipeline,
26
+ PipelineState,
27
+ )
28
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
29
+ ComponentSpec,
30
+ InputParam,
31
+ OutputParam,
32
+ )
33
+ from diffusers import AutoModel
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ class WorldEngineDenoiseLoop(ModularPipelineBlocks):
39
+ """Denoises latents using rectified flow and updates KV cache."""
40
+
41
+ model_name = "world_engine"
42
+
43
+ @property
44
+ def expected_components(self) -> List[ComponentSpec]:
45
+ return [ComponentSpec("transformer", AutoModel)]
46
+
47
+ @property
48
+ def description(self) -> str:
49
+ return (
50
+ "Denoises latents using rectified flow (x = x + dsigma * v) "
51
+ "and updates KV cache for autoregressive generation."
52
+ )
53
+
54
+ @property
55
+ def inputs(self) -> List[InputParam]:
56
+ return [
57
+ InputParam(
58
+ "scheduler_sigmas",
59
+ required=True,
60
+ type_hint=torch.Tensor,
61
+ description="Scheduler sigmas for denoising",
62
+ ),
63
+ InputParam(
64
+ "latents",
65
+ required=True,
66
+ type_hint=torch.Tensor,
67
+ description="Initial noisy latents [1, 1, C, H, W]",
68
+ ),
69
+ InputParam(
70
+ "kv_cache",
71
+ required=True,
72
+ description="KV cache for transformer attention",
73
+ ),
74
+ InputParam(
75
+ "frame_timestamp",
76
+ required=True,
77
+ type_hint=torch.Tensor,
78
+ description="Current frame timestamp",
79
+ ),
80
+ InputParam(
81
+ "prompt_embeds",
82
+ required=True,
83
+ type_hint=torch.Tensor,
84
+ description="Text embeddings for conditioning",
85
+ ),
86
+ InputParam(
87
+ "prompt_pad_mask",
88
+ type_hint=torch.Tensor,
89
+ description="Padding mask for prompt embeddings",
90
+ ),
91
+ InputParam(
92
+ "button_tensor",
93
+ required=True,
94
+ type_hint=torch.Tensor,
95
+ description="One-hot encoded button tensor",
96
+ ),
97
+ InputParam(
98
+ "mouse_tensor",
99
+ required=True,
100
+ type_hint=torch.Tensor,
101
+ description="Mouse velocity tensor",
102
+ ),
103
+ InputParam(
104
+ "scroll_tensor",
105
+ required=True,
106
+ type_hint=torch.Tensor,
107
+ description="Scroll wheel sign tensor",
108
+ ),
109
+ ]
110
+
111
+ @property
112
+ def intermediate_outputs(self) -> List[OutputParam]:
113
+ return [
114
+ OutputParam(
115
+ "latents",
116
+ type_hint=torch.Tensor,
117
+ description="Denoised latents",
118
+ ),
119
+ ]
120
+
121
+ @staticmethod
122
+ def _denoise_pass(
123
+ transformer,
124
+ x,
125
+ sigmas,
126
+ frame_timestamp,
127
+ prompt_emb,
128
+ prompt_pad_mask,
129
+ mouse,
130
+ button,
131
+ scroll,
132
+ kv_cache,
133
+ ):
134
+ """Denoising loop using rectified flow."""
135
+ kv_cache.set_frozen(True)
136
+ sigma = x.new_empty((x.size(0), x.size(1)))
137
+ for step_sig, step_dsig in zip(sigmas, sigmas.diff()):
138
+ v = transformer(
139
+ x=x,
140
+ sigma=sigma.fill_(step_sig),
141
+ frame_timestamp=frame_timestamp,
142
+ prompt_emb=prompt_emb,
143
+ prompt_pad_mask=prompt_pad_mask,
144
+ mouse=mouse,
145
+ button=button,
146
+ scroll=scroll,
147
+ kv_cache=kv_cache,
148
+ )
149
+ x = x + step_dsig * v
150
+ return x
151
+
152
+ @staticmethod
153
+ def _cache_pass(
154
+ transformer,
155
+ x,
156
+ frame_timestamp,
157
+ prompt_emb,
158
+ prompt_pad_mask,
159
+ mouse,
160
+ button,
161
+ scroll,
162
+ kv_cache,
163
+ ):
164
+ """Cache pass to persist frame for next generation."""
165
+ kv_cache.set_frozen(False)
166
+ transformer(
167
+ x=x,
168
+ sigma=x.new_zeros((x.size(0), x.size(1))),
169
+ frame_timestamp=frame_timestamp,
170
+ prompt_emb=prompt_emb,
171
+ prompt_pad_mask=prompt_pad_mask,
172
+ mouse=mouse,
173
+ button=button,
174
+ scroll=scroll,
175
+ kv_cache=kv_cache,
176
+ )
177
+
178
+ @torch.inference_mode()
179
+ def __call__(
180
+ self, components: ModularPipeline, state: PipelineState
181
+ ) -> PipelineState:
182
+ block_state = self.get_block_state(state)
183
+ block_state.latents = self._denoise_pass(
184
+ components.transformer,
185
+ block_state.latents,
186
+ block_state.scheduler_sigmas,
187
+ block_state.frame_timestamp,
188
+ block_state.prompt_embeds,
189
+ block_state.prompt_pad_mask,
190
+ block_state.mouse_tensor,
191
+ block_state.button_tensor,
192
+ block_state.scroll_tensor,
193
+ block_state.kv_cache,
194
+ ).clone()
195
+
196
+ self._cache_pass(
197
+ components.transformer,
198
+ block_state.latents,
199
+ block_state.frame_timestamp,
200
+ block_state.prompt_embeds,
201
+ block_state.prompt_pad_mask,
202
+ block_state.mouse_tensor,
203
+ block_state.button_tensor,
204
+ block_state.scroll_tensor,
205
+ block_state.kv_cache,
206
+ )
207
+ block_state.frame_timestamp.add_(1)
208
+
209
+ self.set_block_state(state, block_state)
210
+ return components, state
encoders.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Text and controller encoder blocks for WorldEngine modular pipeline."""
17
+
18
+ import html
19
+ from typing import List, Set, Tuple, Union
20
+
21
+ import regex as re
22
+ import torch
23
+ from transformers import AutoTokenizer, UMT5EncoderModel
24
+
25
+ from diffusers.utils import is_ftfy_available, logging
26
+ from diffusers.modular_pipelines import (
27
+ ModularPipelineBlocks,
28
+ ModularPipeline,
29
+ PipelineState,
30
+ )
31
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
32
+ ComponentSpec,
33
+ ConfigSpec,
34
+ InputParam,
35
+ OutputParam,
36
+ )
37
+
38
+ if is_ftfy_available():
39
+ import ftfy
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ def basic_clean(text):
46
+ text = ftfy.fix_text(text)
47
+ text = html.unescape(html.unescape(text))
48
+ return text.strip()
49
+
50
+
51
+ def whitespace_clean(text):
52
+ text = re.sub(r"\s+", " ", text)
53
+ text = text.strip()
54
+ return text
55
+
56
+
57
+ def prompt_clean(text):
58
+ text = whitespace_clean(basic_clean(text))
59
+ return text
60
+
61
+
62
+ class WorldEngineTextEncoderStep(ModularPipelineBlocks):
63
+ """Encodes text prompts using UMT5-XL for conditioning."""
64
+
65
+ model_name = "world_engine"
66
+
67
+ @property
68
+ def description(self) -> str:
69
+ return (
70
+ "Text Encoder step that generates text embeddings to guide frame generation"
71
+ )
72
+
73
+ @property
74
+ def expected_components(self) -> List[ComponentSpec]:
75
+ return [
76
+ ComponentSpec("text_encoder", UMT5EncoderModel),
77
+ ComponentSpec("tokenizer", AutoTokenizer),
78
+ ]
79
+
80
+ @property
81
+ def inputs(self) -> List[InputParam]:
82
+ return [
83
+ InputParam(
84
+ "prompt",
85
+ description="The prompt or prompts to guide the frame generation",
86
+ ),
87
+ InputParam(
88
+ "prompt_embeds",
89
+ type_hint=torch.Tensor,
90
+ description="Pre-computed text embeddings",
91
+ ),
92
+ InputParam(
93
+ "prompt_pad_mask",
94
+ type_hint=torch.Tensor,
95
+ description="Padding mask for prompt embeddings",
96
+ ),
97
+ ]
98
+
99
+ @property
100
+ def intermediate_outputs(self) -> List[OutputParam]:
101
+ return [
102
+ OutputParam(
103
+ "prompt_embeds",
104
+ type_hint=torch.Tensor,
105
+ kwargs_type="denoiser_input_fields",
106
+ description="Text embeddings used to guide frame generation",
107
+ ),
108
+ OutputParam(
109
+ "prompt_pad_mask",
110
+ type_hint=torch.Tensor,
111
+ kwargs_type="denoiser_input_fields",
112
+ description="Padding mask for prompt embeddings",
113
+ ),
114
+ ]
115
+
116
+ @staticmethod
117
+ def check_inputs(block_state):
118
+ if block_state.prompt is not None and (
119
+ not isinstance(block_state.prompt, str)
120
+ and not isinstance(block_state.prompt, list)
121
+ ):
122
+ raise ValueError(
123
+ f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}"
124
+ )
125
+
126
+ @staticmethod
127
+ def encode_prompt(
128
+ components,
129
+ prompt: Union[str, List[str]],
130
+ device: torch.device,
131
+ max_sequence_length: int = 512,
132
+ ):
133
+ dtype = components.text_encoder.dtype
134
+
135
+ prompt = [prompt] if isinstance(prompt, str) else prompt
136
+ prompt = [prompt_clean(p) for p in prompt]
137
+
138
+ text_inputs = components.tokenizer(
139
+ prompt,
140
+ padding="max_length",
141
+ max_length=max_sequence_length,
142
+ truncation=True,
143
+ return_attention_mask=True,
144
+ return_tensors="pt",
145
+ )
146
+
147
+ text_input_ids = text_inputs.input_ids.to(device)
148
+ attention_mask = text_inputs.attention_mask.to(device)
149
+
150
+ prompt_embeds = components.text_encoder(
151
+ text_input_ids, attention_mask
152
+ ).last_hidden_state
153
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
154
+
155
+ # Zero out padding
156
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as(
157
+ prompt_embeds
158
+ )
159
+
160
+ # Create padding mask (True where padded)
161
+ prompt_pad_mask = attention_mask.eq(0)
162
+
163
+ return prompt_embeds, prompt_pad_mask
164
+
165
+ @torch.no_grad()
166
+ def __call__(
167
+ self, components: ModularPipeline, state: PipelineState
168
+ ) -> PipelineState:
169
+ block_state = self.get_block_state(state)
170
+ self.check_inputs(block_state)
171
+
172
+ device = components._execution_device
173
+ if block_state.prompt_embeds is None:
174
+ block_state.prompt = block_state.prompt or "An explorable world"
175
+ (
176
+ block_state.prompt_embeds,
177
+ block_state.prompt_pad_mask,
178
+ ) = self.encode_prompt(components, block_state.prompt, device)
179
+ block_state.prompt_embeds = block_state.prompt_embeds.contiguous()
180
+
181
+ if block_state.prompt_pad_mask is None:
182
+ block_state.prompt_pad_mask = torch.zeros(
183
+ block_state.prompt_embeds.shape[:2],
184
+ dtype=torch.bool,
185
+ device=device,
186
+ )
187
+
188
+ self.set_block_state(state, block_state)
189
+ return components, state
190
+
191
+
192
+ class WorldEngineControllerEncoderStep(ModularPipelineBlocks):
193
+ """Encodes controller inputs (mouse + buttons + scroll) for conditioning."""
194
+
195
+ model_name = "world_engine"
196
+
197
+ @property
198
+ def description(self) -> str:
199
+ return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning"
200
+
201
+ @property
202
+ def expected_components(self) -> List[ComponentSpec]:
203
+ return [] # Controller embedding is part of transformer
204
+
205
+ @property
206
+ def expected_configs(self) -> List[ComponentSpec]:
207
+ return [ConfigSpec("n_buttons", 256)]
208
+
209
+ @property
210
+ def inputs(self) -> List[InputParam]:
211
+ return [
212
+ InputParam(
213
+ "button",
214
+ type_hint=Set[int],
215
+ default=set(),
216
+ description="Set of pressed button IDs",
217
+ ),
218
+ InputParam(
219
+ "mouse",
220
+ type_hint=Tuple[float, float],
221
+ default=(0.0, 0.0),
222
+ description="Mouse velocity (x, y)",
223
+ ),
224
+ InputParam(
225
+ "scroll",
226
+ type_hint=int,
227
+ default=0,
228
+ description="Scroll wheel direction (-1, 0, 1)",
229
+ ),
230
+ InputParam(
231
+ "button_tensor",
232
+ type_hint=torch.Tensor,
233
+ kwargs_type="denoiser_input_fields",
234
+ description="One-hot encoded button tensor",
235
+ ),
236
+ InputParam(
237
+ "mouse_tensor",
238
+ type_hint=torch.Tensor,
239
+ kwargs_type="denoiser_input_fields",
240
+ description="Mouse velocity tensor",
241
+ ),
242
+ InputParam(
243
+ "scroll_tensor",
244
+ type_hint=torch.Tensor,
245
+ kwargs_type="denoiser_input_fields",
246
+ description="Scroll wheel sign tensor",
247
+ ),
248
+ ]
249
+
250
+ @property
251
+ def intermediate_outputs(self) -> List[OutputParam]:
252
+ return [
253
+ OutputParam(
254
+ "button_tensor",
255
+ type_hint=torch.Tensor,
256
+ kwargs_type="denoiser_input_fields",
257
+ description="One-hot encoded button tensor",
258
+ ),
259
+ OutputParam(
260
+ "mouse_tensor",
261
+ type_hint=torch.Tensor,
262
+ kwargs_type="denoiser_input_fields",
263
+ description="Mouse velocity tensor",
264
+ ),
265
+ OutputParam(
266
+ "scroll_tensor",
267
+ type_hint=torch.Tensor,
268
+ kwargs_type="denoiser_input_fields",
269
+ description="Scroll wheel sign tensor",
270
+ ),
271
+ ]
272
+
273
+ @torch.no_grad()
274
+ def __call__(
275
+ self, components: ModularPipeline, state: PipelineState
276
+ ) -> PipelineState:
277
+ block_state = self.get_block_state(state)
278
+ device = components._execution_device
279
+ dtype = components.transformer.dtype
280
+
281
+ n_buttons = components.config.n_buttons
282
+
283
+ # Create or reuse button tensor [1, 1, n_buttons]
284
+ if block_state.button_tensor is None:
285
+ block_state.button_tensor = torch.zeros(
286
+ (1, 1, n_buttons), device=device, dtype=dtype
287
+ )
288
+
289
+ # Update button tensor in-place (avoid dynamic shapes for torch.compile)
290
+ block_state.button_tensor.zero_()
291
+ if block_state.button:
292
+ for btn_id in block_state.button:
293
+ if 0 <= btn_id < n_buttons:
294
+ block_state.button_tensor[0, 0, btn_id] = 1.0
295
+
296
+ # Create or reuse mouse tensor [1, 1, 2]
297
+ if block_state.mouse_tensor is None:
298
+ block_state.mouse_tensor = torch.zeros(
299
+ (1, 1, 2), device=device, dtype=dtype
300
+ )
301
+
302
+ # Update mouse tensor in-place
303
+ mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0)
304
+ block_state.mouse_tensor[0, 0, 0] = mouse[0]
305
+ block_state.mouse_tensor[0, 0, 1] = mouse[1]
306
+
307
+ # Create or reuse scroll tensor [1, 1, 1]
308
+ if block_state.scroll_tensor is None:
309
+ block_state.scroll_tensor = torch.zeros(
310
+ (1, 1, 1), device=device, dtype=dtype
311
+ )
312
+
313
+ # Update scroll tensor in-place (sign of scroll value: -1, 0, or 1)
314
+ scroll = block_state.scroll if block_state.scroll is not None else 0
315
+ block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0)
316
+
317
+ self.set_block_state(state, block_state)
318
+ return components, state
modular_blocks.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Block registry for WorldEngine modular pipeline."""
17
+
18
+ from diffusers.utils import logging
19
+ from diffusers.modular_pipelines import SequentialPipelineBlocks
20
+ from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
21
+
22
+ from .encoders import WorldEngineTextEncoderStep, WorldEngineControllerEncoderStep
23
+ from .before_denoise import WorldEngineBeforeDenoiseStep
24
+ from .denoise import WorldEngineDenoiseLoop
25
+ from .decoders import WorldEngineDecodeStep
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ AUTO_BLOCKS = InsertableDict(
31
+ [
32
+ ("text_encoder", WorldEngineTextEncoderStep),
33
+ ("controller_encoder", WorldEngineControllerEncoderStep),
34
+ ("before_denoise", WorldEngineBeforeDenoiseStep),
35
+ ("denoise", WorldEngineDenoiseLoop),
36
+ ("decode", WorldEngineDecodeStep),
37
+ ]
38
+ )
39
+
40
+
41
+ class WorldEngineBlocks(SequentialPipelineBlocks):
42
+ """Sequential pipeline blocks for WorldEngine frame generation."""
43
+
44
+ block_classes = list(AUTO_BLOCKS.copy().values())
45
+ block_names = list(AUTO_BLOCKS.copy().keys())
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "WorldEngineBlocks",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "modular_blocks.WorldEngineBlocks"
6
+ }
7
+ }
modular_model_index.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_blocks_class_name": "WorldEngineBlocks",
3
+ "_class_name": "ModularPipeline",
4
+ "_diffusers_version": "0.36.0.dev0",
5
+ "channels": 16,
6
+ "height": 360,
7
+ "width": 640,
8
+ "patch": [
9
+ 2,
10
+ 2
11
+ ],
12
+ "vae_scale_factor": 16,
13
+ "n_buttons": 256,
14
+ "tokens_per_frame": 256,
15
+ "scheduler_sigmas": [
16
+ 1.0,
17
+ 0.8609585762023926,
18
+ 0.729332447052002,
19
+ 0.3205108940601349,
20
+ 0.0
21
+ ],
22
+ "transformer": [
23
+ null,
24
+ null,
25
+ {
26
+ "pretrained_model_name_or_path": "Overworld/Waypoint-1-Small",
27
+ "subfolder": "transformer",
28
+ "type_hint": [
29
+ "diffusers",
30
+ "AutoModel"
31
+ ],
32
+ "revision": null,
33
+ "variant": null
34
+ }
35
+ ],
36
+ "vae": [
37
+ null,
38
+ null,
39
+ {
40
+ "pretrained_model_name_or_path": "Overworld/Waypoint-1-Small",
41
+ "subfolder": "vae",
42
+ "type_hint": [
43
+ "diffusers",
44
+ "AutoModel"
45
+ ],
46
+ "revision": null,
47
+ "variant": null
48
+ }
49
+ ],
50
+ "text_encoder": [
51
+ null,
52
+ null,
53
+ {
54
+ "pretrained_model_name_or_path": "google/umt5-xl",
55
+ "type_hint": [
56
+ "transformers",
57
+ "UMT5EncoderModel"
58
+ ],
59
+ "revision": null,
60
+ "variant": null
61
+ }
62
+ ],
63
+ "tokenizer": [
64
+ null,
65
+ null,
66
+ {
67
+ "pretrained_model_name_or_path": "google/umt5-xl",
68
+ "type_hint": [
69
+ "transformers",
70
+ "AutoTokenizer"
71
+ ],
72
+ "revision": null,
73
+ "variant": null
74
+ }
75
+ ]
76
+ }
transformer/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from .model import WorldModel
17
+ from .attn import Attn, CrossAttention, OrthoRoPE
18
+ from .nn import MLP, AdaLN, NoiseConditioner, rms_norm, ada_rmsnorm, ada_gate
19
+
20
+ __all__ = [
21
+ "WorldModel",
22
+ "Attn",
23
+ "CrossAttention",
24
+ "OrthoRoPE",
25
+ "MLP",
26
+ "AdaLN",
27
+ "NoiseConditioner",
28
+ "rms_norm",
29
+ "ada_rmsnorm",
30
+ "ada_gate",
31
+ ]
transformer/attn.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Attention mechanisms for WorldModel transformer."""
17
+
18
+ import math
19
+
20
+ import einops as eo
21
+ import torch
22
+ from torch import nn
23
+ from torch.nn.attention.flex_attention import flex_attention
24
+
25
+ from .nn import rms_norm, NoCastModule
26
+
27
+
28
+ def pixel_frequencies(dim: int, max_freq: float) -> torch.Tensor:
29
+ """Linear frequency spectrum for spatial RoPE (pixel positions).
30
+
31
+ Matches rotary_embedding_torch RotaryEmbedding(freqs_for='pixel').
32
+
33
+ Args:
34
+ dim: Output dimension (freqs will be repeated to fill this)
35
+ max_freq: Maximum frequency (should be below Nyquist)
36
+
37
+ Returns:
38
+ Tensor of shape [dim // 2] with linear frequencies
39
+ """
40
+ # Library uses max_freq/2 as the upper bound
41
+ return torch.linspace(1.0, max_freq / 2, dim // 2) * math.pi
42
+
43
+
44
+ def lang_frequencies(dim: int) -> torch.Tensor:
45
+ """Geometric frequency spectrum for temporal RoPE (language-style).
46
+
47
+ Matches rotary_embedding_torch RotaryEmbedding(freqs_for='lang').
48
+
49
+ Args:
50
+ dim: Output dimension (freqs will be repeated to fill this)
51
+
52
+ Returns:
53
+ Tensor of shape [dim // 2] with geometric frequencies
54
+ """
55
+ # Library uses 10^(-i/2) pattern
56
+ return 10.0 ** (-torch.arange(dim // 2).float() / 2)
57
+
58
+
59
+ class OrthoRoPE(NoCastModule):
60
+ """Rotary Position Embeddings for orthogonal axes: time, height, and width.
61
+
62
+ - Time: Geometric spectrum (like language models) -- rotates 1/2 of head dim
63
+ - Height/Width: Linear spectrum (for pixels) -- rotates 1/4 of head dim each
64
+ """
65
+
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.config = config
69
+ assert not getattr(self.config, "has_audio", False)
70
+
71
+ # Compute frequencies and store cos/sin buffers
72
+ freqs = self._compute_freqs()
73
+ self.cos = nn.Buffer(freqs.cos().contiguous(), persistent=False)
74
+ self.sin = nn.Buffer(freqs.sin().contiguous(), persistent=False)
75
+
76
+ def _compute_freqs(self):
77
+ """Compute frequency table for all positions.
78
+
79
+ Matches the behavior of rotary_embedding_torch.RotaryEmbedding.
80
+ The library interleaves frequencies so each freq value is used twice.
81
+ """
82
+ config = self.config
83
+ H, W, T = config.height, config.width, config.n_frames
84
+ head_dim = config.d_model // config.n_heads
85
+
86
+ # Spatial frequencies (linear spectrum, below Nyquist)
87
+ # Library: RotaryEmbedding(dim=head_dim//8) creates head_dim//16 freqs,
88
+ # outputs head_dim//8 values (each freq repeated twice)
89
+ max_freq = min(H, W) * 0.8
90
+ spatial_freqs = pixel_frequencies(head_dim // 8, max_freq) # [D/16]
91
+
92
+ # Positions in [-1, 1] range
93
+ pos_x = torch.linspace(-1 + 1 / W, 1 - 1 / W, W) # [W]
94
+ pos_y = torch.linspace(-1 + 1 / H, 1 - 1 / H, H) # [H]
95
+
96
+ # Spatial frequency embeddings with interleaving (like library)
97
+ freqs_x = torch.outer(pos_x, spatial_freqs) # [W, D/16]
98
+ freqs_y = torch.outer(pos_y, spatial_freqs) # [H, D/16]
99
+ freqs_x = freqs_x.repeat_interleave(2, dim=-1) # [W, D/8]
100
+ freqs_y = freqs_y.repeat_interleave(2, dim=-1) # [H, D/8]
101
+
102
+ # Expand to grid and repeat for all frames
103
+ freqs_x = freqs_x[None, :, :].expand(H, W, -1) # [H, W, D/8]
104
+ freqs_y = freqs_y[:, None, :].expand(H, W, -1) # [H, W, D/8]
105
+
106
+ freqs_x = eo.repeat(freqs_x, "h w d -> (t h w) d", t=T) # [T*H*W, D/8]
107
+ freqs_y = eo.repeat(freqs_y, "h w d -> (t h w) d", t=T) # [T*H*W, D/8]
108
+
109
+ # Temporal frequencies (geometric spectrum)
110
+ # Library: RotaryEmbedding(dim=head_dim//4) creates head_dim//8 freqs,
111
+ # outputs head_dim//4 values (each freq repeated twice)
112
+ temporal_freqs = lang_frequencies(head_dim // 4) # [D/8]
113
+ pos_t = torch.arange(T).float() # [T]
114
+ freqs_t = torch.outer(pos_t, temporal_freqs) # [T, D/8]
115
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1) # [T, D/4]
116
+ freqs_t = eo.repeat(freqs_t, "t d -> (t h w) d", h=H, w=W) # [T*H*W, D/4]
117
+
118
+ # Concatenate: [X, Y, T] -> [T*H*W, D/2]
119
+ return torch.cat([freqs_x, freqs_y, freqs_t], dim=-1)
120
+
121
+ def get_angles(self, pos_ids):
122
+ """Look up cos/sin angles for given position IDs."""
123
+ t, y, x = pos_ids["t_pos"], pos_ids["y_pos"], pos_ids["x_pos"] # [B,T]
124
+ H, W = self.config.height, self.config.width
125
+ if not torch.compiler.is_compiling():
126
+ torch._assert(
127
+ (y.max() < H) & (x.max() < W),
128
+ f"pos_ids out of bounds, {y.max()}, {x.max()}",
129
+ )
130
+ flat = t * (H * W) + y * W + x # [B,T]
131
+ idx = flat.reshape(-1).to(torch.long)
132
+ cos = self.cos.index_select(0, idx).view(*flat.shape, -1)
133
+ sin = self.sin.index_select(0, idx).view(*flat.shape, -1)
134
+ return cos[:, None], sin[:, None] # add head dim for broadcast
135
+
136
+ @torch.autocast("cuda", enabled=False)
137
+ def forward(self, x, pos_ids):
138
+ assert self.cos.dtype == self.sin.dtype == torch.float32
139
+ cos, sin = self.get_angles(pos_ids)
140
+ x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1)
141
+ y0 = x0 * cos - x1 * sin
142
+ y1 = x1 * cos + x0 * sin
143
+ return torch.cat((y0, y1), dim=-1).type_as(x)
144
+
145
+
146
+ class Attn(nn.Module):
147
+ """Self-attention with RoPE and optional GQA, value residual, and gated attention."""
148
+
149
+ def __init__(self, config, layer_idx):
150
+ super().__init__()
151
+ self.config = config
152
+ self.layer_idx = layer_idx
153
+
154
+ self.value_residual = getattr(config, "value_residual", False)
155
+ if self.value_residual:
156
+ self.v_lamb = nn.Parameter(torch.tensor(0.5))
157
+
158
+ self.n_heads = config.n_heads
159
+ self.n_kv_heads = getattr(config, "n_kv_heads", config.n_heads)
160
+ self.d_head = config.d_model // self.n_heads
161
+ assert config.d_model % self.n_heads == 0
162
+
163
+ self.enable_gqa = self.n_heads != self.n_kv_heads
164
+
165
+ self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False)
166
+ self.k_proj = nn.Linear(
167
+ config.d_model, self.n_kv_heads * self.d_head, bias=False
168
+ )
169
+ self.v_proj = nn.Linear(
170
+ config.d_model, self.n_kv_heads * self.d_head, bias=False
171
+ )
172
+ self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
173
+
174
+ self.rope = OrthoRoPE(config)
175
+
176
+ self.gated_attn = getattr(config, "gated_attn", False)
177
+ if self.gated_attn:
178
+ self.gate_proj = nn.Linear(
179
+ self.n_heads, self.n_heads, bias=False
180
+ ) # sparse attn gate
181
+ nn.init.zeros_(self.gate_proj.weight)
182
+
183
+ def forward(self, x, pos_ids, v1, kv_cache):
184
+ # Q, K, V proj -> QK-norm -> RoPE
185
+ q = eo.rearrange(
186
+ self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head
187
+ )
188
+ k = eo.rearrange(
189
+ self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
190
+ )
191
+ v = eo.rearrange(
192
+ self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
193
+ )
194
+
195
+ if self.value_residual:
196
+ v1 = v if v1 is None else v1
197
+ v = torch.lerp(v, v1.view_as(v), self.v_lamb)
198
+
199
+ q, k = rms_norm(q), rms_norm(k)
200
+ q, k = self.rope(q, pos_ids), self.rope(k, pos_ids)
201
+
202
+ k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
203
+ y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
204
+
205
+ if self.gated_attn:
206
+ gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
207
+ y = y * gates.permute(0, 2, 1).unsqueeze(-1)
208
+ y = eo.rearrange(y, "b h t d -> b t (h d)")
209
+ y = self.out_proj(y)
210
+ return y, v1
211
+
212
+
213
+ class MergedQKVAttn(Attn):
214
+ def __init__(self, src: Attn, config):
215
+ super().__init__(config, src.layer_idx) # makes fresh q/k/v/out/etc
216
+ self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype)
217
+ self.load_state_dict(
218
+ src.state_dict(), strict=False
219
+ ) # copies trained weights/buffers
220
+ self.train(src.training) # preserve train/eval mode
221
+
222
+ self.q_out = self.n_heads * self.d_head
223
+ self.kv_out = self.n_kv_heads * self.d_head
224
+
225
+ self.qkv_proj = nn.Linear(
226
+ self.q_proj.in_features,
227
+ self.q_out + 2 * self.kv_out,
228
+ bias=False,
229
+ device=self.q_proj.weight.device,
230
+ dtype=self.q_proj.weight.dtype,
231
+ )
232
+ with torch.no_grad():
233
+ self.qkv_proj.weight.copy_(
234
+ torch.cat(
235
+ [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
236
+ )
237
+ )
238
+
239
+ del self.q_proj, self.k_proj, self.v_proj
240
+
241
+ def forward(self, x, pos_ids, v1, kv_cache):
242
+ q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1)
243
+
244
+ B, T = x.shape[:2]
245
+ q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
246
+ k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
247
+ v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
248
+
249
+ if self.value_residual:
250
+ v1 = v if v1 is None else v1
251
+ v = torch.lerp(v, v1.view_as(v), self.v_lamb)
252
+
253
+ q, k = rms_norm(q), rms_norm(k)
254
+ q, k = self.rope(q, pos_ids), self.rope(k, pos_ids)
255
+
256
+ k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
257
+ y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
258
+
259
+ if self.gated_attn:
260
+ gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
261
+ y = y * gates.permute(0, 2, 1).unsqueeze(-1)
262
+
263
+ y = y.transpose(1, 2).reshape(B, T, -1)
264
+ y = self.out_proj(y)
265
+ return y, v1
266
+
267
+
268
+ class CrossAttention(nn.Module):
269
+ """Cross-attention for prompt conditioning."""
270
+
271
+ def __init__(self, config, context_dim=None):
272
+ super().__init__()
273
+ assert config.d_model % config.n_heads == 0
274
+
275
+ self.d_head = config.d_model // config.n_heads
276
+ self.inner_dim = context_dim or config.d_model
277
+ assert self.inner_dim % self.d_head == 0
278
+ self.n_heads = self.inner_dim // self.d_head
279
+ self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False)
280
+ self.k_proj = nn.Linear(
281
+ context_dim or config.d_model, self.inner_dim, bias=False
282
+ )
283
+ self.v_proj = nn.Linear(
284
+ context_dim or config.d_model, self.inner_dim, bias=False
285
+ )
286
+
287
+ self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False)
288
+ self.out_proj.weight.detach().zero_()
289
+
290
+ def forward(self, x, context, context_pad_mask=None):
291
+ q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads)
292
+ k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
293
+ v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
294
+ q, k = rms_norm(q), rms_norm(k)
295
+ out = flex_attention(q, k, v)
296
+ out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1)
297
+ return self.out_proj(out)
transformer/cache.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ import torch
17
+ from torch import nn, Tensor
18
+
19
+
20
+ def _bf16_u16(x: Tensor) -> Tensor:
21
+ # reinterpret bf16 storage as int16 -> unsigned 0..65535 in int32
22
+ return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF
23
+
24
+
25
+ class CachedDenoiseStepEmb(nn.Module):
26
+ """bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong)."""
27
+
28
+ def __init__(self, base: nn.Module, sigmas: list[float]):
29
+ super().__init__()
30
+ device = next(base.parameters()).device
31
+
32
+ levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) # [S]
33
+ bits = _bf16_u16(levels) # [S]
34
+ if torch.unique(bits).numel() != bits.numel():
35
+ raise ValueError(
36
+ "scheduler_sigmas collide in bf16; caching would be ambiguous"
37
+ )
38
+
39
+ with torch.no_grad():
40
+ table = (
41
+ base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
42
+ ) # [S,D]
43
+
44
+ lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
45
+ lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)
46
+
47
+ self.register_buffer("table", table, persistent=False) # [S,D] bf16
48
+ self.register_buffer("lut", lut, persistent=False) # [65536] int32
49
+ self.register_buffer(
50
+ "oob",
51
+ torch.tensor(bits.numel(), device=device, dtype=torch.int32),
52
+ persistent=False,
53
+ )
54
+
55
+ def forward(self, sigma: Tensor) -> Tensor:
56
+ if sigma.dtype is not torch.bfloat16:
57
+ raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
58
+ idx = self.lut[_bf16_u16(sigma)]
59
+ idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB)
60
+ return self.table[idx.to(torch.int64)] # [...,D] bf16
61
+
62
+
63
+ class CachedCondHead(nn.Module):
64
+ """bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong)."""
65
+
66
+ def __init__(
67
+ self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
68
+ ):
69
+ super().__init__()
70
+ table = cached_denoise_step_emb.table # [S,D] bf16
71
+ S, D = table.shape
72
+
73
+ with torch.no_grad():
74
+ emb = table[:, None, :] # [S,1,D]
75
+ cache = (
76
+ torch.stack([t.squeeze(1) for t in base(emb)], 0)
77
+ .to(torch.bfloat16)
78
+ .contiguous()
79
+ ) # [6,S,D]
80
+
81
+ # pick a single embedding dimension whose bf16 bits uniquely identify sigma
82
+ key_dim = None
83
+ for d in range(min(D, max_key_dims)):
84
+ b = _bf16_u16(table[:, d])
85
+ if torch.unique(b).numel() == S:
86
+ key_dim = d
87
+ key_bits = b
88
+ break
89
+ if key_dim is None:
90
+ raise ValueError(
91
+ "Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims"
92
+ )
93
+
94
+ lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
95
+ lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)
96
+
97
+ self.key_dim = int(key_dim)
98
+ self.register_buffer("cache", cache, persistent=False) # [6,S,D] bf16
99
+ self.register_buffer("lut", lut, persistent=False) # [65536] int32
100
+ self.register_buffer(
101
+ "oob",
102
+ torch.tensor(S, device=table.device, dtype=torch.int32),
103
+ persistent=False,
104
+ )
105
+
106
+ def forward(self, cond: Tensor):
107
+ if cond.dtype is not torch.bfloat16:
108
+ raise RuntimeError("CachedCondHead expects cond bf16")
109
+ idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
110
+ idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB)
111
+ g = self.cache[:, idx.to(torch.int64)] # [6,...,D] bf16 (or errors)
112
+ return tuple(g.unbind(0)) # (s0,b0,g0,s1,b1,g1)
transformer/config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "WorldModel",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model.WorldModel"
6
+ },
7
+ "d_model": 2560,
8
+ "n_heads": 40,
9
+ "n_kv_heads": 20,
10
+ "n_layers": 22,
11
+ "mlp_ratio": 5,
12
+ "channels": 16,
13
+ "height": 16,
14
+ "width": 16,
15
+ "patch": [
16
+ 2,
17
+ 2
18
+ ],
19
+ "tokens_per_frame": 256,
20
+ "n_frames": 4096,
21
+ "local_window": 16,
22
+ "global_window": 128,
23
+ "global_attn_period": 4,
24
+ "global_pinned_dilation": 8,
25
+ "global_attn_offset": 0,
26
+ "value_residual": false,
27
+ "gated_attn": true,
28
+ "n_buttons": 256,
29
+ "ctrl_conditioning": "mlp_fusion",
30
+ "ctrl_conditioning_period": 3,
31
+ "ctrl_cond_dropout": 0.0,
32
+ "prompt_conditioning": "cross_attention",
33
+ "prompt_conditioning_period": 3,
34
+ "prompt_embedding_dim": 2048,
35
+ "prompt_cond_dropout": 0.0,
36
+ "noise_conditioning": "wan",
37
+ "base_fps": 60,
38
+ "causal": true,
39
+ "mlp_gradient_checkpointing": true,
40
+ "block_gradient_checkpointing": true,
41
+ "rope_impl": "ortho",
42
+ "scheduler_sigmas": [
43
+ 1.0,
44
+ 0.8609585762023926,
45
+ 0.729332447052002,
46
+ 0.3205108940601349,
47
+ 0.0
48
+ ]
49
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14356db9229453850f9ad650f31c3e1c4744066abd43562f6fbee161fb36c9e6
3
+ size 12515075376
transformer/model.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """WorldModel transformer for frame generation."""
17
+
18
+ from typing import Optional, List
19
+ import math
20
+
21
+ import einops as eo
22
+ import torch
23
+ from torch import nn, Tensor
24
+ import torch.nn.functional as F
25
+ from tensordict import TensorDict
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+
29
+ from .attn import Attn, MergedQKVAttn, CrossAttention
30
+ from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm
31
+ from .quantize import quantize_model
32
+ from .cache import CachedDenoiseStepEmb, CachedCondHead
33
+
34
+
35
+ def patch_cached_noise_conditioning(model) -> None:
36
+ # Call AFTER: model.to(device="cuda", dtype=torch.bfloat16).eval()
37
+ cached_denoise_step_emb = CachedDenoiseStepEmb(
38
+ model.denoise_step_emb, model.config.scheduler_sigmas
39
+ )
40
+ model.denoise_step_emb = cached_denoise_step_emb
41
+ for blk in model.transformer.blocks:
42
+ blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb)
43
+
44
+
45
+ def patch_Attn_merge_qkv(model) -> None:
46
+ for name, mod in list(model.named_modules()):
47
+ if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn):
48
+ model.set_submodule(name, MergedQKVAttn(mod, model.config))
49
+
50
+
51
+ def patch_MLPFusion_split(model) -> None:
52
+ for name, mod in list(model.named_modules()):
53
+ if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion):
54
+ model.set_submodule(name, SplitMLPFusion(mod))
55
+
56
+
57
+ def _apply_inference_patches(model) -> None:
58
+ patch_cached_noise_conditioning(model)
59
+ patch_Attn_merge_qkv(model)
60
+ patch_MLPFusion_split(model)
61
+
62
+
63
+ class CFG(nn.Module):
64
+ def __init__(self, d_model: int, dropout: float):
65
+ super().__init__()
66
+ self.dropout = dropout
67
+ self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
68
+
69
+ def forward(
70
+ self, x: torch.Tensor, is_conditioned: Optional[bool] = None
71
+ ) -> torch.Tensor:
72
+ """
73
+ x: [B, L, D]
74
+ is_conditioned:
75
+ - None: training-style random dropout
76
+ - bool: whole batch conditioned / unconditioned at sampling
77
+ """
78
+ B, L, _ = x.shape
79
+ null = self.null_emb.expand(B, L, -1)
80
+
81
+ # training-style dropout OR unspecified
82
+ if self.training or is_conditioned is None:
83
+ if self.dropout == 0.0:
84
+ return x
85
+ drop = torch.rand(B, 1, 1, device=x.device) < self.dropout # [B,1,1]
86
+ return torch.where(drop, null, x)
87
+
88
+ # sampling-time switch
89
+ return x if is_conditioned else null
90
+
91
+
92
+ class ControllerInputEmbedding(nn.Module):
93
+ """Embeds controller inputs (mouse + buttons) into model dimension."""
94
+
95
+ def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4):
96
+ super().__init__()
97
+ self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) # mouse velocity (x,y) + scroll sign
98
+
99
+ def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor):
100
+ assert len(mouse.shape) == 3
101
+ x = torch.cat((mouse, button, scroll), dim=-1)
102
+ return self.mlp(x)
103
+
104
+
105
+ class MLPFusion(nn.Module):
106
+ """Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond])."""
107
+
108
+ def __init__(self, d_model: int):
109
+ super().__init__()
110
+ self.mlp = MLP(2 * d_model, d_model, d_model)
111
+
112
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
113
+ B, _, D = x.shape
114
+ L = cond.shape[1]
115
+
116
+ Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) # each [D, D]
117
+
118
+ x = x.view(B, L, -1, D)
119
+ h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze(
120
+ 2
121
+ ) # broadcast, no repeat/cat
122
+ h = F.silu(h)
123
+ y = F.linear(h, self.mlp.fc2.weight)
124
+ return y.flatten(1, 2)
125
+
126
+
127
+ class SplitMLPFusion(nn.Module):
128
+ """Packed MLPFusion -> split linears (no cat, quant-friendly)."""
129
+
130
+ def __init__(self, src: MLPFusion):
131
+ super().__init__()
132
+ D = src.mlp.fc2.in_features
133
+ dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype
134
+
135
+ self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
136
+ self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
137
+ self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
138
+
139
+ with torch.no_grad():
140
+ Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1)
141
+ self.fc1_x.weight.copy_(Wx)
142
+ self.fc1_c.weight.copy_(Wc)
143
+ self.fc2.weight.copy_(src.mlp.fc2.weight)
144
+
145
+ self.train(src.training)
146
+
147
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
148
+ B, _, D = x.shape
149
+ L = cond.shape[1]
150
+ x = x.reshape(B, L, -1, D)
151
+ return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(
152
+ 1, 2
153
+ )
154
+
155
+
156
+ class CondHead(nn.Module):
157
+ """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond)."""
158
+
159
+ n_cond = 6
160
+
161
+ def __init__(self, d_model: int, noise_conditioning: str = "wan"):
162
+ super().__init__()
163
+ self.bias_in = (
164
+ nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None
165
+ )
166
+ self.cond_proj = nn.ModuleList(
167
+ [nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)]
168
+ )
169
+
170
+ def forward(self, cond):
171
+ cond = cond + self.bias_in if self.bias_in is not None else cond
172
+ h = F.silu(cond)
173
+ return tuple(p(h) for p in self.cond_proj)
174
+
175
+
176
+ class WorldDiTBlock(nn.Module):
177
+ """Single transformer block with self-attention, optional cross-attention, and MLP."""
178
+
179
+ def __init__(
180
+ self,
181
+ d_model: int,
182
+ n_heads: int,
183
+ mlp_ratio: int,
184
+ layer_idx: int,
185
+ prompt_conditioning: Optional[str],
186
+ prompt_conditioning_period: int,
187
+ prompt_embedding_dim: int,
188
+ ctrl_conditioning_period: int,
189
+ noise_conditioning: str,
190
+ config,
191
+ ):
192
+ super().__init__()
193
+ self.config = config
194
+ self.attn = Attn(config, layer_idx)
195
+ self.mlp = MLP(d_model, d_model * mlp_ratio, d_model)
196
+ self.cond_head = CondHead(d_model, noise_conditioning)
197
+
198
+ do_prompt_cond = (
199
+ prompt_conditioning is not None
200
+ and layer_idx % prompt_conditioning_period == 0
201
+ )
202
+ self.prompt_cross_attn = (
203
+ CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None
204
+ )
205
+ do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0
206
+ self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None
207
+
208
+ def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None):
209
+ """
210
+ 0) Causal Frame Attention
211
+ 1) Frame->CTX Cross Attention
212
+ 2) MLP
213
+ """
214
+ s0, b0, g0, s1, b1, g1 = self.cond_head(cond)
215
+
216
+ # Self / Causal Attention
217
+ residual = x
218
+ x = ada_rmsnorm(x, s0, b0)
219
+ x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache)
220
+ x = ada_gate(x, g0) + residual
221
+
222
+ # Cross Attention Prompt Conditioning
223
+ if self.prompt_cross_attn is not None:
224
+ x = (
225
+ self.prompt_cross_attn(
226
+ rms_norm(x),
227
+ context=rms_norm(ctx["prompt_emb"]),
228
+ context_pad_mask=ctx["prompt_pad_mask"],
229
+ )
230
+ + x
231
+ )
232
+
233
+ # MLPFusion Controller Conditioning
234
+ if self.ctrl_mlpfusion is not None:
235
+ x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x
236
+
237
+ # MLP
238
+ x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x
239
+
240
+ return x, v
241
+
242
+
243
+ class WorldDiT(nn.Module):
244
+ """Stack of WorldDiTBlocks with shared parameters."""
245
+
246
+ def __init__(self, config):
247
+ super().__init__()
248
+ self.config = config
249
+ self.blocks = nn.ModuleList(
250
+ [
251
+ WorldDiTBlock(
252
+ d_model=config.d_model,
253
+ n_heads=config.n_heads,
254
+ mlp_ratio=config.mlp_ratio,
255
+ layer_idx=idx,
256
+ prompt_conditioning=config.prompt_conditioning,
257
+ prompt_conditioning_period=config.prompt_conditioning_period,
258
+ prompt_embedding_dim=config.prompt_embedding_dim,
259
+ ctrl_conditioning_period=config.ctrl_conditioning_period,
260
+ noise_conditioning=config.noise_conditioning,
261
+ config=config,
262
+ )
263
+ for idx in range(config.n_layers)
264
+ ]
265
+ )
266
+
267
+ if config.noise_conditioning in ("dit_air", "wan"):
268
+ ref_proj = self.blocks[0].cond_head.cond_proj
269
+ for blk in self.blocks[1:]:
270
+ for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj):
271
+ blk_mod.weight = ref_mod.weight
272
+
273
+ # Shared RoPE buffers
274
+ ref_rope = self.blocks[0].attn.rope
275
+ for blk in self.blocks[1:]:
276
+ blk.attn.rope = ref_rope
277
+
278
+ def forward(self, x, pos_ids, cond, ctx, kv_cache=None):
279
+ v = None
280
+ for i, block in enumerate(self.blocks):
281
+ x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache)
282
+ return x
283
+
284
+
285
+ class WorldModel(ModelMixin, ConfigMixin):
286
+ """
287
+ WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser.
288
+
289
+ Denoises a frame given:
290
+ - All previous frames (via KV cache)
291
+ - The prompt embedding
292
+ - The controller input embedding
293
+ - The current noise level
294
+ """
295
+
296
+ _supports_gradient_checkpointing = False
297
+ _keep_in_fp32_modules = ["denoise_step_emb", "rope"]
298
+
299
+ @register_to_config
300
+ def __init__(
301
+ self,
302
+ # Model architecture
303
+ d_model: int = 2560,
304
+ n_heads: int = 40,
305
+ n_kv_heads: Optional[int] = 20,
306
+ n_layers: int = 22,
307
+ mlp_ratio: int = 5,
308
+ channels: int = 16,
309
+ height: int = 16,
310
+ width: int = 16,
311
+ patch: tuple = (2, 2),
312
+ tokens_per_frame: int = 256,
313
+ n_frames: int = 512,
314
+ local_window: int = 16,
315
+ global_window: int = 128,
316
+ global_attn_period: int = 4,
317
+ global_pinned_dilation: int = 8,
318
+ global_attn_offset: int = -1,
319
+ value_residual: bool = False,
320
+ gated_attn: bool = True,
321
+ n_buttons: int = 256,
322
+ ctrl_conditioning: Optional[str] = "mlp_fusion",
323
+ ctrl_conditioning_period: int = 3,
324
+ ctrl_cond_dropout: float = 0.0,
325
+ prompt_conditioning: Optional[str] = "cross_attention",
326
+ prompt_conditioning_period: int = 3,
327
+ prompt_embedding_dim: int = 2048,
328
+ prompt_cond_dropout: float = 0.0,
329
+ noise_conditioning: str = "wan",
330
+ scheduler_sigmas: Optional[List[float]] = [
331
+ 1.0,
332
+ 0.9483006596565247,
333
+ 0.8379597067832947,
334
+ 0.0,
335
+ ],
336
+ base_fps: int = 60,
337
+ causal: bool = True,
338
+ mlp_gradient_checkpointing: bool = True,
339
+ block_gradient_checkpointing: bool = True,
340
+ rope_impl: str = "ortho",
341
+ ):
342
+ super().__init__()
343
+
344
+ self.denoise_step_emb = NoiseConditioner(d_model)
345
+ self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio)
346
+
347
+ if self.config.ctrl_conditioning is not None:
348
+ self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout)
349
+ if self.config.prompt_conditioning is not None:
350
+ self.prompt_cfg = CFG(
351
+ self.config.prompt_embedding_dim, self.config.prompt_cond_dropout
352
+ )
353
+
354
+ self.transformer = WorldDiT(self.config)
355
+ self.patch = tuple(patch)
356
+
357
+ C, D = channels, d_model
358
+ self.patchify = nn.Conv2d(
359
+ C, D, kernel_size=self.patch, stride=self.patch, bias=False
360
+ )
361
+ self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True)
362
+ self.out_norm = AdaLN(d_model)
363
+
364
+ # Cached 1-frame pos_ids (buffers + cached TensorDict view)
365
+ T = tokens_per_frame
366
+ idx = torch.arange(T, dtype=torch.long)
367
+ self.register_buffer(
368
+ "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False
369
+ )
370
+ self.register_buffer(
371
+ "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False
372
+ )
373
+ self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False)
374
+
375
+ def forward(
376
+ self,
377
+ x: Tensor,
378
+ sigma: Tensor,
379
+ frame_timestamp: Tensor,
380
+ prompt_emb: Optional[Tensor] = None,
381
+ prompt_pad_mask: Optional[Tensor] = None,
382
+ mouse: Optional[Tensor] = None,
383
+ button: Optional[Tensor] = None,
384
+ scroll: Optional[Tensor] = None,
385
+ kv_cache=None,
386
+ ):
387
+ """
388
+ Args:
389
+ x: [B, N, C, H, W] - latent frames
390
+ sigma: [B, N] - noise levels
391
+ frame_timestamp: [B, N] - frame indices
392
+ prompt_emb: [B, P, D] - prompt embeddings
393
+ prompt_pad_mask: [B, P] - padding mask for prompts
394
+ mouse: [B, N, 2] - mouse velocity
395
+ button: [B, N, n_buttons] - button states
396
+ scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1)
397
+ kv_cache: StaticKVCache instance
398
+ ctrl_cond: whether to apply controller conditioning (inference only)
399
+ prompt_cond: whether to apply prompt conditioning (inference only)
400
+ """
401
+ B, N, C, H, W = x.shape
402
+ ph, pw = self.patch
403
+ assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch"
404
+ Hp, Wp = H // ph, W // pw
405
+ torch._assert(
406
+ Hp * Wp == self.config.tokens_per_frame,
407
+ f"{Hp} * {Wp} != {self.config.tokens_per_frame}",
408
+ )
409
+
410
+ torch._assert(
411
+ B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1"
412
+ )
413
+ self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f))
414
+ pos_ids = TensorDict(
415
+ {
416
+ "t_pos": self._t_pos_1f[None],
417
+ "y_pos": self._y_pos_1f[None],
418
+ "x_pos": self._x_pos_1f[None],
419
+ },
420
+ batch_size=[1, self._t_pos_1f.numel()],
421
+ )
422
+ cond = self.denoise_step_emb(sigma) # [B, N, d]
423
+
424
+ assert button is not None
425
+ ctx = {
426
+ "ctrl_emb": self.ctrl_emb(mouse, button, scroll),
427
+ "prompt_emb": prompt_emb,
428
+ "prompt_pad_mask": prompt_pad_mask,
429
+ }
430
+
431
+ D = self.unpatchify.in_features
432
+ x = self.patchify(x.reshape(B * N, C, H, W))
433
+ x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d")
434
+ x = self.transformer(x, pos_ids, cond, ctx, kv_cache)
435
+ x = F.silu(self.out_norm(x, cond))
436
+ x = eo.rearrange(
437
+ self.unpatchify(x),
438
+ "b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)",
439
+ n=N,
440
+ hp=Hp,
441
+ wp=Wp,
442
+ ph=ph,
443
+ pw=pw,
444
+ )
445
+
446
+ return x
447
+
448
+ def quantize(self, quant_type: str):
449
+ quantize_model(self, quant_type)
450
+
451
+ def apply_inference_patches(self):
452
+ _apply_inference_patches(self)
transformer/nn.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Neural network building blocks for WorldModel transformer."""
17
+
18
+ import warnings
19
+
20
+ import einops as eo
21
+ import torch
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ class NoCastModule(torch.nn.Module):
27
+ """Module that prevents dtype casting during .to() calls."""
28
+
29
+ def _apply(self, fn):
30
+ def keep_dtype(t):
31
+ old_dtype = t.dtype
32
+ out = fn(t)
33
+ if out.dtype is not old_dtype:
34
+ warnings.warn(
35
+ f"{self.__class__.__name__}: requested dtype cast ignored; "
36
+ f"keeping {old_dtype}.",
37
+ stacklevel=3,
38
+ )
39
+ out = out.to(dtype=old_dtype)
40
+ return out
41
+
42
+ return super()._apply(keep_dtype)
43
+
44
+ def to(self, *args, **kwargs):
45
+ warn_cast = False
46
+
47
+ # m.to(ref_tensor): use ref's device, ignore its dtype
48
+ if args and isinstance(args[0], torch.Tensor):
49
+ ref, *rest = args
50
+ args = (ref.device, *rest)
51
+ base = next(self.parameters(), None) or next(self.buffers(), None)
52
+ if base is not None and ref.dtype is not base.dtype:
53
+ warn_cast = True
54
+
55
+ # keyword dtype
56
+ if kwargs.pop("dtype", None) is not None:
57
+ warn_cast = True
58
+
59
+ # positional dtype
60
+ args = tuple(a for a in args if not isinstance(a, torch.dtype))
61
+
62
+ if warn_cast:
63
+ warnings.warn(
64
+ f"{self.__class__.__name__}.to: requested dtype cast ignored; "
65
+ "keeping existing dtypes.",
66
+ stacklevel=2,
67
+ )
68
+
69
+ return super().to(*args, **kwargs)
70
+
71
+
72
+ def rms_norm(x: torch.Tensor) -> torch.Tensor:
73
+ """Root mean square layer normalization."""
74
+ return F.rms_norm(x, (x.size(-1),))
75
+
76
+
77
+ class MLP(nn.Module):
78
+ """Simple MLP with SiLU activation."""
79
+
80
+ def __init__(self, dim_in, dim_middle, dim_out):
81
+ super().__init__()
82
+ self.fc1 = nn.Linear(dim_in, dim_middle, bias=False)
83
+ self.fc2 = nn.Linear(dim_middle, dim_out, bias=False)
84
+
85
+ def forward(self, x):
86
+ return self.fc2(F.silu(self.fc1(x)))
87
+
88
+
89
+ class AdaLN(nn.Module):
90
+ """Adaptive Layer Normalization."""
91
+
92
+ def __init__(self, dim):
93
+ super().__init__()
94
+ self.fc = nn.Linear(dim, 2 * dim, bias=False)
95
+
96
+ def forward(self, x, cond):
97
+ # cond: [b, n, d], x: [b, n*m, d]
98
+ b, n, d = cond.shape
99
+ _, nm, _ = x.shape
100
+ m = nm // n
101
+
102
+ y = F.silu(cond)
103
+ ab = self.fc(y) # [b, n, 2d]
104
+ ab = ab.view(b, n, 1, 2 * d) # [b, n, 1, 2d]
105
+ ab = ab.expand(-1, -1, m, -1) # [b, n, m, 2d]
106
+ ab = ab.reshape(b, nm, 2 * d) # [b, nm, 2d]
107
+
108
+ a, b_ = ab.chunk(2, dim=-1) # [b, nm, d] each
109
+ x = rms_norm(x) * (1 + a) + b_
110
+ return x
111
+
112
+
113
+ def ada_rmsnorm(x, scale, bias):
114
+ """Adaptive RMS normalization with scale and bias."""
115
+ x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1))
116
+ y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2)
117
+ return eo.rearrange(y4, "b n m d -> b (n m) d")
118
+
119
+
120
+ def ada_gate(x, gate):
121
+ """Apply gating to x with per-frame gates."""
122
+ x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1))
123
+ return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d")
124
+
125
+
126
+ class NoiseConditioner(NoCastModule):
127
+ """Sigma -> logSNR -> Fourier Features -> Dense embedding."""
128
+
129
+ def __init__(self, dim, fourier_dim=512, base=10_000.0):
130
+ super().__init__()
131
+ assert fourier_dim % 2 == 0
132
+ half = fourier_dim // 2
133
+ self.freq = nn.Buffer(
134
+ torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32),
135
+ persistent=False,
136
+ )
137
+ self.mlp = MLP(fourier_dim, dim * 4, dim)
138
+
139
+ def forward(self, s, eps=torch.finfo(torch.float32).eps):
140
+ assert self.freq.dtype == torch.float32
141
+ orig_dtype, shape = s.dtype, s.shape
142
+
143
+ with torch.autocast("cuda", enabled=False):
144
+ s = s.reshape(-1).float() # fp32 for fourier numerical stability
145
+ s = s * 1000 # expressive rotation range
146
+
147
+ # calculate fourier features
148
+ phase = s[:, None] * self.freq[None, :]
149
+ emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1)
150
+ emb = emb * 2**0.5 # Ensure unit variance
151
+ emb = self.mlp(emb)
152
+
153
+ return emb.to(orig_dtype).view(*shape, -1)
transformer/quantize.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ QUANTS = [
23
+ None
24
+ ] # TODO: enable specific quant based on model config, which should specify compatible quants
25
+
26
+
27
+ try:
28
+ from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
29
+
30
+ QUANTS.append("nvfp4")
31
+ except ImportError:
32
+ pass
33
+
34
+
35
+ @torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
36
+ def fp4_linear(
37
+ a_bf16: torch.Tensor,
38
+ b_fp4_T: torch.Tensor,
39
+ a_global_sf: torch.Tensor,
40
+ b_sf_T: torch.Tensor,
41
+ alpha: torch.Tensor,
42
+ ) -> torch.Tensor:
43
+ a_fp4, a_sf = nvfp4_quantize(
44
+ a_bf16,
45
+ a_global_sf,
46
+ sfLayout=SfLayout.layout_128x4,
47
+ do_shuffle=False,
48
+ )
49
+ return mm_fp4(
50
+ a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass"
51
+ )
52
+
53
+
54
+ @fp4_linear.register_fake
55
+ def _fp4_linear_fake(
56
+ a_bf16: torch.Tensor,
57
+ b_fp4_T: torch.Tensor,
58
+ a_global_sf: torch.Tensor,
59
+ b_sf_T: torch.Tensor,
60
+ alpha: torch.Tensor,
61
+ ) -> torch.Tensor:
62
+ return torch.empty(
63
+ (a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16
64
+ )
65
+
66
+
67
+ class FP4Linear(nn.Module):
68
+ """FP4 Linear layer using FlashInfer's NVFP4 quantization."""
69
+
70
+ def __init__(self, lin: nn.Linear):
71
+ super().__init__()
72
+
73
+ self.in_features = lin.in_features
74
+ self.out_features = lin.out_features
75
+
76
+ # Check alignment requirements for NVFP4 TMA
77
+ assert self.in_features % 32 == 0 and self.out_features % 32 == 0, (
78
+ "features % 32 != 0, nvfp4 disallowed"
79
+ )
80
+
81
+ # Store weight from original linear layer
82
+ self.weight = nn.Parameter(lin.weight.detach().clone())
83
+
84
+ # Cached FP4 weight and scales (populated on first forward)
85
+ self._weight_fp4_T: Optional[torch.Tensor] = None
86
+ self._weight_scales_T: Optional[torch.Tensor] = None
87
+ self._alpha: Optional[torch.Tensor] = None
88
+ self._dummy_scale: Optional[torch.Tensor] = None
89
+ self._weight_global_sf = None
90
+
91
+ with torch.no_grad():
92
+ # Quantize weights eagerly (no lazy path)
93
+ self._dummy_scale = torch.full(
94
+ (1,), 1.0, device=self.weight.device, dtype=torch.float32
95
+ )
96
+ weight_bf16 = (
97
+ self.weight.to(torch.bfloat16).to(self.weight.device).contiguous()
98
+ )
99
+ weight_amax = weight_bf16.float().abs().nan_to_num().max()
100
+ self._weight_global_sf = (1.0) / weight_amax
101
+ self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale)
102
+ w_fp4, w_sf = nvfp4_quantize(
103
+ weight_bf16,
104
+ self._weight_global_sf,
105
+ sfLayout=SfLayout.layout_128x4,
106
+ do_shuffle=False,
107
+ )
108
+ self._weight_fp4_T = w_fp4.t()
109
+ self._weight_scales_T = w_sf.t()
110
+
111
+ # Warmup flashinfer fp4 graphs
112
+ assert self.weight.is_cuda, "Weights need to be on GPU before quantization"
113
+ # TODO: test actual shape warmup, might perform better
114
+ lazy_x = torch.zeros(
115
+ (1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16
116
+ )
117
+ fp4_linear(
118
+ lazy_x,
119
+ self._weight_fp4_T,
120
+ self._dummy_scale,
121
+ self._weight_scales_T,
122
+ self._alpha,
123
+ )
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ """Forward pass using FP4 quantization and FlashInfer GEMM."""
127
+ x_flat = x.reshape(-1, x.shape[-1])
128
+ y = fp4_linear(
129
+ x_flat.to(torch.bfloat16).contiguous(),
130
+ self._weight_fp4_T,
131
+ self._dummy_scale,
132
+ self._weight_scales_T,
133
+ self._alpha,
134
+ )
135
+ return y.reshape(x.shape[:-1] + (-1,))
136
+
137
+
138
+ class FP8W8A8Linear(nn.Module):
139
+ __constants__ = ("in_features", "out_features")
140
+
141
+ def __init__(self, lin: nn.Linear):
142
+ super().__init__()
143
+ self.in_features, self.out_features = lin.in_features, lin.out_features
144
+
145
+ f8 = torch.float8_e4m3fn
146
+ inv = 1.0 / float(torch.finfo(f8).max)
147
+ self._inv = inv
148
+
149
+ w = lin.weight.detach()
150
+ ws = (w.abs().amax() * inv).clamp_min(1e-8).float() # 0-d
151
+ wf8 = (w / ws.to(w.dtype)).to(f8).contiguous() # row-major
152
+ self.register_buffer("wT", wf8.t()) # col-major view (no contiguous)
153
+ self.register_buffer("ws", ws)
154
+
155
+ if lin.bias is None:
156
+ self.bias = None
157
+ else:
158
+ self.register_buffer("bias", lin.bias.detach().to(torch.float16))
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ s = x.shape
162
+ x2 = x.reshape(-1, s[-1])
163
+
164
+ xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() # 0-d
165
+ xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()
166
+
167
+ y = torch._scaled_mm(
168
+ xf8,
169
+ self.wT,
170
+ xs,
171
+ self.ws,
172
+ bias=self.bias,
173
+ out_dtype=torch.float16,
174
+ use_fast_accum=True,
175
+ )
176
+ return y.reshape(*s[:-1], self.out_features).to(x.dtype)
177
+
178
+
179
+ class FP8Linear(nn.Module):
180
+ def __init__(self, lin: nn.Linear):
181
+ super().__init__()
182
+ self.in_features, self.out_features = lin.in_features, lin.out_features
183
+
184
+ self.bias = (
185
+ nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn))
186
+ if lin.bias is not None
187
+ else None
188
+ )
189
+ w_amax = lin.weight.data.clone().amax().float().squeeze()
190
+ w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn)
191
+ self.register_buffer("w_amax", w_amax)
192
+ self.register_buffer("weightT", w.t())
193
+ self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32)
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ """
197
+ Forward pass using FP8 matmul.
198
+
199
+ Args:
200
+ x: Input tensor of shape [..., in_features] (flattens if > 2D)
201
+
202
+ Returns:
203
+ Output tensor of shape [..., out_features] in BF16 format, unflattened if input is > 2D
204
+ """
205
+
206
+ # Convert input to FP8 e4m3
207
+ x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous()
208
+
209
+ result = torch._scaled_mm(
210
+ x_fp8,
211
+ self.weightT,
212
+ bias=self.bias,
213
+ scale_a=self.dummy_scale,
214
+ scale_b=self.w_amax,
215
+ out_dtype=torch.bfloat16,
216
+ use_fast_accum=True,
217
+ )
218
+
219
+ return result.reshape(x.shape[:-1] + (-1,))
220
+
221
+
222
+ def quantize_model(model: nn.Module, quant: str):
223
+ if quant is None:
224
+ return model
225
+
226
+ def eligible(m: nn.Module) -> bool:
227
+ w = getattr(m, "weight", None)
228
+ if not isinstance(m, nn.Linear):
229
+ return False
230
+ if getattr(w, "dtype", None) != torch.bfloat16:
231
+ return False
232
+ o, k = w.shape
233
+ return (o % 32 == 0) and (k % 32 == 0)
234
+
235
+ new_linear = {
236
+ "w8a8": FP8W8A8Linear,
237
+ "nvfp4": FP4Linear,
238
+ "fp8": FP8Linear,
239
+ }[quant]
240
+
241
+ for name, child in model.named_children():
242
+ setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(
243
+ child, quant
244
+ )
245
+ return model
vae/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from .ae_model import WorldEngineVAE
17
+ from .dcae import bake_weight_norm
18
+
19
+ __all__ = ["WorldEngineVAE", "bake_weight_norm"]
vae/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (265 Bytes). View file
 
vae/__pycache__/ae_model.cpython-311.pyc ADDED
Binary file (5.64 kB). View file
 
vae/__pycache__/dcae.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
vae/ae_model.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """VAE model for WorldEngine frame encoding/decoding."""
17
+
18
+ from dataclasses import dataclass
19
+ from typing import List, Tuple
20
+
21
+ import torch
22
+ from torch import Tensor
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from .dcae import Encoder, Decoder, bake_weight_norm
27
+
28
+
29
+ @dataclass
30
+ class EncoderDecoderConfig:
31
+ """Config object for Encoder/Decoder initialization."""
32
+
33
+ channels: int
34
+ latent_channels: int
35
+ ch_0: int
36
+ ch_max: int
37
+ encoder_blocks_per_stage: List[int]
38
+ decoder_blocks_per_stage: List[int]
39
+ skip_logvar: bool = False
40
+
41
+
42
+ class WorldEngineVAE(ModelMixin, ConfigMixin):
43
+ """
44
+ VAE for encoding/decoding video frames using DCAE architecture.
45
+
46
+ Encodes RGB uint8 images to latent space and decodes latents back to RGB.
47
+ """
48
+
49
+ _supports_gradient_checkpointing = False
50
+
51
+ @register_to_config
52
+ def __init__(
53
+ self,
54
+ # Common parameters
55
+ sample_size: Tuple[int, int] = (360, 640),
56
+ channels: int = 3,
57
+ latent_channels: int = 16,
58
+ # Encoder parameters
59
+ encoder_ch_0: int = 64,
60
+ encoder_ch_max: int = 256,
61
+ encoder_blocks_per_stage: List[int] = None,
62
+ # Decoder parameters
63
+ decoder_ch_0: int = 128,
64
+ decoder_ch_max: int = 1024,
65
+ decoder_blocks_per_stage: List[int] = None,
66
+ # Shared parameters
67
+ skip_logvar: bool = False,
68
+ # Scaling factors
69
+ scale_factor: float = 1.0,
70
+ shift_factor: float = 0.0,
71
+ ):
72
+ super().__init__()
73
+
74
+ # Default blocks per stage
75
+ if encoder_blocks_per_stage is None:
76
+ encoder_blocks_per_stage = [1, 1, 1, 1]
77
+ if decoder_blocks_per_stage is None:
78
+ decoder_blocks_per_stage = [1, 1, 1, 1]
79
+
80
+ # Create encoder config
81
+ encoder_config = EncoderDecoderConfig(
82
+ channels=channels,
83
+ latent_channels=latent_channels,
84
+ ch_0=encoder_ch_0,
85
+ ch_max=encoder_ch_max,
86
+ encoder_blocks_per_stage=list(encoder_blocks_per_stage),
87
+ decoder_blocks_per_stage=list(decoder_blocks_per_stage),
88
+ skip_logvar=skip_logvar,
89
+ )
90
+
91
+ # Create decoder config
92
+ decoder_config = EncoderDecoderConfig(
93
+ channels=channels,
94
+ latent_channels=latent_channels,
95
+ ch_0=decoder_ch_0,
96
+ ch_max=decoder_ch_max,
97
+ encoder_blocks_per_stage=list(encoder_blocks_per_stage),
98
+ decoder_blocks_per_stage=list(decoder_blocks_per_stage),
99
+ skip_logvar=skip_logvar,
100
+ )
101
+
102
+ self.encoder = Encoder(encoder_config)
103
+ self.decoder = Decoder(decoder_config)
104
+
105
+ def encode(self, img: Tensor):
106
+ """RGB -> RGB+D -> latent"""
107
+ assert img.dim() == 3, "Expected [H, W, C] image tensor"
108
+ img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
109
+ rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1)
110
+ return self.encoder(rgb)
111
+
112
+ def decode(self, latent: Tensor):
113
+ decoded = self.decoder(latent)
114
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
115
+ decoded = (decoded * 255).round().to(torch.uint8)
116
+ return decoded.squeeze(0).permute(1, 2, 0)[..., :3]
117
+
118
+ def forward(self, x: Tensor, encode: bool = True) -> Tensor:
119
+ """
120
+ Forward pass - encode or decode based on flag.
121
+
122
+ Args:
123
+ x: Input tensor (image for encode, latent for decode)
124
+ encode: If True, encode; if False, decode
125
+
126
+ Returns:
127
+ Encoded latent or decoded image
128
+ """
129
+ if encode:
130
+ return self.encode(x)
131
+ else:
132
+ return self.decode(x)
133
+
134
+ def bake_weight_norm(self):
135
+ """Remove weight_norm parametrizations, baking normalized weights into regular tensors.
136
+
137
+ Call this after loading weights and before torch.compile to avoid
138
+ CUDA graph capture errors from in-place weight updates.
139
+ """
140
+ bake_weight_norm(self)
141
+ return self
vae/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "WorldEngineVAE",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "ae_model.WorldEngineVAE"
6
+ },
7
+ "sample_size": [
8
+ 360,
9
+ 640
10
+ ],
11
+ "channels": 3,
12
+ "latent_channels": 16,
13
+ "encoder_ch_0": 64,
14
+ "encoder_ch_max": 256,
15
+ "encoder_blocks_per_stage": [
16
+ 1,
17
+ 1,
18
+ 1,
19
+ 1
20
+ ],
21
+ "decoder_ch_0": 128,
22
+ "decoder_ch_max": 1024,
23
+ "decoder_blocks_per_stage": [
24
+ 1,
25
+ 1,
26
+ 1,
27
+ 1
28
+ ],
29
+ "use_middle_block": false,
30
+ "skip_logvar": false,
31
+ "scale_factor": 1.0,
32
+ "shift_factor": 0.0
33
+ }
vae/dcae.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+ from torch.nn.utils.parametrizations import weight_norm
21
+ from torch.nn.utils.parametrize import remove_parametrizations
22
+
23
+
24
+ def bake_weight_norm(model: nn.Module) -> nn.Module:
25
+ """Remove weight_norm parametrizations, baking normalized weights into regular tensors.
26
+
27
+ This is required for torch.compile/CUDA graph compatibility since weight_norm
28
+ performs in-place updates during forward passes.
29
+ """
30
+ for module in model.modules():
31
+ if hasattr(module, "parametrizations") and "weight" in getattr(module, "parametrizations", {}):
32
+ remove_parametrizations(module, "weight", leave_parametrized=True)
33
+ return model
34
+
35
+
36
+ # === General Blocks ===
37
+
38
+ def WeightNormConv2d(*args, **kwargs):
39
+ return weight_norm(nn.Conv2d(*args, **kwargs))
40
+
41
+ class ResBlock(nn.Module):
42
+ def __init__(self, ch):
43
+ super().__init__()
44
+
45
+ hidden = 2 * ch
46
+ # 16 channels per group (matches checkpoint shapes like [128,16,3,3] when ch=64)
47
+ n_grps = max(1, hidden // 16)
48
+
49
+ self.conv1 = WeightNormConv2d(ch, hidden, 1, 1, 0)
50
+ self.conv2 = WeightNormConv2d(hidden, hidden, 3, 1, 1, groups=n_grps)
51
+ self.conv3 = WeightNormConv2d(hidden, ch, 1, 1, 0, bias=False)
52
+
53
+ self.act1 = nn.LeakyReLU(inplace=False)
54
+ self.act2 = nn.LeakyReLU(inplace=False)
55
+
56
+ def forward(self, x):
57
+ h = self.conv1(x)
58
+ h = self.act1(h)
59
+ h = self.conv2(h)
60
+ h = self.act2(h)
61
+ h = self.conv3(h)
62
+ return x + h
63
+
64
+ # === Encoder ===
65
+
66
+ class LandscapeToSquare(nn.Module):
67
+ # Strict assumption of 360p
68
+ def __init__(self, ch_in, ch_out):
69
+ super().__init__()
70
+
71
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
72
+
73
+ def forward(self, x):
74
+ x = F.interpolate(x, (512, 512), mode='bicubic')
75
+ x = self.proj(x)
76
+ return x
77
+
78
+ class Downsample(nn.Module):
79
+ def __init__(self, ch_in, ch_out):
80
+ super().__init__()
81
+
82
+ self.proj = WeightNormConv2d(ch_in, ch_out, 1, 1, 0, bias=False)
83
+
84
+ def forward(self, x):
85
+ x = F.interpolate(x, scale_factor=0.5, mode='bicubic')
86
+ x = self.proj(x)
87
+ return x
88
+
89
+ class DownBlock(nn.Module):
90
+ def __init__(self, ch_in, ch_out, num_res=1):
91
+ super().__init__()
92
+
93
+ self.down = Downsample(ch_in, ch_out)
94
+ blocks = []
95
+ for _ in range(num_res):
96
+ blocks.append(ResBlock(ch_in))
97
+ self.blocks = nn.ModuleList(blocks)
98
+
99
+ def forward(self, x):
100
+ for block in self.blocks:
101
+ x = block(x)
102
+ x = self.down(x)
103
+ return x
104
+
105
+ class SpaceToChannel(nn.Module):
106
+ def __init__(self, ch_in, ch_out):
107
+ super().__init__()
108
+
109
+ self.proj = WeightNormConv2d(ch_in, ch_out // 4, 3, 1, 1)
110
+
111
+ def forward(self, x):
112
+ x = self.proj(x)
113
+ x = F.pixel_unshuffle(x, 2).contiguous()
114
+ return x
115
+
116
+ class ChannelAverage(nn.Module):
117
+ def __init__(self, ch_in, ch_out):
118
+ super().__init__()
119
+
120
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
121
+ self.grps = ch_in // ch_out
122
+ self.scale = (self.grps) ** 0.5
123
+
124
+ def forward(self, x):
125
+ res = x
126
+ x = self.proj(x.contiguous()) # [b, ch_out, h, w]
127
+
128
+ # Residual goes through channel avg
129
+ res = res.view(res.shape[0], self.grps, res.shape[1] // self.grps, res.shape[2], res.shape[3]).contiguous()
130
+ res = res.mean(dim=1) * self.scale # [b, ch_out, h, w]
131
+
132
+ return res + x
133
+
134
+ # === Decoder ===
135
+
136
+ class SquareToLandscape(nn.Module):
137
+ def __init__(self, ch_in, ch_out):
138
+ super().__init__()
139
+
140
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
141
+
142
+ def forward(self, x):
143
+ x = self.proj(x) # TODO This ordering is wrong for both
144
+ x = F.interpolate(x, (360, 640), mode='bicubic')
145
+ return x
146
+
147
+ class Upsample(nn.Module):
148
+ def __init__(self, ch_in, ch_out):
149
+ super().__init__()
150
+
151
+ self.proj = nn.Identity() if ch_in == ch_out else WeightNormConv2d(
152
+ ch_in, ch_out, 1, 1, 0, bias=False
153
+ )
154
+
155
+ def forward(self, x):
156
+ x = self.proj(x)
157
+ x = F.interpolate(x, scale_factor=2.0, mode='bicubic')
158
+ return x
159
+
160
+ class UpBlock(nn.Module):
161
+ def __init__(self, ch_in, ch_out, num_res=1):
162
+ super().__init__()
163
+
164
+ self.up = Upsample(ch_in, ch_out)
165
+ blocks = []
166
+ for _ in range(num_res):
167
+ blocks.append(ResBlock(ch_out))
168
+ self.blocks = nn.ModuleList(blocks)
169
+
170
+ def forward(self, x):
171
+ x = self.up(x)
172
+ for block in self.blocks:
173
+ x = block(x)
174
+ return x
175
+
176
+ class ChannelToSpace(nn.Module):
177
+ def __init__(self, ch_in, ch_out):
178
+ super().__init__()
179
+
180
+ self.proj = WeightNormConv2d(ch_in, ch_out * 4, 3, 1, 1)
181
+
182
+ def forward(self, x):
183
+ x = self.proj(x)
184
+ x = F.pixel_shuffle(x, 2).contiguous()
185
+ return x
186
+
187
+ class ChannelDuplication(nn.Module):
188
+ def __init__(self, ch_in, ch_out):
189
+ super().__init__()
190
+
191
+ self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
192
+ self.reps = ch_out // ch_in
193
+ self.scale = (self.reps) ** -0.5
194
+
195
+ def forward(self, x):
196
+ res = x
197
+ x = self.proj(x.contiguous())
198
+
199
+ b, c, h, w = res.shape
200
+ res = res.unsqueeze(2) # [b, c, 1, h, w]
201
+ res = res.expand(b, c, self.reps, h, w) # [b, c, reps, h, w]
202
+ res = res.reshape(b, c * self.reps, h, w).contiguous()
203
+ res = res * self.scale
204
+
205
+ return res + x
206
+
207
+ # === Main AE ===
208
+
209
+ class Encoder(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+
213
+ self.conv_in = LandscapeToSquare(config.channels, config.ch_0)
214
+
215
+ blocks = []
216
+ residuals = []
217
+
218
+ ch = config.ch_0
219
+ for block_count in config.encoder_blocks_per_stage:
220
+ next_ch = min(ch*2, config.ch_max)
221
+
222
+ blocks.append(DownBlock(ch, next_ch, block_count))
223
+ residuals.append(SpaceToChannel(ch, next_ch))
224
+
225
+ ch = next_ch
226
+
227
+ self.blocks = nn.ModuleList(blocks)
228
+ self.residuals = nn.ModuleList(residuals)
229
+ self.conv_out = ChannelAverage(ch, config.latent_channels)
230
+
231
+ self.skip_logvar = bool(getattr(config, "skip_logvar", False))
232
+ if not self.skip_logvar:
233
+ # Checkpoint expects a 1-channel logvar head: [1, ch, 3, 3]
234
+ self.conv_out_logvar = WeightNormConv2d(ch, 1, 3, 1, 1)
235
+
236
+ def forward(self, x):
237
+ x = self.conv_in(x)
238
+ for block, residual in zip(self.blocks, self.residuals):
239
+ x = block(x) + residual(x)
240
+ return self.conv_out(x)
241
+
242
+ class Decoder(nn.Module):
243
+ def __init__(self, config):
244
+ super().__init__()
245
+
246
+ self.conv_in = ChannelDuplication(config.latent_channels, config.ch_max)
247
+
248
+ blocks = []
249
+ residuals = []
250
+
251
+ ch = config.ch_0
252
+ for block_count in reversed(config.decoder_blocks_per_stage):
253
+ next_ch = min(ch*2, config.ch_max)
254
+
255
+ blocks.append(UpBlock(next_ch, ch, block_count))
256
+ residuals.append(ChannelToSpace(next_ch, ch))
257
+
258
+ ch = next_ch
259
+
260
+ self.blocks = nn.ModuleList(reversed(blocks))
261
+ self.residuals = nn.ModuleList(reversed(residuals))
262
+
263
+ self.act_out = nn.SiLU()
264
+ self.conv_out = SquareToLandscape(config.ch_0, config.channels)
265
+
266
+ def forward(self, x):
267
+ x = self.conv_in(x)
268
+ for block, residual in zip(self.blocks, self.residuals):
269
+ x = block(x) + residual(x)
270
+ x = self.act_out(x)
271
+ return self.conv_out(x)
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecdebf692a6b02610163948251dcf264c5793da12b3729986fb4a3e3c4dc4d1f
3
+ size 141887736
vae/model.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """VAE model for WorldEngine frame encoding/decoding."""
17
+
18
+ from dataclasses import dataclass
19
+ from typing import List, Tuple
20
+
21
+ import torch
22
+ from torch import Tensor
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from .dcae import Encoder, Decoder
27
+
28
+
29
+ @dataclass
30
+ class EncoderDecoderConfig:
31
+ """Config object for Encoder/Decoder initialization."""
32
+
33
+ sample_size: Tuple[int, int]
34
+ channels: int
35
+ latent_channels: int
36
+ ch_0: int
37
+ ch_max: int
38
+ encoder_blocks_per_stage: List[int]
39
+ decoder_blocks_per_stage: List[int]
40
+ use_middle_block: bool
41
+ skip_logvar: bool = False
42
+ skip_residuals: bool = False
43
+ normalize_mu: bool = False
44
+
45
+
46
+ class WorldEngineVAE(ModelMixin, ConfigMixin):
47
+ """
48
+ VAE for encoding/decoding video frames using DCAE architecture.
49
+
50
+ Encodes RGB uint8 images to latent space and decodes latents back to RGB.
51
+ """
52
+
53
+ _supports_gradient_checkpointing = False
54
+
55
+ @register_to_config
56
+ def __init__(
57
+ self,
58
+ # Common parameters
59
+ sample_size: Tuple[int, int] = (360, 640),
60
+ channels: int = 3,
61
+ latent_channels: int = 16,
62
+ # Encoder parameters
63
+ encoder_ch_0: int = 64,
64
+ encoder_ch_max: int = 256,
65
+ encoder_blocks_per_stage: List[int] = None,
66
+ # Decoder parameters
67
+ decoder_ch_0: int = 128,
68
+ decoder_ch_max: int = 1024,
69
+ decoder_blocks_per_stage: List[int] = None,
70
+ # Shared parameters
71
+ use_middle_block: bool = False,
72
+ skip_logvar: bool = False,
73
+ # Scaling factors
74
+ scale_factor: float = 1.0,
75
+ shift_factor: float = 0.0,
76
+ ):
77
+ super().__init__()
78
+
79
+ # Default blocks per stage
80
+ if encoder_blocks_per_stage is None:
81
+ encoder_blocks_per_stage = [1, 1, 1, 1]
82
+ if decoder_blocks_per_stage is None:
83
+ decoder_blocks_per_stage = [1, 1, 1, 1]
84
+
85
+ # Create encoder config
86
+ encoder_config = EncoderDecoderConfig(
87
+ sample_size=tuple(sample_size),
88
+ channels=channels,
89
+ latent_channels=latent_channels,
90
+ ch_0=encoder_ch_0,
91
+ ch_max=encoder_ch_max,
92
+ encoder_blocks_per_stage=list(encoder_blocks_per_stage),
93
+ decoder_blocks_per_stage=list(decoder_blocks_per_stage),
94
+ use_middle_block=use_middle_block,
95
+ skip_logvar=skip_logvar,
96
+ )
97
+
98
+ # Create decoder config
99
+ decoder_config = EncoderDecoderConfig(
100
+ sample_size=tuple(sample_size),
101
+ channels=channels,
102
+ latent_channels=latent_channels,
103
+ ch_0=decoder_ch_0,
104
+ ch_max=decoder_ch_max,
105
+ encoder_blocks_per_stage=list(encoder_blocks_per_stage),
106
+ decoder_blocks_per_stage=list(decoder_blocks_per_stage),
107
+ use_middle_block=use_middle_block,
108
+ skip_logvar=skip_logvar,
109
+ )
110
+
111
+ self.encoder = Encoder(encoder_config)
112
+ self.decoder = Decoder(decoder_config)
113
+
114
+ def encode(self, img: Tensor):
115
+ """RGB -> RGB+D -> latent"""
116
+ assert img.dim() == 3, "Expected [H, W, C] image tensor"
117
+ img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
118
+ rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1)
119
+ return self.encoder(rgb)
120
+
121
+ @torch.compile
122
+ def decode(self, latent: Tensor):
123
+ decoded = self.decoder(latent)
124
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
125
+ decoded = (decoded * 255).round().to(torch.uint8)
126
+ return decoded.squeeze(0).permute(1, 2, 0)[..., :3]
127
+
128
+ def forward(self, x: Tensor, encode: bool = True) -> Tensor:
129
+ """
130
+ Forward pass - encode or decode based on flag.
131
+
132
+ Args:
133
+ x: Input tensor (image for encode, latent for decode)
134
+ encode: If True, encode; if False, decode
135
+
136
+ Returns:
137
+ Encoded latent or decoded image
138
+ """
139
+ if encode:
140
+ return self.encode(x)
141
+ else:
142
+ return self.decode(x)