from PIL import Image from tqdm import tqdm from torchvision import transforms import argparse import os from natsort import natsorted from glob import glob import torch import torch.nn as nn import torch.nn.functional as F import timm from tqdm import tqdm import numpy as np model = ['swin', 'beit', 'dmnfnet', 'ecaresnet_50', 'efficient', 'regnet', 'vit', 'convnext'] def main(input_model=None): parser = argparse.ArgumentParser(description='Quick demo Image Classification') parser.add_argument('--input_dir', default='./test/', type=str, help='Input images root') parser.add_argument('--result_dir', default='./result/', type=str, help='Results images root') parser.add_argument('--weights_root', default='experiments/pretrained_models', type=str, help='Weights root') parser.add_argument('--model', default='convnext', type=str, help='Classifier') args = parser.parse_args() args.model = input_model inp_dir = args.input_dir out_dir = args.result_dir os.makedirs(out_dir, exist_ok=True) files = natsorted(glob(os.path.join(inp_dir, '*.jpg')) + glob(os.path.join(inp_dir, '*.png'))) model,img_size = build_model(args.model, False, args.weights_root) print('Start predicting......') result = [] for i, file_ in enumerate(tqdm(files)): image_name = os.path.split(file_)[-1] img = Image.open(file_).convert('RGB') input_ = transform_size(img_size )(img).unsqueeze(0) with torch.no_grad(): predict_result = model(input_) # prob = torch.argmax(predict_result, dim=1).item() top5 = torch.topk(predict_result, 5).indices.tolist() result = {} predict_result = predict_result.tolist() for i in top5[0]: label = int(i) # class number prob = predict_result[0][i] result[label] = prob print('result:', result) return result def transform_size(size: int): # mean & std for different sizes mean = {224: (0.5446, 0.4137, 0.3847), 256: (0.5364, 0.4142, 0.3821), 320: (0.5188, 0.4166, 0.3773), 352: (0.5100, 0.4183, 0.3750), 384: (0.5015, 0.4198, 0.3728), 480: (0.4806, 0.4232, 0.3675)} std = {224: (0.2329, 0.2484, 0.2500), 256: (0.2354, 0.2470, 0.2490), 320: (0.2403, 0.2442, 0.2479), 352: (0.2423, 0.2431, 0.2479), 384: (0.2440, 0.2424, 0.2481), 480: (0.2478, 0.2423, 0.2500)} transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=480, interpolation=3), transforms.CenterCrop(size), transforms.Normalize(mean=mean[size], std=std[size]) ]) return transform def load_checkpoint(model, weights): checkpoint = torch.load(weights, map_location=torch.device('cpu')) model.load_state_dict(checkpoint) def build_model(model: str, pretrained: bool, pretrained_path: str): models = ['vit', 'beit', 'swin', 'convnext', 'ecaresnet50', 'dmnfnet', 'regnet', 'efficient'] if model == 'vit': classifier = vit_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) return classifier.eval(), 384 elif model == 'beit': classifier = beit_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'_1.pth')) return classifier.eval(), 384 elif model == 'swin': classifier = swin_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) return classifier.eval(), 384 elif model == 'convnext': classifier = convnext_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) return classifier.eval(), 384 elif model == 'ecaresnet_50': classifier = ecaresnet50_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) return classifier.eval(), 320 elif model == 'dmnfnet': classifier = dmnfnet_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) return classifier.eval(), 256 elif model == 'regnet': classifier = regnetz_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) return classifier.eval(), 320 elif model == 'efficient': classifier = efficientnet_model(pretrained=pretrained) load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) return classifier.eval(), 480 else: raise Exception( "\nNo corresponding model! \nPlease enter the supported model: \n\n{}".format('\n'.join(models))) def save_img(filepath, img): cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) def clean_folder(folder): for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print('Failed to delete %s. Reason: %s' % (file_path, e)) class beit_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(beit_model, self).__init__() self.model = timm.create_model('beit_base_patch16_384', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class convnext_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(convnext_model, self).__init__() self.model = timm.create_model('convnext_base_384_in22ft1k', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class swin_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(swin_model, self).__init__() self.model = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class vit_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(vit_model, self).__init__() self.model = timm.create_model('vit_base_patch16_384', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class resmlp_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(resmlp_model, self).__init__() self.model = timm.create_model('resmlp_big_24_224_in22ft1k', pretrained=pretrained, num_classes=classes) class xcittiny_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(xcittiny_model, self).__init__() self.model = timm.create_model('xcit_tiny_12_p8_384_dist', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class ecaresnet269_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(ecaresnet269_model, self).__init__() self.model = timm.create_model('ecaresnet269d', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class dmnfnet_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(dmnfnet_model, self).__init__() self.model = timm.create_model('dm_nfnet_f0', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class ecaresnet50_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(ecaresnet50_model, self).__init__() self.model = timm.create_model('ecaresnet50t', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class regnetz_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(regnetz_model, self).__init__() self.model = timm.create_model('regnetz_e8', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) class efficientnet_model(nn.Module): def __init__(self, classes=219, pretrained=True): super(efficientnet_model, self).__init__() self.model = timm.create_model('tf_efficientnetv2_m_in21ft1k', pretrained=pretrained, num_classes=classes) def forward(self, x): return self.model(x) def transform_size(size: int): # mean & std for different sizes mean = {224: (0.5446, 0.4137, 0.3847), 256: (0.5364, 0.4142, 0.3821), 320: (0.5188, 0.4166, 0.3773), 352: (0.5100, 0.4183, 0.3750), 384: (0.5015, 0.4198, 0.3728), 480: (0.4806, 0.4232, 0.3675)} std = {224: (0.2329, 0.2484, 0.2500), 256: (0.2354, 0.2470, 0.2490), 320: (0.2403, 0.2442, 0.2479), 352: (0.2423, 0.2431, 0.2479), 384: (0.2440, 0.2424, 0.2481), 480: (0.2478, 0.2423, 0.2500)} transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=480, interpolation=3), transforms.CenterCrop(size), transforms.Normalize(mean=mean[size], std=std[size]) ]) return transform def build_ensemble_model(model: dict, pretrained: bool): """ Args: model: { CLASSIFIER1: ['vit', 384, 'pretrained/vit_testmodel.pth'], CLASSIFIER2: ['beit', 384, 'pretrained/beit_testmodel.pth'], CLASSIFIER3: ['swin', 384, 'pretrained/swin_testmodel.pth'], CLASSIFIER4: ['convnext', 384, 'pretrained/convnext_fold1_best_acc.pth'] } Returns: [[finish loading pretrained model, corresponding transform], ...] """ print('==> Build and load the ensemble models') ensemble_model = [] for i, key in enumerate(tqdm(model)): #print(model[key][0]) value = model[key] each_model = build_model(model=value[0], pretrained=pretrained) load_checkpoint(each_model, value[2]) each_model.eval() ensemble_model.append([each_model, transform_size(value[1]), value[1]]) return ensemble_model if __name__ == '__main__': main(model=None)