| import torch
|
| from torch import nn, einsum
|
| from einops import rearrange, repeat
|
| import torch.nn.functional as F
|
| import math
|
| from comfy import model_management
|
| import types
|
| import os
|
|
|
| def exists(val):
|
| return val is not None
|
|
|
|
|
| abs_mean = lambda x: torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x).abs().mean()
|
|
|
| class temperature_patcher():
|
| def __init__(self, temperature, layer_name="None"):
|
| self.temperature = temperature
|
| self.layer_name = layer_name
|
|
|
|
|
| def attention_basic_with_temperature(self, q, k, v, extra_options, mask=None, attn_precision=None):
|
| if isinstance(extra_options, int):
|
| heads = extra_options
|
| else:
|
| heads = extra_options['n_heads']
|
|
|
| b, _, dim_head = q.shape
|
| dim_head //= heads
|
| scale = dim_head ** -0.5
|
|
|
| h = heads
|
| q, k, v = map(
|
| lambda t: t.unsqueeze(3)
|
| .reshape(b, -1, heads, dim_head)
|
| .permute(0, 2, 1, 3)
|
| .reshape(b * heads, -1, dim_head)
|
| .contiguous(),
|
| (q, k, v),
|
| )
|
|
|
|
|
| if attn_precision == torch.float32:
|
| sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
| else:
|
| sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
|
|
| del q, k
|
|
|
| if exists(mask):
|
| if mask.dtype == torch.bool:
|
| mask = rearrange(mask, 'b ... -> b (...)')
|
| max_neg_value = -torch.finfo(sim.dtype).max
|
| mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| sim.masked_fill_(~mask, max_neg_value)
|
| else:
|
| if len(mask.shape) == 2:
|
| bs = 1
|
| else:
|
| bs = mask.shape[0]
|
| mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
| sim.add_(mask)
|
|
|
|
|
| sim = sim.div(self.temperature if self.temperature > 0 else abs_mean(sim)).softmax(dim=-1)
|
|
|
| out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
| out = (
|
| out.unsqueeze(0)
|
| .reshape(b, heads, -1, dim_head)
|
| .permute(0, 2, 1, 3)
|
| .reshape(b, -1, heads * dim_head)
|
| )
|
| return out
|
|
|
| layers_SD15 = {
|
| "input":[1,2,4,5,7,8],
|
| "middle":[0],
|
| "output":[3,4,5,6,7,8,9,10,11],
|
| }
|
|
|
| layers_SDXL = {
|
| "input":[4,5,7,8],
|
| "middle":[0],
|
| "output":[0,1,2,3,4,5],
|
| }
|
|
|
| class ExperimentalTemperaturePatch:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| required_inputs = {f"{key}_{layer}": ("BOOLEAN", {"default": False}) for key, layers in s.TOGGLES.items() for layer in layers}
|
| required_inputs["model"] = ("MODEL",)
|
| required_inputs["Temperature"] = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01})
|
| required_inputs["Attention"] = (["both","self","cross"],)
|
| return {"required": required_inputs}
|
|
|
| TOGGLES = {}
|
| RETURN_TYPES = ("MODEL","STRING",)
|
| RETURN_NAMES = ("Model","String",)
|
| FUNCTION = "patch"
|
|
|
| CATEGORY = "model_patches/Automatic_CFG/Standalone_temperature_patches"
|
|
|
| def patch(self, model, Temperature, Attention, **kwargs):
|
| m = model.clone()
|
| levels = ["input","middle","output"]
|
| parameters_output = {level:[] for level in levels}
|
| for key, toggle_enabled in kwargs.items():
|
| current_level = key.split("_")[0]
|
| if current_level in levels and toggle_enabled:
|
| b_number = int(key.split("_")[1])
|
| parameters_output[current_level].append(b_number)
|
| patcher = temperature_patcher(Temperature,key)
|
|
|
| if Attention in ["both","self"]:
|
| m.set_model_attn1_replace(patcher.attention_basic_with_temperature, current_level, b_number)
|
| if Attention in ["both","cross"]:
|
| m.set_model_attn2_replace(patcher.attention_basic_with_temperature, current_level, b_number)
|
|
|
| parameters_as_string = "\n".join(f"{k}: {','.join(map(str, v))}" for k, v in parameters_output.items())
|
| parameters_as_string = f"Temperature: {Temperature}\n{parameters_as_string}\nAttention: {Attention}"
|
| return (m, parameters_as_string,)
|
|
|
| ExperimentalTemperaturePatchSDXL = type("ExperimentalTemperaturePatch_SDXL", (ExperimentalTemperaturePatch,), {"TOGGLES": layers_SDXL})
|
| ExperimentalTemperaturePatchSD15 = type("ExperimentalTemperaturePatch_SD15", (ExperimentalTemperaturePatch,), {"TOGGLES": layers_SD15})
|
|
|
| class CLIPTemperaturePatch:
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {"required": { "clip": ("CLIP",),
|
| "Temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
| }}
|
|
|
| RETURN_TYPES = ("CLIP",)
|
| FUNCTION = "patch"
|
| CATEGORY = "model_patches/Automatic_CFG/Standalone_temperature_patches"
|
|
|
| def patch(self, clip, Temperature):
|
| def custom_optimized_attention(device, mask=None, small_input=True):
|
| return temperature_patcher(Temperature).attention_basic_with_temperature
|
|
|
| def new_forward(self, x, mask=None, intermediate_output=None):
|
| optimized_attention = custom_optimized_attention(x.device, mask=mask is not None, small_input=True)
|
|
|
| if intermediate_output is not None:
|
| if intermediate_output < 0:
|
| intermediate_output = len(self.layers) + intermediate_output
|
|
|
| intermediate = None
|
| for i, l in enumerate(self.layers):
|
| x = l(x, mask, optimized_attention)
|
| if i == intermediate_output:
|
| intermediate = x.clone()
|
| return x, intermediate
|
|
|
| m = clip.clone()
|
|
|
| clip_encoder_instance = m.cond_stage_model.clip_l.transformer.text_model.encoder
|
| clip_encoder_instance.forward = types.MethodType(new_forward, clip_encoder_instance)
|
|
|
| if getattr(m.cond_stage_model, f"clip_g", None) is not None:
|
| clip_encoder_instance_g = m.cond_stage_model.clip_g.transformer.text_model.encoder
|
| clip_encoder_instance_g.forward = types.MethodType(new_forward, clip_encoder_instance_g)
|
|
|
| return (m,)
|
|
|
| class CLIPTemperaturePatchDual:
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {"required": { "clip": ("CLIP",),
|
| "Temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
| "CLIP_Model": (["clip_g","clip_l","both"],),
|
| }}
|
|
|
| RETURN_TYPES = ("CLIP",)
|
| FUNCTION = "patch"
|
| CATEGORY = "model_patches/Automatic_CFG/Standalone_temperature_patches"
|
|
|
| def patch(self, clip, Temperature, CLIP_Model):
|
| def custom_optimized_attention(device, mask=None, small_input=True):
|
| return temperature_patcher(Temperature, "CLIP").attention_basic_with_temperature
|
|
|
| def new_forward(self, x, mask=None, intermediate_output=None):
|
| optimized_attention = custom_optimized_attention(x.device, mask=mask is not None, small_input=True)
|
|
|
| if intermediate_output is not None:
|
| if intermediate_output < 0:
|
| intermediate_output = len(self.layers) + intermediate_output
|
|
|
| intermediate = None
|
| for i, l in enumerate(self.layers):
|
| x = l(x, mask, optimized_attention)
|
| if i == intermediate_output:
|
| intermediate = x.clone()
|
| return x, intermediate
|
|
|
| m = clip.clone()
|
|
|
| if CLIP_Model in ["clip_l","both"]:
|
| clip_encoder_instance = m.cond_stage_model.clip_l.transformer.text_model.encoder
|
| clip_encoder_instance.forward = types.MethodType(new_forward, clip_encoder_instance)
|
|
|
| if CLIP_Model in ["clip_g","both"]:
|
| if getattr(m.cond_stage_model, f"clip_g", None) is not None:
|
| clip_encoder_instance_g = m.cond_stage_model.clip_g.transformer.text_model.encoder
|
| clip_encoder_instance_g.forward = types.MethodType(new_forward, clip_encoder_instance_g)
|
|
|
| return (m,)
|
|
|