Cosmos3-Super-Text2Image NVIDIA ModelOpt NVFP4 Transformer

This repository contains a transformer-only NVIDIA ModelOpt NVFP4 quantization for nvidia/Cosmos3-Super-Text2Image.

It does not repeat the original model card. Read NVIDIA's model card, prompt-format guidance, license, and safety notes here: nvidia/Cosmos3-Super-Text2Image.

Only transformer/ is provided as a weight artifact. The VAE, scheduler, tokenizers, safety checker, and other components are loaded from the base model.

Recipe

Setting Value
Quantizer NVIDIA ModelOpt
ModelOpt version 0.44.0
Quant type NVFP4_DEFAULT_CFG
Weight-only True
Compressed True
Quantized modules inserted 2709
Quantization time 1.98s
Compress time 3.93s
Save time 37.53s
Transformer checkpoint size 35.62 GiB

The checkpoint includes ModelOpt state in transformer/modelopt_state.pth.

Assemble The Pipeline

Install ModelOpt in the same environment as Diffusers:

pip install "nvidia_modelopt[hf]"

The current tested runtime requires a small compatibility helper for ModelOpt QTensorWrapper restoration with Diffusers and Accelerate. Important: load the quantized transformer without passing torch_dtype; otherwise Diffusers casts quantized tensors back to BF16 during state-dict loading.

import json
import torch
from diffusers import Cosmos3OmniPipeline, Cosmos3OmniTransformer
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from modelopt.torch.quantization.qtensor.base_qtensor import QTensorWrapper
import modelopt.torch.opt as mto


def patch_modelopt_qtensor_loader():
    import accelerate.utils.modeling as accelerate_modeling
    import diffusers.models.model_loading_utils as diffusers_loading

    original = accelerate_modeling.set_module_tensor_to_device
    if getattr(original, "_cosmos3_modelopt_patch", False):
        return

    def patched(module, tensor_name, device, value=None, dtype=None, fp16_statistics=None,
                tied_params_map=None, non_blocking=False, clear_cache=True):
        leaf_module = module
        leaf_name = tensor_name
        if "." in tensor_name:
            parts = tensor_name.split(".")
            for part in parts[:-1]:
                leaf_module = getattr(leaf_module, part)
            leaf_name = parts[-1]
        old_value = getattr(leaf_module, leaf_name) if hasattr(leaf_module, leaf_name) else None
        if isinstance(old_value, QTensorWrapper) and value is not None:
            leaf_module._parameters[leaf_name] = QTensorWrapper(
                value.to(device, non_blocking=non_blocking),
                metadata=old_value.metadata,
            )
            return
        return original(module, tensor_name, device, value, dtype, fp16_statistics,
                        tied_params_map, non_blocking, clear_cache)

    patched._cosmos3_modelopt_patch = True
    accelerate_modeling.set_module_tensor_to_device = patched
    diffusers_loading.set_module_tensor_to_device = patched


def cast_modelopt_runtime_tensors(model, dtype=torch.bfloat16):
    for module in model.modules():
        for name, param in list(module._parameters.items()):
            if isinstance(param, QTensorWrapper):
                param.metadata["dtype"] = dtype
            elif param is not None and param.is_floating_point():
                module._parameters[name] = torch.nn.Parameter(
                    param.detach().to(dtype),
                    requires_grad=param.requires_grad,
                )
        for name, buf in list(module._buffers.items()):
            if buf is not None and buf.is_floating_point():
                module._buffers[name] = buf.to(dtype)
    return model


patch_modelopt_qtensor_loader()
mto.enable_huggingface_checkpointing()

transformer = Cosmos3OmniTransformer.from_pretrained(
    "WaveCut/Cosmos3-Super-Text2Image-ModelOpt-NVFP4-Transformer",
    subfolder="transformer",
    use_safetensors=False,
)
transformer = cast_modelopt_runtime_tensors(transformer, torch.bfloat16)

pipe = Cosmos3OmniPipeline.from_pretrained(
    "nvidia/Cosmos3-Super-Text2Image",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    enable_safety_checker=True,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=3.0)
pipe.to("cuda")

json_caption = {
    "subjects": [],
    "background_setting": "A concise scene description.",
    "comprehensive_t2i_caption": "A detailed natural-language caption.",
    "resolution": {"H": 1024, "W": 1024},
    "aspect_ratio": "1,1",
}

with torch.autocast("cuda", dtype=torch.bfloat16):
    result = pipe(
        prompt=json.dumps(json_caption),
        negative_prompt="",
        num_frames=1,
        height=1024,
        width=1024,
        num_inference_steps=50,
        guidance_scale=4.0,
        generator=torch.Generator(device="cuda").manual_seed(1143),
    )
result.video[0].save("cosmos3_modelopt_nvfp4.png")

Benchmarks

