Spaces:
Running
on
Zero
Running
on
Zero
Updating sdk version and resolving compability issues -- image_prompter is removed, gradio_image_annotation added
Browse files- README.md +1 -1
- demo_gradio.py +213 -44
- requirements.txt +2 -2
README.md
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
---
|
| 2 |
title: GeCo2 Gradio Demo
|
| 3 |
sdk: gradio
|
| 4 |
-
sdk_version: "
|
| 5 |
python_version: "3.10.13"
|
| 6 |
app_file: demo_gradio.py
|
| 7 |
---
|
|
|
|
| 1 |
---
|
| 2 |
title: GeCo2 Gradio Demo
|
| 3 |
sdk: gradio
|
| 4 |
+
sdk_version: "5.50.0"
|
| 5 |
python_version: "3.10.13"
|
| 6 |
app_file: demo_gradio.py
|
| 7 |
---
|
demo_gradio.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import spaces
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
-
from
|
| 5 |
from torch.nn import DataParallel
|
| 6 |
from models.counter_infer import build_model
|
| 7 |
from utils.arg_parser import get_argparser
|
|
@@ -14,10 +14,55 @@ import numpy as np
|
|
| 14 |
import colorsys
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
_MODEL = None
|
| 18 |
_ARGS = None
|
| 19 |
_WEIGHTS_PATH = None
|
| 20 |
|
|
|
|
| 21 |
def _get_args():
|
| 22 |
global _ARGS
|
| 23 |
if _ARGS is None:
|
|
@@ -26,6 +71,7 @@ def _get_args():
|
|
| 26 |
_ARGS = args
|
| 27 |
return _ARGS
|
| 28 |
|
|
|
|
| 29 |
def _get_weights_path():
|
| 30 |
global _WEIGHTS_PATH
|
| 31 |
if _WEIGHTS_PATH is None:
|
|
@@ -36,6 +82,7 @@ def _get_weights_path():
|
|
| 36 |
)
|
| 37 |
return _WEIGHTS_PATH
|
| 38 |
|
|
|
|
| 39 |
def get_model_on_device(device: torch.device):
|
| 40 |
"""
|
| 41 |
Lazily build and load model, then move to the requested device.
|
|
@@ -63,22 +110,140 @@ def get_model_on_device(device: torch.device):
|
|
| 63 |
return _MODEL
|
| 64 |
|
| 65 |
|
| 66 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
@spaces.GPU
|
| 68 |
def process_image_once(inputs, enable_mask):
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
model = get_model_on_device(device)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
image = inputs["image"]
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
image_tensor = torch.tensor(image).to(device)
|
| 76 |
image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
|
| 77 |
image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
|
| 78 |
|
| 79 |
-
bboxes_tensor = torch.tensor(
|
| 80 |
-
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
|
| 84 |
img = img.unsqueeze(0).to(device)
|
|
@@ -88,13 +253,8 @@ def process_image_once(inputs, enable_mask):
|
|
| 88 |
model.module.return_masks = enable_mask
|
| 89 |
outputs, _, _, _, masks = model(img, bboxes)
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
# ZeroGPU requirement: return ONLY CPU-native objects to main process.
|
| 93 |
-
# Do NOT return CUDA tensors, and avoid returning output dicts that may
|
| 94 |
-
# contain additional CUDA tensors beyond pred_boxes/box_v.
|
| 95 |
-
# ------------------------------------------------------------------
|
| 96 |
out0 = outputs[0]
|
| 97 |
-
|
| 98 |
pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu()
|
| 99 |
box_v_cpu = out0["box_v"].detach().float().cpu()
|
| 100 |
|
|
@@ -108,7 +268,6 @@ def process_image_once(inputs, enable_mask):
|
|
| 108 |
else:
|
| 109 |
masks_cpu = [None]
|
| 110 |
|
| 111 |
-
# img is only used for shape in post_process, so return a CPU tensor
|
| 112 |
img_cpu = img.detach().cpu()
|
| 113 |
|
| 114 |
return image, outputs_cpu, masks_cpu, img_cpu, float(scale), drawn_boxes
|
|
@@ -123,22 +282,13 @@ def _hsv_to_rgb255(h, s, v):
|
|
| 123 |
|
| 124 |
|
| 125 |
def instance_colors(i: int):
|
| 126 |
-
"""
|
| 127 |
-
Pastel palette per instance.
|
| 128 |
-
- Mask: pastel fill
|
| 129 |
-
- Box: same hue, slightly more saturated (but still pastel-ish)
|
| 130 |
-
Deterministic hue stepping (golden ratio) for stable and distinct colors.
|
| 131 |
-
"""
|
| 132 |
h = (i * 0.618033988749895) % 1.0
|
| 133 |
-
mask_rgb = _hsv_to_rgb255(h, s=0.28, v=1.00)
|
| 134 |
-
box_rgb = _hsv_to_rgb255(h, s=0.42, v=0.95)
|
| 135 |
return mask_rgb, box_rgb
|
| 136 |
|
| 137 |
|
| 138 |
def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alpha=0.45):
|
| 139 |
-
"""
|
| 140 |
-
Alpha-composite a single instance mask (boolean HxW) in given rgb onto base_rgba.
|
| 141 |
-
"""
|
| 142 |
if mask_bool.dtype != np.bool_:
|
| 143 |
mask_bool = mask_bool.astype(bool)
|
| 144 |
|
|
@@ -153,12 +303,19 @@ def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alph
|
|
| 153 |
return Image.alpha_composite(base_rgba, overlay_img)
|
| 154 |
|
| 155 |
|
| 156 |
-
#
|
|
|
|
|
|
|
| 157 |
def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
|
| 158 |
idx = 0
|
| 159 |
threshold = 1 / threshold
|
| 160 |
|
| 161 |
score = outputs[idx]["box_v"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
score_mask = score > score.max() / threshold
|
| 163 |
|
| 164 |
keep = ops.nms(
|
|
@@ -171,20 +328,17 @@ def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, th
|
|
| 171 |
pred_boxes = torch.clamp(pred_boxes, 0, 1)
|
| 172 |
pred_boxes = (pred_boxes / scale * img.shape[-1]).tolist()
|
| 173 |
|
| 174 |
-
# Base image as RGBA for compositing
|
| 175 |
image = Image.fromarray((image).astype(np.uint8)).convert("RGBA")
|
| 176 |
|
| 177 |
-
# --- Masks: per-instance, pastel, matching box hue ---
|
| 178 |
if enable_mask and masks is not None and masks[idx] is not None:
|
| 179 |
masks_sel = masks[idx][score_mask[0]] if score_mask.ndim > 1 else masks[idx][score_mask]
|
| 180 |
-
masks_sel = masks_sel[keep]
|
| 181 |
|
| 182 |
target_h = int(img.shape[2] / scale)
|
| 183 |
target_w = int(img.shape[3] / scale)
|
| 184 |
resize_nearest = T.Resize((target_h, target_w), interpolation=T.InterpolationMode.NEAREST)
|
| 185 |
|
| 186 |
W, H = image.size
|
| 187 |
-
|
| 188 |
for i in range(masks_sel.shape[0]):
|
| 189 |
mask_i = masks_sel[i]
|
| 190 |
if mask_i.ndim == 3:
|
|
@@ -197,37 +351,38 @@ def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, th
|
|
| 197 |
mask_rgb, _ = instance_colors(i)
|
| 198 |
image = overlay_single_mask(image, mask_bool, mask_rgb, alpha=0.45)
|
| 199 |
|
| 200 |
-
# --- Boxes: thin, pastel, no labels/text ---
|
| 201 |
draw = ImageDraw.Draw(image)
|
| 202 |
-
box_width = 2
|
| 203 |
|
| 204 |
for i, box in enumerate(pred_boxes):
|
| 205 |
_, box_rgb = instance_colors(i)
|
| 206 |
x1, y1, x2, y2 = map(float, box)
|
| 207 |
draw.rectangle([x1, y1, x2, y2], outline=box_rgb, width=box_width)
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
exemplar_inner = (0, 0, 0, 255) # black
|
| 212 |
for box in drawn_boxes:
|
| 213 |
x1, y1, x2, y2 = box[0], box[1], box[3], box[4]
|
| 214 |
draw.rectangle([x1, y1, x2, y2], outline=exemplar_outline, width=2)
|
| 215 |
draw.rectangle([x1 + 1, y1 + 1, x2 - 1, y2 - 1], outline=exemplar_inner, width=1)
|
| 216 |
|
| 217 |
-
# Return without any text/labels on the image
|
| 218 |
return image.convert("RGB"), len(pred_boxes)
|
| 219 |
|
| 220 |
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
with iface:
|
| 224 |
gr.Markdown(
|
| 225 |
"""
|
| 226 |
# GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation
|
| 227 |
-
|
| 228 |
-
GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image wihtout any retraining.
|
| 229 |
-
|
| 230 |
-
|
| 231 |
1) Upload an image.
|
| 232 |
2) Draw bounding boxes on the target object (preferably ~3 instances).
|
| 233 |
3) Click **Count**.
|
|
@@ -244,7 +399,17 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
|
|
| 244 |
drawn_boxes_state = gr.State()
|
| 245 |
|
| 246 |
with gr.Row():
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
image_output = gr.Image(type="pil")
|
| 249 |
|
| 250 |
with gr.Row():
|
|
@@ -256,6 +421,8 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
|
|
| 256 |
|
| 257 |
def initial_process(inputs, enable_mask, threshold):
|
| 258 |
image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
|
|
|
|
|
|
|
| 259 |
return (
|
| 260 |
*post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
|
| 261 |
image,
|
|
@@ -267,11 +434,13 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
|
|
| 267 |
)
|
| 268 |
|
| 269 |
def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
|
|
|
|
|
|
|
| 270 |
return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
|
| 271 |
|
| 272 |
count_button.click(
|
| 273 |
initial_process,
|
| 274 |
-
[
|
| 275 |
[image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state],
|
| 276 |
)
|
| 277 |
|
|
@@ -288,4 +457,4 @@ GeCo2 is a few-shot, category-agnostic detection counter. With only a small numb
|
|
| 288 |
)
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
| 291 |
-
iface.launch()
|
|
|
|
| 1 |
import spaces
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
+
from gradio_image_annotation import image_annotator
|
| 5 |
from torch.nn import DataParallel
|
| 6 |
from models.counter_infer import build_model
|
| 7 |
from utils.arg_parser import get_argparser
|
|
|
|
| 14 |
import colorsys
|
| 15 |
|
| 16 |
|
| 17 |
+
# -----------------------------
|
| 18 |
+
# Minimal UI + force "Create" mode (press C a few times)
|
| 19 |
+
# -----------------------------
|
| 20 |
+
JS_FORCE_CREATE_MODE = r"""
|
| 21 |
+
function () {
|
| 22 |
+
const pressC = () => {
|
| 23 |
+
const ev = new KeyboardEvent("keydown", {
|
| 24 |
+
key: "c",
|
| 25 |
+
code: "KeyC",
|
| 26 |
+
bubbles: true
|
| 27 |
+
});
|
| 28 |
+
document.dispatchEvent(ev);
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
let tries = 0;
|
| 32 |
+
const t = setInterval(() => {
|
| 33 |
+
tries++;
|
| 34 |
+
pressC();
|
| 35 |
+
if (tries > 20) clearInterval(t);
|
| 36 |
+
}, 200);
|
| 37 |
+
}
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
CSS_MINIMAL_UI = """
|
| 41 |
+
/* Hide labels, instructions, help text */
|
| 42 |
+
.gradio-container label,
|
| 43 |
+
.gradio-container .block-label,
|
| 44 |
+
.gradio-container .markdown,
|
| 45 |
+
.gradio-container p {
|
| 46 |
+
display: none !important;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
/* Reduce rounding of UI containers */
|
| 50 |
+
.gradio-container [class*="rounded"] {
|
| 51 |
+
border-radius: 4px !important;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
/* Reduce padding */
|
| 55 |
+
.gradio-container [class*="p-4"] {
|
| 56 |
+
padding: 0.25rem !important;
|
| 57 |
+
}
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
_MODEL = None
|
| 62 |
_ARGS = None
|
| 63 |
_WEIGHTS_PATH = None
|
| 64 |
|
| 65 |
+
|
| 66 |
def _get_args():
|
| 67 |
global _ARGS
|
| 68 |
if _ARGS is None:
|
|
|
|
| 71 |
_ARGS = args
|
| 72 |
return _ARGS
|
| 73 |
|
| 74 |
+
|
| 75 |
def _get_weights_path():
|
| 76 |
global _WEIGHTS_PATH
|
| 77 |
if _WEIGHTS_PATH is None:
|
|
|
|
| 82 |
)
|
| 83 |
return _WEIGHTS_PATH
|
| 84 |
|
| 85 |
+
|
| 86 |
def get_model_on_device(device: torch.device):
|
| 87 |
"""
|
| 88 |
Lazily build and load model, then move to the requested device.
|
|
|
|
| 110 |
return _MODEL
|
| 111 |
|
| 112 |
|
| 113 |
+
# -----------------------------
|
| 114 |
+
# Rotation helper (in case annotator reports orientation)
|
| 115 |
+
# -----------------------------
|
| 116 |
+
def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int):
|
| 117 |
+
"""
|
| 118 |
+
angle is in 90-degree steps. The gradio_image_annotation README demonstrates:
|
| 119 |
+
np.rot90(image, k=-angle)
|
| 120 |
+
so angle=1 => rotate clockwise 90 deg.
|
| 121 |
+
"""
|
| 122 |
+
if angle is None:
|
| 123 |
+
return image_np, boxes
|
| 124 |
+
|
| 125 |
+
a = int(angle) % 4
|
| 126 |
+
if a == 0:
|
| 127 |
+
return image_np, boxes
|
| 128 |
+
|
| 129 |
+
H, W = image_np.shape[:2]
|
| 130 |
+
|
| 131 |
+
# rotate image using the same convention as the component docs
|
| 132 |
+
image_rot = np.rot90(image_np, k=-a)
|
| 133 |
+
|
| 134 |
+
def clamp_box(xmin, ymin, xmax, ymax, newW, newH):
|
| 135 |
+
xmin = max(0, min(newW, xmin))
|
| 136 |
+
xmax = max(0, min(newW, xmax))
|
| 137 |
+
ymin = max(0, min(newH, ymin))
|
| 138 |
+
ymax = max(0, min(newH, ymax))
|
| 139 |
+
# ensure ordering
|
| 140 |
+
if xmax < xmin:
|
| 141 |
+
xmin, xmax = xmax, xmin
|
| 142 |
+
if ymax < ymin:
|
| 143 |
+
ymin, ymax = ymax, ymin
|
| 144 |
+
return xmin, ymin, xmax, ymax
|
| 145 |
+
|
| 146 |
+
boxes_rot = []
|
| 147 |
+
if a == 1:
|
| 148 |
+
# 90 deg clockwise: (x,y) -> (H - 1 - y, x)
|
| 149 |
+
newH, newW = W, H
|
| 150 |
+
for b in boxes:
|
| 151 |
+
xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
|
| 152 |
+
nxmin = H - ymax
|
| 153 |
+
nxmax = H - ymin
|
| 154 |
+
nymin = xmin
|
| 155 |
+
nymax = xmax
|
| 156 |
+
nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
|
| 157 |
+
bb = dict(b)
|
| 158 |
+
bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
|
| 159 |
+
boxes_rot.append(bb)
|
| 160 |
+
|
| 161 |
+
elif a == 2:
|
| 162 |
+
# 180 deg: (x,y) -> (W - 1 - x, H - 1 - y)
|
| 163 |
+
newH, newW = H, W
|
| 164 |
+
for b in boxes:
|
| 165 |
+
xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
|
| 166 |
+
nxmin = W - xmax
|
| 167 |
+
nxmax = W - xmin
|
| 168 |
+
nymin = H - ymax
|
| 169 |
+
nymax = H - ymin
|
| 170 |
+
nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
|
| 171 |
+
bb = dict(b)
|
| 172 |
+
bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
|
| 173 |
+
boxes_rot.append(bb)
|
| 174 |
+
|
| 175 |
+
else: # a == 3
|
| 176 |
+
# 90 deg counter-clockwise: (x,y) -> (y, W - 1 - x)
|
| 177 |
+
newH, newW = W, H
|
| 178 |
+
for b in boxes:
|
| 179 |
+
xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
|
| 180 |
+
nxmin = ymin
|
| 181 |
+
nxmax = ymax
|
| 182 |
+
nymin = W - xmax
|
| 183 |
+
nymax = W - xmin
|
| 184 |
+
nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
|
| 185 |
+
bb = dict(b)
|
| 186 |
+
bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
|
| 187 |
+
boxes_rot.append(bb)
|
| 188 |
+
|
| 189 |
+
return image_rot, boxes_rot
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# -----------------------------
|
| 193 |
+
# Function to Process Image Once (GPU)
|
| 194 |
+
# -----------------------------
|
| 195 |
@spaces.GPU
|
| 196 |
def process_image_once(inputs, enable_mask):
|
| 197 |
+
"""
|
| 198 |
+
inputs is AnnotatedImageValue-like dict from gradio_image_annotation:
|
| 199 |
+
{
|
| 200 |
+
"image": np.ndarray | PIL | str,
|
| 201 |
+
"boxes": [ {xmin,ymin,xmax,ymax,label?,color?}, ... ],
|
| 202 |
+
"orientation": int?
|
| 203 |
+
}
|
| 204 |
+
"""
|
| 205 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 206 |
model = get_model_on_device(device)
|
| 207 |
|
| 208 |
+
if inputs is None or inputs.get("image", None) is None:
|
| 209 |
+
# keep behavior simple: return empty outputs
|
| 210 |
+
return None, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
|
| 211 |
+
|
| 212 |
image = inputs["image"]
|
| 213 |
+
boxes = inputs.get("boxes", []) or []
|
| 214 |
+
|
| 215 |
+
# Ensure numpy image
|
| 216 |
+
if isinstance(image, Image.Image):
|
| 217 |
+
image = np.array(image)
|
| 218 |
+
elif isinstance(image, str):
|
| 219 |
+
# If you ever allow URL/path returns, you’d need to load it here.
|
| 220 |
+
# For now, enforce image_type="numpy" in the UI so this does not occur.
|
| 221 |
+
raise ValueError("Annotator returned image as str. Set image_type='numpy' on image_annotator.")
|
| 222 |
+
|
| 223 |
+
# Handle orientation if provided (rare but supported by component)
|
| 224 |
+
angle = inputs.get("orientation", None)
|
| 225 |
+
if angle is not None:
|
| 226 |
+
image, boxes = _rotate_image_and_boxes(image, boxes, angle)
|
| 227 |
+
|
| 228 |
+
# Convert boxes dicts to your legacy list format so downstream code stays unchanged:
|
| 229 |
+
# drawn_boxes elements must support [0],[1],[3],[4] usage in your code.
|
| 230 |
+
# We'll encode as: [x1, y1, 0, x2, y2]
|
| 231 |
+
drawn_boxes = []
|
| 232 |
+
for b in boxes:
|
| 233 |
+
drawn_boxes.append([float(b["xmin"]), float(b["ymin"]), 0.0, float(b["xmax"]), float(b["ymax"])])
|
| 234 |
+
|
| 235 |
+
# If no boxes, keep consistent behavior (model call would likely fail)
|
| 236 |
+
if len(drawn_boxes) == 0:
|
| 237 |
+
return image, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
|
| 238 |
+
|
| 239 |
image_tensor = torch.tensor(image).to(device)
|
| 240 |
image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
|
| 241 |
image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
|
| 242 |
|
| 243 |
+
bboxes_tensor = torch.tensor(
|
| 244 |
+
[[box[0], box[1], box[3], box[4]] for box in drawn_boxes],
|
| 245 |
+
dtype=torch.float32,
|
| 246 |
+
).to(device)
|
| 247 |
|
| 248 |
img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
|
| 249 |
img = img.unsqueeze(0).to(device)
|
|
|
|
| 253 |
model.module.return_masks = enable_mask
|
| 254 |
outputs, _, _, _, masks = model(img, bboxes)
|
| 255 |
|
| 256 |
+
# Return ONLY CPU-native objects to main process.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
out0 = outputs[0]
|
|
|
|
| 258 |
pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu()
|
| 259 |
box_v_cpu = out0["box_v"].detach().float().cpu()
|
| 260 |
|
|
|
|
| 268 |
else:
|
| 269 |
masks_cpu = [None]
|
| 270 |
|
|
|
|
| 271 |
img_cpu = img.detach().cpu()
|
| 272 |
|
| 273 |
return image, outputs_cpu, masks_cpu, img_cpu, float(scale), drawn_boxes
|
|
|
|
| 282 |
|
| 283 |
|
| 284 |
def instance_colors(i: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
h = (i * 0.618033988749895) % 1.0
|
| 286 |
+
mask_rgb = _hsv_to_rgb255(h, s=0.28, v=1.00)
|
| 287 |
+
box_rgb = _hsv_to_rgb255(h, s=0.42, v=0.95)
|
| 288 |
return mask_rgb, box_rgb
|
| 289 |
|
| 290 |
|
| 291 |
def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alpha=0.45):
|
|
|
|
|
|
|
|
|
|
| 292 |
if mask_bool.dtype != np.bool_:
|
| 293 |
mask_bool = mask_bool.astype(bool)
|
| 294 |
|
|
|
|
| 303 |
return Image.alpha_composite(base_rgba, overlay_img)
|
| 304 |
|
| 305 |
|
| 306 |
+
# -----------------------------
|
| 307 |
+
# Post-process and Update Output
|
| 308 |
+
# -----------------------------
|
| 309 |
def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
|
| 310 |
idx = 0
|
| 311 |
threshold = 1 / threshold
|
| 312 |
|
| 313 |
score = outputs[idx]["box_v"]
|
| 314 |
+
if score.numel() == 0:
|
| 315 |
+
# no predictions
|
| 316 |
+
image_pil = Image.fromarray((image).astype(np.uint8)).convert("RGB")
|
| 317 |
+
return image_pil, 0
|
| 318 |
+
|
| 319 |
score_mask = score > score.max() / threshold
|
| 320 |
|
| 321 |
keep = ops.nms(
|
|
|
|
| 328 |
pred_boxes = torch.clamp(pred_boxes, 0, 1)
|
| 329 |
pred_boxes = (pred_boxes / scale * img.shape[-1]).tolist()
|
| 330 |
|
|
|
|
| 331 |
image = Image.fromarray((image).astype(np.uint8)).convert("RGBA")
|
| 332 |
|
|
|
|
| 333 |
if enable_mask and masks is not None and masks[idx] is not None:
|
| 334 |
masks_sel = masks[idx][score_mask[0]] if score_mask.ndim > 1 else masks[idx][score_mask]
|
| 335 |
+
masks_sel = masks_sel[keep]
|
| 336 |
|
| 337 |
target_h = int(img.shape[2] / scale)
|
| 338 |
target_w = int(img.shape[3] / scale)
|
| 339 |
resize_nearest = T.Resize((target_h, target_w), interpolation=T.InterpolationMode.NEAREST)
|
| 340 |
|
| 341 |
W, H = image.size
|
|
|
|
| 342 |
for i in range(masks_sel.shape[0]):
|
| 343 |
mask_i = masks_sel[i]
|
| 344 |
if mask_i.ndim == 3:
|
|
|
|
| 351 |
mask_rgb, _ = instance_colors(i)
|
| 352 |
image = overlay_single_mask(image, mask_bool, mask_rgb, alpha=0.45)
|
| 353 |
|
|
|
|
| 354 |
draw = ImageDraw.Draw(image)
|
| 355 |
+
box_width = 2
|
| 356 |
|
| 357 |
for i, box in enumerate(pred_boxes):
|
| 358 |
_, box_rgb = instance_colors(i)
|
| 359 |
x1, y1, x2, y2 = map(float, box)
|
| 360 |
draw.rectangle([x1, y1, x2, y2], outline=box_rgb, width=box_width)
|
| 361 |
|
| 362 |
+
exemplar_outline = (255, 255, 255, 255)
|
| 363 |
+
exemplar_inner = (0, 0, 0, 255)
|
|
|
|
| 364 |
for box in drawn_boxes:
|
| 365 |
x1, y1, x2, y2 = box[0], box[1], box[3], box[4]
|
| 366 |
draw.rectangle([x1, y1, x2, y2], outline=exemplar_outline, width=2)
|
| 367 |
draw.rectangle([x1 + 1, y1 + 1, x2 - 1, y2 - 1], outline=exemplar_inner, width=1)
|
| 368 |
|
|
|
|
| 369 |
return image.convert("RGB"), len(pred_boxes)
|
| 370 |
|
| 371 |
|
| 372 |
+
# -----------------------------
|
| 373 |
+
# Gradio UI
|
| 374 |
+
# -----------------------------
|
| 375 |
+
iface = gr.Blocks(
|
| 376 |
+
title="GeCo2 Gradio Demo",
|
| 377 |
+
js=JS_FORCE_CREATE_MODE,
|
| 378 |
+
css=CSS_MINIMAL_UI,
|
| 379 |
+
)
|
| 380 |
|
| 381 |
with iface:
|
| 382 |
gr.Markdown(
|
| 383 |
"""
|
| 384 |
# GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation
|
| 385 |
+
GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image without any retraining.
|
|
|
|
|
|
|
|
|
|
| 386 |
1) Upload an image.
|
| 387 |
2) Draw bounding boxes on the target object (preferably ~3 instances).
|
| 388 |
3) Click **Count**.
|
|
|
|
| 399 |
drawn_boxes_state = gr.State()
|
| 400 |
|
| 401 |
with gr.Row():
|
| 402 |
+
# New annotator component
|
| 403 |
+
annotator = image_annotator(
|
| 404 |
+
value=None,
|
| 405 |
+
image_type="numpy", # ensures inputs["image"] is a numpy array
|
| 406 |
+
label_list=["Object"],
|
| 407 |
+
label_colors=[(0, 255, 0)],
|
| 408 |
+
use_default_label=True,
|
| 409 |
+
enable_keyboard_shortcuts=True,
|
| 410 |
+
interactive=True,
|
| 411 |
+
show_label=False, # hide label text on boxes
|
| 412 |
+
)
|
| 413 |
image_output = gr.Image(type="pil")
|
| 414 |
|
| 415 |
with gr.Row():
|
|
|
|
| 421 |
|
| 422 |
def initial_process(inputs, enable_mask, threshold):
|
| 423 |
image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
|
| 424 |
+
if image is None:
|
| 425 |
+
return None, 0, None, None, None, None, None, None
|
| 426 |
return (
|
| 427 |
*post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
|
| 428 |
image,
|
|
|
|
| 434 |
)
|
| 435 |
|
| 436 |
def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
|
| 437 |
+
if image is None or outputs is None or img is None:
|
| 438 |
+
return None, 0
|
| 439 |
return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
|
| 440 |
|
| 441 |
count_button.click(
|
| 442 |
initial_process,
|
| 443 |
+
[annotator, enable_mask, threshold],
|
| 444 |
[image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state],
|
| 445 |
)
|
| 446 |
|
|
|
|
| 457 |
)
|
| 458 |
|
| 459 |
if __name__ == "__main__":
|
| 460 |
+
iface.queue().launch()
|
requirements.txt
CHANGED
|
@@ -110,5 +110,5 @@ websockets==12.0
|
|
| 110 |
zipp==3.21.0
|
| 111 |
spaces
|
| 112 |
gradio_client
|
| 113 |
-
gradio
|
| 114 |
-
|
|
|
|
| 110 |
zipp==3.21.0
|
| 111 |
spaces
|
| 112 |
gradio_client
|
| 113 |
+
gradio==5.50.0
|
| 114 |
+
gradio_image_annotation
|