YashNagraj75 commited on
Commit
8a6ed33
·
1 Parent(s): 04cf02a

Add checkpoints its still not clear

Browse files
__pycache__/models.cpython-310.pyc ADDED
Binary file (4.31 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.67 kB). View file
 
checkpoints/model_Epoch10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf1f268fa95855fd879e4690b920b7820cf6e80848850dbacc0b479d69ff1953
3
+ size 6012986
checkpoints/model_Epoch20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55ff29fd4bbe990d286d489a88c712aef197713e5e61c536664b3c3b67232e36
3
+ size 6012986
checkpoints/model_Epoch30.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a323611d36eb7c639bb5f1de6f4d913e7889be5fe4303a8ea4212f5fcc213475
3
+ size 6012986
checkpoints/model_Epoch40.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8d804cb3d2292dc9c4242c7a5914302026a65a1bf1b70fc1ac2bd1f73e174d8
3
+ size 6012986
checkpoints/model_Epoch50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b351a9568c0eab7d0b76a090cf7827aaf9e9b54f7ffed21442619695cc4d9be1
3
+ size 6012986
model.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
models.py CHANGED
@@ -43,21 +43,27 @@ class ResidualBlock(nn.Module):
43
 
44
 
45
  class UnetUp(nn.Module):
46
- def __init__(self, in_channels, out_channels) -> None:
47
- super(UnetUp,self).__init__()
48
-
49
- self.model = nn.Sequential(
50
- nn.ConvTranspose2d(in_channels,out_channels,2,2),
51
- ResidualBlock(out_channels,out_channels),
52
- ResidualBlock(out_channels,out_channels),
53
- )
 
 
 
 
 
54
 
55
  def forward(self, x, skip):
56
- x = torch.cat([x,skip],1)
57
-
 
 
58
  x = self.model(x)
59
- return x
60
-
61
  class UnetDown(nn.Module):
62
  def __init__(self, input_channels, out_channels) -> None:
63
  super(UnetDown,self).__init__()
@@ -106,9 +112,9 @@ class ContextUnet(nn.Module):
106
  self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU())
107
 
108
  self.timeembed1 = EmbedFC(1, 2 *n_feat)
109
- self.timeembed2 = EmbedFC(1,n_feat)
110
  self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat)
111
- self.contextembed2 = EmbedFC(n_cfeat,n_feat)
112
 
113
  self.up0 = nn.Sequential(
114
  nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4),
@@ -137,15 +143,13 @@ class ContextUnet(nn.Module):
137
  if c is None:
138
  c = torch.zeros(x.shape[0],self.n_cfeat).to(x)
139
 
140
- cemb1 = self.contextembed1(c).view(-1,self.n_cfeat*2,1,1)
141
- temb1 = self.timeembed1(t).view(-1,self.n_cfeat * 2,1,1)
142
- cemb2 = self.contextembed2(c).view(-1,self.n_cfeat,1,1)
143
- temb2 = self.timeembed2(t).view(-1,self.n_cfeat,1,1)
144
-
145
- up0 = self.up0(hidden_vec)
146
- up1 =self.up1(up0*cemb1 + temb1,down2)
147
- up2 = self.up2(up1*cemb2+temb2,down1)
148
-
149
- out = self.out(torch.cat((up2,x),1))
150
-
151
  return out
 
43
 
44
 
45
  class UnetUp(nn.Module):
46
+ def __init__(self, in_channels, out_channels):
47
+ super(UnetUp, self).__init__()
48
+
49
+ # Create a list of layers for the upsampling block
50
+ # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
51
+ layers = [
52
+ nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
53
+ ResidualBlock(out_channels, out_channels),
54
+ ResidualBlock(out_channels, out_channels),
55
+ ]
56
+
57
+ # Use the layers to create a sequential model
58
+ self.model = nn.Sequential(*layers)
59
 
60
  def forward(self, x, skip):
61
+ # Concatenate the input tensor x with the skip connection tensor along the channel dimension
62
+ x = torch.cat((x, skip), 1)
63
+
64
+ # Pass the concatenated tensor through the sequential model and return the output
65
  x = self.model(x)
66
+ return x
 
67
  class UnetDown(nn.Module):
68
  def __init__(self, input_channels, out_channels) -> None:
69
  super(UnetDown,self).__init__()
 
112
  self.to_vec = nn.Sequential(nn.AvgPool2d((4)),nn.GELU())
113
 
114
  self.timeembed1 = EmbedFC(1, 2 *n_feat)
115
+ self.timeembed2 = EmbedFC(1,embed_dm=1*n_feat)
116
  self.contextembed1 = EmbedFC(n_cfeat,2 * n_feat)
117
+ self.contextembed2 = EmbedFC(n_cfeat,1*n_feat)
118
 
