| """ |
| PIT (Point-In-Time) GPT model — self-contained for trust_remote_code=True loading. |
| |
| Architecture: decoder-only Transformer with RoPE, RMSNorm on Q/K, squared-ReLU |
| MLP, and weight-tied input/output embeddings. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .configuration_pit import PITConfig |
|
|
|
|
| |
| |
| |
|
|
| class Rotary(nn.Module): |
| def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0): |
| super().__init__() |
| self.dim = dim |
| self.base = base * scaling_factor |
| self.seq_len_cached: int | None = None |
| self.cos_cached: torch.Tensor | None = None |
| self.sin_cached: torch.Tensor | None = None |
|
|
| def forward(self, x: torch.Tensor): |
| seq_len = x.shape[1] |
| if seq_len != self.seq_len_cached: |
| self.seq_len_cached = seq_len |
| |
| |
| inv_freq = 1.0 / (self.base ** ( |
| torch.arange(0, self.dim, 2, device=x.device, dtype=torch.float32) / self.dim |
| )) |
| t = torch.arange(seq_len, device=x.device, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq) |
| self.cos_cached = freqs.cos().bfloat16() |
| self.sin_cached = freqs.sin().bfloat16() |
| return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] |
|
|
|
|
| def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| d = x.shape[3] // 2 |
| x1, x2 = x[..., :d], x[..., d:] |
| return torch.cat([x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos], dim=3).type_as(x) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: PITConfig): |
| super().__init__() |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.head_dim = config.n_embd // config.n_head |
| self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=False) |
| self.c_k = nn.Linear(config.n_embd, config.n_embd, bias=False) |
| self.c_v = nn.Linear(config.n_embd, config.n_embd, bias=False) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) |
| self.c_proj.weight.data.zero_() |
| self.rotary = Rotary(self.head_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.size() |
| q = self.c_q(x).view(B, T, self.n_head, self.head_dim) |
| k = self.c_k(x).view(B, T, self.n_head, self.head_dim) |
| v = self.c_v(x).view(B, T, self.n_head, self.head_dim) |
| cos, sin = self.rotary(q) |
| q = _apply_rotary_emb(F.rms_norm(q, (q.size(-1),)), cos, sin) |
| k = _apply_rotary_emb(F.rms_norm(k, (k.size(-1),)), cos, sin) |
| y = F.scaled_dot_product_attention( |
| q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True |
| ) |
| return self.c_proj(y.transpose(1, 2).contiguous().view_as(x)) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config: PITConfig): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) |
| self.c_proj.weight.data.zero_() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.c_proj(F.relu(self.c_fc(x)).square()) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config: PITConfig): |
| super().__init__() |
| self.attn = CausalSelfAttention(config) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(F.rms_norm(x, (x.size(-1),))) |
| x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class PITForCausalLM(PreTrainedModel): |
| """ |
| Point-In-Time GPT wrapped as a HuggingFace CausalLM. |
| |
| Supports AutoModelForCausalLM, generate(), and pipeline("text-generation"). |
| |
| Loading |
| ------- |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM |
| >>> tokenizer = AutoTokenizer.from_pretrained("Diamegs/PIT-4B-FT-202012") |
| >>> model = AutoModelForCausalLM.from_pretrained( |
| ... "Diamegs/PIT-4B-FT-202012", |
| ... trust_remote_code=True, |
| ... torch_dtype=torch.bfloat16, |
| ... device_map="auto", |
| ... ) |
| """ |
|
|
| config_class = PITConfig |
| _no_split_modules = ["Block"] |
| _supports_cache_class = False |
| |
| _tied_weights_keys = ["lm_head.weight", "transformer.wte.weight"] |
|
|
| def __init__(self, config: PITConfig): |
| super().__init__(config) |
| self.transformer = nn.ModuleDict({ |
| "wte": nn.Embedding(config.vocab_size, config.n_embd), |
| "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| }) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| |
| self.transformer["wte"].weight = self.lm_head.weight |
| self.post_init() |
|
|
| |
|
|
| def get_input_embeddings(self) -> nn.Embedding: |
| return self.transformer["wte"] |
|
|
| def set_input_embeddings(self, value: nn.Embedding) -> None: |
| self.transformer["wte"] = value |
|
|
| def get_output_embeddings(self) -> nn.Linear: |
| return self.lm_head |
|
|
| def set_output_embeddings(self, value: nn.Linear) -> None: |
| self.lm_head = value |
|
|
| |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| x = self.transformer["wte"](input_ids) |
| for block in self.transformer["h"]: |
| x = block(x) |
| x = F.rms_norm(x, (x.size(-1),)) |
| logits = self.lm_head(x).float() |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids: torch.Tensor, **kwargs |
| ) -> dict: |
| return {"input_ids": input_ids} |
|
|