text-driven-image-augmentation / scripts /validate_planner.py
DageBjorne
Add ONNX neural style transfer with embedding-matched keywords
022f82c
Raw
History Blame Contribute Delete
3.79 kB
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from PIL import Image
from pipeline.embedding_planner import SIMILARITY_THRESHOLD
from pipeline.planner import plan_from_instruction, warmup
from pipeline.spatial import apply_spatial_ops
warmup()
planning_cases = [
("rotate", ["rotate"]),
("make it brighter", ["brighten"]),
("vintage warm look", ["sepia", "warmer"]),
("replace background with beach sunset", ["replace_background"]),
("put a table behind it", None),
("blur everything softly", ["blur"]),
("cartoon style", ["style_candy"]),
("mosaic painting", ["style_mosaic"]),
("impressionist look", ["style_rain_princess"]),
("asdf qwerty nonsense", None),
("more contrast", ["contrast_up"]),
("transparent cutout", ["add_cutout_transparent"]),
("cover a piece but not text", ["cover_avoid_text"]),
("blur everything softly", []), # must not trigger spatial ops
]
def spatial_ids(plan):
ids = []
for s in plan.get("spatial", []):
op = s["op"]
if op == "add_cutout":
style = s.get("cutout_style", "solid")
if style in ("transparent", "solid"):
ids.append(f"add_cutout_{style}")
else:
ids.append("add_cutout")
else:
ids.append(op)
return ids
def run_planning_tests():
print(f"Threshold: {SIMILARITY_THRESHOLD}\n")
failed = 0
for instr, expected_ids in planning_cases:
plan = plan_from_instruction(instr)
got_ids = [s[0] for s in plan.get("scores", [])]
all_matched = got_ids + plan["tags"] + spatial_ids(plan)
if expected_ids is None:
ok = not plan["supported"]
elif expected_ids == []:
ok = plan["supported"] and not plan.get("spatial")
else:
ok = all(eid in all_matched for eid in expected_ids)
status = "OK" if ok else "FAIL"
if not ok:
failed += 1
print(f"{status}: {instr!r}")
print(f" supported={plan['supported']} scores={plan.get('scores', [])}")
print(f" tags={plan['tags']} spatial={spatial_ids(plan)}")
if not ok:
print(f" expected: {expected_ids}")
print()
return failed
def run_cutout_execution_tests():
print("Cutout execution tests:\n")
failed = 0
img = Image.new("RGB", (400, 300), color=(100, 150, 200))
execution_cases = [
("add cutout", 1, 4),
("cover and add cutout", 1, 4),
("cover and several cutouts avoid text", 1, 4),
("add cutout avoid text", 1, 4),
("3 cutouts", 3, 3),
("one cutout", 1, 1),
]
for instr, min_cutouts, max_cutouts in execution_cases:
plan = plan_from_instruction(instr)
_, applied = apply_spatial_ops(img.copy(), plan["spatial"], 1.0, instr)
cutout_count = 0
for meta in applied:
if meta.get("cutouts"):
cutout_count += len(meta["cutouts"])
elif meta.get("op") in ("add_cutout", "cutout") and meta.get("cutout_count"):
cutout_count += meta["cutout_count"]
ok = min_cutouts <= cutout_count <= max_cutouts
status = "OK" if ok else "FAIL"
if not ok:
failed += 1
print(f"{status}: {instr!r} -> {cutout_count} cutout(s) (expected {min_cutouts}-{max_cutouts})")
for meta in applied:
if meta.get("cutouts"):
shapes = [c.get("shape") for c in meta["cutouts"]]
print(f" shapes: {shapes}")
print()
return failed
failed = run_planning_tests()
failed += run_cutout_execution_tests()
if failed:
raise SystemExit(f"{failed} test(s) failed")
print("All validation tests passed")