Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -42,22 +42,25 @@ def draw_heatmap(image, mask):
|
|
| 42 |
# Define callable method for the demo
|
| 43 |
def get_mask(image):
|
| 44 |
if image is None:
|
| 45 |
-
return None
|
| 46 |
|
| 47 |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
|
| 48 |
dm_image = feature_extractor(image).unsqueeze(0)
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
masked_img = draw_mask(image, mask)
|
| 52 |
heatmap = draw_heatmap(image, mask)
|
| 53 |
-
return np.hstack((masked_img, heatmap))
|
| 54 |
|
| 55 |
|
| 56 |
# Launch demo interface
|
| 57 |
gr.Interface(
|
| 58 |
get_mask,
|
| 59 |
inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
| 60 |
-
outputs=[gr.outputs.Image(label="Output")],
|
| 61 |
title="Vision DiffMask Demo",
|
| 62 |
live=True,
|
| 63 |
).launch()
|
|
|
|
| 42 |
# Define callable method for the demo
|
| 43 |
def get_mask(image):
|
| 44 |
if image is None:
|
| 45 |
+
return None, None
|
| 46 |
|
| 47 |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
|
| 48 |
dm_image = feature_extractor(image).unsqueeze(0)
|
| 49 |
+
dm_out = diffmask.get_mask(dm_image)
|
| 50 |
+
mask = dm_out["mask"][0].detach()
|
| 51 |
+
pred = dm_out["pred_class"][0].detach()
|
| 52 |
+
pred = diffmask.model.config.id2label[pred.item()]
|
| 53 |
|
| 54 |
masked_img = draw_mask(image, mask)
|
| 55 |
heatmap = draw_heatmap(image, mask)
|
| 56 |
+
return np.hstack((masked_img, heatmap)), pred
|
| 57 |
|
| 58 |
|
| 59 |
# Launch demo interface
|
| 60 |
gr.Interface(
|
| 61 |
get_mask,
|
| 62 |
inputs=gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
|
| 63 |
+
outputs=[gr.outputs.Image(label="Output"), gr.outputs.Label(label="Prediction")],
|
| 64 |
title="Vision DiffMask Demo",
|
| 65 |
live=True,
|
| 66 |
).launch()
|