Text-to-Image
English
PAINE / predictor /inference /noise_selection.py
joonghk's picture
first commit
03de09d
from typing import Optional, Tuple, Union
import torch
def generate_noise_candidates(
num_candidates: int,
latent_shape: Tuple[int, ...],
device: str = "cuda",
dtype: torch.dtype = torch.float16,
generator: Optional[Union[torch.Generator, list]] = None,
) -> torch.Tensor:
if isinstance(generator, list):
generator = generator[0] if generator else None
return torch.randn(
(num_candidates, *latent_shape),
generator=generator,
device=device,
dtype=dtype,
)
def select_top_k_noise(
predictor: torch.nn.Module,
noises: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_mask: torch.Tensor,
num_select: int = 1,
head_index: int = 0,
batch_size: int = 64,
) -> torch.Tensor:
num_candidates = noises.shape[0]
device = noises.device
prompt_embeds = prompt_embeds.to(device)
prompt_mask = prompt_mask.to(device)
all_scores = []
with torch.no_grad():
for i in range(0, num_candidates, batch_size):
batch = noises[i:i + batch_size]
bs = batch.shape[0]
pe = prompt_embeds.expand(bs, -1, -1).float()
pm = prompt_mask.expand(bs, -1)
scores = predictor.predict(pe, batch.float(), pm)
if scores.dim() == 2 and scores.shape[1] > 1:
scores = scores[:, head_index]
else:
scores = scores.squeeze(-1)
all_scores.append(scores)
all_scores = torch.cat(all_scores, dim=0)
topk = torch.topk(all_scores, min(num_select, num_candidates))
return noises[topk.indices], topk.values