| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Cosmos-Embed1 text+video embedder.""" |
|
|
| import math |
| from copy import deepcopy |
|
|
| import torch |
| from einops import rearrange |
| from torch import nn |
| from torch.nn import functional as F |
| from transformers import AutoModel, PreTrainedModel |
|
|
| from .configuration_embed1 import CosmosEmbed1Config |
| from .modeling_outputs import TextEmbedderOutput, TextVideoEmbedderOutput, VideoEmbedderOutput |
| from .modeling_qformer import BertLMHeadModel, load_qformer |
| from .modeling_utils import EncodingFactory, rank0_first |
| from .modeling_vit import EvaViTG |
|
|
|
|
| class CosmosEmbed1(PreTrainedModel): |
| config_class = CosmosEmbed1Config |
|
|
| def __init__(self, config: CosmosEmbed1Config) -> None: |
| """Cosmos-Embed1 video embedder constructor. |
| |
| Args: |
| config (CosmosEmbed1Config): Model configuration. |
| """ |
| super().__init__(config) |
|
|
| self.embed_dim = config.embed_dim |
| self.num_query_tokens = config.num_query_tokens |
| self.num_video_frames = config.num_video_frames |
| self.temporal_encoding_type = config.temporal_encoding_type |
| self.resolution = config.resolution |
| self.vocab_size = config.vocab_size |
| self.transformer_engine = config.transformer_engine |
| self.use_fp8 = config.use_fp8 |
|
|
| |
| self.register_buffer( |
| "normalization_mean", |
| torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1), |
| persistent=False, |
| ) |
| self.register_buffer( |
| "normalization_std", |
| torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1), |
| persistent=False, |
| ) |
| self.visual_encoder = EvaViTG( |
| img_size=self.resolution, |
| transformer_engine=self.transformer_engine, |
| use_fp8=self.use_fp8, |
| ) |
| self.ln_vision = nn.LayerNorm(self.visual_encoder.embed_dim) |
|
|
| |
| self.qformer, self.query_tokens = self._init_qformer( |
| num_query_tokens=self.num_query_tokens, |
| encoder_width=self.visual_encoder.embed_dim, |
| vocab_size=self.vocab_size, |
| ) |
| |
| state_dict = self.qformer.state_dict() |
| for name, param in self.qformer.named_parameters(): |
| if "_query" in name: |
| key_orig = name.replace("_query", "") |
| param.data.copy_(state_dict[key_orig]) |
|
|
| |
| self.temporal_encoding = EncodingFactory( |
| self.temporal_encoding_type, |
| embed_dim=self.visual_encoder.embed_dim, |
| max_len=self.num_video_frames, |
| ) |
|
|
| |
| self.vision_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) |
| self.text_proj = nn.Linear(self.qformer.config.hidden_size, self.embed_dim) |
| self.itm_proj = nn.Linear(self.qformer.config.hidden_size, 2) |
| |
| self.logit_scale = nn.Parameter(torch.tensor(math.log(10.0))) |
| self.logit_bias = nn.Parameter(torch.tensor(-10.0)) |
|
|
| @property |
| def hidden_dim(self) -> int: |
| return self.visual_encoder.embed_dim |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self) -> set: |
| ret = {"logit_scale", "logit_bias"} |
| return ret |
|
|
| def forward( |
| self, |
| videos: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| attention_mask: torch.FloatTensor, |
| ) -> TextVideoEmbedderOutput: |
| """Forward function for `ComosEmbed1`. |
| |
| Args: |
| videos (`torch.Tensor` of shape `(batch_size, num_frames, RGB, height, width)`): |
| batched videos with fixed number of RGB frames. |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. |
| Indices can be obtained by using [`AutoTokenizer`, `CosmosEmbed1Tokenizer`]. |
| attention_mask: (`torch.Tensor` of shape `(batch_size, sequence_length)`): |
| Mask to avoid performing attention on padding token indices. |
| Mask values select in `[0, 1]`. |
| - 1 for tokens that are **not masked**. |
| - 0 for tokens that are **masked**. |
| """ |
| video_output = self.get_video_embeddings(videos) |
| text_output = self.get_text_embeddings(input_ids, attention_mask) |
| return TextVideoEmbedderOutput(**video_output, **text_output) |
|
|
| def get_video_embeddings(self, videos: torch.Tensor) -> VideoEmbedderOutput: |
| videos = (videos - self.normalization_mean) / self.normalization_std |
| batch_size, num_frames, _, H, W = videos.shape |
| frame_batch = rearrange(videos, "b t c h w -> (b t) c h w") |
|
|
| |
| visual_embs = self.visual_encoder(frame_batch) |
| visual_embs = self.ln_vision(visual_embs) |
| visual_embs = rearrange( |
| visual_embs, |
| "(b t) k d -> b t k d", |
| b=batch_size, |
| t=num_frames, |
| k=visual_embs.size(1), |
| d=visual_embs.size(2), |
| ) |
|
|
| |
| visual_embs = self.temporal_encoding(visual_embs) |
|
|
| |
| encoder_hidden_states = rearrange(visual_embs, "b t k d -> b (t k) d") |
| encoder_attention_mask = torch.ones(encoder_hidden_states.size()[:-1], dtype=torch.long).to(videos.device) |
| query_tokens = self.query_tokens.expand(encoder_hidden_states.size(0), -1, -1) |
| visual_query_output = self.qformer.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| use_cache=True, |
| return_dict=True, |
| ) |
|
|
| visual_cls_tokens = visual_query_output.last_hidden_state.mean(dim=1, keepdim=False) |
| visual_proj = self.vision_proj(visual_cls_tokens) |
| visual_proj = F.normalize(visual_proj, dim=-1) |
|
|
| |
| |
| frame_cls_tokens, visual_embs = visual_embs[:, :, 0:1], visual_embs[:, :, 1:] |
| h = H // self.visual_encoder.patch_size |
| w = W // self.visual_encoder.patch_size |
| visual_embs = rearrange(visual_embs, "b t (h w) d -> b t h w d", h=h, w=w) |
|
|
| return VideoEmbedderOutput( |
| visual_proj=visual_proj, |
| visual_embs=visual_embs, |
| visual_query_output=visual_query_output, |
| visual_cls_tokens=visual_cls_tokens, |
| frame_cls_tokens=frame_cls_tokens, |
| ) |
|
|
| def get_text_embeddings( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: torch.FloatTensor, |
| ) -> TextEmbedderOutput: |
| text_query_output = self.qformer.bert( |
| input_ids=input_ids, |
| attention_mask=attention_mask.to(dtype=self.query_tokens.dtype), |
| return_dict=True, |
| ) |
| text_proj = text_query_output.last_hidden_state[:, 0, :] |
| text_proj = self.text_proj(text_proj) |
| text_proj = F.normalize(text_proj, dim=-1) |
|
|
| return TextEmbedderOutput( |
| text_proj=text_proj, |
| text_embs=text_query_output.last_hidden_state, |
| text_query_output=text_query_output, |
| ) |
|
|
| @classmethod |
| @rank0_first |
| def _init_qformer( |
| cls: "CosmosEmbed1", |
| num_query_tokens: int, |
| encoder_width: int, |
| vocab_size: int, |
| hidden_size: int = 768, |
| ) -> tuple[BertLMHeadModel, nn.Parameter]: |
| """Convenience function for initializing QFormer module.""" |
| qformer = load_qformer( |
| num_query_tokens=num_query_tokens, |
| encoder_width=encoder_width, |
| hidden_size=hidden_size, |
| vocab_size=vocab_size, |
| ) |
| query_tokens = nn.Parameter(torch.zeros(1, num_query_tokens, hidden_size)) |
| query_tokens.data.normal_(mean=0.0, std=0.02) |
| return qformer, query_tokens |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
| |
| config = kwargs.get("config", None) |
| if config is None: |
| config = CosmosEmbed1Config.from_pretrained(pretrained_model_name_or_path) |
|
|
| if config.transformer_engine: |
| config_no_te = deepcopy(config) |
| config_no_te.transformer_engine = False |
| config_no_te.use_fp8 = False |
|
|
| |
| kwargs_no_te = deepcopy(kwargs) |
| kwargs_no_te["config"] = config_no_te |
|
|
| |
| base_model = super().from_pretrained(pretrained_model_name_or_path, **kwargs_no_te) |
| base_state_dict = base_model.state_dict() |
|
|
| |
| model_with_te = cls(config=config) |
|
|
| |
| missing, unexpected = model_with_te.load_state_dict(base_state_dict, strict=False) |
|
|
| |
| if missing: |
| print(f"[TransformerEngine] Missing keys: {missing}") |
| if unexpected: |
| print(f"[TransformerEngine] Unexpected keys: {unexpected}") |
|
|
| return model_with_te |
| else: |
| return super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
| AutoModel.register(CosmosEmbed1Config, CosmosEmbed1) |
|
|