| | import math |
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel, AutoModelForCausalLM |
| | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| | from transformers.file_utils import add_start_docstrings_to_model_forward |
| | from .configuration_spt import SPTConfig |
| |
|
| | def repeat_kv(hidden_states, repeat_times): |
| | if repeat_times == 1: |
| | return hidden_states |
| | batch, n_kv_heads, seq_len, head_dim = hidden_states.shape |
| | hidden_states = hidden_states[:,:,None,:,:].expand(batch, n_kv_heads, repeat_times, seq_len, head_dim) |
| | return hidden_states.reshape(batch, n_kv_heads*repeat_times, seq_len, head_dim) |
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-6): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| | self.variance_epsilon = eps |
| |
|
| | def forward(self, hidden_states): |
| | input_dtype = hidden_states.dtype |
| | hidden_states = hidden_states.to(torch.float32) |
| | variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| | return self.weight * hidden_states.to(input_dtype) |
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.head_dim = config.hidden_size // config.n_attn_heads |
| | kv_size = config.n_kv_heads * self.head_dim |
| | self.hidden_size = config.hidden_size |
| | self.n_attn_heads = config.n_attn_heads |
| | self.n_kv_heads = config.n_kv_heads |
| |
|
| | self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| | self.k = nn.Linear(config.hidden_size, kv_size, bias=False) |
| | self.v = nn.Linear(config.hidden_size, kv_size, bias=False) |
| |
|
| | self.register_buffer('tril', torch.tril(torch.ones(config.max_len, config.max_len)).view(1, 1, config.max_len, config.max_len)) |
| |
|
| | def forward(self, x): |
| | batch_size, seq_len, hidden_dim = x.shape |
| |
|
| | q = self.q(x) |
| | k = self.k(x) |
| | v = self.v(x) |
| |
|
| | q = q.view(batch_size, seq_len, self.n_attn_heads, self.head_dim).transpose(1, 2) |
| | k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| | v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| |
|
| | k = repeat_kv(k, self.n_attn_heads//self.n_kv_heads) |
| | v = repeat_kv(v, self.n_attn_heads//self.n_kv_heads) |
| |
|
| | attention = (q @ k.transpose(-2,-1)) * (1.0/math.sqrt(self.hidden_size)) |
| | attention = attention.masked_fill(self.tril[:,:,:seq_len,:seq_len]==0, float('-inf')) |
| | probs = nn.functional.softmax(attention,dim=-1) |
| | y = probs@v |
| | y = y.transpose(1,2).contiguous().reshape(batch_size, seq_len, -1) |
| |
|
| | return y |
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.up = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| | self.gate = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| | self.down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
| | self.act_fn = nn.GELU() |
| |
|
| | def forward(self,x): |
| | up = self.up(x) |
| | gate = self.gate(x) |
| | return self.down(self.act_fn(up * gate)) |
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.attn = Attention(config) |
| | self.mlp = MLP(config) |
| | self.residual = config.residual |
| | self.norm = RMSNorm(config.hidden_size) if config.normalise else nn.Identity() |
| |
|
| | def forward(self, x): |
| | if self.residual: |
| | x = x + self.attn(self.norm(x)) |
| | x = x + self.mlp(self.norm(x)) |
| | else: |
| | x = self.attn(self.norm(x)) |
| | x = self.mlp(self.norm(x)) |
| | return x |
| |
|
| | class SPTModel(PreTrainedModel): |
| | config_class = SPTConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) |
| | self.norm = RMSNorm(config.hidden_size) if config.normalise else nn.Identity() |
| |
|
| | def forward(self, input_ids): |
| | x = self.embedding(input_ids) |
| | for layer in self.layers: |
| | x = layer(x) |
| | x = self.norm(x) |
| | return x |
| |
|
| | class SPTForCausalLM(PreTrainedModel): |
| | config_class = SPTConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model = SPTModel(config) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | def forward(self, input_ids, labels=None): |
| | x = self.model(input_ids) |
| | logits = self.lm_head(x) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| |
|
| | return CausalLMOutputWithCrossAttentions( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=x, |
| | ) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| | return {"input_ids": input_ids} |
| |
|
| | @staticmethod |
| | def _reorder_cache(past, beam_idx): |
| | return past |
| |
|
| | |
| | AutoModelForCausalLM.register(SPTConfig, SPTForCausalLM) |