119
  self.up0 = nn.Sequential(
120
  nn.ConvTranspose2d(2 * n_feat,2*n_feat,self.h // 4,self.h // 4),
 
143
  if c is None:
144
  c = torch.zeros(x.shape[0],self.n_cfeat).to(x)
145
 
146
+ cemb1 = self.contextembed1(c).view(-1,self.n_feat*2,1,1)
147
+ temb1 = self.timeembed1(t).view(-1,self.n_feat * 2,1,1)
148
+ cemb2 = self.contextembed2(c).view(-1,self.n_feat,1,1)
149
+ temb2 = self.timeembed2(t).view(-1,self.n_feat,1,1)
150
+
151
+ up1 = self.up0(hidden_vec)
152
+ up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
153
+ up3 = self.up2(cemb2*up2 + temb2, down1)
154
+ out = self.out(torch.cat((up3, x), 1))
 
 
155
  return out
train.py CHANGED
@@ -3,6 +3,8 @@ from utils import *
3
  from torch.utils.data import DataLoader
4
  from models import *
5
  from tqdm.auto import tqdm
 
 
6
 
7
  timesteps = 500
8
  beta1 = 1e-4
@@ -15,7 +17,7 @@ height = 16
15
  save_dir="./checkpoints"
16
 
17
  batch_size = 100
18
- n_epoch = 40
19
  lrate = 1e-3
20
 
21
 
@@ -25,11 +27,17 @@ a_bt = torch.cumsum(a_t.log(),0).exp()
25
  a_bt[0] = 1
26
 
27
 
 
 
 
 
 
 
28
  dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
29
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
30
 
31
 
32
- nn_model = ContextUnet(3,n_feat,n_cfeat,height)
33
  optim = torch.optim.Adam(nn_model.parameters(),lrate)
34
 
35
  def perturb_input(x, t, noise):
@@ -46,7 +54,7 @@ for epoch in range(n_epoch):
46
 
47
  x = x.to(device)
48
 
49
- t = torch.randint(1,timesteps+1,x.shape[0]).to(device)
50
  noise = torch.randn_like(x)
51
  x_pert = perturb_input(x,t,noise)
52
 
@@ -56,8 +64,9 @@ for epoch in range(n_epoch):
56
  loss.backward()
57
  optim.step()
58
 
59
- if epoch % 1 == 0 and epoch >0:
 
60
  if not os.path.exists(save_dir):
61
  os.mkdir(save_dir)
62
- torch.save(nn_model,save_dir + f"model_Epoch{epoch}.pth")
63
  print("Saved model")
 
3
  from torch.utils.data import DataLoader
4
  from models import *
5
  from tqdm.auto import tqdm
6
+ import os
7
+ import torch.nn.functional as F
8
 
9
  timesteps = 500
10
  beta1 = 1e-4
 
17
  save_dir="./checkpoints"
18
 
19
  batch_size = 100
20
+ n_epoch = 60
21
  lrate = 1e-3
22
 
23
 
 
27
  a_bt[0] = 1
28
 
29
 
30
+
31
+ transform = transforms.Compose([
32
+ transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
33
+ transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
34
+
35
+ ])
36
  dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
37
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
38
 
39
 
40
+ nn_model = ContextUnet(3,n_feat,n_cfeat,height).to(device)
41
  optim = torch.optim.Adam(nn_model.parameters(),lrate)
42
 
43
  def perturb_input(x, t, noise):
 
54
 
55
  x = x.to(device)
56
 
57
+ t = torch.randint(1,timesteps+1,(x.shape[0],)).to(device)
58
  noise = torch.randn_like(x)
59
  x_pert = perturb_input(x,t,noise)
60
 
 
64
  loss.backward()
65
  optim.step()
66
 
67
+ if epoch % 10 == 0 and epoch > 0:
68
+ print(f"Epoch: {epoch} | Loss: {loss}")
69
  if not os.path.exists(save_dir):
70
  os.mkdir(save_dir)
71
+ torch.save(nn_model,save_dir + f"/model_Epoch{epoch}.pth")
72
  print("Saved model")
utils.py CHANGED
@@ -66,16 +66,10 @@ def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):
66
  return ani
67
 
68
 
69
- transform = transforms.Compose([
70
- transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
71
- transforms.Normalize((0.5,), (0.5,)) # range [-1,1]
72
-
73
- ])
74
-
75
  class CustomDataset(Dataset):
76
  def __init__(self, sfilename, lfilename, transform, null_context=False):
77
- self.sprites = np.load(sfilename)
78
- self.slabels = np.load(lfilename)
79
  print(f"sprite shape: {self.sprites.shape}")
80
  print(f"labels shape: {self.slabels.shape}")
81
  self.transform = transform
@@ -83,10 +77,13 @@ class CustomDataset(Dataset):
83
  self.sprites_shape = self.sprites.shape
84
  self.slabel_shape = self.slabels.shape
85
 
 
86
  def __len__(self):
87
  return len(self.sprites)
88
 
 
89
  def __getitem__(self, idx):
 
90
  if self.transform:
91
  image = self.transform(self.sprites[idx])
92
  if self.null_context:
@@ -94,3 +91,7 @@ class CustomDataset(Dataset):
94
  else:
95
  label = torch.tensor(self.slabels[idx]).to(torch.int64)
96
  return (image, label)
 
 
 
 
 
66
  return ani
67
 
68
 
 
 
 
 
 
 
69
  class CustomDataset(Dataset):
70
  def __init__(self, sfilename, lfilename, transform, null_context=False):
71
+ self.sprites = np.load(sfilename,allow_pickle=True,fix_imports=True,encoding='latin1')
72
+ self.slabels = np.load(lfilename,allow_pickle=True,fix_imports=True,encoding='latin1')
73
  print(f"sprite shape: {self.sprites.shape}")
74
  print(f"labels shape: {self.slabels.shape}")
75
  self.transform = transform
 
77
  self.sprites_shape = self.sprites.shape
78
  self.slabel_shape = self.slabels.shape
79
 
80
+ # Return the number of images in the dataset
81
  def __len__(self):
82
  return len(self.sprites)
83
 
84
+ # Get the image and label at a given index
85
  def __getitem__(self, idx):
86
+ # Return the image and label as a tuple
87
  if self.transform:
88
  image = self.transform(self.sprites[idx])
89
  if self.null_context:
 
91
  else:
92
  label = torch.tensor(self.slabels[idx]).to(torch.int64)
93
  return (image, label)
94
+
95
+ def getshapes(self):
96
+ # return shapes of data and labels
97
+ return self.sprites_shape, self.slabel_shape