Seed-0.5B / modeling_seed.py
merterbak's picture
Upload folder using huggingface_hub
d68c7eb verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.cache_utils import DynamicCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_seed import SeedConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.epsilon = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon) * self.weight
return x
class RoPEEmbedding(nn.Module):
def __init__(self, config, device=None):
super().__init__()
self.config = config
assert config.n_embd % config.n_head == 0
self.head_dim = config.head_dim
self.rope_scaling_type = str(getattr(config, "rope_scaling_type", "none"))
self.rope_scaling_factor = float(getattr(config, "rope_scaling_factor", 1.0))
base = float(config.rope_theta)
self.position_scale = 1.0
self.attention_scaling = 1.0
if self.rope_scaling_type == "none" or self.rope_scaling_factor == 1.0:
pass
elif self.rope_scaling_type == "yarn":
base = base * (self.rope_scaling_factor ** (self.head_dim / (self.head_dim - 2.0)))
self.attention_scaling = 0.1 * math.log(self.rope_scaling_factor) + 1.0
elif self.rope_scaling_type == "ntk":
base = base * (self.rope_scaling_factor ** (self.head_dim / (self.head_dim - 2.0)))
else:
raise ValueError(f"Unknown rope_scaling_type={self.rope_scaling_type!r}")
self.base = base
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) / float(self.head_dim))
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, position_ids):
dtype = x.dtype
pos = position_ids.float().unsqueeze(-1) * self.position_scale
inv_freq = self.inv_freq.unsqueeze(0).unsqueeze(0)
freqs = pos * inv_freq
emb = torch.cat([freqs, freqs], dim=-1)
cos = (emb.cos() * self.attention_scaling).to(dtype)
sin = (emb.sin() * self.attention_scaling).to(dtype)
return cos, sin
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
class GQA(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = int(layer_idx)
self.n_head = config.n_head
self.n_kv_head = int(getattr(config, "n_kv_head", config.n_head))
self.n_embd = config.n_embd
self.block_size = int(config.block_size)
assert 1 <= self.n_kv_head <= self.n_head
assert self.n_head % self.n_kv_head == 0
self.head_dim = config.head_dim
q_dim = self.n_head * self.head_dim
kv_dim = self.n_kv_head * self.head_dim
self.q_proj = nn.Linear(self.n_embd, q_dim, bias=False)
self.k_proj = nn.Linear(self.n_embd, kv_dim, bias=False)
self.v_proj = nn.Linear(self.n_embd, kv_dim, bias=False)
self.o_proj = nn.Linear(q_dim, self.n_embd, bias=False)
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
def forward(self, x, cos, sin, past_key_values=None):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
past_len = 0
if past_key_values is not None:
past_len = past_key_values.get_seq_length(self.layer_idx)
k, v = past_key_values.update(k, v, self.layer_idx)
if self.n_kv_head != self.n_head:
repeat_factor = self.n_head // self.n_kv_head
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
if past_len == 0:
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
else:
Tk = int(k.size(2))
query_pos = past_len + torch.arange(T, device=x.device)
key_pos = torch.arange(Tk, device=x.device)
causal_mask = key_pos.unsqueeze(0) <= query_pos.unsqueeze(1)
attn_mask = torch.zeros((1, 1, T, Tk), device=x.device, dtype=q.dtype)
attn_mask = attn_mask.masked_fill(~causal_mask.view(1, 1, T, Tk), torch.finfo(q.dtype).min)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.o_proj(y)
return y
class SwiGLU(nn.Module):
def __init__(self, config):
super().__init__()
self.n_embd = config.n_embd
hidden_dim = getattr(config, "mlp_hidden_dim", None)
if hidden_dim is None:
hidden_dim = int(4 * self.n_embd * 2 / 3)
hidden_dim = (hidden_dim + 255) // 256 * 256
self.gate_proj = nn.Linear(self.n_embd, hidden_dim, bias=config.bias)
self.up_proj = nn.Linear(self.n_embd, hidden_dim, bias=config.bias)
self.down_proj = nn.Linear(hidden_dim, self.n_embd, bias=config.bias)
def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
x = self.down_proj(F.silu(gate) * up)
return x
class DecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.input_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
self.attn = GQA(config, layer_idx=layer_idx)
self.mlp = SwiGLU(config)
def forward(self, x, cos, sin, past_key_values=None):
residual = x
x = self.input_norm(x)
x = self.attn(x, cos, sin, past_key_values=past_key_values)
x = residual + x
residual = x
x = self.post_attn_norm(x)
x = self.mlp(x)
x = residual + x
return x
class SeedPreTrainedModel(PreTrainedModel):
config_class = SeedConfig
base_model_prefix = "model"
_no_split_modules = ["DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_sdpa = True
class SeedForCausalLM(SeedPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.n_layer)])
self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.rope = RoPEEmbedding(config)
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, value):
self.wte = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
token_type_ids=None,
**kwargs
):
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
B, T = inputs_embeds.shape[:2]
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if position_ids is None:
past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(past_seen, past_seen + T, device=inputs_embeds.device).unsqueeze(0).expand(B, T)
cos, sin = self.rope(inputs_embeds, position_ids)
x = inputs_embeds
for layer in self.layers:
x = layer(x, cos, sin, past_key_values=past_key_values)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[:, :-1].contiguous().view(-1, logits.size(-1)),
labels[:, 1:].contiguous().view(-1)
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values if use_cache else None
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
past_length = past_key_values.get_seq_length()
if past_length > 0:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update({
"past_key_values": past_key_values,
"use_cache": True,
"attention_mask": attention_mask,
})
return model_inputs