import math from typing import Callable, Optional, Union, List, Dict, Any import os from PIL import Image import torch from einops import rearrange, repeat from torch import Tensor from .model import Flux from .modules.conditioner import HFEmbedder from .modules.autoencoder import AutoEncoder def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: bs, c, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) print(f"prepare t5 embedding: {txt}") print(f"prepare clip embedding: {vec}") return { "img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device), } def prepare_image(img: Tensor): bs, c, h, w = img.shape img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) return img def time_shift(mu: float, sigma: float, t: Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_lin_function( x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 ) -> Callable[[float], float]: m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b def get_noise( num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int, ): return torch.randn( num_samples, 16, # allow for packing 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=dtype, generator=torch.Generator(device=device).manual_seed(seed), ) def get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True, ) -> list[float]: # extra step for zero timesteps = torch.linspace(1, 0, num_steps + 1) # shifting the schedule to favor high timesteps for higher signal images if shift: # estimate mu based on linear estimation between two points mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() def denoise_rf( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], inverse, info, guidance: float = 4.0 ): # this is ignored for schnell inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step']) if inverse: timesteps = timesteps[::-1] inject_list = inject_list[::-1] guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) step_list = [] for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) info['t'] = t_prev if inverse else t_curr info['inverse'] = inverse info['second_order'] = False info['inject'] = inject_list[i] pred, info = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, info=info, cur_step = i ) img = img + (t_prev - t_curr) * pred return img, info def denoise_rf_solver( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], inverse, info, guidance: float = 4.0, img_ori: Optional[Tensor] = None ): # this is ignored for schnell inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step']) if inverse: timesteps = timesteps[::-1] inject_list = inject_list[::-1] guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) step_list = [] for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) info['t'] = t_prev if inverse else t_curr info['inverse'] = inverse info['second_order'] = False info['inject'] = inject_list[i] pred, info = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, info=info, cur_step = i ) img_mid = img + (t_prev - t_curr) / 2 * pred t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device) info['second_order'] = True pred_mid, info = model( img=img_mid, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec_mid, guidance=guidance_vec, info=info, cur_step = i ) first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2) img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order return img, info def denoise_fireflow( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], inverse, info, guidance: float = 4.0, img_ori: Optional[Tensor] = None, ae: Optional[AutoEncoder] = None, # Optional AutoEncoder for decoding device: Optional[Union[str, torch.device]] = None # Optional device specification ): # this is ignored for schnell inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step']) if inverse: timesteps = timesteps[::-1] inject_list = inject_list[::-1] guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) step_list = [] next_step_velocity = None for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) info['t'] = t_prev if inverse else t_curr info['inverse'] = inverse info['second_order'] = False info['inject'] = inject_list[i] if next_step_velocity is None: pred, info = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, info=info, cur_step=i ) else: pred = next_step_velocity img_mid = img + (t_prev - t_curr) / 2 * pred t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device) info['second_order'] = True pred_mid, info = model( img=img_mid, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec_mid, guidance=guidance_vec, info=info, cur_step=i ) next_step_velocity = pred_mid img = img + (t_prev - t_curr) * pred_mid ########################### save generating steps ############################## #idx = len(timesteps) - 1 #fn = f'result/intermediate_{idx}steps' #if not os.path.exists(fn): #os.makedirs(fn) #fn += f'/fireflow_{t_prev}.jpg' #if inverse: #fn = f'result/intermediate_{idx}steps/inverse_fireflow_{t_prev}.jpg' # decode latents to pixel space #x = unpack(img.float(), img.shape[1] ** 0.5 * 16, img.shape[1] ** 0.5 * 16) #with torch.autocast(device_type=device.type, dtype=torch.bfloat16): #x = ae.decode(x) # bring into PIL format and save #x = x.clamp(-1, 1) #x = rearrange(x[0], "c h w -> h w c") #x = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) #x.save(fn) ########################### save generating steps ############################## return img, info def denoise_midpoint( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], inverse, info, guidance: float = 4.0 ): # this is ignored for schnell inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step']) if inverse: timesteps = timesteps[::-1] inject_list = inject_list[::-1] guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) step_list = [] for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) info['t'] = t_prev if inverse else t_curr info['inverse'] = inverse info['second_order'] = False info['inject'] = inject_list[i] pred, info = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, info=info ) img_mid = img + (t_prev - t_curr) / 2 * pred t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device) info['second_order'] = True pred_mid, info = model( img=img_mid, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec_mid, guidance=guidance_vec, info=info ) next_step_velocity = pred_mid img = img + (t_prev - t_curr) * pred_mid return img, info def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2, ) def denoise_rf_inversion( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], inverse, info, guidance: float = 4.0, stop_timestep: float = 0.35, img_LQR: Dict = {"source img": None, "prev img": None} ): # this is ignored for schnell inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step']) gamma_steps = int(stop_timestep * len(timesteps[:-1])) #gamma_steps = 9 gamma = [0.9] * gamma_steps + [0] * (len(timesteps[:-1]) - gamma_steps) # γ ∈ [0, 1] the controller guidance, γ can be time-varying if inverse: # todo if inverse, text prompt is φ timesteps = timesteps[::-1] inject_list = inject_list[::-1] gamma = [0.5] * len(timesteps[:-1]) # γ ∈ [0, 1] the controller guidance guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) step_list = [] y1 = torch.randn(img.shape, device=img.device, dtype=img.dtype) y0, y_prev = None, None if img_LQR['source img'] is not None: y0 = img_LQR['source img'].to(img.device) if img_LQR['prev img'] is not None: y_prev = img_LQR['prev img'].to(img.device) for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) info['t'] = t_prev if inverse else t_curr info['inverse'] = inverse info['second_order'] = False info['inject'] = inject_list[i] pred, info = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, info=info, cur_step=i ) # 6. Unconditional Vector field uti(Yti) = u(Yti, ti, Φ(“”); φ) unconditional_vector_field = pred if not inverse: unconditional_vector_field = -unconditional_vector_field if inverse: # 7.Conditional Vector field uti(Yti|y1) = (y1−Yti)/1−ti conditional_vector_field = (y1 - img) / (1 - t_curr) else: # 7.Conditional Vector field uti(Xti|y0) = (y0−Xti)/(1−ti) t_i = i / len(timesteps[:-1]) # Empiracally better results #conditional_vector_field = (y0 - img) / t_curr if y_prev is None: conditional_vector_field = (y0 - img) / (1 - t_i) else: #conditional_vector_field = (y_prev - img) / (1 - t_i) conditional_vector_field = (y0 - img) / (1 - t_i) + 0.7 * ((y_prev - img) / (1 - t_i) - (y0 - img) / (1 - t_i)) # 8. Controlled Vector field ti(Yti) = uti(Yti) + γ (uti(Yti|y1) − uti(Yti)) controlled_vector_field = unconditional_vector_field + gamma[i] * (conditional_vector_field - unconditional_vector_field) # 9. Next state Yti+1 = Yti + ˆuti(Yti) (σ(ti+1) − σ(ti)) delta_t = t_prev - t_curr if delta_t < 0: delta_t = t_curr - t_prev img = img + delta_t * controlled_vector_field return img, info def denoise_multi_turn_consistent( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], inverse, info, guidance: float = 4.0, #img_ori: Optional[Tensor] = None img_LQR: Dict = {"source img": None, "prev img": None} ): # this is ignored for schnell inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step']) gamma_steps = int(info['lqr_stop'] * len(timesteps[:-1])) #gamma_steps = 9 gamma = [0.9] * gamma_steps + [0] * (len(timesteps[:-1]) - gamma_steps) # γ ∈ [0, 1] the controller guidance, γ can be time-varying if inverse: # todo if inverse, text prompt is φ timesteps = timesteps[::-1] inject_list = inject_list[::-1] gamma = [0.5] * len(timesteps[:-1]) # γ ∈ [0, 1] the controller guidance guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) step_list = [] y1 = torch.randn(img.shape, device=img.device, dtype=img.dtype) y0, y_prev = None, None if img_LQR['source img'] is not None: y0 = img_LQR['source img'].to(img.device) if img_LQR['prev img'] is not None: y_prev = img_LQR['prev img'].to(img.device) next_step_velocity = None for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) info['t'] = t_prev if inverse else t_curr info['inverse'] = inverse info['second_order'] = False info['inject'] = inject_list[i] if next_step_velocity is None: pred, info = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, info=info, cur_step=i ) else: pred = next_step_velocity img_mid = img + (t_prev - t_curr) / 2 * pred t_vec_mid = torch.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype, device=img.device) info['second_order'] = True pred_mid, info = model( img=img_mid, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec_mid, guidance=guidance_vec, info=info, cur_step=i ) next_step_velocity = pred_mid # 6. Unconditional Vector field uti(Yti) = u(Yti, ti, Φ(“”); φ) unconditional_vector_field = pred_mid if not inverse: unconditional_vector_field = -unconditional_vector_field if inverse: # 7.Conditional Vector field uti(Yti|y1) = (y1−Yti)/(1−ti) conditional_vector_field = (y1 - img) / (1 - t_curr + (t_prev - t_curr) / 2) else: # 7.Conditional Vector field uti(Xti|y0) = (y0−Xti)/(1−ti) t_i = i / len(timesteps[:-1]) # Empiracally better results #conditional_vector_field = (y0 - img) / t_curr if y_prev is None: conditional_vector_field = (y0 - img) / (1 - t_i) else: conditional_vector_field = (y0 - img) / (1 - t_i) + 0.7 * ((y_prev - img) / (1 - t_i) - (y0 - img) / (1 - t_i)) #conditional_vector_field = (y_prev - img) / (1 - t_i) # 8. Controlled Vector field ti(Yti) = uti(Yti) + γ (uti(Yti|y1) − uti(Yti)) controlled_vector_field = unconditional_vector_field + gamma[i] * (conditional_vector_field - unconditional_vector_field) # 9. Next state Yti+1 = Yti + ˆuti(Yti) (σ(ti+1) − σ(ti)) delta_t = t_prev - t_curr if delta_t < 0: delta_t = t_curr - t_prev img = img + delta_t * controlled_vector_field return img, info