|
|
from typing import List, Optional |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.nn import functional as F |
|
|
|
|
|
import lpips |
|
|
import logging |
|
|
import matplotlib.pyplot as plt |
|
|
import imageio, PIL |
|
|
|
|
|
import os |
|
|
import math |
|
|
import pickle |
|
|
import numpy as np |
|
|
|
|
|
from dataset import LD3Dataset |
|
|
from utils import move_tensor_to_device, compute_distance_between_two, compute_distance_between_two_L1 |
|
|
|
|
|
def save_gif(snapshot_path: str): |
|
|
care_files = [f for f in os.listdir(snapshot_path) if "log_best" in f] |
|
|
care_files = sorted(care_files, key=lambda f: int(f.split("_")[-1].replace(".png", ""))) |
|
|
images = [] |
|
|
for f in care_files: |
|
|
images.append(imageio.imread(os.path.join(snapshot_path, f))) |
|
|
imageio.mimsave(os.path.join(snapshot_path, "gif.gif"), images, duration=100.) |
|
|
print(f"Saved gif to {os.path.join(snapshot_path, 'gif.gif')}") |
|
|
|
|
|
|
|
|
def visual(input_, name="test.png", img_resolution=32, img_channels=3): |
|
|
input_ = (input_ + 1.) / 2. |
|
|
batch_size = input_.shape[0] |
|
|
gridh = int(math.sqrt(batch_size)) |
|
|
|
|
|
for i in range(1, gridh+1): |
|
|
if batch_size % i == 0: |
|
|
gridh = i |
|
|
|
|
|
gridw = batch_size // gridh |
|
|
image = (input_ * 255.).clip(0, 255).to(torch.uint8) |
|
|
image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) |
|
|
image = image.reshape(gridh * img_resolution, gridw * img_resolution, img_channels) |
|
|
image = image.cpu().numpy() |
|
|
PIL.Image.fromarray(image, 'RGB').save(name) |
|
|
|
|
|
def custom_collate_fn(batch): |
|
|
collated_batch = [] |
|
|
for samples in zip(*batch): |
|
|
if any(item is None for item in samples): |
|
|
collated_batch.append(None) |
|
|
else: |
|
|
collated_batch.append(torch.utils.data._utils.collate.default_collate(samples)) |
|
|
return collated_batch |
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
train_data: any |
|
|
valid_data: any |
|
|
train_batch_size: int |
|
|
valid_batch_size: int |
|
|
lr_time_1: float |
|
|
lr_time_2: float |
|
|
shift_lr: float |
|
|
shift_lr_decay: float = 0.5 |
|
|
min_lr_time_1: float = 5e-5 |
|
|
min_lr_time_2: float = 1e-6 |
|
|
win_rate: float = 0.5 |
|
|
patient: int = 5 |
|
|
lr2_patient: int = 5 |
|
|
lr_time_decay: float = 0.8 |
|
|
momentum_time_1: float = 0.9 |
|
|
weight_decay_time_1: float = 0.0 |
|
|
loss_type: str = "LPIPS" |
|
|
visualize: bool = False |
|
|
no_v1: bool = False |
|
|
prior_timesteps: Optional[List[float]] = None |
|
|
match_prior: bool = False |
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
net: any |
|
|
decoding_fn: any |
|
|
noise_schedule: any |
|
|
solver: any |
|
|
solver_name: str |
|
|
order: int |
|
|
steps: int |
|
|
prior_bound: float |
|
|
resolution: int |
|
|
channels: int |
|
|
time_mode: str |
|
|
solver_extra_params: Optional[dict] = None |
|
|
snapshot_path: str = "logs" |
|
|
device: Optional[str] = None |
|
|
|
|
|
class LD3Trainer: |
|
|
def __init__( |
|
|
self, model_config: ModelConfig, training_config: TrainingConfig |
|
|
) -> None: |
|
|
|
|
|
self.net = model_config.net |
|
|
self.decoding_fn = model_config.decoding_fn |
|
|
self.noise_schedule = model_config.noise_schedule |
|
|
self.solver = model_config.solver |
|
|
self.solver_name = model_config.solver_name |
|
|
self.order = model_config.order |
|
|
self.steps = model_config.steps |
|
|
self.prior_bound = model_config.prior_bound |
|
|
self.resolution = model_config.resolution |
|
|
self.channels = model_config.channels |
|
|
self.time_mode = model_config.time_mode |
|
|
|
|
|
|
|
|
self.lr_time_1 = training_config.lr_time_1 |
|
|
self.lr_time_2 = training_config.lr_time_2 |
|
|
self.shift_lr = training_config.shift_lr |
|
|
self.shift_lr_decay = training_config.shift_lr_decay |
|
|
self.min_lr_time_1 = training_config.min_lr_time_1 |
|
|
self.min_lr_time_2 = training_config.min_lr_time_2 |
|
|
self.lr_time_decay = training_config.lr_time_decay |
|
|
self.momentum_time_1 = training_config.momentum_time_1 |
|
|
self.weight_decay_time_1 = training_config.weight_decay_time_1 |
|
|
|
|
|
|
|
|
self.train_data = training_config.train_data |
|
|
self.valid_data = training_config.valid_data |
|
|
self.train_batch_size = training_config.train_batch_size |
|
|
self.valid_batch_size = training_config.valid_batch_size |
|
|
self._create_valid_loaders() |
|
|
self._create_train_loader() |
|
|
|
|
|
|
|
|
self.cur_iter = 0 |
|
|
self.cur_round = 0 |
|
|
self.count_worse = 0 |
|
|
self.count_min_lr1_hit = 0 |
|
|
self.count_min_lr2_hit = 0 |
|
|
self.best_loss = float("inf") |
|
|
|
|
|
|
|
|
self.patient = training_config.patient |
|
|
self.lr2_patient = training_config.lr2_patient |
|
|
self.no_v1 = training_config.no_v1 |
|
|
self.win_rate = training_config.win_rate |
|
|
self.snapshot_path = model_config.snapshot_path |
|
|
os.makedirs(self.snapshot_path, exist_ok=True) |
|
|
self.visualize = training_config.visualize |
|
|
|
|
|
|
|
|
self._set_device(model_config.device) |
|
|
self.params1, self.params2 = self._initialize_params() |
|
|
self.optimizer_lamb1 = torch.optim.RMSprop( |
|
|
[self.params1], |
|
|
lr=training_config.lr_time_1, |
|
|
momentum=training_config.momentum_time_1, |
|
|
weight_decay=training_config.weight_decay_time_1, |
|
|
) |
|
|
self.optimizer_lamb2 = torch.optim.SGD( |
|
|
[self.params2], lr=training_config.lr_time_2 |
|
|
) |
|
|
self.prior_timesteps = training_config.prior_timesteps |
|
|
self.match_prior = training_config.match_prior |
|
|
|
|
|
|
|
|
self.solver_extra_params = model_config.solver_extra_params or {} |
|
|
self.lambda_min = self.noise_schedule.lambda_min |
|
|
self.lambda_max = self.noise_schedule.lambda_max |
|
|
self.time_max = self.noise_schedule.inverse_lambda(self.lambda_min) |
|
|
self.time_min = self.noise_schedule.inverse_lambda(self.lambda_max) |
|
|
|
|
|
|
|
|
self._compute_baseline() |
|
|
|
|
|
|
|
|
self.loss_type = training_config.loss_type |
|
|
self.loss_fn = self._initialize_loss_fn() |
|
|
self.loss_vector = None |
|
|
|
|
|
|
|
|
def _train_to_match_prior(self, prior_timesteps=None): |
|
|
if prior_timesteps is None: |
|
|
prior_timesteps = self.prior_timesteps |
|
|
|
|
|
if prior_timesteps is None: |
|
|
return |
|
|
logging.info(f"Matching prior timesteps") |
|
|
prior_timesteps = self.noise_schedule.inverse_lambda(-np.log(prior_timesteps)).to(self.device).float() |
|
|
|
|
|
dis_model = discretize_model_wrapper( |
|
|
self.params1, |
|
|
self.params2, |
|
|
self.lambda_max, |
|
|
self.lambda_min, |
|
|
self.noise_schedule, |
|
|
self.time_mode, |
|
|
self.win_rate, |
|
|
) |
|
|
|
|
|
self.params1.requires_grad = True |
|
|
self.params2.requires_grad = False |
|
|
|
|
|
loss_time = float("inf") |
|
|
while loss_time > 1e-3: |
|
|
self.optimizer_lamb1.zero_grad() |
|
|
self.optimizer_lamb2.zero_grad() |
|
|
times1, times2 = dis_model() |
|
|
loss_time = (times1 - prior_timesteps).pow(2).mean() |
|
|
logging.info(f"Loss time: {loss_time}") |
|
|
loss_time.backward() |
|
|
self.optimizer_lamb1.step() |
|
|
|
|
|
def _initialize_loss_fn(self): |
|
|
if self.loss_type == 'LPIPS': |
|
|
return lpips.LPIPS(net='vgg').to(self.device) |
|
|
elif self.loss_type == 'L2': |
|
|
return lambda x, y : compute_distance_between_two(x, y, self.channels, self.resolution) |
|
|
elif self.loss_type == 'L1': |
|
|
return lambda x, y: compute_distance_between_two_L1(x, y, self.channels, self.resolution) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def _initialize_params(self): |
|
|
params1 = torch.nn.Parameter(torch.ones(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True) |
|
|
params2 = torch.nn.Parameter(torch.zeros(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True) |
|
|
return params1, params2 |
|
|
|
|
|
def _set_device(self, device): |
|
|
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def _create_valid_loaders(self): |
|
|
self.valid_loader = DataLoader(self.valid_data, batch_size=self.train_batch_size, shuffle=False, collate_fn=custom_collate_fn) |
|
|
self.valid_only_loader = DataLoader(self.valid_data, batch_size=self.valid_batch_size, shuffle=False, collate_fn=custom_collate_fn) |
|
|
|
|
|
def _create_train_loader(self): |
|
|
self.train_loader = DataLoader(self.train_data, batch_size=self.train_batch_size, shuffle=True, collate_fn=custom_collate_fn) |
|
|
|
|
|
def _solve_ode(self, timesteps=None, img=None, latent=None, condition=None, uncondition=None, valid=False): |
|
|
batch_size = latent.shape[0] |
|
|
latent = latent.reshape(batch_size, self.channels, self.resolution, self.resolution) |
|
|
dis_model = discretize_model_wrapper( |
|
|
self.params1, |
|
|
self.params2, |
|
|
self.lambda_max, |
|
|
self.lambda_min, |
|
|
self.noise_schedule, |
|
|
self.time_mode, |
|
|
self.win_rate, |
|
|
) |
|
|
|
|
|
if timesteps is None: |
|
|
timesteps1, timesteps2 = dis_model() |
|
|
else: |
|
|
timesteps1 = timesteps |
|
|
timesteps2 = timesteps |
|
|
|
|
|
if not valid and timesteps is None: |
|
|
tst = torch.cat([timesteps1, timesteps2], dim=0).detach().cpu() |
|
|
torch.save(tst, os.path.join(self.snapshot_path, f"t_steps.pt")) |
|
|
|
|
|
self.t_steps1 = timesteps1.detach() |
|
|
self.t_steps2 = timesteps2.detach() |
|
|
lamb1 = self.noise_schedule.marginal_lambda(timesteps1) |
|
|
lamb2 = self.noise_schedule.marginal_lambda(timesteps2) |
|
|
self.logSNR1 = lamb1.detach().cpu() |
|
|
self.logSNR2 = lamb2.detach().cpu() |
|
|
|
|
|
x_next_ = self.noise_schedule.prior_transformation(latent) |
|
|
x_next_ = self.solver.sample_simple( |
|
|
model_fn=self.net, |
|
|
x=x_next_, |
|
|
timesteps=timesteps1, |
|
|
timesteps2=timesteps2, |
|
|
order=self.order, |
|
|
NFEs=self.steps, |
|
|
condition=condition, |
|
|
unconditional_condition=uncondition, |
|
|
**self.solver_extra_params, |
|
|
) |
|
|
x_next_ = self.decoding_fn(x_next_) |
|
|
self.loss_vector = self.loss_fn(img.float(), x_next_.float()).squeeze() |
|
|
loss = self.loss_vector.mean() |
|
|
logging.info(f"{self._current_version} Loss: {loss.item()}") |
|
|
|
|
|
return loss, x_next_.float(), img.float() |
|
|
|
|
|
|
|
|
@property |
|
|
def _current_version(self): |
|
|
return 'Ver1' if self._is_in_version_1() else 'Ver2' |
|
|
|
|
|
def _is_in_version_1(self): |
|
|
return self.cur_round < self.training_rounds_v1 |
|
|
|
|
|
def _compute_baseline(self): |
|
|
self.straight_line = torch.linspace(self.lambda_min, self.lambda_max, self.steps + 1) |
|
|
self.time_logSNR = self.noise_schedule.inverse_lambda(self.straight_line).to(self.device) |
|
|
time_max = self.noise_schedule.inverse_lambda(self.lambda_min) |
|
|
time_min = self.noise_schedule.inverse_lambda(self.lambda_max) |
|
|
self.time_s = torch.linspace(time_max.item(), time_min.item(), 1000) |
|
|
self.time_straight = torch.linspace(time_max.item(), time_min.item(), self.steps + 1) |
|
|
self.time_straight = self.time_straight.to(self.device) |
|
|
self.straight_time = self.noise_schedule.marginal_lambda(self.time_s) |
|
|
t_order = 2 |
|
|
self.time_q = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), 1000)**t_order |
|
|
self.quadratic_time = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), self.steps + 1)**t_order |
|
|
|
|
|
self.quadratic_time = self.quadratic_time.to(self.device) |
|
|
self.time_quadratic = self.noise_schedule.marginal_lambda(self.time_q) |
|
|
|
|
|
self.time_edm = self.solver.get_time_steps('edm', time_max.item(), time_min.item(), 999, self.device) |
|
|
self.lambda_edm = self.noise_schedule.marginal_lambda(self.time_edm) |
|
|
|
|
|
def _run_validation(self): |
|
|
total_loss = 0. |
|
|
count = 0 |
|
|
outputs = list() |
|
|
targets = list() |
|
|
with torch.no_grad(): |
|
|
for img, latent, ori_latent, condition, uncondition in self.valid_only_loader: |
|
|
|
|
|
|
|
|
img = img.to(self.device) |
|
|
latent = latent.to(self.device).reshape(latent.shape[0], -1) |
|
|
ori_latent = ori_latent.to(self.device).reshape(latent.shape[0], -1) |
|
|
if condition is not None: |
|
|
condition = condition.to(self.device) |
|
|
if uncondition is not None: |
|
|
uncondition = uncondition.to(self.device) |
|
|
loss, output, target = self._solve_ode(img=img, latent=latent, condition=condition, uncondition=uncondition, valid=True) |
|
|
|
|
|
total_loss += loss.item() |
|
|
count += 1 |
|
|
outputs.append(output) |
|
|
targets.append(target) |
|
|
|
|
|
output = torch.cat(outputs, dim=0) |
|
|
target = torch.cat(targets, dim=0) |
|
|
return total_loss / count, output, target |
|
|
|
|
|
def _visual_times(self) -> None: |
|
|
""" |
|
|
Visualize time discretization of baselines and ours |
|
|
""" |
|
|
|
|
|
log_path = os.path.join(self.snapshot_path, f"log_best_{self.cur_iter}.png") |
|
|
|
|
|
plt.plot(self.logSNR1.cpu().numpy(), 'o', label="Our discretization1") |
|
|
plt.plot(self.logSNR2.cpu().numpy(), 'x', label="Our discretization2") |
|
|
x_axis = np.linspace(0, self.steps, self.steps + 1) |
|
|
plt.plot(x_axis, self.straight_line.cpu().numpy(), label="Baseline logSNR") |
|
|
x_axis = np.linspace(0, self.steps, 1000) |
|
|
plt.plot(x_axis, self.straight_time.cpu().numpy(), label="Baseline time uniform") |
|
|
plt.plot(x_axis, self.time_quadratic.cpu().numpy(), label="Baseline time quadratic") |
|
|
plt.plot(x_axis, self.lambda_edm.cpu().numpy(), label="Baseline time edm") |
|
|
|
|
|
|
|
|
plt.xlabel("Reverse step i") |
|
|
plt.ylabel("LogSNR(t_i)") |
|
|
plt.legend() |
|
|
plt.tight_layout() |
|
|
plt.savefig(log_path) |
|
|
plt.close() |
|
|
|
|
|
def _save_checkpoint(self): |
|
|
snapshot = {} |
|
|
snapshot["params1"] = self.params1.data |
|
|
snapshot["params2"] = self.params2.data |
|
|
snapshot["best_t_steps"] = torch.cat([self.t_steps1, self.t_steps2], dim=0) |
|
|
|
|
|
if self._is_in_version_1(): |
|
|
torch.save(snapshot, os.path.join(self.snapshot_path, "best_v1.pt")) |
|
|
torch.save(snapshot, os.path.join(self.snapshot_path, "best_v2.pt")) |
|
|
torch.save(snapshot, os.path.join(self.snapshot_path, f"best_t_steps_{self.cur_iter}.pt")) |
|
|
|
|
|
|
|
|
pickle.dump(self.train_data, open(os.path.join(self.snapshot_path, "train_data.pkl"), "wb")) |
|
|
pickle.dump(self.valid_data, open(os.path.join(self.snapshot_path, "valid_data.pkl"), "wb")) |
|
|
|
|
|
def _load_checkpoint(self, reload_data:bool): |
|
|
if self._is_in_version_1(): |
|
|
snapshot = torch.load(os.path.join(self.snapshot_path, "best_v1.pt")) |
|
|
else: |
|
|
snapshot = torch.load(os.path.join(self.snapshot_path, "best_v2.pt")) |
|
|
|
|
|
self.params1.data = snapshot["params1"].cuda() |
|
|
self.params2.data = snapshot["params2"].cuda() |
|
|
|
|
|
if reload_data: |
|
|
self.train_data = pickle.load(open(os.path.join(self.snapshot_path, "train_data.pkl"), "rb")) |
|
|
self.valid_data = pickle.load(open(os.path.join(self.snapshot_path, "valid_data.pkl"), "rb")) |
|
|
self._create_train_loader() |
|
|
self._create_valid_loaders() |
|
|
|
|
|
def _examine_checkpoint(self, iter: int) -> None: |
|
|
logging.info(f"{self._current_version} Saving snapshot at iter {iter}") |
|
|
total_loss, output, target = self._run_validation() |
|
|
|
|
|
if (iter % 5 == 0 or total_loss < self.best_loss) and self.visualize: |
|
|
visual(torch.cat([output[:8], target[:8]], dim=0), os.path.join(self.snapshot_path, f"learned_newnoise_ep{iter}.png"), img_resolution=self.resolution) |
|
|
|
|
|
if total_loss < self.best_loss: |
|
|
self.best_loss = total_loss |
|
|
self.count_worse = 0 |
|
|
self._save_checkpoint() |
|
|
self._visual_times() |
|
|
save_gif(self.snapshot_path) |
|
|
else: |
|
|
self.count_worse += 1 |
|
|
logging.info(f"{self._current_version} Count worse: {self.count_worse}") |
|
|
|
|
|
logging.info(f"{self._current_version} Validation loss: {total_loss}, best loss: {self.best_loss}") |
|
|
logging.info(f"{self._current_version} Iter {iter} snapshot saved!") |
|
|
|
|
|
if self.count_worse >= self.patient: |
|
|
logging.info(f"{self._current_version} Loading best model") |
|
|
self._load_checkpoint(reload_data=True) |
|
|
self.count_worse = 0 |
|
|
|
|
|
self.optimizer_lamb1.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb1.param_groups[0]['lr'], self.min_lr_time_1) |
|
|
logging.info(f"{self._current_version} Decay time1 lr to {self.optimizer_lamb1.param_groups[0]['lr']}") |
|
|
|
|
|
if self._is_in_version_1(): |
|
|
if self.optimizer_lamb1.param_groups[0]['lr'] <= self.min_lr_time_1: |
|
|
self.count_min_lr1_hit += 1 |
|
|
else: |
|
|
self.optimizer_lamb2.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb2.param_groups[0]['lr'], self.min_lr_time_2) |
|
|
logging.info(f"{self._current_version} Decay time2 lr to {self.optimizer_lamb2.param_groups[0]['lr']}") |
|
|
if self.optimizer_lamb2.param_groups[0]['lr'] <= self.min_lr_time_2: |
|
|
self.count_min_lr2_hit += 1 |
|
|
|
|
|
def _set_trainable_params(self, is_train:bool, is_no_v1:bool)->None: |
|
|
if is_train: |
|
|
self.params1.requires_grad = True |
|
|
self.params2.requires_grad = not self._is_in_version_1() |
|
|
|
|
|
if is_no_v1: |
|
|
self.params1.requires_grad = False |
|
|
self.params2.requires_grad = True |
|
|
|
|
|
else: |
|
|
self.params1.requires_grad = False |
|
|
self.params2.requires_grad = False |
|
|
|
|
|
def _log_valid_distance(self, ori_latent: torch.tensor, latent: torch.tensor): |
|
|
assert ori_latent.shape == latent.shape, "Shape of ori_latent and latent mismatched" |
|
|
sq = (latent.reshape(latent.shape[0], -1) - ori_latent.reshape(latent.shape[0], -1)).pow(2) |
|
|
distances = sq.sum(dim=1).sqrt().detach().cpu().numpy() |
|
|
logging.info(f"{self._current_version} Distance: {distances}") |
|
|
|
|
|
def _update_dataloader(self, ori_latents:List[torch.tensor], |
|
|
latents:List[torch.tensor], |
|
|
targets:List[torch.tensor], |
|
|
conditions: List[Optional[torch.tensor]], |
|
|
unconditions: List[Optional[torch.tensor]], |
|
|
is_train:bool): |
|
|
custom_train_dataset = LD3Dataset(ori_latents, latents, targets, conditions, unconditions) |
|
|
if is_train: |
|
|
self.train_data = custom_train_dataset |
|
|
self._create_train_loader() |
|
|
else: |
|
|
self.valid_data = custom_train_dataset |
|
|
self._create_valid_loaders() |
|
|
|
|
|
def _update_latents(self, latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, prior_bound): |
|
|
parameter_data_detached = latent_params.detach() |
|
|
cloned_ori_latent = ori_latent.clone() |
|
|
diff = parameter_data_detached.data - cloned_ori_latent |
|
|
diff_norm = diff.norm(dim=1, keepdim=True) |
|
|
pass_bound = diff_norm > prior_bound |
|
|
pass_bound = pass_bound.flatten() |
|
|
parameter_data_detached.data[pass_bound] = cloned_ori_latent[pass_bound] + prior_bound * diff[pass_bound] / diff_norm[pass_bound] |
|
|
|
|
|
_, _, _ = self._solve_ode(img=img, latent=parameter_data_detached.data, condition=condition, uncondition=uncondition, valid=False) |
|
|
|
|
|
to_update_mask = self.loss_vector < loss_vector_ref |
|
|
parameter_data_detached.data = parameter_data_detached.data.reshape(-1, self.channels, self.resolution, self.resolution) |
|
|
latent[to_update_mask] = parameter_data_detached.data[to_update_mask] |
|
|
return latent, to_update_mask |
|
|
|
|
|
def _train_one_round(self): |
|
|
no_change = True |
|
|
logging.info(f"{self._current_version} Round {self.cur_round}") |
|
|
|
|
|
if self.cur_round > 0: |
|
|
self._load_checkpoint(reload_data=False) |
|
|
self.count_worse = 0 |
|
|
|
|
|
self._examine_checkpoint(self.cur_iter) |
|
|
|
|
|
for loader_idx, loader in enumerate([self.train_loader, self.valid_loader]): |
|
|
if loader_idx == 1 and self.prior_bound == 0.0: |
|
|
continue |
|
|
|
|
|
self._set_trainable_params(is_train=loader_idx == 0, is_no_v1=self.no_v1) |
|
|
|
|
|
ori_latents, latents, targets, conditions, unconditions = [], [], [], [], [] |
|
|
for img, latent, ori_latent, condition, uncondition in loader: |
|
|
img, latent, ori_latent, condition, uncondition = move_tensor_to_device(img, latent, ori_latent, condition, uncondition, device=self.device) |
|
|
if loader_idx == 1: |
|
|
self._log_valid_distance(ori_latent, latent) |
|
|
|
|
|
|
|
|
batch_size = ori_latent.shape[0] |
|
|
ori_latent = ori_latent.reshape(batch_size, -1) |
|
|
latent_to_update = latent.clone().detach().reshape(batch_size, -1).to(self.device) |
|
|
latent_params = torch.nn.Parameter(latent_to_update) |
|
|
latent_params.requires_grad = True |
|
|
|
|
|
latent_optimizer = torch.optim.SGD([latent_params], lr=self.shift_lr) |
|
|
if img.device != latent_params.device: |
|
|
breakpoint() |
|
|
loss, _, _ = self._solve_ode(img=img, latent=latent_params, condition=condition, uncondition=uncondition, valid=False) |
|
|
loss_vector_ref = self.loss_vector.clone().detach() |
|
|
loss.backward() |
|
|
logging.info(f"{self._current_version} Iter {self.cur_iter} {'Train' if loader_idx == 0 else 'Val'} Loss: {loss.item()}") |
|
|
|
|
|
latent_optimizer.step() |
|
|
latent_optimizer.zero_grad() |
|
|
|
|
|
if loader_idx == 0: |
|
|
torch.nn.utils.clip_grad_norm_(self.params1, 1.0) |
|
|
torch.nn.utils.clip_grad_norm_(self.params2, 1.0) |
|
|
self.optimizer_lamb1.step() |
|
|
self.optimizer_lamb1.zero_grad() |
|
|
self.optimizer_lamb2.step() |
|
|
self.optimizer_lamb2.zero_grad() |
|
|
|
|
|
self.cur_iter += 1 |
|
|
self._examine_checkpoint(self.cur_iter) |
|
|
if self.count_min_lr2_hit >= self.lr2_patient: |
|
|
logging.info(f"{self._current_version} Reach min lr2 5 times. Stop training.") |
|
|
return no_change, True |
|
|
|
|
|
with torch.no_grad(): |
|
|
latent, to_update_mask = self._update_latents(latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, self.prior_bound) |
|
|
if loader_idx == 1 and to_update_mask.sum().item() > 0: |
|
|
|
|
|
no_change = False |
|
|
|
|
|
ori_latent = ori_latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu() |
|
|
latent = latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu() |
|
|
img = img.detach().cpu() |
|
|
condition = condition.detach().cpu() if condition is not None else None |
|
|
uncondition = uncondition.detach().cpu() if uncondition is not None else None |
|
|
|
|
|
for j in range(latent.shape[0]): |
|
|
ori_latents.append(ori_latent[j]) |
|
|
targets.append(img[j]) |
|
|
latents.append(latent[j]) |
|
|
conditions.append(condition[j] if condition is not None else None) |
|
|
unconditions.append(uncondition[j] if uncondition is not None else None) |
|
|
|
|
|
|
|
|
if self.prior_bound > 0: |
|
|
self._update_dataloader(ori_latents, latents, targets, conditions, unconditions, is_train=loader_idx==0) |
|
|
|
|
|
return no_change, False |
|
|
|
|
|
def train(self, training_rounds_v1: int, training_rounds_v2: int) -> None: |
|
|
|
|
|
total_round = training_rounds_v1 + training_rounds_v2 |
|
|
self.training_rounds_v1 = training_rounds_v1 |
|
|
|
|
|
if self.match_prior: |
|
|
self._train_to_match_prior() |
|
|
|
|
|
while self.cur_round < total_round: |
|
|
no_latent_change, should_stop = self._train_one_round() |
|
|
if should_stop: |
|
|
return |
|
|
self.cur_round += 1 |
|
|
|
|
|
if no_latent_change and self.prior_bound > 0: |
|
|
self.shift_lr *= self.shift_lr_decay |
|
|
|
|
|
logging.info(f"{self._current_version} Max round reached, stopping") |
|
|
|
|
|
def discretize_model_wrapper(input1, input2, lambda_max, lambda_min, noise_schedule, mode, window_rate=0.5): |
|
|
''' |
|
|
checked! |
|
|
''' |
|
|
|
|
|
def model_time_fn(): |
|
|
time1, time2 = input1, input2 |
|
|
t_max, t_min = noise_schedule.inverse_lambda(lambda_min).to(time1.device), noise_schedule.inverse_lambda(lambda_max).to(time1.device) |
|
|
time_plus = torch.nn.functional.softmax(time1, dim=0) |
|
|
time_md = torch.cumsum(time_plus, dim=0).flip(0) |
|
|
normed = (time_md - time_md[-1]) / (time_md[0] - time_md[-1]) |
|
|
time_steps = normed * (t_max - t_min) + t_min |
|
|
cloned_time_steps = time_steps.clone().detach() |
|
|
max_move = (cloned_time_steps[1:] - cloned_time_steps[:-1]).abs().min().item() * window_rate |
|
|
clipped_time2 = torch.clamp(time2, min=-max_move, max=max_move) |
|
|
mask = torch.ones_like(normed) |
|
|
mask[0] = 0. |
|
|
mask[-1] = 0. |
|
|
return time_steps, time_steps + (clipped_time2 * mask) |
|
|
|
|
|
def model_lambda_fn(): |
|
|
lambda1, lambda2 = input1, input2 |
|
|
lamb_plus = F.softmax(lambda1, dim=0) |
|
|
lamb_md = torch.cumsum(lamb_plus, dim=0) |
|
|
normed = (lamb_md - lamb_md.min()) / (lamb_md.max() - lamb_md.min()) |
|
|
lamb_steps1 = normed * (lambda_max - lambda_min) + lambda_min |
|
|
mask = torch.ones_like(lamb_steps1) |
|
|
|
|
|
cloned_lamb1 = lambda1.clone().detach() |
|
|
max_move = (cloned_lamb1[1:] - cloned_lamb1[:-1]).abs().min().item() * window_rate |
|
|
clipped_lamb2 = torch.clamp(lambda2, min=-max_move, max=max_move) |
|
|
|
|
|
mask[0] = 0. |
|
|
mask[-1] = 0. |
|
|
|
|
|
lamb_steps2 = lamb_steps1 + clipped_lamb2 * mask |
|
|
|
|
|
time1 = noise_schedule.inverse_lambda(lamb_steps1) |
|
|
time2 = noise_schedule.inverse_lambda(lamb_steps2) |
|
|
return time1, time2 |
|
|
|
|
|
return model_time_fn if mode == 'time' else model_lambda_fn |
|
|
|