| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """GeoMatch model definition.""" |
|
|
| from models.gnn import GCN |
| from models.mlp import MLP |
| import torch |
| from torch import nn |
|
|
|
|
| class GeoMatchARModule(nn.Module): |
| """Autoregressive module class for GeoMatch.""" |
|
|
| def __init__(self, config, n_kp) -> None: |
| super().__init__() |
|
|
| self.config = config |
| self.n_kp = n_kp |
| self.final_fc = MLP(128 + 3 * self.n_kp, 1, 3, 256) |
|
|
| def forward(self, obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev): |
| robot_i_embed = ( |
| robot_proj_embed[:, self.n_kp][..., None] |
| .transpose(2, 1) |
| .repeat(1, self.config.obj_pc_n, 1) |
| ) |
| obj_robot_embed = torch.cat((obj_proj_embed, robot_i_embed), dim=-1) |
|
|
| diff_xyz_tensor = [] |
| for i in range(self.n_kp): |
| diff_xyz = obj_pc - xyz_prev[:, i, :][..., None].transpose(2, 1) |
| diff_xyz_tensor.append(diff_xyz) |
|
|
| diff_xyz_tensor = torch.stack(diff_xyz_tensor, dim=-1) |
| diff_xyz_tensor = diff_xyz_tensor.view( |
| diff_xyz_tensor.shape[0], diff_xyz_tensor.shape[1], -1 |
| ) |
| inp = torch.cat((obj_robot_embed, diff_xyz_tensor), dim=-1) |
| pred_curr = self.final_fc(inp) |
|
|
| return pred_curr |
|
|
| def calc_loss(self, pred, label): |
| pred = pred.view(pred.shape[0] * pred.shape[1], 1) |
| label = label.view(label.shape[0] * label.shape[1], 1) |
|
|
| pos_weight = torch.tensor([1000.0]).cuda() |
| loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, label) |
| return torch.mean(loss) |
|
|
|
|
| class GeoMatch(nn.Module): |
| """GeoMatch model class.""" |
|
|
| def __init__(self, config) -> None: |
| super().__init__() |
|
|
| self.config = config |
| self.n_kp = config.keypoint_n |
| self.robot_weighting = config.robot_weighting |
| self.match_weighting = config.matchnet_weighting |
| self.dist_loss_weight = config.dist_loss_weight |
| self.match_loss_weight = config.match_loss_weight |
|
|
| self.obj_encoder = GCN( |
| nfeat=config.obj_in_feats, |
| nhid=config.hidden_n, |
| nout=config.obj_out_feats, |
| dropout=0.5, |
| num_hidden=config.num_hidden, |
| ) |
|
|
| self.robot_encoder = GCN( |
| nfeat=config.robot_in_feats, |
| nhid=config.hidden_n, |
| nout=config.robot_out_feats, |
| dropout=0.5, |
| num_hidden=config.num_hidden, |
| ) |
|
|
| self.obj_proj = nn.Linear(self.config.obj_out_feats, 64, bias=False) |
| self.robot_proj = nn.Linear(self.config.robot_out_feats, 64, bias=False) |
| self.kp_ar_model_1 = GeoMatchARModule(config, 1) |
| self.kp_ar_model_2 = GeoMatchARModule(config, 2) |
| self.kp_ar_model_3 = GeoMatchARModule(config, 3) |
| self.kp_ar_model_4 = GeoMatchARModule(config, 4) |
| self.kp_ar_model_5 = GeoMatchARModule(config, 5) |
|
|
| def encode_embed(self, encoder, feature, adj_mat, normalize_emb=True): |
| x = encoder(feature, adj_mat) |
| if normalize_emb: |
| x = x.clone() / (torch.norm(x, dim=-1, keepdim=True) + 1e-6) |
| return x |
|
|
| def forward( |
| self, obj_pc, robot_pc, robot_key_point_idx, obj_adj, robot_adj, xyz_prev |
| ): |
| obj_embed = self.encode_embed(self.obj_encoder, obj_pc, obj_adj) |
| robot_embed = self.encode_embed(self.robot_encoder, robot_pc, robot_adj) |
|
|
| robot_feat_size = robot_embed.shape[2] |
| keypoint_feat = torch.gather( |
| robot_embed, |
| 1, |
| robot_key_point_idx[..., None].long().repeat(1, 1, robot_feat_size), |
| ) |
| contact_map_pred = torch.matmul(obj_embed, keypoint_feat.transpose(2, 1))[ |
| ..., None |
| ] |
|
|
| obj_proj_embed = self.obj_proj(obj_embed) |
| robot_proj_embed = self.robot_proj(robot_embed) |
|
|
| output_1 = self.kp_ar_model_1( |
| obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev |
| ) |
| output_2 = self.kp_ar_model_2( |
| obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev |
| ) |
| output_3 = self.kp_ar_model_3( |
| obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev |
| ) |
| output_4 = self.kp_ar_model_4( |
| obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev |
| ) |
| output_5 = self.kp_ar_model_5( |
| obj_proj_embed, obj_pc, robot_proj_embed, xyz_prev |
| ) |
|
|
| output = torch.cat( |
| (output_1, output_2, output_3, output_4, output_5), dim=-1 |
| )[..., None] |
|
|
| return contact_map_pred, output |
|
|
| def calc_loss(self, gt_contact_map, contact_map_pred, pred, label): |
| flat_contact_map_pred = contact_map_pred.view( |
| contact_map_pred.shape[0] |
| * contact_map_pred.shape[1] |
| * contact_map_pred.shape[2], |
| 1, |
| ) |
| flat_gt_contact_map = gt_contact_map.view( |
| gt_contact_map.shape[0] |
| * gt_contact_map.shape[1] |
| * gt_contact_map.shape[2], |
| 1, |
| ) |
|
|
| pos_weight = torch.Tensor([self.robot_weighting]).cuda() |
| loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)( |
| flat_contact_map_pred, flat_gt_contact_map |
| ) |
| l_dist = torch.mean(loss) |
|
|
| pos_weight = torch.tensor([self.match_weighting]).cuda() |
|
|
| loss = [] |
| for i in range(self.n_kp - 1): |
| pred_i = pred[:, :, i] |
| label_i = label[:, :, i] |
| pred_i = pred_i.view(pred_i.shape[0] * pred_i.shape[1], 1) |
| label_i = label_i.view(label_i.shape[0] * label_i.shape[1], 1) |
| loss.append(nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred_i, label_i)) |
|
|
| loss = torch.stack(loss) |
| l_match = torch.mean(loss) |
|
|
| return self.dist_loss_weight * l_dist + self.match_loss_weight * l_match |
|
|