| | |
| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | __all__ = ['XLMRoberta', 'xlm_roberta_large'] |
| |
|
| |
|
| | class SelfAttention(nn.Module): |
| |
|
| | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): |
| | assert dim % num_heads == 0 |
| | super().__init__() |
| | self.dim = dim |
| | self.num_heads = num_heads |
| | self.head_dim = dim // num_heads |
| | self.eps = eps |
| |
|
| | |
| | self.q = nn.Linear(dim, dim) |
| | self.k = nn.Linear(dim, dim) |
| | self.v = nn.Linear(dim, dim) |
| | self.o = nn.Linear(dim, dim) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x, mask): |
| | """ |
| | x: [B, L, C]. |
| | """ |
| | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim |
| |
|
| | |
| | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) |
| | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) |
| | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) |
| |
|
| | |
| | p = self.dropout.p if self.training else 0.0 |
| | x = F.scaled_dot_product_attention(q, k, v, mask, p) |
| | x = x.permute(0, 2, 1, 3).reshape(b, s, c) |
| |
|
| | |
| | x = self.o(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | class AttentionBlock(nn.Module): |
| |
|
| | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): |
| | super().__init__() |
| | self.dim = dim |
| | self.num_heads = num_heads |
| | self.post_norm = post_norm |
| | self.eps = eps |
| |
|
| | |
| | self.attn = SelfAttention(dim, num_heads, dropout, eps) |
| | self.norm1 = nn.LayerNorm(dim, eps=eps) |
| | self.ffn = nn.Sequential( |
| | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), |
| | nn.Dropout(dropout)) |
| | self.norm2 = nn.LayerNorm(dim, eps=eps) |
| |
|
| | def forward(self, x, mask): |
| | if self.post_norm: |
| | x = self.norm1(x + self.attn(x, mask)) |
| | x = self.norm2(x + self.ffn(x)) |
| | else: |
| | x = x + self.attn(self.norm1(x), mask) |
| | x = x + self.ffn(self.norm2(x)) |
| | return x |
| |
|
| |
|
| | class XLMRoberta(nn.Module): |
| | """ |
| | XLMRobertaModel with no pooler and no LM head. |
| | """ |
| |
|
| | def __init__(self, |
| | vocab_size=250002, |
| | max_seq_len=514, |
| | type_size=1, |
| | pad_id=1, |
| | dim=1024, |
| | num_heads=16, |
| | num_layers=24, |
| | post_norm=True, |
| | dropout=0.1, |
| | eps=1e-5): |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.max_seq_len = max_seq_len |
| | self.type_size = type_size |
| | self.pad_id = pad_id |
| | self.dim = dim |
| | self.num_heads = num_heads |
| | self.num_layers = num_layers |
| | self.post_norm = post_norm |
| | self.eps = eps |
| |
|
| | |
| | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) |
| | self.type_embedding = nn.Embedding(type_size, dim) |
| | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.blocks = nn.ModuleList([ |
| | AttentionBlock(dim, num_heads, post_norm, dropout, eps) |
| | for _ in range(num_layers) |
| | ]) |
| |
|
| | |
| | self.norm = nn.LayerNorm(dim, eps=eps) |
| |
|
| | def forward(self, ids): |
| | """ |
| | ids: [B, L] of torch.LongTensor. |
| | """ |
| | b, s = ids.shape |
| | mask = ids.ne(self.pad_id).long() |
| |
|
| | |
| | x = self.token_embedding(ids) + \ |
| | self.type_embedding(torch.zeros_like(ids)) + \ |
| | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) |
| | if self.post_norm: |
| | x = self.norm(x) |
| | x = self.dropout(x) |
| |
|
| | |
| | mask = torch.where( |
| | mask.view(b, 1, 1, s).gt(0), 0.0, |
| | torch.finfo(x.dtype).min) |
| | for block in self.blocks: |
| | x = block(x, mask) |
| |
|
| | |
| | if not self.post_norm: |
| | x = self.norm(x) |
| | return x |
| |
|
| |
|
| | def xlm_roberta_large(pretrained=False, |
| | return_tokenizer=False, |
| | device='cpu', |
| | **kwargs): |
| | """ |
| | XLMRobertaLarge adapted from Huggingface. |
| | """ |
| | |
| | cfg = dict( |
| | vocab_size=250002, |
| | max_seq_len=514, |
| | type_size=1, |
| | pad_id=1, |
| | dim=1024, |
| | num_heads=16, |
| | num_layers=24, |
| | post_norm=True, |
| | dropout=0.1, |
| | eps=1e-5) |
| | cfg.update(**kwargs) |
| |
|
| | |
| | with torch.device(device): |
| | model = XLMRoberta(**cfg) |
| | return model |