| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import contextlib |
| import functools |
| import json |
| import math |
| import os |
| import uuid |
| from collections import defaultdict |
| from enum import Enum |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| CheckpointImpl, checkpoint_wrapper) |
| from torch.fx.operator_schemas import normalize_function |
| from torch.nn.attention import SDPBackend, sdpa_kernel |
| from torch.utils._python_dispatch import TorchDispatchMode |
| from torch.utils._pytree import tree_map |
| from torch.utils.module_tracker import ModuleTracker |
| from xformers.ops import fmha |
|
|
|
|
| @torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None) |
| def _log(x: torch.Tensor, name: str, uid: str) -> None: |
| pass |
|
|
|
|
| @_log.register_fake |
| def _log_fake(x: torch.Tensor, name: str, uid: str) -> None: |
| pass |
|
|
|
|
| class _LogStats(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x: torch.Tensor, name: str): |
| uid = str(uuid.uuid4()) |
| torch.ops.torchprobe.log(x, name, uid) |
| ctx.name = name |
| ctx.uid = uid |
| return x |
|
|
| @staticmethod |
| def backward(ctx, grad: torch.Tensor): |
| torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid) |
| return grad, None |
|
|
|
|
| _PROBING_ENABLED = False |
|
|
|
|
| def log_stats(x: torch.Tensor, name: str) -> torch.Tensor: |
| if not _PROBING_ENABLED: |
| return x |
| return _LogStats.apply(x, name) |
|
|
|
|
| QUANTILES = [ |
| 0.0000001, |
| 0.000001, |
| 0.00001, |
| 0.0001, |
| 0.001, |
| 0.01, |
| 0.05, |
| 0.1, |
| 0.3, |
| 0.5, |
| 0.7, |
| 0.9, |
| 0.95, |
| 0.99, |
| 0.999, |
| 0.9999, |
| 0.99999, |
| 0.999999, |
| 0.9999999, |
| ] |
|
|
|
|
| @functools.cache |
| def _get_quantiles(device: torch.device, dtype) -> torch.Tensor: |
| return torch.tensor(QUANTILES, device=device, dtype=dtype) |
|
|
|
|
| def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]: |
| if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]: |
| return {} |
| x = x_.flatten() |
| if remove_inf: |
| x = x[x.abs() < float("inf")] |
| if x.dtype is not torch.double: |
| x = x.float() |
| xabs = x.abs() |
| quantiles = _get_quantiles(x.device, x.dtype) |
| mean = x.mean() |
| std = x.std() |
| return { |
| "shape": tuple(x_.shape), |
| "mean": mean, |
| "std": std, |
| "skew": (((x - mean) / std) ** 3).double().mean(), |
| "kurtosis": (((x - mean) / std) ** 4).double().mean(), |
| "abs.mean": xabs.mean(), |
| "max": x.max(), |
| "min": x.min(), |
| |
| |
| "quantiles": torch.quantile(x[: 2**24], quantiles), |
| } |
|
|
|
|
| def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None: |
| assert logits.ndim == 4 |
| logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf |
|
|
|
|
| def _mask_attn_logits( |
| logits: torch.Tensor, |
| q_idx: List[int], |
| *, |
| causal: bool, |
| cu_seqlens_q: Optional[torch.Tensor] = None, |
| cu_seqlens_k: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| assert logits.dtype is torch.float32 |
| |
| if cu_seqlens_q is not None: |
| assert cu_seqlens_k is not None |
| |
| assert logits.ndim == 4, logits.shape |
| qs = cu_seqlens_q.tolist() |
| ks = cu_seqlens_k.tolist() |
| q_batchid = [] |
| k_batchid = [-2] * logits.shape[-1] |
| q_idx_i = 0 |
| for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])): |
| for k in range(k0, k1): |
| k_batchid[k] = bid |
| while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1: |
| q_batchid.append(bid) |
| if causal: |
| _mask_attn_causal_inplace( |
| logits[:, :, q_idx_i : q_idx_i + 1, k0:k1], |
| q_idx[q_idx_i] - q0, |
| q1 - q0, |
| k1 - k0, |
| ) |
| q_idx_i += 1 |
| mask_out = ( |
| torch.tensor(q_batchid, device=logits.device)[None, None, :, None] |
| != torch.tensor(k_batchid, device=logits.device)[None, None, None, :] |
| ) |
| logits[mask_out.expand_as(logits)] = -math.inf |
| assert q_idx_i == len(q_idx) |
| elif causal: |
| for q_idx_i in range(len(q_idx)): |
| _mask_attn_causal_inplace( |
| logits[:, :, q_idx_i : q_idx_i + 1, :], |
| q_idx[q_idx_i], |
| logits.shape[2], |
| logits.shape[3], |
| ) |
| return logits |
|
|
|
|
| def _attn_queries_subset(num_queries: int) -> List[int]: |
| return list(range(0, num_queries, max(1, num_queries // 128))) |
|
|
|
|
| @torch.no_grad() |
| def _compute_attn_stats_sdpa( |
| probe, |
| path: str, |
| |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask=None, |
| attn_bias=None, |
| dropout_p=0.0, |
| is_causal=False, |
| scale=None, |
| compute_log_sumexp=True, |
| return_debug_mask=False, |
| **kwargs, |
| ): |
| if scale is None: |
| scale = 1 / (query.shape[-1] ** 0.5) |
| |
| if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs: |
| probe.store[f"{path}::attn"] = { |
| "query.shape": tuple(query.shape), |
| "key.shape": tuple(key.shape), |
| "value.shape": tuple(value.shape), |
| "attn_mask": attn_mask.shape if attn_mask is not None else None, |
| "dropout_p": dropout_p, |
| "is_causal": is_causal, |
| "scale": scale, |
| "unk_kwargs": list(kwargs.keys()), |
| } |
| return |
| |
| query_s = _attn_queries_subset(query.shape[-2]) |
| logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale |
| logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal) |
| p = logits.float().softmax(-1) |
| masked_logsoft = logits.log_softmax(-1).where( |
| (logits > -math.inf), torch.zeros_like(logits) |
| ) |
| entropy = -(p * masked_logsoft).sum(-1) |
| probe.log_tensor(f"{path}::attn_entropy", entropy) |
| probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) |
|
|
|
|
| @torch.no_grad() |
| def _compute_attn_stats_flash( |
| probe, |
| path: str, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| cu_seqlens_q: Optional[torch.Tensor], |
| cu_seqlens_k: Optional[torch.Tensor], |
| seqused_k: Optional[torch.Tensor], |
| max_seqlen_q: int, |
| max_seqlen_k: int, |
| p: float, |
| softmax_scale: float, |
| is_causal: bool, |
| window_left: int, |
| window_right: int, |
| return_softmax: bool, |
| block_tables: Optional[torch.Tensor], |
| unpadded_lse: bool = False, |
| ) -> None: |
| |
| if ( |
| seqused_k is not None |
| or p != 0.0 |
| or window_left >= 0 |
| or window_right >= 0 |
| or block_tables is not None |
| ): |
| probe.store[f"{path}::attn"] = { |
| "query.shape": tuple(query.shape), |
| "key.shape": tuple(key.shape), |
| "value.shape": tuple(value.shape), |
| "op": "flash", |
| } |
| return |
|
|
| if cu_seqlens_q is not None: |
| assert query.ndim == 3, query.shape |
| query, key, value = query[None], key[None], value[None] |
| assert query.ndim == 4, query.shape |
|
|
| |
| query_s = _attn_queries_subset(query.shape[1]) |
| logits = ( |
| query[:, query_s].transpose(1, 2) |
| @ key.transpose(1, 2).transpose(-1, -2) |
| * softmax_scale |
| ) |
| logits = _mask_attn_logits( |
| logits.float(), |
| query_s, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| causal=is_causal, |
| ) |
| p = logits.float().softmax(-1) |
| masked_logsoft = logits.log_softmax(-1).where( |
| (logits > -math.inf), torch.zeros_like(logits) |
| ) |
| entropy = -(p * masked_logsoft).sum(-1) |
| probe.log_tensor(f"{path}::attn_entropy", entropy) |
| probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) |
|
|
|
|
| def _tensors_to_python(x): |
| if not isinstance(x, torch.Tensor): |
| return x |
| return x.tolist() |
|
|
|
|
| |
| class LinearBwType(Enum): |
| DW = 1 |
| DX = 2 |
| UNKNOWN = 3 |
|
|
|
|
| class AutoProbeD(TorchDispatchMode): |
| def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None: |
| self.write_file = Path(write_file) if write_file is not None else None |
| self.write_tensors_tmpdir: Optional[Path] = None |
| self.compile_disabler = TorchCompileDisabler(module) |
| self.mod_tracker = ModuleTracker() |
| self.count_per_path: Dict[str, int] = defaultdict(int) |
| self.store: Dict[str, Dict[str, Any]] = {} |
| self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {} |
| self.uid_to_path: Dict[str, str] = {} |
| self.metadata: Any = None |
| self.enabled = False |
| self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0"))) |
|
|
| def __enter__(self): |
| global _PROBING_ENABLED |
| assert not self.enabled, "Entered probe twice" |
| self.compile_disabler.__enter__() |
| self.mod_tracker.__enter__() |
| super().__enter__() |
| self.enabled = True |
| _PROBING_ENABLED = True |
| |
| return self |
|
|
| def __exit__(self, *args) -> None: |
| global _PROBING_ENABLED |
| assert self.enabled, "Exiting probe without entering it" |
| super().__exit__(*args) |
| self.mod_tracker.__exit__(*args) |
| self.compile_disabler.__exit__(*args) |
| self._flush_and_clear() |
| _PROBING_ENABLED = False |
| self.enabled = False |
|
|
| def _setup_tensors_logging(self): |
| if self.write_file is not None: |
| self.write_file.parent.mkdir(exist_ok=True) |
| self.write_tensors_tmpdir = ( |
| self.write_file.parent |
| / f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}" |
| ) |
| self.write_tensors_tmpdir.mkdir(exist_ok=True) |
|
|
| def _flush_and_clear(self) -> None: |
| if self.write_file is not None: |
| dump_data = tree_map(_tensors_to_python, self.store) |
| with self.write_file.open("a") as fd: |
| json.dump( |
| { |
| "data": dump_data, |
| "meta": self.metadata, |
| "version": 2, |
| "quantiles": QUANTILES, |
| }, |
| fd, |
| ) |
| fd.write("\n") |
| if self.write_tensors_tmpdir is not None: |
| assert self.write_file is not None |
| dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump" |
| dump_dir.mkdir(exist_ok=True) |
| dir_name = "" |
| if "it" in self.metadata: |
| dir_name = f"it{int(self.metadata['it']):010}" |
| if dir_name == "" or (dump_dir / dir_name).exists(): |
| num_files = len(list(dump_dir.glob(f"{dir_name}v*"))) |
| dir_name = f"{dir_name}v{num_files}" |
| dump_dir = dump_dir / dir_name |
| assert not dump_dir.exists() |
| self.write_tensors_tmpdir.rename(dump_dir) |
| self.write_tensors_tmpdir = None |
| self.store.clear() |
| self.count_per_path.clear() |
| self.uid_to_path.clear() |
|
|
| def _find_bw_path_and_type( |
| self, path: str, out: torch.Tensor, args |
| ) -> Tuple[str, LinearBwType]: |
| """ |
| We are in the BW pass, and process a GEMM. |
| Let's figure out: |
| (1) The path for the FW pass (might differ in case of ModuleTracker bug) |
| (2) The type of BW pass (eg `dw` or `dx`) |
| """ |
|
|
| def _is_path_correct_dw(path: str) -> bool: |
| |
| in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] |
| return out.shape == (w_shape[1], w_shape[0]) and torch.allclose( |
| input_sm, args[1][:4, :4] |
| ) |
|
|
| def _is_path_correct_dx(path: str) -> bool: |
| |
| in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] |
| return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4]) |
|
|
| if path in self.linear_data: |
| if _is_path_correct_dw(path): |
| return path, LinearBwType.DW |
| if _is_path_correct_dx(path): |
| return path, LinearBwType.DX |
| for candidate_path in self.mod_tracker.parents: |
| if candidate_path not in self.linear_data: |
| continue |
| if _is_path_correct_dw(candidate_path): |
| return candidate_path, LinearBwType.DW |
| if _is_path_correct_dx(candidate_path): |
| return candidate_path, LinearBwType.DX |
| return path, LinearBwType.UNKNOWN |
|
|
| def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None: |
| self.store[name] = _get_stats(x, **kwargs) |
| if self.write_tensors_tmpdir is not None: |
| name_safe = name.replace("::", "__").replace("/", "") |
| torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl") |
|
|
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| kwargs = kwargs if kwargs else {} |
| path = None |
| |
| for p in self.mod_tracker.parents: |
| if p == "Global": |
| continue |
| if path is None or len(p) > len(path): |
| path = p |
| if path is None: |
| path = "Global" |
| path = path.replace("._checkpoint_wrapped_module", "") |
| out = func(*args, **kwargs) |
|
|
| |
| if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]: |
| weight: torch.Tensor |
| input: torch.Tensor |
| if not self.mod_tracker.is_bw: |
| |
| if func._overloadpacket == torch.ops.aten.addmm: |
| _bias, input, weight = args[:3] |
| else: |
| assert func._overloadpacket == torch.ops.aten.mm |
| input, weight = args[:2] |
| self.log_tensor(f"{path}::in", input) |
| self.log_tensor(f"{path}::w", weight) |
| self.log_tensor(f"{path}::out", out) |
| self.linear_data[path] = ( |
| input.shape, |
| weight.shape, |
| out.shape, |
| input[:4, :4].clone(), |
| weight[:4, :4].T.clone(), |
| ) |
| elif func._overloadpacket == torch.ops.aten.mm: |
| |
| |
| new_path, bwtype = self._find_bw_path_and_type(path, out, args) |
| if new_path != path: |
| if self.verbose: |
| print(f"E: Fixing path `{path}` -> `{new_path}") |
| path = new_path |
|
|
| if bwtype == LinearBwType.DW: |
| |
| self.log_tensor(f"{path}::w.g", out) |
| elif bwtype == LinearBwType.DX: |
| |
| self.log_tensor(f"{path}::in.g", out) |
| self.log_tensor(f"{path}::out.g", args[0]) |
| elif func._overloadpacket in [ |
| torch.ops.aten._scaled_dot_product_flash_attention, |
| torch.ops.aten._scaled_dot_product_cudnn_attention, |
| ]: |
| _, kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| _compute_attn_stats_sdpa(self, path, **kwargs) |
| elif func._overloadpacket == fmha.flash.FwOp.OPERATOR: |
| _, kwargs = normalize_function( |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| _compute_attn_stats_flash(self, path, **kwargs) |
| elif func._overloadpacket == torch.ops.torchprobe.log: |
| uid = args[2] |
| path = self.uid_to_path.setdefault(uid, path) |
| self.log_tensor(f"{path}::{args[1]}", args[0]) |
| if self.verbose: |
| print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}") |
| return out |
|
|
|
|
| def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None: |
| if module._compiled_call_impl is not None: |
| out.append(module) |
| for c in module.children(): |
| _find_all_submodules_compiled(out, module=c) |
|
|
|
|
| class TorchCompileDisabler: |
| def __init__(self, module: nn.Module) -> None: |
| self.module = module |
| self.submodules_compiled: List[nn.Module] = [] |
| self.compiled_call_impl: List[Any] = [] |
| self.disable_compile = torch.compiler.disable() |
| torch._dynamo.config.raise_on_ctx_manager_usage = False |
|
|
| def __enter__(self) -> None: |
| |
| |
| self.submodules_compiled.clear() |
| _find_all_submodules_compiled(self.submodules_compiled, self.module) |
| self.compiled_call_impl = [ |
| m._compiled_call_impl for m in self.submodules_compiled |
| ] |
| for m in self.submodules_compiled: |
| m._compiled_call_impl = None |
| self.disable_compile.__enter__() |
|
|
| def __exit__(self, *args) -> None: |
| self.disable_compile.__exit__(*args) |
| for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl): |
| m._compiled_call_impl = c_impl |
| self.compiled_call_impl = [] |
|
|
|
|
| Probe = AutoProbeD |
|
|
| |
| d = 512 |
| seqlen = 4 |
| bs = 2 |
|
|
|
|
| class Attention1(nn.Module): |
| def forward(self, x): |
| attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask() |
| return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape( |
| [x.shape[0], seqlen, -1] |
| ) |
|
|
|
|
| class Attention2(nn.Module): |
| def forward(self, x): |
| attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens( |
| [seqlen] * bs |
| ).make_causal() |
| xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]]) |
| return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape( |
| [x.shape[0], seqlen, -1] |
| ) |
|
|
|
|
| class AttentionSDPA(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.wo = nn.Linear(d, d) |
|
|
| def forward(self, x): |
| x = x.transpose(1, 2) |
| return self.wo( |
| F.scaled_dot_product_attention(x, x, x) |
| .transpose(1, 2) |
| .reshape([x.shape[0], seqlen, -1]) |
| ) |
|
|
|
|
| class AttentionSDPAFlash(AttentionSDPA): |
| def forward(self, x): |
| x = x.transpose(1, 2) |
| with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
| return self.wo( |
| F.scaled_dot_product_attention(x, x, x) |
| .transpose(1, 2) |
| .reshape([x.shape[0], seqlen, -1]) |
| ) |
|
|
|
|
| class Model(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.head = nn.Linear(d, 16) |
| self.trunk = nn.Sequential( |
| nn.Linear(d, d), |
| nn.Linear(d, d), |
| ) |
| self.q_proj = nn.Linear(d, d, bias=False) |
| self.trunk.compile() |
| self.attn1 = Attention1() |
| self.attn2 = Attention2() |
| self.attnSDPA = AttentionSDPA() |
| self.attnSDPAflash = AttentionSDPAFlash() |
|
|
| def forward(self, x): |
| B, nHeads, D = x.shape[0], d // 64, 64 |
| x = self.q_proj(x).reshape([B, seqlen, nHeads, D]) |
| x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x) |
| x = log_stats(x, "attns_out") |
| return self.head(self.trunk(x)) |
|
|
|
|
| def test_masking() -> None: |
| q_seqlen = [1, 1, 14, 12] |
| kv_seqlen = [2, 2, 14, 18] |
| attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( |
| q_seqlen, kv_seqlen |
| ).make_causal_from_bottomright() |
| logits = torch.randn( |
| [1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda" |
| ) |
| bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device) |
| logits_masked = logits.clone() |
| _mask_attn_logits( |
| logits_masked, |
| list(range(logits.shape[2])), |
| causal=True, |
| cu_seqlens_q=attn_bias.q_seqinfo.seqstart, |
| cu_seqlens_k=attn_bias.k_seqinfo.seqstart, |
| ) |
| assert (logits + bias == logits_masked).all().item() |
|
|
|
|
| def test_toy_model() -> None: |
| |
| kw = dict(device="cuda", dtype=torch.float16) |
| x = torch.randn([bs, seqlen, d], **kw) |
| m = Model() |
| m.head = checkpoint_wrapper( |
| m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False |
| ) |
| m.to(**kw) |
| m.compile() |
| optim = torch.optim.SGD(m.parameters(), lr=0.0) |
| probe = AutoProbeD(m, "./probe.json") |
|
|
| for i in range(4): |
| with contextlib.ExitStack() as stack: |
| print(f"########### STEP {i}") |
| if i % 4 == 1: |
| stack.enter_context(probe) |
| probe.metadata = {"it": i} |
| y = m(x) |
| g = torch.randn_like(y) |
| y.backward(g) |
| if i % 4 == 1: |
| assert probe.enabled |
| |
| print(list(probe.store.keys())) |
| for key in [ |
| "Model::attns_out", |
| "Model::attns_out.g", |
| "Model.attn1::attn_logits", |
| "Model.attn2::attn_logits", |
| "Model.attnSDPA::attn_logits", |
| "Model.attnSDPAflash::attn_logits", |
| "Model.head::w", |
| "Model.head::w.g", |
| "Model.head::in", |
| "Model.head::in.g", |
| "Model.head::out", |
| "Model.head::out.g", |
| "Model.trunk.0::in", |
| "Model.trunk.1::in", |
| ]: |
| assert key in probe.store, f"Missing key: '{key}'" |
| |
| for key, tensor in [ |
| ("Model.head::w", m.head.weight), |
| ("Model.head::w.g", m.head.weight.grad), |
| ("Model.q_proj::in", x), |
| ("Model.q_proj::w.g", m.q_proj.weight.grad), |
| ("Model.head::out", y), |
| ("Model.head::out.g", g), |
| ]: |
| assert key in probe.store, f"Missing key: '{key}'" |
| assert torch.allclose( |
| probe.store[key]["abs.mean"], tensor.float().abs().mean() |
| ), f"'{key}' mismatches" |
| |
| for key, value in probe.store.items(): |
| if "abs.mean" in value: |
| assert math.isfinite( |
| value["abs.mean"].item() |
| ), f"Inf/Nan for {key}" |
| optim.step() |
| optim.zero_grad() |
|
|