| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Misc functions and modules for Cosmos-Embed1.""" |
| |
|
| | import functools |
| | from logging import getLogger |
| | from typing import Callable, Optional, Protocol |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| |
|
| | logger = getLogger(__file__) |
| |
|
| |
|
| | def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: |
| | """Get the rank (GPU device) of the worker. |
| | |
| | Returns: |
| | rank (int): The rank of the worker. |
| | """ |
| | rank = 0 |
| | if dist.is_available() and dist.is_initialized(): |
| | rank = dist.get_rank(group) |
| | return rank |
| |
|
| |
|
| | def barrier() -> None: |
| | """Barrier for all GPUs.""" |
| | if dist.is_available() and dist.is_initialized(): |
| | dist.barrier() |
| |
|
| |
|
| | def rank0_first(func: Callable) -> Callable: |
| | """Run the function on rank 0 first, then on other ranks.""" |
| |
|
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | if get_rank() == 0: |
| | result = func(*args, **kwargs) |
| | barrier() |
| | if get_rank() != 0: |
| | result = func(*args, **kwargs) |
| | return result |
| |
|
| | return wrapper |
| |
|
| |
|
| | def add_docstring(docstring: str): |
| | def decorator(func): |
| | func.__doc__ = docstring |
| | return func |
| |
|
| | return decorator |
| |
|
| |
|
| | INIT_DOCSTRING = """ |
| | Constructor for encoding module. |
| | |
| | Args: |
| | embed_dim: size of embedding vectors, e.g. x.shape[3]. |
| | max_len: maximum length of temporal sequence, e.g. x.shape[1]. |
| | """ |
| |
|
| | FORWARD_DOCSTRING = """ |
| | Forward function. |
| | |
| | Args: |
| | x (`torch.Tensor`): rank 4 tensor to add spatio-temporal encodings to. |
| | |
| | Returns: |
| | `torch.Tensor` of rank 4. |
| | """ |
| |
|
| |
|
| | class EncodingProtocol(Protocol): |
| | def __init__(self, embed_dim: int, max_len: int) -> None: |
| | pass |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | pass |
| |
|
| |
|
| | def interpolate_temp_pos_embed(temp_embed: torch.Tensor, num_frames: int) -> torch.Tensor: |
| | """Linearly interpolates temporal encoding from `temp_embed.shape[0] to num_frames.""" |
| |
|
| | temp_embed_resized = temp_embed.permute(1, 0).unsqueeze(0) |
| | temp_embed_resized = nn.functional.interpolate( |
| | temp_embed_resized, |
| | size=(num_frames), |
| | mode="linear", |
| | align_corners=False, |
| | ) |
| | return temp_embed_resized.squeeze(0).permute(1, 0) |
| |
|
| |
|
| | class TemporalParameterEncoding(nn.Module, EncodingProtocol): |
| | @add_docstring(INIT_DOCSTRING) |
| | def __init__(self, embed_dim: int, max_len: int) -> None: |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.max_len = max_len |
| | self.temp_embed = nn.Parameter(torch.zeros(self.max_len, self.embed_dim)) |
| | nn.init.trunc_normal_(self.temp_embed, std=0.02) |
| |
|
| | @add_docstring(FORWARD_DOCSTRING) |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | _, t, _, _ = x.shape |
| | if t != self.temp_embed.shape[0]: |
| | logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.") |
| | temp_embed = interpolate_temp_pos_embed(self.temp_embed, t) |
| | else: |
| | temp_embed = self.temp_embed |
| | temp_embed = temp_embed.unsqueeze(0).unsqueeze(2) |
| | return x + temp_embed |
| |
|
| |
|
| | def create_neighbor_weight_matrix(num_tokens: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
| | indices = torch.arange(num_tokens, dtype=dtype, device=device) |
| | abs_diff = torch.abs(indices.unsqueeze(0) - indices.unsqueeze(1)) |
| | weights = 1.0 / (2.0**abs_diff) |
| | return weights |
| |
|
| |
|
| | def compute_t_adj(x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: |
| | return torch.einsum("bfnd,nk->bfkd", x, weights) |
| |
|
| |
|
| | def token_propagation(x: torch.Tensor, num_tokens: int) -> torch.Tensor: |
| | """Apply neighboring token propagation update.""" |
| | weights = create_neighbor_weight_matrix(num_tokens, x.device, x.dtype) |
| | t_adj = compute_t_adj(x, weights) |
| | return x + t_adj - t_adj.detach() |
| |
|
| |
|
| | class NeighboringTokenPropagationEncoding(TemporalParameterEncoding): |
| | """ |
| | Neighboring Token Propagation method inspired by Momentor (https://arxiv.org/abs/2402.11435) |
| | """ |
| |
|
| | @add_docstring(FORWARD_DOCSTRING) |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | _, t, q, _ = x.shape |
| | if t != self.temp_embed.shape[0]: |
| | logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.") |
| | temp_embed = interpolate_temp_pos_embed(self.temp_embed, t) |
| | else: |
| | temp_embed = self.temp_embed |
| | temp_embed = temp_embed.unsqueeze(0).unsqueeze(2) |
| |
|
| | if self.training: |
| | temp_embed = token_propagation(temp_embed, q) |
| | return x + temp_embed |
| |
|
| |
|
| | class EncodingFactory(nn.Module): |
| | def __init__(self, encoding_type: str, embed_dim: int, max_len: int) -> None: |
| | super().__init__() |
| | fn = { |
| | "temporal_parameter": TemporalParameterEncoding, |
| | "neighboring_token_propagation": NeighboringTokenPropagationEncoding, |
| | }[encoding_type] |
| | self.encoding = fn(embed_dim=embed_dim, max_len=max_len) |
| |
|
| | @add_docstring(FORWARD_DOCSTRING) |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.encoding(x) |
| |
|