Spaces:
Runtime error
Runtime error
File size: 5,476 Bytes
a62b942 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
Visual grounding: count objects in a rendered image.
Two backends (selected automatically by available packages):
1. BlobCounter — fast, dependency-free baseline using scipy/skimage
2. OWLViT-tiny — open-vocabulary detector for richer category recognition
The active backend is chosen once at import time; the public API is identical
for both so the rest of the system is backend-agnostic.
"""
from __future__ import annotations
import io
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
# ------------------------------------------------------------------
# Blob counter baseline (no heavy deps)
# ------------------------------------------------------------------
def _blob_count(image: np.ndarray, min_area: int = 30) -> int:
"""
Count distinct objects in a rendered stimulus image.
Objects are assumed darker than the background (coloured shapes on white).
"""
# Convert to grayscale
if image.ndim == 3:
gray = (0.299 * image[:, :, 0] +
0.587 * image[:, :, 1] +
0.114 * image[:, :, 2]).astype(np.float32)
else:
gray = image.astype(np.float32)
# Detect objects: pixels significantly darker than background.
# Use mean - 0.5*std as threshold so sparse circles on white are captured.
thresh = gray.mean() - 0.5 * gray.std()
binary = (gray < thresh).astype(np.uint8)
# Connected components via simple flood-fill BFS
visited = np.zeros_like(binary, dtype=bool)
count = 0
rows, cols = binary.shape
def bfs(r0, c0):
area = 0
stack = [(r0, c0)]
while stack:
r, c = stack.pop()
if r < 0 or r >= rows or c < 0 or c >= cols:
continue
if visited[r, c] or binary[r, c] == 0:
continue
visited[r, c] = True
area += 1
stack.extend([(r+1,c),(r-1,c),(r,c+1),(r,c-1)])
return area
for r in range(rows):
for c in range(cols):
if binary[r, c] == 1 and not visited[r, c]:
area = bfs(r, c)
if area >= min_area:
count += 1
return count
# ------------------------------------------------------------------
# OWLViT-tiny backend
# ------------------------------------------------------------------
_owlvit_pipeline = None
def _load_owlvit():
global _owlvit_pipeline
if _owlvit_pipeline is None:
from transformers import pipeline
_owlvit_pipeline = pipeline(
"zero-shot-object-detection",
model="google/owlvit-base-patch32",
device=-1, # CPU
)
return _owlvit_pipeline
def _owlvit_count(
image: np.ndarray,
query: str,
score_threshold: float = 0.10,
) -> int:
from PIL import Image as PILImage
pipe = _load_owlvit()
pil_img = PILImage.fromarray(image)
results = pipe(pil_img, candidate_labels=[query], threshold=score_threshold)
return len(results)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def count_objects(
image: np.ndarray,
query: str = "",
backend: str = "auto",
) -> Tuple[int, str]:
"""
Count objects in *image*.
Args:
image: uint8 numpy array (H, W) or (H, W, 3)
query: text description of what to count (used by OWLViT only)
backend: "blob" | "owlvit" | "auto"
"auto" tries OWLViT first; falls back to blob if unavailable
Returns:
(count, backend_used)
"""
if backend == "blob":
return _blob_count(image), "blob"
if backend == "owlvit" or backend == "auto":
try:
n = _owlvit_count(image, query or "object")
return n, "owlvit"
except Exception:
pass # fall back
return _blob_count(image), "blob"
def load_image(path: str | Path) -> np.ndarray:
"""Load an image file to a uint8 numpy array."""
from PIL import Image as PILImage
img = PILImage.open(str(path)).convert("RGB")
return np.array(img, dtype=np.uint8)
def render_counting_stimulus(
n: int,
label: str = "●",
grid_size: int = 128,
) -> np.ndarray:
"""
Render a simple counting stimulus: *n* circles on a white background.
Used when no pre-rendered image asset is available.
Returns a (grid_size, grid_size, 3) uint8 array.
"""
try:
from PIL import Image as PILImage, ImageDraw, ImageFont
img = PILImage.new("RGB", (grid_size, grid_size), (255, 255, 255))
draw = ImageDraw.Draw(img)
margin = 10
if n == 0:
return np.ones((grid_size, grid_size, 3), dtype=np.uint8) * 255
cols = min(n, 5)
rows = (n + cols - 1) // cols
cell_w = (grid_size - 2 * margin) // max(cols, 1)
cell_h = (grid_size - 2 * margin) // max(rows, 1)
r = min(cell_w, cell_h) // 3
for i in range(n):
col = i % cols
row = i // cols
cx = margin + col * cell_w + cell_w // 2
cy = margin + row * cell_h + cell_h // 2
draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=(60, 120, 220))
return np.array(img, dtype=np.uint8)
except ImportError:
# Fallback: return blank array
arr = np.ones((grid_size, grid_size, 3), dtype=np.uint8) * 240
return arr
|