File size: 15,763 Bytes
fdafd05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
"""OpenAI-compatible text-to-image prompt upsampling client."""

from __future__ import annotations

import json
import logging
import os
import re
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from agentic_upsampling.constants import (
    DEFAULT_LLM_EXTRA_BODY,
    DEFAULT_UPSAMPLER_ENDPOINT_URL,
    DEFAULT_UPSAMPLER_MODEL,
)
from agentic_upsampling.data import validate_t2i_json

JSON_ENSURE_ASCII = bool(int(os.environ.get("JSON_ENSURE_ASCII", "1")))
DEFAULT_USER_AGENT = "Cosmos3-Super-Text2Image-Agentic-Upsampling/1.0"
SYSTEM_MESSAGE: dict[str, Any] = {
    "role": "system",
    "content": [{"type": "text", "text": "You are a helpful assistant."}],
}
log = logging.getLogger(__name__)

RESOLUTION_RATIO_DICT: dict[str, dict[str, dict[str, int]]] = {
    "256": {
        "1,1": {"W": 256, "H": 256},
        "4,3": {"W": 320, "H": 256},
        "3,4": {"W": 256, "H": 320},
        "16,9": {"W": 320, "H": 192},
        "9,16": {"W": 192, "H": 320},
    },
    "480": {
        "1,1": {"W": 640, "H": 640},
        "4,3": {"W": 736, "H": 544},
        "3,4": {"W": 544, "H": 736},
        "16,9": {"W": 832, "H": 480},
        "9,16": {"W": 480, "H": 832},
    },
    "720": {
        "1,1": {"W": 960, "H": 960},
        "4,3": {"W": 1104, "H": 832},
        "3,4": {"W": 832, "H": 1104},
        "16,9": {"W": 1280, "H": 720},
        "9,16": {"W": 720, "H": 1280},
    },
    "768": {
        "1,1": {"W": 1024, "H": 1024},
        "4,3": {"W": 1184, "H": 880},
        "3,4": {"W": 880, "H": 1184},
        "16,9": {"W": 1360, "H": 768},
        "9,16": {"W": 768, "H": 1360},
    },
}

T2I_JSON_TEMPLATE = """Given the user's natural-language request below, generate a dense structured JSON that fully describes the image to be produced. The JSON must strictly follow the template provided after the request, including every top-level key and every nested sub-field.

The output is always dense. Even when the request is brief, infer plausible, scene-consistent details for every field. Do not leave fields empty merely because the request did not mention them. Be creative but stay grounded: additions must be physically plausible and internally consistent with the request.

Requirements:
- Extract visual intent from the user request into the visual fields.
- For every visual field, write rich, specific content inferred from the request's scene, subjects, mood, and context.
- Empty values ("", 0, [], {{}}) are permitted only for truly inapplicable fields.
- Do not add keys beyond the template. Do not omit keys required by the template.
- Return only the JSON object. Do not include markdown fences or prose outside JSON.

USER VISUAL REQUEST:
{caption_dense}

Lists may contain zero or more items of the shape shown. All top-level keys must always be present in the output; fill unused fields with "", 0, {{}}, or [] as appropriate.

{{
  "subjects": [
    {{
      "description": "full visual description of the subject",
      "appearance_details": "additional visual details such as accessories, texture, and distinguishing features",
      "relationship": "how this subject relates to others or to the scene",
      "location": "where in frame, for example center foreground or top right",
      "relative_size": "size within frame",
      "orientation": "direction subject faces relative to camera",
      "pose": "body position and posture",
      "clothing": "clothing and accessories; empty string if non-human or not applicable",
      "expression": "facial expression; empty string if non-human or not applicable",
      "gender": "Male, Female, Unknown, or N/A",
      "age": "age category",
      "skin_tone_and_texture": "skin tone description; empty string if non-human",
      "facial_features": "notable facial features; empty string if non-human or not visible",
      "number_of_subjects": "int; total in this subject group, 0 if not applicable",
      "number_of_arms": "int; 2 for humans, 0 if non-human",
      "number_of_legs": "int; 2 for humans, 0 if non-human",
      "number_of_hands": "int; 2 for humans, 0 if non-human",
      "number_of_fingers": "int; 10 for humans, 0 if non-human"
    }}
  ],
  "subject_details": {{
    "key_name_1": "free-form image-specific attribute; empty object if not applicable"
  }},
  "background_setting": "full prose description of the environment and setting",
  "lighting": {{
    "conditions": "type and quality of light",
    "direction": "where light comes from; None for flat digital images",
    "shadows": "shadow description; None for flat digital images",
    "illumination_effect": "overall effect of the lighting"
  }},
  "aesthetics": {{
    "composition": "framing and compositional choices",
    "color_scheme": "dominant colors and palette",
    "mood_atmosphere": "emotional atmosphere in short phrases",
    "patterns": "notable repeating visual patterns; None if none"
  }},
  "cinematography": {{
    "framing": "shot type",
    "camera_angle": "angle such as Eye-level, Low angle, or High angle",
    "depth_of_field": "Shallow, Deep, Uniform focus, or N/A",
    "focus": "what is in sharp focus",
    "lens_focal_length": "descriptive focal length"
  }},
  "style_medium": "visual medium, for example Photography, Digital illustration, or Screenshot",
  "artistic_style": "genre or approach",
  "context": "scene context or use case",
  "text_and_signage_elements": [
    {{
      "text": "the visible text content",
      "category": "physical_in_scene, ui_text, body_text, scene_sign, logo, or label",
      "appearance": "font, color, size, style",
      "spatial": "position in image",
      "context": "purpose or meaning of the text"
    }}
  ],
  "quadrant_scan": {{
    "top_left": "description of what appears in the top-left region",
    "top_right": "description of what appears in the top-right region",
    "bottom_left": "description of what appears in the bottom-left region",
    "bottom_right": "description of what appears in the bottom-right region",
    "absolute_center": "description of what appears at the center"
  }},
  "comprehensive_t2i_caption": "a comprehensive, full-scene natural-language prose description of the image",
  "resolution": {{
    "H": "will be overwritten by the selected resolution and aspect ratio",
    "W": "will be overwritten by the selected resolution and aspect ratio"
  }},
  "aspect_ratio": "will be overwritten by the selected aspect ratio"
}}"""


