Safetensors
custom_code
C-RADIOv4-1D-H / radio1d.py
gheinrich's picture
Upload model (#1)
4f645cb
Raw
History Blame Contribute Delete
79.3 kB
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RADIO1D: Vision Transformer with Variable-Length 1D Token Compression
This module implements RADIO1D, a Vision Transformer variant that compresses spatial tokens
into a variable-length 1D sequence of "global tokens" during encoding, then reconstructs
the full spatial resolution via a decoder.
Architecture Overview:
======================
Input Image (B, 3, H_img, W_img) # Any size divisible by patch_size
┌───────────────────────────────────────────────────────────────┐
│ ENCODER │
├───────────────────────────────────────────────────────────────┤
│ Patch Embedding → [cls, registers, patches] │
│ (B, num_prefix + H*W, embed_dim) │
│ │ │
│ ▼ │
│ Transformer Blocks (before downscale) │
│ (B, num_prefix + H*W, embed_dim) │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────┐ │
│ │ PatchMerging (at downscale_levels) │ │
│ │ - Halves H, W: (H, W) → (H/2, W/2) │ │
│ │ - Doubles embed_dim: C → 2C │ │
│ └─────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Transformer Blocks (after downscale) │
│ (B, num_prefix + H/2*W/2, 2*embed_dim) │
│ │ │
│ ▼ │
│ Final Norm │
└───────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────┐
│ TOKEN SLICING │
├───────────────────────────────────────────────────────────────┤
│ Sample num_tokens per sample from mode distribution │
│ (clamped to available spatial tokens H*W) │
│ │
│ slice_1d_tokens(): │
│ - prefix_tokens: (B, num_prefix, C) │
│ - global_tokens: (B, max_tokens, C) - sliced patches │
│ - global_token_mask: (B, max_tokens) - validity mask │
└───────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────┐
│ DECODER │
├───────────────────────────────────────────────────────────────┤
│ 1. Pad global_tokens with filler_tokens to H*W │
│ (filler_tokens interpolated if size differs from ref) │
│ │
│ 2. Use decoder's own learnable prefix_tokens (NOT encoder's) │
│ This ensures decoder only reconstructs from 1D tokens │
│ │
│ 3. Concatenate: [prefix_tokens, padded_patches] │
│ │
│ 4. Decoder Blocks (before upscale) │
│ │
│ 5. ┌─────────────────────────────────────┐ │
│ │ PatchSplitting (at upscale_levels) │ │
│ │ - Doubles H, W: (H, W) → (2H, 2W) │ │
│ │ - Halves embed_dim: C → C/2 │ │
│ └─────────────────────────────────────┘ │
│ │
│ 6. Decoder Blocks (after upscale) │
│ │
│ 7. Final Norm │
│ (B, num_prefix + H_out*W_out, target_embed_dim) │
└───────────────────────────────────────────────────────────────┘
┌───────────────────────────────────────────────────────────────┐
│ OUTPUT │
├───────────────────────────────────────────────────────────────┤
│ { │
│ "encoder": (B, num_prefix + max_tokens, encoder_C) │
│ "decoder": (B, num_prefix + H_out*W_out, target_C) │
│ } │
│ │
│ Sizes depend on input image size, not fixed! │
└───────────────────────────────────────────────────────────────┘
Key Components:
===============
1. PatchMerging (Encoder Downscaling)
- Merges 2x2 patches into 1: (H, W) -> (H/2, W/2)
- Doubles embedding dimension: C -> 2C
- Applied at specified downscale_levels (e.g., before block 19)
2. PatchSplitting (Decoder Upscaling)
- Splits 1 patch into 2x2: (H, W) -> (2H, 2W)
- Halves embedding dimension: C -> C/2
- Inverse of PatchMerging
3. Token Slicing (slice_1d_tokens)
- Samples variable num_tokens per sample from Gaussian mixture distribution
- Clamps num_tokens to available spatial tokens (important for smaller images)
- Returns separate prefix tokens (cls + registers) and global tokens (spatial)
- Pads to max tokens with validity mask
4. RADIO1D_Decoder
- Uses its own learnable prefix_tokens (NOT the encoder's, to avoid information leak)
- Uses learnable filler_tokens to pad sliced tokens back to full sequence
- Applies transformer blocks with PatchSplitting for upscaling
- Reconstructs original spatial resolution and embedding dimension
Key Parameters:
===============
- downscale_levels: Block indices for encoder downscaling, e.g., [19]
- modes: Token count modes for sampling, e.g., [64, 128, 196]
- mode_weights: Probability weights for modes, e.g., [0.33, 0.34, 0.33]
- decoder_depth: Number of decoder blocks, e.g., 6
- decoder_upscale_levels: Block indices for decoder upscaling, e.g., [3]
Training vs Inference:
======================
- Training: Samples different num_tokens per sample from mode distribution.
The encoder output is padded to the max sampled value in the micro-batch.
- Inference: Uses provided num_tokens or defaults to max(modes).
Return Format:
==============
Both outputs are full sequences [cls, registers, patches/global_tokens]:
- output["encoder"]: (B, num_prefix + max_tokens, encoder_C) - compressed representation
- output["decoder"]: (B, num_prefix + output_patches, target_C) - reconstructed full resolution
This format is compatible with the RADIO1D framework's expected return structure.
Variable Image Size Support:
============================
The model supports any image size that is a multiple of the patch size in each dimension.
The decoder uses bilinear interpolation of filler_tokens to adapt to different input sizes.
Example sizes (with patch_size=16, downscale_levels=[19]):
- 224x224 -> 14x14 patches -> 7x7 after downscale -> 14x14 after decode = 196 output patches
- 448x448 -> 28x28 patches -> 14x14 after downscale -> 28x28 after decode = 784 output patches
- 320x384 -> 20x24 patches -> 10x12 after downscale -> 20x24 after decode = 480 output patches
"""
from abc import ABC
from copy import deepcopy
from functools import partial
from logging import getLogger
import math
from typing import (
Callable,
Dict,
Final,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import torch
import torch.distributed as dist
from torch import nn
from torch.nn.init import xavier_normal_
from torch.nn import ModuleList
from torch.nn import functional as F
from torch.distributed.nn.functional import all_reduce as all_reduce_with_gradients
from torch.utils.checkpoint import checkpoint
from timm.models import register_model, build_model_with_cfg
from timm.models.vision_transformer import (
VisionTransformer,
Mlp,
Attention,
Block,
)
from timm.layers import (
AttentionPoolLatent,
PatchEmbed,
PatchDropout,
LayerNorm,
trunc_normal_,
get_norm_layer,
get_act_layer,
to_2tuple,
)
from .vit_patch_generator import ViTPatchGenerator
from .utils import get_rank
logger = getLogger(__name__)
def round_ste(x: torch.Tensor) -> torch.Tensor:
"""Straight-through estimator for the rounding operation."""
x_hat = x.detach().round()
return x + (x_hat - x).detach()
def const_ste(x: torch.Tensor, c: float) -> torch.Tensor:
"""Straight-through estimator that returns a constant `c` with the same shape as `x`,
while routing gradients through `x` (identity)."""
return c - x.detach() + x
# Type definitions
LayerType = Union[str, Callable, Type[nn.Module]]
class PatchMerging(nn.Module):
"""Patch Merging Layer.
Downsample features by merging 2x2 neighboring patches.
"""
def __init__(
self,
dim: int,
out_dim: Optional[int] = None,
norm_layer: Type[nn.Module] = nn.LayerNorm,
size: Union[int, Tuple[int, int]] = 2,
device=None,
dtype=None,
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels (or 2 * dim if None)
norm_layer: Normalization layer.
"""
dd = {'device': device, 'dtype': dtype}
super().__init__()
self.dim = dim
self.out_dim = out_dim or 2 * dim
self.size: Tuple[int, int] = to_2tuple(size)
self.downscale = math.prod(self.size)
self.norm = norm_layer(self.downscale * dim, **dd)
self.reduction = nn.Linear(self.downscale * dim, self.out_dim, bias=False, **dd)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Input features with shape (B, H, W, C).
Returns:
Output features with shape (B, H//2, W//2, out_dim).
"""
B, H, W, C = x.shape
pad_values = (0, 0, 0, W % self.size[1], 0, H % self.size[0])
x = nn.functional.pad(x, pad_values)
_, H, W, _ = x.shape
x = x.reshape(B, H // self.size[0], self.size[0], W // self.size[1], self.size[1], C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.norm(x)
x = self.reduction(x)
return x
def sample_multinomial_batch(
modes: List[int],
weights: List[float],
batch_size: int,
sigma: float = 30.0,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
"""Sample token counts for each sample in a batch.
Uses torch.multinomial for sampling, ensuring reproducibility with torch.manual_seed().
Args:
modes: List of mode values (e.g., [128, 256, 512])
weights: List of weights for each mode
batch_size: Number of samples to generate
sigma: Standard deviation for Gaussian mixture
Returns:
Tensor of shape (batch_size,) with sampled token counts
"""
min_val = min(modes)
max_val = max(modes)
# Create values tensor
values = torch.arange(min_val, max_val + 1, dtype=torch.long)
# Compute the probability density using a mixture of Gaussians
modes_t = torch.tensor(modes, dtype=torch.float32)
weights_t = torch.tensor(weights, dtype=torch.float32)
# values: (num_values,), modes_t: (num_modes,) -> broadcast to (num_values, num_modes)
diff = values.unsqueeze(1).float() - modes_t.unsqueeze(0) # (num_values, num_modes)
gaussian = torch.exp(-diff.pow(2) / (2 * sigma ** 2)) # (num_values, num_modes)
probs = (gaussian * weights_t.unsqueeze(0)).sum(dim=1) # (num_values,)
probs = probs / probs.sum()
# Sample indices using torch.multinomial
sampled_indices = torch.multinomial(probs, batch_size, replacement=True, generator=generator)
# Map indices to actual token counts
sampled_values = values[sampled_indices]
return sampled_values
class GradScale(torch.autograd.Function):
@staticmethod
def forward(ctx: torch.autograd.function.FunctionCtx, x: torch.Tensor, lambda_: Union[float, torch.Tensor] = 1.0):
lambda_ = torch.as_tensor(lambda_, dtype=x.dtype, device=x.device)
ctx.save_for_backward(lambda_)
return x.view_as(x)
@staticmethod
def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
lambda_, = ctx.saved_tensors
return lambda_ * grad_output, None
def slice_1d_tokens(
x: torch.Tensor,
num_tokens: torch.Tensor,
num_prefix_tokens: int,
max_tokens: Optional[int] = None,
use_last_tokens: bool = False,
dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Slice variable numbers of 1D tokens per sample.
Args:
x: Input tensor of shape (B, N, C) where N = num_prefix + num_spatial
num_tokens: Tensor of shape (B,) with number of tokens to keep per sample
num_prefix_tokens: Number of prefix tokens (cls, registers) to always keep
max_tokens: Maximum number of global tokens (for padding). If None, uses max(num_tokens)
use_last_tokens: If True, take the last num_tokens instead of the first
Returns:
Tuple of (prefix_tokens, global_tokens, global_token_mask):
- prefix_tokens: (B, num_prefix, C) prefix tokens (cls, registers)
- global_tokens: (B, max_tokens, C) padded 1D global tokens
- global_token_mask: (B, max_tokens) boolean mask (True = valid token)
"""
B, N, C = x.shape
device = x.device
num_spatial = N - num_prefix_tokens
# Ensure num_tokens is on the right device and clamp to available spatial tokens
num_tokens = num_tokens.to(device)
num_tokens = num_tokens.clamp(max=num_spatial)
if max_tokens is None:
max_tokens = int(num_tokens.max().item())
# Separate prefix and global tokens
prefix = x[:, :num_prefix_tokens] # (B, num_prefix, C)
global_feats = x[:, num_prefix_tokens:] # (B, num_spatial, C)
# Create output tensor with padding for global tokens
global_tokens = torch.zeros(B, max_tokens, C, device=device, dtype=x.dtype)
# Create global token mask
token_indices = torch.arange(global_feats.shape[1], device=device).unsqueeze(0) # (1, max_tokens)
global_token_mask = token_indices < num_tokens.unsqueeze(1) # (B, max_tokens)
if dynamic:
zero_ste = const_ste(num_tokens, 0.0)
one_ste = const_ste(num_tokens, 1.0)
where_ste = torch.where(global_token_mask, one_ste.unsqueeze(1), zero_ste.unsqueeze(1))
# Reducing the gradient magnitude through this gate application stabilizes training.
# Since it's multiplicatively applied to every surviving token, then `1 / num_tokens` means
# that the signal is consistent across different token counts.
where_ste = GradScale.apply(where_ste, 1 / num_tokens.clamp_min(1).unsqueeze(-1))
global_feats = global_feats * where_ste.unsqueeze(-1)
cpu_num_tokens = num_tokens.tolist()
# Copy tokens for each sample (clamped to available)
for i in range(B):
n = int(cpu_num_tokens[i])
if use_last_tokens:
# Take the last n tokens from the spatial sequence
global_tokens[i, :n] = global_feats[i, -n:]
else:
# Take the first n tokens from the spatial sequence
global_tokens[i, :n] = global_feats[i, :n]
return prefix, global_tokens, global_token_mask[:, :max_tokens]
class PatchSplitting(nn.Module):
"""Patch Splitting Layer - Inverse of PatchMerging.
Upsample features by splitting each patch into 2x2 neighboring patches.
"""
def __init__(
self,
dim: int,
out_dim: Optional[int] = None,
norm_layer: Type[nn.Module] = nn.LayerNorm,
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels (or dim // 2 if None)
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.out_dim = out_dim or dim // 2
# Expand channels to 4x output dim, then reshape to 2x2 spatial
self.expansion = nn.Linear(dim, 4 * self.out_dim, bias=False)
self.norm = norm_layer(self.out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Input features with shape (B, H, W, C).
Returns:
Output features with shape (B, 2*H, 2*W, out_dim).
"""
B, H, W, C = x.shape
# Expand channels: (B, H, W, C) -> (B, H, W, 4 * out_dim)
x = self.expansion(x)
# Reshape to split each token into 2x2 neighbors
# (B, H, W, 4 * out_dim) -> (B, H, W, 2, 2, out_dim) -> (B, 2*H, 2*W, out_dim)
x = x.reshape(B, H, W, 2, 2, self.out_dim)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, 2 * H, 2 * W, self.out_dim)
x = self.norm(x)
return x
class RADIO1D_Decoder(nn.Module):
"""Decoder for RADIO1D that reconstructs the original sequence length and embedding dimension.
Takes compressed global tokens from the encoder and reconstructs the full spatial resolution
by applying inverse patch merging (splitting) operations.
"""
def __init__(
self,
input_embed_dim: int,
target_embed_dim: int,
ref_spatial_size: Tuple[int, int],
num_prefix_tokens: int,
depth: int,
upscale_levels: List[int],
num_heads: int = 16,
mlp_ratio: float = 4.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
):
"""
Args:
input_embed_dim: Embedding dimension of input (after encoder downscaling).
target_embed_dim: Target embedding dimension (original encoder dimension).
ref_spatial_size: Reference spatial size (H, W) for filler token initialization.
This is the expected spatial dimensions at the model's nominal image size
(e.g., (7, 7) for 224x224 with patch_size=16 and one downscale).
At runtime, filler tokens are interpolated to match actual input size.
num_prefix_tokens: Number of prefix tokens (cls + registers).
depth: Number of decoder blocks.
upscale_levels: List of block indices where upscaling should happen.
num_heads: Number of attention heads.
mlp_ratio: MLP ratio for transformer blocks.
norm_layer: Normalization layer.
"""
super().__init__()
self.input_embed_dim = input_embed_dim
self.target_embed_dim = target_embed_dim
self.ref_H, self.ref_W = ref_spatial_size
self.num_prefix_tokens = num_prefix_tokens
self.upscale_levels = set(upscale_levels) if upscale_levels else set()
# Learnable filler tokens - initialized at reference size, interpolated at runtime if needed
ref_num_patches = self.ref_H * self.ref_W
scale = input_embed_dim ** -0.5
self.filler_tokens = nn.Parameter(torch.randn(ref_num_patches, input_embed_dim) * scale)
# Learnable prefix tokens for the decoder (independent from encoder's prefix tokens)
# This ensures the decoder only reconstructs from 1D global tokens, not encoder prefix info
self.prefix_tokens = nn.Parameter(torch.randn(num_prefix_tokens, input_embed_dim) * scale)
# Build blocks and upscale layers
embed_dim = input_embed_dim
blocks = []
upscale_blocks = []
prefix_proj_blocks = []
for i in range(depth):
if upscale_levels is not None and i in upscale_levels:
upscale_block = PatchSplitting(embed_dim)
# Projection for prefix tokens to match new (reduced) embed_dim
prefix_proj = nn.Linear(embed_dim, upscale_block.out_dim, bias=False)
num_heads = max(1, num_heads * upscale_block.out_dim // embed_dim)
embed_dim = upscale_block.out_dim
upscale_blocks.append(upscale_block)
prefix_proj_blocks.append(prefix_proj)
blocks.append(Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
))
self.blocks = nn.ModuleList(blocks)
self.upscale_blocks = nn.ModuleList(upscale_blocks)
self.prefix_proj_blocks = nn.ModuleList(prefix_proj_blocks)
# Final norm
self.norm = norm_layer(embed_dim)
# Verify output dimension matches target
assert embed_dim == target_embed_dim, \
f"Decoder output dim {embed_dim} doesn't match target {target_embed_dim}"
def _apply_upscale(
self,
x: torch.Tensor,
upscale_idx: int,
H: int,
W: int,
) -> Tuple[torch.Tensor, int, int]:
"""Apply patch splitting upscale operation.
Args:
x: Input tensor of shape (B, N, C) where N = num_prefix_tokens + H*W
upscale_idx: Index into self.upscale_blocks and self.prefix_proj_blocks
H: Current spatial height (in patches)
W: Current spatial width (in patches)
Returns:
Tuple of (upscaled tensor, new H, new W)
"""
B, N, C = x.shape
# Separate prefix tokens from patch tokens
prefix_tokens = x[:, :self.num_prefix_tokens] # (B, num_prefix, C)
patch_tokens = x[:, self.num_prefix_tokens:] # (B, H*W, C)
# Reshape patch tokens to spatial format for PatchSplitting
patch_tokens = patch_tokens.reshape(B, H, W, C)
# Apply patch splitting (spatial upsampling)
patch_tokens = self.upscale_blocks[upscale_idx](patch_tokens) # (B, 2H, 2W, C')
# Get new dimensions
_, H_new, W_new, C_new = patch_tokens.shape
# Reshape back to sequence format
patch_tokens = patch_tokens.reshape(B, H_new * W_new, C_new)
# Project prefix tokens to match new channel dimension
prefix_tokens = self.prefix_proj_blocks[upscale_idx](prefix_tokens) # (B, num_prefix, C')
# Concatenate prefix and patch tokens
x = torch.cat([prefix_tokens, patch_tokens], dim=1)
return x, H_new, W_new
def _get_filler_tokens(self, H: int, W: int, B: int, device: torch.device) -> torch.Tensor:
"""Get filler tokens interpolated to the required spatial size.
Args:
H: Target height in patches
W: Target width in patches
B: Batch size
device: Target device
Returns:
Filler tokens of shape (B, H*W, C)
"""
if H == self.ref_H and W == self.ref_W:
# No interpolation needed - matches reference size
filler = self.filler_tokens.unsqueeze(0).expand(B, -1, -1)
else:
# Interpolate filler tokens to match the required size
# Reshape to 2D grid, interpolate, then flatten
filler_2d = self.filler_tokens.reshape(self.ref_H, self.ref_W, -1).permute(2, 0, 1).unsqueeze(0)
# Interpolate: (1, C, ref_H, ref_W) -> (1, C, H, W)
filler_2d = nn.functional.interpolate(filler_2d, size=(H, W), mode='bilinear', align_corners=False)
# Reshape back: (1, C, H, W) -> (1, H*W, C)
filler = filler_2d.squeeze(0).permute(1, 2, 0).reshape(1, H * W, -1)
filler = filler.expand(B, -1, -1)
return filler.to(device)
def forward(
self,
global_tokens: torch.Tensor,
global_token_mask: torch.Tensor,
input_size: Tuple[int, int],
) -> Tuple[torch.Tensor, int, int]:
"""Forward pass through decoder.
Args:
global_tokens: Global tokens from encoder (B, num_tokens, C_in), possibly padded
global_token_mask: Boolean mask for valid global tokens (B, num_tokens)
input_size: Tuple of (H, W) spatial dimensions of the downscaled patches
Returns:
Tuple of (features, H, W):
- features: Reconstructed features (B, num_prefix + H*W, target_embed_dim)
- H: Output spatial height
- W: Output spatial width
"""
B = global_tokens.shape[0]
H, W = input_size
device = global_tokens.device
# Get filler tokens (interpolated if needed for variable image sizes)
filler = self._get_filler_tokens(H, W, B, device)
# Combine global tokens with filler tokens
# Valid global tokens replace the corresponding filler tokens
patch_tokens = filler.clone()
# Use the mask to place valid global tokens
for i in range(B):
n_valid = global_token_mask[i].sum().int().item()
patch_tokens[i, :n_valid] = global_tokens[i, :n_valid]
# Use decoder's own learnable prefix tokens (not encoder's, to avoid information leak)
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(B, -1, -1)
# Concatenate prefix tokens and patch tokens
x = torch.cat([prefix_tokens, patch_tokens], dim=1) # (B, num_prefix + H*W, C)
# Apply decoder blocks with optional upscaling
upscale_idx = 0
for i, blk in enumerate(self.blocks):
# Apply upscale before this block if specified
if i in self.upscale_levels:
x, H, W = self._apply_upscale(x, upscale_idx, H, W)
upscale_idx += 1
# Apply transformer block
x = blk(x)
x = self.norm(x)
return x, H, W
class KSampleDistribution(ABC):
def __init__(self, synchronized: bool = False):
self.synchronized = synchronized
g = None
if synchronized:
g = torch.Generator(device='cuda')
g.manual_seed(42)
self.generator = g
def set_curr_step(self, step: int):
if self.generator is not None:
self.generator.manual_seed(step)
def get_max_tokens(self, outside_max: int) -> int:
return outside_max
def get_expected_tokens(self, outside_max: int) -> int:
return self.get_max_tokens(outside_max)
def _sample(self, batch_size: int, max_tokens: int) -> torch.Tensor:
...
def sample(self, batch_size: int, max_tokens: int) -> torch.Tensor:
inner_bs = 1 if self.synchronized else batch_size
inner_sample = self._sample(inner_bs, max_tokens)
if self.synchronized:
inner_sample = inner_sample.expand(batch_size)
return inner_sample
class MultiModeGaussSampleDistribution(KSampleDistribution):
def __init__(self, modes: List[int], mode_weights: List[float], max_tokens: Optional[int] = None, synchronized: bool = False):
super().__init__(synchronized=synchronized)
if len(modes) != len(mode_weights):
raise ValueError("modes and mode_weights must have the same length")
assert all(mode > 0 for mode in modes)
assert all(weight >= 0 for weight in mode_weights)
assert sum(mode_weights) == 1.0
self.modes = modes
self.mode_weights = mode_weights
self._max_tokens = max_tokens
def get_max_tokens(self, outside_max: int) -> int:
my_max = self._max_tokens or max(self.modes)
return min(my_max, outside_max)
def _sample(self, batch_size: int, max_tokens: int) -> torch.Tensor:
num_tokens_per_sample = sample_multinomial_batch(self.modes, self.mode_weights, batch_size, generator=self.generator)
torch.clamp_max_(num_tokens_per_sample, max_tokens)
return num_tokens_per_sample
class UniformKSampleDistribution(KSampleDistribution):
def _sample(self, batch_size: int, max_tokens: int) -> torch.Tensor:
f = torch.rand(batch_size, dtype=torch.float32, device='cuda', generator=self.generator)
f = torch.round(f * max_tokens)
f = torch.clamp(f, min=1.0, max=max_tokens)
return f.long()
def get_expected_tokens(self, outside_max: int) -> int:
return outside_max // 2
class BetaKSampleDistribution(KSampleDistribution):
def __init__(self, target_pct: float = 0.25, synchronized: bool = False):
super().__init__(synchronized=synchronized)
self.target_pct = target_pct
# This is one particular solution where the mode of the beta distribution is equal to target_pct
# `mode = (alpha - 1) / (alpha + beta - 2)`
# and we add the additional constraint that
# `alpha + beta = 2 / rate`
rate = torch.as_tensor(target_pct, dtype=torch.float32, device='cuda')
alpha = 3 - (2 * rate)
beta = (2 / rate) - 3 + (2 * rate)
self.alpha = alpha
self.beta = beta
self.beta_dist = torch.distributions.Beta(alpha, beta)
def _sample(self, batch_size: int, max_tokens: int) -> torch.Tensor:
f = torch._sample_dirichlet(
self.beta_dist._dirichlet.concentration,
generator=self.generator,
).select(-1, 0)
f = torch.round(f * max_tokens)
f = torch.clamp(f, min=1.0, max=max_tokens)
return f.long()
def get_expected_tokens(self, outside_max: int) -> int:
beta_mean = self.alpha / (self.alpha + self.beta)
return int(beta_mean * outside_max)
class TriangleKSampleDistribution(KSampleDistribution):
'''
Triangle distribution, defined as p(x) = 2 - 2x for x in [0, 1]
with expected value 1/3.
'''
def _sample(self, batch_size: int, max_tokens: int) -> torch.Tensor:
u = torch.rand(batch_size, dtype=torch.float32, device='cuda', generator=self.generator)
# Use inverse transform sampling
f = 1 - torch.sqrt(u.clamp_min_(1e-8))
f = torch.round(f * (max_tokens - 1)) + 1
f = torch.clamp(f, min=1.0, max=max_tokens)
return f.long()
def get_expected_tokens(self, outside_max: int) -> int:
return outside_max // 3
class InterpolateKSampleDistributions(KSampleDistribution):
def __init__(self, dist_a: Union[KSampleDistribution, dict, str], dist_b: Union[KSampleDistribution, dict, str], num_steps: int, synchronized: bool = False):
super().__init__(synchronized=synchronized)
self.dist_a = self._instantiate(dist_a)
self.dist_b = self._instantiate(dist_b)
self.num_steps = num_steps
self.curr_step = 0
def set_curr_step(self, step: int):
super().set_curr_step(step)
self.curr_step = step
self.dist_a.set_curr_step(step)
self.dist_b.set_curr_step(step)
def get_max_tokens(self, outside_max: int) -> int:
return max(self.dist_a.get_max_tokens(outside_max), self.dist_b.get_max_tokens(outside_max))
def get_expected_tokens(self, outside_max: int) -> int:
ea = self.dist_a.get_expected_tokens(outside_max)
eb = self.dist_b.get_expected_tokens(outside_max)
alpha = max(0, min(1, self.curr_step / self.num_steps))
return int((1 - alpha) * ea + alpha * eb)
def _sample(self, batch_size: int, max_tokens: int) -> torch.Tensor:
sample_a = self.dist_a.sample(batch_size, max_tokens).float()
sample_b = self.dist_b.sample(batch_size, max_tokens).float()
alpha = max(0, min(1, self.curr_step / self.num_steps))
f = (1 - alpha) * sample_a + alpha * sample_b
f = torch.round(f).clamp(min=1.0, max=max_tokens)
return f.long()
def _instantiate(self, dist: Union[KSampleDistribution, dict, str]) -> KSampleDistribution:
if isinstance(dist, KSampleDistribution):
return dist
elif isinstance(dist, dict):
dist_type = dist.pop('type')
if dist_type not in _K_SAMPLER_FACTORY:
raise ValueError(f"Unknown KSampleDistribution type: {dist_type}")
return _K_SAMPLER_FACTORY[dist_type](**dist)
elif isinstance(dist, str):
if dist not in _K_SAMPLER_FACTORY:
raise ValueError(f"Unknown KSampleDistribution type: {dist}")
return _K_SAMPLER_FACTORY[dist]()
else:
raise ValueError("dist must be a KSampleDistribution instance, a dict, or a str")
_K_SAMPLER_FACTORY = {
'multimode_gaussian': MultiModeGaussSampleDistribution,
'uniform': UniformKSampleDistribution,
'beta': BetaKSampleDistribution,
'triangle': TriangleKSampleDistribution,
'interpolate': InterpolateKSampleDistributions,
}
class RADIO1D(VisionTransformer):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_img_size: Final[bool]
_iter_count: torch.Tensor
dynamic_rate_vec: Optional[torch.Tensor]
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
# scale_attn_norm: bool = False,
# scale_mlp_norm: bool = False,
proj_bias: bool = True,
init_values: Optional[float] = None,
class_token: bool = True,
pos_embed: str = 'learn',
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
final_norm: bool = True,
fc_norm: Optional[bool] = None,
pool_include_prefix: bool = False,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
fix_init: bool = False,
embed_layer: Callable = PatchEmbed,
embed_norm_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
num_cls_tokens: Optional[int] = None,
cpe_max_size: Optional[int] = None,
num_registers: Optional[int] = None,
register_multiple: Optional[int] = None,
downscale_levels: Optional[List[int]] = None,
k_sample_config: Optional[dict] = None,
decoder_depth: int = 6,
decoder_upscale_levels: Optional[List[int]] = None,
dynamic_rate: bool = False,
dynamic_temperature: float = 1.0,
progressive_reduction: bool = False,
cka_weight: float = 0.0,
cka_weight_final: Optional[float] = None,
uniform_k: bool = False,
grad_checkpointing: Union[bool, int] = False,
decoder_grad_checkpointing: Union[bool, int] = False,
downscale_expansion_factor: float = 2.0,
) -> None:
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Number of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
no_embed_class: Don't include position embeddings for class (or reg) tokens.
reg_tokens: Number of register tokens.
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer: Patch embedding layer.
embed_norm_layer: Normalization layer to use / override in patch embed module.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
num_cls_tokens: Number of class tokens.
cpe_max_size: Maximum size of the input image.
num_registers: Number of registers.
register_multiple: Register multiple.
downscale_levels: Downscale levels.
modes: Modes for the input image size.
mode_weights: Weights for the modes.
decoder_depth: Number of decoder blocks.
decoder_upscale_levels: Block indices in decoder where upscaling should happen.
"""
super().__init__()
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or LayerNorm
embed_norm_layer = get_norm_layer(embed_norm_layer)
act_layer = get_act_layer(act_layer) or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens
self.has_class_token = class_token
self.no_embed_class = no_embed_class
self.pool_include_prefix = pool_include_prefix
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.dynamic_rate = dynamic_rate
self.dynamic_temperature = dynamic_temperature
self.progressive_reduction = progressive_reduction
self.cka_weight = cka_weight
self.cka_weight_final = cka_weight_final or cka_weight
if dynamic_rate:
if num_registers is None or num_registers == 0:
raise ValueError("dynamic_rate requires at least one register token")
self.register_buffer('dynamic_rate_vec', torch.randn(embed_dim))
self.dynamic_rate_projector = nn.Linear(embed_dim, 1)
if k_sample_config is None:
self.k_sampler = UniformKSampleDistribution(synchronized=dynamic_rate or uniform_k)
else:
k_sample_config = deepcopy(k_sample_config)
sampler_type = k_sample_config.pop('type')
self.k_sampler = _K_SAMPLER_FACTORY[sampler_type](**k_sample_config, synchronized=dynamic_rate or uniform_k)
embed_args = {}
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
if embed_norm_layer is not None:
embed_args['norm_layer'] = embed_norm_layer
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
if not pos_embed or pos_embed == 'none':
self.pos_embed = None
else:
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
if cpe_max_size is not None:
# Replace patch embed with CPE patch generator
input_dims = img_size
max_img_size = int(round(cpe_max_size / patch_size) * patch_size)
self.patch_generator = ViTPatchGenerator(
patch_size=patch_size,
embed_dim=embed_dim,
input_dims=input_dims,
normalize_patches=pre_norm,
cls_token=self.has_class_token,
max_input_dims=max_img_size,
pos_dropout=pos_drop_rate,
num_cls_tokens=num_cls_tokens,
register_multiple=register_multiple,
num_registers=num_registers,
#init_from=self,
#adaptive_patch_tokenizer_config=None,
)
self.patch_embed = None
self.cls_token = None
self.pos_embed = None
self.pos_drop = None
self.num_cls_tokens = num_cls_tokens
self.num_registers = num_registers
self.num_prefix_tokens = self.patch_generator.num_cls_patches
else:
self.patch_generator = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# Save original dimensions for decoder creation
original_embed_dim = embed_dim
original_num_patches = num_patches
original_num_heads = num_heads
downscale_blocks = []
prefix_proj_blocks = [] # Projection layers for prefix tokens during downscaling
blocks = []
feature_info = []
for i in range(depth):
if downscale_levels is not None and i in downscale_levels:
downscale_block = PatchMerging(embed_dim)
# Projection for prefix tokens to match new embed_dim
prefix_proj = nn.Linear(embed_dim, downscale_block.out_dim, bias=False)
num_heads = int(num_heads * downscale_block.out_dim // embed_dim)
embed_dim = downscale_block.out_dim
downscale_blocks.append(downscale_block)
prefix_proj_blocks.append(prefix_proj)
blocks.append(block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
# scale_attn_norm=scale_attn_norm,
# scale_mlp_norm=scale_mlp_norm,
proj_bias=proj_bias,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
))
feature_info.append(dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction))
self.blocks = ModuleList(blocks)
self.downscale_blocks = ModuleList(downscale_blocks)
self.prefix_proj_blocks = ModuleList(prefix_proj_blocks)
self.downscale_levels = set(downscale_levels) if downscale_levels else set()
self.feature_info = feature_info
self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
if isinstance(grad_checkpointing, bool):
self.grad_checkpointing = len(blocks) if grad_checkpointing else 0
else:
self.grad_checkpointing = min(grad_checkpointing, len(blocks))
# Classifier Head
if global_pool == 'map':
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
act_layer=act_layer,
)
else:
self.attn_pool = None
self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# Create decoder (always needed to reconstruct from sliced global tokens)
# Compute dimensions after encoder (with or without downscaling)
if downscale_levels:
num_downscales = len(downscale_levels)
encoder_num_patches = original_num_patches // (4 ** num_downscales)
encoder_embed_dim = embed_dim # embed_dim after all downscales
# Default upscale_levels: upscale at the midpoint of decoder
if decoder_upscale_levels is None:
decoder_upscale_levels = [decoder_depth // 2] * num_downscales
else:
encoder_num_patches = original_num_patches
encoder_embed_dim = original_embed_dim
decoder_upscale_levels = [] # No upscaling needed
# Compute reference spatial size for decoder filler tokens
ref_H = ref_W = int(encoder_num_patches ** 0.5)
self.decoder = RADIO1D_Decoder(
input_embed_dim=encoder_embed_dim,
target_embed_dim=original_embed_dim,
ref_spatial_size=(ref_H, ref_W), # Reference size for filler token init
num_prefix_tokens=self.num_prefix_tokens,
depth=decoder_depth,
upscale_levels=decoder_upscale_levels,
num_heads=original_num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
# Iteration counter for logging (not a parameter, won't be saved in state_dict)
self.register_buffer('_iter_count', torch.tensor(0, dtype=torch.long))
self.register_buffer('_total_num_tokens', torch.tensor(0, dtype=torch.long), persistent=False)
self.register_buffer('_total_num_samples', torch.tensor(0, dtype=torch.long), persistent=False)
if weight_init != 'skip':
self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None:
ret = super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
self.k_sampler.set_curr_step(int(self._iter_count.item()))
return ret
def _apply_downscale(self, x: torch.Tensor, downscale_idx: int, H: int, W: int) -> Tuple[torch.Tensor, int, int]:
"""Apply patch merging downscale operation.
Args:
x: Input tensor of shape (B, N, C) where N = num_prefix_tokens + H*W
downscale_idx: Index into self.downscale_blocks and self.prefix_proj_blocks
H: Current spatial height (in patches)
W: Current spatial width (in patches)
Returns:
Tuple of (downscaled tensor, new H, new W)
"""
B, N, C = x.shape
num_prefix = self.num_prefix_tokens
# Separate prefix tokens (cls, registers) from patch tokens
prefix_tokens = x[:, :num_prefix] # (B, num_prefix, C)
patch_tokens = x[:, num_prefix:] # (B, H*W, C)
# Reshape patch tokens to spatial format for PatchMerging
patch_tokens = patch_tokens.reshape(B, H, W, C)
# Apply patch merging (spatial downsampling)
patch_tokens = self.downscale_blocks[downscale_idx](patch_tokens) # (B, H', W', C')
# Get new dimensions
_, H_new, W_new, C_new = patch_tokens.shape
# Reshape back to sequence format
patch_tokens = patch_tokens.reshape(B, H_new * W_new, C_new)
# Project prefix tokens to match new channel dimension
prefix_tokens = self.prefix_proj_blocks[downscale_idx](prefix_tokens) # (B, num_prefix, C')
# Concatenate prefix and patch tokens
x = torch.cat([prefix_tokens, patch_tokens], dim=1)
return x, H_new, W_new
def forward_encoder(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
num_tokens: Optional[int] = None,
use_last_tokens: bool = False,
) -> dict:
"""Forward pass through encoder only (embeddings, transformer blocks, token slicing).
Args:
x: Input image tensor of shape (B, C, H, W)
attn_mask: Optional attention mask
num_tokens: Number of 1D tokens to output per sample.
If None during training: samples per-sample from mode distribution
If None during inference: uses expected number of tokens
If negative, uses dynamic rate prediction
use_last_tokens: If True, take the last num_tokens instead of the first
Returns:
Dict with keys:
- "encoder": (B, num_prefix + max_tokens, C) - prefix tokens + 1D global tokens
- "global_tokens": (B, max_tokens, C) - sliced global tokens (for decoder input)
- "global_token_mask": (B, max_tokens) - validity mask for global tokens
- "encoder_spatial_size": (H, W) - spatial dimensions after encoding
- "original_spatial_size": (H, W) - original spatial dimensions before padding
"""
B = x.shape[0]
# Infer spatial dimensions from input image before patch embedding
_, _, H_img, W_img = x.shape
if self.patch_embed is not None:
patch_size = self.patch_embed.patch_size[0]
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
else:
images = x # Save for visualization
patch_size = self.patch_generator.patch_size
x = self.patch_generator(x)
# Compute spatial dimensions (in patches) for downscaling
H = H_img // patch_size
W = W_img // patch_size
# Save original dimensions before any padding during downscaling
original_H, original_W = H, W
# Sample num_tokens per sample, clamped to available spatial tokens
num_spatial_tokens = H * W
total_downscale = math.prod(ds.downscale for ds in self.downscale_blocks)
num_spatial_tokens //= total_downscale
if num_tokens is not None:
num_tokens = min(num_tokens, num_spatial_tokens)
num_tokens_per_sample = torch.full((B,), num_tokens, dtype=torch.long, device=x.device)
else:
if self.training:
num_tokens_per_sample = self.k_sampler.sample(B, max_tokens=num_spatial_tokens)
else:
# In eval mode, return all available tokens if num_tokens not specified
num_tokens = num_spatial_tokens
num_tokens_per_sample = torch.full((B,), num_tokens, dtype=torch.long, device=x.device)
is_dynamic = False
if self.dynamic_rate and (num_tokens is None or num_tokens < 0):
if num_tokens is not None:
num_tokens = -num_tokens
target_rate_pct = num_tokens_per_sample.float() / num_spatial_tokens
rate_vec = (target_rate_pct * 2 - 1).unsqueeze(1) * self.dynamic_rate_vec.unsqueeze(0)
x0 = x[:, :self.num_cls_tokens]
x1 = x[:, self.num_cls_tokens + 1:]
# Replace the first register instead with this dynamic rate vector, allowing
# the model to know what the target rate will be
x = torch.cat([x0, rate_vec.unsqueeze(1), x1], dim=1)
is_dynamic = True
# Apply transformer blocks with optional downscaling
downscale_idx = 0
use_checkpoint = self.grad_checkpointing and not torch.jit.is_scripting()
first_downscale_level = min(self.downscale_levels)
last_downscale_level = max(self.downscale_levels)
curr_num_tokens = None
total_num_to_drop = None
if attn_mask is not None:
raise NotImplementedError("attn_mask not currently supported at input.")
for i, blk in enumerate(self.blocks):
# Apply downscale before this block if specified
if i in self.downscale_levels:
if i == first_downscale_level and is_dynamic:
dyn_token = x[:, self.num_cls_tokens]
dyn_pred_logits = self.dynamic_rate_projector(dyn_token).squeeze(1)
if self.training:
gumbels = torch.rand(*dyn_pred_logits.shape, 2, dtype=dyn_pred_logits.dtype, device=dyn_pred_logits.device).clamp_min_(1e-8)
gumbels = -(-gumbels.log()).log() # Sample from Gumbel(0,1)
gumbels[..., 1].neg_()
gumbels = gumbels.sum(dim=-1)
dyn_pred = dyn_pred_logits + gumbels / self.dynamic_temperature
else:
dyn_pred = dyn_pred_logits
dyn_pred = F.sigmoid(dyn_pred)
dyn_keep = 1 + dyn_pred * (num_spatial_tokens - 1)
dyn_keep = round_ste(dyn_keep)
target_num_tokens_per_sample = num_tokens_per_sample
num_tokens_per_sample = dyn_keep
x, H, W = self._apply_downscale(x, downscale_idx, H, W)
downscale_idx += 1
if self.progressive_reduction:
if i == last_downscale_level:
assert attn_mask is None
attn_mask = torch.ones(B, x.shape[1], dtype=torch.bool, device=x.device)
curr_num_tokens = torch.full((B,), x.shape[1] - self.num_prefix_tokens, dtype=torch.int64, device=x.device)
ct_seq = torch.arange(0, x.shape[1], dtype=torch.int64, device=x.device).unsqueeze(0).expand(B, -1)
elif i > last_downscale_level:
assert curr_num_tokens is not None
num_remaining_blocks = len(self.blocks) - i
curr_num_to_drop = (curr_num_tokens - num_tokens_per_sample.long()) // num_remaining_blocks
max_keep = curr_num_tokens + self.num_prefix_tokens - curr_num_to_drop
curr_attn_mask = ct_seq[:, :x.shape[1]] < max_keep.unsqueeze(1)
attn_mask = attn_mask & curr_attn_mask
curr_num_tokens -= curr_num_to_drop
# We only need to keep up to the longest valid sequence, so this allows us to progressively
# truncate x, saving on compute
is_valid_for_any = torch.any(attn_mask, dim=0)
curr_valid_length = torch.count_nonzero(is_valid_for_any, dim=0).item()
x = x[:, :curr_valid_length]
attn_mask = attn_mask[:, :curr_valid_length]
my_use_checkpoint = use_checkpoint and i < self.grad_checkpointing
fwd_fn = partial(checkpoint, blk, use_reentrant=False) if my_use_checkpoint else blk
# Apply transformer block
if attn_mask is not None:
if attn_mask.ndim == 4:
my_mask = attn_mask
elif attn_mask.ndim == 2:
my_mask = attn_mask.unsqueeze(1).unsqueeze(1)
else:
my_mask = None
# x = fwd_fn(x, attn_mask=my_mask)
x = fwd_fn(x)
x = self.norm(x)
#if self.patch_generator is not None:
# x = self.patch_generator.broadcast_masks(x, apt_masks, pos_enc=pos_enc)
# self.patch_generator.maybe_visualize(images, x, apt_masks, self)
self._total_num_tokens += num_tokens_per_sample.sum().long()
self._total_num_samples += B
# Log token counts every 100 iterations
self._iter_count += 1
if self._iter_count % 100 == 0 and get_rank() == 0:
# min_tokens = num_tokens_per_sample.min().item()
# max_tokens = num_tokens_per_sample.max().item()
# mean_tokens = num_tokens_per_sample.float().mean().item()
# print(f"[RADIO1D iter {self._iter_count.item()}] 1D tokens: min={min_tokens}, max={max_tokens}, mean={mean_tokens:.1f}, use_last_tokens={use_last_tokens}")
mean_tokens = self._total_num_tokens.float() / self._total_num_samples.float()
print(f"[RADIO1D iter {self._iter_count.item()}] avg 1D tokens: {mean_tokens.item():.2f}")
# Slice tokens with per-sample counts, padded to max in this micro-batch
prefix_tokens, global_tokens, global_token_mask = slice_1d_tokens(
x,
num_tokens_per_sample,
num_prefix_tokens=self.num_prefix_tokens,
use_last_tokens=use_last_tokens,
dynamic=self.dynamic_rate,
)
# Return encoder output with metadata needed for decoder
encoder_output = torch.cat([prefix_tokens, global_tokens], dim=1)
ret = {
"encoder": encoder_output,
"global_tokens": global_tokens,
"global_token_mask": global_token_mask,
"encoder_spatial_size": (H, W),
"original_spatial_size": (original_H, original_W),
}
if is_dynamic:
ret["dynamic_rate"] = dict(
pred_rate=num_tokens_per_sample,
pred_pct=dyn_pred,
target_rate=target_num_tokens_per_sample,
target_pct=target_num_tokens_per_sample / num_spatial_tokens,
pred_logits=dyn_pred_logits,
num_spatial_tokens=num_spatial_tokens,
)
self.apply_aux_losses(ret)
return ret
def forward_decoder(
self,
global_tokens: torch.Tensor,
global_token_mask: torch.Tensor,
encoder_spatial_size: Tuple[int, int],
original_spatial_size: Tuple[int, int],
) -> torch.Tensor:
"""Forward pass through decoder only.
Args:
global_tokens: Global tokens from encoder (B, max_tokens, C)
global_token_mask: Boolean mask for valid global tokens (B, max_tokens)
encoder_spatial_size: Tuple of (H, W) spatial dimensions after encoding
original_spatial_size: Tuple of (H, W) original spatial dimensions before padding
Returns:
Decoded features (B, num_prefix + H*W, target_embed_dim)
"""
B = global_tokens.shape[0]
H, W = encoder_spatial_size
original_H, original_W = original_spatial_size
# Decode back to original resolution (decoder uses its own prefix tokens)
decoded, decoded_H, decoded_W = self.decoder(
global_tokens=global_tokens,
global_token_mask=global_token_mask,
input_size=(H, W),
)
# Crop to original dimensions if needed (handles odd H/W that were padded during downscaling)
if decoded_H != original_H or decoded_W != original_W:
prefix = decoded[:, :self.num_prefix_tokens]
patches = decoded[:, self.num_prefix_tokens:] # (B, decoded_H*decoded_W, C)
patches = patches.reshape(B, decoded_H, decoded_W, -1)
patches = patches[:, :original_H, :original_W, :].reshape(B, original_H * original_W, -1)
decoded = torch.cat([prefix, patches], dim=1)
return decoded
def apply_aux_losses(self, encoder_result: dict):
if not self.training or not self.dynamic_rate:
return
dyn_dict = encoder_result['dynamic_rate']
pred_rate = dyn_dict['pred_rate']
pred_pct = dyn_dict['pred_pct']
target_rate = dyn_dict['target_rate']
target_pct = dyn_dict['target_pct']
pred_logits = dyn_dict['pred_logits']
num_spatial_tokens = dyn_dict['num_spatial_tokens']
pred_local_num_tokens = pred_rate.sum()
pred_global_num_tokens = pred_local_num_tokens.clone()
local_num_tokens = torch.tensor(num_spatial_tokens * pred_rate.shape[0], dtype=torch.float32, device=pred_rate.device)
global_num_tokens = local_num_tokens.clone()
if dist.is_initialized():
dist.all_reduce(global_num_tokens, op=dist.ReduceOp.SUM)
pred_global_num_tokens = all_reduce_with_gradients(pred_global_num_tokens, op=dist.ReduceOp.SUM)
global_pred_pct = pred_global_num_tokens / global_num_tokens
loss_rate = F.mse_loss(global_pred_pct, target_pct[0])
aux_losses: Dict[str, torch.Tensor] = getattr(self, 'auxiliary_losses', dict())
self.auxiliary_losses = aux_losses
aux_losses['dynamic_rate_mse'] = 1.0 * loss_rate.mean()
quantile = 0.98
quantile_sym = (1.0 - quantile) / 2 + quantile
log_q = math.log(quantile_sym / (1 - quantile_sym))
logit_threshold = log_q / self.dynamic_temperature
logit_excess = F.relu(torch.abs(pred_logits) - logit_threshold).pow(2)
aux_losses['dynamic_rate_logit_penalty'] = 0.1 * logit_excess.mean()
aux_losses['dynamic_rate_abs_diff'] = (global_pred_pct - target_pct[0]).abs().detach()
# caps = ', '.join(f'{v * 100:.1f}%' for v in pred_pct[:4].tolist())
# viz_caption = f"Dynamic Rate Pred: Target: {target_pct[0].item() * 100:.1f}%, Achieved: {global_pred_pct.item() * 100:.1f}%, Pred: [{caps}]"
# FeatureDistillationLoss.VIZ_CAPTION = viz_caption
pass
def forward_features(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
num_tokens: Optional[int] = None,
use_last_tokens: bool = False,
neck_name: Optional[str] = None,
) -> dict:
"""Forward pass through encoder and (optionally) decoder.
Args:
x: Input image tensor of shape (B, C, H, W)
attn_mask: Optional attention mask
num_tokens: Number of 1D tokens to output per sample.
If None during training: samples per-sample from mode distribution
If None during inference: uses max(modes)
use_last_tokens: If True, take the last num_tokens instead of the first
neck_name: If "encoder", skip the decoder pass and return only the
encoder output. If "decoder" or None, run both (the decoder
always depends on the encoder output).
Returns:
Dict with keys:
- "encoder": (B, num_prefix + max_tokens, C) - prefix tokens + 1D global tokens
- "decoder": (B, num_prefix + H*W, target_embed_dim) - reconstructed full sequence
(omitted when neck_name == "encoder")
"""
encoder_result = self.forward_encoder(x, attn_mask=attn_mask, num_tokens=num_tokens, use_last_tokens=use_last_tokens)
if neck_name == "encoder":
return {"encoder": encoder_result["encoder"]}
decoded = self.forward_decoder(
global_tokens=encoder_result["global_tokens"],
global_token_mask=encoder_result["global_token_mask"],
encoder_spatial_size=encoder_result["encoder_spatial_size"],
original_spatial_size=encoder_result["original_spatial_size"],
)
# encoder: [prefix_tokens (cls + registers), global_tokens]
# decoder: [prefix_tokens, decoded_patches] (already concatenated by decoder)
return {"encoder": encoder_result["encoder"], "decoder": decoded}
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
output_dict: bool = False,
attn_mask: Optional[torch.Tensor] = None,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], dict]:
"""Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs ('NCHW' or 'NLC')
intermediates_only: Only return intermediate features
output_dict: Return outputs as a dictionary
attn_mask: Optional attention mask
Returns:
Depends on flags:
- intermediates_only=True: List of intermediate features
- output_dict=True: Dict with 'image_features' and 'image_intermediates'
- Otherwise: Tuple of (final_features, intermediates)
"""
assert output_fmt in ('NCHW', 'NLC'), f"Invalid output_fmt: {output_fmt}"
# Determine which block indices to collect. Match timm semantics:
# - int -> "take the last N blocks"
# - list -> verbatim, with Python-style negative indexing relative to num_blocks
num_blocks = len(self.blocks)
if indices is None:
take_indices = list(range(num_blocks))
elif isinstance(indices, int):
take_indices = list(range(max(0, num_blocks - indices), num_blocks))
else:
take_indices = [i if i >= 0 else num_blocks + i for i in indices]
max_index = max(take_indices) if take_indices else num_blocks - 1
# Infer spatial dimensions from input image before patch embedding
B, _, H_img, W_img = x.shape
if self.patch_embed is not None:
patch_size = self.patch_embed.patch_size[0]
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
apt_masks = None
pos_enc = None
else:
images = x
patch_size = self.patch_generator.patch_size
x = self.patch_generator(x)
#if apt_attn_mask is not None:
# attn_mask = apt_attn_mask
# Compute spatial dimensions (in patches) for downscaling
H = H_img // patch_size
W = W_img // patch_size
# Collect intermediate activations
intermediates = []
intermediates_prefix = [] if return_prefix_tokens else None
downscale_idx = 0
for i, blk in enumerate(self.blocks):
# Apply downscale before this block if specified
if i in self.downscale_levels:
x, H, W = self._apply_downscale(x, downscale_idx, H, W)
downscale_idx += 1
# Apply transformer block
if attn_mask is not None:
x = blk(x, attn_mask=attn_mask)
else:
x = blk(x)
# Collect intermediate if this index is requested
if i in take_indices:
# Get spatial tokens (excluding prefix tokens)
num_prefix = self.num_prefix_tokens
feat = x[:, num_prefix:] # (B, H*W, C)
if norm:
feat = self.norm(feat)
# Reshape to output format
if output_fmt == 'NCHW':
C = feat.shape[-1]
feat = feat.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
# else 'NLC' - keep as is
intermediates.append(feat)
if return_prefix_tokens:
prefix = x[:, :num_prefix]
if norm:
prefix = self.norm(prefix)
intermediates_prefix.append(prefix)
# Stop early if we've collected all needed intermediates
if stop_early and i >= max_index:
break
# Compute final features if needed
if not intermediates_only:
# Continue from where we left off if we stopped early
if stop_early and max_index < num_blocks - 1:
for i in range(max_index + 1, num_blocks):
if i in self.downscale_levels:
x, H, W = self._apply_downscale(x, downscale_idx, H, W)
downscale_idx += 1
if attn_mask is not None:
x = blk(x, attn_mask=attn_mask)
else:
x = self.blocks[i](x)
x = self.norm(x)
if self.patch_generator is not None:
x = self.patch_generator.broadcast_masks(x, apt_masks, pos_enc=pos_enc)
# Match the canonical (timm-ViT) `forward_intermediates` output shape:
# when prefix tokens are requested, return a list of (prefix, features)
# tuples so callers can iterate `for summary, features in intermediates`.
if return_prefix_tokens and not output_dict:
intermediates = list(zip(intermediates_prefix, intermediates))
if output_dict:
result = {
'image_intermediates': intermediates,
}
if not intermediates_only:
result['image_features'] = x
if return_prefix_tokens:
result['image_intermediates_prefix'] = intermediates_prefix
return result
elif intermediates_only:
return intermediates
else:
return x, intermediates
def get_first_downscale_block_idx(self) -> Optional[int]:
"""Return the index of the first downscaling block, or None if no downscaling."""
if not self.downscale_levels:
return None
return min(self.downscale_levels)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
num_tokens: Optional[int] = None,
use_last_tokens: bool = False,
) -> dict:
"""Forward pass through encoder only.
Args:
x: Input image tensor of shape (B, C, H, W)
attn_mask: Optional attention mask
num_tokens: Number of 1D tokens to use per sample (for slicing)
use_last_tokens: If True, take the last num_tokens instead of the first
Returns:
Dict with keys:
- "encoder": (B, num_prefix + max_tokens, C) - prefix tokens + 1D global tokens
- "global_tokens": (B, max_tokens, C) - sliced global tokens (for decoder input)
- "global_token_mask": (B, max_tokens) - validity mask for global tokens
- "encoder_spatial_size": (H, W) - spatial dimensions after encoding
- "original_spatial_size": (H, W) - original spatial dimensions before padding
"""
return self.forward_encoder(x, attn_mask=attn_mask, num_tokens=num_tokens, use_last_tokens=use_last_tokens)
@register_model
def radio1d_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
if pretrained:
raise ValueError('There is no pretrained weights for radio1d_large_patch16_224')
model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('radio1d_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def radio1d_so400m_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT model matching the architecture of the So400M model from
"Scaling Vision Transformers to 400 Million Parameters" (https://arxiv.org/abs/2302.05442).
"""
if pretrained:
raise ValueError('There is no pretrained weights for vit_so400m_patch16_224')
mlp_ratio = 4304 / 1152
model_args = dict(patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=mlp_ratio)
model = _create_vision_transformer('radio1d_so400m_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def radio1d_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
if pretrained:
raise ValueError('There is no pretrained weights for radio1d_huge_patch16_224')
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
model = _create_vision_transformer('radio1d_huge_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
def magneto_init(model: VisionTransformer, num_blocks: int = None):
'''
Initialization following [Magneto](http://arxiv.org/abs/2210.06423)
'''
attention_modules = [m for m in model.modules() if isinstance(m, Attention)]
mlp_modules = [m for m in model.modules() if isinstance(m, Mlp)]
if num_blocks is None:
num_blocks = len(model.blocks)
gamma = math.sqrt(math.log(2 * num_blocks))
for m in attention_modules:
qkv = m.qkv
q, k, v = qkv.weight.data.chunk(3, dim=0)
xavier_normal_(q, gain=1)
xavier_normal_(k, gain=1)
xavier_normal_(v, gain=gamma)
xavier_normal_(m.proj.weight.data, gain=gamma)
for m in mlp_modules:
xavier_normal_(m.fc1.weight.data, gain=gamma)
xavier_normal_(m.fc2.weight.data, gain=gamma)
def _init_layerscale(model: VisionTransformer):
# https://proceedings.neurips.cc/paper_files/paper/2022/file/ae0cba715b60c4052359b3d52a2cff7f-Paper-Conference.pdf
for i, block in enumerate(model.blocks):
if isinstance(block, Block):
ls = 1 / math.sqrt(i + 1)
block.ls1.gamma.data.fill_(ls)
block.ls2.gamma.data.fill_(ls)
elif isinstance(block, PatchMerging):
ls = 1 / math.sqrt(i + 1)
block.reduction.weight.data.fill_(ls)
def _create_vision_transformer(name, pretrained=False, **kwargs):
model = build_model_with_cfg(RADIO1D, name, pretrained=pretrained, **kwargs)
if not pretrained:
magneto_init(model)
if kwargs.get('init_values', None) == -1234:
_init_layerscale(model)
return model