import torch import torch.nn as nn import math # ---------------------------- # Config class # ---------------------------- class ChatDBS1Config: def __init__(self, n_embd=512, n_layer=12, n_head=8, vocab_size=50257, block_size=1024): self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.vocab_size = vocab_size self.block_size = block_size # ---------------------------- # Rotary Positional Embeddings (RoPE) # ---------------------------- def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, seq_len): freqs = 10000 ** (-torch.arange(0, q.size(-1), 2, device=q.device).float() / q.size(-1)) t = torch.arange(seq_len, device=q.device).float() angles = t[:, None] * freqs[None, :] cos = angles.cos()[None, :, :] sin = angles.sin()[None, :, :] q_rot = (q[..., ::2] * cos) - (q[..., 1::2] * sin) q_rot = torch.cat([q_rot, rotate_half(q_rot)], dim=-1) k_rot = (k[..., ::2] * cos) - (k[..., 1::2] * sin) k_rot = torch.cat([k_rot, rotate_half(k_rot)], dim=-1) return q_rot, k_rot # ---------------------------- # GPT-style attention block with RoPE # ---------------------------- class GPTBlock(nn.Module): def __init__(self, config): super().__init__() self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head self.qkv = nn.Linear(config.n_embd, config.n_embd * 3) self.ln1 = nn.LayerNorm(config.n_embd) self.mlp = nn.Sequential( nn.Linear(config.n_embd, 4 * config.n_embd), nn.GELU(), nn.Linear(4 * config.n_embd, config.n_embd) ) self.ln2 = nn.LayerNorm(config.n_embd) def forward(self, x): # x: seq_len x batch x embd x_norm = self.ln1(x) B, N, E = x_norm.shape[1], x_norm.shape[0], x_norm.shape[2] qkv = self.qkv(x_norm).reshape(N, B, 3, self.n_head, self.head_dim).permute(2, 3, 1, 0, 4) q, k, v = qkv[0], qkv[1], qkv[2] # each: head x batch x seq_len x head_dim # Apply RoPE seq_len = x.size(0) q, k = apply_rotary_pos_emb(q, k, seq_len) # Scaled dot-product attention attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_probs = torch.softmax(attn_scores, dim=-1) attn_out = torch.matmul(attn_probs, v) attn_out = attn_out.permute(2, 1, 0, 3).reshape(N, B, E) x = x + attn_out x_norm = self.ln2(x) x = x + self.mlp(x_norm) return x # ---------------------------- # GPT-style model with RoPE # ---------------------------- class ChatDBS1Model(nn.Module): def __init__(self, config): super().__init__() self.config = config self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd) self.position_embeddings = nn.Embedding(config.block_size, config.n_embd) self.blocks = nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) def forward(self, input_ids): batch_size, seq_len = input_ids.size() positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) x = self.token_embeddings(input_ids) + self.position_embeddings(positions) x = x.transpose(0,1) # seq_len x batch x embd for block in self.blocks: x = block(x) x = x.transpose(0,1) # batch x seq_len x embd x = self.ln_f(x) logits = self.lm_head(x) return logits