Commit
·
4c2ce48
1
Parent(s):
4744be5
- modules/adm_patch.py +33 -0
- modules/core.py +3 -0
modules/adm_patch.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import comfy.model_base
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def sdxl_encode_adm_patched(self, **kwargs):
|
| 6 |
+
clip_pooled = kwargs["pooled_output"]
|
| 7 |
+
width = kwargs.get("width", 768)
|
| 8 |
+
height = kwargs.get("height", 768)
|
| 9 |
+
crop_w = kwargs.get("crop_w", 0)
|
| 10 |
+
crop_h = kwargs.get("crop_h", 0)
|
| 11 |
+
target_width = kwargs.get("target_width", width)
|
| 12 |
+
target_height = kwargs.get("target_height", height)
|
| 13 |
+
|
| 14 |
+
if kwargs.get("prompt_type", "") == "negative":
|
| 15 |
+
admk = 0.8
|
| 16 |
+
width *= admk
|
| 17 |
+
height *= admk
|
| 18 |
+
target_width *= admk
|
| 19 |
+
target_height *= admk
|
| 20 |
+
|
| 21 |
+
out = []
|
| 22 |
+
out.append(self.embedder(torch.Tensor([height])))
|
| 23 |
+
out.append(self.embedder(torch.Tensor([width])))
|
| 24 |
+
out.append(self.embedder(torch.Tensor([crop_h])))
|
| 25 |
+
out.append(self.embedder(torch.Tensor([crop_w])))
|
| 26 |
+
out.append(self.embedder(torch.Tensor([target_height])))
|
| 27 |
+
out.append(self.embedder(torch.Tensor([target_width])))
|
| 28 |
+
flat = torch.flatten(torch.cat(out))[None, ]
|
| 29 |
+
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def patch_negative_adm():
|
| 33 |
+
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
modules/core.py
CHANGED
|
@@ -12,7 +12,10 @@ from comfy.sd import load_checkpoint_guess_config
|
|
| 12 |
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
|
| 13 |
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
| 14 |
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
|
|
|
|
| 15 |
|
|
|
|
|
|
|
| 16 |
opCLIPTextEncode = CLIPTextEncode()
|
| 17 |
opEmptyLatentImage = EmptyLatentImage()
|
| 18 |
opVAEDecode = VAEDecode()
|
|
|
|
| 12 |
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
|
| 13 |
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
| 14 |
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
|
| 15 |
+
from modules.adm_patch import patch_negative_adm
|
| 16 |
|
| 17 |
+
|
| 18 |
+
patch_negative_adm()
|
| 19 |
opCLIPTextEncode = CLIPTextEncode()
|
| 20 |
opEmptyLatentImage = EmptyLatentImage()
|
| 21 |
opVAEDecode = VAEDecode()
|