Cosmos3-Nano-FP8 / load_quantized.py
Reza2kn's picture
Upload folder using huggingface_hub
489a947 verified
"""Load this quantized Cosmos3-Nano. Requires: diffusers (git main / >=0.39), nvidia-modelopt, torch (cu128).
from load_quantized import load
pipe = load() # uses this repo, or pass a repo id / local dir
import torch
with torch.autocast("cuda", torch.bfloat16):
img = pipe("a corgi astronaut", num_frames=1, height=480, width=480).video[0][0]
"""
import os, torch
from diffusers import Cosmos3OmniPipeline, Cosmos3OmniTransformer
import modelopt.torch.opt as mto
def load(repo_or_dir=".", device="cuda"):
if os.path.isdir(repo_or_dir):
local = repo_or_dir
else:
from huggingface_hub import snapshot_download
local = snapshot_download(repo_or_dir)
tf = Cosmos3OmniTransformer.from_config(
Cosmos3OmniTransformer.load_config(f"{local}/transformer/config.json")).to(torch.bfloat16)
mto.restore(tf, f"{local}/transformer/modelopt_quantized.pt") # restores 4-bit weights
pipe = Cosmos3OmniPipeline.from_pretrained(
local, transformer=tf, torch_dtype=torch.bfloat16, enable_safety_checker=False)
return pipe.to(device)
if __name__ == "__main__":
pipe = load()
with torch.autocast("cuda", dtype=torch.bfloat16): # required: float32 rotary tensors -> bf16 linears
img = pipe("A red panda astronaut floating in a nebula, highly detailed",
num_frames=1, height=480, width=480).video[0][0]
img.save("out.png"); print("saved out.png")