File size: 1,117 Bytes
0ff8d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os, random, torch
import numpy as np
import math

def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def device_of(model):
    return next(model.parameters()).device

def autocast_dtype(use_bf16=True):
    return torch.bfloat16 if use_bf16 and torch.cuda.is_available() else torch.float16


def compute_tgt_ratio(ori_hight, ori_width):
    """

    计算的得到一个最接近的目标height, width

    """

    ratio = min(ori_hight / ori_width, ori_width / ori_hight)

    if ratio >= math.sqrt(1 / 1 * 4 / 5):
        key = "1/1"
    elif ratio >= math.sqrt(4 / 5 * 3 / 4):
        key = "4/5"
    elif ratio >= math.sqrt(3 / 4 * 2 / 3):
        key = "3/4"
    elif ratio >= math.sqrt(2 / 3 * 9 / 16):
        key = "2/3"
    elif ratio >= math.sqrt(9 / 16 * 1 / 2):
        key = "9/16"
    elif ratio >= math.sqrt(1 / 2 * 2 / 5):
        key = "1/2"
    else:
        key = "2/5"

    if ori_hight > ori_width:
        return key
    else:
        return "/".join(key.split("/")[::-1])