Spaces:
Runtime error
Runtime error
| import argparse, os, sys, glob | |
| import datetime, time | |
| from omegaconf import OmegaConf | |
| import math | |
| import torch | |
| from decord import VideoReader, cpu | |
| import torchvision | |
| from pytorch_lightning import seed_everything | |
| from lvdm.samplers.ddim import DDIMSampler | |
| from lvdm.utils.common_utils import instantiate_from_config | |
| from lvdm.utils.saving_utils import tensor_to_mp4 | |
| from scripts.sample_text2video_adapter import load_model_checkpoint, adapter_guided_synthesis | |
| import torchvision.transforms._transforms_video as transforms_video | |
| from huggingface_hub import hf_hub_download | |
| def load_video(filepath, frame_stride, video_size=(256,256), video_frames=16): | |
| info_str = '' | |
| vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0]) | |
| max_frames = len(vidreader) | |
| # auto | |
| if frame_stride != 0: | |
| if frame_stride * (video_frames-1) >= max_frames: | |
| info_str += "Warning: The user-set frame rate makes the current video length not enough, we will set it to an adaptive frame rate.\n" | |
| frame_stride = 0 | |
| if frame_stride == 0: | |
| frame_stride = max_frames / video_frames | |
| # if temp_stride < 1: | |
| # info_str = "Warning: The length of the current input video is less than 16 frames, we will automatically fill to 16 frames for you.\n" | |
| if frame_stride > 100: | |
| frame_stride = 100 | |
| info_str += "Warning: The current input video length is longer than 1600 frames, we will process only the first 1600 frames.\n" | |
| info_str += f"Frame Stride is set to {frame_stride}" | |
| frame_indices = [int(frame_stride*i) for i in range(video_frames)] | |
| frames = vidreader.get_batch(frame_indices) | |
| ## [t,h,w,c] -> [c,t,h,w] | |
| frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() | |
| frame_tensor = (frame_tensor / 255. - 0.5) * 2 | |
| return frame_tensor, info_str | |
| class VideoControl: | |
| def __init__(self, result_dir='./tmp/') -> None: | |
| self.savedir = result_dir | |
| self.download_model() | |
| config_path = "models/adapter_t2v_depth/model_config.yaml" | |
| ckpt_path = "models/base_t2v/model_rm_wtm.ckpt" | |
| adapter_ckpt = "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth" | |
| if os.path.exists('/dev/shm/model_rm_wtm.ckpt'): | |
| ckpt_path='/dev/shm/model_rm_wtm.ckpt' | |
| config = OmegaConf.load(config_path) | |
| model_config = config.pop("model", OmegaConf.create()) | |
| model = instantiate_from_config(model_config) | |
| model = model.to('cuda') | |
| assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" | |
| model = load_model_checkpoint(model, ckpt_path, adapter_ckpt) | |
| model.eval() | |
| self.model = model | |
| def get_video(self, input_video, input_prompt, frame_stride=0, vc_steps=50, vc_cfg_scale=15.0, vc_eta=1.0, video_frames=16, resolution=256): | |
| torch.cuda.empty_cache() | |
| if resolution > 512: | |
| resolution = 512 | |
| if resolution < 64: | |
| resolution = 64 | |
| if video_frames > 64: | |
| video_frames = 64 | |
| resolution = int(resolution//64)*64 | |
| if vc_steps > 60: | |
| vc_steps = 60 | |
| ## load video | |
| print("input video", input_video) | |
| info_str = '' | |
| try: | |
| h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape | |
| except: | |
| os.remove(input_video) | |
| return 'please input video', None, None, None | |
| if h > w: | |
| scale = h / resolution | |
| else: | |
| scale = w / resolution | |
| h = math.ceil(h / scale) | |
| w = math.ceil(w / scale) | |
| try: | |
| video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=video_frames) | |
| except: | |
| os.remove(input_video) | |
| return 'load video error', None, None, None | |
| if h > w: | |
| w = int(w//64)*64 | |
| else: | |
| h = int(h//64)*64 | |
| spatial_transform = transforms_video.CenterCropVideo((h,w)) | |
| video = spatial_transform(video) | |
| print('video shape', video.shape) | |
| rh, rw = h//8, w//8 | |
| bs = 1 | |
| channels = self.model.channels | |
| # frames = self.model.temporal_length | |
| frames = video_frames | |
| noise_shape = [bs, channels, frames, rh, rw] | |
| ## inference | |
| start = time.time() | |
| prompt = input_prompt | |
| video = video.unsqueeze(0).to("cuda") | |
| try: | |
| with torch.no_grad(): | |
| batch_samples, batch_conds = adapter_guided_synthesis(self.model, prompt, video, noise_shape, n_samples=1, ddim_steps=vc_steps, ddim_eta=vc_eta, unconditional_guidance_scale=vc_cfg_scale) | |
| except: | |
| torch.cuda.empty_cache() | |
| info_str="OOM, please enter a smaller resolution or smaller frame num" | |
| return info_str, None, None, None | |
| batch_samples = batch_samples[0] | |
| os.makedirs(self.savedir, exist_ok=True) | |
| filename = prompt | |
| filename = filename.replace("/", "_slash_") if "/" in filename else filename | |
| filename = filename.replace(" ", "_") if " " in filename else filename | |
| if len(filename) > 200: | |
| filename = filename[:200] | |
| video_path = os.path.join(self.savedir, f'{filename}_sample.mp4') | |
| depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4') | |
| origin_path = os.path.join(self.savedir, f'{filename}.mp4') | |
| tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8) | |
| tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8) | |
| tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8) | |
| print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds") | |
| # delete video | |
| (path, input_filename) = os.path.split(input_video) | |
| if input_filename != 'flamingo.mp4': | |
| os.remove(input_video) | |
| print('delete input video') | |
| # print(input_video) | |
| return info_str, origin_path, depth_path, video_path | |
| def download_model(self): | |
| REPO_ID = 'VideoCrafter/t2v-version-1-1' | |
| filename_list = ['models/base_t2v/model_rm_wtm.ckpt', | |
| "models/adapter_t2v_depth/adapter_t2v_depth_rm_wtm.pth", | |
| "models/adapter_t2v_depth/dpt_hybrid-midas.pt" | |
| ] | |
| for filename in filename_list: | |
| if not os.path.exists(filename): | |
| hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False) | |
| if __name__ == "__main__": | |
| vc = VideoControl('./result') | |
| info_str, video_path = vc.get_video('input/flamingo.mp4',"An ostrich walking in the desert, photorealistic, 4k") |