| | import argparse |
| | import numpy as np |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from utils import TrainSet |
| | from AdaIN import AdaINNet |
| | from tqdm import tqdm |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--content_dir', type=str, required=True, help='content images folder path') |
| | parser.add_argument('--style_dir', type=str, required=True, help='style images folder path') |
| | parser.add_argument('--epochs', type=int, default=10, help='Number of epoch') |
| | parser.add_argument('--batch_size', type=int, default=8, help='Batch size') |
| | parser.add_argument('--resume', type=int, default=0, help='Continue training from epoch') |
| | parser.add_argument('--cuda', action='store_true', help='Use CUDA') |
| | args = parser.parse_args() |
| |
|
| | device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') |
| |
|
| | check_point_dir = './check_point/' |
| | weights_dir = './weights/' |
| |
|
| | |
| | train_set = TrainSet(args.content_dir, args.style_dir) |
| | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True) |
| |
|
| | |
| | vgg_model = torch.load('vgg_normalized.pth') |
| | model = AdaINNet(vgg_model).to(device) |
| | decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6) |
| |
|
| | total_loss, content_loss, style_loss = 0.0, 0.0, 0.0 |
| | losses = [] |
| | iteration = 0 |
| |
|
| | |
| | if args.resume > 0: |
| | states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth') |
| | model.decoder.load_state_dict(states['decoder']) |
| | decoder_optimizer.load_state_dict(states['decoder_optimizer']) |
| | losses = states['losses'] |
| | iteration = states['iteration'] |
| |
|
| |
|
| | for epoch in range(args.resume + 1, args.epochs + 1): |
| | print("Begin epoch: %i/%i" % (epoch, int(args.epochs))) |
| | train_tqdm = tqdm(train_loader) |
| | train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss)) |
| | losses.append((iteration, total_loss, content_loss, style_loss)) |
| | total_num = 0 |
| | |
| | for content_batch, style_batch in train_tqdm: |
| | |
| | decoder_optimizer.zero_grad() |
| | |
| | content_batch = content_batch.to(device) |
| | style_batch = style_batch.to(device) |
| |
|
| | |
| | loss_content, loss_style = model(content_batch, style_batch) |
| | loss_scaled = loss_content + 10 * loss_style |
| |
|
| | |
| | loss_scaled.backward() |
| | decoder_optimizer.step() |
| |
|
| | total_loss = loss_scaled.item() |
| | content_loss = loss_content.item() |
| | style_loss = loss_style.item() |
| |
|
| | train_tqdm.set_description('Loss: %.4f, Content loss: %.4f, Style loss: %.4f' % (total_loss, content_loss, style_loss)) |
| | iteration += 1 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print('Finished epoch: %i/%i' % (epoch, int(args.epochs))) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|