File size: 5,950 Bytes
197d4ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import os
import numpy as np
import click
import pandas as pd

from network import resnet as resnet
from network import adaptor_v2
from tools import causalaugment_v3 as causalaugment
from main_my_joint_v13_auto import evaluate,evaluate_causal,evaluate_causal_with_entropy,evaluate_mapping,evaluate_causal_with_average
import data_loader_joint_v3 as data_loader

@click.command()
@click.option('--gpu', type=str, default='0', help='选择GPU编号')
@click.option('--svroot', type=str, default='./saved')
@click.option('--source_domain', type=str, default='art_painting', help='source domain')
@click.option('--svpath', type=str, default=None, help='保存日志的路径')
@click.option('--factor_num', type=int, default=16)
@click.option('--epoch', type=str, default='best')
@click.option('--stride', type=int, default=5)
@click.option('--eval_mapping', type=bool, default=False, help='是否查看mapping学习效果')
@click.option('--network', type=str, default='resnet18', help='项目文件保存路径')
def main(gpu, svroot, source_domain, svpath, factor_num, epoch, stride,eval_mapping, network):
    evaluate_pacs(gpu, svroot, source_domain, svpath, factor_num, epoch, stride,eval_mapping, network)
    
def evaluate_pacs(gpu, svroot, source_domain, svpath, factor_num=16, epoch='best', stride=5,eval_mapping=False, network='resnet18'):
    settings = locals().copy()
    print(settings)
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu

    # 加载分类模型
    if network == 'resnet18':
        cls_net = resnet.resnet18(classes=7,c_dim=2048).cuda()
        input_dim = 2048
    if epoch == 'best':
        print("loading weight of %s"%(epoch))
        saved_weight = torch.load(os.path.join(svroot, 'best_cls_net.pkl'))
    elif epoch == 'last':
        print("loading weight of %s"%(epoch))
        saved_weight = torch.load(os.path.join(svroot, 'last_cls_net.pkl'))
    cls_net.load_state_dict(saved_weight)
    cls_net.eval()
    # 加载adaptation模型
    FA = causalaugment.FactualAugment(m=4, factor_num=factor_num)
    CA = causalaugment.MultiCounterfactualAugment(factor_num,stride) 
    AdaptNet = []
    parameter_list = []
    for i in range(factor_num):
        if epoch == 'best':
            print("loading weight of %s"%(epoch))
            saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl'))
        elif epoch == 'last':
            print("loading weight of %s"%(epoch))
            saved_weight = torch.load(os.path.join(svroot, 'last_mapping_'+str(i)+'.pkl'))
        # saved_weight = torch.load(os.path.join(svroot, 'best_mapping_'+str(i)+'.pkl'))
        mapping = adaptor_v2.mapping(input_dim,1024,input_dim,4).cuda()
        mapping.load_state_dict(saved_weight)
        AdaptNet.append(mapping)
    if epoch == 'best':
        print("loading weight of %s"%(epoch))
        saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl'))
    elif epoch == 'last':
        print("loading weight of %s"%(epoch))
        saved_weight = torch.load(os.path.join(svroot, 'last_E_to_W.pkl'))
    E_to_W = adaptor_v2.effect_to_weight(7,70,1).cuda()
    # Color_mapping.load_state_dict(saved_weight['Color_mapping'])
    # Contrast_mapping.load_state_dict(saved_weight['Contrast_mapping'])
    # Brightness_mapping.load_state_dict(saved_weight['Brightness_mapping'])
    # saved_weight = torch.load(os.path.join(svroot, 'best_E_to_W.pkl'))
    E_to_W.load_state_dict(saved_weight)

    # 测试
    # str2fun = { 
    #     'art_painting': data_loader.load_pacs,
    #     'cartoon': data_loader.load_pacs,
    #     'photo': data_loader.load_pacs,
    #     'sketch': data_loader.load_pacs,
    #     }   
    columns = ['art_painting', 'cartoon', 'photo', 'sketch']
    target = [i for i in columns if i!=source_domain]
    columns = [source_domain] + target
    print("columns:",columns)
    if eval_mapping:
        index = FA.factor_list
        index.append('w/o do (original x)')
    else:
        index = ['w/o do (original x)']
    index_ours = ['do']
    data_result = {}
    data_result_ours = {}

    for idx, data in enumerate(columns):
        teset = data_loader.load_pacs(data, 'test')
        teloader = DataLoader(teset, batch_size=4, num_workers=0)
        # 计算评价指标
        acc_CA = evaluate_causal(cls_net, teloader, CA, AdaptNet, E_to_W)
        data_result_ours[data] = acc_CA
        #最后一维度是原始数据
        if eval_mapping:
            if data == source_domain:
                teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=True)
                acc_avg = np.zeros(teacc_FA_aftermapping.shape)
                acc_avg_CA = np.zeros(acc_CA.shape)
            else:
                teacc_FA_aftermapping, acc_FA = evaluate_mapping(cls_net, teloader, FA, AdaptNet, source=False)
                acc_avg = acc_avg + teacc_FA_aftermapping
                acc_avg_CA = acc_avg_CA + acc_CA
            data_result[data]=teacc_FA_aftermapping
            data_result[data+'_FA'] = acc_FA
        else:
            teacc = evaluate(cls_net, teloader)
            if data == source_domain:
                acc_avg = np.zeros(teacc.shape)
                acc_avg_CA = np.zeros(acc_CA.shape)
            else:
                acc_avg = acc_avg + teacc
                acc_avg_CA = acc_avg_CA + acc_CA
            data_result[data] = teacc        
    acc_avg = acc_avg/float(len(target))
    acc_avg_CA = acc_avg_CA/float(len(target))
    
    data_result['Avg'] = acc_avg
    data_result_ours['Avg'] = acc_avg_CA

    df = pd.DataFrame(data_result,index = index)
    df_ours = pd.DataFrame(data_result_ours,index = index_ours)
    print(df)
    print(df_ours)       
    if svpath is not None:
        df.to_csv(svpath)
        df_ours.to_csv(svpath, mode='a')
if __name__=='__main__':
    main()