| | import sys |
| | from typing import Any, Callable, Union |
| |
|
| | from torch import nn |
| | from torch.utils.hooks import RemovableHandle |
| |
|
| | from ldm.modules.diffusionmodules.openaimodel import ( |
| | TimestepEmbedSequential, |
| | ) |
| | from ldm.modules.attention import ( |
| | SpatialTransformer, |
| | BasicTransformerBlock, |
| | CrossAttention, |
| | MemoryEfficientCrossAttention, |
| | ) |
| | from ldm.modules.diffusionmodules.openaimodel import ( |
| | ResBlock, |
| | ) |
| | from modules.processing import StableDiffusionProcessing |
| | from modules import shared |
| |
|
| | class ForwardHook: |
| | |
| | def __init__(self, module: nn.Module, fn: Callable[[nn.Module, Callable[..., Any], Any], Any]): |
| | self.o = module.forward |
| | self.fn = fn |
| | self.module = module |
| | self.module.forward = self.forward |
| | |
| | def remove(self): |
| | if self.module is not None and self.o is not None: |
| | self.module.forward = self.o |
| | self.module = None |
| | self.o = None |
| | self.fn = None |
| | |
| | def forward(self, *args, **kwargs): |
| | if self.module is not None and self.o is not None: |
| | if self.fn is not None: |
| | return self.fn(self.module, self.o, *args, **kwargs) |
| | return None |
| | |
| |
|
| | class SDHook: |
| | |
| | def __init__(self, enabled: bool): |
| | self._enabled = enabled |
| | self._handles: list[Union[RemovableHandle,ForwardHook]] = [] |
| | |
| | @property |
| | def enabled(self): |
| | return self._enabled |
| | |
| | @enabled.setter |
| | def enabled(self, v: bool): |
| | self._enabled = bool(v) |
| | |
| | @property |
| | def batch_num(self): |
| | return shared.state.job_no |
| | |
| | @property |
| | def step_num(self): |
| | return shared.state.current_image_sampling_step |
| | |
| | def __enter__(self): |
| | if self.enabled: |
| | pass |
| | |
| | def __exit__(self, exc_type, exc_value, traceback): |
| | if self.enabled: |
| | for handle in self._handles: |
| | handle.remove() |
| | self._handles.clear() |
| | self.dispose() |
| | |
| | def dispose(self): |
| | pass |
| | |
| | def setup( |
| | self, |
| | p: StableDiffusionProcessing |
| | ): |
| | if not self.enabled: |
| | return |
| | |
| | wrapper = getattr(p.sd_model, "model", None) |
| | |
| | unet: Union[nn.Module,None] = getattr(wrapper, "diffusion_model", None) if wrapper is not None else None |
| | vae: Union[nn.Module,None] = getattr(p.sd_model, "first_stage_model", None) |
| | clip: Union[nn.Module,None] = getattr(p.sd_model, "cond_stage_model", None) |
| | |
| | assert unet is not None, "p.sd_model.diffusion_model is not found. broken model???" |
| | self._do_hook(p, p.sd_model, unet=unet, vae=vae, clip=clip) |
| | self.on_setup() |
| | |
| | def on_setup(self): |
| | pass |
| | |
| | def _do_hook( |
| | self, |
| | p: StableDiffusionProcessing, |
| | model: Any, |
| | unet: Union[nn.Module,None], |
| | vae: Union[nn.Module,None], |
| | clip: Union[nn.Module,None] |
| | ): |
| | assert model is not None, "empty model???" |
| | |
| | if clip is not None: |
| | self.hook_clip(p, clip) |
| | |
| | if unet is not None: |
| | self.hook_unet(p, unet) |
| | |
| | if vae is not None: |
| | self.hook_vae(p, vae) |
| | |
| | def hook_vae( |
| | self, |
| | p: StableDiffusionProcessing, |
| | vae: nn.Module |
| | ): |
| | pass |
| |
|
| | def hook_unet( |
| | self, |
| | p: StableDiffusionProcessing, |
| | unet: nn.Module |
| | ): |
| | pass |
| |
|
| | def hook_clip( |
| | self, |
| | p: StableDiffusionProcessing, |
| | clip: nn.Module |
| | ): |
| | pass |
| |
|
| | def hook_layer( |
| | self, |
| | module: Union[nn.Module,Any], |
| | fn: Callable[[nn.Module, tuple, Any], Any] |
| | ): |
| | if not self.enabled: |
| | return |
| | |
| | assert module is not None |
| | assert isinstance(module, nn.Module) |
| | self._handles.append(module.register_forward_hook(fn)) |
| |
|
| | def hook_layer_pre( |
| | self, |
| | module: Union[nn.Module,Any], |
| | fn: Callable[[nn.Module, tuple], Any] |
| | ): |
| | if not self.enabled: |
| | return |
| | |
| | assert module is not None |
| | assert isinstance(module, nn.Module) |
| | self._handles.append(module.register_forward_pre_hook(fn)) |
| |
|
| | def hook_forward( |
| | self, |
| | module: Union[nn.Module,Any], |
| | fn: Callable[[nn.Module, Callable[..., Any], Any], Any] |
| | ): |
| | assert module is not None |
| | assert isinstance(module, nn.Module) |
| | self._handles.append(ForwardHook(module, fn)) |
| | |
| | def log(self, msg: str): |
| | print(msg, file=sys.stderr) |
| |
|
| |
|
| | |
| | def each_transformer(unet_block: TimestepEmbedSequential): |
| | for block in unet_block.children(): |
| | if isinstance(block, SpatialTransformer): |
| | yield block |
| |
|
| | |
| | def each_basic_block(trans: SpatialTransformer): |
| | for block in trans.transformer_blocks.children(): |
| | if isinstance(block, BasicTransformerBlock): |
| | yield block |
| |
|
| | |
| | |
| | def each_attns(unet_block: TimestepEmbedSequential): |
| | for n, trans in enumerate(each_transformer(unet_block)): |
| | for depth, basic_block in enumerate(each_basic_block(trans)): |
| | |
| | |
| | |
| | attn1, attn2 = basic_block.attn1, basic_block.attn2 |
| | assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention) |
| | assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention) |
| | |
| | yield n, depth, attn1, attn2 |
| |
|
| | def each_unet_attn_layers(unet: nn.Module): |
| | def get_attns(layer_index: int, block: TimestepEmbedSequential, format: str): |
| | for n, d, attn1, attn2 in each_attns(block): |
| | kwargs = { |
| | 'layer_index': layer_index, |
| | 'trans_index': n, |
| | 'block_index': d |
| | } |
| | yield format.format(attn_name='sattn', **kwargs), attn1 |
| | yield format.format(attn_name='xattn', **kwargs), attn2 |
| | |
| | def enumerate_all(blocks: nn.ModuleList, format: str): |
| | for idx, block in enumerate(blocks.children()): |
| | if isinstance(block, TimestepEmbedSequential): |
| | yield from get_attns(idx, block, format) |
| | |
| | inputs: nn.ModuleList = unet.input_blocks |
| | middle: TimestepEmbedSequential = unet.middle_block |
| | outputs: nn.ModuleList = unet.output_blocks |
| | |
| | yield from enumerate_all(inputs, 'IN{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}') |
| | yield from get_attns(0, middle, 'M{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}') |
| | yield from enumerate_all(outputs, 'OUT{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}') |
| |
|
| |
|
| | def each_unet_transformers(unet: nn.Module): |
| | def get_trans(layer_index: int, block: TimestepEmbedSequential, format: str): |
| | for n, trans in enumerate(each_transformer(block)): |
| | kwargs = { |
| | 'layer_index': layer_index, |
| | 'block_index': n, |
| | 'block_name': 'trans', |
| | } |
| | yield format.format(**kwargs), trans |
| | |
| | def enumerate_all(blocks: nn.ModuleList, format: str): |
| | for idx, block in enumerate(blocks.children()): |
| | if isinstance(block, TimestepEmbedSequential): |
| | yield from get_trans(idx, block, format) |
| | |
| | inputs: nn.ModuleList = unet.input_blocks |
| | middle: TimestepEmbedSequential = unet.middle_block |
| | outputs: nn.ModuleList = unet.output_blocks |
| | |
| | yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}') |
| | yield from get_trans(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}') |
| | yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}') |
| |
|
| |
|
| | def each_resblock(unet_block: TimestepEmbedSequential): |
| | for block in unet_block.children(): |
| | if isinstance(block, ResBlock): |
| | yield block |
| |
|
| | def each_unet_resblock(unet: nn.Module): |
| | def get_resblock(layer_index: int, block: TimestepEmbedSequential, format: str): |
| | for n, res in enumerate(each_resblock(block)): |
| | kwargs = { |
| | 'layer_index': layer_index, |
| | 'block_index': n, |
| | 'block_name': 'resblock', |
| | } |
| | yield format.format(**kwargs), res |
| | |
| | def enumerate_all(blocks: nn.ModuleList, format: str): |
| | for idx, block in enumerate(blocks.children()): |
| | if isinstance(block, TimestepEmbedSequential): |
| | yield from get_resblock(idx, block, format) |
| | |
| | inputs: nn.ModuleList = unet.input_blocks |
| | middle: TimestepEmbedSequential = unet.middle_block |
| | outputs: nn.ModuleList = unet.output_blocks |
| | |
| | yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}') |
| | yield from get_resblock(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}') |
| | yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}') |
| |
|
| |
|