| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from math import sqrt |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...utils import BaseOutput, logging |
| from ...utils.accelerate_utils import apply_forward_hook |
| from ...utils.import_utils import is_transformers_available |
| from ...utils.torch_utils import randn_tensor |
|
|
|
|
| if is_transformers_available(): |
| from transformers import ( |
| Dinov2WithRegistersConfig, |
| Dinov2WithRegistersModel, |
| SiglipVisionConfig, |
| SiglipVisionModel, |
| ViTMAEConfig, |
| ViTMAEModel, |
| ) |
|
|
| from ..activations import get_activation |
| from ..attention import AttentionMixin |
| from ..attention_processor import Attention |
| from ..embeddings import get_2d_sincos_pos_embed |
| from ..modeling_utils import ModelMixin |
| from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: |
| outputs = model(images, output_hidden_states=True) |
| unused_token_num = 5 |
| return outputs.last_hidden_state[:, unused_token_num:] |
|
|
|
|
| def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: |
| outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True) |
| return outputs.last_hidden_state |
|
|
|
|
| def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor: |
| h, w = images.shape[2], images.shape[3] |
| patch_num = int(h * w // patch_size**2) |
| if patch_num * patch_size**2 != h * w: |
| raise ValueError("Image size should be divisible by patch size.") |
| noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype) |
| outputs = model(images, noise, interpolate_pos_encoding=True) |
| return outputs.last_hidden_state[:, 1:] |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _build_encoder( |
| encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64 |
| ) -> nn.Module: |
| """Build a frozen encoder from config (no pretrained download).""" |
| num_attention_heads = hidden_size // head_dim |
|
|
| if encoder_type == "dinov2": |
| config = Dinov2WithRegistersConfig( |
| hidden_size=hidden_size, |
| patch_size=patch_size, |
| image_size=518, |
| num_attention_heads=num_attention_heads, |
| num_hidden_layers=num_hidden_layers, |
| ) |
| model = Dinov2WithRegistersModel(config) |
| |
| |
| model.layernorm.weight = None |
| model.layernorm.bias = None |
| elif encoder_type == "siglip2": |
| config = SiglipVisionConfig( |
| hidden_size=hidden_size, |
| patch_size=patch_size, |
| image_size=256, |
| num_attention_heads=num_attention_heads, |
| num_hidden_layers=num_hidden_layers, |
| ) |
| model = SiglipVisionModel(config) |
| |
| model.vision_model.post_layernorm.weight = None |
| model.vision_model.post_layernorm.bias = None |
| elif encoder_type == "mae": |
| config = ViTMAEConfig( |
| hidden_size=hidden_size, |
| patch_size=patch_size, |
| image_size=224, |
| num_attention_heads=num_attention_heads, |
| num_hidden_layers=num_hidden_layers, |
| mask_ratio=0.0, |
| ) |
| model = ViTMAEModel(config) |
| |
| model.layernorm.weight = None |
| model.layernorm.bias = None |
| else: |
| raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae") |
|
|
| model.requires_grad_(False) |
| return model |
|
|
|
|
| _ENCODER_FORWARD_FNS = { |
| "dinov2": _dinov2_encoder_forward, |
| "siglip2": _siglip2_encoder_forward, |
| "mae": _mae_encoder_forward, |
| } |
|
|
|
|
| @dataclass |
| class RAEDecoderOutput(BaseOutput): |
| """ |
| Output of `RAEDecoder`. |
| |
| Args: |
| logits (`torch.Tensor`): |
| Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`. |
| """ |
|
|
| logits: torch.Tensor |
|
|
|
|
| class ViTMAEIntermediate(nn.Module): |
| def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): |
| super().__init__() |
| self.dense = nn.Linear(hidden_size, intermediate_size) |
| self.intermediate_act_fn = get_activation(hidden_act) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.intermediate_act_fn(hidden_states) |
| return hidden_states |
|
|
|
|
| class ViTMAEOutput(nn.Module): |
| def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0): |
| super().__init__() |
| self.dense = nn.Linear(intermediate_size, hidden_size) |
| self.dropout = nn.Dropout(hidden_dropout_prob) |
|
|
| def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| hidden_states = hidden_states + input_tensor |
| return hidden_states |
|
|
|
|
| class ViTMAELayer(nn.Module): |
| """ |
| This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block). |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| hidden_size: int, |
| num_attention_heads: int, |
| intermediate_size: int, |
| qkv_bias: bool = True, |
| layer_norm_eps: float = 1e-12, |
| hidden_dropout_prob: float = 0.0, |
| attention_probs_dropout_prob: float = 0.0, |
| hidden_act: str = "gelu", |
| ): |
| super().__init__() |
| if hidden_size % num_attention_heads != 0: |
| raise ValueError( |
| f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" |
| ) |
| self.attention = Attention( |
| query_dim=hidden_size, |
| heads=num_attention_heads, |
| dim_head=hidden_size // num_attention_heads, |
| dropout=attention_probs_dropout_prob, |
| bias=qkv_bias, |
| ) |
| self.intermediate = ViTMAEIntermediate( |
| hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act |
| ) |
| self.output = ViTMAEOutput( |
| hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob |
| ) |
| self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
| self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| attention_output = self.attention(self.layernorm_before(hidden_states)) |
| hidden_states = attention_output + hidden_states |
|
|
| layer_output = self.layernorm_after(hidden_states) |
| layer_output = self.intermediate(layer_output) |
| layer_output = self.output(layer_output, hidden_states) |
| return layer_output |
|
|
|
|
| class RAEDecoder(nn.Module): |
| """ |
| Decoder implementation ported from RAE-main to keep checkpoint compatibility. |
| |
| Key attributes (must match checkpoint keys): |
| - decoder_embed |
| - decoder_pos_embed |
| - decoder_layers |
| - decoder_norm |
| - decoder_pred |
| - trainable_cls_token |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int = 768, |
| decoder_hidden_size: int = 512, |
| decoder_num_hidden_layers: int = 8, |
| decoder_num_attention_heads: int = 16, |
| decoder_intermediate_size: int = 2048, |
| num_patches: int = 256, |
| patch_size: int = 16, |
| num_channels: int = 3, |
| image_size: int = 256, |
| qkv_bias: bool = True, |
| layer_norm_eps: float = 1e-12, |
| hidden_dropout_prob: float = 0.0, |
| attention_probs_dropout_prob: float = 0.0, |
| hidden_act: str = "gelu", |
| ): |
| super().__init__() |
| self.decoder_hidden_size = decoder_hidden_size |
| self.patch_size = patch_size |
| self.num_channels = num_channels |
| self.image_size = image_size |
| self.num_patches = num_patches |
|
|
| self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True) |
| grid_size = int(num_patches**0.5) |
| pos_embed = get_2d_sincos_pos_embed( |
| decoder_hidden_size, grid_size, cls_token=True, extra_tokens=1, output_type="pt" |
| ) |
| self.register_buffer("decoder_pos_embed", pos_embed.unsqueeze(0).float(), persistent=False) |
|
|
| self.decoder_layers = nn.ModuleList( |
| [ |
| ViTMAELayer( |
| hidden_size=decoder_hidden_size, |
| num_attention_heads=decoder_num_attention_heads, |
| intermediate_size=decoder_intermediate_size, |
| qkv_bias=qkv_bias, |
| layer_norm_eps=layer_norm_eps, |
| hidden_dropout_prob=hidden_dropout_prob, |
| attention_probs_dropout_prob=attention_probs_dropout_prob, |
| hidden_act=hidden_act, |
| ) |
| for _ in range(decoder_num_hidden_layers) |
| ] |
| ) |
|
|
| self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) |
| self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True) |
| self.gradient_checkpointing = False |
|
|
| self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) |
|
|
| def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: |
| embeddings_positions = embeddings.shape[1] - 1 |
| num_positions = self.decoder_pos_embed.shape[1] - 1 |
|
|
| class_pos_embed = self.decoder_pos_embed[:, 0, :] |
| patch_pos_embed = self.decoder_pos_embed[:, 1:, :] |
| dim = self.decoder_pos_embed.shape[-1] |
|
|
| patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2) |
| patch_pos_embed = F.interpolate( |
| patch_pos_embed, |
| scale_factor=(1, embeddings_positions / num_positions), |
| mode="bicubic", |
| align_corners=False, |
| ) |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) |
|
|
| def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor: |
| b, l, c = x.shape |
| if l == self.num_patches: |
| return x |
| h = w = int(l**0.5) |
| x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) |
| target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5)) |
| x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) |
| x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c) |
| return x |
|
|
| def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None): |
| patch_size, num_channels = self.patch_size, self.num_channels |
| original_image_size = ( |
| original_image_size if original_image_size is not None else (self.image_size, self.image_size) |
| ) |
| original_height, original_width = original_image_size |
| num_patches_h = original_height // patch_size |
| num_patches_w = original_width // patch_size |
| if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: |
| raise ValueError( |
| f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" |
| ) |
|
|
| batch_size = patchified_pixel_values.shape[0] |
| patchified_pixel_values = patchified_pixel_values.reshape( |
| batch_size, |
| num_patches_h, |
| num_patches_w, |
| patch_size, |
| patch_size, |
| num_channels, |
| ) |
| patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) |
| pixel_values = patchified_pixel_values.reshape( |
| batch_size, |
| num_channels, |
| num_patches_h * patch_size, |
| num_patches_w * patch_size, |
| ) |
| return pixel_values |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| *, |
| interpolate_pos_encoding: bool = False, |
| drop_cls_token: bool = False, |
| return_dict: bool = True, |
| ) -> RAEDecoderOutput | tuple[torch.Tensor]: |
| x = self.decoder_embed(hidden_states) |
| if drop_cls_token: |
| x_ = x[:, 1:, :] |
| x_ = self.interpolate_latent(x_) |
| else: |
| x_ = self.interpolate_latent(x) |
|
|
| cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1) |
| x = torch.cat([cls_token, x_], dim=1) |
|
|
| if interpolate_pos_encoding: |
| if not drop_cls_token: |
| raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True") |
| decoder_pos_embed = self.interpolate_pos_encoding(x) |
| else: |
| decoder_pos_embed = self.decoder_pos_embed |
|
|
| hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype) |
|
|
| for layer_module in self.decoder_layers: |
| hidden_states = layer_module(hidden_states) |
|
|
| hidden_states = self.decoder_norm(hidden_states) |
| logits = self.decoder_pred(hidden_states) |
| logits = logits[:, 1:, :] |
|
|
| if not return_dict: |
| return (logits,) |
| return RAEDecoderOutput(logits=logits) |
|
|
|
|
| class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): |
| r""" |
| Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. |
| |
| This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct |
| images from learned representations. |
| |
| This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for |
| all models (such as downloading or saving). |
| |
| Args: |
| encoder_type (`str`, *optional*, defaults to `"dinov2"`): |
| Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. |
| encoder_hidden_size (`int`, *optional*, defaults to `768`): |
| Hidden size of the encoder model. |
| encoder_patch_size (`int`, *optional*, defaults to `14`): |
| Patch size of the encoder model. |
| encoder_num_hidden_layers (`int`, *optional*, defaults to `12`): |
| Number of hidden layers in the encoder model. |
| patch_size (`int`, *optional*, defaults to `16`): |
| Decoder patch size (used for unpatchify and decoder head). |
| encoder_input_size (`int`, *optional*, defaults to `224`): |
| Input size expected by the encoder. |
| image_size (`int`, *optional*): |
| Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like |
| RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // |
| encoder_patch_size) ** 2`. |
| num_channels (`int`, *optional*, defaults to `3`): |
| Number of input/output channels. |
| encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`): |
| Channel-wise mean for encoder input normalization (ImageNet defaults). |
| encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`): |
| Channel-wise std for encoder input normalization (ImageNet defaults). |
| latents_mean (`list` or `tuple`, *optional*): |
| Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable |
| lists. |
| latents_std (`list` or `tuple`, *optional*): |
| Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to |
| config-serializable lists. |
| noise_tau (`float`, *optional*, defaults to `0.0`): |
| Noise level for training (adds noise to latents during training). |
| reshape_to_2d (`bool`, *optional*, defaults to `True`): |
| Whether to reshape latents to 2D (B, C, H, W) format. |
| use_encoder_loss (`bool`, *optional*, defaults to `False`): |
| Whether to use encoder hidden states in the loss (for advanced training). |
| """ |
|
|
| |
| _supports_gradient_checkpointing = False |
| _no_split_modules = ["ViTMAELayer"] |
| _keys_to_ignore_on_load_unexpected = ["decoder.decoder_pos_embed"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| encoder_type: str = "dinov2", |
| encoder_hidden_size: int = 768, |
| encoder_patch_size: int = 14, |
| encoder_num_hidden_layers: int = 12, |
| decoder_hidden_size: int = 512, |
| decoder_num_hidden_layers: int = 8, |
| decoder_num_attention_heads: int = 16, |
| decoder_intermediate_size: int = 2048, |
| patch_size: int = 16, |
| encoder_input_size: int = 224, |
| image_size: int | None = None, |
| num_channels: int = 3, |
| encoder_norm_mean: list | None = None, |
| encoder_norm_std: list | None = None, |
| latents_mean: list | tuple | torch.Tensor | None = None, |
| latents_std: list | tuple | torch.Tensor | None = None, |
| noise_tau: float = 0.0, |
| reshape_to_2d: bool = True, |
| use_encoder_loss: bool = False, |
| scaling_factor: float = 1.0, |
| ): |
| super().__init__() |
|
|
| if encoder_type not in _ENCODER_FORWARD_FNS: |
| raise ValueError( |
| f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}" |
| ) |
|
|
| def _to_config_compatible(value: Any) -> Any: |
| if isinstance(value, torch.Tensor): |
| return value.detach().cpu().tolist() |
| if isinstance(value, tuple): |
| return [_to_config_compatible(v) for v in value] |
| if isinstance(value, list): |
| return [_to_config_compatible(v) for v in value] |
| return value |
|
|
| def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None: |
| if value is None: |
| return None |
| if isinstance(value, torch.Tensor): |
| return value.detach().clone() |
| return torch.tensor(value, dtype=torch.float32) |
|
|
| latents_std_tensor = _as_optional_tensor(latents_std) |
|
|
| |
| self.register_to_config( |
| latents_mean=_to_config_compatible(latents_mean), |
| latents_std=_to_config_compatible(latents_std), |
| ) |
|
|
| self.encoder_input_size = encoder_input_size |
| self.noise_tau = float(noise_tau) |
| self.reshape_to_2d = bool(reshape_to_2d) |
| self.use_encoder_loss = bool(use_encoder_loss) |
|
|
| |
| encoder_patch_size = int(encoder_patch_size) |
| if self.encoder_input_size % encoder_patch_size != 0: |
| raise ValueError( |
| f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." |
| ) |
| decoder_patch_size = int(patch_size) |
| if decoder_patch_size <= 0: |
| raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") |
|
|
| |
| self.encoder: nn.Module = _build_encoder( |
| encoder_type=encoder_type, |
| hidden_size=encoder_hidden_size, |
| patch_size=encoder_patch_size, |
| num_hidden_layers=encoder_num_hidden_layers, |
| ) |
| self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type] |
| num_patches = (self.encoder_input_size // encoder_patch_size) ** 2 |
|
|
| grid = int(sqrt(num_patches)) |
| if grid * grid != num_patches: |
| raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.") |
|
|
| derived_image_size = decoder_patch_size * grid |
| if image_size is None: |
| image_size = derived_image_size |
| else: |
| image_size = int(image_size) |
| if image_size != derived_image_size: |
| raise ValueError( |
| f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} " |
| f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." |
| ) |
|
|
| |
| if encoder_norm_mean is None: |
| encoder_norm_mean = [0.485, 0.456, 0.406] |
| if encoder_norm_std is None: |
| encoder_norm_std = [0.229, 0.224, 0.225] |
| encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1) |
| encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1) |
|
|
| self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True) |
| self.register_buffer("encoder_std", encoder_std_tensor, persistent=True) |
|
|
| |
| latents_mean_tensor = _as_optional_tensor(latents_mean) |
| if latents_mean_tensor is None: |
| latents_mean_tensor = torch.zeros(1) |
| self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True) |
|
|
| if latents_std_tensor is None: |
| latents_std_tensor = torch.ones(1) |
| self.register_buffer("_latents_std", latents_std_tensor, persistent=True) |
|
|
| |
| self.decoder = RAEDecoder( |
| hidden_size=int(encoder_hidden_size), |
| decoder_hidden_size=int(decoder_hidden_size), |
| decoder_num_hidden_layers=int(decoder_num_hidden_layers), |
| decoder_num_attention_heads=int(decoder_num_attention_heads), |
| decoder_intermediate_size=int(decoder_intermediate_size), |
| num_patches=int(num_patches), |
| patch_size=int(decoder_patch_size), |
| num_channels=int(num_channels), |
| image_size=int(image_size), |
| ) |
| self.num_patches = int(num_patches) |
| self.decoder_patch_size = int(decoder_patch_size) |
| self.decoder_image_size = int(image_size) |
|
|
| |
| self.use_slicing = False |
|
|
| def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: |
| |
| noise_sigma = self.noise_tau * torch.rand( |
| (x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator |
| ) |
| return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype) |
|
|
| def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: |
| _, _, h, w = x.shape |
| if h != self.encoder_input_size or w != self.encoder_input_size: |
| x = F.interpolate( |
| x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False |
| ) |
| mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) |
| std = self.encoder_std.to(device=x.device, dtype=x.dtype) |
| return (x - mean) / std |
|
|
| def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor: |
| mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) |
| std = self.encoder_std.to(device=x.device, dtype=x.dtype) |
| return x * std + mean |
|
|
| def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor: |
| latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) |
| latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) |
| return (z - latents_mean) / (latents_std + 1e-5) |
|
|
| def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: |
| latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) |
| latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) |
| return z * (latents_std + 1e-5) + latents_mean |
|
|
| def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: |
| x = self._resize_and_normalize(x) |
|
|
| if self.config.encoder_type == "mae": |
| tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size) |
| else: |
| tokens = self._encoder_forward_fn(self.encoder, x) |
|
|
| if self.training and self.noise_tau > 0: |
| tokens = self._noising(tokens, generator=generator) |
|
|
| if self.reshape_to_2d: |
| b, n, c = tokens.shape |
| side = int(sqrt(n)) |
| if side * side != n: |
| raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.") |
| z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) |
| else: |
| z = tokens |
|
|
| z = self._normalize_latents(z) |
|
|
| |
| if self.config.scaling_factor != 1.0: |
| z = z * self.config.scaling_factor |
|
|
| return z |
|
|
| @apply_forward_hook |
| def encode( |
| self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None |
| ) -> EncoderOutput | tuple[torch.Tensor]: |
| if self.use_slicing and x.shape[0] > 1: |
| latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0) |
| else: |
| latents = self._encode(x, generator=generator) |
|
|
| if not return_dict: |
| return (latents,) |
| return EncoderOutput(latent=latents) |
|
|
| def _decode(self, z: torch.Tensor) -> torch.Tensor: |
| |
| if self.config.scaling_factor != 1.0: |
| z = z / self.config.scaling_factor |
|
|
| z = self._denormalize_latents(z) |
|
|
| if self.reshape_to_2d: |
| b, c, h, w = z.shape |
| tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() |
| else: |
| tokens = z |
|
|
| logits = self.decoder(tokens, return_dict=True).logits |
| x_rec = self.decoder.unpatchify(logits) |
| x_rec = self._denormalize_image(x_rec) |
| return x_rec.to(device=z.device) |
|
|
| @apply_forward_hook |
| def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]: |
| if self.use_slicing and z.shape[0] > 1: |
| decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0) |
| else: |
| decoded = self._decode(z) |
|
|
| if not return_dict: |
| return (decoded,) |
| return DecoderOutput(sample=decoded) |
|
|
| def forward( |
| self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None |
| ) -> DecoderOutput | tuple[torch.Tensor]: |
| latents = self.encode(sample, return_dict=False, generator=generator)[0] |
| decoded = self.decode(latents, return_dict=False)[0] |
| if not return_dict: |
| return (decoded,) |
| return DecoderOutput(sample=decoded) |
|
|