File size: 3,786 Bytes
44a3896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022f82c
 
 
44a3896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from PIL import Image

from pipeline.embedding_planner import SIMILARITY_THRESHOLD
from pipeline.planner import plan_from_instruction, warmup
from pipeline.spatial import apply_spatial_ops

warmup()

planning_cases = [
    ("rotate", ["rotate"]),
    ("make it brighter", ["brighten"]),
    ("vintage warm look", ["sepia", "warmer"]),
    ("replace background with beach sunset", ["replace_background"]),
    ("put a table behind it", None),
    ("blur everything softly", ["blur"]),
    ("cartoon style", ["style_candy"]),
    ("mosaic painting", ["style_mosaic"]),
    ("impressionist look", ["style_rain_princess"]),
    ("asdf qwerty nonsense", None),
    ("more contrast", ["contrast_up"]),
    ("transparent cutout", ["add_cutout_transparent"]),
    ("cover a piece but not text", ["cover_avoid_text"]),
    ("blur everything softly", []),  # must not trigger spatial ops
]


def spatial_ids(plan):
    ids = []
    for s in plan.get("spatial", []):
        op = s["op"]
        if op == "add_cutout":
            style = s.get("cutout_style", "solid")
            if style in ("transparent", "solid"):
                ids.append(f"add_cutout_{style}")
            else:
                ids.append("add_cutout")
        else:
            ids.append(op)
    return ids


def run_planning_tests():
    print(f"Threshold: {SIMILARITY_THRESHOLD}\n")
    failed = 0
    for instr, expected_ids in planning_cases:
        plan = plan_from_instruction(instr)
        got_ids = [s[0] for s in plan.get("scores", [])]
        all_matched = got_ids + plan["tags"] + spatial_ids(plan)

    if expected_ids is None:
        ok = not plan["supported"]
    elif expected_ids == []:
        ok = plan["supported"] and not plan.get("spatial")
    else:
        ok = all(eid in all_matched for eid in expected_ids)

        status = "OK" if ok else "FAIL"
        if not ok:
            failed += 1
        print(f"{status}: {instr!r}")
        print(f"  supported={plan['supported']} scores={plan.get('scores', [])}")
        print(f"  tags={plan['tags']} spatial={spatial_ids(plan)}")
        if not ok:
            print(f"  expected: {expected_ids}")
        print()

    return failed


def run_cutout_execution_tests():
    print("Cutout execution tests:\n")
    failed = 0
    img = Image.new("RGB", (400, 300), color=(100, 150, 200))

    execution_cases = [
        ("add cutout", 1, 4),
        ("cover and add cutout", 1, 4),
        ("cover and several cutouts avoid text", 1, 4),
        ("add cutout avoid text", 1, 4),
        ("3 cutouts", 3, 3),
        ("one cutout", 1, 1),
    ]

    for instr, min_cutouts, max_cutouts in execution_cases:
        plan = plan_from_instruction(instr)
        _, applied = apply_spatial_ops(img.copy(), plan["spatial"], 1.0, instr)

        cutout_count = 0
        for meta in applied:
            if meta.get("cutouts"):
                cutout_count += len(meta["cutouts"])
            elif meta.get("op") in ("add_cutout", "cutout") and meta.get("cutout_count"):
                cutout_count += meta["cutout_count"]

        ok = min_cutouts <= cutout_count <= max_cutouts
        status = "OK" if ok else "FAIL"
        if not ok:
            failed += 1
        print(f"{status}: {instr!r} -> {cutout_count} cutout(s) (expected {min_cutouts}-{max_cutouts})")
        for meta in applied:
            if meta.get("cutouts"):
                shapes = [c.get("shape") for c in meta["cutouts"]]
                print(f"  shapes: {shapes}")
        print()

    return failed


failed = run_planning_tests()
failed += run_cutout_execution_tests()

if failed:
    raise SystemExit(f"{failed} test(s) failed")

print("All validation tests passed")