pgps-demo / core /network.py
asdfasdfdsafdsa's picture
Fix device handling - properly support both CPU and CUDA
9658342 verified
import torch
import torch.nn as nn
from model.backbone import get_visual_backbone
from model.encoder import get_encoder, TransformerEncoder
from model.decoder import get_decoder
from utils.utils import *
import numpy as np
class MLMTransformerPretrain(nn.Module):
def __init__(self, cfg, src_lang):
super(MLMTransformerPretrain, self).__init__()
self.cfg = cfg
self.transformer_en = TransformerEncoder(cfg.encoder_embedding_size)
self.text_embedding_src = self.get_text_embedding_src(
vocab_size = src_lang.n_words,
embedding_dim = cfg.encoder_embedding_size,
padding_idx = 0,
pretrain_emb_path = cfg.pretrain_emb_path
)
self.class_tag_embedding = nn.Embedding(
len(src_lang.class_tag),
cfg.encoder_embedding_size,
padding_idx=0
)
self.sect_tag_embedding = nn.Embedding(
len(src_lang.sect_tag),
cfg.encoder_embedding_size,
padding_idx=0
)
def forward(self, text_dict):
'''
text_dict = {'token', 'sect_tag', 'class_tag', 'len'}
'''
# text feature
token_emb = self.text_embedding_src(text_dict['token'])
class_tag_emb = self.class_tag_embedding(text_dict['class_tag'])
sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag'])
text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb
transformer_outputs = self.transformer_en(text_dict['len'], text_emb_src)
return transformer_outputs
def load_model(self, model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_dict = torch.load(
model_path, map_location=device
)
pretrain_dict_model = pretrain_dict['state_dict'] \
if 'state_dict' in pretrain_dict else pretrain_dict
model_dict = self.state_dict()
from collections import OrderedDict
new_dict = OrderedDict()
for k, v in pretrain_dict_model.items():
if k in model_dict:
if k.startswith("module"):
new_dict[k[7:]] = v
else:
new_dict[k] = v
model_dict.update(new_dict)
self.load_state_dict(model_dict)
def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path):
embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
if pretrain_emb_path!='':
emb_content = []
with open(pretrain_emb_path, 'r') as f:
for line in f:
emb_content.append(line.split()[1:])
vector = np.asarray(emb_content, "float32")
embedding_src.weight.data[-len(emb_content):]. \
copy_(torch.from_numpy(vector))
return embedding_src
class Network(nn.Module):
def __init__(self, cfg, src_lang, tgt_lang):
super(Network, self).__init__()
self.cfg = cfg
# define the encoder and decoder
self.visual_extractor = get_visual_backbone(cfg)
self.encoder = get_encoder(cfg)
self.decoder = get_decoder(cfg, tgt_lang)
self.visual_emb_unify = nn.ModuleList([
nn.Linear(self.visual_extractor.final_feat_dim, cfg.encoder_embedding_size),
nn.ReLU(),
nn.Linear(cfg.encoder_embedding_size, cfg.encoder_embedding_size)]
)
self.visual_emb_unify = nn.Sequential(*self.visual_emb_unify)
if cfg.use_MLM_pretrain:
self.mlm_pretrain = MLMTransformerPretrain(cfg, src_lang)
if cfg.MLM_pretrain_path!='':
self.mlm_pretrain.load_model(cfg.MLM_pretrain_path)
else:
self.text_embedding_src = self.get_text_embedding_src(
vocab_size = src_lang.n_words,
embedding_dim = cfg.encoder_embedding_size,
padding_idx = 0,
pretrain_emb_path = cfg.pretrain_emb_path
)
self.class_tag_embedding = nn.Embedding(
len(src_lang.class_tag),
cfg.encoder_embedding_size,
padding_idx=0
)
self.sect_tag_embedding = nn.Embedding(
len(src_lang.sect_tag),
cfg.encoder_embedding_size,
padding_idx=0
)
self.src_lang = src_lang
def forward(self, diagram_src, text_dict, var_dict, exp_dict, is_train=False):
'''
diagram_src: B x C x W x H
text_dict = {'token', 'sect_tag', 'class_tag', 'len'} /
{'token', 'sect_tag', 'class_tag', 'subseq_len', 'item_len', 'item_quant'}
var_dict = {'pos', 'len', 'var_value', 'arg_value'}
exp_dict = {'exp', 'len', 'answer'}
'''
if self.cfg.use_MLM_pretrain:
text_emb_src = self.mlm_pretrain(text_dict)
else:
# text feature
token_emb = self.text_embedding_src(text_dict['token'])
class_tag_emb = self.class_tag_embedding(text_dict['class_tag'])
sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag'])
# all feature
text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb
# diagram feature
diagram_emb_src = self.visual_extractor(diagram_src)
diagram_emb_src = self.visual_emb_unify(diagram_emb_src).unsqueeze(dim=1)
# feature all
all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)
text_dict['len'] += 1
var_dict['pos'] += 1
# encoder
encoder_outputs, encode_hidden = self.encoder(all_emb_src, text_dict['len'])
problem_output = encode_hidden[-1:,:,:].repeat(self.cfg.decoder_layers, 1, 1)
# decoder
outputs = self.decoder(encoder_outputs, problem_output, \
text_dict['len'], \
var_dict['pos'], var_dict['len'], \
exp_dict['exp'], \
is_train)
return outputs
def freeze_module(self, module):
self.cfg.logger.info("Freezing module of "+" .......")
for p in module.parameters():
p.requires_grad = False
def load_model(self, model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrain_dict = torch.load(
model_path, map_location=device
)
pretrain_dict_model = pretrain_dict['state_dict'] \
if 'state_dict' in pretrain_dict else pretrain_dict
model_dict = self.state_dict()
from collections import OrderedDict
new_dict = OrderedDict()
for k, v in pretrain_dict_model.items():
if k.startswith("module"):
new_dict[k[7:]] = v
else:
new_dict[k] = v
model_dict.update(new_dict)
self.load_state_dict(model_dict)
return pretrain_dict
def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path):
embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
if pretrain_emb_path!='':
emb_content = []
with open(pretrain_emb_path, 'r') as f:
for line in f:
emb_content.append(line.split()[1:])
vector = np.asarray(emb_content, "float32")
embedding_src.weight.data[-len(emb_content):]. \
copy_(torch.from_numpy(vector))
return embedding_src
def get_model(args, src_lang, tgt_lang):
model = Network(args, src_lang, tgt_lang)
args.logger.info(str(model))
return model