""" VicAI Model Architecture A 5B parameter decoder-only transformer language model. """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight class RotaryPositionalEmbedding(nn.Module): """Rotary Position Embedding (RoPE).""" def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_seq_len) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :]) self.register_buffer("sin_cached", emb.sin()[None, None, :, :]) def rotate_half(self, x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(self, q, k, cos, sin): q_embed = (q * cos) + (self.rotate_half(q) * sin) k_embed = (k * cos) + (self.rotate_half(k) * sin) return q_embed, k_embed def forward(self, q, k, seq_len: int): cos = self.cos_cached[:, :, :seq_len, :] sin = self.sin_cached[:, :, :seq_len, :] return self.apply_rotary_pos_emb(q, k, cos, sin) class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) for efficient inference.""" def __init__( self, dim: int, n_heads: int, n_kv_heads: int, dropout: float = 0.0, ): super().__init__() self.dim = dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = dim // n_heads self.n_rep = n_heads // n_kv_heads self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False) self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) self.rope = RotaryPositionalEmbedding(self.head_dim) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): bsz, seq_len, _ = x.shape q = self.wq(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.wk(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.wv(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) q, k = self.rope(q, k, seq_len) if past_key_value is not None: past_k, past_v = past_key_value k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) past_key_value = (k, v) # Repeat k/v for grouped query attention k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask attn = F.softmax(scores, dim=-1) attn = self.attn_dropout(attn) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(bsz, seq_len, self.dim) out = self.wo(out) out = self.resid_dropout(out) return out, past_key_value class FeedForward(nn.Module): """SwiGLU Feed-Forward Network.""" def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): """Single transformer block with pre-normalization.""" def __init__( self, dim: int, n_heads: int, n_kv_heads: int, hidden_dim: int, dropout: float = 0.0, ): super().__init__() self.attention_norm = RMSNorm(dim) self.attention = GroupedQueryAttention(dim, n_heads, n_kv_heads, dropout) self.ffn_norm = RMSNorm(dim) self.feed_forward = FeedForward(dim, hidden_dim, dropout) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Attention with residual attn_out, past_key_value = self.attention( self.attention_norm(x), mask, past_key_value ) x = x + attn_out # FFN with residual x = x + self.feed_forward(self.ffn_norm(x)) return x, past_key_value class VicAIConfig: """Configuration for VicAI model.""" def __init__( self, vocab_size: int = 32000, dim: int = 4096, n_layers: int = 32, n_heads: int = 32, n_kv_heads: int = 8, hidden_dim: int = 14336, max_seq_len: int = 8192, dropout: float = 0.0, tie_weights: bool = False, ): self.vocab_size = vocab_size self.dim = dim self.n_layers = n_layers self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.hidden_dim = hidden_dim self.max_seq_len = max_seq_len self.dropout = dropout self.tie_weights = tie_weights @property def num_parameters(self): """Calculate approximate parameter count.""" # Embedding params = self.vocab_size * self.dim # Attention per layer attn_params = 4 * self.dim * self.dim # q, k, v, o projections # FFN per layer ffn_params = 3 * self.dim * self.hidden_dim # w1, w2, w3 # Layers params += self.n_layers * (attn_params + ffn_params) # Output params += self.vocab_size * self.dim return params class VicAIModel(nn.Module): """ VicAI: A 5B parameter decoder-only transformer language model. Architecture details: - 32 layers - 4096 model dimension - 32 attention heads (8 key-value heads for GQA) - SwiGLU FFN with 14336 hidden dimension - RoPE positional embeddings - RMSNorm pre-normalization - ~5.1B total parameters """ def __init__(self, config: VicAIConfig): super().__init__() self.config = config self.token_embedding = nn.Embedding(config.vocab_size, config.dim) self.dropout = nn.Dropout(config.dropout) self.layers = nn.ModuleList([ TransformerBlock( config.dim, config.n_heads, config.n_kv_heads, config.hidden_dim, config.dropout, ) for _ in range(config.n_layers) ]) self.norm = RMSNorm(config.dim) self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) if config.tie_weights: self.lm_head.weight = self.token_embedding.weight self.apply(self._init_weights) # Print model info total_params = self.get_num_params() print(f"VicAI Model initialized with {total_params / 1e9:.2f}B parameters") def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def get_num_params(self, non_embedding=True): n_params = sum(p.numel() for p in self.parameters()) if non_embedding: n_params -= self.token_embedding.weight.numel() return n_params def forward( self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None, past_key_values: Optional[list] = None, ): bsz, seq_len = input_ids.shape # Create causal mask mask = torch.triu( torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1 ).bool() mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.to(input_ids.device) mask = torch.where(mask, float('-inf'), 0.0) x = self.token_embedding(input_ids) x = self.dropout(x) new_key_values = [] for i, layer in enumerate(self.layers): past_kv = past_key_values[i] if past_key_values is not None else None x, kv = layer(x, mask, past_kv) new_key_values.append(kv) x = self.norm(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100 ) return { 'logits': logits, 'loss': loss, 'past_key_values': new_key_values, } @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.9, repetition_penalty: float = 1.0, eos_token_id: Optional[int] = None, ): """Generate text autoregressively.""" self.eval() batch_size = input_ids.shape[0] device = input_ids.device past_key_values = None for _ in range(max_new_tokens): outputs = self(input_ids, past_key_values=past_key_values) logits = outputs['logits'] past_key_values = outputs['past_key_values'] # Get logits for last token logits = logits[:, -1, :] / temperature # Apply repetition penalty if repetition_penalty != 1.0: for i in range(batch_size): for token_id in set(input_ids[i].tolist()): logits[i, token_id] /= repetition_penalty # Top-k filtering if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = float('-inf') # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) # Early stopping if EOS token generated if eos_token_id is not None and (next_token == eos_token_id).all(): break return input_ids def create_vicai_5b(vocab_size: int = 32000) -> VicAIModel: """Create a 5B parameter VicAI model.""" config = VicAIConfig( vocab_size=vocab_size, dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, hidden_dim=14336, max_seq_len=8192, dropout=0.0, ) return VicAIModel(config) if __name__ == "__main__": # Test model creation model = create_vicai_5b() print(f"Total parameters: {model.get_num_params() / 1e9:.2f}B") # Test forward pass x = torch.randint(0, 32000, (2, 128)) outputs = model(x) print(f"Output shape: {outputs['logits'].shape}") print(f"Loss: {outputs['loss']}")