Spaces:
Running on Zero
Running on Zero
File size: 3,931 Bytes
a8a9bce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import math
import torch
import torchvision.transforms.functional as tvF
from einops import rearrange, repeat
from PIL import Image
from scipy import stats
from torch import Tensor
from flowdis.model import Flux
from flowdis.util import Models
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def beta_scheduler(num_timesteps: int, alpha: float = 2.5, beta: float = 1.0) -> list[float]:
q = torch.linspace(1, 0, num_timesteps+1)
steps = stats.beta.ppf(q, alpha, beta).tolist()
if steps[-1] > 0.0:
steps.append(0.0)
return steps
def prepare(
img: Tensor,
prompt: str | list[str],
models: Models,
device: str = "cuda"
) -> dict[str, Tensor]:
# load and encode the conditioning image and the mask
bs, _, _, _ = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
if isinstance(prompt, str):
prompt = [prompt]
with torch.no_grad():
img = models.ae.encode(img.to(device=device, dtype=torch.bfloat16))
h, w = img.shape[2], img.shape[3]
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
txt = models.t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = models.clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return_dict = {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
return return_dict
def solve_flowdis_ode(
model: Flux,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
num_inference_steps: int,
):
zt = img
timesteps = beta_scheduler(num_inference_steps)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((zt.shape[0],), t_curr, dtype=zt.dtype, device=zt.device)
pred = model(
img=torch.cat((zt, img), dim=-1),
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
)
zt = zt + (t_prev - t_curr) * pred
return zt
@torch.no_grad()
def flowdis_predict(
image: Tensor,
prompt: str | list[str],
models: Models,
resolution: int = 1024,
num_inference_steps: int = 2,
device: str = "cuda",
):
image_orig = image.convert("RGB")
image = image.resize((resolution, resolution))
image_t = tvF.to_tensor(image).unsqueeze(0).to(device=device)
image_t = (image_t - 0.5) / 0.5
inp = prepare(image_t, prompt, models, device)
pred_mask_latent_t = solve_flowdis_ode(
models.transformer,
**inp,
num_inference_steps=num_inference_steps,
)
pred_mask_latent_t = unpack(pred_mask_latent_t.float(), resolution, resolution)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
pred_mask_t = models.ae.decode(pred_mask_latent_t).clamp(-1, 1)
pred_mask_t = rearrange(pred_mask_t[0], "c h w -> h w c")
pred_mask_np = (127.5 * (pred_mask_t + 1.0)).mean(dim=-1).cpu().byte().numpy()
pred_mask = Image.fromarray(pred_mask_np).convert("L")
pred_mask = pred_mask.resize(image_orig.size)
return pred_mask
|