Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import copy | |
| import numpy as np | |
| import torch | |
| from .ShowTellModel import ShowTellModel | |
| from .FCModel import FCModel | |
| from .AttModel import * | |
| from .TransformerModel import TransformerModel | |
| from .cachedTransformer import TransformerModel as cachedTransformer | |
| from .BertCapModel import BertCapModel | |
| from .M2Transformer import M2TransformerModel | |
| from .AoAModel import AoAModel | |
| def setup(opt): | |
| if opt.caption_model in ['fc', 'show_tell']: | |
| print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model) | |
| if opt.caption_model == 'fc': | |
| print('Use newfc instead of fc') | |
| if opt.caption_model == 'fc': | |
| model = FCModel(opt) | |
| elif opt.caption_model == 'language_model': | |
| model = LMModel(opt) | |
| elif opt.caption_model == 'newfc': | |
| model = NewFCModel(opt) | |
| elif opt.caption_model == 'show_tell': | |
| model = ShowTellModel(opt) | |
| # Att2in model in self-critical | |
| elif opt.caption_model == 'att2in': | |
| model = Att2inModel(opt) | |
| # Att2in model with two-layer MLP img embedding and word embedding | |
| elif opt.caption_model == 'att2in2': | |
| model = Att2in2Model(opt) | |
| elif opt.caption_model == 'att2all2': | |
| print('Warning: this is not a correct implementation of the att2all model in the original paper.') | |
| model = Att2all2Model(opt) | |
| # Adaptive Attention model from Knowing when to look | |
| elif opt.caption_model == 'adaatt': | |
| model = AdaAttModel(opt) | |
| # Adaptive Attention with maxout lstm | |
| elif opt.caption_model == 'adaattmo': | |
| model = AdaAttMOModel(opt) | |
| # Top-down attention model | |
| elif opt.caption_model in ['topdown', 'updown']: | |
| model = UpDownModel(opt) | |
| # StackAtt | |
| elif opt.caption_model == 'stackatt': | |
| model = StackAttModel(opt) | |
| # DenseAtt | |
| elif opt.caption_model == 'denseatt': | |
| model = DenseAttModel(opt) | |
| # Transformer | |
| elif opt.caption_model == 'transformer': | |
| if getattr(opt, 'cached_transformer', False): | |
| model = cachedTransformer(opt) | |
| else: | |
| model = TransformerModel(opt) | |
| # AoANet | |
| elif opt.caption_model == 'aoa': | |
| model = AoAModel(opt) | |
| elif opt.caption_model == 'bert': | |
| model = BertCapModel(opt) | |
| elif opt.caption_model == 'm2transformer': | |
| model = M2TransformerModel(opt) | |
| else: | |
| raise Exception("Caption model not supported: {}".format(opt.caption_model)) | |
| return model | |