Spaces:
Runtime error
Runtime error
| from typing import List | |
| from torch import _validate_compressed_sparse_indices | |
| from torchvision.utils import save_image | |
| from videogen_hub import MODEL_PATH | |
| from with_mask_sample import * | |
| class SEINEPipeline(): | |
| def __init__(self, seine_path: str = os.path.join(MODEL_PATH, "SEINE", "seine.pt"), | |
| pretrained_model_path: str = os.path.join(MODEL_PATH, "SEINE", "stable-diffusion-v1-4"), | |
| config_path: str = "src/videogen_hub/pipelines/seine/sample_i2v.yaml"): | |
| """ | |
| Load the configuration file and set the paths of models. | |
| Args: | |
| seine_path: The path of the downloaded seine pretrained model. | |
| pretrained_model_path: The path of the downloaded stable diffusion pretrained model. | |
| config_path: The path of the configuration file. | |
| """ | |
| self.config = OmegaConf.load(config_path) | |
| self.config.ckpt = seine_path | |
| self.config.pretrained_model_path = pretrained_model_path | |
| def infer_one_video(self, input_image, | |
| text_prompt: List = [], | |
| output_size: List = [240, 560], | |
| num_frames: int = 16, | |
| num_sampling_steps: int = 250, | |
| seed: int = 42, | |
| save_video: bool = False): | |
| """ | |
| Generate video based on provided input_image and text_prompt. | |
| Args: | |
| input_image: The input image to generate video. | |
| text_prompt: The text prompt to generate video. | |
| output_size: The size of the generated video. Defaults to [240, 560]. | |
| num_frames: number of frames of the generated video. Defaults to 16. | |
| num_sampling_steps: number of sampling steps to generate the video. Defaults to 250. | |
| seed: The random seed for video generation. Defaults to 42. | |
| save_video: save the video to the path in config if it is True. Not save if it is False. Defaults to False. | |
| Returns: | |
| The generated video as tensor with shape (num_frames, channels, height, width). | |
| """ | |
| self.config.image_size = output_size | |
| self.config.num_frames = num_frames | |
| self.config.num_sampling_steps = num_sampling_steps | |
| self.config.seed = seed | |
| self.config.text_prompt = text_prompt | |
| print(input_image, type(input_image) == str) | |
| if type(input_image) == str: | |
| self.config.input_path = input_image | |
| else: | |
| assert torch.is_tensor(input_image) | |
| assert len(input_image.shape) == 3 | |
| assert input_image.shape[0] == 3 | |
| save_image(input_image, "src/videogen_hub/pipelines/seine/input_image.png") | |
| args = self.config | |
| # Setup PyTorch: | |
| if args.seed: | |
| torch.manual_seed(args.seed) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # device = "cpu" | |
| if args.ckpt is None: | |
| raise ValueError("Please specify a checkpoint path using --ckpt <path>") | |
| # Load model: | |
| latent_h = args.image_size[0] // 8 | |
| latent_w = args.image_size[1] // 8 | |
| args.image_h = args.image_size[0] | |
| args.image_w = args.image_size[1] | |
| args.latent_h = latent_h | |
| args.latent_w = latent_w | |
| print('loading model') | |
| model = get_models(args).to(device) | |
| if args.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| model.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| # load model | |
| ckpt_path = args.ckpt | |
| state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] | |
| model.load_state_dict(state_dict) | |
| print('loading succeed') | |
| model.eval() | |
| pretrained_model_path = args.pretrained_model_path | |
| diffusion = create_diffusion(str(args.num_sampling_steps)) | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) | |
| text_encoder = TextEmbedder(pretrained_model_path).to(device) | |
| if args.use_fp16: | |
| print('Warnning: using half percision for inferencing!') | |
| vae.to(dtype=torch.float16) | |
| model.to(dtype=torch.float16) | |
| text_encoder.to(dtype=torch.float16) | |
| # prompt: | |
| prompt = args.text_prompt | |
| if prompt is None or prompt == []: | |
| prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ') | |
| else: | |
| prompt = prompt[0] | |
| prompt_base = prompt.replace(' ', '_') | |
| prompt = prompt + args.additional_prompt | |
| if save_video: | |
| if not os.path.exists(os.path.join(args.save_path)): | |
| os.makedirs(os.path.join(args.save_path)) | |
| video_input, researve_frames = get_input(args) # f,c,h,w | |
| video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w | |
| mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w | |
| masked_video = video_input * (mask == 0) | |
| video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, | |
| device, ) | |
| video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, | |
| 1) | |
| if save_video: | |
| save_video_path = os.path.join(args.save_path, prompt_base + '.mp4') | |
| torchvision.io.write_video(save_video_path, video_, fps=8) | |
| print(f'save in {save_video_path}') | |
| return video_.permute(0, 3, 1, 2) | |