Khelendramee commited on
Commit
751368d
·
verified ·
1 Parent(s): f77f8b3

Update model_split.py

Browse files
Files changed (1) hide show
  1. model_split.py +9 -2
model_split.py CHANGED
@@ -2,6 +2,13 @@
2
  import torch.nn as nn
3
  import os
4
  import inspect
 
 
 
 
 
 
 
5
 
6
  def extract_all_layers(module):
7
  layers = []
@@ -53,6 +60,6 @@ def auto_split_model(model, num_stages=3, output_dir="model_stage_files"):
53
 
54
  # === Example Usage ===
55
  if __name__ == "__main__":
56
- from transformers import GPT2Model
57
- model = GPT2Model.from_pretrained("gpt2")
58
  auto_split_model(model, num_stages=3)
 
2
  import torch.nn as nn
3
  import os
4
  import inspect
5
+ import requests
6
+
7
+ url = 'https://raw.githubusercontent.com/username/repo-name/branch-name/utils.py'
8
+
9
+ code = requests.get(url).text
10
+
11
+ exec(code) # Dangerous! Only for trusted code.
12
 
13
  def extract_all_layers(module):
14
  layers = []
 
60
 
61
  # === Example Usage ===
62
  if __name__ == "__main__":
63
+ #from transformers import GPT2Model
64
+ #model = GPT2Model.from_pretrained("gpt2")
65
  auto_split_model(model, num_stages=3)