statmlben commited on
Commit
4522963
·
verified ·
1 Parent(s): 1a4a84d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -24,12 +24,18 @@ def load_model(model_name):
24
  model = models.segmentation.deeplabv3_resnet101(weights=weights)
25
  except:
26
  model = models.segmentation.deeplabv3_resnet101(pretrained=True)
27
- elif model_name == "FCN (ResNet50)":
28
  try:
29
- weights = models.segmentation.FCN_ResNet50_Weights.DEFAULT
30
- model = models.segmentation.fcn_resnet50(weights=weights)
31
  except:
32
- model = models.segmentation.fcn_resnet50(pretrained=True)
 
 
 
 
 
 
33
 
34
  model.eval()
35
  if torch.cuda.is_available():
@@ -130,7 +136,7 @@ article = """
130
  # Example images
131
  examples = [
132
  ["demo1.jpg", "DeepLabV3+ (ResNet50)"],
133
- ["demo2.png", "FCN (ResNet50)"]
134
  ]
135
 
136
  demo = gr.Interface(
@@ -138,7 +144,7 @@ demo = gr.Interface(
138
  inputs=[
139
  gr.Image(type="pil", label="Input Image"),
140
  gr.Dropdown(
141
- choices=["DeepLabV3+ (ResNet50)", "DeepLabV3+ (ResNet101)", "FCN (ResNet50)"],
142
  value="DeepLabV3+ (ResNet50)",
143
  label="Select Pre-trained Model"
144
  )
 
24
  model = models.segmentation.deeplabv3_resnet101(weights=weights)
25
  except:
26
  model = models.segmentation.deeplabv3_resnet101(pretrained=True)
27
+ elif model_name == "DeepLabV3+ (MobileNetV3)":
28
  try:
29
+ weights = models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
30
+ model = models.segmentation.deeplabv3_mobilenet_v3_large(weights=weights)
31
  except:
32
+ model = models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True)
33
+ elif model_name == "LRASPP (MobileNetV3)":
34
+ try:
35
+ weights = models.segmentation.LRASPP_MobileNet_V3_Large_Weights.DEFAULT
36
+ model = models.segmentation.lraspp_mobilenet_v3_large(weights=weights)
37
+ except:
38
+ model = models.segmentation.lraspp_mobilenet_v3_large(pretrained=True)
39
 
40
  model.eval()
41
  if torch.cuda.is_available():
 
136
  # Example images
137
  examples = [
138
  ["demo1.jpg", "DeepLabV3+ (ResNet50)"],
139
+ ["demo2.png", "LRASPP (MobileNetV3)"]
140
  ]
141
 
142
  demo = gr.Interface(
 
144
  inputs=[
145
  gr.Image(type="pil", label="Input Image"),
146
  gr.Dropdown(
147
+ choices=["DeepLabV3+ (ResNet50)", "DeepLabV3+ (ResNet101)", "DeepLabV3+ (MobileNetV3)", "LRASPP (MobileNetV3)"],
148
  value="DeepLabV3+ (ResNet50)",
149
  label="Select Pre-trained Model"
150
  )