| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Text-to-Video training pipeline for NOVA.""" |
|
|
| from typing import Dict |
|
|
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| import torch |
|
|
| from diffnext.engine import engine_utils |
| from diffnext.pipelines.builder import build_diffusion_scheduler |
| from diffnext.pipelines.pipeline_utils import PipelineMixin |
|
|
|
|
| class NOVATrainT2VPipeline(DiffusionPipeline, PipelineMixin): |
| """Pipeline for training NOVA T2V models.""" |
|
|
| _optional_components = ["transformer", "scheduler", "vae", "text_encoder", "tokenizer"] |
|
|
| def __init__( |
| self, |
| transformer=None, |
| scheduler=None, |
| vae=None, |
| text_encoder=None, |
| tokenizer=None, |
| trust_remote_code=True, |
| ): |
| super(NOVATrainT2VPipeline, self).__init__() |
| self.vae = self.register_module(vae, "vae") |
| self.text_encoder = self.register_module(text_encoder, "text_encoder") |
| self.tokenizer = self.register_module(tokenizer, "tokenizer") |
| self.transformer = self.register_module(transformer, "transformer") |
| self.scheduler = self.register_module(scheduler, "scheduler") |
| self.transformer.noise_scheduler = build_diffusion_scheduler(self.scheduler) |
| self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0 |
|
|
| @property |
| def model(self) -> torch.nn.Module: |
| """Return the trainable model.""" |
| return self.transformer |
|
|
| def configure_model(self, config, accelerator=None, logger=None) -> torch.nn.Module: |
| """Configure the trainable model.""" |
| ckpt_lvl = config.model.get("gradient_checkpointing", 0) |
| self.model.loss_repeat = config.model.get("loss_repeat", 4) |
| [setattr(blk, "mlp_checkpointing", ckpt_lvl) for blk in self.model.video_encoder.blocks] |
| for blk in self.model.image_encoder.blocks if hasattr(self.model, "image_encoder") else []: |
| setattr(blk, "mlp_checkpointing", ckpt_lvl > 1) |
| [setattr(blk, "mlp_checkpointing", ckpt_lvl > 2) for blk in self.model.image_decoder.blocks] |
| engine_utils.freeze_module(self.model.text_embed.norm) |
| engine_utils.freeze_module(self.model.motion_embed) if self.model.motion_embed else None |
| self.model.pipeline_preprocess = self.preprocess |
| self.model.text_embed.encoders = [self.tokenizer, self.text_encoder] |
| return self.model.train() |
|
|
| def prepare_latents(self, inputs: Dict): |
| """Prepare the video latents.""" |
| if "images" in inputs: |
| raise NotImplementedError |
| elif "latents" in inputs: |
| x = torch.as_tensor(inputs.pop("latents"), device=self.device) |
| x = x.to(dtype=self.dtype if x.is_floating_point() else torch.int64) |
| inputs["x"] = self.vae.scale_(self.vae.latent_dist(x).sample()) |
|
|
| def encode_prompt(self, inputs: Dict): |
| """Encode text prompts.""" |
| inputs["c"] = inputs.get("c", []) |
| if inputs.get("prompt", None) is not None and self.transformer.text_embed: |
| inputs["c"].append(self.transformer.text_embed(inputs.pop("prompt"))) |
|
|
| def preprocess(self, inputs: Dict) -> Dict: |
| """Define the pipeline preprocess at every call.""" |
| if not self.model.training: |
| raise RuntimeError("Excepted a trainable model.") |
| self.prepare_latents(inputs) |
| self.encode_prompt(inputs) |
|
|