| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Annotated, Callable, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding |
| from nemo.collections.llm.gpt.model.llama import Llama3Config, LlamaModel |
| from nemo.collections.llm.utils import Config |
| from nemo.lightning import OptimizerModule, io |
| from nemo.lightning.base import teardown |
| from torch import Tensor, nn |
|
|
| from .log import log |
|
|
|
|
| class RotaryEmbedding3D(RotaryEmbedding): |
| """Rotary Embedding3D for Cosmos Language model. |
| Args: |
| kv_channels (int): Projection weights dimension in multi-head attention. Obtained |
| from transformer config |
| rotary_base (int, optional): Base period for rotary position embeddings. Defaults to |
| 10000. |
| use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly |
| on the GPU. Defaults to False |
| latent_shape: The shape of the latents produced by the video after being tokenized |
| """ |
|
|
| def __init__( |
| self, |
| seq_len: int, |
| kv_channels: int, |
| training_type: str = None, |
| rotary_base: int = 10000, |
| use_cpu_initialization: bool = False, |
| latent_shape=[5, 40, 64], |
| apply_yarn=False, |
| original_latent_shape=None, |
| beta_fast=32, |
| beta_slow=1, |
| scale=None, |
| max_position_embeddings=None, |
| original_max_position_embeddings=None, |
| extrapolation_factor=1, |
| attn_factor=1, |
| ) -> None: |
| super().__init__( |
| kv_channels=kv_channels, |
| rotary_base=rotary_base, |
| rotary_percent=1.0, |
| use_cpu_initialization=use_cpu_initialization, |
| ) |
| self.latent_shape = latent_shape |
| self.device = "cpu" if use_cpu_initialization else torch.cuda.current_device() |
| self.dim = kv_channels |
| self.rope_theta = rotary_base |
| self.apply_yarn = apply_yarn |
| self.original_latent_shape = original_latent_shape |
| self.beta_fast = beta_fast |
| self.beta_slow = beta_slow |
| self.scale = scale |
| self.max_position_embeddings = max_position_embeddings |
| self.original_max_position_embeddings = original_max_position_embeddings |
| self.attn_factor = attn_factor |
| dim_h = self.dim // 6 * 2 |
| dim_t = self.dim - 2 * dim_h |
| self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.device) / dim_h |
| spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range) |
| self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.device) / dim_t |
| temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range) |
| if self.apply_yarn: |
| assert self.original_latent_shape is not None, "Original latent shape required." |
| assert self.beta_slow is not None, "Beta slow value required." |
| assert self.beta_fast is not None, "Beta fast value required." |
| scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1]) |
| spatial_inv_freq = spatial_inv_freq * scale_factors_spatial |
| scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0]) |
| temporal_inv_freq = temporal_inv_freq * scale_factors_temporal |
| self.mscale = float(self.get_mscale(self.scale) * self.attn_factor) |
| self.spatial_inv_freq = spatial_inv_freq |
| self.temporal_inv_freq = temporal_inv_freq |
| max_seq_len_cached = max(self.latent_shape) |
| if self.apply_yarn and seq_len > max_seq_len_cached: |
| max_seq_len_cached = seq_len |
| self.max_seq_len_cached = max_seq_len_cached |
| self.freqs = self.get_freqs_non_repeated(self.max_seq_len_cached) |
|
|
| def get_mscale(self, scale: float = 1.0) -> float: |
| """Get the magnitude scaling factor for YaRN.""" |
| if scale <= 1: |
| return 1.0 |
| return 0.1 * math.log(scale) + 1.0 |
|
|
| def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor: |
| """Get the scale factors for YaRN.""" |
| |
| |
| high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len |
| low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len |
| |
| |
| smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1) |
| |
| scale_factors = (1 - smooth_mask) / self.scale + smooth_mask |
| return scale_factors |
|
|
| def get_freqs_non_repeated(self, max_seq_len_cached: int, offset: int = 0) -> Tensor: |
| dtype = self.spatial_inv_freq.dtype |
| device = self.spatial_inv_freq.device |
|
|
| self.seq = (torch.arange(max_seq_len_cached, device=device, dtype=dtype) + offset).cuda() |
|
|
| assert hasattr( |
| self, "latent_shape" |
| ), "Latent shape is not set. Please run set_latent_shape() method on rope embedding. " |
| T, H, W = self.latent_shape |
| half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq.cuda()) |
| half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq.cuda()) |
| half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq.cuda()) |
| emb = torch.cat( |
| [ |
| repeat(half_emb_t, "t d -> t h w d", h=H, w=W), |
| repeat(half_emb_h, "h d -> t h w d", t=T, w=W), |
| repeat(half_emb_w, "w d -> t h w d", t=T, h=H), |
| ] |
| * 2, |
| dim=-1, |
| ) |
| emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float() |
| return emb |
|
|
| @lru_cache(maxsize=32) |
| def forward(self, seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: |
| if self.spatial_inv_freq.device.type == "cpu": |
| |
| self.spatial_inv_freq = self.spatial_inv_freq.to(device=torch.cuda.current_device()) |
|
|
| max_seq_len_cached = self.max_seq_len_cached |
| if self.apply_yarn and seq_len > max_seq_len_cached: |
| max_seq_len_cached = seq_len |
| self.max_seq_len_cached = max_seq_len_cached |
| emb = self.get_freqs_non_repeated(self.max_seq_len_cached) |
| return emb |
|
|
|
|
| if TYPE_CHECKING: |
| from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
|
|
|
| @dataclass |
| class CosmosConfig(Llama3Config): |
| qk_layernorm: bool = True |
| rope_dim: str = "3D" |
| vocab_size: int = 64000 |
| activation_func = F.silu |
|
|
| def configure_model(self, tokenizer) -> "MCoreGPTModel": |
| model = super().configure_model(tokenizer) |
| if self.rope_dim == "3D": |
| model.rotary_pos_emb = RotaryEmbedding3D( |
| seq_len=self.seq_length, |
| training_type=None, |
| kv_channels=self.kv_channels, |
| max_position_embeddings=self.seq_length, |
| original_max_position_embeddings=self.original_seq_len if hasattr(self, "original_seq_len") else None, |
| rotary_base=self.rotary_base, |
| apply_yarn=True if hasattr(self, "apply_yarn") else False, |
| scale=self.yarn_scale if hasattr(self, "yarn_scale") else None, |
| extrapolation_factor=1, |
| attn_factor=1, |
| beta_fast=self.yarn_beta_fast if hasattr(self, "yarn_beta_fast") else 32, |
| beta_slow=self.yarn_beta_slow if hasattr(self, "yarn_beta_slow") else 1, |
| latent_shape=[5, 40, 64], |
| original_latent_shape=self.original_latent_shape if hasattr(self, "original_latent_shape") else None, |
| ) |
| return model |
|
|
|
|
| @dataclass |
| class CosmosConfig4B(CosmosConfig): |
| rotary_base: int = 500_000 |
| seq_length: int = 15360 |
| num_layers: int = 16 |
| hidden_size: int = 4096 |
| ffn_hidden_size: int = 14336 |
| num_attention_heads: int = 32 |
| num_query_groups: int = 8 |
| layernorm_epsilon: float = 1e-5 |
| use_cpu_initialization: bool = True |
| make_vocab_size_divisible_by: int = 128 |
| kv_channels: int = 128 |
|
|
|
|
| @dataclass |
| class CosmosConfig12B(CosmosConfig): |
| rotary_base: int = 500_000 |
| seq_length: int = 15360 |
| num_layers: int = 40 |
| hidden_size: int = 5120 |
| ffn_hidden_size: int = 14336 |
| num_attention_heads: int = 32 |
| num_query_groups: int = 8 |
| layernorm_epsilon: float = 1e-5 |
| use_cpu_initialization: bool = True |
| make_vocab_size_divisible_by: int = 128 |
| kv_channels: int = 128 |
| original_latent_shape = [3, 40, 64] |
| apply_yarn: bool = True |
| yarn_beta_fast: int = 4 |
| yarn_beta_slow: int = 1 |
| yarn_scale: int = 2 |
| original_seq_len = 8192 |
|
|
|
|
| class CosmosModel(LlamaModel): |
| def __init__( |
| self, |
| config: Annotated[Optional[CosmosConfig], Config[CosmosConfig]] = None, |
| optim: Optional[OptimizerModule] = None, |
| tokenizer: Optional["TokenizerSpec"] = None, |
| model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, |
| ): |
| super().__init__(config or CosmosConfig4B(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) |
| self.config = config |
|
|
|
|
| @io.state_transform( |
| source_key=( |
| "model.layers.*.feed_forward.w1.weight", |
| "model.layers.*.feed_forward.w3.weight", |
| ), |
| target_key="decoder.layers.*.mlp.linear_fc1.weight", |
| ) |
| def _mlp_glu(ctx: io.TransformCTX, w1, w3): |
| return torch.cat((w1, w3), axis=0) |
|
|
|
|
| @io.state_transform( |
| source_key=( |
| "model.layers.*.attention.wq.weight", |
| "model.layers.*.attention.wk.weight", |
| "model.layers.*.attention.wv.weight", |
| ), |
| target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
| ) |
| def _import_qkv_cosmos(ctx: io.TransformCTX, q, k, v): |
| megatron_config = ctx.target.config |
|
|
| head_num = megatron_config.num_attention_heads |
| num_query_groups = megatron_config.num_query_groups |
| heads_per_group = head_num // num_query_groups |
| hidden_size = megatron_config.hidden_size |
| head_size = megatron_config.kv_channels |
|
|
| old_tensor_shape = q.size() |
| new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] |
| new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] |
|
|
| q = q.view(*new_q_tensor_shape) |
| k = k.view(*new_kv_tensor_shape) |
| v = v.view(*new_kv_tensor_shape) |
|
|
| qkv_weights_l = [] |
| for i in range(num_query_groups): |
| qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) |
| qkv_weights_l.append(k[i : i + 1, :, :]) |
| qkv_weights_l.append(v[i : i + 1, :, :]) |
| qkv_weights = torch.cat(qkv_weights_l) |
| assert qkv_weights.ndim == 3, qkv_weights.shape |
| assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape |
| assert qkv_weights.shape[1] == head_size, qkv_weights.shape |
| assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape |
|
|
| qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) |
|
|
| return qkv_weights |
|
|
|
|
| @io.model_importer(CosmosModel, "pt") |
| class PTCosmosImporter(io.ModelConnector["PTCosmosModel", CosmosModel]): |
| def init(self) -> CosmosModel: |
| return CosmosModel(self.config, tokenizer=self.tokenizer) |
|
|
| def apply(self, output_path: Path) -> Path: |
| pt_model_path = str(self) |
| cosmos_model_state_dict = torch.load(pt_model_path, map_location="cpu") |
| for k, v in cosmos_model_state_dict.items(): |
| |
| cosmos_model_state_dict[k] = v.float() |
|
|
| |
| class WrapperCosmos: |
| def __init__(self, model_state_dict): |
| self.model_state_dict = model_state_dict |
|
|
| def state_dict(self): |
| return self.model_state_dict |
|
|
| source = WrapperCosmos(cosmos_model_state_dict) |
| target = self.init() |
| trainer = self.nemo_setup(target) |
| self.convert_state(source, target) |
| self.nemo_save(output_path, trainer) |
|
|
| log.info(f"Converted PT Cosmos model to Nemo, model saved to {output_path}") |
|
|
| teardown(trainer, target) |
| del trainer, target |
|
|
| return output_path |
|
|
| def convert_state(self, source, target): |
| mapping = { |
| "model.tok_embeddings.weight": "embedding.word_embeddings.weight", |
| "model.layers.*.attention.wo.weight": "decoder.layers.*.self_attention.linear_proj.weight", |
| "model.layers.*.attention.q_norm.weight": "decoder.layers.*.self_attention.q_layernorm.weight", |
| "model.layers.*.attention.k_norm.weight": "decoder.layers.*.self_attention.k_layernorm.weight", |
| "model.layers.*.attention_norm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", |
| "model.layers.*.feed_forward.w2.weight": "decoder.layers.*.mlp.linear_fc2.weight", |
| "model.layers.*.ffn_norm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", |
| "model.norm.weight": "decoder.final_layernorm.weight", |
| "model.output.weight": "output_layer.weight", |
| } |
|
|
| return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv_cosmos, _mlp_glu]) |
|
|
| @property |
| def tokenizer(self): |
| return None |
|
|
| @property |
| def config(self): |
| if "4B" in str(self) or "4b" in str(self): |
| return CosmosConfig4B() |
| elif "12B" in str(self) or "12b" in str(self): |
| return CosmosConfig12B() |
| else: |
| raise ValueError("Unable to infer model size from checkpoint") |
|
|