| import math |
| from enum import Enum |
| from collections import namedtuple |
| from types import MethodType |
|
|
| import cv2 |
| import torch |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from modules import devices, shared, prompt_parser, extra_networks |
| from modules import sd_samplers_common |
| from modules.shared import state |
| from modules.processing import opt_f |
|
|
| from tile_utils.typing import * |
|
|
| state: State |
|
|
|
|
| class ComparableEnum(Enum): |
|
|
| def __eq__(self, other: Any) -> bool: |
| if isinstance(other, str): return self.value == other |
| elif isinstance(other, ComparableEnum): return self.value == other.value |
| else: raise TypeError(f'unsupported type: {type(other)}') |
|
|
| class Method(ComparableEnum): |
|
|
| MULTI_DIFF = 'MultiDiffusion' |
| MIX_DIFF = 'Mixture of Diffusers' |
|
|
| class Method_2(ComparableEnum): |
| DEMO_FU = "DemoFusion" |
|
|
| class BlendMode(Enum): |
|
|
| FOREGROUND = 'Foreground' |
| BACKGROUND = 'Background' |
|
|
| BBoxSettings = namedtuple('BBoxSettings', ['enable', 'x', 'y', 'w', 'h', 'prompt', 'neg_prompt', 'blend_mode', 'feather_ratio', 'seed']) |
| NoiseInverseCache = namedtuple('NoiseInversionCache', ['model_hash', 'x0', 'xt', 'noise_inversion_steps', 'retouch', 'prompts']) |
| DEFAULT_BBOX_SETTINGS = BBoxSettings(False, 0.4, 0.4, 0.2, 0.2, '', '', BlendMode.BACKGROUND.value, 0.2, -1) |
| NUM_BBOX_PARAMS = len(BBoxSettings._fields) |
|
|
|
|
| def build_bbox_settings(bbox_control_states:List[Any]) -> Dict[int, BBoxSettings]: |
| settings = {} |
| for index, i in enumerate(range(0, len(bbox_control_states), NUM_BBOX_PARAMS)): |
| setting = BBoxSettings(*bbox_control_states[i:i+NUM_BBOX_PARAMS]) |
| |
| setting = setting._replace( |
| x=round(setting.x, 4), |
| y=round(setting.y, 4), |
| w=round(setting.w, 4), |
| h=round(setting.h, 4), |
| feather_ratio=round(setting.feather_ratio, 4), |
| seed=int(setting.seed), |
| ) |
| |
| if not setting.enable or setting.x > 1.0 or setting.y > 1.0 or setting.w <= 0.0 or setting.h <= 0.0: continue |
| settings[index] = setting |
| return settings |
|
|
| def gr_value(value=None, visible=None): |
| return {"value": value, "visible": visible, "__type__": "update"} |
|
|
|
|
| class BBox: |
|
|
| ''' grid bbox ''' |
|
|
| def __init__(self, x:int, y:int, w:int, h:int): |
| self.x = x |
| self.y = y |
| self.w = w |
| self.h = h |
| self.box = [x, y, x+w, y+h] |
| self.slicer = slice(None), slice(None), slice(y, y+h), slice(x, x+w) |
|
|
| def __getitem__(self, idx:int) -> int: |
| return self.box[idx] |
|
|
| class CustomBBox(BBox): |
|
|
| ''' region control bbox ''' |
|
|
| def __init__(self, x:int, y:int, w:int, h:int, prompt:str, neg_prompt:str, blend_mode:str, feather_radio:float, seed:int): |
| super().__init__(x, y, w, h) |
| self.prompt = prompt |
| self.neg_prompt = neg_prompt |
| self.blend_mode = BlendMode(blend_mode) |
| self.feather_ratio = max(min(feather_radio, 1.0), 0.0) |
| self.seed = seed |
| |
| self.feather_mask = feather_mask(self.w, self.h, self.feather_ratio) if self.blend_mode == BlendMode.FOREGROUND else None |
| self.cond: MulticondLearnedConditioning = None |
| self.extra_network_data: DefaultDict[List[ExtraNetworkParams]] = None |
| self.uncond: List[List[ScheduledPromptConditioning]] = None |
|
|
|
|
| class Prompt: |
|
|
| ''' prompts helper ''' |
|
|
| @staticmethod |
| def apply_styles(prompts:List[str], styles=None) -> List[str]: |
| if not styles: return prompts |
| return [shared.prompt_styles.apply_styles_to_prompt(p, styles) for p in prompts] |
|
|
| @staticmethod |
| def append_prompt(prompts:List[str], prompt:str='') -> List[str]: |
| if not prompt: return prompts |
| return [f'{p}, {prompt}' for p in prompts] |
|
|
| class Condition: |
|
|
| ''' CLIP cond helper ''' |
|
|
| @staticmethod |
| def get_custom_cond(prompts:List[str], prompt, steps:int, styles=None) -> Tuple[Cond, ExtraNetworkData]: |
| prompt = Prompt.apply_styles([prompt], styles)[0] |
| _, extra_network_data = extra_networks.parse_prompts([prompt]) |
| prompts = Prompt.append_prompt(prompts, prompt) |
| prompts = Prompt.apply_styles(prompts, styles) |
| cond = Condition.get_cond(prompts, steps) |
| return cond, extra_network_data |
| |
| @staticmethod |
| def get_cond(prompts, steps:int): |
| prompts, _ = extra_networks.parse_prompts(prompts) |
| cond = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, steps) |
| return cond |
|
|
| @staticmethod |
| def get_uncond(neg_prompts:List[str], steps:int, styles=None) -> Uncond: |
| neg_prompts = Prompt.apply_styles(neg_prompts, styles) |
| uncond = prompt_parser.get_learned_conditioning(shared.sd_model, neg_prompts, steps) |
| return uncond |
|
|
| @staticmethod |
| def reconstruct_cond(cond:Cond, step:int) -> Tensor: |
| list_of_what, tensor = prompt_parser.reconstruct_multicond_batch(cond, step) |
| return tensor |
|
|
| def reconstruct_uncond(uncond:Uncond, step:int) -> Tensor: |
| tensor = prompt_parser.reconstruct_cond_batch(uncond, step) |
| return tensor |
|
|
|
|
| def splitable(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16) -> bool: |
| w, h = w // opt_f, h // opt_f |
| min_tile_size = min(tile_w, tile_h) |
| if overlap >= min_tile_size: |
| overlap = min_tile_size - 4 |
| cols = math.ceil((w - overlap) / (tile_w - overlap)) |
| rows = math.ceil((h - overlap) / (tile_h - overlap)) |
| return cols > 1 or rows > 1 |
|
|
| def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: |
| cols = math.ceil((w - overlap) / (tile_w - overlap)) |
| rows = math.ceil((h - overlap) / (tile_h - overlap)) |
| dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 |
| dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 |
|
|
| bbox_list: List[BBox] = [] |
| weight = torch.zeros((1, 1, h, w), device=devices.device, dtype=torch.float32) |
| for row in range(rows): |
| y = min(int(row * dy), h - tile_h) |
| for col in range(cols): |
| x = min(int(col * dx), w - tile_w) |
|
|
| bbox = BBox(x, y, tile_w, tile_h) |
| bbox_list.append(bbox) |
| weight[bbox.slicer] += init_weight |
|
|
| return bbox_list, weight |
|
|
|
|
| def gaussian_weights(tile_w:int, tile_h:int) -> Tensor: |
| ''' |
| Copy from the original implementation of Mixture of Diffusers |
| https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py |
| This generates gaussian weights to smooth the noise of each tile. |
| This is critical for this method to work. |
| ''' |
| from numpy import pi, exp, sqrt |
| |
| f = lambda x, midpoint, var=0.01: exp(-(x-midpoint)*(x-midpoint) / (tile_w*tile_w) / (2*var)) / sqrt(2*pi*var) |
| x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] |
| y_probs = [f(y, tile_h / 2) for y in range(tile_h)] |
|
|
| w = np.outer(y_probs, x_probs) |
| return torch.from_numpy(w).to(devices.device, dtype=torch.float32) |
|
|
| def feather_mask(w:int, h:int, ratio:float) -> Tensor: |
| '''Generate a feather mask for the bbox''' |
|
|
| mask = np.ones((h, w), dtype=np.float32) |
| feather_radius = int(min(w//2, h//2) * ratio) |
| |
| |
| |
| for i in range(h//2): |
| for j in range(w//2): |
| dist = min(i, j) |
| if dist >= feather_radius: continue |
| weight = (dist / feather_radius) ** 2 |
| mask[i, j] = weight |
| mask[i, w-j-1] = weight |
| mask[h-i-1, j] = weight |
| mask[h-i-1, w-j-1] = weight |
|
|
| return torch.from_numpy(mask).to(devices.device, dtype=torch.float32) |
|
|
| def get_retouch_mask(img_input: np.ndarray, kernel_size: int) -> np.ndarray: |
| ''' |
| Return the area where the image is retouched. |
| Copy from Zhihu.com |
| ''' |
| step = 1 |
| kernel = (kernel_size, kernel_size) |
| |
| img = img_input.astype(np.float32)/255.0 |
| sz = img.shape[:2] |
| sz1 = (int(round(sz[1] * step)), int(round(sz[0] * step))) |
| sz2 = (int(round(kernel[0] * step)), int(round(kernel[0] * step))) |
| sI = cv2.resize(img, sz1, interpolation=cv2.INTER_LINEAR) |
| sp = cv2.resize(img, sz1, interpolation=cv2.INTER_LINEAR) |
| msI = cv2.blur(sI, sz2) |
| msp = cv2.blur(sp, sz2) |
| msII = cv2.blur(sI*sI, sz2) |
| msIp = cv2.blur(sI*sp, sz2) |
| vsI = msII - msI*msI |
| csIp = msIp - msI*msp |
| recA = csIp/(vsI+0.01) |
| recB = msp - recA*msI |
| mA = cv2.resize(recA, (sz[1],sz[0]), interpolation=cv2.INTER_LINEAR) |
| mB = cv2.resize(recB, (sz[1],sz[0]), interpolation=cv2.INTER_LINEAR) |
| |
| gf = mA * img + mB |
| gf -= img |
| gf *= 255 |
| gf = gf.astype(np.uint8) |
| gf = gf.clip(0, 255) |
| gf = gf.astype(np.float32)/255.0 |
| return gf |
|
|
|
|
| def null_decorator(fn): |
| def wrapper(*args, **kwargs): |
| return fn(*args, **kwargs) |
| return wrapper |
|
|
| keep_signature = null_decorator |
| controlnet = null_decorator |
| stablesr = null_decorator |
| grid_bbox = null_decorator |
| custom_bbox = null_decorator |
| noise_inverse = null_decorator |
|
|