| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
|
|
| from torch import nn |
|
|
| try: |
| from transformers.modeling_bert import ( |
| BertPreTrainedModel, |
| BertModel, |
| BertEncoder, |
| BertPredictionHeadTransform, |
| ) |
| except ImportError: |
| pass |
|
|
| from ..modules import VideoTokenMLP, MMBertEmbeddings |
|
|
|
|
| |
| class MMBertForJoint(BertPreTrainedModel): |
| """A BertModel with isolated attention mask to separate modality.""" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.videomlp = VideoTokenMLP(config) |
| self.bert = MMBertModel(config) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| input_video_embeds=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| next_sentence_label=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| separate_forward_split=None, |
| ): |
| return_dict = ( |
| return_dict if return_dict is not None |
| else self.config.use_return_dict |
| ) |
| video_tokens = self.videomlp(input_video_embeds) |
|
|
| outputs = self.bert( |
| input_ids, |
| video_tokens, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| separate_forward_split=separate_forward_split, |
| ) |
|
|
| return outputs |
|
|
|
|
| class MMBertForTokenClassification(BertPreTrainedModel): |
| """A BertModel similar to MMJointUni, with extra wrapper layer |
| to be fine-tuned from other pretrained MMFusion model.""" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.videomlp = VideoTokenMLP(config) |
| self.bert = MMBertModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
| self.classifier = nn.Linear(config.hidden_size, 779) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| input_video_embeds=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| next_sentence_label=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| separate_forward_split=None, |
| ): |
| return_dict = ( |
| return_dict if return_dict is not None |
| else self.config.use_return_dict |
| ) |
|
|
| video_tokens = self.videomlp(input_video_embeds) |
| outputs = self.bert( |
| input_ids, |
| video_tokens, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| separate_forward_split=separate_forward_split, |
| ) |
|
|
| return (self.classifier(outputs[0]),) |
|
|
|
|
| |
|
|
| class MMBertForEncoder(BertPreTrainedModel): |
| """A BertModel for Contrastive Learning.""" |
| def __init__(self, config): |
| super().__init__(config) |
| self.videomlp = VideoTokenMLP(config) |
| self.bert = MMBertModel(config) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| input_video_embeds=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| return_dict = ( |
| return_dict if return_dict is not None |
| else self.config.use_return_dict |
| ) |
| if input_video_embeds is not None: |
| video_tokens = self.videomlp(input_video_embeds) |
| else: |
| video_tokens = None |
|
|
| outputs = self.bert( |
| input_ids, |
| video_tokens, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| return outputs |
|
|
|
|
| class MMBertForMFMMLM(BertPreTrainedModel): |
| """A BertModel with shared prediction head on MFM-MLM.""" |
| def __init__(self, config): |
| super().__init__(config) |
| self.videomlp = VideoTokenMLP(config) |
| self.bert = MMBertModel(config) |
| self.cls = MFMMLMHead(config) |
| self.hidden_size = config.hidden_size |
| self.init_weights() |
|
|
| def get_output_embeddings(self): |
| return self.cls.predictions.decoder |
|
|
| def forward( |
| self, |
| input_ids=None, |
| input_video_embeds=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| masked_frame_labels=None, |
| target_video_hidden_states=None, |
| non_masked_frame_mask=None, |
| masked_lm_labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| return_dict = ( |
| return_dict if return_dict is not None |
| else self.config.use_return_dict |
| ) |
| if input_video_embeds is not None: |
| video_tokens = self.videomlp(input_video_embeds) |
| else: |
| video_tokens = None |
|
|
| if target_video_hidden_states is not None: |
| target_video_hidden_states = self.videomlp( |
| target_video_hidden_states) |
|
|
| non_masked_frame_hidden_states = video_tokens.masked_select( |
| non_masked_frame_mask.unsqueeze(-1) |
| ).view(-1, self.hidden_size) |
|
|
| outputs = self.bert( |
| input_ids, |
| video_tokens, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| mfm_scores, prediction_scores = None, None |
| if masked_frame_labels is not None and masked_lm_labels is not None: |
| |
| text_offset = masked_frame_labels.size(1) + 1 |
| video_sequence_output = sequence_output[ |
| :, 1:text_offset |
| ] |
| text_sequence_output = torch.cat( |
| [sequence_output[:, :1], sequence_output[:, text_offset:]], |
| dim=1 |
| ) |
|
|
| hidden_size = video_sequence_output.size(-1) |
| selected_video_output = video_sequence_output.masked_select( |
| masked_frame_labels.unsqueeze(-1) |
| ).view(-1, hidden_size) |
|
|
| |
| hidden_size = text_sequence_output.size(-1) |
| |
| labels_mask = masked_lm_labels != -100 |
|
|
| selected_text_output = text_sequence_output.masked_select( |
| labels_mask.unsqueeze(-1) |
| ).view(-1, hidden_size) |
| mfm_scores, prediction_scores = self.cls( |
| selected_video_output, |
| target_video_hidden_states, |
| non_masked_frame_hidden_states, |
| selected_text_output, |
| ) |
|
|
| output = ( |
| mfm_scores, |
| prediction_scores, |
| ) + outputs |
| return output |
|
|
|
|
| class BertMFMMLMPredictionHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.transform = BertPredictionHeadTransform(config) |
| |
| |
| self.decoder = nn.Linear( |
| config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
| |
| |
| self.decoder.bias = self.bias |
|
|
| def forward( |
| self, |
| video_hidden_states=None, |
| target_video_hidden_states=None, |
| non_masked_frame_hidden_states=None, |
| text_hidden_states=None, |
| ): |
| video_logits, text_logits = None, None |
| if video_hidden_states is not None: |
| video_hidden_states = self.transform(video_hidden_states) |
| non_masked_frame_logits = torch.mm( |
| video_hidden_states, |
| non_masked_frame_hidden_states.transpose(1, 0) |
| ) |
| masked_frame_logits = torch.bmm( |
| video_hidden_states.unsqueeze(1), |
| target_video_hidden_states.unsqueeze(-1), |
| ).squeeze(-1) |
| video_logits = torch.cat( |
| [masked_frame_logits, non_masked_frame_logits], dim=1 |
| ) |
|
|
| if text_hidden_states is not None: |
| text_hidden_states = self.transform(text_hidden_states) |
| text_logits = self.decoder(text_hidden_states) |
| return video_logits, text_logits |
|
|
|
|
| class MFMMLMHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.predictions = BertMFMMLMPredictionHead(config) |
|
|
| def forward( |
| self, |
| video_hidden_states=None, |
| target_video_hidden_states=None, |
| non_masked_frame_hidden_states=None, |
| text_hidden_states=None, |
| ): |
| video_logits, text_logits = self.predictions( |
| video_hidden_states, |
| target_video_hidden_states, |
| non_masked_frame_hidden_states, |
| text_hidden_states, |
| ) |
| return video_logits, text_logits |
|
|
|
|
| class MMBertForMTM(MMBertForMFMMLM): |
| def __init__(self, config): |
| BertPreTrainedModel.__init__(self, config) |
| self.videomlp = VideoTokenMLP(config) |
| self.bert = MMBertModel(config) |
| self.cls = MTMHead(config) |
| self.hidden_size = config.hidden_size |
| self.init_weights() |
|
|
|
|
| class BertMTMPredictionHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.transform = BertPredictionHeadTransform(config) |
| self.decoder = nn.Linear( |
| config.hidden_size, config.vocab_size, bias=False) |
|
|
| def forward( |
| self, |
| video_hidden_states=None, |
| target_video_hidden_states=None, |
| non_masked_frame_hidden_states=None, |
| text_hidden_states=None, |
| ): |
| non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0) |
| video_logits, text_logits = None, None |
| if video_hidden_states is not None: |
| video_hidden_states = self.transform(video_hidden_states) |
|
|
| masked_frame_logits = torch.bmm( |
| video_hidden_states.unsqueeze(1), |
| target_video_hidden_states.unsqueeze(-1), |
| ).squeeze(-1) |
|
|
| non_masked_frame_logits = torch.mm( |
| video_hidden_states, |
| non_masked_frame_hidden_states |
| ) |
| video_on_vocab_logits = self.decoder(video_hidden_states) |
| video_logits = torch.cat([ |
| masked_frame_logits, |
| non_masked_frame_logits, |
| video_on_vocab_logits], dim=1) |
|
|
| if text_hidden_states is not None: |
| text_hidden_states = self.transform(text_hidden_states) |
| |
| text_on_vocab_logits = self.decoder(text_hidden_states) |
| text_on_video_logits = torch.mm( |
| text_hidden_states, |
| non_masked_frame_hidden_states |
| ) |
| text_logits = torch.cat([ |
| text_on_vocab_logits, |
| text_on_video_logits |
| ], dim=1) |
|
|
| return video_logits, text_logits |
|
|
|
|
| class MTMHead(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.predictions = BertMTMPredictionHead(config) |
|
|
| def forward( |
| self, |
| video_hidden_states=None, |
| target_video_hidden_states=None, |
| non_masked_frame_hidden_states=None, |
| text_hidden_states=None, |
| ): |
| video_logits, text_logits = self.predictions( |
| video_hidden_states, |
| target_video_hidden_states, |
| non_masked_frame_hidden_states, |
| text_hidden_states, |
| ) |
| return video_logits, text_logits |
|
|
|
|
| class MMBertModel(BertModel): |
| """MMBertModel has MMBertEmbedding to support video tokens.""" |
|
|
| def __init__(self, config, add_pooling_layer=True): |
| super().__init__(config) |
| |
| self.embeddings = MMBertEmbeddings(config) |
| self.encoder = MultiLayerAttentionMaskBertEncoder(config) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| input_video_embeds=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| separate_forward_split=None, |
| ): |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None |
| else self.config.use_return_dict |
| ) |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError( |
| "You cannot specify both input_ids " |
| "and inputs_embeds at the same time" |
| ) |
| elif input_ids is not None: |
| if input_video_embeds is not None: |
| input_shape = ( |
| input_ids.size(0), |
| input_ids.size(1) + input_video_embeds.size(1), |
| ) |
| else: |
| input_shape = ( |
| input_ids.size(0), |
| input_ids.size(1), |
| ) |
| elif inputs_embeds is not None: |
| if input_video_embeds is not None: |
| input_shape = ( |
| inputs_embeds.size(0), |
| inputs_embeds.size(1) + input_video_embeds.size(1), |
| ) |
| else: |
| input_shape = ( |
| input_ids.size(0), |
| input_ids.size(1), |
| ) |
| else: |
| raise ValueError( |
| "You have to specify either input_ids or inputs_embeds") |
|
|
| device = input_ids.device if input_ids is not None \ |
| else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(input_shape, device=device) |
| if token_type_ids is None: |
| token_type_ids = torch.zeros( |
| input_shape, dtype=torch.long, device=device) |
|
|
| |
| |
| |
| |
| extended_attention_mask: torch.Tensor = \ |
| self.get_extended_attention_mask( |
| attention_mask, input_shape, device) |
|
|
| |
| |
| |
| if self.config.is_decoder and encoder_hidden_states is not None: |
| ( |
| encoder_batch_size, |
| encoder_sequence_length, |
| _, |
| ) = encoder_hidden_states.size() |
| encoder_hidden_shape = ( |
| encoder_batch_size, encoder_sequence_length) |
| if encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones( |
| encoder_hidden_shape, device=device) |
| encoder_extended_attention_mask = self.invert_attention_mask( |
| encoder_attention_mask |
| ) |
| else: |
| encoder_extended_attention_mask = None |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| head_mask = self.get_head_mask( |
| head_mask, self.config.num_hidden_layers) |
|
|
| embedding_output = self.embeddings( |
| input_ids, |
| input_video_embeds, |
| position_ids=position_ids, |
| token_type_ids=token_type_ids, |
| inputs_embeds=inputs_embeds, |
| ) |
|
|
| if separate_forward_split is not None: |
| split_embedding_output = \ |
| embedding_output[:, :separate_forward_split] |
| split_extended_attention_mask = extended_attention_mask[ |
| :, :, :, :separate_forward_split, :separate_forward_split |
| ] |
| split_encoder_outputs = self.encoder( |
| split_embedding_output, |
| attention_mask=split_extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| assert ( |
| len(split_encoder_outputs) <= 2 |
| ), "we do not support merge on attention for now." |
| encoder_outputs = [] |
| encoder_outputs.append([split_encoder_outputs[0]]) |
| if len(split_encoder_outputs) == 2: |
| encoder_outputs.append([]) |
| for _all_hidden_states in split_encoder_outputs[1]: |
| encoder_outputs[-1].append([_all_hidden_states]) |
|
|
| split_embedding_output = \ |
| embedding_output[:, separate_forward_split:] |
| split_extended_attention_mask = extended_attention_mask[ |
| :, :, :, separate_forward_split:, separate_forward_split: |
| ] |
|
|
| split_encoder_outputs = self.encoder( |
| split_embedding_output, |
| attention_mask=split_extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| assert ( |
| len(split_encoder_outputs) <= 2 |
| ), "we do not support merge on attention for now." |
| encoder_outputs[0].append(split_encoder_outputs[0]) |
| encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1) |
| if len(split_encoder_outputs) == 2: |
| for layer_idx, _all_hidden_states in enumerate( |
| split_encoder_outputs[1] |
| ): |
| encoder_outputs[1][layer_idx].append(_all_hidden_states) |
| encoder_outputs[1][layer_idx] = torch.cat( |
| encoder_outputs[1][layer_idx], dim=1 |
| ) |
| encoder_outputs = tuple(encoder_outputs) |
| else: |
| encoder_outputs = self.encoder( |
| embedding_output, |
| attention_mask=extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = encoder_outputs[0] |
| pooled_output = ( |
| self.pooler(sequence_output) if self.pooler is not None else None |
| ) |
|
|
| return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
| def get_extended_attention_mask(self, attention_mask, input_shape, device): |
| """This is borrowed from `modeling_utils.py` with the support of |
| multi-layer attention masks. |
| The second dim is expected to be number of layers. |
| See `MMAttentionMaskProcessor`. |
| Makes broadcastable attention and causal masks so that future |
| and masked tokens are ignored. |
| |
| Arguments: |
| attention_mask (:obj:`torch.Tensor`): |
| Mask with ones indicating tokens to attend to, |
| zeros for tokens to ignore. |
| input_shape (:obj:`Tuple[int]`): |
| The shape of the input to the model. |
| device: (:obj:`torch.device`): |
| The device of the input to the model. |
| |
| Returns: |
| :obj:`torch.Tensor` The extended attention mask, \ |
| with a the same dtype as :obj:`attention_mask.dtype`. |
| """ |
| |
| |
| |
| |
| if attention_mask.dim() == 4: |
| extended_attention_mask = attention_mask[:, :, None, :, :] |
| extended_attention_mask = extended_attention_mask.to( |
| dtype=self.dtype |
| ) |
| extended_attention_mask = (1.0 - extended_attention_mask) \ |
| * -10000.0 |
| return extended_attention_mask |
| else: |
| return super().get_extended_attention_mask( |
| attention_mask, input_shape, device |
| ) |
|
|
|
|
| class MultiLayerAttentionMaskBertEncoder(BertEncoder): |
| """extend BertEncoder with the capability of |
| multiple layers of attention mask.""" |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=False, |
| ): |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| for i, layer_module in enumerate(self.layer): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
| layer_attention_mask = ( |
| attention_mask[:, i, :, :, :] |
| if attention_mask.dim() == 5 |
| else attention_mask |
| ) |
|
|
| if getattr(self.config, "gradient_checkpointing", False): |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer_module), |
| hidden_states, |
| layer_attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| layer_attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| output_attentions, |
| ) |
| hidden_states = layer_outputs[0] |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| return tuple( |
| v |
| for v in [hidden_states, all_hidden_states, all_attentions] |
| if v is not None |
| ) |
|
|