Update main.py
Browse files
main.py
CHANGED
|
@@ -38,8 +38,31 @@ print(outputs["instances"].pred_classes)
|
|
| 38 |
print(outputs["instances"].pred_boxes)
|
| 39 |
|
| 40 |
# -- load Mask R-CNN model for segmentation
|
| 41 |
-
|
| 42 |
|
|
|
|
|
|
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
|
|
|
| 38 |
print(outputs["instances"].pred_boxes)
|
| 39 |
|
| 40 |
# -- load Mask R-CNN model for segmentation
|
| 41 |
+
DesignModernityModel = torch.load("DesignModernityModel.pt")
|
| 42 |
|
| 43 |
+
#INPUT_FEATURES = DesignModernityModel.fc.in_features
|
| 44 |
+
#linear = nn.linear(INPUT_FEATURES, 5)
|
| 45 |
|
| 46 |
+
DesignModernityModel.eval() # set state of the model to inference
|
| 47 |
+
|
| 48 |
+
LABELS = ['2003-2006', 'VW Up!']
|
| 49 |
+
|
| 50 |
+
carTransforms = transforms.Compose([
|
| 51 |
+
transforms.RandomResizedCrop(224),
|
| 52 |
+
...
|
| 53 |
+
])
|
| 54 |
+
|
| 55 |
+
def classifyCar(im):
|
| 56 |
+
im = Image.fromarray(im.astype('uint8'), 'RGB')
|
| 57 |
+
im = carTransforms(im).unsqueeze(0) # transform and add batch dimension
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
scores = torch.nn.functional.softmax(model(im)[0])
|
| 60 |
+
return {LABELS[i]: float(scores[i]) for i in range(2)}
|
| 61 |
+
|
| 62 |
+
examples = [[example_img.jpg], [example_img2.jpg]] # must be uploaded in repo
|
| 63 |
+
|
| 64 |
+
# create interface for model
|
| 65 |
+
interface = gr.Interface(classifyCar, inputs='Image', outputs='label', cache_examples=False, title='VW Up or Fiat 500', example=examples)
|
| 66 |
+
interface.launch()
|
| 67 |
|
| 68 |
|