ViT_Autoencoder / train.py
detectivejoewest's picture
Upload 7 files
582b238 verified
import kagglehub
import cv2
import os
from IPython.display import clear_output
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
from autoencoder import Encoder, Decoder
from trainer import Trainer
from objectives import Discriminator, vgg_builder
# Global Parameters
image_shape = 256
emb_dim = 768
patch_size = 16
image_path = kagglehub.dataset_download("awsaf49/coco-2017-dataset")
data = []
for dirpath, _, filenames in os.walk(image_path):
for filename in filenames:
if filename.endswith("jpg"):
name = os.path.join(dirpath, filename)
img = cv2.imread(name)
img = cv2.resize(img, (image_shape,image_shape))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 127.5 - 1.0
img = torch.tensor(img).permute(2,0,1)
data.append(img)
clear_output(wait=1)
print(f"{len(data)/1670:.2f}%")
print(len(data))
class CustomDataset(Dataset):
def __init__(self, data):
self.indices = np.arange(len(data))
np.random.shuffle(self.indices)
self.data = data
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
return torch.tensor(self.data[self.indices[idx]], dtype=torch.float32)
# Sanity Check :)
plt.imshow(CustomDataset(data)[0].permute(1,2,0)/2+0.5)
encoder = Encoder(latent_dim=16)
decoder = Decoder(latent_dim=16)
D = Discriminator((3,256,256))
vgg = vgg_builder()
for param in vgg.parameters():
param.requires_grad = False
vgg.eval()
print(f"encoder: {sum(p.numel() for p in encoder.parameters())/(262144):.3f}MB")
print(f"decoder: {sum(p.numel() for p in decoder.parameters())/(262144):.3f}MB")
print(f"Discriminator: {sum(p.numel() for p in D.parameters())/(262144):.3f}MB")
print(f"VGG: {sum(p.numel() for p in vgg.parameters())/(262144):.3f}MB")
batch_size = 16
dataset = CustomDataset(data)
loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True)
epochs = 5
trainer = Trainer(encoder, decoder, D, vgg, ["mse", "gan", "vgg", "KL"], len(loader) if "loader" in locals() else 0, isViT=1)
for epoch in range(1, epochs):
index = 0
for i, x in enumerate(loader):
trainer.train_step(x, freeze_disc=0, with_mse=1, freeze_ae=0)
trainer.update_epoch()
torch.save(encoder.state_dict(), "encoder16.pt")
torch.save(decoder.state_dict(), "decoder16.pt")
torch.save(D.state_dict(), "discriminator16.pt")