UniCalli_Dev / src /flux /sampling.py
TSXu's picture
Add batch generation, torch.compile acceleration, fix dtype issues
d3ccd4b
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Flux
from .modules.conditioner import HFEmbedder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Use same dtype as img for consistency
img_dtype = img.dtype
img_ids = torch.zeros(h // 2, w // 2, 3, dtype=img_dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, dtype=img_dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, dtype=img_dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids_2 = img_ids.clone()
img_ids = torch.cat((img_ids, img_ids_2), dim=1)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3, dtype=img_dtype)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(device=img.device, dtype=img_dtype),
"txt": txt.to(device=img.device, dtype=img_dtype),
"txt_ids": txt_ids.to(device=img.device, dtype=img_dtype),
"vec": vec.to(device=img.device, dtype=img_dtype),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
neg_txt: Tensor,
neg_txt_ids: Tensor,
neg_vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
cond_latent: Tensor=None,
cond_txt_latent: Tensor=None,
true_gs = 1,
timestep_to_start_cfg=0,
# ip-adapter parameters
image_proj: Tensor=None,
neg_image_proj: Tensor=None,
ip_scale: Tensor | float = 1.0,
neg_ip_scale: Tensor | float = 1.0,
is_generation: bool = True,
):
i = 0
assert cond_txt_latent is not None
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
t_0_vec = torch.full((img.shape[0],), timesteps[-1], dtype=img.dtype, device=img.device)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
if cond_latent is not None:
_, c, h, w = cond_latent.shape
assert h * w // 4 == img.shape[1] and c * 4 == img.shape[2] # tianshuo
cond = rearrange(cond_latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if is_generation:
img = torch.cat((img, cond.to(img.dtype)), dim=1)
t1 = t_vec
t2 = t_0_vec
else:
img = torch.cat((cond.to(img.dtype), img), dim=1)
t1 = t_0_vec
t2 = t_vec
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t1,
timesteps2=t2,
cond_txt_latent=cond_txt_latent,
guidance=guidance_vec,
image_proj=image_proj,
ip_scale=ip_scale,
)
if i >= timestep_to_start_cfg:
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t1,
timesteps2=t2,
cond_txt_latent=cond_txt_latent,
guidance=guidance_vec,
image_proj=neg_image_proj,
ip_scale=neg_ip_scale,
)
pred = neg_pred + true_gs * (pred - neg_pred)
if cond_latent is not None:
if is_generation:
img = img.chunk(2, dim=1)[0]
pred = pred.chunk(2, dim=1)[0]
else:
img = img.chunk(2, dim=1)[1]
pred = pred.chunk(2, dim=1)[1]
img = img + (t_prev - t_curr) * pred
i += 1
return img
def denoise_controlnet(
model: Flux,
controlnet:None,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
neg_txt: Tensor,
neg_txt_ids: Tensor,
neg_vec: Tensor,
controlnet_cond,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
true_gs = 1,
controlnet_gs=0.7,
timestep_to_start_cfg=0,
# ip-adapter parameters
image_proj: Tensor=None,
neg_image_proj: Tensor=None,
ip_scale: Tensor | float = 1,
neg_ip_scale: Tensor | float = 1,
):
# this is ignored for schnell
i = 0
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
block_res_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_cond,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples],
image_proj=image_proj,
ip_scale=ip_scale,
)
if i >= timestep_to_start_cfg:
neg_block_res_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_cond,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
)
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples],
image_proj=neg_image_proj,
ip_scale=neg_ip_scale,
)
pred = neg_pred + true_gs * (pred - neg_pred)
img = img + (t_prev - t_curr) * pred
i += 1
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)