Spaces:
Runtime error
Runtime error
| import copy | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from util.misc import inverse_sigmoid | |
| from .deformable_transformer import ( | |
| DeformableTransformerEncoder, | |
| DeformableTransformerEncoderLayer, | |
| MSDeformAttn | |
| ) | |
| from .kv_cache import KVCache, VCache | |
| def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False): | |
| m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
| nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) | |
| if padding_idx is not None: | |
| nn.init.constant_(m.weight[padding_idx], 0) | |
| if zero_init: | |
| nn.init.constant_(m.weight, 0) | |
| return m | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, seq_len): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| pos = np.arange(seq_len, dtype=np.float32) | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=np.float64) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| class DeformableTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model=256, | |
| nhead=8, | |
| num_encoder_layers=6, | |
| num_decoder_layers=6, | |
| dim_feedforward=1024, | |
| dropout=0.1, | |
| activation="relu", | |
| poly_refine=True, | |
| return_intermediate_dec=False, | |
| aux_loss=False, | |
| num_feature_levels=4, | |
| dec_n_points=4, | |
| enc_n_points=4, | |
| query_pos_type="none", | |
| vocab_size=None, | |
| seq_len=1024, | |
| pre_decoder_pos_embed=False, | |
| learnable_dec_pe=False, | |
| dec_attn_concat_src=False, | |
| dec_qkv_proj=True, | |
| pad_idx=None, | |
| use_anchor=False, | |
| inject_cls_embed=False, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.nhead = nhead | |
| self.poly_refine = poly_refine | |
| self.use_anchor = use_anchor | |
| self.inject_cls_embed = inject_cls_embed | |
| encoder_layer = DeformableTransformerEncoderLayer( | |
| d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points | |
| ) | |
| self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) | |
| decoder_layer = TransformerDecoderLayer( | |
| d_model, | |
| dim_feedforward, | |
| dropout, | |
| activation, | |
| num_feature_levels, | |
| nhead, | |
| dec_n_points, | |
| use_qkv_proj=(dec_qkv_proj and not dec_attn_concat_src), | |
| ) | |
| self.decoder = TransformerDecoder( | |
| decoder_layer, | |
| num_decoder_layers, | |
| poly_refine, | |
| return_intermediate_dec, | |
| aux_loss, | |
| query_pos_type, | |
| vocab_size, | |
| pad_idx, | |
| use_anchor=use_anchor, | |
| ) | |
| self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) | |
| if query_pos_type == "sine" and (poly_refine or use_anchor): | |
| self.decoder.pos_trans = nn.Linear(d_model, d_model) | |
| self.decoder.pos_trans_norm = nn.LayerNorm(d_model) | |
| self.pre_decoder_pos_embed = pre_decoder_pos_embed | |
| self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model), requires_grad=learnable_dec_pe) | |
| pos_embed = get_1d_sincos_pos_embed_from_grid(d_model, seq_len) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| self.dec_attn_concat_src = dec_attn_concat_src | |
| if self.inject_cls_embed: | |
| self.decoder.room_class_trans = nn.Sequential( | |
| nn.Linear(d_model, d_model, bias=False), nn.LayerNorm(d_model) | |
| ) | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| for m in self.modules(): | |
| if isinstance(m, MSDeformAttn): | |
| m._reset_parameters() | |
| nn.init.normal_(self.level_embed) | |
| def get_valid_ratio(self, mask): | |
| _, H, W = mask.shape | |
| valid_H = torch.sum(~mask[:, :, 0], 1) | |
| valid_W = torch.sum(~mask[:, 0, :], 1) | |
| valid_ratio_h = valid_H.float() / H | |
| valid_ratio_w = valid_W.float() / W | |
| valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) | |
| return valid_ratio | |
| def _create_causal_attention_mask(self, seq_len): | |
| """ | |
| Creates a causal attention mask for a sequence of length `seq_len`. | |
| """ | |
| # Create an upper triangular matrix with 1s above the diagonal | |
| mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) | |
| # Invert the mask: 1 -> -inf (masked), 0 -> 0 (unmasked) | |
| causal_mask = mask.masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) | |
| return causal_mask | |
| def forward( | |
| self, | |
| srcs, | |
| masks, | |
| pos_embeds, | |
| query_embed=None, | |
| tgt=None, | |
| tgt_masks=None, | |
| seq_kwargs=None, | |
| force_simple_returns=False, | |
| return_enc_cache=False, | |
| enc_cache=None, | |
| decode_token_pos=None, | |
| ): | |
| # assert query_embed is not None | |
| if enc_cache is None: | |
| # prepare input for encoder | |
| src_flatten = [] | |
| mask_flatten = [] | |
| lvl_pos_embed_flatten = [] | |
| spatial_shapes = [] | |
| for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): | |
| bs, c, h, w = src.shape | |
| spatial_shape = (h, w) | |
| spatial_shapes.append(spatial_shape) | |
| src = src.flatten(2).transpose(1, 2) | |
| mask = mask.flatten(1) | |
| pos_embed = pos_embed.flatten(2).transpose(1, 2) | |
| lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) | |
| lvl_pos_embed_flatten.append(lvl_pos_embed) | |
| src_flatten.append(src) | |
| mask_flatten.append(mask) | |
| src_flatten = torch.cat(src_flatten, 1) | |
| mask_flatten = torch.cat(mask_flatten, 1) | |
| lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) | |
| spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) | |
| level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) | |
| valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) | |
| # encoder | |
| memory = self.encoder( | |
| src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten | |
| ) | |
| enc_cache_output = { | |
| "memory": memory, | |
| "spatial_shapes": spatial_shapes, | |
| "level_start_index": level_start_index, | |
| "valid_ratios": valid_ratios, | |
| "mask_flatten": mask_flatten, | |
| "src_flatten": src_flatten, | |
| } | |
| else: | |
| memory, spatial_shapes, level_start_index, valid_ratios, mask_flatten = ( | |
| enc_cache["memory"], | |
| enc_cache["spatial_shapes"], | |
| enc_cache["level_start_index"], | |
| enc_cache["valid_ratios"], | |
| enc_cache["mask_flatten"], | |
| ) | |
| src_flatten = enc_cache["src_flatten"] | |
| enc_cache_output = enc_cache | |
| # prepare input for decoder | |
| bs, _, c = memory.shape | |
| assert not (self.use_anchor and self.poly_refine), "use_anchor and poly_refine cannot be used together" | |
| if self.poly_refine or self.use_anchor: | |
| query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) | |
| reference_points = query_embed.sigmoid() | |
| query_pos = None # inferred from reference_points | |
| else: | |
| reference_points = None | |
| query_pos = self.pos_embed | |
| init_reference_out = reference_points | |
| if tgt_masks is None: | |
| # make causal mask | |
| if decode_token_pos is not None: | |
| tgt_masks = torch.zeros(1, decode_token_pos.max() + 1, dtype=torch.float).to(memory.device) | |
| else: | |
| tgt_masks = self._create_causal_attention_mask(seq_kwargs["seq11"].shape[1]).to(memory.device) | |
| # decoder | |
| hs, inter_references, inter_classes = self.decoder( | |
| tgt, | |
| reference_points, | |
| memory, | |
| src_flatten, | |
| spatial_shapes, | |
| level_start_index, | |
| valid_ratios, | |
| query_pos, | |
| mask_flatten, | |
| tgt_masks, | |
| seq_kwargs, | |
| force_simple_returns=force_simple_returns, | |
| pre_decoder_pos_embed=self.pre_decoder_pos_embed, | |
| attn_concat_src=self.dec_attn_concat_src, | |
| decode_token_pos=decode_token_pos, | |
| ) | |
| if return_enc_cache: | |
| return hs, init_reference_out, inter_references, inter_classes, enc_cache_output | |
| return hs, init_reference_out, inter_references, inter_classes | |
| def _setup_caches(self, max_batch_size, max_seq_length, max_vision_length, model_dim, nhead, dtype, device): | |
| for layer in self.decoder.layers: | |
| layer.kv_cache = KVCache(max_batch_size, max_seq_length, model_dim, dtype).to(device) | |
| layer.cross_attn.cache = VCache( | |
| max_batch_size, max_vision_length, nhead, int(model_dim // nhead), dtype | |
| ).to(device) | |
| class TransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model=256, | |
| d_ffn=1024, | |
| dropout=0.1, | |
| activation="relu", | |
| n_levels=4, | |
| n_heads=8, | |
| n_points=4, | |
| use_qkv_proj=True, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| if use_qkv_proj: | |
| self.attn_q = nn.Linear(d_model, d_model, bias=False) | |
| self.attn_k = nn.Linear(d_model, d_model, bias=False) | |
| self.attn_v = nn.Linear(d_model, d_model, bias=False) | |
| else: | |
| self.attn_q = nn.Identity() | |
| self.attn_k = nn.Identity() | |
| self.attn_v = nn.Identity() | |
| # attention | |
| self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| # cross attention | |
| self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| # ffn | |
| self.linear1 = nn.Linear(d_model, d_ffn) | |
| self.activation = _get_activation_fn(activation) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(d_ffn, d_model) | |
| self.dropout4 = nn.Dropout(dropout) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.kv_cache = None | |
| def with_pos_embed(tensor, pos): | |
| return tensor if pos is None else tensor + pos[:, : tensor.size(1)] | |
| def forward_ffn(self, tgt): | |
| tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout4(tgt2) | |
| tgt = self.norm3(tgt) | |
| return tgt | |
| def forward( | |
| self, | |
| tgt, | |
| query_pos, | |
| reference_points, | |
| src, | |
| src_spatial_shapes, | |
| level_start_index, | |
| src_padding_mask=None, | |
| tgt_masks=None, | |
| attn_concat_src=False, | |
| input_pos=None, | |
| ): | |
| q = self.with_pos_embed(self.attn_q(tgt), query_pos) | |
| # self attention | |
| if self.kv_cache is not None and input_pos is not None: | |
| k = self.attn_k(tgt) | |
| v = self.attn_v(tgt) | |
| k, v = self.kv_cache.update(input_pos, k, v) | |
| else: | |
| k = self.attn_k(tgt) | |
| v = self.attn_v(tgt) | |
| if attn_concat_src: | |
| k = torch.cat([src, k], dim=1) | |
| v = torch.cat([src, v], dim=1) | |
| tgt_masks = torch.cat([torch.zeros(q.size(1), src.size(1), device=q.device), tgt_masks], dim=1).to( | |
| dtype=torch.float32 | |
| ) | |
| tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), v.transpose(0, 1), attn_mask=tgt_masks)[ | |
| 0 | |
| ].transpose(0, 1) | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| # cross attention | |
| tgt2 = self.cross_attn( | |
| self.with_pos_embed(tgt, query_pos), | |
| reference_points, | |
| src, | |
| src_spatial_shapes, | |
| level_start_index, | |
| src_padding_mask, | |
| use_cache=(input_pos is not None and input_pos[0] != 0), | |
| ) # disable cache when processing first token | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt = self.norm1(tgt) | |
| # ffn | |
| tgt = self.forward_ffn(tgt) | |
| return tgt, None | |
| class TransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| decoder_layer, | |
| num_layers, | |
| poly_refine=True, | |
| return_intermediate=False, | |
| aux_loss=False, | |
| query_pos_type="none", | |
| vocab_size=None, | |
| pad_idx=None, | |
| use_anchor=None, | |
| ): | |
| super().__init__() | |
| self.layers = _get_clones(decoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| self.poly_refine = poly_refine | |
| self.return_intermediate = return_intermediate | |
| self.aux_loss = aux_loss | |
| self.query_pos_type = query_pos_type | |
| self.coords_embed = None | |
| self.class_embed = None | |
| self.pos_trans = None | |
| self.pos_trans_norm = None | |
| self.use_anchor = use_anchor | |
| self.room_class_embed = None | |
| self.room_class_trans = None | |
| self.token_embed = Embedding(vocab_size, self.layers[0].d_model, padding_idx=pad_idx, zero_init=False) | |
| def _seq_embed(self, seq11, seq12, seq21, seq22, delta_x1, delta_x2, delta_y1, delta_y2): | |
| # embedding [B, L, D] | |
| e11 = self.token_embed(seq11) | |
| e21 = self.token_embed(seq21) | |
| e12 = self.token_embed(seq12) | |
| e22 = self.token_embed(seq22) | |
| # bilinear interpolation [B, L, D] | |
| out = ( | |
| e11 * delta_x2[..., None] * delta_y2[..., None] | |
| + e21 * delta_x1[..., None] * delta_y2[..., None] | |
| + e12 * delta_x2[..., None] * delta_y1[..., None] | |
| + e22 * delta_x1[..., None] * delta_y1[..., None] | |
| ) | |
| return out | |
| def _add_cls_embed(self, x, input_cls_seq): | |
| # Suppose class_labels is of shape [batch, seq_len] with integer class indices | |
| one_hot = F.one_hot(input_cls_seq, num_classes=self.room_class_embed.out_features).float() | |
| x = x + self.room_class_trans(self.room_class_embed[-1](one_hot)) | |
| return x | |
| def get_query_pos_embed(self, ref_points): | |
| num_pos_feats = 128 | |
| temperature = 10000 | |
| scale = 2 * math.pi | |
| dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=ref_points.device) | |
| dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) # [128] | |
| # N, L, 2 | |
| ref_points = ref_points * scale | |
| # N, L, 2, 128 | |
| pos = ref_points[:, :, :, None] / dim_t | |
| # N, L, 256 | |
| pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) | |
| return pos | |
| def with_pos_embed(tensor, pos): | |
| return tensor if pos is None else tensor + pos[:, : tensor.size(1)] | |
| def forward( | |
| self, | |
| tgt, | |
| reference_points, | |
| src, | |
| src_flatten, | |
| src_spatial_shapes, | |
| src_level_start_index, | |
| src_valid_ratios, | |
| query_pos=None, | |
| src_padding_mask=None, | |
| tgt_masks=None, | |
| seq_kwargs=None, | |
| force_simple_returns=False, | |
| pre_decoder_pos_embed=False, | |
| attn_concat_src=False, | |
| decode_token_pos=None, | |
| ): | |
| # print(seq_kwargs['seq11'].max(),seq_kwargs['seq21'].max(), seq_kwargs['seq12'].max(), seq_kwargs['seq22'].max()) | |
| output = self._seq_embed( | |
| seq11=seq_kwargs["seq11"], | |
| seq12=seq_kwargs["seq12"], | |
| seq21=seq_kwargs["seq21"], | |
| seq22=seq_kwargs["seq22"], | |
| delta_x1=seq_kwargs["delta_x1"], | |
| delta_x2=seq_kwargs["delta_x2"], | |
| delta_y1=seq_kwargs["delta_y1"], | |
| delta_y2=seq_kwargs["delta_y2"], | |
| ) # [B, L, D] | |
| if decode_token_pos is not None: | |
| if query_pos is not None: # if using abs pos_embed | |
| query_pos = query_pos[:, decode_token_pos] | |
| if reference_points is not None: | |
| reference_points = reference_points[:, decode_token_pos : decode_token_pos + 1] | |
| if reference_points is None: | |
| reference_points = torch.zeros(output.shape[0], output.shape[1], 2).to(output.device) | |
| # assert not(pre_decoder_pos_embed and self.poly_refine), 'pre_decoder_pos_embed and poly_refine cannot be used together' | |
| if pre_decoder_pos_embed: | |
| # infer query_pos from reference_points | |
| if (self.poly_refine or self.use_anchor) and self.query_pos_type == "sine": | |
| query_pos = self.pos_trans_norm(self.pos_trans(self.get_query_pos_embed(reference_points))) | |
| output = self.with_pos_embed(output, query_pos) | |
| query_pos = None | |
| if self.room_class_trans is not None: | |
| # add class embedding | |
| output = self._add_cls_embed(output, seq_kwargs["input_polygon_labels"]) | |
| intermediate = [] | |
| intermediate_reference_points = [] | |
| intermediate_classes = [] | |
| point_classes = torch.zeros(output.shape[0], output.shape[1], self.class_embed[0].out_features).to( | |
| output.device | |
| ) | |
| for lid, layer in enumerate(self.layers): | |
| if self.poly_refine or self.use_anchor: | |
| assert reference_points.shape[-1] == 2 | |
| reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] | |
| # disable adding query_pos for every layer | |
| if not pre_decoder_pos_embed: | |
| if self.query_pos_type == "sine": | |
| query_pos = self.pos_trans_norm(self.pos_trans(self.get_query_pos_embed(reference_points))) | |
| elif self.query_pos_type == "none": | |
| query_pos = None | |
| else: | |
| reference_points_input = None | |
| output, src_tmp = layer( | |
| output, | |
| query_pos, | |
| reference_points_input, | |
| src, | |
| src_spatial_shapes, | |
| src_level_start_index, | |
| src_padding_mask, | |
| tgt_masks, | |
| attn_concat_src=attn_concat_src, | |
| input_pos=decode_token_pos, | |
| ) | |
| if src_tmp is not None: | |
| src = src_tmp | |
| # iterative polygon refinement | |
| if self.poly_refine: | |
| offset = self.coords_embed[lid](output) | |
| assert reference_points.shape[-1] == 2 | |
| new_reference_points = offset | |
| new_reference_points = offset + inverse_sigmoid(reference_points) | |
| new_reference_points = new_reference_points.sigmoid() | |
| reference_points = new_reference_points | |
| # if not using iterative polygon refinement, just output the reference points decoded from the last layer | |
| elif lid == len(self.layers) - 1: | |
| if self.use_anchor: | |
| offset = self.coords_embed[-1](output) | |
| assert reference_points.shape[-1] == 2 | |
| new_reference_points = offset | |
| new_reference_points = offset + inverse_sigmoid(reference_points) | |
| new_reference_points = new_reference_points.sigmoid() | |
| reference_points = new_reference_points | |
| else: | |
| reference_points = self.coords_embed[-1](output).sigmoid() | |
| # If aux loss supervision, we predict classes label from each layer and supervise loss | |
| if self.aux_loss: | |
| point_classes = self.class_embed[lid](output) | |
| # Otherwise, we only predict class label from the last layer | |
| elif lid == len(self.layers) - 1: | |
| point_classes = self.class_embed[-1](output) | |
| if self.return_intermediate: | |
| intermediate.append(output) | |
| intermediate_reference_points.append(reference_points) | |
| intermediate_classes.append(point_classes) | |
| if self.return_intermediate and not force_simple_returns: | |
| return ( | |
| torch.stack(intermediate), | |
| torch.stack(intermediate_reference_points), | |
| torch.stack(intermediate_classes), | |
| ) | |
| return output, reference_points, point_classes | |
| def _get_clones(module, N): | |
| if isinstance(module, list): | |
| return nn.ModuleList(module) | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def _get_activation_fn(activation): | |
| """Return an activation function given a string""" | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| raise RuntimeError(f"activation should be relu/gelu, not {activation}.") | |
| def build_deforamble_transformer(args, pad_idx=None): | |
| return DeformableTransformer( | |
| d_model=args.hidden_dim, | |
| nhead=args.nheads, | |
| num_encoder_layers=args.enc_layers, | |
| num_decoder_layers=args.dec_layers, | |
| dim_feedforward=args.dim_feedforward, | |
| dropout=args.dropout, | |
| activation="relu", | |
| poly_refine=args.with_poly_refine, | |
| return_intermediate_dec=True, | |
| aux_loss=args.aux_loss, | |
| num_feature_levels=args.num_feature_levels, | |
| dec_n_points=args.dec_n_points, | |
| enc_n_points=args.enc_n_points, | |
| query_pos_type=args.query_pos_type, | |
| vocab_size=args.vocab_size, | |
| seq_len=args.seq_len, | |
| pre_decoder_pos_embed=args.pre_decoder_pos_embed, | |
| learnable_dec_pe=args.learnable_dec_pe, | |
| dec_attn_concat_src=args.dec_attn_concat_src, | |
| dec_qkv_proj=args.dec_qkv_proj, | |
| pad_idx=pad_idx, | |
| use_anchor=args.use_anchor, | |
| inject_cls_embed=getattr(args, "inject_cls_embed", False), | |
| ) | |