Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import random | |
| import copy | |
| from functools import partial | |
| from typing import Any, Dict, Optional, Tuple, Union | |
| import torch | |
| from einops import rearrange, repeat | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from fourm.utils.timm.registry import register_model | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from .fm_utils import Block, DecoderBlock, LayerNorm | |
| from fourm.data.modality_info import MODALITY_INFO | |
| # Model definitions | |
| __all__ = [ | |
| # GELU models | |
| 'fm_tiny_6e_6d_gelu', | |
| 'fm_small_8e_8d_gelu', | |
| 'fm_base_12e_12d_gelu', | |
| 'fm_large_24e_24d_gelu', | |
| 'fm_xlarge_24e_24d_gelu', | |
| # SwiGLU models | |
| 'fm_tiny_6e_6d_swiglu_nobias', | |
| 'fm_small_8e_8d_swiglu_nobias', | |
| 'fm_base_12e_12d_swiglu_nobias', | |
| 'fm_large_24e_24d_swiglu_nobias', | |
| 'fm_xlarge_24e_24d_swiglu_nobias', | |
| # SwiGLU + QKNorm models | |
| 'fm_base_12e_12d_swiglu_qknorm_nobias', | |
| 'fm_large_24e_24d_swiglu_qknorm_nobias', | |
| 'fm_xlarge_24e_24d_swiglu_qknorm_nobias', | |
| ] | |
| class FourM(nn.Module): | |
| """4M model. | |
| Args: | |
| encoder_embeddings: Dict of encoder embedding modules. | |
| decoder_embeddings: Dict of decoder embedding modules. | |
| modality_info: Dict containing modality information. | |
| dim: Embedding dimension. | |
| encoder_depth: Number of encoder blocks. | |
| decoder_depth: Number of decoder blocks. | |
| num_heads: Number of attention heads. | |
| mlp_ratio: Ratio of mlp hidden dim to embedding dim. | |
| qkv_bias: If True, add a learnable bias to query, key, value projections. | |
| proj_bias: If True, add a learnable bias to the last projection of the attention block. | |
| mlp_bias: If True, add a learnable bias to linear layers in the MLP / feed-forward. | |
| drop_path_rate_encoder: Stochastic depth rate for encoder. | |
| drop_path_rate_decoder: Stochastic depth rate for decoder. | |
| shared_drop_path: If True, shares drop path between encoder and decoder. | |
| act_layer: Activation layer to be used. | |
| norm_layer: Normalization layer to be used. | |
| gated_mlp: If True, make the feedforward gated (e.g., SwiGLU). | |
| qk_norm: If True, applies normalization to queries and keys (QKNorm). | |
| decoder_causal_mask: If True, decoder will use a causal mask for all tokens. | |
| decoder_sep_mask: If True, decoder attention is restricted to within each modality only. | |
| num_register_tokens: Number of register tokens. | |
| use_act_checkpoint: If True, use activation checkpoint for each block. | |
| """ | |
| def __init__(self, | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| modality_info: Dict[str, Any], | |
| dim: int = 768, | |
| encoder_depth: int = 12, | |
| decoder_depth: int = 12, | |
| num_heads: int = 12, | |
| mlp_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| proj_bias: bool = True, | |
| mlp_bias: bool = True, | |
| drop_path_rate_encoder: float = 0.0, | |
| drop_path_rate_decoder: float = 0.0, | |
| shared_drop_path: bool = False, | |
| act_layer: nn.Module = nn.GELU, | |
| norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6), | |
| gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU | |
| qk_norm: bool = False, | |
| decoder_causal_mask: bool = False, | |
| decoder_sep_mask: bool = True, | |
| num_register_tokens: int = 0, | |
| use_act_checkpoint: bool = False, | |
| share_modality_embeddings: bool = True, | |
| ): | |
| super().__init__() | |
| self.modality_info = modality_info | |
| self.dim = dim | |
| self.decoder_causal_mask = decoder_causal_mask | |
| self.decoder_sep_mask = decoder_sep_mask | |
| self.init_std = 0.02 | |
| self.use_act_checkpoint = use_act_checkpoint | |
| self.num_register_tokens = num_register_tokens | |
| # Encoder embeddings & init | |
| self.encoder_modalities = set(encoder_embeddings.keys()) | |
| for emb in encoder_embeddings.values(): | |
| emb.init(dim_tokens=dim, init_std=self.init_std) | |
| self.encoder_embeddings = nn.ModuleDict(encoder_embeddings) | |
| # Decoder embeddings & init | |
| self.decoder_modalities = set(decoder_embeddings.keys()) | |
| for emb in decoder_embeddings.values(): | |
| emb.init(dim_tokens=dim, init_std=self.init_std) | |
| self.decoder_embeddings = nn.ModuleDict(decoder_embeddings) | |
| # Share modality embeddings across the encoder and decoder embedding modules | |
| if share_modality_embeddings: | |
| self.share_modality_embeddings() | |
| ## Transformer encoder | |
| if shared_drop_path: | |
| dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth + decoder_depth)][:encoder_depth] | |
| else: | |
| dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth)] # stochastic depth decay rule | |
| self.encoder = nn.ModuleList([ | |
| Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias, | |
| drop_path=dpr_encoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm) | |
| for i in range(encoder_depth) | |
| ]) | |
| self.encoder_norm = norm_layer(dim) | |
| ## Transformer decoder | |
| if shared_drop_path: | |
| dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, encoder_depth + decoder_depth)][encoder_depth:] | |
| else: | |
| dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, decoder_depth)] # stochastic depth decay rule | |
| # Projection of encoder tokens before adding the embeddings again | |
| self.decoder_proj_context = nn.Linear(dim, dim) | |
| self.decoder = nn.ModuleList([ | |
| DecoderBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias, | |
| drop_path=dpr_decoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm) | |
| for i in range(decoder_depth) | |
| ]) | |
| self.decoder_norm = norm_layer(dim) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, dim)) | |
| nn.init.normal_(self.mask_token, std=self.init_std) | |
| # Additional register tokens that can be used by the encoder during fine-tuning | |
| if self.num_register_tokens > 0: | |
| self.register_tokens = nn.Parameter(torch.zeros(1, self.num_register_tokens, dim)) | |
| nn.init.normal_(self.register_tokens, std=self.init_std) | |
| else: | |
| self.register_tokens = None | |
| # Weight init | |
| self.init_weights() | |
| def share_modality_embeddings(self): | |
| """Share modality embeddings across the encoder and decoder embedding modules.""" | |
| shared_modalities = self.encoder_modalities & self.decoder_modalities | |
| for mod in shared_modalities: | |
| self.decoder_embeddings[mod].mod_emb = self.encoder_embeddings[mod].mod_emb | |
| def init_weights(self): | |
| """Weight initialization following MAE's initialization scheme""" | |
| for name, m in self.named_modules(): | |
| # Skipping tokenizers to avoid reinitializing them | |
| if "tokenizer" in name: | |
| continue | |
| # Linear | |
| elif isinstance(m, nn.Linear): | |
| if 'qkv' in name: | |
| # treat the weights of Q, K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| elif 'kv' in name: | |
| # treat the weights of K, V separately | |
| val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) | |
| nn.init.uniform_(m.weight, -val, val) | |
| else: | |
| nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| # LayerNorm | |
| elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm): | |
| nn.init.constant_(m.weight, 1.0) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| # Embedding | |
| elif isinstance(m, nn.Embedding): | |
| nn.init.normal_(m.weight, std=self.init_std) | |
| # Conv2d | |
| elif isinstance(m, nn.Conv2d): | |
| if '.proj' in name: | |
| # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
| w = m.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| def get_num_layers_encoder(self): | |
| return len(self.encoder) | |
| def get_num_layers_decoder(self): | |
| return len(self.decoder) | |
| def get_num_layers(self): | |
| return self.get_num_layers_encoder() + self.get_num_layers_decoder() | |
| def no_weight_decay(self): | |
| no_wd_set = set() | |
| for mod, emb_module in self.encoder_embeddings.items(): | |
| if hasattr(emb_module, 'no_weight_decay'): | |
| to_skip = emb_module.no_weight_decay() | |
| to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip]) | |
| no_wd_set = no_wd_set | to_skip | |
| for mod, emb_module in self.decoder_embeddings.items(): | |
| if hasattr(emb_module, 'no_weight_decay'): | |
| to_skip = emb_module.no_weight_decay() | |
| to_skip = set([f'decoder_embeddings.{mod}.{name}' for name in to_skip]) | |
| no_wd_set = no_wd_set | to_skip | |
| return no_wd_set | |
| def cat_encoder_tensors(self, mod_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor]: | |
| """Concatenate encoder tensors from different modalities. | |
| Args: | |
| mod_dict (dict): A dictionary containing information for each modality. | |
| Expected keys for each modality are 'x' (input tokens), | |
| 'emb' (embeddings), 'input_mask', etc. | |
| Returns: | |
| tuple: | |
| - encoder_tokens_all (torch.Tensor): Concatenated encoder tokens from all modalities. Shape (B, O, D) where O is the total number of all encoder tokens. | |
| - emb_all (torch.Tensor): Concatenated encoder embeddings from all modalities. Shape (B, O, D) | |
| - encoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the encoder input (set to 0 for valid tokens, 1 otherwise). Shape (B, O) | |
| - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each encoder token. Shape (B, O) | |
| """ | |
| encoder_tokens_all = [] | |
| emb_all = [] | |
| encoder_mask_all = [] | |
| mod_mask_all = [] | |
| for mod, d in mod_dict.items(): | |
| encoder_tokens_all.append(d['x']) | |
| emb_all.append(d['emb']) | |
| encoder_mask_all.append(d['input_mask']) | |
| mod_mask_all.append(torch.full_like(d['input_mask'], self.modality_info[mod]['id'], dtype=torch.int16)) | |
| encoder_tokens_all = torch.cat(encoder_tokens_all, dim=1) | |
| emb_all = torch.cat(emb_all, dim=1) | |
| encoder_mask_all = torch.cat(encoder_mask_all, dim=1) | |
| mod_mask_all = torch.cat(mod_mask_all, dim=1) | |
| return encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all | |
| def cat_decoder_tensors(self, mod_dict: Dict[str, Dict[str, torch.Tensor]]) -> Tuple[torch.Tensor]: | |
| """Concatenate decoder tensors from different modalities. | |
| Args: | |
| mod_dict (dict): A dictionary containing information for each modality. | |
| Expected keys for each modality include 'x' (input tokens), | |
| 'ids' (target IDs), 'emb' (embeddings), 'target_mask', 'decoder_attention_mask', etc. | |
| Returns: | |
| tuple: | |
| - decoder_tokens_all (torch.Tensor): Concatenated decoder tokens from all modalities. Shape (B, P, D) where P is the total number of all decoder tokens. | |
| - emb_all (torch.Tensor): Concatenated decoder embeddings from all modalities. Shape (B, P, D) | |
| - decoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the decoder input / target (set to 0 for valid tokens, 1 otherwise). Shape (B, P) | |
| - target_ids_all (torch.Tensor): Concatenated target IDs from all modalities. Shape (B, P) | |
| - attention_mask_all (torch.Tensor): Concatenated attention masks in compressed format, needs to be passed to adapt_decoder_attention_mask() to obtain the final attention mask. Shape (B, P) | |
| - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each decoder token. Shape (B, P) | |
| """ | |
| decoder_tokens_all = [] | |
| target_ids_all = [] | |
| emb_all = [] | |
| decoder_mask_all = [] | |
| attention_mask_all = [] | |
| mod_mask_all = [] | |
| # Shuffle order in which modalities are provided (useful for modality causal mask) | |
| mod_dict = {mod: d for mod, d in random.sample(mod_dict.items(), len(mod_dict))} | |
| for mod, d in mod_dict.items(): | |
| if self.modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']: | |
| # Important: This makes the assumption that the target sequence appears sequentially | |
| # before sorting / gathering | |
| decoder_tokens_all.append(d['x'][:, :-1]) | |
| target_ids_all.append(d['ids'][:, 1:]) # Shifted left | |
| emb_all.append(d['emb'][:, :-1]) | |
| # Logical or with left shifting removes the last unmasked position | |
| decoder_mask_all.append(torch.logical_or(d['target_mask'][:, 1:], d['target_mask'][:, :-1])) | |
| # Add attention mask ids | |
| attention_mask_all.append(d['decoder_attention_mask'][:, :-1]) | |
| mod_mask_all.append(torch.full_like(d['ids'][:, :-1], self.modality_info[mod]['id'], dtype=torch.int16)) | |
| else: | |
| # Important: For 2d / image modalities, the decoder input tokens are replaced by the mask token | |
| decoder_tokens_all.append(torch.zeros_like(d['x']) + self.mask_token) # Replace x by mask token | |
| target_ids_all.append(d['ids']) | |
| emb_all.append(d['emb']) | |
| decoder_mask_all.append(d['target_mask']) | |
| attention_mask_all.append(d['decoder_attention_mask']) | |
| mod_mask_all.append(torch.full_like(d['ids'], self.modality_info[mod]['id'], dtype=torch.int16)) | |
| decoder_tokens_all = torch.cat(decoder_tokens_all, dim=1) | |
| emb_all = torch.cat(emb_all, dim=1) | |
| decoder_mask_all = torch.cat(decoder_mask_all, dim=1) | |
| target_ids_all = torch.cat(target_ids_all, dim=1) | |
| attention_mask_all = torch.cat(attention_mask_all, dim=1) | |
| mod_mask_all = torch.cat(mod_mask_all, dim=1) | |
| return decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, attention_mask_all, mod_mask_all | |
| def forward_mask_encoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_encoder_tokens: int) -> Tuple[torch.Tensor]: | |
| """Concatenates and mask encoder tensors based on provided modality information. | |
| This function consolidates encoder tokens from multiple modalities, then selects a specified number of them based on modality information (i.e. masking). | |
| Args: | |
| mod_dict (dict): Dictionary containing tensors for different modalities. | |
| It is expected to have keys for each modality and values | |
| containing the modalities' associated tensors. | |
| num_encoder_tokens (int): Number of encoder tokens to retain after masking. | |
| Returns: | |
| tuple: | |
| - encoder_tokens (torch.Tensor): Selected encoder tokens from all modalities. Shape (B, N, D) where N is the number of selected encoder tokens. | |
| - encoder_emb (torch.Tensor): Corresponding embeddings for encoder tokens. Shape (B, N, D) | |
| - encoder_mask (torch.Tensor): A boolean mask indicating which encoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N) | |
| - mod_mask (torch.Tensor): An integer mask marking the modality type for each encoder token (with -1 indicating unassigned pad tokens). Shape (B, N) | |
| Notes: | |
| - If `num_register_tokens` is set and greater than 0, register tokens are added at the beginning of the sequence. | |
| """ | |
| B = list(mod_dict.values())[0]['tensor'].shape[0] | |
| encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.cat_encoder_tensors(mod_dict) | |
| # Add arange multiplied by small constant to mask so they get sorted in a deterministic way | |
| mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6 | |
| ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1) | |
| # ids_restore = torch.argsort(ids_shuffle, dim=1) | |
| ids_keep = ids_shuffle[:, :num_encoder_tokens] | |
| encoder_tokens = torch.gather(encoder_tokens_all, dim=1, | |
| index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2])) | |
| encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) | |
| encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep) | |
| mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) | |
| if self.num_register_tokens > 0: | |
| register_tokens = repeat(self.register_tokens, '() n d -> b n d', b=B) | |
| # We add register tokens at the beginning of the sequence | |
| encoder_tokens = torch.cat([register_tokens, encoder_tokens], dim=1) | |
| encoder_emb = torch.cat([torch.zeros_like(register_tokens), encoder_emb], dim=1) | |
| encoder_mask = torch.cat([torch.zeros((B, register_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1) | |
| mod_mask = torch.cat([torch.full((B, register_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1) | |
| encoder_tokens[encoder_mask] = 0. | |
| encoder_emb[encoder_mask] = 0. | |
| mod_mask[encoder_mask] = -1 | |
| # Mask could be of shape 'b n1 n2' but not needed for masked_fill | |
| # This means this mask can then be re-used for decoder cross-attention | |
| encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2') | |
| return encoder_tokens, encoder_emb, encoder_mask, mod_mask | |
| def forward_mask_decoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_decoder_tokens: int) -> Tuple[torch.Tensor]: | |
| """Concatenates and mask decoder tensors based on provided modality information. | |
| This function consolidates decoder tokens from multiple modalities, selects a specified number of them based on modality information, and applies appropriate masking. | |
| Args: | |
| mod_dict (dict): Dictionary containing tensors for different modalities. | |
| It is expected to have keys for each modality and values | |
| containing the modalities' associated tensors. | |
| num_decoder_tokens (int): Number of decoder tokens to retain after masking. | |
| Returns: | |
| tuple: | |
| - decoder_tokens (torch.Tensor): Selected decoder tokens from all modalities. Shape (B, M, D) where M is the number of selected decoder tokens. | |
| - decoder_emb (torch.Tensor): Corresponding embeddings for decoder tokens. Shape (B, M, D) | |
| - decoder_mask (torch.Tensor): A boolean mask indicating which decoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, M) | |
| - target_ids (torch.Tensor): IDs of the target tokens corresponding to the decoder tokens. Shape (B, M) | |
| - decoder_attention_mask (torch.Tensor): Mask for the decoder self-attention layers. Shape (B, M, M) | |
| - mod_mask (torch.Tensor): An integer mask marking the modality type for each decoder token (with -1 indicating unassigned pad tokens). Shape (B, M) | |
| """ | |
| # decoder_mask and target_mask are equivalent, we rename it here to harmonize with forward_mask_encoder | |
| decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, decoder_attention_mask_all, mod_mask_all = self.cat_decoder_tensors(mod_dict) | |
| # Add arange multiplied by small constant to mask so they get sorted in a deterministic way | |
| mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 | |
| ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1) | |
| # ids_restore = torch.argsort(ids_shuffle, dim=1) | |
| ids_keep = ids_shuffle[:, :num_decoder_tokens] | |
| decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2])) | |
| decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) | |
| decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) | |
| target_ids = torch.gather(target_ids_all, dim=1, index=ids_keep) | |
| decoder_attention_mask = torch.gather(decoder_attention_mask_all, dim=1, index=ids_keep) | |
| mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) | |
| decoder_tokens[decoder_mask] = 0. | |
| decoder_emb[decoder_mask] = 0. | |
| target_ids[decoder_mask] = 0 | |
| decoder_attention_mask = self.adapt_decoder_attention_mask(decoder_attention_mask, mod_mask) | |
| mod_mask[decoder_mask] = -1 | |
| # This means this mask can then be re-used for decoder cross-attention | |
| decoder_mask = rearrange(decoder_mask, 'b n2 -> b 1 n2') | |
| return decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, mod_mask | |
| def adapt_decoder_attention_mask(self, decoder_attention_mask: torch.Tensor, mod_mask=Optional[torch.Tensor]) -> torch.Tensor: | |
| """ | |
| Transforms the compressed decoder attention mask to a full attention mask based on the specified constraints. | |
| Args: | |
| decoder_attention_mask (torch.Tensor): Initial attention mask indicating attention constraints. Shape (B, M) where M is the number of the decoder tokens. | |
| mod_mask (torch.Tensor, optional): Modality mask to separate attention masks per modality. Shape (B, M) | |
| Returns: | |
| torch.Tensor: Adapted attention mask. Shape (B, M, M) where M is the number of the decoder tokens. | |
| """ | |
| B, N = decoder_attention_mask.shape | |
| if self.decoder_causal_mask: | |
| # For causal mode, tokens can only attend to preceding tokens and themselves. | |
| causal_mask = torch.ones((N, N), dtype=torch.bool, device=decoder_attention_mask.device).triu(1) | |
| causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B) | |
| adapted_attention_mask = causal_mask | |
| else: | |
| # Cumulatively sum the attention mask to determine token-wise attention behavior. | |
| # Examples: | |
| # Mask [4, 0, 0, 0] -> Cumsum: [4, 4, 4, 4] -> All tokens attend to each other. | |
| # Mask [1, 1, 1, 1] -> Cumsum: [1, 2, 3, 4] -> Strict autoregressive behavior. | |
| # Mask [2, 0, 1, 1] -> Cumsum: [2, 2, 3, 4] -> Tokens 1 and 2 attend to each other, token 3 attends to tokens 1-3, and token 4 to all. | |
| attention_arange = torch.arange(N, device=decoder_attention_mask.device) | |
| attention_arange = repeat(attention_arange, "n2 -> b n1 n2", b=B, n1=N) | |
| cumsum_mask = torch.cumsum(decoder_attention_mask, dim=-1) | |
| cumsum_mask = rearrange(cumsum_mask, "b n -> b n 1") | |
| adapted_attention_mask = (attention_arange >= cumsum_mask) | |
| if self.decoder_sep_mask: | |
| # Separate attention between tokens based on their modality using mod_mask. | |
| sep_mask = repeat(mod_mask, "b n2 -> b n1 n2", n1=N) != repeat(mod_mask, "b n1 -> b n1 n2", n2=N) | |
| adapted_attention_mask = adapted_attention_mask | sep_mask | |
| return adapted_attention_mask | |
| def forward_encoder(self, | |
| x: torch.Tensor, | |
| encoder_mask: torch.Tensor) -> torch.Tensor: | |
| """Forward pass for the encoder. | |
| Args: | |
| x (torch.Tensor): Encoder input tokens. Shape (B, N, D) where N is the number of encoder tokens. | |
| encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N) | |
| Returns: | |
| torch.Tensor: Encoder output. Shape (B, N, D) | |
| """ | |
| for blk in self.encoder: | |
| x = blk(x, mask=encoder_mask) | |
| x = self.encoder_norm(x) | |
| return x | |
| def forward_decoder(self, | |
| y: torch.Tensor, | |
| context: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| decoder_attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Forward pass for the decoder. | |
| Args: | |
| y (torch.Tensor): Decoder input tokens. Shape (B, M, D). | |
| context (torch.Tensor): Context for the decoder (i.e. encoder output). Shape (B, N, D). | |
| encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N). | |
| decoder_attention_mask (torch.Tensor): Decoder attention mask. Shape (B, M, M). | |
| Returns: | |
| torch.Tensor: Decoder output. Shape (B, M, D). | |
| """ | |
| for blk in self.decoder: | |
| y = blk(y, context, sa_mask=decoder_attention_mask, xa_mask=encoder_mask) | |
| y = self.decoder_norm(y) | |
| return y | |
| def forward_logits(self, | |
| y: torch.Tensor, | |
| decoder_mod_dict: Dict[str, Dict[str, torch.Tensor]], | |
| decoder_mod_mask: torch.Tensor, | |
| return_all_logits: bool = False) -> Dict[str, torch.Tensor]: | |
| """Forward computation of logits for each modality. | |
| Args: | |
| y (torch.Tensor): Decoder output. Shape (B, M, D). | |
| decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. | |
| decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M). | |
| Returns: | |
| Dict[str, torch.Tensor]: Dictionary of logits for each modality. | |
| """ | |
| mod_logits = {} | |
| for mod, d in decoder_mod_dict.items(): | |
| idx = self.modality_info[mod]["id"] | |
| if return_all_logits: | |
| logits = self.decoder_embeddings[mod].forward_logits(y) | |
| else: | |
| logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx]) | |
| mod_logits[mod] = logits | |
| return mod_logits | |
| def forward_loss(self, | |
| y: torch.Tensor, | |
| target_ids: torch.Tensor, | |
| decoder_mod_dict: Dict[str, Any], | |
| decoder_mod_mask: torch.Tensor, loss_type: str) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Computes the loss based on the specified loss type. | |
| Args: | |
| y (torch.Tensor): Decoder output. Shape (B, M, D). | |
| target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M). | |
| decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. | |
| decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M). | |
| loss_type (str): The type of loss to compute. Either 'mod' or 'token'. | |
| Returns: | |
| Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total loss and dictionary of loss for each modality. | |
| """ | |
| if loss_type in ['mod', 'modality']: | |
| loss, mod_loss = self.forward_mod_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask) | |
| elif loss_type == 'token': | |
| loss, mod_loss = self.forward_token_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask) | |
| else: | |
| raise ValueError("Invalid loss type") | |
| return loss, mod_loss | |
| def forward_mod_loss(self, | |
| y: torch.Tensor, | |
| target_ids: torch.Tensor, | |
| decoder_mod_dict: Dict[str, Any], | |
| decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Computes the modality-wise loss. | |
| Args: | |
| y (torch.Tensor): Decoder tokens. Shape (B, M, D). | |
| target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M). | |
| decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. | |
| decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M). | |
| Returns: | |
| Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total modality loss and dictionary of loss for each modality. | |
| """ | |
| mod_loss = {} | |
| for mod, d in decoder_mod_dict.items(): | |
| idx = self.modality_info[mod]["id"] | |
| logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx]) | |
| if logits.numel() == 0: | |
| # If there are no logits / targets, set mod_loss to 0 | |
| mod_loss[mod] = torch.zeros(1, device=logits.device) | |
| else: | |
| loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean') | |
| mod_loss[mod] = loss | |
| loss = sum(mod_loss.values()) / len(mod_loss) | |
| return loss, mod_loss | |
| def forward_token_loss(self, | |
| y: torch.Tensor, | |
| target_ids: torch.Tensor, | |
| decoder_mod_dict: Dict[str, Any], | |
| decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Computes the token-wise loss. | |
| Args: | |
| y (torch.Tensor): Decoder tokens. Shape (B, M, D). | |
| target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M). | |
| decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. | |
| decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M). | |
| Returns: | |
| Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total token loss and dictionary of loss for each modality. | |
| """ | |
| mod_loss = {} | |
| mod_count = {} | |
| for mod, d in decoder_mod_dict.items(): | |
| idx = self.modality_info[mod]["id"] | |
| logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx]) | |
| if logits.numel() == 0: | |
| # If there are no logits / targets, set mod_loss to 0 | |
| mod_loss[mod] = torch.zeros(1, device=logits.device) | |
| mod_count[mod] = 0 | |
| else: | |
| loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean') | |
| mod_loss[mod] = loss | |
| mod_count[mod] = logits.numel() | |
| loss = sum([mod_loss[mod] * mod_count[mod] for mod in mod_loss.keys()]) / sum(mod_count.values()) | |
| return loss, mod_loss | |
| def forward(self, | |
| mod_dict: Dict[str, Dict[str, torch.Tensor]], | |
| num_encoder_tokens: int, | |
| num_decoder_tokens: int, | |
| loss_type: str = 'mod', | |
| return_logits: bool = False) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: | |
| """ | |
| Forward pass for the model. | |
| Args: | |
| mod_dict (Dict[str, Dict[str, torch.Tensor]]): Dictionary containing the tensors, masks, and other info for each modality. | |
| - mod_dict[modality_name]["tensor_name"]: Shape can vary based on tensor_name and modality. | |
| num_encoder_tokens (int): Number of tokens to keep for the encoder. | |
| num_decoder_tokens (int): Number of tokens to keep for the decoder. | |
| loss_type (str, optional): The type of loss to compute. Can be 'mod' (average of loss per modality) or 'token' (average loss per token). Default is 'mod'. | |
| return_logits (bool, optional): If True, return the logits. Default is False. | |
| Returns: | |
| Union[dict, tuple]: | |
| - If return_logits is True: Dictionary of logits for each modality. | |
| - Otherwise: Tuple containing the total loss and dictionary of loss for each modality. | |
| """ | |
| # Mod dicts | |
| encoder_mod_dict = {mod: self.encoder_embeddings[mod](d) | |
| for mod, d in mod_dict.items() | |
| if mod in self.encoder_embeddings} | |
| encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder(encoder_mod_dict, num_encoder_tokens) | |
| decoder_mod_dict = {mod: self.decoder_embeddings[mod].forward_embed(d) | |
| for mod, d in mod_dict.items() | |
| if mod in self.decoder_embeddings} | |
| decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask = self.forward_mask_decoder(decoder_mod_dict, num_decoder_tokens) | |
| # Encoder | |
| x = encoder_tokens + encoder_emb | |
| x = self.forward_encoder(x, encoder_mask=encoder_mask) | |
| # Decoder | |
| context = self.decoder_proj_context(x) + encoder_emb | |
| y = decoder_tokens + decoder_emb | |
| y = self.forward_decoder(y, context, encoder_mask=encoder_mask, decoder_attention_mask=decoder_attention_mask) | |
| # Logits | |
| if return_logits: | |
| mod_logits = self.forward_logits(y, decoder_mod_dict, decoder_mod_mask, return_all_logits=True) | |
| return mod_logits | |
| # Loss | |
| loss, mod_loss = self.forward_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask, loss_type) | |
| return loss, mod_loss | |
| def freeze_encoder(self, freeze_embeddings=True): | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.encoder_norm.parameters(): | |
| param.requires_grad = False | |
| if freeze_embeddings: | |
| for param in self.encoder_embeddings.parameters(): | |
| param.requires_grad = False | |
| def freeze_encoder_except_specific_embeddings(self, frozen_embedding_domain): | |
| frozen_embedding_domain = frozen_embedding_domain.split('-') | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.encoder_norm.parameters(): | |
| param.requires_grad = False | |
| for name, param in self.encoder_embeddings.named_parameters(): | |
| if name.split('.')[0] in frozen_embedding_domain: | |
| param.requires_grad = False | |
| def unfreeze_encoder(self, unfreeze_embeddings=True): | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = True | |
| for param in self.encoder_norm.parameters(): | |
| param.requires_grad = True | |
| if unfreeze_embeddings: | |
| for param in self.encoder_embeddings.parameters(): | |
| param.requires_grad = True | |
| def freeze_decoder(self, freeze_embeddings=True): | |
| for param in self.decoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.decoder_norm.parameters(): | |
| param.requires_grad = False | |
| if freeze_embeddings: | |
| for param in self.decoder_embeddings.parameters(): | |
| param.requires_grad = False | |
| def freeze_decoder_except_specific_embeddings(self, frozen_embedding_domain): | |
| frozen_embedding_domain = frozen_embedding_domain.split('-') | |
| for param in self.decoder.parameters(): | |
| param.requires_grad = False | |
| for param in self.decoder_norm.parameters(): | |
| param.requires_grad = False | |
| for name, param in self.decoder_embeddings.named_parameters(): | |
| if name.split('.')[0] in frozen_embedding_domain: | |
| param.requires_grad = False | |
| def unfreeze_decoder(self, unfreeze_embeddings=True): | |
| for param in self.decoder.parameters(): | |
| param.requires_grad = True | |
| for param in self.decoder_norm.parameters(): | |
| param.requires_grad = True | |
| if unfreeze_embeddings: | |
| for param in self.decoder_embeddings.parameters(): | |
| param.requires_grad = True | |
| def freeze_shared_params(self): | |
| self.freeze_encoder(freeze_embeddings=False) | |
| self.freeze_decoder(freeze_embeddings=False) | |
| def freeze_params_except_specific_embeddings(self, frozen_embedding_domain): | |
| self.freeze_encoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain) | |
| self.freeze_decoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain) | |
| def unfreeze_shared_params(self): | |
| self.unfreeze_encoder(unfreeze_embeddings=False) | |
| self.unfreeze_decoder(unfreeze_embeddings=False) | |
| def unfreeze_all(self): | |
| self.unfreeze_encoder(unfreeze_embeddings=True) | |
| self.unfreeze_decoder(unfreeze_embeddings=True) | |
| ################################################ | |
| # Wrapper for easy loading with Huggingface Hub | |
| class FM(FourM, PyTorchModelHubMixin): | |
| """Wrapper around FourM for easy loading with Huggingface Hub. | |
| Args: | |
| config (dict): Dictionary containing the model and modality configuration, | |
| used for loading from Huggingface Hub. | |
| """ | |
| def __init__(self, config: dict): | |
| config = copy.deepcopy(config) | |
| all_domains = sorted(list(set(config['domains_in']) | set(config['domains_out']))) | |
| modality_info = {mod: MODALITY_INFO[mod] for mod in all_domains} | |
| encoder_embeddings = {} | |
| for mod in config['domains_in']: | |
| info = modality_info[mod] | |
| if info.get("encoder_embedding", None) is not None: | |
| if info["type"] == "img": | |
| image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size']) | |
| encoder_embeddings[mod] = info["encoder_embedding"](patch_size=patch_size, image_size=image_size) | |
| else: | |
| encoder_embeddings[mod] = info["encoder_embedding"]() | |
| decoder_embeddings = {} | |
| for mod in config['domains_out']: | |
| info = modality_info[mod] | |
| if info.get("decoder_embedding", None) is not None: | |
| if info["type"] == "img": | |
| image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size']) | |
| decoder_embeddings[mod] = info["decoder_embedding"](patch_size=patch_size, image_size=image_size, share_embedding=False) | |
| else: | |
| decoder_embeddings[mod] = info["decoder_embedding"](share_embedding=False) | |
| config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias']) | |
| config['act_layer'] = getattr(torch.nn, config['act_layer']) | |
| del config['norm_bias'] | |
| del config['domains_in'] | |
| del config['domains_out'] | |
| del config['image_size'] | |
| del config['patch_size'] | |
| super().__init__( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| modality_info=modality_info, | |
| **config | |
| ) | |
| ################################################ | |
| # Model definitions | |
| # GELU variants | |
| def fm_tiny_6e_6d_gelu( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=6, | |
| decoder_depth=6, | |
| dim=384, | |
| num_heads=6, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_small_8e_8d_gelu( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=8, | |
| decoder_depth=8, | |
| dim=512, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_base_12e_12d_gelu( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=12, | |
| decoder_depth=12, | |
| dim=768, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_large_24e_24d_gelu( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=24, | |
| decoder_depth=24, | |
| dim=1024, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| def fm_xlarge_24e_24d_gelu( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=24, | |
| decoder_depth=24, | |
| dim=2048, | |
| num_heads=32, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs | |
| ) | |
| return model | |
| # SwiGLU variants | |
| def fm_tiny_6e_6d_swiglu_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=6, | |
| decoder_depth=6, | |
| dim=384, | |
| num_heads=6, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_small_8e_8d_swiglu_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=8, | |
| decoder_depth=8, | |
| dim=512, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_base_12e_12d_swiglu_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=12, | |
| decoder_depth=12, | |
| dim=768, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_large_24e_24d_swiglu_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=24, | |
| decoder_depth=24, | |
| dim=1024, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_xlarge_24e_24d_swiglu_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=24, | |
| decoder_depth=24, | |
| dim=2048, | |
| num_heads=32, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| **kwargs | |
| ) | |
| return model | |
| # SwiGLU + QKNorm variants | |
| def fm_base_12e_12d_swiglu_qknorm_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=12, | |
| decoder_depth=12, | |
| dim=768, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| qk_norm=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_large_24e_24d_swiglu_qknorm_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=24, | |
| decoder_depth=24, | |
| dim=1024, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| qk_norm=True, | |
| **kwargs | |
| ) | |
| return model | |
| def fm_xlarge_24e_24d_swiglu_qknorm_nobias( | |
| encoder_embeddings: Dict[str, nn.Module], | |
| decoder_embeddings: Dict[str, nn.Module], | |
| **kwargs): | |
| model = FourM( | |
| encoder_embeddings=encoder_embeddings, | |
| decoder_embeddings=decoder_embeddings, | |
| encoder_depth=24, | |
| decoder_depth=24, | |
| dim=2048, | |
| num_heads=32, | |
| mlp_ratio=4, | |
| qkv_bias=False, | |
| proj_bias=False, | |
| mlp_bias=False, | |
| norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
| act_layer=nn.SiLU, | |
| gated_mlp=True, | |
| qk_norm=True, | |
| **kwargs | |
| ) | |
| return model |