Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- fla/models/abc/configuration_abc.py +91 -0
- fla/models/forgetting_transformer/__init__.py +16 -0
- fla/models/gla/__init__.py +13 -0
- fla/models/nsa/__init__.py +15 -0
- fla/models/transformer_top/configuration_transformer.py +76 -0
- fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
- fla/modules/__pycache__/convolution.cpython-312.pyc +0 -0
- fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc +0 -0
- fla/modules/__pycache__/layernorm.cpython-312.pyc +0 -0
- fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc +0 -0
- fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
- fla/ops/delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
- fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc +0 -0
- fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
- fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc +0 -0
- fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc +0 -0
- fla/ops/generalized_delta_rule/dplr/chunk.py +388 -0
- fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
- fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc +0 -0
- fla/ops/gsa/__pycache__/chunk.cpython-312.pyc +0 -0
- fla/ops/hgrn/__pycache__/chunk.cpython-312.pyc +0 -0
- fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
- fla/ops/ttt/__pycache__/chunk.cpython-312.pyc +0 -0
- fla/ops/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- logs/none_enyj3lod/attempt_0/3/stderr.log +0 -0
- profile_trace/iteration_17408/rank4_trace.json +0 -0
- profile_trace/iteration_18944/rank2_trace.json +0 -0
- profile_trace/iteration_25088/rank3_trace.json +0 -0
- profile_trace/iteration_25088/rank7_trace.json +0 -0
- profile_trace/iteration_33280/rank6_trace.json +0 -0
- profile_trace/iteration_34816/rank5_trace.json +0 -0
- profile_trace/iteration_38912/rank1_trace.json +0 -0
- profile_trace/iteration_38912/rank2_trace.json +0 -0
- profile_trace/iteration_7680/rank0_trace.json +0 -0
- profile_trace/iteration_7680/rank4_trace.json +0 -0
- torchtitan/components/dataloader.py +92 -0
- torchtitan/components/float8.py +150 -0
- torchtitan/components/optimizer.py +303 -0
- torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc +0 -0
- torchtitan/datasets/hf_datasets.py +173 -0
- torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc +0 -0
- torchtitan/datasets/tokenizer/tiktoken.py +190 -0
- torchtitan/distributed/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/distributed/__pycache__/utils.cpython-312.pyc +0 -0
- torchtitan/experiments/deepseek_v3/inference.sh +15 -0
- torchtitan/experiments/deepseek_v3/model_config.py +204 -0
- torchtitan/experiments/flux/README.md +23 -0
- torchtitan/experiments/flux/__pycache__/parallelize_flux.cpython-312.pyc +0 -0
- torchtitan/experiments/flux/flux_argparser.py +42 -0
- torchtitan/experiments/flux/loss.py +27 -0
fla/models/abc/configuration_abc.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ABCConfig(PretrainedConfig):
|
| 9 |
+
|
| 10 |
+
model_type = 'abc'
|
| 11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
hidden_size: int = 2048,
|
| 16 |
+
gate_low_rank_dim: int = 16,
|
| 17 |
+
clamp_min: float = -32,
|
| 18 |
+
clamp_max: float = 32,
|
| 19 |
+
hidden_ratio: Optional[int] = 4,
|
| 20 |
+
intermediate_size: Optional[int] = None,
|
| 21 |
+
num_hidden_layers: int = 24,
|
| 22 |
+
num_heads: int = 4,
|
| 23 |
+
num_slots: Optional[int] = 64,
|
| 24 |
+
use_short_conv: bool = False,
|
| 25 |
+
conv_size: int = 4,
|
| 26 |
+
exapnd_k: float = 0.5,
|
| 27 |
+
exapnd_v: float = 1,
|
| 28 |
+
hidden_act: str = "swish",
|
| 29 |
+
max_position_embeddings: int = 2048,
|
| 30 |
+
elementwise_affine: Optional[bool] = True,
|
| 31 |
+
norm_eps: float = 1e-6,
|
| 32 |
+
use_rope: bool = True,
|
| 33 |
+
attn: Optional[Dict] = None,
|
| 34 |
+
use_cache: bool = True,
|
| 35 |
+
pad_token_id: int = None,
|
| 36 |
+
bos_token_id: int = 1,
|
| 37 |
+
eos_token_id: int = 2,
|
| 38 |
+
tie_word_embeddings: bool = False,
|
| 39 |
+
initializer_range: float = 0.006,
|
| 40 |
+
fuse_norm: bool = True,
|
| 41 |
+
fuse_swiglu: bool = True,
|
| 42 |
+
fuse_cross_entropy: bool = True,
|
| 43 |
+
vocab_size: int = 32000,
|
| 44 |
+
**kwargs
|
| 45 |
+
):
|
| 46 |
+
self.hidden_size = hidden_size
|
| 47 |
+
self.gate_low_rank_dim = gate_low_rank_dim
|
| 48 |
+
self.clamp_min = clamp_min
|
| 49 |
+
self.clamp_max = clamp_max
|
| 50 |
+
self.hidden_ratio = hidden_ratio
|
| 51 |
+
self.intermediate_size = intermediate_size
|
| 52 |
+
self.num_hidden_layers = num_hidden_layers
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.num_slots = num_slots
|
| 55 |
+
self.use_short_conv = use_short_conv
|
| 56 |
+
self.conv_size = conv_size
|
| 57 |
+
self.expand_k = exapnd_k
|
| 58 |
+
self.expand_v = exapnd_v
|
| 59 |
+
self.hidden_act = hidden_act
|
| 60 |
+
self.max_position_embeddings = max_position_embeddings
|
| 61 |
+
self.elementwise_affine = elementwise_affine
|
| 62 |
+
self.norm_eps = norm_eps
|
| 63 |
+
self.use_rope = use_rope
|
| 64 |
+
self.attn = attn
|
| 65 |
+
self.use_cache = use_cache
|
| 66 |
+
self.initializer_range = initializer_range
|
| 67 |
+
|
| 68 |
+
self.fuse_norm = fuse_norm
|
| 69 |
+
self.fuse_swiglu = fuse_swiglu
|
| 70 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
|
| 73 |
+
if attn is not None:
|
| 74 |
+
if not isinstance(attn, Dict):
|
| 75 |
+
raise ValueError("attn must be a dictionary")
|
| 76 |
+
if 'layers' not in attn:
|
| 77 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
| 78 |
+
if 'num_heads' not in attn:
|
| 79 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
| 80 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
| 81 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
| 82 |
+
attn['window_size'] = attn.get('window_size', None)
|
| 83 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
| 84 |
+
|
| 85 |
+
super().__init__(
|
| 86 |
+
pad_token_id=pad_token_id,
|
| 87 |
+
bos_token_id=bos_token_id,
|
| 88 |
+
eos_token_id=eos_token_id,
|
| 89 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 90 |
+
**kwargs,
|
| 91 |
+
)
|
fla/models/forgetting_transformer/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 4 |
+
|
| 5 |
+
from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig
|
| 6 |
+
from fla.models.forgetting_transformer.modeling_forgetting_transformer import (
|
| 7 |
+
ForgettingTransformerForCausalLM,
|
| 8 |
+
ForgettingTransformerModel
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
AutoConfig.register(ForgettingTransformerConfig.model_type, ForgettingTransformerConfig)
|
| 12 |
+
AutoModel.register(ForgettingTransformerConfig, ForgettingTransformerModel)
|
| 13 |
+
AutoModelForCausalLM.register(ForgettingTransformerConfig, ForgettingTransformerForCausalLM)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = ['ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel']
|
fla/models/gla/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 4 |
+
|
| 5 |
+
from fla.models.gla.configuration_gla import GLAConfig
|
| 6 |
+
from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
|
| 7 |
+
|
| 8 |
+
AutoConfig.register(GLAConfig.model_type, GLAConfig)
|
| 9 |
+
AutoModel.register(GLAConfig, GLAModel)
|
| 10 |
+
AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
|
fla/models/nsa/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 4 |
+
|
| 5 |
+
from fla.models.nsa.configuration_nsa import NSAConfig
|
| 6 |
+
from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel
|
| 7 |
+
|
| 8 |
+
AutoConfig.register(NSAConfig.model_type, NSAConfig)
|
| 9 |
+
AutoModel.register(NSAConfig, NSAModel)
|
| 10 |
+
AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
'NSAConfig', 'NSAModel', 'NSAForCausalLM',
|
| 15 |
+
]
|
fla/models/transformer_top/configuration_transformer.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TOPTransformerConfig(PretrainedConfig):
|
| 9 |
+
|
| 10 |
+
model_type = 'top_transformer'
|
| 11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
hidden_size: int = 2048,
|
| 16 |
+
num_hidden_layers: int = 24,
|
| 17 |
+
num_heads: int = 32,
|
| 18 |
+
num_kv_heads: int = None,
|
| 19 |
+
qkv_bias: bool = False,
|
| 20 |
+
qk_norm: bool = False,
|
| 21 |
+
window_size: Optional[int] = None,
|
| 22 |
+
rope_theta: Optional[float] = 10000.,
|
| 23 |
+
max_position_embeddings: int = 2048,
|
| 24 |
+
hidden_ratio: Optional[int] = 4,
|
| 25 |
+
intermediate_size: Optional[int] = None,
|
| 26 |
+
hidden_act: str = "swish",
|
| 27 |
+
initializer_range: float = 0.006,
|
| 28 |
+
elementwise_affine: Optional[bool] = True,
|
| 29 |
+
norm_eps: float = 1e-6,
|
| 30 |
+
use_cache: bool = True,
|
| 31 |
+
pad_token_id: int = None,
|
| 32 |
+
bos_token_id: int = 1,
|
| 33 |
+
eos_token_id: int = 2,
|
| 34 |
+
tie_word_embeddings: bool = False,
|
| 35 |
+
fuse_norm: bool = True,
|
| 36 |
+
fuse_swiglu: bool = True,
|
| 37 |
+
fuse_cross_entropy: bool = True,
|
| 38 |
+
vocab_size: int = 32000,
|
| 39 |
+
use_top_loss: bool = False,
|
| 40 |
+
top_window_size: Optional[int] = None,
|
| 41 |
+
**kwargs,
|
| 42 |
+
):
|
| 43 |
+
self.hidden_size = hidden_size
|
| 44 |
+
self.num_hidden_layers = num_hidden_layers
|
| 45 |
+
self.num_heads = num_heads
|
| 46 |
+
self.num_kv_heads = num_kv_heads
|
| 47 |
+
self.qkv_bias = qkv_bias
|
| 48 |
+
self.qk_norm = qk_norm
|
| 49 |
+
self.window_size = window_size
|
| 50 |
+
self.rope_theta = rope_theta
|
| 51 |
+
self.max_position_embeddings = max_position_embeddings
|
| 52 |
+
|
| 53 |
+
self.hidden_ratio = hidden_ratio
|
| 54 |
+
self.intermediate_size = intermediate_size
|
| 55 |
+
self.hidden_act = hidden_act
|
| 56 |
+
|
| 57 |
+
self.initializer_range = initializer_range
|
| 58 |
+
self.elementwise_affine = elementwise_affine
|
| 59 |
+
self.norm_eps = norm_eps
|
| 60 |
+
self.use_cache = use_cache
|
| 61 |
+
|
| 62 |
+
self.fuse_norm = fuse_norm
|
| 63 |
+
self.fuse_swiglu = fuse_swiglu
|
| 64 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
| 65 |
+
self.vocab_size = vocab_size
|
| 66 |
+
|
| 67 |
+
self.use_top_loss = use_top_loss
|
| 68 |
+
self.top_window_size = top_window_size if top_window_size is not None else max_position_embeddings
|
| 69 |
+
|
| 70 |
+
super().__init__(
|
| 71 |
+
pad_token_id=pad_token_id,
|
| 72 |
+
bos_token_id=bos_token_id,
|
| 73 |
+
eos_token_id=eos_token_id,
|
| 74 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 75 |
+
**kwargs,
|
| 76 |
+
)
|
fla/modules/__pycache__/activations.cpython-312.pyc
ADDED
|
Binary file (23 kB). View file
|
|
|
fla/modules/__pycache__/convolution.cpython-312.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
fla/modules/__pycache__/layernorm.cpython-312.pyc
ADDED
|
Binary file (43.4 kB). View file
|
|
|
fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc
ADDED
|
Binary file (6.74 kB). View file
|
|
|
fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (361 Bytes). View file
|
|
|
fla/ops/delta_rule/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc
ADDED
|
Binary file (392 Bytes). View file
|
|
|
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc
ADDED
|
Binary file (30.6 kB). View file
|
|
|
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
fla/ops/generalized_delta_rule/dplr/chunk.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
|
| 9 |
+
from fla.ops.common.utils import prepare_chunk_indices
|
| 10 |
+
from fla.ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
|
| 11 |
+
from fla.ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_fwd_intra_dplr_fn
|
| 12 |
+
from fla.ops.generalized_delta_rule.dplr.chunk_h_bwd import chunk_dplr_bwd_dhu
|
| 13 |
+
from fla.ops.generalized_delta_rule.dplr.chunk_h_fwd import chunk_dplr_fwd_h
|
| 14 |
+
from fla.ops.generalized_delta_rule.dplr.chunk_o_bwd import chunk_dplr_bwd_dAu, chunk_dplr_bwd_dv, chunk_dplr_bwd_o
|
| 15 |
+
from fla.ops.generalized_delta_rule.dplr.chunk_o_fwd import chunk_dplr_fwd_o
|
| 16 |
+
from fla.ops.generalized_delta_rule.dplr.wy_fast_bwd import chunk_dplr_bwd_wy
|
| 17 |
+
from fla.ops.generalized_delta_rule.dplr.wy_fast_fwd import fwd_prepare_wy_repr
|
| 18 |
+
from fla.ops.rwkv6.chunk import chunk_rwkv6_fwd_cumsum
|
| 19 |
+
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def chunk_dplr_fwd(
|
| 23 |
+
q: torch.Tensor,
|
| 24 |
+
k: torch.Tensor,
|
| 25 |
+
v: torch.Tensor,
|
| 26 |
+
a: torch.Tensor,
|
| 27 |
+
b: torch.Tensor,
|
| 28 |
+
gk: torch.Tensor,
|
| 29 |
+
scale: float,
|
| 30 |
+
initial_state: torch.Tensor,
|
| 31 |
+
output_final_state: bool,
|
| 32 |
+
offsets: Optional[torch.LongTensor] = None,
|
| 33 |
+
indices: Optional[torch.LongTensor] = None,
|
| 34 |
+
head_first: bool = True,
|
| 35 |
+
chunk_size: int = 64
|
| 36 |
+
):
|
| 37 |
+
T = q.shape[2] if head_first else q.shape[1]
|
| 38 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 39 |
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
|
| 40 |
+
|
| 41 |
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
|
| 42 |
+
q=q,
|
| 43 |
+
k=k,
|
| 44 |
+
a=a,
|
| 45 |
+
b=b,
|
| 46 |
+
gi=gi,
|
| 47 |
+
ge=ge,
|
| 48 |
+
scale=scale,
|
| 49 |
+
offsets=offsets,
|
| 50 |
+
indices=indices,
|
| 51 |
+
chunk_size=BT,
|
| 52 |
+
head_first=head_first
|
| 53 |
+
)
|
| 54 |
+
del ge
|
| 55 |
+
|
| 56 |
+
# A_ab, A_ak, gi, ge torch.float32
|
| 57 |
+
# A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
|
| 58 |
+
w, u, _ = fwd_prepare_wy_repr(
|
| 59 |
+
ag=ag,
|
| 60 |
+
A_ab=A_ab,
|
| 61 |
+
A_ak=A_ak,
|
| 62 |
+
v=v,
|
| 63 |
+
offsets=offsets,
|
| 64 |
+
indices=indices,
|
| 65 |
+
head_first=head_first,
|
| 66 |
+
chunk_size=BT
|
| 67 |
+
)
|
| 68 |
+
del A_ab, A_ak
|
| 69 |
+
h, v_new, final_state = chunk_dplr_fwd_h(
|
| 70 |
+
kg=kg,
|
| 71 |
+
bg=bg,
|
| 72 |
+
v=v,
|
| 73 |
+
w=w,
|
| 74 |
+
u=u,
|
| 75 |
+
gk=gi,
|
| 76 |
+
initial_state=initial_state,
|
| 77 |
+
output_final_state=output_final_state,
|
| 78 |
+
offsets=offsets,
|
| 79 |
+
indices=indices,
|
| 80 |
+
head_first=head_first,
|
| 81 |
+
chunk_size=BT
|
| 82 |
+
)
|
| 83 |
+
del u, kg, bg, gi
|
| 84 |
+
|
| 85 |
+
o = chunk_dplr_fwd_o(
|
| 86 |
+
qg=qg,
|
| 87 |
+
v=v,
|
| 88 |
+
v_new=v_new,
|
| 89 |
+
A_qk=A_qk,
|
| 90 |
+
A_qb=A_qb,
|
| 91 |
+
h=h,
|
| 92 |
+
offsets=offsets,
|
| 93 |
+
indices=indices,
|
| 94 |
+
head_first=head_first,
|
| 95 |
+
chunk_size=BT
|
| 96 |
+
)
|
| 97 |
+
del v_new, h, A_qk, A_qb
|
| 98 |
+
|
| 99 |
+
return o, final_state
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
@input_guard
|
| 106 |
+
@autocast_custom_fwd
|
| 107 |
+
def forward(
|
| 108 |
+
ctx,
|
| 109 |
+
q: torch.Tensor,
|
| 110 |
+
k: torch.Tensor,
|
| 111 |
+
v: torch.Tensor,
|
| 112 |
+
a: torch.Tensor,
|
| 113 |
+
b: torch.Tensor,
|
| 114 |
+
gk: torch.Tensor,
|
| 115 |
+
scale: float,
|
| 116 |
+
initial_state: torch.Tensor,
|
| 117 |
+
output_final_state: bool,
|
| 118 |
+
offsets: Optional[torch.LongTensor] = None,
|
| 119 |
+
head_first: bool = True
|
| 120 |
+
):
|
| 121 |
+
chunk_size = 16
|
| 122 |
+
|
| 123 |
+
# 2-d indices denoting the offsets of chunks in each sequence
|
| 124 |
+
# for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
|
| 125 |
+
# then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
|
| 126 |
+
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
|
| 127 |
+
indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
|
| 128 |
+
|
| 129 |
+
o, final_state = chunk_dplr_fwd(
|
| 130 |
+
q=q,
|
| 131 |
+
k=k,
|
| 132 |
+
v=v,
|
| 133 |
+
a=a,
|
| 134 |
+
b=b,
|
| 135 |
+
gk=gk,
|
| 136 |
+
scale=scale,
|
| 137 |
+
initial_state=initial_state,
|
| 138 |
+
output_final_state=output_final_state,
|
| 139 |
+
offsets=offsets,
|
| 140 |
+
indices=indices,
|
| 141 |
+
head_first=head_first,
|
| 142 |
+
chunk_size=chunk_size
|
| 143 |
+
)
|
| 144 |
+
ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
|
| 145 |
+
ctx.head_first = head_first
|
| 146 |
+
ctx.offsets = offsets
|
| 147 |
+
ctx.indices = indices
|
| 148 |
+
ctx.scale = scale
|
| 149 |
+
ctx.chunk_size = chunk_size
|
| 150 |
+
return o.to(q.dtype), final_state
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
@input_guard
|
| 154 |
+
@autocast_custom_bwd
|
| 155 |
+
def backward(
|
| 156 |
+
ctx,
|
| 157 |
+
do: torch.Tensor,
|
| 158 |
+
dht: torch.Tensor
|
| 159 |
+
):
|
| 160 |
+
q, k, v, a, b, gk, initial_state = ctx.saved_tensors
|
| 161 |
+
BT = ctx.chunk_size
|
| 162 |
+
head_first = ctx.head_first
|
| 163 |
+
offsets = ctx.offsets
|
| 164 |
+
indices = ctx.indices
|
| 165 |
+
scale = ctx.scale
|
| 166 |
+
|
| 167 |
+
# ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
|
| 168 |
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first)
|
| 169 |
+
|
| 170 |
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn(
|
| 171 |
+
q=q,
|
| 172 |
+
k=k,
|
| 173 |
+
a=a,
|
| 174 |
+
b=b,
|
| 175 |
+
gi=gi,
|
| 176 |
+
ge=ge,
|
| 177 |
+
scale=scale,
|
| 178 |
+
offsets=offsets,
|
| 179 |
+
indices=indices,
|
| 180 |
+
chunk_size=BT,
|
| 181 |
+
head_first=head_first
|
| 182 |
+
)
|
| 183 |
+
w, u, A_ab_inv = fwd_prepare_wy_repr(
|
| 184 |
+
ag=ag,
|
| 185 |
+
A_ab=A_ab,
|
| 186 |
+
A_ak=A_ak,
|
| 187 |
+
v=v,
|
| 188 |
+
offsets=offsets,
|
| 189 |
+
indices=indices,
|
| 190 |
+
head_first=head_first,
|
| 191 |
+
chunk_size=BT
|
| 192 |
+
)
|
| 193 |
+
del A_ab
|
| 194 |
+
h, v_new, _ = chunk_dplr_fwd_h(
|
| 195 |
+
kg=kg,
|
| 196 |
+
bg=bg,
|
| 197 |
+
v=v,
|
| 198 |
+
w=w,
|
| 199 |
+
u=u,
|
| 200 |
+
gk=gi,
|
| 201 |
+
initial_state=initial_state,
|
| 202 |
+
offsets=offsets,
|
| 203 |
+
indices=indices,
|
| 204 |
+
head_first=head_first,
|
| 205 |
+
chunk_size=BT
|
| 206 |
+
)
|
| 207 |
+
del u
|
| 208 |
+
# ******* end of recomputation *******
|
| 209 |
+
# A_ak, A_ab_inv, gi, ge torch.float32
|
| 210 |
+
# A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
|
| 211 |
+
|
| 212 |
+
dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
|
| 213 |
+
v=v,
|
| 214 |
+
v_new=v_new,
|
| 215 |
+
do=do,
|
| 216 |
+
A_qb=A_qb,
|
| 217 |
+
scale=scale,
|
| 218 |
+
offsets=offsets,
|
| 219 |
+
indices=indices,
|
| 220 |
+
head_first=head_first,
|
| 221 |
+
chunk_size=BT
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
dh, dh0, dv_new = chunk_dplr_bwd_dhu(
|
| 225 |
+
qg=qg,
|
| 226 |
+
bg=bg,
|
| 227 |
+
w=w,
|
| 228 |
+
gk=gi,
|
| 229 |
+
h0=initial_state,
|
| 230 |
+
dht=dht,
|
| 231 |
+
do=do,
|
| 232 |
+
dv=dv_new_intra,
|
| 233 |
+
offsets=offsets,
|
| 234 |
+
indices=indices,
|
| 235 |
+
head_first=head_first,
|
| 236 |
+
chunk_size=BT
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
dv = chunk_dplr_bwd_dv(
|
| 240 |
+
A_qk=A_qk,
|
| 241 |
+
kg=kg,
|
| 242 |
+
do=do,
|
| 243 |
+
dh=dh,
|
| 244 |
+
offsets=offsets,
|
| 245 |
+
indices=indices,
|
| 246 |
+
head_first=head_first,
|
| 247 |
+
chunk_size=BT
|
| 248 |
+
)
|
| 249 |
+
del A_qk
|
| 250 |
+
|
| 251 |
+
dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
|
| 252 |
+
k=kg,
|
| 253 |
+
b=bg,
|
| 254 |
+
v=v,
|
| 255 |
+
v_new=v_new,
|
| 256 |
+
do=do,
|
| 257 |
+
h=h,
|
| 258 |
+
dh=dh,
|
| 259 |
+
dv=dv_new,
|
| 260 |
+
w=w,
|
| 261 |
+
gk=gi,
|
| 262 |
+
offsets=offsets,
|
| 263 |
+
indices=indices,
|
| 264 |
+
chunk_size=BT,
|
| 265 |
+
scale=scale,
|
| 266 |
+
head_first=head_first,
|
| 267 |
+
)
|
| 268 |
+
del v_new
|
| 269 |
+
|
| 270 |
+
dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
|
| 271 |
+
A_ab_inv=A_ab_inv,
|
| 272 |
+
A_ak=A_ak,
|
| 273 |
+
v=v,
|
| 274 |
+
ag=ag,
|
| 275 |
+
dw=dw,
|
| 276 |
+
du=dv_new,
|
| 277 |
+
dv0=dv,
|
| 278 |
+
offsets=offsets,
|
| 279 |
+
indices=indices,
|
| 280 |
+
head_first=head_first,
|
| 281 |
+
chunk_size=BT
|
| 282 |
+
)
|
| 283 |
+
del A_ak
|
| 284 |
+
|
| 285 |
+
dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
|
| 286 |
+
q=q,
|
| 287 |
+
k=k,
|
| 288 |
+
a=a,
|
| 289 |
+
b=b,
|
| 290 |
+
gi=gi,
|
| 291 |
+
ge=ge,
|
| 292 |
+
dAqk=dA_qk,
|
| 293 |
+
dAqb=dA_qb,
|
| 294 |
+
dAak=dA_ak,
|
| 295 |
+
dAab=dA_ab,
|
| 296 |
+
dgk_last=dgk_last,
|
| 297 |
+
dqg=dqg,
|
| 298 |
+
dkg=dkg,
|
| 299 |
+
dag=dag,
|
| 300 |
+
dbg=dbg,
|
| 301 |
+
chunk_size=BT,
|
| 302 |
+
scale=scale,
|
| 303 |
+
head_first=head_first,
|
| 304 |
+
offsets=offsets,
|
| 305 |
+
indices=indices
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), dgk.to(gk), None, dh0, None, None, None
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@torch.compiler.disable
|
| 312 |
+
def chunk_dplr_delta_rule(
|
| 313 |
+
q: torch.Tensor,
|
| 314 |
+
k: torch.Tensor,
|
| 315 |
+
v: torch.Tensor,
|
| 316 |
+
a: torch.Tensor,
|
| 317 |
+
b: torch.Tensor,
|
| 318 |
+
gk: torch.Tensor,
|
| 319 |
+
scale: Optional[float] = None,
|
| 320 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 321 |
+
output_final_state: bool = False,
|
| 322 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 323 |
+
head_first: bool = False
|
| 324 |
+
):
|
| 325 |
+
r"""
|
| 326 |
+
Args:
|
| 327 |
+
q (torch.Tensor):
|
| 328 |
+
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
|
| 329 |
+
k (torch.Tensor):
|
| 330 |
+
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
|
| 331 |
+
v (torch.Tensor):
|
| 332 |
+
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
|
| 333 |
+
a (torch.Tensor):
|
| 334 |
+
activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
|
| 335 |
+
b (torch.Tensor):
|
| 336 |
+
betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
|
| 337 |
+
gk (torch.Tensor):
|
| 338 |
+
gk of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. decay term in log space!
|
| 339 |
+
scale (Optional[int]):
|
| 340 |
+
Scale factor for the RetNet attention scores.
|
| 341 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
| 342 |
+
initial_state (Optional[torch.Tensor]):
|
| 343 |
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
| 344 |
+
For equal-length input sequences, `N` equals the batch size `B`.
|
| 345 |
+
Default: `None`.
|
| 346 |
+
output_final_state (Optional[bool]):
|
| 347 |
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
| 348 |
+
cu_seqlens (torch.LongTensor):
|
| 349 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 350 |
+
consistent with the FlashAttention API.
|
| 351 |
+
head_first (Optional[bool]):
|
| 352 |
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
| 353 |
+
Default: `False`.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
o (torch.Tensor):
|
| 357 |
+
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
|
| 358 |
+
final_state (torch.Tensor):
|
| 359 |
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
| 360 |
+
"""
|
| 361 |
+
assert q.dtype == k.dtype == v.dtype
|
| 362 |
+
# assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
|
| 363 |
+
# gk = gk.float()
|
| 364 |
+
|
| 365 |
+
if cu_seqlens is not None:
|
| 366 |
+
if q.shape[0] != 1:
|
| 367 |
+
raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
| 368 |
+
f"Please flatten variable-length inputs before processing.")
|
| 369 |
+
if head_first:
|
| 370 |
+
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
|
| 371 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 372 |
+
raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 373 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
|
| 374 |
+
scale = k.shape[-1] ** -0.5 if scale is None else scale
|
| 375 |
+
o, final_state = ChunkDPLRDeltaRuleFunction.apply(
|
| 376 |
+
q,
|
| 377 |
+
k,
|
| 378 |
+
v,
|
| 379 |
+
a,
|
| 380 |
+
b,
|
| 381 |
+
gk,
|
| 382 |
+
scale,
|
| 383 |
+
initial_state,
|
| 384 |
+
output_final_state,
|
| 385 |
+
cu_seqlens,
|
| 386 |
+
head_first
|
| 387 |
+
)
|
| 388 |
+
return o, final_state
|
fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc
ADDED
|
Binary file (27.4 kB). View file
|
|
|
fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc
ADDED
|
Binary file (23.1 kB). View file
|
|
|
fla/ops/gsa/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (69.4 kB). View file
|
|
|
fla/ops/hgrn/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
fla/ops/ttt/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (88.1 kB). View file
|
|
|
fla/ops/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.12 kB). View file
|
|
|
logs/none_enyj3lod/attempt_0/3/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_17408/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_18944/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_25088/rank3_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_25088/rank7_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_33280/rank6_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_34816/rank5_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_38912/rank1_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_38912/rank2_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_7680/rank0_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
profile_trace/iteration_7680/rank4_trace.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
torchtitan/components/dataloader.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
| 8 |
+
|
| 9 |
+
import pickle
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from collections.abc import Callable
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 15 |
+
from torch.utils.data import IterableDataset
|
| 16 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 17 |
+
from torchtitan.tools.logging import logger
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BaseDataLoader(Stateful, ABC):
|
| 21 |
+
"""Base class for all dataloaders.
|
| 22 |
+
|
| 23 |
+
This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
|
| 24 |
+
``state_dict()`` and ``load_state_dict()``.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def __iter__(self):
|
| 29 |
+
...
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
|
| 33 |
+
"""Dataloader that is aware of distributed data parallelism.
|
| 34 |
+
|
| 35 |
+
This dataloader is used to load data in a distributed data parallel fashion. It also
|
| 36 |
+
utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
|
| 37 |
+
methods such as ``__iter__``.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
dataset (IterableDataset): The dataset to iterate over.
|
| 41 |
+
dp_rank: Data parallelism rank for this dataloader.
|
| 42 |
+
dp_world_size: The world size of the data parallelism.
|
| 43 |
+
batch_size: The batch size to use for each iteration.
|
| 44 |
+
collate_fn: Optional function to collate samples in a batch.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
dp_rank: int
|
| 48 |
+
dp_world_size: int
|
| 49 |
+
batch_size: int
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
dataset: IterableDataset,
|
| 54 |
+
dp_rank: int,
|
| 55 |
+
dp_world_size: int,
|
| 56 |
+
batch_size: int,
|
| 57 |
+
collate_fn: Callable | None = None,
|
| 58 |
+
):
|
| 59 |
+
self.dp_world_size = dp_world_size
|
| 60 |
+
self.dp_rank = dp_rank
|
| 61 |
+
self.batch_size = batch_size
|
| 62 |
+
super().__init__(dataset, batch_size, collate_fn=collate_fn)
|
| 63 |
+
self._rank_id = f"dp_rank_{dp_rank}"
|
| 64 |
+
|
| 65 |
+
def state_dict(self) -> dict[str, Any]:
|
| 66 |
+
# Store state only for dp rank to avoid replicating the same state across other dimensions.
|
| 67 |
+
return {
|
| 68 |
+
# We don't have to use pickle as DCP will serialize the state_dict. However,
|
| 69 |
+
# we have to keep this for backward compatibility.
|
| 70 |
+
self._rank_id: pickle.dumps(super().state_dict()),
|
| 71 |
+
"world_size": self.dp_world_size,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 75 |
+
# State being empty is valid.
|
| 76 |
+
if not state_dict:
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
if self._rank_id not in state_dict:
|
| 80 |
+
logger.warning(
|
| 81 |
+
f"DataLoader state is empty for dp rank {self.dp_rank}, "
|
| 82 |
+
"expected key {self._rank_id}"
|
| 83 |
+
)
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
assert self.dp_world_size == state_dict["world_size"], (
|
| 87 |
+
"dp_degree is inconsistent before and after checkpoint, "
|
| 88 |
+
"dataloader resharding is not supported yet."
|
| 89 |
+
)
|
| 90 |
+
# We don't have to use pickle as DCP will serialize the state_dict. However, we have to
|
| 91 |
+
# keep this for backward compatibility.
|
| 92 |
+
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
|
torchtitan/components/float8.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# [Note] Getting the 'torchao' package:
|
| 8 |
+
# This script requires the 'torchao' package to function correctly.
|
| 9 |
+
# Please ensure you have this package installed from the appropriate repository.
|
| 10 |
+
# You can obtain it from https://github.com/pytorch/ao by following the
|
| 11 |
+
# installation instructions.
|
| 12 |
+
|
| 13 |
+
# Note: Performance
|
| 14 |
+
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from torchtitan.config_manager import JobConfig
|
| 20 |
+
from torchtitan.distributed import ParallelDims
|
| 21 |
+
from torchtitan.protocols.model_converter import (
|
| 22 |
+
ModelConverter,
|
| 23 |
+
register_model_converter,
|
| 24 |
+
)
|
| 25 |
+
from torchtitan.tools.logging import logger
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _is_sm89_or_later():
|
| 29 |
+
# Float8 is only supported on SM89 or later (H100+ GPUs)
|
| 30 |
+
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Float8Converter(ModelConverter):
|
| 34 |
+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
|
| 35 |
+
self.enabled = False
|
| 36 |
+
|
| 37 |
+
float8_config = job_config.float8
|
| 38 |
+
if not _is_sm89_or_later():
|
| 39 |
+
logger.warning(
|
| 40 |
+
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
|
| 41 |
+
)
|
| 42 |
+
return
|
| 43 |
+
try:
|
| 44 |
+
from torchao.float8 import Float8LinearConfig
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
raise ImportError(
|
| 47 |
+
"torchao is not installed. Please install it to use float8 linear layers."
|
| 48 |
+
) from e
|
| 49 |
+
|
| 50 |
+
if float8_config.recipe_name is not None and not hasattr(
|
| 51 |
+
Float8LinearConfig, "from_recipe_name"
|
| 52 |
+
):
|
| 53 |
+
logger.warning(
|
| 54 |
+
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
|
| 55 |
+
"is too old, please install torchao v0.9.0 or later and try again",
|
| 56 |
+
)
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
self.enabled = True
|
| 60 |
+
self.filter_fqns = float8_config.filter_fqns
|
| 61 |
+
|
| 62 |
+
if float8_config.recipe_name is not None:
|
| 63 |
+
assert (
|
| 64 |
+
not float8_config.enable_fsdp_float8_all_gather
|
| 65 |
+
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
|
| 66 |
+
assert (
|
| 67 |
+
not float8_config.force_recompute_fp8_weight_in_bwd
|
| 68 |
+
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
|
| 69 |
+
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
|
| 70 |
+
self.precompute_scale = False
|
| 71 |
+
logger.info(
|
| 72 |
+
f"Float8 training active with recipe {float8_config.recipe_name}"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
|
| 77 |
+
enable_fsdp_float8_all_gather = (
|
| 78 |
+
parallel_dims.dp_shard_enabled
|
| 79 |
+
and float8_config.enable_fsdp_float8_all_gather
|
| 80 |
+
)
|
| 81 |
+
self.config = Float8LinearConfig(
|
| 82 |
+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
|
| 83 |
+
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
|
| 84 |
+
)
|
| 85 |
+
# for precompute_float8_dynamic_scale_for_fsdp
|
| 86 |
+
self.precompute_scale = (
|
| 87 |
+
enable_fsdp_float8_all_gather
|
| 88 |
+
and float8_config.precompute_float8_dynamic_scale_for_fsdp
|
| 89 |
+
)
|
| 90 |
+
logger.info("Float8 tensorwise scaled training active")
|
| 91 |
+
|
| 92 |
+
def convert(self, model: nn.Module):
|
| 93 |
+
return self.convert_to_float8_training(model)
|
| 94 |
+
|
| 95 |
+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
|
| 96 |
+
return self.precompute_float8_dynamic_scale_for_fsdp(model)
|
| 97 |
+
|
| 98 |
+
def convert_to_float8_training(self, model: nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
This function converts the linear layers of `model` to `Float8Linear`.
|
| 101 |
+
Note that today, only dynamic tensor scaling (the default) is supported.
|
| 102 |
+
This will mutate the model inplace.
|
| 103 |
+
"""
|
| 104 |
+
if not self.enabled:
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
from torchao.float8 import convert_to_float8_training
|
| 108 |
+
|
| 109 |
+
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
|
| 110 |
+
convert_to_float8_training(
|
| 111 |
+
model,
|
| 112 |
+
config=self.config,
|
| 113 |
+
module_filter_fn=self._module_filter_fn,
|
| 114 |
+
)
|
| 115 |
+
logger.info(
|
| 116 |
+
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
|
| 117 |
+
f"{self.config.enable_fsdp_float8_all_gather}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool:
|
| 121 |
+
if not isinstance(mod, nn.Linear):
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
# All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
|
| 125 |
+
dims_multiples_of_16 = (
|
| 126 |
+
mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# If the fqn matches any filtered fqn, then we should not convert this module.
|
| 130 |
+
is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns)
|
| 131 |
+
|
| 132 |
+
return dims_multiples_of_16 and not is_filtered_fqn
|
| 133 |
+
|
| 134 |
+
def precompute_float8_dynamic_scale_for_fsdp(
|
| 135 |
+
self, model: nn.Module | list[nn.Module]
|
| 136 |
+
):
|
| 137 |
+
if not self.enabled:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
if not self.precompute_scale:
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
|
| 144 |
+
|
| 145 |
+
models = [model] if isinstance(model, nn.Module) else model
|
| 146 |
+
for m in models:
|
| 147 |
+
precompute_float8_dynamic_scale_for_fsdp(m)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
register_model_converter(Float8Converter, "float8")
|
torchtitan/components/optimizer.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
from typing import Any, Generic, Iterator, TypeVar
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed.checkpoint.state_dict import (
|
| 13 |
+
get_optimizer_state_dict,
|
| 14 |
+
set_optimizer_state_dict,
|
| 15 |
+
StateDictOptions,
|
| 16 |
+
)
|
| 17 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 18 |
+
from torch.optim import Optimizer
|
| 19 |
+
|
| 20 |
+
from torchtitan.components.ft import FTManager, has_torchft
|
| 21 |
+
from torchtitan.config_manager import JobConfig
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"OptimizersContainer",
|
| 25 |
+
"build_optimizers",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if has_torchft:
|
| 30 |
+
import torchft as ft
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
T = TypeVar("T", bound=Optimizer)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class OptimizersContainer(Optimizer, Stateful, Generic[T]):
|
| 37 |
+
"""A container for multiple optimizers.
|
| 38 |
+
|
| 39 |
+
This class is used to wrap multiple optimizers into a single object that can be
|
| 40 |
+
used to reduce the complexity of the training loop. This mimics the behavior of
|
| 41 |
+
``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``.
|
| 42 |
+
|
| 43 |
+
**Note**
|
| 44 |
+
Users who want to customize the optimizer behavior can inherit from this class and
|
| 45 |
+
extend the functionality as needed. The following methods must follow the same signature
|
| 46 |
+
as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``,
|
| 47 |
+
``load_state_dict()``.
|
| 48 |
+
|
| 49 |
+
**Limitations**
|
| 50 |
+
This class assumes that all the optimizers are the same type and have the same
|
| 51 |
+
configurations. With this assumption, TorchTitan can support lr scheduler resharding
|
| 52 |
+
(e.g., loading a checkpoint with a different number of GPUs and/or different
|
| 53 |
+
parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the
|
| 54 |
+
resharding for the optimizer state but not for the lr scheduler state, hence the limitation.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model_parts (List[nn.Module]): List of model parts to be optimized.
|
| 58 |
+
optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers.
|
| 59 |
+
name (str): Name of the optimizers.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
optimizers: list[T]
|
| 63 |
+
model_parts: list[nn.Module]
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
model_parts: list[nn.Module],
|
| 68 |
+
optimizer_cls: type[T],
|
| 69 |
+
optimizer_kwargs: dict[str, Any],
|
| 70 |
+
) -> None:
|
| 71 |
+
all_params = []
|
| 72 |
+
self.optimizers = []
|
| 73 |
+
self.model_parts = model_parts
|
| 74 |
+
for model in self.model_parts:
|
| 75 |
+
params = [p for p in model.parameters() if p.requires_grad]
|
| 76 |
+
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
|
| 77 |
+
all_params.extend(params)
|
| 78 |
+
self._validate_length(len(self.model_parts))
|
| 79 |
+
self._post_init(all_params, optimizer_kwargs)
|
| 80 |
+
|
| 81 |
+
def __iter__(self) -> Iterator[T]:
|
| 82 |
+
return iter(self.optimizers)
|
| 83 |
+
|
| 84 |
+
def __len__(self) -> int:
|
| 85 |
+
return len(self.optimizers)
|
| 86 |
+
|
| 87 |
+
def step(self, *args, **kwargs) -> None:
|
| 88 |
+
for optimizer in self.optimizers:
|
| 89 |
+
optimizer.step(*args, **kwargs)
|
| 90 |
+
|
| 91 |
+
def zero_grad(self, *args, **kwargs) -> None:
|
| 92 |
+
for optimizer in self.optimizers:
|
| 93 |
+
optimizer.zero_grad(*args, **kwargs)
|
| 94 |
+
|
| 95 |
+
def state_dict(self) -> dict[str, Any]:
|
| 96 |
+
func = functools.partial(
|
| 97 |
+
get_optimizer_state_dict,
|
| 98 |
+
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 99 |
+
)
|
| 100 |
+
return {
|
| 101 |
+
k: v
|
| 102 |
+
for sd in map(func, self.model_parts, self.optimizers)
|
| 103 |
+
for k, v in sd.items()
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 107 |
+
func = functools.partial(
|
| 108 |
+
set_optimizer_state_dict,
|
| 109 |
+
optim_state_dict=state_dict,
|
| 110 |
+
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 111 |
+
)
|
| 112 |
+
list(map(func, self.model_parts, self.optimizers))
|
| 113 |
+
|
| 114 |
+
def _validate_length(self, expected_length: int) -> None:
|
| 115 |
+
assert expected_length == len(self.optimizers), (
|
| 116 |
+
"Must pass one optimizer per model part or per param if "
|
| 117 |
+
"using OptimizersInBackwardContainer."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _post_init(
|
| 121 |
+
self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
|
| 122 |
+
) -> None:
|
| 123 |
+
# We need to call Optimizer.__init__() to initialize some necessary optimizer
|
| 124 |
+
# functionality such as hooks.
|
| 125 |
+
Optimizer.__init__(self, all_params, optimizer_kwargs)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class OptimizersInBackwardContainer(OptimizersContainer):
|
| 129 |
+
"""OptimizersContainer for executing ``optim.step()`` in backward pass.
|
| 130 |
+
|
| 131 |
+
This class extend ``OptimizersContainer`` to support optimizer step in
|
| 132 |
+
backward pass. ``step()`` and ``zero_grad()`` are no-op in this class.
|
| 133 |
+
Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to
|
| 134 |
+
execute these methods when the gradient is accumulated.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
model_parts: list[nn.Module],
|
| 140 |
+
optimizer_cls: type[T],
|
| 141 |
+
optimizer_kwargs: dict[str, Any],
|
| 142 |
+
) -> None:
|
| 143 |
+
all_params = []
|
| 144 |
+
self.model_parts = model_parts
|
| 145 |
+
|
| 146 |
+
optim_dict = {}
|
| 147 |
+
for model in self.model_parts:
|
| 148 |
+
for p in model.parameters():
|
| 149 |
+
if p.requires_grad:
|
| 150 |
+
optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
|
| 151 |
+
all_params.append(p)
|
| 152 |
+
|
| 153 |
+
def optim_hook(param) -> None:
|
| 154 |
+
optim_dict[param].step()
|
| 155 |
+
optim_dict[param].zero_grad()
|
| 156 |
+
|
| 157 |
+
for model in self.model_parts:
|
| 158 |
+
for param in model.parameters():
|
| 159 |
+
if param.requires_grad:
|
| 160 |
+
param.register_post_accumulate_grad_hook(optim_hook)
|
| 161 |
+
|
| 162 |
+
self.optimizers = list(optim_dict.values())
|
| 163 |
+
|
| 164 |
+
self._validate_length(
|
| 165 |
+
sum(len(list(model.parameters())) for model in self.model_parts)
|
| 166 |
+
)
|
| 167 |
+
self._post_init(all_params, optimizer_kwargs)
|
| 168 |
+
|
| 169 |
+
def step(self) -> None:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
def zero_grad(self) -> None:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class FTOptimizersContainer(OptimizersContainer):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
model_parts: list[nn.Module],
|
| 180 |
+
optimizer_cls: type[T],
|
| 181 |
+
optimizer_kwargs: dict[str, Any],
|
| 182 |
+
ft_manager: "ft.Manager",
|
| 183 |
+
) -> None:
|
| 184 |
+
super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
|
| 185 |
+
|
| 186 |
+
# Force to initialize the optimizer state so that `optim.step()`
|
| 187 |
+
# won't be called by state_dict() and load_state_dict().
|
| 188 |
+
_ = {
|
| 189 |
+
k: v
|
| 190 |
+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
|
| 191 |
+
for k, v in sd.items()
|
| 192 |
+
}
|
| 193 |
+
self.cache_state_dict: dict[str, Any] = {}
|
| 194 |
+
self._ft_optimizer = ft.Optimizer(ft_manager, self)
|
| 195 |
+
self._call_from_ft: bool = False
|
| 196 |
+
|
| 197 |
+
def init_cache_state_dict(self) -> None:
|
| 198 |
+
self.cache_state_dict = super().state_dict()
|
| 199 |
+
|
| 200 |
+
def state_dict(self) -> dict[str, Any]:
|
| 201 |
+
return self.cache_state_dict
|
| 202 |
+
|
| 203 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 204 |
+
# We have to invalidate the `cache_state_dict` because optimizer uses
|
| 205 |
+
# assign instead of copy when doing `load_state_dict()`. Without
|
| 206 |
+
# invalidating the `cache_state_dict`, there will be memory leakage.
|
| 207 |
+
self.cache_state_dict = {}
|
| 208 |
+
super().load_state_dict(state_dict)
|
| 209 |
+
self.init_cache_state_dict()
|
| 210 |
+
|
| 211 |
+
def step(self, *args, **kwargs) -> None:
|
| 212 |
+
"""Calling the correct step() depending on the caller.
|
| 213 |
+
|
| 214 |
+
TorchFT's OptimizerWrapper.step() is designed to be callled only once
|
| 215 |
+
per train step per ft.Manager regardless how many optimizers are used.
|
| 216 |
+
Hence we will need to appropriately dispatch the call.
|
| 217 |
+
"""
|
| 218 |
+
if self._call_from_ft:
|
| 219 |
+
super().step(*args, **kwargs)
|
| 220 |
+
else:
|
| 221 |
+
self._call_from_ft = True
|
| 222 |
+
self._ft_optimizer.step(*args, **kwargs)
|
| 223 |
+
self._call_from_ft = False
|
| 224 |
+
|
| 225 |
+
def zero_grad(self, *args, **kwargs) -> None:
|
| 226 |
+
"""Calling the correct zero_grad() depending on the caller.
|
| 227 |
+
|
| 228 |
+
Check the comment in ``step()``.
|
| 229 |
+
"""
|
| 230 |
+
if self._call_from_ft:
|
| 231 |
+
super().zero_grad(*args, **kwargs)
|
| 232 |
+
else:
|
| 233 |
+
self._call_from_ft = True
|
| 234 |
+
self._ft_optimizer.zero_grad(*args, **kwargs)
|
| 235 |
+
self._call_from_ft = False
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def build_optimizers(
|
| 239 |
+
model_parts: list[nn.Module],
|
| 240 |
+
job_config: JobConfig,
|
| 241 |
+
ft_manager: FTManager,
|
| 242 |
+
) -> OptimizersContainer:
|
| 243 |
+
"""Create a OptimizersContainer for the given model parts and job config.
|
| 244 |
+
|
| 245 |
+
This function creates a ``OptimizersContainer`` for the given model parts.
|
| 246 |
+
``job_config`` should define the correct optimizer name and parameters.
|
| 247 |
+
This function currently supports creating ``OptimizersContainer`` and
|
| 248 |
+
``OptimizersInBackwardContainer``.
|
| 249 |
+
|
| 250 |
+
**Note**
|
| 251 |
+
Users who want to customize the optimizer behavior can create their own
|
| 252 |
+
``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
|
| 253 |
+
customized ``build_optimizers`` to ``TrainSpec`` will create the customized
|
| 254 |
+
``OptimizersContainer``.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
model_parts (List[nn.Module]): List of model parts to be optimized.
|
| 258 |
+
job_config (JobConfig): Job config containing the optimizer name and parameters.
|
| 259 |
+
"""
|
| 260 |
+
optim_in_bwd = job_config.optimizer.early_step_in_backward
|
| 261 |
+
if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
|
| 262 |
+
raise NotImplementedError(
|
| 263 |
+
"Optimizers in backward is not supported with pipeline parallelism."
|
| 264 |
+
)
|
| 265 |
+
name = job_config.optimizer.name
|
| 266 |
+
lr = job_config.optimizer.lr
|
| 267 |
+
eps = job_config.optimizer.eps
|
| 268 |
+
|
| 269 |
+
optim_implementation = job_config.optimizer.implementation
|
| 270 |
+
assert optim_implementation in ["fused", "foreach", "for-loop"]
|
| 271 |
+
|
| 272 |
+
fused = optim_implementation == "fused"
|
| 273 |
+
foreach = optim_implementation == "foreach"
|
| 274 |
+
|
| 275 |
+
optimizer_kwargs = {
|
| 276 |
+
"lr": lr,
|
| 277 |
+
"eps": eps,
|
| 278 |
+
"betas": (0.9, 0.95),
|
| 279 |
+
"weight_decay": 0.1,
|
| 280 |
+
"fused": fused,
|
| 281 |
+
"foreach": foreach,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
optimizer_classes = {
|
| 285 |
+
"Adam": torch.optim.Adam,
|
| 286 |
+
"AdamW": torch.optim.AdamW,
|
| 287 |
+
}
|
| 288 |
+
if name not in optimizer_classes:
|
| 289 |
+
raise NotImplementedError(f"Optimizer {name} not added.")
|
| 290 |
+
optimizer_cls = optimizer_classes[name]
|
| 291 |
+
|
| 292 |
+
if optim_in_bwd and ft_manager.enabled:
|
| 293 |
+
raise ValueError("TorchFT is not supported with optimizers in backward.")
|
| 294 |
+
elif optim_in_bwd:
|
| 295 |
+
return OptimizersInBackwardContainer(
|
| 296 |
+
model_parts, optimizer_cls, optimizer_kwargs
|
| 297 |
+
)
|
| 298 |
+
elif ft_manager.enabled:
|
| 299 |
+
return FTOptimizersContainer(
|
| 300 |
+
model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
|
torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
torchtitan/datasets/hf_datasets.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Callable
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from datasets import Dataset, load_dataset
|
| 13 |
+
from datasets.distributed import split_dataset_by_node
|
| 14 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 15 |
+
from torch.utils.data import IterableDataset
|
| 16 |
+
|
| 17 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
| 18 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 19 |
+
from torchtitan.config_manager import JobConfig
|
| 20 |
+
from torchtitan.tools.logging import logger
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _load_c4_dataset(dataset_path: str):
|
| 24 |
+
"""Load C4 dataset with default configuration."""
|
| 25 |
+
return load_dataset(dataset_path, name="en", split="train", streaming=True)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _process_c4_text(sample: dict[str, Any]) -> str:
|
| 29 |
+
"""Process C4 dataset sample text."""
|
| 30 |
+
return sample["text"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class DatasetConfig:
|
| 35 |
+
path: str
|
| 36 |
+
loader: Callable
|
| 37 |
+
text_processor: Callable
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Add your dataset here here - more information at docs/datasets.md
|
| 41 |
+
DATASETS = {
|
| 42 |
+
"c4": DatasetConfig(
|
| 43 |
+
path="allenai/c4",
|
| 44 |
+
loader=_load_c4_dataset,
|
| 45 |
+
text_processor=_process_c4_text,
|
| 46 |
+
),
|
| 47 |
+
"c4_test": DatasetConfig(
|
| 48 |
+
path="tests/assets/c4_test",
|
| 49 |
+
loader=lambda path: load_dataset(path, split="train"),
|
| 50 |
+
text_processor=_process_c4_text,
|
| 51 |
+
),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _validate_dataset(
|
| 56 |
+
dataset_name: str, dataset_path: str | None = None
|
| 57 |
+
) -> tuple[str, Callable, Callable]:
|
| 58 |
+
"""Validate dataset name and path."""
|
| 59 |
+
if dataset_name not in DATASETS:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Dataset {dataset_name} is not supported. "
|
| 62 |
+
f"Supported datasets are: {list(DATASETS.keys())}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
config = DATASETS[dataset_name]
|
| 66 |
+
path = dataset_path or config.path
|
| 67 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
| 68 |
+
return path, config.loader, config.text_processor
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class HuggingFaceDataset(IterableDataset, Stateful):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
dataset_name: str,
|
| 75 |
+
dataset_path: str | None,
|
| 76 |
+
tokenizer: Tokenizer,
|
| 77 |
+
seq_len: int = 2048,
|
| 78 |
+
dp_rank: int = 0,
|
| 79 |
+
dp_world_size: int = 1,
|
| 80 |
+
infinite: bool = False,
|
| 81 |
+
) -> None:
|
| 82 |
+
# Force lowercase for consistent comparison
|
| 83 |
+
dataset_name = dataset_name.lower()
|
| 84 |
+
|
| 85 |
+
path, dataset_loader, text_processor = _validate_dataset(
|
| 86 |
+
dataset_name, dataset_path
|
| 87 |
+
)
|
| 88 |
+
ds = dataset_loader(path)
|
| 89 |
+
|
| 90 |
+
self.dataset_name = dataset_name
|
| 91 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
| 92 |
+
self._tokenizer = tokenizer
|
| 93 |
+
self.seq_len = seq_len
|
| 94 |
+
self.infinite = infinite
|
| 95 |
+
self._text_processor = text_processor
|
| 96 |
+
|
| 97 |
+
# Variables for checkpointing
|
| 98 |
+
self._sample_idx = 0
|
| 99 |
+
self._all_tokens: list[int] = []
|
| 100 |
+
|
| 101 |
+
def _get_data_iter(self):
|
| 102 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
| 103 |
+
return iter([])
|
| 104 |
+
|
| 105 |
+
it = iter(self._data)
|
| 106 |
+
for _ in range(self._sample_idx):
|
| 107 |
+
next(it)
|
| 108 |
+
return it
|
| 109 |
+
|
| 110 |
+
def __iter__(self):
|
| 111 |
+
max_buffer_token_len = 1 + self.seq_len
|
| 112 |
+
|
| 113 |
+
while True:
|
| 114 |
+
for sample in self._get_data_iter():
|
| 115 |
+
# Use the dataset-specific text processor
|
| 116 |
+
sample_text = self._text_processor(sample)
|
| 117 |
+
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
|
| 118 |
+
self._all_tokens.extend(sample_tokens)
|
| 119 |
+
self._sample_idx += 1
|
| 120 |
+
|
| 121 |
+
while len(self._all_tokens) >= max_buffer_token_len:
|
| 122 |
+
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
|
| 123 |
+
# update tokens to the remaining tokens
|
| 124 |
+
self._all_tokens = self._all_tokens[max_buffer_token_len:]
|
| 125 |
+
input = x[:-1]
|
| 126 |
+
label = x[1:]
|
| 127 |
+
yield {"input": input}, label
|
| 128 |
+
|
| 129 |
+
if not self.infinite:
|
| 130 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
| 131 |
+
break
|
| 132 |
+
else:
|
| 133 |
+
# Reset offset for the next iteration
|
| 134 |
+
self._sample_idx = 0
|
| 135 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
| 136 |
+
|
| 137 |
+
def load_state_dict(self, state_dict):
|
| 138 |
+
self._sample_idx = state_dict["sample_idx"]
|
| 139 |
+
self._all_tokens = state_dict["token_buffer"]
|
| 140 |
+
|
| 141 |
+
def state_dict(self):
|
| 142 |
+
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def build_hf_dataloader(
|
| 146 |
+
dp_world_size: int,
|
| 147 |
+
dp_rank: int,
|
| 148 |
+
tokenizer: Tokenizer,
|
| 149 |
+
job_config: JobConfig,
|
| 150 |
+
infinite: bool = True,
|
| 151 |
+
) -> ParallelAwareDataloader:
|
| 152 |
+
"""Build a data loader for HuggingFace datasets."""
|
| 153 |
+
dataset_name = job_config.training.dataset
|
| 154 |
+
dataset_path = job_config.training.dataset_path
|
| 155 |
+
batch_size = job_config.training.batch_size
|
| 156 |
+
seq_len = job_config.training.seq_len
|
| 157 |
+
|
| 158 |
+
hf_ds = HuggingFaceDataset(
|
| 159 |
+
dataset_name=dataset_name,
|
| 160 |
+
dataset_path=dataset_path,
|
| 161 |
+
tokenizer=tokenizer,
|
| 162 |
+
seq_len=seq_len,
|
| 163 |
+
dp_rank=dp_rank,
|
| 164 |
+
dp_world_size=dp_world_size,
|
| 165 |
+
infinite=infinite,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return ParallelAwareDataloader(
|
| 169 |
+
dataset=hf_ds,
|
| 170 |
+
dp_rank=dp_rank,
|
| 171 |
+
dp_world_size=dp_world_size,
|
| 172 |
+
batch_size=batch_size,
|
| 173 |
+
)
|
torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc
ADDED
|
Binary file (7.73 kB). View file
|
|
|
torchtitan/datasets/tokenizer/tiktoken.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import cast, Literal
|
| 14 |
+
|
| 15 |
+
import tiktoken
|
| 16 |
+
from tiktoken.load import load_tiktoken_bpe
|
| 17 |
+
|
| 18 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 19 |
+
from torchtitan.config_manager import JobConfig
|
| 20 |
+
from torchtitan.tools.logging import logger
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TikTokenizer(Tokenizer):
|
| 24 |
+
"""
|
| 25 |
+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_path (str): The path to the Tiktoken model file.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
special_tokens: dict[str, int]
|
| 32 |
+
|
| 33 |
+
num_reserved_special_tokens = 256
|
| 34 |
+
|
| 35 |
+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
|
| 36 |
+
|
| 37 |
+
def __init__(self, model_path: str):
|
| 38 |
+
super().__init__()
|
| 39 |
+
assert os.path.exists(
|
| 40 |
+
model_path
|
| 41 |
+
), f"The tokenizer path does not exist: {model_path}"
|
| 42 |
+
assert os.path.isfile(model_path), model_path
|
| 43 |
+
|
| 44 |
+
mergeable_ranks = load_tiktoken_bpe(model_path)
|
| 45 |
+
num_base_tokens = len(mergeable_ranks)
|
| 46 |
+
special_tokens = [
|
| 47 |
+
"<|begin_of_text|>",
|
| 48 |
+
"<|end_of_text|>",
|
| 49 |
+
"<|reserved_special_token_0|>",
|
| 50 |
+
"<|reserved_special_token_1|>",
|
| 51 |
+
"<|reserved_special_token_2|>",
|
| 52 |
+
"<|reserved_special_token_3|>",
|
| 53 |
+
"<|start_header_id|>",
|
| 54 |
+
"<|end_header_id|>",
|
| 55 |
+
"<|reserved_special_token_4|>",
|
| 56 |
+
"<|eot_id|>", # end of turn
|
| 57 |
+
] + [
|
| 58 |
+
f"<|reserved_special_token_{i}|>"
|
| 59 |
+
for i in range(5, self.num_reserved_special_tokens - 5)
|
| 60 |
+
]
|
| 61 |
+
self.special_tokens = {
|
| 62 |
+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
|
| 63 |
+
}
|
| 64 |
+
self.model = tiktoken.Encoding(
|
| 65 |
+
name=Path(model_path).name,
|
| 66 |
+
pat_str=self.pat_str,
|
| 67 |
+
mergeable_ranks=mergeable_ranks,
|
| 68 |
+
special_tokens=self.special_tokens,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self._n_words: int = self.model.n_vocab
|
| 72 |
+
# BOS / EOS token IDs
|
| 73 |
+
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
| 74 |
+
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
| 75 |
+
self.pad_id: int = -1
|
| 76 |
+
self.stop_tokens = {
|
| 77 |
+
self.special_tokens["<|end_of_text|>"],
|
| 78 |
+
self.special_tokens["<|eot_id|>"],
|
| 79 |
+
}
|
| 80 |
+
logger.info(
|
| 81 |
+
f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def encode(
|
| 85 |
+
self,
|
| 86 |
+
s: str,
|
| 87 |
+
*,
|
| 88 |
+
bos: bool,
|
| 89 |
+
eos: bool,
|
| 90 |
+
allowed_special: Literal["all"] | AbstractSet[str] | None = None,
|
| 91 |
+
disallowed_special: Literal["all"] | Collection[str] | None = None,
|
| 92 |
+
) -> list[int]:
|
| 93 |
+
"""
|
| 94 |
+
Encodes a string into a list of token IDs.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
s (str): The input string to be encoded.
|
| 98 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
| 99 |
+
eos (bool): Whether to append the end-of-sequence token.
|
| 100 |
+
allowed_tokens ("all"|set[str]): allowed special tokens in string
|
| 101 |
+
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
list[int]: A list of token IDs.
|
| 105 |
+
|
| 106 |
+
By default, setting disallowed_special=() encodes a string by ignoring
|
| 107 |
+
special tokens. Specifically:
|
| 108 |
+
- Setting `disallowed_special` to () will cause all text corresponding
|
| 109 |
+
to special tokens to be encoded as natural text (insteading of raising
|
| 110 |
+
an error).
|
| 111 |
+
- Setting `allowed_special` to "all" will treat all text corresponding
|
| 112 |
+
to special tokens to be encoded as special tokens.
|
| 113 |
+
"""
|
| 114 |
+
assert type(s) is str
|
| 115 |
+
allowed_special = allowed_special or set()
|
| 116 |
+
disallowed_special = disallowed_special or ()
|
| 117 |
+
|
| 118 |
+
# The tiktoken tokenizer can handle <=400k chars without
|
| 119 |
+
# pyo3_runtime.PanicException.
|
| 120 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
| 121 |
+
|
| 122 |
+
# https://github.com/openai/tiktoken/issues/195
|
| 123 |
+
# Here we iterate over subsequences and split if we exceed the limit
|
| 124 |
+
# of max consecutive non-whitespace or whitespace characters.
|
| 125 |
+
MAX_NO_WHITESPACES_CHARS = 25_000
|
| 126 |
+
|
| 127 |
+
substrs = (
|
| 128 |
+
substr
|
| 129 |
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
| 130 |
+
for substr in self._split_whitespaces_or_nonwhitespaces(
|
| 131 |
+
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
t: list[int] = []
|
| 135 |
+
for substr in substrs:
|
| 136 |
+
t.extend(
|
| 137 |
+
self.model.encode(
|
| 138 |
+
substr,
|
| 139 |
+
allowed_special=allowed_special,
|
| 140 |
+
disallowed_special=disallowed_special,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
if bos:
|
| 144 |
+
t.insert(0, self.bos_id)
|
| 145 |
+
if eos:
|
| 146 |
+
t.append(self.eos_id)
|
| 147 |
+
return t
|
| 148 |
+
|
| 149 |
+
def decode(self, t: Sequence[int]) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Decodes a list of token IDs into a string.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
t (List[int]): The list of token IDs to be decoded.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
str: The decoded string.
|
| 158 |
+
"""
|
| 159 |
+
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
| 160 |
+
return self.model.decode(cast(list[int], t))
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _split_whitespaces_or_nonwhitespaces(
|
| 164 |
+
s: str, max_consecutive_slice_len: int
|
| 165 |
+
) -> Iterator[str]:
|
| 166 |
+
"""
|
| 167 |
+
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
| 168 |
+
consecutive whitespaces or consecutive non-whitespaces.
|
| 169 |
+
"""
|
| 170 |
+
current_slice_len = 0
|
| 171 |
+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
| 172 |
+
slice_start = 0
|
| 173 |
+
|
| 174 |
+
for i in range(len(s)):
|
| 175 |
+
is_now_space = s[i].isspace()
|
| 176 |
+
|
| 177 |
+
if current_slice_is_space ^ is_now_space:
|
| 178 |
+
current_slice_len = 1
|
| 179 |
+
current_slice_is_space = is_now_space
|
| 180 |
+
else:
|
| 181 |
+
current_slice_len += 1
|
| 182 |
+
if current_slice_len > max_consecutive_slice_len:
|
| 183 |
+
yield s[slice_start:i]
|
| 184 |
+
slice_start = i
|
| 185 |
+
current_slice_len = 1
|
| 186 |
+
yield s[slice_start:]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
|
| 190 |
+
return TikTokenizer(job_config.model.tokenizer_path)
|
torchtitan/distributed/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (255 Bytes). View file
|
|
|
torchtitan/distributed/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
torchtitan/experiments/deepseek_v3/inference.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#!/usr/bin/bash
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
NGPU=${NGPU:-"4"}
|
| 10 |
+
|
| 11 |
+
# Get the prompt from command line argument or use a default
|
| 12 |
+
prompt="${1:-What is 2+2?}"
|
| 13 |
+
|
| 14 |
+
# Run the model with the prompt
|
| 15 |
+
torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt"
|
torchtitan/experiments/deepseek_v3/model_config.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ModelArgs:
|
| 12 |
+
r"""
|
| 13 |
+
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
| 14 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 15 |
+
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
| 16 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 17 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 18 |
+
Args:
|
| 19 |
+
vocab_size (`int`, *optional*, defaults to 129280):
|
| 20 |
+
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
| 21 |
+
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
| 22 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 23 |
+
Dimension of the hidden representations.
|
| 24 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 25 |
+
Dimension of the MLP representations.
|
| 26 |
+
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
| 27 |
+
Dimension of the MoE representations.
|
| 28 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 29 |
+
Number of hidden layers in the Transformer decoder.
|
| 30 |
+
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
| 31 |
+
Number of nextn predict layers in the DeepSeekV3 Model.
|
| 32 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 33 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 34 |
+
n_shared_experts (`int`, *optional*, defaults to None):
|
| 35 |
+
Number of shared experts, None means dense model.
|
| 36 |
+
n_routed_experts (`int`, *optional*, defaults to None):
|
| 37 |
+
Number of routed experts, None means dense model.
|
| 38 |
+
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
| 39 |
+
Scaling factor or routed experts.
|
| 40 |
+
topk_method (`str`, *optional*, defaults to `gready`):
|
| 41 |
+
Topk method used in routed gate.
|
| 42 |
+
n_group (`int`, *optional*, defaults to None):
|
| 43 |
+
Number of groups for routed experts.
|
| 44 |
+
topk_group (`int`, *optional*, defaults to None):
|
| 45 |
+
Number of selected groups for each token(for each token, ensuring the selected experts is only within
|
| 46 |
+
`topk_group` groups).
|
| 47 |
+
num_experts_per_tok (`int`, *optional*, defaults to None):
|
| 48 |
+
Number of selected experts, None means dense model.
|
| 49 |
+
moe_layer_freq (`int`, *optional*, defaults to 1):
|
| 50 |
+
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
| 51 |
+
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
| 52 |
+
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
| 53 |
+
\--k dense layers--/
|
| 54 |
+
norm_topk_prob (`bool`, *optional*, defaults to False):
|
| 55 |
+
Whether to normalize the weights of the routed experts.
|
| 56 |
+
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
| 57 |
+
Method of computing expert weights.
|
| 58 |
+
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
| 59 |
+
Auxiliary loss weight coefficient.
|
| 60 |
+
seq_aux = (`bool`, *optional*, defaults to True):
|
| 61 |
+
Whether to compute the auxiliary loss for each individual sample.
|
| 62 |
+
num_key_value_heads (`int`, *optional*):
|
| 63 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 64 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 65 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 66 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 67 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 68 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 69 |
+
`num_attention_heads`.
|
| 70 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 71 |
+
The non-linear activation function (function or string) in the decoder.
|
| 72 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 73 |
+
The maximum sequence length that this model might ever be used with.
|
| 74 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 75 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 76 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 77 |
+
The epsilon used by the rms normalization layers.
|
| 78 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 80 |
+
relevant if `config.is_decoder=True`.
|
| 81 |
+
pad_token_id (`int`, *optional*):
|
| 82 |
+
Padding token id.
|
| 83 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 84 |
+
Beginning of stream token id.
|
| 85 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 86 |
+
End of stream token id.
|
| 87 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 88 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 89 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
| 90 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
| 91 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 92 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Whether to tie weight embeddings
|
| 94 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 95 |
+
The base period of the RoPE embeddings.
|
| 96 |
+
rope_scaling (`Dict`, *optional*):
|
| 97 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 98 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 99 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 100 |
+
`max_position_embeddings` to the expected new maximum.
|
| 101 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 102 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 103 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 104 |
+
The dropout ratio for the attention probabilities.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
vocab_size: int = 129280
|
| 108 |
+
hidden_size: int = 7168
|
| 109 |
+
intermediate_size: int = 18432
|
| 110 |
+
moe_intermediate_size: int = 2048
|
| 111 |
+
num_hidden_layers: int = 61
|
| 112 |
+
num_nextn_predict_layers: int = 1
|
| 113 |
+
num_attention_heads: int = 128
|
| 114 |
+
num_key_value_heads: int = 128
|
| 115 |
+
n_shared_experts: int = 1
|
| 116 |
+
n_routed_experts: int = 256
|
| 117 |
+
ep_size: int = 1
|
| 118 |
+
routed_scaling_factor: float = 2.5
|
| 119 |
+
kv_lora_rank: int = 512
|
| 120 |
+
q_lora_rank: int = 1536
|
| 121 |
+
qk_rope_head_dim: int = 64
|
| 122 |
+
v_head_dim: int = 128
|
| 123 |
+
qk_nope_head_dim: int = 128
|
| 124 |
+
topk_method: str = "noaux_tc"
|
| 125 |
+
n_group: int = 8
|
| 126 |
+
topk_group: int = 4
|
| 127 |
+
num_experts_per_tok: int = 8
|
| 128 |
+
moe_layer_freq: int = 1
|
| 129 |
+
first_k_dense_replace: int = 3
|
| 130 |
+
norm_topk_prob: bool = True
|
| 131 |
+
scoring_func: str = "sigmoid"
|
| 132 |
+
aux_loss_alpha: float = 0.001
|
| 133 |
+
seq_aux: bool = True
|
| 134 |
+
hidden_act: str = "silu"
|
| 135 |
+
max_position_embeddings: int = 163840
|
| 136 |
+
initializer_range: float = 0.02
|
| 137 |
+
rms_norm_eps: float = 1e-6
|
| 138 |
+
rope_theta: float = 10000.0
|
| 139 |
+
rope_scaling: dict = field(
|
| 140 |
+
default_factory=lambda: {
|
| 141 |
+
"beta_fast": 32,
|
| 142 |
+
"beta_slow": 1,
|
| 143 |
+
"factor": 40,
|
| 144 |
+
"mscale": 1.0,
|
| 145 |
+
"mscale_all_dim": 1.0,
|
| 146 |
+
"original_max_position_embeddings": 4096,
|
| 147 |
+
"type": "yarn",
|
| 148 |
+
}
|
| 149 |
+
)
|
| 150 |
+
attention_bias: bool = False
|
| 151 |
+
attention_dropout: float = 0.0
|
| 152 |
+
pad_token_id = None
|
| 153 |
+
# Added for symmetric memory
|
| 154 |
+
max_seq_len: int = 4096
|
| 155 |
+
dtype: str = "bfloat16"
|
| 156 |
+
# Added for pipeline parallel
|
| 157 |
+
num_stages: int = 1
|
| 158 |
+
stage_idx: int = 0
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# This is the configuration for deepseek-ai/DeepSeek-V2-Lite.
|
| 162 |
+
deepseek_v2_lite_config = ModelArgs(
|
| 163 |
+
vocab_size=102400,
|
| 164 |
+
hidden_size=2048,
|
| 165 |
+
intermediate_size=10944,
|
| 166 |
+
moe_intermediate_size=1408,
|
| 167 |
+
num_hidden_layers=27,
|
| 168 |
+
num_attention_heads=16,
|
| 169 |
+
num_key_value_heads=16,
|
| 170 |
+
n_shared_experts=2,
|
| 171 |
+
n_routed_experts=64,
|
| 172 |
+
routed_scaling_factor=1.0,
|
| 173 |
+
kv_lora_rank=512,
|
| 174 |
+
q_lora_rank=None,
|
| 175 |
+
qk_rope_head_dim=64,
|
| 176 |
+
v_head_dim=128,
|
| 177 |
+
qk_nope_head_dim=128,
|
| 178 |
+
topk_method="greedy",
|
| 179 |
+
n_group=1,
|
| 180 |
+
topk_group=1,
|
| 181 |
+
num_experts_per_tok=6,
|
| 182 |
+
first_k_dense_replace=1,
|
| 183 |
+
norm_topk_prob=False,
|
| 184 |
+
scoring_func="softmax",
|
| 185 |
+
max_position_embeddings=4096,
|
| 186 |
+
rope_scaling={
|
| 187 |
+
"beta_fast": 32,
|
| 188 |
+
"beta_slow": 1,
|
| 189 |
+
"factor": 40,
|
| 190 |
+
"mscale": 0.707,
|
| 191 |
+
"mscale_all_dim": 0.707,
|
| 192 |
+
"original_max_position_embeddings": 4096,
|
| 193 |
+
"type": "yarn",
|
| 194 |
+
},
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# Model configuration registry
|
| 199 |
+
# Key is the model distribution ID on HuggingFace Hub
|
| 200 |
+
deepseek_config_registry = {
|
| 201 |
+
"deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config,
|
| 202 |
+
"deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config,
|
| 203 |
+
"deepseek-ai/deepseek-v3": ModelArgs(),
|
| 204 |
+
}
|
torchtitan/experiments/flux/README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX model in torchtitan
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
## Usage
|
| 6 |
+
First, download the autoencoder model from HuggingFace with your own access token:
|
| 7 |
+
```bash
|
| 8 |
+
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
|
| 9 |
+
```
|
| 10 |
+
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
|
| 11 |
+
|
| 12 |
+
Run the following command to train the model on a single GPU:
|
| 13 |
+
```bash
|
| 14 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## TODO
|
| 18 |
+
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
|
| 19 |
+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
|
| 20 |
+
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
|
| 21 |
+
- [ ] Support for distributed checkpointing and loading
|
| 22 |
+
- [ ] Implement init_weights() function to initialize the model weights
|
| 23 |
+
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
|
torchtitan/experiments/flux/__pycache__/parallelize_flux.cpython-312.pyc
ADDED
|
Binary file (648 Bytes). View file
|
|
|
torchtitan/experiments/flux/flux_argparser.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extend_parser(parser: argparse.ArgumentParser) -> None:
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--training.guidance",
|
| 15 |
+
type=float,
|
| 16 |
+
default=3.5,
|
| 17 |
+
help="guidance value used for guidance distillation",
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--encoder.t5_encoder",
|
| 21 |
+
type=str,
|
| 22 |
+
default="google/t5-v1_1-small",
|
| 23 |
+
help="T5 encoder to use, HuggingFace model name.",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--encoder.clip_encoder",
|
| 27 |
+
type=str,
|
| 28 |
+
default="openai/clip-vit-large-patch14",
|
| 29 |
+
help="Clip encoder to use, HuggingFace model name.",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--encoder.encoder_dtype",
|
| 33 |
+
type=torch.dtype,
|
| 34 |
+
default=torch.bfloat16,
|
| 35 |
+
help="Which dtype to load for autoencoder. ",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--encoder.max_t5_encoding_len",
|
| 39 |
+
type=int,
|
| 40 |
+
default=512,
|
| 41 |
+
help="Maximum length of the T5 encoding.",
|
| 42 |
+
)
|
torchtitan/experiments/flux/loss.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, TypeAlias
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torchtitan.config_manager import JobConfig
|
| 12 |
+
from torchtitan.tools.logging import logger
|
| 13 |
+
|
| 14 |
+
LossFunction: TypeAlias = Callable[..., torch.Tensor]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""Common MSE loss function for Transformer models training."""
|
| 19 |
+
return torch.nn.functional.mse_loss(pred.float(), labels.float().detach())
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_mse_loss(job_config: JobConfig):
|
| 23 |
+
loss_fn = mse_loss
|
| 24 |
+
if job_config.training.compile:
|
| 25 |
+
logger.info("Compiling the loss function with torch.compile")
|
| 26 |
+
loss_fn = torch.compile(loss_fn)
|
| 27 |
+
return loss_fn
|