| import sys |
| from pathlib import Path |
| from typing import List, Optional, Sequence |
|
|
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from einops import rearrange |
| from PIL import Image |
| from torch import Tensor |
| import torchvision |
| import math |
|
|
| from shared.utils.utils import convert_image_to_tensor |
|
|
|
|
|
|
|
|
|
|
| def compress_time(t_ids: Tensor) -> Tensor: |
| assert t_ids.ndim == 1 |
| t_ids_max = torch.max(t_ids) |
| t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) |
| t_unique_sorted_ids = torch.unique(t_ids, sorted=True) |
| t_remap[t_unique_sorted_ids] = torch.arange( |
| len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype |
| ) |
| t_ids_compressed = t_remap[t_ids] |
| return t_ids_compressed |
|
|
| from einops import rearrange |
| from torch import Tensor |
|
|
| def center_crop_to_multiple_of_x( |
| img: Image.Image | list[Image.Image], x: int |
| ) -> Image.Image | list[Image.Image]: |
| if isinstance(img, list): |
| return [center_crop_to_multiple_of_x(_img, x) for _img in img] |
|
|
| w, h = img.size |
| new_w = (w // x) * x |
| new_h = (h // x) * x |
|
|
| left = (w - new_w) // 2 |
| top = (h - new_h) // 2 |
| right = left + new_w |
| bottom = top + new_h |
|
|
| resized = img.crop((left, top, right, bottom)) |
| return resized |
|
|
| def cap_pixels(img: Image.Image | list[Image.Image], k): |
| if isinstance(img, list): |
| return [cap_pixels(_img, k) for _img in img] |
| w, h = img.size |
| pixel_count = w * h |
|
|
| if pixel_count <= k: |
| return img |
|
|
| |
| scale = math.sqrt(k / pixel_count) |
| new_w = int(w * scale) |
| new_h = int(h * scale) |
|
|
| return img.resize((new_w, new_h), Image.Resampling.LANCZOS) |
|
|
|
|
| def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64): |
| if isinstance(img, list): |
| return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img] |
| w, h = img.size |
| if w < min_sidelength or h < min_sidelength: |
| raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}") |
| if w / h > max_ar or h / w > max_ar: |
| raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}") |
| return img |
|
|
|
|
| def to_rgb(img: Image.Image | list[Image.Image]): |
| if isinstance(img, list): |
| return [ |
| to_rgb( |
| _img, |
| ) |
| for _img in img |
| ] |
| return img.convert("RGB") |
|
|
|
|
| def default_images_prep( |
| x: Image.Image | list[Image.Image], |
| ) -> torch.Tensor | list[torch.Tensor]: |
| if isinstance(x, list): |
| return [default_images_prep(e) for e in x] |
| x_tensor = torchvision.transforms.ToTensor()(x) |
| return 2 * x_tensor - 1 |
|
|
| def default_prep( |
| img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16 |
| ) -> torch.Tensor | list[torch.Tensor]: |
| img_rgb = to_rgb(img) |
| img_min = cap_min_pixels(img_rgb) |
| if limit_pixels is not None: |
| img_cap = cap_pixels(img_min, limit_pixels) |
| else: |
| img_cap = img_min |
| img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) |
| img_tensor = default_images_prep(img_crop) |
| return img_tensor |
|
|
|
|
| def encode_image_refs(ae, img_ctx: list[Image.Image]): |
| scale = 10 |
|
|
| if len(img_ctx) > 1: |
| limit_pixels = 1024**2 |
| elif len(img_ctx) == 1: |
| limit_pixels = 2024**2 |
| else: |
| limit_pixels = None |
|
|
| if not img_ctx: |
| return None, None |
|
|
| img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels) |
| if not isinstance(img_ctx_prep, list): |
| img_ctx_prep = [img_ctx_prep] |
|
|
| |
| encoded_refs = [] |
| for img in img_ctx_prep: |
| encoded = ae.encode(img[None].cuda())[0] |
| encoded_refs.append(encoded) |
|
|
| |
| t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))] |
| t_off = [t.view(-1) for t in t_off] |
|
|
| |
| ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off) |
|
|
| |
| ref_tokens = torch.cat(ref_tokens, dim=0) |
| ref_ids = torch.cat(ref_ids, dim=0) |
|
|
| |
| ref_tokens = ref_tokens.unsqueeze(0) |
| ref_ids = ref_ids.unsqueeze(0) |
|
|
| return ref_tokens.to(torch.bfloat16), ref_ids |
| def listed_wrapper(fn): |
| def listed_prc( |
| x: list[Tensor], |
| t_coord: list[Tensor] | None = None, |
| ) -> tuple[list[Tensor], list[Tensor]]: |
| results = [] |
| for i in range(len(x)): |
| results.append( |
| fn( |
| x[i], |
| t_coord[i] if t_coord is not None else None, |
| ) |
| ) |
| x, x_ids = zip(*results) |
| return list(x), list(x_ids) |
|
|
| return listed_prc |
|
|
| def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: |
| _l, _ = x.shape |
|
|
| coords = { |
| "t": torch.arange(1) if t_coord is None else t_coord, |
| "h": torch.arange(1), |
| "w": torch.arange(1), |
| "l": torch.arange(_l), |
| } |
| x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) |
| return x, x_ids.to(x.device) |
|
|
| def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: |
| _, h, w = x.shape |
| x_coords = { |
| "t": torch.arange(1) if t_coord is None else t_coord, |
| "h": torch.arange(h), |
| "w": torch.arange(w), |
| "l": torch.arange(1), |
| } |
| x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"]) |
| x = rearrange(x, "c h w -> (h w) c") |
| return x, x_ids.to(x.device) |
|
|
| def batched_wrapper(fn): |
| def batched_prc(x, t_coord = None): |
| results = [] |
| for i in range(len(x)): |
| results.append( |
| fn( |
| x[i], |
| t_coord[i] if t_coord is not None else None, |
| ) |
| ) |
| x, x_ids = zip(*results) |
| return torch.stack(x), torch.stack(x_ids) |
|
|
| return batched_prc |
|
|
|
|
| listed_prc_img = listed_wrapper(prc_img) |
| batched_prc_img = batched_wrapper(prc_img) |
| batched_prc_txt = batched_wrapper(prc_txt) |
|
|
|
|
|
|
| def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]: |
| """ |
| using position ids to scatter tokens into place |
| """ |
| x_list = [] |
| t_coords = [] |
| for data, pos in zip(x, x_ids): |
| _, ch = data.shape |
| t_ids = pos[:, 0].to(torch.int64) |
| h_ids = pos[:, 1].to(torch.int64) |
| w_ids = pos[:, 2].to(torch.int64) |
|
|
| t_ids_cmpr = compress_time(t_ids) |
|
|
| t = torch.max(t_ids_cmpr) + 1 |
| h = torch.max(h_ids) + 1 |
| w = torch.max(w_ids) + 1 |
|
|
| flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids |
|
|
| out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) |
| out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) |
|
|
| x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) |
|
|
| t_coords.append(torch.unique(t_ids, sorted=True)) |
| return x_list |
|
|
|
|