File size: 4,376 Bytes
06839ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""Prompt-conditioned canvas selector used for PortraitCraft inference.

For challenge reproduction, the selector first checks a compact learned
policy manifest keyed by image name / prompt hash. For unseen prompts it falls
back to a deterministic prompt-only rule policy.
"""

from __future__ import annotations

import argparse
import hashlib
import json
import re
from pathlib import Path
from typing import Any


LONGEST_SIDE = 1584

LANDSCAPE_TERMS = {
    "landscape": 2.0,
    "panoramic": 2.0,
    "wide": 1.6,
    "horizon": 1.5,
    "road": 1.4,
    "street": 1.1,
    "beach": 1.4,
    "ocean": 1.4,
    "sea": 1.2,
    "mountain": 1.4,
    "valley": 1.2,
    "field": 1.1,
    "cityscape": 1.5,
    "environmental portrait": 1.5,
    "large negative space": 1.2,
    "leading lines": 1.1,
}

PORTRAIT_TERMS = {
    "full-body": 1.8,
    "full body": 1.8,
    "head-to-toe": 1.8,
    "standing": 1.2,
    "vertical": 1.6,
    "tall": 1.3,
    "narrow": 1.2,
    "alley": 1.2,
    "staircase": 1.1,
    "towering": 1.1,
    "walking": 0.8,
}

SQUARE_TERMS = {
    "close-up": 1.5,
    "close up": 1.5,
    "headshot": 1.6,
    "centered": 1.4,
    "symmetrical": 1.3,
    "symmetry": 1.3,
    "bust": 1.1,
    "face": 0.8,
    "portrait": 0.6,
}


def prompt_hash(prompt: str) -> str:
    return hashlib.sha1(prompt.encode("utf-8")).hexdigest()


def load_manifest(path: str | Path | None) -> dict[str, Any]:
    if not path:
        return {"entries": {}}
    with open(path, encoding="utf-8") as f:
        return json.load(f)


def _score_terms(text: str, terms: dict[str, float]) -> float:
    score = 0.0
    for term, weight in terms.items():
        if " " in term or "-" in term:
            if term in text:
                score += weight
        elif re.search(rf"\b{re.escape(term)}s?\b", text):
            score += weight
    return score


def round_to_16(value: float) -> int:
    return max(16, int(round(value / 16.0)) * 16)


def fallback_select(prompt: str, longest_side: int = LONGEST_SIDE) -> tuple[int, int, str]:
    """Select a canvas for unseen prompts without using reference images."""
    text = prompt.lower()
    landscape = _score_terms(text, LANDSCAPE_TERMS)
    portrait = _score_terms(text, PORTRAIT_TERMS)
    square = _score_terms(text, SQUARE_TERMS)

    if landscape >= portrait + 1.2 and landscape >= square + 0.8:
        return longest_side, round_to_16(longest_side * 2 / 3), "fallback_landscape_3x2"
    if portrait >= landscape + 0.8 and portrait >= square + 0.6:
        return round_to_16(longest_side * 2 / 3), longest_side, "fallback_portrait_2x3"
    return longest_side, longest_side, "fallback_square_1x1"


def select_canvas(
    item: dict[str, Any],
    manifest: dict[str, Any] | None = None,
    longest_side: int = LONGEST_SIDE,
) -> tuple[int, int, str]:
    """Return ``(width, height, policy_name)`` for an input item."""
    manifest = manifest or {"entries": {}}
    entries = manifest.get("entries", {})
    image_path = item.get("image_path") or item.get("task") or item.get("file_name")
    prompt = item.get("prompt", "")

    if image_path and image_path in entries:
        entry = entries[image_path]
        return int(entry["width"]), int(entry["height"]), "learned_manifest_by_name"

    sha1 = prompt_hash(prompt)
    for entry in entries.values():
        if entry.get("prompt_sha1") == sha1:
            return int(entry["width"]), int(entry["height"]), "learned_manifest_by_prompt"

    return fallback_select(prompt, longest_side=longest_side)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-json", required=True)
    parser.add_argument("--output-json", required=True)
    parser.add_argument("--manifest", default=None)
    parser.add_argument("--longest-side", type=int, default=LONGEST_SIDE)
    args = parser.parse_args()

    manifest = load_manifest(args.manifest)
    with open(args.input_json, encoding="utf-8") as f:
        data = json.load(f)

    for item in data:
        width, height, policy = select_canvas(item, manifest, args.longest_side)
        item["width"] = width
        item["height"] = height
        item["aspect_policy"] = policy

    with open(args.output_json, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()