| import argparse | |
| import random | |
| import os | |
| from datetime import datetime | |
| import numpy as np | |
| import torch | |
| import json | |
| from torch.optim import AdamW | |
| from torchvision.transforms import v2 | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from transformers import get_constant_schedule_with_warmup | |
| from Datasets import CCDataset, Batcher | |
| from model import ICCModel | |
| from utils import get_vocabulary | |
| from Loss import InfoNCELoss | |
| from eval import captioning, retrieve, plot | |
| from huggingface_hub import hf_hub_download | |
| import open_clip | |
| def train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer): | |
| step = 0 | |
| best_score = float("inf") | |
| best_model = None | |
| for epoch in range(args.epochs): | |
| model.train() | |
| for batch in tqdm(train_loader, desc='Epoch ' + str(epoch)): | |
| imgs1 = batch['images_before'].to(device) | |
| imgs2 = batch['images_after'].to(device) | |
| toks = batch['input_ids'].to(device) | |
| labs = batch['labels'].to(device) | |
| flags = batch['flags'].to(device) | |
| attention_mask = batch['pad_mask'].to(device) | |
| embs = batch['embs'].to(device) | |
| cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask) | |
| con_loss, num_pos = infonce(vis_emb, text_emb, flags, embs) | |
| loss = cap_loss + args.lamb * con_loss | |
| loss.backward() | |
| if args.max_grad_norm: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) | |
| grad = torch.norm(torch.stack( | |
| [torch.norm(p.grad.detach()).to(device) for p in model.parameters() if p.grad is not None])) | |
| optim.step() | |
| scheduler.step() | |
| optim.zero_grad() | |
| writer.add_scalar('train_loss', loss.item(), step) | |
| writer.add_scalar('grad', grad, step) | |
| writer.add_scalar('lr', scheduler.get_last_lr()[0], step) | |
| step += 1 | |
| torch.save(model.state_dict(), args.output_path + 'model_{}.pt'.format(step)) | |
| model.eval() | |
| with torch.no_grad(): | |
| eval_losses = torch.empty(0) | |
| for batch in tqdm(valid_loader, desc='Validation ' + str(epoch)): | |
| imgs1 = batch['images_before'].to(device) | |
| imgs2 = batch['images_after'].to(device) | |
| toks = batch['input_ids'].to(device) | |
| labs = batch['labels'].to(device) | |
| flags = batch['flags'].to(device) | |
| attention_mask = batch['pad_mask'].to(device) | |
| embs = batch['embs'].to(device) | |
| cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask) | |
| con_loss, _ = infonce(vis_emb, text_emb, flags, embs) | |
| loss = cap_loss + args.lamb * con_loss | |
| eval_losses = torch.cat([eval_losses, loss.cpu().unsqueeze(0)]) | |
| eval_score = torch.mean(eval_losses) | |
| writer.add_scalar('eval_score', eval_score, step) | |
| is_best = eval_score < best_score | |
| best_score = min(eval_score, best_score) | |
| if is_best: | |
| best_model = step | |
| if best_model is not None: | |
| state_dict = torch.load(os.path.join(args.output_path + 'model_{}.pt'.format(best_model)), map_location=device) | |
| torch.save(state_dict, args.output_path + 'model_best.pt') | |
| def run(args, config): | |
| print('Initializing...') | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| random.seed(args.seed) | |
| torch.backends.cudnn.deterministic = True | |
| device = torch.device('cpu') | |
| if torch.cuda.is_available(): | |
| device = torch.device('cuda') | |
| dt_str = datetime.now().strftime("%d-%m-%Y-%H-%M-%S") | |
| writer_path = args.output_path + dt_str | |
| writer = SummaryWriter(writer_path) | |
| if os.path.exists(args.vocab): | |
| with open(args.vocab, 'r') as infile: | |
| vocab = json.load(infile) | |
| else: | |
| vocab = get_vocabulary(args.annotation_json, args.vocab) | |
| clip = None | |
| preprocess = v2.Compose([ | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| if 'resnet' not in config['backbone']: | |
| checkpoint_path = hf_hub_download("chendelong/RemoteCLIP", | |
| f"RemoteCLIP-{config['backbone']}.pt", | |
| cache_dir=args.pretrained) | |
| clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone']) | |
| ckpt = torch.load(checkpoint_path, map_location="cpu") | |
| clip.load_state_dict(ckpt) | |
| model = ICCModel(device, clip, config['backbone'], config['d_model'], | |
| len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'], | |
| config['encoder_layers'], config['decoder_layers'], config['dropout'], | |
| learnable=config['learnable'], fine_tune=config['fine_tune'], | |
| tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm']) | |
| model = model.to(device) | |
| del clip | |
| print('Loading...') | |
| training_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'train', config['max_len'], | |
| config['s-transformers'], device) | |
| valid_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'val', config['max_len'], | |
| config['s-transformers'], device) | |
| test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'], | |
| config['s-transformers'], device) | |
| train_loader = Batcher(training_set, args.batch_size, config['max_len'], device, args.hd, model=model, shuffle=True) | |
| valid_loader = Batcher(valid_set, args.batch_size, config['max_len'], device) | |
| test_loader = Batcher(test_set, 1, config['max_len'], device) | |
| print('Training...') | |
| infonce = InfoNCELoss(device, k=args.k, temperature=args.temperature, threshold=config['s-threshold'], | |
| fna=config['fna']) | |
| optim = AdamW([x for x in model.parameters() if x.requires_grad], lr=args.learning_rate, eps=args.adam_epsilon) | |
| scheduler = get_constant_schedule_with_warmup(optim, | |
| num_warmup_steps=args.warmup_steps * len(train_loader) * args.epochs) | |
| train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer) | |
| print('Final evaluation...') | |
| model.load_state_dict(torch.load(os.path.join(args.output_path, 'model_best.pt'), map_location=device)) | |
| results = captioning(args, config, model, test_loader, vocab, device) | |
| retrieve(args, config, model, test_loader, device) | |
| plot(args, model.encoder.encoder.feat_size, results) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json') | |
| parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/') | |
| parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json') | |
| parser.add_argument('--pretrained', type=str, default='../../input/checkpoints') | |
| parser.add_argument('--config', type=str, default='../config.json') | |
| parser.add_argument('--output_path', type=str, default='../output/') | |
| parser.add_argument('--epochs', type=int, default=50) | |
| parser.add_argument('--batch_size', type=int, default=4) | |
| parser.add_argument('--k', type=int, default=-1) | |
| parser.add_argument('--hd', type=int, default=-1) | |
| parser.add_argument('--learning_rate', type=float, default=1e-4) | |
| parser.add_argument('--warmup_steps', type=float, default=0.025) | |
| parser.add_argument('--lr_decay', type=float, default=0.7) | |
| parser.add_argument('--adam_epsilon', type=float, default=1e-8) | |
| parser.add_argument('--max_grad_norm', type=float, default=None) | |
| parser.add_argument('--temperature', type=float, default=0.01) | |
| parser.add_argument('--lamb', type=float, default=0.5) | |
| parser.add_argument('--seed', type=int, default=42) | |
| args = parser.parse_args() | |
| with open(args.config, 'r') as config_file: | |
| config = json.load(config_file) | |
| run(args, config) | |
| if __name__ == '__main__': | |
| main() | |