File size: 5,629 Bytes
6bd3c24 9f3f5bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
from transformers import PreTrainedModel
from .configuration import MBZTestConfig
from transformers.modeling_outputs import CausalLMOutput
class RotaryPositionalEncoding(nn.Module):
"""
Rotary Position Embeddings (RoPE) - efficient implementation
"""
def __init__(self, d_head, max_seq_len=8192, base=10000.0):
super().__init__()
self.d_head = d_head
self.max_seq_len = max_seq_len
self.base = base
# Precompute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
self.register_buffer('inv_freq', inv_freq, persistent=False)
# Precompute cos and sin for maximum sequence length
self._precompute_freqs(max_seq_len)
def _precompute_freqs(self, seq_len):
"""Precompute cos and sin values for positions"""
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq) # (seq_len, d_head/2)
# Create cos and sin embeddings
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
# Interleave to match the dimension (seq_len, d_head)
self.register_buffer('freqs_cos', freqs_cos.repeat_interleave(2, dim=-1), persistent=False)
self.register_buffer('freqs_sin', freqs_sin.repeat_interleave(2, dim=-1), persistent=False)
def rotate_half(self, x):
"""Rotate half the hidden dims of the input"""
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack([-x2, x1], dim=-1).flatten(-2)
def forward(self, q, k, start_pos=0):
"""
Apply rotary embeddings to query and key tensors
Args:
q: (batch_size, n_heads, seq_len, d_head)
k: (batch_size, n_heads, seq_len, d_head)
start_pos: starting position for caching scenarios
Returns:
q_rot, k_rot with rotary embeddings applied
"""
seq_len = q.shape[2]
# Get the precomputed frequencies for this sequence length
freqs_cos = self.freqs_cos[start_pos:start_pos + seq_len]
freqs_sin = self.freqs_sin[start_pos:start_pos + seq_len]
# Apply rotary embeddings
q_rot = q * freqs_cos + self.rotate_half(q) * freqs_sin
k_rot = k * freqs_cos + self.rotate_half(k) * freqs_sin
return q_rot, k_rot
class Attention(nn.Module):
def __init__(self, d_model, n_heads, d_head):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_head
self.Wq = nn.Linear(d_model, n_heads * d_head, bias=False)
self.Wk = nn.Linear(d_model, n_heads * d_head, bias=False)
self.Wv = nn.Linear(d_model, n_heads * d_head, bias=False)
self.Wo = nn.Linear(n_heads * d_head, d_model, bias=False)
# Initialize RoPE
self.rope = RotaryPositionalEncoding(d_head)
def forward(self, x):
# x is shape batch_size, seq_len, d_model
batch_size, seq_len, d_model = x.shape
q = self.Wq(x) # q is shape batch_size, seq_len, n_heads * d_head
k = self.Wk(x)
v = self.Wv(x)
# reshape to batch_size, n_heads, seq_len, d_head
q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
q, k = self.rope(q, k)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # ensure use flash attention
a = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)# a is (batch_size, n_heads, seq_len, d_head)
a = a.transpose(1,2) # change a to (batch_size, seq_len, n_heads, d_head)
a = a.reshape(batch_size, seq_len, self.n_heads * self.d_head)
out = self.Wo(a) # out is (batch_size, seq_len, d_model)
return out
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_head):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_head
self.attn = Attention(d_model, n_heads, d_head)
self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model))
self.norm1 = nn.RMSNorm(d_model)
self.norm2 = nn.RMSNorm(d_model)
def forward(self, x):
x = self.attn(self.norm1(x)) + x
x = self.mlp(self.norm2(x)) + x
return x
class MBZTestModelForCausalLM(PreTrainedModel):
config_class = MBZTestConfig
def __init__(self, config):
super().__init__(config)
d_model = config.d_model
n_heads = config.n_heads
d_head = config.d_head
n_vocab = config.n_vocab
n_layers = config.n_layers
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_head
self.n_vocab = n_vocab
self.embed = nn.Embedding(n_vocab, d_model)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_head) for _ in range(n_layers)])
self.norm = nn.RMSNorm(d_model)
self.out_head = nn.Linear(d_model, n_vocab)
def forward(self, x):
with torch.autocast('cuda', dtype=torch.bfloat16):
x = self.embed(x)
for block in self.blocks:
x = block(x)
x = self.out_head(self.norm(x))
return CausalLMOutput(logits=x)
|