| # BaseTrainer | |
| ## π Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL) | |
| EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning | |
| models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for | |
| training Flax/Jax models on TPU/GPU, for both serving and training purposes. | |
| ## π¦ Installation & Usage | |
| ```python | |
| from easydel import AutoEasyDeLModelForCausalLM | |
| from jax import numpy as jnp, lax | |
| model = AutoEasyDeLModelForCausalLM.from_pretrained( | |
| f"REPO_ID/BaseTrainer", | |
| dtype=..., | |
| param_dtype=..., | |
| precision=lax.Precision("fastest"), | |
| auto_shard_model=True, | |
| ) | |
| ``` | |
| ## π§ Training Configuration | |
| ### Model Details | |
| - **Architecture**: qwen2 | |
| - **Platform**: TPU | |
| - **Number of Devices**: 16 | |
| ### Training Parameters | |
| - **Learning Rate**: 5e-05 β 5e-06 | |
| - **Optimizer**: adamw | |
| - **Scheduler**: cosine | |
| - **Warmup Steps**: 160 | |
| - **Weight Decay**: 0.02 | |
| - **Loss Config**: LossConfig( | |
| ignore_index: -100 | |
| label_smoothing: 0.0 | |
| z_loss: 0.0 | |
| loss_normalizing_factor: 'NUM_REAL_TARGET_TOKENS' | |
| num_labels: None | |
| problem_type: None | |
| divide_weight_sum: False | |
| shift_tokens: True | |
| break_on_nan: True | |
| reduction: None | |
| num_classification_labels: None | |
| classification_problem_type: None | |
| ) | |
| ### Training Setup | |
| - **Epochs**: 5 | |
| - **Batch Size**: 16 | |
| - **Sequence Length**: 4096 | |
| - **Dtype**: <class 'jax.numpy.bfloat16'> | |
| - **Params Dtype**: <class 'jax.numpy.bfloat16'> | |
| ### Advanced Configuration | |
| - **Gradient Checkpointing**: | |
| - **Gradient Accumulation Steps**: 1 | |
| - **Max Training Steps**: None | |
| - **Max Evaluation Steps**: None | |
| - **Training Duration**: 7H | |
| ### Sharding Configuration | |
| ```python | |
| # Partition Rules | |
| ( ('model/embed_tokens/embedding', PartitionSpec('tp', ('fsdp', 'sp'))), | |
| ( 'self_attn/(q_proj|k_proj|v_proj)/kernel', | |
| PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('self_attn/o_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), | |
| ('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), | |
| ('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('input_layernorm/kernel', PartitionSpec(None,)), | |
| ('post_attention_layernorm/kernel', PartitionSpec(None,)), | |
| ('model/norm/kernel', PartitionSpec(None,)), | |
| ('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('.*', PartitionSpec(None,))) | |
| ``` | |
| --- | |
| *Generated with EasyDeL v0.1.2* | |