|
|
import math |
|
|
|
|
|
import torch |
|
|
import torchvision |
|
|
from einops import rearrange |
|
|
from PIL import Image |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compress_time(t_ids: Tensor) -> Tensor: |
|
|
assert t_ids.ndim == 1 |
|
|
t_ids_max = torch.max(t_ids) |
|
|
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) |
|
|
t_unique_sorted_ids = torch.unique(t_ids, sorted=True) |
|
|
t_remap[t_unique_sorted_ids] = torch.arange( |
|
|
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype |
|
|
) |
|
|
t_ids_compressed = t_remap[t_ids] |
|
|
return t_ids_compressed |
|
|
|
|
|
|
|
|
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]: |
|
|
""" |
|
|
using position ids to scatter tokens into place |
|
|
""" |
|
|
x_list = [] |
|
|
t_coords = [] |
|
|
for data, pos in zip(x, x_ids): |
|
|
_, ch = data.shape |
|
|
t_ids = pos[:, 0].to(torch.int64) |
|
|
h_ids = pos[:, 1].to(torch.int64) |
|
|
w_ids = pos[:, 2].to(torch.int64) |
|
|
|
|
|
t_ids_cmpr = compress_time(t_ids) |
|
|
|
|
|
t = torch.max(t_ids_cmpr) + 1 |
|
|
h = torch.max(h_ids) + 1 |
|
|
w = torch.max(w_ids) + 1 |
|
|
|
|
|
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids |
|
|
|
|
|
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) |
|
|
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) |
|
|
|
|
|
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) |
|
|
t_coords.append(torch.unique(t_ids, sorted=True)) |
|
|
return x_list |
|
|
|
|
|
|
|
|
def encode_image_refs(ae, img_ctx: list[Image.Image]): |
|
|
scale = 10 |
|
|
|
|
|
if len(img_ctx) > 1: |
|
|
limit_pixels = 1024**2 |
|
|
elif len(img_ctx) == 1: |
|
|
limit_pixels = 2024**2 |
|
|
else: |
|
|
limit_pixels = None |
|
|
|
|
|
if not img_ctx: |
|
|
return None, None |
|
|
|
|
|
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels) |
|
|
if not isinstance(img_ctx_prep, list): |
|
|
img_ctx_prep = [img_ctx_prep] |
|
|
|
|
|
|
|
|
encoded_refs = [] |
|
|
for img in img_ctx_prep: |
|
|
encoded = ae.encode(img[None].cuda())[0] |
|
|
encoded_refs.append(encoded) |
|
|
|
|
|
|
|
|
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))] |
|
|
t_off = [t.view(-1) for t in t_off] |
|
|
|
|
|
|
|
|
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off) |
|
|
|
|
|
|
|
|
ref_tokens = torch.cat(ref_tokens, dim=0) |
|
|
ref_ids = torch.cat(ref_ids, dim=0) |
|
|
|
|
|
|
|
|
ref_tokens = ref_tokens.unsqueeze(0) |
|
|
ref_ids = ref_ids.unsqueeze(0) |
|
|
|
|
|
return ref_tokens.to(torch.bfloat16), ref_ids |
|
|
|
|
|
|
|
|
def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: |
|
|
_l, _ = x.shape |
|
|
|
|
|
coords = { |
|
|
"t": torch.arange(1) if t_coord is None else t_coord, |
|
|
"h": torch.arange(1), |
|
|
"w": torch.arange(1), |
|
|
"l": torch.arange(_l), |
|
|
} |
|
|
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) |
|
|
return x, x_ids.to(x.device) |
|
|
|
|
|
|
|
|
def batched_wrapper(fn): |
|
|
def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: |
|
|
results = [] |
|
|
for i in range(len(x)): |
|
|
results.append( |
|
|
fn( |
|
|
x[i], |
|
|
t_coord[i] if t_coord is not None else None, |
|
|
) |
|
|
) |
|
|
x, x_ids = zip(*results) |
|
|
return torch.stack(x), torch.stack(x_ids) |
|
|
|
|
|
return batched_prc |
|
|
|
|
|
|
|
|
def listed_wrapper(fn): |
|
|
def listed_prc( |
|
|
x: list[Tensor], |
|
|
t_coord: list[Tensor] | None = None, |
|
|
) -> tuple[list[Tensor], list[Tensor]]: |
|
|
results = [] |
|
|
for i in range(len(x)): |
|
|
results.append( |
|
|
fn( |
|
|
x[i], |
|
|
t_coord[i] if t_coord is not None else None, |
|
|
) |
|
|
) |
|
|
x, x_ids = zip(*results) |
|
|
return list(x), list(x_ids) |
|
|
|
|
|
return listed_prc |
|
|
|
|
|
|
|
|
def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: |
|
|
_, h, w = x.shape |
|
|
x_coords = { |
|
|
"t": torch.arange(1) if t_coord is None else t_coord, |
|
|
"h": torch.arange(h), |
|
|
"w": torch.arange(w), |
|
|
"l": torch.arange(1), |
|
|
} |
|
|
x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"]) |
|
|
x = rearrange(x, "c h w -> (h w) c") |
|
|
return x, x_ids.to(x.device) |
|
|
|
|
|
|
|
|
listed_prc_img = listed_wrapper(prc_img) |
|
|
batched_prc_img = batched_wrapper(prc_img) |
|
|
batched_prc_txt = batched_wrapper(prc_txt) |
|
|
|
|
|
|
|
|
def center_crop_to_multiple_of_x( |
|
|
img: Image.Image | list[Image.Image], x: int |
|
|
) -> Image.Image | list[Image.Image]: |
|
|
if isinstance(img, list): |
|
|
return [center_crop_to_multiple_of_x(_img, x) for _img in img] |
|
|
|
|
|
w, h = img.size |
|
|
new_w = (w // x) * x |
|
|
new_h = (h // x) * x |
|
|
|
|
|
left = (w - new_w) // 2 |
|
|
top = (h - new_h) // 2 |
|
|
right = left + new_w |
|
|
bottom = top + new_h |
|
|
|
|
|
resized = img.crop((left, top, right, bottom)) |
|
|
return resized |
|
|
|
|
|
|
|
|
def cap_pixels(img: Image.Image | list[Image.Image], k): |
|
|
if isinstance(img, list): |
|
|
return [cap_pixels(_img, k) for _img in img] |
|
|
w, h = img.size |
|
|
pixel_count = w * h |
|
|
|
|
|
if pixel_count <= k: |
|
|
return img |
|
|
|
|
|
|
|
|
scale = math.sqrt(k / pixel_count) |
|
|
new_w = int(w * scale) |
|
|
new_h = int(h * scale) |
|
|
|
|
|
return img.resize((new_w, new_h), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64): |
|
|
if isinstance(img, list): |
|
|
return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img] |
|
|
w, h = img.size |
|
|
if w < min_sidelength or h < min_sidelength: |
|
|
raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}") |
|
|
if w / h > max_ar or h / w > max_ar: |
|
|
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}") |
|
|
return img |
|
|
|
|
|
|
|
|
def to_rgb(img: Image.Image | list[Image.Image]): |
|
|
if isinstance(img, list): |
|
|
return [ |
|
|
to_rgb( |
|
|
_img, |
|
|
) |
|
|
for _img in img |
|
|
] |
|
|
return img.convert("RGB") |
|
|
|
|
|
|
|
|
def default_images_prep( |
|
|
x: Image.Image | list[Image.Image], |
|
|
) -> torch.Tensor | list[torch.Tensor]: |
|
|
if isinstance(x, list): |
|
|
return [default_images_prep(e) for e in x] |
|
|
x_tensor = torchvision.transforms.ToTensor()(x) |
|
|
return 2 * x_tensor - 1 |
|
|
|
|
|
|
|
|
def default_prep( |
|
|
img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16 |
|
|
) -> torch.Tensor | list[torch.Tensor]: |
|
|
img_rgb = to_rgb(img) |
|
|
img_min = cap_min_pixels(img_rgb) |
|
|
if limit_pixels is not None: |
|
|
img_cap = cap_pixels(img_min, limit_pixels) |
|
|
else: |
|
|
img_cap = img_min |
|
|
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) |
|
|
img_tensor = default_images_prep(img_crop) |
|
|
return img_tensor |
|
|
|
|
|
|
|
|
def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor: |
|
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
|
|
|
|
|
|
|
|
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: |
|
|
mu = compute_empirical_mu(image_seq_len, num_steps) |
|
|
timesteps = torch.linspace(1, 0, num_steps + 1) |
|
|
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) |
|
|
return timesteps.tolist() |
|
|
|
|
|
|
|
|
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: |
|
|
a1, b1 = 8.73809524e-05, 1.89833333 |
|
|
a2, b2 = 0.00016927, 0.45666666 |
|
|
|
|
|
if image_seq_len > 4300: |
|
|
mu = a2 * image_seq_len + b2 |
|
|
return float(mu) |
|
|
|
|
|
m_200 = a2 * image_seq_len + b2 |
|
|
m_10 = a1 * image_seq_len + b1 |
|
|
|
|
|
a = (m_200 - m_10) / 190.0 |
|
|
b = m_200 - 200.0 * a |
|
|
mu = a * num_steps + b |
|
|
|
|
|
return float(mu) |
|
|
|
|
|
|
|
|
|
|
|
def concatenate_images( |
|
|
images: list[Image.Image], |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Concatenate a list of PIL images horizontally with center alignment and white background. |
|
|
""" |
|
|
|
|
|
|
|
|
if len(images) == 1: |
|
|
return images[0].copy() |
|
|
|
|
|
|
|
|
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] |
|
|
|
|
|
|
|
|
total_width = sum(img.width for img in images) |
|
|
max_height = max(img.height for img in images) |
|
|
|
|
|
|
|
|
background_color = (255, 255, 255) |
|
|
new_img = Image.new("RGB", (total_width, max_height), background_color) |
|
|
|
|
|
|
|
|
x_offset = 0 |
|
|
for img in images: |
|
|
y_offset = (max_height - img.height) // 2 |
|
|
new_img.paste(img, (x_offset, y_offset)) |
|
|
x_offset += img.width |
|
|
|
|
|
return new_img |
|
|
|