HebrewGPT-296M / modeling_hebrewgpt.py
ronnengmail's picture
Upload modeling_hebrewgpt.py with huggingface_hub
d046b9f verified
"""HebrewGPT model implementation compatible with HuggingFace AutoModel."""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_hebrewgpt import HebrewGPTConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (x.float() * norm).type_as(x) * self.weight
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(seq_len, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return torch.cos(freqs), torch.sin(freqs)
def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
"""Apply RoPE with interleaved pattern: x[..., ::2], x[..., 1::2]."""
x_r = x[..., ::2]
x_i = x[..., 1::2]
# Reshape freqs for broadcasting: (seq_len, head_dim//2) -> (1, seq_len, 1, head_dim//2)
cos = freqs_cos.unsqueeze(0).unsqueeze(2)
sin = freqs_sin.unsqueeze(0).unsqueeze(2)
o_r = x_r * cos - x_i * sin
o_i = x_r * sin + x_i * cos
# Interleave back
out = torch.stack((o_r, o_i), dim=-1).flatten(-2)
return out
class HebrewGPTAttention(nn.Module):
def __init__(self, config: HebrewGPTConfig):
super().__init__()
self.n_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# RoPE buffers - computed from config, not stored
freqs_cos, freqs_sin = precompute_freqs_cis(
config.head_dim, config.max_position_embeddings, config.rope_theta
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim)
k = k.view(B, T, self.n_heads, self.head_dim)
v = v.view(B, T, self.n_heads, self.head_dim)
# Apply RoPE
q = apply_rotary_emb(q, self.freqs_cos[:T], self.freqs_sin[:T])
k = apply_rotary_emb(k, self.freqs_cos[:T], self.freqs_sin[:T])
# Transpose for attention: (B, n_heads, T, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Scaled dot-product attention with causal mask
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, is_causal=(attention_mask is None)
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(y)
class HebrewGPTMLP(nn.Module):
def __init__(self, config: HebrewGPTConfig):
super().__init__()
self.gate = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class HebrewGPTBlock(nn.Module):
def __init__(self, config: HebrewGPTConfig):
super().__init__()
self.ln1 = RMSNorm(config.hidden_size)
self.attn = HebrewGPTAttention(config)
self.ln2 = RMSNorm(config.hidden_size)
self.mlp = HebrewGPTMLP(config)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
x = x + self.attn(self.ln1(x), attention_mask)
x = x + self.mlp(self.ln2(x))
return x
class HebrewGPTPreTrainedModel(PreTrainedModel):
config_class = HebrewGPTConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["HebrewGPTBlock"]
_keys_to_ignore_on_load_missing = [r"blocks\.\d+\.attn\.freqs_cos", r"blocks\.\d+\.attn\.freqs_sin"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
class HebrewGPTForCausalLM(HebrewGPTPreTrainedModel):
_tied_weights_keys = {"head.weight": "tok_emb.weight"}
_keys_to_ignore_on_load_missing = ["head.weight"]
def __init__(self, config: HebrewGPTConfig):
super().__init__(config)
self.tok_emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.blocks = nn.ModuleList([HebrewGPTBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size)
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie weights
self.head.weight = self.tok_emb.weight
self.post_init()
def get_input_embeddings(self):
return self.tok_emb
def set_input_embeddings(self, value):
self.tok_emb = value
self.head.weight = self.tok_emb.weight
def get_output_embeddings(self):
return self.head
def set_output_embeddings(self, new_embeddings):
self.head = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
x = self.tok_emb(input_ids)
else:
x = inputs_embeds
# Convert attention_mask to the right format for SDPA if provided
attn_mask = None
if attention_mask is not None:
# attention_mask: (B, T) with 1s for real tokens, 0s for padding
B, T = attention_mask.shape
# Create causal + padding mask for SDPA
causal = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))
pad_mask = attention_mask[:, None, None, :].bool() # (B, 1, 1, T)
attn_mask = causal[None, None, :, :] & pad_mask # (B, 1, T, T)
for block in self.blocks:
x = block(x, attn_mask)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
if not return_dict:
output = (logits,)
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}