ChatDBS / modeling_chatdbs1.py
Ai128474's picture
Create modeling_chatdbs1.py
9a09e0c verified
import torch
import torch.nn as nn
import math
# ----------------------------
# Config class
# ----------------------------
class ChatDBS1Config:
def __init__(self, n_embd=512, n_layer=12, n_head=8, vocab_size=50257, block_size=1024):
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.vocab_size = vocab_size
self.block_size = block_size
# ----------------------------
# Rotary Positional Embeddings (RoPE)
# ----------------------------
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, seq_len):
freqs = 10000 ** (-torch.arange(0, q.size(-1), 2, device=q.device).float() / q.size(-1))
t = torch.arange(seq_len, device=q.device).float()
angles = t[:, None] * freqs[None, :]
cos = angles.cos()[None, :, :]
sin = angles.sin()[None, :, :]
q_rot = (q[..., ::2] * cos) - (q[..., 1::2] * sin)
q_rot = torch.cat([q_rot, rotate_half(q_rot)], dim=-1)
k_rot = (k[..., ::2] * cos) - (k[..., 1::2] * sin)
k_rot = torch.cat([k_rot, rotate_half(k_rot)], dim=-1)
return q_rot, k_rot
# ----------------------------
# GPT-style attention block with RoPE
# ----------------------------
class GPTBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.qkv = nn.Linear(config.n_embd, config.n_embd * 3)
self.ln1 = nn.LayerNorm(config.n_embd)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd)
)
self.ln2 = nn.LayerNorm(config.n_embd)
def forward(self, x):
# x: seq_len x batch x embd
x_norm = self.ln1(x)
B, N, E = x_norm.shape[1], x_norm.shape[0], x_norm.shape[2]
qkv = self.qkv(x_norm).reshape(N, B, 3, self.n_head, self.head_dim).permute(2, 3, 1, 0, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # each: head x batch x seq_len x head_dim
# Apply RoPE
seq_len = x.size(0)
q, k = apply_rotary_pos_emb(q, k, seq_len)
# Scaled dot-product attention
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_probs = torch.softmax(attn_scores, dim=-1)
attn_out = torch.matmul(attn_probs, v)
attn_out = attn_out.permute(2, 1, 0, 3).reshape(N, B, E)
x = x + attn_out
x_norm = self.ln2(x)
x = x + self.mlp(x_norm)
return x
# ----------------------------
# GPT-style model with RoPE
# ----------------------------
class ChatDBS1Model(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
self.position_embeddings = nn.Embedding(config.block_size, config.n_embd)
self.blocks = nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
def forward(self, input_ids):
batch_size, seq_len = input_ids.size()
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
x = self.token_embeddings(input_ids) + self.position_embeddings(positions)
x = x.transpose(0,1) # seq_len x batch x embd
for block in self.blocks:
x = block(x)
x = x.transpose(0,1) # batch x seq_len x embd
x = self.ln_f(x)
logits = self.lm_head(x)
return logits