OffGridSchedula / training /render_screenshots.py
ParetoOptimal's picture
Initial Commit
0366d65
Raw
History Blame Contribute Delete
4.66 kB
"""Render eval threads as iMessage-style screenshots for the vision A/B arm.
Each record in an eval jsonl becomes training/data/screenshots/<stem>/<id>.png.
Chat-style threads ("Name: text" lines with a repeated sender or "Me") get
gray/blue bubbles; everything else (flyers, forwarded notices, appointment
cards) renders as a single notice card, line breaks preserved. The PNGs ship
to Modal with the repo (training/modal_eval.py --vision) and are fed to the
model INSTEAD of the thread text by `VISION=1 training/eval.py`.
python training/render_screenshots.py # both eval files
python training/render_screenshots.py training/data/eval.jsonl
"""
from __future__ import annotations
import json
import re
import sys
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
ROOT = Path(__file__).resolve().parent.parent
DEFAULT_FILES = [ROOT / "training" / "data" / "eval.jsonl",
ROOT / "training" / "data" / "eval_unstructured.jsonl"]
OUT_ROOT = ROOT / "training" / "data" / "screenshots"
W = 390 # iPhone-ish logical width
PAD = 12
BUBBLE_PAD = 10
MAX_BUBBLE = 270
GRAY, BLUE, INK, WHITE = (229, 229, 234), (10, 132, 255), (20, 20, 22), (255, 255, 255)
SENDER_RE = re.compile(r"^([A-Za-z][\w .'’()-]{0,40}?):\s+(.*)$")
_FONT_CANDIDATES = ["segoeui.ttf", "C:/Windows/Fonts/segoeui.ttf",
"DejaVuSans.ttf", "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
"Arial.ttf"]
def _font(size: int) -> ImageFont.FreeTypeFont:
for name in _FONT_CANDIDATES:
try:
return ImageFont.truetype(name, size)
except OSError:
continue
return ImageFont.load_default()
BODY, LABEL = _font(15), _font(11)
def _wrap(text: str, font: ImageFont.FreeTypeFont, width: int) -> list[str]:
out = []
for raw in text.split("\n"):
line = ""
for word in raw.split(" "):
cand = f"{line} {word}".strip()
if font.getlength(cand) <= width or not line:
line = cand
else:
out.append(line)
line = word
out.append(line)
return out
def _messages(thread: str) -> list[tuple[str | None, str]]:
"""[(sender_or_None, text)] — None sender means a notice card."""
lines = thread.splitlines()
parsed = [SENDER_RE.match(ln) for ln in lines]
senders = [m.group(1) for m in parsed if m]
is_chat = "Me" in senders or any(senders.count(s) > 1 for s in set(senders))
if not is_chat:
return [(None, thread)]
msgs: list[tuple[str | None, str]] = []
for ln, m in zip(lines, parsed):
if m:
msgs.append((m.group(1), m.group(2)))
elif msgs and ln.strip(): # continuation of the previous bubble
msgs[-1] = (msgs[-1][0], msgs[-1][1] + "\n" + ln.strip())
return msgs
def render(thread: str, out_path: Path) -> None:
line_h = BODY.getbbox("Ag")[3] + 4
label_h = LABEL.getbbox("Ag")[3] + 2
msgs = [(s, _wrap(t, BODY, MAX_BUBBLE - 2 * BUBBLE_PAD)) for s, t in _messages(thread)]
height = PAD
for sender, wrapped in msgs:
if sender not in (None, "Me"):
height += label_h
height += len(wrapped) * line_h + 2 * BUBBLE_PAD + 8
height += PAD
img = Image.new("RGB", (W, max(height, 80)), WHITE)
d = ImageDraw.Draw(img)
y = PAD
for sender, wrapped in msgs:
bw = min(MAX_BUBBLE, max((BODY.getlength(l) for l in wrapped), default=0) + 2 * BUBBLE_PAD)
bh = len(wrapped) * line_h + 2 * BUBBLE_PAD
mine = sender == "Me"
x = W - PAD - bw if mine else PAD
if sender not in (None, "Me"):
d.text((x + 4, y), sender, font=LABEL, fill=(120, 120, 128))
y += label_h
d.rounded_rectangle([x, y, x + bw, y + bh], radius=14,
fill=BLUE if mine else GRAY)
ty = y + BUBBLE_PAD
for line in wrapped:
d.text((x + BUBBLE_PAD, ty), line, font=BODY, fill=WHITE if mine else INK)
ty += line_h
y += bh + 8
img.save(out_path)
def main() -> None:
files = [Path(a) for a in sys.argv[1:]] or DEFAULT_FILES
for f in files:
out_dir = OUT_ROOT / f.stem
out_dir.mkdir(parents=True, exist_ok=True)
n = 0
for line in f.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
rec = json.loads(line)
render(rec["thread"], out_dir / f"{rec['id']}.png")
n += 1
print(f"{f.name}: rendered {n} screenshot(s) -> {out_dir}")
if __name__ == "__main__":
main()