|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from .configuration_tinygpt import TinyGPTConfig |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dim, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
def forward(self, x): |
|
|
var = torch.mean(x ** 2, dim=-1, keepdim=True) |
|
|
return x * torch.rsqrt(var + self.eps) * self.weight |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.fc_in = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) |
|
|
self.act = nn.GELU() |
|
|
self.fc_out = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) |
|
|
def forward(self, x): |
|
|
return self.fc_out(self.act(self.fc_in(x))) |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.n_heads = config.num_attention_heads |
|
|
self.head_dim = config.hidden_size // config.num_attention_heads |
|
|
self.scale = self.head_dim ** -0.5 |
|
|
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) |
|
|
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) |
|
|
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) |
|
|
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) |
|
|
def forward(self, x, mask=None): |
|
|
B, T, C = x.shape |
|
|
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
att = (q @ k.transpose(-2, -1)) * self.scale |
|
|
if mask is not None: |
|
|
if mask.dim() == 2: mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
att = att.masked_fill(mask == 0, float('-inf')) |
|
|
att = torch.softmax(att, dim=-1) |
|
|
out = (att @ v).transpose(1, 2).contiguous().view(B, T, C) |
|
|
return self.out_proj(out) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.norm_1 = RMSNorm(config.hidden_size, config.rms_norm_eps) |
|
|
self.attn = Attention(config) |
|
|
self.norm_2 = RMSNorm(config.hidden_size, config.rms_norm_eps) |
|
|
self.mlp = MLP(config) |
|
|
def forward(self, x, mask=None): |
|
|
x = x + self.attn(self.norm_1(x), mask) |
|
|
x = x + self.mlp(self.norm_2(x)) |
|
|
return x |
|
|
|
|
|
class TinyGPTPreTrainedModel(PreTrainedModel): |
|
|
config_class = TinyGPTConfig |
|
|
base_model_prefix = "transformer" |
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, std=0.02) |
|
|
|
|
|
class TinyGPTModel(TinyGPTPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.wte = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
self.h = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)]) |
|
|
self.ln_f = RMSNorm(config.hidden_size, config.rms_norm_eps) |
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
B, T = input_ids.shape |
|
|
pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device) |
|
|
x = self.wte(input_ids) + self.wpe(pos) |
|
|
mask = torch.tril(torch.ones((T, T), device=input_ids.device)).view(1, 1, T, T) |
|
|
for layer in self.h: |
|
|
x = layer(x, mask) |
|
|
return self.ln_f(x) |
|
|
|
|
|
class TinyGPTForCausalLM(TinyGPTPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.transformer = TinyGPTModel(config) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
|
|
hidden = self.transformer(input_ids, attention_mask) |
|
|
logits = self.lm_head(hidden) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
|
return {"input_ids": input_ids} |
|
|
|