NSamson1 commited on
Commit
d023a51
·
verified ·
1 Parent(s): 0bc4913

Create tutor/visual_grounding.py

Browse files
Files changed (1) hide show
  1. tutor/visual_grounding.py +50 -0
tutor/visual_grounding.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """tutor/visual_grounding.py — renders a counting stimulus as an RGB numpy array."""
2
+ from __future__ import annotations
3
+ import math
4
+ import numpy as np
5
+ from typing import Optional
6
+
7
+ try:
8
+ from PIL import Image, ImageDraw, ImageFont # type: ignore
9
+ _PIL = True
10
+ except ImportError:
11
+ _PIL = False
12
+
13
+ _PALETTE = ["#e74c3c","#3498db","#2ecc71","#f39c12",
14
+ "#9b59b6","#1abc9c","#e91e63","#f1c40f"]
15
+
16
+ def _hex(h):
17
+ h = h.lstrip("#")
18
+ return int(h[:2],16), int(h[2:4],16), int(h[4:],16)
19
+
20
+
21
+ def render_counting_stimulus(count: int, label: str = "●",
22
+ canvas_w: int = 400, canvas_h: int = 300
23
+ ) -> Optional[np.ndarray]:
24
+ if not _PIL or count <= 0 or count > 20:
25
+ return None
26
+ img = Image.new("RGB", (canvas_w, canvas_h), (255,255,255))
27
+ draw = ImageDraw.Draw(img)
28
+ try:
29
+ font = ImageFont.truetype(
30
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 22)
31
+ except Exception:
32
+ font = ImageFont.load_default()
33
+
34
+ draw.text((canvas_w//2, 14), f"Count the {label}s!", fill=(30,30,30),
35
+ font=font, anchor="mt")
36
+
37
+ cols = min(count, 5)
38
+ rows = math.ceil(count / cols)
39
+ dot_r = min(26, (canvas_w-40)//(cols*2), (canvas_h-70)//(rows*2))
40
+ xstep = (canvas_w-40)//cols
41
+ ystep = (canvas_h-70)//rows
42
+ x0, y0 = 20+xstep//2, 55+ystep//2
43
+
44
+ for i in range(count):
45
+ cx = x0 + (i % cols)*xstep
46
+ cy = y0 + (i // cols)*ystep
47
+ c = _hex(_PALETTE[i % len(_PALETTE)])
48
+ draw.ellipse([cx-dot_r, cy-dot_r, cx+dot_r, cy+dot_r], fill=c)
49
+
50
+ return np.array(img)