| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Streaming TAEHV autoencoder for WorldEngine wp-1.5 temporal-compressed latent decoding.""" |
|
|
| from collections import namedtuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
|
|
| |
| |
| |
|
|
| TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) |
|
|
|
|
| def _conv(n_in, n_out, **kwargs): |
| return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) |
|
|
|
|
| class Clamp(nn.Module): |
| def forward(self, x): |
| return torch.tanh(x / 3) * 3 |
|
|
|
|
| class MemBlock(nn.Module): |
| def __init__(self, n_in, n_out): |
| super().__init__() |
| self.conv = nn.Sequential( |
| _conv(n_in * 2, n_out), |
| nn.ReLU(inplace=True), |
| _conv(n_out, n_out), |
| nn.ReLU(inplace=True), |
| _conv(n_out, n_out), |
| ) |
| self.skip = ( |
| nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() |
| ) |
| self.act = nn.ReLU(inplace=True) |
|
|
| def forward(self, x, past): |
| return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) |
|
|
|
|
| class TPool(nn.Module): |
| def __init__(self, n_f, stride): |
| super().__init__() |
| self.stride = stride |
| self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False) |
|
|
| def forward(self, x): |
| _NT, C, H, W = x.shape |
| return self.conv(x.reshape(-1, self.stride * C, H, W)) |
|
|
|
|
| class TGrow(nn.Module): |
| def __init__(self, n_f, stride): |
| super().__init__() |
| self.stride = stride |
| self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False) |
|
|
| def forward(self, x): |
| _NT, C, H, W = x.shape |
| x = self.conv(x) |
| return x.reshape(-1, C, H, W) |
|
|
|
|
| |
| |
| |
|
|
| def _sequential_single_step(model, memory, work_queue): |
| """Process the work queue until an output frame is produced or the queue is empty.""" |
| while work_queue: |
| xt, i = work_queue.pop(0) |
| if i == len(model): |
| return xt.unsqueeze(1) |
| b = model[i] |
| if isinstance(b, MemBlock): |
| if memory[i] is None: |
| xt_new = b(xt, xt * 0) |
| else: |
| xt_new = b(xt, memory[i]) |
| memory[i] = xt |
| work_queue.insert(0, TWorkItem(xt_new, i + 1)) |
| elif isinstance(b, TPool): |
| if memory[i] is None: |
| memory[i] = [] |
| memory[i].append(xt) |
| if len(memory[i]) == b.stride: |
| N, C, H, W = xt.shape |
| xt = b(torch.cat(memory[i], 1).view(N * b.stride, C, H, W)) |
| memory[i] = [] |
| work_queue.insert(0, TWorkItem(xt, i + 1)) |
| elif isinstance(b, TGrow): |
| xt = b(xt) |
| NT, C, H, W = xt.shape |
| for xt_next in reversed( |
| xt.view(NT // b.stride, b.stride * C, H, W).chunk(b.stride, 1) |
| ): |
| work_queue.insert(0, TWorkItem(xt_next, i + 1)) |
| else: |
| xt = b(xt) |
| work_queue.insert(0, TWorkItem(xt, i + 1)) |
| return None |
|
|
|
|
| def _apply_parallel(model, x): |
| """Apply model with parallelization over time axis. x: NTCHW.""" |
| N, T, C, H, W = x.shape |
| x = x.reshape(N * T, C, H, W) |
| for b in model: |
| if isinstance(b, MemBlock): |
| NT, C, H, W = x.shape |
| T = NT // N |
| _x = x.reshape(N, T, C, H, W) |
| block_memory = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape( |
| x.shape |
| ) |
| x = b(x, block_memory) |
| else: |
| x = b(x) |
| NT, C, H, W = x.shape |
| T = NT // N |
| return x.view(N, T, C, H, W) |
|
|
|
|
| |
| |
| |
|
|
| class ChunkedStreamingTAEHV(ModelMixin, ConfigMixin): |
| """Streaming TAEHV autoencoder for temporal-compressed latent decoding. |
| |
| Owns the encoder/decoder weights directly so diffusers can load them |
| from safetensors. Provides a streaming interface that processes one |
| temporal chunk at a time, maintaining internal state across calls. |
| """ |
|
|
| _supports_gradient_checkpointing = False |
|
|
| @register_to_config |
| def __init__( |
| self, |
| latent_channels: int = 32, |
| patch_size: int = 2, |
| image_channels: int = 3, |
| encoder_time_downscale: tuple[bool, ...] = (True, True, False), |
| decoder_time_upscale: tuple[bool, ...] = (False, True, True), |
| decoder_space_upscale: tuple[bool, ...] = (True, True, True), |
| ): |
| super().__init__() |
|
|
| in_ch = image_channels * patch_size ** 2 |
|
|
| self.encoder = nn.Sequential( |
| _conv(in_ch, 64), nn.ReLU(inplace=True), |
| TPool(64, 2 if encoder_time_downscale[0] else 1), |
| _conv(64, 64, stride=2, bias=False), |
| MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), |
| TPool(64, 2 if encoder_time_downscale[1] else 1), |
| _conv(64, 64, stride=2, bias=False), |
| MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), |
| TPool(64, 2 if encoder_time_downscale[2] else 1), |
| _conv(64, 64, stride=2, bias=False), |
| MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), |
| _conv(64, latent_channels), |
| ) |
|
|
| n_f = [256, 128, 64, 64] |
| self.decoder = nn.Sequential( |
| Clamp(), |
| _conv(latent_channels, n_f[0]), nn.ReLU(inplace=True), |
| MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), |
| TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), |
| _conv(n_f[0], n_f[1], bias=False), |
| MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), |
| TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), |
| _conv(n_f[1], n_f[2], bias=False), |
| MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), |
| nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), |
| TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), |
| _conv(n_f[2], n_f[3], bias=False), |
| nn.ReLU(inplace=True), |
| _conv(n_f[3], image_channels * patch_size ** 2), |
| ) |
|
|
| |
| self.t_downscale = 2 ** sum( |
| t.stride == 2 for t in self.encoder if isinstance(t, TPool) |
| ) |
| self.t_upscale = 2 ** sum( |
| t.stride == 2 for t in self.decoder if isinstance(t, TGrow) |
| ) |
| self.frames_to_trim = self.t_upscale - 1 |
| self.patch_size = patch_size |
|
|
| |
| self._encoder_work_queue: list[TWorkItem] = [] |
| self._encoder_memory: list = [None] * len(self.encoder) |
| self._decoder_work_queue: list[TWorkItem] = [] |
| self._decoder_memory: list = [None] * len(self.decoder) |
| self._n_frames_encoded: int = 0 |
| self._n_frames_decoded: int = 0 |
| self._last_encoder_input_frame: Tensor | None = None |
|
|
| |
| |
| |
|
|
| def reset(self): |
| """Reset streaming state for a new sequence.""" |
| self._encoder_work_queue = [] |
| self._encoder_memory = [None] * len(self.encoder) |
| self._decoder_work_queue = [] |
| self._decoder_memory = [None] * len(self.decoder) |
| self._n_frames_encoded = 0 |
| self._n_frames_decoded = 0 |
| self._last_encoder_input_frame = None |
|
|
| |
| |
| |
|
|
| def _preprocess_input_frames(self, x: Tensor) -> Tensor: |
| if self.patch_size > 1: |
| x = F.pixel_unshuffle(x, self.patch_size) |
| return x |
|
|
| def _postprocess_output_frames(self, x: Tensor) -> Tensor: |
| if self.patch_size > 1: |
| x = F.pixel_shuffle(x, self.patch_size) |
| return x.clamp_(0, 1) |
|
|
| |
| |
| |
|
|
| def _streaming_encode_step(self, x: Tensor | None = None) -> Tensor | None: |
| """Feed an input frame and try to produce an encoder output. |
| |
| Args: |
| x: N1CHW RGB frame tensor with values in [0, 1], or None. |
| Returns: |
| N1CHW latent tensor, or None if not enough input accumulated. |
| """ |
| if x is not None: |
| self._last_encoder_input_frame = x[:, -1:] |
| x = self._preprocess_input_frames(x) |
| self._encoder_work_queue.extend( |
| TWorkItem(xt, 0) for xt in x.unbind(1) |
| ) |
| self._n_frames_encoded += x.shape[1] |
| return _sequential_single_step( |
| self.encoder, self._encoder_memory, self._encoder_work_queue |
| ) |
|
|
| def _streaming_decode_step(self, x: Tensor | None = None) -> Tensor | None: |
| """Feed a latent and try to produce a decoded frame. |
| |
| Args: |
| x: N1CHW latent tensor, or None to retrieve the next pending frame. |
| Returns: |
| N1CHW decoded RGB frame tensor, or None. |
| """ |
| if x is not None: |
| self._decoder_work_queue.extend( |
| TWorkItem(xt, 0) for xt in x.unbind(1) |
| ) |
| while True: |
| xt = _sequential_single_step( |
| self.decoder, self._decoder_memory, self._decoder_work_queue |
| ) |
| if xt is None: |
| return None |
| self._n_frames_decoded += 1 |
| if self._n_frames_decoded <= self.frames_to_trim: |
| continue |
| return self._postprocess_output_frames(xt) |
|
|
| def _flush_decoder(self) -> list[Tensor]: |
| """Drain all remaining decoded frames from the decoder.""" |
| frames = [] |
| while (frame := self._streaming_decode_step()) is not None: |
| frames.append(frame) |
| return frames |
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def encode(self, img: Tensor) -> Tensor: |
| """Encode a chunk of frames to a single latent. |
| |
| Args: |
| img: [T, H, W, C] uint8 where T == t_downscale |
| |
| Returns: |
| latent: [B, C, h, w] |
| """ |
| assert img.dim() == 4 and img.shape[-1] == 3, "Expected [T, H, W, C] RGB uint8" |
|
|
| if img.shape[0] != self.t_downscale: |
| raise ValueError( |
| f"Expected {self.t_downscale} frames, got {img.shape[0]}" |
| ) |
|
|
| rgb = ( |
| img.unsqueeze(0) |
| .to(device=self.device, dtype=self.dtype) |
| .permute(0, 1, 4, 2, 3) |
| .contiguous() |
| .div(255) |
| ) |
|
|
| latent = self._streaming_encode_step(rgb) |
| if latent is None: |
| raise RuntimeError("Expected a latent after a full chunk") |
|
|
| return latent.squeeze(1) |
|
|
| @torch.inference_mode() |
| def decode(self, latent: Tensor) -> Tensor: |
| """Decode a latent to frames. |
| |
| Args: |
| latent: [B, C, h, w] |
| |
| Returns: |
| frames: [T, H, W, C] uint8 |
| """ |
| assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor" |
|
|
| z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype) |
|
|
| if self._n_frames_decoded == 0: |
| for _ in range(self.frames_to_trim): |
| self._streaming_decode_step(z) |
| self._flush_decoder() |
|
|
| first = self._streaming_decode_step(z) |
| assert first is not None, "Expected decoded output after a latent" |
| frames = [first, *self._flush_decoder()] |
|
|
| decoded = torch.cat(frames, dim=1) |
| decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8) |
| return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3] |
|
|