File size: 3,721 Bytes
9a09e0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import math

# ----------------------------
# Config class
# ----------------------------
class ChatDBS1Config:
    def __init__(self, n_embd=512, n_layer=12, n_head=8, vocab_size=50257, block_size=1024):
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.vocab_size = vocab_size
        self.block_size = block_size

# ----------------------------
# Rotary Positional Embeddings (RoPE)
# ----------------------------
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, seq_len):
    freqs = 10000 ** (-torch.arange(0, q.size(-1), 2, device=q.device).float() / q.size(-1))
    t = torch.arange(seq_len, device=q.device).float()
    angles = t[:, None] * freqs[None, :]
    cos = angles.cos()[None, :, :]
    sin = angles.sin()[None, :, :]
    q_rot = (q[..., ::2] * cos) - (q[..., 1::2] * sin)
    q_rot = torch.cat([q_rot, rotate_half(q_rot)], dim=-1)
    k_rot = (k[..., ::2] * cos) - (k[..., 1::2] * sin)
    k_rot = torch.cat([k_rot, rotate_half(k_rot)], dim=-1)
    return q_rot, k_rot

# ----------------------------
# GPT-style attention block with RoPE
# ----------------------------
class GPTBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        self.qkv = nn.Linear(config.n_embd, config.n_embd * 3)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd)
        )
        self.ln2 = nn.LayerNorm(config.n_embd)

    def forward(self, x):
        # x: seq_len x batch x embd
        x_norm = self.ln1(x)
        B, N, E = x_norm.shape[1], x_norm.shape[0], x_norm.shape[2]
        qkv = self.qkv(x_norm).reshape(N, B, 3, self.n_head, self.head_dim).permute(2, 3, 1, 0, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: head x batch x seq_len x head_dim

        # Apply RoPE
        seq_len = x.size(0)
        q, k = apply_rotary_pos_emb(q, k, seq_len)

        # Scaled dot-product attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_probs, v)
        attn_out = attn_out.permute(2, 1, 0, 3).reshape(N, B, E)
        x = x + attn_out
        x_norm = self.ln2(x)
        x = x + self.mlp(x_norm)
        return x

# ----------------------------
# GPT-style model with RoPE
# ----------------------------
class ChatDBS1Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embeddings = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    def forward(self, input_ids):
        batch_size, seq_len = input_ids.size()
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_embeddings(input_ids) + self.position_embeddings(positions)
        x = x.transpose(0,1)  # seq_len x batch x embd
        for block in self.blocks:
            x = block(x)
        x = x.transpose(0,1)  # batch x seq_len x embd
        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits