vorkna commited on
Commit
074fce9
·
verified ·
1 Parent(s): 49df311

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -42
app.py CHANGED
@@ -29,15 +29,14 @@ def get_model_instance_segmentation(num_classes):
29
  hidden_layer = 256
30
  # Replace the mask predictor with a new one
31
  model.roi_heads.mask_predictor = MaskRCNNPredictor(
32
- in_features_mask,
33
- hidden_layer,
34
- num_classes
35
  )
36
  return model
37
 
38
- center_x = torch.tensor([-.3, .3])
39
- center_y = torch.tensor([-.3, .3])
40
- gamma = torch.tensor([.9, 1.])
 
41
 
42
 
43
  # Define fisheye augmentation with given parameters
@@ -52,36 +51,37 @@ fisheye_transform = K.RandomFisheye(
52
 
53
  # --- Setup ---
54
  # Check for model file and data directory
55
- if not os.path.exists('maskrcnn_pennfudan.pth'):
56
- raise FileNotFoundError("Model file 'maskrcnn_pennfudan.pth' not found. Please place it in the root directory.")
 
 
57
 
58
- image_dir = 'data/PennFudanPed/PNGImages'
59
  if not os.path.isdir(image_dir):
60
- raise FileNotFoundError(f"Image directory '{image_dir}' not found. Please ensure the data is structured correctly.")
 
 
61
 
62
  # Device and model loading
63
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
64
  # PennFudanPed has 2 classes: background and person
65
  num_classes = 2
66
  model = get_model_instance_segmentation(num_classes)
67
- model.load_state_dict(torch.load('maskrcnn_pennfudan.pth', map_location=device))
68
  model.to(device)
69
  model.eval()
70
 
71
  # Load image paths
72
- image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')])
 
 
73
 
74
- def predict_and_draw(image_index):
 
75
  """
76
- Runs prediction on an image and returns the image with masks and boxes drawn.
77
  """
78
- if not image_files:
79
- return None, "No images found in data/PennFudanPed/PNGImages", 0
80
-
81
- image_index = image_index % len(image_files)
82
- image_path = image_files[image_index]
83
-
84
- img = Image.open(image_path).convert("RGB")
85
  img_tensor = F.to_tensor(img)
86
  # image = image[:3, ...].to(torch.float32) / 255.0
87
  img_tensor = fisheye_transform(img_tensor.unsqueeze(0)).squeeze(0)
@@ -90,43 +90,66 @@ def predict_and_draw(image_index):
90
  prediction = model([img_tensor.to(device)])
91
 
92
  pred = prediction[0]
93
-
94
  # Filter predictions by a confidence score
95
- score_threshold = 0.8
96
- high_conf_indices = pred['scores'] > score_threshold
97
- boxes = pred['boxes'][high_conf_indices]
98
- labels = [f"person: {score:.2f}" for score in pred['scores'][high_conf_indices]]
99
- masks = pred['masks'][high_conf_indices]
100
 
101
  # Convert image tensor back to uint8 for drawing functions
102
  img_to_draw = (img_tensor * 255).to(torch.uint8)
103
 
104
  # Draw bounding boxes
105
  if len(boxes) > 0:
106
- img_with_boxes = draw_bounding_boxes(img_to_draw, boxes=boxes, labels=labels, colors="red", width=2)
 
 
107
  else:
108
  img_with_boxes = img_to_draw
109
 
110
  # Draw segmentation masks
111
  if len(masks) > 0:
112
  masks_bool = masks.squeeze(1) > 0.5
113
- img_with_masks = draw_segmentation_masks(img_with_boxes, masks=masks_bool, alpha=0.5, colors="blue")
 
 
114
  else:
115
  img_with_masks = img_with_boxes
116
 
117
  # Convert tensor to PIL Image for Gradio display
118
  final_image = F.to_pil_image(img_with_masks.cpu())
119
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  info_text = f"Displaying image {image_index + 1}/{len(image_files)}: {os.path.basename(image_path)}"
121
  return final_image, info_text, image_index
122
 
 
123
  # --- Gradio App ---
124
  with gr.Blocks() as demo:
125
- gr.Markdown("# Mask R-CNN Pedestrian Detection on PennFudanPed with Fish Eye Augmentation")
126
-
 
 
 
127
  # State to keep track of the current image index
128
  current_index = gr.State(value=-1)
129
-
130
  with gr.Row():
131
  prev_btn = gr.Button("Previous")
132
  next_btn = gr.Button("Next")
@@ -134,7 +157,7 @@ with gr.Blocks() as demo:
134
 
135
  output_image = gr.Image(label="Image with Predictions")
136
  info_text = gr.Textbox(label="Image Info")
137
-
138
  def next_image(index):
139
  new_index = index + 1
140
  return predict_and_draw(new_index)
@@ -142,19 +165,50 @@ with gr.Blocks() as demo:
142
  def prev_image(index):
143
  new_index = index - 1
144
  if new_index < 0:
145
- new_index = len(image_files) - 1 # Wrap around
146
  return predict_and_draw(new_index)
147
-
148
  def random_image():
149
  new_index = random.randint(0, len(image_files) - 1)
150
  return predict_and_draw(new_index)
151
 
152
- next_btn.click(next_image, inputs=current_index, outputs=[output_image, info_text, current_index])
153
- prev_btn.click(prev_image, inputs=current_index, outputs=[output_image, info_text, current_index])
154
- random_btn.click(random_image, inputs=None, outputs=[output_image, info_text, current_index])
155
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # Load the first image on startup
157
- demo.load(lambda: next_image(-1), inputs=None, outputs=[output_image, info_text, current_index])
 
 
 
 
158
 
159
  if __name__ == "__main__":
160
  demo.launch()
 
29
  hidden_layer = 256
30
  # Replace the mask predictor with a new one
31
  model.roi_heads.mask_predictor = MaskRCNNPredictor(
32
+ in_features_mask, hidden_layer, num_classes
 
 
33
  )
34
  return model
35
 
36
+
37
+ center_x = torch.tensor([-0.3, 0.3])
38
+ center_y = torch.tensor([-0.3, 0.3])
39
+ gamma = torch.tensor([0.9, 1.0])
40
 
41
 
42
  # Define fisheye augmentation with given parameters
 
51
 
52
  # --- Setup ---
53
  # Check for model file and data directory
54
+ if not os.path.exists("maskrcnn_pennfudan.pth"):
55
+ raise FileNotFoundError(
56
+ "Model file 'maskrcnn_pennfudan.pth' not found. Please place it in the root directory."
57
+ )
58
 
59
+ image_dir = "data/PennFudanPed/PNGImages"
60
  if not os.path.isdir(image_dir):
61
+ raise FileNotFoundError(
62
+ f"Image directory '{image_dir}' not found. Please ensure the data is structured correctly."
63
+ )
64
 
65
  # Device and model loading
66
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
67
  # PennFudanPed has 2 classes: background and person
68
  num_classes = 2
69
  model = get_model_instance_segmentation(num_classes)
70
+ model.load_state_dict(torch.load("maskrcnn_pennfudan.pth", map_location=device))
71
  model.to(device)
72
  model.eval()
73
 
74
  # Load image paths
75
+ image_files = sorted(
76
+ [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".png")]
77
+ )
78
 
