import math import os import torch import torch.nn as nn import torch.nn.functional as F # ─── RPW: Relative Positional Warp (learned Fourier additive bias) ────────── class RPW(nn.Module): def __init__(self, num_heads: int, num_freqs: int = 16, max_seq_len: int = 2048): super().__init__() self.num_heads = num_heads self.num_freqs = num_freqs freqs = 1.0 / (10000.0 ** (torch.arange(num_freqs, dtype=torch.float32) / num_freqs)) self.register_buffer("_freqs", freqs) self.W_phi = nn.Parameter(torch.zeros(num_freqs * 2, num_heads)) nn.init.normal_(self.W_phi, std=0.02) def _make_bias(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: pos = torch.arange(seq_len, device=device, dtype=torch.float32) delta = pos.view(1, seq_len) - pos.view(seq_len, 1) angles = delta.unsqueeze(-1) * self._freqs.to(device=device) phi = torch.cat([angles.sin(), angles.cos()], dim=-1) bias = phi @ self.W_phi bias = bias.permute(2, 0, 1).unsqueeze(0) causal = torch.triu(torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype), diagonal=1) return (bias + causal.unsqueeze(0).unsqueeze(0)).to(dtype=dtype) # ─── GQA with RPW ────────────────────────────────────────────────────────── class GQAAttention(nn.Module): def __init__(self, hidden: int, num_heads: int, num_kv_heads: int): super().__init__() self.hidden = hidden self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = hidden // num_heads self.num_groups = num_heads // num_kv_heads self.q_proj = nn.Linear(hidden, hidden, bias=False) self.k_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(hidden, hidden, bias=False) self.rpw = RPW(num_heads, num_freqs=16) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) if self.num_groups > 1: k = k.repeat_interleave(self.num_groups, dim=1) v = v.repeat_interleave(self.num_groups, dim=1) rpw_bias = self.rpw._make_bias(T, x.device, q.dtype) attn = F.scaled_dot_product_attention(q, k, v, attn_mask=rpw_bias, is_causal=False) out = attn.transpose(1, 2).contiguous().view(B, T, self.hidden) return self.o_proj(out) # ─── GPP: Gated Principal Projection ────────────────────────────────────── class GPP(nn.Module): def __init__(self, hidden: int, code_dim: int): super().__init__() self.down = nn.Linear(hidden, code_dim, bias=False) self.up = nn.Linear(code_dim, hidden, bias=False) def forward(self, x): return self.up(F.silu(self.down(x))) # ─── VCR: Variance-Controlled Residual ───────────────────────────────────── class VCR(nn.Module): def __init__(self): super().__init__() self.alpha = nn.Parameter(torch.tensor(1.0)) self.beta = nn.Parameter(torch.tensor(2.0)) def forward(self, x, delta): v_exp = delta.float().norm() / x.float().norm().clamp(min=1e-8) scale = torch.sigmoid(self.alpha * v_exp + self.beta) return x + scale * delta # ─── RMSNorm ─────────────────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * norm).type_as(x) * self.weight # ─── Transformer Block ──────────────────────────────────────────────────── class TransformerBlock(nn.Module): def __init__(self, hidden: int, code_dim: int, num_heads: int, num_kv_heads: int): super().__init__() self.ln1 = RMSNorm(hidden) self.attn = GQAAttention(hidden, num_heads, num_kv_heads) self.vcr_attn = VCR() self.ln2 = RMSNorm(hidden) self.mlp = GPP(hidden, code_dim) self.vcr_mlp = VCR() def forward(self, x): x = self.vcr_attn(x, self.attn(self.ln1(x))) x = self.vcr_mlp(x, self.mlp(self.ln2(x))) return x # ─── TinyModel V3 ────────────────────────────────────────────────────────── class TinyModel(nn.Module): def __init__(self, vocab_size=4096, hidden=128, code_dim=96, num_layers=6, num_heads=8, num_kv_heads=4, max_seq_len=2048, tie_weights=True): super().__init__() self.hidden = hidden self.max_seq_len = max_seq_len self.token_embed = nn.Embedding(vocab_size, hidden) self.blocks = nn.ModuleList([ TransformerBlock(hidden, code_dim, num_heads, num_kv_heads) for _ in range(num_layers) ]) self.ln_f = RMSNorm(hidden) self.lm_head = nn.Linear(hidden, vocab_size, bias=False) if tie_weights: self.lm_head.weight = self.token_embed.weight def reset_weights(self): for m in self.modules(): if isinstance(m, (nn.Linear, nn.Embedding)) and m.weight.dim() >= 2: fan_in = m.weight.shape[1] if m.weight.dim() >= 2 else 1 nn.init.normal_(m.weight, std=1.0 / math.sqrt(fan_in)) elif isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, RMSNorm): nn.init.ones_(m.weight) for m in self.modules(): if isinstance(m, RPW): nn.init.normal_(m.W_phi, std=0.02) def forward(self, input_ids, labels=None): x = self.token_embed(input_ids) for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return logits, loss def generate(self, input_ids, max_new_tokens=128, temperature=0.7, top_p=0.9, stream_callback=None): self.eval() for _ in range(max_new_tokens): seq_len = input_ids.size(1) if seq_len > self.max_seq_len: input_ids = input_ids[:, -self.max_seq_len:] with torch.no_grad(): logits, _ = self.forward(input_ids) logits = logits[:, -1, :] if temperature > 0: logits = logits / temperature if top_p < 1.0: sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) cutoff = cum_probs > top_p cutoff[..., 1:] = cutoff[..., :-1].clone() cutoff[..., 0] = False logits[~cutoff] = float('-inf') probs = F.softmax(logits, dim=-1) if temperature > 0: next_token = torch.multinomial(probs, num_samples=1) else: next_token = probs.argmax(dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=1) if stream_callback: stream_callback(next_token.item()) if next_token.item() == 2: break return input_ids # ─── Restore from V2 checkpoint ──────────────────────────────────────────── def restore_from_v2(v2_ckpt_path: str | None = None, strict: bool = False) -> TinyModel: model = TinyModel(vocab_size=4096, hidden=128, code_dim=96, num_layers=6, num_heads=8, num_kv_heads=4, max_seq_len=2048, tie_weights=True) model.reset_weights() if v2_ckpt_path is None or not os.path.exists(v2_ckpt_path): print(f"No V2 checkpoint at {v2_ckpt_path} — starting from random init") model.lm_head.weight = model.token_embed.weight return model raw = torch.load(v2_ckpt_path, map_location="cpu", weights_only=True) v2_sd = raw["model"] v2_embed = v2_sd["token_embed.weight"] restore_map = {} restore_map["token_embed.weight"] = v2_embed restore_map["ln_f.weight"] = v2_sd["ln_f.weight"] for i in range(3): restore_map[f"blocks.{i}.ln1.weight"] = v2_sd[f"blocks.{i}.ln1.weight"] restore_map[f"blocks.{i}.ln2.weight"] = v2_sd[f"blocks.{i}.ln2.weight"] restored = 0 for k, v in restore_map.items(): if k in model.state_dict(): target = model.state_dict()[k] if target.shape == v.shape: target.copy_(v) restored += 1 elif target.dim() == 2 and v.dim() == 2 and target.shape[1] == v.shape[1]: d = min(target.shape[0], v.shape[0]) target[:d].copy_(v[:d]) restored += 1 model.lm_head.weight = model.token_embed.weight total_v2 = sum(p.numel() for p in v2_sd.values()) restored_params = sum(v.numel() for v in restore_map.values()) v3_params = sum(p.numel() for p in model.parameters()) print(f"V2 checkpoint: {len(v2_sd)} keys, {total_v2:,} params") print(f"V3 model: {v3_params:,} params") print(f"Restored: {restored} tensors ({restored_params:,} params, {100*restored_params/v3_params:.1f}%)") return model # ─── QLoRA (4-bit NF4 + LoRA) ────────────────────────────────────────────── NF4_LEVELS = torch.tensor([ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0, ], dtype=torch.float32) def _quantize_nf4_row(row: torch.Tensor) -> tuple: absmax = row.abs().max().clamp(min=1e-12) scaled = row / absmax idx = (scaled[:, None] - NF4_LEVELS[None, :].to(row.device)).abs().argmin(dim=-1) return idx.to(torch.uint8), absmax def pack_nf4(indices: torch.Tensor) -> torch.Tensor: n = indices.numel() if n % 2 != 0: indices = torch.cat([indices, indices.new_zeros(1)]) packed = (indices[0::2].to(torch.uint8) | (indices[1::2].to(torch.uint8) << 4)) return packed def unpack_nf4(packed: torch.Tensor, shape) -> torch.Tensor: n = shape[0] * shape[1] low = (packed & 0x0F).to(torch.long) high = ((packed >> 4) & 0x0F).to(torch.long) indices = torch.stack([low, high], dim=-1).reshape(n) return indices[:shape[0] * shape[1]].reshape(shape) def dequantize_nf4(packed_weight: torch.Tensor, scales: torch.Tensor, shape) -> torch.Tensor: indices = unpack_nf4(packed_weight, shape) return NF4_LEVELS[indices] * scales[:, None] class QLoRALinear(nn.Module): def __init__(self, in_features: int, out_features: int, r: int = 8, alpha: float = 16, dropout: float = 0.0): super().__init__() self.in_features = in_features self.out_features = out_features self.r = r self.alpha = alpha self.scaling = alpha / r self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() n_elements = in_features * out_features n_packed = (n_elements + 1) // 2 self.register_buffer("qweight", torch.zeros(n_packed, dtype=torch.uint8)) self.register_buffer("scales", torch.zeros(out_features, dtype=torch.float32)) self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float32)) self._has_bias = False self._shape = (out_features, in_features) self.lora_A = nn.Parameter(torch.zeros(r, in_features)) self.lora_B = nn.Parameter(torch.zeros(out_features, r)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) def _dequantized_weight(self) -> torch.Tensor: dev = self.qweight.device indices = unpack_nf4(self.qweight, self._shape) return NF4_LEVELS.to(dev)[indices] * self.scales.to(dev)[:, None] def quantize_from(self, weight: torch.Tensor, bias: torch.Tensor | None = None): w = weight.float().detach() rows = [] scales = [] for i in range(w.shape[0]): idx, s = _quantize_nf4_row(w[i]) rows.append(idx) scales.append(s.item()) all_idx = torch.stack(rows) self.qweight.copy_(pack_nf4(all_idx.reshape(-1))) self.scales.copy_(torch.tensor(scales, dtype=torch.float32)) if bias is not None: self.bias.copy_(bias.float().detach()) self._has_bias = True def forward(self, x): w = self._dequantized_weight() b = self.bias if self._has_bias else None base = F.linear(x, w, b) return base + self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling def apply_qlora(model, target_modules=None, r=8, alpha=16, dropout=0.0, freeze_embeds=True): if target_modules is None: target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "down", "up"] qlora_params = 0 for name, module in model.named_modules(): if not isinstance(module, nn.Linear): continue key = name.split(".")[-1] if key not in target_modules: continue parent = model parts = name.split(".") for p in parts[:-1]: parent = getattr(parent, p) qlora = QLoRALinear(module.in_features, module.out_features, r=r, alpha=alpha, dropout=dropout) qlora.quantize_from(module.weight, module.bias) qlora = qlora.to(module.weight.device) setattr(parent, parts[-1], qlora) qlora_params += 2 * r * module.in_features + module.out_features * r for name, param in model.named_parameters(): if "lora_" not in name: param.requires_grad = False n = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"QLoRA applied: {qlora_params:,} LoRA params | trainable: {n:,}") return model # ─── Helpers ─────────────────────────────────────────────────────────────── def count_params(model): return sum(p.numel() for p in model.parameters()) def create_model(): model = restore_from_v2("checkpoint.pt") n = count_params(model) print(f"TinyModel V3: {n:,} params ({n/1e6:.2f}M)") return model if __name__ == "__main__": m = create_model() x = torch.randint(0, 100, (1, 16)) logits, loss = m(x, labels=x) print(f"Forward OK: logits {logits.shape}, loss {loss.item():.4f}") gen = m.generate(x, max_new_tokens=10) print(f"Generate OK: {gen.shape}")