msj19 commited on
Commit
52abc1d
·
verified ·
1 Parent(s): 4c0cd64

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. create_yaml.py +115 -0
  2. fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd.cpython-310.pyc +0 -0
  4. fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc +0 -0
  5. fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc +0 -0
  6. fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc +0 -0
  7. fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc +0 -0
  8. fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc +0 -0
  9. fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc +0 -0
  10. fla3/ops/retention/__pycache__/fused_chunk.cpython-310.pyc +0 -0
  11. fla3/ops/rwkv7/__pycache__/__init__.cpython-310.pyc +0 -0
  12. fla3/ops/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  13. fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc +0 -0
  14. fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc +0 -0
  15. fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc +0 -0
  16. fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  17. fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  18. fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  19. fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc +0 -0
  20. fla3/ops/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  21. fla3/ops/utils/__pycache__/asm.cpython-310.pyc +0 -0
  22. fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc +0 -0
  23. fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc +0 -0
  24. fla3/ops/utils/__pycache__/index.cpython-310.pyc +0 -0
  25. fla3/ops/utils/__pycache__/index.cpython-312.pyc +0 -0
  26. fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc +0 -0
  27. fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc +0 -0
  28. fla3/ops/utils/__pycache__/matmul.cpython-310.pyc +0 -0
  29. fla3/ops/utils/__pycache__/op.cpython-312.pyc +0 -0
  30. fla3/ops/utils/__pycache__/pack.cpython-310.pyc +0 -0
  31. fla3/ops/utils/__pycache__/pack.cpython-312.pyc +0 -0
  32. fla3/ops/utils/__pycache__/pooling.cpython-310.pyc +0 -0
  33. fla3/ops/utils/__pycache__/softmax.cpython-310.pyc +0 -0
  34. fla3/ops/utils/__pycache__/softmax.cpython-312.pyc +0 -0
  35. fla3/ops/utils/__pycache__/solve_tril.cpython-312.pyc +0 -0
  36. fla3/ops/utils/logsumexp.py +80 -0
  37. fla3/ops/utils/matmul.py +245 -0
  38. fla3/ops/utils/op.py +39 -0
  39. fla3/ops/utils/pack.py +208 -0
  40. fla3/ops/utils/pooling.py +207 -0
  41. flame/__init__.py +0 -0
  42. flame/__pycache__/__init__.cpython-310.pyc +0 -0
  43. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  44. flame/__pycache__/data.cpython-310.pyc +0 -0
  45. flame/__pycache__/logging.cpython-310.pyc +0 -0
  46. flame/__pycache__/logging.cpython-312.pyc +0 -0
  47. flame/__pycache__/parser.cpython-310.pyc +0 -0
  48. flame/__pycache__/parser.cpython-312.pyc +0 -0
  49. flame/data.py +246 -0
  50. flame/parser.py +94 -0
