| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| | from transformers import PreTrainedModel |
| | from .configuration_gator import GatorConfig |
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, dim, eps=1e-5): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| | def forward(self, x): |
| | norm = x.norm(2, dim=-1, keepdim=True) / math.sqrt(x.shape[-1]) |
| | return self.weight * (x / (norm + self.eps)) |
| |
|
| | class Rope(nn.Module): |
| | def __init__(self, d_model, max_len=1024): |
| | super().__init__() |
| | assert d_model % 2 == 0 |
| | self.register_buffer("pos", torch.arange(max_len).unsqueeze(1)) |
| | self.register_buffer("inv_freq", torch.exp( |
| | torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))) |
| | def forward(self, x): |
| | t = x.size(1) |
| | freqs = self.pos[:t] * self.inv_freq |
| | cos, sin = torch.cos(freqs), torch.sin(freqs) |
| | x = x.view(*x.shape[:-1], -1, 2) |
| | x1, x2 = x[...,0], x[...,1] |
| | x_rot = torch.stack([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1) |
| | return x_rot.view(*x.shape[:-2], -1) |
| |
|
| | class GQA(nn.Module): |
| | def __init__(self, d_model, n_heads, gqa_groups, max_len): |
| | super().__init__() |
| | self.n_heads = n_heads |
| | self.head_dim = d_model // n_heads |
| | self.n_kv = n_heads // gqa_groups |
| | self.q_proj = nn.Linear(d_model, n_heads*self.head_dim, bias=False) |
| | self.k_proj = nn.Linear(d_model, self.n_kv*self.head_dim, bias=False) |
| | self.v_proj = nn.Linear(d_model, self.n_kv*self.head_dim, bias=False) |
| | self.o_proj = nn.Linear(d_model, d_model, bias=False) |
| | self.rope_q = Rope(n_heads*self.head_dim, max_len) |
| | self.rope_k = Rope(self.n_kv*self.head_dim, max_len) |
| | def forward(self, x): |
| | B,T,C = x.shape |
| | q = self.rope_q(self.q_proj(x)).view(B,T,self.n_heads,self.head_dim).transpose(1,2) |
| | k = self.rope_k(self.k_proj(x)).view(B,T,self.n_kv,self.head_dim).transpose(1,2) |
| | v = self.v_proj(x).view(B,T,self.n_kv,self.head_dim).transpose(1,2) |
| | expand = self.n_heads // self.n_kv |
| | k = k.repeat_interleave(expand, dim=1) |
| | v = v.repeat_interleave(expand, dim=1) |
| | attn = torch.softmax((q @ k.transpose(-2,-1))/math.sqrt(self.head_dim), dim=-1) |
| | out = attn @ v |
| | out = out.transpose(1,2).contiguous().view(B,T,C) |
| | return self.o_proj(out) |
| |
|
| | class MLP(nn.Module): |
| | def __init__(self, d_model, d_ff): |
| | super().__init__() |
| | self.fc1 = nn.Linear(d_model, 2*d_ff, bias=False) |
| | self.fc2 = nn.Linear(d_ff, d_model, bias=False) |
| | def forward(self,x): |
| | up, gate = self.fc1(x).chunk(2, dim=-1) |
| | return self.fc2(up * F.silu(gate)) |
| |
|
| | class Block(nn.Module): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.rms1 = RMSNorm(cfg.hidden_size) |
| | self.rms2 = RMSNorm(cfg.hidden_size) |
| | self.attn = GQA(cfg.hidden_size, cfg.num_attention_heads, 2, cfg.max_position_embeddings) |
| | self.mlp = MLP(cfg.hidden_size, 2*cfg.hidden_size) |
| | def forward(self,x): |
| | x = x + self.attn(self.rms1(x)) |
| | x = x + self.mlp(self.rms2(x)) |
| | return x |
| |
|
| | class GatorModel(PreTrainedModel): |
| | config_class = GatorConfig |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.embed = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)]) |
| | self.norm = RMSNorm(config.hidden_size) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.lm_head.weight = self.embed.weight |
| | def forward(self, input_ids): |
| | h = self.embed(input_ids) |
| | for blk in self.blocks: h = blk(h) |
| | h = self.norm(h) |
| | return {"logits": self.lm_head(h)} |
| |
|
| | class GatorForCausalLM(PreTrainedModel): |
| | config_class = GatorConfig |
| | base_model_prefix = "gator" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.gator = GatorModel(config) |
| | self.post_init() |
| |
|
| | @torch.no_grad() |
| | def forward(self, input_ids, temperature=0.8, top_k=5): |
| | logits = self.gator(input_ids)["logits"][:, -1, :] / temperature |
| | topk = torch.topk(logits, k=top_k, dim=-1) |
| | probs = torch.softmax(topk.values, dim=-1) |
| | next_token = topk.indices.gather(-1, torch.multinomial(probs, 1)) |
| | return next_token.squeeze().item() |