detectivejoewest commited on
Commit
582b238
·
verified ·
1 Parent(s): 5623993

Upload 7 files

Browse files
Files changed (7) hide show
  1. RoPE.py +22 -0
  2. attention.py +63 -0
  3. autoencoder.py +48 -0
  4. autoencoder_test.py +31 -0
  5. objectives.py +55 -0
  6. train.py +81 -0
  7. trainer.py +98 -0
RoPE.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+ def generate_angles_2d(H,W,D, freq=None):
5
+ freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))]) if freq is None else freq
6
+ pos = torch.outer(torch.linspace(-1, 1, steps=H),torch.linspace(-1, 1, steps=W))
7
+ freq_tensor = torch.einsum("ij,k->ijk", pos, freq)
8
+ return freq_tensor
9
+
10
+ def apply_angles_2d(x, f):
11
+ x_reshaped = rearrange(x, "B h H W (D p) -> B h H W D p", p=2)
12
+ real = x_reshaped[..., 0]
13
+ imag = x_reshaped[..., 1]
14
+ cosines, sines = f.cos(), f.sin()
15
+ # r , i -> rcos-isin , rsin icos
16
+ rot_real = real * cosines - imag * sines
17
+ rot_imag = real * sines + imag * cosines
18
+ rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
19
+ return rearrange(rot_full, "B h H W D p -> B h H W (D p)", p=2)
20
+
21
+ # Sanity Check :)
22
+ print(apply_angles_2d(torch.randn(1,8,64,64,768), generate_angles_2d(64,64,768)).shape)
attention.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from RoPE import apply_angles_2d, generate_angles_2d
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+
7
+
8
+ class Attention(nn.Module):
9
+ def __init__(self, H,W, emb_dim, n_heads=8):
10
+ super().__init__()
11
+ self.H = H
12
+ self.W = W
13
+ self.n_heads = n_heads
14
+ head_dim = emb_dim // n_heads
15
+ self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
16
+ self.apply_angles_2d = apply_angles_2d
17
+ self.proj = nn.Linear(emb_dim, emb_dim)
18
+ self.register_buffer("freq", generate_angles_2d(H, W, head_dim), persistent=False)
19
+
20
+ def forward(self, x):
21
+ B, N, D = x.shape
22
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
23
+
24
+ # to 2D
25
+ q = rearrange(q, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
26
+ k = rearrange(k, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
27
+ v = rearrange(v, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
28
+
29
+ q = apply_angles_2d(q, self.freq)
30
+ k = apply_angles_2d(k, self.freq)
31
+ v = apply_angles_2d(v, self.freq)
32
+
33
+ # to 1D
34
+ q = rearrange(q, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
35
+ k = rearrange(k, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
36
+ v = rearrange(v, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
37
+
38
+ x = F.scaled_dot_product_attention(q, k, v)
39
+ x = rearrange(x, "B h N D -> B N (h D)")
40
+ x = self.proj(x)
41
+ return x
42
+
43
+ class ViTBlock(nn.Module):
44
+ def __init__(self, H, W, emb_dim, n_heads=8, dropout=0.1):
45
+ self.H, self.W, self.emb_dim = H, W, emb_dim
46
+ super().__init__()
47
+ self.attn = nn.Sequential(nn.LayerNorm(emb_dim),
48
+ Attention(H,W,emb_dim,n_heads=n_heads))
49
+ self.MLP = nn.Sequential(nn.LayerNorm(emb_dim),
50
+ nn.Linear(emb_dim, emb_dim*4, bias=True),
51
+ nn.GELU(),
52
+ nn.Dropout(dropout),
53
+ nn.Linear(emb_dim*4, emb_dim, bias=True),
54
+ nn.Dropout(dropout))
55
+ def forward(self, x):
56
+ assert x.ndim == 3, f"Expected shape [B, N, D], but got shape {x.shape}. You probably passed [B, H, W, D] instead."
57
+ assert x.shape == torch.Size([x.shape[0], self.H * self.W, self.emb_dim]), f"Expected shape [B, N, D] -> {torch.Size([x.shape[0], self.H * self.W, self.emb_dim])}, got {x.shape}"
58
+ x = x + self.attn(x)
59
+ x = x + self.MLP(x)
60
+ return x
61
+
62
+ # Sanity Check :)
63
+ print(ViTBlock(64,64,384)(torch.randn(1, 64**2, 384)).shape)
autoencoder.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from attention import ViTBlock
6
+
7
+ # Global Parameters
8
+ image_shape = 256
9
+ emb_dim = 768
10
+ patch_size = 16
11
+
12
+ class Encoder(nn.Module):
13
+ def __init__(self, latent_dim, image_shape=image_shape, emb_dim=emb_dim, patch_size=patch_size, n_heads=8, dropout=0.1, layers=6, gaussian=False):
14
+ super().__init__()
15
+ self.patchifier = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)
16
+ self.Blocks = nn.ModuleList([ViTBlock(image_shape // patch_size, image_shape // patch_size, emb_dim, n_heads=8, dropout=dropout) for _ in range(layers)])
17
+ self.ln = nn.LayerNorm(emb_dim)
18
+ self.compress_latent = nn.Linear(emb_dim, latent_dim)
19
+
20
+ def forward(self,x):
21
+ x = self.patchifier(x)
22
+ x = rearrange(x, "B D H W -> B (H W) D") # Flatten to B, N, D
23
+ for vitBlock in self.Blocks:
24
+ x = vitBlock(x)
25
+ x = self.ln(x)
26
+ x = self.compress_latent(x)
27
+ return x
28
+
29
+ class Decoder(nn.Module):
30
+ def __init__(self, latent_dim, image_shape=image_shape, emb_dim=emb_dim, patch_size=patch_size, n_heads=8, dropout=0.1, layers=6, gaussian=False):
31
+ super().__init__()
32
+ self.hw = image_shape // patch_size
33
+ self.patch_size = patch_size
34
+ self.decompress_latent = nn.Linear(latent_dim, emb_dim)
35
+ self.ln = nn.LayerNorm(emb_dim)
36
+ self.emb_to_patch = nn.Linear(emb_dim, 3*(patch_size**2))
37
+ self.Blocks = nn.ModuleList([ViTBlock(image_shape // patch_size, image_shape // patch_size, emb_dim, n_heads=8, dropout=dropout) for _ in range(layers)])
38
+
39
+ def forward(self,x):
40
+ x = self.decompress_latent(x)
41
+ for vitBlock in self.Blocks:
42
+ x = vitBlock(x)
43
+ self.ln(x)
44
+ #shape is [B HW/p**2 (3 p p)]
45
+ x = self.emb_to_patch(x)
46
+ assert x.shape == torch.Size([x.shape[0], self.hw**2, 3*(self.patch_size**2)]), f"Expected shape {torch.Size([x.shape[0], self.hw**2, 3*(self.patch_size**2)])} got {x.shape}"
47
+ x = rearrange(x, "B (H W) (D p1 p2) -> B D (H p1) (W p2)", H=self.hw, W=self.hw, p1=self.patch_size, p2=self.patch_size) # Expand to B, H, W, D
48
+ return F.tanh(x)
autoencoder_test.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from autoencoder import Encoder, Decoder
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import cv2
6
+
7
+ image_shape = 256
8
+ emb_dim = 768
9
+ patch_size = 16
10
+
11
+ encoder = Encoder(latent_dim=16,
12
+ image_shape=image_shape,
13
+ emb_dim=emb_dim,
14
+ patch_size=patch_size)
15
+ encoder.load_state_dict(torch.load("encoder16.pt", map_location=torch.device('cpu')))
16
+
17
+ decoder = Decoder(latent_dim=16,
18
+ image_shape=image_shape,
19
+ emb_dim=emb_dim,
20
+ patch_size=patch_size)
21
+ decoder.load_state_dict(torch.load("decoder16.pt", map_location=torch.device('cpu')))
22
+
23
+ image = cv2.imread("test_image.jpg")
24
+ image = cv2.resize(image, (image_shape, image_shape))
25
+ image = torch.tensor(image, dtype=torch.float32, device='cpu').permute(2, 0, 1) / 127.5 - 1.0
26
+ image = image.unsqueeze(0)
27
+ with torch.no_grad():
28
+ z = encoder(image)
29
+ x = decoder(z)
30
+ plt.imshow(x[0].permute(1, 2, 0).numpy()*0.5 + 0.5)
31
+ plt.show()
objectives.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision.models import vgg16, VGG16_Weights
3
+
4
+ class Discriminator(nn.Module):
5
+ def __init__(self, img_shape, filters=[256,512]):
6
+ super().__init__()
7
+ module_list = [nn.Conv2d(img_shape[0], filters[0], kernel_size=3, stride=2, padding=1),
8
+ nn.BatchNorm2d(filters[0]),
9
+ nn.LeakyReLU(0.2)]
10
+ for i in range(1,len(filters)):
11
+ module_list += [nn.Conv2d(filters[i-1], filters[i], kernel_size=3, stride=2, padding=1),
12
+ nn.BatchNorm2d(filters[i]),
13
+ nn.LeakyReLU(0.2)]
14
+
15
+ self.convs = nn.Sequential(*module_list)
16
+ self.mlp = nn.Sequential(nn.Conv2d(filters[-1], 1, kernel_size=1, stride=1, padding=0))
17
+
18
+ def forward(self, x):
19
+ x = self.convs(x)
20
+ x = self.mlp(x)
21
+ return x
22
+
23
+ class vgg_builder(nn.Module):
24
+ def __init__(self):
25
+ super(vgg_builder, self).__init__()
26
+ convs = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
27
+ self.N_slices = 5
28
+ self.slices = nn.ModuleList(list(nn.Sequential() for _ in range(self.N_slices)))
29
+ for x in range(4):
30
+ self.slices[0].add_module(str(x), convs[x])
31
+ for x in range(4, 9):
32
+ self.slices[1].add_module(str(x), convs[x])
33
+ for x in range(9, 16):
34
+ self.slices[2].add_module(str(x), convs[x])
35
+ for x in range(16, 23):
36
+ self.slices[3].add_module(str(x), convs[x])
37
+ for x in range(23, 30):
38
+ self.slices[4].add_module(str(x), convs[x])
39
+ for param in self.parameters():
40
+ param.requires_grad = False
41
+
42
+ def forward(self, x):
43
+ feat_map = []
44
+ x = (x+1)/2
45
+ x = self.slices[0](x)
46
+ feat_map.append(x)
47
+ x = self.slices[1](x)
48
+ feat_map.append(x)
49
+ x = self.slices[2](x)
50
+ feat_map.append(x)
51
+ x = self.slices[3](x)
52
+ feat_map.append(x)
53
+ x = self.slices[4](x)
54
+ feat_map.append(x)
55
+ return feat_map
train.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kagglehub
2
+ import cv2
3
+ import os
4
+ from IPython.display import clear_output
5
+ import torch
6
+ import numpy as np
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import torch.nn as nn
9
+ import matplotlib.pyplot as plt
10
+ from autoencoder import Encoder, Decoder
11
+ from trainer import Trainer
12
+ from objectives import Discriminator, vgg_builder
13
+
14
+ # Global Parameters
15
+ image_shape = 256
16
+ emb_dim = 768
17
+ patch_size = 16
18
+
19
+ image_path = kagglehub.dataset_download("awsaf49/coco-2017-dataset")
20
+ data = []
21
+ for dirpath, _, filenames in os.walk(image_path):
22
+ for filename in filenames:
23
+ if filename.endswith("jpg"):
24
+ name = os.path.join(dirpath, filename)
25
+ img = cv2.imread(name)
26
+ img = cv2.resize(img, (image_shape,image_shape))
27
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28
+ img = img.astype(np.float32) / 127.5 - 1.0
29
+ img = torch.tensor(img).permute(2,0,1)
30
+ data.append(img)
31
+ clear_output(wait=1)
32
+ print(f"{len(data)/1670:.2f}%")
33
+ print(len(data))
34
+
35
+ class CustomDataset(Dataset):
36
+ def __init__(self, data):
37
+ self.indices = np.arange(len(data))
38
+ np.random.shuffle(self.indices)
39
+ self.data = data
40
+
41
+ def __len__(self):
42
+ return len(self.indices)
43
+
44
+ def __getitem__(self, idx):
45
+ return torch.tensor(self.data[self.indices[idx]], dtype=torch.float32)
46
+
47
+ # Sanity Check :)
48
+ plt.imshow(CustomDataset(data)[0].permute(1,2,0)/2+0.5)
49
+
50
+ encoder = Encoder(latent_dim=16)
51
+ decoder = Decoder(latent_dim=16)
52
+ D = Discriminator((3,256,256))
53
+
54
+ vgg = vgg_builder()
55
+ for param in vgg.parameters():
56
+ param.requires_grad = False
57
+ vgg.eval()
58
+
59
+ print(f"encoder: {sum(p.numel() for p in encoder.parameters())/(262144):.3f}MB")
60
+ print(f"decoder: {sum(p.numel() for p in decoder.parameters())/(262144):.3f}MB")
61
+ print(f"Discriminator: {sum(p.numel() for p in D.parameters())/(262144):.3f}MB")
62
+ print(f"VGG: {sum(p.numel() for p in vgg.parameters())/(262144):.3f}MB")
63
+
64
+ batch_size = 16
65
+ dataset = CustomDataset(data)
66
+ loader = DataLoader(dataset,
67
+ batch_size=batch_size,
68
+ shuffle=True,
69
+ num_workers=8,
70
+ pin_memory=True)
71
+ epochs = 5
72
+ trainer = Trainer(encoder, decoder, D, vgg, ["mse", "gan", "vgg", "KL"], len(loader) if "loader" in locals() else 0, isViT=1)
73
+ for epoch in range(1, epochs):
74
+ index = 0
75
+ for i, x in enumerate(loader):
76
+ trainer.train_step(x, freeze_disc=0, with_mse=1, freeze_ae=0)
77
+ trainer.update_epoch()
78
+
79
+ torch.save(encoder.state_dict(), "encoder16.pt")
80
+ torch.save(decoder.state_dict(), "decoder16.pt")
81
+ torch.save(D.state_dict(), "discriminator16.pt")
trainer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from IPython.display import clear_output
4
+
5
+ # @title Trainer
6
+ class Trainer():
7
+ def __init__(self, encoder, decoder, D, vgg, losses, data_len, ema=3, a_disc=1, a_vae=1, a_KL=0.1, isViT=True):
8
+ self.vgg_schedule = None
9
+ self.ema = 2/(ema+1)
10
+ self.a_disc = a_disc
11
+ self.a_vae = a_vae
12
+ self.a_KL = a_KL
13
+
14
+ self.isViT = isViT
15
+ self.encoder = encoder
16
+ self.decoder = decoder
17
+ self.D = D
18
+ self.vgg = vgg
19
+ self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-5)
20
+ self.encoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.encoder_optimizer, T_max=50)
21
+ self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), lr=1e-5)
22
+ self.decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.decoder_optimizer, T_max=50)
23
+ self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=4e-5)
24
+ self.D_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.D_optimizer, T_max=50)
25
+ self.losses = losses
26
+ self.loss_vals = {loss:0 for loss in losses}
27
+ self.data_len = data_len
28
+ self.loss_record = []
29
+ self.epoch = 1
30
+ self.index = 1
31
+ self.device = torch.device("cuda")
32
+
33
+ self.encoder.to(self.device)
34
+ self.decoder.to(self.device)
35
+ self.D.to(self.device)
36
+ self.vgg.to(self.device)
37
+
38
+ def train_step(self, x, with_mse=False, freeze_ae=False, freeze_disc=False):
39
+ self.index += 1
40
+ x = x.to(self.device)
41
+ with torch.no_grad():
42
+ x_hat = self.decoder(self.encoder(x.permute(0,2,3,1))).permute(0,3,1,2) if not self.isViT else self.decoder(self.encoder(x))
43
+ if not freeze_disc:
44
+ disc_loss = F.relu(1. - self.D(x)).mean() + F.relu(1. + self.D(x_hat)).mean() # Hinge
45
+ self.D_optimizer.zero_grad()
46
+ disc_loss.backward()
47
+ self.D_optimizer.step()
48
+ self.D_scheduler.step()
49
+
50
+ if not freeze_ae:
51
+ z = self.encoder(x.permute(0,2,3,1)) if not self.isViT else self.encoder(x)
52
+ x_hat = self.decoder(z).permute(0,3,1,2) if not self.isViT else self.decoder(z)
53
+ mse = F.mse_loss(x_hat, x)
54
+ KL = 0.5 * (z.mean() ** 2)
55
+ vgg_real = self.vgg(x)
56
+ vgg_fake = self.vgg(x_hat)
57
+ vgg_loss = 0
58
+ for i in range(len(vgg_real)):
59
+ vgg_loss += F.mse_loss(vgg_real[i], vgg_fake[i])
60
+
61
+ adv_loss = 0
62
+ if not freeze_disc:
63
+ adv_loss = -(self.D(self.decoder(self.encoder(x))).mean())
64
+
65
+ loss = mse * with_mse + self.a_KL* KL + vgg_loss + self.a_vae * adv_loss
66
+ self.encoder_optimizer.zero_grad()
67
+ self.decoder_optimizer.zero_grad()
68
+ loss.backward()
69
+ self.encoder_optimizer.step()
70
+ self.decoder_optimizer.step()
71
+ self.encoder_scheduler.step()
72
+ self.decoder_scheduler.step()
73
+
74
+ self.update_batch({"mse":mse.item() if not freeze_ae else 0,
75
+ "gan":disc_loss.item() if not freeze_disc else 0,
76
+ "vgg":vgg_loss.item() if not freeze_ae else 0,
77
+ "KL":z.mean() if not freeze_ae else 0})
78
+
79
+ def update_batch(self, loss_vals):
80
+ clear_output(wait=True)
81
+ for record in self.loss_record:
82
+ print(record)
83
+ self.loss_vals = {loss:(1-self.ema)*self.loss_vals[loss] + self.ema*loss_vals[loss] for loss in self.losses}
84
+ print(f"epoch:{self.epoch} ", end="")
85
+ for loss in self.losses:
86
+ print(f"{loss}: {self.loss_vals[loss]:.3f} ", end="")
87
+ for _ in range(int(self.index * 20 / self.data_len)):
88
+ print("=", end="")
89
+ for _ in range(int(self.index * 20 / self.data_len),20):
90
+ print("-", end="")
91
+
92
+ def update_epoch(self):
93
+ self.index = 0
94
+ record = f"epoch:{self.epoch} "
95
+ for loss in self.losses:
96
+ record += f"{loss}: {self.loss_vals[loss]:.3f} "
97
+ self.loss_record.append(record)
98
+ self.epoch += 1