ltx2 / Wan2GP /models /flux /flux2_adapter.py
vidfom's picture
Upload folder using huggingface_hub
31112ad verified
import sys
from pathlib import Path
from typing import List, Optional, Sequence
import torch
import torch.nn as nn
from PIL import Image
from einops import rearrange
from PIL import Image
from torch import Tensor
import torchvision
import math
from shared.utils.utils import convert_image_to_tensor
def compress_time(t_ids: Tensor) -> Tensor:
assert t_ids.ndim == 1
t_ids_max = torch.max(t_ids)
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
t_remap[t_unique_sorted_ids] = torch.arange(
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
)
t_ids_compressed = t_remap[t_ids]
return t_ids_compressed
from einops import rearrange
from torch import Tensor
def center_crop_to_multiple_of_x(
img: Image.Image | list[Image.Image], x: int
) -> Image.Image | list[Image.Image]:
if isinstance(img, list):
return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore
w, h = img.size
new_w = (w // x) * x
new_h = (h // x) * x
left = (w - new_w) // 2
top = (h - new_h) // 2
right = left + new_w
bottom = top + new_h
resized = img.crop((left, top, right, bottom))
return resized
def cap_pixels(img: Image.Image | list[Image.Image], k):
if isinstance(img, list):
return [cap_pixels(_img, k) for _img in img]
w, h = img.size
pixel_count = w * h
if pixel_count <= k:
return img
# Scaling factor to reduce total pixels below K
scale = math.sqrt(k / pixel_count)
new_w = int(w * scale)
new_h = int(h * scale)
return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
if isinstance(img, list):
return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
w, h = img.size
if w < min_sidelength or h < min_sidelength:
raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
if w / h > max_ar or h / w > max_ar:
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
return img
def to_rgb(img: Image.Image | list[Image.Image]):
if isinstance(img, list):
return [
to_rgb(
_img,
)
for _img in img
]
return img.convert("RGB")
def default_images_prep(
x: Image.Image | list[Image.Image],
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(x, list):
return [default_images_prep(e) for e in x] # type: ignore
x_tensor = torchvision.transforms.ToTensor()(x)
return 2 * x_tensor - 1
def default_prep(
img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16
) -> torch.Tensor | list[torch.Tensor]:
img_rgb = to_rgb(img)
img_min = cap_min_pixels(img_rgb) # type: ignore
if limit_pixels is not None:
img_cap = cap_pixels(img_min, limit_pixels) # type: ignore
else:
img_cap = img_min
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore
img_tensor = default_images_prep(img_crop)
return img_tensor
def encode_image_refs(ae, img_ctx: list[Image.Image]):
scale = 10
if len(img_ctx) > 1:
limit_pixels = 1024**2
elif len(img_ctx) == 1:
limit_pixels = 2024**2
else:
limit_pixels = None
if not img_ctx:
return None, None
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
if not isinstance(img_ctx_prep, list):
img_ctx_prep = [img_ctx_prep]
# Encode each reference image
encoded_refs = []
for img in img_ctx_prep:
encoded = ae.encode(img[None].cuda())[0]
encoded_refs.append(encoded)
# Create time offsets for each reference
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
t_off = [t.view(-1) for t in t_off]
# Process with position IDs
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
# Concatenate all references along sequence dimension
ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C)
ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4)
# Add batch dimension
ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C)
ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4)
return ref_tokens.to(torch.bfloat16), ref_ids
def listed_wrapper(fn):
def listed_prc(
x: list[Tensor],
t_coord: list[Tensor] | None = None,
) -> tuple[list[Tensor], list[Tensor]]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return list(x), list(x_ids)
return listed_prc
def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_l, _ = x.shape # noqa: F841
coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(1), # dummy dimension
"w": torch.arange(1), # dummy dimension
"l": torch.arange(_l),
}
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
return x, x_ids.to(x.device)
def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_, h, w = x.shape # noqa: F841
x_coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(h),
"w": torch.arange(w),
"l": torch.arange(1),
}
x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
x = rearrange(x, "c h w -> (h w) c")
return x, x_ids.to(x.device)
def batched_wrapper(fn):
def batched_prc(x, t_coord = None):
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return torch.stack(x), torch.stack(x_ids)
return batched_prc
listed_prc_img = listed_wrapper(prc_img)
batched_prc_img = batched_wrapper(prc_img)
batched_prc_txt = batched_wrapper(prc_txt)
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
t_coords = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
t_ids = pos[:, 0].to(torch.int64)
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
t_ids_cmpr = compress_time(t_ids)
t = torch.max(t_ids_cmpr) + 1
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
t_coords.append(torch.unique(t_ids, sorted=True))
return x_list