Zai commited on
Commit
2ca3b3b
·
1 Parent(s): c6d21a3

Start data loading and training process

Browse files
scripts/train.py CHANGED
@@ -1,7 +1,23 @@
 
 
 
 
 
1
  from umi.config import Config
2
  from umi.models.unet import create_model
 
 
3
 
4
  if __name__ == '__main__':
5
- model_config = Config()
6
 
7
- model = create_model(model_config)
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import torch
3
+ from accelerate import Accelerator
4
+ from tqdm import tqdm
5
+ from datasets import load_dataset
6
  from umi.config import Config
7
  from umi.models.unet import create_model
8
+ from umi.datasets import CIFAR10Dataset
9
+ from diffusers import DDPMPipeline, DDPMScheduler
10
 
11
  if __name__ == '__main__':
12
+ config = Config()
13
 
14
+ model = create_model(config)
15
+ dataset = load_dataset("cifar10", split="train")
16
+ dataset = CIFAR10Dataset(dataset, transform=config.transform)
17
+ dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
18
+
19
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
20
+
21
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
22
+ accelerator = Accelerator()
23
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
umi/config.py CHANGED
@@ -1,6 +1,8 @@
1
  from torchvision import transforms
2
 
3
  class Config:
 
 
4
  transform = transforms.Compose([
5
  transforms.Resize(32),
6
  transforms.CenterCrop(32),
 
1
  from torchvision import transforms
2
 
3
  class Config:
4
+ epochs :int = 4
5
+
6
  transform = transforms.Compose([
7
  transforms.Resize(32),
8
  transforms.CenterCrop(32),
umi/datasets/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from cifar10 import CIFAR10Dataset
umi/datasets/cifar10.py CHANGED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+
3
+ class CIFAR10Dataset(Dataset):
4
+ def __init__(self, dataset, transform=None):
5
+ self.dataset = dataset
6
+ self.transform = transform
7
+
8
+ def __len__(self):
9
+ return len(self.dataset)
10
+
11
+ def __getitem__(self, idx):
12
+ img = self.dataset[idx]["img"]
13
+ if self.transform:
14
+ img = self.transform(img)
15
+ return {"image": img}