| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Cosmos-Embed1 text+video embedder.""" |
| |
|
| | import math |
| | from copy import deepcopy |
| |
|
| | import torch |
| | from einops import rearrange |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from transformers import AutoModel, PreTrainedModel |
| |
|
| | from .configuration_embed1 import CosmosEmbed1Config |
| | from .modeling_outputs import TextEmbedderOutput, TextVideoEmbedderOutput, VideoEmbedderOutput |
| | from .modeling_qformer import BertLMHeadModel, load_qformer |
| | from .modeling_utils import EncodingFactory, rank0_first |
| | from .modeling_vit import EvaViTG |
| |
|
| |
|
| | class CosmosEmbed1(PreTrainedModel): |
| | config_class = CosmosEmbed1Config |
| |
|
| | def __init__(self, config: CosmosEmbed1Config) -> None: |
| | """Cosmos-Embed1 video embedder constructor. |
| | |
| | Args: |
| | config (CosmosEmbed1Config): Model configuration. |
| | """ |
| | super().__init__(config) |
| |
|
| | self.embed_dim = config.embed_dim |
| | self.num_query_tokens = config.num_query_tokens |
| | self.num_video_frames = config.num_video_frames |
| | self.temporal_encoding_type = config.temporal_encoding_type |
| | self.resolution = config.resolution |
| | self.vocab_size = config.vocab_size |
| | self.transformer_engine = config.transformer_engine |
| | self.use_fp8 = config.use_fp8 |
| |
|
| | |
| | self.register_buffer( |
| | "normalization_mean", |
| | torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1), |
| | persistent=False, |
| | ) |
| | self.register_buffer( |
| | "normalization_std", |
| | torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1), |
| | persistent=False, |
| | ) |
| | self.visual_encoder = EvaViTG( |
| | img_size=self.resolution, |
| | transformer_engine=self.transformer_engine, |
| | use_fp8=self.use_fp8, |
| | ) |
| | self.ln_vision = nn.LayerNorm(self.visual_encoder.embed_dim) |
| |
|
| | |
| | self.qformer, self.query_tokens = self._init_qformer( |
| | num_query_tokens=self.num_query_tokens, |
| | encoder_width=self.visual_encoder.embed_dim, |
| | vocab_size=self.vocab_size, |
| | ) |
| | |
| | state_dict = self.qformer.state_dict() |
| | for name, param in self.qformer.named_parameters(): |
| | if "_query" in name: |
| | key_orig = name.replace("_query", "") |
| | param.data.copy_(state_dict[key_orig]) |
| |
|
| | |
| | self.temporal_encoding = EncodingFactory( |
| | self.temporal_encoding_type, |
| | embed_dim=self.visual_encoder.embed_dim, |
| | max_len=self.num_video_frames, |
| | ) |
| |
|
| | |
| | self.vision_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) |
| | self.text_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) |
| | self.itm_proj = nn.Linear(self.qformer.config.hidden_size, 2) |
| | |
| | self.logit_scale = nn.Parameter(torch.tensor(math.log(10.0))) |
| | self.logit_bias = nn.Parameter(torch.tensor(-10.0)) |
| |
|
| | @property |
| | def hidden_dim(self) -> int: |
| | return self.visual_encoder.embed_dim |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self) -> set: |
| | ret = {"logit_scale", "logit_bias"} |
| | return ret |
| |
|
| | def forward( |
| | self, |
| | videos: torch.FloatTensor, |
| | input_ids: torch.LongTensor, |
| | attention_mask: torch.FloatTensor, |
| | ) -> TextVideoEmbedderOutput: |
| | """Forward function for `ComosEmbed1`. |
| | |
| | Args: |
| | videos (`torch.Tensor` of shape `(batch_size, num_frames, RGB, height, width)`): |
| | batched videos with fixed number of RGB frames. |
| | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| | Indices of input sequence tokens in the vocabulary. |
| | Indices can be obtained by using [`AutoTokenizer`, `CosmosEmbed1Tokenizer`]. |
| | attention_mask: (`torch.Tensor` of shape `(batch_size, sequence_length)`): |
| | Mask to avoid performing attention on padding token indices. |
| | Mask values select in `[0, 1]`. |
| | - 1 for tokens that are **not masked**. |
| | - 0 for tokens that are **masked**. |
| | """ |
| | video_output = self.get_video_embeddings(videos) |
| | text_output = self.get_text_embeddings(input_ids, attention_mask) |
| | return TextVideoEmbedderOutput(**video_output, **text_output) |
| |
|
| | def get_video_embeddings(self, videos: torch.Tensor) -> VideoEmbedderOutput: |
| | videos = (videos - self.normalization_mean) / self.normalization_std |
| | batch_size, num_frames, _, H, W = videos.shape |
| | frame_batch = rearrange(videos, "b t c h w -> (b t) c h w") |
| |
|
| | |
| | visual_embs = self.visual_encoder(frame_batch) |
| | visual_embs = self.ln_vision(visual_embs) |
| | visual_embs = rearrange( |
| | visual_embs, |
| | "(b t) k d -> b t k d", |
| | b=batch_size, |
| | t=num_frames, |
| | k=visual_embs.size(1), |
| | d=visual_embs.size(2), |
| | ) |
| |
|
| | |
| | visual_embs = self.temporal_encoding(visual_embs) |
| |
|
| | |
| | encoder_hidden_states = rearrange(visual_embs, "b t k d -> b (t k) d") |
| | encoder_attention_mask = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(videos.device) |
| | query_tokens = self.query_tokens.expand(encoder_hidden_states.size(0), -1, -1) |
| | visual_query_output = self.qformer.bert( |
| | query_embeds=query_tokens, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | use_cache=True, |
| | return_dict=True, |
| | ) |
| |
|
| | visual_cls_tokens = visual_query_output.last_hidden_state.mean(dim=1, keepdim=False) |
| | visual_proj = self.vision_proj(visual_cls_tokens) |
| | visual_proj = F.normalize(visual_proj, dim=-1) |
| |
|
| | |
| | |
| | frame_cls_tokens, visual_embs = visual_embs[:, :, 0:1], visual_embs[:, :, 1:] |
| | h = H // self.visual_encoder.patch_size |
| | w = W // self.visual_encoder.patch_size |
| | visual_embs = rearrange(visual_embs, "b t (h w) d -> b t h w d", h=h, w=w) |
| |
|
| | return VideoEmbedderOutput( |
| | visual_proj=visual_proj, |
| | visual_embs=visual_embs, |
| | visual_query_output=visual_query_output, |
| | visual_cls_tokens=visual_cls_tokens, |
| | frame_cls_tokens=frame_cls_tokens, |
| | ) |
| |
|
| | def get_text_embeddings( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: torch.FloatTensor, |
| | ) -> TextEmbedderOutput: |
| | text_query_output = self.qformer.bert( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask.to(dtype=self.query_tokens.dtype), |
| | return_dict=True, |
| | ) |
| | text_proj = text_query_output.last_hidden_state[:, 0, :] |
| | text_proj = self.text_proj(text_proj) |
| | text_proj = F.normalize(text_proj, dim=-1) |
| |
|
| | return TextEmbedderOutput( |
| | text_proj=text_proj, |
| | text_embs=text_query_output.last_hidden_state, |
| | text_query_output=text_query_output, |
| | ) |
| |
|
| | @classmethod |
| | @rank0_first |
| | def _init_qformer( |
| | cls: "CosmosEmbed1", |
| | num_query_tokens: int, |
| | encoder_width: int, |
| | vocab_size: int, |
| | hidden_size: int = 768, |
| | ) -> tuple[BertLMHeadModel, nn.Parameter]: |
| | """Convenience function for initializing QFormer module.""" |
| | qformer = load_qformer( |
| | num_query_tokens=num_query_tokens, |
| | encoder_width=encoder_width, |
| | hidden_size=hidden_size, |
| | vocab_size=vocab_size, |
| | ) |
| | query_tokens = nn.Parameter(torch.zeros(1, num_query_tokens, hidden_size)) |
| | query_tokens.data.normal_(mean=0.0, std=0.02) |
| | return qformer, query_tokens |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
| | |
| | config = kwargs.get("config", None) |
| | if config is None: |
| | config = CosmosEmbed1Config.from_pretrained(pretrained_model_name_or_path) |
| |
|
| | if config.transformer_engine: |
| | config_no_te = deepcopy(config) |
| | config_no_te.transformer_engine = False |
| | config_no_te.use_fp8 = False |
| |
|
| | |
| | kwargs_no_te = deepcopy(kwargs) |
| | kwargs_no_te["config"] = config_no_te |
| |
|
| | |
| | base_model = super().from_pretrained(pretrained_model_name_or_path, **kwargs_no_te) |
| | base_state_dict = base_model.state_dict() |
| |
|
| | |
| | model_with_te = cls(config=config) |
| |
|
| | |
| | missing, unexpected = model_with_te.load_state_dict(base_state_dict, strict=False) |
| |
|
| | |
| | if missing: |
| | print(f"[TransformerEngine] Missing keys: {missing}") |
| | if unexpected: |
| | print(f"[TransformerEngine] Unexpected keys: {unexpected}") |
| |
|
| | return model_with_te |
| | else: |
| | return super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
| |
|
| |
|
| | AutoModel.register(CosmosEmbed1Config, CosmosEmbed1) |
| |
|