sebastianM's picture
Update app.py
bc5f7a7
import torch
import os
### Installations ###
#####################
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
### Import Libraries ###
#########################
# general
import gradio as gr
import numpy as np
import cv2
from PIL import Image
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog
# import torchvision utilities
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
### Detectron Model ###
#######################
# Initialize and set to cpu
cfg = get_cfg()
cfg.MODEL.DEVICE='cpu'
# Load Model
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Load pretrained weights
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
# get labels
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
class_catalog = metadata.thing_classes
### ResNet18 Model ###
######################
pretrained_model = models.resnet18(pretrained=True)
IN_FEATURES = pretrained_model.fc.in_features
OUTPUT_DIM = 5
final_fc = torch.nn.Linear(IN_FEATURES, OUTPUT_DIM)
pretrained_model.fc = final_fc
# Load fine tuned weights
pretrained_model.load_state_dict(torch.load('model_modernity_advanced.pt', map_location = 'cpu'))
pretrained_model.eval()
### Test Transforms ###
#######################
pretrained_size = 224
pretrained_means = [0.485, 0.456, 0.406]
pretrained_stds = [0.229, 0.224, 0.225]
test_transforms = transforms.Compose([
transforms.Resize(pretrained_size),
transforms.ToTensor(),
transforms.Normalize(mean = pretrained_means,
std = pretrained_stds)
])
### Car Modernity Function ###
##############################
def modernity_pred(logits):
p = F.softmax(logits, dim = 1)
groups = torch.tensor([[0,1,2,3,4]])
score = (p * groups).sum(axis = 1)
return score
### Image Classification function ###
#####################################
def image_classifier(inp):
### Detect in full image ###
############################
# detectron prediction
output = predictor(inp)
instances = output['instances']
# assign class names
classes = []
for i in instances.pred_classes.detach().cpu():
classes.append(class_catalog[i])
# select cars and pick largest according to pixel count of pred_mask
is_car = np.array(classes) == 'car'
# statement to check if car was detected in the image and proceed accordingly
if is_car.any() == True:
# select cars and pick largest according to pixel count of pred_mask
pred_masks = instances.pred_masks[is_car].detach().cpu()
idx_largest_car = pred_masks.reshape(pred_masks.shape[0], -1).sum(axis= 1).argmax()
### crop image by according region of interest
##############################################
# extract region of interest
pred_boxes = instances.pred_boxes[is_car][int(idx_largest_car)]
box = list(pred_boxes)[0].detach().cpu().numpy()
x_min = int(box[0])
y_min = int(box[1])
x_max = int(box[2])
y_max = int(box[3])
# crop image respectively
crop_img = inp[y_min:y_max, x_min:x_max, :]
### Change Background to White ###
##################################
# convert to PIL fromat
cropped = Image.fromarray(crop_img.astype('uint8'), 'RGB')
# select respective mask from model output
pred_mask_crop = pred_masks[idx_largest_car].numpy()
# convert to PIL format
pred_mask_crop = Image.fromarray((pred_mask_crop * 255).astype('uint8'))
#crop the pred_mask from model output
pred_mask_crop = pred_mask_crop.crop((x_min, y_min, x_max, y_max))
# create white background
s = np.array(pred_mask_crop).shape
background = Image.fromarray(np.ones(shape = (s[0], s[1], 3), dtype = np.uint8) * 255, mode = 'RGB')
# create alpha mask
new_alpha_mask = Image.new('L', background.size, color = 0)
new_alpha_mask.paste(pred_mask_crop)
# bring both together
composite = Image.composite(cropped, background, new_alpha_mask)
### Predict modernity
img_trans = test_transforms(composite).unsqueeze(0)
with torch.no_grad():
out = pretrained_model(img_trans)
mod_score = modernity_pred(out)
return composite, f'Modernity score: {round(float(mod_score), 5)}'
else:
message = 'no car was detected in image'
# White image as place holder
placeholder = Image.fromarray(np.ones(shape = (100, 150, 3), dtype = np.uint8) * 255, mode = 'RGB')
return placeholder, message
### Gradio App ###
##################
title = "Prediction of Car Modernity Score"
description = "Upload image of car to get prediction of the modernity score. If image includes multiple cars, car with largest pixel count is extracted"
examples = [['test_img_1.jpg'], ['test_img_2.jpg'], ['test_img_3.jpg'], ['test_img_4.jpg'], ['test_img_5.jpg'], ['test_img_6.jpeg'], ['test_img_7.jpeg']]
classif_app = gr.Interface(fn=image_classifier,
inputs="image",
outputs=["image", "label"],
title = title,
description = description,
examples = examples)
classif_app.launch()