vidfom's picture
Upload folder using huggingface_hub
31112ad verified
import math
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch import Tensor
# from .model import Flux2
def compress_time(t_ids: Tensor) -> Tensor:
assert t_ids.ndim == 1
t_ids_max = torch.max(t_ids)
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
t_remap[t_unique_sorted_ids] = torch.arange(
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
)
t_ids_compressed = t_remap[t_ids]
return t_ids_compressed
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
t_coords = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
t_ids = pos[:, 0].to(torch.int64)
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
t_ids_cmpr = compress_time(t_ids)
t = torch.max(t_ids_cmpr) + 1
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
t_coords.append(torch.unique(t_ids, sorted=True))
return x_list
def encode_image_refs(ae, img_ctx: list[Image.Image]):
scale = 10
if len(img_ctx) > 1:
limit_pixels = 1024**2
elif len(img_ctx) == 1:
limit_pixels = 2024**2
else:
limit_pixels = None
if not img_ctx:
return None, None
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
if not isinstance(img_ctx_prep, list):
img_ctx_prep = [img_ctx_prep]
# Encode each reference image
encoded_refs = []
for img in img_ctx_prep:
encoded = ae.encode(img[None].cuda())[0]
encoded_refs.append(encoded)
# Create time offsets for each reference
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
t_off = [t.view(-1) for t in t_off]
# Process with position IDs
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
# Concatenate all references along sequence dimension
ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C)
ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4)
# Add batch dimension
ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C)
ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4)
return ref_tokens.to(torch.bfloat16), ref_ids
def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_l, _ = x.shape # noqa: F841
coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(1), # dummy dimension
"w": torch.arange(1), # dummy dimension
"l": torch.arange(_l),
}
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
return x, x_ids.to(x.device)
def batched_wrapper(fn):
def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return torch.stack(x), torch.stack(x_ids)
return batched_prc
def listed_wrapper(fn):
def listed_prc(
x: list[Tensor],
t_coord: list[Tensor] | None = None,
) -> tuple[list[Tensor], list[Tensor]]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return list(x), list(x_ids)
return listed_prc
def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_, h, w = x.shape # noqa: F841
x_coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(h),
"w": torch.arange(w),
"l": torch.arange(1),
}
x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
x = rearrange(x, "c h w -> (h w) c")
return x, x_ids.to(x.device)
listed_prc_img = listed_wrapper(prc_img)
batched_prc_img = batched_wrapper(prc_img)
batched_prc_txt = batched_wrapper(prc_txt)
def center_crop_to_multiple_of_x(
img: Image.Image | list[Image.Image], x: int
) -> Image.Image | list[Image.Image]:
if isinstance(img, list):
return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore
w, h = img.size
new_w = (w // x) * x
new_h = (h // x) * x
left = (w - new_w) // 2
top = (h - new_h) // 2
right = left + new_w
bottom = top + new_h
resized = img.crop((left, top, right, bottom))
return resized
def cap_pixels(img: Image.Image | list[Image.Image], k):
if isinstance(img, list):
return [cap_pixels(_img, k) for _img in img]
w, h = img.size
pixel_count = w * h
if pixel_count <= k:
return img
# Scaling factor to reduce total pixels below K
scale = math.sqrt(k / pixel_count)
new_w = int(w * scale)
new_h = int(h * scale)
return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
if isinstance(img, list):
return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
w, h = img.size
if w < min_sidelength or h < min_sidelength:
raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
if w / h > max_ar or h / w > max_ar:
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
return img
def to_rgb(img: Image.Image | list[Image.Image]):
if isinstance(img, list):
return [
to_rgb(
_img,
)
for _img in img
]
return img.convert("RGB")
def default_images_prep(
x: Image.Image | list[Image.Image],
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(x, list):
return [default_images_prep(e) for e in x] # type: ignore
x_tensor = torchvision.transforms.ToTensor()(x)
return 2 * x_tensor - 1
def default_prep(
img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16
) -> torch.Tensor | list[torch.Tensor]:
img_rgb = to_rgb(img)
img_min = cap_min_pixels(img_rgb) # type: ignore
if limit_pixels is not None:
img_cap = cap_pixels(img_min, limit_pixels) # type: ignore
else:
img_cap = img_min
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore
img_tensor = default_images_prep(img_crop)
return img_tensor
def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps.tolist()
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def concatenate_images(
images: list[Image.Image],
) -> Image.Image:
"""
Concatenate a list of PIL images horizontally with center alignment and white background.
"""
# If only one image, return a copy of it
if len(images) == 1:
return images[0].copy()
# Convert all images to RGB if not already
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
# Calculate dimensions for horizontal concatenation
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)
# Create new image with white background
background_color = (255, 255, 255)
new_img = Image.new("RGB", (total_width, max_height), background_color)
# Paste images with center alignment
x_offset = 0
for img in images:
y_offset = (max_height - img.height) // 2
new_img.paste(img, (x_offset, y_offset))
x_offset += img.width
return new_img