Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| ### Installations ### | |
| ##################### | |
| os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
| ### Import Libraries ### | |
| ######################### | |
| # general | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| # Setup detectron2 logger | |
| import detectron2 | |
| from detectron2.utils.logger import setup_logger | |
| setup_logger() | |
| # import some common detectron2 utilities | |
| from detectron2 import model_zoo | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.data import MetadataCatalog, DatasetCatalog | |
| # import torchvision utilities | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| ### Detectron Model ### | |
| ####################### | |
| # Initialize and set to cpu | |
| cfg = get_cfg() | |
| cfg.MODEL.DEVICE='cpu' | |
| # Load Model | |
| cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model | |
| # Load pretrained weights | |
| cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") | |
| predictor = DefaultPredictor(cfg) | |
| # get labels | |
| metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) | |
| class_catalog = metadata.thing_classes | |
| ### ResNet18 Model ### | |
| ###################### | |
| pretrained_model = models.resnet18(pretrained=True) | |
| IN_FEATURES = pretrained_model.fc.in_features | |
| OUTPUT_DIM = 5 | |
| final_fc = torch.nn.Linear(IN_FEATURES, OUTPUT_DIM) | |
| pretrained_model.fc = final_fc | |
| # Load fine tuned weights | |
| pretrained_model.load_state_dict(torch.load('model_modernity_advanced.pt', map_location = 'cpu')) | |
| pretrained_model.eval() | |
| ### Test Transforms ### | |
| ####################### | |
| pretrained_size = 224 | |
| pretrained_means = [0.485, 0.456, 0.406] | |
| pretrained_stds = [0.229, 0.224, 0.225] | |
| test_transforms = transforms.Compose([ | |
| transforms.Resize(pretrained_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean = pretrained_means, | |
| std = pretrained_stds) | |
| ]) | |
| ### Car Modernity Function ### | |
| ############################## | |
| def modernity_pred(logits): | |
| p = F.softmax(logits, dim = 1) | |
| groups = torch.tensor([[0,1,2,3,4]]) | |
| score = (p * groups).sum(axis = 1) | |
| return score | |
| ### Image Classification function ### | |
| ##################################### | |
| def image_classifier(inp): | |
| ### Detect in full image ### | |
| ############################ | |
| # detectron prediction | |
| output = predictor(inp) | |
| instances = output['instances'] | |
| # assign class names | |
| classes = [] | |
| for i in instances.pred_classes.detach().cpu(): | |
| classes.append(class_catalog[i]) | |
| # select cars and pick largest according to pixel count of pred_mask | |
| is_car = np.array(classes) == 'car' | |
| # statement to check if car was detected in the image and proceed accordingly | |
| if is_car.any() == True: | |
| # select cars and pick largest according to pixel count of pred_mask | |
| pred_masks = instances.pred_masks[is_car].detach().cpu() | |
| idx_largest_car = pred_masks.reshape(pred_masks.shape[0], -1).sum(axis= 1).argmax() | |
| ### crop image by according region of interest | |
| ############################################## | |
| # extract region of interest | |
| pred_boxes = instances.pred_boxes[is_car][int(idx_largest_car)] | |
| box = list(pred_boxes)[0].detach().cpu().numpy() | |
| x_min = int(box[0]) | |
| y_min = int(box[1]) | |
| x_max = int(box[2]) | |
| y_max = int(box[3]) | |
| # crop image respectively | |
| crop_img = inp[y_min:y_max, x_min:x_max, :] | |
| ### Change Background to White ### | |
| ################################## | |
| # convert to PIL fromat | |
| cropped = Image.fromarray(crop_img.astype('uint8'), 'RGB') | |
| # select respective mask from model output | |
| pred_mask_crop = pred_masks[idx_largest_car].numpy() | |
| # convert to PIL format | |
| pred_mask_crop = Image.fromarray((pred_mask_crop * 255).astype('uint8')) | |
| #crop the pred_mask from model output | |
| pred_mask_crop = pred_mask_crop.crop((x_min, y_min, x_max, y_max)) | |
| # create white background | |
| s = np.array(pred_mask_crop).shape | |
| background = Image.fromarray(np.ones(shape = (s[0], s[1], 3), dtype = np.uint8) * 255, mode = 'RGB') | |
| # create alpha mask | |
| new_alpha_mask = Image.new('L', background.size, color = 0) | |
| new_alpha_mask.paste(pred_mask_crop) | |
| # bring both together | |
| composite = Image.composite(cropped, background, new_alpha_mask) | |
| ### Predict modernity | |
| img_trans = test_transforms(composite).unsqueeze(0) | |
| with torch.no_grad(): | |
| out = pretrained_model(img_trans) | |
| mod_score = modernity_pred(out) | |
| return composite, f'Modernity score: {round(float(mod_score), 5)}' | |
| else: | |
| message = 'no car was detected in image' | |
| # White image as place holder | |
| placeholder = Image.fromarray(np.ones(shape = (100, 150, 3), dtype = np.uint8) * 255, mode = 'RGB') | |
| return placeholder, message | |
| ### Gradio App ### | |
| ################## | |
| title = "Prediction of Car Modernity Score" | |
| description = "Upload image of car to get prediction of the modernity score. If image includes multiple cars, car with largest pixel count is extracted" | |
| examples = [['test_img_1.jpg'], ['test_img_2.jpg'], ['test_img_3.jpg'], ['test_img_4.jpg'], ['test_img_5.jpg'], ['test_img_6.jpeg'], ['test_img_7.jpeg']] | |
| classif_app = gr.Interface(fn=image_classifier, | |
| inputs="image", | |
| outputs=["image", "label"], | |
| title = title, | |
| description = description, | |
| examples = examples) | |
| classif_app.launch() | |