zaydzuhri commited on
Commit
bd301da
·
verified ·
1 Parent(s): 268966a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/abc/configuration_abc.py +91 -0
  2. fla/models/forgetting_transformer/__init__.py +16 -0
  3. fla/models/gla/__init__.py +13 -0
  4. fla/models/nsa/__init__.py +15 -0
  5. fla/models/transformer_top/configuration_transformer.py +76 -0
  6. fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
  7. fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
  8. fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc +0 -0
  9. fla/modules/__pycache__/layernorm.cpython-312.pyc +0 -0
  10. fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc +0 -0
  11. fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  12. fla/ops/delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
  13. fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  14. fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
  15. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc +0 -0
  16. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc +0 -0
  17. fla/ops/generalized_delta_rule/dplr/chunk.py +388 -0
  18. fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  19. fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc +0 -0
  20. fla/ops/gsa/__pycache__/chunk.cpython-312.pyc +0 -0
  21. fla/ops/hgrn/__pycache__/chunk.cpython-312.pyc +0 -0
  22. fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  23. fla/ops/ttt/__pycache__/chunk.cpython-312.pyc +0 -0
  24. fla/ops/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  25. logs/none_enyj3lod/attempt_0/3/stderr.log +0 -0
  26. profile_trace/iteration_17408/rank4_trace.json +0 -0
  27. profile_trace/iteration_18944/rank2_trace.json +0 -0
  28. profile_trace/iteration_25088/rank3_trace.json +0 -0
  29. profile_trace/iteration_25088/rank7_trace.json +0 -0
  30. profile_trace/iteration_33280/rank6_trace.json +0 -0
  31. profile_trace/iteration_34816/rank5_trace.json +0 -0
  32. profile_trace/iteration_38912/rank1_trace.json +0 -0
  33. profile_trace/iteration_38912/rank2_trace.json +0 -0
  34. profile_trace/iteration_7680/rank0_trace.json +0 -0
  35. profile_trace/iteration_7680/rank4_trace.json +0 -0
  36. torchtitan/components/dataloader.py +92 -0
  37. torchtitan/components/float8.py +150 -0
  38. torchtitan/components/optimizer.py +303 -0
  39. torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc +0 -0
  40. torchtitan/datasets/hf_datasets.py +173 -0
  41. torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc +0 -0
  42. torchtitan/datasets/tokenizer/tiktoken.py +190 -0
  43. torchtitan/distributed/__pycache__/__init__.cpython-312.pyc +0 -0
  44. torchtitan/distributed/__pycache__/utils.cpython-312.pyc +0 -0
  45. torchtitan/experiments/deepseek_v3/inference.sh +15 -0
  46. torchtitan/experiments/deepseek_v3/model_config.py +204 -0
  47. torchtitan/experiments/flux/README.md +23 -0
  48. torchtitan/experiments/flux/__pycache__/parallelize_flux.cpython-312.pyc +0 -0
  49. torchtitan/experiments/flux/flux_argparser.py +42 -0
  50. torchtitan/experiments/flux/loss.py +27 -0
fla/models/abc/configuration_abc.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ABCConfig(PretrainedConfig):
9
+
10
+ model_type = 'abc'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_low_rank_dim: int = 16,
17
+ clamp_min: float = -32,
18
+ clamp_max: float = 32,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_slots: Optional[int] = 64,
24
+ use_short_conv: bool = False,
25
+ conv_size: int = 4,
26
+ exapnd_k: float = 0.5,
27
+ exapnd_v: float = 1,
28
+ hidden_act: str = "swish",
29
+ max_position_embeddings: int = 2048,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_eps: float = 1e-6,
32
+ use_rope: bool = True,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.hidden_size = hidden_size
47
+ self.gate_low_rank_dim = gate_low_rank_dim
48
+ self.clamp_min = clamp_min
49
+ self.clamp_max = clamp_max
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_slots = num_slots
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.expand_k = exapnd_k
58
+ self.expand_v = exapnd_v
59
+ self.hidden_act = hidden_act
60
+ self.max_position_embeddings = max_position_embeddings
61
+ self.elementwise_affine = elementwise_affine
62
+ self.norm_eps = norm_eps
63
+ self.use_rope = use_rope
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/forgetting_transformer/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
6
+ from fla.models.forgetting_transformer.modeling_forgetting_transformer import (
7
+ ForgettingTransformerForCausalLM,
8
+ ForgettingTransformerModel
9
+ )
10
+
11
+ AutoConfig.register(ForgettingTransformerConfig.model_type, ForgettingTransformerConfig)
12
+ AutoModel.register(ForgettingTransformerConfig, ForgettingTransformerModel)
13
+ AutoModelForCausalLM.register(ForgettingTransformerConfig, ForgettingTransformerForCausalLM)
14
+
15
+
16
+ __all__ = ['ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel']
fla/models/gla/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gla.configuration_gla import GLAConfig
6
+ from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
7
+
8
+ AutoConfig.register(GLAConfig.model_type, GLAConfig)
9
+ AutoModel.register(GLAConfig, GLAModel)
10
+ AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
11
+
12
+
13
+ __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
fla/models/nsa/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.nsa.configuration_nsa import NSAConfig
6
+ from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel
7
+
8
+ AutoConfig.register(NSAConfig.model_type, NSAConfig)
9
+ AutoModel.register(NSAConfig, NSAModel)
10
+ AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)
11
+
12
+
13
+ __all__ = [
14
+ 'NSAConfig', 'NSAModel', 'NSAForCausalLM',
15
+ ]
fla/models/transformer_top/configuration_transformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class TOPTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'top_transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ rope_theta: Optional[float] = 10000.,
23
+ max_position_embeddings: int = 2048,
24
+ hidden_ratio: Optional[int] = 4,
25
+ intermediate_size: Optional[int] = None,
26
+ hidden_act: str = "swish",
27
+ initializer_range: float = 0.006,
28
+ elementwise_affine: Optional[bool] = True,
29
+ norm_eps: float = 1e-6,
30
+ use_cache: bool = True,
31
+ pad_token_id: int = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ use_top_loss: bool = False,
40
+ top_window_size: Optional[int] = None,
41
+ **kwargs,
42
+ ):
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.num_heads = num_heads
46
+ self.num_kv_heads = num_kv_heads
47
+ self.qkv_bias = qkv_bias
48
+ self.qk_norm = qk_norm
49
+ self.window_size = window_size
50
+ self.rope_theta = rope_theta
51
+ self.max_position_embeddings = max_position_embeddings
52
+
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_act = hidden_act
56
+
57
+ self.initializer_range = initializer_range
58
+ self.elementwise_affine = elementwise_affine
59
+ self.norm_eps = norm_eps
60
+ self.use_cache = use_cache
61
+
62
+ self.fuse_norm = fuse_norm
63
+ self.fuse_swiglu = fuse_swiglu
64
+ self.fuse_cross_entropy = fuse_cross_entropy
65
+ self.vocab_size = vocab_size
66
+
67
+ self.use_top_loss = use_top_loss
68
+ self.top_window_size = top_window_size if top_window_size is not None else max_position_embeddings
69
+
70
+ super().__init__(
71
+ pad_token_id=pad_token_id,
72
+ bos_token_id=bos_token_id,
73
+ eos_token_id=eos_token_id,
74
+ tie_word_embeddings=tie_word_embeddings,
75
+ **kwargs,
76
+ )
fla/modules/__pycache__/activations.cpython-312.pyc ADDED
Binary file (23 kB). View file
 
fla/modules/__pycache__/convolution.cpython-312.pyc ADDED
Binary file (21 kB). View file
 
fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
fla/modules/__pycache__/layernorm.cpython-312.pyc ADDED
Binary file (43.4 kB). View file
 
fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc ADDED
Binary file (6.74 kB). View file
 
fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (361 Bytes). View file
 
fla/ops/delta_rule/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (392 Bytes). View file
 
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc ADDED
Binary file (30.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc ADDED
Binary file (28 kB). View file
 
fla/ops/generalized_delta_rule/dplr/chunk.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from fla.ops.common.utils import prepare_chunk_indices
10
+ from fla.ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
11
+ from fla.ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_fwd_intra_dplr_fn
12
+ from fla.ops.generalized_delta_rule.dplr.chunk_h_bwd import chunk_dplr_bwd_dhu
13
+ from fla.ops.generalized_delta_rule.dplr.chunk_h_fwd import chunk_dplr_fwd_h
14
+ from fla.ops.generalized_delta_rule.dplr.chunk_o_bwd import chunk_dplr_bwd_dAu, chunk_dplr_bwd_dv, chunk_dplr_bwd_o
15
+ from fla.ops.generalized_delta_rule.dplr.chunk_o_fwd import chunk_dplr_fwd_o
16
+ from fla.ops.generalized_delta_rule.dplr.wy_fast_bwd import chunk_dplr_bwd_wy
17
+ from fla.ops.generalized_delta_rule.dplr.wy_fast_fwd import fwd_prepare_wy_repr
18
+ from fla.ops.rwkv6.chunk import chunk_rwkv6_fwd_cumsum
19
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
20
+
21
+
22
+ def chunk_dplr_fwd(
23
+ q: torch.Tensor,
24
+ k: torch.Tensor,
25
+ v: torch.Tensor,
26
+ a: torch.Tensor,
27
+ b: torch.Tensor,
28
+ gk: torch.Tensor,
29
+ scale: float,
30
+ initial_state: torch.Tensor,
31
+ output_final_state: bool,
32
+ offsets: Optional[torch.LongTensor] = None,
33
+ indices: Optional[torch.LongTensor] = None,
34
+ head_first: bool = True,
35
+ chunk_size: int = 64
36
+ ):
37
+ T = q.shape[2] if head_first else q.shape[1]
38
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
39
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
40
+
41
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
42
+ q=q,
43
+ k=k,
44
+ a=a,
45
+ b=b,
46
+ gi=gi,
47
+ ge=ge,
48
+ scale=scale,
49
+ offsets=offsets,
50
+ indices=indices,
51
+ chunk_size=BT,
52
+ head_first=head_first
53
+ )
54
+ del ge
55
+
56
+ # A_ab, A_ak, gi, ge torch.float32
57
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
58
+ w, u, _ = fwd_prepare_wy_repr(
59
+ ag=ag,
60
+ A_ab=A_ab,
61
+ A_ak=A_ak,
62
+ v=v,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ del A_ab, A_ak
69
+ h, v_new, final_state = chunk_dplr_fwd_h(
70
+ kg=kg,
71
+ bg=bg,
72
+ v=v,
73
+ w=w,
74
+ u=u,
75
+ gk=gi,
76
+ initial_state=initial_state,
77
+ output_final_state=output_final_state,
78
+ offsets=offsets,
79
+ indices=indices,
80
+ head_first=head_first,
81
+ chunk_size=BT
82
+ )
83
+ del u, kg, bg, gi
84
+
85
+ o = chunk_dplr_fwd_o(
86
+ qg=qg,
87
+ v=v,
88
+ v_new=v_new,
89
+ A_qk=A_qk,
90
+ A_qb=A_qb,
91
+ h=h,
92
+ offsets=offsets,
93
+ indices=indices,
94
+ head_first=head_first,
95
+ chunk_size=BT
96
+ )
97
+ del v_new, h, A_qk, A_qb
98
+
99
+ return o, final_state
100
+
101
+
102
+ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
103
+
104
+ @staticmethod
105
+ @input_guard
106
+ @autocast_custom_fwd
107
+ def forward(
108
+ ctx,
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ a: torch.Tensor,
113
+ b: torch.Tensor,
114
+ gk: torch.Tensor,
115
+ scale: float,
116
+ initial_state: torch.Tensor,
117
+ output_final_state: bool,
118
+ offsets: Optional[torch.LongTensor] = None,
119
+ head_first: bool = True
120
+ ):
121
+ chunk_size = 16
122
+
123
+ # 2-d indices denoting the offsets of chunks in each sequence
124
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
125
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
126
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
127
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
128
+
129
+ o, final_state = chunk_dplr_fwd(
130
+ q=q,
131
+ k=k,
132
+ v=v,
133
+ a=a,
134
+ b=b,
135
+ gk=gk,
136
+ scale=scale,
137
+ initial_state=initial_state,
138
+ output_final_state=output_final_state,
139
+ offsets=offsets,
140
+ indices=indices,
141
+ head_first=head_first,
142
+ chunk_size=chunk_size
143
+ )
144
+ ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
145
+ ctx.head_first = head_first
146
+ ctx.offsets = offsets
147
+ ctx.indices = indices
148
+ ctx.scale = scale
149
+ ctx.chunk_size = chunk_size
150
+ return o.to(q.dtype), final_state
151
+
152
+ @staticmethod
153
+ @input_guard
154
+ @autocast_custom_bwd
155
+ def backward(
156
+ ctx,
157
+ do: torch.Tensor,
158
+ dht: torch.Tensor
159
+ ):
160
+ q, k, v, a, b, gk, initial_state = ctx.saved_tensors
161
+ BT = ctx.chunk_size
162
+ head_first = ctx.head_first
163
+ offsets = ctx.offsets
164
+ indices = ctx.indices
165
+ scale = ctx.scale
166
+
167
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
168
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
169
+
170
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
171
+ q=q,
172
+ k=k,
173
+ a=a,
174
+ b=b,
175
+ gi=gi,
176
+ ge=ge,
177
+ scale=scale,
178
+ offsets=offsets,
179
+ indices=indices,
180
+ chunk_size=BT,
181
+ head_first=head_first
182
+ )
183
+ w, u, A_ab_inv = fwd_prepare_wy_repr(
184
+ ag=ag,
185
+ A_ab=A_ab,
186
+ A_ak=A_ak,
187
+ v=v,
188
+ offsets=offsets,
189
+ indices=indices,
190
+ head_first=head_first,
191
+ chunk_size=BT
192
+ )
193
+ del A_ab
194
+ h, v_new, _ = chunk_dplr_fwd_h(
195
+ kg=kg,
196
+ bg=bg,
197
+ v=v,
198
+ w=w,
199
+ u=u,
200
+ gk=gi,
201
+ initial_state=initial_state,
202
+ offsets=offsets,
203
+ indices=indices,
204
+ head_first=head_first,
205
+ chunk_size=BT
206
+ )
207
+ del u
208
+ # ******* end of recomputation *******
209
+ # A_ak, A_ab_inv, gi, ge torch.float32
210
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
211
+
212
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
213
+ v=v,
214
+ v_new=v_new,
215
+ do=do,
216
+ A_qb=A_qb,
217
+ scale=scale,
218
+ offsets=offsets,
219
+ indices=indices,
220
+ head_first=head_first,
221
+ chunk_size=BT
222
+ )
223
+
224
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
225
+ qg=qg,
226
+ bg=bg,
227
+ w=w,
228
+ gk=gi,
229
+ h0=initial_state,
230
+ dht=dht,
231
+ do=do,
232
+ dv=dv_new_intra,
233
+ offsets=offsets,
234
+ indices=indices,
235
+ head_first=head_first,
236
+ chunk_size=BT
237
+ )
238
+
239
+ dv = chunk_dplr_bwd_dv(
240
+ A_qk=A_qk,
241
+ kg=kg,
242
+ do=do,
243
+ dh=dh,
244
+ offsets=offsets,
245
+ indices=indices,
246
+ head_first=head_first,
247
+ chunk_size=BT
248
+ )
249
+ del A_qk
250
+
251
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
252
+ k=kg,
253
+ b=bg,
254
+ v=v,
255
+ v_new=v_new,
256
+ do=do,
257
+ h=h,
258
+ dh=dh,
259
+ dv=dv_new,
260
+ w=w,
261
+ gk=gi,
262
+ offsets=offsets,
263
+ indices=indices,
264
+ chunk_size=BT,
265
+ scale=scale,
266
+ head_first=head_first,
267
+ )
268
+ del v_new
269
+
270
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
271
+ A_ab_inv=A_ab_inv,
272
+ A_ak=A_ak,
273
+ v=v,
274
+ ag=ag,
275
+ dw=dw,
276
+ du=dv_new,
277
+ dv0=dv,
278
+ offsets=offsets,
279
+ indices=indices,
280
+ head_first=head_first,
281
+ chunk_size=BT
282
+ )
283
+ del A_ak
284
+
285
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
286
+ q=q,
287
+ k=k,
288
+ a=a,
289
+ b=b,
290
+ gi=gi,
291
+ ge=ge,
292
+ dAqk=dA_qk,
293
+ dAqb=dA_qb,
294
+ dAak=dA_ak,
295
+ dAab=dA_ab,
296
+ dgk_last=dgk_last,
297
+ dqg=dqg,
298
+ dkg=dkg,
299
+ dag=dag,
300
+ dbg=dbg,
301
+ chunk_size=BT,
302
+ scale=scale,
303
+ head_first=head_first,
304
+ offsets=offsets,
305
+ indices=indices
306
+ )
307
+
308
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), dgk.to(gk), None, dh0, None, None, None
309
+
310
+
311
+ @torch.compiler.disable
312
+ def chunk_dplr_delta_rule(
313
+ q: torch.Tensor,
314
+ k: torch.Tensor,
315
+ v: torch.Tensor,
316
+ a: torch.Tensor,
317
+ b: torch.Tensor,
318
+ gk: torch.Tensor,
319
+ scale: Optional[float] = None,
320
+ initial_state: Optional[torch.Tensor] = None,
321
+ output_final_state: bool = False,
322
+ cu_seqlens: Optional[torch.LongTensor] = None,
323
+ head_first: bool = False
324
+ ):
325
+ r"""
326
+ Args:
327
+ q (torch.Tensor):
328
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
329
+ k (torch.Tensor):
330
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
331
+ v (torch.Tensor):
332
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
333
+ a (torch.Tensor):
334
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
335
+ b (torch.Tensor):
336
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
337
+ gk (torch.Tensor):
338
+ gk of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. decay term in log space!
339
+ scale (Optional[int]):
340
+ Scale factor for the RetNet attention scores.
341
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
342
+ initial_state (Optional[torch.Tensor]):
343
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
344
+ For equal-length input sequences, `N` equals the batch size `B`.
345
+ Default: `None`.
346
+ output_final_state (Optional[bool]):
347
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
348
+ cu_seqlens (torch.LongTensor):
349
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
350
+ consistent with the FlashAttention API.
351
+ head_first (Optional[bool]):
352
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
353
+ Default: `False`.
354
+
355
+ Returns:
356
+ o (torch.Tensor):
357
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
358
+ final_state (torch.Tensor):
359
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
360
+ """
361
+ assert q.dtype == k.dtype == v.dtype
362
+ # assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
363
+ # gk = gk.float()
364
+
365
+ if cu_seqlens is not None:
366
+ if q.shape[0] != 1:
367
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
368
+ f"Please flatten variable-length inputs before processing.")
369
+ if head_first:
370
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
371
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
372
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
373
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
374
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
375
+ o, final_state = ChunkDPLRDeltaRuleFunction.apply(
376
+ q,
377
+ k,
378
+ v,
379
+ a,
380
+ b,
381
+ gk,
382
+ scale,
383
+ initial_state,
384
+ output_final_state,
385
+ cu_seqlens,
386
+ head_first
387
+ )
388
+ return o, final_state
fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (27.4 kB). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (23.1 kB). View file
 
