Wilai / modeling_openthaiwilai.py
JonusNattapong's picture
Create modeling_openthaiwilai.py
4ee78f9 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModelForCausalLM
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.generation.utils import GenerationMixin
# ------------------------------------------------------------
# 🧩 Rotary Positional Embedding (RoPE)
# ------------------------------------------------------------
def build_rope_cache(seq_len, head_dim, device):
half_dim = head_dim // 2
freq_seq = torch.arange(half_dim, device=device, dtype=torch.float32)
inv_freq = 1.0 / (10000 ** (freq_seq / half_dim))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq) # (seq_len, half_dim)
cos, sin = torch.cos(freqs), torch.sin(freqs)
return cos, sin
def apply_rope(x, cos, sin):
# x: (B, T, H, D)
B, T, H, D = x.shape
cos = cos[:T, :].unsqueeze(0).unsqueeze(2) # (1, T, 1, D/2)
sin = sin[:T, :].unsqueeze(0).unsqueeze(2)
x1 = x[..., ::2]
x2 = x[..., 1::2]
out = torch.cat([x1 * cos - x2 * sin,
x1 * sin + x2 * cos], dim=-1)
return out
# ------------------------------------------------------------
# 🧩 Config
# ------------------------------------------------------------
class OpenThaiWilaiConfig(PretrainedConfig):
model_type = "OpenThaiWilai"
def __init__(
self,
vocab_size=50000,
hidden_size=768,
num_layers=6,
num_heads=8,
num_key_value_heads=None,
num_experts=4,
top_k=2,
max_position_embeddings=2048,
intermediate_size=3072,
rope=True,
use_flashattn=True,
eos_token_id=None,
bos_token_id=None,
pad_token_id=None,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_hidden_layers = num_layers
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads or num_heads
self.num_experts = num_experts
self.top_k = top_k
self.max_position_embeddings = max_position_embeddings
self.intermediate_size = intermediate_size
self.rope = rope
self.use_flashattn = use_flashattn
# ------------------------------------------------------------
# 🧩 Custom Components
# ------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, d, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d))
self.eps = eps
def forward(self, x):
norm = x.norm(dim=-1, keepdim=True) * (1.0 / math.sqrt(x.size(-1)))
return self.weight * x / (norm + self.eps)
class SwiGLU(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_model, d_ff)
def forward(self, x):
return F.silu(self.w1(x)) * self.w2(x)
# ------------------------------------------------------------
# 🧩 Multi-Head Attention with RoPE + FlashAttention + GQA
# ------------------------------------------------------------
try:
from flash_attn import flash_attn_func
FLASH_AVAILABLE = True
except ImportError:
FLASH_AVAILABLE = False
class MultiHeadAttention(nn.Module):
def __init__(self, config: OpenThaiWilaiConfig):
super().__init__()
self.num_heads = config.num_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_heads
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.rope = config.rope
self.use_flash = config.use_flashattn
def forward(self, x, attention_mask=None):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
# RoPE
if self.rope:
cos, sin = build_rope_cache(T, self.head_dim, x.device)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
# GQA
if self.num_kv_heads != self.num_heads:
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
# FlashAttention fallback
if self.use_flash and FLASH_AVAILABLE and torch.cuda.get_device_capability()[0] >= 8:
q = q.permute(0, 2, 1, 3) # (B, H, T, D)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True)
out = out.permute(0, 2, 1, 3).reshape(B, T, C)
else:
q = q.transpose(1, 2) # (B, H, T, D)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn = attn.masked_fill(attention_mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
# ------------------------------------------------------------
# 🧩 MoE with load balancing
# ------------------------------------------------------------
class MoE(nn.Module):
def __init__(self, config: OpenThaiWilaiConfig):
super().__init__()
self.experts = nn.ModuleList([
SwiGLU(config.hidden_size, config.intermediate_size) for _ in range(config.num_experts)
])
self.gate = nn.Linear(config.hidden_size, config.num_experts)
self.top_k = config.top_k
self.num_experts = config.num_experts
def forward(self, x):
B, T, C = x.shape
scores = F.softmax(self.gate(x), dim=-1)
current_top_k = min(self.top_k, self.num_experts)
topk_scores, topk_idx = torch.topk(scores, current_top_k, dim=-1)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
topk_idx_expanded = topk_idx.unsqueeze(-1).expand(-1, -1, -1, C)
selected_expert_outputs = torch.gather(expert_outputs, dim=2, index=topk_idx_expanded)
topk_scores_expanded = topk_scores.unsqueeze(-1).expand(-1, -1, -1, C)
weighted_expert_outputs = selected_expert_outputs * topk_scores_expanded
aux_loss = (scores.mean(0).var(dim=-1)).mean()
self.last_aux_loss = aux_loss
return torch.sum(weighted_expert_outputs, dim=2)
# ------------------------------------------------------------
# 🧩 Transformer Block
# ------------------------------------------------------------
class Block(nn.Module):
def __init__(self, config: OpenThaiWilaiConfig):
super().__init__()
self.ln1 = RMSNorm(config.hidden_size)
self.attn = MultiHeadAttention(config)
self.ln2 = RMSNorm(config.hidden_size)
self.moe = MoE(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.moe(self.ln2(x))
return x
# ------------------------------------------------------------
# 🧩 OpenThaiWilai For Causal LM
# ------------------------------------------------------------
class OpenThaiWilaiForCausalLM(PreTrainedModel, GenerationMixin):
config_class = OpenThaiWilaiConfig
_keys_to_ignore_on_save = []
_dynamic_tied_weights_keys = {"lm_head.weight", "embed.weight"}
def __init__(self, config: OpenThaiWilaiConfig):
super().__init__(config)
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
self.ln_f = RMSNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.embed.weight
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
return {"input_ids": input_ids, "past_key_values": past_key_values}
def forward(
self,
input_ids,
labels=None,
attention_mask=None,
past_key_values=None,
use_cache: bool = False,
**kwargs
):
B, T = input_ids.shape
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
x = self.embed(input_ids) + self.pos_embed(pos)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
aux_loss = 0
for block in self.blocks:
if hasattr(block.moe, "last_aux_loss"):
aux_loss += block.moe.last_aux_loss
if labels is not None:
ce_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
loss = ce_loss + 0.01 * aux_loss
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=past_key_values if use_cache else None,
hidden_states=None,
attentions=None,
)
# ------------------------------------------------------------
# 🧩 Register model for Auto classes
# ------------------------------------------------------------
AutoConfig.register("OpenThaiWilai", OpenThaiWilaiConfig)
AutoModelForCausalLM.register(OpenThaiWilaiConfig, OpenThaiWilaiForCausalLM)