Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import json | |
| from json import encoder | |
| import random | |
| import string | |
| import time | |
| import os | |
| import sys | |
| from . import misc as utils | |
| # load coco-caption if available | |
| try: | |
| sys.path.append("coco-caption") | |
| from pycocotools.coco import COCO | |
| from pycocoevalcap.eval import COCOEvalCap | |
| except: | |
| print('Warning: coco-caption not available') | |
| bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] | |
| bad_endings += ['the'] | |
| def count_bad(sen): | |
| sen = sen.split(' ') | |
| if sen[-1] in bad_endings: | |
| return 1 | |
| else: | |
| return 0 | |
| def getCOCO(dataset): | |
| if 'coco' in dataset: | |
| annFile = 'coco-caption/annotations/captions_val2014.json' | |
| elif 'flickr30k' in dataset or 'f30k' in dataset: | |
| annFile = 'data/f30k_captions4eval.json' | |
| return COCO(annFile) | |
| def language_eval(dataset, preds, preds_n, eval_kwargs, split): | |
| model_id = eval_kwargs['id'] | |
| eval_oracle = eval_kwargs.get('eval_oracle', 0) | |
| # create output dictionary | |
| out = {} | |
| if len(preds_n) > 0: | |
| # vocab size and novel sentences | |
| if 'coco' in dataset: | |
| dataset_file = 'data/dataset_coco.json' | |
| elif 'flickr30k' in dataset or 'f30k' in dataset: | |
| dataset_file = 'data/dataset_flickr30k.json' | |
| training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']]) | |
| generated_sentences = set([_['caption'] for _ in preds_n]) | |
| novels = generated_sentences - training_sentences | |
| out['novel_sentences'] = float(len(novels)) / len(preds_n) | |
| tmp = [_.split() for _ in generated_sentences] | |
| words = [] | |
| for _ in tmp: | |
| words += _ | |
| out['vocab_size'] = len(set(words)) | |
| # encoder.FLOAT_REPR = lambda o: format(o, '.3f') | |
| cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json') | |
| coco = getCOCO(dataset) | |
| valids = coco.getImgIds() | |
| # filter results to only those in MSCOCO validation set | |
| preds_filt = [p for p in preds if p['image_id'] in valids] | |
| mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt) | |
| mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt) | |
| print('using %d/%d predictions' % (len(preds_filt), len(preds))) | |
| json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... | |
| cocoRes = coco.loadRes(cache_path) | |
| cocoEval = COCOEvalCap(coco, cocoRes) | |
| cocoEval.params['image_id'] = cocoRes.getImgIds() | |
| cocoEval.evaluate() | |
| for metric, score in cocoEval.eval.items(): | |
| out[metric] = score | |
| # Add mean perplexity | |
| out['perplexity'] = mean_perplexity | |
| out['entropy'] = mean_entropy | |
| imgToEval = cocoEval.imgToEval | |
| for k in list(imgToEval.values())[0]['SPICE'].keys(): | |
| if k != 'All': | |
| out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()]) | |
| out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean() | |
| for p in preds_filt: | |
| image_id, caption = p['image_id'], p['caption'] | |
| imgToEval[image_id]['caption'] = caption | |
| if len(preds_n) > 0: | |
| from . import eval_multi | |
| cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json') | |
| allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split) | |
| out.update(allspice['overall']) | |
| div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split) | |
| out.update(div_stats['overall']) | |
| if eval_oracle: | |
| oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split) | |
| out.update(oracle['overall']) | |
| else: | |
| oracle = None | |
| self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split) | |
| out.update(self_cider['overall']) | |
| with open(cache_path_n, 'w') as outfile: | |
| json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile) | |
| out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt)) | |
| outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json') | |
| with open(outfile_path, 'w') as outfile: | |
| json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) | |
| return out | |
| def eval_split(model, crit, loader, eval_kwargs={}): | |
| verbose = eval_kwargs.get('verbose', True) | |
| verbose_beam = eval_kwargs.get('verbose_beam', 0) | |
| verbose_loss = eval_kwargs.get('verbose_loss', 1) | |
| num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) | |
| split = eval_kwargs.get('split', 'val') | |
| lang_eval = eval_kwargs.get('language_eval', 0) | |
| dataset = eval_kwargs.get('dataset', 'coco') | |
| beam_size = eval_kwargs.get('beam_size', 1) | |
| sample_n = eval_kwargs.get('sample_n', 1) | |
| remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) | |
| os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration | |
| device = eval_kwargs.get('device', 'cuda') | |
| # Make sure in the evaluation mode | |
| model.eval() | |
| loader.reset_iterator(split) | |
| n = 0 | |
| loss = 0 | |
| loss_sum = 0 | |
| loss_evals = 1e-8 | |
| predictions = [] | |
| n_predictions = [] # when sample_n > 1 | |
| while True: | |
| data = loader.get_batch(split) | |
| n = n + len(data['infos']) | |
| tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] | |
| tmp = [_.to(device) if _ is not None else _ for _ in tmp] | |
| fc_feats, att_feats, labels, masks, att_masks = tmp | |
| if labels is not None and verbose_loss: | |
| # forward the model to get loss | |
| with torch.no_grad(): | |
| loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item() | |
| loss_sum = loss_sum + loss | |
| loss_evals = loss_evals + 1 | |
| # forward the model to also get generated samples for each image | |
| with torch.no_grad(): | |
| tmp_eval_kwargs = eval_kwargs.copy() | |
| tmp_eval_kwargs.update({'sample_n': 1}) | |
| seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
| seq = seq.data | |
| entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) | |
| perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) | |
| # Print beam search | |
| if beam_size > 1 and verbose_beam: | |
| for i in range(fc_feats.shape[0]): | |
| print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) | |
| print('--' * 10) | |
| sents = utils.decode_sequence(model.vocab, seq) | |
| for k, sent in enumerate(sents): | |
| entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} | |
| if eval_kwargs.get('dump_path', 0) == 1: | |
| entry['file_name'] = data['infos'][k]['file_path'] | |
| predictions.append(entry) | |
| if eval_kwargs.get('dump_images', 0) == 1: | |
| # dump the raw image to vis/ folder | |
| cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross | |
| print(cmd) | |
| os.system(cmd) | |
| if verbose: | |
| print('image %s: %s' %(entry['image_id'], entry['caption'])) | |
| if sample_n > 1: | |
| eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs) | |
| # ix0 = data['bounds']['it_pos_now'] | |
| ix1 = data['bounds']['it_max'] | |
| if num_images != -1: | |
| ix1 = min(ix1, num_images) | |
| else: | |
| num_images = ix1 | |
| for i in range(n - ix1): | |
| predictions.pop() | |
| if verbose: | |
| print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss)) | |
| if num_images >= 0 and n >= num_images: | |
| break | |
| lang_stats = None | |
| if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: | |
| n_predictions = sorted(n_predictions, key=lambda x: x['perplexity']) | |
| if not os.path.isdir('eval_results'): | |
| os.mkdir('eval_results') | |
| torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth')) | |
| if lang_eval == 1: | |
| lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split) | |
| # Switch back to training mode | |
| model.train() | |
| return loss_sum/loss_evals, predictions, lang_stats | |
| # Only run when sample_n > 0 | |
| def eval_split_n(model, n_predictions, input_data, eval_kwargs={}): | |
| verbose = eval_kwargs.get('verbose', True) | |
| beam_size = eval_kwargs.get('beam_size', 1) | |
| sample_n = eval_kwargs.get('sample_n', 1) | |
| sample_n_method = eval_kwargs.get('sample_n_method', 'sample') | |
| fc_feats, att_feats, att_masks, data = input_data | |
| tmp_eval_kwargs = eval_kwargs.copy() | |
| if sample_n_method == 'bs': | |
| # case 1 sample_n == beam size | |
| tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax | |
| with torch.no_grad(): | |
| model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
| for k in range(fc_feats.shape[0]): | |
| _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)])) | |
| for sent in _sents: | |
| entry = {'image_id': data['infos'][k]['id'], 'caption': sent} | |
| n_predictions.append(entry) | |
| # case 2 sample / gumbel / topk sampling/ nucleus sampling | |
| elif sample_n_method == 'sample' or \ | |
| sample_n_method == 'gumbel' or \ | |
| sample_n_method.startswith('top'): | |
| tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample | |
| with torch.no_grad(): | |
| _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
| _sents = utils.decode_sequence(model.vocab, _seq) | |
| _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1) | |
| for k, sent in enumerate(_sents): | |
| entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()} | |
| n_predictions.append(entry) | |
| elif sample_n_method == 'dbs': | |
| # Use diverse beam search | |
| tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax | |
| with torch.no_grad(): | |
| model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
| for k in range(loader.batch_size): | |
| _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)])) | |
| for sent in _sents: | |
| entry = {'image_id': data['infos'][k]['id'], 'caption': sent} | |
| n_predictions.append(entry) | |
| else: | |
| tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax | |
| with torch.no_grad(): | |
| _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
| _sents = utils.decode_sequence(model.vocab, _seq) | |
| for k, sent in enumerate(_sents): | |
| entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent} | |
| n_predictions.append(entry) | |
| if verbose: | |
| for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']): | |
| print('image %s: %s' %(entry['image_id'], entry['caption'])) |