Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from gradio_image_annotation import image_annotator | |
| from models.counter_infer import build_model | |
| from utils.arg_parser import get_argparser | |
| from utils.data import resize_and_pad | |
| import torchvision.ops as ops | |
| from torchvision import transforms as T | |
| from PIL import Image, ImageDraw | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import colorsys | |
| # ----------------------------- | |
| _MODEL = None | |
| _ARGS = None | |
| _WEIGHTS_PATH = None | |
| # ----------------------------- | |
| def _get_args(): | |
| global _ARGS | |
| if _ARGS is None: | |
| args = get_argparser().parse_args() | |
| args.zero_shot = True | |
| _ARGS = args | |
| return _ARGS | |
| def _get_weights_path(): | |
| global _WEIGHTS_PATH | |
| if _WEIGHTS_PATH is None: | |
| _WEIGHTS_PATH = hf_hub_download( | |
| repo_id="jerpelhan/geco2-assets", | |
| filename="weights/CNTQG_multitrain_ca44.pth", | |
| repo_type="dataset", | |
| ) | |
| return _WEIGHTS_PATH | |
| def _strip_module_prefix(state_dict: dict) -> dict: | |
| """ | |
| If weights were saved from torch.nn.DataParallel, keys are often prefixed with 'module.'. | |
| When loading into a non-DataParallel model, strip that prefix. | |
| """ | |
| if not isinstance(state_dict, dict) or len(state_dict) == 0: | |
| return state_dict | |
| # Only strip if it looks like DP | |
| has_module = any(k.startswith("module.") for k in state_dict.keys()) | |
| if not has_module: | |
| return state_dict | |
| return {k[len("module.") :]: v for k, v in state_dict.items()} | |
| def _extract_state_dict(ckpt) -> dict: | |
| """ | |
| Robustly extract a state_dict from typical checkpoint formats. | |
| """ | |
| if isinstance(ckpt, dict): | |
| # Common keys | |
| if "model" in ckpt and isinstance(ckpt["model"], dict): | |
| return ckpt["model"] | |
| if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict): | |
| return ckpt["state_dict"] | |
| # Fallback: checkpoint itself is the state_dict | |
| return ckpt | |
| def get_model_on_device(device: torch.device): | |
| """ | |
| Lazily build and load model, then move to the requested device. | |
| IMPORTANT: model is constructed/loaded without initializing CUDA in the main process. | |
| This function will be called from inside the @spaces.GPU worker. | |
| """ | |
| global _MODEL | |
| if _MODEL is None: | |
| args = _get_args() | |
| # Build on CPU first to avoid CUDA init in the wrong process | |
| model = build_model(args) | |
| weights_path = _get_weights_path() | |
| ckpt = torch.load(weights_path, map_location="cpu") # keep compatibility across torch versions | |
| state = _extract_state_dict(ckpt) | |
| state = _strip_module_prefix(state) | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| _MODEL = model | |
| _MODEL = _MODEL.to(device) | |
| if device.type == "cuda": | |
| torch.backends.cudnn.benchmark = True | |
| return _MODEL | |
| # ----------------------------- | |
| # Rotation helper (in case annotator reports orientation) | |
| # ----------------------------- | |
| def _rotate_image_and_boxes(image_np: np.ndarray, boxes: list[dict], angle: int): | |
| if angle is None: | |
| return image_np, boxes | |
| a = int(angle) % 4 | |
| if a == 0: | |
| return image_np, boxes | |
| H, W = image_np.shape[:2] | |
| # rotate image using the same convention as the component docs | |
| image_rot = np.rot90(image_np, k=-a) | |
| def clamp_box(xmin, ymin, xmax, ymax, newW, newH): | |
| xmin = max(0, min(newW, xmin)) | |
| xmax = max(0, min(newW, xmax)) | |
| ymin = max(0, min(newH, ymin)) | |
| ymax = max(0, min(newH, ymax)) | |
| if xmax < xmin: | |
| xmin, xmax = xmax, xmin | |
| if ymax < ymin: | |
| ymin, ymax = ymax, ymin | |
| return xmin, ymin, xmax, ymax | |
| boxes_rot = [] | |
| if a == 1: | |
| # 90 deg clockwise: (x,y) -> (H - 1 - y, x) | |
| newH, newW = W, H | |
| for b in boxes: | |
| xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"] | |
| nxmin = H - ymax | |
| nxmax = H - ymin | |
| nymin = xmin | |
| nymax = xmax | |
| nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH) | |
| bb = dict(b) | |
| bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax}) | |
| boxes_rot.append(bb) | |
| elif a == 2: | |
| # 180 deg: (x,y) -> (W - 1 - x, H - 1 - y) | |
| newH, newW = H, W | |
| for b in boxes: | |
| xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"] | |
| nxmin = W - xmax | |
| nxmax = W - xmin | |
| nymin = H - ymax | |
| nymax = H - ymin | |
| nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH) | |
| bb = dict(b) | |
| bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax}) | |
| boxes_rot.append(bb) | |
| else: # a == 3 | |
| # 90 deg counter-clockwise: (x,y) -> (y, W - 1 - x) | |
| newH, newW = W, H | |
| for b in boxes: | |
| xmin, ymin, xmax, ymax = b["xmin"], b["ymin"], b["xmax"], b["ymax"] | |
| nxmin = ymin | |
| nxmax = ymax | |
| nymin = W - xmax | |
| nymax = W - xmin | |
| nxmin, nymin, nxmax, nymax = clamp_box(nxmin, nymin, nxmax, nymax, newW, newH) | |
| bb = dict(b) | |
| bb.update({"xmin": nxmin, "ymin": nymin, "xmax": nxmax, "ymax": nymax}) | |
| boxes_rot.append(bb) | |
| return image_rot, boxes_rot | |
| # ----------------------------- | |
| # Function to Process Image Once (GPU) | |
| # ----------------------------- | |
| def process_image_once(inputs, enable_mask): | |
| """ | |
| inputs is AnnotatedImageValue-like dict from gradio_image_annotation: | |
| { | |
| "image": np.ndarray | PIL | str, | |
| "boxes": [ {xmin,ymin,xmax,ymax,label?,color?}, ... ], | |
| "orientation": int? | |
| } | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = get_model_on_device(device) | |
| if inputs is None or inputs.get("image", None) is None: | |
| # keep behavior simple: return empty outputs | |
| return None, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, [] | |
| image = inputs["image"] | |
| boxes = inputs.get("boxes", []) or [] | |
| # Ensure numpy image (support numpy, PIL, OR local path string) | |
| if isinstance(image, Image.Image): | |
| image = np.array(image.convert("RGB")) | |
| elif isinstance(image, str): | |
| image = np.array(Image.open(image).convert("RGB")) | |
| elif isinstance(image, np.ndarray): | |
| pass | |
| else: | |
| raise ValueError(f"Unsupported image type from annotator: {type(image)}") | |
| angle = inputs.get("orientation", None) | |
| if angle is not None: | |
| image, boxes = _rotate_image_and_boxes(image, boxes, angle) | |
| drawn_boxes = [] | |
| for b in boxes: | |
| drawn_boxes.append([float(b["xmin"]), float(b["ymin"]), 0.0, float(b["xmax"]), float(b["ymax"])]) | |
| # If no boxes, do not call model (caller will handle warning) | |
| if len(drawn_boxes) == 0: | |
| return image, [{"pred_boxes": torch.empty(0, 4), "box_v": torch.empty(0)}], [None], torch.empty(1), 1.0, [] | |
| image_tensor = torch.tensor(image).to(device) | |
| image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0 | |
| image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor) | |
| bboxes_tensor = torch.tensor( | |
| [[box[0], box[1], box[3], box[4]] for box in drawn_boxes], | |
| dtype=torch.float32, | |
| ).to(device) | |
| img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0) | |
| img = img.unsqueeze(0).to(device) | |
| bboxes = bboxes.unsqueeze(0).to(device) | |
| # Faster inference mode | |
| use_amp = (device.type == "cuda") | |
| with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp): | |
| model.return_masks = enable_mask | |
| outputs, _, _, _, masks = model(img, bboxes) | |
| # Return ONLY CPU-native objects to main process. | |
| out0 = outputs[0] | |
| pred_boxes_cpu = out0["pred_boxes"].detach().float().cpu() | |
| box_v_cpu = out0["box_v"].detach().float().cpu() | |
| outputs_cpu = [{"pred_boxes": pred_boxes_cpu, "box_v": box_v_cpu}] | |
| if enable_mask and masks is not None and masks[0] is not None: | |
| masks_cpu = [masks[0].detach().float().cpu()] | |
| else: | |
| masks_cpu = [None] | |
| img_cpu = img.detach().cpu() | |
| return image, outputs_cpu, masks_cpu, img_cpu, float(scale), drawn_boxes | |
| # ----------------------------- | |
| # Pastel visualization helpers | |
| # ----------------------------- | |
| def _hsv_to_rgb255(h, s, v): | |
| r, g, b = colorsys.hsv_to_rgb(h, s, v) | |
| return (int(255 * r), int(255 * g), int(255 * b)) | |
| def instance_colors(i: int): | |
| h = (i * 0.618033988749895) % 1.0 | |
| mask_rgb = _hsv_to_rgb255(h, s=0.28, v=1.00) | |
| box_rgb = _hsv_to_rgb255(h, s=0.42, v=0.95) | |
| return mask_rgb, box_rgb | |
| def overlay_single_mask(base_rgba: Image.Image, mask_bool: np.ndarray, rgb, alpha=0.45): | |
| if mask_bool.dtype != np.bool_: | |
| mask_bool = mask_bool.astype(bool) | |
| h, w = mask_bool.shape | |
| overlay = np.zeros((h, w, 4), dtype=np.uint8) | |
| overlay[..., 0] = rgb[0] | |
| overlay[..., 1] = rgb[1] | |
| overlay[..., 2] = rgb[2] | |
| overlay[..., 3] = (mask_bool.astype(np.uint8) * int(255 * alpha)) | |
| overlay_img = Image.fromarray(overlay, mode="RGBA") | |
| return Image.alpha_composite(base_rgba, overlay_img) | |
| # ----------------------------- | |
| # Post-process and Update Output | |
| # ----------------------------- | |
| def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold): | |
| idx = 0 | |
| threshold = 1 / threshold | |
| score = outputs[idx]["box_v"] | |
| if score.numel() == 0: | |
| # no predictions | |
| image_pil = Image.fromarray((image).astype(np.uint8)).convert("RGB") | |
| return image_pil, 0 | |
| score_mask = score > score.max() / threshold | |
| keep = ops.nms( | |
| outputs[idx]["pred_boxes"][score_mask], | |
| score[score_mask], | |
| 0.5, | |
| ) | |
| pred_boxes = outputs[idx]["pred_boxes"][score_mask][keep] | |
| pred_boxes = torch.clamp(pred_boxes, 0, 1) | |
| pred_boxes = (pred_boxes / scale * img.shape[-1]).tolist() | |
| image = Image.fromarray((image).astype(np.uint8)).convert("RGBA") | |
| if enable_mask and masks is not None and masks[idx] is not None: | |
| masks_sel = masks[idx][score_mask[0]] if score_mask.ndim > 1 else masks[idx][score_mask] | |
| masks_sel = masks_sel[keep] | |
| target_h = int(img.shape[2] / scale) | |
| target_w = int(img.shape[3] / scale) | |
| resize_nearest = T.Resize((target_h, target_w), interpolation=T.InterpolationMode.NEAREST) | |
| W, H = image.size | |
| for i in range(masks_sel.shape[0]): | |
| mask_i = masks_sel[i] | |
| if mask_i.ndim == 3: | |
| mask_i = mask_i[0] | |
| mask_rs = resize_nearest(mask_i.unsqueeze(0))[0] | |
| mask_rs = mask_rs[:H, :W] | |
| mask_bool = (mask_rs > 0.0).cpu().numpy().astype(bool) | |
| mask_rgb, _ = instance_colors(i) | |
| image = overlay_single_mask(image, mask_bool, mask_rgb, alpha=0.45) | |
| draw = ImageDraw.Draw(image) | |
| box_width = 2 | |
| for i, box in enumerate(pred_boxes): | |
| _, box_rgb = instance_colors(i) | |
| x1, y1, x2, y2 = map(float, box) | |
| draw.rectangle([x1, y1, x2, y2], outline=box_rgb, width=box_width) | |
| exemplar_outline = (255, 255, 255, 255) | |
| exemplar_inner = (0, 0, 0, 255) | |
| for box in drawn_boxes: | |
| x1, y1, x2, y2 = box[0], box[1], box[3], box[4] | |
| draw.rectangle([x1, y1, x2, y2], outline=exemplar_outline, width=2) | |
| draw.rectangle([x1 + 1, y1 + 1, x2 - 1, y2 - 1], outline=exemplar_inner, width=1) | |
| return image.convert("RGB"), len(pred_boxes) | |
| # ----------------------------- | |
| # Examples: gallery click -> set annotator value | |
| # ----------------------------- | |
| EXAMPLE_PATHS = ["material/01.jpg", "material/00.jpg", "material/02.jpg", "material/03.jpg", "material/05.jpg","material/04.jpg","material/06.jpg"] | |
| def load_example_from_gallery(evt: gr.SelectData): | |
| """ | |
| When user clicks a thumbnail in the gallery, load that image into the annotator. | |
| """ | |
| idx = int(evt.index) | |
| path = EXAMPLE_PATHS[idx] | |
| return {"image": path, "boxes": []} | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| iface = gr.Blocks( | |
| title="GeCo2 Gradio Demo", | |
| ) | |
| with iface: | |
| gr.Markdown( | |
| """ | |
| # GeCo2: Generalized-Scale Object Counting with Gradual Query Aggregation | |
| GeCo2 is a few-shot, category-agnostic detection counter. With only a small number of exemplars, GeCo2 can detect and count all instances of the target object in an image without any retraining. | |
| 1) Upload an image or click an example below. | |
| 2) Draw bounding boxes on the target object (preferably ~3 instances). | |
| 3) Click **Count**. | |
| 4) If needed, adjust the threshold. | |
| """ | |
| ) | |
| # Store intermediate states | |
| image_input = gr.State() | |
| outputs_state = gr.State() | |
| masks_state = gr.State() | |
| img_state = gr.State() | |
| scale_state = gr.State() | |
| drawn_boxes_state = gr.State() | |
| with gr.Row(): | |
| annotator = image_annotator( | |
| value=None, | |
| image_type="numpy", # ensures inputs["image"] is a numpy array | |
| label_list=["Object"], | |
| label_colors=[(0, 255, 0)], | |
| use_default_label=True, | |
| enable_keyboard_shortcuts=True, | |
| interactive=True, | |
| show_label=False, | |
| box_min_size=3, | |
| box_thickness=1, | |
| ) | |
| image_output = gr.Image(type="pil") | |
| with gr.Row(): | |
| count_output = gr.Number(label="Total Count") | |
| enable_mask = gr.Checkbox(label="Predict masks", value=True) | |
| threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01, label="Threshold") | |
| count_button = gr.Button("Count") | |
| gallery = gr.Gallery( | |
| value=EXAMPLE_PATHS, | |
| columns=7, | |
| height=300, | |
| label="Examples (click an image to load it into the annotator)", | |
| show_label=True, | |
| allow_preview=False, | |
| ) | |
| gallery.select( | |
| fn=load_example_from_gallery, | |
| inputs=None, | |
| outputs=annotator, | |
| ) | |
| def initial_process(inputs, enable_mask, threshold): | |
| # Validate: must have at least one box | |
| if inputs is None or inputs.get("image", None) is None: | |
| gr.Warning("please delineate at least one target category object") | |
| return None, 0, None, None, None, None, None, None | |
| img_val = inputs.get("image", None) | |
| boxes = inputs.get("boxes", []) or [] | |
| if len(boxes) == 0: | |
| # Try to show current image in the output even if no boxes | |
| if isinstance(img_val, str): | |
| preview = Image.open(img_val).convert("RGB") | |
| elif isinstance(img_val, Image.Image): | |
| preview = img_val.convert("RGB") | |
| elif isinstance(img_val, np.ndarray): | |
| preview = Image.fromarray(img_val.astype(np.uint8)).convert("RGB") | |
| else: | |
| preview = None | |
| gr.Warning("please delineate at least one target category object") | |
| return preview, 0, None, None, None, None, None, None | |
| image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask) | |
| if image is None: | |
| return None, 0, None, None, None, None, None, None | |
| out_img, cnt = post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold) | |
| return ( | |
| out_img, | |
| cnt, | |
| image, | |
| outputs, | |
| masks, | |
| img, | |
| scale, | |
| drawn_boxes, | |
| ) | |
| def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask): | |
| if image is None or outputs is None or img is None: | |
| return None, 0 | |
| return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold) | |
| count_button.click( | |
| initial_process, | |
| [annotator, enable_mask, threshold], | |
| [image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state], | |
| ) | |
| threshold.change( | |
| update_threshold, | |
| [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask], | |
| [image_output, count_output], | |
| ) | |
| enable_mask.change( | |
| update_threshold, | |
| [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask], | |
| [image_output, count_output], | |
| ) | |
| if __name__ == "__main__": | |
| iface.queue().launch(ssr_mode=False) | |