| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchtune.modules import RotaryPositionalEmbeddings |
| | from transformers import PreTrainedModel |
| | from .config import model_config |
| | from typing import Mapping |
| | from transformers.tokenization_utils_base import BatchEncoding |
| |
|
| | class SwiGLU(nn.Module): |
| | def __init__(self, input_dim, hidden_dim): |
| | super().__init__() |
| | self.linear1 = nn.Linear(input_dim, hidden_dim * 2, bias=True) |
| | self.linear2 = nn.Linear(hidden_dim, input_dim, bias=True) |
| | self.dropout = nn.Dropout(0.1) |
| | def forward(self, x): |
| | |
| | x1, x2 = self.linear1(x).chunk(2, dim=-1) |
| | output = self.linear2(F.silu(x1) * x2) |
| | return self.dropout(output) |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module): |
| | def __init__(self, embed_dim, num_heads, max_seq_len): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.head_dim = embed_dim // num_heads |
| | assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
| |
|
| | self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) |
| | self.rotary = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=max_seq_len) |
| | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
| | self.dropout = nn.Dropout(0.1) |
| |
|
| | def forward(self, x, input_pos=None, mask=None): |
| | B, T, C = x.shape |
| | |
| | |
| | q, k, v = self.qkv_proj(x).view(B, T, 3, self.num_heads, self.head_dim).unbind(2) |
| |
|
| | |
| | q, k = self.rotary(q, input_pos=input_pos), self.rotary(k, input_pos=input_pos) |
| |
|
| | |
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| |
|
| | if mask is not None: |
| | |
| | mask = mask.to(dtype=torch.float32) |
| | mask = (1.0 - mask) * -1e9 |
| | |
| | |
| | mask = mask.unsqueeze(1).unsqueeze(2) |
| | mask = mask.expand(B, 1, T, T) |
| | |
| | |
| | attn_output = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=mask) |
| | attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C) |
| | attn_output = self.out_proj(attn_output) |
| | return self.dropout(attn_output) |
| |
|
| |
|
| | class UnifiedTransformerBlock(nn.Module): |
| | def __init__(self, embed_dim, num_heads, ffn_hidden_dim, max_seq_len): |
| | super().__init__() |
| | self.attn_norm = nn.LayerNorm(embed_dim) |
| | self.attn = MultiHeadAttention(embed_dim, num_heads, max_seq_len) |
| | self.ffn_norm = nn.LayerNorm(embed_dim) |
| | self.ffn = SwiGLU(embed_dim, ffn_hidden_dim) |
| |
|
| | def forward(self, x, input_pos=None, mask=None): |
| | x = x + self.attn(self.attn_norm(x), input_pos=input_pos, mask=mask) |
| | x = x + self.ffn(self.ffn_norm(x)) |
| | return x |
| |
|
| | class TransformerStack(nn.Module): |
| | def __init__(self, num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len): |
| | super().__init__() |
| | self.blocks = nn.ModuleList([ |
| | UnifiedTransformerBlock(embed_dim, num_heads, ffn_hidden_dim, max_seq_len) |
| | for _ in range(num_blocks) |
| | ]) |
| | self.norm = nn.LayerNorm(embed_dim) |
| |
|
| | def forward(self, x, input_pos=None, mask=None): |
| | for block in self.blocks: |
| | x = block(x, input_pos=input_pos, mask=mask) |
| | return self.norm(x) |
| |
|
| | class MLM_core(nn.Module): |
| | def __init__( |
| | self, |
| | vocab_size: int, |
| | embed_dim: int, |
| | num_blocks: int, |
| | num_heads: int, |
| | ffn_hidden_dim: int, |
| | output_dim: int, |
| | max_seq_len: int, |
| | ): |
| | super().__init__() |
| | self.embed = nn.Embedding(vocab_size, embed_dim) |
| | self.transformer = TransformerStack( |
| | num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len |
| | ) |
| | self.sequence_head = nn.Linear(embed_dim, output_dim, bias=True) |
| |
|
| |
|
| | def forward(self, ids, mask=None, pad_token_id=0, input_pos=None): |
| | x = self.embed(ids) |
| | x = self.transformer(x, mask=mask, input_pos=input_pos) |
| | |
| | |
| | logits = self.sequence_head(x) |
| | |
| |
|
| | |
| | mean_pool = x.masked_fill(ids.unsqueeze(-1) == pad_token_id, 0).mean(dim=1) |
| |
|
| | outputs = { |
| | 'logits': logits, |
| | 'last_layer': x, |
| | 'mean_pool': mean_pool |
| | } |
| |
|
| | return outputs |
| |
|
| | class MLM_model(PreTrainedModel): |
| | config_class = model_config |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model = MLM_core( |
| | vocab_size=config.vocab_size, |
| | embed_dim=config.embed_dim, |
| | num_blocks=config.num_blocks, |
| | num_heads=config.num_heads, |
| | ffn_hidden_dim=config.ffn_hidden_dim, |
| | output_dim=config.output_dim, |
| | max_seq_len=config.max_seq_len, |
| | ) |
| | self.post_init() |
| | |
| | |
| | def forward(self, x=None, **kwargs): |
| | if isinstance(x, (BatchEncoding, Mapping)): |
| | return self.model(x.get("input_ids"), mask=x.get("attention_mask")) |
| | |
| | if "input_ids" in kwargs or "attention_mask" in kwargs: |
| | return self.model(kwargs.get("input_ids"), mask=kwargs.get("attention_mask")) |
| | |
| | return self.model(x) |
| |
|