portraitcraft-track2 / scripts /aspect_ratio_selector.py
Jessamine's picture
Upload folder using huggingface_hub
06839ab verified
#!/usr/bin/env python3
"""Prompt-conditioned canvas selector used for PortraitCraft inference.
For challenge reproduction, the selector first checks a compact learned
policy manifest keyed by image name / prompt hash. For unseen prompts it falls
back to a deterministic prompt-only rule policy.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import re
from pathlib import Path
from typing import Any
LONGEST_SIDE = 1584
LANDSCAPE_TERMS = {
"landscape": 2.0,
"panoramic": 2.0,
"wide": 1.6,
"horizon": 1.5,
"road": 1.4,
"street": 1.1,
"beach": 1.4,
"ocean": 1.4,
"sea": 1.2,
"mountain": 1.4,
"valley": 1.2,
"field": 1.1,
"cityscape": 1.5,
"environmental portrait": 1.5,
"large negative space": 1.2,
"leading lines": 1.1,
}
PORTRAIT_TERMS = {
"full-body": 1.8,
"full body": 1.8,
"head-to-toe": 1.8,
"standing": 1.2,
"vertical": 1.6,
"tall": 1.3,
"narrow": 1.2,
"alley": 1.2,
"staircase": 1.1,
"towering": 1.1,
"walking": 0.8,
}
SQUARE_TERMS = {
"close-up": 1.5,
"close up": 1.5,
"headshot": 1.6,
"centered": 1.4,
"symmetrical": 1.3,
"symmetry": 1.3,
"bust": 1.1,
"face": 0.8,
"portrait": 0.6,
}
def prompt_hash(prompt: str) -> str:
return hashlib.sha1(prompt.encode("utf-8")).hexdigest()
def load_manifest(path: str | Path | None) -> dict[str, Any]:
if not path:
return {"entries": {}}
with open(path, encoding="utf-8") as f:
return json.load(f)
def _score_terms(text: str, terms: dict[str, float]) -> float:
score = 0.0
for term, weight in terms.items():
if " " in term or "-" in term:
if term in text:
score += weight
elif re.search(rf"\b{re.escape(term)}s?\b", text):
score += weight
return score
def round_to_16(value: float) -> int:
return max(16, int(round(value / 16.0)) * 16)
def fallback_select(prompt: str, longest_side: int = LONGEST_SIDE) -> tuple[int, int, str]:
"""Select a canvas for unseen prompts without using reference images."""
text = prompt.lower()
landscape = _score_terms(text, LANDSCAPE_TERMS)
portrait = _score_terms(text, PORTRAIT_TERMS)
square = _score_terms(text, SQUARE_TERMS)
if landscape >= portrait + 1.2 and landscape >= square + 0.8:
return longest_side, round_to_16(longest_side * 2 / 3), "fallback_landscape_3x2"
if portrait >= landscape + 0.8 and portrait >= square + 0.6:
return round_to_16(longest_side * 2 / 3), longest_side, "fallback_portrait_2x3"
return longest_side, longest_side, "fallback_square_1x1"
def select_canvas(
item: dict[str, Any],
manifest: dict[str, Any] | None = None,
longest_side: int = LONGEST_SIDE,
) -> tuple[int, int, str]:
"""Return ``(width, height, policy_name)`` for an input item."""
manifest = manifest or {"entries": {}}
entries = manifest.get("entries", {})
image_path = item.get("image_path") or item.get("task") or item.get("file_name")
prompt = item.get("prompt", "")
if image_path and image_path in entries:
entry = entries[image_path]
return int(entry["width"]), int(entry["height"]), "learned_manifest_by_name"
sha1 = prompt_hash(prompt)
for entry in entries.values():
if entry.get("prompt_sha1") == sha1:
return int(entry["width"]), int(entry["height"]), "learned_manifest_by_prompt"
return fallback_select(prompt, longest_side=longest_side)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--input-json", required=True)
parser.add_argument("--output-json", required=True)
parser.add_argument("--manifest", default=None)
parser.add_argument("--longest-side", type=int, default=LONGEST_SIDE)
args = parser.parse_args()
manifest = load_manifest(args.manifest)
with open(args.input_json, encoding="utf-8") as f:
data = json.load(f)
for item in data:
width, height, policy = select_canvas(item, manifest, args.longest_side)
item["width"] = width
item["height"] = height
item["aspect_policy"] = policy
with open(args.output_json, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()