Spaces:
Running on Zero
Running on Zero
| import os, random, torch | |
| import numpy as np | |
| import math | |
| def set_seed(seed: int): | |
| random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) | |
| def device_of(model): | |
| return next(model.parameters()).device | |
| def autocast_dtype(use_bf16=True): | |
| return torch.bfloat16 if use_bf16 and torch.cuda.is_available() else torch.float16 | |
| def compute_tgt_ratio(ori_hight, ori_width): | |
| """ | |
| 计算的得到一个最接近的目标height, width | |
| """ | |
| ratio = min(ori_hight / ori_width, ori_width / ori_hight) | |
| if ratio >= math.sqrt(1 / 1 * 4 / 5): | |
| key = "1/1" | |
| elif ratio >= math.sqrt(4 / 5 * 3 / 4): | |
| key = "4/5" | |
| elif ratio >= math.sqrt(3 / 4 * 2 / 3): | |
| key = "3/4" | |
| elif ratio >= math.sqrt(2 / 3 * 9 / 16): | |
| key = "2/3" | |
| elif ratio >= math.sqrt(9 / 16 * 1 / 2): | |
| key = "9/16" | |
| elif ratio >= math.sqrt(1 / 2 * 2 / 5): | |
| key = "1/2" | |
| else: | |
| key = "2/5" | |
| if ori_hight > ori_width: | |
| return key | |
| else: | |
| return "/".join(key.split("/")[::-1]) |