Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch | |
| from .token import TokenEmbedding | |
| from .position import PositionalEmbedding | |
| from .segment import SegmentEmbedding | |
| from .time_embed import TimeEmbedding | |
| class BERTEmbedding(nn.Module): | |
| """ | |
| BERT Embedding which is consisted with under features | |
| 1. TokenEmbedding : normal embedding matrix | |
| 2. PositionalEmbedding : adding positional information using sin, cos | |
| 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2) | |
| sum of all these features are output of BERTEmbedding | |
| """ | |
| def __init__(self, vocab_size, embed_size, max_len, dropout=0.1, is_logkey=True, is_time=False): | |
| """ | |
| :param vocab_size: total vocab size | |
| :param embed_size: embedding size of token embedding | |
| :param dropout: dropout rate | |
| """ | |
| super().__init__() | |
| self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) | |
| self.position = PositionalEmbedding(d_model=self.token.embedding_dim, max_len=max_len) | |
| self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim) | |
| self.time_embed = TimeEmbedding(embed_size=self.token.embedding_dim) | |
| self.dropout = nn.Dropout(p=dropout) | |
| self.embed_size = embed_size | |
| self.is_logkey = is_logkey | |
| self.is_time = is_time | |
| def forward(self, sequence, segment_label=None, time_info=None): | |
| x = self.position(sequence) | |
| # if self.is_logkey: | |
| x = x + self.token(sequence) | |
| if segment_label is not None: | |
| x = x + self.segment(segment_label) | |
| if self.is_time: | |
| x = x + self.time_embed(time_info) | |
| return self.dropout(x) | |