|
|
|
|
|
''' |
|
|
训练 base 模型 |
|
|
''' |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import itertools |
|
|
from torch import optim |
|
|
from torch.utils.data import DataLoader, RandomSampler |
|
|
from torchvision import models |
|
|
from torchvision.datasets import CIFAR10 |
|
|
from torchvision.utils import make_grid |
|
|
import torchvision.transforms as transforms |
|
|
from tensorboardX import SummaryWriter |
|
|
from torch.cuda.amp import autocast,GradScaler |
|
|
|
|
|
import os |
|
|
import click |
|
|
import time |
|
|
import numpy as np |
|
|
|
|
|
from network import mnist_net_my as mnist_net |
|
|
from network import wideresnet as wideresnet |
|
|
from network import resnet as resnet |
|
|
from network import adaptor_v2 |
|
|
|
|
|
from tools import causalaugment_v3 as causalaugment |
|
|
import data_loader_joint_v3 as data_loader |
|
|
|
|
|
|
|
|
HOME = os.environ['HOME'] |
|
|
|
|
|
@click.command() |
|
|
@click.option('--gpu', type=str, default='0', help='选择gpu') |
|
|
@click.option('--data', type=str, default='mnist', help='数据集名称') |
|
|
@click.option('--ntr', type=int, default=None, help='选择训练集前ntr个样本') |
|
|
@click.option('--translate', type=float, default=None, help='随机平移数据增强') |
|
|
@click.option('--autoaug', type=str, default=None, help='AA FastAA RA') |
|
|
@click.option('--n', type=int, default=3, help='选择多少个factor生成RA') |
|
|
@click.option('--stride', type=int, default=5, help='if autoaug==CA_multiple, stride is used') |
|
|
@click.option('--factor_num', type=int, default=16, help='the first n factors') |
|
|
@click.option('--epochs', type=int, default=100) |
|
|
@click.option('--nbatch', type=int, default=100, help='每个epoch中batch的数量') |
|
|
@click.option('--batchsize', type=int, default=128, help='每个batch中样本的数量') |
|
|
@click.option('--lr', type=float, default=1e-3) |
|
|
@click.option('--lr_scheduler', type=str, default='none', help='是否选择学习率衰减策略') |
|
|
@click.option('--svroot', type=str, default='./saved', help='项目文件保存路径') |
|
|
@click.option('--clsadapt', type=bool, default=True, help='映射后是否用分类损失') |
|
|
@click.option('--lambda_causal', type=float, default=1, help='the weight of reconstruction during mapping and causal ') |
|
|
@click.option('--lambda_re', type=float, default=1, help='the weight of reconstruction during mapping and causal ') |
|
|
@click.option('--randm', type=bool, default=True, help='m取值是否randm') |
|
|
@click.option('--randn', type=bool, default=False, help='原始特征是否detach') |
|
|
@click.option('--network', type=str, default='resnet18', help='项目文件保存路径') |
|
|
def experiment(gpu, data, ntr, translate, autoaug,n,stride, factor_num, epochs, nbatch, batchsize, lr, lr_scheduler, svroot, clsadapt, lambda_causal,lambda_re,randm,randn,network): |
|
|
|
|
|
settings = locals().copy() |
|
|
print(settings) |
|
|
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu |
|
|
if not os.path.exists(svroot): |
|
|
os.makedirs(svroot) |
|
|
log_file = open(svroot+os.sep+'log.log',"w") |
|
|
log_file.write(str(settings)+'\n') |
|
|
writer = SummaryWriter(svroot) |
|
|
CA = causalaugment.MultiCounterfactualAugment(factor_num,stride) |
|
|
|
|
|
|
|
|
if data in ['mnist', 'mnist_t']: |
|
|
if data == 'mnist': |
|
|
trset = data_loader.load_mnist('train', translate=translate,twox=True, ntr=ntr, factor_num=factor_num,autoaug=autoaug,randm=randm,randn=randn,n=n,stride=stride) |
|
|
elif data == 'mnist_t': |
|
|
trset = data_loader.load_mnist_t('train', translate=translate, ntr=ntr) |
|
|
teset = data_loader.load_mnist('test') |
|
|
trloader = DataLoader(trset, batch_size=batchsize, num_workers=0, \ |
|
|
sampler=RandomSampler(trset, True, nbatch*batchsize)) |
|
|
teloader = DataLoader(teset, batch_size=batchsize, num_workers=0, shuffle=False) |
|
|
cls_net = mnist_net.ConvNet().cuda() |
|
|
AdaptNet = [] |
|
|
parameter_list = [] |
|
|
for i in range(factor_num): |
|
|
mapping = adaptor_v2.mapping(1024,512,1024,2).cuda() |
|
|
AdaptNet.append(mapping) |
|
|
parameter_list.append({'params':mapping.parameters(),'lr':lr}) |
|
|
if autoaug == 'CA_multiple': |
|
|
var_num = len(list(range(0, 31, stride))) |
|
|
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda() |
|
|
else: |
|
|
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda() |
|
|
parameter_list.append({'params':cls_net.parameters(),'lr':lr}) |
|
|
parameter_list.append({'params':E_to_W.parameters(),'lr':lr}) |
|
|
|
|
|
opt = optim.Adam(parameter_list, lr=lr) |
|
|
if lr_scheduler == 'cosine': |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, epochs) |
|
|
elif lr_scheduler == 'Exp': |
|
|
scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.95) |
|
|
elif lr_scheduler == 'Step': |
|
|
scheduler = optim.lr_scheduler.StepLR(opt, step_size=int(epochs*0.8)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif data == 'cifar10': |
|
|
|
|
|
trset = data_loader.load_cifar10(split='train',twox=True, factor_num=factor_num,autoaug=autoaug,randm=randm,randn=randn,n=n,stride=stride) |
|
|
teset = data_loader.load_cifar10(split='test') |
|
|
trloader = DataLoader(trset, batch_size=batchsize, num_workers=4, shuffle=True, drop_last=True) |
|
|
teloader = DataLoader(teset, batch_size=batchsize, num_workers=4, shuffle=False) |
|
|
cls_net = wideresnet.WideResNet(16, 10, 4).cuda() |
|
|
|
|
|
AdaptNet = [] |
|
|
parameter_list = [] |
|
|
for i in range(factor_num): |
|
|
mapping = adaptor_v2.mapping(256,512,256,4).cuda() |
|
|
AdaptNet.append(mapping) |
|
|
parameter_list.append({'params':mapping.parameters(),'lr':lr}) |
|
|
if autoaug == 'CA_multiple': |
|
|
var_num = len(list(range(0, 31, stride))) |
|
|
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda() |
|
|
else: |
|
|
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda() |
|
|
parameter_list.append({'params':cls_net.parameters(),'lr':lr}) |
|
|
parameter_list.append({'params':E_to_W.parameters(),'lr':lr}) |
|
|
|
|
|
|
|
|
opt = optim.SGD(parameter_list, lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4) |
|
|
if lr_scheduler == 'cosine': |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, epochs) |
|
|
elif lr_scheduler == 'Exp': |
|
|
scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.95) |
|
|
elif lr_scheduler == 'Step': |
|
|
scheduler = optim.lr_scheduler.StepLR(opt, step_size=int(epochs*0.8)) |
|
|
elif data in ['art_painting', 'cartoon', 'photo', 'sketch']: |
|
|
|
|
|
trset = data_loader.load_pacs(domain=data, split='train', twox=True, factor_num=factor_num,autoaug=autoaug,randm=randm,randn=randn,n=n,stride=stride) |
|
|
teset = data_loader.load_pacs(domain=data, split='val') |
|
|
trloader = DataLoader(trset, batch_size=batchsize, num_workers=4, shuffle=True, drop_last=True) |
|
|
teloader = DataLoader(teset, batch_size=batchsize, num_workers=4, shuffle=False) |
|
|
if network == 'resnet18': |
|
|
cls_net = resnet.resnet18(classes=7,c_dim=2048).cuda() |
|
|
input_dim = 2048 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
classifier_param = list(map(id, cls_net.class_classifier.parameters())) |
|
|
backbone_param = filter(lambda p: id(p) not in classifier_param and p.requires_grad, cls_net.parameters()) |
|
|
AdaptNet = [] |
|
|
parameter_list = [] |
|
|
for i in range(factor_num): |
|
|
mapping = adaptor_v2.mapping(input_dim,1024,input_dim,4).cuda() |
|
|
AdaptNet.append(mapping) |
|
|
parameter_list.append({'params':mapping.parameters(),'lr':lr}) |
|
|
if autoaug == 'CA_multiple': |
|
|
var_num = len(list(range(0, 31, stride))) |
|
|
E_to_W = adaptor_v2.effect_to_weight(7,70,1).cuda() |
|
|
else: |
|
|
E_to_W = adaptor_v2.effect_to_weight(7,70,1).cuda() |
|
|
parameter_list.append({'params':backbone_param,'lr':0.01*lr}) |
|
|
parameter_list.append({'params':cls_net.class_classifier.parameters(),'lr':lr}) |
|
|
parameter_list.append({'params':E_to_W.parameters(),'lr':lr}) |
|
|
|
|
|
|
|
|
|
|
|
opt = optim.SGD(parameter_list, momentum=0.9, nesterov=True, weight_decay=5e-4) |
|
|
if lr_scheduler == 'cosine': |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, epochs) |
|
|
elif lr_scheduler == 'Exp': |
|
|
scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.99999) |
|
|
elif lr_scheduler == 'Step': |
|
|
scheduler = optim.lr_scheduler.StepLR(opt, step_size=15) |
|
|
elif 'synthia' in data: |
|
|
|
|
|
branch = data.split('_')[1] |
|
|
trset = data_loader.load_synthia(branch) |
|
|
trloader = DataLoader(trset, batch_size=batchsize, num_workers=8, shuffle=True) |
|
|
teloader = DataLoader(trset, batch_size=batchsize, num_workers=8, shuffle=True) |
|
|
imsize = [192, 320] |
|
|
nclass = 14 |
|
|
|
|
|
cls_net = fcn.FCN_resnet50(nclass=nclass).cuda() |
|
|
cls_opt = optim.Adam(cls_net.parameters(), lr=lr) |
|
|
if lr_scheduler == 'cosine': |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(cls_opt, epochs*len(trloader)) |
|
|
|
|
|
cls_criterion = nn.CrossEntropyLoss() |
|
|
adapt_criterion = nn.MSELoss() |
|
|
|
|
|
best_acc = 0 |
|
|
best_acc_t = 0 |
|
|
scaler = GradScaler() |
|
|
for epoch in range(epochs): |
|
|
t1 = time.time() |
|
|
loss_list = [] |
|
|
cls_net.train() |
|
|
|
|
|
print(len(trloader)) |
|
|
for i, (x_four,y) in enumerate(trloader): |
|
|
b_sample_num = y.size(0) |
|
|
x, x_RA, x_FA, x_CA, y = x_four[0].cuda(), x_four[1].cuda(), x_four[2].cuda(), x_four[3].cuda(), y.cuda() |
|
|
b, c, h, w = x.shape |
|
|
|
|
|
x_FA = x_FA.reshape(b*factor_num, c, h, w) |
|
|
x_CA = x_CA.reshape(b*factor_num*var_num, c, h, w) |
|
|
|
|
|
y_repeat = y.unsqueeze(0).reshape(b_sample_num,1).repeat((1,factor_num)).reshape(1,b_sample_num*factor_num).squeeze() |
|
|
|
|
|
|
|
|
with autocast(): |
|
|
p,f = cls_net(x) |
|
|
|
|
|
|
|
|
_,f_FA = cls_net(x_FA) |
|
|
p_RA,f_RA = cls_net(x_RA) |
|
|
p_CA,_ = cls_net(x_CA) |
|
|
|
|
|
|
|
|
|
|
|
f_repeat = f.repeat((1,factor_num)).reshape(f_FA.shape) |
|
|
f_adapt = torch.zeros(f_FA.shape).cuda() |
|
|
for b in range(b_sample_num): |
|
|
for j in range(factor_num): |
|
|
f_adapt[b*factor_num+j] = AdaptNet[j](f_FA[b*factor_num+j]) |
|
|
p_adapt = cls_net(f_adapt, mode='c') |
|
|
|
|
|
|
|
|
if autoaug == 'CA_multiple': |
|
|
p_RA_repeat = p_RA.repeat((1,factor_num*var_num)).reshape(p_CA.shape) |
|
|
effect_context = p_RA_repeat - p_CA |
|
|
effect_context = effect_context.reshape(b_sample_num,factor_num,var_num,-1) |
|
|
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*factor_num,-1) |
|
|
|
|
|
else: |
|
|
p_RA_repeat = p_RA.repeat((1,factor_num)).reshape(p_CA.shape) |
|
|
effect_context = p_RA_repeat - p_CA |
|
|
weight = E_to_W(effect_context) |
|
|
|
|
|
weight = weight.reshape(b_sample_num,factor_num) |
|
|
alphas = F.softmax(weight,dim=1) |
|
|
|
|
|
f_adapt_RA = torch.zeros(f_RA.shape).cuda() |
|
|
for b in range(b_sample_num): |
|
|
for j in range(factor_num): |
|
|
f_adapt_RA[b] = f_adapt_RA[b]+ alphas[b,j]*AdaptNet[j](f_RA[b]) |
|
|
p_adapt_RA = cls_net(f_adapt_RA, mode='c') |
|
|
|
|
|
cls_loss = cls_criterion(p, y) |
|
|
re_mapping = adapt_criterion(f_adapt,f_repeat) |
|
|
re_causal = adapt_criterion(f_adapt_RA,f) |
|
|
cls_loss_mapping = cls_criterion(p_adapt, y_repeat) |
|
|
cls_loss_causal = cls_criterion(p_adapt_RA, y) |
|
|
|
|
|
loss = cls_loss + cls_loss_mapping + lambda_re*re_mapping + lambda_causal*(lambda_re*re_causal + cls_loss_causal) |
|
|
|
|
|
opt.zero_grad() |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(opt) |
|
|
scaler.update() |
|
|
loss_list.append([cls_loss.item(), cls_loss_mapping.item(),cls_loss_causal.item(), re_mapping.item(), re_causal.item()]) |
|
|
|
|
|
|
|
|
if lr_scheduler in ['cosine', 'Exp', 'Step']: |
|
|
writer.add_scalar('scalar/lr', opt.param_groups[0]["lr"], epoch) |
|
|
print(opt.param_groups[0]["lr"]) |
|
|
print("changing lr") |
|
|
scheduler.step() |
|
|
cls_loss, cls_loss_mapping, cls_loss_causal, re_mapping, re_causal = np.mean(loss_list, 0) |
|
|
|
|
|
|
|
|
cls_net.eval() |
|
|
if data in ['mnist', 'mnist_t', 'cifar10', 'mnistvis', 'art_painting', 'cartoon', 'photo', 'sketch']: |
|
|
teacc = evaluate(cls_net, teloader) |
|
|
|
|
|
elif 'synthia' in data: |
|
|
teacc = evaluate_seg(cls_net, teloader, nclass) |
|
|
|
|
|
if best_acc < teacc: |
|
|
print(f'---------------------saving model at epoch {epoch}----------------------------------------------------') |
|
|
log_file.write(f'saving model at epoch {epoch}\n') |
|
|
|
|
|
best_acc = teacc |
|
|
torch.save(cls_net.state_dict(),os.path.join(svroot, 'best_cls_net.pkl')) |
|
|
for j in range(factor_num): |
|
|
torch.save(AdaptNet[j].state_dict(),os.path.join(svroot, 'best_mapping_'+str(j)+'.pkl')) |
|
|
torch.save(E_to_W.state_dict(), os.path.join(svroot, 'best_E_to_W.pkl')) |
|
|
|
|
|
|
|
|
t2 = time.time() |
|
|
print(f'epoch {epoch}, time {t2-t1:.2f}, cls_loss {cls_loss:.4f} cls_loss_mapping {cls_loss_mapping:.4f} cls_loss_causal {cls_loss_causal:.4f} re_mapping {re_mapping:.4f} re_causal {re_causal:.4f} /// teacc {teacc:2.2f} lr {opt.param_groups[0]["lr"]:.8f}') |
|
|
log_file.write(f'epoch {epoch}, time {t2-t1:.2f}, cls_loss {cls_loss:.4f} cls_loss_mapping {cls_loss_mapping:.4f} cls_loss_causal {cls_loss_causal:.4f} re_mapping {re_mapping:.4f} re_causal {re_causal:.4f} /// teacc {teacc:2.2f} lr {opt.param_groups[0]["lr"]:.8f} \n') |
|
|
writer.add_scalar('scalar/cls_loss', cls_loss, epoch) |
|
|
writer.add_scalar('scalar/cls_loss_mapping', cls_loss_mapping, epoch) |
|
|
writer.add_scalar('scalar/cls_loss_causal', cls_loss_causal, epoch) |
|
|
writer.add_scalar('scalar/re_mapping', re_mapping, epoch) |
|
|
writer.add_scalar('scalar/re_causal', re_causal, epoch) |
|
|
writer.add_scalar('scalar/teacc', teacc, epoch) |
|
|
print(f'---------------------saving last model at epoch {epoch}----------------------------------------------------') |
|
|
log_file.write(f'saving last model at epoch {epoch}\n') |
|
|
torch.save(cls_net.state_dict(),os.path.join(svroot, 'last_cls_net.pkl')) |
|
|
for j in range(factor_num): |
|
|
torch.save(AdaptNet[j].state_dict(),os.path.join(svroot, 'last_mapping_'+str(j)+'.pkl')) |
|
|
torch.save(E_to_W.state_dict(), os.path.join(svroot, 'last_E_to_W.pkl')) |
|
|
|
|
|
writer.close() |
|
|
def evalute_pacs(source_domain,cls_net,CA,AdaptNet,E_to_W): |
|
|
cls_net.eval() |
|
|
data_total = ['art_painting', 'cartoon', 'photo', 'sketch'] |
|
|
target = [i for i in data_total if i!=source_domain] |
|
|
acc_CA = np.zeros(len(target)) |
|
|
for idx, data in enumerate(target): |
|
|
teset = data_loader.load_pacs(data, 'test') |
|
|
teloader = DataLoader(teset, batch_size=6, num_workers=0) |
|
|
|
|
|
acc_CA[idx] = evaluate_causal(cls_net, teloader, CA, AdaptNet, E_to_W) |
|
|
acc_avg_CA = sum(acc_CA)/len(target) |
|
|
return acc_avg_CA,acc_CA |
|
|
|
|
|
|
|
|
def evaluate(net, teloader): |
|
|
ps = [] |
|
|
ys = [] |
|
|
for i,(x1, y1) in enumerate(teloader): |
|
|
with torch.no_grad(): |
|
|
x1 = x1.cuda() |
|
|
p1,_ = net(x1, mode='fc') |
|
|
p1 = p1.argmax(dim=1) |
|
|
ps.append(p1.detach().cpu().numpy()) |
|
|
ys.append(y1.numpy()) |
|
|
|
|
|
ps = np.concatenate(ps) |
|
|
ys = np.concatenate(ys) |
|
|
acc = np.mean(ys==ps)*100 |
|
|
return acc |
|
|
def extract_feature(net, teloader, savedir): |
|
|
ps = [] |
|
|
ys = [] |
|
|
for i,(x1, y1) in enumerate(teloader): |
|
|
img_class = y1[0].cpu().numpy() |
|
|
save_path = os.path.join(savedir,str(img_class)) |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path) |
|
|
|
|
|
with torch.no_grad(): |
|
|
x1 = x1.cuda() |
|
|
p1,f1 = net(x1, mode='fc') |
|
|
save_name = save_path+os.sep+str(i)+'.npy' |
|
|
np.save(save_name,f1.cpu()) |
|
|
p1 = p1.argmax(dim=1) |
|
|
ps.append(p1.detach().cpu().numpy()) |
|
|
ys.append(y1.numpy()) |
|
|
|
|
|
ps = np.concatenate(ps) |
|
|
ys = np.concatenate(ys) |
|
|
acc = np.mean(ys==ps)*100 |
|
|
return acc |
|
|
|
|
|
def evaluate_causal(net, teloader, CA, AdaptNet, E_to_W): |
|
|
ps = [] |
|
|
ys = [] |
|
|
p_orig = [] |
|
|
y_orig = [] |
|
|
for i,(x1, y1) in enumerate(teloader): |
|
|
b_sample_num = x1.size(0) |
|
|
with torch.no_grad(): |
|
|
x1 = x1.cuda() |
|
|
p1,f_x1 = net(x1, mode='fc') |
|
|
x1_CA = CA(x1).cuda() |
|
|
p1_CA,_ = net(x1_CA, mode='fc') |
|
|
p1_repeat = p1.repeat((1,CA.factor_num*CA.var_num)).reshape(p1_CA.shape) |
|
|
effect_context = p1_repeat - p1_CA |
|
|
effect_context = effect_context.reshape(b_sample_num,CA.factor_num,CA.var_num,-1) |
|
|
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*CA.factor_num,-1) |
|
|
weight = E_to_W(effect_context) |
|
|
weight = weight.reshape(b_sample_num,CA.factor_num) |
|
|
alphas = F.softmax(weight,dim=1) |
|
|
f_adapt = torch.zeros(f_x1.shape).cuda() |
|
|
for b in range(b_sample_num): |
|
|
for j in range(CA.factor_num): |
|
|
f_adapt[b] = f_adapt[b]+ alphas[b,j]*AdaptNet[j](f_x1[b]) |
|
|
p_adapt = net(f_adapt, mode='c') |
|
|
p_adapt = p_adapt.argmax(dim=1) |
|
|
ps.append(p_adapt.detach().cpu().numpy()) |
|
|
ys.append(y1.numpy()) |
|
|
|
|
|
ps = np.concatenate(ps) |
|
|
ys = np.concatenate(ys) |
|
|
acc = np.mean(ys==ps)*100 |
|
|
return acc |
|
|
|
|
|
def extract_feature_do(net, teloader, CA, AdaptNet, E_to_W, savedir_base, savedir,source_flag): |
|
|
ps = [] |
|
|
ys = [] |
|
|
for i,(x1, y1) in enumerate(teloader): |
|
|
img_class = y1[0].cpu().numpy() |
|
|
save_path_base = os.path.join(savedir_base,str(img_class)) |
|
|
save_path = os.path.join(savedir,str(img_class)) |
|
|
if not os.path.exists(save_path_base): |
|
|
os.makedirs(save_path_base) |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path) |
|
|
b_sample_num = x1.size(0) |
|
|
with torch.no_grad(): |
|
|
x1 = x1.cuda() |
|
|
p1,f_x1 = net(x1, mode='fc') |
|
|
save_name_base = save_path_base+os.sep+str(i)+'_base.npy' |
|
|
print(save_name_base) |
|
|
np.save(save_name_base,f_x1.cpu()) |
|
|
x1_CA = CA(x1).cuda() |
|
|
p1_CA,_ = net(x1_CA, mode='fc') |
|
|
p1_repeat = p1.repeat((1,CA.factor_num*CA.var_num)).reshape(p1_CA.shape) |
|
|
effect_context = p1_repeat - p1_CA |
|
|
effect_context = effect_context.reshape(b_sample_num,CA.factor_num,CA.var_num,-1) |
|
|
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*CA.factor_num,-1) |
|
|
weight = E_to_W(effect_context) |
|
|
weight = weight.reshape(b_sample_num,CA.factor_num) |
|
|
alphas = F.softmax(weight,dim=1) |
|
|
f_adapt = torch.zeros(f_x1.shape).cuda() |
|
|
for b in range(b_sample_num): |
|
|
for j in range(CA.factor_num): |
|
|
f_adapt[b] = f_adapt[b]+ alphas[b,j]*AdaptNet[j](f_x1[b]) |
|
|
if not source_flag: |
|
|
save_name = save_path+os.sep+str(i)+'.npy' |
|
|
print(save_name) |
|
|
np.save(save_name,f_adapt.cpu()) |
|
|
p_adapt = net(f_adapt, mode='c') |
|
|
p_adapt = p_adapt.argmax(dim=1) |
|
|
ps.append(p_adapt.detach().cpu().numpy()) |
|
|
ys.append(y1.numpy()) |
|
|
|
|
|
ps = np.concatenate(ps) |
|
|
ys = np.concatenate(ys) |
|
|
acc = np.mean(ys==ps)*100 |
|
|
return acc |
|
|
|
|
|
|
|
|
def evaluate_mapping(net, teloader, FA, AdaptNet, source=False): |
|
|
correct, count = 0, 0 |
|
|
ps = [] |
|
|
ys = [] |
|
|
pt = [] |
|
|
yt = [] |
|
|
factor_num = FA.factor_num |
|
|
for j in range(factor_num): |
|
|
ps.append([]) |
|
|
ys.append([]) |
|
|
pt.append([]) |
|
|
yt.append([]) |
|
|
ps.append([]) |
|
|
ys.append([]) |
|
|
|
|
|
for i,(x1, y1) in enumerate(teloader): |
|
|
with torch.no_grad(): |
|
|
x1 = x1.cuda() |
|
|
b = x1.size(0) |
|
|
if source: |
|
|
x_FA = FA(x1).cuda() |
|
|
_, f = net(x_FA, mode='fc') |
|
|
p,_ = net(x1, mode='fc') |
|
|
p = p.argmax(dim=1) |
|
|
ps[-1].append(p.detach().cpu().numpy()) |
|
|
ys[-1].append(y1.numpy()) |
|
|
else: |
|
|
p, f = net(x1, mode='fc') |
|
|
f = f.repeat((1,factor_num)).reshape((-1,f.size(1))) |
|
|
p = p.argmax(dim=1) |
|
|
ps[-1].append(p.detach().cpu().numpy()) |
|
|
ys[-1].append(y1.numpy()) |
|
|
for b_ in range(b): |
|
|
for j in range(factor_num): |
|
|
f_adapt = AdaptNet[j](f[b_*factor_num+j]) |
|
|
|
|
|
p1 = net(f_adapt, mode='c') |
|
|
p1 = p1.argmax(dim=0) |
|
|
ps[j].append(p1.detach().cpu()) |
|
|
ys[j].append(y1[b_]) |
|
|
p1_t = net(f[b_*factor_num+j], mode='c') |
|
|
|
|
|
p1_t = p1_t.argmax(dim=0) |
|
|
pt[j].append(p1_t.detach().cpu()) |
|
|
yt[j].append(y1[b_]) |
|
|
|
|
|
acc = np.zeros(factor_num+1) |
|
|
acc_t = np.zeros(factor_num+1) |
|
|
for j in range(factor_num): |
|
|
pred = torch.stack(ps[j]) |
|
|
label = torch.stack(ys[j]) |
|
|
acc[j] = (pred==label).sum()/float(len(ys[j]))*100 |
|
|
predt = torch.stack(pt[j]) |
|
|
labelt = torch.stack(yt[j]) |
|
|
acc_t[j] = (predt==labelt).sum()/float(len(yt[j]))*100 |
|
|
pred = np.concatenate(ps[-1]) |
|
|
label = np.concatenate(ys[-1]) |
|
|
acc[-1] = np.mean(pred==label)*100 |
|
|
|
|
|
return acc, acc_t |
|
|
def evaluate_causal_with_entropy(net, teloader, CA, AdaptNet): |
|
|
ps = [] |
|
|
ys = [] |
|
|
for i,(x1, y1) in enumerate(teloader): |
|
|
b_sample_num = x1.size(0) |
|
|
with torch.no_grad(): |
|
|
x1 = x1.cuda() |
|
|
p1,f_x1 = net(x1, mode='fc') |
|
|
|
|
|
x1_CA = CA(x1).cuda() |
|
|
p1_CA, _ = net(x1_CA, mode='fc') |
|
|
p1_repeat = p1.repeat((1,CA.factor_num*CA.var_num)).reshape(p1_CA.shape) |
|
|
effect_context = p1_repeat - p1_CA |
|
|
effect_context = effect_context.reshape(b_sample_num,CA.factor_num,CA.var_num,-1) |
|
|
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*CA.factor_num,-1) |
|
|
effect_context = F.softmax(effect_context,dim=1) |
|
|
|
|
|
weight = torch.sum(-effect_context*(torch.log2(effect_context)),dim=1) |
|
|
weight = weight.reshape(b_sample_num,CA.factor_num) |
|
|
alphas = F.softmax(-weight,dim=1) |
|
|
f_adapt = torch.zeros(f_x1.shape).cuda() |
|
|
for b in range(b_sample_num): |
|
|
for j in range(CA.factor_num): |
|
|
f_adapt[b] = f_adapt[b]+ alphas[b,j]*AdaptNet[j](f_x1[b]) |
|
|
p_adapt = net(f_adapt, mode='c') |
|
|
p_adapt = p_adapt.argmax(dim=1) |
|
|
ps.append(p_adapt.detach().cpu().numpy()) |
|
|
ys.append(y1.numpy()) |
|
|
|
|
|
ps = np.concatenate(ps) |
|
|
ys = np.concatenate(ys) |
|
|
acc = np.mean(ys==ps)*100 |
|
|
return acc |
|
|
def evaluate_causal_with_average(net, teloader, factor_num, AdaptNet): |
|
|
ps = [] |
|
|
ys = [] |
|
|
for i,(x1, y1) in enumerate(teloader): |
|
|
b_sample_num = x1.size(0) |
|
|
with torch.no_grad(): |
|
|
x1 = x1.cuda() |
|
|
p1,f_x1 = net(x1, mode='fc') |
|
|
f_adapt = torch.zeros(f_x1.shape).cuda() |
|
|
for b in range(b_sample_num): |
|
|
for j in range(factor_num): |
|
|
f_adapt[b] = f_adapt[b]+ float(1/factor_num)*AdaptNet[j](f_x1[b]) |
|
|
p_adapt = net(f_adapt, mode='c') |
|
|
p_adapt = p_adapt.argmax(dim=1) |
|
|
ps.append(p_adapt.detach().cpu().numpy()) |
|
|
ys.append(y1.numpy()) |
|
|
|
|
|
ps = np.concatenate(ps) |
|
|
ys = np.concatenate(ys) |
|
|
acc = np.mean(ys==ps)*100 |
|
|
return acc |
|
|
if __name__=='__main__': |
|
|
experiment() |