| | |
| | from diffusers import UNet2DConditionModel |
| | import torch |
| |
|
| | unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16) |
| | unet.train() |
| | unet.enable_gradient_checkpointing() |
| | unet = unet.to("cuda:1") |
| |
|
| | batch_size = 8 |
| |
|
| | sample = torch.randn((1, 4, 128, 128)).half().to(unet.device).repeat(batch_size, 1, 1, 1) |
| | time_ids = (torch.arange(6) / 6)[None, :].half().to(unet.device).repeat(batch_size, 1) |
| | encoder_hidden_states = torch.randn((1, 77, 2048)).half().to(unet.device).repeat(batch_size, 1, 1) |
| | text_embeds = torch.randn((1, 1280)).half().to(unet.device).repeat(batch_size, 1) |
| |
|
| | out = unet(sample, 1.0, added_cond_kwargs={"time_ids": time_ids, "text_embeds": text_embeds}, encoder_hidden_states=encoder_hidden_states).sample |
| |
|
| | loss = ((out - sample) ** 2).mean() |
| | loss.backward() |
| |
|
| | print(torch.cuda.max_memory_allocated(device=unet.device)) |
| |
|
| |
|
| | |
| | |
| |
|