Upload 7 files
Browse files- RoPE.py +22 -0
- attention.py +63 -0
- autoencoder.py +48 -0
- autoencoder_test.py +31 -0
- objectives.py +55 -0
- train.py +81 -0
- trainer.py +98 -0
RoPE.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
def generate_angles_2d(H,W,D, freq=None):
|
| 5 |
+
freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))]) if freq is None else freq
|
| 6 |
+
pos = torch.outer(torch.linspace(-1, 1, steps=H),torch.linspace(-1, 1, steps=W))
|
| 7 |
+
freq_tensor = torch.einsum("ij,k->ijk", pos, freq)
|
| 8 |
+
return freq_tensor
|
| 9 |
+
|
| 10 |
+
def apply_angles_2d(x, f):
|
| 11 |
+
x_reshaped = rearrange(x, "B h H W (D p) -> B h H W D p", p=2)
|
| 12 |
+
real = x_reshaped[..., 0]
|
| 13 |
+
imag = x_reshaped[..., 1]
|
| 14 |
+
cosines, sines = f.cos(), f.sin()
|
| 15 |
+
# r , i -> rcos-isin , rsin icos
|
| 16 |
+
rot_real = real * cosines - imag * sines
|
| 17 |
+
rot_imag = real * sines + imag * cosines
|
| 18 |
+
rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
|
| 19 |
+
return rearrange(rot_full, "B h H W D p -> B h H W (D p)", p=2)
|
| 20 |
+
|
| 21 |
+
# Sanity Check :)
|
| 22 |
+
print(apply_angles_2d(torch.randn(1,8,64,64,768), generate_angles_2d(64,64,768)).shape)
|
attention.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from RoPE import apply_angles_2d, generate_angles_2d
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Attention(nn.Module):
|
| 9 |
+
def __init__(self, H,W, emb_dim, n_heads=8):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.H = H
|
| 12 |
+
self.W = W
|
| 13 |
+
self.n_heads = n_heads
|
| 14 |
+
head_dim = emb_dim // n_heads
|
| 15 |
+
self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
|
| 16 |
+
self.apply_angles_2d = apply_angles_2d
|
| 17 |
+
self.proj = nn.Linear(emb_dim, emb_dim)
|
| 18 |
+
self.register_buffer("freq", generate_angles_2d(H, W, head_dim), persistent=False)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
B, N, D = x.shape
|
| 22 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
| 23 |
+
|
| 24 |
+
# to 2D
|
| 25 |
+
q = rearrange(q, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
| 26 |
+
k = rearrange(k, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
| 27 |
+
v = rearrange(v, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
| 28 |
+
|
| 29 |
+
q = apply_angles_2d(q, self.freq)
|
| 30 |
+
k = apply_angles_2d(k, self.freq)
|
| 31 |
+
v = apply_angles_2d(v, self.freq)
|
| 32 |
+
|
| 33 |
+
# to 1D
|
| 34 |
+
q = rearrange(q, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
| 35 |
+
k = rearrange(k, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
| 36 |
+
v = rearrange(v, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
| 37 |
+
|
| 38 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 39 |
+
x = rearrange(x, "B h N D -> B N (h D)")
|
| 40 |
+
x = self.proj(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
class ViTBlock(nn.Module):
|
| 44 |
+
def __init__(self, H, W, emb_dim, n_heads=8, dropout=0.1):
|
| 45 |
+
self.H, self.W, self.emb_dim = H, W, emb_dim
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.attn = nn.Sequential(nn.LayerNorm(emb_dim),
|
| 48 |
+
Attention(H,W,emb_dim,n_heads=n_heads))
|
| 49 |
+
self.MLP = nn.Sequential(nn.LayerNorm(emb_dim),
|
| 50 |
+
nn.Linear(emb_dim, emb_dim*4, bias=True),
|
| 51 |
+
nn.GELU(),
|
| 52 |
+
nn.Dropout(dropout),
|
| 53 |
+
nn.Linear(emb_dim*4, emb_dim, bias=True),
|
| 54 |
+
nn.Dropout(dropout))
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
assert x.ndim == 3, f"Expected shape [B, N, D], but got shape {x.shape}. You probably passed [B, H, W, D] instead."
|
| 57 |
+
assert x.shape == torch.Size([x.shape[0], self.H * self.W, self.emb_dim]), f"Expected shape [B, N, D] -> {torch.Size([x.shape[0], self.H * self.W, self.emb_dim])}, got {x.shape}"
|
| 58 |
+
x = x + self.attn(x)
|
| 59 |
+
x = x + self.MLP(x)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
# Sanity Check :)
|
| 63 |
+
print(ViTBlock(64,64,384)(torch.randn(1, 64**2, 384)).shape)
|
autoencoder.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from attention import ViTBlock
|
| 6 |
+
|
| 7 |
+
# Global Parameters
|
| 8 |
+
image_shape = 256
|
| 9 |
+
emb_dim = 768
|
| 10 |
+
patch_size = 16
|
| 11 |
+
|
| 12 |
+
class Encoder(nn.Module):
|
| 13 |
+
def __init__(self, latent_dim, image_shape=image_shape, emb_dim=emb_dim, patch_size=patch_size, n_heads=8, dropout=0.1, layers=6, gaussian=False):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.patchifier = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)
|
| 16 |
+
self.Blocks = nn.ModuleList([ViTBlock(image_shape // patch_size, image_shape // patch_size, emb_dim, n_heads=8, dropout=dropout) for _ in range(layers)])
|
| 17 |
+
self.ln = nn.LayerNorm(emb_dim)
|
| 18 |
+
self.compress_latent = nn.Linear(emb_dim, latent_dim)
|
| 19 |
+
|
| 20 |
+
def forward(self,x):
|
| 21 |
+
x = self.patchifier(x)
|
| 22 |
+
x = rearrange(x, "B D H W -> B (H W) D") # Flatten to B, N, D
|
| 23 |
+
for vitBlock in self.Blocks:
|
| 24 |
+
x = vitBlock(x)
|
| 25 |
+
x = self.ln(x)
|
| 26 |
+
x = self.compress_latent(x)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
class Decoder(nn.Module):
|
| 30 |
+
def __init__(self, latent_dim, image_shape=image_shape, emb_dim=emb_dim, patch_size=patch_size, n_heads=8, dropout=0.1, layers=6, gaussian=False):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.hw = image_shape // patch_size
|
| 33 |
+
self.patch_size = patch_size
|
| 34 |
+
self.decompress_latent = nn.Linear(latent_dim, emb_dim)
|
| 35 |
+
self.ln = nn.LayerNorm(emb_dim)
|
| 36 |
+
self.emb_to_patch = nn.Linear(emb_dim, 3*(patch_size**2))
|
| 37 |
+
self.Blocks = nn.ModuleList([ViTBlock(image_shape // patch_size, image_shape // patch_size, emb_dim, n_heads=8, dropout=dropout) for _ in range(layers)])
|
| 38 |
+
|
| 39 |
+
def forward(self,x):
|
| 40 |
+
x = self.decompress_latent(x)
|
| 41 |
+
for vitBlock in self.Blocks:
|
| 42 |
+
x = vitBlock(x)
|
| 43 |
+
self.ln(x)
|
| 44 |
+
#shape is [B HW/p**2 (3 p p)]
|
| 45 |
+
x = self.emb_to_patch(x)
|
| 46 |
+
assert x.shape == torch.Size([x.shape[0], self.hw**2, 3*(self.patch_size**2)]), f"Expected shape {torch.Size([x.shape[0], self.hw**2, 3*(self.patch_size**2)])} got {x.shape}"
|
| 47 |
+
x = rearrange(x, "B (H W) (D p1 p2) -> B D (H p1) (W p2)", H=self.hw, W=self.hw, p1=self.patch_size, p2=self.patch_size) # Expand to B, H, W, D
|
| 48 |
+
return F.tanh(x)
|
autoencoder_test.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from autoencoder import Encoder, Decoder
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
image_shape = 256
|
| 8 |
+
emb_dim = 768
|
| 9 |
+
patch_size = 16
|
| 10 |
+
|
| 11 |
+
encoder = Encoder(latent_dim=16,
|
| 12 |
+
image_shape=image_shape,
|
| 13 |
+
emb_dim=emb_dim,
|
| 14 |
+
patch_size=patch_size)
|
| 15 |
+
encoder.load_state_dict(torch.load("encoder16.pt", map_location=torch.device('cpu')))
|
| 16 |
+
|
| 17 |
+
decoder = Decoder(latent_dim=16,
|
| 18 |
+
image_shape=image_shape,
|
| 19 |
+
emb_dim=emb_dim,
|
| 20 |
+
patch_size=patch_size)
|
| 21 |
+
decoder.load_state_dict(torch.load("decoder16.pt", map_location=torch.device('cpu')))
|
| 22 |
+
|
| 23 |
+
image = cv2.imread("test_image.jpg")
|
| 24 |
+
image = cv2.resize(image, (image_shape, image_shape))
|
| 25 |
+
image = torch.tensor(image, dtype=torch.float32, device='cpu').permute(2, 0, 1) / 127.5 - 1.0
|
| 26 |
+
image = image.unsqueeze(0)
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
z = encoder(image)
|
| 29 |
+
x = decoder(z)
|
| 30 |
+
plt.imshow(x[0].permute(1, 2, 0).numpy()*0.5 + 0.5)
|
| 31 |
+
plt.show()
|
objectives.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torchvision.models import vgg16, VGG16_Weights
|
| 3 |
+
|
| 4 |
+
class Discriminator(nn.Module):
|
| 5 |
+
def __init__(self, img_shape, filters=[256,512]):
|
| 6 |
+
super().__init__()
|
| 7 |
+
module_list = [nn.Conv2d(img_shape[0], filters[0], kernel_size=3, stride=2, padding=1),
|
| 8 |
+
nn.BatchNorm2d(filters[0]),
|
| 9 |
+
nn.LeakyReLU(0.2)]
|
| 10 |
+
for i in range(1,len(filters)):
|
| 11 |
+
module_list += [nn.Conv2d(filters[i-1], filters[i], kernel_size=3, stride=2, padding=1),
|
| 12 |
+
nn.BatchNorm2d(filters[i]),
|
| 13 |
+
nn.LeakyReLU(0.2)]
|
| 14 |
+
|
| 15 |
+
self.convs = nn.Sequential(*module_list)
|
| 16 |
+
self.mlp = nn.Sequential(nn.Conv2d(filters[-1], 1, kernel_size=1, stride=1, padding=0))
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x = self.convs(x)
|
| 20 |
+
x = self.mlp(x)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
class vgg_builder(nn.Module):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
super(vgg_builder, self).__init__()
|
| 26 |
+
convs = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
|
| 27 |
+
self.N_slices = 5
|
| 28 |
+
self.slices = nn.ModuleList(list(nn.Sequential() for _ in range(self.N_slices)))
|
| 29 |
+
for x in range(4):
|
| 30 |
+
self.slices[0].add_module(str(x), convs[x])
|
| 31 |
+
for x in range(4, 9):
|
| 32 |
+
self.slices[1].add_module(str(x), convs[x])
|
| 33 |
+
for x in range(9, 16):
|
| 34 |
+
self.slices[2].add_module(str(x), convs[x])
|
| 35 |
+
for x in range(16, 23):
|
| 36 |
+
self.slices[3].add_module(str(x), convs[x])
|
| 37 |
+
for x in range(23, 30):
|
| 38 |
+
self.slices[4].add_module(str(x), convs[x])
|
| 39 |
+
for param in self.parameters():
|
| 40 |
+
param.requires_grad = False
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
feat_map = []
|
| 44 |
+
x = (x+1)/2
|
| 45 |
+
x = self.slices[0](x)
|
| 46 |
+
feat_map.append(x)
|
| 47 |
+
x = self.slices[1](x)
|
| 48 |
+
feat_map.append(x)
|
| 49 |
+
x = self.slices[2](x)
|
| 50 |
+
feat_map.append(x)
|
| 51 |
+
x = self.slices[3](x)
|
| 52 |
+
feat_map.append(x)
|
| 53 |
+
x = self.slices[4](x)
|
| 54 |
+
feat_map.append(x)
|
| 55 |
+
return feat_map
|
train.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import kagglehub
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
from IPython.display import clear_output
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from autoencoder import Encoder, Decoder
|
| 11 |
+
from trainer import Trainer
|
| 12 |
+
from objectives import Discriminator, vgg_builder
|
| 13 |
+
|
| 14 |
+
# Global Parameters
|
| 15 |
+
image_shape = 256
|
| 16 |
+
emb_dim = 768
|
| 17 |
+
patch_size = 16
|
| 18 |
+
|
| 19 |
+
image_path = kagglehub.dataset_download("awsaf49/coco-2017-dataset")
|
| 20 |
+
data = []
|
| 21 |
+
for dirpath, _, filenames in os.walk(image_path):
|
| 22 |
+
for filename in filenames:
|
| 23 |
+
if filename.endswith("jpg"):
|
| 24 |
+
name = os.path.join(dirpath, filename)
|
| 25 |
+
img = cv2.imread(name)
|
| 26 |
+
img = cv2.resize(img, (image_shape,image_shape))
|
| 27 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 28 |
+
img = img.astype(np.float32) / 127.5 - 1.0
|
| 29 |
+
img = torch.tensor(img).permute(2,0,1)
|
| 30 |
+
data.append(img)
|
| 31 |
+
clear_output(wait=1)
|
| 32 |
+
print(f"{len(data)/1670:.2f}%")
|
| 33 |
+
print(len(data))
|
| 34 |
+
|
| 35 |
+
class CustomDataset(Dataset):
|
| 36 |
+
def __init__(self, data):
|
| 37 |
+
self.indices = np.arange(len(data))
|
| 38 |
+
np.random.shuffle(self.indices)
|
| 39 |
+
self.data = data
|
| 40 |
+
|
| 41 |
+
def __len__(self):
|
| 42 |
+
return len(self.indices)
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx):
|
| 45 |
+
return torch.tensor(self.data[self.indices[idx]], dtype=torch.float32)
|
| 46 |
+
|
| 47 |
+
# Sanity Check :)
|
| 48 |
+
plt.imshow(CustomDataset(data)[0].permute(1,2,0)/2+0.5)
|
| 49 |
+
|
| 50 |
+
encoder = Encoder(latent_dim=16)
|
| 51 |
+
decoder = Decoder(latent_dim=16)
|
| 52 |
+
D = Discriminator((3,256,256))
|
| 53 |
+
|
| 54 |
+
vgg = vgg_builder()
|
| 55 |
+
for param in vgg.parameters():
|
| 56 |
+
param.requires_grad = False
|
| 57 |
+
vgg.eval()
|
| 58 |
+
|
| 59 |
+
print(f"encoder: {sum(p.numel() for p in encoder.parameters())/(262144):.3f}MB")
|
| 60 |
+
print(f"decoder: {sum(p.numel() for p in decoder.parameters())/(262144):.3f}MB")
|
| 61 |
+
print(f"Discriminator: {sum(p.numel() for p in D.parameters())/(262144):.3f}MB")
|
| 62 |
+
print(f"VGG: {sum(p.numel() for p in vgg.parameters())/(262144):.3f}MB")
|
| 63 |
+
|
| 64 |
+
batch_size = 16
|
| 65 |
+
dataset = CustomDataset(data)
|
| 66 |
+
loader = DataLoader(dataset,
|
| 67 |
+
batch_size=batch_size,
|
| 68 |
+
shuffle=True,
|
| 69 |
+
num_workers=8,
|
| 70 |
+
pin_memory=True)
|
| 71 |
+
epochs = 5
|
| 72 |
+
trainer = Trainer(encoder, decoder, D, vgg, ["mse", "gan", "vgg", "KL"], len(loader) if "loader" in locals() else 0, isViT=1)
|
| 73 |
+
for epoch in range(1, epochs):
|
| 74 |
+
index = 0
|
| 75 |
+
for i, x in enumerate(loader):
|
| 76 |
+
trainer.train_step(x, freeze_disc=0, with_mse=1, freeze_ae=0)
|
| 77 |
+
trainer.update_epoch()
|
| 78 |
+
|
| 79 |
+
torch.save(encoder.state_dict(), "encoder16.pt")
|
| 80 |
+
torch.save(decoder.state_dict(), "decoder16.pt")
|
| 81 |
+
torch.save(D.state_dict(), "discriminator16.pt")
|
trainer.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from IPython.display import clear_output
|
| 4 |
+
|
| 5 |
+
# @title Trainer
|
| 6 |
+
class Trainer():
|
| 7 |
+
def __init__(self, encoder, decoder, D, vgg, losses, data_len, ema=3, a_disc=1, a_vae=1, a_KL=0.1, isViT=True):
|
| 8 |
+
self.vgg_schedule = None
|
| 9 |
+
self.ema = 2/(ema+1)
|
| 10 |
+
self.a_disc = a_disc
|
| 11 |
+
self.a_vae = a_vae
|
| 12 |
+
self.a_KL = a_KL
|
| 13 |
+
|
| 14 |
+
self.isViT = isViT
|
| 15 |
+
self.encoder = encoder
|
| 16 |
+
self.decoder = decoder
|
| 17 |
+
self.D = D
|
| 18 |
+
self.vgg = vgg
|
| 19 |
+
self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-5)
|
| 20 |
+
self.encoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.encoder_optimizer, T_max=50)
|
| 21 |
+
self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), lr=1e-5)
|
| 22 |
+
self.decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.decoder_optimizer, T_max=50)
|
| 23 |
+
self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=4e-5)
|
| 24 |
+
self.D_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.D_optimizer, T_max=50)
|
| 25 |
+
self.losses = losses
|
| 26 |
+
self.loss_vals = {loss:0 for loss in losses}
|
| 27 |
+
self.data_len = data_len
|
| 28 |
+
self.loss_record = []
|
| 29 |
+
self.epoch = 1
|
| 30 |
+
self.index = 1
|
| 31 |
+
self.device = torch.device("cuda")
|
| 32 |
+
|
| 33 |
+
self.encoder.to(self.device)
|
| 34 |
+
self.decoder.to(self.device)
|
| 35 |
+
self.D.to(self.device)
|
| 36 |
+
self.vgg.to(self.device)
|
| 37 |
+
|
| 38 |
+
def train_step(self, x, with_mse=False, freeze_ae=False, freeze_disc=False):
|
| 39 |
+
self.index += 1
|
| 40 |
+
x = x.to(self.device)
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
x_hat = self.decoder(self.encoder(x.permute(0,2,3,1))).permute(0,3,1,2) if not self.isViT else self.decoder(self.encoder(x))
|
| 43 |
+
if not freeze_disc:
|
| 44 |
+
disc_loss = F.relu(1. - self.D(x)).mean() + F.relu(1. + self.D(x_hat)).mean() # Hinge
|
| 45 |
+
self.D_optimizer.zero_grad()
|
| 46 |
+
disc_loss.backward()
|
| 47 |
+
self.D_optimizer.step()
|
| 48 |
+
self.D_scheduler.step()
|
| 49 |
+
|
| 50 |
+
if not freeze_ae:
|
| 51 |
+
z = self.encoder(x.permute(0,2,3,1)) if not self.isViT else self.encoder(x)
|
| 52 |
+
x_hat = self.decoder(z).permute(0,3,1,2) if not self.isViT else self.decoder(z)
|
| 53 |
+
mse = F.mse_loss(x_hat, x)
|
| 54 |
+
KL = 0.5 * (z.mean() ** 2)
|
| 55 |
+
vgg_real = self.vgg(x)
|
| 56 |
+
vgg_fake = self.vgg(x_hat)
|
| 57 |
+
vgg_loss = 0
|
| 58 |
+
for i in range(len(vgg_real)):
|
| 59 |
+
vgg_loss += F.mse_loss(vgg_real[i], vgg_fake[i])
|
| 60 |
+
|
| 61 |
+
adv_loss = 0
|
| 62 |
+
if not freeze_disc:
|
| 63 |
+
adv_loss = -(self.D(self.decoder(self.encoder(x))).mean())
|
| 64 |
+
|
| 65 |
+
loss = mse * with_mse + self.a_KL* KL + vgg_loss + self.a_vae * adv_loss
|
| 66 |
+
self.encoder_optimizer.zero_grad()
|
| 67 |
+
self.decoder_optimizer.zero_grad()
|
| 68 |
+
loss.backward()
|
| 69 |
+
self.encoder_optimizer.step()
|
| 70 |
+
self.decoder_optimizer.step()
|
| 71 |
+
self.encoder_scheduler.step()
|
| 72 |
+
self.decoder_scheduler.step()
|
| 73 |
+
|
| 74 |
+
self.update_batch({"mse":mse.item() if not freeze_ae else 0,
|
| 75 |
+
"gan":disc_loss.item() if not freeze_disc else 0,
|
| 76 |
+
"vgg":vgg_loss.item() if not freeze_ae else 0,
|
| 77 |
+
"KL":z.mean() if not freeze_ae else 0})
|
| 78 |
+
|
| 79 |
+
def update_batch(self, loss_vals):
|
| 80 |
+
clear_output(wait=True)
|
| 81 |
+
for record in self.loss_record:
|
| 82 |
+
print(record)
|
| 83 |
+
self.loss_vals = {loss:(1-self.ema)*self.loss_vals[loss] + self.ema*loss_vals[loss] for loss in self.losses}
|
| 84 |
+
print(f"epoch:{self.epoch} ", end="")
|
| 85 |
+
for loss in self.losses:
|
| 86 |
+
print(f"{loss}: {self.loss_vals[loss]:.3f} ", end="")
|
| 87 |
+
for _ in range(int(self.index * 20 / self.data_len)):
|
| 88 |
+
print("=", end="")
|
| 89 |
+
for _ in range(int(self.index * 20 / self.data_len),20):
|
| 90 |
+
print("-", end="")
|
| 91 |
+
|
| 92 |
+
def update_epoch(self):
|
| 93 |
+
self.index = 0
|
| 94 |
+
record = f"epoch:{self.epoch} "
|
| 95 |
+
for loss in self.losses:
|
| 96 |
+
record += f"{loss}: {self.loss_vals[loss]:.3f} "
|
| 97 |
+
self.loss_record.append(record)
|
| 98 |
+
self.epoch += 1
|