|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
import functools |
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import uuid |
|
|
from collections import defaultdict |
|
|
from enum import Enum |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
|
CheckpointImpl, checkpoint_wrapper) |
|
|
from torch.fx.operator_schemas import normalize_function |
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
from torch.utils._python_dispatch import TorchDispatchMode |
|
|
from torch.utils._pytree import tree_map |
|
|
from torch.utils.module_tracker import ModuleTracker |
|
|
from xformers.ops import fmha |
|
|
|
|
|
|
|
|
@torch.library.custom_op("torchprobe::log", mutates_args=(), device_types=None) |
|
|
def _log(x: torch.Tensor, name: str, uid: str) -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
@_log.register_fake |
|
|
def _log_fake(x: torch.Tensor, name: str, uid: str) -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
class _LogStats(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, x: torch.Tensor, name: str): |
|
|
uid = str(uuid.uuid4()) |
|
|
torch.ops.torchprobe.log(x, name, uid) |
|
|
ctx.name = name |
|
|
ctx.uid = uid |
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad: torch.Tensor): |
|
|
torch.ops.torchprobe.log(grad, f"{ctx.name}.g", ctx.uid) |
|
|
return grad, None |
|
|
|
|
|
|
|
|
_PROBING_ENABLED = False |
|
|
|
|
|
|
|
|
def log_stats(x: torch.Tensor, name: str) -> torch.Tensor: |
|
|
if not _PROBING_ENABLED: |
|
|
return x |
|
|
return _LogStats.apply(x, name) |
|
|
|
|
|
|
|
|
QUANTILES = [ |
|
|
0.0000001, |
|
|
0.000001, |
|
|
0.00001, |
|
|
0.0001, |
|
|
0.001, |
|
|
0.01, |
|
|
0.05, |
|
|
0.1, |
|
|
0.3, |
|
|
0.5, |
|
|
0.7, |
|
|
0.9, |
|
|
0.95, |
|
|
0.99, |
|
|
0.999, |
|
|
0.9999, |
|
|
0.99999, |
|
|
0.999999, |
|
|
0.9999999, |
|
|
] |
|
|
|
|
|
|
|
|
@functools.cache |
|
|
def _get_quantiles(device: torch.device, dtype) -> torch.Tensor: |
|
|
return torch.tensor(QUANTILES, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
def _get_stats(x_: torch.Tensor, remove_inf=False) -> Dict[str, Any]: |
|
|
if x_.dtype not in [torch.float, torch.double, torch.float16, torch.bfloat16]: |
|
|
return {} |
|
|
x = x_.flatten() |
|
|
if remove_inf: |
|
|
x = x[x.abs() < float("inf")] |
|
|
if x.dtype is not torch.double: |
|
|
x = x.float() |
|
|
xabs = x.abs() |
|
|
quantiles = _get_quantiles(x.device, x.dtype) |
|
|
mean = x.mean() |
|
|
std = x.std() |
|
|
return { |
|
|
"shape": tuple(x_.shape), |
|
|
"mean": mean, |
|
|
"std": std, |
|
|
"skew": (((x - mean) / std) ** 3).double().mean(), |
|
|
"kurtosis": (((x - mean) / std) ** 4).double().mean(), |
|
|
"abs.mean": xabs.mean(), |
|
|
"max": x.max(), |
|
|
"min": x.min(), |
|
|
|
|
|
|
|
|
"quantiles": torch.quantile(x[: 2**24], quantiles), |
|
|
} |
|
|
|
|
|
|
|
|
def _mask_attn_causal_inplace(logits: torch.Tensor, q_idx, q_len, kv_len) -> None: |
|
|
assert logits.ndim == 4 |
|
|
logits[:, :, :, q_idx + kv_len - q_len + 1 :] = -math.inf |
|
|
|
|
|
|
|
|
def _mask_attn_logits( |
|
|
logits: torch.Tensor, |
|
|
q_idx: List[int], |
|
|
*, |
|
|
causal: bool, |
|
|
cu_seqlens_q: Optional[torch.Tensor] = None, |
|
|
cu_seqlens_k: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
assert logits.dtype is torch.float32 |
|
|
|
|
|
if cu_seqlens_q is not None: |
|
|
assert cu_seqlens_k is not None |
|
|
|
|
|
assert logits.ndim == 4, logits.shape |
|
|
qs = cu_seqlens_q.tolist() |
|
|
ks = cu_seqlens_k.tolist() |
|
|
q_batchid = [] |
|
|
k_batchid = [-2] * logits.shape[-1] |
|
|
q_idx_i = 0 |
|
|
for bid, (q0, q1, k0, k1) in enumerate(zip(qs, qs[1:], ks, ks[1:])): |
|
|
for k in range(k0, k1): |
|
|
k_batchid[k] = bid |
|
|
while q_idx_i < len(q_idx) and q_idx[q_idx_i] < q1: |
|
|
q_batchid.append(bid) |
|
|
if causal: |
|
|
_mask_attn_causal_inplace( |
|
|
logits[:, :, q_idx_i : q_idx_i + 1, k0:k1], |
|
|
q_idx[q_idx_i] - q0, |
|
|
q1 - q0, |
|
|
k1 - k0, |
|
|
) |
|
|
q_idx_i += 1 |
|
|
mask_out = ( |
|
|
torch.tensor(q_batchid, device=logits.device)[None, None, :, None] |
|
|
!= torch.tensor(k_batchid, device=logits.device)[None, None, None, :] |
|
|
) |
|
|
logits[mask_out.expand_as(logits)] = -math.inf |
|
|
assert q_idx_i == len(q_idx) |
|
|
elif causal: |
|
|
for q_idx_i in range(len(q_idx)): |
|
|
_mask_attn_causal_inplace( |
|
|
logits[:, :, q_idx_i : q_idx_i + 1, :], |
|
|
q_idx[q_idx_i], |
|
|
logits.shape[2], |
|
|
logits.shape[3], |
|
|
) |
|
|
return logits |
|
|
|
|
|
|
|
|
def _attn_queries_subset(num_queries: int) -> List[int]: |
|
|
return list(range(0, num_queries, max(1, num_queries // 128))) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _compute_attn_stats_sdpa( |
|
|
probe, |
|
|
path: str, |
|
|
|
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attn_mask=None, |
|
|
attn_bias=None, |
|
|
dropout_p=0.0, |
|
|
is_causal=False, |
|
|
scale=None, |
|
|
compute_log_sumexp=True, |
|
|
return_debug_mask=False, |
|
|
**kwargs, |
|
|
): |
|
|
if scale is None: |
|
|
scale = 1 / (query.shape[-1] ** 0.5) |
|
|
|
|
|
if attn_mask is not None or attn_bias is not None or dropout_p != 0.0 or kwargs: |
|
|
probe.store[f"{path}::attn"] = { |
|
|
"query.shape": tuple(query.shape), |
|
|
"key.shape": tuple(key.shape), |
|
|
"value.shape": tuple(value.shape), |
|
|
"attn_mask": attn_mask.shape if attn_mask is not None else None, |
|
|
"dropout_p": dropout_p, |
|
|
"is_causal": is_causal, |
|
|
"scale": scale, |
|
|
"unk_kwargs": list(kwargs.keys()), |
|
|
} |
|
|
return |
|
|
|
|
|
query_s = _attn_queries_subset(query.shape[-2]) |
|
|
logits = query[:, :, query_s] @ key.transpose(-1, -2) * scale |
|
|
logits = _mask_attn_logits(logits.float(), query_s, causal=is_causal) |
|
|
p = logits.float().softmax(-1) |
|
|
masked_logsoft = logits.log_softmax(-1).where( |
|
|
(logits > -math.inf), torch.zeros_like(logits) |
|
|
) |
|
|
entropy = -(p * masked_logsoft).sum(-1) |
|
|
probe.log_tensor(f"{path}::attn_entropy", entropy) |
|
|
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _compute_attn_stats_flash( |
|
|
probe, |
|
|
path: str, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
cu_seqlens_q: Optional[torch.Tensor], |
|
|
cu_seqlens_k: Optional[torch.Tensor], |
|
|
seqused_k: Optional[torch.Tensor], |
|
|
max_seqlen_q: int, |
|
|
max_seqlen_k: int, |
|
|
p: float, |
|
|
softmax_scale: float, |
|
|
is_causal: bool, |
|
|
window_left: int, |
|
|
window_right: int, |
|
|
return_softmax: bool, |
|
|
block_tables: Optional[torch.Tensor], |
|
|
unpadded_lse: bool = False, |
|
|
) -> None: |
|
|
|
|
|
if ( |
|
|
seqused_k is not None |
|
|
or p != 0.0 |
|
|
or window_left >= 0 |
|
|
or window_right >= 0 |
|
|
or block_tables is not None |
|
|
): |
|
|
probe.store[f"{path}::attn"] = { |
|
|
"query.shape": tuple(query.shape), |
|
|
"key.shape": tuple(key.shape), |
|
|
"value.shape": tuple(value.shape), |
|
|
"op": "flash", |
|
|
} |
|
|
return |
|
|
|
|
|
if cu_seqlens_q is not None: |
|
|
assert query.ndim == 3, query.shape |
|
|
query, key, value = query[None], key[None], value[None] |
|
|
assert query.ndim == 4, query.shape |
|
|
|
|
|
|
|
|
query_s = _attn_queries_subset(query.shape[1]) |
|
|
logits = ( |
|
|
query[:, query_s].transpose(1, 2) |
|
|
@ key.transpose(1, 2).transpose(-1, -2) |
|
|
* softmax_scale |
|
|
) |
|
|
logits = _mask_attn_logits( |
|
|
logits.float(), |
|
|
query_s, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
causal=is_causal, |
|
|
) |
|
|
p = logits.float().softmax(-1) |
|
|
masked_logsoft = logits.log_softmax(-1).where( |
|
|
(logits > -math.inf), torch.zeros_like(logits) |
|
|
) |
|
|
entropy = -(p * masked_logsoft).sum(-1) |
|
|
probe.log_tensor(f"{path}::attn_entropy", entropy) |
|
|
probe.log_tensor(f"{path}::attn_logits", logits, remove_inf=True) |
|
|
|
|
|
|
|
|
def _tensors_to_python(x): |
|
|
if not isinstance(x, torch.Tensor): |
|
|
return x |
|
|
return x.tolist() |
|
|
|
|
|
|
|
|
|
|
|
class LinearBwType(Enum): |
|
|
DW = 1 |
|
|
DX = 2 |
|
|
UNKNOWN = 3 |
|
|
|
|
|
|
|
|
class AutoProbeD(TorchDispatchMode): |
|
|
def __init__(self, module: nn.Module, write_file: Optional[str] = None) -> None: |
|
|
self.write_file = Path(write_file) if write_file is not None else None |
|
|
self.write_tensors_tmpdir: Optional[Path] = None |
|
|
self.compile_disabler = TorchCompileDisabler(module) |
|
|
self.mod_tracker = ModuleTracker() |
|
|
self.count_per_path: Dict[str, int] = defaultdict(int) |
|
|
self.store: Dict[str, Dict[str, Any]] = {} |
|
|
self.linear_data: Dict[str, Tuple[Any, Any, Any, Any, Any]] = {} |
|
|
self.uid_to_path: Dict[str, str] = {} |
|
|
self.metadata: Any = None |
|
|
self.enabled = False |
|
|
self.verbose = bool(int(os.environ.get("PROBE_VERBOSE", "0"))) |
|
|
|
|
|
def __enter__(self): |
|
|
global _PROBING_ENABLED |
|
|
assert not self.enabled, "Entered probe twice" |
|
|
self.compile_disabler.__enter__() |
|
|
self.mod_tracker.__enter__() |
|
|
super().__enter__() |
|
|
self.enabled = True |
|
|
_PROBING_ENABLED = True |
|
|
|
|
|
return self |
|
|
|
|
|
def __exit__(self, *args) -> None: |
|
|
global _PROBING_ENABLED |
|
|
assert self.enabled, "Exiting probe without entering it" |
|
|
super().__exit__(*args) |
|
|
self.mod_tracker.__exit__(*args) |
|
|
self.compile_disabler.__exit__(*args) |
|
|
self._flush_and_clear() |
|
|
_PROBING_ENABLED = False |
|
|
self.enabled = False |
|
|
|
|
|
def _setup_tensors_logging(self): |
|
|
if self.write_file is not None: |
|
|
self.write_file.parent.mkdir(exist_ok=True) |
|
|
self.write_tensors_tmpdir = ( |
|
|
self.write_file.parent |
|
|
/ f"{self.write_file.name}-tmp-{str(uuid.uuid4())[:8]}" |
|
|
) |
|
|
self.write_tensors_tmpdir.mkdir(exist_ok=True) |
|
|
|
|
|
def _flush_and_clear(self) -> None: |
|
|
if self.write_file is not None: |
|
|
dump_data = tree_map(_tensors_to_python, self.store) |
|
|
with self.write_file.open("a") as fd: |
|
|
json.dump( |
|
|
{ |
|
|
"data": dump_data, |
|
|
"meta": self.metadata, |
|
|
"version": 2, |
|
|
"quantiles": QUANTILES, |
|
|
}, |
|
|
fd, |
|
|
) |
|
|
fd.write("\n") |
|
|
if self.write_tensors_tmpdir is not None: |
|
|
assert self.write_file is not None |
|
|
dump_dir = self.write_tensors_tmpdir.parent / f"{self.write_file.name}-dump" |
|
|
dump_dir.mkdir(exist_ok=True) |
|
|
dir_name = "" |
|
|
if "it" in self.metadata: |
|
|
dir_name = f"it{int(self.metadata['it']):010}" |
|
|
if dir_name == "" or (dump_dir / dir_name).exists(): |
|
|
num_files = len(list(dump_dir.glob(f"{dir_name}v*"))) |
|
|
dir_name = f"{dir_name}v{num_files}" |
|
|
dump_dir = dump_dir / dir_name |
|
|
assert not dump_dir.exists() |
|
|
self.write_tensors_tmpdir.rename(dump_dir) |
|
|
self.write_tensors_tmpdir = None |
|
|
self.store.clear() |
|
|
self.count_per_path.clear() |
|
|
self.uid_to_path.clear() |
|
|
|
|
|
def _find_bw_path_and_type( |
|
|
self, path: str, out: torch.Tensor, args |
|
|
) -> Tuple[str, LinearBwType]: |
|
|
""" |
|
|
We are in the BW pass, and process a GEMM. |
|
|
Let's figure out: |
|
|
(1) The path for the FW pass (might differ in case of ModuleTracker bug) |
|
|
(2) The type of BW pass (eg `dw` or `dx`) |
|
|
""" |
|
|
|
|
|
def _is_path_correct_dw(path: str) -> bool: |
|
|
|
|
|
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] |
|
|
return out.shape == (w_shape[1], w_shape[0]) and torch.allclose( |
|
|
input_sm, args[1][:4, :4] |
|
|
) |
|
|
|
|
|
def _is_path_correct_dx(path: str) -> bool: |
|
|
|
|
|
in_shape, w_shape, out_shape, input_sm, weight_sm = self.linear_data[path] |
|
|
return out.shape == in_shape and torch.allclose(weight_sm, args[1][:4, :4]) |
|
|
|
|
|
if path in self.linear_data: |
|
|
if _is_path_correct_dw(path): |
|
|
return path, LinearBwType.DW |
|
|
if _is_path_correct_dx(path): |
|
|
return path, LinearBwType.DX |
|
|
for candidate_path in self.mod_tracker.parents: |
|
|
if candidate_path not in self.linear_data: |
|
|
continue |
|
|
if _is_path_correct_dw(candidate_path): |
|
|
return candidate_path, LinearBwType.DW |
|
|
if _is_path_correct_dx(candidate_path): |
|
|
return candidate_path, LinearBwType.DX |
|
|
return path, LinearBwType.UNKNOWN |
|
|
|
|
|
def log_tensor(self, name: str, x: torch.Tensor, **kwargs) -> None: |
|
|
self.store[name] = _get_stats(x, **kwargs) |
|
|
if self.write_tensors_tmpdir is not None: |
|
|
name_safe = name.replace("::", "__").replace("/", "") |
|
|
torch.save(x, self.write_tensors_tmpdir / f"{name_safe}.pkl") |
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
|
kwargs = kwargs if kwargs else {} |
|
|
path = None |
|
|
|
|
|
for p in self.mod_tracker.parents: |
|
|
if p == "Global": |
|
|
continue |
|
|
if path is None or len(p) > len(path): |
|
|
path = p |
|
|
if path is None: |
|
|
path = "Global" |
|
|
path = path.replace("._checkpoint_wrapped_module", "") |
|
|
out = func(*args, **kwargs) |
|
|
|
|
|
|
|
|
if func._overloadpacket in [torch.ops.aten.addmm, torch.ops.aten.mm]: |
|
|
weight: torch.Tensor |
|
|
input: torch.Tensor |
|
|
if not self.mod_tracker.is_bw: |
|
|
|
|
|
if func._overloadpacket == torch.ops.aten.addmm: |
|
|
_bias, input, weight = args[:3] |
|
|
else: |
|
|
assert func._overloadpacket == torch.ops.aten.mm |
|
|
input, weight = args[:2] |
|
|
self.log_tensor(f"{path}::in", input) |
|
|
self.log_tensor(f"{path}::w", weight) |
|
|
self.log_tensor(f"{path}::out", out) |
|
|
self.linear_data[path] = ( |
|
|
input.shape, |
|
|
weight.shape, |
|
|
out.shape, |
|
|
input[:4, :4].clone(), |
|
|
weight[:4, :4].T.clone(), |
|
|
) |
|
|
elif func._overloadpacket == torch.ops.aten.mm: |
|
|
|
|
|
|
|
|
new_path, bwtype = self._find_bw_path_and_type(path, out, args) |
|
|
if new_path != path: |
|
|
if self.verbose: |
|
|
print(f"E: Fixing path `{path}` -> `{new_path}") |
|
|
path = new_path |
|
|
|
|
|
if bwtype == LinearBwType.DW: |
|
|
|
|
|
self.log_tensor(f"{path}::w.g", out) |
|
|
elif bwtype == LinearBwType.DX: |
|
|
|
|
|
self.log_tensor(f"{path}::in.g", out) |
|
|
self.log_tensor(f"{path}::out.g", args[0]) |
|
|
elif func._overloadpacket in [ |
|
|
torch.ops.aten._scaled_dot_product_flash_attention, |
|
|
torch.ops.aten._scaled_dot_product_cudnn_attention, |
|
|
]: |
|
|
_, kwargs = normalize_function( |
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
_compute_attn_stats_sdpa(self, path, **kwargs) |
|
|
elif func._overloadpacket == fmha.flash.FwOp.OPERATOR: |
|
|
_, kwargs = normalize_function( |
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
_compute_attn_stats_flash(self, path, **kwargs) |
|
|
elif func._overloadpacket == torch.ops.torchprobe.log: |
|
|
uid = args[2] |
|
|
path = self.uid_to_path.setdefault(uid, path) |
|
|
self.log_tensor(f"{path}::{args[1]}", args[0]) |
|
|
if self.verbose: |
|
|
print(f"{'[BW]' if self.mod_tracker.is_bw else '[FW]'} `{path}`: {func}") |
|
|
return out |
|
|
|
|
|
|
|
|
def _find_all_submodules_compiled(out: List[nn.Module], module: nn.Module) -> None: |
|
|
if module._compiled_call_impl is not None: |
|
|
out.append(module) |
|
|
for c in module.children(): |
|
|
_find_all_submodules_compiled(out, module=c) |
|
|
|
|
|
|
|
|
class TorchCompileDisabler: |
|
|
def __init__(self, module: nn.Module) -> None: |
|
|
self.module = module |
|
|
self.submodules_compiled: List[nn.Module] = [] |
|
|
self.compiled_call_impl: List[Any] = [] |
|
|
self.disable_compile = torch.compiler.disable() |
|
|
torch._dynamo.config.raise_on_ctx_manager_usage = False |
|
|
|
|
|
def __enter__(self) -> None: |
|
|
|
|
|
|
|
|
self.submodules_compiled.clear() |
|
|
_find_all_submodules_compiled(self.submodules_compiled, self.module) |
|
|
self.compiled_call_impl = [ |
|
|
m._compiled_call_impl for m in self.submodules_compiled |
|
|
] |
|
|
for m in self.submodules_compiled: |
|
|
m._compiled_call_impl = None |
|
|
self.disable_compile.__enter__() |
|
|
|
|
|
def __exit__(self, *args) -> None: |
|
|
self.disable_compile.__exit__(*args) |
|
|
for m, c_impl in zip(self.submodules_compiled, self.compiled_call_impl): |
|
|
m._compiled_call_impl = c_impl |
|
|
self.compiled_call_impl = [] |
|
|
|
|
|
|
|
|
Probe = AutoProbeD |
|
|
|
|
|
|
|
|
d = 512 |
|
|
seqlen = 4 |
|
|
bs = 2 |
|
|
|
|
|
|
|
|
class Attention1(nn.Module): |
|
|
def forward(self, x): |
|
|
attn_bias = fmha.attn_bias.LowerTriangularFromBottomRightMask() |
|
|
return fmha.memory_efficient_attention(x, x, x, attn_bias=attn_bias).reshape( |
|
|
[x.shape[0], seqlen, -1] |
|
|
) |
|
|
|
|
|
|
|
|
class Attention2(nn.Module): |
|
|
def forward(self, x): |
|
|
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens( |
|
|
[seqlen] * bs |
|
|
).make_causal() |
|
|
xr = x.reshape([1, 2 * seqlen, x.shape[2], x.shape[3]]) |
|
|
return fmha.memory_efficient_attention(xr, xr, xr, attn_bias=attn_bias).reshape( |
|
|
[x.shape[0], seqlen, -1] |
|
|
) |
|
|
|
|
|
|
|
|
class AttentionSDPA(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.wo = nn.Linear(d, d) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x.transpose(1, 2) |
|
|
return self.wo( |
|
|
F.scaled_dot_product_attention(x, x, x) |
|
|
.transpose(1, 2) |
|
|
.reshape([x.shape[0], seqlen, -1]) |
|
|
) |
|
|
|
|
|
|
|
|
class AttentionSDPAFlash(AttentionSDPA): |
|
|
def forward(self, x): |
|
|
x = x.transpose(1, 2) |
|
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
|
|
return self.wo( |
|
|
F.scaled_dot_product_attention(x, x, x) |
|
|
.transpose(1, 2) |
|
|
.reshape([x.shape[0], seqlen, -1]) |
|
|
) |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
self.head = nn.Linear(d, 16) |
|
|
self.trunk = nn.Sequential( |
|
|
nn.Linear(d, d), |
|
|
nn.Linear(d, d), |
|
|
) |
|
|
self.q_proj = nn.Linear(d, d, bias=False) |
|
|
self.trunk.compile() |
|
|
self.attn1 = Attention1() |
|
|
self.attn2 = Attention2() |
|
|
self.attnSDPA = AttentionSDPA() |
|
|
self.attnSDPAflash = AttentionSDPAFlash() |
|
|
|
|
|
def forward(self, x): |
|
|
B, nHeads, D = x.shape[0], d // 64, 64 |
|
|
x = self.q_proj(x).reshape([B, seqlen, nHeads, D]) |
|
|
x = self.attn1(x) + self.attn2(x) + self.attnSDPA(x) + self.attnSDPAflash(x) |
|
|
x = log_stats(x, "attns_out") |
|
|
return self.head(self.trunk(x)) |
|
|
|
|
|
|
|
|
def test_masking() -> None: |
|
|
q_seqlen = [1, 1, 14, 12] |
|
|
kv_seqlen = [2, 2, 14, 18] |
|
|
attn_bias = fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( |
|
|
q_seqlen, kv_seqlen |
|
|
).make_causal_from_bottomright() |
|
|
logits = torch.randn( |
|
|
[1, 1, sum(q_seqlen), sum(kv_seqlen)], dtype=torch.float32, device="cuda" |
|
|
) |
|
|
bias = attn_bias.materialize(logits.shape, dtype=logits.dtype, device=logits.device) |
|
|
logits_masked = logits.clone() |
|
|
_mask_attn_logits( |
|
|
logits_masked, |
|
|
list(range(logits.shape[2])), |
|
|
causal=True, |
|
|
cu_seqlens_q=attn_bias.q_seqinfo.seqstart, |
|
|
cu_seqlens_k=attn_bias.k_seqinfo.seqstart, |
|
|
) |
|
|
assert (logits + bias == logits_masked).all().item() |
|
|
|
|
|
|
|
|
def test_toy_model() -> None: |
|
|
|
|
|
kw = dict(device="cuda", dtype=torch.float16) |
|
|
x = torch.randn([bs, seqlen, d], **kw) |
|
|
m = Model() |
|
|
m.head = checkpoint_wrapper( |
|
|
m.head, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False |
|
|
) |
|
|
m.to(**kw) |
|
|
m.compile() |
|
|
optim = torch.optim.SGD(m.parameters(), lr=0.0) |
|
|
probe = AutoProbeD(m, "./probe.json") |
|
|
|
|
|
for i in range(4): |
|
|
with contextlib.ExitStack() as stack: |
|
|
print(f"########### STEP {i}") |
|
|
if i % 4 == 1: |
|
|
stack.enter_context(probe) |
|
|
probe.metadata = {"it": i} |
|
|
y = m(x) |
|
|
g = torch.randn_like(y) |
|
|
y.backward(g) |
|
|
if i % 4 == 1: |
|
|
assert probe.enabled |
|
|
|
|
|
print(list(probe.store.keys())) |
|
|
for key in [ |
|
|
"Model::attns_out", |
|
|
"Model::attns_out.g", |
|
|
"Model.attn1::attn_logits", |
|
|
"Model.attn2::attn_logits", |
|
|
"Model.attnSDPA::attn_logits", |
|
|
"Model.attnSDPAflash::attn_logits", |
|
|
"Model.head::w", |
|
|
"Model.head::w.g", |
|
|
"Model.head::in", |
|
|
"Model.head::in.g", |
|
|
"Model.head::out", |
|
|
"Model.head::out.g", |
|
|
"Model.trunk.0::in", |
|
|
"Model.trunk.1::in", |
|
|
]: |
|
|
assert key in probe.store, f"Missing key: '{key}'" |
|
|
|
|
|
for key, tensor in [ |
|
|
("Model.head::w", m.head.weight), |
|
|
("Model.head::w.g", m.head.weight.grad), |
|
|
("Model.q_proj::in", x), |
|
|
("Model.q_proj::w.g", m.q_proj.weight.grad), |
|
|
("Model.head::out", y), |
|
|
("Model.head::out.g", g), |
|
|
]: |
|
|
assert key in probe.store, f"Missing key: '{key}'" |
|
|
assert torch.allclose( |
|
|
probe.store[key]["abs.mean"], tensor.float().abs().mean() |
|
|
), f"'{key}' mismatches" |
|
|
|
|
|
for key, value in probe.store.items(): |
|
|
if "abs.mean" in value: |
|
|
assert math.isfinite( |
|
|
value["abs.mean"].item() |
|
|
), f"Inf/Nan for {key}" |
|
|
optim.step() |
|
|
optim.zero_grad() |
|
|
|