CausalStyleAdv / Meta-causal /code /main_test_pacs_v13.py
YuqianFu's picture
Upload folder using huggingface_hub
197d4ca verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import numpy as np
import click
import pandas as pd
from network import resnet as resnet
from network import adaptor_v2
from tools import causalaugment_v3 as causalaugment
from main_my_joint_v13_auto import evaluate,evaluate_causal,evaluate_causal_with_entropy,evaluate_mapping,evaluate_causal_with_average
import data_loader_joint_v3 as data_loader
@click.command()
@click.option('--gpu', type=str, default='0', help='选择GPU编号')
@click.option('--svroot', type=str, default='./saved')
@click.option('--source_domain', type=str, default='art_painting', help='source domain')
@click.option('--svpath', type=str, default=None, help='保存日志的路径')
@click.option('--factor_num', type=int, default=16)
@click.option('--epoch', type=str, default='best')
@click.option('--stride', type=int, default=5)
@click.option('--eval_mapping', type=bool, default=False, help='是否查看mapping学习效果')
@click.option('--network', type=str, default='resnet18', help='项目文件保存路径')
def main(gpu, svroot, source_domain, svpath, factor_num, epoch, stride,eval_mapping, network):
evaluate_pacs(gpu, svroot, source_domain, svpath, factor_num, epoch, stride,eval_mapping, network)
def evaluate_pacs(gpu, svroot, source_domain, svpath, factor_num=16, epoch='best', stride=5,eval_mapping=False, network='resnet18'):
settings = locals().copy()
print(settings)
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
# 加载分类模型
if network == 'resnet18':
cls_net = resnet.resnet18(classes=7,c_dim=2048).cuda()
input_dim = 2048
if epoch == 'best':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'best_cls_net.pkl'))
elif epoch == 'last':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'last_cls_net.pkl'))
cls_net.load_state_dict(saved_weight)
cls_net.eval()
# 加载adaptation模型
FA = causalaugment.FactualAugment(m=4, factor_num=factor_num)
CA = causalaugment.MultiCounterfactualAugment(factor_num,stride)
AdaptNet = []
parameter_list = []
for i in range(factor_num):
if epoch == 'best':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl'))
elif epoch == 'last':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'last_mapping_'+str(i)+'.pkl'))
# saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl'))
mapping = adaptor_v2.mapping(input_dim,1024,input_dim,4).cuda()
mapping.load_state_dict(saved_weight)
AdaptNet.append(mapping)
if epoch == 'best':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl'))
elif epoch == 'last':
print("loading weight of %s"%(epoch))
saved_weight = torch.load(os.path.join(svroot, 'last_E_to_W.pkl'))
E_to_W = adaptor_v2.effect_to_weight(7,70,1).cuda()
# Color_mapping.load_state_dict(saved_weight['Color_mapping'])
# Contrast_mapping.load_state_dict(saved_weight['Contrast_mapping'])
# Brightness_mapping.load_state_dict(saved_weight['Brightness_mapping'])
# saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl'))
E_to_W.load_state_dict(saved_weight)
# 测试
# str2fun = {
# 'art_painting': data_loader.load_pacs,
# 'cartoon': data_loader.load_pacs,
# 'photo': data_loader.load_pacs,
# 'sketch': data_loader.load_pacs,
# }
columns = ['art_painting', 'cartoon', 'photo', 'sketch']
target = [i for i in columns if i!=source_domain]
columns = [source_domain] + target
print("columns:",columns)
if eval_mapping:
index = FA.factor_list
index.append('w/o do (original x)')
else:
index = ['w/o do (original x)']
index_ours = ['do']
data_result = {}
data_result_ours = {}
for idx, data in enumerate(columns):
teset = data_loader.load_pacs(data, 'test')
teloader = DataLoader(teset, batch_size=4, num_workers=0)
# 计算评价指标
acc_CA = evaluate_causal(cls_net, teloader, CA, AdaptNet, E_to_W)
data_result_ours[data] = acc_CA
#最后一维度是原始数据
if eval_mapping:
if data == source_domain:
teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=True)
acc_avg = np.zeros(teacc_FA_aftermapping.shape)
acc_avg_CA = np.zeros(acc_CA.shape)
else:
teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=False)
acc_avg = acc_avg + teacc_FA_aftermapping
acc_avg_CA = acc_avg_CA + acc_CA
data_result[data]=teacc_FA_aftermapping
data_result[data+'_FA'] = acc_FA
else:
teacc = evaluate(cls_net, teloader)
if data == source_domain:
acc_avg = np.zeros(teacc.shape)
acc_avg_CA = np.zeros(acc_CA.shape)
else:
acc_avg = acc_avg + teacc
acc_avg_CA = acc_avg_CA + acc_CA
data_result[data] = teacc
acc_avg = acc_avg/float(len(target))
acc_avg_CA = acc_avg_CA/float(len(target))
data_result['Avg'] = acc_avg
data_result_ours['Avg'] = acc_avg_CA
df = pd.DataFrame(data_result,index = index)
df_ours = pd.DataFrame(data_result_ours,index = index_ours)
print(df)
print(df_ours)
if svpath is not None:
df.to_csv(svpath)
df_ours.to_csv(svpath, mode='a')
if __name__=='__main__':
main()