| # BaseTrainer | |
| ## π Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL) | |
| EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning | |
| models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for | |
| training Flax/Jax models on TPU/GPU, for both serving and training purposes. | |
| ## π¦ Installation & Usage | |
| ```python | |
| from easydel import AutoEasyDeLModelForCausalLM | |
| from jax import numpy as jnp, lax | |
| model = AutoEasyDeLModelForCausalLM.from_pretrained( | |
| f"REPO_ID/BaseTrainer", | |
| dtype=..., | |
| param_dtype=..., | |
| precision=lax.Precision("fastest"), | |
| auto_shard_model=True, | |
| ) | |
| ``` | |
| ## π§ Training Configuration | |
| ### Model Details | |
| - **Architecture**: gemma3_text | |
| - **Platform**: TPU | |
| - **Number of Devices**: 16 | |
| ### Training Parameters | |
| - **Learning Rate**: 4e-05 β 4e-06 | |
| - **Optimizer**: adamw | |
| - **Scheduler**: cosine | |
| - **Warmup Steps**: 50 | |
| - **Weight Decay**: 0.02 | |
| - **Loss Config**: LossConfig( | |
| ignore_index : -100 | |
| label_smoothing : 0.0 | |
| z_loss : 0.0 | |
| loss_normalizing_factor : NUM_REAL_TARGET_TOKENS | |
| num_labels : None | |
| problem_type : None | |
| divide_weight_sum : False | |
| shift_tokens : True | |
| break_on_nan : True | |
| reduction : None | |
| num_classification_labels : None | |
| classification_problem_type : None | |
| ) | |
| ### Training Setup | |
| - **Epochs**: 3 | |
| - **Batch Size**: 8 | |
| - **Sequence Length**: 8192 | |
| - **Dtype**: <class 'jax.numpy.bfloat16'> | |
| - **Params Dtype**: <class 'jax.numpy.bfloat16'> | |
| ### Advanced Configuration | |
| - **Gradient Checkpointing**: | |
| - **Gradient Accumulation Steps**: 1 | |
| - **Max Training Steps**: None | |
| - **Max Evaluation Steps**: None | |
| - **Training Duration**: 7H | |
| ### Sharding Configuration | |
| ```python | |
| # Partition Rules | |
| ( ('model/embed_tokens/embedding', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('self_attn/q_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), | |
| ('self_attn/k_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), | |
| ('self_attn/v_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), | |
| ('self_attn/o_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), | |
| ('input_layernorm/kernel', PartitionSpec(None,)), | |
| ('post_attention_layernorm/kernel', PartitionSpec(None,)), | |
| ('pre_feedforward_layernorm/kernel', PartitionSpec(None,)), | |
| ('post_feedforward_layernorm/kernel', PartitionSpec(None,)), | |
| ('model/norm/kernel', PartitionSpec(None,)), | |
| ('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), | |
| ('.*', PartitionSpec(None,))) | |
| ``` | |
| --- | |
| *Generated with EasyDeL v0.1.3* | |