JohnJoelMota commited on
Commit
976162b
·
verified ·
1 Parent(s): d74b66f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -4,6 +4,7 @@ from PIL import Image
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import gradio as gr
 
7
 
8
  # Load pretrained Mask R-CNN model
9
  model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
@@ -26,8 +27,24 @@ COCO_INSTANCE_CATEGORY_NAMES = [
26
  'hair drier', 'toothbrush'
27
  ]
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Detection and segmentation function
30
- def segment_objects(image, threshold=0.5):
 
 
 
31
  transform = torchvision.transforms.ToTensor()
32
  img_tensor = transform(image).unsqueeze(0)
33
 
@@ -77,11 +94,16 @@ interface = gr.Interface(
77
  fn=segment_objects,
78
  inputs=[
79
  gr.Image(type="pil", label="Upload Image"),
 
 
 
 
 
80
  gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
81
  ],
82
  outputs=gr.Image(type="filepath", label="Segmented Output"),
83
  title="Mask R-CNN Instance Segmentation",
84
- description="Upload an image to detect and segment objects using a pretrained Mask R-CNN model (TorchVision)."
85
  )
86
 
87
  if __name__ == "__main__":
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import gradio as gr
7
+ import os
8
 
9
  # Load pretrained Mask R-CNN model
10
  model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
 
27
  'hair drier', 'toothbrush'
28
  ]
29
 
30
+ # Function to load example image or use uploaded image
31
+ def load_image(uploaded_image, example_image):
32
+ if uploaded_image is not None:
33
+ return uploaded_image
34
+ elif example_image and example_image != "None":
35
+ example_path = os.path.join("examples", example_image)
36
+ if os.path.exists(example_path):
37
+ return Image.open(example_path).convert("RGB")
38
+ else:
39
+ raise FileNotFoundError(f"Example image {example_path} not found.")
40
+ else:
41
+ raise ValueError("Please upload an image or select an example image.")
42
+
43
  # Detection and segmentation function
44
+ def segment_objects(uploaded_image, example_image, threshold=0.5):
45
+ # Load the image (either uploaded or example)
46
+ image = load_image(uploaded_image, example_image)
47
+
48
  transform = torchvision.transforms.ToTensor()
49
  img_tensor = transform(image).unsqueeze(0)
50
 
 
94
  fn=segment_objects,
95
  inputs=[
96
  gr.Image(type="pil", label="Upload Image"),
97
+ gr.Dropdown(
98
+ choices=["None", "example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"],
99
+ value="None",
100
+ label="Select Example Image"
101
+ ),
102
  gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
103
  ],
104
  outputs=gr.Image(type="filepath", label="Segmented Output"),
105
  title="Mask R-CNN Instance Segmentation",
106
+ description="Upload an image or select an example image to detect and segment objects using a pretrained Mask R-CNN model (TorchVision)."
107
  )
108
 
109
  if __name__ == "__main__":