Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import re | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.distributed as dist | |
| def seed_everything(seed: int | None = None) -> None: | |
| if seed is not None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # --------------------------------------------------------------------------- | |
| # Distributed helpers (replaces modules.distributed.parallel_states) | |
| # --------------------------------------------------------------------------- | |
| def maybe_init_distributed() -> bool: | |
| """Initialize torch distributed if WORLD_SIZE > 1. Returns True if initialized.""" | |
| world_size = int(os.environ.get('WORLD_SIZE', '1')) | |
| if world_size <= 1: | |
| return False | |
| rank = int(os.environ.get('RANK', '0')) | |
| dist.init_process_group(backend='nccl', world_size=world_size, rank=rank) | |
| return True | |
| def clean_dist_env() -> None: | |
| """Destroy the distributed process group if it was initialized.""" | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def _dynamic_resize_from_bucket(image: Image, basesize: int = 512): | |
| from modules.models.bucket import BucketGroup, generate_video_image_bucket | |
| from typing import Tuple | |
| import math | |
| import torchvision.transforms.functional as TF | |
| def resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: | |
| """等比缩放到 >= 目标尺寸,再中心裁剪到目标尺寸。(PIL输入/输出)""" | |
| w, h = img.size # PIL: (width, height) | |
| bh, bw = target_size | |
| scale = max(bh / h, bw / w) | |
| resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale) | |
| img = TF.resize(img, (resize_h, resize_w), | |
| interpolation=TF.InterpolationMode.BILINEAR, antialias=True) | |
| img = TF.center_crop(img, target_size) | |
| return img | |
| bucket_config = generate_video_image_bucket( | |
| basesize=basesize, min_temporal=56, max_temporal=56, bs_img=4, bs_vid=4, bs_mimg=8, min_items=2, max_items=2 | |
| ) | |
| bucket_group = BucketGroup(bucket_config) | |
| img_w, img_h = image.size | |
| bucket = bucket_group.find_best_bucket((1, 1, img_h, img_w)) | |
| target_height, target_width = bucket[-2], bucket[-1] # (height, width) | |
| img_proc = resize_center_crop(image, (target_height, target_width)) | |
| return img_proc | |