Add diffusers support
#1
by
dn6
HF Staff
- opened
- README.md +91 -3
- __init__.py +61 -0
- before_denoise.py +612 -0
- decoders.py +122 -0
- denoise.py +210 -0
- encoders.py +318 -0
- modular_blocks.py +45 -0
- modular_config.json +7 -0
- modular_model_index.json +76 -0
- transformer/__init__.py +31 -0
- transformer/attn.py +297 -0
- transformer/cache.py +112 -0
- transformer/config.json +49 -0
- transformer/diffusion_pytorch_model.safetensors +3 -0
- transformer/model.py +452 -0
- transformer/nn.py +153 -0
- transformer/quantize.py +245 -0
- vae/__init__.py +19 -0
- vae/__pycache__/__init__.cpython-311.pyc +0 -0
- vae/__pycache__/ae_model.cpython-311.pyc +0 -0
- vae/__pycache__/dcae.cpython-311.pyc +0 -0
- vae/ae_model.py +141 -0
- vae/config.json +33 -0
- vae/dcae.py +271 -0
- vae/diffusion_pytorch_model.safetensors +3 -0
- vae/model.py +142 -0
README.md
CHANGED
|
@@ -8,7 +8,7 @@ tags:
|
|
| 8 |
- Egocentric
|
| 9 |
---
|
| 10 |
|
| 11 |
-
Waypoint-1-Small is a 2.3 billion parameter control-and-text-conditioned causal diffusion model. It is a transformer architecture utilizing rectified flow, distilled via self forcing with DMD. The model can autoregressively generate new frames given historical frames, actions, and text.
|
| 12 |
|
| 13 |
# Capabilities:
|
| 14 |
|
|
@@ -23,12 +23,99 @@ In order to simply use Waypoint-1-Small, we recommend [Biome](https://github.com
|
|
| 23 |
|
| 24 |
To run the model locally, we recommend an NVIDIA RTX 5090, which should achieve 20-30 FPS, or an RTX 6000 Pro Blackwell, which should achieve ~35 FPS.
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# Keywords
|
| 27 |
|
| 28 |
To properly explain limitations and misuse we must define some terms. While the model can be used for general interactive video generation tasks, we herein define interacting with the model via sending controls and receiving new frames as “playing” the model, and the agent/user inputting controls as the “player”. The model has two forms of output, continuations and generations. Continuations occur when seed frames are given and no inputs are given. For example, if a scene has fire or water, you may see them evolve progressively in the generated frames even if no action is given. Likewise, if you seed with an image of a humanoid entity, the entity will persist on the screen as you move/look around. However, generations occur when the player plays with the model extensively, for example moving around, turning around fully, or interacting with objects/items. Continuations roughly correspond to moving around already existing information in the given context frames while generations correspond to creating entirely new information.
|
| 29 |
|
| 30 |
# Limitations
|
| 31 |
-
|
| 32 |
- Continuations can plausibly model any inputted scene or photo, and will depend largely on the seed frame given. For generations, the model may occasionally:
|
| 33 |
- Ignore given text prompt
|
| 34 |
- Ignore certain controls in specific contexts
|
|
@@ -46,4 +133,5 @@ To properly explain limitations and misuse we must define some terms. While the
|
|
| 46 |
- For simulating extremely violent acts
|
| 47 |
- For generating violent/gory video
|
| 48 |
- For facilitation of large-scale disinformation campaigns
|
| 49 |
-
- For the purpose of generating any sexually explicit or suggestive material
|
|
|
|
|
|
| 8 |
- Egocentric
|
| 9 |
---
|
| 10 |
|
| 11 |
+
Waypoint-1-Small is a 2.3 billion parameter control-and-text-conditioned causal diffusion model. It is a transformer architecture utilizing rectified flow, distilled via self forcing with DMD. The model can autoregressively generate new frames given historical frames, actions, and text.
|
| 12 |
|
| 13 |
# Capabilities:
|
| 14 |
|
|
|
|
| 23 |
|
| 24 |
To run the model locally, we recommend an NVIDIA RTX 5090, which should achieve 20-30 FPS, or an RTX 6000 Pro Blackwell, which should achieve ~35 FPS.
|
| 25 |
|
| 26 |
+
# Run with Diffusers Modular Pipelines
|
| 27 |
+
|
| 28 |
+
World Engine and Waypoint-1 can be used with [Diffusers Modular Pipelines](https://huggingface.co/docs/diffusers/main/en/api/modular_pipelines/modular_pipeline).
|
| 29 |
+
|
| 30 |
+
## Setup
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
uv venv -p 3.11 && uv pip install \
|
| 34 |
+
torch>=2.9.0 \
|
| 35 |
+
diffusers>=0.36.0 \
|
| 36 |
+
transformers>=4.57.1 \
|
| 37 |
+
einops>=0.8.0 \
|
| 38 |
+
tensordict>=0.5.0 \
|
| 39 |
+
regex \
|
| 40 |
+
ftfy \
|
| 41 |
+
imageio \
|
| 42 |
+
imageio-ffmpeg \
|
| 43 |
+
tqdm
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Usage Example
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
import random
|
| 50 |
+
import torch
|
| 51 |
+
|
| 52 |
+
from tqdm import tqdm
|
| 53 |
+
from dataclasses import dataclass, field
|
| 54 |
+
from typing import Set, Tuple
|
| 55 |
+
from diffusers.modular_pipelines import ModularPipeline
|
| 56 |
+
from diffusers.utils import load_image, export_to_video
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class CtrlInput:
|
| 60 |
+
button: Set[int] = field(default_factory=set) # pressed button IDs
|
| 61 |
+
mouse: Tuple[float, float] = (0.0, 0.0) # (x, y) velocity
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Generate random control trajectories
|
| 65 |
+
ctrl = lambda: random.choice(
|
| 66 |
+
[
|
| 67 |
+
CtrlInput(button={48, 42}, mouse=[0.4, 0.3]),
|
| 68 |
+
CtrlInput(mouse=[0.1, 0.2]),
|
| 69 |
+
CtrlInput(button={95, 32, 105}),
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
model_id = "Overworld/Waypoint-1-Small"
|
| 73 |
+
|
| 74 |
+
pipe = ModularPipeline.from_pretrained(model_id, trust_remote_code=True)
|
| 75 |
+
pipe.load_components(
|
| 76 |
+
device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True
|
| 77 |
+
)
|
| 78 |
+
pipe.transformer.apply_inference_patches()
|
| 79 |
+
|
| 80 |
+
# Optional Quantization Step
|
| 81 |
+
# Available options are: nvfp4 (if running on Blackwell hardware), fp8, w8a8
|
| 82 |
+
# pipe.transformer.quantize("nvfp4")
|
| 83 |
+
pipe.transformer.compile(fullgraph=True, mode="max-autotune", dynamic=False)
|
| 84 |
+
pipe.vae.bake_weight_norm()
|
| 85 |
+
pipe.vae.compile(fullgraph=True, mode="max-autotune")
|
| 86 |
+
|
| 87 |
+
prompt = "A fun game"
|
| 88 |
+
image = load_image(
|
| 89 |
+
"https://gist.github.com/user-attachments/assets/4adc5a3d-6980-4d1e-b6e8-9033cdf61c66"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
num_frames = 240
|
| 93 |
+
outputs = []
|
| 94 |
+
|
| 95 |
+
# create world state based on an initial image
|
| 96 |
+
state = pipe(prompt=prompt, image=image, button=ctrl().button, mouse=ctrl().mouse)
|
| 97 |
+
outputs.append(state.values["images"])
|
| 98 |
+
|
| 99 |
+
state.values["image"] = None
|
| 100 |
+
for _ in tqdm(range(1, num_frames)):
|
| 101 |
+
state = pipe(
|
| 102 |
+
state,
|
| 103 |
+
prompt=prompt,
|
| 104 |
+
button=ctrl().button,
|
| 105 |
+
mouse=ctrl().mouse,
|
| 106 |
+
output_type="pil",
|
| 107 |
+
)
|
| 108 |
+
outputs.append(state.values["images"])
|
| 109 |
+
|
| 110 |
+
export_to_video(outputs, "waypoint-1-small.mp4", fps=60)
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
# Keywords
|
| 114 |
|
| 115 |
To properly explain limitations and misuse we must define some terms. While the model can be used for general interactive video generation tasks, we herein define interacting with the model via sending controls and receiving new frames as “playing” the model, and the agent/user inputting controls as the “player”. The model has two forms of output, continuations and generations. Continuations occur when seed frames are given and no inputs are given. For example, if a scene has fire or water, you may see them evolve progressively in the generated frames even if no action is given. Likewise, if you seed with an image of a humanoid entity, the entity will persist on the screen as you move/look around. However, generations occur when the player plays with the model extensively, for example moving around, turning around fully, or interacting with objects/items. Continuations roughly correspond to moving around already existing information in the given context frames while generations correspond to creating entirely new information.
|
| 116 |
|
| 117 |
# Limitations
|
| 118 |
+
|
| 119 |
- Continuations can plausibly model any inputted scene or photo, and will depend largely on the seed frame given. For generations, the model may occasionally:
|
| 120 |
- Ignore given text prompt
|
| 121 |
- Ignore certain controls in specific contexts
|
|
|
|
| 133 |
- For simulating extremely violent acts
|
| 134 |
- For generating violent/gory video
|
| 135 |
- For facilitation of large-scale disinformation campaigns
|
| 136 |
+
- For the purpose of generating any sexually explicit or suggestive material
|
| 137 |
+
|
__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
WorldEngine Modular Pipeline
|
| 18 |
+
|
| 19 |
+
A Diffusers-compatible modular pipeline for frame-by-frame world model generation.
|
| 20 |
+
Supports text and controller (mouse + button + scroll) conditioning.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from .modular_blocks import WorldEngineBlocks, AUTO_BLOCKS
|
| 24 |
+
from .encoders import WorldEngineTextEncoderStep, WorldEngineControllerEncoderStep
|
| 25 |
+
from .before_denoise import (
|
| 26 |
+
WorldEngineBeforeDenoiseStep,
|
| 27 |
+
WorldEngineSetTimestepsStep,
|
| 28 |
+
WorldEnginePrepareLatentsStep,
|
| 29 |
+
WorldEngineSetupKVCacheStep,
|
| 30 |
+
StaticKVCache,
|
| 31 |
+
LayerKVCache,
|
| 32 |
+
)
|
| 33 |
+
from .denoise import WorldEngineDenoiseLoop
|
| 34 |
+
from .decoders import WorldEngineDecodeStep
|
| 35 |
+
from .vae import WorldEngineVAE
|
| 36 |
+
|
| 37 |
+
__version__ = "0.1.0"
|
| 38 |
+
|
| 39 |
+
__all__ = [
|
| 40 |
+
# Main pipeline blocks
|
| 41 |
+
"WorldEngineBlocks",
|
| 42 |
+
"AUTO_BLOCKS",
|
| 43 |
+
# Encoder blocks
|
| 44 |
+
"WorldEngineTextEncoderStep",
|
| 45 |
+
"WorldEngineControllerEncoderStep",
|
| 46 |
+
# Before denoise blocks
|
| 47 |
+
"WorldEngineBeforeDenoiseStep",
|
| 48 |
+
"WorldEngineSetTimestepsStep",
|
| 49 |
+
"WorldEnginePrepareLatentsStep",
|
| 50 |
+
"WorldEngineSetupKVCacheStep",
|
| 51 |
+
# Denoise block
|
| 52 |
+
"WorldEngineDenoiseLoop",
|
| 53 |
+
# Decoder blocks
|
| 54 |
+
"WorldEngineDecodeStep",
|
| 55 |
+
# Models
|
| 56 |
+
"WorldModel",
|
| 57 |
+
"WorldEngineVAE",
|
| 58 |
+
# KV Cache
|
| 59 |
+
"StaticKVCache",
|
| 60 |
+
"LayerKVCache",
|
| 61 |
+
]
|
before_denoise.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Before-denoise blocks for WorldEngine modular pipeline."""
|
| 17 |
+
|
| 18 |
+
from typing import List, Optional, Union
|
| 19 |
+
|
| 20 |
+
import PIL.Image
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn, Tensor
|
| 23 |
+
from tensordict import TensorDict
|
| 24 |
+
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask
|
| 25 |
+
|
| 26 |
+
from diffusers.configuration_utils import FrozenDict
|
| 27 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 28 |
+
from diffusers.utils import logging
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
from diffusers.modular_pipelines import (
|
| 31 |
+
ModularPipelineBlocks,
|
| 32 |
+
ModularPipeline,
|
| 33 |
+
PipelineState,
|
| 34 |
+
SequentialPipelineBlocks,
|
| 35 |
+
)
|
| 36 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
| 37 |
+
ComponentSpec,
|
| 38 |
+
ConfigSpec,
|
| 39 |
+
InputParam,
|
| 40 |
+
OutputParam,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask:
|
| 47 |
+
"""
|
| 48 |
+
Create a block mask for flex_attention.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
T: Q length for this frame
|
| 52 |
+
L: KV capacity == written.numel()
|
| 53 |
+
written: [L] bool, True where there is valid KV data
|
| 54 |
+
"""
|
| 55 |
+
BS = _DEFAULT_SPARSE_BLOCK_SIZE
|
| 56 |
+
KV_blocks = (L + BS - 1) // BS
|
| 57 |
+
Q_blocks = (T + BS - 1) // BS
|
| 58 |
+
|
| 59 |
+
# [KV_blocks, BS]
|
| 60 |
+
written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view(
|
| 61 |
+
KV_blocks, BS
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Block-level occupancy
|
| 65 |
+
block_any = written_blocks.any(-1) # block has at least one written token
|
| 66 |
+
block_all = written_blocks.all(-1) # block is fully written
|
| 67 |
+
|
| 68 |
+
# Every Q-block sees the same KV-block pattern
|
| 69 |
+
nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) # [Q_blocks, KV_blocks]
|
| 70 |
+
full_bm = block_all[None, :].expand_as(nonzero_bm) # [Q_blocks, KV_blocks]
|
| 71 |
+
partial_bm = nonzero_bm & ~full_bm # [Q_blocks, KV_blocks]
|
| 72 |
+
|
| 73 |
+
def dense_to_ordered(dense_mask: torch.Tensor):
|
| 74 |
+
# dense_mask: [Q_blocks, KV_blocks] bool
|
| 75 |
+
# returns: [1,1,Q_blocks], [1,1,Q_blocks,KV_blocks]
|
| 76 |
+
num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) # [Q_blocks]
|
| 77 |
+
indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(
|
| 78 |
+
torch.int32
|
| 79 |
+
)
|
| 80 |
+
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
|
| 81 |
+
|
| 82 |
+
# Partial blocks (need mask_mod)
|
| 83 |
+
kv_num_blocks, kv_indices = dense_to_ordered(partial_bm)
|
| 84 |
+
|
| 85 |
+
# Full blocks (mask_mod can be skipped entirely)
|
| 86 |
+
full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm)
|
| 87 |
+
|
| 88 |
+
def mask_mod(b, h, q, kv):
|
| 89 |
+
return written[kv]
|
| 90 |
+
|
| 91 |
+
bm = BlockMask.from_kv_blocks(
|
| 92 |
+
kv_num_blocks,
|
| 93 |
+
kv_indices,
|
| 94 |
+
full_kv_num_blocks,
|
| 95 |
+
full_kv_indices,
|
| 96 |
+
BLOCK_SIZE=BS,
|
| 97 |
+
mask_mod=mask_mod,
|
| 98 |
+
seq_lengths=(T, L),
|
| 99 |
+
compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return bm
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class LayerKVCache(nn.Module):
|
| 106 |
+
"""
|
| 107 |
+
Ring-buffer KV cache with fixed capacity L (tokens) for history plus
|
| 108 |
+
one extra frame (tokens_per_frame) at the tail holding the current frame.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.tpf = tokens_per_frame
|
| 116 |
+
self.L = L
|
| 117 |
+
# total KV capacity: ring (L) + tail frame (tpf)
|
| 118 |
+
self.capacity = L + self.tpf
|
| 119 |
+
self.pinned_dilation = pinned_dilation
|
| 120 |
+
self.num_buckets = (L // self.tpf) // self.pinned_dilation
|
| 121 |
+
assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0
|
| 122 |
+
|
| 123 |
+
# KV buffer: [2, B, H, capacity, Dh]
|
| 124 |
+
self.kv = nn.Buffer(
|
| 125 |
+
torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype),
|
| 126 |
+
persistent=False,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# which slots have ever been written
|
| 130 |
+
# tail slice [L, L+tpf) always holds the current frame and is considered written
|
| 131 |
+
written = torch.zeros(self.capacity, dtype=torch.bool)
|
| 132 |
+
written[L:] = True
|
| 133 |
+
self.written = nn.Buffer(written, persistent=False)
|
| 134 |
+
|
| 135 |
+
# Precompute indices:
|
| 136 |
+
# frame_offsets: [0, 1, ..., tpf-1] (for ring indexing)
|
| 137 |
+
# current_idx: [L, L+1, ..., L+tpf-1] (tail slice)
|
| 138 |
+
self.frame_offsets = nn.Buffer(
|
| 139 |
+
torch.arange(self.tpf, dtype=torch.long), persistent=False
|
| 140 |
+
)
|
| 141 |
+
self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False)
|
| 142 |
+
|
| 143 |
+
def reset(self):
|
| 144 |
+
self.kv.zero_()
|
| 145 |
+
self.written.zero_()
|
| 146 |
+
self.written[self.L :].fill_(True)
|
| 147 |
+
|
| 148 |
+
def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool):
|
| 149 |
+
"""
|
| 150 |
+
Args:
|
| 151 |
+
kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame)
|
| 152 |
+
pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1)
|
| 153 |
+
"""
|
| 154 |
+
T = self.tpf
|
| 155 |
+
t_pos = pos_ids["t_pos"]
|
| 156 |
+
|
| 157 |
+
if not torch.compiler.is_compiling():
|
| 158 |
+
torch._check(
|
| 159 |
+
kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert"
|
| 160 |
+
)
|
| 161 |
+
torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]")
|
| 162 |
+
torch._check(self.tpf <= self.L, "frame longer than KV ring capacity")
|
| 163 |
+
torch._check(
|
| 164 |
+
self.L % self.tpf == 0,
|
| 165 |
+
f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})",
|
| 166 |
+
)
|
| 167 |
+
torch._check(
|
| 168 |
+
self.kv.size(3) == self.capacity,
|
| 169 |
+
"KV buffer has unexpected length (expected L + tokens_per_frame)",
|
| 170 |
+
)
|
| 171 |
+
torch._check(
|
| 172 |
+
(t_pos >= 0).all().item(),
|
| 173 |
+
"t_pos must be non-negative during inference",
|
| 174 |
+
)
|
| 175 |
+
torch._check(
|
| 176 |
+
((t_pos == t_pos[:, :1]).all()).item(),
|
| 177 |
+
"t_pos must be constant within frame",
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
frame_t = t_pos[0, 0]
|
| 181 |
+
|
| 182 |
+
# map frame_t to a bucket, each bucket owns T contiguous slots
|
| 183 |
+
bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation
|
| 184 |
+
slot = bucket % self.num_buckets
|
| 185 |
+
base = slot * T
|
| 186 |
+
|
| 187 |
+
# indices in the ring for this frame: [T] in [0, L)
|
| 188 |
+
ring_idx = self.frame_offsets + base
|
| 189 |
+
|
| 190 |
+
# Always write current frame into the tail slice [L, L+T):
|
| 191 |
+
# this is the "self-attention component" for the current frame.
|
| 192 |
+
self.kv.index_copy_(3, self.current_idx, kv)
|
| 193 |
+
|
| 194 |
+
write_step = frame_t.remainder(self.pinned_dilation) == 0
|
| 195 |
+
mask_written = self.written.clone()
|
| 196 |
+
mask_written[ring_idx] = mask_written[ring_idx] & ~write_step
|
| 197 |
+
bm = make_block_mask(T, self.capacity, mask_written)
|
| 198 |
+
|
| 199 |
+
# Persist current frame into the ring for future queries when unfrozen.
|
| 200 |
+
if not is_frozen:
|
| 201 |
+
# Persist current frame into the ring for future queries.
|
| 202 |
+
dst = torch.where(write_step, ring_idx, self.current_idx)
|
| 203 |
+
self.kv.index_copy_(3, dst, kv)
|
| 204 |
+
self.written[dst] = True
|
| 205 |
+
|
| 206 |
+
k, v = self.kv.unbind(0)
|
| 207 |
+
return k, v, bm
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class StaticKVCache(nn.Module):
|
| 211 |
+
"""Static KV cache with per-layer configuration for local/global attention."""
|
| 212 |
+
|
| 213 |
+
def __init__(self, config, batch_size, dtype):
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
self.tpf = config.tokens_per_frame
|
| 217 |
+
|
| 218 |
+
local_L = config.local_window * self.tpf
|
| 219 |
+
global_L = config.global_window * self.tpf
|
| 220 |
+
|
| 221 |
+
period = config.global_attn_period
|
| 222 |
+
off = getattr(config, "global_attn_offset", 0) % period
|
| 223 |
+
self.layers = nn.ModuleList(
|
| 224 |
+
[
|
| 225 |
+
LayerKVCache(
|
| 226 |
+
batch_size,
|
| 227 |
+
getattr(config, "n_kv_heads", config.n_heads),
|
| 228 |
+
global_L if ((layer_idx - off) % period == 0) else local_L,
|
| 229 |
+
config.d_model // config.n_heads,
|
| 230 |
+
dtype,
|
| 231 |
+
self.tpf,
|
| 232 |
+
(
|
| 233 |
+
config.global_pinned_dilation
|
| 234 |
+
if ((layer_idx - off) % period == 0)
|
| 235 |
+
else 1
|
| 236 |
+
),
|
| 237 |
+
)
|
| 238 |
+
for layer_idx in range(config.n_layers)
|
| 239 |
+
]
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
self._is_frozen = True
|
| 243 |
+
|
| 244 |
+
def reset(self):
|
| 245 |
+
for layer in self.layers:
|
| 246 |
+
layer.reset()
|
| 247 |
+
self._is_frozen = True
|
| 248 |
+
|
| 249 |
+
def set_frozen(self, is_frozen: bool):
|
| 250 |
+
self._is_frozen = is_frozen
|
| 251 |
+
|
| 252 |
+
def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int):
|
| 253 |
+
kv = torch.stack([k, v], dim=0)
|
| 254 |
+
return self.layers[layer].upsert(kv, pos_ids, self._is_frozen)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class WorldEngineSetTimestepsStep(ModularPipelineBlocks):
|
| 258 |
+
"""Sets up the scheduler sigmas for rectified flow denoising."""
|
| 259 |
+
|
| 260 |
+
model_name = "world_engine"
|
| 261 |
+
|
| 262 |
+
@property
|
| 263 |
+
def description(self) -> str:
|
| 264 |
+
return "Sets up scheduler sigmas for rectified flow denoising"
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 268 |
+
return []
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def expected_configs(self) -> List[ConfigSpec]:
|
| 272 |
+
return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])]
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def inputs(self) -> List[InputParam]:
|
| 276 |
+
return [
|
| 277 |
+
InputParam(
|
| 278 |
+
"scheduler_sigmas",
|
| 279 |
+
type_hint=List[float],
|
| 280 |
+
description="Custom scheduler sigmas (overrides config)",
|
| 281 |
+
),
|
| 282 |
+
InputParam(
|
| 283 |
+
"frame_timestamp",
|
| 284 |
+
type_hint=torch.Tensor,
|
| 285 |
+
description="Current frame timestamp",
|
| 286 |
+
),
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
@property
|
| 290 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 291 |
+
return [
|
| 292 |
+
OutputParam(
|
| 293 |
+
"scheduler_sigmas",
|
| 294 |
+
type_hint=torch.Tensor,
|
| 295 |
+
description="Tensor of scheduler sigmas for denoising",
|
| 296 |
+
),
|
| 297 |
+
OutputParam(
|
| 298 |
+
"frame_timestamp",
|
| 299 |
+
type_hint=torch.Tensor,
|
| 300 |
+
description="Current frame timestamp",
|
| 301 |
+
),
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
@torch.no_grad()
|
| 305 |
+
def __call__(
|
| 306 |
+
self, components: ModularPipeline, state: PipelineState
|
| 307 |
+
) -> PipelineState:
|
| 308 |
+
block_state = self.get_block_state(state)
|
| 309 |
+
device = components._execution_device
|
| 310 |
+
dtype = components.transformer.dtype
|
| 311 |
+
|
| 312 |
+
# Use provided sigmas or get from config
|
| 313 |
+
sigmas = block_state.scheduler_sigmas
|
| 314 |
+
if sigmas is None:
|
| 315 |
+
sigmas = components.config.scheduler_sigmas
|
| 316 |
+
block_state.scheduler_sigmas = torch.tensor(
|
| 317 |
+
sigmas, device=device, dtype=dtype
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
frame_ts = block_state.frame_timestamp
|
| 321 |
+
if frame_ts is None:
|
| 322 |
+
frame_ts = torch.tensor([[0]], dtype=torch.long, device=device)
|
| 323 |
+
elif isinstance(frame_ts, int):
|
| 324 |
+
frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device)
|
| 325 |
+
|
| 326 |
+
block_state.frame_timestamp = frame_ts
|
| 327 |
+
|
| 328 |
+
self.set_block_state(state, block_state)
|
| 329 |
+
return components, state
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class WorldEngineSetupKVCacheStep(ModularPipelineBlocks):
|
| 333 |
+
"""Initializes or reuses the KV cache for autoregressive generation."""
|
| 334 |
+
|
| 335 |
+
model_name = "world_engine"
|
| 336 |
+
|
| 337 |
+
@property
|
| 338 |
+
def description(self) -> str:
|
| 339 |
+
return "Initializes or reuses KV cache for autoregressive frame generation"
|
| 340 |
+
|
| 341 |
+
@property
|
| 342 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 343 |
+
return []
|
| 344 |
+
|
| 345 |
+
@property
|
| 346 |
+
def inputs(self) -> List[InputParam]:
|
| 347 |
+
return [
|
| 348 |
+
InputParam(
|
| 349 |
+
"kv_cache",
|
| 350 |
+
type_hint=Optional[StaticKVCache],
|
| 351 |
+
description="Existing KV cache (will be reused if provided)",
|
| 352 |
+
),
|
| 353 |
+
InputParam(
|
| 354 |
+
"reset_cache",
|
| 355 |
+
type_hint=bool,
|
| 356 |
+
default=False,
|
| 357 |
+
description="If True, reset the KV cache even if one exists",
|
| 358 |
+
),
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 363 |
+
return [
|
| 364 |
+
OutputParam(
|
| 365 |
+
"kv_cache",
|
| 366 |
+
type_hint=StaticKVCache,
|
| 367 |
+
description="KV cache for transformer attention",
|
| 368 |
+
),
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
@torch.no_grad()
|
| 372 |
+
def __call__(
|
| 373 |
+
self, components: ModularPipeline, state: PipelineState
|
| 374 |
+
) -> PipelineState:
|
| 375 |
+
block_state = self.get_block_state(state)
|
| 376 |
+
device = components._execution_device
|
| 377 |
+
dtype = components.transformer.dtype
|
| 378 |
+
|
| 379 |
+
# Create or reuse KV cache
|
| 380 |
+
if block_state.kv_cache is None:
|
| 381 |
+
block_state.kv_cache = StaticKVCache(
|
| 382 |
+
components.transformer.config,
|
| 383 |
+
batch_size=1,
|
| 384 |
+
dtype=dtype,
|
| 385 |
+
).to(device)
|
| 386 |
+
elif block_state.reset_cache:
|
| 387 |
+
block_state.kv_cache.reset()
|
| 388 |
+
|
| 389 |
+
self.set_block_state(state, block_state)
|
| 390 |
+
return components, state
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class WorldEnginePrepareLatentsStep(ModularPipelineBlocks):
|
| 394 |
+
"""Prepares latents for frame generation, optionally encoding an input image."""
|
| 395 |
+
|
| 396 |
+
model_name = "world_engine"
|
| 397 |
+
|
| 398 |
+
@property
|
| 399 |
+
def description(self) -> str:
|
| 400 |
+
return (
|
| 401 |
+
"Prepares latents for frame generation. If an image is provided on the "
|
| 402 |
+
"first frame, encodes it and caches it as context. Always creates fresh "
|
| 403 |
+
"random noise for the actual denoising."
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
@property
|
| 407 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 408 |
+
return [
|
| 409 |
+
ComponentSpec(
|
| 410 |
+
"image_processor",
|
| 411 |
+
VaeImageProcessor,
|
| 412 |
+
config=FrozenDict(
|
| 413 |
+
{
|
| 414 |
+
"vae_scale_factor": 16,
|
| 415 |
+
"do_normalize": False,
|
| 416 |
+
"do_convert_rgb": False,
|
| 417 |
+
}
|
| 418 |
+
),
|
| 419 |
+
default_creation_method="from_config",
|
| 420 |
+
),
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
@property
|
| 424 |
+
def expected_configs(self) -> List[ConfigSpec]:
|
| 425 |
+
return [
|
| 426 |
+
ConfigSpec("channels", 16),
|
| 427 |
+
ConfigSpec("height", 16),
|
| 428 |
+
ConfigSpec("width", 16),
|
| 429 |
+
ConfigSpec("patch", [2, 2]),
|
| 430 |
+
ConfigSpec("vae_scale_factor", 16),
|
| 431 |
+
]
|
| 432 |
+
|
| 433 |
+
@property
|
| 434 |
+
def inputs(self) -> List[InputParam]:
|
| 435 |
+
return [
|
| 436 |
+
InputParam(
|
| 437 |
+
"image",
|
| 438 |
+
type_hint=Union[PIL.Image.Image, torch.Tensor],
|
| 439 |
+
description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame",
|
| 440 |
+
),
|
| 441 |
+
InputParam(
|
| 442 |
+
"latents",
|
| 443 |
+
type_hint=torch.Tensor,
|
| 444 |
+
description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.",
|
| 445 |
+
),
|
| 446 |
+
InputParam(
|
| 447 |
+
"use_random_latents",
|
| 448 |
+
type_hint=bool,
|
| 449 |
+
default=True,
|
| 450 |
+
description="If True, always generate fresh random latents. If False, use provided latents.",
|
| 451 |
+
),
|
| 452 |
+
InputParam(
|
| 453 |
+
"kv_cache",
|
| 454 |
+
description="KV cache to update",
|
| 455 |
+
),
|
| 456 |
+
InputParam(
|
| 457 |
+
"frame_timestamp",
|
| 458 |
+
type_hint=torch.Tensor,
|
| 459 |
+
description="Current frame timestamp",
|
| 460 |
+
),
|
| 461 |
+
InputParam(
|
| 462 |
+
"prompt_embeds",
|
| 463 |
+
type_hint=torch.Tensor,
|
| 464 |
+
description="Prompt embeddings for cache pass",
|
| 465 |
+
),
|
| 466 |
+
InputParam(
|
| 467 |
+
"prompt_pad_mask",
|
| 468 |
+
type_hint=torch.Tensor,
|
| 469 |
+
description="Prompt padding mask",
|
| 470 |
+
),
|
| 471 |
+
InputParam(
|
| 472 |
+
"button_tensor",
|
| 473 |
+
type_hint=torch.Tensor,
|
| 474 |
+
description="Button tensor for cache pass",
|
| 475 |
+
),
|
| 476 |
+
InputParam(
|
| 477 |
+
"mouse_tensor",
|
| 478 |
+
type_hint=torch.Tensor,
|
| 479 |
+
description="Mouse tensor for cache pass",
|
| 480 |
+
),
|
| 481 |
+
InputParam(
|
| 482 |
+
"scroll_tensor",
|
| 483 |
+
type_hint=torch.Tensor,
|
| 484 |
+
description="Scroll tensor for cache pass",
|
| 485 |
+
),
|
| 486 |
+
InputParam(
|
| 487 |
+
"generator",
|
| 488 |
+
type_hint=torch.Generator,
|
| 489 |
+
default=None,
|
| 490 |
+
description="torch Generator for deterministic output",
|
| 491 |
+
),
|
| 492 |
+
]
|
| 493 |
+
|
| 494 |
+
@property
|
| 495 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 496 |
+
return [
|
| 497 |
+
OutputParam(
|
| 498 |
+
"latents",
|
| 499 |
+
type_hint=torch.Tensor,
|
| 500 |
+
description="Latent tensor for denoising [1, 1, C, H, W]",
|
| 501 |
+
),
|
| 502 |
+
]
|
| 503 |
+
|
| 504 |
+
@staticmethod
|
| 505 |
+
def _cache_pass(
|
| 506 |
+
transformer,
|
| 507 |
+
x,
|
| 508 |
+
frame_timestamp,
|
| 509 |
+
prompt_emb,
|
| 510 |
+
prompt_pad_mask,
|
| 511 |
+
mouse,
|
| 512 |
+
button,
|
| 513 |
+
scroll,
|
| 514 |
+
kv_cache,
|
| 515 |
+
):
|
| 516 |
+
"""Cache pass to persist frame in KV cache."""
|
| 517 |
+
kv_cache.set_frozen(False)
|
| 518 |
+
transformer(
|
| 519 |
+
x=x,
|
| 520 |
+
sigma=x.new_zeros((x.size(0), x.size(1))),
|
| 521 |
+
frame_timestamp=frame_timestamp,
|
| 522 |
+
prompt_emb=prompt_emb,
|
| 523 |
+
prompt_pad_mask=prompt_pad_mask,
|
| 524 |
+
mouse=mouse,
|
| 525 |
+
button=button,
|
| 526 |
+
scroll=scroll,
|
| 527 |
+
kv_cache=kv_cache,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
@torch.inference_mode()
|
| 531 |
+
def __call__(
|
| 532 |
+
self, components: ModularPipeline, state: PipelineState
|
| 533 |
+
) -> PipelineState:
|
| 534 |
+
block_state = self.get_block_state(state)
|
| 535 |
+
device = components._execution_device
|
| 536 |
+
dtype = components.transformer.dtype
|
| 537 |
+
|
| 538 |
+
# Get latent shape info
|
| 539 |
+
channels = components.config.channels
|
| 540 |
+
height = components.config.height
|
| 541 |
+
width = components.config.width
|
| 542 |
+
patch = components.config.patch
|
| 543 |
+
|
| 544 |
+
pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch)
|
| 545 |
+
shape = (
|
| 546 |
+
1,
|
| 547 |
+
1,
|
| 548 |
+
channels,
|
| 549 |
+
components.config.vae_scale_factor * pH,
|
| 550 |
+
components.config.vae_scale_factor * pW,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
if block_state.image is not None:
|
| 554 |
+
image = block_state.image
|
| 555 |
+
# Preprocess: PIL/tensor -> [B, C, H, W] float32 in [0, 1]
|
| 556 |
+
image = components.image_processor.preprocess(
|
| 557 |
+
image,
|
| 558 |
+
height=height,
|
| 559 |
+
width=width,
|
| 560 |
+
)
|
| 561 |
+
# Convert to [H, W, 3] uint8 for VAE encoder
|
| 562 |
+
image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8)
|
| 563 |
+
|
| 564 |
+
assert image.dtype == torch.uint8, (
|
| 565 |
+
f"Expected uint8 image, got {image.dtype}"
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
latents = components.vae.encode(image)
|
| 569 |
+
latents = latents.unsqueeze(1)
|
| 570 |
+
|
| 571 |
+
# Run cache pass to persist encoded frame
|
| 572 |
+
self._cache_pass(
|
| 573 |
+
components.transformer,
|
| 574 |
+
latents,
|
| 575 |
+
block_state.frame_timestamp,
|
| 576 |
+
block_state.prompt_embeds,
|
| 577 |
+
block_state.prompt_pad_mask,
|
| 578 |
+
block_state.mouse_tensor,
|
| 579 |
+
block_state.button_tensor,
|
| 580 |
+
block_state.scroll_tensor,
|
| 581 |
+
block_state.kv_cache,
|
| 582 |
+
)
|
| 583 |
+
block_state.frame_timestamp.add_(1)
|
| 584 |
+
|
| 585 |
+
# Generate latents based on use_random_latents flag
|
| 586 |
+
if block_state.use_random_latents or block_state.latents is None:
|
| 587 |
+
block_state.latents = torch.randn(
|
| 588 |
+
shape, device=device, dtype=torch.bfloat16
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
self.set_block_state(state, block_state)
|
| 592 |
+
return components, state
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks):
|
| 596 |
+
"""Sequential pipeline that prepares all inputs for denoising."""
|
| 597 |
+
|
| 598 |
+
block_classes = [
|
| 599 |
+
WorldEngineSetTimestepsStep,
|
| 600 |
+
WorldEngineSetupKVCacheStep,
|
| 601 |
+
WorldEnginePrepareLatentsStep,
|
| 602 |
+
]
|
| 603 |
+
block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"]
|
| 604 |
+
|
| 605 |
+
@property
|
| 606 |
+
def description(self) -> str:
|
| 607 |
+
return (
|
| 608 |
+
"Before denoise step that prepares inputs for denoising:\n"
|
| 609 |
+
" - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n"
|
| 610 |
+
" - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n"
|
| 611 |
+
" - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise"
|
| 612 |
+
)
|
decoders.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Decoder blocks for WorldEngine modular pipeline."""
|
| 17 |
+
|
| 18 |
+
from typing import List, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import PIL.Image
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from diffusers import AutoModel
|
| 25 |
+
from diffusers.configuration_utils import FrozenDict
|
| 26 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 27 |
+
from diffusers.utils import logging
|
| 28 |
+
from diffusers.modular_pipelines import (
|
| 29 |
+
ModularPipelineBlocks,
|
| 30 |
+
ModularPipeline,
|
| 31 |
+
PipelineState,
|
| 32 |
+
)
|
| 33 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
| 34 |
+
ComponentSpec,
|
| 35 |
+
InputParam,
|
| 36 |
+
OutputParam,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class WorldEngineDecodeStep(ModularPipelineBlocks):
|
| 43 |
+
"""Decodes denoised latents back to RGB image using VAE."""
|
| 44 |
+
|
| 45 |
+
model_name = "world_engine"
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 49 |
+
return [
|
| 50 |
+
ComponentSpec("vae", AutoModel),
|
| 51 |
+
ComponentSpec(
|
| 52 |
+
"image_processor",
|
| 53 |
+
VaeImageProcessor,
|
| 54 |
+
config=FrozenDict(
|
| 55 |
+
{
|
| 56 |
+
"vae_scale_factor": 16,
|
| 57 |
+
"do_normalize": False,
|
| 58 |
+
"do_convert_rgb": True,
|
| 59 |
+
}
|
| 60 |
+
),
|
| 61 |
+
default_creation_method="from_config",
|
| 62 |
+
),
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def description(self) -> str:
|
| 67 |
+
return "Decodes denoised latents to RGB image using the VAE decoder"
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def inputs(self) -> List[InputParam]:
|
| 71 |
+
return [
|
| 72 |
+
InputParam(
|
| 73 |
+
"latents",
|
| 74 |
+
required=True,
|
| 75 |
+
type_hint=torch.Tensor,
|
| 76 |
+
description="Denoised latent tensor [1, 1, C, H, W]",
|
| 77 |
+
),
|
| 78 |
+
InputParam(
|
| 79 |
+
"output_type",
|
| 80 |
+
default="pil",
|
| 81 |
+
description="The output format for the generated images (pil, latent, pt, or np)",
|
| 82 |
+
),
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 87 |
+
return [
|
| 88 |
+
OutputParam(
|
| 89 |
+
"images",
|
| 90 |
+
type_hint=Union[PIL.Image.Image, torch.Tensor, np.ndarray],
|
| 91 |
+
description="Decoded RGB image in requested output format",
|
| 92 |
+
),
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def __call__(
|
| 97 |
+
self, components: ModularPipeline, state: PipelineState
|
| 98 |
+
) -> PipelineState:
|
| 99 |
+
block_state = self.get_block_state(state)
|
| 100 |
+
latents = block_state.latents
|
| 101 |
+
output_type = block_state.output_type or "pil"
|
| 102 |
+
|
| 103 |
+
if output_type == "latent":
|
| 104 |
+
block_state.images = latents
|
| 105 |
+
else:
|
| 106 |
+
# Decode to image
|
| 107 |
+
# VAE expects [B, C, H, W] input, squeeze frame dim
|
| 108 |
+
# VAE returns [H, W, 3] uint8 tensor
|
| 109 |
+
image = components.vae.decode(latents.squeeze(1))
|
| 110 |
+
|
| 111 |
+
# Postprocess based on output_type
|
| 112 |
+
if output_type == "pt":
|
| 113 |
+
block_state.images = image
|
| 114 |
+
elif output_type == "np":
|
| 115 |
+
block_state.images = image.cpu().numpy()
|
| 116 |
+
else: # "pil"
|
| 117 |
+
block_state.images = PIL.Image.fromarray(image.cpu().numpy())
|
| 118 |
+
|
| 119 |
+
# Clear latents so next frame generates fresh random noise
|
| 120 |
+
block_state.latents = None
|
| 121 |
+
self.set_block_state(state, block_state)
|
| 122 |
+
return components, state
|
denoise.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Denoising block for WorldEngine modular pipeline."""
|
| 17 |
+
|
| 18 |
+
from typing import List
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from diffusers.utils import logging
|
| 23 |
+
from diffusers.modular_pipelines import (
|
| 24 |
+
ModularPipelineBlocks,
|
| 25 |
+
ModularPipeline,
|
| 26 |
+
PipelineState,
|
| 27 |
+
)
|
| 28 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
| 29 |
+
ComponentSpec,
|
| 30 |
+
InputParam,
|
| 31 |
+
OutputParam,
|
| 32 |
+
)
|
| 33 |
+
from diffusers import AutoModel
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class WorldEngineDenoiseLoop(ModularPipelineBlocks):
|
| 39 |
+
"""Denoises latents using rectified flow and updates KV cache."""
|
| 40 |
+
|
| 41 |
+
model_name = "world_engine"
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 45 |
+
return [ComponentSpec("transformer", AutoModel)]
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def description(self) -> str:
|
| 49 |
+
return (
|
| 50 |
+
"Denoises latents using rectified flow (x = x + dsigma * v) "
|
| 51 |
+
"and updates KV cache for autoregressive generation."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def inputs(self) -> List[InputParam]:
|
| 56 |
+
return [
|
| 57 |
+
InputParam(
|
| 58 |
+
"scheduler_sigmas",
|
| 59 |
+
required=True,
|
| 60 |
+
type_hint=torch.Tensor,
|
| 61 |
+
description="Scheduler sigmas for denoising",
|
| 62 |
+
),
|
| 63 |
+
InputParam(
|
| 64 |
+
"latents",
|
| 65 |
+
required=True,
|
| 66 |
+
type_hint=torch.Tensor,
|
| 67 |
+
description="Initial noisy latents [1, 1, C, H, W]",
|
| 68 |
+
),
|
| 69 |
+
InputParam(
|
| 70 |
+
"kv_cache",
|
| 71 |
+
required=True,
|
| 72 |
+
description="KV cache for transformer attention",
|
| 73 |
+
),
|
| 74 |
+
InputParam(
|
| 75 |
+
"frame_timestamp",
|
| 76 |
+
required=True,
|
| 77 |
+
type_hint=torch.Tensor,
|
| 78 |
+
description="Current frame timestamp",
|
| 79 |
+
),
|
| 80 |
+
InputParam(
|
| 81 |
+
"prompt_embeds",
|
| 82 |
+
required=True,
|
| 83 |
+
type_hint=torch.Tensor,
|
| 84 |
+
description="Text embeddings for conditioning",
|
| 85 |
+
),
|
| 86 |
+
InputParam(
|
| 87 |
+
"prompt_pad_mask",
|
| 88 |
+
type_hint=torch.Tensor,
|
| 89 |
+
description="Padding mask for prompt embeddings",
|
| 90 |
+
),
|
| 91 |
+
InputParam(
|
| 92 |
+
"button_tensor",
|
| 93 |
+
required=True,
|
| 94 |
+
type_hint=torch.Tensor,
|
| 95 |
+
description="One-hot encoded button tensor",
|
| 96 |
+
),
|
| 97 |
+
InputParam(
|
| 98 |
+
"mouse_tensor",
|
| 99 |
+
required=True,
|
| 100 |
+
type_hint=torch.Tensor,
|
| 101 |
+
description="Mouse velocity tensor",
|
| 102 |
+
),
|
| 103 |
+
InputParam(
|
| 104 |
+
"scroll_tensor",
|
| 105 |
+
required=True,
|
| 106 |
+
type_hint=torch.Tensor,
|
| 107 |
+
description="Scroll wheel sign tensor",
|
| 108 |
+
),
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 113 |
+
return [
|
| 114 |
+
OutputParam(
|
| 115 |
+
"latents",
|
| 116 |
+
type_hint=torch.Tensor,
|
| 117 |
+
description="Denoised latents",
|
| 118 |
+
),
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def _denoise_pass(
|
| 123 |
+
transformer,
|
| 124 |
+
x,
|
| 125 |
+
sigmas,
|
| 126 |
+
frame_timestamp,
|
| 127 |
+
prompt_emb,
|
| 128 |
+
prompt_pad_mask,
|
| 129 |
+
mouse,
|
| 130 |
+
button,
|
| 131 |
+
scroll,
|
| 132 |
+
kv_cache,
|
| 133 |
+
):
|
| 134 |
+
"""Denoising loop using rectified flow."""
|
| 135 |
+
kv_cache.set_frozen(True)
|
| 136 |
+
sigma = x.new_empty((x.size(0), x.size(1)))
|
| 137 |
+
for step_sig, step_dsig in zip(sigmas, sigmas.diff()):
|
| 138 |
+
v = transformer(
|
| 139 |
+
x=x,
|
| 140 |
+
sigma=sigma.fill_(step_sig),
|
| 141 |
+
frame_timestamp=frame_timestamp,
|
| 142 |
+
prompt_emb=prompt_emb,
|
| 143 |
+
prompt_pad_mask=prompt_pad_mask,
|
| 144 |
+
mouse=mouse,
|
| 145 |
+
button=button,
|
| 146 |
+
scroll=scroll,
|
| 147 |
+
kv_cache=kv_cache,
|
| 148 |
+
)
|
| 149 |
+
x = x + step_dsig * v
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def _cache_pass(
|
| 154 |
+
transformer,
|
| 155 |
+
x,
|
| 156 |
+
frame_timestamp,
|
| 157 |
+
prompt_emb,
|
| 158 |
+
prompt_pad_mask,
|
| 159 |
+
mouse,
|
| 160 |
+
button,
|
| 161 |
+
scroll,
|
| 162 |
+
kv_cache,
|
| 163 |
+
):
|
| 164 |
+
"""Cache pass to persist frame for next generation."""
|
| 165 |
+
kv_cache.set_frozen(False)
|
| 166 |
+
transformer(
|
| 167 |
+
x=x,
|
| 168 |
+
sigma=x.new_zeros((x.size(0), x.size(1))),
|
| 169 |
+
frame_timestamp=frame_timestamp,
|
| 170 |
+
prompt_emb=prompt_emb,
|
| 171 |
+
prompt_pad_mask=prompt_pad_mask,
|
| 172 |
+
mouse=mouse,
|
| 173 |
+
button=button,
|
| 174 |
+
scroll=scroll,
|
| 175 |
+
kv_cache=kv_cache,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
@torch.inference_mode()
|
| 179 |
+
def __call__(
|
| 180 |
+
self, components: ModularPipeline, state: PipelineState
|
| 181 |
+
) -> PipelineState:
|
| 182 |
+
block_state = self.get_block_state(state)
|
| 183 |
+
block_state.latents = self._denoise_pass(
|
| 184 |
+
components.transformer,
|
| 185 |
+
block_state.latents,
|
| 186 |
+
block_state.scheduler_sigmas,
|
| 187 |
+
block_state.frame_timestamp,
|
| 188 |
+
block_state.prompt_embeds,
|
| 189 |
+
block_state.prompt_pad_mask,
|
| 190 |
+
block_state.mouse_tensor,
|
| 191 |
+
block_state.button_tensor,
|
| 192 |
+
block_state.scroll_tensor,
|
| 193 |
+
block_state.kv_cache,
|
| 194 |
+
).clone()
|
| 195 |
+
|
| 196 |
+
self._cache_pass(
|
| 197 |
+
components.transformer,
|
| 198 |
+
block_state.latents,
|
| 199 |
+
block_state.frame_timestamp,
|
| 200 |
+
block_state.prompt_embeds,
|
| 201 |
+
block_state.prompt_pad_mask,
|
| 202 |
+
block_state.mouse_tensor,
|
| 203 |
+
block_state.button_tensor,
|
| 204 |
+
block_state.scroll_tensor,
|
| 205 |
+
block_state.kv_cache,
|
| 206 |
+
)
|
| 207 |
+
block_state.frame_timestamp.add_(1)
|
| 208 |
+
|
| 209 |
+
self.set_block_state(state, block_state)
|
| 210 |
+
return components, state
|
encoders.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Text and controller encoder blocks for WorldEngine modular pipeline."""
|
| 17 |
+
|
| 18 |
+
import html
|
| 19 |
+
from typing import List, Set, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import regex as re
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import AutoTokenizer, UMT5EncoderModel
|
| 24 |
+
|
| 25 |
+
from diffusers.utils import is_ftfy_available, logging
|
| 26 |
+
from diffusers.modular_pipelines import (
|
| 27 |
+
ModularPipelineBlocks,
|
| 28 |
+
ModularPipeline,
|
| 29 |
+
PipelineState,
|
| 30 |
+
)
|
| 31 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
| 32 |
+
ComponentSpec,
|
| 33 |
+
ConfigSpec,
|
| 34 |
+
InputParam,
|
| 35 |
+
OutputParam,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if is_ftfy_available():
|
| 39 |
+
import ftfy
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def basic_clean(text):
|
| 46 |
+
text = ftfy.fix_text(text)
|
| 47 |
+
text = html.unescape(html.unescape(text))
|
| 48 |
+
return text.strip()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def whitespace_clean(text):
|
| 52 |
+
text = re.sub(r"\s+", " ", text)
|
| 53 |
+
text = text.strip()
|
| 54 |
+
return text
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def prompt_clean(text):
|
| 58 |
+
text = whitespace_clean(basic_clean(text))
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class WorldEngineTextEncoderStep(ModularPipelineBlocks):
|
| 63 |
+
"""Encodes text prompts using UMT5-XL for conditioning."""
|
| 64 |
+
|
| 65 |
+
model_name = "world_engine"
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def description(self) -> str:
|
| 69 |
+
return (
|
| 70 |
+
"Text Encoder step that generates text embeddings to guide frame generation"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 75 |
+
return [
|
| 76 |
+
ComponentSpec("text_encoder", UMT5EncoderModel),
|
| 77 |
+
ComponentSpec("tokenizer", AutoTokenizer),
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def inputs(self) -> List[InputParam]:
|
| 82 |
+
return [
|
| 83 |
+
InputParam(
|
| 84 |
+
"prompt",
|
| 85 |
+
description="The prompt or prompts to guide the frame generation",
|
| 86 |
+
),
|
| 87 |
+
InputParam(
|
| 88 |
+
"prompt_embeds",
|
| 89 |
+
type_hint=torch.Tensor,
|
| 90 |
+
description="Pre-computed text embeddings",
|
| 91 |
+
),
|
| 92 |
+
InputParam(
|
| 93 |
+
"prompt_pad_mask",
|
| 94 |
+
type_hint=torch.Tensor,
|
| 95 |
+
description="Padding mask for prompt embeddings",
|
| 96 |
+
),
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 101 |
+
return [
|
| 102 |
+
OutputParam(
|
| 103 |
+
"prompt_embeds",
|
| 104 |
+
type_hint=torch.Tensor,
|
| 105 |
+
kwargs_type="denoiser_input_fields",
|
| 106 |
+
description="Text embeddings used to guide frame generation",
|
| 107 |
+
),
|
| 108 |
+
OutputParam(
|
| 109 |
+
"prompt_pad_mask",
|
| 110 |
+
type_hint=torch.Tensor,
|
| 111 |
+
kwargs_type="denoiser_input_fields",
|
| 112 |
+
description="Padding mask for prompt embeddings",
|
| 113 |
+
),
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def check_inputs(block_state):
|
| 118 |
+
if block_state.prompt is not None and (
|
| 119 |
+
not isinstance(block_state.prompt, str)
|
| 120 |
+
and not isinstance(block_state.prompt, list)
|
| 121 |
+
):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def encode_prompt(
|
| 128 |
+
components,
|
| 129 |
+
prompt: Union[str, List[str]],
|
| 130 |
+
device: torch.device,
|
| 131 |
+
max_sequence_length: int = 512,
|
| 132 |
+
):
|
| 133 |
+
dtype = components.text_encoder.dtype
|
| 134 |
+
|
| 135 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 136 |
+
prompt = [prompt_clean(p) for p in prompt]
|
| 137 |
+
|
| 138 |
+
text_inputs = components.tokenizer(
|
| 139 |
+
prompt,
|
| 140 |
+
padding="max_length",
|
| 141 |
+
max_length=max_sequence_length,
|
| 142 |
+
truncation=True,
|
| 143 |
+
return_attention_mask=True,
|
| 144 |
+
return_tensors="pt",
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 148 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 149 |
+
|
| 150 |
+
prompt_embeds = components.text_encoder(
|
| 151 |
+
text_input_ids, attention_mask
|
| 152 |
+
).last_hidden_state
|
| 153 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
| 154 |
+
|
| 155 |
+
# Zero out padding
|
| 156 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as(
|
| 157 |
+
prompt_embeds
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Create padding mask (True where padded)
|
| 161 |
+
prompt_pad_mask = attention_mask.eq(0)
|
| 162 |
+
|
| 163 |
+
return prompt_embeds, prompt_pad_mask
|
| 164 |
+
|
| 165 |
+
@torch.no_grad()
|
| 166 |
+
def __call__(
|
| 167 |
+
self, components: ModularPipeline, state: PipelineState
|
| 168 |
+
) -> PipelineState:
|
| 169 |
+
block_state = self.get_block_state(state)
|
| 170 |
+
self.check_inputs(block_state)
|
| 171 |
+
|
| 172 |
+
device = components._execution_device
|
| 173 |
+
if block_state.prompt_embeds is None:
|
| 174 |
+
block_state.prompt = block_state.prompt or "An explorable world"
|
| 175 |
+
(
|
| 176 |
+
block_state.prompt_embeds,
|
| 177 |
+
block_state.prompt_pad_mask,
|
| 178 |
+
) = self.encode_prompt(components, block_state.prompt, device)
|
| 179 |
+
block_state.prompt_embeds = block_state.prompt_embeds.contiguous()
|
| 180 |
+
|
| 181 |
+
if block_state.prompt_pad_mask is None:
|
| 182 |
+
block_state.prompt_pad_mask = torch.zeros(
|
| 183 |
+
block_state.prompt_embeds.shape[:2],
|
| 184 |
+
dtype=torch.bool,
|
| 185 |
+
device=device,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.set_block_state(state, block_state)
|
| 189 |
+
return components, state
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class WorldEngineControllerEncoderStep(ModularPipelineBlocks):
|
| 193 |
+
"""Encodes controller inputs (mouse + buttons + scroll) for conditioning."""
|
| 194 |
+
|
| 195 |
+
model_name = "world_engine"
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def description(self) -> str:
|
| 199 |
+
return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning"
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def expected_components(self) -> List[ComponentSpec]:
|
| 203 |
+
return [] # Controller embedding is part of transformer
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def expected_configs(self) -> List[ComponentSpec]:
|
| 207 |
+
return [ConfigSpec("n_buttons", 256)]
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def inputs(self) -> List[InputParam]:
|
| 211 |
+
return [
|
| 212 |
+
InputParam(
|
| 213 |
+
"button",
|
| 214 |
+
type_hint=Set[int],
|
| 215 |
+
default=set(),
|
| 216 |
+
description="Set of pressed button IDs",
|
| 217 |
+
),
|
| 218 |
+
InputParam(
|
| 219 |
+
"mouse",
|
| 220 |
+
type_hint=Tuple[float, float],
|
| 221 |
+
default=(0.0, 0.0),
|
| 222 |
+
description="Mouse velocity (x, y)",
|
| 223 |
+
),
|
| 224 |
+
InputParam(
|
| 225 |
+
"scroll",
|
| 226 |
+
type_hint=int,
|
| 227 |
+
default=0,
|
| 228 |
+
description="Scroll wheel direction (-1, 0, 1)",
|
| 229 |
+
),
|
| 230 |
+
InputParam(
|
| 231 |
+
"button_tensor",
|
| 232 |
+
type_hint=torch.Tensor,
|
| 233 |
+
kwargs_type="denoiser_input_fields",
|
| 234 |
+
description="One-hot encoded button tensor",
|
| 235 |
+
),
|
| 236 |
+
InputParam(
|
| 237 |
+
"mouse_tensor",
|
| 238 |
+
type_hint=torch.Tensor,
|
| 239 |
+
kwargs_type="denoiser_input_fields",
|
| 240 |
+
description="Mouse velocity tensor",
|
| 241 |
+
),
|
| 242 |
+
InputParam(
|
| 243 |
+
"scroll_tensor",
|
| 244 |
+
type_hint=torch.Tensor,
|
| 245 |
+
kwargs_type="denoiser_input_fields",
|
| 246 |
+
description="Scroll wheel sign tensor",
|
| 247 |
+
),
|
| 248 |
+
]
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def intermediate_outputs(self) -> List[OutputParam]:
|
| 252 |
+
return [
|
| 253 |
+
OutputParam(
|
| 254 |
+
"button_tensor",
|
| 255 |
+
type_hint=torch.Tensor,
|
| 256 |
+
kwargs_type="denoiser_input_fields",
|
| 257 |
+
description="One-hot encoded button tensor",
|
| 258 |
+
),
|
| 259 |
+
OutputParam(
|
| 260 |
+
"mouse_tensor",
|
| 261 |
+
type_hint=torch.Tensor,
|
| 262 |
+
kwargs_type="denoiser_input_fields",
|
| 263 |
+
description="Mouse velocity tensor",
|
| 264 |
+
),
|
| 265 |
+
OutputParam(
|
| 266 |
+
"scroll_tensor",
|
| 267 |
+
type_hint=torch.Tensor,
|
| 268 |
+
kwargs_type="denoiser_input_fields",
|
| 269 |
+
description="Scroll wheel sign tensor",
|
| 270 |
+
),
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
@torch.no_grad()
|
| 274 |
+
def __call__(
|
| 275 |
+
self, components: ModularPipeline, state: PipelineState
|
| 276 |
+
) -> PipelineState:
|
| 277 |
+
block_state = self.get_block_state(state)
|
| 278 |
+
device = components._execution_device
|
| 279 |
+
dtype = components.transformer.dtype
|
| 280 |
+
|
| 281 |
+
n_buttons = components.config.n_buttons
|
| 282 |
+
|
| 283 |
+
# Create or reuse button tensor [1, 1, n_buttons]
|
| 284 |
+
if block_state.button_tensor is None:
|
| 285 |
+
block_state.button_tensor = torch.zeros(
|
| 286 |
+
(1, 1, n_buttons), device=device, dtype=dtype
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Update button tensor in-place (avoid dynamic shapes for torch.compile)
|
| 290 |
+
block_state.button_tensor.zero_()
|
| 291 |
+
if block_state.button:
|
| 292 |
+
for btn_id in block_state.button:
|
| 293 |
+
if 0 <= btn_id < n_buttons:
|
| 294 |
+
block_state.button_tensor[0, 0, btn_id] = 1.0
|
| 295 |
+
|
| 296 |
+
# Create or reuse mouse tensor [1, 1, 2]
|
| 297 |
+
if block_state.mouse_tensor is None:
|
| 298 |
+
block_state.mouse_tensor = torch.zeros(
|
| 299 |
+
(1, 1, 2), device=device, dtype=dtype
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Update mouse tensor in-place
|
| 303 |
+
mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0)
|
| 304 |
+
block_state.mouse_tensor[0, 0, 0] = mouse[0]
|
| 305 |
+
block_state.mouse_tensor[0, 0, 1] = mouse[1]
|
| 306 |
+
|
| 307 |
+
# Create or reuse scroll tensor [1, 1, 1]
|
| 308 |
+
if block_state.scroll_tensor is None:
|
| 309 |
+
block_state.scroll_tensor = torch.zeros(
|
| 310 |
+
(1, 1, 1), device=device, dtype=dtype
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Update scroll tensor in-place (sign of scroll value: -1, 0, or 1)
|
| 314 |
+
scroll = block_state.scroll if block_state.scroll is not None else 0
|
| 315 |
+
block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0)
|
| 316 |
+
|
| 317 |
+
self.set_block_state(state, block_state)
|
| 318 |
+
return components, state
|
modular_blocks.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Block registry for WorldEngine modular pipeline."""
|
| 17 |
+
|
| 18 |
+
from diffusers.utils import logging
|
| 19 |
+
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
| 20 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
|
| 21 |
+
|
| 22 |
+
from .encoders import WorldEngineTextEncoderStep, WorldEngineControllerEncoderStep
|
| 23 |
+
from .before_denoise import WorldEngineBeforeDenoiseStep
|
| 24 |
+
from .denoise import WorldEngineDenoiseLoop
|
| 25 |
+
from .decoders import WorldEngineDecodeStep
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
AUTO_BLOCKS = InsertableDict(
|
| 31 |
+
[
|
| 32 |
+
("text_encoder", WorldEngineTextEncoderStep),
|
| 33 |
+
("controller_encoder", WorldEngineControllerEncoderStep),
|
| 34 |
+
("before_denoise", WorldEngineBeforeDenoiseStep),
|
| 35 |
+
("denoise", WorldEngineDenoiseLoop),
|
| 36 |
+
("decode", WorldEngineDecodeStep),
|
| 37 |
+
]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class WorldEngineBlocks(SequentialPipelineBlocks):
|
| 42 |
+
"""Sequential pipeline blocks for WorldEngine frame generation."""
|
| 43 |
+
|
| 44 |
+
block_classes = list(AUTO_BLOCKS.copy().values())
|
| 45 |
+
block_names = list(AUTO_BLOCKS.copy().keys())
|
modular_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "WorldEngineBlocks",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"ModularPipelineBlocks": "modular_blocks.WorldEngineBlocks"
|
| 6 |
+
}
|
| 7 |
+
}
|
modular_model_index.json
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_blocks_class_name": "WorldEngineBlocks",
|
| 3 |
+
"_class_name": "ModularPipeline",
|
| 4 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 5 |
+
"channels": 16,
|
| 6 |
+
"height": 360,
|
| 7 |
+
"width": 640,
|
| 8 |
+
"patch": [
|
| 9 |
+
2,
|
| 10 |
+
2
|
| 11 |
+
],
|
| 12 |
+
"vae_scale_factor": 16,
|
| 13 |
+
"n_buttons": 256,
|
| 14 |
+
"tokens_per_frame": 256,
|
| 15 |
+
"scheduler_sigmas": [
|
| 16 |
+
1.0,
|
| 17 |
+
0.8609585762023926,
|
| 18 |
+
0.729332447052002,
|
| 19 |
+
0.3205108940601349,
|
| 20 |
+
0.0
|
| 21 |
+
],
|
| 22 |
+
"transformer": [
|
| 23 |
+
null,
|
| 24 |
+
null,
|
| 25 |
+
{
|
| 26 |
+
"pretrained_model_name_or_path": "Overworld/Waypoint-1-Small",
|
| 27 |
+
"subfolder": "transformer",
|
| 28 |
+
"type_hint": [
|
| 29 |
+
"diffusers",
|
| 30 |
+
"AutoModel"
|
| 31 |
+
],
|
| 32 |
+
"revision": null,
|
| 33 |
+
"variant": null
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"vae": [
|
| 37 |
+
null,
|
| 38 |
+
null,
|
| 39 |
+
{
|
| 40 |
+
"pretrained_model_name_or_path": "Overworld/Waypoint-1-Small",
|
| 41 |
+
"subfolder": "vae",
|
| 42 |
+
"type_hint": [
|
| 43 |
+
"diffusers",
|
| 44 |
+
"AutoModel"
|
| 45 |
+
],
|
| 46 |
+
"revision": null,
|
| 47 |
+
"variant": null
|
| 48 |
+
}
|
| 49 |
+
],
|
| 50 |
+
"text_encoder": [
|
| 51 |
+
null,
|
| 52 |
+
null,
|
| 53 |
+
{
|
| 54 |
+
"pretrained_model_name_or_path": "google/umt5-xl",
|
| 55 |
+
"type_hint": [
|
| 56 |
+
"transformers",
|
| 57 |
+
"UMT5EncoderModel"
|
| 58 |
+
],
|
| 59 |
+
"revision": null,
|
| 60 |
+
"variant": null
|
| 61 |
+
}
|
| 62 |
+
],
|
| 63 |
+
"tokenizer": [
|
| 64 |
+
null,
|
| 65 |
+
null,
|
| 66 |
+
{
|
| 67 |
+
"pretrained_model_name_or_path": "google/umt5-xl",
|
| 68 |
+
"type_hint": [
|
| 69 |
+
"transformers",
|
| 70 |
+
"AutoTokenizer"
|
| 71 |
+
],
|
| 72 |
+
"revision": null,
|
| 73 |
+
"variant": null
|
| 74 |
+
}
|
| 75 |
+
]
|
| 76 |
+
}
|
transformer/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
from .model import WorldModel
|
| 17 |
+
from .attn import Attn, CrossAttention, OrthoRoPE
|
| 18 |
+
from .nn import MLP, AdaLN, NoiseConditioner, rms_norm, ada_rmsnorm, ada_gate
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"WorldModel",
|
| 22 |
+
"Attn",
|
| 23 |
+
"CrossAttention",
|
| 24 |
+
"OrthoRoPE",
|
| 25 |
+
"MLP",
|
| 26 |
+
"AdaLN",
|
| 27 |
+
"NoiseConditioner",
|
| 28 |
+
"rms_norm",
|
| 29 |
+
"ada_rmsnorm",
|
| 30 |
+
"ada_gate",
|
| 31 |
+
]
|
transformer/attn.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Attention mechanisms for WorldModel transformer."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import einops as eo
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 24 |
+
|
| 25 |
+
from .nn import rms_norm, NoCastModule
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def pixel_frequencies(dim: int, max_freq: float) -> torch.Tensor:
|
| 29 |
+
"""Linear frequency spectrum for spatial RoPE (pixel positions).
|
| 30 |
+
|
| 31 |
+
Matches rotary_embedding_torch RotaryEmbedding(freqs_for='pixel').
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dim: Output dimension (freqs will be repeated to fill this)
|
| 35 |
+
max_freq: Maximum frequency (should be below Nyquist)
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Tensor of shape [dim // 2] with linear frequencies
|
| 39 |
+
"""
|
| 40 |
+
# Library uses max_freq/2 as the upper bound
|
| 41 |
+
return torch.linspace(1.0, max_freq / 2, dim // 2) * math.pi
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def lang_frequencies(dim: int) -> torch.Tensor:
|
| 45 |
+
"""Geometric frequency spectrum for temporal RoPE (language-style).
|
| 46 |
+
|
| 47 |
+
Matches rotary_embedding_torch RotaryEmbedding(freqs_for='lang').
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
dim: Output dimension (freqs will be repeated to fill this)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Tensor of shape [dim // 2] with geometric frequencies
|
| 54 |
+
"""
|
| 55 |
+
# Library uses 10^(-i/2) pattern
|
| 56 |
+
return 10.0 ** (-torch.arange(dim // 2).float() / 2)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class OrthoRoPE(NoCastModule):
|
| 60 |
+
"""Rotary Position Embeddings for orthogonal axes: time, height, and width.
|
| 61 |
+
|
| 62 |
+
- Time: Geometric spectrum (like language models) -- rotates 1/2 of head dim
|
| 63 |
+
- Height/Width: Linear spectrum (for pixels) -- rotates 1/4 of head dim each
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, config):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.config = config
|
| 69 |
+
assert not getattr(self.config, "has_audio", False)
|
| 70 |
+
|
| 71 |
+
# Compute frequencies and store cos/sin buffers
|
| 72 |
+
freqs = self._compute_freqs()
|
| 73 |
+
self.cos = nn.Buffer(freqs.cos().contiguous(), persistent=False)
|
| 74 |
+
self.sin = nn.Buffer(freqs.sin().contiguous(), persistent=False)
|
| 75 |
+
|
| 76 |
+
def _compute_freqs(self):
|
| 77 |
+
"""Compute frequency table for all positions.
|
| 78 |
+
|
| 79 |
+
Matches the behavior of rotary_embedding_torch.RotaryEmbedding.
|
| 80 |
+
The library interleaves frequencies so each freq value is used twice.
|
| 81 |
+
"""
|
| 82 |
+
config = self.config
|
| 83 |
+
H, W, T = config.height, config.width, config.n_frames
|
| 84 |
+
head_dim = config.d_model // config.n_heads
|
| 85 |
+
|
| 86 |
+
# Spatial frequencies (linear spectrum, below Nyquist)
|
| 87 |
+
# Library: RotaryEmbedding(dim=head_dim//8) creates head_dim//16 freqs,
|
| 88 |
+
# outputs head_dim//8 values (each freq repeated twice)
|
| 89 |
+
max_freq = min(H, W) * 0.8
|
| 90 |
+
spatial_freqs = pixel_frequencies(head_dim // 8, max_freq) # [D/16]
|
| 91 |
+
|
| 92 |
+
# Positions in [-1, 1] range
|
| 93 |
+
pos_x = torch.linspace(-1 + 1 / W, 1 - 1 / W, W) # [W]
|
| 94 |
+
pos_y = torch.linspace(-1 + 1 / H, 1 - 1 / H, H) # [H]
|
| 95 |
+
|
| 96 |
+
# Spatial frequency embeddings with interleaving (like library)
|
| 97 |
+
freqs_x = torch.outer(pos_x, spatial_freqs) # [W, D/16]
|
| 98 |
+
freqs_y = torch.outer(pos_y, spatial_freqs) # [H, D/16]
|
| 99 |
+
freqs_x = freqs_x.repeat_interleave(2, dim=-1) # [W, D/8]
|
| 100 |
+
freqs_y = freqs_y.repeat_interleave(2, dim=-1) # [H, D/8]
|
| 101 |
+
|
| 102 |
+
# Expand to grid and repeat for all frames
|
| 103 |
+
freqs_x = freqs_x[None, :, :].expand(H, W, -1) # [H, W, D/8]
|
| 104 |
+
freqs_y = freqs_y[:, None, :].expand(H, W, -1) # [H, W, D/8]
|
| 105 |
+
|
| 106 |
+
freqs_x = eo.repeat(freqs_x, "h w d -> (t h w) d", t=T) # [T*H*W, D/8]
|
| 107 |
+
freqs_y = eo.repeat(freqs_y, "h w d -> (t h w) d", t=T) # [T*H*W, D/8]
|
| 108 |
+
|
| 109 |
+
# Temporal frequencies (geometric spectrum)
|
| 110 |
+
# Library: RotaryEmbedding(dim=head_dim//4) creates head_dim//8 freqs,
|
| 111 |
+
# outputs head_dim//4 values (each freq repeated twice)
|
| 112 |
+
temporal_freqs = lang_frequencies(head_dim // 4) # [D/8]
|
| 113 |
+
pos_t = torch.arange(T).float() # [T]
|
| 114 |
+
freqs_t = torch.outer(pos_t, temporal_freqs) # [T, D/8]
|
| 115 |
+
freqs_t = freqs_t.repeat_interleave(2, dim=-1) # [T, D/4]
|
| 116 |
+
freqs_t = eo.repeat(freqs_t, "t d -> (t h w) d", h=H, w=W) # [T*H*W, D/4]
|
| 117 |
+
|
| 118 |
+
# Concatenate: [X, Y, T] -> [T*H*W, D/2]
|
| 119 |
+
return torch.cat([freqs_x, freqs_y, freqs_t], dim=-1)
|
| 120 |
+
|
| 121 |
+
def get_angles(self, pos_ids):
|
| 122 |
+
"""Look up cos/sin angles for given position IDs."""
|
| 123 |
+
t, y, x = pos_ids["t_pos"], pos_ids["y_pos"], pos_ids["x_pos"] # [B,T]
|
| 124 |
+
H, W = self.config.height, self.config.width
|
| 125 |
+
if not torch.compiler.is_compiling():
|
| 126 |
+
torch._assert(
|
| 127 |
+
(y.max() < H) & (x.max() < W),
|
| 128 |
+
f"pos_ids out of bounds, {y.max()}, {x.max()}",
|
| 129 |
+
)
|
| 130 |
+
flat = t * (H * W) + y * W + x # [B,T]
|
| 131 |
+
idx = flat.reshape(-1).to(torch.long)
|
| 132 |
+
cos = self.cos.index_select(0, idx).view(*flat.shape, -1)
|
| 133 |
+
sin = self.sin.index_select(0, idx).view(*flat.shape, -1)
|
| 134 |
+
return cos[:, None], sin[:, None] # add head dim for broadcast
|
| 135 |
+
|
| 136 |
+
@torch.autocast("cuda", enabled=False)
|
| 137 |
+
def forward(self, x, pos_ids):
|
| 138 |
+
assert self.cos.dtype == self.sin.dtype == torch.float32
|
| 139 |
+
cos, sin = self.get_angles(pos_ids)
|
| 140 |
+
x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1)
|
| 141 |
+
y0 = x0 * cos - x1 * sin
|
| 142 |
+
y1 = x1 * cos + x0 * sin
|
| 143 |
+
return torch.cat((y0, y1), dim=-1).type_as(x)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class Attn(nn.Module):
|
| 147 |
+
"""Self-attention with RoPE and optional GQA, value residual, and gated attention."""
|
| 148 |
+
|
| 149 |
+
def __init__(self, config, layer_idx):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.config = config
|
| 152 |
+
self.layer_idx = layer_idx
|
| 153 |
+
|
| 154 |
+
self.value_residual = getattr(config, "value_residual", False)
|
| 155 |
+
if self.value_residual:
|
| 156 |
+
self.v_lamb = nn.Parameter(torch.tensor(0.5))
|
| 157 |
+
|
| 158 |
+
self.n_heads = config.n_heads
|
| 159 |
+
self.n_kv_heads = getattr(config, "n_kv_heads", config.n_heads)
|
| 160 |
+
self.d_head = config.d_model // self.n_heads
|
| 161 |
+
assert config.d_model % self.n_heads == 0
|
| 162 |
+
|
| 163 |
+
self.enable_gqa = self.n_heads != self.n_kv_heads
|
| 164 |
+
|
| 165 |
+
self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False)
|
| 166 |
+
self.k_proj = nn.Linear(
|
| 167 |
+
config.d_model, self.n_kv_heads * self.d_head, bias=False
|
| 168 |
+
)
|
| 169 |
+
self.v_proj = nn.Linear(
|
| 170 |
+
config.d_model, self.n_kv_heads * self.d_head, bias=False
|
| 171 |
+
)
|
| 172 |
+
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
| 173 |
+
|
| 174 |
+
self.rope = OrthoRoPE(config)
|
| 175 |
+
|
| 176 |
+
self.gated_attn = getattr(config, "gated_attn", False)
|
| 177 |
+
if self.gated_attn:
|
| 178 |
+
self.gate_proj = nn.Linear(
|
| 179 |
+
self.n_heads, self.n_heads, bias=False
|
| 180 |
+
) # sparse attn gate
|
| 181 |
+
nn.init.zeros_(self.gate_proj.weight)
|
| 182 |
+
|
| 183 |
+
def forward(self, x, pos_ids, v1, kv_cache):
|
| 184 |
+
# Q, K, V proj -> QK-norm -> RoPE
|
| 185 |
+
q = eo.rearrange(
|
| 186 |
+
self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head
|
| 187 |
+
)
|
| 188 |
+
k = eo.rearrange(
|
| 189 |
+
self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
|
| 190 |
+
)
|
| 191 |
+
v = eo.rearrange(
|
| 192 |
+
self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if self.value_residual:
|
| 196 |
+
v1 = v if v1 is None else v1
|
| 197 |
+
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
|
| 198 |
+
|
| 199 |
+
q, k = rms_norm(q), rms_norm(k)
|
| 200 |
+
q, k = self.rope(q, pos_ids), self.rope(k, pos_ids)
|
| 201 |
+
|
| 202 |
+
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
|
| 203 |
+
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
|
| 204 |
+
|
| 205 |
+
if self.gated_attn:
|
| 206 |
+
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
|
| 207 |
+
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
|
| 208 |
+
y = eo.rearrange(y, "b h t d -> b t (h d)")
|
| 209 |
+
y = self.out_proj(y)
|
| 210 |
+
return y, v1
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class MergedQKVAttn(Attn):
|
| 214 |
+
def __init__(self, src: Attn, config):
|
| 215 |
+
super().__init__(config, src.layer_idx) # makes fresh q/k/v/out/etc
|
| 216 |
+
self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype)
|
| 217 |
+
self.load_state_dict(
|
| 218 |
+
src.state_dict(), strict=False
|
| 219 |
+
) # copies trained weights/buffers
|
| 220 |
+
self.train(src.training) # preserve train/eval mode
|
| 221 |
+
|
| 222 |
+
self.q_out = self.n_heads * self.d_head
|
| 223 |
+
self.kv_out = self.n_kv_heads * self.d_head
|
| 224 |
+
|
| 225 |
+
self.qkv_proj = nn.Linear(
|
| 226 |
+
self.q_proj.in_features,
|
| 227 |
+
self.q_out + 2 * self.kv_out,
|
| 228 |
+
bias=False,
|
| 229 |
+
device=self.q_proj.weight.device,
|
| 230 |
+
dtype=self.q_proj.weight.dtype,
|
| 231 |
+
)
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
self.qkv_proj.weight.copy_(
|
| 234 |
+
torch.cat(
|
| 235 |
+
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
|
| 236 |
+
)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
del self.q_proj, self.k_proj, self.v_proj
|
| 240 |
+
|
| 241 |
+
def forward(self, x, pos_ids, v1, kv_cache):
|
| 242 |
+
q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1)
|
| 243 |
+
|
| 244 |
+
B, T = x.shape[:2]
|
| 245 |
+
q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
|
| 246 |
+
k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
|
| 247 |
+
v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
|
| 248 |
+
|
| 249 |
+
if self.value_residual:
|
| 250 |
+
v1 = v if v1 is None else v1
|
| 251 |
+
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
|
| 252 |
+
|
| 253 |
+
q, k = rms_norm(q), rms_norm(k)
|
| 254 |
+
q, k = self.rope(q, pos_ids), self.rope(k, pos_ids)
|
| 255 |
+
|
| 256 |
+
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
|
| 257 |
+
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
|
| 258 |
+
|
| 259 |
+
if self.gated_attn:
|
| 260 |
+
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
|
| 261 |
+
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
|
| 262 |
+
|
| 263 |
+
y = y.transpose(1, 2).reshape(B, T, -1)
|
| 264 |
+
y = self.out_proj(y)
|
| 265 |
+
return y, v1
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class CrossAttention(nn.Module):
|
| 269 |
+
"""Cross-attention for prompt conditioning."""
|
| 270 |
+
|
| 271 |
+
def __init__(self, config, context_dim=None):
|
| 272 |
+
super().__init__()
|
| 273 |
+
assert config.d_model % config.n_heads == 0
|
| 274 |
+
|
| 275 |
+
self.d_head = config.d_model // config.n_heads
|
| 276 |
+
self.inner_dim = context_dim or config.d_model
|
| 277 |
+
assert self.inner_dim % self.d_head == 0
|
| 278 |
+
self.n_heads = self.inner_dim // self.d_head
|
| 279 |
+
self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False)
|
| 280 |
+
self.k_proj = nn.Linear(
|
| 281 |
+
context_dim or config.d_model, self.inner_dim, bias=False
|
| 282 |
+
)
|
| 283 |
+
self.v_proj = nn.Linear(
|
| 284 |
+
context_dim or config.d_model, self.inner_dim, bias=False
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False)
|
| 288 |
+
self.out_proj.weight.detach().zero_()
|
| 289 |
+
|
| 290 |
+
def forward(self, x, context, context_pad_mask=None):
|
| 291 |
+
q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads)
|
| 292 |
+
k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
|
| 293 |
+
v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
|
| 294 |
+
q, k = rms_norm(q), rms_norm(k)
|
| 295 |
+
out = flex_attention(q, k, v)
|
| 296 |
+
out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1)
|
| 297 |
+
return self.out_proj(out)
|
transformer/cache.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn, Tensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _bf16_u16(x: Tensor) -> Tensor:
|
| 21 |
+
# reinterpret bf16 storage as int16 -> unsigned 0..65535 in int32
|
| 22 |
+
return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class CachedDenoiseStepEmb(nn.Module):
|
| 26 |
+
"""bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong)."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, base: nn.Module, sigmas: list[float]):
|
| 29 |
+
super().__init__()
|
| 30 |
+
device = next(base.parameters()).device
|
| 31 |
+
|
| 32 |
+
levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) # [S]
|
| 33 |
+
bits = _bf16_u16(levels) # [S]
|
| 34 |
+
if torch.unique(bits).numel() != bits.numel():
|
| 35 |
+
raise ValueError(
|
| 36 |
+
"scheduler_sigmas collide in bf16; caching would be ambiguous"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
table = (
|
| 41 |
+
base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
|
| 42 |
+
) # [S,D]
|
| 43 |
+
|
| 44 |
+
lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
|
| 45 |
+
lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)
|
| 46 |
+
|
| 47 |
+
self.register_buffer("table", table, persistent=False) # [S,D] bf16
|
| 48 |
+
self.register_buffer("lut", lut, persistent=False) # [65536] int32
|
| 49 |
+
self.register_buffer(
|
| 50 |
+
"oob",
|
| 51 |
+
torch.tensor(bits.numel(), device=device, dtype=torch.int32),
|
| 52 |
+
persistent=False,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, sigma: Tensor) -> Tensor:
|
| 56 |
+
if sigma.dtype is not torch.bfloat16:
|
| 57 |
+
raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
|
| 58 |
+
idx = self.lut[_bf16_u16(sigma)]
|
| 59 |
+
idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB)
|
| 60 |
+
return self.table[idx.to(torch.int64)] # [...,D] bf16
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CachedCondHead(nn.Module):
|
| 64 |
+
"""bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong)."""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
table = cached_denoise_step_emb.table # [S,D] bf16
|
| 71 |
+
S, D = table.shape
|
| 72 |
+
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
emb = table[:, None, :] # [S,1,D]
|
| 75 |
+
cache = (
|
| 76 |
+
torch.stack([t.squeeze(1) for t in base(emb)], 0)
|
| 77 |
+
.to(torch.bfloat16)
|
| 78 |
+
.contiguous()
|
| 79 |
+
) # [6,S,D]
|
| 80 |
+
|
| 81 |
+
# pick a single embedding dimension whose bf16 bits uniquely identify sigma
|
| 82 |
+
key_dim = None
|
| 83 |
+
for d in range(min(D, max_key_dims)):
|
| 84 |
+
b = _bf16_u16(table[:, d])
|
| 85 |
+
if torch.unique(b).numel() == S:
|
| 86 |
+
key_dim = d
|
| 87 |
+
key_bits = b
|
| 88 |
+
break
|
| 89 |
+
if key_dim is None:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
"Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
|
| 95 |
+
lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)
|
| 96 |
+
|
| 97 |
+
self.key_dim = int(key_dim)
|
| 98 |
+
self.register_buffer("cache", cache, persistent=False) # [6,S,D] bf16
|
| 99 |
+
self.register_buffer("lut", lut, persistent=False) # [65536] int32
|
| 100 |
+
self.register_buffer(
|
| 101 |
+
"oob",
|
| 102 |
+
torch.tensor(S, device=table.device, dtype=torch.int32),
|
| 103 |
+
persistent=False,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def forward(self, cond: Tensor):
|
| 107 |
+
if cond.dtype is not torch.bfloat16:
|
| 108 |
+
raise RuntimeError("CachedCondHead expects cond bf16")
|
| 109 |
+
idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
|
| 110 |
+
idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB)
|
| 111 |
+
g = self.cache[:, idx.to(torch.int64)] # [6,...,D] bf16 (or errors)
|
| 112 |
+
return tuple(g.unbind(0)) # (s0,b0,g0,s1,b1,g1)
|
transformer/config.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "WorldModel",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model.WorldModel"
|
| 6 |
+
},
|
| 7 |
+
"d_model": 2560,
|
| 8 |
+
"n_heads": 40,
|
| 9 |
+
"n_kv_heads": 20,
|
| 10 |
+
"n_layers": 22,
|
| 11 |
+
"mlp_ratio": 5,
|
| 12 |
+
"channels": 16,
|
| 13 |
+
"height": 16,
|
| 14 |
+
"width": 16,
|
| 15 |
+
"patch": [
|
| 16 |
+
2,
|
| 17 |
+
2
|
| 18 |
+
],
|
| 19 |
+
"tokens_per_frame": 256,
|
| 20 |
+
"n_frames": 4096,
|
| 21 |
+
"local_window": 16,
|
| 22 |
+
"global_window": 128,
|
| 23 |
+
"global_attn_period": 4,
|
| 24 |
+
"global_pinned_dilation": 8,
|
| 25 |
+
"global_attn_offset": 0,
|
| 26 |
+
"value_residual": false,
|
| 27 |
+
"gated_attn": true,
|
| 28 |
+
"n_buttons": 256,
|
| 29 |
+
"ctrl_conditioning": "mlp_fusion",
|
| 30 |
+
"ctrl_conditioning_period": 3,
|
| 31 |
+
"ctrl_cond_dropout": 0.0,
|
| 32 |
+
"prompt_conditioning": "cross_attention",
|
| 33 |
+
"prompt_conditioning_period": 3,
|
| 34 |
+
"prompt_embedding_dim": 2048,
|
| 35 |
+
"prompt_cond_dropout": 0.0,
|
| 36 |
+
"noise_conditioning": "wan",
|
| 37 |
+
"base_fps": 60,
|
| 38 |
+
"causal": true,
|
| 39 |
+
"mlp_gradient_checkpointing": true,
|
| 40 |
+
"block_gradient_checkpointing": true,
|
| 41 |
+
"rope_impl": "ortho",
|
| 42 |
+
"scheduler_sigmas": [
|
| 43 |
+
1.0,
|
| 44 |
+
0.8609585762023926,
|
| 45 |
+
0.729332447052002,
|
| 46 |
+
0.3205108940601349,
|
| 47 |
+
0.0
|
| 48 |
+
]
|
| 49 |
+
}
|
transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14356db9229453850f9ad650f31c3e1c4744066abd43562f6fbee161fb36c9e6
|
| 3 |
+
size 12515075376
|
transformer/model.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""WorldModel transformer for frame generation."""
|
| 17 |
+
|
| 18 |
+
from typing import Optional, List
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
import einops as eo
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn, Tensor
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from tensordict import TensorDict
|
| 26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 28 |
+
|
| 29 |
+
from .attn import Attn, MergedQKVAttn, CrossAttention
|
| 30 |
+
from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm
|
| 31 |
+
from .quantize import quantize_model
|
| 32 |
+
from .cache import CachedDenoiseStepEmb, CachedCondHead
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def patch_cached_noise_conditioning(model) -> None:
|
| 36 |
+
# Call AFTER: model.to(device="cuda", dtype=torch.bfloat16).eval()
|
| 37 |
+
cached_denoise_step_emb = CachedDenoiseStepEmb(
|
| 38 |
+
model.denoise_step_emb, model.config.scheduler_sigmas
|
| 39 |
+
)
|
| 40 |
+
model.denoise_step_emb = cached_denoise_step_emb
|
| 41 |
+
for blk in model.transformer.blocks:
|
| 42 |
+
blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def patch_Attn_merge_qkv(model) -> None:
|
| 46 |
+
for name, mod in list(model.named_modules()):
|
| 47 |
+
if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn):
|
| 48 |
+
model.set_submodule(name, MergedQKVAttn(mod, model.config))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def patch_MLPFusion_split(model) -> None:
|
| 52 |
+
for name, mod in list(model.named_modules()):
|
| 53 |
+
if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion):
|
| 54 |
+
model.set_submodule(name, SplitMLPFusion(mod))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _apply_inference_patches(model) -> None:
|
| 58 |
+
patch_cached_noise_conditioning(model)
|
| 59 |
+
patch_Attn_merge_qkv(model)
|
| 60 |
+
patch_MLPFusion_split(model)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CFG(nn.Module):
|
| 64 |
+
def __init__(self, d_model: int, dropout: float):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.dropout = dropout
|
| 67 |
+
self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 68 |
+
|
| 69 |
+
def forward(
|
| 70 |
+
self, x: torch.Tensor, is_conditioned: Optional[bool] = None
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
"""
|
| 73 |
+
x: [B, L, D]
|
| 74 |
+
is_conditioned:
|
| 75 |
+
- None: training-style random dropout
|
| 76 |
+
- bool: whole batch conditioned / unconditioned at sampling
|
| 77 |
+
"""
|
| 78 |
+
B, L, _ = x.shape
|
| 79 |
+
null = self.null_emb.expand(B, L, -1)
|
| 80 |
+
|
| 81 |
+
# training-style dropout OR unspecified
|
| 82 |
+
if self.training or is_conditioned is None:
|
| 83 |
+
if self.dropout == 0.0:
|
| 84 |
+
return x
|
| 85 |
+
drop = torch.rand(B, 1, 1, device=x.device) < self.dropout # [B,1,1]
|
| 86 |
+
return torch.where(drop, null, x)
|
| 87 |
+
|
| 88 |
+
# sampling-time switch
|
| 89 |
+
return x if is_conditioned else null
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ControllerInputEmbedding(nn.Module):
|
| 93 |
+
"""Embeds controller inputs (mouse + buttons) into model dimension."""
|
| 94 |
+
|
| 95 |
+
def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) # mouse velocity (x,y) + scroll sign
|
| 98 |
+
|
| 99 |
+
def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor):
|
| 100 |
+
assert len(mouse.shape) == 3
|
| 101 |
+
x = torch.cat((mouse, button, scroll), dim=-1)
|
| 102 |
+
return self.mlp(x)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class MLPFusion(nn.Module):
|
| 106 |
+
"""Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond])."""
|
| 107 |
+
|
| 108 |
+
def __init__(self, d_model: int):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.mlp = MLP(2 * d_model, d_model, d_model)
|
| 111 |
+
|
| 112 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
B, _, D = x.shape
|
| 114 |
+
L = cond.shape[1]
|
| 115 |
+
|
| 116 |
+
Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) # each [D, D]
|
| 117 |
+
|
| 118 |
+
x = x.view(B, L, -1, D)
|
| 119 |
+
h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze(
|
| 120 |
+
2
|
| 121 |
+
) # broadcast, no repeat/cat
|
| 122 |
+
h = F.silu(h)
|
| 123 |
+
y = F.linear(h, self.mlp.fc2.weight)
|
| 124 |
+
return y.flatten(1, 2)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class SplitMLPFusion(nn.Module):
|
| 128 |
+
"""Packed MLPFusion -> split linears (no cat, quant-friendly)."""
|
| 129 |
+
|
| 130 |
+
def __init__(self, src: MLPFusion):
|
| 131 |
+
super().__init__()
|
| 132 |
+
D = src.mlp.fc2.in_features
|
| 133 |
+
dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype
|
| 134 |
+
|
| 135 |
+
self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
|
| 136 |
+
self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
|
| 137 |
+
self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt)
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1)
|
| 141 |
+
self.fc1_x.weight.copy_(Wx)
|
| 142 |
+
self.fc1_c.weight.copy_(Wc)
|
| 143 |
+
self.fc2.weight.copy_(src.mlp.fc2.weight)
|
| 144 |
+
|
| 145 |
+
self.train(src.training)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
B, _, D = x.shape
|
| 149 |
+
L = cond.shape[1]
|
| 150 |
+
x = x.reshape(B, L, -1, D)
|
| 151 |
+
return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(
|
| 152 |
+
1, 2
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class CondHead(nn.Module):
|
| 157 |
+
"""Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond)."""
|
| 158 |
+
|
| 159 |
+
n_cond = 6
|
| 160 |
+
|
| 161 |
+
def __init__(self, d_model: int, noise_conditioning: str = "wan"):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.bias_in = (
|
| 164 |
+
nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None
|
| 165 |
+
)
|
| 166 |
+
self.cond_proj = nn.ModuleList(
|
| 167 |
+
[nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def forward(self, cond):
|
| 171 |
+
cond = cond + self.bias_in if self.bias_in is not None else cond
|
| 172 |
+
h = F.silu(cond)
|
| 173 |
+
return tuple(p(h) for p in self.cond_proj)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class WorldDiTBlock(nn.Module):
|
| 177 |
+
"""Single transformer block with self-attention, optional cross-attention, and MLP."""
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
d_model: int,
|
| 182 |
+
n_heads: int,
|
| 183 |
+
mlp_ratio: int,
|
| 184 |
+
layer_idx: int,
|
| 185 |
+
prompt_conditioning: Optional[str],
|
| 186 |
+
prompt_conditioning_period: int,
|
| 187 |
+
prompt_embedding_dim: int,
|
| 188 |
+
ctrl_conditioning_period: int,
|
| 189 |
+
noise_conditioning: str,
|
| 190 |
+
config,
|
| 191 |
+
):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.config = config
|
| 194 |
+
self.attn = Attn(config, layer_idx)
|
| 195 |
+
self.mlp = MLP(d_model, d_model * mlp_ratio, d_model)
|
| 196 |
+
self.cond_head = CondHead(d_model, noise_conditioning)
|
| 197 |
+
|
| 198 |
+
do_prompt_cond = (
|
| 199 |
+
prompt_conditioning is not None
|
| 200 |
+
and layer_idx % prompt_conditioning_period == 0
|
| 201 |
+
)
|
| 202 |
+
self.prompt_cross_attn = (
|
| 203 |
+
CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None
|
| 204 |
+
)
|
| 205 |
+
do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0
|
| 206 |
+
self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None
|
| 207 |
+
|
| 208 |
+
def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None):
|
| 209 |
+
"""
|
| 210 |
+
0) Causal Frame Attention
|
| 211 |
+
1) Frame->CTX Cross Attention
|
| 212 |
+
2) MLP
|
| 213 |
+
"""
|
| 214 |
+
s0, b0, g0, s1, b1, g1 = self.cond_head(cond)
|
| 215 |
+
|
| 216 |
+
# Self / Causal Attention
|
| 217 |
+
residual = x
|
| 218 |
+
x = ada_rmsnorm(x, s0, b0)
|
| 219 |
+
x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache)
|
| 220 |
+
x = ada_gate(x, g0) + residual
|
| 221 |
+
|
| 222 |
+
# Cross Attention Prompt Conditioning
|
| 223 |
+
if self.prompt_cross_attn is not None:
|
| 224 |
+
x = (
|
| 225 |
+
self.prompt_cross_attn(
|
| 226 |
+
rms_norm(x),
|
| 227 |
+
context=rms_norm(ctx["prompt_emb"]),
|
| 228 |
+
context_pad_mask=ctx["prompt_pad_mask"],
|
| 229 |
+
)
|
| 230 |
+
+ x
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# MLPFusion Controller Conditioning
|
| 234 |
+
if self.ctrl_mlpfusion is not None:
|
| 235 |
+
x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x
|
| 236 |
+
|
| 237 |
+
# MLP
|
| 238 |
+
x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x
|
| 239 |
+
|
| 240 |
+
return x, v
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class WorldDiT(nn.Module):
|
| 244 |
+
"""Stack of WorldDiTBlocks with shared parameters."""
|
| 245 |
+
|
| 246 |
+
def __init__(self, config):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.config = config
|
| 249 |
+
self.blocks = nn.ModuleList(
|
| 250 |
+
[
|
| 251 |
+
WorldDiTBlock(
|
| 252 |
+
d_model=config.d_model,
|
| 253 |
+
n_heads=config.n_heads,
|
| 254 |
+
mlp_ratio=config.mlp_ratio,
|
| 255 |
+
layer_idx=idx,
|
| 256 |
+
prompt_conditioning=config.prompt_conditioning,
|
| 257 |
+
prompt_conditioning_period=config.prompt_conditioning_period,
|
| 258 |
+
prompt_embedding_dim=config.prompt_embedding_dim,
|
| 259 |
+
ctrl_conditioning_period=config.ctrl_conditioning_period,
|
| 260 |
+
noise_conditioning=config.noise_conditioning,
|
| 261 |
+
config=config,
|
| 262 |
+
)
|
| 263 |
+
for idx in range(config.n_layers)
|
| 264 |
+
]
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if config.noise_conditioning in ("dit_air", "wan"):
|
| 268 |
+
ref_proj = self.blocks[0].cond_head.cond_proj
|
| 269 |
+
for blk in self.blocks[1:]:
|
| 270 |
+
for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj):
|
| 271 |
+
blk_mod.weight = ref_mod.weight
|
| 272 |
+
|
| 273 |
+
# Shared RoPE buffers
|
| 274 |
+
ref_rope = self.blocks[0].attn.rope
|
| 275 |
+
for blk in self.blocks[1:]:
|
| 276 |
+
blk.attn.rope = ref_rope
|
| 277 |
+
|
| 278 |
+
def forward(self, x, pos_ids, cond, ctx, kv_cache=None):
|
| 279 |
+
v = None
|
| 280 |
+
for i, block in enumerate(self.blocks):
|
| 281 |
+
x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache)
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class WorldModel(ModelMixin, ConfigMixin):
|
| 286 |
+
"""
|
| 287 |
+
WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser.
|
| 288 |
+
|
| 289 |
+
Denoises a frame given:
|
| 290 |
+
- All previous frames (via KV cache)
|
| 291 |
+
- The prompt embedding
|
| 292 |
+
- The controller input embedding
|
| 293 |
+
- The current noise level
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
_supports_gradient_checkpointing = False
|
| 297 |
+
_keep_in_fp32_modules = ["denoise_step_emb", "rope"]
|
| 298 |
+
|
| 299 |
+
@register_to_config
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
# Model architecture
|
| 303 |
+
d_model: int = 2560,
|
| 304 |
+
n_heads: int = 40,
|
| 305 |
+
n_kv_heads: Optional[int] = 20,
|
| 306 |
+
n_layers: int = 22,
|
| 307 |
+
mlp_ratio: int = 5,
|
| 308 |
+
channels: int = 16,
|
| 309 |
+
height: int = 16,
|
| 310 |
+
width: int = 16,
|
| 311 |
+
patch: tuple = (2, 2),
|
| 312 |
+
tokens_per_frame: int = 256,
|
| 313 |
+
n_frames: int = 512,
|
| 314 |
+
local_window: int = 16,
|
| 315 |
+
global_window: int = 128,
|
| 316 |
+
global_attn_period: int = 4,
|
| 317 |
+
global_pinned_dilation: int = 8,
|
| 318 |
+
global_attn_offset: int = -1,
|
| 319 |
+
value_residual: bool = False,
|
| 320 |
+
gated_attn: bool = True,
|
| 321 |
+
n_buttons: int = 256,
|
| 322 |
+
ctrl_conditioning: Optional[str] = "mlp_fusion",
|
| 323 |
+
ctrl_conditioning_period: int = 3,
|
| 324 |
+
ctrl_cond_dropout: float = 0.0,
|
| 325 |
+
prompt_conditioning: Optional[str] = "cross_attention",
|
| 326 |
+
prompt_conditioning_period: int = 3,
|
| 327 |
+
prompt_embedding_dim: int = 2048,
|
| 328 |
+
prompt_cond_dropout: float = 0.0,
|
| 329 |
+
noise_conditioning: str = "wan",
|
| 330 |
+
scheduler_sigmas: Optional[List[float]] = [
|
| 331 |
+
1.0,
|
| 332 |
+
0.9483006596565247,
|
| 333 |
+
0.8379597067832947,
|
| 334 |
+
0.0,
|
| 335 |
+
],
|
| 336 |
+
base_fps: int = 60,
|
| 337 |
+
causal: bool = True,
|
| 338 |
+
mlp_gradient_checkpointing: bool = True,
|
| 339 |
+
block_gradient_checkpointing: bool = True,
|
| 340 |
+
rope_impl: str = "ortho",
|
| 341 |
+
):
|
| 342 |
+
super().__init__()
|
| 343 |
+
|
| 344 |
+
self.denoise_step_emb = NoiseConditioner(d_model)
|
| 345 |
+
self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio)
|
| 346 |
+
|
| 347 |
+
if self.config.ctrl_conditioning is not None:
|
| 348 |
+
self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout)
|
| 349 |
+
if self.config.prompt_conditioning is not None:
|
| 350 |
+
self.prompt_cfg = CFG(
|
| 351 |
+
self.config.prompt_embedding_dim, self.config.prompt_cond_dropout
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
self.transformer = WorldDiT(self.config)
|
| 355 |
+
self.patch = tuple(patch)
|
| 356 |
+
|
| 357 |
+
C, D = channels, d_model
|
| 358 |
+
self.patchify = nn.Conv2d(
|
| 359 |
+
C, D, kernel_size=self.patch, stride=self.patch, bias=False
|
| 360 |
+
)
|
| 361 |
+
self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True)
|
| 362 |
+
self.out_norm = AdaLN(d_model)
|
| 363 |
+
|
| 364 |
+
# Cached 1-frame pos_ids (buffers + cached TensorDict view)
|
| 365 |
+
T = tokens_per_frame
|
| 366 |
+
idx = torch.arange(T, dtype=torch.long)
|
| 367 |
+
self.register_buffer(
|
| 368 |
+
"_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False
|
| 369 |
+
)
|
| 370 |
+
self.register_buffer(
|
| 371 |
+
"_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False
|
| 372 |
+
)
|
| 373 |
+
self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False)
|
| 374 |
+
|
| 375 |
+
def forward(
|
| 376 |
+
self,
|
| 377 |
+
x: Tensor,
|
| 378 |
+
sigma: Tensor,
|
| 379 |
+
frame_timestamp: Tensor,
|
| 380 |
+
prompt_emb: Optional[Tensor] = None,
|
| 381 |
+
prompt_pad_mask: Optional[Tensor] = None,
|
| 382 |
+
mouse: Optional[Tensor] = None,
|
| 383 |
+
button: Optional[Tensor] = None,
|
| 384 |
+
scroll: Optional[Tensor] = None,
|
| 385 |
+
kv_cache=None,
|
| 386 |
+
):
|
| 387 |
+
"""
|
| 388 |
+
Args:
|
| 389 |
+
x: [B, N, C, H, W] - latent frames
|
| 390 |
+
sigma: [B, N] - noise levels
|
| 391 |
+
frame_timestamp: [B, N] - frame indices
|
| 392 |
+
prompt_emb: [B, P, D] - prompt embeddings
|
| 393 |
+
prompt_pad_mask: [B, P] - padding mask for prompts
|
| 394 |
+
mouse: [B, N, 2] - mouse velocity
|
| 395 |
+
button: [B, N, n_buttons] - button states
|
| 396 |
+
scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1)
|
| 397 |
+
kv_cache: StaticKVCache instance
|
| 398 |
+
ctrl_cond: whether to apply controller conditioning (inference only)
|
| 399 |
+
prompt_cond: whether to apply prompt conditioning (inference only)
|
| 400 |
+
"""
|
| 401 |
+
B, N, C, H, W = x.shape
|
| 402 |
+
ph, pw = self.patch
|
| 403 |
+
assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch"
|
| 404 |
+
Hp, Wp = H // ph, W // pw
|
| 405 |
+
torch._assert(
|
| 406 |
+
Hp * Wp == self.config.tokens_per_frame,
|
| 407 |
+
f"{Hp} * {Wp} != {self.config.tokens_per_frame}",
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
torch._assert(
|
| 411 |
+
B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1"
|
| 412 |
+
)
|
| 413 |
+
self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f))
|
| 414 |
+
pos_ids = TensorDict(
|
| 415 |
+
{
|
| 416 |
+
"t_pos": self._t_pos_1f[None],
|
| 417 |
+
"y_pos": self._y_pos_1f[None],
|
| 418 |
+
"x_pos": self._x_pos_1f[None],
|
| 419 |
+
},
|
| 420 |
+
batch_size=[1, self._t_pos_1f.numel()],
|
| 421 |
+
)
|
| 422 |
+
cond = self.denoise_step_emb(sigma) # [B, N, d]
|
| 423 |
+
|
| 424 |
+
assert button is not None
|
| 425 |
+
ctx = {
|
| 426 |
+
"ctrl_emb": self.ctrl_emb(mouse, button, scroll),
|
| 427 |
+
"prompt_emb": prompt_emb,
|
| 428 |
+
"prompt_pad_mask": prompt_pad_mask,
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
D = self.unpatchify.in_features
|
| 432 |
+
x = self.patchify(x.reshape(B * N, C, H, W))
|
| 433 |
+
x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d")
|
| 434 |
+
x = self.transformer(x, pos_ids, cond, ctx, kv_cache)
|
| 435 |
+
x = F.silu(self.out_norm(x, cond))
|
| 436 |
+
x = eo.rearrange(
|
| 437 |
+
self.unpatchify(x),
|
| 438 |
+
"b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)",
|
| 439 |
+
n=N,
|
| 440 |
+
hp=Hp,
|
| 441 |
+
wp=Wp,
|
| 442 |
+
ph=ph,
|
| 443 |
+
pw=pw,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
return x
|
| 447 |
+
|
| 448 |
+
def quantize(self, quant_type: str):
|
| 449 |
+
quantize_model(self, quant_type)
|
| 450 |
+
|
| 451 |
+
def apply_inference_patches(self):
|
| 452 |
+
_apply_inference_patches(self)
|
transformer/nn.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Neural network building blocks for WorldModel transformer."""
|
| 17 |
+
|
| 18 |
+
import warnings
|
| 19 |
+
|
| 20 |
+
import einops as eo
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NoCastModule(torch.nn.Module):
|
| 27 |
+
"""Module that prevents dtype casting during .to() calls."""
|
| 28 |
+
|
| 29 |
+
def _apply(self, fn):
|
| 30 |
+
def keep_dtype(t):
|
| 31 |
+
old_dtype = t.dtype
|
| 32 |
+
out = fn(t)
|
| 33 |
+
if out.dtype is not old_dtype:
|
| 34 |
+
warnings.warn(
|
| 35 |
+
f"{self.__class__.__name__}: requested dtype cast ignored; "
|
| 36 |
+
f"keeping {old_dtype}.",
|
| 37 |
+
stacklevel=3,
|
| 38 |
+
)
|
| 39 |
+
out = out.to(dtype=old_dtype)
|
| 40 |
+
return out
|
| 41 |
+
|
| 42 |
+
return super()._apply(keep_dtype)
|
| 43 |
+
|
| 44 |
+
def to(self, *args, **kwargs):
|
| 45 |
+
warn_cast = False
|
| 46 |
+
|
| 47 |
+
# m.to(ref_tensor): use ref's device, ignore its dtype
|
| 48 |
+
if args and isinstance(args[0], torch.Tensor):
|
| 49 |
+
ref, *rest = args
|
| 50 |
+
args = (ref.device, *rest)
|
| 51 |
+
base = next(self.parameters(), None) or next(self.buffers(), None)
|
| 52 |
+
if base is not None and ref.dtype is not base.dtype:
|
| 53 |
+
warn_cast = True
|
| 54 |
+
|
| 55 |
+
# keyword dtype
|
| 56 |
+
if kwargs.pop("dtype", None) is not None:
|
| 57 |
+
warn_cast = True
|
| 58 |
+
|
| 59 |
+
# positional dtype
|
| 60 |
+
args = tuple(a for a in args if not isinstance(a, torch.dtype))
|
| 61 |
+
|
| 62 |
+
if warn_cast:
|
| 63 |
+
warnings.warn(
|
| 64 |
+
f"{self.__class__.__name__}.to: requested dtype cast ignored; "
|
| 65 |
+
"keeping existing dtypes.",
|
| 66 |
+
stacklevel=2,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return super().to(*args, **kwargs)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def rms_norm(x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""Root mean square layer normalization."""
|
| 74 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class MLP(nn.Module):
|
| 78 |
+
"""Simple MLP with SiLU activation."""
|
| 79 |
+
|
| 80 |
+
def __init__(self, dim_in, dim_middle, dim_out):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.fc1 = nn.Linear(dim_in, dim_middle, bias=False)
|
| 83 |
+
self.fc2 = nn.Linear(dim_middle, dim_out, bias=False)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
return self.fc2(F.silu(self.fc1(x)))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class AdaLN(nn.Module):
|
| 90 |
+
"""Adaptive Layer Normalization."""
|
| 91 |
+
|
| 92 |
+
def __init__(self, dim):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.fc = nn.Linear(dim, 2 * dim, bias=False)
|
| 95 |
+
|
| 96 |
+
def forward(self, x, cond):
|
| 97 |
+
# cond: [b, n, d], x: [b, n*m, d]
|
| 98 |
+
b, n, d = cond.shape
|
| 99 |
+
_, nm, _ = x.shape
|
| 100 |
+
m = nm // n
|
| 101 |
+
|
| 102 |
+
y = F.silu(cond)
|
| 103 |
+
ab = self.fc(y) # [b, n, 2d]
|
| 104 |
+
ab = ab.view(b, n, 1, 2 * d) # [b, n, 1, 2d]
|
| 105 |
+
ab = ab.expand(-1, -1, m, -1) # [b, n, m, 2d]
|
| 106 |
+
ab = ab.reshape(b, nm, 2 * d) # [b, nm, 2d]
|
| 107 |
+
|
| 108 |
+
a, b_ = ab.chunk(2, dim=-1) # [b, nm, d] each
|
| 109 |
+
x = rms_norm(x) * (1 + a) + b_
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def ada_rmsnorm(x, scale, bias):
|
| 114 |
+
"""Adaptive RMS normalization with scale and bias."""
|
| 115 |
+
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1))
|
| 116 |
+
y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2)
|
| 117 |
+
return eo.rearrange(y4, "b n m d -> b (n m) d")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def ada_gate(x, gate):
|
| 121 |
+
"""Apply gating to x with per-frame gates."""
|
| 122 |
+
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1))
|
| 123 |
+
return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class NoiseConditioner(NoCastModule):
|
| 127 |
+
"""Sigma -> logSNR -> Fourier Features -> Dense embedding."""
|
| 128 |
+
|
| 129 |
+
def __init__(self, dim, fourier_dim=512, base=10_000.0):
|
| 130 |
+
super().__init__()
|
| 131 |
+
assert fourier_dim % 2 == 0
|
| 132 |
+
half = fourier_dim // 2
|
| 133 |
+
self.freq = nn.Buffer(
|
| 134 |
+
torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32),
|
| 135 |
+
persistent=False,
|
| 136 |
+
)
|
| 137 |
+
self.mlp = MLP(fourier_dim, dim * 4, dim)
|
| 138 |
+
|
| 139 |
+
def forward(self, s, eps=torch.finfo(torch.float32).eps):
|
| 140 |
+
assert self.freq.dtype == torch.float32
|
| 141 |
+
orig_dtype, shape = s.dtype, s.shape
|
| 142 |
+
|
| 143 |
+
with torch.autocast("cuda", enabled=False):
|
| 144 |
+
s = s.reshape(-1).float() # fp32 for fourier numerical stability
|
| 145 |
+
s = s * 1000 # expressive rotation range
|
| 146 |
+
|
| 147 |
+
# calculate fourier features
|
| 148 |
+
phase = s[:, None] * self.freq[None, :]
|
| 149 |
+
emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1)
|
| 150 |
+
emb = emb * 2**0.5 # Ensure unit variance
|
| 151 |
+
emb = self.mlp(emb)
|
| 152 |
+
|
| 153 |
+
return emb.to(orig_dtype).view(*shape, -1)
|
transformer/quantize.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
from typing import Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
QUANTS = [
|
| 23 |
+
None
|
| 24 |
+
] # TODO: enable specific quant based on model config, which should specify compatible quants
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
|
| 29 |
+
|
| 30 |
+
QUANTS.append("nvfp4")
|
| 31 |
+
except ImportError:
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
|
| 36 |
+
def fp4_linear(
|
| 37 |
+
a_bf16: torch.Tensor,
|
| 38 |
+
b_fp4_T: torch.Tensor,
|
| 39 |
+
a_global_sf: torch.Tensor,
|
| 40 |
+
b_sf_T: torch.Tensor,
|
| 41 |
+
alpha: torch.Tensor,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
+
a_fp4, a_sf = nvfp4_quantize(
|
| 44 |
+
a_bf16,
|
| 45 |
+
a_global_sf,
|
| 46 |
+
sfLayout=SfLayout.layout_128x4,
|
| 47 |
+
do_shuffle=False,
|
| 48 |
+
)
|
| 49 |
+
return mm_fp4(
|
| 50 |
+
a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@fp4_linear.register_fake
|
| 55 |
+
def _fp4_linear_fake(
|
| 56 |
+
a_bf16: torch.Tensor,
|
| 57 |
+
b_fp4_T: torch.Tensor,
|
| 58 |
+
a_global_sf: torch.Tensor,
|
| 59 |
+
b_sf_T: torch.Tensor,
|
| 60 |
+
alpha: torch.Tensor,
|
| 61 |
+
) -> torch.Tensor:
|
| 62 |
+
return torch.empty(
|
| 63 |
+
(a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class FP4Linear(nn.Module):
|
| 68 |
+
"""FP4 Linear layer using FlashInfer's NVFP4 quantization."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, lin: nn.Linear):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
self.in_features = lin.in_features
|
| 74 |
+
self.out_features = lin.out_features
|
| 75 |
+
|
| 76 |
+
# Check alignment requirements for NVFP4 TMA
|
| 77 |
+
assert self.in_features % 32 == 0 and self.out_features % 32 == 0, (
|
| 78 |
+
"features % 32 != 0, nvfp4 disallowed"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Store weight from original linear layer
|
| 82 |
+
self.weight = nn.Parameter(lin.weight.detach().clone())
|
| 83 |
+
|
| 84 |
+
# Cached FP4 weight and scales (populated on first forward)
|
| 85 |
+
self._weight_fp4_T: Optional[torch.Tensor] = None
|
| 86 |
+
self._weight_scales_T: Optional[torch.Tensor] = None
|
| 87 |
+
self._alpha: Optional[torch.Tensor] = None
|
| 88 |
+
self._dummy_scale: Optional[torch.Tensor] = None
|
| 89 |
+
self._weight_global_sf = None
|
| 90 |
+
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
# Quantize weights eagerly (no lazy path)
|
| 93 |
+
self._dummy_scale = torch.full(
|
| 94 |
+
(1,), 1.0, device=self.weight.device, dtype=torch.float32
|
| 95 |
+
)
|
| 96 |
+
weight_bf16 = (
|
| 97 |
+
self.weight.to(torch.bfloat16).to(self.weight.device).contiguous()
|
| 98 |
+
)
|
| 99 |
+
weight_amax = weight_bf16.float().abs().nan_to_num().max()
|
| 100 |
+
self._weight_global_sf = (1.0) / weight_amax
|
| 101 |
+
self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale)
|
| 102 |
+
w_fp4, w_sf = nvfp4_quantize(
|
| 103 |
+
weight_bf16,
|
| 104 |
+
self._weight_global_sf,
|
| 105 |
+
sfLayout=SfLayout.layout_128x4,
|
| 106 |
+
do_shuffle=False,
|
| 107 |
+
)
|
| 108 |
+
self._weight_fp4_T = w_fp4.t()
|
| 109 |
+
self._weight_scales_T = w_sf.t()
|
| 110 |
+
|
| 111 |
+
# Warmup flashinfer fp4 graphs
|
| 112 |
+
assert self.weight.is_cuda, "Weights need to be on GPU before quantization"
|
| 113 |
+
# TODO: test actual shape warmup, might perform better
|
| 114 |
+
lazy_x = torch.zeros(
|
| 115 |
+
(1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16
|
| 116 |
+
)
|
| 117 |
+
fp4_linear(
|
| 118 |
+
lazy_x,
|
| 119 |
+
self._weight_fp4_T,
|
| 120 |
+
self._dummy_scale,
|
| 121 |
+
self._weight_scales_T,
|
| 122 |
+
self._alpha,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 126 |
+
"""Forward pass using FP4 quantization and FlashInfer GEMM."""
|
| 127 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 128 |
+
y = fp4_linear(
|
| 129 |
+
x_flat.to(torch.bfloat16).contiguous(),
|
| 130 |
+
self._weight_fp4_T,
|
| 131 |
+
self._dummy_scale,
|
| 132 |
+
self._weight_scales_T,
|
| 133 |
+
self._alpha,
|
| 134 |
+
)
|
| 135 |
+
return y.reshape(x.shape[:-1] + (-1,))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FP8W8A8Linear(nn.Module):
|
| 139 |
+
__constants__ = ("in_features", "out_features")
|
| 140 |
+
|
| 141 |
+
def __init__(self, lin: nn.Linear):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.in_features, self.out_features = lin.in_features, lin.out_features
|
| 144 |
+
|
| 145 |
+
f8 = torch.float8_e4m3fn
|
| 146 |
+
inv = 1.0 / float(torch.finfo(f8).max)
|
| 147 |
+
self._inv = inv
|
| 148 |
+
|
| 149 |
+
w = lin.weight.detach()
|
| 150 |
+
ws = (w.abs().amax() * inv).clamp_min(1e-8).float() # 0-d
|
| 151 |
+
wf8 = (w / ws.to(w.dtype)).to(f8).contiguous() # row-major
|
| 152 |
+
self.register_buffer("wT", wf8.t()) # col-major view (no contiguous)
|
| 153 |
+
self.register_buffer("ws", ws)
|
| 154 |
+
|
| 155 |
+
if lin.bias is None:
|
| 156 |
+
self.bias = None
|
| 157 |
+
else:
|
| 158 |
+
self.register_buffer("bias", lin.bias.detach().to(torch.float16))
|
| 159 |
+
|
| 160 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 161 |
+
s = x.shape
|
| 162 |
+
x2 = x.reshape(-1, s[-1])
|
| 163 |
+
|
| 164 |
+
xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() # 0-d
|
| 165 |
+
xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()
|
| 166 |
+
|
| 167 |
+
y = torch._scaled_mm(
|
| 168 |
+
xf8,
|
| 169 |
+
self.wT,
|
| 170 |
+
xs,
|
| 171 |
+
self.ws,
|
| 172 |
+
bias=self.bias,
|
| 173 |
+
out_dtype=torch.float16,
|
| 174 |
+
use_fast_accum=True,
|
| 175 |
+
)
|
| 176 |
+
return y.reshape(*s[:-1], self.out_features).to(x.dtype)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class FP8Linear(nn.Module):
|
| 180 |
+
def __init__(self, lin: nn.Linear):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.in_features, self.out_features = lin.in_features, lin.out_features
|
| 183 |
+
|
| 184 |
+
self.bias = (
|
| 185 |
+
nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn))
|
| 186 |
+
if lin.bias is not None
|
| 187 |
+
else None
|
| 188 |
+
)
|
| 189 |
+
w_amax = lin.weight.data.clone().amax().float().squeeze()
|
| 190 |
+
w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn)
|
| 191 |
+
self.register_buffer("w_amax", w_amax)
|
| 192 |
+
self.register_buffer("weightT", w.t())
|
| 193 |
+
self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32)
|
| 194 |
+
|
| 195 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 196 |
+
"""
|
| 197 |
+
Forward pass using FP8 matmul.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
x: Input tensor of shape [..., in_features] (flattens if > 2D)
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Output tensor of shape [..., out_features] in BF16 format, unflattened if input is > 2D
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
# Convert input to FP8 e4m3
|
| 207 |
+
x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous()
|
| 208 |
+
|
| 209 |
+
result = torch._scaled_mm(
|
| 210 |
+
x_fp8,
|
| 211 |
+
self.weightT,
|
| 212 |
+
bias=self.bias,
|
| 213 |
+
scale_a=self.dummy_scale,
|
| 214 |
+
scale_b=self.w_amax,
|
| 215 |
+
out_dtype=torch.bfloat16,
|
| 216 |
+
use_fast_accum=True,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return result.reshape(x.shape[:-1] + (-1,))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def quantize_model(model: nn.Module, quant: str):
|
| 223 |
+
if quant is None:
|
| 224 |
+
return model
|
| 225 |
+
|
| 226 |
+
def eligible(m: nn.Module) -> bool:
|
| 227 |
+
w = getattr(m, "weight", None)
|
| 228 |
+
if not isinstance(m, nn.Linear):
|
| 229 |
+
return False
|
| 230 |
+
if getattr(w, "dtype", None) != torch.bfloat16:
|
| 231 |
+
return False
|
| 232 |
+
o, k = w.shape
|
| 233 |
+
return (o % 32 == 0) and (k % 32 == 0)
|
| 234 |
+
|
| 235 |
+
new_linear = {
|
| 236 |
+
"w8a8": FP8W8A8Linear,
|
| 237 |
+
"nvfp4": FP4Linear,
|
| 238 |
+
"fp8": FP8Linear,
|
| 239 |
+
}[quant]
|
| 240 |
+
|
| 241 |
+
for name, child in model.named_children():
|
| 242 |
+
setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(
|
| 243 |
+
child, quant
|
| 244 |
+
)
|
| 245 |
+
return model
|
vae/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
from .ae_model import WorldEngineVAE
|
| 17 |
+
from .dcae import bake_weight_norm
|
| 18 |
+
|
| 19 |
+
__all__ = ["WorldEngineVAE", "bake_weight_norm"]
|
vae/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (265 Bytes). View file
|
|
|
vae/__pycache__/ae_model.cpython-311.pyc
ADDED
|
Binary file (5.64 kB). View file
|
|
|
vae/__pycache__/dcae.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
vae/ae_model.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""VAE model for WorldEngine frame encoding/decoding."""
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import List, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 26 |
+
from .dcae import Encoder, Decoder, bake_weight_norm
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class EncoderDecoderConfig:
|
| 31 |
+
"""Config object for Encoder/Decoder initialization."""
|
| 32 |
+
|
| 33 |
+
channels: int
|
| 34 |
+
latent_channels: int
|
| 35 |
+
ch_0: int
|
| 36 |
+
ch_max: int
|
| 37 |
+
encoder_blocks_per_stage: List[int]
|
| 38 |
+
decoder_blocks_per_stage: List[int]
|
| 39 |
+
skip_logvar: bool = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class WorldEngineVAE(ModelMixin, ConfigMixin):
|
| 43 |
+
"""
|
| 44 |
+
VAE for encoding/decoding video frames using DCAE architecture.
|
| 45 |
+
|
| 46 |
+
Encodes RGB uint8 images to latent space and decodes latents back to RGB.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
_supports_gradient_checkpointing = False
|
| 50 |
+
|
| 51 |
+
@register_to_config
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
# Common parameters
|
| 55 |
+
sample_size: Tuple[int, int] = (360, 640),
|
| 56 |
+
channels: int = 3,
|
| 57 |
+
latent_channels: int = 16,
|
| 58 |
+
# Encoder parameters
|
| 59 |
+
encoder_ch_0: int = 64,
|
| 60 |
+
encoder_ch_max: int = 256,
|
| 61 |
+
encoder_blocks_per_stage: List[int] = None,
|
| 62 |
+
# Decoder parameters
|
| 63 |
+
decoder_ch_0: int = 128,
|
| 64 |
+
decoder_ch_max: int = 1024,
|
| 65 |
+
decoder_blocks_per_stage: List[int] = None,
|
| 66 |
+
# Shared parameters
|
| 67 |
+
skip_logvar: bool = False,
|
| 68 |
+
# Scaling factors
|
| 69 |
+
scale_factor: float = 1.0,
|
| 70 |
+
shift_factor: float = 0.0,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
# Default blocks per stage
|
| 75 |
+
if encoder_blocks_per_stage is None:
|
| 76 |
+
encoder_blocks_per_stage = [1, 1, 1, 1]
|
| 77 |
+
if decoder_blocks_per_stage is None:
|
| 78 |
+
decoder_blocks_per_stage = [1, 1, 1, 1]
|
| 79 |
+
|
| 80 |
+
# Create encoder config
|
| 81 |
+
encoder_config = EncoderDecoderConfig(
|
| 82 |
+
channels=channels,
|
| 83 |
+
latent_channels=latent_channels,
|
| 84 |
+
ch_0=encoder_ch_0,
|
| 85 |
+
ch_max=encoder_ch_max,
|
| 86 |
+
encoder_blocks_per_stage=list(encoder_blocks_per_stage),
|
| 87 |
+
decoder_blocks_per_stage=list(decoder_blocks_per_stage),
|
| 88 |
+
skip_logvar=skip_logvar,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Create decoder config
|
| 92 |
+
decoder_config = EncoderDecoderConfig(
|
| 93 |
+
channels=channels,
|
| 94 |
+
latent_channels=latent_channels,
|
| 95 |
+
ch_0=decoder_ch_0,
|
| 96 |
+
ch_max=decoder_ch_max,
|
| 97 |
+
encoder_blocks_per_stage=list(encoder_blocks_per_stage),
|
| 98 |
+
decoder_blocks_per_stage=list(decoder_blocks_per_stage),
|
| 99 |
+
skip_logvar=skip_logvar,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.encoder = Encoder(encoder_config)
|
| 103 |
+
self.decoder = Decoder(decoder_config)
|
| 104 |
+
|
| 105 |
+
def encode(self, img: Tensor):
|
| 106 |
+
"""RGB -> RGB+D -> latent"""
|
| 107 |
+
assert img.dim() == 3, "Expected [H, W, C] image tensor"
|
| 108 |
+
img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
| 109 |
+
rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1)
|
| 110 |
+
return self.encoder(rgb)
|
| 111 |
+
|
| 112 |
+
def decode(self, latent: Tensor):
|
| 113 |
+
decoded = self.decoder(latent)
|
| 114 |
+
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
| 115 |
+
decoded = (decoded * 255).round().to(torch.uint8)
|
| 116 |
+
return decoded.squeeze(0).permute(1, 2, 0)[..., :3]
|
| 117 |
+
|
| 118 |
+
def forward(self, x: Tensor, encode: bool = True) -> Tensor:
|
| 119 |
+
"""
|
| 120 |
+
Forward pass - encode or decode based on flag.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
x: Input tensor (image for encode, latent for decode)
|
| 124 |
+
encode: If True, encode; if False, decode
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Encoded latent or decoded image
|
| 128 |
+
"""
|
| 129 |
+
if encode:
|
| 130 |
+
return self.encode(x)
|
| 131 |
+
else:
|
| 132 |
+
return self.decode(x)
|
| 133 |
+
|
| 134 |
+
def bake_weight_norm(self):
|
| 135 |
+
"""Remove weight_norm parametrizations, baking normalized weights into regular tensors.
|
| 136 |
+
|
| 137 |
+
Call this after loading weights and before torch.compile to avoid
|
| 138 |
+
CUDA graph capture errors from in-place weight updates.
|
| 139 |
+
"""
|
| 140 |
+
bake_weight_norm(self)
|
| 141 |
+
return self
|
vae/config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "WorldEngineVAE",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "ae_model.WorldEngineVAE"
|
| 6 |
+
},
|
| 7 |
+
"sample_size": [
|
| 8 |
+
360,
|
| 9 |
+
640
|
| 10 |
+
],
|
| 11 |
+
"channels": 3,
|
| 12 |
+
"latent_channels": 16,
|
| 13 |
+
"encoder_ch_0": 64,
|
| 14 |
+
"encoder_ch_max": 256,
|
| 15 |
+
"encoder_blocks_per_stage": [
|
| 16 |
+
1,
|
| 17 |
+
1,
|
| 18 |
+
1,
|
| 19 |
+
1
|
| 20 |
+
],
|
| 21 |
+
"decoder_ch_0": 128,
|
| 22 |
+
"decoder_ch_max": 1024,
|
| 23 |
+
"decoder_blocks_per_stage": [
|
| 24 |
+
1,
|
| 25 |
+
1,
|
| 26 |
+
1,
|
| 27 |
+
1
|
| 28 |
+
],
|
| 29 |
+
"use_middle_block": false,
|
| 30 |
+
"skip_logvar": false,
|
| 31 |
+
"scale_factor": 1.0,
|
| 32 |
+
"shift_factor": 0.0
|
| 33 |
+
}
|
vae/dcae.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 21 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def bake_weight_norm(model: nn.Module) -> nn.Module:
|
| 25 |
+
"""Remove weight_norm parametrizations, baking normalized weights into regular tensors.
|
| 26 |
+
|
| 27 |
+
This is required for torch.compile/CUDA graph compatibility since weight_norm
|
| 28 |
+
performs in-place updates during forward passes.
|
| 29 |
+
"""
|
| 30 |
+
for module in model.modules():
|
| 31 |
+
if hasattr(module, "parametrizations") and "weight" in getattr(module, "parametrizations", {}):
|
| 32 |
+
remove_parametrizations(module, "weight", leave_parametrized=True)
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# === General Blocks ===
|
| 37 |
+
|
| 38 |
+
def WeightNormConv2d(*args, **kwargs):
|
| 39 |
+
return weight_norm(nn.Conv2d(*args, **kwargs))
|
| 40 |
+
|
| 41 |
+
class ResBlock(nn.Module):
|
| 42 |
+
def __init__(self, ch):
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
hidden = 2 * ch
|
| 46 |
+
# 16 channels per group (matches checkpoint shapes like [128,16,3,3] when ch=64)
|
| 47 |
+
n_grps = max(1, hidden // 16)
|
| 48 |
+
|
| 49 |
+
self.conv1 = WeightNormConv2d(ch, hidden, 1, 1, 0)
|
| 50 |
+
self.conv2 = WeightNormConv2d(hidden, hidden, 3, 1, 1, groups=n_grps)
|
| 51 |
+
self.conv3 = WeightNormConv2d(hidden, ch, 1, 1, 0, bias=False)
|
| 52 |
+
|
| 53 |
+
self.act1 = nn.LeakyReLU(inplace=False)
|
| 54 |
+
self.act2 = nn.LeakyReLU(inplace=False)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
h = self.conv1(x)
|
| 58 |
+
h = self.act1(h)
|
| 59 |
+
h = self.conv2(h)
|
| 60 |
+
h = self.act2(h)
|
| 61 |
+
h = self.conv3(h)
|
| 62 |
+
return x + h
|
| 63 |
+
|
| 64 |
+
# === Encoder ===
|
| 65 |
+
|
| 66 |
+
class LandscapeToSquare(nn.Module):
|
| 67 |
+
# Strict assumption of 360p
|
| 68 |
+
def __init__(self, ch_in, ch_out):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
x = F.interpolate(x, (512, 512), mode='bicubic')
|
| 75 |
+
x = self.proj(x)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
class Downsample(nn.Module):
|
| 79 |
+
def __init__(self, ch_in, ch_out):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
self.proj = WeightNormConv2d(ch_in, ch_out, 1, 1, 0, bias=False)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bicubic')
|
| 86 |
+
x = self.proj(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
class DownBlock(nn.Module):
|
| 90 |
+
def __init__(self, ch_in, ch_out, num_res=1):
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
self.down = Downsample(ch_in, ch_out)
|
| 94 |
+
blocks = []
|
| 95 |
+
for _ in range(num_res):
|
| 96 |
+
blocks.append(ResBlock(ch_in))
|
| 97 |
+
self.blocks = nn.ModuleList(blocks)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
for block in self.blocks:
|
| 101 |
+
x = block(x)
|
| 102 |
+
x = self.down(x)
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
class SpaceToChannel(nn.Module):
|
| 106 |
+
def __init__(self, ch_in, ch_out):
|
| 107 |
+
super().__init__()
|
| 108 |
+
|
| 109 |
+
self.proj = WeightNormConv2d(ch_in, ch_out // 4, 3, 1, 1)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
x = self.proj(x)
|
| 113 |
+
x = F.pixel_unshuffle(x, 2).contiguous()
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
class ChannelAverage(nn.Module):
|
| 117 |
+
def __init__(self, ch_in, ch_out):
|
| 118 |
+
super().__init__()
|
| 119 |
+
|
| 120 |
+
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
|
| 121 |
+
self.grps = ch_in // ch_out
|
| 122 |
+
self.scale = (self.grps) ** 0.5
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
res = x
|
| 126 |
+
x = self.proj(x.contiguous()) # [b, ch_out, h, w]
|
| 127 |
+
|
| 128 |
+
# Residual goes through channel avg
|
| 129 |
+
res = res.view(res.shape[0], self.grps, res.shape[1] // self.grps, res.shape[2], res.shape[3]).contiguous()
|
| 130 |
+
res = res.mean(dim=1) * self.scale # [b, ch_out, h, w]
|
| 131 |
+
|
| 132 |
+
return res + x
|
| 133 |
+
|
| 134 |
+
# === Decoder ===
|
| 135 |
+
|
| 136 |
+
class SquareToLandscape(nn.Module):
|
| 137 |
+
def __init__(self, ch_in, ch_out):
|
| 138 |
+
super().__init__()
|
| 139 |
+
|
| 140 |
+
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
x = self.proj(x) # TODO This ordering is wrong for both
|
| 144 |
+
x = F.interpolate(x, (360, 640), mode='bicubic')
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
class Upsample(nn.Module):
|
| 148 |
+
def __init__(self, ch_in, ch_out):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
self.proj = nn.Identity() if ch_in == ch_out else WeightNormConv2d(
|
| 152 |
+
ch_in, ch_out, 1, 1, 0, bias=False
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = self.proj(x)
|
| 157 |
+
x = F.interpolate(x, scale_factor=2.0, mode='bicubic')
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
class UpBlock(nn.Module):
|
| 161 |
+
def __init__(self, ch_in, ch_out, num_res=1):
|
| 162 |
+
super().__init__()
|
| 163 |
+
|
| 164 |
+
self.up = Upsample(ch_in, ch_out)
|
| 165 |
+
blocks = []
|
| 166 |
+
for _ in range(num_res):
|
| 167 |
+
blocks.append(ResBlock(ch_out))
|
| 168 |
+
self.blocks = nn.ModuleList(blocks)
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
x = self.up(x)
|
| 172 |
+
for block in self.blocks:
|
| 173 |
+
x = block(x)
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
class ChannelToSpace(nn.Module):
|
| 177 |
+
def __init__(self, ch_in, ch_out):
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.proj = WeightNormConv2d(ch_in, ch_out * 4, 3, 1, 1)
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
x = self.proj(x)
|
| 184 |
+
x = F.pixel_shuffle(x, 2).contiguous()
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
class ChannelDuplication(nn.Module):
|
| 188 |
+
def __init__(self, ch_in, ch_out):
|
| 189 |
+
super().__init__()
|
| 190 |
+
|
| 191 |
+
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
|
| 192 |
+
self.reps = ch_out // ch_in
|
| 193 |
+
self.scale = (self.reps) ** -0.5
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
res = x
|
| 197 |
+
x = self.proj(x.contiguous())
|
| 198 |
+
|
| 199 |
+
b, c, h, w = res.shape
|
| 200 |
+
res = res.unsqueeze(2) # [b, c, 1, h, w]
|
| 201 |
+
res = res.expand(b, c, self.reps, h, w) # [b, c, reps, h, w]
|
| 202 |
+
res = res.reshape(b, c * self.reps, h, w).contiguous()
|
| 203 |
+
res = res * self.scale
|
| 204 |
+
|
| 205 |
+
return res + x
|
| 206 |
+
|
| 207 |
+
# === Main AE ===
|
| 208 |
+
|
| 209 |
+
class Encoder(nn.Module):
|
| 210 |
+
def __init__(self, config):
|
| 211 |
+
super().__init__()
|
| 212 |
+
|
| 213 |
+
self.conv_in = LandscapeToSquare(config.channels, config.ch_0)
|
| 214 |
+
|
| 215 |
+
blocks = []
|
| 216 |
+
residuals = []
|
| 217 |
+
|
| 218 |
+
ch = config.ch_0
|
| 219 |
+
for block_count in config.encoder_blocks_per_stage:
|
| 220 |
+
next_ch = min(ch*2, config.ch_max)
|
| 221 |
+
|
| 222 |
+
blocks.append(DownBlock(ch, next_ch, block_count))
|
| 223 |
+
residuals.append(SpaceToChannel(ch, next_ch))
|
| 224 |
+
|
| 225 |
+
ch = next_ch
|
| 226 |
+
|
| 227 |
+
self.blocks = nn.ModuleList(blocks)
|
| 228 |
+
self.residuals = nn.ModuleList(residuals)
|
| 229 |
+
self.conv_out = ChannelAverage(ch, config.latent_channels)
|
| 230 |
+
|
| 231 |
+
self.skip_logvar = bool(getattr(config, "skip_logvar", False))
|
| 232 |
+
if not self.skip_logvar:
|
| 233 |
+
# Checkpoint expects a 1-channel logvar head: [1, ch, 3, 3]
|
| 234 |
+
self.conv_out_logvar = WeightNormConv2d(ch, 1, 3, 1, 1)
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
x = self.conv_in(x)
|
| 238 |
+
for block, residual in zip(self.blocks, self.residuals):
|
| 239 |
+
x = block(x) + residual(x)
|
| 240 |
+
return self.conv_out(x)
|
| 241 |
+
|
| 242 |
+
class Decoder(nn.Module):
|
| 243 |
+
def __init__(self, config):
|
| 244 |
+
super().__init__()
|
| 245 |
+
|
| 246 |
+
self.conv_in = ChannelDuplication(config.latent_channels, config.ch_max)
|
| 247 |
+
|
| 248 |
+
blocks = []
|
| 249 |
+
residuals = []
|
| 250 |
+
|
| 251 |
+
ch = config.ch_0
|
| 252 |
+
for block_count in reversed(config.decoder_blocks_per_stage):
|
| 253 |
+
next_ch = min(ch*2, config.ch_max)
|
| 254 |
+
|
| 255 |
+
blocks.append(UpBlock(next_ch, ch, block_count))
|
| 256 |
+
residuals.append(ChannelToSpace(next_ch, ch))
|
| 257 |
+
|
| 258 |
+
ch = next_ch
|
| 259 |
+
|
| 260 |
+
self.blocks = nn.ModuleList(reversed(blocks))
|
| 261 |
+
self.residuals = nn.ModuleList(reversed(residuals))
|
| 262 |
+
|
| 263 |
+
self.act_out = nn.SiLU()
|
| 264 |
+
self.conv_out = SquareToLandscape(config.ch_0, config.channels)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
x = self.conv_in(x)
|
| 268 |
+
for block, residual in zip(self.blocks, self.residuals):
|
| 269 |
+
x = block(x) + residual(x)
|
| 270 |
+
x = self.act_out(x)
|
| 271 |
+
return self.conv_out(x)
|
vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ecdebf692a6b02610163948251dcf264c5793da12b3729986fb4a3e3c4dc4d1f
|
| 3 |
+
size 141887736
|
vae/model.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""VAE model for WorldEngine frame encoding/decoding."""
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import List, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 26 |
+
from .dcae import Encoder, Decoder
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class EncoderDecoderConfig:
|
| 31 |
+
"""Config object for Encoder/Decoder initialization."""
|
| 32 |
+
|
| 33 |
+
sample_size: Tuple[int, int]
|
| 34 |
+
channels: int
|
| 35 |
+
latent_channels: int
|
| 36 |
+
ch_0: int
|
| 37 |
+
ch_max: int
|
| 38 |
+
encoder_blocks_per_stage: List[int]
|
| 39 |
+
decoder_blocks_per_stage: List[int]
|
| 40 |
+
use_middle_block: bool
|
| 41 |
+
skip_logvar: bool = False
|
| 42 |
+
skip_residuals: bool = False
|
| 43 |
+
normalize_mu: bool = False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class WorldEngineVAE(ModelMixin, ConfigMixin):
|
| 47 |
+
"""
|
| 48 |
+
VAE for encoding/decoding video frames using DCAE architecture.
|
| 49 |
+
|
| 50 |
+
Encodes RGB uint8 images to latent space and decodes latents back to RGB.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
_supports_gradient_checkpointing = False
|
| 54 |
+
|
| 55 |
+
@register_to_config
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
# Common parameters
|
| 59 |
+
sample_size: Tuple[int, int] = (360, 640),
|
| 60 |
+
channels: int = 3,
|
| 61 |
+
latent_channels: int = 16,
|
| 62 |
+
# Encoder parameters
|
| 63 |
+
encoder_ch_0: int = 64,
|
| 64 |
+
encoder_ch_max: int = 256,
|
| 65 |
+
encoder_blocks_per_stage: List[int] = None,
|
| 66 |
+
# Decoder parameters
|
| 67 |
+
decoder_ch_0: int = 128,
|
| 68 |
+
decoder_ch_max: int = 1024,
|
| 69 |
+
decoder_blocks_per_stage: List[int] = None,
|
| 70 |
+
# Shared parameters
|
| 71 |
+
use_middle_block: bool = False,
|
| 72 |
+
skip_logvar: bool = False,
|
| 73 |
+
# Scaling factors
|
| 74 |
+
scale_factor: float = 1.0,
|
| 75 |
+
shift_factor: float = 0.0,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
# Default blocks per stage
|
| 80 |
+
if encoder_blocks_per_stage is None:
|
| 81 |
+
encoder_blocks_per_stage = [1, 1, 1, 1]
|
| 82 |
+
if decoder_blocks_per_stage is None:
|
| 83 |
+
decoder_blocks_per_stage = [1, 1, 1, 1]
|
| 84 |
+
|
| 85 |
+
# Create encoder config
|
| 86 |
+
encoder_config = EncoderDecoderConfig(
|
| 87 |
+
sample_size=tuple(sample_size),
|
| 88 |
+
channels=channels,
|
| 89 |
+
latent_channels=latent_channels,
|
| 90 |
+
ch_0=encoder_ch_0,
|
| 91 |
+
ch_max=encoder_ch_max,
|
| 92 |
+
encoder_blocks_per_stage=list(encoder_blocks_per_stage),
|
| 93 |
+
decoder_blocks_per_stage=list(decoder_blocks_per_stage),
|
| 94 |
+
use_middle_block=use_middle_block,
|
| 95 |
+
skip_logvar=skip_logvar,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Create decoder config
|
| 99 |
+
decoder_config = EncoderDecoderConfig(
|
| 100 |
+
sample_size=tuple(sample_size),
|
| 101 |
+
channels=channels,
|
| 102 |
+
latent_channels=latent_channels,
|
| 103 |
+
ch_0=decoder_ch_0,
|
| 104 |
+
ch_max=decoder_ch_max,
|
| 105 |
+
encoder_blocks_per_stage=list(encoder_blocks_per_stage),
|
| 106 |
+
decoder_blocks_per_stage=list(decoder_blocks_per_stage),
|
| 107 |
+
use_middle_block=use_middle_block,
|
| 108 |
+
skip_logvar=skip_logvar,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.encoder = Encoder(encoder_config)
|
| 112 |
+
self.decoder = Decoder(decoder_config)
|
| 113 |
+
|
| 114 |
+
def encode(self, img: Tensor):
|
| 115 |
+
"""RGB -> RGB+D -> latent"""
|
| 116 |
+
assert img.dim() == 3, "Expected [H, W, C] image tensor"
|
| 117 |
+
img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
| 118 |
+
rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1)
|
| 119 |
+
return self.encoder(rgb)
|
| 120 |
+
|
| 121 |
+
@torch.compile
|
| 122 |
+
def decode(self, latent: Tensor):
|
| 123 |
+
decoded = self.decoder(latent)
|
| 124 |
+
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
| 125 |
+
decoded = (decoded * 255).round().to(torch.uint8)
|
| 126 |
+
return decoded.squeeze(0).permute(1, 2, 0)[..., :3]
|
| 127 |
+
|
| 128 |
+
def forward(self, x: Tensor, encode: bool = True) -> Tensor:
|
| 129 |
+
"""
|
| 130 |
+
Forward pass - encode or decode based on flag.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
x: Input tensor (image for encode, latent for decode)
|
| 134 |
+
encode: If True, encode; if False, decode
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Encoded latent or decoded image
|
| 138 |
+
"""
|
| 139 |
+
if encode:
|
| 140 |
+
return self.encode(x)
|
| 141 |
+
else:
|
| 142 |
+
return self.decode(x)
|