# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import math from typing import Dict, Optional import torch from torch import nn from einops import rearrange from timm.models.vision_transformer import Block from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam class MLPBase(nn.Module): def __init__( self, requires_summary_and_spatial: bool, handles_summary_and_spatial: bool = False ) -> None: super().__init__() self.requires_summary_and_spatial = requires_summary_and_spatial self.handles_summary_and_spatial = handles_summary_and_spatial assert not handles_summary_and_spatial or requires_summary_and_spatial, "If handles summary and spatial, must require it too!" class MLP(MLPBase): def __init__(self, input_size: int, hidden_size: int, output_size: int, num_inner: int = 0, device: torch.device = None, **kwargs): super(MLP, self).__init__(requires_summary_and_spatial=False) self.fc1 = nn.Linear(input_size, hidden_size, device=device) self.norm = nn.LayerNorm(hidden_size, device=device) self.relu = nn.ReLU() inner = [] for _ in range(num_inner): inner.extend([ nn.Linear(hidden_size, hidden_size, device=device), nn.LayerNorm(hidden_size, device=device), nn.ReLU(), ]) if inner: self.inner = nn.Sequential(*inner) else: self.inner = nn.Identity() self.fc2 = nn.Linear(hidden_size, output_size, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.norm(x) x = self.relu(x) x = self.inner(x) x = self.fc2(x) return x class MLP2(MLPBase): def __init__(self, input_size: int, hidden_size: int, output_size: int, num_inner: int = 0, pre_norm: bool = False, device: torch.device = None, upsample_factor: int = 1, upsample_rank: int = None, from_config: bool = False, **kwargs): super().__init__(requires_summary_and_spatial=False) self.pre_norm = nn.Sequential( nn.LayerNorm(input_size), nn.GELU(), ) if pre_norm else nn.Identity() self.upsample_factor = upsample_factor sq_ups = upsample_factor ** 2 self._real_output_dim = output_size // sq_ups # hidden_size *= upsample_factor # output_size *= (upsample_factor ** 2) self.fc1 = nn.Linear(input_size, hidden_size, device=device) blocks = [] for _ in range(num_inner): blocks.append(nn.Sequential( nn.LayerNorm(hidden_size, device=device), nn.GELU(), nn.Linear(hidden_size, hidden_size, device=device), )) self.blocks = nn.ModuleList(blocks) self.final = nn.Sequential( nn.LayerNorm(hidden_size, device=device), nn.GELU(), nn.Linear(hidden_size, output_size, device=device), ) def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor: x = self.pre_norm(x) x = self.fc1(x) for block in self.blocks: x = x + block(x) x = self.final(x) if self.upsample_factor > 1: if images is None: raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!') if patch_size is None: raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!') h, w = tuple(d // patch_size for d in images.shape[-2:]) x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c', h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor, c=self._real_output_dim) return x class AttnFDHead(MLPBase): def __init__( self, input_size: int, hidden_size: int, output_size: int, num_inner: int = 0, pre_norm: bool = False, device: torch.device = None, upsample_factor: int = 1, upsample_rank: int = 0, **kwargs # Ignore kwargs that might be to other "mlp" verions, e.g. teacher_summary_idxs ) -> None: super().__init__(requires_summary_and_spatial=False) from timm.models.vision_transformer import Block self.blocks = nn.Sequential(*[ Block(input_size, num_heads=16, init_values=1e-5) for _ in range(2) ]) self.mlp = MLP2(input_size, hidden_size, output_size, num_inner=0, pre_norm=pre_norm, device=device, upsample_factor=upsample_factor, upsample_rank=upsample_rank, **kwargs) def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: x = self.blocks(x) x = self.mlp(x) return x MLP_SUMMARY_FACTORY = { 'v1': MLP, 'v2': MLP2, } MLP_FD_FACTORY = { 'v1': MLP, 'v2': MLP2, 'attn': AttnFDHead, } def strip_prefix(state: Dict[str, torch.Tensor], prefix: str): state = { k[len(prefix):]: v for k, v in state.items() if k.startswith(prefix) } return state def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False): state = strip_prefix(state, prefix) weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original' if version == 'v1': hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape output_dim = state[f'fc2.{weight_suffix}'].shape[0] for num_inner in range(1000): k = f'inner.{num_inner}.0.weight' if k not in state: break elif version == 'v2': hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape output_dim = state[f'final.2.{weight_suffix}'].shape[0] for num_inner in range(1000): k = f'blocks.{num_inner}.0.weight' if k not in state: break elif version == 'attn': hidden_dim, input_dim = state[f'mlp.fc1.{weight_suffix}'].shape output_dim = state[f'mlp.final.2.{weight_suffix}'].shape[0] num_inner = 0 else: raise ValueError(f'Unsupported MLP version: {version}') return input_dim, hidden_dim, output_dim, num_inner def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, is_summary: bool = True, **kwargs): factory = MLP_SUMMARY_FACTORY if is_summary else MLP_FD_FACTORY ret: nn.Module = factory[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs) return ret def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, is_summary: bool = True, **kwargs): state = strip_prefix(state, prefix) input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights) ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, is_summary=is_summary, **kwargs) if spectral_weights: enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state) ret.load_state_dict(state) if spectral_weights: disable_spectral_reparam(ret) return ret