Create tutor/visual_grounding.py
Browse files- tutor/visual_grounding.py +50 -0
tutor/visual_grounding.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""tutor/visual_grounding.py — renders a counting stimulus as an RGB numpy array."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont # type: ignore
|
| 9 |
+
_PIL = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
_PIL = False
|
| 12 |
+
|
| 13 |
+
_PALETTE = ["#e74c3c","#3498db","#2ecc71","#f39c12",
|
| 14 |
+
"#9b59b6","#1abc9c","#e91e63","#f1c40f"]
|
| 15 |
+
|
| 16 |
+
def _hex(h):
|
| 17 |
+
h = h.lstrip("#")
|
| 18 |
+
return int(h[:2],16), int(h[2:4],16), int(h[4:],16)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def render_counting_stimulus(count: int, label: str = "●",
|
| 22 |
+
canvas_w: int = 400, canvas_h: int = 300
|
| 23 |
+
) -> Optional[np.ndarray]:
|
| 24 |
+
if not _PIL or count <= 0 or count > 20:
|
| 25 |
+
return None
|
| 26 |
+
img = Image.new("RGB", (canvas_w, canvas_h), (255,255,255))
|
| 27 |
+
draw = ImageDraw.Draw(img)
|
| 28 |
+
try:
|
| 29 |
+
font = ImageFont.truetype(
|
| 30 |
+
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 22)
|
| 31 |
+
except Exception:
|
| 32 |
+
font = ImageFont.load_default()
|
| 33 |
+
|
| 34 |
+
draw.text((canvas_w//2, 14), f"Count the {label}s!", fill=(30,30,30),
|
| 35 |
+
font=font, anchor="mt")
|
| 36 |
+
|
| 37 |
+
cols = min(count, 5)
|
| 38 |
+
rows = math.ceil(count / cols)
|
| 39 |
+
dot_r = min(26, (canvas_w-40)//(cols*2), (canvas_h-70)//(rows*2))
|
| 40 |
+
xstep = (canvas_w-40)//cols
|
| 41 |
+
ystep = (canvas_h-70)//rows
|
| 42 |
+
x0, y0 = 20+xstep//2, 55+ystep//2
|
| 43 |
+
|
| 44 |
+
for i in range(count):
|
| 45 |
+
cx = x0 + (i % cols)*xstep
|
| 46 |
+
cy = y0 + (i // cols)*ystep
|
| 47 |
+
c = _hex(_PALETTE[i % len(_PALETTE)])
|
| 48 |
+
draw.ellipse([cx-dot_r, cy-dot_r, cx+dot_r, cy+dot_r], fill=c)
|
| 49 |
+
|
| 50 |
+
return np.array(img)
|