| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
| from einops import rearrange, repeat |
| from torch import nn |
| from torch.distributed import ProcessGroup, get_process_group_ranks |
|
|
| from cosmos_predict1.diffusion.module.attention import normalize |
| from cosmos_predict1.diffusion.module.parallel import split_inputs_cp |
| from cosmos_predict1.diffusion.module.timm import trunc_normal_ |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position |
| pos: a list of positions to be encoded: size (M,) |
| out: (M, D) |
| """ |
| assert embed_dim % 2 == 0 |
| omega = np.arange(embed_dim // 2, dtype=np.float64) |
| omega /= embed_dim / 2.0 |
| omega = 1.0 / 10000**omega |
|
|
| pos = pos.reshape(-1) |
| out = np.einsum("m,d->md", pos, omega) |
|
|
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
|
|
| emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| return emb |
|
|
|
|
| class VideoPositionEmb(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.cp_group = None |
|
|
| def enable_context_parallel(self, cp_group: ProcessGroup): |
| self.cp_group = cp_group |
|
|
| def disable_context_parallel(self): |
| self.cp_group = None |
|
|
| def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: |
| """ |
| It delegates the embedding generation to generate_embeddings function. |
| """ |
| B_T_H_W_C = x_B_T_H_W_C.shape |
| if self.cp_group is not None: |
| cp_ranks = get_process_group_ranks(self.cp_group) |
| cp_size = len(cp_ranks) |
| B, T, H, W, C = B_T_H_W_C |
| B_T_H_W_C = (B, T * cp_size, H, W, C) |
| embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) |
|
|
| if self.cp_group is not None: |
| if isinstance(self, VideoRopePosition3DEmb): |
| seq_dim = 0 |
| else: |
| seq_dim = 1 |
| embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) |
| return embeddings |
|
|
| def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): |
| raise NotImplementedError |
|
|
|
|
| class VideoRopePosition3DEmb(VideoPositionEmb): |
| def __init__( |
| self, |
| *, |
| head_dim: int, |
| len_h: int, |
| len_w: int, |
| len_t: int, |
| base_fps: int = 24, |
| h_extrapolation_ratio: float = 1.0, |
| w_extrapolation_ratio: float = 1.0, |
| t_extrapolation_ratio: float = 1.0, |
| **kwargs, |
| ): |
| del kwargs |
| super().__init__() |
| self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) |
| self.base_fps = base_fps |
| self.max_h = len_h |
| self.max_w = len_w |
|
|
| dim = head_dim |
| dim_h = dim // 6 * 2 |
| dim_w = dim_h |
| dim_t = dim - 2 * dim_h |
| assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" |
| self.register_buffer( |
| "dim_spatial_range", |
| torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, |
| persistent=False, |
| ) |
| self.register_buffer( |
| "dim_temporal_range", |
| torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, |
| persistent=False, |
| ) |
|
|
| self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) |
| self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) |
| self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) |
|
|
| def generate_embeddings( |
| self, |
| B_T_H_W_C: torch.Size, |
| fps: Optional[torch.Tensor] = None, |
| h_ntk_factor: Optional[float] = None, |
| w_ntk_factor: Optional[float] = None, |
| t_ntk_factor: Optional[float] = None, |
| ): |
| """ |
| Generate embeddings for the given input size. |
| |
| Args: |
| B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). |
| fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. |
| h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. |
| w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. |
| t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. |
| |
| Returns: |
| Not specified in the original code snippet. |
| """ |
| h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor |
| w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor |
| t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor |
|
|
| h_theta = 10000.0 * h_ntk_factor |
| w_theta = 10000.0 * w_ntk_factor |
| t_theta = 10000.0 * t_ntk_factor |
|
|
| h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) |
| w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) |
| temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) |
|
|
| B, T, H, W, _ = B_T_H_W_C |
| uniform_fps = (fps is None) or (fps.min() == fps.max()) |
| assert ( |
| uniform_fps or B == 1 or T == 1 |
| ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" |
| assert ( |
| H <= self.max_h and W <= self.max_w |
| ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" |
| half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) |
| half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) |
|
|
| |
| if fps is None: |
| assert T == 1, "T should be 1 for image batch." |
| half_emb_t = torch.outer(self.seq[:T], temporal_freqs) |
| else: |
| half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) |
|
|
| em_T_H_W_D = torch.cat( |
| [ |
| repeat(half_emb_t, "t d -> t h w d", h=H, w=W), |
| repeat(half_emb_h, "h d -> t h w d", t=T, w=W), |
| repeat(half_emb_w, "w d -> t h w d", t=T, h=H), |
| ] |
| * 2, |
| dim=-1, |
| ) |
|
|
| return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() |
|
|
|
|
| class LearnablePosEmbAxis(VideoPositionEmb): |
| def __init__( |
| self, |
| *, |
| interpolation: str, |
| model_channels: int, |
| len_h: int, |
| len_w: int, |
| len_t: int, |
| **kwargs, |
| ): |
| """ |
| Args: |
| interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. |
| """ |
| del kwargs |
| super().__init__() |
| self.interpolation = interpolation |
| assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" |
|
|
| self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) |
| self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) |
| self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) |
|
|
| trunc_normal_(self.pos_emb_h, std=0.02) |
| trunc_normal_(self.pos_emb_w, std=0.02) |
| trunc_normal_(self.pos_emb_t, std=0.02) |
|
|
| def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: |
| B, T, H, W, _ = B_T_H_W_C |
| if self.interpolation == "crop": |
| emb_h_H = self.pos_emb_h[:H] |
| emb_w_W = self.pos_emb_w[:W] |
| emb_t_T = self.pos_emb_t[:T] |
| emb = ( |
| repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) |
| + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) |
| + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) |
| ) |
| assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" |
| else: |
| raise ValueError(f"Unknown interpolation method {self.interpolation}") |
|
|
| return normalize(emb, dim=-1, eps=1e-6) |
|
|
|
|
| class MultiviewVideoPositionEmb(nn.Module): |
| def __init__( |
| self, |
| ): |
| super().__init__() |
| self.cp_group = None |
|
|
| def enable_context_parallel(self, cp_group: ProcessGroup): |
| self.cp_group = cp_group |
|
|
| def disable_context_parallel(self): |
| self.cp_group = None |
|
|
| def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: |
| """ |
| With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. |
| """ |
| B_T_H_W_C = x_B_T_H_W_C.shape |
| if self.cp_group is not None: |
| cp_ranks = get_process_group_ranks(self.cp_group) |
| cp_size = len(cp_ranks) |
| B, T, H, W, C = B_T_H_W_C |
| B_T_H_W_C = (B, T * cp_size, H, W, C) |
| embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) |
|
|
| if self.cp_group is not None: |
| if isinstance(self, MultiviewVideoRopePosition3DEmb): |
| seq_dim = 1 |
| embeddings = rearrange(embeddings, "(V T) H W D -> V (T H W) 1 1 D", V=self.n_views).float() |
| |
| embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) |
| embeddings = rearrange(embeddings, "V T 1 1 D -> (V T) 1 1 D", V=self.n_views).float() |
| else: |
| seq_dim = 1 |
| embeddings = rearrange(embeddings, "B (V T) H W C -> (B V) T H W C", V=self.n_views) |
| embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) |
| embeddings = rearrange(embeddings, "(B V) T H W C -> B (V T) H W C", V=self.n_views) |
| else: |
| if isinstance(self, MultiviewVideoRopePosition3DEmb): |
| embeddings = rearrange(embeddings, "t h w d -> (t h w) 1 1 d").float() |
|
|
| return embeddings |
|
|
| def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): |
| raise NotImplementedError |
|
|
|
|
| class MultiviewVideoRopePosition3DEmb(MultiviewVideoPositionEmb): |
| def __init__( |
| self, |
| *, |
| head_dim: int, |
| len_h: int, |
| len_w: int, |
| len_t: int, |
| base_fps: int = 24, |
| h_extrapolation_ratio: float = 1.0, |
| w_extrapolation_ratio: float = 1.0, |
| t_extrapolation_ratio: float = 1.0, |
| n_views: int = 4, |
| **kwargs, |
| ): |
| del kwargs |
| super().__init__() |
| self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) |
| self.base_fps = base_fps |
| self.max_h = len_h |
| self.max_w = len_w |
| self.n_views = n_views |
| dim = head_dim |
| dim_h = dim // 6 * 2 |
| dim_w = dim_h |
| dim_t = dim - 2 * dim_h |
| assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" |
| self.register_buffer( |
| "dim_spatial_range", |
| torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, |
| persistent=False, |
| ) |
| self.register_buffer( |
| "dim_temporal_range", |
| torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, |
| persistent=False, |
| ) |
|
|
| self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) |
| self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) |
| self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) |
|
|
| def generate_embedding_for_batch( |
| self, |
| B_T_H_W_C: torch.Size, |
| fps: Optional[torch.Tensor] = None, |
| h_ntk_factor: Optional[float] = None, |
| w_ntk_factor: Optional[float] = None, |
| t_ntk_factor: Optional[float] = None, |
| ): |
| """ |
| Generate embeddings for the given input size. |
| |
| Args: |
| B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). |
| fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. |
| h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. |
| w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. |
| t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. |
| |
| Returns: |
| Not specified in the original code snippet. |
| """ |
| h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor |
| w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor |
| t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor |
|
|
| h_theta = 10000.0 * h_ntk_factor |
| w_theta = 10000.0 * w_ntk_factor |
| t_theta = 10000.0 * t_ntk_factor |
|
|
| h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) |
| w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) |
| temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) |
|
|
| B, T, H, W, _ = B_T_H_W_C |
| uniform_fps = (fps is None) or (fps.min() == fps.max()) |
| assert uniform_fps |
|
|
| assert ( |
| uniform_fps or B == 1 or T == 1 |
| ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" |
| assert ( |
| H <= self.max_h and W <= self.max_w |
| ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." |
| half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) |
| half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) |
|
|
| |
| if fps is None: |
| assert T == 1, "T should be 1 for image batch." |
| half_emb_t = torch.outer(self.seq[:T], temporal_freqs) |
| else: |
| half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) |
|
|
| em_T_H_W_D = torch.cat( |
| [ |
| repeat(half_emb_t, "t d -> t h w d", h=H, w=W), |
| repeat(half_emb_h, "h d -> t h w d", t=T, w=W), |
| repeat(half_emb_w, "w d -> t h w d", t=T, h=H), |
| ] |
| * 2, |
| dim=-1, |
| ) |
|
|
| return em_T_H_W_D |
|
|
| def generate_embeddings( |
| self, |
| B_T_H_W_C: torch.Size, |
| fps: Optional[torch.Tensor] = None, |
| h_ntk_factor: Optional[float] = None, |
| w_ntk_factor: Optional[float] = None, |
| t_ntk_factor: Optional[float] = None, |
| ): |
| """ |
| Generate embeddings for the given input size. The camera view dimension is merged in the T dimension |
| |
| Args: |
| B_T_H_W_C (torch.Size): Input tensor size (Batch, Time * Views, Height, Width, Channels). |
| fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. |
| h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. |
| w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. |
| t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. |
| |
| Returns: |
| Not specified in the original code snippet. |
| """ |
|
|
| B, T, H, W, C = B_T_H_W_C |
|
|
| single_view_B_T_H_W_C = (B, T // self.n_views, H, W, C) |
| em_T_H_W_D = torch.cat( |
| [ |
| self.generate_embedding_for_batch( |
| single_view_B_T_H_W_C, |
| fps=fps, |
| h_ntk_factor=h_ntk_factor, |
| w_ntk_factor=w_ntk_factor, |
| t_ntk_factor=t_ntk_factor, |
| ) |
| for item in range(self.n_views) |
| ], |
| dim=0, |
| ) |
| return em_T_H_W_D |
|
|
|
|
| class MultiviewSinCosPosEmbAxis(MultiviewVideoPositionEmb): |
| def __init__( |
| self, |
| *, |
| interpolation: str, |
| model_channels: int, |
| len_h: int, |
| len_w: int, |
| len_t: int, |
| h_extrapolation_ratio: float = 1.0, |
| w_extrapolation_ratio: float = 1.0, |
| t_extrapolation_ratio: float = 1.0, |
| n_views: int = 4, |
| **kwargs, |
| ): |
| """ |
| Args: |
| interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. |
| """ |
| del kwargs |
| self.n_views = n_views |
| super().__init__() |
| self.interpolation = interpolation |
| assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" |
|
|
| dim = model_channels |
| dim_h = dim // 6 * 2 |
| dim_w = dim_h |
| dim_t = dim - 2 * dim_h |
| assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" |
|
|
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) |
| emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) |
| emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) |
|
|
| self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) |
| self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) |
| self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) |
|
|
| def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: |
| B, T, H, W, C = B_T_H_W_C |
|
|
| single_view_T = T // self.n_views |
|
|
| if self.interpolation == "crop": |
| emb_h_H = self.pos_emb_h[:H] |
| emb_w_W = self.pos_emb_w[:W] |
| emb_t_T = self.pos_emb_t[:single_view_T] |
| emb = torch.cat( |
| [ |
| torch.cat( |
| [ |
| repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), |
| repeat(emb_h_H, "h d-> b t h w d", b=B, t=single_view_T, w=W), |
| repeat(emb_w_W, "w d-> b t h w d", b=B, t=single_view_T, h=H), |
| ], |
| dim=-1, |
| ) |
| for _ in range(self.n_views) |
| ], |
| 1, |
| ) |
| assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" |
| return emb |
|
|
| raise ValueError(f"Unknown interpolation method {self.interpolation}") |
|
|