Spaces:
Runtime error
Runtime error
| from .misc_4ddpm import * | |
| from lmk_util.lmk_extractor import lmkAll_2_lmkMain, get_lmkMain_indices | |
| class DDPM(pl.LightningModule): | |
| # classic DDPM with Gaussian diffusion, in image space | |
| def __init__(self, | |
| unet_config, | |
| timesteps=1000, | |
| beta_schedule="linear", | |
| loss_type="l2", | |
| ckpt_path=None, | |
| ignore_keys=[], | |
| load_only_unet=False, | |
| monitor="val/loss", | |
| use_ema=True, | |
| first_stage_key="image", | |
| image_size=256, | |
| channels=3, | |
| log_every_t=100, | |
| clip_denoised=True, | |
| linear_start=1e-4, | |
| linear_end=2e-2, | |
| cosine_s=8e-3, | |
| given_betas=None, | |
| original_elbo_weight=0., | |
| v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta | |
| l_simple_weight=1., | |
| conditioning_key=None, | |
| parameterization="eps", # all assuming fixed variance schedules | |
| scheduler_config=None, | |
| learn_logvar=False, | |
| logvar_init=0., | |
| u_cond_percent=0, | |
| ): | |
| super().__init__() | |
| assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' | |
| self.parameterization = parameterization | |
| print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") | |
| self.cond_stage_model = None | |
| self.clip_denoised = clip_denoised | |
| self.log_every_t = log_every_t | |
| self.first_stage_key = first_stage_key | |
| self.image_size = image_size | |
| self.channels = channels | |
| self.u_cond_percent=u_cond_percent | |
| unet_config['params']['in_channels'] = 14 if CH14 else 9 | |
| self.model = DiffusionWrapper(unet_config, conditioning_key) | |
| count_params(self.model, verbose=True) | |
| self.use_ema = use_ema | |
| if self.use_ema: | |
| self.model_ema = LitEma(self.model) | |
| print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") | |
| self.use_scheduler = scheduler_config is not None | |
| if self.use_scheduler: | |
| self.scheduler_config = scheduler_config | |
| self.v_posterior = v_posterior | |
| self.original_elbo_weight = original_elbo_weight | |
| self.l_simple_weight = l_simple_weight | |
| if monitor is not None: | |
| self.monitor = monitor | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) | |
| self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, | |
| linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) | |
| self.loss_type = loss_type | |
| self.learn_logvar = learn_logvar | |
| self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) | |
| if self.learn_logvar: | |
| self.logvar = nn.Parameter(self.logvar, requires_grad=True) | |
| def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, | |
| linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): | |
| if exists(given_betas): | |
| betas = given_betas | |
| else: | |
| betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, | |
| cosine_s=cosine_s) | |
| alphas = 1. - betas | |
| alphas_cumprod = np.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) | |
| timesteps, = betas.shape | |
| self.num_timesteps = int(timesteps) | |
| self.linear_start = linear_start | |
| self.linear_end = linear_end | |
| assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' | |
| to_torch = partial(torch.tensor, dtype=torch.float32) | |
| self.register_buffer('betas', to_torch(betas)) | |
| self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) | |
| self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) | |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) | |
| self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) | |
| self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) | |
| self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( | |
| 1. - alphas_cumprod) + self.v_posterior * betas | |
| # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) | |
| self.register_buffer('posterior_variance', to_torch(posterior_variance)) | |
| # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
| self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) | |
| self.register_buffer('posterior_mean_coef1', to_torch( | |
| betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) | |
| self.register_buffer('posterior_mean_coef2', to_torch( | |
| (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) | |
| if self.parameterization == "eps": | |
| lvlb_weights = self.betas ** 2 / ( | |
| 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) | |
| elif self.parameterization == "x0": | |
| lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) | |
| else: | |
| raise NotImplementedError("mu not supported") | |
| # TODO how to choose this term | |
| lvlb_weights[0] = lvlb_weights[1] | |
| self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) | |
| assert not torch.isnan(self.lvlb_weights).all() | |
| def ema_scope(self, context=None): | |
| if self.use_ema: | |
| self.model_ema.store(self.model.parameters()) | |
| self.model_ema.copy_to(self.model) | |
| if context is not None: | |
| print(f"{context}: Switched to EMA weights") | |
| try: | |
| yield None | |
| finally: | |
| if self.use_ema: | |
| self.model_ema.restore(self.model.parameters()) | |
| if context is not None: | |
| print(f"{context}: Restored training weights") | |
| def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): | |
| assert 0 | |
| print("[init_from_ckpt]") | |
| sd = torch.load(path, map_location="cpu") | |
| if "state_dict" in list(sd.keys()): | |
| sd = sd["state_dict"] | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print("Deleting key {} from state_dict.".format(k)) | |
| del sd[k] | |
| missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( | |
| sd, strict=False) | |
| print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") | |
| if len(missing) > 0: | |
| print(f"Missing Keys: {missing}") | |
| if len(unexpected) > 0: | |
| print(f"Unexpected Keys: {unexpected}") | |
| def q_sample(self, x_start, t, noise=None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + | |
| extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) | |
| def get_loss(self, pred, target, mean=True): | |
| if self.loss_type == 'l1': | |
| loss = (target - pred).abs() | |
| if mean: | |
| loss = loss.mean() | |
| elif self.loss_type == 'l2': | |
| if mean: | |
| loss = torch.nn.functional.mse_loss(target, pred) | |
| else: | |
| loss = torch.nn.functional.mse_loss(target, pred, reduction='none') #--> | |
| else: | |
| raise NotImplementedError("unknown loss type '{loss_type}'") | |
| return loss | |
| def p_losses(self, x_start, t, noise=None): | |
| assert 0, 'This should not be called; subclasses override this method' | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) | |
| model_out = self.model(x_noisy, t) | |
| loss_dict = {} | |
| if self.parameterization == "eps": | |
| target = noise | |
| elif self.parameterization == "x0": | |
| target = x_start | |
| else: | |
| raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") | |
| loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) | |
| # metrics.csv entries like 'train/...' and 'val/...' originate here | |
| log_prefix = 'train' if self.training else 'val' | |
| loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) | |
| loss_simple = loss.mean() * self.l_simple_weight | |
| loss_vlb = (self.lvlb_weights[t] * loss).mean() | |
| loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) | |
| loss = loss_simple + self.original_elbo_weight * loss_vlb | |
| loss_dict.update({f'{log_prefix}/loss': loss}) | |
| return loss, loss_dict | |
| def forward(self, x, *args, **kwargs): | |
| # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size | |
| # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' | |
| t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() | |
| return self.p_losses(x, t, *args, **kwargs) | |
| def shared_step(self, batch): | |
| assert 0 | |
| def set_task(self, batch): | |
| task = batch['task'][0].item() | |
| printC('task',f"{task=}") | |
| global_.task = task | |
| assert all(batch['task'] == task), batch['task'] | |
| self.task = task | |
| if 1: | |
| if (not USE_pts) or task==1: self.Landmark_cond=False | |
| else: self.Landmark_cond=True | |
| if 1: | |
| if task in (0,2,3,): | |
| self.Landmarks_weight=0.05 | |
| else: | |
| self.Landmarks_weight=0 | |
| self.STACK_feat=True | |
| return task | |
| def unset_task(self): | |
| global_.task = None | |
| global_.lmk_ = None | |
| del self.task | |
| def training_step(self, batch, batch_idx): | |
| task = batch['task'][0].item() | |
| opt = self.optimizers() | |
| if not self.Reconstruct_initial:# only MSE loss(orig diffusion). -> shared_step -> forward -> p_losses | |
| loss, loss_dict = self.shared_step(batch) # original | |
| else: # added Multistep (DDIM) loss -> shared_step_face -> forward_face -> p_losses_face | |
| loss, loss_dict = self.shared_step_face(batch) # changed by sanoojan : to add ID loss after reconstructing through DDIM | |
| step_or_accumulate = ( task==3 or TP_enable) | |
| _ctx = nullcontext | |
| if not step_or_accumulate and not TP_enable: | |
| _ctx = self.trainer.model.no_sync # https://github.com/Lightning-AI/pytorch-lightning/discussions/10792 | |
| with _ctx(): # https://zhuanlan.zhihu.com/p/250471767 | |
| self.manual_backward(loss) | |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): | |
| self.model.bank.clear() | |
| self.unset_task() | |
| total_step = len(self.trainer.train_dataloader) | |
| if step_or_accumulate: | |
| # Average grads of shared params across ranks (TaskParallel) | |
| if dist.is_available() and dist.is_initialized(): | |
| ws = dist.get_world_size() | |
| shared_sync_cnt = 0; task_skip_cnt = 0 | |
| for name, p in self.named_parameters(): | |
| need_sync, is_task_specific_skip = tp_param_need_sync(name, p) | |
| if not need_sync: | |
| if is_task_specific_skip: | |
| task_skip_cnt += 1 | |
| continue | |
| if p.grad is None: | |
| p.grad = torch.zeros_like(p) # ensure collective call sequence remains consistent | |
| dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) | |
| p.grad.div_(ws) | |
| shared_sync_cnt += 1 | |
| if gate_('[TP] shared sync counts'): | |
| print(f"synced={shared_sync_cnt} skipped(task)={task_skip_cnt}") | |
| torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) | |
| opt.step() | |
| opt.zero_grad() | |
| if self.use_scheduler: # handle LR schedulers | |
| sch = self.lr_schedulers() | |
| if isinstance(sch, list) and len(sch) > 0: # schedulers expressed as a list | |
| for scheduler_config in sch: | |
| if isinstance(scheduler_config, dict) and 'scheduler' in scheduler_config: | |
| scheduler_config['scheduler'].step() | |
| else: | |
| scheduler_config.step() | |
| elif hasattr(sch, 'step'): | |
| sch.step() | |
| self.log_dict(loss_dict, prog_bar=True, | |
| logger=True, on_step=True, on_epoch=True) | |
| self.log("global_step", self.global_step, | |
| prog_bar=True, logger=True, on_step=True, on_epoch=False) | |
| if self.use_scheduler: | |
| lr = self.optimizers().param_groups[0]['lr'] | |
| self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) | |
| return loss | |
| # manual optimization calls backward in training_step already, so this is skipped here | |
| # def backward( | |
| def validation_step(self, batch, batch_idx): | |
| _, loss_dict_no_ema = self.shared_step(batch) | |
| with self.ema_scope(): | |
| _, loss_dict_ema = self.shared_step(batch) | |
| loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} | |
| self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) | |
| self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) | |
| self.unset_task() | |
| def on_train_batch_end(self, *args, **kwargs): | |
| if self.use_ema: | |
| self.model_ema(self.model) | |
| class LatentDiffusion(DDPM): | |
| """main class""" | |
| def __init__(self, | |
| first_stage_config, | |
| cond_stage_config, | |
| num_timesteps_cond=None, | |
| cond_stage_key="image", | |
| cond_stage_trainable=False, | |
| concat_mode=True, | |
| cond_stage_forward=None, | |
| conditioning_key=None, | |
| scale_factor=1.0, | |
| scale_by_std=False, | |
| *args, **kwargs): | |
| self.num_timesteps_cond = default(num_timesteps_cond, 1) | |
| self.scale_by_std = scale_by_std | |
| assert self.num_timesteps_cond <= kwargs['timesteps'] | |
| # for backwards compatibility after implementation of DiffusionWrapper | |
| if conditioning_key is None: | |
| conditioning_key = 'concat' if concat_mode else 'crossattn' | |
| if cond_stage_config == '__is_unconditional__': | |
| conditioning_key = None | |
| ckpt_path = kwargs.pop("ckpt_path", None) | |
| ignore_keys = kwargs.pop("ignore_keys", []) | |
| super().__init__(conditioning_key=conditioning_key, *args, **kwargs) | |
| self.automatic_optimization = False # disable automatic optimization to manage parameter updates manually | |
| # self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True) | |
| # breakpoint() | |
| self.concat_mode = concat_mode | |
| self.cond_stage_trainable = cond_stage_trainable | |
| self.cond_stage_key = cond_stage_key | |
| #check if other_params is present in cond_stage_config | |
| if hasattr(cond_stage_config, 'other_params'): | |
| self.clip_weight=cond_stage_config.other_params.clip_weight | |
| # those three weights: 0 skips module init, >0 enables it and acts as weight when !STACK_feat | |
| if set(TASKS) & {0,2,3}: self.ID_weight = 10.0 | |
| else: self.ID_weight = 0 | |
| if (not USE_pts) and TASKS==(1,): self.Landmark_cond=False | |
| else: self.Landmark_cond=True | |
| self.Landmarks_weight=0.05 | |
| if hasattr(cond_stage_config.other_params, 'Additional_config'): | |
| self.Reconstruct_initial=cond_stage_config.other_params.Additional_config.Reconstruct_initial | |
| self.Reconstruct_DDIM_steps=cond_stage_config.other_params.Additional_config.Reconstruct_DDIM_steps | |
| self.sampler=DDIMSampler(self) | |
| if hasattr(cond_stage_config.other_params, 'multi_scale_ID'): | |
| self.multi_scale_ID=cond_stage_config.other_params.multi_scale_ID # True has an issue | |
| else: | |
| self.multi_scale_ID=True #this has an issue obtaining earlier layer from ID | |
| if hasattr(cond_stage_config.other_params, 'normalize'): | |
| self.normalize=cond_stage_config.other_params.normalize # normalizes the combintaion of ID and LPIPS loss | |
| else: | |
| self.normalize=False | |
| if 1: | |
| self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() | |
| if hasattr(cond_stage_config.other_params, 'partial_training'): | |
| self.partial_training=cond_stage_config.other_params.partial_training | |
| self.trainable_keys=cond_stage_config.other_params.trainable_keys | |
| else: | |
| self.partial_training=False | |
| if hasattr(cond_stage_config.other_params.Additional_config, 'Same_image_reconstruct'): | |
| self.Same_image_reconstruct=cond_stage_config.other_params.Additional_config.Same_image_reconstruct | |
| else: | |
| self.Same_image_reconstruct=False | |
| if hasattr(cond_stage_config.other_params.Additional_config, 'Target_CLIP_feat'): | |
| self.Target_CLIP_feat=cond_stage_config.other_params.Additional_config.Target_CLIP_feat | |
| else: | |
| self.Target_CLIP_feat=False | |
| if hasattr(cond_stage_config.other_params.Additional_config, 'Source_CLIP_feat'): | |
| self.Source_CLIP_feat=cond_stage_config.other_params.Additional_config.Source_CLIP_feat | |
| else: | |
| self.Source_CLIP_feat=False | |
| if hasattr(cond_stage_config.other_params.Additional_config, 'use_3dmm'): | |
| self.use_3dmm=cond_stage_config.other_params.Additional_config.use_3dmm | |
| else: | |
| self.use_3dmm=False | |
| else: | |
| self.Reconstruct_initial=False | |
| self.Reconstruct_DDIM_steps=0 | |
| self.update_weight=False | |
| else: | |
| assert 0 | |
| if 1: | |
| self.learnable_vector = nn.ParameterList([ | |
| nn.Parameter(torch.randn((1,259,768)), requires_grad=True), | |
| nn.Parameter(torch.randn((1,257,768)), requires_grad=True), | |
| nn.Parameter(torch.randn((1,259,768)), requires_grad=True), | |
| nn.Parameter(torch.randn((1,259,768)), requires_grad=True), | |
| ]) | |
| if self.ID_weight>0: | |
| if self.multi_scale_ID: | |
| self.ID_proj_out=nn.Linear(200704, 768) | |
| else: | |
| self.ID_proj_out=nn.Linear(512, 768) # yes | |
| self.instantiate_IDLoss(cond_stage_config) | |
| if self.Landmark_cond: | |
| if USE_pts: | |
| self.ptsM_Generator = None | |
| else: | |
| raise | |
| if self.Landmarks_weight>0: | |
| self.landmark_proj_out=nn.Linear(NUM_pts*2, 768) | |
| self.total_steps_in_epoch=0 # will be calculated inside training_step. Not known for now | |
| if 1: | |
| assert cond_stage_config.target=="ldm.modules.encoders.modules.FrozenCLIPEmbedder" and self.Source_CLIP_feat and self.Target_CLIP_feat | |
| self.USE_proj_out_source = 1 | |
| if set(TASKS) & {0,}: | |
| self.proj_out_source__face=nn.Linear(768, 768) | |
| if set(TASKS) & {1,}: | |
| self.proj_out_source__hair=nn.Linear(768, 768) | |
| if set(TASKS) & {2,3,}: | |
| self.proj_out_source__head=nn.Linear(768, 768) | |
| if 0: # dummy, just for compa | |
| self.proj_out_target=nn.Linear(768, 768) | |
| self.proj_out=nn.Identity() | |
| try: | |
| self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 | |
| except: | |
| self.num_downs = 0 | |
| if not scale_by_std: | |
| self.scale_factor = scale_factor | |
| else: | |
| self.register_buffer('scale_factor', torch.tensor(scale_factor)) | |
| self.instantiate_first_stage(first_stage_config) | |
| self.instantiate_cond_stage(cond_stage_config) | |
| self.cond_stage_forward = cond_stage_forward | |
| self.clip_denoised = False | |
| self.bbox_tokenizer = None | |
| self.restarted_from_ckpt = False | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys) | |
| self.restarted_from_ckpt = True | |
| def get_lmk_for_router(self, batch: dict, x_tensor: torch.Tensor): | |
| """ | |
| Prepare global_.lmk_ (BS, L, 2) normalized to [0,1] for gating/router. | |
| - Prefer cached Mediapipe landmarks if present in batch | |
| - Convert 468/478 to main landmarks with face oval using get_lmkMain_indices(True) | |
| - Fallback to zeros if not available | |
| """ | |
| b, _, H, W = x_tensor.shape | |
| if READ_mediapipe_result_from_cache and ('mediapipe_lmkAll' in batch): | |
| data_all = batch['mediapipe_lmkAll'] # tensor or ndarray | |
| if isinstance(data_all, torch.Tensor): | |
| lmks_all = data_all.to(x_tensor.device).to(x_tensor.dtype) | |
| else: | |
| lmks_all = torch.from_numpy(data_all).to(x_tensor.device).to(x_tensor.dtype) | |
| # map to main indices with face oval (cached tensor indices on device) | |
| idxs = getattr(global_, 'lmk_main_idx_tensor', None) | |
| if (idxs is None) or (idxs.device != x_tensor.device): | |
| idx_list = get_lmkMain_indices(include_face_oval=True) | |
| idxs = torch.as_tensor(list(idx_list), dtype=torch.long, device=x_tensor.device) | |
| global_.lmk_main_idx_tensor = idxs | |
| lmk = torch.index_select(lmks_all, dim=1, index=idxs) | |
| # normalize by current spatial size | |
| if lmk.numel() > 0: | |
| # print(f"0 {lmk[:,:5]=}") | |
| lmk[..., 0] = lmk[..., 0] / float(W) | |
| lmk[..., 1] = lmk[..., 1] / float(H) | |
| # print(f"1 {lmk[:,:5]=}") | |
| else: | |
| assert 0 | |
| lmk = torch.zeros((b, 0, 2), device=x_tensor.device, dtype=x_tensor.dtype) | |
| return lmk | |
| def make_cond_schedule(self, ): | |
| self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) | |
| ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() | |
| self.cond_ids[:self.num_timesteps_cond] = ids | |
| def on_train_batch_start(self, batch, batch_idx, dataloader_idx): | |
| # only for very first batch | |
| if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: | |
| assert 0 | |
| def register_schedule(self, | |
| given_betas=None, beta_schedule="linear", timesteps=1000, | |
| linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): | |
| super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) | |
| self.shorten_cond_schedule = self.num_timesteps_cond > 1 | |
| if self.shorten_cond_schedule: | |
| self.make_cond_schedule() | |
| def instantiate_first_stage(self, config): | |
| model = instantiate_from_config(config) | |
| self.first_stage_model = model.eval() | |
| self.first_stage_model.train = disabled_train | |
| for param in self.first_stage_model.parameters(): | |
| param.requires_grad = False | |
| def instantiate_IDLoss(self, config): | |
| # Need to modify @sanoojan | |
| # if not self.cond_stage_trainable: | |
| model = IDLoss(config,multiscale=self.multi_scale_ID) | |
| self.face_ID_model = model.eval() | |
| self.face_ID_model.train = disabled_train | |
| for param in self.face_ID_model.parameters(): | |
| param.requires_grad = False | |
| def instantiate_cond_stage(self, config): | |
| if 1: | |
| assert config != '__is_first_stage__' | |
| assert config != '__is_unconditional__' | |
| model: FrozenCLIPEmbedder = instantiate_from_config(config) #ldm.modules.encoders.modules.FrozenCLIPEmbedder | |
| if 0 in TASKS: | |
| self.encoder_clip_face :FrozenCLIPEmbedder = model | |
| if 1 in TASKS: | |
| self.encoder_clip_hair :FrozenCLIPEmbedder = copy.deepcopy(model) | |
| del self.encoder_clip_hair.model | |
| del self.encoder_clip_hair.tokenizer | |
| if set(TASKS) & {2,}: | |
| self.encoder_clip_head_t2 :FrozenCLIPEmbedder = copy.deepcopy(model) | |
| del self.encoder_clip_head_t2.model | |
| del self.encoder_clip_head_t2.tokenizer | |
| if set(TASKS) & {3,}: | |
| self.encoder_clip_head_t3 :FrozenCLIPEmbedder = copy.deepcopy(model) | |
| del self.encoder_clip_head_t3.model | |
| del self.encoder_clip_head_t3.tokenizer | |
| def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): | |
| denoise_row = [] | |
| for zd in tqdm(samples, desc=desc): | |
| denoise_row.append(self.decode_first_stage(zd.to(self.device), | |
| force_not_quantize=force_no_decoder_quantization)) | |
| n_imgs_per_row = len(denoise_row) | |
| denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W | |
| denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') | |
| denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') | |
| denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) | |
| return denoise_grid | |
| def get_first_stage_encoding(self, encoder_posterior): | |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution): | |
| z = encoder_posterior.sample() | |
| elif isinstance(encoder_posterior, torch.Tensor): | |
| z = encoder_posterior | |
| else: | |
| raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") | |
| return self.scale_factor * z | |
| def get_learned_conditioning(self, c): | |
| raise Exception | |
| def conditioning_with_feat(self,x,landmarks=None,enInputs:dict=None): | |
| if gate_('vis LatentDiffusion.conditioning_with_feat'): | |
| debug_dir = Path(f"4debug/conditioning_with_feat/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) | |
| all_images = [ ('x', x), ] | |
| for _name, _enInput in enInputs.items(): | |
| all_images.append((_name, _enInput)) | |
| vis_tensors_A(all_images, debug_dir / f"all-{str_t_pid()}.jpg", vis_batch_size= min(5, landmarks.shape[0]) ) | |
| del x # (x is GT during training, ref_imgs during inference) | |
| task = self.task | |
| ID_weight = self.ID_weight | |
| Landmarks_weight = self.Landmarks_weight | |
| if self.task==0: | |
| face_clip_weight = self.clip_weight | |
| elif self.task==1: | |
| hair_clip_weight = self.clip_weight | |
| elif self.task==2: | |
| head_clip_weight = self.clip_weight | |
| elif self.task==3: | |
| head_clip_weight = self.clip_weight | |
| if 1: | |
| cs = [] # conditionings | |
| ws = [] # weights corresponding one-to-one with cs | |
| def encode_face_ID(): | |
| _c = enInputs['face_ID-in'] | |
| _c=self.face_ID_model.extract_feats(_c)[0] | |
| _c = self.ID_proj_out(_c) #-->c:[4,768] | |
| _c = _c.unsqueeze(1) #-->c:[4,1,768] | |
| if self.normalize: #normalize c2 | |
| _c = _c*norm_coeff/F.normalize(_c, p=2, dim=2) | |
| cs.append(_c); ws.append(ID_weight) | |
| def encode_face_clip(_z=None):# _z: result of ViT forward pass | |
| if _z is None: | |
| _c = enInputs['face-clip-in'] | |
| _c = self.encoder_clip_face.encode(_c) #b,3,224,224 --> b,1,768 | |
| else: | |
| assert 0 | |
| _c = self.encoder_clip_face.encode_B(_z) | |
| if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: | |
| _c = self.proj_out_source__face(_c) | |
| cs.append(_c); ws.append(face_clip_weight) | |
| def encode_hair_clip(_z=None): | |
| if _z is None: | |
| _c = enInputs['hair-clip-in'] | |
| _c = self.encoder_clip_hair.encode(_c) #b,3,224,224 --> b,1,768 | |
| else: | |
| _c = self.encoder_clip_hair.encode_B(_z) | |
| if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: | |
| _c = self.proj_out_source__hair(_c) | |
| printC("hair _c.shape:",f"{_c.shape}") | |
| cs.append(_c); ws.append(hair_clip_weight) | |
| def encode_head_clip(_z=None): | |
| if global_.task == 2: | |
| encoder_clip_head = self.encoder_clip_head_t2 | |
| elif global_.task == 3: | |
| encoder_clip_head = self.encoder_clip_head_t3 | |
| else: | |
| raise ValueError(f"Task {global_.task} does not have encoder_clip_head") | |
| if _z is None: | |
| _c = enInputs['head-clip-in'] | |
| _c = encoder_clip_head.encode(_c) #b,3,224,224 --> b,1,768 | |
| else: | |
| _c = encoder_clip_head.encode_B(_z) | |
| if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: | |
| _c = self.proj_out_source__head(_c) | |
| printC("head _c.shape:",f"{_c.shape}") | |
| cs.append(_c); ws.append(head_clip_weight) | |
| if task==0: | |
| encode_face_ID() | |
| encode_face_clip() | |
| elif task==1: | |
| _z = enInputs['hair-clip-in'] | |
| _z = self.encoder_clip_face.forward_vit(_z) | |
| encode_hair_clip(_z) | |
| elif task==2: | |
| encode_face_ID() | |
| _z = enInputs['head-clip-in'] | |
| _z = self.encoder_clip_face.forward_vit(_z) | |
| encode_head_clip(_z) | |
| elif task==3: | |
| encode_face_ID() | |
| _z = enInputs['head-clip-in'] | |
| _z = self.encoder_clip_face.forward_vit(_z) | |
| encode_head_clip(_z) | |
| c=0 | |
| if Landmarks_weight > 0: | |
| landmarks=landmarks.unsqueeze(1) if len(landmarks.shape)!=3 else landmarks | |
| cs.append(landmarks); ws.append(Landmarks_weight) | |
| if self.STACK_feat: # _Cc | |
| # stack all features | |
| conc=torch.cat(cs, dim=-2) | |
| c = conc | |
| else: | |
| total_weight = sum(ws) | |
| weighted_sum = sum(c * w for c, w in zip(cs, ws)) | |
| c = weighted_sum / total_weight if total_weight > 0 else 0 | |
| printC("[conditioning_with_feat return]",f"{custom_repr_v3(c)}") | |
| # assert c.shape[1]==NUM_token, c.shape | |
| return c | |
| def get_landmarks(self,x, batch:dict): | |
| if (self.Landmark_cond) and x is not None: | |
| # pass | |
| # # Detect faces in an image | |
| #convert to 8bit image | |
| x=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() | |
| x=x.astype(np.uint8) # B,512,512,3 | |
| Landmarks_all=[] | |
| if USE_pts: | |
| l_lmkAll=[] | |
| if READ_mediapipe_result_from_cache: | |
| _l_lmkAll :np.ndarray = batch['mediapipe_lmkAll'].cpu().numpy() | |
| bs = len(x) | |
| for i in range(len(x)): | |
| if USE_pts: | |
| if READ_mediapipe_result_from_cache: | |
| lmkAll :np.ndarray = _l_lmkAll[i] | |
| else: | |
| lmkAll :np.ndarray = self.ptsM_Generator.extract_single(x[i], only_main_lmk=False) | |
| if lmkAll is None: lmkAll = np.zeros((478,2)) | |
| l_lmkAll.append(lmkAll) | |
| lm = lmkAll_2_lmkMain(lmkAll) # NUM_pts,2 | |
| lm = lm.reshape(1, NUM_pts*2) # num of points * 2 coordinates | |
| Landmarks_all.append(lm) | |
| if 0: | |
| from util_vis import visualize_landmarks | |
| starter_stem = Path(sys.argv[0]).stem | |
| path_vis_lmk = f'4debug/vis_lmk/{starter_stem}-{i}.png' | |
| visualize_landmarks(x[i], lm[0], path_vis_lmk) | |
| print(f"{path_vis_lmk=}") | |
| Landmarks_all=np.concatenate(Landmarks_all,axis=0) | |
| pts68 = Landmarks_all.reshape(bs, NUM_pts, 2, ) | |
| if self.Landmarks_weight>0: | |
| Landmarks_all=torch.tensor(Landmarks_all).float().to(self.device) | |
| if self.Landmark_cond == False: | |
| return Landmarks_all | |
| with torch.enable_grad(): | |
| Landmarks_all=self.landmark_proj_out(Landmarks_all) | |
| # normalize Landmarks_all | |
| lmk_aux={} | |
| if USE_pts: lmk_aux['l_lmkAll'] = l_lmkAll | |
| return Landmarks_all,pts68,lmk_aux | |
| def meshgrid(self, h, w): | |
| y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) | |
| x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) | |
| arr = torch.cat([y, x], dim=-1) | |
| return arr | |
| def delta_border(self, h, w): | |
| """ | |
| :param h: height | |
| :param w: width | |
| :return: normalized distance to image border, | |
| wtith min distance = 0 at border and max dist = 0.5 at image center | |
| """ | |
| lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) | |
| arr = self.meshgrid(h, w) / lower_right_corner | |
| dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] | |
| dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] | |
| edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] | |
| return edge_dist | |
| def get_weighting(self, h, w, Ly, Lx, device): | |
| weighting = self.delta_border(h, w) | |
| weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], | |
| self.split_input_params["clip_max_weight"], ) | |
| weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) | |
| if self.split_input_params["tie_braker"]: | |
| L_weighting = self.delta_border(Ly, Lx) | |
| L_weighting = torch.clip(L_weighting, | |
| self.split_input_params["clip_min_tie_weight"], | |
| self.split_input_params["clip_max_tie_weight"]) | |
| L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) | |
| weighting = weighting * L_weighting | |
| return weighting | |
| def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code | |
| """ | |
| :param x: img of size (bs, c, h, w) | |
| :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) | |
| """ | |
| bs, nc, h, w = x.shape | |
| # number of crops in image | |
| Ly = (h - kernel_size[0]) // stride[0] + 1 | |
| Lx = (w - kernel_size[1]) // stride[1] + 1 | |
| if uf == 1 and df == 1: | |
| fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) | |
| unfold = torch.nn.Unfold(**fold_params) | |
| fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) | |
| weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) | |
| normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap | |
| weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) | |
| elif uf > 1 and df == 1: | |
| fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) | |
| unfold = torch.nn.Unfold(**fold_params) | |
| fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), | |
| dilation=1, padding=0, | |
| stride=(stride[0] * uf, stride[1] * uf)) | |
| fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) | |
| weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) | |
| normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap | |
| weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) | |
| elif df > 1 and uf == 1: | |
| fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) | |
| unfold = torch.nn.Unfold(**fold_params) | |
| fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), | |
| dilation=1, padding=0, | |
| stride=(stride[0] // df, stride[1] // df)) | |
| fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) | |
| weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) | |
| normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap | |
| weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) | |
| else: | |
| raise NotImplementedError | |
| return fold, unfold, normalization, weighting | |
| # returned x is the concatenated multi-channel tensor (mask, ref, lmk, ...); e.g. "x_start[:,8,:,:]" extracts the mask | |
| def get_input_(self, batch, k, return_first_stage_outputs=False, | |
| cond_key=None, bs=None, | |
| get_referenceZ=False, # reference image latent tensor, dims B,4,64,64 | |
| ): | |
| if k == "inpaint": # yes | |
| x = batch['GT'] | |
| mask = batch['inpaint_mask'].clone() # b,1,512,512 | |
| inpaint = batch['inpaint_image'].clone() # .clone so that batch['inpaint_image'] remains the original image without landmarks | |
| # reference = batch['ref_imgs'] | |
| reference = None | |
| else: | |
| assert 0 | |
| if len(x.shape) == 3: | |
| assert 0 | |
| x = x[..., None] | |
| if 1: | |
| enInputs = batch['enInputs'] # encoder inputs (each self.encoder receives these raw tensors without preprocessing) | |
| for k,v in enInputs.items(): | |
| enInputs[k] = v.to(memory_format=torch.contiguous_format).float() | |
| #-------------------------------------------------------------------------------- | |
| ref_imgs_4unet = batch.get('ref_imgs_4unet', None) if get_referenceZ else None | |
| #x : Original Image | |
| #inpaint : Masked original image | |
| #mask: mask | |
| #reference: Transformed(Masked(original image)) | |
| if bs is not None: | |
| assert 0 | |
| x = x.to(self.device) | |
| global_.lmk_ = self.get_lmk_for_router(batch, x) # for router/gate | |
| if self.Landmark_cond: | |
| landmarks, pts68, lmk_aux=self.get_landmarks(x,batch) | |
| else: | |
| landmarks=None | |
| if self.task in (0,2,3,) and USE_pts: | |
| mask_np = mask.detach().cpu().numpy() | |
| if 1: | |
| #convert to 8bit image | |
| x_unnorm=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() | |
| x_unnorm=x_unnorm.astype(np.uint8) # B,512,512,3 | |
| batch_size = x.shape[0] | |
| VIS_pts= 0 | |
| for b in range(batch_size): | |
| lmkAll = lmk_aux['l_lmkAll'][b] | |
| inpaint[b] = torch.Tensor(LandmarkExtractor(include_visualizer=True,include_lmk_extractor=0,img_256_mode=False).visualizer.visualize_landmarks(inpaint[b].permute(1,2,0).detach().cpu().numpy(), lmkAll, ) ).permute(2,0,1) | |
| del lmkAll | |
| if self.training and gate_('vis LatentDiffusion.get_input'): | |
| debug_dir = Path(f"4debug/LatentDiffusion.get_input/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) | |
| vis_batch_size = min(5, x.shape[0]) # Show at most 4 samples | |
| all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet) ] | |
| for _name, _enInput in enInputs.items(): | |
| all_images.append((_name, _enInput)) | |
| all_path = debug_dir / f"all--after-pts-{str_t_pid()}.jpg" | |
| vis_tensors_A(all_images, all_path, vis_batch_size) | |
| encoder_posterior = self.encode_first_stage(x) | |
| z = self.get_first_stage_encoding(encoder_posterior).detach() | |
| encoder_posterior_inpaint = self.encode_first_stage(inpaint) | |
| z_inpaint = self.get_first_stage_encoding(encoder_posterior_inpaint).detach() | |
| # tgt/ref_mask_64 | |
| mask_resize = Resize([z.shape[-1],z.shape[-1]])(mask) | |
| ref_mask_64 = Resize([z.shape[-1],z.shape[-1]])(batch['ref_mask_512']) if 'ref_mask_512' in batch else None | |
| # z9 & z_ref | |
| if not CH14: | |
| z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) # shape:[4,9,64,64] 9:4+4+1 | |
| if get_referenceZ: | |
| encoder_posterior_ref = self.encode_first_stage(ref_imgs_4unet) | |
| z_ref = self.get_first_stage_encoding(encoder_posterior_ref).detach() # shape:[4,4,64,64] | |
| else: | |
| z_ref = None | |
| if CH14: | |
| z_new = torch.cat((z,z_inpaint,mask_resize, z_ref,ref_mask_64),dim=1) | |
| assert z.shape[1:]==(4,64,64,) | |
| if gate_(f'vis LatentDiffusion.get_input-before_return {self.training}'): | |
| debug_dir = Path(f"4debug/LatentDiffusion.get_input-before_return/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) | |
| vis_batch_size = min(5, x.shape[0]) | |
| all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet), | |
| ('z4_gt',z[:,:3]),('z4_inpaint', z_inpaint[:,:3]),('tgt_mask_64', mask_resize),('z_ref',None if z_ref is None else z_ref[:,:3]),('ref_mask_64',ref_mask_64),] | |
| all_path = debug_dir / f"{str_t_pid()}.jpg" | |
| vis_tensors_A(all_images, all_path, vis_batch_size) | |
| if 1: | |
| assert self.model.conditioning_key is not None | |
| assert self.first_stage_key=='inpaint' | |
| assert self.cond_stage_key=='image' | |
| return { | |
| **batch, | |
| 'z9': z_new,# b,9/14,... | |
| 'z4_gt': z, | |
| 'z4_inpaint': z_inpaint, | |
| # | |
| 'tgt_mask_64': mask_resize, | |
| 'ref_mask_64': ref_mask_64, | |
| # | |
| 'z_ref': z_ref, # 'z_ref' is ambiguous but kept for legacy usage; hard-code the intended meaning | |
| # | |
| 'landmarks': landmarks, # projected features, not raw coordinates | |
| } | |
| def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): | |
| if predict_cids: | |
| if z.dim() == 4: | |
| z = torch.argmax(z.exp(), dim=1).long() | |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) | |
| z = rearrange(z, 'b h w c -> b c h w').contiguous() | |
| z = 1. / self.scale_factor * z | |
| if hasattr(self, "split_input_params"): | |
| if self.split_input_params["patch_distributed_vq"]: | |
| ks = self.split_input_params["ks"] # eg. (128, 128) | |
| stride = self.split_input_params["stride"] # eg. (64, 64) | |
| uf = self.split_input_params["vqf"] | |
| bs, nc, h, w = z.shape | |
| if ks[0] > h or ks[1] > w: | |
| ks = (min(ks[0], h), min(ks[1], w)) | |
| print("reducing Kernel") | |
| if stride[0] > h or stride[1] > w: | |
| stride = (min(stride[0], h), min(stride[1], w)) | |
| print("reducing stride") | |
| fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) | |
| z = unfold(z) # (bn, nc * prod(**ks), L) | |
| # 1. Reshape to img shape | |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
| # 2. apply model loop over last dim | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i], | |
| force_not_quantize=predict_cids or force_not_quantize) | |
| for i in range(z.shape[-1])] | |
| else: | |
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) | |
| for i in range(z.shape[-1])] | |
| o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) | |
| o = o * weighting | |
| # Reverse 1. reshape to img shape | |
| o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
| # stitch crops together | |
| decoded = fold(o) | |
| decoded = decoded / normalization # norm is shape (1, 1, h, w) | |
| return decoded | |
| else: | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
| else: | |
| return self.first_stage_model.decode(z) | |
| else: | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
| else: | |
| if self.first_stage_key=='inpaint': | |
| return self.first_stage_model.decode(z[:,:4,:,:]) | |
| else: | |
| return self.first_stage_model.decode(z) | |
| # same as above but without decorator | |
| def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): | |
| if predict_cids: | |
| if z.dim() == 4: | |
| z = torch.argmax(z.exp(), dim=1).long() | |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) | |
| z = rearrange(z, 'b h w c -> b c h w').contiguous() | |
| z = 1. / self.scale_factor * z | |
| if hasattr(self, "split_input_params"): | |
| if self.split_input_params["patch_distributed_vq"]: | |
| ks = self.split_input_params["ks"] # eg. (128, 128) | |
| stride = self.split_input_params["stride"] # eg. (64, 64) | |
| uf = self.split_input_params["vqf"] | |
| bs, nc, h, w = z.shape | |
| if ks[0] > h or ks[1] > w: | |
| ks = (min(ks[0], h), min(ks[1], w)) | |
| print("reducing Kernel") | |
| if stride[0] > h or stride[1] > w: | |
| stride = (min(stride[0], h), min(stride[1], w)) | |
| print("reducing stride") | |
| fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) | |
| z = unfold(z) # (bn, nc * prod(**ks), L) | |
| # 1. Reshape to img shape | |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
| # 2. apply model loop over last dim | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i], | |
| force_not_quantize=predict_cids or force_not_quantize) | |
| for i in range(z.shape[-1])] | |
| else: | |
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) | |
| for i in range(z.shape[-1])] | |
| o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) | |
| o = o * weighting | |
| # Reverse 1. reshape to img shape | |
| o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
| # stitch crops together | |
| decoded = fold(o) | |
| decoded = decoded / normalization # norm is shape (1, 1, h, w) | |
| return decoded | |
| else: | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
| else: | |
| return self.first_stage_model.decode(z) | |
| else: | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
| else: | |
| return self.first_stage_model.decode(z) | |
| def encode_first_stage(self, x): | |
| if hasattr(self, "split_input_params"): | |
| if self.split_input_params["patch_distributed_vq"]: | |
| ks = self.split_input_params["ks"] # eg. (128, 128) | |
| stride = self.split_input_params["stride"] # eg. (64, 64) | |
| df = self.split_input_params["vqf"] | |
| self.split_input_params['original_image_size'] = x.shape[-2:] | |
| bs, nc, h, w = x.shape | |
| if ks[0] > h or ks[1] > w: | |
| ks = (min(ks[0], h), min(ks[1], w)) | |
| print("reducing Kernel") | |
| if stride[0] > h or stride[1] > w: | |
| stride = (min(stride[0], h), min(stride[1], w)) | |
| print("reducing stride") | |
| fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) | |
| z = unfold(x) # (bn, nc * prod(**ks), L) | |
| # Reshape to img shape | |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
| output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) | |
| for i in range(z.shape[-1])] | |
| o = torch.stack(output_list, axis=-1) | |
| o = o * weighting | |
| # Reverse reshape to img shape | |
| o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
| # stitch crops together | |
| decoded = fold(o) | |
| decoded = decoded / normalization | |
| return decoded | |
| else: | |
| return self.first_stage_model.encode(x) | |
| else: | |
| return self.first_stage_model.encode(x) | |
| def get_input_and_conditioning(self,batch, device=None): | |
| if device is not None: batch = recursive_to(batch, device) | |
| #------------------------from shared_step------------------------- | |
| get_referenceZ=(REFNET.ENABLE and REFNET.task2layerNum[global_.task]>0) or CH14 | |
| batch = self.get_input_(batch, self.first_stage_key,get_referenceZ=get_referenceZ) | |
| #------------------------from shared_step -> forward------------------------- | |
| assert ( self.model.conditioning_key is not None ) and self.cond_stage_trainable | |
| c=self.conditioning_with_feat(batch['ref_imgs'],landmarks=batch['landmarks'],enInputs=batch['enInputs']) | |
| return batch,c | |
| def shared_step(self, batch, **kwargs): | |
| task = self.set_task(batch) | |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): | |
| self.model.bank.clear() | |
| batch, c = self.get_input_and_conditioning(batch) | |
| z9 = batch['z9'] | |
| z_ref = batch['z_ref'] | |
| gt512 = batch['GT'] | |
| gt256 = batch.get('GT256',None) | |
| # del batch | |
| loss = self(z9, c,z_ref=z_ref,gt512=gt512,gt256=gt256,task=task,batch=batch,) | |
| return loss | |
| def forward(self, x, c, *args, **kwargs): | |
| task = kwargs['task'] | |
| # c is the reference tensor; target shares the same shape | |
| t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() | |
| self.u_cond_prop=random.uniform(0, 1) | |
| if self.model.conditioning_key is not None: | |
| # assert c is not None | |
| if self.cond_stage_trainable: # yes | |
| pass | |
| if self.shorten_cond_schedule: # TODO: drop this option | |
| raise Exception | |
| tc = self.cond_ids[t].to(self.device) | |
| c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) | |
| if self.u_cond_prop<self.u_cond_percent and self.training : | |
| return self.p_losses(x, self.learnable_vector[task].repeat(x.shape[0],1,1), t, *args, **kwargs) | |
| else: #x:[4,9,64,64] c:[4,1,768] x: img,inpaint_img,mask after first stage c:clip embedding | |
| return self.p_losses(x, c, t, *args, **kwargs) | |
| def apply_model(self, x_noisy, t, cond, return_ids=False,return_features=False, | |
| z_ref=None, | |
| ): | |
| if isinstance(cond, dict): | |
| # hybrid case, cond is exptected to be a dict | |
| pass | |
| else: | |
| if not isinstance(cond, list): | |
| cond = [cond] | |
| key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' # -->c_crossattn | |
| cond = {key: cond} | |
| if hasattr(self, "split_input_params"): | |
| assert 0,'This branch should not execute in practice' | |
| assert len(cond) == 1 # todo can only deal with one conditioning atm | |
| assert not return_ids | |
| ks = self.split_input_params["ks"] # eg. (128, 128) | |
| stride = self.split_input_params["stride"] # eg. (64, 64) | |
| h, w = x_noisy.shape[-2:] | |
| fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) | |
| z = unfold(x_noisy) # (bn, nc * prod(**ks), L) | |
| # Reshape to img shape | |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
| z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] | |
| if self.cond_stage_key in ["image", "LR_image", "segmentation", | |
| 'bbox_img'] and self.model.conditioning_key: # todo check for completeness | |
| c_key = next(iter(cond.keys())) # get key | |
| c = next(iter(cond.values())) # get value | |
| assert (len(c) == 1) # todo extend to list with more than one elem | |
| c = c[0] # get element | |
| c = unfold(c) | |
| c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
| cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] | |
| elif self.cond_stage_key == 'coordinates_bbox': | |
| assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' | |
| # assuming padding of unfold is always 0 and its dilation is always 1 | |
| n_patches_per_row = int((w - ks[0]) / stride[0] + 1) | |
| full_img_h, full_img_w = self.split_input_params['original_image_size'] | |
| # as we are operating on latents, we need the factor from the original image size to the | |
| # spatial latent size to properly rescale the crops for regenerating the bbox annotations | |
| num_downs = self.first_stage_model.encoder.num_resolutions - 1 | |
| rescale_latent = 2 ** (num_downs) | |
| # get top left positions of patches as conforming for the bbbox tokenizer, therefore we | |
| # need to rescale the tl patch coordinates to be in between (0,1) | |
| tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, | |
| rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) | |
| for patch_nr in range(z.shape[-1])] | |
| # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) | |
| patch_limits = [(x_tl, y_tl, | |
| rescale_latent * ks[0] / full_img_w, | |
| rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] | |
| # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] | |
| # tokenize crop coordinates for the bounding boxes of the respective patches | |
| patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) | |
| for bbox in patch_limits] # list of length l with tensors of shape (1, 2) | |
| print(patch_limits_tknzd[0].shape) | |
| # cut tknzd crop position from conditioning | |
| assert isinstance(cond, dict), 'cond must be dict to be fed into model' | |
| cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) | |
| adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) | |
| adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') | |
| adapted_cond = self.get_learned_conditioning(adapted_cond) | |
| adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) | |
| cond_list = [{'c_crossattn': [e]} for e in adapted_cond] | |
| else: | |
| cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient | |
| # apply model by loop over crops | |
| output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] | |
| assert not isinstance(output_list[0], | |
| tuple) # todo cant deal with multiple model outputs check this never happens | |
| o = torch.stack(output_list, axis=-1) | |
| o = o * weighting | |
| # Reverse reshape to img shape | |
| o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
| # stitch crops together | |
| x_recon = fold(o) / normalization | |
| else: | |
| x_recon = self.model(x_noisy, t, **cond, return_features=return_features, z_ref=z_ref, | |
| task=self.task, _trainer=self.trainer, | |
| ) | |
| if return_features: | |
| return x_recon | |
| if isinstance(x_recon, tuple) and not return_ids: | |
| return x_recon[0] | |
| else: | |
| return x_recon | |
| def p_losses(self, x_start, cond, t, noise=None, z_ref=None, gt512=None, gt256=None, task=None, | |
| batch :dict = None, | |
| ): | |
| # def p_losses_face(self, x_start, cond, t, reference=None,noise=None,GT_tar=None,landmarks=None): | |
| # initialize MoE auxiliary loss to 0 to allow unconditional accumulation later | |
| global_.moe_aux_loss = torch.tensor(0.0, device=self.device) | |
| if self.first_stage_key == 'inpaint': | |
| # x_start=x_start[:,:4,:,:] | |
| noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) | |
| if 1: | |
| x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise) | |
| x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1) | |
| else: | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| if 1: | |
| x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) | |
| if z_ref is not None: | |
| assert self.first_stage_key == 'inpaint', 'Expected first_stage_key to be "inpaint"' | |
| """ | |
| z_ref: b,4,... | |
| z_ref = concat [z_ref_noisy, z_ref, tensor_1c] | |
| tensor_1c is temporarily set to all zeros | |
| """ | |
| z_ref_noisy = self.q_sample(x_start=z_ref, t=t, noise=torch.randn_like(z_ref)) | |
| tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) | |
| if REFNET.CH9: | |
| z_ref = torch.cat([z_ref_noisy, z_ref, tensor_1c], dim=1) | |
| if 1: | |
| model_output = self.apply_model(x_noisy, t, cond, z_ref=z_ref, ) | |
| loss_dict = {} | |
| prefix = 'train' if self.training else 'val' | |
| if DDIM_losses: | |
| ######################## | |
| t_new = torch.randint(self.num_timesteps-1, self.num_timesteps, (x_start.shape[0],), device=self.device).long().to(self.device) | |
| # t_new=torch.tensor(t_new).to(self.device) | |
| # noise_rec = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) | |
| x_noisy_rec = self.q_sample(x_start=x_start[:,:4,:,:], t=t_new, noise=noise) | |
| x_noisy_rec = torch.cat((x_noisy_rec,x_start[:,4:,:,:]),dim=1) | |
| ddim_steps=self.Reconstruct_DDIM_steps | |
| n_samples=x_noisy_rec.shape[0] | |
| shape=(4,64,64) | |
| scale=5 | |
| ddim_eta=0.0 | |
| start_code=x_noisy_rec | |
| test_model_kwargs=None | |
| # t=t | |
| samples_ddim, sample_intermediates = self.sampler.sample_train(S=ddim_steps, # 4 (from Reconstruct_DDIM_steps in trian.yaml) | |
| conditioning=cond, | |
| batch_size=n_samples, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=None, | |
| eta=ddim_eta, | |
| x_T=start_code, | |
| t=t_new, | |
| z_ref=z_ref, | |
| test_model_kwargs=test_model_kwargs) | |
| # x_samples_ddim= self.differentiable_decode_first_stage(samples_ddim) | |
| other_pred_x_0=sample_intermediates['pred_x0'] | |
| len_inter = len(other_pred_x_0) | |
| printC("len_inter", len_inter ) | |
| for i in range(len(other_pred_x_0)): | |
| other_pred_x_0[i]=self.differentiable_decode_first_stage(other_pred_x_0[i]) | |
| # x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
| # x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() | |
| ########################################### | |
| ID_loss=0 | |
| clip_loss=0 | |
| loss_lpips=0 | |
| loss_rec=0 | |
| loss_landmark=0 | |
| # model_output=samples_ddim | |
| if 1: | |
| # x_samples_ddim=TF.resize(x_samples_ddim,(256,256)) | |
| if 0: | |
| inpaint_mask_64 = x_start[:,8,:,:] # inpaint region is 1, background is 0; shape b,64,64 | |
| masks=TF.resize(inpaint_mask_64,(other_pred_x_0[0].shape[2],other_pred_x_0[0].shape[3])) # b,512,512 | |
| if not 1: | |
| masks = 1 - masks | |
| #mask x_samples_ddim | |
| x_samples_ddim_masked=[x_samples_ddim_preds*masks.unsqueeze(1) for x_samples_ddim_preds in other_pred_x_0] | |
| # x_samples_ddim_masked=un_norm_clip(x_samples_ddim_masked) | |
| # x_samples_ddim_masked = TF.normalize(x_samples_ddim_masked, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| else: | |
| x_samples_ddim_masked = other_pred_x_0 | |
| Landmark_loss_weight = 0 | |
| ID_loss_weight = [0.3, 0, 0.1, 0.2, ][task] | |
| if ID_loss_weight > 0 : | |
| ID_Losses=[] | |
| for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): | |
| ID_loss,sim_imp,_=self.face_ID_model(x_samples_ddim_preds,gt512,clip_img=False) | |
| ID_Losses.append(ID_loss) | |
| loss_dict.update({f'{prefix}/ID_loss_{step}': ID_loss}) | |
| ID_loss=torch.mean(torch.stack(ID_Losses)) | |
| loss_dict.update({f'{prefix}/ID_loss': ID_loss}) | |
| loss_dict.update({f'{prefix}/sim_imp': sim_imp}) | |
| CLIP_loss_weight = [1.5/4, 0.8, 1, 0.5, ][task] | |
| if CLIP_loss_weight > 0 : | |
| def _loss(_img1,_img2): | |
| _e1 = self.encoder_clip_face.forward_vit(_img1,resize=True) | |
| _e2 = self.encoder_clip_face.forward_vit(_img2,resize=True) | |
| return torch.nn.functional.mse_loss( _e1, _e2 ) | |
| clip_Losses=[] | |
| for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): | |
| clip_loss = _loss(x_samples_ddim_preds,gt512) | |
| clip_Losses.append(clip_loss) | |
| loss_dict.update({f'{prefix}/clip_loss_{step}': clip_loss}) | |
| clip_loss=torch.mean(torch.stack(clip_Losses)) | |
| loss_dict.update({f'{prefix}/clip_loss': clip_loss}) | |
| LPIPS_loss_weight = [0.05, 0.015, 0.015, 0.015, ][task] | |
| if LPIPS_loss_weight>0: | |
| if gt256 is not None: | |
| _lpips_base_size = 256 | |
| _gt_for_lpips = gt256 | |
| else: | |
| _lpips_base_size = 512 | |
| _gt_for_lpips = gt512 | |
| for j in range(len(other_pred_x_0)): | |
| for i in range(3): | |
| _size = _lpips_base_size//2**i | |
| _pred_for_lpips = F.adaptive_avg_pool2d(other_pred_x_0[j],(_size,_size)) | |
| _gt_for_lpips_resized = F.adaptive_avg_pool2d(_gt_for_lpips,(_size,_size)) | |
| loss_lpips_1 = self.lpips_loss( | |
| _pred_for_lpips, | |
| _gt_for_lpips_resized, | |
| ) | |
| loss_dict.update({f'{prefix}/loss_lpips_{j}_{i}': loss_lpips_1}) | |
| printC(f"loss_lpips_1 at {j} {i} :", loss_lpips_1) | |
| loss_lpips += loss_lpips_1 | |
| loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) | |
| REC_loss_weight = [0.05, 0.01, 0.01, 0.01, ][task] | |
| if REC_loss_weight > 0 : # rec loss | |
| for j in range(len(other_pred_x_0)): | |
| loss_rec_1 = torch.nn.functional.mse_loss( other_pred_x_0[j], gt512) | |
| loss_dict.update({f'{prefix}/loss_rec_{j}': loss_rec_1}) | |
| printC(f"loss_rec_1 at {j} :", loss_rec_1) | |
| loss_rec += loss_rec_1 | |
| loss_dict.update({f'{prefix}/loss_rec': loss_rec}) | |
| if 1: | |
| if self.parameterization == "x0": | |
| target = x_start | |
| elif self.parameterization == "eps": | |
| target = noise | |
| else: | |
| raise NotImplementedError() | |
| # this should be an MSE loss | |
| loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) | |
| loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) | |
| loss_dict.update({f'{prefix}/loss_simple-t{task}': loss_simple.mean()}) | |
| self.logvar = self.logvar.to(self.device) | |
| logvar_t = self.logvar[t].to(self.device) | |
| loss = loss_simple / torch.exp(logvar_t) + logvar_t | |
| # loss = loss_simple / torch.exp(self.logvar) + self.logvar | |
| if self.learn_logvar: | |
| loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) | |
| loss_dict.update({'logvar': self.logvar.data.mean()}) | |
| loss = self.l_simple_weight * loss.mean() | |
| loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) #?? | |
| loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() | |
| loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) | |
| loss_dict.update({f'{prefix}/loss_vlb-t{task}': loss_vlb}) | |
| loss += (self.original_elbo_weight * loss_vlb) | |
| else: | |
| loss = 0 | |
| if DDIM_losses: | |
| _item = lambda _a: _a.detach().cpu().item() if isinstance(_a,torch.Tensor) else _a | |
| printC("orig, ID clip, lpips rec lmk:", | |
| f"{_item(loss):.4f}, {_item(ID_loss):.4f} {_item(clip_loss):.4f}, {_item(loss_lpips):.4f} {_item(loss_rec):.4f} {_item(loss_landmark):.4f}", | |
| f"{ID_Losses=}" if ID_loss_weight>0 else "", | |
| f"{clip_Losses=}" if CLIP_loss_weight>0 else "", | |
| ) | |
| loss+=ID_loss_weight*ID_loss+LPIPS_loss_weight*loss_lpips+Landmark_loss_weight*loss_landmark+REC_loss_weight*loss_rec+CLIP_loss_weight*clip_loss | |
| # incorporate MoE auxiliary loss | |
| moe_aux = global_.moe_aux_loss | |
| if isinstance(moe_aux, torch.Tensor): | |
| loss = loss + moe_aux | |
| loss_dict.update({f'{prefix}/moe_aux_loss': moe_aux}) | |
| loss_dict.update({f'{prefix}/loss': loss}) | |
| loss_dict.update({f'{prefix}/loss-t{task}': loss}) | |
| return loss, loss_dict | |
| def configure_optimizers(self): | |
| lr = self.learning_rate | |
| params = list(self.model.parameters()) | |
| if self.partial_training:# no | |
| # if True: | |
| print("Partial training.............................") | |
| train_names=self.trainable_keys | |
| train_names=[ 'attn2','norm2'] | |
| params_train=[] | |
| for name,param in self.model.named_parameters(): | |
| if "diffusion_model" not in name and param.requires_grad: | |
| print(name) | |
| params_train.append(param) | |
| elif "diffusion_model" in name and any(train_name in name for train_name in train_names): | |
| print(name) | |
| params_train.append(param) | |
| params=params_train | |
| print("Setting up Adam optimizer.......................") | |
| if self.cond_stage_trainable:# yes | |
| print(f"{self.__class__.__name__}: Also optimizing conditioner params!") | |
| if hasattr(self,'encoder_clip_face'): | |
| params += list(self.encoder_clip_face.final_ln2.parameters())+list(self.encoder_clip_face.mapper2.parameters()) | |
| if self.USE_proj_out_source: | |
| params += list(self.proj_out_source__face.parameters()) | |
| if hasattr(self,'encoder_clip_hair'): | |
| params += list(self.encoder_clip_hair.final_ln2.parameters())+list(self.encoder_clip_hair.mapper2.parameters()) | |
| if self.USE_proj_out_source: | |
| params += list(self.proj_out_source__hair.parameters()) | |
| if hasattr(self,'encoder_clip_head_t2'): | |
| params += list(self.encoder_clip_head_t2.final_ln2.parameters())+list(self.encoder_clip_head_t2.mapper2.parameters()) | |
| if hasattr(self,'encoder_clip_head_t3'): | |
| params += list(self.encoder_clip_head_t3.final_ln2.parameters())+list(self.encoder_clip_head_t3.mapper2.parameters()) | |
| if hasattr(self,'encoder_clip_head_t2') or hasattr(self,'encoder_clip_head_t3'): | |
| if self.USE_proj_out_source: | |
| params += list(self.proj_out_source__head.parameters()) | |
| if hasattr(self,'ID_proj_out'): | |
| params += list(self.ID_proj_out.parameters()) | |
| if hasattr(self,'landmark_proj_out'): # fixLmkProj | |
| params += list(self.landmark_proj_out.parameters()) | |
| if self.learn_logvar: | |
| print('Diffusion model optimizing logvar') | |
| params.append(self.logvar) | |
| params.extend(self.learnable_vector) | |
| params = [p for p in params if p.requires_grad] | |
| # Build param groups: MoE gate/expert use larger LR. | |
| # Also apply per-task LR factor to all task-specific params. | |
| # only match MoE-related parameter names generated by the UNet wrappers | |
| moe_gate_ids = set() | |
| moe_ep_ids = set() | |
| for name, p in self.model.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if ".moe_gate_mlp." in name: | |
| moe_gate_ids.add(id(p)) | |
| elif ".moe_experts_" in name: | |
| moe_ep_ids.add(id(p)) | |
| params_ids = set(id(p) for p in params) | |
| task_specific_ids = set() | |
| for name, p in self.named_parameters(): | |
| if not p.requires_grad: | |
| continue | |
| if id(p) not in params_ids: | |
| continue | |
| is_task_specific = is_task_specific_(name) | |
| if rank_==0: print(f"{is_task_specific=} {name}") | |
| if is_task_specific: | |
| task_specific_ids.add(id(p)) | |
| base_params = [] | |
| task_specific_params = [] | |
| moe_gate_params = [] | |
| moe_ep_params = [] | |
| for p in params: | |
| pid = id(p) | |
| if pid in task_specific_ids: | |
| task_specific_params.append(p) | |
| elif pid in moe_gate_ids: | |
| moe_gate_params.append(p) | |
| elif pid in moe_ep_ids: | |
| moe_ep_params.append(p) | |
| else: | |
| base_params.append(p) | |
| param_groups = [] | |
| if base_params: | |
| param_groups.append({"params": base_params, "lr": lr}) | |
| if task_specific_params: | |
| param_groups.append({"params": task_specific_params, "lr": lr * LR_factor}) | |
| if moe_gate_params: | |
| param_groups.append({"params": moe_gate_params, "lr": lr * MOE_GATE_LR_MULT}) | |
| if moe_ep_params: | |
| param_groups.append({"params": moe_ep_params, "lr": lr * MOE_EP_LR_MULT}) | |
| if ZERO1_ENABLE: | |
| zero_pg = None | |
| if 1: | |
| if dist.is_available() and dist.is_initialized(): | |
| zero_pg = dist.new_group(backend='gloo') | |
| opt = ZeroRedundancyOptimizer( | |
| param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, | |
| optimizer_class=torch.optim.AdamW if ADAM_or_SGD else torch.optim.SGD, | |
| lr=lr, | |
| process_group=zero_pg, | |
| ) | |
| else: | |
| if ADAM_or_SGD: | |
| opt = torch.optim.AdamW(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr) | |
| else: | |
| opt = torch.optim.SGD(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr, momentum=0.9) | |
| if gate_('LatentDiffusion.configure_optimizers params:'): | |
| if (task_specific_params or moe_gate_params or moe_ep_params): | |
| print(f"base/task_specific/ep/gate lens: {len(base_params)=} {len(task_specific_params)=} {len(moe_ep_params)=} {len(moe_gate_params)=}") | |
| print(f"sum of .numel(): base={sum(p.numel() for p in base_params)} task_specific={sum(p.numel() for p in task_specific_params)} ep={sum(p.numel() for p in moe_ep_params)} gate={sum(p.numel() for p in moe_gate_params)}") | |
| else: | |
| print(f"{len(params)=}") | |
| print(f"sum of .numel(): {sum(param.numel() for param in params)}") | |
| if self.use_scheduler:# yes | |
| assert 'target' in self.scheduler_config | |
| scheduler = instantiate_from_config(self.scheduler_config) | |
| print("Setting up LambdaLR scheduler...") | |
| scheduler = [ | |
| { | |
| 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), | |
| 'interval': 'step', | |
| 'frequency': 1 | |
| }] | |
| return [opt], scheduler | |
| return opt | |
| def on_train_epoch_start(self): | |
| def _set_req_grad(p, flag): | |
| if p.requires_grad != flag: | |
| p.requires_grad = flag | |
| return 1 | |
| return 0 | |
| return | |
| if 0: | |
| train_now = self.current_epoch < N_EPOCHS_TRAIN_REF_AND_MID | |
| else: # alternating freezing | |
| train_now = (self.current_epoch % 2 == 0) | |
| ct_toggled = 0 | |
| # 1) freeze all shared if not train_now; unfreeze when train_now | |
| ct_shared = 0 | |
| for name, p in self.model.diffusion_model.named_parameters(): | |
| # target only the shared weights inside Shared+LoRA wrappers: FFN.shared_ffn.* and Conv.shared.* | |
| is_shared = ('.shared_ffn.' in name) or ('.shared.' in name) | |
| if is_shared: | |
| ct_shared += _set_req_grad(p, train_now) | |
| print(f"[freeze@epoch]{self.current_epoch=} {train_now=} {ct_toggled=} {ct_shared=}") | |
| def to_rgb(self, x): | |
| x = x.float() | |
| if not hasattr(self, "colorize"): | |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) | |
| x = nn.functional.conv2d(x, weight=self.colorize) | |
| x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. | |
| return x | |
| def __repr__(self): | |
| if DEBUG: return 'LatentDiffusion.__repr__' | |
| return super().__repr__() | |
| def model_size(self): | |
| if DEBUG: return -1 | |
| return super().model_size | |
| from .bank import Bank | |
| class DiffusionWrapper(pl.LightningModule): | |
| def __init__(self, diff_model_config, conditioning_key): | |
| super().__init__() | |
| diff_model_config['params']['is_refNet'] = False | |
| self.diffusion_model = instantiate_from_config(diff_model_config) | |
| self.conditioning_key = conditioning_key | |
| assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] | |
| if REFNET.ENABLE: | |
| diff_model_config_refNet = diff_model_config | |
| print('instantiate / deepcopy diffusion_model_refNet ing...') | |
| if 1: | |
| diff_model_config_refNet['params']['in_channels'] = 9 if REFNET.CH9 else 4 | |
| diff_model_config_refNet['params']['is_refNet'] = True | |
| self.diffusion_model_refNet :UNetModel = instantiate_from_config(diff_model_config_refNet) | |
| else: | |
| self.diffusion_model_refNet :UNetModel = copy.deepcopy(self.diffusion_model) # faster than re-instantiating | |
| self.diffusion_model_refNet.is_refNet = True | |
| if 1: | |
| # print(f"before del: {len(self.diffusion_model_refNet.input_blocks)=}") | |
| if 1: | |
| self.diffusion_model_refNet.input_blocks = self.diffusion_model_refNet.input_blocks[:9] | |
| del self.diffusion_model_refNet.middle_block | |
| del self.diffusion_model_refNet.output_blocks | |
| del self.diffusion_model_refNet.out | |
| print('over.') | |
| # Keep only a single diffusion_model_refNet; no t-suffixed clones | |
| def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,return_features=False, | |
| z_ref=None, | |
| task = None, | |
| _trainer :pl.Trainer = None, | |
| ): | |
| _in_train_or_val = ( _trainer is not None ) and ( _trainer.validating or _trainer.sanity_checking ) # indicates train or validation state | |
| assert self.conditioning_key == 'crossattn' | |
| if self.conditioning_key is None: | |
| out = self.diffusion_model(x, t) | |
| elif self.conditioning_key == 'concat': | |
| xc = torch.cat([x] + c_concat, dim=1) | |
| out = self.diffusion_model(xc, t) | |
| elif self.conditioning_key == 'crossattn': | |
| cc = torch.cat(c_crossattn, 1) #-->cc.shape = (bs, 1, 768) ## adding return_features here only for testing | |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): | |
| if task in (0,2,3,): | |
| cc_ref = cc[:,:-1, :] | |
| else: | |
| cc_ref = cc | |
| printC("c for refNet",f"{custom_repr_v3(cc_ref)}") | |
| self.diffusion_model_refNet(z_ref, t, context=cc_ref,return_features=False) | |
| out = self.diffusion_model(x, t, context=cc,return_features=return_features) | |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0) and not (self.training or _in_train_or_val): | |
| # if 1: | |
| self.bank.clear() | |
| elif self.conditioning_key == 'hybrid': | |
| xc = torch.cat([x] + c_concat, dim=1) | |
| cc = torch.cat(c_crossattn, 1) | |
| out = self.diffusion_model(xc, t, context=cc) | |
| elif self.conditioning_key == 'adm': | |
| cc = c_crossattn[0] | |
| out = self.diffusion_model(x, t, y=cc) | |
| else: | |
| raise NotImplementedError() | |
| return out #-->out.shape = (bs, 4,64,64) | |