# Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass, field from typing import List, Optional, Union class LLMModelArch: qwen = 'qwen' llama = 'llama' internlm2 = 'internlm2' chatglm = 'chatglm' deepseek_v2 = 'deepseek_v2' baichuan = 'baichuan' yuan = 'yuan' codefuse = 'codefuse' phi2 = 'phi2' phi3 = 'phi3' phi3_small = 'phi3_small' telechat = 'telechat' dbrx = 'dbrx' class MLLMModelArch: qwen_vl = 'qwen_vl' qwen_audio = 'qwen_audio' qwen2_vl = 'qwen2_vl' qwen2_audio = 'qwen2_audio' qwen2_5_omni = 'qwen2_5_omni' cogvlm = 'cogvlm' glm4v = 'glm4v' glm_edge_v = 'glm_edge_v' llama3_1_omni = 'llama3_1_omni' llama3_2_vision = 'llama3_2_vision' llama4 = 'llama4' llava_hf = 'llava_hf' llava_next_video_hf = 'llava_next_video_hf' llava_llama = 'llava_llama' llava_mistral = 'llava_mistral' xcomposer = 'xcomposer' internvl = 'internvl' minicpmv = 'minicpmv' deepseek_vl = 'deepseek_vl' deepseek_vl2 = 'deepseek_vl2' deepseek_janus = 'deepseek_janus' mplug_owl2 = 'mplug_owl2' mplug_owl2_1 = 'mplug_owl2_1' mplug_owl3 = 'mplug_owl3' doc_owl2 = 'doc_owl2' phi3_vision = 'phi3_vision' phi4_multimodal = 'phi4_multimodal' florence = 'florence' idefics3 = 'idefics3' got_ocr2 = 'got_ocr2' got_ocr2_hf = 'got_ocr2_hf' ovis1_6 = 'ovis1_6' molmo = 'molmo' emu3_chat = 'emu3_chat' megrez_omni = 'megrez_omni' valley = 'valley' gemma3_vision = 'gemma3_vision' mistral_2503 = 'mistral_2503' class ModelArch(LLMModelArch, MLLMModelArch): pass @dataclass class ModelKeys: arch_name: str = None embedding: str = None module_list: str = None lm_head: str = None q_proj: str = None k_proj: str = None v_proj: str = None o_proj: str = None attention: str = None mlp: str = None down_proj: str = None qkv_proj: str = None qk_proj: str = None qa_proj: str = None qb_proj: str = None kv_proj: str = None kva_proj: str = None kvb_proj: str = None @dataclass class MultiModelKeys(ModelKeys): language_model: Union[str, List[str]] = field(default_factory=list) aligner: Union[str, List[str]] = field(default_factory=list) vision_tower: Union[str, List[str]] = field(default_factory=list) generator: Union[str, List[str]] = field(default_factory=list) def __post_init__(self): for key in ['language_model', 'aligner', 'vision_tower', 'generator']: v = getattr(self, key) if isinstance(v, str): setattr(self, key, [v]) if v is None: setattr(self, key, []) MODEL_ARCH_MAPPING = {} def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> None: """ model_type: The unique ID for the model type. Models with the same model_type share the same architectures, template, get_function, etc. """ arch_name = model_arch.arch_name if not exist_ok and arch_name in MODEL_ARCH_MAPPING: raise ValueError(f'The `{arch_name}` has already been registered in the MODEL_ARCH_MAPPING.') MODEL_ARCH_MAPPING[arch_name] = model_arch register_model_arch( ModelKeys( LLMModelArch.llama, module_list='model.layers', mlp='model.layers.{}.mlp', down_proj='model.layers.{}.mlp.down_proj', attention='model.layers.{}.self_attn', o_proj='model.layers.{}.self_attn.o_proj', q_proj='model.layers.{}.self_attn.q_proj', k_proj='model.layers.{}.self_attn.k_proj', v_proj='model.layers.{}.self_attn.v_proj', embedding='model.embed_tokens', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.internlm2, module_list='model.layers', mlp='model.layers.{}.feed_forward', down_proj='model.layers.{}.feed_forward.w2', attention='model.layers.{}.attention', o_proj='model.layers.{}.attention.wo', qkv_proj='model.layers.{}.attention.wqkv', embedding='model.tok_embeddings', lm_head='output', )) register_model_arch( ModelKeys( LLMModelArch.chatglm, module_list='transformer.encoder.layers', mlp='transformer.encoder.layers.{}.mlp', down_proj='transformer.encoder.layers.{}.mlp.dense_4h_to_h', attention='transformer.encoder.layers.{}.self_attention', o_proj='transformer.encoder.layers.{}.self_attention.dense', qkv_proj='transformer.encoder.layers.{}.self_attention.query_key_value', embedding='transformer.embedding', lm_head='transformer.output_layer')) register_model_arch( ModelKeys( LLMModelArch.telechat, module_list='transformer.h', mlp='transformer.h.{}.mlp', down_proj='transformer.h.{}.mlp.down_proj', attention='transformer.h.{}.self_attention', o_proj='transformer.h.{}.self_attention.dense', q_proj='transformer.h.{}.self_attention.query', kv_proj='transformer.h.{}.self_attention.key_value', embedding='transformer.word_embeddings', lm_head='lm_head')) register_model_arch( ModelKeys( LLMModelArch.baichuan, module_list='model.layers', mlp='model.layers.{}.mlp', down_proj='model.layers.{}.mlp.down_proj', attention='model.layers.{}.self_attn', qkv_proj='model.layers.{}.self_attn.W_pack', embedding='model.embed_tokens', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.yuan, module_list='model.layers', mlp='model.layers.{}.mlp', down_proj='model.layers.{}.mlp.down_proj', attention='model.layers.{}.self_attn', qk_proj='model.layers.{}.self_attn.qk_proj', o_proj='model.layers.{}.self_attn.o_proj', q_proj='model.layers.{}.self_attn.q_proj', k_proj='model.layers.{}.self_attn.k_proj', v_proj='model.layers.{}.self_attn.v_proj', embedding='model.embed_tokens', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.codefuse, module_list='gpt_neox.layers', mlp='gpt_neox.layers.{}.mlp', down_proj='gpt_neox.layers.{}.mlp.dense_4h_to_h', attention='gpt_neox.layers.{}.attention', o_proj='gpt_neox.layers.{}.attention.dense', qkv_proj='gpt_neox.layers.{}.attention.query_key_value', embedding='gpt_neox.embed_in', lm_head='gpt_neox.embed_out', )) register_model_arch( ModelKeys( LLMModelArch.phi2, module_list='model.layers', mlp='model.layers.{}.mlp', down_proj='model.layers.{}.mlp.fc2', attention='model.layers.{}.self_attn', o_proj='model.layers.{}.self_attn.dense', q_proj='model.layers.{}.self_attn.q_proj', k_proj='model.layers.{}.self_attn.k_proj', v_proj='model.layers.{}.self_attn.v_proj', embedding='model.embed_tokens', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.qwen, module_list='transformer.h', mlp='transformer.h.{}.mlp', down_proj='transformer.h.{}.mlp.c_proj', attention='transformer.h.{}.attn', o_proj='transformer.h.{}.attn.c_proj', qkv_proj='transformer.h.{}.attn.c_attn', embedding='transformer.wte', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.dbrx, module_list='transformer.blocks', mlp='transformer.blocks.{}.ffn', attention='transformer.blocks.{}.norm_attn_norm.attn', o_proj='transformer.blocks.{}.norm_attn_norm.attn.out_proj', qkv_proj='transformer.blocks.{}.norm_attn_norm.attn.Wqkv', embedding='transformer.wte', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.phi3, module_list='model.layers', mlp='model.layers.{}.mlp', down_proj='model.layers.{}.mlp.down_proj', attention='model.layers.{}.self_attn', o_proj='model.layers.{}.self_attn.o_proj', qkv_proj='model.layers.{}.self_attn.qkv_proj', embedding='model.embed_tokens', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.phi3_small, module_list='model.layers', mlp='model.layers.{}.mlp', down_proj='model.layers.{}.mlp.down_proj', attention='model.layers.{}.self_attn', o_proj='model.layers.{}.self_attn.dense', qkv_proj='model.layers.{}.self_attn.query_key_value', embedding='model.embed_tokens', lm_head='lm_head', )) register_model_arch( ModelKeys( LLMModelArch.deepseek_v2, module_list='model.layers', mlp='model.layers.{}.mlp', down_proj='model.layers.{}.mlp.down_proj', attention='model.layers.{}.self_attn', o_proj='model.layers.{}.self_attn.o_proj', qa_proj='model.layers.{}.self_attn.q_a_proj', qb_proj='model.layers.{}.self_attn.q_b_proj', kva_proj='model.layers.{}.self_attn.kv_a_proj_with_mqa', kvb_proj='model.layers.{}.self_attn.kv_b_proj', embedding='model.embed_tokens', lm_head='lm_head', )) register_model_arch( MultiModelKeys( MLLMModelArch.llava_hf, language_model='language_model', aligner='multi_modal_projector', vision_tower='vision_tower', )) register_model_arch( MultiModelKeys( MLLMModelArch.llava_mistral, language_model='model.layers', aligner='model.mm_projector', vision_tower='model.vision_tower', )) register_model_arch( MultiModelKeys( MLLMModelArch.llava_next_video_hf, language_model='language_model', aligner=['multi_modal_projector'], vision_tower='vision_tower')) register_model_arch( MultiModelKeys( MLLMModelArch.llava_llama, language_model='model.layers', aligner='model.mm_projector', vision_tower='model.vision_tower', )) register_model_arch( MultiModelKeys( MLLMModelArch.xcomposer, language_model='model', aligner='vision_proj', vision_tower='vit', )) register_model_arch( MultiModelKeys( MLLMModelArch.internvl, language_model='language_model', aligner='mlp1', vision_tower='vision_model', )) register_model_arch( MultiModelKeys( MLLMModelArch.mplug_owl3, language_model='language_model', aligner='vision2text_model', vision_tower='vision_model', )) register_model_arch( MultiModelKeys( MLLMModelArch.doc_owl2, language_model='model.layers', aligner=['model.vision2text', 'model.hr_compressor'], vision_tower='model.vision_model', )) register_model_arch( MultiModelKeys( MLLMModelArch.deepseek_vl, language_model='language_model', aligner='aligner', vision_tower='vision_model', )) register_model_arch( MultiModelKeys( MLLMModelArch.deepseek_janus, language_model='language_model', vision_tower='vision_model', aligner='aligner', generator=['gen_vision_model', 'gen_aligner', 'gen_head', 'gen_embed'])) register_model_arch( MultiModelKeys( MLLMModelArch.deepseek_vl2, language_model='language', vision_tower='vision', aligner='projector', )) register_model_arch( MultiModelKeys( MLLMModelArch.minicpmv, language_model='llm', aligner='resampler', vision_tower='vpm', )) register_model_arch( MultiModelKeys( MLLMModelArch.phi3_vision, language_model='model.layers', aligner='model.vision_embed_tokens.img_projection', vision_tower='model.vision_embed_tokens.img_processor', )) register_model_arch( MultiModelKeys( MLLMModelArch.phi4_multimodal, language_model='model.layers', aligner=[ 'model.embed_tokens_extend.image_embed.img_projection', 'model.embed_tokens_extend.audio_embed.audio_projection' ], vision_tower=[ 'model.embed_tokens_extend.image_embed.img_processor', 'model.embed_tokens_extend.audio_embed.encoder' ], )) register_model_arch(MultiModelKeys( MLLMModelArch.cogvlm, language_model='model.layers', vision_tower='model.vision', )) register_model_arch( MultiModelKeys( MLLMModelArch.florence, language_model='language_model', vision_tower='vision_tower', )) register_model_arch( MultiModelKeys( MLLMModelArch.qwen_vl, language_model='transformer.h', vision_tower='transformer.visual', )) # TODO: check lm_head, ALL register_model_arch( MultiModelKeys( MLLMModelArch.qwen_audio, language_model='transformer.h', vision_tower='transformer.audio', )) register_model_arch( MultiModelKeys( MLLMModelArch.qwen2_audio, language_model='language_model', aligner='multi_modal_projector', vision_tower='audio_tower', )) register_model_arch( MultiModelKeys( MLLMModelArch.qwen2_vl, language_model='model', aligner='visual.merger', vision_tower='visual', )) register_model_arch( MultiModelKeys( MLLMModelArch.qwen2_5_omni, language_model='thinker.model', vision_tower=['thinker.audio_tower', 'thinker.visual'], aligner=['thinker.audio_tower.proj', 'thinker.visual.merger'], generator=['talker', 'token2wav'], )) register_model_arch( MultiModelKeys( MLLMModelArch.glm4v, language_model='transformer.encoder', vision_tower='transformer.vision', )) register_model_arch( MultiModelKeys( MLLMModelArch.idefics3, language_model='model.text_model', aligner='model.connector', vision_tower='model.vision_model', )) register_model_arch( MultiModelKeys( MLLMModelArch.llama3_1_omni, language_model='model.layers', aligner='model.speech_projector', vision_tower='model.speech_encoder', generator='speech_generator', )) register_model_arch( MultiModelKeys( MLLMModelArch.got_ocr2, language_model='model.layers', aligner='model.mm_projector_vary', vision_tower='model.vision_tower_high', )) register_model_arch( MultiModelKeys( MLLMModelArch.llama3_2_vision, language_model='language_model', aligner='multi_modal_projector', vision_tower='vision_model', )) register_model_arch(MultiModelKeys( MLLMModelArch.ovis1_6, language_model='llm', vision_tower='visual_tokenizer', )) register_model_arch( MultiModelKeys( MLLMModelArch.molmo, language_model='model.transformer', vision_tower='model.vision_backbone', aligner='model.vision_backbone.image_projector')) register_model_arch( MultiModelKeys( MLLMModelArch.megrez_omni, language_model='llm', vision_tower=['vision', 'audio'], )) register_model_arch(MultiModelKeys(MLLMModelArch.emu3_chat, language_model='model')) register_model_arch( MultiModelKeys(MLLMModelArch.glm_edge_v, language_model='model.layers', vision_tower='model.vision')) register_model_arch( MultiModelKeys( MLLMModelArch.valley, language_model='model', vision_tower=['model.vision_tower', 'model.qwen2vl_vision_tower'], )) register_model_arch( MultiModelKeys( MLLMModelArch.gemma3_vision, language_model='language_model', aligner='multi_modal_projector', vision_tower='vision_tower', )) def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]: return MODEL_ARCH_MAPPING.get(arch_name)