""" PIT (Point-In-Time) GPT model — self-contained for trust_remote_code=True loading. Architecture: decoder-only Transformer with RoPE, RMSNorm on Q/K, squared-ReLU MLP, and weight-tied input/output embeddings. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_pit import PITConfig # --------------------------------------------------------------------------- # Architecture (mirrors models/GPT.py exactly) # --------------------------------------------------------------------------- class Rotary(nn.Module): def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0): super().__init__() self.dim = dim self.base = base * scaling_factor self.seq_len_cached: int | None = None self.cos_cached: torch.Tensor | None = None self.sin_cached: torch.Tensor | None = None def forward(self, x: torch.Tensor): seq_len = x.shape[1] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len # Compute inv_freq on-the-fly on the correct device — never stored # as a buffer so device_map="auto" / meta-device loading can't break it. inv_freq = 1.0 / (self.base ** ( torch.arange(0, self.dim, 2, device=x.device, dtype=torch.float32) / self.dim )) t = torch.arange(seq_len, device=x.device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) self.cos_cached = freqs.cos().bfloat16() self.sin_cached = freqs.sin().bfloat16() return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] return torch.cat([x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos], dim=3).type_as(x) class CausalSelfAttention(nn.Module): def __init__(self, config: PITConfig): super().__init__() self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=False) self.c_k = nn.Linear(config.n_embd, config.n_embd, bias=False) self.c_v = nn.Linear(config.n_embd, config.n_embd, bias=False) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.c_proj.weight.data.zero_() self.rotary = Rotary(self.head_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.size() q = self.c_q(x).view(B, T, self.n_head, self.head_dim) k = self.c_k(x).view(B, T, self.n_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_head, self.head_dim) cos, sin = self.rotary(q) q = _apply_rotary_emb(F.rms_norm(q, (q.size(-1),)), cos, sin) k = _apply_rotary_emb(F.rms_norm(k, (k.size(-1),)), cos, sin) y = F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True ) return self.c_proj(y.transpose(1, 2).contiguous().view_as(x)) class MLP(nn.Module): def __init__(self, config: PITConfig): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.c_proj.weight.data.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.c_proj(F.relu(self.c_fc(x)).square()) class Block(nn.Module): def __init__(self, config: PITConfig): super().__init__() self.attn = CausalSelfAttention(config) self.mlp = MLP(config) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(F.rms_norm(x, (x.size(-1),))) x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) return x # --------------------------------------------------------------------------- # HuggingFace PreTrainedModel wrapper # --------------------------------------------------------------------------- class PITForCausalLM(PreTrainedModel): """ Point-In-Time GPT wrapped as a HuggingFace CausalLM. Supports AutoModelForCausalLM, generate(), and pipeline("text-generation"). Loading ------- >>> from transformers import AutoTokenizer, AutoModelForCausalLM >>> tokenizer = AutoTokenizer.from_pretrained("Diamegs/PIT-4B-FT-202012") >>> model = AutoModelForCausalLM.from_pretrained( ... "Diamegs/PIT-4B-FT-202012", ... trust_remote_code=True, ... torch_dtype=torch.bfloat16, ... device_map="auto", ... ) """ config_class = PITConfig _no_split_modules = ["Block"] _supports_cache_class = False # Weight tying: lm_head and transformer.wte share parameters. _tied_weights_keys = ["lm_head.weight", "transformer.wte.weight"] def __init__(self, config: PITConfig): super().__init__(config) self.transformer = nn.ModuleDict({ "wte": nn.Embedding(config.vocab_size, config.n_embd), "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), }) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Tie weights (re-tied after load_state_dict via tie_weights()) self.transformer["wte"].weight = self.lm_head.weight self.post_init() # -- weight tying hooks required by PreTrainedModel ---------------------- def get_input_embeddings(self) -> nn.Embedding: return self.transformer["wte"] def set_input_embeddings(self, value: nn.Embedding) -> None: self.transformer["wte"] = value def get_output_embeddings(self) -> nn.Linear: return self.lm_head def set_output_embeddings(self, value: nn.Linear) -> None: self.lm_head = value # -- forward ------------------------------------------------------------- def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, **kwargs, ) -> CausalLMOutputWithPast: x = self.transformer["wte"](input_ids) for block in self.transformer["h"]: x = block(x) x = F.rms_norm(x, (x.size(-1),)) logits = self.lm_head(x).float() loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits) def prepare_inputs_for_generation( self, input_ids: torch.Tensor, **kwargs ) -> dict: return {"input_ids": input_ids}