fla/ops/gsa/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (69.4 kB). View file
 
fla/ops/hgrn/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (16.2 kB). View file
 
fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
fla/ops/ttt/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (88.1 kB). View file
 
fla/ops/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.12 kB). View file
 
logs/none_enyj3lod/attempt_0/3/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_17408/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_18944/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_25088/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_25088/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_33280/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_34816/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_38912/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_38912/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_7680/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_7680/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
torchtitan/components/dataloader.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ import pickle
10
+ from abc import ABC, abstractmethod
11
+ from collections.abc import Callable
12
+ from typing import Any
13
+
14
+ from torch.distributed.checkpoint.stateful import Stateful
15
+ from torch.utils.data import IterableDataset
16
+ from torchdata.stateful_dataloader import StatefulDataLoader
17
+ from torchtitan.tools.logging import logger
18
+
19
+
20
+ class BaseDataLoader(Stateful, ABC):
21
+ """Base class for all dataloaders.
22
+
23
+ This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
24
+ ``state_dict()`` and ``load_state_dict()``.
25
+ """
26
+
27
+ @abstractmethod
28
+ def __iter__(self):
29
+ ...
30
+
31
+
32
+ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
33
+ """Dataloader that is aware of distributed data parallelism.
34
+
35
+ This dataloader is used to load data in a distributed data parallel fashion. It also
36
+ utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
37
+ methods such as ``__iter__``.
38
+
39
+ Args:
40
+ dataset (IterableDataset): The dataset to iterate over.
41
+ dp_rank: Data parallelism rank for this dataloader.
42
+ dp_world_size: The world size of the data parallelism.
43
+ batch_size: The batch size to use for each iteration.
44
+ collate_fn: Optional function to collate samples in a batch.
45
+ """
46
+
47
+ dp_rank: int
48
+ dp_world_size: int
49
+ batch_size: int
50
+
51
+ def __init__(
52
+ self,
53
+ dataset: IterableDataset,
54
+ dp_rank: int,
55
+ dp_world_size: int,
56
+ batch_size: int,
57
+ collate_fn: Callable | None = None,
58
+ ):
59
+ self.dp_world_size = dp_world_size
60
+ self.dp_rank = dp_rank
61
+ self.batch_size = batch_size
62
+ super().__init__(dataset, batch_size, collate_fn=collate_fn)
63
+ self._rank_id = f"dp_rank_{dp_rank}"
64
+
65
+ def state_dict(self) -> dict[str, Any]:
66
+ # Store state only for dp rank to avoid replicating the same state across other dimensions.
67
+ return {
68
+ # We don't have to use pickle as DCP will serialize the state_dict. However,
69
+ # we have to keep this for backward compatibility.
70
+ self._rank_id: pickle.dumps(super().state_dict()),
71
+ "world_size": self.dp_world_size,
72
+ }
73
+
74
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
75
+ # State being empty is valid.
76
+ if not state_dict:
77
+ return
78
+
79
+ if self._rank_id not in state_dict:
80
+ logger.warning(
81
+ f"DataLoader state is empty for dp rank {self.dp_rank}, "
82
+ "expected key {self._rank_id}"
83
+ )
84
+ return
85
+
86
+ assert self.dp_world_size == state_dict["world_size"], (
87
+ "dp_degree is inconsistent before and after checkpoint, "
88
+ "dataloader resharding is not supported yet."
89
+ )
90
+ # We don't have to use pickle as DCP will serialize the state_dict. However, we have to
91
+ # keep this for backward compatibility.
92
+ super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
torchtitan/components/float8.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # [Note] Getting the 'torchao' package:
8
+ # This script requires the 'torchao' package to function correctly.
9
+ # Please ensure you have this package installed from the appropriate repository.
10
+ # You can obtain it from https://github.com/pytorch/ao by following the
11
+ # installation instructions.
12
+
13
+ # Note: Performance
14
+ # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.distributed import ParallelDims
21
+ from torchtitan.protocols.model_converter import (
22
+ ModelConverter,
23
+ register_model_converter,
24
+ )
25
+ from torchtitan.tools.logging import logger
26
+
27
+
28
+ def _is_sm89_or_later():
29
+ # Float8 is only supported on SM89 or later (H100+ GPUs)
30
+ return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
31
+
32
+
33
+ class Float8Converter(ModelConverter):
34
+ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
35
+ self.enabled = False
36
+
37
+ float8_config = job_config.float8
38
+ if not _is_sm89_or_later():
39
+ logger.warning(
40
+ "Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
41
+ )
42
+ return
43
+ try:
44
+ from torchao.float8 import Float8LinearConfig
45
+ except ImportError as e:
46
+ raise ImportError(
47
+ "torchao is not installed. Please install it to use float8 linear layers."
48
+ ) from e
49
+
50
+ if float8_config.recipe_name is not None and not hasattr(
51
+ Float8LinearConfig, "from_recipe_name"
52
+ ):
53
+ logger.warning(
54
+ "Failed to swap to Float8Linear with recipe lookup because the torchao version "
55
+ "is too old, please install torchao v0.9.0 or later and try again",
56
+ )
57
+ return
58
+
59
+ self.enabled = True
60
+ self.filter_fqns = float8_config.filter_fqns
61
+
62
+ if float8_config.recipe_name is not None:
63
+ assert (
64
+ not float8_config.enable_fsdp_float8_all_gather
65
+ ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
66
+ assert (
67
+ not float8_config.force_recompute_fp8_weight_in_bwd
68
+ ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
69
+ self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
70
+ self.precompute_scale = False
71
+ logger.info(
72
+ f"Float8 training active with recipe {float8_config.recipe_name}"
73
+ )
74
+
75
+ else:
76
+ # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
77
+ enable_fsdp_float8_all_gather = (
78
+ parallel_dims.dp_shard_enabled
79
+ and float8_config.enable_fsdp_float8_all_gather
80
+ )
81
+ self.config = Float8LinearConfig(
82
+ enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
83
+ force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
84
+ )
85
+ # for precompute_float8_dynamic_scale_for_fsdp
86
+ self.precompute_scale = (
87
+ enable_fsdp_float8_all_gather
88
+ and float8_config.precompute_float8_dynamic_scale_for_fsdp
89
+ )
90
+ logger.info("Float8 tensorwise scaled training active")
91
+
92
+ def convert(self, model: nn.Module):
93
+ return self.convert_to_float8_training(model)
94
+
95
+ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
96
+ return self.precompute_float8_dynamic_scale_for_fsdp(model)
97
+
98
+ def convert_to_float8_training(self, model: nn.Module):
99
+ """
100
+ This function converts the linear layers of `model` to `Float8Linear`.
101
+ Note that today, only dynamic tensor scaling (the default) is supported.
102
+ This will mutate the model inplace.
103
+ """
104
+ if not self.enabled:
105
+ return
106
+
107
+ from torchao.float8 import convert_to_float8_training
108
+
109
+ # Mutates the model inplace replacing instances of nn.Linear with Float8Linear
110
+ convert_to_float8_training(
111
+ model,
112
+ config=self.config,
113
+ module_filter_fn=self._module_filter_fn,
114
+ )
115
+ logger.info(
116
+ "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
117
+ f"{self.config.enable_fsdp_float8_all_gather}"
118
+ )
119
+
120
+ def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool:
121
+ if not isinstance(mod, nn.Linear):
122
+ return False
123
+
124
+ # All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
125
+ dims_multiples_of_16 = (
126
+ mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
127
+ )
128
+
129
+ # If the fqn matches any filtered fqn, then we should not convert this module.
130
+ is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns)
131
+
132
+ return dims_multiples_of_16 and not is_filtered_fqn
133
+
134
+ def precompute_float8_dynamic_scale_for_fsdp(
135
+ self, model: nn.Module | list[nn.Module]
136
+ ):
137
+ if not self.enabled:
138
+ return
139
+
140
+ if not self.precompute_scale:
141
+ return
142
+
143
+ from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
144
+
145
+ models = [model] if isinstance(model, nn.Module) else model
146
+ for m in models:
147
+ precompute_float8_dynamic_scale_for_fsdp(m)
148
+
149
+
150
+ register_model_converter(Float8Converter, "float8")
torchtitan/components/optimizer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import functools
8
+ from typing import Any, Generic, Iterator, TypeVar
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.distributed.checkpoint.state_dict import (
13
+ get_optimizer_state_dict,
14
+ set_optimizer_state_dict,
15
+ StateDictOptions,
16
+ )
17
+ from torch.distributed.checkpoint.stateful import Stateful
18
+ from torch.optim import Optimizer
19
+
20
+ from torchtitan.components.ft import FTManager, has_torchft
21
+ from torchtitan.config_manager import JobConfig
22
+
23
+ __all__ = [
24
+ "OptimizersContainer",
25
+ "build_optimizers",
26
+ ]
27
+
28
+
29
+ if has_torchft:
30
+ import torchft as ft
31
+
32
+
33
+ T = TypeVar("T", bound=Optimizer)
34
+
35
+
36
+ class OptimizersContainer(Optimizer, Stateful, Generic[T]):
37
+ """A container for multiple optimizers.
38
+
39
+ This class is used to wrap multiple optimizers into a single object that can be
40
+ used to reduce the complexity of the training loop. This mimics the behavior of
41
+ ``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``.
42
+
43
+ **Note**
44
+ Users who want to customize the optimizer behavior can inherit from this class and
45
+ extend the functionality as needed. The following methods must follow the same signature
46
+ as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``,
47
+ ``load_state_dict()``.
48
+
49
+ **Limitations**
50
+ This class assumes that all the optimizers are the same type and have the same
51
+ configurations. With this assumption, TorchTitan can support lr scheduler resharding
52
+ (e.g., loading a checkpoint with a different number of GPUs and/or different
53
+ parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the
54
+ resharding for the optimizer state but not for the lr scheduler state, hence the limitation.
55
+
56
+ Args:
57
+ model_parts (List[nn.Module]): List of model parts to be optimized.
58
+ optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers.
59
+ name (str): Name of the optimizers.
60
+ """
61
+
62
+ optimizers: list[T]
63
+ model_parts: list[nn.Module]
64
+
65
+ def __init__(
66
+ self,
67
+ model_parts: list[nn.Module],
68
+ optimizer_cls: type[T],
69
+ optimizer_kwargs: dict[str, Any],
70
+ ) -> None:
71
+ all_params = []
72
+ self.optimizers = []
73
+ self.model_parts = model_parts
74
+ for model in self.model_parts:
75
+ params = [p for p in model.parameters() if p.requires_grad]
76
+ self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
77
+ all_params.extend(params)
78
+ self._validate_length(len(self.model_parts))
79
+ self._post_init(all_params, optimizer_kwargs)
80
+
81
+ def __iter__(self) -> Iterator[T]:
82
+ return iter(self.optimizers)
83
+
84
+ def __len__(self) -> int:
85
+ return len(self.optimizers)
86
+
87
+ def step(self, *args, **kwargs) -> None:
88
+ for optimizer in self.optimizers:
89
+ optimizer.step(*args, **kwargs)
90
+
91
+ def zero_grad(self, *args, **kwargs) -> None:
92
+ for optimizer in self.optimizers:
93
+ optimizer.zero_grad(*args, **kwargs)
94
+
95
+ def state_dict(self) -> dict[str, Any]:
96
+ func = functools.partial(
97
+ get_optimizer_state_dict,
98
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
99
+ )
100
+ return {
101
+ k: v
102
+ for sd in map(func, self.model_parts, self.optimizers)
103
+ for k, v in sd.items()
104
+ }
105
+
106
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
107
+ func = functools.partial(
108
+ set_optimizer_state_dict,
109
+ optim_state_dict=state_dict,
110
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
111
+ )
112
+ list(map(func, self.model_parts, self.optimizers))
113
+
114
+ def _validate_length(self, expected_length: int) -> None:
115
+ assert expected_length == len(self.optimizers), (
116
+ "Must pass one optimizer per model part or per param if "
117
+ "using OptimizersInBackwardContainer."
118
+ )
119
+
120
+ def _post_init(
121
+ self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
122
+ ) -> None:
123
+ # We need to call Optimizer.__init__() to initialize some necessary optimizer
124
+ # functionality such as hooks.
125
+ Optimizer.__init__(self, all_params, optimizer_kwargs)
126
+
127
+
128
+ class OptimizersInBackwardContainer(OptimizersContainer):
129
+ """OptimizersContainer for executing ``optim.step()`` in backward pass.
130
+
131
+ This class extend ``OptimizersContainer`` to support optimizer step in
132
+ backward pass. ``step()`` and ``zero_grad()`` are no-op in this class.
133
+ Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to
134
+ execute these methods when the gradient is accumulated.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ model_parts: list[nn.Module],
140
+ optimizer_cls: type[T],
141
+ optimizer_kwargs: dict[str, Any],
142
+ ) -> None:
143
+ all_params = []
144
+ self.model_parts = model_parts
145
+
146
+ optim_dict = {}
147
+ for model in self.model_parts:
148
+ for p in model.parameters():
149
+ if p.requires_grad:
150
+ optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
151
+ all_params.append(p)
152
+
153
+ def optim_hook(param) -> None:
154
+ optim_dict[param].step()
155
+ optim_dict[param].zero_grad()
156
+
157
+ for model in self.model_parts:
158
+ for param in model.parameters():
159
+ if param.requires_grad:
160
+ param.register_post_accumulate_grad_hook(optim_hook)
161
+
162
+ self.optimizers = list(optim_dict.values())
163
+
164
+ self._validate_length(
165
+ sum(len(list(model.parameters())) for model in self.model_parts)
166
+ )
167
+ self._post_init(all_params, optimizer_kwargs)
168
+
169
+ def step(self) -> None:
170
+ pass
171
+
172
+ def zero_grad(self) -> None:
173
+ pass
174
+
175
+
176
+ class FTOptimizersContainer(OptimizersContainer):
177
+ def __init__(
178
+ self,
179
+ model_parts: list[nn.Module],
180
+ optimizer_cls: type[T],
181
+ optimizer_kwargs: dict[str, Any],
182
+ ft_manager: "ft.Manager",
183
+ ) -> None:
184
+ super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
185
+
186
+ # Force to initialize the optimizer state so that `optim.step()`
187
+ # won't be called by state_dict() and load_state_dict().
188
+ _ = {
189
+ k: v
190
+ for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
191
+ for k, v in sd.items()
192
+ }
193
+ self.cache_state_dict: dict[str, Any] = {}
194
+ self._ft_optimizer = ft.Optimizer(ft_manager, self)
195
+ self._call_from_ft: bool = False
196
+
197
+ def init_cache_state_dict(self) -> None:
198
+ self.cache_state_dict = super().state_dict()
199
+
200
+ def state_dict(self) -> dict[str, Any]:
201
+ return self.cache_state_dict
202
+
203
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
204
+ # We have to invalidate the `cache_state_dict` because optimizer uses
205
+ # assign instead of copy when doing `load_state_dict()`. Without
206
+ # invalidating the `cache_state_dict`, there will be memory leakage.
207
+ self.cache_state_dict = {}
208
+ super().load_state_dict(state_dict)
209
+ self.init_cache_state_dict()
210
+
211
+ def step(self, *args, **kwargs) -> None:
212
+ """Calling the correct step() depending on the caller.
213
+
214
+ TorchFT's OptimizerWrapper.step() is designed to be callled only once
215
+ per train step per ft.Manager regardless how many optimizers are used.
216
+ Hence we will need to appropriately dispatch the call.
217
+ """
218
+ if self._call_from_ft:
219
+ super().step(*args, **kwargs)
220
+ else:
221
+ self._call_from_ft = True
222
+ self._ft_optimizer.step(*args, **kwargs)
223
+ self._call_from_ft = False
224
+
225
+ def zero_grad(self, *args, **kwargs) -> None:
226
+ """Calling the correct zero_grad() depending on the caller.
227
+
228
+ Check the comment in ``step()``.
229
+ """
230
+ if self._call_from_ft:
231
+ super().zero_grad(*args, **kwargs)
232
+ else:
233
+ self._call_from_ft = True
234
+ self._ft_optimizer.zero_grad(*args, **kwargs)
235
+ self._call_from_ft = False
236
+
237
+
238
+ def build_optimizers(
239
+ model_parts: list[nn.Module],
240
+ job_config: JobConfig,
241
+ ft_manager: FTManager,
242
+ ) -> OptimizersContainer:
243
+ """Create a OptimizersContainer for the given model parts and job config.
244
+
245
+ This function creates a ``OptimizersContainer`` for the given model parts.
246
+ ``job_config`` should define the correct optimizer name and parameters.
247
+ This function currently supports creating ``OptimizersContainer`` and
248
+ ``OptimizersInBackwardContainer``.
249
+
250
+ **Note**
251
+ Users who want to customize the optimizer behavior can create their own
252
+ ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
253
+ customized ``build_optimizers`` to ``TrainSpec`` will create the customized
254
+ ``OptimizersContainer``.
255
+
256
+ Args:
257
+ model_parts (List[nn.Module]): List of model parts to be optimized.
258
+ job_config (JobConfig): Job config containing the optimizer name and parameters.
259
+ """
260
+ optim_in_bwd = job_config.optimizer.early_step_in_backward
261
+ if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
262
+ raise NotImplementedError(
263
+ "Optimizers in backward is not supported with pipeline parallelism."
264
+ )
265
+ name = job_config.optimizer.name
266
+ lr = job_config.optimizer.lr
267
+ eps = job_config.optimizer.eps
268
+
269
+ optim_implementation = job_config.optimizer.implementation
270
+ assert optim_implementation in ["fused", "foreach", "for-loop"]
271
+
272
+ fused = optim_implementation == "fused"
273
+ foreach = optim_implementation == "foreach"
274
+
275
+ optimizer_kwargs = {
276
+ "lr": lr,
277
+ "eps": eps,
278
+ "betas": (0.9, 0.95),
279
+ "weight_decay": 0.1,
280
+ "fused": fused,
281
+ "foreach": foreach,
282
+ }
283
+
284
+ optimizer_classes = {
285
+ "Adam": torch.optim.Adam,
286
+ "AdamW": torch.optim.AdamW,
287
+ }
288
+ if name not in optimizer_classes:
289
+ raise NotImplementedError(f"Optimizer {name} not added.")
290
+ optimizer_cls = optimizer_classes[name]
291
+
292
+ if optim_in_bwd and ft_manager.enabled:
293
+ raise ValueError("TorchFT is not supported with optimizers in backward.")
294
+ elif optim_in_bwd:
295
+ return OptimizersInBackwardContainer(
296
+ model_parts, optimizer_cls, optimizer_kwargs
297
+ )
298
+ elif ft_manager.enabled:
299
+ return FTOptimizersContainer(
300
+ model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager
301
+ )
302
+ else:
303
+ return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc ADDED
Binary file (7.04 kB). View file
 
torchtitan/datasets/hf_datasets.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable
9
+
10
+ import torch
11
+
12
+ from datasets import Dataset, load_dataset
13
+ from datasets.distributed import split_dataset_by_node
14
+ from torch.distributed.checkpoint.stateful import Stateful
15
+ from torch.utils.data import IterableDataset
16
+
17
+ from torchtitan.components.dataloader import ParallelAwareDataloader
18
+ from torchtitan.components.tokenizer import Tokenizer
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ def _load_c4_dataset(dataset_path: str):
24
+ """Load C4 dataset with default configuration."""
25
+ return load_dataset(dataset_path, name="en", split="train", streaming=True)
26
+
27
+
28
+ def _process_c4_text(sample: dict[str, Any]) -> str:
29
+ """Process C4 dataset sample text."""
30
+ return sample["text"]
31
+
32
+
33
+ @dataclass
34
+ class DatasetConfig:
35
+ path: str
36
+ loader: Callable
37
+ text_processor: Callable
38
+
39
+
40
+ # Add your dataset here here - more information at docs/datasets.md
41
+ DATASETS = {
42
+ "c4": DatasetConfig(
43
+ path="allenai/c4",
44
+ loader=_load_c4_dataset,
45
+ text_processor=_process_c4_text,
46
+ ),
47
+ "c4_test": DatasetConfig(
48
+ path="tests/assets/c4_test",
49
+ loader=lambda path: load_dataset(path, split="train"),
50
+ text_processor=_process_c4_text,
51
+ ),
52
+ }
53
+
54
+
55
+ def _validate_dataset(
56
+ dataset_name: str, dataset_path: str | None = None
57
+ ) -> tuple[str, Callable, Callable]:
58
+ """Validate dataset name and path."""
59
+ if dataset_name not in DATASETS:
60
+ raise ValueError(
61
+ f"Dataset {dataset_name} is not supported. "
62
+ f"Supported datasets are: {list(DATASETS.keys())}"
63
+ )
64
+
65
+ config = DATASETS[dataset_name]
66
+ path = dataset_path or config.path
67
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
68
+ return path, config.loader, config.text_processor
69
+
70
+
71
+ class HuggingFaceDataset(IterableDataset, Stateful):
72
+ def __init__(
73
+ self,
74
+ dataset_name: str,
75
+ dataset_path: str | None,
76
+ tokenizer: Tokenizer,
77
+ seq_len: int = 2048,
78
+ dp_rank: int = 0,
79
+ dp_world_size: int = 1,
80
+ infinite: bool = False,
81
+ ) -> None:
82
+ # Force lowercase for consistent comparison
83
+ dataset_name = dataset_name.lower()
84
+
85
+ path, dataset_loader, text_processor = _validate_dataset(
86
+ dataset_name, dataset_path
87
+ )
88
+ ds = dataset_loader(path)
89
+
90
+ self.dataset_name = dataset_name
91
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
92
+ self._tokenizer = tokenizer
93
+ self.seq_len = seq_len
94
+ self.infinite = infinite
95
+ self._text_processor = text_processor
96
+
97
+ # Variables for checkpointing
98
+ self._sample_idx = 0
99
+ self._all_tokens: list[int] = []
100
+
101
+ def _get_data_iter(self):
102
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
103
+ return iter([])
104
+
105
+ it = iter(self._data)
106
+ for _ in range(self._sample_idx):
107
+ next(it)
108
+ return it
109
+
110
+ def __iter__(self):
111
+ max_buffer_token_len = 1 + self.seq_len
112
+
113
+ while True:
114
+ for sample in self._get_data_iter():
115
+ # Use the dataset-specific text processor
116
+ sample_text = self._text_processor(sample)
117
+ sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
118
+ self._all_tokens.extend(sample_tokens)
119
+ self._sample_idx += 1
120
+
121
+ while len(self._all_tokens) >= max_buffer_token_len:
122
+ x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
123
+ # update tokens to the remaining tokens
124
+ self._all_tokens = self._all_tokens[max_buffer_token_len:]
125
+ input = x[:-1]
126
+ label = x[1:]
127
+ yield {"input": input}, label
128
+
129
+ if not self.infinite:
130
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
131
+ break
132
+ else:
133
+ # Reset offset for the next iteration
134
+ self._sample_idx = 0
135
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
136
+
137
+ def load_state_dict(self, state_dict):
138
+ self._sample_idx = state_dict["sample_idx"]
139
+ self._all_tokens = state_dict["token_buffer"]
140
+
141
+ def state_dict(self):
142
+ return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
143
+
144
+
145
+ def build_hf_dataloader(
146
+ dp_world_size: int,
147
+ dp_rank: int,
148
+ tokenizer: Tokenizer,
149
+ job_config: JobConfig,
150
+ infinite: bool = True,
151
+ ) -> ParallelAwareDataloader:
152
+ """Build a data loader for HuggingFace datasets."""
153
+ dataset_name = job_config.training.dataset
154
+ dataset_path = job_config.training.dataset_path
155
+ batch_size = job_config.training.batch_size
156
+ seq_len = job_config.training.seq_len
157
+
158
+ hf_ds = HuggingFaceDataset(
159
+ dataset_name=dataset_name,
160
+ dataset_path=dataset_path,
161
+ tokenizer=tokenizer,
162
+ seq_len=seq_len,
163
+ dp_rank=dp_rank,
164
+ dp_world_size=dp_world_size,
165
+ infinite=infinite,
166
+ )
167
+
168
+ return ParallelAwareDataloader(
169
+ dataset=hf_ds,
170
+ dp_rank=dp_rank,
171
+ dp_world_size=dp_world_size,
172
+ batch_size=batch_size,
173
+ )
torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc ADDED
Binary file (7.73 kB). View file
 
torchtitan/datasets/tokenizer/tiktoken.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+ import os
11
+ from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet
12
+ from pathlib import Path
13
+ from typing import cast, Literal
14
+
15
+ import tiktoken
16
+ from tiktoken.load import load_tiktoken_bpe
17
+
18
+ from torchtitan.components.tokenizer import Tokenizer
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class TikTokenizer(Tokenizer):
24
+ """
25
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
26
+
27
+ Args:
28
+ model_path (str): The path to the Tiktoken model file.
29
+ """
30
+
31
+ special_tokens: dict[str, int]
32
+
33
+ num_reserved_special_tokens = 256
34
+
35
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
36
+
37
+ def __init__(self, model_path: str):
38
+ super().__init__()
39
+ assert os.path.exists(
40
+ model_path
41
+ ), f"The tokenizer path does not exist: {model_path}"
42
+ assert os.path.isfile(model_path), model_path
43
+
44
+ mergeable_ranks = load_tiktoken_bpe(model_path)
45
+ num_base_tokens = len(mergeable_ranks)
46
+ special_tokens = [
47
+ "<|begin_of_text|>",
48
+ "<|end_of_text|>",
49
+ "<|reserved_special_token_0|>",
50
+ "<|reserved_special_token_1|>",
51
+ "<|reserved_special_token_2|>",
52
+ "<|reserved_special_token_3|>",
53
+ "<|start_header_id|>",
54
+ "<|end_header_id|>",
55
+ "<|reserved_special_token_4|>",
56
+ "<|eot_id|>", # end of turn
57
+ ] + [
58
+ f"<|reserved_special_token_{i}|>"
59
+ for i in range(5, self.num_reserved_special_tokens - 5)
60
+ ]
61
+ self.special_tokens = {
62
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
63
+ }
64
+ self.model = tiktoken.Encoding(
65
+ name=Path(model_path).name,
66
+ pat_str=self.pat_str,
67
+ mergeable_ranks=mergeable_ranks,
68
+ special_tokens=self.special_tokens,
69
+ )
70
+
71
+ self._n_words: int = self.model.n_vocab
72
+ # BOS / EOS token IDs
73
+ self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
74
+ self.eos_id: int = self.special_tokens["<|end_of_text|>"]
75
+ self.pad_id: int = -1
76
+ self.stop_tokens = {
77
+ self.special_tokens["<|end_of_text|>"],
78
+ self.special_tokens["<|eot_id|>"],
79
+ }
80
+ logger.info(
81
+ f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}"
82
+ )
83
+
84
+ def encode(
85
+ self,
86
+ s: str,
87
+ *,
88
+ bos: bool,
89
+ eos: bool,
90
+ allowed_special: Literal["all"] | AbstractSet[str] | None = None,
91
+ disallowed_special: Literal["all"] | Collection[str] | None = None,
92
+ ) -> list[int]:
93
+ """
94
+ Encodes a string into a list of token IDs.
95
+
96
+ Args:
97
+ s (str): The input string to be encoded.
98
+ bos (bool): Whether to prepend the beginning-of-sequence token.
99
+ eos (bool): Whether to append the end-of-sequence token.
100
+ allowed_tokens ("all"|set[str]): allowed special tokens in string
101
+ disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
102
+
103
+ Returns:
104
+ list[int]: A list of token IDs.
105
+
106
+ By default, setting disallowed_special=() encodes a string by ignoring
107
+ special tokens. Specifically:
108
+ - Setting `disallowed_special` to () will cause all text corresponding
109
+ to special tokens to be encoded as natural text (insteading of raising
110
+ an error).
111
+ - Setting `allowed_special` to "all" will treat all text corresponding
112
+ to special tokens to be encoded as special tokens.
113
+ """
114
+ assert type(s) is str
115
+ allowed_special = allowed_special or set()
116
+ disallowed_special = disallowed_special or ()
117
+
118
+ # The tiktoken tokenizer can handle <=400k chars without
119
+ # pyo3_runtime.PanicException.
120
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
121
+
122
+ # https://github.com/openai/tiktoken/issues/195
123
+ # Here we iterate over subsequences and split if we exceed the limit
124
+ # of max consecutive non-whitespace or whitespace characters.
125
+ MAX_NO_WHITESPACES_CHARS = 25_000
126
+
127
+ substrs = (
128
+ substr
129
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
130
+ for substr in self._split_whitespaces_or_nonwhitespaces(
131
+ s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
132
+ )
133
+ )
134
+ t: list[int] = []
135
+ for substr in substrs:
136
+ t.extend(
137
+ self.model.encode(
138
+ substr,
139
+ allowed_special=allowed_special,
140
+ disallowed_special=disallowed_special,
141
+ )
142
+ )
143
+ if bos:
144
+ t.insert(0, self.bos_id)
145
+ if eos:
146
+ t.append(self.eos_id)
147
+ return t
148
+
149
+ def decode(self, t: Sequence[int]) -> str:
150
+ """
151
+ Decodes a list of token IDs into a string.
152
+
153
+ Args:
154
+ t (List[int]): The list of token IDs to be decoded.
155
+
156
+ Returns:
157
+ str: The decoded string.
158
+ """
159
+ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
160
+ return self.model.decode(cast(list[int], t))
161
+
162
+ @staticmethod
163
+ def _split_whitespaces_or_nonwhitespaces(
164
+ s: str, max_consecutive_slice_len: int
165
+ ) -> Iterator[str]:
166
+ """
167
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
168
+ consecutive whitespaces or consecutive non-whitespaces.
169
+ """
170
+ current_slice_len = 0
171
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
172
+ slice_start = 0
173
+
174
+ for i in range(len(s)):
175
+ is_now_space = s[i].isspace()
176
+
177
+ if current_slice_is_space ^ is_now_space:
178
+ current_slice_len = 1
179
+ current_slice_is_space = is_now_space
180
+ else:
181
+ current_slice_len += 1
182
+ if current_slice_len > max_consecutive_slice_len:
183
+ yield s[slice_start:i]
184
+ slice_start = i
185
+ current_slice_len = 1
186
+ yield s[slice_start:]
187
+
188
+
189
+ def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
190
+ return TikTokenizer(job_config.model.tokenizer_path)
torchtitan/distributed/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (255 Bytes). View file
 
torchtitan/distributed/__pycache__/utils.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
torchtitan/experiments/deepseek_v3/inference.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #!/usr/bin/bash
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ NGPU=${NGPU:-"4"}
10
+
11
+ # Get the prompt from command line argument or use a default
12
+ prompt="${1:-What is 2+2?}"
13
+
14
+ # Run the model with the prompt
15
+ torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
torchtitan/experiments/deepseek_v3/model_config.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class ModelArgs:
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
14
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
15
+ defaults will yield a similar configuration to that of the DeepSeek-V3.
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
+ documentation from [`PretrainedConfig`] for more information.
18
+ Args:
19
+ vocab_size (`int`, *optional*, defaults to 129280):
20
+ Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
21
+ `inputs_ids` passed when calling [`DeepseekV3Model`]
22
+ hidden_size (`int`, *optional*, defaults to 4096):
23
+ Dimension of the hidden representations.
24
+ intermediate_size (`int`, *optional*, defaults to 11008):
25
+ Dimension of the MLP representations.
26
+ moe_intermediate_size (`int`, *optional*, defaults to 1407):
27
+ Dimension of the MoE representations.
28
+ num_hidden_layers (`int`, *optional*, defaults to 32):
29
+ Number of hidden layers in the Transformer decoder.
30
+ num_nextn_predict_layers (`int`, *optional*, defaults to 1):
31
+ Number of nextn predict layers in the DeepSeekV3 Model.
32
+ num_attention_heads (`int`, *optional*, defaults to 32):
33
+ Number of attention heads for each attention layer in the Transformer decoder.
34
+ n_shared_experts (`int`, *optional*, defaults to None):
35
+ Number of shared experts, None means dense model.
36
+ n_routed_experts (`int`, *optional*, defaults to None):
37
+ Number of routed experts, None means dense model.
38
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
39
+ Scaling factor or routed experts.
40
+ topk_method (`str`, *optional*, defaults to `gready`):
41
+ Topk method used in routed gate.
42
+ n_group (`int`, *optional*, defaults to None):
43
+ Number of groups for routed experts.
44
+ topk_group (`int`, *optional*, defaults to None):
45
+ Number of selected groups for each token(for each token, ensuring the selected experts is only within
46
+ `topk_group` groups).
47
+ num_experts_per_tok (`int`, *optional*, defaults to None):
48
+ Number of selected experts, None means dense model.
49
+ moe_layer_freq (`int`, *optional*, defaults to 1):
50
+ The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
51
+ first_k_dense_replace (`int`, *optional*, defaults to 0):
52
+ Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
53
+ \--k dense layers--/
54
+ norm_topk_prob (`bool`, *optional*, defaults to False):
55
+ Whether to normalize the weights of the routed experts.
56
+ scoring_func (`str`, *optional*, defaults to 'softmax'):
57
+ Method of computing expert weights.
58
+ aux_loss_alpha (`float`, *optional*, defaults to 0.001):
59
+ Auxiliary loss weight coefficient.
60
+ seq_aux = (`bool`, *optional*, defaults to True):
61
+ Whether to compute the auxiliary loss for each individual sample.
62
+ num_key_value_heads (`int`, *optional*):
63
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
64
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
65
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
66
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
67
+ by meanpooling all the original heads within that group. For more details checkout [this
68
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
69
+ `num_attention_heads`.
70
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
71
+ The non-linear activation function (function or string) in the decoder.
72
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
73
+ The maximum sequence length that this model might ever be used with.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
77
+ The epsilon used by the rms normalization layers.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`.
81
+ pad_token_id (`int`, *optional*):
82
+ Padding token id.
83
+ bos_token_id (`int`, *optional*, defaults to 1):
84
+ Beginning of stream token id.
85
+ eos_token_id (`int`, *optional*, defaults to 2):
86
+ End of stream token id.
87
+ pretraining_tp (`int`, *optional*, defaults to 1):
88
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
89
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
90
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
91
+ issue](https://github.com/pytorch/pytorch/issues/76232).
92
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
93
+ Whether to tie weight embeddings
94
+ rope_theta (`float`, *optional*, defaults to 10000.0):
95
+ The base period of the RoPE embeddings.
96
+ rope_scaling (`Dict`, *optional*):
97
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
98
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
99
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
100
+ `max_position_embeddings` to the expected new maximum.
101
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
102
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
103
+ attention_dropout (`float`, *optional*, defaults to 0.0):
104
+ The dropout ratio for the attention probabilities.
105
+ """
106
+
107
+ vocab_size: int = 129280
108
+ hidden_size: int = 7168
109
+ intermediate_size: int = 18432
110
+ moe_intermediate_size: int = 2048
111
+ num_hidden_layers: int = 61
112
+ num_nextn_predict_layers: int = 1
113
+ num_attention_heads: int = 128
114
+ num_key_value_heads: int = 128
115
+ n_shared_experts: int = 1
116
+ n_routed_experts: int = 256
117
+ ep_size: int = 1
118
+ routed_scaling_factor: float = 2.5
119
+ kv_lora_rank: int = 512
120
+ q_lora_rank: int = 1536
121
+ qk_rope_head_dim: int = 64
122
+ v_head_dim: int = 128
123
+ qk_nope_head_dim: int = 128
124
+ topk_method: str = "noaux_tc"
125
+ n_group: int = 8
126
+ topk_group: int = 4
127
+ num_experts_per_tok: int = 8
128
+ moe_layer_freq: int = 1
129
+ first_k_dense_replace: int = 3
130
+ norm_topk_prob: bool = True
131
+ scoring_func: str = "sigmoid"
132
+ aux_loss_alpha: float = 0.001
133
+ seq_aux: bool = True
134
+ hidden_act: str = "silu"
135
+ max_position_embeddings: int = 163840
136
+ initializer_range: float = 0.02
137
+ rms_norm_eps: float = 1e-6
138
+ rope_theta: float = 10000.0
139
+ rope_scaling: dict = field(
140
+ default_factory=lambda: {
141
+ "beta_fast": 32,
142
+ "beta_slow": 1,
143
+ "factor": 40,
144
+ "mscale": 1.0,
145
+ "mscale_all_dim": 1.0,
146
+ "original_max_position_embeddings": 4096,
147
+ "type": "yarn",
148
+ }
149
+ )
150
+ attention_bias: bool = False
151
+ attention_dropout: float = 0.0
152
+ pad_token_id = None
153
+ # Added for symmetric memory
154
+ max_seq_len: int = 4096
155
+ dtype: str = "bfloat16"
156
+ # Added for pipeline parallel
157
+ num_stages: int = 1
158
+ stage_idx: int = 0
159
+
160
+
161
+ # This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
162
+ deepseek_v2_lite_config = ModelArgs(
163
+ vocab_size=102400,
164
+ hidden_size=2048,
165
+ intermediate_size=10944,
166
+ moe_intermediate_size=1408,
167
+ num_hidden_layers=27,
168
+ num_attention_heads=16,
169
+ num_key_value_heads=16,
170
+ n_shared_experts=2,
171
+ n_routed_experts=64,
172
+ routed_scaling_factor=1.0,
173
+ kv_lora_rank=512,
174
+ q_lora_rank=None,
175
+ qk_rope_head_dim=64,
176
+ v_head_dim=128,
177
+ qk_nope_head_dim=128,
178
+ topk_method="greedy",
179
+ n_group=1,
180
+ topk_group=1,
181
+ num_experts_per_tok=6,
182
+ first_k_dense_replace=1,
183
+ norm_topk_prob=False,
184
+ scoring_func="softmax",
185
+ max_position_embeddings=4096,
186
+ rope_scaling={
187
+ "beta_fast": 32,
188
+ "beta_slow": 1,
189
+ "factor": 40,
190
+ "mscale": 0.707,
191
+ "mscale_all_dim": 0.707,
192
+ "original_max_position_embeddings": 4096,
193
+ "type": "yarn",
194
+ },
195
+ )
196
+
197
+
198
+ # Model configuration registry
199
+ # Key is the model distribution ID on HuggingFace Hub
200
+ deepseek_config_registry = {
201
+ "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
202
+ "deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
203
+ "deepseek-ai/deepseek-v3": ModelArgs(),
204
+ }
torchtitan/experiments/flux/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX model in torchtitan
2
+
3
+ ## Overview
4
+
5
+ ## Usage
6
+ First, download the autoencoder model from HuggingFace with your own access token:
7
+ ```bash
8
+ python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
9
+ ```
10
+ This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
11
+
12
+ Run the following command to train the model on a single GPU:
13
+ ```bash
14
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
15
+ ```
16
+
17
+ ## TODO
18
+ - [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19
+ - [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
20
+ - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
21
+ - [ ] Support for distributed checkpointing and loading
22
+ - [ ] Implement init_weights() function to initialize the model weights
23
+ - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
torchtitan/experiments/flux/__pycache__/parallelize_flux.cpython-312.pyc ADDED
Binary file (648 Bytes). View file
 
torchtitan/experiments/flux/flux_argparser.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+
9
+ import torch
10
+
11
+
12
+ def extend_parser(parser: argparse.ArgumentParser) -> None:
13
+ parser.add_argument(
14
+ "--training.guidance",
15
+ type=float,
16
+ default=3.5,
17
+ help="guidance value used for guidance distillation",
18
+ )
19
+ parser.add_argument(
20
+ "--encoder.t5_encoder",
21
+ type=str,
22
+ default="google/t5-v1_1-small",
23
+ help="T5 encoder to use, HuggingFace model name.",
24
+ )
25
+ parser.add_argument(
26
+ "--encoder.clip_encoder",
27
+ type=str,
28
+ default="openai/clip-vit-large-patch14",
29
+ help="Clip encoder to use, HuggingFace model name.",
30
+ )
31
+ parser.add_argument(
32
+ "--encoder.encoder_dtype",
33
+ type=torch.dtype,
34
+ default=torch.bfloat16,
35
+ help="Which dtype to load for autoencoder. ",
36
+ )
37
+ parser.add_argument(
38
+ "--encoder.max_t5_encoding_len",
39
+ type=int,
40
+ default=512,
41
+ help="Maximum length of the T5 encoding.",
42
+ )
torchtitan/experiments/flux/loss.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, TypeAlias
8
+
9
+ import torch
10
+
11
+ from torchtitan.config_manager import JobConfig
12
+ from torchtitan.tools.logging import logger
13
+
14
+ LossFunction: TypeAlias = Callable[..., torch.Tensor]
15
+
16
+
17
+ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
18
+ """Common MSE loss function for Transformer models training."""
19
+ return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())
20
+
21
+
22
+ def build_mse_loss(job_config: JobConfig):
23
+ loss_fn = mse_loss
24
+ if job_config.training.compile:
25
+ logger.info("Compiling the loss function with torch.compile")
26
+ loss_fn = torch.compile(loss_fn)
27
+ return loss_fn