create_yaml.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import yaml
4
+ import sys
5
+
6
+ def main():
7
+ # 从环境变量获取值,设置默认值
8
+ master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
9
+ master_port = int(os.getenv("MASTER_PORT", 29500))
10
+
11
+ # 获取节点和GPU数量,确保转换为整数
12
+ try:
13
+ num_nodes = int(os.getenv("SENSECORE_PYTORCH_NNODES", 1))
14
+ except (ValueError, TypeError):
15
+ num_nodes = 1
16
+
17
+ try:
18
+ gpus_per_node = int(os.getenv("SENSECORE_ACCELERATE_DEVICE_COUNT", 1))
19
+ except (ValueError, TypeError):
20
+ gpus_per_node = 1
21
+
22
+ try:
23
+ node_rank = int(os.getenv("SENSECORE_PYTORCH_NODE_RANK", 0))
24
+ except (ValueError, TypeError):
25
+ node_rank = 0
26
+
27
+ # 计算总进程数
28
+ num_processes = num_nodes * gpus_per_node
29
+
30
+ # 配置字典
31
+ config = {
32
+ "compute_environment": "LOCAL_MACHINE",
33
+ "distributed_type": "DEEPSPEED",
34
+ "deepspeed_config": {
35
+ "deepspeed_config_file": "configs/ds_config.json",
36
+ "zero3_init_flag": True,
37
+ "deepspeed_multinode_launcher": "standard",
38
+ "deepspeed_hostfile": '/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt',
39
+ },
40
+ "machine_rank": node_rank,
41
+ "main_process_ip": master_addr,
42
+ "main_process_port": master_port,
43
+ "main_training_function": "main",
44
+ "num_machines": num_nodes,
45
+ "num_processes": num_processes,
46
+ "same_network": True,
47
+ "use_cpu": False,
48
+ "rdzv_backend": "c10d",
49
+ "tpu_env": [],
50
+ "tpu_use_cluster": False,
51
+ "tpu_use_sudo": False,
52
+ }
53
+
54
+ # 打印配置信息
55
+ print("Generated Configuration:")
56
+ print(f" Master: {master_addr}:{master_port}")
57
+ print(f" Number of nodes: {num_nodes}")
58
+ print(f" GPUs per node: {gpus_per_node}")
59
+ print(f" Total processes: {num_processes}")
60
+ print(f" Node rank: {node_rank}")
61
+
62
+ # 确保配置目录存在
63
+ os.makedirs("configs", exist_ok=True)
64
+
65
+ # 写入YAML文件
66
+ output_file = "/mnt/jfzn/msj/flash-linear-attention/legacy/training/configs/deepspeed_sencore.yaml"
67
+ with open(output_file, "w") as f:
68
+ yaml.dump(config, f, default_flow_style=False)
69
+
70
+
71
+
72
+
73
+ print(f"\nConfiguration saved to: {output_file}")
74
+
75
+ # 同时输出文件内容供检查
76
+ print("\nFile content:")
77
+ with open(output_file, "r") as f:
78
+ print(f.read())
79
+
80
+ # 读取原始 SSH 配置
81
+
82
+ input_file = '/mnt/jfzn/msj/flash-linear-attention/legacy/training/ssh_config/config' # 你的文件名
83
+ output_file = '/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt' # 输出 hostfile
84
+
85
+ hostnames = []
86
+
87
+ with open(input_file, "r") as f:
88
+ for line in f:
89
+ line = line.strip()
90
+ if line.startswith("Hostname"):
91
+ # 提取 Hostname 后面的内容
92
+ hostname = line.split(None, 1)[1]
93
+ hostnames.append(hostname)
94
+
95
+ # 写入到 hostfile,每行一个 hostname
96
+ with open(output_file, "w") as f:
97
+ for host in hostnames:
98
+ f.write(host+ " slots=8\n")
99
+
100
+ print(f"提取了 {len(hostnames)} 个 hostname,已写入 {output_file}")
101
+
102
+
103
+ # output_path='/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt'
104
+ # with open(output_path, "w") as f:
105
+ # addr = master_addr
106
+ # f.write(f"{addr} slots={8}\n")
107
+ # q = addr.split('-')
108
+ # q[2] = 'worker'
109
+ # addr = '-'.join(q)
110
+ # f.write(f"{addr} slots={8}\n")
111
+ # print(f"hostfile 已生成: {output_path}")
112
+ return 0
113
+
114
+ if __name__ == "__main__":
115
+ sys.exit(main())
fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (273 Bytes). View file
 
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd.cpython-310.pyc ADDED
Binary file (5.02 kB). View file
 
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc ADDED
Binary file (5.13 kB). View file
 
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc ADDED
Binary file (4.65 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (9.9 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc ADDED
Binary file (5.47 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc ADDED
Binary file (5.1 kB). View file
 
fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
fla3/ops/retention/__pycache__/fused_chunk.cpython-310.pyc ADDED
Binary file (9.26 kB). View file
 
fla3/ops/rwkv7/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (350 Bytes). View file
 
fla3/ops/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (361 Bytes). View file
 
fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (2.56 kB). View file
 
fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc ADDED
Binary file (3.93 kB). View file
 
fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (16.5 kB). View file
 
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (4.78 kB). View file
 
fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
fla3/ops/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.16 kB). View file
 
fla3/ops/utils/__pycache__/asm.cpython-310.pyc ADDED
Binary file (482 Bytes). View file
 
fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc ADDED
Binary file (21.4 kB). View file
 
fla3/ops/utils/__pycache__/index.cpython-310.pyc ADDED
Binary file (3.12 kB). View file
 
fla3/ops/utils/__pycache__/index.cpython-312.pyc ADDED
Binary file (5.48 kB). View file
 
fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc ADDED
Binary file (3.66 kB). View file
 
fla3/ops/utils/__pycache__/matmul.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
fla3/ops/utils/__pycache__/op.cpython-312.pyc ADDED
Binary file (1.56 kB). View file
 
fla3/ops/utils/__pycache__/pack.cpython-310.pyc ADDED
Binary file (4.56 kB). View file
 
fla3/ops/utils/__pycache__/pack.cpython-312.pyc ADDED
Binary file (8.01 kB). View file
 
fla3/ops/utils/__pycache__/pooling.cpython-310.pyc ADDED
Binary file (5.61 kB). View file
 
fla3/ops/utils/__pycache__/softmax.cpython-310.pyc ADDED
Binary file (2.35 kB). View file
 
fla3/ops/utils/__pycache__/softmax.cpython-312.pyc ADDED
Binary file (4.92 kB). View file
 
fla3/ops/utils/__pycache__/solve_tril.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla3/ops/utils/logsumexp.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.op import exp, log
11
+
12
+
13
+ @triton.heuristics({
14
+ 'HAS_SCALE': lambda args: args['scale'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({}, num_warps=num_warps)
19
+ for num_warps in [1, 2, 4, 8, 16, 32]
20
+ ],
21
+ key=['D']
22
+ )
23
+ @triton.jit
24
+ def logsumexp_fwd_kernel(
25
+ x,
26
+ z,
27
+ scale,
28
+ D: tl.constexpr,
29
+ B: tl.constexpr,
30
+ HAS_SCALE: tl.constexpr
31
+ ):
32
+ i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
33
+ o_d = i_d * B + tl.arange(0, B)
34
+ m_d = o_d < D
35
+
36
+ b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
37
+ if HAS_SCALE:
38
+ b_x = b_x * scale
39
+ b_m = tl.max(b_x, 0)
40
+ b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m
41
+ tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z)
42
+
43
+
44
+ def logsumexp_fwd(
45
+ x,
46
+ scale: Optional[float] = None,
47
+ dtype: Optional[torch.dtype] = None
48
+ ):
49
+ r"""
50
+ Compute the logsumexp of the input tensor over the last dimension.
51
+
52
+ Args:
53
+ x (Tensor):
54
+ The input tensor of any shape.
55
+ scale (Optional[float]):
56
+ The scale applied to the input tensor. Default: `None`.
57
+ dtype (Optional[torch.dtype]):
58
+ The data type of the output tensor. Default: `None`.
59
+ Returns:
60
+ Tensor: The logsumexp of the input tensor.
61
+ """
62
+
63
+ shape = x.shape
64
+ x = x.view(-1, shape[-1])
65
+ N, D = x.shape
66
+ B = min(triton.next_power_of_2(D), 64 * 1024)
67
+ ND = triton.cdiv(D, B)
68
+
69
+ z = x.new_empty(N, ND, dtype=torch.float)
70
+ logsumexp_fwd_kernel[(N, ND)](
71
+ x=x,
72
+ z=z,
73
+ scale=scale,
74
+ D=D,
75
+ B=B
76
+ )
77
+ z = z.logsumexp(-1).view(*shape[:-1])
78
+ if dtype is not None and dtype != torch.float:
79
+ z = z.to(dtype)
80
+ return z
fla3/ops/utils/matmul.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # code adapted from
5
+ # https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from ...ops.utils.op import exp
14
+ from ...utils import input_guard
15
+
16
+
17
+ # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
18
+ # - A list of `triton.Config` objects that define different configurations of
19
+ # meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
20
+ # - An auto-tuning *key* whose change in values will trigger evaluation of all the
21
+ # provided configs
22
+ @triton.heuristics({
23
+ 'HAS_ALPHA': lambda args: args['alpha'] is not None,
24
+ 'HAS_BETA': lambda args: args['beta'] is not None
25
+ })
26
+ @triton.autotune(
27
+ configs=[
28
+ triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
29
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
30
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
31
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
32
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
33
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4),
34
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2),
35
+ triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2),
36
+ # Good config for fp8 inputs.
37
+ # triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
38
+ # triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
39
+ # triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
40
+ # triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
41
+ # triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
42
+ # triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
43
+ # triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
44
+ # triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4)
45
+ ],
46
+ key=['M', 'N', 'K']
47
+ )
48
+ @triton.jit
49
+ def matmul_kernel(
50
+ # Pointers to matrices
51
+ a,
52
+ b,
53
+ c,
54
+ input,
55
+ alpha,
56
+ beta,
57
+ # Matrix dimensions
58
+ M,
59
+ N,
60
+ K,
61
+ # The stride variables represent how much to increase the ptr by when moving by 1
62
+ # element in a particular dimension. E.g. `s_am` is how much to increase `a`
63
+ # by to get the element one row down (A has M rows).
64
+ stride_ab, stride_am, stride_ak, # a: batch, M, K
65
+ stride_bk, stride_bn, # b: K, N
66
+ stride_cb, stride_cm, stride_cn, # c: batch, M, N
67
+ # Meta-parameters
68
+ BM: tl.constexpr,
69
+ BK: tl.constexpr,
70
+ BN: tl.constexpr,
71
+ G: tl.constexpr,
72
+ ACTIVATION: tl.constexpr,
73
+ HAS_INPUT: tl.constexpr,
74
+ HAS_ALPHA: tl.constexpr,
75
+ HAS_BETA: tl.constexpr,
76
+ ALLOW_TF32: tl.constexpr,
77
+ X_DIM: tl.constexpr = 1,
78
+ ):
79
+ """Kernel for computing the matmul C = A x B.
80
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
81
+ """
82
+ # -----------------------------------------------------------
83
+ # Map program ids `pid` to the block of C it should compute.
84
+ # This is done in a grouped ordering to promote L2 data reuse.
85
+ # See above `L2 Cache Optimizations` section for details.
86
+ i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
87
+
88
+ NM, NN = tl.num_programs(1), tl.num_programs(2)
89
+ i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
90
+
91
+ # ----------------------------------------------------------
92
+ # Create pointers for the first blocks of A and B.
93
+ # We will advance this pointer as we move in the K direction
94
+ # and accumulate
95
+ # `p_a` is a block of [BM, BK] pointers
96
+ # `p_b` is a block of [BK, BN] pointers
97
+ # See above `Pointer Arithmetic` section for details
98
+ a_batch_ptr = a + i_b * stride_ab
99
+ o_am = (i_m * BM + tl.arange(0, BM)) % M
100
+ o_bn = (i_n * BN + tl.arange(0, BN)) % N
101
+ o_k = tl.arange(0, BK)
102
+
103
+ p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak)
104
+ p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn)
105
+
106
+ b_acc = tl.zeros((BM, BN), dtype=tl.float32)
107
+ for k in range(0, tl.cdiv(K, BK)):
108
+ # Load the next block of A and B, generate a mask by checking the K dimension.
109
+ # If it is out of bounds, set it to 0.
110
+ b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)
111
+ b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)
112
+ # We accumulate along the K dimension.
113
+ b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32)
114
+ # Advance the ptrs to the next K block.
115
+ p_a += BK * stride_ak
116
+ p_b += BK * stride_bk
117
+
118
+ o_cm = i_m * BM + tl.arange(0, BM)
119
+ o_cn = i_n * BN + tl.arange(0, BN)
120
+ mask = (o_cm[:, None] < M) & (o_cn[None, :] < N)
121
+
122
+ b_c = b_acc
123
+ # You can fuse arbitrary activation functions here
124
+ # while the b_acc is still in FP32!
125
+ if ACTIVATION == "leaky_relu":
126
+ b_c = leaky_relu(b_c)
127
+ elif ACTIVATION == "relu":
128
+ b_c = relu(b_c)
129
+ elif ACTIVATION == "sigmoid":
130
+ b_c = sigmoid(b_c)
131
+ elif ACTIVATION == "tanh":
132
+ b_c = tanh(b_c)
133
+
134
+ if HAS_ALPHA:
135
+ b_c *= tl.load(alpha)
136
+
137
+ if HAS_INPUT:
138
+ p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :]
139
+ mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask
140
+ b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32)
141
+ if HAS_BETA:
142
+ b_i *= tl.load(beta)
143
+ b_c += b_i
144
+
145
+ # -----------------------------------------------------------
146
+ # Write back the block of the output matrix C with masks.
147
+ c_batch_ptr = c + i_b * stride_cb
148
+ p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :]
149
+ tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask)
150
+
151
+
152
+ # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
153
+ @triton.jit
154
+ def leaky_relu(x):
155
+ return tl.where(x >= 0, x, 0.01 * x)
156
+
157
+
158
+ @triton.jit
159
+ def sigmoid(x):
160
+ # σ(x) = 1 / (1 + exp(-x))
161
+ return 1.0 / (1.0 + exp(-x))
162
+
163
+
164
+ @triton.jit
165
+ def tanh(x):
166
+ # tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
167
+ # 2 * sigmoid(2x) - 1
168
+ return (exp(x) - exp(-x)) / (exp(x) + exp(-x))
169
+
170
+
171
+ @triton.jit
172
+ def relu(x):
173
+ # ReLU(x) = max(0, x)
174
+ return tl.maximum(x, 0.0)
175
+
176
+
177
+ @input_guard
178
+ def matmul(a, b, activation=''):
179
+ assert a.dim() in [2, 3], "a must be 2D or 3D"
180
+ assert b.dim() == 2, "b must be 2D"
181
+ assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
182
+
183
+ if a.dim() == 2:
184
+ a_dim = 2
185
+ a = a.unsqueeze(0).contiguous() # (1, M, K)
186
+ else:
187
+ a_dim = 3
188
+ allow_tf32 = False if a.dtype == torch.float32 else True
189
+
190
+ B, M, K = a.shape[0], a.shape[1], a.shape[2]
191
+ K_b, N = b.shape
192
+ assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
193
+ c = a.new_empty(B, M, N)
194
+
195
+ def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
196
+ matmul_kernel[grid](
197
+ a, b, c, None, None, None,
198
+ M, N, K,
199
+ a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
200
+ b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
201
+ c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
202
+ ACTIVATION=activation,
203
+ ALLOW_TF32=allow_tf32,
204
+ HAS_INPUT=False,
205
+ )
206
+ return c.squeeze(0) if a_dim == 2 else c
207
+
208
+
209
+ @input_guard
210
+ def addmm(
211
+ x: torch.Tensor,
212
+ a: torch.Tensor,
213
+ b: torch.Tensor,
214
+ alpha: Optional[float] = None,
215
+ beta: Optional[float] = None,
216
+ ) -> torch.Tensor:
217
+ assert a.dim() in [2, 3], "a must be 2D or 3D"
218
+ assert b.dim() == 2, "b must be 2D"
219
+ assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
220
+
221
+ if a.dim() == 2:
222
+ a_dim = 2
223
+ a = a.unsqueeze(0).contiguous() # (1, M, K)
224
+ else:
225
+ a_dim = 3
226
+ allow_tf32 = False if a.dtype == torch.float32 else True
227
+
228
+ B, M, K = a.shape[0], a.shape[1], a.shape[2]
229
+ K_b, N = b.shape
230
+ assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
231
+ c = a.new_empty(B, M, N)
232
+
233
+ def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
234
+ matmul_kernel[grid](
235
+ a, b, c, x, alpha, beta,
236
+ M, N, K,
237
+ a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
238
+ b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
239
+ c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
240
+ ACTIVATION=None,
241
+ ALLOW_TF32=allow_tf32,
242
+ HAS_INPUT=True,
243
+ X_DIM=x.dim(),
244
+ )
245
+ return c.squeeze(0) if a_dim == 2 else c
fla3/ops/utils/op.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ import os
5
+
6
+ import triton
7
+ import triton.language as tl
8
+ import triton.language.extra.libdevice as tldevice
9
+
10
+ from ...utils import is_gather_supported
11
+
12
+ if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
13
+ div = tldevice.fast_dividef
14
+ exp = tldevice.fast_expf
15
+ log = tldevice.fast_logf
16
+ log2 = tldevice.fast_log2f
17
+ else:
18
+ @triton.jit
19
+ def div_normal(x, y):
20
+ return x / y
21
+ div = div_normal
22
+ exp = tl.exp
23
+ log = tl.log
24
+ log2 = tl.log2
25
+
26
+
27
+ @triton.jit
28
+ def safe_exp(x):
29
+ return exp(tl.where(x <= 0, x, float('-inf')))
30
+
31
+
32
+ if not is_gather_supported:
33
+ @triton.jit
34
+ def gather(src, index, axis, _builder=None):
35
+ # This is a fallback implementation when tl.gather is not supported
36
+ # In order to pass triton compiler, there is no actual gather operation
37
+ return src
38
+ else:
39
+ gather = tl.gather
fla3/ops/utils/pack.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # Code adapted from https://github.com/mayank31398/cute-kernels
5
+
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from ...ops.utils.index import prepare_lens
13
+ from ...utils import input_guard
14
+
15
+
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({}, num_warps=num_warps)
19
+ for num_warps in [4, 8, 16, 32]
20
+ ],
21
+ key=['D', 'PADDING_SIDE', 'PACK']
22
+ )
23
+ @triton.jit
24
+ def packunpack_sequence_kernel(
25
+ x,
26
+ y,
27
+ cu_seqlens,
28
+ S,
29
+ D,
30
+ BD: tl.constexpr,
31
+ PADDING_SIDE: tl.constexpr,
32
+ PACK: tl.constexpr,
33
+ ):
34
+ i_d, i_s, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+ bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
36
+
37
+ T = eos - bos
38
+ if PADDING_SIDE == 'left':
39
+ NP = S - T
40
+ if i_s < NP:
41
+ return
42
+ i_t = bos + (i_s - NP)
43
+ else:
44
+ if i_s >= T:
45
+ return
46
+ i_t = bos + i_s
47
+
48
+ o_d = i_d * BD + tl.arange(0, BD)
49
+ mask = o_d < D
50
+
51
+ if PACK:
52
+ b_x = tl.load(x + (i_b * S + i_s) * D + o_d, mask=mask)
53
+ tl.store(y + i_t * D + o_d, b_x, mask=mask)
54
+ else:
55
+ b_x = tl.load(x + i_t * D + o_d, mask=mask)
56
+ tl.store(y + (i_b * S + i_s) * D + o_d, b_x, mask=mask)
57
+
58
+
59
+ def pack_sequence_fwdbwd(
60
+ x: torch.Tensor,
61
+ cu_seqlens: torch.Tensor,
62
+ padding_side: str,
63
+ ) -> torch.Tensor:
64
+ B, S = x.shape[:2]
65
+ D = x.numel() // (B * S)
66
+ BD = min(triton.next_power_of_2(D), 4096)
67
+ ND = triton.cdiv(D, BD)
68
+
69
+ y = torch.empty(cu_seqlens[-1].item(), *x.shape[2:], device=x.device, dtype=x.dtype)
70
+ packunpack_sequence_kernel[ND, S, B](
71
+ x=x,
72
+ y=y,
73
+ cu_seqlens=cu_seqlens,
74
+ S=S,
75
+ D=D,
76
+ BD=BD,
77
+ PADDING_SIDE=padding_side,
78
+ PACK=True,
79
+ )
80
+ return y
81
+
82
+
83
+ def unpack_sequence_fwdbwd(
84
+ x: torch.Tensor,
85
+ cu_seqlens: torch.Tensor,
86
+ padding_side: str,
87
+ desired_shape: torch.Size,
88
+ ) -> torch.Tensor:
89
+ if desired_shape is None:
90
+ desired_shape = (len(cu_seqlens) - 1, prepare_lens(cu_seqlens).max().item(), *x.shape[1:])
91
+ y = torch.zeros(desired_shape, device=x.device, dtype=x.dtype)
92
+ B, S = y.shape[:2]
93
+ D = y.numel() // (B * S)
94
+ BD = min(triton.next_power_of_2(D), 4096)
95
+ ND = triton.cdiv(D, BD)
96
+
97
+ packunpack_sequence_kernel[ND, S, B](
98
+ x=x,
99
+ y=y,
100
+ cu_seqlens=cu_seqlens,
101
+ S=S,
102
+ D=D,
103
+ BD=BD,
104
+ PADDING_SIDE=padding_side,
105
+ PACK=False,
106
+ )
107
+ return y
108
+
109
+
110
+ class PackSequenceFunction(torch.autograd.Function):
111
+
112
+ @staticmethod
113
+ @input_guard
114
+ def forward(
115
+ ctx,
116
+ x: torch.Tensor,
117
+ cu_seqlens: torch.Tensor,
118
+ padding_side: str,
119
+ ) -> torch.Tensor:
120
+ assert padding_side in ['left', 'right']
121
+ assert x.ndim >= 2
122
+
123
+ ctx.cu_seqlens = cu_seqlens
124
+ ctx.padding_side = padding_side
125
+ ctx.desired_shape = x.shape
126
+
127
+ y = pack_sequence_fwdbwd(
128
+ x=x,
129
+ cu_seqlens=cu_seqlens,
130
+ padding_side=padding_side,
131
+ )
132
+ return y
133
+
134
+ @staticmethod
135
+ @input_guard
136
+ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
137
+ dx = unpack_sequence_fwdbwd(
138
+ x=dy,
139
+ cu_seqlens=ctx.cu_seqlens,
140
+ padding_side=ctx.padding_side,
141
+ desired_shape=ctx.desired_shape,
142
+ )
143
+ return dx, *[None] * 10
144
+
145
+
146
+ class UnpackSequenceFunction(torch.autograd.Function):
147
+
148
+ @staticmethod
149
+ @input_guard
150
+ def forward(
151
+ ctx,
152
+ x: torch.Tensor,
153
+ cu_seqlens: torch.Tensor,
154
+ padding_side: str,
155
+ desired_shape: Optional[torch.Size] = None,
156
+ ) -> torch.Tensor:
157
+ assert padding_side in ['left', 'right']
158
+ assert x.ndim >= 2
159
+ if desired_shape is not None:
160
+ assert desired_shape[0] == cu_seqlens.shape[0] - 1
161
+ assert desired_shape[2:] == x.shape[1:]
162
+
163
+ ctx.cu_seqlens = cu_seqlens
164
+ ctx.padding_side = padding_side
165
+
166
+ y = unpack_sequence_fwdbwd(
167
+ x=x,
168
+ cu_seqlens=cu_seqlens,
169
+ padding_side=padding_side,
170
+ desired_shape=desired_shape,
171
+ )
172
+ return y
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
177
+ dx = pack_sequence_fwdbwd(
178
+ x=dy,
179
+ cu_seqlens=ctx.cu_seqlens,
180
+ padding_side=ctx.padding_side,
181
+ )
182
+ return dx, None, None, None
183
+
184
+
185
+ def pack_sequence(
186
+ x: torch.Tensor,
187
+ cu_seqlens: torch.Tensor,
188
+ padding_side: str = 'left'
189
+ ) -> torch.Tensor:
190
+ return PackSequenceFunction.apply(
191
+ x,
192
+ cu_seqlens,
193
+ padding_side,
194
+ )
195
+
196
+
197
+ def unpack_sequence(
198
+ x: torch.Tensor,
199
+ cu_seqlens: torch.Tensor,
200
+ padding_side: str = 'left',
201
+ desired_shape: Optional[torch.Size] = None,
202
+ ) -> torch.Tensor:
203
+ return UnpackSequenceFunction.apply(
204
+ x,
205
+ cu_seqlens,
206
+ padding_side,
207
+ desired_shape,
208
+ )
fla3/ops/utils/pooling.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.index import prepare_chunk_indices
11
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BD': BD}, num_warps=num_warps)
20
+ for BD in [16, 32, 64, 128]
21
+ for num_warps in [1, 2, 4, 8]
22
+ ],
23
+ key=['BT']
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def mean_pooling_fwd_kernel(
27
+ x,
28
+ o,
29
+ cu_seqlens,
30
+ chunk_indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ D: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BD: tl.constexpr,
36
+ IS_VARLEN: tl.constexpr
37
+ ):
38
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_tg = i_t
42
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
43
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
44
+ T = eos - bos
45
+ NT = tl.cdiv(T, BT)
46
+ else:
47
+ NT = tl.cdiv(T, BT)
48
+ i_tg = i_b * NT + i_t
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
52
+ p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
53
+ # [BT, BD]
54
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
55
+ # [BD]
56
+ b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
57
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
58
+
59
+
60
+ @triton.heuristics({
61
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
62
+ })
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BD': BD}, num_warps=num_warps)
66
+ for BD in [16, 32, 64, 128]
67
+ for num_warps in [1, 2, 4, 8]
68
+ ],
69
+ key=['BT']
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def mean_pooling_bwd_kernel(
73
+ do,
74
+ dx,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ D: tl.constexpr,
80
+ BT: tl.constexpr,
81
+ BD: tl.constexpr,
82
+ IS_VARLEN: tl.constexpr
83
+ ):
84
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_tg = i_t
88
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
89
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
90
+ T = eos - bos
91
+ NT = tl.cdiv(T, BT)
92
+ else:
93
+ NT = tl.cdiv(T, BT)
94
+ i_tg = i_b * NT + i_t
95
+ bos, eos = i_b * T, i_b * T + T
96
+
97
+ p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
98
+ p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
99
+ # [BD]
100
+ b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
101
+ # [BT, BD]
102
+ b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
103
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
104
+
105
+
106
+ def mean_pooling_fwd(
107
+ x: torch.Tensor,
108
+ chunk_size: int,
109
+ cu_seqlens: Optional[torch.LongTensor] = None
110
+ ) -> torch.Tensor:
111
+ B, T, H, D = x.shape
112
+ BT = chunk_size
113
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
114
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
115
+
116
+ o = x.new_empty(B, NT, H, D)
117
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
118
+ mean_pooling_fwd_kernel[grid](
119
+ x,
120
+ o,
121
+ cu_seqlens,
122
+ chunk_indices,
123
+ T=T,
124
+ H=H,
125
+ D=D,
126
+ BT=BT,
127
+ )
128
+ return o
129
+
130
+
131
+ def mean_pooling_bwd(
132
+ do: torch.Tensor,
133
+ batch_size: int,
134
+ seq_len: int,
135
+ chunk_size: int,
136
+ cu_seqlens: Optional[torch.LongTensor] = None
137
+ ) -> torch.Tensor:
138
+ B, T, H, D = batch_size, seq_len, *do.shape[-2:]
139
+ BT = chunk_size
140
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
141
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
142
+
143
+ dx = do.new_empty(B, T, H, D)
144
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
145
+ mean_pooling_bwd_kernel[grid](
146
+ do,
147
+ dx,
148
+ cu_seqlens,
149
+ chunk_indices,
150
+ T=T,
151
+ H=H,
152
+ D=D,
153
+ BT=BT,
154
+ )
155
+ return dx
156
+
157
+
158
+ class MeanPoolingFunction(torch.autograd.Function):
159
+
160
+ @staticmethod
161
+ @input_guard
162
+ @autocast_custom_fwd
163
+ def forward(
164
+ ctx,
165
+ x: torch.Tensor,
166
+ chunk_size: int,
167
+ cu_seqlens: Optional[torch.LongTensor] = None
168
+ ) -> torch.Tensor:
169
+ o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
170
+ ctx.batch_size = x.shape[0]
171
+ ctx.seq_len = x.shape[1]
172
+ ctx.chunk_size = chunk_size
173
+ ctx.cu_seqlens = cu_seqlens
174
+ return o
175
+
176
+ @staticmethod
177
+ @input_guard
178
+ @autocast_custom_bwd
179
+ def backward(
180
+ ctx, do
181
+ ) -> Tuple[torch.Tensor, None, None]:
182
+ batch_size = ctx.batch_size
183
+ seq_len = ctx.seq_len
184
+ chunk_size = ctx.chunk_size
185
+ cu_seqlens = ctx.cu_seqlens
186
+ dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
187
+ return dx, None, None
188
+
189
+
190
+ def mean_pooling(
191
+ x: torch.Tensor,
192
+ chunk_size: int,
193
+ cu_seqlens: Optional[torch.LongTensor] = None,
194
+ head_first: bool = False
195
+ ) -> torch.Tensor:
196
+ if head_first:
197
+ x = x.transpose(1, 2)
198
+ if cu_seqlens is not None:
199
+ if x.shape[0] != 1:
200
+ raise ValueError(
201
+ f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
202
+ f"Please ..tten variable-length inputs before processing."
203
+ )
204
+ o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
205
+ if head_first:
206
+ o = o.transpose(1, 2)
207
+ return o
flame/__init__.py ADDED
File without changes
flame/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (167 Bytes). View file
 
