File size: 1,477 Bytes
489a947
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""Load this quantized Cosmos3-Nano. Requires: diffusers (git main / >=0.39), nvidia-modelopt, torch (cu128).
    from load_quantized import load
    pipe = load()                       # uses this repo, or pass a repo id / local dir
    import torch
    with torch.autocast("cuda", torch.bfloat16):
        img = pipe("a corgi astronaut", num_frames=1, height=480, width=480).video[0][0]
"""
import os, torch
from diffusers import Cosmos3OmniPipeline, Cosmos3OmniTransformer
import modelopt.torch.opt as mto

def load(repo_or_dir=".", device="cuda"):
    if os.path.isdir(repo_or_dir):
        local = repo_or_dir
    else:
        from huggingface_hub import snapshot_download
        local = snapshot_download(repo_or_dir)
    tf = Cosmos3OmniTransformer.from_config(
        Cosmos3OmniTransformer.load_config(f"{local}/transformer/config.json")).to(torch.bfloat16)
    mto.restore(tf, f"{local}/transformer/modelopt_quantized.pt")   # restores 4-bit weights
    pipe = Cosmos3OmniPipeline.from_pretrained(
        local, transformer=tf, torch_dtype=torch.bfloat16, enable_safety_checker=False)
    return pipe.to(device)

if __name__ == "__main__":
    pipe = load()
    with torch.autocast("cuda", dtype=torch.bfloat16):  # required: float32 rotary tensors -> bf16 linears
        img = pipe("A red panda astronaut floating in a nebula, highly detailed",
                   num_frames=1, height=480, width=480).video[0][0]
    img.save("out.png"); print("saved out.png")