Spaces:
Sleeping
Sleeping
File size: 2,189 Bytes
bdba16a 751368d 399b065 751368d 399b065 751368d 399b065 bdba16a fa34ca2 bdba16a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | # auto_split_model.py
import torch.nn as nn
import os
import inspect
import requests
#url = 'https://raw.githubusercontent.com/username/repo-name/branch-name/utils.py'
#code = requests.get(url).text
#exec(code) # Dangerous! Only for trusted code.
def extract_all_layers(module):
layers = []
for name, layer in module.named_modules():
# Skip the parent module itself
if isinstance(layer, (nn.Sequential, nn.ModuleList)):
for sublayer in layer:
if isinstance(sublayer, nn.Module):
layers.append(sublayer)
elif len(list(layer.children())) == 0:
layers.append(layer)
return layers
def split_into_chunks(layers, num_parts):
k, m = divmod(len(layers), num_parts)
return [layers[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_parts)]
def layer_to_code(layer):
name = layer.__class__.__name__
args = inspect.signature(layer.__class__).parameters
try:
repr_str = repr(layer)
return repr_str.replace('\n', '\n ')
except:
return f"# Could not serialize layer: {name}"
def auto_split_model(model, num_stages=3, output_dir="model_stage_files"):
print("🔍 Extracting layers...")
layers = extract_all_layers(model)
print(f"✅ Total Layers Extracted: {len(layers)}")
print("📦 Splitting into stages...")
layer_chunks = split_into_chunks(layers, num_stages)
os.makedirs(output_dir, exist_ok=True)
for i, chunk in enumerate(layer_chunks):
filename = os.path.join(output_dir, f"node_{i}.py")
with open(filename, "w") as f:
f.write("import torch.nn as nn\n\n")
f.write("def get_model():\n")
f.write(" return nn.Sequential(\n")
for layer in chunk:
f.write(f" {layer_to_code(layer)},\n")
f.write(" )\n")
print(f"✅ Saved stage {i} to {filename}")
print("\n🎉 All model parts saved successfully!")
# === Example Usage ===
if __name__ == "__main__":
from transformers import GPT2Model
model = GPT2Model.from_pretrained("gpt2")
auto_split_model(model, num_stages=3)
|