Spaces:
Running
Running
| import os | |
| import glob | |
| import numpy as np | |
| import wandb | |
| import copy | |
| import argparse | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torchinfo import summary | |
| from utils import StyleContentDataset, DataStore, denorm_img | |
| from loss import Loss | |
| from model import Model | |
| config = { | |
| "lr": 1e-4, | |
| "max_iter": 80000, | |
| "logging_interval": 100, | |
| "preview_interval": 1000, | |
| "batch_size": 4, | |
| "activations": "ReLU", | |
| "optimizer": "Adam", | |
| "lambda": 7 | |
| } | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using {device} device") | |
| def prepare_data(style_dir, content_dir, preview_dir): | |
| norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| # Training images | |
| transform = transforms.Compose([transforms.Resize(512), transforms.RandomCrop(256)]) | |
| style_imgs = glob.glob(os.path.join(style_dir, '*.jpg')) | |
| content_imgs = glob.glob(os.path.join(content_dir, '*.jpg')) | |
| train_dataset = StyleContentDataset(style_imgs, content_imgs, transform=transform, normalize=norm) | |
| datastore = DataStore(train_dataset, batch_size=config['batch_size'], shuffle=True) | |
| # Preview images | |
| transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256)]) | |
| preview_style_imgs = glob.glob(os.path.join(preview_dir, 'style/*.jpg')) | |
| preview_content_imgs = glob.glob(os.path.join(preview_dir, 'content/*.jpg')) | |
| # preview_dataset = StyleContentDataset(preview_style_imgs, preview_content_imgs, transform=transform, normalize=norm) | |
| preview_dataset = StyleContentDataset(preview_style_imgs, [preview_content_imgs[8]] * len(preview_style_imgs), transform=transform, normalize=norm) | |
| preview_datastore = DataStore(preview_dataset, batch_size=len(preview_dataset), shuffle=False) | |
| return datastore, preview_datastore | |
| def preview(model: Model, datastore: DataStore, iteration, save=False, use_wandb=False): | |
| model.eval() | |
| with torch.no_grad(): | |
| # np.random.shuffle(datastore.dataset.style_imgs) | |
| # np.random.shuffle(datastore.dataset.content_imgs) | |
| style, content = datastore.get() | |
| style, content = style.to(device), content.to(device) | |
| out = model(content, style) | |
| fig, axs = plt.subplots(8, 6, figsize=(20, 26)) | |
| axs = axs.flatten() | |
| i = 0 | |
| for (s, c, o) in zip(style, content, out): # style, content, out | |
| axs[i].imshow(denorm_img(s.cpu()).permute(1, 2, 0)) | |
| axs[i].axis('off') | |
| axs[i].set_title('style') | |
| axs[i+1].imshow(denorm_img(c.cpu()).permute(1, 2, 0)) | |
| axs[i+1].axis('off') | |
| axs[i+1].set_title('content') | |
| axs[i+2].imshow(denorm_img(o.cpu()).permute(1, 2, 0)) | |
| axs[i+2].axis('off') | |
| axs[i+2].set_title('output') | |
| i += 3 | |
| if save: | |
| fig.savefig(f'outputs/{iteration}_preview.png') | |
| plt.close(fig) | |
| if use_wandb: | |
| wandb.log({'preview': wandb.Image(f'outputs/{iteration}_preview.png')}, step=iteration) | |
| def train_one_iter(datastore: DataStore, model: Model, optimizer: torch.optim.Adam, loss_fn: Loss): | |
| model.train() | |
| style, content = datastore.get() | |
| style, content = style.to(device), content.to(device) | |
| optimizer.zero_grad() | |
| # Forward | |
| out = model(content, style) | |
| # Save activations | |
| style_activations = copy.deepcopy(model.activations) | |
| enc_out = model.encoder(out) | |
| out_activations = model.activations | |
| # Compute loss | |
| loss = loss_fn(enc_out, model.t, out_activations, style_activations) | |
| # Update parameters | |
| loss.backward() | |
| optimizer.step() | |
| return loss.item(), loss_fn.loss_c.item(), loss_fn.loss_s.item() | |
| def train(datastore, preview_datastore, model: Model, optimizer: torch.optim.Adam, use_wandb=False): | |
| train_history = {'style_loss': [], 'content_loss': [], 'loss': []} | |
| # optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr']) | |
| loss_fn = Loss(lamb=config['lambda']) | |
| for i in range(config['max_iter']): | |
| loss, content_loss, style_loss = train_one_iter(datastore, model, optimizer, loss_fn) | |
| train_history['loss'].append(loss) | |
| train_history['style_loss'].append(style_loss) | |
| train_history['content_loss'].append(content_loss) | |
| if i%config['logging_interval'] == 0: | |
| print(f'iter: {i}') | |
| print(f'loss: {loss:>5f}, style loss: {style_loss:>5f}, content loss: {content_loss:>5f}') | |
| print('-------------------------------') | |
| if use_wandb: | |
| wandb.log({ | |
| 'iter': i, 'loss': loss, 'style_loss': style_loss, 'content_loss': content_loss | |
| }) | |
| if i%config['preview_interval'] == 0: | |
| torch.save({ | |
| 'iter': i, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict() | |
| }, 'outputs/checkpoint.pt') | |
| preview(model, preview_datastore, i, save=True, use_wandb=use_wandb) | |
| return train_history | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--content_path', type=str, help='path to content dataset') | |
| parser.add_argument('--style_path', type=str, help='path to content dataset') | |
| parser.add_argument('--preview_path', type=str, help='path to preview dataset') | |
| parser.add_argument('--wandb', type=str, help='wandb id') | |
| parser.add_argument('--model_path', type=str, help='path to model') | |
| args = parser.parse_args() | |
| use_wandb = False | |
| wandb_key = args.wandb | |
| if wandb_key: | |
| wandb.login(key=wandb_key) | |
| wandb.init(project="assignment-3", name="", reinit=True, config=config) | |
| use_wandb = True | |
| if args.content_path and args.style_path and args.preview_path: | |
| content_dir = args.content_path | |
| style_dir = args.style_path | |
| preview_dir = args.preview_path | |
| else: | |
| print('You didnt specify the data path >:(') | |
| return | |
| if not os.path.isdir('outputs'): | |
| os.mkdir('outputs') | |
| datastore, preview_datastore = prepare_data(style_dir, content_dir, preview_dir) | |
| model = Model() | |
| optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr']) | |
| if args.model_path: | |
| # From checkpoint | |
| checkpoint = torch.load('outputs/checkpoint.pt') | |
| model.load_state_dict(checkpoint['model_state']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state']) | |
| config['max_iter'] -= checkpoint['iter'] | |
| # From final model | |
| # model.load_state_dict(torch.load(args.model_path, map_location=torch.device(device))) | |
| # print(summary(model)) | |
| model.to(device) | |
| train(datastore, preview_datastore, model, optimizer, use_wandb) | |
| torch.save(model.state_dict(), 'outputs/model.pt') | |
| if use_wandb: | |
| artifact = wandb.Artifact('model', type='model') | |
| artifact.add_file('outputs/model.pt') | |
| wandb.log_artifact(artifact) | |
| wandb.finish() | |
| if __name__ == '__main__': | |
| main() |