""" SignGSD — Sign Gradient-Sign Descent optimizer. A minimal optimizer for low-precision (ternary/binary) training. Key property: discards all magnitude information. Only signs matter. This aligns with ternary weight domains where weights are {-1, 0, +1} and updates are discrete flips rather than continuous steps. Memory: zero optimizer state (no momentum buffers). Only stores what torch already tracks (params + grad). 0 bytes overhead vs AdamW's 8 bytes/param (2× float32). """ import torch from torch.optim import Optimizer class ScaledOptum(Optimizer): """ Sign Gradient-Sign Descent. Update rule: p += -lr * (sign(grad) + wd * sign(p)) Compared to AdamW: - No first/second moment estimates (no exp_avg, exp_avg_sq) - No adaptive per-parameter learning rate - Weight decay acts on sign(p) not p itself - Uniform LR across all parameters Why this works for ternary training: Ternary weights live in {-1, 0, +1}. Continuous updates like p -= lr * grad immediately leave the ternary domain. SignGSD sidesteps this by only voting on direction — the actual flip decision (±1 vote, not a continuous step) can be accumulated elsewhere (e.g., T_accum counts sign votes and flips at threshold). """ def __init__(self, params, lr=1e-2, weight_decay=0.0): """ Args: params: iterable of parameters or param groups. lr: uniform learning rate (same for all params, no adaptive scaling). weight_decay: L2-style decay, but applied as wd * sign(p), not wd * p. This pushes ternary weights toward zero when sign(grad) == sign(p), because the update becomes sign(grad) + sign(p) = ±2 (stronger push) or 0 (cancel) when signs disagree. """ defaults = dict(lr=lr, weight_decay=weight_decay) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): """ Perform a single optimization step. Flow: 1. Compute grad.sign() — direction of steepest descent, ±1 per element. Discards all magnitude. This is the core difference from AdamW which uses grad magnitude via adaptive RMS scaling. 2. Optionally add wd * p.sign() — weight decay using _sign_ of weight, not the weight itself. In standard weight decay (wd * p), large weights are regularized more. Here, all nonzero weights (±1 in ternary) receive equal regularization regardless of magnitude. 3. p += -lr * update — apply the sign-based step. Memory: Does NOT allocate any optimizer state. The gradient sign and parameter sign are computed on-the-fly from existing .grad and .data. Returns: loss from closure if provided. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group["lr"] wd = group["weight_decay"] for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.is_sparse: grad = grad.to_dense() # === Core: sign-sign update === # update = sign(grad) ∈ {-1, 0, +1} # Zero gradient → zero update (no flip vote) update = grad.sign() if wd > 0: # Weight decay as sign(p) not p. # For ternary p ∈ {-1, 0, +1}, sign(p) = p (except 0). # This biases toward zero: when grad and p agree, # |update| = 2 (stronger pull back toward zero). # When they disagree, they cancel to 0 (no update). update = update + wd * p.sign() # p += -lr * update # For ternary: the actual flip happens elsewhere. # This step writes to the _latent_ or _accumulator_ values, # not the ternary weights themselves. # (See prepare_ternary_backward + _ternary_update_memory # in the ARBS training loop for the flip pipeline.) p.add_(-lr * update) return loss @torch.no_grad() def get_memory_mb(self, params=None) -> float: """ Compute total memory of given parameters in MB. Unlike AdamW which needs 8 bytes/param for state (2× float32), SignGSD stores zero optimizer state. The memory reported here is just the parameter tensors themselves. """ if params is None: params = [] for group in self.param_groups: params.extend(group["params"]) total_bytes = sum(p.numel() * p.element_size() for p in params) return total_bytes / (1024 * 1024)