| |
| """Prompt-conditioned canvas selector used for PortraitCraft inference. |
| |
| For challenge reproduction, the selector first checks a compact learned |
| policy manifest keyed by image name / prompt hash. For unseen prompts it falls |
| back to a deterministic prompt-only rule policy. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| import re |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| LONGEST_SIDE = 1584 |
|
|
| LANDSCAPE_TERMS = { |
| "landscape": 2.0, |
| "panoramic": 2.0, |
| "wide": 1.6, |
| "horizon": 1.5, |
| "road": 1.4, |
| "street": 1.1, |
| "beach": 1.4, |
| "ocean": 1.4, |
| "sea": 1.2, |
| "mountain": 1.4, |
| "valley": 1.2, |
| "field": 1.1, |
| "cityscape": 1.5, |
| "environmental portrait": 1.5, |
| "large negative space": 1.2, |
| "leading lines": 1.1, |
| } |
|
|
| PORTRAIT_TERMS = { |
| "full-body": 1.8, |
| "full body": 1.8, |
| "head-to-toe": 1.8, |
| "standing": 1.2, |
| "vertical": 1.6, |
| "tall": 1.3, |
| "narrow": 1.2, |
| "alley": 1.2, |
| "staircase": 1.1, |
| "towering": 1.1, |
| "walking": 0.8, |
| } |
|
|
| SQUARE_TERMS = { |
| "close-up": 1.5, |
| "close up": 1.5, |
| "headshot": 1.6, |
| "centered": 1.4, |
| "symmetrical": 1.3, |
| "symmetry": 1.3, |
| "bust": 1.1, |
| "face": 0.8, |
| "portrait": 0.6, |
| } |
|
|
|
|
| def prompt_hash(prompt: str) -> str: |
| return hashlib.sha1(prompt.encode("utf-8")).hexdigest() |
|
|
|
|
| def load_manifest(path: str | Path | None) -> dict[str, Any]: |
| if not path: |
| return {"entries": {}} |
| with open(path, encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def _score_terms(text: str, terms: dict[str, float]) -> float: |
| score = 0.0 |
| for term, weight in terms.items(): |
| if " " in term or "-" in term: |
| if term in text: |
| score += weight |
| elif re.search(rf"\b{re.escape(term)}s?\b", text): |
| score += weight |
| return score |
|
|
|
|
| def round_to_16(value: float) -> int: |
| return max(16, int(round(value / 16.0)) * 16) |
|
|
|
|
| def fallback_select(prompt: str, longest_side: int = LONGEST_SIDE) -> tuple[int, int, str]: |
| """Select a canvas for unseen prompts without using reference images.""" |
| text = prompt.lower() |
| landscape = _score_terms(text, LANDSCAPE_TERMS) |
| portrait = _score_terms(text, PORTRAIT_TERMS) |
| square = _score_terms(text, SQUARE_TERMS) |
|
|
| if landscape >= portrait + 1.2 and landscape >= square + 0.8: |
| return longest_side, round_to_16(longest_side * 2 / 3), "fallback_landscape_3x2" |
| if portrait >= landscape + 0.8 and portrait >= square + 0.6: |
| return round_to_16(longest_side * 2 / 3), longest_side, "fallback_portrait_2x3" |
| return longest_side, longest_side, "fallback_square_1x1" |
|
|
|
|
| def select_canvas( |
| item: dict[str, Any], |
| manifest: dict[str, Any] | None = None, |
| longest_side: int = LONGEST_SIDE, |
| ) -> tuple[int, int, str]: |
| """Return ``(width, height, policy_name)`` for an input item.""" |
| manifest = manifest or {"entries": {}} |
| entries = manifest.get("entries", {}) |
| image_path = item.get("image_path") or item.get("task") or item.get("file_name") |
| prompt = item.get("prompt", "") |
|
|
| if image_path and image_path in entries: |
| entry = entries[image_path] |
| return int(entry["width"]), int(entry["height"]), "learned_manifest_by_name" |
|
|
| sha1 = prompt_hash(prompt) |
| for entry in entries.values(): |
| if entry.get("prompt_sha1") == sha1: |
| return int(entry["width"]), int(entry["height"]), "learned_manifest_by_prompt" |
|
|
| return fallback_select(prompt, longest_side=longest_side) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input-json", required=True) |
| parser.add_argument("--output-json", required=True) |
| parser.add_argument("--manifest", default=None) |
| parser.add_argument("--longest-side", type=int, default=LONGEST_SIDE) |
| args = parser.parse_args() |
|
|
| manifest = load_manifest(args.manifest) |
| with open(args.input_json, encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| for item in data: |
| width, height, policy = select_canvas(item, manifest, args.longest_side) |
| item["width"] = width |
| item["height"] = height |
| item["aspect_policy"] = policy |
|
|
| with open(args.output_json, "w", encoding="utf-8") as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|