Khelendramee commited on
Commit
bdba16a
·
verified ·
1 Parent(s): cd67fd0

Create model_split.py

Browse files
Files changed (1) hide show
  1. model_split.py +58 -0
model_split.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # auto_split_model.py
2
+ import torch.nn as nn
3
+ import os
4
+ import inspect
5
+
6
+ def extract_all_layers(module):
7
+ layers = []
8
+ for name, layer in module.named_modules():
9
+ # Skip the parent module itself
10
+ if isinstance(layer, (nn.Sequential, nn.ModuleList)):
11
+ for sublayer in layer:
12
+ if isinstance(sublayer, nn.Module):
13
+ layers.append(sublayer)
14
+ elif len(list(layer.children())) == 0:
15
+ layers.append(layer)
16
+ return layers
17
+
18
+ def split_into_chunks(layers, num_parts):
19
+ k, m = divmod(len(layers), num_parts)
20
+ return [layers[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_parts)]
21
+
22
+ def layer_to_code(layer):
23
+ name = layer.__class__.__name__
24
+ args = inspect.signature(layer.__class__).parameters
25
+ try:
26
+ repr_str = repr(layer)
27
+ return repr_str.replace('\n', '\n ')
28
+ except:
29
+ return f"# Could not serialize layer: {name}"
30
+
31
+ def auto_split_model(model, num_stages=3, output_dir="model_stage_files"):
32
+ print("🔍 Extracting layers...")
33
+ layers = extract_all_layers(model)
34
+ print(f"✅ Total Layers Extracted: {len(layers)}")
35
+
36
+ print("📦 Splitting into stages...")
37
+ layer_chunks = split_into_chunks(layers, num_stages)
38
+
39
+ os.makedirs(output_dir, exist_ok=True)
40
+
41
+ for i, chunk in enumerate(layer_chunks):
42
+ filename = os.path.join(output_dir, f"node_{i}.py")
43
+ with open(filename, "w") as f:
44
+ f.write("import torch.nn as nn\n\n")
45
+ f.write("def get_model():\n")
46
+ f.write(" return nn.Sequential(\n")
47
+ for layer in chunk:
48
+ f.write(f" {layer_to_code(layer)},\n")
49
+ f.write(" )\n")
50
+ print(f"✅ Saved stage {i} to {filename}")
51
+
52
+ print("\n🎉 All model parts saved successfully!")
53
+
54
+ # === Example Usage ===
55
+ if __name__ == "__main__":
56
+ from transformers import GPT2Model
57
+ model = GPT2Model.from_pretrained("gpt2")
58
+ auto_split_model(model, num_stages=3)