Spaces:
Configuration error
Configuration error
| import os | |
| import sys | |
| import json | |
| import argparse | |
| import pathlib | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| # https://github.com/PyTorchLightning/pytorch-lightning/issues/11663 | |
| import sentencepiece; import pytorch_lightning as pl | |
| import torchmetrics.functional as MF | |
| from load_aokvqa import load_aokvqa | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') | |
| parser.add_argument('--vocab', type=argparse.FileType('r'), required=True) | |
| parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True) | |
| # | |
| parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True) | |
| parser.add_argument('--clip-model-type', type=str, | |
| choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], | |
| dest='clip_model_type', required=('clip' in sys.argv)) | |
| parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features') | |
| parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features') | |
| parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features') | |
| # | |
| parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True) | |
| parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True) | |
| # Defaults | |
| parser.add_argument('--bs', type=int, default=128, dest='batch_size') | |
| parser.add_argument('--lr', type=float, default=0.01) | |
| parser.add_argument('--epochs', type=int, default=500) | |
| parser.add_argument('--gpus', type=int, default=1) | |
| args = parser.parse_args() | |
| pl.seed_everything(1) | |
| vocab = args.vocab.read().splitlines() | |
| ## Data loading | |
| dm = AokvqaEmbeddingsDataModule( | |
| args.aokvqa_dir, | |
| args.train_features, | |
| args.val_features, | |
| args.objective, | |
| args.backbone, | |
| args.inputs, | |
| vocab, | |
| args.vocab_features, | |
| batch_size=args.batch_size, | |
| num_workers=16 | |
| ) | |
| ## Model definition | |
| model = LinearClassifier( | |
| args.objective, | |
| args.backbone, | |
| args.clip_model_type, | |
| args.inputs, | |
| len(vocab), | |
| args.lr | |
| ) | |
| ## Training and testing loops | |
| logger = pl.loggers.TensorBoardLogger( | |
| args.log_dir, | |
| name=f'{args.backbone}-{args.objective}', | |
| version=f"inputs:{'+'.join(args.inputs)}" | |
| ) | |
| trainer = pl.Trainer( | |
| logger=logger, | |
| gpus=args.gpus, | |
| max_epochs=args.epochs, | |
| callbacks=[ | |
| pl.callbacks.ModelCheckpoint( | |
| monitor="val_acc", | |
| filename="{epoch:02d}-{val_acc:.2f}", | |
| mode="max" | |
| ) | |
| ], | |
| ) | |
| trainer.fit(model, dm) | |
| class AokvqaEmbeddingsDataset(Dataset): | |
| def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features): | |
| aokvqa_set = load_aokvqa(aokvqa_dir, split) | |
| assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \ | |
| or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \ | |
| or ( backbone == 'clip' ) | |
| embeddings = torch.load(input_features) | |
| if backbone == 'clip': | |
| for q in embeddings.keys(): | |
| embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True) | |
| embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True) | |
| if objective == 'contrastive': | |
| vocab_embeddings = torch.load(vocab_features) | |
| vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True) | |
| self.objective = objective | |
| self.vocab_len = len(vocab) | |
| self.embeddings = [] | |
| self.answers = [] | |
| for o in aokvqa_set: | |
| correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers']) | |
| correct_answers = [vocab.index(a) for a in correct_answers if a in vocab] | |
| if self.objective == 'contrastive': | |
| correct_answers = [vocab_embeddings[a] for a in correct_answers] | |
| if len(correct_answers) == 0: continue | |
| self.answers.append(correct_answers) | |
| q = o['question_id'] | |
| if 'question' in inputs and 'image' in inputs: | |
| e = torch.cat((embeddings[q]['question'], embeddings[q]['image'])) | |
| elif 'question' in inputs and 'image' not in inputs: | |
| e = embeddings[q]['question'] | |
| elif 'question' not in inputs and 'image' in inputs: | |
| e = embeddings[q]['image'] | |
| self.embeddings.append(e) | |
| def __getitem__(self, index): | |
| e = self.embeddings[index] | |
| a = self.answers[index] | |
| if self.objective == 'classifier': | |
| a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0) | |
| elif self.objective == 'contrastive': | |
| a = random.sample(a, 1)[0] | |
| return e, a | |
| def __len__(self): | |
| return len(self.embeddings) | |
| class AokvqaEmbeddingsDataModule(pl.LightningDataModule): | |
| def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0): | |
| super().__init__() | |
| self.aokvqa_dir = aokvqa_dir | |
| self.train_features = train_features | |
| self.val_features = val_features | |
| self.objective = objective | |
| self.backbone = backbone | |
| self.inputs = inputs | |
| self.vocab = vocab | |
| self.vocab_features = vocab_features | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| def setup(self, stage=None): | |
| self.train_dataset = AokvqaEmbeddingsDataset( | |
| self.aokvqa_dir, 'train', self.train_features, self.objective, | |
| self.backbone, self.inputs, self.vocab, self.vocab_features | |
| ) | |
| self.val_dataset = AokvqaEmbeddingsDataset( | |
| self.aokvqa_dir, 'val', self.val_features, self.objective, | |
| self.backbone, self.inputs, self.vocab, self.vocab_features | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, batch_size=self.batch_size, shuffle=True, | |
| num_workers=int(0.8 * self.num_workers) | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, batch_size=self.batch_size, shuffle=False, | |
| num_workers=int(0.2 * self.num_workers) | |
| ) | |
| class LinearClassifier(pl.LightningModule): | |
| def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001): | |
| super().__init__() | |
| self.save_hyperparameters(ignore=['lr']) | |
| self.lr = lr | |
| if self.hparams.backbone == 'clip': | |
| clip_dim = { | |
| 'RN50' : 1024, | |
| 'RN50x4' : 640, | |
| 'RN50x16' : 768, | |
| 'RN50x64' : 1024, | |
| 'RN101' : 512, | |
| 'ViT-B/32' : 512, | |
| 'ViT-B/16' : 512, | |
| 'ViT-L/14' : 768, | |
| 'ViT-L/14@336px' : 768, | |
| }[clip_model_type] | |
| emb_dim = clip_dim * len(inputs) | |
| elif self.hparams.backbone == 'resnet': | |
| emb_dim = 2048 | |
| elif self.hparams.backbone == 'bert': | |
| emb_dim = 768 | |
| if self.hparams.objective == 'classifier': | |
| out_dim = vocab_len | |
| elif self.hparams.objective == 'contrastive': | |
| out_dim = clip_dim | |
| self.linear = nn.Linear(emb_dim, out_dim) | |
| def forward(self, x): | |
| x = self.linear(x) | |
| if self.hparams.objective == 'classifier': | |
| x = torch.sigmoid(x) | |
| return x | |
| def compute_loss(self, batch): | |
| x, y = batch | |
| y_pred = self.forward(x) | |
| if self.hparams.objective == 'classifier': | |
| loss = F.binary_cross_entropy(y_pred, y.float()) | |
| elif self.hparams.objective == 'contrastive': | |
| indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device) | |
| sim = (y_pred @ y.T).softmax(dim=-1) | |
| loss = F.cross_entropy(sim, indices) | |
| if self.hparams.objective == 'classifier': | |
| acc = MF.f1_score(y_pred, y) | |
| elif self.hparams.objective == 'contrastive': | |
| acc = torch.mean(sim[indices, indices]) | |
| return loss, acc | |
| def training_step(self, batch, batch_idx): | |
| loss, acc = self.compute_loss(batch) | |
| self.log("train_loss", loss) | |
| self.log("train_acc", acc) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| loss, acc = self.compute_loss(batch) | |
| self.log("val_loss", loss) | |
| self.log("val_acc", acc) | |
| return loss | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) | |
| return optimizer | |
| if __name__ == '__main__': | |
| main() | |