|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
| def safe_tensor(x, clamp=30.0): |
| x = torch.nan_to_num( |
| x, |
| nan=0.0, |
| posinf=clamp, |
| neginf=-clamp, |
| ) |
| x = torch.clamp(x, min=-clamp, max=clamp) |
| return x |
|
|
|
|
| class VanFastConfig(PretrainedConfig): |
| model_type = "van_fast_transformer" |
|
|
| def __init__( |
| self, |
| vocab_size=50257, |
| block_size=1024, |
| d_model=1024, |
| n_layer=18, |
| n_head=16, |
| n_kv_head=4, |
| d_ff=4096, |
| dropout=0.0, |
| use_qk_norm=True, |
| initializer_range=0.02, |
| pad_token_id=None, |
| eos_token_id=None, |
| bos_token_id=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| super().__init__( |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| bos_token_id=bos_token_id, |
| **kwargs, |
| ) |
|
|
| self.vocab_size = vocab_size |
| self.block_size = block_size |
| self.d_model = d_model |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.n_kv_head = n_kv_head |
| self.d_ff = d_ff |
| self.dropout = dropout |
| self.use_qk_norm = use_qk_norm |
| self.initializer_range = initializer_range |
|
|
| self.is_decoder = True |
| self.is_encoder_decoder = False |
| self.tie_word_embeddings = False |
| self.use_cache = use_cache |
|
|
|
|
| class HFRMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| x = safe_tensor(x, clamp=30.0) |
|
|
| x_float = x.float() |
| var = x_float.pow(2).mean(dim=-1, keepdim=True) |
| var = torch.nan_to_num(var, nan=1.0, posinf=1.0, neginf=1.0) |
| var = torch.clamp(var, min=0.0, max=1e6) |
|
|
| y = x_float * torch.rsqrt(var + self.eps) |
| y = y.to(dtype=x.dtype) * self.weight.to(dtype=x.dtype) |
| y = safe_tensor(y, clamp=30.0) |
|
|
| return y |
|
|
|
|
| class HFRotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0): |
| super().__init__() |
|
|
| inv_freq = 1.0 / ( |
| base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) |
| ) |
|
|
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
|
| cos = freqs.cos() |
| sin = freqs.sin() |
|
|
| self.register_buffer("cos_cached", cos[None, None, :, :], persistent=False) |
| self.register_buffer("sin_cached", sin[None, None, :, :], persistent=False) |
|
|
| def forward(self, x, seq_len: int, offset: int = 0): |
| end = offset + seq_len |
|
|
| max_len = self.cos_cached.shape[2] |
| if end > max_len: |
| |
| offset = max(0, max_len - seq_len) |
| end = offset + seq_len |
|
|
| cos = self.cos_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype) |
| sin = self.sin_cached[:, :, offset:end, :].to(device=x.device, dtype=x.dtype) |
|
|
| return cos, sin |
|
|
|
|
| def hf_apply_rope(q, k, cos, sin): |
| q1 = q[..., ::2] |
| q2 = q[..., 1::2] |
|
|
| k1 = k[..., ::2] |
| k2 = k[..., 1::2] |
|
|
| q_rot = torch.stack( |
| [ |
| q1 * cos - q2 * sin, |
| q1 * sin + q2 * cos, |
| ], |
| dim=-1, |
| ).flatten(-2) |
|
|
| k_rot = torch.stack( |
| [ |
| k1 * cos - k2 * sin, |
| k1 * sin + k2 * cos, |
| ], |
| dim=-1, |
| ).flatten(-2) |
|
|
| q_rot = safe_tensor(q_rot, clamp=10.0) |
| k_rot = safe_tensor(k_rot, clamp=10.0) |
|
|
| return q_rot, k_rot |
|
|
|
|
| class HFGQAAttention(nn.Module): |
| def __init__(self, config: VanFastConfig): |
| super().__init__() |
|
|
| d_model = config.d_model |
| n_head = config.n_head |
| n_kv_head = config.n_kv_head |
|
|
| assert d_model % n_head == 0 |
| assert n_head % n_kv_head == 0 |
|
|
| self.d_model = d_model |
| self.n_head = n_head |
| self.n_kv_head = n_kv_head |
| self.head_dim = d_model // n_head |
| self.num_groups = n_head // n_kv_head |
| self.dropout = config.dropout |
| self.block_size = config.block_size |
|
|
| assert self.head_dim % 2 == 0 |
|
|
| self.q_proj = nn.Linear(d_model, n_head * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(d_model, n_kv_head * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
| if config.use_qk_norm: |
| self.q_norm = HFRMSNorm(self.head_dim) |
| self.k_norm = HFRMSNorm(self.head_dim) |
| else: |
| self.q_norm = nn.Identity() |
| self.k_norm = nn.Identity() |
|
|
| self.rope = HFRotaryEmbedding( |
| dim=self.head_dim, |
| max_seq_len=config.block_size, |
| ) |
|
|
| def forward( |
| self, |
| x, |
| past_key_value=None, |
| use_cache=False, |
| ): |
| x = safe_tensor(x, clamp=30.0) |
|
|
| B, T, C = x.shape |
|
|
| q = self.q_proj(x) |
| k = self.k_proj(x) |
| v = self.v_proj(x) |
|
|
| q = safe_tensor(q, clamp=30.0) |
| k = safe_tensor(k, clamp=30.0) |
| v = safe_tensor(v, clamp=30.0) |
|
|
| q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| q = safe_tensor(q, clamp=10.0) |
| k = safe_tensor(k, clamp=10.0) |
| v = safe_tensor(v, clamp=30.0) |
|
|
| past_len = 0 |
|
|
| if past_key_value is not None: |
| past_k, past_v = past_key_value |
| past_len = past_k.shape[2] |
|
|
| cos, sin = self.rope(q, T, offset=past_len) |
| q, k = hf_apply_rope(q, k, cos, sin) |
|
|
| if past_key_value is not None: |
| past_k, past_v = past_key_value |
| k = torch.cat([past_k, k], dim=2) |
| v = torch.cat([past_v, v], dim=2) |
|
|
| |
| if k.shape[2] > self.block_size: |
| k = k[:, :, -self.block_size:, :].contiguous() |
| v = v[:, :, -self.block_size:, :].contiguous() |
|
|
| present_key_value = (k, v) if use_cache else None |
|
|
| k_attn = k |
| v_attn = v |
|
|
| if self.num_groups > 1: |
| k_attn = k_attn.repeat_interleave(self.num_groups, dim=1) |
| v_attn = v_attn.repeat_interleave(self.num_groups, dim=1) |
|
|
| |
| is_causal = past_key_value is None |
|
|
| y = F.scaled_dot_product_attention( |
| q, |
| k_attn, |
| v_attn, |
| attn_mask=None, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=is_causal, |
| ) |
|
|
| y = safe_tensor(y, clamp=30.0) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.o_proj(y) |
| y = safe_tensor(y, clamp=30.0) |
|
|
| return y, present_key_value |
|
|
|
|
| class HFSwiGLU(nn.Module): |
| def __init__(self, config: VanFastConfig): |
| super().__init__() |
|
|
| self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False) |
| self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False) |
| self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False) |
|
|
| def forward(self, x): |
| x = safe_tensor(x, clamp=30.0) |
|
|
| a = self.w1(x) |
| b = self.w3(x) |
|
|
| a = safe_tensor(a, clamp=30.0) |
| b = safe_tensor(b, clamp=30.0) |
|
|
| y = F.silu(a) * b |
| y = safe_tensor(y, clamp=30.0) |
|
|
| y = self.w2(y) |
| y = safe_tensor(y, clamp=30.0) |
|
|
| return y |
|
|
|
|
| class HFDecoderBlock(nn.Module): |
| def __init__(self, config: VanFastConfig): |
| super().__init__() |
|
|
| self.attn_norm = HFRMSNorm(config.d_model) |
| self.attn = HFGQAAttention(config) |
|
|
| self.ffn_norm = HFRMSNorm(config.d_model) |
| self.ffn = HFSwiGLU(config) |
|
|
| def forward( |
| self, |
| x, |
| past_key_value=None, |
| use_cache=False, |
| ): |
| x = safe_tensor(x, clamp=30.0) |
|
|
| a, present_key_value = self.attn( |
| self.attn_norm(x), |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| ) |
|
|
| a = safe_tensor(a, clamp=30.0) |
| x = safe_tensor(x + a, clamp=30.0) |
|
|
| f = self.ffn(self.ffn_norm(x)) |
| f = safe_tensor(f, clamp=30.0) |
| x = safe_tensor(x + f, clamp=30.0) |
|
|
| return x, present_key_value |
|
|
|
|
| class VanFastForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = VanFastConfig |
| base_model_prefix = "van_fast" |
| supports_gradient_checkpointing = False |
| _supports_cache_class = False |
|
|
| def __init__(self, config: VanFastConfig): |
| super().__init__(config) |
|
|
| self.token_emb = nn.Embedding(config.vocab_size, config.d_model) |
| self.drop = nn.Dropout(config.dropout) |
|
|
| self.blocks = nn.ModuleList([ |
| HFDecoderBlock(config) |
| for _ in range(config.n_layer) |
| ]) |
|
|
| self.norm = HFRMSNorm(config.d_model) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| self.post_init() |
|
|
| def _init_weights(self, module): |
| std = getattr(self.config, "initializer_range", 0.02) |
|
|
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
| def get_input_embeddings(self): |
| return self.token_emb |
|
|
| def set_input_embeddings(self, value): |
| self.token_emb = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def _normalize_past(self, past_key_values): |
| if past_key_values is None: |
| return [None] * len(self.blocks) |
|
|
| if isinstance(past_key_values, tuple): |
| past_key_values = list(past_key_values) |
|
|
| if len(past_key_values) < len(self.blocks): |
| past_key_values = past_key_values + [None] * ( |
| len(self.blocks) - len(past_key_values) |
| ) |
|
|
| return past_key_values |
|
|
| def forward( |
| self, |
| input_ids=None, |
| labels=None, |
| attention_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| return_dict=True, |
| **kwargs, |
| ): |
| if input_ids is None: |
| raise ValueError("input_ids is required") |
|
|
| if use_cache is None: |
| use_cache = getattr(self.config, "use_cache", True) |
|
|
| has_past = past_key_values is not None |
|
|
| |
| if has_past and input_ids.shape[1] > 1: |
| input_ids = input_ids[:, -1:] |
|
|
| |
| if not has_past and input_ids.shape[1] > self.config.block_size: |
| input_ids = input_ids[:, -self.config.block_size:] |
| if labels is not None: |
| labels = labels[:, -self.config.block_size:] |
|
|
| past_key_values = self._normalize_past(past_key_values) |
|
|
| x = self.token_emb(input_ids) |
| x = safe_tensor(x, clamp=30.0) |
|
|
| x = self.drop(x) |
|
|
| presents = [] if use_cache else None |
|
|
| for i, block in enumerate(self.blocks): |
| layer_past = past_key_values[i] |
|
|
| x, present = block( |
| x, |
| past_key_value=layer_past, |
| use_cache=use_cache, |
| ) |
|
|
| if use_cache: |
| presents.append(present) |
|
|
| x = self.norm(x) |
| x = safe_tensor(x, clamp=30.0) |
|
|
| logits = self.lm_head(x) |
|
|
| logits = logits.float() |
| logits = torch.nan_to_num( |
| logits, |
| nan=0.0, |
| posinf=80.0, |
| neginf=-80.0, |
| ) |
| logits = torch.clamp(logits, min=-80.0, max=80.0) |
|
|
| loss = None |
|
|
| if labels is not None: |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
|
|
| if shift_logits.numel() > 0: |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| past_out = tuple(presents) if use_cache else None |
|
|
| if not return_dict: |
| if loss is None: |
| return (logits, past_out) |
| return (loss, logits, past_out) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=past_out, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| if past_key_values is not None: |
| input_ids = input_ids[:, -1:] |
| else: |
| if input_ids.shape[1] > self.config.block_size: |
| input_ids = input_ids[:, -self.config.block_size:] |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| } |
|
|
| def _reorder_cache(self, past_key_values, beam_idx): |
| if past_key_values is None: |
| return None |
|
|
| reordered = [] |
|
|
| for layer_past in past_key_values: |
| if layer_past is None: |
| reordered.append(None) |
| continue |
|
|
| k, v = layer_past |
| reordered.append( |
| ( |
| k.index_select(0, beam_idx.to(k.device)), |
| v.index_select(0, beam_idx.to(v.device)), |
| ) |
| ) |
|
|
| return tuple(reordered) |
|
|