|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import click |
|
|
import pandas as pd |
|
|
|
|
|
from network import mnist_net_my as mnist_net |
|
|
from network import adaptor_v2 |
|
|
from tools import causalaugment_v3 as causalaugment |
|
|
from main_my_joint_v13_auto import evaluate,evaluate_causal,evaluate_causal_with_entropy,evaluate_mapping,evaluate_causal_with_average |
|
|
import data_loader_joint_v3 as data_loader |
|
|
|
|
|
@click.command() |
|
|
@click.option('--gpu', type=str, default='0', help='选择GPU编号') |
|
|
@click.option('--svroot', type=str, default='./saved') |
|
|
@click.option('--svpath', type=str, default=None, help='保存日志的路径') |
|
|
@click.option('--channels', type=int, default=3) |
|
|
@click.option('--factor_num', type=int, default=16) |
|
|
@click.option('--stride', type=int, default=16) |
|
|
@click.option('--epoch', type=str, default='best') |
|
|
@click.option('--eval_mapping', type=bool, default=True, help='是否查看mapping学习效果') |
|
|
def main(gpu, svroot, svpath, channels, factor_num,stride, epoch, eval_mapping): |
|
|
evaluate_digit(gpu, svroot, svpath, channels, factor_num, stride,epoch, eval_mapping) |
|
|
|
|
|
def evaluate_digit(gpu, svroot, svpath, channels=3, factor_num=16,stride=5,epoch='best', eval_mapping=True): |
|
|
settings = locals().copy() |
|
|
print(settings) |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu |
|
|
|
|
|
|
|
|
if channels == 3: |
|
|
cls_net = mnist_net.ConvNet().cuda() |
|
|
elif channels == 1: |
|
|
cls_net = mnist_net.ConvNet(imdim=channels).cuda() |
|
|
if epoch == 'best': |
|
|
print("loading weight of %s"%(epoch)) |
|
|
saved_weight = torch.load(os.path.join(svroot, 'best_cls_net.pkl')) |
|
|
elif epoch == 'last': |
|
|
print("loading weight of %s"%(epoch)) |
|
|
saved_weight = torch.load(os.path.join(svroot, 'last_cls_net.pkl')) |
|
|
cls_net.load_state_dict(saved_weight) |
|
|
|
|
|
|
|
|
FA = causalaugment.FactualAugment(m=4, factor_num=factor_num) |
|
|
CA = causalaugment.MultiCounterfactualAugment(factor_num,stride) |
|
|
|
|
|
|
|
|
|
|
|
AdaptNet = [] |
|
|
parameter_list = [] |
|
|
for i in range(factor_num): |
|
|
if epoch == 'best': |
|
|
print("loading weight of %s"%(epoch)) |
|
|
saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl')) |
|
|
elif epoch == 'last': |
|
|
print("loading weight of %s"%(epoch)) |
|
|
saved_weight = torch.load(os.path.join(svroot, 'last_mapping_'+str(i)+'.pkl')) |
|
|
|
|
|
mapping = adaptor_v2.mapping(1024,512,1024,2).cuda() |
|
|
mapping.load_state_dict(saved_weight) |
|
|
AdaptNet.append(mapping) |
|
|
if epoch == 'best': |
|
|
print("loading weight of %s"%(epoch)) |
|
|
saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl')) |
|
|
elif epoch == 'last': |
|
|
print("loading weight of %s"%(epoch)) |
|
|
saved_weight = torch.load(os.path.join(svroot, 'last_E_to_W.pkl')) |
|
|
|
|
|
E_to_W = adaptor_v2.effect_to_weight(10,100,1).cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
E_to_W.load_state_dict(saved_weight) |
|
|
|
|
|
|
|
|
str2fun = { |
|
|
'mnist': data_loader.load_mnist, |
|
|
'mnist_m': data_loader.load_mnist_m, |
|
|
'usps': data_loader.load_usps, |
|
|
'svhn': data_loader.load_svhn, |
|
|
'syndigit': data_loader.load_syndigit, |
|
|
} |
|
|
columns = ['mnist', 'svhn', 'mnist_m', 'syndigit','usps'] |
|
|
target = ['svhn', 'mnist_m', 'syndigit','usps'] |
|
|
if eval_mapping: |
|
|
index = FA.factor_list |
|
|
index.append('w/o do (original x)') |
|
|
else: |
|
|
index = ['w/o do (original x)'] |
|
|
index_ours = ['do'] |
|
|
data_result = {} |
|
|
data_result_ours = {} |
|
|
cls_net.eval() |
|
|
for idx, data in enumerate(columns): |
|
|
teset = str2fun[data]('test', channels=channels) |
|
|
teloader = DataLoader(teset, batch_size=8, num_workers=0) |
|
|
|
|
|
acc_CA = evaluate_causal(cls_net, teloader, CA, AdaptNet, E_to_W) |
|
|
data_result_ours[data] = acc_CA |
|
|
|
|
|
if eval_mapping: |
|
|
if data == 'mnist': |
|
|
teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=True) |
|
|
acc_avg = np.zeros(teacc_FA_aftermapping.shape) |
|
|
acc_avg_CA = np.zeros(acc_CA.shape) |
|
|
else: |
|
|
teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=False) |
|
|
acc_avg = acc_avg + teacc_FA_aftermapping |
|
|
acc_avg_CA = acc_avg_CA + acc_CA |
|
|
data_result[data]=teacc_FA_aftermapping |
|
|
data_result[data+'_FA'] = acc_FA |
|
|
else: |
|
|
teacc = evaluate(cls_net, teloader) |
|
|
if data == 'mnist': |
|
|
acc_avg = np.zeros(teacc.shape) |
|
|
acc_avg_CA = np.zeros(acc_CA.shape) |
|
|
else: |
|
|
acc_avg = acc_avg + teacc |
|
|
acc_avg_CA = acc_avg_CA + acc_CA |
|
|
data_result[data] = teacc |
|
|
acc_avg = acc_avg/float(len(target)) |
|
|
acc_avg_CA = acc_avg_CA/float(len(target)) |
|
|
|
|
|
data_result['Avg'] = acc_avg |
|
|
data_result_ours['Avg'] = acc_avg_CA |
|
|
|
|
|
df = pd.DataFrame(data_result,index = index) |
|
|
df_ours = pd.DataFrame(data_result_ours,index = index_ours) |
|
|
print(df) |
|
|
print(df_ours) |
|
|
if svpath is not None: |
|
|
df.to_csv(svpath) |
|
|
df_ours.to_csv(svpath, mode='a') |
|
|
|
|
|
if __name__=='__main__': |
|
|
main() |
|
|
|
|
|
|