rubenaghayan's picture
fsdp + bugfixes
f45427d
raw
history blame contribute delete
717 Bytes
from dataclasses import dataclass
from dtypes import DType
@dataclass
class Model:
vocab_size: int
num_layers: int
hidden_dim: int
intermediate_size: int
weight_tied_embeddings: bool
active_experts: int
total_experts: int
is_moe: bool
@dataclass
class Parallelism:
tensor_parallelism: int
pipeline_parallelism: int
context_parallelism: int
expert_parallelism: int
fsdp_enabled: bool
fsdp_parallelism: int
fsdp_strategy: str
@dataclass
class Training:
sequence_length: int
batch_size: int
gradient_checkpointing: bool
grad_accumulation: bool
precision: DType
mixed_precision: bool
param_dtype: DType
reduce_dtype: DType