Spaces:
Build error
Build error
| import os | |
| import time | |
| from datetime import datetime | |
| from tqdm import tqdm | |
| from tensorboardX import SummaryWriter | |
| import torch | |
| import torchinfo | |
| import numpy as np | |
| import options | |
| from validate import validate, calculate_acc | |
| from datasetss import * | |
| from utilss.logger import create_logger | |
| from utilss.earlystop import EarlyStopping | |
| from networks.trainer import Trainer | |
| if __name__ == '__main__': | |
| train_opt = options.TrainOptions().parse() | |
| # val_opt = options.TestOptions().parse() | |
| # logger | |
| logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer") | |
| logger.info(f"working dir: {train_opt.checkpoints_dir}") | |
| model = Trainer(train_opt) | |
| # logger.info(opt.gpu_ids[0]) | |
| logger.info(model.device) | |
| # extract_feature_model = model.extract_feature_model | |
| train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8) | |
| logger.info(f"train {len(train_loader)}") | |
| logger.info(f"validate {len(val_loader)}") | |
| train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train")) | |
| val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val")) | |
| early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True) | |
| start_time = time.time() | |
| logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20, | |
| col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)) | |
| logger.info("Length of train loader: %d" %(len(train_loader))) | |
| for epoch in range(train_opt.niter): | |
| y_true, y_pred = [], [] | |
| pbar = tqdm(train_loader) | |
| for i, data in enumerate(pbar): | |
| pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
| model.total_steps += 1 | |
| model.set_input(data) | |
| model.optimize_parameters() | |
| y_pred.extend(model.output.sigmoid().flatten().tolist()) | |
| y_true.extend(data[1].flatten().tolist()) | |
| if model.total_steps % train_opt.loss_freq == 0: | |
| logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps)) | |
| train_writer.add_scalar('loss', model.loss, model.total_steps) | |
| logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps) ) | |
| if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters | |
| model.save_networks('model_iters_%s.pth' % model.total_steps) | |
| # logger.info("trained one batch") | |
| pbar.set_postfix_str(f"loss: {model.loss}, ") | |
| r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5) | |
| logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}") | |
| if epoch % train_opt.save_epoch_freq == 0: | |
| logger.info('saving the model at the end of epoch %d' % (epoch)) | |
| model.save_networks( 'model_epoch_%s.pth' % epoch ) | |
| # Validation | |
| model.eval() | |
| ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger) | |
| val_writer.add_scalar('accuracy', acc, model.total_steps) | |
| val_writer.add_scalar('ap', ap, model.total_steps) | |
| logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap)) | |
| early_stopping(acc, model.model) | |
| if early_stopping.early_stop: | |
| cont_train = model.adjust_learning_rate() | |
| if cont_train: | |
| logger.info("Learning rate dropped by 10, continue training...") | |
| early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True) | |
| else: | |
| logger.info("Early stopping.") | |
| break | |
| model.train() | |