Buckets:
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from magi_compiler import magi_compile | |
| class MLPConfig: | |
| """Configuration for the MLP module""" | |
| hidden_size: int | |
| intermediate_size: int | |
| params_dtype: torch.dtype = torch.bfloat16 | |
| class RMSNormConfig: | |
| """Configuration for the RMSNorm module""" | |
| hidden_size: int | |
| eps: float = 1e-6 | |
| class RMSNorm(nn.Module): | |
| """Simple RMSNorm implementation""" | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| input_dtype = x.dtype | |
| variance = x.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + self.eps) | |
| x = x.to(self.weight.dtype) * self.weight | |
| return x.to(input_dtype) | |
| class MLP(torch.nn.Module): | |
| """MLP module with traditional architecture (up-projection, activation, and down-projection)""" | |
| config: MLPConfig | |
| def __init__(self, config: MLPConfig): | |
| super().__init__() | |
| self.config = config | |
| self.pre_norm = RMSNorm(config.hidden_size) | |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass of the MLP module. | |
| Args: | |
| x (torch.Tensor): Input tensor | |
| Returns: | |
| output (torch.Tensor): Output tensor | |
| Shape: | |
| - x: (num_tokens, hidden_size) | |
| - output: (num_tokens, hidden_size) | |
| """ | |
| # Pre-normalization | |
| x = self.pre_norm(x).to(torch.bfloat16) | |
| # Up-projection | |
| x = self.up_proj(x).to(torch.float32) | |
| # Activation (SiLU) | |
| x = F.silu(x).to(torch.bfloat16) | |
| # Down-projection | |
| x = self.down_proj(x).to(torch.float32) | |
| return x | |
| class RMSNormModule(torch.nn.Module): | |
| """Compiled RMSNorm module for testing""" | |
| config: RMSNormConfig | |
| def __init__(self, config: RMSNormConfig): | |
| super().__init__() | |
| self.config = config | |
| self.norm = RMSNorm(config.hidden_size, eps=config.eps) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass of the RMSNorm module. | |
| Args: | |
| x (torch.Tensor): Input tensor | |
| Returns: | |
| output (torch.Tensor): Normalized output tensor | |
| Shape: | |
| - x: (num_tokens, hidden_size) | |
| - output: (num_tokens, hidden_size) | |
| """ | |
| return self.norm(x) | |
| def create_rms_norm_model(config: RMSNormConfig, device: torch.device) -> RMSNormModule: | |
| """Create RMSNorm model | |
| Args: | |
| config: RMSNorm configuration | |
| device: Target device | |
| Returns: | |
| model: Created RMSNorm model | |
| """ | |
| model = RMSNormModule(config).to(device) | |
| return model | |
| def create_mlp_model(config: MLPConfig, device: torch.device) -> MLP: | |
| """Create MLP model | |
| Args: | |
| config: MLP configuration | |
| device: Target device | |
| Returns: | |
| model: Created MLP model | |
| """ | |
| model = MLP(config).to(device) | |
| return model | |
| def create_mlp_model_with_initial_params(config: MLPConfig, device: torch.device) -> tuple[MLP, list[torch.Tensor]]: | |
| """Create MLP model and return model with initial parameter snapshot | |
| Args: | |
| config: MLP configuration | |
| device: Target device | |
| Returns: | |
| model: Created MLP model | |
| initial_params: Initial snapshot of model parameters for verifying parameter updates | |
| """ | |
| model = MLP(config).to(device) | |
| initial_params = [p.clone().detach() for p in model.parameters()] | |
| return model, initial_params | |
Xet Storage Details
- Size:
- 4.67 kB
- Xet hash:
- c4e540973490575b47ec9dda7feb04851e3d209b91a2a15bb394e2504f13b111
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.