|
|
from typing import Optional |
|
|
import torch |
|
|
import os |
|
|
|
|
|
from samplers.uni_pc import UniPC |
|
|
from samplers.heun import Heun |
|
|
from samplers.dpm_solverpp import DPM_SolverPP |
|
|
from samplers.dpm_solver import DPM_Solver |
|
|
from samplers.euler import Euler |
|
|
from samplers.ipndm import iPNDM |
|
|
from noise_schedulers import NoiseScheduleVE |
|
|
import pickle |
|
|
import argparse |
|
|
import time |
|
|
import yaml |
|
|
import random |
|
|
import numpy as np |
|
|
import ast |
|
|
|
|
|
PRIOR_TIMESTEPS = { |
|
|
"cifar10": { |
|
|
4: [80.0, 5.1092, 1.584, 0.47, 0.002], |
|
|
5: [80.0, 5.8389, 2.1632, 0.8119, 0.2107, 0.002], |
|
|
6: [80.0, 9.7232, 3.3686, 1.3482, 0.5666, 0.1698, 0.002], |
|
|
7: [80.0, 10.9836, 3.8811, 1.8543, 0.8119, 0.3183, 0.1079, 0.002], |
|
|
8: [80.0, 10.9836, 3.8811, 1.8543, 0.9654, 0.47, 0.2107, 0.0665, 0.002], |
|
|
9: [80.0, 12.3816, 4.459, 2.1632, 1.1431, 0.5666, 0.2597, 0.1079, 0.03, 0.002], |
|
|
10: [80.0, 13.9293, 5.1092, 2.5152, 1.3482, 0.6799, 0.3183, 0.1698, 0.0665, 0.0225, 0.002], |
|
|
}, |
|
|
"ffhq": { |
|
|
4 :[80.0, 7.5699, 2.1632, 0.5666, 0.002], |
|
|
5 : [80.0, 9.7232, 2.9152, 0.9654, 0.2597, 0.002], |
|
|
6 : [80.0, 10.9836, 3.8811, 1.584, 0.5666, 0.1698, 0.002], |
|
|
7 : [80.0, 12.3816, 4.459, 1.8543, 0.8119, 0.3183, 0.1079, 0.002], |
|
|
8: [80.0, 12.3816, 5.1092, 2.1632, 0.9654, 0.47, 0.2107, 0.0665, 0.002], |
|
|
9: [80.0, 13.9293, 5.8389, 2.9152, 1.3482, 0.6799, 0.3183, 0.1359, 0.0515, 0.002], |
|
|
10: [80.0, 13.9293, 5.8389, 2.9152, 1.584, 0.8119, 0.3878, 0.2107, 0.0851, 0.03, 0.002], |
|
|
}, |
|
|
"afhqv2": { |
|
|
4 : [80.0, 7.5699, 2.1632, 0.3878, 0.002], |
|
|
5 : [80.0, 8.5888, 2.9152, 0.9654, 0.2107, 0.002], |
|
|
6 : [80.0, 9.7232, 3.8811, 1.584, 0.47, 0.1359, 0.002], |
|
|
7 : [80.0, 10.9836, 4.459, 1.8543, 0.6799, 0.2597, 0.0851, 0.002], |
|
|
8: [80.0, 12.3816, 5.1092, 2.5152, 1.1431, 0.47, 0.2107, 0.0665, 0.002], |
|
|
9: [80.0, 13.9293, 5.8389, 2.9152, 1.3482, 0.6799, 0.3183, 0.1359, 0.0515, 0.002], |
|
|
10: [80.0, 13.9293, 5.8389, 2.9152, 1.584, 0.8119, 0.3878, 0.2107, 0.1079, 0.0395, 0.002], |
|
|
}, |
|
|
'lsun': { |
|
|
4: [83.8225, 2.1307, 0.9556, 0.425, 0.0388], |
|
|
5:[83.8225, 2.4793, 1.1629, 0.5745, 0.2411, 0.0388], |
|
|
6: [83.8225, 2.4793, 1.2928, 0.7324, 0.3678, 0.1578, 0.0388], |
|
|
7: [83.8225, 2.9282, 1.4464, 0.8717, 0.4929, 0.2586, 0.109, 0.0388], |
|
|
8: [83.8225, 3.5196, 1.854, 1.1629, 0.7324, 0.425, 0.2249, 0.1009, 0.0388], |
|
|
9: [83.8225, 3.5196, 1.854, 1.1629, 0.7324, 0.4574, 0.2773, 0.1578, 0.0731, 0.0388], |
|
|
10:[83.8225, 4.3198, 2.1307, 1.2928, 0.8717, 0.5745, 0.3678, 0.2411, 0.1365, 0.0672, 0.0388], |
|
|
}, |
|
|
'sd': { |
|
|
3: [14.6146, 1.7083, 0.532, 0.0292], |
|
|
4: [14.6146, 3.1131, 1.0421, 0.3811, 0.0292], |
|
|
5: [14.6146, 4.39, 1.5286, 0.6526, 0.2667, 0.0292], |
|
|
6: [14.6146, 4.7242, 1.9132, 0.9324, 0.4557, 0.1801, 0.0292], |
|
|
7: [14.6146, 6.4477, 2.2797, 1.1629, 0.6114, 0.3058, 0.1258, 0.0292], |
|
|
8: [14.6146, 6.4477, 2.7391, 1.4467, 0.8319, 0.4936, 0.2667, 0.1258, 0.0292], |
|
|
9: [14.6146, 6.4477, 3.3251, 1.9132, 1.1629, 0.7391, 0.4557, 0.2667, 0.1258, 0.0292], |
|
|
10: [14.6146, 5.9489, 3.3251, 2.0267, 1.2969, 0.8319, 0.5712, 0.3811, 0.2255, 0.1258, 0.0292], |
|
|
11: [14.6146, 6.4477, 3.8092, 2.2797, 1.5286, 1.0421, 0.7391, 0.4936, 0.3437, 0.2255, 0.1258, 0.0292] |
|
|
} |
|
|
} |
|
|
|
|
|
def parse_prior_timesteps(args): |
|
|
if args.custom_ts_1 is not None: |
|
|
try: |
|
|
args.custom_ts_1 = ast.literal_eval(args.custom_ts_1) |
|
|
except Exception: |
|
|
pass |
|
|
else: |
|
|
if args.custom_ts_2 is not None: |
|
|
try: |
|
|
args.custom_ts_2 = ast.literal_eval(args.custom_ts_2) |
|
|
except Exception: |
|
|
pass |
|
|
if args.custom_ts_2 is None: |
|
|
args.custom_ts_2 = args.custom_ts_1 |
|
|
return |
|
|
|
|
|
if args.use_gits: |
|
|
dataset = None |
|
|
if args.model == 'edm': |
|
|
for d in ['cifar10', 'afhqv2', 'ffhq']: |
|
|
if d in args.ckp_path: |
|
|
dataset = d |
|
|
break |
|
|
elif args.model == 'latent_diff': |
|
|
dataset = 'lsun' |
|
|
elif args.model == 'conditioned_latent_diff': |
|
|
dataset = 'sd' |
|
|
|
|
|
if args.steps in PRIOR_TIMESTEPS[dataset]: |
|
|
args.custom_ts_1 = PRIOR_TIMESTEPS[dataset][args.steps] |
|
|
args.custom_ts_2 = args.custom_ts_1 |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def set_seed_everything(seed): |
|
|
random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def parse_arguments(): |
|
|
parser = argparse.ArgumentParser(description="Description of your program") |
|
|
|
|
|
parser.add_argument('--all_config') |
|
|
parser.add_argument('--model', help="edm/latent_diff") |
|
|
model_group = parser.add_argument_group('Model Parameters') |
|
|
model_group.add_argument("--ckp_path", type=str, help="Path to the checkpoint file.") |
|
|
model_group.add_argument("--solver_name", type=str, help="Method for solving: heun/dpm_solver++/uni_pc.") |
|
|
model_group.add_argument("--unipc_variant", type=str, choices=["bh1", "bh2"], help="Variant of UniPC: bh1/bh2.") |
|
|
model_group.add_argument("--steps", type=int, help="Number of sampling steps.") |
|
|
model_group.add_argument("--order", type=int, help="Order for sampling.") |
|
|
model_group.add_argument("--time_mode", type=str, help="Time model: time or lambda.") |
|
|
|
|
|
training_group = parser.add_argument_group('Training Parameters') |
|
|
training_group.add_argument("--seed", type=int, help="seed") |
|
|
training_group.add_argument("--use_ema", action="store_true", help="If we use ema for LSUN latent diff") |
|
|
training_group.add_argument("--log_path", type=str, help="Folder name for storing evaluation results.") |
|
|
training_group.add_argument("--old_log_path", type=str, help="Folder name for storing old evaluation results.") |
|
|
training_group.add_argument("--data_dir", type=str, help="Path to data dir.") |
|
|
training_group.add_argument("--num_train", type=int, help="Number of training sample.") |
|
|
training_group.add_argument("--num_valid", type=int, help="Number of validation sample.") |
|
|
training_group.add_argument("--main_train_batch_size", type=int, help="Batch size for training.") |
|
|
training_group.add_argument("--main_valid_batch_size", type=int, help="Batch size for validation.") |
|
|
training_group.add_argument("--win_rate", type=float, help="Win rate, should be in (0, 0.5]") |
|
|
training_group.add_argument("--prior_bound", type=float, help="Prior bound.") |
|
|
training_group.add_argument("--fix_bound", action="store_true", help="fix bound or not") |
|
|
training_group.add_argument("--loss_type", type=str, choices=["L1", "L2", "LPIPS"], help="Type of loss: L1, L2 or LPIPS.") |
|
|
training_group.add_argument("--training_rounds_v1", type=int, help="Number of training rounds for phase 1.") |
|
|
training_group.add_argument("--training_rounds_v2", type=int, help="Number of training rounds for phase 2.") |
|
|
training_group.add_argument("--lr_time_1", type=float, help="Learning rate for the first phase.") |
|
|
training_group.add_argument("--lr_time_2", type=float, help="Learning rate for the second phase.") |
|
|
training_group.add_argument("--min_lr_time_1", type=float, help="Minimum learning rate for the first phase.") |
|
|
training_group.add_argument("--min_lr_time_2", type=float, help="Minimum learning rate for the second phase.") |
|
|
training_group.add_argument("--momentum_time_1", type=float, help="Momentum for the first phase.") |
|
|
training_group.add_argument("--weight_decay_time_1", type=float, help="Weight decay for the first phase.") |
|
|
training_group.add_argument("--shift_lr", type=float, help="Learning rate for moving latents.") |
|
|
training_group.add_argument("--shift_lr_decay", type=float, help="Learning rate decay for the shift phase.") |
|
|
training_group.add_argument("--lr_time_decay", type=float, help="Learning rate decay for the time phase.") |
|
|
training_group.add_argument("--patient", type=int, help="Patient for the time phase.") |
|
|
training_group.add_argument("--lr2_patient", type=int, help="Patient for the second phase.") |
|
|
training_group.add_argument("--no_v1", action="store_true", help="Skip the first phase.") |
|
|
training_group.add_argument("--visualize", action="store_true", help="Visualize.") |
|
|
training_group.add_argument("--low_gpu", action="store_true", help="If we using low-mem gpu, we need to use checkpoint.") |
|
|
training_group.add_argument("--scale", type=int, help="Guidance scale") |
|
|
training_group.add_argument("--match_prior", action="store_true", help="Whether to initial params by prior timesteps") |
|
|
|
|
|
testing_group = parser.add_argument_group('Testing Parameters') |
|
|
testing_group.add_argument("--load_from_version", type=int, default=2, help="Load from whihc version, default=2") |
|
|
testing_group.add_argument("--custom_ts_1", type=str, help="Custom timesteps 1") |
|
|
testing_group.add_argument("--custom_ts_2", type=str, help="Custom timesteps 2") |
|
|
testing_group.add_argument("--use_gits", action="store_true", help="Use pre-computed gits timesteps") |
|
|
testing_group.add_argument("--learn", action="store_true", help="Load from learned timesteps.") |
|
|
testing_group.add_argument("--load_from", type=str, help="Ckpt path") |
|
|
testing_group.add_argument("--skip_type", type=str, help="Type of skip.") |
|
|
testing_group.add_argument("--num_multi_steps_fid", type=int, help="num_multi_steps_fid") |
|
|
testing_group.add_argument("--fid_folder", type=str, default=None, help="FID path") |
|
|
testing_group.add_argument("--sampling_batch_size", type=int, help="Batch size for FID calculation.") |
|
|
testing_group.add_argument("--sampling_seed", type=int, help="Sampling seed for FID calculation") |
|
|
testing_group.add_argument("--ref_path", type=str, help="Path to dataset reference statistics.") |
|
|
testing_group.add_argument("--total_samples", type=int, help="Total number of sample for FID calculation.") |
|
|
testing_group.add_argument("--save_png", action="store_true", help="Save generated img in png.") |
|
|
testing_group.add_argument("--save_pt", action="store_true", help="Save generated img and latent in pt files.") |
|
|
|
|
|
other_group = parser.add_argument_group('Other Parameters') |
|
|
other_group.add_argument("--prompt_path", type=str, help="Prompt json path for stable diff") |
|
|
other_group.add_argument("--num_prompts", type=int, default=5, help="Number of prompts we want to use, default 5") |
|
|
other_group.add_argument("--num_samples_per_prompt", type=int, default=1, help="Number of samplers per prompt, default 1") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.all_config and os.path.isfile(args.all_config): |
|
|
with open(args.all_config, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
for key, value in config.items(): |
|
|
if not hasattr(args, key) or getattr(args, key) is None: |
|
|
setattr(args, key, value) |
|
|
|
|
|
return args |
|
|
|
|
|
def compute_distance_between_two(x, y, n_channels=3, resolution=256): |
|
|
''' |
|
|
x: bs x 3 x 256 x 256 |
|
|
y: bs x 3 x 256 x 256 |
|
|
''' |
|
|
square_distance = (x - y) ** 2 |
|
|
distance = square_distance.sum(dim=(1, 2, 3)) / (n_channels * resolution * resolution) |
|
|
return distance |
|
|
|
|
|
def compute_distance_between_two_L1(x, y, n_channels=3, resolution=256): |
|
|
''' |
|
|
x: bs x 3 x 256 x 256 |
|
|
y: bs x 3 x 256 x 256 |
|
|
''' |
|
|
square_distance = torch.abs(x - y) |
|
|
distance = square_distance.sum(dim=(1, 2, 3)) / (n_channels * resolution * resolution) |
|
|
return distance |
|
|
|
|
|
def get_solvers(solver_name: str, NFEs: int, order:int, noise_schedule: NoiseScheduleVE, unipc_variant: Optional[str] = None): |
|
|
solver_extra_params = dict() |
|
|
if solver_name == 'euler': |
|
|
steps = NFEs |
|
|
solver = Euler(noise_schedule) |
|
|
elif solver_name == 'heun': |
|
|
steps = NFEs // 2 |
|
|
solver = Heun(noise_schedule) |
|
|
|
|
|
elif solver_name == 'dpm_solver': |
|
|
solver = DPM_Solver(noise_schedule) |
|
|
dpm_steps, dpm_orders = solver.compute_K_and_order(NFEs, order=order) |
|
|
solver_extra_params['dpm_orders'] = dpm_orders |
|
|
solver_extra_params['NFEs'] = NFEs |
|
|
solver_extra_params['dpm_steps'] = dpm_steps |
|
|
|
|
|
steps = dpm_steps |
|
|
elif solver_name == 'dpm_solver++': |
|
|
steps = NFEs |
|
|
solver = DPM_SolverPP(noise_schedule) |
|
|
elif solver_name == 'uni_pc': |
|
|
steps = NFEs |
|
|
solver = UniPC(noise_schedule, variant=unipc_variant) |
|
|
elif solver_name == 'ipndm': |
|
|
steps = NFEs |
|
|
solver = iPNDM(noise_schedule) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
return solver, steps, solver_extra_params |
|
|
|
|
|
def save_arguments_to_yaml(args, filename): |
|
|
with open(filename, 'w') as file: |
|
|
yaml.dump(vars(args), file) |
|
|
|
|
|
def adjust_hyper(args, resolution=64, channel=3): |
|
|
parse_prior_timesteps(args) |
|
|
if args.shift_lr is None: |
|
|
args.shift_lr = 3.0 * 4 / args.steps |
|
|
if not args.fix_bound: |
|
|
args.prior_bound = 0.001 * resolution * resolution * channel / (args.steps ** 2) |
|
|
args.lr_time_2 = args.lr_time_2 / args.steps |
|
|
|
|
|
args.lr_time_2 = round(args.lr_time_2, 8) |
|
|
|
|
|
args.prior_bound = round(args.prior_bound, 8) |
|
|
|
|
|
args.shift_lr = round(args.shift_lr, 8) |
|
|
return args |
|
|
|
|
|
|
|
|
def create_desc(args): |
|
|
NFEs = args.steps |
|
|
method_full = args.solver_name |
|
|
desc = f"{method_full}-N{NFEs}-b{args.prior_bound}-{args.loss_type}-lr2{args.lr_time_2}" |
|
|
desc += f"rv1{args.training_rounds_v1}-rv2{args.training_rounds_v2}-seed{args.seed}" |
|
|
if args.no_v1: |
|
|
desc += "-no_v1_only_v2" |
|
|
if args.match_prior: |
|
|
desc += "-match_prior" |
|
|
return desc |
|
|
|
|
|
|
|
|
|
|
|
def prepare_paths(args): |
|
|
skip_type="" |
|
|
if args.learn: |
|
|
if args.load_from is None: |
|
|
desc = create_desc(args) |
|
|
args.log_path = os.path.join(args.log_path, desc) |
|
|
args.load_from = os.path.join(args.log_path, f'best_v{args.load_from_version}.pt') |
|
|
else: |
|
|
args.log_path = os.path.dirname(args.load_from) |
|
|
desc = os.path.basename(args.log_path) |
|
|
|
|
|
|
|
|
else: |
|
|
NFEs = args.steps |
|
|
solver_name = args.solver_name |
|
|
skip_type = args.skip_type |
|
|
desc = f"{solver_name}_NFE{NFEs}_{skip_type}_seed{args.seed}" |
|
|
|
|
|
|
|
|
if args.fid_folder: |
|
|
os.makedirs(args.fid_folder, exist_ok=True) |
|
|
fid_log_path = os.path.join(args.fid_folder, f"{desc}.txt") |
|
|
else: |
|
|
fid_log_path = None |
|
|
return desc, fid_log_path, skip_type |
|
|
|
|
|
def check_fid_file(fid_log_path): |
|
|
if os.path.exists(fid_log_path): |
|
|
|
|
|
with open(fid_log_path, "r") as f: |
|
|
scores = f.read() |
|
|
|
|
|
try: |
|
|
scores = [float(_) for _ in scores.strip().split()] |
|
|
if len(scores) == 1: |
|
|
print(f"FID: {scores[0]}") |
|
|
elif len(scores) == 2: |
|
|
print(f"FID: {scores[0]}") |
|
|
print(f"IS: {scores[1]}") |
|
|
else: |
|
|
return False |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
return False |
|
|
|
|
|
def is_trained(path): |
|
|
log_path = os.path.join(path, 'log.txt') |
|
|
print(log_path) |
|
|
if not os.path.isfile(log_path): |
|
|
print("log.txt not exist") |
|
|
return False |
|
|
|
|
|
last_line = "" |
|
|
|
|
|
with open(log_path, 'r') as f: |
|
|
|
|
|
for line in f: |
|
|
|
|
|
stripped_line = line.strip() |
|
|
|
|
|
if stripped_line: |
|
|
last_line = stripped_line |
|
|
return "Training time" in last_line |
|
|
|
|
|
|
|
|
def move_tensor_to_device(*args, device): |
|
|
return [arg.to(device) if arg is not None else arg for arg in args] |