Spaces:
Sleeping
Sleeping
| # auto_split_model.py | |
| import torch.nn as nn | |
| import os | |
| import inspect | |
| import requests | |
| #url = 'https://raw.githubusercontent.com/username/repo-name/branch-name/utils.py' | |
| #code = requests.get(url).text | |
| #exec(code) # Dangerous! Only for trusted code. | |
| def extract_all_layers(module): | |
| layers = [] | |
| for name, layer in module.named_modules(): | |
| # Skip the parent module itself | |
| if isinstance(layer, (nn.Sequential, nn.ModuleList)): | |
| for sublayer in layer: | |
| if isinstance(sublayer, nn.Module): | |
| layers.append(sublayer) | |
| elif len(list(layer.children())) == 0: | |
| layers.append(layer) | |
| return layers | |
| def split_into_chunks(layers, num_parts): | |
| k, m = divmod(len(layers), num_parts) | |
| return [layers[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_parts)] | |
| def layer_to_code(layer): | |
| name = layer.__class__.__name__ | |
| args = inspect.signature(layer.__class__).parameters | |
| try: | |
| repr_str = repr(layer) | |
| return repr_str.replace('\n', '\n ') | |
| except: | |
| return f"# Could not serialize layer: {name}" | |
| def auto_split_model(model, num_stages=3, output_dir="model_stage_files"): | |
| print("π Extracting layers...") | |
| layers = extract_all_layers(model) | |
| print(f"β Total Layers Extracted: {len(layers)}") | |
| print("π¦ Splitting into stages...") | |
| layer_chunks = split_into_chunks(layers, num_stages) | |
| os.makedirs(output_dir, exist_ok=True) | |
| for i, chunk in enumerate(layer_chunks): | |
| filename = os.path.join(output_dir, f"node_{i}.py") | |
| with open(filename, "w") as f: | |
| f.write("import torch.nn as nn\n\n") | |
| f.write("def get_model():\n") | |
| f.write(" return nn.Sequential(\n") | |
| for layer in chunk: | |
| f.write(f" {layer_to_code(layer)},\n") | |
| f.write(" )\n") | |
| print(f"β Saved stage {i} to {filename}") | |
| print("\nπ All model parts saved successfully!") | |
| # === Example Usage === | |
| if __name__ == "__main__": | |
| from transformers import GPT2Model | |
| model = GPT2Model.from_pretrained("gpt2") | |
| auto_split_model(model, num_stages=3) | |