Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from collections.abc import Iterable | |
| from typing import Any | |
| import torch | |
| from app.core.model_support import get_decoder_layers | |
| _ATTENTION_STORE: dict[int, torch.Tensor] = {} | |
| def clear_stored_attentions() -> None: | |
| _ATTENTION_STORE.clear() | |
| def get_stored_attentions() -> dict[int, torch.Tensor]: | |
| return dict(_ATTENTION_STORE) | |
| def _extract_attention_tensor(output: Any) -> torch.Tensor | None: | |
| if isinstance(output, torch.Tensor): | |
| return output if output.dim() == 4 else None | |
| if isinstance(output, dict): | |
| for value in output.values(): | |
| if isinstance(value, torch.Tensor) and value.dim() == 4: | |
| return value | |
| if isinstance(output, Iterable) and not isinstance(output, (str, bytes)): | |
| for item in output: | |
| if isinstance(item, torch.Tensor) and item.dim() == 4: | |
| return item | |
| return None | |
| def _get_attention_impl(model: Any) -> str | None: | |
| config = getattr(model, "config", None) | |
| if config is None: | |
| return None | |
| return getattr(config, "_attn_implementation", None) or getattr( | |
| config, | |
| "attn_implementation", | |
| None, | |
| ) | |
| def make_attn_hook(layer_idx: int): | |
| def hook(_module: Any, _inputs: Any, output: Any) -> None: | |
| attn = _extract_attention_tensor(output) | |
| if attn is None: | |
| return | |
| if attn.dim() != 4: | |
| raise RuntimeError(f"Expected 4D attention tensor at layer {layer_idx}, got {attn.shape}.") | |
| attn.retain_grad() | |
| _ATTENTION_STORE[layer_idx] = attn | |
| return hook | |
| def register_hooks(model: Any) -> list[Any]: | |
| clear_stored_attentions() | |
| layers, _layer_path, attention_attr = get_decoder_layers(model) | |
| handles: list[Any] = [] | |
| for layer_idx, layer in enumerate(layers): | |
| self_attn = getattr(layer, attention_attr, None) | |
| if self_attn is None: | |
| raise RuntimeError(f"Layer {layer_idx} does not expose {attention_attr}.") | |
| handles.append(self_attn.register_forward_hook(make_attn_hook(layer_idx))) | |
| return handles | |
| def remove_hooks(handles: list[Any]) -> None: | |
| for handle in handles: | |
| handle.remove() | |
| clear_stored_attentions() | |