|
|
|
|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
from __future__ import print_function |
|
|
|
|
|
import json |
|
|
import time |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
import collections |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
from os.path import dirname, abspath |
|
|
|
|
|
pdvc_dir = dirname(abspath(__file__)) |
|
|
sys.path.insert(0, pdvc_dir) |
|
|
sys.path.insert(0, os.path.join(pdvc_dir, 'densevid_eval3')) |
|
|
sys.path.insert(0, os.path.join(pdvc_dir, 'densevid_eval3/SODA')) |
|
|
|
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
from eval_utils import evaluate |
|
|
import opts |
|
|
from tensorboardX import SummaryWriter |
|
|
from misc.utils import print_alert_message, build_floder, create_logger, backup_envir, print_opt, set_seed |
|
|
from data.video_dataset import PropSeqDataset, collate_fn |
|
|
from pdvc.pdvc import build |
|
|
from collections import OrderedDict |
|
|
from transformers import AutoTokenizer, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup |
|
|
import copy |
|
|
|
|
|
a100_folder = ['/cpfs01/shared/Gvlab-A100/Gvlab-A100_hdd/wuhao/youcook2', '/cpfs01/shared/Gvlab-A100/Gvlab-A100_hdd/wuhao/Tasty/features', '/cpfs01/shared/Gvlab-A100/Gvlab-A100_hdd/huabin/dataset/Tasty/UniVL_feature', '/cpfs01/shared/Gvlab-A100/Gvlab-A100_hdd/huabin/dataset/Anet', '/cpfs01/shared/Gvlab-A100/Gvlab-A100_hdd/wuhao/howto100m/features'] |
|
|
r3090_folder = ['/mnt/data/Gvlab/wuhao/features/yc2', '/mnt/data/Gvlab/wuhao/features/tasty', '/mnt/data/Gvlab/wuhao/features/tasty/univl', '/mnt/data/Gvlab/wuhao/features/anet', '/mnt/data/Gvlab/wuhao/features/howto100m'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_fn(worker_id): |
|
|
np.random.seed(12 + worker_id) |
|
|
|
|
|
def map_path(path): |
|
|
path_backup = copy.deepcopy(path) |
|
|
|
|
|
for i, folder in enumerate(a100_folder): |
|
|
if folder in path: |
|
|
path = path.replace(folder, r3090_folder[i]) |
|
|
return path |
|
|
if path == path_backup: |
|
|
print('map failed') |
|
|
exit(1) |
|
|
|
|
|
|
|
|
def train(opt): |
|
|
set_seed(opt.seed) |
|
|
save_folder = build_floder(opt) |
|
|
logger = create_logger(save_folder, 'train.log') |
|
|
tf_writer = SummaryWriter(os.path.join(save_folder, 'tf_summary')) |
|
|
|
|
|
if not opt.start_from: |
|
|
backup_envir(save_folder) |
|
|
logger.info('backup evironment completed !') |
|
|
|
|
|
saved_info = {'best': {}, 'last': {}, 'history': {}, 'eval_history': {}} |
|
|
|
|
|
|
|
|
if opt.start_from: |
|
|
opt.pretrain = False |
|
|
infos_path = os.path.join(save_folder, 'info.json') |
|
|
with open(infos_path) as f: |
|
|
logger.info('Load info from {}'.format(infos_path)) |
|
|
saved_info = json.load(f) |
|
|
prev_opt = saved_info[opt.start_from_mode[:4]]['opt'] |
|
|
|
|
|
exclude_opt = ['start_from', 'start_from_mode', 'pretrain'] |
|
|
for opt_name in prev_opt.keys(): |
|
|
if opt_name not in exclude_opt: |
|
|
vars(opt).update({opt_name: prev_opt.get(opt_name)}) |
|
|
if prev_opt.get(opt_name) != vars(opt).get(opt_name): |
|
|
logger.info('Change opt {} : {} --> {}'.format(opt_name, prev_opt.get(opt_name), |
|
|
vars(opt).get(opt_name))) |
|
|
if len(opt.visual_feature_folder) == 2: |
|
|
train_dataset_pretrain = PropSeqDataset(opt.train_caption_file[0], |
|
|
[opt.visual_feature_folder[0]], |
|
|
[opt.text_feature_folder[0]], |
|
|
opt.dict_file, True, 'gt', |
|
|
opt) |
|
|
train_dataset_target = PropSeqDataset(opt.train_caption_file[1], |
|
|
[opt.visual_feature_folder[1]], |
|
|
[opt.text_feature_folder[1]], |
|
|
opt.dict_file, True, 'gt', |
|
|
opt) |
|
|
train_loader_pretrain = DataLoader(train_dataset_pretrain, batch_size=opt.batch_size, |
|
|
shuffle=True, num_workers=opt.nthreads, collate_fn=collate_fn, worker_init_fn=_init_fn) |
|
|
train_loader_target = DataLoader(train_dataset_target, batch_size=opt.batch_size, |
|
|
shuffle=True, num_workers=opt.nthreads, collate_fn=collate_fn, worker_init_fn=_init_fn) |
|
|
|
|
|
train_dataloaders = [train_loader_pretrain, train_loader_target] |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
train_dataset_target = PropSeqDataset(opt.train_caption_file, |
|
|
opt.visual_feature_folder, |
|
|
opt.text_feature_folder, |
|
|
opt.dict_file, True, 'gt', |
|
|
opt) |
|
|
train_loader_target = DataLoader(train_dataset_target, batch_size=opt.batch_size, |
|
|
shuffle=True, num_workers=opt.nthreads, collate_fn=collate_fn, worker_init_fn=_init_fn) |
|
|
train_dataloaders = [train_loader_target] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(opt, 'dict_file_val'): |
|
|
opt.dict_file_val = opt.dict_file |
|
|
opt.vocab_size_val = opt.vocab_size |
|
|
|
|
|
val_dataset = PropSeqDataset(opt.val_caption_file, |
|
|
opt.visual_feature_folder_val, |
|
|
opt.text_feature_folder_val, |
|
|
opt.dict_file, False, 'gt', |
|
|
opt) |
|
|
|
|
|
|
|
|
val_loader = DataLoader(val_dataset, batch_size=opt.batch_size_for_eval, |
|
|
shuffle=False, num_workers=opt.nthreads, collate_fn=collate_fn, worker_init_fn=_init_fn) |
|
|
|
|
|
epoch = saved_info[opt.start_from_mode[:4]].get('epoch', 0) |
|
|
iteration = saved_info[opt.start_from_mode[:4]].get('iter', 0) |
|
|
best_val_score = saved_info[opt.start_from_mode[:4]].get('best_val_score', -1e5) |
|
|
val_result_history = saved_info['history'].get('val_result_history', {}) |
|
|
loss_history = saved_info['history'].get('loss_history', {}) |
|
|
lr_history = saved_info['history'].get('lr_history', {}) |
|
|
opt.current_lr = vars(opt).get('current_lr', opt.lr) |
|
|
|
|
|
|
|
|
|
|
|
model, criterion, contrastive_criterion, postprocessors = build(opt) |
|
|
model.translator = train_dataset_target.translator |
|
|
model.train() |
|
|
|
|
|
|
|
|
if opt.start_from and (not opt.pretrain): |
|
|
if opt.start_from_mode == 'best': |
|
|
model_pth = torch.load(os.path.join(save_folder, 'model-best.pth')) |
|
|
elif opt.start_from_mode == 'last': |
|
|
model_pth = torch.load(os.path.join(save_folder, 'model-last.pth')) |
|
|
logger.info('Loading pth from {}, iteration:{}'.format(save_folder, iteration)) |
|
|
model.load_state_dict(model_pth['model']) |
|
|
|
|
|
|
|
|
if opt.pretrain and (not opt.start_from): |
|
|
logger.info('Load pre-trained parameters from {}'.format(opt.pretrain_path)) |
|
|
model_pth = torch.load(opt.pretrain_path, map_location=torch.device(opt.device)) |
|
|
|
|
|
if opt.pretrain == 'encoder': |
|
|
encoder_filter = model.get_filter_rule_for_encoder() |
|
|
encoder_pth = {k:v for k,v in model_pth['model'].items() if encoder_filter(k)} |
|
|
model.load_state_dict(encoder_pth, strict=True) |
|
|
elif opt.pretrain == 'decoder': |
|
|
encoder_filter = model.get_filter_rule_for_encoder() |
|
|
decoder_pth = {k:v for k,v in model_pth['model'].items() if not encoder_filter(k)} |
|
|
model.load_state_dict(decoder_pth, strict=True) |
|
|
pass |
|
|
elif opt.pretrain == 'full': |
|
|
|
|
|
model.load_state_dict(model_pth['model'], strict=True) |
|
|
else: |
|
|
raise ValueError("wrong value of opt.pretrain") |
|
|
|
|
|
|
|
|
model.to(opt.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
other_params = model.parameters() |
|
|
|
|
|
training_params = [{'params': other_params, 'lr': opt.lr}] |
|
|
|
|
|
if opt.optimizer_type == 'adam': |
|
|
optimizer = optim.Adam(training_params, weight_decay=opt.weight_decay) |
|
|
|
|
|
elif opt.optimizer_type == 'adamw': |
|
|
optimizer = optim.AdamW(training_params, weight_decay=opt.weight_decay) |
|
|
|
|
|
milestone = [opt.learning_rate_decay_start + opt.learning_rate_decay_every * _ for _ in range(int((opt.epoch - opt.learning_rate_decay_start) / opt.learning_rate_decay_every))] |
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestone, gamma=opt.learning_rate_decay_rate) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.start_from: |
|
|
optimizer.load_state_dict(model_pth['optimizer']) |
|
|
lr_scheduler.step(epoch-1) |
|
|
|
|
|
|
|
|
print_opt(opt, model, logger) |
|
|
print_alert_message('Strat training !', logger) |
|
|
|
|
|
loss_sum = OrderedDict() |
|
|
bad_video_num = 0 |
|
|
|
|
|
start = time.time() |
|
|
|
|
|
weight_dict = criterion.weight_dict |
|
|
logger.info('loss type: {}'.format(weight_dict.keys())) |
|
|
logger.info('loss weights: {}'.format(weight_dict.values())) |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
|
if True: |
|
|
|
|
|
if epoch > opt.scheduled_sampling_start >= 0: |
|
|
frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every |
|
|
opt.ss_prob = min(opt.basic_ss_prob + opt.scheduled_sampling_increase_prob * frac, |
|
|
opt.scheduled_sampling_max_prob) |
|
|
model.caption_head.ss_prob = opt.ss_prob |
|
|
|
|
|
print('lr:{}'.format(float(opt.current_lr))) |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
for train_loader in train_dataloaders: |
|
|
trained_samples = 0 |
|
|
for dt in tqdm(train_loader, disable=opt.disable_tqdm): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.device=='cuda': |
|
|
torch.cuda.synchronize(opt.device) |
|
|
if opt.debug: |
|
|
|
|
|
if (iteration + 1) % 5 == 0: |
|
|
iteration += 1 |
|
|
break |
|
|
iteration += 1 |
|
|
|
|
|
optimizer.zero_grad() |
|
|
dt = {key: _.to(opt.device) if isinstance(_, torch.Tensor) else _ for key, _ in dt.items()} |
|
|
dt['video_target'] = [ |
|
|
{key: _.to(opt.device) if isinstance(_, torch.Tensor) else _ for key, _ in vid_info.items()} for vid_info in |
|
|
dt['video_target']] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output, loss = model(dt, criterion, contrastive_criterion) |
|
|
final_loss = sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict) |
|
|
|
|
|
final_loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
for loss_k,loss_v in loss.items(): |
|
|
loss_sum[loss_k] = loss_sum.get(loss_k, 0)+ loss_v.item() |
|
|
loss_sum['total_loss'] = loss_sum.get('total_loss', 0) + final_loss.item() |
|
|
|
|
|
if opt.device=='cuda': |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
losses_log_every = int(len(train_loader) / 10) |
|
|
|
|
|
if opt.debug: |
|
|
losses_log_every = 6 |
|
|
|
|
|
if iteration % losses_log_every == 0: |
|
|
end = time.time() |
|
|
for k in loss_sum.keys(): |
|
|
loss_sum[k] = np.round(loss_sum[k] /losses_log_every, 3).item() |
|
|
|
|
|
logger.info( |
|
|
"ID {} iter {} (epoch {}), \nloss = {}, \ntime/iter = {:.3f}, bad_vid = {:.3f}" |
|
|
.format(opt.id, iteration, epoch, loss_sum, |
|
|
(end - start) / losses_log_every, bad_video_num)) |
|
|
|
|
|
tf_writer.add_scalar('lr', opt.current_lr, iteration) |
|
|
for loss_type in loss_sum.keys(): |
|
|
tf_writer.add_scalar(loss_type, loss_sum[loss_type], iteration) |
|
|
loss_history[iteration] = loss_sum |
|
|
lr_history[iteration] = opt.current_lr |
|
|
loss_sum = OrderedDict() |
|
|
start = time.time() |
|
|
bad_video_num = 0 |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if (epoch % opt.save_checkpoint_every == 0) and (epoch >= opt.min_epoch_when_save): |
|
|
|
|
|
|
|
|
saved_pth = {'epoch': epoch, |
|
|
'model': model.state_dict(), |
|
|
'optimizer': optimizer.state_dict()} |
|
|
|
|
|
if opt.save_all_checkpoint: |
|
|
checkpoint_path = os.path.join(save_folder, 'model_iter_{}.pth'.format(iteration)) |
|
|
else: |
|
|
checkpoint_path = os.path.join(save_folder, 'model-last.pth') |
|
|
|
|
|
torch.save(saved_pth, checkpoint_path) |
|
|
|
|
|
model.eval() |
|
|
result_json_path = os.path.join(save_folder, 'prediction', |
|
|
'num{}_epoch{}.json'.format( |
|
|
len(val_dataset), epoch)) |
|
|
|
|
|
eval_score, _ = evaluate(model, criterion, postprocessors, val_loader, result_json_path, logger=logger, args=opt, alpha=opt.ec_alpha, device=opt.device, debug=opt.debug) |
|
|
if opt.caption_decoder_type == 'none': |
|
|
current_score = 2./(1./eval_score['Precision'] + 1./eval_score['Recall']) |
|
|
else: |
|
|
if opt.criteria_for_best_ckpt == 'dvc': |
|
|
current_score = np.array(eval_score['METEOR']).mean() + np.array(eval_score['soda_c']).mean() |
|
|
else: |
|
|
current_score = np.array(eval_score['para_METEOR']).mean() + np.array(eval_score['para_CIDEr']).mean() + np.array(eval_score['para_Bleu_4']).mean() |
|
|
|
|
|
|
|
|
for key in eval_score.keys(): |
|
|
tf_writer.add_scalar(key, np.array(eval_score[key]).mean(), iteration) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ = [item.append(np.array(item).mean()) for item in eval_score.values() if isinstance(item, list)] |
|
|
print_info = '\n'.join([key + ":" + str(eval_score[key]) for key in eval_score.keys()]) |
|
|
logger.info('\nValidation results of iter {}:\n'.format(iteration) + print_info) |
|
|
logger.info('\noverall score of iter {}: {}\n'.format(iteration, current_score)) |
|
|
val_result_history[epoch] = {'eval_score': eval_score} |
|
|
logger.info('Save model at iter {} to {}.'.format(iteration, checkpoint_path)) |
|
|
|
|
|
|
|
|
if current_score >= best_val_score: |
|
|
best_val_score = current_score |
|
|
best_epoch = epoch |
|
|
saved_info['best'] = {'opt': vars(opt), |
|
|
'iter': iteration, |
|
|
'epoch': best_epoch, |
|
|
'best_val_score': best_val_score, |
|
|
'result_json_path': result_json_path, |
|
|
'avg_proposal_num': eval_score['avg_proposal_number'], |
|
|
'Precision': eval_score['Precision'], |
|
|
'Recall': eval_score['Recall'] |
|
|
} |
|
|
|
|
|
|
|
|
torch.save(saved_pth, os.path.join(save_folder, 'model-best.pth')) |
|
|
logger.info('Save Best-model at iter {} to checkpoint file.'.format(iteration)) |
|
|
|
|
|
saved_info['last'] = {'opt': vars(opt), |
|
|
'iter': iteration, |
|
|
'epoch': epoch, |
|
|
'best_val_score': best_val_score, |
|
|
} |
|
|
saved_info['history'] = {'val_result_history': val_result_history, |
|
|
'loss_history': loss_history, |
|
|
'lr_history': lr_history, |
|
|
|
|
|
} |
|
|
with open(os.path.join(save_folder, 'info.json'), 'w') as f: |
|
|
json.dump(saved_info, f) |
|
|
logger.info('Save info to info.json') |
|
|
|
|
|
model.train() |
|
|
|
|
|
epoch += 1 |
|
|
lr_scheduler.step() |
|
|
opt.current_lr = optimizer.param_groups[0]['lr'] |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if epoch >= opt.epoch: |
|
|
|
|
|
print('====== Conduct the Final Evaluation to test Best Checkpoint ======') |
|
|
val_logger = create_logger(save_folder, 'val.log') |
|
|
loaded_pth = torch.load(os.path.join(save_folder, 'model-best.pth'), map_location='cuda') |
|
|
model.load_state_dict(loaded_pth['model'], strict=True) |
|
|
model.eval() |
|
|
result_json_path = saved_info['best']['result_json_path'] |
|
|
eval_score, _ = evaluate(model, criterion, postprocessors, val_loader, result_json_path, logger=logger, args=opt, alpha=opt.ec_alpha, device=opt.device, debug=opt.debug) |
|
|
if opt.caption_decoder_type == 'none': |
|
|
current_score = 2./(1./eval_score['Precision'] + 1./eval_score['Recall']) |
|
|
else: |
|
|
if opt.criteria_for_best_ckpt == 'dvc': |
|
|
current_score = np.array(eval_score['METEOR']).mean() + np.array(eval_score['soda_c']).mean() |
|
|
else: |
|
|
current_score = np.array(eval_score['para_METEOR']).mean() + np.array(eval_score['para_CIDEr']).mean() + np.array(eval_score['para_Bleu_4']).mean() |
|
|
|
|
|
_ = [item.append(np.array(item).mean()) for item in eval_score.values() if isinstance(item, list)] |
|
|
print_info = '\n'.join([key + ":" + str(eval_score[key]) for key in eval_score.keys()]) |
|
|
val_logger.info('Best-model is saved at iter {}.\n'.format(saved_info['best']['iter'])) |
|
|
val_logger.info('\nBest Model Performance:\n' + print_info) |
|
|
val_logger.info('\nBest Overall Score {}: {}\n'.format(iteration, current_score)) |
|
|
|
|
|
tf_writer.close() |
|
|
break |
|
|
|
|
|
return saved_info |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
opt = opts.parse_opts() |
|
|
opt.id = 'seq-train' |
|
|
|
|
|
if not hasattr(opt, 'visual_feature_folder_val'): |
|
|
opt.visual_feature_folder_val = opt.visual_feature_folder |
|
|
opt.text_feature_folder_val = opt.text_feature_folder |
|
|
|
|
|
if opt.map: |
|
|
opt.visual_feature_folder = [map_path(path) for path in opt.visual_feature_folder] |
|
|
opt.text_feature_folder = [map_path(path) for path in opt.text_feature_folder] |
|
|
opt.visual_feature_folder_val = [map_path(path) for path in opt.visual_feature_folder_val] |
|
|
opt.text_feature_folder_val = [map_path(path) for path in opt.text_feature_folder_val] |
|
|
|
|
|
if opt.gpu_id: |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in opt.gpu_id]) |
|
|
if opt.disable_cudnn: |
|
|
torch.backends.cudnn.enabled = False |
|
|
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' |
|
|
|
|
|
train(opt) |
|
|
|
|
|
|