| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| from transformers import AutoProcessor, MllamaConfig |
| from transformers.models.mllama.configuration_mllama import MllamaTextConfig, MllamaVisionConfig |
|
|
| from nemo import lightning as nl |
| from nemo.collections import vlm |
|
|
|
|
| def split_qkv_weight(qkv_weight, model_config): |
| """Split attention qkv from nemo to hf format""" |
| hidden_size = model_config.hidden_size |
| head_num = model_config.num_attention_heads |
| num_query_groups = model_config.num_query_groups or head_num |
| head_size = model_config.kv_channels or (hidden_size // head_num) |
| heads_per_group = head_num // num_query_groups |
| qkv_weight = qkv_weight.reshape(-1, head_size, hidden_size) |
| q_weight = torch.empty((head_num, head_size, hidden_size), device=qkv_weight.device) |
| k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) |
| v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) |
|
|
| qkv_index = 0 |
| for i in range(num_query_groups): |
| q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[ |
| qkv_index : qkv_index + heads_per_group, :, : |
| ] |
| qkv_index += heads_per_group |
| k_weight[i, :, :] = qkv_weight[qkv_index, :, :] |
| qkv_index += 1 |
| v_weight[i, :, :] = qkv_weight[qkv_index, :, :] |
| qkv_index += 1 |
|
|
| return [('q_proj', q_weight), ('k_proj', k_weight), ('v_proj', v_weight)] |
|
|
|
|
| def split_kv_weight(kv_weight, model_config): |
| """Split cross attention qkv from nemo to hf format""" |
| hidden_size = model_config.hidden_size |
| head_num = model_config.num_attention_heads |
| num_query_groups = model_config.num_query_groups or head_num |
| head_size = model_config.kv_channels or (hidden_size // head_num) |
| kv_weight = kv_weight.reshape(-1, head_size, hidden_size) |
| k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=kv_weight.device) |
| v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=kv_weight.device) |
|
|
| kv_index = 0 |
| for i in range(num_query_groups): |
| k_weight[i, :, :] = kv_weight[kv_index, :, :] |
| kv_index += 1 |
| v_weight[i, :, :] = kv_weight[kv_index, :, :] |
| kv_index += 1 |
|
|
| return [('k_proj', k_weight), ('v_proj', v_weight)] |
|
|
|
|
| def split_gate_weight(gate_weight): |
| """Split linear fc to gate""" |
| gate_weight = torch.chunk(gate_weight, 2, axis=0) |
|
|
| return [('gate_proj', gate_weight[0]), ('up_proj', gate_weight[1])] |
|
|
|
|
| def convert_mllama_config(source_vision, source_text): |
| """Convert nemo mllama config to hf config""" |
| vision_config = MllamaVisionConfig( |
| num_hidden_layers=source_vision.num_layers, |
| hidden_size=source_vision.hidden_size, |
| attention_heads=source_vision.num_attention_heads, |
| image_size=source_vision.vision_chunk_size, |
| max_num_tiles=source_vision.vision_max_num_chunks, |
| torch_dtype="bfloat16", |
| ) |
|
|
| cross_attention_layers = [ |
| x + i for i, x in enumerate(source_text._init_fusion_schedule(source_text.num_cross_attention_layers)) |
| ] |
| text_config = MllamaTextConfig( |
| rope_theta=source_text.rotary_base, |
| num_hidden_layers=source_text.num_layers + source_text.num_cross_attention_layers, |
| cross_attention_layers=cross_attention_layers, |
| hidden_size=source_text.hidden_size, |
| intermediate_size=source_text.ffn_hidden_size, |
| num_attention_heads=source_text.num_attention_heads, |
| num_key_value_heads=source_text.num_query_groups, |
| vocab_size=source_text.vocab_size, |
| rope_scaling={ |
| "factor": 8.0, |
| "high_freq_factor": 4.0, |
| "low_freq_factor": 1.0, |
| "original_max_position_embeddings": 8192, |
| "rope_type": "llama3", |
| }, |
| eos_token_id=[128001, 128008, 128009], |
| torch_dtype="bfloat16", |
| ) |
|
|
| return MllamaConfig(vision_config, text_config, torch_dtype="bfloat16") |
|
|
|
|
| def convert_mllama_nemo_to_hf(checkpoint_path, processor_name): |
| """Convert nemo mllama to hf state dict and config""" |
| processor = AutoProcessor.from_pretrained(processor_name) |
|
|
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=1, |
| ckpt_load_optimizer=False, |
| ckpt_save_optimizer=False, |
| ) |
| trainer = nl.Trainer( |
| devices=1, |
| max_steps=1000, |
| accelerator="gpu", |
| strategy=strategy, |
| plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), |
| val_check_interval=1000, |
| limit_val_batches=50, |
| ) |
|
|
| fabric = trainer.to_fabric() |
|
|
| tokenizer = processor.tokenizer |
| model = vlm.MLlamaModel(vlm.MLlamaConfig11BInstruct(), tokenizer=tokenizer) |
| config = model.config |
| vision_model_config = config.vision_model_config |
| language_model_config = config.language_model_config |
| model = fabric.load_model(checkpoint_path, model) |
| model = model.module.module.module.module |
|
|
| state_dict = model.state_dict() |
| del model |
|
|
| v = "vision_model.vision_encoder" |
| key_map = [ |
| ("vision_model.class_embedding", f"{v}.class_embedding"), |
| ("vision_model.gated_positional_embedding.embedding", f"{v}.positional_embedding"), |
| ( |
| "vision_model.gated_positional_embedding.tile_embedding.weight", |
| f"{v}.gated_tile_positional_embedding.weight", |
| ), |
| ("vision_model.gated_positional_embedding.gate", f"{v}.gated_positional_embedding_gate"), |
| ("vision_model.layernorm_post.bias", f"{v}.ln_post.bias"), |
| ("vision_model.layernorm_post.weight", f"{v}.ln_post.weight"), |
| ("vision_model.layernorm_pre.bias", f"{v}.ln_pre.bias"), |
| ("vision_model.layernorm_pre.weight", f"{v}.ln_pre.weight"), |
| ("vision_model.post_tile_positional_embedding.embedding.weight", f"{v}.post_tile_pos_embed.embedding.weight"), |
| ("vision_model.post_tile_positional_embedding.gate", f"{v}.post_tile_pos_embed.gate"), |
| ("vision_model.pre_tile_positional_embedding.embedding.weight", f"{v}.pre_tile_pos_embed.embedding.weight"), |
| ("vision_model.pre_tile_positional_embedding.gate", f"{v}.pre_tile_pos_embed.gate"), |
| ("multi_modal_projector.bias", "vision_model.vision_projection.encoder.bias"), |
| ("multi_modal_projector.weight", "vision_model.vision_projection.encoder.weight"), |
| ("language_model.model.norm.weight", "language_model.decoder.final_layernorm.weight"), |
| ("language_model.lm_head.weight", "language_model.output_layer.weight"), |
| ] |
|
|
| for i in range(vision_model_config.num_layers): |
| key_map.extend( |
| [ |
| ( |
| f"vision_model.transformer.layers.{i}.self_attn.o_proj.weight", |
| f"{v}.transformer.layers.{i}.self_attention.linear_proj.weight", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.input_layernorm.bias", |
| f"{v}.transformer.layers.{i}.input_layernorm.bias", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.input_layernorm.weight", |
| f"{v}.transformer.layers.{i}.input_layernorm.weight", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.post_attention_layernorm.bias", |
| f"{v}.transformer.layers.{i}.pre_mlp_layernorm.bias", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.post_attention_layernorm.weight", |
| f"{v}.transformer.layers.{i}.pre_mlp_layernorm.weight", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.mlp.fc1.bias", |
| f"{v}.transformer.layers.{i}.mlp.linear_fc1.bias", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.mlp.fc1.weight", |
| f"{v}.transformer.layers.{i}.mlp.linear_fc1.weight", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.mlp.fc2.bias", |
| f"{v}.transformer.layers.{i}.mlp.linear_fc2.bias", |
| ), |
| ( |
| f"vision_model.transformer.layers.{i}.mlp.fc2.weight", |
| f"{v}.transformer.layers.{i}.mlp.linear_fc2.weight", |
| ), |
| ] |
| ) |
|
|
| for i in range(vision_model_config.num_global_layers): |
| key_map.extend( |
| [ |
| ( |
| f"vision_model.global_transformer.layers.{i}.self_attn.o_proj.weight", |
| f"{v}.global_transformer.layers.{i}.self_attention.linear_proj.weight", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.gate_attn", |
| f"{v}.global_transformer.layers.{i}.gate_attn", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.gate_ffn", |
| f"{v}.global_transformer.layers.{i}.gate_ffn", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.input_layernorm.bias", |
| f"{v}.global_transformer.layers.{i}.input_layernorm.bias", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.input_layernorm.weight", |
| f"{v}.global_transformer.layers.{i}.input_layernorm.weight", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.post_attention_layernorm.bias", |
| f"{v}.global_transformer.layers.{i}.pre_mlp_layernorm.bias", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.post_attention_layernorm.weight", |
| f"{v}.global_transformer.layers.{i}.pre_mlp_layernorm.weight", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.mlp.fc1.bias", |
| f"{v}.global_transformer.layers.{i}.mlp.linear_fc1.bias", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.mlp.fc1.weight", |
| f"{v}.global_transformer.layers.{i}.mlp.linear_fc1.weight", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.mlp.fc2.bias", |
| f"{v}.global_transformer.layers.{i}.mlp.linear_fc2.bias", |
| ), |
| ( |
| f"vision_model.global_transformer.layers.{i}.mlp.fc2.weight", |
| f"{v}.global_transformer.layers.{i}.mlp.linear_fc2.weight", |
| ), |
| ] |
| ) |
|
|
| cross_attention_frequency = language_model_config.num_layers // language_model_config.num_cross_attention_layers |
| toal_num_layer = language_model_config.num_layers + language_model_config.num_cross_attention_layers |
| prefix = "language_model.decoder" |
| for i in range(toal_num_layer): |
| cross_num = (i - 3) // (cross_attention_frequency + 1) |
| if (i - 3) % (cross_attention_frequency + 1) == 0: |
| xattn_index = cross_num * cross_attention_frequency + 3 |
| key_map.extend( |
| [ |
| ( |
| f"language_model.model.layers.{i}.cross_attn.o_proj.weight", |
| f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_proj.weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.cross_attn.q_proj.weight", |
| f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_q.weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.cross_attn.k_norm.weight", |
| f"{prefix}.xattn_layers.{xattn_index}.cross_attention.k_layernorm.weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.input_layernorm.weight", |
| f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_q.layer_norm_weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.cross_attn.q_norm.weight", |
| f"{prefix}.xattn_layers.{xattn_index}.cross_attention.q_layernorm.weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.post_attention_layernorm.weight", |
| f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc1.layer_norm_weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.mlp.down_proj.weight", |
| f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc2.weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.cross_attn_attn_gate", |
| f"{prefix}.xattn_layers.{xattn_index}.gate_attn", |
| ), |
| ( |
| f"language_model.model.layers.{i}.cross_attn_mlp_gate", |
| f"{prefix}.xattn_layers.{xattn_index}.gate_ffn", |
| ), |
| ] |
| ) |
| else: |
| attn_index = i - cross_num - 1 |
| key_map.extend( |
| [ |
| ( |
| f"language_model.model.layers.{i}.self_attn.o_proj.weight", |
| f"{prefix}.layers.{attn_index}.self_attention.linear_proj.weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.post_attention_layernorm.weight", |
| f"{prefix}.layers.{attn_index}.mlp.linear_fc1.layer_norm_weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.mlp.down_proj.weight", |
| f"{prefix}.layers.{attn_index}.mlp.linear_fc2.weight", |
| ), |
| ( |
| f"language_model.model.layers.{i}.input_layernorm.weight", |
| f"{prefix}.layers.{attn_index}.self_attention.linear_qkv.layer_norm_weight", |
| ), |
| ] |
| ) |
|
|
| new_state_dict = {} |
| for new_key, old_key in key_map: |
| new_state_dict[new_key] = state_dict[old_key] |
|
|
| def convert_vision_qkv_weight(state_dict, vision_model_config): |
| hidden_size = vision_model_config.hidden_size |
|
|
| new_state_dict = {} |
| for i in range(vision_model_config.num_layers): |
| qkv_weights = state_dict[ |
| f"vision_model.vision_encoder.transformer.layers.{i}.self_attention.linear_qkv.weight" |
| ] |
|
|
| for name, weight in split_qkv_weight(qkv_weights, vision_model_config): |
| new_key = f'vision_model.transformer.layers.{i}.self_attn.{name}.weight' |
| new_state_dict[new_key] = weight.reshape(-1, hidden_size) |
|
|
| for i in range(vision_model_config.num_global_layers): |
| qkv_weights = state_dict[ |
| f"vision_model.vision_encoder.global_transformer.layers.{i}.self_attention.linear_qkv.weight" |
| ] |
|
|
| for name, weight in split_qkv_weight(qkv_weights, vision_model_config): |
| new_key = f'vision_model.global_transformer.layers.{i}.self_attn.{name}.weight' |
| new_state_dict[new_key] = weight.reshape(-1, hidden_size) |
|
|
| return new_state_dict |
|
|
| def convert_patch_embeding(state_dict): |
| conv1_weight = state_dict["vision_model.vision_encoder.conv1._linear.weight"] |
| return {"vision_model.patch_embedding.weight": conv1_weight.reshape(conv1_weight.shape[0], 3, 14, 14)} |
|
|
| def convert_language_qkv_weight(state_dict, language_model_config): |
| hidden_size = language_model_config.hidden_size |
| new_state_dict = {} |
| for i in range(toal_num_layer): |
| cross_num = (i - 3) // (cross_attention_frequency + 1) |
| if (i - 3) % (cross_attention_frequency + 1) == 0: |
| xattn_index = cross_num * cross_attention_frequency + 3 |
| kv_weights = state_dict[f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_kv.weight"] |
| for name, weight in split_kv_weight(kv_weights, language_model_config): |
| new_key = f"language_model.model.layers.{i}.cross_attn.{name}.weight" |
| new_state_dict[new_key] = weight.reshape(-1, hidden_size) |
| else: |
| attn_index = i - cross_num - 1 |
| qkv_weights = state_dict[f"{prefix}.layers.{attn_index}.self_attention.linear_qkv.weight"] |
| for name, weight in split_qkv_weight(qkv_weights, language_model_config): |
| new_key = f"language_model.model.layers.{i}.self_attn.{name}.weight" |
| new_state_dict[new_key] = weight.reshape(-1, hidden_size) |
|
|
| return new_state_dict |
|
|
| def convert_gate(state_dict): |
| new_state_dict = {} |
| for i in range(toal_num_layer): |
| cross_num = (i - 3) // (cross_attention_frequency + 1) |
| if (i - 3) % (cross_attention_frequency + 1) == 0: |
| xattn_index = cross_num * cross_attention_frequency + 3 |
| gate_weight = state_dict[f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc1.weight"] |
| else: |
| attn_index = i - cross_num - 1 |
| gate_weight = state_dict[f"{prefix}.layers.{attn_index}.mlp.linear_fc1.weight"] |
|
|
| for name, weight in split_gate_weight(gate_weight): |
| new_key = f"language_model.model.layers.{i}.mlp.{name}.weight" |
| new_state_dict[new_key] = weight |
|
|
| return new_state_dict |
|
|
| def convert_embedding(state_dict): |
| word_embeddings = state_dict["language_model.embedding.word_embeddings.weight"] |
| learnable_embedding = state_dict["language_model.learnable_embedding.weight"] |
|
|
| return {"language_model.model.embed_tokens.weight": torch.cat((word_embeddings, learnable_embedding), dim=0)} |
|
|
| new_state_dict.update(convert_vision_qkv_weight(state_dict, vision_model_config)) |
| new_state_dict.update(convert_patch_embeding(state_dict)) |
| new_state_dict.update(convert_language_qkv_weight(state_dict, language_model_config)) |
| new_state_dict.update(convert_gate(state_dict)) |
| new_state_dict.update(convert_embedding(state_dict)) |
|
|
| return new_state_dict, convert_mllama_config(vision_model_config, language_model_config) |
|
|