|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from diffusers import EDMEulerScheduler |
|
|
from megatron.core import parallel_state |
|
|
from torch import Tensor |
|
|
|
|
|
from cosmos_predict1.diffusion.conditioner import BaseVideoCondition |
|
|
from cosmos_predict1.diffusion.module import parallel |
|
|
from cosmos_predict1.diffusion.module.blocks import FourierFeatures |
|
|
from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp |
|
|
from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE |
|
|
from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser |
|
|
from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, setup_lora_requires_grad |
|
|
from cosmos_predict1.utils import log, misc |
|
|
from cosmos_predict1.utils.distributed import get_rank |
|
|
from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate |
|
|
|
|
|
|
|
|
class DiffusionT2WModel(torch.nn.Module): |
|
|
"""Text-to-world diffusion model that generates video frames from text descriptions. |
|
|
|
|
|
This model implements a diffusion-based approach for generating videos conditioned on text input. |
|
|
It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling, |
|
|
and classifier-free guidance. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
"""Initialize the diffusion model. |
|
|
|
|
|
Args: |
|
|
config: Configuration object containing model parameters and architecture settings |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
|
|
|
self.precision = { |
|
|
"float32": torch.float32, |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16, |
|
|
}[config.precision] |
|
|
self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} |
|
|
log.debug(f"DiffusionModel: precision {self.precision}") |
|
|
|
|
|
|
|
|
self.sigma_data = config.sigma_data |
|
|
self.state_shape = list(config.latent_shape) |
|
|
self.setup_data_key() |
|
|
|
|
|
|
|
|
self.scheduler = EDMEulerScheduler(sigma_max=80, sigma_min=0.0002, sigma_data=self.sigma_data) |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
|
|
|
@property |
|
|
def net(self): |
|
|
return self.model.net |
|
|
|
|
|
@property |
|
|
def conditioner(self): |
|
|
return self.model.conditioner |
|
|
|
|
|
@property |
|
|
def logvar(self): |
|
|
return self.model.logvar |
|
|
|
|
|
def set_up_tokenizer(self, tokenizer_dir: str): |
|
|
self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer) |
|
|
self.tokenizer.load_weights(tokenizer_dir) |
|
|
if hasattr(self.tokenizer, "reset_dtype"): |
|
|
self.tokenizer.reset_dtype() |
|
|
|
|
|
@misc.timer("DiffusionModel: set_up_model") |
|
|
def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format): |
|
|
"""Initialize the core model components including network, conditioner and logvar.""" |
|
|
self.model = self.build_model() |
|
|
if self.config.peft_control and self.config.peft_control.enabled: |
|
|
log.info("Setting up LoRA layers") |
|
|
peft_control_config_parser = LayerControlConfigParser(config=self.config.peft_control) |
|
|
peft_control_config = peft_control_config_parser.parse() |
|
|
add_lora_layers(self.model, peft_control_config) |
|
|
num_lora_params = setup_lora_requires_grad(self.model) |
|
|
self.model.requires_grad_(False) |
|
|
if num_lora_params == 0: |
|
|
raise ValueError("No LoRA parameters found. Please check the model configuration.") |
|
|
self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) |
|
|
|
|
|
def build_model(self) -> torch.nn.ModuleDict: |
|
|
"""Construct the model's neural network components. |
|
|
|
|
|
Returns: |
|
|
ModuleDict containing the network, conditioner and logvar components |
|
|
""" |
|
|
config = self.config |
|
|
net = lazy_instantiate(config.net) |
|
|
conditioner = lazy_instantiate(config.conditioner) |
|
|
logvar = torch.nn.Sequential( |
|
|
FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) |
|
|
) |
|
|
|
|
|
return torch.nn.ModuleDict( |
|
|
{ |
|
|
"net": net, |
|
|
"conditioner": conditioner, |
|
|
"logvar": logvar, |
|
|
} |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(self, state: torch.Tensor) -> torch.Tensor: |
|
|
"""Encode input state into latent representation using VAE. |
|
|
|
|
|
Args: |
|
|
state: Input tensor to encode |
|
|
|
|
|
Returns: |
|
|
Encoded latent representation scaled by sigma_data |
|
|
""" |
|
|
return self.tokenizer.encode(state) * self.sigma_data |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode(self, latent: torch.Tensor) -> torch.Tensor: |
|
|
"""Decode latent representation back to pixel space using VAE. |
|
|
|
|
|
Args: |
|
|
latent: Latent tensor to decode |
|
|
|
|
|
Returns: |
|
|
Decoded tensor in pixel space |
|
|
""" |
|
|
return self.tokenizer.decode(latent / self.sigma_data) |
|
|
|
|
|
def setup_data_key(self) -> None: |
|
|
"""Configure input data keys for video and image data.""" |
|
|
self.input_data_key = self.config.input_data_key |
|
|
|
|
|
def generate_samples_from_batch( |
|
|
self, |
|
|
data_batch: dict, |
|
|
guidance: float = 1.5, |
|
|
seed: int = 1, |
|
|
state_shape: tuple | None = None, |
|
|
n_sample: int | None = 1, |
|
|
is_negative_prompt: bool = False, |
|
|
num_steps: int = 35, |
|
|
) -> Tensor: |
|
|
"""Generate samples from a data batch using diffusion sampling. |
|
|
|
|
|
This function generates samples from either image or video data batches using diffusion sampling. |
|
|
It handles both conditional and unconditional generation with classifier-free guidance. |
|
|
|
|
|
Args: |
|
|
data_batch (dict): Raw data batch from the training data loader |
|
|
guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5. |
|
|
seed (int, optional): Random seed for reproducibility. Defaults to 1. |
|
|
state_shape (tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None. |
|
|
n_sample (int | None, optional): Number of samples to generate. Defaults to 1. |
|
|
is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False. |
|
|
num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35. |
|
|
|
|
|
Returns: |
|
|
Tensor: Generated samples after diffusion sampling |
|
|
""" |
|
|
condition, uncondition = self._get_conditions(data_batch, is_negative_prompt) |
|
|
|
|
|
self.scheduler.set_timesteps(num_steps) |
|
|
|
|
|
xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma |
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
if to_cp: |
|
|
xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
|
|
for t in self.scheduler.timesteps: |
|
|
xt = xt.to(**self.tensor_kwargs) |
|
|
xt_scaled = self.scheduler.scale_model_input(xt, timestep=t) |
|
|
|
|
|
t = t.to(**self.tensor_kwargs) |
|
|
net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict()) |
|
|
net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict()) |
|
|
net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond) |
|
|
|
|
|
xt = self.scheduler.step(net_output, t, xt).prev_sample |
|
|
samples = xt |
|
|
|
|
|
if to_cp: |
|
|
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
|
|
return samples |
|
|
|
|
|
def _get_conditions( |
|
|
self, |
|
|
data_batch: dict, |
|
|
is_negative_prompt: bool = False, |
|
|
): |
|
|
"""Get the conditions for the model. |
|
|
|
|
|
Args: |
|
|
data_batch: Input data dictionary |
|
|
is_negative_prompt: Whether to use negative prompting |
|
|
|
|
|
Returns: |
|
|
condition: Input conditions |
|
|
uncondition: Conditions removed/reduced to minimum (unconditioned) |
|
|
""" |
|
|
if is_negative_prompt: |
|
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
|
else: |
|
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
|
|
|
if parallel_state.is_initialized(): |
|
|
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) |
|
|
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) |
|
|
|
|
|
return condition, uncondition |
|
|
|
|
|
|
|
|
def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: |
|
|
condition_kwargs = {} |
|
|
for k, v in condition.to_dict().items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" |
|
|
condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) |
|
|
condition = type(condition)(**condition_kwargs) |
|
|
return condition |
|
|
|