| """ |
| Photon-3M | Arsitektur Dual Sparse |
| Dikembangkan oleh Velyn (https://huggingface.co/Veenn) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.w = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| return x * rms * self.w |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_seq=2048, base=10000): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| self._build_cache(max_seq) |
|
|
| def _build_cache(self, seq): |
| t = torch.arange(seq, device=self.inv_freq.device).float() |
| f = torch.outer(t, self.inv_freq) |
| emb = torch.cat([f, f], dim=-1) |
| self.register_buffer("cos_cache", emb.cos()[None, None]) |
| self.register_buffer("sin_cache", emb.sin()[None, None]) |
|
|
| def forward(self, x, seq_len): |
| cos = self.cos_cache[:, :, :seq_len] |
| sin = self.sin_cache[:, :, :seq_len] |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| return x * cos + torch.cat([-x2, x1], dim=-1) * sin |
|
|
|
|
| class PhotonAttention(nn.Module): |
| """Grouped Query Attention + RoPE""" |
| def __init__(self, hidden, heads, kv_heads): |
| super().__init__() |
| self.heads = heads |
| self.kv_heads = kv_heads |
| self.head_dim = hidden // heads |
| self.groups = heads // kv_heads |
|
|
| self.q = nn.Linear(hidden, hidden, bias=False) |
| self.k = nn.Linear(hidden, self.head_dim * kv_heads, bias=False) |
| self.v = nn.Linear(hidden, self.head_dim * kv_heads, bias=False) |
| self.o = nn.Linear(hidden, hidden, bias=False) |
| self.rope = RotaryEmbedding(self.head_dim) |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| q = self.q(x).view(B, T, self.heads, self.head_dim).transpose(1, 2) |
| k = self.k(x).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v(x).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2) |
|
|
| q = self.rope(q, T) |
| k = self.rope(k, T) |
|
|
| k = k.repeat_interleave(self.groups, dim=1) |
| v = v.repeat_interleave(self.groups, dim=1) |
|
|
| out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| return self.o(out.transpose(1, 2).contiguous().view(B, T, C)) |
|
|
|
|
| class PhotonExpert(nn.Module): |
| """Single FFN expert dengan aktivasi SwiGLU""" |
| def __init__(self, hidden, ff_dim): |
| super().__init__() |
| self.gate = nn.Linear(hidden, ff_dim, bias=False) |
| self.up = nn.Linear(hidden, ff_dim, bias=False) |
| self.down = nn.Linear(ff_dim, hidden, bias=False) |
|
|
| def forward(self, x): |
| return self.down(F.silu(self.gate(x)) * self.up(x)) |
|
|
|
|
| class PhotonMoE(nn.Module): |
| """ |
| Sparse MoE: |
| - 1 Shared Expert (selalu aktif) |
| - N Specialist Expert (router pilih 1 per token) |
| """ |
| def __init__(self, hidden, ff_mult, num_experts, num_active): |
| super().__init__() |
| ff_dim = hidden * ff_mult |
| self.num_experts = num_experts |
| self.num_active = num_active |
|
|
| self.shared = PhotonExpert(hidden, ff_dim) |
| self.specialists = nn.ModuleList([ |
| PhotonExpert(hidden, ff_dim) for _ in range(num_experts) |
| ]) |
| self.router = nn.Linear(hidden, num_experts, bias=False) |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| flat = x.view(-1, C) |
|
|
| shared_out = self.shared(flat) |
|
|
| weights = F.softmax(self.router(flat), dim=-1) |
| topk_w, topk_i = weights.topk(self.num_active, dim=-1) |
| topk_w = topk_w / topk_w.sum(dim=-1, keepdim=True) |
|
|
| spec_out = torch.zeros_like(flat) |
| for i in range(self.num_active): |
| for e in range(self.num_experts): |
| mask = (topk_i[:, i] == e) |
| if mask.any(): |
| spec_out[mask] += topk_w[mask, i:i+1] * self.specialists[e](flat[mask]) |
|
|
| return (shared_out + spec_out).view(B, T, C) |
|
|
|
|
| class LayerSkipRouter(nn.Module): |
| """ |
| Adaptive Layer Skipping. |
| Router kecil per layer yang memutuskan: proses atau lewati? |
| """ |
| def __init__(self, hidden, skip_prob=0.3): |
| super().__init__() |
| self.skip_prob = skip_prob |
| self.gate = nn.Linear(hidden, 1, bias=True) |
| nn.init.constant_(self.gate.bias, 2.0) |
|
|
| def forward(self, x, training=False): |
| score = torch.sigmoid(self.gate(x.mean(dim=1))) |
| skip = (score < self.skip_prob).float() |
| if training: |
| skip = skip + score - score.detach() |
| return skip |
|
|
|
|
| class PhotonLayer(nn.Module): |
| def __init__(self, hidden, heads, kv_heads, ff_mult, num_experts, num_active): |
| super().__init__() |
| self.norm1 = RMSNorm(hidden) |
| self.attn = PhotonAttention(hidden, heads, kv_heads) |
| self.norm2 = RMSNorm(hidden) |
| self.moe = PhotonMoE(hidden, ff_mult, num_experts, num_active) |
| self.skip_router = LayerSkipRouter(hidden) |
|
|
| def forward(self, x, training=False): |
| skip = self.skip_router(x, training=training).unsqueeze(-1) |
| attn_out = x + self.attn(self.norm1(x)) |
| moe_out = attn_out + self.moe(self.norm2(attn_out)) |
| return torch.where(skip.bool(), attn_out, moe_out) |
|
|
|
|
| class PhotonModel(nn.Module): |
| def __init__(self, vocab, hidden, layers, heads, kv_heads, |
| ff_mult, num_experts, num_active, max_seq): |
| super().__init__() |
| self.embed = nn.Embedding(vocab, hidden) |
| self.layers = nn.ModuleList([ |
| PhotonLayer(hidden, heads, kv_heads, ff_mult, num_experts, num_active) |
| for _ in range(layers) |
| ]) |
| self.norm = RMSNorm(hidden) |
| self.head = nn.Linear(hidden, vocab, bias=False) |
| self.head.weight = self.embed.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| x = self.embed(input_ids) |
| for layer in self.layers: |
| x = layer(x, training=self.training) |
| x = self.norm(x) |
| logits = self.head(x) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy( |
| logits[:, :-1].contiguous().view(-1, logits.size(-1)), |
| labels[:, 1:].contiguous().view(-1), |
| ignore_index=-100 |
| ) |
| return loss, logits |
|
|