summerV2 / modeling_van_fast.py
summerMC's picture
Upload folder using huggingface_hub
ec757bc verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
def safe_tensor(x, clamp=30.0):
x = torch.nan_to_num(
x,
nan=0.0,
posinf=clamp,
neginf=-clamp,
)
x = torch.clamp(x, min=-clamp, max=clamp)
return x
class VanFastConfig(PretrainedConfig):
model_type = "van_fast_transformer"
def __init__(
self,
vocab_size=50257,
block_size=1024,
d_model=1024,
n_layer=18,
n_head=16,
n_kv_head=4,
d_ff=4096,
dropout=0.0,
use_qk_norm=True,
initializer_range=0.02,
pad_token_id=None,
eos_token_id=None,
bos_token_id=None,
use_cache=True,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
bos_token_id=bos_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.block_size = block_size
self.d_model = d_model
self.n_layer = n_layer
self.n_head = n_head
self.n_kv_head = n_kv_head
self.d_ff = d_ff
self.dropout = dropout
self.use_qk_norm = use_qk_norm
self.initializer_range = initializer_range
self.is_decoder = True
self.is_encoder_decoder = False
self.tie_word_embeddings = False
self.use_cache = use_cache
class HFRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = safe_tensor(x, clamp=30.0)
x_float = x.float()
var = x_float.pow(2).mean(dim=-1, keepdim=True)
var = torch.nan_to_num(var, nan=1.0, posinf=1.0, neginf=1.0)
var = torch.clamp(var, min=0.0, max=1e6)
y = x_float * torch.rsqrt(var + self.eps)
y = y.to(dtype=x.dtype) * self.weight.to(dtype=x.dtype)
y = safe_tensor(y, clamp=30.0)
return y
class HFRotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
)
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
self.register_buffer("cos_cached", cos[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", sin[None, None, :, :], persistent=False)
def forward(self, x, seq_len: int, offset: int = 0):
end = offset + seq_len
max_len = self.cos_cached.shape[2]
if end > max_len:
# block_sizeを超えた場合は最後の範囲に丸める
offset = max(0, max_len - seq_len)
end = offset + seq_len
cos = self.cos_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype)
sin = self.sin_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype)
return cos, sin
def hf_apply_rope(q, k, cos, sin):
q1 = q[..., ::2]
q2 = q[..., 1::2]
k1 = k[..., ::2]
k2 = k[..., 1::2]
q_rot = torch.stack(
[
q1 * cos - q2 * sin,
q1 * sin + q2 * cos,
],
dim=-1,
).flatten(-2)
k_rot = torch.stack(
[
k1 * cos - k2 * sin,
k1 * sin + k2 * cos,
],
dim=-1,
).flatten(-2)
q_rot = safe_tensor(q_rot, clamp=10.0)
k_rot = safe_tensor(k_rot, clamp=10.0)
return q_rot, k_rot
class HFGQAAttention(nn.Module):
def __init__(self, config: VanFastConfig):
super().__init__()
d_model = config.d_model
n_head = config.n_head
n_kv_head = config.n_kv_head
assert d_model % n_head == 0
assert n_head % n_kv_head == 0
self.d_model = d_model
self.n_head = n_head
self.n_kv_head = n_kv_head
self.head_dim = d_model // n_head
self.num_groups = n_head // n_kv_head
self.dropout = config.dropout
self.block_size = config.block_size
assert self.head_dim % 2 == 0
self.q_proj = nn.Linear(d_model, n_head * self.head_dim, bias=False)
self.k_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
if config.use_qk_norm:
self.q_norm = HFRMSNorm(self.head_dim)
self.k_norm = HFRMSNorm(self.head_dim)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.rope = HFRotaryEmbedding(
dim=self.head_dim,
max_seq_len=config.block_size,
)
def forward(
self,
x,
past_key_value=None,
use_cache=False,
):
x = safe_tensor(x, clamp=30.0)
B, T, C = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = safe_tensor(q, clamp=30.0)
k = safe_tensor(k, clamp=30.0)
v = safe_tensor(v, clamp=30.0)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
q = safe_tensor(q, clamp=10.0)
k = safe_tensor(k, clamp=10.0)
v = safe_tensor(v, clamp=30.0)
past_len = 0
if past_key_value is not None:
past_k, past_v = past_key_value
past_len = past_k.shape[2]
cos, sin = self.rope(q, T, offset=past_len)
q, k = hf_apply_rope(q, k, cos, sin)
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
# cache長をblock_size以内に制限
if k.shape[2] > self.block_size:
k = k[:, :, -self.block_size:, :].contiguous()
v = v[:, :, -self.block_size:, :].contiguous()
present_key_value = (k, v) if use_cache else None
k_attn = k
v_attn = v
if self.num_groups > 1:
k_attn = k_attn.repeat_interleave(self.num_groups, dim=1)
v_attn = v_attn.repeat_interleave(self.num_groups, dim=1)
# prefill時はcausal、decode時はqueryが最新1tokenなので全cacheへattend可能
is_causal = past_key_value is None
y = F.scaled_dot_product_attention(
q,
k_attn,
v_attn,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
y = safe_tensor(y, clamp=30.0)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.o_proj(y)
y = safe_tensor(y, clamp=30.0)
return y, present_key_value
class HFSwiGLU(nn.Module):
def __init__(self, config: VanFastConfig):
super().__init__()
self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False)
self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False)
def forward(self, x):
x = safe_tensor(x, clamp=30.0)
a = self.w1(x)
b = self.w3(x)
a = safe_tensor(a, clamp=30.0)
b = safe_tensor(b, clamp=30.0)
y = F.silu(a) * b
y = safe_tensor(y, clamp=30.0)
y = self.w2(y)
y = safe_tensor(y, clamp=30.0)
return y
class HFDecoderBlock(nn.Module):
def __init__(self, config: VanFastConfig):
super().__init__()
self.attn_norm = HFRMSNorm(config.d_model)
self.attn = HFGQAAttention(config)
self.ffn_norm = HFRMSNorm(config.d_model)
self.ffn = HFSwiGLU(config)
def forward(
self,
x,
past_key_value=None,
use_cache=False,
):
x = safe_tensor(x, clamp=30.0)
a, present_key_value = self.attn(
self.attn_norm(x),
past_key_value=past_key_value,
use_cache=use_cache,
)
a = safe_tensor(a, clamp=30.0)
x = safe_tensor(x + a, clamp=30.0)
f = self.ffn(self.ffn_norm(x))
f = safe_tensor(f, clamp=30.0)
x = safe_tensor(x + f, clamp=30.0)
return x, present_key_value
class VanFastForCausalLM(PreTrainedModel, GenerationMixin):
config_class = VanFastConfig
base_model_prefix = "van_fast"
supports_gradient_checkpointing = False
_supports_cache_class = False
def __init__(self, config: VanFastConfig):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([
HFDecoderBlock(config)
for _ in range(config.n_layer)
])
self.norm = HFRMSNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", 0.02)
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
def get_input_embeddings(self):
return self.token_emb
def set_input_embeddings(self, value):
self.token_emb = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def _normalize_past(self, past_key_values):
if past_key_values is None:
return [None] * len(self.blocks)
if isinstance(past_key_values, tuple):
past_key_values = list(past_key_values)
if len(past_key_values) < len(self.blocks):
past_key_values = past_key_values + [None] * (
len(self.blocks) - len(past_key_values)
)
return past_key_values
def forward(
self,
input_ids=None,
labels=None,
attention_mask=None,
past_key_values=None,
use_cache=None,
return_dict=True,
**kwargs,
):
if input_ids is None:
raise ValueError("input_ids is required")
if use_cache is None:
use_cache = getattr(self.config, "use_cache", True)
has_past = past_key_values is not None
# cache使用時は新規tokenだけ処理
if has_past and input_ids.shape[1] > 1:
input_ids = input_ids[:, -1:]
# cacheなしのprefill時だけblock_sizeに丸める
if not has_past and input_ids.shape[1] > self.config.block_size:
input_ids = input_ids[:, -self.config.block_size:]
if labels is not None:
labels = labels[:, -self.config.block_size:]
past_key_values = self._normalize_past(past_key_values)
x = self.token_emb(input_ids)
x = safe_tensor(x, clamp=30.0)
x = self.drop(x)
presents = [] if use_cache else None
for i, block in enumerate(self.blocks):
layer_past = past_key_values[i]
x, present = block(
x,
past_key_value=layer_past,
use_cache=use_cache,
)
if use_cache:
presents.append(present)
x = self.norm(x)
x = safe_tensor(x, clamp=30.0)
logits = self.lm_head(x)
logits = logits.float()
logits = torch.nan_to_num(
logits,
nan=0.0,
posinf=80.0,
neginf=-80.0,
)
logits = torch.clamp(logits, min=-80.0, max=80.0)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
if shift_logits.numel() > 0:
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
past_out = tuple(presents) if use_cache else None
if not return_dict:
if loss is None:
return (logits, past_out)
return (loss, logits, past_out)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_out,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=True,
**kwargs,
):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
else:
if input_ids.shape[1] > self.config.block_size:
input_ids = input_ids[:, -self.config.block_size:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
def _reorder_cache(self, past_key_values, beam_idx):
if past_key_values is None:
return None
reordered = []
for layer_past in past_key_values:
if layer_past is None:
reordered.append(None)
continue
k, v = layer_past
reordered.append(
(
k.index_select(0, beam_idx.to(k.device)),
v.index_select(0, beam_idx.to(v.device)),
)
)
return tuple(reordered)