| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from easydict import EasyDict as edict |
| | from xml.model_components import BertAttention, TrainablePositionalEncoding |
| |
|
| |
|
| | class TextEncoder(nn.Module): |
| | def __init__(self, hidden_size, drop, input_drop, nheads, max_position_embeddings): |
| | super().__init__() |
| | self.transformer_encoder = BertAttention(edict( |
| | hidden_size=hidden_size, |
| | intermediate_size=hidden_size, |
| | hidden_dropout_prob=drop, |
| | attention_probs_dropout_prob=drop, |
| | num_attention_heads=nheads, |
| | )) |
| | self.pos_embed = TrainablePositionalEncoding( |
| | max_position_embeddings=max_position_embeddings, |
| | hidden_size=hidden_size, |
| | dropout=input_drop, |
| | ) |
| | self.modular_vector_mapping = nn.Linear(hidden_size, 1, bias=False) |
| |
|
| | def forward(self, feat, mask): |
| | """ |
| | Args: |
| | feat: (N, L, D=hidden_size) |
| | mask: (N, L) with 1 indicates valid |
| | |
| | Returns: |
| | (N, D) |
| | """ |
| | feat = self.pos_embed(feat) |
| | feat = self.transformer_encoder(feat, mask.unsqueeze(1)) |
| | att_scores = self.modular_vector_mapping(feat) |
| | att_scores = F.softmax(mask_logits(att_scores, mask.unsqueeze(2)), dim=1) |
| | pooled_feat = torch.einsum("blm,bld->bmd", att_scores, feat) |
| | return pooled_feat.squeeze(1) |
| |
|
| |
|
| | def mask_logits(target, mask): |
| | return target * mask + (1 - mask) * (-1e10) |
| |
|
| |
|
| | def build_text_encoder(args): |
| | return TextEncoder( |
| | hidden_size=args.hidden_dim, |
| | drop=args.dropout, |
| | input_drop=args.input_dropout, |
| | nheads=args.nheads, |
| | max_position_embeddings=args.max_q_l |
| | ) |
| |
|