| import math
|
|
|
| import torch
|
| import torchvision
|
| from einops import rearrange
|
| from PIL import Image
|
| from torch import Tensor
|
|
|
|
|
|
|
|
|
| def compress_time(t_ids: Tensor) -> Tensor:
|
| assert t_ids.ndim == 1
|
| t_ids_max = torch.max(t_ids)
|
| t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
|
| t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
|
| t_remap[t_unique_sorted_ids] = torch.arange(
|
| len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
|
| )
|
| t_ids_compressed = t_remap[t_ids]
|
| return t_ids_compressed
|
|
|
|
|
| def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
|
| """
|
| using position ids to scatter tokens into place
|
| """
|
| x_list = []
|
| t_coords = []
|
| for data, pos in zip(x, x_ids):
|
| _, ch = data.shape
|
| t_ids = pos[:, 0].to(torch.int64)
|
| h_ids = pos[:, 1].to(torch.int64)
|
| w_ids = pos[:, 2].to(torch.int64)
|
|
|
| t_ids_cmpr = compress_time(t_ids)
|
|
|
| t = torch.max(t_ids_cmpr) + 1
|
| h = torch.max(h_ids) + 1
|
| w = torch.max(w_ids) + 1
|
|
|
| flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
|
|
|
| out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
|
| out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
|
|
| x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
|
| t_coords.append(torch.unique(t_ids, sorted=True))
|
| return x_list
|
|
|
|
|
| def encode_image_refs(ae, img_ctx: list[Image.Image]):
|
| scale = 10
|
|
|
| if len(img_ctx) > 1:
|
| limit_pixels = 1024**2
|
| elif len(img_ctx) == 1:
|
| limit_pixels = 2024**2
|
| else:
|
| limit_pixels = None
|
|
|
| if not img_ctx:
|
| return None, None
|
|
|
| img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
|
| if not isinstance(img_ctx_prep, list):
|
| img_ctx_prep = [img_ctx_prep]
|
|
|
|
|
| encoded_refs = []
|
| for img in img_ctx_prep:
|
| encoded = ae.encode(img[None].cuda())[0]
|
| encoded_refs.append(encoded)
|
|
|
|
|
| t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
|
| t_off = [t.view(-1) for t in t_off]
|
|
|
|
|
| ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
|
|
|
|
|
| ref_tokens = torch.cat(ref_tokens, dim=0)
|
| ref_ids = torch.cat(ref_ids, dim=0)
|
|
|
|
|
| ref_tokens = ref_tokens.unsqueeze(0)
|
| ref_ids = ref_ids.unsqueeze(0)
|
|
|
| return ref_tokens.to(torch.bfloat16), ref_ids
|
|
|
|
|
| def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
|
| _l, _ = x.shape
|
|
|
| coords = {
|
| "t": torch.arange(1) if t_coord is None else t_coord,
|
| "h": torch.arange(1),
|
| "w": torch.arange(1),
|
| "l": torch.arange(_l),
|
| }
|
| x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
|
| return x, x_ids.to(x.device)
|
|
|
|
|
| def batched_wrapper(fn):
|
| def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
|
| results = []
|
| for i in range(len(x)):
|
| results.append(
|
| fn(
|
| x[i],
|
| t_coord[i] if t_coord is not None else None,
|
| )
|
| )
|
| x, x_ids = zip(*results)
|
| return torch.stack(x), torch.stack(x_ids)
|
|
|
| return batched_prc
|
|
|
|
|
| def listed_wrapper(fn):
|
| def listed_prc(
|
| x: list[Tensor],
|
| t_coord: list[Tensor] | None = None,
|
| ) -> tuple[list[Tensor], list[Tensor]]:
|
| results = []
|
| for i in range(len(x)):
|
| results.append(
|
| fn(
|
| x[i],
|
| t_coord[i] if t_coord is not None else None,
|
| )
|
| )
|
| x, x_ids = zip(*results)
|
| return list(x), list(x_ids)
|
|
|
| return listed_prc
|
|
|
|
|
| def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
|
| _, h, w = x.shape
|
| x_coords = {
|
| "t": torch.arange(1) if t_coord is None else t_coord,
|
| "h": torch.arange(h),
|
| "w": torch.arange(w),
|
| "l": torch.arange(1),
|
| }
|
| x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
|
| x = rearrange(x, "c h w -> (h w) c")
|
| return x, x_ids.to(x.device)
|
|
|
|
|
| listed_prc_img = listed_wrapper(prc_img)
|
| batched_prc_img = batched_wrapper(prc_img)
|
| batched_prc_txt = batched_wrapper(prc_txt)
|
|
|
|
|
| def center_crop_to_multiple_of_x(
|
| img: Image.Image | list[Image.Image], x: int
|
| ) -> Image.Image | list[Image.Image]:
|
| if isinstance(img, list):
|
| return [center_crop_to_multiple_of_x(_img, x) for _img in img]
|
|
|
| w, h = img.size
|
| new_w = (w // x) * x
|
| new_h = (h // x) * x
|
|
|
| left = (w - new_w) // 2
|
| top = (h - new_h) // 2
|
| right = left + new_w
|
| bottom = top + new_h
|
|
|
| resized = img.crop((left, top, right, bottom))
|
| return resized
|
|
|
|
|
| def cap_pixels(img: Image.Image | list[Image.Image], k):
|
| if isinstance(img, list):
|
| return [cap_pixels(_img, k) for _img in img]
|
| w, h = img.size
|
| pixel_count = w * h
|
|
|
| if pixel_count <= k:
|
| return img
|
|
|
|
|
| scale = math.sqrt(k / pixel_count)
|
| new_w = int(w * scale)
|
| new_h = int(h * scale)
|
|
|
| return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
|
|
|
| def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
|
| if isinstance(img, list):
|
| return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
|
| w, h = img.size
|
| if w < min_sidelength or h < min_sidelength:
|
| raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
|
| if w / h > max_ar or h / w > max_ar:
|
| raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
|
| return img
|
|
|
|
|
| def to_rgb(img: Image.Image | list[Image.Image]):
|
| if isinstance(img, list):
|
| return [
|
| to_rgb(
|
| _img,
|
| )
|
| for _img in img
|
| ]
|
| return img.convert("RGB")
|
|
|
|
|
| def default_images_prep(
|
| x: Image.Image | list[Image.Image],
|
| ) -> torch.Tensor | list[torch.Tensor]:
|
| if isinstance(x, list):
|
| return [default_images_prep(e) for e in x]
|
| x_tensor = torchvision.transforms.ToTensor()(x)
|
| return 2 * x_tensor - 1
|
|
|
|
|
| def default_prep(
|
| img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16
|
| ) -> torch.Tensor | list[torch.Tensor]:
|
| img_rgb = to_rgb(img)
|
| img_min = cap_min_pixels(img_rgb)
|
| if limit_pixels is not None:
|
| img_cap = cap_pixels(img_min, limit_pixels)
|
| else:
|
| img_cap = img_min
|
| img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple)
|
| img_tensor = default_images_prep(img_crop)
|
| return img_tensor
|
|
|
|
|
| def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
|
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
|
|
|
| def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
|
| mu = compute_empirical_mu(image_seq_len, num_steps)
|
| timesteps = torch.linspace(1, 0, num_steps + 1)
|
| timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
|
| return timesteps.tolist()
|
|
|
|
|
| def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
| a1, b1 = 8.73809524e-05, 1.89833333
|
| a2, b2 = 0.00016927, 0.45666666
|
|
|
| if image_seq_len > 4300:
|
| mu = a2 * image_seq_len + b2
|
| return float(mu)
|
|
|
| m_200 = a2 * image_seq_len + b2
|
| m_10 = a1 * image_seq_len + b1
|
|
|
| a = (m_200 - m_10) / 190.0
|
| b = m_200 - 200.0 * a
|
| mu = a * num_steps + b
|
|
|
| return float(mu)
|
|
|
|
|
|
|
| def concatenate_images(
|
| images: list[Image.Image],
|
| ) -> Image.Image:
|
| """
|
| Concatenate a list of PIL images horizontally with center alignment and white background.
|
| """
|
|
|
|
|
| if len(images) == 1:
|
| return images[0].copy()
|
|
|
|
|
| images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
|
|
|
|
|
| total_width = sum(img.width for img in images)
|
| max_height = max(img.height for img in images)
|
|
|
|
|
| background_color = (255, 255, 255)
|
| new_img = Image.new("RGB", (total_width, max_height), background_color)
|
|
|
|
|
| x_offset = 0
|
| for img in images:
|
| y_offset = (max_height - img.height) // 2
|
| new_img.paste(img, (x_offset, y_offset))
|
| x_offset += img.width
|
|
|
| return new_img
|
|
|