World_Model / URSA /diffnext /pipelines /nova /pipeline_train_t2v.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Text-to-Video training pipeline for NOVA."""
from typing import Dict
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
import torch
from diffnext.engine import engine_utils
from diffnext.pipelines.builder import build_diffusion_scheduler
from diffnext.pipelines.pipeline_utils import PipelineMixin
class NOVATrainT2VPipeline(DiffusionPipeline, PipelineMixin):
"""Pipeline for training NOVA T2V models."""
_optional_components = ["transformer", "scheduler", "vae", "text_encoder", "tokenizer"]
def __init__(
self,
transformer=None,
scheduler=None,
vae=None,
text_encoder=None,
tokenizer=None,
trust_remote_code=True,
):
super(NOVATrainT2VPipeline, self).__init__()
self.vae = self.register_module(vae, "vae")
self.text_encoder = self.register_module(text_encoder, "text_encoder")
self.tokenizer = self.register_module(tokenizer, "tokenizer")
self.transformer = self.register_module(transformer, "transformer")
self.scheduler = self.register_module(scheduler, "scheduler")
self.transformer.noise_scheduler = build_diffusion_scheduler(self.scheduler)
self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0
@property
def model(self) -> torch.nn.Module:
"""Return the trainable model."""
return self.transformer
def configure_model(self, config, accelerator=None, logger=None) -> torch.nn.Module:
"""Configure the trainable model."""
ckpt_lvl = config.model.get("gradient_checkpointing", 0)
self.model.loss_repeat = config.model.get("loss_repeat", 4)
[setattr(blk, "mlp_checkpointing", ckpt_lvl) for blk in self.model.video_encoder.blocks]
for blk in self.model.image_encoder.blocks if hasattr(self.model, "image_encoder") else []:
setattr(blk, "mlp_checkpointing", ckpt_lvl > 1)
[setattr(blk, "mlp_checkpointing", ckpt_lvl > 2) for blk in self.model.image_decoder.blocks]
engine_utils.freeze_module(self.model.text_embed.norm) # We always use frozen LN.
engine_utils.freeze_module(self.model.motion_embed) if self.model.motion_embed else None
self.model.pipeline_preprocess = self.preprocess
self.model.text_embed.encoders = [self.tokenizer, self.text_encoder]
return self.model.train()
def prepare_latents(self, inputs: Dict):
"""Prepare the video latents."""
if "images" in inputs:
raise NotImplementedError
elif "latents" in inputs:
x = torch.as_tensor(inputs.pop("latents"), device=self.device)
x = x.to(dtype=self.dtype if x.is_floating_point() else torch.int64)
inputs["x"] = self.vae.scale_(self.vae.latent_dist(x).sample())
def encode_prompt(self, inputs: Dict):
"""Encode text prompts."""
inputs["c"] = inputs.get("c", [])
if inputs.get("prompt", None) is not None and self.transformer.text_embed:
inputs["c"].append(self.transformer.text_embed(inputs.pop("prompt")))
def preprocess(self, inputs: Dict) -> Dict:
"""Define the pipeline preprocess at every call."""
if not self.model.training:
raise RuntimeError("Excepted a trainable model.")
self.prepare_latents(inputs)
self.encode_prompt(inputs)