Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| from torch import Tensor | |
| # Global variables for attention visualization | |
| step = 0 | |
| global_timestep = 0 | |
| def scaled_dot_product_average_attention_map(query, key, attn_mask=None, is_causal=False, scale=None) -> torch.Tensor: | |
| # Efficient implementation equivalent to the following: | |
| L, S = query.size(-2), key.size(-2) | |
| scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | |
| attn_bias = torch.zeros(L, S, dtype=query.dtype) | |
| if is_causal: | |
| assert attn_mask is None | |
| temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | |
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | |
| attn_bias.to(query.dtype) | |
| if attn_mask is not None: | |
| if attn_mask.dtype == torch.bool: | |
| attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) | |
| else: | |
| attn_bias += attn_mask | |
| attn_weight = query @ key.transpose(-2, -1) * scale_factor | |
| attn_weight += attn_bias.to(attn_weight.device) | |
| attn_weight = attn_weight.mean(dim=(1, 2)) | |
| return attn_weight | |