| | """Pruned FlexOlmo model with variable-width expert 1. |
| | |
| | This module provides a HuggingFace-compatible model that can be loaded with: |
| | AutoModelForCausalLM.from_pretrained("hbfreed/flex-math-8192", trust_remote_code=True) |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import FlexOlmoForCausalLM |
| | from transformers.models.flex_olmo.modeling_flex_olmo import FlexOlmoMLP |
| |
|
| | from .configuration_pruned_flex_olmo import PrunedFlexOlmoConfig |
| |
|
| |
|
| | class PrunedFlexOlmoMLP(nn.Module): |
| | """Pruned MLP with same interface as FlexOlmoMLP but variable width.""" |
| |
|
| | def __init__(self, intermediate_size: int, hidden_size: int, act_fn, dtype=torch.bfloat16): |
| | super().__init__() |
| | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype) |
| | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype) |
| | self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype) |
| | self.act_fn = act_fn |
| |
|
| | def forward(self, x): |
| | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| |
|
| |
|
| | class PrunedFlexOlmoForCausalLM(FlexOlmoForCausalLM): |
| | """FlexOlmo with pruned expert 1 for variable-width MoE. |
| | |
| | Expert 0 remains at full intermediate_size, while expert 1 is pruned |
| | to expert_1_intermediate_size specified in the config. |
| | """ |
| |
|
| | config_class = PrunedFlexOlmoConfig |
| |
|
| | def __init__(self, config: PrunedFlexOlmoConfig): |
| | |
| | super().__init__(config) |
| |
|
| | |
| | expert_1_width = config.expert_1_intermediate_size |
| | hidden_size = config.hidden_size |
| |
|
| | for layer in self.model.layers: |
| | |
| | act_fn = layer.mlp.experts[1].act_fn |
| |
|
| | |
| | layer.mlp.experts[1] = PrunedFlexOlmoMLP( |
| | intermediate_size=expert_1_width, |
| | hidden_size=hidden_size, |
| | act_fn=act_fn, |
| | dtype=self.dtype, |
| | ) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| | """Load pruned model, handling both local and hub paths.""" |
| | |
| | |
| | return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
| |
|