Lyra / src /models /utils /model.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import json
import einops
from typing import Tuple
from src.models.utils.attention import Block
from src.models.utils.mamba2 import Mamba2Block
from src.models.utils.cosmos_1_tokenizer import load_cosmos_1_tokenizer
def load_vae(vae_backbone, vae_path):
if vae_backbone == 'cosmos1':
vae = load_cosmos_1_tokenizer(vae_path, load_decoder=True, load_jit=True)
return vae
def encode_cosmos1(vae, video):
sample = vae.encode(video)[0]
return sample
def encode_video_model(vae, video, vae_backbone):
if vae_backbone == 'cosmos1':
encode_func = encode_cosmos1
return encode_func(vae, video)
def encode_video(vae, video, vae_backbone):
chunk_size = get_encoder_chunk_size(video)
with torch.no_grad():
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
samples = []
for chunk_idx in range(0, video.shape[0], chunk_size):
video_batch = video[chunk_idx: chunk_idx + chunk_size]
sample = encode_video_model(vae, video_batch, vae_backbone)
samples.append(sample)
samples = torch.cat(samples, 0)
samples = samples.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
return samples
def get_encoder_chunk_size(video, encoder=True):
if encoder:
encoder_chunk_sizes = {49: {480: 4, 256: 10, 128: 20}, 121: {704: 1}}
else:
encoder_chunk_sizes = {49: {480: 4, 256: 10, 128: 20}, 121: {704: 1}}
B, T, C, H, W = video.shape
chunk_size = B
if T in encoder_chunk_sizes:
encoder_chunk_sizes_T = encoder_chunk_sizes[T]
if H in encoder_chunk_sizes_T:
chunk_size = encoder_chunk_sizes_T[H]
return chunk_size
def encode_multi_view_video(vae, video, num_input_multi_views, vae_backbone):
if num_input_multi_views != 1:
video = einops.rearrange(video, 'b (v t) c h w -> (b v) t c h w', v=num_input_multi_views)
model_input = encode_video(vae, video, vae_backbone)
if num_input_multi_views != 1:
model_input = einops.rearrange(model_input, '(b v) t c h w -> b (v t) c h w', v=num_input_multi_views)
return model_input
def decode_multi_view_latents(vae, latents: torch.Tensor, num_input_multi_views: int, vae_backbone: str):
if num_input_multi_views != 1:
latents = einops.rearrange(latents, 'b (v t) c h w -> (b v) t c h w', v=num_input_multi_views)
chunk_size = get_encoder_chunk_size(latents, encoder=False)
video = []
for chunk_idx in range(0, latents.shape[0], chunk_size):
latents_batch = latents[chunk_idx: chunk_idx + chunk_size]
video_batch = decode_video_model(vae, latents_batch, vae_backbone)
video.append(video_batch)
video = torch.cat(video, 0)
if num_input_multi_views != 1:
video = einops.rearrange(video, '(b v) c t h w -> b c (v t) h w', v=num_input_multi_views)
video = video.transpose(1, 2)
video = video.clip(-1, 1)
return video
def decode_cosmos1(vae, video):
video = einops.rearrange(video, 'b t c h w -> b c t h w')
sample = vae.decode(video)
return sample
def decode_video_model(vae, video, vae_backbone):
if vae_backbone == 'cosmos1':
decode_func = decode_cosmos1
return decode_func(vae, video)
def encode_plucker_vae(batch, encode_video, plucker_key='plucker_embedding'):
batch[plucker_key] = torch.cat((
encode_video(batch[plucker_key][:, :, :3]),
encode_video(batch[plucker_key][:, :, 3:])),
2)
return batch
def encode_latent_time_vae(batch, encode_video, img_size, time_keys=['time_embeddings', 'time_embeddings_target']):
time_embeddings_out = []
for k in time_keys:
batch[k] = repeat_time_spatially(batch[k], img_size)
time_embeddings_out.append(encode_video(batch[k] * 2 - 1))
batch[time_keys[0]] = torch.cat(time_embeddings_out, 1)
del batch[time_keys[1]]
return batch
def repeat_time_spatially(time_embeddings: torch.Tensor, img_size: Tuple[int, int]):
return einops.repeat(time_embeddings, 'b t c -> b t c h w', h=img_size[0], w=img_size[1])
def get_model_blocks(enc_embed_dim, enc_depth, enc_num_heads, mlp_ratio, use_mamba, llrm_7m1t, norm_layer, use_qk_norm, index_transformer_block=8):
if use_mamba:
# Mix of mamba2 and transformer
if llrm_7m1t:
enc_blocks = []
for i in range(enc_depth):
if (i + 1) % index_transformer_block == 0:
block_mamba = Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, use_qk_norm=use_qk_norm)
else:
block_mamba = Mamba2Block(enc_embed_dim)
enc_blocks.append(block_mamba)
enc_blocks = nn.ModuleList(enc_blocks)
# Only mamba2
else:
enc_blocks = nn.ModuleList([
Mamba2Block(enc_embed_dim)
for i in range(enc_depth)])
else:
enc_blocks = nn.ModuleList(
[Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, use_qk_norm=use_qk_norm)
for i in range(enc_depth)]
)
return enc_blocks
def forward_checkpointing(layer, *args, gradient_checkpoint=False):
if not gradient_checkpoint:
return layer(*args)
# Identify tensor positions and values
tensor_positions = [(i, arg) for i, arg in enumerate(args) if isinstance(arg, torch.Tensor)]
tensor_indices, tensor_args = zip(*tensor_positions) if tensor_positions else ([], [])
def wrapped(*tensors_in):
args_copy = list(args)
for i, t in zip(tensor_indices, tensors_in):
args_copy[i] = t
return layer(*args_copy)
return torch.utils.checkpoint.checkpoint(wrapped, *tensor_args, use_reentrant=False)
def timestep_embedding(timesteps, dim, max_period=10000, use_orig=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if use_orig:
dim -= 1
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
if use_orig:
embedding = torch.cat([timesteps[:, None], embedding], dim=-1)
return embedding
class ConvTranspose3dFactorized(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
upsample_mode='trilinear', use_channel_reduction=True, gradient_checkpoint=True):
super().__init__()
self.scale_factor = stride
self.upsample_mode = upsample_mode
self.use_channel_reduction = use_channel_reduction
self.gradient_checkpoint = gradient_checkpoint
if self.use_channel_reduction:
self.channel_reducer = nn.Conv3d(in_channels, out_channels, kernel_size=1)
conv_in_channels = out_channels
else:
conv_in_channels = in_channels
self.conv = nn.Conv3d(
conv_in_channels,
out_channels,
kernel_size=kernel_size,
padding=padding
)
def _interpolate(self, x):
return F.interpolate(
x,
scale_factor=self.scale_factor,
mode=self.upsample_mode,
align_corners=False if self.upsample_mode in ['linear', 'bilinear', 'trilinear'] else None
)
def _conv(self, x):
return self.conv(x)
def _channel_reduce(self, x):
return self.channel_reducer(x)
def forward(self, x):
if self.use_channel_reduction:
x = forward_checkpointing(self._channel_reduce, x, gradient_checkpoint=self.gradient_checkpoint)
x = forward_checkpointing(self._interpolate, x, gradient_checkpoint=self.gradient_checkpoint)
x = forward_checkpointing(self._conv, x, gradient_checkpoint=self.gradient_checkpoint)
return x
class ConvTranspose3dFactorized(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
upsample_mode='trilinear', use_channel_reduction=True, gradient_checkpoint=True):
super().__init__()
self.scale_factor = stride
self.upsample_mode = upsample_mode
self.use_channel_reduction = use_channel_reduction
self.gradient_checkpoint = gradient_checkpoint
if self.use_channel_reduction:
self.channel_reducer = nn.Conv3d(in_channels, out_channels, kernel_size=1)
conv_in_channels = out_channels
else:
conv_in_channels = in_channels
self.conv = nn.Conv3d(
conv_in_channels,
out_channels,
kernel_size=kernel_size,
padding=padding
)
def _interpolate(self, x):
return F.interpolate(
x,
scale_factor=self.scale_factor,
mode=self.upsample_mode,
align_corners=False if self.upsample_mode in ['linear', 'bilinear', 'trilinear'] else None
)
def _conv(self, x):
return self.conv(x)
def _channel_reduce(self, x):
return self.channel_reducer(x)
def forward(self, x):
if self.use_channel_reduction:
x = forward_checkpointing(self._channel_reduce, x, gradient_checkpoint=self.gradient_checkpoint)
x = forward_checkpointing(self._interpolate, x, gradient_checkpoint=self.gradient_checkpoint)
x = forward_checkpointing(self._conv, x, gradient_checkpoint=self.gradient_checkpoint)
return x
class ConvTranspose3dReduced(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
hidden_channels: int = None, use_channel_reduction: bool = True):
super().__init__()
if hidden_channels is None:
hidden_channels = out_channels
self.use_channel_reduction = use_channel_reduction
if self.use_channel_reduction:
self.channel_reducer = nn.Conv3d(in_channels, hidden_channels, kernel_size=1)
conv_in_channels = hidden_channels
else:
conv_in_channels = in_channels
self.conv = nn.ConvTranspose3d(
conv_in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
def _conv(self, x):
return self.conv(x)
def _channel_reduce(self, x):
return self.channel_reducer(x)
def forward(self, x: torch.Tensor, gradient_checkpoint: bool = False):
if self.use_channel_reduction:
x = forward_checkpointing(self._channel_reduce, x, gradient_checkpoint=gradient_checkpoint)
x = forward_checkpointing(self._conv, x, gradient_checkpoint=gradient_checkpoint)
return x
class MultiStageConvTranspose3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
multi_stage=True,
norm_layer=None,
activation=nn.ReLU(inplace=True),
gradient_checkpoint=False,
):
super().__init__()
self.multi_stage = multi_stage
self.gradient_checkpoint = gradient_checkpoint
if not multi_stage:
self.upsampler = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self.pre_pad = None
else:
self.target_kernel_size = kernel_size
self.target_stride = stride
self.target_padding = padding
sD, sH, sW = stride
assert all(s == sH for s in stride[1:]), "Only symmetric spatial strides supported."
temporal_stages = int(torch.log2(torch.tensor(sD)).item()) if sD > 1 else 0
spatial_stages = int(torch.log2(torch.tensor(sH)).item())
self.temporal_blocks = nn.ModuleList()
for i in range(temporal_stages):
in_ch = in_channels if i == 0 else out_channels
self.temporal_blocks.append(nn.Sequential(
nn.ConvTranspose3d(
in_ch,
out_channels,
kernel_size=(3, 1, 1),
stride=(2, 1, 1),
padding=(1, 0, 0),
output_padding=(1, 0, 0),
),
*( [norm_layer(out_channels)] if norm_layer else [] ),
*( [activation] if activation else [] )
))
self.spatial_blocks = nn.ModuleList()
for i in range(spatial_stages):
in_ch = in_channels if (i == 0 and temporal_stages == 0) else out_channels
self.spatial_blocks.append(nn.Sequential(
nn.ConvTranspose3d(
in_ch,
out_channels,
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
output_padding=(0, 1, 1),
),
*( [norm_layer(out_channels)] if norm_layer else [] ),
*( [activation] if activation else [] )
))
self.pre_pad = None
def _compute_input_padding(self, input_shape):
D, H, W = input_shape[-3:]
sD, sH, sW = self.target_stride
kD, kH, kW = self.target_kernel_size
pD, pH, pW = self.target_padding
target_D = (D - 1) * sD - 2 * pD + kD
spatial_out_H = (H - 1) * sH - 2 * pH + kH
spatial_out_W = (W - 1) * sW - 2 * pW + kW
approx_H = H * (2 ** len(self.spatial_blocks))
approx_W = W * (2 ** len(self.spatial_blocks))
pad_H = spatial_out_H - approx_H
pad_W = spatial_out_W - approx_W
pad = [0, pad_W, 0, pad_H, 0, 0] # spatial padding only
return nn.ConstantPad3d(pad, 0.0)
def forward(self, x):
if not self.multi_stage:
return forward_checkpointing(self.upsampler, x, gradient_checkpoint=self.gradient_checkpoint)
if self.pre_pad is None:
self.pre_pad = self._compute_input_padding(x.shape)
x = self.pre_pad(x)
# Temporal upsampling stages with +1 growth each time
for block in self.temporal_blocks:
x = forward_checkpointing(block, x, gradient_checkpoint=self.gradient_checkpoint)
# Spatial upsampling stages
for block in self.spatial_blocks:
x = forward_checkpointing(block, x, gradient_checkpoint=self.gradient_checkpoint)
return x
class PositionalEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_length: int):
super().__init__()
self.pos_embed = nn.Embedding(max_seq_length, dim)
def forward(self, x):
# x: (batch_size, seq_len, dim)
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0) # (1, seq_len)
pos_embedding = self.pos_embed(positions) # (1, seq_len, dim)
return x + pos_embedding