File size: 29,570 Bytes
31112ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 |
import math
from typing import Callable, Literal
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from PIL import Image
from torch import Tensor
from .model import Flux
from .modules.autoencoder import AutoEncoder
from .modules.conditioner import HFEmbedder
from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
from .util import PREFERED_KONTEXT_RESOLUTIONS
import torchvision.transforms.functional as TVF
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
*,
channels: int = 16,
patch_size: int = 2,
):
generator = torch.Generator(device=device).manual_seed(seed)
if channels == 16 and patch_size == 2:
return torch.randn(
num_samples,
channels,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
dtype=dtype,
device=device,
generator=generator,
)
return torch.randn(
num_samples,
channels,
height,
width,
dtype=dtype,
device=device,
generator=generator,
)
def prepare_prompt(t5: HFEmbedder, clip: HFEmbedder, bs: int, prompt: str | list[str], neg: bool = False, device: str = "cuda") -> dict[str, Tensor]:
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"neg_txt" if neg else "txt": txt.to(device),
"neg_txt_ids" if neg else "txt_ids": txt_ids.to(device),
"neg_vec" if neg else "vec": vec.to(device),
}
def prepare_img(img: Tensor, patch_size: int = 2) -> dict[str, Tensor]:
bs, c, h, w = img.shape
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // patch_size, w // patch_size, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // patch_size)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // patch_size)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
}
def prepare_redux(
t5: HFEmbedder,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
encoder: ReduxImageEncoder,
img_cond_path: str,
) -> dict[str, Tensor]:
bs, _, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img_cond = Image.open(img_cond_path).convert("RGB")
with torch.no_grad():
img_cond = encoder(img_cond)
img_cond = img_cond.to(torch.bfloat16)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
txt = torch.cat((txt, img_cond.to(txt)), dim=-2)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def resizeinput(img):
multiple_of = 16
image_height, image_width = img.height, img.width
aspect_ratio = image_width / image_height
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
if (image_width, image_height) != img.size:
img = img.resize((image_width, image_height), Image.LANCZOS)
return img
def build_mask(target_width, target_height, img_mask, device):
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
return {
"img_msk_latents": image_mask_latents,
"img_msk_rebuilt": image_mask_rebuilt,
}
def prepare_kontext(
ae: AutoEncoder | None,
img_cond_list: list,
seed: int,
device: torch.device,
target_width: int | None = None,
target_height: int | None = None,
bs: int = 1,
img_mask = None,
*,
patch_size: int = 2,
noise_channels: int = 16,
) -> tuple[dict[str, Tensor], int, int]:
# load and encode the conditioning image
res_match_output = img_mask is not None
img_cond_seq = None
img_cond_seq_ids = None
if img_cond_list == None: img_cond_list = []
height_offset = 0
width_offset = 0
for cond_no, img_cond in enumerate(img_cond_list):
if res_match_output:
if img_cond.size != (target_width, target_height):
img_cond = img_cond.resize((target_width, target_height), Image.Resampling.LANCZOS)
else:
img_cond = resizeinput(img_cond)
width, height = img_cond.size
width, height = width // 8, height // 8
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
if ae is None:
raise ValueError("Image conditioning is not supported for this model.")
with torch.no_grad():
img_cond_latents = ae.encode(img_cond.to(device))
img_cond_latents = img_cond_latents.to(torch.bfloat16)
img_cond_latents = rearrange(img_cond_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img_cond.shape[0] == 1 and bs > 1:
img_cond_latents = repeat(img_cond_latents, "1 ... -> bs ...", bs=bs)
img_cond = None
# image ids are the same as base image with the first dimension set to 1
# instead of 0
img_cond_ids = torch.zeros(height // 2, width // 2, 3)
img_cond_ids[..., 0] = 1
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] + height_offset
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] + width_offset
img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
height_offset += height // 2
width_offset += width // 2
if target_width is None:
target_width = 8 * width
if target_height is None:
target_height = 8 * height
img_cond_ids = img_cond_ids.to(device)
if cond_no == 0:
img_cond_seq, img_cond_seq_ids = img_cond_latents, img_cond_ids
else:
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1)
return_dict = {
"img_cond_seq": img_cond_seq,
"img_cond_seq_ids": img_cond_seq_ids,
}
if img_mask is not None:
return_dict.update(build_mask(target_width, target_height, img_mask, device))
img = get_noise(
bs,
target_height,
target_width,
device=device,
dtype=torch.bfloat16,
seed=seed,
channels=noise_channels,
patch_size=patch_size,
)
return_dict.update(prepare_img(img, patch_size=patch_size))
return return_dict, target_height, target_width
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule_flux2(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps.tolist()
def get_schedule_piflux2(num_steps: int, image_seq_len: int) -> list[float]:
"""
pi-FLUX.2 FlowMapSDE schedule with shift=3.2 and final_step_size_scale=0.5.
"""
if num_steps <= 0:
return [0.0]
shift = 3.2
final_step_size_scale = 0.5
end = (final_step_size_scale - 1.0) / (num_steps + final_step_size_scale - 1.0)
step = (end - 1.0) / num_steps
raw_timesteps = 1.0 + step * torch.arange(num_steps, dtype=torch.float32)
raw_timesteps = raw_timesteps.clamp(min=0)
sigmas = shift * raw_timesteps / (1 + (shift - 1) * raw_timesteps)
sigmas = torch.cat([sigmas, torch.zeros(1, dtype=sigmas.dtype)])
return sigmas.tolist()
def _flow_map_sde_warp_t(raw_t: Tensor, shift: float) -> Tensor:
return shift * raw_t / (1 + (shift - 1) * raw_t)
def _flow_map_sde_unwarp_t(sigma_t: Tensor, shift: float) -> Tensor:
return sigma_t / (shift + (1 - shift) * sigma_t)
def _flow_map_sde_calculate_sigmas_dst(
sigmas: Tensor, h: float = 0.0, eps: float = 1e-6
) -> tuple[Tensor, Tensor]:
sigmas_src = sigmas[:-1]
sigmas_to = sigmas[1:]
alphas_src = 1 - sigmas_src
alphas_to = 1 - sigmas_to
if h <= 0.0:
m_vals = torch.ones_like(sigmas_src)
else:
h2 = h * h
m_vals = (sigmas_to * alphas_src / (sigmas_src * alphas_to).clamp(min=eps)) ** h2
sigmas_to_mul_m = sigmas_to * m_vals
sigmas_dst = sigmas_to_mul_m / (alphas_to + sigmas_to_mul_m).clamp(min=eps)
return sigmas_dst, m_vals
def _gmflow_posterior_mean(
sigma_t_src: Tensor,
sigma_t: Tensor,
x_t_src: Tensor,
x_t: Tensor,
gm_means: Tensor,
gm_vars: Tensor,
gm_logweights: Tensor,
*,
eps: float = 1e-6,
gm_dim: int = -4,
channel_dim: int = -3,
) -> Tensor:
sigma_t_src = sigma_t_src.clamp(min=eps)
sigma_t = sigma_t.clamp(min=eps)
alpha_t_src = 1 - sigma_t_src
alpha_t = 1 - sigma_t
alpha_over_sigma_t_src = alpha_t_src / sigma_t_src
alpha_over_sigma_t = alpha_t / sigma_t
zeta = alpha_over_sigma_t.square() - alpha_over_sigma_t_src.square()
nu = alpha_over_sigma_t * x_t / sigma_t - alpha_over_sigma_t_src * x_t_src / sigma_t_src
nu = nu.unsqueeze(gm_dim)
zeta = zeta.unsqueeze(gm_dim)
denom = (gm_vars * zeta + 1).clamp(min=eps)
out_means = (gm_vars * nu + gm_means) / denom
logweights_delta = (gm_means * (nu - 0.5 * zeta * gm_means)).sum(dim=channel_dim, keepdim=True) / denom
out_weights = (gm_logweights + logweights_delta).softmax(dim=gm_dim)
return (out_means * out_weights).sum(dim=gm_dim)
class _GMFlowPolicy:
def __init__(
self,
denoising_output: dict[str, Tensor],
x_t_src: Tensor,
sigma_t_src: Tensor,
eps: float = 1e-4,
) -> None:
self.x_t_src = x_t_src
self.ndim = x_t_src.dim()
self.eps = eps
self.sigma_t_src = sigma_t_src.reshape(
*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1])
)
self.denoising_output_x_0 = self._u_to_x_0(denoising_output, self.x_t_src, self.sigma_t_src)
@staticmethod
def _u_to_x_0(denoising_output: dict[str, Tensor], x_t: Tensor, sigma_t: Tensor) -> dict[str, Tensor]:
x_t = x_t.unsqueeze(1)
sigma_t = sigma_t.unsqueeze(1)
means_x_0 = x_t - sigma_t * denoising_output["means"]
gm_vars = (denoising_output["logstds"] * 2).exp() * sigma_t.square()
return {"means": means_x_0, "gm_vars": gm_vars, "logweights": denoising_output["logweights"]}
def pi(self, x_t: Tensor, sigma_t: Tensor) -> Tensor:
sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
means = self.denoising_output_x_0["means"]
gm_vars = self.denoising_output_x_0["gm_vars"]
logweights = self.denoising_output_x_0["logweights"]
if (sigma_t == self.sigma_t_src).all() and (x_t == self.x_t_src).all():
x_0 = (logweights.softmax(dim=1) * means).sum(dim=1)
else:
x_0 = _gmflow_posterior_mean(
self.sigma_t_src,
sigma_t,
self.x_t_src,
x_t,
means,
gm_vars,
logweights,
eps=self.eps,
)
return (x_t - x_0) / sigma_t.clamp(min=self.eps)
def temperature_(self, temperature: float) -> None:
if temperature >= 1.0:
return
temperature = max(temperature, self.eps)
gm = self.denoising_output_x_0
gm["logweights"] = (gm["logweights"] / temperature).log_softmax(dim=1)
gm["gm_vars"] = gm["gm_vars"] * temperature
def _policy_rollout(
x_t_start: Tensor,
sigma_t_start: Tensor,
sigma_t_end: Tensor,
total_substeps: int,
policy: _GMFlowPolicy,
*,
shift: float = 3.2,
) -> Tensor:
num_batches = x_t_start.size(0)
ndim = x_t_start.dim()
sigma_t_start = sigma_t_start.reshape(num_batches, *((ndim - 1) * [1]))
sigma_t_end = sigma_t_end.reshape(num_batches, *((ndim - 1) * [1]))
raw_t_start = _flow_map_sde_unwarp_t(sigma_t_start, shift)
raw_t_end = _flow_map_sde_unwarp_t(sigma_t_end, shift)
delta_raw_t = raw_t_start - raw_t_end
num_substeps = (delta_raw_t * total_substeps).round().to(torch.long).clamp(min=1)
substep_size = delta_raw_t / num_substeps
max_num_substeps = num_substeps.max()
raw_t = raw_t_start
sigma_t = sigma_t_start
x_t = x_t_start
for substep_id in range(max_num_substeps.item()):
u = policy.pi(x_t, sigma_t)
raw_t_minus = (raw_t - substep_size).clamp(min=0)
sigma_t_minus = _flow_map_sde_warp_t(raw_t_minus, shift)
x_t_minus = x_t + u * (sigma_t_minus - sigma_t)
active_mask = num_substeps > substep_id
x_t = torch.where(active_mask, x_t_minus, x_t)
sigma_t = torch.where(active_mask, sigma_t_minus, sigma_t)
raw_t = torch.where(active_mask, raw_t_minus, raw_t)
return x_t
def _unpack_latent_piflux2(x: Tensor, patch_size: int = 2) -> Tensor:
bsz, packed_channels, h, w = x.shape
channels = packed_channels // (patch_size * patch_size)
x = x.view(bsz, channels, patch_size, patch_size, h, w)
x = x.permute(0, 1, 4, 2, 5, 3).reshape(bsz, channels, h * patch_size, w * patch_size)
return x
def _pack_latent_piflux2(x: Tensor, patch_size: int = 2) -> Tensor:
bsz, channels, h, w = x.shape
h_packed = h // patch_size
w_packed = w // patch_size
x = x.view(bsz, channels, h_packed, patch_size, w_packed, patch_size)
x = x.permute(0, 1, 3, 5, 2, 4).reshape(
bsz, channels * patch_size * patch_size, h_packed, w_packed
)
return x
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
real_guidance_scale = None,
final_step_size_scale: float | None = None,
# extra img tokens (channel-wise)
neg_txt: Tensor = None,
neg_txt_ids: Tensor= None,
neg_vec: Tensor = None,
img_cond: Tensor | None = None,
# extra img tokens (sequence-wise)
img_cond_seq: Tensor | None = None,
img_cond_seq_ids: Tensor | None = None,
siglip_embedding = None,
siglip_embedding_ids = None,
callback=None,
pipeline=None,
loras_slists=None,
unpack_latent = None,
joint_pass= False,
img_msk_latents = None,
img_msk_rebuilt = None,
denoising_strength = 1,
masking_strength = 1,
preview_meta = None,
original_image_latents = None,
# pi-Flow settings
piflow_substeps: int = 128,
piflow_generator: torch.Generator | None = None,
piflow_gm_temperature: float | None = None,
):
kwargs = {
'pipeline': pipeline,
'callback': callback,
"img_len": img.shape[1],
"siglip_embedding": siglip_embedding,
"siglip_embedding_ids": siglip_embedding_ids,
}
if callback != None:
callback(-1, None, True)
original_timesteps = timesteps
is_piflow = getattr(model, "piflow", False)
piflow_spatial_h = piflow_spatial_w = None
piflow_sigmas = piflow_sigmas_dst = piflow_m_vals = None
morph, first_step = False, 0
if img_msk_latents is not None:
if original_image_latents is None: original_image_latents= img_cond_seq.clone()
randn = torch.randn_like(original_image_latents)
if denoising_strength < 1.:
first_step = int(len(timesteps[:-1]) * (1. - denoising_strength))
masked_steps = math.ceil(len(timesteps[:-1]) * masking_strength)
if not morph:
latent_noise_factor = timesteps[first_step]
latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
img = latents.to(img)
latents = None
timesteps = timesteps[first_step:]
if is_piflow:
base_img_ids = img_ids[:, :img.shape[1]]
piflow_spatial_h = int(base_img_ids[..., 1].max().item() + 1)
piflow_spatial_w = int(base_img_ids[..., 2].max().item() + 1)
updated_num_steps= len(timesteps) -1
if callback != None:
from shared.utils.loras_mutipliers import update_loras_slists
update_loras_slists(model, loras_slists, len(original_timesteps))
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
from mmgp import offload
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
if is_piflow and len(timesteps) > 1:
piflow_sigmas = torch.tensor(timesteps, device=img.device, dtype=torch.float32)
piflow_sigmas_dst, piflow_m_vals = _flow_map_sde_calculate_sigmas_dst(
piflow_sigmas, h=0.0
)
if piflow_gm_temperature is None:
nfe = len(timesteps) - 1
piflow_gm_temperature = min(max(0.1 * (nfe - 1), 0.0), 1.0)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
offload.set_step_no_for_lora(model, first_step + i)
if pipeline._interrupt:
return None
if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph:
latent_noise_factor = t_curr/1000
img = original_image_latents * (1.0 - latent_noise_factor) + img * latent_noise_factor
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
img_input = img
img_input_ids = img_ids
if img_cond is not None:
img_input = torch.cat((img, img_cond), dim=-1)
if img_cond_seq is not None:
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
if not joint_pass or real_guidance_scale == 1:
pred = model(
img=img_input,
img_ids=img_input_ids,
txt_list=[txt],
txt_ids_list=[txt_ids],
y_list=[vec],
timesteps=t_vec,
guidance=guidance_vec,
**kwargs
)[0]
if pred == None: return None
if real_guidance_scale> 1:
neg_pred = model(
img=img_input,
img_ids=img_input_ids,
txt_list=[neg_txt],
txt_ids_list=[neg_txt_ids],
y_list=[neg_vec],
timesteps=t_vec,
guidance=guidance_vec,
**kwargs
)[0]
if neg_pred == None: return None
else:
pred, neg_pred = model(
img=img_input,
img_ids=img_input_ids,
txt_list=[txt, neg_txt],
txt_ids_list=[txt_ids, neg_txt_ids],
y_list=[vec, neg_vec],
timesteps=t_vec,
guidance=guidance_vec,
**kwargs
)
if pred == None: return None
if is_piflow and isinstance(pred, dict):
if real_guidance_scale > 1:
pred = {k: neg_pred[k] + real_guidance_scale * (pred[k] - neg_pred[k]) for k in pred}
patch_size = getattr(model, "piflow_patch_size", 2)
img_packed = rearrange(
img, "b (h w) c -> b c h w", h=piflow_spatial_h, w=piflow_spatial_w
)
img_unpacked = _unpack_latent_piflux2(img_packed, patch_size=patch_size).float()
pred = {k: v.float() for k, v in pred.items()}
sigma_t_src = piflow_sigmas[i].expand(img.shape[0])
sigma_t_dst = piflow_sigmas_dst[i].expand(img.shape[0])
policy = _GMFlowPolicy(pred, img_unpacked, sigma_t_src)
if (
piflow_gm_temperature is not None
and i != len(timesteps) - 2
and piflow_gm_temperature < 1.0
):
policy.temperature_(piflow_gm_temperature)
img_unpacked = _policy_rollout(
img_unpacked,
sigma_t_src,
sigma_t_dst,
total_substeps=piflow_substeps,
policy=policy,
shift=3.2,
)
sigma_t_to = piflow_sigmas[i + 1]
m = piflow_m_vals[i]
alpha_t_to = 1 - sigma_t_to
if not torch.allclose(m, torch.ones_like(m)):
if piflow_generator is not None and img_unpacked.device.type == "cpu":
noise = torch.randn_like(img_unpacked, generator=piflow_generator)
else:
noise = torch.randn(img_unpacked.shape, device=img_unpacked.device, dtype=img_unpacked.dtype)
img_unpacked = (alpha_t_to + sigma_t_to * m) * img_unpacked + sigma_t_to * (
1 - m.square()
).clamp(min=0).sqrt() * noise
img_packed = _pack_latent_piflux2(img_unpacked, patch_size=patch_size)
img = rearrange(img_packed, "b c h w -> b (h w) c").to(img.dtype)
else:
if real_guidance_scale > 1:
pred = neg_pred + real_guidance_scale * (pred - neg_pred)
step_size = t_prev - t_curr
if final_step_size_scale is not None and i == len(timesteps) - 2:
step_size = step_size * final_step_size_scale
img += step_size * pred
if img_msk_latents is not None and i < masked_steps:
latent_noise_factor = t_prev
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
img = noisy_image * (1-img_msk_latents) + img_msk_latents * img
noisy_image = None
if callback is not None:
preview = unpack_latent(img).transpose(0,1)
callback(i, preview, False, preview_meta=preview_meta)
return img
def prepare_multi_ip(
ae: AutoEncoder,
img_cond_list: list,
seed: int,
device: torch.device,
target_width: int | None = None,
target_height: int | None = None,
bs: int = 1,
pe: Literal["d", "h", "w", "o"] = "d",
conditions_zero_start = False,
set_cond_index = False,
res_match_output = True,
patch_size: int = 2
) -> dict[str, Tensor]:
assert pe in ["d", "h", "w", "o"]
if img_cond_list == None: img_cond_list = []
if not res_match_output:
for i, img_cond in enumerate(img_cond_list):
img_cond_list[i]= resizeinput(img_cond)
ref_imgs = [
ae.encode(
(TVF.to_tensor(ref_img) * 2.0 - 1.0)
.unsqueeze(0)
.to(device, torch.float32)
).to(torch.bfloat16)
for ref_img in img_cond_list
]
img = get_noise( bs, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed)
bs, c, h, w = img.shape
# tgt img
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // patch_size, w // patch_size, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // patch_size)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // patch_size)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_cond_seq = img_cond_seq_ids = None
if conditions_zero_start:
pe_shift_w = pe_shift_h = 0
else:
pe_shift_w, pe_shift_h = w // patch_size, h // patch_size
for cond_no, ref_img in enumerate(ref_imgs):
_, _, ref_h1, ref_w1 = ref_img.shape
ref_img = rearrange(
ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
)
if ref_img.shape[0] == 1 and bs > 1:
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
if set_cond_index:
ref_img_ids1[..., 0] = cond_no + 1
h_offset = pe_shift_h if pe in {"d", "h"} else 0
w_offset = pe_shift_w if pe in {"d", "w"} else 0
ref_img_ids1[..., 1] = (
ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
)
ref_img_ids1[..., 2] = (
ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
)
ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
if target_width is None:
target_width = 8 * ref_w1
if target_height is None:
target_height = 8 * ref_h1
ref_img_ids1 = ref_img_ids1.to(device)
if cond_no == 0:
img_cond_seq, img_cond_seq_ids = ref_img, ref_img_ids1
else:
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, ref_img], dim=1), torch.cat([img_cond_seq_ids, ref_img_ids1], dim=1)
# 更新pe shift
pe_shift_h += ref_h1 // 2
pe_shift_w += ref_w1 // 2
return {
"img": img,
"img_ids": img_ids.to(img.device),
"img_cond_seq": img_cond_seq,
"img_cond_seq_ids": img_cond_seq_ids,
}, target_height, target_width
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def patches_to_image(x: Tensor, height: int, width: int, patch_size: int) -> Tensor:
tokens = x.transpose(1, 2)
return F.fold(
tokens,
output_size=(height, width),
kernel_size=patch_size,
stride=patch_size,
)
|