Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Copyright (c)
|
| 2 |
|
| 3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
# of this software and associated documentation files (the "Software"), to deal
|
|
@@ -55,7 +55,7 @@ from typing import Tuple, List, Literal, Optional, Union
|
|
| 55 |
from tqdm import tqdm
|
| 56 |
from PIL import Image
|
| 57 |
|
| 58 |
-
from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
| 59 |
|
| 60 |
|
| 61 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
|
@@ -73,7 +73,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
| 73 |
return noise_cfg
|
| 74 |
|
| 75 |
|
| 76 |
-
class
|
| 77 |
def __init__(
|
| 78 |
self,
|
| 79 |
device: torch.device,
|
|
@@ -93,7 +93,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 93 |
has_i2t: bool = True,
|
| 94 |
lora_weight: float = 1.0,
|
| 95 |
) -> None:
|
| 96 |
-
r"""Stabilized
|
| 97 |
|
| 98 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
| 99 |
Model while preserving mask fidelity and quality.
|
|
@@ -131,7 +131,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 131 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
| 132 |
where each mask covered by other masks is reduced in its alpha
|
| 133 |
value by this specified factor.
|
| 134 |
-
t_index_list (List[int]): The default scheduling for
|
| 135 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
| 136 |
defines the mask quantization modes. Details in the codes of
|
| 137 |
`self.process_mask`. Basically, this (subtly) controls the
|
|
@@ -170,10 +170,10 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 170 |
model_key = hf_key
|
| 171 |
lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
|
| 172 |
|
| 173 |
-
self.pipe =
|
| 174 |
self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
|
| 175 |
self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
|
| 176 |
-
|
| 177 |
else:
|
| 178 |
model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
| 179 |
variant = 'fp16'
|
|
@@ -212,7 +212,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 212 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
| 213 |
|
| 214 |
# Prepare white background for bootstrapping.
|
| 215 |
-
|
| 216 |
|
| 217 |
print(f'[INFO] Model is loaded!')
|
| 218 |
|
|
@@ -691,7 +691,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 691 |
25, 37], the masks are split into binary masks whose values are
|
| 692 |
greater than these levels. This results in tradual increase of mask
|
| 693 |
region as the timesteps increase. Details are described in our
|
| 694 |
-
paper
|
| 695 |
|
| 696 |
On the Three Modes of `mask_type`:
|
| 697 |
`self.mask_type` is predefined at the initialization stage of this
|
|
@@ -949,6 +949,9 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 949 |
boostrap_mix_steps: Optional[float] = None,
|
| 950 |
bootstrap_leak_sensitivity: Optional[float] = None,
|
| 951 |
preprocess_mask_cover_alpha: Optional[float] = None,
|
|
|
|
|
|
|
|
|
|
| 952 |
) -> Image.Image:
|
| 953 |
r"""Arbitrary-size image generation from multiple pairs of (regional)
|
| 954 |
text prompt-mask pairs.
|
|
@@ -957,7 +960,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 957 |
|
| 958 |
Example:
|
| 959 |
>>> device = torch.device('cuda:0')
|
| 960 |
-
>>> smd =
|
| 961 |
>>> prompts = {... specify prompts}
|
| 962 |
>>> masks = {... specify mask tensors}
|
| 963 |
>>> height, width = masks.shape[-2:]
|
|
@@ -1046,7 +1049,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1046 |
|
| 1047 |
# prompts is None: return background.
|
| 1048 |
# masks is None but prompts is not None: return prompts
|
| 1049 |
-
# masks is not None and prompts is not None: Do
|
| 1050 |
|
| 1051 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
| 1052 |
if background is None and background_prompt is not None:
|
|
@@ -1157,27 +1160,22 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1157 |
|
| 1158 |
# SDXL pipeline settings.
|
| 1159 |
batch_size = 1
|
| 1160 |
-
output_type = 'pil'
|
| 1161 |
-
|
| 1162 |
-
guidance_rescale = 0.7
|
| 1163 |
-
|
| 1164 |
-
prompt_2 = None
|
| 1165 |
-
device = self.device
|
| 1166 |
num_images_per_prompt = 1
|
| 1167 |
-
negative_prompt_2 = None
|
| 1168 |
|
| 1169 |
original_size = (height, width)
|
| 1170 |
target_size = (height, width)
|
| 1171 |
crops_coords_top_left = (0, 0)
|
| 1172 |
-
negative_crops_coords_top_left = (0, 0)
|
| 1173 |
negative_original_size = None
|
| 1174 |
negative_target_size = None
|
| 1175 |
-
|
| 1176 |
-
negative_pooled_prompt_embeds = None
|
| 1177 |
-
text_encoder_lora_scale = None
|
| 1178 |
|
|
|
|
|
|
|
| 1179 |
prompt_embeds = None
|
| 1180 |
negative_prompt_embeds = None
|
|
|
|
|
|
|
|
|
|
| 1181 |
|
| 1182 |
(
|
| 1183 |
prompt_embeds,
|
|
@@ -1187,7 +1185,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1187 |
) = self.encode_prompt(
|
| 1188 |
prompt=prompts,
|
| 1189 |
prompt_2=prompt_2,
|
| 1190 |
-
device=device,
|
| 1191 |
num_images_per_prompt=num_images_per_prompt,
|
| 1192 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 1193 |
negative_prompt=negative_prompts,
|
|
@@ -1199,30 +1197,6 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1199 |
lora_scale=text_encoder_lora_scale,
|
| 1200 |
)
|
| 1201 |
|
| 1202 |
-
add_text_embeds = pooled_prompt_embeds
|
| 1203 |
-
if self.text_encoder_2 is None:
|
| 1204 |
-
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 1205 |
-
else:
|
| 1206 |
-
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 1207 |
-
|
| 1208 |
-
add_time_ids = self._get_add_time_ids(
|
| 1209 |
-
original_size,
|
| 1210 |
-
crops_coords_top_left,
|
| 1211 |
-
target_size,
|
| 1212 |
-
dtype=prompt_embeds.dtype,
|
| 1213 |
-
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 1214 |
-
)
|
| 1215 |
-
if negative_original_size is not None and negative_target_size is not None:
|
| 1216 |
-
negative_add_time_ids = self._get_add_time_ids(
|
| 1217 |
-
negative_original_size,
|
| 1218 |
-
negative_crops_coords_top_left,
|
| 1219 |
-
negative_target_size,
|
| 1220 |
-
dtype=prompt_embeds.dtype,
|
| 1221 |
-
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 1222 |
-
)
|
| 1223 |
-
else:
|
| 1224 |
-
negative_add_time_ids = add_time_ids
|
| 1225 |
-
|
| 1226 |
if has_background:
|
| 1227 |
# First channel is background prompt text embeds. Background prompt itself is not used for generation.
|
| 1228 |
s = prompt_strengths
|
|
@@ -1248,10 +1222,26 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1248 |
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
| 1249 |
fu = fu.repeat(num_prompts, 1, 1)
|
| 1250 |
negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1251 |
elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
|
| 1252 |
# # negative prompts = 1; # prompts > 1.
|
| 1253 |
assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
|
| 1254 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
|
|
|
|
|
|
|
|
|
|
| 1255 |
# assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
|
| 1256 |
if num_masks > num_prompts:
|
| 1257 |
assert masks.shape[0] == num_masks and num_prompts == 1
|
|
@@ -1259,6 +1249,34 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1259 |
if negative_prompt_embeds is not None:
|
| 1260 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
|
| 1261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1262 |
# SDXL pipeline settings.
|
| 1263 |
if do_classifier_free_guidance:
|
| 1264 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
|
@@ -1266,19 +1284,25 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1266 |
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 1267 |
del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
|
| 1268 |
|
| 1269 |
-
prompt_embeds = prompt_embeds.to(device)
|
| 1270 |
-
add_text_embeds = add_text_embeds.to(device)
|
| 1271 |
-
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 1272 |
|
| 1273 |
|
| 1274 |
### Run
|
| 1275 |
|
| 1276 |
# Latent initialization.
|
|
|
|
| 1277 |
if self.timesteps[0] < 999 and has_background:
|
| 1278 |
-
|
| 1279 |
else:
|
| 1280 |
-
|
| 1281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
|
| 1283 |
# Tiling (if needed).
|
| 1284 |
if height > tile_size or width > tile_size:
|
|
@@ -1287,9 +1311,9 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1287 |
tile_masks = tile_masks.to(self.device)
|
| 1288 |
else:
|
| 1289 |
views = [(0, h, 0, w)]
|
| 1290 |
-
tile_masks =
|
| 1291 |
-
value = torch.zeros_like(
|
| 1292 |
-
count_all = torch.zeros_like(
|
| 1293 |
|
| 1294 |
with torch.autocast('cuda'):
|
| 1295 |
for i, t in enumerate(tqdm(self.timesteps)):
|
|
@@ -1300,7 +1324,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1300 |
count_all.zero_()
|
| 1301 |
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
| 1302 |
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
|
| 1303 |
-
|
| 1304 |
|
| 1305 |
# Additional arguments for the SDXL pipeline.
|
| 1306 |
add_time_ids_input = add_time_ids.clone()
|
|
@@ -1312,16 +1336,16 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1312 |
if i < bootstrap_steps:
|
| 1313 |
mix_ratio = min(1, max(0, boostrap_mix_steps - i))
|
| 1314 |
# Treat the first foreground latent as the background latent if one does not exist.
|
| 1315 |
-
|
| 1316 |
white_ = white[..., h_start:h_end, w_start:w_end]
|
| 1317 |
-
white_ = self.scheduler_add_noise(white_,
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
|
| 1321 |
# Centering.
|
| 1322 |
-
|
| 1323 |
|
| 1324 |
-
latent_model_input = torch.cat([
|
| 1325 |
latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
|
| 1326 |
|
| 1327 |
# Perform one step of the reverse diffusion.
|
|
@@ -1341,33 +1365,32 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1341 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1342 |
|
| 1343 |
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
| 1344 |
-
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 1345 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
|
| 1346 |
|
| 1347 |
-
|
| 1348 |
|
| 1349 |
if i < bootstrap_steps:
|
| 1350 |
# Uncentering.
|
| 1351 |
-
|
| 1352 |
|
| 1353 |
# Remove leakage (optional).
|
| 1354 |
-
leak = (
|
| 1355 |
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
|
| 1356 |
fg_mask_ = fg_mask_ * leak_sigmoid
|
| 1357 |
|
| 1358 |
# Mix the latents.
|
| 1359 |
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
|
| 1360 |
-
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ *
|
| 1361 |
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
|
| 1362 |
|
| 1363 |
-
|
| 1364 |
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
|
| 1365 |
if has_background:
|
| 1366 |
-
|
| 1367 |
|
| 1368 |
# Noise is added after mixing.
|
| 1369 |
if i < len(self.timesteps) - 1:
|
| 1370 |
-
|
| 1371 |
|
| 1372 |
if not output_type == "latent":
|
| 1373 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
@@ -1375,7 +1398,7 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1375 |
|
| 1376 |
if needs_upcasting:
|
| 1377 |
self.upcast_vae()
|
| 1378 |
-
|
| 1379 |
|
| 1380 |
# unscale/denormalize the latents
|
| 1381 |
# denormalize with the mean and std if available and not None
|
|
@@ -1383,22 +1406,22 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1383 |
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
| 1384 |
if has_latents_mean and has_latents_std:
|
| 1385 |
latents_mean = (
|
| 1386 |
-
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(
|
| 1387 |
)
|
| 1388 |
latents_std = (
|
| 1389 |
-
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(
|
| 1390 |
)
|
| 1391 |
-
|
| 1392 |
else:
|
| 1393 |
-
|
| 1394 |
|
| 1395 |
-
image = self.vae.decode(
|
| 1396 |
|
| 1397 |
# cast back to fp16 if needed
|
| 1398 |
if needs_upcasting:
|
| 1399 |
self.vae.to(dtype=torch.float16)
|
| 1400 |
else:
|
| 1401 |
-
image =
|
| 1402 |
|
| 1403 |
# Return PIL Image.
|
| 1404 |
image = image[0].clip_(-1, 1) * 0.5 + 0.5
|
|
@@ -1407,4 +1430,4 @@ class StableMultiDiffusionSDXLPipeline(nn.Module):
|
|
| 1407 |
image = blend(image, background[0], fg_mask)
|
| 1408 |
else:
|
| 1409 |
image = T.ToPILImage()(image)
|
| 1410 |
-
return image
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Jaerin Lee
|
| 2 |
|
| 3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
# of this software and associated documentation files (the "Software"), to deal
|
|
|
|
| 55 |
from tqdm import tqdm
|
| 56 |
from PIL import Image
|
| 57 |
|
| 58 |
+
from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
| 59 |
|
| 60 |
|
| 61 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
|
|
|
| 73 |
return noise_cfg
|
| 74 |
|
| 75 |
|
| 76 |
+
class SemanticDrawSDXLPipeline(nn.Module):
|
| 77 |
def __init__(
|
| 78 |
self,
|
| 79 |
device: torch.device,
|
|
|
|
| 93 |
has_i2t: bool = True,
|
| 94 |
lora_weight: float = 1.0,
|
| 95 |
) -> None:
|
| 96 |
+
r"""Stabilized regionally assigned texts-to-image generation for fast sampling.
|
| 97 |
|
| 98 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
| 99 |
Model while preserving mask fidelity and quality.
|
|
|
|
| 131 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
| 132 |
where each mask covered by other masks is reduced in its alpha
|
| 133 |
value by this specified factor.
|
| 134 |
+
t_index_list (List[int]): The default scheduling for scheduler.
|
| 135 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
| 136 |
defines the mask quantization modes. Details in the codes of
|
| 137 |
`self.process_mask`. Basically, this (subtly) controls the
|
|
|
|
| 170 |
model_key = hf_key
|
| 171 |
lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
|
| 172 |
|
| 173 |
+
self.pipe = load_model(model_key, 'xl', self.device, self.dtype)
|
| 174 |
self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
|
| 175 |
self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
|
| 176 |
+
self.pipe.fuse_lora()
|
| 177 |
else:
|
| 178 |
model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
| 179 |
variant = 'fp16'
|
|
|
|
| 212 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
| 213 |
|
| 214 |
# Prepare white background for bootstrapping.
|
| 215 |
+
self.get_white_background(1024, 1024)
|
| 216 |
|
| 217 |
print(f'[INFO] Model is loaded!')
|
| 218 |
|
|
|
|
| 691 |
25, 37], the masks are split into binary masks whose values are
|
| 692 |
greater than these levels. This results in tradual increase of mask
|
| 693 |
region as the timesteps increase. Details are described in our
|
| 694 |
+
paper.
|
| 695 |
|
| 696 |
On the Three Modes of `mask_type`:
|
| 697 |
`self.mask_type` is predefined at the initialization stage of this
|
|
|
|
| 949 |
boostrap_mix_steps: Optional[float] = None,
|
| 950 |
bootstrap_leak_sensitivity: Optional[float] = None,
|
| 951 |
preprocess_mask_cover_alpha: Optional[float] = None,
|
| 952 |
+
# SDXL Pipeline setting.
|
| 953 |
+
guidance_rescale: float = 0.7,
|
| 954 |
+
output_type = 'pil',
|
| 955 |
) -> Image.Image:
|
| 956 |
r"""Arbitrary-size image generation from multiple pairs of (regional)
|
| 957 |
text prompt-mask pairs.
|
|
|
|
| 960 |
|
| 961 |
Example:
|
| 962 |
>>> device = torch.device('cuda:0')
|
| 963 |
+
>>> smd = SemanticDrawPipeline(device)
|
| 964 |
>>> prompts = {... specify prompts}
|
| 965 |
>>> masks = {... specify mask tensors}
|
| 966 |
>>> height, width = masks.shape[-2:]
|
|
|
|
| 1049 |
|
| 1050 |
# prompts is None: return background.
|
| 1051 |
# masks is None but prompts is not None: return prompts
|
| 1052 |
+
# masks is not None and prompts is not None: Do SemanticDraw.
|
| 1053 |
|
| 1054 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
| 1055 |
if background is None and background_prompt is not None:
|
|
|
|
| 1160 |
|
| 1161 |
# SDXL pipeline settings.
|
| 1162 |
batch_size = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1163 |
num_images_per_prompt = 1
|
|
|
|
| 1164 |
|
| 1165 |
original_size = (height, width)
|
| 1166 |
target_size = (height, width)
|
| 1167 |
crops_coords_top_left = (0, 0)
|
|
|
|
| 1168 |
negative_original_size = None
|
| 1169 |
negative_target_size = None
|
| 1170 |
+
negative_crops_coords_top_left = (0, 0)
|
|
|
|
|
|
|
| 1171 |
|
| 1172 |
+
prompt_2 = None
|
| 1173 |
+
negative_prompt_2 = None
|
| 1174 |
prompt_embeds = None
|
| 1175 |
negative_prompt_embeds = None
|
| 1176 |
+
pooled_prompt_embeds = None
|
| 1177 |
+
negative_pooled_prompt_embeds = None
|
| 1178 |
+
text_encoder_lora_scale = None
|
| 1179 |
|
| 1180 |
(
|
| 1181 |
prompt_embeds,
|
|
|
|
| 1185 |
) = self.encode_prompt(
|
| 1186 |
prompt=prompts,
|
| 1187 |
prompt_2=prompt_2,
|
| 1188 |
+
device=self.device,
|
| 1189 |
num_images_per_prompt=num_images_per_prompt,
|
| 1190 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 1191 |
negative_prompt=negative_prompts,
|
|
|
|
| 1197 |
lora_scale=text_encoder_lora_scale,
|
| 1198 |
)
|
| 1199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1200 |
if has_background:
|
| 1201 |
# First channel is background prompt text embeds. Background prompt itself is not used for generation.
|
| 1202 |
s = prompt_strengths
|
|
|
|
| 1222 |
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
| 1223 |
fu = fu.repeat(num_prompts, 1, 1)
|
| 1224 |
negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
|
| 1225 |
+
|
| 1226 |
+
be = pooled_prompt_embeds[:1]
|
| 1227 |
+
fe = pooled_prompt_embeds[1:]
|
| 1228 |
+
pooled_prompt_embeds = torch.lerp(be, fe, s[..., 0]) # (p, 1280)
|
| 1229 |
+
|
| 1230 |
+
if negative_pooled_prompt_embeds is not None:
|
| 1231 |
+
bu = negative_pooled_prompt_embeds[:1]
|
| 1232 |
+
fu = negative_pooled_prompt_embeds[1:]
|
| 1233 |
+
if num_prompts > num_nprompts:
|
| 1234 |
+
# # negative prompts = 1; # prompts > 1.
|
| 1235 |
+
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
| 1236 |
+
fu = fu.repeat(num_prompts, 1)
|
| 1237 |
+
negative_pooled_prompt_embeds = torch.lerp(bu, fu, s[..., 0]) # (n, 1280)
|
| 1238 |
elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
|
| 1239 |
# # negative prompts = 1; # prompts > 1.
|
| 1240 |
assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
|
| 1241 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
|
| 1242 |
+
|
| 1243 |
+
assert negative_pooled_prompt_embeds.shape[0] == 1 and pooled_prompt_embeds.shape[0] == num_prompts
|
| 1244 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_prompts, 1)
|
| 1245 |
# assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
|
| 1246 |
if num_masks > num_prompts:
|
| 1247 |
assert masks.shape[0] == num_masks and num_prompts == 1
|
|
|
|
| 1249 |
if negative_prompt_embeds is not None:
|
| 1250 |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
|
| 1251 |
|
| 1252 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_masks, 1)
|
| 1253 |
+
if negative_pooled_prompt_embeds is not None:
|
| 1254 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_masks, 1)
|
| 1255 |
+
|
| 1256 |
+
add_text_embeds = pooled_prompt_embeds
|
| 1257 |
+
if self.text_encoder_2 is None:
|
| 1258 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 1259 |
+
else:
|
| 1260 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 1261 |
+
|
| 1262 |
+
add_time_ids = self._get_add_time_ids(
|
| 1263 |
+
original_size,
|
| 1264 |
+
crops_coords_top_left,
|
| 1265 |
+
target_size,
|
| 1266 |
+
dtype=prompt_embeds.dtype,
|
| 1267 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 1268 |
+
)
|
| 1269 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 1270 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 1271 |
+
negative_original_size,
|
| 1272 |
+
negative_crops_coords_top_left,
|
| 1273 |
+
negative_target_size,
|
| 1274 |
+
dtype=prompt_embeds.dtype,
|
| 1275 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 1276 |
+
)
|
| 1277 |
+
else:
|
| 1278 |
+
negative_add_time_ids = add_time_ids
|
| 1279 |
+
|
| 1280 |
# SDXL pipeline settings.
|
| 1281 |
if do_classifier_free_guidance:
|
| 1282 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
|
|
|
| 1284 |
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 1285 |
del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
|
| 1286 |
|
| 1287 |
+
prompt_embeds = prompt_embeds.to(self.device)
|
| 1288 |
+
add_text_embeds = add_text_embeds.to(self.device)
|
| 1289 |
+
add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
|
| 1290 |
|
| 1291 |
|
| 1292 |
### Run
|
| 1293 |
|
| 1294 |
# Latent initialization.
|
| 1295 |
+
noise = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
| 1296 |
if self.timesteps[0] < 999 and has_background:
|
| 1297 |
+
latent = self.scheduler_add_noise(bg_latent, noise, 0, initial=True)
|
| 1298 |
else:
|
| 1299 |
+
noise = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
| 1300 |
+
latent = noise * self.scheduler.init_noise_sigma
|
| 1301 |
+
|
| 1302 |
+
if has_background:
|
| 1303 |
+
noise_bg_latents = [
|
| 1304 |
+
self.scheduler_add_noise(bg_latent, noise, i, initial=True) for i in range(len(self.timesteps))
|
| 1305 |
+
] + [bg_latent]
|
| 1306 |
|
| 1307 |
# Tiling (if needed).
|
| 1308 |
if height > tile_size or width > tile_size:
|
|
|
|
| 1311 |
tile_masks = tile_masks.to(self.device)
|
| 1312 |
else:
|
| 1313 |
views = [(0, h, 0, w)]
|
| 1314 |
+
tile_masks = latent.new_ones((1, 1, h, w))
|
| 1315 |
+
value = torch.zeros_like(latent)
|
| 1316 |
+
count_all = torch.zeros_like(latent)
|
| 1317 |
|
| 1318 |
with torch.autocast('cuda'):
|
| 1319 |
for i, t in enumerate(tqdm(self.timesteps)):
|
|
|
|
| 1324 |
count_all.zero_()
|
| 1325 |
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
| 1326 |
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
|
| 1327 |
+
latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
|
| 1328 |
|
| 1329 |
# Additional arguments for the SDXL pipeline.
|
| 1330 |
add_time_ids_input = add_time_ids.clone()
|
|
|
|
| 1336 |
if i < bootstrap_steps:
|
| 1337 |
mix_ratio = min(1, max(0, boostrap_mix_steps - i))
|
| 1338 |
# Treat the first foreground latent as the background latent if one does not exist.
|
| 1339 |
+
bg_latent_ = noise_bg_latents[i][..., h_start:h_end, w_start:w_end] if has_background else latent_[:1]
|
| 1340 |
white_ = white[..., h_start:h_end, w_start:w_end]
|
| 1341 |
+
white_ = self.scheduler_add_noise(white_, noise[..., h_start:h_end, w_start:w_end], i, initial=True)
|
| 1342 |
+
bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_
|
| 1343 |
+
latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_
|
| 1344 |
|
| 1345 |
# Centering.
|
| 1346 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True)
|
| 1347 |
|
| 1348 |
+
latent_model_input = torch.cat([latent_] * 2) if do_classifier_free_guidance else latent_
|
| 1349 |
latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
|
| 1350 |
|
| 1351 |
# Perform one step of the reverse diffusion.
|
|
|
|
| 1365 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1366 |
|
| 1367 |
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
|
|
|
| 1368 |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
|
| 1369 |
|
| 1370 |
+
latent_ = self.scheduler_step(noise_pred, i, latent_)
|
| 1371 |
|
| 1372 |
if i < bootstrap_steps:
|
| 1373 |
# Uncentering.
|
| 1374 |
+
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_)
|
| 1375 |
|
| 1376 |
# Remove leakage (optional).
|
| 1377 |
+
leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
|
| 1378 |
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
|
| 1379 |
fg_mask_ = fg_mask_ * leak_sigmoid
|
| 1380 |
|
| 1381 |
# Mix the latents.
|
| 1382 |
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
|
| 1383 |
+
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True)
|
| 1384 |
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
|
| 1385 |
|
| 1386 |
+
latent = torch.where(count_all > 0, value / count_all, value)
|
| 1387 |
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
|
| 1388 |
if has_background:
|
| 1389 |
+
latent = (1 - bg_mask) * latent + bg_mask * noise_bg_latents[i + 1] # bg_latent
|
| 1390 |
|
| 1391 |
# Noise is added after mixing.
|
| 1392 |
if i < len(self.timesteps) - 1:
|
| 1393 |
+
latent = self.scheduler_add_noise(latent, None, i + 1)
|
| 1394 |
|
| 1395 |
if not output_type == "latent":
|
| 1396 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
|
|
| 1398 |
|
| 1399 |
if needs_upcasting:
|
| 1400 |
self.upcast_vae()
|
| 1401 |
+
latent = latent.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 1402 |
|
| 1403 |
# unscale/denormalize the latents
|
| 1404 |
# denormalize with the mean and std if available and not None
|
|
|
|
| 1406 |
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
| 1407 |
if has_latents_mean and has_latents_std:
|
| 1408 |
latents_mean = (
|
| 1409 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latent.device, latent.dtype)
|
| 1410 |
)
|
| 1411 |
latents_std = (
|
| 1412 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latent.device, latent.dtype)
|
| 1413 |
)
|
| 1414 |
+
latent = latent * latents_std / self.vae.config.scaling_factor + latents_mean
|
| 1415 |
else:
|
| 1416 |
+
latent = latent / self.vae.config.scaling_factor
|
| 1417 |
|
| 1418 |
+
image = self.vae.decode(latent, return_dict=False)[0]
|
| 1419 |
|
| 1420 |
# cast back to fp16 if needed
|
| 1421 |
if needs_upcasting:
|
| 1422 |
self.vae.to(dtype=torch.float16)
|
| 1423 |
else:
|
| 1424 |
+
image = latent
|
| 1425 |
|
| 1426 |
# Return PIL Image.
|
| 1427 |
image = image[0].clip_(-1, 1) * 0.5 + 0.5
|
|
|
|
| 1430 |
image = blend(image, background[0], fg_mask)
|
| 1431 |
else:
|
| 1432 |
image = T.ToPILImage()(image)
|
| 1433 |
+
return image
|