Spaces:
Sleeping
Sleeping
| from typing import Optional, Tuple, Callable, List | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.models.autoencoders.vq_model import VQModel | |
| from transformers import CLIPTextModelWithProjection, CLIPTokenizer | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from src.smc.transformer import Transformer2DModel | |
| from src.smc.scheduler import BaseScheduler | |
| from src.smc.resampling import compute_ess_from_log_w, normalize_weights | |
| from src.smc.lora_pipeline import MeissonicLoraLoaderMixin | |
| def logmeanexp(x, dim=None, keepdim=False): | |
| """Numerically stable log-mean-exp using torch.logsumexp.""" | |
| if dim is None: | |
| x = x.view(-1) | |
| dim = 0 | |
| # log-sum-exp with or without keeping the reduced dim | |
| lse = torch.logsumexp(x, dim=dim, keepdim=keepdim) | |
| # subtract log(N) to convert sum into mean (broadcasts correctly) | |
| return lse - math.log(x.size(dim)) | |
| def _prepare_latent_image_ids(batch_size, height, width, device, dtype): | |
| """ | |
| Build positional IDs for latent-image tokens. | |
| Each latent token corresponds to a downsampled image “pixel” in a 2D grid. | |
| This function creates a (H//2, W//2, 3) grid where: | |
| - channel 0 is reserved (all zeros) | |
| - channel 1 stores the row index (0 .. H//2-1) | |
| - channel 2 stores the column index (0 .. W//2-1) | |
| Args: | |
| batch_size (int): Number of images in the batch (unused here, but kept for API consistency). | |
| height (int): Input image height (pre-VAE) or latent height depending on call site. | |
| width (int): Input image width (pre-VAE) or latent width depending on call site. | |
| device (torch.device): Device on which to place the returned tensor. | |
| dtype (torch.dtype): Desired data type of the returned tensor. | |
| Returns: | |
| torch.Tensor of shape ((H//2 * W//2), 3) with dtype and device as specified. | |
| Each row is [0, row_index, col_index], flattened in row-major order. | |
| """ | |
| latent_image_ids = torch.zeros(height // 2, width // 2, 3) | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] | |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
| latent_image_ids = latent_image_ids.reshape( | |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
| ) | |
| return latent_image_ids.to(device=device, dtype=dtype) | |
| class Pipeline( | |
| DiffusionPipeline, | |
| MeissonicLoraLoaderMixin, | |
| ): | |
| image_processor: VaeImageProcessor | |
| vqvae: VQModel | |
| tokenizer: CLIPTokenizer | |
| text_encoder: CLIPTextModelWithProjection | |
| transformer: Transformer2DModel | |
| scheduler: BaseScheduler | |
| model_cpu_offload_seq = "text_encoder->transformer->vqvae" | |
| def __init__( | |
| self, | |
| vqvae: VQModel, | |
| tokenizer: CLIPTokenizer, | |
| text_encoder: CLIPTextModelWithProjection, | |
| transformer: Transformer2DModel, | |
| scheduler: BaseScheduler, | |
| ): | |
| super().__init__() | |
| self.register_modules( | |
| vqvae=vqvae, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| ) | |
| self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) # type: ignore | |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) | |
| self.model_dtype = torch.bfloat16 | |
| self.mask_index = self.scheduler.mask_token_id # type: ignore | |
| self.vocab_size = self.transformer.config.vocab_size # type:ignore | |
| self.codebook_size = self.transformer.config.codebook_size # type: ignore | |
| def __call__( | |
| self, | |
| prompt: str|List[str], | |
| reward_fn: Callable, | |
| resample_fn: Callable, | |
| resample_frequency: int = 1, | |
| kl_weight: float = 1.0, | |
| lambdas: Optional[torch.Tensor] = None, | |
| height: Optional[int] = 1024, | |
| width: Optional[int] = 1024, | |
| num_inference_steps: int = 48, | |
| guidance_scale: float = 9.0, | |
| negative_prompt = None, | |
| batches: int = 1, # Number of independent SMCs | |
| num_particles: int = 1, # Number of particles per SMC | |
| batch_p: int = 1, # Number of parallel particles | |
| phi: int = 1, # number of samples for reward approximation | |
| tau: float = 1.0, # temperature for taking x0 samples | |
| output_type="pil", | |
| micro_conditioning_aesthetic_score: int = 6, | |
| micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), | |
| proposal_type:str = "locally_optimal", | |
| ft_model_pipe = None, # needs to supplied if proposal_type is ft_model | |
| use_ft_model_for_expected_reward: bool = False, # Whether to use the forward model for expected reward | |
| use_continuous_formulation: bool = False, # Whether to use a continuous formulation of carry over unmasking | |
| disable_progress_bar: bool = False, | |
| final_strategy="argmax_rewards", | |
| verbose=True, | |
| ): | |
| # 0. Set default lambdas | |
| if lambdas is None: | |
| lambdas = torch.ones(num_inference_steps + 1) | |
| assert len(lambdas) == num_inference_steps + 1, f"lambdas must of length {num_inference_steps + 1}" | |
| lambdas = lambdas.clamp_min(0.001).to(self._execution_device) | |
| # 1. n_particles, batch_size etc | |
| total_particles = batches * num_particles | |
| batch_p = min(batch_p, total_particles) | |
| H, W = height // self.vae_scale_factor, width // self.vae_scale_factor | |
| # 2.1. Calculate prompt (and negative prompt) embeddings | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| input_ids = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
| prompt_embeds = outputs.text_embeds | |
| encoder_hidden_states = outputs.hidden_states[-2] | |
| prompt_embeds = prompt_embeds.repeat(batch_p, 1) | |
| encoder_hidden_states = encoder_hidden_states.repeat(batch_p, 1, 1) | |
| if guidance_scale > 1.0: | |
| if negative_prompt is None: | |
| negative_prompt = [""] | |
| else: | |
| negative_prompt = [negative_prompt] | |
| input_ids = self.tokenizer( | |
| negative_prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
| negative_prompt_embeds = outputs.text_embeds | |
| negative_encoder_hidden_states = outputs.hidden_states[-2] | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(batch_p, 1) | |
| negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(batch_p, 1, 1) | |
| prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) | |
| encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) | |
| # 2.2. Prepare micro-conditions | |
| micro_conds = torch.tensor( | |
| [ | |
| width, | |
| height, | |
| micro_conditioning_crop_coord[0], | |
| micro_conditioning_crop_coord[1], | |
| micro_conditioning_aesthetic_score, | |
| ], | |
| device=self._execution_device, | |
| dtype=encoder_hidden_states.dtype, | |
| ) | |
| micro_conds = micro_conds.unsqueeze(0) | |
| micro_conds = micro_conds.expand(2 * batch_p if guidance_scale > 1.0 else batch_p, -1) | |
| # 3. Intialize latents | |
| latents = torch.full( | |
| (total_particles, H, W), self.mask_index, dtype=torch.long, device=self._execution_device # type: ignore | |
| ) | |
| # Set some constant vectors | |
| ONE = torch.ones(self.vocab_size, device=self._execution_device).float() | |
| MASK = F.one_hot(torch.tensor(self.mask_index), num_classes=self.vocab_size).float().to(self._execution_device) # type: ignore | |
| # 4. Set scheduler timesteps | |
| self.scheduler.set_timesteps(num_inference_steps) | |
| # 5. Set SMC variables | |
| logits = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device) | |
| logits_ft_model = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device) | |
| rewards = torch.zeros((total_particles,), device=self._execution_device) | |
| rewards_grad = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device) | |
| log_twist = torch.zeros((total_particles, ), device=self._execution_device) | |
| log_prob_proposal = torch.zeros((total_particles, ), device=self._execution_device) | |
| log_prob_diffusion = torch.zeros((total_particles, ), device=self._execution_device) | |
| log_w = torch.zeros((total_particles, ), device=self._execution_device) | |
| def propagate(): | |
| if proposal_type == "locally_optimal": | |
| propgate_locally_optimal() | |
| # elif proposal_type == "straight_through_gradients": | |
| # propagate_straight_through_gradients() | |
| elif proposal_type == "reverse": | |
| propagate_reverse() | |
| elif proposal_type == "without_SMC": | |
| propagate_without_SMC() | |
| elif proposal_type == "ft_model": | |
| propagate_ft_model() | |
| else: | |
| raise NotImplementedError(f"Proposal type {proposal_type} is not implemented.") | |
| def propgate_locally_optimal(): | |
| nonlocal log_w, latents, log_prob_proposal, log_prob_diffusion, logits, rewards, rewards_grad, log_twist | |
| log_twist_prev = log_twist.clone() | |
| for j in range(0, total_particles, batch_p): | |
| latents_batch = latents[j:j+batch_p] | |
| with torch.enable_grad(): | |
| latents_one_hot = F.one_hot(latents_batch, num_classes=self.vocab_size).to(dtype=self.model_dtype).requires_grad_(True) | |
| tmp_logits = self.get_logits(latents_one_hot, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep) | |
| tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device) | |
| gamma = 1 - ((ONE - MASK) * latents_one_hot).sum(dim=-1, keepdim=True) | |
| for phi_i in range(phi): | |
| sample = F.gumbel_softmax(tmp_logits, tau=tau, hard=True) | |
| if use_continuous_formulation: | |
| sample = gamma * sample + (ONE - MASK) * latents_one_hot | |
| sample = self._decode_one_hot_latents(sample, height, width, "pt") | |
| tmp_rewards[:, phi_i] = reward_fn(sample) | |
| tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur | |
| tmp_rewards_grad = torch.autograd.grad( | |
| outputs=tmp_rewards, | |
| inputs=latents_one_hot, | |
| grad_outputs=torch.ones_like(tmp_rewards) | |
| )[0].detach() | |
| logits[j:j+batch_p] = tmp_logits.detach() | |
| rewards[j:j+batch_p] = tmp_rewards.detach() | |
| rewards_grad[j:j+batch_p] = tmp_rewards_grad.detach() | |
| log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur | |
| if verbose: | |
| print("Rewards: ", rewards) | |
| # Calculate weights | |
| incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev) | |
| log_w += incremental_log_w | |
| # Now reshape log_w to (batches, num_particles) | |
| log_w = log_w.reshape(batches, num_particles) | |
| if verbose: | |
| print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal) | |
| print("log_twist - log_twist_prev:", log_twist - log_twist_prev) | |
| print("Incremental log weights: ", incremental_log_w) | |
| print("Log weights: ", log_w) | |
| print("Normalized weights: ", normalize_weights(log_w, dim=-1)) | |
| # Resample particles | |
| if verbose: | |
| print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1)) | |
| if resample_condition: | |
| resample_indices = [] | |
| log_w_new = [] | |
| is_resampled = False | |
| for batch in range(batches): | |
| resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch]) | |
| resample_indices.append(resample_indices_batch + batch * num_particles) | |
| log_w_new.append(log_w_batch) | |
| is_resampled = is_resampled or is_resampled_batch | |
| resample_indices = torch.cat(resample_indices, dim=0) | |
| log_w = torch.cat(log_w_new, dim=0) | |
| if is_resampled: | |
| latents = latents[resample_indices] | |
| logits = logits[resample_indices] | |
| rewards = rewards[resample_indices] | |
| rewards_grad = rewards_grad[resample_indices] | |
| log_twist = log_twist[resample_indices] | |
| if verbose: | |
| print("Resample indices: ", resample_indices) | |
| if log_w.ndim == 2: | |
| log_w = log_w.reshape(total_particles) | |
| # Propose new particles | |
| sched_out = self.scheduler.step_with_approx_guidance( | |
| latents=latents, | |
| logits=logits, | |
| approx_guidance=rewards_grad * scale_next, | |
| step=i, | |
| ) | |
| if verbose: | |
| print("Approx guidance norm: ", ((rewards_grad * scale_next) ** 2).sum(dim=(1, 2)).sqrt()) | |
| latents, log_prob_proposal, log_prob_diffusion = ( | |
| sched_out.new_latents, | |
| sched_out.log_prob_proposal, | |
| sched_out.log_prob_diffusion, | |
| ) | |
| def propagate_reverse(): | |
| nonlocal log_w, latents, logits, rewards, log_twist | |
| log_twist_prev = log_twist.clone() | |
| for j in range(0, total_particles, batch_p): | |
| latents_batch = latents[j:j+batch_p] | |
| with torch.no_grad(): | |
| tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep) | |
| tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device) | |
| tmp_logp_x0 = self.model._subs_parameterization(tmp_logits, latents_batch) | |
| for phi_i in range(phi): | |
| sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1) | |
| sample = self._decode_latents(sample, height, width, "pt") | |
| tmp_rewards[:, phi_i] = reward_fn(sample) | |
| tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur | |
| logits[j:j+batch_p] = tmp_logits.detach() | |
| rewards[j:j+batch_p] = tmp_rewards.detach() | |
| log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur | |
| if verbose: | |
| print("Rewards: ", rewards) | |
| # Calculate weights | |
| incremental_log_w = (log_twist - log_twist_prev) | |
| log_w += incremental_log_w | |
| # Now reshape log_w to (batches, num_particles) | |
| log_w = log_w.reshape(batches, num_particles) | |
| if verbose: | |
| print("log_twist - log_twist_prev:", log_twist - log_twist_prev) | |
| print("Incremental log weights: ", incremental_log_w) | |
| print("Log weights: ", log_w) | |
| print("Normalized weights: ", normalize_weights(log_w, dim=-1)) | |
| # Resample particles | |
| if verbose: | |
| print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1)) | |
| if resample_condition: | |
| resample_indices = [] | |
| log_w_new = [] | |
| is_resampled = False | |
| for batch in range(batches): | |
| resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch]) | |
| resample_indices.append(resample_indices_batch + batch * num_particles) | |
| log_w_new.append(log_w_batch) | |
| is_resampled = is_resampled or is_resampled_batch | |
| resample_indices = torch.cat(resample_indices, dim=0) | |
| log_w = torch.cat(log_w_new, dim=0) | |
| if is_resampled: | |
| latents = latents[resample_indices] | |
| logits = logits[resample_indices] | |
| rewards = rewards[resample_indices] | |
| log_twist = log_twist[resample_indices] | |
| if verbose: | |
| print("Resample indices: ", resample_indices) | |
| if log_w.ndim == 2: | |
| log_w = log_w.reshape(total_particles) | |
| # Propose new particles | |
| sched_out = self.scheduler.step( | |
| latents=latents, | |
| logits=logits, | |
| step=i, | |
| ) | |
| latents = sched_out.new_latents | |
| def propagate_without_SMC(): | |
| nonlocal latents, logits | |
| for j in range(0, total_particles, batch_p): | |
| latents_batch = latents[j:j+batch_p] | |
| with torch.no_grad(): | |
| tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep) | |
| logits[j:j+batch_p] = tmp_logits.detach() | |
| # Propose new particles | |
| sched_out = self.scheduler.step( | |
| latents=latents, | |
| logits=logits, | |
| step=i, | |
| ) | |
| latents = sched_out.new_latents | |
| def propagate_ft_model(): | |
| assert ft_model_pipe is not None, f"ft_model must be provided for proposal_type={proposal_type}." | |
| nonlocal log_w, latents, log_prob_proposal, log_prob_diffusion, logits, logits_ft_model, rewards, log_twist | |
| log_twist_prev = log_twist.clone() | |
| for j in range(0, total_particles, batch_p): | |
| latents_batch = latents[j:j+batch_p] | |
| with torch.no_grad(): | |
| tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep) | |
| tmp_logits_ft_model = ft_model_pipe.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep) | |
| tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device) | |
| if use_ft_model_for_expected_reward: | |
| tmp_logp_x0 = ft_model_pipe._subs_parameterization(tmp_logits_ft_model, latents_batch) | |
| else: | |
| tmp_logp_x0 = self._subs_parameterization(tmp_logits, latents_batch) | |
| for phi_i in range(phi): | |
| sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1) | |
| sample = self._decode_latents(sample, height, width, "pt") | |
| tmp_rewards[:, phi_i] = reward_fn(sample) | |
| tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur | |
| logits[j:j+batch_p] = tmp_logits.detach() | |
| logits_ft_model[j:j+batch_p] = tmp_logits_ft_model.detach() | |
| rewards[j:j+batch_p] = tmp_rewards.detach() | |
| log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur | |
| if verbose: | |
| print("Rewards: ", rewards) | |
| # Calculate weights | |
| incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev) | |
| log_w += incremental_log_w | |
| # Now reshape log_w to (batches, num_particles) | |
| log_w = log_w.reshape(batches, num_particles) | |
| if verbose: | |
| print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal) | |
| print("log_twist - log_twist_prev:", log_twist - log_twist_prev) | |
| print("Incremental log weights: ", incremental_log_w) | |
| print("Log weights: ", log_w) | |
| print("Normalized weights: ", normalize_weights(log_w, dim=-1)) | |
| # Resample particles | |
| if verbose: | |
| print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1)) | |
| if resample_condition: | |
| resample_indices = [] | |
| log_w_new = [] | |
| is_resampled = False | |
| for batch in range(batches): | |
| resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch]) | |
| resample_indices.append(resample_indices_batch + batch * num_particles) | |
| log_w_new.append(log_w_batch) | |
| is_resampled = is_resampled or is_resampled_batch | |
| resample_indices = torch.cat(resample_indices, dim=0) | |
| log_w = torch.cat(log_w_new, dim=0) | |
| if is_resampled: | |
| latents = latents[resample_indices] | |
| logits = logits[resample_indices] | |
| logits_ft_model = logits_ft_model[resample_indices] | |
| rewards = rewards[resample_indices] | |
| log_twist = log_twist[resample_indices] | |
| if verbose: | |
| print("Resample indices: ", resample_indices) | |
| if log_w.ndim == 2: | |
| log_w = log_w.reshape(total_particles) | |
| # Propose new particles | |
| approx_guidance = logits_ft_model - logits # this effectively makes logits_ft_model the proposal distribution | |
| approx_guidance[..., self.codebook_size:] = 0.0 # avoid nan due to (inf - inf) | |
| sched_out = self.scheduler.step_with_approx_guidance( | |
| latents=latents, | |
| logits=logits, | |
| approx_guidance=approx_guidance, | |
| step=i, | |
| ) | |
| latents, log_prob_proposal, log_prob_diffusion = ( | |
| sched_out.new_latents, | |
| sched_out.log_prob_proposal, | |
| sched_out.log_prob_diffusion, | |
| ) | |
| bar = enumerate(reversed(range(num_inference_steps))) | |
| if not disable_progress_bar: | |
| bar = tqdm(bar, leave=False) | |
| for i, timestep in bar: | |
| resample_condition = (i + 1) % resample_frequency == 0 | |
| scale_cur = lambdas[i] / kl_weight | |
| scale_next = lambdas[i + 1] / kl_weight | |
| if verbose: | |
| print(f"scale_cur: {scale_cur}, scale_next: {scale_next}") | |
| propagate() | |
| print('\n\n') | |
| # Final SMC weights | |
| scale_cur = lambdas[-1] / kl_weight | |
| log_twist_prev = log_twist.clone() | |
| for j in range(0, total_particles, batch_p): | |
| latents_batch = latents[j:j+batch_p] | |
| with torch.no_grad(): | |
| sample = self._decode_latents(latents_batch, height, width, "pt") | |
| tmp_rewards = reward_fn(sample) | |
| rewards[j:j+batch_p] = tmp_rewards | |
| log_twist[j:j+batch_p] = tmp_rewards * scale_cur | |
| if verbose: | |
| print("Rewards: ", rewards) | |
| # Calculate weights | |
| incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev) | |
| log_w += incremental_log_w | |
| # Now reshape everything to (batches, num_particles) for final strategy | |
| log_w = log_w.reshape(batches, num_particles) | |
| latents = latents.reshape(batches, num_particles, H, W) | |
| rewards = rewards.reshape(batches, num_particles) | |
| if verbose: | |
| print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal) | |
| print("log_twist - log_twist_prev:", log_twist - log_twist_prev) | |
| print("Incremental log weights: ", incremental_log_w) | |
| print("Log weights: ", log_w) | |
| print("Normalized weights: ", normalize_weights(log_w, dim=-1)) | |
| if final_strategy == "multinomial": | |
| final_indices = torch.multinomial(normalize_weights(log_w, dim=-1), num_samples=1).squeeze(-1) | |
| elif final_strategy == "argmax_rewards": | |
| final_indices = rewards.argmax(dim=-1) | |
| elif final_strategy == "argmax_weights": | |
| final_indices = log_w.argmax(dim=-1) | |
| else: | |
| raise NotImplementedError(f"Final strategy {final_strategy} is not implemented.") | |
| if verbose: | |
| print("Final selected indices: ", final_indices) | |
| latents = latents[ | |
| torch.arange(batches, device=latents.device), | |
| final_indices | |
| ] | |
| # Decode latents | |
| outputs = [] | |
| for j in range(0, batches, batch_p): | |
| latents_batch = latents[j:j+batch_p] | |
| outputs.extend( | |
| self._decode_latents(latents_batch, height, width, output_type) # type: ignore | |
| ) | |
| if output_type == "pt": | |
| outputs = torch.stack(outputs, dim=0) | |
| return outputs | |
| def get_logits(self, latents, guidance_scale, resolution, encoder_hidden_states, micro_conds, prompt_embeds, timestep): | |
| if guidance_scale > 1.0: | |
| # Latents are duplicated to get both unconditional and conditional logits | |
| model_input = torch.cat([latents] * 2) # type: ignore | |
| else: | |
| model_input = latents | |
| # img_ids, text_ids are used for positional embeddings | |
| if resolution == 1024: #args.resolution == 1024: | |
| img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[1],model_input.shape[2],model_input.device,model_input.dtype) | |
| else: | |
| img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[1],2*model_input.shape[2],model_input.device,model_input.dtype) | |
| txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype) | |
| if prompt_embeds.shape[0] != model_input.shape[0]: | |
| # This can happen for the last batch (if batch_p is not divisble by total particles) | |
| if guidance_scale > 1.0: | |
| batch_p = prompt_embeds.shape[0] // 2 | |
| last_batch_size = model_input.shape[0] // 2 | |
| prompt_embeds = torch.cat([prompt_embeds[:last_batch_size], prompt_embeds[batch_p :batch_p + last_batch_size]]) | |
| encoder_hidden_states = torch.cat([encoder_hidden_states[:last_batch_size], encoder_hidden_states[batch_p :batch_p + last_batch_size]]) | |
| micro_conds = torch.cat([micro_conds[:last_batch_size], micro_conds[batch_p :batch_p + last_batch_size]]) | |
| else: | |
| last_batch_size = model_input.shape[0] | |
| prompt_embeds = prompt_embeds[:last_batch_size] | |
| encoder_hidden_states = encoder_hidden_states[:last_batch_size] | |
| micro_conds = micro_conds[:last_batch_size] | |
| model_output = self.transformer( | |
| hidden_states = model_input, | |
| micro_conds=micro_conds, | |
| pooled_projections=prompt_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| img_ids = img_ids, | |
| txt_ids = txt_ids, | |
| timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long), | |
| ) | |
| if guidance_scale > 1.0: | |
| uncond_logits, cond_logits = model_output.chunk(2) | |
| model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
| tmp_logits = torch.permute(model_output, (0, 2, 3, 1)).float() | |
| pad_logits = torch.full( | |
| (*tmp_logits.shape[:3], self.vocab_size - self.codebook_size), | |
| -torch.inf, | |
| device=tmp_logits.device, dtype=tmp_logits.dtype | |
| ) | |
| tmp_logits = torch.cat([tmp_logits, pad_logits], dim=-1) | |
| return tmp_logits | |
| def _decode_latents(self, latents, height, width, output_type): | |
| batch_size = latents.shape[0] | |
| if output_type == "latent": | |
| output = latents | |
| else: | |
| needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast # type: ignore | |
| if needs_upcasting: | |
| self.vqvae.float() | |
| output = self.vqvae.decode( | |
| latents, | |
| force_not_quantize=True, | |
| shape=( | |
| batch_size, | |
| height // self.vae_scale_factor, | |
| width // self.vae_scale_factor, | |
| self.vqvae.config.latent_channels, # type: ignore | |
| ), | |
| ).sample.clip(0, 1) # type: ignore | |
| output = self.image_processor.postprocess(output, output_type) | |
| if needs_upcasting: | |
| self.vqvae.half() | |
| return output | |
| def _decode_one_hot_latents(self, latents_one_hot, height, width, output_type): | |
| batch_size = latents_one_hot.shape[0] | |
| shape = ( | |
| batch_size, | |
| height // self.vae_scale_factor, | |
| width // self.vae_scale_factor, | |
| self.vqvae.config.latent_channels, # type: ignore | |
| ) | |
| codebook_size = self.transformer.config.codebook_size #type: ignore | |
| needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast # type: ignore | |
| if needs_upcasting: | |
| self.vqvae.float() | |
| # get quantized latent vectors | |
| embedding = self.vqvae.quantize.embedding.weight | |
| h: torch.Tensor = latents_one_hot[..., :codebook_size].to(embedding.dtype) @ embedding | |
| h = h.view(shape) | |
| # reshape back to match original input shape | |
| h = h.permute(0, 3, 1, 2).contiguous() | |
| # Setting lookup_from_codebook to False, as we already have the codebook embeddings in h | |
| self.vqvae.config.lookup_from_codebook = False # type: ignore | |
| output = self.vqvae.decode( | |
| h, # type: ignore | |
| force_not_quantize=True, | |
| ).sample.clip(0, 1) # type: ignore | |
| self.vqvae.config.lookup_from_codebook = True # type: ignore | |
| output = self.image_processor.postprocess(output, output_type) | |
| if needs_upcasting: | |
| self.vqvae.half() | |
| return output | |
| def _subs_parameterization(self, logits, latents): | |
| B, H, W, C = logits.shape | |
| logits = logits.view(B, H * W, C) | |
| assert latents.shape == (B, H, W) | |
| latents = latents.view(B, H * W) | |
| logits = logits - torch.logsumexp(logits, dim=-1, | |
| keepdim=True) | |
| unmasked_indices = (latents != self.mask_index) | |
| logits[unmasked_indices] = -torch.inf | |
| logits[unmasked_indices, latents[unmasked_indices]] = 0 | |
| logits = logits.view(B, H, W, C) | |
| return logits | |