# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """Streaming TAEHV autoencoder for WorldEngine wp-1.5 temporal-compressed latent decoding.""" from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin # --------------------------------------------------------------------------- # Building blocks (mirror the taehv library) # --------------------------------------------------------------------------- TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) def _conv(n_in, n_out, **kwargs): return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): return torch.tanh(x / 3) * 3 class MemBlock(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.conv = nn.Sequential( _conv(n_in * 2, n_out), nn.ReLU(inplace=True), _conv(n_out, n_out), nn.ReLU(inplace=True), _conv(n_out, n_out), ) self.skip = ( nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() ) self.act = nn.ReLU(inplace=True) def forward(self, x, past): return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) class TPool(nn.Module): def __init__(self, n_f, stride): super().__init__() self.stride = stride self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) def forward(self, x): _NT, C, H, W = x.shape return self.conv(x.reshape(-1, self.stride * C, H, W)) class TGrow(nn.Module): def __init__(self, n_f, stride): super().__init__() self.stride = stride self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) def forward(self, x): _NT, C, H, W = x.shape x = self.conv(x) return x.reshape(-1, C, H, W) # --------------------------------------------------------------------------- # Sequential streaming helpers # --------------------------------------------------------------------------- def _sequential_single_step(model, memory, work_queue): """Process the work queue until an output frame is produced or the queue is empty.""" while work_queue: xt, i = work_queue.pop(0) if i == len(model): return xt.unsqueeze(1) b = model[i] if isinstance(b, MemBlock): if memory[i] is None: xt_new = b(xt, xt * 0) else: xt_new = b(xt, memory[i]) memory[i] = xt work_queue.insert(0, TWorkItem(xt_new, i + 1)) elif isinstance(b, TPool): if memory[i] is None: memory[i] = [] memory[i].append(xt) if len(memory[i]) == b.stride: N, C, H, W = xt.shape xt = b(torch.cat(memory[i], 1).view(N * b.stride, C, H, W)) memory[i] = [] work_queue.insert(0, TWorkItem(xt, i + 1)) elif isinstance(b, TGrow): xt = b(xt) NT, C, H, W = xt.shape for xt_next in reversed( xt.view(NT // b.stride, b.stride * C, H, W).chunk(b.stride, 1) ): work_queue.insert(0, TWorkItem(xt_next, i + 1)) else: xt = b(xt) work_queue.insert(0, TWorkItem(xt, i + 1)) return None def _apply_parallel(model, x): """Apply model with parallelization over time axis. x: NTCHW.""" N, T, C, H, W = x.shape x = x.reshape(N * T, C, H, W) for b in model: if isinstance(b, MemBlock): NT, C, H, W = x.shape T = NT // N _x = x.reshape(N, T, C, H, W) block_memory = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape( x.shape ) x = b(x, block_memory) else: x = b(x) NT, C, H, W = x.shape T = NT // N return x.view(N, T, C, H, W) # --------------------------------------------------------------------------- # ChunkedStreamingTAEHV # --------------------------------------------------------------------------- class ChunkedStreamingTAEHV(ModelMixin, ConfigMixin): """Streaming TAEHV autoencoder for temporal-compressed latent decoding. Owns the encoder/decoder weights directly so diffusers can load them from safetensors. Provides a streaming interface that processes one temporal chunk at a time, maintaining internal state across calls. """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, latent_channels: int = 32, patch_size: int = 2, image_channels: int = 3, encoder_time_downscale: tuple[bool, ...] = (True, True, False), decoder_time_upscale: tuple[bool, ...] = (False, True, True), decoder_space_upscale: tuple[bool, ...] = (True, True, True), ): super().__init__() in_ch = image_channels * patch_size ** 2 self.encoder = nn.Sequential( _conv(in_ch, 64), nn.ReLU(inplace=True), TPool(64, 2 if encoder_time_downscale[0] else 1), _conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), TPool(64, 2 if encoder_time_downscale[1] else 1), _conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), TPool(64, 2 if encoder_time_downscale[2] else 1), _conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), _conv(64, latent_channels), ) n_f = [256, 128, 64, 64] self.decoder = nn.Sequential( Clamp(), _conv(latent_channels, n_f[0]), nn.ReLU(inplace=True), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), _conv(n_f[0], n_f[1], bias=False), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), _conv(n_f[1], n_f[2], bias=False), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), _conv(n_f[2], n_f[3], bias=False), nn.ReLU(inplace=True), _conv(n_f[3], image_channels * patch_size ** 2), ) # Computed properties self.t_downscale = 2 ** sum( t.stride == 2 for t in self.encoder if isinstance(t, TPool) ) self.t_upscale = 2 ** sum( t.stride == 2 for t in self.decoder if isinstance(t, TGrow) ) self.frames_to_trim = self.t_upscale - 1 self.patch_size = patch_size # Streaming state (initialised on first use / reset) self._encoder_work_queue: list[TWorkItem] = [] self._encoder_memory: list = [None] * len(self.encoder) self._decoder_work_queue: list[TWorkItem] = [] self._decoder_memory: list = [None] * len(self.decoder) self._n_frames_encoded: int = 0 self._n_frames_decoded: int = 0 self._last_encoder_input_frame: Tensor | None = None # ------------------------------------------------------------------ # Streaming state management # ------------------------------------------------------------------ def reset(self): """Reset streaming state for a new sequence.""" self._encoder_work_queue = [] self._encoder_memory = [None] * len(self.encoder) self._decoder_work_queue = [] self._decoder_memory = [None] * len(self.decoder) self._n_frames_encoded = 0 self._n_frames_decoded = 0 self._last_encoder_input_frame = None # ------------------------------------------------------------------ # Pre/post processing # ------------------------------------------------------------------ def _preprocess_input_frames(self, x: Tensor) -> Tensor: if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) return x def _postprocess_output_frames(self, x: Tensor) -> Tensor: if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) return x.clamp_(0, 1) # ------------------------------------------------------------------ # Streaming encode / decode (one chunk at a time) # ------------------------------------------------------------------ def _streaming_encode_step(self, x: Tensor | None = None) -> Tensor | None: """Feed an input frame and try to produce an encoder output. Args: x: N1CHW RGB frame tensor with values in [0, 1], or None. Returns: N1CHW latent tensor, or None if not enough input accumulated. """ if x is not None: self._last_encoder_input_frame = x[:, -1:] x = self._preprocess_input_frames(x) self._encoder_work_queue.extend( TWorkItem(xt, 0) for xt in x.unbind(1) ) self._n_frames_encoded += x.shape[1] return _sequential_single_step( self.encoder, self._encoder_memory, self._encoder_work_queue ) def _streaming_decode_step(self, x: Tensor | None = None) -> Tensor | None: """Feed a latent and try to produce a decoded frame. Args: x: N1CHW latent tensor, or None to retrieve the next pending frame. Returns: N1CHW decoded RGB frame tensor, or None. """ if x is not None: self._decoder_work_queue.extend( TWorkItem(xt, 0) for xt in x.unbind(1) ) while True: xt = _sequential_single_step( self.decoder, self._decoder_memory, self._decoder_work_queue ) if xt is None: return None self._n_frames_decoded += 1 if self._n_frames_decoded <= self.frames_to_trim: continue return self._postprocess_output_frames(xt) def _flush_decoder(self) -> list[Tensor]: """Drain all remaining decoded frames from the decoder.""" frames = [] while (frame := self._streaming_decode_step()) is not None: frames.append(frame) return frames # ------------------------------------------------------------------ # Pipeline-facing encode / decode # ------------------------------------------------------------------ @torch.inference_mode() def encode(self, img: Tensor) -> Tensor: """Encode a chunk of frames to a single latent. Args: img: [T, H, W, C] uint8 where T == t_downscale Returns: latent: [B, C, h, w] """ assert img.dim() == 4 and img.shape[-1] == 3, "Expected [T, H, W, C] RGB uint8" if img.shape[0] != self.t_downscale: raise ValueError( f"Expected {self.t_downscale} frames, got {img.shape[0]}" ) rgb = ( img.unsqueeze(0) .to(device=self.device, dtype=self.dtype) .permute(0, 1, 4, 2, 3) .contiguous() .div(255) ) latent = self._streaming_encode_step(rgb) if latent is None: raise RuntimeError("Expected a latent after a full chunk") return latent.squeeze(1) @torch.inference_mode() def decode(self, latent: Tensor) -> Tensor: """Decode a latent to frames. Args: latent: [B, C, h, w] Returns: frames: [T, H, W, C] uint8 """ assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor" z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) if self._n_frames_decoded == 0: for _ in range(self.frames_to_trim): self._streaming_decode_step(z) self._flush_decoder() first = self._streaming_decode_step(z) assert first is not None, "Expected decoded output after a latent" frames = [first, *self._flush_decoder()] decoded = torch.cat(frames, dim=1) decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8) return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3]