LD3 / gen_data.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import torch
import os
from tqdm import tqdm
from utils import (
get_solvers,
parse_arguments,
prepare_paths,
adjust_hyper,
)
from models import prepare_stuff, prepare_condition_loader
import time
import numpy as np
import PIL.Image
def get_data_inverse_scaler(centered=True):
"""Inverse data normalizer."""
if centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.0) / 2.0
else:
return lambda x: x
class Generator:
def __init__(
self,
noise_schedule,
solver,
order,
skip_type=None,
load_from=None,
timesteps_1=None,
timesteps_2=None,
steps=35,
solver_extra_params=None,
device=None,
) -> None:
self.device = device
self.noise_schedule = noise_schedule
self.solver = solver
self.order = order
self.skip_type = skip_type
self.load_from = load_from
self.timesteps_1 = timesteps_1
self.timesteps_2 = timesteps_2
self.steps = steps
self.solver_extra_params = solver_extra_params
self._precompute_timesteps()
def _precompute_timesteps(self):
if self.load_from is None and type(self.timesteps_1) == list and type(self.timesteps_1[0]) == float \
and type(self.timesteps_2) == list and type(self.timesteps_2[0]) == float:
self.timesteps = self.noise_schedule.inverse_lambda(-np.log(self.timesteps_1)).to(self.device).float()
self.timesteps2 = self.noise_schedule.inverse_lambda(-np.log(self.timesteps_2)).to(self.device).float()
else:
self.timesteps, self.timesteps2 = self.solver.prepare_timesteps(
steps=self.steps,
t_start=self.noise_schedule.T,
t_end=self.noise_schedule.eps,
skip_type=self.skip_type,
device=self.device,
load_from=self.load_from,
)
def _sample(self, net, decoding_fn, latents, condition=None, unconditional_condition=None):
x_next_ = self.noise_schedule.prior_transformation(latents)
x_next_ = self.solver.sample_simple(
model_fn=net,
x=x_next_,
timesteps=self.timesteps,
timesteps2=self.timesteps2,
order=self.order,
NFEs=self.steps,
condition=condition,
unconditional_condition=unconditional_condition,
**self.solver_extra_params,
)
x_next_ = decoding_fn(x_next_)
return x_next_
def sample(self, net, decoding_fn, latents, condition=None, unconditional_condition=None, no_grad=True):
if no_grad:
with torch.no_grad():
return self._sample(net, decoding_fn, latents, condition, unconditional_condition)
else:
return self._sample(net, decoding_fn, latents, condition, unconditional_condition)
def main(args):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
wrapped_model, model, decoding_fn, noise_schedule, latent_resolution, latent_channel, img_resolution, img_channel = prepare_stuff(args)
condition_loader = prepare_condition_loader(model_type=args.model,
model=model,
scale=args.scale if hasattr(args, "scale") else None,
condition=args.prompt_path or "random",
sampling_batch_size=args.sampling_batch_size,
num_prompt=args.num_prompts,
num_samples_per_prompt=args.num_samples_per_prompt,
)
adjust_hyper(args, latent_resolution, latent_channel)
desc, _, skip_type = prepare_paths(args)
data_dir = os.path.join(args.data_dir, desc)
os.makedirs(data_dir, exist_ok=True)
solver, steps, solver_extra_params = get_solvers(
args.solver_name,
NFEs=args.steps,
order=args.order,
noise_schedule=noise_schedule,
unipc_variant=args.unipc_variant,
)
generator = Generator(
noise_schedule=noise_schedule,
solver=solver,
order=args.order,
skip_type=skip_type,
load_from=args.load_from,
timesteps_1=args.custom_ts_1,
timesteps_2=args.custom_ts_2,
steps=steps,
solver_extra_params=solver_extra_params,
device=device,
)
print(generator.timesteps, generator.timesteps2)
inverse_scalar = get_data_inverse_scaler(centered=True)
start = time.time()
count = 0
batch_size = args.sampling_batch_size
if args.prompt_path is not None:
args.total_samples = min(args.total_samples, len(condition_loader.prompts))
num_batches = (args.total_samples + batch_size - 1) // batch_size
for i in tqdm(range(num_batches)):
current_batch_size = min(batch_size, args.total_samples - i * batch_size)
sampling_shape = (current_batch_size, latent_channel, latent_resolution, latent_resolution)
latents = torch.randn(sampling_shape, device=device)
if condition_loader is not None:
conditioning, conditioned_unconditioning = next(condition_loader)
else:
conditioning = None
conditioned_unconditioning = None
img_teacher = generator.sample(wrapped_model, decoding_fn, latents, conditioning, conditioned_unconditioning)
img_teacher = img_teacher.detach().cpu().view(current_batch_size, img_channel, img_resolution, img_resolution)
latents = latents.detach().cpu()
if args.save_pt:
for i in range(current_batch_size):
latent = latents[i]
img = img_teacher[i]
c = conditioning[i] if conditioning is not None else None
uc = conditioned_unconditioning[i] if conditioned_unconditioning is not None else None
data = dict(latent=latent, img=img, c=c, uc=uc)
torch.save(data, os.path.join(data_dir, f"latent_{(count + i):06d}.pt"))
if args.save_png:
samples_raw = inverse_scalar(img_teacher)
samples = np.clip(
samples_raw.permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255
).astype(np.uint8)
images_np = samples.reshape((-1, img_resolution, img_resolution, img_channel))
for i in range(current_batch_size):
image_np = images_np[i]
if args.prompt_path is not None and args.prompt_path.startswith('hpsv2'):
image_path = os.path.join(data_dir, f"{(count + i):05d}.jpg")
else:
image_path = os.path.join(data_dir, f"{(count + i):06d}.png")
if image_np.shape[2] == 1:
PIL.Image.fromarray(image_np[:, :, 0], "L").save(image_path)
else:
PIL.Image.fromarray(image_np, "RGB").save(image_path)
count += batch_size
end = time.time()
print(f"Generation time: {end - start}")
if __name__ == "__main__":
args = parse_arguments()
main(args)