photon-3m / modeling_photon.py
Veenn's picture
Upload Photon-3M checkpoint-2000
d007d5a verified
"""
Photon-3M | Arsitektur Dual Sparse
Dikembangkan oleh Velyn (https://huggingface.co/Veenn)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.w = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return x * rms * self.w
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq=2048, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq)
def _build_cache(self, seq):
t = torch.arange(seq, device=self.inv_freq.device).float()
f = torch.outer(t, self.inv_freq)
emb = torch.cat([f, f], dim=-1)
self.register_buffer("cos_cache", emb.cos()[None, None])
self.register_buffer("sin_cache", emb.sin()[None, None])
def forward(self, x, seq_len):
cos = self.cos_cache[:, :, :seq_len]
sin = self.sin_cache[:, :, :seq_len]
x1, x2 = x[..., ::2], x[..., 1::2]
return x * cos + torch.cat([-x2, x1], dim=-1) * sin
class PhotonAttention(nn.Module):
"""Grouped Query Attention + RoPE"""
def __init__(self, hidden, heads, kv_heads):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.head_dim = hidden // heads
self.groups = heads // kv_heads
self.q = nn.Linear(hidden, hidden, bias=False)
self.k = nn.Linear(hidden, self.head_dim * kv_heads, bias=False)
self.v = nn.Linear(hidden, self.head_dim * kv_heads, bias=False)
self.o = nn.Linear(hidden, hidden, bias=False)
self.rope = RotaryEmbedding(self.head_dim)
def forward(self, x):
B, T, C = x.shape
q = self.q(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)
k = self.k(x).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2)
v = self.v(x).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2)
q = self.rope(q, T)
k = self.rope(k, T)
k = k.repeat_interleave(self.groups, dim=1)
v = v.repeat_interleave(self.groups, dim=1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o(out.transpose(1, 2).contiguous().view(B, T, C))
class PhotonExpert(nn.Module):
"""Single FFN expert dengan aktivasi SwiGLU"""
def __init__(self, hidden, ff_dim):
super().__init__()
self.gate = nn.Linear(hidden, ff_dim, bias=False)
self.up = nn.Linear(hidden, ff_dim, bias=False)
self.down = nn.Linear(ff_dim, hidden, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class PhotonMoE(nn.Module):
"""
Sparse MoE:
- 1 Shared Expert (selalu aktif)
- N Specialist Expert (router pilih 1 per token)
"""
def __init__(self, hidden, ff_mult, num_experts, num_active):
super().__init__()
ff_dim = hidden * ff_mult
self.num_experts = num_experts
self.num_active = num_active
self.shared = PhotonExpert(hidden, ff_dim)
self.specialists = nn.ModuleList([
PhotonExpert(hidden, ff_dim) for _ in range(num_experts)
])
self.router = nn.Linear(hidden, num_experts, bias=False)
def forward(self, x):
B, T, C = x.shape
flat = x.view(-1, C)
shared_out = self.shared(flat)
weights = F.softmax(self.router(flat), dim=-1)
topk_w, topk_i = weights.topk(self.num_active, dim=-1)
topk_w = topk_w / topk_w.sum(dim=-1, keepdim=True)
spec_out = torch.zeros_like(flat)
for i in range(self.num_active):
for e in range(self.num_experts):
mask = (topk_i[:, i] == e)
if mask.any():
spec_out[mask] += topk_w[mask, i:i+1] * self.specialists[e](flat[mask])
return (shared_out + spec_out).view(B, T, C)
class LayerSkipRouter(nn.Module):
"""
Adaptive Layer Skipping.
Router kecil per layer yang memutuskan: proses atau lewati?
"""
def __init__(self, hidden, skip_prob=0.3):
super().__init__()
self.skip_prob = skip_prob
self.gate = nn.Linear(hidden, 1, bias=True)
nn.init.constant_(self.gate.bias, 2.0)
def forward(self, x, training=False):
score = torch.sigmoid(self.gate(x.mean(dim=1)))
skip = (score < self.skip_prob).float()
if training:
skip = skip + score - score.detach()
return skip
class PhotonLayer(nn.Module):
def __init__(self, hidden, heads, kv_heads, ff_mult, num_experts, num_active):
super().__init__()
self.norm1 = RMSNorm(hidden)
self.attn = PhotonAttention(hidden, heads, kv_heads)
self.norm2 = RMSNorm(hidden)
self.moe = PhotonMoE(hidden, ff_mult, num_experts, num_active)
self.skip_router = LayerSkipRouter(hidden)
def forward(self, x, training=False):
skip = self.skip_router(x, training=training).unsqueeze(-1)
attn_out = x + self.attn(self.norm1(x))
moe_out = attn_out + self.moe(self.norm2(attn_out))
return torch.where(skip.bool(), attn_out, moe_out)
class PhotonModel(nn.Module):
def __init__(self, vocab, hidden, layers, heads, kv_heads,
ff_mult, num_experts, num_active, max_seq):
super().__init__()
self.embed = nn.Embedding(vocab, hidden)
self.layers = nn.ModuleList([
PhotonLayer(hidden, heads, kv_heads, ff_mult, num_experts, num_active)
for _ in range(layers)
])
self.norm = RMSNorm(hidden)
self.head = nn.Linear(hidden, vocab, bias=False)
self.head.weight = self.embed.weight
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(self, input_ids, attention_mask=None, labels=None):
x = self.embed(input_ids)
for layer in self.layers:
x = layer(x, training=self.training)
x = self.norm(x)
logits = self.head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[:, :-1].contiguous().view(-1, logits.size(-1)),
labels[:, 1:].contiguous().view(-1),
ignore_index=-100
)
return loss, logits