File size: 3,786 Bytes
44a3896 022f82c 44a3896 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from PIL import Image
from pipeline.embedding_planner import SIMILARITY_THRESHOLD
from pipeline.planner import plan_from_instruction, warmup
from pipeline.spatial import apply_spatial_ops
warmup()
planning_cases = [
("rotate", ["rotate"]),
("make it brighter", ["brighten"]),
("vintage warm look", ["sepia", "warmer"]),
("replace background with beach sunset", ["replace_background"]),
("put a table behind it", None),
("blur everything softly", ["blur"]),
("cartoon style", ["style_candy"]),
("mosaic painting", ["style_mosaic"]),
("impressionist look", ["style_rain_princess"]),
("asdf qwerty nonsense", None),
("more contrast", ["contrast_up"]),
("transparent cutout", ["add_cutout_transparent"]),
("cover a piece but not text", ["cover_avoid_text"]),
("blur everything softly", []), # must not trigger spatial ops
]
def spatial_ids(plan):
ids = []
for s in plan.get("spatial", []):
op = s["op"]
if op == "add_cutout":
style = s.get("cutout_style", "solid")
if style in ("transparent", "solid"):
ids.append(f"add_cutout_{style}")
else:
ids.append("add_cutout")
else:
ids.append(op)
return ids
def run_planning_tests():
print(f"Threshold: {SIMILARITY_THRESHOLD}\n")
failed = 0
for instr, expected_ids in planning_cases:
plan = plan_from_instruction(instr)
got_ids = [s[0] for s in plan.get("scores", [])]
all_matched = got_ids + plan["tags"] + spatial_ids(plan)
if expected_ids is None:
ok = not plan["supported"]
elif expected_ids == []:
ok = plan["supported"] and not plan.get("spatial")
else:
ok = all(eid in all_matched for eid in expected_ids)
status = "OK" if ok else "FAIL"
if not ok:
failed += 1
print(f"{status}: {instr!r}")
print(f" supported={plan['supported']} scores={plan.get('scores', [])}")
print(f" tags={plan['tags']} spatial={spatial_ids(plan)}")
if not ok:
print(f" expected: {expected_ids}")
print()
return failed
def run_cutout_execution_tests():
print("Cutout execution tests:\n")
failed = 0
img = Image.new("RGB", (400, 300), color=(100, 150, 200))
execution_cases = [
("add cutout", 1, 4),
("cover and add cutout", 1, 4),
("cover and several cutouts avoid text", 1, 4),
("add cutout avoid text", 1, 4),
("3 cutouts", 3, 3),
("one cutout", 1, 1),
]
for instr, min_cutouts, max_cutouts in execution_cases:
plan = plan_from_instruction(instr)
_, applied = apply_spatial_ops(img.copy(), plan["spatial"], 1.0, instr)
cutout_count = 0
for meta in applied:
if meta.get("cutouts"):
cutout_count += len(meta["cutouts"])
elif meta.get("op") in ("add_cutout", "cutout") and meta.get("cutout_count"):
cutout_count += meta["cutout_count"]
ok = min_cutouts <= cutout_count <= max_cutouts
status = "OK" if ok else "FAIL"
if not ok:
failed += 1
print(f"{status}: {instr!r} -> {cutout_count} cutout(s) (expected {min_cutouts}-{max_cutouts})")
for meta in applied:
if meta.get("cutouts"):
shapes = [c.get("shape") for c in meta["cutouts"]]
print(f" shapes: {shapes}")
print()
return failed
failed = run_planning_tests()
failed += run_cutout_execution_tests()
if failed:
raise SystemExit(f"{failed} test(s) failed")
print("All validation tests passed")
|