Spaces:
Runtime error
Runtime error
| """ | |
| Visual grounding: count objects in a rendered image. | |
| Two backends (selected automatically by available packages): | |
| 1. BlobCounter — fast, dependency-free baseline using scipy/skimage | |
| 2. OWLViT-tiny — open-vocabulary detector for richer category recognition | |
| The active backend is chosen once at import time; the public API is identical | |
| for both so the rest of the system is backend-agnostic. | |
| """ | |
| from __future__ import annotations | |
| import io | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| # ------------------------------------------------------------------ | |
| # Blob counter baseline (no heavy deps) | |
| # ------------------------------------------------------------------ | |
| def _blob_count(image: np.ndarray, min_area: int = 30) -> int: | |
| """ | |
| Count distinct objects in a rendered stimulus image. | |
| Objects are assumed darker than the background (coloured shapes on white). | |
| """ | |
| # Convert to grayscale | |
| if image.ndim == 3: | |
| gray = (0.299 * image[:, :, 0] + | |
| 0.587 * image[:, :, 1] + | |
| 0.114 * image[:, :, 2]).astype(np.float32) | |
| else: | |
| gray = image.astype(np.float32) | |
| # Detect objects: pixels significantly darker than background. | |
| # Use mean - 0.5*std as threshold so sparse circles on white are captured. | |
| thresh = gray.mean() - 0.5 * gray.std() | |
| binary = (gray < thresh).astype(np.uint8) | |
| # Connected components via simple flood-fill BFS | |
| visited = np.zeros_like(binary, dtype=bool) | |
| count = 0 | |
| rows, cols = binary.shape | |
| def bfs(r0, c0): | |
| area = 0 | |
| stack = [(r0, c0)] | |
| while stack: | |
| r, c = stack.pop() | |
| if r < 0 or r >= rows or c < 0 or c >= cols: | |
| continue | |
| if visited[r, c] or binary[r, c] == 0: | |
| continue | |
| visited[r, c] = True | |
| area += 1 | |
| stack.extend([(r+1,c),(r-1,c),(r,c+1),(r,c-1)]) | |
| return area | |
| for r in range(rows): | |
| for c in range(cols): | |
| if binary[r, c] == 1 and not visited[r, c]: | |
| area = bfs(r, c) | |
| if area >= min_area: | |
| count += 1 | |
| return count | |
| # ------------------------------------------------------------------ | |
| # OWLViT-tiny backend | |
| # ------------------------------------------------------------------ | |
| _owlvit_pipeline = None | |
| def _load_owlvit(): | |
| global _owlvit_pipeline | |
| if _owlvit_pipeline is None: | |
| from transformers import pipeline | |
| _owlvit_pipeline = pipeline( | |
| "zero-shot-object-detection", | |
| model="google/owlvit-base-patch32", | |
| device=-1, # CPU | |
| ) | |
| return _owlvit_pipeline | |
| def _owlvit_count( | |
| image: np.ndarray, | |
| query: str, | |
| score_threshold: float = 0.10, | |
| ) -> int: | |
| from PIL import Image as PILImage | |
| pipe = _load_owlvit() | |
| pil_img = PILImage.fromarray(image) | |
| results = pipe(pil_img, candidate_labels=[query], threshold=score_threshold) | |
| return len(results) | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def count_objects( | |
| image: np.ndarray, | |
| query: str = "", | |
| backend: str = "auto", | |
| ) -> Tuple[int, str]: | |
| """ | |
| Count objects in *image*. | |
| Args: | |
| image: uint8 numpy array (H, W) or (H, W, 3) | |
| query: text description of what to count (used by OWLViT only) | |
| backend: "blob" | "owlvit" | "auto" | |
| "auto" tries OWLViT first; falls back to blob if unavailable | |
| Returns: | |
| (count, backend_used) | |
| """ | |
| if backend == "blob": | |
| return _blob_count(image), "blob" | |
| if backend == "owlvit" or backend == "auto": | |
| try: | |
| n = _owlvit_count(image, query or "object") | |
| return n, "owlvit" | |
| except Exception: | |
| pass # fall back | |
| return _blob_count(image), "blob" | |
| def load_image(path: str | Path) -> np.ndarray: | |
| """Load an image file to a uint8 numpy array.""" | |
| from PIL import Image as PILImage | |
| img = PILImage.open(str(path)).convert("RGB") | |
| return np.array(img, dtype=np.uint8) | |
| def render_counting_stimulus( | |
| n: int, | |
| label: str = "●", | |
| grid_size: int = 128, | |
| ) -> np.ndarray: | |
| """ | |
| Render a simple counting stimulus: *n* circles on a white background. | |
| Used when no pre-rendered image asset is available. | |
| Returns a (grid_size, grid_size, 3) uint8 array. | |
| """ | |
| try: | |
| from PIL import Image as PILImage, ImageDraw, ImageFont | |
| img = PILImage.new("RGB", (grid_size, grid_size), (255, 255, 255)) | |
| draw = ImageDraw.Draw(img) | |
| margin = 10 | |
| if n == 0: | |
| return np.ones((grid_size, grid_size, 3), dtype=np.uint8) * 255 | |
| cols = min(n, 5) | |
| rows = (n + cols - 1) // cols | |
| cell_w = (grid_size - 2 * margin) // max(cols, 1) | |
| cell_h = (grid_size - 2 * margin) // max(rows, 1) | |
| r = min(cell_w, cell_h) // 3 | |
| for i in range(n): | |
| col = i % cols | |
| row = i // cols | |
| cx = margin + col * cell_w + cell_w // 2 | |
| cy = margin + row * cell_h + cell_h // 2 | |
| draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=(60, 120, 220)) | |
| return np.array(img, dtype=np.uint8) | |
| except ImportError: | |
| # Fallback: return blank array | |
| arr = np.ones((grid_size, grid_size, 3), dtype=np.uint8) * 240 | |
| return arr | |