Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| from tqdm import tqdm | |
| from torchvision import transforms | |
| import argparse | |
| import os | |
| from natsort import natsorted | |
| from glob import glob | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| from tqdm import tqdm | |
| import numpy as np | |
| model = ['swin', | |
| 'beit', | |
| 'dmnfnet', | |
| 'ecaresnet_50', | |
| 'efficient', | |
| 'regnet', | |
| 'vit', | |
| 'convnext'] | |
| def main(input_model=None): | |
| parser = argparse.ArgumentParser(description='Quick demo Image Classification') | |
| parser.add_argument('--input_dir', default='./test/', type=str, help='Input images root') | |
| parser.add_argument('--result_dir', default='./result/', type=str, help='Results images root') | |
| parser.add_argument('--weights_root', default='experiments/pretrained_models', type=str, help='Weights root') | |
| parser.add_argument('--model', default='convnext', type=str, help='Classifier') | |
| args = parser.parse_args() | |
| args.model = input_model | |
| inp_dir = args.input_dir | |
| out_dir = args.result_dir | |
| os.makedirs(out_dir, exist_ok=True) | |
| files = natsorted(glob(os.path.join(inp_dir, '*.jpg')) + glob(os.path.join(inp_dir, '*.png'))) | |
| model,img_size = build_model(args.model, False, args.weights_root) | |
| print('Start predicting......') | |
| result = [] | |
| for i, file_ in enumerate(tqdm(files)): | |
| image_name = os.path.split(file_)[-1] | |
| img = Image.open(file_).convert('RGB') | |
| input_ = transform_size(img_size )(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| predict_result = model(input_) | |
| # prob = torch.argmax(predict_result, dim=1).item() | |
| top5 = torch.topk(predict_result, 5).indices.tolist() | |
| result = {} | |
| predict_result = predict_result.tolist() | |
| for i in top5[0]: | |
| label = int(i) # class number | |
| prob = predict_result[0][i] | |
| result[label] = prob | |
| print('result:', result) | |
| return result | |
| def transform_size(size: int): | |
| # mean & std for different sizes | |
| mean = {224: (0.5446, 0.4137, 0.3847), | |
| 256: (0.5364, 0.4142, 0.3821), | |
| 320: (0.5188, 0.4166, 0.3773), | |
| 352: (0.5100, 0.4183, 0.3750), | |
| 384: (0.5015, 0.4198, 0.3728), | |
| 480: (0.4806, 0.4232, 0.3675)} | |
| std = {224: (0.2329, 0.2484, 0.2500), | |
| 256: (0.2354, 0.2470, 0.2490), | |
| 320: (0.2403, 0.2442, 0.2479), | |
| 352: (0.2423, 0.2431, 0.2479), | |
| 384: (0.2440, 0.2424, 0.2481), | |
| 480: (0.2478, 0.2423, 0.2500)} | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize(size=480, interpolation=3), | |
| transforms.CenterCrop(size), | |
| transforms.Normalize(mean=mean[size], std=std[size]) | |
| ]) | |
| return transform | |
| def load_checkpoint(model, weights): | |
| checkpoint = torch.load(weights, map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint) | |
| def build_model(model: str, pretrained: bool, pretrained_path: str): | |
| models = ['vit', 'beit', 'swin', 'convnext', 'ecaresnet50', 'dmnfnet', 'regnet', 'efficient'] | |
| if model == 'vit': | |
| classifier = vit_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) | |
| return classifier.eval(), 384 | |
| elif model == 'beit': | |
| classifier = beit_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'_1.pth')) | |
| return classifier.eval(), 384 | |
| elif model == 'swin': | |
| classifier = swin_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) | |
| return classifier.eval(), 384 | |
| elif model == 'convnext': | |
| classifier = convnext_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) | |
| return classifier.eval(), 384 | |
| elif model == 'ecaresnet_50': | |
| classifier = ecaresnet50_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) | |
| return classifier.eval(), 320 | |
| elif model == 'dmnfnet': | |
| classifier = dmnfnet_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) | |
| return classifier.eval(), 256 | |
| elif model == 'regnet': | |
| classifier = regnetz_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) | |
| return classifier.eval(), 320 | |
| elif model == 'efficient': | |
| classifier = efficientnet_model(pretrained=pretrained) | |
| load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth')) | |
| return classifier.eval(), 480 | |
| else: | |
| raise Exception( | |
| "\nNo corresponding model! \nPlease enter the supported model: \n\n{}".format('\n'.join(models))) | |
| def save_img(filepath, img): | |
| cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | |
| def clean_folder(folder): | |
| for filename in os.listdir(folder): | |
| file_path = os.path.join(folder, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print('Failed to delete %s. Reason: %s' % (file_path, e)) | |
| class beit_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(beit_model, self).__init__() | |
| self.model = timm.create_model('beit_base_patch16_384', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class convnext_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(convnext_model, self).__init__() | |
| self.model = timm.create_model('convnext_base_384_in22ft1k', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class swin_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(swin_model, self).__init__() | |
| self.model = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class vit_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(vit_model, self).__init__() | |
| self.model = timm.create_model('vit_base_patch16_384', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class resmlp_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(resmlp_model, self).__init__() | |
| self.model = timm.create_model('resmlp_big_24_224_in22ft1k', pretrained=pretrained, num_classes=classes) | |
| class xcittiny_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(xcittiny_model, self).__init__() | |
| self.model = timm.create_model('xcit_tiny_12_p8_384_dist', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ecaresnet269_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(ecaresnet269_model, self).__init__() | |
| self.model = timm.create_model('ecaresnet269d', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class dmnfnet_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(dmnfnet_model, self).__init__() | |
| self.model = timm.create_model('dm_nfnet_f0', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class ecaresnet50_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(ecaresnet50_model, self).__init__() | |
| self.model = timm.create_model('ecaresnet50t', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class regnetz_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(regnetz_model, self).__init__() | |
| self.model = timm.create_model('regnetz_e8', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class efficientnet_model(nn.Module): | |
| def __init__(self, classes=219, pretrained=True): | |
| super(efficientnet_model, self).__init__() | |
| self.model = timm.create_model('tf_efficientnetv2_m_in21ft1k', pretrained=pretrained, num_classes=classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| def transform_size(size: int): | |
| # mean & std for different sizes | |
| mean = {224: (0.5446, 0.4137, 0.3847), | |
| 256: (0.5364, 0.4142, 0.3821), | |
| 320: (0.5188, 0.4166, 0.3773), | |
| 352: (0.5100, 0.4183, 0.3750), | |
| 384: (0.5015, 0.4198, 0.3728), | |
| 480: (0.4806, 0.4232, 0.3675)} | |
| std = {224: (0.2329, 0.2484, 0.2500), | |
| 256: (0.2354, 0.2470, 0.2490), | |
| 320: (0.2403, 0.2442, 0.2479), | |
| 352: (0.2423, 0.2431, 0.2479), | |
| 384: (0.2440, 0.2424, 0.2481), | |
| 480: (0.2478, 0.2423, 0.2500)} | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize(size=480, interpolation=3), | |
| transforms.CenterCrop(size), | |
| transforms.Normalize(mean=mean[size], std=std[size]) | |
| ]) | |
| return transform | |
| def build_ensemble_model(model: dict, pretrained: bool): | |
| """ | |
| Args: | |
| model: | |
| { | |
| CLASSIFIER1: ['vit', 384, 'pretrained/vit_testmodel.pth'], | |
| CLASSIFIER2: ['beit', 384, 'pretrained/beit_testmodel.pth'], | |
| CLASSIFIER3: ['swin', 384, 'pretrained/swin_testmodel.pth'], | |
| CLASSIFIER4: ['convnext', 384, 'pretrained/convnext_fold1_best_acc.pth'] | |
| } | |
| Returns: [[finish loading pretrained model, corresponding transform], ...] | |
| """ | |
| print('==> Build and load the ensemble models') | |
| ensemble_model = [] | |
| for i, key in enumerate(tqdm(model)): | |
| #print(model[key][0]) | |
| value = model[key] | |
| each_model = build_model(model=value[0], pretrained=pretrained) | |
| load_checkpoint(each_model, value[2]) | |
| each_model.eval() | |
| ensemble_model.append([each_model, transform_size(value[1]), value[1]]) | |
| return ensemble_model | |
| if __name__ == '__main__': | |
| main(model=None) |