Spaces:
Build error
Build error
File size: 1,912 Bytes
2ca3b3b c6d21a3 2ca3b3b c6d21a3 7d9bd81 2ca3b3b c6d21a3 2ca3b3b a58a9ba 2ca3b3b bdbcf89 7d9bd81 bdbcf89 7d9bd81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from torch.utils.data import DataLoader
import torch
from accelerate import Accelerator
from tqdm import tqdm
from datasets import load_dataset
from umi.config import Config
from umi.models.unet import create_model
from umi.datasets import CIFAR10Dataset
from diffusers import DDPMPipeline, DDPMScheduler
if __name__ == "__main__":
config = Config()
model = create_model(config)
dataset = load_dataset("cifar10", split="train")
dataset = CIFAR10Dataset(dataset, transform=config.transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_steps)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
accelerator = Accelerator()
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for epoch in range(config.epochs):
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
for step, batch in enumerate(progress_bar):
clean_images = batch["image"]
noise = torch.randn_like(clean_images)
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(clean_images.shape[0],),
device=clean_images.device,
).long()
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
model_output = model(noisy_images, timesteps).sample
loss = torch.nn.functional.mse_loss(model_output, noise)
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
progress_bar.set_postfix(loss=loss.item())
model.save_pretrained("ddpm-cifar10")
noise_scheduler.save_pretrained("ddpm-cifar10")
# Push to huggingface
# model.push_to_hub("zaibutcooler/umi")
# noise_scheduler.push_to_hub("zaibutcooler/umi")
|