File size: 23,380 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
import os
import logging
import copy
import gradio as gr
import torch
import re
from torchvision.transforms import GaussianBlur


from einops import rearrange
from modules import shared, script_callbacks
from modules.images import get_next_sequence_number
from modules.processing import StableDiffusionProcessing
from scripts.ui_wrapper import UIWrapper, arg
from scripts.incant_utils import module_hooks, plot_tools, prompt_utils

logger = logging.getLogger(__name__)


module_field_map = {
    'savemaps': True,
    'savemaps_batch': None,
    'savemaps_step': None,
    'savemaps_save_steps': None,
}


SUBMODULES = ['to_q', 'to_k', 'to_v']


class SaveAttentionMapsScript(UIWrapper):
    def __init__(self):
        self.infotext_fields: list = []
        self.paste_field_names: list = []

    def title(self) -> str:
        return "Save Attention Maps"
    
    def setup_ui(self, is_img2img) -> list:
        with gr.Accordion('Save Attention Maps', open = False):
            with gr.Row():
                active = gr.Checkbox(label = 'Active', default = False)
                map_types = gr.CheckboxGroup(
                    label = 'Map Types',
                    choices = ['One-Hot Map', 'Per-Token Maps'],
                    value = ['One-Hot Map'],
                    info = 'Select the type of attention maps to save.',
                )
            export_folder = gr.Textbox(visible=False, label = 'Export Folder', value = 'attention_maps', info = 'Folder to save attention maps to as a subdirectory of the outputs.')
            module_name_filter = gr.Textbox(label = 'Module Names', value = 'input_blocks_5_1_transformer_blocks_0_attn2', info = 'Module name to save attention maps for. If the substring is found in the module name, the attention maps will be saved for that module.')
            class_name_filter = gr.Textbox(label = 'Class Name Filter', value = 'CrossAttention', info = 'Filters eligible modules by the class name.')
            save_every_n_step = gr.Slider(label = 'Save Every N Step', value = 0, min = 0, max = 100, step = 1, info = 'Save attention maps every N steps. 0 to save last step.')
            print_modules = gr.Button(value = 'Print Modules To Console')
            print_modules.click(self.print_modules, inputs=[module_name_filter, class_name_filter])

        self.infotext_fields = []
        self.paste_field_names = []

        opts = [active, module_name_filter, class_name_filter, save_every_n_step, map_types]
        for opt in opts:
            opt.do_not_save_to_config = True
        return opts
    
    def before_process_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs):
        # Always unhook the modules first
        module_list = self.get_modules_by_filter(module_name_filter, class_name_filter)
        script_callbacks.remove_current_script_callbacks()
        self.unhook_modules(module_list, copy.deepcopy(module_field_map))

        setattr(p, 'savemaps_module_list', module_list)
        setattr(p, 'savemaps_map_types', map_types)

        if not active:
            return
        
        token_count, _= prompt_utils.get_token_count(p.prompt, p.steps, True)

        if token_count <= 0:
            logger.warning("No tokens found in prompt. Skipping saving attention maps.")
            return

        setattr(p, 'savemaps_token_count', token_count)
        setattr(p, 'savemaps_step', 0)

        token_indices = []
        # Tokenize/decode the prompts
        tokenized_prompts = []
        batch_chunks, _ = prompt_utils.tokenize_prompt(p.prompt)
        for batch in batch_chunks:
            for sub_batch in batch:
                tokenized_prompts.append(prompt_utils.decode_tokenized_prompt(sub_batch.tokens))
        for tp_prompt in tokenized_prompts:
            for tp in tp_prompt:
                token_idx, token_id, word = tp
                # jank
                if token_id < 49406:
                    token_indices.append(token_idx)
                # sanitize tokenized prompts
                tp[2] = re.escape(word)


        setattr(p, 'savemaps_tokenized_prompts', tokenized_prompts)
        setattr(p, 'savemaps_token_indices', token_indices)

                
        # Make sure the output folder exists
        outpath_samples = p.outpath_samples
        # Move this to plot tools?
        if not outpath_samples:
            logger.warning("No output path found. Skipping saving attention maps.")
            return
        output_folder_path = os.path.join(outpath_samples, 'attention_maps')
        if not os.path.exists(output_folder_path):
            logger.info(f"Creating directory: {output_folder_path}")
            os.makedirs(output_folder_path)
        
        # sequence number for saving
        seq_num = get_next_sequence_number(output_folder_path, basename='')
        setattr(p, 'savemaps_seq_num', seq_num)

        latent_shape = [p.height // p.rng.shape[1], p.width // p.rng.shape[2]] # (height, width)
        
        save_steps = []
        min_step = max(save_every_n_step-1, 0) 
        if save_every_n_step > 0:
            save_steps = list(range(min_step, p.steps, save_every_n_step))
        else:
            save_steps = [p.steps-1]
        # always save last step
        if p.steps-1 not in save_steps:
            save_steps.append(p.steps-1)
        setattr(p, 'savemaps_save_steps', save_steps)

        # Create fields in module
        value_map = copy.deepcopy(module_field_map)
        value_map['savemaps_save_steps'] = save_steps
        value_map['savemaps_step'] = 0
        #value_map['savemaps_shape'] = torch.tensor(latent_shape).to(device=shared.device, dtype=torch.int32)
        self.hook_modules(module_list, value_map, p)
        self.create_save_hook(module_list)

        def on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams):
            """ Sets the step for all modules
                the webui reports an incorrect step so we just count it ourselves
            """
            for module in module_list:
                module.savemaps_step = p.savemaps_step
            # logger.debug('Setting step to %d for %d modules', p.savemaps_step, len(module_list))
            p.savemaps_step += 1
        
        script_callbacks.on_cfg_denoiser(on_cfg_denoiser)


    def process(self, p, *args, **kwargs):
        pass

    def before_process(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs):
        module_list = self.get_modules_by_filter(module_name_filter, class_name_filter)
        self.unhook_modules(module_list, copy.deepcopy(module_field_map))

    def process_batch(self, p, *args, **kwargs):
        pass

    def postprocess_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, map_types, *args, **kwargs):
        module_list = self.get_modules_by_filter(module_name_filter, class_name_filter)

        if getattr(p, 'savemaps_token_count', None) is None:
            self.unhook_modules(module_list, copy.deepcopy(module_field_map))
            return

        base_seq_num = getattr(p, 'savemaps_seq_num', None)
        map_types = getattr(p, 'savemaps_map_types', [])
        tokenized_prompts = getattr(p, 'savemaps_tokenized_prompts', None)
        token_indices = getattr(p, 'savemaps_token_indices', None)
        save_steps = getattr(p, 'savemaps_save_steps', None)
        save_image_path = os.path.join(p.outpath_samples, 'attention_maps')

        plot_is_self = False # kind of useless

        for module in module_list:
            network_layer_name = module.network_layer_name

            if not hasattr(module, 'savemaps_batch') or module.savemaps_batch is None:
                logger.error(f"No attention maps found for module: {network_layer_name}")
                continue

            # self attn maps are kind of useless atm
            is_self = getattr(module, 'savemaps_is_self', False)
            if is_self and not plot_is_self:
                continue

            # selfattn: seq_len = hw
            # crossattn: seq_len = # of tokens
            attn_maps = module.savemaps_batch # (attn_map num, 2 * batch_num, height * width, sequence_len)
            attn_map_num, batch_num, hw, seq_len = attn_maps.shape
            token_indices = p.savemaps_token_indices
            save_steps = p.savemaps_save_steps
            downscale_h = round((hw * (p.height / p.width)) ** 0.5)
            downscale_w = hw // downscale_h
            gaussian_blur = GaussianBlur(kernel_size=3, sigma=1)

            # Blur maps
            if is_self:
                attn_maps = attn_maps.view(attn_map_num * batch_num, downscale_h, downscale_w, seq_len) # if self-attn, we need to blur over the sequence length
            attn_maps = attn_maps.permute(0, 3, 1, 2) # (ab, seq_len, height, width)
            attn_maps = gaussian_blur(attn_maps)  # Applying Gaussian smoothing
            attn_maps = attn_maps.permute(0, 2, 3, 1) # (ab, height, width, seq_len)
            if is_self:
                attn_maps = attn_maps.view(attn_map_num, 2, batch_num // 2, downscale_h * downscale_w, seq_len).mean(dim=1) # (attn_map num, batch_num, hw, hw)
                attn_maps = attn_maps.unsqueeze(2) # (attn_map num, batch_num, 1, hw, hw)
            else:
                attn_maps = rearrange(attn_maps, 'n (m b) (h w) t -> n m b t h w', m = 2, h = downscale_h).mean(dim=1) # (attn_map num, batch_num, token_idx, height, width)
                attn_map_num, batch_num, token_dim, h, w = attn_maps.shape

            output_dict_maps = []
            per_token_dict_maps = []
            one_hot_dict_maps = []

            if 'Per-Token Maps' in map_types:

                # write to dict
                for attn_map_idx in range(attn_maps.shape[0]):
                    for batch_idx in range(batch_num):
                        for token_idx in token_indices:

                            attnmap = attn_maps[attn_map_idx, batch_idx, token_idx]
                            _, token_id, word = tokenized_prompts[batch_idx][token_idx]

                            plot_type = f"({token_idx}, {token_id}, '{word}')"
                            filename_info = f'token{token_idx:04}'
                            plot_color = 'viridis'

                            map_info: dict = self.create_base_dict(plot_type, base_seq_num, network_layer_name, save_steps, attn_map_idx, batch_idx, attnmap, filename_info, plot_color)
                            map_info.update({
                                'token_idx': token_idx,
                                'token_id': token_id,
                                'token_word': word,
                            })
                            output_dict_maps.append(map_info)

            if 'One-Hot Map' in map_types:
                one_hot_map = attn_maps[:, :, token_indices] # (attn_map num, batch_num, token_idx, height, width)
                one_hot_map = one_hot_map.argmax(dim=2, keepdim=True)
                one_hot_map = one_hot_map.to(dtype=torch.float16)

                # quantize to stable number of colors s.t. 
                num_colors = max(len(token_indices), 1)
                min_val, max_val = one_hot_map.min(), one_hot_map.max()
                step = 1 / num_colors
                one_hot_map *= step
                one_hot_map = one_hot_map.sum(dim=2) # (attn_map num, batch_num, height, width)

                # write to dict
                for attn_map_idx in range(one_hot_map.shape[0]):
                    for batch_idx in range(batch_num):
                        plot_type = "One Hot"
                        plot_color = 'plasma'
                        attnmap = one_hot_map[attn_map_idx, batch_idx]
                        ohm_info: dict = self.create_base_dict(plot_type, base_seq_num, network_layer_name, save_steps, attn_map_idx, batch_idx, attnmap, 'ohm', plot_color)
                        output_dict_maps.append(ohm_info)

            # Save maps from map dict
            for md in output_dict_maps:
                base_seq_num = md['seq_num']
                network_layer_name = md['network_layer_name']
                savestep_num = md['savestep_num']
                attn_map_idx = md['attn_map_idx']
                batch_idx = md['batch_idx']

                # output filename and path
                filename_info = md['filename_info']
                if len(filename_info) > 0:
                    filename_info = f'{filename_info}_'

                out_file_name = f'{base_seq_num:04}-{network_layer_name}_{filename_info}step{savestep_num:04}_attnmap_{attn_map_idx:04}_batch{batch_idx:04}.png'
                out_save_path = os.path.join(save_image_path, out_file_name)

                # plot title
                plot_type = md['plot_type']
                plot_color = md['plot_color']
                plot_title = f"{network_layer_name}\nStep {savestep_num}"
                if len(plot_type) > 0:
                    plot_title += f", {plot_type}"

                attn_map = md['attnmap']
                plot_tools.plot_attention_map(
                    attention_map = attn_map,
                    title = plot_title,
                    save_path = out_save_path,
                    plot_type = plot_color,
                )

                if shared.state.interrupted:
                    self.unhook_modules(module_list, copy.deepcopy(module_field_map))
                    return 
        self.unhook_modules(module_list, copy.deepcopy(module_field_map))

    def create_base_dict(self, plot_type:str, base_seq_num: int, network_layer_name: str, save_steps: list, attn_map_idx: int, batch_idx: int, attnmap: torch.Tensor, filename_info: str, plot_color: str):
        """ Create a base dictionary for saving attention maps for minimum metadata that the save function expects 
        Arguments:
                plot_type: str - name of the type of plot, used in the plot title
                base_seq_num: int - start sequence number for saving, prefixes the filename with "000xx-" where xx is the sequence number
                module_name: str - the module's network layer name
                save_steps: list[int] - list of steps to save attention maps for, should be same length as the number of attention maps
                attn_map_idx: int - index of the attention map
                batch_idx: int - index of the batch
                attnmap: torch.Tensor - attention map of shape [C, H, W]
                filename_info: str- a string that goes in the middle of the filename f"000xx-{filename_info}-000yy.png"
                plot_color: str - one of the matplotlib color maps (default is 'viridis')
        """
        network_layer_name = network_layer_name.removeprefix('diffusion_model_')
        network_layer_name = network_layer_name.replace('transformer_blocks_', 'tr_bl_')
        base_dict = {
            'plot_type': plot_type,
            'seq_num': base_seq_num + batch_idx,
            'step': save_steps[attn_map_idx] + 1,
            'network_layer_name': network_layer_name,
            'attn_map_idx': attn_map_idx,
            'savestep_num': save_steps[attn_map_idx] + 1,
            'batch_idx': batch_idx,
            'attnmap': attnmap,
            'filename_info': filename_info,
            'plot_color': plot_color,
        }
        return base_dict
    
    def unhook_callbacks(self) -> None:
        pass

    def get_xyz_axis_options(self) -> dict:
        return {}
    
    def get_infotext_fields(self) -> list:
        return self.infotext_fields
    
    def create_save_hook(self, module_list):
        pass

    def hook_modules(self, module_list: list, value_map: dict, p: StableDiffusionProcessing):
        def savemaps_hook(module, input, kwargs, output):
            """ Hook to save attention maps every N steps, or the last step if N is 0.
            Saves attention maps to a field named 'savemaps_batch' in the module.
            with shape (attn_map, batch_num, height * width).
            
            """
            #module.savemaps_step += 1

            if not module.savemaps_step in module.savemaps_save_steps:
                return
            reweight_crossattn = True 


            is_self = getattr(module, 'savemaps_is_self', False)
            to_q_map = getattr(module, 'savemaps_to_q_map', None)
            to_k_map = to_q_map if module.savemaps_is_self else getattr(module, 'savemaps_to_k_map', None)

            # we want to reweight the attention scores by removing influence of the first token
            orig_seq_len = to_k_map.shape[1]
            # token_count = module.savemaps_token_count
            # min_token = 0
            # max_token = min(token_count+1, orig_seq_len)
            token_indices = module.savemaps_token_indices

            if not is_self and reweight_crossattn:
                to_k_map = to_k_map[:, token_indices, :]

            attn_map = get_attention_scores(to_q_map, to_k_map, dtype=to_q_map.dtype)
            b, hw, seq_len = attn_map.shape

            if not is_self and reweight_crossattn:
                #to_attn_zeros = torch.zeros([b, hw]).unsqueeze(-1).to(device=shared.device, dtype=attn_map.dtype) # (batch, h*w, 1)
                #attn_map = torch.cat([to_attn_zeros, attn_map], dim=-1) # re pad to original token dim size
                left_pad = 1
                right_pad = orig_seq_len - seq_len - 1
                attn_map = torch.nn.functional.pad(attn_map, (left_pad, right_pad), value=0) # re pad to original token dim size

            # multiply into text embeddings
            attn_map = attn_map.unsqueeze(0)

            #attn_map = attn_map.mean(dim=-1)
            if module.savemaps_batch is None:
                module.savemaps_batch = attn_map
            else:
                module.savemaps_batch = torch.cat([module.savemaps_batch, attn_map], dim=0)

        def savemaps_to_q_hook(module, input, kwargs, output):
                setattr(module.savemaps_parent_module[0], 'savemaps_to_q_map', output)

        def savemaps_to_k_hook(module, input, kwargs, output):
                if not module.savemaps_parent_module[0].savemaps_is_self:
                    setattr(module.savemaps_parent_module[0],'savemaps_to_k_map', output)

        def savemaps_to_v_hook(module, input, kwargs, output):
                setattr(module.savemaps_parent_module[0],'savemaps_to_v_map', output)

        #for module, kv in zip(module_list, value_map.items()):
        for module in module_list:
            # logger.debug('Adding hook to %s', module.network_layer_name)
            for key_name, default_value in value_map.items():
                module_hooks.modules_add_field(module, key_name, default_value)

            module_hooks.module_add_forward_hook(module, savemaps_hook, 'forward', with_kwargs=True)
            module_hooks.modules_add_field(module, 'savemaps_token_count', p.savemaps_token_count)
            module_hooks.modules_add_field(module, 'savemaps_token_indices', p.savemaps_token_indices)

            if module.network_layer_name.endswith('attn1'): # self attn
                module_hooks.modules_add_field(module, 'savemaps_is_self', True)
            if module.network_layer_name.endswith('attn2'): # self attn
                module_hooks.modules_add_field(module, 'savemaps_is_self', False)

            for module_name in SUBMODULES:
                if not hasattr(module, module_name):
                    logger.error(f"Submodule not found: {module_name} in module: {module.network_layer_name}")
                    continue
                submodule = getattr(module, module_name)
                hook_fn_name = f'savemaps_{module_name}_hook'
                hook_fn = locals().get(hook_fn_name, None)
                if not hook_fn:
                    logger.error(f"Hook function '{hook_fn_name}' not found for submodule: {module_name}")
                    continue

                module_hooks.modules_add_field(submodule, 'savemaps_parent_module', [module])
                module_hooks.module_add_forward_hook(submodule, hook_fn, 'forward', with_kwargs=True)
    
    def unhook_modules(self, module_list: list, value_map: dict):
        for module in module_list:
            for key_name, _ in value_map.items():
                module_hooks.modules_remove_field(module, key_name)
            module_hooks.modules_remove_field(module, 'savemaps_is_self')
            module_hooks.modules_remove_field(module, 'savemaps_token_count')
            module_hooks.modules_remove_field(module, 'savemaps_token_indices')
            module_hooks.remove_module_forward_hook(module, 'savemaps_hook')
            for module_name in SUBMODULES:
                module_hooks.modules_remove_field(module, f'savemaps_{module_name}_map')

                if hasattr(module, module_name):
                    submodule = getattr(module, module_name)
                    module_hooks.modules_remove_field(submodule, 'savemaps_parent_module')    
                    module_hooks.remove_module_forward_hook(submodule, f'savemaps_{module_name}_hook')


    def print_modules(self, module_name_filter, class_name_filter):
            logger.info("Module name filter: '%s', Class name filter: '%s'", module_name_filter, class_name_filter)
            modules = self.get_modules_by_filter(module_name_filter, class_name_filter)
            module_names = [""]
            if len(modules) > 0:
                module_names = "\n".join([f"{m.network_layer_name}: {m.__class__.__name__}" for m in modules])
            logger.info("Modules found:\n----------\n%s\n----------\n", module_names)

    def get_modules_by_filter(self, module_name_filter, class_name_filter):
        if len(class_name_filter) == 0:
            class_name_filter = None
        if len(module_name_filter) == 0:
            module_name_filter = None
        found_modules = module_hooks.get_modules(module_name_filter, class_name_filter)
        if len(found_modules) == 0:
            logger.warning(f"No modules found with module name filter: {module_name_filter} and class name filter")
        return found_modules


def get_attention_scores(to_q_map, to_k_map, dtype):
        """ Calculate the attention scores for the given query and key maps
        Arguments:
                to_q_map: torch.Tensor - query map
                to_k_map: torch.Tensor - key map
                dtype: torch.dtype - data type of the tensor
        Returns:
                torch.Tensor - attention scores 
        """
        # based on diffusers models/attention.py "get_attention_scores"
        # use in place operations vs. softmax to save memory: https://stackoverflow.com/questions/53732209/torch-in-place-operations-to-save-memory-softmax
        # 512x: 2.65G -> 2.47G

        attn_probs = to_q_map @ to_k_map.transpose(-1, -2)
        attn_probs = attn_probs.to(dtype=torch.float32) #

        channel_dim = to_q_map.size(1)
        attn_probs /= (channel_dim ** 0.5)
        attn_probs -= attn_probs.max()

        # avoid nan by converting to float32 and subtracting max 
        attn_probs = attn_probs.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype)
        attn_probs = attn_probs.to(dtype=dtype)

        return attn_probs