| | import torch |
| | import torch.nn as nn |
| | import math |
| |
|
| | |
| | |
| | |
| | class ChatDBS1Config: |
| | def __init__(self, n_embd=512, n_layer=12, n_head=8, vocab_size=50257, block_size=1024): |
| | self.n_embd = n_embd |
| | self.n_layer = n_layer |
| | self.n_head = n_head |
| | self.vocab_size = vocab_size |
| | self.block_size = block_size |
| |
|
| | |
| | |
| | |
| | def rotate_half(x): |
| | x1, x2 = x.chunk(2, dim=-1) |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| | def apply_rotary_pos_emb(q, k, seq_len): |
| | freqs = 10000 ** (-torch.arange(0, q.size(-1), 2, device=q.device).float() / q.size(-1)) |
| | t = torch.arange(seq_len, device=q.device).float() |
| | angles = t[:, None] * freqs[None, :] |
| | cos = angles.cos()[None, :, :] |
| | sin = angles.sin()[None, :, :] |
| | q_rot = (q[..., ::2] * cos) - (q[..., 1::2] * sin) |
| | q_rot = torch.cat([q_rot, rotate_half(q_rot)], dim=-1) |
| | k_rot = (k[..., ::2] * cos) - (k[..., 1::2] * sin) |
| | k_rot = torch.cat([k_rot, rotate_half(k_rot)], dim=-1) |
| | return q_rot, k_rot |
| |
|
| | |
| | |
| | |
| | class GPTBlock(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.n_head = config.n_head |
| | self.n_embd = config.n_embd |
| | self.head_dim = config.n_embd // config.n_head |
| | self.qkv = nn.Linear(config.n_embd, config.n_embd * 3) |
| | self.ln1 = nn.LayerNorm(config.n_embd) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(config.n_embd, 4 * config.n_embd), |
| | nn.GELU(), |
| | nn.Linear(4 * config.n_embd, config.n_embd) |
| | ) |
| | self.ln2 = nn.LayerNorm(config.n_embd) |
| |
|
| | def forward(self, x): |
| | |
| | x_norm = self.ln1(x) |
| | B, N, E = x_norm.shape[1], x_norm.shape[0], x_norm.shape[2] |
| | qkv = self.qkv(x_norm).reshape(N, B, 3, self.n_head, self.head_dim).permute(2, 3, 1, 0, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | |
| | seq_len = x.size(0) |
| | q, k = apply_rotary_pos_emb(q, k, seq_len) |
| |
|
| | |
| | attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| | attn_probs = torch.softmax(attn_scores, dim=-1) |
| | attn_out = torch.matmul(attn_probs, v) |
| | attn_out = attn_out.permute(2, 1, 0, 3).reshape(N, B, E) |
| | x = x + attn_out |
| | x_norm = self.ln2(x) |
| | x = x + self.mlp(x_norm) |
| | return x |
| |
|
| | |
| | |
| | |
| | class ChatDBS1Model(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd) |
| | self.position_embeddings = nn.Embedding(config.block_size, config.n_embd) |
| | self.blocks = nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)]) |
| | self.ln_f = nn.LayerNorm(config.n_embd) |
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| |
|
| | def forward(self, input_ids): |
| | batch_size, seq_len = input_ids.size() |
| | positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) |
| | x = self.token_embeddings(input_ids) + self.position_embeddings(positions) |
| | x = x.transpose(0,1) |
| | for block in self.blocks: |
| | x = block(x) |
| | x = x.transpose(0,1) |
| | x = self.ln_f(x) |
| | logits = self.lm_head(x) |
| | return logits |