Instructions to use iceDonkey/Cosmos3-Super-Text2Image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Cosmos
How to use iceDonkey/Cosmos3-Super-Text2Image with Cosmos:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Diffusers
How to use iceDonkey/Cosmos3-Super-Text2Image with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("iceDonkey/Cosmos3-Super-Text2Image", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """Network clients for standalone agentic text-to-image upsampling.""" | |
| from __future__ import annotations | |
| import base64 | |
| import io | |
| import json | |
| import os | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| from PIL import Image | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| from agentic_upsampling.constants import ( | |
| DEFAULT_ASPECT_RATIO, | |
| DEFAULT_CRITIC_ENDPOINT_URL, | |
| DEFAULT_CRITIC_MODEL, | |
| DEFAULT_GENERATION_AUTH_KEY_ENV, | |
| DEFAULT_GENERATION_EXTRA_ARGS, | |
| DEFAULT_GENERATION_MODEL, | |
| DEFAULT_FLOW_SHIFT, | |
| DEFAULT_GUIDANCE, | |
| DEFAULT_IMAGE_SIZE, | |
| DEFAULT_JPEG_QUALITY, | |
| DEFAULT_LLM_EXTRA_BODY, | |
| DEFAULT_NUM_STEPS, | |
| DEFAULT_OPENAI_API_KEY_ENV, | |
| DEFAULT_RESOLUTION, | |
| DEFAULT_REWRITER_ENDPOINT_URL, | |
| DEFAULT_REWRITER_MODEL, | |
| DEFAULT_UPSAMPLER_ENDPOINT_URL, | |
| DEFAULT_UPSAMPLER_MODEL, | |
| ) | |
| from agentic_upsampling.data import PromptItem, validate_t2i_json | |
| from agentic_upsampling.io_utils import compact_json, write_json_atomic | |
| from agentic_upsampling.prompt_upsampler import ( | |
| JSON_ENSURE_ASCII, | |
| SYSTEM_MESSAGE, | |
| ChatClientConfig, | |
| OpenAIChatClient, | |
| Text2ImagePromptUpsampler, | |
| extract_json_object, | |
| ) | |
| from agentic_upsampling.rubric import ( | |
| all_category_check_text, | |
| analysis_json_text, | |
| build_judge_prompt, | |
| compact_analysis_for_rewrite, | |
| parse_analysis_response, | |
| ) | |
| CONNECT_TIMEOUT_S = 60 | |
| SUBMIT_READ_TIMEOUT_S = 240 | |
| IMAGE_GENERATION_READ_TIMEOUT_S = 600 | |
| REWRITER_APPLICATION_GUIDANCE = all_category_check_text() | |
| class GenerationOutput: | |
| """Output from one image generation request.""" | |
| image_path: Path | |
| meta_path: Path | |
| meta: dict[str, Any] | |
| def read_api_token(api_key_env: str, api_key_file: Path | None = None) -> str: | |
| """Resolve an API token from an environment variable or explicit file.""" | |
| token = os.environ.get(api_key_env, "").strip() | |
| if token: | |
| return token | |
| if api_key_file is not None and api_key_file.exists(): | |
| token = api_key_file.read_text(encoding="utf-8").strip() | |
| if token: | |
| return token | |
| raise RuntimeError(f"Missing API key. Export {api_key_env} or pass the matching --*-api-key-file flag.") | |
| def read_optional_generation_auth_key(auth_key: str, api_key_env: str = DEFAULT_GENERATION_AUTH_KEY_ENV) -> str: | |
| """Resolve the optional generation endpoint auth key.""" | |
| return auth_key.strip() or os.environ.get(api_key_env, "").strip() | |
| def normalize_generation_endpoint(endpoint: str) -> str: | |
| """Normalize the vLLM-Omni endpoint root without the /v1 suffix.""" | |
| normalized = endpoint.strip().rstrip("/") | |
| if not normalized: | |
| raise ValueError("generation endpoint cannot be empty.") | |
| if not normalized.startswith(("http://", "https://")): | |
| normalized = f"https://{normalized}" | |
| if normalized.endswith("/v1/images/generations"): | |
| normalized = normalized[: -len("/v1/images/generations")] | |
| elif normalized.endswith("/v1"): | |
| normalized = normalized[: -len("/v1")] | |
| return normalized.rstrip("/") | |
| def make_session(pool_size: int = 4) -> requests.Session: | |
| """Create a retrying HTTP session.""" | |
| session = requests.Session() | |
| retry = Retry( | |
| total=2, | |
| connect=2, | |
| read=0, | |
| status=2, | |
| status_forcelist=(429, 500, 502, 503, 504), | |
| allowed_methods=frozenset({"GET", "POST"}), | |
| backoff_factor=0.5, | |
| raise_on_status=False, | |
| ) | |
| adapter = HTTPAdapter(pool_connections=pool_size, pool_maxsize=pool_size, max_retries=retry, pool_block=False) | |
| session.mount("https://", adapter) | |
| session.mount("http://", adapter) | |
| return session | |
| def image_path_to_data_url(path: Path, *, jpeg_quality: int | None = DEFAULT_JPEG_QUALITY) -> str: | |
| """Encode a local image file as a data URL, optionally transcoding to JPEG.""" | |
| if jpeg_quality is None: | |
| encoded = base64.b64encode(path.read_bytes()).decode("ascii") | |
| return f"data:image/png;base64,{encoded}" | |
| with Image.open(path) as image: | |
| if image.mode not in ("RGB", "L"): | |
| image = image.convert("RGB") | |
| buf = io.BytesIO() | |
| image.save(buf, format="JPEG", quality=jpeg_quality, optimize=True) | |
| encoded = base64.b64encode(buf.getvalue()).decode("ascii") | |
| return f"data:image/jpeg;base64,{encoded}" | |
| class PromptRewriterClient: | |
| """GPT-based T2I JSON prompt upsampler and iterative rewriter.""" | |
| upsampler: Text2ImagePromptUpsampler | |
| rewrite_client: OpenAIChatClient | |
| resolution: str | |
| aspect_ratio: str | |
| def __init__( | |
| self, | |
| *, | |
| api_token: str, | |
| upsampler_endpoint_url: str = DEFAULT_UPSAMPLER_ENDPOINT_URL, | |
| upsampler_model: str = DEFAULT_UPSAMPLER_MODEL, | |
| rewriter_endpoint_url: str = DEFAULT_REWRITER_ENDPOINT_URL, | |
| rewriter_model: str = DEFAULT_REWRITER_MODEL, | |
| extra_body: dict[str, Any] | None = None, | |
| resolution: str = DEFAULT_RESOLUTION, | |
| aspect_ratio: str = DEFAULT_ASPECT_RATIO, | |
| ) -> None: | |
| resolved_extra_body = DEFAULT_LLM_EXTRA_BODY if extra_body is None else extra_body | |
| self.upsampler = Text2ImagePromptUpsampler.from_defaults( | |
| api_token=api_token, | |
| endpoint_url=upsampler_endpoint_url, | |
| model=upsampler_model, | |
| extra_body=resolved_extra_body, | |
| ) | |
| self.rewrite_client = OpenAIChatClient( | |
| ChatClientConfig( | |
| endpoint_url=rewriter_endpoint_url, | |
| model=rewriter_model, | |
| api_token=api_token, | |
| extra_body=resolved_extra_body, | |
| max_tokens=8192, | |
| max_retries=3, | |
| ) | |
| ) | |
| self.resolution = resolution | |
| self.aspect_ratio = aspect_ratio | |
| def initial_prompt(self, item: PromptItem) -> dict[str, Any]: | |
| """Create the initial dense structured prompt for a user prompt.""" | |
| return self.upsampler.upsample( | |
| item.prompt, | |
| prompt_id=item.prompt_id, | |
| resolution=self.resolution, | |
| aspect_ratio=self.aspect_ratio, | |
| ) | |
| def rewrite_prompt_pair( | |
| self, | |
| item: PromptItem, | |
| previous_prompt: dict[str, Any], | |
| previous_negative_prompt: str, | |
| previous_analysis: dict[str, Any], | |
| history: list[dict[str, Any]], | |
| ) -> tuple[dict[str, Any], str]: | |
| """Jointly rewrite the positive JSON prompt and generator-side negative prompt.""" | |
| schema_keys = list(previous_prompt.keys()) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a precise text-to-image prompt engineer. Return valid JSON only, no markdown. " | |
| "Jointly coordinate the positive structured prompt and generator-side negative prompt so they do not contradict each other." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": self._joint_rewrite_user_prompt( | |
| item=item, | |
| previous_prompt=previous_prompt, | |
| previous_negative_prompt=previous_negative_prompt, | |
| previous_analysis=previous_analysis, | |
| history=history, | |
| schema_keys=schema_keys, | |
| ), | |
| }, | |
| ] | |
| last_exc: Exception | None = None | |
| for attempt in range(1, 4): | |
| try: | |
| raw = self.rewrite_client.complete(messages, response_format_json=True) | |
| return self._parse_joint_rewrite_response(raw, item.prompt_id) | |
| except Exception as exc: | |
| last_exc = exc | |
| if attempt < 3: | |
| time.sleep(min(20.0, 2.0 * attempt)) | |
| raise RuntimeError(f"Joint prompt rewrite failed after 3 attempts for prompt {item.prompt_id}.") from last_exc | |
| def _parse_joint_rewrite_response(raw: str, prompt_id: str) -> tuple[dict[str, Any], str]: | |
| data = extract_json_object(raw) | |
| positive_prompt = data.get("positive_prompt") | |
| if not isinstance(positive_prompt, dict): | |
| raise ValueError(f"Joint rewrite returned missing or non-object positive_prompt for prompt {prompt_id}.") | |
| validate_t2i_json(positive_prompt, prompt_id) | |
| negative_prompt = data.get("negative_prompt", "") | |
| if not isinstance(negative_prompt, str): | |
| raise ValueError(f"Joint rewrite returned non-string negative_prompt for prompt {prompt_id}.") | |
| return positive_prompt, " ".join(negative_prompt.split()) | |
| def _joint_rewrite_user_prompt( | |
| *, | |
| item: PromptItem, | |
| previous_prompt: dict[str, Any], | |
| previous_negative_prompt: str, | |
| previous_analysis: dict[str, Any], | |
| history: list[dict[str, Any]], | |
| schema_keys: list[str], | |
| ) -> str: | |
| sections = [ | |
| "Original user prompt:", | |
| item.prompt, | |
| "", | |
| "Application-specific guidance:", | |
| "Apply the following sections as one checklist program. Do not first classify the prompt. Apply each section only when relevant to the original user prompt, previous JSON, or VLM failures.", | |
| REWRITER_APPLICATION_GUIDANCE, | |
| "", | |
| "Previous generated image failed or scored according to this VLM analysis:", | |
| analysis_json_text(compact_analysis_for_rewrite(previous_analysis)), | |
| "", | |
| "Iteration history summary:", | |
| json.dumps(PromptRewriterClient._history_summary(history), ensure_ascii=JSON_ENSURE_ASCII, indent=2), | |
| "", | |
| "Previous positive JSON prompt:", | |
| json.dumps(previous_prompt, ensure_ascii=JSON_ENSURE_ASCII, indent=2), | |
| "", | |
| "Previous negative prompt:", | |
| previous_negative_prompt or "", | |
| "", | |
| "Joint rewrite task:", | |
| 'Return a JSON object with exactly two top-level keys: "positive_prompt" and "negative_prompt".', | |
| '"positive_prompt" must be a complete JSON object with exactly these top-level keys, preserving their names and types:', | |
| json.dumps(schema_keys, ensure_ascii=JSON_ENSURE_ASCII), | |
| "", | |
| '"positive_prompt" must keep the previous "resolution" and "aspect_ratio".', | |
| '"negative_prompt" must be a concise generator-side negative prompt string.', | |
| "Coordinate both fields: strengthen required positive constraints while using the negative prompt only to suppress concrete wrong alternatives or artifacts.", | |
| "Do not put positive instructions in negative_prompt. Do not negate content required by the original user prompt.", | |
| "For exact counts, grids, text, geometry, or anatomy, explicitly block wrong alternatives when useful.", | |
| 'The positive "comprehensive_t2i_caption" should be direct generation guidance, not an explanation of this rewrite process.', | |
| ] | |
| return "\n".join(sections) | |
| def _history_summary(history: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| return [ | |
| { | |
| "iteration": item.get("iteration"), | |
| "overall_score": item.get("analysis", {}).get("overall_score"), | |
| "prompt_adherence_score": item.get("analysis", {}).get("prompt_adherence_score"), | |
| "category_score": item.get("analysis", {}).get("category_score"), | |
| "threshold_cleared": item.get("analysis", {}).get("threshold_cleared"), | |
| } | |
| for item in history | |
| ] | |
| class ImageGenerationClient: | |
| """Client for a vLLM-Omni /v1/images/generations text-to-image endpoint.""" | |
| endpoint: str | |
| auth_key: str | |
| model: str | |
| session: requests.Session | |
| size: str | |
| num_steps: int | |
| guidance: float | |
| flow_shift: float | |
| extra_args: dict[str, Any] | |
| def __init__( | |
| self, | |
| *, | |
| endpoint: str, | |
| auth_key: str = "", | |
| model: str = DEFAULT_GENERATION_MODEL, | |
| size: str = DEFAULT_IMAGE_SIZE, | |
| num_steps: int = DEFAULT_NUM_STEPS, | |
| guidance: float = DEFAULT_GUIDANCE, | |
| flow_shift: float = DEFAULT_FLOW_SHIFT, | |
| extra_args: dict[str, Any] | None = None, | |
| session: requests.Session | None = None, | |
| ) -> None: | |
| self.endpoint = normalize_generation_endpoint(endpoint) | |
| self.auth_key = auth_key | |
| self.model = model | |
| self.session = session or make_session() | |
| self.size = size | |
| self.num_steps = num_steps | |
| self.guidance = guidance | |
| self.flow_shift = flow_shift | |
| self.extra_args = dict(DEFAULT_GENERATION_EXTRA_ARGS if extra_args is None else extra_args) | |
| def build_payload( | |
| self, | |
| prompt_json: dict[str, Any], | |
| prompt_id: str, | |
| seed: int | None = None, | |
| negative_prompt: str = "", | |
| ) -> dict[str, Any]: | |
| """Build the vLLM-Omni image generation request payload.""" | |
| del prompt_id | |
| payload: dict[str, Any] = { | |
| "model": self.model, | |
| "prompt": compact_json(prompt_json, ensure_ascii=JSON_ENSURE_ASCII), | |
| "size": self.size, | |
| "n": 1, | |
| "response_format": "b64_json", | |
| "negative_prompt": negative_prompt.strip(), | |
| "num_inference_steps": self.num_steps, | |
| "guidance_scale": self.guidance, | |
| "flow_shift": self.flow_shift, | |
| "extra_args": dict(self.extra_args), | |
| } | |
| if seed is not None: | |
| payload["seed"] = int(seed) | |
| return payload | |
| def generate( | |
| self, | |
| *, | |
| prompt_json: dict[str, Any], | |
| prompt_id: str, | |
| output_dir: Path, | |
| seed: int | None = None, | |
| negative_prompt: str = "", | |
| jpeg_quality: int = DEFAULT_JPEG_QUALITY, | |
| ) -> GenerationOutput: | |
| """Generate and persist one candidate image.""" | |
| payload = self.build_payload(prompt_json, prompt_id, seed, negative_prompt=negative_prompt) | |
| response_json = self._generate_image(payload) | |
| image_bytes = self._decode_image_response(response_json) | |
| image_path = output_dir / "image.jpg" | |
| image_info = self._save_jpeg(image_bytes, image_path, jpeg_quality) | |
| meta = { | |
| "prompt_id": prompt_id, | |
| "status": "completed", | |
| "endpoint": self.endpoint, | |
| "image_generation_url": self._image_generation_url(), | |
| "payload": payload, | |
| "response": self._response_without_image_bytes(response_json), | |
| "output_image_path": str(image_path), | |
| "image_info": image_info, | |
| } | |
| meta_path = output_dir / "generation_meta.json" | |
| write_json_atomic(meta_path, meta, ensure_ascii=JSON_ENSURE_ASCII) | |
| return GenerationOutput(image_path=image_path, meta_path=meta_path, meta=meta) | |
| def _generate_image(self, payload: dict[str, Any]) -> dict[str, Any]: | |
| last_exc: Exception | None = None | |
| for attempt in range(1, 4): | |
| try: | |
| return self._request_json( | |
| "POST", | |
| self._image_generation_url(), | |
| json=payload, | |
| headers=self._auth_headers(), | |
| timeout=(CONNECT_TIMEOUT_S, IMAGE_GENERATION_READ_TIMEOUT_S), | |
| ) | |
| except Exception as exc: | |
| last_exc = exc | |
| if attempt < 3: | |
| time.sleep(min(20.0, 2.0 * attempt)) | |
| raise RuntimeError(f"/v1/images/generations failed after retries: {last_exc}") from last_exc | |
| def _image_generation_url(self) -> str: | |
| return f"{self.endpoint}/v1/images/generations" | |
| def _auth_headers(self) -> dict[str, str] | None: | |
| token = self.auth_key.strip() | |
| if not token: | |
| return None | |
| if token.lower().startswith("bearer "): | |
| return {"Authorization": token} | |
| return {"Authorization": f"Bearer {token}"} | |
| def _request_json(self, method: str, url: str, **kwargs: Any) -> dict[str, Any]: | |
| timeout = kwargs.pop("timeout", (CONNECT_TIMEOUT_S, IMAGE_GENERATION_READ_TIMEOUT_S)) | |
| response = self.session.request(method, url, timeout=timeout, **kwargs) | |
| if not response.ok: | |
| raise RuntimeError(f"{method} {url} HTTP {response.status_code}: {response.text[:1000]}") | |
| parsed = response.json() | |
| if not isinstance(parsed, dict): | |
| raise RuntimeError(f"{method} {url} returned non-object JSON: {parsed!r}") | |
| return parsed | |
| def _decode_image_response(response_json: dict[str, Any]) -> bytes: | |
| data = response_json.get("data") | |
| if not isinstance(data, list) or not data or not isinstance(data[0], dict): | |
| raise RuntimeError(f"Image generation response has no data[0] object: {response_json}") | |
| first_image = data[0] | |
| b64_image = first_image.get("b64_json") | |
| if not isinstance(b64_image, str) or not b64_image.strip(): | |
| image_url = first_image.get("url") | |
| if isinstance(image_url, str) and image_url.startswith("data:image") and "," in image_url: | |
| b64_image = image_url.split(",", 1)[1] | |
| else: | |
| raise RuntimeError(f"Image generation response has no b64_json image: {response_json}") | |
| try: | |
| return base64.b64decode(b64_image, validate=True) | |
| except ValueError: | |
| return base64.b64decode(b64_image) | |
| def _response_without_image_bytes(response_json: dict[str, Any]) -> dict[str, Any]: | |
| redacted = json.loads(json.dumps(response_json)) | |
| data = redacted.get("data") | |
| if isinstance(data, list): | |
| for item in data: | |
| if isinstance(item, dict) and isinstance(item.get("b64_json"), str): | |
| item["b64_json"] = f"<base64 image omitted: {len(item['b64_json'])} chars>" | |
| if isinstance(item, dict) and isinstance(item.get("url"), str) and item["url"].startswith("data:image"): | |
| item["url"] = f"<data image omitted: {len(item['url'])} chars>" | |
| return redacted | |
| def _save_jpeg(image_bytes: bytes, output_path: Path, quality: int) -> dict[str, Any]: | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp = output_path.with_suffix(output_path.suffix + ".tmp") | |
| with Image.open(io.BytesIO(image_bytes)) as image: | |
| source_format = image.format | |
| rgb = image.convert("RGB") | |
| width, height = rgb.size | |
| rgb.save(tmp, format="JPEG", quality=quality, optimize=True) | |
| tmp.replace(output_path) | |
| return {"source_image_format": source_format, "saved_format": "JPEG", "width": width, "height": height} | |
| class VLMQualityJudge: | |
| """Gemini critic for generated images through an OpenAI-compatible endpoint.""" | |
| chat_client: OpenAIChatClient | |
| image_jpeg_quality: int | None | |
| def __init__( | |
| self, | |
| *, | |
| api_token: str, | |
| endpoint_url: str = DEFAULT_CRITIC_ENDPOINT_URL, | |
| model: str = DEFAULT_CRITIC_MODEL, | |
| max_tokens: int = 8192, | |
| image_jpeg_quality: int | None = DEFAULT_JPEG_QUALITY, | |
| ) -> None: | |
| self.chat_client = OpenAIChatClient( | |
| ChatClientConfig( | |
| endpoint_url=endpoint_url, | |
| model=model, | |
| api_token=api_token, | |
| max_tokens=max_tokens, | |
| max_retries=3, | |
| ) | |
| ) | |
| self.image_jpeg_quality = image_jpeg_quality | |
| def score_image( | |
| self, | |
| *, | |
| item: PromptItem, | |
| image_path: Path, | |
| ) -> dict[str, Any]: | |
| """Score one image with the non-classifying rubric program.""" | |
| messages = [ | |
| SYSTEM_MESSAGE, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": image_path_to_data_url(image_path, jpeg_quality=self.image_jpeg_quality)}, | |
| }, | |
| { | |
| "type": "text", | |
| "text": build_judge_prompt(item), | |
| }, | |
| ], | |
| }, | |
| ] | |
| raw = self.chat_client.complete(messages, response_format_json=True) | |
| analysis = parse_analysis_response(raw) | |
| analysis["raw_response"] = raw | |
| return analysis | |