CausalStyleAdv / Meta-causal /code /main_my_joint_v13_auto.py
YuqianFu's picture
Upload folder using huggingface_hub
197d4ca verified
'''
训练 base 模型
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
from torch import optim
from torch.utils.data import DataLoader, RandomSampler
from torchvision import models
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from torch.cuda.amp import autocast,GradScaler
import os
import click
import time
import numpy as np
from network import mnist_net_my as mnist_net
from network import wideresnet as wideresnet
from network import resnet as resnet
from network import adaptor_v2
from tools import causalaugment_v3 as causalaugment
import data_loader_joint_v3 as data_loader
# from utils import set_requires_grad
HOME = os.environ['HOME']
@click.command()
@click.option('--gpu', type=str, default='0', help='选择gpu')
@click.option('--data', type=str, default='mnist', help='数据集名称')
@click.option('--ntr', type=int, default=None, help='选择训练集前ntr个样本')
@click.option('--translate', type=float, default=None, help='随机平移数据增强')
@click.option('--autoaug', type=str, default=None, help='AA FastAA RA')
@click.option('--n', type=int, default=3, help='选择多少个factor生成RA')
@click.option('--stride', type=int, default=5, help='if autoaug==CA_multiple, stride is used')
@click.option('--factor_num', type=int, default=16, help='the first n factors')
@click.option('--epochs', type=int, default=100)
@click.option('--nbatch', type=int, default=100, help='每个epoch中batch的数量')
@click.option('--batchsize', type=int, default=128, help='每个batch中样本的数量')
@click.option('--lr', type=float, default=1e-3)
@click.option('--lr_scheduler', type=str, default='none', help='是否选择学习率衰减策略')
@click.option('--svroot', type=str, default='./saved', help='项目文件保存路径')
@click.option('--clsadapt', type=bool, default=True, help='映射后是否用分类损失')
@click.option('--lambda_causal', type=float, default=1, help='the weight of reconstruction during mapping and causal ')
@click.option('--lambda_re', type=float, default=1, help='the weight of reconstruction during mapping and causal ')
@click.option('--randm', type=bool, default=True, help='m取值是否randm')
@click.option('--randn', type=bool, default=False, help='原始特征是否detach')
@click.option('--network', type=str, default='resnet18', help='项目文件保存路径')
def experiment(gpu, data, ntr, translate, autoaug,n,stride, factor_num, epochs, nbatch, batchsize, lr, lr_scheduler, svroot, clsadapt, lambda_causal,lambda_re,randm,randn,network):
settings = locals().copy()
print(settings)
# 全局设置
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
if not os.path.exists(svroot):
os.makedirs(svroot)
log_file = open(svroot+os.sep+'log.log',"w")
log_file.write(str(settings)+'\n')
writer = SummaryWriter(svroot)
CA = causalaugment.MultiCounterfactualAugment(factor_num,stride)
# FA = causalaugment.FactualAugment(m=4, factor_num=factor_num, randm=True)
# 加载数据集和模型
if data in ['mnist', 'mnist_t']:
if data == 'mnist':
trset = data_loader.load_mnist('train', translate=translate,twox=True, ntr=ntr, factor_num=factor_num,autoaug=autoaug,randm=randm,randn=randn,n=n,stride=stride)
elif data == 'mnist_t':
trset = data_loader.load_mnist_t('train', translate=translate, ntr=ntr)
teset = data_loader.load_mnist('test')
trloader = DataLoader(trset, batch_size=batchsize, num_workers=0, \
sampler=RandomSampler(trset, True, nbatch*batchsize))
teloader = DataLoader(teset, batch_size=batchsize, num_workers=0, shuffle=False)
cls_net = mnist_net.ConvNet().cuda()
AdaptNet = []
parameter_list = []
for i in range(factor_num):
mapping = adaptor_v2.mapping(1024,512,1024,2).cuda()
AdaptNet.append(mapping)
parameter_list.append({'params':mapping.parameters(),'lr':lr})
if autoaug == 'CA_multiple':
var_num = len(list(range(0, 31, stride)))
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda()
else:
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda()
parameter_list.append({'params':cls_net.parameters(),'lr':lr})
parameter_list.append({'params':E_to_W.parameters(),'lr':lr})
#print("---------------------------------------------------------------------------------------")
opt = optim.Adam(parameter_list, lr=lr)
if lr_scheduler == 'cosine':
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, epochs)
elif lr_scheduler == 'Exp':
scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.95)
elif lr_scheduler == 'Step':
scheduler = optim.lr_scheduler.StepLR(opt, step_size=int(epochs*0.8))
# print("------------------------------------opt_mapping---------------------------------------------------")
# for param_group in opt_mapping.param_groups:
# print(param_group.keys())
# # print(type(param_group))
# print([type(value) for value in param_group.values()])
# print('lr: ',param_group['lr'])
# print("------------------------------------opt_causal---------------------------------------------------")
# for param_group in opt_causal.param_groups:
# print(param_group.keys())
# # print(type(param_group))
# print([type(value) for value in param_group.values()])
# print('lr: ',param_group['lr'])
elif data == 'cifar10':
# 加载数据集
trset = data_loader.load_cifar10(split='train',twox=True, factor_num=factor_num,autoaug=autoaug,randm=randm,randn=randn,n=n,stride=stride)
teset = data_loader.load_cifar10(split='test')
trloader = DataLoader(trset, batch_size=batchsize, num_workers=4, shuffle=True, drop_last=True)
teloader = DataLoader(teset, batch_size=batchsize, num_workers=4, shuffle=False)
cls_net = wideresnet.WideResNet(16, 10, 4).cuda()
# cls_opt = optim.SGD(cls_net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
AdaptNet = []
parameter_list = []
for i in range(factor_num):
mapping = adaptor_v2.mapping(256,512,256,4).cuda()
AdaptNet.append(mapping)
parameter_list.append({'params':mapping.parameters(),'lr':lr})
if autoaug == 'CA_multiple':
var_num = len(list(range(0, 31, stride)))
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda()
else:
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda()
parameter_list.append({'params':cls_net.parameters(),'lr':lr})
parameter_list.append({'params':E_to_W.parameters(),'lr':lr})
#print("---------------------------------------------------------------------------------------")
# opt = optim.Adam(parameter_list)
opt = optim.SGD(parameter_list, lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
if lr_scheduler == 'cosine':
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, epochs)
elif lr_scheduler == 'Exp':
scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.95)
elif lr_scheduler == 'Step':
scheduler = optim.lr_scheduler.StepLR(opt, step_size=int(epochs*0.8))
elif data in ['art_painting', 'cartoon', 'photo', 'sketch']:
# 加载数据集
trset = data_loader.load_pacs(domain=data, split='train', twox=True, factor_num=factor_num,autoaug=autoaug,randm=randm,randn=randn,n=n,stride=stride)
teset = data_loader.load_pacs(domain=data, split='val')
trloader = DataLoader(trset, batch_size=batchsize, num_workers=4, shuffle=True, drop_last=True)
teloader = DataLoader(teset, batch_size=batchsize, num_workers=4, shuffle=False)
if network == 'resnet18':
cls_net = resnet.resnet18(classes=7,c_dim=2048).cuda()
input_dim = 2048
# for param in cls_net.features.parameters():
# param.requires_grad = False
# for name, parms in cls_net.named_parameters():
# print('-->name:', name)
# print('-->grad_requirs:',parms.requires_grad)
# cls_opt = optim.SGD(cls_net.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
# print(cls_net.state_dict())
classifier_param = list(map(id, cls_net.class_classifier.parameters()))
backbone_param = filter(lambda p: id(p) not in classifier_param and p.requires_grad, cls_net.parameters())
AdaptNet = []
parameter_list = []
for i in range(factor_num):
mapping = adaptor_v2.mapping(input_dim,1024,input_dim,4).cuda()
AdaptNet.append(mapping)
parameter_list.append({'params':mapping.parameters(),'lr':lr})
if autoaug == 'CA_multiple':
var_num = len(list(range(0, 31, stride)))
E_to_W = adaptor_v2.effect_to_weight(7,70,1).cuda()
else:
E_to_W = adaptor_v2.effect_to_weight(7,70,1).cuda()
parameter_list.append({'params':backbone_param,'lr':0.01*lr})
parameter_list.append({'params':cls_net.class_classifier.parameters(),'lr':lr})
parameter_list.append({'params':E_to_W.parameters(),'lr':lr})
#print("---------------------------------------------------------------------------------------")
# opt = optim.Adam(parameter_list)
opt = optim.SGD(parameter_list, momentum=0.9, nesterov=True, weight_decay=5e-4)
if lr_scheduler == 'cosine':
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, epochs)
elif lr_scheduler == 'Exp':
scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.99999)
elif lr_scheduler == 'Step':
scheduler = optim.lr_scheduler.StepLR(opt, step_size=15)
elif 'synthia' in data:
# 加载数据集
branch = data.split('_')[1]
trset = data_loader.load_synthia(branch)
trloader = DataLoader(trset, batch_size=batchsize, num_workers=8, shuffle=True)
teloader = DataLoader(trset, batch_size=batchsize, num_workers=8, shuffle=True)
imsize = [192, 320]
nclass = 14
# 加载模型
cls_net = fcn.FCN_resnet50(nclass=nclass).cuda()
cls_opt = optim.Adam(cls_net.parameters(), lr=lr)#, weight_decay=1e-4) # 对于synthia 加上weigh_decay会掉1-2个点
if lr_scheduler == 'cosine':
scheduler = optim.lr_scheduler.CosineAnnealingLR(cls_opt, epochs*len(trloader))
cls_criterion = nn.CrossEntropyLoss()
adapt_criterion = nn.MSELoss()
# 开始训练
best_acc = 0
best_acc_t = 0
scaler = GradScaler()
for epoch in range(epochs):
t1 = time.time()
loss_list = []
cls_net.train()
# unloader = transforms.ToPILImage()
print(len(trloader))
for i, (x_four,y) in enumerate(trloader):
b_sample_num = y.size(0)
x, x_RA, x_FA, x_CA, y = x_four[0].cuda(), x_four[1].cuda(), x_four[2].cuda(), x_four[3].cuda(), y.cuda()
b, c, h, w = x.shape
# x_FA_ = x_FA.transpose(1,2)
x_FA = x_FA.reshape(b*factor_num, c, h, w)
x_CA = x_CA.reshape(b*factor_num*var_num, c, h, w)
#learning mapping
y_repeat = y.unsqueeze(0).reshape(b_sample_num,1).repeat((1,factor_num)).reshape(1,b_sample_num*factor_num).squeeze()
# x_FA = FA(x).cuda().detach()
# x_CA = CA(x_RA).cuda().detach()
with autocast():
p,f = cls_net(x)
# print("x.shape:",x.shape)
# print("x_FA.shape:",x_FA.shape)
_,f_FA = cls_net(x_FA)
p_RA,f_RA = cls_net(x_RA)
p_CA,_ = cls_net(x_CA)
# print("f.shape:",f.shape)
# print("f_FA.shape:",f_FA.shape)
#learning mapping
f_repeat = f.repeat((1,factor_num)).reshape(f_FA.shape)
f_adapt = torch.zeros(f_FA.shape).cuda()
for b in range(b_sample_num):
for j in range(factor_num):
f_adapt[b*factor_num+j] = AdaptNet[j](f_FA[b*factor_num+j])
p_adapt = cls_net(f_adapt, mode='c')
#learning causality
if autoaug == 'CA_multiple':
p_RA_repeat = p_RA.repeat((1,factor_num*var_num)).reshape(p_CA.shape)
effect_context = p_RA_repeat - p_CA
effect_context = effect_context.reshape(b_sample_num,factor_num,var_num,-1)
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*factor_num,-1)
# print("effect_context.shape:",effect_context.shape)
else:
p_RA_repeat = p_RA.repeat((1,factor_num)).reshape(p_CA.shape)
effect_context = p_RA_repeat - p_CA
weight = E_to_W(effect_context)
# weight = E_to_W(effect_context.detach())
weight = weight.reshape(b_sample_num,factor_num)
alphas = F.softmax(weight,dim=1)
f_adapt_RA = torch.zeros(f_RA.shape).cuda()
for b in range(b_sample_num):
for j in range(factor_num):
f_adapt_RA[b] = f_adapt_RA[b]+ alphas[b,j]*AdaptNet[j](f_RA[b])
p_adapt_RA = cls_net(f_adapt_RA, mode='c')
cls_loss = cls_criterion(p, y)
re_mapping = adapt_criterion(f_adapt,f_repeat)
re_causal = adapt_criterion(f_adapt_RA,f)
cls_loss_mapping = cls_criterion(p_adapt, y_repeat)
cls_loss_causal = cls_criterion(p_adapt_RA, y)
loss = cls_loss + cls_loss_mapping + lambda_re*re_mapping + lambda_causal*(lambda_re*re_causal + cls_loss_causal)
opt.zero_grad()
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
loss_list.append([cls_loss.item(), cls_loss_mapping.item(),cls_loss_causal.item(), re_mapping.item(), re_causal.item()])
# 调整学习率
if lr_scheduler in ['cosine', 'Exp', 'Step']:
writer.add_scalar('scalar/lr', opt.param_groups[0]["lr"], epoch)
print(opt.param_groups[0]["lr"])
print("changing lr")
scheduler.step()
cls_loss, cls_loss_mapping, cls_loss_causal, re_mapping, re_causal = np.mean(loss_list, 0)
# 测试,并保存最优模型
cls_net.eval()
if data in ['mnist', 'mnist_t', 'cifar10', 'mnistvis', 'art_painting', 'cartoon', 'photo', 'sketch']:
teacc = evaluate(cls_net, teloader)
elif 'synthia' in data:
teacc = evaluate_seg(cls_net, teloader, nclass) # 这里算的其实是 miou
if best_acc < teacc:
print(f'---------------------saving model at epoch {epoch}----------------------------------------------------')
log_file.write(f'saving model at epoch {epoch}\n')
best_acc = teacc
torch.save(cls_net.state_dict(),os.path.join(svroot, 'best_cls_net.pkl'))
for j in range(factor_num):
torch.save(AdaptNet[j].state_dict(),os.path.join(svroot, 'best_mapping_'+str(j)+'.pkl'))
torch.save(E_to_W.state_dict(), os.path.join(svroot, 'best_E_to_W.pkl'))
# 保存日志
t2 = time.time()
print(f'epoch {epoch}, time {t2-t1:.2f}, cls_loss {cls_loss:.4f} cls_loss_mapping {cls_loss_mapping:.4f} cls_loss_causal {cls_loss_causal:.4f} re_mapping {re_mapping:.4f} re_causal {re_causal:.4f} /// teacc {teacc:2.2f} lr {opt.param_groups[0]["lr"]:.8f}')
log_file.write(f'epoch {epoch}, time {t2-t1:.2f}, cls_loss {cls_loss:.4f} cls_loss_mapping {cls_loss_mapping:.4f} cls_loss_causal {cls_loss_causal:.4f} re_mapping {re_mapping:.4f} re_causal {re_causal:.4f} /// teacc {teacc:2.2f} lr {opt.param_groups[0]["lr"]:.8f} \n')
writer.add_scalar('scalar/cls_loss', cls_loss, epoch)
writer.add_scalar('scalar/cls_loss_mapping', cls_loss_mapping, epoch)
writer.add_scalar('scalar/cls_loss_causal', cls_loss_causal, epoch)
writer.add_scalar('scalar/re_mapping', re_mapping, epoch)
writer.add_scalar('scalar/re_causal', re_causal, epoch)
writer.add_scalar('scalar/teacc', teacc, epoch)
print(f'---------------------saving last model at epoch {epoch}----------------------------------------------------')
log_file.write(f'saving last model at epoch {epoch}\n')
torch.save(cls_net.state_dict(),os.path.join(svroot, 'last_cls_net.pkl'))
for j in range(factor_num):
torch.save(AdaptNet[j].state_dict(),os.path.join(svroot, 'last_mapping_'+str(j)+'.pkl'))
torch.save(E_to_W.state_dict(), os.path.join(svroot, 'last_E_to_W.pkl'))
writer.close()
def evalute_pacs(source_domain,cls_net,CA,AdaptNet,E_to_W):
cls_net.eval()
data_total = ['art_painting', 'cartoon', 'photo', 'sketch']
target = [i for i in data_total if i!=source_domain]
acc_CA = np.zeros(len(target))
for idx, data in enumerate(target):
teset = data_loader.load_pacs(data, 'test')
teloader = DataLoader(teset, batch_size=6, num_workers=0)
# 计算评价指标
acc_CA[idx] = evaluate_causal(cls_net, teloader, CA, AdaptNet, E_to_W)
acc_avg_CA = sum(acc_CA)/len(target)
return acc_avg_CA,acc_CA
def evaluate(net, teloader):
ps = []
ys = []
for i,(x1, y1) in enumerate(teloader):
with torch.no_grad():
x1 = x1.cuda()
p1,_ = net(x1, mode='fc')
p1 = p1.argmax(dim=1)
ps.append(p1.detach().cpu().numpy())
ys.append(y1.numpy())
# 计算评价指标
ps = np.concatenate(ps)
ys = np.concatenate(ys)
acc = np.mean(ys==ps)*100
return acc
def extract_feature(net, teloader, savedir):
ps = []
ys = []
for i,(x1, y1) in enumerate(teloader):
img_class = y1[0].cpu().numpy()
save_path = os.path.join(savedir,str(img_class))
if not os.path.exists(save_path):
os.makedirs(save_path)
with torch.no_grad():
x1 = x1.cuda()
p1,f1 = net(x1, mode='fc')
save_name = save_path+os.sep+str(i)+'.npy'
np.save(save_name,f1.cpu())
p1 = p1.argmax(dim=1)
ps.append(p1.detach().cpu().numpy())
ys.append(y1.numpy())
# 计算评价指标
ps = np.concatenate(ps)
ys = np.concatenate(ys)
acc = np.mean(ys==ps)*100
return acc
def evaluate_causal(net, teloader, CA, AdaptNet, E_to_W):
ps = []
ys = []
p_orig = []
y_orig = []
for i,(x1, y1) in enumerate(teloader):
b_sample_num = x1.size(0)
with torch.no_grad():
x1 = x1.cuda()
p1,f_x1 = net(x1, mode='fc')
x1_CA = CA(x1).cuda()
p1_CA,_ = net(x1_CA, mode='fc')
p1_repeat = p1.repeat((1,CA.factor_num*CA.var_num)).reshape(p1_CA.shape)
effect_context = p1_repeat - p1_CA
effect_context = effect_context.reshape(b_sample_num,CA.factor_num,CA.var_num,-1)
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*CA.factor_num,-1)
weight = E_to_W(effect_context)
weight = weight.reshape(b_sample_num,CA.factor_num)
alphas = F.softmax(weight,dim=1)
f_adapt = torch.zeros(f_x1.shape).cuda()
for b in range(b_sample_num):
for j in range(CA.factor_num):
f_adapt[b] = f_adapt[b]+ alphas[b,j]*AdaptNet[j](f_x1[b])
p_adapt = net(f_adapt, mode='c')
p_adapt = p_adapt.argmax(dim=1)
ps.append(p_adapt.detach().cpu().numpy())
ys.append(y1.numpy())
# 计算评价指标
ps = np.concatenate(ps)
ys = np.concatenate(ys)
acc = np.mean(ys==ps)*100
return acc
def extract_feature_do(net, teloader, CA, AdaptNet, E_to_W, savedir_base, savedir,source_flag):
ps = []
ys = []
for i,(x1, y1) in enumerate(teloader):
img_class = y1[0].cpu().numpy()
save_path_base = os.path.join(savedir_base,str(img_class))
save_path = os.path.join(savedir,str(img_class))
if not os.path.exists(save_path_base):
os.makedirs(save_path_base)
if not os.path.exists(save_path):
os.makedirs(save_path)
b_sample_num = x1.size(0)
with torch.no_grad():
x1 = x1.cuda()
p1,f_x1 = net(x1, mode='fc')
save_name_base = save_path_base+os.sep+str(i)+'_base.npy'
print(save_name_base)
np.save(save_name_base,f_x1.cpu())
x1_CA = CA(x1).cuda()
p1_CA,_ = net(x1_CA, mode='fc')
p1_repeat = p1.repeat((1,CA.factor_num*CA.var_num)).reshape(p1_CA.shape)
effect_context = p1_repeat - p1_CA
effect_context = effect_context.reshape(b_sample_num,CA.factor_num,CA.var_num,-1)
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*CA.factor_num,-1)
weight = E_to_W(effect_context)
weight = weight.reshape(b_sample_num,CA.factor_num)
alphas = F.softmax(weight,dim=1)
f_adapt = torch.zeros(f_x1.shape).cuda()
for b in range(b_sample_num):
for j in range(CA.factor_num):
f_adapt[b] = f_adapt[b]+ alphas[b,j]*AdaptNet[j](f_x1[b])
if not source_flag:
save_name = save_path+os.sep+str(i)+'.npy'
print(save_name)
np.save(save_name,f_adapt.cpu())
p_adapt = net(f_adapt, mode='c')
p_adapt = p_adapt.argmax(dim=1)
ps.append(p_adapt.detach().cpu().numpy())
ys.append(y1.numpy())
# 计算评价指标
ps = np.concatenate(ps)
ys = np.concatenate(ys)
acc = np.mean(ys==ps)*100
return acc
def evaluate_mapping(net, teloader, FA, AdaptNet, source=False):
correct, count = 0, 0
ps = []
ys = []
pt = []
yt = []
factor_num = FA.factor_num
for j in range(factor_num):
ps.append([])
ys.append([])
pt.append([])
yt.append([])
ps.append([])
ys.append([])
# print(len(ps),len(ys))
for i,(x1, y1) in enumerate(teloader):
with torch.no_grad():
x1 = x1.cuda()
b = x1.size(0)
if source:
x_FA = FA(x1).cuda()
_, f = net(x_FA, mode='fc')
p,_ = net(x1, mode='fc')
p = p.argmax(dim=1)
ps[-1].append(p.detach().cpu().numpy())
ys[-1].append(y1.numpy())
else:
p, f = net(x1, mode='fc')
f = f.repeat((1,factor_num)).reshape((-1,f.size(1)))
p = p.argmax(dim=1)
ps[-1].append(p.detach().cpu().numpy())
ys[-1].append(y1.numpy())
for b_ in range(b):
for j in range(factor_num):
f_adapt = AdaptNet[j](f[b_*factor_num+j])
#f_adapt = torch.mm(AdaptNet[j].W1,f_FA[b_*factor_num+j].unsqueeze(1)).squeeze()
p1 = net(f_adapt, mode='c')
p1 = p1.argmax(dim=0)
ps[j].append(p1.detach().cpu())
ys[j].append(y1[b_])
p1_t = net(f[b_*factor_num+j], mode='c')
# print("p1_t.shape:",p1_t.shape)
p1_t = p1_t.argmax(dim=0)
pt[j].append(p1_t.detach().cpu())
yt[j].append(y1[b_])
# 计算评价指标
acc = np.zeros(factor_num+1)
acc_t = np.zeros(factor_num+1)
for j in range(factor_num):
pred = torch.stack(ps[j])
label = torch.stack(ys[j])
acc[j] = (pred==label).sum()/float(len(ys[j]))*100
predt = torch.stack(pt[j])
labelt = torch.stack(yt[j])
acc_t[j] = (predt==labelt).sum()/float(len(yt[j]))*100
pred = np.concatenate(ps[-1])
label = np.concatenate(ys[-1])
acc[-1] = np.mean(pred==label)*100
# print("acc:",acc)
return acc, acc_t
def evaluate_causal_with_entropy(net, teloader, CA, AdaptNet):
ps = []
ys = []
for i,(x1, y1) in enumerate(teloader):
b_sample_num = x1.size(0)
with torch.no_grad():
x1 = x1.cuda()
p1,f_x1 = net(x1, mode='fc')
x1_CA = CA(x1).cuda()
p1_CA, _ = net(x1_CA, mode='fc')
p1_repeat = p1.repeat((1,CA.factor_num*CA.var_num)).reshape(p1_CA.shape)
effect_context = p1_repeat - p1_CA
effect_context = effect_context.reshape(b_sample_num,CA.factor_num,CA.var_num,-1)
effect_context = effect_context.mean(axis=2).reshape(b_sample_num*CA.factor_num,-1)
effect_context = F.softmax(effect_context,dim=1)
# weight = calc_ent(effect_context)
weight = torch.sum(-effect_context*(torch.log2(effect_context)),dim=1)
weight = weight.reshape(b_sample_num,CA.factor_num)
alphas = F.softmax(-weight,dim=1)
f_adapt = torch.zeros(f_x1.shape).cuda()
for b in range(b_sample_num):
for j in range(CA.factor_num):
f_adapt[b] = f_adapt[b]+ alphas[b,j]*AdaptNet[j](f_x1[b])
p_adapt = net(f_adapt, mode='c')
p_adapt = p_adapt.argmax(dim=1)
ps.append(p_adapt.detach().cpu().numpy())
ys.append(y1.numpy())
# 计算评价指标
ps = np.concatenate(ps)
ys = np.concatenate(ys)
acc = np.mean(ys==ps)*100
return acc
def evaluate_causal_with_average(net, teloader, factor_num, AdaptNet):
ps = []
ys = []
for i,(x1, y1) in enumerate(teloader):
b_sample_num = x1.size(0)
with torch.no_grad():
x1 = x1.cuda()
p1,f_x1 = net(x1, mode='fc')
f_adapt = torch.zeros(f_x1.shape).cuda()
for b in range(b_sample_num):
for j in range(factor_num):
f_adapt[b] = f_adapt[b]+ float(1/factor_num)*AdaptNet[j](f_x1[b])
p_adapt = net(f_adapt, mode='c')
p_adapt = p_adapt.argmax(dim=1)
ps.append(p_adapt.detach().cpu().numpy())
ys.append(y1.numpy())
# 计算评价指标
ps = np.concatenate(ps)
ys = np.concatenate(ys)
acc = np.mean(ys==ps)*100
return acc
if __name__=='__main__':
experiment()