Waypoint-1.5-1B / vae /ae_model.py
dn6's picture
dn6 HF Staff
Add diffusers support
53f536d verified
raw
history blame
13.5 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Streaming TAEHV autoencoder for WorldEngine wp-1.5 temporal-compressed latent decoding."""
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
# ---------------------------------------------------------------------------
# Building blocks (mirror the taehv library)
# ---------------------------------------------------------------------------
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def _conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(
_conv(n_in * 2, n_out),
nn.ReLU(inplace=True),
_conv(n_out, n_out),
nn.ReLU(inplace=True),
_conv(n_out, n_out),
)
self.skip = (
nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
)
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
# ---------------------------------------------------------------------------
# Sequential streaming helpers
# ---------------------------------------------------------------------------
def _sequential_single_step(model, memory, work_queue):
"""Process the work queue until an output frame is produced or the queue is empty."""
while work_queue:
xt, i = work_queue.pop(0)
if i == len(model):
return xt.unsqueeze(1)
b = model[i]
if isinstance(b, MemBlock):
if memory[i] is None:
xt_new = b(xt, xt * 0)
else:
xt_new = b(xt, memory[i])
memory[i] = xt
work_queue.insert(0, TWorkItem(xt_new, i + 1))
elif isinstance(b, TPool):
if memory[i] is None:
memory[i] = []
memory[i].append(xt)
if len(memory[i]) == b.stride:
N, C, H, W = xt.shape
xt = b(torch.cat(memory[i], 1).view(N * b.stride, C, H, W))
memory[i] = []
work_queue.insert(0, TWorkItem(xt, i + 1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
for xt_next in reversed(
xt.view(NT // b.stride, b.stride * C, H, W).chunk(b.stride, 1)
):
work_queue.insert(0, TWorkItem(xt_next, i + 1))
else:
xt = b(xt)
work_queue.insert(0, TWorkItem(xt, i + 1))
return None
def _apply_parallel(model, x):
"""Apply model with parallelization over time axis. x: NTCHW."""
N, T, C, H, W = x.shape
x = x.reshape(N * T, C, H, W)
for b in model:
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
block_memory = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(
x.shape
)
x = b(x, block_memory)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
return x.view(N, T, C, H, W)
# ---------------------------------------------------------------------------
# ChunkedStreamingTAEHV
# ---------------------------------------------------------------------------
class ChunkedStreamingTAEHV(ModelMixin, ConfigMixin):
"""Streaming TAEHV autoencoder for temporal-compressed latent decoding.
Owns the encoder/decoder weights directly so diffusers can load them
from safetensors. Provides a streaming interface that processes one
temporal chunk at a time, maintaining internal state across calls.
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
latent_channels: int = 32,
patch_size: int = 2,
image_channels: int = 3,
encoder_time_downscale: tuple[bool, ...] = (True, True, False),
decoder_time_upscale: tuple[bool, ...] = (False, True, True),
decoder_space_upscale: tuple[bool, ...] = (True, True, True),
):
super().__init__()
in_ch = image_channels * patch_size ** 2
self.encoder = nn.Sequential(
_conv(in_ch, 64), nn.ReLU(inplace=True),
TPool(64, 2 if encoder_time_downscale[0] else 1),
_conv(64, 64, stride=2, bias=False),
MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2 if encoder_time_downscale[1] else 1),
_conv(64, 64, stride=2, bias=False),
MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2 if encoder_time_downscale[2] else 1),
_conv(64, 64, stride=2, bias=False),
MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
_conv(64, latent_channels),
)
n_f = [256, 128, 64, 64]
self.decoder = nn.Sequential(
Clamp(),
_conv(latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1),
_conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1),
_conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1),
_conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True),
_conv(n_f[3], image_channels * patch_size ** 2),
)
# Computed properties
self.t_downscale = 2 ** sum(
t.stride == 2 for t in self.encoder if isinstance(t, TPool)
)
self.t_upscale = 2 ** sum(
t.stride == 2 for t in self.decoder if isinstance(t, TGrow)
)
self.frames_to_trim = self.t_upscale - 1
self.patch_size = patch_size
# Streaming state (initialised on first use / reset)
self._encoder_work_queue: list[TWorkItem] = []
self._encoder_memory: list = [None] * len(self.encoder)
self._decoder_work_queue: list[TWorkItem] = []
self._decoder_memory: list = [None] * len(self.decoder)
self._n_frames_encoded: int = 0
self._n_frames_decoded: int = 0
self._last_encoder_input_frame: Tensor | None = None
# ------------------------------------------------------------------
# Streaming state management
# ------------------------------------------------------------------
def reset(self):
"""Reset streaming state for a new sequence."""
self._encoder_work_queue = []
self._encoder_memory = [None] * len(self.encoder)
self._decoder_work_queue = []
self._decoder_memory = [None] * len(self.decoder)
self._n_frames_encoded = 0
self._n_frames_decoded = 0
self._last_encoder_input_frame = None
# ------------------------------------------------------------------
# Pre/post processing
# ------------------------------------------------------------------
def _preprocess_input_frames(self, x: Tensor) -> Tensor:
if self.patch_size > 1:
x = F.pixel_unshuffle(x, self.patch_size)
return x
def _postprocess_output_frames(self, x: Tensor) -> Tensor:
if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
return x.clamp_(0, 1)
# ------------------------------------------------------------------
# Streaming encode / decode (one chunk at a time)
# ------------------------------------------------------------------
def _streaming_encode_step(self, x: Tensor | None = None) -> Tensor | None:
"""Feed an input frame and try to produce an encoder output.
Args:
x: N1CHW RGB frame tensor with values in [0, 1], or None.
Returns:
N1CHW latent tensor, or None if not enough input accumulated.
"""
if x is not None:
self._last_encoder_input_frame = x[:, -1:]
x = self._preprocess_input_frames(x)
self._encoder_work_queue.extend(
TWorkItem(xt, 0) for xt in x.unbind(1)
)
self._n_frames_encoded += x.shape[1]
return _sequential_single_step(
self.encoder, self._encoder_memory, self._encoder_work_queue
)
def _streaming_decode_step(self, x: Tensor | None = None) -> Tensor | None:
"""Feed a latent and try to produce a decoded frame.
Args:
x: N1CHW latent tensor, or None to retrieve the next pending frame.
Returns:
N1CHW decoded RGB frame tensor, or None.
"""
if x is not None:
self._decoder_work_queue.extend(
TWorkItem(xt, 0) for xt in x.unbind(1)
)
while True:
xt = _sequential_single_step(
self.decoder, self._decoder_memory, self._decoder_work_queue
)
if xt is None:
return None
self._n_frames_decoded += 1
if self._n_frames_decoded <= self.frames_to_trim:
continue
return self._postprocess_output_frames(xt)
def _flush_decoder(self) -> list[Tensor]:
"""Drain all remaining decoded frames from the decoder."""
frames = []
while (frame := self._streaming_decode_step()) is not None:
frames.append(frame)
return frames
# ------------------------------------------------------------------
# Pipeline-facing encode / decode
# ------------------------------------------------------------------
@torch.inference_mode()
def encode(self, img: Tensor) -> Tensor:
"""Encode a chunk of frames to a single latent.
Args:
img: [T, H, W, C] uint8 where T == t_downscale
Returns:
latent: [B, C, h, w]
"""
assert img.dim() == 4 and img.shape[-1] == 3, "Expected [T, H, W, C] RGB uint8"
if img.shape[0] != self.t_downscale:
raise ValueError(
f"Expected {self.t_downscale} frames, got {img.shape[0]}"
)
rgb = (
img.unsqueeze(0)
.to(device=self.device, dtype=self.dtype)
.permute(0, 1, 4, 2, 3)
.contiguous()
.div(255)
)
latent = self._streaming_encode_step(rgb)
if latent is None:
raise RuntimeError("Expected a latent after a full chunk")
return latent.squeeze(1)
@torch.inference_mode()
def decode(self, latent: Tensor) -> Tensor:
"""Decode a latent to frames.
Args:
latent: [B, C, h, w]
Returns:
frames: [T, H, W, C] uint8
"""
assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor"
z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype)
if self._n_frames_decoded == 0:
for _ in range(self.frames_to_trim):
self._streaming_decode_step(z)
self._flush_decoder()
first = self._streaming_decode_step(z)
assert first is not None, "Expected decoded output after a latent"
frames = [first, *self._flush_decoder()]
decoded = torch.cat(frames, dim=1)
decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8)
return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3]