stocker / model_split.py
Khelendramee's picture
Update model_split.py
399b065 verified
# auto_split_model.py
import torch.nn as nn
import os
import inspect
import requests
#url = 'https://raw.githubusercontent.com/username/repo-name/branch-name/utils.py'
#code = requests.get(url).text
#exec(code) # Dangerous! Only for trusted code.
def extract_all_layers(module):
layers = []
for name, layer in module.named_modules():
# Skip the parent module itself
if isinstance(layer, (nn.Sequential, nn.ModuleList)):
for sublayer in layer:
if isinstance(sublayer, nn.Module):
layers.append(sublayer)
elif len(list(layer.children())) == 0:
layers.append(layer)
return layers
def split_into_chunks(layers, num_parts):
k, m = divmod(len(layers), num_parts)
return [layers[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_parts)]
def layer_to_code(layer):
name = layer.__class__.__name__
args = inspect.signature(layer.__class__).parameters
try:
repr_str = repr(layer)
return repr_str.replace('\n', '\n ')
except:
return f"# Could not serialize layer: {name}"
def auto_split_model(model, num_stages=3, output_dir="model_stage_files"):
print("πŸ” Extracting layers...")
layers = extract_all_layers(model)
print(f"βœ… Total Layers Extracted: {len(layers)}")
print("πŸ“¦ Splitting into stages...")
layer_chunks = split_into_chunks(layers, num_stages)
os.makedirs(output_dir, exist_ok=True)
for i, chunk in enumerate(layer_chunks):
filename = os.path.join(output_dir, f"node_{i}.py")
with open(filename, "w") as f:
f.write("import torch.nn as nn\n\n")
f.write("def get_model():\n")
f.write(" return nn.Sequential(\n")
for layer in chunk:
f.write(f" {layer_to_code(layer)},\n")
f.write(" )\n")
print(f"βœ… Saved stage {i} to {filename}")
print("\nπŸŽ‰ All model parts saved successfully!")
# === Example Usage ===
if __name__ == "__main__":
from transformers import GPT2Model
model = GPT2Model.from_pretrained("gpt2")
auto_split_model(model, num_stages=3)