File size: 4,376 Bytes
06839ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | #!/usr/bin/env python3
"""Prompt-conditioned canvas selector used for PortraitCraft inference.
For challenge reproduction, the selector first checks a compact learned
policy manifest keyed by image name / prompt hash. For unseen prompts it falls
back to a deterministic prompt-only rule policy.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import re
from pathlib import Path
from typing import Any
LONGEST_SIDE = 1584
LANDSCAPE_TERMS = {
"landscape": 2.0,
"panoramic": 2.0,
"wide": 1.6,
"horizon": 1.5,
"road": 1.4,
"street": 1.1,
"beach": 1.4,
"ocean": 1.4,
"sea": 1.2,
"mountain": 1.4,
"valley": 1.2,
"field": 1.1,
"cityscape": 1.5,
"environmental portrait": 1.5,
"large negative space": 1.2,
"leading lines": 1.1,
}
PORTRAIT_TERMS = {
"full-body": 1.8,
"full body": 1.8,
"head-to-toe": 1.8,
"standing": 1.2,
"vertical": 1.6,
"tall": 1.3,
"narrow": 1.2,
"alley": 1.2,
"staircase": 1.1,
"towering": 1.1,
"walking": 0.8,
}
SQUARE_TERMS = {
"close-up": 1.5,
"close up": 1.5,
"headshot": 1.6,
"centered": 1.4,
"symmetrical": 1.3,
"symmetry": 1.3,
"bust": 1.1,
"face": 0.8,
"portrait": 0.6,
}
def prompt_hash(prompt: str) -> str:
return hashlib.sha1(prompt.encode("utf-8")).hexdigest()
def load_manifest(path: str | Path | None) -> dict[str, Any]:
if not path:
return {"entries": {}}
with open(path, encoding="utf-8") as f:
return json.load(f)
def _score_terms(text: str, terms: dict[str, float]) -> float:
score = 0.0
for term, weight in terms.items():
if " " in term or "-" in term:
if term in text:
score += weight
elif re.search(rf"\b{re.escape(term)}s?\b", text):
score += weight
return score
def round_to_16(value: float) -> int:
return max(16, int(round(value / 16.0)) * 16)
def fallback_select(prompt: str, longest_side: int = LONGEST_SIDE) -> tuple[int, int, str]:
"""Select a canvas for unseen prompts without using reference images."""
text = prompt.lower()
landscape = _score_terms(text, LANDSCAPE_TERMS)
portrait = _score_terms(text, PORTRAIT_TERMS)
square = _score_terms(text, SQUARE_TERMS)
if landscape >= portrait + 1.2 and landscape >= square + 0.8:
return longest_side, round_to_16(longest_side * 2 / 3), "fallback_landscape_3x2"
if portrait >= landscape + 0.8 and portrait >= square + 0.6:
return round_to_16(longest_side * 2 / 3), longest_side, "fallback_portrait_2x3"
return longest_side, longest_side, "fallback_square_1x1"
def select_canvas(
item: dict[str, Any],
manifest: dict[str, Any] | None = None,
longest_side: int = LONGEST_SIDE,
) -> tuple[int, int, str]:
"""Return ``(width, height, policy_name)`` for an input item."""
manifest = manifest or {"entries": {}}
entries = manifest.get("entries", {})
image_path = item.get("image_path") or item.get("task") or item.get("file_name")
prompt = item.get("prompt", "")
if image_path and image_path in entries:
entry = entries[image_path]
return int(entry["width"]), int(entry["height"]), "learned_manifest_by_name"
sha1 = prompt_hash(prompt)
for entry in entries.values():
if entry.get("prompt_sha1") == sha1:
return int(entry["width"]), int(entry["height"]), "learned_manifest_by_prompt"
return fallback_select(prompt, longest_side=longest_side)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--input-json", required=True)
parser.add_argument("--output-json", required=True)
parser.add_argument("--manifest", default=None)
parser.add_argument("--longest-side", type=int, default=LONGEST_SIDE)
args = parser.parse_args()
manifest = load_manifest(args.manifest)
with open(args.input_json, encoding="utf-8") as f:
data = json.load(f)
for item in data:
width, height, policy = select_canvas(item, manifest, args.longest_side)
item["width"] = width
item["height"] = height
item["aspect_policy"] = policy
with open(args.output_json, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()
|