Instructions to use nvidia/Cosmos3-Super-Text2Image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Cosmos
How to use nvidia/Cosmos3-Super-Text2Image with Cosmos:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Diffusers
How to use nvidia/Cosmos3-Super-Text2Image with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("nvidia/Cosmos3-Super-Text2Image", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """OpenAI-compatible text-to-image prompt upsampling client.""" | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import time | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import requests | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| from agentic_upsampling.constants import ( | |
| DEFAULT_LLM_EXTRA_BODY, | |
| DEFAULT_UPSAMPLER_ENDPOINT_URL, | |
| DEFAULT_UPSAMPLER_MODEL, | |
| ) | |
| from agentic_upsampling.data import validate_t2i_json | |
| JSON_ENSURE_ASCII = bool(int(os.environ.get("JSON_ENSURE_ASCII", "1"))) | |
| DEFAULT_USER_AGENT = "Cosmos3-Super-Text2Image-Agentic-Upsampling/1.0" | |
| SYSTEM_MESSAGE: dict[str, Any] = { | |
| "role": "system", | |
| "content": [{"type": "text", "text": "You are a helpful assistant."}], | |
| } | |
| log = logging.getLogger(__name__) | |
| RESOLUTION_RATIO_DICT: dict[str, dict[str, dict[str, int]]] = { | |
| "256": { | |
| "1,1": {"W": 256, "H": 256}, | |
| "4,3": {"W": 320, "H": 256}, | |
| "3,4": {"W": 256, "H": 320}, | |
| "16,9": {"W": 320, "H": 192}, | |
| "9,16": {"W": 192, "H": 320}, | |
| }, | |
| "480": { | |
| "1,1": {"W": 640, "H": 640}, | |
| "4,3": {"W": 736, "H": 544}, | |
| "3,4": {"W": 544, "H": 736}, | |
| "16,9": {"W": 832, "H": 480}, | |
| "9,16": {"W": 480, "H": 832}, | |
| }, | |
| "720": { | |
| "1,1": {"W": 960, "H": 960}, | |
| "4,3": {"W": 1104, "H": 832}, | |
| "3,4": {"W": 832, "H": 1104}, | |
| "16,9": {"W": 1280, "H": 720}, | |
| "9,16": {"W": 720, "H": 1280}, | |
| }, | |
| "768": { | |
| "1,1": {"W": 1024, "H": 1024}, | |
| "4,3": {"W": 1184, "H": 880}, | |
| "3,4": {"W": 880, "H": 1184}, | |
| "16,9": {"W": 1360, "H": 768}, | |
| "9,16": {"W": 768, "H": 1360}, | |
| }, | |
| } | |
| T2I_JSON_TEMPLATE = """Given the user's natural-language request below, generate a dense structured JSON that fully describes the image to be produced. The JSON must strictly follow the template provided after the request, including every top-level key and every nested sub-field. | |
| The output is always dense. Even when the request is brief, infer plausible, scene-consistent details for every field. Do not leave fields empty merely because the request did not mention them. Be creative but stay grounded: additions must be physically plausible and internally consistent with the request. | |
| Requirements: | |
| - Extract visual intent from the user request into the visual fields. | |
| - For every visual field, write rich, specific content inferred from the request's scene, subjects, mood, and context. | |
| - Empty values ("", 0, [], {{}}) are permitted only for truly inapplicable fields. | |
| - Do not add keys beyond the template. Do not omit keys required by the template. | |
| - Return only the JSON object. Do not include markdown fences or prose outside JSON. | |
| USER VISUAL REQUEST: | |
| {caption_dense} | |
| Lists may contain zero or more items of the shape shown. All top-level keys must always be present in the output; fill unused fields with "", 0, {{}}, or [] as appropriate. | |
| {{ | |
| "subjects": [ | |
| {{ | |
| "description": "full visual description of the subject", | |
| "appearance_details": "additional visual details such as accessories, texture, and distinguishing features", | |
| "relationship": "how this subject relates to others or to the scene", | |
| "location": "where in frame, for example center foreground or top right", | |
| "relative_size": "size within frame", | |
| "orientation": "direction subject faces relative to camera", | |
| "pose": "body position and posture", | |
| "clothing": "clothing and accessories; empty string if non-human or not applicable", | |
| "expression": "facial expression; empty string if non-human or not applicable", | |
| "gender": "Male, Female, Unknown, or N/A", | |
| "age": "age category", | |
| "skin_tone_and_texture": "skin tone description; empty string if non-human", | |
| "facial_features": "notable facial features; empty string if non-human or not visible", | |
| "number_of_subjects": "int; total in this subject group, 0 if not applicable", | |
| "number_of_arms": "int; 2 for humans, 0 if non-human", | |
| "number_of_legs": "int; 2 for humans, 0 if non-human", | |
| "number_of_hands": "int; 2 for humans, 0 if non-human", | |
| "number_of_fingers": "int; 10 for humans, 0 if non-human" | |
| }} | |
| ], | |
| "subject_details": {{ | |
| "key_name_1": "free-form image-specific attribute; empty object if not applicable" | |
| }}, | |
| "background_setting": "full prose description of the environment and setting", | |
| "lighting": {{ | |
| "conditions": "type and quality of light", | |
| "direction": "where light comes from; None for flat digital images", | |
| "shadows": "shadow description; None for flat digital images", | |
| "illumination_effect": "overall effect of the lighting" | |
| }}, | |
| "aesthetics": {{ | |
| "composition": "framing and compositional choices", | |
| "color_scheme": "dominant colors and palette", | |
| "mood_atmosphere": "emotional atmosphere in short phrases", | |
| "patterns": "notable repeating visual patterns; None if none" | |
| }}, | |
| "cinematography": {{ | |
| "framing": "shot type", | |
| "camera_angle": "angle such as Eye-level, Low angle, or High angle", | |
| "depth_of_field": "Shallow, Deep, Uniform focus, or N/A", | |
| "focus": "what is in sharp focus", | |
| "lens_focal_length": "descriptive focal length" | |
| }}, | |
| "style_medium": "visual medium, for example Photography, Digital illustration, or Screenshot", | |
| "artistic_style": "genre or approach", | |
| "context": "scene context or use case", | |
| "text_and_signage_elements": [ | |
| {{ | |
| "text": "the visible text content", | |
| "category": "physical_in_scene, ui_text, body_text, scene_sign, logo, or label", | |
| "appearance": "font, color, size, style", | |
| "spatial": "position in image", | |
| "context": "purpose or meaning of the text" | |
| }} | |
| ], | |
| "quadrant_scan": {{ | |
| "top_left": "description of what appears in the top-left region", | |
| "top_right": "description of what appears in the top-right region", | |
| "bottom_left": "description of what appears in the bottom-left region", | |
| "bottom_right": "description of what appears in the bottom-right region", | |
| "absolute_center": "description of what appears at the center" | |
| }}, | |
| "comprehensive_t2i_caption": "a comprehensive, full-scene natural-language prose description of the image", | |
| "resolution": {{ | |
| "H": "will be overwritten by the selected resolution and aspect ratio", | |
| "W": "will be overwritten by the selected resolution and aspect ratio" | |
| }}, | |
| "aspect_ratio": "will be overwritten by the selected aspect ratio" | |
| }}""" | |
| class ChatClientConfig: | |
| """Configuration for an OpenAI-compatible chat-completions endpoint.""" | |
| endpoint_url: str | |
| model: str | |
| api_token: str | |
| timeout_s: float = 300.0 | |
| max_tokens: int = 8192 | |
| max_retries: int = 3 | |
| retry_base_delay_s: float = 1.0 | |
| extra_body: dict[str, Any] | None = None | |
| connection_max_retries: int = 2 | |
| connection_pool_size: int = 4 | |
| class OpenAIChatClient: | |
| """Small synchronous OpenAI-compatible chat-completions client.""" | |
| config: ChatClientConfig | |
| base_url: str | |
| session: requests.Session | |
| sleep: Callable[[float], None] | |
| def __init__( | |
| self, | |
| config: ChatClientConfig, | |
| *, | |
| session: requests.Session | None = None, | |
| sleep: Callable[[float], None] = time.sleep, | |
| ) -> None: | |
| self.config = config | |
| self.base_url = normalize_openai_base_url(config.endpoint_url) | |
| self.session = _make_session(config) if session is None else session | |
| self.sleep = sleep | |
| def complete(self, messages: list[dict[str, Any]], *, response_format_json: bool = False) -> str: | |
| """Request one chat completion and return assistant text.""" | |
| def _call() -> str: | |
| payload: dict[str, Any] = { | |
| "model": self.config.model, | |
| "messages": messages, | |
| self._max_tokens_key(): self.config.max_tokens, | |
| } | |
| if response_format_json: | |
| payload["response_format"] = {"type": "json_object"} | |
| if self.config.extra_body: | |
| payload.update(self.config.extra_body) | |
| parsed = self._request_json("POST", f"{self.base_url}/chat/completions", payload=payload) | |
| choices = parsed.get("choices") | |
| if not isinstance(choices, list) or not choices: | |
| raise ValueError("Chat completion response missing choices.") | |
| first_choice = choices[0] | |
| if not isinstance(first_choice, dict): | |
| raise ValueError("Chat completion choice must be an object.") | |
| message = first_choice.get("message") | |
| if not isinstance(message, dict): | |
| raise ValueError("Chat completion choice missing message.") | |
| return _message_content_to_text(message.get("content")) | |
| return self._with_retries("complete chat request", _call) | |
| def _request_json(self, method: str, url: str, *, payload: dict[str, Any] | None = None) -> dict[str, Any]: | |
| headers = {"Accept": "application/json", "User-Agent": DEFAULT_USER_AGENT} | |
| if payload is not None: | |
| headers["Content-Type"] = "application/json" | |
| if self.config.api_token: | |
| headers["Authorization"] = f"Bearer {self.config.api_token}" | |
| try: | |
| response = self.session.request(method, url, json=payload, headers=headers, timeout=self.config.timeout_s) | |
| except requests.RequestException as exc: | |
| raise RuntimeError(f"Failed to reach {url}: {exc}") from exc | |
| if not response.ok: | |
| raise RuntimeError(f"HTTP {response.status_code} from {url}: {response.text[:1000]}") | |
| parsed = response.json() | |
| if not isinstance(parsed, dict): | |
| raise RuntimeError(f"Response from {url} must be a JSON object.") | |
| return parsed | |
| def _with_retries(self, operation: str, fn: Callable[[], str]) -> str: | |
| if self.config.max_retries < 1: | |
| raise ValueError("max_retries must be >= 1.") | |
| last_exc: Exception | None = None | |
| for attempt in range(self.config.max_retries): | |
| try: | |
| return fn() | |
| except Exception as exc: | |
| last_exc = exc | |
| if attempt == self.config.max_retries - 1: | |
| break | |
| self.sleep(self.config.retry_base_delay_s * (2**attempt)) | |
| raise RuntimeError(f"Failed to {operation} after {self.config.max_retries} attempts: {last_exc}") from last_exc | |
| def _max_tokens_key(self) -> str: | |
| if "api.openai.com" in self.base_url: | |
| return "max_completion_tokens" | |
| return "max_tokens" | |
| class Text2ImagePromptUpsampler: | |
| """Create structured Cosmos3 text-to-image JSON prompts from user text.""" | |
| chat_client: OpenAIChatClient | |
| def __init__(self, chat_client: OpenAIChatClient) -> None: | |
| self.chat_client = chat_client | |
| def from_defaults( | |
| cls, | |
| *, | |
| api_token: str, | |
| endpoint_url: str = DEFAULT_UPSAMPLER_ENDPOINT_URL, | |
| model: str = DEFAULT_UPSAMPLER_MODEL, | |
| extra_body: dict[str, Any] | None = None, | |
| ) -> Text2ImagePromptUpsampler: | |
| """Build the default GPT-5.5 based T2I prompt upsampler.""" | |
| return cls( | |
| OpenAIChatClient( | |
| ChatClientConfig( | |
| endpoint_url=endpoint_url, | |
| model=model, | |
| api_token=api_token, | |
| extra_body=DEFAULT_LLM_EXTRA_BODY if extra_body is None else extra_body, | |
| ) | |
| ) | |
| ) | |
| def upsample( | |
| self, | |
| prompt: str, | |
| *, | |
| prompt_id: str, | |
| resolution: str, | |
| aspect_ratio: str, | |
| user_prompt: str | None = None, | |
| ) -> dict[str, Any]: | |
| """Return a validated structured T2I JSON prompt.""" | |
| messages = build_t2i_messages(prompt, user_prompt=user_prompt) | |
| raw = self.chat_client.complete(messages, response_format_json=True) | |
| data = apply_t2i_output_parameters(extract_json_object(raw), resolution=resolution, aspect_ratio=aspect_ratio) | |
| validate_t2i_json(data, prompt_id) | |
| return data | |
| def build_t2i_messages(prompt: str, *, user_prompt: str | None = None) -> list[dict[str, Any]]: | |
| """Build chat messages for the initial structured prompt upsampling request.""" | |
| message_text = user_prompt or T2I_JSON_TEMPLATE.format(caption_dense=prompt.strip()) | |
| return [ | |
| SYSTEM_MESSAGE, | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": message_text}], | |
| }, | |
| ] | |
| def apply_t2i_output_parameters(data: dict[str, Any], *, resolution: str, aspect_ratio: str) -> dict[str, Any]: | |
| """Overwrite output metadata with the selected T2I canvas parameters.""" | |
| if resolution not in RESOLUTION_RATIO_DICT: | |
| raise ValueError(f"Unsupported resolution {resolution!r}.") | |
| if aspect_ratio not in RESOLUTION_RATIO_DICT[resolution]: | |
| raise ValueError(f"Unsupported aspect_ratio {aspect_ratio!r} for resolution {resolution!r}.") | |
| resolution_pair = RESOLUTION_RATIO_DICT[resolution][aspect_ratio] | |
| data["resolution"] = {"H": resolution_pair["H"], "W": resolution_pair["W"]} | |
| data["aspect_ratio"] = aspect_ratio | |
| return data | |
| def extract_json_object(text: str) -> dict[str, Any]: | |
| """Extract a JSON object from raw model text.""" | |
| cleaned = text.strip() | |
| fence_match = re.search(r"```(?:json)?\s*(.*?)\s*```", cleaned, flags=re.DOTALL) | |
| if fence_match: | |
| cleaned = fence_match.group(1).strip() | |
| start = cleaned.find("{") | |
| end = cleaned.rfind("}") | |
| if start < 0 or end < start: | |
| raise ValueError("Model response did not contain a JSON object.") | |
| parsed = json.loads(cleaned[start : end + 1]) | |
| if not isinstance(parsed, dict): | |
| raise ValueError("Model response JSON must be an object.") | |
| return parsed | |
| def normalize_openai_base_url(url: str) -> str: | |
| """Normalize an OpenAI-compatible endpoint root.""" | |
| normalized = url.strip().rstrip("/") | |
| if not normalized: | |
| raise ValueError("endpoint_url cannot be empty.") | |
| if not normalized.startswith(("http://", "https://")): | |
| normalized = f"https://{normalized}" | |
| if normalized.endswith("/chat/completions"): | |
| normalized = normalized[: -len("/chat/completions")] | |
| if normalized.endswith("/v1") or normalized.endswith("/openai"): | |
| return normalized | |
| return f"{normalized}/v1" | |
| def _make_session(config: ChatClientConfig) -> requests.Session: | |
| session = requests.Session() | |
| retry = Retry( | |
| total=config.connection_max_retries, | |
| connect=config.connection_max_retries, | |
| read=0, | |
| status=2, | |
| status_forcelist=(429, 500, 502, 503, 504), | |
| allowed_methods=frozenset({"GET", "POST"}), | |
| backoff_factor=0.5, | |
| raise_on_status=False, | |
| ) | |
| adapter = HTTPAdapter( | |
| pool_connections=config.connection_pool_size, | |
| pool_maxsize=config.connection_pool_size, | |
| max_retries=retry, | |
| ) | |
| session.mount("https://", adapter) | |
| session.mount("http://", adapter) | |
| return session | |
| def _message_content_to_text(content: Any) -> str: | |
| if isinstance(content, str) and content.strip(): | |
| return content | |
| if isinstance(content, list): | |
| parts: list[str] = [] | |
| for item in content: | |
| if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str): | |
| parts.append(item["text"]) | |
| text = "".join(parts).strip() | |
| if text: | |
| return text | |
| raise ValueError("Chat completion message content is empty or unsupported.") | |