Spaces:
Sleeping
Sleeping
Update model_split.py
Browse files- model_split.py +9 -2
model_split.py
CHANGED
|
@@ -2,6 +2,13 @@
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import os
|
| 4 |
import inspect
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def extract_all_layers(module):
|
| 7 |
layers = []
|
|
@@ -53,6 +60,6 @@ def auto_split_model(model, num_stages=3, output_dir="model_stage_files"):
|
|
| 53 |
|
| 54 |
# === Example Usage ===
|
| 55 |
if __name__ == "__main__":
|
| 56 |
-
from transformers import GPT2Model
|
| 57 |
-
model = GPT2Model.from_pretrained("gpt2")
|
| 58 |
auto_split_model(model, num_stages=3)
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import os
|
| 4 |
import inspect
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
url = 'https://raw.githubusercontent.com/username/repo-name/branch-name/utils.py'
|
| 8 |
+
|
| 9 |
+
code = requests.get(url).text
|
| 10 |
+
|
| 11 |
+
exec(code) # Dangerous! Only for trusted code.
|
| 12 |
|
| 13 |
def extract_all_layers(module):
|
| 14 |
layers = []
|
|
|
|
| 60 |
|
| 61 |
# === Example Usage ===
|
| 62 |
if __name__ == "__main__":
|
| 63 |
+
#from transformers import GPT2Model
|
| 64 |
+
#model = GPT2Model.from_pretrained("gpt2")
|
| 65 |
auto_split_model(model, num_stages=3)
|