|
|
import kagglehub
|
|
|
import cv2
|
|
|
import os
|
|
|
from IPython.display import clear_output
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
import torch.nn as nn
|
|
|
import matplotlib.pyplot as plt
|
|
|
from autoencoder import Encoder, Decoder
|
|
|
from trainer import Trainer
|
|
|
from objectives import Discriminator, vgg_builder
|
|
|
|
|
|
|
|
|
image_shape = 256
|
|
|
emb_dim = 768
|
|
|
patch_size = 16
|
|
|
|
|
|
image_path = kagglehub.dataset_download("awsaf49/coco-2017-dataset")
|
|
|
data = []
|
|
|
for dirpath, _, filenames in os.walk(image_path):
|
|
|
for filename in filenames:
|
|
|
if filename.endswith("jpg"):
|
|
|
name = os.path.join(dirpath, filename)
|
|
|
img = cv2.imread(name)
|
|
|
img = cv2.resize(img, (image_shape,image_shape))
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
img = img.astype(np.float32) / 127.5 - 1.0
|
|
|
img = torch.tensor(img).permute(2,0,1)
|
|
|
data.append(img)
|
|
|
clear_output(wait=1)
|
|
|
print(f"{len(data)/1670:.2f}%")
|
|
|
print(len(data))
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
|
def __init__(self, data):
|
|
|
self.indices = np.arange(len(data))
|
|
|
np.random.shuffle(self.indices)
|
|
|
self.data = data
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.indices)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return torch.tensor(self.data[self.indices[idx]], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
plt.imshow(CustomDataset(data)[0].permute(1,2,0)/2+0.5)
|
|
|
|
|
|
encoder = Encoder(latent_dim=16)
|
|
|
decoder = Decoder(latent_dim=16)
|
|
|
D = Discriminator((3,256,256))
|
|
|
|
|
|
vgg = vgg_builder()
|
|
|
for param in vgg.parameters():
|
|
|
param.requires_grad = False
|
|
|
vgg.eval()
|
|
|
|
|
|
print(f"encoder: {sum(p.numel() for p in encoder.parameters())/(262144):.3f}MB")
|
|
|
print(f"decoder: {sum(p.numel() for p in decoder.parameters())/(262144):.3f}MB")
|
|
|
print(f"Discriminator: {sum(p.numel() for p in D.parameters())/(262144):.3f}MB")
|
|
|
print(f"VGG: {sum(p.numel() for p in vgg.parameters())/(262144):.3f}MB")
|
|
|
|
|
|
batch_size = 16
|
|
|
dataset = CustomDataset(data)
|
|
|
loader = DataLoader(dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=8,
|
|
|
pin_memory=True)
|
|
|
epochs = 5
|
|
|
trainer = Trainer(encoder, decoder, D, vgg, ["mse", "gan", "vgg", "KL"], len(loader) if "loader" in locals() else 0, isViT=1)
|
|
|
for epoch in range(1, epochs):
|
|
|
index = 0
|
|
|
for i, x in enumerate(loader):
|
|
|
trainer.train_step(x, freeze_disc=0, with_mse=1, freeze_ae=0)
|
|
|
trainer.update_epoch()
|
|
|
|
|
|
torch.save(encoder.state_dict(), "encoder16.pt")
|
|
|
torch.save(decoder.state_dict(), "decoder16.pt")
|
|
|
torch.save(D.state_dict(), "discriminator16.pt") |