""" Photon-3M | Arsitektur Dual Sparse Dikembangkan oleh Velyn (https://huggingface.co/Veenn) """ import torch import torch.nn as nn import torch.nn.functional as F class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.w = nn.Parameter(torch.ones(dim)) def forward(self, x): rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return x * rms * self.w class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq=2048, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_seq) def _build_cache(self, seq): t = torch.arange(seq, device=self.inv_freq.device).float() f = torch.outer(t, self.inv_freq) emb = torch.cat([f, f], dim=-1) self.register_buffer("cos_cache", emb.cos()[None, None]) self.register_buffer("sin_cache", emb.sin()[None, None]) def forward(self, x, seq_len): cos = self.cos_cache[:, :, :seq_len] sin = self.sin_cache[:, :, :seq_len] x1, x2 = x[..., ::2], x[..., 1::2] return x * cos + torch.cat([-x2, x1], dim=-1) * sin class PhotonAttention(nn.Module): """Grouped Query Attention + RoPE""" def __init__(self, hidden, heads, kv_heads): super().__init__() self.heads = heads self.kv_heads = kv_heads self.head_dim = hidden // heads self.groups = heads // kv_heads self.q = nn.Linear(hidden, hidden, bias=False) self.k = nn.Linear(hidden, self.head_dim * kv_heads, bias=False) self.v = nn.Linear(hidden, self.head_dim * kv_heads, bias=False) self.o = nn.Linear(hidden, hidden, bias=False) self.rope = RotaryEmbedding(self.head_dim) def forward(self, x): B, T, C = x.shape q = self.q(x).view(B, T, self.heads, self.head_dim).transpose(1, 2) k = self.k(x).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2) v = self.v(x).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2) q = self.rope(q, T) k = self.rope(k, T) k = k.repeat_interleave(self.groups, dim=1) v = v.repeat_interleave(self.groups, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.o(out.transpose(1, 2).contiguous().view(B, T, C)) class PhotonExpert(nn.Module): """Single FFN expert dengan aktivasi SwiGLU""" def __init__(self, hidden, ff_dim): super().__init__() self.gate = nn.Linear(hidden, ff_dim, bias=False) self.up = nn.Linear(hidden, ff_dim, bias=False) self.down = nn.Linear(ff_dim, hidden, bias=False) def forward(self, x): return self.down(F.silu(self.gate(x)) * self.up(x)) class PhotonMoE(nn.Module): """ Sparse MoE: - 1 Shared Expert (selalu aktif) - N Specialist Expert (router pilih 1 per token) """ def __init__(self, hidden, ff_mult, num_experts, num_active): super().__init__() ff_dim = hidden * ff_mult self.num_experts = num_experts self.num_active = num_active self.shared = PhotonExpert(hidden, ff_dim) self.specialists = nn.ModuleList([ PhotonExpert(hidden, ff_dim) for _ in range(num_experts) ]) self.router = nn.Linear(hidden, num_experts, bias=False) def forward(self, x): B, T, C = x.shape flat = x.view(-1, C) shared_out = self.shared(flat) weights = F.softmax(self.router(flat), dim=-1) topk_w, topk_i = weights.topk(self.num_active, dim=-1) topk_w = topk_w / topk_w.sum(dim=-1, keepdim=True) spec_out = torch.zeros_like(flat) for i in range(self.num_active): for e in range(self.num_experts): mask = (topk_i[:, i] == e) if mask.any(): spec_out[mask] += topk_w[mask, i:i+1] * self.specialists[e](flat[mask]) return (shared_out + spec_out).view(B, T, C) class LayerSkipRouter(nn.Module): """ Adaptive Layer Skipping. Router kecil per layer yang memutuskan: proses atau lewati? """ def __init__(self, hidden, skip_prob=0.3): super().__init__() self.skip_prob = skip_prob self.gate = nn.Linear(hidden, 1, bias=True) nn.init.constant_(self.gate.bias, 2.0) def forward(self, x, training=False): score = torch.sigmoid(self.gate(x.mean(dim=1))) skip = (score < self.skip_prob).float() if training: skip = skip + score - score.detach() return skip class PhotonLayer(nn.Module): def __init__(self, hidden, heads, kv_heads, ff_mult, num_experts, num_active): super().__init__() self.norm1 = RMSNorm(hidden) self.attn = PhotonAttention(hidden, heads, kv_heads) self.norm2 = RMSNorm(hidden) self.moe = PhotonMoE(hidden, ff_mult, num_experts, num_active) self.skip_router = LayerSkipRouter(hidden) def forward(self, x, training=False): skip = self.skip_router(x, training=training).unsqueeze(-1) attn_out = x + self.attn(self.norm1(x)) moe_out = attn_out + self.moe(self.norm2(attn_out)) return torch.where(skip.bool(), attn_out, moe_out) class PhotonModel(nn.Module): def __init__(self, vocab, hidden, layers, heads, kv_heads, ff_mult, num_experts, num_active, max_seq): super().__init__() self.embed = nn.Embedding(vocab, hidden) self.layers = nn.ModuleList([ PhotonLayer(hidden, heads, kv_heads, ff_mult, num_experts, num_active) for _ in range(layers) ]) self.norm = RMSNorm(hidden) self.head = nn.Linear(hidden, vocab, bias=False) self.head.weight = self.embed.weight self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward(self, input_ids, attention_mask=None, labels=None): x = self.embed(input_ids) for layer in self.layers: x = layer(x, training=self.training) x = self.norm(x) logits = self.head(x) loss = None if labels is not None: loss = F.cross_entropy( logits[:, :-1].contiguous().view(-1, logits.size(-1)), labels[:, 1:].contiguous().view(-1), ignore_index=-100 ) return loss, logits