Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| from comfy.model_patcher import ModelPatcher | |
| from typing import Union | |
| T = torch.Tensor | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d | |
| class StyleAlignedArgs: | |
| def __init__(self, share_attn: str) -> None: | |
| self.adain_keys = "k" in share_attn | |
| self.adain_values = "v" in share_attn | |
| self.adain_queries = "q" in share_attn | |
| share_attention: bool = True | |
| adain_queries: bool = True | |
| adain_keys: bool = True | |
| adain_values: bool = True | |
| def expand_first( | |
| feat: T, | |
| scale=1.0, | |
| ) -> T: | |
| """ | |
| Expand the first element so it has the same shape as the rest of the batch. | |
| """ | |
| b = feat.shape[0] | |
| feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) | |
| if scale == 1: | |
| feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) | |
| else: | |
| feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) | |
| feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) | |
| return feat_style.reshape(*feat.shape) | |
| def concat_first(feat: T, dim=2, scale=1.0) -> T: | |
| """ | |
| concat the the feature and the style feature expanded above | |
| """ | |
| feat_style = expand_first(feat, scale=scale) | |
| return torch.cat((feat, feat_style), dim=dim) | |
| def calc_mean_std(feat, eps: float = 1e-5) -> "tuple[T, T]": | |
| feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() | |
| feat_mean = feat.mean(dim=-2, keepdims=True) | |
| return feat_mean, feat_std | |
| def adain(feat: T) -> T: | |
| feat_mean, feat_std = calc_mean_std(feat) | |
| feat_style_mean = expand_first(feat_mean) | |
| feat_style_std = expand_first(feat_std) | |
| feat = (feat - feat_mean) / feat_std | |
| feat = feat * feat_style_std + feat_style_mean | |
| return feat | |
| class SharedAttentionProcessor: | |
| def __init__(self, args: StyleAlignedArgs, scale: float): | |
| self.args = args | |
| self.scale = scale | |
| def __call__(self, q, k, v, extra_options): | |
| if self.args.adain_queries: | |
| q = adain(q) | |
| if self.args.adain_keys: | |
| k = adain(k) | |
| if self.args.adain_values: | |
| v = adain(v) | |
| if self.args.share_attention: | |
| k = concat_first(k, -2, scale=self.scale) | |
| v = concat_first(v, -2) | |
| return q, k, v | |
| def get_norm_layers( | |
| layer: nn.Module, | |
| norm_layers_: "dict[str, list[Union[nn.GroupNorm, nn.LayerNorm]]]", | |
| share_layer_norm: bool, | |
| share_group_norm: bool, | |
| ): | |
| if isinstance(layer, nn.LayerNorm) and share_layer_norm: | |
| norm_layers_["layer"].append(layer) | |
| if isinstance(layer, nn.GroupNorm) and share_group_norm: | |
| norm_layers_["group"].append(layer) | |
| else: | |
| for child_layer in layer.children(): | |
| get_norm_layers( | |
| child_layer, norm_layers_, share_layer_norm, share_group_norm | |
| ) | |
| def register_norm_forward( | |
| norm_layer: Union[nn.GroupNorm, nn.LayerNorm], | |
| ) -> Union[nn.GroupNorm, nn.LayerNorm]: | |
| if not hasattr(norm_layer, "orig_forward"): | |
| setattr(norm_layer, "orig_forward", norm_layer.forward) | |
| orig_forward = norm_layer.orig_forward | |
| def forward_(hidden_states: T) -> T: | |
| n = hidden_states.shape[-2] | |
| hidden_states = concat_first(hidden_states, dim=-2) | |
| hidden_states = orig_forward(hidden_states) # type: ignore | |
| return hidden_states[..., :n, :] | |
| norm_layer.forward = forward_ # type: ignore | |
| return norm_layer | |
| def register_shared_norm( | |
| model: ModelPatcher, | |
| share_group_norm: bool = True, | |
| share_layer_norm: bool = True, | |
| ): | |
| norm_layers = {"group": [], "layer": []} | |
| get_norm_layers(model.model, norm_layers, share_layer_norm, share_group_norm) | |
| print( | |
| f"Patching {len(norm_layers['group'])} group norms, {len(norm_layers['layer'])} layer norms." | |
| ) | |
| return [register_norm_forward(layer) for layer in norm_layers["group"]] + [ | |
| register_norm_forward(layer) for layer in norm_layers["layer"] | |
| ] | |
| SHARE_NORM_OPTIONS = ["both", "group", "layer", "disabled"] | |
| SHARE_ATTN_OPTIONS = ["q+k", "q+k+v", "disabled"] | |
| def styleAlignBatch(model, share_norm, share_attn, scale=1.0): | |
| m = model.clone() | |
| share_group_norm = share_norm in ["group", "both"] | |
| share_layer_norm = share_norm in ["layer", "both"] | |
| register_shared_norm(model, share_group_norm, share_layer_norm) | |
| args = StyleAlignedArgs(share_attn) | |
| m.set_model_attn1_patch(SharedAttentionProcessor(args, scale)) | |
| return m |