| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops.einops import rearrange, repeat |
| |
|
| |
|
| | class FinePreprocess(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| |
|
| | self.config = config |
| | self.cat_c_feat = config['fine_concat_coarse_feat'] |
| | self.W = self.config['fine_window_size'] |
| |
|
| | d_model_c = self.config['coarse']['d_model'] |
| | d_model_f = self.config['fine']['d_model'] |
| | self.d_model_f = d_model_f |
| | if self.cat_c_feat: |
| | self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) |
| | self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) |
| |
|
| | self._reset_parameters() |
| |
|
| | def _reset_parameters(self): |
| | for p in self.parameters(): |
| | if p.dim() > 1: |
| | nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") |
| |
|
| | def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): |
| | W = self.W |
| | stride = data['hw0_f'][0] // data['hw0_c'][0] |
| |
|
| | data.update({'W': W}) |
| | if data['b_ids'].shape[0] == 0: |
| | feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) |
| | feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) |
| | return feat0, feat1 |
| |
|
| | |
| | feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) |
| | feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) |
| | feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) |
| | feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) |
| |
|
| | |
| | feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] |
| | feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] |
| |
|
| | |
| | if self.cat_c_feat: |
| | feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], |
| | feat_c1[data['b_ids'], data['j_ids']]], 0)) |
| | feat_cf_win = self.merge_feat(torch.cat([ |
| | torch.cat([feat_f0_unfold, feat_f1_unfold], 0), |
| | repeat(feat_c_win, 'n c -> n ww c', ww=W**2), |
| | ], -1)) |
| | feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) |
| |
|
| | return feat_f0_unfold, feat_f1_unfold |
| |
|