tinyvic / model.py
Viclim's picture
Upload 17 files
9299fff verified
"""
VicAI Model Architecture
A 5B parameter decoder-only transformer language model.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(max_seq_len)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, q, k, cos, sin):
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
def forward(self, q, k, seq_len: int):
cos = self.cos_cached[:, :, :seq_len, :]
sin = self.sin_cached[:, :, :seq_len, :]
return self.apply_rotary_pos_emb(q, k, cos, sin)
class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) for efficient inference."""
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
dropout: float = 0.0,
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.n_rep = n_heads // n_kv_heads
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.rope = RotaryPositionalEmbedding(self.head_dim)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
bsz, seq_len, _ = x.shape
q = self.wq(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.wk(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.wv(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = self.rope(q, k, seq_len)
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
past_key_value = (k, v)
# Repeat k/v for grouped query attention
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask
attn = F.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(bsz, seq_len, self.dim)
out = self.wo(out)
out = self.resid_dropout(out)
return out, past_key_value
class FeedForward(nn.Module):
"""SwiGLU Feed-Forward Network."""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
"""Single transformer block with pre-normalization."""
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
hidden_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.attention_norm = RMSNorm(dim)
self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads, dropout)
self.ffn_norm = RMSNorm(dim)
self.feed_forward = FeedForward(dim, hidden_dim, dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
# Attention with residual
attn_out, past_key_value = self.attention(
self.attention_norm(x), mask, past_key_value
)
x = x + attn_out
# FFN with residual
x = x + self.feed_forward(self.ffn_norm(x))
return x, past_key_value
class VicAIConfig:
"""Configuration for VicAI model."""
def __init__(
self,
vocab_size: int = 32000,
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
n_kv_heads: int = 8,
hidden_dim: int = 14336,
max_seq_len: int = 8192,
dropout: float = 0.0,
tie_weights: bool = False,
):
self.vocab_size = vocab_size
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_dim = hidden_dim
self.max_seq_len = max_seq_len
self.dropout = dropout
self.tie_weights = tie_weights
@property
def num_parameters(self):
"""Calculate approximate parameter count."""
# Embedding
params = self.vocab_size * self.dim
# Attention per layer
attn_params = 4 * self.dim * self.dim # q, k, v, o projections
# FFN per layer
ffn_params = 3 * self.dim * self.hidden_dim # w1, w2, w3
# Layers
params += self.n_layers * (attn_params + ffn_params)
# Output
params += self.vocab_size * self.dim
return params
class VicAIModel(nn.Module):
"""
VicAI: A 5B parameter decoder-only transformer language model.
Architecture details:
- 32 layers
- 4096 model dimension
- 32 attention heads (8 key-value heads for GQA)
- SwiGLU FFN with 14336 hidden dimension
- RoPE positional embeddings
- RMSNorm pre-normalization
- ~5.1B total parameters
"""
def __init__(self, config: VicAIConfig):
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.dim)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([
TransformerBlock(
config.dim,
config.n_heads,
config.n_kv_heads,
config.hidden_dim,
config.dropout,
)
for _ in range(config.n_layers)
])
self.norm = RMSNorm(config.dim)
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
if config.tie_weights:
self.lm_head.weight = self.token_embedding.weight
self.apply(self._init_weights)
# Print model info
total_params = self.get_num_params()
print(f"VicAI Model initialized with {total_params / 1e9:.2f}B parameters")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.token_embedding.weight.numel()
return n_params
def forward(
self,
input_ids: torch.Tensor,
targets: Optional[torch.Tensor] = None,
past_key_values: Optional[list] = None,
):
bsz, seq_len = input_ids.shape
# Create causal mask
mask = torch.triu(
torch.ones(seq_len, seq_len, device=input_ids.device),
diagonal=1
).bool()
mask = mask.unsqueeze(0).unsqueeze(0)
mask = mask.to(input_ids.device)
mask = torch.where(mask, float('-inf'), 0.0)
x = self.token_embedding(input_ids)
x = self.dropout(x)
new_key_values = []
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values is not None else None
x, kv = layer(x, mask, past_kv)
new_key_values.append(kv)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100
)
return {
'logits': logits,
'loss': loss,
'past_key_values': new_key_values,
}
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.0,
eos_token_id: Optional[int] = None,
):
"""Generate text autoregressively."""
self.eval()
batch_size = input_ids.shape[0]
device = input_ids.device
past_key_values = None
for _ in range(max_new_tokens):
outputs = self(input_ids, past_key_values=past_key_values)
logits = outputs['logits']
past_key_values = outputs['past_key_values']
# Get logits for last token
logits = logits[:, -1, :] / temperature
# Apply repetition penalty
if repetition_penalty != 1.0:
for i in range(batch_size):
for token_id in set(input_ids[i].tolist()):
logits[i, token_id] /= repetition_penalty
# Top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = float('-inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
# Early stopping if EOS token generated
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return input_ids
def create_vicai_5b(vocab_size: int = 32000) -> VicAIModel:
"""Create a 5B parameter VicAI model."""
config = VicAIConfig(
vocab_size=vocab_size,
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
hidden_dim=14336,
max_seq_len=8192,
dropout=0.0,
)
return VicAIModel(config)
if __name__ == "__main__":
# Test model creation
model = create_vicai_5b()
print(f"Total parameters: {model.get_num_params() / 1e9:.2f}B")
# Test forward pass
x = torch.randint(0, 32000, (2, 128))
outputs = model(x)
print(f"Output shape: {outputs['logits'].shape}")
print(f"Loss: {outputs['loss']}")