Spaces:
Running on Zero
Running on Zero
| import math | |
| import torch | |
| import torchvision.transforms.functional as tvF | |
| from einops import rearrange, repeat | |
| from PIL import Image | |
| from scipy import stats | |
| from torch import Tensor | |
| from flowdis.model import Flux | |
| from flowdis.util import Models | |
| def unpack(x: Tensor, height: int, width: int) -> Tensor: | |
| return rearrange( | |
| x, | |
| "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
| h=math.ceil(height / 16), | |
| w=math.ceil(width / 16), | |
| ph=2, | |
| pw=2, | |
| ) | |
| def beta_scheduler(num_timesteps: int, alpha: float = 2.5, beta: float = 1.0) -> list[float]: | |
| q = torch.linspace(1, 0, num_timesteps+1) | |
| steps = stats.beta.ppf(q, alpha, beta).tolist() | |
| if steps[-1] > 0.0: | |
| steps.append(0.0) | |
| return steps | |
| def prepare( | |
| img: Tensor, | |
| prompt: str | list[str], | |
| models: Models, | |
| device: str = "cuda" | |
| ) -> dict[str, Tensor]: | |
| # load and encode the conditioning image and the mask | |
| bs, _, _, _ = img.shape | |
| if bs == 1 and not isinstance(prompt, str): | |
| bs = len(prompt) | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| with torch.no_grad(): | |
| img = models.ae.encode(img.to(device=device, dtype=torch.bfloat16)) | |
| h, w = img.shape[2], img.shape[3] | |
| img_ids = torch.zeros(h // 2, w // 2, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| if img.shape[0] == 1 and bs > 1: | |
| img = repeat(img, "1 ... -> bs ...", bs=bs) | |
| txt = models.t5(prompt) | |
| if txt.shape[0] == 1 and bs > 1: | |
| txt = repeat(txt, "1 ... -> bs ...", bs=bs) | |
| txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
| vec = models.clip(prompt) | |
| if vec.shape[0] == 1 and bs > 1: | |
| vec = repeat(vec, "1 ... -> bs ...", bs=bs) | |
| return_dict = { | |
| "img": img, | |
| "img_ids": img_ids.to(img.device), | |
| "txt": txt.to(img.device), | |
| "txt_ids": txt_ids.to(img.device), | |
| "vec": vec.to(img.device), | |
| } | |
| return return_dict | |
| def solve_flowdis_ode( | |
| model: Flux, | |
| img: Tensor, | |
| img_ids: Tensor, | |
| txt: Tensor, | |
| txt_ids: Tensor, | |
| vec: Tensor, | |
| num_inference_steps: int, | |
| ): | |
| zt = img | |
| timesteps = beta_scheduler(num_inference_steps) | |
| for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): | |
| t_vec = torch.full((zt.shape[0],), t_curr, dtype=zt.dtype, device=zt.device) | |
| pred = model( | |
| img=torch.cat((zt, img), dim=-1), | |
| img_ids=img_ids, | |
| txt=txt, | |
| txt_ids=txt_ids, | |
| y=vec, | |
| timesteps=t_vec, | |
| ) | |
| zt = zt + (t_prev - t_curr) * pred | |
| return zt | |
| def flowdis_predict( | |
| image: Tensor, | |
| prompt: str | list[str], | |
| models: Models, | |
| resolution: int = 1024, | |
| num_inference_steps: int = 2, | |
| device: str = "cuda", | |
| ): | |
| image_orig = image.convert("RGB") | |
| image = image.resize((resolution, resolution)) | |
| image_t = tvF.to_tensor(image).unsqueeze(0).to(device=device) | |
| image_t = (image_t - 0.5) / 0.5 | |
| inp = prepare(image_t, prompt, models, device) | |
| pred_mask_latent_t = solve_flowdis_ode( | |
| models.transformer, | |
| **inp, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| pred_mask_latent_t = unpack(pred_mask_latent_t.float(), resolution, resolution) | |
| with torch.autocast(device_type=device, dtype=torch.bfloat16): | |
| pred_mask_t = models.ae.decode(pred_mask_latent_t).clamp(-1, 1) | |
| pred_mask_t = rearrange(pred_mask_t[0], "c h w -> h w c") | |
| pred_mask_np = (127.5 * (pred_mask_t + 1.0)).mean(dim=-1).cpu().byte().numpy() | |
| pred_mask = Image.fromarray(pred_mask_np).convert("L") | |
| pred_mask = pred_mask.resize(image_orig.size) | |
| return pred_mask | |