Add files using upload-large-folder tool
Browse files- create_yaml.py +115 -0
- fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc +0 -0
- fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc +0 -0
- fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc +0 -0
- fla3/ops/retention/__pycache__/fused_chunk.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/__init__.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/asm.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/index.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/index.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/matmul.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/op.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/pack.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/pack.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/pooling.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/softmax.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/softmax.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/solve_tril.cpython-312.pyc +0 -0
- fla3/ops/utils/logsumexp.py +80 -0
- fla3/ops/utils/matmul.py +245 -0
- fla3/ops/utils/op.py +39 -0
- fla3/ops/utils/pack.py +208 -0
- fla3/ops/utils/pooling.py +207 -0
- flame/__init__.py +0 -0
- flame/__pycache__/__init__.cpython-310.pyc +0 -0
- flame/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/__pycache__/data.cpython-310.pyc +0 -0
- flame/__pycache__/logging.cpython-310.pyc +0 -0
- flame/__pycache__/logging.cpython-312.pyc +0 -0
- flame/__pycache__/parser.cpython-310.pyc +0 -0
- flame/__pycache__/parser.cpython-312.pyc +0 -0
- flame/data.py +246 -0
- flame/parser.py +94 -0
create_yaml.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
# 从环境变量获取值,设置默认值
|
| 8 |
+
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
| 9 |
+
master_port = int(os.getenv("MASTER_PORT", 29500))
|
| 10 |
+
|
| 11 |
+
# 获取节点和GPU数量,确保转换为整数
|
| 12 |
+
try:
|
| 13 |
+
num_nodes = int(os.getenv("SENSECORE_PYTORCH_NNODES", 1))
|
| 14 |
+
except (ValueError, TypeError):
|
| 15 |
+
num_nodes = 1
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
gpus_per_node = int(os.getenv("SENSECORE_ACCELERATE_DEVICE_COUNT", 1))
|
| 19 |
+
except (ValueError, TypeError):
|
| 20 |
+
gpus_per_node = 1
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
node_rank = int(os.getenv("SENSECORE_PYTORCH_NODE_RANK", 0))
|
| 24 |
+
except (ValueError, TypeError):
|
| 25 |
+
node_rank = 0
|
| 26 |
+
|
| 27 |
+
# 计算总进程数
|
| 28 |
+
num_processes = num_nodes * gpus_per_node
|
| 29 |
+
|
| 30 |
+
# 配置字典
|
| 31 |
+
config = {
|
| 32 |
+
"compute_environment": "LOCAL_MACHINE",
|
| 33 |
+
"distributed_type": "DEEPSPEED",
|
| 34 |
+
"deepspeed_config": {
|
| 35 |
+
"deepspeed_config_file": "configs/ds_config.json",
|
| 36 |
+
"zero3_init_flag": True,
|
| 37 |
+
"deepspeed_multinode_launcher": "standard",
|
| 38 |
+
"deepspeed_hostfile": '/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt',
|
| 39 |
+
},
|
| 40 |
+
"machine_rank": node_rank,
|
| 41 |
+
"main_process_ip": master_addr,
|
| 42 |
+
"main_process_port": master_port,
|
| 43 |
+
"main_training_function": "main",
|
| 44 |
+
"num_machines": num_nodes,
|
| 45 |
+
"num_processes": num_processes,
|
| 46 |
+
"same_network": True,
|
| 47 |
+
"use_cpu": False,
|
| 48 |
+
"rdzv_backend": "c10d",
|
| 49 |
+
"tpu_env": [],
|
| 50 |
+
"tpu_use_cluster": False,
|
| 51 |
+
"tpu_use_sudo": False,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# 打印配置信息
|
| 55 |
+
print("Generated Configuration:")
|
| 56 |
+
print(f" Master: {master_addr}:{master_port}")
|
| 57 |
+
print(f" Number of nodes: {num_nodes}")
|
| 58 |
+
print(f" GPUs per node: {gpus_per_node}")
|
| 59 |
+
print(f" Total processes: {num_processes}")
|
| 60 |
+
print(f" Node rank: {node_rank}")
|
| 61 |
+
|
| 62 |
+
# 确保配置目录存在
|
| 63 |
+
os.makedirs("configs", exist_ok=True)
|
| 64 |
+
|
| 65 |
+
# 写入YAML文件
|
| 66 |
+
output_file = "/mnt/jfzn/msj/flash-linear-attention/legacy/training/configs/deepspeed_sencore.yaml"
|
| 67 |
+
with open(output_file, "w") as f:
|
| 68 |
+
yaml.dump(config, f, default_flow_style=False)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
print(f"\nConfiguration saved to: {output_file}")
|
| 74 |
+
|
| 75 |
+
# 同时输出文件内容供检查
|
| 76 |
+
print("\nFile content:")
|
| 77 |
+
with open(output_file, "r") as f:
|
| 78 |
+
print(f.read())
|
| 79 |
+
|
| 80 |
+
# 读取原始 SSH 配置
|
| 81 |
+
|
| 82 |
+
input_file = '/mnt/jfzn/msj/flash-linear-attention/legacy/training/ssh_config/config' # 你的文件名
|
| 83 |
+
output_file = '/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt' # 输出 hostfile
|
| 84 |
+
|
| 85 |
+
hostnames = []
|
| 86 |
+
|
| 87 |
+
with open(input_file, "r") as f:
|
| 88 |
+
for line in f:
|
| 89 |
+
line = line.strip()
|
| 90 |
+
if line.startswith("Hostname"):
|
| 91 |
+
# 提取 Hostname 后面的内容
|
| 92 |
+
hostname = line.split(None, 1)[1]
|
| 93 |
+
hostnames.append(hostname)
|
| 94 |
+
|
| 95 |
+
# 写入到 hostfile,每行一个 hostname
|
| 96 |
+
with open(output_file, "w") as f:
|
| 97 |
+
for host in hostnames:
|
| 98 |
+
f.write(host+ " slots=8\n")
|
| 99 |
+
|
| 100 |
+
print(f"提取了 {len(hostnames)} 个 hostname,已写入 {output_file}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# output_path='/mnt/jfzn/msj/flash-linear-attention/legacy/training/hostfile.txt'
|
| 104 |
+
# with open(output_path, "w") as f:
|
| 105 |
+
# addr = master_addr
|
| 106 |
+
# f.write(f"{addr} slots={8}\n")
|
| 107 |
+
# q = addr.split('-')
|
| 108 |
+
# q[2] = 'worker'
|
| 109 |
+
# addr = '-'.join(q)
|
| 110 |
+
# f.write(f"{addr} slots={8}\n")
|
| 111 |
+
# print(f"hostfile 已生成: {output_path}")
|
| 112 |
+
return 0
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
sys.exit(main())
|
fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (273 Bytes). View file
|
|
|
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd.cpython-310.pyc
ADDED
|
Binary file (5.02 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc
ADDED
|
Binary file (5.13 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc
ADDED
|
Binary file (4.65 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc
ADDED
|
Binary file (9.9 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/parallel_path_bwd_inter_dkv.cpython-310.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
fla3/ops/retention/__pycache__/fused_chunk.cpython-310.pyc
ADDED
|
Binary file (9.26 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (350 Bytes). View file
|
|
|
fla3/ops/rwkv7/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (361 Bytes). View file
|
|
|
fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/fused_k_update.cpython-310.pyc
ADDED
|
Binary file (3.93 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-312.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc
ADDED
|
Binary file (4.78 kB). View file
|
|
|
fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
fla3/ops/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
fla3/ops/utils/__pycache__/asm.cpython-310.pyc
ADDED
|
Binary file (482 Bytes). View file
|
|
|
fla3/ops/utils/__pycache__/cumsum.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
fla3/ops/utils/__pycache__/index.cpython-310.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
fla3/ops/utils/__pycache__/index.cpython-312.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
fla3/ops/utils/__pycache__/logsumexp.cpython-312.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
fla3/ops/utils/__pycache__/matmul.cpython-310.pyc
ADDED
|
Binary file (5.29 kB). View file
|
|
|
fla3/ops/utils/__pycache__/op.cpython-312.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
fla3/ops/utils/__pycache__/pack.cpython-310.pyc
ADDED
|
Binary file (4.56 kB). View file
|
|
|
fla3/ops/utils/__pycache__/pack.cpython-312.pyc
ADDED
|
Binary file (8.01 kB). View file
|
|
|
fla3/ops/utils/__pycache__/pooling.cpython-310.pyc
ADDED
|
Binary file (5.61 kB). View file
|
|
|
fla3/ops/utils/__pycache__/softmax.cpython-310.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
fla3/ops/utils/__pycache__/softmax.cpython-312.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
fla3/ops/utils/__pycache__/solve_tril.cpython-312.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
fla3/ops/utils/logsumexp.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.op import exp, log
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.heuristics({
|
| 14 |
+
'HAS_SCALE': lambda args: args['scale'] is not None
|
| 15 |
+
})
|
| 16 |
+
@triton.autotune(
|
| 17 |
+
configs=[
|
| 18 |
+
triton.Config({}, num_warps=num_warps)
|
| 19 |
+
for num_warps in [1, 2, 4, 8, 16, 32]
|
| 20 |
+
],
|
| 21 |
+
key=['D']
|
| 22 |
+
)
|
| 23 |
+
@triton.jit
|
| 24 |
+
def logsumexp_fwd_kernel(
|
| 25 |
+
x,
|
| 26 |
+
z,
|
| 27 |
+
scale,
|
| 28 |
+
D: tl.constexpr,
|
| 29 |
+
B: tl.constexpr,
|
| 30 |
+
HAS_SCALE: tl.constexpr
|
| 31 |
+
):
|
| 32 |
+
i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
|
| 33 |
+
o_d = i_d * B + tl.arange(0, B)
|
| 34 |
+
m_d = o_d < D
|
| 35 |
+
|
| 36 |
+
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
|
| 37 |
+
if HAS_SCALE:
|
| 38 |
+
b_x = b_x * scale
|
| 39 |
+
b_m = tl.max(b_x, 0)
|
| 40 |
+
b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m
|
| 41 |
+
tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def logsumexp_fwd(
|
| 45 |
+
x,
|
| 46 |
+
scale: Optional[float] = None,
|
| 47 |
+
dtype: Optional[torch.dtype] = None
|
| 48 |
+
):
|
| 49 |
+
r"""
|
| 50 |
+
Compute the logsumexp of the input tensor over the last dimension.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
x (Tensor):
|
| 54 |
+
The input tensor of any shape.
|
| 55 |
+
scale (Optional[float]):
|
| 56 |
+
The scale applied to the input tensor. Default: `None`.
|
| 57 |
+
dtype (Optional[torch.dtype]):
|
| 58 |
+
The data type of the output tensor. Default: `None`.
|
| 59 |
+
Returns:
|
| 60 |
+
Tensor: The logsumexp of the input tensor.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
shape = x.shape
|
| 64 |
+
x = x.view(-1, shape[-1])
|
| 65 |
+
N, D = x.shape
|
| 66 |
+
B = min(triton.next_power_of_2(D), 64 * 1024)
|
| 67 |
+
ND = triton.cdiv(D, B)
|
| 68 |
+
|
| 69 |
+
z = x.new_empty(N, ND, dtype=torch.float)
|
| 70 |
+
logsumexp_fwd_kernel[(N, ND)](
|
| 71 |
+
x=x,
|
| 72 |
+
z=z,
|
| 73 |
+
scale=scale,
|
| 74 |
+
D=D,
|
| 75 |
+
B=B
|
| 76 |
+
)
|
| 77 |
+
z = z.logsumexp(-1).view(*shape[:-1])
|
| 78 |
+
if dtype is not None and dtype != torch.float:
|
| 79 |
+
z = z.to(dtype)
|
| 80 |
+
return z
|
fla3/ops/utils/matmul.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
# code adapted from
|
| 5 |
+
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from ...ops.utils.op import exp
|
| 14 |
+
from ...utils import input_guard
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
| 18 |
+
# - A list of `triton.Config` objects that define different configurations of
|
| 19 |
+
# meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
|
| 20 |
+
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
|
| 21 |
+
# provided configs
|
| 22 |
+
@triton.heuristics({
|
| 23 |
+
'HAS_ALPHA': lambda args: args['alpha'] is not None,
|
| 24 |
+
'HAS_BETA': lambda args: args['beta'] is not None
|
| 25 |
+
})
|
| 26 |
+
@triton.autotune(
|
| 27 |
+
configs=[
|
| 28 |
+
triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
|
| 29 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
|
| 30 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 31 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 32 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 33 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4),
|
| 34 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2),
|
| 35 |
+
triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2),
|
| 36 |
+
# Good config for fp8 inputs.
|
| 37 |
+
# triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
|
| 38 |
+
# triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
|
| 39 |
+
# triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 40 |
+
# triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
|
| 41 |
+
# triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 42 |
+
# triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 43 |
+
# triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 44 |
+
# triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4)
|
| 45 |
+
],
|
| 46 |
+
key=['M', 'N', 'K']
|
| 47 |
+
)
|
| 48 |
+
@triton.jit
|
| 49 |
+
def matmul_kernel(
|
| 50 |
+
# Pointers to matrices
|
| 51 |
+
a,
|
| 52 |
+
b,
|
| 53 |
+
c,
|
| 54 |
+
input,
|
| 55 |
+
alpha,
|
| 56 |
+
beta,
|
| 57 |
+
# Matrix dimensions
|
| 58 |
+
M,
|
| 59 |
+
N,
|
| 60 |
+
K,
|
| 61 |
+
# The stride variables represent how much to increase the ptr by when moving by 1
|
| 62 |
+
# element in a particular dimension. E.g. `s_am` is how much to increase `a`
|
| 63 |
+
# by to get the element one row down (A has M rows).
|
| 64 |
+
stride_ab, stride_am, stride_ak, # a: batch, M, K
|
| 65 |
+
stride_bk, stride_bn, # b: K, N
|
| 66 |
+
stride_cb, stride_cm, stride_cn, # c: batch, M, N
|
| 67 |
+
# Meta-parameters
|
| 68 |
+
BM: tl.constexpr,
|
| 69 |
+
BK: tl.constexpr,
|
| 70 |
+
BN: tl.constexpr,
|
| 71 |
+
G: tl.constexpr,
|
| 72 |
+
ACTIVATION: tl.constexpr,
|
| 73 |
+
HAS_INPUT: tl.constexpr,
|
| 74 |
+
HAS_ALPHA: tl.constexpr,
|
| 75 |
+
HAS_BETA: tl.constexpr,
|
| 76 |
+
ALLOW_TF32: tl.constexpr,
|
| 77 |
+
X_DIM: tl.constexpr = 1,
|
| 78 |
+
):
|
| 79 |
+
"""Kernel for computing the matmul C = A x B.
|
| 80 |
+
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
| 81 |
+
"""
|
| 82 |
+
# -----------------------------------------------------------
|
| 83 |
+
# Map program ids `pid` to the block of C it should compute.
|
| 84 |
+
# This is done in a grouped ordering to promote L2 data reuse.
|
| 85 |
+
# See above `L2 Cache Optimizations` section for details.
|
| 86 |
+
i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 87 |
+
|
| 88 |
+
NM, NN = tl.num_programs(1), tl.num_programs(2)
|
| 89 |
+
i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
|
| 90 |
+
|
| 91 |
+
# ----------------------------------------------------------
|
| 92 |
+
# Create pointers for the first blocks of A and B.
|
| 93 |
+
# We will advance this pointer as we move in the K direction
|
| 94 |
+
# and accumulate
|
| 95 |
+
# `p_a` is a block of [BM, BK] pointers
|
| 96 |
+
# `p_b` is a block of [BK, BN] pointers
|
| 97 |
+
# See above `Pointer Arithmetic` section for details
|
| 98 |
+
a_batch_ptr = a + i_b * stride_ab
|
| 99 |
+
o_am = (i_m * BM + tl.arange(0, BM)) % M
|
| 100 |
+
o_bn = (i_n * BN + tl.arange(0, BN)) % N
|
| 101 |
+
o_k = tl.arange(0, BK)
|
| 102 |
+
|
| 103 |
+
p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak)
|
| 104 |
+
p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn)
|
| 105 |
+
|
| 106 |
+
b_acc = tl.zeros((BM, BN), dtype=tl.float32)
|
| 107 |
+
for k in range(0, tl.cdiv(K, BK)):
|
| 108 |
+
# Load the next block of A and B, generate a mask by checking the K dimension.
|
| 109 |
+
# If it is out of bounds, set it to 0.
|
| 110 |
+
b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)
|
| 111 |
+
b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)
|
| 112 |
+
# We accumulate along the K dimension.
|
| 113 |
+
b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32)
|
| 114 |
+
# Advance the ptrs to the next K block.
|
| 115 |
+
p_a += BK * stride_ak
|
| 116 |
+
p_b += BK * stride_bk
|
| 117 |
+
|
| 118 |
+
o_cm = i_m * BM + tl.arange(0, BM)
|
| 119 |
+
o_cn = i_n * BN + tl.arange(0, BN)
|
| 120 |
+
mask = (o_cm[:, None] < M) & (o_cn[None, :] < N)
|
| 121 |
+
|
| 122 |
+
b_c = b_acc
|
| 123 |
+
# You can fuse arbitrary activation functions here
|
| 124 |
+
# while the b_acc is still in FP32!
|
| 125 |
+
if ACTIVATION == "leaky_relu":
|
| 126 |
+
b_c = leaky_relu(b_c)
|
| 127 |
+
elif ACTIVATION == "relu":
|
| 128 |
+
b_c = relu(b_c)
|
| 129 |
+
elif ACTIVATION == "sigmoid":
|
| 130 |
+
b_c = sigmoid(b_c)
|
| 131 |
+
elif ACTIVATION == "tanh":
|
| 132 |
+
b_c = tanh(b_c)
|
| 133 |
+
|
| 134 |
+
if HAS_ALPHA:
|
| 135 |
+
b_c *= tl.load(alpha)
|
| 136 |
+
|
| 137 |
+
if HAS_INPUT:
|
| 138 |
+
p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :]
|
| 139 |
+
mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask
|
| 140 |
+
b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32)
|
| 141 |
+
if HAS_BETA:
|
| 142 |
+
b_i *= tl.load(beta)
|
| 143 |
+
b_c += b_i
|
| 144 |
+
|
| 145 |
+
# -----------------------------------------------------------
|
| 146 |
+
# Write back the block of the output matrix C with masks.
|
| 147 |
+
c_batch_ptr = c + i_b * stride_cb
|
| 148 |
+
p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :]
|
| 149 |
+
tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
|
| 153 |
+
@triton.jit
|
| 154 |
+
def leaky_relu(x):
|
| 155 |
+
return tl.where(x >= 0, x, 0.01 * x)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@triton.jit
|
| 159 |
+
def sigmoid(x):
|
| 160 |
+
# σ(x) = 1 / (1 + exp(-x))
|
| 161 |
+
return 1.0 / (1.0 + exp(-x))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@triton.jit
|
| 165 |
+
def tanh(x):
|
| 166 |
+
# tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
| 167 |
+
# 2 * sigmoid(2x) - 1
|
| 168 |
+
return (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@triton.jit
|
| 172 |
+
def relu(x):
|
| 173 |
+
# ReLU(x) = max(0, x)
|
| 174 |
+
return tl.maximum(x, 0.0)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@input_guard
|
| 178 |
+
def matmul(a, b, activation=''):
|
| 179 |
+
assert a.dim() in [2, 3], "a must be 2D or 3D"
|
| 180 |
+
assert b.dim() == 2, "b must be 2D"
|
| 181 |
+
assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
|
| 182 |
+
|
| 183 |
+
if a.dim() == 2:
|
| 184 |
+
a_dim = 2
|
| 185 |
+
a = a.unsqueeze(0).contiguous() # (1, M, K)
|
| 186 |
+
else:
|
| 187 |
+
a_dim = 3
|
| 188 |
+
allow_tf32 = False if a.dtype == torch.float32 else True
|
| 189 |
+
|
| 190 |
+
B, M, K = a.shape[0], a.shape[1], a.shape[2]
|
| 191 |
+
K_b, N = b.shape
|
| 192 |
+
assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
|
| 193 |
+
c = a.new_empty(B, M, N)
|
| 194 |
+
|
| 195 |
+
def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
|
| 196 |
+
matmul_kernel[grid](
|
| 197 |
+
a, b, c, None, None, None,
|
| 198 |
+
M, N, K,
|
| 199 |
+
a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
|
| 200 |
+
b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
|
| 201 |
+
c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
|
| 202 |
+
ACTIVATION=activation,
|
| 203 |
+
ALLOW_TF32=allow_tf32,
|
| 204 |
+
HAS_INPUT=False,
|
| 205 |
+
)
|
| 206 |
+
return c.squeeze(0) if a_dim == 2 else c
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@input_guard
|
| 210 |
+
def addmm(
|
| 211 |
+
x: torch.Tensor,
|
| 212 |
+
a: torch.Tensor,
|
| 213 |
+
b: torch.Tensor,
|
| 214 |
+
alpha: Optional[float] = None,
|
| 215 |
+
beta: Optional[float] = None,
|
| 216 |
+
) -> torch.Tensor:
|
| 217 |
+
assert a.dim() in [2, 3], "a must be 2D or 3D"
|
| 218 |
+
assert b.dim() == 2, "b must be 2D"
|
| 219 |
+
assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
|
| 220 |
+
|
| 221 |
+
if a.dim() == 2:
|
| 222 |
+
a_dim = 2
|
| 223 |
+
a = a.unsqueeze(0).contiguous() # (1, M, K)
|
| 224 |
+
else:
|
| 225 |
+
a_dim = 3
|
| 226 |
+
allow_tf32 = False if a.dtype == torch.float32 else True
|
| 227 |
+
|
| 228 |
+
B, M, K = a.shape[0], a.shape[1], a.shape[2]
|
| 229 |
+
K_b, N = b.shape
|
| 230 |
+
assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
|
| 231 |
+
c = a.new_empty(B, M, N)
|
| 232 |
+
|
| 233 |
+
def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
|
| 234 |
+
matmul_kernel[grid](
|
| 235 |
+
a, b, c, x, alpha, beta,
|
| 236 |
+
M, N, K,
|
| 237 |
+
a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
|
| 238 |
+
b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
|
| 239 |
+
c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
|
| 240 |
+
ACTIVATION=None,
|
| 241 |
+
ALLOW_TF32=allow_tf32,
|
| 242 |
+
HAS_INPUT=True,
|
| 243 |
+
X_DIM=x.dim(),
|
| 244 |
+
)
|
| 245 |
+
return c.squeeze(0) if a_dim == 2 else c
|
fla3/ops/utils/op.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2024, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import triton
|
| 7 |
+
import triton.language as tl
|
| 8 |
+
import triton.language.extra.libdevice as tldevice
|
| 9 |
+
|
| 10 |
+
from ...utils import is_gather_supported
|
| 11 |
+
|
| 12 |
+
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
| 13 |
+
div = tldevice.fast_dividef
|
| 14 |
+
exp = tldevice.fast_expf
|
| 15 |
+
log = tldevice.fast_logf
|
| 16 |
+
log2 = tldevice.fast_log2f
|
| 17 |
+
else:
|
| 18 |
+
@triton.jit
|
| 19 |
+
def div_normal(x, y):
|
| 20 |
+
return x / y
|
| 21 |
+
div = div_normal
|
| 22 |
+
exp = tl.exp
|
| 23 |
+
log = tl.log
|
| 24 |
+
log2 = tl.log2
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def safe_exp(x):
|
| 29 |
+
return exp(tl.where(x <= 0, x, float('-inf')))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if not is_gather_supported:
|
| 33 |
+
@triton.jit
|
| 34 |
+
def gather(src, index, axis, _builder=None):
|
| 35 |
+
# This is a fallback implementation when tl.gather is not supported
|
| 36 |
+
# In order to pass triton compiler, there is no actual gather operation
|
| 37 |
+
return src
|
| 38 |
+
else:
|
| 39 |
+
gather = tl.gather
|
fla3/ops/utils/pack.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
# Code adapted from https://github.com/mayank31398/cute-kernels
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
|
| 12 |
+
from ...ops.utils.index import prepare_lens
|
| 13 |
+
from ...utils import input_guard
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@triton.autotune(
|
| 17 |
+
configs=[
|
| 18 |
+
triton.Config({}, num_warps=num_warps)
|
| 19 |
+
for num_warps in [4, 8, 16, 32]
|
| 20 |
+
],
|
| 21 |
+
key=['D', 'PADDING_SIDE', 'PACK']
|
| 22 |
+
)
|
| 23 |
+
@triton.jit
|
| 24 |
+
def packunpack_sequence_kernel(
|
| 25 |
+
x,
|
| 26 |
+
y,
|
| 27 |
+
cu_seqlens,
|
| 28 |
+
S,
|
| 29 |
+
D,
|
| 30 |
+
BD: tl.constexpr,
|
| 31 |
+
PADDING_SIDE: tl.constexpr,
|
| 32 |
+
PACK: tl.constexpr,
|
| 33 |
+
):
|
| 34 |
+
i_d, i_s, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 35 |
+
bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
|
| 36 |
+
|
| 37 |
+
T = eos - bos
|
| 38 |
+
if PADDING_SIDE == 'left':
|
| 39 |
+
NP = S - T
|
| 40 |
+
if i_s < NP:
|
| 41 |
+
return
|
| 42 |
+
i_t = bos + (i_s - NP)
|
| 43 |
+
else:
|
| 44 |
+
if i_s >= T:
|
| 45 |
+
return
|
| 46 |
+
i_t = bos + i_s
|
| 47 |
+
|
| 48 |
+
o_d = i_d * BD + tl.arange(0, BD)
|
| 49 |
+
mask = o_d < D
|
| 50 |
+
|
| 51 |
+
if PACK:
|
| 52 |
+
b_x = tl.load(x + (i_b * S + i_s) * D + o_d, mask=mask)
|
| 53 |
+
tl.store(y + i_t * D + o_d, b_x, mask=mask)
|
| 54 |
+
else:
|
| 55 |
+
b_x = tl.load(x + i_t * D + o_d, mask=mask)
|
| 56 |
+
tl.store(y + (i_b * S + i_s) * D + o_d, b_x, mask=mask)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def pack_sequence_fwdbwd(
|
| 60 |
+
x: torch.Tensor,
|
| 61 |
+
cu_seqlens: torch.Tensor,
|
| 62 |
+
padding_side: str,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
B, S = x.shape[:2]
|
| 65 |
+
D = x.numel() // (B * S)
|
| 66 |
+
BD = min(triton.next_power_of_2(D), 4096)
|
| 67 |
+
ND = triton.cdiv(D, BD)
|
| 68 |
+
|
| 69 |
+
y = torch.empty(cu_seqlens[-1].item(), *x.shape[2:], device=x.device, dtype=x.dtype)
|
| 70 |
+
packunpack_sequence_kernel[ND, S, B](
|
| 71 |
+
x=x,
|
| 72 |
+
y=y,
|
| 73 |
+
cu_seqlens=cu_seqlens,
|
| 74 |
+
S=S,
|
| 75 |
+
D=D,
|
| 76 |
+
BD=BD,
|
| 77 |
+
PADDING_SIDE=padding_side,
|
| 78 |
+
PACK=True,
|
| 79 |
+
)
|
| 80 |
+
return y
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def unpack_sequence_fwdbwd(
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
cu_seqlens: torch.Tensor,
|
| 86 |
+
padding_side: str,
|
| 87 |
+
desired_shape: torch.Size,
|
| 88 |
+
) -> torch.Tensor:
|
| 89 |
+
if desired_shape is None:
|
| 90 |
+
desired_shape = (len(cu_seqlens) - 1, prepare_lens(cu_seqlens).max().item(), *x.shape[1:])
|
| 91 |
+
y = torch.zeros(desired_shape, device=x.device, dtype=x.dtype)
|
| 92 |
+
B, S = y.shape[:2]
|
| 93 |
+
D = y.numel() // (B * S)
|
| 94 |
+
BD = min(triton.next_power_of_2(D), 4096)
|
| 95 |
+
ND = triton.cdiv(D, BD)
|
| 96 |
+
|
| 97 |
+
packunpack_sequence_kernel[ND, S, B](
|
| 98 |
+
x=x,
|
| 99 |
+
y=y,
|
| 100 |
+
cu_seqlens=cu_seqlens,
|
| 101 |
+
S=S,
|
| 102 |
+
D=D,
|
| 103 |
+
BD=BD,
|
| 104 |
+
PADDING_SIDE=padding_side,
|
| 105 |
+
PACK=False,
|
| 106 |
+
)
|
| 107 |
+
return y
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class PackSequenceFunction(torch.autograd.Function):
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
@input_guard
|
| 114 |
+
def forward(
|
| 115 |
+
ctx,
|
| 116 |
+
x: torch.Tensor,
|
| 117 |
+
cu_seqlens: torch.Tensor,
|
| 118 |
+
padding_side: str,
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
assert padding_side in ['left', 'right']
|
| 121 |
+
assert x.ndim >= 2
|
| 122 |
+
|
| 123 |
+
ctx.cu_seqlens = cu_seqlens
|
| 124 |
+
ctx.padding_side = padding_side
|
| 125 |
+
ctx.desired_shape = x.shape
|
| 126 |
+
|
| 127 |
+
y = pack_sequence_fwdbwd(
|
| 128 |
+
x=x,
|
| 129 |
+
cu_seqlens=cu_seqlens,
|
| 130 |
+
padding_side=padding_side,
|
| 131 |
+
)
|
| 132 |
+
return y
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
@input_guard
|
| 136 |
+
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
|
| 137 |
+
dx = unpack_sequence_fwdbwd(
|
| 138 |
+
x=dy,
|
| 139 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 140 |
+
padding_side=ctx.padding_side,
|
| 141 |
+
desired_shape=ctx.desired_shape,
|
| 142 |
+
)
|
| 143 |
+
return dx, *[None] * 10
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class UnpackSequenceFunction(torch.autograd.Function):
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
@input_guard
|
| 150 |
+
def forward(
|
| 151 |
+
ctx,
|
| 152 |
+
x: torch.Tensor,
|
| 153 |
+
cu_seqlens: torch.Tensor,
|
| 154 |
+
padding_side: str,
|
| 155 |
+
desired_shape: Optional[torch.Size] = None,
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
assert padding_side in ['left', 'right']
|
| 158 |
+
assert x.ndim >= 2
|
| 159 |
+
if desired_shape is not None:
|
| 160 |
+
assert desired_shape[0] == cu_seqlens.shape[0] - 1
|
| 161 |
+
assert desired_shape[2:] == x.shape[1:]
|
| 162 |
+
|
| 163 |
+
ctx.cu_seqlens = cu_seqlens
|
| 164 |
+
ctx.padding_side = padding_side
|
| 165 |
+
|
| 166 |
+
y = unpack_sequence_fwdbwd(
|
| 167 |
+
x=x,
|
| 168 |
+
cu_seqlens=cu_seqlens,
|
| 169 |
+
padding_side=padding_side,
|
| 170 |
+
desired_shape=desired_shape,
|
| 171 |
+
)
|
| 172 |
+
return y
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
@input_guard
|
| 176 |
+
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
|
| 177 |
+
dx = pack_sequence_fwdbwd(
|
| 178 |
+
x=dy,
|
| 179 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 180 |
+
padding_side=ctx.padding_side,
|
| 181 |
+
)
|
| 182 |
+
return dx, None, None, None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def pack_sequence(
|
| 186 |
+
x: torch.Tensor,
|
| 187 |
+
cu_seqlens: torch.Tensor,
|
| 188 |
+
padding_side: str = 'left'
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
return PackSequenceFunction.apply(
|
| 191 |
+
x,
|
| 192 |
+
cu_seqlens,
|
| 193 |
+
padding_side,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def unpack_sequence(
|
| 198 |
+
x: torch.Tensor,
|
| 199 |
+
cu_seqlens: torch.Tensor,
|
| 200 |
+
padding_side: str = 'left',
|
| 201 |
+
desired_shape: Optional[torch.Size] = None,
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
return UnpackSequenceFunction.apply(
|
| 204 |
+
x,
|
| 205 |
+
cu_seqlens,
|
| 206 |
+
padding_side,
|
| 207 |
+
desired_shape,
|
| 208 |
+
)
|
fla3/ops/utils/pooling.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.index import prepare_chunk_indices
|
| 11 |
+
from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@triton.heuristics({
|
| 15 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 16 |
+
})
|
| 17 |
+
@triton.autotune(
|
| 18 |
+
configs=[
|
| 19 |
+
triton.Config({'BD': BD}, num_warps=num_warps)
|
| 20 |
+
for BD in [16, 32, 64, 128]
|
| 21 |
+
for num_warps in [1, 2, 4, 8]
|
| 22 |
+
],
|
| 23 |
+
key=['BT']
|
| 24 |
+
)
|
| 25 |
+
@triton.jit(do_not_specialize=['T'])
|
| 26 |
+
def mean_pooling_fwd_kernel(
|
| 27 |
+
x,
|
| 28 |
+
o,
|
| 29 |
+
cu_seqlens,
|
| 30 |
+
chunk_indices,
|
| 31 |
+
T,
|
| 32 |
+
H: tl.constexpr,
|
| 33 |
+
D: tl.constexpr,
|
| 34 |
+
BT: tl.constexpr,
|
| 35 |
+
BD: tl.constexpr,
|
| 36 |
+
IS_VARLEN: tl.constexpr
|
| 37 |
+
):
|
| 38 |
+
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 39 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 40 |
+
if IS_VARLEN:
|
| 41 |
+
i_tg = i_t
|
| 42 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 43 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 44 |
+
T = eos - bos
|
| 45 |
+
NT = tl.cdiv(T, BT)
|
| 46 |
+
else:
|
| 47 |
+
NT = tl.cdiv(T, BT)
|
| 48 |
+
i_tg = i_b * NT + i_t
|
| 49 |
+
bos, eos = i_b * T, i_b * T + T
|
| 50 |
+
|
| 51 |
+
p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
| 52 |
+
p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
|
| 53 |
+
# [BT, BD]
|
| 54 |
+
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
| 55 |
+
# [BD]
|
| 56 |
+
b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
|
| 57 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@triton.heuristics({
|
| 61 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 62 |
+
})
|
| 63 |
+
@triton.autotune(
|
| 64 |
+
configs=[
|
| 65 |
+
triton.Config({'BD': BD}, num_warps=num_warps)
|
| 66 |
+
for BD in [16, 32, 64, 128]
|
| 67 |
+
for num_warps in [1, 2, 4, 8]
|
| 68 |
+
],
|
| 69 |
+
key=['BT']
|
| 70 |
+
)
|
| 71 |
+
@triton.jit(do_not_specialize=['T'])
|
| 72 |
+
def mean_pooling_bwd_kernel(
|
| 73 |
+
do,
|
| 74 |
+
dx,
|
| 75 |
+
cu_seqlens,
|
| 76 |
+
chunk_indices,
|
| 77 |
+
T,
|
| 78 |
+
H: tl.constexpr,
|
| 79 |
+
D: tl.constexpr,
|
| 80 |
+
BT: tl.constexpr,
|
| 81 |
+
BD: tl.constexpr,
|
| 82 |
+
IS_VARLEN: tl.constexpr
|
| 83 |
+
):
|
| 84 |
+
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 85 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 86 |
+
if IS_VARLEN:
|
| 87 |
+
i_tg = i_t
|
| 88 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 89 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 90 |
+
T = eos - bos
|
| 91 |
+
NT = tl.cdiv(T, BT)
|
| 92 |
+
else:
|
| 93 |
+
NT = tl.cdiv(T, BT)
|
| 94 |
+
i_tg = i_b * NT + i_t
|
| 95 |
+
bos, eos = i_b * T, i_b * T + T
|
| 96 |
+
|
| 97 |
+
p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
| 98 |
+
p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
|
| 99 |
+
# [BD]
|
| 100 |
+
b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
|
| 101 |
+
# [BT, BD]
|
| 102 |
+
b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
|
| 103 |
+
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def mean_pooling_fwd(
|
| 107 |
+
x: torch.Tensor,
|
| 108 |
+
chunk_size: int,
|
| 109 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
+
B, T, H, D = x.shape
|
| 112 |
+
BT = chunk_size
|
| 113 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 114 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 115 |
+
|
| 116 |
+
o = x.new_empty(B, NT, H, D)
|
| 117 |
+
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
|
| 118 |
+
mean_pooling_fwd_kernel[grid](
|
| 119 |
+
x,
|
| 120 |
+
o,
|
| 121 |
+
cu_seqlens,
|
| 122 |
+
chunk_indices,
|
| 123 |
+
T=T,
|
| 124 |
+
H=H,
|
| 125 |
+
D=D,
|
| 126 |
+
BT=BT,
|
| 127 |
+
)
|
| 128 |
+
return o
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def mean_pooling_bwd(
|
| 132 |
+
do: torch.Tensor,
|
| 133 |
+
batch_size: int,
|
| 134 |
+
seq_len: int,
|
| 135 |
+
chunk_size: int,
|
| 136 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
B, T, H, D = batch_size, seq_len, *do.shape[-2:]
|
| 139 |
+
BT = chunk_size
|
| 140 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 141 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 142 |
+
|
| 143 |
+
dx = do.new_empty(B, T, H, D)
|
| 144 |
+
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
|
| 145 |
+
mean_pooling_bwd_kernel[grid](
|
| 146 |
+
do,
|
| 147 |
+
dx,
|
| 148 |
+
cu_seqlens,
|
| 149 |
+
chunk_indices,
|
| 150 |
+
T=T,
|
| 151 |
+
H=H,
|
| 152 |
+
D=D,
|
| 153 |
+
BT=BT,
|
| 154 |
+
)
|
| 155 |
+
return dx
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MeanPoolingFunction(torch.autograd.Function):
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
@input_guard
|
| 162 |
+
@autocast_custom_fwd
|
| 163 |
+
def forward(
|
| 164 |
+
ctx,
|
| 165 |
+
x: torch.Tensor,
|
| 166 |
+
chunk_size: int,
|
| 167 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 168 |
+
) -> torch.Tensor:
|
| 169 |
+
o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
|
| 170 |
+
ctx.batch_size = x.shape[0]
|
| 171 |
+
ctx.seq_len = x.shape[1]
|
| 172 |
+
ctx.chunk_size = chunk_size
|
| 173 |
+
ctx.cu_seqlens = cu_seqlens
|
| 174 |
+
return o
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
@input_guard
|
| 178 |
+
@autocast_custom_bwd
|
| 179 |
+
def backward(
|
| 180 |
+
ctx, do
|
| 181 |
+
) -> Tuple[torch.Tensor, None, None]:
|
| 182 |
+
batch_size = ctx.batch_size
|
| 183 |
+
seq_len = ctx.seq_len
|
| 184 |
+
chunk_size = ctx.chunk_size
|
| 185 |
+
cu_seqlens = ctx.cu_seqlens
|
| 186 |
+
dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
|
| 187 |
+
return dx, None, None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def mean_pooling(
|
| 191 |
+
x: torch.Tensor,
|
| 192 |
+
chunk_size: int,
|
| 193 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 194 |
+
head_first: bool = False
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
if head_first:
|
| 197 |
+
x = x.transpose(1, 2)
|
| 198 |
+
if cu_seqlens is not None:
|
| 199 |
+
if x.shape[0] != 1:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
|
| 202 |
+
f"Please ..tten variable-length inputs before processing."
|
| 203 |
+
)
|
| 204 |
+
o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
|
| 205 |
+
if head_first:
|
| 206 |
+
o = o.transpose(1, 2)
|
| 207 |
+
return o
|
flame/__init__.py
ADDED
|
File without changes
|
flame/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
flame/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
flame/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (8.17 kB). View file
|
|
|
flame/__pycache__/logging.cpython-310.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
flame/__pycache__/logging.cpython-312.pyc
ADDED
|
Binary file (6.44 kB). View file
|
|
|
flame/__pycache__/parser.cpython-310.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
flame/__pycache__/parser.cpython-312.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
flame/data.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, Iterable, List, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from datasets import Dataset, IterableDataset
|
| 12 |
+
from flame.logging import get_logger
|
| 13 |
+
from transformers import PreTrainedTokenizer
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HuggingfaceDataset(IterableDataset):
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dataset: Dataset,
|
| 23 |
+
tokenizer: PreTrainedTokenizer,
|
| 24 |
+
context_len: int = 2048,
|
| 25 |
+
rank: int = 0,
|
| 26 |
+
world_size: int = 1,
|
| 27 |
+
buffer_size: int = 1024
|
| 28 |
+
) -> HuggingfaceDataset:
|
| 29 |
+
|
| 30 |
+
self.dataset = dataset
|
| 31 |
+
self.tokenizer = tokenizer
|
| 32 |
+
|
| 33 |
+
self.data = dataset.shard(world_size, rank)
|
| 34 |
+
self.context_len = context_len
|
| 35 |
+
self.rank = rank
|
| 36 |
+
self.world_size = world_size
|
| 37 |
+
self.buffer_size = buffer_size
|
| 38 |
+
|
| 39 |
+
if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
|
| 40 |
+
self.dtype = torch.int16
|
| 41 |
+
elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
|
| 42 |
+
self.dtype = torch.int32
|
| 43 |
+
else:
|
| 44 |
+
self.dtype = torch.int64
|
| 45 |
+
self.states = None
|
| 46 |
+
self.buffer = torch.tensor([], dtype=self.dtype)
|
| 47 |
+
self.tokens = []
|
| 48 |
+
self.rand_id = 0
|
| 49 |
+
self.token_id = 0
|
| 50 |
+
self.rng_state = None
|
| 51 |
+
self._epoch = 0
|
| 52 |
+
|
| 53 |
+
def __iter__(self):
|
| 54 |
+
g = torch.Generator()
|
| 55 |
+
g.manual_seed(self._epoch + self.rank)
|
| 56 |
+
if self.rng_state is not None:
|
| 57 |
+
g.set_state(self.rng_state)
|
| 58 |
+
|
| 59 |
+
rand_it = self.randint(0, self.buffer_size, g=g)
|
| 60 |
+
if self.states is not None:
|
| 61 |
+
self.data.load_state_dict(self.states)
|
| 62 |
+
|
| 63 |
+
# max number of tokens allowed in the chunk buffer
|
| 64 |
+
n_tokens = self.buffer_size * self.context_len
|
| 65 |
+
|
| 66 |
+
while True:
|
| 67 |
+
for sample in self.tokenize(self.data):
|
| 68 |
+
# keep appending the samples to the token buffer
|
| 69 |
+
self.tokens += sample
|
| 70 |
+
# if the token buffer is full, start sampling
|
| 71 |
+
# NOTE: we first convert the token ids to a tensor of shape [n_chunks, context_len] for efficiency
|
| 72 |
+
if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
|
| 73 |
+
self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
|
| 74 |
+
self.tokens = self.tokens[n_tokens:]
|
| 75 |
+
if len(self.buffer) == self.buffer_size:
|
| 76 |
+
yield from self.sample(rand_it)
|
| 77 |
+
|
| 78 |
+
n_chunks = len(self.tokens) // self.context_len
|
| 79 |
+
# handle the left tokens in the buffer
|
| 80 |
+
if n_chunks > 0:
|
| 81 |
+
n_tokens = n_chunks * self.context_len
|
| 82 |
+
indices = torch.randperm(n_chunks, generator=g).tolist()
|
| 83 |
+
self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
|
| 84 |
+
self.tokens = self.tokens[n_tokens:]
|
| 85 |
+
for i in indices:
|
| 86 |
+
yield {'input_ids': self.buffer[i]}
|
| 87 |
+
|
| 88 |
+
def tokenize(self, data, batch_size: int = 64):
|
| 89 |
+
texts, states = [], []
|
| 90 |
+
for sample in data:
|
| 91 |
+
texts.append(sample['text'])
|
| 92 |
+
states.append(self.data.state_dict())
|
| 93 |
+
if len(texts) == batch_size:
|
| 94 |
+
for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
|
| 95 |
+
self.states = s
|
| 96 |
+
yield tokenized
|
| 97 |
+
texts, states = [], []
|
| 98 |
+
if len(texts) > 0:
|
| 99 |
+
for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
|
| 100 |
+
self.states = s
|
| 101 |
+
yield tokenized
|
| 102 |
+
|
| 103 |
+
def sample(self, indices):
|
| 104 |
+
n_tokens = (len(self.tokens) // self.context_len) * self.context_len
|
| 105 |
+
while self.token_id < n_tokens:
|
| 106 |
+
i = next(indices)
|
| 107 |
+
start, end = self.token_id, self.token_id + self.context_len
|
| 108 |
+
self.token_id += self.context_len
|
| 109 |
+
yield {'input_ids': self.buffer[i].to(torch.long)}
|
| 110 |
+
self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
|
| 111 |
+
self.token_id = 0
|
| 112 |
+
self.tokens = self.tokens[n_tokens:]
|
| 113 |
+
|
| 114 |
+
def randint(
|
| 115 |
+
self,
|
| 116 |
+
low: int,
|
| 117 |
+
high: int,
|
| 118 |
+
batch_size: int = 1024,
|
| 119 |
+
g: torch.Generator = torch.Generator()
|
| 120 |
+
) -> Iterable[int]:
|
| 121 |
+
indices = torch.empty(batch_size, dtype=torch.long)
|
| 122 |
+
while True:
|
| 123 |
+
# record the generator states before sampling
|
| 124 |
+
self.rng_state = g.get_state()
|
| 125 |
+
indices = torch.randint(low, high, (batch_size,), out=indices, generator=g)
|
| 126 |
+
for i in indices[self.rand_id:].tolist():
|
| 127 |
+
self.rand_id += 1
|
| 128 |
+
yield i
|
| 129 |
+
self.rand_id = 0
|
| 130 |
+
|
| 131 |
+
def set_epoch(self, epoch):
|
| 132 |
+
self._epoch = epoch
|
| 133 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 134 |
+
self.dataset.set_epoch(epoch)
|
| 135 |
+
|
| 136 |
+
def state_dict(self):
|
| 137 |
+
return {
|
| 138 |
+
'states': self.states,
|
| 139 |
+
'buffer': self.buffer.clone(),
|
| 140 |
+
'tokens': deepcopy(self.tokens),
|
| 141 |
+
'rand_id': self.rand_id,
|
| 142 |
+
'token_id': self.token_id,
|
| 143 |
+
'rng_state': self.rng_state,
|
| 144 |
+
'epoch': self._epoch
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def load_state_dict(self, state_dict):
|
| 148 |
+
self.states = state_dict['states']
|
| 149 |
+
self.buffer = state_dict['buffer'].clone()
|
| 150 |
+
self.tokens = deepcopy(state_dict['tokens'])
|
| 151 |
+
self.rand_id = state_dict['rand_id']
|
| 152 |
+
self.token_id = state_dict['token_id']
|
| 153 |
+
self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
|
| 154 |
+
self._epoch = state_dict['epoch']
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@dataclass
|
| 158 |
+
class DataCollatorForLanguageModeling:
|
| 159 |
+
"""
|
| 160 |
+
Data collator used for language modeling.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 164 |
+
The tokenizer used for encoding the data.
|
| 165 |
+
varlen (`bool`):
|
| 166 |
+
Whether to return sequences with variable lengths.
|
| 167 |
+
If `True`, the offsets indicating the start and end of each sequence will be returned.
|
| 168 |
+
For example, if the sequence lengths are `[4, 8, 12]`,
|
| 169 |
+
the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`.
|
| 170 |
+
If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly.
|
| 171 |
+
return_tensors (`str`):
|
| 172 |
+
The type of Tensor to return. Allowable values are "pt".
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
tokenizer: PreTrainedTokenizer
|
| 176 |
+
varlen: bool = False
|
| 177 |
+
return_tensors: str = "pt"
|
| 178 |
+
|
| 179 |
+
def __call__(
|
| 180 |
+
self,
|
| 181 |
+
examples: List[Union[List[int], Dict[str, Any]]]
|
| 182 |
+
) -> Dict[str, Any]:
|
| 183 |
+
if not isinstance(examples[0], Dict):
|
| 184 |
+
examples = [{'input_ids': example} for example in examples]
|
| 185 |
+
|
| 186 |
+
def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 187 |
+
tensorized = {}
|
| 188 |
+
for key in ['input_ids', 'offsets']:
|
| 189 |
+
if key not in example:
|
| 190 |
+
continue
|
| 191 |
+
if isinstance(example[key], List):
|
| 192 |
+
tensorized[key] = torch.tensor(example[key], dtype=torch.long)
|
| 193 |
+
elif isinstance(example[key], np.ndarray):
|
| 194 |
+
tensorized[key] = torch.from_numpy(example[key])
|
| 195 |
+
else:
|
| 196 |
+
tensorized[key] = example[key]
|
| 197 |
+
return tensorized
|
| 198 |
+
|
| 199 |
+
examples = list(map(tensorize, examples))
|
| 200 |
+
|
| 201 |
+
if not self.varlen:
|
| 202 |
+
length_of_first = examples[0]['input_ids'].size(0)
|
| 203 |
+
# Check if padding is necessary.
|
| 204 |
+
if all(example['input_ids'].size(0) == length_of_first for example in examples):
|
| 205 |
+
batch = {
|
| 206 |
+
'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0),
|
| 207 |
+
}
|
| 208 |
+
else:
|
| 209 |
+
# If yes, check if we have a `pad_token`.
|
| 210 |
+
if self.tokenizer._pad_token is None:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"You are attempting to pad samples but the tokenizer you are using "
|
| 213 |
+
f"({self.tokenizer.__class__.__name__}) does not have a pad token."
|
| 214 |
+
)
|
| 215 |
+
batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False)
|
| 216 |
+
else:
|
| 217 |
+
if len(examples) > 1:
|
| 218 |
+
raise ValueError("The batch size must be 1 for variable length inputs.")
|
| 219 |
+
batch = {
|
| 220 |
+
'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)
|
| 221 |
+
}
|
| 222 |
+
if 'offsets' in examples[0]:
|
| 223 |
+
batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0)
|
| 224 |
+
else:
|
| 225 |
+
# determine boundaries by bos/eos positions
|
| 226 |
+
if self.tokenizer.add_bos_token:
|
| 227 |
+
offsets = []
|
| 228 |
+
if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
|
| 229 |
+
offsets.append(torch.tensor([0], dtype=torch.long))
|
| 230 |
+
offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1])
|
| 231 |
+
offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
|
| 232 |
+
batch['offsets'] = torch.cat(offsets, dim=0)
|
| 233 |
+
elif self.tokenizer.add_eos_token:
|
| 234 |
+
offsets = [torch.tensor([0], dtype=torch.long)]
|
| 235 |
+
offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1)
|
| 236 |
+
if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
|
| 237 |
+
offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
|
| 238 |
+
batch['offsets'] = torch.cat(offsets, dim=0)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.")
|
| 241 |
+
|
| 242 |
+
labels = batch['input_ids'].clone()
|
| 243 |
+
if self.tokenizer.pad_token_id is not None:
|
| 244 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 245 |
+
batch["labels"] = labels
|
| 246 |
+
return batch
|
flame/parser.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import transformers
|
| 9 |
+
from transformers import HfArgumentParser, TrainingArguments
|
| 10 |
+
|
| 11 |
+
from flame.logging import get_logger
|
| 12 |
+
|
| 13 |
+
logger = get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainingArguments(TrainingArguments):
|
| 18 |
+
|
| 19 |
+
model_name_or_path: str = field(
|
| 20 |
+
default=None,
|
| 21 |
+
metadata={
|
| 22 |
+
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
| 23 |
+
},
|
| 24 |
+
)
|
| 25 |
+
tokenizer: str = field(
|
| 26 |
+
default="fla-hub/gla-1.3B-100B",
|
| 27 |
+
metadata={"help": "Name of the tokenizer to use."}
|
| 28 |
+
)
|
| 29 |
+
use_fast_tokenizer: bool = field(
|
| 30 |
+
default=False,
|
| 31 |
+
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
| 32 |
+
)
|
| 33 |
+
from_config: bool = field(
|
| 34 |
+
default=True,
|
| 35 |
+
metadata={"help": "Whether to initialize models from scratch."},
|
| 36 |
+
)
|
| 37 |
+
dataset: Optional[str] = field(
|
| 38 |
+
default=None,
|
| 39 |
+
metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
|
| 40 |
+
)
|
| 41 |
+
dataset_name: Optional[str] = field(
|
| 42 |
+
default=None,
|
| 43 |
+
metadata={"help": "The name of provided dataset(s) to use."},
|
| 44 |
+
)
|
| 45 |
+
cache_dir: str = field(
|
| 46 |
+
default=None,
|
| 47 |
+
metadata={"help": "Path to the cached tokenized dataset."},
|
| 48 |
+
)
|
| 49 |
+
split: str = field(
|
| 50 |
+
default="train",
|
| 51 |
+
metadata={"help": "Which dataset split to use for training and evaluation."},
|
| 52 |
+
)
|
| 53 |
+
streaming: bool = field(
|
| 54 |
+
default=False,
|
| 55 |
+
metadata={"help": "Enable dataset streaming."},
|
| 56 |
+
)
|
| 57 |
+
hf_hub_token: Optional[str] = field(
|
| 58 |
+
default=None,
|
| 59 |
+
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
| 60 |
+
)
|
| 61 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 62 |
+
default=None,
|
| 63 |
+
metadata={"help": "The number of processes to use for the pre-processing."},
|
| 64 |
+
)
|
| 65 |
+
buffer_size: int = field(
|
| 66 |
+
default=2048,
|
| 67 |
+
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
| 68 |
+
)
|
| 69 |
+
context_length: int = field(
|
| 70 |
+
default=2048,
|
| 71 |
+
metadata={"help": "The context length of the tokenized inputs in the dataset."},
|
| 72 |
+
)
|
| 73 |
+
varlen: bool = field(
|
| 74 |
+
default=False,
|
| 75 |
+
metadata={"help": "Enable training with variable length inputs."},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_train_args():
|
| 80 |
+
parser = HfArgumentParser(TrainingArguments)
|
| 81 |
+
args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
| 82 |
+
|
| 83 |
+
if unknown_args:
|
| 84 |
+
print(parser.format_help())
|
| 85 |
+
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
| 86 |
+
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
| 87 |
+
|
| 88 |
+
if args.should_log:
|
| 89 |
+
transformers.utils.logging.set_verbosity(args.get_process_log_level())
|
| 90 |
+
transformers.utils.logging.enable_default_handler()
|
| 91 |
+
transformers.utils.logging.enable_explicit_format()
|
| 92 |
+
# set seeds manually
|
| 93 |
+
transformers.set_seed(args.seed)
|
| 94 |
+
return args
|