Spaces:
Running on Zero
Running on Zero
| # Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from math import sqrt | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ...configuration_utils import ConfigMixin, register_to_config | |
| from ...utils import BaseOutput, logging | |
| from ...utils.accelerate_utils import apply_forward_hook | |
| from ...utils.import_utils import is_transformers_available | |
| from ...utils.torch_utils import randn_tensor | |
| if is_transformers_available(): | |
| from transformers import ( | |
| Dinov2WithRegistersConfig, | |
| Dinov2WithRegistersModel, | |
| SiglipVisionConfig, | |
| SiglipVisionModel, | |
| ViTMAEConfig, | |
| ViTMAEModel, | |
| ) | |
| from ..activations import get_activation | |
| from ..attention import AttentionMixin | |
| from ..attention_processor import Attention | |
| from ..embeddings import get_2d_sincos_pos_embed | |
| from ..modeling_utils import ModelMixin | |
| from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput | |
| logger = logging.get_logger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Per-encoder forward functions | |
| # --------------------------------------------------------------------------- | |
| # Each function takes the raw transformers model + images and returns patch | |
| # tokens of shape (B, N, C), stripping CLS / register tokens as needed. | |
| def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: | |
| outputs = model(images, output_hidden_states=True) | |
| unused_token_num = 5 # 1 CLS + 4 register tokens | |
| return outputs.last_hidden_state[:, unused_token_num:] | |
| def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: | |
| outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True) | |
| return outputs.last_hidden_state | |
| def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor: | |
| h, w = images.shape[2], images.shape[3] | |
| patch_num = int(h * w // patch_size**2) | |
| if patch_num * patch_size**2 != h * w: | |
| raise ValueError("Image size should be divisible by patch size.") | |
| noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype) | |
| outputs = model(images, noise, interpolate_pos_encoding=True) | |
| return outputs.last_hidden_state[:, 1:] # remove cls token | |
| # --------------------------------------------------------------------------- | |
| # Encoder construction helpers | |
| # --------------------------------------------------------------------------- | |
| def _build_encoder( | |
| encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64 | |
| ) -> nn.Module: | |
| """Build a frozen encoder from config (no pretrained download).""" | |
| num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64 | |
| if encoder_type == "dinov2": | |
| config = Dinov2WithRegistersConfig( | |
| hidden_size=hidden_size, | |
| patch_size=patch_size, | |
| image_size=518, | |
| num_attention_heads=num_attention_heads, | |
| num_hidden_layers=num_hidden_layers, | |
| ) | |
| model = Dinov2WithRegistersModel(config) | |
| # RAE strips the final layernorm affine params (identity LN). Remove them from | |
| # the architecture so `from_pretrained` doesn't leave them on the meta device. | |
| model.layernorm.weight = None | |
| model.layernorm.bias = None | |
| elif encoder_type == "siglip2": | |
| config = SiglipVisionConfig( | |
| hidden_size=hidden_size, | |
| patch_size=patch_size, | |
| image_size=256, | |
| num_attention_heads=num_attention_heads, | |
| num_hidden_layers=num_hidden_layers, | |
| ) | |
| model = SiglipVisionModel(config) | |
| # See dinov2 comment above. | |
| model.vision_model.post_layernorm.weight = None | |
| model.vision_model.post_layernorm.bias = None | |
| elif encoder_type == "mae": | |
| config = ViTMAEConfig( | |
| hidden_size=hidden_size, | |
| patch_size=patch_size, | |
| image_size=224, | |
| num_attention_heads=num_attention_heads, | |
| num_hidden_layers=num_hidden_layers, | |
| mask_ratio=0.0, | |
| ) | |
| model = ViTMAEModel(config) | |
| # See dinov2 comment above. | |
| model.layernorm.weight = None | |
| model.layernorm.bias = None | |
| else: | |
| raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae") | |
| model.requires_grad_(False) | |
| return model | |
| _ENCODER_FORWARD_FNS = { | |
| "dinov2": _dinov2_encoder_forward, | |
| "siglip2": _siglip2_encoder_forward, | |
| "mae": _mae_encoder_forward, | |
| } | |
| class RAEDecoderOutput(BaseOutput): | |
| """ | |
| Output of `RAEDecoder`. | |
| Args: | |
| logits (`torch.Tensor`): | |
| Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`. | |
| """ | |
| logits: torch.Tensor | |
| class ViTMAEIntermediate(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): | |
| super().__init__() | |
| self.dense = nn.Linear(hidden_size, intermediate_size) | |
| self.intermediate_act_fn = get_activation(hidden_act) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.intermediate_act_fn(hidden_states) | |
| return hidden_states | |
| class ViTMAEOutput(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0): | |
| super().__init__() | |
| self.dense = nn.Linear(intermediate_size, hidden_size) | |
| self.dropout = nn.Dropout(hidden_dropout_prob) | |
| def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.dense(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = hidden_states + input_tensor | |
| return hidden_states | |
| class ViTMAELayer(nn.Module): | |
| """ | |
| This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block). | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| hidden_size: int, | |
| num_attention_heads: int, | |
| intermediate_size: int, | |
| qkv_bias: bool = True, | |
| layer_norm_eps: float = 1e-12, | |
| hidden_dropout_prob: float = 0.0, | |
| attention_probs_dropout_prob: float = 0.0, | |
| hidden_act: str = "gelu", | |
| ): | |
| super().__init__() | |
| if hidden_size % num_attention_heads != 0: | |
| raise ValueError( | |
| f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" | |
| ) | |
| self.attention = Attention( | |
| query_dim=hidden_size, | |
| heads=num_attention_heads, | |
| dim_head=hidden_size // num_attention_heads, | |
| dropout=attention_probs_dropout_prob, | |
| bias=qkv_bias, | |
| ) | |
| self.intermediate = ViTMAEIntermediate( | |
| hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act | |
| ) | |
| self.output = ViTMAEOutput( | |
| hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob | |
| ) | |
| self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | |
| self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| attention_output = self.attention(self.layernorm_before(hidden_states)) | |
| hidden_states = attention_output + hidden_states | |
| layer_output = self.layernorm_after(hidden_states) | |
| layer_output = self.intermediate(layer_output) | |
| layer_output = self.output(layer_output, hidden_states) | |
| return layer_output | |
| class RAEDecoder(nn.Module): | |
| """ | |
| Decoder implementation ported from RAE-main to keep checkpoint compatibility. | |
| Key attributes (must match checkpoint keys): | |
| - decoder_embed | |
| - decoder_pos_embed | |
| - decoder_layers | |
| - decoder_norm | |
| - decoder_pred | |
| - trainable_cls_token | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int = 768, | |
| decoder_hidden_size: int = 512, | |
| decoder_num_hidden_layers: int = 8, | |
| decoder_num_attention_heads: int = 16, | |
| decoder_intermediate_size: int = 2048, | |
| num_patches: int = 256, | |
| patch_size: int = 16, | |
| num_channels: int = 3, | |
| image_size: int = 256, | |
| qkv_bias: bool = True, | |
| layer_norm_eps: float = 1e-12, | |
| hidden_dropout_prob: float = 0.0, | |
| attention_probs_dropout_prob: float = 0.0, | |
| hidden_act: str = "gelu", | |
| ): | |
| super().__init__() | |
| self.decoder_hidden_size = decoder_hidden_size | |
| self.patch_size = patch_size | |
| self.num_channels = num_channels | |
| self.image_size = image_size | |
| self.num_patches = num_patches | |
| self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True) | |
| grid_size = int(num_patches**0.5) | |
| pos_embed = get_2d_sincos_pos_embed( | |
| decoder_hidden_size, grid_size, cls_token=True, extra_tokens=1, output_type="pt" | |
| ) | |
| self.register_buffer("decoder_pos_embed", pos_embed.unsqueeze(0).float(), persistent=False) | |
| self.decoder_layers = nn.ModuleList( | |
| [ | |
| ViTMAELayer( | |
| hidden_size=decoder_hidden_size, | |
| num_attention_heads=decoder_num_attention_heads, | |
| intermediate_size=decoder_intermediate_size, | |
| qkv_bias=qkv_bias, | |
| layer_norm_eps=layer_norm_eps, | |
| hidden_dropout_prob=hidden_dropout_prob, | |
| attention_probs_dropout_prob=attention_probs_dropout_prob, | |
| hidden_act=hidden_act, | |
| ) | |
| for _ in range(decoder_num_hidden_layers) | |
| ] | |
| ) | |
| self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) | |
| self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True) | |
| self.gradient_checkpointing = False | |
| self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) | |
| def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: | |
| embeddings_positions = embeddings.shape[1] - 1 | |
| num_positions = self.decoder_pos_embed.shape[1] - 1 | |
| class_pos_embed = self.decoder_pos_embed[:, 0, :] | |
| patch_pos_embed = self.decoder_pos_embed[:, 1:, :] | |
| dim = self.decoder_pos_embed.shape[-1] | |
| patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2) | |
| patch_pos_embed = F.interpolate( | |
| patch_pos_embed, | |
| scale_factor=(1, embeddings_positions / num_positions), | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) | |
| def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor: | |
| b, l, c = x.shape | |
| if l == self.num_patches: | |
| return x | |
| h = w = int(l**0.5) | |
| x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) | |
| target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5)) | |
| x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) | |
| x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c) | |
| return x | |
| def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None): | |
| patch_size, num_channels = self.patch_size, self.num_channels | |
| original_image_size = ( | |
| original_image_size if original_image_size is not None else (self.image_size, self.image_size) | |
| ) | |
| original_height, original_width = original_image_size | |
| num_patches_h = original_height // patch_size | |
| num_patches_w = original_width // patch_size | |
| if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: | |
| raise ValueError( | |
| f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" | |
| ) | |
| batch_size = patchified_pixel_values.shape[0] | |
| patchified_pixel_values = patchified_pixel_values.reshape( | |
| batch_size, | |
| num_patches_h, | |
| num_patches_w, | |
| patch_size, | |
| patch_size, | |
| num_channels, | |
| ) | |
| patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) | |
| pixel_values = patchified_pixel_values.reshape( | |
| batch_size, | |
| num_channels, | |
| num_patches_h * patch_size, | |
| num_patches_w * patch_size, | |
| ) | |
| return pixel_values | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| *, | |
| interpolate_pos_encoding: bool = False, | |
| drop_cls_token: bool = False, | |
| return_dict: bool = True, | |
| ) -> RAEDecoderOutput | tuple[torch.Tensor]: | |
| x = self.decoder_embed(hidden_states) | |
| if drop_cls_token: | |
| x_ = x[:, 1:, :] | |
| x_ = self.interpolate_latent(x_) | |
| else: | |
| x_ = self.interpolate_latent(x) | |
| cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1) | |
| x = torch.cat([cls_token, x_], dim=1) | |
| if interpolate_pos_encoding: | |
| if not drop_cls_token: | |
| raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True") | |
| decoder_pos_embed = self.interpolate_pos_encoding(x) | |
| else: | |
| decoder_pos_embed = self.decoder_pos_embed | |
| hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype) | |
| for layer_module in self.decoder_layers: | |
| hidden_states = layer_module(hidden_states) | |
| hidden_states = self.decoder_norm(hidden_states) | |
| logits = self.decoder_pred(hidden_states) | |
| logits = logits[:, 1:, :] | |
| if not return_dict: | |
| return (logits,) | |
| return RAEDecoderOutput(logits=logits) | |
| class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): | |
| r""" | |
| Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. | |
| This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct | |
| images from learned representations. | |
| This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for | |
| all models (such as downloading or saving). | |
| Args: | |
| encoder_type (`str`, *optional*, defaults to `"dinov2"`): | |
| Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. | |
| encoder_hidden_size (`int`, *optional*, defaults to `768`): | |
| Hidden size of the encoder model. | |
| encoder_patch_size (`int`, *optional*, defaults to `14`): | |
| Patch size of the encoder model. | |
| encoder_num_hidden_layers (`int`, *optional*, defaults to `12`): | |
| Number of hidden layers in the encoder model. | |
| patch_size (`int`, *optional*, defaults to `16`): | |
| Decoder patch size (used for unpatchify and decoder head). | |
| encoder_input_size (`int`, *optional*, defaults to `224`): | |
| Input size expected by the encoder. | |
| image_size (`int`, *optional*): | |
| Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like | |
| RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // | |
| encoder_patch_size) ** 2`. | |
| num_channels (`int`, *optional*, defaults to `3`): | |
| Number of input/output channels. | |
| encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`): | |
| Channel-wise mean for encoder input normalization (ImageNet defaults). | |
| encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`): | |
| Channel-wise std for encoder input normalization (ImageNet defaults). | |
| latents_mean (`list` or `tuple`, *optional*): | |
| Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable | |
| lists. | |
| latents_std (`list` or `tuple`, *optional*): | |
| Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to | |
| config-serializable lists. | |
| noise_tau (`float`, *optional*, defaults to `0.0`): | |
| Noise level for training (adds noise to latents during training). | |
| reshape_to_2d (`bool`, *optional*, defaults to `True`): | |
| Whether to reshape latents to 2D (B, C, H, W) format. | |
| use_encoder_loss (`bool`, *optional*, defaults to `False`): | |
| Whether to use encoder hidden states in the loss (for advanced training). | |
| """ | |
| # NOTE: gradient checkpointing is not wired up for this model yet. | |
| _supports_gradient_checkpointing = False | |
| _no_split_modules = ["ViTMAELayer"] | |
| _keys_to_ignore_on_load_unexpected = ["decoder.decoder_pos_embed"] | |
| def __init__( | |
| self, | |
| encoder_type: str = "dinov2", | |
| encoder_hidden_size: int = 768, | |
| encoder_patch_size: int = 14, | |
| encoder_num_hidden_layers: int = 12, | |
| decoder_hidden_size: int = 512, | |
| decoder_num_hidden_layers: int = 8, | |
| decoder_num_attention_heads: int = 16, | |
| decoder_intermediate_size: int = 2048, | |
| patch_size: int = 16, | |
| encoder_input_size: int = 224, | |
| image_size: int | None = None, | |
| num_channels: int = 3, | |
| encoder_norm_mean: list | None = None, | |
| encoder_norm_std: list | None = None, | |
| latents_mean: list | tuple | torch.Tensor | None = None, | |
| latents_std: list | tuple | torch.Tensor | None = None, | |
| noise_tau: float = 0.0, | |
| reshape_to_2d: bool = True, | |
| use_encoder_loss: bool = False, | |
| scaling_factor: float = 1.0, | |
| ): | |
| super().__init__() | |
| if encoder_type not in _ENCODER_FORWARD_FNS: | |
| raise ValueError( | |
| f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}" | |
| ) | |
| def _to_config_compatible(value: Any) -> Any: | |
| if isinstance(value, torch.Tensor): | |
| return value.detach().cpu().tolist() | |
| if isinstance(value, tuple): | |
| return [_to_config_compatible(v) for v in value] | |
| if isinstance(value, list): | |
| return [_to_config_compatible(v) for v in value] | |
| return value | |
| def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None: | |
| if value is None: | |
| return None | |
| if isinstance(value, torch.Tensor): | |
| return value.detach().clone() | |
| return torch.tensor(value, dtype=torch.float32) | |
| latents_std_tensor = _as_optional_tensor(latents_std) | |
| # Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors. | |
| self.register_to_config( | |
| latents_mean=_to_config_compatible(latents_mean), | |
| latents_std=_to_config_compatible(latents_std), | |
| ) | |
| self.encoder_input_size = encoder_input_size | |
| self.noise_tau = float(noise_tau) | |
| self.reshape_to_2d = bool(reshape_to_2d) | |
| self.use_encoder_loss = bool(use_encoder_loss) | |
| # Validate early, before building the (potentially large) encoder/decoder. | |
| encoder_patch_size = int(encoder_patch_size) | |
| if self.encoder_input_size % encoder_patch_size != 0: | |
| raise ValueError( | |
| f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." | |
| ) | |
| decoder_patch_size = int(patch_size) | |
| if decoder_patch_size <= 0: | |
| raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") | |
| # Frozen representation encoder (built from config, no downloads) | |
| self.encoder: nn.Module = _build_encoder( | |
| encoder_type=encoder_type, | |
| hidden_size=encoder_hidden_size, | |
| patch_size=encoder_patch_size, | |
| num_hidden_layers=encoder_num_hidden_layers, | |
| ) | |
| self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type] | |
| num_patches = (self.encoder_input_size // encoder_patch_size) ** 2 | |
| grid = int(sqrt(num_patches)) | |
| if grid * grid != num_patches: | |
| raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.") | |
| derived_image_size = decoder_patch_size * grid | |
| if image_size is None: | |
| image_size = derived_image_size | |
| else: | |
| image_size = int(image_size) | |
| if image_size != derived_image_size: | |
| raise ValueError( | |
| f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} " | |
| f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." | |
| ) | |
| # Encoder input normalization stats (ImageNet defaults) | |
| if encoder_norm_mean is None: | |
| encoder_norm_mean = [0.485, 0.456, 0.406] | |
| if encoder_norm_std is None: | |
| encoder_norm_std = [0.229, 0.224, 0.225] | |
| encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1) | |
| encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1) | |
| self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True) | |
| self.register_buffer("encoder_std", encoder_std_tensor, persistent=True) | |
| # Latent normalization buffers (defaults are no-ops; actual values come from checkpoint) | |
| latents_mean_tensor = _as_optional_tensor(latents_mean) | |
| if latents_mean_tensor is None: | |
| latents_mean_tensor = torch.zeros(1) | |
| self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True) | |
| if latents_std_tensor is None: | |
| latents_std_tensor = torch.ones(1) | |
| self.register_buffer("_latents_std", latents_std_tensor, persistent=True) | |
| # ViT-MAE style decoder | |
| self.decoder = RAEDecoder( | |
| hidden_size=int(encoder_hidden_size), | |
| decoder_hidden_size=int(decoder_hidden_size), | |
| decoder_num_hidden_layers=int(decoder_num_hidden_layers), | |
| decoder_num_attention_heads=int(decoder_num_attention_heads), | |
| decoder_intermediate_size=int(decoder_intermediate_size), | |
| num_patches=int(num_patches), | |
| patch_size=int(decoder_patch_size), | |
| num_channels=int(num_channels), | |
| image_size=int(image_size), | |
| ) | |
| self.num_patches = int(num_patches) | |
| self.decoder_patch_size = int(decoder_patch_size) | |
| self.decoder_image_size = int(image_size) | |
| # Slicing support (batch dimension) similar to other diffusers autoencoders | |
| self.use_slicing = False | |
| def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: | |
| # Per-sample random sigma in [0, noise_tau] | |
| noise_sigma = self.noise_tau * torch.rand( | |
| (x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator | |
| ) | |
| return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype) | |
| def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: | |
| _, _, h, w = x.shape | |
| if h != self.encoder_input_size or w != self.encoder_input_size: | |
| x = F.interpolate( | |
| x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False | |
| ) | |
| mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) | |
| std = self.encoder_std.to(device=x.device, dtype=x.dtype) | |
| return (x - mean) / std | |
| def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor: | |
| mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) | |
| std = self.encoder_std.to(device=x.device, dtype=x.dtype) | |
| return x * std + mean | |
| def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor: | |
| latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) | |
| latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) | |
| return (z - latents_mean) / (latents_std + 1e-5) | |
| def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: | |
| latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) | |
| latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) | |
| return z * (latents_std + 1e-5) + latents_mean | |
| def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: | |
| x = self._resize_and_normalize(x) | |
| if self.config.encoder_type == "mae": | |
| tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size) | |
| else: | |
| tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C) | |
| if self.training and self.noise_tau > 0: | |
| tokens = self._noising(tokens, generator=generator) | |
| if self.reshape_to_2d: | |
| b, n, c = tokens.shape | |
| side = int(sqrt(n)) | |
| if side * side != n: | |
| raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.") | |
| z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w) | |
| else: | |
| z = tokens | |
| z = self._normalize_latents(z) | |
| # Follow diffusers convention: optionally scale latents for diffusion | |
| if self.config.scaling_factor != 1.0: | |
| z = z * self.config.scaling_factor | |
| return z | |
| def encode( | |
| self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None | |
| ) -> EncoderOutput | tuple[torch.Tensor]: | |
| if self.use_slicing and x.shape[0] > 1: | |
| latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0) | |
| else: | |
| latents = self._encode(x, generator=generator) | |
| if not return_dict: | |
| return (latents,) | |
| return EncoderOutput(latent=latents) | |
| def _decode(self, z: torch.Tensor) -> torch.Tensor: | |
| # Undo scaling factor if applied at encode time | |
| if self.config.scaling_factor != 1.0: | |
| z = z / self.config.scaling_factor | |
| z = self._denormalize_latents(z) | |
| if self.reshape_to_2d: | |
| b, c, h, w = z.shape | |
| tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C) | |
| else: | |
| tokens = z | |
| logits = self.decoder(tokens, return_dict=True).logits | |
| x_rec = self.decoder.unpatchify(logits) | |
| x_rec = self._denormalize_image(x_rec) | |
| return x_rec.to(device=z.device) | |
| def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]: | |
| if self.use_slicing and z.shape[0] > 1: | |
| decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0) | |
| else: | |
| decoded = self._decode(z) | |
| if not return_dict: | |
| return (decoded,) | |
| return DecoderOutput(sample=decoded) | |
| def forward( | |
| self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None | |
| ) -> DecoderOutput | tuple[torch.Tensor]: | |
| r""" | |
| Args: | |
| sample (`torch.Tensor`): Input sample. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`DecoderOutput`] instead of a plain tuple. | |
| generator (`torch.Generator`, *optional*): | |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make sampling | |
| deterministic. | |
| Returns: | |
| [`DecoderOutput`] or `tuple`: | |
| If `return_dict` is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned. | |
| """ | |
| latents = self.encode(sample, return_dict=False, generator=generator)[0] | |
| decoded = self.decode(latents, return_dict=False)[0] | |
| if not return_dict: | |
| return (decoded,) | |
| return DecoderOutput(sample=decoded) | |