FlowDIS / flowdis /sampling.py
AndranikSargsyan
Add FlowDIS inference and demo
a8a9bce
import math
import torch
import torchvision.transforms.functional as tvF
from einops import rearrange, repeat
from PIL import Image
from scipy import stats
from torch import Tensor
from flowdis.model import Flux
from flowdis.util import Models
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def beta_scheduler(num_timesteps: int, alpha: float = 2.5, beta: float = 1.0) -> list[float]:
q = torch.linspace(1, 0, num_timesteps+1)
steps = stats.beta.ppf(q, alpha, beta).tolist()
if steps[-1] > 0.0:
steps.append(0.0)
return steps
def prepare(
img: Tensor,
prompt: str | list[str],
models: Models,
device: str = "cuda"
) -> dict[str, Tensor]:
# load and encode the conditioning image and the mask
bs, _, _, _ = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
if isinstance(prompt, str):
prompt = [prompt]
with torch.no_grad():
img = models.ae.encode(img.to(device=device, dtype=torch.bfloat16))
h, w = img.shape[2], img.shape[3]
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
txt = models.t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = models.clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return_dict = {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
return return_dict
def solve_flowdis_ode(
model: Flux,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
num_inference_steps: int,
):
zt = img
timesteps = beta_scheduler(num_inference_steps)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((zt.shape[0],), t_curr, dtype=zt.dtype, device=zt.device)
pred = model(
img=torch.cat((zt, img), dim=-1),
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
)
zt = zt + (t_prev - t_curr) * pred
return zt
@torch.no_grad()
def flowdis_predict(
image: Tensor,
prompt: str | list[str],
models: Models,
resolution: int = 1024,
num_inference_steps: int = 2,
device: str = "cuda",
):
image_orig = image.convert("RGB")
image = image.resize((resolution, resolution))
image_t = tvF.to_tensor(image).unsqueeze(0).to(device=device)
image_t = (image_t - 0.5) / 0.5
inp = prepare(image_t, prompt, models, device)
pred_mask_latent_t = solve_flowdis_ode(
models.transformer,
**inp,
num_inference_steps=num_inference_steps,
)
pred_mask_latent_t = unpack(pred_mask_latent_t.float(), resolution, resolution)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
pred_mask_t = models.ae.decode(pred_mask_latent_t).clamp(-1, 1)
pred_mask_t = rearrange(pred_mask_t[0], "c h w -> h w c")
pred_mask_np = (127.5 * (pred_mask_t + 1.0)).mean(dim=-1).cpu().byte().numpy()
pred_mask = Image.fromarray(pred_mask_np).convert("L")
pred_mask = pred_mask.resize(image_orig.size)
return pred_mask