Hilbert-135M-Base / modeling.py
baillietn's picture
Add Hilbert-135M-Base
66c3b71 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
class HilbertLMConfig(PretrainedConfig):
model_type = "HilbertLM"
def __init__(
self,
vocab_size=49152,
hidden_size=576,
num_hidden_layers=30,
num_attention_heads=9,
num_key_value_heads=3,
block_size=2048,
use_layernorm=False,
use_swiglu=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.block_size = block_size
self.use_layernorm = use_layernorm
self.use_swiglu = use_swiglu
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class RoPE(nn.Module):
def __init__(self, head_dim, max_seq_len=2048):
super().__init__()
pos = torch.arange(max_seq_len, dtype=torch.float)
theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
angles = torch.outer(pos, theta)
embedding = torch.cat((angles, angles), dim=-1)
self.register_buffer('cos', embedding.cos()[None, None, :, :])
self.register_buffer('sin', embedding.sin()[None, None, :, :])
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:, :, :seq_len, :].to(x.dtype)
sin = self.sin[:, :, :seq_len, :].to(x.dtype)
x1, x2 = x.chunk(2, dim=-1)
x_rotated_half = torch.cat((-x2, x1), dim=-1)
return (x * cos) + (x_rotated_half * sin)
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(x) * gate
class TransformerBlock(nn.Module):
def __init__(self, hidden_size, num_attention_heads, max_len, num_key_value_heads, use_layernorm=False, use_swiglu=True):
super().__init__()
self.n_head = num_attention_heads
self.n_kv_head = num_key_value_heads
self.head_dim = hidden_size // num_attention_heads
self.hidden_size = hidden_size
self.q_size = self.n_head * self.head_dim
self.kv_size = self.n_kv_head * self.head_dim
total_qkv_dim = self.q_size + 2 * self.kv_size
self.rope = RoPE(self.head_dim, max_len)
ffn_hidden = int(hidden_size * 8/3) if use_swiglu else int(hidden_size * 4)
self.ln1 = nn.LayerNorm(hidden_size) if use_layernorm else nn.RMSNorm(hidden_size)
self.qkv_proj = nn.Linear(hidden_size, total_qkv_dim, bias=False)
self.c_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.ln2 = nn.LayerNorm(hidden_size) if use_layernorm else nn.RMSNorm(hidden_size)
if use_swiglu:
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 2 * ffn_hidden, bias=False),
SwiGLU(),
nn.Linear(ffn_hidden, hidden_size, bias=False)
)
else:
self.mlp = nn.Sequential(
nn.Linear(hidden_size, ffn_hidden, bias=False),
nn.GELU(),
nn.Linear(ffn_hidden, hidden_size, bias=False)
)
def forward(self, x):
residual = x
x_norm = self.ln1(x)
qkv = self.qkv_proj(x_norm)
q, k, v = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size], dim=2)
B, T, _ = q.size()
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
q = self.rope(q)
k = self.rope(k)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, self.hidden_size)
x = residual + self.c_proj(attn_out)
x = x + self.mlp(self.ln2(x))
return x
class HilbertLM(nn.Module):
def __init__(self, vocab_size, hidden_size, num_hidden_layers, num_attention_heads, max_len, num_key_value_heads, use_layernorm=False, use_swiglu=True):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, hidden_size)
self.layers = nn.ModuleList([
TransformerBlock(hidden_size, num_attention_heads, max_len, num_key_value_heads, use_layernorm, use_swiglu)
for _ in range(num_hidden_layers)
])
self.final_norm = nn.LayerNorm(hidden_size) if use_layernorm else nn.RMSNorm(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
nn.init.ones_(module.weight)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x):
x = self.token_embedding(x)
for layer in self.layers:
x = layer(x)
x = self.final_norm(x)
logits = self.lm_head(x)
return logits
class HilbertLMForCausalLM(PreTrainedModel, GenerationMixin):
config_class = HilbertLMConfig
_keys_to_ignore_on_load_missing = ["model.lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = HilbertLM(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
max_len=config.block_size,
num_key_value_heads=config.num_key_value_heads,
use_layernorm=config.use_layernorm,
use_swiglu=config.use_swiglu
)
if config.tie_word_embeddings:
self.all_tied_weights_keys = {"model.token_embedding.weight": "model.lm_head.weight"}
else:
self.all_tied_weights_keys = {}
def tie_weights(self, missing_keys=None, recompute_mapping=True):
if self.config.tie_word_embeddings:
self.model.lm_head.weight = self.model.token_embedding.weight
def get_input_embeddings(self):
return self.model.token_embedding
def set_input_embeddings(self, value):
self.model.token_embedding = value
def get_output_embeddings(self):
return self.model.lm_head
def set_output_embeddings(self, new_embeddings):
self.model.lm_head = new_embeddings
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
logits = self.model(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}