PIT-4B-202012 / modeling_pit.py
Diamegs's picture
Add 4B base snapshot (2020-12)
fccb6f3 verified
"""
PIT (Point-In-Time) GPT model — self-contained for trust_remote_code=True loading.
Architecture: decoder-only Transformer with RoPE, RMSNorm on Q/K, squared-ReLU
MLP, and weight-tied input/output embeddings.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_pit import PITConfig
# ---------------------------------------------------------------------------
# Architecture (mirrors models/GPT.py exactly)
# ---------------------------------------------------------------------------
class Rotary(nn.Module):
def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0):
super().__init__()
self.dim = dim
self.base = base * scaling_factor
self.seq_len_cached: int | None = None
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None
def forward(self, x: torch.Tensor):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
# Compute inv_freq on-the-fly on the correct device — never stored
# as a buffer so device_map="auto" / meta-device loading can't break it.
inv_freq = 1.0 / (self.base ** (
torch.arange(0, self.dim, 2, device=x.device, dtype=torch.float32) / self.dim
))
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self.cos_cached = freqs.cos().bfloat16()
self.sin_cached = freqs.sin().bfloat16()
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
return torch.cat([x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos], dim=3).type_as(x)
class CausalSelfAttention(nn.Module):
def __init__(self, config: PITConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_k = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_v = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_proj.weight.data.zero_()
self.rotary = Rotary(self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
cos, sin = self.rotary(q)
q = _apply_rotary_emb(F.rms_norm(q, (q.size(-1),)), cos, sin)
k = _apply_rotary_emb(F.rms_norm(k, (k.size(-1),)), cos, sin)
y = F.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True
)
return self.c_proj(y.transpose(1, 2).contiguous().view_as(x))
class MLP(nn.Module):
def __init__(self, config: PITConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.c_proj.weight.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.c_proj(F.relu(self.c_fc(x)).square())
class Block(nn.Module):
def __init__(self, config: PITConfig):
super().__init__()
self.attn = CausalSelfAttention(config)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(F.rms_norm(x, (x.size(-1),)))
x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
return x
# ---------------------------------------------------------------------------
# HuggingFace PreTrainedModel wrapper
# ---------------------------------------------------------------------------
class PITForCausalLM(PreTrainedModel):
"""
Point-In-Time GPT wrapped as a HuggingFace CausalLM.
Supports AutoModelForCausalLM, generate(), and pipeline("text-generation").
Loading
-------
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("Diamegs/PIT-4B-FT-202012")
>>> model = AutoModelForCausalLM.from_pretrained(
... "Diamegs/PIT-4B-FT-202012",
... trust_remote_code=True,
... torch_dtype=torch.bfloat16,
... device_map="auto",
... )
"""
config_class = PITConfig
_no_split_modules = ["Block"]
_supports_cache_class = False
# Weight tying: lm_head and transformer.wte share parameters.
_tied_weights_keys = ["lm_head.weight", "transformer.wte.weight"]
def __init__(self, config: PITConfig):
super().__init__(config)
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Tie weights (re-tied after load_state_dict via tie_weights())
self.transformer["wte"].weight = self.lm_head.weight
self.post_init()
# -- weight tying hooks required by PreTrainedModel ----------------------
def get_input_embeddings(self) -> nn.Embedding:
return self.transformer["wte"]
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.transformer["wte"] = value
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head
def set_output_embeddings(self, value: nn.Linear) -> None:
self.lm_head = value
# -- forward -------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
**kwargs,
) -> CausalLMOutputWithPast:
x = self.transformer["wte"](input_ids)
for block in self.transformer["h"]:
x = block(x)
x = F.rms_norm(x, (x.size(-1),))
logits = self.lm_head(x).float()
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(
self, input_ids: torch.Tensor, **kwargs
) -> dict:
return {"input_ids": input_ids}