|
|
from dataclasses import dataclass, field |
|
|
from enum import auto |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
from strenum import LowercaseStrEnum |
|
|
|
|
|
from tensorrt_llm._utils import BaseEnumMeta, DictConversion |
|
|
|
|
|
from .cluster_info import ClusterInfo, cluster_infos |
|
|
|
|
|
|
|
|
class CostModel(LowercaseStrEnum, metaclass=BaseEnumMeta): |
|
|
ALPHA_BETA = auto() |
|
|
PROFILE = auto() |
|
|
S_CURVE = auto() |
|
|
|
|
|
|
|
|
|
|
|
ZERO = auto() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AutoParallelConfig(DictConversion): |
|
|
|
|
|
world_size: int = 1 |
|
|
gpus_per_node: int = 8 |
|
|
cluster_key: str = None |
|
|
cluster_info: Optional[ClusterInfo] = None |
|
|
|
|
|
|
|
|
sharding_cost_model: str = CostModel.ALPHA_BETA |
|
|
comm_cost_model: str = CostModel.ALPHA_BETA |
|
|
|
|
|
|
|
|
enable_pipeline_parallelism: bool = False |
|
|
enable_shard_unbalanced_shape: bool = False |
|
|
enable_shard_dynamic_shape: bool = False |
|
|
enable_reduce_scatter: bool = True |
|
|
|
|
|
|
|
|
builder_flags: Optional[int] = None |
|
|
debug_mode: bool = False |
|
|
infer_shape: bool = True |
|
|
validation_mode: bool = False |
|
|
same_buffer_io: Dict[str, str] = field(default_factory=dict) |
|
|
same_spec_io: Dict[str, str] = field(default_factory=dict) |
|
|
sharded_io_allowlist: List[str] = field(default_factory=list) |
|
|
fill_weights: bool = False |
|
|
|
|
|
|
|
|
parallel_config_cache: Optional[str] = None |
|
|
profile_cache: Optional[str] = None |
|
|
dump_path: Optional[str] = None |
|
|
debug_outputs: Union[List[str], str] = field(default_factory=list) |
|
|
|
|
|
def get_cluster_info(self) -> ClusterInfo: |
|
|
return self.cluster_info or cluster_infos[self.cluster_key] |
|
|
|
|
|
@property |
|
|
def enabled(self) -> bool: |
|
|
return self.world_size > 1 |
|
|
|