79
+
80
+ def predict_on_image(img):
81
  """
82
+ Runs prediction on a PIL image and returns the image with masks and boxes drawn.
83
  """
84
+ img = img.convert("RGB")
 
 
 
 
 
 
85
  img_tensor = F.to_tensor(img)
86
  # image = image[:3, ...].to(torch.float32) / 255.0
87
  img_tensor = fisheye_transform(img_tensor.unsqueeze(0)).squeeze(0)
 
90
  prediction = model([img_tensor.to(device)])
91
 
92
  pred = prediction[0]
93
+
94
  # Filter predictions by a confidence score
95
+ score_threshold = 0.7
96
+ high_conf_indices = pred["scores"] > score_threshold
97
+ boxes = pred["boxes"][high_conf_indices]
98
+ labels = [f"person: {score:.2f}" for score in pred["scores"][high_conf_indices]]
99
+ masks = pred["masks"][high_conf_indices]
100
 
101
  # Convert image tensor back to uint8 for drawing functions
102
  img_to_draw = (img_tensor * 255).to(torch.uint8)
103
 
104
  # Draw bounding boxes
105
  if len(boxes) > 0:
106
+ img_with_boxes = draw_bounding_boxes(
107
+ img_to_draw, boxes=boxes, labels=labels, colors="red", width=2
108
+ )
109
  else:
