Buckets:
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from tests.model_definition import MLP, MLPConfig, create_mlp_model_with_initial_params | |
| from tests.utils import CleanupCacheContext, enable_remote_debug | |
| def train_mlp_model( | |
| model: MLP, | |
| optimizer: torch.optim.Optimizer, | |
| device: torch.device, | |
| num_tokens: int, | |
| hidden_size: int, | |
| num_epochs: int, | |
| batches_per_epoch: int, | |
| gradient_accumulation_steps: int = 1, | |
| ) -> list[float]: | |
| """Execute training loop for MLP model (supports gradient accumulation) | |
| Args: | |
| model: MLP model to train | |
| optimizer: Optimizer | |
| device: Training device | |
| num_tokens: Number of tokens per batch | |
| hidden_size: Hidden layer dimension | |
| num_epochs: Number of training epochs | |
| batches_per_epoch: Number of batches per epoch | |
| gradient_accumulation_steps: Gradient accumulation steps, default is 1 (no accumulation) | |
| Returns: | |
| epoch_losses: List of average losses per epoch | |
| """ | |
| epoch_losses = [] | |
| print(f"Starting training: {num_epochs} epochs, {batches_per_epoch} batches per epoch") | |
| if gradient_accumulation_steps > 1: | |
| print(f"Using gradient accumulation, accumulation steps: {gradient_accumulation_steps}") | |
| for epoch in range(num_epochs): | |
| epoch_loss_sum = 0.0 | |
| for batch_idx in range(batches_per_epoch): | |
| # Zero gradients at the start of each accumulation cycle | |
| if batch_idx % gradient_accumulation_steps == 0: | |
| optimizer.zero_grad() | |
| # Generate random input and target data | |
| input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=torch.bfloat16) | |
| target_tensor = torch.ones(num_tokens, hidden_size, device=device, dtype=torch.float32) | |
| # Forward pass | |
| output = model(input_tensor) | |
| # Compute loss, divided by accumulation steps to maintain effective batch size consistency | |
| loss = nn.functional.mse_loss(output, target_tensor) / gradient_accumulation_steps | |
| # Backward pass (gradients are automatically accumulated) | |
| loss.backward() | |
| # Accumulate loss for logging (multiply by accumulation steps to restore original value) | |
| epoch_loss_sum += loss.item() * gradient_accumulation_steps | |
| # Update parameters after accumulating gradient_accumulation_steps batches | |
| if (batch_idx + 1) % gradient_accumulation_steps == 0: | |
| optimizer.step() | |
| # Handle the last incomplete accumulation batch | |
| if batches_per_epoch % gradient_accumulation_steps != 0: | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| avg_loss = epoch_loss_sum / batches_per_epoch | |
| epoch_losses.append(avg_loss) | |
| print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.6f}") | |
| print("Training completed!") | |
| return epoch_losses | |
| def verify_model_parameters_updated( | |
| initial_params: list[torch.Tensor], current_params: list[torch.Tensor], tolerance: float = 1e-6 | |
| ) -> bool: | |
| """Verify whether model parameters have been updated after training | |
| Args: | |
| initial_params: Parameter snapshot before training | |
| current_params: Current parameters after training | |
| tolerance: Tolerance for determining if parameters are the same | |
| Returns: | |
| True if parameters have been updated, False otherwise | |
| """ | |
| for initial_param, current_param in zip(initial_params, current_params): | |
| if not torch.allclose(initial_param, current_param, atol=tolerance): | |
| return True | |
| return False | |
| def test_mlp_training_with_magi_compiler(): | |
| """Test MLP training with magi_compiler in training scenario""" | |
| # Set device | |
| device = torch.device("cuda") | |
| # Create MLP configuration | |
| mlp_config = MLPConfig(hidden_size=8, intermediate_size=32, params_dtype=torch.bfloat16) | |
| # Create model and save initial parameters | |
| model, initial_params = create_mlp_model_with_initial_params(mlp_config, device) | |
| # Create optimizer | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) | |
| # Training parameters | |
| num_tokens = 8 | |
| hidden_size = mlp_config.hidden_size | |
| num_epochs = 2 | |
| batches_per_epoch = 2 | |
| # Execute training | |
| epoch_losses = train_mlp_model( | |
| model=model, | |
| optimizer=optimizer, | |
| device=device, | |
| num_tokens=num_tokens, | |
| hidden_size=hidden_size, | |
| num_epochs=num_epochs, | |
| batches_per_epoch=batches_per_epoch, | |
| ) | |
| # Verify model parameters have been updated | |
| params_updated = verify_model_parameters_updated(initial_params=initial_params, current_params=list(model.parameters())) | |
| assert params_updated, "Model parameters should change after training" | |
| print("Test passed: Model successfully completed multiple training epochs, parameters have been updated") | |
| if __name__ == "__main__": | |
| # Usage: | |
| # ENABLE_REMOTE_DEBUG=true MAGI_ENABLE_FX_GRAPH_VIZ=true TORCH_LOGS=aot CUDA_VISIBLE_DEVICES=1 python pkgs/MagiCompiler/tests/test_mlp_training.py | |
| with CleanupCacheContext(): | |
| enable_remote_debug() | |
| test_mlp_training_with_magi_compiler() | |
Xet Storage Details
- Size:
- 5.92 kB
- Xet hash:
- 16862bdd85dfd3fe470ccc477240e824a1f7b2531b7d61955ed023dce504989c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.