GECO2-demo / demo_gradio.py
jerpelhan's picture
Allow for 3px bboxes, bbox thickness reduced
0bdf549
import spaces
import torch
import torch.nn.functional as F
import gradio as gr
from gradio_image_annotation import image_annotator
from models.counter_infer import build_model
from utils.arg_parser import get_argparser
from utils.data import resize_and_pad
import torchvision.ops as ops
from torchvision import transforms as T
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
import numpy as np
import colorsys
# -----------------------------
_MODEL = None
_ARGS = None
_WEIGHTS_PATH = None
# -----------------------------
def _get_args():
global _ARGS
if _ARGS is None:
args = get_argparser().parse_args()
args.zero_shot = True
_ARGS = args
return _ARGS
def _get_weights_path():
global _WEIGHTS_PATH
if _WEIGHTS_PATH is None:
_WEIGHTS_PATH = hf_hub_download(
repo_id="jerpelhan/geco2-assets",
filename="weights/CNTQG_multitrain_ca44.pth",
repo_type="dataset",
)
return _WEIGHTS_PATH
def _strip_module_prefix(state_dict: dict) -> dict:
"""
If weights were saved from torch.nn.DataParallel, keys are often prefixed with 'module.'.
When loading into a non-DataParallel model, strip that prefix.
"""
if not isinstance(state_dict, dict) or len(state_dict) == 0:
return state_dict
# Only strip if it looks like DP
has_module = any(k.startswith("module.") for k in state_dict.keys())
if not has_module:
return state_dict
return {k[len("module.") :]: v for k, v in state_dict.items()}
def _extract_state_dict(ckpt) -> dict:
"""
Robustly extract a state_dict from typical checkpoint formats.
"""
if isinstance(ckpt, dict):
# Common keys
if "model" in ckpt and isinstance(ckpt["model"], dict):
return ckpt["model"]
if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
return ckpt["state_dict"]
# Fallback: checkpoint itself is the state_dict
return ckpt
def get_model_on_device(device: torch.device):
"""
Lazily build and load model, then move to the requested device.
IMPORTANT: model is constructed/loaded without initializing CUDA in the main process.
This function will be called from inside the @spaces.GPU worker.
"""
global _MODEL
if _MODEL is None:
args = _get_args()
# Build on CPU first to avoid CUDA init in the wrong process
model = build_model(args)
weights_path = _get_weights_path()
ckpt = torch.load(weights_path, map_location="cpu") # keep compatibility across torch versions
state = _extract_state_dict(ckpt)
state = _strip_module_prefix(state)
model.load_state_dict(state, strict=False)
model.eval()
_MODEL = model
_MODEL = _MODEL.to(device)
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
return _MODEL
# -----------------------------
# Rotation helper (in case annotator reports orientation)
# -----------------------------
def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int):
if angle is None:
return image_np, boxes
a = int(angle) % 4
if a == 0:
return image_np, boxes
H, W = image_np.shape[:2]
# rotate image using the same convention as the component docs
image_rot = np.rot90(image_np, k=-a)
def clamp_box(xmin, ymin, xmax, ymax, newW, newH):
xmin = max(0, min(newW, xmin))
xmax = max(0, min(newW, xmax))
ymin = max(0, min(newH, ymin))
ymax = max(0, min(newH, ymax))
if xmax < xmin:
xmin, xmax = xmax, xmin
if ymax < ymin:
ymin, ymax = ymax, ymin
return xmin, ymin, xmax, ymax
boxes_rot = []
if a == 1:
# 90 deg clockwise: (x,y) -> (H - 1 - y, x)
newH, newW = W, H
for b in boxes:
xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
nxmin = H - ymax
nxmax = H - ymin
nymin = xmin
nymax = xmax
nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
bb = dict(b)
bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
boxes_rot.append(bb)
elif a == 2:
# 180 deg: (x,y) -> (W - 1 - x, H - 1 - y)
newH, newW = H, W
for b in boxes:
xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
nxmin = W - xmax
nxmax = W - xmin
nymin = H - ymax
nymax = H - ymin
nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
bb = dict(b)
bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
boxes_rot.append(bb)
else: # a == 3
# 90 deg counter-clockwise: (x,y) -> (y, W - 1 - x)
newH, newW = W, H
for b in boxes:
xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"]
nxmin = ymin
nxmax = ymax
nymin = W - xmax
nymax = W - xmin
nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH)
bb = dict(b)
bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax})
boxes_rot.append(bb)
return image_rot, boxes_rot
# -----------------------------
# Function to Process Image Once (GPU)
# -----------------------------
@spaces.GPU
def process_image_once(inputs, enable_mask):
"""
inputs is AnnotatedImageValue-like dict from gradio_image_annotation:
{
"image": np.ndarray | PIL | str,
"boxes": [ {xmin,ymin,xmax,ymax,label?,color?}, ... ],
"orientation": int?
}
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model_on_device(device)
if inputs is None or inputs.get("image", None) is None:
# keep behavior simple: return empty outputs
return None, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
image = inputs["image"]
boxes = inputs.get("boxes", []) or []
# Ensure numpy image (support numpy, PIL, OR local path string)
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB"))
elif isinstance(image, str):
image = np.array(Image.open(image).convert("RGB"))
elif isinstance(image, np.ndarray):
pass
else:
raise ValueError(f"Unsupported image type from annotator: {type(image)}")
angle = inputs.get("orientation", None)
if angle is not None:
image, boxes = _rotate_image_and_boxes(image, boxes, angle)
drawn_boxes = []
for b in boxes:
drawn_boxes.append([float(b["xmin"]), float(b["ymin"]), 0.0, float(b["xmax"]), float(b["ymax"])])
# If no boxes, do not call model (caller will handle warning)
if len(drawn_boxes) == 0:
return image, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, []
image_tensor = torch.tensor(image).to(device)
image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
bboxes_tensor = torch.tensor(
[[box[0], box[1], box[3], box[4]] for box in drawn_boxes],
dtype=torch.float32,
).to(device)
img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
img = img.unsqueeze(0).to(device)
bboxes = bboxes.unsqueeze(0).to(device)
# Faster inference mode
use_amp = (device.type == "cuda")
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
model.return_masks = enable_mask
outputs, _, _, _, masks = model(img, bboxes)
# Return ONLY CPU-native objects to main process.
out0 = outputs[0]
pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu()
box_v_cpu = out0["box_v"].detach().float().cpu()
outputs_cpu = [{"pred_boxes": pred_boxes_cpu, "box_v": box_v_cpu}]
if enable_mask and masks is not None and masks[0] is not None:
masks_cpu = [masks[0].detach().float().cpu()]
else:
masks_cpu = [None]
img_cpu = img.detach().cpu()
return image, outputs_cpu, masks_cpu, img_cpu, float(scale), drawn_boxes
# -----------------------------
# Pastel visualization helpers
# -----------------------------
def _hsv_to_rgb255(h, s, v):
r, g, b = colorsys.hsv_to_rgb(h, s, v)
return (int(255 * r), int(255 * g), int(255 * b))
def instance_colors(i: int):
h = (i * 0.618033988749895) % 1.0
mask_rgb = _hsv_to_rgb255(h, s=0.28, v=1.00)
box_rgb = _hsv_to_rgb255(h, s=0.42, v=0.95)
return mask_rgb, box_rgb
def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alpha=0.45):
if mask_bool.dtype != np.bool_:
mask_bool = mask_bool.astype(bool)
h, w = mask_bool.shape
overlay = np.zeros((h, w, 4), dtype=np.uint8)
overlay[..., 0] = rgb[0]
overlay[..., 1] = rgb[1]
overlay[..., 2] = rgb[2]
overlay[..., 3] = (mask_bool.astype(np.uint8) * int(255 * alpha))
overlay_img = Image.fromarray(overlay, mode="RGBA")
return Image.alpha_composite(base_rgba, overlay_img)
# -----------------------------
# Post-process and Update Output
# -----------------------------
def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
idx = 0
threshold = 1 / threshold
score = outputs[idx]["box_v"]
if score.numel() == 0:
# no predictions
image_pil = Image.fromarray((image).astype(np.uint8)).convert("RGB")
return image_pil, 0
score_mask = score > score.max() / threshold
keep = ops.nms(
outputs[idx]["pred_boxes"][score_mask],
score[score_mask],
0.5,
)
pred_boxes = outputs[idx]["pred_boxes"][score_mask][keep]
pred_boxes = torch.clamp(pred_boxes, 0, 1)
pred_boxes = (pred_boxes / scale * img.shape[-1]).tolist()
image = Image.fromarray((image).astype(np.uint8)).convert("RGBA")
if enable_mask and masks is not None and masks[idx] is not None:
masks_sel = masks[idx][score_mask[0]] if score_mask.ndim > 1 else masks[idx][score_mask]
masks_sel = masks_sel[keep]
target_h = int(img.shape[2] / scale)
target_w = int(img.shape[3] / scale)
resize_nearest = T.Resize((target_h, target_w), interpolation=T.InterpolationMode.NEAREST)
W, H = image.size
for i in range(masks_sel.shape[0]):
mask_i = masks_sel[i]
if mask_i.ndim == 3:
mask_i = mask_i[0]
mask_rs = resize_nearest(mask_i.unsqueeze(0))[0]
mask_rs = mask_rs[:H, :W]
mask_bool = (mask_rs > 0.0).cpu().numpy().astype(bool)
mask_rgb, _ = instance_colors(i)
image = overlay_single_mask(image, mask_bool, mask_rgb, alpha=0.45)
draw = ImageDraw.Draw(image)
box_width = 2
for i, box in enumerate(pred_boxes):
_, box_rgb = instance_colors(i)
x1, y1, x2, y2 = map(float, box)
draw.rectangle([x1, y1, x2, y2], outline=box_rgb, width=box_width)
exemplar_outline = (255, 255, 255, 255)
exemplar_inner = (0, 0, 0, 255)
for box in drawn_boxes:
x1, y1, x2, y2 = box[0], box[1], box[3], box[4]
draw.rectangle([x1, y1, x2, y2], outline=exemplar_outline, width=2)
draw.rectangle([x1 + 1, y1 + 1, x2 - 1, y2 - 1], outline=exemplar_inner, width=1)
return image.convert("RGB"), len(pred_boxes)
# -----------------------------
# Examples: gallery click -> set annotator value
# -----------------------------
EXAMPLE_PATHS = ["material/01.jpg", "material/00.jpg", "material/02.jpg", "material/03.jpg", "material/05.jpg","material/04.jpg","material/06.jpg"]
def load_example_from_gallery(evt: gr.SelectData):
"""
When user clicks a thumbnail in the gallery, load that image into the annotator.
"""
idx = int(evt.index)
path = EXAMPLE_PATHS[idx]
return {"image": path, "boxes": []}
# -----------------------------
# Gradio UI
# -----------------------------
iface = gr.Blocks(
title="GeCo2 Gradio Demo",
)
with iface:
gr.Markdown(
"""
# GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation
GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image without any retraining.
1) Upload an image or click an example below.
2) Draw bounding boxes on the target object (preferably ~3 instances).
3) Click **Count**.
4) If needed, adjust the threshold.
"""
)
# Store intermediate states
image_input = gr.State()
outputs_state = gr.State()
masks_state = gr.State()
img_state = gr.State()
scale_state = gr.State()
drawn_boxes_state = gr.State()
with gr.Row():
annotator = image_annotator(
value=None,
image_type="numpy", # ensures inputs["image"] is a numpy array
label_list=["Object"],
label_colors=[(0, 255, 0)],
use_default_label=True,
enable_keyboard_shortcuts=True,
interactive=True,
show_label=False,
box_min_size=3,
box_thickness=1,
)
image_output = gr.Image(type="pil")
with gr.Row():
count_output = gr.Number(label="Total Count")
enable_mask = gr.Checkbox(label="Predict masks", value=True)
threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01, label="Threshold")
count_button = gr.Button("Count")
gallery = gr.Gallery(
value=EXAMPLE_PATHS,
columns=7,
height=300,
label="Examples (click an image to load it into the annotator)",
show_label=True,
allow_preview=False,
)
gallery.select(
fn=load_example_from_gallery,
inputs=None,
outputs=annotator,
)
def initial_process(inputs, enable_mask, threshold):
# Validate: must have at least one box
if inputs is None or inputs.get("image", None) is None:
gr.Warning("please delineate at least one target category object")
return None, 0, None, None, None, None, None, None
img_val = inputs.get("image", None)
boxes = inputs.get("boxes", []) or []
if len(boxes) == 0:
# Try to show current image in the output even if no boxes
if isinstance(img_val, str):
preview = Image.open(img_val).convert("RGB")
elif isinstance(img_val, Image.Image):
preview = img_val.convert("RGB")
elif isinstance(img_val, np.ndarray):
preview = Image.fromarray(img_val.astype(np.uint8)).convert("RGB")
else:
preview = None
gr.Warning("please delineate at least one target category object")
return preview, 0, None, None, None, None, None, None
image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
if image is None:
return None, 0, None, None, None, None, None, None
out_img, cnt = post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
return (
out_img,
cnt,
image,
outputs,
masks,
img,
scale,
drawn_boxes,
)
def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
if image is None or outputs is None or img is None:
return None, 0
return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
count_button.click(
initial_process,
[annotator, enable_mask, threshold],
[image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state],
)
threshold.change(
update_threshold,
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
[image_output, count_output],
)
enable_mask.change(
update_threshold,
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
[image_output, count_output],
)
if __name__ == "__main__":
iface.queue().launch(ssr_mode=False)