import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast def safe_tensor(x, clamp=30.0): x = torch.nan_to_num( x, nan=0.0, posinf=clamp, neginf=-clamp, ) x = torch.clamp(x, min=-clamp, max=clamp) return x class VanFastConfig(PretrainedConfig): model_type = "van_fast_transformer" def __init__( self, vocab_size=50257, block_size=1024, d_model=1024, n_layer=18, n_head=16, n_kv_head=4, d_ff=4096, dropout=0.0, use_qk_norm=True, initializer_range=0.02, pad_token_id=None, eos_token_id=None, bos_token_id=None, use_cache=True, **kwargs, ): super().__init__( pad_token_id=pad_token_id, eos_token_id=eos_token_id, bos_token_id=bos_token_id, **kwargs, ) self.vocab_size = vocab_size self.block_size = block_size self.d_model = d_model self.n_layer = n_layer self.n_head = n_head self.n_kv_head = n_kv_head self.d_ff = d_ff self.dropout = dropout self.use_qk_norm = use_qk_norm self.initializer_range = initializer_range self.is_decoder = True self.is_encoder_decoder = False self.tie_word_embeddings = False self.use_cache = use_cache class HFRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): x = safe_tensor(x, clamp=30.0) x_float = x.float() var = x_float.pow(2).mean(dim=-1, keepdim=True) var = torch.nan_to_num(var, nan=1.0, posinf=1.0, neginf=1.0) var = torch.clamp(var, min=0.0, max=1e6) y = x_float * torch.rsqrt(var + self.eps) y = y.to(dtype=x.dtype) * self.weight.to(dtype=x.dtype) y = safe_tensor(y, clamp=30.0) return y class HFRotaryEmbedding(nn.Module): def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) ) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() self.register_buffer("cos_cached", cos[None, None, :, :], persistent=False) self.register_buffer("sin_cached", sin[None, None, :, :], persistent=False) def forward(self, x, seq_len: int, offset: int = 0): end = offset + seq_len max_len = self.cos_cached.shape[2] if end > max_len: # block_sizeを超えた場合は最後の範囲に丸める offset = max(0, max_len - seq_len) end = offset + seq_len cos = self.cos_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype) sin = self.sin_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype) return cos, sin def hf_apply_rope(q, k, cos, sin): q1 = q[..., ::2] q2 = q[..., 1::2] k1 = k[..., ::2] k2 = k[..., 1::2] q_rot = torch.stack( [ q1 * cos - q2 * sin, q1 * sin + q2 * cos, ], dim=-1, ).flatten(-2) k_rot = torch.stack( [ k1 * cos - k2 * sin, k1 * sin + k2 * cos, ], dim=-1, ).flatten(-2) q_rot = safe_tensor(q_rot, clamp=10.0) k_rot = safe_tensor(k_rot, clamp=10.0) return q_rot, k_rot class HFGQAAttention(nn.Module): def __init__(self, config: VanFastConfig): super().__init__() d_model = config.d_model n_head = config.n_head n_kv_head = config.n_kv_head assert d_model % n_head == 0 assert n_head % n_kv_head == 0 self.d_model = d_model self.n_head = n_head self.n_kv_head = n_kv_head self.head_dim = d_model // n_head self.num_groups = n_head // n_kv_head self.dropout = config.dropout self.block_size = config.block_size assert self.head_dim % 2 == 0 self.q_proj = nn.Linear(d_model, n_head * self.head_dim, bias=False) self.k_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False) self.v_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) if config.use_qk_norm: self.q_norm = HFRMSNorm(self.head_dim) self.k_norm = HFRMSNorm(self.head_dim) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() self.rope = HFRotaryEmbedding( dim=self.head_dim, max_seq_len=config.block_size, ) def forward( self, x, past_key_value=None, use_cache=False, ): x = safe_tensor(x, clamp=30.0) B, T, C = x.shape q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) q = safe_tensor(q, clamp=30.0) k = safe_tensor(k, clamp=30.0) v = safe_tensor(v, clamp=30.0) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) q = self.q_norm(q) k = self.k_norm(k) q = safe_tensor(q, clamp=10.0) k = safe_tensor(k, clamp=10.0) v = safe_tensor(v, clamp=30.0) past_len = 0 if past_key_value is not None: past_k, past_v = past_key_value past_len = past_k.shape[2] cos, sin = self.rope(q, T, offset=past_len) q, k = hf_apply_rope(q, k, cos, sin) if past_key_value is not None: past_k, past_v = past_key_value k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) # cache長をblock_size以内に制限 if k.shape[2] > self.block_size: k = k[:, :, -self.block_size:, :].contiguous() v = v[:, :, -self.block_size:, :].contiguous() present_key_value = (k, v) if use_cache else None k_attn = k v_attn = v if self.num_groups > 1: k_attn = k_attn.repeat_interleave(self.num_groups, dim=1) v_attn = v_attn.repeat_interleave(self.num_groups, dim=1) # prefill時はcausal、decode時はqueryが最新1tokenなので全cacheへattend可能 is_causal = past_key_value is None y = F.scaled_dot_product_attention( q, k_attn, v_attn, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) y = safe_tensor(y, clamp=30.0) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.o_proj(y) y = safe_tensor(y, clamp=30.0) return y, present_key_value class HFSwiGLU(nn.Module): def __init__(self, config: VanFastConfig): super().__init__() self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False) self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False) self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False) def forward(self, x): x = safe_tensor(x, clamp=30.0) a = self.w1(x) b = self.w3(x) a = safe_tensor(a, clamp=30.0) b = safe_tensor(b, clamp=30.0) y = F.silu(a) * b y = safe_tensor(y, clamp=30.0) y = self.w2(y) y = safe_tensor(y, clamp=30.0) return y class HFDecoderBlock(nn.Module): def __init__(self, config: VanFastConfig): super().__init__() self.attn_norm = HFRMSNorm(config.d_model) self.attn = HFGQAAttention(config) self.ffn_norm = HFRMSNorm(config.d_model) self.ffn = HFSwiGLU(config) def forward( self, x, past_key_value=None, use_cache=False, ): x = safe_tensor(x, clamp=30.0) a, present_key_value = self.attn( self.attn_norm(x), past_key_value=past_key_value, use_cache=use_cache, ) a = safe_tensor(a, clamp=30.0) x = safe_tensor(x + a, clamp=30.0) f = self.ffn(self.ffn_norm(x)) f = safe_tensor(f, clamp=30.0) x = safe_tensor(x + f, clamp=30.0) return x, present_key_value class VanFastForCausalLM(PreTrainedModel, GenerationMixin): config_class = VanFastConfig base_model_prefix = "van_fast" supports_gradient_checkpointing = False _supports_cache_class = False def __init__(self, config: VanFastConfig): super().__init__(config) self.token_emb = nn.Embedding(config.vocab_size, config.d_model) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([ HFDecoderBlock(config) for _ in range(config.n_layer) ]) self.norm = HFRMSNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() def _init_weights(self, module): std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=std) def get_input_embeddings(self): return self.token_emb def set_input_embeddings(self, value): self.token_emb = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def _normalize_past(self, past_key_values): if past_key_values is None: return [None] * len(self.blocks) if isinstance(past_key_values, tuple): past_key_values = list(past_key_values) if len(past_key_values) < len(self.blocks): past_key_values = past_key_values + [None] * ( len(self.blocks) - len(past_key_values) ) return past_key_values def forward( self, input_ids=None, labels=None, attention_mask=None, past_key_values=None, use_cache=None, return_dict=True, **kwargs, ): if input_ids is None: raise ValueError("input_ids is required") if use_cache is None: use_cache = getattr(self.config, "use_cache", True) has_past = past_key_values is not None # cache使用時は新規tokenだけ処理 if has_past and input_ids.shape[1] > 1: input_ids = input_ids[:, -1:] # cacheなしのprefill時だけblock_sizeに丸める if not has_past and input_ids.shape[1] > self.config.block_size: input_ids = input_ids[:, -self.config.block_size:] if labels is not None: labels = labels[:, -self.config.block_size:] past_key_values = self._normalize_past(past_key_values) x = self.token_emb(input_ids) x = safe_tensor(x, clamp=30.0) x = self.drop(x) presents = [] if use_cache else None for i, block in enumerate(self.blocks): layer_past = past_key_values[i] x, present = block( x, past_key_value=layer_past, use_cache=use_cache, ) if use_cache: presents.append(present) x = self.norm(x) x = safe_tensor(x, clamp=30.0) logits = self.lm_head(x) logits = logits.float() logits = torch.nan_to_num( logits, nan=0.0, posinf=80.0, neginf=-80.0, ) logits = torch.clamp(logits, min=-80.0, max=80.0) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() if shift_logits.numel() > 0: loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) past_out = tuple(presents) if use_cache else None if not return_dict: if loss is None: return (logits, past_out) return (loss, logits, past_out) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_out, hidden_states=None, attentions=None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **kwargs, ): if past_key_values is not None: input_ids = input_ids[:, -1:] else: if input_ids.shape[1] > self.config.block_size: input_ids = input_ids[:, -self.config.block_size:] return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": use_cache, } def _reorder_cache(self, past_key_values, beam_idx): if past_key_values is None: return None reordered = [] for layer_past in past_key_values: if layer_past is None: reordered.append(None) continue k, v = layer_past reordered.append( ( k.index_select(0, beam_idx.to(k.device)), v.index_select(0, beam_idx.to(v.device)), ) ) return tuple(reordered)