|
|
| from .misc_4ddpm import * |
| from lmk_util.lmk_extractor import lmkAll_2_lmkMain, get_lmkMain_indices |
|
|
| class DDPM(pl.LightningModule): |
| |
| def __init__(self, |
| unet_config, |
| timesteps=1000, |
| beta_schedule="linear", |
| loss_type="l2", |
| ckpt_path=None, |
| ignore_keys=[], |
| load_only_unet=False, |
| monitor="val/loss", |
| use_ema=True, |
| first_stage_key="image", |
| image_size=256, |
| channels=3, |
| log_every_t=100, |
| clip_denoised=True, |
| linear_start=1e-4, |
| linear_end=2e-2, |
| cosine_s=8e-3, |
| given_betas=None, |
| original_elbo_weight=0., |
| v_posterior=0., |
| l_simple_weight=1., |
| conditioning_key=None, |
| parameterization="eps", |
| scheduler_config=None, |
| learn_logvar=False, |
| logvar_init=0., |
| u_cond_percent=0, |
| ): |
| super().__init__() |
| assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' |
| self.parameterization = parameterization |
| print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") |
| self.cond_stage_model = None |
| self.clip_denoised = clip_denoised |
| self.log_every_t = log_every_t |
| self.first_stage_key = first_stage_key |
| self.image_size = image_size |
| self.channels = channels |
| self.u_cond_percent=u_cond_percent |
| unet_config['params']['in_channels'] = 14 if CH14 else 9 |
| self.model = DiffusionWrapper(unet_config, conditioning_key) |
| count_params(self.model, verbose=True) |
| self.use_ema = use_ema |
| if self.use_ema: |
| self.model_ema = LitEma(self.model) |
| print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") |
|
|
| self.use_scheduler = scheduler_config is not None |
| if self.use_scheduler: |
| self.scheduler_config = scheduler_config |
|
|
| self.v_posterior = v_posterior |
| self.original_elbo_weight = original_elbo_weight |
| self.l_simple_weight = l_simple_weight |
|
|
| if monitor is not None: |
| self.monitor = monitor |
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) |
|
|
| self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, |
| linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) |
|
|
| self.loss_type = loss_type |
|
|
| self.learn_logvar = learn_logvar |
| self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) |
| if self.learn_logvar: |
| self.logvar = nn.Parameter(self.logvar, requires_grad=True) |
|
|
|
|
| def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, |
| linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): |
| if exists(given_betas): |
| betas = given_betas |
| else: |
| betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, |
| cosine_s=cosine_s) |
| alphas = 1. - betas |
| alphas_cumprod = np.cumprod(alphas, axis=0) |
| alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) |
|
|
| timesteps, = betas.shape |
| self.num_timesteps = int(timesteps) |
| self.linear_start = linear_start |
| self.linear_end = linear_end |
| assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' |
|
|
| to_torch = partial(torch.tensor, dtype=torch.float32) |
|
|
| self.register_buffer('betas', to_torch(betas)) |
| self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) |
| self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) |
|
|
| |
| self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) |
| self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) |
| self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) |
| self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) |
|
|
| |
| posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( |
| 1. - alphas_cumprod) + self.v_posterior * betas |
| |
| self.register_buffer('posterior_variance', to_torch(posterior_variance)) |
| |
| self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) |
| self.register_buffer('posterior_mean_coef1', to_torch( |
| betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) |
| self.register_buffer('posterior_mean_coef2', to_torch( |
| (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) |
|
|
| if self.parameterization == "eps": |
| lvlb_weights = self.betas ** 2 / ( |
| 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) |
| elif self.parameterization == "x0": |
| lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) |
| else: |
| raise NotImplementedError("mu not supported") |
| |
| lvlb_weights[0] = lvlb_weights[1] |
| self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) |
| assert not torch.isnan(self.lvlb_weights).all() |
|
|
| @contextmanager |
| def ema_scope(self, context=None): |
| if self.use_ema: |
| self.model_ema.store(self.model.parameters()) |
| self.model_ema.copy_to(self.model) |
| if context is not None: |
| print(f"{context}: Switched to EMA weights") |
| try: |
| yield None |
| finally: |
| if self.use_ema: |
| self.model_ema.restore(self.model.parameters()) |
| if context is not None: |
| print(f"{context}: Restored training weights") |
|
|
| def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): |
| assert 0 |
| print("[init_from_ckpt]") |
| sd = torch.load(path, map_location="cpu") |
| if "state_dict" in list(sd.keys()): |
| sd = sd["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( |
| sd, strict=False) |
| print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") |
| if len(missing) > 0: |
| print(f"Missing Keys: {missing}") |
| if len(unexpected) > 0: |
| print(f"Unexpected Keys: {unexpected}") |
|
|
| def q_sample(self, x_start, t, noise=None): |
| noise = default(noise, lambda: torch.randn_like(x_start)) |
| return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
| extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) |
|
|
| def get_loss(self, pred, target, mean=True): |
| if self.loss_type == 'l1': |
| loss = (target - pred).abs() |
| if mean: |
| loss = loss.mean() |
| elif self.loss_type == 'l2': |
| if mean: |
| loss = torch.nn.functional.mse_loss(target, pred) |
| else: |
| loss = torch.nn.functional.mse_loss(target, pred, reduction='none') |
| else: |
| raise NotImplementedError("unknown loss type '{loss_type}'") |
|
|
| return loss |
|
|
| def p_losses(self, x_start, t, noise=None): |
| assert 0, 'This should not be called; subclasses override this method' |
| noise = default(noise, lambda: torch.randn_like(x_start)) |
| x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) |
| model_out = self.model(x_noisy, t) |
|
|
| loss_dict = {} |
| if self.parameterization == "eps": |
| target = noise |
| elif self.parameterization == "x0": |
| target = x_start |
| else: |
| raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") |
|
|
| loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) |
|
|
| |
| log_prefix = 'train' if self.training else 'val' |
|
|
| loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) |
| loss_simple = loss.mean() * self.l_simple_weight |
|
|
| loss_vlb = (self.lvlb_weights[t] * loss).mean() |
| loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) |
|
|
| loss = loss_simple + self.original_elbo_weight * loss_vlb |
|
|
| loss_dict.update({f'{log_prefix}/loss': loss}) |
|
|
| return loss, loss_dict |
|
|
| def forward(self, x, *args, **kwargs): |
| |
| |
| t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() |
| return self.p_losses(x, t, *args, **kwargs) |
|
|
|
|
| def shared_step(self, batch): |
| assert 0 |
|
|
| |
| def set_task(self, batch): |
| task = batch['task'][0].item() |
| printC('task',f"{task=}") |
| global_.task = task |
| assert all(batch['task'] == task), batch['task'] |
| self.task = task |
| if 1: |
| if (not USE_pts) or task==1: self.Landmark_cond=False |
| else: self.Landmark_cond=True |
| if 1: |
| if task in (0,2,3,): |
| self.Landmarks_weight=0.05 |
| else: |
| self.Landmarks_weight=0 |
| self.STACK_feat=True |
| return task |
| def unset_task(self): |
| global_.task = None |
| global_.lmk_ = None |
| del self.task |
| def training_step(self, batch, batch_idx): |
| task = batch['task'][0].item() |
| opt = self.optimizers() |
| |
| if not self.Reconstruct_initial: |
| loss, loss_dict = self.shared_step(batch) |
| else: |
| loss, loss_dict = self.shared_step_face(batch) |
|
|
| step_or_accumulate = ( task==3 or TP_enable) |
| _ctx = nullcontext |
| if not step_or_accumulate and not TP_enable: |
| _ctx = self.trainer.model.no_sync |
| with _ctx(): |
| self.manual_backward(loss) |
| |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): |
| self.model.bank.clear() |
| self.unset_task() |
| |
| |
| total_step = len(self.trainer.train_dataloader) |
| if step_or_accumulate: |
| |
| if dist.is_available() and dist.is_initialized(): |
| ws = dist.get_world_size() |
| shared_sync_cnt = 0; task_skip_cnt = 0 |
| for name, p in self.named_parameters(): |
| need_sync, is_task_specific_skip = tp_param_need_sync(name, p) |
| if not need_sync: |
| if is_task_specific_skip: |
| task_skip_cnt += 1 |
| continue |
| if p.grad is None: |
| p.grad = torch.zeros_like(p) |
| dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) |
| p.grad.div_(ws) |
| shared_sync_cnt += 1 |
| if gate_('[TP] shared sync counts'): |
| print(f"synced={shared_sync_cnt} skipped(task)={task_skip_cnt}") |
| torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) |
| opt.step() |
| opt.zero_grad() |
| if self.use_scheduler: |
| sch = self.lr_schedulers() |
| if isinstance(sch, list) and len(sch) > 0: |
| for scheduler_config in sch: |
| if isinstance(scheduler_config, dict) and 'scheduler' in scheduler_config: |
| scheduler_config['scheduler'].step() |
| else: |
| scheduler_config.step() |
| elif hasattr(sch, 'step'): |
| sch.step() |
| self.log_dict(loss_dict, prog_bar=True, |
| logger=True, on_step=True, on_epoch=True) |
|
|
| self.log("global_step", self.global_step, |
| prog_bar=True, logger=True, on_step=True, on_epoch=False) |
|
|
| if self.use_scheduler: |
| lr = self.optimizers().param_groups[0]['lr'] |
| self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) |
|
|
| return loss |
| |
| |
|
|
| @torch.no_grad() |
| def validation_step(self, batch, batch_idx): |
| _, loss_dict_no_ema = self.shared_step(batch) |
| with self.ema_scope(): |
| _, loss_dict_ema = self.shared_step(batch) |
| loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} |
| self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
| self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
| self.unset_task() |
| def on_train_batch_end(self, *args, **kwargs): |
| if self.use_ema: |
| self.model_ema(self.model) |
|
|
|
|
|
|
|
|
| class LatentDiffusion(DDPM): |
| """main class""" |
| def __init__(self, |
| first_stage_config, |
| cond_stage_config, |
| num_timesteps_cond=None, |
| cond_stage_key="image", |
| cond_stage_trainable=False, |
| concat_mode=True, |
| cond_stage_forward=None, |
| conditioning_key=None, |
| scale_factor=1.0, |
| scale_by_std=False, |
| *args, **kwargs): |
| self.num_timesteps_cond = default(num_timesteps_cond, 1) |
| self.scale_by_std = scale_by_std |
| assert self.num_timesteps_cond <= kwargs['timesteps'] |
| |
| if conditioning_key is None: |
| conditioning_key = 'concat' if concat_mode else 'crossattn' |
| if cond_stage_config == '__is_unconditional__': |
| conditioning_key = None |
| ckpt_path = kwargs.pop("ckpt_path", None) |
| ignore_keys = kwargs.pop("ignore_keys", []) |
| super().__init__(conditioning_key=conditioning_key, *args, **kwargs) |
| self.automatic_optimization = False |
| |
| |
| |
| |
| |
| |
| self.concat_mode = concat_mode |
| self.cond_stage_trainable = cond_stage_trainable |
| self.cond_stage_key = cond_stage_key |
| |
| |
| if hasattr(cond_stage_config, 'other_params'): |
| |
| self.clip_weight=cond_stage_config.other_params.clip_weight |
| |
| if set(TASKS) & {0,2,3}: self.ID_weight = 10.0 |
| else: self.ID_weight = 0 |
| if (not USE_pts) and TASKS==(1,): self.Landmark_cond=False |
| else: self.Landmark_cond=True |
| self.Landmarks_weight=0.05 |
| if hasattr(cond_stage_config.other_params, 'Additional_config'): |
| self.Reconstruct_initial=cond_stage_config.other_params.Additional_config.Reconstruct_initial |
| self.Reconstruct_DDIM_steps=cond_stage_config.other_params.Additional_config.Reconstruct_DDIM_steps |
| self.sampler=DDIMSampler(self) |
| if hasattr(cond_stage_config.other_params, 'multi_scale_ID'): |
| self.multi_scale_ID=cond_stage_config.other_params.multi_scale_ID |
| else: |
| self.multi_scale_ID=True |
| if hasattr(cond_stage_config.other_params, 'normalize'): |
| self.normalize=cond_stage_config.other_params.normalize |
| else: |
| self.normalize=False |
| if 1: |
| self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() |
| if hasattr(cond_stage_config.other_params, 'partial_training'): |
| self.partial_training=cond_stage_config.other_params.partial_training |
| self.trainable_keys=cond_stage_config.other_params.trainable_keys |
| else: |
| self.partial_training=False |
| if hasattr(cond_stage_config.other_params.Additional_config, 'Same_image_reconstruct'): |
| self.Same_image_reconstruct=cond_stage_config.other_params.Additional_config.Same_image_reconstruct |
| else: |
| self.Same_image_reconstruct=False |
| if hasattr(cond_stage_config.other_params.Additional_config, 'Target_CLIP_feat'): |
| self.Target_CLIP_feat=cond_stage_config.other_params.Additional_config.Target_CLIP_feat |
| else: |
| self.Target_CLIP_feat=False |
| if hasattr(cond_stage_config.other_params.Additional_config, 'Source_CLIP_feat'): |
| self.Source_CLIP_feat=cond_stage_config.other_params.Additional_config.Source_CLIP_feat |
| else: |
| self.Source_CLIP_feat=False |
| if hasattr(cond_stage_config.other_params.Additional_config, 'use_3dmm'): |
| self.use_3dmm=cond_stage_config.other_params.Additional_config.use_3dmm |
| else: |
| self.use_3dmm=False |
| |
| else: |
| self.Reconstruct_initial=False |
| self.Reconstruct_DDIM_steps=0 |
| |
| self.update_weight=False |
| |
| else: |
| assert 0 |
| if 1: |
| self.learnable_vector = nn.ParameterList([ |
| nn.Parameter(torch.randn((1,259,768)), requires_grad=True), |
| nn.Parameter(torch.randn((1,257,768)), requires_grad=True), |
| nn.Parameter(torch.randn((1,259,768)), requires_grad=True), |
| nn.Parameter(torch.randn((1,259,768)), requires_grad=True), |
| ]) |
| if self.ID_weight>0: |
| if self.multi_scale_ID: |
| self.ID_proj_out=nn.Linear(200704, 768) |
| else: |
| self.ID_proj_out=nn.Linear(512, 768) |
| self.instantiate_IDLoss(cond_stage_config) |
| |
| if self.Landmark_cond: |
| if USE_pts: |
| self.ptsM_Generator = LandmarkExtractor(include_visualizer=True,img_256_mode=False) |
| else: |
| self.detector = dlib.get_frontal_face_detector() |
| self.predictor = dlib.shape_predictor("Other_dependencies/DLIB_landmark_det/shape_predictor_68_face_landmarks.dat") |
| |
| if self.Landmarks_weight>0: |
| self.landmark_proj_out=nn.Linear(NUM_pts*2, 768) |
| self.total_steps_in_epoch=0 |
| if 1: |
| assert cond_stage_config.target=="ldm.modules.encoders.modules.FrozenCLIPEmbedder" and self.Source_CLIP_feat and self.Target_CLIP_feat |
| self.USE_proj_out_source = 1 |
| if set(TASKS) & {0,}: |
| self.proj_out_source__face=nn.Linear(768, 768) |
| if set(TASKS) & {1,}: |
| self.proj_out_source__hair=nn.Linear(768, 768) |
| if set(TASKS) & {2,3,}: |
| self.proj_out_source__head=nn.Linear(768, 768) |
| if 0: |
| self.proj_out_target=nn.Linear(768, 768) |
| self.proj_out=nn.Identity() |
| |
| try: |
| self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 |
| except: |
| self.num_downs = 0 |
| if not scale_by_std: |
| self.scale_factor = scale_factor |
| else: |
| self.register_buffer('scale_factor', torch.tensor(scale_factor)) |
| self.instantiate_first_stage(first_stage_config) |
| self.instantiate_cond_stage(cond_stage_config) |
| |
| |
| self.cond_stage_forward = cond_stage_forward |
| self.clip_denoised = False |
| self.bbox_tokenizer = None |
|
|
| self.restarted_from_ckpt = False |
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys) |
| self.restarted_from_ckpt = True |
|
|
| def get_lmk_for_router(self, batch: dict, x_tensor: torch.Tensor): |
| """ |
| Prepare global_.lmk_ (BS, L, 2) normalized to [0,1] for gating/router. |
| - Prefer cached Mediapipe landmarks if present in batch |
| - Convert 468/478 to main landmarks with face oval using get_lmkMain_indices(True) |
| - Fallback to zeros if not available |
| """ |
| b, _, H, W = x_tensor.shape |
| if READ_mediapipe_result_from_cache and ('mediapipe_lmkAll' in batch): |
| data_all = batch['mediapipe_lmkAll'] |
| if isinstance(data_all, torch.Tensor): |
| lmks_all = data_all.to(x_tensor.device).to(x_tensor.dtype) |
| else: |
| lmks_all = torch.from_numpy(data_all).to(x_tensor.device).to(x_tensor.dtype) |
| |
| idxs = getattr(global_, 'lmk_main_idx_tensor', None) |
| if (idxs is None) or (idxs.device != x_tensor.device): |
| idx_list = get_lmkMain_indices(include_face_oval=True) |
| idxs = torch.as_tensor(list(idx_list), dtype=torch.long, device=x_tensor.device) |
| global_.lmk_main_idx_tensor = idxs |
| lmk = torch.index_select(lmks_all, dim=1, index=idxs) |
| |
| if lmk.numel() > 0: |
| |
| lmk[..., 0] = lmk[..., 0] / float(W) |
| lmk[..., 1] = lmk[..., 1] / float(H) |
| |
| else: |
| assert 0 |
| lmk = torch.zeros((b, 0, 2), device=x_tensor.device, dtype=x_tensor.dtype) |
| return lmk |
|
|
| def make_cond_schedule(self, ): |
| self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) |
| ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() |
| self.cond_ids[:self.num_timesteps_cond] = ids |
|
|
| @rank_zero_only |
| @torch.no_grad() |
| def on_train_batch_start(self, batch, batch_idx, dataloader_idx): |
| |
| if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: |
| assert 0 |
|
|
| def register_schedule(self, |
| given_betas=None, beta_schedule="linear", timesteps=1000, |
| linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): |
| super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) |
|
|
| self.shorten_cond_schedule = self.num_timesteps_cond > 1 |
| if self.shorten_cond_schedule: |
| self.make_cond_schedule() |
|
|
| def instantiate_first_stage(self, config): |
| model = instantiate_from_config(config) |
| self.first_stage_model = model.eval() |
| self.first_stage_model.train = disabled_train |
| for param in self.first_stage_model.parameters(): |
| param.requires_grad = False |
|
|
| def instantiate_IDLoss(self, config): |
| |
| |
| model = IDLoss(config,multiscale=self.multi_scale_ID) |
| self.face_ID_model = model.eval() |
| self.face_ID_model.train = disabled_train |
| for param in self.face_ID_model.parameters(): |
| param.requires_grad = False |
| |
| |
| |
| def instantiate_cond_stage(self, config): |
| if 1: |
| assert config != '__is_first_stage__' |
| assert config != '__is_unconditional__' |
| model: FrozenCLIPEmbedder = instantiate_from_config(config) |
| if 0 in TASKS: |
| self.encoder_clip_face :FrozenCLIPEmbedder = model |
| if 1 in TASKS: |
| self.encoder_clip_hair :FrozenCLIPEmbedder = copy.deepcopy(model) |
| del self.encoder_clip_hair.model |
| del self.encoder_clip_hair.tokenizer |
| if set(TASKS) & {2,}: |
| self.encoder_clip_head_t2 :FrozenCLIPEmbedder = copy.deepcopy(model) |
| del self.encoder_clip_head_t2.model |
| del self.encoder_clip_head_t2.tokenizer |
| if set(TASKS) & {3,}: |
| self.encoder_clip_head_t3 :FrozenCLIPEmbedder = copy.deepcopy(model) |
| del self.encoder_clip_head_t3.model |
| del self.encoder_clip_head_t3.tokenizer |
|
|
|
|
| def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): |
| denoise_row = [] |
| for zd in tqdm(samples, desc=desc): |
| denoise_row.append(self.decode_first_stage(zd.to(self.device), |
| force_not_quantize=force_no_decoder_quantization)) |
| n_imgs_per_row = len(denoise_row) |
| denoise_row = torch.stack(denoise_row) |
| denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') |
| denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') |
| denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) |
| return denoise_grid |
|
|
| def get_first_stage_encoding(self, encoder_posterior): |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution): |
| z = encoder_posterior.sample() |
| elif isinstance(encoder_posterior, torch.Tensor): |
| z = encoder_posterior |
| else: |
| raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") |
| return self.scale_factor * z |
|
|
| |
| def get_learned_conditioning(self, c): |
| raise Exception |
| def conditioning_with_feat(self,x,landmarks=None,enInputs:dict=None): |
| if gate_('vis LatentDiffusion.conditioning_with_feat'): |
| debug_dir = Path(f"4debug/conditioning_with_feat/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) |
| all_images = [ ('x', x), ] |
| for _name, _enInput in enInputs.items(): |
| all_images.append((_name, _enInput)) |
| vis_tensors_A(all_images, debug_dir / f"all-{str_t_pid()}.jpg", vis_batch_size= min(5, landmarks.shape[0]) ) |
| del x |
| task = self.task |
| ID_weight = self.ID_weight |
| Landmarks_weight = self.Landmarks_weight |
| if self.task==0: |
| face_clip_weight = self.clip_weight |
| elif self.task==1: |
| hair_clip_weight = self.clip_weight |
| elif self.task==2: |
| head_clip_weight = self.clip_weight |
| elif self.task==3: |
| head_clip_weight = self.clip_weight |
| if 1: |
| cs = [] |
| ws = [] |
| def encode_face_ID(): |
| _c = enInputs['face_ID-in'] |
| _c=self.face_ID_model.extract_feats(_c)[0] |
| _c = self.ID_proj_out(_c) |
| _c = _c.unsqueeze(1) |
| if self.normalize: |
| _c = _c*norm_coeff/F.normalize(_c, p=2, dim=2) |
| cs.append(_c); ws.append(ID_weight) |
| def encode_face_clip(_z=None): |
| if _z is None: |
| _c = enInputs['face-clip-in'] |
| _c = self.encoder_clip_face.encode(_c) |
| else: |
| assert 0 |
| _c = self.encoder_clip_face.encode_B(_z) |
| if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: |
| _c = self.proj_out_source__face(_c) |
| cs.append(_c); ws.append(face_clip_weight) |
| def encode_hair_clip(_z=None): |
| if _z is None: |
| _c = enInputs['hair-clip-in'] |
| _c = self.encoder_clip_hair.encode(_c) |
| else: |
| _c = self.encoder_clip_hair.encode_B(_z) |
| if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: |
| _c = self.proj_out_source__hair(_c) |
| printC("hair _c.shape:",f"{_c.shape}") |
| cs.append(_c); ws.append(hair_clip_weight) |
| def encode_head_clip(_z=None): |
| if global_.task == 2: |
| encoder_clip_head = self.encoder_clip_head_t2 |
| elif global_.task == 3: |
| encoder_clip_head = self.encoder_clip_head_t3 |
| else: |
| raise ValueError(f"Task {global_.task} does not have encoder_clip_head") |
| if _z is None: |
| _c = enInputs['head-clip-in'] |
| _c = encoder_clip_head.encode(_c) |
| else: |
| _c = encoder_clip_head.encode_B(_z) |
| if hasattr(self,'USE_proj_out_source') and self.USE_proj_out_source: |
| _c = self.proj_out_source__head(_c) |
| printC("head _c.shape:",f"{_c.shape}") |
| cs.append(_c); ws.append(head_clip_weight) |
| if task==0: |
| encode_face_ID() |
| encode_face_clip() |
| elif task==1: |
| _z = enInputs['hair-clip-in'] |
| _z = self.encoder_clip_face.forward_vit(_z) |
| encode_hair_clip(_z) |
| elif task==2: |
| encode_face_ID() |
| _z = enInputs['head-clip-in'] |
| _z = self.encoder_clip_face.forward_vit(_z) |
| encode_head_clip(_z) |
| elif task==3: |
| encode_face_ID() |
| _z = enInputs['head-clip-in'] |
| _z = self.encoder_clip_face.forward_vit(_z) |
| encode_head_clip(_z) |
| c=0 |
| |
| if Landmarks_weight > 0: |
| landmarks=landmarks.unsqueeze(1) if len(landmarks.shape)!=3 else landmarks |
| cs.append(landmarks); ws.append(Landmarks_weight) |
| if self.STACK_feat: |
| |
| conc=torch.cat(cs, dim=-2) |
| c = conc |
| else: |
| total_weight = sum(ws) |
| weighted_sum = sum(c * w for c, w in zip(cs, ws)) |
| c = weighted_sum / total_weight if total_weight > 0 else 0 |
| printC("[conditioning_with_feat return]",f"{custom_repr_v3(c)}") |
| |
| return c |
| |
|
|
| def get_landmarks(self,x, batch:dict): |
| |
| if (self.Landmark_cond) and x is not None: |
| |
| |
| |
| x=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() |
| x=x.astype(np.uint8) |
| Landmarks_all=[] |
| if USE_pts: |
| l_lmkAll=[] |
| if READ_mediapipe_result_from_cache: |
| _l_lmkAll :np.ndarray = batch['mediapipe_lmkAll'].cpu().numpy() |
| bs = len(x) |
| for i in range(len(x)): |
| if USE_pts: |
| if READ_mediapipe_result_from_cache: |
| lmkAll :np.ndarray = _l_lmkAll[i] |
| else: |
| lmkAll :np.ndarray = self.ptsM_Generator.extract_single(x[i], only_main_lmk=False) |
| if lmkAll is None: lmkAll = np.zeros((478,2)) |
| l_lmkAll.append(lmkAll) |
| lm = lmkAll_2_lmkMain(lmkAll) |
| lm = lm.reshape(1, NUM_pts*2) |
| Landmarks_all.append(lm) |
| if 0: |
| from util_vis import visualize_landmarks |
| starter_stem = Path(sys.argv[0]).stem |
| path_vis_lmk = f'4debug/vis_lmk/{starter_stem}-{i}.png' |
| visualize_landmarks(x[i], lm[0], path_vis_lmk) |
| print(f"{path_vis_lmk=}") |
| Landmarks_all=np.concatenate(Landmarks_all,axis=0) |
| pts68 = Landmarks_all.reshape(bs, NUM_pts, 2, ) |
| if self.Landmarks_weight>0: |
| Landmarks_all=torch.tensor(Landmarks_all).float().to(self.device) |
| if self.Landmark_cond == False: |
| return Landmarks_all |
| with torch.enable_grad(): |
| Landmarks_all=self.landmark_proj_out(Landmarks_all) |
| |
| |
| lmk_aux={} |
| if USE_pts: lmk_aux['l_lmkAll'] = l_lmkAll |
| return Landmarks_all,pts68,lmk_aux |
|
|
| def meshgrid(self, h, w): |
| y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) |
| x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) |
|
|
| arr = torch.cat([y, x], dim=-1) |
| return arr |
|
|
| def delta_border(self, h, w): |
| """ |
| :param h: height |
| :param w: width |
| :return: normalized distance to image border, |
| wtith min distance = 0 at border and max dist = 0.5 at image center |
| """ |
| lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) |
| arr = self.meshgrid(h, w) / lower_right_corner |
| dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] |
| dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] |
| edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] |
| return edge_dist |
|
|
| def get_weighting(self, h, w, Ly, Lx, device): |
| weighting = self.delta_border(h, w) |
| weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], |
| self.split_input_params["clip_max_weight"], ) |
| weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) |
|
|
| if self.split_input_params["tie_braker"]: |
| L_weighting = self.delta_border(Ly, Lx) |
| L_weighting = torch.clip(L_weighting, |
| self.split_input_params["clip_min_tie_weight"], |
| self.split_input_params["clip_max_tie_weight"]) |
|
|
| L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) |
| weighting = weighting * L_weighting |
| return weighting |
|
|
| def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): |
| """ |
| :param x: img of size (bs, c, h, w) |
| :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) |
| """ |
| bs, nc, h, w = x.shape |
|
|
| |
| Ly = (h - kernel_size[0]) // stride[0] + 1 |
| Lx = (w - kernel_size[1]) // stride[1] + 1 |
|
|
| if uf == 1 and df == 1: |
| fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) |
| unfold = torch.nn.Unfold(**fold_params) |
|
|
| fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) |
|
|
| weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) |
| normalization = fold(weighting).view(1, 1, h, w) |
| weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) |
|
|
| elif uf > 1 and df == 1: |
| fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) |
| unfold = torch.nn.Unfold(**fold_params) |
|
|
| fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), |
| dilation=1, padding=0, |
| stride=(stride[0] * uf, stride[1] * uf)) |
| fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) |
|
|
| weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) |
| normalization = fold(weighting).view(1, 1, h * uf, w * uf) |
| weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) |
|
|
| elif df > 1 and uf == 1: |
| fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) |
| unfold = torch.nn.Unfold(**fold_params) |
|
|
| fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), |
| dilation=1, padding=0, |
| stride=(stride[0] // df, stride[1] // df)) |
| fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) |
|
|
| weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) |
| normalization = fold(weighting).view(1, 1, h // df, w // df) |
| weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) |
|
|
| else: |
| raise NotImplementedError |
|
|
| return fold, unfold, normalization, weighting |
|
|
| |
| @torch.no_grad() |
| def get_input_(self, batch, k, return_first_stage_outputs=False, |
| cond_key=None, bs=None, |
| get_referenceZ=False, |
| ): |
| if k == "inpaint": |
| x = batch['GT'] |
| mask = batch['inpaint_mask'].clone() |
| inpaint = batch['inpaint_image'].clone() |
| |
| reference = None |
| else: |
| assert 0 |
| if len(x.shape) == 3: |
| assert 0 |
| x = x[..., None] |
| if 1: |
| enInputs = batch['enInputs'] |
| for k,v in enInputs.items(): |
| enInputs[k] = v.to(memory_format=torch.contiguous_format).float() |
| |
| ref_imgs_4unet = batch.get('ref_imgs_4unet', None) if get_referenceZ else None |
| |
| |
| |
| |
| |
| |
| if bs is not None: |
| assert 0 |
| x = x.to(self.device) |
| |
| global_.lmk_ = self.get_lmk_for_router(batch, x) |
| if self.Landmark_cond: |
| landmarks, pts68, lmk_aux=self.get_landmarks(x,batch) |
| else: |
| landmarks=None |
| |
| if self.task in (0,2,3,) and USE_pts: |
| mask_np = mask.detach().cpu().numpy() |
| if 1: |
| |
| x_unnorm=255.0*un_norm(x).permute(0,2,3,1).cpu().numpy() |
| x_unnorm=x_unnorm.astype(np.uint8) |
| |
| batch_size = x.shape[0] |
| |
| VIS_pts= 0 |
| |
| for b in range(batch_size): |
| lmkAll = lmk_aux['l_lmkAll'][b] |
| inpaint[b] = torch.Tensor(self.ptsM_Generator.visualizer.visualize_landmarks(inpaint[b].permute(1,2,0).detach().cpu().numpy(), lmkAll, ) ).permute(2,0,1) |
| del lmkAll |
| |
| if self.training and gate_('vis LatentDiffusion.get_input'): |
| debug_dir = Path(f"4debug/LatentDiffusion.get_input/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) |
| vis_batch_size = min(5, x.shape[0]) |
| all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet) ] |
| for _name, _enInput in enInputs.items(): |
| all_images.append((_name, _enInput)) |
| all_path = debug_dir / f"all--after-pts-{str_t_pid()}.jpg" |
| vis_tensors_A(all_images, all_path, vis_batch_size) |
| |
| encoder_posterior = self.encode_first_stage(x) |
| z = self.get_first_stage_encoding(encoder_posterior).detach() |
| encoder_posterior_inpaint = self.encode_first_stage(inpaint) |
| z_inpaint = self.get_first_stage_encoding(encoder_posterior_inpaint).detach() |
| |
| mask_resize = Resize([z.shape[-1],z.shape[-1]])(mask) |
| ref_mask_64 = Resize([z.shape[-1],z.shape[-1]])(batch['ref_mask_512']) if 'ref_mask_512' in batch else None |
| |
| if not CH14: |
| z_new = torch.cat((z,z_inpaint,mask_resize),dim=1) |
| if get_referenceZ: |
| encoder_posterior_ref = self.encode_first_stage(ref_imgs_4unet) |
| z_ref = self.get_first_stage_encoding(encoder_posterior_ref).detach() |
| else: |
| z_ref = None |
| if CH14: |
| z_new = torch.cat((z,z_inpaint,mask_resize, z_ref,ref_mask_64),dim=1) |
| assert z.shape[1:]==(4,64,64,) |
| if gate_(f'vis LatentDiffusion.get_input-before_return {self.training}'): |
| debug_dir = Path(f"4debug/LatentDiffusion.get_input-before_return/{ID}"); debug_dir.mkdir(parents=0, exist_ok=True) |
| vis_batch_size = min(5, x.shape[0]) |
| all_images = [ ('x', x), ('inpaint', inpaint), ('mask', mask), ('reference', reference), ('ref_imgs_4unet', ref_imgs_4unet), |
| ('z4_gt',z[:,:3]),('z4_inpaint', z_inpaint[:,:3]),('tgt_mask_64', mask_resize),('z_ref',None if z_ref is None else z_ref[:,:3]),('ref_mask_64',ref_mask_64),] |
| all_path = debug_dir / f"{str_t_pid()}.jpg" |
| vis_tensors_A(all_images, all_path, vis_batch_size) |
|
|
| if 1: |
| assert self.model.conditioning_key is not None |
| assert self.first_stage_key=='inpaint' |
| assert self.cond_stage_key=='image' |
| return { |
| **batch, |
| 'z9': z_new, |
| 'z4_gt': z, |
| 'z4_inpaint': z_inpaint, |
| |
| 'tgt_mask_64': mask_resize, |
| 'ref_mask_64': ref_mask_64, |
| |
| 'z_ref': z_ref, |
| |
| 'landmarks': landmarks, |
| } |
| |
| @torch.no_grad() |
| def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): |
| if predict_cids: |
| if z.dim() == 4: |
| z = torch.argmax(z.exp(), dim=1).long() |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) |
| z = rearrange(z, 'b h w c -> b c h w').contiguous() |
|
|
| z = 1. / self.scale_factor * z |
|
|
| if hasattr(self, "split_input_params"): |
| if self.split_input_params["patch_distributed_vq"]: |
| ks = self.split_input_params["ks"] |
| stride = self.split_input_params["stride"] |
| uf = self.split_input_params["vqf"] |
| bs, nc, h, w = z.shape |
| if ks[0] > h or ks[1] > w: |
| ks = (min(ks[0], h), min(ks[1], w)) |
| print("reducing Kernel") |
|
|
| if stride[0] > h or stride[1] > w: |
| stride = (min(stride[0], h), min(stride[1], w)) |
| print("reducing stride") |
|
|
| fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) |
|
|
| z = unfold(z) |
| |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) |
|
|
| |
| if isinstance(self.first_stage_model, VQModelInterface): |
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i], |
| force_not_quantize=predict_cids or force_not_quantize) |
| for i in range(z.shape[-1])] |
| else: |
|
|
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) |
| for i in range(z.shape[-1])] |
|
|
| o = torch.stack(output_list, axis=-1) |
| o = o * weighting |
| |
| o = o.view((o.shape[0], -1, o.shape[-1])) |
| |
| decoded = fold(o) |
| decoded = decoded / normalization |
| return decoded |
| else: |
| if isinstance(self.first_stage_model, VQModelInterface): |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) |
| else: |
| return self.first_stage_model.decode(z) |
|
|
| else: |
| if isinstance(self.first_stage_model, VQModelInterface): |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) |
| else: |
| if self.first_stage_key=='inpaint': |
| return self.first_stage_model.decode(z[:,:4,:,:]) |
| else: |
| return self.first_stage_model.decode(z) |
|
|
| |
|
|
| |
| def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): |
| if predict_cids: |
| if z.dim() == 4: |
| z = torch.argmax(z.exp(), dim=1).long() |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) |
| z = rearrange(z, 'b h w c -> b c h w').contiguous() |
|
|
| z = 1. / self.scale_factor * z |
|
|
| if hasattr(self, "split_input_params"): |
| if self.split_input_params["patch_distributed_vq"]: |
| ks = self.split_input_params["ks"] |
| stride = self.split_input_params["stride"] |
| uf = self.split_input_params["vqf"] |
| bs, nc, h, w = z.shape |
| if ks[0] > h or ks[1] > w: |
| ks = (min(ks[0], h), min(ks[1], w)) |
| print("reducing Kernel") |
|
|
| if stride[0] > h or stride[1] > w: |
| stride = (min(stride[0], h), min(stride[1], w)) |
| print("reducing stride") |
|
|
| fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) |
|
|
| z = unfold(z) |
| |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) |
|
|
| |
| if isinstance(self.first_stage_model, VQModelInterface): |
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i], |
| force_not_quantize=predict_cids or force_not_quantize) |
| for i in range(z.shape[-1])] |
| else: |
|
|
| output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) |
| for i in range(z.shape[-1])] |
|
|
| o = torch.stack(output_list, axis=-1) |
| o = o * weighting |
| |
| o = o.view((o.shape[0], -1, o.shape[-1])) |
| |
| decoded = fold(o) |
| decoded = decoded / normalization |
| return decoded |
| else: |
| if isinstance(self.first_stage_model, VQModelInterface): |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) |
| else: |
| return self.first_stage_model.decode(z) |
|
|
| else: |
| if isinstance(self.first_stage_model, VQModelInterface): |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) |
| else: |
| return self.first_stage_model.decode(z) |
|
|
| @torch.no_grad() |
| def encode_first_stage(self, x): |
| if hasattr(self, "split_input_params"): |
| if self.split_input_params["patch_distributed_vq"]: |
| ks = self.split_input_params["ks"] |
| stride = self.split_input_params["stride"] |
| df = self.split_input_params["vqf"] |
| self.split_input_params['original_image_size'] = x.shape[-2:] |
| bs, nc, h, w = x.shape |
| if ks[0] > h or ks[1] > w: |
| ks = (min(ks[0], h), min(ks[1], w)) |
| print("reducing Kernel") |
|
|
| if stride[0] > h or stride[1] > w: |
| stride = (min(stride[0], h), min(stride[1], w)) |
| print("reducing stride") |
|
|
| fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) |
| z = unfold(x) |
| |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) |
|
|
| output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) |
| for i in range(z.shape[-1])] |
|
|
| o = torch.stack(output_list, axis=-1) |
| o = o * weighting |
|
|
| |
| o = o.view((o.shape[0], -1, o.shape[-1])) |
| |
| decoded = fold(o) |
| decoded = decoded / normalization |
| return decoded |
|
|
| else: |
| return self.first_stage_model.encode(x) |
| else: |
| return self.first_stage_model.encode(x) |
|
|
| def get_input_and_conditioning(self,batch, device=None): |
| if device is not None: batch = recursive_to(batch, device) |
| |
| get_referenceZ=(REFNET.ENABLE and REFNET.task2layerNum[global_.task]>0) or CH14 |
| batch = self.get_input_(batch, self.first_stage_key,get_referenceZ=get_referenceZ) |
| |
| assert ( self.model.conditioning_key is not None ) and self.cond_stage_trainable |
| c=self.conditioning_with_feat(batch['ref_imgs'],landmarks=batch['landmarks'],enInputs=batch['enInputs']) |
| return batch,c |
| def shared_step(self, batch, **kwargs): |
| task = self.set_task(batch) |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): |
| self.model.bank.clear() |
| batch, c = self.get_input_and_conditioning(batch) |
| z9 = batch['z9'] |
| z_ref = batch['z_ref'] |
| gt512 = batch['GT'] |
| gt256 = batch.get('GT256',None) |
| |
| loss = self(z9, c,z_ref=z_ref,gt512=gt512,gt256=gt256,task=task,batch=batch,) |
| return loss |
|
|
| def forward(self, x, c, *args, **kwargs): |
| task = kwargs['task'] |
| |
| t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() |
| self.u_cond_prop=random.uniform(0, 1) |
| if self.model.conditioning_key is not None: |
| |
| if self.cond_stage_trainable: |
| pass |
| |
| if self.shorten_cond_schedule: |
| raise Exception |
| tc = self.cond_ids[t].to(self.device) |
| c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) |
|
|
| if self.u_cond_prop<self.u_cond_percent and self.training : |
| return self.p_losses(x, self.learnable_vector[task].repeat(x.shape[0],1,1), t, *args, **kwargs) |
| else: |
| return self.p_losses(x, c, t, *args, **kwargs) |
| |
| |
|
|
| def apply_model(self, x_noisy, t, cond, return_ids=False,return_features=False, |
| z_ref=None, |
| ): |
|
|
| if isinstance(cond, dict): |
| |
| pass |
| else: |
| if not isinstance(cond, list): |
| cond = [cond] |
| key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' |
| cond = {key: cond} |
|
|
| if hasattr(self, "split_input_params"): |
| assert 0,'This branch should not execute in practice' |
| assert len(cond) == 1 |
| assert not return_ids |
| ks = self.split_input_params["ks"] |
| stride = self.split_input_params["stride"] |
|
|
| h, w = x_noisy.shape[-2:] |
|
|
| fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) |
|
|
| z = unfold(x_noisy) |
| |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) |
| z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] |
|
|
| if self.cond_stage_key in ["image", "LR_image", "segmentation", |
| 'bbox_img'] and self.model.conditioning_key: |
| c_key = next(iter(cond.keys())) |
| c = next(iter(cond.values())) |
| assert (len(c) == 1) |
| c = c[0] |
|
|
| c = unfold(c) |
| c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) |
|
|
| cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] |
|
|
| elif self.cond_stage_key == 'coordinates_bbox': |
| assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' |
|
|
| |
| n_patches_per_row = int((w - ks[0]) / stride[0] + 1) |
| full_img_h, full_img_w = self.split_input_params['original_image_size'] |
| |
| |
| num_downs = self.first_stage_model.encoder.num_resolutions - 1 |
| rescale_latent = 2 ** (num_downs) |
|
|
| |
| |
| tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, |
| rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) |
| for patch_nr in range(z.shape[-1])] |
|
|
| |
| patch_limits = [(x_tl, y_tl, |
| rescale_latent * ks[0] / full_img_w, |
| rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] |
| |
|
|
| |
| patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) |
| for bbox in patch_limits] |
| print(patch_limits_tknzd[0].shape) |
| |
| assert isinstance(cond, dict), 'cond must be dict to be fed into model' |
| cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) |
|
|
| adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) |
| adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') |
| adapted_cond = self.get_learned_conditioning(adapted_cond) |
| adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) |
|
|
| cond_list = [{'c_crossattn': [e]} for e in adapted_cond] |
|
|
| else: |
| cond_list = [cond for i in range(z.shape[-1])] |
|
|
| |
| output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] |
| assert not isinstance(output_list[0], |
| tuple) |
|
|
| o = torch.stack(output_list, axis=-1) |
| o = o * weighting |
| |
| o = o.view((o.shape[0], -1, o.shape[-1])) |
| |
| x_recon = fold(o) / normalization |
|
|
| else: |
| x_recon = self.model(x_noisy, t, **cond, return_features=return_features, z_ref=z_ref, |
| task=self.task, _trainer=self.trainer, |
| ) |
| if return_features: |
| return x_recon |
| if isinstance(x_recon, tuple) and not return_ids: |
| return x_recon[0] |
| else: |
| return x_recon |
|
|
|
|
| def p_losses(self, x_start, cond, t, noise=None, z_ref=None, gt512=None, gt256=None, task=None, |
| batch :dict = None, |
| ): |
| |
| |
| global_.moe_aux_loss = torch.tensor(0.0, device=self.device) |
| if self.first_stage_key == 'inpaint': |
| |
| noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:])) |
| if 1: |
| x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise) |
| x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1) |
| else: |
| noise = default(noise, lambda: torch.randn_like(x_start)) |
| if 1: |
| x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) |
| if z_ref is not None: |
| assert self.first_stage_key == 'inpaint', 'Expected first_stage_key to be "inpaint"' |
| """ |
| z_ref: b,4,... |
| z_ref = concat [z_ref_noisy, z_ref, tensor_1c] |
| tensor_1c is temporarily set to all zeros |
| """ |
| z_ref_noisy = self.q_sample(x_start=z_ref, t=t, noise=torch.randn_like(z_ref)) |
| tensor_1c = torch.zeros((z_ref.shape[0], 1, z_ref.shape[2], z_ref.shape[3]), device=z_ref.device) |
| if REFNET.CH9: |
| z_ref = torch.cat([z_ref_noisy, z_ref, tensor_1c], dim=1) |
| if 1: |
| model_output = self.apply_model(x_noisy, t, cond, z_ref=z_ref, ) |
|
|
| loss_dict = {} |
| prefix = 'train' if self.training else 'val' |
| if DDIM_losses: |
| |
| t_new = torch.randint(self.num_timesteps-1, self.num_timesteps, (x_start.shape[0],), device=self.device).long().to(self.device) |
| |
| |
| x_noisy_rec = self.q_sample(x_start=x_start[:,:4,:,:], t=t_new, noise=noise) |
| x_noisy_rec = torch.cat((x_noisy_rec,x_start[:,4:,:,:]),dim=1) |
| |
| |
| ddim_steps=self.Reconstruct_DDIM_steps |
| n_samples=x_noisy_rec.shape[0] |
| shape=(4,64,64) |
| scale=5 |
| ddim_eta=0.0 |
| start_code=x_noisy_rec |
| test_model_kwargs=None |
| |
| |
| samples_ddim, sample_intermediates = self.sampler.sample_train(S=ddim_steps, |
| conditioning=cond, |
| batch_size=n_samples, |
| shape=shape, |
| verbose=False, |
| unconditional_guidance_scale=scale, |
| unconditional_conditioning=None, |
| eta=ddim_eta, |
| x_T=start_code, |
| t=t_new, |
| z_ref=z_ref, |
| test_model_kwargs=test_model_kwargs) |
|
|
|
|
| |
| |
| |
| |
| other_pred_x_0=sample_intermediates['pred_x0'] |
| len_inter = len(other_pred_x_0) |
| printC("len_inter", len_inter ) |
| for i in range(len(other_pred_x_0)): |
| other_pred_x_0[i]=self.differentiable_decode_first_stage(other_pred_x_0[i]) |
| |
| |
| |
| |
| |
| |
| ID_loss=0 |
| clip_loss=0 |
| loss_lpips=0 |
| loss_rec=0 |
| loss_landmark=0 |
| |
| |
| if 1: |
| |
| |
| if 0: |
| inpaint_mask_64 = x_start[:,8,:,:] |
| masks=TF.resize(inpaint_mask_64,(other_pred_x_0[0].shape[2],other_pred_x_0[0].shape[3])) |
| if not 1: |
| masks = 1 - masks |
| |
| x_samples_ddim_masked=[x_samples_ddim_preds*masks.unsqueeze(1) for x_samples_ddim_preds in other_pred_x_0] |
| |
| |
| else: |
| x_samples_ddim_masked = other_pred_x_0 |
| Landmark_loss_weight = 0 |
| ID_loss_weight = [0.3, 0, 0.1, 0.2, ][task] |
| if ID_loss_weight > 0 : |
| ID_Losses=[] |
| for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): |
| ID_loss,sim_imp,_=self.face_ID_model(x_samples_ddim_preds,gt512,clip_img=False) |
| ID_Losses.append(ID_loss) |
| loss_dict.update({f'{prefix}/ID_loss_{step}': ID_loss}) |
| |
| ID_loss=torch.mean(torch.stack(ID_Losses)) |
| loss_dict.update({f'{prefix}/ID_loss': ID_loss}) |
| loss_dict.update({f'{prefix}/sim_imp': sim_imp}) |
| |
| CLIP_loss_weight = [1.5/4, 0.8, 1, 0.5, ][task] |
| if CLIP_loss_weight > 0 : |
| def _loss(_img1,_img2): |
| _e1 = self.encoder_clip_face.forward_vit(_img1,resize=True) |
| _e2 = self.encoder_clip_face.forward_vit(_img2,resize=True) |
| return torch.nn.functional.mse_loss( _e1, _e2 ) |
| clip_Losses=[] |
| for step,x_samples_ddim_preds in enumerate(x_samples_ddim_masked): |
| clip_loss = _loss(x_samples_ddim_preds,gt512) |
| clip_Losses.append(clip_loss) |
| loss_dict.update({f'{prefix}/clip_loss_{step}': clip_loss}) |
| clip_loss=torch.mean(torch.stack(clip_Losses)) |
| loss_dict.update({f'{prefix}/clip_loss': clip_loss}) |
| |
| LPIPS_loss_weight = [0.05, 0.015, 0.015, 0.015, ][task] |
| if LPIPS_loss_weight>0: |
| if gt256 is not None: |
| _lpips_base_size = 256 |
| _gt_for_lpips = gt256 |
| else: |
| _lpips_base_size = 512 |
| _gt_for_lpips = gt512 |
| |
| for j in range(len(other_pred_x_0)): |
| for i in range(3): |
| _size = _lpips_base_size//2**i |
| _pred_for_lpips = F.adaptive_avg_pool2d(other_pred_x_0[j],(_size,_size)) |
| _gt_for_lpips_resized = F.adaptive_avg_pool2d(_gt_for_lpips,(_size,_size)) |
| loss_lpips_1 = self.lpips_loss( |
| _pred_for_lpips, |
| _gt_for_lpips_resized, |
| ) |
| loss_dict.update({f'{prefix}/loss_lpips_{j}_{i}': loss_lpips_1}) |
| printC(f"loss_lpips_1 at {j} {i} :", loss_lpips_1) |
| loss_lpips += loss_lpips_1 |
| loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) |
| |
| REC_loss_weight = [0.05, 0.01, 0.01, 0.01, ][task] |
| if REC_loss_weight > 0 : |
| for j in range(len(other_pred_x_0)): |
| loss_rec_1 = torch.nn.functional.mse_loss( other_pred_x_0[j], gt512) |
| loss_dict.update({f'{prefix}/loss_rec_{j}': loss_rec_1}) |
| printC(f"loss_rec_1 at {j} :", loss_rec_1) |
| loss_rec += loss_rec_1 |
| loss_dict.update({f'{prefix}/loss_rec': loss_rec}) |
| if 1: |
| if self.parameterization == "x0": |
| target = x_start |
| elif self.parameterization == "eps": |
| target = noise |
| else: |
| raise NotImplementedError() |
| |
| |
| loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) |
| loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) |
| loss_dict.update({f'{prefix}/loss_simple-t{task}': loss_simple.mean()}) |
|
|
| self.logvar = self.logvar.to(self.device) |
| logvar_t = self.logvar[t].to(self.device) |
| loss = loss_simple / torch.exp(logvar_t) + logvar_t |
| |
| if self.learn_logvar: |
| loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) |
| loss_dict.update({'logvar': self.logvar.data.mean()}) |
|
|
| loss = self.l_simple_weight * loss.mean() |
|
|
| loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) |
| loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() |
| loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) |
| loss_dict.update({f'{prefix}/loss_vlb-t{task}': loss_vlb}) |
| loss += (self.original_elbo_weight * loss_vlb) |
| else: |
| loss = 0 |
| if DDIM_losses: |
| _item = lambda _a: _a.detach().cpu().item() if isinstance(_a,torch.Tensor) else _a |
| printC("orig, ID clip, lpips rec lmk:", |
| f"{_item(loss):.4f}, {_item(ID_loss):.4f} {_item(clip_loss):.4f}, {_item(loss_lpips):.4f} {_item(loss_rec):.4f} {_item(loss_landmark):.4f}", |
| f"{ID_Losses=}" if ID_loss_weight>0 else "", |
| f"{clip_Losses=}" if CLIP_loss_weight>0 else "", |
| ) |
| loss+=ID_loss_weight*ID_loss+LPIPS_loss_weight*loss_lpips+Landmark_loss_weight*loss_landmark+REC_loss_weight*loss_rec+CLIP_loss_weight*clip_loss |
|
|
| |
| moe_aux = global_.moe_aux_loss |
| if isinstance(moe_aux, torch.Tensor): |
| loss = loss + moe_aux |
| loss_dict.update({f'{prefix}/moe_aux_loss': moe_aux}) |
| loss_dict.update({f'{prefix}/loss': loss}) |
| loss_dict.update({f'{prefix}/loss-t{task}': loss}) |
| return loss, loss_dict |
|
|
|
|
|
|
| def configure_optimizers(self): |
| lr = self.learning_rate |
| params = list(self.model.parameters()) |
| |
| if self.partial_training: |
| |
| print("Partial training.............................") |
| train_names=self.trainable_keys |
| train_names=[ 'attn2','norm2'] |
| params_train=[] |
| for name,param in self.model.named_parameters(): |
| if "diffusion_model" not in name and param.requires_grad: |
| print(name) |
| params_train.append(param) |
| |
| elif "diffusion_model" in name and any(train_name in name for train_name in train_names): |
| print(name) |
| params_train.append(param) |
| params=params_train |
| print("Setting up Adam optimizer.......................") |
|
|
| if self.cond_stage_trainable: |
| print(f"{self.__class__.__name__}: Also optimizing conditioner params!") |
| if hasattr(self,'encoder_clip_face'): |
| params += list(self.encoder_clip_face.final_ln2.parameters())+list(self.encoder_clip_face.mapper2.parameters()) |
| if self.USE_proj_out_source: |
| params += list(self.proj_out_source__face.parameters()) |
| if hasattr(self,'encoder_clip_hair'): |
| params += list(self.encoder_clip_hair.final_ln2.parameters())+list(self.encoder_clip_hair.mapper2.parameters()) |
| if self.USE_proj_out_source: |
| params += list(self.proj_out_source__hair.parameters()) |
| if hasattr(self,'encoder_clip_head_t2'): |
| params += list(self.encoder_clip_head_t2.final_ln2.parameters())+list(self.encoder_clip_head_t2.mapper2.parameters()) |
| if hasattr(self,'encoder_clip_head_t3'): |
| params += list(self.encoder_clip_head_t3.final_ln2.parameters())+list(self.encoder_clip_head_t3.mapper2.parameters()) |
| if hasattr(self,'encoder_clip_head_t2') or hasattr(self,'encoder_clip_head_t3'): |
| if self.USE_proj_out_source: |
| params += list(self.proj_out_source__head.parameters()) |
| if hasattr(self,'ID_proj_out'): |
| params += list(self.ID_proj_out.parameters()) |
| if hasattr(self,'landmark_proj_out'): |
| params += list(self.landmark_proj_out.parameters()) |
| if self.learn_logvar: |
| print('Diffusion model optimizing logvar') |
| params.append(self.logvar) |
| params.extend(self.learnable_vector) |
| params = [p for p in params if p.requires_grad] |
|
|
| |
| |
| |
| moe_gate_ids = set() |
| moe_ep_ids = set() |
| for name, p in self.model.named_parameters(): |
| if not p.requires_grad: |
| continue |
| if ".moe_gate_mlp." in name: |
| moe_gate_ids.add(id(p)) |
| elif ".moe_experts_" in name: |
| moe_ep_ids.add(id(p)) |
|
|
| params_ids = set(id(p) for p in params) |
| task_specific_ids = set() |
| for name, p in self.named_parameters(): |
| if not p.requires_grad: |
| continue |
| if id(p) not in params_ids: |
| continue |
| is_task_specific = is_task_specific_(name) |
| if rank_==0: print(f"{is_task_specific=} {name}") |
| if is_task_specific: |
| task_specific_ids.add(id(p)) |
|
|
| base_params = [] |
| task_specific_params = [] |
| moe_gate_params = [] |
| moe_ep_params = [] |
| for p in params: |
| pid = id(p) |
| if pid in task_specific_ids: |
| task_specific_params.append(p) |
| elif pid in moe_gate_ids: |
| moe_gate_params.append(p) |
| elif pid in moe_ep_ids: |
| moe_ep_params.append(p) |
| else: |
| base_params.append(p) |
|
|
| param_groups = [] |
| if base_params: |
| param_groups.append({"params": base_params, "lr": lr}) |
| if task_specific_params: |
| param_groups.append({"params": task_specific_params, "lr": lr * LR_factor}) |
| if moe_gate_params: |
| param_groups.append({"params": moe_gate_params, "lr": lr * MOE_GATE_LR_MULT}) |
| if moe_ep_params: |
| param_groups.append({"params": moe_ep_params, "lr": lr * MOE_EP_LR_MULT}) |
| if ZERO1_ENABLE: |
| zero_pg = None |
| if 1: |
| if dist.is_available() and dist.is_initialized(): |
| zero_pg = dist.new_group(backend='gloo') |
| opt = ZeroRedundancyOptimizer( |
| param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, |
| optimizer_class=torch.optim.AdamW if ADAM_or_SGD else torch.optim.SGD, |
| lr=lr, |
| process_group=zero_pg, |
| ) |
| else: |
| if ADAM_or_SGD: |
| opt = torch.optim.AdamW(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr) |
| else: |
| opt = torch.optim.SGD(param_groups if (task_specific_params or moe_gate_params or moe_ep_params) else params, lr=lr, momentum=0.9) |
| if gate_('LatentDiffusion.configure_optimizers params:'): |
| if (task_specific_params or moe_gate_params or moe_ep_params): |
| print(f"base/task_specific/ep/gate lens: {len(base_params)=} {len(task_specific_params)=} {len(moe_ep_params)=} {len(moe_gate_params)=}") |
| print(f"sum of .numel(): base={sum(p.numel() for p in base_params)} task_specific={sum(p.numel() for p in task_specific_params)} ep={sum(p.numel() for p in moe_ep_params)} gate={sum(p.numel() for p in moe_gate_params)}") |
| else: |
| print(f"{len(params)=}") |
| print(f"sum of .numel(): {sum(param.numel() for param in params)}") |
| if self.use_scheduler: |
| assert 'target' in self.scheduler_config |
| scheduler = instantiate_from_config(self.scheduler_config) |
|
|
| print("Setting up LambdaLR scheduler...") |
| scheduler = [ |
| { |
| 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), |
| 'interval': 'step', |
| 'frequency': 1 |
| }] |
| return [opt], scheduler |
| return opt |
|
|
| def on_train_epoch_start(self): |
| def _set_req_grad(p, flag): |
| if p.requires_grad != flag: |
| p.requires_grad = flag |
| return 1 |
| return 0 |
| return |
| if 0: |
| train_now = self.current_epoch < N_EPOCHS_TRAIN_REF_AND_MID |
| else: |
| train_now = (self.current_epoch % 2 == 0) |
| ct_toggled = 0 |
| |
| ct_shared = 0 |
| for name, p in self.model.diffusion_model.named_parameters(): |
| |
| is_shared = ('.shared_ffn.' in name) or ('.shared.' in name) |
| if is_shared: |
| ct_shared += _set_req_grad(p, train_now) |
| print(f"[freeze@epoch]{self.current_epoch=} {train_now=} {ct_toggled=} {ct_shared=}") |
|
|
| @torch.no_grad() |
| def to_rgb(self, x): |
| x = x.float() |
| if not hasattr(self, "colorize"): |
| self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) |
| x = nn.functional.conv2d(x, weight=self.colorize) |
| x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. |
| return x |
| def __repr__(self): |
| if DEBUG: return 'LatentDiffusion.__repr__' |
| return super().__repr__() |
| @property |
| def model_size(self): |
| if DEBUG: return -1 |
| return super().model_size |
|
|
|
|
| from .bank import Bank |
| class DiffusionWrapper(pl.LightningModule): |
| def __init__(self, diff_model_config, conditioning_key): |
| super().__init__() |
| diff_model_config['params']['is_refNet'] = False |
| self.diffusion_model = instantiate_from_config(diff_model_config) |
| self.conditioning_key = conditioning_key |
| assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] |
| if REFNET.ENABLE: |
| diff_model_config_refNet = diff_model_config |
| print('instantiate / deepcopy diffusion_model_refNet ing...') |
| if 1: |
| diff_model_config_refNet['params']['in_channels'] = 9 if REFNET.CH9 else 4 |
| diff_model_config_refNet['params']['is_refNet'] = True |
| self.diffusion_model_refNet :UNetModel = instantiate_from_config(diff_model_config_refNet) |
| else: |
| self.diffusion_model_refNet :UNetModel = copy.deepcopy(self.diffusion_model) |
| self.diffusion_model_refNet.is_refNet = True |
| if 1: |
| |
| if 1: |
| self.diffusion_model_refNet.input_blocks = self.diffusion_model_refNet.input_blocks[:9] |
| del self.diffusion_model_refNet.middle_block |
| del self.diffusion_model_refNet.output_blocks |
| del self.diffusion_model_refNet.out |
| print('over.') |
| |
|
|
| def forward(self, x, t, c_concat: list = None, c_crossattn: list = None,return_features=False, |
| z_ref=None, |
| task = None, |
| _trainer :pl.Trainer = None, |
| ): |
| _in_train_or_val = ( _trainer is not None ) and ( _trainer.validating or _trainer.sanity_checking ) |
| assert self.conditioning_key == 'crossattn' |
| if self.conditioning_key is None: |
| out = self.diffusion_model(x, t) |
| elif self.conditioning_key == 'concat': |
| xc = torch.cat([x] + c_concat, dim=1) |
| out = self.diffusion_model(xc, t) |
| elif self.conditioning_key == 'crossattn': |
| cc = torch.cat(c_crossattn, 1) |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0): |
| if task in (0,2,3,): |
| cc_ref = cc[:,:-1, :] |
| else: |
| cc_ref = cc |
| printC("c for refNet",f"{custom_repr_v3(cc_ref)}") |
| self.diffusion_model_refNet(z_ref, t, context=cc_ref,return_features=False) |
| out = self.diffusion_model(x, t, context=cc,return_features=return_features) |
| if (REFNET.ENABLE and REFNET.task2layerNum[task]>0) and not (self.training or _in_train_or_val): |
| |
| self.bank.clear() |
| elif self.conditioning_key == 'hybrid': |
| xc = torch.cat([x] + c_concat, dim=1) |
| cc = torch.cat(c_crossattn, 1) |
| out = self.diffusion_model(xc, t, context=cc) |
| elif self.conditioning_key == 'adm': |
| cc = c_crossattn[0] |
| out = self.diffusion_model(x, t, y=cc) |
| else: |
| raise NotImplementedError() |
|
|
| return out |
|
|