File size: 3,312 Bytes
6b2d891
9a30e62
2e4ea99
 
 
adca53b
1094aba
2e4ea99
458c5d4
 
6b2d891
2e4ea99
fa10257
a64c8d1
958d113
2e4ea99
d01979a
2e4ea99
4616018
2e4ea99
 
 
 
b8316ce
 
d54d7ce
2e4ea99
 
 
 
 
 
 
 
 
 
919da0d
2e4ea99
d01979a
cacc64e
2e4ea99
 
 
 
 
 
 
dbe5043
e069f67
2e4ea99
 
 
 
 
 
 
1094aba
146be9b
2e4ea99
6f792de
 
 
 
 
 
2e4ea99
 
 
e8d20c7
919da0d
 
 
ce9caa8
e8d20c7
ce9caa8
919da0d
ce9caa8
 
 
919da0d
 
ce9caa8
e069f67
359c4ce
e069f67
d01979a
2e4ea99
612a23d
2e4ea99
 
d01979a
2e4ea99
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

import os
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image

# -- get torch and cuda version
#TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
#CUDA_VERSION = torch.__version__.split("+")[-1]

# -- install pre-build detectron2
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
os.system('pip install pyyaml==5.1')

import detectron2

from detectron2.utils.logger import setup_logger # ????
# from google.colab.patches import cv2_imshow

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

# ????
setup_logger()

# -- load rcnn model
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE= 'cpu'
predictor = DefaultPredictor(cfg)
'''
os.system(wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg)
im = cv2.imread("./input.jpg")
cv2_imshow(im)

outputs = predictor(im)

print(outputs["instances"].pred_classes)
print(outputs["instances"].pred_boxes)
'''
# -- load design modernity model for classification
DesignModernityModel = torch.load("DesignModernityModel.pt")

#INPUT_FEATURES = DesignModernityModel.fc.in_features
#linear = nn.linear(INPUT_FEATURES, 5)

DesignModernityModel.eval() # set state of the model to inference

LABELS = ['2000-2003', '2006-2008', '2009-2011', '2012-2014', '2015-2018']
n_labels = len(LABELS)

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

carTransforms = transforms.Compose([transforms.Resize(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=MEAN, std=STD)])

def classifyCar(im):
  im = Image.fromarray(im.astype('uint8'), 'RGB')
  im = transforms.ToTensor(im)
  try:
    with torch.no_grad():
      outputs = predictor(im)
  except:
    return im, {"error1": im.shape}
  try:
      v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
  except Exception as err:
    return im, {"error2": 0.5}
  try:
      out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
  except Exception as err:
    return im, {"error3": 0.5}
  im2 = carTransforms(im).unsqueeze(0)  # transform and add batch dimension
  with torch.no_grad():
    scores = torch.nn.functional.softmax(DesignModernityModel(im2)[0])
  return Image.fromarray(np.uint8(out.get_image())).convert('RGB'), {LABELS[i]: float(scores[i]) for i in range(n_labels)}

#examples = [[example_img.jpg], [example_img2.jpg]]  # must be uploaded in repo

# create interface for model
interface = gr.Interface(classifyCar, inputs='image', outputs=['image','label'], cache_examples=False, title='VW Up or Fiat 500')
interface.launch()