| """HebrewGPT model implementation compatible with HuggingFace AutoModel.""" |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .configuration_hebrewgpt import HebrewGPTConfig |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) |
| return (x.float() * norm).type_as(x) * self.weight |
|
|
|
|
| def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| t = torch.arange(seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, freqs) |
| return torch.cos(freqs), torch.sin(freqs) |
|
|
|
|
| def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor): |
| """Apply RoPE with interleaved pattern: x[..., ::2], x[..., 1::2].""" |
| x_r = x[..., ::2] |
| x_i = x[..., 1::2] |
| |
| |
| cos = freqs_cos.unsqueeze(0).unsqueeze(2) |
| sin = freqs_sin.unsqueeze(0).unsqueeze(2) |
| |
| o_r = x_r * cos - x_i * sin |
| o_i = x_r * sin + x_i * cos |
| |
| |
| out = torch.stack((o_r, o_i), dim=-1).flatten(-2) |
| return out |
|
|
|
|
| class HebrewGPTAttention(nn.Module): |
| def __init__(self, config: HebrewGPTConfig): |
| super().__init__() |
| self.n_heads = config.num_attention_heads |
| self.head_dim = config.head_dim |
| self.hidden_size = config.hidden_size |
|
|
| self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False) |
| self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
| |
| freqs_cos, freqs_sin = precompute_freqs_cis( |
| config.head_dim, config.max_position_embeddings, config.rope_theta |
| ) |
| self.register_buffer("freqs_cos", freqs_cos, persistent=False) |
| self.register_buffer("freqs_sin", freqs_sin, persistent=False) |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): |
| B, T, C = x.shape |
|
|
| qkv = self.qkv(x) |
| q, k, v = qkv.chunk(3, dim=-1) |
|
|
| q = q.view(B, T, self.n_heads, self.head_dim) |
| k = k.view(B, T, self.n_heads, self.head_dim) |
| v = v.view(B, T, self.n_heads, self.head_dim) |
|
|
| |
| q = apply_rotary_emb(q, self.freqs_cos[:T], self.freqs_sin[:T]) |
| k = apply_rotary_emb(k, self.freqs_cos[:T], self.freqs_sin[:T]) |
|
|
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| y = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=attention_mask, is_causal=(attention_mask is None) |
| ) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.proj(y) |
|
|
|
|
| class HebrewGPTMLP(nn.Module): |
| def __init__(self, config: HebrewGPTConfig): |
| super().__init__() |
| self.gate = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.up = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
|
|
| def forward(self, x): |
| return self.down(F.silu(self.gate(x)) * self.up(x)) |
|
|
|
|
| class HebrewGPTBlock(nn.Module): |
| def __init__(self, config: HebrewGPTConfig): |
| super().__init__() |
| self.ln1 = RMSNorm(config.hidden_size) |
| self.attn = HebrewGPTAttention(config) |
| self.ln2 = RMSNorm(config.hidden_size) |
| self.mlp = HebrewGPTMLP(config) |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): |
| x = x + self.attn(self.ln1(x), attention_mask) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class HebrewGPTPreTrainedModel(PreTrainedModel): |
| config_class = HebrewGPTConfig |
| base_model_prefix = "" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["HebrewGPTBlock"] |
| _keys_to_ignore_on_load_missing = [r"blocks\.\d+\.attn\.freqs_cos", r"blocks\.\d+\.attn\.freqs_sin"] |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
| class HebrewGPTForCausalLM(HebrewGPTPreTrainedModel): |
| _tied_weights_keys = {"head.weight": "tok_emb.weight"} |
| _keys_to_ignore_on_load_missing = ["head.weight"] |
|
|
| def __init__(self, config: HebrewGPTConfig): |
| super().__init__(config) |
| self.tok_emb = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.blocks = nn.ModuleList([HebrewGPTBlock(config) for _ in range(config.num_hidden_layers)]) |
| self.ln_f = RMSNorm(config.hidden_size) |
| self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.head.weight = self.tok_emb.weight |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.tok_emb |
|
|
| def set_input_embeddings(self, value): |
| self.tok_emb = value |
| self.head.weight = self.tok_emb.weight |
|
|
| def get_output_embeddings(self): |
| return self.head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is None: |
| x = self.tok_emb(input_ids) |
| else: |
| x = inputs_embeds |
|
|
| |
| attn_mask = None |
| if attention_mask is not None: |
| |
| B, T = attention_mask.shape |
| |
| causal = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)) |
| pad_mask = attention_mask[:, None, None, :].bool() |
| attn_mask = causal[None, None, :, :] & pad_mask |
|
|
| for block in self.blocks: |
| x = block(x, attn_mask) |
|
|
| x = self.ln_f(x) |
| logits = self.head(x) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| if not return_dict: |
| output = (logits,) |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| } |
|
|