LD3 / utils.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
from typing import Optional
import torch
import os
from samplers.uni_pc import UniPC
from samplers.heun import Heun
from samplers.dpm_solverpp import DPM_SolverPP
from samplers.dpm_solver import DPM_Solver
from samplers.euler import Euler
from samplers.ipndm import iPNDM
from noise_schedulers import NoiseScheduleVE
import pickle
import argparse
import time
import yaml
import random
import numpy as np
import ast
PRIOR_TIMESTEPS = {
"cifar10": {
4: [80.0, 5.1092, 1.584, 0.47, 0.002],
5: [80.0, 5.8389, 2.1632, 0.8119, 0.2107, 0.002],
6: [80.0, 9.7232, 3.3686, 1.3482, 0.5666, 0.1698, 0.002],
7: [80.0, 10.9836, 3.8811, 1.8543, 0.8119, 0.3183, 0.1079, 0.002],
8: [80.0, 10.9836, 3.8811, 1.8543, 0.9654, 0.47, 0.2107, 0.0665, 0.002],
9: [80.0, 12.3816, 4.459, 2.1632, 1.1431, 0.5666, 0.2597, 0.1079, 0.03, 0.002],
10: [80.0, 13.9293, 5.1092, 2.5152, 1.3482, 0.6799, 0.3183, 0.1698, 0.0665, 0.0225, 0.002],
},
"ffhq": {
4 :[80.0, 7.5699, 2.1632, 0.5666, 0.002],
5 : [80.0, 9.7232, 2.9152, 0.9654, 0.2597, 0.002],
6 : [80.0, 10.9836, 3.8811, 1.584, 0.5666, 0.1698, 0.002],
7 : [80.0, 12.3816, 4.459, 1.8543, 0.8119, 0.3183, 0.1079, 0.002],
8: [80.0, 12.3816, 5.1092, 2.1632, 0.9654, 0.47, 0.2107, 0.0665, 0.002],
9: [80.0, 13.9293, 5.8389, 2.9152, 1.3482, 0.6799, 0.3183, 0.1359, 0.0515, 0.002],
10: [80.0, 13.9293, 5.8389, 2.9152, 1.584, 0.8119, 0.3878, 0.2107, 0.0851, 0.03, 0.002],
},
"afhqv2": {
4 : [80.0, 7.5699, 2.1632, 0.3878, 0.002],
5 : [80.0, 8.5888, 2.9152, 0.9654, 0.2107, 0.002],
6 : [80.0, 9.7232, 3.8811, 1.584, 0.47, 0.1359, 0.002],
7 : [80.0, 10.9836, 4.459, 1.8543, 0.6799, 0.2597, 0.0851, 0.002],
8: [80.0, 12.3816, 5.1092, 2.5152, 1.1431, 0.47, 0.2107, 0.0665, 0.002],
9: [80.0, 13.9293, 5.8389, 2.9152, 1.3482, 0.6799, 0.3183, 0.1359, 0.0515, 0.002],
10: [80.0, 13.9293, 5.8389, 2.9152, 1.584, 0.8119, 0.3878, 0.2107, 0.1079, 0.0395, 0.002],
},
'lsun': {
4: [83.8225, 2.1307, 0.9556, 0.425, 0.0388],
5:[83.8225, 2.4793, 1.1629, 0.5745, 0.2411, 0.0388],
6: [83.8225, 2.4793, 1.2928, 0.7324, 0.3678, 0.1578, 0.0388],
7: [83.8225, 2.9282, 1.4464, 0.8717, 0.4929, 0.2586, 0.109, 0.0388],
8: [83.8225, 3.5196, 1.854, 1.1629, 0.7324, 0.425, 0.2249, 0.1009, 0.0388],
9: [83.8225, 3.5196, 1.854, 1.1629, 0.7324, 0.4574, 0.2773, 0.1578, 0.0731, 0.0388],
10:[83.8225, 4.3198, 2.1307, 1.2928, 0.8717, 0.5745, 0.3678, 0.2411, 0.1365, 0.0672, 0.0388],
},
'sd': {
3: [14.6146, 1.7083, 0.532, 0.0292],
4: [14.6146, 3.1131, 1.0421, 0.3811, 0.0292],
5: [14.6146, 4.39, 1.5286, 0.6526, 0.2667, 0.0292],
6: [14.6146, 4.7242, 1.9132, 0.9324, 0.4557, 0.1801, 0.0292],
7: [14.6146, 6.4477, 2.2797, 1.1629, 0.6114, 0.3058, 0.1258, 0.0292],
8: [14.6146, 6.4477, 2.7391, 1.4467, 0.8319, 0.4936, 0.2667, 0.1258, 0.0292],
9: [14.6146, 6.4477, 3.3251, 1.9132, 1.1629, 0.7391, 0.4557, 0.2667, 0.1258, 0.0292],
10: [14.6146, 5.9489, 3.3251, 2.0267, 1.2969, 0.8319, 0.5712, 0.3811, 0.2255, 0.1258, 0.0292],
11: [14.6146, 6.4477, 3.8092, 2.2797, 1.5286, 1.0421, 0.7391, 0.4936, 0.3437, 0.2255, 0.1258, 0.0292]
}
}
def parse_prior_timesteps(args):
if args.custom_ts_1 is not None:
try:
args.custom_ts_1 = ast.literal_eval(args.custom_ts_1)
except Exception:
pass
else:
if args.custom_ts_2 is not None:
try:
args.custom_ts_2 = ast.literal_eval(args.custom_ts_2)
except Exception:
pass
if args.custom_ts_2 is None:
args.custom_ts_2 = args.custom_ts_1
return
if args.use_gits:
dataset = None
if args.model == 'edm':
for d in ['cifar10', 'afhqv2', 'ffhq']:
if d in args.ckp_path:
dataset = d
break
elif args.model == 'latent_diff':
dataset = 'lsun'
elif args.model == 'conditioned_latent_diff':
dataset = 'sd'
if args.steps in PRIOR_TIMESTEPS[dataset]:
args.custom_ts_1 = PRIOR_TIMESTEPS[dataset][args.steps]
args.custom_ts_2 = args.custom_ts_1
else:
raise NotImplementedError
def set_seed_everything(seed):
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_arguments():
parser = argparse.ArgumentParser(description="Description of your program")
parser.add_argument('--all_config')
parser.add_argument('--model', help="edm/latent_diff")
model_group = parser.add_argument_group('Model Parameters')
model_group.add_argument("--ckp_path", type=str, help="Path to the checkpoint file.")
model_group.add_argument("--solver_name", type=str, help="Method for solving: heun/dpm_solver++/uni_pc.")
model_group.add_argument("--unipc_variant", type=str, choices=["bh1", "bh2"], help="Variant of UniPC: bh1/bh2.")
model_group.add_argument("--steps", type=int, help="Number of sampling steps.")
model_group.add_argument("--order", type=int, help="Order for sampling.")
model_group.add_argument("--time_mode", type=str, help="Time model: time or lambda.")
training_group = parser.add_argument_group('Training Parameters')
training_group.add_argument("--seed", type=int, help="seed")
training_group.add_argument("--use_ema", action="store_true", help="If we use ema for LSUN latent diff")
training_group.add_argument("--log_path", type=str, help="Folder name for storing evaluation results.")
training_group.add_argument("--old_log_path", type=str, help="Folder name for storing old evaluation results.")
training_group.add_argument("--data_dir", type=str, help="Path to data dir.")
training_group.add_argument("--num_train", type=int, help="Number of training sample.")
training_group.add_argument("--num_valid", type=int, help="Number of validation sample.")
training_group.add_argument("--main_train_batch_size", type=int, help="Batch size for training.")
training_group.add_argument("--main_valid_batch_size", type=int, help="Batch size for validation.")
training_group.add_argument("--win_rate", type=float, help="Win rate, should be in (0, 0.5]")
training_group.add_argument("--prior_bound", type=float, help="Prior bound.")
training_group.add_argument("--fix_bound", action="store_true", help="fix bound or not")
training_group.add_argument("--loss_type", type=str, choices=["L1", "L2", "LPIPS"], help="Type of loss: L1, L2 or LPIPS.")
training_group.add_argument("--training_rounds_v1", type=int, help="Number of training rounds for phase 1.")
training_group.add_argument("--training_rounds_v2", type=int, help="Number of training rounds for phase 2.")
training_group.add_argument("--lr_time_1", type=float, help="Learning rate for the first phase.")
training_group.add_argument("--lr_time_2", type=float, help="Learning rate for the second phase.")
training_group.add_argument("--min_lr_time_1", type=float, help="Minimum learning rate for the first phase.")
training_group.add_argument("--min_lr_time_2", type=float, help="Minimum learning rate for the second phase.")
training_group.add_argument("--momentum_time_1", type=float, help="Momentum for the first phase.")
training_group.add_argument("--weight_decay_time_1", type=float, help="Weight decay for the first phase.")
training_group.add_argument("--shift_lr", type=float, help="Learning rate for moving latents.")
training_group.add_argument("--shift_lr_decay", type=float, help="Learning rate decay for the shift phase.")
training_group.add_argument("--lr_time_decay", type=float, help="Learning rate decay for the time phase.")
training_group.add_argument("--patient", type=int, help="Patient for the time phase.")
training_group.add_argument("--lr2_patient", type=int, help="Patient for the second phase.")
training_group.add_argument("--no_v1", action="store_true", help="Skip the first phase.")
training_group.add_argument("--visualize", action="store_true", help="Visualize.")
training_group.add_argument("--low_gpu", action="store_true", help="If we using low-mem gpu, we need to use checkpoint.")
training_group.add_argument("--scale", type=int, help="Guidance scale")
training_group.add_argument("--match_prior", action="store_true", help="Whether to initial params by prior timesteps")
testing_group = parser.add_argument_group('Testing Parameters')
testing_group.add_argument("--load_from_version", type=int, default=2, help="Load from whihc version, default=2")
testing_group.add_argument("--custom_ts_1", type=str, help="Custom timesteps 1")
testing_group.add_argument("--custom_ts_2", type=str, help="Custom timesteps 2")
testing_group.add_argument("--use_gits", action="store_true", help="Use pre-computed gits timesteps")
testing_group.add_argument("--learn", action="store_true", help="Load from learned timesteps.")
testing_group.add_argument("--load_from", type=str, help="Ckpt path")
testing_group.add_argument("--skip_type", type=str, help="Type of skip.")
testing_group.add_argument("--num_multi_steps_fid", type=int, help="num_multi_steps_fid")
testing_group.add_argument("--fid_folder", type=str, default=None, help="FID path")
testing_group.add_argument("--sampling_batch_size", type=int, help="Batch size for FID calculation.")
testing_group.add_argument("--sampling_seed", type=int, help="Sampling seed for FID calculation")
testing_group.add_argument("--ref_path", type=str, help="Path to dataset reference statistics.")
testing_group.add_argument("--total_samples", type=int, help="Total number of sample for FID calculation.")
testing_group.add_argument("--save_png", action="store_true", help="Save generated img in png.")
testing_group.add_argument("--save_pt", action="store_true", help="Save generated img and latent in pt files.")
other_group = parser.add_argument_group('Other Parameters')
other_group.add_argument("--prompt_path", type=str, help="Prompt json path for stable diff")
other_group.add_argument("--num_prompts", type=int, default=5, help="Number of prompts we want to use, default 5")
other_group.add_argument("--num_samples_per_prompt", type=int, default=1, help="Number of samplers per prompt, default 1")
args = parser.parse_args()
# Load the config file if specified
if args.all_config and os.path.isfile(args.all_config):
with open(args.all_config, 'r') as f:
config = yaml.safe_load(f)
# Override the arguments with config values if they are None
for key, value in config.items():
if not hasattr(args, key) or getattr(args, key) is None:
setattr(args, key, value)
return args
def compute_distance_between_two(x, y, n_channels=3, resolution=256):
'''
x: bs x 3 x 256 x 256
y: bs x 3 x 256 x 256
'''
square_distance = (x - y) ** 2
distance = square_distance.sum(dim=(1, 2, 3)) / (n_channels * resolution * resolution)
return distance
def compute_distance_between_two_L1(x, y, n_channels=3, resolution=256):
'''
x: bs x 3 x 256 x 256
y: bs x 3 x 256 x 256
'''
square_distance = torch.abs(x - y)
distance = square_distance.sum(dim=(1, 2, 3)) / (n_channels * resolution * resolution)
return distance
def get_solvers(solver_name: str, NFEs: int, order:int, noise_schedule: NoiseScheduleVE, unipc_variant: Optional[str] = None):
solver_extra_params = dict()
if solver_name == 'euler':
steps = NFEs
solver = Euler(noise_schedule)
elif solver_name == 'heun':
steps = NFEs // 2
solver = Heun(noise_schedule)
elif solver_name == 'dpm_solver':
solver = DPM_Solver(noise_schedule)
dpm_steps, dpm_orders = solver.compute_K_and_order(NFEs, order=order)
solver_extra_params['dpm_orders'] = dpm_orders
solver_extra_params['NFEs'] = NFEs
solver_extra_params['dpm_steps'] = dpm_steps
steps = dpm_steps
elif solver_name == 'dpm_solver++':
steps = NFEs
solver = DPM_SolverPP(noise_schedule)
elif solver_name == 'uni_pc':
steps = NFEs
solver = UniPC(noise_schedule, variant=unipc_variant)
elif solver_name == 'ipndm':
steps = NFEs
solver = iPNDM(noise_schedule)
else:
raise NotImplementedError
return solver, steps, solver_extra_params
def save_arguments_to_yaml(args, filename):
with open(filename, 'w') as file:
yaml.dump(vars(args), file)
def adjust_hyper(args, resolution=64, channel=3):
parse_prior_timesteps(args)
if args.shift_lr is None:
args.shift_lr = 3.0 * 4 / args.steps
if not args.fix_bound:
args.prior_bound = 0.001 * resolution * resolution * channel / (args.steps ** 2)
args.lr_time_2 = args.lr_time_2 / args.steps
args.lr_time_2 = round(args.lr_time_2, 8)
# round prior_bound
args.prior_bound = round(args.prior_bound, 8)
# round shift_lr
args.shift_lr = round(args.shift_lr, 8)
return args
def create_desc(args):
NFEs = args.steps
method_full = args.solver_name
desc = f"{method_full}-N{NFEs}-b{args.prior_bound}-{args.loss_type}-lr2{args.lr_time_2}"
desc += f"rv1{args.training_rounds_v1}-rv2{args.training_rounds_v2}-seed{args.seed}"
if args.no_v1:
desc += "-no_v1_only_v2"
if args.match_prior:
desc += "-match_prior"
return desc
def prepare_paths(args):
skip_type=""
if args.learn:
if args.load_from is None:
desc = create_desc(args)
args.log_path = os.path.join(args.log_path, desc)
args.load_from = os.path.join(args.log_path, f'best_v{args.load_from_version}.pt')
else:
args.log_path = os.path.dirname(args.load_from)
desc = os.path.basename(args.log_path)
# if not is_trained(args.log_path):
# raise ValueError("Model not trained!")
else:
NFEs = args.steps
solver_name = args.solver_name
skip_type = args.skip_type
desc = f"{solver_name}_NFE{NFEs}_{skip_type}_seed{args.seed}"
# create fid folder
if args.fid_folder:
os.makedirs(args.fid_folder, exist_ok=True)
fid_log_path = os.path.join(args.fid_folder, f"{desc}.txt")
else:
fid_log_path = None
return desc, fid_log_path, skip_type
def check_fid_file(fid_log_path):
if os.path.exists(fid_log_path):
# check if FID has been computed
with open(fid_log_path, "r") as f:
scores = f.read()
# check if fid is a number
try:
scores = [float(_) for _ in scores.strip().split()]
if len(scores) == 1:
print(f"FID: {scores[0]}")
elif len(scores) == 2:
print(f"FID: {scores[0]}")
print(f"IS: {scores[1]}")
else:
return False
return True
except ValueError:
return False
return False
def is_trained(path):
log_path = os.path.join(path, 'log.txt')
print(log_path)
if not os.path.isfile(log_path):
print("log.txt not exist")
return False
last_line = ""
# Open the file in read mode
with open(log_path, 'r') as f:
# Read each line in the file
for line in f:
# Strip any leading or trailing whitespace
stripped_line = line.strip()
# Check if the line is not empty
if stripped_line:
last_line = stripped_line # Update last non-empty line
return "Training time" in last_line
def move_tensor_to_device(*args, device):
return [arg.to(device) if arg is not None else arg for arg in args]