Marxist-Leninist
AGILLM4-DiffusionBlocks: block-wise AR+SAT+NAT denoising, fused CE, tied heads
82d098e | """DiffusionBlocks prototype for AGILLM-4 (block-wise denoising training). | |
| Faithful to SakanaAI/DiffusionBlocks: EDM (Karras) diffusion, equi-probability | |
| block sigma partitioning, and the key property -- each step trains ONE block as a | |
| denoiser for its noise range, so backprop touches only that block => training | |
| memory ~ 1/B of end-to-end. Here we reuse a transformer-block stack (stand-in for | |
| AGILLM-4's Encoder.blocks: same pre-norm residual Block API x = blk(x, mask)). | |
| """ | |
| import math, numpy as np, torch, torch.nn as nn, torch.nn.functional as F | |
| def _cdf(x): return 0.5*(1+math.erf(x/math.sqrt(2))) | |
| def _ppf(p): return float(torch.erfinv(torch.tensor(2*p-1.0))*math.sqrt(2)) | |
| # ---- EDM block partitioning (verbatim mechanism from dblock_modules.py) ---- | |
| def get_block_sigmas(num_blocks, sigma_min=0.002, sigma_max=80.0, p_mean=-1.2, p_std=1.2): | |
| cdf_min = _cdf((np.log(sigma_min)-p_mean)/p_std) | |
| cdf_max = _cdf((np.log(sigma_max)-p_mean)/p_std) | |
| out=[] | |
| for i in range(num_blocks+1): | |
| p = cdf_min + (cdf_max-cdf_min)*(i/num_blocks) | |
| out.append(float(np.exp(p_mean+p_std*_ppf(p)))) | |
| return out # descending->ascending sigma boundaries, equal CDF mass per block | |
| class DBlockLM(nn.Module): | |
| """Shared token emb + output proj; n_layers split into B independent blocks.""" | |
| def __init__(self, vocab=2048, d=256, n_layers=8, heads=4, num_blocks=4): | |
| super().__init__() | |
| self.emb = nn.Embedding(vocab, d) | |
| self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d, heads, 4*d, | |
| batch_first=True, norm_first=True) for _ in range(n_layers)]) | |
| self.ln = nn.LayerNorm(d); self.out = nn.Linear(d, vocab) | |
| self.num_blocks = num_blocks | |
| split = n_layers // num_blocks | |
| self.assign = [list(range(i*split,(i+1)*split)) for i in range(num_blocks)] | |
| self.block_sigmas = get_block_sigmas(num_blocks) | |
| self.sigma_data = 0.5 | |
| def run_block(self, b, zt, sigma): | |
| # EDM preconditioning (Karras 2022), exactly as DiffusionBlocks.denoise() | |
| s = sigma[:,None,None] | |
| c_skip = self.sigma_data**2/(s**2+self.sigma_data**2) | |
| c_out = s*self.sigma_data/(s**2+self.sigma_data**2)**0.5 | |
| c_in = 1/(s**2+self.sigma_data**2)**0.5 | |
| h = zt*c_in | |
| for li in self.assign[b]: # <-- ONLY this block's layers run | |
| h = self.blocks[li](h) | |
| return c_skip*zt + c_out*h # denoiser D_theta(zt,sigma) -> predicts z0 | |
| def edm_weight(sigma, sigma_data=0.5): | |
| return (sigma**2+sigma_data**2)/(sigma*sigma_data)**2 | |
| def train_step(model, opt, ids): | |
| z0 = F.normalize(model.emb(ids), p=2, dim=-1).detach() # clean target embeds | |
| b = np.random.randint(model.num_blocks) # pick one block | |
| lo, hi = model.block_sigmas[b], model.block_sigmas[b+1] | |
| lo, hi = min(lo,hi), max(lo,hi) | |
| sigma = torch.from_numpy(np.exp(np.random.uniform(np.log(lo),np.log(hi), | |
| size=ids.shape[0])).astype('float32')) | |
| zt = z0 + sigma[:,None,None]*torch.randn_like(z0) | |
| D = model.run_block(b, zt, sigma) # backprop only block b | |
| loss = (edm_weight(sigma)[:,None,None]*(D-z0)**2).mean() | |
| opt.zero_grad(set_to_none=True); loss.backward(); opt.step() | |
| return b, loss.item() | |
| if __name__ == "__main__": | |
| torch.manual_seed(0) | |
| NB=4; m=DBlockLM(num_blocks=NB); opt=torch.optim.AdamW(m.parameters(),1e-3) | |
| ids=torch.randint(0,2048,(8,64)) | |
| print("layers:",len(m.blocks),"| blocks:",NB,"| layer assignment:",m.assign) | |
| print("block sigma boundaries:",[round(s,3) for s in m.block_sigmas]) | |
| # one step, then verify ONLY the chosen block's layer params got gradients | |
| b,loss=train_step(m,opt,ids) | |
| def has_grad(mod): return any(p.grad is not None and p.grad.abs().sum()>0 for p in mod.parameters()) | |
| grad_blocks=[i for i in range(NB) if any(has_grad(m.blocks[li]) for li in m.assign[i])] | |
| tot=sum(p.numel() for p in m.parameters()) | |
| blk_params=sum(p.numel() for li in m.assign[b] for p in m.blocks[li].parameters()) | |
| print(f"\\nstep trained block {b} loss={loss:.4f}") | |
| print("blocks whose layers received gradients:",grad_blocks,"(expect just",[b],")") | |
| print(f"layer-params with grad this step: {blk_params:,} / {sum(p.numel() for blk in m.blocks for p in blk.parameters()):,}" | |
| f" ({100*blk_params/sum(p.numel() for blk in m.blocks for p in blk.parameters()):.0f}% = ~1/{NB})") | |
| print("\\n--- a few more steps (each trains one block independently) ---") | |
| for _ in range(6): | |
| b,loss=train_step(m,opt,ids); print(f" block {b}: loss={loss:.4f}") | |