flame/__pycache__/data.cpython-310.pyc ADDED
Binary file (8.17 kB). View file
 
flame/__pycache__/logging.cpython-310.pyc ADDED
Binary file (3.56 kB). View file
 
flame/__pycache__/logging.cpython-312.pyc ADDED
Binary file (6.44 kB). View file
 
flame/__pycache__/parser.cpython-310.pyc ADDED
Binary file (2.89 kB). View file
 
flame/__pycache__/parser.cpython-312.pyc ADDED
Binary file (4.07 kB). View file
 
flame/data.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Iterable, List, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from datasets import Dataset, IterableDataset
12
+ from flame.logging import get_logger
13
+ from transformers import PreTrainedTokenizer
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class HuggingfaceDataset(IterableDataset):
19
+
20
+ def __init__(
21
+ self,
22
+ dataset: Dataset,
23
+ tokenizer: PreTrainedTokenizer,
24
+ context_len: int = 2048,
25
+ rank: int = 0,
26
+ world_size: int = 1,
27
+ buffer_size: int = 1024
28
+ ) -> HuggingfaceDataset:
29
+
30
+ self.dataset = dataset
31
+ self.tokenizer = tokenizer
32
+
33
+ self.data = dataset.shard(world_size, rank)
34
+ self.context_len = context_len
35
+ self.rank = rank
36
+ self.world_size = world_size
37
+ self.buffer_size = buffer_size
38
+
39
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
40
+ self.dtype = torch.int16
41
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
42
+ self.dtype = torch.int32
43
+ else:
44
+ self.dtype = torch.int64
45
+ self.states = None
46
+ self.buffer = torch.tensor([], dtype=self.dtype)
47
+ self.tokens = []
48
+ self.rand_id = 0
49
+ self.token_id = 0
50
+ self.rng_state = None
51
+ self._epoch = 0
52
+
53
+ def __iter__(self):
54
+ g = torch.Generator()
55
+ g.manual_seed(self._epoch + self.rank)
56
+ if self.rng_state is not None:
57
+ g.set_state(self.rng_state)
58
+
59
+ rand_it = self.randint(0, self.buffer_size, g=g)
60
+ if self.states is not None:
61
+ self.data.load_state_dict(self.states)
62
+
63
+ # max number of tokens allowed in the chunk buffer
64
+ n_tokens = self.buffer_size * self.context_len
65
+
66
+ while True:
67
+ for sample in self.tokenize(self.data):
68
+ # keep appending the samples to the token buffer
69
+ self.tokens += sample
70
+ # if the token buffer is full, start sampling
71
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, context_len] for efficiency
72
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
73
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
74
+ self.tokens = self.tokens[n_tokens:]
75
+ if len(self.buffer) == self.buffer_size:
76
+ yield from self.sample(rand_it)
77
+
78
+ n_chunks = len(self.tokens) // self.context_len
79
+ # handle the left tokens in the buffer
80
+ if n_chunks > 0:
81
+ n_tokens = n_chunks * self.context_len
82
+ indices = torch.randperm(n_chunks, generator=g).tolist()
83
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
84
+ self.tokens = self.tokens[n_tokens:]
85
+ for i in indices:
86
+ yield {'input_ids': self.buffer[i]}
87
+
88
+ def tokenize(self, data, batch_size: int = 64):
89
+ texts, states = [], []
90
+ for sample in data:
91
+ texts.append(sample['text'])
92
+ states.append(self.data.state_dict())
93
+ if len(texts) == batch_size:
94
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
95
+ self.states = s
96
+ yield tokenized
97
+ texts, states = [], []
98
+ if len(texts) > 0:
99
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
100
+ self.states = s
101
+ yield tokenized
102
+
103
+ def sample(self, indices):
104
+ n_tokens = (len(self.tokens) // self.context_len) * self.context_len
105
+ while self.token_id < n_tokens:
106
+ i = next(indices)
107
+ start, end = self.token_id, self.token_id + self.context_len
108
+ self.token_id += self.context_len
109
+ yield {'input_ids': self.buffer[i].to(torch.long)}
110
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
111
+ self.token_id = 0
112
+ self.tokens = self.tokens[n_tokens:]
113
+
114
+ def randint(
115
+ self,
116
+ low: int,
117
+ high: int,
118
+ batch_size: int = 1024,
119
+ g: torch.Generator = torch.Generator()
120
+ ) -> Iterable[int]:
121
+ indices = torch.empty(batch_size, dtype=torch.long)
122
+ while True:
123
+ # record the generator states before sampling
124
+ self.rng_state = g.get_state()
125
+ indices = torch.randint(low, high, (batch_size,), out=indices, generator=g)
126
+ for i in indices[self.rand_id:].tolist():
127
+ self.rand_id += 1
128
+ yield i
129
+ self.rand_id = 0
130
+
131
+ def set_epoch(self, epoch):
132
+ self._epoch = epoch
133
+ if hasattr(self.dataset, "set_epoch"):
134
+ self.dataset.set_epoch(epoch)
135
+
136
+ def state_dict(self):
137
+ return {
138
+ 'states': self.states,
139
+ 'buffer': self.buffer.clone(),
140
+ 'tokens': deepcopy(self.tokens),
141
+ 'rand_id': self.rand_id,
142
+ 'token_id': self.token_id,
143
+ 'rng_state': self.rng_state,
144
+ 'epoch': self._epoch
145
+ }
146
+
147
+ def load_state_dict(self, state_dict):
148
+ self.states = state_dict['states']
149
+ self.buffer = state_dict['buffer'].clone()
150
+ self.tokens = deepcopy(state_dict['tokens'])
151
+ self.rand_id = state_dict['rand_id']
152
+ self.token_id = state_dict['token_id']
153
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
154
+ self._epoch = state_dict['epoch']
155
+
156
+
157
+ @dataclass
158
+ class DataCollatorForLanguageModeling:
159
+ """
160
+ Data collator used for language modeling.
161
+
162
+ Args:
163
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
164
+ The tokenizer used for encoding the data.
165
+ varlen (`bool`):
166
+ Whether to return sequences with variable lengths.
167
+ If `True`, the offsets indicating the start and end of each sequence will be returned.
168
+ For example, if the sequence lengths are `[4, 8, 12]`,
169
+ the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`.
170
+ If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly.
171
+ return_tensors (`str`):
172
+ The type of Tensor to return. Allowable values are "pt".
173
+ """
174
+
175
+ tokenizer: PreTrainedTokenizer
176
+ varlen: bool = False
177
+ return_tensors: str = "pt"
178
+
179
+ def __call__(
180
+ self,
181
+ examples: List[Union[List[int], Dict[str, Any]]]
182
+ ) -> Dict[str, Any]:
183
+ if not isinstance(examples[0], Dict):
184
+ examples = [{'input_ids': example} for example in examples]
185
+
186
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
187
+ tensorized = {}
188
+ for key in ['input_ids', 'offsets']:
189
+ if key not in example:
190
+ continue
191
+ if isinstance(example[key], List):
192
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
193
+ elif isinstance(example[key], np.ndarray):
194
+ tensorized[key] = torch.from_numpy(example[key])
195
+ else:
196
+ tensorized[key] = example[key]
197
+ return tensorized
198
+
199
+ examples = list(map(tensorize, examples))
200
+
201
+ if not self.varlen:
202
+ length_of_first = examples[0]['input_ids'].size(0)
203
+ # Check if padding is necessary.
204
+ if all(example['input_ids'].size(0) == length_of_first for example in examples):
205
+ batch = {
206
+ 'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0),
207
+ }
208
+ else:
209
+ # If yes, check if we have a `pad_token`.
210
+ if self.tokenizer._pad_token is None:
211
+ raise ValueError(
212
+ f"You are attempting to pad samples but the tokenizer you are using "
213
+ f"({self.tokenizer.__class__.__name__}) does not have a pad token."
214
+ )
215
+ batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False)
216
+ else:
217
+ if len(examples) > 1:
218
+ raise ValueError("The batch size must be 1 for variable length inputs.")
219
+ batch = {
220
+ 'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)
221
+ }
222
+ if 'offsets' in examples[0]:
223
+ batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0)
224
+ else:
225
+ # determine boundaries by bos/eos positions
226
+ if self.tokenizer.add_bos_token:
227
+ offsets = []
228
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
229
+ offsets.append(torch.tensor([0], dtype=torch.long))
230
+ offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1])
231
+ offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
232
+ batch['offsets'] = torch.cat(offsets, dim=0)
233
+ elif self.tokenizer.add_eos_token:
234
+ offsets = [torch.tensor([0], dtype=torch.long)]
235
+ offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1)
236
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
237
+ offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
238
+ batch['offsets'] = torch.cat(offsets, dim=0)
239
+ else:
240
+ raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.")
241
+
242
+ labels = batch['input_ids'].clone()
243
+ if self.tokenizer.pad_token_id is not None:
244
+ labels[labels == self.tokenizer.pad_token_id] = -100
245
+ batch["labels"] = labels
246
+ return batch
flame/parser.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ import transformers
9
+ from transformers import HfArgumentParser, TrainingArguments
10
+
11
+ from flame.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class TrainingArguments(TrainingArguments):
18
+
19
+ model_name_or_path: str = field(
20
+ default=None,
21
+ metadata={
22
+ "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
23
+ },
24
+ )
25
+ tokenizer: str = field(
26
+ default="fla-hub/gla-1.3B-100B",
27
+ metadata={"help": "Name of the tokenizer to use."}
28
+ )
29
+ use_fast_tokenizer: bool = field(
30
+ default=False,
31
+ metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
32
+ )
33
+ from_config: bool = field(
34
+ default=True,
35
+ metadata={"help": "Whether to initialize models from scratch."},
36
+ )
37
+ dataset: Optional[str] = field(
38
+ default=None,
39
+ metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
40
+ )
41
+ dataset_name: Optional[str] = field(
42
+ default=None,
43
+ metadata={"help": "The name of provided dataset(s) to use."},
44
+ )
45
+ cache_dir: str = field(
46
+ default=None,
47
+ metadata={"help": "Path to the cached tokenized dataset."},
48
+ )
49
+ split: str = field(
50
+ default="train",
51
+ metadata={"help": "Which dataset split to use for training and evaluation."},
52
+ )
53
+ streaming: bool = field(
54
+ default=False,
55
+ metadata={"help": "Enable dataset streaming."},
56
+ )
57
+ hf_hub_token: Optional[str] = field(
58
+ default=None,
59
+ metadata={"help": "Auth token to log in with Hugging Face Hub."},
60
+ )
61
+ preprocessing_num_workers: Optional[int] = field(
62
+ default=None,
63
+ metadata={"help": "The number of processes to use for the pre-processing."},
64
+ )
65
+ buffer_size: int = field(
66
+ default=2048,
67
+ metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
68
+ )
69
+ context_length: int = field(
70
+ default=2048,
71
+ metadata={"help": "The context length of the tokenized inputs in the dataset."},
72
+ )
73
+ varlen: bool = field(
74
+ default=False,
75
+ metadata={"help": "Enable training with variable length inputs."},
76
+ )
77
+
78
+
79
+ def get_train_args():
80
+ parser = HfArgumentParser(TrainingArguments)
81
+ args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
82
+
83
+ if unknown_args:
84
+ print(parser.format_help())
85
+ print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
86
+ raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
87
+
88
+ if args.should_log:
89
+ transformers.utils.logging.set_verbosity(args.get_process_log_level())
90
+ transformers.utils.logging.enable_default_handler()
91
+ transformers.utils.logging.enable_explicit_format()
92
+ # set seeds manually
93
+ transformers.set_seed(args.seed)
94
+ return args