Commit
fdafd05
·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: CYChenv <CYChenv@users.noreply.huggingface.co>
Co-authored-by: tangyue0820 <tangyue0820@users.noreply.huggingface.co>
Co-authored-by: fferroni <fferroni@users.noreply.huggingface.co>
Co-authored-by: harrim-nv <harrim-nv@users.noreply.huggingface.co>
Co-authored-by: liang1225 <liang1225@users.noreply.huggingface.co>
Co-authored-by: shilinzhu-nvidia <shilinzhu-nvidia@users.noreply.huggingface.co>
Co-authored-by: mli0603 <mli0603@users.noreply.huggingface.co>
Co-authored-by: mbalaNV <mbalaNV@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +39 -0
  2. .gitignore +2 -0
  3. AGENTIC_UPSAMPLING.md +174 -0
  4. BIAS.md +11 -0
  5. EXPLAINABILITY.md +16 -0
  6. PRIVACY.md +6 -0
  7. README.md +463 -0
  8. SAFETY.md +11 -0
  9. agentic_upsampling/__init__.py +6 -0
  10. agentic_upsampling/__main__.py +7 -0
  11. agentic_upsampling/clients.py +521 -0
  12. agentic_upsampling/constants.py +35 -0
  13. agentic_upsampling/data.py +167 -0
  14. agentic_upsampling/extract_best.py +155 -0
  15. agentic_upsampling/io_utils.py +46 -0
  16. agentic_upsampling/prompt_upsampler.py +388 -0
  17. agentic_upsampling/rubric.py +220 -0
  18. agentic_upsampling/run.py +187 -0
  19. agentic_upsampling/runner.py +474 -0
  20. assets/benchmark-text2image-leaderboard-all-models.jpg +3 -0
  21. assets/benchmark-text2image-leaderboard.png +3 -0
  22. assets/benchmark-text2image.png +3 -0
  23. assets/example_caption.json +88 -0
  24. assets/example_image.png +3 -0
  25. assets/more_images.jpg +3 -0
  26. assets/original_prompt.txt +1 -0
  27. chat_template.json +3 -0
  28. checkpoint.json +1 -0
  29. config.json +258 -0
  30. generation_config.json +14 -0
  31. merges.txt +0 -0
  32. model.safetensors.index.json +0 -0
  33. model_index.json +28 -0
  34. preprocessor_config.json +21 -0
  35. pytest.ini +4 -0
  36. scheduler/scheduler_config.json +33 -0
  37. sound_tokenizer/config.json +64 -0
  38. sound_tokenizer/diffusion_pytorch_model.safetensors +3 -0
  39. tests/test_agentic_upsampling.py +496 -0
  40. text_tokenizer/added_tokens.json +28 -0
  41. text_tokenizer/chat_template.jinja +120 -0
  42. text_tokenizer/merges.txt +0 -0
  43. text_tokenizer/special_tokens_map.json +31 -0
  44. text_tokenizer/tokenizer.json +3 -0
  45. text_tokenizer/tokenizer_config.json +239 -0
  46. text_tokenizer/vocab.json +0 -0
  47. tokenizer.json +0 -0
  48. tokenizer_config.json +239 -0
  49. transformer/config.json +54 -0
  50. transformer/diffusion_pytorch_model-00001-of-00027.safetensors +3 -0
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ text_tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pyc
2
+ outputs/
AGENTIC_UPSAMPLING.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Agentic Prompt Upsampling
2
+
3
+ This repository includes a standalone text-to-image agentic prompt upsampler for Cosmos3-Super-Text2Image.
4
+
5
+ The loop:
6
+
7
+ 1. Upsamples the user prompt into a structured Cosmos3 T2I JSON prompt.
8
+ 2. Generates an image through a vLLM-Omni `/v1/images/generations` endpoint.
9
+ 3. Scores the image with a VLM critic.
10
+ 4. Rewrites both the positive JSON prompt and generator-side negative prompt from the critic feedback.
11
+ 5. Repeats up to the configured iteration limit and returns the best scored image.
12
+
13
+ ## Install
14
+
15
+ From the repository root:
16
+
17
+ ```bash
18
+ python -m pip install requests pillow
19
+ ```
20
+
21
+ Recommended vLLM-Omni serving configuration for `nvidia/Cosmos3-Super-Text2Image` on 4xH200 is:
22
+
23
+ ```bash
24
+ vllm serve nvidia/Cosmos3-Super-Text2Image \
25
+ --omni \
26
+ --cfg-parallel-size 2 \
27
+ --ulysses-degree 2 \
28
+ --tensor-parallel-size 1
29
+ ```
30
+
31
+ With the no-offload configuration above, 1024x1024 image generation with 50 steps is expected to take roughly 5 seconds server-side per request.
32
+
33
+ ## Default Models
34
+
35
+ The default prompt upsampler and rewriter are OpenAI GPT-5.5 through the public OpenAI chat completions API:
36
+
37
+ ```text
38
+ endpoint: https://api.openai.com/v1
39
+ model: gpt-5.5
40
+ extra body: {"reasoning_effort": "low"}
41
+ env var: OPENAI_API_KEY
42
+ ```
43
+
44
+ The default critic is Gemini 3.1 Pro Preview through Google's OpenAI-compatible chat completions endpoint:
45
+
46
+ ```text
47
+ endpoint: https://generativelanguage.googleapis.com/v1beta/openai/
48
+ model: gemini-3.1-pro-preview
49
+ env var: GEMINI_API_KEY
50
+ ```
51
+
52
+ Set credentials:
53
+
54
+ ```bash
55
+ export OPENAI_API_KEY=...
56
+ export GEMINI_API_KEY=...
57
+ ```
58
+
59
+ If your vLLM-Omni generation endpoint requires auth:
60
+
61
+ ```bash
62
+ export AGENTIC_UPSAMPLING_GENERATION_AUTH_KEY=...
63
+ ```
64
+
65
+ ## Run One Prompt
66
+
67
+ ```bash
68
+ python -m agentic_upsampling.run \
69
+ --prompt "a cinematic photo of a glass greenhouse at sunrise" \
70
+ --output-dir outputs/agentic_greenhouse \
71
+ --generation-endpoint https://YOUR_VLLM_OMNI_ENDPOINT
72
+ ```
73
+
74
+ The generation call is a standard vLLM-Omni image request:
75
+
76
+ ```text
77
+ POST /v1/images/generations
78
+ model: nvidia/Cosmos3-Super-Text2Image
79
+ size: 1024x1024
80
+ response_format: b64_json
81
+ num_inference_steps: 50
82
+ guidance_scale: 4.0
83
+ flow_shift: 3.0
84
+ negative_prompt: ""
85
+ extra_args: {"guardrails": false, "use_resolution_template": false}
86
+ ```
87
+
88
+ ## Run A Batch
89
+
90
+ Text file, one prompt per non-empty line:
91
+
92
+ ```bash
93
+ python -m agentic_upsampling.run \
94
+ --prompts prompts.txt \
95
+ --output-dir outputs/agentic_batch \
96
+ --generation-endpoint https://YOUR_VLLM_OMNI_ENDPOINT
97
+ ```
98
+
99
+ JSONL rows can be strings or objects with `prompt` and optional `id`:
100
+
101
+ ```json
102
+ {"id": "greenhouse", "prompt": "a glass greenhouse at sunrise"}
103
+ {"id": "city", "prompt": "a clean futuristic city plaza after rain"}
104
+ ```
105
+
106
+ CSV files must include a `prompt` or `Prompt` column and may include an `id` column.
107
+
108
+ ## Useful Options
109
+
110
+ ```bash
111
+ python -m agentic_upsampling.run \
112
+ --prompt "a precise product photo of a transparent mechanical keyboard" \
113
+ --output-dir outputs/keyboard \
114
+ --generation-endpoint https://YOUR_VLLM_OMNI_ENDPOINT \
115
+ --max-iterations 2 \
116
+ --samples-per-iteration 3 \
117
+ --seed-base 42 \
118
+ --size 1024x1024 \
119
+ --guidance 4.0 \
120
+ --flow-shift 3.0
121
+ ```
122
+
123
+ - `--max-iterations` controls total prompt stages. The default is `2`, meaning the initial upsample plus up to two rewrites.
124
+ - `--samples-per-iteration` runs a best-of-N seed search for each prompt stage. Generation requests for those seeds are submitted concurrently within the iteration.
125
+ - `--seed-base` makes seeds deterministic. Sample seeds are `seed_base + sample_index`.
126
+ - `--size` is the vLLM-Omni image size in `WIDTHxHEIGHT` format.
127
+ - `--guidance` sets `guidance_scale`; the default is `4.0`.
128
+ - `--flow-shift` sets `flow_shift`; the default is `3.0`.
129
+ - `--generation-extra-args` overrides the default vLLM-Omni generation `extra_args` JSON object.
130
+ - Early stopping is enabled by default when the critic score clears the strict threshold. Use `--disable-early-stop` to always run every iteration.
131
+ - Reruns resume from completed artifacts by default. Use `--overwrite` to regenerate them.
132
+
133
+ ## Output Layout
134
+
135
+ ```text
136
+ output_dir/
137
+ run_config.json
138
+ summary.json
139
+ manifest.jsonl
140
+ failures.jsonl
141
+ 0001/
142
+ best.json
143
+ iter_00/
144
+ prompt.json
145
+ negative_prompt.json
146
+ image.jpg
147
+ generation_meta.json
148
+ analysis.json
149
+ samples.json
150
+ meta.json
151
+ iter_01/
152
+ ...
153
+ ```
154
+
155
+ For `--samples-per-iteration N`, each iteration contains `sample_00/`, `sample_01/`, and so on.
156
+
157
+ ## Export Best Images
158
+
159
+ Copy the selected best image for every completed prompt into one folder:
160
+
161
+ ```bash
162
+ python -m agentic_upsampling.extract_best \
163
+ --output-dir outputs/agentic_batch \
164
+ --export-dir outputs/agentic_batch_best \
165
+ --overwrite
166
+ ```
167
+
168
+ The exporter writes:
169
+
170
+ ```text
171
+ best_generations.jsonl
172
+ best_generations.csv
173
+ images/
174
+ ```
BIAS.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Bias
2
+
3
+ | Field | Response |
4
+ | :---- | :---- |
5
+ | Participation considerations from adversely impacted groups [protected classes](https://www.senate.ca.gov/content/protected-classes) in model design and testing | None. |
6
+ | Measures taken to mitigate against unwanted bias | Training, evaluation, and testing data are curated before release to filter restricted content, including content relating to protected classes. Model behavior is evaluated across Physical AI domains — robotics, autonomous vehicles, human-centric scenes, common scenes, industry, miscellaneous, and physics-oriented benchmarks — with attention to coverage across diverse demographic and contextual characteristics that affect protected-class outcomes. |
7
+ | Which characteristic (feature) show(s) the greatest difference in performance?: | Greatest performance differences are observed in tasks requiring long-horizon temporal consistency, fine-grained physical interactions, and embodiment-specific action generation. Performance is generally stronger on common visual reasoning and world-generation tasks than on complex multi-agent, robotics-control, or tightly synchronized multimodal generation scenarios. |
8
+ | Which feature(s) have the worst performance overall? | Performance is generally weakest in tasks requiring long-horizon temporal consistency, precise physical interactions, embodiment-specific action control, and strict audio-visual synchronization. |
9
+ | If using internal data, description of methods implemented in data acquisition or processing, if any, to address the prevalence of identifiable biases in the training, testing, and validation data: | Bias-specific methods applied during data processing include person-presence screening, demographic-taxonomy classification (age, gender, ethnicity), embedding-based diversity analysis, and dataset balancing across sources. Internal analysis surfaced: non-person scenes are more prevalent than person-centric content; demographic-taxonomy outputs on person-present samples are most frequently "uncertain" across age, gender, and ethnicity dimensions; and source-type variation, with people-centric image and video datasets showing higher demographic signal than document-, object-, robotics-, or scene-focused datasets. *(Quantitative details in the row below.)* Downstream deployments should add bias audits, fairness evaluation, red-teaming, demographically balanced fine-tuning, or counterfactual augmentation as mitigations. |
10
+ | Tools used to assess statistical imbalances and highlight patterns that may introduce bias into AI models: | Dataset analytics pipelines, metadata distribution analysis, heuristic quality checks, embedding-based clustering, model-assisted filtering systems, and benchmark evaluation suites are used to assess statistical imbalances and identify patterns that may introduce bias into model behavior. |
11
+ | Tools used to assess statistical imbalances and highlight patterns that may introduce bias into AI models: | These datasets, such as OpenImages-derived detection-to-NLP datasets, visual grounding and VQA datasets, document/image understanding datasets, video/action understanding datasets, and NVIDIA-created or curated visual datasets, do not collectively or exhaustively represent all demographic groups (and proportionally therein). For instance, automated person-presence screening did not identify a person in approximately 58% of visual samples analyzed across approximately 400 datasets, while person-present signals were identified in approximately 42% of analyzed samples. In the subset where person-present signals were identified, these datasets contain uneven representation splits across the measured visual taxonomies: age outputs were most frequently uncertain, followed by child and adult; gender outputs were most frequently uncertain, followed by male and female; and ethnicity outputs were most frequently uncertain, followed by Hispanic and White as the most frequent identified categories. Dataset-level results vary by source type, with people-centric image and video datasets containing higher person-present and demographic-taxonomy signals than document-, object-, robotics-, or scene-focused datasets. To mitigate these imbalances, we recommend considering evaluation techniques such as bias audits, task-specific fairness evaluation, and red-teaming, along with fine-tuning with demographically balanced datasets and counterfactual data augmentation to align with the desired model behavior. This evaluation used a baseline of 200 samples across all datasets, with larger subsets of up to 3,000 samples utilized for certain in-depth analyses, identified as optimal thresholds for maximizing embedder accuracy. |
EXPLAINABILITY.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Explainability
2
+
3
+ | Field | Response |
4
+ | :---- | :---- |
5
+ | Intended Application & Domain | World reasoning and generation for Physical AI. |
6
+ | Model Type | Mixture-of-Transformers architecture with two towers. One is an autoregressive model for Physical AI reasoning; the other is a diffusion model for Physical AI generation. |
7
+ | Intended Users | Physical AI developers, researchers, and practitioners building or evaluating autonomous vehicle, robotics, and world-generation workflows. |
8
+ | Output | Images, videos, audio, and action commands. |
9
+ | Tools used to evaluate datasets to identify synthetic data and ensure data authenticity. | Dataset provenance analysis, metadata validation, watermark and artifact detection, embedding-based clustering, heuristic quality checks, and model-assisted data validation pipelines are used to identify synthetic content patterns, assess dataset authenticity, and improve data quality during dataset curation. |
10
+ | Describe how the model works | Cosmos3 is an Omni world foundation model that generates texts, images, videos, audio, and action commands from combinations of text, images, videos, and action trajectory inputs. Input tokens from multiple modalities are packed into a shared sequence and processed by our mixture-of-transformer backbone with modality-specific output heads. |
11
+ | Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | None. |
12
+ | Technical Limitations | The model may not follow text, image, video, audio, or action trajectory inputs accurately in challenging cases, especially where the input contains complex scene composition, unusual camera motion, multiple interacting agents, low lighting, high motion blur, or fine-grained physical interactions. Generated outputs may contain temporal inconsistency, object morphing, inaccurate 3D structure, or implausible physical dynamics. Generated audio may not accurately render intelligible speech, or maintain strict temporal and semantic alignment with the visual context. |
13
+ | Verified to have met prescribed NVIDIA quality standards | Yes. |
14
+ | Performance Metrics | Video generation is measured using PAIBench-G, RBench, PhysicsIQ, and Artifical Analysis Image2Video benchmark. Image generation uses UniGenBench and Artifical Analysis Text2Image benchmark. For transfer evaluation, we use PAIBench-C and AVBench-C. Audio generation uses internal benchmarks. Action prediction uses metrics such as action MSE, Absolute Translation Error, Relative Translation Error, Relative Rotation Error, PSNR, and robotic task completion success rate. |
15
+ | Potential Known Risks | This model can generate synthetic media and may produce content that is offensive, unsafe, misleading, indecent, or unsuitable for a target deployment. Users should implement robust safety guardrails — including content filtering, abuse monitoring, and access controls — to reduce the risk of harmful outputs. Users are responsible for ensuring that their use of the model complies with all applicable laws and regulations, and for regularly reviewing and updating their guardrails as risks evolve. |
16
+ | Licensing | [OpenMDW1.1](https://openmdw.ai/) |
PRIVACY.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## Privacy
2
+ | Privacy Information |
3
+ |---|
4
+ | The model was trained on large-scale publicly available data that may contain images, audio-video, and text relating to people. NVIDIA collected and used this data in compliance with applicable data protection and privacy laws. This model was not designed to derive insights or otherwise learn from any personal data contained in the datasets. |
5
+ | NVIDIA uses a combination of filters, data minimization techniques, and other guardrails to help prevent personal data from being recited by our models. We employ automated tools and data processing techniques during pre-training or training to identify and filter certain categories of personal data. For example, for text-bearing source and document components, our automated tools identified potential personal data such as person names, locations, and possible business or public-facing contact information such as email addresses and phone numbers. We reviewed and removed any verified instances of personal data through a combination of automated filtering and human-in-the-loop validation. |
6
+ | Please review NVIDIA's [Privacy Policy](https://www.nvidia.com/en-us/about-nvidia/privacy-policy/) for more information. |
README.md ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: openmdw1.1-license
4
+ license_link: >-
5
+ https://openmdw.ai/license/1-1/
6
+ library_name: cosmos
7
+ tags:
8
+ - nvidia
9
+ - cosmos
10
+ - cosmos3
11
+ - vllm-omni
12
+ - diffusers
13
+ - text-to-image
14
+ - image-generation
15
+ ---
16
+
17
+ # **Cosmos 3: Omnimodal World Models for Physical AI**
18
+ **[Model Collection](https://huggingface.co/collections/nvidia/cosmos3)** | **[Code](https://github.com/nvidia/cosmos)** | **[White Paper](https://research.nvidia.com/labs/cosmos-lab/cosmos3/technical-report.pdf)** | **[Website](https://research.nvidia.com/labs/cosmos-lab/cosmos3/)**
19
+
20
+ [NVIDIA Cosmos™](https://github.com/nvidia/cosmos) is a world foundation model platform designed to accelerate the development of Physical AI by enabling machines to understand, simulate, and interact with the physical world across robotics, autonomous driving, and smart space environments, including industrial and factory-scale applications.
21
+
22
+ # Model Overview: Cosmos3-Super-Text2Image
23
+
24
+ ## Description
25
+
26
+ Cosmos3 is a collection of Omnimodal world models capable of generating dynamic, high-quality video, image, audio, and action commands from combinations of text, image, video, and action trajectory inputs. It serves as a foundational building block for a broad range of Physical AI applications and research spanning world understanding, world generation, simulation, and embodied policy learning.
27
+
28
+ This model is ready for commercial and non-commercial use.
29
+
30
+ **Model Developer:** NVIDIA
31
+
32
+ ### Model Versions
33
+ - Cosmos3-Nano:
34
+ - Given multimodal inputs including text, images, video, audio, and action trajectories, generate coherent text, images, video, audio, and action outputs for multimodal understanding, world simulation, future prediction, action reasoning, and Physical AI applications.
35
+
36
+ - Cosmos3-Super:
37
+ - Given multimodal inputs including text, images, video, audio, and action trajectories, generate coherent text, images, video, audio, and action outputs for multimodal understanding, world simulation, future prediction, action reasoning, and Physical AI applications.
38
+
39
+ - Cosmos3-Nano-Policy-DROID:
40
+ - Given language instructions and visual observations from the DROID robot platform, generate robot action trajectories for manipulation and control tasks.
41
+
42
+ - Cosmos3-Super-Image2Video:
43
+ - Given one input image and text instructions, generate temporally coherent video sequences that are consistent with the provided visual content.
44
+
45
+ - Cosmos3-Super-Text2Image:
46
+ - Given text input, generate high-fidelity images that are consistent with the provided description.
47
+
48
+ ### License
49
+
50
+ This model is released under the [OpenMDW1.1](https://openmdw.ai/license/1-1/)
51
+
52
+ ### Deployment Geography
53
+
54
+ Global
55
+
56
+ ### Use Case
57
+
58
+ Physical AI: Encompassing robotics, autonomous vehicles (AV), and smart space environments, including industrial and factory-scale applications.
59
+
60
+ ### Release Date
61
+
62
+ Hugging Face 05/31/2026 via [https://huggingface.co/collections/nvidia/cosmos3](https://huggingface.co/collections/nvidia/cosmos3)
63
+ GitHub 05/31/2026 via [https://github.com/nvidia/cosmos](https://github.com/nvidia/cosmos)
64
+
65
+ ## Model Architecture
66
+
67
+ **Architecture Type:** Transformer
68
+
69
+ **Network Architecture:** Mixture-of-Transformers (MoT)
70
+
71
+ Cosmos3 is an Omni-modal foundation model built on a Mixture-of-Transformers (MoT) architecture consisting of two complementary transformer towers: an autoregressive transformer for discrete token generation and a diffusion transformer for continuous multimodal generation. During inference, text is generated through standard next-token autoregressive decoding, while non-text modalities, such as images, video, audio, and actions, are synthesized through iterative denoising. This unified architecture enables Cosmos3 to model heterogeneous modalities within a single framework while preserving generation mechanisms best suited to each modality.
72
+
73
+ **This model was developed based on:** [Cosmos Framework](https://github.com/nvidia/cosmos-framework)
74
+
75
+ **Number of trainable model parameters:**
76
+
77
+ - Cosmos3-Nano: 16B
78
+ - Cosmos3-Super: 64B
79
+ - Cosmos3-Nano-Policy-DROID: 16B
80
+ - Cosmos3-Super-Image2Video: 64B
81
+ - Cosmos3-Super-Text2Image: 64B
82
+
83
+ ## Input/Output Specifications
84
+
85
+ - **Generator Input**
86
+ - **Input Type(s)**: Text, Image, Video (with audio or without audio), Action Trajectory
87
+ - **Input Format(s)**:
88
+ - Text: String
89
+ - Image: jpg, png, jpeg, webp
90
+ - Video (with or without audio): mp4
91
+ - Action: json (1D list)
92
+ - **Input Parameters**:
93
+ - Text: One-dimensional (1D)
94
+ - Image: Two-dimensional (2D)
95
+ - Video: Three-dimensional (3D)
96
+ - Audio: One-dimensional (1D)
97
+ - Action trajectory: One-dimensional (1D)
98
+ - **Other Properties Related to Input**:
99
+ - For video inputs, we accept various resolutions, including 720p, 480p, and 256p.
100
+ - When using input video with audio muxed into the video MP4 file, the audio should have 2 channels (stereo) and a 48 kHz sample rate.
101
+ - Image and video inputs are RGB color (8 bits per channel, sRGB color space); grayscale inputs are not supported.
102
+ - Action input is a per-frame sequence of robot/agent state or control values (e.g., joint positions, gripper state, camera pose). The full input is a 2D array shaped (T, D), where T is the number of frames and D is the embodiment-specific dimensionality listed below.
103
+ - Input action is only supported for compatible embodiments, including general camera motion (9D), autonomous vehicle (9D), egocentric motion (57D), single Franka Panda arm with RobotiQ gripper (10D), dual Franka Panda arm with RobotiQ gripper (20D), Agibot (29D), UR (10D), Google robot (10D), WidowX 250 (10D), UMI (9D).
104
+ - **Input Size and Length limits:**
105
+ - **Text:** 4096 tokens
106
+ - **Image:** 256p, 480p, and 720p resolution at one of these aspect ratios (16:9, 4:3, 1:1, 3:4, 9:16)
107
+ - **Video:** 256p, 480p, and 720p resolution at one of these aspect ratios (16:9, 4:3, 1:1, 3:4, 9:16). Max number of frames = 5.
108
+ - **Audio:** Max 0.5 second
109
+ - **Action:** 16 – 400 video frames
110
+ - **Generator Output**
111
+ - **Output Type(s)**: Image, video, audio, action, text
112
+ - **Output Format(s)**:
113
+ - Image: JPG
114
+ - Video: MP4
115
+ - Audio: Advanced Audio Coding (AAC) stream (muxed within the MP4)
116
+ - Action: 1D list (.json)
117
+ - Text: string
118
+ - **Output Parameters**:
119
+ - Image: Two-dimensional (2D)
120
+ - Video: Three-dimensional (3D)
121
+ - Audio: One-dimensional (1D)
122
+ - Action: One-dimensional (1D)
123
+ - Text: One-dimensional (1D)
124
+ - **Other Properties Related to Output**:
125
+ - The generated video is an MP4 file, with the resolution, frame rate, and duration specified in the input. The generated audio is encoded in AAC format, muxed into the video MP4 file with 2 channels (stereo) and a 48 kHz sample rate.
126
+ - Video generation supports durations from 5 to 400 frames, with 189 frames as the default generation duration.
127
+ - The generated action is only supported for compatible embodiments, including general camera motion (9D), autonomous vehicle (9D), egocentric motion (57D), single Franka Panda arm with RobotiQ gripper (10D), dual Franka Panda arm with RobotiQ gripper (20D), Agibot (29D), UR (10D), Google robot (10D), WidowX 250 (10D), UMI (9D).
128
+ - Audio: 48 kHz stereo AAC stream muxed into video mp4
129
+ - Video: mp4 at the FPS specified in input
130
+ - Image: JPEG
131
+ - **Reasoner Input**
132
+ - **Input Type(s)**: Text, Text+Image, Text+Video
133
+ - **Input Format(s)**:
134
+ - Text: String
135
+ - Image: jpg, png, jpeg, webp
136
+ - Video: mp4
137
+ - **Input Parameters**:
138
+ - Text: One-dimensional (1D)
139
+ - Image: Two-dimensional (2D)
140
+ - Video: Three-dimensional (3D)
141
+ - **Other Properties Related to Input**:
142
+ - Video inputs are recommended at a frame rate of 4 fps.
143
+ - Long-context inputs supported up to 256K tokens.
144
+ - **Input Size and Length limits:**
145
+ - **Text:** Up to 256K tokens (context window).
146
+ - **Image:** Standard input image formats; passed as file or URL.
147
+ - **Video:** mp4 at the recommended 4 fps.
148
+ - **Reasoner Output**
149
+ - **Output Type(s)**: Text
150
+ - **Output Format(s)**:
151
+ - Text: string
152
+ - **Output Parameters**:
153
+ - Text: One-dimensional (1D)
154
+ - **Other Properties Related to Output**:
155
+ - Default `max_tokens=4096+` is recommended for reasoning outputs; longer outputs may be requested.
156
+ - Reasoning outputs may include structured chain-of-thought, 2D/3D point localization, and bounding-box coordinates for vision-based tasks.
157
+
158
+ The video content visualizes the input text description as a short animated scene, capturing key elements within the specified time constraints.
159
+
160
+ Our AI models are designed and/or optimized to run on NVIDIA GPU-accelerated systems. By leveraging NVIDIA's hardware (e.g., GPU cores) and software frameworks (e.g., CUDA libraries), the model achieves faster training and inference times compared to CPU-only solutions.
161
+
162
+ ## Software Integration
163
+
164
+ **Runtime Engine(s):**
165
+
166
+ - [PyTorch](https://github.com/nvidia/cosmos3)
167
+ - [vLLM-Omni](https://github.com/vllm-project/vllm-omni)
168
+ - [Hugging Face Diffusers](https://huggingface.co/docs/diffusers/en/index)
169
+
170
+ **Supported Hardware Microarchitecture Compatibility:**
171
+
172
+ - NVIDIA Ampere
173
+ - NVIDIA Blackwell
174
+ - NVIDIA Hopper
175
+
176
+ **Operating System(s):**
177
+
178
+ - Linux (We have not tested on other operating systems.)
179
+
180
+ **Note:** Only BF16 precision is tested. Other precisions like FP4, FP8, and FP16 are not officially supported.
181
+
182
+ The integration of foundation and fine-tuned models into AI systems requires additional testing using use-case-specific data to ensure safe and effective deployment. Following the V-model methodology, iterative testing and validation at both unit and system levels are essential to mitigate risks, meet technical and functional requirements, and ensure compliance with safety and ethical standards before deployment.
183
+
184
+ ## Training, Testing, and Evaluation Datasets
185
+
186
+ ### Dataset Overview
187
+
188
+ - **Total Size:** 1.3B data points
189
+ - **Total Number of Datasets:** 393 dataset entries
190
+ - **Dataset partition:** Training [100%], Testing [N/A ��� evaluation benchmarks used separately], Validation [N/A — evaluation benchmarks used separately]
191
+ - **Time period for training data collection:** 2024–2026
192
+ - **Time period for testing data collection:** N/A (standard public benchmarks)
193
+ - **Time period for validation data collection:** N/A (standard public benchmarks)
194
+
195
+ Raw data from internal and external sources is transformed into training-ready data through multiple stages of curation, filtering, and quality review. Data acquisition spans diverse multimodal sources — robotics, autonomous driving, industrial environments, indoor and outdoor scenes, varied lighting and weather conditions, camera viewpoints, object categories, and human activities — to broaden coverage across Physical AI operating environments. Automated filtering pipelines remove corrupted, duplicate, low-quality, and restricted content. Metadata analysis, heuristic rules, and model-assisted classifiers are applied during preprocessing to flag anomalous distributions and low-diversity subsets. Human review supplements automated filtering for selected datasets, benchmark construction, and targeted quality analysis. Datasets are balanced across modalities and task categories — visual reasoning, text-to-image, text-to-video, image-to-video, audio generation, video transfer, action-conditioned generation, and action command generation — to reduce overrepresentation of narrow domains. Synthetic and simulation-based augmentation supplements coverage of rare physical interactions and edge-case scenarios. Deduplication and provenance tracking are applied across the corpus. The resulting processed data is converted into model-ready tokenized or encoded representations through modality-specific preprocessors before training begins.
196
+
197
+ Training datasets passed through multiple layers of automated and manual safeguards designed to reduce the presence of harmful or policy-violating content across categories including weapons and weapons-related instructional content, criminal planning, child sexual abuse material (CSAM), non-consensual intimate imagery (NCII), sexual content involving minors, harassment, hate speech, profanity, threats and incitement to violence, self-harm or suicide-related content, and graphic violence. Data sources are reviewed for licensing compatibility, provenance, and alignment with internal data governance and safety policies before admission into training corpora. Automated filtering pipelines combine multiple detection strategies: hash-matching against known CSAM and NCII reference databases; classifier-based moderation models trained for explicit sexual content, hate speech, violence, weapons imagery, and other restricted categories; keyword and regex-based screening for criminal-planning, threats, and self-harm phrases in text data; metadata and provenance heuristics for source-level risk signals; and embedding-based anomaly detection to surface samples that fall outside expected distributions. Human review and targeted audits supplement automated filtering for selected datasets, benchmark construction, and safety-sensitive evaluation. For multimodal Physical AI data (robotics, autonomous driving, industrial scenes), additional filtering targets invalid action trajectories, physically implausible interactions, and unsafe control sequences. Synthetic and simulation-generated data are evaluated through internal validation before inclusion. Benchmark evaluations and red-team testing are applied post-training to surface remaining safety gaps across world generation, reasoning, audio, and action tasks. No large-scale data-filtering process can guarantee complete removal of all harmful content; residual risks may remain, particularly in rare edge cases or open-world deployment settings. Ongoing monitoring and dataset review continue post-release.
198
+
199
+ **Data Modality and Training Data Size**
200
+
201
+ | Modality | Reasoning Data Sample Count | Generation Data Sample Count |
202
+ | -------- | ------------------- | -------------------- |
203
+ | Text | 22M | Not Applicable |
204
+ | Image | 19M | 767M |
205
+ | Video | 1M | 348M |
206
+ | Audio | Not Applicable | 139M |
207
+ | Action | Not Applicable | 8M |
208
+
209
+ **Data Collection Method by dataset**
210
+
211
+ - Hybrid: Automatic/Sensors, Synthetic, Automated
212
+
213
+ **Labeling Method by dataset**
214
+
215
+ - Hybrid: Human, Automated
216
+
217
+ **Properties:** The training, testing, and evaluation datasets consist of diverse multimodal video, image, audio, action, synthetic, and sensor-conditioned data sourced from NVIDIA-owned data and publicly available, commercially permissive datasets. These datasets are curated to exclude known restricted content and to support building an Omni model that learns to generate and reason about dynamic physical environments across world reasoning and generation tasks.
218
+
219
+ ### Public Datasets
220
+
221
+ | Dataset&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; | Samples&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; |
222
+ |---|---|
223
+ | OpenImage | 1.2M |
224
+ | Coyo700M | 100M |
225
+ | YouTube Video | 340M |
226
+ | UMI | 4.5M |
227
+
228
+ ### Private Datasets
229
+
230
+ | Dataset&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; | Samples&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; |
231
+ |---|---|
232
+ | Egocentric | 7M |
233
+ | Nexar | 0.6M |
234
+ | AgiBot | 0.2M |
235
+ | HOI | 0.3M |
236
+
237
+ ### Synthetic Datasets
238
+
239
+ | Dataset | Samples |
240
+ |---|---|
241
+ | synthetic images generated using HiDream-I1 | 15M |
242
+ | synthetic images generated using Qwen-Image-2512 | 14M |
243
+ | synthetic captions generated using Qwen3-VL | 1115M |
244
+
245
+ ## Evaluation Datasets
246
+
247
+ **Data Collection Method by dataset**
248
+
249
+ - Hybrid: Automatic/Sensors, Synthetic, Automated
250
+
251
+ **Labeling Method by dataset**
252
+
253
+ - Hybrid: Human, Automated
254
+
255
+ **Properties:** The training, testing, and evaluation datasets consist of diverse multimodal video, image, audio, action, synthetic, and sensor-conditioned data sourced from NVIDIA-owned data and publicly available, commercially permissive datasets. These datasets are curated to exclude known restricted content and to support building an Omni model that learns to generate and reason about dynamic physical environments across world reasoning and generation tasks.
256
+
257
+ ## Benchmarks
258
+
259
+ Please see our [technical paper](https://research.nvidia.com/labs/cosmos-lab/cosmos3/technical-report.pdf) for detailed evaluations of the base model.
260
+
261
+ ### Text-to-image benchmark results
262
+
263
+ ![benchmark results](assets/benchmark-text2image.png)
264
+
265
+ ### Artificial Analysis Leaderboard
266
+
267
+ #### Open-Source Models [2026/05/28/]
268
+
269
+ ![Artificial Analysis Text-to-Image leaderboard — open-source models](assets/benchmark-text2image-leaderboard.png)
270
+
271
+ #### All Models [2026/05/28/] (Including Closed-Source)
272
+
273
+ ![Artificial Analysis Text-to-Image leaderboard — all models including closed-source](assets/benchmark-text2image-leaderboard-all-models.jpg)
274
+
275
+ ## Qualitative examples
276
+
277
+ ![Qualitative examples](assets/more_images.jpg)
278
+
279
+ ## Usage
280
+
281
+ - See [Cosmos](https://github.com/nvidia/cosmos) for details.
282
+
283
+ ### Prompt upsampling
284
+
285
+ For optimal quality, text prompts should be upsampled into a specific JSON structure. Description and code can be found [here](https://github.com/nvidia/cosmos-framework/blob/main/docs/prompt_upsampling.md).
286
+
287
+ For example, for text-to-image upsampling using Opus-4.7:
288
+
289
+ ```bash
290
+ git clone https://github.com/NVIDIA/cosmos-framework.git packages/cosmos-framework
291
+ pip install -e packages/cosmos-framework
292
+
293
+ export PROMPT_UPSAMPLER_ENDPOINT_URL="https://api.anthropic.com/v1/"
294
+ export PROMPT_UPSAMPLER_MODEL_NAME="claude-opus-4-7"
295
+ export PROMPT_UPSAMPLER_API_TOKEN="<your_token>"
296
+
297
+ python -m cosmos_framework.inference.prompt_upsampling \
298
+ --input assets/original_prompt.txt \
299
+ --output /tmp/upsampled_t2i_opus/ \
300
+ --mode text2image \
301
+ --endpoint-url "${PROMPT_UPSAMPLER_ENDPOINT_URL}" \
302
+ --model "${PROMPT_UPSAMPLER_MODEL_NAME}" \
303
+ --api-token "${PROMPT_UPSAMPLER_API_TOKEN}" \
304
+ --resolution 768 \
305
+ --aspect-ratio "1,1"
306
+ ```
307
+
308
+ The JSON-upsampled version of `assets/original_prompt.txt` is saved in `assets/example_caption.json` for convenience, and is used for the image generation examples below.
309
+
310
+ ### vLLM-Omni
311
+
312
+ #### Container
313
+
314
+ ```
315
+ docker pull vllm/vllm-omni:cosmos3
316
+ ```
317
+
318
+ #### General Invocation
319
+
320
+ You can use the release-tested `vllm-omni` package for deploying an OpenAI-compatible API inference endpoint.
321
+ The recommended vLLM-Omni serving configuration for `nvidia/Cosmos3-Super-Text2Image` on a 8xH100 node is:
322
+
323
+ ```bash
324
+ vllm serve nvidia/Cosmos3-Super-Text2Image \
325
+ --omni \
326
+ --host 0.0.0.0 \
327
+ --port 8000 \
328
+ --cfg-parallel-size 2 \
329
+ --ulysses-degree 4 \
330
+ --tensor-parallel-size 1 \
331
+ --use-hsdp \
332
+ --hsdp-shard-size 8 \
333
+ --init-timeout 1800
334
+ ```
335
+
336
+ Setting `--enable-layerwise-offload` can help with memory usage on GPUs with less available memory; however, please note that for text2image generation, this may incur a significant performance penalty. For 4xH200 or 4xGB200 one can simply use `--cfg-parallel-size 2 --ulysses-degree 2 --tensor-parallel-size 1`.
337
+
338
+ #### Examples
339
+
340
+ ##### Text to image generation
341
+ ```python
342
+ import base64
343
+ import json
344
+ import requests
345
+
346
+ # 1. Read JSON-upsampled prompt
347
+ json_prompt = json.load(open("assets/example_caption.json"))
348
+
349
+ # 2. Build your API payload
350
+ payload = {
351
+ "prompt": json.dumps(json_prompt),
352
+ "size": "1024x1024",
353
+ "n": 1, # single frame generation
354
+ "num_inference_steps": 50,
355
+ "guidance_scale": 4.0,
356
+ "flow_shift": 3.0,
357
+ "negative_prompt": "",
358
+ "seed": 1143,
359
+ "extra_args": {
360
+ "use_resolution_template": False,
361
+ "guardrails": True,
362
+ },
363
+ }
364
+
365
+ # 3. Send the POST request
366
+ url = "http://localhost:8000/v1/images/generations"
367
+ print("Sending request to server...")
368
+ response = requests.post(url, json=payload, headers={"Content-Type": "application/json"})
369
+ response.raise_for_status()
370
+
371
+ # 4. Extract the base64 data and decode it into an image
372
+ response_json = response.json()
373
+ b64_data = response_json["data"][0]["b64_json"]
374
+ image_bytes = base64.b64decode(b64_data)
375
+
376
+ # 5. Save the final PNG file
377
+ with open("/tmp/cosmos3_t2i.png", "wb") as image_file:
378
+ image_file.write(image_bytes)
379
+ print("Saved image to /tmp/cosmos3_t2i.png")
380
+ ```
381
+
382
+ ![example_image](assets/example_image.png)
383
+
384
+ ### Diffusers
385
+
386
+ Cosmos3 is fully supported within the popular HuggingFace Diffusers package. This integration makes it a supported inference backend, allowing developers to easily incorporate Cosmos3's capabilities - such as text-to-image generation - into their pipelines using the Cosmos3OmniPipeline class, as demonstrated by the provided code examples (see examples for other modalities on the HuggingFace Cosmos3 page).
387
+
388
+ **Note:** This example is tested on GB200. For H100, use the [vLLM-Omni serving recipe](#vllm-omni) above, which supports multi-GPU deployment via HSDP.
389
+
390
+ #### Installation
391
+
392
+ To install diffusers with Cosmos3OmniPipeline:
393
+ ```
394
+ uv venv --python 3.13 --seed --managed-python
395
+ source .venv/bin/activate
396
+ uv pip install \
397
+ "diffusers @ git+https://github.com/huggingface/diffusers.git" \
398
+ accelerate \
399
+ av \
400
+ cosmos_guardrail \
401
+ huggingface_hub \
402
+ imageio \
403
+ imageio-ffmpeg \
404
+ torch \
405
+ torchvision \
406
+ transformers
407
+ ```
408
+
409
+ #### Examples
410
+
411
+ ##### Text to image generation
412
+ ```python
413
+ import json
414
+ import torch
415
+ from diffusers import Cosmos3OmniPipeline
416
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
417
+
418
+ json_prompt = json.load(open("assets/example_caption.json"))
419
+
420
+ pipe = Cosmos3OmniPipeline.from_pretrained(
421
+ "nvidia/Cosmos3-Super-Text2Image",
422
+ torch_dtype=torch.bfloat16,
423
+ device_map="cuda",
424
+ enable_safety_checker=True,
425
+ )
426
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=3.0)
427
+
428
+ result = pipe(
429
+ prompt=json.dumps(json_prompt),
430
+ negative_prompt="",
431
+ num_frames=1,
432
+ height=1024,
433
+ width=1024,
434
+ num_inference_steps=50,
435
+ guidance_scale=4.0,
436
+ generator=torch.Generator(device="cuda").manual_seed(1143),
437
+ )
438
+
439
+ result.video[0].save("/tmp/cosmos3_t2i.png")
440
+ print("Saved image to /tmp/cosmos3_t2i.png")
441
+ ```
442
+
443
+ ## Limitations
444
+
445
+ Cosmos3 may produce imperfect outputs in challenging scenarios. Generation artifacts include temporal inconsistency, unstable camera or object motion, imprecise physical interactions, inaccurate audio-video synchronization, and action-state drift — especially in long-horizon or high-resolution outputs. Reasoning may also be incorrect: object states, causal relationships, spatial geometry, temporal ordering, agent intent, and future outcomes can be misinferred, and complex or long-context inputs may yield hallucinated entities, inconsistent interpretations, or implausible predictions. Because the model lacks an explicit physics simulator, 3D geometry, 4D space-time evolution, object permanence, contact dynamics, and physical laws are only approximated — producing artifacts such as disappearing or morphing objects, unrealistic collisions, and physically implausible motions. Quality further degrades in out-of-distribution environments, safety-critical edge cases, and domains underrepresented in training.
446
+
447
+ Cosmos3 outputs should not be treated as physically accurate simulation, reliable ground-truth reasoning, or safety-certified decision making. Applications involving robotics control, autonomous systems, scientific simulation, or safety-critical planning require additional validation, external constraints, system-level safety analysis, and domain-specific guardrails before deployment.
448
+
449
+ ## Inference
450
+
451
+ **Acceleration Engine:** [PyTorch](https://pytorch.org/), [vLLM](https://github.com/vllm-project/vllm), [vLLM-Omni](https://github.com/vllm-project/vllm-omni), [Hugging Face Diffusers](https://github.com/huggingface/diffusers)
452
+
453
+ **Test Hardware:** GB200 and H100
454
+
455
+ ## Ethical Considerations
456
+
457
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. Developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
458
+
459
+ Please make sure you have proper rights and permissions for all input image and video content; if image or video includes people, personal health information, or intellectual property, the image or video generated will not blur or maintain proportions of image subjects included.
460
+
461
+ Users are responsible for model inputs and outputs. Users are responsible for ensuring safe integration of this model, including implementing guardrails as well as other safety mechanisms, prior to deployment.
462
+
463
+ For more detailed information on ethical considerations for this model, please see the Model Card++ [Explainability](EXPLAINABILITY.md), [Bias](BIAS.md), [Safety & Security](SAFETY.md), and [Privacy](PRIVACY.md) subcards. Please report model quality, risk, security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
SAFETY.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Safety & Security
2
+
3
+ | Field | Response |
4
+ | :---- | :---- |
5
+ | Model Application(s) | World reasoning and generation for Physical AI. |
6
+ | Describe the life critical impact: | This model is not a safety-certified component and must not be used as the sole basis for life-critical decisions or control without additional system-level validation, safety analysis, and safeguards. The model is not designed or tested by NVIDIA for use in any system or application where the use of or failure of such system or application developed with the model could result in injury, death, or catastrophic damage. NVIDIA is not liable to any party, in whole or in part, for any claims or damages arising from those uses. Any system or application developed with the model must include sufficient safety and redundancy features and comply with applicable legal and regulatory standards and requirements. |
7
+ | Description of methods implemented in data acquisition or processing, if any, to address other types of potentially harmful data in the training, testing, and validation data: | Training, evaluation, and validation datasets pass through multi-stage automated and manual filtering to reduce harmful, unsafe, restricted, or policy-violating content. Pipelines include source-licensing review, deduplication, metadata-based and classifier-based moderation, embedding-based anomaly detection, and human audits on selected datasets. For Physical AI data (robotics, autonomous driving, industrial scenes), filtering also targets invalid action trajectories, physically implausible interactions, and unsafe control sequences. Synthetic and simulation-generated data are evaluated through internal validation before inclusion. Benchmark and red-team testing surface remaining safety gaps across world generation, reasoning, audio, and action tasks. No data-filtering process can guarantee complete removal; developers are responsible for application-specific safeguards and validation before deployment. |
8
+ | Description of any methods implemented in data acquisition or processing, if any, to address illegal or harmful content in the training data, including, but not limited to, child sexual abuse material (CSAM) and non-consensual intimate imagery (NCII) | In addition to the general unsafe-content filtering described above, training data acquisition and preprocessing apply CSAM- and NCII-specific safeguards: hash-matching systems against known CSAM databases, classifier-based moderation models trained specifically for explicit content and NCII detection, and provenance and licensing review for sources containing human imagery. Identified content is removed at ingest, with human review and targeted audits supplementing automated filtering for selected datasets. Despite these safeguards, no large-scale data-filtering system can guarantee complete detection. Ongoing monitoring and dataset review continue post-release. |
9
+ | Use Case Restrictions | Use is governed by the [OpenMDW1.1](https://openmdw.ai/) |
10
+ | Model and dataset restrictions | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. |
11
+ | Responsible Data Handling | This AI model was developed based on our policies to ensure responsible data handling and risk mitigation. The datasets used for training have been scanned for harmful content and illegal content, consistent with our policies including scanning for Child Sexual Abuse Material (CSAM). Ongoing review and monitoring mechanisms are in place based on our policies and to maintain data integrity. |
agentic_upsampling/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Standalone agentic prompt upsampling for Cosmos3 text-to-image."""
2
+
3
+ from agentic_upsampling.data import PromptItem
4
+ from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig
5
+
6
+ __all__ = ["AgenticUpsamplerRunner", "PromptItem", "RunnerConfig"]
agentic_upsampling/__main__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Run the agentic prompt upsampling CLI."""
2
+
3
+ from agentic_upsampling.run import main
4
+
5
+
6
+ if __name__ == "__main__":
7
+ raise SystemExit(main())
agentic_upsampling/clients.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Network clients for standalone agentic text-to-image upsampling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import io
7
+ import json
8
+ import os
9
+ import time
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Any
13
+ import requests
14
+ from PIL import Image
15
+ from requests.adapters import HTTPAdapter
16
+ from urllib3.util.retry import Retry
17
+
18
+ from agentic_upsampling.constants import (
19
+ DEFAULT_ASPECT_RATIO,
20
+ DEFAULT_CRITIC_ENDPOINT_URL,
21
+ DEFAULT_CRITIC_MODEL,
22
+ DEFAULT_GENERATION_AUTH_KEY_ENV,
23
+ DEFAULT_GENERATION_EXTRA_ARGS,
24
+ DEFAULT_GENERATION_MODEL,
25
+ DEFAULT_FLOW_SHIFT,
26
+ DEFAULT_GUIDANCE,
27
+ DEFAULT_IMAGE_SIZE,
28
+ DEFAULT_JPEG_QUALITY,
29
+ DEFAULT_LLM_EXTRA_BODY,
30
+ DEFAULT_NUM_STEPS,
31
+ DEFAULT_OPENAI_API_KEY_ENV,
32
+ DEFAULT_RESOLUTION,
33
+ DEFAULT_REWRITER_ENDPOINT_URL,
34
+ DEFAULT_REWRITER_MODEL,
35
+ DEFAULT_UPSAMPLER_ENDPOINT_URL,
36
+ DEFAULT_UPSAMPLER_MODEL,
37
+ )
38
+ from agentic_upsampling.data import PromptItem, validate_t2i_json
39
+ from agentic_upsampling.io_utils import compact_json, write_json_atomic
40
+ from agentic_upsampling.prompt_upsampler import (
41
+ JSON_ENSURE_ASCII,
42
+ SYSTEM_MESSAGE,
43
+ ChatClientConfig,
44
+ OpenAIChatClient,
45
+ Text2ImagePromptUpsampler,
46
+ extract_json_object,
47
+ )
48
+ from agentic_upsampling.rubric import (
49
+ all_category_check_text,
50
+ analysis_json_text,
51
+ build_judge_prompt,
52
+ compact_analysis_for_rewrite,
53
+ parse_analysis_response,
54
+ )
55
+
56
+ CONNECT_TIMEOUT_S = 60
57
+ SUBMIT_READ_TIMEOUT_S = 240
58
+ IMAGE_GENERATION_READ_TIMEOUT_S = 600
59
+ REWRITER_APPLICATION_GUIDANCE = all_category_check_text()
60
+
61
+
62
+ @dataclass(frozen=True, slots=True)
63
+ class GenerationOutput:
64
+ """Output from one image generation request."""
65
+
66
+ image_path: Path
67
+ meta_path: Path
68
+ meta: dict[str, Any]
69
+
70
+
71
+ def read_api_token(api_key_env: str, api_key_file: Path | None = None) -> str:
72
+ """Resolve an API token from an environment variable or explicit file."""
73
+ token = os.environ.get(api_key_env, "").strip()
74
+ if token:
75
+ return token
76
+ if api_key_file is not None and api_key_file.exists():
77
+ token = api_key_file.read_text(encoding="utf-8").strip()
78
+ if token:
79
+ return token
80
+ raise RuntimeError(f"Missing API key. Export {api_key_env} or pass the matching --*-api-key-file flag.")
81
+
82
+
83
+ def read_optional_generation_auth_key(auth_key: str, api_key_env: str = DEFAULT_GENERATION_AUTH_KEY_ENV) -> str:
84
+ """Resolve the optional generation endpoint auth key."""
85
+ return auth_key.strip() or os.environ.get(api_key_env, "").strip()
86
+
87
+
88
+ def normalize_generation_endpoint(endpoint: str) -> str:
89
+ """Normalize the vLLM-Omni endpoint root without the /v1 suffix."""
90
+ normalized = endpoint.strip().rstrip("/")
91
+ if not normalized:
92
+ raise ValueError("generation endpoint cannot be empty.")
93
+ if not normalized.startswith(("http://", "https://")):
94
+ normalized = f"https://{normalized}"
95
+ if normalized.endswith("/v1/images/generations"):
96
+ normalized = normalized[: -len("/v1/images/generations")]
97
+ elif normalized.endswith("/v1"):
98
+ normalized = normalized[: -len("/v1")]
99
+ return normalized.rstrip("/")
100
+
101
+
102
+ def make_session(pool_size: int = 4) -> requests.Session:
103
+ """Create a retrying HTTP session."""
104
+ session = requests.Session()
105
+ retry = Retry(
106
+ total=2,
107
+ connect=2,
108
+ read=0,
109
+ status=2,
110
+ status_forcelist=(429, 500, 502, 503, 504),
111
+ allowed_methods=frozenset({"GET", "POST"}),
112
+ backoff_factor=0.5,
113
+ raise_on_status=False,
114
+ )
115
+ adapter = HTTPAdapter(pool_connections=pool_size, pool_maxsize=pool_size, max_retries=retry, pool_block=False)
116
+ session.mount("https://", adapter)
117
+ session.mount("http://", adapter)
118
+ return session
119
+
120
+
121
+ def image_path_to_data_url(path: Path, *, jpeg_quality: int | None = DEFAULT_JPEG_QUALITY) -> str:
122
+ """Encode a local image file as a data URL, optionally transcoding to JPEG."""
123
+ if jpeg_quality is None:
124
+ encoded = base64.b64encode(path.read_bytes()).decode("ascii")
125
+ return f"data:image/png;base64,{encoded}"
126
+
127
+ with Image.open(path) as image:
128
+ if image.mode not in ("RGB", "L"):
129
+ image = image.convert("RGB")
130
+ buf = io.BytesIO()
131
+ image.save(buf, format="JPEG", quality=jpeg_quality, optimize=True)
132
+ encoded = base64.b64encode(buf.getvalue()).decode("ascii")
133
+ return f"data:image/jpeg;base64,{encoded}"
134
+
135
+
136
+ class PromptRewriterClient:
137
+ """GPT-based T2I JSON prompt upsampler and iterative rewriter."""
138
+
139
+ upsampler: Text2ImagePromptUpsampler
140
+ rewrite_client: OpenAIChatClient
141
+ resolution: str
142
+ aspect_ratio: str
143
+
144
+ def __init__(
145
+ self,
146
+ *,
147
+ api_token: str,
148
+ upsampler_endpoint_url: str = DEFAULT_UPSAMPLER_ENDPOINT_URL,
149
+ upsampler_model: str = DEFAULT_UPSAMPLER_MODEL,
150
+ rewriter_endpoint_url: str = DEFAULT_REWRITER_ENDPOINT_URL,
151
+ rewriter_model: str = DEFAULT_REWRITER_MODEL,
152
+ extra_body: dict[str, Any] | None = None,
153
+ resolution: str = DEFAULT_RESOLUTION,
154
+ aspect_ratio: str = DEFAULT_ASPECT_RATIO,
155
+ ) -> None:
156
+ resolved_extra_body = DEFAULT_LLM_EXTRA_BODY if extra_body is None else extra_body
157
+ self.upsampler = Text2ImagePromptUpsampler.from_defaults(
158
+ api_token=api_token,
159
+ endpoint_url=upsampler_endpoint_url,
160
+ model=upsampler_model,
161
+ extra_body=resolved_extra_body,
162
+ )
163
+ self.rewrite_client = OpenAIChatClient(
164
+ ChatClientConfig(
165
+ endpoint_url=rewriter_endpoint_url,
166
+ model=rewriter_model,
167
+ api_token=api_token,
168
+ extra_body=resolved_extra_body,
169
+ max_tokens=8192,
170
+ max_retries=3,
171
+ )
172
+ )
173
+ self.resolution = resolution
174
+ self.aspect_ratio = aspect_ratio
175
+
176
+ def initial_prompt(self, item: PromptItem) -> dict[str, Any]:
177
+ """Create the initial dense structured prompt for a user prompt."""
178
+ return self.upsampler.upsample(
179
+ item.prompt,
180
+ prompt_id=item.prompt_id,
181
+ resolution=self.resolution,
182
+ aspect_ratio=self.aspect_ratio,
183
+ )
184
+
185
+ def rewrite_prompt_pair(
186
+ self,
187
+ item: PromptItem,
188
+ previous_prompt: dict[str, Any],
189
+ previous_negative_prompt: str,
190
+ previous_analysis: dict[str, Any],
191
+ history: list[dict[str, Any]],
192
+ ) -> tuple[dict[str, Any], str]:
193
+ """Jointly rewrite the positive JSON prompt and generator-side negative prompt."""
194
+ schema_keys = list(previous_prompt.keys())
195
+ messages = [
196
+ {
197
+ "role": "system",
198
+ "content": (
199
+ "You are a precise text-to-image prompt engineer. Return valid JSON only, no markdown. "
200
+ "Jointly coordinate the positive structured prompt and generator-side negative prompt so they do not contradict each other."
201
+ ),
202
+ },
203
+ {
204
+ "role": "user",
205
+ "content": self._joint_rewrite_user_prompt(
206
+ item=item,
207
+ previous_prompt=previous_prompt,
208
+ previous_negative_prompt=previous_negative_prompt,
209
+ previous_analysis=previous_analysis,
210
+ history=history,
211
+ schema_keys=schema_keys,
212
+ ),
213
+ },
214
+ ]
215
+ last_exc: Exception | None = None
216
+ for attempt in range(1, 4):
217
+ try:
218
+ raw = self.rewrite_client.complete(messages, response_format_json=True)
219
+ return self._parse_joint_rewrite_response(raw, item.prompt_id)
220
+ except Exception as exc:
221
+ last_exc = exc
222
+ if attempt < 3:
223
+ time.sleep(min(20.0, 2.0 * attempt))
224
+ raise RuntimeError(f"Joint prompt rewrite failed after 3 attempts for prompt {item.prompt_id}.") from last_exc
225
+
226
+ @staticmethod
227
+ def _parse_joint_rewrite_response(raw: str, prompt_id: str) -> tuple[dict[str, Any], str]:
228
+ data = extract_json_object(raw)
229
+ positive_prompt = data.get("positive_prompt")
230
+ if not isinstance(positive_prompt, dict):
231
+ raise ValueError(f"Joint rewrite returned missing or non-object positive_prompt for prompt {prompt_id}.")
232
+ validate_t2i_json(positive_prompt, prompt_id)
233
+ negative_prompt = data.get("negative_prompt", "")
234
+ if not isinstance(negative_prompt, str):
235
+ raise ValueError(f"Joint rewrite returned non-string negative_prompt for prompt {prompt_id}.")
236
+ return positive_prompt, " ".join(negative_prompt.split())
237
+
238
+ @staticmethod
239
+ def _joint_rewrite_user_prompt(
240
+ *,
241
+ item: PromptItem,
242
+ previous_prompt: dict[str, Any],
243
+ previous_negative_prompt: str,
244
+ previous_analysis: dict[str, Any],
245
+ history: list[dict[str, Any]],
246
+ schema_keys: list[str],
247
+ ) -> str:
248
+ sections = [
249
+ "Original user prompt:",
250
+ item.prompt,
251
+ "",
252
+ "Application-specific guidance:",
253
+ "Apply the following sections as one checklist program. Do not first classify the prompt. Apply each section only when relevant to the original user prompt, previous JSON, or VLM failures.",
254
+ REWRITER_APPLICATION_GUIDANCE,
255
+ "",
256
+ "Previous generated image failed or scored according to this VLM analysis:",
257
+ analysis_json_text(compact_analysis_for_rewrite(previous_analysis)),
258
+ "",
259
+ "Iteration history summary:",
260
+ json.dumps(PromptRewriterClient._history_summary(history), ensure_ascii=JSON_ENSURE_ASCII, indent=2),
261
+ "",
262
+ "Previous positive JSON prompt:",
263
+ json.dumps(previous_prompt, ensure_ascii=JSON_ENSURE_ASCII, indent=2),
264
+ "",
265
+ "Previous negative prompt:",
266
+ previous_negative_prompt or "",
267
+ "",
268
+ "Joint rewrite task:",
269
+ 'Return a JSON object with exactly two top-level keys: "positive_prompt" and "negative_prompt".',
270
+ '"positive_prompt" must be a complete JSON object with exactly these top-level keys, preserving their names and types:',
271
+ json.dumps(schema_keys, ensure_ascii=JSON_ENSURE_ASCII),
272
+ "",
273
+ '"positive_prompt" must keep the previous "resolution" and "aspect_ratio".',
274
+ '"negative_prompt" must be a concise generator-side negative prompt string.',
275
+ "Coordinate both fields: strengthen required positive constraints while using the negative prompt only to suppress concrete wrong alternatives or artifacts.",
276
+ "Do not put positive instructions in negative_prompt. Do not negate content required by the original user prompt.",
277
+ "For exact counts, grids, text, geometry, or anatomy, explicitly block wrong alternatives when useful.",
278
+ 'The positive "comprehensive_t2i_caption" should be direct generation guidance, not an explanation of this rewrite process.',
279
+ ]
280
+ return "\n".join(sections)
281
+
282
+ @staticmethod
283
+ def _history_summary(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
284
+ return [
285
+ {
286
+ "iteration": item.get("iteration"),
287
+ "overall_score": item.get("analysis", {}).get("overall_score"),
288
+ "prompt_adherence_score": item.get("analysis", {}).get("prompt_adherence_score"),
289
+ "category_score": item.get("analysis", {}).get("category_score"),
290
+ "threshold_cleared": item.get("analysis", {}).get("threshold_cleared"),
291
+ }
292
+ for item in history
293
+ ]
294
+
295
+
296
+ class ImageGenerationClient:
297
+ """Client for a vLLM-Omni /v1/images/generations text-to-image endpoint."""
298
+
299
+ endpoint: str
300
+ auth_key: str
301
+ model: str
302
+ session: requests.Session
303
+ size: str
304
+ num_steps: int
305
+ guidance: float
306
+ flow_shift: float
307
+ extra_args: dict[str, Any]
308
+
309
+ def __init__(
310
+ self,
311
+ *,
312
+ endpoint: str,
313
+ auth_key: str = "",
314
+ model: str = DEFAULT_GENERATION_MODEL,
315
+ size: str = DEFAULT_IMAGE_SIZE,
316
+ num_steps: int = DEFAULT_NUM_STEPS,
317
+ guidance: float = DEFAULT_GUIDANCE,
318
+ flow_shift: float = DEFAULT_FLOW_SHIFT,
319
+ extra_args: dict[str, Any] | None = None,
320
+ session: requests.Session | None = None,
321
+ ) -> None:
322
+ self.endpoint = normalize_generation_endpoint(endpoint)
323
+ self.auth_key = auth_key
324
+ self.model = model
325
+ self.session = session or make_session()
326
+ self.size = size
327
+ self.num_steps = num_steps
328
+ self.guidance = guidance
329
+ self.flow_shift = flow_shift
330
+ self.extra_args = dict(DEFAULT_GENERATION_EXTRA_ARGS if extra_args is None else extra_args)
331
+
332
+ def build_payload(
333
+ self,
334
+ prompt_json: dict[str, Any],
335
+ prompt_id: str,
336
+ seed: int | None = None,
337
+ negative_prompt: str = "",
338
+ ) -> dict[str, Any]:
339
+ """Build the vLLM-Omni image generation request payload."""
340
+ del prompt_id
341
+ payload: dict[str, Any] = {
342
+ "model": self.model,
343
+ "prompt": compact_json(prompt_json, ensure_ascii=JSON_ENSURE_ASCII),
344
+ "size": self.size,
345
+ "n": 1,
346
+ "response_format": "b64_json",
347
+ "negative_prompt": negative_prompt.strip(),
348
+ "num_inference_steps": self.num_steps,
349
+ "guidance_scale": self.guidance,
350
+ "flow_shift": self.flow_shift,
351
+ "extra_args": dict(self.extra_args),
352
+ }
353
+ if seed is not None:
354
+ payload["seed"] = int(seed)
355
+ return payload
356
+
357
+ def generate(
358
+ self,
359
+ *,
360
+ prompt_json: dict[str, Any],
361
+ prompt_id: str,
362
+ output_dir: Path,
363
+ seed: int | None = None,
364
+ negative_prompt: str = "",
365
+ jpeg_quality: int = DEFAULT_JPEG_QUALITY,
366
+ ) -> GenerationOutput:
367
+ """Generate and persist one candidate image."""
368
+ payload = self.build_payload(prompt_json, prompt_id, seed, negative_prompt=negative_prompt)
369
+ response_json = self._generate_image(payload)
370
+ image_bytes = self._decode_image_response(response_json)
371
+ image_path = output_dir / "image.jpg"
372
+ image_info = self._save_jpeg(image_bytes, image_path, jpeg_quality)
373
+ meta = {
374
+ "prompt_id": prompt_id,
375
+ "status": "completed",
376
+ "endpoint": self.endpoint,
377
+ "image_generation_url": self._image_generation_url(),
378
+ "payload": payload,
379
+ "response": self._response_without_image_bytes(response_json),
380
+ "output_image_path": str(image_path),
381
+ "image_info": image_info,
382
+ }
383
+ meta_path = output_dir / "generation_meta.json"
384
+ write_json_atomic(meta_path, meta, ensure_ascii=JSON_ENSURE_ASCII)
385
+ return GenerationOutput(image_path=image_path, meta_path=meta_path, meta=meta)
386
+
387
+ def _generate_image(self, payload: dict[str, Any]) -> dict[str, Any]:
388
+ last_exc: Exception | None = None
389
+ for attempt in range(1, 4):
390
+ try:
391
+ return self._request_json(
392
+ "POST",
393
+ self._image_generation_url(),
394
+ json=payload,
395
+ headers=self._auth_headers(),
396
+ timeout=(CONNECT_TIMEOUT_S, IMAGE_GENERATION_READ_TIMEOUT_S),
397
+ )
398
+ except Exception as exc:
399
+ last_exc = exc
400
+ if attempt < 3:
401
+ time.sleep(min(20.0, 2.0 * attempt))
402
+ raise RuntimeError(f"/v1/images/generations failed after retries: {last_exc}") from last_exc
403
+
404
+ def _image_generation_url(self) -> str:
405
+ return f"{self.endpoint}/v1/images/generations"
406
+
407
+ def _auth_headers(self) -> dict[str, str] | None:
408
+ token = self.auth_key.strip()
409
+ if not token:
410
+ return None
411
+ if token.lower().startswith("bearer "):
412
+ return {"Authorization": token}
413
+ return {"Authorization": f"Bearer {token}"}
414
+
415
+ def _request_json(self, method: str, url: str, **kwargs: Any) -> dict[str, Any]:
416
+ timeout = kwargs.pop("timeout", (CONNECT_TIMEOUT_S, IMAGE_GENERATION_READ_TIMEOUT_S))
417
+ response = self.session.request(method, url, timeout=timeout, **kwargs)
418
+ if not response.ok:
419
+ raise RuntimeError(f"{method} {url} HTTP {response.status_code}: {response.text[:1000]}")
420
+ parsed = response.json()
421
+ if not isinstance(parsed, dict):
422
+ raise RuntimeError(f"{method} {url} returned non-object JSON: {parsed!r}")
423
+ return parsed
424
+
425
+ @staticmethod
426
+ def _decode_image_response(response_json: dict[str, Any]) -> bytes:
427
+ data = response_json.get("data")
428
+ if not isinstance(data, list) or not data or not isinstance(data[0], dict):
429
+ raise RuntimeError(f"Image generation response has no data[0] object: {response_json}")
430
+ first_image = data[0]
431
+ b64_image = first_image.get("b64_json")
432
+ if not isinstance(b64_image, str) or not b64_image.strip():
433
+ image_url = first_image.get("url")
434
+ if isinstance(image_url, str) and image_url.startswith("data:image") and "," in image_url:
435
+ b64_image = image_url.split(",", 1)[1]
436
+ else:
437
+ raise RuntimeError(f"Image generation response has no b64_json image: {response_json}")
438
+ try:
439
+ return base64.b64decode(b64_image, validate=True)
440
+ except ValueError:
441
+ return base64.b64decode(b64_image)
442
+
443
+ @staticmethod
444
+ def _response_without_image_bytes(response_json: dict[str, Any]) -> dict[str, Any]:
445
+ redacted = json.loads(json.dumps(response_json))
446
+ data = redacted.get("data")
447
+ if isinstance(data, list):
448
+ for item in data:
449
+ if isinstance(item, dict) and isinstance(item.get("b64_json"), str):
450
+ item["b64_json"] = f"<base64 image omitted: {len(item['b64_json'])} chars>"
451
+ if isinstance(item, dict) and isinstance(item.get("url"), str) and item["url"].startswith("data:image"):
452
+ item["url"] = f"<data image omitted: {len(item['url'])} chars>"
453
+ return redacted
454
+
455
+ @staticmethod
456
+ def _save_jpeg(image_bytes: bytes, output_path: Path, quality: int) -> dict[str, Any]:
457
+ output_path.parent.mkdir(parents=True, exist_ok=True)
458
+ tmp = output_path.with_suffix(output_path.suffix + ".tmp")
459
+ with Image.open(io.BytesIO(image_bytes)) as image:
460
+ source_format = image.format
461
+ rgb = image.convert("RGB")
462
+ width, height = rgb.size
463
+ rgb.save(tmp, format="JPEG", quality=quality, optimize=True)
464
+ tmp.replace(output_path)
465
+ return {"source_image_format": source_format, "saved_format": "JPEG", "width": width, "height": height}
466
+
467
+
468
+ class VLMQualityJudge:
469
+ """Gemini critic for generated images through an OpenAI-compatible endpoint."""
470
+
471
+ chat_client: OpenAIChatClient
472
+ image_jpeg_quality: int | None
473
+
474
+ def __init__(
475
+ self,
476
+ *,
477
+ api_token: str,
478
+ endpoint_url: str = DEFAULT_CRITIC_ENDPOINT_URL,
479
+ model: str = DEFAULT_CRITIC_MODEL,
480
+ max_tokens: int = 8192,
481
+ image_jpeg_quality: int | None = DEFAULT_JPEG_QUALITY,
482
+ ) -> None:
483
+ self.chat_client = OpenAIChatClient(
484
+ ChatClientConfig(
485
+ endpoint_url=endpoint_url,
486
+ model=model,
487
+ api_token=api_token,
488
+ max_tokens=max_tokens,
489
+ max_retries=3,
490
+ )
491
+ )
492
+ self.image_jpeg_quality = image_jpeg_quality
493
+
494
+ def score_image(
495
+ self,
496
+ *,
497
+ item: PromptItem,
498
+ image_path: Path,
499
+ ) -> dict[str, Any]:
500
+ """Score one image with the non-classifying rubric program."""
501
+ messages = [
502
+ SYSTEM_MESSAGE,
503
+ {
504
+ "role": "user",
505
+ "content": [
506
+ {
507
+ "type": "image_url",
508
+ "image_url": {"url": image_path_to_data_url(image_path, jpeg_quality=self.image_jpeg_quality)},
509
+ },
510
+ {
511
+ "type": "text",
512
+ "text": build_judge_prompt(item),
513
+ },
514
+ ],
515
+ },
516
+ ]
517
+ raw = self.chat_client.complete(messages, response_format_json=True)
518
+ analysis = parse_analysis_response(raw)
519
+ analysis["raw_response"] = raw
520
+ return analysis
521
+
agentic_upsampling/constants.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public defaults for the standalone agentic text-to-image upsampler."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ DEFAULT_OPENAI_ENDPOINT_URL = "https://api.openai.com/v1"
8
+ DEFAULT_UPSAMPLER_ENDPOINT_URL = DEFAULT_OPENAI_ENDPOINT_URL
9
+ DEFAULT_REWRITER_ENDPOINT_URL = DEFAULT_OPENAI_ENDPOINT_URL
10
+ DEFAULT_UPSAMPLER_MODEL = "gpt-5.5"
11
+ DEFAULT_REWRITER_MODEL = "gpt-5.5"
12
+ DEFAULT_OPENAI_API_KEY_ENV = "OPENAI_API_KEY"
13
+ DEFAULT_LLM_EXTRA_BODY: dict[str, Any] = {"reasoning_effort": "low"}
14
+
15
+ DEFAULT_CRITIC_ENDPOINT_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
16
+ DEFAULT_CRITIC_MODEL = "gemini-3.1-pro-preview"
17
+ DEFAULT_GEMINI_API_KEY_ENV = "GEMINI_API_KEY"
18
+
19
+ DEFAULT_GENERATION_AUTH_KEY_ENV = "AGENTIC_UPSAMPLING_GENERATION_AUTH_KEY"
20
+ DEFAULT_GENERATION_MODEL = "nvidia/Cosmos3-Super-Text2Image"
21
+ DEFAULT_IMAGE_SIZE = "1024x1024"
22
+ DEFAULT_GENERATION_EXTRA_ARGS: dict[str, Any] = {"guardrails": False, "use_resolution_template": False}
23
+
24
+ DEFAULT_RESOLUTION = "768"
25
+ DEFAULT_ASPECT_RATIO = "1,1"
26
+ DEFAULT_NUM_STEPS = 50
27
+ DEFAULT_GUIDANCE = 4.0
28
+ DEFAULT_FLOW_SHIFT = 3.0
29
+ DEFAULT_MAX_ITERATIONS = 2
30
+ DEFAULT_SAMPLES_PER_ITERATION = 3
31
+ DEFAULT_JPEG_QUALITY = 99
32
+
33
+ STRICT_OVERALL_THRESHOLD = 9.0
34
+ STRICT_PROMPT_THRESHOLD = 9.0
35
+
agentic_upsampling/data.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generic prompt loading and text-to-image JSON validation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import json
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+
13
+ @dataclass(frozen=True, slots=True)
14
+ class PromptItem:
15
+ """One text-to-image prompt to process."""
16
+
17
+ prompt_id: str
18
+ row_number: int
19
+ prompt: str
20
+ metadata: dict[str, Any] = field(default_factory=dict)
21
+
22
+
23
+ REQUIRED_T2I_KEYS = {
24
+ "subjects",
25
+ "subject_details",
26
+ "background_setting",
27
+ "lighting",
28
+ "text_and_signage_elements",
29
+ "resolution",
30
+ "aspect_ratio",
31
+ "comprehensive_t2i_caption",
32
+ }
33
+
34
+ PROMPT_COLUMNS = ("prompt", "Prompt")
35
+ ID_COLUMNS = ("id", "ID", "prompt_id", "Prompt ID")
36
+ _SAFE_ID_RE = re.compile(r"[^A-Za-z0-9_.-]+")
37
+
38
+
39
+ def prompt_dir_name(item: PromptItem) -> str:
40
+ """Return the deterministic output directory name for a prompt."""
41
+ raw_id = item.prompt_id.strip()
42
+ if raw_id.isdigit():
43
+ return f"{int(raw_id):04d}"
44
+ cleaned = _SAFE_ID_RE.sub("_", raw_id).strip("._-")
45
+ return cleaned or f"row_{item.row_number + 1:04d}"
46
+
47
+
48
+ def load_prompt_items(
49
+ *,
50
+ prompt: str | None = None,
51
+ prompts_path: Path | None = None,
52
+ limit: int | None = None,
53
+ ) -> list[PromptItem]:
54
+ """Load prompts from a literal prompt or a txt/jsonl/csv file."""
55
+ if bool(prompt) == bool(prompts_path):
56
+ raise ValueError("Provide exactly one of --prompt or --prompts.")
57
+ if prompt:
58
+ items = [PromptItem(prompt_id="1", row_number=0, prompt=prompt.strip())]
59
+ elif prompts_path is not None:
60
+ items = _load_prompts_path(prompts_path)
61
+ else:
62
+ items = []
63
+
64
+ items = [item for item in items if item.prompt.strip()]
65
+ if limit is not None and limit >= 0:
66
+ items = items[:limit]
67
+ _validate_unique_output_dirs(items)
68
+ return items
69
+
70
+
71
+ def _load_prompts_path(path: Path) -> list[PromptItem]:
72
+ suffix = path.suffix.lower()
73
+ if suffix == ".txt":
74
+ return _load_txt_prompts(path)
75
+ if suffix == ".jsonl":
76
+ return _load_jsonl_prompts(path)
77
+ if suffix == ".csv":
78
+ return _load_csv_prompts(path)
79
+ raise ValueError(f"Unsupported prompt file extension {suffix!r}. Use .txt, .jsonl, or .csv.")
80
+
81
+
82
+ def _load_txt_prompts(path: Path) -> list[PromptItem]:
83
+ items: list[PromptItem] = []
84
+ for row_number, line in enumerate(path.read_text(encoding="utf-8").splitlines()):
85
+ prompt = line.strip()
86
+ if not prompt:
87
+ continue
88
+ items.append(PromptItem(prompt_id=str(len(items) + 1), row_number=row_number, prompt=prompt))
89
+ return items
90
+
91
+
92
+ def _load_jsonl_prompts(path: Path) -> list[PromptItem]:
93
+ items: list[PromptItem] = []
94
+ with path.open(encoding="utf-8") as f:
95
+ for row_number, line in enumerate(f):
96
+ raw = line.strip()
97
+ if not raw:
98
+ continue
99
+ parsed = json.loads(raw)
100
+ if isinstance(parsed, str):
101
+ prompt = parsed.strip()
102
+ prompt_id = str(len(items) + 1)
103
+ metadata: dict[str, Any] = {}
104
+ elif isinstance(parsed, dict):
105
+ prompt = str(parsed.get("prompt") or parsed.get("Prompt") or "").strip()
106
+ prompt_id = str(parsed.get("id") or parsed.get("prompt_id") or len(items) + 1)
107
+ metadata = {key: value for key, value in parsed.items() if key not in {"prompt", "Prompt"}}
108
+ else:
109
+ raise ValueError(f"JSONL row {row_number + 1} must be a string or object.")
110
+ if prompt:
111
+ items.append(PromptItem(prompt_id=prompt_id, row_number=row_number, prompt=prompt, metadata=metadata))
112
+ return items
113
+
114
+
115
+ def _load_csv_prompts(path: Path) -> list[PromptItem]:
116
+ items: list[PromptItem] = []
117
+ with path.open(newline="", encoding="utf-8") as f:
118
+ reader = csv.DictReader(f)
119
+ for row_number, row in enumerate(reader):
120
+ prompt_key = _first_present_key(row, PROMPT_COLUMNS)
121
+ if prompt_key is None:
122
+ raise ValueError(f"CSV must include one of these prompt columns: {', '.join(PROMPT_COLUMNS)}.")
123
+ prompt = str(row.get(prompt_key) or "").strip()
124
+ if not prompt:
125
+ continue
126
+ id_key = _first_present_key(row, ID_COLUMNS)
127
+ prompt_id = str(row.get(id_key) or len(items) + 1) if id_key is not None else str(len(items) + 1)
128
+ items.append(PromptItem(prompt_id=prompt_id, row_number=row_number, prompt=prompt, metadata=dict(row)))
129
+ return items
130
+
131
+
132
+ def _first_present_key(row: dict[str, Any], keys: tuple[str, ...]) -> str | None:
133
+ for key in keys:
134
+ if key in row:
135
+ return key
136
+ return None
137
+
138
+
139
+ def _validate_unique_output_dirs(items: list[PromptItem]) -> None:
140
+ seen: dict[str, str] = {}
141
+ for item in items:
142
+ dirname = prompt_dir_name(item)
143
+ previous = seen.get(dirname)
144
+ if previous is not None:
145
+ raise ValueError(f"Prompt ids {previous!r} and {item.prompt_id!r} map to the same output dir {dirname!r}.")
146
+ seen[dirname] = item.prompt_id
147
+
148
+
149
+ def validate_t2i_json(data: dict[str, Any], prompt_id: str) -> None:
150
+ """Validate the minimum structured T2I JSON shape expected by Cosmos3."""
151
+ missing = sorted(REQUIRED_T2I_KEYS - set(data))
152
+ if missing:
153
+ raise ValueError(f"Prompt JSON for {prompt_id} is missing required keys: {missing}")
154
+ if not isinstance(data.get("subjects"), list):
155
+ raise ValueError(f"Prompt JSON for {prompt_id}: subjects must be a list.")
156
+ if not isinstance(data.get("text_and_signage_elements"), list):
157
+ raise ValueError(f"Prompt JSON for {prompt_id}: text_and_signage_elements must be a list.")
158
+ caption = data.get("comprehensive_t2i_caption")
159
+ if not isinstance(caption, str) or not caption.strip():
160
+ raise ValueError(f"Prompt JSON for {prompt_id}: comprehensive_t2i_caption is empty.")
161
+ resolution = data.get("resolution")
162
+ if not isinstance(resolution, dict) or not {"H", "W"}.issubset(resolution):
163
+ raise ValueError(f"Prompt JSON for {prompt_id}: resolution must contain H and W.")
164
+ aspect_ratio = data.get("aspect_ratio")
165
+ if not isinstance(aspect_ratio, str) or not aspect_ratio.strip():
166
+ raise ValueError(f"Prompt JSON for {prompt_id}: aspect_ratio must be a non-empty string.")
167
+
agentic_upsampling/extract_best.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extract best agentic upsampling images from an output directory."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import csv
7
+ import json
8
+ import shutil
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from agentic_upsampling.io_utils import append_jsonl
13
+
14
+ IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".webp"}
15
+
16
+
17
+ def parse_args() -> argparse.Namespace:
18
+ parser = argparse.ArgumentParser(description=__doc__)
19
+ parser.add_argument("--output-dir", type=Path, required=True, help="Agentic upsampler run output directory.")
20
+ parser.add_argument(
21
+ "--export-dir",
22
+ type=Path,
23
+ default=None,
24
+ help="Directory for copied best images and manifests. Defaults to OUTPUT_DIR/best_generations.",
25
+ )
26
+ parser.add_argument("--overwrite", action="store_true", help="Replace existing copied images/manifests.")
27
+ return parser.parse_args()
28
+
29
+
30
+ def iter_best_jsons(output_dir: Path) -> list[Path]:
31
+ """Return per-prompt best.json files in deterministic order."""
32
+ return sorted(path for path in output_dir.glob("*/best.json") if path.parent.name != "best_generations")
33
+
34
+
35
+ def resolve_image_path(raw_path: str, *, output_dir: Path, best_json_path: Path) -> Path:
36
+ """Resolve image paths written by runs launched with relative or absolute output dirs."""
37
+ image_path = Path(raw_path)
38
+ candidates = [image_path]
39
+ if not image_path.is_absolute():
40
+ candidates.extend(
41
+ [
42
+ output_dir / image_path,
43
+ output_dir.parent / image_path,
44
+ best_json_path.parent / image_path.name,
45
+ ]
46
+ )
47
+ for candidate in candidates:
48
+ if candidate.exists():
49
+ return candidate
50
+ raise FileNotFoundError(f"Best image does not exist: {raw_path}")
51
+
52
+
53
+ def copied_image_name(record: dict[str, Any], image_path: Path) -> str:
54
+ """Build a simple copied image filename."""
55
+ prompt_id = str(record["prompt_id"])
56
+ suffix = image_path.suffix.lower()
57
+ if suffix not in IMAGE_SUFFIXES:
58
+ suffix = ".jpg"
59
+ return f"{prompt_id}{suffix}"
60
+
61
+
62
+ def extract_record(best_json_path: Path, *, output_dir: Path, images_dir: Path, overwrite: bool) -> dict[str, Any]:
63
+ """Copy one best image and return its export manifest record."""
64
+ best_data = json.loads(best_json_path.read_text(encoding="utf-8"))
65
+ if not isinstance(best_data, dict):
66
+ raise ValueError(f"{best_json_path} must contain a JSON object.")
67
+ best = best_data.get("best")
68
+ if not isinstance(best, dict):
69
+ raise ValueError(f"{best_json_path} is missing best candidate metadata.")
70
+ raw_image_path = str(best.get("image_path") or "")
71
+ if not raw_image_path:
72
+ raise ValueError(f"{best_json_path} best candidate is missing image_path.")
73
+ image_path = resolve_image_path(raw_image_path, output_dir=output_dir, best_json_path=best_json_path)
74
+ record = {
75
+ "prompt_id": str(best_data["prompt_id"]),
76
+ "prompt": str(best_data.get("prompt") or ""),
77
+ "best_score": best_data.get("best_score"),
78
+ "best_iteration": best_data.get("best_iteration"),
79
+ "selected_sample_index": best.get("selected_sample_index", best.get("sample_index")),
80
+ "threshold_cleared_any": bool(best_data.get("threshold_cleared_any")),
81
+ "source_image_path": str(image_path),
82
+ "best_json_path": str(best_json_path),
83
+ "analysis_path": str(best.get("analysis_path") or ""),
84
+ }
85
+ dest_path = images_dir / copied_image_name(record, image_path)
86
+ if dest_path.exists() and not overwrite:
87
+ raise FileExistsError(f"Refusing to overwrite existing image: {dest_path}")
88
+ images_dir.mkdir(parents=True, exist_ok=True)
89
+ shutil.copy2(image_path, dest_path)
90
+ record["copied_image_path"] = str(dest_path)
91
+ return record
92
+
93
+
94
+ def write_csv(path: Path, records: list[dict[str, Any]]) -> None:
95
+ """Write a flat CSV summary for quick spreadsheet inspection."""
96
+ fieldnames = [
97
+ "prompt_id",
98
+ "best_score",
99
+ "best_iteration",
100
+ "selected_sample_index",
101
+ "threshold_cleared_any",
102
+ "copied_image_path",
103
+ "source_image_path",
104
+ "best_json_path",
105
+ "analysis_path",
106
+ "prompt",
107
+ ]
108
+ with path.open("w", newline="", encoding="utf-8") as f:
109
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
110
+ writer.writeheader()
111
+ for record in records:
112
+ writer.writerow({key: record.get(key, "") for key in fieldnames})
113
+
114
+
115
+ def extract_best_images(output_dir: Path, export_dir: Path, *, overwrite: bool = False) -> list[dict[str, Any]]:
116
+ """Copy best images from a run and write JSONL/CSV manifests."""
117
+ output_dir = output_dir.expanduser()
118
+ export_dir = export_dir.expanduser()
119
+ if not output_dir.exists():
120
+ raise FileNotFoundError(f"Missing output directory: {output_dir}")
121
+ best_jsons = iter_best_jsons(output_dir)
122
+ if not best_jsons:
123
+ raise RuntimeError(f"No per-prompt best.json files found under {output_dir}")
124
+
125
+ images_dir = export_dir / "images"
126
+ manifest_path = export_dir / "best_generations.jsonl"
127
+ csv_path = export_dir / "best_generations.csv"
128
+ if overwrite:
129
+ manifest_path.unlink(missing_ok=True)
130
+ csv_path.unlink(missing_ok=True)
131
+ elif manifest_path.exists() or csv_path.exists():
132
+ raise FileExistsError(f"Export manifests already exist in {export_dir}; pass --overwrite to replace them.")
133
+
134
+ records: list[dict[str, Any]] = []
135
+ for best_json_path in best_jsons:
136
+ record = extract_record(best_json_path, output_dir=output_dir, images_dir=images_dir, overwrite=overwrite)
137
+ records.append(record)
138
+ append_jsonl(manifest_path, record)
139
+ write_csv(csv_path, records)
140
+ return records
141
+
142
+
143
+ def main() -> int:
144
+ args = parse_args()
145
+ export_dir = args.export_dir or (args.output_dir / "best_generations")
146
+ records = extract_best_images(args.output_dir, export_dir, overwrite=args.overwrite)
147
+ print(f"Exported {len(records)} best images to {export_dir}", flush=True)
148
+ print(f"Images: {export_dir / 'images'}", flush=True)
149
+ print(f"JSONL: {export_dir / 'best_generations.jsonl'}", flush=True)
150
+ print(f"CSV: {export_dir / 'best_generations.csv'}", flush=True)
151
+ return 0
152
+
153
+
154
+ if __name__ == "__main__":
155
+ raise SystemExit(main())
agentic_upsampling/io_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Small JSON and file helpers for agentic upsampling runs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import tempfile
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+
12
+ def write_json_atomic(path: Path, data: Any, *, ensure_ascii: bool = True) -> None:
13
+ """Write JSON through a temporary file and atomically replace the destination."""
14
+ path.parent.mkdir(parents=True, exist_ok=True)
15
+ fd, tmp_name = tempfile.mkstemp(prefix=f".{path.name}.", suffix=".tmp", dir=path.parent)
16
+ try:
17
+ with os.fdopen(fd, "w", encoding="utf-8") as f:
18
+ json.dump(data, f, ensure_ascii=ensure_ascii, indent=2)
19
+ f.write("\n")
20
+ Path(tmp_name).replace(path)
21
+ except Exception:
22
+ try:
23
+ Path(tmp_name).unlink(missing_ok=True)
24
+ finally:
25
+ raise
26
+
27
+
28
+ def append_jsonl(path: Path, data: Any, *, ensure_ascii: bool = True) -> None:
29
+ """Append one compact JSON record to a JSONL file."""
30
+ path.parent.mkdir(parents=True, exist_ok=True)
31
+ with path.open("a", encoding="utf-8") as f:
32
+ f.write(json.dumps(data, ensure_ascii=ensure_ascii, separators=(",", ":")) + "\n")
33
+
34
+
35
+ def read_json(path: Path) -> dict[str, Any]:
36
+ """Read a JSON object from disk."""
37
+ data = json.loads(path.read_text(encoding="utf-8"))
38
+ if not isinstance(data, dict):
39
+ raise ValueError(f"{path} must contain a JSON object.")
40
+ return data
41
+
42
+
43
+ def compact_json(data: dict[str, Any], *, ensure_ascii: bool = True) -> str:
44
+ """Serialize JSON using the compact prompt format expected by the generation endpoint."""
45
+ return json.dumps(data, ensure_ascii=ensure_ascii, separators=(",", ":"))
46
+
agentic_upsampling/prompt_upsampler.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI-compatible text-to-image prompt upsampling client."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import re
9
+ import time
10
+ from collections.abc import Callable
11
+ from dataclasses import dataclass
12
+ from typing import Any
13
+
14
+ import requests
15
+ from requests.adapters import HTTPAdapter
16
+ from urllib3.util.retry import Retry
17
+
18
+ from agentic_upsampling.constants import (
19
+ DEFAULT_LLM_EXTRA_BODY,
20
+ DEFAULT_UPSAMPLER_ENDPOINT_URL,
21
+ DEFAULT_UPSAMPLER_MODEL,
22
+ )
23
+ from agentic_upsampling.data import validate_t2i_json
24
+
25
+ JSON_ENSURE_ASCII = bool(int(os.environ.get("JSON_ENSURE_ASCII", "1")))
26
+ DEFAULT_USER_AGENT = "Cosmos3-Super-Text2Image-Agentic-Upsampling/1.0"
27
+ SYSTEM_MESSAGE: dict[str, Any] = {
28
+ "role": "system",
29
+ "content": [{"type": "text", "text": "You are a helpful assistant."}],
30
+ }
31
+ log = logging.getLogger(__name__)
32
+
33
+ RESOLUTION_RATIO_DICT: dict[str, dict[str, dict[str, int]]] = {
34
+ "256": {
35
+ "1,1": {"W": 256, "H": 256},
36
+ "4,3": {"W": 320, "H": 256},
37
+ "3,4": {"W": 256, "H": 320},
38
+ "16,9": {"W": 320, "H": 192},
39
+ "9,16": {"W": 192, "H": 320},
40
+ },
41
+ "480": {
42
+ "1,1": {"W": 640, "H": 640},
43
+ "4,3": {"W": 736, "H": 544},
44
+ "3,4": {"W": 544, "H": 736},
45
+ "16,9": {"W": 832, "H": 480},
46
+ "9,16": {"W": 480, "H": 832},
47
+ },
48
+ "720": {
49
+ "1,1": {"W": 960, "H": 960},
50
+ "4,3": {"W": 1104, "H": 832},
51
+ "3,4": {"W": 832, "H": 1104},
52
+ "16,9": {"W": 1280, "H": 720},
53
+ "9,16": {"W": 720, "H": 1280},
54
+ },
55
+ "768": {
56
+ "1,1": {"W": 1024, "H": 1024},
57
+ "4,3": {"W": 1184, "H": 880},
58
+ "3,4": {"W": 880, "H": 1184},
59
+ "16,9": {"W": 1360, "H": 768},
60
+ "9,16": {"W": 768, "H": 1360},
61
+ },
62
+ }
63
+
64
+ T2I_JSON_TEMPLATE = """Given the user's natural-language request below, generate a dense structured JSON that fully describes the image to be produced. The JSON must strictly follow the template provided after the request, including every top-level key and every nested sub-field.
65
+
66
+ The output is always dense. Even when the request is brief, infer plausible, scene-consistent details for every field. Do not leave fields empty merely because the request did not mention them. Be creative but stay grounded: additions must be physically plausible and internally consistent with the request.
67
+
68
+ Requirements:
69
+ - Extract visual intent from the user request into the visual fields.
70
+ - For every visual field, write rich, specific content inferred from the request's scene, subjects, mood, and context.
71
+ - Empty values ("", 0, [], {{}}) are permitted only for truly inapplicable fields.
72
+ - Do not add keys beyond the template. Do not omit keys required by the template.
73
+ - Return only the JSON object. Do not include markdown fences or prose outside JSON.
74
+
75
+ USER VISUAL REQUEST:
76
+ {caption_dense}
77
+
78
+ Lists may contain zero or more items of the shape shown. All top-level keys must always be present in the output; fill unused fields with "", 0, {{}}, or [] as appropriate.
79
+
80
+ {{
81
+ "subjects": [
82
+ {{
83
+ "description": "full visual description of the subject",
84
+ "appearance_details": "additional visual details such as accessories, texture, and distinguishing features",
85
+ "relationship": "how this subject relates to others or to the scene",
86
+ "location": "where in frame, for example center foreground or top right",
87
+ "relative_size": "size within frame",
88
+ "orientation": "direction subject faces relative to camera",
89
+ "pose": "body position and posture",
90
+ "clothing": "clothing and accessories; empty string if non-human or not applicable",
91
+ "expression": "facial expression; empty string if non-human or not applicable",
92
+ "gender": "Male, Female, Unknown, or N/A",
93
+ "age": "age category",
94
+ "skin_tone_and_texture": "skin tone description; empty string if non-human",
95
+ "facial_features": "notable facial features; empty string if non-human or not visible",
96
+ "number_of_subjects": "int; total in this subject group, 0 if not applicable",
97
+ "number_of_arms": "int; 2 for humans, 0 if non-human",
98
+ "number_of_legs": "int; 2 for humans, 0 if non-human",
99
+ "number_of_hands": "int; 2 for humans, 0 if non-human",
100
+ "number_of_fingers": "int; 10 for humans, 0 if non-human"
101
+ }}
102
+ ],
103
+ "subject_details": {{
104
+ "key_name_1": "free-form image-specific attribute; empty object if not applicable"
105
+ }},
106
+ "background_setting": "full prose description of the environment and setting",
107
+ "lighting": {{
108
+ "conditions": "type and quality of light",
109
+ "direction": "where light comes from; None for flat digital images",
110
+ "shadows": "shadow description; None for flat digital images",
111
+ "illumination_effect": "overall effect of the lighting"
112
+ }},
113
+ "aesthetics": {{
114
+ "composition": "framing and compositional choices",
115
+ "color_scheme": "dominant colors and palette",
116
+ "mood_atmosphere": "emotional atmosphere in short phrases",
117
+ "patterns": "notable repeating visual patterns; None if none"
118
+ }},
119
+ "cinematography": {{
120
+ "framing": "shot type",
121
+ "camera_angle": "angle such as Eye-level, Low angle, or High angle",
122
+ "depth_of_field": "Shallow, Deep, Uniform focus, or N/A",
123
+ "focus": "what is in sharp focus",
124
+ "lens_focal_length": "descriptive focal length"
125
+ }},
126
+ "style_medium": "visual medium, for example Photography, Digital illustration, or Screenshot",
127
+ "artistic_style": "genre or approach",
128
+ "context": "scene context or use case",
129
+ "text_and_signage_elements": [
130
+ {{
131
+ "text": "the visible text content",
132
+ "category": "physical_in_scene, ui_text, body_text, scene_sign, logo, or label",
133
+ "appearance": "font, color, size, style",
134
+ "spatial": "position in image",
135
+ "context": "purpose or meaning of the text"
136
+ }}
137
+ ],
138
+ "quadrant_scan": {{
139
+ "top_left": "description of what appears in the top-left region",
140
+ "top_right": "description of what appears in the top-right region",
141
+ "bottom_left": "description of what appears in the bottom-left region",
142
+ "bottom_right": "description of what appears in the bottom-right region",
143
+ "absolute_center": "description of what appears at the center"
144
+ }},
145
+ "comprehensive_t2i_caption": "a comprehensive, full-scene natural-language prose description of the image",
146
+ "resolution": {{
147
+ "H": "will be overwritten by the selected resolution and aspect ratio",
148
+ "W": "will be overwritten by the selected resolution and aspect ratio"
149
+ }},
150
+ "aspect_ratio": "will be overwritten by the selected aspect ratio"
151
+ }}"""
152
+
153
+
154
+ @dataclass(slots=True)
155
+ class ChatClientConfig:
156
+ """Configuration for an OpenAI-compatible chat-completions endpoint."""
157
+
158
+ endpoint_url: str
159
+ model: str
160
+ api_token: str
161
+ timeout_s: float = 300.0
162
+ max_tokens: int = 8192
163
+ max_retries: int = 3
164
+ retry_base_delay_s: float = 1.0
165
+ extra_body: dict[str, Any] | None = None
166
+ connection_max_retries: int = 2
167
+ connection_pool_size: int = 4
168
+
169
+
170
+ class OpenAIChatClient:
171
+ """Small synchronous OpenAI-compatible chat-completions client."""
172
+
173
+ config: ChatClientConfig
174
+ base_url: str
175
+ session: requests.Session
176
+ sleep: Callable[[float], None]
177
+
178
+ def __init__(
179
+ self,
180
+ config: ChatClientConfig,
181
+ *,
182
+ session: requests.Session | None = None,
183
+ sleep: Callable[[float], None] = time.sleep,
184
+ ) -> None:
185
+ self.config = config
186
+ self.base_url = normalize_openai_base_url(config.endpoint_url)
187
+ self.session = _make_session(config) if session is None else session
188
+ self.sleep = sleep
189
+
190
+ def complete(self, messages: list[dict[str, Any]], *, response_format_json: bool = False) -> str:
191
+ """Request one chat completion and return assistant text."""
192
+
193
+ def _call() -> str:
194
+ payload: dict[str, Any] = {
195
+ "model": self.config.model,
196
+ "messages": messages,
197
+ self._max_tokens_key(): self.config.max_tokens,
198
+ }
199
+ if response_format_json:
200
+ payload["response_format"] = {"type": "json_object"}
201
+ if self.config.extra_body:
202
+ payload.update(self.config.extra_body)
203
+ parsed = self._request_json("POST", f"{self.base_url}/chat/completions", payload=payload)
204
+ choices = parsed.get("choices")
205
+ if not isinstance(choices, list) or not choices:
206
+ raise ValueError("Chat completion response missing choices.")
207
+ first_choice = choices[0]
208
+ if not isinstance(first_choice, dict):
209
+ raise ValueError("Chat completion choice must be an object.")
210
+ message = first_choice.get("message")
211
+ if not isinstance(message, dict):
212
+ raise ValueError("Chat completion choice missing message.")
213
+ return _message_content_to_text(message.get("content"))
214
+
215
+ return self._with_retries("complete chat request", _call)
216
+
217
+ def _request_json(self, method: str, url: str, *, payload: dict[str, Any] | None = None) -> dict[str, Any]:
218
+ headers = {"Accept": "application/json", "User-Agent": DEFAULT_USER_AGENT}
219
+ if payload is not None:
220
+ headers["Content-Type"] = "application/json"
221
+ if self.config.api_token:
222
+ headers["Authorization"] = f"Bearer {self.config.api_token}"
223
+ try:
224
+ response = self.session.request(method, url, json=payload, headers=headers, timeout=self.config.timeout_s)
225
+ except requests.RequestException as exc:
226
+ raise RuntimeError(f"Failed to reach {url}: {exc}") from exc
227
+ if not response.ok:
228
+ raise RuntimeError(f"HTTP {response.status_code} from {url}: {response.text[:1000]}")
229
+ parsed = response.json()
230
+ if not isinstance(parsed, dict):
231
+ raise RuntimeError(f"Response from {url} must be a JSON object.")
232
+ return parsed
233
+
234
+ def _with_retries(self, operation: str, fn: Callable[[], str]) -> str:
235
+ if self.config.max_retries < 1:
236
+ raise ValueError("max_retries must be >= 1.")
237
+ last_exc: Exception | None = None
238
+ for attempt in range(self.config.max_retries):
239
+ try:
240
+ return fn()
241
+ except Exception as exc:
242
+ last_exc = exc
243
+ if attempt == self.config.max_retries - 1:
244
+ break
245
+ self.sleep(self.config.retry_base_delay_s * (2**attempt))
246
+ raise RuntimeError(f"Failed to {operation} after {self.config.max_retries} attempts: {last_exc}") from last_exc
247
+
248
+ def _max_tokens_key(self) -> str:
249
+ if "api.openai.com" in self.base_url:
250
+ return "max_completion_tokens"
251
+ return "max_tokens"
252
+
253
+
254
+ class Text2ImagePromptUpsampler:
255
+ """Create structured Cosmos3 text-to-image JSON prompts from user text."""
256
+
257
+ chat_client: OpenAIChatClient
258
+
259
+ def __init__(self, chat_client: OpenAIChatClient) -> None:
260
+ self.chat_client = chat_client
261
+
262
+ @classmethod
263
+ def from_defaults(
264
+ cls,
265
+ *,
266
+ api_token: str,
267
+ endpoint_url: str = DEFAULT_UPSAMPLER_ENDPOINT_URL,
268
+ model: str = DEFAULT_UPSAMPLER_MODEL,
269
+ extra_body: dict[str, Any] | None = None,
270
+ ) -> Text2ImagePromptUpsampler:
271
+ """Build the default GPT-5.5 based T2I prompt upsampler."""
272
+ return cls(
273
+ OpenAIChatClient(
274
+ ChatClientConfig(
275
+ endpoint_url=endpoint_url,
276
+ model=model,
277
+ api_token=api_token,
278
+ extra_body=DEFAULT_LLM_EXTRA_BODY if extra_body is None else extra_body,
279
+ )
280
+ )
281
+ )
282
+
283
+ def upsample(
284
+ self,
285
+ prompt: str,
286
+ *,
287
+ prompt_id: str,
288
+ resolution: str,
289
+ aspect_ratio: str,
290
+ user_prompt: str | None = None,
291
+ ) -> dict[str, Any]:
292
+ """Return a validated structured T2I JSON prompt."""
293
+ messages = build_t2i_messages(prompt, user_prompt=user_prompt)
294
+ raw = self.chat_client.complete(messages, response_format_json=True)
295
+ data = apply_t2i_output_parameters(extract_json_object(raw), resolution=resolution, aspect_ratio=aspect_ratio)
296
+ validate_t2i_json(data, prompt_id)
297
+ return data
298
+
299
+
300
+ def build_t2i_messages(prompt: str, *, user_prompt: str | None = None) -> list[dict[str, Any]]:
301
+ """Build chat messages for the initial structured prompt upsampling request."""
302
+ message_text = user_prompt or T2I_JSON_TEMPLATE.format(caption_dense=prompt.strip())
303
+ return [
304
+ SYSTEM_MESSAGE,
305
+ {
306
+ "role": "user",
307
+ "content": [{"type": "text", "text": message_text}],
308
+ },
309
+ ]
310
+
311
+
312
+ def apply_t2i_output_parameters(data: dict[str, Any], *, resolution: str, aspect_ratio: str) -> dict[str, Any]:
313
+ """Overwrite output metadata with the selected T2I canvas parameters."""
314
+ if resolution not in RESOLUTION_RATIO_DICT:
315
+ raise ValueError(f"Unsupported resolution {resolution!r}.")
316
+ if aspect_ratio not in RESOLUTION_RATIO_DICT[resolution]:
317
+ raise ValueError(f"Unsupported aspect_ratio {aspect_ratio!r} for resolution {resolution!r}.")
318
+ resolution_pair = RESOLUTION_RATIO_DICT[resolution][aspect_ratio]
319
+ data["resolution"] = {"H": resolution_pair["H"], "W": resolution_pair["W"]}
320
+ data["aspect_ratio"] = aspect_ratio
321
+ return data
322
+
323
+
324
+ def extract_json_object(text: str) -> dict[str, Any]:
325
+ """Extract a JSON object from raw model text."""
326
+ cleaned = text.strip()
327
+ fence_match = re.search(r"```(?:json)?\s*(.*?)\s*```", cleaned, flags=re.DOTALL)
328
+ if fence_match:
329
+ cleaned = fence_match.group(1).strip()
330
+ start = cleaned.find("{")
331
+ end = cleaned.rfind("}")
332
+ if start < 0 or end < start:
333
+ raise ValueError("Model response did not contain a JSON object.")
334
+ parsed = json.loads(cleaned[start : end + 1])
335
+ if not isinstance(parsed, dict):
336
+ raise ValueError("Model response JSON must be an object.")
337
+ return parsed
338
+
339
+
340
+ def normalize_openai_base_url(url: str) -> str:
341
+ """Normalize an OpenAI-compatible endpoint root."""
342
+ normalized = url.strip().rstrip("/")
343
+ if not normalized:
344
+ raise ValueError("endpoint_url cannot be empty.")
345
+ if not normalized.startswith(("http://", "https://")):
346
+ normalized = f"https://{normalized}"
347
+ if normalized.endswith("/chat/completions"):
348
+ normalized = normalized[: -len("/chat/completions")]
349
+ if normalized.endswith("/v1") or normalized.endswith("/openai"):
350
+ return normalized
351
+ return f"{normalized}/v1"
352
+
353
+
354
+ def _make_session(config: ChatClientConfig) -> requests.Session:
355
+ session = requests.Session()
356
+ retry = Retry(
357
+ total=config.connection_max_retries,
358
+ connect=config.connection_max_retries,
359
+ read=0,
360
+ status=2,
361
+ status_forcelist=(429, 500, 502, 503, 504),
362
+ allowed_methods=frozenset({"GET", "POST"}),
363
+ backoff_factor=0.5,
364
+ raise_on_status=False,
365
+ )
366
+ adapter = HTTPAdapter(
367
+ pool_connections=config.connection_pool_size,
368
+ pool_maxsize=config.connection_pool_size,
369
+ max_retries=retry,
370
+ )
371
+ session.mount("https://", adapter)
372
+ session.mount("http://", adapter)
373
+ return session
374
+
375
+
376
+ def _message_content_to_text(content: Any) -> str:
377
+ if isinstance(content, str) and content.strip():
378
+ return content
379
+ if isinstance(content, list):
380
+ parts: list[str] = []
381
+ for item in content:
382
+ if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
383
+ parts.append(item["text"])
384
+ text = "".join(parts).strip()
385
+ if text:
386
+ return text
387
+ raise ValueError("Chat completion message content is empty or unsupported.")
388
+
agentic_upsampling/rubric.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VLM critic prompt and score normalization for agentic T2I upsampling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import Any
7
+
8
+ from agentic_upsampling.constants import STRICT_OVERALL_THRESHOLD, STRICT_PROMPT_THRESHOLD
9
+ from agentic_upsampling.data import PromptItem
10
+ from agentic_upsampling.prompt_upsampler import extract_json_object
11
+
12
+ CATEGORY_SECTIONS = {
13
+ "text_commercial_ui": (
14
+ "Text/commercial/UI/logo checks: readable text for logos, labels, posters, "
15
+ "billboards, product packaging, or UI. Verify exact quoted strings, spelling, legibility, typography, "
16
+ "placement, layout, and whether commercial/UI intent is visually clear."
17
+ ),
18
+ "people_anatomy": (
19
+ "People/anatomy checks: if humans, human-like characters, body parts, portraits, or poses are present or "
20
+ "required by the prompt, inspect faces, eyes, hands, fingers, limbs, pose, proportions, expression, "
21
+ "clothing coherence, and physically possible interactions."
22
+ ),
23
+ "fantasy_cartoon_vector": (
24
+ "Fantasy/cartoon/vector/pixel-art checks: if a stylized medium is requested, judge whether stylization is "
25
+ "intentional and clean. Penalize messy geometry, inconsistent line language, broken vector shapes, muddy "
26
+ "palettes, and unwanted photorealistic texture."
27
+ ),
28
+ "photorealistic_physical": (
29
+ "Photorealistic/physical checks: if realism, physical objects, geometry, camera behavior, reflections, "
30
+ "transparent materials, shadows, perspective, scale, or contact matter, judge material realism, lighting "
31
+ "physics, lens plausibility, and whether objects obey real-world physical constraints."
32
+ ),
33
+ "general_scene": (
34
+ "General scene checks: always judge object completeness, layout clarity, subject relationships, background "
35
+ "coherence, visual appeal, and absence of obvious AI artifacts."
36
+ ),
37
+ }
38
+
39
+ SCORE_KEYS = (
40
+ "prompt_adherence_score",
41
+ "visual_quality_score",
42
+ "aesthetics_score",
43
+ "physical_plausibility_score",
44
+ "category_score",
45
+ "overall_score",
46
+ )
47
+ ISSUE_SEVERITIES = {"minor", "moderate", "severe"}
48
+
49
+
50
+ def all_category_check_text() -> str:
51
+ """Return the full non-classifying category checklist."""
52
+ return "\n".join(f"- {text}" for text in CATEGORY_SECTIONS.values())
53
+
54
+
55
+ def build_judge_prompt(item: PromptItem) -> str:
56
+ """Build the VLM critic prompt using the original user prompt as task context."""
57
+ return f"""You are an expert image quality analyst specializing in AI-generated image evaluation.
58
+ Your job is to produce an exhaustive defect report. Be meticulous: go beyond obvious problems and look carefully for subtle or background issues too.
59
+
60
+ The attached image was generated by an AI image model.
61
+
62
+ Analyze this image carefully and list every quality issue you observe.
63
+ For each issue give an approximate location and name the specific object or region involved. Report each distinct occurrence separately.
64
+
65
+ Before finalizing, check these areas, but only report issues you actually see:
66
+ - Physics: gravity violations, impossible collisions, implausible trajectories.
67
+ - Object deformation: morphing, melting, stretching of solid objects.
68
+ - Anatomy: distorted hands, faces, fingers, limbs, or wrong body proportions.
69
+ - Lighting and shadows: missing shadows or inconsistent illumination.
70
+ - Depth and scale: wrong spatial relationships, perspective issues, or scale inconsistencies.
71
+ - Text and numbers: garbled, floating, or incorrect text and digits.
72
+ - Visual quality: blur patches, noise, compression blocking, visual artifacts, or low-resolution regions.
73
+ - Color: inconsistent coloration, bleeding, or banding.
74
+ - Action correctness: prompted actions are correctly displayed.
75
+ - Prompt following: missing subjects, wrong objects, wrong setting, or wrong action.
76
+
77
+ Depending on the prompt, also apply the relevant checks below:
78
+ {all_category_check_text()}
79
+
80
+ The attached image was generated from this prompt:
81
+ {item.prompt}
82
+
83
+ Return exactly one JSON object, no markdown fences and no prose outside JSON:
84
+ {{
85
+ "prompt_adherence_score": <number 0-10>,
86
+ "visual_quality_score": <number 0-10>,
87
+ "aesthetics_score": <number 0-10>,
88
+ "physical_plausibility_score": <number 0-10>,
89
+ "category_score": <number 0-10>,
90
+ "text_rendering_score": <number 0-10 or null>,
91
+ "photorealism_score": <number 0-10 or null>,
92
+ "overall_score": <number 0-10>,
93
+ "issues": [
94
+ {{
95
+ "category": "<concise label>",
96
+ "description": "<what failed and where in the image>",
97
+ "severity": "minor" | "moderate" | "severe"
98
+ }}
99
+ ],
100
+ "prompt_elements": {{
101
+ "<key noun or action from the prompt>": "present" | "absent" | "partial"
102
+ }},
103
+ "category_findings": {{"<check area>": "<concise finding>"}},
104
+ "improvement_directives": ["<specific prompt rewrite instruction>"],
105
+ "rationale": "<2-4 concise sentences>"
106
+ }}
107
+ """
108
+
109
+
110
+ def parse_analysis_response(text: str) -> dict[str, Any]:
111
+ """Parse and normalize a raw VLM scoring response."""
112
+ return normalize_analysis(extract_json_object(text))
113
+
114
+
115
+ def normalize_analysis(data: dict[str, Any]) -> dict[str, Any]:
116
+ """Normalize VLM analysis into the schema used by selection and reporting."""
117
+ normalized = dict(data)
118
+ for key in SCORE_KEYS:
119
+ normalized[key] = _score(normalized.get(key))
120
+ for optional_key in ("text_rendering_score", "photorealism_score"):
121
+ if normalized.get(optional_key) is not None:
122
+ normalized[optional_key] = _score(normalized.get(optional_key))
123
+
124
+ normalized["issues"] = _normalize_issues(normalized.get("issues"))
125
+
126
+ directives = normalized.get("improvement_directives")
127
+ if isinstance(directives, list):
128
+ normalized["improvement_directives"] = [str(item) for item in directives if str(item).strip()]
129
+ else:
130
+ normalized["improvement_directives"] = []
131
+
132
+ findings = normalized.get("category_findings")
133
+ normalized["category_findings"] = findings if isinstance(findings, dict) else {}
134
+ normalized["threshold_cleared"] = clears_strict_threshold(normalized)
135
+ return normalized
136
+
137
+
138
+ def clears_strict_threshold(analysis: dict[str, Any]) -> bool:
139
+ """Return whether a candidate clears the strict quality milestone."""
140
+ if _score(analysis.get("overall_score")) < STRICT_OVERALL_THRESHOLD:
141
+ return False
142
+ if _score(analysis.get("prompt_adherence_score")) < STRICT_PROMPT_THRESHOLD:
143
+ return False
144
+ if _has_severe_issue(analysis.get("issues")):
145
+ return False
146
+ if analysis.get("text_rendering_score") is not None:
147
+ return _score(analysis.get("text_rendering_score")) >= STRICT_PROMPT_THRESHOLD
148
+ return True
149
+
150
+
151
+ def candidate_sort_key(candidate: dict[str, Any]) -> tuple[float, float, float, float, float, int]:
152
+ """Sort key for picking the best candidate."""
153
+ analysis = candidate.get("analysis", {})
154
+ iteration = int(candidate.get("iteration", 0))
155
+ return (
156
+ _score(analysis.get("overall_score")),
157
+ _score(analysis.get("prompt_adherence_score")),
158
+ _score(analysis.get("category_score")),
159
+ _score(analysis.get("visual_quality_score")),
160
+ _score(analysis.get("aesthetics_score")),
161
+ -iteration,
162
+ )
163
+
164
+
165
+ def compact_analysis_for_rewrite(analysis: dict[str, Any]) -> dict[str, Any]:
166
+ """Return the VLM fields most useful for the next prompt rewrite."""
167
+ keys = (
168
+ "overall_score",
169
+ "prompt_adherence_score",
170
+ "visual_quality_score",
171
+ "aesthetics_score",
172
+ "physical_plausibility_score",
173
+ "category_score",
174
+ "text_rendering_score",
175
+ "photorealism_score",
176
+ "issues",
177
+ "prompt_elements",
178
+ "category_findings",
179
+ "improvement_directives",
180
+ "rationale",
181
+ )
182
+ return {key: analysis.get(key) for key in keys if key in analysis}
183
+
184
+
185
+ def analysis_json_text(data: dict[str, Any]) -> str:
186
+ """Serialize compact analysis for prompt inclusion."""
187
+ return json.dumps(data, ensure_ascii=True, indent=2)
188
+
189
+
190
+ def _score(value: Any) -> float:
191
+ if value is None:
192
+ return 0.0
193
+ try:
194
+ number = float(value)
195
+ except (TypeError, ValueError):
196
+ return 0.0
197
+ return max(0.0, min(10.0, number))
198
+
199
+
200
+ def _normalize_issues(value: Any) -> list[dict[str, str]]:
201
+ if not isinstance(value, list):
202
+ return []
203
+ issues: list[dict[str, str]] = []
204
+ for item in value:
205
+ if not isinstance(item, dict):
206
+ continue
207
+ description = str(item.get("description") or "").strip()
208
+ if not description:
209
+ continue
210
+ category = str(item.get("category") or "unspecified").strip() or "unspecified"
211
+ severity = str(item.get("severity") or "moderate").strip().lower()
212
+ if severity not in ISSUE_SEVERITIES:
213
+ severity = "moderate"
214
+ issues.append({"category": category, "description": description, "severity": severity})
215
+ return issues
216
+
217
+
218
+ def _has_severe_issue(issues: Any) -> bool:
219
+ return any(isinstance(item, dict) and item.get("severity") == "severe" for item in issues or [])
220
+
agentic_upsampling/run.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI for standalone agentic Cosmos3 text-to-image prompt upsampling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ from agentic_upsampling.clients import (
10
+ ImageGenerationClient,
11
+ PromptRewriterClient,
12
+ VLMQualityJudge,
13
+ read_api_token,
14
+ read_optional_generation_auth_key,
15
+ )
16
+ from agentic_upsampling.constants import (
17
+ DEFAULT_ASPECT_RATIO,
18
+ DEFAULT_CRITIC_ENDPOINT_URL,
19
+ DEFAULT_CRITIC_MODEL,
20
+ DEFAULT_FLOW_SHIFT,
21
+ DEFAULT_GENERATION_AUTH_KEY_ENV,
22
+ DEFAULT_GENERATION_EXTRA_ARGS,
23
+ DEFAULT_GENERATION_MODEL,
24
+ DEFAULT_GEMINI_API_KEY_ENV,
25
+ DEFAULT_GUIDANCE,
26
+ DEFAULT_IMAGE_SIZE,
27
+ DEFAULT_LLM_EXTRA_BODY,
28
+ DEFAULT_MAX_ITERATIONS,
29
+ DEFAULT_NUM_STEPS,
30
+ DEFAULT_OPENAI_API_KEY_ENV,
31
+ DEFAULT_RESOLUTION,
32
+ DEFAULT_REWRITER_ENDPOINT_URL,
33
+ DEFAULT_REWRITER_MODEL,
34
+ DEFAULT_SAMPLES_PER_ITERATION,
35
+ DEFAULT_UPSAMPLER_ENDPOINT_URL,
36
+ DEFAULT_UPSAMPLER_MODEL,
37
+ )
38
+ from agentic_upsampling.data import load_prompt_items
39
+ from agentic_upsampling.extract_best import extract_best_images
40
+ from agentic_upsampling.io_utils import write_json_atomic
41
+ from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig, write_run_manifest
42
+
43
+
44
+ def parse_args() -> argparse.Namespace:
45
+ parser = argparse.ArgumentParser(description=__doc__)
46
+ input_group = parser.add_mutually_exclusive_group(required=True)
47
+ input_group.add_argument("--prompt", default=None, help="Single text prompt to run.")
48
+ input_group.add_argument("--prompts", type=Path, default=None, help="Path to .txt, .jsonl, or .csv prompts.")
49
+ parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of prompts to run.")
50
+ parser.add_argument("--output-dir", type=Path, required=True)
51
+ parser.add_argument("--overwrite", action="store_true")
52
+ parser.add_argument("--max-iterations", type=int, default=DEFAULT_MAX_ITERATIONS)
53
+ parser.add_argument("--samples-per-iteration", type=int, default=DEFAULT_SAMPLES_PER_ITERATION)
54
+ parser.add_argument("--seed-base", type=int, default=None)
55
+ parser.add_argument("--disable-early-stop", action="store_true")
56
+ parser.add_argument("--quiet", action="store_true")
57
+ parser.add_argument("--extract-best", action="store_true", help="Copy best images after the run finishes.")
58
+
59
+ parser.add_argument("--generation-endpoint", required=True)
60
+ parser.add_argument("--generation-model", default=DEFAULT_GENERATION_MODEL)
61
+ parser.add_argument("--size", default=DEFAULT_IMAGE_SIZE, help="vLLM-Omni image size in WIDTHxHEIGHT format.")
62
+ parser.add_argument("--generation-auth-key", default="")
63
+ parser.add_argument("--generation-auth-key-env", default=DEFAULT_GENERATION_AUTH_KEY_ENV)
64
+ parser.add_argument("--resolution", default=DEFAULT_RESOLUTION)
65
+ parser.add_argument("--aspect-ratio", default=DEFAULT_ASPECT_RATIO)
66
+ parser.add_argument("--num-steps", type=int, default=DEFAULT_NUM_STEPS)
67
+ parser.add_argument("--guidance", type=float, default=DEFAULT_GUIDANCE)
68
+ parser.add_argument("--flow-shift", type=float, default=DEFAULT_FLOW_SHIFT)
69
+ parser.add_argument("--generation-extra-args", type=json.loads, default=DEFAULT_GENERATION_EXTRA_ARGS)
70
+
71
+ parser.add_argument("--upsampler-endpoint-url", default=DEFAULT_UPSAMPLER_ENDPOINT_URL)
72
+ parser.add_argument("--upsampler-model", default=DEFAULT_UPSAMPLER_MODEL)
73
+ parser.add_argument("--rewriter-endpoint-url", default=DEFAULT_REWRITER_ENDPOINT_URL)
74
+ parser.add_argument("--rewriter-model", default=DEFAULT_REWRITER_MODEL)
75
+ parser.add_argument("--openai-api-key-env", default=DEFAULT_OPENAI_API_KEY_ENV)
76
+ parser.add_argument("--openai-api-key-file", type=Path, default=None)
77
+ parser.add_argument("--llm-extra-body", type=json.loads, default=DEFAULT_LLM_EXTRA_BODY)
78
+ parser.add_argument("--initial-negative-prompt", default="")
79
+
80
+ parser.add_argument("--critic-endpoint-url", default=DEFAULT_CRITIC_ENDPOINT_URL)
81
+ parser.add_argument("--critic-model", default=DEFAULT_CRITIC_MODEL)
82
+ parser.add_argument("--gemini-api-key-env", default=DEFAULT_GEMINI_API_KEY_ENV)
83
+ parser.add_argument("--gemini-api-key-file", type=Path, default=None)
84
+ return parser.parse_args()
85
+
86
+
87
+ def main() -> int:
88
+ args = parse_args()
89
+ args.output_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ items = load_prompt_items(prompt=args.prompt, prompts_path=args.prompts, limit=args.limit)
92
+ if not items:
93
+ raise RuntimeError("No prompts selected.")
94
+ if args.samples_per_iteration < 1:
95
+ raise ValueError("--samples-per-iteration must be >= 1.")
96
+ if not isinstance(args.generation_extra_args, dict):
97
+ raise ValueError("--generation-extra-args must decode to a JSON object.")
98
+
99
+ openai_token = read_api_token(args.openai_api_key_env, args.openai_api_key_file)
100
+ gemini_token = read_api_token(args.gemini_api_key_env, args.gemini_api_key_file)
101
+ generation_auth_key = read_optional_generation_auth_key(args.generation_auth_key, args.generation_auth_key_env)
102
+
103
+ write_json_atomic(
104
+ args.output_dir / "run_config.json",
105
+ {
106
+ "selected_prompts": len(items),
107
+ "max_iterations": args.max_iterations,
108
+ "samples_per_iteration": args.samples_per_iteration,
109
+ "early_stop": not args.disable_early_stop,
110
+ "generation_endpoint": args.generation_endpoint,
111
+ "generation_model": args.generation_model,
112
+ "size": args.size,
113
+ "resolution": args.resolution,
114
+ "aspect_ratio": args.aspect_ratio,
115
+ "num_steps": args.num_steps,
116
+ "guidance": args.guidance,
117
+ "flow_shift": args.flow_shift,
118
+ "generation_extra_args": args.generation_extra_args,
119
+ "upsampler_endpoint_url": args.upsampler_endpoint_url,
120
+ "upsampler_model": args.upsampler_model,
121
+ "rewriter_endpoint_url": args.rewriter_endpoint_url,
122
+ "rewriter_model": args.rewriter_model,
123
+ "llm_extra_body": args.llm_extra_body,
124
+ "critic_endpoint_url": args.critic_endpoint_url,
125
+ "critic_model": args.critic_model,
126
+ "initial_negative_prompt": args.initial_negative_prompt,
127
+ },
128
+ )
129
+
130
+ rewriter = PromptRewriterClient(
131
+ api_token=openai_token,
132
+ upsampler_endpoint_url=args.upsampler_endpoint_url,
133
+ upsampler_model=args.upsampler_model,
134
+ rewriter_endpoint_url=args.rewriter_endpoint_url,
135
+ rewriter_model=args.rewriter_model,
136
+ extra_body=args.llm_extra_body,
137
+ resolution=args.resolution,
138
+ aspect_ratio=args.aspect_ratio,
139
+ )
140
+ generator = ImageGenerationClient(
141
+ endpoint=args.generation_endpoint,
142
+ auth_key=generation_auth_key,
143
+ model=args.generation_model,
144
+ size=args.size,
145
+ num_steps=args.num_steps,
146
+ guidance=args.guidance,
147
+ flow_shift=args.flow_shift,
148
+ extra_args=args.generation_extra_args,
149
+ )
150
+ judge = VLMQualityJudge(
151
+ api_token=gemini_token,
152
+ endpoint_url=args.critic_endpoint_url,
153
+ model=args.critic_model,
154
+ )
155
+ runner = AgenticUpsamplerRunner(
156
+ rewriter=rewriter,
157
+ generator=generator,
158
+ judge=judge,
159
+ config=RunnerConfig(
160
+ output_dir=args.output_dir,
161
+ max_iterations=args.max_iterations,
162
+ samples_per_iteration=args.samples_per_iteration,
163
+ overwrite=args.overwrite,
164
+ seed_base=args.seed_base,
165
+ initial_negative_prompt=args.initial_negative_prompt,
166
+ early_stop=not args.disable_early_stop,
167
+ verbose=not args.quiet,
168
+ ),
169
+ )
170
+
171
+ results = [runner.run_item_safely(item) for item in items]
172
+ write_run_manifest(args.output_dir, results)
173
+ failures = sum(1 for item in results if item.get("error"))
174
+ summary = {"selected_prompts": len(items), "completed": len(items) - failures, "failures": failures}
175
+ write_json_atomic(args.output_dir / "summary.json", summary)
176
+ print(json.dumps(summary, indent=2), flush=True)
177
+
178
+ if args.extract_best and not failures:
179
+ export_dir = args.output_dir / "best_generations"
180
+ extract_best_images(args.output_dir, export_dir, overwrite=args.overwrite)
181
+ print(f"Exported best images to {export_dir}", flush=True)
182
+ return 1 if failures else 0
183
+
184
+
185
+ if __name__ == "__main__":
186
+ raise SystemExit(main())
187
+
agentic_upsampling/runner.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agentic text-to-image prompt upsampling orchestration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import traceback
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Protocol
11
+
12
+ from agentic_upsampling.clients import GenerationOutput
13
+ from agentic_upsampling.constants import DEFAULT_JPEG_QUALITY, DEFAULT_MAX_ITERATIONS, DEFAULT_SAMPLES_PER_ITERATION
14
+ from agentic_upsampling.data import PromptItem, prompt_dir_name
15
+ from agentic_upsampling.io_utils import read_json, write_json_atomic
16
+ from agentic_upsampling.rubric import candidate_sort_key
17
+
18
+
19
+ class RewriterLike(Protocol):
20
+ def initial_prompt(self, item: PromptItem) -> dict[str, Any]:
21
+ """Create an initial prompt."""
22
+
23
+ def rewrite_prompt_pair(
24
+ self,
25
+ item: PromptItem,
26
+ previous_prompt: dict[str, Any],
27
+ previous_negative_prompt: str,
28
+ previous_analysis: dict[str, Any],
29
+ history: list[dict[str, Any]],
30
+ ) -> tuple[dict[str, Any], str]:
31
+ """Jointly rewrite a positive prompt and negative prompt."""
32
+
33
+
34
+ class GeneratorLike(Protocol):
35
+ def generate(
36
+ self,
37
+ *,
38
+ prompt_json: dict[str, Any],
39
+ prompt_id: str,
40
+ output_dir: Path,
41
+ seed: int | None = None,
42
+ negative_prompt: str = "",
43
+ jpeg_quality: int = DEFAULT_JPEG_QUALITY,
44
+ ) -> GenerationOutput:
45
+ """Generate one image."""
46
+
47
+
48
+ class JudgeLike(Protocol):
49
+ def score_image(
50
+ self,
51
+ *,
52
+ item: PromptItem,
53
+ image_path: Path,
54
+ ) -> dict[str, Any]:
55
+ """Score one image."""
56
+
57
+
58
+ @dataclass(frozen=True, slots=True)
59
+ class RunnerConfig:
60
+ """Runtime settings for the agentic loop."""
61
+
62
+ output_dir: Path
63
+ max_iterations: int = DEFAULT_MAX_ITERATIONS
64
+ samples_per_iteration: int = DEFAULT_SAMPLES_PER_ITERATION
65
+ overwrite: bool = False
66
+ seed_base: int | None = None
67
+ jpeg_quality: int = DEFAULT_JPEG_QUALITY
68
+ initial_negative_prompt: str = ""
69
+ early_stop: bool = True
70
+ verbose: bool = True
71
+
72
+ def __post_init__(self) -> None:
73
+ if self.max_iterations < 1:
74
+ raise ValueError("max_iterations must be >= 1.")
75
+ if self.samples_per_iteration < 1:
76
+ raise ValueError("samples_per_iteration must be >= 1.")
77
+
78
+
79
+ @dataclass(frozen=True, slots=True)
80
+ class IterationPrompt:
81
+ """Positive and negative prompts prepared for one iteration."""
82
+
83
+ prompt_json: dict[str, Any]
84
+ negative_prompt: str
85
+
86
+
87
+ class AgenticUpsamplerRunner:
88
+ """Run the iterative prompt rewrite, generate, and judge loop."""
89
+
90
+ rewriter: RewriterLike
91
+ generator: GeneratorLike
92
+ judge: JudgeLike
93
+ config: RunnerConfig
94
+
95
+ def __init__(
96
+ self,
97
+ *,
98
+ rewriter: RewriterLike,
99
+ generator: GeneratorLike,
100
+ judge: JudgeLike,
101
+ config: RunnerConfig,
102
+ ) -> None:
103
+ self.rewriter = rewriter
104
+ self.generator = generator
105
+ self.judge = judge
106
+ self.config = config
107
+
108
+ def run_item(self, item: PromptItem) -> dict[str, Any]:
109
+ """Run all iterations for one prompt item and persist the best candidate."""
110
+ item_dir = self.config.output_dir / prompt_dir_name(item)
111
+ item_dir.mkdir(parents=True, exist_ok=True)
112
+ (item_dir / "failure.json").unlink(missing_ok=True)
113
+ (item_dir / "incomplete.json").unlink(missing_ok=True)
114
+ self._log(f"[prompt {item.prompt_id}] start")
115
+ candidates: list[dict[str, Any]] = []
116
+ previous_prompt: dict[str, Any] | None = None
117
+ previous_analysis: dict[str, Any] | None = None
118
+ previous_negative_prompt = self.config.initial_negative_prompt.strip()
119
+ incomplete_error: dict[str, Any] | None = None
120
+
121
+ for iteration in range(self.config.max_iterations):
122
+ iteration_dir = item_dir / f"iter_{iteration:02d}"
123
+ candidate = None if self.config.overwrite else self._load_iteration(iteration_dir, iteration)
124
+ if candidate is None:
125
+ try:
126
+ candidate = self._run_iteration(
127
+ item,
128
+ iteration_dir,
129
+ iteration,
130
+ previous_prompt,
131
+ previous_analysis,
132
+ previous_negative_prompt,
133
+ candidates,
134
+ )
135
+ except Exception as exc:
136
+ if not candidates:
137
+ raise
138
+ incomplete_error = {
139
+ "iteration": iteration,
140
+ "error": repr(exc),
141
+ "traceback": traceback.format_exc(),
142
+ }
143
+ write_json_atomic(item_dir / "incomplete.json", incomplete_error)
144
+ self._log(f"[prompt {item.prompt_id}] incomplete at iter={iteration}: {exc!r}")
145
+ break
146
+
147
+ candidates.append(candidate)
148
+ previous_prompt = candidate["prompt_json"]
149
+ previous_analysis = candidate["analysis"]
150
+ previous_negative_prompt = str(candidate.get("negative_prompt") or "")
151
+ if self.config.early_stop and bool(candidate["analysis"].get("threshold_cleared")):
152
+ self._log(f"[prompt {item.prompt_id}] early stop at iter={iteration}")
153
+ break
154
+
155
+ return self.finalize_item(item, candidates, incomplete_error=incomplete_error)
156
+
157
+ def run_item_safely(self, item: PromptItem) -> dict[str, Any]:
158
+ """Run one item and convert failures into structured records."""
159
+ try:
160
+ return self.run_item(item)
161
+ except Exception as exc:
162
+ self._log(f"[prompt {item.prompt_id}] failed: {exc!r}")
163
+ failure = {
164
+ "prompt_id": item.prompt_id,
165
+ "prompt": item.prompt,
166
+ "error": repr(exc),
167
+ "traceback": traceback.format_exc(),
168
+ }
169
+ failure_path = self.config.output_dir / prompt_dir_name(item) / "failure.json"
170
+ write_json_atomic(failure_path, failure)
171
+ return {"prompt_id": item.prompt_id, "error": repr(exc), "failure_path": str(failure_path)}
172
+
173
+ def _run_iteration(
174
+ self,
175
+ item: PromptItem,
176
+ iteration_dir: Path,
177
+ iteration: int,
178
+ previous_prompt: dict[str, Any] | None,
179
+ previous_analysis: dict[str, Any] | None,
180
+ previous_negative_prompt: str,
181
+ candidates: list[dict[str, Any]],
182
+ ) -> dict[str, Any]:
183
+ prepared = self.prepare_iteration_prompt(
184
+ item,
185
+ iteration_dir,
186
+ iteration,
187
+ previous_prompt,
188
+ previous_analysis,
189
+ previous_negative_prompt,
190
+ candidates,
191
+ )
192
+ sample_candidates, sample_errors = self._run_iteration_samples(
193
+ item,
194
+ iteration_dir,
195
+ iteration,
196
+ prepared.prompt_json,
197
+ prepared.negative_prompt,
198
+ )
199
+ return self.finalize_iteration(item, iteration_dir, iteration, sample_candidates, sample_errors)
200
+
201
+ def _run_iteration_samples(
202
+ self,
203
+ item: PromptItem,
204
+ iteration_dir: Path,
205
+ iteration: int,
206
+ prompt_json: dict[str, Any],
207
+ negative_prompt: str,
208
+ ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
209
+ """Generate seed samples concurrently, then judge successful images in sample order."""
210
+ generation_outputs: dict[int, GenerationOutput] = {}
211
+ sample_errors: list[dict[str, Any]] = []
212
+ with ThreadPoolExecutor(max_workers=self.config.samples_per_iteration) as executor:
213
+ future_to_sample_index = {
214
+ executor.submit(
215
+ self.run_generation_sample,
216
+ item,
217
+ iteration_dir,
218
+ sample_index,
219
+ prompt_json,
220
+ negative_prompt,
221
+ ): sample_index
222
+ for sample_index in range(self.config.samples_per_iteration)
223
+ }
224
+ for future in as_completed(future_to_sample_index):
225
+ sample_index = future_to_sample_index[future]
226
+ try:
227
+ generation_outputs[sample_index] = future.result()
228
+ except Exception as exc:
229
+ sample_errors.append(self._record_sample_error(item, iteration_dir, iteration, sample_index, exc))
230
+
231
+ sample_candidates: list[dict[str, Any]] = []
232
+ for sample_index in range(self.config.samples_per_iteration):
233
+ generation = generation_outputs.get(sample_index)
234
+ if generation is None:
235
+ continue
236
+ try:
237
+ sample_candidates.append(
238
+ self.judge_iteration_sample(
239
+ item,
240
+ iteration_dir,
241
+ iteration,
242
+ sample_index,
243
+ prompt_json,
244
+ negative_prompt,
245
+ generation,
246
+ )
247
+ )
248
+ except Exception as exc:
249
+ sample_errors.append(self._record_sample_error(item, iteration_dir, iteration, sample_index, exc))
250
+ return sample_candidates, sample_errors
251
+
252
+ def _record_sample_error(
253
+ self,
254
+ item: PromptItem,
255
+ iteration_dir: Path,
256
+ iteration: int,
257
+ sample_index: int,
258
+ exc: Exception,
259
+ ) -> dict[str, Any]:
260
+ """Persist one per-sample failure record."""
261
+ error = {"sample_index": sample_index, "error": repr(exc), "traceback": traceback.format_exc()}
262
+ write_json_atomic(self._sample_dir(iteration_dir, sample_index) / "failure.json", error)
263
+ self._log(f"[prompt {item.prompt_id}] iter={iteration} sample={sample_index} failed: {exc!r}")
264
+ return error
265
+
266
+ def prepare_iteration_prompt(
267
+ self,
268
+ item: PromptItem,
269
+ iteration_dir: Path,
270
+ iteration: int,
271
+ previous_prompt: dict[str, Any] | None,
272
+ previous_analysis: dict[str, Any] | None,
273
+ previous_negative_prompt: str,
274
+ candidates: list[dict[str, Any]],
275
+ ) -> IterationPrompt:
276
+ """Prepare and persist the positive/negative prompt pair for one iteration."""
277
+ iteration_dir.mkdir(parents=True, exist_ok=True)
278
+ self._log(f"[prompt {item.prompt_id}] iter={iteration} start")
279
+ if iteration == 0 or previous_prompt is None or previous_analysis is None:
280
+ prompt_json = self.rewriter.initial_prompt(item)
281
+ negative_prompt = self.config.initial_negative_prompt.strip()
282
+ else:
283
+ prompt_json, negative_prompt = self.rewriter.rewrite_prompt_pair(
284
+ item,
285
+ previous_prompt,
286
+ previous_negative_prompt,
287
+ previous_analysis,
288
+ candidates,
289
+ )
290
+ negative_prompt = negative_prompt.strip()
291
+ write_json_atomic(iteration_dir / "prompt.json", prompt_json)
292
+ write_json_atomic(iteration_dir / "negative_prompt.json", {"negative_prompt": negative_prompt})
293
+ return IterationPrompt(prompt_json=prompt_json, negative_prompt=negative_prompt)
294
+
295
+ def _run_iteration_sample(
296
+ self,
297
+ item: PromptItem,
298
+ iteration_dir: Path,
299
+ iteration: int,
300
+ sample_index: int,
301
+ prompt_json: dict[str, Any],
302
+ negative_prompt: str,
303
+ ) -> dict[str, Any]:
304
+ generation = self.run_generation_sample(item, iteration_dir, sample_index, prompt_json, negative_prompt)
305
+ return self.judge_iteration_sample(
306
+ item,
307
+ iteration_dir,
308
+ iteration,
309
+ sample_index,
310
+ prompt_json,
311
+ negative_prompt,
312
+ generation,
313
+ )
314
+
315
+ def run_generation_sample(
316
+ self,
317
+ item: PromptItem,
318
+ iteration_dir: Path,
319
+ sample_index: int,
320
+ prompt_json: dict[str, Any],
321
+ negative_prompt: str,
322
+ ) -> GenerationOutput:
323
+ """Generate one sample image for an iteration."""
324
+ sample_dir = self._sample_dir(iteration_dir, sample_index)
325
+ sample_dir.mkdir(parents=True, exist_ok=True)
326
+ self._log(f"[prompt {item.prompt_id}] sample={sample_index} generate")
327
+ return self.generator.generate(
328
+ prompt_json=prompt_json,
329
+ prompt_id=item.prompt_id,
330
+ output_dir=sample_dir,
331
+ seed=self._sample_seed(sample_index),
332
+ negative_prompt=negative_prompt,
333
+ jpeg_quality=self.config.jpeg_quality,
334
+ )
335
+
336
+ def judge_iteration_sample(
337
+ self,
338
+ item: PromptItem,
339
+ iteration_dir: Path,
340
+ iteration: int,
341
+ sample_index: int,
342
+ prompt_json: dict[str, Any],
343
+ negative_prompt: str,
344
+ generation: GenerationOutput,
345
+ ) -> dict[str, Any]:
346
+ """Judge one generated sample and persist its candidate metadata."""
347
+ sample_dir = self._sample_dir(iteration_dir, sample_index)
348
+ analysis = self.judge.score_image(item=item, image_path=generation.image_path)
349
+ self._log(f"[prompt {item.prompt_id}] iter={iteration} sample={sample_index} score={analysis.get('overall_score')}")
350
+ analysis_path = sample_dir / "analysis.json"
351
+ write_json_atomic(analysis_path, analysis)
352
+ candidate = {
353
+ "prompt_id": item.prompt_id,
354
+ "iteration": iteration,
355
+ "sample_index": sample_index,
356
+ "prompt_path": str(iteration_dir / "prompt.json"),
357
+ "image_path": str(generation.image_path),
358
+ "analysis_path": str(analysis_path),
359
+ "generation_meta_path": str(generation.meta_path),
360
+ "negative_prompt_path": str(iteration_dir / "negative_prompt.json"),
361
+ "negative_prompt": negative_prompt,
362
+ "prompt_json": prompt_json,
363
+ "analysis": analysis,
364
+ }
365
+ write_json_atomic(sample_dir / "meta.json", candidate)
366
+ return candidate
367
+
368
+ def finalize_iteration(
369
+ self,
370
+ item: PromptItem,
371
+ iteration_dir: Path,
372
+ iteration: int,
373
+ sample_candidates: list[dict[str, Any]],
374
+ sample_errors: list[dict[str, Any]],
375
+ ) -> dict[str, Any]:
376
+ """Select and persist the best sample candidate for one iteration."""
377
+ if not sample_candidates:
378
+ raise RuntimeError(f"All {self.config.samples_per_iteration} samples failed for iteration {iteration}.")
379
+ write_json_atomic(iteration_dir / "samples.json", sample_candidates)
380
+ candidate = dict(max(sample_candidates, key=candidate_sort_key))
381
+ candidate["samples"] = sample_candidates
382
+ candidate["sample_count"] = len(sample_candidates)
383
+ candidate["selected_sample_index"] = candidate["sample_index"]
384
+ if sample_errors:
385
+ candidate["sample_errors"] = sample_errors
386
+ write_json_atomic(iteration_dir / "sample_failures.json", sample_errors)
387
+ write_json_atomic(iteration_dir / "meta.json", candidate)
388
+ self._log(
389
+ f"[prompt {item.prompt_id}] iter={iteration} best_sample={candidate['selected_sample_index']} "
390
+ f"score={candidate['analysis'].get('overall_score')} samples={len(sample_candidates)}"
391
+ )
392
+ return candidate
393
+
394
+ def finalize_item(
395
+ self,
396
+ item: PromptItem,
397
+ candidates: list[dict[str, Any]],
398
+ *,
399
+ incomplete_error: dict[str, Any] | None = None,
400
+ ) -> dict[str, Any]:
401
+ """Persist and return the best candidate summary for a completed or incomplete item."""
402
+ if not candidates:
403
+ raise RuntimeError(f"No candidates produced for prompt {item.prompt_id}.")
404
+ item_dir = self.config.output_dir / prompt_dir_name(item)
405
+ best = max(candidates, key=candidate_sort_key)
406
+ summary = {
407
+ "prompt_id": item.prompt_id,
408
+ "prompt": item.prompt,
409
+ "best_iteration": best["iteration"],
410
+ "best_score": best["analysis"].get("overall_score"),
411
+ "threshold_cleared_any": any(bool(candidate["analysis"].get("threshold_cleared")) for candidate in candidates),
412
+ "best": best,
413
+ "iterations": candidates,
414
+ }
415
+ if incomplete_error is not None:
416
+ summary["incomplete_error"] = incomplete_error
417
+ write_json_atomic(item_dir / "best.json", summary)
418
+ self._log(f"[prompt {item.prompt_id}] done best_iter={summary['best_iteration']} best_score={summary['best_score']}")
419
+ return summary
420
+
421
+ def _log(self, message: str) -> None:
422
+ if self.config.verbose:
423
+ print(message, flush=True)
424
+
425
+ def _sample_seed(self, sample_index: int) -> int | None:
426
+ if self.config.seed_base is None:
427
+ return None
428
+ return self.config.seed_base + sample_index
429
+
430
+ def _sample_dir(self, iteration_dir: Path, sample_index: int) -> Path:
431
+ if self.config.samples_per_iteration == 1:
432
+ return iteration_dir
433
+ return iteration_dir / f"sample_{sample_index:02d}"
434
+
435
+ @staticmethod
436
+ def _load_iteration(iteration_dir: Path, iteration: int) -> dict[str, Any] | None:
437
+ meta_path = iteration_dir / "meta.json"
438
+ prompt_path = iteration_dir / "prompt.json"
439
+ if not (meta_path.exists() and prompt_path.exists()):
440
+ return None
441
+ meta = read_json(meta_path)
442
+ analysis_path = Path(str(meta.get("analysis_path") or iteration_dir / "analysis.json"))
443
+ image_path = Path(str(meta.get("image_path") or iteration_dir / "image.jpg"))
444
+ if not (analysis_path.exists() and image_path.exists()):
445
+ return None
446
+ meta["iteration"] = iteration
447
+ meta["prompt_json"] = read_json(prompt_path)
448
+ meta["analysis"] = read_json(analysis_path)
449
+ negative_prompt_path = iteration_dir / "negative_prompt.json"
450
+ if "negative_prompt" not in meta and negative_prompt_path.exists():
451
+ negative_prompt_data = read_json(negative_prompt_path)
452
+ meta["negative_prompt"] = str(negative_prompt_data.get("negative_prompt") or "")
453
+ meta["negative_prompt_path"] = str(negative_prompt_path)
454
+ meta.setdefault("negative_prompt", "")
455
+ samples_path = iteration_dir / "samples.json"
456
+ if samples_path.exists():
457
+ samples = json.loads(samples_path.read_text(encoding="utf-8"))
458
+ if isinstance(samples, list):
459
+ meta["samples"] = samples
460
+ meta["sample_count"] = len(samples)
461
+ return meta
462
+
463
+
464
+ def write_run_manifest(output_dir: Path, results: list[dict[str, Any]]) -> None:
465
+ """Write compact run-level manifest files."""
466
+ manifest_path = output_dir / "manifest.jsonl"
467
+ failures_path = output_dir / "failures.jsonl"
468
+ manifest_path.unlink(missing_ok=True)
469
+ failures_path.unlink(missing_ok=True)
470
+ for result in results:
471
+ target = failures_path if result.get("error") else manifest_path
472
+ with target.open("a", encoding="utf-8") as f:
473
+ f.write(json.dumps(result, ensure_ascii=True, separators=(",", ":")) + "\n")
474
+
assets/benchmark-text2image-leaderboard-all-models.jpg ADDED

Git LFS Details

  • SHA256: 380c72e4df9a1b95d7929d7af082ad0af8bd885f160a81ae3b66661040923c9e
  • Pointer size: 131 Bytes
  • Size of remote file: 432 kB
assets/benchmark-text2image-leaderboard.png ADDED

Git LFS Details

  • SHA256: 10458182a0c5dfe07ae295f20fabf60a4e4d2a59633a27ff15c50ac1f9baae84
  • Pointer size: 132 Bytes
  • Size of remote file: 4.12 MB
assets/benchmark-text2image.png ADDED

Git LFS Details

  • SHA256: 55bdd6bc617832086be44c3d63f03cb28426dc352a97e8d3c10ae7967b94c4e9
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
assets/example_caption.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "subjects": [
3
+ {
4
+ "description": "Two damp human hands working a spinning cylinder of wet gray clay on a pottery wheel, fingers gently pinching and pulling the walls upward to form a narrow neck and rounded belly",
5
+ "appearance_details": "Hands coated with a thin sheen of slip, glistening with water; fingertips leave subtle ridges and fingerprints in the clay; knuckles slightly creased, nails short and rimmed with gray clay; small splatters of clay on the back of the wrists",
6
+ "relationship": "The primary actor shaping the central spinning clay form on the wheel",
7
+ "location": "Center foreground, framing the clay column",
8
+ "relative_size": "Hands and clay occupy roughly the central 60 percent of the frame",
9
+ "orientation": "Hands angled inward from left and right, palms cupping the clay symmetrically",
10
+ "pose": "Both hands curled around the clay, thumbs and index fingers pinched at the upper neck while remaining fingers support the rounded belly below",
11
+ "clothing": "",
12
+ "expression": "",
13
+ "gender": "Unknown",
14
+ "age": "Adult",
15
+ "skin_tone_and_texture": "Light to medium skin tone, slick with watery slip giving a glossy sheen, fine pores and small wrinkles visible at knuckles",
16
+ "facial_features": "",
17
+ "number_of_subjects": 1,
18
+ "number_of_arms": 2,
19
+ "number_of_legs": 0,
20
+ "number_of_hands": 2,
21
+ "number_of_fingers": 10
22
+ },
23
+ {
24
+ "description": "A spinning cylinder of wet gray stoneware clay being formed into a vase shape with a narrow neck and rounded belly",
25
+ "appearance_details": "Concentric horizontal rings spiral around the surface from the rotation; glistening film of water; subtle fingerprint impressions; small drips of slip running down the lower belly onto the wheel head",
26
+ "relationship": "The object being shaped by the hands; central focal point of the scene",
27
+ "location": "Absolute center of the frame, rising vertically from the wheel head",
28
+ "relative_size": "Approximately one third of the frame height, dominant central form",
29
+ "orientation": "Vertical axis, rotating",
30
+ "pose": "Upright cylindrical-to-vase form mid-throw, neck tapering, belly bulging",
31
+ "clothing": "",
32
+ "expression": "",
33
+ "gender": "N/A",
34
+ "age": "N/A",
35
+ "skin_tone_and_texture": "",
36
+ "facial_features": "",
37
+ "number_of_subjects": 1,
38
+ "number_of_arms": 0,
39
+ "number_of_legs": 0,
40
+ "number_of_hands": 0,
41
+ "number_of_fingers": 0
42
+ }
43
+ ],
44
+ "subject_details": {
45
+ "clay_state": "Wet gray stoneware, plastic and pliable, glistening with slip",
46
+ "wheel_motion": "Visible motion blur in concentric ring patterns indicating rotation",
47
+ "forming_stage": "Mid-throw, transitioning from cylinder to vase form with narrow neck and rounded belly",
48
+ "slip_splatter": "Splattered clay droplets dot the black apron and the wheel tray",
49
+ "tools_in_background": "Wooden ribs, wire cutter, sponge, and metal trimming tools softly blurred behind"
50
+ },
51
+ "background_setting": "A dim pottery studio with a matte black wheel tray and splash pan surrounding the spinning wheel head. Behind the potter, softly blurred, sit a row of throwing tools: wooden ribs, a metal kidney, a wire cutter coiled on a small shelf, and a sponge in a shallow water bowl. The black apron worn by the potter forms part of the lower background, dotted with dried clay flecks. The studio walls are deep charcoal, allowing the warm directional light to sculpt the central action.",
52
+ "lighting": {
53
+ "conditions": "Warm, controlled studio lighting with a single key source",
54
+ "direction": "From the right side of the frame, slightly elevated",
55
+ "shadows": "Soft, defined shadows falling to the left of the clay form and beneath the hands; subtle contact shadow on the wheel head",
56
+ "illumination_effect": "Highlights the wet sheen on the clay, accentuates concentric ring textures and fingerprint detail, and creates a moody chiaroscuro that isolates the hands and clay from the darker background"
57
+ },
58
+ "aesthetics": {
59
+ "composition": "Centered symmetrical composition with the clay form on the vertical axis, hands framing it from both sides; rule-of-thirds intersections align with the neck and belly of the vase",
60
+ "color_scheme": "Muted earth tones dominated by cool grays of the clay, warm amber highlights from the studio light, deep blacks of the apron and wheel tray, and natural skin tones",
61
+ "mood_atmosphere": "Meditative, tactile, focused craftsmanship",
62
+ "patterns": "Concentric horizontal rings spiraling around the clay surface from the wheel's rotation"
63
+ },
64
+ "cinematography": {
65
+ "framing": "Medium close-up",
66
+ "camera_angle": "Slightly above rim height, looking gently down onto the hands and clay",
67
+ "depth_of_field": "Shallow",
68
+ "focus": "Crisp focus on the hands and the wet clay surface",
69
+ "lens_focal_length": "35mm"
70
+ },
71
+ "style_medium": "Photography",
72
+ "artistic_style": "Photorealistic studio product/process photography",
73
+ "context": "Editorial or artisanal documentation of a hand-thrown ceramic vase being formed on a pottery wheel",
74
+ "text_and_signage_elements": [],
75
+ "quadrant_scan": {
76
+ "top_left": "Softly blurred dark background with faint warm rim light catching the edge of a wooden rib tool on a shelf",
77
+ "top_right": "Warm directional studio light source area; brightest highlights spill across the upper right, illuminating the narrow neck of the vase",
78
+ "bottom_left": "Edge of the black apron speckled with dried clay flecks and the curved rim of the splash pan in shadow",
79
+ "bottom_right": "Wheel tray dotted with splattered clay droplets and trickles of slip; partial view of the rotating wheel head",
80
+ "absolute_center": "The wet gray clay vase form mid-throw, cradled by two slip-coated hands pinching the neck and supporting the rounded belly, concentric rings glistening under warm light"
81
+ },
82
+ "comprehensive_t2i_caption": "A photorealistic studio photograph captured at slightly above rim height with a 35mm lens shows a pottery wheel in motion at the center of the frame. A cylinder of wet gray stoneware clay spins, its surface scored with fine concentric rings from rotation and glistening with a thin film of water and slip. Two damp human hands, coated in pale gray slip, gently pinch and pull the walls upward, the thumbs and index fingers narrowing the neck while the lower fingers cradle and shape a rounded belly, transforming the cylinder into a vase. Warm directional studio light from the right rakes across the scene, accentuating the sheen of water, the subtle ridges of fingerprints, and the soft modeling of the clay's curves, while casting gentle shadows to the left. Splattered droplets of clay dot a black apron worn by the potter and the matte black wheel tray and splash pan beneath. In the softly blurred background, a row of pottery tools — wooden ribs, a metal kidney, a wire cutter, and a damp sponge in a water bowl — sit on a charcoal-toned shelf. The depth of field is shallow, holding crisp focus on the hands and clay while the surroundings dissolve into a moody, warm-toned haze. The atmosphere is meditative and tactile, celebrating the intimate craftsmanship of hand-thrown ceramics.",
83
+ "resolution": {
84
+ "H": 1024,
85
+ "W": 1024
86
+ },
87
+ "aspect_ratio": "1,1"
88
+ }
assets/example_image.png ADDED

Git LFS Details

  • SHA256: 478903c6adf090f6dbf8c584e073061caf0451053360cae02b9fea3b84c132eb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
assets/more_images.jpg ADDED

Git LFS Details

  • SHA256: 48efdeacbf9824941fff4348c5780a87fa760ebd3fa1b22d87fa9b51fba72120
  • Pointer size: 132 Bytes
  • Size of remote file: 4.25 MB
assets/original_prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Photorealistic studio photograph of a pottery wheel in motion, a cylinder of wet gray clay spinning with concentric rings. Two hands, damp and coated with slip, gently pinch and pull the walls upward to form a narrow neck and rounded belly like a vase. Directional warm studio light from the right highlights the sheen of water on clay and the texture of fingerprints. Splattered clay dots the black apron and wheel tray. Camera slightly above rim height, 35mm, crisp focus on hands and clay, background tools softly blurred.
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
3
+ }
checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
config.json ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "allow_patterns_overrides": [
3
+ "*/*.safetensors"
4
+ ],
5
+ "architectures": [
6
+ "Cosmos3ForConditionalGeneration"
7
+ ],
8
+ "image_token_id": 151655,
9
+ "model": {
10
+ "_recursive_": false,
11
+ "_target": "omni_mot_model",
12
+ "config": {
13
+ "_type": "omni_mot_model_config",
14
+ "action_gen": false,
15
+ "activation_checkpointing": {
16
+ "_type": "activation_checkpointing_config",
17
+ "determinism_check": "default",
18
+ "mode": "full",
19
+ "preserve_rng_state": true,
20
+ "save_ops_regex": [
21
+ "fmha"
22
+ ]
23
+ },
24
+ "causal_training_strategy": "none",
25
+ "compile": {
26
+ "_type": "compile_config",
27
+ "compile_dynamic": true,
28
+ "compiled_region": "language",
29
+ "coordinate_descent_tuning": false,
30
+ "enabled": true,
31
+ "max_autotune_pointwise": false,
32
+ "use_cuda_graphs": false
33
+ },
34
+ "diffusion_expert_config": {
35
+ "_type": "diffusion_expert_config",
36
+ "base_fps": 24,
37
+ "enable_fps_modulation": true,
38
+ "load_weights_from_pretrained": false,
39
+ "max_vae_latent_side_after_patchify": 20,
40
+ "patch_spatial": 2,
41
+ "position_embedding_type": "unified_3d_mrope",
42
+ "rope_h_extrapolation_ratio": 1.0,
43
+ "rope_t_extrapolation_ratio": 1.0,
44
+ "rope_w_extrapolation_ratio": 1.0,
45
+ "timestep_range": 1.0,
46
+ "unified_3d_mrope_reset_spatial_ids": true,
47
+ "unified_3d_mrope_temporal_modality_margin": 15000
48
+ },
49
+ "ema": {
50
+ "_type": "ema_config",
51
+ "enabled": false,
52
+ "iteration_shift": 0,
53
+ "rate": 0.1
54
+ },
55
+ "fixed_step_sampler_config": null,
56
+ "input_caption_key": "ai_caption",
57
+ "input_image_key": "images",
58
+ "input_video_key": "video",
59
+ "joint_attn_implementation": "two_way",
60
+ "latent_downsample_factor": 16,
61
+ "lbl": {
62
+ "_type": "lbl_config",
63
+ "coeff_gen": null,
64
+ "coeff_und": null,
65
+ "method": "local"
66
+ },
67
+ "log_enc_time_every_n": 100,
68
+ "lora_alpha": 32,
69
+ "lora_enabled": false,
70
+ "lora_rank": 16,
71
+ "lora_target_modules": "q_proj_moe_gen,k_proj_moe_gen,v_proj_moe_gen,o_proj_moe_gen",
72
+ "max_action_dim": 32,
73
+ "max_num_tokens_after_packing": 69632,
74
+ "natten_parameter_list": null,
75
+ "net": null,
76
+ "num_embodiment_domains": 32,
77
+ "parallelism": {
78
+ "_type": "parallelism_config",
79
+ "cfg_parallel_shard_degree": 1,
80
+ "context_parallel_shard_degree": 1,
81
+ "data_parallel_replicate_degree": 1,
82
+ "data_parallel_shard_degree": 16,
83
+ "enable_inference_mode": false,
84
+ "fsdp_master_dtype": "float32"
85
+ },
86
+ "precision": "bfloat16",
87
+ "rectified_flow_inference_config": {
88
+ "_type": "rectified_flow_inference_config",
89
+ "num_train_timesteps": 1000,
90
+ "scheduler_type": "unipc",
91
+ "shift": 3,
92
+ "use_dynamic_shifting": false
93
+ },
94
+ "rectified_flow_training_config": {
95
+ "_type": "rectified_flow_training_config",
96
+ "action_loss_weight": 10.0,
97
+ "high_sigma_ratio": 0.05,
98
+ "high_sigma_timesteps_max": 1000,
99
+ "high_sigma_timesteps_min": 995,
100
+ "image_loss_scale": null,
101
+ "independent_action_schedule": false,
102
+ "independent_sound_schedule": false,
103
+ "loss_scale": 10.0,
104
+ "normalize_loss_by_active": false,
105
+ "shift": {
106
+ "720": 5,
107
+ "768": 5
108
+ },
109
+ "shift_action": null,
110
+ "shift_sound": null,
111
+ "sound_loss_scale": 2.0,
112
+ "train_time_action_distribution": "logitnormal",
113
+ "train_time_image_distribution": "logitnormal",
114
+ "train_time_sound_distribution": "logitnormal",
115
+ "train_time_video_distribution": "waver",
116
+ "train_time_weight": "uniform",
117
+ "use_discrete_rf": false,
118
+ "use_dynamic_shift": false,
119
+ "use_high_sigma_strategy": false,
120
+ "use_high_sigma_strategy_action": false,
121
+ "use_high_sigma_strategy_sound": false
122
+ },
123
+ "resolution": "768",
124
+ "sound_dim": 64,
125
+ "sound_gen": true,
126
+ "sound_latent_fps": 25,
127
+ "sound_tokenizer": {
128
+ "_target": "avae_interface",
129
+ "audio_channels": 2,
130
+ "avae_config_path": "",
131
+ "avae_path": "pretrained/tokenizers/audio/avae/avae_48k_noncausal_25hz_64ch.ckpt",
132
+ "bucket_name": "bucket",
133
+ "hop_size": 1920,
134
+ "io_channels": 64,
135
+ "latent_mean": null,
136
+ "latent_std": null,
137
+ "normalization_type": "none",
138
+ "normalize_latents": false,
139
+ "object_store_credential_path_pretrained": "credentials/gcp_training.secret",
140
+ "sample_rate": 48000,
141
+ "tanh_clamp": 0.995,
142
+ "tanh_input_scale": 1.5,
143
+ "tanh_output_scale": 3.5
144
+ },
145
+ "state_ch": 48,
146
+ "state_t": 300,
147
+ "tokenizer": {
148
+ "_target": "wan2pt2_vae_interface",
149
+ "bucket_name": "bucket",
150
+ "causal": true,
151
+ "chunk_duration": 93,
152
+ "encode_bucket_multiple": null,
153
+ "encode_chunk_frames": {
154
+ "720": 12,
155
+ "768": 12
156
+ },
157
+ "encode_exact_durations": null,
158
+ "keep_decoder_cache": false,
159
+ "object_store_credential_path_pretrained": "credentials/gcp_training.secret",
160
+ "spatial_compression_factor": 16,
161
+ "temporal_compression_factor": 4,
162
+ "temporal_window": null,
163
+ "use_streaming_encode": false,
164
+ "vae_path": "pretrained/tokenizers/video/wan2pt2/Wan2.2_VAE.pth"
165
+ },
166
+ "video_temporal_causal": false,
167
+ "vision_gen": true,
168
+ "vlm_config": {
169
+ "_type": "vlm_config",
170
+ "layer_module": null,
171
+ "model_instance": {
172
+ "_target": "qwen3_vl_text_for_causal_lm",
173
+ "config": {
174
+ "_target": "create_vlm_config",
175
+ "base_config": {
176
+ "_target": "qwen3_vl_mot_config_from_json_file",
177
+ "json_file": "cosmos3://vfm/models/vlm/qwen3_vl/configs/Qwen3-VL-32B-Instruct.json"
178
+ },
179
+ "qk_norm_for_text": true
180
+ }
181
+ },
182
+ "model_name": "nvidia/Cosmos3-Super-Reasoner",
183
+ "pretrained_weights": {
184
+ "_type": "pretrained_weights_config",
185
+ "backbone_path": "s3://bucket/cosmos3/pretrained/huggingface/Cosmos-Reason/Cosmos3-Super-Reasoner-b6df0d1/",
186
+ "checkpoint_format": null,
187
+ "credentials_path": "credentials/gcp_checkpoint.secret",
188
+ "enable_gcs_patch_in_boto3": true,
189
+ "enabled": false
190
+ },
191
+ "qk_norm": false,
192
+ "tie_word_embeddings": false,
193
+ "tokenizer": {
194
+ "_target": "create_qwen2_tokenizer_with_download",
195
+ "config_variant": "gcp",
196
+ "pretrained_model_name": "Qwen/Qwen3-VL-32B-Instruct"
197
+ },
198
+ "use_system_prompt": false
199
+ }
200
+ }
201
+ },
202
+ "model_type": "cosmos3_omni",
203
+ "text_config": {
204
+ "attention_bias": false,
205
+ "attention_dropout": 0.0,
206
+ "bos_token_id": 151643,
207
+ "dtype": "bfloat16",
208
+ "eos_token_id": 151645,
209
+ "head_dim": 128,
210
+ "hidden_act": "silu",
211
+ "hidden_size": 5120,
212
+ "initializer_range": 0.02,
213
+ "intermediate_size": 25600,
214
+ "max_position_embeddings": 262144,
215
+ "model_type": "qwen3_vl_text",
216
+ "num_attention_heads": 64,
217
+ "num_hidden_layers": 64,
218
+ "num_key_value_heads": 8,
219
+ "rms_norm_eps": 1e-06,
220
+ "rope_scaling": {
221
+ "mrope_interleaved": true,
222
+ "mrope_section": [
223
+ 24,
224
+ 20,
225
+ 20
226
+ ],
227
+ "rope_type": "default"
228
+ },
229
+ "rope_theta": 5000000,
230
+ "use_cache": true,
231
+ "vocab_size": 151936
232
+ },
233
+ "tie_word_embeddings": false,
234
+ "transformers_version": "4.57.0.dev0",
235
+ "video_token_id": 151656,
236
+ "vision_config": {
237
+ "deepstack_visual_indexes": [
238
+ 8,
239
+ 16,
240
+ 24
241
+ ],
242
+ "depth": 27,
243
+ "hidden_act": "gelu_pytorch_tanh",
244
+ "hidden_size": 1152,
245
+ "in_channels": 3,
246
+ "initializer_range": 0.02,
247
+ "intermediate_size": 4304,
248
+ "model_type": "qwen3_vl",
249
+ "num_heads": 16,
250
+ "num_position_embeddings": 2304,
251
+ "out_hidden_size": 5120,
252
+ "patch_size": 16,
253
+ "spatial_merge_size": 2,
254
+ "temporal_patch_size": 2
255
+ },
256
+ "vision_end_token_id": 151653,
257
+ "vision_start_token_id": 151652
258
+ }
generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "pad_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "top_p": 0.8,
10
+ "top_k": 20,
11
+ "temperature": 0.7,
12
+ "repetition_penalty": 1.0,
13
+ "transformers_version": "4.56.0"
14
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
model_index.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Cosmos3OmniDiffusersPipeline",
3
+ "_diffusers_version": "0.37.1",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "UniPCMultistepScheduler"
7
+ ],
8
+ "text_tokenizer": [
9
+ "transformers",
10
+ "Qwen2TokenizerFast"
11
+ ],
12
+ "transformer": [
13
+ "diffusers",
14
+ "Cosmos3OmniTransformer"
15
+ ],
16
+ "vae": [
17
+ "diffusers",
18
+ "AutoencoderKLWan"
19
+ ],
20
+ "vision_encoder": [
21
+ "transformers",
22
+ "Qwen3VLVisionModel"
23
+ ],
24
+ "sound_tokenizer": [
25
+ "diffusers",
26
+ "Cosmos3AVAEAudioTokenizer"
27
+ ]
28
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "size": {
3
+ "longest_edge": 16777216,
4
+ "shortest_edge": 65536
5
+ },
6
+ "patch_size": 16,
7
+ "temporal_patch_size": 2,
8
+ "merge_size": 2,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "processor_class": "Qwen3VLProcessor",
20
+ "image_processor_type": "Qwen2VLImageProcessorFast"
21
+ }
pytest.ini ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests
3
+ pythonpath = .
4
+ addopts = --confcutdir=.
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UniPCMultistepScheduler",
3
+ "_diffusers_version": "0.37.1",
4
+ "beta_end": 0.02,
5
+ "beta_schedule": "linear",
6
+ "beta_start": 0.0001,
7
+ "disable_corrector": [],
8
+ "dynamic_thresholding_ratio": 0.995,
9
+ "final_sigmas_type": "zero",
10
+ "flow_shift": 3.0,
11
+ "lower_order_final": true,
12
+ "num_train_timesteps": 1000,
13
+ "predict_x0": true,
14
+ "prediction_type": "flow_prediction",
15
+ "rescale_betas_zero_snr": false,
16
+ "sample_max_value": 1.0,
17
+ "shift_terminal": null,
18
+ "sigma_max": 200.0,
19
+ "sigma_min": 0.147,
20
+ "solver_order": 2,
21
+ "solver_p": null,
22
+ "solver_type": "bh2",
23
+ "steps_offset": 0,
24
+ "thresholding": false,
25
+ "time_shift_type": "exponential",
26
+ "timestep_spacing": "linspace",
27
+ "trained_betas": null,
28
+ "use_beta_sigmas": false,
29
+ "use_dynamic_shifting": false,
30
+ "use_exponential_sigmas": false,
31
+ "use_flow_sigmas": true,
32
+ "use_karras_sigmas": true
33
+ }
sound_tokenizer/config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder_v2",
3
+ "sampling_rate": 48000,
4
+ "stereo": true,
5
+ "use_wav_as_input": true,
6
+ "normalize_volume": true,
7
+ "hop_size": 1920,
8
+ "input_channels": 1,
9
+ "enc_type": "spec_convnext",
10
+ "enc_dim": 192,
11
+ "enc_intermediate_dim": 768,
12
+ "enc_num_layers": 12,
13
+ "enc_num_blocks": 2,
14
+ "enc_n_fft": 64,
15
+ "enc_hop_length": 16,
16
+ "enc_latent_dim": 128,
17
+ "enc_c_mults": [
18
+ 1,
19
+ 2,
20
+ 4
21
+ ],
22
+ "enc_strides": [
23
+ 4,
24
+ 5,
25
+ 6
26
+ ],
27
+ "enc_identity_init": false,
28
+ "enc_use_snake": true,
29
+ "dec_type": "oobleck",
30
+ "dec_dim": 320,
31
+ "dec_c_mults": [
32
+ 1,
33
+ 2,
34
+ 4,
35
+ 8,
36
+ 16
37
+ ],
38
+ "dec_strides": [
39
+ 2,
40
+ 4,
41
+ 5,
42
+ 6,
43
+ 8
44
+ ],
45
+ "dec_use_snake": true,
46
+ "dec_final_tanh": false,
47
+ "dec_out_channels": 2,
48
+ "dec_anti_aliasing": false,
49
+ "dec_use_nearest_upsample": false,
50
+ "dec_use_tanh_at_final": false,
51
+ "bottleneck_type": "vae",
52
+ "bottleneck": {
53
+ "type": "vae"
54
+ },
55
+ "activation": "snakebeta",
56
+ "snake_logscale": true,
57
+ "anti_aliasing": false,
58
+ "use_cuda_kernel": false,
59
+ "causal": false,
60
+ "padding_mode": "zeros",
61
+ "vocoder_input_dim": 64,
62
+ "latent_mean": null,
63
+ "latent_std": null
64
+ }
sound_tokenizer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d4c61cde38acfb0cad9048a140c3533750277a8462b19dc08450d9fe1ad9879
3
+ size 1892409600
tests/test_agentic_upsampling.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import io
5
+ import json
6
+ import threading
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from PIL import Image
12
+
13
+ from agentic_upsampling.clients import ImageGenerationClient, PromptRewriterClient
14
+ from agentic_upsampling.constants import (
15
+ DEFAULT_CRITIC_ENDPOINT_URL,
16
+ DEFAULT_CRITIC_MODEL,
17
+ DEFAULT_FLOW_SHIFT,
18
+ DEFAULT_GENERATION_EXTRA_ARGS,
19
+ DEFAULT_GENERATION_MODEL,
20
+ DEFAULT_LLM_EXTRA_BODY,
21
+ DEFAULT_REWRITER_MODEL,
22
+ )
23
+ from agentic_upsampling.data import PromptItem, load_prompt_items, prompt_dir_name
24
+ from agentic_upsampling.extract_best import extract_best_images
25
+ from agentic_upsampling.prompt_upsampler import (
26
+ Text2ImagePromptUpsampler,
27
+ apply_t2i_output_parameters,
28
+ normalize_openai_base_url,
29
+ )
30
+ from agentic_upsampling.rubric import parse_analysis_response
31
+ from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig
32
+
33
+
34
+ def _item(prompt_id: str = "1", prompt: str = "a red cube") -> PromptItem:
35
+ return PromptItem(prompt_id=prompt_id, row_number=0, prompt=prompt)
36
+
37
+
38
+ def _valid_t2i_prompt(caption: str) -> dict[str, Any]:
39
+ return {
40
+ "subjects": [],
41
+ "subject_details": {},
42
+ "background_setting": "plain studio",
43
+ "lighting": {"conditions": "soft", "direction": "front", "shadows": "soft", "illumination_effect": "clear"},
44
+ "aesthetics": {
45
+ "composition": "centered",
46
+ "color_scheme": "balanced",
47
+ "mood_atmosphere": "precise",
48
+ "patterns": "",
49
+ },
50
+ "cinematography": {
51
+ "framing": "centered",
52
+ "camera_angle": "eye-level",
53
+ "depth_of_field": "deep",
54
+ "focus": "sharp",
55
+ "lens_focal_length": "standard",
56
+ },
57
+ "style_medium": "digital render",
58
+ "artistic_style": "clean realistic render",
59
+ "context": "test prompt",
60
+ "text_and_signage_elements": [],
61
+ "quadrant_scan": {
62
+ "top_left": "",
63
+ "top_right": "",
64
+ "bottom_left": "",
65
+ "bottom_right": "",
66
+ "absolute_center": "",
67
+ },
68
+ "comprehensive_t2i_caption": caption,
69
+ "resolution": {"H": 960, "W": 960},
70
+ "aspect_ratio": "1,1",
71
+ }
72
+
73
+
74
+ class FakeChatClient:
75
+ messages: list[dict[str, Any]]
76
+ response_format_json: bool
77
+
78
+ def __init__(self, response: dict[str, Any]) -> None:
79
+ self.response = response
80
+ self.messages = []
81
+ self.response_format_json = False
82
+
83
+ def complete(self, messages: list[dict[str, Any]], *, response_format_json: bool = False) -> str:
84
+ self.messages = messages
85
+ self.response_format_json = response_format_json
86
+ return json.dumps(self.response)
87
+
88
+
89
+ def test_defaults_are_public_provider_defaults() -> None:
90
+ assert DEFAULT_REWRITER_MODEL == "gpt-5.5"
91
+ assert DEFAULT_LLM_EXTRA_BODY == {"reasoning_effort": "low"}
92
+ assert DEFAULT_CRITIC_MODEL == "gemini-3.1-pro-preview"
93
+ assert DEFAULT_CRITIC_ENDPOINT_URL == "https://generativelanguage.googleapis.com/v1beta/openai/"
94
+
95
+
96
+ def test_gemini_openai_compatible_base_url_is_not_modified() -> None:
97
+ assert (
98
+ normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/")
99
+ == "https://generativelanguage.googleapis.com/v1beta/openai"
100
+ )
101
+ assert (
102
+ normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/chat/completions")
103
+ == "https://generativelanguage.googleapis.com/v1beta/openai"
104
+ )
105
+
106
+
107
+ def test_prompt_loaders_support_text_jsonl_and_csv(tmp_path: Path) -> None:
108
+ txt_path = tmp_path / "prompts.txt"
109
+ txt_path.write_text("one\n\ntwo\n", encoding="utf-8")
110
+ assert [item.prompt for item in load_prompt_items(prompts_path=txt_path)] == ["one", "two"]
111
+
112
+ jsonl_path = tmp_path / "prompts.jsonl"
113
+ jsonl_path.write_text('{"id":"custom id","prompt":"three"}\n"four"\n', encoding="utf-8")
114
+ jsonl_items = load_prompt_items(prompts_path=jsonl_path)
115
+ assert [item.prompt for item in jsonl_items] == ["three", "four"]
116
+ assert prompt_dir_name(jsonl_items[0]) == "custom_id"
117
+
118
+ csv_path = tmp_path / "prompts.csv"
119
+ csv_path.write_text("id,prompt\nfive_id,five\n", encoding="utf-8")
120
+ csv_items = load_prompt_items(prompts_path=csv_path)
121
+ assert csv_items[0].prompt_id == "five_id"
122
+ assert csv_items[0].prompt == "five"
123
+
124
+
125
+ def test_prompt_upsampler_applies_resolution_and_requests_json() -> None:
126
+ prompt_json = _valid_t2i_prompt("initial cube prompt")
127
+ fake_client = FakeChatClient(prompt_json)
128
+ upsampler = Text2ImagePromptUpsampler(fake_client) # type: ignore[arg-type]
129
+
130
+ result = upsampler.upsample("a cube", prompt_id="cube", resolution="720", aspect_ratio="16,9")
131
+
132
+ assert result["resolution"] == {"H": 720, "W": 1280}
133
+ assert result["aspect_ratio"] == "16,9"
134
+ assert fake_client.response_format_json is True
135
+
136
+
137
+ def test_apply_t2i_output_parameters_rejects_bad_canvas() -> None:
138
+ try:
139
+ apply_t2i_output_parameters(_valid_t2i_prompt("x"), resolution="999", aspect_ratio="1,1")
140
+ except ValueError as exc:
141
+ assert "Unsupported resolution" in str(exc)
142
+ else:
143
+ raise AssertionError("Expected unsupported resolution error.")
144
+
145
+
146
+ def test_prompt_rewriter_joint_rewrite_uses_vlm_feedback() -> None:
147
+ previous_prompt = _valid_t2i_prompt("old cube prompt")
148
+ rewritten_prompt = _valid_t2i_prompt("new cube prompt with no 4x4 grid")
149
+ analysis = {
150
+ "overall_score": 2.0,
151
+ "prompt_adherence_score": 3.0,
152
+ "category_score": 3.0,
153
+ "issues": [
154
+ {
155
+ "category": "geometry",
156
+ "description": "Generated a 4x4 grid instead of a 3x3 cube.",
157
+ "severity": "severe",
158
+ }
159
+ ],
160
+ "improvement_directives": ["Strictly enforce 3x3x3 geometry."],
161
+ "raw_response": "large omitted blob",
162
+ }
163
+ rewriter = PromptRewriterClient(api_token="unused")
164
+ fake_client = FakeChatClient({"positive_prompt": rewritten_prompt, "negative_prompt": "4x4 grid"})
165
+ rewriter.rewrite_client = fake_client # type: ignore[assignment]
166
+
167
+ positive_prompt, negative_prompt = rewriter.rewrite_prompt_pair(
168
+ _item("39", "A Rubik's cube mid twist with the top layer rotated exactly 45 degrees"),
169
+ previous_prompt,
170
+ "",
171
+ analysis,
172
+ [{"iteration": 0, "analysis": analysis}],
173
+ )
174
+
175
+ assert positive_prompt["comprehensive_t2i_caption"] == "new cube prompt with no 4x4 grid"
176
+ assert negative_prompt == "4x4 grid"
177
+ assert fake_client.response_format_json is True
178
+ user_message = str(fake_client.messages[1]["content"])
179
+ assert "Generated a 4x4 grid" in user_message
180
+ assert "Strictly enforce 3x3x3 geometry" in user_message
181
+ assert "raw_response" not in user_message
182
+
183
+
184
+ def test_generation_payload_uses_vllm_omni_images_api() -> None:
185
+ client = ImageGenerationClient(endpoint="https://example.test/v1", model="test/model")
186
+ payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3", seed=100, negative_prompt="blur")
187
+
188
+ assert client.endpoint == "https://example.test"
189
+ assert payload["model"] == "test/model"
190
+ assert payload["prompt"] == '{"comprehensive_t2i_caption":"x"}'
191
+ assert payload["size"] == "1024x1024"
192
+ assert payload["n"] == 1
193
+ assert payload["response_format"] == "b64_json"
194
+ assert payload["negative_prompt"] == "blur"
195
+ assert payload["num_inference_steps"] == 50
196
+ assert payload["guidance_scale"] == 4.0
197
+ assert payload["flow_shift"] == DEFAULT_FLOW_SHIFT
198
+ assert payload["extra_args"] == DEFAULT_GENERATION_EXTRA_ARGS
199
+ assert payload["seed"] == 100
200
+ assert "model_mode" not in payload
201
+ assert "prompt_upsampling" not in payload
202
+
203
+
204
+ def test_generation_payload_allows_custom_extra_args() -> None:
205
+ client = ImageGenerationClient(endpoint="https://example.test", extra_args={"guardrails": True})
206
+ payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3")
207
+
208
+ assert payload["extra_args"] == {"guardrails": True}
209
+
210
+
211
+ class FakeImageResponse:
212
+ ok: bool = True
213
+ status_code: int = 200
214
+ text: str = "ok"
215
+
216
+ def __init__(self, payload: dict[str, Any]) -> None:
217
+ self.payload = payload
218
+
219
+ def json(self) -> dict[str, Any]:
220
+ return self.payload
221
+
222
+
223
+ class FakeImageSession:
224
+ calls: list[dict[str, Any]]
225
+
226
+ def __init__(self, response_payload: dict[str, Any]) -> None:
227
+ self.response_payload = response_payload
228
+ self.calls = []
229
+
230
+ def request(self, method: str, url: str, **kwargs: Any) -> FakeImageResponse:
231
+ self.calls.append({"method": method, "url": url, "kwargs": kwargs})
232
+ return FakeImageResponse(self.response_payload)
233
+
234
+
235
+ def _tiny_png_b64() -> str:
236
+ buf = io.BytesIO()
237
+ Image.new("RGB", (4, 4), (0, 255, 0)).save(buf, format="PNG")
238
+ return base64.b64encode(buf.getvalue()).decode("ascii")
239
+
240
+
241
+ def test_generation_client_decodes_vllm_omni_b64_response(tmp_path: Path) -> None:
242
+ session = FakeImageSession({"created": 1, "data": [{"b64_json": _tiny_png_b64(), "revised_prompt": None}]})
243
+ client = ImageGenerationClient(endpoint="example.test", auth_key="secret-token", session=session) # type: ignore[arg-type]
244
+
245
+ result = client.generate(prompt_json=_valid_t2i_prompt("x"), prompt_id="3", output_dir=tmp_path, seed=5)
246
+
247
+ assert result.image_path.exists()
248
+ assert session.calls[0]["method"] == "POST"
249
+ assert session.calls[0]["url"] == "https://example.test/v1/images/generations"
250
+ assert session.calls[0]["kwargs"]["headers"] == {"Authorization": "Bearer secret-token"}
251
+ assert session.calls[0]["kwargs"]["json"]["model"] == DEFAULT_GENERATION_MODEL
252
+ meta = json.loads(result.meta_path.read_text(encoding="utf-8"))
253
+ assert meta["status"] == "completed"
254
+ assert meta["response"]["data"][0]["b64_json"].startswith("<base64 image omitted:")
255
+
256
+
257
+ def test_parse_analysis_response_sets_threshold_flag() -> None:
258
+ analysis = parse_analysis_response(
259
+ """
260
+ {
261
+ "prompt_adherence_score": 9,
262
+ "visual_quality_score": 9,
263
+ "aesthetics_score": 8.5,
264
+ "physical_plausibility_score": 8,
265
+ "category_score": 9,
266
+ "text_rendering_score": 9,
267
+ "photorealism_score": null,
268
+ "overall_score": 9.1,
269
+ "issues": [],
270
+ "category_findings": {},
271
+ "improvement_directives": [],
272
+ "rationale": "Strong."
273
+ }
274
+ """,
275
+ )
276
+ assert analysis["threshold_cleared"] is True
277
+
278
+
279
+ class FakeRewriter:
280
+ initial_calls: int
281
+ joint_rewrite_calls: int
282
+ previous_scores: list[float]
283
+
284
+ def __init__(self) -> None:
285
+ self.initial_calls = 0
286
+ self.joint_rewrite_calls = 0
287
+ self.previous_scores = []
288
+
289
+ def initial_prompt(self, item: PromptItem) -> dict[str, Any]:
290
+ self.initial_calls += 1
291
+ return _valid_t2i_prompt(f"initial {item.prompt_id}")
292
+
293
+ def rewrite_prompt_pair(
294
+ self,
295
+ item: PromptItem,
296
+ previous_prompt: dict[str, Any],
297
+ previous_negative_prompt: str,
298
+ previous_analysis: dict[str, Any],
299
+ history: list[dict[str, Any]],
300
+ ) -> tuple[dict[str, Any], str]:
301
+ self.joint_rewrite_calls += 1
302
+ self.previous_scores.append(float(previous_analysis["overall_score"]))
303
+ return _valid_t2i_prompt(f"rewrite {len(history)}"), f"negative {len(history)}"
304
+
305
+
306
+ @dataclass(frozen=True, slots=True)
307
+ class FakeGeneration:
308
+ image_path: Path
309
+ meta_path: Path
310
+ meta: dict[str, Any]
311
+
312
+
313
+ class FakeGenerator:
314
+ seeds: list[int | None]
315
+ negative_prompts: list[str]
316
+
317
+ def __init__(self) -> None:
318
+ self.seeds = []
319
+ self.negative_prompts = []
320
+
321
+ def generate(
322
+ self,
323
+ *,
324
+ prompt_json: dict[str, Any],
325
+ prompt_id: str,
326
+ output_dir: Path,
327
+ seed: int | None = None,
328
+ negative_prompt: str = "",
329
+ jpeg_quality: int = 95,
330
+ ) -> FakeGeneration:
331
+ self.seeds.append(seed)
332
+ self.negative_prompts.append(negative_prompt)
333
+ output_dir.mkdir(parents=True, exist_ok=True)
334
+ image_path = output_dir / "image.jpg"
335
+ Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path)
336
+ meta_path = output_dir / "generation_meta.json"
337
+ meta_path.write_text('{"status":"completed"}\n', encoding="utf-8")
338
+ return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"})
339
+
340
+
341
+ class BarrierGenerator(FakeGenerator):
342
+ barrier: threading.Barrier
343
+ lock: threading.Lock
344
+
345
+ def __init__(self, parties: int) -> None:
346
+ super().__init__()
347
+ self.barrier = threading.Barrier(parties)
348
+ self.lock = threading.Lock()
349
+
350
+ def generate(
351
+ self,
352
+ *,
353
+ prompt_json: dict[str, Any],
354
+ prompt_id: str,
355
+ output_dir: Path,
356
+ seed: int | None = None,
357
+ negative_prompt: str = "",
358
+ jpeg_quality: int = 95,
359
+ ) -> FakeGeneration:
360
+ with self.lock:
361
+ self.seeds.append(seed)
362
+ self.negative_prompts.append(negative_prompt)
363
+ self.barrier.wait(timeout=2.0)
364
+ output_dir.mkdir(parents=True, exist_ok=True)
365
+ image_path = output_dir / "image.jpg"
366
+ Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path)
367
+ meta_path = output_dir / "generation_meta.json"
368
+ meta_path.write_text('{"status":"completed"}\n', encoding="utf-8")
369
+ return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"})
370
+
371
+
372
+ class FakeJudge:
373
+ calls: int
374
+ scores: list[float]
375
+
376
+ def __init__(self, scores: list[float]) -> None:
377
+ self.calls = 0
378
+ self.scores = scores
379
+
380
+ def score_image(
381
+ self,
382
+ *,
383
+ item: PromptItem,
384
+ image_path: Path,
385
+ ) -> dict[str, Any]:
386
+ score = self.scores[self.calls]
387
+ self.calls += 1
388
+ return {
389
+ "overall_score": score,
390
+ "prompt_adherence_score": score,
391
+ "visual_quality_score": score,
392
+ "aesthetics_score": score,
393
+ "physical_plausibility_score": score,
394
+ "category_score": score,
395
+ "issues": [],
396
+ "improvement_directives": [],
397
+ "threshold_cleared": score >= 9,
398
+ }
399
+
400
+
401
+ def test_runner_early_stops_by_default(tmp_path: Path) -> None:
402
+ rewriter = FakeRewriter()
403
+ generator = FakeGenerator()
404
+ runner = AgenticUpsamplerRunner(
405
+ rewriter=rewriter,
406
+ generator=generator, # type: ignore[arg-type]
407
+ judge=FakeJudge([9.1, 8.0]),
408
+ config=RunnerConfig(output_dir=tmp_path, max_iterations=3, samples_per_iteration=1),
409
+ )
410
+
411
+ result = runner.run_item(_item())
412
+
413
+ assert result["best_iteration"] == 0
414
+ assert rewriter.initial_calls == 1
415
+ assert rewriter.joint_rewrite_calls == 0
416
+ assert generator.seeds == [None]
417
+
418
+
419
+ def test_runner_can_disable_early_stop_and_select_best_sample(tmp_path: Path) -> None:
420
+ rewriter = FakeRewriter()
421
+ generator = FakeGenerator()
422
+ runner = AgenticUpsamplerRunner(
423
+ rewriter=rewriter,
424
+ generator=generator, # type: ignore[arg-type]
425
+ judge=FakeJudge([5.0, 9.0, 7.0, 6.0, 10.0, 8.0]),
426
+ config=RunnerConfig(
427
+ output_dir=tmp_path,
428
+ max_iterations=2,
429
+ samples_per_iteration=3,
430
+ seed_base=1000,
431
+ early_stop=False,
432
+ ),
433
+ )
434
+
435
+ result = runner.run_item(_item("8", "exactly 12 balloons with exact color counts"))
436
+
437
+ assert generator.seeds == [1000, 1001, 1002, 1000, 1001, 1002]
438
+ assert rewriter.previous_scores == [9.0]
439
+ assert result["best_iteration"] == 1
440
+ assert result["best"]["selected_sample_index"] == 1
441
+ assert result["iterations"][0]["selected_sample_index"] == 1
442
+
443
+
444
+ def test_runner_generates_seed_samples_in_parallel(tmp_path: Path) -> None:
445
+ rewriter = FakeRewriter()
446
+ generator = BarrierGenerator(parties=3)
447
+ runner = AgenticUpsamplerRunner(
448
+ rewriter=rewriter,
449
+ generator=generator, # type: ignore[arg-type]
450
+ judge=FakeJudge([5.0, 6.0, 7.0]),
451
+ config=RunnerConfig(
452
+ output_dir=tmp_path,
453
+ max_iterations=1,
454
+ samples_per_iteration=3,
455
+ seed_base=2000,
456
+ early_stop=False,
457
+ ),
458
+ )
459
+
460
+ result = runner.run_item(_item("parallel", "a parallel seed test"))
461
+
462
+ assert sorted(generator.seeds) == [2000, 2001, 2002]
463
+ assert result["best"]["selected_sample_index"] == 2
464
+ assert result["iterations"][0]["sample_count"] == 3
465
+
466
+
467
+ def test_extract_best_images_copies_images_and_writes_manifests(tmp_path: Path) -> None:
468
+ output_dir = tmp_path / "run"
469
+ image_dir = output_dir / "0001" / "iter_00"
470
+ image_dir.mkdir(parents=True)
471
+ image_path = image_dir / "image.jpg"
472
+ Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path)
473
+ best_json = {
474
+ "prompt_id": "1",
475
+ "prompt": "a red square",
476
+ "best_iteration": 0,
477
+ "best_score": 9.25,
478
+ "threshold_cleared_any": True,
479
+ "best": {
480
+ "selected_sample_index": 0,
481
+ "image_path": str(image_path),
482
+ "analysis_path": str(image_dir / "analysis.json"),
483
+ },
484
+ "iterations": [],
485
+ }
486
+ (output_dir / "0001" / "best.json").write_text(json.dumps(best_json), encoding="utf-8")
487
+
488
+ records = extract_best_images(output_dir, tmp_path / "export")
489
+
490
+ assert len(records) == 1
491
+ copied_path = Path(records[0]["copied_image_path"])
492
+ assert copied_path.exists()
493
+ assert copied_path.name == "1.jpg"
494
+ assert (tmp_path / "export" / "best_generations.jsonl").exists()
495
+ assert (tmp_path / "export" / "best_generations.csv").exists()
496
+
text_tokenizer/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
text_tokenizer/chat_template.jinja ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {%- if messages[0].content is string %}
5
+ {{- messages[0].content }}
6
+ {%- else %}
7
+ {%- for content in messages[0].content %}
8
+ {%- if 'text' in content %}
9
+ {{- content.text }}
10
+ {%- endif %}
11
+ {%- endfor %}
12
+ {%- endif %}
13
+ {{- '\n\n' }}
14
+ {%- endif %}
15
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
16
+ {%- for tool in tools %}
17
+ {{- "\n" }}
18
+ {{- tool | tojson }}
19
+ {%- endfor %}
20
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
21
+ {%- else %}
22
+ {%- if messages[0].role == 'system' %}
23
+ {{- '<|im_start|>system\n' }}
24
+ {%- if messages[0].content is string %}
25
+ {{- messages[0].content }}
26
+ {%- else %}
27
+ {%- for content in messages[0].content %}
28
+ {%- if 'text' in content %}
29
+ {{- content.text }}
30
+ {%- endif %}
31
+ {%- endfor %}
32
+ {%- endif %}
33
+ {{- '<|im_end|>\n' }}
34
+ {%- endif %}
35
+ {%- endif %}
36
+ {%- set image_count = namespace(value=0) %}
37
+ {%- set video_count = namespace(value=0) %}
38
+ {%- for message in messages %}
39
+ {%- if message.role == "user" %}
40
+ {{- '<|im_start|>' + message.role + '\n' }}
41
+ {%- if message.content is string %}
42
+ {{- message.content }}
43
+ {%- else %}
44
+ {%- for content in message.content %}
45
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
46
+ {%- set image_count.value = image_count.value + 1 %}
47
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
48
+ <|vision_start|><|image_pad|><|vision_end|>
49
+ {%- elif content.type == 'video' or 'video' in content %}
50
+ {%- set video_count.value = video_count.value + 1 %}
51
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
52
+ <|vision_start|><|video_pad|><|vision_end|>
53
+ {%- elif 'text' in content %}
54
+ {{- content.text }}
55
+ {%- endif %}
56
+ {%- endfor %}
57
+ {%- endif %}
58
+ {{- '<|im_end|>\n' }}
59
+ {%- elif message.role == "assistant" %}
60
+ {{- '<|im_start|>' + message.role + '\n' }}
61
+ {%- if message.content is string %}
62
+ {{- message.content }}
63
+ {%- else %}
64
+ {%- for content_item in message.content %}
65
+ {%- if 'text' in content_item %}
66
+ {{- content_item.text }}
67
+ {%- endif %}
68
+ {%- endfor %}
69
+ {%- endif %}
70
+ {%- if message.tool_calls %}
71
+ {%- for tool_call in message.tool_calls %}
72
+ {%- if (loop.first and message.content) or (not loop.first) %}
73
+ {{- '\n' }}
74
+ {%- endif %}
75
+ {%- if tool_call.function %}
76
+ {%- set tool_call = tool_call.function %}
77
+ {%- endif %}
78
+ {{- '<tool_call>\n{"name": "' }}
79
+ {{- tool_call.name }}
80
+ {{- '", "arguments": ' }}
81
+ {%- if tool_call.arguments is string %}
82
+ {{- tool_call.arguments }}
83
+ {%- else %}
84
+ {{- tool_call.arguments | tojson }}
85
+ {%- endif %}
86
+ {{- '}\n</tool_call>' }}
87
+ {%- endfor %}
88
+ {%- endif %}
89
+ {{- '<|im_end|>\n' }}
90
+ {%- elif message.role == "tool" %}
91
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
92
+ {{- '<|im_start|>user' }}
93
+ {%- endif %}
94
+ {{- '\n<tool_response>\n' }}
95
+ {%- if message.content is string %}
96
+ {{- message.content }}
97
+ {%- else %}
98
+ {%- for content in message.content %}
99
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
100
+ {%- set image_count.value = image_count.value + 1 %}
101
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
102
+ <|vision_start|><|image_pad|><|vision_end|>
103
+ {%- elif content.type == 'video' or 'video' in content %}
104
+ {%- set video_count.value = video_count.value + 1 %}
105
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
106
+ <|vision_start|><|video_pad|><|vision_end|>
107
+ {%- elif 'text' in content %}
108
+ {{- content.text }}
109
+ {%- endif %}
110
+ {%- endfor %}
111
+ {%- endif %}
112
+ {{- '\n</tool_response>' }}
113
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
114
+ {{- '<|im_end|>\n' }}
115
+ {%- endif %}
116
+ {%- endif %}
117
+ {%- endfor %}
118
+ {%- if add_generation_prompt %}
119
+ {{- '<|im_start|>assistant\n' }}
120
+ {%- endif %}
text_tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
text_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
text_tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
text_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 262144,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
text_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "model_max_length": 262144,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
transformer/config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Cosmos3OmniTransformer",
3
+ "_diffusers_version": "0.37.1",
4
+ "action_dim": 32,
5
+ "action_gen": false,
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "base_fps": 24,
9
+ "dtype": "bfloat16",
10
+ "enable_fps_modulation": true,
11
+ "freeze_und": false,
12
+ "head_dim": 128,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 5120,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 25600,
17
+ "joint_attn_implementation": "two_way",
18
+ "latent_channel": 48,
19
+ "latent_patch_size": 2,
20
+ "max_action_dim": 32,
21
+ "max_position_embeddings": 262144,
22
+ "model_type": "qwen3_vl_text",
23
+ "num_attention_heads": 64,
24
+ "num_embodiment_domains": 32,
25
+ "num_hidden_layers": 64,
26
+ "num_key_value_heads": 8,
27
+ "patch_latent_dim": 192,
28
+ "position_embedding_type": "unified_3d_mrope",
29
+ "qk_norm": false,
30
+ "qk_norm_for_diffusion": true,
31
+ "qk_norm_for_text": true,
32
+ "rms_norm_eps": 1e-06,
33
+ "rope_scaling": {
34
+ "mrope_interleaved": true,
35
+ "mrope_section": [
36
+ 24,
37
+ 20,
38
+ 20
39
+ ],
40
+ "rope_type": "default"
41
+ },
42
+ "rope_theta": 5000000,
43
+ "sound_dim": 64,
44
+ "sound_gen": true,
45
+ "sound_latent_fps": 25,
46
+ "temporal_compression_factor_sound": 1,
47
+ "timestep_scale": 0.001,
48
+ "unified_3d_mrope_reset_spatial_ids": true,
49
+ "unified_3d_mrope_temporal_modality_margin": 15000,
50
+ "use_cache": true,
51
+ "use_moe": true,
52
+ "video_temporal_causal": false,
53
+ "vocab_size": 151936
54
+ }
transformer/diffusion_pytorch_model-00001-of-00027.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28fe2fcd84de5c3e1a26a5224fda8da81b13ba6e58cf6073460f9b04403a33d6
3
+ size 4932297056