| """ |
| Copyright (c) Microsoft Corporation. |
| Licensed under the MIT license. |
| |
| """ |
|
|
| from __future__ import absolute_import, division, print_function, unicode_literals |
|
|
| import logging |
| import math |
| import os |
| import code |
| import torch |
| from torch import nn |
| from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput |
| |
| |
| from .modeling_utils import prune_linear_layer |
|
|
| LayerNormClass = torch.nn.LayerNorm |
| BertLayerNorm = torch.nn.LayerNorm |
|
|
|
|
| class BertSelfAttention(nn.Module): |
| def __init__(self, config): |
| super(BertSelfAttention, self).__init__() |
| if config.hidden_size % config.num_attention_heads != 0: |
| raise ValueError( |
| "The hidden size (%d) is not a multiple of the number of attention " |
| "heads (%d)" % (config.hidden_size, config.num_attention_heads) |
| ) |
| self.output_attentions = config.output_attentions |
|
|
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
| self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
| def transpose_for_scores(self, x): |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1, 3) |
|
|
| def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): |
| if history_state is not None: |
| raise |
| x_states = torch.cat([history_state, hidden_states], dim=1) |
| mixed_query_layer = self.query(hidden_states) |
| mixed_key_layer = self.key(x_states) |
| mixed_value_layer = self.value(x_states) |
| else: |
| mixed_query_layer = self.query(hidden_states) |
| mixed_key_layer = self.key(hidden_states) |
| mixed_value_layer = self.value(hidden_states) |
|
|
| print( |
| 'mixed_query_layer', mixed_query_layer.shape, mixed_key_layer.shape, |
| mixed_value_layer.shape |
| ) |
| query_layer = self.transpose_for_scores(mixed_query_layer) |
| key_layer = self.transpose_for_scores(mixed_key_layer) |
| value_layer = self.transpose_for_scores(mixed_value_layer) |
| print('query_layer', query_layer.shape, key_layer.shape, value_layer.shape) |
|
|
| |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| |
| attention_scores = attention_scores + attention_mask |
|
|
| |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
| |
| |
| attention_probs = self.dropout(attention_probs) |
|
|
| |
| if head_mask is not None: |
| raise |
| attention_probs = attention_probs * head_mask |
|
|
| context_layer = torch.matmul(attention_probs, value_layer) |
|
|
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) |
| context_layer = context_layer.view(*new_context_layer_shape) |
|
|
| outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, ) |
| return outputs |
|
|
|
|
| class BertAttention(nn.Module): |
| def __init__(self, config): |
| super(BertAttention, self).__init__() |
| self.self = BertSelfAttention(config) |
| self.output = BertSelfOutput(config) |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) |
| for head in heads: |
| mask[head] = 0 |
| mask = mask.view(-1).contiguous().eq(1) |
| index = torch.arange(len(mask))[mask].long() |
| |
| self.self.query = prune_linear_layer(self.self.query, index) |
| self.self.key = prune_linear_layer(self.self.key, index) |
| self.self.value = prune_linear_layer(self.self.value, index) |
| self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
| |
| self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
| self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
|
|
| def forward(self, input_tensor, attention_mask, head_mask=None, history_state=None): |
| self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state) |
| attention_output = self.output(self_outputs[0], input_tensor) |
| outputs = (attention_output, ) + self_outputs[1:] |
| return outputs |
|
|
|
|
| class GraphormerLayer(nn.Module): |
| def __init__(self, config): |
| super(GraphormerLayer, self).__init__() |
| self.attention = BertAttention(config) |
| self.has_graph_conv = config.graph_conv |
| self.mesh_type = config.mesh_type |
|
|
| if self.has_graph_conv == True: |
| if self.mesh_type == 'hand': |
| self.graph_conv = GraphResBlock( |
| config.hidden_size, config.hidden_size, mesh_type=self.mesh_type |
| ) |
| elif self.mesh_type == 'body': |
| self.graph_conv = GraphResBlock( |
| config.hidden_size, config.hidden_size, mesh_type=self.mesh_type |
| ) |
|
|
| self.intermediate = BertIntermediate(config) |
| self.output = BertOutput(config) |
|
|
| def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, history_state=None): |
| attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_state) |
| attention_output = attention_outputs[0] |
|
|
| if self.has_graph_conv == True: |
| if self.mesh_type == 'body': |
| joints = attention_output[:, 0:14, :] |
| vertices = attention_output[:, 14:-49, :] |
| img_tokens = attention_output[:, -49:, :] |
|
|
| elif self.mesh_type == 'hand': |
| joints = attention_output[:, 0:21, :] |
| vertices = attention_output[:, 21:-49, :] |
| img_tokens = attention_output[:, -49:, :] |
|
|
| vertices = self.graph_conv(vertices) |
| joints_vertices = torch.cat([joints, vertices, img_tokens], dim=1) |
| else: |
| joints_vertices = attention_output |
|
|
| intermediate_output = self.intermediate(joints_vertices) |
| layer_output = self.output(intermediate_output, joints_vertices) |
| print('layer_output', layer_output.shape) |
| outputs = (layer_output, ) + attention_outputs[1:] |
| return outputs |
|
|
| def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): |
| return self.MHA_GCN(hidden_states, attention_mask, head_mask, history_state) |
|
|
|
|
| class GraphormerEncoder(nn.Module): |
| def __init__(self, config): |
| super(GraphormerEncoder, self).__init__() |
| self.output_attentions = config.output_attentions |
| self.output_hidden_states = config.output_hidden_states |
| self.layer = nn.ModuleList( |
| [GraphormerLayer(config) for _ in range(config.num_hidden_layers)] |
| ) |
|
|
| def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None): |
| all_hidden_states = () |
| all_attentions = () |
| for i, layer_module in enumerate(self.layer): |
| if self.output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states, ) |
|
|
| history_state = None if encoder_history_states is None else encoder_history_states[i] |
| layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], history_state) |
| hidden_states = layer_outputs[0] |
|
|
| if self.output_attentions: |
| all_attentions = all_attentions + (layer_outputs[1], ) |
|
|
| |
| if self.output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states, ) |
|
|
| outputs = (hidden_states, ) |
| if self.output_hidden_states: |
| outputs = outputs + (all_hidden_states, ) |
| if self.output_attentions: |
| outputs = outputs + (all_attentions, ) |
|
|
| return outputs |
|
|
|
|
| class EncoderBlock(BertPreTrainedModel): |
| def __init__(self, config): |
| super(EncoderBlock, self).__init__(config) |
| self.config = config |
| |
| self.encoder = GraphormerEncoder(config) |
| |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| self.img_dim = config.img_feature_dim |
|
|
| try: |
| self.use_img_layernorm = config.use_img_layernorm |
| except: |
| self.use_img_layernorm = None |
|
|
| self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| if self.use_img_layernorm: |
| self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps) |
|
|
| self.apply(self.init_weights) |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ Prunes heads of the model. |
| heads_to_prune: dict of {layer_num: list of heads to prune in this layer} |
| See base class PreTrainedModel |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
| def forward( |
| self, |
| img_feats, |
| input_ids=None, |
| token_type_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| head_mask=None |
| ): |
|
|
| batch_size = len(img_feats) |
| seq_length = len(img_feats[0]) |
| input_ids = torch.zeros([batch_size, seq_length], dtype=torch.long).cuda() |
|
|
| if position_ids is None: |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
| print('-------------------') |
| print('position_ids', seq_length, position_ids.shape) |
|
|
| position_embeddings = self.position_embeddings(position_ids) |
| print( |
| 'position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, |
| self.config.hidden_size |
| ) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids) |
| else: |
| raise |
|
|
| if token_type_ids is None: |
| token_type_ids = torch.zeros_like(input_ids) |
| else: |
| raise |
|
|
| if attention_mask.dim() == 2: |
| extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| elif attention_mask.dim() == 3: |
| extended_attention_mask = attention_mask.unsqueeze(1) |
| else: |
| raise NotImplementedError |
|
|
| extended_attention_mask = extended_attention_mask.to( |
| dtype=next(self.parameters()).dtype |
| ) |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
| if head_mask is not None: |
| raise |
| if head_mask.dim() == 1: |
| head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
| head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) |
| elif head_mask.dim() == 2: |
| head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( |
| -1 |
| ) |
| head_mask = head_mask.to( |
| dtype=next(self.parameters()).dtype |
| ) |
| else: |
| head_mask = [None] * self.config.num_hidden_layers |
|
|
| |
| print('img_feats', img_feats.shape) |
| img_embedding_output = self.img_embedding(img_feats) |
| print('img_embedding_output', img_embedding_output.shape) |
|
|
| |
| embeddings = position_embeddings + img_embedding_output |
|
|
| if self.use_img_layernorm: |
| embeddings = self.LayerNorm(embeddings) |
| embeddings = self.dropout(embeddings) |
|
|
| print('extended_attention_mask', extended_attention_mask.shape) |
| encoder_outputs = self.encoder(embeddings, extended_attention_mask, head_mask=head_mask) |
| sequence_output = encoder_outputs[0] |
|
|
| outputs = (sequence_output, ) |
| if self.config.output_hidden_states: |
| all_hidden_states = encoder_outputs[1] |
| outputs = outputs + (all_hidden_states, ) |
| if self.config.output_attentions: |
| all_attentions = encoder_outputs[-1] |
| outputs = outputs + (all_attentions, ) |
|
|
| return outputs |
|
|
|
|
| class Graphormer(BertPreTrainedModel): |
| ''' |
| The archtecture of a transformer encoder block we used in Graphormer |
| ''' |
| def __init__(self, config): |
| super(Graphormer, self).__init__(config) |
| self.config = config |
| self.bert = EncoderBlock(config) |
| self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim) |
| self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) |
| self.apply(self.init_weights) |
|
|
| def forward( |
| self, |
| img_feats, |
| input_ids=None, |
| token_type_ids=None, |
| attention_mask=None, |
| masked_lm_labels=None, |
| next_sentence_label=None, |
| position_ids=None, |
| head_mask=None |
| ): |
| ''' |
| # self.bert has three outputs |
| # predictions[0]: output tokens |
| # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" |
| # predictions[2]: attentions, if enable "self.config.output_attentions" |
| ''' |
| predictions = self.bert( |
| img_feats=img_feats, |
| input_ids=input_ids, |
| position_ids=position_ids, |
| token_type_ids=token_type_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask |
| ) |
|
|
| |
| pred_score = self.cls_head(predictions[0]) |
| res_img_feats = self.residual(img_feats) |
| pred_score = pred_score + res_img_feats |
| print('pred_score', pred_score.shape) |
|
|
| if self.config.output_attentions and self.config.output_hidden_states: |
| return pred_score, predictions[1], predictions[-1] |
| else: |
| return pred_score |
|
|