stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
raw
history blame
2.43 kB
import os
import random
import re
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.distributed as dist
def seed_everything(seed: int | None = None) -> None:
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ---------------------------------------------------------------------------
# Distributed helpers (replaces modules.distributed.parallel_states)
# ---------------------------------------------------------------------------
def maybe_init_distributed() -> bool:
"""Initialize torch distributed if WORLD_SIZE > 1. Returns True if initialized."""
world_size = int(os.environ.get('WORLD_SIZE', '1'))
if world_size <= 1:
return False
rank = int(os.environ.get('RANK', '0'))
dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
return True
def clean_dist_env() -> None:
"""Destroy the distributed process group if it was initialized."""
if dist.is_initialized():
dist.destroy_process_group()
def _dynamic_resize_from_bucket(image: Image, basesize: int = 512):
from modules.models.bucket import BucketGroup, generate_video_image_bucket
from typing import Tuple
import math
import torchvision.transforms.functional as TF
def resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
"""等比缩放到 >= 目标尺寸,再中心裁剪到目标尺寸。(PIL输入/输出)"""
w, h = img.size # PIL: (width, height)
bh, bw = target_size
scale = max(bh / h, bw / w)
resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale)
img = TF.resize(img, (resize_h, resize_w),
interpolation=TF.InterpolationMode.BILINEAR, antialias=True)
img = TF.center_crop(img, target_size)
return img
bucket_config = generate_video_image_bucket(
basesize=basesize, min_temporal=56, max_temporal=56, bs_img=4, bs_vid=4, bs_mimg=8, min_items=2, max_items=2
)
bucket_group = BucketGroup(bucket_config)
img_w, img_h = image.size
bucket = bucket_group.find_best_bucket((1, 1, img_h, img_w))
target_height, target_width = bucket[-2], bucket[-1] # (height, width)
img_proc = resize_center_crop(image, (target_height, target_width))
return img_proc