PDG commited on
Commit
dbdb12b
·
1 Parent(s): f14524c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -1
main.py CHANGED
@@ -38,8 +38,31 @@ print(outputs["instances"].pred_classes)
38
  print(outputs["instances"].pred_boxes)
39
 
40
  # -- load Mask R-CNN model for segmentation
41
- MaskRCNN = torch.load("MaskRCNN.pt")
42
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
 
38
  print(outputs["instances"].pred_boxes)
39
 
40
  # -- load Mask R-CNN model for segmentation
41
+ DesignModernityModel = torch.load("DesignModernityModel.pt")
42
 
43
+ #INPUT_FEATURES = DesignModernityModel.fc.in_features
44
+ #linear = nn.linear(INPUT_FEATURES, 5)
45
 
46
+ DesignModernityModel.eval() # set state of the model to inference
47
+
48
+ LABELS = ['2003-2006', 'VW Up!']
49
+
50
+ carTransforms = transforms.Compose([
51
+ transforms.RandomResizedCrop(224),
52
+ ...
53
+ ])
54
+
55
+ def classifyCar(im):
56
+ im = Image.fromarray(im.astype('uint8'), 'RGB')
57
+ im = carTransforms(im).unsqueeze(0) # transform and add batch dimension
58
+ with torch.no_grad():
59
+ scores = torch.nn.functional.softmax(model(im)[0])
60
+ return {LABELS[i]: float(scores[i]) for i in range(2)}
61
+
62
+ examples = [[example_img.jpg], [example_img2.jpg]] # must be uploaded in repo
63
+
64
+ # create interface for model
65
+ interface = gr.Interface(classifyCar, inputs='Image', outputs='label', cache_examples=False, title='VW Up or Fiat 500', example=examples)
66
+ interface.launch()
67
 
68