tellurion's picture
Clean up dead code and add startup model loading
1928ea4
import cv2
import numpy as np
import PIL.Image as Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from functools import partial
maxium_resolution = 4096
token_length = int(256 ** 0.5)
def exists(v):
return v is not None
resize = partial(transforms.Resize, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
def resize_image(img, new_size, w, h):
if w > h:
img = resize((int(h / w * new_size), new_size))(img)
else:
img = resize((new_size, int(w / h * new_size)))(img)
return img
def pad_image(image: torch.Tensor, h, w):
b, c, height, width = image.shape
square_image = -torch.ones([b, c, h, w], device=image.device)
left = (w - width) // 2
top = (h - height) // 2
square_image[:, :, top:top+height, left:left+width] = image
return square_image, (left, top, width, height)
def pad_image_with_margin(image: Image, scale):
w, h = image.size
nw = int(w * scale)
bg = Image.new('RGB', (nw, h), (255, 255, 255))
bg.paste(image, ((nw-w)//2, 0))
return bg
def crop_image_from_square(square_image, original_dim):
left, top, width, height = original_dim
return square_image.crop((left, top, left + width, top + height))
def to_tensor(x, inverse=False):
x = transforms.ToTensor()(x).unsqueeze(0)
x = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(x).cuda()
return x if not inverse else -x
def to_numpy(x, denormalize=True):
if denormalize:
return ((x.clamp(-1, 1) + 1.) * 127.5).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
else:
return (x.clamp(0, 1) * 255)[0][0].cpu().numpy().astype(np.uint8)
def lineart_standard(x: Image.Image):
x = np.array(x).astype(np.float32)
g = cv2.GaussianBlur(x, (0, 0), 6.0)
intensity = np.min(g - x, axis=2).clip(0, 255)
intensity /= max(16, np.median(intensity[intensity > 8]))
intensity *= 127
intensity = np.repeat(np.expand_dims(intensity, 2), 3, axis=2)
result = to_tensor(intensity.clip(0, 255).astype(np.uint8))
return result
def preprocess_sketch(sketch, resolution, preprocess="none", extractor=None):
w, h = sketch.size
th, tw = resolution
r = min(th/h, tw/w)
if preprocess == "none":
sketch = to_tensor(sketch)
elif preprocess == "invert":
sketch = to_tensor(sketch, inverse=True)
elif preprocess == "invert-webui":
sketch = lineart_standard(sketch)
else:
sketch = extractor.proceed(resize((768, 768))(sketch)).repeat(1, 3, 1, 1)
sketch, original_shape = pad_image(resize((int(h*r), int(w*r)))(sketch), th, tw)
white_sketch = -sketch
return sketch, original_shape, white_sketch
@torch.no_grad()
def preprocessing_inputs(
sketch: Image.Image,
reference: Image.Image,
background: Image.Image,
preprocess: str,
hook: bool,
resolution: tuple[int, int],
extractor: nn.Module,
pad_scale: float = 1.,
):
extractor = extractor.cuda()
h, w = resolution
if exists(sketch):
sketch, original_shape, white_sketch = preprocess_sketch(sketch, resolution, preprocess, extractor)
else:
sketch = -torch.ones([1, 3, h, w], device="cuda")
white_sketch = None
original_shape = (0, 0, h, w)
inject_xs = None
if hook:
assert exists(reference) and exists(extractor)
maxm = max(h, w)
# inject_xs = resize((h, w))(extractor.proceed(resize((maxm, maxm))(reference)).repeat(1, 3, 1, 1))
inject_xr = to_tensor(resize((h, w))(reference))
else:
inject_xr = None
extractor = extractor.cpu()
if exists(reference):
if pad_scale > 1.:
reference = pad_image_with_margin(reference, pad_scale)
reference = to_tensor(reference)
if exists(background):
if pad_scale > 1.:
background = pad_image_with_margin(background, pad_scale)
background = to_tensor(background)
return sketch, reference, background, original_shape, inject_xr, inject_xs, white_sketch
def postprocess(results, sketch, reference, background, crop, original_shape,
mask_guided, smask, rmask, bgmask, mask_ts, mask_ss):
results = to_numpy(results)
sketch = to_numpy(sketch, True)[0]
results_list = []
for result in results:
result = Image.fromarray(result)
if crop:
result = crop_image_from_square(result, original_shape)
results_list.append(result)
results_list.append(sketch)
if exists(reference):
reference = to_numpy(reference)[0]
results_list.append(reference)
# if vis_crossattn:
# results_list += visualize_attention_map(reference, results_list[0], vh, vw)
if exists(background):
background = to_numpy(background)[0]
results_list.append(background)
if exists(bgmask):
background = Image.fromarray(background)
results_list.append(Image.composite(
background,
Image.new("RGB", background.size, (255, 255, 255)),
Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
))
results_list.append(Image.composite(
Image.new("RGB", background.size, (255, 255, 255)),
background,
Image.fromarray(to_numpy(bgmask, denormalize=False), mode="L")
))
if mask_guided:
smask[smask < mask_ss] = 0
results_list.append(Image.fromarray(to_numpy(smask, denormalize=False), mode="L"))
if exists(rmask):
reference = Image.fromarray(reference)
rmask[rmask < mask_ts] = 0
results_list.append(Image.fromarray(to_numpy(rmask, denormalize=False), mode="L"))
results_list.append(Image.composite(
reference,
Image.new("RGB", reference.size, (255, 255, 255)),
Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
))
results_list.append(Image.composite(
Image.new("RGB", reference.size, (255, 255, 255)),
reference,
Image.fromarray(to_numpy(rmask, denormalize=False), mode="L")
))
return results_list
def parse_prompts(
prompts: str,
target: bool = None,
anchor: bool = None,
control: bool = None,
target_scale: bool = None,
ts0: float = None,
ts1: float = None,
ts2: float = None,
ts3: float = None,
enhance: bool = None
):
targets = []
anchors = []
controls = []
scales = []
enhances = []
thresholds_list = []
replace_str = ["; [anchor] ", "; [control] ", "; [scale]", "; [enhanced]", "; [ts0]", "; [ts1]", "; [ts2]", "; [ts3]"]
if prompts != "" and prompts is not None:
ps_l = prompts.split('\n')
for ps in ps_l:
ps = ps.replace("[target] ", "")
for str in replace_str:
ps = ps.replace(str, "||||")
p_l = ps.split("||||")
targets.append(p_l[0])
anchors.append(p_l[1])
controls.append(p_l[2])
scales.append(float(p_l[3]))
enhances.append(bool(p_l[4]))
thresholds_list.append([float(p_l[5]), float(p_l[6]), float(p_l[7]), float(p_l[8])])
if exists(target) and target != "":
targets.append(target)
anchors.append(anchor)
controls.append(control)
scales.append(target_scale)
enhances.append(enhance)
thresholds_list.append([ts0, ts1, ts2, ts3])
return {
"targets": targets,
"anchors": anchors,
"controls": controls,
"target_scales": scales,
"enhances": enhances,
"thresholds_list": thresholds_list
}
from refnet.sampling.manipulation import get_heatmaps
def visualize_heatmaps(model, reference, manipulation_params, control, ts0, ts1, ts2, ts3):
if reference is None:
return []
size = reference.size
if size[0] > maxium_resolution or size[1] > maxium_resolution:
if size[0] > size[1]:
size = (maxium_resolution, int(float(maxium_resolution) / size[0] * size[1]))
else:
size = (int(float(maxium_resolution) / size[1] * size[0]), maxium_resolution)
reference = reference.resize(size, Image.BICUBIC)
reference = np.array(reference)
scale_maps = get_heatmaps(model, to_tensor(reference), size[1], size[0],
control, ts0, ts1, ts2, ts3, **manipulation_params)
scale_map = scale_maps[0] + scale_maps[1] + scale_maps[2] + scale_maps[3]
heatmap = cv2.cvtColor(cv2.applyColorMap(scale_map, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
result = cv2.addWeighted(reference, 0.3, heatmap, 0.7, 0)
hu = size[1] // token_length
wu = size[0] // token_length
for i in range(16):
result[i * hu, :] = (0, 0, 0)
for i in range(16):
result[:, i * wu] = (0, 0, 0)
return [result]