Orchid_classification_AICUP / app_predict.py
52Hz's picture
Update app_predict.py
dd6626e
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
import argparse
import os
from natsort import natsorted
from glob import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm import tqdm
import numpy as np
model = ['swin',
'beit',
'dmnfnet',
'ecaresnet_50',
'efficient',
'regnet',
'vit',
'convnext']
def main(input_model=None):
parser = argparse.ArgumentParser(description='Quick demo Image Classification')
parser.add_argument('--input_dir', default='./test/', type=str, help='Input images root')
parser.add_argument('--result_dir', default='./result/', type=str, help='Results images root')
parser.add_argument('--weights_root', default='experiments/pretrained_models', type=str, help='Weights root')
parser.add_argument('--model', default='convnext', type=str, help='Classifier')
args = parser.parse_args()
args.model = input_model
inp_dir = args.input_dir
out_dir = args.result_dir
os.makedirs(out_dir, exist_ok=True)
files = natsorted(glob(os.path.join(inp_dir, '*.jpg')) + glob(os.path.join(inp_dir, '*.png')))
model,img_size = build_model(args.model, False, args.weights_root)
print('Start predicting......')
result = []
for i, file_ in enumerate(tqdm(files)):
image_name = os.path.split(file_)[-1]
img = Image.open(file_).convert('RGB')
input_ = transform_size(img_size )(img).unsqueeze(0)
with torch.no_grad():
predict_result = model(input_)
# prob = torch.argmax(predict_result, dim=1).item()
top5 = torch.topk(predict_result, 5).indices.tolist()
result = {}
predict_result = predict_result.tolist()
for i in top5[0]:
label = int(i) # class number
prob = predict_result[0][i]
result[label] = prob
print('result:', result)
return result
def transform_size(size: int):
# mean & std for different sizes
mean = {224: (0.5446, 0.4137, 0.3847),
256: (0.5364, 0.4142, 0.3821),
320: (0.5188, 0.4166, 0.3773),
352: (0.5100, 0.4183, 0.3750),
384: (0.5015, 0.4198, 0.3728),
480: (0.4806, 0.4232, 0.3675)}
std = {224: (0.2329, 0.2484, 0.2500),
256: (0.2354, 0.2470, 0.2490),
320: (0.2403, 0.2442, 0.2479),
352: (0.2423, 0.2431, 0.2479),
384: (0.2440, 0.2424, 0.2481),
480: (0.2478, 0.2423, 0.2500)}
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=480, interpolation=3),
transforms.CenterCrop(size),
transforms.Normalize(mean=mean[size], std=std[size])
])
return transform
def load_checkpoint(model, weights):
checkpoint = torch.load(weights, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
def build_model(model: str, pretrained: bool, pretrained_path: str):
models = ['vit', 'beit', 'swin', 'convnext', 'ecaresnet50', 'dmnfnet', 'regnet', 'efficient']
if model == 'vit':
classifier = vit_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth'))
return classifier.eval(), 384
elif model == 'beit':
classifier = beit_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'_1.pth'))
return classifier.eval(), 384
elif model == 'swin':
classifier = swin_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth'))
return classifier.eval(), 384
elif model == 'convnext':
classifier = convnext_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth'))
return classifier.eval(), 384
elif model == 'ecaresnet_50':
classifier = ecaresnet50_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth'))
return classifier.eval(), 320
elif model == 'dmnfnet':
classifier = dmnfnet_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth'))
return classifier.eval(), 256
elif model == 'regnet':
classifier = regnetz_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth'))
return classifier.eval(), 320
elif model == 'efficient':
classifier = efficientnet_model(pretrained=pretrained)
load_checkpoint(classifier, os.path.join(pretrained_path, model+'.pth'))
return classifier.eval(), 480
else:
raise Exception(
"\nNo corresponding model! \nPlease enter the supported model: \n\n{}".format('\n'.join(models)))
def save_img(filepath, img):
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
class beit_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(beit_model, self).__init__()
self.model = timm.create_model('beit_base_patch16_384', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class convnext_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(convnext_model, self).__init__()
self.model = timm.create_model('convnext_base_384_in22ft1k', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class swin_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(swin_model, self).__init__()
self.model = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class vit_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(vit_model, self).__init__()
self.model = timm.create_model('vit_base_patch16_384', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class resmlp_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(resmlp_model, self).__init__()
self.model = timm.create_model('resmlp_big_24_224_in22ft1k', pretrained=pretrained, num_classes=classes)
class xcittiny_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(xcittiny_model, self).__init__()
self.model = timm.create_model('xcit_tiny_12_p8_384_dist', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class ecaresnet269_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(ecaresnet269_model, self).__init__()
self.model = timm.create_model('ecaresnet269d', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class dmnfnet_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(dmnfnet_model, self).__init__()
self.model = timm.create_model('dm_nfnet_f0', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class ecaresnet50_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(ecaresnet50_model, self).__init__()
self.model = timm.create_model('ecaresnet50t', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class regnetz_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(regnetz_model, self).__init__()
self.model = timm.create_model('regnetz_e8', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
class efficientnet_model(nn.Module):
def __init__(self, classes=219, pretrained=True):
super(efficientnet_model, self).__init__()
self.model = timm.create_model('tf_efficientnetv2_m_in21ft1k', pretrained=pretrained, num_classes=classes)
def forward(self, x):
return self.model(x)
def transform_size(size: int):
# mean & std for different sizes
mean = {224: (0.5446, 0.4137, 0.3847),
256: (0.5364, 0.4142, 0.3821),
320: (0.5188, 0.4166, 0.3773),
352: (0.5100, 0.4183, 0.3750),
384: (0.5015, 0.4198, 0.3728),
480: (0.4806, 0.4232, 0.3675)}
std = {224: (0.2329, 0.2484, 0.2500),
256: (0.2354, 0.2470, 0.2490),
320: (0.2403, 0.2442, 0.2479),
352: (0.2423, 0.2431, 0.2479),
384: (0.2440, 0.2424, 0.2481),
480: (0.2478, 0.2423, 0.2500)}
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=480, interpolation=3),
transforms.CenterCrop(size),
transforms.Normalize(mean=mean[size], std=std[size])
])
return transform
def build_ensemble_model(model: dict, pretrained: bool):
"""
Args:
model:
{
CLASSIFIER1: ['vit', 384, 'pretrained/vit_testmodel.pth'],
CLASSIFIER2: ['beit', 384, 'pretrained/beit_testmodel.pth'],
CLASSIFIER3: ['swin', 384, 'pretrained/swin_testmodel.pth'],
CLASSIFIER4: ['convnext', 384, 'pretrained/convnext_fold1_best_acc.pth']
}
Returns: [[finish loading pretrained model, corresponding transform], ...]
"""
print('==> Build and load the ensemble models')
ensemble_model = []
for i, key in enumerate(tqdm(model)):
#print(model[key][0])
value = model[key]
each_model = build_model(model=value[0], pretrained=pretrained)
load_checkpoint(each_model, value[2])
each_model.eval()
ensemble_model.append([each_model, transform_size(value[1]), value[1]])
return ensemble_model
if __name__ == '__main__':
main(model=None)