hbfreed commited on
Commit
d68f90f
·
verified ·
1 Parent(s): 17c2d94

Upload modeling_pruned_flex_olmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_pruned_flex_olmo.py +63 -0
modeling_pruned_flex_olmo.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pruned FlexOlmo model with variable-width expert 1.
2
+
3
+ This module provides a HuggingFace-compatible model that can be loaded with:
4
+ AutoModelForCausalLM.from_pretrained("hbfreed/flex-math-8192", trust_remote_code=True)
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import FlexOlmoForCausalLM
10
+ from transformers.models.flex_olmo.modeling_flex_olmo import FlexOlmoMLP
11
+
12
+ from configuration_pruned_flex_olmo import PrunedFlexOlmoConfig
13
+
14
+
15
+ class PrunedFlexOlmoMLP(nn.Module):
16
+ """Pruned MLP with same interface as FlexOlmoMLP but variable width."""
17
+
18
+ def __init__(self, intermediate_size: int, hidden_size: int, act_fn, dtype=torch.bfloat16):
19
+ super().__init__()
20
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
21
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype)
22
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype)
23
+ self.act_fn = act_fn
24
+
25
+ def forward(self, x):
26
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
27
+
28
+
29
+ class PrunedFlexOlmoForCausalLM(FlexOlmoForCausalLM):
30
+ """FlexOlmo with pruned expert 1 for variable-width MoE.
31
+
32
+ Expert 0 remains at full intermediate_size, while expert 1 is pruned
33
+ to expert_1_intermediate_size specified in the config.
34
+ """
35
+
36
+ config_class = PrunedFlexOlmoConfig
37
+
38
+ def __init__(self, config: PrunedFlexOlmoConfig):
39
+ # Initialize parent with full architecture
40
+ super().__init__(config)
41
+
42
+ # Replace expert 1 in each layer with pruned version
43
+ expert_1_width = config.expert_1_intermediate_size
44
+ hidden_size = config.hidden_size
45
+
46
+ for layer in self.model.layers:
47
+ # Get activation function from existing expert
48
+ act_fn = layer.mlp.experts[1].act_fn
49
+
50
+ # Replace expert 1 with pruned version
51
+ layer.mlp.experts[1] = PrunedFlexOlmoMLP(
52
+ intermediate_size=expert_1_width,
53
+ hidden_size=hidden_size,
54
+ act_fn=act_fn,
55
+ dtype=self.dtype,
56
+ )
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
60
+ """Load pruned model, handling both local and hub paths."""
61
+ # Let parent handle the loading - it will use our config_class
62
+ # and __init__ which sets up the pruned architecture
63
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)