| from comfy.ldm.modules.attention import default, optimized_attention, optimized_attention_masked | |
| from .style_functions import adain, concat_first | |
| class VisualStyleProcessor(object): | |
| def __init__(self, | |
| module_self, | |
| keys_scale: float = 1.0, | |
| enabled: bool = True, | |
| adain_queries: bool = True, | |
| adain_keys: bool = True, | |
| adain_values: bool = False | |
| ): | |
| self.module_self = module_self | |
| self.keys_scale = keys_scale | |
| self.enabled = enabled | |
| self.adain_queries = adain_queries | |
| self.adain_keys = adain_keys | |
| self.adain_values = adain_values | |
| def visual_style_forward(self, x, context, value, mask=None): | |
| q = self.module_self.to_q(x) | |
| context = default(context, x) | |
| k = self.module_self.to_k(context) | |
| if value is not None: | |
| v = self.module_self.to_v(value) | |
| del value | |
| else: | |
| v = self.module_self.to_v(context) | |
| if self.enabled: | |
| if self.adain_queries: | |
| q = adain(q) | |
| if self.adain_keys: | |
| k = adain(k) | |
| if self.adain_values: | |
| v = adain(v) | |
| k = concat_first(k, -2, self.keys_scale) | |
| v = concat_first(v, -2) | |
| if mask is None: | |
| out = optimized_attention(q, k, v, self.module_self.heads) | |
| else: | |
| out = optimized_attention_masked(q, k, v, self.module_self.heads, mask) | |
| return self.module_self.to_out(out) |