|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import ( |
|
|
PreTrainedModel, |
|
|
PretrainedConfig, |
|
|
AutoConfig, |
|
|
AutoModelForCausalLM |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_rope_cache(seq_len, head_dim, device): |
|
|
half_dim = head_dim // 2 |
|
|
freq_seq = torch.arange(half_dim, device=device, dtype=torch.float32) |
|
|
inv_freq = 1.0 / (10000 ** (freq_seq / half_dim)) |
|
|
t = torch.arange(seq_len, device=device, dtype=torch.float32) |
|
|
freqs = torch.outer(t, inv_freq) |
|
|
cos, sin = torch.cos(freqs), torch.sin(freqs) |
|
|
return cos, sin |
|
|
|
|
|
def apply_rope(x, cos, sin): |
|
|
|
|
|
B, T, H, D = x.shape |
|
|
cos = cos[:T, :].unsqueeze(0).unsqueeze(2) |
|
|
sin = sin[:T, :].unsqueeze(0).unsqueeze(2) |
|
|
x1 = x[..., ::2] |
|
|
x2 = x[..., 1::2] |
|
|
out = torch.cat([x1 * cos - x2 * sin, |
|
|
x1 * sin + x2 * cos], dim=-1) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenThaiWilaiConfig(PretrainedConfig): |
|
|
model_type = "OpenThaiWilai" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=50000, |
|
|
hidden_size=768, |
|
|
num_layers=6, |
|
|
num_heads=8, |
|
|
num_key_value_heads=None, |
|
|
num_experts=4, |
|
|
top_k=2, |
|
|
max_position_embeddings=2048, |
|
|
intermediate_size=3072, |
|
|
rope=True, |
|
|
use_flashattn=True, |
|
|
eos_token_id=None, |
|
|
bos_token_id=None, |
|
|
pad_token_id=None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_layers = num_layers |
|
|
self.num_hidden_layers = num_layers |
|
|
self.num_heads = num_heads |
|
|
self.num_key_value_heads = num_key_value_heads or num_heads |
|
|
self.num_experts = num_experts |
|
|
self.top_k = top_k |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.intermediate_size = intermediate_size |
|
|
self.rope = rope |
|
|
self.use_flashattn = use_flashattn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, d, eps=1e-6): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(d)) |
|
|
self.eps = eps |
|
|
def forward(self, x): |
|
|
norm = x.norm(dim=-1, keepdim=True) * (1.0 / math.sqrt(x.size(-1))) |
|
|
return self.weight * x / (norm + self.eps) |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__(self, d_model, d_ff): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(d_model, d_ff) |
|
|
self.w2 = nn.Linear(d_model, d_ff) |
|
|
def forward(self, x): |
|
|
return F.silu(self.w1(x)) * self.w2(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func |
|
|
FLASH_AVAILABLE = True |
|
|
except ImportError: |
|
|
FLASH_AVAILABLE = False |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
def __init__(self, config: OpenThaiWilaiConfig): |
|
|
super().__init__() |
|
|
self.num_heads = config.num_heads |
|
|
self.num_kv_heads = config.num_key_value_heads |
|
|
self.head_dim = config.hidden_size // config.num_heads |
|
|
|
|
|
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim) |
|
|
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim) |
|
|
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
|
|
self.rope = config.rope |
|
|
self.use_flash = config.use_flashattn |
|
|
|
|
|
def forward(self, x, attention_mask=None): |
|
|
B, T, C = x.shape |
|
|
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim) |
|
|
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim) |
|
|
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim) |
|
|
|
|
|
|
|
|
if self.rope: |
|
|
cos, sin = build_rope_cache(T, self.head_dim, x.device) |
|
|
q = apply_rope(q, cos, sin) |
|
|
k = apply_rope(k, cos, sin) |
|
|
|
|
|
|
|
|
if self.num_kv_heads != self.num_heads: |
|
|
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2) |
|
|
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2) |
|
|
|
|
|
|
|
|
if self.use_flash and FLASH_AVAILABLE and torch.cuda.get_device_capability()[0] >= 8: |
|
|
q = q.permute(0, 2, 1, 3) |
|
|
k = k.permute(0, 2, 1, 3) |
|
|
v = v.permute(0, 2, 1, 3) |
|
|
out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True) |
|
|
out = out.permute(0, 2, 1, 3).reshape(B, T, C) |
|
|
else: |
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
|
if attention_mask is not None: |
|
|
attn = attn.masked_fill(attention_mask == 0, float("-inf")) |
|
|
attn = F.softmax(attn, dim=-1) |
|
|
out = attn @ v |
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
|
|
|
return self.o_proj(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MoE(nn.Module): |
|
|
def __init__(self, config: OpenThaiWilaiConfig): |
|
|
super().__init__() |
|
|
self.experts = nn.ModuleList([ |
|
|
SwiGLU(config.hidden_size, config.intermediate_size) for _ in range(config.num_experts) |
|
|
]) |
|
|
self.gate = nn.Linear(config.hidden_size, config.num_experts) |
|
|
self.top_k = config.top_k |
|
|
self.num_experts = config.num_experts |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.shape |
|
|
scores = F.softmax(self.gate(x), dim=-1) |
|
|
current_top_k = min(self.top_k, self.num_experts) |
|
|
topk_scores, topk_idx = torch.topk(scores, current_top_k, dim=-1) |
|
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2) |
|
|
topk_idx_expanded = topk_idx.unsqueeze(-1).expand(-1, -1, -1, C) |
|
|
selected_expert_outputs = torch.gather(expert_outputs, dim=2, index=topk_idx_expanded) |
|
|
topk_scores_expanded = topk_scores.unsqueeze(-1).expand(-1, -1, -1, C) |
|
|
weighted_expert_outputs = selected_expert_outputs * topk_scores_expanded |
|
|
|
|
|
aux_loss = (scores.mean(0).var(dim=-1)).mean() |
|
|
self.last_aux_loss = aux_loss |
|
|
|
|
|
return torch.sum(weighted_expert_outputs, dim=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config: OpenThaiWilaiConfig): |
|
|
super().__init__() |
|
|
self.ln1 = RMSNorm(config.hidden_size) |
|
|
self.attn = MultiHeadAttention(config) |
|
|
self.ln2 = RMSNorm(config.hidden_size) |
|
|
self.moe = MoE(config) |
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.ln1(x)) |
|
|
x = x + self.moe(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenThaiWilaiForCausalLM(PreTrainedModel, GenerationMixin): |
|
|
config_class = OpenThaiWilaiConfig |
|
|
_keys_to_ignore_on_save = [] |
|
|
_dynamic_tied_weights_keys = {"lm_head.weight", "embed.weight"} |
|
|
|
|
|
def __init__(self, config: OpenThaiWilaiConfig): |
|
|
super().__init__(config) |
|
|
self.embed = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) |
|
|
self.ln_f = RMSNorm(config.hidden_size) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
self.tie_weights() |
|
|
|
|
|
def tie_weights(self): |
|
|
self.lm_head.weight = self.embed.weight |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
|
|
return {"input_ids": input_ids, "past_key_values": past_key_values} |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids, |
|
|
labels=None, |
|
|
attention_mask=None, |
|
|
past_key_values=None, |
|
|
use_cache: bool = False, |
|
|
**kwargs |
|
|
): |
|
|
B, T = input_ids.shape |
|
|
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0) |
|
|
x = self.embed(input_ids) + self.pos_embed(pos) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.ln_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
aux_loss = 0 |
|
|
for block in self.blocks: |
|
|
if hasattr(block.moe, "last_aux_loss"): |
|
|
aux_loss += block.moe.last_aux_loss |
|
|
|
|
|
if labels is not None: |
|
|
ce_loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), |
|
|
labels.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
loss = ce_loss + 0.01 * aux_loss |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=past_key_values if use_cache else None, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("OpenThaiWilai", OpenThaiWilaiConfig) |
|
|
AutoModelForCausalLM.register(OpenThaiWilaiConfig, OpenThaiWilaiForCausalLM) |