| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from PIL import Image |
|
|
| from pipeline.embedding_planner import SIMILARITY_THRESHOLD |
| from pipeline.planner import plan_from_instruction, warmup |
| from pipeline.spatial import apply_spatial_ops |
|
|
| warmup() |
|
|
| planning_cases = [ |
| ("rotate", ["rotate"]), |
| ("make it brighter", ["brighten"]), |
| ("vintage warm look", ["sepia", "warmer"]), |
| ("replace background with beach sunset", ["replace_background"]), |
| ("put a table behind it", None), |
| ("blur everything softly", ["blur"]), |
| ("cartoon style", ["style_candy"]), |
| ("mosaic painting", ["style_mosaic"]), |
| ("impressionist look", ["style_rain_princess"]), |
| ("asdf qwerty nonsense", None), |
| ("more contrast", ["contrast_up"]), |
| ("transparent cutout", ["add_cutout_transparent"]), |
| ("cover a piece but not text", ["cover_avoid_text"]), |
| ("blur everything softly", []), |
| ] |
|
|
|
|
| def spatial_ids(plan): |
| ids = [] |
| for s in plan.get("spatial", []): |
| op = s["op"] |
| if op == "add_cutout": |
| style = s.get("cutout_style", "solid") |
| if style in ("transparent", "solid"): |
| ids.append(f"add_cutout_{style}") |
| else: |
| ids.append("add_cutout") |
| else: |
| ids.append(op) |
| return ids |
|
|
|
|
| def run_planning_tests(): |
| print(f"Threshold: {SIMILARITY_THRESHOLD}\n") |
| failed = 0 |
| for instr, expected_ids in planning_cases: |
| plan = plan_from_instruction(instr) |
| got_ids = [s[0] for s in plan.get("scores", [])] |
| all_matched = got_ids + plan["tags"] + spatial_ids(plan) |
|
|
| if expected_ids is None: |
| ok = not plan["supported"] |
| elif expected_ids == []: |
| ok = plan["supported"] and not plan.get("spatial") |
| else: |
| ok = all(eid in all_matched for eid in expected_ids) |
|
|
| status = "OK" if ok else "FAIL" |
| if not ok: |
| failed += 1 |
| print(f"{status}: {instr!r}") |
| print(f" supported={plan['supported']} scores={plan.get('scores', [])}") |
| print(f" tags={plan['tags']} spatial={spatial_ids(plan)}") |
| if not ok: |
| print(f" expected: {expected_ids}") |
| print() |
|
|
| return failed |
|
|
|
|
| def run_cutout_execution_tests(): |
| print("Cutout execution tests:\n") |
| failed = 0 |
| img = Image.new("RGB", (400, 300), color=(100, 150, 200)) |
|
|
| execution_cases = [ |
| ("add cutout", 1, 4), |
| ("cover and add cutout", 1, 4), |
| ("cover and several cutouts avoid text", 1, 4), |
| ("add cutout avoid text", 1, 4), |
| ("3 cutouts", 3, 3), |
| ("one cutout", 1, 1), |
| ] |
|
|
| for instr, min_cutouts, max_cutouts in execution_cases: |
| plan = plan_from_instruction(instr) |
| _, applied = apply_spatial_ops(img.copy(), plan["spatial"], 1.0, instr) |
|
|
| cutout_count = 0 |
| for meta in applied: |
| if meta.get("cutouts"): |
| cutout_count += len(meta["cutouts"]) |
| elif meta.get("op") in ("add_cutout", "cutout") and meta.get("cutout_count"): |
| cutout_count += meta["cutout_count"] |
|
|
| ok = min_cutouts <= cutout_count <= max_cutouts |
| status = "OK" if ok else "FAIL" |
| if not ok: |
| failed += 1 |
| print(f"{status}: {instr!r} -> {cutout_count} cutout(s) (expected {min_cutouts}-{max_cutouts})") |
| for meta in applied: |
| if meta.get("cutouts"): |
| shapes = [c.get("shape") for c in meta["cutouts"]] |
| print(f" shapes: {shapes}") |
| print() |
|
|
| return failed |
|
|
|
|
| failed = run_planning_tests() |
| failed += run_cutout_execution_tests() |
|
|
| if failed: |
| raise SystemExit(f"{failed} test(s) failed") |
|
|
| print("All validation tests passed") |
|
|