Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -60,6 +60,29 @@ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
|
|
| 60 |
return similarity
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def filter_masks(
|
| 64 |
image: np.ndarray,
|
| 65 |
masks: List[Dict[str, Any]],
|
|
@@ -77,15 +100,8 @@ def filter_masks(
|
|
| 77 |
or mask["stability_score"] < stability_score_threshold
|
| 78 |
):
|
| 79 |
continue
|
| 80 |
-
|
| 81 |
filtered_masks.append(mask)
|
| 82 |
-
|
| 83 |
-
x, y, w, h = mask["bbox"]
|
| 84 |
-
masked = image * np.expand_dims(mask["segmentation"], -1)
|
| 85 |
-
crop = masked[y: y + h, x: x + w]
|
| 86 |
-
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
| 87 |
-
crop = PIL.Image.fromarray(crop)
|
| 88 |
-
cropped_masks.append(crop)
|
| 89 |
|
| 90 |
if query and filtered_masks:
|
| 91 |
scores = get_scores(cropped_masks, query)
|
|
@@ -167,9 +183,9 @@ demo = gr.Interface(
|
|
| 167 |
[
|
| 168 |
0.9,
|
| 169 |
0.8,
|
| 170 |
-
0.
|
| 171 |
os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
|
| 172 |
-
"
|
| 173 |
],
|
| 174 |
[
|
| 175 |
0.9,
|
|
|
|
| 60 |
return similarity
|
| 61 |
|
| 62 |
|
| 63 |
+
def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
|
| 64 |
+
x, y, w, h = mask["bbox"]
|
| 65 |
+
masked = image * np.expand_dims(mask["segmentation"], -1)
|
| 66 |
+
crop = masked[y : y + h, x : x + w]
|
| 67 |
+
if h > w:
|
| 68 |
+
top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
|
| 69 |
+
else:
|
| 70 |
+
top, bottom, left, right = (w - h) // 2, (w - h) // 2, 0, 0
|
| 71 |
+
# padding
|
| 72 |
+
crop = cv2.copyMakeBorder(
|
| 73 |
+
crop,
|
| 74 |
+
top,
|
| 75 |
+
bottom,
|
| 76 |
+
left,
|
| 77 |
+
right,
|
| 78 |
+
cv2.BORDER_CONSTANT,
|
| 79 |
+
value=(0, 0, 0),
|
| 80 |
+
)
|
| 81 |
+
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
| 82 |
+
crop = PIL.Image.fromarray(crop)
|
| 83 |
+
return crop
|
| 84 |
+
|
| 85 |
+
|
| 86 |
def filter_masks(
|
| 87 |
image: np.ndarray,
|
| 88 |
masks: List[Dict[str, Any]],
|
|
|
|
| 100 |
or mask["stability_score"] < stability_score_threshold
|
| 101 |
):
|
| 102 |
continue
|
|
|
|
| 103 |
filtered_masks.append(mask)
|
| 104 |
+
cropped_masks.append(crop_image(image, mask))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
if query and filtered_masks:
|
| 107 |
scores = get_scores(cropped_masks, query)
|
|
|
|
| 183 |
[
|
| 184 |
0.9,
|
| 185 |
0.8,
|
| 186 |
+
0.001,
|
| 187 |
os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
|
| 188 |
+
"building",
|
| 189 |
],
|
| 190 |
[
|
| 191 |
0.9,
|