File size: 2,189 Bytes
bdba16a
 
 
 
751368d
 
399b065
751368d
399b065
751368d
399b065
bdba16a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa34ca2
 
bdba16a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# auto_split_model.py
import torch.nn as nn
import os
import inspect
import requests

#url = 'https://raw.githubusercontent.com/username/repo-name/branch-name/utils.py'

#code = requests.get(url).text

#exec(code)  # Dangerous! Only for trusted code.

def extract_all_layers(module):
    layers = []
    for name, layer in module.named_modules():
        # Skip the parent module itself
        if isinstance(layer, (nn.Sequential, nn.ModuleList)):
            for sublayer in layer:
                if isinstance(sublayer, nn.Module):
                    layers.append(sublayer)
        elif len(list(layer.children())) == 0:
            layers.append(layer)
    return layers

def split_into_chunks(layers, num_parts):
    k, m = divmod(len(layers), num_parts)
    return [layers[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_parts)]

def layer_to_code(layer):
    name = layer.__class__.__name__
    args = inspect.signature(layer.__class__).parameters
    try:
        repr_str = repr(layer)
        return repr_str.replace('\n', '\n        ')
    except:
        return f"# Could not serialize layer: {name}"

def auto_split_model(model, num_stages=3, output_dir="model_stage_files"):
    print("🔍 Extracting layers...")
    layers = extract_all_layers(model)
    print(f"✅ Total Layers Extracted: {len(layers)}")

    print("📦 Splitting into stages...")
    layer_chunks = split_into_chunks(layers, num_stages)

    os.makedirs(output_dir, exist_ok=True)

    for i, chunk in enumerate(layer_chunks):
        filename = os.path.join(output_dir, f"node_{i}.py")
        with open(filename, "w") as f:
            f.write("import torch.nn as nn\n\n")
            f.write("def get_model():\n")
            f.write("    return nn.Sequential(\n")
            for layer in chunk:
                f.write(f"        {layer_to_code(layer)},\n")
            f.write("    )\n")
        print(f"✅ Saved stage {i} to {filename}")

    print("\n🎉 All model parts saved successfully!")

# === Example Usage ===
if __name__ == "__main__":
    from transformers import GPT2Model
    model = GPT2Model.from_pretrained("gpt2")
    auto_split_model(model, num_stages=3)