Update app.py
Browse files
app.py
CHANGED
|
@@ -104,18 +104,42 @@ def analyze_image(image):
|
|
| 104 |
|
| 105 |
|
| 106 |
def show_mask(mask, ax, random_color=False):
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def process_image_detection(image, target_label, surprise_rating):
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
def show_mask(mask, ax, random_color=False):
|
| 107 |
+
try:
|
| 108 |
+
# Debug print to understand mask type
|
| 109 |
+
print(f"show_mask input type: {type(mask)}")
|
| 110 |
+
|
| 111 |
+
# Convert mask if it's a tuple
|
| 112 |
+
if isinstance(mask, tuple):
|
| 113 |
+
if len(mask) > 0 and mask[0] is not None:
|
| 114 |
+
mask = mask[0]
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError("Invalid mask tuple")
|
| 117 |
+
|
| 118 |
+
# Convert torch tensor to numpy if needed
|
| 119 |
+
if torch.is_tensor(mask):
|
| 120 |
+
mask = mask.cpu().numpy()
|
| 121 |
+
|
| 122 |
+
# Handle 4D tensor/array case
|
| 123 |
+
if len(mask.shape) == 4:
|
| 124 |
+
mask = mask[0, 0]
|
| 125 |
+
# Handle 3D tensor/array case
|
| 126 |
+
elif len(mask.shape) == 3:
|
| 127 |
+
mask = mask[0]
|
| 128 |
+
|
| 129 |
+
if random_color:
|
| 130 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 131 |
+
else:
|
| 132 |
+
color = np.array([1.0, 0.0, 0.0, 0.5])
|
| 133 |
|
| 134 |
+
mask_image = np.zeros((*mask.shape, 4), dtype=np.float32)
|
| 135 |
+
mask_image[mask > 0] = color
|
| 136 |
|
| 137 |
+
ax.imshow(mask_image)
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"show_mask error: {str(e)}")
|
| 141 |
+
print(f"mask shape: {getattr(mask, 'shape', 'no shape')}")
|
| 142 |
+
raise
|
| 143 |
|
| 144 |
|
| 145 |
def process_image_detection(image, target_label, surprise_rating):
|