File size: 5,950 Bytes
197d4ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import numpy as np
import click
import pandas as pd
from network import resnet as resnet
from network import adaptor_v2
from tools import causalaugment_v3 as causalaugment
from main_my_joint_v13_auto import evaluate,evaluate_causal,evaluate_causal_with_entropy,evaluate_mapping,evaluate_causal_with_average
import data_loader_joint_v3 as data_loader
@click.command()
@click.option('--gpu', type=str, default='0', help='选择GPU编号')
@click.option('--svroot', type=str, default='./saved')
@click.option('--source_domain', type=str, default='art_painting', help='source domain')
@click.option('--svpath', type=str, default=None, help='保存日志的路径')
@click.option('--factor_num', type=int, default=16)
@click.option('--epoch', type=str, default='best')
@click.option('--stride', type=int, default=5)
@click.option('--eval_mapping', type=bool, default=False, help='是否查看mapping学习效果')
@click.option('--network', type=str, default='resnet18', help='项目文件保存路径')
def main(gpu, svroot, source_domain, svpath, factor_num, epoch, stride,eval_mapping, network):
evaluate_pacs(gpu, svroot, source_domain, svpath, factor_num, epoch, stride,eval_mapping, network)
def evaluate_pacs(gpu, svroot, source_domain, svpath, factor_num=16, epoch='best', stride=5,eval_mapping=False, network='resnet18'):
settings = locals().copy()
print(settings)
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
# 加载分类模型
if network == 'resnet18':
cls_net = resnet.resnet18(classes=7,c_dim=2048).cuda()
input_dim = 2048
if epoch == 'best':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'best_cls_net.pkl'))
elif epoch == 'last':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'last_cls_net.pkl'))
cls_net.load_state_dict(saved_weight)
cls_net.eval()
# 加载adaptation模型
FA = causalaugment.FactualAugment(m=4, factor_num=factor_num)
CA = causalaugment.MultiCounterfactualAugment(factor_num,stride)
AdaptNet = []
parameter_list = []
for i in range(factor_num):
if epoch == 'best':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl'))
elif epoch == 'last':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'last_mapping_'+str(i)+'.pkl'))
# saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl'))
mapping = adaptor_v2.mapping(input_dim,1024,input_dim,4).cuda()
mapping.load_state_dict(saved_weight)
AdaptNet.append(mapping)
if epoch == 'best':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl'))
elif epoch == 'last':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'last_E_to_W.pkl'))
E_to_W = adaptor_v2.effect_to_weight(7,70,1).cuda()
# Color_mapping.load_state_dict(saved_weight['Color_mapping'])
# Contrast_mapping.load_state_dict(saved_weight['Contrast_mapping'])
# Brightness_mapping.load_state_dict(saved_weight['Brightness_mapping'])
# saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl'))
E_to_W.load_state_dict(saved_weight)
# 测试
# str2fun = {
# 'art_painting': data_loader.load_pacs,
# 'cartoon': data_loader.load_pacs,
# 'photo': data_loader.load_pacs,
# 'sketch': data_loader.load_pacs,
# }
columns = ['art_painting', 'cartoon', 'photo', 'sketch']
target = [i for i in columns if i!=source_domain]
columns = [source_domain] + target
print("columns:",columns)
if eval_mapping:
index = FA.factor_list
index.append('w/o do (original x)')
else:
index = ['w/o do (original x)']
index_ours = ['do']
data_result = {}
data_result_ours = {}
for idx, data in enumerate(columns):
teset = data_loader.load_pacs(data, 'test')
teloader = DataLoader(teset, batch_size=4, num_workers=0)
# 计算评价指标
acc_CA = evaluate_causal(cls_net, teloader, CA, AdaptNet, E_to_W)
data_result_ours[data] = acc_CA
#最后一维度是原始数据
if eval_mapping:
if data == source_domain:
teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=True)
acc_avg = np.zeros(teacc_FA_aftermapping.shape)
acc_avg_CA = np.zeros(acc_CA.shape)
else:
teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=False)
acc_avg = acc_avg + teacc_FA_aftermapping
acc_avg_CA = acc_avg_CA + acc_CA
data_result[data]=teacc_FA_aftermapping
data_result[data+'_FA'] = acc_FA
else:
teacc = evaluate(cls_net, teloader)
if data == source_domain:
acc_avg = np.zeros(teacc.shape)
acc_avg_CA = np.zeros(acc_CA.shape)
else:
acc_avg = acc_avg + teacc
acc_avg_CA = acc_avg_CA + acc_CA
data_result[data] = teacc
acc_avg = acc_avg/float(len(target))
acc_avg_CA = acc_avg_CA/float(len(target))
data_result['Avg'] = acc_avg
data_result_ours['Avg'] = acc_avg_CA
df = pd.DataFrame(data_result,index = index)
df_ours = pd.DataFrame(data_result_ours,index = index_ours)
print(df)
print(df_ours)
if svpath is not None:
df.to_csv(svpath)
df_ours.to_csv(svpath, mode='a')
if __name__=='__main__':
main()
|