import math import torch from torch import Tensor # Global variables for attention visualization step = 0 global_timestep = 0 def scaled_dot_product_average_attention_map(query, key, attn_mask=None, is_causal=False, scale=None) -> torch.Tensor: # Efficient implementation equivalent to the following: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias.to(attn_weight.device) attn_weight = attn_weight.mean(dim=(1, 2)) return attn_weight