PDG commited on
Commit
50861f1
·
1 Parent(s): 17cf5c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -66,6 +66,43 @@ carTransforms = transforms.Compose([transforms.Resize(224),
66
  transforms.ToTensor(),
67
  transforms.Normalize(mean=MEAN, std=STD)])
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def classifyCar(im):
70
  try:
71
  im = cv2.imread(im)
@@ -80,12 +117,12 @@ def classifyCar(im):
80
  boxes = list(outputs["instances"].pred_boxes[car_class_true])
81
 
82
  if len(boxes) != 0:
83
- max_idx = torch.tensor([(x[2] - x[0])*(x[3] - x[1]) for x in boxes]).argmax().item()
84
-
85
- im2 = Image.fromarray(np.uint8(im)).convert('RGB').crop(boxes[max_idx].to(torch.int64).numpy())
86
 
87
- carResize = transforms.Compose([transforms.Resize((224, 224))])
88
- im2 = carResize(im2)
 
 
89
 
90
  with torch.no_grad():
91
  scores = torch.nn.functional.softmax(DesignModernityModel(carTransforms(im2).unsqueeze(0))[0])
@@ -106,7 +143,7 @@ def classifyCar(im):
106
  #examples = [[example_img.jpg], [example_img2.jpg]] # must be uploaded in repo
107
 
108
  # create interface for model
109
- interface = gr.Interface(classifyCar, inputs='image', outputs=['image','label'], cache_examples=False, title='VW Up or Fiat 500')
110
  interface.launch()
111
 
112
 
 
66
  transforms.ToTensor(),
67
  transforms.Normalize(mean=MEAN, std=STD)])
68
 
69
+
70
+ def cropImage(outputs, im, boxes, car_class_true):
71
+ # Get the masks
72
+ #car_class_true = outputs["instances"].pred_classes == 2
73
+ #boxes = list(outputs["instances"].pred_boxes[car_class_true])
74
+ masks = list(np.array(outputs["instances"].pred_masks[car_class_true]))
75
+ max_idx = torch.tensor([(x[2] - x[0])*(x[3] - x[1]) for x in boxes]).argmax().item()
76
+
77
+ # Pick an item to mask
78
+ item_mask = masks[max_idx]
79
+ # Get the true bounding box of the mask
80
+ segmentation = np.where(item_mask == True) # return a list of different position in the bow, which are the actual detected object
81
+ x_min = int(np.min(segmentation[1])) # minimum x position
82
+ x_max = int(np.max(segmentation[1]))
83
+ y_min = int(np.min(segmentation[0]))
84
+ y_max = int(np.max(segmentation[0]))
85
+ # Create cropped image from the just portion of the image we want
86
+ cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode = 'RGB')
87
+ # Create a PIL Image out of the mask
88
+ mask = Image.fromarray((item_mask * 255).astype('uint8')) ###### change 255
89
+ # Crop the mask to match the cropped image
90
+ cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
91
+ # Load in a background image and choose a paste position
92
+ height = y_max-y_min
93
+ width = x_max-x_min
94
+ background = Image.new(mode='RGB', size=(width, height), color=(255, 255, 255, 0))
95
+ # Create a new foreground image as large as the composite and paste the cropped image on top
96
+ new_fg_image = Image.new('RGB', background.size)
97
+ new_fg_image.paste(cropped)
98
+ # Create a new alpha mask as large as the composite and paste the cropped mask
99
+ new_alpha_mask = Image.new('L', background.size, color=0)
100
+ new_alpha_mask.paste(cropped_mask)
101
+ #composite the foreground and background using the alpha mask
102
+ composite = Image.composite(new_fg_image, background, new_alpha_mask)
103
+ return composite
104
+
105
+
106
  def classifyCar(im):
107
  try:
108
  im = cv2.imread(im)
 
117
  boxes = list(outputs["instances"].pred_boxes[car_class_true])
118
 
119
  if len(boxes) != 0:
120
+ #max_idx = torch.tensor([(x[2] - x[0])*(x[3] - x[1]) for x in boxes]).argmax().item()
 
 
121
 
122
+ #im2 = Image.fromarray(np.uint8(im)).convert('RGB').crop(boxes[max_idx].to(torch.int64).numpy())
123
+ im2 = cropImage(outputs, im, boxes, car_class_true)
124
+ #carResize = transforms.Compose([transforms.Resize((224, 224))])
125
+ #im2 = carResize(im2)
126
 
127
  with torch.no_grad():
128
  scores = torch.nn.functional.softmax(DesignModernityModel(carTransforms(im2).unsqueeze(0))[0])
 
143
  #examples = [[example_img.jpg], [example_img2.jpg]] # must be uploaded in repo
144
 
145
  # create interface for model
146
+ interface = gr.Interface(classifyCar, inputs='image', outputs=['image','label'], cache_examples=False, title='Modernity car classification')
147
  interface.launch()
148
 
149