Aze4ka commited on
Commit
1272ff3
·
verified ·
1 Parent(s): f6a7967

Upload spmd_util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. spmd_util.py +97 -0
spmd_util.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+ import torch_xla.experimental.xla_sharding as xs
5
+ import torch_xla.core.xla_model as xm
6
+ from transformers import (
7
+ GPTNeoXConfig, T5Config, LlamaConfig
8
+ )
9
+
10
+ # ends with $ to prevent sharding lora parameters
11
+ GPTNEOX_RULES = (
12
+ # embeddings
13
+ ("gpt_neox\\.embed_in", ("mp", "fsdp")),
14
+ # atention
15
+ ("attention\\.query_key_value$", ("fsdp", "mp")),
16
+ ("attention\\.dense$", ("mp", "fsdp")),
17
+ # mlp
18
+ ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
19
+ ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
20
+ # output
21
+ ("embed_out", ("fsdp", "mp")),
22
+ )
23
+
24
+ T5_RULES = (
25
+ # embeddings
26
+ ("shared$", ("mp", "fsdp")),
27
+ ("embed_tokens$", ("mp", "fsdp")),
28
+
29
+ # attention
30
+ ("q$", ("fsdp", "mp")),
31
+ ("k$", ("fsdp", "mp")),
32
+ ("v$", ("fsdp", "mp")),
33
+ ("o$", ("mp", "fsdp")),
34
+
35
+ # mlp
36
+ ("w$", ("fsdp", "mp")),
37
+ ("wi_0$", ("fsdp", "mp")),
38
+ ("wi_1$", ("fsdp", "mp")),
39
+ ("wo$", ("mp", "fsdp")),
40
+
41
+ # seq2seq lm head
42
+ ("lm_head", ("fsdp", "mp")),
43
+ )
44
+
45
+ LLAMA_RULES = (
46
+ ("model\\.embed_tokens", ("mp", "fsdp")),
47
+ ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
48
+ ("self_attn\\.o_proj", ("mp", "fsdp")),
49
+ ("mlp\\.gate_proj", ("fsdp", "mp")),
50
+ ("mlp\\.down_proj", ("mp", "fsdp")),
51
+ ("mlp\\.up_proj", ("fsdp", "mp")),
52
+ ("lm_head", ("fsdp", "mp")),
53
+ )
54
+
55
+ ALL_RULES = [
56
+ (GPTNeoXConfig, GPTNEOX_RULES),
57
+ (T5Config, T5_RULES),
58
+ (LlamaConfig, LLAMA_RULES)
59
+ ]
60
+
61
+ def find_rule(model):
62
+ for config, rule in ALL_RULES:
63
+ if model.config.__class__ == config:
64
+ return rule
65
+ raise Exception("unsupported model to partitioning")
66
+
67
+ strkey2id = {
68
+ "dp": 0,
69
+ "fsdp": 1,
70
+ "mp": 2
71
+ }
72
+
73
+ def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
74
+ partition_specs = find_rule(model)
75
+ rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
76
+
77
+ # print(rule)
78
+
79
+ for name, module in model.named_modules():
80
+ module.to(device)
81
+ # print(name, module.__class__.__name__)
82
+ if isinstance(module, (nn.Embedding, nn.Linear)):
83
+ for rule_pattern, spec in rule:
84
+ if re.findall(rule_pattern, name):
85
+ if verbose:
86
+ print("match", rule_pattern, name)
87
+
88
+ xs.mark_sharding(module.weight, mesh, spec)
89
+ break
90
+
91
+ def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=False):
92
+ spec = (1, 2)
93
+
94
+ for name, module in model.named_modules():
95
+ module.to(device)
96
+ if isinstance(module, (nn.Embedding, nn.Linear)):
97
+ xs.mark_sharding(module.weight, mesh, spec)