Spaces:
Running on Zero
Running on Zero
| import os | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import spaces # noqa: F401 must precede torch / diffusers | |
| import tempfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass, transform_subclass | |
| # ZeroGPU's empty_fake calls empty_like + set_ on each parameter to build the pinned-CPU | |
| # mirror it streams from. Those ops don't make sense on tensor-subclass wrappers (NVFP4Tensor, | |
| # etc.) which contain multiple inner storages. Patch empty_fake to recurse into wrapper | |
| # subclasses via transform_subclass so each inner tensor gets packed individually. | |
| import spaces.zero.torch.patching as _zg_patching | |
| _orig_empty_fake = _zg_patching.empty_fake | |
| def _empty_fake_subclass_aware(tensor): | |
| if is_traceable_wrapper_subclass(tensor): | |
| def _per_inner(_name, inner): | |
| inner_fake = _orig_empty_fake(inner) | |
| # Register inner-tensor aliases so the packer actually packs each storage. | |
| _zg_patching.cuda_aliases[inner_fake] = inner | |
| return inner_fake | |
| return transform_subclass(tensor, _per_inner) | |
| return _orig_empty_fake(tensor) | |
| _zg_patching.empty_fake = _empty_fake_subclass_aware | |
| from diffusers import AutoModel, Cosmos3OmniPipeline, TorchAoConfig | |
| from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | |
| from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig | |
| from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor | |
| # Cosmos3's time_proj emits fp32 sinusoidals; vanilla F.linear upcasts the weight, but the | |
| # NVFP4 dispatch handlers expect input.dtype == weight.orig_dtype. Wrap the matmul-family | |
| # handlers to cast non-NVFP4 tensor inputs to the weight's orig_dtype on the fly. | |
| def _make_dtype_safe(orig_handler): | |
| def wrapped(func, types, args, kwargs): | |
| weight = next((a for a in args if isinstance(a, NVFP4Tensor)), None) | |
| if weight is not None: | |
| target = weight.orig_dtype | |
| new_args = tuple( | |
| a.to(target) if isinstance(a, torch.Tensor) | |
| and not isinstance(a, NVFP4Tensor) | |
| and a.dtype != target | |
| and a.is_floating_point() | |
| else a | |
| for a in args | |
| ) | |
| return orig_handler(func, types, new_args, kwargs) | |
| return orig_handler(func, types, args, kwargs) | |
| return wrapped | |
| _aten = torch.ops.aten | |
| _nvfp4_table = NVFP4Tensor._ATEN_OP_TABLE[NVFP4Tensor] | |
| for _f in [ | |
| torch.nn.functional.linear, | |
| _aten.linear.default, | |
| _aten.addmm.default, | |
| _aten.mm.default, | |
| _aten.matmul.default, | |
| ]: | |
| if _f in _nvfp4_table: | |
| _nvfp4_table[_f] = _make_dtype_safe(_nvfp4_table[_f]) | |
| MODEL_ID = "nvidia/Cosmos3-Super-Text2Image" | |
| quant_config = TorchAoConfig(NVFP4WeightOnlyConfig()) | |
| transformer = AutoModel.from_pretrained( | |
| MODEL_ID, | |
| subfolder="transformer", | |
| quantization_config=quant_config, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipe = Cosmos3OmniPipeline.from_pretrained( | |
| MODEL_ID, | |
| transformer=transformer, | |
| torch_dtype=torch.bfloat16, | |
| enable_safety_checker=False, | |
| ) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=3.0) | |
| pipe.to("cuda") | |
| RESOLUTIONS = { | |
| "1024×1024 (1:1)": (1024, 1024), | |
| "1280×720 (16:9)": (1280, 720), | |
| "720×1280 (9:16)": (720, 1280), | |
| "1024×768 (4:3)": (1024, 768), | |
| "768×1024 (3:4)": (768, 1024), | |
| } | |
| def _duration(prompt, resolution, steps, *_): | |
| w, h = RESOLUTIONS[resolution] | |
| # Measured: ~14s per step at 1024×1024 NVFP4 dequant; scale by pixel count + 25% margin. | |
| per_step = 18 * (w * h) / (1024 * 1024) | |
| return min(1500, int(60 + per_step * int(steps))) | |
| def generate( | |
| prompt, | |
| resolution, | |
| steps, | |
| guidance, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please enter a prompt.") | |
| width, height = RESOLUTIONS[resolution] | |
| if randomize_seed: | |
| seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) | |
| generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt or None, | |
| num_frames=1, | |
| height=height, | |
| width=width, | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance), | |
| generator=generator, | |
| output_type="pil", | |
| ) | |
| img = result.video[0] | |
| out_dir = Path(tempfile.mkdtemp(prefix="cosmos3_")) | |
| p = out_dir / "image.png" | |
| img.save(p) | |
| return str(p), seed | |
| CSS = """ | |
| .gradio-container { max-width: 1100px !important; margin: auto !important; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=CSS, title="Cosmos3-Super · NVFP4") as demo: | |
| gr.Markdown( | |
| "# Cosmos3-Super-Text2Image · NVFP4\n" | |
| "[nvidia/Cosmos3-Super-Text2Image](https://huggingface.co/nvidia/Cosmos3-Super-Text2Image) " | |
| "(64B) with NVFP4 quantization." | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| show_label=False, | |
| placeholder="A photo of a robot reading a book under a cherry tree…", | |
| container=False, | |
| scale=4, | |
| ) | |
| run = gr.Button("Generate", variant="primary", scale=1) | |
| out = gr.Image(label="Output", type="filepath", format="png", height=640) | |
| with gr.Accordion("Advanced settings", open=False): | |
| negative_prompt = gr.Textbox(label="Negative prompt", value="") | |
| resolution = gr.Dropdown( | |
| label="Resolution", | |
| choices=list(RESOLUTIONS), | |
| value="1024×1024 (1:1)", | |
| ) | |
| steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, value=35, step=1) | |
| guidance = gr.Slider(label="Guidance scale", minimum=1.0, maximum=8.0, value=4.0, step=0.1) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| seed = gr.Number(label="Seed", value=0, precision=0) | |
| inputs = [prompt, resolution, steps, guidance, negative_prompt, seed, randomize_seed] | |
| outputs = [out, seed] | |
| run.click(generate, inputs, outputs) | |
| prompt.submit(generate, inputs, outputs) | |
| demo.queue().launch() | |