lumia-tiny / model_tiny.py
samcheng0's picture
Upload model_tiny.py with huggingface_hub
248968d verified
Raw
History Blame Contribute Delete
16.2 kB
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─── RPW: Relative Positional Warp (learned Fourier additive bias) ──────────
class RPW(nn.Module):
def __init__(self, num_heads: int, num_freqs: int = 16, max_seq_len: int = 2048):
super().__init__()
self.num_heads = num_heads
self.num_freqs = num_freqs
freqs = 1.0 / (10000.0 ** (torch.arange(num_freqs, dtype=torch.float32) / num_freqs))
self.register_buffer("_freqs", freqs)
self.W_phi = nn.Parameter(torch.zeros(num_freqs * 2, num_heads))
nn.init.normal_(self.W_phi, std=0.02)
def _make_bias(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
pos = torch.arange(seq_len, device=device, dtype=torch.float32)
delta = pos.view(1, seq_len) - pos.view(seq_len, 1)
angles = delta.unsqueeze(-1) * self._freqs.to(device=device)
phi = torch.cat([angles.sin(), angles.cos()], dim=-1)
bias = phi @ self.W_phi
bias = bias.permute(2, 0, 1).unsqueeze(0)
causal = torch.triu(torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype), diagonal=1)
return (bias + causal.unsqueeze(0).unsqueeze(0)).to(dtype=dtype)
# ─── GQA with RPW ──────────────────────────────────────────────────────────
class GQAAttention(nn.Module):
def __init__(self, hidden: int, num_heads: int, num_kv_heads: int):
super().__init__()
self.hidden = hidden
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden // num_heads
self.num_groups = num_heads // num_kv_heads
self.q_proj = nn.Linear(hidden, hidden, bias=False)
self.k_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(hidden, hidden, bias=False)
self.rpw = RPW(num_heads, num_freqs=16)
def forward(self, x):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.num_groups > 1:
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
rpw_bias = self.rpw._make_bias(T, x.device, q.dtype)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=rpw_bias, is_causal=False)
out = attn.transpose(1, 2).contiguous().view(B, T, self.hidden)
return self.o_proj(out)
# ─── GPP: Gated Principal Projection ──────────────────────────────────────
class GPP(nn.Module):
def __init__(self, hidden: int, code_dim: int):
super().__init__()
self.down = nn.Linear(hidden, code_dim, bias=False)
self.up = nn.Linear(code_dim, hidden, bias=False)
def forward(self, x):
return self.up(F.silu(self.down(x)))
# ─── VCR: Variance-Controlled Residual ─────────────────────────────────────
class VCR(nn.Module):
def __init__(self):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(1.0))
self.beta = nn.Parameter(torch.tensor(2.0))
def forward(self, x, delta):
v_exp = delta.float().norm() / x.float().norm().clamp(min=1e-8)
scale = torch.sigmoid(self.alpha * v_exp + self.beta)
return x + scale * delta
# ─── RMSNorm ───────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
# ─── Transformer Block ────────────────────────────────────────────────────
class TransformerBlock(nn.Module):
def __init__(self, hidden: int, code_dim: int, num_heads: int, num_kv_heads: int):
super().__init__()
self.ln1 = RMSNorm(hidden)
self.attn = GQAAttention(hidden, num_heads, num_kv_heads)
self.vcr_attn = VCR()
self.ln2 = RMSNorm(hidden)
self.mlp = GPP(hidden, code_dim)
self.vcr_mlp = VCR()
def forward(self, x):
x = self.vcr_attn(x, self.attn(self.ln1(x)))
x = self.vcr_mlp(x, self.mlp(self.ln2(x)))
return x
# ─── TinyModel V3 ──────────────────────────────────────────────────────────
class TinyModel(nn.Module):
def __init__(self, vocab_size=4096, hidden=128, code_dim=96,
num_layers=6, num_heads=8, num_kv_heads=4, max_seq_len=2048,
tie_weights=True):
super().__init__()
self.hidden = hidden
self.max_seq_len = max_seq_len
self.token_embed = nn.Embedding(vocab_size, hidden)
self.blocks = nn.ModuleList([
TransformerBlock(hidden, code_dim, num_heads, num_kv_heads)
for _ in range(num_layers)
])
self.ln_f = RMSNorm(hidden)
self.lm_head = nn.Linear(hidden, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.token_embed.weight
def reset_weights(self):
for m in self.modules():
if isinstance(m, (nn.Linear, nn.Embedding)) and m.weight.dim() >= 2:
fan_in = m.weight.shape[1] if m.weight.dim() >= 2 else 1
nn.init.normal_(m.weight, std=1.0 / math.sqrt(fan_in))
elif isinstance(m, nn.Linear) and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, RMSNorm):
nn.init.ones_(m.weight)
for m in self.modules():
if isinstance(m, RPW):
nn.init.normal_(m.W_phi, std=0.02)
def forward(self, input_ids, labels=None):
x = self.token_embed(input_ids)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return logits, loss
def generate(self, input_ids, max_new_tokens=128, temperature=0.7, top_p=0.9, stream_callback=None):
self.eval()
for _ in range(max_new_tokens):
seq_len = input_ids.size(1)
if seq_len > self.max_seq_len:
input_ids = input_ids[:, -self.max_seq_len:]
with torch.no_grad():
logits, _ = self.forward(input_ids)
logits = logits[:, -1, :]
if temperature > 0:
logits = logits / temperature
if top_p < 1.0:
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
cutoff = cum_probs > top_p
cutoff[..., 1:] = cutoff[..., :-1].clone()
cutoff[..., 0] = False
logits[~cutoff] = float('-inf')
probs = F.softmax(logits, dim=-1)
if temperature > 0:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = probs.argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
if stream_callback:
stream_callback(next_token.item())
if next_token.item() == 2:
break
return input_ids
# ─── Restore from V2 checkpoint ────────────────────────────────────────────
def restore_from_v2(v2_ckpt_path: str | None = None, strict: bool = False) -> TinyModel:
model = TinyModel(vocab_size=4096, hidden=128, code_dim=96,
num_layers=6, num_heads=8, num_kv_heads=4,
max_seq_len=2048, tie_weights=True)
model.reset_weights()
if v2_ckpt_path is None or not os.path.exists(v2_ckpt_path):
print(f"No V2 checkpoint at {v2_ckpt_path} β€” starting from random init")
model.lm_head.weight = model.token_embed.weight
return model
raw = torch.load(v2_ckpt_path, map_location="cpu", weights_only=True)
v2_sd = raw["model"]
v2_embed = v2_sd["token_embed.weight"]
restore_map = {}
restore_map["token_embed.weight"] = v2_embed
restore_map["ln_f.weight"] = v2_sd["ln_f.weight"]
for i in range(3):
restore_map[f"blocks.{i}.ln1.weight"] = v2_sd[f"blocks.{i}.ln1.weight"]
restore_map[f"blocks.{i}.ln2.weight"] = v2_sd[f"blocks.{i}.ln2.weight"]
restored = 0
for k, v in restore_map.items():
if k in model.state_dict():
target = model.state_dict()[k]
if target.shape == v.shape:
target.copy_(v)
restored += 1
elif target.dim() == 2 and v.dim() == 2 and target.shape[1] == v.shape[1]:
d = min(target.shape[0], v.shape[0])
target[:d].copy_(v[:d])
restored += 1
model.lm_head.weight = model.token_embed.weight
total_v2 = sum(p.numel() for p in v2_sd.values())
restored_params = sum(v.numel() for v in restore_map.values())
v3_params = sum(p.numel() for p in model.parameters())
print(f"V2 checkpoint: {len(v2_sd)} keys, {total_v2:,} params")
print(f"V3 model: {v3_params:,} params")
print(f"Restored: {restored} tensors ({restored_params:,} params, {100*restored_params/v3_params:.1f}%)")
return model
# ─── QLoRA (4-bit NF4 + LoRA) ──────────────────────────────────────────────
NF4_LEVELS = torch.tensor([
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
0.07958029955625534, 0.16093020141124725, 0.24611230194568634,
0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
0.7229568362236023, 1.0,
], dtype=torch.float32)
def _quantize_nf4_row(row: torch.Tensor) -> tuple:
absmax = row.abs().max().clamp(min=1e-12)
scaled = row / absmax
idx = (scaled[:, None] - NF4_LEVELS[None, :].to(row.device)).abs().argmin(dim=-1)
return idx.to(torch.uint8), absmax
def pack_nf4(indices: torch.Tensor) -> torch.Tensor:
n = indices.numel()
if n % 2 != 0:
indices = torch.cat([indices, indices.new_zeros(1)])
packed = (indices[0::2].to(torch.uint8) | (indices[1::2].to(torch.uint8) << 4))
return packed
def unpack_nf4(packed: torch.Tensor, shape) -> torch.Tensor:
n = shape[0] * shape[1]
low = (packed & 0x0F).to(torch.long)
high = ((packed >> 4) & 0x0F).to(torch.long)
indices = torch.stack([low, high], dim=-1).reshape(n)
return indices[:shape[0] * shape[1]].reshape(shape)
def dequantize_nf4(packed_weight: torch.Tensor, scales: torch.Tensor, shape) -> torch.Tensor:
indices = unpack_nf4(packed_weight, shape)
return NF4_LEVELS[indices] * scales[:, None]
class QLoRALinear(nn.Module):
def __init__(self, in_features: int, out_features: int, r: int = 8, alpha: float = 16, dropout: float = 0.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.r = r
self.alpha = alpha
self.scaling = alpha / r
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
n_elements = in_features * out_features
n_packed = (n_elements + 1) // 2
self.register_buffer("qweight", torch.zeros(n_packed, dtype=torch.uint8))
self.register_buffer("scales", torch.zeros(out_features, dtype=torch.float32))
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float32))
self._has_bias = False
self._shape = (out_features, in_features)
self.lora_A = nn.Parameter(torch.zeros(r, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, r))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
def _dequantized_weight(self) -> torch.Tensor:
dev = self.qweight.device
indices = unpack_nf4(self.qweight, self._shape)
return NF4_LEVELS.to(dev)[indices] * self.scales.to(dev)[:, None]
def quantize_from(self, weight: torch.Tensor, bias: torch.Tensor | None = None):
w = weight.float().detach()
rows = []
scales = []
for i in range(w.shape[0]):
idx, s = _quantize_nf4_row(w[i])
rows.append(idx)
scales.append(s.item())
all_idx = torch.stack(rows)
self.qweight.copy_(pack_nf4(all_idx.reshape(-1)))
self.scales.copy_(torch.tensor(scales, dtype=torch.float32))
if bias is not None:
self.bias.copy_(bias.float().detach())
self._has_bias = True
def forward(self, x):
w = self._dequantized_weight()
b = self.bias if self._has_bias else None
base = F.linear(x, w, b)
return base + self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling
def apply_qlora(model, target_modules=None, r=8, alpha=16, dropout=0.0, freeze_embeds=True):
if target_modules is None:
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "down", "up"]
qlora_params = 0
for name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
key = name.split(".")[-1]
if key not in target_modules:
continue
parent = model
parts = name.split(".")
for p in parts[:-1]:
parent = getattr(parent, p)
qlora = QLoRALinear(module.in_features, module.out_features, r=r, alpha=alpha, dropout=dropout)
qlora.quantize_from(module.weight, module.bias)
qlora = qlora.to(module.weight.device)
setattr(parent, parts[-1], qlora)
qlora_params += 2 * r * module.in_features + module.out_features * r
for name, param in model.named_parameters():
if "lora_" not in name:
param.requires_grad = False
n = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"QLoRA applied: {qlora_params:,} LoRA params | trainable: {n:,}")
return model
# ─── Helpers ───────────────────────────────────────────────────────────────
def count_params(model):
return sum(p.numel() for p in model.parameters())
def create_model():
model = restore_from_v2("checkpoint.pt")
n = count_params(model)
print(f"TinyModel V3: {n:,} params ({n/1e6:.2f}M)")
return model
if __name__ == "__main__":
m = create_model()
x = torch.randint(0, 100, (1, 16))
logits, loss = m(x, labels=x)
print(f"Forward OK: logits {logits.shape}, loss {loss.item():.4f}")
gen = m.generate(x, max_new_tokens=10)
print(f"Generate OK: {gen.shape}")