multimodalart's picture
multimodalart HF Staff
Upload app.py with huggingface_hub
fb0c49e verified
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces # noqa: F401 must precede torch / diffusers
import tempfile
from pathlib import Path
import gradio as gr
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass, transform_subclass
# ZeroGPU's empty_fake calls empty_like + set_ on each parameter to build the pinned-CPU
# mirror it streams from. Those ops don't make sense on tensor-subclass wrappers (NVFP4Tensor,
# etc.) which contain multiple inner storages. Patch empty_fake to recurse into wrapper
# subclasses via transform_subclass so each inner tensor gets packed individually.
import spaces.zero.torch.patching as _zg_patching
_orig_empty_fake = _zg_patching.empty_fake
def _empty_fake_subclass_aware(tensor):
if is_traceable_wrapper_subclass(tensor):
def _per_inner(_name, inner):
inner_fake = _orig_empty_fake(inner)
# Register inner-tensor aliases so the packer actually packs each storage.
_zg_patching.cuda_aliases[inner_fake] = inner
return inner_fake
return transform_subclass(tensor, _per_inner)
return _orig_empty_fake(tensor)
_zg_patching.empty_fake = _empty_fake_subclass_aware
from diffusers import AutoModel, Cosmos3OmniPipeline, TorchAoConfig
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
# Cosmos3's time_proj emits fp32 sinusoidals; vanilla F.linear upcasts the weight, but the
# NVFP4 dispatch handlers expect input.dtype == weight.orig_dtype. Wrap the matmul-family
# handlers to cast non-NVFP4 tensor inputs to the weight's orig_dtype on the fly.
def _make_dtype_safe(orig_handler):
def wrapped(func, types, args, kwargs):
weight = next((a for a in args if isinstance(a, NVFP4Tensor)), None)
if weight is not None:
target = weight.orig_dtype
new_args = tuple(
a.to(target) if isinstance(a, torch.Tensor)
and not isinstance(a, NVFP4Tensor)
and a.dtype != target
and a.is_floating_point()
else a
for a in args
)
return orig_handler(func, types, new_args, kwargs)
return orig_handler(func, types, args, kwargs)
return wrapped
_aten = torch.ops.aten
_nvfp4_table = NVFP4Tensor._ATEN_OP_TABLE[NVFP4Tensor]
for _f in [
torch.nn.functional.linear,
_aten.linear.default,
_aten.addmm.default,
_aten.mm.default,
_aten.matmul.default,
]:
if _f in _nvfp4_table:
_nvfp4_table[_f] = _make_dtype_safe(_nvfp4_table[_f])
MODEL_ID = "nvidia/Cosmos3-Super-Text2Image"
quant_config = TorchAoConfig(NVFP4WeightOnlyConfig())
transformer = AutoModel.from_pretrained(
MODEL_ID,
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
pipe = Cosmos3OmniPipeline.from_pretrained(
MODEL_ID,
transformer=transformer,
torch_dtype=torch.bfloat16,
enable_safety_checker=False,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=3.0)
pipe.to("cuda")
RESOLUTIONS = {
"1024×1024 (1:1)": (1024, 1024),
"1280×720 (16:9)": (1280, 720),
"720×1280 (9:16)": (720, 1280),
"1024×768 (4:3)": (1024, 768),
"768×1024 (3:4)": (768, 1024),
}
def _duration(prompt, resolution, steps, *_):
w, h = RESOLUTIONS[resolution]
# Measured: ~14s per step at 1024×1024 NVFP4 dequant; scale by pixel count + 25% margin.
per_step = 18 * (w * h) / (1024 * 1024)
return min(1500, int(60 + per_step * int(steps)))
@spaces.GPU(duration=_duration, size="xlarge")
def generate(
prompt,
resolution,
steps,
guidance,
negative_prompt,
seed,
randomize_seed,
progress=gr.Progress(track_tqdm=True),
):
if not prompt or not prompt.strip():
raise gr.Error("Please enter a prompt.")
width, height = RESOLUTIONS[resolution]
if randomize_seed:
seed = int(torch.randint(0, 2**31 - 1, (1,)).item())
generator = torch.Generator(device="cuda").manual_seed(int(seed))
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt or None,
num_frames=1,
height=height,
width=width,
num_inference_steps=int(steps),
guidance_scale=float(guidance),
generator=generator,
output_type="pil",
)
img = result.video[0]
out_dir = Path(tempfile.mkdtemp(prefix="cosmos3_"))
p = out_dir / "image.png"
img.save(p)
return str(p), seed
CSS = """
.gradio-container { max-width: 1100px !important; margin: auto !important; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=CSS, title="Cosmos3-Super · NVFP4") as demo:
gr.Markdown(
"# Cosmos3-Super-Text2Image · NVFP4\n"
"[nvidia/Cosmos3-Super-Text2Image](https://huggingface.co/nvidia/Cosmos3-Super-Text2Image) "
"(64B) with NVFP4 quantization."
)
with gr.Row():
prompt = gr.Textbox(
show_label=False,
placeholder="A photo of a robot reading a book under a cherry tree…",
container=False,
scale=4,
)
run = gr.Button("Generate", variant="primary", scale=1)
out = gr.Image(label="Output", type="filepath", format="png", height=640)
with gr.Accordion("Advanced settings", open=False):
negative_prompt = gr.Textbox(label="Negative prompt", value="")
resolution = gr.Dropdown(
label="Resolution",
choices=list(RESOLUTIONS),
value="1024×1024 (1:1)",
)
steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, value=35, step=1)
guidance = gr.Slider(label="Guidance scale", minimum=1.0, maximum=8.0, value=4.0, step=0.1)
with gr.Row():
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Number(label="Seed", value=0, precision=0)
inputs = [prompt, resolution, steps, guidance, negative_prompt, seed, randomize_seed]
outputs = [out, seed]
run.click(generate, inputs, outputs)
prompt.submit(generate, inputs, outputs)
demo.queue().launch()