Spaces:
Running on Zero
Running on Zero
File size: 8,439 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | """Utilities for HiDiffusion patches."""
from __future__ import annotations
import contextlib
import importlib
import itertools
import logging
import math
import sys
from functools import partial
from typing import TYPE_CHECKING, Callable, NamedTuple
from enum import Enum
import torch.nn.functional as F
from src.Utilities import Latent, upscale
# Logger for HiDiffusion modules
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from collections.abc import Sequence
from types import ModuleType
try:
from enum import StrEnum
except ImportError:
class StrEnum(str, Enum):
@staticmethod
def _generate_next_value_(name, *_): return name.lower()
def __str__(self): return str(self.value)
UPSCALE_METHODS = ("bicubic", "bislerp", "bilinear", "nearest-exact", "nearest", "area")
class TimeMode(StrEnum):
PERCENT = "percent"
TIMESTEP = "timestep"
SIGMA = "sigma"
class ModelType(StrEnum):
SD15 = "SD15"
SDXL = "SDXL"
def parse_blocks(name: str, val) -> set[tuple[str, int]]:
"""Parse block definitions."""
if isinstance(val, (tuple, list)):
return {(name, item) for item in val if isinstance(item, int) and item >= 0}
return {(name, int(v.strip())) for v in str(val).split(",") if v.strip()}
def convert_time(ms, time_mode: TimeMode, start: float, end: float) -> tuple[float, float]:
"""Convert time based on mode."""
if time_mode == TimeMode.SIGMA:
return start, end
if time_mode == TimeMode.TIMESTEP:
start, end = 1.0 - start / 999.0, 1.0 - end / 999.0
return round(ms.percent_to_sigma(start), 4), round(ms.percent_to_sigma(end), 4)
_sigma_cache, _pct_cache = {}, {}
def get_sigma(options, key="sigmas"):
"""Get sigma value from options."""
if not isinstance(options, dict) or (sigmas := options.get(key)) is None:
return None
if isinstance(sigmas, float):
return sigmas
cache_key = id(sigmas)
if cache_key not in _sigma_cache:
if len(_sigma_cache) > 4: _sigma_cache.clear()
_sigma_cache[cache_key] = sigmas.detach().cpu().max().item()
return _sigma_cache[cache_key]
def check_time(time_arg, start_sigma: float, end_sigma: float) -> bool:
"""Check if time is within sigma range."""
sigma = get_sigma(time_arg) if not isinstance(time_arg, float) else time_arg
return sigma is not None and start_sigma >= sigma >= end_sigma
_block_map = {"input": 0, "middle": 1, "output": 2}
def block_to_num(block_type: str, block_id: int) -> tuple[int, int]:
"""Convert block type to numerical representation."""
if (tid := _block_map.get(block_type)) is None:
raise ValueError(f"Unexpected block type {block_type}")
return tid, block_id
def rescale_size(width: int, height: int, target_res: int, tolerance=1) -> tuple[int, int]:
"""Rescale size to fit target resolution."""
tolerance = min(target_res, tolerance)
scale = math.sqrt(height * width / target_res)
hs, ws = height / scale, width / scale
def neighbors(n):
ni = int(n)
return [ni + adj for adj in sorted(range(-min(ni-1, tolerance), tolerance+1+math.ceil(n-ni)), key=abs)]
for h, w in itertools.zip_longest(neighbors(hs), neighbors(ws)):
if w and (ha := target_res / w) % 1 == 0: return w, int(ha)
if h and (wa := target_res / h) % 1 == 0: return int(wa), h
raise ValueError(f"Can't rescale {width}x{height} to {target_res}")
def guess_model_type(model) -> ModelType | None:
"""Guess model type from latent format."""
lf = model.get_model_object("latent_format")
if lf is None:
return None
# 1. Try explicit type checking (most reliable)
try:
if isinstance(lf, Latent.SDXL) or isinstance(lf, Latent.SDXL_Playground_2_5):
return ModelType.SDXL
if isinstance(lf, Latent.SD15):
return ModelType.SD15
except Exception:
pass
# 2. Fallback to channel-based heuristics
ch = getattr(lf, "latent_channels", None)
if ch == 4:
# Default to SD15 for 4 channels if not explicitly SDXL
return ModelType.SD15
if ch == 8:
# Some SDXL implementations/VAEs use 8 channels
return ModelType.SDXL
# 3. Exclude Flux/SD3 (16 or 32 channels) from UNet-specific HiDiffusion
return None
def sigma_to_pct(ms, sigma):
"""Convert sigma to percentage."""
if isinstance(sigma, float):
return (1.0 - ms.timestep(sigma) / 999.0).clamp(0.0, 1.0)
cache_key = id(sigma)
if cache_key not in _pct_cache:
if len(_pct_cache) > 4: _pct_cache.clear()
_pct_cache[cache_key] = (1.0 - ms.timestep(sigma).detach().cpu() / 999.0).clamp(0.0, 1.0).item()
return _pct_cache[cache_key]
def fade_scale(pct, start_pct=0.0, end_pct=1.0, fade_start=1.0, fade_cap=0.0):
"""Calculate fade scale."""
if not (start_pct <= pct <= end_pct) or start_pct > end_pct:
return 0.0
if pct < fade_start:
return 1.0
return max(fade_cap, 1.0 - (pct - fade_start) / (end_pct - fade_start))
def scale_samples(samples, width, height, mode="bicubic", sigma=None):
"""Scale samples to target size."""
if mode == "bislerp":
return upscale.bislerp(samples, width, height)
return F.interpolate(samples, size=(height, width), mode=mode)
class Integrations:
"""Integration manager."""
class Integration(NamedTuple):
key: str
module_name: str
handler: Callable | None = None
def __init__(self):
self.initialized, self.modules, self.init_handlers, self.handlers = False, {}, [], []
def __getitem__(self, key): return self.modules[key]
def __contains__(self, key): return key in self.modules
def __getattr__(self, key): return self.modules.get(key)
@staticmethod
def get_custom_node(name: str):
module_key = f"custom_nodes.{name}"
with contextlib.suppress(StopIteration):
spec = importlib.util.find_spec(module_key)
if spec:
return next((v for v in sys.modules.copy().values()
if hasattr(v, "__spec__") and v.__spec__ and v.__spec__.origin == spec.origin), None)
return None
def register_init_handler(self, h): self.init_handlers.append(h)
def register_integration(self, key, module_name, handler=None):
if self.initialized: raise ValueError("Cannot register after init")
self.handlers.append(self.Integration(key, module_name, handler))
def initialize(self):
if self.initialized: return
self.initialized = True
for ih in self.handlers:
if (mod := self.get_custom_node(ih.module_name)):
mod = ih.handler(mod) if ih.handler else mod
if mod: self.modules[ih.key] = mod
for h in self.init_handlers: h(self)
class JHDIntegrations(Integrations):
"""JHD-specific integrations."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_integration("bleh", "ComfyUI-bleh", self.bleh_integration)
self.register_integration("freeu_advanced", "FreeU_Advanced")
@classmethod
def bleh_integration(cls, bleh):
return bleh if getattr(bleh, "BLEH_VERSION", -1) >= 0 else None
MODULES = JHDIntegrations()
class IntegratedNode(type):
"""Metaclass for integrated nodes."""
@staticmethod
def wrap_INPUT_TYPES(orig, *args, **kwargs):
MODULES.initialize()
return orig(*args, **kwargs)
def __new__(cls, name, bases, attrs):
obj = type.__new__(cls, name, bases, attrs)
if hasattr(obj, "INPUT_TYPES"):
obj.INPUT_TYPES = partial(cls.wrap_INPUT_TYPES, obj.INPUT_TYPES)
return obj
def init_integrations(integrations):
"""Initialize integrations."""
global scale_samples, UPSCALE_METHODS
if (bleh := integrations.bleh) and (lu := getattr(bleh.py, "latent_utils", None)):
UPSCALE_METHODS = lu.UPSCALE_METHODS
if getattr(bleh, "BLEH_VERSION", -1) >= 0:
scale_samples = lu.scale_samples
else:
scale_samples = lambda *a, sigma=None, **k: lu.scale_samples(*a, **k)
MODULES.register_init_handler(init_integrations)
__all__ = ("UPSCALE_METHODS", "check_time", "convert_time", "get_sigma", "guess_model_type",
"logger", "parse_blocks", "rescale_size", "scale_samples")
|