| |
| """Latency proxy models and a tiny LUT for hardware correction. |
| |
| This file defines a family-agnostic interface plus concrete proxies (ViT, ResNet, LLM) |
| that estimate latency from *soft structure* (gates) and input size. All proxies accept |
| the trainer's `(model, batch) -> ms` call signature directly (batches may be dict/tuple/tensor). |
| A small, in-memory LUT can be populated from real measurements during training to correct |
| analytic estimates. |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, Optional, Tuple, Union, List |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .gates import iter_gates, _as_like |
|
|
|
|
| |
| |
| |
|
|
| TensorOrBatch = Union[torch.Tensor, Tuple, List, Dict[str, Any]] |
|
|
| def _first_tensor(batch: TensorOrBatch) -> torch.Tensor: |
| """Find the first tensor inside a batch-like structure.""" |
| if torch.is_tensor(batch): |
| return batch |
| if isinstance(batch, dict): |
| |
| for k in ("input_ids", "pixel_values", "images", "x"): |
| v = batch.get(k, None) |
| if torch.is_tensor(v): |
| return v |
| |
| for v in batch.values(): |
| if torch.is_tensor(v): |
| return v |
| raise ValueError("Batch dict has no tensor field I recognize.") |
| if isinstance(batch, (list, tuple)): |
| for v in batch: |
| if torch.is_tensor(v): |
| return v |
| |
| if len(batch) and isinstance(batch[0], (list, tuple)): |
| for v in batch[0]: |
| if torch.is_tensor(v): |
| return v |
| raise ValueError("Cannot find a tensor in the provided batch.") |
|
|
| def _ids_from_batch(batch: TensorOrBatch) -> torch.Tensor: |
| """Return a 2D [B,S] tensor representing token ids for LLMs.""" |
| if isinstance(batch, dict) and "input_ids" in batch and torch.is_tensor(batch["input_ids"]): |
| return batch["input_ids"] |
| t = _first_tensor(batch) |
| if t.dim() >= 2: |
| return t |
| raise ValueError("Cannot infer [B,S] from batch; need 'input_ids' or a 2D tensor.") |
|
|
| def _nchw_from_batch(batch: TensorOrBatch) -> Tuple[int, int, int, int]: |
| """Return NCHW shape from a batch or an explicit (N,C,H,W) tuple/list/tensor.""" |
| if isinstance(batch, (tuple, list)) and len(batch) == 4 and all(isinstance(x, int) for x in batch): |
| return tuple(batch) |
| x = _first_tensor(batch) |
| if x.dim() != 4: |
| raise ValueError(f"Expected NCHW tensor for CNN proxy; got tensor with shape {tuple(x.shape)}") |
| N, C, H, W = map(int, x.shape) |
| return (N, C, H, W) |
|
|
|
|
| |
| |
| |
|
|
| class LatencyProxy(nn.Module): |
| """Abstract proxy producing a scalar latency-like value (ms). |
| |
| Subclasses implement `_predict_raw` and may define `_signature` keys used by |
| a LUT to refine estimates with real measurements. Proxies accept either a |
| batch-like object (dict/tuple/tensor) or an explicit shape tuple. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| def predict( |
| self, |
| model: nn.Module, |
| sample: TensorOrBatch, |
| *, |
| policy=None, |
| step: Optional[int] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """Batch-friendly entry point. `sample` may be a batch or explicit shape.""" |
| return self._predict_raw(model, sample, policy=policy, step=step, **kwargs) |
|
|
| def _predict_raw( |
| self, |
| model: nn.Module, |
| sample: TensorOrBatch, |
| *, |
| policy=None, |
| step: Optional[int] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| def signature( |
| self, |
| model: nn.Module, |
| sample: TensorOrBatch, |
| *, |
| policy=None, |
| step: Optional[int] = None |
| ) -> Tuple: |
| """Return a hashable signature describing the workload shape.""" |
| if torch.is_tensor(sample): |
| shp = tuple(sample.shape) |
| elif isinstance(sample, (tuple, list)): |
| shp = tuple(sample) |
| elif isinstance(sample, dict): |
| |
| shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v)) |
| else: |
| shp = (str(type(sample)),) |
| return (type(self).__name__, shp) |
|
|
|
|
| class LatencyLUT: |
| """Tiny LUT mapping `(signature) -> measured_ms`.""" |
|
|
| def __init__(self): |
| self._table: Dict[Tuple[Any, ...], float] = {} |
|
|
| def update(self, signature: Tuple[Any, ...], measured_ms: float) -> None: |
| self._table[signature] = float(measured_ms) |
|
|
| def get(self, signature: Tuple[Any, ...]) -> Optional[float]: |
| return self._table.get(signature) |
|
|
| def blend(self, raw_estimate: torch.Tensor, signature: Tuple[Any, ...]) -> torch.Tensor: |
| val = self.get(signature) |
| if val is None: |
| return raw_estimate |
| |
| return _as_like(raw_estimate, val) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ViTProxyConfig: |
| scale_ms: float = 1.0 |
| alpha_qkv: float = 1.0 |
| alpha_scores: float = 1.0 |
| alpha_out: float = 1.0 |
| alpha_mlp: float = 1.0 |
|
|
| def _vit_layers(m): |
| enc = getattr(m, "encoder", None) |
| if enc is not None and hasattr(enc, "layer"): |
| return enc.layer |
| vit = getattr(m, "vit", None) |
| if vit is not None and hasattr(vit, "encoder") and hasattr(vit.encoder, "layer"): |
| return vit.encoder.layer |
| raise TypeError("Expected a HF ViT with *.encoder.layer (ViTModel or ViTForImageClassification).") |
|
|
|
|
| class ViTLatencyProxy(LatencyProxy): |
| """Latency proxy for ViT models. Accepts batches or (N,C,H,W) tuples.""" |
|
|
| def __init__(self, cfg: Optional[ViTProxyConfig] = None, lut: Optional[LatencyLUT] = None): |
| super().__init__() |
| self.cfg = cfg or ViTProxyConfig() |
| self.lut = lut or LatencyLUT() |
|
|
| |
| @staticmethod |
| def _input_spec(sample: TensorOrBatch) -> Tuple[int, int, int]: |
| if isinstance(sample, (tuple, list)) and len(sample) == 4 and all(isinstance(x, int) for x in sample): |
| B, C, H, W = sample |
| return int(B), int(H), int(W) |
| x = _first_tensor(sample) |
| if x.dim() != 4: |
| raise ValueError("ViTLatencyProxy expects a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)") |
| B, C, H, W = x.shape |
| return int(B), int(H), int(W) |
|
|
| @staticmethod |
| def _patch_hw(cfg) -> Tuple[int, int]: |
| patch = getattr(cfg, "patch_size", 16) |
| if isinstance(patch, (tuple, list)): |
| return int(patch[0]), int(patch[1]) |
| return int(patch), int(patch) |
|
|
| @staticmethod |
| def _soft_heads_from_block(blk) -> Optional[torch.Tensor]: |
| |
| attn = getattr(getattr(blk, "attention", None), "attention", None) |
| if attn is not None and hasattr(attn, "kept_heads_soft"): |
| return attn.kept_heads_soft() |
| return None |
|
|
| @staticmethod |
| def _find_ffn_gate(blk): |
| inter = getattr(blk, "intermediate", None) |
| if inter is None: |
| return None |
| |
| for nm in ("neuron_gate", "gate", "ffn_gate"): |
| g = getattr(inter, nm, None) |
| if g is not None and hasattr(g, "logits") and hasattr(g, "tau"): |
| return g |
| |
| for m in blk.modules(): |
| if hasattr(m, "logits") and hasattr(m, "tau"): |
| return m |
| return None |
|
|
| |
| def _predict_raw( |
| self, |
| model: nn.Module, |
| sample: TensorOrBatch, |
| *, |
| policy=None, |
| step: Optional[int] = None |
| ) -> torch.Tensor: |
| anchor = next((p for p in model.parameters()), torch.tensor(0.0)) |
|
|
| B, H_img, W_img = self._input_spec(sample) |
| cfg = getattr(model, "config", None) |
| if cfg is None: |
| raise ValueError("Model must expose a HuggingFace-like .config for ViT proxy") |
| ph, pw = self._patch_hw(cfg) |
|
|
| S = _as_like(anchor, 1 + (H_img // ph) * (W_img // pw)) |
| D = _as_like(anchor, int(getattr(cfg, "hidden_size", 768))) |
| Hh = _as_like(anchor, int(getattr(cfg, "num_attention_heads", 12))) |
| Dh = D // Hh |
|
|
| warm = False |
| if policy is not None and step is not None: |
| warm = (step < int(getattr(policy, "warmup_steps", 0))) |
|
|
| total_qkv = _as_like(anchor, 0.0) |
| total_scores = _as_like(anchor, 0.0) |
| total_out = _as_like(anchor, 0.0) |
| total_mlp = _as_like(anchor, 0.0) |
|
|
| default_hidden = _as_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D)))) |
|
|
| layers = _vit_layers(model) |
| for blk in layers: |
| heads_soft = Hh if warm else (self._soft_heads_from_block(blk) or Hh) |
|
|
| |
| if warm: |
| hidden_soft = default_hidden |
| else: |
| g = self._find_ffn_gate(blk) |
| if g is None: |
| hidden_soft = default_hidden |
| else: |
| probs = torch.sigmoid(g.logits / g.tau) |
| group = int(getattr(g, "group", getattr(g, "group_size", 16))) |
| hidden_soft = probs.sum() * _as_like(anchor, group) |
|
|
| D_kept = heads_soft * Dh |
|
|
| total_qkv += 3 * S * D * D_kept |
| total_scores += (S * S) * heads_soft * Dh |
| total_out += S * D_kept * D |
| total_mlp += 2 * S * D * hidden_soft |
|
|
| raw = ( |
| self.cfg.alpha_qkv * total_qkv |
| + self.cfg.alpha_scores * total_scores |
| + self.cfg.alpha_out * total_out |
| + self.cfg.alpha_mlp * total_mlp |
| ) |
| raw_ms = raw * _as_like(anchor, float(self.cfg.scale_ms)) |
|
|
| |
| sig = self.signature(model, sample, policy=policy, step=step) |
| return self.lut.blend(raw_ms, sig) |
|
|
| |
| def signature(self, model: nn.Module, sample, *, policy=None, step: Optional[int] = None) -> Tuple: |
| if torch.is_tensor(sample): |
| shp = tuple(sample.shape) |
| elif isinstance(sample, (tuple, list)): |
| shp = tuple(sample) |
| elif isinstance(sample, dict): |
| shp = tuple((k, tuple(v.shape)) for k, v in sample.items() if torch.is_tensor(v)) |
| else: |
| shp = (str(type(sample)),) |
| cfg = getattr(model, "config", None) |
| heads = int(getattr(cfg, "num_attention_heads", 12)) |
| hidden = int(getattr(cfg, "hidden_size", 768)) |
| inter = int(getattr(cfg, "intermediate_size", 3072)) |
| return ("ViT", shp, heads, hidden, inter) |
|
|
| @torch.no_grad() |
| def calibrate(self, model: nn.Module, shape: tuple, measure_fn, *, device: str = "cuda") -> float: |
| """Set proxy scale so that keep-all student matches measured ms. |
| |
| `measure_fn(model, shape_or_tensor)` should return `(mean_ms, p95_ms)`. |
| """ |
| |
| sample_t = torch.randn(shape, device=device) |
| |
| sample_t = sample_t.to(device) |
| model = model.to(device).eval() |
| mean_ms, _ = measure_fn(model, shape, device=device) |
| soft_ms = self.predict(model, sample_t).item() |
| self.cfg.scale_ms = float(mean_ms / max(soft_ms, 1e-9)) |
| return self.cfg.scale_ms |
|
|
| |
|
|
| @dataclass |
| class ResNetProxyConfig: |
| scale_ms: float = 1.0 |
| alpha_conv: float = 1.0 |
|
|
|
|
| def _as_const_like_resnet(x_like: torch.Tensor, val): |
| return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype) |
|
|
|
|
| def _find_anchor_param(model: nn.Module) -> torch.Tensor: |
| |
| for m in model.modules(): |
| for nm in ("logits", "head_gate"): |
| t = getattr(m, nm, None) |
| if isinstance(t, torch.Tensor): |
| return t |
| for p in model.parameters(): |
| return p |
| return torch.tensor(0.0) |
|
|
|
|
| def _kept_from_gate(module, anchor: torch.Tensor) -> Optional[torch.Tensor]: |
| """Return expected kept channels for a BN gate: probs.sum() * group_size. |
| If no gate is found, return None. |
| """ |
| g = None |
| for nm in ("gate", "neuron_gate", "channel_gate", "bn_gate"): |
| if hasattr(module, nm): |
| g = getattr(module, nm) |
| break |
| if g is None and hasattr(module, "logits") and hasattr(module, "tau"): |
| g = module |
|
|
| if g is None or not hasattr(g, "logits"): |
| return None |
| logits = g.logits |
| tau = float(getattr(g, "tau", 1.5)) |
| group = int(getattr(g, "group", getattr(g, "group_size", 1))) |
| if group <= 0: group = 1 |
| probs = torch.sigmoid(logits / tau) |
| return probs.sum() * _as_const_like_resnet(anchor, group) |
|
|
|
|
| class ResNetLatencyProxy(LatencyProxy): |
| """Latency proxy for ResNet-like backbones with BN gates. |
| |
| Approximates latency with a FLOPs-style sum over convs, using the *expected* |
| kept channels after each BN gate (probs.sum()*group_size). Falls back to the |
| full channel count when a gate is not found. |
| |
| Accepts a batch or an explicit (N,C,H,W) shape. |
| """ |
|
|
| def __init__(self, cfg: Optional[ResNetProxyConfig] = None): |
| super().__init__() |
| self.cfg = cfg or ResNetProxyConfig() |
|
|
| def _add_cost(self, cost_like: torch.Tensor, oc, ic, k, stride, H, W): |
| alpha = _as_const_like_resnet(cost_like, self.cfg.alpha_conv) |
| |
| H = (H + stride - 1) // stride |
| W = (W + stride - 1) // stride |
| flops = _as_const_like_resnet(cost_like, oc) * _as_const_like_resnet(cost_like, ic) * (k * k) * _as_const_like_resnet(cost_like, H) * _as_const_like_resnet(cost_like, W) |
| return cost_like + alpha * flops, H, W |
|
|
| def _predict_raw(self, model: nn.Module, sample: TensorOrBatch, **_) -> torch.Tensor: |
| N, C_in, H0, W0 = _nchw_from_batch(sample) |
| anchor = _find_anchor_param(model) |
| cost = _as_const_like_resnet(anchor, 0.0) |
| H = _as_const_like_resnet(anchor, int(H0)) |
| W = _as_const_like_resnet(anchor, int(W0)) |
|
|
| |
| conv1 = getattr(model, "conv1") |
| bn1 = getattr(model, "bn1", None) |
| k = conv1.kernel_size[0] |
| s = conv1.stride[0] |
| kept_out = None |
| if bn1 is not None: |
| kept = _kept_from_gate(bn1, anchor) |
| if kept is not None: |
| kept_out = kept |
| oc_eff = kept_out if kept_out is not None else _as_const_like_resnet(anchor, conv1.out_channels) |
| cost, H, W = self._add_cost(cost, oc_eff, _as_const_like_resnet(anchor, C_in), k, s, H, W) |
| in_ch = oc_eff |
|
|
| def _block_cost(block, in_ch, H, W, cost): |
| |
| c1 = block.conv1 |
| b1 = block.bn1 if hasattr(block, "bn1") else None |
| k1, s1 = c1.kernel_size[0], c1.stride[0] |
| oc1_eff = _kept_from_gate(b1, anchor) or _as_const_like_resnet(anchor, c1.out_channels) |
| cost, H, W = self._add_cost(cost, oc1_eff, in_ch, k1, s1, H, W) |
|
|
| |
| c2 = block.conv2 |
| b2 = block.bn2 if hasattr(block, "bn2") else None |
| k2, s2 = c2.kernel_size[0], c2.stride[0] |
| oc2_eff = _kept_from_gate(b2, anchor) or _as_const_like_resnet(anchor, c2.out_channels) |
| cost, H, W = self._add_cost(cost, oc2_eff, oc1_eff, k2, s2, H, W) |
|
|
| return oc2_eff, H, W, cost |
|
|
| |
| for lname in ("layer1", "layer2", "layer3", "layer4"): |
| layer = getattr(model, lname, None) |
| if layer is None: |
| continue |
| for blk in layer: |
| in_ch, H, W, cost = _block_cost(blk, in_ch, H, W, cost) |
|
|
| scale = _as_const_like_resnet(anchor, self.cfg.scale_ms) |
| return cost * scale |
|
|
| @torch.no_grad() |
| def calibrate(self, model: nn.Module, keepall_export_fn, profiler_fn, sample: TensorOrBatch, device: str = "cuda") -> float: |
| """Calibrate `scale_ms` so proxy(model_keepall) ~= real latency in ms.""" |
| keep = keepall_export_fn(model) |
| sample_shape = _nchw_from_batch(sample) |
| mean_ms, _ = profiler_fn(keep, sample_shape, device=device) |
| soft = float(self.predict(model, sample).detach().cpu()) |
| self.cfg.scale_ms = mean_ms / max(soft, 1e-9) |
| return mean_ms |
|
|
|
|
| |
| |
| |
|
|
| """ |
| LatencyProxyLLM |
| --------------- |
| A lightweight latency proxy for decoder-only HF LLMs (LLaMA/Mistral style). |
| |
| - Estimates end-to-end latency (ms-like scalar) for a given (B, S, T): |
| * Prefill on S tokens (build KV cache) |
| * Cached decode for T steps |
| - Uses soft gate expectations: |
| * Attention heads (HeadGate on GatedSelfAttentionLLM) |
| * FFN hidden (SwiGLUWidthGate via .mlp.neuron_gate) |
| - Calibrate .scale_ms so proxy ≈ real latency of a keep-all model. |
| |
| Public API |
| ---------- |
| - LatencyProxyLLM(...).predict(model, batch_or_shape) # trainer entry |
| - LatencyProxyLLM(...).predict(model, B=?, S=?, T=?) # explicit entry |
| - LatencyProxyLLM(...).debug_layer_view(...) |
| - calibrate_proxy_llm(...), calibrate_proxy_llm_from_batch(...) |
| """ |
|
|
| |
| |
| |
| def _find_gate_param_or_fallback(model: nn.Module) -> torch.Tensor: |
| """ |
| Return a tensor to anchor device/dtype for proxy constants. |
| Prefer gate logits; else any parameter; else CPU fp32 scalar. |
| """ |
| for m in model.modules(): |
| if hasattr(m, "head_gate") and hasattr(getattr(m, "head_gate"), "logits"): |
| return m.head_gate.logits |
| if hasattr(m, "neuron_gate") and hasattr(m.neuron_gate, "logits"): |
| return m.neuron_gate.logits |
| if hasattr(m, "logits") and isinstance(getattr(m, "logits"), torch.Tensor): |
| return m.logits |
| for p in model.parameters(): |
| return p |
| return torch.tensor(0.0) |
|
|
| def _as_const_like(x_like: torch.Tensor, val): |
| return torch.as_tensor(val, device=x_like.device, dtype=x_like.dtype) |
|
|
|
|
| |
| |
| |
| @dataclass |
| class _WarmupOnlyPolicy: |
| """Tiny policy shim so you can pass warmup_steps to .predict().""" |
| warmup_steps: int = 0 |
|
|
| class LatencyProxyLLM(LatencyProxy): |
| """ |
| LLM latency proxy (ms ~ weighted FLOPs/bandwidth terms) for prefill + cached decode. |
| Accepts either a batch or explicit B,S,T. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| scale_ms: float = 1.0, |
| alpha_qkv: float = 1.0, |
| alpha_scores: float = 1.0, |
| alpha_out: float = 1.0, |
| alpha_mlp: float = 1.0, |
| gate_kv_in_proxy: bool = False, |
| default_T: int = 128, |
| ): |
| super().__init__() |
| self.scale_ms = float(scale_ms) |
| self.alpha_qkv = float(alpha_qkv) |
| self.alpha_scores = float(alpha_scores) |
| self.alpha_out = float(alpha_out) |
| self.alpha_mlp = float(alpha_mlp) |
| self.gate_kv_in_proxy = bool(gate_kv_in_proxy) |
| self.default_T = int(default_T) |
|
|
| |
| @staticmethod |
| def _soft_heads_from_block_llm(blk) -> Optional[torch.Tensor]: |
| attn = getattr(blk, "self_attn", None) |
| if attn is None: |
| return None |
| if hasattr(attn, "kept_heads_soft") and callable(attn.kept_heads_soft): |
| return attn.kept_heads_soft() |
| logits, tau = None, None |
| if hasattr(attn, "head_gate") and hasattr(attn.head_gate, "logits"): |
| logits = attn.head_gate.logits |
| tau = float(getattr(attn.head_gate, "tau", getattr(attn, "tau", 1.5))) |
| elif hasattr(attn, "logits"): |
| logits = attn.logits |
| tau = float(getattr(attn, "tau", 1.5)) |
| if logits is None: |
| return None |
| return torch.sigmoid(logits / tau).sum() |
|
|
| @staticmethod |
| def _find_ffn_gate_llm(blk): |
| mlp = getattr(blk, "mlp", None) |
| g = getattr(mlp, "neuron_gate", None) if mlp is not None else None |
| if g is not None and hasattr(g, "logits") and hasattr(g, "tau"): |
| return g |
| return None |
|
|
| def _soft_hidden_from_block_llm(self, blk, default_hidden, anchor, warm=False): |
| if warm: |
| return default_hidden |
| g = self._find_ffn_gate_llm(blk) |
| if g is None: |
| return default_hidden |
| probs = torch.sigmoid(g.logits / float(g.tau)) |
| group = int(getattr(g, "group", getattr(g, "group_size", 128))) |
| kept_hidden = probs.sum() * _as_const_like(anchor, group) |
| return kept_hidden |
|
|
| |
| def predict( |
| self, |
| model: nn.Module, |
| sample: Optional[TensorOrBatch] = None, |
| *, |
| B: Optional[int] = None, |
| S: Optional[int] = None, |
| T: Optional[int] = None, |
| policy: Optional[object] = None, |
| step: Optional[int] = None, |
| return_terms: bool = False, |
| ): |
| |
| if B is not None and S is not None: |
| ids_B, ids_S = int(B), int(S) |
| ids_T = int(T) if T is not None else int(self.default_T) |
| else: |
| if sample is None: |
| raise ValueError("LatencyProxyLLM.predict needs either a batch sample or explicit B,S.") |
| if isinstance(sample, (tuple, list)) and len(sample) in (2, 3) and all(isinstance(x, int) for x in sample): |
| |
| ids_B, ids_S = int(sample[0]), int(sample[1]) |
| ids_T = int(sample[2]) if len(sample) == 3 else int(self.default_T) |
| else: |
| ids = _ids_from_batch(sample) |
| ids_B, ids_S = int(ids.size(0)), int(ids.size(1)) |
| ids_T = int(self.default_T) if T is None else int(T) |
|
|
| anchor = _find_gate_param_or_fallback(model) |
|
|
| |
| B_t = _as_const_like(anchor, ids_B) |
| S_t = _as_const_like(anchor, ids_S) |
| T_t = _as_const_like(anchor, ids_T) |
|
|
| cfg = model.config |
| D = _as_const_like(anchor, int(cfg.hidden_size)) |
| Hh = _as_const_like(anchor, int(cfg.num_attention_heads)) |
| Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hh)))) |
| Dh = D // Hh |
|
|
| warmup_steps = int(getattr(policy, "warmup_steps", 0)) if policy is not None else 0 |
| warm = bool(step is not None and step < warmup_steps) |
|
|
| total_qkv = anchor.new_zeros(()) |
| total_scores = anchor.new_zeros(()) |
| total_out = anchor.new_zeros(()) |
| total_mlp = anchor.new_zeros(()) |
|
|
| default_hidden = _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D)))) |
|
|
| layers = getattr(getattr(model, "model", model), "layers", []) |
| for blk in layers: |
| heads_soft = Hh if warm else (self._soft_heads_from_block_llm(blk) or Hh) |
| Dq = heads_soft * Dh |
| |
| if self.gate_kv_in_proxy: |
| Dkv = heads_soft * Dh |
| else: |
| Dkv = Hkv * Dh |
| hidden_soft = self._soft_hidden_from_block_llm(blk, default_hidden, anchor, warm=warm) |
|
|
| |
| Seff = S_t + T_t |
|
|
| |
| total_qkv = total_qkv + ( |
| |
| B_t * Seff * D * Dq + |
| |
| 2 * B_t * Seff * D * Dkv |
| ) |
| |
| total_scores = total_scores + ( |
| B_t * (S_t * S_t) * heads_soft * Dh + |
| B_t * heads_soft * Dh * (T_t * S_t + (T_t * (T_t + 1)) // 2) |
| ) |
| |
| total_out = total_out + B_t * Seff * Dq * D |
| |
| total_mlp = total_mlp + B_t * Seff * 2 * D * hidden_soft |
|
|
| flops_like = ( |
| self.alpha_qkv * total_qkv |
| + self.alpha_scores * total_scores |
| + self.alpha_out * total_out |
| + self.alpha_mlp * total_mlp |
| ) |
|
|
| ms = flops_like * _as_const_like(anchor, self.scale_ms) |
| if return_terms: |
| return ms, { |
| "qkv": float((self.alpha_qkv * total_qkv).detach().cpu()), |
| "scores": float((self.alpha_scores * total_scores).detach().cpu()), |
| "out": float((self.alpha_out * total_out).detach().cpu()), |
| "mlp": float((self.alpha_mlp * total_mlp).detach().cpu()), |
| } |
| return ms |
|
|
| |
| @torch.no_grad() |
| def debug_layer_view( |
| self, |
| model: nn.Module, |
| *, |
| B: int, |
| S: int, |
| T: int, |
| policy: Optional[object] = None, |
| step: Optional[int] = None, |
| ) -> list: |
| anchor = _find_gate_param_or_fallback(model) |
| cfg = getattr(model, "config", None) |
| D = _as_const_like(anchor, int(getattr(cfg, "hidden_size", 0))) |
| Hq = _as_const_like(anchor, int(getattr(cfg, "num_attention_heads", 0))) |
| Hkv = _as_const_like(anchor, int(getattr(cfg, "num_key_value_heads", int(Hq)))) |
| Dh = D // Hq |
|
|
| warm = False |
| if policy is not None and step is not None: |
| warm = (int(step) < int(getattr(policy, "warmup_steps", 0))) |
|
|
| rows = [] |
| layers = getattr(getattr(model, "model", model), "layers", None) or [] |
| for i, blk in enumerate(layers): |
| heads_soft = Hq if warm else (self._soft_heads_from_block_llm(blk) or Hq) |
| Dq = heads_soft * Dh |
| Dkv = (heads_soft * Dh) if self.gate_kv_in_proxy else (Hkv * Dh) |
| hidden_soft = self._soft_hidden_from_block_llm( |
| blk, _as_const_like(anchor, int(getattr(cfg, "intermediate_size", 4 * int(D)))), anchor, warm=warm |
| ) |
| rows.append({ |
| "layer": i, |
| "heads_soft": float(heads_soft.detach().cpu()), |
| "Dq≈heads*Dh": float(Dq.detach().cpu()), |
| "Dkv_used": float(Dkv.detach().cpu()), |
| "ffn_hidden_soft": float(hidden_soft.detach().cpu()), |
| }) |
| return rows |
|
|
|
|
| |
| |
| |
| @torch.inference_mode() |
| def calibrate_proxy_llm( |
| proxy: LatencyProxyLLM, |
| model: nn.Module, |
| *, |
| B: int, |
| S: int, |
| T: int, |
| export_keepall_fn, |
| device: str = "cuda", |
| warmup: int = 10, |
| iters: int = 30, |
| ) -> float: |
| """ |
| Calibrate proxy.scale_ms so proxy.predict(...) matches real keep-all latency for (B,S,T). |
| Returns the measured real mean latency in ms. |
| """ |
| keepall = export_keepall_fn(model).to(device).eval() |
|
|
| |
| from core.measure import measure_latency_text_ms as _measure |
| real_ms, _ = _measure(keepall, B=B, S=S, T=T, warmup=warmup, iters=iters, device=device) |
|
|
| |
| ms_like = proxy.predict(model, B=B, S=S, T=T) |
| soft_ms = float(ms_like.detach().item()) if torch.is_tensor(ms_like) else float(ms_like) |
|
|
| proxy.scale_ms = float(real_ms / max(soft_ms, 1e-9)) |
| return real_ms |
|
|
|
|
| @torch.inference_mode() |
| def calibrate_proxy_llm_from_batch( |
| proxy: LatencyProxyLLM, |
| model: nn.Module, |
| batch: Dict[str, torch.Tensor], |
| *, |
| T: int, |
| export_keepall_fn, |
| device: str = "cuda", |
| warmup: int = 10, |
| iters: int = 30, |
| ) -> Tuple[int, int, int, float]: |
| """ |
| Infers (B,S) from a batch like {'input_ids': [B,S], ...}, |
| calibrates for (B,S,T), and returns (B,S,T, real_ms). |
| """ |
| input_ids = batch["input_ids"] |
| B, S = int(input_ids.size(0)), int(input_ids.size(1)) |
| ms = calibrate_proxy_llm( |
| proxy, model, B=B, S=S, T=T, export_keepall_fn=export_keepall_fn, |
| device=device, warmup=warmup, iters=iters |
| ) |
| return B, S, T, ms |
|
|