| import math |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
|
|
| class RPW(nn.Module): |
| def __init__(self, num_heads: int, num_freqs: int = 16, max_seq_len: int = 2048): |
| super().__init__() |
| self.num_heads = num_heads |
| self.num_freqs = num_freqs |
| freqs = 1.0 / (10000.0 ** (torch.arange(num_freqs, dtype=torch.float32) / num_freqs)) |
| self.register_buffer("_freqs", freqs) |
| self.W_phi = nn.Parameter(torch.zeros(num_freqs * 2, num_heads)) |
| nn.init.normal_(self.W_phi, std=0.02) |
| def _make_bias(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
| pos = torch.arange(seq_len, device=device, dtype=torch.float32) |
| delta = pos.view(1, seq_len) - pos.view(seq_len, 1) |
| angles = delta.unsqueeze(-1) * self._freqs.to(device=device) |
| phi = torch.cat([angles.sin(), angles.cos()], dim=-1) |
| bias = phi @ self.W_phi |
| bias = bias.permute(2, 0, 1).unsqueeze(0) |
| causal = torch.triu(torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype), diagonal=1) |
| return (bias + causal.unsqueeze(0).unsqueeze(0)).to(dtype=dtype) |
|
|
|
|
| |
|
|
| class GQAAttention(nn.Module): |
| def __init__(self, hidden: int, num_heads: int, num_kv_heads: int): |
| super().__init__() |
| self.hidden = hidden |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads |
| self.head_dim = hidden // num_heads |
| self.num_groups = num_heads // num_kv_heads |
|
|
| self.q_proj = nn.Linear(hidden, hidden, bias=False) |
| self.k_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(hidden, hidden, bias=False) |
| self.rpw = RPW(num_heads, num_freqs=16) |
|
|
| def forward(self, x): |
| B, T, _ = x.shape |
| q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| if self.num_groups > 1: |
| k = k.repeat_interleave(self.num_groups, dim=1) |
| v = v.repeat_interleave(self.num_groups, dim=1) |
|
|
| rpw_bias = self.rpw._make_bias(T, x.device, q.dtype) |
| attn = F.scaled_dot_product_attention(q, k, v, attn_mask=rpw_bias, is_causal=False) |
| out = attn.transpose(1, 2).contiguous().view(B, T, self.hidden) |
| return self.o_proj(out) |
|
|
|
|
| |
|
|
| class GPP(nn.Module): |
| def __init__(self, hidden: int, code_dim: int): |
| super().__init__() |
| self.down = nn.Linear(hidden, code_dim, bias=False) |
| self.up = nn.Linear(code_dim, hidden, bias=False) |
|
|
| def forward(self, x): |
| return self.up(F.silu(self.down(x))) |
|
|
|
|
| |
|
|
| class VCR(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.tensor(1.0)) |
| self.beta = nn.Parameter(torch.tensor(2.0)) |
|
|
| def forward(self, x, delta): |
| v_exp = delta.float().norm() / x.float().norm().clamp(min=1e-8) |
| scale = torch.sigmoid(self.alpha * v_exp + self.beta) |
| return x + scale * delta |
|
|
|
|
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() |
| return (x.float() * norm).type_as(x) * self.weight |
|
|
|
|
| |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, hidden: int, code_dim: int, num_heads: int, num_kv_heads: int): |
| super().__init__() |
| self.ln1 = RMSNorm(hidden) |
| self.attn = GQAAttention(hidden, num_heads, num_kv_heads) |
| self.vcr_attn = VCR() |
| self.ln2 = RMSNorm(hidden) |
| self.mlp = GPP(hidden, code_dim) |
| self.vcr_mlp = VCR() |
|
|
| def forward(self, x): |
| x = self.vcr_attn(x, self.attn(self.ln1(x))) |
| x = self.vcr_mlp(x, self.mlp(self.ln2(x))) |
| return x |
|
|
|
|
| |
|
|
| class TinyModel(nn.Module): |
| def __init__(self, vocab_size=4096, hidden=128, code_dim=96, |
| num_layers=6, num_heads=8, num_kv_heads=4, max_seq_len=2048, |
| tie_weights=True): |
| super().__init__() |
| self.hidden = hidden |
| self.max_seq_len = max_seq_len |
| self.token_embed = nn.Embedding(vocab_size, hidden) |
| self.blocks = nn.ModuleList([ |
| TransformerBlock(hidden, code_dim, num_heads, num_kv_heads) |
| for _ in range(num_layers) |
| ]) |
| self.ln_f = RMSNorm(hidden) |
| self.lm_head = nn.Linear(hidden, vocab_size, bias=False) |
| if tie_weights: |
| self.lm_head.weight = self.token_embed.weight |
|
|
| def reset_weights(self): |
| for m in self.modules(): |
| if isinstance(m, (nn.Linear, nn.Embedding)) and m.weight.dim() >= 2: |
| fan_in = m.weight.shape[1] if m.weight.dim() >= 2 else 1 |
| nn.init.normal_(m.weight, std=1.0 / math.sqrt(fan_in)) |
| elif isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, RMSNorm): |
| nn.init.ones_(m.weight) |
| for m in self.modules(): |
| if isinstance(m, RPW): |
| nn.init.normal_(m.W_phi, std=0.02) |
|
|
| def forward(self, input_ids, labels=None): |
| x = self.token_embed(input_ids) |
| for block in self.blocks: |
| x = block(x) |
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| return logits, loss |
|
|
| def generate(self, input_ids, max_new_tokens=128, temperature=0.7, top_p=0.9, stream_callback=None): |
| self.eval() |
| for _ in range(max_new_tokens): |
| seq_len = input_ids.size(1) |
| if seq_len > self.max_seq_len: |
| input_ids = input_ids[:, -self.max_seq_len:] |
| with torch.no_grad(): |
| logits, _ = self.forward(input_ids) |
| logits = logits[:, -1, :] |
| if temperature > 0: |
| logits = logits / temperature |
| if top_p < 1.0: |
| sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) |
| cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
| cutoff = cum_probs > top_p |
| cutoff[..., 1:] = cutoff[..., :-1].clone() |
| cutoff[..., 0] = False |
| logits[~cutoff] = float('-inf') |
| probs = F.softmax(logits, dim=-1) |
| if temperature > 0: |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| next_token = probs.argmax(dim=-1, keepdim=True) |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
| if stream_callback: |
| stream_callback(next_token.item()) |
| if next_token.item() == 2: |
| break |
| return input_ids |
|
|
|
|
| |
|
|
| def restore_from_v2(v2_ckpt_path: str | None = None, strict: bool = False) -> TinyModel: |
| model = TinyModel(vocab_size=4096, hidden=128, code_dim=96, |
| num_layers=6, num_heads=8, num_kv_heads=4, |
| max_seq_len=2048, tie_weights=True) |
| model.reset_weights() |
|
|
| if v2_ckpt_path is None or not os.path.exists(v2_ckpt_path): |
| print(f"No V2 checkpoint at {v2_ckpt_path} β starting from random init") |
| model.lm_head.weight = model.token_embed.weight |
| return model |
|
|
| raw = torch.load(v2_ckpt_path, map_location="cpu", weights_only=True) |
| v2_sd = raw["model"] |
| v2_embed = v2_sd["token_embed.weight"] |
|
|
| restore_map = {} |
| restore_map["token_embed.weight"] = v2_embed |
| restore_map["ln_f.weight"] = v2_sd["ln_f.weight"] |
| for i in range(3): |
| restore_map[f"blocks.{i}.ln1.weight"] = v2_sd[f"blocks.{i}.ln1.weight"] |
| restore_map[f"blocks.{i}.ln2.weight"] = v2_sd[f"blocks.{i}.ln2.weight"] |
|
|
| restored = 0 |
| for k, v in restore_map.items(): |
| if k in model.state_dict(): |
| target = model.state_dict()[k] |
| if target.shape == v.shape: |
| target.copy_(v) |
| restored += 1 |
| elif target.dim() == 2 and v.dim() == 2 and target.shape[1] == v.shape[1]: |
| d = min(target.shape[0], v.shape[0]) |
| target[:d].copy_(v[:d]) |
| restored += 1 |
|
|
| model.lm_head.weight = model.token_embed.weight |
|
|
| total_v2 = sum(p.numel() for p in v2_sd.values()) |
| restored_params = sum(v.numel() for v in restore_map.values()) |
| v3_params = sum(p.numel() for p in model.parameters()) |
| print(f"V2 checkpoint: {len(v2_sd)} keys, {total_v2:,} params") |
| print(f"V3 model: {v3_params:,} params") |
| print(f"Restored: {restored} tensors ({restored_params:,} params, {100*restored_params/v3_params:.1f}%)") |
| return model |
|
|
|
|
| |
|
|
| NF4_LEVELS = torch.tensor([ |
| -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, |
| -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, |
| 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, |
| 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, |
| 0.7229568362236023, 1.0, |
| ], dtype=torch.float32) |
|
|
|
|
| def _quantize_nf4_row(row: torch.Tensor) -> tuple: |
| absmax = row.abs().max().clamp(min=1e-12) |
| scaled = row / absmax |
| idx = (scaled[:, None] - NF4_LEVELS[None, :].to(row.device)).abs().argmin(dim=-1) |
| return idx.to(torch.uint8), absmax |
|
|
|
|
| def pack_nf4(indices: torch.Tensor) -> torch.Tensor: |
| n = indices.numel() |
| if n % 2 != 0: |
| indices = torch.cat([indices, indices.new_zeros(1)]) |
| packed = (indices[0::2].to(torch.uint8) | (indices[1::2].to(torch.uint8) << 4)) |
| return packed |
|
|
|
|
| def unpack_nf4(packed: torch.Tensor, shape) -> torch.Tensor: |
| n = shape[0] * shape[1] |
| low = (packed & 0x0F).to(torch.long) |
| high = ((packed >> 4) & 0x0F).to(torch.long) |
| indices = torch.stack([low, high], dim=-1).reshape(n) |
| return indices[:shape[0] * shape[1]].reshape(shape) |
|
|
|
|
| def dequantize_nf4(packed_weight: torch.Tensor, scales: torch.Tensor, shape) -> torch.Tensor: |
| indices = unpack_nf4(packed_weight, shape) |
| return NF4_LEVELS[indices] * scales[:, None] |
|
|
|
|
| class QLoRALinear(nn.Module): |
| def __init__(self, in_features: int, out_features: int, r: int = 8, alpha: float = 16, dropout: float = 0.0): |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.r = r |
| self.alpha = alpha |
| self.scaling = alpha / r |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
| n_elements = in_features * out_features |
| n_packed = (n_elements + 1) // 2 |
| self.register_buffer("qweight", torch.zeros(n_packed, dtype=torch.uint8)) |
| self.register_buffer("scales", torch.zeros(out_features, dtype=torch.float32)) |
| self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float32)) |
| self._has_bias = False |
| self._shape = (out_features, in_features) |
|
|
| self.lora_A = nn.Parameter(torch.zeros(r, in_features)) |
| self.lora_B = nn.Parameter(torch.zeros(out_features, r)) |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
|
| def _dequantized_weight(self) -> torch.Tensor: |
| dev = self.qweight.device |
| indices = unpack_nf4(self.qweight, self._shape) |
| return NF4_LEVELS.to(dev)[indices] * self.scales.to(dev)[:, None] |
|
|
| def quantize_from(self, weight: torch.Tensor, bias: torch.Tensor | None = None): |
| w = weight.float().detach() |
| rows = [] |
| scales = [] |
| for i in range(w.shape[0]): |
| idx, s = _quantize_nf4_row(w[i]) |
| rows.append(idx) |
| scales.append(s.item()) |
| all_idx = torch.stack(rows) |
| self.qweight.copy_(pack_nf4(all_idx.reshape(-1))) |
| self.scales.copy_(torch.tensor(scales, dtype=torch.float32)) |
| if bias is not None: |
| self.bias.copy_(bias.float().detach()) |
| self._has_bias = True |
|
|
| def forward(self, x): |
| w = self._dequantized_weight() |
| b = self.bias if self._has_bias else None |
| base = F.linear(x, w, b) |
| return base + self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling |
|
|
|
|
| def apply_qlora(model, target_modules=None, r=8, alpha=16, dropout=0.0, freeze_embeds=True): |
| if target_modules is None: |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "down", "up"] |
| qlora_params = 0 |
| for name, module in model.named_modules(): |
| if not isinstance(module, nn.Linear): |
| continue |
| key = name.split(".")[-1] |
| if key not in target_modules: |
| continue |
| parent = model |
| parts = name.split(".") |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| qlora = QLoRALinear(module.in_features, module.out_features, r=r, alpha=alpha, dropout=dropout) |
| qlora.quantize_from(module.weight, module.bias) |
| qlora = qlora.to(module.weight.device) |
| setattr(parent, parts[-1], qlora) |
| qlora_params += 2 * r * module.in_features + module.out_features * r |
| for name, param in model.named_parameters(): |
| if "lora_" not in name: |
| param.requires_grad = False |
| n = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"QLoRA applied: {qlora_params:,} LoRA params | trainable: {n:,}") |
| return model |
|
|
|
|
| |
|
|
| def count_params(model): |
| return sum(p.numel() for p in model.parameters()) |
|
|
|
|
| def create_model(): |
| model = restore_from_v2("checkpoint.pt") |
| n = count_params(model) |
| print(f"TinyModel V3: {n:,} params ({n/1e6:.2f}M)") |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| m = create_model() |
| x = torch.randint(0, 100, (1, 16)) |
| logits, loss = m(x, labels=x) |
| print(f"Forward OK: logits {logits.shape}, loss {loss.item():.4f}") |
| gen = m.generate(x, max_new_tokens=10) |
| print(f"Generate OK: {gen.shape}") |
|
|