Spaces:
Runtime error
Runtime error
| import torch | |
| import einops | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch import Tensor, nn | |
| from typing import Optional | |
| from leo.utils import get_activation_fn, layer_repeat, calc_pairwise_locs | |
| def disabled_train(self, mode=True): | |
| """ | |
| Overwrite model.train with this function to make sure train/eval mode does not change anymore | |
| """ | |
| return self | |
| class TransformerEncoderLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dim_feedforward=2048, batch_first=True, dropout=0.1, activation="relu", prenorm=False): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention( | |
| d_model, nhead, dropout=dropout, batch_first=batch_first | |
| ) | |
| # Implementation of Feedforward modules | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.activation = get_activation_fn(activation) | |
| self.prenorm = prenorm | |
| def forward( | |
| self, tgt, tgt_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| ): | |
| tgt2 = tgt | |
| if self.prenorm: | |
| tgt2 = self.norm1(tgt2) | |
| tgt2, self_attn_matrices = self.self_attn( | |
| query=tgt2, key=tgt2, value=tgt2, attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask | |
| ) | |
| tgt = tgt + self.dropout1(tgt2) | |
| if not self.prenorm: | |
| tgt = self.norm1(tgt) | |
| if self.prenorm: | |
| tgt = self.norm2(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout2(tgt2) | |
| if not self.prenorm: | |
| tgt = self.norm2(tgt) | |
| return tgt, self_attn_matrices | |
| class MultiHeadAttentionSpatial(nn.Module): | |
| def __init__( | |
| self, d_model, n_head, dropout=0.1, spatial_multihead=True, spatial_dim=5, | |
| spatial_attn_fusion='mul', | |
| ): | |
| super().__init__() | |
| assert d_model % n_head == 0, 'd_model: %d, n_head: %d' % (d_model, n_head) | |
| self.n_head = n_head | |
| self.d_model = d_model | |
| self.d_per_head = d_model // n_head | |
| self.spatial_multihead = spatial_multihead | |
| self.spatial_dim = spatial_dim | |
| self.spatial_attn_fusion = spatial_attn_fusion | |
| self.w_qs = nn.Linear(d_model, d_model) | |
| self.w_ks = nn.Linear(d_model, d_model) | |
| self.w_vs = nn.Linear(d_model, d_model) | |
| self.fc = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| self.spatial_n_head = n_head if spatial_multihead else 1 | |
| if self.spatial_attn_fusion in ['mul', 'bias', 'add']: | |
| self.pairwise_loc_fc = nn.Linear(spatial_dim, self.spatial_n_head) | |
| elif self.spatial_attn_fusion == 'ctx': | |
| self.pairwise_loc_fc = nn.Linear(spatial_dim, d_model) | |
| elif self.spatial_attn_fusion == 'cond': | |
| self.lang_cond_fc = nn.Linear(d_model, self.spatial_n_head * (spatial_dim + 1)) | |
| else: | |
| raise NotImplementedError('unsupported spatial_attn_fusion %s' % (self.spatial_attn_fusion)) | |
| def forward(self, q, k, v, pairwise_locs, key_padding_mask=None, txt_embeds=None): | |
| residual = q | |
| q = einops.rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head) | |
| k = einops.rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head) | |
| v = einops.rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head) | |
| attn = torch.einsum('hblk,hbtk->hblt', q, k) / np.sqrt(q.shape[-1]) | |
| if self.spatial_attn_fusion in ['mul', 'bias', 'add']: | |
| loc_attn = self.pairwise_loc_fc(pairwise_locs) | |
| loc_attn = einops.rearrange(loc_attn, 'b l t h -> h b l t') | |
| if self.spatial_attn_fusion == 'mul': | |
| loc_attn = F.relu(loc_attn) | |
| if not self.spatial_multihead: | |
| loc_attn = einops.repeat(loc_attn, 'h b l t -> (h nh) b l t', nh=self.n_head) | |
| elif self.spatial_attn_fusion == 'ctx': | |
| loc_attn = self.pairwise_loc_fc(pairwise_locs) | |
| loc_attn = einops.rearrange(loc_attn, 'b l t (h k) -> h b l t k', h=self.n_head) | |
| loc_attn = torch.einsum('hblk,hbltk->hblt', q, loc_attn) / np.sqrt(q.shape[-1]) | |
| elif self.spatial_attn_fusion == 'cond': | |
| spatial_weights = self.lang_cond_fc(residual) | |
| spatial_weights = einops.rearrange(spatial_weights, 'b l (h d) -> h b l d', h=self.spatial_n_head, | |
| d=self.spatial_dim + 1) | |
| if self.spatial_n_head == 1: | |
| spatial_weights = einops.repeat(spatial_weights, '1 b l d -> h b l d', h=self.n_head) | |
| spatial_bias = spatial_weights[..., :1] | |
| spatial_weights = spatial_weights[..., 1:] | |
| loc_attn = torch.einsum('hbld,bltd->hblt', spatial_weights, pairwise_locs) + spatial_bias | |
| loc_attn = torch.sigmoid(loc_attn) | |
| if key_padding_mask is not None: | |
| mask = einops.repeat(key_padding_mask, 'b t -> h b l t', h=self.n_head, l=q.size(2)) | |
| attn = attn.masked_fill(mask, -np.inf) | |
| if self.spatial_attn_fusion in ['mul', 'cond']: | |
| loc_attn = loc_attn.masked_fill(mask, 0) | |
| else: | |
| loc_attn = loc_attn.masked_fill(mask, -np.inf) | |
| if self.spatial_attn_fusion == 'add': | |
| fused_attn = (torch.softmax(attn, 3) + torch.softmax(loc_attn, 3)) / 2 | |
| else: | |
| if self.spatial_attn_fusion in ['mul', 'cond']: | |
| fused_attn = torch.log(torch.clamp(loc_attn, min=1e-6)) + attn | |
| else: | |
| fused_attn = loc_attn + attn | |
| fused_attn = torch.softmax(fused_attn, 3) | |
| assert torch.sum(torch.isnan(fused_attn) == 0), print(fused_attn) | |
| output = torch.einsum('hblt,hbtv->hblv', fused_attn, v) | |
| output = einops.rearrange(output, 'head b l v -> b l (head v)') | |
| output = self.dropout(self.fc(output)) | |
| output = self.layer_norm(output + residual) | |
| return output, fused_attn | |
| class TransformerSpatialEncoderLayer(TransformerEncoderLayer): | |
| def __init__( | |
| self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", | |
| spatial_multihead=True, spatial_dim=5, spatial_attn_fusion='mul' | |
| ): | |
| super().__init__( | |
| d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation | |
| ) | |
| del self.self_attn | |
| self.self_attn = MultiHeadAttentionSpatial( | |
| d_model, nhead, dropout=dropout, | |
| spatial_multihead=spatial_multihead, | |
| spatial_dim=spatial_dim, | |
| spatial_attn_fusion=spatial_attn_fusion, | |
| ) | |
| def forward( | |
| self, tgt, tgt_pairwise_locs, | |
| tgt_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| ): | |
| tgt2 = tgt | |
| tgt2, self_attn_matrices = self.self_attn( | |
| tgt2, tgt2, tgt2, tgt_pairwise_locs, | |
| key_padding_mask=tgt_key_padding_mask | |
| ) | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt = self.norm1(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| return tgt, self_attn_matrices | |
| def _init_weights_bert(module, std=0.02): | |
| """ | |
| Huggingface transformer weight initialization, | |
| most commonly for bert initialization | |
| """ | |
| if isinstance(module, nn.Linear): | |
| # Slightly different from the TF version which uses truncated_normal for initialization | |
| # cf https://github.com/pytorch/pytorch/pull/5617 | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def generate_fourier_features(pos, num_bands=10, max_freq=15, concat_pos=True, sine_only=False): | |
| # Input: B, N, C | |
| # Output: B, N, C' | |
| batch_size = pos.shape[0] | |
| device = pos.device | |
| min_freq = 1.0 | |
| # Nyquist frequency at the target resolution: | |
| freq_bands = torch.linspace(start=min_freq, end=max_freq, steps=num_bands, device=device) | |
| # Get frequency bands for each spatial dimension. | |
| # Output is size [n, d * num_bands] | |
| per_pos_features = pos.unsqueeze(-1).repeat(1, 1, 1, num_bands) * freq_bands | |
| per_pos_features = torch.reshape( | |
| per_pos_features, [batch_size, -1, np.prod(per_pos_features.shape[2:])]) | |
| if sine_only: | |
| # Output is size [n, d * num_bands] | |
| per_pos_features = torch.sin(np.pi * (per_pos_features)) | |
| else: | |
| # Output is size [n, 2 * d * num_bands] | |
| per_pos_features = torch.cat( | |
| [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1 | |
| ) | |
| # Concatenate the raw input positions. | |
| if concat_pos: | |
| # Adds d bands to the encoding. | |
| per_pos_features = torch.cat( | |
| [pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1) | |
| return per_pos_features | |
| class OSE3D(nn.Module): | |
| # Open-vocabulary, Spatial-attention, Embodied-token, 3D-agent | |
| def __init__(self, use_spatial_attn=True, use_embodied_token=False, hidden_dim=256, fourier_size=84, spatial_encoder={ | |
| "num_attention_heads": 8, | |
| "dim_feedforward": 2048, | |
| "dropout": 0.1, | |
| "activation": "gelu", | |
| "spatial_dim": 5, | |
| "spatial_multihead": True, | |
| "spatial_attn_fusion": "cond", | |
| "num_layers": 3, | |
| "pairwise_rel_type": "center", | |
| "spatial_dist_norm": True, | |
| "obj_loc_encoding": "same_all", | |
| "dim_loc": 6, | |
| }): | |
| super().__init__() | |
| self.use_spatial_attn = use_spatial_attn # spatial attention | |
| self.use_embodied_token = use_embodied_token # embodied token | |
| # pcd backbone | |
| # self.obj_encoder = PointcloudBackbone(backbone) | |
| self.obj_proj = nn.Linear(768, hidden_dim) | |
| # embodied token | |
| if self.use_embodied_token: | |
| self.anchor_feat = nn.Parameter(torch.zeros(1, 1, hidden_dim)) | |
| self.anchor_size = nn.Parameter(torch.ones(1, 1, 3)) | |
| self.orient_encoder = nn.Linear(fourier_size, hidden_dim) | |
| self.obj_type_embed = nn.Embedding(2, hidden_dim) | |
| # spatial encoder | |
| if self.use_spatial_attn: | |
| spatial_encoder_layer = TransformerSpatialEncoderLayer( | |
| d_model=hidden_dim, | |
| nhead=spatial_encoder['num_attention_heads'], | |
| dim_feedforward=spatial_encoder['dim_feedforward'], | |
| dropout=spatial_encoder['dropout'], | |
| activation=spatial_encoder['activation'], | |
| spatial_dim=spatial_encoder['spatial_dim'], | |
| spatial_multihead=spatial_encoder['spatial_multihead'], | |
| spatial_attn_fusion=spatial_encoder['spatial_attn_fusion'], | |
| ) | |
| else: | |
| spatial_encoder_layer = TransformerEncoderLayer( | |
| d_model=hidden_dim, | |
| nhead=spatial_encoder['num_attention_heads'], | |
| dim_feedforward=spatial_encoder['dim_feedforward'], | |
| dropout=spatial_encoder['dropout'], | |
| activation=spatial_encoder['activation'], | |
| ) | |
| self.spatial_encoder = layer_repeat( | |
| spatial_encoder_layer, | |
| spatial_encoder['num_layers'], | |
| ) | |
| self.pairwise_rel_type = spatial_encoder['pairwise_rel_type'] | |
| self.spatial_dist_norm = spatial_encoder['spatial_dist_norm'] | |
| self.spatial_dim = spatial_encoder['spatial_dim'] | |
| self.obj_loc_encoding = spatial_encoder['obj_loc_encoding'] | |
| # location encoding | |
| if self.obj_loc_encoding in ['same_0', 'same_all']: | |
| num_loc_layers = 1 | |
| elif self.obj_loc_encoding == 'diff_all': | |
| num_loc_layers = spatial_encoder['num_layers'] | |
| loc_layer = nn.Sequential( | |
| nn.Linear(spatial_encoder['dim_loc'], hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| ) | |
| self.loc_layers = layer_repeat(loc_layer, num_loc_layers) | |
| # only initialize spatial encoder and loc layers | |
| self.spatial_encoder.apply(_init_weights_bert) | |
| self.loc_layers.apply(_init_weights_bert) | |
| if self.use_embodied_token: | |
| nn.init.normal_(self.anchor_feat, std=0.02) | |
| def device(self): | |
| return list(self.parameters())[0].device | |
| def forward(self, data_dict): | |
| """ | |
| data_dict requires keys: | |
| obj_fts: (B, N, P, 6), xyz + rgb | |
| obj_masks: (B, N), 1 valid and 0 masked | |
| obj_locs: (B, N, 6), xyz + whd | |
| anchor_locs: (B, 3) | |
| anchor_orientation: (B, C) | |
| """ | |
| # obj_feats = self.obj_encoder(data_dict['obj_fts']) | |
| obj_feats = data_dict['obj_feats'] | |
| obj_feats = self.obj_proj(obj_feats) | |
| obj_masks = ~data_dict['obj_masks'] # flipped due to different convention of TransformerEncoder | |
| B, N = obj_feats.shape[:2] | |
| device = obj_feats.device | |
| obj_type_ids = torch.zeros((B, N), dtype=torch.long, device=device) | |
| obj_type_embeds = self.obj_type_embed(obj_type_ids) | |
| if self.use_embodied_token: | |
| # anchor feature | |
| anchor_orient = data_dict['anchor_orientation'].unsqueeze(1) | |
| anchor_orient_feat = self.orient_encoder(generate_fourier_features(anchor_orient)) | |
| anchor_feat = self.anchor_feat + anchor_orient_feat | |
| anchor_mask = torch.zeros((B, 1), dtype=bool, device=device) | |
| # anchor loc (3) + size (3) | |
| anchor_loc = torch.cat( | |
| [data_dict['anchor_locs'].unsqueeze(1), self.anchor_size.expand(B, -1, -1).to(device)], dim=-1 | |
| ) | |
| # anchor type | |
| anchor_type_id = torch.ones((B, 1), dtype=torch.long, device=device) | |
| anchor_type_embed = self.obj_type_embed(anchor_type_id) | |
| # fuse anchor and objs | |
| all_obj_feats = torch.cat([anchor_feat, obj_feats], dim=1) | |
| all_obj_masks = torch.cat((anchor_mask, obj_masks), dim=1) | |
| all_obj_locs = torch.cat([anchor_loc, data_dict['obj_locs']], dim=1) | |
| all_obj_type_embeds = torch.cat((anchor_type_embed, obj_type_embeds), dim=1) | |
| else: | |
| all_obj_feats = obj_feats | |
| all_obj_masks = obj_masks | |
| all_obj_locs = data_dict['obj_locs'] | |
| all_obj_type_embeds = obj_type_embeds | |
| all_obj_feats = all_obj_feats + all_obj_type_embeds | |
| # call spatial encoder | |
| if self.use_spatial_attn: | |
| pairwise_locs = calc_pairwise_locs( | |
| all_obj_locs[:, :, :3], | |
| all_obj_locs[:, :, 3:], | |
| pairwise_rel_type=self.pairwise_rel_type, | |
| spatial_dist_norm=self.spatial_dist_norm, | |
| spatial_dim=self.spatial_dim, | |
| ) | |
| for i, pc_layer in enumerate(self.spatial_encoder): | |
| if self.obj_loc_encoding == 'diff_all': | |
| query_pos = self.loc_layers[i](all_obj_locs) | |
| else: | |
| query_pos = self.loc_layers[0](all_obj_locs) | |
| if not (self.obj_loc_encoding == 'same_0' and i > 0): | |
| all_obj_feats = all_obj_feats + query_pos | |
| if self.use_spatial_attn: | |
| all_obj_feats, _ = pc_layer( | |
| all_obj_feats, pairwise_locs, | |
| tgt_key_padding_mask=all_obj_masks | |
| ) | |
| else: | |
| all_obj_feats, _ = pc_layer( | |
| all_obj_feats, | |
| tgt_key_padding_mask=all_obj_masks | |
| ) | |
| data_dict['obj_tokens'] = all_obj_feats | |
| data_dict['obj_masks'] = ~all_obj_masks | |
| # ###feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth') | |
| # data_dict['obj_tokens'] = torch.load('assets/inputs/scene0350_00/obj_tokens.pth') | |
| # data_dict['obj_masks'] = torch.load('assets/inputs/scene0350_00/obj_masks.pth') | |
| return data_dict |