File size: 16,254 Bytes
d382778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
from typing import Optional
import torch
import os
from samplers.uni_pc import UniPC
from samplers.heun import Heun
from samplers.dpm_solverpp import DPM_SolverPP
from samplers.dpm_solver import DPM_Solver
from samplers.euler import Euler
from samplers.ipndm import iPNDM
from noise_schedulers import NoiseScheduleVE
import pickle
import argparse
import time
import yaml
import random
import numpy as np
import ast
PRIOR_TIMESTEPS = {
"cifar10": {
4: [80.0, 5.1092, 1.584, 0.47, 0.002],
5: [80.0, 5.8389, 2.1632, 0.8119, 0.2107, 0.002],
6: [80.0, 9.7232, 3.3686, 1.3482, 0.5666, 0.1698, 0.002],
7: [80.0, 10.9836, 3.8811, 1.8543, 0.8119, 0.3183, 0.1079, 0.002],
8: [80.0, 10.9836, 3.8811, 1.8543, 0.9654, 0.47, 0.2107, 0.0665, 0.002],
9: [80.0, 12.3816, 4.459, 2.1632, 1.1431, 0.5666, 0.2597, 0.1079, 0.03, 0.002],
10: [80.0, 13.9293, 5.1092, 2.5152, 1.3482, 0.6799, 0.3183, 0.1698, 0.0665, 0.0225, 0.002],
},
"ffhq": {
4 :[80.0, 7.5699, 2.1632, 0.5666, 0.002],
5 : [80.0, 9.7232, 2.9152, 0.9654, 0.2597, 0.002],
6 : [80.0, 10.9836, 3.8811, 1.584, 0.5666, 0.1698, 0.002],
7 : [80.0, 12.3816, 4.459, 1.8543, 0.8119, 0.3183, 0.1079, 0.002],
8: [80.0, 12.3816, 5.1092, 2.1632, 0.9654, 0.47, 0.2107, 0.0665, 0.002],
9: [80.0, 13.9293, 5.8389, 2.9152, 1.3482, 0.6799, 0.3183, 0.1359, 0.0515, 0.002],
10: [80.0, 13.9293, 5.8389, 2.9152, 1.584, 0.8119, 0.3878, 0.2107, 0.0851, 0.03, 0.002],
},
"afhqv2": {
4 : [80.0, 7.5699, 2.1632, 0.3878, 0.002],
5 : [80.0, 8.5888, 2.9152, 0.9654, 0.2107, 0.002],
6 : [80.0, 9.7232, 3.8811, 1.584, 0.47, 0.1359, 0.002],
7 : [80.0, 10.9836, 4.459, 1.8543, 0.6799, 0.2597, 0.0851, 0.002],
8: [80.0, 12.3816, 5.1092, 2.5152, 1.1431, 0.47, 0.2107, 0.0665, 0.002],
9: [80.0, 13.9293, 5.8389, 2.9152, 1.3482, 0.6799, 0.3183, 0.1359, 0.0515, 0.002],
10: [80.0, 13.9293, 5.8389, 2.9152, 1.584, 0.8119, 0.3878, 0.2107, 0.1079, 0.0395, 0.002],
},
'lsun': {
4: [83.8225, 2.1307, 0.9556, 0.425, 0.0388],
5:[83.8225, 2.4793, 1.1629, 0.5745, 0.2411, 0.0388],
6: [83.8225, 2.4793, 1.2928, 0.7324, 0.3678, 0.1578, 0.0388],
7: [83.8225, 2.9282, 1.4464, 0.8717, 0.4929, 0.2586, 0.109, 0.0388],
8: [83.8225, 3.5196, 1.854, 1.1629, 0.7324, 0.425, 0.2249, 0.1009, 0.0388],
9: [83.8225, 3.5196, 1.854, 1.1629, 0.7324, 0.4574, 0.2773, 0.1578, 0.0731, 0.0388],
10:[83.8225, 4.3198, 2.1307, 1.2928, 0.8717, 0.5745, 0.3678, 0.2411, 0.1365, 0.0672, 0.0388],
},
'sd': {
3: [14.6146, 1.7083, 0.532, 0.0292],
4: [14.6146, 3.1131, 1.0421, 0.3811, 0.0292],
5: [14.6146, 4.39, 1.5286, 0.6526, 0.2667, 0.0292],
6: [14.6146, 4.7242, 1.9132, 0.9324, 0.4557, 0.1801, 0.0292],
7: [14.6146, 6.4477, 2.2797, 1.1629, 0.6114, 0.3058, 0.1258, 0.0292],
8: [14.6146, 6.4477, 2.7391, 1.4467, 0.8319, 0.4936, 0.2667, 0.1258, 0.0292],
9: [14.6146, 6.4477, 3.3251, 1.9132, 1.1629, 0.7391, 0.4557, 0.2667, 0.1258, 0.0292],
10: [14.6146, 5.9489, 3.3251, 2.0267, 1.2969, 0.8319, 0.5712, 0.3811, 0.2255, 0.1258, 0.0292],
11: [14.6146, 6.4477, 3.8092, 2.2797, 1.5286, 1.0421, 0.7391, 0.4936, 0.3437, 0.2255, 0.1258, 0.0292]
}
}
def parse_prior_timesteps(args):
if args.custom_ts_1 is not None:
try:
args.custom_ts_1 = ast.literal_eval(args.custom_ts_1)
except Exception:
pass
else:
if args.custom_ts_2 is not None:
try:
args.custom_ts_2 = ast.literal_eval(args.custom_ts_2)
except Exception:
pass
if args.custom_ts_2 is None:
args.custom_ts_2 = args.custom_ts_1
return
if args.use_gits:
dataset = None
if args.model == 'edm':
for d in ['cifar10', 'afhqv2', 'ffhq']:
if d in args.ckp_path:
dataset = d
break
elif args.model == 'latent_diff':
dataset = 'lsun'
elif args.model == 'conditioned_latent_diff':
dataset = 'sd'
if args.steps in PRIOR_TIMESTEPS[dataset]:
args.custom_ts_1 = PRIOR_TIMESTEPS[dataset][args.steps]
args.custom_ts_2 = args.custom_ts_1
else:
raise NotImplementedError
def set_seed_everything(seed):
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_arguments():
parser = argparse.ArgumentParser(description="Description of your program")
parser.add_argument('--all_config')
parser.add_argument('--model', help="edm/latent_diff")
model_group = parser.add_argument_group('Model Parameters')
model_group.add_argument("--ckp_path", type=str, help="Path to the checkpoint file.")
model_group.add_argument("--solver_name", type=str, help="Method for solving: heun/dpm_solver++/uni_pc.")
model_group.add_argument("--unipc_variant", type=str, choices=["bh1", "bh2"], help="Variant of UniPC: bh1/bh2.")
model_group.add_argument("--steps", type=int, help="Number of sampling steps.")
model_group.add_argument("--order", type=int, help="Order for sampling.")
model_group.add_argument("--time_mode", type=str, help="Time model: time or lambda.")
training_group = parser.add_argument_group('Training Parameters')
training_group.add_argument("--seed", type=int, help="seed")
training_group.add_argument("--use_ema", action="store_true", help="If we use ema for LSUN latent diff")
training_group.add_argument("--log_path", type=str, help="Folder name for storing evaluation results.")
training_group.add_argument("--old_log_path", type=str, help="Folder name for storing old evaluation results.")
training_group.add_argument("--data_dir", type=str, help="Path to data dir.")
training_group.add_argument("--num_train", type=int, help="Number of training sample.")
training_group.add_argument("--num_valid", type=int, help="Number of validation sample.")
training_group.add_argument("--main_train_batch_size", type=int, help="Batch size for training.")
training_group.add_argument("--main_valid_batch_size", type=int, help="Batch size for validation.")
training_group.add_argument("--win_rate", type=float, help="Win rate, should be in (0, 0.5]")
training_group.add_argument("--prior_bound", type=float, help="Prior bound.")
training_group.add_argument("--fix_bound", action="store_true", help="fix bound or not")
training_group.add_argument("--loss_type", type=str, choices=["L1", "L2", "LPIPS"], help="Type of loss: L1, L2 or LPIPS.")
training_group.add_argument("--training_rounds_v1", type=int, help="Number of training rounds for phase 1.")
training_group.add_argument("--training_rounds_v2", type=int, help="Number of training rounds for phase 2.")
training_group.add_argument("--lr_time_1", type=float, help="Learning rate for the first phase.")
training_group.add_argument("--lr_time_2", type=float, help="Learning rate for the second phase.")
training_group.add_argument("--min_lr_time_1", type=float, help="Minimum learning rate for the first phase.")
training_group.add_argument("--min_lr_time_2", type=float, help="Minimum learning rate for the second phase.")
training_group.add_argument("--momentum_time_1", type=float, help="Momentum for the first phase.")
training_group.add_argument("--weight_decay_time_1", type=float, help="Weight decay for the first phase.")
training_group.add_argument("--shift_lr", type=float, help="Learning rate for moving latents.")
training_group.add_argument("--shift_lr_decay", type=float, help="Learning rate decay for the shift phase.")
training_group.add_argument("--lr_time_decay", type=float, help="Learning rate decay for the time phase.")
training_group.add_argument("--patient", type=int, help="Patient for the time phase.")
training_group.add_argument("--lr2_patient", type=int, help="Patient for the second phase.")
training_group.add_argument("--no_v1", action="store_true", help="Skip the first phase.")
training_group.add_argument("--visualize", action="store_true", help="Visualize.")
training_group.add_argument("--low_gpu", action="store_true", help="If we using low-mem gpu, we need to use checkpoint.")
training_group.add_argument("--scale", type=int, help="Guidance scale")
training_group.add_argument("--match_prior", action="store_true", help="Whether to initial params by prior timesteps")
testing_group = parser.add_argument_group('Testing Parameters')
testing_group.add_argument("--load_from_version", type=int, default=2, help="Load from whihc version, default=2")
testing_group.add_argument("--custom_ts_1", type=str, help="Custom timesteps 1")
testing_group.add_argument("--custom_ts_2", type=str, help="Custom timesteps 2")
testing_group.add_argument("--use_gits", action="store_true", help="Use pre-computed gits timesteps")
testing_group.add_argument("--learn", action="store_true", help="Load from learned timesteps.")
testing_group.add_argument("--load_from", type=str, help="Ckpt path")
testing_group.add_argument("--skip_type", type=str, help="Type of skip.")
testing_group.add_argument("--num_multi_steps_fid", type=int, help="num_multi_steps_fid")
testing_group.add_argument("--fid_folder", type=str, default=None, help="FID path")
testing_group.add_argument("--sampling_batch_size", type=int, help="Batch size for FID calculation.")
testing_group.add_argument("--sampling_seed", type=int, help="Sampling seed for FID calculation")
testing_group.add_argument("--ref_path", type=str, help="Path to dataset reference statistics.")
testing_group.add_argument("--total_samples", type=int, help="Total number of sample for FID calculation.")
testing_group.add_argument("--save_png", action="store_true", help="Save generated img in png.")
testing_group.add_argument("--save_pt", action="store_true", help="Save generated img and latent in pt files.")
other_group = parser.add_argument_group('Other Parameters')
other_group.add_argument("--prompt_path", type=str, help="Prompt json path for stable diff")
other_group.add_argument("--num_prompts", type=int, default=5, help="Number of prompts we want to use, default 5")
other_group.add_argument("--num_samples_per_prompt", type=int, default=1, help="Number of samplers per prompt, default 1")
args = parser.parse_args()
# Load the config file if specified
if args.all_config and os.path.isfile(args.all_config):
with open(args.all_config, 'r') as f:
config = yaml.safe_load(f)
# Override the arguments with config values if they are None
for key, value in config.items():
if not hasattr(args, key) or getattr(args, key) is None:
setattr(args, key, value)
return args
def compute_distance_between_two(x, y, n_channels=3, resolution=256):
'''
x: bs x 3 x 256 x 256
y: bs x 3 x 256 x 256
'''
square_distance = (x - y) ** 2
distance = square_distance.sum(dim=(1, 2, 3)) / (n_channels * resolution * resolution)
return distance
def compute_distance_between_two_L1(x, y, n_channels=3, resolution=256):
'''
x: bs x 3 x 256 x 256
y: bs x 3 x 256 x 256
'''
square_distance = torch.abs(x - y)
distance = square_distance.sum(dim=(1, 2, 3)) / (n_channels * resolution * resolution)
return distance
def get_solvers(solver_name: str, NFEs: int, order:int, noise_schedule: NoiseScheduleVE, unipc_variant: Optional[str] = None):
solver_extra_params = dict()
if solver_name == 'euler':
steps = NFEs
solver = Euler(noise_schedule)
elif solver_name == 'heun':
steps = NFEs // 2
solver = Heun(noise_schedule)
elif solver_name == 'dpm_solver':
solver = DPM_Solver(noise_schedule)
dpm_steps, dpm_orders = solver.compute_K_and_order(NFEs, order=order)
solver_extra_params['dpm_orders'] = dpm_orders
solver_extra_params['NFEs'] = NFEs
solver_extra_params['dpm_steps'] = dpm_steps
steps = dpm_steps
elif solver_name == 'dpm_solver++':
steps = NFEs
solver = DPM_SolverPP(noise_schedule)
elif solver_name == 'uni_pc':
steps = NFEs
solver = UniPC(noise_schedule, variant=unipc_variant)
elif solver_name == 'ipndm':
steps = NFEs
solver = iPNDM(noise_schedule)
else:
raise NotImplementedError
return solver, steps, solver_extra_params
def save_arguments_to_yaml(args, filename):
with open(filename, 'w') as file:
yaml.dump(vars(args), file)
def adjust_hyper(args, resolution=64, channel=3):
parse_prior_timesteps(args)
if args.shift_lr is None:
args.shift_lr = 3.0 * 4 / args.steps
if not args.fix_bound:
args.prior_bound = 0.001 * resolution * resolution * channel / (args.steps ** 2)
args.lr_time_2 = args.lr_time_2 / args.steps
args.lr_time_2 = round(args.lr_time_2, 8)
# round prior_bound
args.prior_bound = round(args.prior_bound, 8)
# round shift_lr
args.shift_lr = round(args.shift_lr, 8)
return args
def create_desc(args):
NFEs = args.steps
method_full = args.solver_name
desc = f"{method_full}-N{NFEs}-b{args.prior_bound}-{args.loss_type}-lr2{args.lr_time_2}"
desc += f"rv1{args.training_rounds_v1}-rv2{args.training_rounds_v2}-seed{args.seed}"
if args.no_v1:
desc += "-no_v1_only_v2"
if args.match_prior:
desc += "-match_prior"
return desc
def prepare_paths(args):
skip_type=""
if args.learn:
if args.load_from is None:
desc = create_desc(args)
args.log_path = os.path.join(args.log_path, desc)
args.load_from = os.path.join(args.log_path, f'best_v{args.load_from_version}.pt')
else:
args.log_path = os.path.dirname(args.load_from)
desc = os.path.basename(args.log_path)
# if not is_trained(args.log_path):
# raise ValueError("Model not trained!")
else:
NFEs = args.steps
solver_name = args.solver_name
skip_type = args.skip_type
desc = f"{solver_name}_NFE{NFEs}_{skip_type}_seed{args.seed}"
# create fid folder
if args.fid_folder:
os.makedirs(args.fid_folder, exist_ok=True)
fid_log_path = os.path.join(args.fid_folder, f"{desc}.txt")
else:
fid_log_path = None
return desc, fid_log_path, skip_type
def check_fid_file(fid_log_path):
if os.path.exists(fid_log_path):
# check if FID has been computed
with open(fid_log_path, "r") as f:
scores = f.read()
# check if fid is a number
try:
scores = [float(_) for _ in scores.strip().split()]
if len(scores) == 1:
print(f"FID: {scores[0]}")
elif len(scores) == 2:
print(f"FID: {scores[0]}")
print(f"IS: {scores[1]}")
else:
return False
return True
except ValueError:
return False
return False
def is_trained(path):
log_path = os.path.join(path, 'log.txt')
print(log_path)
if not os.path.isfile(log_path):
print("log.txt not exist")
return False
last_line = ""
# Open the file in read mode
with open(log_path, 'r') as f:
# Read each line in the file
for line in f:
# Strip any leading or trailing whitespace
stripped_line = line.strip()
# Check if the line is not empty
if stripped_line:
last_line = stripped_line # Update last non-empty line
return "Training time" in last_line
def move_tensor_to_device(*args, device):
return [arg.to(device) if arg is not None else arg for arg in args] |