@dataclass(slots=True)
class ChatClientConfig:
    """Configuration for an OpenAI-compatible chat-completions endpoint."""

    endpoint_url: str
    model: str
    api_token: str
    timeout_s: float = 300.0
    max_tokens: int = 8192
    max_retries: int = 3
    retry_base_delay_s: float = 1.0
    extra_body: dict[str, Any] | None = None
    connection_max_retries: int = 2
    connection_pool_size: int = 4


class OpenAIChatClient:
    """Small synchronous OpenAI-compatible chat-completions client."""

    config: ChatClientConfig
    base_url: str
    session: requests.Session
    sleep: Callable[[float], None]

    def __init__(
        self,
        config: ChatClientConfig,
        *,
        session: requests.Session | None = None,
        sleep: Callable[[float], None] = time.sleep,
    ) -> None:
        self.config = config
        self.base_url = normalize_openai_base_url(config.endpoint_url)
        self.session = _make_session(config) if session is None else session
        self.sleep = sleep

    def complete(self, messages: list[dict[str, Any]], *, response_format_json: bool = False) -> str:
        """Request one chat completion and return assistant text."""

        def _call() -> str:
            payload: dict[str, Any] = {
                "model": self.config.model,
                "messages": messages,
                self._max_tokens_key(): self.config.max_tokens,
            }
            if response_format_json:
                payload["response_format"] = {"type": "json_object"}
            if self.config.extra_body:
                payload.update(self.config.extra_body)
            parsed = self._request_json("POST", f"{self.base_url}/chat/completions", payload=payload)
            choices = parsed.get("choices")
            if not isinstance(choices, list) or not choices:
                raise ValueError("Chat completion response missing choices.")
            first_choice = choices[0]
            if not isinstance(first_choice, dict):
                raise ValueError("Chat completion choice must be an object.")
            message = first_choice.get("message")
            if not isinstance(message, dict):
                raise ValueError("Chat completion choice missing message.")
            return _message_content_to_text(message.get("content"))

        return self._with_retries("complete chat request", _call)

    def _request_json(self, method: str, url: str, *, payload: dict[str, Any] | None = None) -> dict[str, Any]:
        headers = {"Accept": "application/json", "User-Agent": DEFAULT_USER_AGENT}
        if payload is not None:
            headers["Content-Type"] = "application/json"
        if self.config.api_token:
            headers["Authorization"] = f"Bearer {self.config.api_token}"
        try:
            response = self.session.request(method, url, json=payload, headers=headers, timeout=self.config.timeout_s)
        except requests.RequestException as exc:
            raise RuntimeError(f"Failed to reach {url}: {exc}") from exc
        if not response.ok:
            raise RuntimeError(f"HTTP {response.status_code} from {url}: {response.text[:1000]}")
        parsed = response.json()
        if not isinstance(parsed, dict):
            raise RuntimeError(f"Response from {url} must be a JSON object.")
        return parsed

    def _with_retries(self, operation: str, fn: Callable[[], str]) -> str:
        if self.config.max_retries < 1:
            raise ValueError("max_retries must be >= 1.")
        last_exc: Exception | None = None
        for attempt in range(self.config.max_retries):
            try:
                return fn()
            except Exception as exc:
                last_exc = exc
                if attempt == self.config.max_retries - 1:
                    break
                self.sleep(self.config.retry_base_delay_s * (2**attempt))
        raise RuntimeError(f"Failed to {operation} after {self.config.max_retries} attempts: {last_exc}") from last_exc

    def _max_tokens_key(self) -> str:
        if "api.openai.com" in self.base_url:
            return "max_completion_tokens"
        return "max_tokens"


