import os os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import spaces # noqa: F401 must precede torch / diffusers import tempfile from pathlib import Path import gradio as gr import torch from torch.utils._python_dispatch import is_traceable_wrapper_subclass, transform_subclass # ZeroGPU's empty_fake calls empty_like + set_ on each parameter to build the pinned-CPU # mirror it streams from. Those ops don't make sense on tensor-subclass wrappers (NVFP4Tensor, # etc.) which contain multiple inner storages. Patch empty_fake to recurse into wrapper # subclasses via transform_subclass so each inner tensor gets packed individually. import spaces.zero.torch.patching as _zg_patching _orig_empty_fake = _zg_patching.empty_fake def _empty_fake_subclass_aware(tensor): if is_traceable_wrapper_subclass(tensor): def _per_inner(_name, inner): inner_fake = _orig_empty_fake(inner) # Register inner-tensor aliases so the packer actually packs each storage. _zg_patching.cuda_aliases[inner_fake] = inner return inner_fake return transform_subclass(tensor, _per_inner) return _orig_empty_fake(tensor) _zg_patching.empty_fake = _empty_fake_subclass_aware from diffusers import AutoModel, Cosmos3OmniPipeline, TorchAoConfig from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor # Cosmos3's time_proj emits fp32 sinusoidals; vanilla F.linear upcasts the weight, but the # NVFP4 dispatch handlers expect input.dtype == weight.orig_dtype. Wrap the matmul-family # handlers to cast non-NVFP4 tensor inputs to the weight's orig_dtype on the fly. def _make_dtype_safe(orig_handler): def wrapped(func, types, args, kwargs): weight = next((a for a in args if isinstance(a, NVFP4Tensor)), None) if weight is not None: target = weight.orig_dtype new_args = tuple( a.to(target) if isinstance(a, torch.Tensor) and not isinstance(a, NVFP4Tensor) and a.dtype != target and a.is_floating_point() else a for a in args ) return orig_handler(func, types, new_args, kwargs) return orig_handler(func, types, args, kwargs) return wrapped _aten = torch.ops.aten _nvfp4_table = NVFP4Tensor._ATEN_OP_TABLE[NVFP4Tensor] for _f in [ torch.nn.functional.linear, _aten.linear.default, _aten.addmm.default, _aten.mm.default, _aten.matmul.default, ]: if _f in _nvfp4_table: _nvfp4_table[_f] = _make_dtype_safe(_nvfp4_table[_f]) MODEL_ID = "nvidia/Cosmos3-Super-Text2Image" quant_config = TorchAoConfig(NVFP4WeightOnlyConfig()) transformer = AutoModel.from_pretrained( MODEL_ID, subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16, ) pipe = Cosmos3OmniPipeline.from_pretrained( MODEL_ID, transformer=transformer, torch_dtype=torch.bfloat16, enable_safety_checker=False, ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=3.0) pipe.to("cuda") RESOLUTIONS = { "1024×1024 (1:1)": (1024, 1024), "1280×720 (16:9)": (1280, 720), "720×1280 (9:16)": (720, 1280), "1024×768 (4:3)": (1024, 768), "768×1024 (3:4)": (768, 1024), } def _duration(prompt, resolution, steps, *_): w, h = RESOLUTIONS[resolution] # Measured: ~14s per step at 1024×1024 NVFP4 dequant; scale by pixel count + 25% margin. per_step = 18 * (w * h) / (1024 * 1024) return min(1500, int(60 + per_step * int(steps))) @spaces.GPU(duration=_duration, size="xlarge") def generate( prompt, resolution, steps, guidance, negative_prompt, seed, randomize_seed, progress=gr.Progress(track_tqdm=True), ): if not prompt or not prompt.strip(): raise gr.Error("Please enter a prompt.") width, height = RESOLUTIONS[resolution] if randomize_seed: seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) generator = torch.Generator(device="cuda").manual_seed(int(seed)) result = pipe( prompt=prompt, negative_prompt=negative_prompt or None, num_frames=1, height=height, width=width, num_inference_steps=int(steps), guidance_scale=float(guidance), generator=generator, output_type="pil", ) img = result.video[0] out_dir = Path(tempfile.mkdtemp(prefix="cosmos3_")) p = out_dir / "image.png" img.save(p) return str(p), seed CSS = """ .gradio-container { max-width: 1100px !important; margin: auto !important; } """ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, title="Cosmos3-Super · NVFP4") as demo: gr.Markdown( "# Cosmos3-Super-Text2Image · NVFP4\n" "[nvidia/Cosmos3-Super-Text2Image](https://huggingface.co/nvidia/Cosmos3-Super-Text2Image) " "(64B) with NVFP4 quantization." ) with gr.Row(): prompt = gr.Textbox( show_label=False, placeholder="A photo of a robot reading a book under a cherry tree…", container=False, scale=4, ) run = gr.Button("Generate", variant="primary", scale=1) out = gr.Image(label="Output", type="filepath", format="png", height=640) with gr.Accordion("Advanced settings", open=False): negative_prompt = gr.Textbox(label="Negative prompt", value="") resolution = gr.Dropdown( label="Resolution", choices=list(RESOLUTIONS), value="1024×1024 (1:1)", ) steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, value=35, step=1) guidance = gr.Slider(label="Guidance scale", minimum=1.0, maximum=8.0, value=4.0, step=0.1) with gr.Row(): randomize_seed = gr.Checkbox(label="Randomize seed", value=True) seed = gr.Number(label="Seed", value=0, precision=0) inputs = [prompt, resolution, steps, guidance, negative_prompt, seed, randomize_seed] outputs = [out, seed] run.click(generate, inputs, outputs) prompt.submit(generate, inputs, outputs) demo.queue().launch()