hyx21 commited on
Commit
93cfccb
·
verified ·
1 Parent(s): 521ee68

Upload 9 files

Browse files
added_tokens.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|execute_end|>": 73444,
3
+ "<|execute_start|>": 73443,
4
+ "<|fim_middle|>": 73446,
5
+ "<|fim_prefix|>": 73445,
6
+ "<|fim_suffix|>": 73447,
7
+ "<|im_end|>": 73440,
8
+ "<|im_start|>": 73441,
9
+ "<|tool_call|>": 73442
10
+ }
cis_pooling.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ import torch
4
+
5
+ MAX_LEN = 32768
6
+
7
+ @triton.jit
8
+ def nosa_mean_pool_kernel(
9
+ cis_ptr, # [N, H]
10
+ cu_seqlens_ptr, # int32 [B]
11
+ result_ptr, # [H, N, M]
12
+ cis_stride_n, # int32
13
+ cis_stride_h, # int32
14
+ result_stride_h, # int32
15
+ result_stride_n, # int32
16
+ result_stride_m, # int32
17
+ N,
18
+ H,
19
+ M,
20
+ kernel_size: tl.constexpr,
21
+ stride,
22
+ MAX_LEN: tl.constexpr,
23
+ ):
24
+ # grid: (H, B, M)
25
+ tidx_h = tl.program_id(0) # head
26
+ tidx_b = tl.program_id(1) # batch idx
27
+ tidx_m = tl.program_id(2) # window idx
28
+
29
+
30
+ batch_start = tl.load(cu_seqlens_ptr + tidx_b)
31
+ batch_end = tl.load(cu_seqlens_ptr + tidx_b + 1)
32
+
33
+ block_idx = tl.arange(0, kernel_size)
34
+
35
+ beg_pos = cis_ptr + tidx_h * cis_stride_h + (batch_start + tidx_m * stride) * cis_stride_n
36
+
37
+ block_cis_ptrs = beg_pos + block_idx * cis_stride_n
38
+ mask = (block_idx + tidx_m * stride) < (batch_end - batch_start)
39
+ block_scores = tl.load(
40
+ block_cis_ptrs,
41
+ mask=mask,
42
+ other=0.0,
43
+ )
44
+
45
+ # 对block_scores做平均值,注意mask要对, 分母上是mask的有效元素数
46
+ val_cnt = tl.sum(mask.to(tl.int32), axis=0)
47
+ acc = tl.sum(block_scores, axis=0) / val_cnt
48
+
49
+ if tidx_m * stride + kernel_size <= batch_end - batch_start:
50
+ write_pos = result_ptr + tidx_h * result_stride_h + batch_start * result_stride_n + tidx_m * result_stride_m
51
+ write_idx = tl.arange(0, MAX_LEN)
52
+ write_ptrs = write_pos + write_idx * result_stride_n
53
+ tl.store(write_ptrs, acc, mask=write_idx < batch_end - batch_start)
54
+
55
+ def nosa_mean_pooling(cis_score, cu_seqlens, max_seqlen, kernel_size=32, stride=16):
56
+ """
57
+ cis_score: [N, H] (torch.Tensor, float32/bfloat16/float16都行,但triton里先用float32)
58
+ cu_seqlens: [B+1] (torch.int32)
59
+ """
60
+ assert kernel_size == 32 and stride == 16
61
+
62
+ N, H = cis_score.shape
63
+ B = cu_seqlens.numel() - 1
64
+ M = max_seqlen // stride - 1 # 每个batch最大窗口数
65
+ M = max(M, 0) # bug fix
66
+ assert max_seqlen < MAX_LEN, f"Please increate MAX_LEN, MAX_LEN: {MAX_LEN}, max_seqlen: {max_seqlen}"
67
+
68
+ result = torch.zeros((H, N, M), dtype=cis_score.dtype, device=cis_score.device)
69
+
70
+ grid = (H, B, M)
71
+ nosa_mean_pool_kernel[grid](
72
+ cis_score,
73
+ cu_seqlens,
74
+ result,
75
+ cis_score.stride(0),
76
+ cis_score.stride(1),
77
+ result.stride(0),
78
+ result.stride(1),
79
+ result.stride(2),
80
+ N, H, M, kernel_size, stride, MAX_LEN
81
+ )
82
+
83
+ return result
84
+
85
+
86
+ def main():
87
+ torch.manual_seed(0)
88
+ device = "cuda"
89
+
90
+ # 模拟数据
91
+ B = 2
92
+ H = 4
93
+ lens = [67, 1432] # 每个 batch 的长度
94
+ cu_seqlens = torch.tensor([0] + list(torch.cumsum(torch.tensor(lens), dim=0)), dtype=torch.int32, device=device)
95
+ N = cu_seqlens[-1].item()
96
+ max_seqlen = max(lens)
97
+
98
+ cis_score = torch.randn(N, H, device=device, dtype=torch.bfloat16)
99
+
100
+ # Triton 版本
101
+ result = nosa_mean_pooling(cis_score, cu_seqlens, max_seqlen, kernel_size=32, stride=16)
102
+
103
+ # PyTorch baseline: 对每个 batch 做 pooling 然后广播
104
+ M = max_seqlen // 16 - 1
105
+ baseline = torch.zeros((H, N, M), device=device, dtype=torch.bfloat16)
106
+ for b in range(B):
107
+ start, end = cu_seqlens[b].item(), cu_seqlens[b+1].item()
108
+ seq = cis_score[start:end].T.unsqueeze(0) # [1, H, L]
109
+ pooled = torch.nn.functional.avg_pool1d(seq, kernel_size=32, stride=16) # [1, H, m]
110
+ pooled = pooled.squeeze(0) # [H, m]
111
+ baseline[:, start:end, :pooled.size(-1)] = pooled.unsqueeze(1).expand(H, end-start, pooled.size(-1))
112
+
113
+ # 检查差异
114
+ max_diff = (result - baseline).abs().max()
115
+ print("Triton vs PyTorch max diff:", max_diff.item())
116
+
117
+ if __name__ == "__main__":
118
+ main()
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openbmb/CPM-2B",
3
+ "architectures": [
4
+ "SparseLlamaForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModelForCausalLM": "modeling_llama_long_infllmv2.SparseLlamaForCausalLM"
8
+ },
9
+ "bos_token_id": 1,
10
+ "eos_token_id": [2,73440],
11
+ "pad_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 4096,
14
+ "initializer_range": 0.1,
15
+ "intermediate_size": 16384,
16
+ "head_dim": 128,
17
+ "max_position_embeddings": 32768,
18
+ "num_attention_heads": 32,
19
+ "num_hidden_layers": 32,
20
+ "model_type": "llama",
21
+ "num_key_value_heads": 2,
22
+ "rms_norm_eps": 1e-06,
23
+ "rope_scaling": {
24
+ "rope_type": "longrope",
25
+ "attention_factor": 1.0,
26
+ "long_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.615569542115128, 5.2684819496549835, 6.014438591970396, 6.858830049237097, 7.804668263503327, 8.851768731513417, 9.99600492938444, 11.228766118181639, 12.536757560834843, 13.902257701387796, 15.303885189125953, 16.717837610115794, 18.119465097853947, 19.484965238406907, 20.792956681060105, 22.02571786985731, 23.16995406772833, 24.217054535738416, 25.16289275000465, 26.007284207271347, 26.753240849586767, 27.40615325712662, 27.973003419175363, 28.461674954469114, 28.880393889607006, 29.237306864684626, 29.540186419591297, 29.79624387177199, 30.01202719065413, 30.193382037992453, 30.34545697551969, 30.47273746338473, 30.579096895249787, 30.66785612408345, 30.741845563814174, 30.80346599254902, 30.85474569563567, 30.897392663720595, 30.932841297560394, 30.962293553185553, 30.986754758742034, 31.007064503249293, 31.02392307921529],
27
+ "short_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.615569542115128, 5.2684819496549835, 6.014438591970396, 6.858830049237097, 7.804668263503327, 8.851768731513417, 9.99600492938444, 11.228766118181639, 12.536757560834843, 13.902257701387796, 15.303885189125953, 16.717837610115794, 18.119465097853947, 19.484965238406907, 20.792956681060105, 22.02571786985731, 23.16995406772833, 24.217054535738416, 25.16289275000465, 26.007284207271347, 26.753240849586767, 27.40615325712662, 27.973003419175363, 28.461674954469114, 28.880393889607006, 29.237306864684626, 29.540186419591297, 29.79624387177199, 30.01202719065413, 30.193382037992453, 30.34545697551969, 30.47273746338473, 30.579096895249787, 30.66785612408345, 30.741845563814174, 30.80346599254902, 30.85474569563567, 30.897392663720595, 30.932841297560394, 30.962293553185553, 30.986754758742034, 31.007064503249293, 31.02392307921529],
28
+ "original_max_position_embeddings": 32768
29
+ },
30
+ "rope_theta": 10000.0,
31
+ "torch_dtype": "bfloat16",
32
+ "transformers_version": "4.36.0",
33
+ "use_cache": true,
34
+ "vocab_size": 73448
35
+ }
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 2,
6
+ 73440
7
+ ],
8
+ "pad_token_id": 2,
9
+ "temperature": 0.8,
10
+ "top_p": 0.8,
11
+ "transformers_version": "4.46.1"
12
+ }
modeling_llama_long_infllmv2.py ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_end|>",
4
+ "<|im_start|>",
5
+ "<|tool_call|>",
6
+ "<|execute_start|>",
7
+ "<|execute_end|>",
8
+ "<|fim_prefix|>",
9
+ "<|fim_middle|>",
10
+ "<|fim_suffix|>"
11
+ ],
12
+ "bos_token": {
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "content": "<|im_end|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
3
+ size 1181204
tokenizer_config.json ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "73440": {
31
+ "content": "<|im_end|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "73441": {
39
+ "content": "<|im_start|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "73442": {
47
+ "content": "<|tool_call|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "73443": {
55
+ "content": "<|execute_start|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "73444": {
63
+ "content": "<|execute_end|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "73445": {
71
+ "content": "<|fim_prefix|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "73446": {
79
+ "content": "<|fim_middle|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "73447": {
87
+ "content": "<|fim_suffix|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ }
94
+ },
95
+ "additional_special_tokens": [
96
+ "<|im_end|>",
97
+ "<|im_start|>",
98
+ "<|tool_call|>",
99
+ "<|execute_start|>",
100
+ "<|execute_end|>",
101
+ "<|fim_prefix|>",
102
+ "<|fim_middle|>",
103
+ "<|fim_suffix|>"
104
+ ],
105
+ "bos_token": "<s>",
106
+ "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% if enable_thinking is defined and enable_thinking is false %}{{ '<think>\n\n</think>\n' }}{% endif %}{% endif %}",
107
+ "clean_up_tokenization_spaces": false,
108
+ "eos_token": "<|im_end|>",
109
+ "legacy": true,
110
+ "model_max_length": 1000000000000000019884624838656,
111
+ "pad_token": null,
112
+ "sp_model_kwargs": {},
113
+ "spaces_between_special_tokens": false,
114
+ "tokenizer_class": "LlamaTokenizer",
115
+ "unk_token": "<unk>",
116
+ "use_default_system_prompt": false
117
+ }