Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import PIL | |
| import torch | |
| from pytorch_lightning import seed_everything | |
| from torchvision import transforms | |
| from . import sample_utils | |
| VERSION2SPECS = { | |
| "vwm": {"config": "configs/inference/vista.yaml", "ckpt": "ckpts/vista.safetensors"} | |
| } | |
| DATASET2SOURCES = { | |
| "NUSCENES": {"data_root": "data/nuscenes", "anno_file": "annos/nuScenes_val.json"}, | |
| "IMG": {"data_root": "image_folder"}, | |
| } | |
| def parse_args(**parser_kwargs): | |
| parser = argparse.ArgumentParser(**parser_kwargs) | |
| parser.add_argument("--version", type=str, default="vwm", help="model version") | |
| parser.add_argument("--dataset", type=str, default="NUSCENES", help="dataset name") | |
| parser.add_argument( | |
| "--save", type=str, default="outputs", help="directory to save samples" | |
| ) | |
| parser.add_argument( | |
| "--action", | |
| type=str, | |
| default="free", | |
| help="action mode for control, such as traj, cmd, steer, goal", | |
| ) | |
| parser.add_argument( | |
| "--n_rounds", type=int, default=1, help="number of sampling rounds" | |
| ) | |
| parser.add_argument( | |
| "--n_frames", type=int, default=25, help="number of frames for each round" | |
| ) | |
| parser.add_argument( | |
| "--n_conds", | |
| type=int, | |
| default=1, | |
| help="number of initial condition frames for the first round", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=23, help="random seed for seed_everything" | |
| ) | |
| parser.add_argument( | |
| "--height", type=int, default=576, help="target height of the generated video" | |
| ) | |
| parser.add_argument( | |
| "--width", type=int, default=1024, help="target width of the generated video" | |
| ) | |
| parser.add_argument( | |
| "--cfg_scale", | |
| type=float, | |
| default=2.5, | |
| help="scale of the classifier-free guidance", | |
| ) | |
| parser.add_argument( | |
| "--cond_aug", type=float, default=0.0, help="strength of the noise augmentation" | |
| ) | |
| parser.add_argument( | |
| "--n_steps", type=int, default=50, help="number of sampling steps" | |
| ) | |
| parser.add_argument( | |
| "--rand_gen", | |
| action="store_false", | |
| help="whether to generate samples randomly or sequentially", | |
| ) | |
| parser.add_argument( | |
| "--low_vram", action="store_true", help="whether to save memory or not" | |
| ) | |
| return parser | |
| def get_sample( | |
| selected_index=0, dataset_name="NUSCENES", num_frames=25, action_mode="free" | |
| ): | |
| dataset_dict = DATASET2SOURCES[dataset_name] | |
| action_dict = None | |
| if dataset_name == "IMG": | |
| image_list = os.listdir(dataset_dict["data_root"]) | |
| total_length = len(image_list) | |
| while selected_index >= total_length: | |
| selected_index -= total_length | |
| image_file = image_list[selected_index] | |
| path_list = [os.path.join(dataset_dict["data_root"], image_file)] * num_frames | |
| else: | |
| with open(dataset_dict["anno_file"]) as anno_json: | |
| all_samples = json.load(anno_json) | |
| total_length = len(all_samples) | |
| while selected_index >= total_length: | |
| selected_index -= total_length | |
| sample_dict = all_samples[selected_index] | |
| path_list = list() | |
| if dataset_name == "NUSCENES": | |
| for index in range(num_frames): | |
| image_path = os.path.join( | |
| dataset_dict["data_root"], sample_dict["frames"][index] | |
| ) | |
| assert os.path.exists(image_path), image_path | |
| path_list.append(image_path) | |
| if action_mode != "free": | |
| action_dict = dict() | |
| if action_mode == "traj" or action_mode == "trajectory": | |
| action_dict["trajectory"] = torch.tensor(sample_dict["traj"][2:]) | |
| elif action_mode == "cmd" or action_mode == "command": | |
| action_dict["command"] = torch.tensor(sample_dict["cmd"]) | |
| elif action_mode == "steer": | |
| # scene might be empty | |
| if sample_dict["speed"]: | |
| action_dict["speed"] = torch.tensor(sample_dict["speed"][1:]) | |
| # scene might be empty | |
| if sample_dict["angle"]: | |
| action_dict["angle"] = ( | |
| torch.tensor(sample_dict["angle"][1:]) / 780 | |
| ) | |
| elif action_mode == "goal": | |
| # point might be invalid | |
| if ( | |
| sample_dict["z"] > 0 | |
| and 0 < sample_dict["goal"][0] < 1600 | |
| and 0 < sample_dict["goal"][1] < 900 | |
| ): | |
| action_dict["goal"] = torch.tensor( | |
| [ | |
| sample_dict["goal"][0] / 1600, | |
| sample_dict["goal"][1] / 900, | |
| ] | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported action mode {action_mode}") | |
| else: | |
| raise ValueError(f"Invalid dataset {dataset_name}") | |
| return path_list, selected_index, total_length, action_dict | |
| def load_img(file_name, target_height=320, target_width=576, device="cuda"): | |
| if file_name is not None: | |
| image = PIL.Image.open(file_name) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| else: | |
| raise ValueError(f"Invalid image file {file_name}") | |
| ori_w, ori_h = image.size | |
| # print(f"Loaded input image of size ({ori_w}, {ori_h})") | |
| if ori_w / ori_h > target_width / target_height: | |
| tmp_w = int(target_width / target_height * ori_h) | |
| left = (ori_w - tmp_w) // 2 | |
| right = (ori_w + tmp_w) // 2 | |
| image = image.crop((left, 0, right, ori_h)) | |
| elif ori_w / ori_h < target_width / target_height: | |
| tmp_h = int(target_height / target_width * ori_w) | |
| top = (ori_h - tmp_h) // 2 | |
| bottom = (ori_h + tmp_h) // 2 | |
| image = image.crop((0, top, ori_w, bottom)) | |
| image = image.resize((target_width, target_height), resample=PIL.Image.LANCZOS) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| image = transforms.Compose( | |
| [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] | |
| )(image) | |
| return image.to(device) | |
| if __name__ == "__main__": | |
| parser = parse_args() | |
| opt, unknown = parser.parse_known_args() | |
| sample_utils.set_lowvram_mode(opt.low_vram) | |
| version_dict = VERSION2SPECS[opt.version] | |
| model = sample_utils.init_model(version_dict) | |
| unique_keys = set([x.input_key for x in model.conditioner.embedders]) | |
| sample_index = 0 | |
| while sample_index >= 0: | |
| seed_everything(opt.seed) | |
| frame_list, sample_index, dataset_length, action_dict = get_sample( | |
| sample_index, opt.dataset, opt.n_frames, opt.action | |
| ) | |
| img_seq = list() | |
| for each_path in frame_list: | |
| img = load_img(each_path, opt.height, opt.width) | |
| img_seq.append(img) | |
| images = torch.stack(img_seq) | |
| value_dict = sample_utils.init_embedder_options(unique_keys) | |
| cond_img = img_seq[0][None] | |
| value_dict["cond_frames_without_noise"] = cond_img | |
| value_dict["cond_aug"] = opt.cond_aug | |
| value_dict["cond_frames"] = cond_img + opt.cond_aug * torch.randn_like(cond_img) | |
| if action_dict is not None: | |
| for key, value in action_dict.items(): | |
| value_dict[key] = value | |
| if opt.n_rounds > 1: | |
| guider = "TrianglePredictionGuider" | |
| else: | |
| guider = "VanillaCFG" | |
| sampler = sample_utils.init_sampling( | |
| guider=guider, | |
| steps=opt.n_steps, | |
| cfg_scale=opt.cfg_scale, | |
| num_frames=opt.n_frames, | |
| ) | |
| uc_keys = [ | |
| "cond_frames", | |
| "cond_frames_without_noise", | |
| "command", | |
| "trajectory", | |
| "speed", | |
| "angle", | |
| "goal", | |
| ] | |
| out = sample_utils.do_sample( | |
| images, | |
| model, | |
| sampler, | |
| value_dict, | |
| num_rounds=opt.n_rounds, | |
| num_frames=opt.n_frames, | |
| force_uc_zero_embeddings=uc_keys, | |
| initial_cond_indices=[index for index in range(opt.n_conds)], | |
| ) | |
| if isinstance(out, (tuple, list)): | |
| samples, samples_z, inputs = out | |
| virtual_path = os.path.join(opt.save, "virtual") | |
| real_path = os.path.join(opt.save, "real") | |
| sample_utils.perform_save_locally( | |
| virtual_path, samples, "videos", opt.dataset, sample_index | |
| ) | |
| sample_utils.perform_save_locally( | |
| virtual_path, samples, "grids", opt.dataset, sample_index | |
| ) | |
| sample_utils.perform_save_locally( | |
| virtual_path, samples, "images", opt.dataset, sample_index | |
| ) | |
| sample_utils.perform_save_locally( | |
| real_path, inputs, "videos", opt.dataset, sample_index | |
| ) | |
| sample_utils.perform_save_locally( | |
| real_path, inputs, "grids", opt.dataset, sample_index | |
| ) | |
| sample_utils.perform_save_locally( | |
| real_path, inputs, "images", opt.dataset, sample_index | |
| ) | |
| else: | |
| raise TypeError | |
| if opt.rand_gen: | |
| sample_index += random.randint(1, dataset_length - 1) | |
| else: | |
| sample_index += 1 | |
| if dataset_length <= sample_index: | |
| sample_index = -1 | |