| from typing import Callable, Literal |
| import torch |
| import torch.nn as nn |
| from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock |
|
|
|
|
| def register_general_hook(pipe, position, hook, with_kwargs=False, is_pre_hook=False): |
| """Registers a forward hook in a module of the pipeline specified with 'position' |
| |
| Args: |
| pipe (_type_): _description_ |
| position (_type_): _description_ |
| hook (_type_): _description_ |
| with_kwargs (bool, optional): _description_. Defaults to False. |
| is_pre_hook (bool, optional): _description_. Defaults to False. |
| |
| Returns: |
| _type_: _description_ |
| """ |
|
|
| block: nn.Module = locate_block(pipe, position) |
|
|
| if is_pre_hook: |
| return block.register_forward_pre_hook(hook, with_kwargs=with_kwargs) |
| else: |
| return block.register_forward_hook(hook, with_kwargs=with_kwargs) |
|
|
|
|
| def locate_block(pipe, position: str) -> nn.Module: |
| ''' |
| Locate the block at the specified position in the pipeline. |
| ''' |
| block = pipe |
| for step in position.split('.'): |
| if step.isdigit(): |
| step = int(step) |
| block = block[step] |
| else: |
| block = getattr(block, step) |
| return block |
|
|
|
|
| def _safe_clip(x: torch.Tensor): |
| if x.dtype == torch.float16: |
| x[torch.isposinf(x)] = 65504 |
| x[torch.isneginf(x)] = -65504 |
| return x |
| |
|
|
| @torch.no_grad() |
| def fix_inf_values_hook(*args): |
|
|
| |
| if len(args) == 3: |
| module, input, output = args |
| |
| elif len(args) == 4: |
| module, input, kwinput, output = args |
|
|
| if isinstance(module, FluxTransformerBlock): |
| return _safe_clip(output[0]), _safe_clip(output[1]) |
|
|
| elif isinstance(module, FluxSingleTransformerBlock): |
| return _safe_clip(output) |
| |
|
|
| @torch.no_grad() |
| def edit_streams_hook(*args, |
| recompute_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
| stream: Literal["text", "image", "both"]): |
| """ |
| recompute_fn will get as input the input tensor and the output tensor for such stream |
| and returns what should be the new modified output |
| """ |
|
|
| |
| if len(args) == 3: |
| module, input, output = args |
| |
| elif len(args) == 4: |
| module, input, kwinput, output = args |
| else: |
| raise AssertionError(f'Weird len(args):{len(args)}') |
|
|
| if isinstance(module, FluxTransformerBlock): |
|
|
| if stream == 'text': |
| output_text = recompute_fn(kwinput["encoder_hidden_states"], output[0]) |
| output_image = output[1] |
| elif stream == 'image': |
| output_image = recompute_fn(kwinput["hidden_states"], output[1]) |
| output_text = output[0] |
| else: |
| raise AssertionError("Branch not supported for this layer.") |
|
|
| return _safe_clip(output_text), _safe_clip(output_image) |
|
|
| elif isinstance(module, FluxSingleTransformerBlock): |
| |
| if stream == 'text': |
| output[:, :512] = recompute_fn(kwinput["hidden_states"][:, :512], output[:, :512]) |
| elif stream == 'image': |
| output[:, 512:] = recompute_fn(kwinput["hidden_states"][:, 512:], output[:, 512:]) |
| else: |
| output = recompute_fn(kwinput["hidden_states"], output) |
| |
| return _safe_clip(output) |
| |