Spaces:
Sleeping
Sleeping
Create model_split.py
Browse files- model_split.py +58 -0
model_split.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# auto_split_model.py
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
import inspect
|
| 5 |
+
|
| 6 |
+
def extract_all_layers(module):
|
| 7 |
+
layers = []
|
| 8 |
+
for name, layer in module.named_modules():
|
| 9 |
+
# Skip the parent module itself
|
| 10 |
+
if isinstance(layer, (nn.Sequential, nn.ModuleList)):
|
| 11 |
+
for sublayer in layer:
|
| 12 |
+
if isinstance(sublayer, nn.Module):
|
| 13 |
+
layers.append(sublayer)
|
| 14 |
+
elif len(list(layer.children())) == 0:
|
| 15 |
+
layers.append(layer)
|
| 16 |
+
return layers
|
| 17 |
+
|
| 18 |
+
def split_into_chunks(layers, num_parts):
|
| 19 |
+
k, m = divmod(len(layers), num_parts)
|
| 20 |
+
return [layers[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_parts)]
|
| 21 |
+
|
| 22 |
+
def layer_to_code(layer):
|
| 23 |
+
name = layer.__class__.__name__
|
| 24 |
+
args = inspect.signature(layer.__class__).parameters
|
| 25 |
+
try:
|
| 26 |
+
repr_str = repr(layer)
|
| 27 |
+
return repr_str.replace('\n', '\n ')
|
| 28 |
+
except:
|
| 29 |
+
return f"# Could not serialize layer: {name}"
|
| 30 |
+
|
| 31 |
+
def auto_split_model(model, num_stages=3, output_dir="model_stage_files"):
|
| 32 |
+
print("🔍 Extracting layers...")
|
| 33 |
+
layers = extract_all_layers(model)
|
| 34 |
+
print(f"✅ Total Layers Extracted: {len(layers)}")
|
| 35 |
+
|
| 36 |
+
print("📦 Splitting into stages...")
|
| 37 |
+
layer_chunks = split_into_chunks(layers, num_stages)
|
| 38 |
+
|
| 39 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
for i, chunk in enumerate(layer_chunks):
|
| 42 |
+
filename = os.path.join(output_dir, f"node_{i}.py")
|
| 43 |
+
with open(filename, "w") as f:
|
| 44 |
+
f.write("import torch.nn as nn\n\n")
|
| 45 |
+
f.write("def get_model():\n")
|
| 46 |
+
f.write(" return nn.Sequential(\n")
|
| 47 |
+
for layer in chunk:
|
| 48 |
+
f.write(f" {layer_to_code(layer)},\n")
|
| 49 |
+
f.write(" )\n")
|
| 50 |
+
print(f"✅ Saved stage {i} to {filename}")
|
| 51 |
+
|
| 52 |
+
print("\n🎉 All model parts saved successfully!")
|
| 53 |
+
|
| 54 |
+
# === Example Usage ===
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
from transformers import GPT2Model
|
| 57 |
+
model = GPT2Model.from_pretrained("gpt2")
|
| 58 |
+
auto_split_model(model, num_stages=3)
|