stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
import os
import random
import re
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.distributed as dist
def seed_everything(seed: int | None = None) -> None:
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ---------------------------------------------------------------------------
# Distributed helpers (replaces modules.distributed.parallel_states)
# ---------------------------------------------------------------------------
def maybe_init_distributed() -> bool:
"""Initialize torch distributed if WORLD_SIZE > 1. Returns True if initialized."""
world_size = int(os.environ.get('WORLD_SIZE', '1'))
if world_size <= 1:
return False
rank = int(os.environ.get('RANK', '0'))
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
return True
def clean_dist_env() -> None:
"""Destroy the distributed process group if it was initialized."""
if dist.is_initialized():
dist.destroy_process_group()
def _dynamic_resize_from_bucket(image: Image, basesize: int = 512):
from modules.models.bucket import BucketGroup, generate_video_image_bucket
from typing import Tuple
import math
import torchvision.transforms.functional as TF
def resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
"""等比缩放到 >= 目标尺寸,再中心裁剪到目标尺寸。(PIL输入/输出)"""
w, h = img.size # PIL: (width, height)
bh, bw = target_size
scale = max(bh / h, bw / w)
resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale)
img = TF.resize(img, (resize_h, resize_w),
interpolation=TF.InterpolationMode.BILINEAR, antialias=True)
img = TF.center_crop(img, target_size)
return img
bucket_config = generate_video_image_bucket(
basesize=basesize, min_temporal=56, max_temporal=56, bs_img=4, bs_vid=4, bs_mimg=8, min_items=2, max_items=2
)
bucket_group = BucketGroup(bucket_config)
img_w, img_h = image.size
bucket = bucket_group.find_best_bucket((1, 1, img_h, img_w))
target_height, target_width = bucket[-2], bucket[-1] # (height, width)
img_proc = resize_center_crop(image, (target_height, target_width))
return img_proc