Wind-Edge-1.6-Instruct / modeling_wind_edge.py
arthu1's picture
Replace with 20M-token corrected instruct build
1ddaf2d verified
"""Wind Edge causal LM — RMSNorm + RoPE + GQA + SwiGLU dense transformer."""
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.generation import GenerationMixin
from .configuration_wind_edge import WindEdgeConfig
class WindEdgeRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
in_dtype = x.dtype
x = x.to(torch.float32)
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.variance_epsilon)
return (self.weight * x).to(in_dtype)
def _build_rope_cache(seq_len: int, head_dim: int, theta: float, device, dtype):
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos().to(dtype), emb.sin().to(dtype)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_out = (q * cos) + (_rotate_half(q) * sin)
k_out = (k * cos) + (_rotate_half(k) * sin)
return q_out, k_out
def _padding_bias(attention_mask: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return (1.0 - attention_mask.to(ref.dtype))[:, None, None, :] * torch.finfo(ref.dtype).min
class WindEdgeAttention(nn.Module):
def __init__(self, config: WindEdgeConfig):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.scale = self.head_dim ** -0.5
q_out = self.num_heads * self.head_dim
kv_out = self.num_kv_heads * self.head_dim
self.q_proj = nn.Linear(self.hidden_size, q_out, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, kv_out, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, kv_out, bias=config.attention_bias)
self.o_proj = nn.Linear(q_out, self.hidden_size, bias=config.attention_bias)
self.q_norm = WindEdgeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = WindEdgeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(self, x, cos, sin, attention_mask=None):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q, k = _apply_rope(q, k, cos, sin)
if x.is_cuda and hasattr(F, "scaled_dot_product_attention"):
try:
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=attention_mask is None,
enable_gqa=self.num_kv_heads != self.num_heads,
)
out = out.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_dim)
return self.o_proj(out)
except TypeError:
# Older torch builds may not support enable_gqa; fall back to the manual path.
pass
if self.num_kv_heads != self.num_heads:
repeats = self.num_heads // self.num_kv_heads
k = k.repeat_interleave(repeats, dim=1)
v = v.repeat_interleave(repeats, dim=1)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if attention_mask is not None:
attn = attn + attention_mask
attn = F.softmax(attn.float(), dim=-1).to(q.dtype)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_dim)
return self.o_proj(out)
class WindEdgeMLP(nn.Module):
def __init__(self, config: WindEdgeConfig):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class WindEdgeBlock(nn.Module):
def __init__(self, config: WindEdgeConfig):
super().__init__()
self.input_layernorm = WindEdgeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = WindEdgeAttention(config)
self.post_attention_layernorm = WindEdgeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = WindEdgeMLP(config)
def forward(self, x, cos, sin, attention_mask=None):
x = x + self.self_attn(self.input_layernorm(x), cos, sin, attention_mask)
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class WindEdgePreTrainedModel(PreTrainedModel):
config_class = WindEdgeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["WindEdgeBlock"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(0.0, std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(0.0, std)
class WindEdgeModel(WindEdgePreTrainedModel):
def __init__(self, config: WindEdgeConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([WindEdgeBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = WindEdgeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None):
B, T = input_ids.shape
x = self.embed_tokens(input_ids)
cos, sin = _build_rope_cache(T, self.config.head_dim, self.config.rope_theta, x.device, x.dtype)
causal = torch.triu(torch.full((T, T), float("-inf"), device=x.device, dtype=x.dtype), diagonal=1)
if attention_mask is not None:
pad = _padding_bias(attention_mask, x)
mask = causal[None, None, :, :] + pad
else:
mask = None if x.is_cuda and hasattr(F, "scaled_dot_product_attention") else causal[None, None, :, :]
for layer in self.layers:
if self.gradient_checkpointing and self.training:
x = torch.utils.checkpoint.checkpoint(layer, x, cos, sin, mask, use_reentrant=False)
else:
x = layer(x, cos, sin, mask)
return self.norm(x)
class WindEdgeForCausalLM(WindEdgePreTrainedModel, GenerationMixin):
# transformers 5.x requires the dict form for `_tied_weights_keys`, but the default
# `from_pretrained` then silently fails to copy disk weights into the in-RAM params
# for this model — they end up at the freshly-initialised values (~N(0, 0.02)).
# We override `from_pretrained` below to manually re-apply the safetensors after load.
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config: WindEdgeConfig):
super().__init__(config)
self.model = WindEdgeModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""Override to work around a tx 5.x bug where saved weights are not applied
to in-RAM params when `_tied_weights_keys` is a dict. We let the parent build
the module, then manually copy every key from the on-disk safetensors into the
matching parameter and re-tie lm_head <- embed_tokens."""
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
try:
import os
from safetensors.torch import safe_open
sd_path = pretrained_model_name_or_path
if os.path.isdir(sd_path):
shards = [f for f in os.listdir(sd_path) if f.endswith(".safetensors")]
if not shards:
return model
sd = {}
for shard in shards:
with safe_open(os.path.join(sd_path, shard), framework="pt") as f:
for k in f.keys():
sd[k] = f.get_tensor(k)
missing, unexpected = model.load_state_dict(sd, strict=False)
# Re-tie lm_head to embed_tokens (the saved file omits lm_head.weight).
model.lm_head.weight = model.model.embed_tokens.weight
except Exception:
pass
return model
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, value):
self.lm_head = value
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
):
hidden = self.model(input_ids, attention_mask=attention_mask)
logits = self.lm_head(hidden)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)).float(),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits)