| | |
| | |
| | |
| | |
| |
|
| | """ |
| | OFA |
| | """ |
| | from typing import Optional |
| |
|
| | import logging |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from fairseq import utils |
| | from fairseq.models import register_model, register_model_architecture |
| | from fairseq.modules.transformer_sentence_encoder import init_bert_params |
| |
|
| | from .unify_transformer import TransformerModel |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @register_model("ofa") |
| | class OFAModel(TransformerModel): |
| | __jit_unused_properties__ = ["supported_targets"] |
| |
|
| | def __init__(self, args, encoder, decoder): |
| | super().__init__(args, encoder, decoder) |
| |
|
| | |
| | self.apply(init_bert_params) |
| |
|
| | self.classification_heads = nn.ModuleDict() |
| | if hasattr(self.encoder, "dictionary"): |
| | self.eos: int = self.encoder.dictionary.eos() |
| |
|
| | @staticmethod |
| | def add_args(parser): |
| | super(OFAModel, OFAModel).add_args(parser) |
| | parser.add_argument( |
| | "--pooler-dropout", |
| | type=float, |
| | metavar="D", |
| | help="dropout probability in the masked_lm pooler layers", |
| | ) |
| | parser.add_argument( |
| | "--pooler-classifier", |
| | type=str, |
| | choices=['mlp', 'linear'], |
| | help="type of pooler classifier", |
| | ) |
| | parser.add_argument( |
| | "--pooler-activation-fn", |
| | choices=utils.get_available_activation_fns(), |
| | help="activation function to use for pooler layer", |
| | ) |
| | parser.add_argument( |
| | "--spectral-norm-classification-head", |
| | action="store_true", |
| | help="Apply spectral normalization on the classification head", |
| | ) |
| |
|
| | @property |
| | def supported_targets(self): |
| | return {"self"} |
| |
|
| | def forward( |
| | self, |
| | src_tokens, |
| | src_lengths, |
| | prev_output_tokens, |
| | patch_images: Optional[torch.Tensor] = None, |
| | patch_images_2: Optional[torch.Tensor] = None, |
| | patch_masks: Optional[torch.Tensor] = None, |
| | code_masks: Optional[torch.Tensor] = None, |
| | sample_patch_num: Optional[int] = None, |
| | features_only: bool = False, |
| | classification_head_name: Optional[str] = None, |
| | token_embeddings: Optional[torch.Tensor] = None, |
| | return_all_hiddens: bool = False, |
| | alignment_layer: Optional[int] = None, |
| | alignment_heads: Optional[int] = None, |
| | ): |
| | if classification_head_name is not None: |
| | features_only = True |
| |
|
| | encoder_out = self.encoder( |
| | src_tokens, |
| | src_lengths=src_lengths, |
| | patch_images=patch_images, |
| | patch_masks=patch_masks, |
| | patch_images_2=patch_images_2, |
| | token_embeddings=token_embeddings, |
| | return_all_hiddens=return_all_hiddens, |
| | sample_patch_num=sample_patch_num |
| | ) |
| | x, extra = self.decoder( |
| | prev_output_tokens, |
| | code_masks=code_masks, |
| | encoder_out=encoder_out, |
| | features_only=features_only, |
| | alignment_layer=alignment_layer, |
| | alignment_heads=alignment_heads, |
| | src_lengths=src_lengths, |
| | return_all_hiddens=return_all_hiddens, |
| | ) |
| |
|
| | pad = self.encoder.padding_idx |
| | if classification_head_name is not None: |
| | prev_lengths = prev_output_tokens.ne(pad).sum(1) |
| | gather_index = prev_lengths[:, None, None].expand(x.size(0), 1, x.size(2)) - 1 |
| | sentence_representation = x.gather(1, gather_index).squeeze() |
| | if self.classification_heads[classification_head_name].use_two_images: |
| | hidden_size = sentence_representation.size(1) |
| | sentence_representation = sentence_representation.view(-1, hidden_size * 2) |
| | for k, head in self.classification_heads.items(): |
| | |
| | if k == classification_head_name: |
| | x = head(sentence_representation) |
| | break |
| |
|
| | return x, extra |
| |
|
| | def register_embedding_tokens(self, ans2label_dict, src_dict, bpe): |
| | """Register embedding tokens""" |
| | logger.info("Registering embedding tokens") |
| | self.ans_tensor_list = [] |
| | for i in range(len(ans2label_dict)): |
| | ans = src_dict[-len(ans2label_dict)+i] |
| | ans = ans[5:-1].replace('_', ' ') |
| | ans_tensor = src_dict.encode_line( |
| | line=bpe.encode(' {}'.format(ans.lower())), |
| | add_if_not_exist=False, |
| | append_eos=False |
| | ).long() |
| | self.ans_tensor_list.append(ans_tensor) |
| |
|
| | def register_classification_head( |
| | self, name, num_classes=None, inner_dim=None, use_two_images=False, **kwargs |
| | ): |
| | """Register a classification head.""" |
| | logger.info("Registering classification head: {0}".format(name)) |
| | if name in self.classification_heads: |
| | prev_num_classes = self.classification_heads[name].out_proj.out_features |
| | prev_inner_dim = self.classification_heads[name].dense.out_features |
| | if num_classes != prev_num_classes or inner_dim != prev_inner_dim: |
| | logger.warning( |
| | 're-registering head "{}" with num_classes {} (prev: {}) ' |
| | "and inner_dim {} (prev: {})".format( |
| | name, num_classes, prev_num_classes, inner_dim, prev_inner_dim |
| | ) |
| | ) |
| | self.classification_heads[name] = OFAClassificationHead( |
| | input_dim=self.args.encoder_embed_dim, |
| | inner_dim=inner_dim or self.args.encoder_embed_dim, |
| | num_classes=num_classes, |
| | activation_fn=self.args.pooler_activation_fn, |
| | pooler_dropout=self.args.pooler_dropout, |
| | pooler_classifier=self.args.pooler_classifier, |
| | use_two_images=use_two_images, |
| | do_spectral_norm=getattr( |
| | self.args, "spectral_norm_classification_head", False |
| | ), |
| | ) |
| |
|
| | def upgrade_state_dict_named(self, state_dict, name): |
| | super().upgrade_state_dict_named(state_dict, name) |
| |
|
| | prefix = name + "." if name != "" else "" |
| | current_head_names = ( |
| | [] |
| | if not hasattr(self, "classification_heads") |
| | else self.classification_heads.keys() |
| | ) |
| |
|
| | |
| | keys_to_delete = [] |
| | for k in state_dict.keys(): |
| | if not k.startswith(prefix + "classification_heads."): |
| | continue |
| |
|
| | head_name = k[len(prefix + "classification_heads.") :].split(".")[0] |
| | num_classes = state_dict[ |
| | prefix + "classification_heads." + head_name + ".out_proj.weight" |
| | ].size(0) |
| | inner_dim = state_dict[ |
| | prefix + "classification_heads." + head_name + ".dense.weight" |
| | ].size(0) |
| |
|
| | if getattr(self.args, "load_checkpoint_heads", False): |
| | if head_name not in current_head_names: |
| | self.register_classification_head(head_name, num_classes, inner_dim) |
| | else: |
| | if head_name not in current_head_names: |
| | logger.warning( |
| | "deleting classification head ({}) from checkpoint " |
| | "not present in current model: {}".format(head_name, k) |
| | ) |
| | keys_to_delete.append(k) |
| | elif ( |
| | num_classes |
| | != self.classification_heads[head_name].out_proj.out_features |
| | or inner_dim |
| | != self.classification_heads[head_name].dense.out_features |
| | ): |
| | logger.warning( |
| | "deleting classification head ({}) from checkpoint " |
| | "with different dimensions than current model: {}".format( |
| | head_name, k |
| | ) |
| | ) |
| | keys_to_delete.append(k) |
| | for k in keys_to_delete: |
| | del state_dict[k] |
| |
|
| | def truncate_emb(key): |
| | if key in state_dict: |
| | state_dict[key] = state_dict[key][:-1, :] |
| |
|
| | |
| | |
| | loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0) |
| | if ( |
| | loaded_dict_size == len(self.encoder.dictionary) + 1 |
| | and "<mask>" not in self.encoder.dictionary |
| | ): |
| | truncate_emb("encoder.embed_tokens.weight") |
| | truncate_emb("decoder.embed_tokens.weight") |
| | truncate_emb("encoder.output_projection.weight") |
| | truncate_emb("decoder.output_projection.weight") |
| |
|
| | if loaded_dict_size < len(self.encoder.dictionary): |
| | num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size |
| | embed_dim = state_dict["encoder.embed_tokens.weight"].size(1) |
| |
|
| | new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim) |
| | if getattr(self, "ans_tensor_list", None): |
| | assert len(new_lang_embed_to_add) == len(self.ans_tensor_list) |
| | for i, ans_tensor in enumerate(self.ans_tensor_list): |
| | ans_embed = F.embedding(ans_tensor, state_dict["encoder.embed_tokens.weight"]) |
| | ans_embed = ans_embed.sum(0) / ans_embed.size(0) |
| | new_lang_embed_to_add[i] = ans_embed |
| | else: |
| | nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5) |
| | new_lang_embed_to_add = new_lang_embed_to_add.to( |
| | dtype=state_dict["encoder.embed_tokens.weight"].dtype, |
| | ) |
| |
|
| | state_dict["encoder.embed_tokens.weight"] = torch.cat( |
| | [state_dict["encoder.embed_tokens.weight"], new_lang_embed_to_add] |
| | ) |
| | state_dict["decoder.embed_tokens.weight"] = torch.cat( |
| | [state_dict["decoder.embed_tokens.weight"], new_lang_embed_to_add] |
| | ) |
| | state_dict["decoder.output_projection.weight"] = torch.cat( |
| | [state_dict["decoder.output_projection.weight"], new_lang_embed_to_add] |
| | ) |
| |
|
| | |
| | |
| | if hasattr(self, "classification_heads"): |
| | cur_state = self.classification_heads.state_dict() |
| | for k, v in cur_state.items(): |
| | if prefix + "classification_heads." + k not in state_dict: |
| | logger.info("Overwriting " + prefix + "classification_heads." + k) |
| | state_dict[prefix + "classification_heads." + k] = v |
| |
|
| |
|
| | class OFAClassificationHead(nn.Module): |
| | """Head for sentence-level classification tasks.""" |
| |
|
| | def __init__( |
| | self, |
| | input_dim, |
| | inner_dim, |
| | num_classes, |
| | activation_fn, |
| | pooler_dropout, |
| | pooler_classifier, |
| | use_two_images=False, |
| | do_spectral_norm=False, |
| | ): |
| | super().__init__() |
| | self.pooler_classifier = pooler_classifier |
| | self.use_two_images = use_two_images |
| | input_dim = input_dim * 2 if use_two_images else input_dim |
| | if pooler_classifier == "mlp": |
| | self.dense = nn.Linear(input_dim, inner_dim) |
| | self.activation_fn = utils.get_activation_fn(activation_fn) |
| | self.dropout = nn.Dropout(p=pooler_dropout) |
| | self.out_proj = nn.Linear(inner_dim, num_classes) |
| | elif pooler_classifier == "linear": |
| | self.dropout = nn.Dropout(p=pooler_dropout) |
| | self.out_proj = nn.Linear(input_dim, num_classes) |
| | else: |
| | raise NotImplementedError |
| |
|
| | if do_spectral_norm: |
| | self.out_proj = torch.nn.utils.spectral_norm(self.out_proj) |
| |
|
| | def forward(self, features, **kwargs): |
| | if self.pooler_classifier == 'mlp': |
| | x = features |
| | x = self.dropout(x) |
| | x = self.dense(x) |
| | x = self.activation_fn(x) |
| | x = self.dropout(x) |
| | x = self.out_proj(x) |
| | elif self.pooler_classifier == 'linear': |
| | x = features |
| | x = self.dropout(x) |
| | x = self.out_proj(x) |
| | else: |
| | raise NotImplementedError |
| | return x |
| |
|
| |
|
| | @register_model_architecture("ofa", "ofa_large") |
| | def ofa_large_architecture(args): |
| | args.encoder_embed_path = getattr(args, "encoder_embed_path", None) |
| | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) |
| | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024) |
| | args.encoder_layers = getattr(args, "encoder_layers", 12) |
| | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) |
| | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) |
| | args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) |
| | args.decoder_embed_path = getattr(args, "decoder_embed_path", None) |
| | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) |
| | args.decoder_ffn_embed_dim = getattr( |
| | args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim |
| | ) |
| | args.decoder_layers = getattr(args, "decoder_layers", 12) |
| | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) |
| | args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) |
| | args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) |
| | args.attention_dropout = getattr(args, "attention_dropout", 0.0) |
| | args.relu_dropout = getattr(args, "relu_dropout", 0.0) |
| | args.dropout = getattr(args, "dropout", 0.0) |
| | args.max_target_positions = getattr(args, "max_target_positions", 1024) |
| | args.max_source_positions = getattr(args, "max_source_positions", 1024) |
| | args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) |
| | args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) |
| | args.share_decoder_input_output_embed = getattr( |
| | args, "share_decoder_input_output_embed", True |
| | ) |
| | args.share_all_embeddings = getattr(args, "share_all_embeddings", True) |
| |
|
| | args.decoder_output_dim = getattr( |
| | args, "decoder_output_dim", args.decoder_embed_dim |
| | ) |
| | args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) |
| |
|
| | args.no_scale_embedding = getattr(args, "no_scale_embedding", True) |
| | args.layernorm_embedding = getattr(args, "layernorm_embedding", True) |
| |
|
| | args.activation_fn = getattr(args, "activation_fn", "gelu") |
| | args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") |
| | args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) |
| | args.pooler_classifier = getattr(args, "pooler_classifier", "mlp") |
| |
|
| | args.resnet_drop_path_rate = getattr(args, "resnet_drop_path_rate", 0.0) |
| | args.encoder_drop_path_rate = getattr(args, "encoder_drop_path_rate", 0.0) |
| | args.decoder_drop_path_rate = getattr(args, "decoder_drop_path_rate", 0.0) |
| |
|
| | args.resnet_type = getattr(args, "resnet_type", "resnet152") |
| | args.token_bucket_size = getattr(args, "token_bucket_size", 256) |
| | args.image_bucket_size = getattr(args, "image_bucket_size", 42) |
| |
|
| | args.freeze_encoder_embedding = getattr(args, "freeze_encoder_embedding", False) |
| | args.freeze_decoder_embedding = getattr(args, "freeze_decoder_embedding", False) |
| | args.add_type_embedding = getattr(args, "add_type_embedding", True) |
| | args.attn_scale_factor = getattr(args, "attn_scale_factor", 2) |
| |
|
| | args.code_image_size = getattr(args, "code_image_size", 128) |
| | args.patch_layernorm_embedding = getattr(args, "patch_layernorm_embedding", True) |
| | args.code_layernorm_embedding = getattr(args, "code_layernorm_embedding", True) |
| | args.entangle_position_embedding = getattr(args, "entangle_position_embedding", False) |
| | args.disable_entangle = getattr(args, "disable_entangle", False) |
| | args.sync_bn = getattr(args, "sync_bn", False) |
| |
|
| | args.scale_attn = getattr(args, "scale_attn", False) |
| | args.scale_fc = getattr(args, "scale_fc", False) |
| | args.scale_heads = getattr(args, "scale_heads", False) |
| | args.scale_resids = getattr(args, "scale_resids", False) |
| |
|
| | args.orig_patch_image_size = getattr(args, "orig_patch_image_size", 256) |
| |
|
| |
|
| | @register_model_architecture("ofa", "ofa_base") |
| | def ofa_base_architecture(args): |
| | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) |
| | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768) |
| | args.encoder_layers = getattr(args, "encoder_layers", 6) |
| | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) |
| | args.decoder_layers = getattr(args, "decoder_layers", 6) |
| | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) |
| | args.resnet_type = getattr(args, "resnet_type", "resnet101") |
| | ofa_large_architecture(args) |
| |
|
| |
|
| | @register_model_architecture("ofa", "ofa_huge") |
| | def ofa_huge_architecture(args): |
| | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280) |
| | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1280) |
| | args.encoder_layers = getattr(args, "encoder_layers", 24) |
| | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) |
| | args.decoder_layers = getattr(args, "decoder_layers", 12) |
| | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) |
| | args.resnet_type = getattr(args, "resnet_type", "resnet152") |
| | ofa_large_architecture(args) |
| |
|
| |
|
| | @register_model_architecture("ofa", "ofa_medium") |
| | def ofa_medium_architecture(args): |
| | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
| | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 512) |
| | args.encoder_layers = getattr(args, "encoder_layers", 4) |
| | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) |
| | args.decoder_layers = getattr(args, "decoder_layers", 4) |
| | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) |
| | args.resnet_type = getattr(args, "resnet_type", "resnet101") |
| | ofa_large_architecture(args) |
| |
|
| |
|
| | @register_model_architecture("ofa", "ofa_tiny") |
| | def ofa_tiny_architecture(args): |
| | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) |
| | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 256) |
| | args.encoder_layers = getattr(args, "encoder_layers", 4) |
| | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) |
| | args.decoder_layers = getattr(args, "decoder_layers", 4) |
| | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) |
| | args.resnet_type = getattr(args, "resnet_type", "resnet50") |
| | ofa_large_architecture(args) |
| |
|