import os import glob import cv2 import spaces from torchvision import transforms,utils import torch import classifiers.MixMatch.models.wideresnet as wmodels from PIL import Image import numpy as np import classifiers.MixMatch.dataset.waterfowl as dataset import torch.nn as nn import torch.nn.functional as F import random from collections import Counter import pandas as pd import matplotlib.pyplot as plt category_dict = { "0": "American Widgeon_Female", "1": "American Widgeon_Male", "2": "Canada Goose", "3": "Canvasback_Male", "4": "Coot", "5": "Gadwall", "6": "Green-winged teal", "7": "Mallard Female", "8": "Mallard Male", "9": "Not a bird", "10": "Pelican", "11": "Pintail_Female", "12": "Pintail_Male", "13": "Ring-necked duck Female", "14": "Ring-necked duck Male", "15": "Scaup_Male", "16": "Shoveler_Female", "17": "Shoveler_Male", "18": "Snow", "19": "Unknown", "20": "White-fronted Goose" } test_transform = transforms.Compose([ dataset.ToTensor(), ]) device = torch.device('cuda') cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255 cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255 def window_jittering(box,mega_image,num_box=5): # h,w,c = mega_image.shape w, h = mega_image.size x1,y1,x2,y2 = box jittering_box = [] jittering_box.append([int(x1),int(y1),int(x2),int(y2)]) for _ in range(num_box-1): x1_random = random.random()*0.4-0.1 x2_random = random.random()*0.4-0.1 y1_random = random.random()*0.4-0.1 y2_random = random.random()*0.4-0.1 x1_jitter = int(max(x1-(x2-x1)*x1_random,0)) x2_jitter = int(min(x2+(x2-x1)*x2_random,w)) y1_jitter = int(max(y1-(y2-y1)*y1_random,0)) y2_jitter = int(min(y2+(y2-y1)*y2_random,h)) jittering_box.append([x1_jitter,y1_jitter,x2_jitter,y2_jitter]) return jittering_box def transpose(x, source='NHWC', target='NCHW'): return x.transpose([source.index(d) for d in target]) @spaces.GPU def predict_methods(mixmatch_model,box,mega_image,method = 'baseline'): # if category == 'Snow/Ross Goose' or category == 'Snow/Ross Goose (blue)': # category = 'Snow' # elif category not in category_dict.values(): # category = 'Unknown' if method == 'baseline': bird_crop = prepare_data_mixmatch(mega_image,box) out = mixmatch_model(bird_crop) _, pred = out.topk(1, 1, True, True) pred_cate = category_dict[str(np.array(pred.cpu())[0][0])] elif method == 'voting': jittering_boxes = window_jittering(box,mega_image) predictions = [] for box in jittering_boxes: bird_crop = prepare_data_mixmatch(mega_image,box) out = mixmatch_model(bird_crop) # out = F.softmax(out,dim=1) pred_prob, pred_class = out.topk(1, 1, True, True) pred_cate = category_dict[str(np.array(pred_class.cpu())[0][0])] predictions.append(pred_cate) pred_cate = max(set(predictions), key=predictions.count) elif method == 'prob_sum': jittering_boxes = window_jittering(box,mega_image,10) score_list = [0 for _ in range(21)] for box in jittering_boxes: bird_crop = prepare_data_mixmatch(mega_image,box) out = mixmatch_model(bird_crop) out = F.softmax(out,dim=1) # sigmoid = nn.Sigmoid() # out = sigmoid(out) pred_prob, pred_class = out.topk(5, 1, True, True) pred_prob = pred_prob.cpu().numpy()[0] pred_class = pred_class.cpu().numpy()[0] for i in range(len(pred_prob)): score_list[pred_class[i]] += pred_prob[i] pred_cate = category_dict[str(np.argmax(np.array(score_list)))] return pred_cate @spaces.GPU def create_model(ema=False): model = wmodels.WideResNet(num_classes=21) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model @spaces.GPU def get_model_mixmatch(checkpoint_dir,ema=True): checkpoint = torch.load(checkpoint_dir) if ema: ema_model = create_model(ema=True) ema_model.load_state_dict(checkpoint['ema_state_dict']) ema_model.to(device) print('loaded ema_model') return ema_model.eval() else: model = create_model() model.load_state_dict(checkpoint['state_dict']) model.to(device) print('loaded model') return model.eval() def normalize(x, mean=cifar10_mean, std=cifar10_std): x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)] x -= mean*255 x *= 1.0/(255*std) return x def prepare_data_mixmatch(mega_image,box): [x1,y1,x2,y2] = box h,w,c = np.shape(mega_image) bird_crop = cv2.resize(mega_image[max(y1,0):min(y2,h),max(x1,0):min(x2,w),:],(32,32)) bird_crop = test_transform(transpose(normalize(np.array([bird_crop])))).to(device) return bird_crop def plt2arr(plt): canvas = plt.gca().figure.canvas canvas.draw() data = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8) image = data.reshape(canvas.get_width_height()[::-1] + (3,)) return image def draw_bar_chart(num_dict): cates = list(num_dict.keys()) nums = list(num_dict.values()) fig = plt.figure(figsize = (10, 5)) plt.bar(cates, nums, color ='maroon', width = 0.4) plt.xlabel("waterfowl categories") plt.ylabel("No. of waterfowl per category") plt.title("The number of waterfowls of different categories(Total number: {})".format(str(sum(nums)))) return plt2arr(plt) @spaces.GPU def mixmatch_classifier_inference(model_dir,mega_image,bbox_list): mixmatch_model = get_model_mixmatch(model_dir) pred_data = [] for bbox in bbox_list: [x1,y1,x2,y2,conf] = bbox box = [int(x1),int(y1),int(x2),int(y2)] pred = predict_methods(mixmatch_model,box,mega_image) pred_data.append(pred) return draw_bar_chart(Counter(pred_data)) # with open(txt_dir,'w') as f: # for line in pred_data: # line = ','.join(line) # f.writelines(line+'\n')