|
|
|
|
|
|
|
|
import folder_paths, os |
|
|
from comfy.supported_models import FluxInpaint, models |
|
|
from nodes import UNETLoader |
|
|
try: |
|
|
from scepter.modules.utils.file_system import FS |
|
|
from scepter.modules.annotator.registry import ANNOTATORS |
|
|
from scepter.modules.utils.config import Config |
|
|
|
|
|
fs_list = [ |
|
|
Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./"}, load=False), |
|
|
Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./"}, load=False), |
|
|
Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./"}, load=False), |
|
|
Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./"}, load=False) |
|
|
] |
|
|
|
|
|
for one_fs in fs_list: |
|
|
FS.init_fs_client(one_fs) |
|
|
SCEPTER = True |
|
|
except: |
|
|
SCEPTER = False |
|
|
|
|
|
class ACEPlus(FluxInpaint): |
|
|
unet_config = { |
|
|
"image_model": "flux", |
|
|
"guidance_embed": True, |
|
|
"in_channels": 112, |
|
|
} |
|
|
|
|
|
|
|
|
class ACEPlusFFTLoader(UNETLoader): |
|
|
|
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": {"unet_name": (folder_paths.get_filename_list("diffusion_models"), ), |
|
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],) |
|
|
}} |
|
|
RETURN_TYPES = ("MODEL",) |
|
|
FUNCTION = "load_unet" |
|
|
CATEGORY = "ComfyUI-ACE_Plus" |
|
|
|
|
|
def load_unet(self, unet_name, weight_dtype): |
|
|
models.append(ACEPlus) |
|
|
return super().load_unet(unet_name, weight_dtype) |
|
|
|
|
|
|
|
|
import torch |
|
|
import node_helpers |
|
|
|
|
|
|
|
|
class ACEPlusFFTConditioning: |
|
|
|
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": {"positive": ("CONDITIONING", ), |
|
|
"negative": ("CONDITIONING", ), |
|
|
"vae": ("VAE", ), |
|
|
"ucpixels": ("IMAGE", ), |
|
|
"cpixels": ("IMAGE", ), |
|
|
"mask": ("MASK", ), |
|
|
"noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent " |
|
|
"so sampling will only happen " |
|
|
"within the mask. Might improve " |
|
|
"results or completely break " |
|
|
"things depending on the model."}), |
|
|
}} |
|
|
|
|
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") |
|
|
RETURN_NAMES = ("positive", "negative", "latent") |
|
|
FUNCTION = "encode" |
|
|
|
|
|
CATEGORY = "ComfyUI-ACE_Plus" |
|
|
|
|
|
def encode(self, |
|
|
positive, |
|
|
negative, |
|
|
vae, |
|
|
ucpixels, |
|
|
cpixels, |
|
|
mask, |
|
|
noise_mask=True): |
|
|
x = (ucpixels.shape[1] // 8) * 8 |
|
|
y = (ucpixels.shape[2] // 8) * 8 |
|
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), |
|
|
size=(ucpixels.shape[1], ucpixels.shape[2]), mode="bilinear") |
|
|
|
|
|
orig_pixels = ucpixels |
|
|
pixels = orig_pixels.clone() |
|
|
if pixels.shape[1] != x or pixels.shape[2] != y: |
|
|
x_offset = (pixels.shape[1] % 8) // 2 |
|
|
y_offset = (pixels.shape[2] % 8) // 2 |
|
|
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] |
|
|
mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset] |
|
|
|
|
|
orig_c_pixels = cpixels |
|
|
c_pixels = orig_c_pixels.clone() |
|
|
if orig_c_pixels.shape[1] != x or orig_c_pixels.shape[2] != y: |
|
|
x_offset = (orig_c_pixels.shape[1] % 8) // 2 |
|
|
y_offset = (orig_c_pixels.shape[2] % 8) // 2 |
|
|
c_pixels = orig_c_pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] |
|
|
|
|
|
concat_latent = vae.encode(pixels) |
|
|
orig_latent = vae.encode(orig_pixels) |
|
|
c_concat_latent = vae.encode(c_pixels) |
|
|
|
|
|
out_latent = {"samples": orig_latent} |
|
|
if noise_mask: |
|
|
out_latent["noise_mask"] = mask |
|
|
|
|
|
out = [] |
|
|
for conditioning in [positive, negative]: |
|
|
c = node_helpers.conditioning_set_values(conditioning, { |
|
|
"concat_latent_image": torch.cat([concat_latent, c_concat_latent], dim=1), |
|
|
"concat_mask": mask}) |
|
|
out.append(c) |
|
|
|
|
|
return (out[0], out[1], out_latent) |
|
|
|
|
|
class ACEPlusLoraConditioning: |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": {"positive": ("CONDITIONING",), |
|
|
"negative": ("CONDITIONING",), |
|
|
"vae": ("VAE",), |
|
|
"pixels": ("IMAGE",), |
|
|
"mask": ("MASK",), |
|
|
"noise_mask": ("BOOLEAN", {"default": True, |
|
|
"tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}), |
|
|
}} |
|
|
|
|
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") |
|
|
RETURN_NAMES = ("positive", "negative", "latent") |
|
|
FUNCTION = "encode" |
|
|
|
|
|
CATEGORY = "ComfyUI-ACE_Plus" |
|
|
|
|
|
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True): |
|
|
x = (pixels.shape[1] // 8) * 8 |
|
|
y = (pixels.shape[2] // 8) * 8 |
|
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), |
|
|
size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") |
|
|
|
|
|
orig_pixels = pixels |
|
|
pixels = orig_pixels.clone() |
|
|
if pixels.shape[1] != x or pixels.shape[2] != y: |
|
|
x_offset = (pixels.shape[1] % 8) // 2 |
|
|
y_offset = (pixels.shape[2] % 8) // 2 |
|
|
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] |
|
|
mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset] |
|
|
|
|
|
concat_latent = vae.encode(pixels) |
|
|
orig_latent = vae.encode(orig_pixels) |
|
|
|
|
|
out_latent = {} |
|
|
|
|
|
out_latent["samples"] = orig_latent |
|
|
if noise_mask: |
|
|
out_latent["noise_mask"] = mask |
|
|
|
|
|
out = [] |
|
|
for conditioning in [positive, negative]: |
|
|
c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent, |
|
|
"concat_mask": mask}) |
|
|
out.append(c) |
|
|
return (out[0], out[1], out_latent) |
|
|
|
|
|
|
|
|
import torch |
|
|
import math |
|
|
import os |
|
|
import yaml |
|
|
import torchvision.transforms as T |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class AcePlusFFTProcessor: |
|
|
def __init__(self, |
|
|
max_aspect_ratio=4, |
|
|
d=16, |
|
|
max_seq_len=1024): |
|
|
self.max_aspect_ratio = max_aspect_ratio |
|
|
self.max_seq_len = max_seq_len |
|
|
self.d = d |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
config_path = os.path.join(current_dir, 'config', 'ace_plus_fft_processor.yaml') |
|
|
self.processor_cfg = self.load_yaml(config_path) |
|
|
self.task_list = {} |
|
|
for task in self.processor_cfg['PREPROCESSOR']: |
|
|
self.task_list[task['TYPE']] = task |
|
|
self.transforms = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0, 0, 0], std=[1.0, 1.0, 1.0]) |
|
|
]) |
|
|
|
|
|
CATEGORY = 'ComfyUI-ACE_Plus' |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
'required': { |
|
|
'use_reference': ('BOOLEAN', {'default': True}), |
|
|
'height': ('INT', { |
|
|
'default': 1024, |
|
|
'min': 256, |
|
|
'max': 1436, |
|
|
'step': 16 |
|
|
}), |
|
|
'width': ('INT', { |
|
|
'default': 1024, |
|
|
'min': 256, |
|
|
'max': 1436, |
|
|
'step': 16 |
|
|
}), |
|
|
'task_type': (list(s().task_list.keys()),), |
|
|
'keep_pixels_rate': ('FLOAT', { |
|
|
'default': 0.8, |
|
|
'min': 0, |
|
|
'max': 1, |
|
|
'step': 0.01 |
|
|
}), |
|
|
'max_seq_length': ('INT', { |
|
|
'default': 3072, |
|
|
'min': 1024, |
|
|
'max': 5120, |
|
|
'step': 0.01 |
|
|
}), |
|
|
}, |
|
|
'optional': { |
|
|
'reference_image': ('IMAGE',), |
|
|
'edit_image': ('IMAGE',), |
|
|
'edit_mask': ('MASK',), |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
OUTPUT_NODE = True |
|
|
RETURN_TYPES = ('IMAGE', 'IMAGE', 'MASK', 'INT', 'INT', 'INT') |
|
|
RETURN_NAMES = ('UC_IMAGE', 'C_IMAGE', 'MASK', 'OUT_H', 'OUT_W', 'SLICE_W') |
|
|
FUNCTION = 'preprocess' |
|
|
|
|
|
def load_yaml(self, cfg_file): |
|
|
with open(cfg_file, 'r') as f: |
|
|
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) |
|
|
return cfg |
|
|
|
|
|
def image_check(self, image): |
|
|
if image is None: |
|
|
return image |
|
|
|
|
|
H, W = image.shape[1: 3] |
|
|
image = image.permute(0, 3, 1, 2) |
|
|
if H / W > self.max_aspect_ratio: |
|
|
image[0] = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image[0]) |
|
|
elif W / H > self.max_aspect_ratio: |
|
|
image[0] = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image[0]) |
|
|
return image[0] |
|
|
|
|
|
def trans_pil_tensor(self, pil_image): |
|
|
transform = T.Compose([ |
|
|
T.ToTensor() |
|
|
]) |
|
|
tensor_image = transform(pil_image) |
|
|
return tensor_image |
|
|
|
|
|
def edit_preprocess(self, processor, device, edit_image, edit_mask): |
|
|
|
|
|
if edit_image is None or processor is None: |
|
|
return edit_image |
|
|
if not SCEPTER: |
|
|
raise ImportError(f'Please install scepter to use edit processor {processor} by ' |
|
|
f'runing "pip install scepter" in the conda env') |
|
|
processor = Config(cfg_dict=processor, load=False) |
|
|
processor = ANNOTATORS.build(processor).to(device) |
|
|
edit_image = Image.fromarray(np.array(edit_image[0] * 255).astype(np.uint8)).convert('RGB') |
|
|
new_edit_image = processor(np.asarray(edit_image)) |
|
|
|
|
|
del processor |
|
|
new_edit_image = Image.fromarray(new_edit_image) |
|
|
if edit_mask is not None: |
|
|
edit_mask = np.where(edit_mask > 0.5, 1, 0) * 255 |
|
|
edit_mask = Image.fromarray(np.array(edit_mask[0]).astype(np.uint8)).convert('L') |
|
|
|
|
|
if new_edit_image.size != edit_image.size: |
|
|
edit_image = T.Resize((edit_image.size[1], edit_image.size[0]), |
|
|
interpolation=T.InterpolationMode.BILINEAR, |
|
|
antialias=True)(new_edit_image) |
|
|
|
|
|
image = Image.composite(new_edit_image, edit_image, edit_mask) |
|
|
|
|
|
return self.trans_pil_tensor(image).unsqueeze(0).permute(0, 2, 3, 1) |
|
|
|
|
|
def preprocess(self, |
|
|
reference_image=None, |
|
|
edit_image=None, |
|
|
edit_mask=None, |
|
|
use_reference=True, |
|
|
task_type=None, |
|
|
height=1024, |
|
|
width=1024, |
|
|
keep_pixels_rate=0.8, |
|
|
max_seq_length=4096): |
|
|
self.max_seq_len = max_seq_length |
|
|
if not use_reference and edit_image is not None: |
|
|
reference_image = None |
|
|
if edit_mask is not None and edit_image is not None: |
|
|
iH, iW = edit_image.shape[1:3] |
|
|
mH, mW = edit_mask.shape[1:3] |
|
|
if iH != mH or iW != mW: |
|
|
edit_mask = torch.ones(edit_image.shape[:3]) |
|
|
|
|
|
if task_type != 'repainting': |
|
|
repainting_scale = 0 |
|
|
else: |
|
|
repainting_scale = 1 |
|
|
if task_type in self.task_list: |
|
|
edit_image = self.edit_preprocess(self.task_list[task_type]['ANNOTATOR'], 0, |
|
|
edit_image, edit_mask) |
|
|
if reference_image is not None: |
|
|
reference_image = self.image_check(reference_image) - 0.5 |
|
|
if edit_image is not None: |
|
|
edit_image = self.image_check(edit_image) - 0.5 |
|
|
|
|
|
if edit_image is None: |
|
|
edit_image = torch.zeros([3, height, width]) |
|
|
edit_mask = torch.ones([1, height, width]) |
|
|
else: |
|
|
if edit_mask is None: |
|
|
_, eH, eW = edit_image.shape |
|
|
edit_mask = np.ones((eH, eW)) |
|
|
else: |
|
|
edit_mask = np.asarray(edit_mask)[0] |
|
|
edit_mask = np.where(edit_mask > 0.5, 1, 0) |
|
|
edit_mask = edit_mask.astype( |
|
|
np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype( |
|
|
np.float32) |
|
|
edit_mask = torch.tensor(edit_mask).unsqueeze(0) |
|
|
|
|
|
edit_image = edit_image * (1 - edit_mask * repainting_scale) |
|
|
|
|
|
out_h, out_w = edit_image.shape[-2:] |
|
|
|
|
|
assert edit_mask is not None |
|
|
if reference_image is not None: |
|
|
_, H, W = reference_image.shape |
|
|
_, eH, eW = edit_image.shape |
|
|
if not True: |
|
|
|
|
|
scale = eH / H |
|
|
tH, tW = eH, int(W * scale) |
|
|
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)( |
|
|
reference_image) |
|
|
else: |
|
|
|
|
|
if H >= keep_pixels_rate * eH: |
|
|
tH = int(eH * keep_pixels_rate) |
|
|
scale = tH / H |
|
|
tW = int(W * scale) |
|
|
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)( |
|
|
reference_image) |
|
|
rH, rW = reference_image.shape[-2:] |
|
|
delta_w = 0 |
|
|
delta_h = eH - rH |
|
|
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) |
|
|
reference_image = T.Pad(padding, fill=0, padding_mode="constant")(reference_image) |
|
|
edit_image = torch.cat([reference_image, edit_image], dim=-1) |
|
|
edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], |
|
|
dim=-1) |
|
|
slice_w = reference_image.shape[-1] |
|
|
else: |
|
|
slice_w = 0 |
|
|
|
|
|
H, W = edit_image.shape[-2:] |
|
|
scale = min(1.0, math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d)))) |
|
|
rH = int(H * scale) // self.d * self.d |
|
|
rW = int(W * scale) // self.d * self.d |
|
|
slice_w = int(slice_w * scale) // self.d * self.d |
|
|
|
|
|
edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image) |
|
|
edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask) |
|
|
|
|
|
change_image = edit_image * edit_mask |
|
|
edit_image = edit_image * (1 - edit_mask) |
|
|
edit_image = edit_image.unsqueeze(0).permute(0, 2, 3, 1) |
|
|
change_image = change_image.unsqueeze(0).permute(0, 2, 3, 1) |
|
|
slice_w = slice_w if slice_w > 30 else slice_w + 30 |
|
|
|
|
|
return edit_image + 0.5, change_image + 0.5, edit_mask, out_h, out_w, slice_w |
|
|
|
|
|
class AcePlusLoraProcessor: |
|
|
def __init__(self, |
|
|
max_aspect_ratio=4, |
|
|
d=16, |
|
|
max_seq_len=1024): |
|
|
self.max_aspect_ratio = max_aspect_ratio |
|
|
self.max_seq_len = max_seq_len |
|
|
self.d = d |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
config_path = os.path.join(current_dir, 'config', 'ace_plus_fft_processor.yaml') |
|
|
self.processor_cfg = self.load_yaml(config_path) |
|
|
self.task_list = {} |
|
|
for task in self.processor_cfg['PREPROCESSOR']: |
|
|
self.task_list[task['TYPE']] = task |
|
|
self.transforms = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0, 0, 0], std=[1.0, 1.0, 1.0]) |
|
|
]) |
|
|
|
|
|
CATEGORY = 'ComfyUI-ACE_Plus' |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
'required': { |
|
|
'use_reference': ('BOOLEAN', {'default': True}), |
|
|
'height': ('INT', { |
|
|
'default': 1024, |
|
|
'min': 256, |
|
|
'max': 1436, |
|
|
'step': 16 |
|
|
}), |
|
|
'width': ('INT', { |
|
|
'default': 1024, |
|
|
'min': 256, |
|
|
'max': 1436, |
|
|
'step': 16 |
|
|
}), |
|
|
'task_type': (list(s().task_list.keys()),), |
|
|
'max_seq_length': ('INT', { |
|
|
'default': 3072, |
|
|
'min': 1024, |
|
|
'max': 5120, |
|
|
'step': 0.01 |
|
|
}), |
|
|
}, |
|
|
'optional': { |
|
|
'reference_image': ('IMAGE',), |
|
|
'edit_image': ('IMAGE',), |
|
|
'edit_mask': ('MASK',), |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
OUTPUT_NODE = True |
|
|
RETURN_TYPES = ('IMAGE', 'MASK', 'INT', 'INT', 'INT') |
|
|
RETURN_NAMES = ('IMAGE', 'MASK', 'OUT_H', 'OUT_W', 'SLICE_W') |
|
|
FUNCTION = 'preprocess' |
|
|
|
|
|
def load_yaml(self, cfg_file): |
|
|
with open(cfg_file, 'r') as f: |
|
|
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) |
|
|
return cfg |
|
|
|
|
|
def image_check(self, image): |
|
|
if image is None: |
|
|
return image |
|
|
|
|
|
H, W = image.shape[1: 3] |
|
|
image = image.permute(0, 3, 1, 2) |
|
|
if H / W > self.max_aspect_ratio: |
|
|
image[0] = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image[0]) |
|
|
elif W / H > self.max_aspect_ratio: |
|
|
image[0] = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image[0]) |
|
|
return image[0] |
|
|
|
|
|
def trans_pil_tensor(self, pil_image): |
|
|
transform = T.Compose([ |
|
|
T.ToTensor() |
|
|
]) |
|
|
tensor_image = transform(pil_image) |
|
|
return tensor_image |
|
|
|
|
|
def edit_preprocess(self, processor, device, edit_image, edit_mask): |
|
|
|
|
|
if edit_image is None or processor is None: |
|
|
return edit_image |
|
|
if not SCEPTER: |
|
|
raise ImportError(f'Please install scepter to use edit processor {processor} by ' |
|
|
f'runing "pip install scepter" in the conda env') |
|
|
processor = Config(cfg_dict=processor, load=False) |
|
|
processor = ANNOTATORS.build(processor).to(device) |
|
|
edit_image = Image.fromarray(np.array(edit_image[0] * 255).astype(np.uint8)).convert('RGB') |
|
|
new_edit_image = processor(np.asarray(edit_image)) |
|
|
|
|
|
del processor |
|
|
new_edit_image = Image.fromarray(new_edit_image) |
|
|
if edit_mask is not None: |
|
|
edit_mask = np.where(edit_mask > 0.5, 1, 0) * 255 |
|
|
edit_mask = Image.fromarray(np.array(edit_mask[0]).astype(np.uint8)).convert('L') |
|
|
|
|
|
if new_edit_image.size != edit_image.size: |
|
|
new_edit_image = T.Resize((edit_image.size[1], edit_image.size[0]), |
|
|
interpolation=T.InterpolationMode.BILINEAR, |
|
|
antialias=True)(new_edit_image) |
|
|
|
|
|
image = Image.composite(new_edit_image, edit_image, edit_mask) |
|
|
|
|
|
return self.trans_pil_tensor(image).unsqueeze(0).permute(0, 2, 3, 1) |
|
|
|
|
|
def preprocess(self, |
|
|
reference_image=None, |
|
|
edit_image=None, |
|
|
edit_mask=None, |
|
|
use_reference=True, |
|
|
task_type=None, |
|
|
height=1024, |
|
|
width=1024, |
|
|
max_seq_length=4096): |
|
|
self.max_seq_len = max_seq_length |
|
|
if not use_reference and edit_image is not None: |
|
|
reference_image = None |
|
|
if edit_mask is not None and edit_image is not None: |
|
|
iH, iW = edit_image.shape[1:3] |
|
|
mH, mW = edit_mask.shape[1:3] |
|
|
if iH != mH or iW != mW: |
|
|
edit_mask = torch.ones(edit_image.shape[:3]) |
|
|
|
|
|
if task_type != 'repainting': |
|
|
repainting_scale = 0.0 |
|
|
else: |
|
|
repainting_scale = 1.0 |
|
|
if task_type in self.task_list: |
|
|
edit_image = self.edit_preprocess(self.task_list[task_type]['ANNOTATOR'], 0, |
|
|
edit_image, edit_mask) |
|
|
if reference_image is not None: |
|
|
reference_image = self.image_check(reference_image) - 0.5 |
|
|
if edit_image is not None: |
|
|
edit_image = self.image_check(edit_image) - 0.5 |
|
|
|
|
|
if edit_image is None: |
|
|
edit_image = torch.zeros([3, height, width]) |
|
|
edit_mask = torch.ones([1, height, width]) |
|
|
else: |
|
|
if edit_mask is None: |
|
|
_, eH, eW = edit_image.shape |
|
|
edit_mask = np.ones((eH, eW)) |
|
|
else: |
|
|
edit_mask = np.asarray(edit_mask)[0] |
|
|
edit_mask = np.where(edit_mask > 0.5, 1, 0) |
|
|
edit_mask = edit_mask.astype( |
|
|
np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype( |
|
|
np.float32) |
|
|
edit_mask = torch.tensor(edit_mask).unsqueeze(0) |
|
|
|
|
|
edit_image = edit_image * (1 - edit_mask * repainting_scale) |
|
|
|
|
|
out_h, out_w = edit_image.shape[-2:] |
|
|
|
|
|
assert edit_mask is not None |
|
|
if reference_image is not None: |
|
|
_, H, W = reference_image.shape |
|
|
_, eH, eW = edit_image.shape |
|
|
|
|
|
scale = eH / H |
|
|
tH, tW = eH, int(W * scale) |
|
|
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)( |
|
|
reference_image) |
|
|
edit_image = torch.cat([reference_image, edit_image], dim=-1) |
|
|
edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], |
|
|
dim=-1) |
|
|
slice_w = reference_image.shape[-1] |
|
|
else: |
|
|
slice_w = 0 |
|
|
|
|
|
H, W = edit_image.shape[-2:] |
|
|
scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / self.d) * (W / self.d)))) |
|
|
rH = int(H * scale) // self.d * self.d |
|
|
rW = int(W * scale) // self.d * self.d |
|
|
slice_w = int(slice_w * scale) // self.d * self.d |
|
|
|
|
|
edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_image) |
|
|
edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask) |
|
|
|
|
|
edit_image = edit_image.unsqueeze(0).permute(0, 2, 3, 1) |
|
|
slice_w = slice_w if slice_w < 30 else slice_w + 30 |
|
|
|
|
|
return edit_image + 0.5, edit_mask, out_h, out_w, slice_w |
|
|
|