| import math |
| import os |
| import torch |
| import argparse |
| import torchvision |
|
|
| from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, |
| EulerDiscreteScheduler, DPMSolverMultistepScheduler, |
| HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, |
| DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) |
| from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler |
| from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder |
| from omegaconf import OmegaConf |
| from torchvision.utils import save_image |
| from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer |
|
|
| import os, sys |
|
|
| from opensora.models.ae import ae_stride_config, getae, getae_wrapper |
| from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper |
| from opensora.models.diffusion.latte.modeling_latte import LatteT2V |
| from opensora.models.text_encoder import get_text_enc |
| from opensora.utils.utils import save_video_grid |
|
|
| sys.path.append(os.path.split(sys.path[0])[0]) |
| from pipeline_videogen import VideoGenPipeline |
|
|
| import imageio |
|
|
|
|
| def main(args): |
| |
| torch.set_grad_enabled(False) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16) |
| if args.enable_tiling: |
| vae.vae.enable_tiling() |
| vae.vae.tile_overlap_factor = args.tile_overlap_factor |
|
|
| |
| transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) |
| transformer_model.force_images = args.force_images |
| tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir") |
| text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) |
|
|
| video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1]) |
| latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) |
| vae.latent_size = latent_size |
| if args.force_images: |
| video_length = 1 |
| ext = 'jpg' |
| else: |
| ext = 'mp4' |
|
|
| |
| transformer_model.eval() |
| vae.eval() |
| text_encoder.eval() |
|
|
| if args.sample_method == 'DDIM': |
| scheduler = DDIMScheduler() |
| elif args.sample_method == 'EulerDiscrete': |
| scheduler = EulerDiscreteScheduler() |
| elif args.sample_method == 'DDPM': |
| scheduler = DDPMScheduler() |
| elif args.sample_method == 'DPMSolverMultistep': |
| scheduler = DPMSolverMultistepScheduler() |
| elif args.sample_method == 'DPMSolverSinglestep': |
| scheduler = DPMSolverSinglestepScheduler() |
| elif args.sample_method == 'PNDM': |
| scheduler = PNDMScheduler() |
| elif args.sample_method == 'HeunDiscrete': |
| scheduler = HeunDiscreteScheduler() |
| elif args.sample_method == 'EulerAncestralDiscrete': |
| scheduler = EulerAncestralDiscreteScheduler() |
| elif args.sample_method == 'DEISMultistep': |
| scheduler = DEISMultistepScheduler() |
| elif args.sample_method == 'KDPM2AncestralDiscrete': |
| scheduler = KDPM2AncestralDiscreteScheduler() |
| print('videogen_pipeline', device) |
| videogen_pipeline = VideoGenPipeline(vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| scheduler=scheduler, |
| transformer=transformer_model).to(device=device) |
| |
|
|
| if not os.path.exists(args.save_img_path): |
| os.makedirs(args.save_img_path) |
|
|
| video_grids = [] |
| if not isinstance(args.text_prompt, list): |
| args.text_prompt = [args.text_prompt] |
| if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): |
| text_prompt = open(args.text_prompt[0], 'r').readlines() |
| args.text_prompt = [i.strip() for i in text_prompt] |
| for prompt in args.text_prompt: |
| print('Processing the ({}) prompt'.format(prompt)) |
| videos = videogen_pipeline(prompt, |
| video_length=video_length, |
| height=image_size, |
| width=image_size, |
| num_inference_steps=args.num_sampling_steps, |
| guidance_scale=args.guidance_scale, |
| enable_temporal_attentions=not args.force_images, |
| num_images_per_prompt=1, |
| mask_feature=True, |
| ).video |
| try: |
| if args.force_images: |
| videos = videos[:, 0].permute(0, 3, 1, 2) |
| save_image(videos / 255.0, os.path.join(args.save_img_path, |
| prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), |
| nrow=1, normalize=True, value_range=(0, 1)) |
|
|
| else: |
| imageio.mimwrite( |
| os.path.join( |
| args.save_img_path, |
| prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}' |
| ), videos[0], |
| fps=args.fps, quality=9) |
| except: |
| print('Error when saving {}'.format(prompt)) |
| video_grids.append(videos) |
| video_grids = torch.cat(video_grids, dim=0) |
|
|
|
|
| |
| if args.force_images: |
| save_image(video_grids / 255.0, os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), |
| nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1)) |
| else: |
| video_grids = save_video_grid(video_grids) |
| imageio.mimwrite(os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), video_grids, fps=args.fps, quality=9) |
|
|
| print('save path {}'.format(args.save_img_path)) |
|
|
| |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') |
| parser.add_argument("--version", type=str, default='65x512x512', choices=['65x512x512', '65x256x256', '17x256x256']) |
| parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') |
| parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') |
| parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") |
| parser.add_argument("--guidance_scale", type=float, default=7.5) |
| parser.add_argument("--sample_method", type=str, default="PNDM") |
| parser.add_argument("--num_sampling_steps", type=int, default=50) |
| parser.add_argument("--fps", type=int, default=24) |
| parser.add_argument("--run_time", type=int, default=0) |
| parser.add_argument("--text_prompt", nargs='+') |
| parser.add_argument('--force_images', action='store_true') |
| parser.add_argument('--tile_overlap_factor', type=float, default=0.25) |
| parser.add_argument('--enable_tiling', action='store_true') |
| args = parser.parse_args() |
|
|
| main(args) |