| """Modified from https://github.com/khanrc/honeybee |
| """ |
|
|
| import math |
| from functools import partial |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| from timm.layers import LayerNorm, LayerNorm2d |
| from timm.models.regnet import RegStage |
| from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling |
| from transformers.models.deformable_detr import DeformableDetrConfig |
| from transformers.models.deformable_detr.modeling_deformable_detr import ( |
| DeformableDetrDecoder, |
| DeformableDetrDecoderLayer, |
| DeformableDetrDecoderOutput, |
| ) |
| from transformers.pytorch_utils import ( |
| find_pruneable_heads_and_indices, |
| prune_linear_layer, |
| ) |
|
|
| from .common_layers import HoneybeePreTrainedModel, LayerNormFp32 |
| from .configuration_m4cxr import HoneybeeVisualProjectorConfig |
|
|
|
|
| def build_pos_embeds( |
| config: HoneybeeVisualProjectorConfig, |
| num_input_tokens: int, |
| vision_hidden_size: int, |
| ): |
| |
| if config.pos_emb: |
| pos_emb = torch.nn.Parameter( |
| torch.zeros(1, num_input_tokens, vision_hidden_size) |
| ) |
| nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02) |
| else: |
| pos_emb = None |
|
|
| return pos_emb |
|
|
|
|
| def build_eos_tokens(config: HoneybeeVisualProjectorConfig, output_hidden_size: int): |
| |
| num_eos_tokens = config.num_eos_tokens |
| if num_eos_tokens: |
| eos_tokens = torch.nn.Parameter( |
| torch.randn(1, num_eos_tokens, output_hidden_size) |
| ) |
| nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range) |
| else: |
| eos_tokens = None |
|
|
| return eos_tokens |
|
|
|
|
| def build_prenorm(config: HoneybeeVisualProjectorConfig): |
| if getattr(config, "prenorm", False): |
| prenorm = LayerNorm(config.encoder_hidden_size) |
| else: |
| prenorm = None |
| return prenorm |
|
|
|
|
| def build_mlp(depth: int, hidden_size: int, output_hidden_size: int): |
| layers = [nn.Linear(hidden_size, output_hidden_size)] |
| for _ in range(1, depth): |
| layers.append(nn.SiLU()) |
| layers.append(nn.Linear(output_hidden_size, output_hidden_size)) |
| return nn.Sequential(*layers) |
|
|
|
|
| class Projector(nn.Module): |
| """Base projector class""" |
|
|
| def __init__( |
| self, |
| config: HoneybeeVisualProjectorConfig, |
| num_input_tokens: int, |
| ): |
| super().__init__() |
| self.config = config |
| self.num_input_tokens = num_input_tokens |
|
|
| |
| self.eos_tokens = build_eos_tokens(config, config.output_hidden_size) |
|
|
| |
| self.pos_emb = build_pos_embeds( |
| config, num_input_tokens, config.encoder_hidden_size |
| ) |
|
|
| self.prenorm = build_prenorm(config) |
|
|
| self.build_net() |
|
|
| def build_net(self): |
| raise NotImplementedError() |
|
|
| def _forward(self, x): |
| raise NotImplementedError() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), |
| including cls token. |
| """ |
| if self.prenorm is not None: |
| x = self.prenorm(x) |
|
|
| if self.pos_emb is not None: |
| x += self.pos_emb |
|
|
| x = self._forward(x) |
|
|
| B = x.size(0) |
| if self.eos_tokens is not None: |
| x = torch.cat([x, self.eos_tokens.expand(B, -1, -1)], dim=1) |
|
|
| output = BaseModelOutput(last_hidden_state=x) |
| return output |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| class MLPProjector(Projector): |
| def build_net(self): |
| encoder_hidden_size = self.config.encoder_hidden_size |
| output_hidden_size = self.config.output_hidden_size |
| depth = self.config.depth |
|
|
| self.net = build_mlp(depth, encoder_hidden_size, output_hidden_size) |
|
|
| def _forward(self, x): |
| return self.net(x) |
|
|
|
|
| class ConvProjector(Projector): |
| def _forward(self, x): |
| |
| hw = int(x.size(1) ** 0.5) |
| x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) |
| x = self.net(x) |
| x = rearrange(x, "b d h w -> b (h w) d") |
| x = self.readout(x) |
|
|
| return x |
|
|
|
|
| class CAbstractor(ConvProjector): |
| """C-Abstractor based on RegBlock""" |
|
|
| def build_net(self): |
| encoder_hidden_size = self.config.encoder_hidden_size |
| hidden_size = self.config.hidden_size |
| output_hidden_size = self.config.output_hidden_size |
| depth = self.config.depth |
| mlp_depth = self.config.mlp_depth |
|
|
| n_queries = self.config.num_query_tokens |
| assert (n_queries**0.5).is_integer(), "n_queries must be square number" |
| hw = int(n_queries**0.5) |
|
|
| RegBlock = partial( |
| RegStage, |
| stride=1, |
| dilation=1, |
| act_layer=nn.SiLU, |
| norm_layer=LayerNorm2d, |
| ) |
|
|
| s1 = RegBlock( |
| depth, |
| encoder_hidden_size, |
| hidden_size, |
| ) |
| sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
| s2 = RegBlock( |
| depth, |
| hidden_size, |
| hidden_size, |
| ) |
|
|
| if depth: |
| self.net = nn.Sequential(s1, sampler, s2) |
| self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size) |
| else: |
| self.net = sampler |
| self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size) |
|
|
|
|
| class HoneybeeVisualProjectorMLP(nn.Module): |
| def __init__(self, config: HoneybeeVisualProjectorConfig): |
| super().__init__() |
| self.config = config |
| in_features = config.hidden_size |
| self.act = nn.SiLU() |
| hidden_features = config.intermediate_size |
|
|
| self.w1 = nn.Linear(in_features, hidden_features) |
| self.w2 = nn.Linear(hidden_features, in_features) |
| self.w3 = nn.Linear(in_features, hidden_features) |
| self.ffn_ln = LayerNormFp32(hidden_features, eps=config.layer_norm_eps) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.act(self.w1(hidden_states)) * self.w3(hidden_states) |
| hidden_states = self.ffn_ln(hidden_states) |
| hidden_states = self.w2(hidden_states) |
| return hidden_states |
|
|
|
|
| class HoneybeeVisualProjectorMultiHeadAttention(nn.Module): |
| def __init__(self, config: HoneybeeVisualProjectorConfig): |
| super().__init__() |
| self.config = config |
| if config.hidden_size % config.num_attention_heads != 0: |
| raise ValueError( |
| "The hidden size (%d) is not a multiple of the number of attention heads (%d)" |
| % (config.hidden_size, config.num_attention_heads) |
| ) |
|
|
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
| |
| |
| self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.save_attention = False |
|
|
| def save_attn_gradients(self, attn_gradients): |
| self.attn_gradients = attn_gradients |
|
|
| def get_attn_gradients(self): |
| return self.attn_gradients |
|
|
| def save_attention_map(self, attention_map): |
| self.attention_map = attention_map |
|
|
| def get_attention_map(self): |
| return self.attention_map |
|
|
| def transpose_for_scores(self, x): |
| new_x_shape = x.size()[:-1] + ( |
| self.num_attention_heads, |
| self.attention_head_size, |
| ) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1, 3) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_value=None, |
| output_attentions=False, |
| ): |
| |
| |
| |
| key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
| attention_mask = encoder_attention_mask |
|
|
| mixed_query_layer = self.query(hidden_states) |
|
|
| query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
| past_key_value = (key_layer, value_layer) |
|
|
| |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
|
| if attention_mask is not None: |
| |
| attention_scores = attention_scores + attention_mask |
|
|
| |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
| if self.save_attention: |
| self.save_attention_map(attention_probs) |
| attention_probs.register_hook(self.save_attn_gradients) |
|
|
| |
| |
| attention_probs_dropped = self.dropout(attention_probs) |
|
|
| |
| if head_mask is not None: |
| attention_probs_dropped = attention_probs_dropped * head_mask |
|
|
| context_layer = torch.matmul(attention_probs_dropped, value_layer) |
|
|
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| context_layer = context_layer.view(*new_context_layer_shape) |
|
|
| outputs = ( |
| (context_layer, attention_probs) if output_attentions else (context_layer,) |
| ) |
|
|
| outputs = outputs + (past_key_value,) |
| return outputs |
|
|
|
|
| class HoneybeeVisualProjectorCrossOutput(nn.Module): |
| def __init__(self, config: HoneybeeVisualProjectorConfig): |
| super().__init__() |
| dim = config.hidden_size |
| self.out_proj = nn.Linear(dim, dim, bias=True) |
| self.norm2 = LayerNormFp32(dim) |
| self.mlp = HoneybeeVisualProjectorMLP(config) |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, input_tensor: torch.Tensor |
| ) -> torch.Tensor: |
| input_tensor = input_tensor + self.out_proj(hidden_states) |
| input_tensor = input_tensor + self.mlp(self.norm2(input_tensor)) |
| return input_tensor |
|
|
|
|
| class HoneybeeVisualProjectorAttention(nn.Module): |
| def __init__(self, config: HoneybeeVisualProjectorConfig): |
| super().__init__() |
| self.attention = HoneybeeVisualProjectorMultiHeadAttention(config) |
| self.output = HoneybeeVisualProjectorCrossOutput(config) |
| self.pruned_heads = set() |
| self.norm1 = LayerNormFp32(config.hidden_size) |
| self.normk = LayerNormFp32(config.hidden_size) |
|
|
| def prune_heads(self, heads): |
| if len(heads) == 0: |
| return |
| heads, index = find_pruneable_heads_and_indices( |
| heads, |
| self.attention.num_attention_heads, |
| self.attention.attention_head_size, |
| self.pruned_heads, |
| ) |
|
|
| |
| self.attention.query = prune_linear_layer(self.attention.query, index) |
| self.attention.key = prune_linear_layer(self.attention.key, index) |
| self.attention.value = prune_linear_layer(self.attention.value, index) |
| self.output.dense = prune_linear_layer(self.output.out_proj, index, dim=1) |
|
|
| |
| self.attention.num_attention_heads = self.attention.num_attention_heads - len( |
| heads |
| ) |
| self.attention.all_head_size = ( |
| self.attention.attention_head_size * self.attention.num_attention_heads |
| ) |
| self.pruned_heads = self.pruned_heads.union(heads) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.Tensor]: |
| """ |
| hidden_states: query embeddings [B, num_queries, dim] |
| encoder_hidden_states: visual features [B, num_visual_features, dim] |
| Note) above two features should be the same dimensions. |
| """ |
| |
| hidden_states = self.norm1(hidden_states) |
| encoder_hidden_states = self.normk(encoder_hidden_states) |
| |
| encoder_hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) |
| encoder_attention_mask = torch.cat( |
| [attention_mask, encoder_attention_mask], dim=-1 |
| ) |
| self_outputs = self.attention( |
| hidden_states, |
| attention_mask, |
| head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| past_key_value, |
| output_attentions, |
| ) |
| attention_output = self.output(self_outputs[0], hidden_states) |
| |
| outputs = (attention_output,) + self_outputs[1:] |
| return outputs |
|
|
|
|
| class HoneybeeVisualProjectorLayer(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| self.seq_len_dim = 1 |
|
|
| self.layer_idx = layer_idx |
|
|
| self.crossattention = HoneybeeVisualProjectorAttention(config) |
| self.has_cross_attention = True |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| output_attentions=False, |
| ): |
| if encoder_hidden_states is None: |
| raise ValueError( |
| "encoder_hidden_states must be given for cross-attention layers" |
| ) |
| cross_attention_outputs = self.crossattention( |
| hidden_states, |
| attention_mask, |
| head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| output_attentions=output_attentions, |
| ) |
|
|
| return cross_attention_outputs |
|
|
|
|
| class HoneybeeVisualProjectorEncoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.layers = nn.ModuleList( |
| [ |
| HoneybeeVisualProjectorLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| ): |
| all_hidden_states = () if output_hidden_states else None |
| all_output_attentions = () if output_attentions else None |
|
|
| for i in range(self.config.num_hidden_layers): |
| layer_module = self.layers[i] |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| |
| |
| |
| return module(*inputs, output_attentions) |
| |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer_module), |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| attention_mask, |
| layer_head_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| output_attentions, |
| ) |
| if output_attentions: |
| all_output_attentions = all_output_attentions + (layer_outputs[1],) |
| hidden_states = layer_outputs[0] |
|
|
| return BaseModelOutput( |
| last_hidden_state=hidden_states, attentions=all_output_attentions |
| ) |
|
|
|
|
| class HoneybeeVisualProjectorModel(HoneybeePreTrainedModel): |
| """Resampler model performing cross-attention |
| between query_tokens (key, value) and visual features (query) |
| """ |
|
|
| def __init__(self, config: HoneybeeVisualProjectorConfig, num_input_tokens: int): |
| super().__init__(config) |
| self.config = config |
| self.encoder = HoneybeeVisualProjectorEncoder(config) |
| |
| self.visual_input_fc = torch.nn.Linear( |
| config.encoder_hidden_size, config.hidden_size |
| ) |
| |
| self.visual_output_fc = torch.nn.Linear( |
| config.hidden_size, config.output_hidden_size |
| ) |
|
|
| self.query_tokens = nn.Parameter( |
| torch.zeros(1, config.num_query_tokens, config.hidden_size) |
| ) |
|
|
| |
| self.vit_eos = build_eos_tokens(config, config.output_hidden_size) |
|
|
| |
| self.pos_emb = build_pos_embeds( |
| config, num_input_tokens, config.encoder_hidden_size |
| ) |
|
|
| self.post_init() |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
| class PreTrainedModel |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
| def get_extended_attention_mask( |
| self, |
| attention_mask: torch.Tensor, |
| input_shape: Tuple[int], |
| device: torch.device, |
| ) -> torch.Tensor: |
| """ |
| Makes broadcastable attention and causal masks so that future and masked tokens are ignored. |
| |
| Arguments: |
| attention_mask (`torch.Tensor`): |
| Mask with ones indicating tokens to attend to, zeros for tokens to ignore. |
| input_shape (`Tuple[int]`): |
| The shape of the input to the model. |
| device: (`torch.device`): |
| The device of the input to the model. |
| |
| Returns: |
| `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. |
| """ |
| |
| |
| if attention_mask.dim() == 3: |
| extended_attention_mask = attention_mask[:, None, :, :] |
| elif attention_mask.dim() == 2: |
| |
| |
| extended_attention_mask = attention_mask[:, None, None, :] |
| else: |
| raise ValueError( |
| "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( |
| input_shape, attention_mask.shape |
| ) |
| ) |
|
|
| |
| |
| |
| |
| |
| extended_attention_mask = extended_attention_mask.to( |
| dtype=self.dtype |
| ) |
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
| return extended_attention_mask |
|
|
| def forward( |
| self, |
| encoder_hidden_states, |
| attention_mask=None, |
| head_mask=None, |
| encoder_attention_mask=None, |
| past_key_values=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if |
| the model is configured as a decoder. |
| encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): |
| Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
| the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: |
| shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and |
| value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are |
| used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key |
| value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape |
| `(batch_size, sequence_length)`. |
| """ |
| query_embeds = self.query_tokens.expand(encoder_hidden_states.shape[0], -1, -1) |
|
|
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| input_shape = query_embeds.size()[:-1] |
| device = query_embeds.device |
|
|
| |
| |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| (query_embeds.shape[0], query_embeds.shape[1]), |
| dtype=torch.long, |
| device=query_embeds.device, |
| ) |
| extended_attention_mask = self.get_extended_attention_mask( |
| attention_mask, input_shape, device |
| ) |
|
|
| |
| |
| if encoder_hidden_states is not None: |
| if type(encoder_hidden_states) == list: |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ |
| 0 |
| ].size() |
| else: |
| ( |
| encoder_batch_size, |
| encoder_sequence_length, |
| _, |
| ) = encoder_hidden_states.size() |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
|
| if type(encoder_attention_mask) == list: |
| encoder_extended_attention_mask = [ |
| self.invert_attention_mask(mask) for mask in encoder_attention_mask |
| ] |
| elif encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
| encoder_extended_attention_mask = self.invert_attention_mask( |
| encoder_attention_mask |
| ) |
| else: |
| encoder_extended_attention_mask = self.invert_attention_mask( |
| encoder_attention_mask |
| ) |
| else: |
| encoder_extended_attention_mask = None |
|
|
| |
| |
| |
| |
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
| |
| |
| if self.pos_emb is not None: |
| encoder_hidden_states += self.pos_emb |
|
|
| |
| |
| encoder_hidden_states = self.visual_input_fc(encoder_hidden_states) |
| assert query_embeds.shape[-1] == encoder_hidden_states.shape[-1] |
|
|
| encoder_outputs = self.encoder( |
| query_embeds, |
| attention_mask=extended_attention_mask, |
| head_mask=head_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| past_key_values=past_key_values, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| sequence_output = encoder_outputs[0] |
| pooled_output = sequence_output[:, 0, :] |
|
|
| |
| sequence_output = self.visual_output_fc(sequence_output) |
| if self.vit_eos is not None: |
| sequence_output = torch.cat( |
| [sequence_output, self.vit_eos.repeat(sequence_output.shape[0], 1, 1)], |
| dim=1, |
| ) |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=sequence_output, |
| pooler_output=pooled_output, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| class DAbstractor(DeformableDetrDecoder): |
| |
| def __init__(self, config: DeformableDetrConfig, num_input_tokens: int, *igargs): |
| super().__init__(config) |
|
|
| self.num_queries = config.num_queries |
| self.num_input_tokens = num_input_tokens |
|
|
| self.num_feature_levels = config.num_feature_levels |
| self.isMs = self.num_feature_levels > 1 |
|
|
| self.layers = nn.ModuleList( |
| [DeformableDetrDecoderLayer(config) for _ in range(config.decoder_layers)] |
| ) |
|
|
| |
| is_dim_missmatch = config.d_model != config.encoder_hidden_size |
| input_proj_list = [] |
| for _ in range(self.num_feature_levels): |
| if is_dim_missmatch: |
| |
| input_proj_list.append( |
| nn.Linear(config.encoder_hidden_size, config.d_model) |
| ) |
| else: |
| input_proj_list.append(nn.Identity()) |
|
|
| self.input_proj = nn.ModuleList(input_proj_list) |
|
|
| |
| if self.isMs: |
| assert config.num_feature_levels == len(config.feature_layer_index) |
| self.level_emb = nn.Parameter( |
| torch.Tensor(1, config.num_feature_levels, 1, config.d_model) |
| ) |
| nn.init.normal_( |
| self.level_emb |
| ) |
|
|
| |
| self.pooled_v_target = config.pooled_v_target |
| if self.pooled_v_target != "none": |
| tgt_hw = int(config.num_queries**0.5) |
| self.downsampler = nn.AdaptiveAvgPool2d((tgt_hw, tgt_hw)) |
| self.query_position_embeddings = nn.Embedding( |
| config.num_queries, config.d_model |
| ) |
| else: |
| self.query_position_embeddings = nn.Embedding( |
| config.num_queries, config.d_model * 2 |
| ) |
|
|
| |
| |
| valid_ratios_q, spatial_shapes_q, _ = self._prepare_ddetr_inputs( |
| 1, num_input_tokens, 1 |
| ) |
| reference_points = self._get_query_reference_points( |
| spatial_shapes_q, valid_ratios_q |
| ) |
| self.reference_points = nn.Parameter(reference_points) |
|
|
| |
| self.eos_tokens = build_eos_tokens(config, config.d_model) |
|
|
| |
| self.v_pos_emb = build_pos_embeds(config, num_input_tokens, config.d_model) |
|
|
| |
| if config.output_hidden_size != config.d_model: |
| self.output_proj = nn.Linear(config.d_model, config.output_hidden_size) |
| else: |
| self.output_proj = nn.Identity() |
|
|
| def _get_query_reference_points(self, spatial_shapes, valid_ratios): |
| """ |
| Get reference points for each feature map. Used in decoder. |
| Args: |
| spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): |
| Spatial shapes of each feature map. |
| valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): |
| Valid ratios of each feature map. |
| device (`torch.device`): |
| Device on which to create the tensors. |
| Returns: |
| `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` |
| """ |
| reference_points_list = [] |
| steps = int(self.num_queries**0.5) |
| for level, (height, width) in enumerate(spatial_shapes): |
| ref_y, ref_x = torch.meshgrid( |
| torch.linspace(0.5, height - 0.5, steps, dtype=torch.float32), |
| torch.linspace(0.5, width - 0.5, steps, dtype=torch.float32), |
| indexing="ij", |
| ) |
| ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height) |
| ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width) |
| ref = torch.stack((ref_x, ref_y), -1) |
| reference_points_list.append(ref) |
| reference_points = torch.cat(reference_points_list, 1) |
| reference_points = reference_points[:, :, None] * valid_ratios[:, None] |
| return reference_points.squeeze(2) |
|
|
| def _forward( |
| self, |
| inputs_embeds=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| position_embeddings=None, |
| reference_points=None, |
| spatial_shapes=None, |
| level_start_index=None, |
| valid_ratios=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| r""" |
| Args: |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): |
| The query embeddings that are passed into the decoder. |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention |
| of the decoder. |
| encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected |
| in `[0, 1]`: |
| - 1 for pixels that are real (i.e. **not masked**), |
| - 0 for pixels that are padding (i.e. **masked**). |
| position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): |
| Position embeddings that are added to the queries and keys in each self-attention layer. |
| reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): |
| Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. |
| spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): |
| Spatial shapes of the feature maps. |
| level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): |
| Indexes for the start of each feature level. In range `[0, sequence_length]`. |
| valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): |
| Ratio of valid area in each feature level. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. |
| """ |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| if inputs_embeds is not None: |
| hidden_states = inputs_embeds |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| intermediate = () |
| intermediate_reference_points = () |
|
|
| for _, decoder_layer in enumerate(self.layers): |
| if reference_points.shape[-1] == 4: |
| reference_points_input = ( |
| reference_points[:, :, None] |
| * torch.cat([valid_ratios, valid_ratios], -1)[:, None] |
| ) |
| else: |
| if reference_points.shape[-1] != 2: |
| raise ValueError( |
| "Reference points' last dimension must be of size 2" |
| ) |
| reference_points_input = ( |
| reference_points[:, :, None] * valid_ratios[:, None] |
| ) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(decoder_layer), |
| hidden_states, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| None, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| position_embeddings=position_embeddings, |
| encoder_hidden_states=encoder_hidden_states, |
| reference_points=reference_points_input, |
| spatial_shapes=spatial_shapes, |
| level_start_index=level_start_index, |
| encoder_attention_mask=encoder_attention_mask, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| intermediate += (hidden_states,) |
| intermediate_reference_points += (reference_points,) |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| |
| intermediate = torch.stack(intermediate, dim=1) |
| intermediate_reference_points = torch.stack( |
| intermediate_reference_points, dim=1 |
| ) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| intermediate, |
| intermediate_reference_points, |
| all_hidden_states, |
| all_self_attns, |
| ] |
| if v is not None |
| ) |
| return DeformableDetrDecoderOutput( |
| last_hidden_state=hidden_states, |
| intermediate_hidden_states=intermediate, |
| intermediate_reference_points=intermediate_reference_points, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
| def _process_v_features(self, visual_feat): |
| |
|
|
| if self.isMs: |
| visual_feats = [] |
| for level in range(self.num_feature_levels): |
| visual_feats.append(self.input_proj[level](visual_feat[:, level])) |
| visual_feat = torch.stack(visual_feats, 1) |
|
|
| |
| if self.v_pos_emb is not None: |
| visual_feat = visual_feat + self.v_pos_emb.unsqueeze(1) |
|
|
| |
| visual_feat = visual_feat + self.level_emb |
| visual_feat = visual_feat.flatten( |
| 1, 2 |
| ) |
| else: |
| visual_feat = self.input_proj[0](visual_feat) |
| if self.v_pos_emb is not None: |
| visual_feat = visual_feat + self.v_pos_emb |
|
|
| return visual_feat |
|
|
| def _convert_dtype_device(self, tgt_feat, dtype=None, device=None): |
| |
| _dtype = tgt_feat.dtype if dtype is None else dtype |
| _device = tgt_feat.device if device is None else device |
|
|
| tgt_feat = tgt_feat.type(_dtype).to(_device) |
|
|
| return tgt_feat |
|
|
| def _prepare_ddetr_inputs(self, batch_size, seq_len, lvls, dtype=None, device=None): |
| |
| valid_ratios = torch.ones(batch_size, lvls, 2) |
|
|
| |
| spatial_shapes = torch.tensor( |
| [int(seq_len**0.5), int(seq_len**0.5)] |
| ).repeat(lvls, 1) |
| level_start_index = torch.arange(0, seq_len * lvls, seq_len) |
|
|
| if dtype is not None and device is not None: |
| valid_ratios = self._convert_dtype_device( |
| valid_ratios, dtype=dtype, device=device |
| ) |
| spatial_shapes = self._convert_dtype_device( |
| spatial_shapes, dtype=torch.long, device=device |
| ) |
| level_start_index = self._convert_dtype_device( |
| level_start_index, dtype=torch.long, device=device |
| ) |
|
|
| return valid_ratios, spatial_shapes, level_start_index |
|
|
| def _make_pooled_queries(self, visual_feat): |
| assert ( |
| self.num_feature_levels == 1 |
| ) |
|
|
| batch_size, seq_len, h_dim = visual_feat.shape |
| query_embeds = self.query_position_embeddings.weight |
| if self.pooled_v_target != "none": |
| hw_v = int(seq_len**0.5) |
| hw_q = int(self.num_queries**0.5) |
| visual_feat = rearrange(visual_feat, "b (h w) d -> b d h w", h=hw_v, w=hw_v) |
| if self.pooled_v_target == "tgt": |
| query_embed = query_embeds.unsqueeze(0).expand(batch_size, -1, -1) |
| target = self.downsampler(visual_feat) |
| target = rearrange(target, "b d h w -> b (h w) d", h=hw_q, w=hw_q) |
| else: |
| target = query_embeds.unsqueeze(0).expand(batch_size, -1, -1) |
| query_embed = self.downsampler(visual_feat) |
| query_embed = rearrange( |
| query_embed, "b d h w -> b (h w) d", h=hw_q, w=hw_q |
| ) |
| else: |
| query_embed, target = torch.split(query_embeds, h_dim, dim=1) |
| query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1) |
| target = target.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
| return query_embed, target |
|
|
| def forward(self, visual_feat): |
| """ |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): |
| The query embeddings that are passed into the decoder. |
| """ |
| |
| original_dtype = visual_feat.type() |
| visual_feat = visual_feat.type(torch.cuda.FloatTensor) |
| visual_feat = self._process_v_features(visual_feat) |
|
|
| batch_size, seq_len, h_dim = visual_feat.shape |
| seq_len /= self.num_feature_levels |
|
|
| query_embed, target = self._make_pooled_queries(visual_feat) |
| reference_points = self.reference_points.expand(batch_size, -1, -1) |
|
|
| valid_ratios, spatial_shapes, level_start_index = self._prepare_ddetr_inputs( |
| batch_size, |
| seq_len, |
| self.num_feature_levels, |
| visual_feat.dtype, |
| visual_feat.device, |
| ) |
|
|
| decoder_outputs_dict = self._forward( |
| inputs_embeds=target, |
| position_embeddings=query_embed, |
| encoder_hidden_states=visual_feat, |
| valid_ratios=valid_ratios, |
| reference_points=reference_points, |
| return_dict=True, |
| spatial_shapes=spatial_shapes, |
| level_start_index=level_start_index, |
| ) |
|
|
| decoder_outputs = decoder_outputs_dict.last_hidden_state |
|
|
| if self.eos_tokens is not None: |
| decoder_outputs = torch.cat( |
| [decoder_outputs, self.eos_tokens.expand(batch_size, -1, -1)], dim=1 |
| ) |
|
|
| decoder_outputs = self.output_proj(decoder_outputs) |
| decoder_outputs = decoder_outputs.type(original_dtype) |
|
|
| return DeformableDetrDecoderOutput( |
| last_hidden_state=decoder_outputs, |
| ) |
|
|