| | import os |
| | import argparse |
| | import random |
| | import numpy as np |
| | import torch |
| | from torch.nn import functional as F |
| | from tqdm import tqdm |
| | from CLIP.clip import create_model |
| | from CLIP.adapter import CLIPAD |
| | from sklearn.metrics import roc_auc_score, average_precision_score |
| | from dataset.continual import ImageDataset |
| | import csv |
| | import logging |
| | from CoOp import PromptMaker |
| | import json |
| | from safetensors.torch import load_file |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| | def setup_seed(seed): |
| | os.environ['PYTHONHASHSEED'] = str(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| |
|
| | def get_logger(output_dir): |
| | |
| | log_file = f"{output_dir}/log.log" |
| | head = '%(asctime)-15s %(message)s' |
| | logging.basicConfig(filename=log_file, |
| | format=head) |
| | logger = logging.getLogger() |
| | logger.setLevel(logging.INFO) |
| | console = logging.StreamHandler() |
| | logging.getLogger('').addHandler(console) |
| |
|
| | return logger |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Evaluation') |
| | parser.add_argument('--model_name', type=str, default='ViT-L-14-336', help="ViT-B-16-plus-240, ViT-L-14-336") |
| | parser.add_argument('--pretrain', type=str, default='openai', help="laion400m, openai") |
| | parser.add_argument('--img_size', type=int, default=336) |
| | parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used") |
| | parser.add_argument('--seed', type=int, default=111) |
| | parser.add_argument('--gpu', type=str, default="0") |
| | parser.add_argument("--meta_file", type=str, default="meta_files/meta_mvtec.json") |
| | parser.add_argument("--n_learnable_token", type=int, default=8, help="number of learnable token") |
| | parser.add_argument("--adapter_ckpt", type=str, default="scenario2/30classes/adapters_sc2_task2.safetensors", help="adapter checkpoint path") |
| | parser.add_argument("--prompt_makder_ckpt", type=str, default="scenario2/30classes/prompt_maker_sc2.safetensors", help="prompt maker checkpoint path") |
| | parser.add_argument("--save_path", type=str, default="results_zero") |
| | parser.add_argument("--data_root", type=str, default="data/mvtec_anomaly_detection") |
| |
|
| | args = parser.parse_args() |
| |
|
| | setup_seed(args.seed) |
| |
|
| | use_cuda = torch.cuda.is_available() |
| | device = torch.device("cuda:{}".format(args.gpu) if use_cuda else "cpu") |
| |
|
| | save_path = args.save_path |
| | if not os.path.isdir(save_path): |
| | os.makedirs(save_path) |
| | |
| | |
| | logger = get_logger(save_path) |
| | logger.info(args) |
| | |
| | |
| | clip_model = create_model(model_name=args.model_name, img_size=args.img_size, device=device, pretrained=args.pretrain, require_pretrained=True) |
| | |
| | |
| | prompts = { |
| | "normal": [ |
| | "This is an example of a normal object", |
| | "This is a typical appearance of the object", |
| | "This is what a normal object looks like", |
| | "A photo of a normal object", |
| | "This is not an anomaly", |
| | "This is an example of a standard object.", |
| | "This is the standard appearance of the object.", |
| | "This is what a standard object looks like.", |
| | "A photo of a standard object.", |
| | "This object meets standard characteristics." |
| | ], |
| | "abnormal": [ |
| | "This is an example of an anomalous object", |
| | "This is not the typical appearance of the object", |
| | "This is what an anomaly looks like", |
| | "A photo of an anomalous object", |
| | "An anomaly detected in this object", |
| | "This is an example of an abnormal object.", |
| | "This is not the usual appearance of the object.", |
| | "This is what an abnormal object looks like.", |
| | "A photo of an abnormal object.", |
| | "An abnormality detected in this object." |
| | ] |
| | } |
| |
|
| | clip_model.device = device |
| | clip_model.to(device) |
| |
|
| | prompt_maker = PromptMaker( |
| | prompts=prompts, |
| | clip_model=clip_model, |
| | n_ctx= args.n_learnable_token, |
| | CSC = True, |
| | class_token_position=['end'], |
| | ).to(device) |
| |
|
| | model = CLIPAD(clip_model=clip_model, features=args.features_list) |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | adpater_state_dict = load_file(args.adapter_ckpt) |
| | model.adapters.load_state_dict(adpater_state_dict) |
| | logger.info(f"load adapter from {args.adapter_ckpt}") |
| | prompt_state_dict = load_file(args.prompt_makder_ckpt) |
| | prompt_maker.prompt_learner.load_state_dict(prompt_state_dict) |
| | logger.info(f"load prompt maker from {args.prompt_makder_ckpt}") |
| |
|
| | kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {} |
| |
|
| | prompt_maker.eval() |
| | model.eval() |
| | |
| | logging.info(f"start zero shot {args.meta_file} test") |
| | task_meta = json.load(open(args.meta_file, 'r')) |
| |
|
| | class_name_list = list(task_meta["test"].keys()) |
| | test_dataset_list = [ImageDataset(data_root=args.data_root, meta_file=task_meta, resize=args.img_size, mode="test", test_class=class_name) for class_name in class_name_list] |
| | test_loader_list = [torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, **kwargs) for test_dataset in test_dataset_list] |
| |
|
| | with torch.cuda.amp.autocast(), torch.no_grad(): |
| | |
| | seg_ap_list = [] |
| | img_auc_list = [] |
| | prompt_maker.eval() |
| | model.eval() |
| | text_features = prompt_maker() |
| |
|
| | for test_loader, class_name in zip(test_loader_list, class_name_list): |
| | logger.info(f"start test {class_name}") |
| | roc_auc_im, seg_ap = test(args, model, test_loader, text_features, device) |
| | logger.info(f'{class_name} P-AP : {round(seg_ap,4)}') |
| | logger.info(f'{class_name} I-AUC : {round(roc_auc_im, 4)}') |
| | seg_ap_list.append(seg_ap) |
| | img_auc_list.append(roc_auc_im) |
| |
|
| | seg_ap_mean = np.mean(seg_ap_list) |
| | img_auc_mean = np.mean(img_auc_list) |
| |
|
| | logger.info(f'Average P-AP : {round(seg_ap_mean,4)}') |
| | logger.info(f'Average I-AUC : {round(img_auc_mean, 4)}') |
| |
|
| |
|
| | def test(args, model, test_loader, text_features, device): |
| | gt_list = [] |
| | gt_mask_list = [] |
| |
|
| | seg_score_map_zero = [] |
| | image_scores = [] |
| | for data in tqdm(test_loader): |
| | image, mask, cls_name, label = data['image'], data['mask'], data['cls_name'], data['anomaly'] |
| | image = image.to(device) |
| | mask[mask > 0.5], mask[mask <= 0.5] = 1, 0 |
| |
|
| | with torch.no_grad(), torch.cuda.amp.autocast(): |
| | _, ada_patch_tokens = model(image) |
| | ada_patch_tokens = [p[0, 1:, :] for p in ada_patch_tokens] |
| |
|
| | anomaly_maps = [] |
| | image_score = 0 |
| | for layer in range(len(ada_patch_tokens)): |
| | ada_patch_tokens[layer] /= ada_patch_tokens[layer].norm(dim=-1, keepdim=True) |
| | anomaly_map = (100.0 * ada_patch_tokens[layer] @ text_features).unsqueeze(0) |
| | B, L, C = anomaly_map.shape |
| | H = int(np.sqrt(L)) |
| |
|
| | |
| | anomaly_score = torch.softmax(anomaly_map, dim=-1)[:, :, 1] |
| | image_score += anomaly_score.max() |
| | |
| | anomaly_maps.append(anomaly_map) |
| |
|
| | score_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1) |
| | score_map = F.interpolate(score_map.permute(0, 2, 1).view(B, 2, H, H), |
| | size=args.img_size, mode='bilinear', align_corners=True) |
| | score_map = torch.softmax(score_map, dim=1)[:, 1, :, :] |
| | score_map = score_map.squeeze(0).cpu().numpy() |
| | seg_score_map_zero.append(score_map) |
| | image_scores.append(image_score.cpu() / len(ada_patch_tokens)) |
| | |
| | gt_mask_list.append(mask.squeeze().cpu().detach().numpy()) |
| | gt_list.extend(label.cpu().detach().numpy()) |
| |
|
| | |
| | gt_list = np.array(gt_list) |
| | gt_mask_list = np.asarray(gt_mask_list) |
| | gt_mask_list = (gt_mask_list>0).astype(np.int_) |
| |
|
| | segment_scores = np.array(seg_score_map_zero) |
| | image_scores = np.array(image_scores) |
| |
|
| | roc_auc_im = roc_auc_score(gt_list, image_scores) |
| |
|
| | seg_pr = average_precision_score(gt_mask_list.flatten(), segment_scores.flatten()) |
| |
|
| | return roc_auc_im, seg_pr |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|
| |
|
| |
|