# auto_split_model.py import torch.nn as nn import os import inspect import requests #url = 'https://raw.githubusercontent.com/username/repo-name/branch-name/utils.py' #code = requests.get(url).text #exec(code) # Dangerous! Only for trusted code. def extract_all_layers(module): layers = [] for name, layer in module.named_modules(): # Skip the parent module itself if isinstance(layer, (nn.Sequential, nn.ModuleList)): for sublayer in layer: if isinstance(sublayer, nn.Module): layers.append(sublayer) elif len(list(layer.children())) == 0: layers.append(layer) return layers def split_into_chunks(layers, num_parts): k, m = divmod(len(layers), num_parts) return [layers[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_parts)] def layer_to_code(layer): name = layer.__class__.__name__ args = inspect.signature(layer.__class__).parameters try: repr_str = repr(layer) return repr_str.replace('\n', '\n ') except: return f"# Could not serialize layer: {name}" def auto_split_model(model, num_stages=3, output_dir="model_stage_files"): print("šŸ” Extracting layers...") layers = extract_all_layers(model) print(f"āœ… Total Layers Extracted: {len(layers)}") print("šŸ“¦ Splitting into stages...") layer_chunks = split_into_chunks(layers, num_stages) os.makedirs(output_dir, exist_ok=True) for i, chunk in enumerate(layer_chunks): filename = os.path.join(output_dir, f"node_{i}.py") with open(filename, "w") as f: f.write("import torch.nn as nn\n\n") f.write("def get_model():\n") f.write(" return nn.Sequential(\n") for layer in chunk: f.write(f" {layer_to_code(layer)},\n") f.write(" )\n") print(f"āœ… Saved stage {i} to {filename}") print("\nšŸŽ‰ All model parts saved successfully!") # === Example Usage === if __name__ == "__main__": from transformers import GPT2Model model = GPT2Model.from_pretrained("gpt2") auto_split_model(model, num_stages=3)