Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- auth.py +41 -0
- cldm.py +312 -0
- constants.py +5 -0
- entry_with_update 2.py +46 -0
- face_restoration_helper.py +374 -0
- inpaint_worker 2.py +264 -0
- inpaint_worker.py +264 -0
- launch_util.py +103 -0
- lora.py +152 -0
- model_loader.py +26 -0
- sdxl_styles.py +82 -0
- upscaler.py +34 -0
- util.py +177 -0
- webui.py +623 -0
auth.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import hashlib
|
| 3 |
+
import modules.constants as constants
|
| 4 |
+
|
| 5 |
+
from os.path import exists
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def auth_list_to_dict(auth_list):
|
| 9 |
+
auth_dict = {}
|
| 10 |
+
for auth_data in auth_list:
|
| 11 |
+
if 'user' in auth_data:
|
| 12 |
+
if 'hash' in auth_data:
|
| 13 |
+
auth_dict |= {auth_data['user']: auth_data['hash']}
|
| 14 |
+
elif 'pass' in auth_data:
|
| 15 |
+
auth_dict |= {auth_data['user']: hashlib.sha256(bytes(auth_data['pass'], encoding='utf-8')).hexdigest()}
|
| 16 |
+
return auth_dict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_auth_data(filename=None):
|
| 20 |
+
auth_dict = None
|
| 21 |
+
if filename != None and exists(filename):
|
| 22 |
+
with open(filename, encoding='utf-8') as auth_file:
|
| 23 |
+
try:
|
| 24 |
+
auth_obj = json.load(auth_file)
|
| 25 |
+
if isinstance(auth_obj, list) and len(auth_obj) > 0:
|
| 26 |
+
auth_dict = auth_list_to_dict(auth_obj)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print('load_auth_data, e: ' + str(e))
|
| 29 |
+
return auth_dict
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
auth_dict = load_auth_data(constants.AUTH_FILENAME)
|
| 33 |
+
|
| 34 |
+
auth_enabled = auth_dict != None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def check_auth(user, password):
|
| 38 |
+
if user not in auth_dict:
|
| 39 |
+
return False
|
| 40 |
+
else:
|
| 41 |
+
return hashlib.sha256(bytes(password, encoding='utf-8')).hexdigest() == auth_dict[user]
|
cldm.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#taken from: https://github.com/lllyasviel/ControlNet
|
| 2 |
+
#and modified
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch as th
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from ldm_patched.ldm.modules.diffusionmodules.util import (
|
| 9 |
+
zero_module,
|
| 10 |
+
timestep_embedding,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from ldm_patched.ldm.modules.attention import SpatialTransformer
|
| 14 |
+
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
| 15 |
+
from ldm_patched.ldm.util import exists
|
| 16 |
+
import ldm_patched.modules.ops
|
| 17 |
+
|
| 18 |
+
class ControlledUnetModel(UNetModel):
|
| 19 |
+
#implemented in the ldm unet
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
class ControlNet(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
image_size,
|
| 26 |
+
in_channels,
|
| 27 |
+
model_channels,
|
| 28 |
+
hint_channels,
|
| 29 |
+
num_res_blocks,
|
| 30 |
+
dropout=0,
|
| 31 |
+
channel_mult=(1, 2, 4, 8),
|
| 32 |
+
conv_resample=True,
|
| 33 |
+
dims=2,
|
| 34 |
+
num_classes=None,
|
| 35 |
+
use_checkpoint=False,
|
| 36 |
+
dtype=torch.float32,
|
| 37 |
+
num_heads=-1,
|
| 38 |
+
num_head_channels=-1,
|
| 39 |
+
num_heads_upsample=-1,
|
| 40 |
+
use_scale_shift_norm=False,
|
| 41 |
+
resblock_updown=False,
|
| 42 |
+
use_new_attention_order=False,
|
| 43 |
+
use_spatial_transformer=False, # custom transformer support
|
| 44 |
+
transformer_depth=1, # custom transformer support
|
| 45 |
+
context_dim=None, # custom transformer support
|
| 46 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 47 |
+
legacy=True,
|
| 48 |
+
disable_self_attentions=None,
|
| 49 |
+
num_attention_blocks=None,
|
| 50 |
+
disable_middle_self_attn=False,
|
| 51 |
+
use_linear_in_transformer=False,
|
| 52 |
+
adm_in_channels=None,
|
| 53 |
+
transformer_depth_middle=None,
|
| 54 |
+
transformer_depth_output=None,
|
| 55 |
+
device=None,
|
| 56 |
+
operations=ldm_patched.modules.ops.disable_weight_init,
|
| 57 |
+
**kwargs,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
| 61 |
+
if use_spatial_transformer:
|
| 62 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 63 |
+
|
| 64 |
+
if context_dim is not None:
|
| 65 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 66 |
+
# from omegaconf.listconfig import ListConfig
|
| 67 |
+
# if type(context_dim) == ListConfig:
|
| 68 |
+
# context_dim = list(context_dim)
|
| 69 |
+
|
| 70 |
+
if num_heads_upsample == -1:
|
| 71 |
+
num_heads_upsample = num_heads
|
| 72 |
+
|
| 73 |
+
if num_heads == -1:
|
| 74 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 75 |
+
|
| 76 |
+
if num_head_channels == -1:
|
| 77 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 78 |
+
|
| 79 |
+
self.dims = dims
|
| 80 |
+
self.image_size = image_size
|
| 81 |
+
self.in_channels = in_channels
|
| 82 |
+
self.model_channels = model_channels
|
| 83 |
+
|
| 84 |
+
if isinstance(num_res_blocks, int):
|
| 85 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 86 |
+
else:
|
| 87 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 88 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 89 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 90 |
+
self.num_res_blocks = num_res_blocks
|
| 91 |
+
|
| 92 |
+
if disable_self_attentions is not None:
|
| 93 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 94 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 95 |
+
if num_attention_blocks is not None:
|
| 96 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 97 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 98 |
+
|
| 99 |
+
transformer_depth = transformer_depth[:]
|
| 100 |
+
|
| 101 |
+
self.dropout = dropout
|
| 102 |
+
self.channel_mult = channel_mult
|
| 103 |
+
self.conv_resample = conv_resample
|
| 104 |
+
self.num_classes = num_classes
|
| 105 |
+
self.use_checkpoint = use_checkpoint
|
| 106 |
+
self.dtype = dtype
|
| 107 |
+
self.num_heads = num_heads
|
| 108 |
+
self.num_head_channels = num_head_channels
|
| 109 |
+
self.num_heads_upsample = num_heads_upsample
|
| 110 |
+
self.predict_codebook_ids = n_embed is not None
|
| 111 |
+
|
| 112 |
+
time_embed_dim = model_channels * 4
|
| 113 |
+
self.time_embed = nn.Sequential(
|
| 114 |
+
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
| 115 |
+
nn.SiLU(),
|
| 116 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if self.num_classes is not None:
|
| 120 |
+
if isinstance(self.num_classes, int):
|
| 121 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 122 |
+
elif self.num_classes == "continuous":
|
| 123 |
+
print("setting up linear c_adm embedding layer")
|
| 124 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
| 125 |
+
elif self.num_classes == "sequential":
|
| 126 |
+
assert adm_in_channels is not None
|
| 127 |
+
self.label_emb = nn.Sequential(
|
| 128 |
+
nn.Sequential(
|
| 129 |
+
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
| 130 |
+
nn.SiLU(),
|
| 131 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError()
|
| 136 |
+
|
| 137 |
+
self.input_blocks = nn.ModuleList(
|
| 138 |
+
[
|
| 139 |
+
TimestepEmbedSequential(
|
| 140 |
+
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
| 141 |
+
)
|
| 142 |
+
]
|
| 143 |
+
)
|
| 144 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
| 145 |
+
|
| 146 |
+
self.input_hint_block = TimestepEmbedSequential(
|
| 147 |
+
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
| 148 |
+
nn.SiLU(),
|
| 149 |
+
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
| 150 |
+
nn.SiLU(),
|
| 151 |
+
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
| 152 |
+
nn.SiLU(),
|
| 153 |
+
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
| 154 |
+
nn.SiLU(),
|
| 155 |
+
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
| 156 |
+
nn.SiLU(),
|
| 157 |
+
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
| 158 |
+
nn.SiLU(),
|
| 159 |
+
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
| 160 |
+
nn.SiLU(),
|
| 161 |
+
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self._feature_size = model_channels
|
| 165 |
+
input_block_chans = [model_channels]
|
| 166 |
+
ch = model_channels
|
| 167 |
+
ds = 1
|
| 168 |
+
for level, mult in enumerate(channel_mult):
|
| 169 |
+
for nr in range(self.num_res_blocks[level]):
|
| 170 |
+
layers = [
|
| 171 |
+
ResBlock(
|
| 172 |
+
ch,
|
| 173 |
+
time_embed_dim,
|
| 174 |
+
dropout,
|
| 175 |
+
out_channels=mult * model_channels,
|
| 176 |
+
dims=dims,
|
| 177 |
+
use_checkpoint=use_checkpoint,
|
| 178 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 179 |
+
dtype=self.dtype,
|
| 180 |
+
device=device,
|
| 181 |
+
operations=operations,
|
| 182 |
+
)
|
| 183 |
+
]
|
| 184 |
+
ch = mult * model_channels
|
| 185 |
+
num_transformers = transformer_depth.pop(0)
|
| 186 |
+
if num_transformers > 0:
|
| 187 |
+
if num_head_channels == -1:
|
| 188 |
+
dim_head = ch // num_heads
|
| 189 |
+
else:
|
| 190 |
+
num_heads = ch // num_head_channels
|
| 191 |
+
dim_head = num_head_channels
|
| 192 |
+
if legacy:
|
| 193 |
+
#num_heads = 1
|
| 194 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 195 |
+
if exists(disable_self_attentions):
|
| 196 |
+
disabled_sa = disable_self_attentions[level]
|
| 197 |
+
else:
|
| 198 |
+
disabled_sa = False
|
| 199 |
+
|
| 200 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 201 |
+
layers.append(
|
| 202 |
+
SpatialTransformer(
|
| 203 |
+
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
| 204 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
| 205 |
+
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 209 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
| 210 |
+
self._feature_size += ch
|
| 211 |
+
input_block_chans.append(ch)
|
| 212 |
+
if level != len(channel_mult) - 1:
|
| 213 |
+
out_ch = ch
|
| 214 |
+
self.input_blocks.append(
|
| 215 |
+
TimestepEmbedSequential(
|
| 216 |
+
ResBlock(
|
| 217 |
+
ch,
|
| 218 |
+
time_embed_dim,
|
| 219 |
+
dropout,
|
| 220 |
+
out_channels=out_ch,
|
| 221 |
+
dims=dims,
|
| 222 |
+
use_checkpoint=use_checkpoint,
|
| 223 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 224 |
+
down=True,
|
| 225 |
+
dtype=self.dtype,
|
| 226 |
+
device=device,
|
| 227 |
+
operations=operations
|
| 228 |
+
)
|
| 229 |
+
if resblock_updown
|
| 230 |
+
else Downsample(
|
| 231 |
+
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
ch = out_ch
|
| 236 |
+
input_block_chans.append(ch)
|
| 237 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
| 238 |
+
ds *= 2
|
| 239 |
+
self._feature_size += ch
|
| 240 |
+
|
| 241 |
+
if num_head_channels == -1:
|
| 242 |
+
dim_head = ch // num_heads
|
| 243 |
+
else:
|
| 244 |
+
num_heads = ch // num_head_channels
|
| 245 |
+
dim_head = num_head_channels
|
| 246 |
+
if legacy:
|
| 247 |
+
#num_heads = 1
|
| 248 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 249 |
+
mid_block = [
|
| 250 |
+
ResBlock(
|
| 251 |
+
ch,
|
| 252 |
+
time_embed_dim,
|
| 253 |
+
dropout,
|
| 254 |
+
dims=dims,
|
| 255 |
+
use_checkpoint=use_checkpoint,
|
| 256 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 257 |
+
dtype=self.dtype,
|
| 258 |
+
device=device,
|
| 259 |
+
operations=operations
|
| 260 |
+
)]
|
| 261 |
+
if transformer_depth_middle >= 0:
|
| 262 |
+
mid_block += [SpatialTransformer( # always uses a self-attn
|
| 263 |
+
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
| 264 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
| 265 |
+
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
| 266 |
+
),
|
| 267 |
+
ResBlock(
|
| 268 |
+
ch,
|
| 269 |
+
time_embed_dim,
|
| 270 |
+
dropout,
|
| 271 |
+
dims=dims,
|
| 272 |
+
use_checkpoint=use_checkpoint,
|
| 273 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 274 |
+
dtype=self.dtype,
|
| 275 |
+
device=device,
|
| 276 |
+
operations=operations
|
| 277 |
+
)]
|
| 278 |
+
self.middle_block = TimestepEmbedSequential(*mid_block)
|
| 279 |
+
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
| 280 |
+
self._feature_size += ch
|
| 281 |
+
|
| 282 |
+
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
| 283 |
+
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
| 284 |
+
|
| 285 |
+
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
| 286 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
| 287 |
+
emb = self.time_embed(t_emb)
|
| 288 |
+
|
| 289 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
| 290 |
+
|
| 291 |
+
outs = []
|
| 292 |
+
|
| 293 |
+
hs = []
|
| 294 |
+
if self.num_classes is not None:
|
| 295 |
+
assert y.shape[0] == x.shape[0]
|
| 296 |
+
emb = emb + self.label_emb(y)
|
| 297 |
+
|
| 298 |
+
h = x
|
| 299 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
| 300 |
+
if guided_hint is not None:
|
| 301 |
+
h = module(h, emb, context)
|
| 302 |
+
h += guided_hint
|
| 303 |
+
guided_hint = None
|
| 304 |
+
else:
|
| 305 |
+
h = module(h, emb, context)
|
| 306 |
+
outs.append(zero_conv(h, emb, context))
|
| 307 |
+
|
| 308 |
+
h = self.middle_block(h, emb, context)
|
| 309 |
+
outs.append(self.middle_block_out(h, emb, context))
|
| 310 |
+
|
| 311 |
+
return outs
|
| 312 |
+
|
constants.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# as in k-diffusion (sampling.py)
|
| 2 |
+
MIN_SEED = 0
|
| 3 |
+
MAX_SEED = 2**63 - 1
|
| 4 |
+
|
| 5 |
+
AUTH_FILENAME = 'auth.json'
|
entry_with_update 2.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
root = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
+
sys.path.append(root)
|
| 7 |
+
os.chdir(root)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import pygit2
|
| 12 |
+
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
| 13 |
+
|
| 14 |
+
repo = pygit2.Repository(os.path.abspath(os.path.dirname(__file__)))
|
| 15 |
+
|
| 16 |
+
branch_name = repo.head.shorthand
|
| 17 |
+
|
| 18 |
+
remote_name = 'origin'
|
| 19 |
+
remote = repo.remotes[remote_name]
|
| 20 |
+
|
| 21 |
+
remote.fetch()
|
| 22 |
+
|
| 23 |
+
local_branch_ref = f'refs/heads/{branch_name}'
|
| 24 |
+
local_branch = repo.lookup_reference(local_branch_ref)
|
| 25 |
+
|
| 26 |
+
remote_reference = f'refs/remotes/{remote_name}/{branch_name}'
|
| 27 |
+
remote_commit = repo.revparse_single(remote_reference)
|
| 28 |
+
|
| 29 |
+
merge_result, _ = repo.merge_analysis(remote_commit.id)
|
| 30 |
+
|
| 31 |
+
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
|
| 32 |
+
print("Already up-to-date")
|
| 33 |
+
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
|
| 34 |
+
local_branch.set_target(remote_commit.id)
|
| 35 |
+
repo.head.set_target(remote_commit.id)
|
| 36 |
+
repo.checkout_tree(repo.get(remote_commit.id))
|
| 37 |
+
repo.reset(local_branch.target, pygit2.GIT_RESET_HARD)
|
| 38 |
+
print("Fast-forward merge")
|
| 39 |
+
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
|
| 40 |
+
print("Update failed - Did you modify any file?")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print('Update failed.')
|
| 43 |
+
print(str(e))
|
| 44 |
+
|
| 45 |
+
print('Update succeeded.')
|
| 46 |
+
from launch import *
|
face_restoration_helper.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision.transforms.functional import normalize
|
| 6 |
+
|
| 7 |
+
from extras.facexlib.detection import init_detection_model
|
| 8 |
+
from extras.facexlib.parsing import init_parsing_model
|
| 9 |
+
from extras.facexlib.utils.misc import img2tensor, imwrite
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_largest_face(det_faces, h, w):
|
| 13 |
+
|
| 14 |
+
def get_location(val, length):
|
| 15 |
+
if val < 0:
|
| 16 |
+
return 0
|
| 17 |
+
elif val > length:
|
| 18 |
+
return length
|
| 19 |
+
else:
|
| 20 |
+
return val
|
| 21 |
+
|
| 22 |
+
face_areas = []
|
| 23 |
+
for det_face in det_faces:
|
| 24 |
+
left = get_location(det_face[0], w)
|
| 25 |
+
right = get_location(det_face[2], w)
|
| 26 |
+
top = get_location(det_face[1], h)
|
| 27 |
+
bottom = get_location(det_face[3], h)
|
| 28 |
+
face_area = (right - left) * (bottom - top)
|
| 29 |
+
face_areas.append(face_area)
|
| 30 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 31 |
+
return det_faces[largest_idx], largest_idx
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
| 35 |
+
if center is not None:
|
| 36 |
+
center = np.array(center)
|
| 37 |
+
else:
|
| 38 |
+
center = np.array([w / 2, h / 2])
|
| 39 |
+
center_dist = []
|
| 40 |
+
for det_face in det_faces:
|
| 41 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
| 42 |
+
dist = np.linalg.norm(face_center - center)
|
| 43 |
+
center_dist.append(dist)
|
| 44 |
+
center_idx = center_dist.index(min(center_dist))
|
| 45 |
+
return det_faces[center_idx], center_idx
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class FaceRestoreHelper(object):
|
| 49 |
+
"""Helper for the face restoration pipeline (base class)."""
|
| 50 |
+
|
| 51 |
+
def __init__(self,
|
| 52 |
+
upscale_factor,
|
| 53 |
+
face_size=512,
|
| 54 |
+
crop_ratio=(1, 1),
|
| 55 |
+
det_model='retinaface_resnet50',
|
| 56 |
+
save_ext='png',
|
| 57 |
+
template_3points=False,
|
| 58 |
+
pad_blur=False,
|
| 59 |
+
use_parse=False,
|
| 60 |
+
device=None,
|
| 61 |
+
model_rootpath=None):
|
| 62 |
+
self.template_3points = template_3points # improve robustness
|
| 63 |
+
self.upscale_factor = upscale_factor
|
| 64 |
+
# the cropped face ratio based on the square face
|
| 65 |
+
self.crop_ratio = crop_ratio # (h, w)
|
| 66 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
| 67 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
| 68 |
+
|
| 69 |
+
if self.template_3points:
|
| 70 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
| 71 |
+
else:
|
| 72 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
| 73 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
| 74 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
| 75 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
| 76 |
+
if self.crop_ratio[0] > 1:
|
| 77 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
| 78 |
+
if self.crop_ratio[1] > 1:
|
| 79 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
| 80 |
+
self.save_ext = save_ext
|
| 81 |
+
self.pad_blur = pad_blur
|
| 82 |
+
if self.pad_blur is True:
|
| 83 |
+
self.template_3points = False
|
| 84 |
+
|
| 85 |
+
self.all_landmarks_5 = []
|
| 86 |
+
self.det_faces = []
|
| 87 |
+
self.affine_matrices = []
|
| 88 |
+
self.inverse_affine_matrices = []
|
| 89 |
+
self.cropped_faces = []
|
| 90 |
+
self.restored_faces = []
|
| 91 |
+
self.pad_input_imgs = []
|
| 92 |
+
|
| 93 |
+
if device is None:
|
| 94 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 95 |
+
else:
|
| 96 |
+
self.device = device
|
| 97 |
+
|
| 98 |
+
# init face detection model
|
| 99 |
+
self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath)
|
| 100 |
+
|
| 101 |
+
# init face parsing model
|
| 102 |
+
self.use_parse = use_parse
|
| 103 |
+
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device, model_rootpath=model_rootpath)
|
| 104 |
+
|
| 105 |
+
def set_upscale_factor(self, upscale_factor):
|
| 106 |
+
self.upscale_factor = upscale_factor
|
| 107 |
+
|
| 108 |
+
def read_image(self, img):
|
| 109 |
+
"""img can be image path or cv2 loaded image."""
|
| 110 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
| 111 |
+
if isinstance(img, str):
|
| 112 |
+
img = cv2.imread(img)
|
| 113 |
+
|
| 114 |
+
if np.max(img) > 256: # 16-bit image
|
| 115 |
+
img = img / 65535 * 255
|
| 116 |
+
if len(img.shape) == 2: # gray image
|
| 117 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 118 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
| 119 |
+
img = img[:, :, 0:3]
|
| 120 |
+
|
| 121 |
+
self.input_img = img
|
| 122 |
+
|
| 123 |
+
def get_face_landmarks_5(self,
|
| 124 |
+
only_keep_largest=False,
|
| 125 |
+
only_center_face=False,
|
| 126 |
+
resize=None,
|
| 127 |
+
blur_ratio=0.01,
|
| 128 |
+
eye_dist_threshold=None):
|
| 129 |
+
if resize is None:
|
| 130 |
+
scale = 1
|
| 131 |
+
input_img = self.input_img
|
| 132 |
+
else:
|
| 133 |
+
h, w = self.input_img.shape[0:2]
|
| 134 |
+
scale = min(h, w) / resize
|
| 135 |
+
h, w = int(h / scale), int(w / scale)
|
| 136 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
| 137 |
+
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
|
| 140 |
+
for bbox in bboxes:
|
| 141 |
+
# remove faces with too small eye distance: side faces or too small faces
|
| 142 |
+
eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
|
| 143 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
if self.template_3points:
|
| 147 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
| 148 |
+
else:
|
| 149 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
| 150 |
+
self.all_landmarks_5.append(landmark)
|
| 151 |
+
self.det_faces.append(bbox[0:5])
|
| 152 |
+
if len(self.det_faces) == 0:
|
| 153 |
+
return 0
|
| 154 |
+
if only_keep_largest:
|
| 155 |
+
h, w, _ = self.input_img.shape
|
| 156 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
| 157 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
| 158 |
+
elif only_center_face:
|
| 159 |
+
h, w, _ = self.input_img.shape
|
| 160 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
| 161 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
| 162 |
+
|
| 163 |
+
# pad blurry images
|
| 164 |
+
if self.pad_blur:
|
| 165 |
+
self.pad_input_imgs = []
|
| 166 |
+
for landmarks in self.all_landmarks_5:
|
| 167 |
+
# get landmarks
|
| 168 |
+
eye_left = landmarks[0, :]
|
| 169 |
+
eye_right = landmarks[1, :]
|
| 170 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
| 171 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
| 172 |
+
eye_to_eye = eye_right - eye_left
|
| 173 |
+
eye_to_mouth = mouth_avg - eye_avg
|
| 174 |
+
|
| 175 |
+
# Get the oriented crop rectangle
|
| 176 |
+
# x: half width of the oriented crop rectangle
|
| 177 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
| 178 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
| 179 |
+
# norm with the hypotenuse: get the direction
|
| 180 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
| 181 |
+
rect_scale = 1.5
|
| 182 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
| 183 |
+
# y: half height of the oriented crop rectangle
|
| 184 |
+
y = np.flipud(x) * [-1, 1]
|
| 185 |
+
|
| 186 |
+
# c: center
|
| 187 |
+
c = eye_avg + eye_to_mouth * 0.1
|
| 188 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
| 189 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
| 190 |
+
# qsize: side length of the square
|
| 191 |
+
qsize = np.hypot(*x) * 2
|
| 192 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
| 193 |
+
|
| 194 |
+
# get pad
|
| 195 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
| 196 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
| 197 |
+
int(np.ceil(max(quad[:, 1]))))
|
| 198 |
+
pad = [
|
| 199 |
+
max(-pad[0] + border, 1),
|
| 200 |
+
max(-pad[1] + border, 1),
|
| 201 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
| 202 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
if max(pad) > 1:
|
| 206 |
+
# pad image
|
| 207 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
| 208 |
+
# modify landmark coords
|
| 209 |
+
landmarks[:, 0] += pad[0]
|
| 210 |
+
landmarks[:, 1] += pad[1]
|
| 211 |
+
# blur pad images
|
| 212 |
+
h, w, _ = pad_img.shape
|
| 213 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
| 214 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
| 215 |
+
np.float32(w - 1 - x) / pad[2]),
|
| 216 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
| 217 |
+
np.float32(h - 1 - y) / pad[3]))
|
| 218 |
+
blur = int(qsize * blur_ratio)
|
| 219 |
+
if blur % 2 == 0:
|
| 220 |
+
blur += 1
|
| 221 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
| 222 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
| 223 |
+
|
| 224 |
+
pad_img = pad_img.astype('float32')
|
| 225 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
| 226 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
| 227 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
| 228 |
+
self.pad_input_imgs.append(pad_img)
|
| 229 |
+
else:
|
| 230 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
| 231 |
+
|
| 232 |
+
return len(self.all_landmarks_5)
|
| 233 |
+
|
| 234 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
| 235 |
+
"""Align and warp faces with face template.
|
| 236 |
+
"""
|
| 237 |
+
if self.pad_blur:
|
| 238 |
+
assert len(self.pad_input_imgs) == len(
|
| 239 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
| 240 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
| 241 |
+
# use 5 landmarks to get affine matrix
|
| 242 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
| 243 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
| 244 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
| 245 |
+
self.affine_matrices.append(affine_matrix)
|
| 246 |
+
# warp and crop faces
|
| 247 |
+
if border_mode == 'constant':
|
| 248 |
+
border_mode = cv2.BORDER_CONSTANT
|
| 249 |
+
elif border_mode == 'reflect101':
|
| 250 |
+
border_mode = cv2.BORDER_REFLECT101
|
| 251 |
+
elif border_mode == 'reflect':
|
| 252 |
+
border_mode = cv2.BORDER_REFLECT
|
| 253 |
+
if self.pad_blur:
|
| 254 |
+
input_img = self.pad_input_imgs[idx]
|
| 255 |
+
else:
|
| 256 |
+
input_img = self.input_img
|
| 257 |
+
cropped_face = cv2.warpAffine(
|
| 258 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
| 259 |
+
self.cropped_faces.append(cropped_face)
|
| 260 |
+
# save the cropped face
|
| 261 |
+
if save_cropped_path is not None:
|
| 262 |
+
path = os.path.splitext(save_cropped_path)[0]
|
| 263 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
| 264 |
+
imwrite(cropped_face, save_path)
|
| 265 |
+
|
| 266 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
| 267 |
+
"""Get inverse affine matrix."""
|
| 268 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
| 269 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
| 270 |
+
inverse_affine *= self.upscale_factor
|
| 271 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
| 272 |
+
# save inverse affine matrices
|
| 273 |
+
if save_inverse_affine_path is not None:
|
| 274 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
| 275 |
+
save_path = f'{path}_{idx:02d}.pth'
|
| 276 |
+
torch.save(inverse_affine, save_path)
|
| 277 |
+
|
| 278 |
+
def add_restored_face(self, face):
|
| 279 |
+
self.restored_faces.append(face)
|
| 280 |
+
|
| 281 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
|
| 282 |
+
h, w, _ = self.input_img.shape
|
| 283 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
| 284 |
+
|
| 285 |
+
if upsample_img is None:
|
| 286 |
+
# simply resize the background
|
| 287 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 288 |
+
else:
|
| 289 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 290 |
+
|
| 291 |
+
assert len(self.restored_faces) == len(
|
| 292 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
| 293 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
| 294 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
| 295 |
+
if self.upscale_factor > 1:
|
| 296 |
+
extra_offset = 0.5 * self.upscale_factor
|
| 297 |
+
else:
|
| 298 |
+
extra_offset = 0
|
| 299 |
+
inverse_affine[:, 2] += extra_offset
|
| 300 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
| 301 |
+
|
| 302 |
+
if self.use_parse:
|
| 303 |
+
# inference
|
| 304 |
+
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 305 |
+
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
| 306 |
+
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 307 |
+
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
out = self.face_parse(face_input)[0]
|
| 310 |
+
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
| 311 |
+
|
| 312 |
+
mask = np.zeros(out.shape)
|
| 313 |
+
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
| 314 |
+
for idx, color in enumerate(MASK_COLORMAP):
|
| 315 |
+
mask[out == idx] = color
|
| 316 |
+
# blur the mask
|
| 317 |
+
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
| 318 |
+
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
| 319 |
+
# remove the black borders
|
| 320 |
+
thres = 10
|
| 321 |
+
mask[:thres, :] = 0
|
| 322 |
+
mask[-thres:, :] = 0
|
| 323 |
+
mask[:, :thres] = 0
|
| 324 |
+
mask[:, -thres:] = 0
|
| 325 |
+
mask = mask / 255.
|
| 326 |
+
|
| 327 |
+
mask = cv2.resize(mask, restored_face.shape[:2])
|
| 328 |
+
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
|
| 329 |
+
inv_soft_mask = mask[:, :, None]
|
| 330 |
+
pasted_face = inv_restored
|
| 331 |
+
|
| 332 |
+
else: # use square parse maps
|
| 333 |
+
mask = np.ones(self.face_size, dtype=np.float32)
|
| 334 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 335 |
+
# remove the black borders
|
| 336 |
+
inv_mask_erosion = cv2.erode(
|
| 337 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
| 338 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
| 339 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
| 340 |
+
# compute the fusion edge based on the area of face
|
| 341 |
+
w_edge = int(total_face_area**0.5) // 20
|
| 342 |
+
erosion_radius = w_edge * 2
|
| 343 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 344 |
+
blur_size = w_edge * 2
|
| 345 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
| 346 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
| 347 |
+
upsample_img = upsample_img[:, :, None]
|
| 348 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
| 349 |
+
|
| 350 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
| 351 |
+
alpha = upsample_img[:, :, 3:]
|
| 352 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
| 353 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
| 354 |
+
else:
|
| 355 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
| 356 |
+
|
| 357 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
| 358 |
+
upsample_img = upsample_img.astype(np.uint16)
|
| 359 |
+
else:
|
| 360 |
+
upsample_img = upsample_img.astype(np.uint8)
|
| 361 |
+
if save_path is not None:
|
| 362 |
+
path = os.path.splitext(save_path)[0]
|
| 363 |
+
save_path = f'{path}.{self.save_ext}'
|
| 364 |
+
imwrite(upsample_img, save_path)
|
| 365 |
+
return upsample_img
|
| 366 |
+
|
| 367 |
+
def clean_all(self):
|
| 368 |
+
self.all_landmarks_5 = []
|
| 369 |
+
self.restored_faces = []
|
| 370 |
+
self.affine_matrices = []
|
| 371 |
+
self.cropped_faces = []
|
| 372 |
+
self.inverse_affine_matrices = []
|
| 373 |
+
self.det_faces = []
|
| 374 |
+
self.pad_input_imgs = []
|
inpaint_worker 2.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from PIL import Image, ImageFilter
|
| 5 |
+
from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
|
| 6 |
+
from modules.upscaler import perform_upscale
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
inpaint_head_model = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class InpaintHead(torch.nn.Module):
|
| 14 |
+
def __init__(self, *args, **kwargs):
|
| 15 |
+
super().__init__(*args, **kwargs)
|
| 16 |
+
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
|
| 17 |
+
|
| 18 |
+
def __call__(self, x):
|
| 19 |
+
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
|
| 20 |
+
return torch.nn.functional.conv2d(input=x, weight=self.head)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
current_task = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def box_blur(x, k):
|
| 27 |
+
x = Image.fromarray(x)
|
| 28 |
+
x = x.filter(ImageFilter.BoxBlur(k))
|
| 29 |
+
return np.array(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def max_filter_opencv(x, ksize=3):
|
| 33 |
+
# Use OpenCV maximum filter
|
| 34 |
+
# Make sure the input type is int16
|
| 35 |
+
return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def morphological_open(x):
|
| 39 |
+
# Convert array to int16 type via threshold operation
|
| 40 |
+
x_int16 = np.zeros_like(x, dtype=np.int16)
|
| 41 |
+
x_int16[x > 127] = 256
|
| 42 |
+
|
| 43 |
+
for i in range(32):
|
| 44 |
+
# Use int16 type to avoid overflow
|
| 45 |
+
maxed = max_filter_opencv(x_int16, ksize=3) - 8
|
| 46 |
+
x_int16 = np.maximum(maxed, x_int16)
|
| 47 |
+
|
| 48 |
+
# Clip negative values to 0 and convert back to uint8 type
|
| 49 |
+
x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
|
| 50 |
+
return x_uint8
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def up255(x, t=0):
|
| 54 |
+
y = np.zeros_like(x).astype(np.uint8)
|
| 55 |
+
y[x > t] = 255
|
| 56 |
+
return y
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def imsave(x, path):
|
| 60 |
+
x = Image.fromarray(x)
|
| 61 |
+
x.save(path)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def regulate_abcd(x, a, b, c, d):
|
| 65 |
+
H, W = x.shape[:2]
|
| 66 |
+
if a < 0:
|
| 67 |
+
a = 0
|
| 68 |
+
if a > H:
|
| 69 |
+
a = H
|
| 70 |
+
if b < 0:
|
| 71 |
+
b = 0
|
| 72 |
+
if b > H:
|
| 73 |
+
b = H
|
| 74 |
+
if c < 0:
|
| 75 |
+
c = 0
|
| 76 |
+
if c > W:
|
| 77 |
+
c = W
|
| 78 |
+
if d < 0:
|
| 79 |
+
d = 0
|
| 80 |
+
if d > W:
|
| 81 |
+
d = W
|
| 82 |
+
return int(a), int(b), int(c), int(d)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def compute_initial_abcd(x):
|
| 86 |
+
indices = np.where(x)
|
| 87 |
+
a = np.min(indices[0])
|
| 88 |
+
b = np.max(indices[0])
|
| 89 |
+
c = np.min(indices[1])
|
| 90 |
+
d = np.max(indices[1])
|
| 91 |
+
abp = (b + a) // 2
|
| 92 |
+
abm = (b - a) // 2
|
| 93 |
+
cdp = (d + c) // 2
|
| 94 |
+
cdm = (d - c) // 2
|
| 95 |
+
l = int(max(abm, cdm) * 1.15)
|
| 96 |
+
a = abp - l
|
| 97 |
+
b = abp + l + 1
|
| 98 |
+
c = cdp - l
|
| 99 |
+
d = cdp + l + 1
|
| 100 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
| 101 |
+
return a, b, c, d
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def solve_abcd(x, a, b, c, d, k):
|
| 105 |
+
k = float(k)
|
| 106 |
+
assert 0.0 <= k <= 1.0
|
| 107 |
+
|
| 108 |
+
H, W = x.shape[:2]
|
| 109 |
+
if k == 1.0:
|
| 110 |
+
return 0, H, 0, W
|
| 111 |
+
while True:
|
| 112 |
+
if b - a >= H * k and d - c >= W * k:
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
add_h = (b - a) < (d - c)
|
| 116 |
+
add_w = not add_h
|
| 117 |
+
|
| 118 |
+
if b - a == H:
|
| 119 |
+
add_w = True
|
| 120 |
+
|
| 121 |
+
if d - c == W:
|
| 122 |
+
add_h = True
|
| 123 |
+
|
| 124 |
+
if add_h:
|
| 125 |
+
a -= 1
|
| 126 |
+
b += 1
|
| 127 |
+
|
| 128 |
+
if add_w:
|
| 129 |
+
c -= 1
|
| 130 |
+
d += 1
|
| 131 |
+
|
| 132 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
| 133 |
+
return a, b, c, d
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def fooocus_fill(image, mask):
|
| 137 |
+
current_image = image.copy()
|
| 138 |
+
raw_image = image.copy()
|
| 139 |
+
area = np.where(mask < 127)
|
| 140 |
+
store = raw_image[area]
|
| 141 |
+
|
| 142 |
+
for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
|
| 143 |
+
for _ in range(repeats):
|
| 144 |
+
current_image = box_blur(current_image, k)
|
| 145 |
+
current_image[area] = store
|
| 146 |
+
|
| 147 |
+
return current_image
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class InpaintWorker:
|
| 151 |
+
def __init__(self, image, mask, use_fill=True, k=0.618):
|
| 152 |
+
a, b, c, d = compute_initial_abcd(mask > 0)
|
| 153 |
+
a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
|
| 154 |
+
|
| 155 |
+
# interested area
|
| 156 |
+
self.interested_area = (a, b, c, d)
|
| 157 |
+
self.interested_mask = mask[a:b, c:d]
|
| 158 |
+
self.interested_image = image[a:b, c:d]
|
| 159 |
+
|
| 160 |
+
# super resolution
|
| 161 |
+
if get_image_shape_ceil(self.interested_image) < 1024:
|
| 162 |
+
self.interested_image = perform_upscale(self.interested_image)
|
| 163 |
+
|
| 164 |
+
# resize to make images ready for diffusion
|
| 165 |
+
self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
|
| 166 |
+
self.interested_fill = self.interested_image.copy()
|
| 167 |
+
H, W, C = self.interested_image.shape
|
| 168 |
+
|
| 169 |
+
# process mask
|
| 170 |
+
self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
|
| 171 |
+
|
| 172 |
+
# compute filling
|
| 173 |
+
if use_fill:
|
| 174 |
+
self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
|
| 175 |
+
|
| 176 |
+
# soft pixels
|
| 177 |
+
self.mask = morphological_open(mask)
|
| 178 |
+
self.image = image
|
| 179 |
+
|
| 180 |
+
# ending
|
| 181 |
+
self.latent = None
|
| 182 |
+
self.latent_after_swap = None
|
| 183 |
+
self.swapped = False
|
| 184 |
+
self.latent_mask = None
|
| 185 |
+
self.inpaint_head_feature = None
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
def load_latent(self, latent_fill, latent_mask, latent_swap=None):
|
| 189 |
+
self.latent = latent_fill
|
| 190 |
+
self.latent_mask = latent_mask
|
| 191 |
+
self.latent_after_swap = latent_swap
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
|
| 195 |
+
global inpaint_head_model
|
| 196 |
+
|
| 197 |
+
if inpaint_head_model is None:
|
| 198 |
+
inpaint_head_model = InpaintHead()
|
| 199 |
+
sd = torch.load(inpaint_head_model_path, map_location='cpu')
|
| 200 |
+
inpaint_head_model.load_state_dict(sd)
|
| 201 |
+
|
| 202 |
+
feed = torch.cat([
|
| 203 |
+
inpaint_latent_mask,
|
| 204 |
+
model.model.process_latent_in(inpaint_latent)
|
| 205 |
+
], dim=1)
|
| 206 |
+
|
| 207 |
+
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
|
| 208 |
+
inpaint_head_feature = inpaint_head_model(feed)
|
| 209 |
+
|
| 210 |
+
def input_block_patch(h, transformer_options):
|
| 211 |
+
if transformer_options["block"][1] == 0:
|
| 212 |
+
h = h + inpaint_head_feature.to(h)
|
| 213 |
+
return h
|
| 214 |
+
|
| 215 |
+
m = model.clone()
|
| 216 |
+
m.set_model_input_block_patch(input_block_patch)
|
| 217 |
+
return m
|
| 218 |
+
|
| 219 |
+
def swap(self):
|
| 220 |
+
if self.swapped:
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
if self.latent is None:
|
| 224 |
+
return
|
| 225 |
+
|
| 226 |
+
if self.latent_after_swap is None:
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
| 230 |
+
self.swapped = True
|
| 231 |
+
return
|
| 232 |
+
|
| 233 |
+
def unswap(self):
|
| 234 |
+
if not self.swapped:
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
if self.latent is None:
|
| 238 |
+
return
|
| 239 |
+
|
| 240 |
+
if self.latent_after_swap is None:
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
| 244 |
+
self.swapped = False
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
def color_correction(self, img):
|
| 248 |
+
fg = img.astype(np.float32)
|
| 249 |
+
bg = self.image.copy().astype(np.float32)
|
| 250 |
+
w = self.mask[:, :, None].astype(np.float32) / 255.0
|
| 251 |
+
y = fg * w + bg * (1 - w)
|
| 252 |
+
return y.clip(0, 255).astype(np.uint8)
|
| 253 |
+
|
| 254 |
+
def post_process(self, img):
|
| 255 |
+
a, b, c, d = self.interested_area
|
| 256 |
+
content = resample_image(img, d - c, b - a)
|
| 257 |
+
result = self.image.copy()
|
| 258 |
+
result[a:b, c:d] = content
|
| 259 |
+
result = self.color_correction(result)
|
| 260 |
+
return result
|
| 261 |
+
|
| 262 |
+
def visualize_mask_processing(self):
|
| 263 |
+
return [self.interested_fill, self.interested_mask, self.interested_image]
|
| 264 |
+
|
inpaint_worker.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from PIL import Image, ImageFilter
|
| 5 |
+
from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
|
| 6 |
+
from modules.upscaler import perform_upscale
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
inpaint_head_model = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class InpaintHead(torch.nn.Module):
|
| 14 |
+
def __init__(self, *args, **kwargs):
|
| 15 |
+
super().__init__(*args, **kwargs)
|
| 16 |
+
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
|
| 17 |
+
|
| 18 |
+
def __call__(self, x):
|
| 19 |
+
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
|
| 20 |
+
return torch.nn.functional.conv2d(input=x, weight=self.head)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
current_task = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def box_blur(x, k):
|
| 27 |
+
x = Image.fromarray(x)
|
| 28 |
+
x = x.filter(ImageFilter.BoxBlur(k))
|
| 29 |
+
return np.array(x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def max_filter_opencv(x, ksize=3):
|
| 33 |
+
# Use OpenCV maximum filter
|
| 34 |
+
# Make sure the input type is int16
|
| 35 |
+
return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def morphological_open(x):
|
| 39 |
+
# Convert array to int16 type via threshold operation
|
| 40 |
+
x_int16 = np.zeros_like(x, dtype=np.int16)
|
| 41 |
+
x_int16[x > 127] = 256
|
| 42 |
+
|
| 43 |
+
for i in range(32):
|
| 44 |
+
# Use int16 type to avoid overflow
|
| 45 |
+
maxed = max_filter_opencv(x_int16, ksize=3) - 8
|
| 46 |
+
x_int16 = np.maximum(maxed, x_int16)
|
| 47 |
+
|
| 48 |
+
# Clip negative values to 0 and convert back to uint8 type
|
| 49 |
+
x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
|
| 50 |
+
return x_uint8
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def up255(x, t=0):
|
| 54 |
+
y = np.zeros_like(x).astype(np.uint8)
|
| 55 |
+
y[x > t] = 255
|
| 56 |
+
return y
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def imsave(x, path):
|
| 60 |
+
x = Image.fromarray(x)
|
| 61 |
+
x.save(path)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def regulate_abcd(x, a, b, c, d):
|
| 65 |
+
H, W = x.shape[:2]
|
| 66 |
+
if a < 0:
|
| 67 |
+
a = 0
|
| 68 |
+
if a > H:
|
| 69 |
+
a = H
|
| 70 |
+
if b < 0:
|
| 71 |
+
b = 0
|
| 72 |
+
if b > H:
|
| 73 |
+
b = H
|
| 74 |
+
if c < 0:
|
| 75 |
+
c = 0
|
| 76 |
+
if c > W:
|
| 77 |
+
c = W
|
| 78 |
+
if d < 0:
|
| 79 |
+
d = 0
|
| 80 |
+
if d > W:
|
| 81 |
+
d = W
|
| 82 |
+
return int(a), int(b), int(c), int(d)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def compute_initial_abcd(x):
|
| 86 |
+
indices = np.where(x)
|
| 87 |
+
a = np.min(indices[0])
|
| 88 |
+
b = np.max(indices[0])
|
| 89 |
+
c = np.min(indices[1])
|
| 90 |
+
d = np.max(indices[1])
|
| 91 |
+
abp = (b + a) // 2
|
| 92 |
+
abm = (b - a) // 2
|
| 93 |
+
cdp = (d + c) // 2
|
| 94 |
+
cdm = (d - c) // 2
|
| 95 |
+
l = int(max(abm, cdm) * 1.15)
|
| 96 |
+
a = abp - l
|
| 97 |
+
b = abp + l + 1
|
| 98 |
+
c = cdp - l
|
| 99 |
+
d = cdp + l + 1
|
| 100 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
| 101 |
+
return a, b, c, d
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def solve_abcd(x, a, b, c, d, k):
|
| 105 |
+
k = float(k)
|
| 106 |
+
assert 0.0 <= k <= 1.0
|
| 107 |
+
|
| 108 |
+
H, W = x.shape[:2]
|
| 109 |
+
if k == 1.0:
|
| 110 |
+
return 0, H, 0, W
|
| 111 |
+
while True:
|
| 112 |
+
if b - a >= H * k and d - c >= W * k:
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
add_h = (b - a) < (d - c)
|
| 116 |
+
add_w = not add_h
|
| 117 |
+
|
| 118 |
+
if b - a == H:
|
| 119 |
+
add_w = True
|
| 120 |
+
|
| 121 |
+
if d - c == W:
|
| 122 |
+
add_h = True
|
| 123 |
+
|
| 124 |
+
if add_h:
|
| 125 |
+
a -= 1
|
| 126 |
+
b += 1
|
| 127 |
+
|
| 128 |
+
if add_w:
|
| 129 |
+
c -= 1
|
| 130 |
+
d += 1
|
| 131 |
+
|
| 132 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
| 133 |
+
return a, b, c, d
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def fooocus_fill(image, mask):
|
| 137 |
+
current_image = image.copy()
|
| 138 |
+
raw_image = image.copy()
|
| 139 |
+
area = np.where(mask < 127)
|
| 140 |
+
store = raw_image[area]
|
| 141 |
+
|
| 142 |
+
for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
|
| 143 |
+
for _ in range(repeats):
|
| 144 |
+
current_image = box_blur(current_image, k)
|
| 145 |
+
current_image[area] = store
|
| 146 |
+
|
| 147 |
+
return current_image
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class InpaintWorker:
|
| 151 |
+
def __init__(self, image, mask, use_fill=True, k=0.618):
|
| 152 |
+
a, b, c, d = compute_initial_abcd(mask > 0)
|
| 153 |
+
a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
|
| 154 |
+
|
| 155 |
+
# interested area
|
| 156 |
+
self.interested_area = (a, b, c, d)
|
| 157 |
+
self.interested_mask = mask[a:b, c:d]
|
| 158 |
+
self.interested_image = image[a:b, c:d]
|
| 159 |
+
|
| 160 |
+
# super resolution
|
| 161 |
+
if get_image_shape_ceil(self.interested_image) < 1024:
|
| 162 |
+
self.interested_image = perform_upscale(self.interested_image)
|
| 163 |
+
|
| 164 |
+
# resize to make images ready for diffusion
|
| 165 |
+
self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
|
| 166 |
+
self.interested_fill = self.interested_image.copy()
|
| 167 |
+
H, W, C = self.interested_image.shape
|
| 168 |
+
|
| 169 |
+
# process mask
|
| 170 |
+
self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
|
| 171 |
+
|
| 172 |
+
# compute filling
|
| 173 |
+
if use_fill:
|
| 174 |
+
self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
|
| 175 |
+
|
| 176 |
+
# soft pixels
|
| 177 |
+
self.mask = morphological_open(mask)
|
| 178 |
+
self.image = image
|
| 179 |
+
|
| 180 |
+
# ending
|
| 181 |
+
self.latent = None
|
| 182 |
+
self.latent_after_swap = None
|
| 183 |
+
self.swapped = False
|
| 184 |
+
self.latent_mask = None
|
| 185 |
+
self.inpaint_head_feature = None
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
def load_latent(self, latent_fill, latent_mask, latent_swap=None):
|
| 189 |
+
self.latent = latent_fill
|
| 190 |
+
self.latent_mask = latent_mask
|
| 191 |
+
self.latent_after_swap = latent_swap
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
|
| 195 |
+
global inpaint_head_model
|
| 196 |
+
|
| 197 |
+
if inpaint_head_model is None:
|
| 198 |
+
inpaint_head_model = InpaintHead()
|
| 199 |
+
sd = torch.load(inpaint_head_model_path, map_location='cpu')
|
| 200 |
+
inpaint_head_model.load_state_dict(sd)
|
| 201 |
+
|
| 202 |
+
feed = torch.cat([
|
| 203 |
+
inpaint_latent_mask,
|
| 204 |
+
model.model.process_latent_in(inpaint_latent)
|
| 205 |
+
], dim=1)
|
| 206 |
+
|
| 207 |
+
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
|
| 208 |
+
inpaint_head_feature = inpaint_head_model(feed)
|
| 209 |
+
|
| 210 |
+
def input_block_patch(h, transformer_options):
|
| 211 |
+
if transformer_options["block"][1] == 0:
|
| 212 |
+
h = h + inpaint_head_feature.to(h)
|
| 213 |
+
return h
|
| 214 |
+
|
| 215 |
+
m = model.clone()
|
| 216 |
+
m.set_model_input_block_patch(input_block_patch)
|
| 217 |
+
return m
|
| 218 |
+
|
| 219 |
+
def swap(self):
|
| 220 |
+
if self.swapped:
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
if self.latent is None:
|
| 224 |
+
return
|
| 225 |
+
|
| 226 |
+
if self.latent_after_swap is None:
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
| 230 |
+
self.swapped = True
|
| 231 |
+
return
|
| 232 |
+
|
| 233 |
+
def unswap(self):
|
| 234 |
+
if not self.swapped:
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
if self.latent is None:
|
| 238 |
+
return
|
| 239 |
+
|
| 240 |
+
if self.latent_after_swap is None:
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
| 244 |
+
self.swapped = False
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
def color_correction(self, img):
|
| 248 |
+
fg = img.astype(np.float32)
|
| 249 |
+
bg = self.image.copy().astype(np.float32)
|
| 250 |
+
w = self.mask[:, :, None].astype(np.float32) / 255.0
|
| 251 |
+
y = fg * w + bg * (1 - w)
|
| 252 |
+
return y.clip(0, 255).astype(np.uint8)
|
| 253 |
+
|
| 254 |
+
def post_process(self, img):
|
| 255 |
+
a, b, c, d = self.interested_area
|
| 256 |
+
content = resample_image(img, d - c, b - a)
|
| 257 |
+
result = self.image.copy()
|
| 258 |
+
result[a:b, c:d] = content
|
| 259 |
+
result = self.color_correction(result)
|
| 260 |
+
return result
|
| 261 |
+
|
| 262 |
+
def visualize_mask_processing(self):
|
| 263 |
+
return [self.interested_fill, self.interested_mask, self.interested_image]
|
| 264 |
+
|
launch_util.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import importlib
|
| 3 |
+
import importlib.util
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
import importlib.metadata
|
| 9 |
+
import packaging.version
|
| 10 |
+
from packaging.requirements import Requirement
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
| 16 |
+
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
| 17 |
+
|
| 18 |
+
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
| 19 |
+
|
| 20 |
+
python = sys.executable
|
| 21 |
+
default_command_live = (os.environ.get('LAUNCH_LIVE_OUTPUT') == "1")
|
| 22 |
+
index_url = os.environ.get('INDEX_URL', "")
|
| 23 |
+
|
| 24 |
+
modules_path = os.path.dirname(os.path.realpath(__file__))
|
| 25 |
+
script_path = os.path.dirname(modules_path)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_installed(package):
|
| 29 |
+
try:
|
| 30 |
+
spec = importlib.util.find_spec(package)
|
| 31 |
+
except ModuleNotFoundError:
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
return spec is not None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
| 38 |
+
if desc is not None:
|
| 39 |
+
print(desc)
|
| 40 |
+
|
| 41 |
+
run_kwargs = {
|
| 42 |
+
"args": command,
|
| 43 |
+
"shell": True,
|
| 44 |
+
"env": os.environ if custom_env is None else custom_env,
|
| 45 |
+
"encoding": 'utf8',
|
| 46 |
+
"errors": 'ignore',
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
if not live:
|
| 50 |
+
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
| 51 |
+
|
| 52 |
+
result = subprocess.run(**run_kwargs)
|
| 53 |
+
|
| 54 |
+
if result.returncode != 0:
|
| 55 |
+
error_bits = [
|
| 56 |
+
f"{errdesc or 'Error running command'}.",
|
| 57 |
+
f"Command: {command}",
|
| 58 |
+
f"Error code: {result.returncode}",
|
| 59 |
+
]
|
| 60 |
+
if result.stdout:
|
| 61 |
+
error_bits.append(f"stdout: {result.stdout}")
|
| 62 |
+
if result.stderr:
|
| 63 |
+
error_bits.append(f"stderr: {result.stderr}")
|
| 64 |
+
raise RuntimeError("\n".join(error_bits))
|
| 65 |
+
|
| 66 |
+
return (result.stdout or "")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def run_pip(command, desc=None, live=default_command_live):
|
| 70 |
+
try:
|
| 71 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
| 72 |
+
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}",
|
| 73 |
+
errdesc=f"Couldn't install {desc}", live=live)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(e)
|
| 76 |
+
print(f'CMD Failed {desc}: {command}')
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def requirements_met(requirements_file):
|
| 81 |
+
with open(requirements_file, "r", encoding="utf8") as file:
|
| 82 |
+
for line in file:
|
| 83 |
+
line = line.strip()
|
| 84 |
+
if line == "" or line.startswith('#'):
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
requirement = Requirement(line)
|
| 88 |
+
package = requirement.name
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
version_installed = importlib.metadata.version(package)
|
| 92 |
+
installed_version = packaging.version.parse(version_installed)
|
| 93 |
+
|
| 94 |
+
# Check if the installed version satisfies the requirement
|
| 95 |
+
if installed_version not in requirement.specifier:
|
| 96 |
+
print(f"Version mismatch for {package}: Installed version {version_installed} does not meet requirement {requirement}")
|
| 97 |
+
return False
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Error checking version for {package}: {e}")
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
return True
|
| 103 |
+
|
lora.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def match_lora(lora, to_load):
|
| 2 |
+
patch_dict = {}
|
| 3 |
+
loaded_keys = set()
|
| 4 |
+
for x in to_load:
|
| 5 |
+
real_load_key = to_load[x]
|
| 6 |
+
if real_load_key in lora:
|
| 7 |
+
patch_dict[real_load_key] = ('fooocus', lora[real_load_key])
|
| 8 |
+
loaded_keys.add(real_load_key)
|
| 9 |
+
continue
|
| 10 |
+
|
| 11 |
+
alpha_name = "{}.alpha".format(x)
|
| 12 |
+
alpha = None
|
| 13 |
+
if alpha_name in lora.keys():
|
| 14 |
+
alpha = lora[alpha_name].item()
|
| 15 |
+
loaded_keys.add(alpha_name)
|
| 16 |
+
|
| 17 |
+
regular_lora = "{}.lora_up.weight".format(x)
|
| 18 |
+
diffusers_lora = "{}_lora.up.weight".format(x)
|
| 19 |
+
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
| 20 |
+
A_name = None
|
| 21 |
+
|
| 22 |
+
if regular_lora in lora.keys():
|
| 23 |
+
A_name = regular_lora
|
| 24 |
+
B_name = "{}.lora_down.weight".format(x)
|
| 25 |
+
mid_name = "{}.lora_mid.weight".format(x)
|
| 26 |
+
elif diffusers_lora in lora.keys():
|
| 27 |
+
A_name = diffusers_lora
|
| 28 |
+
B_name = "{}_lora.down.weight".format(x)
|
| 29 |
+
mid_name = None
|
| 30 |
+
elif transformers_lora in lora.keys():
|
| 31 |
+
A_name = transformers_lora
|
| 32 |
+
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
| 33 |
+
mid_name = None
|
| 34 |
+
|
| 35 |
+
if A_name is not None:
|
| 36 |
+
mid = None
|
| 37 |
+
if mid_name is not None and mid_name in lora.keys():
|
| 38 |
+
mid = lora[mid_name]
|
| 39 |
+
loaded_keys.add(mid_name)
|
| 40 |
+
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
| 41 |
+
loaded_keys.add(A_name)
|
| 42 |
+
loaded_keys.add(B_name)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
######## loha
|
| 46 |
+
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
| 47 |
+
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
| 48 |
+
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
| 49 |
+
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
| 50 |
+
hada_t1_name = "{}.hada_t1".format(x)
|
| 51 |
+
hada_t2_name = "{}.hada_t2".format(x)
|
| 52 |
+
if hada_w1_a_name in lora.keys():
|
| 53 |
+
hada_t1 = None
|
| 54 |
+
hada_t2 = None
|
| 55 |
+
if hada_t1_name in lora.keys():
|
| 56 |
+
hada_t1 = lora[hada_t1_name]
|
| 57 |
+
hada_t2 = lora[hada_t2_name]
|
| 58 |
+
loaded_keys.add(hada_t1_name)
|
| 59 |
+
loaded_keys.add(hada_t2_name)
|
| 60 |
+
|
| 61 |
+
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
| 62 |
+
loaded_keys.add(hada_w1_a_name)
|
| 63 |
+
loaded_keys.add(hada_w1_b_name)
|
| 64 |
+
loaded_keys.add(hada_w2_a_name)
|
| 65 |
+
loaded_keys.add(hada_w2_b_name)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
######## lokr
|
| 69 |
+
lokr_w1_name = "{}.lokr_w1".format(x)
|
| 70 |
+
lokr_w2_name = "{}.lokr_w2".format(x)
|
| 71 |
+
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
| 72 |
+
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
| 73 |
+
lokr_t2_name = "{}.lokr_t2".format(x)
|
| 74 |
+
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
| 75 |
+
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
| 76 |
+
|
| 77 |
+
lokr_w1 = None
|
| 78 |
+
if lokr_w1_name in lora.keys():
|
| 79 |
+
lokr_w1 = lora[lokr_w1_name]
|
| 80 |
+
loaded_keys.add(lokr_w1_name)
|
| 81 |
+
|
| 82 |
+
lokr_w2 = None
|
| 83 |
+
if lokr_w2_name in lora.keys():
|
| 84 |
+
lokr_w2 = lora[lokr_w2_name]
|
| 85 |
+
loaded_keys.add(lokr_w2_name)
|
| 86 |
+
|
| 87 |
+
lokr_w1_a = None
|
| 88 |
+
if lokr_w1_a_name in lora.keys():
|
| 89 |
+
lokr_w1_a = lora[lokr_w1_a_name]
|
| 90 |
+
loaded_keys.add(lokr_w1_a_name)
|
| 91 |
+
|
| 92 |
+
lokr_w1_b = None
|
| 93 |
+
if lokr_w1_b_name in lora.keys():
|
| 94 |
+
lokr_w1_b = lora[lokr_w1_b_name]
|
| 95 |
+
loaded_keys.add(lokr_w1_b_name)
|
| 96 |
+
|
| 97 |
+
lokr_w2_a = None
|
| 98 |
+
if lokr_w2_a_name in lora.keys():
|
| 99 |
+
lokr_w2_a = lora[lokr_w2_a_name]
|
| 100 |
+
loaded_keys.add(lokr_w2_a_name)
|
| 101 |
+
|
| 102 |
+
lokr_w2_b = None
|
| 103 |
+
if lokr_w2_b_name in lora.keys():
|
| 104 |
+
lokr_w2_b = lora[lokr_w2_b_name]
|
| 105 |
+
loaded_keys.add(lokr_w2_b_name)
|
| 106 |
+
|
| 107 |
+
lokr_t2 = None
|
| 108 |
+
if lokr_t2_name in lora.keys():
|
| 109 |
+
lokr_t2 = lora[lokr_t2_name]
|
| 110 |
+
loaded_keys.add(lokr_t2_name)
|
| 111 |
+
|
| 112 |
+
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
| 113 |
+
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
| 114 |
+
|
| 115 |
+
#glora
|
| 116 |
+
a1_name = "{}.a1.weight".format(x)
|
| 117 |
+
a2_name = "{}.a2.weight".format(x)
|
| 118 |
+
b1_name = "{}.b1.weight".format(x)
|
| 119 |
+
b2_name = "{}.b2.weight".format(x)
|
| 120 |
+
if a1_name in lora:
|
| 121 |
+
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
| 122 |
+
loaded_keys.add(a1_name)
|
| 123 |
+
loaded_keys.add(a2_name)
|
| 124 |
+
loaded_keys.add(b1_name)
|
| 125 |
+
loaded_keys.add(b2_name)
|
| 126 |
+
|
| 127 |
+
w_norm_name = "{}.w_norm".format(x)
|
| 128 |
+
b_norm_name = "{}.b_norm".format(x)
|
| 129 |
+
w_norm = lora.get(w_norm_name, None)
|
| 130 |
+
b_norm = lora.get(b_norm_name, None)
|
| 131 |
+
|
| 132 |
+
if w_norm is not None:
|
| 133 |
+
loaded_keys.add(w_norm_name)
|
| 134 |
+
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
| 135 |
+
if b_norm is not None:
|
| 136 |
+
loaded_keys.add(b_norm_name)
|
| 137 |
+
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
| 138 |
+
|
| 139 |
+
diff_name = "{}.diff".format(x)
|
| 140 |
+
diff_weight = lora.get(diff_name, None)
|
| 141 |
+
if diff_weight is not None:
|
| 142 |
+
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
| 143 |
+
loaded_keys.add(diff_name)
|
| 144 |
+
|
| 145 |
+
diff_bias_name = "{}.diff_b".format(x)
|
| 146 |
+
diff_bias = lora.get(diff_bias_name, None)
|
| 147 |
+
if diff_bias is not None:
|
| 148 |
+
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
| 149 |
+
loaded_keys.add(diff_bias_name)
|
| 150 |
+
|
| 151 |
+
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
| 152 |
+
return patch_dict, remaining_dict
|
model_loader.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from urllib.parse import urlparse
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_file_from_url(
|
| 7 |
+
url: str,
|
| 8 |
+
*,
|
| 9 |
+
model_dir: str,
|
| 10 |
+
progress: bool = True,
|
| 11 |
+
file_name: Optional[str] = None,
|
| 12 |
+
) -> str:
|
| 13 |
+
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
| 14 |
+
|
| 15 |
+
Returns the path to the downloaded file.
|
| 16 |
+
"""
|
| 17 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 18 |
+
if not file_name:
|
| 19 |
+
parts = urlparse(url)
|
| 20 |
+
file_name = os.path.basename(parts.path)
|
| 21 |
+
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
| 22 |
+
if not os.path.exists(cached_file):
|
| 23 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
| 24 |
+
from torch.hub import download_url_to_file
|
| 25 |
+
download_url_to_file(url, cached_file, progress=progress)
|
| 26 |
+
return cached_file
|
sdxl_styles.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
from modules.util import get_files_from_folder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# cannot use modules.config - validators causing circular imports
|
| 9 |
+
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
| 10 |
+
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
|
| 11 |
+
wildcards_max_bfs_depth = 64
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def normalize_key(k):
|
| 15 |
+
k = k.replace('-', ' ')
|
| 16 |
+
words = k.split(' ')
|
| 17 |
+
words = [w[:1].upper() + w[1:].lower() for w in words]
|
| 18 |
+
k = ' '.join(words)
|
| 19 |
+
k = k.replace('3d', '3D')
|
| 20 |
+
k = k.replace('Sai', 'SAI')
|
| 21 |
+
k = k.replace('Mre', 'MRE')
|
| 22 |
+
k = k.replace('(s', '(S')
|
| 23 |
+
return k
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
styles = {}
|
| 27 |
+
|
| 28 |
+
styles_files = get_files_from_folder(styles_path, ['.json'])
|
| 29 |
+
|
| 30 |
+
for x in ['sdxl_styles_fooocus.json',
|
| 31 |
+
'sdxl_styles_sai.json',
|
| 32 |
+
'sdxl_styles_mre.json',
|
| 33 |
+
'sdxl_styles_twri.json',
|
| 34 |
+
'sdxl_styles_diva.json',
|
| 35 |
+
'sdxl_styles_marc_k3nt3l.json']:
|
| 36 |
+
if x in styles_files:
|
| 37 |
+
styles_files.remove(x)
|
| 38 |
+
styles_files.append(x)
|
| 39 |
+
|
| 40 |
+
for styles_file in styles_files:
|
| 41 |
+
try:
|
| 42 |
+
with open(os.path.join(styles_path, styles_file), encoding='utf-8') as f:
|
| 43 |
+
for entry in json.load(f):
|
| 44 |
+
name = normalize_key(entry['name'])
|
| 45 |
+
prompt = entry['prompt'] if 'prompt' in entry else ''
|
| 46 |
+
negative_prompt = entry['negative_prompt'] if 'negative_prompt' in entry else ''
|
| 47 |
+
styles[name] = (prompt, negative_prompt)
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(str(e))
|
| 50 |
+
print(f'Failed to load style file {styles_file}')
|
| 51 |
+
|
| 52 |
+
style_keys = list(styles.keys())
|
| 53 |
+
fooocus_expansion = "Fooocus V2"
|
| 54 |
+
legal_style_names = [fooocus_expansion] + style_keys
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def apply_style(style, positive):
|
| 58 |
+
p, n = styles[style]
|
| 59 |
+
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
| 63 |
+
for _ in range(wildcards_max_bfs_depth):
|
| 64 |
+
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
| 65 |
+
if len(placeholders) == 0:
|
| 66 |
+
return wildcard_text
|
| 67 |
+
|
| 68 |
+
print(f'[Wildcards] processing: {wildcard_text}')
|
| 69 |
+
for placeholder in placeholders:
|
| 70 |
+
try:
|
| 71 |
+
words = open(os.path.join(directory, f'{placeholder}.txt'), encoding='utf-8').read().splitlines()
|
| 72 |
+
words = [x for x in words if x != '']
|
| 73 |
+
assert len(words) > 0
|
| 74 |
+
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
| 75 |
+
except:
|
| 76 |
+
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
|
| 77 |
+
f'Using "{placeholder}" as a normal word.')
|
| 78 |
+
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
|
| 79 |
+
print(f'[Wildcards] {wildcard_text}')
|
| 80 |
+
|
| 81 |
+
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
| 82 |
+
return wildcard_text
|
upscaler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import modules.core as core
|
| 4 |
+
|
| 5 |
+
from ldm_patched.pfn.architecture.RRDB import RRDBNet as ESRGAN
|
| 6 |
+
from ldm_patched.contrib.external_upscale_model import ImageUpscaleWithModel
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from modules.config import path_upscale_models
|
| 9 |
+
|
| 10 |
+
model_filename = os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
| 11 |
+
opImageUpscaleWithModel = ImageUpscaleWithModel()
|
| 12 |
+
model = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def perform_upscale(img):
|
| 16 |
+
global model
|
| 17 |
+
|
| 18 |
+
print(f'Upscaling image with shape {str(img.shape)} ...')
|
| 19 |
+
|
| 20 |
+
if model is None:
|
| 21 |
+
sd = torch.load(model_filename)
|
| 22 |
+
sdo = OrderedDict()
|
| 23 |
+
for k, v in sd.items():
|
| 24 |
+
sdo[k.replace('residual_block_', 'RDB')] = v
|
| 25 |
+
del sd
|
| 26 |
+
model = ESRGAN(sdo)
|
| 27 |
+
model.cpu()
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
img = core.numpy_to_pytorch(img)
|
| 31 |
+
img = opImageUpscaleWithModel.upscale(model, img)[0]
|
| 32 |
+
img = core.pytorch_to_numpy(img)[0]
|
| 33 |
+
|
| 34 |
+
return img
|
util.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import datetime
|
| 3 |
+
import random
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def erode_or_dilate(x, k):
|
| 15 |
+
k = int(k)
|
| 16 |
+
if k > 0:
|
| 17 |
+
return cv2.dilate(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=k)
|
| 18 |
+
if k < 0:
|
| 19 |
+
return cv2.erode(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=-k)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def resample_image(im, width, height):
|
| 24 |
+
im = Image.fromarray(im)
|
| 25 |
+
im = im.resize((int(width), int(height)), resample=LANCZOS)
|
| 26 |
+
return np.array(im)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def resize_image(im, width, height, resize_mode=1):
|
| 30 |
+
"""
|
| 31 |
+
Resizes an image with the specified resize_mode, width, and height.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
resize_mode: The mode to use when resizing the image.
|
| 35 |
+
0: Resize the image to the specified width and height.
|
| 36 |
+
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
| 37 |
+
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
| 38 |
+
im: The image to resize.
|
| 39 |
+
width: The width to resize the image to.
|
| 40 |
+
height: The height to resize the image to.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
im = Image.fromarray(im)
|
| 44 |
+
|
| 45 |
+
def resize(im, w, h):
|
| 46 |
+
return im.resize((w, h), resample=LANCZOS)
|
| 47 |
+
|
| 48 |
+
if resize_mode == 0:
|
| 49 |
+
res = resize(im, width, height)
|
| 50 |
+
|
| 51 |
+
elif resize_mode == 1:
|
| 52 |
+
ratio = width / height
|
| 53 |
+
src_ratio = im.width / im.height
|
| 54 |
+
|
| 55 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
| 56 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
| 57 |
+
|
| 58 |
+
resized = resize(im, src_w, src_h)
|
| 59 |
+
res = Image.new("RGB", (width, height))
|
| 60 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
| 61 |
+
|
| 62 |
+
else:
|
| 63 |
+
ratio = width / height
|
| 64 |
+
src_ratio = im.width / im.height
|
| 65 |
+
|
| 66 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
| 67 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
| 68 |
+
|
| 69 |
+
resized = resize(im, src_w, src_h)
|
| 70 |
+
res = Image.new("RGB", (width, height))
|
| 71 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
| 72 |
+
|
| 73 |
+
if ratio < src_ratio:
|
| 74 |
+
fill_height = height // 2 - src_h // 2
|
| 75 |
+
if fill_height > 0:
|
| 76 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
| 77 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
| 78 |
+
elif ratio > src_ratio:
|
| 79 |
+
fill_width = width // 2 - src_w // 2
|
| 80 |
+
if fill_width > 0:
|
| 81 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
| 82 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
| 83 |
+
|
| 84 |
+
return np.array(res)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_shape_ceil(h, w):
|
| 88 |
+
return math.ceil(((h * w) ** 0.5) / 64.0) * 64.0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_image_shape_ceil(im):
|
| 92 |
+
H, W = im.shape[:2]
|
| 93 |
+
return get_shape_ceil(H, W)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def set_image_shape_ceil(im, shape_ceil):
|
| 97 |
+
shape_ceil = float(shape_ceil)
|
| 98 |
+
|
| 99 |
+
H_origin, W_origin, _ = im.shape
|
| 100 |
+
H, W = H_origin, W_origin
|
| 101 |
+
|
| 102 |
+
for _ in range(256):
|
| 103 |
+
current_shape_ceil = get_shape_ceil(H, W)
|
| 104 |
+
if abs(current_shape_ceil - shape_ceil) < 0.1:
|
| 105 |
+
break
|
| 106 |
+
k = shape_ceil / current_shape_ceil
|
| 107 |
+
H = int(round(float(H) * k / 64.0) * 64)
|
| 108 |
+
W = int(round(float(W) * k / 64.0) * 64)
|
| 109 |
+
|
| 110 |
+
if H == H_origin and W == W_origin:
|
| 111 |
+
return im
|
| 112 |
+
|
| 113 |
+
return resample_image(im, width=W, height=H)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def HWC3(x):
|
| 117 |
+
assert x.dtype == np.uint8
|
| 118 |
+
if x.ndim == 2:
|
| 119 |
+
x = x[:, :, None]
|
| 120 |
+
assert x.ndim == 3
|
| 121 |
+
H, W, C = x.shape
|
| 122 |
+
assert C == 1 or C == 3 or C == 4
|
| 123 |
+
if C == 3:
|
| 124 |
+
return x
|
| 125 |
+
if C == 1:
|
| 126 |
+
return np.concatenate([x, x, x], axis=2)
|
| 127 |
+
if C == 4:
|
| 128 |
+
color = x[:, :, 0:3].astype(np.float32)
|
| 129 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
| 130 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
| 131 |
+
y = y.clip(0, 255).astype(np.uint8)
|
| 132 |
+
return y
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def remove_empty_str(items, default=None):
|
| 136 |
+
items = [x for x in items if x != ""]
|
| 137 |
+
if len(items) == 0 and default is not None:
|
| 138 |
+
return [default]
|
| 139 |
+
return items
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def join_prompts(*args, **kwargs):
|
| 143 |
+
prompts = [str(x) for x in args if str(x) != ""]
|
| 144 |
+
if len(prompts) == 0:
|
| 145 |
+
return ""
|
| 146 |
+
if len(prompts) == 1:
|
| 147 |
+
return prompts[0]
|
| 148 |
+
return ', '.join(prompts)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def generate_temp_filename(folder='./outputs/', extension='png'):
|
| 152 |
+
current_time = datetime.datetime.now()
|
| 153 |
+
date_string = current_time.strftime("%Y-%m-%d")
|
| 154 |
+
time_string = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
| 155 |
+
random_number = random.randint(1000, 9999)
|
| 156 |
+
filename = f"{time_string}_{random_number}.{extension}"
|
| 157 |
+
result = os.path.join(folder, date_string, filename)
|
| 158 |
+
return date_string, os.path.abspath(os.path.realpath(result)), filename
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
| 162 |
+
if not os.path.isdir(folder_path):
|
| 163 |
+
raise ValueError("Folder path is not a valid directory.")
|
| 164 |
+
|
| 165 |
+
filenames = []
|
| 166 |
+
|
| 167 |
+
for root, dirs, files in os.walk(folder_path):
|
| 168 |
+
relative_path = os.path.relpath(root, folder_path)
|
| 169 |
+
if relative_path == ".":
|
| 170 |
+
relative_path = ""
|
| 171 |
+
for filename in files:
|
| 172 |
+
_, file_extension = os.path.splitext(filename)
|
| 173 |
+
if (exensions == None or file_extension.lower() in exensions) and (name_filter == None or name_filter in _):
|
| 174 |
+
path = os.path.join(relative_path, filename)
|
| 175 |
+
filenames.append(path)
|
| 176 |
+
|
| 177 |
+
return sorted(filenames, key=lambda x: -1 if os.sep in x else 1)
|
webui.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import random
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
import shared
|
| 7 |
+
import modules.config
|
| 8 |
+
import fooocus_version
|
| 9 |
+
import modules.html
|
| 10 |
+
import modules.async_worker as worker
|
| 11 |
+
import modules.constants as constants
|
| 12 |
+
import modules.flags as flags
|
| 13 |
+
import modules.gradio_hijack as grh
|
| 14 |
+
import modules.advanced_parameters as advanced_parameters
|
| 15 |
+
import modules.style_sorter as style_sorter
|
| 16 |
+
import modules.meta_parser
|
| 17 |
+
import args_manager
|
| 18 |
+
import copy
|
| 19 |
+
|
| 20 |
+
from modules.sdxl_styles import legal_style_names
|
| 21 |
+
from modules.private_logger import get_current_html_path
|
| 22 |
+
from modules.ui_gradio_extensions import reload_javascript
|
| 23 |
+
from modules.auth import auth_enabled, check_auth
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def generate_clicked(*args):
|
| 27 |
+
import ldm_patched.modules.model_management as model_management
|
| 28 |
+
|
| 29 |
+
with model_management.interrupt_processing_mutex:
|
| 30 |
+
model_management.interrupt_processing = False
|
| 31 |
+
|
| 32 |
+
# outputs=[progress_html, progress_window, progress_gallery, gallery]
|
| 33 |
+
|
| 34 |
+
execution_start_time = time.perf_counter()
|
| 35 |
+
task = worker.AsyncTask(args=list(args))
|
| 36 |
+
finished = False
|
| 37 |
+
|
| 38 |
+
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
|
| 39 |
+
gr.update(visible=True, value=None), \
|
| 40 |
+
gr.update(visible=False, value=None), \
|
| 41 |
+
gr.update(visible=False)
|
| 42 |
+
|
| 43 |
+
worker.async_tasks.append(task)
|
| 44 |
+
|
| 45 |
+
while not finished:
|
| 46 |
+
time.sleep(0.01)
|
| 47 |
+
if len(task.yields) > 0:
|
| 48 |
+
flag, product = task.yields.pop(0)
|
| 49 |
+
if flag == 'preview':
|
| 50 |
+
|
| 51 |
+
# help bad internet connection by skipping duplicated preview
|
| 52 |
+
if len(task.yields) > 0: # if we have the next item
|
| 53 |
+
if task.yields[0][0] == 'preview': # if the next item is also a preview
|
| 54 |
+
# print('Skipped one preview for better internet connection.')
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
percentage, title, image = product
|
| 58 |
+
yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
|
| 59 |
+
gr.update(visible=True, value=image) if image is not None else gr.update(), \
|
| 60 |
+
gr.update(), \
|
| 61 |
+
gr.update(visible=False)
|
| 62 |
+
if flag == 'results':
|
| 63 |
+
yield gr.update(visible=True), \
|
| 64 |
+
gr.update(visible=True), \
|
| 65 |
+
gr.update(visible=True, value=product), \
|
| 66 |
+
gr.update(visible=False)
|
| 67 |
+
if flag == 'finish':
|
| 68 |
+
yield gr.update(visible=False), \
|
| 69 |
+
gr.update(visible=False), \
|
| 70 |
+
gr.update(visible=False), \
|
| 71 |
+
gr.update(visible=True, value=product)
|
| 72 |
+
finished = True
|
| 73 |
+
|
| 74 |
+
execution_time = time.perf_counter() - execution_start_time
|
| 75 |
+
print(f'Total time: {execution_time:.2f} seconds')
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
reload_javascript()
|
| 80 |
+
|
| 81 |
+
title = f'Fooocus {fooocus_version.version}'
|
| 82 |
+
|
| 83 |
+
if isinstance(args_manager.args.preset, str):
|
| 84 |
+
title += ' ' + args_manager.args.preset
|
| 85 |
+
|
| 86 |
+
shared.gradio_root = gr.Blocks(
|
| 87 |
+
title=title,
|
| 88 |
+
css=modules.html.css).queue()
|
| 89 |
+
|
| 90 |
+
with shared.gradio_root:
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column(scale=2):
|
| 93 |
+
with gr.Row():
|
| 94 |
+
progress_window = grh.Image(label='Preview', show_label=True, visible=False, height=768,
|
| 95 |
+
elem_classes=['main_view'])
|
| 96 |
+
progress_gallery = gr.Gallery(label='Finished Images', show_label=True, object_fit='contain',
|
| 97 |
+
height=768, visible=False, elem_classes=['main_view', 'image_gallery'])
|
| 98 |
+
progress_html = gr.HTML(value=modules.html.make_progress_html(32, 'Progress 32%'), visible=False,
|
| 99 |
+
elem_id='progress-bar', elem_classes='progress-bar')
|
| 100 |
+
gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', visible=True, height=768,
|
| 101 |
+
elem_classes=['resizable_area', 'main_view', 'final_gallery', 'image_gallery'],
|
| 102 |
+
elem_id='final_gallery')
|
| 103 |
+
with gr.Row(elem_classes='type_row'):
|
| 104 |
+
with gr.Column(scale=17):
|
| 105 |
+
prompt = gr.Textbox(show_label=False, placeholder="Type prompt here or paste parameters.", elem_id='positive_prompt',
|
| 106 |
+
container=False, autofocus=True, elem_classes='type_row', lines=1024)
|
| 107 |
+
|
| 108 |
+
default_prompt = modules.config.default_prompt
|
| 109 |
+
if isinstance(default_prompt, str) and default_prompt != '':
|
| 110 |
+
shared.gradio_root.load(lambda: default_prompt, outputs=prompt)
|
| 111 |
+
|
| 112 |
+
with gr.Column(scale=3, min_width=0):
|
| 113 |
+
generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True)
|
| 114 |
+
load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False)
|
| 115 |
+
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False)
|
| 116 |
+
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
|
| 117 |
+
|
| 118 |
+
def stop_clicked():
|
| 119 |
+
import ldm_patched.modules.model_management as model_management
|
| 120 |
+
shared.last_stop = 'stop'
|
| 121 |
+
model_management.interrupt_current_processing()
|
| 122 |
+
return [gr.update(interactive=False)] * 2
|
| 123 |
+
|
| 124 |
+
def skip_clicked():
|
| 125 |
+
import ldm_patched.modules.model_management as model_management
|
| 126 |
+
shared.last_stop = 'skip'
|
| 127 |
+
model_management.interrupt_current_processing()
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
stop_button.click(stop_clicked, outputs=[skip_button, stop_button],
|
| 131 |
+
queue=False, show_progress=False, _js='cancelGenerateForever')
|
| 132 |
+
skip_button.click(skip_clicked, queue=False, show_progress=False)
|
| 133 |
+
with gr.Row(elem_classes='advanced_check_row'):
|
| 134 |
+
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
|
| 135 |
+
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
|
| 136 |
+
with gr.Row(visible=False) as image_input_panel:
|
| 137 |
+
with gr.Tabs():
|
| 138 |
+
with gr.TabItem(label='Upscale or Variation') as uov_tab:
|
| 139 |
+
with gr.Row():
|
| 140 |
+
with gr.Column():
|
| 141 |
+
uov_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy')
|
| 142 |
+
with gr.Column():
|
| 143 |
+
uov_method = gr.Radio(label='Upscale or Variation:', choices=flags.uov_list, value=flags.disabled)
|
| 144 |
+
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/390" target="_blank">\U0001F4D4 Document</a>')
|
| 145 |
+
with gr.TabItem(label='Image Prompt') as ip_tab:
|
| 146 |
+
with gr.Row():
|
| 147 |
+
ip_images = []
|
| 148 |
+
ip_types = []
|
| 149 |
+
ip_stops = []
|
| 150 |
+
ip_weights = []
|
| 151 |
+
ip_ctrls = []
|
| 152 |
+
ip_ad_cols = []
|
| 153 |
+
for _ in range(4):
|
| 154 |
+
with gr.Column():
|
| 155 |
+
ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300)
|
| 156 |
+
ip_images.append(ip_image)
|
| 157 |
+
ip_ctrls.append(ip_image)
|
| 158 |
+
with gr.Column(visible=False) as ad_col:
|
| 159 |
+
with gr.Row():
|
| 160 |
+
default_end, default_weight = flags.default_parameters[flags.default_ip]
|
| 161 |
+
|
| 162 |
+
ip_stop = gr.Slider(label='Stop At', minimum=0.0, maximum=1.0, step=0.001, value=default_end)
|
| 163 |
+
ip_stops.append(ip_stop)
|
| 164 |
+
ip_ctrls.append(ip_stop)
|
| 165 |
+
|
| 166 |
+
ip_weight = gr.Slider(label='Weight', minimum=0.0, maximum=2.0, step=0.001, value=default_weight)
|
| 167 |
+
ip_weights.append(ip_weight)
|
| 168 |
+
ip_ctrls.append(ip_weight)
|
| 169 |
+
|
| 170 |
+
ip_type = gr.Radio(label='Type', choices=flags.ip_list, value=flags.default_ip, container=False)
|
| 171 |
+
ip_types.append(ip_type)
|
| 172 |
+
ip_ctrls.append(ip_type)
|
| 173 |
+
|
| 174 |
+
ip_type.change(lambda x: flags.default_parameters[x], inputs=[ip_type], outputs=[ip_stop, ip_weight], queue=False, show_progress=False)
|
| 175 |
+
ip_ad_cols.append(ad_col)
|
| 176 |
+
ip_advanced = gr.Checkbox(label='Advanced', value=False, container=False)
|
| 177 |
+
gr.HTML('* \"Image Prompt\" is powered by Fooocus Image Mixture Engine (v1.0.1). <a href="https://github.com/lllyasviel/Fooocus/discussions/557" target="_blank">\U0001F4D4 Document</a>')
|
| 178 |
+
|
| 179 |
+
def ip_advance_checked(x):
|
| 180 |
+
return [gr.update(visible=x)] * len(ip_ad_cols) + \
|
| 181 |
+
[flags.default_ip] * len(ip_types) + \
|
| 182 |
+
[flags.default_parameters[flags.default_ip][0]] * len(ip_stops) + \
|
| 183 |
+
[flags.default_parameters[flags.default_ip][1]] * len(ip_weights)
|
| 184 |
+
|
| 185 |
+
ip_advanced.change(ip_advance_checked, inputs=ip_advanced,
|
| 186 |
+
outputs=ip_ad_cols + ip_types + ip_stops + ip_weights,
|
| 187 |
+
queue=False, show_progress=False)
|
| 188 |
+
with gr.TabItem(label='Inpaint or Outpaint') as inpaint_tab:
|
| 189 |
+
with gr.Row():
|
| 190 |
+
inpaint_input_image = grh.Image(label='Drag inpaint or outpaint image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas')
|
| 191 |
+
inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy', height=500, visible=False)
|
| 192 |
+
|
| 193 |
+
with gr.Row():
|
| 194 |
+
inpaint_additional_prompt = gr.Textbox(placeholder="Describe what you want to inpaint.", elem_id='inpaint_additional_prompt', label='Inpaint Additional Prompt', visible=False)
|
| 195 |
+
outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint Direction')
|
| 196 |
+
inpaint_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_default, label='Method')
|
| 197 |
+
example_inpaint_prompts = gr.Dataset(samples=modules.config.example_inpaint_prompts, label='Additional Prompt Quick List', components=[inpaint_additional_prompt], visible=False)
|
| 198 |
+
gr.HTML('* Powered by Fooocus Inpaint Engine <a href="https://github.com/lllyasviel/Fooocus/discussions/414" target="_blank">\U0001F4D4 Document</a>')
|
| 199 |
+
example_inpaint_prompts.click(lambda x: x[0], inputs=example_inpaint_prompts, outputs=inpaint_additional_prompt, show_progress=False, queue=False)
|
| 200 |
+
with gr.TabItem(label='Describe') as desc_tab:
|
| 201 |
+
with gr.Row():
|
| 202 |
+
with gr.Column():
|
| 203 |
+
desc_input_image = grh.Image(label='Drag any image to here', source='upload', type='numpy')
|
| 204 |
+
with gr.Column():
|
| 205 |
+
desc_method = gr.Radio(
|
| 206 |
+
label='Content Type',
|
| 207 |
+
choices=[flags.desc_type_photo, flags.desc_type_anime],
|
| 208 |
+
value=flags.desc_type_photo)
|
| 209 |
+
desc_btn = gr.Button(value='Describe this Image into Prompt')
|
| 210 |
+
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/1363" target="_blank">\U0001F4D4 Document</a>')
|
| 211 |
+
switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
|
| 212 |
+
down_js = "() => {viewer_to_bottom();}"
|
| 213 |
+
|
| 214 |
+
input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox,
|
| 215 |
+
outputs=image_input_panel, queue=False, show_progress=False, _js=switch_js)
|
| 216 |
+
ip_advanced.change(lambda: None, queue=False, show_progress=False, _js=down_js)
|
| 217 |
+
|
| 218 |
+
current_tab = gr.Textbox(value='uov', visible=False)
|
| 219 |
+
uov_tab.select(lambda: 'uov', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
| 220 |
+
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
| 221 |
+
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
| 222 |
+
desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
| 223 |
+
|
| 224 |
+
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
|
| 225 |
+
with gr.Tab(label='Setting'):
|
| 226 |
+
performance_selection = gr.Radio(label='Performance',
|
| 227 |
+
choices=modules.flags.performance_selections,
|
| 228 |
+
value=modules.config.default_performance)
|
| 229 |
+
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
|
| 230 |
+
value=modules.config.default_aspect_ratio, info='width × height',
|
| 231 |
+
elem_classes='aspect_ratios')
|
| 232 |
+
image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number)
|
| 233 |
+
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
|
| 234 |
+
info='Describing what you do not want to see.', lines=2,
|
| 235 |
+
elem_id='negative_prompt',
|
| 236 |
+
value=modules.config.default_prompt_negative)
|
| 237 |
+
seed_random = gr.Checkbox(label='Random', value=True)
|
| 238 |
+
image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354
|
| 239 |
+
|
| 240 |
+
def random_checked(r):
|
| 241 |
+
return gr.update(visible=not r)
|
| 242 |
+
|
| 243 |
+
def refresh_seed(r, seed_string):
|
| 244 |
+
if r:
|
| 245 |
+
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
| 246 |
+
else:
|
| 247 |
+
try:
|
| 248 |
+
seed_value = int(seed_string)
|
| 249 |
+
if constants.MIN_SEED <= seed_value <= constants.MAX_SEED:
|
| 250 |
+
return seed_value
|
| 251 |
+
except ValueError:
|
| 252 |
+
pass
|
| 253 |
+
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
| 254 |
+
|
| 255 |
+
seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed],
|
| 256 |
+
queue=False, show_progress=False)
|
| 257 |
+
|
| 258 |
+
if not args_manager.args.disable_image_log:
|
| 259 |
+
gr.HTML(f'<a href="file={get_current_html_path()}" target="_blank">\U0001F4DA History Log</a>')
|
| 260 |
+
|
| 261 |
+
with gr.Tab(label='Style'):
|
| 262 |
+
style_sorter.try_load_sorted_styles(
|
| 263 |
+
style_names=legal_style_names,
|
| 264 |
+
default_selected=modules.config.default_styles)
|
| 265 |
+
|
| 266 |
+
style_search_bar = gr.Textbox(show_label=False, container=False,
|
| 267 |
+
placeholder="\U0001F50E Type here to search styles ...",
|
| 268 |
+
value="",
|
| 269 |
+
label='Search Styles')
|
| 270 |
+
style_selections = gr.CheckboxGroup(show_label=False, container=False,
|
| 271 |
+
choices=copy.deepcopy(style_sorter.all_styles),
|
| 272 |
+
value=copy.deepcopy(modules.config.default_styles),
|
| 273 |
+
label='Selected Styles',
|
| 274 |
+
elem_classes=['style_selections'])
|
| 275 |
+
gradio_receiver_style_selections = gr.Textbox(elem_id='gradio_receiver_style_selections', visible=False)
|
| 276 |
+
|
| 277 |
+
shared.gradio_root.load(lambda: gr.update(choices=copy.deepcopy(style_sorter.all_styles)),
|
| 278 |
+
outputs=style_selections)
|
| 279 |
+
|
| 280 |
+
style_search_bar.change(style_sorter.search_styles,
|
| 281 |
+
inputs=[style_selections, style_search_bar],
|
| 282 |
+
outputs=style_selections,
|
| 283 |
+
queue=False,
|
| 284 |
+
show_progress=False).then(
|
| 285 |
+
lambda: None, _js='()=>{refresh_style_localization();}')
|
| 286 |
+
|
| 287 |
+
gradio_receiver_style_selections.input(style_sorter.sort_styles,
|
| 288 |
+
inputs=style_selections,
|
| 289 |
+
outputs=style_selections,
|
| 290 |
+
queue=False,
|
| 291 |
+
show_progress=False).then(
|
| 292 |
+
lambda: None, _js='()=>{refresh_style_localization();}')
|
| 293 |
+
|
| 294 |
+
with gr.Tab(label='Model'):
|
| 295 |
+
with gr.Group():
|
| 296 |
+
with gr.Row():
|
| 297 |
+
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
|
| 298 |
+
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
|
| 299 |
+
|
| 300 |
+
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
|
| 301 |
+
info='Use 0.4 for SD1.5 realistic models; '
|
| 302 |
+
'or 0.667 for SD1.5 anime models; '
|
| 303 |
+
'or 0.8 for XL-refiners; '
|
| 304 |
+
'or any value for switching two SDXL models.',
|
| 305 |
+
value=modules.config.default_refiner_switch,
|
| 306 |
+
visible=modules.config.default_refiner_model_name != 'None')
|
| 307 |
+
|
| 308 |
+
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
|
| 309 |
+
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
|
| 310 |
+
|
| 311 |
+
with gr.Group():
|
| 312 |
+
lora_ctrls = []
|
| 313 |
+
|
| 314 |
+
for i, (n, v) in enumerate(modules.config.default_loras):
|
| 315 |
+
with gr.Row():
|
| 316 |
+
lora_model = gr.Dropdown(label=f'LoRA {i + 1}',
|
| 317 |
+
choices=['None'] + modules.config.lora_filenames, value=n)
|
| 318 |
+
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v,
|
| 319 |
+
elem_classes='lora_weight')
|
| 320 |
+
lora_ctrls += [lora_model, lora_weight]
|
| 321 |
+
|
| 322 |
+
with gr.Row():
|
| 323 |
+
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
| 324 |
+
with gr.Tab(label='Advanced'):
|
| 325 |
+
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01,
|
| 326 |
+
value=modules.config.default_cfg_scale,
|
| 327 |
+
info='Higher value means style is cleaner, vivider, and more artistic.')
|
| 328 |
+
sharpness = gr.Slider(label='Image Sharpness', minimum=0.0, maximum=30.0, step=0.001,
|
| 329 |
+
value=modules.config.default_sample_sharpness,
|
| 330 |
+
info='Higher value means image and texture are sharper.')
|
| 331 |
+
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
|
| 332 |
+
dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
|
| 333 |
+
|
| 334 |
+
with gr.Column(visible=False) as dev_tools:
|
| 335 |
+
with gr.Tab(label='Debug Tools'):
|
| 336 |
+
adm_scaler_positive = gr.Slider(label='Positive ADM Guidance Scaler', minimum=0.1, maximum=3.0,
|
| 337 |
+
step=0.001, value=1.5, info='The scaler multiplied to positive ADM (use 1.0 to disable). ')
|
| 338 |
+
adm_scaler_negative = gr.Slider(label='Negative ADM Guidance Scaler', minimum=0.1, maximum=3.0,
|
| 339 |
+
step=0.001, value=0.8, info='The scaler multiplied to negative ADM (use 1.0 to disable). ')
|
| 340 |
+
adm_scaler_end = gr.Slider(label='ADM Guidance End At Step', minimum=0.0, maximum=1.0,
|
| 341 |
+
step=0.001, value=0.3,
|
| 342 |
+
info='When to end the guidance from positive/negative ADM. ')
|
| 343 |
+
|
| 344 |
+
refiner_swap_method = gr.Dropdown(label='Refiner swap method', value='joint',
|
| 345 |
+
choices=['joint', 'separate', 'vae'])
|
| 346 |
+
|
| 347 |
+
adaptive_cfg = gr.Slider(label='CFG Mimicking from TSNR', minimum=1.0, maximum=30.0, step=0.01,
|
| 348 |
+
value=modules.config.default_cfg_tsnr,
|
| 349 |
+
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
|
| 350 |
+
'(effective when real CFG > mimicked CFG).')
|
| 351 |
+
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
|
| 352 |
+
value=modules.config.default_sampler)
|
| 353 |
+
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
|
| 354 |
+
value=modules.config.default_scheduler)
|
| 355 |
+
|
| 356 |
+
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
|
| 357 |
+
info='(Experimental) This may cause performance problems on some computers and certain internet conditions.',
|
| 358 |
+
value=False)
|
| 359 |
+
|
| 360 |
+
overwrite_step = gr.Slider(label='Forced Overwrite of Sampling Step',
|
| 361 |
+
minimum=-1, maximum=200, step=1,
|
| 362 |
+
value=modules.config.default_overwrite_step,
|
| 363 |
+
info='Set as -1 to disable. For developer debugging.')
|
| 364 |
+
overwrite_switch = gr.Slider(label='Forced Overwrite of Refiner Switch Step',
|
| 365 |
+
minimum=-1, maximum=200, step=1,
|
| 366 |
+
value=modules.config.default_overwrite_switch,
|
| 367 |
+
info='Set as -1 to disable. For developer debugging.')
|
| 368 |
+
overwrite_width = gr.Slider(label='Forced Overwrite of Generating Width',
|
| 369 |
+
minimum=-1, maximum=2048, step=1, value=-1,
|
| 370 |
+
info='Set as -1 to disable. For developer debugging. '
|
| 371 |
+
'Results will be worse for non-standard numbers that SDXL is not trained on.')
|
| 372 |
+
overwrite_height = gr.Slider(label='Forced Overwrite of Generating Height',
|
| 373 |
+
minimum=-1, maximum=2048, step=1, value=-1,
|
| 374 |
+
info='Set as -1 to disable. For developer debugging. '
|
| 375 |
+
'Results will be worse for non-standard numbers that SDXL is not trained on.')
|
| 376 |
+
overwrite_vary_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Vary"',
|
| 377 |
+
minimum=-1, maximum=1.0, step=0.001, value=-1,
|
| 378 |
+
info='Set as negative number to disable. For developer debugging.')
|
| 379 |
+
overwrite_upscale_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Upscale"',
|
| 380 |
+
minimum=-1, maximum=1.0, step=0.001, value=-1,
|
| 381 |
+
info='Set as negative number to disable. For developer debugging.')
|
| 382 |
+
disable_preview = gr.Checkbox(label='Disable Preview', value=False,
|
| 383 |
+
info='Disable preview during generation.')
|
| 384 |
+
|
| 385 |
+
with gr.Tab(label='Control'):
|
| 386 |
+
debugging_cn_preprocessor = gr.Checkbox(label='Debug Preprocessors', value=False,
|
| 387 |
+
info='See the results from preprocessors.')
|
| 388 |
+
skipping_cn_preprocessor = gr.Checkbox(label='Skip Preprocessors', value=False,
|
| 389 |
+
info='Do not preprocess images. (Inputs are already canny/depth/cropped-face/etc.)')
|
| 390 |
+
|
| 391 |
+
mixing_image_prompt_and_vary_upscale = gr.Checkbox(label='Mixing Image Prompt and Vary/Upscale',
|
| 392 |
+
value=False)
|
| 393 |
+
mixing_image_prompt_and_inpaint = gr.Checkbox(label='Mixing Image Prompt and Inpaint',
|
| 394 |
+
value=False)
|
| 395 |
+
|
| 396 |
+
controlnet_softness = gr.Slider(label='Softness of ControlNet', minimum=0.0, maximum=1.0,
|
| 397 |
+
step=0.001, value=0.25,
|
| 398 |
+
info='Similar to the Control Mode in A1111 (use 0.0 to disable). ')
|
| 399 |
+
|
| 400 |
+
with gr.Tab(label='Canny'):
|
| 401 |
+
canny_low_threshold = gr.Slider(label='Canny Low Threshold', minimum=1, maximum=255,
|
| 402 |
+
step=1, value=64)
|
| 403 |
+
canny_high_threshold = gr.Slider(label='Canny High Threshold', minimum=1, maximum=255,
|
| 404 |
+
step=1, value=128)
|
| 405 |
+
|
| 406 |
+
with gr.Tab(label='Inpaint'):
|
| 407 |
+
debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False)
|
| 408 |
+
inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False)
|
| 409 |
+
inpaint_engine = gr.Dropdown(label='Inpaint Engine',
|
| 410 |
+
value=modules.config.default_inpaint_engine_version,
|
| 411 |
+
choices=flags.inpaint_engine_versions,
|
| 412 |
+
info='Version of Fooocus inpaint model')
|
| 413 |
+
inpaint_strength = gr.Slider(label='Inpaint Denoising Strength',
|
| 414 |
+
minimum=0.0, maximum=1.0, step=0.001, value=1.0,
|
| 415 |
+
info='Same as the denoising strength in A1111 inpaint. '
|
| 416 |
+
'Only used in inpaint, not used in outpaint. '
|
| 417 |
+
'(Outpaint always use 1.0)')
|
| 418 |
+
inpaint_respective_field = gr.Slider(label='Inpaint Respective Field',
|
| 419 |
+
minimum=0.0, maximum=1.0, step=0.001, value=0.618,
|
| 420 |
+
info='The area to inpaint. '
|
| 421 |
+
'Value 0 is same as "Only Masked" in A1111. '
|
| 422 |
+
'Value 1 is same as "Whole Image" in A1111. '
|
| 423 |
+
'Only used in inpaint, not used in outpaint. '
|
| 424 |
+
'(Outpaint always use 1.0)')
|
| 425 |
+
inpaint_erode_or_dilate = gr.Slider(label='Mask Erode or Dilate',
|
| 426 |
+
minimum=-64, maximum=64, step=1, value=0,
|
| 427 |
+
info='Positive value will make white area in the mask larger, '
|
| 428 |
+
'negative value will make white area smaller.'
|
| 429 |
+
'(default is 0, always process before any mask invert)')
|
| 430 |
+
inpaint_mask_upload_checkbox = gr.Checkbox(label='Enable Mask Upload', value=False)
|
| 431 |
+
invert_mask_checkbox = gr.Checkbox(label='Invert Mask', value=False)
|
| 432 |
+
|
| 433 |
+
inpaint_ctrls = [debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine,
|
| 434 |
+
inpaint_strength, inpaint_respective_field,
|
| 435 |
+
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate]
|
| 436 |
+
|
| 437 |
+
inpaint_mask_upload_checkbox.change(lambda x: gr.update(visible=x),
|
| 438 |
+
inputs=inpaint_mask_upload_checkbox,
|
| 439 |
+
outputs=inpaint_mask_image, queue=False, show_progress=False)
|
| 440 |
+
|
| 441 |
+
with gr.Tab(label='FreeU'):
|
| 442 |
+
freeu_enabled = gr.Checkbox(label='Enabled', value=False)
|
| 443 |
+
freeu_b1 = gr.Slider(label='B1', minimum=0, maximum=2, step=0.01, value=1.01)
|
| 444 |
+
freeu_b2 = gr.Slider(label='B2', minimum=0, maximum=2, step=0.01, value=1.02)
|
| 445 |
+
freeu_s1 = gr.Slider(label='S1', minimum=0, maximum=4, step=0.01, value=0.99)
|
| 446 |
+
freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95)
|
| 447 |
+
freeu_ctrls = [freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2]
|
| 448 |
+
|
| 449 |
+
adps = [disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name,
|
| 450 |
+
scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height,
|
| 451 |
+
overwrite_vary_strength, overwrite_upscale_strength,
|
| 452 |
+
mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint,
|
| 453 |
+
debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness,
|
| 454 |
+
canny_low_threshold, canny_high_threshold, refiner_swap_method]
|
| 455 |
+
adps += freeu_ctrls
|
| 456 |
+
adps += inpaint_ctrls
|
| 457 |
+
|
| 458 |
+
def dev_mode_checked(r):
|
| 459 |
+
return gr.update(visible=r)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
|
| 463 |
+
queue=False, show_progress=False)
|
| 464 |
+
|
| 465 |
+
def model_refresh_clicked():
|
| 466 |
+
modules.config.update_all_model_names()
|
| 467 |
+
results = []
|
| 468 |
+
results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
|
| 469 |
+
for i in range(5):
|
| 470 |
+
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
| 471 |
+
return results
|
| 472 |
+
|
| 473 |
+
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
| 474 |
+
queue=False, show_progress=False)
|
| 475 |
+
|
| 476 |
+
performance_selection.change(lambda x: [gr.update(interactive=x != 'Extreme Speed')] * 11 +
|
| 477 |
+
[gr.update(visible=x != 'Extreme Speed')] * 1,
|
| 478 |
+
inputs=performance_selection,
|
| 479 |
+
outputs=[
|
| 480 |
+
guidance_scale, sharpness, adm_scaler_end, adm_scaler_positive,
|
| 481 |
+
adm_scaler_negative, refiner_switch, refiner_model, sampler_name,
|
| 482 |
+
scheduler_name, adaptive_cfg, refiner_swap_method, negative_prompt
|
| 483 |
+
], queue=False, show_progress=False)
|
| 484 |
+
|
| 485 |
+
advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, advanced_column,
|
| 486 |
+
queue=False, show_progress=False) \
|
| 487 |
+
.then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False)
|
| 488 |
+
|
| 489 |
+
def inpaint_mode_change(mode):
|
| 490 |
+
assert mode in modules.flags.inpaint_options
|
| 491 |
+
|
| 492 |
+
# inpaint_additional_prompt, outpaint_selections, example_inpaint_prompts,
|
| 493 |
+
# inpaint_disable_initial_latent, inpaint_engine,
|
| 494 |
+
# inpaint_strength, inpaint_respective_field
|
| 495 |
+
|
| 496 |
+
if mode == modules.flags.inpaint_option_detail:
|
| 497 |
+
return [
|
| 498 |
+
gr.update(visible=True), gr.update(visible=False, value=[]),
|
| 499 |
+
gr.Dataset.update(visible=True, samples=modules.config.example_inpaint_prompts),
|
| 500 |
+
False, 'None', 0.5, 0.0
|
| 501 |
+
]
|
| 502 |
+
|
| 503 |
+
if mode == modules.flags.inpaint_option_modify:
|
| 504 |
+
return [
|
| 505 |
+
gr.update(visible=True), gr.update(visible=False, value=[]),
|
| 506 |
+
gr.Dataset.update(visible=False, samples=modules.config.example_inpaint_prompts),
|
| 507 |
+
True, modules.config.default_inpaint_engine_version, 1.0, 0.0
|
| 508 |
+
]
|
| 509 |
+
|
| 510 |
+
return [
|
| 511 |
+
gr.update(visible=False, value=''), gr.update(visible=True),
|
| 512 |
+
gr.Dataset.update(visible=False, samples=modules.config.example_inpaint_prompts),
|
| 513 |
+
False, modules.config.default_inpaint_engine_version, 1.0, 0.618
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
inpaint_mode.input(inpaint_mode_change, inputs=inpaint_mode, outputs=[
|
| 517 |
+
inpaint_additional_prompt, outpaint_selections, example_inpaint_prompts,
|
| 518 |
+
inpaint_disable_initial_latent, inpaint_engine,
|
| 519 |
+
inpaint_strength, inpaint_respective_field
|
| 520 |
+
], show_progress=False, queue=False)
|
| 521 |
+
|
| 522 |
+
ctrls = [
|
| 523 |
+
prompt, negative_prompt, style_selections,
|
| 524 |
+
performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale
|
| 525 |
+
]
|
| 526 |
+
|
| 527 |
+
ctrls += [base_model, refiner_model, refiner_switch] + lora_ctrls
|
| 528 |
+
ctrls += [input_image_checkbox, current_tab]
|
| 529 |
+
ctrls += [uov_method, uov_input_image]
|
| 530 |
+
ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
|
| 531 |
+
ctrls += ip_ctrls
|
| 532 |
+
|
| 533 |
+
state_is_generating = gr.State(False)
|
| 534 |
+
|
| 535 |
+
def parse_meta(raw_prompt_txt, is_generating):
|
| 536 |
+
loaded_json = None
|
| 537 |
+
try:
|
| 538 |
+
if '{' in raw_prompt_txt:
|
| 539 |
+
if '}' in raw_prompt_txt:
|
| 540 |
+
if ':' in raw_prompt_txt:
|
| 541 |
+
loaded_json = json.loads(raw_prompt_txt)
|
| 542 |
+
assert isinstance(loaded_json, dict)
|
| 543 |
+
except:
|
| 544 |
+
loaded_json = None
|
| 545 |
+
|
| 546 |
+
if loaded_json is None:
|
| 547 |
+
if is_generating:
|
| 548 |
+
return gr.update(), gr.update(), gr.update()
|
| 549 |
+
else:
|
| 550 |
+
return gr.update(), gr.update(visible=True), gr.update(visible=False)
|
| 551 |
+
|
| 552 |
+
return json.dumps(loaded_json), gr.update(visible=False), gr.update(visible=True)
|
| 553 |
+
|
| 554 |
+
prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
|
| 555 |
+
|
| 556 |
+
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=[
|
| 557 |
+
advanced_checkbox,
|
| 558 |
+
image_number,
|
| 559 |
+
prompt,
|
| 560 |
+
negative_prompt,
|
| 561 |
+
style_selections,
|
| 562 |
+
performance_selection,
|
| 563 |
+
aspect_ratios_selection,
|
| 564 |
+
overwrite_width,
|
| 565 |
+
overwrite_height,
|
| 566 |
+
sharpness,
|
| 567 |
+
guidance_scale,
|
| 568 |
+
adm_scaler_positive,
|
| 569 |
+
adm_scaler_negative,
|
| 570 |
+
adm_scaler_end,
|
| 571 |
+
base_model,
|
| 572 |
+
refiner_model,
|
| 573 |
+
refiner_switch,
|
| 574 |
+
sampler_name,
|
| 575 |
+
scheduler_name,
|
| 576 |
+
seed_random,
|
| 577 |
+
image_seed,
|
| 578 |
+
generate_button,
|
| 579 |
+
load_parameter_button
|
| 580 |
+
] + lora_ctrls, queue=False, show_progress=False)
|
| 581 |
+
|
| 582 |
+
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), [], True),
|
| 583 |
+
outputs=[stop_button, skip_button, generate_button, gallery, state_is_generating]) \
|
| 584 |
+
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
|
| 585 |
+
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
|
| 586 |
+
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
|
| 587 |
+
.then(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), gr.update(visible=False, interactive=False), False),
|
| 588 |
+
outputs=[generate_button, stop_button, skip_button, state_is_generating]) \
|
| 589 |
+
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')
|
| 590 |
+
|
| 591 |
+
for notification_file in ['notification.ogg', 'notification.mp3']:
|
| 592 |
+
if os.path.exists(notification_file):
|
| 593 |
+
gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False)
|
| 594 |
+
break
|
| 595 |
+
|
| 596 |
+
def trigger_describe(mode, img):
|
| 597 |
+
if mode == flags.desc_type_photo:
|
| 598 |
+
from extras.interrogate import default_interrogator as default_interrogator_photo
|
| 599 |
+
return default_interrogator_photo(img), ["Fooocus V2", "Fooocus Enhance", "Fooocus Sharp"]
|
| 600 |
+
if mode == flags.desc_type_anime:
|
| 601 |
+
from extras.wd14tagger import default_interrogator as default_interrogator_anime
|
| 602 |
+
return default_interrogator_anime(img), ["Fooocus V2", "Fooocus Masterpiece"]
|
| 603 |
+
return mode, ["Fooocus V2"]
|
| 604 |
+
|
| 605 |
+
desc_btn.click(trigger_describe, inputs=[desc_method, desc_input_image],
|
| 606 |
+
outputs=[prompt, style_selections], show_progress=True, queue=True)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def dump_default_english_config():
|
| 610 |
+
from modules.localization import dump_english_config
|
| 611 |
+
dump_english_config(grh.all_components)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
# dump_default_english_config()
|
| 615 |
+
|
| 616 |
+
shared.gradio_root.launch(
|
| 617 |
+
inbrowser=args_manager.args.in_browser,
|
| 618 |
+
server_name=args_manager.args.listen,
|
| 619 |
+
server_port=args_manager.args.port,
|
| 620 |
+
share=args_manager.args.share,
|
| 621 |
+
auth=check_auth if args_manager.args.share and auth_enabled else None,
|
| 622 |
+
blocked_paths=[constants.AUTH_FILENAME]
|
| 623 |
+
)
|