PDG's picture
Create new file
2b37f77
raw
history blame
5.44 kB
import os
import numpy as np
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
# -- install detectron2 from source ------------------------------------------------------------------------------
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
os.system('pip install pyyaml==5.1')
import detectron2
from detectron2.utils.logger import setup_logger
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
import cv2
setup_logger()
# -- load rcnn model ---------------------------------------------------------------------------------------------
cfg = get_cfg()
# load model config
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# set model weights
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE= 'cpu' # move to cpu
predictor = DefaultPredictor(cfg) # create model
# -- load design modernity model for classification --------------------------------------------------------------
DesignModernityModel = torch.load("DesignModernityModel.pt")
DesignModernityModel.eval() # set state of the model to inference
# Set class labels
LABELS = ['2000-2003', '2006-2008', '2009-2011', '2012-2014', '2015-2018']
n_labels = len(LABELS)
# define maéan and std dev for normalization
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
# define image transformation steps
carTransforms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)])
# -- define a function for extraction of the detected car ---------------------------------------------------------
def cropImage(outputs, im, boxes, car_class_true):
# Get the masks
masks = list(np.array(outputs["instances"].pred_masks[car_class_true]))
max_idx = torch.tensor([(x[2] - x[0])*(x[3] - x[1]) for x in boxes]).argmax().item()
# Pick an item to mask
item_mask = masks[max_idx]
# Get the true bounding box of the mask
segmentation = np.where(item_mask == True) # return a list of different position in the bow, which are the actual detected object
x_min = int(np.min(segmentation[1])) # minimum x position
x_max = int(np.max(segmentation[1]))
y_min = int(np.min(segmentation[0]))
y_max = int(np.max(segmentation[0]))
# Create cropped image from the just portion of the image we want
cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode = 'RGB')
# Create a PIL Image out of the mask
mask = Image.fromarray((item_mask * 255).astype('uint8')) ###### change 255
# Crop the mask to match the cropped image
cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
# Load in a background image and choose a paste position
height = y_max-y_min
width = x_max-x_min
background = Image.new(mode='RGB', size=(width, height), color=(255, 255, 255, 0))
# Create a new foreground image as large as the composite and paste the cropped image on top
new_fg_image = Image.new('RGB', background.size)
new_fg_image.paste(cropped)
# Create a new alpha mask as large as the composite and paste the cropped mask
new_alpha_mask = Image.new('L', background.size, color=0)
new_alpha_mask.paste(cropped_mask)
#composite the foreground and background using the alpha mask
composite = Image.composite(new_fg_image, background, new_alpha_mask)
return composite
# -- define function for image segmentation and classification --------------------------------------------------------
def classifyCar(im):
# read image
#im = cv2.imread(im)
# perform segmentation
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1)
out = v.draw_instance_predictions(outputs["instances"])
# check if a car was detected in the image
car_class_true = outputs["instances"].pred_classes == 2
boxes = list(outputs["instances"].pred_boxes[car_class_true])
# if a car was detected, extract the car and perform modernity score classification
if len(boxes) != 0:
im2 = cropImage(outputs, im, boxes, car_class_true)
with torch.no_grad():
scores = torch.nn.functional.softmax(DesignModernityModel(carTransforms(im2).unsqueeze(0))[0])
label = {LABELS[i]: float(scores[i]) for i in range(n_labels)}
# if no car was detected, show original image and print "No car detected"
else:
im2 = Image.fromarray(np.uint8(im)).convert('RGB')
label = "No car detected"
return im2, label
# -- create interface for model ----------------------------------------------------------------------------------------
interface = gr.Interface(classifyCar, inputs='image', outputs=['image','label'], cache_examples=False, title='Modernity car classification')
interface.launch()