pipeline_server / classifiers /MixMatch /mixmatch_classification.py
zy984764389's picture
Update classifiers/MixMatch/mixmatch_classification.py
12d6a57 verified
import os
import glob
import cv2
import spaces
from torchvision import transforms,utils
import torch
import classifiers.MixMatch.models.wideresnet as wmodels
from PIL import Image
import numpy as np
import classifiers.MixMatch.dataset.waterfowl as dataset
import torch.nn as nn
import torch.nn.functional as F
import random
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
category_dict = {
"0": "American Widgeon_Female",
"1": "American Widgeon_Male",
"2": "Canada Goose",
"3": "Canvasback_Male",
"4": "Coot",
"5": "Gadwall",
"6": "Green-winged teal",
"7": "Mallard Female",
"8": "Mallard Male",
"9": "Not a bird",
"10": "Pelican",
"11": "Pintail_Female",
"12": "Pintail_Male",
"13": "Ring-necked duck Female",
"14": "Ring-necked duck Male",
"15": "Scaup_Male",
"16": "Shoveler_Female",
"17": "Shoveler_Male",
"18": "Snow",
"19": "Unknown",
"20": "White-fronted Goose"
}
test_transform = transforms.Compose([
dataset.ToTensor(),
])
device = torch.device('cuda')
cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255
def window_jittering(box,mega_image,num_box=5):
# h,w,c = mega_image.shape
w, h = mega_image.size
x1,y1,x2,y2 = box
jittering_box = []
jittering_box.append([int(x1),int(y1),int(x2),int(y2)])
for _ in range(num_box-1):
x1_random = random.random()*0.4-0.1
x2_random = random.random()*0.4-0.1
y1_random = random.random()*0.4-0.1
y2_random = random.random()*0.4-0.1
x1_jitter = int(max(x1-(x2-x1)*x1_random,0))
x2_jitter = int(min(x2+(x2-x1)*x2_random,w))
y1_jitter = int(max(y1-(y2-y1)*y1_random,0))
y2_jitter = int(min(y2+(y2-y1)*y2_random,h))
jittering_box.append([x1_jitter,y1_jitter,x2_jitter,y2_jitter])
return jittering_box
def transpose(x, source='NHWC', target='NCHW'):
return x.transpose([source.index(d) for d in target])
@spaces.GPU
def predict_methods(mixmatch_model,box,mega_image,method = 'baseline'):
# if category == 'Snow/Ross Goose' or category == 'Snow/Ross Goose (blue)':
# category = 'Snow'
# elif category not in category_dict.values():
# category = 'Unknown'
if method == 'baseline':
bird_crop = prepare_data_mixmatch(mega_image,box)
out = mixmatch_model(bird_crop)
_, pred = out.topk(1, 1, True, True)
pred_cate = category_dict[str(np.array(pred.cpu())[0][0])]
elif method == 'voting':
jittering_boxes = window_jittering(box,mega_image)
predictions = []
for box in jittering_boxes:
bird_crop = prepare_data_mixmatch(mega_image,box)
out = mixmatch_model(bird_crop)
# out = F.softmax(out,dim=1)
pred_prob, pred_class = out.topk(1, 1, True, True)
pred_cate = category_dict[str(np.array(pred_class.cpu())[0][0])]
predictions.append(pred_cate)
pred_cate = max(set(predictions), key=predictions.count)
elif method == 'prob_sum':
jittering_boxes = window_jittering(box,mega_image,10)
score_list = [0 for _ in range(21)]
for box in jittering_boxes:
bird_crop = prepare_data_mixmatch(mega_image,box)
out = mixmatch_model(bird_crop)
out = F.softmax(out,dim=1)
# sigmoid = nn.Sigmoid()
# out = sigmoid(out)
pred_prob, pred_class = out.topk(5, 1, True, True)
pred_prob = pred_prob.cpu().numpy()[0]
pred_class = pred_class.cpu().numpy()[0]
for i in range(len(pred_prob)):
score_list[pred_class[i]] += pred_prob[i]
pred_cate = category_dict[str(np.argmax(np.array(score_list)))]
return pred_cate
@spaces.GPU
def create_model(ema=False):
model = wmodels.WideResNet(num_classes=21)
model = model.cuda()
if ema:
for param in model.parameters():
param.detach_()
return model
@spaces.GPU
def get_model_mixmatch(checkpoint_dir,ema=True):
checkpoint = torch.load(checkpoint_dir)
if ema:
ema_model = create_model(ema=True)
ema_model.load_state_dict(checkpoint['ema_state_dict'])
ema_model.to(device)
print('loaded ema_model')
return ema_model.eval()
else:
model = create_model()
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
print('loaded model')
return model.eval()
def normalize(x, mean=cifar10_mean, std=cifar10_std):
x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
x -= mean*255
x *= 1.0/(255*std)
return x
def prepare_data_mixmatch(mega_image,box):
[x1,y1,x2,y2] = box
h,w,c = np.shape(mega_image)
bird_crop = cv2.resize(mega_image[max(y1,0):min(y2,h),max(x1,0):min(x2,w),:],(32,32))
bird_crop = test_transform(transpose(normalize(np.array([bird_crop])))).to(device)
return bird_crop
def plt2arr(plt):
canvas = plt.gca().figure.canvas
canvas.draw()
data = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))
return image
def draw_bar_chart(num_dict):
cates = list(num_dict.keys())
nums = list(num_dict.values())
fig = plt.figure(figsize = (10, 5))
plt.bar(cates, nums, color ='maroon', width = 0.4)
plt.xlabel("waterfowl categories")
plt.ylabel("No. of waterfowl per category")
plt.title("The number of waterfowls of different categories(Total number: {})".format(str(sum(nums))))
return plt2arr(plt)
@spaces.GPU
def mixmatch_classifier_inference(model_dir,mega_image,bbox_list):
mixmatch_model = get_model_mixmatch(model_dir)
pred_data = []
for bbox in bbox_list:
[x1,y1,x2,y2,conf] = bbox
box = [int(x1),int(y1),int(x2),int(y2)]
pred = predict_methods(mixmatch_model,box,mega_image)
pred_data.append(pred)
return draw_bar_chart(Counter(pred_data))
# with open(txt_dir,'w') as f:
# for line in pred_data:
# line = ','.join(line)
# f.writelines(line+'\n')