Cosmos3-Nano-FP8 / load_quantized.py
wfen's picture
fix confliction
db04faa
Raw
History Blame Contribute Delete
3.02 kB
"""Load this quantized Cosmos3-Nano (FP8, safetensors). Self-contained — no project `src/` needed.
Requires: diffusers (git main / >=0.39), nvidia-modelopt, torch (cu128), safetensors.
from load_quantized import load
pipe = load() # uses this dir, or pass a repo id / local dir
import torch
with torch.autocast("cuda", torch.bfloat16):
img = pipe("a corgi astronaut", num_frames=1, height=480, width=480).video[0][0]
Format (Path B; see ../docs/reports/session_3.md): the FP8 transformer is stored as **safetensors**
(`transformer/diffusion_pytorch_model.safetensors`: 505 weight-only E4M3 weights + per-tensor
`weight_quantizer._amax`/`._scale` buffers) plus a tiny tensor-free `transformer/modelopt_state.pt`
structural sidecar (the quantizer layout). The original `transformer/modelopt_quantized.pt` is
**retained** as a fallback (loadable via `modelopt.torch.opt.restore`); this loader does NOT use it.
SECURITY: `modelopt_state.pt` is loaded with `torch.load(weights_only=False)`, which executes
arbitrary pickle. Load this checkpoint ONLY from a source you trust (a tampered sidecar = remote
code execution). The safetensors weights themselves are safe; only the structural sidecar is pickle.
"""
import glob
import os
import torch
from diffusers import Cosmos3OmniPipeline, Cosmos3OmniTransformer, UniPCMultistepScheduler
import modelopt.torch.opt as mto
from safetensors.torch import load_file
def load_transformer(local):
"""Materialize the quantized transformer from safetensors + the structural sidecar (no `.pt`)."""
cfg = {**Cosmos3OmniTransformer.load_config(f"{local}/transformer/config.json"), "action_gen": False}
tf = Cosmos3OmniTransformer.from_config(cfg).to(torch.bfloat16)
state = torch.load(f"{local}/transformer/modelopt_state.pt", weights_only=False)
restored = mto.restore_from_modelopt_state(tf, state)
if restored is not None:
tf = restored
tensors = {}
for shard in sorted(glob.glob(f"{local}/transformer/*.safetensors")):
tensors.update(load_file(shard))
tf.load_state_dict(tensors, strict=True)
return tf
def load(repo_or_dir=".", device="cuda"):
if os.path.isdir(repo_or_dir):
local = repo_or_dir
else:
from huggingface_hub import snapshot_download
local = snapshot_download(repo_or_dir)
tf = load_transformer(local)
pipe = Cosmos3OmniPipeline.from_pretrained(
local, transformer=tf, torch_dtype=torch.bfloat16, enable_safety_checker=False
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=10.0)
return pipe.to(device)
if __name__ == "__main__":
pipe = load()
with torch.autocast("cuda", dtype=torch.bfloat16): # required: float32 rotary tensors -> bf16 linears
img = pipe("A red panda astronaut floating in a nebula, highly detailed",
num_frames=1, height=480, width=480).video[0][0]
img.save("out.png")
print("saved out.png")