|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class ProtoNet(nn.Module): |
|
|
def __init__(self, backbone): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.bias = nn.Parameter(torch.FloatTensor(1).fill_(0), requires_grad=True) |
|
|
self.scale_cls = nn.Parameter(torch.FloatTensor(1).fill_(10), requires_grad=True) |
|
|
|
|
|
|
|
|
self.backbone = backbone |
|
|
|
|
|
def cos_classifier(self, w, f): |
|
|
""" |
|
|
w.shape = B, nC, d |
|
|
f.shape = B, M, d |
|
|
""" |
|
|
f = F.normalize(f, p=2, dim=f.dim()-1, eps=1e-12) |
|
|
w = F.normalize(w, p=2, dim=w.dim()-1, eps=1e-12) |
|
|
|
|
|
cls_scores = f @ w.transpose(1, 2) |
|
|
cls_scores = self.scale_cls * (cls_scores + self.bias) |
|
|
return cls_scores |
|
|
|
|
|
def forward(self, supp_x, supp_y, x): |
|
|
""" |
|
|
supp_x.shape = [B, nSupp, C, H, W] |
|
|
supp_y.shape = [B, nSupp] |
|
|
x.shape = [B, nQry, C, H, W] |
|
|
""" |
|
|
num_classes = supp_y.max() + 1 |
|
|
B, nSupp, C, H, W = supp_x.shape |
|
|
supp_f = self.backbone.forward(supp_x.contiguous().view(-1, C, H, W)) |
|
|
''' |
|
|
# for test vit_multiBlock (equals to forward()) |
|
|
supp_f = self.backbone.forward_block1(supp_x.contiguous().view(-1, C, H, W)) |
|
|
supp_f = self.backbone.forward_block2(supp_f) |
|
|
supp_f = self.backbone.forward_block3(supp_f) |
|
|
supp_f = self.backbone.forward_block4(supp_f) |
|
|
supp_f = self.backbone.forward_rest(supp_f) |
|
|
''' |
|
|
supp_f = supp_f.view(B, nSupp, -1) |
|
|
supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(1, 2) |
|
|
|
|
|
|
|
|
prototypes = torch.bmm(supp_y_1hot.float(), supp_f) |
|
|
prototypes = prototypes / supp_y_1hot.sum(dim=2, keepdim=True) |
|
|
|
|
|
feat = self.backbone.forward(x.view(-1, C, H, W)) |
|
|
feat = feat.view(B, x.shape[1], -1) |
|
|
|
|
|
logits = self.cos_classifier(prototypes, feat) |
|
|
return logits |
|
|
|