|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class ConvNet(nn.Module): |
|
|
''' 网络结构和cvpr2020的 M-ADA 方法一致 ''' |
|
|
def __init__(self, imdim=3): |
|
|
super(ConvNet, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(imdim, 64, kernel_size=5, stride=1, padding=0) |
|
|
self.mp = nn.MaxPool2d(2) |
|
|
self.relu1 = nn.ReLU(inplace=True) |
|
|
self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=0) |
|
|
self.relu2 = nn.ReLU(inplace=True) |
|
|
self.fc1 = nn.Linear(128*5*5, 1024) |
|
|
self.relu3 = nn.ReLU(inplace=True) |
|
|
self.fc2 = nn.Linear(1024, 1024) |
|
|
self.relu4 = nn.ReLU(inplace=True) |
|
|
|
|
|
self.cls_head_src = nn.Linear(1024, 10) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, mode='fc'): |
|
|
if mode == 'c': |
|
|
out4 = self.relu4(x) |
|
|
p = self.cls_head_src(out4) |
|
|
return p |
|
|
elif mode == 'fc': |
|
|
in_size = x.size(0) |
|
|
out1 = self.mp(self.relu1(self.conv1(x))) |
|
|
out2 = self.mp(self.relu2(self.conv2(out1))) |
|
|
out2 = out2.view(in_size, -1) |
|
|
out3 = self.relu3(self.fc1(out2)) |
|
|
out4_worelu = self.fc2(out3) |
|
|
out4 = self.relu4(out4_worelu) |
|
|
p = self.cls_head_src(out4) |
|
|
return p, out4_worelu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvNetVis(nn.Module): |
|
|
''' 方便可视化,特征提取器输出2-d特征 |
|
|
''' |
|
|
def __init__(self, imdim=3): |
|
|
super(ConvNetVis, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(imdim, 64, kernel_size=5, stride=1, padding=0) |
|
|
self.mp = nn.MaxPool2d(2) |
|
|
self.relu1 = nn.ReLU(inplace=True) |
|
|
self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=0) |
|
|
self.relu2 = nn.ReLU(inplace=True) |
|
|
self.fc1 = nn.Linear(128*5*5, 1024) |
|
|
self.relu3 = nn.ReLU(inplace=True) |
|
|
self.fc2 = nn.Linear(1024, 2) |
|
|
self.relu4 = nn.ReLU(inplace=True) |
|
|
|
|
|
self.cls_head_src = nn.Linear(2, 10) |
|
|
self.cls_head_tgt = nn.Linear(2, 10) |
|
|
self.pro_head = nn.Linear(2, 128) |
|
|
|
|
|
def forward(self, x, mode='test'): |
|
|
|
|
|
in_size = x.size(0) |
|
|
out1 = self.mp(self.relu1(self.conv1(x))) |
|
|
out2 = self.mp(self.relu2(self.conv2(out1))) |
|
|
out2 = out2.view(in_size, -1) |
|
|
out3 = self.relu3(self.fc1(out2)) |
|
|
out4 = self.relu4(self.fc2(out3)) |
|
|
|
|
|
if mode == 'test': |
|
|
p = self.cls_head_src(out4) |
|
|
return p |
|
|
elif mode == 'train': |
|
|
p = self.cls_head_src(out4) |
|
|
z = self.pro_head(out4) |
|
|
z = F.normalize(z) |
|
|
return p,z |
|
|
elif mode == 'p_f': |
|
|
p = self.cls_head_src(out4) |
|
|
return p, out4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|