Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import copy | |
| from functools import partial | |
| from typing import Optional, Union | |
| import torch | |
| from torch import nn | |
| from fourm.utils.timm.registry import register_model | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from .encoder_embeddings import ImageEncoderEmbedding | |
| from .fm_utils import Block, LayerNorm | |
| from fourm.data.modality_info import MODALITY_INFO | |
| __all__ = [ | |
| # GELU models | |
| 'fm_vit_tiny_6e_gelu', | |
| 'fm_vit_small_8e_gelu', | |
| 'fm_vit_base_12e_gelu', | |
| 'fm_vit_large_24e_gelu', | |
| 'fm_vit_xlarge_24e_gelu', | |
| # SwiGLU models | |
| 'fm_vit_tiny_6e_swiglu_nobias', | |
| 'fm_vit_small_8e_swiglu_nobias', | |
| 'fm_vit_base_12e_swiglu_nobias', | |
| 'fm_vit_large_24e_swiglu_nobias', | |
| 'fm_vit_xlarge_24e_swiglu_nobias', | |
| # SwiGLU + QKNorm models | |
| 'fm_vit_base_12e_swiglu_qknorm_nobias', | |
| 'fm_vit_large_24e_swiglu_qknorm_nobias', | |
| 'fm_vit_xlarge_24e_swiglu_qknorm_nobias', | |
| ] | |
| class FourMViT(nn.Module): | |
| """Modified 4M model, adapted to behave as a simple RGB-only ViT. | |
| Args: | |
| img_size (int): Input image size. | |
| patch_size (int): Patch size. | |
| in_chans (int): Number of input image channels. | |
| dim (int): Patch embedding dimension. | |
| encoder_depth (int): Depth of ViT / number of encoder blocks. | |
| num_heads (int): Number of attention heads in each ViT block. | |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
| proj_bias (bool): If True, adds a bias to the attention out proj layer. | |
| mlp_bias (bool): If True, adds a learnable bias for the feedforward. | |
| drop_path_rate (float): Stochastic depth rate. | |
| drop_rate (float): Dropout rate. | |
| attn_drop_rate (float): Attention dropout rate. | |
| act_layer (nn.Module): Activation layer. | |
| norm_layer (nn.Module): Normalization layer. | |
| gated_mlp (bool): If True, makes the feedforward gated (e.g., for SwiGLU) | |
| qk_norm (bool): If True, normalizes the query and keys (as in ViT-22B) | |
| use_act_checkpoint (bool): If True, use activation checkpointing. | |
| encoder_norm (bool): If True, adds a norm layer after the last encoder block. | |
| output_head (Optional[nn.Module]): Optional output head after the encoder | |
| """ | |
| def __init__( | |
| self, | |
| img_size=224, | |
| patch_size=16, | |
| in_chans=3, | |
| dim=768, | |
| encoder_depth=12, | |
| num_heads=12, | |
| mlp_ratio=4.0, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| mlp_bias: bool = True, | |
| drop_path_rate: float =0.0, | |
| drop_rate: float = 0.0, | |
| attn_drop_rate: float =0.0, | |
| act_layer: torch.Tensor =nn.GELU, | |
| norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6), | |
| gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU | |
| qk_norm: bool = False, | |
| encoder_norm = True, | |
| output_head: Optional[nn.Module] = None, | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.init_std = 0.02 | |
| rgb_embedding = ImageEncoderEmbedding(num_channels=in_chans, patch_size=patch_size, | |
| dim_tokens=dim, sincos_pos_emb=True, image_size=img_size) | |
| self.num_patches = rgb_embedding.num_patches | |
| self.encoder_embeddings = nn.ModuleDict({f"rgb@{img_size}": rgb_embedding}) | |
| # stochastic depth decay rule | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)] | |
| self.encoder = nn.ModuleList([ | |
| Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias, | |
| drop_path=dpr[i], drop=drop_rate, attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, | |
| gated_mlp=gated_mlp, qk_norm=qk_norm) | |
| for i in range(encoder_depth) | |
| ]) | |
| self.encoder_norm = norm_layer(dim) if encoder_norm else nn.Identity() | |
| # Weight init | |
| self.init_weights() | |
| # Classification head is initialized after init_weights() to allow for special init scale | |
| if output_head is not None: | |
| self.output_head = output_head | |
| if hasattr(self.output_head, 'init'): | |
| self.output_head.init(dim) | |
| else: | |
| self.output_head = nn.Identity() | |
| def init_weights(self): | |
| """Weight initialization following MAE's initialization scheme""" | |
| for name, m in self.named_modules(): | |
| # Skipping tokenizers to avoid reinitializing them | |
| if "tokenizer" in name: | |
| continue | |
| # Linear | |
| elif isinstance(m, nn.Linear): | |
| if 'qkv' in name: | |
| # treat the weights of Q, K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| elif 'kv' in name: | |
| # treat the weights of K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| else: | |
| nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| # LayerNorm | |
| elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm): | |
| nn.init.constant_(m.weight, 1.0) | |
| nn.init.constant_(m.bias, 0) | |
| # Embedding | |
| elif isinstance(m, nn.Embedding): | |
| nn.init.normal_(m.weight, std=self.init_std) | |
| # Conv2d | |
| elif isinstance(m, nn.Conv2d): | |
| if '.proj' in name: | |
| # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
| w = m.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| def get_num_layers_encoder(self): | |
| return len(self.encoder) | |
| def get_num_layers(self): | |
| return self.get_num_layers_encoder() | |
| def no_weight_decay(self): | |
| no_wd_set = set() | |
| for mod, emb_module in self.encoder_embeddings.items(): | |
| if hasattr(emb_module, 'no_weight_decay'): | |
| to_skip = emb_module.no_weight_decay() | |
| to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip]) | |
| no_wd_set = no_wd_set | to_skip | |
| return no_wd_set | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass of the model. | |
| Args: | |
| x (torch.Tensor): Input tensor. Shape (B, C, H, W) | |
| Returns: | |
| torch.Tensor: Output tensor. Shape (B, num_classes). | |
| """ | |
| rgb_dict = {'tensor': x} | |
| rgb_dict = self.encoder_embeddings[f'rgb@{self.img_size}'](rgb_dict) | |
| # Add embeddings to patchified RGB image | |
| x = rgb_dict['x'] + rgb_dict['emb'] # Shape: (B, N, D) with N = num_patches | |
| for blk in self.encoder: | |
| x = blk(x) | |
| x = self.encoder_norm(x) # Shape: (B, N, D) | |
| out = self.output_head(x) | |
| return out | |
| def freeze_encoder(self, freeze_embeddings=True): | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.encoder_norm.parameters(): | |
| param.requires_grad = False | |
| if freeze_embeddings: | |
| for param in self.encoder_embeddings.parameters(): | |
| param.requires_grad = False | |
| def unfreeze_encoder(self, unfreeze_embeddings=True): | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = True | |
| for param in self.encoder_norm.parameters(): | |
| param.requires_grad = True | |
| if unfreeze_embeddings: | |
| for param in self.encoder_embeddings.parameters(): | |
| param.requires_grad = True | |
| ################################################ | |
| # Wrapper for easy loading with Huggingface Hub | |
| class FMViT(FourMViT, PyTorchModelHubMixin): | |
| """Wrapper around FourMViT for easy loading with Huggingface Hub. | |
| Args: | |
| config (dict): Dictionary containing the model and modality configuration, | |
| used for loading from Huggingface Hub. | |
| output_head (nn.Module): Optional output head. | |
| """ | |
| def __init__(self, config: dict, output_head: Optional[nn.Module] = None): | |
| config = copy.deepcopy(config) | |
| config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias']) | |
| config['act_layer'] = getattr(torch.nn, config['act_layer']) | |
| img_size = config['image_size'] | |
| config['img_size'] = img_size | |
| config['patch_size'] = MODALITY_INFO[f'rgb@{img_size}'].get('patch_size', config['patch_size']) | |
| config['in_chans'] = MODALITY_INFO[f'rgb@{img_size}'].get('num_channels', 3) | |
| for key in ['image_size', 'norm_bias', 'domains_in', 'domains_out', 'decoder_depth', 'share_modality_embeddings']: | |
| if key in config: | |
| del config[key] | |
| super().__init__( | |
| output_head=output_head, | |
| **config | |
| ) | |
| ################################################ | |
| # Model definitions | |
| # GELU variants | |
| def fm_vit_tiny_6e_gelu(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=6, | |
| dim=384, | |
| num_heads=6, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_small_8e_gelu(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=8, | |
| dim=512, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_base_12e_gelu(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=12, | |
| dim=768, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_large_24e_gelu(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=24, | |
| dim=1024, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_xlarge_24e_gelu(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=24, | |
| dim=2048, | |
| num_heads=32, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| # SwiGLU variants | |
| def fm_vit_tiny_6e_swiglu_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=6, | |
| dim=384, | |
| num_heads=6, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_small_8e_swiglu_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=8, | |
| dim=512, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_base_12e_swiglu_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=12, | |
| dim=768, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_large_24e_swiglu_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=24, | |
| dim=1024, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_xlarge_24e_swiglu_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=24, | |
| dim=2048, | |
| num_heads=32, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| # SwiGLU + QKNorm variants | |
| def fm_vit_base_12e_swiglu_qknorm_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=12, | |
| dim=768, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| qk_norm=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_large_24e_swiglu_qknorm_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=24, | |
| dim=1024, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| qk_norm=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_vit_xlarge_24e_swiglu_qknorm_nobias(**kwargs): | |
| model = FourMViT( | |
| encoder_depth=24, | |
| dim=2048, | |
| num_heads=32, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| qk_norm=True, | |
| **kwargs | |
| ) | |
| return model |