import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast class HilbertLMConfig(PretrainedConfig): model_type = "HilbertLM" def __init__( self, vocab_size=49152, hidden_size=576, num_hidden_layers=30, num_attention_heads=9, num_key_value_heads=3, block_size=2048, use_layernorm=False, use_swiglu=True, tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.block_size = block_size self.use_layernorm = use_layernorm self.use_swiglu = use_swiglu super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) class RoPE(nn.Module): def __init__(self, head_dim, max_seq_len=2048): super().__init__() pos = torch.arange(max_seq_len, dtype=torch.float) theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) angles = torch.outer(pos, theta) embedding = torch.cat((angles, angles), dim=-1) self.register_buffer('cos', embedding.cos()[None, None, :, :]) self.register_buffer('sin', embedding.sin()[None, None, :, :]) def forward(self, x): seq_len = x.shape[2] cos = self.cos[:, :, :seq_len, :].to(x.dtype) sin = self.sin[:, :, :seq_len, :].to(x.dtype) x1, x2 = x.chunk(2, dim=-1) x_rotated_half = torch.cat((-x2, x1), dim=-1) return (x * cos) + (x_rotated_half * sin) class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(x) * gate class TransformerBlock(nn.Module): def __init__(self, hidden_size, num_attention_heads, max_len, num_key_value_heads, use_layernorm=False, use_swiglu=True): super().__init__() self.n_head = num_attention_heads self.n_kv_head = num_key_value_heads self.head_dim = hidden_size // num_attention_heads self.hidden_size = hidden_size self.q_size = self.n_head * self.head_dim self.kv_size = self.n_kv_head * self.head_dim total_qkv_dim = self.q_size + 2 * self.kv_size self.rope = RoPE(self.head_dim, max_len) ffn_hidden = int(hidden_size * 8/3) if use_swiglu else int(hidden_size * 4) self.ln1 = nn.LayerNorm(hidden_size) if use_layernorm else nn.RMSNorm(hidden_size) self.qkv_proj = nn.Linear(hidden_size, total_qkv_dim, bias=False) self.c_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.ln2 = nn.LayerNorm(hidden_size) if use_layernorm else nn.RMSNorm(hidden_size) if use_swiglu: self.mlp = nn.Sequential( nn.Linear(hidden_size, 2 * ffn_hidden, bias=False), SwiGLU(), nn.Linear(ffn_hidden, hidden_size, bias=False) ) else: self.mlp = nn.Sequential( nn.Linear(hidden_size, ffn_hidden, bias=False), nn.GELU(), nn.Linear(ffn_hidden, hidden_size, bias=False) ) def forward(self, x): residual = x x_norm = self.ln1(x) qkv = self.qkv_proj(x_norm) q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=2) B, T, _ = q.size() q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) q = self.rope(q) k = self.rope(k) attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, self.hidden_size) x = residual + self.c_proj(attn_out) x = x + self.mlp(self.ln2(x)) return x class HilbertLM(nn.Module): def __init__(self, vocab_size, hidden_size, num_hidden_layers, num_attention_heads, max_len, num_key_value_heads, use_layernorm=False, use_swiglu=True): super().__init__() self.token_embedding = nn.Embedding(vocab_size, hidden_size) self.layers = nn.ModuleList([ TransformerBlock(hidden_size, num_attention_heads, max_len, num_key_value_heads, use_layernorm, use_swiglu) for _ in range(num_hidden_layers) ]) self.final_norm = nn.LayerNorm(hidden_size) if use_layernorm else nn.RMSNorm(hidden_size) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) self._init_weights() def _init_weights(self): nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)): nn.init.ones_(module.weight) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, x): x = self.token_embedding(x) for layer in self.layers: x = layer(x) x = self.final_norm(x) logits = self.lm_head(x) return logits class HilbertLMForCausalLM(PreTrainedModel, GenerationMixin): config_class = HilbertLMConfig _keys_to_ignore_on_load_missing = ["model.lm_head.weight"] def __init__(self, config): super().__init__(config) self.config = config self.model = HilbertLM( vocab_size=config.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, max_len=config.block_size, num_key_value_heads=config.num_key_value_heads, use_layernorm=config.use_layernorm, use_swiglu=config.use_swiglu ) if config.tie_word_embeddings: self.all_tied_weights_keys = {"model.token_embedding.weight": "model.lm_head.weight"} else: self.all_tied_weights_keys = {} def tie_weights(self, missing_keys=None, recompute_mapping=True): if self.config.tie_word_embeddings: self.model.lm_head.weight = self.model.token_embedding.weight def get_input_embeddings(self): return self.model.token_embedding def set_input_embeddings(self, value): self.model.token_embedding = value def get_output_embeddings(self): return self.model.lm_head def set_output_embeddings(self, new_embeddings): self.model.lm_head = new_embeddings def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): logits = self.model(input_ids) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}