World_Model / URSA /diffnext /pipelines /nova /pipeline_nova.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Video generation pipeline for NOVA."""
from typing import List
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
import numpy as np
import torch
from diffnext.image_processor import VaeImageProcessor
from diffnext.pipelines.pipeline_utils import NOVAPipelineOutput, PipelineMixin
class NOVAPipeline(DiffusionPipeline, PipelineMixin):
"""NOVA video generation pipeline."""
_optional_components = ["transformer", "scheduler", "vae", "text_encoder", "tokenizer"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
transformer=None,
scheduler=None,
vae=None,
text_encoder=None,
tokenizer=None,
trust_remote_code=True,
):
super(NOVAPipeline, self).__init__()
self.vae = self.register_module(vae, "vae")
self.text_encoder = self.register_module(text_encoder, "text_encoder")
self.tokenizer = self.register_module(tokenizer, "tokenizer")
self.transformer = self.register_module(transformer, "transformer")
self.scheduler = self.register_module(scheduler, "scheduler")
self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0
if self.transformer.text_embed:
self.tokenizer_max_length = self.transformer.text_embed.num_tokens
self.transformer.text_embed.encoders = [self.tokenizer, self.text_encoder]
self.image_processor = VaeImageProcessor()
@torch.no_grad()
def __call__(
self,
prompt=None,
num_inference_steps=64,
num_diffusion_steps=25,
max_latent_length=1,
guidance_scale=5,
guidance_trunc=0,
guidance_renorm=1,
image_guidance_scale=0,
spatiotemporal_guidance_scale=0,
flow_shift=None,
motion_score=5,
negative_prompt=None,
image=None,
num_images_per_prompt=1,
generator=None,
latents=None,
prompt_embeds=None,
negative_prompt_embeds=None,
disable_progress_bar=False,
output_type="pil",
**kwargs,
) -> NOVAPipelineOutput:
"""The call function to the pipeline for generation.
Args:
prompt (str or List[str], *optional*):
The prompt to be encoded.
num_inference_steps (int, *optional*, defaults to 64):
The number of autoregressive steps.
num_diffusion_steps (int, *optional*, defaults to 25):
The number of denoising steps.
max_latent_length (int, *optional*, defaults to 1):
The maximum number of latents to generate. ``1`` for image generation.
guidance_scale (float, *optional*, defaults to 5):
The classifier guidance scale.
guidance_trunc (float, *optional*, defaults to 0):
The truncation threshold to classifier guidance.
guidance_renorm (float, *optional*, defaults to 1):
The minimal renorm scale to classifier guidance.
image_guidance_scale (float, *optional*, defaults to 0):
The image guidance scale.
spatiotemporal_guidance_scale (float, *optional*, defaults to 0):
The spatiotemporal guidance scale.
flow_shift (float, *optional*)
The shift value for the timestep schedule.
motion_score (float, *optional*, defaults to 5):
The motion score value for video generation.
negative_prompt (str or List[str], *optional*):
The prompt or prompts to guide what to not include in image generation.
image (numpy.ndarray, *optional*):
The image to be encoded.
num_images_per_prompt (int, *optional*, defaults to 1):
The number of images that should be generated per prompt.
generator (torch.Generator, *optional*):
The random generator.
latents (List[torch.Tensor], *optional*)
A list of prefilled VAE latents.
prompt_embeds (List[torch.Tensor], *optional*)
A list of precomputed prompt embeddings.
negative_prompt_embeds (List[torch.Tensor], *optional*)
A list of precomputed negative prompt embeddings.
disable_progress_bar (bool, *optional*)
Whether to disable all progress bars.
output_type (str, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
Returns:
NOVAPipelineOutput: The pipeline output.
"""
self.guidance_scale = guidance_scale
self.scheduler.set_shift(flow_shift) if flow_shift else None
inputs = {"generator": generator, **locals()}
num_patches = int(np.prod(self.transformer.config.image_base_size))
mask_ratios = np.cos(0.5 * np.pi * np.arange(num_inference_steps + 1) / num_inference_steps)
mask_length = np.round(mask_ratios * num_patches).astype("int64")
inputs["num_preds"] = mask_length[:-1] - mask_length[1:]
inputs["tqdm1"] = max_latent_length > 1 and not disable_progress_bar
inputs["tqdm2"] = max_latent_length == 1 and not disable_progress_bar
inputs["prompt"] = self.encode_prompt(**dict(_ for _ in inputs.items() if "prompt" in _[0]))
inputs["latents"] = self.prepare_latents(image, num_images_per_prompt, generator, latents)
inputs["batch_size"] = len(inputs["prompt"]) // (2 if guidance_scale > 1 else 1)
inputs["motion"] = [motion_score] * inputs["batch_size"]
_, outputs = inputs.pop("self"), self.transformer(inputs)
if output_type != "latent":
outputs["x"] = self.image_processor.decode_latents(self.vae, outputs["x"])
output_name = {4: "images", 5: "frames"}[len(outputs["x"].shape)]
outputs["x"] = self.image_processor.postprocess(outputs["x"], output_type)
return NOVAPipelineOutput(**{output_name: outputs["x"]})
def prepare_latents(
self,
image=None,
num_images_per_prompt=1,
generator=None,
latents=None,
) -> List[torch.Tensor]:
"""Prepare the video latents.
Args:
image (numpy.ndarray, *optional*):
The image to be encoded.
num_images_per_prompt (int, *optional*, defaults to 1):
The number of images that should be generated per prompt.
generator (torch.Generator, *optional*):
The random generator.
latents (List[torch.Tensor], *optional*)
A list of prefilled VAE latents.
Returns:
List[torch.Tensor]: The encoded latents.
"""
if latents is not None:
return latents
latents = []
if image is not None:
latents.append(self.encode_image(image, num_images_per_prompt, generator))
return latents
def encode_prompt(
self,
prompt,
num_images_per_prompt=1,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
) -> torch.Tensor:
"""Encode text prompts.
Args:
prompt (str or List[str], *optional*):
The prompt to be encoded.
num_images_per_prompt (int, *optional*, defaults to 1):
The number of images that should be generated per prompt.
negative_prompt (str or List[str], *optional*):
The prompt or prompts not to guide the image generation.
prompt_embeds (List[torch.Tensor], *optional*)
A list of precomputed prompt embeddings.
negative_prompt_embeds (List[torch.Tensor], *optional*)
A list of precomputed negative prompt embeddings.
Returns:
torch.Tensor: The prompt embedding.
"""
def select_or_pad(a, b, n=1):
return [a or b] * n if isinstance(a or b, str) else (a or b)
embedder = self.transformer.text_embed
if prompt_embeds is not None:
prompt_embeds = embedder.encode_prompts(prompt_embeds)
if negative_prompt_embeds is not None:
negative_prompt_embeds = embedder.encode_prompts(negative_prompt_embeds)
if prompt_embeds is not None:
if negative_prompt_embeds is None and self.guidance_scale > 1:
bs, seqlen = prompt_embeds.shape[:2]
negative_prompt_embeds = embedder.weight[:seqlen].expand(bs, -1, -1)
if self.guidance_scale > 1:
c = torch.cat([prompt_embeds, negative_prompt_embeds])
return c.repeat_interleave(num_images_per_prompt, dim=0)
prompt = [prompt] if isinstance(prompt, str) else prompt
negative_prompt = select_or_pad(negative_prompt, "", len(prompt))
prompts = prompt + (negative_prompt if self.guidance_scale > 1 else [])
c = embedder.encode_prompts(prompts)
return c.repeat_interleave(num_images_per_prompt, dim=0)
def encode_image(self, image, num_images_per_prompt=1, generator=None) -> torch.Tensor:
"""Encode image prompt.
Args:
image (numpy.ndarray):
The image to be encoded.
num_images_per_prompt (int):
The number of images that should be generated per prompt.
generator (torch.Generator, *optional*):
The random generator.
Returns:
torch.Tensor: The image embedding.
"""
x = torch.as_tensor(image, device=self.device).to(dtype=self.dtype)
x = x.sub(127.5).div_(127.5).permute(2, 0, 1).unsqueeze_(0)
x = self.vae.scale_(self.vae.encode(x).latent_dist.sample(generator))
return x.expand(num_images_per_prompt, -1, -1, -1)