Spaces:
Build error
Build error
| # coding:utf-8 | |
| # chenjun | |
| # date:2020-04-18 | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| # def get_non_pad_mask(seq, PAD): | |
| # assert seq.dim() == 2 | |
| # return seq.ne(PAD).type(torch.float).unsqueeze(-1) | |
| def get_pad_mask(seq, pad_idx): | |
| return (seq == pad_idx).unsqueeze(-2) | |
| def get_subsequent_mask(seq): | |
| ''' For masking out the subsequent info. ''' | |
| sz_b, len_s = seq.size() | |
| subsequent_mask = torch.triu( | |
| torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) # 返回上三角矩阵 | |
| subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls | |
| return subsequent_mask | |
| def get_attn_key_pad_mask(seq_k, seq_q, PAD): | |
| ''' For masking out the padding part of key sequence. | |
| seq_k:src_seq | |
| seq_q:tgt_seq | |
| ''' | |
| # Expand to fit the shape of key query attention matrix. | |
| len_q = seq_q.size(1) # 目标序列 | |
| padding_mask = seq_k.eq(PAD) # 源序列 | |
| padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk | |
| return padding_mask | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_hid, n_position=200): | |
| super(PositionalEncoding, self).__init__() | |
| # Not a parameter | |
| self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) | |
| def _get_sinusoid_encoding_table(self, n_position, d_hid): | |
| ''' Sinusoid position encoding table ''' | |
| # TODO: make it with torch instead of numpy | |
| def get_position_angle_vec(position): | |
| return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] | |
| sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) | |
| sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
| sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
| return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
| def forward(self, x): | |
| return x + self.pos_table[:, :x.size(1)].clone().detach() | |
| class ScaledDotProductAttention(nn.Module): | |
| ''' Scaled Dot-Product Attention ''' | |
| def __init__(self, temperature, attn_dropout=0.1): | |
| super(ScaledDotProductAttention, self).__init__() | |
| self.temperature = temperature | |
| self.dropout = nn.Dropout(attn_dropout) | |
| self.softmax = nn.Softmax(dim=2) | |
| def forward(self, q, k, v, mask=None): | |
| attn = torch.bmm(q, k.transpose(1, 2)) | |
| attn = attn / self.temperature | |
| if mask is not None: | |
| # print(mask.shape, attn.shape, v.shape) | |
| attn = attn.masked_fill(mask, -1e9) | |
| attn = self.softmax(attn) # 第3个维度为权重 | |
| attn = self.dropout(attn) | |
| output = torch.bmm(attn, v) | |
| return output, attn | |
| class MultiHeadAttention(nn.Module): | |
| ''' Multi-Head Attention module ''' | |
| def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): | |
| super(MultiHeadAttention, self).__init__() | |
| self.n_head = n_head | |
| self.d_k = d_k | |
| self.d_v = d_v | |
| self.w_qs = nn.Linear(d_model, n_head * d_k) | |
| self.w_ks = nn.Linear(d_model, n_head * d_k) | |
| self.w_vs = nn.Linear(d_model, n_head * d_v) | |
| nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
| nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
| nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) | |
| self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| self.fc = nn.Linear(n_head * d_v, d_model) | |
| nn.init.xavier_normal_(self.fc.weight) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, q, k, v, mask=None): | |
| d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |
| sz_b, len_q, _ = q.size() | |
| sz_b, len_k, _ = k.size() | |
| sz_b, len_v, _ = v.size() | |
| residual = q | |
| q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) # 4*21*512 ---- 4*21*8*64 | |
| k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) | |
| v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) | |
| q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk | |
| k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk | |
| v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv | |
| mask = mask.repeat(n_head, 1, 1) if mask is not None else None # (n*b) x .. x .. | |
| output, attn = self.attention(q, k, v, mask=mask) | |
| output = output.view(n_head, sz_b, len_q, d_v) | |
| output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) | |
| output = self.dropout(self.fc(output)) | |
| output = self.layer_norm(output + residual) | |
| return output, attn | |
| class PositionwiseFeedForward(nn.Module): | |
| ''' A two-feed-forward-layer module ''' | |
| def __init__(self, d_in, d_hid, dropout=0.1): | |
| super(PositionwiseFeedForward, self).__init__() | |
| self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise | |
| self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise | |
| self.layer_norm = nn.LayerNorm(d_in) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| residual = x | |
| output = x.transpose(1, 2) | |
| output = self.w_2(F.relu(self.w_1(output))) | |
| output = output.transpose(1, 2) | |
| output = self.dropout(output) | |
| output = self.layer_norm(output + residual) | |
| return output | |
| class EncoderLayer(nn.Module): | |
| ''' Compose with two layers ''' | |
| def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |
| super(EncoderLayer, self).__init__() | |
| self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | |
| self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) | |
| def forward(self, enc_input, slf_attn_mask=None): | |
| enc_output, enc_slf_attn = self.slf_attn( | |
| enc_input, enc_input, enc_input, mask=slf_attn_mask) | |
| enc_output = self.pos_ffn(enc_output) | |
| return enc_output, enc_slf_attn | |
| class Torch_transformer_encoder(nn.Module): | |
| ''' | |
| use pytorch transformer for sequence learning | |
| ''' | |
| def __init__(self, d_word_vec=512, n_layers=2, n_head=8, d_model=512, dim_feedforward=1024, n_position=256): | |
| super(Torch_transformer_encoder, self).__init__() | |
| self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=dim_feedforward) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers, norm=self.layer_norm) | |
| self.dropout = nn.Dropout(p=0.1) | |
| def forward(self, cnn_feature, src_mask=None, return_attns=False): | |
| enc_slf_attn_list = [] | |
| # -- Forward | |
| enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding | |
| enc_output = self.encoder(enc_output) | |
| enc_output = self.layer_norm(enc_output) | |
| if return_attns: | |
| return enc_output, enc_slf_attn_list | |
| return enc_output, | |
| class Transforme_Encoder(nn.Module): | |
| ''' to capture the global spatial dependencies''' | |
| ''' | |
| d_word_vec: 位置编码,特征空间维度 | |
| n_layers: transformer的层数 | |
| n_head:多头数量 | |
| d_k: 64 | |
| d_v: 64 | |
| d_model: 512, | |
| d_inner: 1024 | |
| n_position: 位置编码的最大值 | |
| ''' | |
| def __init__( | |
| self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64, | |
| d_model=512, d_inner=1024, dropout=0.1, n_position=256): | |
| super().__init__() | |
| self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.layer_stack = nn.ModuleList([ | |
| EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |
| for _ in range(n_layers)]) | |
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
| def forward(self, cnn_feature, src_mask, return_attns=False): | |
| enc_slf_attn_list = [] | |
| # -- Forward | |
| enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding | |
| for enc_layer in self.layer_stack: | |
| enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) | |
| enc_slf_attn_list += [enc_slf_attn] if return_attns else [] | |
| enc_output = self.layer_norm(enc_output) | |
| if return_attns: | |
| return enc_output, enc_slf_attn_list | |
| return enc_output, | |
| class PVAM(nn.Module): | |
| ''' Parallel Visual attention module 平行解码''' | |
| ''' | |
| n_dim:512,阅读顺序序列编码的空间维度 | |
| N_max_character: 25,单张图片最多有多少个字符 | |
| n_position: cnn出来之后特征的序列长度 | |
| ''' | |
| def __init__(self, n_dim=512, N_max_character=25, n_position=256): | |
| super(PVAM, self).__init__() | |
| self.character_len = N_max_character | |
| self.f0_embedding = nn.Embedding(N_max_character, n_dim) | |
| self.w0 = nn.Linear(N_max_character, n_position) | |
| self.wv = nn.Linear(n_dim, n_dim) | |
| # first linear(512,25) | |
| self.we = nn.Linear(n_dim, N_max_character) | |
| self.active = nn.Tanh() | |
| self.softmax = nn.Softmax(dim=2) | |
| def forward(self, enc_output): | |
| reading_order = torch.arange(self.character_len, dtype=torch.long, device=enc_output.device) | |
| reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) | |
| reading_order_embed = self.f0_embedding(reading_order) # b,25,512 | |
| t = self.w0(reading_order_embed.permute(0,2,1)) # b,512,256 | |
| t = self.active(t.permute(0,2,1) + self.wv(enc_output)) # b,256,512 | |
| # first linear(512,25) | |
| attn = self.we(t) # b,256,25 | |
| attn = self.softmax(attn.permute(0,2,1)) # b,25,256 | |
| g_output = torch.bmm(attn, enc_output) # b,25,512 | |
| return g_output | |
| class GSRM(nn.Module): | |
| # global semantic reasoning module | |
| ''' | |
| n_dim:embed编码的特征空间维度 | |
| n_class:embedding需要用到 | |
| PAD:计算mask用到 | |
| ''' | |
| def __init__(self, n_dim=512, n_class=37, PAD=37-1, n_layers=4, n_position=25): | |
| super(GSRM, self).__init__() | |
| self.PAD = PAD | |
| self.argmax_embed = nn.Embedding(n_class, n_dim) | |
| self.transformer_units = Transforme_Encoder(n_layers=n_layers, n_position=n_position) # for global context information | |
| # self.transformer_units = Torch_transformer_encoder(n_layers=n_layers, n_position=n_position) | |
| def forward(self, e_out): | |
| ''' | |
| e_out: b,25,37 | the output from PVAM3 | |
| ''' | |
| e_argmax = e_out.argmax(dim=-1) # b, 25 | |
| e = self.argmax_embed(e_argmax) # b,25,512 | |
| e_mask = get_pad_mask(e_argmax, self.PAD) # b,25,1 | |
| s = self.transformer_units(e, None) # b,25,512 | |
| return s | |
| class SRN_Decoder(nn.Module): | |
| # the wrapper of decoder layers | |
| ''' | |
| n_dim: 特征空间维度 | |
| n_class:字符种类 | |
| N_max_character: 单张图最多只25个字符 | |
| n_position:cnn输出的特征序列长度 | |
| 整个有三个部分的输出 | |
| ''' | |
| def __init__(self, n_dim=512, n_class=37, N_max_character=25, n_position=256, GSRM_layer=4 ): | |
| super(SRN_Decoder, self).__init__() | |
| self.pvam = PVAM(N_max_character=N_max_character, n_position=n_position) | |
| self.w_e = nn.Linear(n_dim, n_class) # output layer | |
| self.GSRM = GSRM(n_class=n_class, PAD=n_class-1, n_dim=n_dim, n_position=N_max_character, n_layers=GSRM_layer) | |
| self.w_s = nn.Linear(n_dim, n_class) # output layer | |
| self.w_f = nn.Linear(n_dim, n_class) # output layer | |
| def forward(self, cnn_feature ): | |
| '''cnn_feature: b,256,512 | the output from cnn''' | |
| g_output = self.pvam(cnn_feature) # b,25,512 | |
| e_out = self.w_e(g_output) # b,25,37 ----> cross entropy loss | 第一个输出 | |
| s = self.GSRM(e_out)[0] # b,25,512 | |
| s_out = self.w_s(s) # b,25,37f | |
| # TODO:change the add to gated unit | |
| f = g_output + s # b,25,512 | |
| f_out = self.w_f(f) | |
| return e_out, s_out, f_out | |
| def cal_performance(preds, gold, mask=None, smoothing='1'): | |
| ''' Apply label smoothing if needed ''' | |
| loss = 0. | |
| n_correct = 0 | |
| weights = [1.0, 0.15, 2.0] | |
| for ori_pred, weight in zip(preds, weights): | |
| pred = ori_pred.view(-1, ori_pred.shape[-1]) | |
| # debug show | |
| t_gold = gold.view(ori_pred.shape[0], -1) | |
| t_pred_index = ori_pred.max(2)[1] | |
| mask = mask.view(-1) | |
| non_pad_mask = mask.ne(0) if mask is not None else None | |
| tloss = cal_loss(pred, gold, non_pad_mask, smoothing) | |
| if torch.isnan(tloss): | |
| print('have nan loss') | |
| continue | |
| else: | |
| loss += tloss * weight | |
| pred = pred.max(1)[1] | |
| gold = gold.contiguous().view(-1) | |
| n_correct = pred.eq(gold) | |
| n_correct = n_correct.masked_select(non_pad_mask).sum().item() if mask is not None else None | |
| return loss, n_correct | |
| def cal_loss(pred, gold, mask, smoothing): | |
| ''' Calculate cross entropy loss, apply label smoothing if needed. ''' | |
| gold = gold.contiguous().view(-1) | |
| if smoothing=='0': | |
| eps = 0.1 | |
| n_class = pred.size(1) | |
| one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) | |
| one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) | |
| log_prb = F.log_softmax(pred, dim=1) | |
| non_pad_mask = gold.ne(0) | |
| loss = -(one_hot * log_prb).sum(dim=1) | |
| loss = loss.masked_select(non_pad_mask).sum() # average later | |
| elif smoothing == '1': | |
| if mask is not None: | |
| loss = F.cross_entropy(pred, gold, reduction='none') | |
| loss = loss.masked_select(mask) | |
| loss = loss.sum() / mask.sum() | |
| else: | |
| loss = F.cross_entropy(pred, gold) | |
| else: | |
| # loss = F.cross_entropy(pred, gold, ignore_index=PAD) | |
| loss = F.cross_entropy(pred, gold) | |
| return loss | |
| def cal_performance2(preds, gold, PAD, smoothing='1'): | |
| ''' Apply label smoothing if needed ''' | |
| loss = 0. | |
| n_correct = 0 | |
| weights = [1.0, 0.15, 2.0] | |
| for ori_pred, weight in zip(preds, weights): | |
| pred = ori_pred.view(-1, ori_pred.shape[-1]) | |
| # debug show | |
| t_gold = gold.view(ori_pred.shape[0], -1) | |
| t_pred_index = ori_pred.max(2)[1] | |
| tloss = cal_loss2(pred, gold, PAD, smoothing=smoothing) | |
| if torch.isnan(tloss): | |
| print('have nan loss') | |
| continue | |
| else: | |
| loss += tloss * weight | |
| pred = pred.max(1)[1] | |
| gold = gold.contiguous().view(-1) | |
| n_correct = pred.eq(gold) | |
| non_pad_mask = gold.ne(PAD) | |
| n_correct = n_correct.masked_select(non_pad_mask).sum().item() | |
| return loss, n_correct | |
| def cal_loss2(pred, gold, PAD, smoothing='1'): | |
| ''' Calculate cross entropy loss, apply label smoothing if needed. ''' | |
| gold = gold.contiguous().view(-1) | |
| if smoothing=='0': | |
| eps = 0.1 | |
| n_class = pred.size(1) | |
| one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) | |
| one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) | |
| log_prb = F.log_softmax(pred, dim=1) | |
| non_pad_mask = gold.ne(0) | |
| loss = -(one_hot * log_prb).sum(dim=1) | |
| loss = loss.masked_select(non_pad_mask).sum() # average later | |
| elif smoothing == '1': | |
| loss = F.cross_entropy(pred, gold, ignore_index=PAD) | |
| else: | |
| # loss = F.cross_entropy(pred, gold, ignore_index=PAD) | |
| loss = F.cross_entropy(pred, gold) | |
| return loss | |
| if __name__=='__main__': | |
| cnn_feature = torch.rand((2,256,512)) | |
| model1 = Transforme_Encoder() | |
| image = model1(cnn_feature,src_mask=None)[0] | |
| model = SRN_Decoder(N_max_character=30) | |
| outs = model(image) | |
| for out in outs: | |
| print(out.shape) | |
| # image = torch.rand((4,3,32,60)) | |
| # tgt_seq = torch.tensor([[ 2, 24, 2176, 882, 2480, 612, 1525, 480, 875, 147, 1700, 715, | |
| # 1465, 3], | |
| # [ 2, 369, 1781, 882, 703, 879, 2855, 2415, 502, 1154, 833, 1465, | |
| # 3, 0], | |
| # [ 2, 2943, 334, 328, 480, 330, 1644, 1449, 163, 147, 1823, 1184, | |
| # 1465, 3], | |
| # [ 2, 24, 396, 480, 703, 1646, 897, 1711, 1508, 703, 2321, 147, | |
| # 642, 1465]], device='cuda:0') | |
| # tgt_pos = torch.tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], | |
| # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | |
| # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], | |
| # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]], | |
| # device='cuda:0') | |
| # src_seq = torch.tensor([[ 2, 598, 2088, 822, 2802, 1156, 157, 1099, 1000, 598, 1707, 1345, | |
| # 3, 0, 0, 0], | |
| # [ 2, 598, 2348, 822, 598, 1222, 471, 948, 986, 423, 1345, 3, | |
| # 0, 0, 0, 0], | |
| # [ 2, 2437, 2470, 901, 2473, 598, 1735, 84, 1, 2277, 1979, 499, | |
| # 962, 1345, 3, 0], | |
| # [ 2, 598, 186, 1904, 598, 868, 1339, 1604, 84, 598, 608, 1728, | |
| # 1345, 3, 0, 0]], device='cuda:0') | |
| # device = torch.device('cuda') | |
| # image = image.cuda() | |
| # transformer = Transformer() | |
| # transformer = transformer.to(device) | |
| # transformer.train() | |
| # out = transformer(image, tgt_seq, tgt_pos, src_seq) | |
| # gold = tgt_seq[:, 1:] # 从第二列开始 | |
| # # backward | |
| # loss, n_correct = cal_performance(out, gold, smoothing=True) | |
| # print(loss, n_correct) | |
| # a = 1 |