| """ |
| VicAI Model Architecture |
| A 5B parameter decoder-only transformer language model. |
| """ |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization.""" |
| |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| |
| def forward(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
| class RotaryPositionalEmbedding(nn.Module): |
| """Rotary Position Embedding (RoPE).""" |
| |
| def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.base = base |
| |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| |
| t = torch.arange(max_seq_len) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos()[None, None, :, :]) |
| self.register_buffer("sin_cached", emb.sin()[None, None, :, :]) |
| |
| def rotate_half(self, x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
| |
| def apply_rotary_pos_emb(self, q, k, cos, sin): |
| q_embed = (q * cos) + (self.rotate_half(q) * sin) |
| k_embed = (k * cos) + (self.rotate_half(k) * sin) |
| return q_embed, k_embed |
| |
| def forward(self, q, k, seq_len: int): |
| cos = self.cos_cached[:, :, :seq_len, :] |
| sin = self.sin_cached[:, :, :seq_len, :] |
| return self.apply_rotary_pos_emb(q, k, cos, sin) |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| """Grouped Query Attention (GQA) for efficient inference.""" |
| |
| def __init__( |
| self, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.n_heads = n_heads |
| self.n_kv_heads = n_kv_heads |
| self.head_dim = dim // n_heads |
| self.n_rep = n_heads // n_kv_heads |
| |
| self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False) |
| self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) |
| self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) |
| self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) |
| |
| self.attn_dropout = nn.Dropout(dropout) |
| self.resid_dropout = nn.Dropout(dropout) |
| |
| self.rope = RotaryPositionalEmbedding(self.head_dim) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ): |
| bsz, seq_len, _ = x.shape |
| |
| q = self.wq(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.wk(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.wv(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| |
| q, k = self.rope(q, k, seq_len) |
| |
| if past_key_value is not None: |
| past_k, past_v = past_key_value |
| k = torch.cat([past_k, k], dim=2) |
| v = torch.cat([past_v, v], dim=2) |
| |
| past_key_value = (k, v) |
| |
| |
| k = k.repeat_interleave(self.n_rep, dim=1) |
| v = v.repeat_interleave(self.n_rep, dim=1) |
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| |
| if mask is not None: |
| scores = scores + mask |
| |
| attn = F.softmax(scores, dim=-1) |
| attn = self.attn_dropout(attn) |
| |
| out = torch.matmul(attn, v) |
| out = out.transpose(1, 2).contiguous().view(bsz, seq_len, self.dim) |
| out = self.wo(out) |
| out = self.resid_dropout(out) |
| |
| return out, past_key_value |
|
|
|
|
| class FeedForward(nn.Module): |
| """SwiGLU Feed-Forward Network.""" |
| |
| def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): |
| super().__init__() |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x): |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Single transformer block with pre-normalization.""" |
| |
| def __init__( |
| self, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| hidden_dim: int, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.attention_norm = RMSNorm(dim) |
| self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads, dropout) |
| self.ffn_norm = RMSNorm(dim) |
| self.feed_forward = FeedForward(dim, hidden_dim, dropout) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ): |
| |
| attn_out, past_key_value = self.attention( |
| self.attention_norm(x), mask, past_key_value |
| ) |
| x = x + attn_out |
| |
| |
| x = x + self.feed_forward(self.ffn_norm(x)) |
| |
| return x, past_key_value |
|
|
|
|
| class VicAIConfig: |
| """Configuration for VicAI model.""" |
| |
| def __init__( |
| self, |
| vocab_size: int = 32000, |
| dim: int = 4096, |
| n_layers: int = 32, |
| n_heads: int = 32, |
| n_kv_heads: int = 8, |
| hidden_dim: int = 14336, |
| max_seq_len: int = 8192, |
| dropout: float = 0.0, |
| tie_weights: bool = False, |
| ): |
| self.vocab_size = vocab_size |
| self.dim = dim |
| self.n_layers = n_layers |
| self.n_heads = n_heads |
| self.n_kv_heads = n_kv_heads |
| self.hidden_dim = hidden_dim |
| self.max_seq_len = max_seq_len |
| self.dropout = dropout |
| self.tie_weights = tie_weights |
| |
| @property |
| def num_parameters(self): |
| """Calculate approximate parameter count.""" |
| |
| params = self.vocab_size * self.dim |
| |
| attn_params = 4 * self.dim * self.dim |
| |
| ffn_params = 3 * self.dim * self.hidden_dim |
| |
| params += self.n_layers * (attn_params + ffn_params) |
| |
| params += self.vocab_size * self.dim |
| return params |
|
|
|
|
| class VicAIModel(nn.Module): |
| """ |
| VicAI: A 5B parameter decoder-only transformer language model. |
| |
| Architecture details: |
| - 32 layers |
| - 4096 model dimension |
| - 32 attention heads (8 key-value heads for GQA) |
| - SwiGLU FFN with 14336 hidden dimension |
| - RoPE positional embeddings |
| - RMSNorm pre-normalization |
| - ~5.1B total parameters |
| """ |
| |
| def __init__(self, config: VicAIConfig): |
| super().__init__() |
| self.config = config |
| |
| self.token_embedding = nn.Embedding(config.vocab_size, config.dim) |
| self.dropout = nn.Dropout(config.dropout) |
| |
| self.layers = nn.ModuleList([ |
| TransformerBlock( |
| config.dim, |
| config.n_heads, |
| config.n_kv_heads, |
| config.hidden_dim, |
| config.dropout, |
| ) |
| for _ in range(config.n_layers) |
| ]) |
| |
| self.norm = RMSNorm(config.dim) |
| self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) |
| |
| if config.tie_weights: |
| self.lm_head.weight = self.token_embedding.weight |
| |
| self.apply(self._init_weights) |
| |
| |
| total_params = self.get_num_params() |
| print(f"VicAI Model initialized with {total_params / 1e9:.2f}B parameters") |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
| def get_num_params(self, non_embedding=True): |
| n_params = sum(p.numel() for p in self.parameters()) |
| if non_embedding: |
| n_params -= self.token_embedding.weight.numel() |
| return n_params |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| targets: Optional[torch.Tensor] = None, |
| past_key_values: Optional[list] = None, |
| ): |
| bsz, seq_len = input_ids.shape |
| |
| |
| mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=input_ids.device), |
| diagonal=1 |
| ).bool() |
| mask = mask.unsqueeze(0).unsqueeze(0) |
| mask = mask.to(input_ids.device) |
| mask = torch.where(mask, float('-inf'), 0.0) |
| |
| x = self.token_embedding(input_ids) |
| x = self.dropout(x) |
| |
| new_key_values = [] |
| for i, layer in enumerate(self.layers): |
| past_kv = past_key_values[i] if past_key_values is not None else None |
| x, kv = layer(x, mask, past_kv) |
| new_key_values.append(kv) |
| |
| x = self.norm(x) |
| logits = self.lm_head(x) |
| |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ignore_index=-100 |
| ) |
| |
| return { |
| 'logits': logits, |
| 'loss': loss, |
| 'past_key_values': new_key_values, |
| } |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int = 100, |
| temperature: float = 1.0, |
| top_k: int = 50, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.0, |
| eos_token_id: Optional[int] = None, |
| ): |
| """Generate text autoregressively.""" |
| self.eval() |
| |
| batch_size = input_ids.shape[0] |
| device = input_ids.device |
| past_key_values = None |
| |
| for _ in range(max_new_tokens): |
| outputs = self(input_ids, past_key_values=past_key_values) |
| logits = outputs['logits'] |
| past_key_values = outputs['past_key_values'] |
| |
| |
| logits = logits[:, -1, :] / temperature |
| |
| |
| if repetition_penalty != 1.0: |
| for i in range(batch_size): |
| for token_id in set(input_ids[i].tolist()): |
| logits[i, token_id] /= repetition_penalty |
| |
| |
| if top_k > 0: |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| logits[indices_to_remove] = float('-inf') |
| |
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| indices_to_remove = sorted_indices_to_remove.scatter( |
| 1, sorted_indices, sorted_indices_to_remove |
| ) |
| logits[indices_to_remove] = float('-inf') |
| |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
| |
| |
| if eos_token_id is not None and (next_token == eos_token_id).all(): |
| break |
| |
| return input_ids |
|
|
|
|
| def create_vicai_5b(vocab_size: int = 32000) -> VicAIModel: |
| """Create a 5B parameter VicAI model.""" |
| config = VicAIConfig( |
| vocab_size=vocab_size, |
| dim=4096, |
| n_layers=32, |
| n_heads=32, |
| n_kv_heads=8, |
| hidden_dim=14336, |
| max_seq_len=8192, |
| dropout=0.0, |
| ) |
| return VicAIModel(config) |
|
|
|
|
| if __name__ == "__main__": |
| |
| model = create_vicai_5b() |
| print(f"Total parameters: {model.get_num_params() / 1e9:.2f}B") |
| |
| |
| x = torch.randint(0, 32000, (2, 128)) |
| outputs = model(x) |
| print(f"Output shape: {outputs['logits'].shape}") |
| print(f"Loss: {outputs['loss']}") |
|
|