Spaces:
Runtime error
Runtime error
| # Prediction interface for Cog ⚙️ | |
| # https://github.com/replicate/cog/blob/main/docs/python.md | |
| import os | |
| import sys | |
| import argparse | |
| import random | |
| from omegaconf import OmegaConf | |
| from einops import rearrange, repeat | |
| import torch | |
| import torchvision | |
| from pytorch_lightning import seed_everything | |
| from cog import BasePredictor, Input, Path | |
| sys.path.insert(0, "scripts/evaluation") | |
| from funcs import ( | |
| batch_ddim_sampling, | |
| load_model_checkpoint, | |
| load_image_batch, | |
| get_filelist, | |
| ) | |
| from utils.utils import instantiate_from_config | |
| class Predictor(BasePredictor): | |
| def setup(self) -> None: | |
| """Load the model into memory to make running multiple predictions efficient""" | |
| ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt" | |
| config_base = "configs/inference_t2v_1024_v1.0.yaml" | |
| ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt" | |
| config_i2v = "configs/inference_i2v_512_v1.0.yaml" | |
| config_base = OmegaConf.load(config_base) | |
| model_config_base = config_base.pop("model", OmegaConf.create()) | |
| self.model_base = instantiate_from_config(model_config_base) | |
| self.model_base = self.model_base.cuda() | |
| self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base) | |
| self.model_base.eval() | |
| config_i2v = OmegaConf.load(config_i2v) | |
| model_config_i2v = config_i2v.pop("model", OmegaConf.create()) | |
| self.model_i2v = instantiate_from_config(model_config_i2v) | |
| self.model_i2v = self.model_i2v.cuda() | |
| self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v) | |
| self.model_i2v.eval() | |
| def predict( | |
| self, | |
| task: str = Input( | |
| description="Choose the task.", | |
| choices=["text2video", "image2video"], | |
| default="text2video", | |
| ), | |
| prompt: str = Input( | |
| description="Prompt for video generation.", | |
| default="A tiger walks in the forest, photorealistic, 4k, high definition.", | |
| ), | |
| image: Path = Input( | |
| description="Input image for image2video task.", default=None | |
| ), | |
| ddim_steps: int = Input(description="Number of denoising steps.", default=50), | |
| unconditional_guidance_scale: float = Input( | |
| description="Classifier-free guidance scale.", default=12.0 | |
| ), | |
| seed: int = Input( | |
| description="Random seed. Leave blank to randomize the seed", default=None | |
| ), | |
| save_fps: int = Input( | |
| description="Frame per second for the generated video.", default=10 | |
| ), | |
| ) -> Path: | |
| width = 1024 if task == "text2video" else 512 | |
| height = 576 if task == "text2video" else 320 | |
| model = self.model_base if task == "text2video" else self.model_i2v | |
| if task == "image2video": | |
| assert image is not None, "Please provide image for image2video generation." | |
| if seed is None: | |
| seed = int.from_bytes(os.urandom(2), "big") | |
| print(f"Using seed: {seed}") | |
| seed_everything(seed) | |
| args = argparse.Namespace( | |
| mode="base" if task == "text2video" else "i2v", | |
| savefps=save_fps, | |
| n_samples=1, | |
| ddim_steps=ddim_steps, | |
| ddim_eta=1.0, | |
| bs=1, | |
| height=height, | |
| width=width, | |
| frames=-1, | |
| fps=28 if task == "text2video" else 8, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_guidance_scale_temporal=None, | |
| ) | |
| ## latent noise shape | |
| h, w = args.height // 8, args.width // 8 | |
| frames = model.temporal_length if args.frames < 0 else args.frames | |
| channels = model.channels | |
| batch_size = 1 | |
| noise_shape = [batch_size, channels, frames, h, w] | |
| fps = torch.tensor([args.fps] * batch_size).to(model.device).long() | |
| prompts = [prompt] | |
| text_emb = model.get_learned_conditioning(prompts) | |
| if args.mode == "base": | |
| cond = {"c_crossattn": [text_emb], "fps": fps} | |
| elif args.mode == "i2v": | |
| cond_images = load_image_batch([str(image)], (args.height, args.width)) | |
| cond_images = cond_images.to(model.device) | |
| img_emb = model.get_image_embeds(cond_images) | |
| imtext_cond = torch.cat([text_emb, img_emb], dim=1) | |
| cond = {"c_crossattn": [imtext_cond], "fps": fps} | |
| else: | |
| raise NotImplementedError | |
| ## inference | |
| batch_samples = batch_ddim_sampling( | |
| model, | |
| cond, | |
| noise_shape, | |
| args.n_samples, | |
| args.ddim_steps, | |
| args.ddim_eta, | |
| args.unconditional_guidance_scale, | |
| ) | |
| out_path = "/tmp/output.mp4" | |
| vid_tensor = batch_samples[0] | |
| video = vid_tensor.detach().cpu() | |
| video = torch.clamp(video.float(), -1.0, 1.0) | |
| video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w | |
| frame_grids = [ | |
| torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) | |
| for framesheet in video | |
| ] # [3, 1*h, n*w] | |
| grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] | |
| grid = (grid + 1.0) / 2.0 | |
| grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) | |
| torchvision.io.write_video( | |
| out_path, | |
| grid, | |
| fps=args.savefps, | |
| video_codec="h264", | |
| options={"crf": "10"}, | |
| ) | |
| return Path(out_path) | |