File size: 2,429 Bytes
fcfea15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import random
import re
from pathlib import Path

import numpy as np
from PIL import Image
import torch
import torch.distributed as dist


def seed_everything(seed: int | None = None) -> None:
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


# ---------------------------------------------------------------------------
#  Distributed helpers (replaces modules.distributed.parallel_states)
# ---------------------------------------------------------------------------

def maybe_init_distributed() -> bool:
    """Initialize torch distributed if WORLD_SIZE > 1. Returns True if initialized."""
    world_size = int(os.environ.get('WORLD_SIZE', '1'))
    if world_size <= 1:
        return False
    rank = int(os.environ.get('RANK', '0'))
    dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
    return True


def clean_dist_env() -> None:
    """Destroy the distributed process group if it was initialized."""
    if dist.is_initialized():
        dist.destroy_process_group()


def _dynamic_resize_from_bucket(image: Image, basesize: int = 512):
    from modules.models.bucket import BucketGroup, generate_video_image_bucket
    from typing import Tuple
    import math
    import torchvision.transforms.functional as TF

    def resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
        """等比缩放到 >= 目标尺寸,再中心裁剪到目标尺寸。(PIL输入/输出)"""
        w, h = img.size  # PIL: (width, height)
        bh, bw = target_size
        scale = max(bh / h, bw / w)
        resize_h, resize_w = math.ceil(h * scale), math.ceil(w * scale)
        img = TF.resize(img, (resize_h, resize_w),
                        interpolation=TF.InterpolationMode.BILINEAR, antialias=True)
        img = TF.center_crop(img, target_size)
        return img

    bucket_config = generate_video_image_bucket(
        basesize=basesize, min_temporal=56, max_temporal=56, bs_img=4, bs_vid=4, bs_mimg=8, min_items=2, max_items=2
    )
    bucket_group = BucketGroup(bucket_config)
    img_w, img_h = image.size
    bucket = bucket_group.find_best_bucket((1, 1, img_h, img_w))
    target_height, target_width = bucket[-2], bucket[-1]  # (height, width)
    img_proc = resize_center_crop(image, (target_height, target_width))
    return img_proc