Measured on one RunPod NVIDIA B200 instance with local container storage, cached model files, PyTorch 2.9.1+cu130, 1024x1024 image generation, 50 inference steps, guidance scale 4.0, flow_shift=3.0, system prompt enabled. The NVIDIA ModelOpt NVFP4 runtime uses BF16 autocast around the pipeline forward.

Full stress safety settings: BF16 safety checker n/a, BF16 per-call safety check n/a; NVIDIA ModelOpt NVFP4 safety checker enabled, NVIDIA ModelOpt NVFP4 per-call safety check disabled.

Transformer Component Load

Variant Load to CUDA VRAM after load Torch allocated Torch reserved Transformer weights
BF16 base transformer 22.19s 122,758 MiB 122,121 MiB 122,132 MiB 119.21 GiB
NVIDIA ModelOpt NVFP4 transformer 30.51s 41,548 MiB 40,335 MiB 40,922 MiB 35.62 GiB

Full Pipeline Generation

The stress set is ten handwritten JSON-caption prompts designed to stress Cyrillic text, reflections, multi-object composition, anatomy, small details, and scene-following.

Variant Full pipeline load VRAM after load Torch allocated after load Avg generation time Min / max generation time Peak sampled VRAM Images
BF16 base pipeline 31.31s 125,134 MiB 124,386 MiB 16.05s 15.51s / 17.97s 141,104 MiB 10
NVIDIA ModelOpt NVFP4 pipeline 48.24s 45,060 MiB 43,845 MiB 108.34s 107.77s / 110.21s 61,728 MiB 10

Original NVIDIA Example Caption

The original model repository provides assets/example_caption.json. The images below are generated locally with the same JSON-caption, seed 1143, 1024x1024, 50 steps, guidance scale 4.0.

Variant Pipeline load Generation time Peak sampled VRAM
BF16 base pipeline 35.41s 18.01s 141,098 MiB
NVIDIA ModelOpt NVFP4 pipeline 49.10s 138.52s 61,112 MiB

BF16 reference output:

BF16 output for NVIDIA example caption

NVIDIA ModelOpt NVFP4 output:

NVIDIA ModelOpt NVFP4 output for NVIDIA example caption

Blackwell Quant-Only Runtime Check

The same published NVIDIA ModelOpt NVFP4 transformer was also tested on a one-GPU RunPod NVIDIA RTX PRO 6000 Blackwell Server Edition instance with PyTorch 2.9.1+cu130. This was a quant-only validation run: the transformer was not re-quantized, and a BF16 RTX PRO 6000 baseline was not attempted because the BF16 transformer component alone used about 122 GiB on the B200 run.

The RTX PRO 6000 run used the same 1024x1024, 50-step, guidance scale 4.0 image settings. Safety checking was disabled for this RTX PRO 6000 measurement because the container did not include cosmos_guardrail; the B200 ModelOpt NVFP4 stress run above used the safety checker enabled with per-call safety disabled.

GPU Transformer load to CUDA Transformer VRAM after load Full pipeline load Pipeline VRAM after load Avg stress generation Min / max stress generation Peak sampled stress VRAM NVIDIA example generation
NVIDIA B200 30.51s 41,548 MiB 48.24s 45,060 MiB 108.34s 107.77s / 110.21s 61,728 MiB 138.52s
NVIDIA RTX PRO 6000 Blackwell Server Edition 12.06s 41,481 MiB 19.50s 43,753 MiB 240.36s 239.19s / 242.35s 59,497 MiB 242.47s

Stress Prompt Outputs

Stress prompt NVIDIA ModelOpt NVFP4 output
01 metro archive reading room 01 metro archive reading room
02 arctic greenhouse night shift 02 arctic greenhouse night shift
03 control room restoration 03 control room restoration
04 rain market cross section 04 rain market cross section
05 manuscript restoration table 05 manuscript restoration table
06 robotic assembly line signage 06 robotic assembly line signage
07 kitchen storm chess table 07 kitchen storm chess table
08 orbital cockpit cyrillic ui 08 orbital cockpit cyrillic ui
09 flood command center 09 flood command center
10 cyrillic newspaper press 10 cyrillic newspaper press

Notes

  • Treat this as an experimental NVIDIA ModelOpt NVFP4 transformer artifact. The upstream NVIDIA card documents BF16 as the tested precision.
  • Do not pass torch_dtype=torch.bfloat16 when loading this quantized transformer; cast runtime metadata after loading as shown above.
  • The safety checker is not included in this repository; load it from the base model if your use case requires it.
  • Text rendering, especially exact Cyrillic text, remains a hard case for this model family and should be evaluated visually for the target prompt distribution.
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for WaveCut/Cosmos3-Super-Text2Image-ModelOpt-NVFP4-Transformer

Finetuned
(5)
this model

Collection including WaveCut/Cosmos3-Super-Text2Image-ModelOpt-NVFP4-Transformer