File size: 10,002 Bytes
6f0b660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
from PIL import Image
from ..masking_utils import create_causal_mask
from ..models.auto.auto_factory import _get_model_class
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.modeling_auto import MODEL_FOR_PRETRAINING_MAPPING, MODEL_MAPPING
from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES, AutoProcessor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES, AutoTokenizer
from .import_utils import is_torch_available
if is_torch_available():
import torch
import torch.nn as nn
# Print the matrix with words as row labels
GREEN = "\033[92m"
YELLOW = "\033[93m"
RESET = "\033[0m"
BLACK_SQUARE = "■"
WHITE_SQUARE = "⬚"
def generate_attention_matrix_from_mask(
words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
):
"""
Generates an attention matrix from a given attention mask.
Optionally applies a sliding window mask (e.g., for Gemma2/3) and
marks regions where image tokens occur based on the specified `img_token`.
"""
mask = mask.int()
if mask.ndim == 3:
mask = mask[0, :, :]
if mask.ndim == 4:
mask = mask[0, 0, :, :]
n = len(words)
max_word_length = max(len(repr(word)) for word in words)
first_img_idx = 0
output = []
for i, k in enumerate(words):
if k == img_token and not first_img_idx:
first_img_idx = i
mask[i, i] = 2 # Mark yellow regions
if first_img_idx > 0 and (k != img_token or i == n - 1):
if i == n - 1:
i += 1
mask[first_img_idx:i, first_img_idx:i] = 2 # Mark yellow regions
first_img_idx = 0
# Generate sliding window mask (size = 4), excluding img_token
sliding_window_mask = None
if sliding_window is not None:
sliding_window_mask = [[1 if (0 <= i - j < sliding_window) else 0 for j in range(n)] for i in range(n)]
row_dummy = " ".join(
f"{YELLOW}{BLACK_SQUARE}{RESET}"
if mask[0, j]
else f"{GREEN}{BLACK_SQUARE}{RESET}"
if 0 == j
else BLACK_SQUARE
if mask[0, j]
else WHITE_SQUARE
for j in range(n)
)
if token_type_ids is not None:
is_special = token_type_ids == 1
token_type_buckets = torch.where(
(token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
)
boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
# Print headers
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
output.append(" " + legend)
f_string = " " * (max_word_length + 5) + "Attention Matrix".ljust(len(row_dummy) // 2)
if sliding_window is not None:
f_string += "Sliding Window Mask"
output.append(f_string)
vertical_header = []
for idx, word in enumerate(words):
if mask[idx, idx] == 2:
vertical_header.append([f"{YELLOW}{k}{RESET}" for k in list(str(idx).rjust(len(str(n))))])
else:
vertical_header.append(list(str(idx).rjust(len(str(n)))))
vertical_header = list(map(list, zip(*vertical_header))) # Transpose
for row in vertical_header:
output.append(
(max_word_length + 5) * " " + " ".join(row) + " | " + " ".join(row)
if sliding_window is not None
else ""
)
for i, word in enumerate(words):
word_repr = repr(word).ljust(max_word_length)
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
row_display = " ".join(
f"{YELLOW}{BLACK_SQUARE}{RESET}"
if img_token in words[j] and mask[i, j] and img_token in word
else f"{GREEN}{BLACK_SQUARE}{RESET}"
if i == j
else BLACK_SQUARE
if mask[i, j]
else WHITE_SQUARE
for j in range(n)
)
sliding_window_row = ""
if sliding_window is not None:
sliding_window_row = " ".join(
f"{YELLOW}{BLACK_SQUARE}{RESET}"
if img_token in words[j] and img_token in word and token_type_buckets[0, i] == token_type_buckets[0, j]
else f"{GREEN}{BLACK_SQUARE}{RESET}"
if i == j
else BLACK_SQUARE
if sliding_window_mask[i][j]
else WHITE_SQUARE
for j in range(n)
)
output.append(f"{colored_word}: {str(i).rjust(2)} {row_display} | {sliding_window_row}")
return "\n".join(output)
class AttentionMaskVisualizer:
def __init__(self, model_name: str):
config = AutoConfig.from_pretrained(model_name)
self.image_token = "<img>"
if hasattr(config.get_text_config(), "sliding_window"):
self.sliding_window = getattr(config.get_text_config(), "sliding_window", None)
try:
mapped_cls = _get_model_class(config, MODEL_MAPPING)
except Exception:
mapped_cls = _get_model_class(config, MODEL_FOR_PRETRAINING_MAPPING)
if mapped_cls is None:
raise ValueError(f"Model name {model_name} is not supported for attention visualization")
self.mapped_cls = mapped_cls
class _ModelWrapper(mapped_cls, nn.Module):
def __init__(self, config, model_name):
nn.Module.__init__(self)
self.dummy_module = nn.Linear(1, 1)
self.config = config
self.model = _ModelWrapper(config, model_name)
self.model.to(config.dtype)
self.repo_id = model_name
self.config = config
def __call__(self, input_sentence: str, suffix=""):
self.visualize_attention_mask(input_sentence, suffix=suffix)
def visualize_attention_mask(self, input_sentence: str, suffix=""):
model = self.model
kwargs = {}
image_seq_length = None
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
img = Image.open(requests.get(img, stream=True).raw)
image_seq_length = 5
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
if hasattr(processor, "image_token"):
image_token = processor.image_token
else:
image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
if image_token:
input_sentence = input_sentence.replace("<img>", image_token)
inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
attention_mask = inputs["attention_mask"]
if "token_type_ids" in inputs: # TODO inspect signature of update causal mask
kwargs["token_type_ids"] = inputs["token_type_ids"]
tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
elif self.config.model_type in TOKENIZER_MAPPING_NAMES:
tokenizer = AutoTokenizer.from_pretrained(self.repo_id)
tokens = tokenizer.tokenize(input_sentence)
attention_mask = tokenizer(input_sentence, return_tensors="pt")["attention_mask"]
else:
raise ValueError(f"Model type {model.config.model_type} does not support attention visualization")
model.config._attn_implementation = "eager"
model.train()
batch_size, seq_length = attention_mask.shape
input_embeds = torch.zeros((batch_size, seq_length, model.config.hidden_size), dtype=self.model.dtype)
cache_position = torch.arange(seq_length)
causal_mask = create_causal_mask(
config=model.config,
input_embeds=input_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=None,
)
if causal_mask is not None:
attention_mask = ~causal_mask.bool()
else:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length)
top_bottom_border = "##" * (
len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4
) # Box width adjusted to text length
side_border = "##"
print(f"\n{top_bottom_border}")
print(
"##"
+ f" Attention visualization for \033[1m{self.config.model_type}:{self.repo_id}\033[0m {self.mapped_cls.__name__}".center(
len(top_bottom_border)
)
+ " "
+ side_border,
)
print(f"{top_bottom_border}")
f_string = generate_attention_matrix_from_mask(
tokens,
attention_mask,
img_token=self.image_token,
sliding_window=getattr(self.config, "sliding_window", None),
token_type_ids=kwargs.get("token_type_ids"),
image_seq_length=image_seq_length,
)
print(f_string)
print(f"{top_bottom_border}")
|