File size: 4,066 Bytes
d403233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Text-to-Video training pipeline for NOVA."""

from typing import Dict

from diffusers.pipelines.pipeline_utils import DiffusionPipeline
import torch

from diffnext.engine import engine_utils
from diffnext.pipelines.builder import build_diffusion_scheduler
from diffnext.pipelines.pipeline_utils import PipelineMixin


class NOVATrainT2VPipeline(DiffusionPipeline, PipelineMixin):
    """Pipeline for training NOVA T2V models."""

    _optional_components = ["transformer", "scheduler", "vae", "text_encoder", "tokenizer"]

    def __init__(
        self,
        transformer=None,
        scheduler=None,
        vae=None,
        text_encoder=None,
        tokenizer=None,
        trust_remote_code=True,
    ):
        super(NOVATrainT2VPipeline, self).__init__()
        self.vae = self.register_module(vae, "vae")
        self.text_encoder = self.register_module(text_encoder, "text_encoder")
        self.tokenizer = self.register_module(tokenizer, "tokenizer")
        self.transformer = self.register_module(transformer, "transformer")
        self.scheduler = self.register_module(scheduler, "scheduler")
        self.transformer.noise_scheduler = build_diffusion_scheduler(self.scheduler)
        self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0

    @property
    def model(self) -> torch.nn.Module:
        """Return the trainable model."""
        return self.transformer

    def configure_model(self, config, accelerator=None, logger=None) -> torch.nn.Module:
        """Configure the trainable model."""
        ckpt_lvl = config.model.get("gradient_checkpointing", 0)
        self.model.loss_repeat = config.model.get("loss_repeat", 4)
        [setattr(blk, "mlp_checkpointing", ckpt_lvl) for blk in self.model.video_encoder.blocks]
        for blk in self.model.image_encoder.blocks if hasattr(self.model, "image_encoder") else []:
            setattr(blk, "mlp_checkpointing", ckpt_lvl > 1)
        [setattr(blk, "mlp_checkpointing", ckpt_lvl > 2) for blk in self.model.image_decoder.blocks]
        engine_utils.freeze_module(self.model.text_embed.norm)  # We always use frozen LN.
        engine_utils.freeze_module(self.model.motion_embed) if self.model.motion_embed else None
        self.model.pipeline_preprocess = self.preprocess
        self.model.text_embed.encoders = [self.tokenizer, self.text_encoder]
        return self.model.train()

    def prepare_latents(self, inputs: Dict):
        """Prepare the video latents."""
        if "images" in inputs:
            raise NotImplementedError
        elif "latents" in inputs:
            x = torch.as_tensor(inputs.pop("latents"), device=self.device)
            x = x.to(dtype=self.dtype if x.is_floating_point() else torch.int64)
            inputs["x"] = self.vae.scale_(self.vae.latent_dist(x).sample())

    def encode_prompt(self, inputs: Dict):
        """Encode text prompts."""
        inputs["c"] = inputs.get("c", [])
        if inputs.get("prompt", None) is not None and self.transformer.text_embed:
            inputs["c"].append(self.transformer.text_embed(inputs.pop("prompt")))

    def preprocess(self, inputs: Dict) -> Dict:
        """Define the pipeline preprocess at every call."""
        if not self.model.training:
            raise RuntimeError("Excepted a trainable model.")
        self.prepare_latents(inputs)
        self.encode_prompt(inputs)