File size: 1,223 Bytes
8a70e8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from causvid.models.sdxl.sdxl_wrapper import SDXLWrapper, SDXLTextEncoder, SDXLVAE
from huggingface_hub import hf_hub_download
from PIL import Image
import torch

torch.set_grad_enabled(False)

model = SDXLWrapper().to("cuda")
encoder = SDXLTextEncoder().to("cuda")
vae = SDXLVAE().to("cuda")

repo_name = "tianweiy/DMD2"
ckpt_name = "dmd2_sdxl_1step_unet_fp16.bin"
model.model.load_state_dict(torch.load(
    hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
model.set_module_grad(
    {
        "model": False
    }
)

conditional_dict = encoder(
    text_prompts=["a photo of a cat"] * 4
)

noise = torch.randn(
    4, 1, 4, 128, 128, generator=torch.Generator().manual_seed(42)
).to("cuda")

timetep = 399 * torch.ones([4, 1], device="cuda", dtype=torch.long)
output = model(noise, conditional_dict, timetep)

# [B, F, C, H, W] -> [B, C, H, W]
images = vae.decode_to_pixel(output).squeeze(1)

images = ((images + 1.0) * 127.5).clamp(0,
                                        255).to(torch.uint8).permute(0, 2, 3, 1)

output_image_list = []
for index, image in enumerate(images):
    output_image_list.append(Image.fromarray(image.cpu().numpy()))

    output_image_list[-1].save(f"output_image_{index}.jpg")