Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from model.backbone import get_visual_backbone | |
| from model.encoder import get_encoder, TransformerEncoder | |
| from model.decoder import get_decoder | |
| from utils.utils import * | |
| import numpy as np | |
| class MLMTransformerPretrain(nn.Module): | |
| def __init__(self, cfg, src_lang): | |
| super(MLMTransformerPretrain, self).__init__() | |
| self.cfg = cfg | |
| self.transformer_en = TransformerEncoder(cfg.encoder_embedding_size) | |
| self.text_embedding_src = self.get_text_embedding_src( | |
| vocab_size = src_lang.n_words, | |
| embedding_dim = cfg.encoder_embedding_size, | |
| padding_idx = 0, | |
| pretrain_emb_path = cfg.pretrain_emb_path | |
| ) | |
| self.class_tag_embedding = nn.Embedding( | |
| len(src_lang.class_tag), | |
| cfg.encoder_embedding_size, | |
| padding_idx=0 | |
| ) | |
| self.sect_tag_embedding = nn.Embedding( | |
| len(src_lang.sect_tag), | |
| cfg.encoder_embedding_size, | |
| padding_idx=0 | |
| ) | |
| def forward(self, text_dict): | |
| ''' | |
| text_dict = {'token', 'sect_tag', 'class_tag', 'len'} | |
| ''' | |
| # text feature | |
| token_emb = self.text_embedding_src(text_dict['token']) | |
| class_tag_emb = self.class_tag_embedding(text_dict['class_tag']) | |
| sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag']) | |
| text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb | |
| transformer_outputs = self.transformer_en(text_dict['len'], text_emb_src) | |
| return transformer_outputs | |
| def load_model(self, model_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pretrain_dict = torch.load( | |
| model_path, map_location=device | |
| ) | |
| pretrain_dict_model = pretrain_dict['state_dict'] \ | |
| if 'state_dict' in pretrain_dict else pretrain_dict | |
| model_dict = self.state_dict() | |
| from collections import OrderedDict | |
| new_dict = OrderedDict() | |
| for k, v in pretrain_dict_model.items(): | |
| if k in model_dict: | |
| if k.startswith("module"): | |
| new_dict[k[7:]] = v | |
| else: | |
| new_dict[k] = v | |
| model_dict.update(new_dict) | |
| self.load_state_dict(model_dict) | |
| def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path): | |
| embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) | |
| if pretrain_emb_path!='': | |
| emb_content = [] | |
| with open(pretrain_emb_path, 'r') as f: | |
| for line in f: | |
| emb_content.append(line.split()[1:]) | |
| vector = np.asarray(emb_content, "float32") | |
| embedding_src.weight.data[-len(emb_content):]. \ | |
| copy_(torch.from_numpy(vector)) | |
| return embedding_src | |
| class Network(nn.Module): | |
| def __init__(self, cfg, src_lang, tgt_lang): | |
| super(Network, self).__init__() | |
| self.cfg = cfg | |
| # define the encoder and decoder | |
| self.visual_extractor = get_visual_backbone(cfg) | |
| self.encoder = get_encoder(cfg) | |
| self.decoder = get_decoder(cfg, tgt_lang) | |
| self.visual_emb_unify = nn.ModuleList([ | |
| nn.Linear(self.visual_extractor.final_feat_dim, cfg.encoder_embedding_size), | |
| nn.ReLU(), | |
| nn.Linear(cfg.encoder_embedding_size, cfg.encoder_embedding_size)] | |
| ) | |
| self.visual_emb_unify = nn.Sequential(*self.visual_emb_unify) | |
| if cfg.use_MLM_pretrain: | |
| self.mlm_pretrain = MLMTransformerPretrain(cfg, src_lang) | |
| if cfg.MLM_pretrain_path!='': | |
| self.mlm_pretrain.load_model(cfg.MLM_pretrain_path) | |
| else: | |
| self.text_embedding_src = self.get_text_embedding_src( | |
| vocab_size = src_lang.n_words, | |
| embedding_dim = cfg.encoder_embedding_size, | |
| padding_idx = 0, | |
| pretrain_emb_path = cfg.pretrain_emb_path | |
| ) | |
| self.class_tag_embedding = nn.Embedding( | |
| len(src_lang.class_tag), | |
| cfg.encoder_embedding_size, | |
| padding_idx=0 | |
| ) | |
| self.sect_tag_embedding = nn.Embedding( | |
| len(src_lang.sect_tag), | |
| cfg.encoder_embedding_size, | |
| padding_idx=0 | |
| ) | |
| self.src_lang = src_lang | |
| def forward(self, diagram_src, text_dict, var_dict, exp_dict, is_train=False): | |
| ''' | |
| diagram_src: B x C x W x H | |
| text_dict = {'token', 'sect_tag', 'class_tag', 'len'} / | |
| {'token', 'sect_tag', 'class_tag', 'subseq_len', 'item_len', 'item_quant'} | |
| var_dict = {'pos', 'len', 'var_value', 'arg_value'} | |
| exp_dict = {'exp', 'len', 'answer'} | |
| ''' | |
| if self.cfg.use_MLM_pretrain: | |
| text_emb_src = self.mlm_pretrain(text_dict) | |
| else: | |
| # text feature | |
| token_emb = self.text_embedding_src(text_dict['token']) | |
| class_tag_emb = self.class_tag_embedding(text_dict['class_tag']) | |
| sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag']) | |
| # all feature | |
| text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb | |
| # diagram feature | |
| diagram_emb_src = self.visual_extractor(diagram_src) | |
| diagram_emb_src = self.visual_emb_unify(diagram_emb_src).unsqueeze(dim=1) | |
| # feature all | |
| all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1) | |
| text_dict['len'] += 1 | |
| var_dict['pos'] += 1 | |
| # encoder | |
| encoder_outputs, encode_hidden = self.encoder(all_emb_src, text_dict['len']) | |
| problem_output = encode_hidden[-1:,:,:].repeat(self.cfg.decoder_layers, 1, 1) | |
| # decoder | |
| outputs = self.decoder(encoder_outputs, problem_output, \ | |
| text_dict['len'], \ | |
| var_dict['pos'], var_dict['len'], \ | |
| exp_dict['exp'], \ | |
| is_train) | |
| return outputs | |
| def freeze_module(self, module): | |
| self.cfg.logger.info("Freezing module of "+" .......") | |
| for p in module.parameters(): | |
| p.requires_grad = False | |
| def load_model(self, model_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pretrain_dict = torch.load( | |
| model_path, map_location=device | |
| ) | |
| pretrain_dict_model = pretrain_dict['state_dict'] \ | |
| if 'state_dict' in pretrain_dict else pretrain_dict | |
| model_dict = self.state_dict() | |
| from collections import OrderedDict | |
| new_dict = OrderedDict() | |
| for k, v in pretrain_dict_model.items(): | |
| if k.startswith("module"): | |
| new_dict[k[7:]] = v | |
| else: | |
| new_dict[k] = v | |
| model_dict.update(new_dict) | |
| self.load_state_dict(model_dict) | |
| return pretrain_dict | |
| def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path): | |
| embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) | |
| if pretrain_emb_path!='': | |
| emb_content = [] | |
| with open(pretrain_emb_path, 'r') as f: | |
| for line in f: | |
| emb_content.append(line.split()[1:]) | |
| vector = np.asarray(emb_content, "float32") | |
| embedding_src.weight.data[-len(emb_content):]. \ | |
| copy_(torch.from_numpy(vector)) | |
| return embedding_src | |
| def get_model(args, src_lang, tgt_lang): | |
| model = Network(args, src_lang, tgt_lang) | |
| args.logger.info(str(model)) | |
| return model | |