110
  img_with_boxes = img_to_draw
111
 
112
  # Draw segmentation masks
113
  if len(masks) > 0:
114
  masks_bool = masks.squeeze(1) > 0.5
115
+ img_with_masks = draw_segmentation_masks(
116
+ img_with_boxes, masks=masks_bool, alpha=0.5, colors="blue"
117
+ )
118
  else:
119
  img_with_masks = img_with_boxes
120
 
121
  # Convert tensor to PIL Image for Gradio display
122
  final_image = F.to_pil_image(img_with_masks.cpu())
123
+ return final_image
124
+
125
+
126
+ def predict_and_draw(image_index):
127
+ """
128
+ Runs prediction on an image from the dataset and returns the image with masks and boxes drawn.
129
+ """
130
+ if not image_files:
131
+ return None, "No images found in data/PennFudanPed/PNGImages", 0
132
+
133
+ image_index = image_index % len(image_files)
134
+ image_path = image_files[image_index]
135
+
136
+ img = Image.open(image_path)
137
+ final_image = predict_on_image(img)
138
+
139
  info_text = f"Displaying image {image_index + 1}/{len(image_files)}: {os.path.basename(image_path)}"
140
  return final_image, info_text, image_index
141
 
142
+
143
  # --- Gradio App ---
144
  with gr.Blocks() as demo:
145
+ gr.Markdown(
146
+ "# Mask R-CNN Pedestrian Detection on PennFudanPed with Fish Eye Augmentation"
147
+ )
148
+
149
+ gr.Markdown("### Browse Dataset Images")
150
  # State to keep track of the current image index
151
  current_index = gr.State(value=-1)
152
+
153
  with gr.Row():
154
  prev_btn = gr.Button("Previous")
155
  next_btn = gr.Button("Next")
 
157
 
158
  output_image = gr.Image(label="Image with Predictions")
159
  info_text = gr.Textbox(label="Image Info")
160
+
161
  def next_image(index):
162
  new_index = index + 1
163
  return predict_and_draw(new_index)
 
165
  def prev_image(index):
166
  new_index = index - 1
167
  if new_index < 0:
168
+ new_index = len(image_files) - 1 # Wrap around
169
  return predict_and_draw(new_index)
170
+
171
  def random_image():
172
  new_index = random.randint(0, len(image_files) - 1)
173
  return predict_and_draw(new_index)
174
 
175
+ next_btn.click(
176
+ next_image,
177
+ inputs=current_index,
178
+ outputs=[output_image, info_text, current_index],
179
+ )
180
+ prev_btn.click(
181
+ prev_image,
182
+ inputs=current_index,
183
+ outputs=[output_image, info_text, current_index],
184
+ )
185
+ random_btn.click(
186
+ random_image, inputs=None, outputs=[output_image, info_text, current_index]
187
+ )
188
+
189
+ gr.Markdown("---")
190
+ gr.Markdown("### Or upload your own image")
191
+ input_image = gr.Image(type="pil", label="Upload Image")
192
+ upload_btn = gr.Button("Predict on Uploaded Image")
193
+
194
+ def handle_upload(img):
195
+ if img is None:
196
+ return None, "Please upload an image.", -1
197
+ result = predict_on_image(img)
198
+ return result, "Prediction for uploaded image.", -1
199
+
200
+ upload_btn.click(
201
+ handle_upload,
202
+ inputs=input_image,
203
+ outputs=[output_image, info_text, current_index],
204
+ )
205
+
206
  # Load the first image on startup
207
+ demo.load(
208
+ lambda: next_image(-1),
209
+ inputs=None,
210
+ outputs=[output_image, info_text, current_index],
211
+ )
212
 
213
  if __name__ == "__main__":
214
  demo.launch()