File size: 3,867 Bytes
cf6c0e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Generate assets/reward_grid.png and assets/reward_grid.jsonl.

Layout: 3 rows (easy/medium/hard) Γ— 4 columns (reference + 3 randomly sampled
variants sorted descending by reward score). No column headers.
"""
import json
import random
import pathlib
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from PIL import Image

ROOT = pathlib.Path(__file__).parent.parent

TASKS = [
    {"difficulty": "Easy",   "task_id": 0},
    {"difficulty": "Medium", "task_id": 5},
    {"difficulty": "Hard",   "task_id": 10},
]
POOL = ["perfect", "minor_diff", "bad_colors", "half_styled", "no_layout", "no_style"]

BG   = "#0d1117"
TICK = "#8b949e"
ANNO = "#e6edf3"

def score_color(score):
    """Red (0) β†’ yellow (0.5) β†’ green (1)."""
    if score >= 0.7:
        r, g = int((1 - score) * 2 * 255), 180
    else:
        r, g = 220, int(score / 0.7 * 160)
    return (r / 255, g / 255, 60 / 255)


def run(seed: int = 42):
    rng = random.Random(seed)
    records = []

    for task in TASKS:
        tid = task["task_id"]
        scores_path = ROOT / f"data/tests/{tid}/expected_scores.json"
        scores = json.loads(scores_path.read_text())
        sampled = rng.sample(POOL, 3)
        variants = sorted(
            [{"name": v, "score": round(scores[v], 4)} for v in sampled],
            key=lambda x: x["score"],
            reverse=True,   # descending: best first
        )
        records.append({
            "difficulty": task["difficulty"],
            "task_id": tid,
            "variants": variants,
        })

    # Write JSONL
    out_jsonl = ROOT / "assets/reward_grid.jsonl"
    out_jsonl.write_text("\n".join(json.dumps(r) for r in records) + "\n")
    print(f"Wrote {out_jsonl}")

    # ── Plot ─────────────────────────────────────────────────────────
    fig, axes = plt.subplots(3, 4, figsize=(14, 9), facecolor=BG)
    fig.subplots_adjust(hspace=0.08, wspace=0.04)

    for row_idx, record in enumerate(records):
        tid = record["task_id"]
        renders_dir = ROOT / f"data/tests/{tid}/renders"

        # Column 0 β€” reference
        ax = axes[row_idx, 0]
        img = Image.open(renders_dir / "reference.png").convert("RGB")
        ax.imshow(np.array(img))
        ax.set_facecolor(BG)
        ax.axis("off")
        for spine in ax.spines.values():
            spine.set_edgecolor("#30363d")
            spine.set_linewidth(0.8)
        # Row label as text inside the axis (top-left corner)
        ax.text(-0.18, 0.5, record["difficulty"], transform=ax.transAxes,
                color=ANNO, fontsize=13, fontweight="bold",
                va="center", ha="center", rotation=90)

        # Columns 1-3 β€” sampled variants descending
        for col_idx, variant in enumerate(record["variants"], start=1):
            ax = axes[row_idx, col_idx]
            png = renders_dir / f"{variant['name']}.png"
            img = Image.open(png).convert("RGB")
            ax.imshow(np.array(img))
            ax.set_facecolor(BG)
            ax.axis("off")
            color = score_color(variant["score"])
            # Score badge at bottom-center of image
            ax.text(0.5, -0.04, f"reward: {variant['score']:.2f}",
                    transform=ax.transAxes, color=color,
                    fontsize=10, fontweight="bold",
                    va="top", ha="center")

    plt.suptitle("Reward Function Evaluation β€” Reference vs Sampled Quality Levels",
                 color=ANNO, fontsize=12, y=0.998)

    out_png = ROOT / "assets/reward_grid.png"
    plt.savefig(out_png, dpi=130, bbox_inches="tight",
                facecolor=BG, edgecolor="none")
    print(f"Saved {out_png}")


if __name__ == "__main__":
    run()