smc_meissonic / src /smc /pipeline.py
cp524's picture
Add support for non-divisible batch sizes
f1f0a8f
from typing import Optional, Tuple, Callable, List
import math
import torch
import torch.nn.functional as F
from tqdm import tqdm
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.vq_model import VQModel
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from src.smc.transformer import Transformer2DModel
from src.smc.scheduler import BaseScheduler
from src.smc.resampling import compute_ess_from_log_w, normalize_weights
from src.smc.lora_pipeline import MeissonicLoraLoaderMixin
def logmeanexp(x, dim=None, keepdim=False):
"""Numerically stable log-mean-exp using torch.logsumexp."""
if dim is None:
x = x.view(-1)
dim = 0
# log-sum-exp with or without keeping the reduced dim
lse = torch.logsumexp(x, dim=dim, keepdim=keepdim)
# subtract log(N) to convert sum into mean (broadcasts correctly)
return lse - math.log(x.size(dim))
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
"""
Build positional IDs for latent-image tokens.
Each latent token corresponds to a downsampled image “pixel” in a 2D grid.
This function creates a (H//2, W//2, 3) grid where:
- channel 0 is reserved (all zeros)
- channel 1 stores the row index (0 .. H//2-1)
- channel 2 stores the column index (0 .. W//2-1)
Args:
batch_size (int): Number of images in the batch (unused here, but kept for API consistency).
height (int): Input image height (pre-VAE) or latent height depending on call site.
width (int): Input image width (pre-VAE) or latent width depending on call site.
device (torch.device): Device on which to place the returned tensor.
dtype (torch.dtype): Desired data type of the returned tensor.
Returns:
torch.Tensor of shape ((H//2 * W//2), 3) with dtype and device as specified.
Each row is [0, row_index, col_index], flattened in row-major order.
"""
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
class Pipeline(
DiffusionPipeline,
MeissonicLoraLoaderMixin,
):
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
text_encoder: CLIPTextModelWithProjection
transformer: Transformer2DModel
scheduler: BaseScheduler
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
def __init__(
self,
vqvae: VQModel,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
transformer: Transformer2DModel,
scheduler: BaseScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) # type: ignore
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
self.model_dtype = torch.bfloat16
self.mask_index = self.scheduler.mask_token_id # type: ignore
self.vocab_size = self.transformer.config.vocab_size # type:ignore
self.codebook_size = self.transformer.config.codebook_size # type: ignore
@torch.no_grad()
def __call__(
self,
prompt: str|List[str],
reward_fn: Callable,
resample_fn: Callable,
resample_frequency: int = 1,
kl_weight: float = 1.0,
lambdas: Optional[torch.Tensor] = None,
height: Optional[int] = 1024,
width: Optional[int] = 1024,
num_inference_steps: int = 48,
guidance_scale: float = 9.0,
negative_prompt = None,
batches: int = 1, # Number of independent SMCs
num_particles: int = 1, # Number of particles per SMC
batch_p: int = 1, # Number of parallel particles
phi: int = 1, # number of samples for reward approximation
tau: float = 1.0, # temperature for taking x0 samples
output_type="pil",
micro_conditioning_aesthetic_score: int = 6,
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
proposal_type:str = "locally_optimal",
ft_model_pipe = None, # needs to supplied if proposal_type is ft_model
use_ft_model_for_expected_reward: bool = False, # Whether to use the forward model for expected reward
use_continuous_formulation: bool = False, # Whether to use a continuous formulation of carry over unmasking
disable_progress_bar: bool = False,
final_strategy="argmax_rewards",
verbose=True,
):
# 0. Set default lambdas
if lambdas is None:
lambdas = torch.ones(num_inference_steps + 1)
assert len(lambdas) == num_inference_steps + 1, f"lambdas must of length {num_inference_steps + 1}"
lambdas = lambdas.clamp_min(0.001).to(self._execution_device)
# 1. n_particles, batch_size etc
total_particles = batches * num_particles
batch_p = min(batch_p, total_particles)
H, W = height // self.vae_scale_factor, width // self.vae_scale_factor
# 2.1. Calculate prompt (and negative prompt) embeddings
if isinstance(prompt, str):
prompt = [prompt]
input_ids = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=77,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
prompt_embeds = outputs.text_embeds
encoder_hidden_states = outputs.hidden_states[-2]
prompt_embeds = prompt_embeds.repeat(batch_p, 1)
encoder_hidden_states = encoder_hidden_states.repeat(batch_p, 1, 1)
if guidance_scale > 1.0:
if negative_prompt is None:
negative_prompt = [""]
else:
negative_prompt = [negative_prompt]
input_ids = self.tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=77,
).input_ids.to(self._execution_device)
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
negative_prompt_embeds = outputs.text_embeds
negative_encoder_hidden_states = outputs.hidden_states[-2]
negative_prompt_embeds = negative_prompt_embeds.repeat(batch_p, 1)
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(batch_p, 1, 1)
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
# 2.2. Prepare micro-conditions
micro_conds = torch.tensor(
[
width,
height,
micro_conditioning_crop_coord[0],
micro_conditioning_crop_coord[1],
micro_conditioning_aesthetic_score,
],
device=self._execution_device,
dtype=encoder_hidden_states.dtype,
)
micro_conds = micro_conds.unsqueeze(0)
micro_conds = micro_conds.expand(2 * batch_p if guidance_scale > 1.0 else batch_p, -1)
# 3. Intialize latents
latents = torch.full(
(total_particles, H, W), self.mask_index, dtype=torch.long, device=self._execution_device # type: ignore
)
# Set some constant vectors
ONE = torch.ones(self.vocab_size, device=self._execution_device).float()
MASK = F.one_hot(torch.tensor(self.mask_index), num_classes=self.vocab_size).float().to(self._execution_device) # type: ignore
# 4. Set scheduler timesteps
self.scheduler.set_timesteps(num_inference_steps)
# 5. Set SMC variables
logits = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device)
logits_ft_model = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device)
rewards = torch.zeros((total_particles,), device=self._execution_device)
rewards_grad = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device)
log_twist = torch.zeros((total_particles, ), device=self._execution_device)
log_prob_proposal = torch.zeros((total_particles, ), device=self._execution_device)
log_prob_diffusion = torch.zeros((total_particles, ), device=self._execution_device)
log_w = torch.zeros((total_particles, ), device=self._execution_device)
def propagate():
if proposal_type == "locally_optimal":
propgate_locally_optimal()
# elif proposal_type == "straight_through_gradients":
# propagate_straight_through_gradients()
elif proposal_type == "reverse":
propagate_reverse()
elif proposal_type == "without_SMC":
propagate_without_SMC()
elif proposal_type == "ft_model":
propagate_ft_model()
else:
raise NotImplementedError(f"Proposal type {proposal_type} is not implemented.")
def propgate_locally_optimal():
nonlocal log_w, latents, log_prob_proposal, log_prob_diffusion, logits, rewards, rewards_grad, log_twist
log_twist_prev = log_twist.clone()
for j in range(0, total_particles, batch_p):
latents_batch = latents[j:j+batch_p]
with torch.enable_grad():
latents_one_hot = F.one_hot(latents_batch, num_classes=self.vocab_size).to(dtype=self.model_dtype).requires_grad_(True)
tmp_logits = self.get_logits(latents_one_hot, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device)
gamma = 1 - ((ONE - MASK) * latents_one_hot).sum(dim=-1, keepdim=True)
for phi_i in range(phi):
sample = F.gumbel_softmax(tmp_logits, tau=tau, hard=True)
if use_continuous_formulation:
sample = gamma * sample + (ONE - MASK) * latents_one_hot
sample = self._decode_one_hot_latents(sample, height, width, "pt")
tmp_rewards[:, phi_i] = reward_fn(sample)
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
tmp_rewards_grad = torch.autograd.grad(
outputs=tmp_rewards,
inputs=latents_one_hot,
grad_outputs=torch.ones_like(tmp_rewards)
)[0].detach()
logits[j:j+batch_p] = tmp_logits.detach()
rewards[j:j+batch_p] = tmp_rewards.detach()
rewards_grad[j:j+batch_p] = tmp_rewards_grad.detach()
log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur
if verbose:
print("Rewards: ", rewards)
# Calculate weights
incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev)
log_w += incremental_log_w
# Now reshape log_w to (batches, num_particles)
log_w = log_w.reshape(batches, num_particles)
if verbose:
print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal)
print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
print("Incremental log weights: ", incremental_log_w)
print("Log weights: ", log_w)
print("Normalized weights: ", normalize_weights(log_w, dim=-1))
# Resample particles
if verbose:
print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1))
if resample_condition:
resample_indices = []
log_w_new = []
is_resampled = False
for batch in range(batches):
resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch])
resample_indices.append(resample_indices_batch + batch * num_particles)
log_w_new.append(log_w_batch)
is_resampled = is_resampled or is_resampled_batch
resample_indices = torch.cat(resample_indices, dim=0)
log_w = torch.cat(log_w_new, dim=0)
if is_resampled:
latents = latents[resample_indices]
logits = logits[resample_indices]
rewards = rewards[resample_indices]
rewards_grad = rewards_grad[resample_indices]
log_twist = log_twist[resample_indices]
if verbose:
print("Resample indices: ", resample_indices)
if log_w.ndim == 2:
log_w = log_w.reshape(total_particles)
# Propose new particles
sched_out = self.scheduler.step_with_approx_guidance(
latents=latents,
logits=logits,
approx_guidance=rewards_grad * scale_next,
step=i,
)
if verbose:
print("Approx guidance norm: ", ((rewards_grad * scale_next) ** 2).sum(dim=(1, 2)).sqrt())
latents, log_prob_proposal, log_prob_diffusion = (
sched_out.new_latents,
sched_out.log_prob_proposal,
sched_out.log_prob_diffusion,
)
def propagate_reverse():
nonlocal log_w, latents, logits, rewards, log_twist
log_twist_prev = log_twist.clone()
for j in range(0, total_particles, batch_p):
latents_batch = latents[j:j+batch_p]
with torch.no_grad():
tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device)
tmp_logp_x0 = self.model._subs_parameterization(tmp_logits, latents_batch)
for phi_i in range(phi):
sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
sample = self._decode_latents(sample, height, width, "pt")
tmp_rewards[:, phi_i] = reward_fn(sample)
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
logits[j:j+batch_p] = tmp_logits.detach()
rewards[j:j+batch_p] = tmp_rewards.detach()
log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur
if verbose:
print("Rewards: ", rewards)
# Calculate weights
incremental_log_w = (log_twist - log_twist_prev)
log_w += incremental_log_w
# Now reshape log_w to (batches, num_particles)
log_w = log_w.reshape(batches, num_particles)
if verbose:
print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
print("Incremental log weights: ", incremental_log_w)
print("Log weights: ", log_w)
print("Normalized weights: ", normalize_weights(log_w, dim=-1))
# Resample particles
if verbose:
print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1))
if resample_condition:
resample_indices = []
log_w_new = []
is_resampled = False
for batch in range(batches):
resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch])
resample_indices.append(resample_indices_batch + batch * num_particles)
log_w_new.append(log_w_batch)
is_resampled = is_resampled or is_resampled_batch
resample_indices = torch.cat(resample_indices, dim=0)
log_w = torch.cat(log_w_new, dim=0)
if is_resampled:
latents = latents[resample_indices]
logits = logits[resample_indices]
rewards = rewards[resample_indices]
log_twist = log_twist[resample_indices]
if verbose:
print("Resample indices: ", resample_indices)
if log_w.ndim == 2:
log_w = log_w.reshape(total_particles)
# Propose new particles
sched_out = self.scheduler.step(
latents=latents,
logits=logits,
step=i,
)
latents = sched_out.new_latents
def propagate_without_SMC():
nonlocal latents, logits
for j in range(0, total_particles, batch_p):
latents_batch = latents[j:j+batch_p]
with torch.no_grad():
tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
logits[j:j+batch_p] = tmp_logits.detach()
# Propose new particles
sched_out = self.scheduler.step(
latents=latents,
logits=logits,
step=i,
)
latents = sched_out.new_latents
def propagate_ft_model():
assert ft_model_pipe is not None, f"ft_model must be provided for proposal_type={proposal_type}."
nonlocal log_w, latents, log_prob_proposal, log_prob_diffusion, logits, logits_ft_model, rewards, log_twist
log_twist_prev = log_twist.clone()
for j in range(0, total_particles, batch_p):
latents_batch = latents[j:j+batch_p]
with torch.no_grad():
tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
tmp_logits_ft_model = ft_model_pipe.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device)
if use_ft_model_for_expected_reward:
tmp_logp_x0 = ft_model_pipe._subs_parameterization(tmp_logits_ft_model, latents_batch)
else:
tmp_logp_x0 = self._subs_parameterization(tmp_logits, latents_batch)
for phi_i in range(phi):
sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
sample = self._decode_latents(sample, height, width, "pt")
tmp_rewards[:, phi_i] = reward_fn(sample)
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
logits[j:j+batch_p] = tmp_logits.detach()
logits_ft_model[j:j+batch_p] = tmp_logits_ft_model.detach()
rewards[j:j+batch_p] = tmp_rewards.detach()
log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur
if verbose:
print("Rewards: ", rewards)
# Calculate weights
incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev)
log_w += incremental_log_w
# Now reshape log_w to (batches, num_particles)
log_w = log_w.reshape(batches, num_particles)
if verbose:
print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal)
print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
print("Incremental log weights: ", incremental_log_w)
print("Log weights: ", log_w)
print("Normalized weights: ", normalize_weights(log_w, dim=-1))
# Resample particles
if verbose:
print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1))
if resample_condition:
resample_indices = []
log_w_new = []
is_resampled = False
for batch in range(batches):
resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch])
resample_indices.append(resample_indices_batch + batch * num_particles)
log_w_new.append(log_w_batch)
is_resampled = is_resampled or is_resampled_batch
resample_indices = torch.cat(resample_indices, dim=0)
log_w = torch.cat(log_w_new, dim=0)
if is_resampled:
latents = latents[resample_indices]
logits = logits[resample_indices]
logits_ft_model = logits_ft_model[resample_indices]
rewards = rewards[resample_indices]
log_twist = log_twist[resample_indices]
if verbose:
print("Resample indices: ", resample_indices)
if log_w.ndim == 2:
log_w = log_w.reshape(total_particles)
# Propose new particles
approx_guidance = logits_ft_model - logits # this effectively makes logits_ft_model the proposal distribution
approx_guidance[..., self.codebook_size:] = 0.0 # avoid nan due to (inf - inf)
sched_out = self.scheduler.step_with_approx_guidance(
latents=latents,
logits=logits,
approx_guidance=approx_guidance,
step=i,
)
latents, log_prob_proposal, log_prob_diffusion = (
sched_out.new_latents,
sched_out.log_prob_proposal,
sched_out.log_prob_diffusion,
)
bar = enumerate(reversed(range(num_inference_steps)))
if not disable_progress_bar:
bar = tqdm(bar, leave=False)
for i, timestep in bar:
resample_condition = (i + 1) % resample_frequency == 0
scale_cur = lambdas[i] / kl_weight
scale_next = lambdas[i + 1] / kl_weight
if verbose:
print(f"scale_cur: {scale_cur}, scale_next: {scale_next}")
propagate()
print('\n\n')
# Final SMC weights
scale_cur = lambdas[-1] / kl_weight
log_twist_prev = log_twist.clone()
for j in range(0, total_particles, batch_p):
latents_batch = latents[j:j+batch_p]
with torch.no_grad():
sample = self._decode_latents(latents_batch, height, width, "pt")
tmp_rewards = reward_fn(sample)
rewards[j:j+batch_p] = tmp_rewards
log_twist[j:j+batch_p] = tmp_rewards * scale_cur
if verbose:
print("Rewards: ", rewards)
# Calculate weights
incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev)
log_w += incremental_log_w
# Now reshape everything to (batches, num_particles) for final strategy
log_w = log_w.reshape(batches, num_particles)
latents = latents.reshape(batches, num_particles, H, W)
rewards = rewards.reshape(batches, num_particles)
if verbose:
print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal)
print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
print("Incremental log weights: ", incremental_log_w)
print("Log weights: ", log_w)
print("Normalized weights: ", normalize_weights(log_w, dim=-1))
if final_strategy == "multinomial":
final_indices = torch.multinomial(normalize_weights(log_w, dim=-1), num_samples=1).squeeze(-1)
elif final_strategy == "argmax_rewards":
final_indices = rewards.argmax(dim=-1)
elif final_strategy == "argmax_weights":
final_indices = log_w.argmax(dim=-1)
else:
raise NotImplementedError(f"Final strategy {final_strategy} is not implemented.")
if verbose:
print("Final selected indices: ", final_indices)
latents = latents[
torch.arange(batches, device=latents.device),
final_indices
]
# Decode latents
outputs = []
for j in range(0, batches, batch_p):
latents_batch = latents[j:j+batch_p]
outputs.extend(
self._decode_latents(latents_batch, height, width, output_type) # type: ignore
)
if output_type == "pt":
outputs = torch.stack(outputs, dim=0)
return outputs
def get_logits(self, latents, guidance_scale, resolution, encoder_hidden_states, micro_conds, prompt_embeds, timestep):
if guidance_scale > 1.0:
# Latents are duplicated to get both unconditional and conditional logits
model_input = torch.cat([latents] * 2) # type: ignore
else:
model_input = latents
# img_ids, text_ids are used for positional embeddings
if resolution == 1024: #args.resolution == 1024:
img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[1],model_input.shape[2],model_input.device,model_input.dtype)
else:
img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[1],2*model_input.shape[2],model_input.device,model_input.dtype)
txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
if prompt_embeds.shape[0] != model_input.shape[0]:
# This can happen for the last batch (if batch_p is not divisble by total particles)
if guidance_scale > 1.0:
batch_p = prompt_embeds.shape[0] // 2
last_batch_size = model_input.shape[0] // 2
prompt_embeds = torch.cat([prompt_embeds[:last_batch_size], prompt_embeds[batch_p :batch_p + last_batch_size]])
encoder_hidden_states = torch.cat([encoder_hidden_states[:last_batch_size], encoder_hidden_states[batch_p :batch_p + last_batch_size]])
micro_conds = torch.cat([micro_conds[:last_batch_size], micro_conds[batch_p :batch_p + last_batch_size]])
else:
last_batch_size = model_input.shape[0]
prompt_embeds = prompt_embeds[:last_batch_size]
encoder_hidden_states = encoder_hidden_states[:last_batch_size]
micro_conds = micro_conds[:last_batch_size]
model_output = self.transformer(
hidden_states = model_input,
micro_conds=micro_conds,
pooled_projections=prompt_embeds,
encoder_hidden_states=encoder_hidden_states,
img_ids = img_ids,
txt_ids = txt_ids,
timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
)
if guidance_scale > 1.0:
uncond_logits, cond_logits = model_output.chunk(2)
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
tmp_logits = torch.permute(model_output, (0, 2, 3, 1)).float()
pad_logits = torch.full(
(*tmp_logits.shape[:3], self.vocab_size - self.codebook_size),
-torch.inf,
device=tmp_logits.device, dtype=tmp_logits.dtype
)
tmp_logits = torch.cat([tmp_logits, pad_logits], dim=-1)
return tmp_logits
def _decode_latents(self, latents, height, width, output_type):
batch_size = latents.shape[0]
if output_type == "latent":
output = latents
else:
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast # type: ignore
if needs_upcasting:
self.vqvae.float()
output = self.vqvae.decode(
latents,
force_not_quantize=True,
shape=(
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels, # type: ignore
),
).sample.clip(0, 1) # type: ignore
output = self.image_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
return output
def _decode_one_hot_latents(self, latents_one_hot, height, width, output_type):
batch_size = latents_one_hot.shape[0]
shape = (
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels, # type: ignore
)
codebook_size = self.transformer.config.codebook_size #type: ignore
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast # type: ignore
if needs_upcasting:
self.vqvae.float()
# get quantized latent vectors
embedding = self.vqvae.quantize.embedding.weight
h: torch.Tensor = latents_one_hot[..., :codebook_size].to(embedding.dtype) @ embedding
h = h.view(shape)
# reshape back to match original input shape
h = h.permute(0, 3, 1, 2).contiguous()
# Setting lookup_from_codebook to False, as we already have the codebook embeddings in h
self.vqvae.config.lookup_from_codebook = False # type: ignore
output = self.vqvae.decode(
h, # type: ignore
force_not_quantize=True,
).sample.clip(0, 1) # type: ignore
self.vqvae.config.lookup_from_codebook = True # type: ignore
output = self.image_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
return output
def _subs_parameterization(self, logits, latents):
B, H, W, C = logits.shape
logits = logits.view(B, H * W, C)
assert latents.shape == (B, H, W)
latents = latents.view(B, H * W)
logits = logits - torch.logsumexp(logits, dim=-1,
keepdim=True)
unmasked_indices = (latents != self.mask_index)
logits[unmasked_indices] = -torch.inf
logits[unmasked_indices, latents[unmasked_indices]] = 0
logits = logits.view(B, H, W, C)
return logits