LORE / flux /sampling_lore.py
oyly
del last loss
35e2e6a
import math
import random
import copy
from typing import Callable
import torch
import numpy as np
from einops import rearrange, repeat
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from .model_lore import Flux
from .modules.conditioner_lore import HFEmbedder
def prepare_tokens(t5, source_prompt, target_prompt, replacements,show_tokens=False):
_, _, src_dif_ids, tgt_dif_ids=t5.get_text_embeddings_with_diff(source_prompt,target_prompt,replacements,show_tokens=show_tokens)
return src_dif_ids,tgt_dif_ids
transform = transforms.ToTensor()
def get_mask_one_tensor(mask_dirs,width,height,device):
res = []
for mask_dir in mask_dirs:
mask_image = Image.open(mask_dir).convert('L')
# resize
mask_image = mask_image.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
mask_tensor = transform(mask_image)
mask_tensor = mask_tensor.squeeze(0)
# to one dim
mask_tensor = mask_tensor.flatten()
mask_tensor = mask_tensor.to(device)
res.append(mask_tensor)
res = sum(res)
res = res.view(1, 1, -1, 1)
res = res.to(torch.bfloat16)
return res
def get_v_mask(mask_dirs,width,height,device,txt_length=512):
res = []
for mask_dir in mask_dirs:
mask_image = Image.open(mask_dir).convert('L')
# resize
mask_image = mask_image.resize((math.ceil(height/16), math.ceil(width/16)), Image.Resampling.LANCZOS)
mask_tensor = transform(mask_image)
mask_tensor = mask_tensor.squeeze(0)
# to one dim
mask_tensor = mask_tensor.flatten()
mask_tensor = mask_tensor.to(device)
res.append(mask_tensor)
res = sum(res)
res = torch.cat([torch.ones(txt_length).to(device),res])
res = res.view(1, 1, -1, 1)
res = res.to(torch.bfloat16)
return res
def add_masked_noise_to_z(z,mask,width,height,seed=42,noise_scale=0.1):
if noise_scale == 0:
return z
noise = torch.randn(z.shape,device=z.device,dtype=z.dtype,generator=torch.Generator(device=z.device).manual_seed(seed))
if noise_scale > 10:
return noise
# how to change z?
z = z*(1-mask[0])+noise_scale*noise*mask[0]+(1-noise_scale)*z*mask[0]
return z
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
inverse,
info,
guidance: float = 4.0,
trainable_noise_list=None,
):
# this is ignored for schnell
inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
if inverse:
timesteps = timesteps[::-1]
inject_list = inject_list[::-1]
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
step_list = []
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
info['t'] = t_prev if inverse else t_curr
info['inverse'] = inverse
info['second_order'] = False
info['inject'] = inject_list[i]
# when editing add optim latent for several steps
if trainable_noise_list and i != 0 and i<len(trainable_noise_list):
# smask = info['source_mask'].squeeze(0)
# img = trainable_noise_list[i]*smask+img*(1-smask)
img = trainable_noise_list[i]
pred, info, attn_maps_mid = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
info=info
)
img_mid = img + (t_prev - t_curr) / 2 * pred
t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
info['second_order'] = True
pred_mid, info, attn_maps = model(
img=img_mid,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec_mid,
guidance=guidance_vec,
info=info
)
first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
# return attnmaps L,1,512,N
step_list.append(t_curr)
return img, info, step_list, None
selected_layers = range(8,44)
def gaussian_smooth(attnmap,wh,kernel_size=3,sigma=0.5):
# to 2d
attnmap = rearrange(
attnmap,
"b (w h) -> b (w) (h)",
w=math.ceil(wh[0]/16),
h=math.ceil(wh[1]/16),
)
attnmap = attnmap.unsqueeze(1)
# prepare kernel
ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1., device=attnmap.device)
xx, yy = torch.meshgrid(ax, ax, indexing='ij')
kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
kernel = kernel / kernel.sum()
kernel = kernel.view(1, 1, kernel_size, kernel_size)
kernel = kernel.to(dtype=attnmap.dtype)
# gaussian smooth
attnmap_smoothed = F.conv2d(attnmap, kernel, padding=kernel_size // 2)
return attnmap_smoothed.view(attnmap_smoothed.shape[0], -1)
def compute_attn_max_loss(attnmaps,source_mask,wh):
# attnmaps L,1,N,k
attnmaps = attnmaps[selected_layers,0,:,:]
attnmaps = attnmaps.mean(dim=-1)
src_mask = source_mask.view(-1).unsqueeze(0)
p = attnmaps*src_mask
p = gaussian_smooth(p, wh, kernel_size=3, sigma=0.5)
p = p.max(dim=1).values
loss = (1 - p).mean()
return loss
def compute_attn_min_loss(attnmaps,source_mask,wh):
# attnmaps L,1,N,k
attnmaps = attnmaps[selected_layers,0,:,:]
attnmaps = attnmaps.mean(dim=-1)
src_mask = source_mask.view(-1).unsqueeze(0)
p = attnmaps*src_mask
p = gaussian_smooth(p, wh, kernel_size=3, sigma=0.5)
p = p.max(dim=1).values
loss = p.mean()
return loss
def denoise_with_noise_optim(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# loss cal
token_ids: list[list[int]],
source_mask: Tensor,
training_steps: int,
training_epochs: int,
learning_rate: float,
seed: int,
noise_scale: float,
# sampling parameters
timesteps: list[float],
info,
guidance: float = 4.0
):
# this is ignored for schnell
#print(f'training the noise in last {training_steps} steps and {training_epochs} epochs')
#timesteps = timesteps[::-1]
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
step_list = []
trainable_noise_list = []
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
if i >= training_steps:
break
# prepare ori parameters
ori_txt = txt.clone()
ori_img = img.clone()
ori_vec = vec.clone()
# prepare trainable noise
if i == 0:
if noise_scale == 0:
trainable_noise = torch.nn.Parameter(img.clone().detach(), requires_grad=True)
else:
noise = torch.randn(img.shape,device=img.device,dtype=img.dtype,generator=torch.Generator(device=img.device).manual_seed(seed))
noise = img*(1-source_mask[0])+ noise_scale*noise*source_mask[0] + (1-noise_scale)*img*source_mask[0]
trainable_noise = torch.nn.Parameter(noise.clone().detach(), requires_grad=True)
else:
trainable_noise = torch.nn.Parameter(img.clone().detach(), requires_grad=True)
optimizer = optim.Adam([trainable_noise], lr=learning_rate)
# run one training step
for j in range(training_epochs):
optimizer.zero_grad()
txt = ori_txt.clone().detach()
vec = ori_vec.clone().detach()
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
info['t'] = t_prev
info['inverse'] = False
info['second_order'] = False
info['inject'] = False # tried True, seems not necessary
pred, info, attn_maps_mid = model(
img=trainable_noise,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
info=info
)
img_mid = trainable_noise + (t_prev - t_curr) / 2 * pred
t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
info['second_order'] = True
pred_mid, info, attn_maps = model(
img=img_mid,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec_mid,
guidance=guidance_vec,
info=info
)
first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
# attnmaps L,1,N,512 for cal loss
attn_maps=(attn_maps_mid+attn_maps)/2
total_loss = 0.0
for indices,change,ratio in token_ids:
if change:
total_loss += compute_attn_max_loss(attn_maps[:,:,:,indices], source_mask, info['wh'])
else:
if ratio != 0:
total_loss += ratio*compute_attn_min_loss(attn_maps[:,:,:,indices], source_mask, info['wh'])
total_loss.backward()
with torch.no_grad():
trainable_noise.grad *= source_mask[0]
optimizer.step()
print(f"Time {t_curr:.4f} Step {j+1}/{training_epochs}, Loss: {total_loss.item():.6f}")
del total_loss,attn_maps
torch.cuda.empty_cache()
torch.cuda.synchronize()
step_list.append(t_curr)
trainable_noise = trainable_noise.detach()
trainable_noise_list.append(trainable_noise.clone())
return img, info, step_list, None, trainable_noise_list
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)