from .misc_4ddpm import * from lmk_util.lmk_extractor import lmkAll_2_lmkMain, get_lmkMain_indices class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space def __init__(self, unet_config, timesteps=1000, beta_schedule="linear", loss_type="l2", ckpt_path=None, ignore_keys=[], load_only_unet=False, monitor="val/loss", use_ema=True, first_stage_key="image", image_size=256, channels=3, log_every_t=100, clip_denoised=True, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, given_betas=None, original_elbo_weight=0., v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta l_simple_weight=1., conditioning_key=None, parameterization="eps", # all assuming fixed variance schedules scheduler_config=None, learn_logvar=False, logvar_init=0., u_cond_percent=0, ): super().__init__() assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' self.parameterization = parameterization print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t self.first_stage_key = first_stage_key self.image_size = image_size self.channels = channels self.u_cond_percent=u_cond_percent unet_config['params']['in_channels'] = 14 if CH14 else 9 self.model = DiffusionWrapper(unet_config, conditioning_key) count_params(self.model, verbose=True) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.use_scheduler = scheduler_config is not None if self.use_scheduler: self.scheduler_config = scheduler_config self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) self.loss_type = loss_type self.learn_logvar = learn_logvar self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if exists(given_betas): betas = given_betas else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer('posterior_variance', to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) if self.parameterization == "eps": lvlb_weights = self.betas ** 2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) else: raise NotImplementedError("mu not supported") # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): assert 0 print("[init_from_ckpt]") sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) def get_loss(self, pred, target, mean=True): if self.loss_type == 'l1': loss = (target - pred).abs() if mean: loss = loss.mean() elif self.loss_type == 'l2': if mean: loss = torch.nn.functional.mse_loss(target, pred) else: loss = torch.nn.functional.mse_loss(target, pred, reduction='none') #--> else: raise NotImplementedError("unknown loss type '{loss_type}'") return loss def p_losses(self, x_start, t, noise=None): assert 0, 'This should not be called; subclasses override this method' noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t) loss_dict = {} if self.parameterization == "eps": target = noise elif self.parameterization == "x0": target = x_start else: raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) # metrics.csv entries like 'train/...' and 'val/...' originate here log_prefix = 'train' if self.training else 'val' loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb loss_dict.update({f'{log_prefix}/loss': loss}) return loss, loss_dict def forward(self, x, *args, **kwargs): # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() return self.p_losses(x, t, *args, **kwargs) def shared_step(self, batch): assert 0 def set_task(self, batch): task = batch['task'][0].item() printC('task',f"{task=}") global_.task = task assert all(batch['task'] == task), batch['task'] self.task = task if 1: if (not USE_pts) or task==1: self.Landmark_cond=False else: self.Landmark_cond=True if 1: if task in (0,2,3,): self.Landmarks_weight=0.05 else: self.Landmarks_weight=0 self.STACK_feat=True return task def unset_task(self): global_.task = None global_.lmk_ = None del self.task def training_step(self, batch, batch_idx): task = batch['task'][0].item() opt = self.optimizers() if not self.Reconstruct_initial:# only MSE loss(orig diffusion). -> shared_step -> forward -> p_losses loss, loss_dict = self.shared_step(batch) # original else: # added Multistep (DDIM) loss -> shared_step_face -> forward_face -> p_losses_face loss, loss_dict = self.shared_step_face(batch) # changed by sanoojan : to add ID loss after reconstructing through DDIM step_or_accumulate = ( task==3 or TP_enable) _ctx = nullcontext if not step_or_accumulate and not TP_enable: _ctx = self.trainer.model.no_sync # https://github.com/Lightning-AI/pytorch-lightning/discussions/10792 with _ctx(): # https://zhuanlan.zhihu.com/p/250471767 self.manual_backward(loss) if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): self.model.bank.clear() self.unset_task() total_step = len(self.trainer.train_dataloader) if step_or_accumulate: # Average grads of shared params across ranks (TaskParallel) if dist.is_available() and dist.is_initialized(): ws = dist.get_world_size() shared_sync_cnt = 0; task_skip_cnt = 0 for name, p in self.named_parameters(): need_sync, is_task_specific_skip = tp_param_need_sync(name, p) if not need_sync: if is_task_specific_skip: task_skip_cnt += 1 continue if p.grad is None: p.grad = torch.zeros_like(p) # ensure collective call sequence remains consistent dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) p.grad.div_(ws) shared_sync_cnt += 1 if gate_('[TP] shared sync counts'): print(f"synced={shared_sync_cnt} skipped(task)={task_skip_cnt}") torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) opt.step() opt.zero_grad() if self.use_scheduler: # handle LR schedulers sch = self.lr_schedulers() if isinstance(sch, list) and len(sch) > 0: # schedulers expressed as a list for scheduler_config in sch: if isinstance(scheduler_config, dict) and 'scheduler' in scheduler_config: scheduler_config['scheduler'].step() else: scheduler_config.step() elif hasattr(sch, 'step'): sch.step() self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: lr = self.optimizers().param_groups[0]['lr'] self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss # manual optimization calls backward in training_step already, so this is skipped here # def backward( @torch.no_grad() def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.unset_task() def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) class LatentDiffusion(DDPM): """main class""" def __init__(self, first_stage_config, cond_stage_config, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, concat_mode=True, cond_stage_forward=None, conditioning_key=None, scale_factor=1.0, scale_by_std=False, *args, **kwargs): self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std assert self.num_timesteps_cond <= kwargs['timesteps'] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: conditioning_key = 'concat' if concat_mode else 'crossattn' if cond_stage_config == '__is_unconditional__': conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.automatic_optimization = False # disable automatic optimization to manage parameter updates manually # self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True) # breakpoint() self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key #check if other_params is present in cond_stage_config if hasattr(cond_stage_config, 'other_params'): self.clip_weight=cond_stage_config.other_params.clip_weight # those three weights: 0 skips module init, >0 enables it and acts as weight when !STACK_feat if set(TASKS) & {0,2,3}: self.ID_weight = 10.0 else: self.ID_weight = 0 if (not USE_pts) and TASKS==(1,): self.Landmark_cond=False else: self.Landmark_cond=True self.Landmarks_weight=0.05 if hasattr(cond_stage_config.other_params, 'Additional_config'): self.Reconstruct_initial=cond_stage_config.other_params.Additional_config.Reconstruct_initial self.Reconstruct_DDIM_steps=cond_stage_config.other_params.Additional_config.Reconstruct_DDIM_steps self.sampler=DDIMSampler(self) if hasattr(cond_stage_config.other_params, 'multi_scale_ID'): self.multi_scale_ID=cond_stage_config.other_params.multi_scale_ID # True has an issue else: self.multi_scale_ID=True #this has an issue obtaining earlier layer from ID if hasattr(cond_stage_config.other_params, 'normalize'): self.normalize=cond_stage_config.other_params.normalize # normalizes the combintaion of ID and LPIPS loss else: self.normalize=False if 1: self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() if hasattr(cond_stage_config.other_params, 'partial_training'): self.partial_training=cond_stage_config.other_params.partial_training self.trainable_keys=cond_stage_config.other_params.trainable_keys else: self.partial_training=False if hasattr(cond_stage_config.other_params.Additional_config, 'Same_image_reconstruct'): self.Same_image_reconstruct=cond_stage_config.other_params.Additional_config.Same_image_reconstruct else: self.Same_image_reconstruct=False if hasattr(cond_stage_config.other_params.Additional_config, 'Target_CLIP_feat'): self.Target_CLIP_feat=cond_stage_config.other_params.Additional_config.Target_CLIP_feat else: self.Target_CLIP_feat=False if hasattr(cond_stage_config.other_params.Additional_config, 'Source_CLIP_feat'): self.Source_CLIP_feat=cond_stage_config.other_params.Additional_config.Source_CLIP_feat else: self.Source_CLIP_feat=False if hasattr(cond_stage_config.other_params.Additional_config, 'use_3dmm'): self.use_3dmm=cond_stage_config.other_params.Additional_config.use_3dmm else: self.use_3dmm=False else: self.Reconstruct_initial=False self.Reconstruct_DDIM_steps=0 self.update_weight=False else: assert 0 if 1: self.learnable_vector = nn.ParameterList([ nn.Parameter(torch.randn((1,259,768)), requires_grad=True), nn.Parameter(torch.randn((1,257,768)), requires_grad=True), nn.Parameter(torch.randn((1,259,768)), requires_grad=True), nn.Parameter(torch.randn((1,259,768)), requires_grad=True), ]) if self.ID_weight>0: if self.multi_scale_ID: self.ID_proj_out=nn.Linear(200704, 768) else: self.ID_proj_out=nn.Linear(512, 768) # yes self.instantiate_IDLoss(cond_stage_config) if self.Landmark_cond: if USE_pts: self.ptsM_Generator = LandmarkExtractor(include_visualizer=True,img_256_mode=False) else: self.detector = dlib.get_frontal_face_detector() self.predictor = dlib.shape_predictor("Other_dependencies/DLIB_landmark_det/shape_predictor_68_face_landmarks.dat") if self.Landmarks_weight>0: self.landmark_proj_out=nn.Linear(NUM_pts*2, 768) self.total_steps_in_epoch=0 # will be calculated inside training_step. Not known for now if 1: assert cond_stage_config.target=="ldm.modules.encoders.modules.FrozenCLIPEmbedder" and self.Source_CLIP_feat and self.Target_CLIP_feat self.USE_proj_out_source = 1 if set(TASKS) & {0,}: self.proj_out_source__face=nn.Linear(768, 768) if set(TASKS) & {1,}: self.proj_out_source__hair=nn.Linear(768, 768) if set(TASKS) & {2,3,}: self.proj_out_source__head=nn.Linear(768, 768) if 0: # dummy, just for compa self.proj_out_target=nn.Linear(768, 768) self.proj_out=nn.Identity() try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor else: self.register_buffer('scale_factor', torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True def get_lmk_for_router(self, batch: dict, x_tensor: torch.Tensor): """ Prepare global_.lmk_ (BS, L, 2) normalized to [0,1] for gating/router. - Prefer cached Mediapipe landmarks if present in batch - Convert 468/478 to main landmarks with face oval using get_lmkMain_indices(True) - Fallback to zeros if not available """ b, _, H, W = x_tensor.shape if READ_mediapipe_result_from_cache and ('mediapipe_lmkAll' in batch): data_all = batch['mediapipe_lmkAll'] # tensor or ndarray if isinstance(data_all, torch.Tensor): lmks_all = data_all.to(x_tensor.device).to(x_tensor.dtype) else: lmks_all = torch.from_numpy(data_all).to(x_tensor.device).to(x_tensor.dtype) # map to main indices with face oval (cached tensor indices on device) idxs = getattr(global_, 'lmk_main_idx_tensor', None) if (idxs is None) or (idxs.device != x_tensor.device): idx_list = get_lmkMain_indices(include_face_oval=True) idxs = torch.as_tensor(list(idx_list), dtype=torch.long, device=x_tensor.device) global_.lmk_main_idx_tensor = idxs lmk = torch.index_select(lmks_all, dim=1, index=idxs) # normalize by current spatial size if lmk.numel() > 0: # print(f"0 {lmk[:,:5]=}") lmk[..., 0] = lmk[..., 0] / float(W) lmk[..., 1] = lmk[..., 1] / float(H) # print(f"1 {lmk[:,:5]=}") else: assert 0 lmk = torch.zeros((b, 0, 2), device=x_tensor.device, dtype=x_tensor.dtype) return lmk def make_cond_schedule(self, ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() self.cond_ids[:self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx, dataloader_idx): # only for very first batch if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: assert 0 def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: self.make_cond_schedule() def instantiate_first_stage(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): param.requires_grad = False def instantiate_IDLoss(self, config): # Need to modify @sanoojan # if not self.cond_stage_trainable: model = IDLoss(config,multiscale=self.multi_scale_ID) self.face_ID_model = model.eval() self.face_ID_model.train = disabled_train for param in self.face_ID_model.parameters(): param.requires_grad = False def instantiate_cond_stage(self, config): if 1: assert config != '__is_first_stage__' assert config != '__is_unconditional__' model: FrozenCLIPEmbedder = instantiate_from_config(config) #ldm.modules.encoders.modules.FrozenCLIPEmbedder if 0 in TASKS: self.encoder_clip_face :FrozenCLIPEmbedder = model if 1 in TASKS: self.encoder_clip_hair :FrozenCLIPEmbedder = copy.deepcopy(model) del self.encoder_clip_hair.model del self.encoder_clip_hair.tokenizer if set(TASKS) & {2,}: self.encoder_clip_head_t2 :FrozenCLIPEmbedder = copy.deepcopy(model) del self.encoder_clip_head_t2.model del self.encoder_clip_head_t2.tokenizer if set(TASKS) & {3,}: self.encoder_clip_head_t3 :FrozenCLIPEmbedder = copy.deepcopy(model) del self.encoder_clip_head_t3.model del self.encoder_clip_head_t3.tokenizer def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append(self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid def get_first_stage_encoding(self, encoder_posterior): if isinstance(encoder_posterior, DiagonalGaussianDistribution): z = encoder_posterior.sample() elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") return self.scale_factor * z def get_learned_conditioning(self, c): raise Exception def conditioning_with_feat(self,x,landmarks=None,enInputs:dict=None): if gate_('vis LatentDiffusion.conditioning_with_feat'): debug_dir = Path(f"4debug/conditioning_with_feat/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) all_images = [ ('x', x), ] for _name, _enInput in enInputs.items(): all_images.append((_name, _enInput)) vis_tensors_A(all_images, debug_dir / f"all-{str_t_pid()}.jpg", vis_batch_size= min(5, landmarks.shape[0]) ) del x # (x is GT during training, ref_imgs during inference) task = self.task ID_weight = self.ID_weight Landmarks_weight = self.Landmarks_weight if self.task==0: face_clip_weight = self.clip_weight elif self.task==1: hair_clip_weight = self.clip_weight elif self.task==2: head_clip_weight = self.clip_weight elif self.task==3: head_clip_weight = self.clip_weight if 1: cs = [] # conditionings ws = [] # weights corresponding one-to-one with cs def encode_face_ID(): _c = enInputs['face_ID-in'] _c=self.face_ID_model.extract_feats(_c)[0] _c = self.ID_proj_out(_c) #-->c:[4,768] _c = _c.unsqueeze(1) #-->c:[4,1,768] if self.normalize: #normalize c2 _c = _c*norm_coeff/F.normalize(_c, p=2, dim=2) cs.append(_c); ws.append(ID_weight) def encode_face_clip(_z=None):# _z: result of ViT forward pass if _z is None: _c = enInputs['face-clip-in'] _c = self.encoder_clip_face.encode(_c) #b,3,224,224 --> b,1,768 else: assert 0 _c = self.encoder_clip_face.encode_B(_z) if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: _c = self.proj_out_source__face(_c) cs.append(_c); ws.append(face_clip_weight) def encode_hair_clip(_z=None): if _z is None: _c = enInputs['hair-clip-in'] _c = self.encoder_clip_hair.encode(_c) #b,3,224,224 --> b,1,768 else: _c = self.encoder_clip_hair.encode_B(_z) if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: _c = self.proj_out_source__hair(_c) printC("hair _c.shape:",f"{_c.shape}") cs.append(_c); ws.append(hair_clip_weight) def encode_head_clip(_z=None): if global_.task == 2: encoder_clip_head = self.encoder_clip_head_t2 elif global_.task == 3: encoder_clip_head = self.encoder_clip_head_t3 else: raise ValueError(f"Task {global_.task} does not have encoder_clip_head") if _z is None: _c = enInputs['head-clip-in'] _c = encoder_clip_head.encode(_c) #b,3,224,224 --> b,1,768 else: _c = encoder_clip_head.encode_B(_z) if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: _c = self.proj_out_source__head(_c) printC("head _c.shape:",f"{_c.shape}") cs.append(_c); ws.append(head_clip_weight) if task==0: encode_face_ID() encode_face_clip() elif task==1: _z = enInputs['hair-clip-in'] _z = self.encoder_clip_face.forward_vit(_z) encode_hair_clip(_z) elif task==2: encode_face_ID() _z = enInputs['head-clip-in'] _z = self.encoder_clip_face.forward_vit(_z) encode_head_clip(_z) elif task==3: encode_face_ID() _z = enInputs['head-clip-in'] _z = self.encoder_clip_face.forward_vit(_z) encode_head_clip(_z) c=0 if Landmarks_weight > 0: landmarks=landmarks.unsqueeze(1) if len(landmarks.shape)!=3 else landmarks cs.append(landmarks); ws.append(Landmarks_weight) if self.STACK_feat: # _Cc # stack all features conc=torch.cat(cs, dim=-2) c = conc else: total_weight = sum(ws) weighted_sum = sum(c * w for c, w in zip(cs, ws)) c = weighted_sum / total_weight if total_weight > 0 else 0 printC("[conditioning_with_feat return]",f"{custom_repr_v3(c)}") # assert c.shape[1]==NUM_token, c.shape return c def get_landmarks(self,x, batch:dict): if (self.Landmark_cond) and x is not None: # pass # # Detect faces in an image #convert to 8bit image x=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() x=x.astype(np.uint8) # B,512,512,3 Landmarks_all=[] if USE_pts: l_lmkAll=[] if READ_mediapipe_result_from_cache: _l_lmkAll :np.ndarray = batch['mediapipe_lmkAll'].cpu().numpy() bs = len(x) for i in range(len(x)): if USE_pts: if READ_mediapipe_result_from_cache: lmkAll :np.ndarray = _l_lmkAll[i] else: lmkAll :np.ndarray = self.ptsM_Generator.extract_single(x[i], only_main_lmk=False) if lmkAll is None: lmkAll = np.zeros((478,2)) l_lmkAll.append(lmkAll) lm = lmkAll_2_lmkMain(lmkAll) # NUM_pts,2 lm = lm.reshape(1, NUM_pts*2) # num of points * 2 coordinates Landmarks_all.append(lm) if 0: from util_vis import visualize_landmarks starter_stem = Path(sys.argv[0]).stem path_vis_lmk = f'4debug/vis_lmk/{starter_stem}-{i}.png' visualize_landmarks(x[i], lm[0], path_vis_lmk) print(f"{path_vis_lmk=}") Landmarks_all=np.concatenate(Landmarks_all,axis=0) pts68 = Landmarks_all.reshape(bs, NUM_pts, 2, ) if self.Landmarks_weight>0: Landmarks_all=torch.tensor(Landmarks_all).float().to(self.device) if self.Landmark_cond == False: return Landmarks_all with torch.enable_grad(): Landmarks_all=self.landmark_proj_out(Landmarks_all) # normalize Landmarks_all lmk_aux={} if USE_pts: lmk_aux['l_lmkAll'] = l_lmkAll return Landmarks_all,pts68,lmk_aux def meshgrid(self, h, w): y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) arr = torch.cat([y, x], dim=-1) return arr def delta_border(self, h, w): """ :param h: height :param w: width :return: normalized distance to image border, wtith min distance = 0 at border and max dist = 0.5 at image center """ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) arr = self.meshgrid(h, w) / lower_right_corner dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] return edge_dist def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], self.split_input_params["clip_max_weight"], ) weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) if self.split_input_params["tie_braker"]: L_weighting = self.delta_border(Ly, Lx) L_weighting = torch.clip(L_weighting, self.split_input_params["clip_min_tie_weight"], self.split_input_params["clip_max_tie_weight"]) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) """ bs, nc, h, w = x.shape # number of crops in image Ly = (h - kernel_size[0]) // stride[0] + 1 Lx = (w - kernel_size[1]) // stride[1] + 1 if uf == 1 and df == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) elif uf > 1 and df == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), dilation=1, padding=0, stride=(stride[0] * uf, stride[1] * uf)) fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) elif df > 1 and uf == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), dilation=1, padding=0, stride=(stride[0] // df, stride[1] // df)) fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) else: raise NotImplementedError return fold, unfold, normalization, weighting # returned x is the concatenated multi-channel tensor (mask, ref, lmk, ...); e.g. "x_start[:,8,:,:]" extracts the mask @torch.no_grad() def get_input_(self, batch, k, return_first_stage_outputs=False, cond_key=None, bs=None, get_referenceZ=False, # reference image latent tensor, dims B,4,64,64 ): if k == "inpaint": # yes x = batch['GT'] mask = batch['inpaint_mask'].clone() # b,1,512,512 inpaint = batch['inpaint_image'].clone() # .clone so that batch['inpaint_image'] remains the original image without landmarks # reference = batch['ref_imgs'] reference = None else: assert 0 if len(x.shape) == 3: assert 0 x = x[..., None] if 1: enInputs = batch['enInputs'] # encoder inputs (each self.encoder receives these raw tensors without preprocessing) for k,v in enInputs.items(): enInputs[k] = v.to(memory_format=torch.contiguous_format).float() #-------------------------------------------------------------------------------- ref_imgs_4unet = batch.get('ref_imgs_4unet', None) if get_referenceZ else None #x : Original Image #inpaint : Masked original image #mask: mask #reference: Transformed(Masked(original image)) if bs is not None: assert 0 x = x.to(self.device) global_.lmk_ = self.get_lmk_for_router(batch, x) # for router/gate if self.Landmark_cond: landmarks, pts68, lmk_aux=self.get_landmarks(x,batch) else: landmarks=None if self.task in (0,2,3,) and USE_pts: mask_np = mask.detach().cpu().numpy() if 1: #convert to 8bit image x_unnorm=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() x_unnorm=x_unnorm.astype(np.uint8) # B,512,512,3 batch_size = x.shape[0] VIS_pts= 0 for b in range(batch_size): lmkAll = lmk_aux['l_lmkAll'][b] inpaint[b] = torch.Tensor(self.ptsM_Generator.visualizer.visualize_landmarks(inpaint[b].permute(1,2,0).detach().cpu().numpy(), lmkAll, ) ).permute(2,0,1) del lmkAll if self.training and gate_('vis LatentDiffusion.get_input'): debug_dir = Path(f"4debug/LatentDiffusion.get_input/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) vis_batch_size = min(5, x.shape[0]) # Show at most 4 samples all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet) ] for _name, _enInput in enInputs.items(): all_images.append((_name, _enInput)) all_path = debug_dir / f"all--after-pts-{str_t_pid()}.jpg" vis_tensors_A(all_images, all_path, vis_batch_size) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() encoder_posterior_inpaint = self.encode_first_stage(inpaint) z_inpaint = self.get_first_stage_encoding(encoder_posterior_inpaint).detach() # tgt/ref_mask_64 mask_resize = Resize([z.shape[-1],z.shape[-1]])(mask) ref_mask_64 = Resize([z.shape[-1],z.shape[-1]])(batch['ref_mask_512']) if 'ref_mask_512' in batch else None # z9 & z_ref if not CH14: z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) # shape:[4,9,64,64] 9:4+4+1 if get_referenceZ: encoder_posterior_ref = self.encode_first_stage(ref_imgs_4unet) z_ref = self.get_first_stage_encoding(encoder_posterior_ref).detach() # shape:[4,4,64,64] else: z_ref = None if CH14: z_new = torch.cat((z,z_inpaint,mask_resize, z_ref,ref_mask_64),dim=1) assert z.shape[1:]==(4,64,64,) if gate_(f'vis LatentDiffusion.get_input-before_return {self.training}'): debug_dir = Path(f"4debug/LatentDiffusion.get_input-before_return/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) vis_batch_size = min(5, x.shape[0]) all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet), ('z4_gt',z[:,:3]),('z4_inpaint', z_inpaint[:,:3]),('tgt_mask_64', mask_resize),('z_ref',None if z_ref is None else z_ref[:,:3]),('ref_mask_64',ref_mask_64),] all_path = debug_dir / f"{str_t_pid()}.jpg" vis_tensors_A(all_images, all_path, vis_batch_size) if 1: assert self.model.conditioning_key is not None assert self.first_stage_key=='inpaint' assert self.cond_stage_key=='image' return { **batch, 'z9': z_new,# b,9/14,... 'z4_gt': z, 'z4_inpaint': z_inpaint, # 'tgt_mask_64': mask_resize, 'ref_mask_64': ref_mask_64, # 'z_ref': z_ref, # 'z_ref' is ambiguous but kept for legacy usage; hard-code the intended meaning # 'landmarks': landmarks, # projected features, not raw coordinates } @torch.no_grad() def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) z = rearrange(z, 'b h w c -> b c h w').contiguous() z = 1. / self.scale_factor * z if hasattr(self, "split_input_params"): if self.split_input_params["patch_distributed_vq"]: ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) uf = self.split_input_params["vqf"] bs, nc, h, w = z.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) print("reducing Kernel") if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) print("reducing stride") fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) z = unfold(z) # (bn, nc * prod(**ks), L) # 1. Reshape to img shape z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] else: output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1])] o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) o = o * weighting # Reverse 1. reshape to img shape o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization # norm is shape (1, 1, h, w) return decoded else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) else: return self.first_stage_model.decode(z) else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) else: if self.first_stage_key=='inpaint': return self.first_stage_model.decode(z[:,:4,:,:]) else: return self.first_stage_model.decode(z) # same as above but without decorator def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) z = rearrange(z, 'b h w c -> b c h w').contiguous() z = 1. / self.scale_factor * z if hasattr(self, "split_input_params"): if self.split_input_params["patch_distributed_vq"]: ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) uf = self.split_input_params["vqf"] bs, nc, h, w = z.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) print("reducing Kernel") if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) print("reducing stride") fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) z = unfold(z) # (bn, nc * prod(**ks), L) # 1. Reshape to img shape z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] else: output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1])] o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) o = o * weighting # Reverse 1. reshape to img shape o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization # norm is shape (1, 1, h, w) return decoded else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) else: return self.first_stage_model.decode(z) else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) else: return self.first_stage_model.decode(z) @torch.no_grad() def encode_first_stage(self, x): if hasattr(self, "split_input_params"): if self.split_input_params["patch_distributed_vq"]: ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) df = self.split_input_params["vqf"] self.split_input_params['original_image_size'] = x.shape[-2:] bs, nc, h, w = x.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) print("reducing Kernel") if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) print("reducing stride") fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) z = unfold(x) # (bn, nc * prod(**ks), L) # Reshape to img shape z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) for i in range(z.shape[-1])] o = torch.stack(output_list, axis=-1) o = o * weighting # Reverse reshape to img shape o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization return decoded else: return self.first_stage_model.encode(x) else: return self.first_stage_model.encode(x) def get_input_and_conditioning(self,batch, device=None): if device is not None: batch = recursive_to(batch, device) #------------------------from shared_step------------------------- get_referenceZ=(REFNET.ENABLE and REFNET.task2layerNum[global_.task]>0) or CH14 batch = self.get_input_(batch, self.first_stage_key,get_referenceZ=get_referenceZ) #------------------------from shared_step -> forward------------------------- assert ( self.model.conditioning_key is not None ) and self.cond_stage_trainable c=self.conditioning_with_feat(batch['ref_imgs'],landmarks=batch['landmarks'],enInputs=batch['enInputs']) return batch,c def shared_step(self, batch, **kwargs): task = self.set_task(batch) if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): self.model.bank.clear() batch, c = self.get_input_and_conditioning(batch) z9 = batch['z9'] z_ref = batch['z_ref'] gt512 = batch['GT'] gt256 = batch.get('GT256',None) # del batch loss = self(z9, c,z_ref=z_ref,gt512=gt512,gt256=gt256,task=task,batch=batch,) return loss def forward(self, x, c, *args, **kwargs): task = kwargs['task'] # c is the reference tensor; target shares the same shape t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() self.u_cond_prop=random.uniform(0, 1) if self.model.conditioning_key is not None: # assert c is not None if self.cond_stage_trainable: # yes pass if self.shorten_cond_schedule: # TODO: drop this option raise Exception tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) if self.u_cond_propc_crossattn cond = {key: cond} if hasattr(self, "split_input_params"): assert 0,'This branch should not execute in practice' assert len(cond) == 1 # todo can only deal with one conditioning atm assert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) h, w = x_noisy.shape[-2:] fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) z = unfold(x_noisy) # (bn, nc * prod(**ks), L) # Reshape to img shape z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] if self.cond_stage_key in ["image", "LR_image", "segmentation", 'bbox_img'] and self.model.conditioning_key: # todo check for completeness c_key = next(iter(cond.keys())) # get key c = next(iter(cond.values())) # get value assert (len(c) == 1) # todo extend to list with more than one elem c = c[0] # get element c = unfold(c) c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] elif self.cond_stage_key == 'coordinates_bbox': assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' # assuming padding of unfold is always 0 and its dilation is always 1 n_patches_per_row = int((w - ks[0]) / stride[0] + 1) full_img_h, full_img_w = self.split_input_params['original_image_size'] # as we are operating on latents, we need the factor from the original image size to the # spatial latent size to properly rescale the crops for regenerating the bbox annotations num_downs = self.first_stage_model.encoder.num_resolutions - 1 rescale_latent = 2 ** (num_downs) # get top left positions of patches as conforming for the bbbox tokenizer, therefore we # need to rescale the tl patch coordinates to be in between (0,1) tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) for patch_nr in range(z.shape[-1])] # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) patch_limits = [(x_tl, y_tl, rescale_latent * ks[0] / full_img_w, rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] # tokenize crop coordinates for the bounding boxes of the respective patches patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) for bbox in patch_limits] # list of length l with tensors of shape (1, 2) print(patch_limits_tknzd[0].shape) # cut tknzd crop position from conditioning assert isinstance(cond, dict), 'cond must be dict to be fed into model' cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') adapted_cond = self.get_learned_conditioning(adapted_cond) adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) cond_list = [{'c_crossattn': [e]} for e in adapted_cond] else: cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient # apply model by loop over crops output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] assert not isinstance(output_list[0], tuple) # todo cant deal with multiple model outputs check this never happens o = torch.stack(output_list, axis=-1) o = o * weighting # Reverse reshape to img shape o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) # stitch crops together x_recon = fold(o) / normalization else: x_recon = self.model(x_noisy, t, **cond, return_features=return_features, z_ref=z_ref, task=self.task, _trainer=self.trainer, ) if return_features: return x_recon if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] else: return x_recon def p_losses(self, x_start, cond, t, noise=None, z_ref=None, gt512=None, gt256=None, task=None, batch :dict = None, ): # def p_losses_face(self, x_start, cond, t, reference=None,noise=None,GT_tar=None,landmarks=None): # initialize MoE auxiliary loss to 0 to allow unconditional accumulation later global_.moe_aux_loss = torch.tensor(0.0, device=self.device) if self.first_stage_key == 'inpaint': # x_start=x_start[:,:4,:,:] noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) if 1: x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise) x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1) else: noise = default(noise, lambda: torch.randn_like(x_start)) if 1: x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) if z_ref is not None: assert self.first_stage_key == 'inpaint', 'Expected first_stage_key to be "inpaint"' """ z_ref: b,4,... z_ref = concat [z_ref_noisy, z_ref, tensor_1c] tensor_1c is temporarily set to all zeros """ z_ref_noisy = self.q_sample(x_start=z_ref, t=t, noise=torch.randn_like(z_ref)) tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) if REFNET.CH9: z_ref = torch.cat([z_ref_noisy, z_ref, tensor_1c], dim=1) if 1: model_output = self.apply_model(x_noisy, t, cond, z_ref=z_ref, ) loss_dict = {} prefix = 'train' if self.training else 'val' if DDIM_losses: ######################## t_new = torch.randint(self.num_timesteps-1, self.num_timesteps, (x_start.shape[0],), device=self.device).long().to(self.device) # t_new=torch.tensor(t_new).to(self.device) # noise_rec = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) x_noisy_rec = self.q_sample(x_start=x_start[:,:4,:,:], t=t_new, noise=noise) x_noisy_rec = torch.cat((x_noisy_rec,x_start[:,4:,:,:]),dim=1) ddim_steps=self.Reconstruct_DDIM_steps n_samples=x_noisy_rec.shape[0] shape=(4,64,64) scale=5 ddim_eta=0.0 start_code=x_noisy_rec test_model_kwargs=None # t=t samples_ddim, sample_intermediates = self.sampler.sample_train(S=ddim_steps, # 4 (from Reconstruct_DDIM_steps in trian.yaml) conditioning=cond, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=None, eta=ddim_eta, x_T=start_code, t=t_new, z_ref=z_ref, test_model_kwargs=test_model_kwargs) # x_samples_ddim= self.differentiable_decode_first_stage(samples_ddim) other_pred_x_0=sample_intermediates['pred_x0'] len_inter = len(other_pred_x_0) printC("len_inter", len_inter ) for i in range(len(other_pred_x_0)): other_pred_x_0[i]=self.differentiable_decode_first_stage(other_pred_x_0[i]) # x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) # x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() ########################################### ID_loss=0 clip_loss=0 loss_lpips=0 loss_rec=0 loss_landmark=0 # model_output=samples_ddim if 1: # x_samples_ddim=TF.resize(x_samples_ddim,(256,256)) if 0: inpaint_mask_64 = x_start[:,8,:,:] # inpaint region is 1, background is 0; shape b,64,64 masks=TF.resize(inpaint_mask_64,(other_pred_x_0[0].shape[2],other_pred_x_0[0].shape[3])) # b,512,512 if not 1: masks = 1 - masks #mask x_samples_ddim x_samples_ddim_masked=[x_samples_ddim_preds*masks.unsqueeze(1) for x_samples_ddim_preds in other_pred_x_0] # x_samples_ddim_masked=un_norm_clip(x_samples_ddim_masked) # x_samples_ddim_masked = TF.normalize(x_samples_ddim_masked, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) else: x_samples_ddim_masked = other_pred_x_0 Landmark_loss_weight = 0 ID_loss_weight = [0.3, 0, 0.1, 0.2, ][task] if ID_loss_weight > 0 : ID_Losses=[] for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): ID_loss,sim_imp,_=self.face_ID_model(x_samples_ddim_preds,gt512,clip_img=False) ID_Losses.append(ID_loss) loss_dict.update({f'{prefix}/ID_loss_{step}': ID_loss}) ID_loss=torch.mean(torch.stack(ID_Losses)) loss_dict.update({f'{prefix}/ID_loss': ID_loss}) loss_dict.update({f'{prefix}/sim_imp': sim_imp}) CLIP_loss_weight = [1.5/4, 0.8, 1, 0.5, ][task] if CLIP_loss_weight > 0 : def _loss(_img1,_img2): _e1 = self.encoder_clip_face.forward_vit(_img1,resize=True) _e2 = self.encoder_clip_face.forward_vit(_img2,resize=True) return torch.nn.functional.mse_loss( _e1, _e2 ) clip_Losses=[] for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): clip_loss = _loss(x_samples_ddim_preds,gt512) clip_Losses.append(clip_loss) loss_dict.update({f'{prefix}/clip_loss_{step}': clip_loss}) clip_loss=torch.mean(torch.stack(clip_Losses)) loss_dict.update({f'{prefix}/clip_loss': clip_loss}) LPIPS_loss_weight = [0.05, 0.015, 0.015, 0.015, ][task] if LPIPS_loss_weight>0: if gt256 is not None: _lpips_base_size = 256 _gt_for_lpips = gt256 else: _lpips_base_size = 512 _gt_for_lpips = gt512 for j in range(len(other_pred_x_0)): for i in range(3): _size = _lpips_base_size//2**i _pred_for_lpips = F.adaptive_avg_pool2d(other_pred_x_0[j],(_size,_size)) _gt_for_lpips_resized = F.adaptive_avg_pool2d(_gt_for_lpips,(_size,_size)) loss_lpips_1 = self.lpips_loss( _pred_for_lpips, _gt_for_lpips_resized, ) loss_dict.update({f'{prefix}/loss_lpips_{j}_{i}': loss_lpips_1}) printC(f"loss_lpips_1 at {j} {i} :", loss_lpips_1) loss_lpips += loss_lpips_1 loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) REC_loss_weight = [0.05, 0.01, 0.01, 0.01, ][task] if REC_loss_weight > 0 : # rec loss for j in range(len(other_pred_x_0)): loss_rec_1 = torch.nn.functional.mse_loss( other_pred_x_0[j], gt512) loss_dict.update({f'{prefix}/loss_rec_{j}': loss_rec_1}) printC(f"loss_rec_1 at {j} :", loss_rec_1) loss_rec += loss_rec_1 loss_dict.update({f'{prefix}/loss_rec': loss_rec}) if 1: if self.parameterization == "x0": target = x_start elif self.parameterization == "eps": target = noise else: raise NotImplementedError() # this should be an MSE loss loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) loss_dict.update({f'{prefix}/loss_simple-t{task}': loss_simple.mean()}) self.logvar = self.logvar.to(self.device) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) loss_dict.update({'logvar': self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) #?? loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) loss_dict.update({f'{prefix}/loss_vlb-t{task}': loss_vlb}) loss += (self.original_elbo_weight * loss_vlb) else: loss = 0 if DDIM_losses: _item = lambda _a: _a.detach().cpu().item() if isinstance(_a,torch.Tensor) else _a printC("orig, ID clip, lpips rec lmk:", f"{_item(loss):.4f}, {_item(ID_loss):.4f} {_item(clip_loss):.4f}, {_item(loss_lpips):.4f} {_item(loss_rec):.4f} {_item(loss_landmark):.4f}", f"{ID_Losses=}" if ID_loss_weight>0 else "", f"{clip_Losses=}" if CLIP_loss_weight>0 else "", ) loss+=ID_loss_weight*ID_loss+LPIPS_loss_weight*loss_lpips+Landmark_loss_weight*loss_landmark+REC_loss_weight*loss_rec+CLIP_loss_weight*clip_loss # incorporate MoE auxiliary loss moe_aux = global_.moe_aux_loss if isinstance(moe_aux, torch.Tensor): loss = loss + moe_aux loss_dict.update({f'{prefix}/moe_aux_loss': moe_aux}) loss_dict.update({f'{prefix}/loss': loss}) loss_dict.update({f'{prefix}/loss-t{task}': loss}) return loss, loss_dict def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.partial_training:# no # if True: print("Partial training.............................") train_names=self.trainable_keys train_names=[ 'attn2','norm2'] params_train=[] for name,param in self.model.named_parameters(): if "diffusion_model" not in name and param.requires_grad: print(name) params_train.append(param) elif "diffusion_model" in name and any(train_name in name for train_name in train_names): print(name) params_train.append(param) params=params_train print("Setting up Adam optimizer.......................") if self.cond_stage_trainable:# yes print(f"{self.__class__.__name__}: Also optimizing conditioner params!") if hasattr(self,'encoder_clip_face'): params += list(self.encoder_clip_face.final_ln2.parameters())+list(self.encoder_clip_face.mapper2.parameters()) if self.USE_proj_out_source: params += list(self.proj_out_source__face.parameters()) if hasattr(self,'encoder_clip_hair'): params += list(self.encoder_clip_hair.final_ln2.parameters())+list(self.encoder_clip_hair.mapper2.parameters()) if self.USE_proj_out_source: params += list(self.proj_out_source__hair.parameters()) if hasattr(self,'encoder_clip_head_t2'): params += list(self.encoder_clip_head_t2.final_ln2.parameters())+list(self.encoder_clip_head_t2.mapper2.parameters()) if hasattr(self,'encoder_clip_head_t3'): params += list(self.encoder_clip_head_t3.final_ln2.parameters())+list(self.encoder_clip_head_t3.mapper2.parameters()) if hasattr(self,'encoder_clip_head_t2') or hasattr(self,'encoder_clip_head_t3'): if self.USE_proj_out_source: params += list(self.proj_out_source__head.parameters()) if hasattr(self,'ID_proj_out'): params += list(self.ID_proj_out.parameters()) if hasattr(self,'landmark_proj_out'): # fixLmkProj params += list(self.landmark_proj_out.parameters()) if self.learn_logvar: print('Diffusion model optimizing logvar') params.append(self.logvar) params.extend(self.learnable_vector) params = [p for p in params if p.requires_grad] # Build param groups: MoE gate/expert use larger LR. # Also apply per-task LR factor to all task-specific params. # only match MoE-related parameter names generated by the UNet wrappers moe_gate_ids = set() moe_ep_ids = set() for name, p in self.model.named_parameters(): if not p.requires_grad: continue if ".moe_gate_mlp." in name: moe_gate_ids.add(id(p)) elif ".moe_experts_" in name: moe_ep_ids.add(id(p)) params_ids = set(id(p) for p in params) task_specific_ids = set() for name, p in self.named_parameters(): if not p.requires_grad: continue if id(p) not in params_ids: continue is_task_specific = is_task_specific_(name) if rank_==0: print(f"{is_task_specific=} {name}") if is_task_specific: task_specific_ids.add(id(p)) base_params = [] task_specific_params = [] moe_gate_params = [] moe_ep_params = [] for p in params: pid = id(p) if pid in task_specific_ids: task_specific_params.append(p) elif pid in moe_gate_ids: moe_gate_params.append(p) elif pid in moe_ep_ids: moe_ep_params.append(p) else: base_params.append(p) param_groups = [] if base_params: param_groups.append({"params": base_params, "lr": lr}) if task_specific_params: param_groups.append({"params": task_specific_params, "lr": lr * LR_factor}) if moe_gate_params: param_groups.append({"params": moe_gate_params, "lr": lr * MOE_GATE_LR_MULT}) if moe_ep_params: param_groups.append({"params": moe_ep_params, "lr": lr * MOE_EP_LR_MULT}) if ZERO1_ENABLE: zero_pg = None if 1: if dist.is_available() and dist.is_initialized(): zero_pg = dist.new_group(backend='gloo') opt = ZeroRedundancyOptimizer( param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, optimizer_class=torch.optim.AdamW if ADAM_or_SGD else torch.optim.SGD, lr=lr, process_group=zero_pg, ) else: if ADAM_or_SGD: opt = torch.optim.AdamW(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr) else: opt = torch.optim.SGD(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr, momentum=0.9) if gate_('LatentDiffusion.configure_optimizers params:'): if (task_specific_params or moe_gate_params or moe_ep_params): print(f"base/task_specific/ep/gate lens: {len(base_params)=} {len(task_specific_params)=} {len(moe_ep_params)=} {len(moe_gate_params)=}") print(f"sum of .numel(): base={sum(p.numel() for p in base_params)} task_specific={sum(p.numel() for p in task_specific_params)} ep={sum(p.numel() for p in moe_ep_params)} gate={sum(p.numel() for p in moe_gate_params)}") else: print(f"{len(params)=}") print(f"sum of .numel(): {sum(param.numel() for param in params)}") if self.use_scheduler:# yes assert 'target' in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }] return [opt], scheduler return opt def on_train_epoch_start(self): def _set_req_grad(p, flag): if p.requires_grad != flag: p.requires_grad = flag return 1 return 0 return if 0: train_now = self.current_epoch < N_EPOCHS_TRAIN_REF_AND_MID else: # alternating freezing train_now = (self.current_epoch % 2 == 0) ct_toggled = 0 # 1) freeze all shared if not train_now; unfreeze when train_now ct_shared = 0 for name, p in self.model.diffusion_model.named_parameters(): # target only the shared weights inside Shared+LoRA wrappers: FFN.shared_ffn.* and Conv.shared.* is_shared = ('.shared_ffn.' in name) or ('.shared.' in name) if is_shared: ct_shared += _set_req_grad(p, train_now) print(f"[freeze@epoch]{self.current_epoch=} {train_now=} {ct_toggled=} {ct_shared=}") @torch.no_grad() def to_rgb(self, x): x = x.float() if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x def __repr__(self): if DEBUG: return 'LatentDiffusion.__repr__' return super().__repr__() @property def model_size(self): if DEBUG: return -1 return super().model_size from .bank import Bank class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): super().__init__() diff_model_config['params']['is_refNet'] = False self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] if REFNET.ENABLE: diff_model_config_refNet = diff_model_config print('instantiate / deepcopy diffusion_model_refNet ing...') if 1: diff_model_config_refNet['params']['in_channels'] = 9 if REFNET.CH9 else 4 diff_model_config_refNet['params']['is_refNet'] = True self.diffusion_model_refNet :UNetModel = instantiate_from_config(diff_model_config_refNet) else: self.diffusion_model_refNet :UNetModel = copy.deepcopy(self.diffusion_model) # faster than re-instantiating self.diffusion_model_refNet.is_refNet = True if 1: # print(f"before del: {len(self.diffusion_model_refNet.input_blocks)=}") if 1: self.diffusion_model_refNet.input_blocks = self.diffusion_model_refNet.input_blocks[:9] del self.diffusion_model_refNet.middle_block del self.diffusion_model_refNet.output_blocks del self.diffusion_model_refNet.out print('over.') # Keep only a single diffusion_model_refNet; no t-suffixed clones def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,return_features=False, z_ref=None, task = None, _trainer :pl.Trainer = None, ): _in_train_or_val = ( _trainer is not None ) and ( _trainer.validating or _trainer.sanity_checking ) # indicates train or validation state assert self.conditioning_key == 'crossattn' if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) elif self.conditioning_key == 'crossattn': cc = torch.cat(c_crossattn, 1) #-->cc.shape = (bs, 1, 768) ## adding return_features here only for testing if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): if task in (0,2,3,): cc_ref = cc[:,:-1, :] else: cc_ref = cc printC("c for refNet",f"{custom_repr_v3(cc_ref)}") self.diffusion_model_refNet(z_ref, t, context=cc_ref,return_features=False) out = self.diffusion_model(x, t, context=cc,return_features=return_features) if (REFNET.ENABLE and REFNET.task2layerNum[task]>0) and not (self.training or _in_train_or_val): # if 1: self.bank.clear() elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == 'adm': cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: raise NotImplementedError() return out #-->out.shape = (bs, 4,64,64)