AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
raw
history blame
2.59 kB
import os
import sys
import time
import torch
import torch.nn
import argparse
from PIL import Image
import numpy as np
from validate import validate
from data import create_dataloader
from networks.trainer import Trainer
from options.train_options import TrainOptions
from options.test_options import TestOptions
from util import Logger
from tqdm import tqdm
import random
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
def get_val_opt():
val_opt = TrainOptions().parse(print_options=False)
val_opt.isTrain = False
val_opt.no_resize = False
val_opt.no_crop = False
val_opt.serial_batches = True
return val_opt
if __name__ == '__main__':
opt_train = TrainOptions().parse()
seed_torch(100)
print(' '.join(list(sys.argv)) )
opt_val = get_val_opt()
train_loader = create_dataloader(opt_train, split='train')
val_loader = create_dataloader(opt_val, split='val')
model = Trainer(opt_train)
model.train()
print(f'cwd: {os.getcwd()}')
for epoch in range(opt_train.niter):
if epoch > 0:
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
#for i, data in enumerate(train_loader):
with tqdm(train_loader, unit='batch', mininterval=0.5) as tepoch:
tepoch.set_description(f'Epoch {epoch}', refresh=False)
for i, data in enumerate(tepoch):
model.total_steps += 1
epoch_iter += opt_train.batch_size
model.set_input(data)
model.optimize_parameters()
tepoch.set_postfix(loss=model.loss.item())
if epoch % opt_train.delr_freq == 0 and epoch != 0:
print('changing lr at the end of epoch %d, iters %d' % (epoch, model.total_steps))
model.adjust_learning_rate()
# Validation
model.eval()
acc, ap = validate(model.model, val_loader)[:2]
print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
model.train()
if epoch == 0:
model.save_networks('best')
elif acc >= model.best_acc:
model.save_networks('best')