Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import math | |
| import os | |
| import queue | |
| from typing import Optional, Union | |
| import numpy as np | |
| import rerun as rr | |
| import torch | |
| import torchvision | |
| from einops import rearrange, repeat | |
| from omegaconf import ListConfig, OmegaConf | |
| from PIL import Image | |
| from safetensors.torch import load_file as load_safetensors | |
| from torch import autocast | |
| from tqdm import tqdm | |
| from .vwm.modules.diffusionmodules.sampling import EulerEDMSampler | |
| from .vwm.util import default, instantiate_from_config | |
| def init_model(version_dict, load_ckpt=True): | |
| config = OmegaConf.load(version_dict["config"]) | |
| model = load_model_from_config(config, version_dict["ckpt"] if load_ckpt else None) | |
| return model | |
| lowvram_mode = True | |
| def set_lowvram_mode(mode): | |
| global lowvram_mode | |
| lowvram_mode = mode | |
| def initial_model_load(model): | |
| global lowvram_mode | |
| if lowvram_mode: | |
| model.model.half() | |
| else: | |
| model.cuda() | |
| return model | |
| def load_model(model): | |
| model.cuda() | |
| def unload_model(model): | |
| global lowvram_mode | |
| print(lowvram_mode) | |
| if lowvram_mode: | |
| model.cpu() | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def load_model_from_config(config, ckpt=None): | |
| model = instantiate_from_config(config.model) | |
| print(ckpt) | |
| if ckpt is not None: | |
| print(f"Loading model from {ckpt}") | |
| if ckpt.endswith("ckpt"): | |
| pl_svd = torch.load(ckpt, map_location="cpu") | |
| # dict contains: | |
| # "epoch", "global_step", "pytorch-lightning_version", | |
| # "state_dict", "loops", "callbacks", "optimizer_states", "lr_schedulers" | |
| if "global_step" in pl_svd: | |
| print(f"Global step: {pl_svd['global_step']}") | |
| svd = pl_svd["state_dict"] | |
| else: | |
| svd = load_safetensors(ckpt) | |
| missing, unexpected = model.load_state_dict(svd, strict=False) | |
| if len(missing) > 0: | |
| print(f"Missing keys: {missing}") | |
| if len(unexpected) > 0: | |
| print(f"Unexpected keys: {unexpected}") | |
| model = initial_model_load(model) | |
| model.eval() | |
| return model | |
| def init_embedder_options(keys): | |
| # hardcoded demo settings, might undergo some changes in the future | |
| value_dict = dict() | |
| for key in keys: | |
| if key in ["fps_id", "fps"]: | |
| fps = 10 | |
| value_dict["fps"] = fps | |
| value_dict["fps_id"] = fps - 1 | |
| elif key == "motion_bucket_id": | |
| value_dict["motion_bucket_id"] = 127 # [0, 511] | |
| return value_dict | |
| def perform_save_locally(save_path, samples, mode, dataset_name, sample_index): | |
| assert mode in ["images", "grids", "videos"] | |
| merged_path = os.path.join(save_path, mode) | |
| os.makedirs(merged_path, exist_ok=True) | |
| samples = samples.cpu() | |
| if mode == "images": | |
| frame_count = 0 | |
| for sample in samples: | |
| sample = rearrange(sample.numpy(), "c h w -> h w c") | |
| if "real" in save_path: | |
| sample = 255.0 * (sample + 1.0) / 2.0 | |
| else: | |
| sample = 255.0 * sample | |
| image_save_path = os.path.join( | |
| merged_path, f"{dataset_name}_{sample_index:06}_{frame_count:04}.png" | |
| ) | |
| # if os.path.exists(image_save_path): | |
| # return | |
| Image.fromarray(sample.astype(np.uint8)).save(image_save_path) | |
| frame_count += 1 | |
| elif mode == "grids": | |
| grid = torchvision.utils.make_grid(samples, nrow=int(samples.shape[0] ** 0.5)) | |
| grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy() | |
| if "real" in save_path: | |
| grid = 255.0 * (grid + 1.0) / 2.0 | |
| else: | |
| grid = 255.0 * grid | |
| grid_save_path = os.path.join( | |
| merged_path, f"{dataset_name}_{sample_index:06}.png" | |
| ) | |
| # if os.path.exists(grid_save_path): | |
| # return | |
| Image.fromarray(grid.astype(np.uint8)).save(grid_save_path) | |
| elif mode == "videos": | |
| img_seq = rearrange(samples.numpy(), "t c h w -> t h w c") | |
| if "real" in save_path: | |
| img_seq = 255.0 * (img_seq + 1.0) / 2.0 | |
| else: | |
| img_seq = 255.0 * img_seq | |
| video_save_path = os.path.join( | |
| merged_path, f"{dataset_name}_{sample_index:06}.mp4" | |
| ) | |
| # if os.path.exists(video_save_path): | |
| # return | |
| save_img_seq_to_video(video_save_path, img_seq.astype(np.uint8), 10) | |
| else: | |
| raise NotImplementedError | |
| def init_sampling( | |
| sampler="EulerEDMSampler", | |
| guider="VanillaCFG", | |
| discretization="EDMDiscretization", | |
| steps=50, | |
| cfg_scale=2.5, | |
| num_frames=25, | |
| ): | |
| discretization_config = get_discretization(discretization) | |
| guider_config = get_guider(guider, cfg_scale, num_frames) | |
| sampler = get_sampler(sampler, steps, discretization_config, guider_config) | |
| return sampler | |
| def get_discretization(discretization): | |
| if discretization == "LegacyDDPMDiscretization": | |
| discretization_config = { | |
| "target": "vista.vwm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" | |
| } | |
| elif discretization == "EDMDiscretization": | |
| discretization_config = { | |
| "target": "vista.vwm.modules.diffusionmodules.discretizer.EDMDiscretization", | |
| "params": {"sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0}, | |
| } | |
| else: | |
| raise NotImplementedError | |
| return discretization_config | |
| def get_guider(guider="LinearPredictionGuider", cfg_scale=2.5, num_frames=25): | |
| if guider == "IdentityGuider": | |
| guider_config = { | |
| "target": "vista.vwm.modules.diffusionmodules.guiders.IdentityGuider" | |
| } | |
| elif guider == "VanillaCFG": | |
| scale = cfg_scale | |
| guider_config = { | |
| "target": "vista.vwm.modules.diffusionmodules.guiders.VanillaCFG", | |
| "params": {"scale": scale}, | |
| } | |
| elif guider == "LinearPredictionGuider": | |
| max_scale = cfg_scale | |
| min_scale = 1.0 | |
| guider_config = { | |
| "target": "vista.vwm.modules.diffusionmodules.guiders.LinearPredictionGuider", | |
| "params": { | |
| "max_scale": max_scale, | |
| "min_scale": min_scale, | |
| "num_frames": num_frames, | |
| }, | |
| } | |
| elif guider == "TrianglePredictionGuider": | |
| max_scale = cfg_scale | |
| min_scale = 1.0 | |
| guider_config = { | |
| "target": "vista.vwm.modules.diffusionmodules.guiders.TrianglePredictionGuider", | |
| "params": { | |
| "max_scale": max_scale, | |
| "min_scale": min_scale, | |
| "num_frames": num_frames, | |
| }, | |
| } | |
| else: | |
| raise NotImplementedError | |
| return guider_config | |
| def get_sampler(sampler, steps, discretization_config, guider_config): | |
| if sampler == "EulerEDMSampler": | |
| s_churn = 0.0 | |
| s_tmin = 0.0 | |
| s_tmax = 999.0 | |
| s_noise = 1.0 | |
| sampler = EulerEDMSampler( | |
| num_steps=steps, | |
| discretization_config=discretization_config, | |
| guider_config=guider_config, | |
| s_churn=s_churn, | |
| s_tmin=s_tmin, | |
| s_tmax=s_tmax, | |
| s_noise=s_noise, | |
| verbose=False, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown sampler {sampler}") | |
| return sampler | |
| def get_batch(keys, value_dict, N: Union[list, ListConfig], device="cuda"): | |
| # hardcoded demo setups, might undergo some changes in the future | |
| batch = dict() | |
| batch_uc = dict() | |
| for key in keys: | |
| if key in value_dict: | |
| if key in ["fps", "fps_id", "motion_bucket_id", "cond_aug"]: | |
| batch[key] = repeat( | |
| torch.tensor([value_dict[key]]).to(device), "1 -> b", b=math.prod(N) | |
| ) | |
| elif key in ["command", "trajectory", "speed", "angle", "goal"]: | |
| batch[key] = repeat( | |
| value_dict[key][None].to(device), "1 ... -> b ...", b=N[0] | |
| ) | |
| elif key in ["cond_frames", "cond_frames_without_noise"]: | |
| batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) | |
| else: | |
| # batch[key] = value_dict[key] | |
| raise NotImplementedError | |
| for key in batch.keys(): | |
| if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
| batch_uc[key] = torch.clone(batch[key]) | |
| return batch, batch_uc | |
| def get_condition(model, value_dict, num_samples, force_uc_zero_embeddings, device): | |
| load_model(model.conditioner) | |
| batch, batch_uc = get_batch( | |
| list(set([x.input_key for x in model.conditioner.embedders])), | |
| value_dict, | |
| [num_samples], | |
| ) | |
| c, uc = model.conditioner.get_unconditional_conditioning( | |
| batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings | |
| ) | |
| unload_model(model.conditioner) | |
| for k in c: | |
| if isinstance(c[k], torch.Tensor): | |
| c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) | |
| if c[k].shape[0] < num_samples: | |
| c[k] = c[k][[0]] | |
| if uc[k].shape[0] < num_samples: | |
| uc[k] = uc[k][[0]] | |
| return c, uc | |
| def fill_latent(cond, length, cond_indices, device): | |
| latent = torch.zeros(length, *cond.shape[1:]).to(device) | |
| latent[cond_indices] = cond | |
| return latent | |
| def do_sample( | |
| images, | |
| model, | |
| sampler, | |
| value_dict, | |
| num_rounds, | |
| num_frames, | |
| force_uc_zero_embeddings: Optional[list] = None, | |
| initial_cond_indices: Optional[list] = None, | |
| device="cuda", | |
| log_queue: queue.SimpleQueue = None, | |
| ): | |
| if initial_cond_indices is None: | |
| initial_cond_indices = [0] | |
| force_uc_zero_embeddings = default(force_uc_zero_embeddings, list()) | |
| precision_scope = autocast | |
| with torch.no_grad(), precision_scope(device), model.ema_scope("Sampling"): | |
| c, uc = get_condition( | |
| model, value_dict, num_frames, force_uc_zero_embeddings, device | |
| ) | |
| load_model(model.first_stage_model) | |
| z = model.encode_first_stage(images) | |
| unload_model(model.first_stage_model) | |
| samples_z = torch.zeros((num_rounds * (num_frames - 3) + 3, *z.shape[1:])).to( | |
| device | |
| ) | |
| sampling_progress = tqdm(total=num_rounds, desc="Compute sequences") | |
| def denoiser(x, sigma, cond, cond_mask): | |
| return model.denoiser(model.model, x, sigma, cond, cond_mask) | |
| load_model(model.denoiser) | |
| load_model(model.model) | |
| initial_cond_mask = torch.zeros(num_frames).to(device) | |
| prediction_cond_mask = torch.zeros(num_frames).to(device) | |
| initial_cond_mask[initial_cond_indices] = 1 | |
| prediction_cond_mask[[0, 1, 2]] = 1 | |
| generated_images = [] | |
| noise = torch.randn_like(z) | |
| sample = sampler( | |
| denoiser, | |
| noise, | |
| cond=c, | |
| uc=uc, | |
| cond_frame=z, # cond_frame will be rescaled when calling the sampler | |
| cond_mask=initial_cond_mask, | |
| num_sequence=0, | |
| log_queue=log_queue, | |
| ) | |
| sampling_progress.update(1) | |
| sample[0] = z[0] | |
| samples_z[:num_frames] = sample | |
| generated_images.append(decode_samples(sample[:num_frames], model)) | |
| for i, generated_image in enumerate(generated_images[-1]): | |
| log_queue.put( | |
| ( | |
| "generated_image", | |
| rr.Image(generated_image.cpu().permute(1, 2, 0)), | |
| [ | |
| ("frame_id", i), | |
| ("diffusion", 0), | |
| ( | |
| "combined", | |
| 1 + 2 * 0 + (i * 1.0 / len(generated_images[-1])), | |
| ), | |
| ], | |
| ) | |
| ) | |
| for n in range(num_rounds - 1): | |
| load_model(model.first_stage_model) | |
| samples_x_for_guidance = model.decode_first_stage(sample[-14:]) | |
| unload_model(model.first_stage_model) | |
| value_dict["cond_frames_without_noise"] = samples_x_for_guidance[[-3]] | |
| value_dict["cond_frames"] = sample[[-3]] / model.scale_factor | |
| for embedder in model.conditioner.embedders: | |
| if hasattr(embedder, "skip_encode"): | |
| embedder.skip_encode = True | |
| c, uc = get_condition( | |
| model, value_dict, num_frames, force_uc_zero_embeddings, device | |
| ) | |
| for embedder in model.conditioner.embedders: | |
| if hasattr(embedder, "skip_encode"): | |
| embedder.skip_encode = False | |
| filled_latent = fill_latent(sample[-3:], num_frames, [0, 1, 2], device) | |
| noise = torch.randn_like(filled_latent) | |
| sample = sampler( | |
| denoiser, | |
| noise, | |
| cond=c, | |
| uc=uc, | |
| cond_frame=filled_latent, # cond_frame will be rescaled when calling the sampler | |
| cond_mask=prediction_cond_mask, | |
| num_sequence=n + 1, | |
| log_queue=log_queue, | |
| ) | |
| sampling_progress.update(1) | |
| first_frame_id = (n + 1) * (num_frames - 3) + 3 | |
| last_frame_id = (n + 1) * (num_frames - 3) + num_frames | |
| samples_z[first_frame_id:last_frame_id] = sample[3:] | |
| generated_images.append(decode_samples(sample[3:], model)) | |
| for i, generated_image in enumerate(generated_images[-1]): | |
| log_queue.put( | |
| ( | |
| "generated_image", | |
| rr.Image(generated_image.cpu().permute(1, 2, 0)), | |
| [ | |
| ("frame_id", first_frame_id + i), | |
| ("diffusion", 0), | |
| ( | |
| "combined", | |
| 1 + 2 * (n + 1) + (i * 1.0 / len(generated_images[-1])), | |
| ), | |
| ], | |
| ) | |
| ) | |
| unload_model(model.model) | |
| unload_model(model.denoiser) | |
| generated_images = torch.concat(generated_images, dim=0) | |
| return generated_images, samples_z, images | |
| def decode_samples(samples, model): | |
| load_model(model.first_stage_model) | |
| samples_x = model.decode_first_stage(samples) | |
| unload_model(model.first_stage_model) | |
| samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
| return samples | |