| |
| from dataclasses import dataclass, field |
| from typing import Any, Callable, Dict, List, Optional |
|
|
| import torch.nn as nn |
| from transformers import PretrainedConfig |
|
|
| from swift.llm import MODEL_MAPPING, ModelGroup |
|
|
| MEGATRON_MODEL_MAPPING = {} |
|
|
|
|
| @dataclass |
| class MegatronModelMeta: |
| megatron_model_type: str |
| model_types: List[str] |
|
|
| model_provider: Callable[[], nn.Module] |
| convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] |
| convert_mcore2hf: Callable[[nn.Module, nn.Module], None] |
| convert_hf2mcore: Callable[[nn.Module, nn.Module], None] |
|
|
|
|
| def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): |
| megatron_model_type = megatron_model_meta.megatron_model_type |
| for model_type in megatron_model_meta.model_types: |
| model_meta = MODEL_MAPPING[model_type] |
| model_meta.support_megatron = True |
| if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: |
| raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.') |
|
|
| MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta |
|
|
|
|
| _MODEL_META_MAPPING = None |
|
|
|
|
| def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]: |
| global _MODEL_META_MAPPING |
| if _MODEL_META_MAPPING is None: |
| _MODEL_META_MAPPING = {} |
| for k, megatron_model_meta in MEGATRON_MODEL_MAPPING.items(): |
| for _model_type in megatron_model_meta.model_types: |
| _MODEL_META_MAPPING[_model_type] = k |
| if model_type not in _MODEL_META_MAPPING: |
| return |
| return MEGATRON_MODEL_MAPPING[_MODEL_META_MAPPING[model_type]] |
|
|