Spaces:
Runtime error
Runtime error
| from ast import parse | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import numpy as np | |
| import time | |
| import os | |
| from collections import defaultdict | |
| # import captioning.utils.opts as opts | |
| # import captioning.models as models | |
| # from captioning.data.pth_loader import CaptionDataset | |
| # import captioning.utils.eval_utils as eval_utils | |
| # import captioning.utils.misc as utils | |
| # from captioning.utils.rewards import init_scorer, get_self_critical_reward | |
| # from captioning.modules.loss_wrapper import LossWrapper | |
| from clip_model import CLIPScore | |
| from caption_data import COCORetrievalDataset | |
| import pytorch_lightning as pl | |
| import detectron2.utils.comm as d2comm | |
| from detectron2.utils.env import seed_all_rng | |
| seed_all_rng(1234) | |
| class LitModel(pl.LightningModule): | |
| def __init__(self, opt): | |
| super().__init__() | |
| self.opt = opt | |
| self.args = args | |
| # Intilaize dataset | |
| # self.dataset = CaptionDataset(opt) | |
| # self.dataset = | |
| # opt.vocab_size = self.dataset.vocab_size | |
| # opt.seq_length = self.dataset.seq_length | |
| # self.batch_size = opt.batch_size | |
| # Build model | |
| # opt.vocab = self.dataset.get_vocab() | |
| # model = models.setup(opt) | |
| # print(model) | |
| # del opt.vocab | |
| # wrapper with loss in it. | |
| # lw_model = LossWrapper(model, opt) | |
| self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out) | |
| # self.lw_model = lw_model | |
| for p in self.model.clip_model.vision_model.parameters(): | |
| p.requires_grad = False | |
| for p in self.model.clip_model.visual_projection.parameters(): | |
| p.requires_grad = False | |
| # self.struc_flag = None | |
| # self.sc_flag = None | |
| def forward(self, *args, **kwargs): | |
| """ | |
| I hate this design. Never pretend it as a nn.Module | |
| """ | |
| raise NotImplementedError | |
| def train_dataloader(self): | |
| # train_dataset = torch.utils.data.Subset( | |
| # self.dataset, | |
| # self.dataset.split_ix['train'] | |
| # ) | |
| # train_loader = torch.utils.data.DataLoader( | |
| # dataset=train_dataset, | |
| # batch_size=self.batch_size, | |
| # shuffle=True, | |
| # num_workers=4, | |
| # collate_fn=self.dataset.collate_func | |
| # ) | |
| train_dataset = COCORetrievalDataset( | |
| split='karpathy_train', mode='train', | |
| args=opt, | |
| verbose=verbose | |
| ) | |
| train_loader = torch.utils.data.DataLoader( | |
| dataset=train_dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| collate_fn=train_dataset.collate_fn | |
| ) | |
| return train_loader | |
| def val_dataloader(self, split='karpathy_val'): | |
| # val_dataset = torch.utils.data.Subset( | |
| # self.dataset, | |
| # self.dataset.split_ix[split] | |
| # ) | |
| # val_loader = torch.utils.data.DataLoader( | |
| # val_dataset, | |
| # batch_size=self.batch_size, | |
| # shuffle=False, | |
| # num_workers=4, | |
| # drop_last=False, | |
| # collate_fn=self.dataset.collate_func | |
| # ) | |
| val_dataset = COCORetrievalDataset( | |
| split=split, mode='val', | |
| args=opt, | |
| verbose=verbose | |
| ) | |
| val_loader = torch.utils.data.DataLoader( | |
| dataset=val_dataset, | |
| batch_size=opt.valid_batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| drop_last=False, | |
| collate_fn=val_dataset.collate_fn | |
| ) | |
| return val_loader | |
| def test_dataloader(self): | |
| return self.val_dataloader('karpathy_test') | |
| def training_step(self, data, batch_idx): | |
| batch = data | |
| self.model.train() | |
| model_out = self.model.train_step( | |
| img_feat=batch['img_feats'], | |
| text=batch['text'], | |
| neg_text=batch['neg_text'], | |
| ) | |
| clip_loss = model_out['clip_loss'] | |
| if self.opt.joint_out: | |
| loss = clip_loss | |
| else: | |
| grammar_loss = model_out['grammar_loss'] | |
| loss = clip_loss + grammar_loss | |
| data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1] | |
| data_time = torch.tensor(data_time) | |
| # print('batch_idx', batch_idx) | |
| # print('loss:', loss) | |
| # logger_logs = model_out.copy() | |
| logger_logs = {} | |
| logger_logs['loss'] = loss.detach() | |
| logger_logs['clip_loss'] = clip_loss.detach() | |
| if not self.opt.joint_out: | |
| logger_logs['grammar_loss'] = grammar_loss.detach() | |
| logger_logs['data_time'] = data_time.detach() | |
| # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 | |
| # Please use self.log(...) inside the lightningModule instead. | |
| # # log on a step or aggregate epoch metric to the logger and/or progress bar | |
| # # (inside LightningModule) | |
| # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) | |
| # warnings.warn(*args, **kwargs) | |
| # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 | |
| # Please use self.log(...) inside the lightningModule instead. | |
| # output = { | |
| # 'loss': loss, | |
| # 'log': logger_logs, | |
| # 'progress_bar': {'data_time': data_time} | |
| # } | |
| for k, v in logger_logs.items(): | |
| if k in ['data_time', 'clip_loss', 'grammar_loss']: | |
| self.log('train/'+k, v, prog_bar=True) | |
| else: | |
| self.log('train/'+k, v) | |
| # print('training step logged') | |
| return loss | |
| def validation_step(self, data, batch_idx): | |
| batch = data | |
| self.model.eval() | |
| with torch.no_grad(): | |
| model_out = self.model.train_step( | |
| img_feat=batch['img_feats'], | |
| text=batch['text'], | |
| neg_text=batch['neg_text'], | |
| ) | |
| if self.opt.joint_out: | |
| clip_loss = model_out['clip_loss'] | |
| loss = clip_loss | |
| output = { | |
| # 'val_loss': loss, | |
| 'loss': loss.detach(), | |
| 'clip_loss': clip_loss.detach(), | |
| # 'grammar_loss': grammar_loss.detach(), | |
| 'img_feat': model_out['img_feat'].detach(), | |
| 'text_feat': model_out['text_feat'].detach(), | |
| # 'neg_text_feat': model_out['neg_text_feat'].detach(), | |
| # 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(), | |
| # 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(), | |
| # 'predictions': predictions, | |
| # 'n_predictions': n_predictions, | |
| } | |
| else: | |
| clip_loss = model_out['clip_loss'] | |
| grammar_loss = model_out['grammar_loss'] | |
| loss = clip_loss + grammar_loss | |
| output = { | |
| # 'val_loss': loss, | |
| 'loss': loss.detach(), | |
| 'clip_loss': clip_loss.detach(), | |
| 'grammar_loss': grammar_loss.detach(), | |
| 'img_feat': model_out['img_feat'].detach(), | |
| 'text_feat': model_out['text_feat'].detach(), | |
| # 'neg_text_feat': model_out['neg_text_feat'].detach(), | |
| 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(), | |
| 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(), | |
| # 'predictions': predictions, | |
| # 'n_predictions': n_predictions, | |
| } | |
| return output | |
| def test_step(self, *args, **kwargs): | |
| return self.validation_step(*args, **kwargs) | |
| def validation_epoch_end(self, outputs, split='val'): | |
| outputs = d2comm.gather(outputs) | |
| # master node | |
| if d2comm.is_main_process(): | |
| assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 | |
| outputs = sum(outputs, []) | |
| out = {} | |
| val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs) | |
| val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs) | |
| if not self.opt.joint_out: | |
| val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs) | |
| print('loss', val_loss_mean.item()) | |
| print('clip_loss', val_clip_loss_mean.item()) | |
| if not self.opt.joint_out: | |
| print('grammar_loss', val_grammar_loss_mean.item()) | |
| logit_scale = self.model.clip_model.logit_scale.exp().cpu() | |
| text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0) | |
| img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0) | |
| assert text_feats.size() == (5000, 512), text_feats.size() | |
| assert img_feats.size() == (5000, 512), img_feats.size() | |
| logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale | |
| logits_per_image = logits_per_text.T | |
| # text-to-image retrieval | |
| print('Text-to-Image retrieval') | |
| for k in [1, 5, 10]: | |
| text_to_image_topk = logits_per_text.topk(k, dim=1).indices | |
| n_text = len(text_to_image_topk) | |
| labels = torch.arange(0, n_text).view(-1, 1) | |
| n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum() | |
| recall_k = n_retrieved / n_text * 100 | |
| out[f'text_to_image_recall_{k}'] = recall_k.item() | |
| print(f'R@{k}: {recall_k.item():.2f}%') | |
| # image-to-text retrieval | |
| print('Image-to-Text retrieval') | |
| for k in [1, 5, 10]: | |
| image_to_text_topk = logits_per_image.topk(k, dim=1).indices | |
| n_image = len(image_to_text_topk) | |
| labels = torch.arange(0, n_image).view(-1, 1) | |
| n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum() | |
| recall_k = n_retrieved / n_image * 100 | |
| out[f'image_to_text_recall_{k}'] = recall_k.item() | |
| print(f'R@{k}: {recall_k.item():.2f}%') | |
| out.update({ | |
| 'loss': val_loss_mean.item(), | |
| 'clip_loss': val_clip_loss_mean.item() | |
| }) | |
| if not self.opt.joint_out: | |
| # grammar scoring | |
| grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0) | |
| grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0) | |
| TP = (grammar_pos_pred == 1).sum().item() | |
| FP = (grammar_pos_pred == 0).sum().item() | |
| FN = (grammar_neg_pred == 1).sum().item() | |
| TN = (grammar_neg_pred == 0).sum().item() | |
| print('Grammar check') | |
| print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}') | |
| precision = TP / (TP + FP) * 100 | |
| recall = TP / (TP + FN) * 100 | |
| accuracy = (TP + TN) / (TP + FP + FN + TN) * 100 | |
| f1 = 2 * precision * recall / (precision + recall) | |
| print(f'Precision: {precision:.2f}%') | |
| print(f'Recall: {recall:.2f}%') | |
| print(f'Accuracy: {accuracy:.2f}%') | |
| print(f'F1: {f1:.2f}%') | |
| print('Total: {}'.format(len(grammar_pos_pred))) | |
| out.update({ | |
| 'grammar_loss': val_grammar_loss_mean, | |
| 'grammar_precision': precision, | |
| 'grammar_recall': recall, | |
| 'grammar_accuracy': accuracy, | |
| 'grammar_f1': f1, | |
| }) | |
| else: | |
| out = {} | |
| out = d2comm.all_gather(out)[0] # Only the one from master node | |
| assert len(out) > 0 # make sure the head has index 0 | |
| # must all be tensors | |
| out = {k: torch.tensor(v) if not torch.is_tensor( | |
| v) else v for k, v in out.items()} | |
| for k, v in out.items(): | |
| self.log(f'{split}/{k}', v) | |
| def test_epoch_end(self, outputs): | |
| self.validation_epoch_end(outputs, 'test') | |
| def configure_optimizers(self): | |
| # opt = self.opt | |
| # model = self.model | |
| # parameters = [p for p in model.parameters() if p.requires_grad] | |
| # if opt.noamopt: | |
| # # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer' | |
| # optimizer = utils.get_std_opt( | |
| # model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) | |
| # elif opt.reduce_on_plateau: | |
| # # optimizer = utils.build_optimizer(model.parameters(), opt) | |
| # optimizer = utils.build_optimizer(parameters, opt) | |
| # optimizer = utils.ReduceLROnPlateau(optimizer, | |
| # factor=opt.reduce_on_plateau_factor, | |
| # patience=opt.reduce_on_plateau_patience) | |
| # else: | |
| # # optimizer = utils.build_optimizer(model.parameters(), opt) | |
| # optimizer = utils.build_optimizer(parameters, opt) | |
| # from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |
| # batch_per_epoch = len(self.train_loader) | |
| # t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs | |
| # warmup_ratio = self.args.warmup_ratio | |
| # warmup_iters = int(t_total * warmup_ratio) | |
| # if self.verbose: | |
| # print("Batch per epoch: %d" % batch_per_epoch) | |
| # print("Total Iters: %d" % t_total) | |
| # print('Warmup ratio:', warmup_ratio) | |
| # print("Warm up Iters: %d" % warmup_iters) | |
| if self.args.optim == 'adamw': | |
| no_decay = ["bias", "LayerNorm.weight"] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], | |
| "weight_decay": self.args.weight_decay, | |
| }, | |
| { | |
| "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| for group in optimizer_grouped_parameters: | |
| group['params'] = [p for p in group['params'] if p.requires_grad] | |
| from transformers.optimization import AdamW | |
| optim = AdamW(optimizer_grouped_parameters, | |
| lr=self.args.lr, eps=self.args.adam_eps) | |
| # lr_scheduler = get_linear_schedule_with_warmup( | |
| # optim, warmup_iters, t_total) | |
| # optimizers = [] | |
| optimizers = [optim] | |
| lr_schedulers = [] | |
| return optimizers, lr_schedulers | |
| def optimizer_step(self, epoch, batch_idx, optimizer, | |
| optimizer_idx, *args, **kwargs): | |
| # # warm up lr | |
| # opt = self.opt | |
| # iteration = self.trainer.global_step | |
| # if opt.use_warmup and (iteration < opt.noamopt_warmup): | |
| # opt.current_lr = opt.learning_rate * \ | |
| # (iteration+1) / opt.noamopt_warmup | |
| # utils.set_lr(optimizer, opt.current_lr) | |
| super().optimizer_step(epoch, batch_idx, optimizer, | |
| optimizer_idx, *args, **kwargs) | |
| # print('optimizer step') | |
| def state_dict(self): | |
| """ | |
| Save the model state dict as well as opt and vocab | |
| """ | |
| state_dict = self.model.state_dict() | |
| device = next(iter(state_dict.values())).device | |
| assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' | |
| # state_dict.update({ | |
| # '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device), | |
| # '_opt': utils.serialize_to_tensor(self.opt).to(device) | |
| # }) | |
| return state_dict | |
| def load_state_dict(self, state_dict=None, strict=True): | |
| # if '_vocab' in state_dict: | |
| # self.model.vocab = utils.deserialize(state_dict['_vocab']) | |
| # del state_dict['_vocab'] | |
| # elif strict: | |
| # raise KeyError | |
| # if '_opt' in state_dict: | |
| # saved_model_opt = utils.deserialize(state_dict['_opt']) | |
| # del state_dict['_opt'] | |
| # opt = self.opt | |
| # # Make sure the saved opt is compatible with the curren topt | |
| # need_be_same = ["caption_model", | |
| # "rnn_type", "rnn_size", "num_layers"] | |
| # for checkme in need_be_same: | |
| # if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ | |
| # getattr(opt, checkme) in ['updown', 'topdown']: | |
| # continue | |
| # assert getattr(saved_model_opt, checkme) == getattr( | |
| # opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme | |
| # elif strict: | |
| # raise KeyError | |
| self.model.load_state_dict(state_dict, strict) | |
| class OnEpochStartCallback(pl.Callback): | |
| def on_epoch_start(self, trainer, pl_module): | |
| # Update lr/training stage/scheduled sampling prob etc. | |
| opt = pl_module.opt | |
| model = pl_module.model | |
| epoch = trainer.current_epoch | |
| optimizer = trainer.optimizers[0] | |
| # if not opt.noamopt and not opt.reduce_on_plateau: | |
| # # Assign the learning rate | |
| # if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: | |
| # frac = ( | |
| # epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every | |
| # decay_factor = opt.learning_rate_decay_rate ** frac | |
| # opt.current_lr = opt.learning_rate * decay_factor | |
| # else: | |
| # opt.current_lr = opt.learning_rate | |
| # utils.set_lr(optimizer, opt.current_lr) # set the decayed rate | |
| # # Assign the scheduled sampling prob | |
| # if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: | |
| # frac = ( | |
| # epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every | |
| # opt.ss_prob = min(opt.scheduled_sampling_increase_prob * | |
| # frac, opt.scheduled_sampling_max_prob) | |
| # model.ss_prob = opt.ss_prob | |
| # # If start self critical training | |
| # if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: | |
| # sc_flag = True | |
| # init_scorer(opt.cached_tokens) | |
| # else: | |
| # sc_flag = False | |
| # # If start structure loss training | |
| # if opt.structure_after != -1 and epoch >= opt.structure_after: | |
| # struc_flag = True | |
| # init_scorer(opt.cached_tokens) | |
| # else: | |
| # struc_flag = False | |
| # pl_module.struc_flag = struc_flag | |
| # pl_module.sc_flag = sc_flag | |
| class ModelCheckpoint(pl.callbacks.ModelCheckpoint): | |
| def on_keyboard_interrupt(self, trainer, pl_module): | |
| # Save model when keyboard interrupt | |
| filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') | |
| self._save_model(filepath) | |
| from param import parse_args | |
| # opt = opts.parse_opt() | |
| args = parse_args() | |
| opt = args | |
| checkpoint_callback = ModelCheckpoint( | |
| filepath=opt.checkpoint_dir + '{epoch:02d}', | |
| # dirpath=opt.checkpoint_path, | |
| save_last=True, | |
| save_top_k=1, | |
| verbose=True, | |
| # monitor='to_monitor', | |
| # monitor='val/to_monitor', | |
| # monitor='val/CIDEr', | |
| monitor='val/loss', | |
| mode='min', | |
| # prefix=opt.id+'_', | |
| prefix=opt.id, | |
| # filename=f'{opt.id}_', | |
| ) | |
| verbose = True | |
| # import torch | |
| # if torch.cuda.current_device() in [0, -1]: | |
| if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': | |
| verbose = False | |
| # if verbose: | |
| # print(opt) | |
| # print(""" | |
| # val_image_use, | |
| # save_checkpoint_very | |
| # save_every_epoch, | |
| # save_history-ckpt will be ignored. | |
| # """) | |
| # Lightning defines batch size as batch size per gpu | |
| assert opt.batch_size % torch.cuda.device_count() == 0 | |
| opt.batch_size = opt.batch_size // torch.cuda.device_count() | |
| opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count() | |
| # If resume from last checkpoint | |
| # if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')): | |
| # resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt') | |
| if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')): | |
| resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt') | |
| if verbose: | |
| print('resume from', resume_from) | |
| else: | |
| resume_from = None | |
| from pytorch_lightning.loggers import WandbLogger | |
| wandb_logger = WandbLogger( | |
| # project='CLIP-ViL-COCOCaption', | |
| project='CLIP-Finetune-COCO', | |
| name=opt.id, | |
| ) | |
| if verbose: | |
| wandb_logger.experiment.config.update(opt) | |
| from pathlib import Path | |
| import glob | |
| import wandb | |
| # src_dir = Path(__file__).resolve().parent.parent | |
| glob_str = "*.py" | |
| base_path = './' | |
| wandb.save(glob_str=glob_str, base_path=base_path) | |
| glob_str = "**/*.yaml" | |
| base_path = './' | |
| wandb.save(glob_str=glob_str, base_path=base_path) | |
| # code = wandb.Artifact('project-source', type='code') | |
| # for path in glob.glob('**/*.py', recursive=True): | |
| # code.add_file(path, name='source/'+path) | |
| # print(path) | |
| # wandb.run.use_artifact(code) | |
| lit = LitModel(opt) | |
| # warning grad_clip_mode is ignored. | |
| trainer = pl.Trainer( | |
| callbacks=[ | |
| OnEpochStartCallback(), | |
| # pl.callbacks.lr_logger.LearningRateLogger() | |
| pl.callbacks.LearningRateMonitor() | |
| ], | |
| default_root_dir=opt.checkpoint_dir, | |
| resume_from_checkpoint=resume_from, | |
| distributed_backend='ddp', | |
| gpus=torch.cuda.device_count(), | |
| # gpus=1, | |
| check_val_every_n_epoch=1, | |
| # max_epochs=opt.max_epochs, | |
| max_epochs=opt.epochs, | |
| # gradient_clip_val=opt.grad_clip_value, | |
| gradient_clip_val=opt.clip_grad_norm, | |
| checkpoint_callback=checkpoint_callback, | |
| log_gpu_memory='min_max', | |
| # log_save_interval=opt.losses_log_every, | |
| log_every_n_steps=opt.losses_log_every, | |
| profiler=True, | |
| # profiler='simple', | |
| # row_log_interval=10, # what is it? | |
| flush_logs_every_n_steps=10, | |
| num_sanity_val_steps=0, | |
| # val_check_interval=0.01, | |
| # limit_train_batches=500, | |
| # progress_bar_refresh_rate=0, | |
| # fast_dev_run=True, | |
| precision=opt.precision, | |
| logger=wandb_logger | |
| ) | |
| if os.getenv('EVALUATE', '0') == '1': | |
| trainer.test(lit) | |
| else: | |
| trainer.fit(lit) | |