YuqianFu's picture
Upload folder using huggingface_hub
197d4ca verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvNet(nn.Module):
''' 网络结构和cvpr2020的 M-ADA 方法一致 '''
def __init__(self, imdim=3):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(imdim, 64, kernel_size=5, stride=1, padding=0)
self.mp = nn.MaxPool2d(2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=0)
self.relu2 = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(128*5*5, 1024)
self.relu3 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(1024, 1024)
self.relu4 = nn.ReLU(inplace=True)
self.cls_head_src = nn.Linear(1024, 10)
# self.cls_head_tgt = nn.Linear(1024, 10)
# self.pro_head = nn.Linear(1024, 128)
def forward(self, x, mode='fc'):
if mode == 'c':
out4 = self.relu4(x)
p = self.cls_head_src(out4)
return p
elif mode == 'fc':
in_size = x.size(0)
out1 = self.mp(self.relu1(self.conv1(x)))
out2 = self.mp(self.relu2(self.conv2(out1)))
out2 = out2.view(in_size, -1)
out3 = self.relu3(self.fc1(out2))
out4_worelu = self.fc2(out3)
out4 = self.relu4(out4_worelu)
p = self.cls_head_src(out4)
return p, out4_worelu
# if mode == 'test':
# p = self.cls_head_src(out4)
# return p
# elif mode == 'train':
# p = self.cls_head_src(out4)
# # z = self.pro_head(out4)
# # z = F.normalize(z)
# return p,out4_worelu
# elif mode == 'p_f':
# p = self.cls_head_src(out4)
# return p, out4
#elif mode == 'target':
# p = self.cls_head_tgt(out4)
# z = self.pro_head(out4)
# z = F.normalize(z)
# return p,z
class ConvNetVis(nn.Module):
''' 方便可视化,特征提取器输出2-d特征
'''
def __init__(self, imdim=3):
super(ConvNetVis, self).__init__()
self.conv1 = nn.Conv2d(imdim, 64, kernel_size=5, stride=1, padding=0)
self.mp = nn.MaxPool2d(2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=0)
self.relu2 = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(128*5*5, 1024)
self.relu3 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(1024, 2)
self.relu4 = nn.ReLU(inplace=True)
self.cls_head_src = nn.Linear(2, 10)
self.cls_head_tgt = nn.Linear(2, 10)
self.pro_head = nn.Linear(2, 128)
def forward(self, x, mode='test'):
in_size = x.size(0)
out1 = self.mp(self.relu1(self.conv1(x)))
out2 = self.mp(self.relu2(self.conv2(out1)))
out2 = out2.view(in_size, -1)
out3 = self.relu3(self.fc1(out2))
out4 = self.relu4(self.fc2(out3))
if mode == 'test':
p = self.cls_head_src(out4)
return p
elif mode == 'train':
p = self.cls_head_src(out4)
z = self.pro_head(out4)
z = F.normalize(z)
return p,z
elif mode == 'p_f':
p = self.cls_head_src(out4)
return p, out4
#elif mode == 'target':
# p = self.cls_head_tgt(out4)
# z = self.pro_head(out4)
# z = F.normalize(z)
# return p,z