Continual-Mega commited on
Commit
e5ffdfc
·
verified ·
1 Parent(s): b24ca95

Upload eval_continual.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval_continual.py +281 -0
eval_continual.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ from torch.nn import functional as F
7
+ from tqdm import tqdm
8
+ from CLIP.clip import create_model
9
+ from CLIP.adapter import CLIPAD
10
+ from sklearn.metrics import roc_auc_score, average_precision_score
11
+ from dataset.continual import ImageDataset
12
+ import csv
13
+ import logging
14
+ from CoOp import PromptMaker
15
+ import json
16
+ from safetensors.torch import load_file
17
+
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+
20
+ import warnings
21
+ warnings.filterwarnings("ignore")
22
+
23
+ def setup_seed(seed):
24
+ os.environ['PYTHONHASHSEED'] = str(seed)
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+ np.random.seed(seed)
28
+ random.seed(seed)
29
+
30
+ def get_logger(output_dir):
31
+ # set log file
32
+ log_file = f"{output_dir}/log.log"
33
+ head = '%(asctime)-15s %(message)s'
34
+ logging.basicConfig(filename=log_file,
35
+ format=head)
36
+ logger = logging.getLogger()
37
+ logger.setLevel(logging.INFO)
38
+ console = logging.StreamHandler()
39
+ logging.getLogger('').addHandler(console)
40
+
41
+ return logger
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser(description='Evaluation')
45
+ parser.add_argument('--model_name', type=str, default='ViT-L-14-336', help="ViT-B-16-plus-240, ViT-L-14-336")
46
+ parser.add_argument('--pretrain', type=str, default='openai', help="laion400m, openai")
47
+ parser.add_argument('--img_size', type=int, default=336)
48
+ parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
49
+ parser.add_argument('--seed', type=int, default=111)
50
+ parser.add_argument('--gpu', type=str, default="0")
51
+ parser.add_argument("--meta_file", type=str, default="meta_files/scenario1_5classes_tasks.json")
52
+ parser.add_argument("--base_meta_file", type=str, default="meta_files/scenario1_base.json")
53
+ parser.add_argument("--num_tasks", type=int, default=12, help="number of tasks")
54
+ parser.add_argument("--n_learnable_token", type=int, default=8, help="number of learnable token")
55
+ parser.add_argument("--checkpoints", type=str, default="scenario2/30classes", help="folder path to checkpoints")
56
+ parser.add_argument("--checkpoint_base", type=str, default="scenario2/adapters_base.safetensors", help="checkpoint base path")
57
+ parser.add_argument("--prompt_makder_ckpt", type=str, default="scenario2/prompt_maker.safetensors", help="prompt maker checkpoint path")
58
+ parser.add_argument("--task_id", type=int, default=1, help="test task id") # 0 - base classes
59
+ parser.add_argument("--save_path", type=str, default="results")
60
+ parser.add_argument("--data_root", type=str, default="data")
61
+
62
+ args = parser.parse_args()
63
+
64
+ setup_seed(args.seed)
65
+
66
+ use_cuda = torch.cuda.is_available()
67
+ device = torch.device("cuda:{}".format(args.gpu) if use_cuda else "cpu")
68
+
69
+ save_path = args.save_path
70
+ if not os.path.isdir(save_path):
71
+ os.makedirs(save_path)
72
+
73
+ # for logging
74
+ logger = get_logger(save_path)
75
+ logger.info(args)
76
+
77
+ # fixed feature extractor
78
+ clip_model = create_model(model_name=args.model_name, img_size=args.img_size, device=device, pretrained=args.pretrain, require_pretrained=True)
79
+
80
+ # prompt learner
81
+ prompts = {
82
+ "normal": [
83
+ "This is an example of a normal object",
84
+ "This is a typical appearance of the object",
85
+ "This is what a normal object looks like",
86
+ "A photo of a normal object",
87
+ "This is not an anomaly",
88
+ "This is an example of a standard object.",
89
+ "This is the standard appearance of the object.",
90
+ "This is what a standard object looks like.",
91
+ "A photo of a standard object.",
92
+ "This object meets standard characteristics."
93
+ ],
94
+ "abnormal": [
95
+ "This is an example of an anomalous object",
96
+ "This is not the typical appearance of the object",
97
+ "This is what an anomaly looks like",
98
+ "A photo of an anomalous object",
99
+ "An anomaly detected in this object",
100
+ "This is an example of an abnormal object.",
101
+ "This is not the usual appearance of the object.",
102
+ "This is what an abnormal object looks like.",
103
+ "A photo of an abnormal object.",
104
+ "An abnormality detected in this object."
105
+ ]
106
+ }
107
+
108
+ clip_model.device = device
109
+ clip_model.to(device)
110
+
111
+ prompt_maker = PromptMaker(
112
+ prompts=prompts,
113
+ clip_model=clip_model,
114
+ n_ctx= args.n_learnable_token,
115
+ CSC = True,
116
+ class_token_position=['end'],
117
+ ).to(device)
118
+
119
+ model = CLIPAD(clip_model=clip_model, features=args.features_list)
120
+ model.to(device)
121
+ model.eval()
122
+
123
+ # load checkpoint
124
+ if args.task_id == 0:
125
+ checkpoint_path = args.checkpoint_base
126
+ else:
127
+ checkpoint_path = f"{args.checkpoints}/adapters_task{args.task_id}.safetensors"
128
+
129
+ checkpoint = load_file(checkpoint_path)
130
+ model.adapters.load_state_dict(checkpoint)
131
+ logger.info(f"load adapter from {checkpoint_path}")
132
+ prompt_state_dict = load_file(args.prompt_makder_ckpt)
133
+ prompt_maker.prompt_learner.load_state_dict(prompt_state_dict)
134
+ logger.info(f"load prompt maker from {args.prompt_makder_ckpt}")
135
+
136
+ kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
137
+
138
+ # save results
139
+ num_tasks = args.num_tasks + 1
140
+ results_image = np.full((num_tasks, num_tasks), 0) # save for csv image-level
141
+ results_pixel = np.full((num_tasks, num_tasks), 0) # save for csv pixel-level
142
+
143
+ # load saved_results
144
+ csv_image = f"{save_path}/results_image.csv"
145
+ csv_pixel = f"{save_path}/results_pixel.csv"
146
+
147
+ if os.path.exists(csv_image):
148
+ with open(csv_image, mode="r") as file:
149
+ reader = csv.reader(file)
150
+ for i, row in enumerate(reader):
151
+ if not i == 0:
152
+ results_image[i-1] = row
153
+ logger.info(f"load previous results from {csv_image}")
154
+
155
+ if os.path.exists(csv_pixel):
156
+ with open(csv_pixel, mode="r") as file:
157
+ reader = csv.reader(file)
158
+ for i, row in enumerate(reader):
159
+ if not i == 0:
160
+ results_pixel[i-1] = row
161
+ logger.info(f"load previous results from {csv_pixel}")
162
+
163
+ prompt_maker.eval()
164
+ model.eval()
165
+
166
+ task_all_meta_info = json.load(open(args.meta_file, 'r'))
167
+
168
+ # test all previous tasks
169
+ for i in range(args.task_id + 1):
170
+ if i == 0: # base classes
171
+ task_meta = json.load(open(args.base_meta_file, 'r'))
172
+ logging.info(f"start base task test")
173
+ else:
174
+ task_meta = task_all_meta_info[f"task_{i}"]
175
+ logging.info(f"start task_{i} test")
176
+
177
+ class_name_list = list(task_meta["test"].keys())
178
+ test_dataset_list = [ImageDataset(data_root=args.data_root, meta_file=task_meta, resize=args.img_size, mode="test", test_class=class_name) for class_name in class_name_list]
179
+ test_loader_list = [torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, **kwargs) for test_dataset in test_dataset_list]
180
+
181
+ with torch.cuda.amp.autocast(), torch.no_grad():
182
+ # test all class
183
+ seg_ap_list = []
184
+ img_auc_list = []
185
+ prompt_maker.eval()
186
+ model.eval()
187
+ text_features = prompt_maker()
188
+
189
+ for test_loader, class_name in zip(test_loader_list, class_name_list):
190
+ logger.info(f"start test {class_name}")
191
+ roc_auc_im, seg_ap = test(args, model, test_loader, text_features, device)
192
+ logger.info(f'{class_name} P-AP : {round(seg_ap,4)}')
193
+ logger.info(f'{class_name} I-AUC : {round(roc_auc_im, 4)}')
194
+ seg_ap_list.append(seg_ap)
195
+ img_auc_list.append(roc_auc_im)
196
+
197
+ seg_ap_mean = np.mean(seg_ap_list)
198
+ img_auc_mean = np.mean(img_auc_list)
199
+
200
+ logger.info(f'Average P-AP : {round(seg_ap_mean,4)}')
201
+ logger.info(f'Average I-AUC : {round(img_auc_mean, 4)}')
202
+
203
+ # save results csv (i task)
204
+ seg_ap_mean = round(seg_ap_mean,4)
205
+ img_auc_mean = round(img_auc_mean, 4)
206
+ results_image[args.task_id, i] = img_auc_mean
207
+ results_pixel[args.task_id, i] = seg_ap_mean
208
+
209
+ logger.info(f"save results csv task {i}")
210
+
211
+ # save results csv
212
+ with open(csv_image, mode="w", newline="") as file:
213
+ writer = csv.writer(file)
214
+ writer.writerow(["Base"] + ["Task " + str(i + 1) for i in range(num_tasks-1)])
215
+ for row in results_image:
216
+ writer.writerow(row)
217
+ with open(csv_pixel, mode="w", newline="") as file:
218
+ writer = csv.writer(file)
219
+ writer.writerow(["Base"] + ["Task " + str(i + 1) for i in range(num_tasks-1)])
220
+ for row in results_pixel:
221
+ writer.writerow(row)
222
+
223
+ def test(args, model, test_loader, text_features, device):
224
+ gt_list = []
225
+ gt_mask_list = []
226
+
227
+ seg_score_map_zero = []
228
+ image_scores = []
229
+ for data in tqdm(test_loader):
230
+ image, mask, cls_name, label = data['image'], data['mask'], data['cls_name'], data['anomaly']
231
+ image = image.to(device)
232
+ mask[mask > 0.5], mask[mask <= 0.5] = 1, 0
233
+
234
+ with torch.no_grad(), torch.cuda.amp.autocast():
235
+ _, ada_patch_tokens = model(image)
236
+ ada_patch_tokens = [p[0, 1:, :] for p in ada_patch_tokens]
237
+
238
+ anomaly_maps = []
239
+ image_score = 0
240
+ for layer in range(len(ada_patch_tokens)):
241
+ ada_patch_tokens[layer] /= ada_patch_tokens[layer].norm(dim=-1, keepdim=True)
242
+ anomaly_map = (100.0 * ada_patch_tokens[layer] @ text_features).unsqueeze(0)
243
+ B, L, C = anomaly_map.shape
244
+ H = int(np.sqrt(L))
245
+
246
+ # image
247
+ anomaly_score = torch.softmax(anomaly_map, dim=-1)[:, :, 1]
248
+ image_score += anomaly_score.max()
249
+
250
+ anomaly_maps.append(anomaly_map)
251
+
252
+ score_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1)
253
+ score_map = F.interpolate(score_map.permute(0, 2, 1).view(B, 2, H, H),
254
+ size=args.img_size, mode='bilinear', align_corners=True)
255
+ score_map = torch.softmax(score_map, dim=1)[:, 1, :, :]
256
+ score_map = score_map.squeeze(0).cpu().numpy()
257
+ seg_score_map_zero.append(score_map)
258
+ image_scores.append(image_score.cpu() / len(ada_patch_tokens))
259
+
260
+ gt_mask_list.append(mask.squeeze().cpu().detach().numpy())
261
+ gt_list.extend(label.cpu().detach().numpy())
262
+
263
+
264
+ gt_list = np.array(gt_list)
265
+ gt_mask_list = np.asarray(gt_mask_list)
266
+ gt_mask_list = (gt_mask_list>0).astype(np.int_)
267
+
268
+ segment_scores = np.array(seg_score_map_zero)
269
+ image_scores = np.array(image_scores)
270
+
271
+ roc_auc_im = roc_auc_score(gt_list, image_scores)
272
+
273
+ seg_pr = average_precision_score(gt_mask_list.flatten(), segment_scores.flatten())
274
+
275
+ return roc_auc_im, seg_pr
276
+
277
+
278
+ if __name__ == '__main__':
279
+ main()
280
+
281
+