class Text2ImagePromptUpsampler:
    """Create structured Cosmos3 text-to-image JSON prompts from user text."""

    chat_client: OpenAIChatClient

    def __init__(self, chat_client: OpenAIChatClient) -> None:
        self.chat_client = chat_client

    @classmethod
    def from_defaults(
        cls,
        *,
        api_token: str,
        endpoint_url: str = DEFAULT_UPSAMPLER_ENDPOINT_URL,
        model: str = DEFAULT_UPSAMPLER_MODEL,
        extra_body: dict[str, Any] | None = None,
    ) -> Text2ImagePromptUpsampler:
        """Build the default GPT-5.5 based T2I prompt upsampler."""
        return cls(
            OpenAIChatClient(
                ChatClientConfig(
                    endpoint_url=endpoint_url,
                    model=model,
                    api_token=api_token,
                    extra_body=DEFAULT_LLM_EXTRA_BODY if extra_body is None else extra_body,
                )
            )
        )

    def upsample(
        self,
        prompt: str,
        *,
        prompt_id: str,
        resolution: str,
        aspect_ratio: str,
        user_prompt: str | None = None,
    ) -> dict[str, Any]:
        """Return a validated structured T2I JSON prompt."""
        messages = build_t2i_messages(prompt, user_prompt=user_prompt)
        raw = self.chat_client.complete(messages, response_format_json=True)
        data = apply_t2i_output_parameters(extract_json_object(raw), resolution=resolution, aspect_ratio=aspect_ratio)
        validate_t2i_json(data, prompt_id)
        return data


def build_t2i_messages(prompt: str, *, user_prompt: str | None = None) -> list[dict[str, Any]]:
    """Build chat messages for the initial structured prompt upsampling request."""
    message_text = user_prompt or T2I_JSON_TEMPLATE.format(caption_dense=prompt.strip())
    return [
        SYSTEM_MESSAGE,
        {
            "role": "user",
            "content": [{"type": "text", "text": message_text}],
        },
    ]


def apply_t2i_output_parameters(data: dict[str, Any], *, resolution: str, aspect_ratio: str) -> dict[str, Any]:
    """Overwrite output metadata with the selected T2I canvas parameters."""
    if resolution not in RESOLUTION_RATIO_DICT:
        raise ValueError(f"Unsupported resolution {resolution!r}.")
    if aspect_ratio not in RESOLUTION_RATIO_DICT[resolution]:
        raise ValueError(f"Unsupported aspect_ratio {aspect_ratio!r} for resolution {resolution!r}.")
    resolution_pair = RESOLUTION_RATIO_DICT[resolution][aspect_ratio]
    data["resolution"] = {"H": resolution_pair["H"], "W": resolution_pair["W"]}
    data["aspect_ratio"] = aspect_ratio
    return data


def extract_json_object(text: str) -> dict[str, Any]:
    """Extract a JSON object from raw model text."""
    cleaned = text.strip()
    fence_match = re.search(r"```(?:json)?\s*(.*?)\s*```", cleaned, flags=re.DOTALL)
    if fence_match:
        cleaned = fence_match.group(1).strip()
    start = cleaned.find("{")
    end = cleaned.rfind("}")
    if start < 0 or end < start:
        raise ValueError("Model response did not contain a JSON object.")
    parsed = json.loads(cleaned[start : end + 1])
    if not isinstance(parsed, dict):
        raise ValueError("Model response JSON must be an object.")
    return parsed


def normalize_openai_base_url(url: str) -> str:
    """Normalize an OpenAI-compatible endpoint root."""
    normalized = url.strip().rstrip("/")
    if not normalized:
        raise ValueError("endpoint_url cannot be empty.")
    if not normalized.startswith(("http://", "https://")):
        normalized = f"https://{normalized}"
    if normalized.endswith("/chat/completions"):
        normalized = normalized[: -len("/chat/completions")]
    if normalized.endswith("/v1") or normalized.endswith("/openai"):
        return normalized
    return f"{normalized}/v1"


def _make_session(config: ChatClientConfig) -> requests.Session:
    session = requests.Session()
    retry = Retry(
        total=config.connection_max_retries,
        connect=config.connection_max_retries,
        read=0,
        status=2,
        status_forcelist=(429, 500, 502, 503, 504),
        allowed_methods=frozenset({"GET", "POST"}),
        backoff_factor=0.5,
        raise_on_status=False,
    )
    adapter = HTTPAdapter(
        pool_connections=config.connection_pool_size,
        pool_maxsize=config.connection_pool_size,
        max_retries=retry,
    )
    session.mount("https://", adapter)
    session.mount("http://", adapter)
    return session


def _message_content_to_text(content: Any) -> str:
    if isinstance(content, str) and content.strip():
        return content
    if isinstance(content, list):
        parts: list[str] = []
        for item in content:
            if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
                parts.append(item["text"])
        text = "".join(parts).strip()
        if text:
            return text
    raise ValueError("Chat completion message content is empty or unsupported.")