msj19 commited on
Commit
65775f0
·
verified ·
1 Parent(s): 7652cf9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +173 -0
  2. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +173 -0
  3. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py +428 -0
  4. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +123 -0
  5. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py +273 -0
  6. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/naive.py +96 -0
  7. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +164 -0
  8. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +284 -0
  9. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/__init__.py +7 -0
  10. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/chunk.py +500 -0
  11. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py +452 -0
  12. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/naive.py +69 -0
  13. build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/wy_fast.py +300 -0
  14. docs/en/.readthedocs.yaml +17 -0
  15. docs/en/Makefile +20 -0
  16. docs/en/_static/css/readthedocs.css +62 -0
  17. docs/en/_static/image/logo.svg +79 -0
  18. docs/en/_static/image/logo_icon.svg +31 -0
  19. docs/en/_static/js/custom.js +20 -0
  20. docs/en/_templates/404.html +18 -0
  21. docs/en/_templates/autosummary/class.rst +13 -0
  22. docs/en/_templates/callable.rst +14 -0
  23. docs/en/advanced_guides/accelerator_intro.md +142 -0
  24. docs/en/advanced_guides/circular_eval.md +113 -0
  25. docs/en/advanced_guides/code_eval.md +104 -0
  26. docs/en/advanced_guides/code_eval_service.md +224 -0
  27. docs/en/advanced_guides/contamination_eval.md +124 -0
  28. docs/en/advanced_guides/custom_dataset.md +267 -0
  29. docs/en/advanced_guides/evaluation_lightllm.md +71 -0
  30. docs/en/advanced_guides/evaluation_lmdeploy.md +88 -0
  31. docs/en/advanced_guides/llm_judge.md +370 -0
  32. docs/en/advanced_guides/longeval.md +169 -0
  33. docs/en/advanced_guides/math_verify.md +190 -0
  34. docs/en/advanced_guides/needleinahaystack_eval.md +138 -0
  35. docs/en/advanced_guides/new_dataset.md +105 -0
  36. docs/en/advanced_guides/new_model.md +73 -0
  37. docs/en/advanced_guides/objective_judgelm_evaluation.md +186 -0
  38. docs/en/advanced_guides/persistence.md +65 -0
  39. docs/en/advanced_guides/prompt_attack.md +108 -0
  40. docs/en/advanced_guides/subjective_evaluation.md +171 -0
  41. docs/en/conf.py +234 -0
  42. docs/en/docutils.conf +2 -0
  43. docs/en/get_started/faq.md +128 -0
  44. docs/en/get_started/installation.md +142 -0
  45. docs/en/get_started/quick_start.md +300 -0
  46. docs/en/index.rst +99 -0
  47. docs/en/notes/academic.md +106 -0
  48. docs/en/notes/contribution_guide.md +158 -0
  49. docs/en/notes/news.md +40 -0
  50. docs/en/prompt/chain_of_thought.md +127 -0
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ cu_seqlens,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ IS_VARLEN: tl.constexpr,
54
+ ):
55
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ NT = tl.cdiv(T, BT)
61
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
62
+ else:
63
+ bos, eos = i_n * T, i_n * T + T
64
+ NT = tl.cdiv(T, BT)
65
+ boh = i_n * NT
66
+
67
+ # [BK, BV]
68
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
69
+ if USE_FINAL_STATE_GRADIENT:
70
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
71
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
72
+
73
+ mask_k = tl.arange(0, BK) < K
74
+ for i_t in range(NT - 1, -1, -1):
75
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
77
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
78
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
79
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
80
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
83
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
84
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ # [BK, BT]
86
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
87
+ # [BT, BK]
88
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
89
+ b_w = tl.load(p_w, boundary_check=(0, 1))
90
+ # [BT, V]
91
+ b_do = tl.load(p_do, boundary_check=(0, 1))
92
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
93
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
94
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
95
+ # [BK, BV]
96
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
97
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
98
+ last_idx = min((i_t + 1) * BT, T) - 1
99
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
100
+ b_dh *= exp(bg_last)[:, None]
101
+ b_dh += b_dh_tmp
102
+
103
+ if USE_INITIAL_STATE:
104
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
105
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
106
+
107
+
108
+ def chunk_dplr_bwd_dhu(
109
+ qg: torch.Tensor,
110
+ bg: torch.Tensor,
111
+ w: torch.Tensor,
112
+ gk: torch.Tensor,
113
+ h0: torch.Tensor,
114
+ dht: Optional[torch.Tensor],
115
+ do: torch.Tensor,
116
+ dv: torch.Tensor,
117
+ cu_seqlens: Optional[torch.LongTensor] = None,
118
+ chunk_size: int = 64
119
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
120
+ B, T, H, K, V = *qg.shape, do.shape[-1]
121
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
122
+ BK = triton.next_power_of_2(K)
123
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
124
+ # H100
125
+ if check_shared_mem('hopper', qg.device.index):
126
+ BV = 64
127
+ BC = 64 if K <= 128 else 32
128
+ elif check_shared_mem('ampere', qg.device.index): # A100
129
+ BV = 32
130
+ BC = 32
131
+ else: # Etc: 4090
132
+ BV = 16
133
+ BC = 16
134
+
135
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
136
+ # N: the actual number of sequences in the batch with either equal or variable lengths
137
+ if cu_seqlens is None:
138
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
139
+ else:
140
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
141
+
142
+ BC = min(BT, BC)
143
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
144
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
145
+
146
+ dh = qg.new_empty(B, NT, H, K, V)
147
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
148
+ dv2 = torch.zeros_like(dv)
149
+
150
+ grid = (NK, NV, N * H)
151
+ chunk_dplr_bwd_kernel_dhu[grid](
152
+ qg=qg,
153
+ bg=bg,
154
+ w=w,
155
+ gk=gk,
156
+ dht=dht,
157
+ dh0=dh0,
158
+ do=do,
159
+ dh=dh,
160
+ dv=dv,
161
+ dv2=dv2,
162
+ cu_seqlens=cu_seqlens,
163
+ chunk_offsets=chunk_offsets,
164
+ T=T,
165
+ H=H,
166
+ K=K,
167
+ V=V,
168
+ BT=BT,
169
+ BC=BC,
170
+ BK=BK,
171
+ BV=BV,
172
+ )
173
+ return dh, dh0, dv2
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ cu_seqlens,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ IS_VARLEN: tl.constexpr,
54
+ ):
55
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ NT = tl.cdiv(T, BT)
61
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
62
+ else:
63
+ bos, eos = i_n * T, i_n * T + T
64
+ NT = tl.cdiv(T, BT)
65
+ boh = i_n * NT
66
+ o_k = i_k * BK + tl.arange(0, BK)
67
+
68
+ # [BK, BV]
69
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_INITIAL_STATE:
71
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
73
+
74
+ for i_t in range(NT):
75
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
79
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
80
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
81
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
83
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
84
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
87
+ # [BK, BC]
88
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
89
+ b_v = tl.load(p_v, boundary_check=(0, 1))
90
+ b_w = tl.load(p_w, boundary_check=(0, 1))
91
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
92
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
93
+ b_hc += tl.dot(b_kg, b_v)
94
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
95
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
96
+
97
+ last_idx = min((i_t + 1) * BT, T) - 1
98
+ b_g_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k, mask=o_k < K).to(tl.float32)
99
+ b_h *= exp(b_g_last[:, None])
100
+ b_h += b_hc
101
+
102
+ if STORE_FINAL_STATE:
103
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
104
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
105
+
106
+
107
+ def chunk_dplr_fwd_h(
108
+ kg: torch.Tensor,
109
+ v: torch.Tensor,
110
+ w: torch.Tensor,
111
+ u: torch.Tensor,
112
+ bg: torch.Tensor,
113
+ gk: torch.Tensor,
114
+ initial_state: Optional[torch.Tensor] = None,
115
+ output_final_state: bool = False,
116
+ cu_seqlens: Optional[torch.LongTensor] = None,
117
+ chunk_size: int = 64
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ B, T, H, K, V = *kg.shape, u.shape[-1]
120
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
121
+
122
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
123
+ # N: the actual number of sequences in the batch with either equal or variable lengths
124
+ if cu_seqlens is None:
125
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
126
+ else:
127
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
128
+ BK = triton.next_power_of_2(K)
129
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
130
+ # H100 can have larger block size
131
+
132
+ if check_shared_mem('hopper', kg.device.index):
133
+ BV = 64
134
+ BC = 64 if K <= 128 else 32
135
+ elif check_shared_mem('ampere', kg.device.index): # A100
136
+ BV = 32
137
+ BC = 32
138
+ else:
139
+ BV = 16
140
+ BC = 16
141
+
142
+ BC = min(BT, BC)
143
+ NK = triton.cdiv(K, BK)
144
+ NV = triton.cdiv(V, BV)
145
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
146
+
147
+ h = kg.new_empty(B, NT, H, K, V)
148
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
149
+ v_new = torch.empty_like(u)
150
+ grid = (NK, NV, N * H)
151
+ chunk_dplr_fwd_kernel_h[grid](
152
+ kg=kg,
153
+ v=v,
154
+ w=w,
155
+ bg=bg,
156
+ u=u,
157
+ v_new=v_new,
158
+ h=h,
159
+ gk=gk,
160
+ h0=initial_state,
161
+ ht=final_state,
162
+ cu_seqlens=cu_seqlens,
163
+ chunk_offsets=chunk_offsets,
164
+ T=T,
165
+ H=H,
166
+ K=K,
167
+ V=V,
168
+ BT=BT,
169
+ BC=BC,
170
+ BK=BK,
171
+ BV=BV,
172
+ )
173
+ return h, v_new, final_state
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BV', 'BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dAu(
31
+ v,
32
+ do,
33
+ v_new,
34
+ A_qb,
35
+ dA_qk,
36
+ dA_qb,
37
+ dv_new,
38
+ cu_seqlens,
39
+ chunk_indices,
40
+ scale: tl.constexpr,
41
+ T,
42
+ H: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BV: tl.constexpr,
46
+ IS_VARLEN: tl.constexpr,
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if IS_VARLEN:
51
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
53
+ else:
54
+ bos, eos = i_b * T, i_b * T + T
55
+ T = eos - bos
56
+
57
+ b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32)
58
+ b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32)
59
+
60
+ p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
61
+
62
+ b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1))
63
+ # causal mask
64
+ b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype)
65
+
66
+ for i_v in range(tl.cdiv(V, BV)):
67
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
68
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
69
+ p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
70
+ p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
71
+ b_v = tl.load(p_v, boundary_check=(0, 1))
72
+ b_do = tl.load(p_do, boundary_check=(0, 1))
73
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
74
+ b_dA_qk += tl.dot(b_do, b_v)
75
+ b_dA_qb += tl.dot(b_do, b_v_new)
76
+ b_dv_new = tl.dot(tl.trans(b_A_qb), b_do)
77
+ # for recurrent
78
+ tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1))
79
+
80
+ p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
81
+ p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
82
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
83
+ b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.)
84
+ tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1))
85
+ b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.)
86
+ tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1))
87
+
88
+
89
+ @triton.heuristics({
90
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
91
+ })
92
+ @triton.autotune(
93
+ configs=[
94
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
95
+ for num_warps in [2, 4, 8, 16, 32]
96
+ for num_stages in [2, 3, 4]
97
+ ],
98
+ key=['BT', 'BK', 'BV'],
99
+ use_cuda_graph=use_cuda_graph,
100
+ )
101
+ @triton.jit
102
+ def chunk_dplr_bwd_o_kernel(
103
+ v,
104
+ v_new,
105
+ h,
106
+ do,
107
+ dh,
108
+ dk,
109
+ db,
110
+ w,
111
+ dq,
112
+ dv,
113
+ dw,
114
+ gk,
115
+ dgk_last,
116
+ k,
117
+ b,
118
+ cu_seqlens,
119
+ chunk_indices,
120
+ T,
121
+ H: tl.constexpr,
122
+ K: tl.constexpr,
123
+ V: tl.constexpr,
124
+ BT: tl.constexpr,
125
+ BK: tl.constexpr,
126
+ BV: tl.constexpr,
127
+ IS_VARLEN: tl.constexpr,
128
+ ):
129
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
130
+ i_b, i_h = i_bh // H, i_bh % H
131
+
132
+ if IS_VARLEN:
133
+ i_tg = i_t
134
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
135
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
136
+ T = eos - bos
137
+ NT = tl.cdiv(T, BT)
138
+ else:
139
+ NT = tl.cdiv(T, BT)
140
+ i_tg = i_b * NT + i_t
141
+ bos, eos = i_b * T, i_b * T + T
142
+
143
+ # offset calculation
144
+ v += (bos * H + i_h) * V
145
+ v_new += (bos * H + i_h) * V
146
+ do += (bos * H + i_h) * V
147
+ h += (i_tg * H + i_h) * K * V
148
+ dh += (i_tg * H + i_h) * K * V
149
+ dk += (bos * H + i_h) * K
150
+ k += (bos * H + i_h) * K
151
+ db += (bos * H + i_h) * K
152
+ b += (bos * H + i_h) * K
153
+ dw += (bos * H + i_h) * K
154
+ dv += (bos * H + i_h) * V
155
+ dq += (bos * H + i_h) * K
156
+ w += (bos * H + i_h) * K
157
+
158
+ dgk_last += (i_tg * H + i_h) * K
159
+ gk += (bos * H + i_h) * K
160
+
161
+ stride_qk = H*K
162
+ stride_vo = H*V
163
+
164
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
165
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
166
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32)
167
+ b_db = tl.zeros([BT, BK], dtype=tl.float32)
168
+ b_dgk_last = tl.zeros([BK], dtype=tl.float32)
169
+
170
+ for i_v in range(tl.cdiv(V, BV)):
171
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
172
+ p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
173
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
174
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
175
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
176
+ # [BT, BV]
177
+ b_v = tl.load(p_v, boundary_check=(0, 1))
178
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
179
+ b_do = tl.load(p_do, boundary_check=(0, 1))
180
+ # [BV, BK]
181
+ b_h = tl.load(p_h, boundary_check=(0, 1))
182
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
183
+ b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0)
184
+
185
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
186
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
187
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
188
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
189
+ b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype))
190
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
191
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
192
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
193
+
194
+ m_k = (i_k*BK+tl.arange(0, BK)) < K
195
+ last_idx = min(i_t * BT + BT, T) - 1
196
+ b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf'))
197
+ b_dgk_last *= exp(b_gk_last)
198
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
199
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
200
+ b_k = tl.load(p_k, boundary_check=(0, 1))
201
+ b_b = tl.load(p_b, boundary_check=(0, 1))
202
+ b_dgk_last += tl.sum(b_k * b_dk, axis=0)
203
+ b_dgk_last += tl.sum(b_b * b_db, axis=0)
204
+ tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k)
205
+
206
+ p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
207
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
208
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
209
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
210
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
211
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
212
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
213
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
214
+
215
+
216
+ @triton.heuristics({
217
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
218
+ })
219
+ @triton.autotune(
220
+ configs=[
221
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
222
+ for num_warps in [2, 4, 8, 16, 32]
223
+ for num_stages in [2, 3, 4]
224
+ for BK in BK_LIST
225
+ for BV in BK_LIST
226
+ ],
227
+ key=['BT'],
228
+ use_cuda_graph=use_cuda_graph,
229
+ )
230
+ @triton.jit
231
+ def chunk_dplr_bwd_kernel_dv(
232
+ A_qk,
233
+ kg,
234
+ do,
235
+ dv,
236
+ dh,
237
+ cu_seqlens,
238
+ chunk_indices,
239
+ T,
240
+ H: tl.constexpr,
241
+ K: tl.constexpr,
242
+ V: tl.constexpr,
243
+ BT: tl.constexpr,
244
+ BK: tl.constexpr,
245
+ BV: tl.constexpr,
246
+ IS_VARLEN: tl.constexpr,
247
+ ):
248
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
249
+ i_b, i_h = i_bh // H, i_bh % H
250
+ if IS_VARLEN:
251
+ i_tg = i_t
252
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
253
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
254
+ T = eos - bos
255
+ NT = tl.cdiv(T, BT)
256
+ else:
257
+ NT = tl.cdiv(T, BT)
258
+ i_tg = i_b * NT + i_t
259
+ bos, eos = i_b * T, i_b * T + T
260
+
261
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
262
+
263
+ # offset calculation
264
+ A_qk += (bos * H + i_h) * BT
265
+ do += (bos * H + i_h) * V
266
+ dv += (bos * H + i_h) * V
267
+ kg += (bos * H + i_h) * K
268
+ dh += (i_tg * H + i_h) * K*V
269
+
270
+ stride_qk = H*K
271
+ stride_vo = H*V
272
+ stride_A = H*BT
273
+
274
+ for i_k in range(tl.cdiv(K, BK)):
275
+ p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
276
+ p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
277
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
278
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
279
+ b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype))
280
+
281
+ p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1))
282
+ b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0)
283
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
284
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
285
+ b_do = tl.load(p_do, boundary_check=(0, 1))
286
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
287
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
288
+
289
+
290
+ def chunk_dplr_bwd_dv(
291
+ A_qk: torch.Tensor,
292
+ kg: torch.Tensor,
293
+ do: torch.Tensor,
294
+ dh: torch.Tensor,
295
+ cu_seqlens: Optional[torch.LongTensor] = None,
296
+ chunk_size: int = 64
297
+ ) -> torch.Tensor:
298
+ B, T, H, K, V = *kg.shape, do.shape[-1]
299
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
300
+
301
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
302
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
303
+
304
+ dv = torch.empty_like(do)
305
+
306
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
307
+ chunk_dplr_bwd_kernel_dv[grid](
308
+ A_qk=A_qk,
309
+ kg=kg,
310
+ do=do,
311
+ dv=dv,
312
+ dh=dh,
313
+ cu_seqlens=cu_seqlens,
314
+ chunk_indices=chunk_indices,
315
+ T=T,
316
+ H=H,
317
+ K=K,
318
+ V=V,
319
+ BT=BT,
320
+ )
321
+ return dv
322
+
323
+
324
+ def chunk_dplr_bwd_o(
325
+ k: torch.Tensor,
326
+ b: torch.Tensor,
327
+ v: torch.Tensor,
328
+ v_new: torch.Tensor,
329
+ gk: torch.Tensor,
330
+ do: torch.Tensor,
331
+ h: torch.Tensor,
332
+ dh: torch.Tensor,
333
+ dv: torch.Tensor,
334
+ w: torch.Tensor,
335
+ cu_seqlens: Optional[torch.LongTensor] = None,
336
+ chunk_size: int = 64,
337
+ scale: float = 1.0,
338
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
339
+
340
+ B, T, H, K, V = *w.shape, v.shape[-1]
341
+
342
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
343
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
344
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
345
+
346
+ BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
347
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
348
+ NK = triton.cdiv(K, BK)
349
+ dq = torch.empty_like(k)
350
+ dk = torch.empty_like(k)
351
+ dw = torch.empty_like(w)
352
+ db = torch.empty_like(b)
353
+ grid = (NK, NT, B * H)
354
+
355
+ dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device)
356
+
357
+ chunk_dplr_bwd_o_kernel[grid](
358
+ k=k,
359
+ b=b,
360
+ v=v,
361
+ v_new=v_new,
362
+ h=h,
363
+ do=do,
364
+ dh=dh,
365
+ dq=dq,
366
+ dk=dk,
367
+ db=db,
368
+ dgk_last=dgk_last,
369
+ w=w,
370
+ dv=dv,
371
+ dw=dw,
372
+ gk=gk,
373
+ cu_seqlens=cu_seqlens,
374
+ chunk_indices=chunk_indices,
375
+ T=T,
376
+ H=H,
377
+ K=K,
378
+ V=V,
379
+ BT=BT,
380
+ BK=BK,
381
+ BV=BV,
382
+ )
383
+ return dq, dk, dw, db, dgk_last
384
+
385
+
386
+ def chunk_dplr_bwd_dAu(
387
+ v: torch.Tensor,
388
+ v_new: torch.Tensor,
389
+ do: torch.Tensor,
390
+ A_qb: torch.Tensor,
391
+ scale: float,
392
+ cu_seqlens: Optional[torch.LongTensor] = None,
393
+ chunk_size: int = 64
394
+ ) -> torch.Tensor:
395
+ B, T, H, V = v.shape
396
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
397
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
398
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
399
+
400
+ if check_shared_mem('ampere'): # A100
401
+ BV = min(triton.next_power_of_2(V), 128)
402
+ elif check_shared_mem('ada'): # 4090
403
+ BV = min(triton.next_power_of_2(V), 64)
404
+ else:
405
+ BV = min(triton.next_power_of_2(V), 32)
406
+
407
+ grid = (NT, B * H)
408
+ dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
409
+ dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
410
+ dv_new = torch.empty_like(v_new)
411
+ chunk_dplr_bwd_kernel_dAu[grid](
412
+ v=v,
413
+ do=do,
414
+ v_new=v_new,
415
+ A_qb=A_qb,
416
+ dA_qk=dA_qk,
417
+ dA_qb=dA_qb,
418
+ dv_new=dv_new,
419
+ cu_seqlens=cu_seqlens,
420
+ chunk_indices=chunk_indices,
421
+ scale=scale,
422
+ T=T,
423
+ H=H,
424
+ V=V,
425
+ BT=BT,
426
+ BV=BV,
427
+ )
428
+ return dv_new, dA_qk, dA_qb
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....utils import check_shared_mem, use_cuda_graph
12
+
13
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BK in BK_LIST
23
+ for BV in BK_LIST
24
+ for num_warps in [2, 4, 8, 16, 32]
25
+ for num_stages in [2, 3, 4]
26
+ ],
27
+ key=['BT'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_dplr_fwd_kernel_o(
32
+ qg,
33
+ v,
34
+ v_new,
35
+ A_qk,
36
+ A_qb,
37
+ h,
38
+ o,
39
+ cu_seqlens,
40
+ chunk_indices,
41
+ T,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if IS_VARLEN:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
67
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
68
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
69
+ b_h = tl.load(p_h, boundary_check=(0, 1))
70
+ b_o += tl.dot(b_qg, b_h)
71
+
72
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
74
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
75
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+
78
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
79
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
80
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
81
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
82
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
83
+ b_v = tl.load(p_v, boundary_check=(0, 1))
84
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
85
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
86
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
87
+
88
+
89
+ def chunk_dplr_fwd_o(
90
+ qg: torch.Tensor,
91
+ v: torch.Tensor,
92
+ v_new: torch.Tensor,
93
+ A_qk: torch.Tensor,
94
+ A_qb: torch.Tensor,
95
+ h: torch.Tensor,
96
+ cu_seqlens: Optional[torch.LongTensor] = None,
97
+ chunk_size: int = 64
98
+ ) -> torch.Tensor:
99
+ B, T, H, K, V = *qg.shape, v.shape[-1]
100
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
101
+
102
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
103
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
104
+
105
+ o = torch.empty_like(v)
106
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
107
+ chunk_dplr_fwd_kernel_o[grid](
108
+ qg=qg,
109
+ v=v,
110
+ v_new=v_new,
111
+ A_qk=A_qk,
112
+ A_qb=A_qb,
113
+ h=h,
114
+ o=o,
115
+ cu_seqlens=cu_seqlens,
116
+ chunk_indices=chunk_indices,
117
+ T=T,
118
+ H=H,
119
+ K=K,
120
+ V=V,
121
+ BT=BT,
122
+ )
123
+ return o
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils.op import exp
11
+ from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ cu_seqlens,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ IS_VARLEN: tl.constexpr,
53
+ ):
54
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
55
+ i_n, i_h = i_nh // H, i_nh % H
56
+
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
59
+ T = eos - bos
60
+ else:
61
+ bos, eos = i_n * T, i_n * T + T
62
+
63
+ o_k = tl.arange(0, BK)
64
+ o_v = i_v * BV + tl.arange(0, BV)
65
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
66
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
67
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
68
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
69
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
70
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
71
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
72
+
73
+ mask_k = o_k < K
74
+ mask_v = o_v < V
75
+ mask_h = mask_k[None, :] & mask_v[:, None]
76
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
77
+
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
80
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
81
+
82
+ for _ in range(0, T):
83
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
84
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
85
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
86
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
87
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
88
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
89
+
90
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
91
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
92
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
93
+
94
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
95
+ p_q += (-1 if REVERSE else 1) * H*K
96
+ p_k += (-1 if REVERSE else 1) * H*K
97
+ p_a += (-1 if REVERSE else 1) * H*K
98
+ p_b += (-1 if REVERSE else 1) * H*K
99
+ p_gk += (-1 if REVERSE else 1) * H*K
100
+ p_v += (-1 if REVERSE else 1) * H*V
101
+ p_o += (-1 if REVERSE else 1) * H*V
102
+
103
+ if STORE_FINAL_STATE:
104
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
105
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
106
+
107
+
108
+ def fused_recurrent_dplr_delta_rule_fwd(
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ a: torch.Tensor,
113
+ b: torch.Tensor,
114
+ gk: torch.Tensor,
115
+ scale: Optional[float] = 1.0,
116
+ initial_state: Optional[torch.Tensor] = None,
117
+ output_final_state: bool = False,
118
+ reverse: bool = False,
119
+ cu_seqlens: Optional[torch.LongTensor] = None,
120
+ ):
121
+ B, T, H, K, V = *k.shape, v.shape[-1]
122
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
123
+ BK = triton.next_power_of_2(K)
124
+
125
+ h0 = initial_state
126
+ if output_final_state:
127
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
128
+ else:
129
+ ht = None
130
+ o = torch.empty_like(v)
131
+
132
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
133
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
134
+ q,
135
+ k,
136
+ v,
137
+ a,
138
+ b,
139
+ gk,
140
+ o,
141
+ h0,
142
+ ht,
143
+ cu_seqlens,
144
+ scale,
145
+ T=T,
146
+ B=B,
147
+ H=H,
148
+ K=K,
149
+ V=V,
150
+ BK=BK,
151
+ REVERSE=reverse,
152
+ )
153
+ return o, ht
154
+
155
+
156
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ @input_guard
160
+ @autocast_custom_fwd
161
+ def forward(
162
+ ctx,
163
+ q: torch.Tensor,
164
+ k: torch.Tensor,
165
+ v: torch.Tensor,
166
+ a: torch.Tensor,
167
+ b: torch.Tensor,
168
+ gk: torch.Tensor,
169
+ scale: Optional[float] = 1.0,
170
+ initial_state: Optional[torch.Tensor] = None,
171
+ output_final_state: bool = False,
172
+ reverse: bool = False,
173
+ cu_seqlens: Optional[torch.LongTensor] = None,
174
+ ):
175
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
176
+ q=q,
177
+ k=k,
178
+ v=v,
179
+ a=a,
180
+ b=b,
181
+ gk=gk,
182
+ scale=scale,
183
+ initial_state=initial_state,
184
+ output_final_state=output_final_state,
185
+ reverse=reverse,
186
+ cu_seqlens=cu_seqlens,
187
+ )
188
+ return o, ht
189
+
190
+ @staticmethod
191
+ @input_guard
192
+ @autocast_custom_bwd
193
+ def backward(ctx, do, dht):
194
+ raise NotImplementedError(
195
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
196
+ "This kernel is only for inference. "
197
+ "For training, please use `chunk_dplr_delta_rule`."
198
+ )
199
+
200
+
201
+ def fused_recurrent_dplr_delta_rule(
202
+ q: torch.Tensor,
203
+ k: torch.Tensor,
204
+ v: torch.Tensor,
205
+ a: torch.Tensor,
206
+ b: torch.Tensor,
207
+ gk: torch.Tensor,
208
+ scale: Optional[float] = 1.0,
209
+ initial_state: Optional[torch.Tensor] = None,
210
+ output_final_state: bool = False,
211
+ reverse: bool = False,
212
+ cu_seqlens: Optional[torch.Tensor] = None,
213
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
214
+ r"""
215
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
216
+
217
+ Args:
218
+ q (torch.Tensor):
219
+ queries of shape `[B, T, H, K]`.
220
+ k (torch.Tensor):
221
+ keys of shape `[B, T, H, K]`.
222
+ v (torch.Tensor):
223
+ values of shape `[B, T, H, V]`.
224
+ a (torch.Tensor):
225
+ a of shape `[B, T, H, K]`.
226
+ b (torch.Tensor):
227
+ b of shape `[B, T, H, K]`.
228
+ gk (torch.Tensor):
229
+ gk of shape `[B, T, H, K]`. decay term in log space!
230
+ scale (Optional[int]):
231
+ Scale factor for the RetNet attention scores.
232
+ If not provided, it will default to `1 / sqrt(K)`. Default: 1.
233
+ initial_state (Optional[torch.Tensor]):
234
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
235
+ For equal-length input sequences, `N` equals the batch size `B`.
236
+ Default: `None`.
237
+ output_final_state (Optional[bool]):
238
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
239
+ reverse (Optional[bool]):
240
+ If `True`, process the state passing in reverse order. Default: `False`.
241
+ cu_seqlens (Optional[torch.Tensor]):
242
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
243
+ consistent with the FlashAttention API.
244
+ """
245
+ if cu_seqlens is not None:
246
+ if q.shape[0] != 1:
247
+ raise ValueError(
248
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
249
+ f"Please flatten variable-length inputs before processing."
250
+ )
251
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
252
+ raise ValueError(
253
+ f"The number of initial states is expected to be equal to the number of input sequences, "
254
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
255
+ )
256
+ if scale is None:
257
+ scale = q.shape[-1] ** -0.5
258
+ else:
259
+ assert scale > 0, "scale must be positive"
260
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
261
+ q,
262
+ k,
263
+ v,
264
+ a,
265
+ b,
266
+ gk,
267
+ scale,
268
+ initial_state,
269
+ output_final_state,
270
+ reverse,
271
+ cu_seqlens,
272
+ )
273
+ return o, final_state
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
12
+
13
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
14
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def prepare_wy_repr_bwd_kernel(
31
+ A_ab_inv,
32
+ A_ak,
33
+ ag,
34
+ v,
35
+ dw,
36
+ du,
37
+ dv,
38
+ dv0,
39
+ dag,
40
+ dAak,
41
+ dAab,
42
+ cu_seqlens,
43
+ chunk_indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ IS_VARLEN: tl.constexpr,
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if IS_VARLEN:
56
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
65
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+
67
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
68
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
69
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
70
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
71
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
72
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
73
+
74
+ for i_v in range(tl.cdiv(V, BV)):
75
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
78
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
79
+ b_v = tl.load(p_v, boundary_check=(0, 1))
80
+ b_du = tl.load(p_du, boundary_check=(0, 1))
81
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
82
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
83
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
84
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
85
+
86
+ m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :]
87
+ b_dA_tmp = tl.where(m_i, b_dA_tmp, 0)
88
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
89
+ b_dA_ak = tl.where(m_i, b_dA_ak, 0)
90
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
91
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
92
+
93
+ for i_k in range(tl.cdiv(K, BK)):
94
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
95
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
96
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
97
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
98
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
99
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
100
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
101
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
102
+
103
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
104
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
105
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
106
+ # denote A = I - lower(A_ab), B = A^-1
107
+ # in the backward pass.
108
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
109
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
110
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
111
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
112
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
113
+ b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0)
114
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
115
+
116
+
117
+ def chunk_dplr_bwd_wy(
118
+ A_ab_inv: torch.Tensor,
119
+ A_ak: torch.Tensor,
120
+ v: torch.Tensor,
121
+ ag: torch.Tensor,
122
+ dw: torch.Tensor,
123
+ du: torch.Tensor,
124
+ dv0: torch.Tensor,
125
+ cu_seqlens: Optional[torch.LongTensor],
126
+ chunk_size: int,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
129
+ B, T, H, K, V = *dw.shape, du.shape[-1]
130
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
131
+
132
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
133
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
134
+ BK = min(triton.next_power_of_2(K), 64)
135
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
136
+
137
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
138
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
139
+ dv = torch.empty_like(v)
140
+ dag = torch.empty_like(ag)
141
+
142
+ prepare_wy_repr_bwd_kernel[(NT, B * H)](
143
+ A_ab_inv=A_ab_inv,
144
+ A_ak=A_ak,
145
+ ag=ag,
146
+ v=v,
147
+ dw=dw,
148
+ du=du,
149
+ dv=dv,
150
+ dv0=dv0,
151
+ dag=dag,
152
+ dAak=dA_ak,
153
+ dAab=dA_ab,
154
+ cu_seqlens=cu_seqlens,
155
+ chunk_indices=chunk_indices,
156
+ T=T,
157
+ H=H,
158
+ K=K,
159
+ V=V,
160
+ BT=BT,
161
+ BK=BK,
162
+ BV=BV,
163
+ )
164
+ return dA_ab, dA_ak, dv, dag
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import gather
12
+ from ....utils import is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps)
21
+ for num_warps in [1, 2, 4, 8, 16]
22
+ ],
23
+ key=['BT'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def prepare_wy_repr_fwd_kernel_chunk32(
28
+ A_ab,
29
+ A_ab_inv,
30
+ cu_seqlens,
31
+ chunk_indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BC: tl.constexpr, # placeholder, do not delete
36
+ IS_VARLEN: tl.constexpr,
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
47
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
49
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
50
+ for i in range(1, BT):
51
+ mask = tl.arange(0, BT) == i
52
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
53
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
54
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
55
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
56
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
57
+
58
+
59
+ @triton.heuristics({
60
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
61
+ })
62
+ @triton.autotune(
63
+ configs=[
64
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
65
+ for num_warps in [2, 4, 8]
66
+ for num_stages in [2, 3, 4]
67
+ ],
68
+ key=['BC'],
69
+ use_cuda_graph=use_cuda_graph,
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def prepare_wy_repr_fwd_kernel_chunk64(
73
+ A_ab,
74
+ A_ab_inv,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ BT: tl.constexpr,
80
+ BC: tl.constexpr,
81
+ IS_VARLEN: tl.constexpr,
82
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
83
+ ):
84
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
88
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
89
+ T = eos - bos
90
+ else:
91
+ bos, eos = i_b * T, i_b * T + T
92
+
93
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
94
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
95
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
96
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
97
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
98
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
99
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
100
+
101
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
102
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
103
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
104
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
105
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
106
+
107
+ for i in range(1, BC):
108
+ if GATHER_SUPPORTED:
109
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
110
+ # [1, BK] -> [BK]
111
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
112
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
113
+ else:
114
+ mask = tl.arange(0, BC) == i
115
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
116
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
117
+ mask = tl.arange(0, BC) == i
118
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
119
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
120
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
121
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
122
+ b_A = tl.where(mask[:, None], b_a, b_A)
123
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
124
+
125
+ # blockwise computation of lower triangular matrix's inverse
126
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
127
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
128
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
129
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
130
+ # tl.debug_barrier()
131
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
132
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
133
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
134
+ # causal mask
135
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
136
+
137
+
138
+ @triton.heuristics({
139
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
140
+ })
141
+ @triton.autotune(
142
+ configs=[
143
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
144
+ for num_warps in [2, 4, 8, 16]
145
+ for num_stages in [2, 3, 4]
146
+ ],
147
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
148
+ use_cuda_graph=use_cuda_graph,
149
+ )
150
+ @triton.jit(do_not_specialize=['T'])
151
+ def wu_fwd_kernel(
152
+ w,
153
+ u,
154
+ ag,
155
+ v,
156
+ A_ab_inv,
157
+ A_ak,
158
+ cu_seqlens,
159
+ chunk_indices,
160
+ T,
161
+ H: tl.constexpr,
162
+ K: tl.constexpr,
163
+ V: tl.constexpr,
164
+ BT: tl.constexpr,
165
+ BK: tl.constexpr,
166
+ BV: tl.constexpr,
167
+ IS_VARLEN: tl.constexpr,
168
+ ):
169
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
170
+ i_b, i_h = i_bh // H, i_bh % H
171
+ if IS_VARLEN:
172
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
173
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
174
+ T = eos - bos
175
+ else:
176
+ bos, eos = i_b * T, i_b * T + T
177
+ o_s = tl.arange(0, BT)
178
+
179
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
180
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+
182
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
183
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
184
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
185
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
186
+ # let's use tf32 here
187
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
188
+ # (SY 01/04) should be bf16 or tf32? To verify.
189
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
190
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
191
+
192
+ for i_k in range(tl.cdiv(K, BK)):
193
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
195
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
196
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
197
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
198
+
199
+ for i_v in range(tl.cdiv(V, BV)):
200
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
201
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
202
+ b_v = tl.load(p_v, boundary_check=(0, 1))
203
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
204
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
205
+
206
+
207
+ def wu_fwd(
208
+ ag: torch.Tensor,
209
+ v: torch.Tensor,
210
+ A_ak: torch.Tensor,
211
+ A_ab_inv: torch.Tensor,
212
+ cu_seqlens: Optional[torch.LongTensor],
213
+ chunk_size: int
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ B, T, H, K, V = *ag.shape, v.shape[-1]
216
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
217
+
218
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
219
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
220
+ BK = min(triton.next_power_of_2(K), 64)
221
+ BV = min(triton.next_power_of_2(V), 64)
222
+
223
+ w = torch.empty_like(ag)
224
+ u = torch.empty_like(v)
225
+ wu_fwd_kernel[(NT, B * H)](
226
+ ag=ag,
227
+ v=v,
228
+ A_ak=A_ak,
229
+ A_ab_inv=A_ab_inv,
230
+ w=w,
231
+ u=u,
232
+ cu_seqlens=cu_seqlens,
233
+ chunk_indices=chunk_indices,
234
+ T=T,
235
+ H=H,
236
+ K=K,
237
+ V=V,
238
+ BT=BT,
239
+ BK=BK,
240
+ BV=BV,
241
+ )
242
+ return w, u
243
+
244
+
245
+ def prepare_wy_repr_fwd(
246
+ ag: torch.Tensor,
247
+ v: torch.Tensor,
248
+ A_ak: torch.Tensor,
249
+ A_ab: torch.Tensor,
250
+ cu_seqlens: Optional[torch.LongTensor],
251
+ chunk_size: int = 64
252
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
253
+ B, T, H, _ = ag.shape
254
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
255
+
256
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
257
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
258
+ BC = min(BT, 32)
259
+ fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
260
+ A_ab_inv = torch.empty_like(A_ab)
261
+ fwd_fn[(NT, B * H)](
262
+ A_ab=A_ab,
263
+ A_ab_inv=A_ab_inv,
264
+ cu_seqlens=cu_seqlens,
265
+ chunk_indices=chunk_indices,
266
+ T=T,
267
+ H=H,
268
+ BT=BT,
269
+ BC=BC,
270
+ )
271
+ w, u = wu_fwd(
272
+ ag=ag,
273
+ v=v,
274
+ A_ak=A_ak,
275
+ A_ab_inv=A_ab_inv,
276
+ cu_seqlens=cu_seqlens,
277
+ chunk_size=BT
278
+ )
279
+ return w, u, A_ab_inv
280
+
281
+
282
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
283
+
284
+ fwd_wu = wu_fwd
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_iplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_iplr_delta_rule',
6
+ 'fused_recurrent_iplr_delta_rule'
7
+ ]
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/chunk.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from ....ops.generalized_delta_rule.iplr.wy_fast import prepare_wy_repr_fwd
13
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
14
+ from ....utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
15
+
16
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
17
+
18
+
19
+ @triton.heuristics({
20
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
21
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
22
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
23
+ })
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=num_warps)
27
+ for num_warps in [2, 4, 8, 16]
28
+ ],
29
+ key=['BT', 'BK', 'BV'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ b,
38
+ u,
39
+ v_new,
40
+ h,
41
+ h0,
42
+ ht,
43
+ cu_seqlens,
44
+ chunk_offsets,
45
+ T,
46
+ H: tl.constexpr,
47
+ K: tl.constexpr,
48
+ V: tl.constexpr,
49
+ BT: tl.constexpr,
50
+ BC: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ BV: tl.constexpr,
53
+ USE_INITIAL_STATE: tl.constexpr,
54
+ STORE_FINAL_STATE: tl.constexpr,
55
+ IS_VARLEN: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if IS_VARLEN:
60
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
77
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
78
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
79
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
80
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
81
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
83
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
84
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
87
+ # [BK, BC]
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ b_v = tl.load(p_v, boundary_check=(0, 1))
90
+ b_d = tl.load(p_d, boundary_check=(0, 1))
91
+ b_b = tl.load(p_b, boundary_check=(0, 1))
92
+ b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
93
+ b_hc += tl.dot(b_k, b_v)
94
+ b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
95
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
96
+ b_h += b_hc
97
+
98
+ if STORE_FINAL_STATE:
99
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
100
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
101
+
102
+
103
+ @triton.heuristics({
104
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
105
+ })
106
+ @triton.autotune(
107
+ configs=[
108
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
109
+ for BK in BKV_LIST
110
+ for BV in BKV_LIST
111
+ for num_warps in [2, 4, 8]
112
+ for num_stages in [2, 3]
113
+ ],
114
+ key=['BT'],
115
+ use_cuda_graph=use_cuda_graph,
116
+ )
117
+ @triton.jit(do_not_specialize=['T'])
118
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
119
+ q,
120
+ k,
121
+ v,
122
+ u,
123
+ b,
124
+ h,
125
+ o,
126
+ cu_seqlens,
127
+ chunk_indices,
128
+ scale,
129
+ T,
130
+ H: tl.constexpr,
131
+ K: tl.constexpr,
132
+ V: tl.constexpr,
133
+ BT: tl.constexpr,
134
+ BK: tl.constexpr,
135
+ BV: tl.constexpr,
136
+ IS_VARLEN: tl.constexpr,
137
+ ):
138
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
139
+ i_b, i_h = i_bh // H, i_bh % H
140
+
141
+ if IS_VARLEN:
142
+ i_tg = i_t
143
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
144
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
145
+ T = eos - bos
146
+ NT = tl.cdiv(T, BT)
147
+ else:
148
+ NT = tl.cdiv(T, BT)
149
+ i_tg = i_b * NT + i_t
150
+ bos, eos = i_b * T, i_b * T + T
151
+
152
+ # offset calculation
153
+ q += (bos * H + i_h) * K
154
+ k += (bos * H + i_h) * K
155
+ b += (bos * H + i_h) * K
156
+ v += (bos * H + i_h) * V
157
+ u += (bos * H + i_h) * V
158
+ o += (bos * H + i_h) * V
159
+ h += (i_tg * H + i_h) * K * V
160
+ stride_qk = H*K
161
+ stride_vo = H*V
162
+
163
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
164
+ b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
165
+ b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
166
+
167
+ for i_k in range(tl.cdiv(K, BK)):
168
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
169
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
170
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
171
+ p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
172
+ # [BT, BK]
173
+ b_q = tl.load(p_q, boundary_check=(0, 1))
174
+ # [BK, BT]
175
+ b_k = tl.load(p_k, boundary_check=(0, 1))
176
+ b_b = tl.load(p_b, boundary_check=(0, 1))
177
+ # [BK, BV]
178
+ b_h = tl.load(p_h, boundary_check=(0, 1))
179
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
180
+ b_o += tl.dot(b_q, b_h)
181
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
182
+ b_Aqk += tl.dot(b_q, b_k)
183
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
184
+ b_Aqb += tl.dot(b_q, b_b)
185
+
186
+ o_i = tl.arange(0, BT)
187
+ m_A = o_i[:, None] >= o_i[None, :]
188
+ b_Aqk = tl.where(m_A, b_Aqk, 0)
189
+ b_Aqb = tl.where(m_A, b_Aqb, 0)
190
+
191
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
192
+ p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
193
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
194
+ b_v = tl.load(p_v, boundary_check=(0, 1))
195
+ b_u = tl.load(p_u, boundary_check=(0, 1))
196
+ b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
197
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
198
+
199
+
200
+ def chunk_generalized_iplr_delta_rule_fwd_o(
201
+ q: torch.Tensor,
202
+ k: torch.Tensor,
203
+ v: torch.Tensor,
204
+ v_new: torch.Tensor,
205
+ b: torch.Tensor,
206
+ h: torch.Tensor,
207
+ scale: Optional[float] = None,
208
+ cu_seqlens: Optional[torch.LongTensor] = None,
209
+ chunk_size: int = 64
210
+ ) -> torch.Tensor:
211
+ B, T, H, K, V = *q.shape, v.shape[-1]
212
+ if scale is None:
213
+ scale = k.shape[-1] ** -0.5
214
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
215
+
216
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
217
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
218
+
219
+ o = torch.empty_like(v)
220
+
221
+ def grid(meta): return (
222
+ triton.cdiv(V, meta['BV']),
223
+ NT,
224
+ B * H
225
+ )
226
+ chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
227
+ q=q,
228
+ k=k,
229
+ v=v,
230
+ u=v_new,
231
+ b=b,
232
+ h=h,
233
+ o=o,
234
+ cu_seqlens=cu_seqlens,
235
+ chunk_indices=chunk_indices,
236
+ scale=scale,
237
+ T=T,
238
+ H=H,
239
+ K=K,
240
+ V=V,
241
+ BT=BT,
242
+ )
243
+ return o
244
+
245
+
246
+ def chunk_generalized_iplr_delta_rule_fwd_h(
247
+ k: torch.Tensor,
248
+ v: torch.Tensor,
249
+ w: torch.Tensor,
250
+ u: torch.Tensor,
251
+ b: torch.Tensor,
252
+ initial_state: Optional[torch.Tensor] = None,
253
+ output_final_state: bool = False,
254
+ cu_seqlens: Optional[torch.LongTensor] = None,
255
+ chunk_size: int = 64
256
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
257
+ B, T, H, K, V = *k.shape, u.shape[-1]
258
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
259
+
260
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
261
+ # N: the actual number of sequences in the batch with either equal or variable lengths
262
+ if cu_seqlens is None:
263
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
264
+ else:
265
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
266
+
267
+ BK = triton.next_power_of_2(K)
268
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
269
+ # H100 can have larger block size
270
+
271
+ if check_shared_mem('hopper', k.device.index):
272
+ BV = 64
273
+ BC = 64 if K <= 128 else 32
274
+ elif check_shared_mem('ampere', k.device.index): # A100
275
+ BV = 32
276
+ BC = 32
277
+ else:
278
+ BV = 16
279
+ BC = 16
280
+
281
+ BC = min(BT, BC)
282
+ NK = triton.cdiv(K, BK)
283
+ NV = triton.cdiv(V, BV)
284
+
285
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
286
+
287
+ h = k.new_empty(B, NT, H, K, V)
288
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
289
+
290
+ v_new = torch.empty_like(u)
291
+ grid = (NK, NV, N * H)
292
+
293
+ chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
294
+ k=k,
295
+ v=v,
296
+ d=w,
297
+ b=b,
298
+ u=u,
299
+ v_new=v_new,
300
+ h=h,
301
+ h0=initial_state,
302
+ ht=final_state,
303
+ cu_seqlens=cu_seqlens,
304
+ chunk_offsets=chunk_offsets,
305
+ T=T,
306
+ H=H,
307
+ K=K,
308
+ V=V,
309
+ BT=BT,
310
+ BC=BC,
311
+ BK=BK,
312
+ BV=BV,
313
+ )
314
+ return h, v_new, final_state
315
+
316
+
317
+ def chunk_generalized_iplr_delta_rule_fwd(
318
+ q: torch.Tensor,
319
+ k: torch.Tensor,
320
+ v: torch.Tensor,
321
+ a: torch.Tensor,
322
+ b: torch.Tensor,
323
+ scale: float,
324
+ initial_state: torch.Tensor,
325
+ output_final_state: bool,
326
+ cu_seqlens: Optional[torch.LongTensor] = None,
327
+ chunk_size: int = 64
328
+ ):
329
+ T = q.shape[1]
330
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
331
+ w, u, _ = prepare_wy_repr_fwd(
332
+ a=a,
333
+ b=b,
334
+ k=k,
335
+ v=v,
336
+ cu_seqlens=cu_seqlens,
337
+ chunk_size=BT
338
+ )
339
+
340
+ h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
341
+ k=k,
342
+ v=v,
343
+ b=b,
344
+ w=w,
345
+ u=u,
346
+ initial_state=initial_state,
347
+ output_final_state=output_final_state,
348
+ cu_seqlens=cu_seqlens,
349
+ chunk_size=BT
350
+ )
351
+ o = chunk_generalized_iplr_delta_rule_fwd_o(
352
+ q=q,
353
+ k=k,
354
+ v=v,
355
+ v_new=v_new,
356
+ b=b,
357
+ h=h,
358
+ scale=scale,
359
+ cu_seqlens=cu_seqlens,
360
+ chunk_size=BT
361
+ )
362
+ return o, final_state
363
+
364
+
365
+ class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
366
+
367
+ @staticmethod
368
+ @input_guard
369
+ @autocast_custom_fwd
370
+ def forward(
371
+ ctx,
372
+ q: torch.Tensor,
373
+ k: torch.Tensor,
374
+ v: torch.Tensor,
375
+ a: torch.Tensor,
376
+ b: torch.Tensor,
377
+ scale: float,
378
+ initial_state: torch.Tensor,
379
+ output_final_state: bool,
380
+ cu_seqlens: Optional[torch.LongTensor] = None,
381
+ ):
382
+ chunk_size = 64
383
+
384
+ o, final_state = chunk_generalized_iplr_delta_rule_fwd(
385
+ q=q,
386
+ k=k,
387
+ v=v,
388
+ a=a,
389
+ b=b,
390
+ scale=scale,
391
+ initial_state=initial_state,
392
+ output_final_state=output_final_state,
393
+ cu_seqlens=cu_seqlens,
394
+ chunk_size=chunk_size
395
+ )
396
+ return o.to(q.dtype), final_state
397
+
398
+ @staticmethod
399
+ @input_guard
400
+ @autocast_custom_bwd
401
+ def backward(
402
+ ctx,
403
+ do: torch.Tensor,
404
+ dht: torch.Tensor
405
+ ):
406
+ raise NotImplementedError(
407
+ "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
408
+ "Stay tuned!"
409
+ )
410
+
411
+
412
+ @torch.compiler.disable
413
+ def chunk_iplr_delta_rule(
414
+ q: torch.Tensor,
415
+ k: torch.Tensor,
416
+ v: torch.Tensor,
417
+ a: torch.Tensor,
418
+ b: torch.Tensor,
419
+ scale: float = None,
420
+ initial_state: torch.Tensor = None,
421
+ output_final_state: bool = False,
422
+ cu_seqlens: Optional[torch.LongTensor] = None,
423
+ head_first: bool = False
424
+ ):
425
+ r"""
426
+ Args:
427
+ q (torch.Tensor):
428
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
429
+ k (torch.Tensor):
430
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
431
+ v (torch.Tensor):
432
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
433
+ a (torch.Tensor):
434
+ activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
435
+ b (torch.Tensor):
436
+ betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
437
+ scale (Optional[int]):
438
+ Scale factor for the RetNet attention scores.
439
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
440
+ initial_state (Optional[torch.Tensor]):
441
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
442
+ For equal-length input sequences, `N` equals the batch size `B`.
443
+ Default: `None`.
444
+ output_final_state (Optional[bool]):
445
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
446
+ cu_seqlens (torch.LongTensor):
447
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
448
+ consistent with the FlashAttention API.
449
+ head_first (Optional[bool]):
450
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
451
+ Default: `False`.
452
+
453
+ Returns:
454
+ o (torch.Tensor):
455
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
456
+ final_state (torch.Tensor):
457
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
458
+ """
459
+ assert q.dtype == k.dtype == v.dtype
460
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
461
+
462
+ if head_first:
463
+ raise DeprecationWarning(
464
+ "head_first is deprecated and will be removed in a future version. "
465
+ "Please use head_first=False for now instead."
466
+ )
467
+ q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
468
+ if not head_first and q.shape[1] < q.shape[2]:
469
+ warnings.warn(
470
+ f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
471
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
472
+ "when head_first=False was specified. "
473
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
474
+ )
475
+ if cu_seqlens is not None:
476
+ if q.shape[0] != 1:
477
+ raise ValueError(
478
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
479
+ f"Please ...tten variable-length inputs before processing."
480
+ )
481
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
482
+ raise ValueError(
483
+ f"The number of initial states is expected to be equal to the number of input sequences, "
484
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
485
+ )
486
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
487
+ o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
488
+ q,
489
+ k,
490
+ v,
491
+ a,
492
+ b,
493
+ scale,
494
+ initial_state,
495
+ output_final_state,
496
+ cu_seqlens,
497
+ )
498
+ if head_first:
499
+ o = rearrange(o, 'b t h ... -> b h t ...')
500
+ return o, final_state
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....utils import input_guard
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BV in [32, 64]
22
+ for num_warps in [2, 4, 8, 16]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=["BK"],
26
+ )
27
+ @triton.jit
28
+ def fused_recurrent_fwd_kernel(
29
+ q, # query [B, H, L, K]
30
+ k, # key [B, H, L, V]
31
+ v, # value [B, H, L, V].
32
+ a, # a [B, H, L, K]
33
+ b, # b [B, H, L, K]
34
+ o, # output [B, H, L, V]
35
+ ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
36
+ h0, # initial hidden state [B, H, K, V]
37
+ ht, # final hidden state [B, H, K, V]
38
+ cu_seqlens, # varlen cu_seqlens
39
+ scale, # K ** -0.5
40
+ H, # n_heads
41
+ T, # seq_len
42
+ K: tl.constexpr, # K
43
+ V: tl.constexpr, # V
44
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
45
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
46
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
47
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
51
+ i_n, i_h = i_nh // H, i_nh % H
52
+
53
+ if IS_VARLEN:
54
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
55
+ T = eos - bos
56
+ else:
57
+ bos, eos = i_n * T, i_n * T + T
58
+
59
+ p_q = q + (bos * H + i_h) * K + tl.arange(0, BK)
60
+ p_k = k + (bos * H + i_h) * K + tl.arange(0, BK)
61
+ p_a = a + (bos * H + i_h) * K + tl.arange(0, BK)
62
+ p_b = b + (bos * H + i_h) * K + tl.arange(0, BK)
63
+ p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
64
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
65
+ p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
66
+
67
+ mask_k = tl.arange(0, BK) < K
68
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
69
+ mask_h = mask_k[None, :] & mask_v[:, None]
70
+
71
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
72
+
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
75
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
76
+
77
+ for _ in range(0, T):
78
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
81
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
82
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
83
+ # to store
84
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
85
+ b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
86
+ b_o = b_h * b_q[None, :]
87
+ b_o = tl.sum(b_o, axis=1)
88
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
89
+ tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v)
90
+ p_q += K*H
91
+ p_k += K*H
92
+ p_o += V*H
93
+ p_v += V*H
94
+ p_ha += V*H
95
+ p_a += K*H
96
+ p_b += K*H
97
+
98
+ if STORE_FINAL_STATE:
99
+ p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
100
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
101
+
102
+
103
+ @triton.heuristics({
104
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
105
+ 'USE_DHT': lambda args: args['dht'] is not None,
106
+ 'USE_DH0': lambda args: args['dh0'] is not None,
107
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
108
+ })
109
+ @triton.autotune(
110
+ configs=[
111
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
112
+ for num_warps in [2, 4, 8, 16]
113
+ for num_stages in [2, 3]
114
+ ],
115
+ key=["BK", "BV"],
116
+ )
117
+ @triton.jit
118
+ def fused_recurrent_bwd_kernel(
119
+ # B: batch_size, H: n_heads, T: seq_len, D: b_dhead
120
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
121
+ q, # query [B, H, L, K]
122
+ k, # key [B, H, L, V]
123
+ v, # value [B, H, L, V]
124
+ a, # a [B, H, L, K]
125
+ b, # b [B, H, L, K]
126
+ ha, # ha [B, H, L, V]
127
+ dht, # gradient of final state [B, H, K, V]
128
+ dh0, # gradient of initial state [B, H, K, V]
129
+ do, # gradient of output [B, H, L, V]
130
+ dq, # gradient of query [NV, B, H, L, K]
131
+ dk, # gradient of key [NV, B, H, L, K]
132
+ dv, # gradient of value [NK, B, H, L, V]
133
+ da, # gradient of a [NV, B, H, L, K]
134
+ db, # gradient of b [NV, B, H, L, K]
135
+ dha, # gradient of ha [NK, B, H, L, V]
136
+ h0, # initial state [B, H, K, V]
137
+ scale, # K ** -0.5
138
+ cu_seqlens, # cu_seqlens
139
+ B, # batch_size
140
+ H, # n_heads
141
+ T, # seq_len
142
+ K: tl.constexpr, # K
143
+ V: tl.constexpr, # V
144
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
145
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
146
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0
147
+ USE_DH0: tl.constexpr, # whether to use dh0
148
+ USE_DHT: tl.constexpr, # whether to use dht
149
+ IS_VARLEN: tl.constexpr,
150
+ ):
151
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
152
+ i_n, i_h = i_nh // H, i_nh % H
153
+ dk += i_v * B * H * K * T
154
+ db += i_v * B * H * K * T
155
+ dq += i_v * B * H * K * T
156
+ da += i_v * B * H * K * T
157
+ if IS_VARLEN:
158
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
159
+ T = eos - bos
160
+ else:
161
+ bos, eos = i_n * T, i_n * T + T
162
+ mask_k = tl.arange(0, BK) < K
163
+ mask_v = (tl.arange(0, BV) + i_v * BV) < V
164
+
165
+ q += (bos * H + i_h) * K
166
+ k += (bos * H + i_h) * K
167
+ v += (bos * H + i_h) * V + i_v * BV
168
+ ha += (bos * H + i_h) * V + i_v * BV
169
+ a += (bos * H + i_h) * K
170
+ b += (bos * H + i_h) * K
171
+ do += (bos * H + i_h) * V + i_v * BV
172
+ dq += (bos * H + i_h) * K
173
+ dk += (bos * H + i_h) * K
174
+ dv += (bos * H + i_h) * V + i_v * BV
175
+ da += (bos * H + i_h) * K
176
+ db += (bos * H + i_h) * K
177
+ dha += (bos * H + i_h) * V + i_v * BV
178
+
179
+ p_q = q + tl.arange(0, BK) + (T - 1) * H*K
180
+ p_k = k + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_v = v + tl.arange(0, BV) + (T - 1) * H*V
182
+ p_ha = ha + tl.arange(0, BV) + (T - 1) * H*V
183
+ p_a = a + tl.arange(0, BK) + (T - 1) * H*K
184
+ p_b = b + tl.arange(0, BK) + (T - 1) * H*K
185
+ p_do = do + tl.arange(0, BV) + (T - 1) * H*V
186
+ p_dk = dk + tl.arange(0, BK) + (T - 1) * H*K
187
+ p_dv = dv + tl.arange(0, BV) + (T - 1) * H*V
188
+ p_dha = dha + tl.arange(0, BV) + (T - 1) * H*V
189
+ p_db = db + tl.arange(0, BK) + (T - 1) * H*K
190
+ p_da = da + tl.arange(0, BK) + (T - 1) * H*K
191
+ p_dq = dq + tl.arange(0, BK) + (T - 1) * H*K
192
+
193
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
194
+ if USE_DHT:
195
+ p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
196
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
197
+
198
+ for _ in range(T):
199
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
200
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
201
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
202
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
203
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
204
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
205
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
206
+
207
+ b_dh += b_q[:, None] * b_do[None, :]
208
+ d_k = tl.sum(b_dh * b_v[None, :], axis=1)
209
+ d_v = tl.sum(b_dh * b_k[:, None], axis=0)
210
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v)
212
+
213
+ b_dha = tl.sum(b_dh * b_b[:, None], axis=0)
214
+ tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v)
215
+ b_db = tl.sum(b_dh * b_ha[None, :], axis=1)
216
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k)
217
+
218
+ b_dh += b_dha[None, :] * b_a[:, None]
219
+ p_do -= H*V
220
+ p_q -= H*K
221
+ p_k -= H*K
222
+ p_v -= H*V
223
+ p_dk -= H*K
224
+ p_dv -= H*V
225
+ p_b -= H*K
226
+ p_db -= H*K
227
+ p_a -= H*K
228
+ p_dha -= H*V
229
+ p_ha -= H*V
230
+
231
+ if USE_DH0:
232
+ p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
233
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
234
+
235
+ tl.debug_barrier()
236
+
237
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
238
+
239
+ if USE_INITIAL_STATE:
240
+ mask_kv = mask_k[:, None] & mask_v[None, :]
241
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
242
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
243
+
244
+ p_k = k + tl.arange(0, BK)
245
+ p_v = v + tl.arange(0, BV)
246
+ p_ha = ha + tl.arange(0, BV)
247
+ p_do = do + tl.arange(0, BV)
248
+ p_dha = dha + tl.arange(0, BV)
249
+ p_da = da + tl.arange(0, BK)
250
+ p_dq = dq + tl.arange(0, BK)
251
+ p_b = b + tl.arange(0, BK)
252
+
253
+ for i in range(0, T):
254
+ b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32)
255
+ d_a = tl.sum(b_dha[None, :] * b_h, axis=1)
256
+ tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k)
257
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
258
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
259
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
260
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
261
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
262
+ b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :]
263
+ _d_q = b_h * b_do[None, :]
264
+ d_q = tl.sum(_d_q, axis=1) * scale
265
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
266
+
267
+ p_k += H*K
268
+ p_do += H*V
269
+ p_v += H*V
270
+ p_da += H*K
271
+ p_dha += H*V
272
+ p_ha += H*V
273
+ p_dq += H*K
274
+ p_b += H*K
275
+
276
+
277
+ class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function):
278
+
279
+ @staticmethod
280
+ @input_guard
281
+ def forward(
282
+ ctx,
283
+ q: torch.Tensor,
284
+ k: torch.Tensor,
285
+ v: torch.Tensor,
286
+ a: torch.Tensor,
287
+ b: torch.Tensor,
288
+ scale: Optional[float] = None,
289
+ initial_state: Optional[torch.Tensor] = None,
290
+ output_final_state: bool = False,
291
+ cu_seqlens: Optional[torch.LongTensor] = None
292
+ ):
293
+ B, T, H, K, V = *k.shape, v.shape[-1]
294
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
295
+
296
+ BK = triton.next_power_of_2(K)
297
+ if output_final_state:
298
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float32)
299
+ else:
300
+ final_state = None
301
+
302
+ ha = torch.empty_like(v, dtype=torch.float32)
303
+
304
+ def grid(meta): return (
305
+ triton.cdiv(V, meta['BV']),
306
+ N * H
307
+ )
308
+ o = torch.empty_like(v)
309
+ fused_recurrent_fwd_kernel[grid](
310
+ q=q,
311
+ k=k,
312
+ v=v,
313
+ a=a,
314
+ b=b,
315
+ o=o,
316
+ ha=ha,
317
+ h0=initial_state,
318
+ ht=final_state,
319
+ scale=scale,
320
+ cu_seqlens=cu_seqlens,
321
+ H=H,
322
+ T=T,
323
+ K=K,
324
+ V=V,
325
+ BK=BK,
326
+ )
327
+ ctx.save_for_backward(q, k, v, a, b, ha, initial_state)
328
+ ctx.scale = scale
329
+ ctx.cu_seqlens = cu_seqlens
330
+ return o, final_state
331
+
332
+ @staticmethod
333
+ @input_guard
334
+ def backward(ctx, do, dht):
335
+ q, k, v, a, b, ha, initial_state = ctx.saved_tensors
336
+ B, T, H, K, V = *q.shape, v.shape[-1]
337
+ N = B if ctx.cu_seqlens is None else len(ctx.cu_seqlens) - 1
338
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
339
+ NV = triton.cdiv(V, BV)
340
+ scale = ctx.scale
341
+
342
+ dq = q.new_empty(NV, *q.shape)
343
+ dk = k.new_empty(NV, *k.shape)
344
+ da = a.new_empty(NV, *a.shape)
345
+ db = b.new_empty(NV, *b.shape)
346
+ dv = torch.empty_like(v)
347
+ dha = torch.empty_like(ha)
348
+ grid = (NV, N * H)
349
+
350
+ if initial_state is not None and initial_state.requires_grad:
351
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
352
+ else:
353
+ dh0 = None
354
+
355
+ fused_recurrent_bwd_kernel[grid](
356
+ q=q,
357
+ k=k,
358
+ v=v,
359
+ a=a,
360
+ b=b,
361
+ ha=ha,
362
+ dht=dht,
363
+ dh0=dh0,
364
+ do=do,
365
+ dq=dq,
366
+ dk=dk,
367
+ dv=dv,
368
+ da=da,
369
+ db=db,
370
+ dha=dha,
371
+ h0=initial_state,
372
+ scale=scale,
373
+ cu_seqlens=ctx.cu_seqlens,
374
+ B=B,
375
+ H=H,
376
+ T=T,
377
+ K=K,
378
+ V=V,
379
+ BK=BK,
380
+ BV=BV,
381
+ )
382
+ dq = dq.sum(0)
383
+ dk = dk.sum(0)
384
+ da = da.sum(0)
385
+ db = db.sum(0)
386
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None
387
+
388
+
389
+ def fused_recurrent_iplr_delta_rule(
390
+ q: torch.Tensor,
391
+ k: torch.Tensor,
392
+ v: torch.Tensor,
393
+ a: torch.Tensor,
394
+ b: torch.Tensor,
395
+ scale: float = None,
396
+ initial_state: torch.Tensor = None,
397
+ output_final_state: bool = False,
398
+ cu_seqlens: Optional[torch.Tensor] = None,
399
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
400
+ r"""
401
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
402
+
403
+ Args:
404
+ q (torch.Tensor):
405
+ queries of shape `[B, T, H, K]`
406
+ k (torch.Tensor):
407
+ keys of shape `[B, T, H, K]`
408
+ v (torch.Tensor):
409
+ values of shape `[B, T, H, V]`
410
+ a (torch.Tensor):
411
+ as of shape `[B, T, H, K]`
412
+ b (torch.Tensor):
413
+ bs of shape `[B, T, H, K]`
414
+ scale (Optional[int]):
415
+ Scale factor for the RetNet attention scores.
416
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
417
+ initial_state (Optional[torch.Tensor]):
418
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
419
+ output_final_state (Optional[bool]):
420
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
421
+ cu_seqlens (torch.LongTensor):
422
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
423
+ consistent with the FlashAttention API.
424
+
425
+ """
426
+ if cu_seqlens is not None:
427
+ if q.shape[0] != 1:
428
+ raise ValueError(
429
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
430
+ f"Please flatten variable-length inputs before processing."
431
+ )
432
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
433
+ raise ValueError(
434
+ f"The number of initial states is expected to be equal to the number of input sequences, "
435
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
436
+ )
437
+ if scale is None:
438
+ scale = q.shape[-1] ** -0.5
439
+ else:
440
+ assert scale > 0, "scale must be positive"
441
+ o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
442
+ q,
443
+ k,
444
+ v,
445
+ a,
446
+ b,
447
+ scale,
448
+ initial_state,
449
+ output_final_state,
450
+ cu_seqlens
451
+ )
452
+ return o, final_state
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/naive.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
8
+ # q, k, alpha, beta [B, H, L, D_K]
9
+ # v [B, H, L, D_V]
10
+ def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True):
11
+ orig_dtype = q.dtype
12
+ b, h, l, d_k = q.shape
13
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
14
+ d_v = v.shape[-1]
15
+ o = torch.zeros_like(v)
16
+ S = torch.zeros(b, h, d_k, d_v).to(v)
17
+ q = q * (d_k ** -0.5)
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i]
26
+ _alpha = alpha[:, :, i]
27
+ _beta = beta[:, :, i]
28
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
29
+ S = S + _kv
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v
40
+ assert l % chunk_size == 0
41
+
42
+ S = k.new_zeros(b, h, d_k, d_v)
43
+ if initial_state is not None:
44
+ S += initial_state
45
+
46
+ # note that diagonal is masked.
47
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
48
+ q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta])
49
+
50
+ v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
51
+ attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0)
52
+ for i in range(1, chunk_size):
53
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
54
+
55
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
56
+ u = attn @ v2
57
+ w = attn @ alpha
58
+ o = torch.zeros_like(v)
59
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
60
+ for i in range(0, l // chunk_size):
61
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
62
+ o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i
63
+ v2_i = u_i + w_i @ S
64
+ o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i)
65
+ o_3 = q_i @ S
66
+ o[:, :, i] = o_1 + o_2 + o_3
67
+ S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i
68
+ S = None if output_final_state is False else S
69
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ....ops.utils import prepare_chunk_indices
12
+ from ....utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4, 8, 16]
24
+ ],
25
+ key=['BK']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def prepare_wy_repr_fwd_kernel_chunk32(
29
+ a,
30
+ b,
31
+ A,
32
+ cu_seqlens,
33
+ chunk_indices,
34
+ T,
35
+ H: tl.constexpr,
36
+ K: tl.constexpr,
37
+ BT: tl.constexpr,
38
+ BK: tl.constexpr,
39
+ BC: tl.constexpr, # dummy placeholder
40
+ IS_VARLEN: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if IS_VARLEN:
45
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
54
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
55
+ b_a = tl.load(p_a, boundary_check=(0, 1))
56
+ b_b = tl.load(p_b, boundary_check=(0, 1))
57
+ b_A += tl.dot(b_a, b_b)
58
+
59
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
60
+ for i in range(1, BT):
61
+ mask = tl.arange(0, BT) == i
62
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
63
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
64
+ b_A = tl.where(mask[:, None], b_a, b_A)
65
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
66
+
67
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
68
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
69
+
70
+
71
+ @triton.heuristics({
72
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
73
+ })
74
+ @triton.autotune(
75
+ configs=[
76
+ triton.Config({}, num_warps=num_warps)
77
+ for num_warps in [1, 2, 4, 8, 16]
78
+ ],
79
+ key=['BK']
80
+ )
81
+ @triton.jit(do_not_specialize=['T'])
82
+ def prepare_wy_repr_fwd_kernel_chunk64(
83
+ a,
84
+ b,
85
+ A,
86
+ cu_seqlens,
87
+ chunk_indices,
88
+ T,
89
+ H: tl.constexpr,
90
+ K: tl.constexpr,
91
+ BT: tl.constexpr,
92
+ BK: tl.constexpr,
93
+ BC: tl.constexpr,
94
+ IS_VARLEN: tl.constexpr,
95
+ ):
96
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
97
+ i_b, i_h = i_bh // H, i_bh % H
98
+ if IS_VARLEN:
99
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
100
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
101
+ T = eos - bos
102
+ else:
103
+ bos, eos = i_b * T, i_b * T + T
104
+
105
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
106
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
107
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
108
+
109
+ for i_k in range(tl.cdiv(K, BK)):
110
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
111
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
112
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
113
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
114
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
115
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
116
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
117
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
118
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
119
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
120
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
121
+
122
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
123
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
124
+
125
+ for i in range(1, BC):
126
+ mask = tl.arange(0, BC) == i
127
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
128
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
129
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
130
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
131
+ b_A = tl.where(mask[:, None], b_a, b_A)
132
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
133
+
134
+ # blockwise computation of lower triangular matrix's inverse
135
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
136
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
137
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
138
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
139
+
140
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
141
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
142
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
143
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
144
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
145
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
146
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
147
+ # causal mask
148
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
149
+
150
+
151
+ @triton.heuristics({
152
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
153
+ })
154
+ @triton.autotune(
155
+ configs=[
156
+ triton.Config({}, num_warps=num_warps)
157
+ for num_warps in NUM_WARPS
158
+ ],
159
+ key=['BT', 'BK', 'BV']
160
+ )
161
+ @triton.jit(do_not_specialize=['T'])
162
+ def wu_fwd_kernel(
163
+ w,
164
+ u,
165
+ a,
166
+ k,
167
+ v,
168
+ A,
169
+ cu_seqlens,
170
+ chunk_indices,
171
+ T,
172
+ H: tl.constexpr,
173
+ K: tl.constexpr,
174
+ V: tl.constexpr,
175
+ BT: tl.constexpr,
176
+ BK: tl.constexpr,
177
+ BV: tl.constexpr,
178
+ IS_VARLEN: tl.constexpr,
179
+ ):
180
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
181
+ i_b, i_h = i_bh // H, i_bh % H
182
+ if IS_VARLEN:
183
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
184
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
185
+ T = eos - bos
186
+ else:
187
+ bos, eos = i_b * T, i_b * T + T
188
+
189
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
190
+
191
+ b_A = tl.load(p_A, boundary_check=(0, 1))
192
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
193
+
194
+ for i_k in range(tl.cdiv(K, BK)):
195
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
198
+ b_k = tl.load(p_k, boundary_check=(0, 1))
199
+ b_a = tl.load(p_a, boundary_check=(0, 1))
200
+ b_w = tl.dot(b_A, b_a)
201
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
202
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
203
+
204
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
205
+ b_Aak = b_Aak.to(k.dtype.element_ty)
206
+
207
+ for i_v in range(tl.cdiv(V, BV)):
208
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
209
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
210
+ b_v = tl.load(p_v, boundary_check=(0, 1))
211
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
212
+ b_u = tl.dot(b_A, b_v)
213
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
214
+
215
+
216
+ def prepare_wy_repr_fwd(
217
+ a: torch.Tensor,
218
+ b: torch.Tensor,
219
+ v: torch.Tensor,
220
+ k: torch.Tensor,
221
+ cu_seqlens: Optional[torch.LongTensor],
222
+ chunk_size: int = 64
223
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
224
+ B, T, H, K = a.shape
225
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
226
+
227
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
228
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
229
+ BC = min(BT, 32)
230
+ BK = min(triton.next_power_of_2(K), 64)
231
+
232
+ A = torch.empty(B, T, H, BT, device=a.device, dtype=a.dtype)
233
+ fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
234
+
235
+ fwd_fn[(NT, B * H)](
236
+ a=a,
237
+ b=b,
238
+ A=A,
239
+ cu_seqlens=cu_seqlens,
240
+ chunk_indices=chunk_indices,
241
+ T=T,
242
+ H=H,
243
+ K=K,
244
+ BT=BT,
245
+ BK=BK,
246
+ BC=BC,
247
+ )
248
+ w, u = wu_fwd(
249
+ a=a,
250
+ v=v,
251
+ k=k,
252
+ A=A,
253
+ cu_seqlens=cu_seqlens,
254
+ chunk_size=chunk_size
255
+ )
256
+ return w, u, A
257
+
258
+
259
+ def wu_fwd(
260
+ a: torch.Tensor,
261
+ v: torch.Tensor,
262
+ k: torch.Tensor,
263
+ A: torch.Tensor,
264
+ cu_seqlens: Optional[torch.LongTensor],
265
+ chunk_size: int
266
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
267
+ B, T, H, K, V = *a.shape, v.shape[-1]
268
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
269
+
270
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
271
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
272
+ CONST_TILING = 64 if check_shared_mem() else 32
273
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
274
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
275
+
276
+ u = torch.empty_like(v)
277
+ w = torch.empty_like(a)
278
+ wu_fwd_kernel[(NT, B*H)](
279
+ a=a,
280
+ v=v,
281
+ w=w,
282
+ u=u,
283
+ A=A,
284
+ k=k,
285
+ cu_seqlens=cu_seqlens,
286
+ chunk_indices=chunk_indices,
287
+ T=T,
288
+ H=H,
289
+ K=K,
290
+ V=V,
291
+ BT=BT,
292
+ BK=BK,
293
+ BV=BV,
294
+ )
295
+ return w, u
296
+
297
+
298
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
299
+
300
+ fwd_wu = wu_fwd
docs/en/.readthedocs.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2
2
+
3
+ # Set the version of Python and other tools you might need
4
+ build:
5
+ os: ubuntu-22.04
6
+ tools:
7
+ python: "3.8"
8
+
9
+ formats:
10
+ - epub
11
+
12
+ sphinx:
13
+ configuration: docs/en/conf.py
14
+
15
+ python:
16
+ install:
17
+ - requirements: requirements/docs.txt
docs/en/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = .
9
+ BUILDDIR = _build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/en/_static/css/readthedocs.css ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .header-logo {
2
+ background-image: url("../image/logo.svg");
3
+ background-size: 275px 80px;
4
+ height: 80px;
5
+ width: 275px;
6
+ }
7
+
8
+ @media screen and (min-width: 1100px) {
9
+ .header-logo {
10
+ top: -25px;
11
+ }
12
+ }
13
+
14
+ pre {
15
+ white-space: pre;
16
+ }
17
+
18
+ @media screen and (min-width: 2000px) {
19
+ .pytorch-content-left {
20
+ width: 1200px;
21
+ margin-left: 30px;
22
+ }
23
+ article.pytorch-article {
24
+ max-width: 1200px;
25
+ }
26
+ .pytorch-breadcrumbs-wrapper {
27
+ width: 1200px;
28
+ }
29
+ .pytorch-right-menu.scrolling-fixed {
30
+ position: fixed;
31
+ top: 45px;
32
+ left: 1580px;
33
+ }
34
+ }
35
+
36
+
37
+ article.pytorch-article section code {
38
+ padding: .2em .4em;
39
+ background-color: #f3f4f7;
40
+ border-radius: 5px;
41
+ }
42
+
43
+ /* Disable the change in tables */
44
+ article.pytorch-article section table code {
45
+ padding: unset;
46
+ background-color: unset;
47
+ border-radius: unset;
48
+ }
49
+
50
+ table.autosummary td {
51
+ width: 50%
52
+ }
53
+
54
+ img.align-center {
55
+ display: block;
56
+ margin-left: auto;
57
+ margin-right: auto;
58
+ }
59
+
60
+ article.pytorch-article p.rubric {
61
+ font-weight: bold;
62
+ }
docs/en/_static/image/logo.svg ADDED
docs/en/_static/image/logo_icon.svg ADDED
docs/en/_static/js/custom.js ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var collapsedSections = ['Dataset Statistics'];
2
+
3
+ $(document).ready(function () {
4
+ $('.dataset').DataTable({
5
+ "stateSave": false,
6
+ "lengthChange": false,
7
+ "pageLength": 20,
8
+ "order": [],
9
+ "language": {
10
+ "info": "Show _START_ to _END_ Items(Totally _TOTAL_ )",
11
+ "infoFiltered": "(Filtered from _MAX_ Items)",
12
+ "search": "Search:",
13
+ "zeroRecords": "Item Not Found",
14
+ "paginate": {
15
+ "next": "Next",
16
+ "previous": "Previous"
17
+ },
18
+ }
19
+ });
20
+ });
docs/en/_templates/404.html ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% extends "layout.html" %}
2
+
3
+ {% block body %}
4
+
5
+ <h1>Page Not Found</h1>
6
+ <p>
7
+ The page you are looking for cannot be found.
8
+ </p>
9
+ <p>
10
+ If you just switched documentation versions, it is likely that the page you were on is moved. You can look for it in
11
+ the content table left, or go to <a href="{{ pathto(root_doc) }}">the homepage</a>.
12
+ </p>
13
+ <!-- <p>
14
+ If you cannot find documentation you want, please <a
15
+ href="">open an issue</a> to tell us!
16
+ </p> -->
17
+
18
+ {% endblock %}
docs/en/_templates/autosummary/class.rst ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+ .. currentmodule:: {{ module }}
4
+
5
+
6
+ {{ name | underline}}
7
+
8
+ .. autoclass:: {{ name }}
9
+ :members:
10
+
11
+ ..
12
+ autogenerated from _templates/autosummary/class.rst
13
+ note it does not have :inherited-members:
docs/en/_templates/callable.rst ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+ .. currentmodule:: {{ module }}
4
+
5
+
6
+ {{ name | underline}}
7
+
8
+ .. autoclass:: {{ name }}
9
+ :members:
10
+ :special-members: __call__
11
+
12
+ ..
13
+ autogenerated from _templates/callable.rst
14
+ note it does not have :inherited-members:
docs/en/advanced_guides/accelerator_intro.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Accelerate Evaluation Inference with vLLM or LMDeploy
2
+
3
+ ## Background
4
+
5
+ During the OpenCompass evaluation process, the Huggingface transformers library is used for inference by default. While this is a very general solution, there are scenarios where more efficient inference methods are needed to speed up the process, such as leveraging VLLM or LMDeploy.
6
+
7
+ - [LMDeploy](https://github.com/InternLM/lmdeploy) is a toolkit designed for compressing, deploying, and serving large language models (LLMs), developed by the [MMRazor](https://github.com/open-mmlab/mmrazor) and [MMDeploy](https://github.com/open-mmlab/mmdeploy) teams.
8
+ - [vLLM](https://github.com/vllm-project/vllm) is a fast and user-friendly library for LLM inference and serving, featuring advanced serving throughput, efficient PagedAttention memory management, continuous batching of requests, fast model execution via CUDA/HIP graphs, quantization techniques (e.g., GPTQ, AWQ, SqueezeLLM, FP8 KV Cache), and optimized CUDA kernels.
9
+
10
+ ## Preparation for Acceleration
11
+
12
+ First, check whether the model you want to evaluate supports inference acceleration using vLLM or LMDeploy. Additionally, ensure you have installed vLLM or LMDeploy as per their official documentation. Below are the installation methods for reference:
13
+
14
+ ### LMDeploy Installation Method
15
+
16
+ Install LMDeploy using pip (Python 3.8+) or from [source](https://github.com/InternLM/lmdeploy/blob/main/docs/en/build.md):
17
+
18
+ ```bash
19
+ pip install lmdeploy
20
+ ```
21
+
22
+ ### VLLM Installation Method
23
+
24
+ Install vLLM using pip or from [source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
25
+
26
+ ```bash
27
+ pip install vllm
28
+ ```
29
+
30
+ ## Accelerated Evaluation Using VLLM or LMDeploy
31
+
32
+ ### Method 1: Using Command Line Parameters to Change the Inference Backend
33
+
34
+ OpenCompass offers one-click evaluation acceleration. During evaluation, it can automatically convert Huggingface transformer models to VLLM or LMDeploy models for use. Below is an example code for evaluating the GSM8k dataset using the default Huggingface version of the llama3-8b-instruct model:
35
+
36
+ ```python
37
+ # eval_gsm8k.py
38
+ from mmengine.config import read_base
39
+
40
+ with read_base():
41
+ # Select a dataset list
42
+ from .datasets.gsm8k.gsm8k_0shot_gen_a58960 import gsm8k_datasets as datasets
43
+ # Select an interested model
44
+ from ..models.hf_llama.hf_llama3_8b_instruct import models
45
+ ```
46
+
47
+ Here, `hf_llama3_8b_instruct` specifies the original Huggingface model configuration, as shown below:
48
+
49
+ ```python
50
+ from opencompass.models import HuggingFacewithChatTemplate
51
+
52
+ models = [
53
+ dict(
54
+ type=HuggingFacewithChatTemplate,
55
+ abbr='llama-3-8b-instruct-hf',
56
+ path='meta-llama/Meta-Llama-3-8B-Instruct',
57
+ max_out_len=1024,
58
+ batch_size=8,
59
+ run_cfg=dict(num_gpus=1),
60
+ stop_words=['<|end_of_text|>', '<|eot_id|>'],
61
+ )
62
+ ]
63
+ ```
64
+
65
+ To evaluate the GSM8k dataset using the default Huggingface version of the llama3-8b-instruct model, use:
66
+
67
+ ```bash
68
+ python run.py config/eval_gsm8k.py
69
+ ```
70
+
71
+ To accelerate the evaluation using vLLM or LMDeploy, you can use the following script:
72
+
73
+ ```bash
74
+ python run.py config/eval_gsm8k.py -a vllm
75
+ ```
76
+
77
+ or
78
+
79
+ ```bash
80
+ python run.py config/eval_gsm8k.py -a lmdeploy
81
+ ```
82
+
83
+ ### Method 2: Accelerating Evaluation via Deployed Inference Acceleration Service API
84
+
85
+ OpenCompass also supports accelerating evaluation by deploying vLLM or LMDeploy inference acceleration service APIs. Follow these steps:
86
+
87
+ 1. Install the openai package:
88
+
89
+ ```bash
90
+ pip install openai
91
+ ```
92
+
93
+ 2. Deploy the inference acceleration service API for vLLM or LMDeploy. Below is an example for LMDeploy:
94
+
95
+ ```bash
96
+ lmdeploy serve api_server meta-llama/Meta-Llama-3-8B-Instruct --model-name Meta-Llama-3-8B-Instruct --server-port 23333
97
+ ```
98
+
99
+ Parameters for starting the api_server can be checked using `lmdeploy serve api_server -h`, such as --tp for tensor parallelism, --session-len for the maximum context window length, --cache-max-entry-count for adjusting the k/v cache memory usage ratio, etc.
100
+
101
+ 3. Once the service is successfully deployed, modify the evaluation script by changing the model configuration path to the service address, as shown below:
102
+
103
+ ```python
104
+ from opencompass.models import OpenAISDK
105
+
106
+ api_meta_template = dict(
107
+ round=[
108
+ dict(role='HUMAN', api_role='HUMAN'),
109
+ dict(role='BOT', api_role='BOT', generate=True),
110
+ ],
111
+ reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
112
+ )
113
+
114
+ models = [
115
+ dict(
116
+ abbr='Meta-Llama-3-8B-Instruct-LMDeploy-API',
117
+ type=OpenAISDK,
118
+ key='EMPTY', # API key
119
+ openai_api_base='http://0.0.0.0:23333/v1', # Service address
120
+ path='Meta-Llama-3-8B-Instruct', # Model name for service request
121
+ tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', # The tokenizer name or path, if set to `None`, uses the default `gpt-4` tokenizer
122
+ rpm_verbose=True, # Whether to print request rate
123
+ meta_template=api_meta_template, # Service request template
124
+ query_per_second=1, # Service request rate
125
+ max_out_len=1024, # Maximum output length
126
+ max_seq_len=4096, # Maximum input length
127
+ temperature=0.01, # Generation temperature
128
+ batch_size=8, # Batch size
129
+ retry=3, # Number of retries
130
+ )
131
+ ]
132
+ ```
133
+
134
+ ## Acceleration Effect and Performance Comparison
135
+
136
+ Below is a comparison table of the acceleration effect and performance when using VLLM or LMDeploy on a single A800 GPU for evaluating the Llama-3-8B-Instruct model on the GSM8k dataset:
137
+
138
+ | Inference Backend | Accuracy | Inference Time (minutes:seconds) | Speedup (relative to Huggingface) |
139
+ | ----------------- | -------- | -------------------------------- | --------------------------------- |
140
+ | Huggingface | 74.22 | 24:26 | 1.0 |
141
+ | LMDeploy | 73.69 | 11:15 | 2.2 |
142
+ | VLLM | 72.63 | 07:52 | 3.1 |
docs/en/advanced_guides/circular_eval.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CircularEval
2
+
3
+ ## Background
4
+
5
+ For multiple-choice questions, when a Language Model (LLM) provides the correct option, it does not necessarily imply a true understanding and reasoning of the question. It could be a guess. To differentiate these scenarios and reduce LLM bias towards options, CircularEval (CircularEval) can be utilized. A multiple-choice question is augmented by shuffling its options, and if the LLM correctly answers all variations of the augmented question, it is considered correct under CircularEval.
6
+
7
+ ## Adding Your Own CircularEval Dataset
8
+
9
+ Generally, to evaluate a dataset using CircularEval, both its loading and evaluation methods need to be rewritten. Modifications are required in both the OpenCompass main library and configuration files. We will use C-Eval as an example for explanation.
10
+
11
+ OpenCompass main library:
12
+
13
+ ```python
14
+ from opencompass.datasets.ceval import CEvalDataset
15
+ from opencompass.datasets.circular import CircularDatasetMeta
16
+
17
+ class CircularCEvalDataset(CEvalDataset, metaclass=CircularDatasetMeta):
18
+ # The overloaded dataset class
19
+ dataset_class = CEvalDataset
20
+
21
+ # Splits of the DatasetDict that need CircularEval. For CEvalDataset, which loads [dev, val, test], we only need 'val' and 'test' for CircularEval, not 'dev'
22
+ default_circular_splits = ['val', 'test']
23
+
24
+ # List of keys to be shuffled
25
+ default_option_keys = ['A', 'B', 'C', 'D']
26
+
27
+ # If the content of 'answer_key' is one of ['A', 'B', 'C', 'D'], representing the correct answer. This field indicates how to update the correct answer after shuffling options. Choose either this or default_answer_key_switch_method
28
+ default_answer_key = 'answer'
29
+
30
+ # If 'answer_key' content is not one of ['A', 'B', 'C', 'D'], a function can be used to specify the correct answer after shuffling options. Choose either this or default_answer_key
31
+ # def default_answer_key_switch_method(item, circular_pattern):
32
+ # # 'item' is the original data item
33
+ # # 'circular_pattern' is a tuple indicating the order after shuffling options, e.g., ('D', 'A', 'B', 'C') means the original option A is now D, and so on
34
+ # item['answer'] = circular_pattern['ABCD'.index(item['answer'])]
35
+ # return item
36
+ ```
37
+
38
+ `CircularCEvalDataset` accepts the `circular_pattern` parameter with two values:
39
+
40
+ - `circular`: Indicates a single cycle. It is the default value. ABCD is expanded to ABCD, BCDA, CDAB, DABC, a total of 4 variations.
41
+ - `all_possible`: Indicates all permutations. ABCD is expanded to ABCD, ABDC, ACBD, ACDB, ADBC, ADCB, BACD, ..., a total of 24 variations.
42
+
43
+ Additionally, we provide a `CircularEvaluator` to replace `AccEvaluator`. This Evaluator also accepts `circular_pattern`, and it should be consistent with the above. It produces the following metrics:
44
+
45
+ - `acc_{origin|circular|all_possible}`: Treating each question with shuffled options as separate, calculating accuracy.
46
+ - `perf_{origin|circular|all_possible}`: Following Circular logic, a question is considered correct only if all its variations with shuffled options are answered correctly, calculating accuracy.
47
+ - `more_{num}_{origin|circular|all_possible}`: According to Circular logic, a question is deemed correct if the number of its variations answered correctly is greater than or equal to num, calculating accuracy.
48
+
49
+ OpenCompass configuration file:
50
+
51
+ ```python
52
+ from mmengine.config import read_base
53
+ from opencompass.datasets.circular import CircularCEvalDataset
54
+
55
+ with read_base():
56
+ from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
57
+
58
+ for d in ceval_datasets:
59
+ # Overloading the load method
60
+ d['type'] = CircularCEvalDataset
61
+ # Renaming for differentiation from non-circular evaluation versions
62
+ d['abbr'] = d['abbr'] + '-circular-4'
63
+ # Overloading the evaluation method
64
+ d['eval_cfg']['evaluator'] = {'type': CircularEvaluator}
65
+
66
+ # The dataset after the above operations looks like this:
67
+ # dict(
68
+ # type=CircularCEvalDataset,
69
+ # path='./data/ceval/formal_ceval', # Unchanged
70
+ # name='computer_network', # Unchanged
71
+ # abbr='ceval-computer_network-circular-4',
72
+ # reader_cfg=dict(...), # Unchanged
73
+ # infer_cfg=dict(...), # Unchanged
74
+ # eval_cfg=dict(evaluator=dict(type=CircularEvaluator), ...),
75
+ # )
76
+ ```
77
+
78
+ Additionally, for better presentation of results in CircularEval, consider using the following summarizer:
79
+
80
+ ```python
81
+
82
+
83
+ from mmengine.config import read_base
84
+ from opencompass.summarizers import CircularSummarizer
85
+
86
+ with read_base():
87
+ from ...summarizers.groups.ceval.ceval_summary_groups
88
+
89
+ new_summary_groups = []
90
+ for item in ceval_summary_groups:
91
+ new_summary_groups.append(
92
+ {
93
+ 'name': item['name'] + '-circular-4',
94
+ 'subsets': [i + '-circular-4' for i in item['subsets']],
95
+ }
96
+ )
97
+
98
+ summarizer = dict(
99
+ type=CircularSummarizer,
100
+ # Select specific metrics to view
101
+ metric_types=['acc_origin', 'perf_circular'],
102
+ dataset_abbrs = [
103
+ 'ceval-circular-4',
104
+ 'ceval-humanities-circular-4',
105
+ 'ceval-stem-circular-4',
106
+ 'ceval-social-science-circular-4',
107
+ 'ceval-other-circular-4',
108
+ ],
109
+ summary_groups=new_summary_groups,
110
+ )
111
+ ```
112
+
113
+ For more complex evaluation examples, refer to this sample code: https://github.com/open-compass/opencompass/tree/main/examples/eval_circular.py
docs/en/advanced_guides/code_eval.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Evaluation Tutorial
2
+
3
+ This tutorial primarily focuses on evaluating a model's coding proficiency, using `humaneval` and `mbpp` as examples.
4
+
5
+ ## pass@1
6
+
7
+ If you only need to generate a single response to evaluate the pass@1 performance, you can directly use [configs/datasets/humaneval/humaneval_gen_8e312c.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/humaneval/humaneval_gen_8e312c.py) and [configs/datasets/mbpp/deprecated_mbpp_gen_1e1056.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/mbpp/deprecated_mbpp_gen_1e1056.py), referring to the general [quick start tutorial](../get_started/quick_start.md).
8
+
9
+ For multilingual evaluation, please refer to the [Multilingual Code Evaluation Tutorial](./code_eval_service.md).
10
+
11
+ ## pass@k
12
+
13
+ If you need to generate multiple responses for a single example to evaluate the pass@k performance, consider the following two situations. Here we take 10 responses as an example:
14
+
15
+ ### Typical Situation
16
+
17
+ For most models that support the `num_return_sequences` parameter in HF's generation, we can use it directly to obtain multiple responses. Refer to the following configuration file:
18
+
19
+ ```python
20
+ from opencompass.datasets import MBPPDatasetV2, MBPPPassKEvaluator
21
+
22
+ with read_base():
23
+ from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
24
+ from .datasets.mbpp.deprecated_mbpp_gen_1e1056 import mbpp_datasets
25
+
26
+ mbpp_datasets[0]['type'] = MBPPDatasetV2
27
+ mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
28
+ mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
29
+
30
+ datasets = []
31
+ datasets += humaneval_datasets
32
+ datasets += mbpp_datasets
33
+
34
+ models = [
35
+ dict(
36
+ type=HuggingFaceCausalLM,
37
+ ...,
38
+ generation_kwargs=dict(
39
+ num_return_sequences=10,
40
+ do_sample=True,
41
+ top_p=0.95,
42
+ temperature=0.8,
43
+ ),
44
+ ...,
45
+ )
46
+ ]
47
+ ```
48
+
49
+ For `mbpp`, new changes are needed in the dataset and evaluation, so we simultaneously modify the `type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` fields to accommodate these requirements.
50
+
51
+ We also need model responses with randomness, thus setting the `generation_kwargs` parameter is necessary. Note that we need to set `num_return_sequences` to get the number of responses.
52
+
53
+ Note: `num_return_sequences` must be greater than or equal to k, as pass@k itself is a probability estimate.
54
+
55
+ You can specifically refer to the following configuration file [examples/eval_code_passk.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_code_passk.py)
56
+
57
+ ### For Models That Do Not Support Multiple Responses
58
+
59
+ This applies to some HF models with poorly designed APIs or missing features. In this case, we need to repeatedly construct datasets to achieve multiple response effects. Refer to the following configuration:
60
+
61
+ ```python
62
+ from opencompass.datasets import MBPPDatasetV2, MBPPPassKEvaluator
63
+
64
+ with read_base():
65
+ from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
66
+ from .datasets.mbpp.deprecated_mbpp_gen_1e1056 import mbpp_datasets
67
+
68
+ humaneval_datasets[0]['abbr'] = 'openai_humaneval_pass10'
69
+ humaneval_datasets[0]['num_repeats'] = 10
70
+ mbpp_datasets[0]['abbr'] = 'mbpp_pass10'
71
+ mbpp_datasets[0]['num_repeats'] = 10
72
+ mbpp_datasets[0]['type'] = MBPPDatasetV2
73
+ mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
74
+ mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
75
+
76
+ datasets = []
77
+ datasets += humaneval_datasets
78
+ datasets += mbpp_datasets
79
+
80
+ models = [
81
+ dict(
82
+ type=HuggingFaceCausalLM,
83
+ ...,
84
+ generation_kwargs=dict(
85
+ do_sample=True,
86
+ top_p=0.95,
87
+ temperature=0.8,
88
+ ),
89
+ ...,
90
+ )
91
+ ]
92
+ ```
93
+
94
+ Since the dataset's prompt has not been modified, we need to replace the corresponding fields to achieve the purpose of repeating the dataset.
95
+ You need to modify these fields:
96
+
97
+ - `num_repeats`: the number of times the dataset is repeated
98
+ - `abbr`: It's best to modify the dataset abbreviation along with the number of repetitions because the number of datasets will change, preventing potential issues arising from discrepancies with the values in `.cache/dataset_size.json`.
99
+
100
+ For `mbpp`, modify the `type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` fields as well.
101
+
102
+ We also need model responses with randomness, thus setting the `generation_kwargs` parameter is necessary.
103
+
104
+ You can specifically refer to the following configuration file [examples/eval_code_passk_repeat_dataset.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_code_passk_repeat_dataset.py)
docs/en/advanced_guides/code_eval_service.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Evaluation Docker Tutorial
2
+
3
+ To complete the LLM code capability evaluation, we need to build a separate evaluation environment to avoid executing erroneous code in the development environment, which would inevitably cause losses. The code evaluation service currently used by OpenCompass can refer to the [code-evaluator](https://github.com/open-compass/code-evaluator) project. The following will introduce evaluation tutorials around the code evaluation service.
4
+
5
+ 1. humaneval-x
6
+
7
+ This is a multi-programming language dataset [humaneval-x](https://huggingface.co/datasets/THUDM/humaneval-x).
8
+ You can download the dataset from this [download link](https://github.com/THUDM/CodeGeeX2/tree/main/benchmark/humanevalx). Please download the language file (××.jsonl.gz) that needs to be evaluated and place it in the `./data/humanevalx` folder.
9
+
10
+ The currently supported languages are `python`, `cpp`, `go`, `java`, `js`.
11
+
12
+ 2. DS1000
13
+
14
+ This is a Python multi-algorithm library dataset [ds1000](https://github.com/xlang-ai/DS-1000).
15
+ You can download the dataset from this [download link](https://github.com/xlang-ai/DS-1000/blob/main/ds1000_data.zip).
16
+
17
+ The currently supported algorithm libraries are `Pandas`, `Numpy`, `Tensorflow`, `Scipy`, `Sklearn`, `Pytorch`, `Matplotlib`.
18
+
19
+ ## Launching the Code Evaluation Service
20
+
21
+ 1. Ensure you have installed Docker, please refer to [Docker installation document](https://docs.docker.com/engine/install/).
22
+ 2. Pull the source code of the code evaluation service project and build the Docker image.
23
+
24
+ Choose the dockerfile corresponding to the dataset you need, and replace `humanevalx` or `ds1000` in the command below.
25
+
26
+ ```shell
27
+ git clone https://github.com/open-compass/code-evaluator.git
28
+ docker build -t code-eval-{your-dataset}:latest -f docker/{your-dataset}/Dockerfile .
29
+ ```
30
+
31
+ 3. Create a container with the following commands:
32
+
33
+ ```shell
34
+ # Log output format
35
+ docker run -it -p 5000:5000 code-eval-{your-dataset}:latest python server.py
36
+
37
+ # Run the program in the background
38
+ # docker run -itd -p 5000:5000 code-eval-{your-dataset}:latest python server.py
39
+
40
+ # Using different ports
41
+ # docker run -itd -p 5001:5001 code-eval-{your-dataset}:latest python server.py --port 5001
42
+ ```
43
+
44
+ **Note:**
45
+
46
+ - If you encounter a timeout during the evaluation of Go, please use the following command when creating the container.
47
+
48
+ ```shell
49
+ docker run -it -p 5000:5000 -e GO111MODULE=on -e GOPROXY=https://goproxy.io code-eval-{your-dataset}:latest python server.py
50
+ ```
51
+
52
+ 4. To ensure you have access to the service, use the following command to check the inference environment and evaluation service connection status. (If both inferences and code evaluations run on the same host, skip this step.)
53
+
54
+ ```shell
55
+ ping your_service_ip_address
56
+ telnet your_service_ip_address your_service_port
57
+ ```
58
+
59
+ ## Local Code Evaluation
60
+
61
+ When the model inference and code evaluation services are running on the same host or within the same local area network, direct code reasoning and evaluation can be performed. **Note: DS1000 is currently not supported, please proceed with remote evaluation.**
62
+
63
+ ### Configuration File
64
+
65
+ We provide [the configuration file](https://github.com/open-compass/opencompass/blob/main/examples/eval_codegeex2.py) of using `humanevalx` for evaluation on `codegeex2` as reference.
66
+
67
+ The dataset and related post-processing configurations files can be found at this [link](https://github.com/open-compass/opencompass/tree/main/configs/datasets/humanevalx) with attention paid to the `evaluator` field in the humanevalx_eval_cfg_dict.
68
+
69
+ ```python
70
+ from opencompass.openicl.icl_prompt_template import PromptTemplate
71
+ from opencompass.openicl.icl_retriever import ZeroRetriever
72
+ from opencompass.openicl.icl_inferencer import GenInferencer
73
+ from opencompass.datasets import HumanevalXDataset, HumanevalXEvaluator
74
+
75
+ humanevalx_reader_cfg = dict(
76
+ input_columns=['prompt'], output_column='task_id', train_split='test')
77
+
78
+ humanevalx_infer_cfg = dict(
79
+ prompt_template=dict(
80
+ type=PromptTemplate,
81
+ template='{prompt}'),
82
+ retriever=dict(type=ZeroRetriever),
83
+ inferencer=dict(type=GenInferencer, max_out_len=1024))
84
+
85
+ humanevalx_eval_cfg_dict = {
86
+ lang : dict(
87
+ evaluator=dict(
88
+ type=HumanevalXEvaluator,
89
+ language=lang,
90
+ ip_address="localhost", # replace to your code_eval_server ip_address, port
91
+ port=5000), # refer to https://github.com/open-compass/code-evaluator to launch a server
92
+ pred_role='BOT')
93
+ for lang in ['python', 'cpp', 'go', 'java', 'js'] # do not support rust now
94
+ }
95
+
96
+ humanevalx_datasets = [
97
+ dict(
98
+ type=HumanevalXDataset,
99
+ abbr=f'humanevalx-{lang}',
100
+ language=lang,
101
+ path='./data/humanevalx',
102
+ reader_cfg=humanevalx_reader_cfg,
103
+ infer_cfg=humanevalx_infer_cfg,
104
+ eval_cfg=humanevalx_eval_cfg_dict[lang])
105
+ for lang in ['python', 'cpp', 'go', 'java', 'js']
106
+ ]
107
+ ```
108
+
109
+ ### Task Launch
110
+
111
+ Refer to the [Quick Start](../get_started.html)
112
+
113
+ ## Remote Code Evaluation
114
+
115
+ Model inference and code evaluation services located in different machines which cannot be accessed directly require prior model inference before collecting the code evaluation results. The configuration file and inference process can be reused from the previous tutorial.
116
+
117
+ ### Collect Inference Results(Only for Humanevalx)
118
+
119
+ In OpenCompass's tools folder, there is a script called `collect_code_preds.py` provided to process and collect the inference results after providing the task launch configuration file during startup along with specifying the working directory used corresponding to the task.
120
+ It is the same with `-r` option in `run.py`. More details can be referred through the [documentation](https://opencompass.readthedocs.io/en/latest/get_started/quick_start.html#launching-evaluation).
121
+
122
+ ```shell
123
+ python tools/collect_code_preds.py [config] [-r latest]
124
+ ```
125
+
126
+ The collected results will be organized as following under the `-r` folder:
127
+
128
+ ```
129
+ workdir/humanevalx
130
+ ├── codegeex2-6b
131
+ │   ├── humanevalx_cpp.json
132
+ │   ├── humanevalx_go.json
133
+ │   ├── humanevalx_java.json
134
+ │   ├── humanevalx_js.json
135
+ │   └── humanevalx_python.json
136
+ ├── CodeLlama-13b
137
+ │   ├── ...
138
+ ├── CodeLlama-13b-Instruct
139
+ │   ├── ...
140
+ ├── CodeLlama-13b-Python
141
+ │   ├── ...
142
+ ├── ...
143
+ ```
144
+
145
+ For DS1000, you just need to obtain the corresponding prediction file generated by `opencompass`.
146
+
147
+ ### Code Evaluation
148
+
149
+ Make sure your code evaluation service is started, and use `curl` to request:
150
+
151
+ #### The following only supports Humanevalx
152
+
153
+ ```shell
154
+ curl -X POST -F 'file=@{result_absolute_path}' -F 'dataset={dataset/language}' {your_service_ip_address}:{your_service_port}/evaluate
155
+ ```
156
+
157
+ For example:
158
+
159
+ ```shell
160
+ curl -X POST -F 'file=@./examples/humanevalx/python.json' -F 'dataset=humanevalx/python' localhost:5000/evaluate
161
+ ```
162
+
163
+ The we have:
164
+
165
+ ```
166
+ "{\"pass@1\": 37.19512195121951%}"
167
+ ```
168
+
169
+ Additionally, we offer an extra option named `with_prompt`(Defaults to `True`), since some models(like `WizardCoder`) generate complete codes without requiring the form of concatenating prompt and prediction. You may refer to the following commands for evaluation.
170
+
171
+ ```shell
172
+ curl -X POST -F 'file=@./examples/humanevalx/python.json' -F 'dataset=humanevalx/python' -H 'with-prompt: False' localhost:5000/evaluate
173
+ ```
174
+
175
+ #### The following only supports DS1000
176
+
177
+ Make sure the code evaluation service is started, then use `curl` to submit a request:
178
+
179
+ ```shell
180
+ curl -X POST -F 'file=@./internlm-chat-7b-hf-v11/ds1000_Numpy.json' localhost:5000/evaluate
181
+ ```
182
+
183
+ DS1000 supports additional debug parameters. Be aware that a large amount of log will be generated when it is turned on:
184
+
185
+ - `full`: Additional print out of the original prediction for each error sample, post-processing prediction, running program, and final error.
186
+ - `half`: Additional print out of the running program and final error for each error sample.
187
+ - `error`: Additional print out of the final error for each error sample.
188
+
189
+ ```shell
190
+ curl -X POST -F 'file=@./internlm-chat-7b-hf-v11/ds1000_Numpy.json' -F 'debug=error' localhost:5000/evaluate
191
+ ```
192
+
193
+ You can also modify the `num_workers` in the same way to control the degree of parallelism.
194
+
195
+ ## Advanced Tutorial
196
+
197
+ Besides evaluating the supported HUMANEVAList data set, users might also need:
198
+
199
+ ### Support New Dataset
200
+
201
+ Please refer to the [tutorial on supporting new datasets](./new_dataset.md).
202
+
203
+ ### Modify Post-Processing
204
+
205
+ 1. For local evaluation, follow the post-processing section in the tutorial on supporting new datasets to modify the post-processing method.
206
+ 2. For remote evaluation, please modify the post-processing part in the tool's `collect_code_preds.py`.
207
+ 3. Some parts of post-processing could also be modified in the code evaluation service, more information will be available in the next section.
208
+
209
+ ### Debugging Code Evaluation Service
210
+
211
+ When supporting new datasets or modifying post-processors, it is possible that modifications need to be made to the original code evaluation service. Please make changes based on the following steps:
212
+
213
+ 1. Remove the installation of the `code-evaluator` in `Dockerfile`, mount the `code-evaluator` when starting the container instead:
214
+
215
+ ```shell
216
+ docker run -it -p 5000:5000 -v /local/path/of/code-evaluator:/workspace/code-evaluator code-eval:latest bash
217
+ ```
218
+
219
+ 2. Install and start the code evaluation service locally. At this point, any necessary modifications can be made to the local copy of the `code-evaluator`.
220
+
221
+ ```shell
222
+ cd code-evaluator && pip install -r requirements.txt
223
+ python server.py
224
+ ```
docs/en/advanced_guides/contamination_eval.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Contamination Assessment
2
+
3
+ **Data Contamination** refers to the phenomenon where data intended for downstream testing tasks appear in the training data of large language models (LLMs), resulting in artificially inflated performance metrics in downstream tasks (such as summarization, natural language inference, text classification), which do not accurately reflect the model's true generalization capabilities.
4
+
5
+ Since the source of data contamination lies in the training data used by LLMs, the most direct method to detect data contamination is to collide test data with training data and then report the extent of overlap between the two. The classic GPT-3 [paper](https://arxiv.org/pdf/2005.14165.pdf) reported on this in Table C.1.
6
+
7
+ However, today's open-source community often only publishes model parameters, not training datasets. In such cases, how to determine the presence and extent of data contamination remains unsolved. OpenCompass offers two possible solutions.
8
+
9
+ ## Contamination Data Annotation Based on Self-Built Co-Distribution Data
10
+
11
+ Referencing the method mentioned in Section 5.2 of [Skywork](https://arxiv.org/pdf/2310.19341.pdf), we directly used the dataset [mock_gsm8k_test](https://huggingface.co/datasets/Skywork/mock_gsm8k_test) uploaded to HuggingFace by Skywork.
12
+
13
+ In this method, the authors used GPT-4 to synthesize data similar to the original GSM8K style, and then calculated the perplexity on the GSM8K training set (train), GSM8K test set (test), and GSM8K reference set (ref). Since the GSM8K reference set was newly generated, the authors considered it as clean, not belonging to any training set of any model. They posited:
14
+
15
+ - If the test set's perplexity is significantly lower than the reference set's, the test set might have appeared in the model's training phase;
16
+ - If the training set's perplexity is significantly lower than the test set's, the training set might have been overfitted by the model.
17
+
18
+ The following configuration file can be referenced:
19
+
20
+ ```python
21
+ from mmengine.config import read_base
22
+
23
+ with read_base():
24
+ from .datasets.gsm8k_contamination.gsm8k_contamination_ppl_ecdd22 import gsm8k_datasets # includes training, test, and reference sets
25
+ from .models.qwen.hf_qwen_7b import models as hf_qwen_7b_model # model under review
26
+ from .models.yi.hf_yi_6b import models as hf_yi_6b_model
27
+
28
+ datasets = [*gsm8k_datasets]
29
+ models = [*hf_qwen_7b_model, *hf_yi_6b_model]
30
+ ```
31
+
32
+ An example output is as follows:
33
+
34
+ ```text
35
+ dataset version metric mode internlm-7b-hf qwen-7b-hf yi-6b-hf chatglm3-6b-base-hf qwen-14b-hf baichuan2-13b-base-hf internlm-20b-hf aquila2-34b-hf ...
36
+ --------------- --------- ----------- ------- ---------------- ------------ ---------- --------------------- ------------- ----------------------- ----------------- ---------------- ...
37
+ gsm8k-train-ppl 0b8e46 average_ppl unknown 1.5 0.78 1.37 1.16 0.5 0.76 1.41 0.78 ...
38
+ gsm8k-test-ppl 0b8e46 average_ppl unknown 1.56 1.33 1.42 1.3 1.15 1.13 1.52 1.16 ...
39
+ gsm8k-ref-ppl f729ba average_ppl unknown 1.55 1.2 1.43 1.35 1.27 1.19 1.47 1.35 ...
40
+ ```
41
+
42
+ Currently, this solution only supports the GSM8K dataset. We welcome the community to contribute more datasets.
43
+
44
+ Consider cite the following paper if you find it helpful:
45
+
46
+ ```bibtex
47
+ @misc{2023opencompass,
48
+ title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
49
+ author={OpenCompass Contributors},
50
+ howpublished = {\url{https://github.com/open-compass/opencompass}},
51
+ year={2023}
52
+ }
53
+ @misc{wei2023skywork,
54
+ title={Skywork: A More Open Bilingual Foundation Model},
55
+ author={Tianwen Wei and Liang Zhao and Lichang Zhang and Bo Zhu and Lijie Wang and Haihua Yang and Biye Li and Cheng Cheng and Weiwei Lü and Rui Hu and Chenxia Li and Liu Yang and Xilin Luo and Xuejie Wu and Lunan Liu and Wenjun Cheng and Peng Cheng and Jianhao Zhang and Xiaoyu Zhang and Lei Lin and Xiaokun Wang and Yutuan Ma and Chuanhai Dong and Yanqi Sun and Yifu Chen and Yongyi Peng and Xiaojuan Liang and Shuicheng Yan and Han Fang and Yahui Zhou},
56
+ year={2023},
57
+ eprint={2310.19341},
58
+ archivePrefix={arXiv},
59
+ primaryClass={cs.CL}
60
+ }
61
+ ```
62
+
63
+ ## Contamination Data Annotation Based on Classic Pre-trained Sets
64
+
65
+ Thanks to [Contamination_Detector](https://github.com/liyucheng09/Contamination_Detector) and @liyucheng09 for providing this method.
66
+
67
+ In this method, the authors search the test datasets (such as C-Eval, ARC, HellaSwag, etc.) using the Common Crawl database and Bing search engine, then mark each test sample as clean / question contaminated / both question and answer contaminated.
68
+
69
+ During testing, OpenCompass
70
+
71
+ will report the accuracy or perplexity of ceval on subsets composed of these three labels. Generally, the accuracy ranges from low to high: clean, question contaminated, both question and answer contaminated subsets. The authors believe:
72
+
73
+ - If the performance of the three is relatively close, the contamination level of the model on that test set is light; otherwise, it is heavy.
74
+
75
+ The following configuration file can be referenced [link](https://github.com/open-compass/opencompass/blob/main/examples/eval_contamination.py):
76
+
77
+ ```python
78
+ from mmengine.config import read_base
79
+
80
+ with read_base():
81
+ from .datasets.ceval.ceval_clean_ppl import ceval_datasets # ceval dataset with contamination tags
82
+ from .models.yi.hf_yi_6b import models as hf_yi_6b_model # model under review
83
+ from .models.qwen.hf_qwen_7b import models as hf_qwen_7b_model
84
+ from .summarizers.contamination import ceval_summarizer as summarizer # output formatting
85
+
86
+ datasets = [*ceval_datasets]
87
+ models = [*hf_yi_6b_model, *hf_qwen_7b_model]
88
+ ```
89
+
90
+ An example output is as follows:
91
+
92
+ ```text
93
+ dataset version mode yi-6b-hf - - qwen-7b-hf - - ...
94
+ ---------------------------------------------- --------- ------ ---------------- ----------------------------- --------------------------------------- ---------------- ----------------------------- --------------------------------------- ...
95
+ - - - accuracy - clean accuracy - input contaminated accuracy - input-and-label contaminated accuracy - clean accuracy - input contaminated accuracy - input-and-label contaminated ...
96
+ ...
97
+ ceval-humanities - ppl 74.42 75.00 82.14 67.44 50.00 70.54 ...
98
+ ceval-stem - ppl 53.70 57.14 85.61 47.41 52.38 67.63 ...
99
+ ceval-social-science - ppl 81.60 84.62 83.09 76.00 61.54 72.79 ...
100
+ ceval-other - ppl 72.31 73.91 75.00 58.46 39.13 61.88 ...
101
+ ceval-hard - ppl 44.35 37.50 70.00 41.13 25.00 30.00 ...
102
+ ceval - ppl 67.32 71.01 81.17 58.97 49.28 67.82 ...
103
+ ```
104
+
105
+ Currently, this solution only supports the C-Eval, MMLU, HellaSwag and ARC. [Contamination_Detector](https://github.com/liyucheng09/Contamination_Detector) also includes CSQA and WinoGrande, but these have not yet been implemented in OpenCompass. We welcome the community to contribute more datasets.
106
+
107
+ Consider cite the following paper if you find it helpful:
108
+
109
+ ```bibtex
110
+ @misc{2023opencompass,
111
+ title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
112
+ author={OpenCompass Contributors},
113
+ howpublished = {\url{https://github.com/open-compass/opencompass}},
114
+ year={2023}
115
+ }
116
+ @article{Li2023AnOS,
117
+ title={An Open Source Data Contamination Report for Llama Series Models},
118
+ author={Yucheng Li},
119
+ journal={ArXiv},
120
+ year={2023},
121
+ volume={abs/2310.17589},
122
+ url={https://api.semanticscholar.org/CorpusID:264490711}
123
+ }
124
+ ```
docs/en/advanced_guides/custom_dataset.md ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Quick Evaluation Tutorial
2
+
3
+ OpenCompass provides two paths for quickly evaluating the provided data, the data format protocol based on ChatMLDataset and the data format protocol based on CustomDataset.
4
+ Compared to the complete dataset integration process in [new_dataset.md](./new_dataset.md), these two evaluation paths are more convenient and efficient, being able to directly enter the evaluation process without adding new configuration files.
5
+ But if you have specific needs for custom reading/inference/evaluation, it is recommended to still follow the complete integration process to add a new dataset.
6
+
7
+ ## Data Format Protocol and Fast Evaluation Based on ChatMLDataset
8
+
9
+ OpenCompass has recently launched a dataset evaluation mode based on the ChatML dialogue template, which allow users to provide a dataset .json file that conforms to the ChatML dialogue template, and simply set the dataset information config like model configs to start evaluating directly.
10
+
11
+ ### Format Requirements for Data Files
12
+
13
+ This evaluation method only supports data files in `.json` format, and each sample must comply with the following format:
14
+
15
+ The format of a text-only dataset with a simple structure:
16
+
17
+ ```jsonl
18
+ {
19
+ "question":[
20
+ {
21
+ "role": "system" # Omittable
22
+ "content": Str
23
+ },
24
+ {
25
+ "role": "user",
26
+ "content": Str
27
+ }
28
+ ],
29
+ "answer":[
30
+ Str
31
+ ]
32
+ }
33
+ {
34
+ ...
35
+ }
36
+ ...
37
+ ```
38
+
39
+ The format of multiple rounds and multiple modes datasets:
40
+
41
+ ```jsonl
42
+ {
43
+ "question":[
44
+ {
45
+ "role": "system",
46
+ "content": Str,
47
+ },
48
+ {
49
+ "role": "user",
50
+ "content": Str or List
51
+ [
52
+ {
53
+ "type": Str, # "image"
54
+ "image_url": Str,
55
+ },
56
+ ...
57
+ {
58
+ "type": Str, # "text"
59
+ "text": Str,
60
+ },
61
+ ]
62
+ },
63
+ {
64
+ "role": "assistant",
65
+ "content": Str
66
+ },
67
+ {
68
+ "role": "user",
69
+ "content": Str or List
70
+ },
71
+ ...
72
+ ],
73
+ "answer":[
74
+ Str,
75
+ Str,
76
+ ...
77
+ ]
78
+ }
79
+ {
80
+ ...
81
+ }
82
+ ...
83
+ ```
84
+
85
+ (As OpenCompass currently does not support multi-mode evaluation, the template above is for reference only.)
86
+
87
+ When ChatMLDataset reading `.json` files, it will use `pydantic` to perform simple format validation on the files.
88
+ You can use `tools/chatml_fformat_test.py` to check your provided data file.
89
+
90
+ After format checking, please add a config dictionary named `chatml_datasets` in your running config file to convert the data file into an OpenCompass dataset at runtime.
91
+ An example is as follows:
92
+
93
+ ```python
94
+ chatml_datasets = [
95
+ dict(
96
+ abbr='YOUR_DATASET_NAME',
97
+ path='YOUR_DATASET_PATH',
98
+ evaluator=dict(
99
+ type='cascade_evaluator',
100
+ rule_evaluator=dict(
101
+ type='math_evaluator',
102
+ ),
103
+ llm_evaluator=dict(
104
+ type='llm_evaluator',
105
+ prompt="YOUR_JUDGE_PROMPT",
106
+ judge_cfg=dict(), # YOUR Judge Model Config
107
+ )
108
+ ),
109
+ n=1, # Repeat Number
110
+ ),
111
+ ]
112
+ ```
113
+
114
+ The ChatML evaluation module currently provides four preset evaluators, `mcq_rule_evaluator` used for MCQ evaluation, `math_evaluator` used for latex mathematical formula evaluation, `llm_evaluator` used for evaluating answers that are open-ended or difficult to extract), and `cascade_evaluator`, an evaluation mode composed of rule and LLM evaluators cascaded together.
115
+
116
+ In addition, if you have a long-term need to use datasets based on ChatML templates, you can contribute your dataset config to `opencompass/config/chatml_datasets`.
117
+ An eval example of calling these dataset configs is provided in `examples/evalchat_datasets.py`.
118
+
119
+ ## Data Format Protocol and Fast Evaluation Based on CustomsDataset
120
+
121
+ (This module is no longer being updated, but it can still be used if there is a need for cli- quick evaluation.)
122
+
123
+ This module support two types of tasks: multiple choice (`mcq`) and question & answer (`qa`). For `mcq`, both ppl and gen inferences are supported; for `qa`, gen inference is supported.
124
+
125
+ ### Dataset Format
126
+
127
+ We support datasets in both `.jsonl` and `.csv` formats.
128
+
129
+ #### Multiple Choice (`mcq`)
130
+
131
+ For `mcq` datasets, the default fields are as follows:
132
+
133
+ - `question`: The stem of the multiple-choice question.
134
+ - `A`, `B`, `C`, ...: Single uppercase letters representing the options, with no limit on the number. Defaults to parsing consecutive letters strating from `A` as options.
135
+ - `answer`: The correct answer to the multiple-choice question, which must be one of the options used above, such as `A`, `B`, `C`, etc.
136
+
137
+ Non-default fields will be read in but are not used by default. To use them, specify in the `.meta.json` file.
138
+
139
+ An example of the `.jsonl` format:
140
+
141
+ ```jsonl
142
+ {"question": "165+833+650+615=", "A": "2258", "B": "2263", "C": "2281", "answer": "B"}
143
+ {"question": "368+959+918+653+978=", "A": "3876", "B": "3878", "C": "3880", "answer": "A"}
144
+ {"question": "776+208+589+882+571+996+515+726=", "A": "5213", "B": "5263", "C": "5383", "answer": "B"}
145
+ {"question": "803+862+815+100+409+758+262+169=", "A": "4098", "B": "4128", "C": "4178", "answer": "C"}
146
+ ```
147
+
148
+ An example of the `.csv` format:
149
+
150
+ ```csv
151
+ question,A,B,C,answer
152
+ 127+545+588+620+556+199=,2632,2635,2645,B
153
+ 735+603+102+335+605=,2376,2380,2410,B
154
+ 506+346+920+451+910+142+659+850=,4766,4774,4784,C
155
+ 504+811+870+445=,2615,2630,2750,B
156
+ ```
157
+
158
+ #### Question & Answer (`qa`)
159
+
160
+ For `qa` datasets, the default fields are as follows:
161
+
162
+ - `question`: The stem of the question & answer question.
163
+ - `answer`: The correct answer to the question & answer question. It can be missing, indicating the dataset has no correct answer.
164
+
165
+ Non-default fields will be read in but are not used by default. To use them, specify in the `.meta.json` file.
166
+
167
+ An example of the `.jsonl` format:
168
+
169
+ ```jsonl
170
+ {"question": "752+361+181+933+235+986=", "answer": "3448"}
171
+ {"question": "712+165+223+711=", "answer": "1811"}
172
+ {"question": "921+975+888+539=", "answer": "3323"}
173
+ {"question": "752+321+388+643+568+982+468+397=", "answer": "4519"}
174
+ ```
175
+
176
+ An example of the `.csv` format:
177
+
178
+ ```csv
179
+ question,answer
180
+ 123+147+874+850+915+163+291+604=,3967
181
+ 149+646+241+898+822+386=,3142
182
+ 332+424+582+962+735+798+653+214=,4700
183
+ 649+215+412+495+220+738+989+452=,4170
184
+ ```
185
+
186
+ ### Command Line List
187
+
188
+ Custom datasets can be directly called for evaluation through the command line.
189
+
190
+ ```bash
191
+ python run.py \
192
+ --models hf_llama2_7b \
193
+ --custom-dataset-path xxx/test_mcq.csv \
194
+ --custom-dataset-data-type mcq \
195
+ --custom-dataset-infer-method ppl
196
+ ```
197
+
198
+ ```bash
199
+ python run.py \
200
+ --models hf_llama2_7b \
201
+ --custom-dataset-path xxx/test_qa.jsonl \
202
+ --custom-dataset-data-type qa \
203
+ --custom-dataset-infer-method gen
204
+ ```
205
+
206
+ In most cases, `--custom-dataset-data-type` and `--custom-dataset-infer-method` can be omitted. OpenCompass will
207
+
208
+ set them based on the following logic:
209
+
210
+ - If options like `A`, `B`, `C`, etc., can be parsed from the dataset file, it is considered an `mcq` dataset; otherwise, it is considered a `qa` dataset.
211
+ - The default `infer_method` is `gen`.
212
+
213
+ ### Configuration File
214
+
215
+ In the original configuration file, simply add a new item to the `datasets` variable. Custom datasets can be mixed with regular datasets.
216
+
217
+ ```python
218
+ datasets = [
219
+ {"path": "xxx/test_mcq.csv", "data_type": "mcq", "infer_method": "ppl"},
220
+ {"path": "xxx/test_qa.jsonl", "data_type": "qa", "infer_method": "gen"},
221
+ ]
222
+ ```
223
+
224
+ ### Supplemental Information for Dataset `.meta.json`
225
+
226
+ OpenCompass will try to parse the input dataset file by default, so in most cases, the `.meta.json` file is **not necessary**. However, if the dataset field names are not the default ones, or custom prompt words are required, it should be specified in the `.meta.json` file.
227
+
228
+ The file is placed in the same directory as the dataset, with the filename followed by `.meta.json`. An example file structure is as follows:
229
+
230
+ ```tree
231
+ .
232
+ ├── test_mcq.csv
233
+ ├── test_mcq.csv.meta.json
234
+ ├── test_qa.jsonl
235
+ └── test_qa.jsonl.meta.json
236
+ ```
237
+
238
+ Possible fields in this file include:
239
+
240
+ - `abbr` (str): Abbreviation of the dataset, serving as its ID.
241
+ - `data_type` (str): Type of dataset, options are `mcq` and `qa`.
242
+ - `infer_method` (str): Inference method, options are `ppl` and `gen`.
243
+ - `human_prompt` (str): User prompt template for generating prompts. Variables in the template are enclosed in `{}`, like `{question}`, `{opt1}`, etc. If `template` exists, this field will be ignored.
244
+ - `bot_prompt` (str): Bot prompt template for generating prompts. Variables in the template are enclosed in `{}`, like `{answer}`, etc. If `template` exists, this field will be ignored.
245
+ - `template` (str or dict): Question template for generating prompts. Variables in the template are enclosed in `{}`, like `{question}`, `{opt1}`, etc. The relevant syntax is in [here](../prompt/prompt_template.md) regarding `infer_cfg['prompt_template']['template']`.
246
+ - `input_columns` (list): List of input fields for reading data.
247
+ - `output_column` (str): Output field for reading data.
248
+ - `options` (list): List of options for reading data, valid only when `data_type` is `mcq`.
249
+
250
+ For example:
251
+
252
+ ```json
253
+ {
254
+ "human_prompt": "Question: 127 + 545 + 588 + 620 + 556 + 199 =\nA. 2632\nB. 2635\nC. 2645\nAnswer: Let's think step by step, 127 + 545 + 588 + 620 + 556 + 199 = 672 + 588 + 620 + 556 + 199 = 1260 + 620 + 556 + 199 = 1880 + 556 + 199 = 2436 + 199 = 2635. So the answer is B.\nQuestion: {question}\nA. {A}\nB. {B}\nC. {C}\nAnswer: ",
255
+ "bot_prompt": "{answer}"
256
+ }
257
+ ```
258
+
259
+ or
260
+
261
+ ```json
262
+ {
263
+ "template": "Question: {my_question}\nX. {X}\nY. {Y}\nZ. {Z}\nW. {W}\nAnswer:",
264
+ "input_columns": ["my_question", "X", "Y", "Z", "W"],
265
+ "output_column": "my_answer",
266
+ }
267
+ ```
docs/en/advanced_guides/evaluation_lightllm.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation with Lightllm
2
+
3
+ We now support the evaluation of large language models using [Lightllm](https://github.com/ModelTC/lightllm) for inference. Developed by SenseTime, LightLLM is a Python-based LLM (Large Language Model) inference and serving framework, notable for its lightweight design, easy scalability, and high-speed performance. Lightllm provides support for various large Language models, allowing users to perform model inference through Lightllm, locally deploying it as a service. During the evaluation process, OpenCompass feeds data to Lightllm through an API and processes the response. OpenCompass has been adapted for compatibility with Lightllm, and this tutorial will guide you on using OpenCompass to evaluate models with Lightllm as the inference backend.
4
+
5
+ ## Setup
6
+
7
+ ### Install OpenCompass
8
+
9
+ Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets.
10
+
11
+ ### Install Lightllm
12
+
13
+ Please follow the [Lightllm homepage](https://github.com/ModelTC/lightllm) to install the Lightllm. Pay attention to aligning the versions of relevant dependencies, especially the version of the Transformers.
14
+
15
+ ## Evaluation
16
+
17
+ We use the evaluation of Humaneval with the llama2-7B model as an example.
18
+
19
+ ### Step-1: Deploy the model locally as a service using Lightllm.
20
+
21
+ ```shell
22
+ python -m lightllm.server.api_server --model_dir /path/llama2-7B \
23
+ --host 0.0.0.0 \
24
+ --port 1030 \
25
+ --nccl_port 2066 \
26
+ --max_req_input_len 4096 \
27
+ --max_req_total_len 6144 \
28
+ --tp 1 \
29
+ --trust_remote_code \
30
+ --max_total_token_num 120000
31
+ ```
32
+
33
+ \*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.
34
+
35
+ \*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible.
36
+
37
+ \*\*Note: \*\* If you want to start multiple LightLLM services on the same machine, you need to reconfigure the above port and nccl_port.
38
+
39
+ You can use the following Python script to quickly test whether the current service has been successfully started.
40
+
41
+ ```python
42
+ import time
43
+ import requests
44
+ import json
45
+
46
+ url = 'http://localhost:8080/generate'
47
+ headers = {'Content-Type': 'application/json'}
48
+ data = {
49
+ 'inputs': 'What is AI?',
50
+ "parameters": {
51
+ 'do_sample': False,
52
+ 'ignore_eos': False,
53
+ 'max_new_tokens': 1024,
54
+ }
55
+ }
56
+ response = requests.post(url, headers=headers, data=json.dumps(data))
57
+ if response.status_code == 200:
58
+ print(response.json())
59
+ else:
60
+ print('Error:', response.status_code, response.text)
61
+ ```
62
+
63
+ ### Step-2: Evaluate the above model using OpenCompass.
64
+
65
+ ```shell
66
+ python run.py examples/eval_lightllm.py
67
+ ```
68
+
69
+ You are expected to get the evaluation results after the inference and evaluation.
70
+
71
+ \*\*Note: \*\*In `eval_lightllm.py`, please align the configured URL with the service address from the previous step.
docs/en/advanced_guides/evaluation_lmdeploy.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation with LMDeploy
2
+
3
+ We now support evaluation of models accelerated by the [LMDeploy](https://github.com/InternLM/lmdeploy). LMDeploy is a toolkit designed for compressing, deploying, and serving LLM. It has a remarkable inference performance. We now illustrate how to evaluate a model with the support of LMDeploy in OpenCompass.
4
+
5
+ ## Setup
6
+
7
+ ### Install OpenCompass
8
+
9
+ Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets.
10
+
11
+ ### Install LMDeploy
12
+
13
+ Install lmdeploy via pip (python 3.8+)
14
+
15
+ ```shell
16
+ pip install lmdeploy
17
+ ```
18
+
19
+ The default prebuilt package is compiled on CUDA 12. However, if CUDA 11+ is required, you can install lmdeploy by:
20
+
21
+ ```shell
22
+ export LMDEPLOY_VERSION=0.6.0
23
+ export PYTHON_VERSION=310
24
+ pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
25
+ ```
26
+
27
+ ## Evaluation
28
+
29
+ When evaluating a model, it is necessary to prepare an evaluation configuration that specifies information such as the evaluation dataset, the model, and inference parameters.
30
+
31
+ Taking [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) as an example, the evaluation config is as follows:
32
+
33
+ ```python
34
+ # configure the dataset
35
+ from mmengine.config import read_base
36
+
37
+
38
+ with read_base():
39
+ # choose a list of datasets
40
+ from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
41
+ from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
42
+ from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
43
+ from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import \
44
+ gsm8k_datasets
45
+ # and output the results in a chosen format
46
+ from .summarizers.medium import summarizer
47
+
48
+ datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
49
+
50
+ # configure lmdeploy
51
+ from opencompass.models import TurboMindModelwithChatTemplate
52
+
53
+
54
+
55
+ # configure the model
56
+ models = [
57
+ dict(
58
+ type=TurboMindModelwithChatTemplate,
59
+ abbr=f'internlm2-chat-7b-lmdeploy',
60
+ # model path, which can be the address of a model repository on the Hugging Face Hub or a local path
61
+ path='internlm/internlm2-chat-7b',
62
+ # inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'.
63
+ # If the model is not supported by 'turbomind', it will fallback to
64
+ # 'pytorch'
65
+ backend='turbomind',
66
+ # For the detailed engine config and generation config, please refer to
67
+ # https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py
68
+ engine_config=dict(tp=1),
69
+ gen_config=dict(do_sample=False),
70
+ # the max size of the context window
71
+ max_seq_len=7168,
72
+ # the max number of new tokens
73
+ max_out_len=1024,
74
+ # the max number of prompts that LMDeploy receives
75
+ # in `generate` function
76
+ batch_size=5000,
77
+ run_cfg=dict(num_gpus=1),
78
+ )
79
+ ]
80
+ ```
81
+
82
+ Place the aforementioned configuration in a file, such as "configs/eval_internlm2_lmdeploy.py". Then, in the home folder of OpenCompass, start evaluation by the following command:
83
+
84
+ ```shell
85
+ python run.py configs/eval_internlm2_lmdeploy.py -w outputs
86
+ ```
87
+
88
+ You are expected to get the evaluation results after the inference and evaluation.
docs/en/advanced_guides/llm_judge.md ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM as Judge Evaluation
2
+
3
+ ## Introduction
4
+
5
+ The GenericLLMEvaluator is particularly useful for scenarios where rule-based methods (like regular expressions) cannot perfectly judge outputs, such as:
6
+
7
+ - Cases where models output answer content without option identifiers
8
+ - Factual judgment datasets that are difficult to evaluate with rules
9
+ - Open-ended responses requiring complex understanding and reasoning
10
+ - Evaluation that requires a lot of rules to be designed
11
+
12
+ OpenCompass provides the GenericLLMEvaluator component to facilitate LLM-as-judge evaluations.
13
+
14
+ ## Dataset Format
15
+
16
+ The dataset for LLM judge evaluation should be in either JSON Lines (.jsonl) or CSV format. Each entry should contain at least:
17
+
18
+ - A problem or question
19
+ - A reference answer or gold standard
20
+ - (The model's prediction will be generated during evaluation)
21
+
22
+ Example JSONL format:
23
+
24
+ ```json
25
+ {"problem": "What is the capital of France?", "answer": "Paris"}
26
+ ```
27
+
28
+ Example CSV format:
29
+
30
+ ```csv
31
+ problem,answer
32
+ "What is the capital of France?","Paris"
33
+ ```
34
+
35
+ ## Configuration
36
+
37
+ ### Using LLM for Evaluation via Command Line
38
+
39
+ Some datasets in OpenCompass already include LLM judge configurations.
40
+ You need to use a model service (such as OpenAI or DeepSeek's official API) or start a model service locally using tools like LMDeploy, vLLM, or SGLang.
41
+
42
+ Then, you can set the environment variables for the evaluation service and evaluate models using the following commands:
43
+
44
+ ```bash
45
+ export OC_JUDGE_MODEL=Qwen/Qwen2.5-32B-Instruct
46
+ export OC_JUDGE_API_KEY=sk-1234
47
+ export OC_JUDGE_API_BASE=http://172.30.56.1:4000/v1
48
+ ```
49
+
50
+ Note that by default, OpenCompass will use these three environment variables, but if you use configuration files to configure the evaluation service, these environment variables will not take effect.
51
+
52
+ ### Using LLM for Evaluation via Configuration Files
53
+
54
+ To set up an LLM judge evaluation, you'll need to configure three main components:
55
+
56
+ 1. Dataset Reader Configuration
57
+
58
+ ```python
59
+ reader_cfg = dict(
60
+ input_columns=['problem'], # Column name for the question
61
+ output_column='answer' # Column name for the reference answer
62
+ )
63
+ ```
64
+
65
+ 2. Inference Configuration
66
+
67
+ ```python
68
+ infer_cfg = dict(
69
+ prompt_template=dict(
70
+ type=PromptTemplate,
71
+ template=dict(
72
+ round=[
73
+ dict(
74
+ role='HUMAN',
75
+ prompt='{problem}', # Template for prompting the model
76
+ ),
77
+ ]
78
+ ),
79
+ ),
80
+ retriever=dict(type=ZeroRetriever),
81
+ inferencer=dict(type=GenInferencer),
82
+ )
83
+ ```
84
+
85
+ 3. Evaluation Configuration with LLM Judge
86
+
87
+ ```python
88
+ eval_cfg = dict(
89
+ evaluator=dict(
90
+ type=GenericLLMEvaluator, # Using LLM as evaluator
91
+ prompt_template=dict(
92
+ type=PromptTemplate,
93
+ template=dict(
94
+ begin=[
95
+ dict(
96
+ role='SYSTEM',
97
+ fallback_role='HUMAN',
98
+ prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
99
+ )
100
+ ],
101
+ round=[
102
+ dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE), # Template for the judge
103
+ ],
104
+ ),
105
+ ),
106
+ dataset_cfg=dict(
107
+ type=CustomDataset,
108
+ path='path/to/your/dataset',
109
+ file_name='your_dataset.jsonl',
110
+ reader_cfg=reader_cfg,
111
+ ),
112
+ judge_cfg=YOUR_JUDGE_MODEL_CONFIG, # Configuration for the judge model
113
+ dict_postprocessor=dict(type=generic_llmjudge_postprocess), # Post-processing the judge's output
114
+ ),
115
+ )
116
+ ```
117
+
118
+ ## Using CustomDataset with GenericLLMEvaluator
119
+
120
+ Here's how to set up a complete configuration for LLM judge evaluation:
121
+
122
+ ```python
123
+ from mmengine.config import read_base
124
+ from opencompass.models import TurboMindModelwithChatTemplate
125
+ from opencompass.datasets import CustomDataset
126
+ from opencompass.evaluator import GenericLLMEvaluator
127
+ from opencompass.datasets import generic_llmjudge_postprocess
128
+ from opencompass.openicl.icl_prompt_template import PromptTemplate
129
+ from opencompass.openicl.icl_retriever import ZeroRetriever
130
+ from opencompass.openicl.icl_inferencer import GenInferencer
131
+
132
+ # Import your judge model configuration
133
+ with read_base():
134
+ from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
135
+ models as judge_model,
136
+ )
137
+
138
+ # Define your judge template
139
+ JUDGE_TEMPLATE = """
140
+ Please evaluate whether the following response correctly answers the question.
141
+ Question: {problem}
142
+ Reference Answer: {answer}
143
+ Model Response: {prediction}
144
+
145
+ Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
146
+ """.strip()
147
+
148
+ # Dataset reader configuration
149
+ reader_cfg = dict(input_columns=['problem'], output_column='answer')
150
+
151
+ # Inference configuration for the model being evaluated
152
+ infer_cfg = dict(
153
+ prompt_template=dict(
154
+ type=PromptTemplate,
155
+ template=dict(
156
+ round=[
157
+ dict(
158
+ role='HUMAN',
159
+ prompt='{problem}',
160
+ ),
161
+ ]
162
+ ),
163
+ ),
164
+ retriever=dict(type=ZeroRetriever),
165
+ inferencer=dict(type=GenInferencer),
166
+ )
167
+
168
+ # Evaluation configuration with LLM judge
169
+ eval_cfg = dict(
170
+ evaluator=dict(
171
+ type=GenericLLMEvaluator,
172
+ prompt_template=dict(
173
+ type=PromptTemplate,
174
+ template=dict(
175
+ begin=[
176
+ dict(
177
+ role='SYSTEM',
178
+ fallback_role='HUMAN',
179
+ prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
180
+ )
181
+ ],
182
+ round=[
183
+ dict(role='HUMAN', prompt=JUDGE_TEMPLATE),
184
+ ],
185
+ ),
186
+ ),
187
+ dataset_cfg=dict(
188
+ type=CustomDataset,
189
+ path='path/to/your/dataset',
190
+ file_name='your_dataset.jsonl',
191
+ reader_cfg=reader_cfg,
192
+ ),
193
+ judge_cfg=judge_model[0],
194
+ dict_postprocessor=dict(type=generic_llmjudge_postprocess),
195
+ ),
196
+ pred_role='BOT',
197
+ )
198
+
199
+ # Dataset configuration
200
+ datasets = [
201
+ dict(
202
+ type=CustomDataset,
203
+ abbr='my-dataset',
204
+ path='path/to/your/dataset',
205
+ file_name='your_dataset.jsonl',
206
+ reader_cfg=reader_cfg,
207
+ infer_cfg=infer_cfg,
208
+ eval_cfg=eval_cfg,
209
+ )
210
+ ]
211
+
212
+ # Model configuration for the model being evaluated
213
+ models = [
214
+ dict(
215
+ type=TurboMindModelwithChatTemplate,
216
+ abbr='model-to-evaluate',
217
+ path='path/to/your/model',
218
+ # ... other model configurations
219
+ )
220
+ ]
221
+
222
+ # Output directory
223
+ work_dir = './outputs/llm_judge_eval'
224
+ ```
225
+
226
+ ## GenericLLMEvaluator
227
+
228
+ The GenericLLMEvaluator is designed to use an LLM as a judge for evaluating model outputs. Key features include:
229
+
230
+ 1. Flexible prompt templates for instructing the judge
231
+ 2. Support for various judge models (local or API-based)
232
+ 3. Customizable evaluation criteria through prompt engineering
233
+ 4. Post-processing of judge outputs to extract structured evaluations
234
+
235
+ **Important Note**: The current generic version of the judge template only supports outputs in the format of "A" (correct) or "B" (incorrect), and does not support other output formats (like "CORRECT" or "INCORRECT"). This is because the post-processing function `generic_llmjudge_postprocess` is specifically designed to parse this format.
236
+
237
+ The evaluator works by:
238
+
239
+ 1. Taking the original problem, reference answer, and model prediction
240
+ 2. Formatting them into a prompt for the judge model
241
+ 3. Parsing the judge's response to determine the evaluation result (looking for "A" or "B")
242
+ 4. Aggregating results across the dataset
243
+
244
+ If you would like to see the full details of evaluation results, you can add `--dump-eval-details` to the command line when you start the job.
245
+ Example evaluation output:
246
+
247
+ ```python
248
+ {
249
+ 'accuracy': 75.0, # Percentage of responses judged as correct
250
+ 'details': [
251
+ {
252
+ 'origin_prompt': """
253
+ Please evaluate whether the following response correctly answers the question.
254
+ Question: What is the capital of France?
255
+ Reference Answer: Paris
256
+ Model Response: Paris
257
+ Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
258
+ """,
259
+ 'gold': 'Paris',
260
+ 'prediction': 'A',
261
+ },
262
+ # ... more results
263
+ ]
264
+ }
265
+ ```
266
+
267
+ ## CascadeEvaluator
268
+
269
+ OpenCompass also provides a CascadeEvaluator that combines the strengths of rule-based evaluation and LLM-based evaluation. The cascade evaluator has two modes:
270
+
271
+ 1. **Cascade Mode (parallel=False)**: First evaluates all samples with a rule-based evaluator, then only sends samples that were deemed incorrect by the rule-based evaluation to an LLM judge for re-evaluation. This approach reduces reliance on LLM judgments while maintaining accuracy, thus lowering evaluation costs and time.
272
+
273
+ 2. **Parallel Mode (parallel=True)**: Evaluates all samples with both the rule-based evaluator and LLM judge, then considers a sample correct if either method marks it as correct. This approach can increase the leniency of evaluation but may result in higher costs since all samples require LLM evaluation.
274
+
275
+ ### Configuring CascadeEvaluator
276
+
277
+ Here's an example of how to configure the CascadeEvaluator:
278
+
279
+ ```python
280
+ # Define a rule-based evaluator
281
+ rule_evaluator = dict(type=MATHVerifyEvaluator)
282
+
283
+ # Define an LLM judge evaluator
284
+ llm_judge_evaluator = dict(
285
+ type=GenericLLMEvaluator,
286
+ prompt_template=dict(
287
+ type=PromptTemplate,
288
+ template=dict(
289
+ begin=[
290
+ dict(
291
+ role='SYSTEM',
292
+ fallback_role='HUMAN',
293
+ prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
294
+ )
295
+ ],
296
+ round=[
297
+ dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE),
298
+ ],
299
+ ),
300
+ ),
301
+ dataset_cfg=dict(
302
+ type=YourDataset,
303
+ path='path/to/your/dataset',
304
+ reader_cfg=reader_cfg,
305
+ ),
306
+ judge_cfg=dict(), # Can use environment variables to configure the judge model
307
+ )
308
+
309
+ # Configure cascade evaluator (cascade mode)
310
+ cascade_evaluator = dict(
311
+ type=CascadeEvaluator,
312
+ llm_evaluator=llm_judge_evaluator,
313
+ rule_evaluator=rule_evaluator,
314
+ parallel=False # Cascade mode
315
+ )
316
+
317
+ # For parallel mode, set parallel=True
318
+ parallel_evaluator = dict(
319
+ type=CascadeEvaluator,
320
+ llm_evaluator=llm_judge_evaluator,
321
+ rule_evaluator=rule_evaluator,
322
+ parallel=True # Parallel mode
323
+ )
324
+
325
+ # Use the cascade evaluator in your dataset evaluation config
326
+ eval_cfg = dict(evaluator=cascade_evaluator)
327
+ ```
328
+
329
+ ### Evaluation Results
330
+
331
+ The cascade evaluator outputs detailed evaluation statistics including:
332
+
333
+ - Accuracy of the rule-based evaluation
334
+ - Accuracy of the LLM evaluation (for samples that failed rule-based evaluation in cascade mode)
335
+ - Final combined accuracy
336
+
337
+ Example output:
338
+
339
+ ```python
340
+ {
341
+ 'accuracy': 85.0, # Final accuracy
342
+ 'cascade_stats': {
343
+ 'total_samples': 100,
344
+ 'rule_correct': 70, # Number of samples correct by rule evaluation
345
+ 'rule_accuracy': 70.0, # Accuracy of rule evaluation
346
+ 'llm_evaluated': 30, # Number of samples evaluated by LLM (failed samples in cascade mode)
347
+ 'llm_correct': 15, # Number of samples correct by LLM evaluation
348
+ 'llm_accuracy': 50.0, # Accuracy of LLM evaluation
349
+ 'final_correct': 85, # Total correct samples
350
+ 'final_accuracy': 85.0, # Final accuracy
351
+ 'parallel_mode': False, # Whether parallel mode was used
352
+ },
353
+ 'details': [
354
+ # Detailed evaluation results for each sample
355
+ ]
356
+ }
357
+ ```
358
+
359
+ The cascade evaluator is particularly useful for:
360
+
361
+ 1. Scenarios that require balancing evaluation cost and accuracy
362
+ 2. Cases where rule-based evaluators are available but might not be comprehensive
363
+ 3. Evaluation tasks that need more nuanced judgment for edge cases
364
+
365
+ ## Complete Example
366
+
367
+ For a complete working example using GenericLLMEvaluator
368
+ , refer to the `eval_llm_judge.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving .
369
+
370
+ For a complete working example using CascadeEvaluator, refer to the `eval_cascade_evaluator.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving .
docs/en/advanced_guides/longeval.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Long Context Evaluation Guidance
2
+
3
+ ## Introduction
4
+
5
+ Although large-scale language models (LLMs) such as GPT-4 have demonstrated significant advantages in handling natural language tasks, most current open-source models can only handle texts with a length of a few thousand tokens, which limits their ability to process long contexts such as reading books and writing text summaries. To explore the performance of models in dealing with long contexts, we use the [L-Eval](https://github.com/OpenLMLab/LEval) and [LongBench](https://github.com/THUDM/LongBench) datasets to test the model's ability to handle long contexts.
6
+
7
+ ## Existing Algorithms and models
8
+
9
+ When dealing with long context inputs, the two main challenges faced by large models are the inference time cost and catastrophic forgetting. Recently, a large amount of research has been devoted to extending the model length, focusing on three improvement directions:
10
+
11
+ - Attention mechanisms. The ultimate goal of these methods is to reduce the computation cost of query-key pairs, but they may affect the performance of downstream tasks.
12
+ - Input methods. Some studies divide long context inputs into chunks or retrieve pre-existing text segments to enhance the model's ability to handle long contexts, but these methods are only effective for some tasks and are difficult to adapt to multiple downstream tasks.
13
+ - Position encoding. This research includes RoPE, ALiBi, Position Interpolation etc., which have shown good results in length extrapolation. These methods have been used to train long context models such as ChatGLM2-6B-32k and LongChat-32k.
14
+
15
+ First, we introduce some popular position encoding algorithms.
16
+
17
+ ### RoPE
18
+
19
+ RoPE is a type of positional embedding that injects the information of position in Transformer. It encodes the absolute position with a rotation matrix and meanwhile incorporates the explicit relative position dependency in self-attention formulation. A graphic illustration of RoPE is shown below.
20
+
21
+ <div align="center">
22
+ <img src=https://github.com/open-compass/opencompass/assets/75252858/08c57958-0dcb-40d7-b91b-33f20ca2d89f>
23
+ </div>
24
+
25
+ RoPE comes with valuable properties such as flexibility of being expand to any sequence lengths, decaying inter-token dependency with increasing relative distances, and capability of equipping the linear self-attention with relative position encoding.
26
+
27
+ RoPE is adopted in many LLMs including LLaMA, LLaMA 2 and Vicuna-7b-v1.5-16k.
28
+
29
+ ### ALiBi
30
+
31
+ Though RoPE and other alternatives to the original sinusoidal position method(like T5 bias) have improved extrapolation, they are considerably slower than the sinusoidal approach and use extra memory and parameter. Therefore, Attention with Linear Biases (ALiBi) is introduced to facilitate efficient extrapolation.
32
+
33
+ For an input subsequence of length L, the attention sublayer computes the attention scores for the ith query
34
+
35
+ ```{math}
36
+ q_{i} \in R^{1 \times d}, (1 \leq i \leq L)
37
+ ```
38
+
39
+ in each head, given the first i keys
40
+
41
+ ```{math}
42
+ K \in R^{i \times d}
43
+ ```
44
+
45
+ where d is the head dimension.
46
+
47
+ ```{math}
48
+ softmax(q_{i}K^{T})
49
+ ```
50
+
51
+ ALiBi negatively biases attention scores with a linearly decreasing penalty proportional to the distance between the relevant key and query. The only modification it applies is after the query-key dot product, where it adds a static, non-learned bias.
52
+
53
+ ```{math}
54
+ softmax(q_{i}K^{T}+m\cdot[-(i-1),...,-2,-1,0])
55
+ ```
56
+
57
+ where scalar m is a head-specific slope fixed before training.
58
+
59
+ ALiBi eliminates position embeddings and it is as fast as the sinusoidal approach. It is used in LLMs including mpt-7b-storywriter, which is prepared to handle extremely long inputs.
60
+
61
+ ### Position Interpolation(PI)
62
+
63
+ Many existing pre-trained LLMs including LLaMA use positional encodings that have weak extrapolation properties(e.g. RoPE). Position Interpolation is proposed and it can easily enable very long context windows while preserving model quality relatively well for the tasks within its original context window size.
64
+
65
+ The key idea of Position Interpolation is directly down-scale the position indices so that the maximum position index matches the previous context window limit in the pre-training stage. In other words, to accommodate more input tokens, the algorithm interpolates position encodings at neighboring integer positions, utilizing the fact that position encodings can be applied on non-integer positions, as opposed toextrapolating outside the trained positions, which may lead to catastrophic values. The algorithm requires only a very short period of fine-tuning for the model to fully adapt to greatly extended context windows.
66
+
67
+ An illustration of Position Interpolation method is shown below. Lower left illustrates Position Interpolation where it downscales the position indices (blue and green dots) themselves from \[0, 4096\] to \[0, 2048\] to force them to reside in the pretrained range.
68
+
69
+ <div align="center">
70
+ <img src=https://github.com/open-compass/opencompass/assets/75252858/406454ba-a811-4c66-abbe-3a5528947257>
71
+ </div>
72
+
73
+ Position Interpolation empowers ChatGLM2-6B-32k, a model based on ChatGLM2-6B, to deal with a 32k context window size.
74
+
75
+ Next, we introduce some long context language models we evaluate.
76
+
77
+ ### XGen-7B-8k
78
+
79
+ XGen-7B-8k is trained with standard dense attention on up to 8k sequence length for up to 1.5T tokens. To mitigate slow training, XGen-7B-8k introduces training in stages with increasing sequence length. First, 800B tokens with sequence length of 2k tokens are observed, then 400B tokens with 4k, finally, 300B tokens with 8k length.
80
+
81
+ ### Vicuna-7b-v1.5-16k
82
+
83
+ Vicuna-7b-v1.5-16k is fine-tuned from LLaMA 2 with supervised instruction fine-tuning and linear RoPE scaling. The training data is around 125K conversations collected from ShareGPT, a website where users can share their ChatGPT conversation. These conversations are packed into sequences that contain 16k tokens each.
84
+
85
+ ### LongChat-7b-v1.5-32k
86
+
87
+ LongChat-7b-v1.5-32k is fine-tuned from LLaMA 2 models, which were originally pretrained with 4k context length. The training recipe can be conceptually described in two steps. The first step is condensing RoPE. Since the LLaMA model has not observed scenarios where position_ids > 4096 during the pre-training phase, LongChat condenses position_ids > 4096 to be within 0 to 4096. The second step is fine-tuning LongChat model on curated conversation data. In this step, the data is cleaned using FastChat data pipeline and truncated to the maximum length of model.
88
+
89
+ ### ChatGLM2-6B-32k
90
+
91
+ The ChatGLM2-6B-32k further strengthens the ability to understand long texts based on the ChatGLM2-6B. Based on the method of Positional Interpolation, and trained with a 32K context length during the dialogue alignment, ChatGLM2-6B-32k can better handle up to 32K context length.
92
+
93
+ ## [L-Eval](https://github.com/OpenLMLab/LEval)
94
+
95
+ L-Eval is a long context dataset built by OpenLMLab, consisting of 18 subtasks, including texts from various fields such as law, economy, and technology. The dataset consists of a total of 411 documents, over 2000 test cases, with an average document length of 7217 words. The subtasks in this dataset are divided into close-ended and open-ended categories, with 5 close-ended tasks evaluated using the exact match criterion and 13 open-ended tasks evaluated using Rouge scores.
96
+
97
+ ## [LongBench](https://github.com/THUDM/LongBench)
98
+
99
+ LongBench is a long context dataset built by THUDM, consisting of 21 subtasks with a total of 4750 test cases. This dataset is the first long context dataset that includes both English and Chinese texts, with an average English text length of 6711 words and an average Chinese text length of 13386 characters. The 21 subtasks are divided into 6 types, providing a more comprehensive evaluation of the model's capabilities in various aspects.
100
+
101
+ <div align="center">
102
+ <img src=https://github.com/open-compass/opencompass/assets/75252858/4555e937-c519-4e9c-ad8d-7370430d466a>
103
+ </div>
104
+
105
+ ## Evaluation Method
106
+
107
+ Due to the different maximum input lengths accepted by different models, in order to compare these large models more fairly, when the input length exceeds the maximum input limit of the model, we will trim the middle part of the input text to avoid missing prompt words.
108
+
109
+ ## Long Context Ability Ranking
110
+
111
+ In the LongBench and L-Eval ability rankings, we select the average ranking **(The lower the better)** of each model in the subtask as the standard. It can be seen that GPT-4 and GPT-3.5-turbo-16k still occupy a leading position in long context tasks, while models like ChatGLM2-6B-32k also show significant improvement in long context ability after position interpolation based on ChatGLM2-6B.
112
+
113
+ <div align="center">
114
+ <img src=https://github.com/open-compass/opencompass/assets/75252858/29b5ad12-d9a3-4255-be0a-f770923fe514>
115
+ <img src=https://github.com/open-compass/opencompass/assets/75252858/680b4cda-c2b1-45d1-8c33-196dee1a38f3>
116
+ </div>
117
+
118
+ The original scores are shown below.
119
+
120
+ | L-Eval | GPT-4 | GPT-3.5-turbo-16k | chatglm2-6b-32k | vicuna-7b-v1.5-16k | xgen-7b-8k | internlm-chat-7b-8k | longchat-7b-v1.5-32k | chatglm2-6b |
121
+ | ----------------- | ----- | ----------------- | --------------- | ------------------ | ---------- | ------------------- | -------------------- | ----------- |
122
+ | coursera | 61.05 | 50 | 45.35 | 26.74 | 33.72 | 40.12 | 27.91 | 38.95 |
123
+ | gsm100 | 92 | 78 | 27 | 11 | 8 | 19 | 5 | 8 |
124
+ | quality | 81.19 | 62.87 | 44.55 | 11.39 | 33.66 | 45.54 | 29.7 | 41.09 |
125
+ | tpo | 72.93 | 74.72 | 56.51 | 17.47 | 44.61 | 60.59 | 17.1 | 56.51 |
126
+ | topic_retrieval | 100 | 79.33 | 44.67 | 24.67 | 1.33 | 0 | 25.33 | 1.33 |
127
+ | | | | | | | | | |
128
+ | financialqa | 53.49 | 50.32 | 35.41 | 44.59 | 39.28 | 25.09 | 34.07 | 17.82 |
129
+ | gov_report | 50.84 | 50.48 | 42.97 | 48.17 | 38.52 | 31.29 | 36.52 | 41.88 |
130
+ | legal_contract_qa | 31.23 | 27.97 | 34.21 | 24.25 | 21.36 | 19.28 | 13.32 | 17.59 |
131
+ | meeting_summ | 31.44 | 33.54 | 29.13 | 28.52 | 27.96 | 17.56 | 22.32 | 15.98 |
132
+ | multidocqa | 37.81 | 35.84 | 28.6 | 26.88 | 24.41 | 22.43 | 21.85 | 19.66 |
133
+ | narrativeqa | 25.87 | 25.73 | 18.24 | 20.58 | 16.87 | 13.81 | 16.87 | 1.16 |
134
+ | nq | 67.36 | 66.91 | 41.06 | 36.44 | 29.43 | 16.42 | 35.02 | 0.92 |
135
+ | news_summ | 34.52 | 40.41 | 32.72 | 33.98 | 26.87 | 22.48 | 30.33 | 29.51 |
136
+ | paper_assistant | 42.26 | 41.76 | 34.59 | 35.83 | 25.39 | 28.25 | 30.42 | 30.43 |
137
+ | patent_summ | 48.61 | 50.62 | 46.04 | 48.87 | 46.53 | 30.3 | 41.6 | 41.25 |
138
+ | review_summ | 31.98 | 33.37 | 21.88 | 29.21 | 26.85 | 16.61 | 20.02 | 19.68 |
139
+ | scientificqa | 49.76 | 48.32 | 31.27 | 31 | 27.43 | 33.01 | 20.98 | 13.61 |
140
+ | tvshow_summ | 34.84 | 31.36 | 23.97 | 27.88 | 26.6 | 14.55 | 25.09 | 19.45 |
141
+
142
+ | LongBench | GPT-4 | GPT-3.5-turbo-16k | chatglm2-6b-32k | longchat-7b-v1.5-32k | vicuna-7b-v1.5-16k | internlm-chat-7b-8k | chatglm2-6b | xgen-7b-8k |
143
+ | ------------------- | ----- | ----------------- | --------------- | -------------------- | ------------------ | ------------------- | ----------- | ---------- |
144
+ | NarrativeQA | 31.2 | 25.79 | 19.27 | 19.19 | 23.65 | 12.24 | 13.09 | 18.85 |
145
+ | Qasper | 42.77 | 43.4 | 33.93 | 30.36 | 31.45 | 24.81 | 22.52 | 20.18 |
146
+ | MultiFieldQA-en | 55.1 | 54.35 | 45.58 | 44.6 | 43.38 | 25.41 | 38.09 | 37 |
147
+ | MultiFieldQA-zh | 64.4 | 61.92 | 52.94 | 32.35 | 44.65 | 36.13 | 37.67 | 14.7 |
148
+ | | | | | | | | | |
149
+ | HotpotQA | 59.85 | 52.49 | 46.41 | 34.43 | 34.17 | 27.42 | 27.35 | 28.78 |
150
+ | 2WikiMQA | 67.52 | 41.7 | 33.63 | 23.06 | 20.45 | 26.24 | 22.83 | 20.13 |
151
+ | Musique | 37.53 | 27.5 | 21.57 | 12.42 | 13.92 | 9.75 | 7.26 | 11.34 |
152
+ | DuReader (zh) | 38.65 | 29.37 | 38.53 | 20.25 | 20.42 | 11.11 | 17.18 | 8.57 |
153
+ | | | | | | | | | |
154
+ | GovReport | 32.09 | 29.92 | 32.47 | 29.83 | 29.27 | 18.38 | 22.86 | 23.37 |
155
+ | QMSum | 24.37 | 23.67 | 23.19 | 22.71 | 23.37 | 18.45 | 21.23 | 21.12 |
156
+ | Multi_news | 28.52 | 27.05 | 25.12 | 26.1 | 27.83 | 24.52 | 24.7 | 23.69 |
157
+ | VCSUM (zh) | 15.54 | 16.88 | 15.95 | 13.46 | 15.76 | 12.91 | 14.07 | 0.98 |
158
+ | | | | | | | | | |
159
+ | TREC | 78.5 | 73.5 | 30.96 | 29.23 | 32.06 | 39 | 24.46 | 29.31 |
160
+ | TriviaQA | 92.19 | 92.75 | 80.64 | 64.19 | 46.53 | 79.55 | 64.19 | 69.58 |
161
+ | SAMSum | 46.32 | 43.16 | 29.49 | 25.23 | 25.23 | 43.05 | 20.22 | 16.05 |
162
+ | LSHT (zh) | 41.5 | 34.5 | 22.75 | 20 | 24.75 | 20.5 | 16 | 18.67 |
163
+ | | | | | | | | | |
164
+ | Passage Count | 8.5 | 3 | 3 | 1 | 3 | 1.76 | 3 | 1 |
165
+ | PassageRetrieval-en | 75 | 73 | 57.5 | 20.5 | 16.5 | 7 | 5.5 | 12 |
166
+ | PassageRetrieval-zh | 96 | 82.5 | 58 | 15 | 21 | 2.29 | 5 | 3.75 |
167
+ | | | | | | | | | |
168
+ | LCC | 59.25 | 53.49 | 53.3 | 51.46 | 49.3 | 49.32 | 46.59 | 44.1 |
169
+ | RepoBench-P | 55.42 | 55.95 | 46.66 | 52.18 | 41.49 | 35.86 | 41.97 | 41.83 |
docs/en/advanced_guides/math_verify.md ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General Math Evaluation Guidance
2
+
3
+ ## Introduction
4
+
5
+ Mathematical reasoning is a crucial capability for large language models (LLMs). To evaluate a model's mathematical abilities, we need to test its capability to solve mathematical problems step by step and provide accurate final answers. OpenCompass provides a convenient way to evaluate mathematical reasoning through the CustomDataset and MATHVerifyEvaluator components.
6
+
7
+ ## Dataset Format
8
+
9
+ The math evaluation dataset should be in either JSON Lines (.jsonl) or CSV format. Each problem should contain at least:
10
+
11
+ - A problem statement
12
+ - A solution/answer (typically in LaTeX format with the final answer in \\boxed{})
13
+
14
+ Example JSONL format:
15
+
16
+ ```json
17
+ {"problem": "Find the value of x if 2x + 3 = 7", "solution": "Let's solve step by step:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\nTherefore, \\boxed{2}"}
18
+ ```
19
+
20
+ Example CSV format:
21
+
22
+ ```csv
23
+ problem,solution
24
+ "Find the value of x if 2x + 3 = 7","Let's solve step by step:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\nTherefore, \\boxed{2}"
25
+ ```
26
+
27
+ ## Configuration
28
+
29
+ To evaluate mathematical reasoning, you'll need to set up three main components:
30
+
31
+ 1. Dataset Reader Configuration
32
+
33
+ ```python
34
+ math_reader_cfg = dict(
35
+ input_columns=['problem'], # Column name for the question
36
+ output_column='solution' # Column name for the answer
37
+ )
38
+ ```
39
+
40
+ 2. Inference Configuration
41
+
42
+ ```python
43
+ math_infer_cfg = dict(
44
+ prompt_template=dict(
45
+ type=PromptTemplate,
46
+ template=dict(
47
+ round=[
48
+ dict(
49
+ role='HUMAN',
50
+ prompt='{problem}\nPlease reason step by step, and put your final answer within \\boxed{}.',
51
+ ),
52
+ ]
53
+ ),
54
+ ),
55
+ retriever=dict(type=ZeroRetriever),
56
+ inferencer=dict(type=GenInferencer),
57
+ )
58
+ ```
59
+
60
+ 3. Evaluation Configuration
61
+
62
+ ```python
63
+ math_eval_cfg = dict(
64
+ evaluator=dict(type=MATHVerifyEvaluator),
65
+ )
66
+ ```
67
+
68
+ ## Using CustomDataset
69
+
70
+ Here's how to set up a complete configuration for math evaluation:
71
+
72
+ ```python
73
+ from mmengine.config import read_base
74
+ from opencompass.models import TurboMindModelwithChatTemplate
75
+ from opencompass.datasets import CustomDataset
76
+
77
+ math_datasets = [
78
+ dict(
79
+ type=CustomDataset,
80
+ abbr='my-math-dataset', # Dataset abbreviation
81
+ path='path/to/your/dataset', # Path to your dataset file
82
+ reader_cfg=math_reader_cfg,
83
+ infer_cfg=math_infer_cfg,
84
+ eval_cfg=math_eval_cfg,
85
+ )
86
+ ]
87
+ ```
88
+
89
+ ## MATHVerifyEvaluator
90
+
91
+ The MATHVerifyEvaluator is specifically designed to evaluate mathematical answers. It is developed based on the math_verify library, which provides mathematical expression parsing and verification capabilities, supporting extraction and equivalence verification for both LaTeX and general expressions.
92
+
93
+ The MATHVerifyEvaluator implements:
94
+
95
+ 1. Extracts answers from both predictions and references using LaTeX extraction
96
+ 2. Handles various LaTeX formats and environments
97
+ 3. Verifies mathematical equivalence between predicted and reference answers
98
+ 4. Provides detailed evaluation results including:
99
+ - Accuracy score
100
+ - Detailed comparison between predictions and references
101
+ - Parse results of both predicted and reference answers
102
+
103
+ The evaluator supports:
104
+
105
+ - Basic arithmetic operations
106
+ - Fractions and decimals
107
+ - Algebraic expressions
108
+ - Trigonometric functions
109
+ - Roots and exponents
110
+ - Mathematical symbols and operators
111
+
112
+ Example evaluation output:
113
+
114
+ ```python
115
+ {
116
+ 'accuracy': 85.0, # Percentage of correct answers
117
+ 'details': [
118
+ {
119
+ 'predictions': 'x = 2', # Parsed prediction
120
+ 'references': 'x = 2', # Parsed reference
121
+ 'correct': True # Whether they match
122
+ },
123
+ # ... more results
124
+ ]
125
+ }
126
+ ```
127
+
128
+ ## Complete Example
129
+
130
+ Here's a complete example of how to set up math evaluation:
131
+
132
+ ```python
133
+ from mmengine.config import read_base
134
+ from opencompass.models import TurboMindModelwithChatTemplate
135
+ from opencompass.datasets import CustomDataset
136
+ from opencompass.openicl.icl_evaluator.math_evaluator import MATHVerifyEvaluator
137
+ from opencompass.openicl.icl_prompt_template import PromptTemplate
138
+ from opencompass.openicl.icl_retriever import ZeroRetriever
139
+ from opencompass.openicl.icl_inferencer import GenInferencer
140
+
141
+ # Dataset reader configuration
142
+ math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
143
+
144
+ # Inference configuration
145
+ math_infer_cfg = dict(
146
+ prompt_template=dict(
147
+ type=PromptTemplate,
148
+ template=dict(
149
+ round=[
150
+ dict(
151
+ role='HUMAN',
152
+ prompt='{problem}\nPlease reason step by step, and put your final answer within \\boxed{}.',
153
+ ),
154
+ ]
155
+ ),
156
+ ),
157
+ retriever=dict(type=ZeroRetriever),
158
+ inferencer=dict(type=GenInferencer),
159
+ )
160
+
161
+ # Evaluation configuration
162
+ math_eval_cfg = dict(
163
+ evaluator=dict(type=MATHVerifyEvaluator),
164
+ )
165
+
166
+ # Dataset configuration
167
+ math_datasets = [
168
+ dict(
169
+ type=CustomDataset,
170
+ abbr='my-math-dataset',
171
+ path='path/to/your/dataset.jsonl', # or .csv
172
+ reader_cfg=math_reader_cfg,
173
+ infer_cfg=math_infer_cfg,
174
+ eval_cfg=math_eval_cfg,
175
+ )
176
+ ]
177
+
178
+ # Model configuration
179
+ models = [
180
+ dict(
181
+ type=TurboMindModelwithChatTemplate,
182
+ abbr='your-model-name',
183
+ path='your/model/path',
184
+ # ... other model configurations
185
+ )
186
+ ]
187
+
188
+ # Output directory
189
+ work_dir = './outputs/math_eval'
190
+ ```
docs/en/advanced_guides/needleinahaystack_eval.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Needle In A Haystack Evaluation
2
+
3
+ ## Introduction to the Needle In A Haystack Test
4
+
5
+ The Needle In A Haystack test (inspired by [NeedleInAHaystack](https://github.com/gkamradt/LLMTest_NeedleInAHaystack/blob/main/LLMNeedleHaystackTester.py)) is an evaluation method where key information is randomly inserted into long texts to form the prompt for large language models (LLMs). This test aims to assess whether LLMs can extract critical information from long texts, thereby evaluating their fundamental ability to comprehend and process long-context documents.
6
+
7
+ ## Task Overview
8
+
9
+ Within the `OpenCompass` framework, under `NeedleBench`, we designed a series of progressively challenging evaluation tasks to comprehensively assess LLMs' long-text information extraction and reasoning capabilities. For a complete description, please refer to our [technical report](https://arxiv.org/abs/2407.11963).
10
+
11
+ - **Single-Needle Retrieval Task (S-RT)**: Evaluates the LLM's ability to retrieve a single piece of key information from a long text, testing precise recall of specific details within extensive narratives. This corresponds to the **original Needle In A Haystack test** setup.
12
+
13
+ - **Multi-Needle Retrieval Task (M-RT)**: Explores the LLM's ability to retrieve multiple relevant pieces of information from long texts, simulating complex queries over comprehensive documents.
14
+
15
+ - **Multi-Needle Reasoning Task (M-RS)**: Assesses LLMs' abilities to integrate multiple key pieces of information extracted from long texts for reasoning, requiring a comprehensive understanding of content.
16
+
17
+ - **Ancestral Trace Challenge (ATC)**: Tests LLMs' capabilities in handling multi-layer logical challenges within realistic long-text contexts through "kinship trace needles." In the ATC task, no irrelevant (haystack) texts are added; every piece of text is critical, and models must reason through all details for accurate answers.
18
+
19
+ > **Note:** NeedleBench (v2) includes several optimizations and adjustments in dataset construction and task details. For a detailed comparison between the old and new versions, as well as a summary of updates, please refer to [opencompass/configs/datasets/needlebench_v2/readme.md](https://github.com/open-compass/opencompass/blob/main/opencompass/configs/datasets/needlebench_v2/readme.md).
20
+
21
+ ## Evaluation Steps
22
+
23
+ > Note: In the latest `OpenCompass` codebase, the NeedleBench dataset is automatically loaded from the [Huggingface interface](https://huggingface.co/datasets/opencompass/NeedleBench), with no need for manual download or configuration.
24
+
25
+ ### `OpenCompass` Environment Setup
26
+
27
+ ```bash
28
+ conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
29
+ conda activate opencompass
30
+ git clone https://github.com/open-compass/opencompass opencompass
31
+ cd opencompass
32
+ pip install -e .
33
+ ```
34
+
35
+ ### Dataset Configuration
36
+
37
+ We have pre-configured various long-context settings (4k, 8k, 32k, 128k, 200k, 1000k) in `opencompass/configs/datasets/needlebench_v2`, and you can flexibly define your parameters by adjusting the configuration files.
38
+
39
+ ### Evaluation Example
40
+
41
+ #### Evaluating with `VLLM` Deployed `Qwen2-5-7B` Model
42
+
43
+ To evaluate the `Qwen2-5-7B` model deployed with `VLLM` on all tasks under NeedleBench-128K, use the following command. This leverages pre-defined model and dataset configuration files without needing additional configuration:
44
+
45
+ ##### Local Evaluation
46
+
47
+ If evaluating locally, the command will use all available GPUs. You can control GPU visibility using `CUDA_VISIBLE_DEVICES`:
48
+
49
+ ```bash
50
+ # Local evaluation
51
+ python run.py --datasets needlebench_v2_128k --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer
52
+ ```
53
+
54
+ ##### Evaluation on Slurm Cluster
55
+
56
+ For Slurm environments, you can add options like `--slurm -p partition_name -q reserved --max-num-workers 16`:
57
+
58
+ ```bash
59
+ # Slurm evaluation
60
+ python run.py --datasets needlebench_v2_128k --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer --slurm -p partition_name -q reserved --max-num-workers 16
61
+ ```
62
+
63
+ ##### Evaluating Specific Subsets
64
+
65
+ If you only want to test the original Needle In A Haystack task (e.g., single-needle 128k), adjust the dataset parameter:
66
+
67
+ ```bash
68
+ python run.py --datasets needlebench_v2_single_128k --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer --slurm -p partition_name -q reserved --max-num-workers 16
69
+ ```
70
+
71
+ To evaluate only Chinese versions, specify the subset dataset after `/`:
72
+
73
+ ```bash
74
+ python run.py --datasets needlebench_v2_single_128k/needlebench_zh_datasets --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer --slurm -p partition_name -q reserved --max-num-workers 16
75
+ ```
76
+
77
+ Ensure `VLLM` is installed beforehand:
78
+
79
+ ```bash
80
+ # Install vLLM with CUDA 12.4.
81
+ # For other CUDA versions, please refer to the [official documentation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html)
82
+ pip install vllm
83
+ ```
84
+
85
+ #### Evaluating Other `Huggingface` Models
86
+
87
+ For other models, it is recommended to write your own config file (such as `examples/eval_needlebench_v2.py`) to adjust `max_seq_len` and `max_out_len`, so that the model can process the full context.
88
+
89
+ You can then run evaluation with:
90
+
91
+ ```bash
92
+ python run.py examples/eval_needlebench_v2.py --slurm -p partition_name -q reserved --max-num-workers 16
93
+ ```
94
+
95
+ No need to manually specify `--datasets`, `--models`, or `--summarizer` again.
96
+
97
+ ### Visualization
98
+
99
+ NeedleBench's latest version has built-in visualization integrated into the summarizer. You can find corresponding visualizations in the `plots` directory under the output folder without needing additional scripts.
100
+
101
+ ### Citation
102
+
103
+ If you use NeedleBench, please cite us:
104
+
105
+ ```bibtex
106
+ @misc{li2025needlebenchllmsretrievalreasoning,
107
+ title={NeedleBench: Can LLMs Do Retrieval and Reasoning in Information-Dense Context?},
108
+ author={Mo Li and Songyang Zhang and Taolin Zhang and Haodong Duan and Yunxin Liu and Kai Chen},
109
+ year={2025},
110
+ eprint={2407.11963},
111
+ archivePrefix={arXiv},
112
+ primaryClass={cs.CL},
113
+ url={https://arxiv.org/abs/2407.11963},
114
+ }
115
+
116
+ @misc{2023opencompass,
117
+ title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
118
+ author={OpenCompass Contributors},
119
+ howpublished={\url{https://github.com/open-compass/opencompass}},
120
+ year={2023}
121
+ }
122
+
123
+ @misc{LLMTest_NeedleInAHaystack,
124
+ title={LLMTest Needle In A Haystack - Pressure Testing LLMs},
125
+ author={gkamradt},
126
+ year={2023},
127
+ howpublished={\url{https://github.com/gkamradt/LLMTest_NeedleInAHaystack}}
128
+ }
129
+
130
+ @misc{wei2023skywork,
131
+ title={Skywork: A More Open Bilingual Foundation Model},
132
+ author={Tianwen Wei and Liang Zhao and Lichang Zhang and Bo Zhu and Lijie Wang and Haihua Yang and Biye Li and Cheng Cheng and Weiwei L\"u and Rui Hu and Chenxia Li and Liu Yang and Xilin Luo and Xuejie Wu and Lunan Liu and Wenjun Cheng and Peng Cheng and Jianhao Zhang and Xiaoyu Zhang and Lei Lin and Xiaokun Wang and Yutuan Ma and Chuanhai Dong and Yanqi Sun and Yifu Chen and Yongyi Peng and Xiaojuan Liang and Shuicheng Yan and Han Fang and Yahui Zhou},
133
+ year={2023},
134
+ eprint={2310.19341},
135
+ archivePrefix={arXiv},
136
+ primaryClass={cs.CL}
137
+ }
138
+ ```
docs/en/advanced_guides/new_dataset.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add a dataset
2
+
3
+ Although OpenCompass has already included most commonly used datasets, users need to follow the steps below to support a new dataset if wanted:
4
+
5
+ 1. Add a dataset script `mydataset.py` to the `opencompass/datasets` folder. This script should include:
6
+
7
+ - The dataset and its loading method. Define a `MyDataset` class that implements the data loading method `load` as a static method. This method should return data of type `datasets.Dataset`. We use the Hugging Face dataset as the unified interface for datasets to avoid introducing additional logic. Here's an example:
8
+
9
+ ```python
10
+ import datasets
11
+ from .base import BaseDataset
12
+
13
+ class MyDataset(BaseDataset):
14
+
15
+ @staticmethod
16
+ def load(**kwargs) -> datasets.Dataset:
17
+ pass
18
+ ```
19
+
20
+ - (Optional) If the existing evaluators in OpenCompass do not meet your needs, you need to define a `MyDatasetEvaluator` class that implements the scoring method `score`. This method should take `predictions` and `references` as input and return the desired dictionary. Since a dataset may have multiple metrics, the method should return a dictionary containing the metrics and their corresponding scores. Here's an example:
21
+
22
+ ```python
23
+ from opencompass.openicl.icl_evaluator import BaseEvaluator
24
+
25
+ class MyDatasetEvaluator(BaseEvaluator):
26
+
27
+ def score(self, predictions: List, references: List) -> dict:
28
+ pass
29
+ ```
30
+
31
+ - (Optional) If the existing postprocessors in OpenCompass do not meet your needs, you need to define the `mydataset_postprocess` method. This method takes an input string and returns the corresponding postprocessed result string. Here's an example:
32
+
33
+ ```python
34
+ def mydataset_postprocess(text: str) -> str:
35
+ pass
36
+ ```
37
+
38
+ 2. After defining the dataset loading, data postprocessing, and evaluator methods, you need to add the following configurations to the configuration file:
39
+
40
+ ```python
41
+ from opencompass.datasets import MyDataset, MyDatasetEvaluator, mydataset_postprocess
42
+
43
+ mydataset_eval_cfg = dict(
44
+ evaluator=dict(type=MyDatasetEvaluator),
45
+ pred_postprocessor=dict(type=mydataset_postprocess))
46
+
47
+ mydataset_datasets = [
48
+ dict(
49
+ type=MyDataset,
50
+ ...,
51
+ reader_cfg=...,
52
+ infer_cfg=...,
53
+ eval_cfg=mydataset_eval_cfg)
54
+ ]
55
+ ```
56
+
57
+ - To facilitate the access of your datasets to other users, you need to specify the channels for downloading the datasets in the configuration file. Specifically, you need to first fill in a dataset name given by yourself in the `path` field in the `mydataset_datasets` configuration, and this name will be mapped to the actual download path in the `opencompass/utils/datasets_info.py` file. Here's an example:
58
+
59
+ ```python
60
+ mmlu_datasets = [an
61
+ dict(
62
+ ...,
63
+ path='opencompass/mmlu',
64
+ ...,
65
+ )
66
+ ]
67
+ ```
68
+
69
+ - Next, you need to create a dictionary key in `opencompass/utils/datasets_info.py` with the same name as the one you provided above. If you have already hosted the dataset on HuggingFace or Modelscope, please add a dictionary key to the `DATASETS_MAPPING` dictionary and fill in the HuggingFace or Modelscope dataset address in the `hf_id` or `ms_id` key, respectively. You can also specify a default local address. Here's an example:
70
+
71
+ ```python
72
+ "opencompass/mmlu": {
73
+ "ms_id": "opencompass/mmlu",
74
+ "hf_id": "opencompass/mmlu",
75
+ "local": "./data/mmlu/",
76
+ }
77
+ ```
78
+
79
+ - If you wish for the provided dataset to be directly accessible from the OpenCompass OSS repository when used by others, you need to submit the dataset files in the Pull Request phase. We will then transfer the dataset to the OSS on your behalf and create a new dictionary key in the `DATASET_URL`.
80
+
81
+ - To ensure the optionality of data sources, you need to improve the method `load` in the dataset script `mydataset.py`. Specifically, you need to implement a functionality to switch among different download sources based on the setting of the environment variable `DATASET_SOURCE`. It should be noted that if the environment variable `DATASET_SOURCE` is not set, the dataset will default to being downloaded from the OSS repository. Here's an example from `opencompass/dataset/cmmlu.py`:
82
+
83
+ ```python
84
+ def load(path: str, name: str, **kwargs):
85
+ ...
86
+ if environ.get('DATASET_SOURCE') == 'ModelScope':
87
+ ...
88
+ else:
89
+ ...
90
+ return dataset
91
+ ```
92
+
93
+ 3. After completing the dataset script and config file, you need to register the information of your new dataset in the file `dataset-index.yml` at the main directory, so that it can be added to the dataset statistics list on the OpenCompass website.
94
+
95
+ - The keys that need to be filled in include `name`: the name of your dataset, `category`: the category of your dataset, `paper`: the URL of the paper or project, and `configpath`: the path to the dataset config file. Here's an example:
96
+
97
+ ```
98
+ - mydataset:
99
+ name: MyDataset
100
+ category: Understanding
101
+ paper: https://arxiv.org/pdf/xxxxxxx
102
+ configpath: opencompass/configs/datasets/MyDataset
103
+ ```
104
+
105
+ Detailed dataset configuration files and other required configuration files can be referred to in the [Configuration Files](../user_guides/config.md) tutorial. For guides on launching tasks, please refer to the [Quick Start](../get_started/quick_start.md) tutorial.
docs/en/advanced_guides/new_model.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add a Model
2
+
3
+ Currently, we support HF models, some model APIs, and some third-party models.
4
+
5
+ ## Adding API Models
6
+
7
+ To add a new API-based model, you need to create a new file named `mymodel_api.py` under `opencompass/models` directory. In this file, you should inherit from `BaseAPIModel` and implement the `generate` method for inference and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file.
8
+
9
+ ```python
10
+ from ..base_api import BaseAPIModel
11
+
12
+ class MyModelAPI(BaseAPIModel):
13
+
14
+ is_api: bool = True
15
+
16
+ def __init__(self,
17
+ path: str,
18
+ max_seq_len: int = 2048,
19
+ query_per_second: int = 1,
20
+ retry: int = 2,
21
+ **kwargs):
22
+ super().__init__(path=path,
23
+ max_seq_len=max_seq_len,
24
+ meta_template=meta_template,
25
+ query_per_second=query_per_second,
26
+ retry=retry)
27
+ ...
28
+
29
+ def generate(
30
+ self,
31
+ inputs,
32
+ max_out_len: int = 512,
33
+ temperature: float = 0.7,
34
+ ) -> List[str]:
35
+ """Generate results given a list of inputs."""
36
+ pass
37
+
38
+ def get_token_len(self, prompt: str) -> int:
39
+ """Get lengths of the tokenized string."""
40
+ pass
41
+ ```
42
+
43
+ ## Adding Third-Party Models
44
+
45
+ To add a new third-party model, you need to create a new file named `mymodel.py` under `opencompass/models` directory. In this file, you should inherit from `BaseModel` and implement the `generate` method for generative inference, the `get_ppl` method for discriminative inference, and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file.
46
+
47
+ ```python
48
+ from ..base import BaseModel
49
+
50
+ class MyModel(BaseModel):
51
+
52
+ def __init__(self,
53
+ pkg_root: str,
54
+ ckpt_path: str,
55
+ tokenizer_only: bool = False,
56
+ meta_template: Optional[Dict] = None,
57
+ **kwargs):
58
+ ...
59
+
60
+ def get_token_len(self, prompt: str) -> int:
61
+ """Get lengths of the tokenized strings."""
62
+ pass
63
+
64
+ def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
65
+ """Generate results given a list of inputs. """
66
+ pass
67
+
68
+ def get_ppl(self,
69
+ inputs: List[str],
70
+ mask_length: Optional[List[int]] = None) -> List[float]:
71
+ """Get perplexity scores given a list of inputs."""
72
+ pass
73
+ ```
docs/en/advanced_guides/objective_judgelm_evaluation.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Using Large Models as JudgeLLM for Objective Evaluation
2
+
3
+ ## Introduction
4
+
5
+ Traditional objective evaluations often rely on standard answers for reference. However, in practical applications, the predicted results of models may vary due to differences in the model's instruction-following capabilities or imperfections in post-processing functions. This can lead to incorrect extraction of answers and comparison with standard answers, resulting in potentially inaccurate evaluation outcomes. To address this issue, we have adopted a process similar to subjective evaluations by introducing JudgeLLM post-prediction to assess the consistency between model responses and standard answers. ([LLM-as-a-Judge](https://arxiv.org/abs/2306.05685)).
6
+
7
+ Currently, all models supported by the opencompass repository can be directly used as JudgeLLM. Additionally, we are planning to support dedicated JudgeLLMs.
8
+
9
+ ## Currently Supported Objective Evaluation Datasets
10
+
11
+ 1. MATH ([https://github.com/hendrycks/math](https://github.com/hendrycks/math))
12
+
13
+ ## Custom JudgeLLM Objective Dataset Evaluation
14
+
15
+ OpenCompass currently supports most datasets that use `GenInferencer` for inference. The specific process for custom JudgeLLM objective evaluation includes:
16
+
17
+ 1. Building evaluation configurations using API models or open-source models for inference of question answers.
18
+ 2. Employing a selected evaluation model (JudgeLLM) to assess the outputs of the model.
19
+
20
+ ### Step One: Building Evaluation Configurations, Using MATH as an Example
21
+
22
+ Below is the Config for evaluating the MATH dataset with JudgeLLM, with the evaluation model being *Llama3-8b-instruct* and the JudgeLLM being *Llama3-70b-instruct*. For more detailed config settings, please refer to `examples/eval_math_llm_judge.py`. The following is a brief version of the annotations to help users understand the meaning of the configuration file.
23
+
24
+ ```python
25
+ # Most of the code in this file is copied from https://github.com/openai/simple-evals/blob/main/math_eval.py
26
+ from mmengine.config import read_base
27
+ with read_base():
28
+ from .models.hf_llama.hf_llama3_8b_instruct import models as hf_llama3_8b_instruct_model # noqa: F401, F403
29
+ from .models.hf_llama.hf_llama3_70b_instruct import models as hf_llama3_70b_instruct_model # noqa: F401, F403
30
+ from .datasets.math.math_llm_judge import math_datasets # noqa: F401, F403
31
+ from opencompass.datasets import math_judement_preprocess
32
+ from opencompass.partitioners import NaivePartitioner, SizePartitioner
33
+ from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
34
+ from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
35
+ from opencompass.runners import LocalRunner
36
+ from opencompass.runners import SlurmSequentialRunner
37
+ from opencompass.tasks import OpenICLInferTask
38
+ from opencompass.tasks.subjective_eval import SubjectiveEvalTask
39
+ from opencompass.summarizers import AllObjSummarizer
40
+ from opencompass.openicl.icl_evaluator import LMEvaluator
41
+ from opencompass.openicl.icl_prompt_template import PromptTemplate
42
+
43
+
44
+ # ------------- Prompt Settings ----------------------------------------
45
+ # Evaluation template, please modify the template as needed, JudgeLLM typically uses [Yes] or [No] as the response. For the MATH dataset, the evaluation template is as follows:
46
+ eng_obj_prompt = """
47
+ Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
48
+
49
+ Examples:
50
+
51
+ Expression 1: $2x+3$
52
+ Expression 2: $3+2x$
53
+
54
+ [Yes]
55
+
56
+ Expression 1: 3/2
57
+ Expression 2: 1.5
58
+
59
+ [Yes]
60
+
61
+ Expression 1: $x^2+2x+1$
62
+ Expression 2: $y^2+2y+1$
63
+
64
+ [No]
65
+
66
+ Expression 1: $x^2+2x+1$
67
+ Expression 2: $(x+1)^2$
68
+
69
+ [Yes]
70
+
71
+ Expression 1: 3245/5
72
+ Expression 2: 649
73
+
74
+ [No]
75
+ (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
76
+
77
+ Expression 1: 2/(-3)
78
+ Expression 2: -2/3
79
+
80
+ [Yes]
81
+ (trivial simplifications are allowed)
82
+
83
+ Expression 1: 72 degrees
84
+ Expression 2: 72
85
+
86
+ [Yes]
87
+ (give benefit of the doubt to units)
88
+
89
+ Expression 1: 64
90
+ Expression 2: 64 square feet
91
+
92
+ [Yes]
93
+ (give benefit of the doubt to units)
94
+
95
+ Expression 1: 64
96
+ Expression 2:
97
+
98
+ [No]
99
+ (only mark as equivalent if both expressions are nonempty)
100
+
101
+ ---
102
+
103
+ YOUR TASK
104
+
105
+
106
+ Respond with only "[Yes]" or "[No]" (without quotes). Do not include a rationale.
107
+ Expression 1: {obj_gold}
108
+ Expression 2: {prediction}
109
+
110
+ """
111
+
112
+ # ------------- Inference Phase ----------------------------------------
113
+ # Models to be evaluated
114
+ models = [*hf_llama3_8b_instruct_model]
115
+ # Evaluation models
116
+ judge_models = hf_llama3_70b_instruct_model
117
+
118
+ eng_datasets = [*math_datasets]
119
+ chn_datasets = []
120
+ datasets = eng_datasets + chn_datasets
121
+
122
+
123
+ for d in eng_datasets:
124
+ d['eval_cfg']= dict(
125
+ evaluator=dict(
126
+ type=LMEvaluator,
127
+ # If you need to preprocess model predictions before judging,
128
+ # you can specify a pred_postprocessor function here
129
+ pred_postprocessor=dict(type=math_judement_preprocess),
130
+ prompt_template=dict(
131
+ type=PromptTemplate,
132
+ template=dict(round=[
133
+ dict(
134
+ role='HUMAN',
135
+ prompt = eng_obj_prompt
136
+ ),
137
+ ]),
138
+ ),
139
+ ),
140
+ pred_role="BOT",
141
+ )
142
+
143
+ infer = dict(
144
+ partitioner=dict(type=SizePartitioner, max_task_size=40000),
145
+ runner=dict(
146
+ type=LocalRunner,
147
+ max_num_workers=256,
148
+ task=dict(type=OpenICLInferTask)),
149
+ )
150
+
151
+ # ------------- Evaluation Configuration --------------------------------
152
+ eval = dict(
153
+ partitioner=dict(
154
+ type=SubjectiveSizePartitioner, max_task_size=80000, mode='singlescore', models=models, judge_models=judge_models,
155
+ ),
156
+ runner=dict(type=LocalRunner,
157
+ max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
158
+ )
159
+
160
+ summarizer = dict(
161
+ type=AllObjSummarizer
162
+ )
163
+
164
+ # Output folder
165
+ work_dir = 'outputs/obj_all/'
166
+ ```
167
+
168
+ ### Step Two: Launch Evaluation and Output Results
169
+
170
+ ```shell
171
+ python run.py eval_math_llm_judge.py
172
+ ```
173
+
174
+ This will initiate two rounds of evaluation. The first round involves model inference to obtain predicted answers to questions, and the second round involves JudgeLLM evaluating the consistency between the predicted answers and the standard answers, and scoring them.
175
+
176
+ - The results of model predictions will be saved in `output/.../timestamp/predictions/xxmodel/xxx.json`
177
+ - The JudgeLLM's evaluation responses will be saved in `output/.../timestamp/results/xxmodel/xxx.json`
178
+ - The evaluation report will be output to `output/.../timestamp/summary/timestamp/xxx.csv`
179
+
180
+ ## Results
181
+
182
+ Using the Llama3-8b-instruct as the evaluation model and the Llama3-70b-instruct as the evaluator, the MATH dataset was assessed with the following results:
183
+
184
+ | Model | JudgeLLM Evaluation | Naive Evaluation |
185
+ | ------------------- | ------------------- | ---------------- |
186
+ | llama-3-8b-instruct | 27.7 | 27.8 |
docs/en/advanced_guides/persistence.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Results Persistence
2
+
3
+ ## Introduction
4
+
5
+ Normally, the evaluation results of OpenCompass will be saved to your work directory. But in some cases, there may be a need for data sharing among users or quickly browsing existing public evaluation results. Therefore, we provide an interface that can quickly transfer evaluation results to external public data stations, and on this basis, provide functions such as uploading, overwriting, and reading.
6
+
7
+ ## Quick Start
8
+
9
+ ### Uploading
10
+
11
+ By adding `args` to the evaluation command or adding configuration in the Eval script, the results of evaluation can be stored in the path you specify. Here are the examples:
12
+
13
+ (Approach 1) Add an `args` option to the command and specify your public path address.
14
+
15
+ ```bash
16
+ opencompass ... -sp '/your_path'
17
+ ```
18
+
19
+ (Approach 2) Add configuration in the Eval script.
20
+
21
+ ```pythonE
22
+ station_path = '/your_path'
23
+ ```
24
+
25
+ ### Overwriting
26
+
27
+ The above storage method will first determine whether the same task result already exists in the data station based on the `abbr` attribute in the model and dataset configuration before uploading data. If results already exists, cancel this storage. If you need to update these results, please add the `station-overwrite` option to the command, here is an example:
28
+
29
+ ```bash
30
+ opencompass ... -sp '/your_path' --station-overwrite
31
+ ```
32
+
33
+ ### Reading
34
+
35
+ You can directly read existing results from the data station to avoid duplicate evaluation tasks. The read results will directly participate in the 'summarize' step. When using this configuration, only tasks that do not store results in the data station will be initiated. Here is an example:
36
+
37
+ ```bash
38
+ opencompass ... -sp '/your_path' --read-from-station
39
+ ```
40
+
41
+ ### Command Combination
42
+
43
+ 1. Only upload the results under your latest working directory to the data station, without supplementing tasks that missing results:
44
+
45
+ ```bash
46
+ opencompass ... -sp '/your_path' -r latest -m viz
47
+ ```
48
+
49
+ ## Storage Format of the Data Station
50
+
51
+ In the data station, the evaluation results are stored as `json` files for each `model-dataset` pair. The specific directory form is `/your_path/dataset_name/model_name.json `. Each `json` file stores a dictionary corresponding to the results, including `predictions`, `results`, and `cfg`, here is an example:
52
+
53
+ ```pythonE
54
+ Result = {
55
+ 'predictions': List[Dict],
56
+ 'results': Dict,
57
+ 'cfg': Dict = {
58
+ 'models': Dict,
59
+ 'datasets': Dict,
60
+ (Only subjective datasets)'judge_models': Dict
61
+ }
62
+ }
63
+ ```
64
+
65
+ Among this three keys, `predictions` records the predictions of the model on each item of data in the dataset. `results` records the total score of the model on the dataset. `cfg` records detailed configurations of the model and the dataset in this evaluation task.
docs/en/advanced_guides/prompt_attack.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt Attack
2
+
3
+ We support prompt attack following the idea of [PromptBench](https://github.com/microsoft/promptbench). The main purpose here is to evaluate the robustness of prompt instruction, which means when attack/modify the prompt to instruct the task, how well can this task perform as the original task.
4
+
5
+ ## Set up environment
6
+
7
+ Some components are necessary to prompt attack experiment, therefore we need to set up environments.
8
+
9
+ ```shell
10
+ git clone https://github.com/microsoft/promptbench.git
11
+ pip install textattack==0.3.8
12
+ export PYTHONPATH=$PYTHONPATH:promptbench/
13
+ ```
14
+
15
+ ## How to attack
16
+
17
+ ### Add a dataset config
18
+
19
+ We will use GLUE-wnli dataset as example, most configuration settings can refer to [config.md](../user_guides/config.md) for help.
20
+
21
+ First we need support the basic dataset config, you can find the existing config files in `configs` or support your own config according to [new-dataset](./new_dataset.md)
22
+
23
+ Take the following `infer_cfg` as example, we need to define the prompt template. `adv_prompt` is the basic prompt placeholder to be attacked in the experiment. `sentence1` and `sentence2` are the input columns of this dataset. The attack will only modify the `adv_prompt` here.
24
+
25
+ Then, we should use `AttackInferencer` with `original_prompt_list` and `adv_key` to tell the inferencer where to attack and what text to be attacked.
26
+
27
+ More details can refer to `configs/datasets/promptbench/promptbench_wnli_gen_50662f.py` config file.
28
+
29
+ ```python
30
+ original_prompt_list = [
31
+ 'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
32
+ "Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
33
+ ...,
34
+ ]
35
+
36
+ wnli_infer_cfg = dict(
37
+ prompt_template=dict(
38
+ type=PromptTemplate,
39
+ template=dict(round=[
40
+ dict(
41
+ role="HUMAN",
42
+ prompt="""{adv_prompt}
43
+ Sentence 1: {sentence1}
44
+ Sentence 2: {sentence2}
45
+ Answer:"""),
46
+ ]),
47
+ ),
48
+ retriever=dict(type=ZeroRetriever),
49
+ inferencer=dict(
50
+ type=AttackInferencer,
51
+ original_prompt_list=original_prompt_list,
52
+ adv_key='adv_prompt'))
53
+ ```
54
+
55
+ ### Add a eval config
56
+
57
+ We should use `OpenICLAttackTask` here for attack task. Also `NaivePartitioner` should be used because the attack experiment will run the whole dataset repeatedly for nearly hurdurds times to search the best attack, we do not want to split the dataset for convenience.
58
+
59
+ ```note
60
+ Please choose a small dataset(example < 1000) for attack, due to the aforementioned repeated search, otherwise the time cost is enumerous.
61
+ ```
62
+
63
+ There are several other options in `attack` config:
64
+
65
+ - `attack`: attack type, available options includes `textfooler`, `textbugger`, `deepwordbug`, `bertattack`, `checklist`, `stresstest`;
66
+ - `query_budget`: upper boundary of queries, which means the total numbers of running the dataset;
67
+ - `prompt_topk`: number of topk prompt to be attacked. In most case, the original prompt list is great than 10, running the whole set is time consuming.
68
+
69
+ ```python
70
+ # Please run whole dataset at a time, aka use `NaivePartitioner` only
71
+ # Please use `OpenICLAttackTask` if want to perform attack experiment
72
+ infer = dict(
73
+ partitioner=dict(type=NaivePartitioner),
74
+ runner=dict(
75
+ type=SlurmRunner,
76
+ max_num_workers=8,
77
+ task=dict(type=OpenICLAttackTask),
78
+ retry=0),
79
+ )
80
+
81
+ attack = dict(
82
+ attack='textfooler',
83
+ query_budget=100,
84
+ prompt_topk=2,
85
+ )
86
+ ```
87
+
88
+ ### Run the experiment
89
+
90
+ Please use `--mode infer` when run the attack experiment, and set `PYTHONPATH` env.
91
+
92
+ ```shell
93
+ python run.py examples/eval_attack.py --mode infer
94
+ ```
95
+
96
+ All the results will be saved in `attack` folder.
97
+ The content includes the original prompt accuracy and the attacked prompt with dropped accuracy of `topk` prompt, for instance:
98
+
99
+ ```
100
+ Prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'., acc: 59.15%
101
+ Prompt: Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'., acc: 57.75%
102
+ Prompt: Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'., acc: 56.34%
103
+ Prompt: Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'., acc: 54.93%
104
+ ...
105
+ Original prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.
106
+ Attacked prompt: b"Assess the attach between the following sentences and sorted it as 'A. entailment' or 'B. not_entailment'."
107
+ Original acc: 59.15%, attacked acc: 40.85%, dropped acc: 18.31%
108
+ ```
docs/en/advanced_guides/subjective_evaluation.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Subjective Evaluation Guidance
2
+
3
+ ## Introduction
4
+
5
+ Subjective evaluation aims to assess the model's performance in tasks that align with human preferences. The key criterion for this evaluation is human preference, but it comes with a high cost of annotation.
6
+
7
+ To explore the model's subjective capabilities, we employ JudgeLLM as a substitute for human assessors ([LLM-as-a-Judge](https://arxiv.org/abs/2306.05685)).
8
+
9
+ A popular evaluation method involves
10
+
11
+ - Compare Mode: comparing model responses pairwise to calculate their win rate
12
+ - Score Mode: another method involves calculate scores with single model response ([Chatbot Arena](https://chat.lmsys.org/)).
13
+
14
+ We support the use of GPT-4 (or other JudgeLLM) for the subjective evaluation of models based on above methods.
15
+
16
+ ## Currently Supported Subjective Evaluation Datasets
17
+
18
+ 1. AlignBench Chinese Scoring Dataset (https://github.com/THUDM/AlignBench)
19
+ 2. MTBench English Scoring Dataset, two-turn dialogue (https://github.com/lm-sys/FastChat)
20
+ 3. MTBench101 English Scoring Dataset, multi-turn dialogue (https://github.com/mtbench101/mt-bench-101)
21
+ 4. AlpacaEvalv2 English Compare Dataset (https://github.com/tatsu-lab/alpaca_eval)
22
+ 5. ArenaHard English Compare Dataset, mainly focused on coding (https://github.com/lm-sys/arena-hard/tree/main)
23
+ 6. Fofo English Scoring Dataset (https://github.com/SalesforceAIResearch/FoFo/)
24
+ 7. Wildbench English Score and Compare Dataset(https://github.com/allenai/WildBench)
25
+
26
+ ## Initiating Subjective Evaluation
27
+
28
+ Similar to existing objective evaluation methods, you can configure related settings in `examples/eval_subjective.py`.
29
+
30
+ ### Basic Parameters: Specifying models, datasets, and judgemodels
31
+
32
+ Similar to objective evaluation, import the models and datasets that need to be evaluated, for example:
33
+
34
+ ```
35
+ with read_base():
36
+ from .datasets.subjective.alignbench.alignbench_judgeby_critiquellm import alignbench_datasets
37
+ from .datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import subjective_datasets as alpacav2
38
+ from .models.qwen.hf_qwen_7b import models
39
+ ```
40
+
41
+ It is worth noting that since the model setup parameters for subjective evaluation are often different from those for objective evaluation, it often requires setting up `do_sample` for inference instead of `greedy`. You can modify the relevant parameters in the configuration file as needed, for example:
42
+
43
+ ```
44
+ models = [
45
+ dict(
46
+ type=HuggingFaceChatGLM3,
47
+ abbr='chatglm3-6b-hf2',
48
+ path='THUDM/chatglm3-6b',
49
+ tokenizer_path='THUDM/chatglm3-6b',
50
+ model_kwargs=dict(
51
+ device_map='auto',
52
+ trust_remote_code=True,
53
+ ),
54
+ tokenizer_kwargs=dict(
55
+ padding_side='left',
56
+ truncation_side='left',
57
+ trust_remote_code=True,
58
+ ),
59
+ generation_kwargs=dict(
60
+ do_sample=True,
61
+ ),
62
+ meta_template=api_meta_template,
63
+ max_out_len=2048,
64
+ max_seq_len=4096,
65
+ batch_size=8,
66
+ run_cfg=dict(num_gpus=1, num_procs=1),
67
+ )
68
+ ]
69
+ ```
70
+
71
+ The judgemodel is usually set to a powerful model like GPT4, and you can directly enter your API key according to the configuration in the config file, or use a custom model as the judgemodel.
72
+
73
+ ### Specifying Other Parameters
74
+
75
+ In addition to the basic parameters, you can also modify the `infer` and `eval` fields in the config to set a more appropriate partitioning method. The currently supported partitioning methods mainly include three types: NaivePartitioner, SizePartitioner, and NumberWorkPartitioner. You can also specify your own workdir to save related files.
76
+
77
+ ## Subjective Evaluation with Custom Dataset
78
+
79
+ The specific process includes:
80
+
81
+ 1. Data preparation
82
+ 2. Model response generation
83
+ 3. Evaluate the response with a JudgeLLM
84
+ 4. Generate JudgeLLM's response and calculate the metric
85
+
86
+ ### Step-1: Data Preparation
87
+
88
+ This step requires preparing the dataset file and implementing your own dataset class under `Opencompass/datasets/subjective/`, returning the read data in the format of `list of dict`.
89
+
90
+ Actually, you can prepare the data in any format you like (csv, json, jsonl, etc.). However, to make it easier to get started, it is recommended to construct the data according to the format of the existing subjective datasets or according to the following json format.
91
+ We provide mini test-set for **Compare Mode** and **Score Mode** as below:
92
+
93
+ ```python
94
+ ###COREV2
95
+ [
96
+ {
97
+ "question": "如果我在空中垂直抛球,球最初向哪个方向行进?",
98
+ "capability": "知识-社会常识",
99
+ "others": {
100
+ "question": "如果我在空中垂直抛球,球最初向哪个方向行进?",
101
+ "evaluating_guidance": "",
102
+ "reference_answer": "上"
103
+ }
104
+ },...]
105
+
106
+ ###CreationV0.1
107
+ [
108
+ {
109
+ "question": "请你扮演一个邮件管家,我让你给谁发送什么主题的邮件,你就帮我扩充好邮件正文,并打印在聊天框里。你需要根据我提供���邮件收件人以及邮件主题,来斟酌用词,并使用合适的敬语。现在请给导师发送邮件,询问他是否可以下周三下午15:00进行科研同步会,大约200字。",
110
+ "capability": "邮件通知",
111
+ "others": ""
112
+ },
113
+ ```
114
+
115
+ The json must includes the following fields:
116
+
117
+ - 'question': Question description
118
+ - 'capability': The capability dimension of the question.
119
+ - 'others': Other needed information.
120
+
121
+ If you want to modify prompt on each single question, you can full some other information into 'others' and construct it.
122
+
123
+ ### Step-2: Evaluation Configuration(Compare Mode)
124
+
125
+ Taking Alignbench as an example, `configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py`:
126
+
127
+ 1. First, you need to set `subjective_reader_cfg` to receive the relevant fields returned from the custom Dataset class and specify the output fields when saving files.
128
+ 2. Then, you need to specify the root path `data_path` of the dataset and the dataset filename `subjective_all_sets`. If there are multiple sub-files, you can add them to this list.
129
+ 3. Specify `subjective_infer_cfg` and `subjective_eval_cfg` to configure the corresponding inference and evaluation prompts.
130
+ 4. Specify additional information such as `mode` at the corresponding location. Note that the fields required for different subjective datasets may vary.
131
+ 5. Define post-processing and score statistics. For example, the postprocessing function `alignbench_postprocess` located under `opencompass/opencompass/datasets/subjective/alignbench`.
132
+
133
+ ### Step-3: Launch the Evaluation
134
+
135
+ ```shell
136
+ python run.py config/eval_subjective_score.py -r
137
+ ```
138
+
139
+ The `-r` parameter allows the reuse of model inference and GPT-4 evaluation results.
140
+
141
+ The response of JudgeLLM will be output to `output/.../results/timestamp/xxmodel/xxdataset/.json`.
142
+ The evaluation report will be output to `output/.../summary/timestamp/report.csv`.
143
+
144
+ ## Multi-round Subjective Evaluation in OpenCompass
145
+
146
+ In OpenCompass, we also support subjective multi-turn dialogue evaluation. For instance, the evaluation of MT-Bench can be referred to in `configs/datasets/subjective/multiround`.
147
+
148
+ In the multi-turn dialogue evaluation, you need to organize the data format into the following dialogue structure:
149
+
150
+ ```
151
+ "dialogue": [
152
+ {
153
+ "role": "user",
154
+ "content": "Imagine you are participating in a race with a group of people. If you have just overtaken the second person, what's your current position? Where is the person you just overtook?"
155
+ },
156
+ {
157
+ "role": "assistant",
158
+ "content": ""
159
+ },
160
+ {
161
+ "role": "user",
162
+ "content": "If the \"second person\" is changed to \"last person\" in the above question, what would the answer be?"
163
+ },
164
+ {
165
+ "role": "assistant",
166
+ "content": ""
167
+ }
168
+ ],
169
+ ```
170
+
171
+ It's important to note that due to the different question types in MTBench having different temperature settings, we need to divide the original data files into three different subsets according to the temperature for separate inference. For different subsets, we can set different temperatures. For specific settings, please refer to `configs\datasets\subjective\multiround\mtbench_single_judge_diff_temp.py`.
docs/en/conf.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # Configuration file for the Sphinx documentation builder.
3
+ #
4
+ # This file only contains a selection of the most common options. For a full
5
+ # list see the documentation:
6
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
7
+
8
+ # -- Path setup --------------------------------------------------------------
9
+
10
+ # If extensions (or modules to document with autodoc) are in another directory,
11
+ # add these directories to sys.path here. If the directory is relative to the
12
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
13
+ #
14
+ import os
15
+ import subprocess
16
+ import sys
17
+
18
+ import pytorch_sphinx_theme
19
+ from sphinx.builders.html import StandaloneHTMLBuilder
20
+
21
+ sys.path.insert(0, os.path.abspath('../../'))
22
+
23
+ # -- Project information -----------------------------------------------------
24
+
25
+ project = 'OpenCompass'
26
+ copyright = '2023, OpenCompass'
27
+ author = 'OpenCompass Authors'
28
+
29
+ # The full version, including alpha/beta/rc tags
30
+ version_file = '../../opencompass/__init__.py'
31
+
32
+
33
+ def get_version():
34
+ with open(version_file, 'r') as f:
35
+ exec(compile(f.read(), version_file, 'exec'))
36
+ return locals()['__version__']
37
+
38
+
39
+ release = get_version()
40
+
41
+ # -- General configuration ---------------------------------------------------
42
+
43
+ # Add any Sphinx extension module names here, as strings. They can be
44
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
45
+ # ones.
46
+ extensions = [
47
+ 'sphinx.ext.autodoc',
48
+ 'sphinx.ext.autosummary',
49
+ 'sphinx.ext.intersphinx',
50
+ 'sphinx.ext.napoleon',
51
+ 'sphinx.ext.viewcode',
52
+ 'myst_parser',
53
+ 'sphinx_copybutton',
54
+ 'sphinx_tabs.tabs',
55
+ 'notfound.extension',
56
+ 'sphinxcontrib.jquery',
57
+ 'sphinx_design',
58
+ ]
59
+
60
+ # Add any paths that contain templates here, relative to this directory.
61
+ templates_path = ['_templates']
62
+
63
+ # The suffix(es) of source filenames.
64
+ # You can specify multiple suffix as a list of string:
65
+ #
66
+ source_suffix = {
67
+ '.rst': 'restructuredtext',
68
+ '.md': 'markdown',
69
+ }
70
+
71
+ language = 'en'
72
+
73
+ # The master toctree document.
74
+ root_doc = 'index'
75
+
76
+ # List of patterns, relative to source directory, that match files and
77
+ # directories to ignore when looking for source files.
78
+ # This pattern also affects html_static_path and html_extra_path.
79
+ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
80
+
81
+ # -- Options for HTML output -------------------------------------------------
82
+
83
+ # The theme to use for HTML and HTML Help pages. See the documentation for
84
+ # a list of builtin themes.
85
+ #
86
+ html_theme = 'pytorch_sphinx_theme'
87
+ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
88
+
89
+ # Theme options are theme-specific and customize the look and feel of a theme
90
+ # further. For a list of options available for each theme, see the
91
+ # documentation.
92
+ # yapf: disable
93
+ html_theme_options = {
94
+ 'menu': [
95
+ {
96
+ 'name': 'GitHub',
97
+ 'url': 'https://github.com/open-compass/opencompass'
98
+ },
99
+ ],
100
+ # Specify the language of shared menu
101
+ 'menu_lang': 'en',
102
+ # Disable the default edit on GitHub
103
+ 'default_edit_on_github': False,
104
+ }
105
+ # yapf: enable
106
+
107
+ # Add any paths that contain custom static files (such as style sheets) here,
108
+ # relative to this directory. They are copied after the builtin static files,
109
+ # so a file named "default.css" will overwrite the builtin "default.css".
110
+ html_static_path = ['_static']
111
+ html_css_files = [
112
+ 'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.css',
113
+ 'css/readthedocs.css'
114
+ ]
115
+ html_js_files = [
116
+ 'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.js',
117
+ 'js/custom.js'
118
+ ]
119
+
120
+ html_context = {
121
+ 'github_version': 'main',
122
+ }
123
+
124
+ # -- Options for HTMLHelp output ---------------------------------------------
125
+
126
+ # Output file base name for HTML help builder.
127
+ htmlhelp_basename = 'opencompassdoc'
128
+
129
+ # -- Options for LaTeX output ------------------------------------------------
130
+
131
+ latex_elements = {
132
+ # The paper size ('letterpaper' or 'a4paper').
133
+ #
134
+ # 'papersize': 'letterpaper',
135
+
136
+ # The font size ('10pt', '11pt' or '12pt').
137
+ #
138
+ # 'pointsize': '10pt',
139
+
140
+ # Additional stuff for the LaTeX preamble.
141
+ #
142
+ # 'preamble': '',
143
+ }
144
+
145
+ # Grouping the document tree into LaTeX files. List of tuples
146
+ # (source start file, target name, title,
147
+ # author, documentclass [howto, manual, or own class]).
148
+ latex_documents = [
149
+ (root_doc, 'opencompass.tex', 'OpenCompass Documentation', author,
150
+ 'manual'),
151
+ ]
152
+
153
+ # -- Options for manual page output ------------------------------------------
154
+
155
+ # One entry per manual page. List of tuples
156
+ # (source start file, name, description, authors, manual section).
157
+ man_pages = [(root_doc, 'opencompass', 'OpenCompass Documentation', [author],
158
+ 1)]
159
+
160
+ # -- Options for Texinfo output ----------------------------------------------
161
+
162
+ # Grouping the document tree into Texinfo files. List of tuples
163
+ # (source start file, target name, title, author,
164
+ # dir menu entry, description, category)
165
+ texinfo_documents = [
166
+ (root_doc, 'opencompass', 'OpenCompass Documentation', author,
167
+ 'OpenCompass Authors', 'AGI evaluation toolbox and benchmark.',
168
+ 'Miscellaneous'),
169
+ ]
170
+
171
+ # -- Options for Epub output -------------------------------------------------
172
+
173
+ # Bibliographic Dublin Core info.
174
+ epub_title = project
175
+
176
+ # The unique identifier of the text. This can be a ISBN number
177
+ # or the project homepage.
178
+ #
179
+ # epub_identifier = ''
180
+
181
+ # A unique identification for the text.
182
+ #
183
+ # epub_uid = ''
184
+
185
+ # A list of files that should not be packed into the epub file.
186
+ epub_exclude_files = ['search.html']
187
+
188
+ # set priority when building html
189
+ StandaloneHTMLBuilder.supported_image_types = [
190
+ 'image/svg+xml', 'image/gif', 'image/png', 'image/jpeg'
191
+ ]
192
+
193
+ # -- Extension configuration -------------------------------------------------
194
+ # Ignore >>> when copying code
195
+ copybutton_prompt_text = r'>>> |\.\.\. '
196
+ copybutton_prompt_is_regexp = True
197
+
198
+ # Auto-generated header anchors
199
+ myst_heading_anchors = 3
200
+ # Enable "colon_fence" extension of myst.
201
+ myst_enable_extensions = ['colon_fence', 'dollarmath']
202
+
203
+ # Configuration for intersphinx
204
+ intersphinx_mapping = {
205
+ 'python': ('https://docs.python.org/3', None),
206
+ 'numpy': ('https://numpy.org/doc/stable', None),
207
+ 'torch': ('https://pytorch.org/docs/stable/', None),
208
+ 'mmengine': ('https://mmengine.readthedocs.io/en/latest/', None),
209
+ 'transformers':
210
+ ('https://huggingface.co/docs/transformers/main/en/', None),
211
+ }
212
+ napoleon_custom_sections = [
213
+ # Custom sections for data elements.
214
+ ('Meta fields', 'params_style'),
215
+ ('Data fields', 'params_style'),
216
+ ]
217
+
218
+ # Disable docstring inheritance
219
+ autodoc_inherit_docstrings = False
220
+ # Mock some imports during generate API docs.
221
+ autodoc_mock_imports = ['rich', 'attr', 'einops']
222
+ # Disable displaying type annotations, these can be very verbose
223
+ autodoc_typehints = 'none'
224
+
225
+ # The not found page
226
+ notfound_template = '404.html'
227
+
228
+
229
+ def builder_inited_handler(app):
230
+ subprocess.run(['./statis.py'])
231
+
232
+
233
+ def setup(app):
234
+ app.connect('builder-inited', builder_inited_handler)
docs/en/docutils.conf ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [html writers]
2
+ table_style: colwidths-auto
docs/en/get_started/faq.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ ## General
4
+
5
+ ### What are the differences and connections between `ppl` and `gen`?
6
+
7
+ `ppl` stands for perplexity, an index used to evaluate a model's language modeling capabilities. In the context of OpenCompass, it generally refers to a method of answering multiple-choice questions: given a context, the model needs to choose the most appropriate option from multiple choices. In this case, we concatenate the n options with the context to form n sequences, then calculate the model's perplexity for these n sequences. We consider the option corresponding to the sequence with the lowest perplexity as the model's reasoning result for this question. This evaluation method is simple and direct in post-processing, with high certainty.
8
+
9
+ `gen` is an abbreviation for generate. In the context of OpenCompass, it refers to the model's continuation writing result given a context as the reasoning result for a question. Generally, the string obtained from continuation writing requires a heavier post-processing process to extract reliable answers and complete the evaluation.
10
+
11
+ In terms of usage, multiple-choice questions and some multiple-choice-like questions of the base model use `ppl`, while the base model's multiple-selection and non-multiple-choice questions use `gen`. All questions of the chat model use `gen`, as many commercial API models do not expose the `ppl` interface. However, there are exceptions, such as when we want the base model to output the problem-solving process (e.g., Let's think step by step), we will also use `gen`, but the overall usage is as shown in the following table:
12
+
13
+ | | ppl | gen |
14
+ | ---------- | -------------- | -------------------- |
15
+ | Base Model | Only MCQ Tasks | Tasks Other Than MCQ |
16
+ | Chat Model | None | All Tasks |
17
+
18
+ Similar to `ppl`, conditional log probability (`clp`) calculates the probability of the next token given a context. It is also only applicable to multiple-choice questions, and the range of probability calculation is limited to the tokens corresponding to the option numbers. The option corresponding to the token with the highest probability is considered the model's reasoning result. Compared to `ppl`, `clp` calculation is more efficient, requiring only one inference, whereas `ppl` requires n inferences. However, the drawback is that `clp` is subject to the tokenizer. For example, the presence or absence of space symbols before and after an option can change the tokenizer's encoding result, leading to unreliable test results. Therefore, `clp` is rarely used in OpenCompass.
19
+
20
+ ### How does OpenCompass control the number of shots in few-shot evaluations?
21
+
22
+ In the dataset configuration file, there is a retriever field indicating how to recall samples from the dataset as context examples. The most commonly used is `FixKRetriever`, which means using a fixed k samples, hence k-shot. There is also `ZeroRetriever`, which means not using any samples, which in most cases implies 0-shot.
23
+
24
+ On the other hand, in-context samples can also be directly specified in the dataset template. In this case, `ZeroRetriever` is also used, but the evaluation is not 0-shot and needs to be determined based on the specific template. Refer to [prompt](../prompt/prompt_template.md) for more details
25
+
26
+ ### How does OpenCompass allocate GPUs?
27
+
28
+ OpenCompass processes evaluation requests using the unit termed as "task". Each task is an independent combination of model(s) and dataset(s). The GPU resources needed for a task are determined entirely by the model being evaluated, specifically by the `num_gpus` parameter.
29
+
30
+ During evaluation, OpenCompass deploys multiple workers to execute tasks in parallel. These workers continuously try to secure GPU resources and run tasks until they succeed. As a result, OpenCompass always strives to leverage all available GPU resources to their maximum capacity.
31
+
32
+ For instance, if you're using OpenCompass on a local machine equipped with 8 GPUs, and each task demands 4 GPUs, then by default, OpenCompass will employ all 8 GPUs to concurrently run 2 tasks. However, if you adjust the `--max-num-workers` setting to 1, then only one task will be processed at a time, utilizing just 4 GPUs.
33
+
34
+ ### Why doesn't the GPU behavior of HuggingFace models align with my expectations?
35
+
36
+ This is a complex issue that needs to be explained from both the supply and demand sides:
37
+
38
+ The supply side refers to how many tasks are being run. A task is a combination of a model and a dataset, and it primarily depends on how many models and datasets need to be tested. Additionally, since OpenCompass splits a larger task into multiple smaller tasks, the number of data entries per sub-task (`--max-partition-size`) also affects the number of tasks. (The `--max-partition-size` is proportional to the actual number of data entries, but the relationship is not 1:1).
39
+
40
+ The demand side refers to how many workers are running. Since OpenCompass instantiates multiple models for inference simultaneously, we use `--hf-num-gpus` to specify how many GPUs each instance uses. Note that `--hf-num-gpus` is a parameter specific to HuggingFace models and setting this parameter for non-HuggingFace models will not have any effect. We also use `--max-num-workers` to indicate the maximum number of instances running at the same time. Lastly, due to issues like GPU memory and insufficient load, OpenCompass also supports running multiple instances on the same GPU, which is managed by the parameter `--max-num-workers-per-gpu`. Therefore, it can be generally assumed that we will use a total of `--hf-num-gpus` * `--max-num-workers` / `--max-num-workers-per-gpu` GPUs.
41
+
42
+ In summary, when tasks run slowly or the GPU load is low, we first need to check if the supply is sufficient. If not, consider reducing `--max-partition-size` to split the tasks into finer parts. Next, we need to check if the demand is sufficient. If not, consider increasing `--max-num-workers` and `--max-num-workers-per-gpu`. Generally, **we set `--hf-num-gpus` to the minimum value that meets the demand and do not adjust it further.**
43
+
44
+ ### How do I control the number of GPUs that OpenCompass occupies?
45
+
46
+ Currently, there isn't a direct method to specify the number of GPUs OpenCompass can utilize. However, the following are some indirect strategies:
47
+
48
+ **If evaluating locally:**
49
+ You can limit OpenCompass's GPU access by setting the `CUDA_VISIBLE_DEVICES` environment variable. For instance, using `CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py ...` will only expose the first four GPUs to OpenCompass, ensuring it uses no more than these four GPUs simultaneously.
50
+
51
+ **If using Slurm or DLC:**
52
+ Although OpenCompass doesn't have direct access to the resource pool, you can adjust the `--max-num-workers` parameter to restrict the number of evaluation tasks being submitted simultaneously. This will indirectly manage the number of GPUs that OpenCompass employs. For instance, if each task requires 4 GPUs, and you wish to allocate a total of 8 GPUs, then you should set `--max-num-workers` to 2.
53
+
54
+ ### `libGL.so.1` not foune
55
+
56
+ opencv-python depends on some dynamic libraries that are not present in the environment. The simplest solution is to uninstall opencv-python and then install opencv-python-headless.
57
+
58
+ ```bash
59
+ pip uninstall opencv-python
60
+ pip install opencv-python-headless
61
+ ```
62
+
63
+ Alternatively, you can install the corresponding dependency libraries according to the error message
64
+
65
+ ```bash
66
+ sudo apt-get update
67
+ sudo apt-get install -y libgl1 libglib2.0-0
68
+ ```
69
+
70
+ ## Network
71
+
72
+ ### My tasks failed with error: `('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))` or `urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='cdn-lfs.huggingface.co', port=443)`
73
+
74
+ Because of HuggingFace's implementation, OpenCompass requires network (especially the connection to HuggingFace) for the first time it loads some datasets and models. Additionally, it connects to HuggingFace each time it is launched. For a successful run, you may:
75
+
76
+ - Work behind a proxy by specifying the environment variables `http_proxy` and `https_proxy`;
77
+ - Use the cache files from other machines. You may first run the experiment on a machine that has access to the Internet, and then copy the cached files to the offline one. The cached files are located at `~/.cache/huggingface/` by default ([doc](https://huggingface.co/docs/datasets/cache#cache-directory)). When the cached files are ready, you can start the evaluation in offline mode:
78
+ ```python
79
+ HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 HF_EVALUATE_OFFLINE=1 python run.py ...
80
+ ```
81
+ With which no more network connection is needed for the evaluation. However, error will still be raised if the files any dataset or model is missing from the cache.
82
+ - Use mirror like [hf-mirror](https://hf-mirror.com/)
83
+ ```python
84
+ HF_ENDPOINT=https://hf-mirror.com python run.py ...
85
+ ```
86
+
87
+ ### My server cannot connect to the Internet, how can I use OpenCompass?
88
+
89
+ Use the cache files from other machines, as suggested in the answer to [Network-Q1](#my-tasks-failed-with-error-connection-aborted-connectionreseterror104-connection-reset-by-peer-or-urllib3exceptionsmaxretryerror-httpsconnectionpoolhostcdn-lfshuggingfaceco-port443).
90
+
91
+ ### In evaluation phase, I'm running into an error saying that `FileNotFoundError: Couldn't find a module script at opencompass/accuracy.py. Module 'accuracy' doesn't exist on the Hugging Face Hub either.`
92
+
93
+ HuggingFace tries to load the metric (e.g. `accuracy`) as an module online, and it could fail if the network is unreachable. Please refer to [Network-Q1](#my-tasks-failed-with-error-connection-aborted-connectionreseterror104-connection-reset-by-peer-or-urllib3exceptionsmaxretryerror-httpsconnectionpoolhostcdn-lfshuggingfaceco-port443) for guidelines to fix your network issue.
94
+
95
+ The issue has been fixed in the latest version of OpenCompass, so you might also consider pull from the latest version.
96
+
97
+ ## Efficiency
98
+
99
+ ### Why does OpenCompass partition each evaluation request into tasks?
100
+
101
+ Given the extensive evaluation time and the vast quantity of datasets, conducting a comprehensive linear evaluation on LLM models can be immensely time-consuming. To address this, OpenCompass divides the evaluation request into multiple independent "tasks". These tasks are then dispatched to various GPU groups or nodes, achieving full parallelism and maximizing the efficiency of computational resources.
102
+
103
+ ### How does task partitioning work?
104
+
105
+ Each task in OpenCompass represents a combination of specific model(s) and portions of the dataset awaiting evaluation. OpenCompass offers a variety of task partitioning strategies, each tailored for different scenarios. During the inference stage, the prevalent partitioning method seeks to balance task size, or computational cost. This cost is heuristically derived from the dataset size and the type of inference.
106
+
107
+ ### Why does it take more time to evaluate LLM models on OpenCompass?
108
+
109
+ There is a tradeoff between the number of tasks and the time to load the model. For example, if we partition an request that evaluates a model against a dataset into 100 tasks, the model will be loaded 100 times in total. When resources are abundant, these 100 tasks can be executed in parallel, so the additional time spent on model loading can be ignored. However, if resources are limited, these 100 tasks will operate more sequentially, and repeated loadings can become a bottleneck in execution time.
110
+
111
+ Hence, if users find that the number of tasks greatly exceeds the available GPUs, we advise setting the `--max-partition-size` to a larger value.
112
+
113
+ ## Model
114
+
115
+ ### How to use the downloaded huggingface models?
116
+
117
+ If you have already download the checkpoints of the model, you can specify the local path of the model. For example
118
+
119
+ ```bash
120
+ python run.py --datasets siqa_gen winograd_ppl --hf-type base --hf-path /path/to/model
121
+ ```
122
+
123
+ ## Dataset
124
+
125
+ ### How to build a new dataset?
126
+
127
+ - For building new objective dataset: [new_dataset](../advanced_guides/new_dataset.md)
128
+ - For building new subjective dataset: [subjective_evaluation](../advanced_guides/subjective_evaluation.md)
docs/en/get_started/installation.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation
2
+
3
+ ## Basic Installation
4
+
5
+ 1. Prepare the OpenCompass runtime environment using Conda:
6
+
7
+ ```conda create --name opencompass python=3.10 -y
8
+ # conda create --name opencompass_lmdeploy python=3.10 -y
9
+
10
+ conda activate opencompass
11
+ ```
12
+
13
+ If you want to customize the PyTorch version or related CUDA version, please refer to the [official documentation](https://pytorch.org/get-started/locally/) to set up the PyTorch environment. Note that OpenCompass requires `pytorch>=1.13`.
14
+
15
+ 2. Install OpenCompass:
16
+ - pip Installation
17
+ ```bash
18
+ # For support of most datasets and models
19
+ pip install -U opencompass
20
+
21
+ # Complete installation (supports more datasets)
22
+ # pip install "opencompass[full]"
23
+
24
+ # API Testing (e.g., OpenAI, Qwen)
25
+ # pip install "opencompass[api]"
26
+ ```
27
+ - Building from Source Code If you want to use the latest features of OpenCompass
28
+ ```bash
29
+ git clone https://github.com/open-compass/opencompass opencompass
30
+ cd opencompass
31
+ pip install -e .
32
+ ```
33
+
34
+ ## Other Installations
35
+
36
+ ### Inference Backends
37
+
38
+ ```bash
39
+ # Model inference backends. Since these backends often have dependency conflicts,
40
+ # we recommend using separate virtual environments to manage them.
41
+ pip install "opencompass[lmdeploy]"
42
+ # pip install "opencompass[vllm]"
43
+ ```
44
+
45
+ - LMDeploy
46
+
47
+ You can check if the inference backend has been installed successfully with the following command. For more information, refer to the [official documentation](https://lmdeploy.readthedocs.io/en/latest/get_started.html)
48
+
49
+ ```bash
50
+ lmdeploy chat internlm/internlm2_5-1_8b-chat --backend turbomind
51
+ ```
52
+
53
+ - vLLM
54
+
55
+ You can check if the inference backend has been installed successfully with the following command. For more information, refer to the [official documentation](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
56
+
57
+ ```bash
58
+ vllm serve facebook/opt-125m
59
+ ```
60
+
61
+ ### API
62
+
63
+ OpenCompass supports different commercial model API calls, which you can install via pip or by referring to the [API dependencies](https://github.com/open-compass/opencompass/blob/main/requirements/api.txt) for specific API model dependencies.
64
+
65
+ ```bash
66
+ pip install "opencompass[api]"
67
+
68
+ # pip install openai # GPT-3.5-Turbo / GPT-4-Turbo / GPT-4 / GPT-4o (API)
69
+ # pip install anthropic # Claude (API)
70
+ # pip install dashscope # Qwen (API)
71
+ # pip install volcengine-python-sdk # ByteDance Volcano Engine (API)
72
+ # ...
73
+ ```
74
+
75
+ ### Datasets
76
+
77
+ The basic installation supports most fundamental datasets. For certain datasets (e.g., Alpaca-eval, Longbench, etc.), additional dependencies need to be installed.
78
+
79
+ You can install these through pip or refer to the [additional dependencies](<(https://github.com/open-compass/opencompass/blob/main/requirements/extra.txt)>) for specific dependencies.
80
+
81
+ ```bash
82
+ pip install "opencompass[full]"
83
+ ```
84
+
85
+ For HumanEvalX / HumanEval+ / MBPP+, you need to manually clone the Git repository and install it.
86
+
87
+ ```bash
88
+ git clone --recurse-submodules git@github.com:open-compass/human-eval.git
89
+ cd human-eval
90
+ pip install -e .
91
+ pip install -e evalplus
92
+ ```
93
+
94
+ Some agent evaluations require installing numerous dependencies, which may conflict with existing runtime environments. We recommend creating separate conda environments to manage these.
95
+
96
+ ```bash
97
+ # T-Eval
98
+ pip install lagent==0.1.2
99
+ # CIBench
100
+ pip install -r requirements/agent.txt
101
+ ```
102
+
103
+ # Dataset Preparation
104
+
105
+ The datasets supported by OpenCompass mainly include three parts:
106
+
107
+ 1. Huggingface datasets: The [Huggingface Datasets](https://huggingface.co/datasets) provide a large number of datasets, which will **automatically download** when running with this option.
108
+ Translate the paragraph into English:
109
+
110
+ 2. ModelScope Datasets: [ModelScope OpenCompass Dataset](https://modelscope.cn/organization/opencompass) supports automatic downloading of datasets from ModelScope.
111
+
112
+ To enable this feature, set the environment variable: `export DATASET_SOURCE=ModelScope`. The available datasets include (sourced from OpenCompassData-core.zip):
113
+
114
+ ```plain
115
+ humaneval, triviaqa, commonsenseqa, tydiqa, strategyqa, cmmlu, lambada, piqa, ceval, math, LCSTS, Xsum, winogrande, openbookqa, AGIEval, gsm8k, nq, race, siqa, mbpp, mmlu, hellaswag, ARC, BBH, xstory_cloze, summedits, GAOKAO-BENCH, OCNLI, cmnli
116
+ ```
117
+
118
+ 3. Custom dataset: OpenCompass also provides some Chinese custom **self-built** datasets. Please run the following command to **manually download and extract** them.
119
+
120
+ Run the following commands to download and place the datasets in the `${OpenCompass}/data` directory can complete dataset preparation.
121
+
122
+ ```bash
123
+ # Run in the OpenCompass directory
124
+ wget https://github.com/open-compass/opencompass/releases/download/0.2.2.rc1/OpenCompassData-core-20240207.zip
125
+ unzip OpenCompassData-core-20240207.zip
126
+ ```
127
+
128
+ If you need to use the more comprehensive dataset (~500M) provided by OpenCompass, You can download and `unzip` it using the following command:
129
+
130
+ ```bash
131
+ # For proxy and resumable downloads, try `aria2c -x16 -s16 -k1M "http://ghfast.top/https://github.com/open-compass/opencompass/releases/download/0.2.2.rc1/OpenCompassData-complete-20240207.zip" `
132
+ wget https://github.com/open-compass/opencompass/releases/download/0.2.2.rc1/OpenCompassData-complete-20240207.zip
133
+ unzip OpenCompassData-complete-20240207.zip
134
+ cd ./data
135
+ find . -name "*.zip" -exec unzip "{}" \;
136
+ ```
137
+
138
+ The list of datasets included in both `.zip` can be found [here](https://github.com/open-compass/opencompass/releases/tag/0.2.2.rc1)
139
+
140
+ OpenCompass has supported most of the datasets commonly used for performance comparison, please refer to `configs/dataset` for the specific list of supported datasets.
141
+
142
+ For next step, please read [Quick Start](./quick_start.md).
docs/en/get_started/quick_start.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick Start
2
+
3
+ ![image](https://github.com/open-compass/opencompass/assets/22607038/d063cae0-3297-4fd2-921a-366e0a24890b)
4
+
5
+ ## Overview
6
+
7
+ OpenCompass provides a streamlined workflow for evaluating a model, which consists of the following stages: **Configure** -> **Inference** -> **Evaluation** -> **Visualization**.
8
+
9
+ **Configure**: This is your starting point. Here, you'll set up the entire evaluation process, choosing the model(s) and dataset(s) to assess. You also have the option to select an evaluation strategy, the computation backend, and define how you'd like the results displayed.
10
+
11
+ **Inference & Evaluation**: OpenCompass efficiently manages the heavy lifting, conducting parallel inference and evaluation on your chosen model(s) and dataset(s). The **Inference** phase is all about producing outputs from your datasets, whereas the **Evaluation** phase measures how well these outputs align with the gold standard answers. While this procedure is broken down into multiple "tasks" that run concurrently for greater efficiency, be aware that working with limited computational resources might introduce some unexpected overheads, and resulting in generally slower evaluation. To understand this issue and know how to solve it, check out [FAQ: Efficiency](faq.md#efficiency).
12
+
13
+ **Visualization**: Once the evaluation is done, OpenCompass collates the results into an easy-to-read table and saves them as both CSV and TXT files. If you need real-time updates, you can activate lark reporting and get immediate status reports in your Lark clients.
14
+
15
+ Coming up, we'll walk you through the basics of OpenCompass, showcasing evaluations of pretrained models [OPT-125M](https://huggingface.co/facebook/opt-125m) and [OPT-350M](https://huggingface.co/facebook/opt-350m) on the [SIQA](https://huggingface.co/datasets/social_i_qa) and [Winograd](https://huggingface.co/datasets/winograd_wsc) benchmark tasks. Their configuration files can be found at [configs/eval_demo.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_demo.py).
16
+
17
+ Before running this experiment, please make sure you have installed OpenCompass locally and it should run successfully under one _GTX-1660-6G_ GPU.
18
+ For larger parameterized models like Llama-7B, refer to other examples provided in the [configs directory](https://github.com/open-compass/opencompass/tree/main/configs).
19
+
20
+ ## Configuring an Evaluation Task
21
+
22
+ In OpenCompass, each evaluation task consists of the model to be evaluated and the dataset. The entry point for evaluation is `run.py`. Users can select the model and dataset to be tested either via command line or configuration files.
23
+
24
+ `````{tabs}
25
+ ````{tab} Command Line (Custom HF Model)
26
+
27
+ For HuggingFace models, users can set model parameters directly through the command line without additional configuration files. For instance, for the `facebook/opt-125m` model, you can evaluate it with the following command:
28
+
29
+ ```bash
30
+ python run.py --datasets siqa_gen winograd_ppl \
31
+ --hf-type base \
32
+ --hf-path facebook/opt-125m
33
+ ```
34
+
35
+ Note that in this way, OpenCompass only evaluates one model at a time, while other ways can evaluate multiple models at once.
36
+
37
+ ```{caution}
38
+ `--hf-num-gpus` does not stand for the actual number of GPUs to use in evaluation, but the minimum required number of GPUs for this model. [More](faq.md#how-does-opencompass-allocate-gpus)
39
+ ```
40
+
41
+ :::{dropdown} More detailed example
42
+ :animate: fade-in-slide-down
43
+ ```bash
44
+ python run.py --datasets siqa_gen winograd_ppl \
45
+ --hf-type base \ # HuggingFace model type, base or chat
46
+ --hf-path facebook/opt-125m \ # HuggingFace model path
47
+ --tokenizer-path facebook/opt-125m \ # HuggingFace tokenizer path (if the same as the model path, can be omitted)
48
+ --tokenizer-kwargs padding_side='left' truncation='left' trust_remote_code=True \ # Arguments to construct the tokenizer
49
+ --model-kwargs device_map='auto' \ # Arguments to construct the model
50
+ --max-seq-len 2048 \ # Maximum sequence length the model can accept
51
+ --max-out-len 100 \ # Maximum number of tokens to generate
52
+ --min-out-len 100 \ # Minimum number of tokens to generate
53
+ --batch-size 64 \ # Batch size
54
+ --hf-num-gpus 1 # Number of GPUs required to run the model
55
+ ```
56
+ ```{seealso}
57
+ For all HuggingFace related parameters supported by `run.py`, please read [Launching Evaluation Task](../user_guides/experimentation.md#launching-an-evaluation-task).
58
+ ```
59
+ :::
60
+
61
+ ````
62
+ ````{tab} Command Line
63
+
64
+ Users can combine the models and datasets they want to test using `--models` and `--datasets`.
65
+
66
+ ```bash
67
+ python run.py --models hf_opt_125m hf_opt_350m --datasets siqa_gen winograd_ppl
68
+ ```
69
+
70
+ The models and datasets are pre-stored in the form of configuration files in `configs/models` and `configs/datasets`. Users can view or filter the currently available model and dataset configurations using `tools/list_configs.py`.
71
+
72
+ ```bash
73
+ # List all configurations
74
+ python tools/list_configs.py
75
+ # List all configurations related to llama and mmlu
76
+ python tools/list_configs.py llama mmlu
77
+ ```
78
+
79
+ :::{dropdown} More about `list_configs`
80
+ :animate: fade-in-slide-down
81
+
82
+ Running `python tools/list_configs.py llama mmlu` gives the output like:
83
+
84
+ ```text
85
+ +-----------------+-----------------------------------+
86
+ | Model | Config Path |
87
+ |-----------------+-----------------------------------|
88
+ | hf_llama2_13b | configs/models/hf_llama2_13b.py |
89
+ | hf_llama2_70b | configs/models/hf_llama2_70b.py |
90
+ | ... | ... |
91
+ +-----------------+-----------------------------------+
92
+ +-------------------+---------------------------------------------------+
93
+ | Dataset | Config Path |
94
+ |-------------------+---------------------------------------------------|
95
+ | cmmlu_gen | configs/datasets/cmmlu/cmmlu_gen.py |
96
+ | cmmlu_gen_ffe7c0 | configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py |
97
+ | ... | ... |
98
+ +-------------------+---------------------------------------------------+
99
+ ```
100
+
101
+ Users can use the names in the first column as input parameters for `--models` and `--datasets` in `python run.py`. For datasets, the same name with different suffixes generally indicates that its prompts or evaluation methods are different.
102
+ :::
103
+
104
+ :::{dropdown} Model not on the list?
105
+ :animate: fade-in-slide-down
106
+
107
+ If you want to evaluate other models, please check out the "Command Line (Custom HF Model)" tab for the way to construct a custom HF model without a configuration file, or "Configuration File" tab to learn the general way to prepare your model configurations.
108
+
109
+ :::
110
+
111
+ ````
112
+
113
+ ````{tab} Configuration File
114
+
115
+ In addition to configuring the experiment through the command line, OpenCompass also allows users to write the full configuration of the experiment in a configuration file and run it directly through `run.py`. The configuration file is organized in Python format and must include the `datasets` and `models` fields.
116
+
117
+ The test configuration for this time is [configs/eval_demo.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_demo.py). This configuration introduces the required dataset and model configurations through the [inheritance mechanism](../user_guides/config.md#inheritance-mechanism) and combines the `datasets` and `models` fields in the required format.
118
+
119
+ ```python
120
+ from mmengine.config import read_base
121
+
122
+ with read_base():
123
+ from .datasets.siqa.siqa_gen import siqa_datasets
124
+ from .datasets.winograd.winograd_ppl import winograd_datasets
125
+ from .models.opt.hf_opt_125m import opt125m
126
+ from .models.opt.hf_opt_350m import opt350m
127
+
128
+ datasets = [*siqa_datasets, *winograd_datasets]
129
+ models = [opt125m, opt350m]
130
+ ```
131
+
132
+ When running tasks, we just need to pass the path of the configuration file to `run.py`:
133
+
134
+ ```bash
135
+ python run.py configs/eval_demo.py
136
+ ```
137
+
138
+ :::{dropdown} More about `models`
139
+ :animate: fade-in-slide-down
140
+
141
+ OpenCompass provides a series of pre-defined model configurations under `configs/models`. Below is the configuration snippet related to [opt-350m](https://github.com/open-compass/opencompass/blob/main/configs/models/opt/hf_opt_350m.py) (`configs/models/opt/hf_opt_350m.py`):
142
+
143
+ ```python
144
+ # Evaluate models supported by HuggingFace's `AutoModelForCausalLM` using `HuggingFaceBaseModel`
145
+ from opencompass.models import HuggingFaceBaseModel
146
+
147
+ models = [
148
+ # OPT-350M
149
+ dict(
150
+ type=HuggingFaceBaseModel,
151
+ # Initialization parameters for `HuggingFaceBaseModel`
152
+ path='facebook/opt-350m',
153
+ # Below are common parameters for all models, not specific to HuggingFaceBaseModel
154
+ abbr='opt-350m-hf', # Model abbreviation
155
+ max_out_len=1024, # Maximum number of generated tokens
156
+ batch_size=32, # Batch size
157
+ run_cfg=dict(num_gpus=1), # The required GPU numbers for this model
158
+ )
159
+ ]
160
+ ```
161
+
162
+ When using configurations, we can specify the relevant files through the command-line argument ` --models` or import the model configurations into the `models` list in the configuration file using the inheritance mechanism.
163
+
164
+ ```{seealso}
165
+ More information about model configuration can be found in [Prepare Models](../user_guides/models.md).
166
+ ```
167
+ :::
168
+
169
+ :::{dropdown} More about `datasets`
170
+ :animate: fade-in-slide-down
171
+
172
+ Similar to models, dataset configuration files are provided under `configs/datasets`. Users can use `--datasets` in the command line or import related configurations in the configuration file via inheritance
173
+
174
+ Below is a dataset-related configuration snippet from `configs/eval_demo.py`:
175
+
176
+ ```python
177
+ from mmengine.config import read_base # Use mmengine.read_base() to read the base configuration
178
+
179
+ with read_base():
180
+ # Directly read the required dataset configurations from the preset dataset configurations
181
+ from .datasets.winograd.winograd_ppl import winograd_datasets # Read Winograd configuration, evaluated based on PPL (perplexity)
182
+ from .datasets.siqa.siqa_gen import siqa_datasets # Read SIQA configuration, evaluated based on generation
183
+
184
+ datasets = [*siqa_datasets, *winograd_datasets] # The final config needs to contain the required evaluation dataset list 'datasets'
185
+ ```
186
+
187
+ Dataset configurations are typically of two types: 'ppl' and 'gen', indicating the evaluation method used. Where `ppl` means discriminative evaluation and `gen` means generative evaluation.
188
+
189
+ Moreover, [configs/datasets/collections](https://github.com/open-compass/opencompass/blob/main/configs/datasets/collections) houses various dataset collections, making it convenient for comprehensive evaluations. OpenCompass often uses [`base_medium.py`](/configs/datasets/collections/base_medium.py) for full-scale model testing. To replicate results, simply import that file, for example:
190
+
191
+ ```bash
192
+ python run.py --models hf_llama_7b --datasets base_medium
193
+ ```
194
+
195
+ ```{seealso}
196
+ You can find more information from [Dataset Preparation](../user_guides/datasets.md).
197
+ ```
198
+ :::
199
+
200
+
201
+ ````
202
+
203
+ `````
204
+
205
+ ```{warning}
206
+ OpenCompass usually assumes network is available. If you encounter network issues or wish to run OpenCompass in an offline environment, please refer to [FAQ - Network - Q1](./faq.md#network) for solutions.
207
+ ```
208
+
209
+ The following sections will use configuration-based method as an example to explain the other features.
210
+
211
+ ## Launching Evaluation
212
+
213
+ Since OpenCompass launches evaluation processes in parallel by default, we can start the evaluation in `--debug` mode for the first run and check if there is any problem. In `--debug` mode, the tasks will be executed sequentially and output will be printed in real time.
214
+
215
+ ```bash
216
+ python run.py configs/eval_demo.py -w outputs/demo --debug
217
+ ```
218
+
219
+ The pretrained models 'facebook/opt-350m' and 'facebook/opt-125m' will be automatically downloaded from HuggingFace during the first run.
220
+ If everything is fine, you should see "Starting inference process" on screen:
221
+
222
+ ```bash
223
+ [2023-07-12 18:23:55,076] [opencompass.openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...
224
+ ```
225
+
226
+ Then you can press `ctrl+c` to interrupt the program, and run the following command in normal mode:
227
+
228
+ ```bash
229
+ python run.py configs/eval_demo.py -w outputs/demo
230
+ ```
231
+
232
+ In normal mode, the evaluation tasks will be executed parallelly in the background, and their output will be redirected to the output directory `outputs/demo/{TIMESTAMP}`. The progress bar on the frontend only indicates the number of completed tasks, regardless of their success or failure. **Any backend task failures will only trigger a warning message in the terminal.**
233
+
234
+ :::{dropdown} More parameters in `run.py`
235
+ :animate: fade-in-slide-down
236
+ Here are some parameters related to evaluation that can help you configure more efficient inference tasks based on your environment:
237
+
238
+ - `-w outputs/demo`: Work directory to save evaluation logs and results. In this case, the experiment result will be saved to `outputs/demo/{TIMESTAMP}`.
239
+ - `-r`: Reuse existing inference results, and skip the finished tasks. If followed by a timestamp, the result under that timestamp in the workspace path will be reused; otherwise, the latest result in the specified workspace path will be reused.
240
+ - `--mode all`: Specify a specific stage of the task.
241
+ - all: (Default) Perform a complete evaluation, including inference and evaluation.
242
+ - infer: Perform inference on each dataset.
243
+ - eval: Perform evaluation based on the inference results.
244
+ - viz: Display evaluation results only.
245
+ - `--max-partition-size 2000`: Dataset partition size. Some datasets may be large, and using this parameter can split them into multiple sub-tasks to efficiently utilize resources. However, if the partition is too fine, the overall speed may be slower due to longer model loading times.
246
+ - `--max-num-workers 32`: Maximum number of parallel tasks. In distributed environments such as Slurm, this parameter specifies the maximum number of submitted tasks. In a local environment, it specifies the maximum number of tasks executed in parallel. Note that the actual number of parallel tasks depends on the available GPU resources and may not be equal to this number.
247
+
248
+ If you are not performing the evaluation on your local machine but using a Slurm cluster, you can specify the following parameters:
249
+
250
+ - `--slurm`: Submit tasks using Slurm on the cluster.
251
+ - `--partition(-p) my_part`: Slurm cluster partition.
252
+ - `--retry 2`: Number of retries for failed tasks.
253
+
254
+ ```{seealso}
255
+ The entry also supports submitting tasks to Alibaba Deep Learning Center (DLC), and more customized evaluation strategies. Please refer to [Launching an Evaluation Task](../user_guides/experimentation.md#launching-an-evaluation-task) for details.
256
+ ```
257
+
258
+ :::
259
+
260
+ ## Visualizing Evaluation Results
261
+
262
+ After the evaluation is complete, the evaluation results table will be printed as follows:
263
+
264
+ ```text
265
+ dataset version metric mode opt350m opt125m
266
+ --------- --------- -------- ------ --------- ---------
267
+ siqa e78df3 accuracy gen 21.55 12.44
268
+ winograd b6c7ed accuracy ppl 51.23 49.82
269
+ ```
270
+
271
+ All run outputs will be directed to `outputs/demo/` directory with following structure:
272
+
273
+ ```text
274
+ outputs/default/
275
+ ├── 20200220_120000
276
+ ├── 20230220_183030 # one experiment pre folder
277
+ │ ├── configs # Dumped config files for record. Multiple configs may be kept if different experiments have been re-run on the same experiment folder
278
+ │ ├── logs # log files for both inference and evaluation stages
279
+ │ │ ├── eval
280
+ │ │ └── infer
281
+ │   ├── predictions # Prediction results for each task
282
+ │   ├── results # Evaluation results for each task
283
+ │   └── summary # Summarized evaluation results for a single experiment
284
+ ├── ...
285
+ ```
286
+
287
+ The summarization process can be further customized in configuration and output the averaged score of some benchmarks (MMLU, C-Eval, etc.).
288
+
289
+ More information about obtaining evaluation results can be found in [Results Summary](../user_guides/summarizer.md).
290
+
291
+ ## Additional Tutorials
292
+
293
+ To learn more about using OpenCompass, explore the following tutorials:
294
+
295
+ - [Prepare Datasets](../user_guides/datasets.md)
296
+ - [Prepare Models](../user_guides/models.md)
297
+ - [Task Execution and Monitoring](../user_guides/experimentation.md)
298
+ - [Understand Prompts](../prompt/overview.md)
299
+ - [Results Summary](../user_guides/summarizer.md)
300
+ - [Learn about Config](../user_guides/config.md)
docs/en/index.rst ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Welcome to OpenCompass' documentation!
2
+ ==========================================
3
+
4
+ Getting started with OpenCompass
5
+ -------------------------------
6
+
7
+ To help you quickly familiarized with OpenCompass, we recommend you to walk through the following documents in order:
8
+
9
+ - First read the GetStarted_ section set up the environment, and run a mini experiment.
10
+
11
+ - Then learn its basic usage through the UserGuides_.
12
+
13
+ - If you want to tune the prompts, refer to the Prompt_.
14
+
15
+ - If you want to customize some modules, like adding a new dataset or model, we have provided the AdvancedGuides_.
16
+
17
+ - There are more handy tools, such as prompt viewer and lark bot reporter, all presented in Tools_.
18
+
19
+ We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
20
+
21
+ .. _GetStarted:
22
+ .. toctree::
23
+ :maxdepth: 1
24
+ :caption: Get Started
25
+
26
+ get_started/installation.md
27
+ get_started/quick_start.md
28
+ get_started/faq.md
29
+
30
+ .. _UserGuides:
31
+ .. toctree::
32
+ :maxdepth: 1
33
+ :caption: User Guides
34
+
35
+ user_guides/framework_overview.md
36
+ user_guides/config.md
37
+ user_guides/datasets.md
38
+ user_guides/models.md
39
+ user_guides/evaluation.md
40
+ user_guides/experimentation.md
41
+ user_guides/metrics.md
42
+ user_guides/deepseek_r1.md
43
+ user_guides/interns1.md
44
+
45
+ .. _Prompt:
46
+ .. toctree::
47
+ :maxdepth: 1
48
+ :caption: Prompt
49
+
50
+ prompt/overview.md
51
+ prompt/prompt_template.md
52
+ prompt/meta_template.md
53
+ prompt/chain_of_thought.md
54
+
55
+
56
+ .. _AdvancedGuides:
57
+ .. toctree::
58
+ :maxdepth: 1
59
+ :caption: Advanced Guides
60
+
61
+ advanced_guides/new_dataset.md
62
+ advanced_guides/custom_dataset.md
63
+ advanced_guides/new_model.md
64
+ advanced_guides/evaluation_lmdeploy.md
65
+ advanced_guides/accelerator_intro.md
66
+ advanced_guides/math_verify.md
67
+ advanced_guides/llm_judge.md
68
+ advanced_guides/code_eval.md
69
+ advanced_guides/code_eval_service.md
70
+ advanced_guides/subjective_evaluation.md
71
+ advanced_guides/persistence.md
72
+
73
+ .. _Tools:
74
+ .. toctree::
75
+ :maxdepth: 1
76
+ :caption: Tools
77
+
78
+ tools.md
79
+
80
+ .. _Dataset List:
81
+ .. toctree::
82
+ :maxdepth: 1
83
+ :caption: Dataset List
84
+
85
+ dataset_statistics.md
86
+
87
+ .. _Notes:
88
+ .. toctree::
89
+ :maxdepth: 1
90
+ :caption: Notes
91
+
92
+ notes/contribution_guide.md
93
+ notes/academic.md
94
+
95
+ Indexes & Tables
96
+ ==================
97
+
98
+ * :ref:`genindex`
99
+ * :ref:`search`
docs/en/notes/academic.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Guide to Reproducing CompassAcademic Leaderboard Results
2
+
3
+ To provide users with a quick and intuitive overview of the performance of mainstream open-source and commercial models on widely-used datasets, we maintain the [CompassAcademic Leaderboard](https://rank.opencompass.org.cn/leaderboard-llm-academic/?m=REALTIME) for LLMs on our official website, updating it typically every two weeks.
4
+
5
+ Given the continuous iteration of models and datasets, along with ongoing upgrades to the OpenCompass, the configuration settings for the CompassAcademic leaderboard may evolve. Specifically, we adhere to the following update principles:
6
+
7
+ - Newly released models are promptly included, while models published six months to one year (or more) ago are removed from the leaderboard.
8
+ - New datasets are incorporated, while datasets nearing performance saturation are phased out.
9
+ - Existing evaluation results on the leaderboard are updated in sync with changes to the evaluation configuration.
10
+
11
+ To support rapid reproducibility, OpenCompass provides the real-time configuration files used in the academic leaderboard.
12
+
13
+ ## CompassAcademic Leaderboard Reproduction
14
+
15
+ [eval_academic_leaderboard_REALTIME.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_academic_leaderboard_REALTIME.py) contains the configuration currently used for academic ranking evaluation. You can replicate the evaluation by following the steps as follows.
16
+
17
+ ### 1: Model Configs
18
+
19
+ Firstly, modify the Model List code block in [eval_academic_leaderboard_REALTIME.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_academic_leaderboard_REALTIME.py) to include the model you wish to evaluate.
20
+
21
+ ```python
22
+ # Models (add your models here)
23
+ from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import \
24
+ models as hf_internlm2_5_7b_chat_model
25
+ ```
26
+
27
+ The original example calls an lmdeploy-based model configuration in OpenCompass.
28
+ You can also build your new model configuration based on [this document](https://opencompass.readthedocs.io/zh-cn/latest/user_guides/models.html).
29
+ An example of a configuration that calls the deployed service of Qwen3-235B-A22B based on OpenAISDK is as follows:
30
+
31
+ ```python
32
+ from opencompass.models import OpenAISDK
33
+ from opencompass.utils.text_postprocessors import extract_non_reasoning_content
34
+
35
+ qwen3_235b_a22b_model = dict(
36
+ abbr="qwen_3_235b_a22b_thinking", # Used to identify the model configuration
37
+ key="YOUR_SERVE_API_KEY",
38
+ openai_api_base="YOUR_SERVE_API_URL",
39
+ type=OpenAISDK, # The model configuration types, commonly used such as OpenAISDK, TurboMindModelwithChatTemplate, HuggingFacewithChatTemplate
40
+ path="Qwen/Qwen3-235B-A22B",
41
+ temperature=0.6,
42
+ meta_template=dict(
43
+ round=[
44
+ dict(role='HUMAN', api_role='HUMAN'),
45
+ dict(role='BOT', api_role='BOT', generate=True),
46
+ ],
47
+ ),
48
+ query_per_second=1,
49
+ max_out_len=32000,
50
+ max_seq_len=32768,
51
+ batch_size=8,
52
+ retry=10,
53
+ extra_body={
54
+ 'chat_template_kwargs': {'enable_thinking': True},
55
+ }, # Additional configurations of the model, such as the option in Qwen3 series to control whether they thinks or not
56
+ pred_postprocessor=dict(type=extract_non_reasoning_content), # adding this pred_postprocessor can extract the non-reasoning content from models that output with a think tag
57
+ )
58
+
59
+ models = [
60
+ qwen3_235b_a22b_model,
61
+ ]
62
+ ```
63
+
64
+ Here are the commonly used parameters for reference.
65
+
66
+ - `max_seq_len` = 65536 or 32768
67
+ - `max_out_len` = 64000 or 32000
68
+ - `temperature` = 0.6
69
+ - `top_p` = 0.95
70
+
71
+ ### 2: Verifier Configs
72
+
73
+ Complete your verifier model information in `judge_cfg`.
74
+ For detailed information about LLM verifiers, please refer to [this document](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/llm_judge.html).
75
+ At present, CompassAcademic use [CompassVerifier-32B](https://huggingface.co/opencompass/CompassVerifier-32B), here is the config example using OpenAISDK:
76
+
77
+ ```python
78
+ judge_cfg = dict(
79
+ abbr='CompassVerifier',
80
+ type=OpenAISDK,
81
+ path='opencompass/CompassVerifier-32B',
82
+ key='YOUR_API_KEY',
83
+ openai_api_base='YOUR_API_BASE',
84
+ meta_template=dict(
85
+ round=[
86
+ dict(role='HUMAN', api_role='HUMAN'),
87
+ dict(role='BOT', api_role='BOT', generate=True),
88
+ ]),
89
+ query_per_second=1,
90
+ batch_size=8,
91
+ temperature=0.001,
92
+ max_out_len=8192,
93
+ max_seq_len=32768,
94
+ mode='mid',
95
+ )
96
+ ```
97
+
98
+ ### 3: Execute evaluation
99
+
100
+ After completing the above configuration file, you can enter the following content in the CLI to start the evaluation:
101
+
102
+ ```bash
103
+ opencompass examples/eval_academic_leaderboard_REALTIME.py
104
+ ```
105
+
106
+ For more detailed CLI parameters, please refer to [this document](https://opencompass.readthedocs.io/zh-cn/latest/user_guides/experimentation.html)。
docs/en/notes/contribution_guide.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to OpenCompass
2
+
3
+ - [Contributing to OpenCompass](#contributing-to-opencompass)
4
+ - [What is PR](#what-is-pr)
5
+ - [Basic Workflow](#basic-workflow)
6
+ - [Procedures in detail](#procedures-in-detail)
7
+ - [1. Get the most recent codebase](#1-get-the-most-recent-codebase)
8
+ - [2. Checkout a new branch from `main` branch](#2-checkout-a-new-branch-from-main-branch)
9
+ - [3. Commit your changes](#3-commit-your-changes)
10
+ - [4. Push your changes to the forked repository and create a PR](#4-push-your-changes-to-the-forked-repository-and-create-a-pr)
11
+ - [5. Discuss and review your code](#5-discuss-and-review-your-code)
12
+ - [6. Merge your branch to `main` branch and delete the branch](#6--merge-your-branch-to-main-branch-and-delete-the-branch)
13
+ - [Code style](#code-style)
14
+ - [Python](#python)
15
+ - [About Contributing Test Datasets](#about-contributing-test-datasets)
16
+
17
+ Thanks for your interest in contributing to OpenCompass! All kinds of contributions are welcome, including but not limited to the following.
18
+
19
+ - Fix typo or bugs
20
+ - Add documentation or translate the documentation into other languages
21
+ - Add new features and components
22
+
23
+ ## What is PR
24
+
25
+ `PR` is the abbreviation of `Pull Request`. Here's the definition of `PR` in the [official document](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) of Github.
26
+
27
+ ```
28
+ Pull requests let you tell others about changes you have pushed to a branch in a repository on GitHub. Once a pull request is opened, you can discuss and review the potential changes with collaborators and add follow-up commits before your changes are merged into the base branch.
29
+ ```
30
+
31
+ ## Basic Workflow
32
+
33
+ 1. Get the most recent codebase
34
+ 2. Checkout a new branch from `main` branch.
35
+ 3. Commit your changes ([Don't forget to use pre-commit hooks!](#3-commit-your-changes))
36
+ 4. Push your changes and create a PR
37
+ 5. Discuss and review your code
38
+ 6. Merge your branch to `main` branch
39
+
40
+ ## Procedures in detail
41
+
42
+ ### 1. Get the most recent codebase
43
+
44
+ - When you work on your first PR
45
+
46
+ Fork the OpenCompass repository: click the **fork** button at the top right corner of Github page
47
+ ![avatar](https://github.com/open-compass/opencompass/assets/22607038/851ed33d-02db-49c9-bf94-7c62eee89eb2)
48
+
49
+ Clone forked repository to local
50
+
51
+ ```bash
52
+ git clone git@github.com:XXX/opencompass.git
53
+ ```
54
+
55
+ Add source repository to upstream
56
+
57
+ ```bash
58
+ git remote add upstream git@github.com:InternLM/opencompass.git
59
+ ```
60
+
61
+ - After your first PR
62
+
63
+ Checkout the latest branch of the local repository and pull the latest branch of the source repository.
64
+
65
+ ```bash
66
+ git checkout main
67
+ git pull upstream main
68
+ ```
69
+
70
+ ### 2. Checkout a new branch from `main` branch
71
+
72
+ ```bash
73
+ git checkout main -b branchname
74
+ ```
75
+
76
+ ### 3. Commit your changes
77
+
78
+ - If you are a first-time contributor, please install and initialize pre-commit hooks from the repository root directory first.
79
+
80
+ ```bash
81
+ pip install -U pre-commit
82
+ pre-commit install
83
+ ```
84
+
85
+ - Commit your changes as usual. Pre-commit hooks will be triggered to stylize your code before each commit.
86
+
87
+ ```bash
88
+ # coding
89
+ git add [files]
90
+ git commit -m 'messages'
91
+ ```
92
+
93
+ ```{note}
94
+ Sometimes your code may be changed by pre-commit hooks. In this case, please remember to re-stage the modified files and commit again.
95
+ ```
96
+
97
+ ### 4. Push your changes to the forked repository and create a PR
98
+
99
+ - Push the branch to your forked remote repository
100
+
101
+ ```bash
102
+ git push origin branchname
103
+ ```
104
+
105
+ - Create a PR
106
+ ![avatar](https://github.com/open-compass/opencompass/assets/22607038/08feb221-b145-4ea8-8e20-05f143081604)
107
+
108
+ - Revise PR message template to describe your motivation and modifications made in this PR. You can also link the related issue to the PR manually in the PR message (For more information, checkout the [official guidance](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue)).
109
+
110
+ - You can also ask a specific person to review the changes you've proposed.
111
+
112
+ ### 5. Discuss and review your code
113
+
114
+ - Modify your codes according to reviewers' suggestions and then push your changes.
115
+
116
+ ### 6. Merge your branch to `main` branch and delete the branch
117
+
118
+ - After the PR is merged by the maintainer, you can delete the branch you created in your forked repository.
119
+
120
+ ```bash
121
+ git branch -d branchname # delete local branch
122
+ git push origin --delete branchname # delete remote branch
123
+ ```
124
+
125
+ ## Code style
126
+
127
+ ### Python
128
+
129
+ We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
130
+
131
+ We use the following tools for linting and formatting:
132
+
133
+ - [flake8](https://github.com/PyCQA/flake8): A wrapper around some linter tools.
134
+ - [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports.
135
+ - [yapf](https://github.com/google/yapf): A formatter for Python files.
136
+ - [codespell](https://github.com/codespell-project/codespell): A Python utility to fix common misspellings in text files.
137
+ - [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.
138
+ - [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
139
+
140
+ Style configurations of yapf and isort can be found in [setup.cfg](https://github.com/open-mmlab/OpenCompass/blob/main/setup.cfg).
141
+
142
+ ## About Contributing Test Datasets
143
+
144
+ - Submitting Test Datasets
145
+ - Please implement logic for automatic dataset downloading in the code; or provide a method for obtaining the dataset in the PR. The OpenCompass maintainers will follow up accordingly. If the dataset is not yet public, please indicate so.
146
+ - Submitting Data Configuration Files
147
+ - Provide a README in the same directory as the data configuration. The README should include, but is not limited to:
148
+ - A brief description of the dataset
149
+ - The official link to the dataset
150
+ - Some test examples from the dataset
151
+ - Evaluation results of the dataset on relevant models
152
+ - Citation of the dataset
153
+ - (Optional) Summarizer of the dataset
154
+ - (Optional) If the testing process cannot be achieved simply by concatenating the dataset and model configuration files, a configuration file for conducting the test is also required.
155
+ - (Optional) If necessary, please add a description of the dataset in the relevant documentation sections. This is very necessary to help users understand the testing scheme. You can refer to the following types of documents in OpenCompass:
156
+ - [Circular Evaluation](../advanced_guides/circular_eval.md)
157
+ - [Code Evaluation](../advanced_guides/code_eval.md)
158
+ - [Contamination Assessment](../advanced_guides/contamination_eval.md)
docs/en/notes/news.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # News
2
+
3
+ - **\[2024.05.08\]** We supported the evaluation of 4 MoE models: [Mixtral-8x22B-v0.1](configs/models/mixtral/hf_mixtral_8x22b_v0_1.py), [Mixtral-8x22B-Instruct-v0.1](configs/models/mixtral/hf_mixtral_8x22b_instruct_v0_1.py), [Qwen1.5-MoE-A2.7B](configs/models/qwen/hf_qwen1_5_moe_a2_7b.py), [Qwen1.5-MoE-A2.7B-Chat](configs/models/qwen/hf_qwen1_5_moe_a2_7b_chat.py). Try them out now!
4
+ - **\[2024.04.30\]** We supported evaluating a model's compression efficiency by calculating its Bits per Character (BPC) metric on an [external corpora](configs/datasets/llm_compression/README.md) ([official paper](https://github.com/hkust-nlp/llm-compression-intelligence)). Check out the [llm-compression](configs/eval_llm_compression.py) evaluation config now! 🔥🔥🔥
5
+ - **\[2024.04.29\]** We report the performance of several famous LLMs on the common benchmarks, welcome to [documentation](https://opencompass.readthedocs.io/en/latest/user_guides/corebench.html) for more information! 🔥🔥🔥.
6
+ - **\[2024.04.26\]** We deprecated the multi-madality evaluating function from OpenCompass, related implement has moved to [VLMEvalKit](https://github.com/open-compass/VLMEvalKit), welcome to use! 🔥🔥🔥.
7
+ - **\[2024.04.26\]** We supported the evaluation of [ArenaHard](configs/eval_subjective_arena_hard.py) welcome to try!🔥🔥🔥.
8
+ - **\[2024.04.22\]** We supported the evaluation of [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py) 和 [LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py), welcome to try! 🔥🔥🔥
9
+ - **\[2024.02.29\]** We supported the MT-Bench, AlpacalEval and AlignBench, more information can be found [here](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html)
10
+ - **\[2024.01.30\]** We release OpenCompass 2.0. Click [CompassKit](https://github.com/open-compass), [CompassHub](https://hub.opencompass.org.cn/home), and [CompassRank](https://rank.opencompass.org.cn/home) for more information !
11
+ - **\[2024.01.17\]** We supported the evaluation of [InternLM2](https://github.com/open-compass/opencompass/blob/main/configs/eval_internlm2_keyset.py) and [InternLM2-Chat](https://github.com/open-compass/opencompass/blob/main/configs/eval_internlm2_chat_keyset.py), InternLM2 showed extremely strong performance in these tests, welcome to try!
12
+ - **\[2024.01.17\]** We supported the needle in a haystack test with multiple needles, more information can be found [here](https://opencompass.readthedocs.io/en/latest/advanced_guides/needleinahaystack_eval.html#id8).
13
+ - **\[2023.12.28\]** We have enabled seamless evaluation of all models developed using [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), a powerful toolkit for comprehensive LLM development.
14
+ - **\[2023.12.22\]** We have released [T-Eval](https://github.com/open-compass/T-Eval), a step-by-step evaluation benchmark to gauge your LLMs on tool utilization. Welcome to our [Leaderboard](https://open-compass.github.io/T-Eval/leaderboard.html) for more details!
15
+ - **\[2023.12.10\]** We have released [VLMEvalKit](https://github.com/open-compass/VLMEvalKit), a toolkit for evaluating vision-language models (VLMs), currently support 20+ VLMs and 7 multi-modal benchmarks (including MMBench series).
16
+ - **\[2023.12.10\]** We have supported Mistral AI's MoE LLM: **Mixtral-8x7B-32K**. Welcome to [MixtralKit](https://github.com/open-compass/MixtralKit) for more details about inference and evaluation.
17
+ - **\[2023.11.22\]** We have supported many API-based models, include **Baidu, ByteDance, Huawei, 360**. Welcome to [Models](https://opencompass.readthedocs.io/en/latest/user_guides/models.html) section for more details.
18
+ - **\[2023.11.20\]** Thanks [helloyongyang](https://github.com/helloyongyang) for supporting the evaluation with [LightLLM](https://github.com/ModelTC/lightllm) as backent. Welcome to [Evaluation With LightLLM](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lightllm.html) for more details.
19
+ - **\[2023.11.13\]** We are delighted to announce the release of OpenCompass v0.1.8. This version enables local loading of evaluation benchmarks, thereby eliminating the need for an internet connection. Please note that with this update, **you must re-download all evaluation datasets** to ensure accurate and up-to-date results.
20
+ - **\[2023.11.06\]** We have supported several API-based models, include **ChatGLM Pro@Zhipu, ABAB-Chat@MiniMax and Xunfei**. Welcome to [Models](https://opencompass.readthedocs.io/en/latest/user_guides/models.html) section for more details.
21
+ - **\[2023.10.24\]** We release a new benchmark for evaluating LLMs’ capabilities of having multi-turn dialogues. Welcome to [BotChat](https://github.com/open-compass/BotChat) for more details.
22
+ - **\[2023.09.26\]** We update the leaderboard with [Qwen](https://github.com/QwenLM/Qwen), one of the best-performing open-source models currently available, welcome to our [homepage](https://opencompass.org.cn) for more details.
23
+ - **\[2023.09.20\]** We update the leaderboard with [InternLM-20B](https://github.com/InternLM/InternLM), welcome to our [homepage](https://opencompass.org.cn) for more details.
24
+ - **\[2023.09.19\]** We update the leaderboard with WeMix-LLaMA2-70B/Phi-1.5-1.3B, welcome to our [homepage](https://opencompass.org.cn) for more details.
25
+ - **\[2023.09.18\]** We have released [long context evaluation guidance](docs/en/advanced_guides/longeval.md).
26
+ - **\[2023.09.08\]** We update the leaderboard with Baichuan-2/Tigerbot-2/Vicuna-v1.5, welcome to our [homepage](https://opencompass.org.cn) for more details.
27
+ - **\[2023.09.06\]** [**Baichuan2**](https://github.com/baichuan-inc/Baichuan2) team adpots OpenCompass to evaluate their models systematically. We deeply appreciate the community's dedication to transparency and reproducibility in LLM evaluation.
28
+ - **\[2023.09.02\]** We have supported the evaluation of [Qwen-VL](https://github.com/QwenLM/Qwen-VL) in OpenCompass.
29
+ - **\[2023.08.25\]** [**TigerBot**](https://github.com/TigerResearch/TigerBot) team adpots OpenCompass to evaluate their models systematically. We deeply appreciate the community's dedication to transparency and reproducibility in LLM evaluation.
30
+ - **\[2023.08.21\]** [**Lagent**](https://github.com/InternLM/lagent) has been released, which is a lightweight framework for building LLM-based agents. We are working with Lagent team to support the evaluation of general tool-use capability, stay tuned!
31
+ - **\[2023.08.18\]** We have supported evaluation for **multi-modality learning**, include **MMBench, SEED-Bench, COCO-Caption, Flickr-30K, OCR-VQA, ScienceQA** and so on. Leaderboard is on the road. Feel free to try multi-modality evaluation with OpenCompass !
32
+ - **\[2023.08.18\]** [Dataset card](https://opencompass.org.cn/dataset-detail/MMLU) is now online. Welcome new evaluation benchmark OpenCompass !
33
+ - **\[2023.08.11\]** [Model comparison](https://opencompass.org.cn/model-compare/GPT-4,ChatGPT,LLaMA-2-70B,LLaMA-65B) is now online. We hope this feature offers deeper insights!
34
+ - **\[2023.08.11\]** We have supported [LEval](https://github.com/OpenLMLab/LEval).
35
+ - **\[2023.08.10\]** OpenCompass is compatible with [LMDeploy](https://github.com/InternLM/lmdeploy). Now you can follow this [instruction](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lmdeploy.html#) to evaluate the accelerated models provide by the **Turbomind**.
36
+ - **\[2023.08.10\]** We have supported [Qwen-7B](https://github.com/QwenLM/Qwen-7B) and [XVERSE-13B](https://github.com/xverse-ai/XVERSE-13B) ! Go to our [leaderboard](https://opencompass.org.cn/leaderboard-llm) for more results! More models are welcome to join OpenCompass.
37
+ - **\[2023.08.09\]** Several new datasets(**CMMLU, TydiQA, SQuAD2.0, DROP**) are updated on our [leaderboard](https://opencompass.org.cn/leaderboard-llm)! More datasets are welcomed to join OpenCompass.
38
+ - **\[2023.08.07\]** We have added a [script](tools/eval_mmbench.py) for users to evaluate the inference results of [MMBench](https://opencompass.org.cn/MMBench)-dev.
39
+ - **\[2023.08.05\]** We have supported [GPT-4](https://openai.com/gpt-4)! Go to our [leaderboard](https://opencompass.org.cn/leaderboard-llm) for more results! More models are welcome to join OpenCompass.
40
+ - **\[2023.07.27\]** We have supported [CMMLU](https://github.com/haonan-li/CMMLU)! More datasets are welcome to join OpenCompass.
docs/en/prompt/chain_of_thought.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chain of Thought
2
+
3
+ ## Background
4
+
5
+ During the process of reasoning, CoT (Chain of Thought) method is an efficient way to help LLMs deal complex questions, for example: math problem and relation inference. In OpenCompass, we support multiple types of CoT method.
6
+
7
+ ![image](https://github.com/open-compass/opencompass/assets/28834990/45d60e0e-02a1-49aa-b792-40a1f95f9b9e)
8
+
9
+ ## 1. Zero Shot CoT
10
+
11
+ You can change the `PromptTemplate` of the dataset config, by simply add *Let's think step by step* to realize a Zero-Shot CoT prompt for your evaluation:
12
+
13
+ ```python
14
+ qa_infer_cfg = dict(
15
+ prompt_template=dict(
16
+ type=PromptTemplate,
17
+ template="Answer the question:\nQ: {question}?\nLet's think step by step:\n"
18
+ ),
19
+ retriever=dict(type=ZeroRetriever)
20
+ )
21
+ ```
22
+
23
+ ## 2. Few Shot CoT
24
+
25
+ Few-shot CoT can make LLMs easy to follow your instructions and get better answers. For few-shot CoT, add your CoT template to `PromptTemplate` like following config to create a one-shot prompt:
26
+
27
+ ```python
28
+ qa_infer_cfg = dict(
29
+ prompt_template=dict(
30
+ type=PromptTemplate,
31
+ template=
32
+ '''Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?
33
+ Let's think step by step
34
+ Answer:
35
+ Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.
36
+ His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers
37
+ They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.
38
+ All together his team scored 50+24+10= 84 points
39
+ Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.
40
+ His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.
41
+ They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.
42
+ All together Mark's opponents scored 100+12+5=117 points
43
+ The total score for the game is both team's scores added together, so it is 84+117=201 points
44
+ The answer is 201
45
+
46
+ Question: {question}\nLet's think step by step:\n{answer}
47
+ '''),
48
+ retriever=dict(type=ZeroRetriever)
49
+ )
50
+ ```
51
+
52
+ ## 3. Self-Consistency
53
+
54
+ The SC (Self-Consistency) method is proposed in [this paper](https://arxiv.org/abs/2203.11171), which will sample multiple reasoning paths for the question, and make majority voting to the generated answers for LLMs. This method displays remarkable proficiency among reasoning tasks with high accuracy but may consume more time and resources when inferencing, because of the majority voting strategy. In OpenCompass, You can easily implement the SC method by replacing `GenInferencer` with `SCInferencer` in the dataset configuration and setting the corresponding parameters like:
55
+
56
+ ```python
57
+ # This SC gsm8k config can be found at: opencompass.configs.datasets.gsm8k.gsm8k_gen_a3e34a.py
58
+ gsm8k_infer_cfg = dict(
59
+ inferencer=dict(
60
+ type=SCInferencer, # Replace GenInferencer with SCInferencer.
61
+ generation_kwargs=dict(do_sample=True, temperature=0.7, top_k=40), # Set sample parameters to make sure model generate various output, only works for models load from HuggingFace now.
62
+ infer_type='SC',
63
+ sc_size = SAMPLE_SIZE
64
+ )
65
+ )
66
+ gsm8k_eval_cfg = dict(sc_size=SAMPLE_SIZE)
67
+ ```
68
+
69
+ ```{note}
70
+ OpenCompass defaults to use argmax for sampling the next token. Therefore, if the sampling parameters are not specified, the model's inference results will be completely consistent each time, and multiple rounds of evaluation will be ineffective.
71
+ ```
72
+
73
+ Where `SAMPLE_SIZE` is the number of reasoning paths in Self-Consistency, higher value usually outcome higher performance. The following figure from the original SC paper demonstrates the relation between reasoning paths and performance in several reasoning tasks:
74
+
75
+ ![image](https://github.com/open-compass/opencompass/assets/28834990/05c7d850-7076-43ca-b165-e6251f9b3001)
76
+
77
+ From the figure, it can be seen that in different reasoning tasks, performance tends to improve as the number of reasoning paths increases. However, for some tasks, increasing the number of reasoning paths may reach a limit, and further increasing the number of paths may not bring significant performance improvement. Therefore, it is necessary to conduct experiments and adjustments on specific tasks to find the optimal number of reasoning paths that best suit the task.
78
+
79
+ ## 4. Tree-of-Thoughts
80
+
81
+ In contrast to the conventional CoT approach that considers only a single reasoning path, Tree-of-Thoughts (ToT) allows the language model to explore multiple diverse reasoning paths simultaneously. The model evaluates the reasoning process through self-assessment and makes global choices by conducting lookahead or backtracking when necessary. Specifically, this process is divided into the following four stages:
82
+
83
+ **1. Thought Decomposition**
84
+
85
+ Based on the nature of the problem, break down the problem into multiple intermediate steps. Each step can be a phrase, equation, or writing plan, depending on the nature of the problem.
86
+
87
+ **2. Thought Generation**
88
+
89
+ Assuming that solving the problem requires k steps, there are two methods to generate reasoning content:
90
+
91
+ - Independent sampling: For each state, the model independently extracts k reasoning contents from the CoT prompts, without relying on other reasoning contents.
92
+ - Sequential generation: Sequentially use "prompts" to guide the generation of reasoning content, where each reasoning content may depend on the previous one.
93
+
94
+ **3. Heuristic Evaluation**
95
+
96
+ Use heuristic methods to evaluate the contribution of each generated reasoning content to problem-solving. This self-evaluation is based on the model's self-feedback and involves designing prompts to have the model score multiple generated results.
97
+
98
+ **4. Search Algorithm Selection**
99
+
100
+ Based on the methods of generating and evaluating reasoning content, select an appropriate search algorithm. For example, you can use breadth-first search (BFS) or depth-first search (DFS) algorithms to systematically explore the thought tree, conducting lookahead and backtracking.
101
+
102
+ In OpenCompass, ToT parameters need to be set according to the requirements. Below is an example configuration for the 24-Point game from the [official paper](https://arxiv.org/pdf/2305.10601.pdf). Currently, ToT inference is supported only with Huggingface models:
103
+
104
+ ```python
105
+ # This ToT Game24 config can be found at: opencompass/configs/datasets/game24/game24_gen_8dfde3.py.
106
+ from opencompass.datasets import (Game24Dataset, game24_postprocess,
107
+ Game24Evaluator, Game24PromptWrapper)
108
+
109
+ generation_kwargs = dict(temperature=0.7)
110
+
111
+ game24_infer_cfg = dict(
112
+ prompt_template=dict(
113
+ type=PromptTemplate,
114
+ template='{input}'), # Directly pass the input content, as the Prompt needs to be specified in steps
115
+ retriever=dict(type=ZeroRetriever),
116
+ inferencer=dict(type=ToTInferencer, # Replace GenInferencer with ToTInferencer
117
+ generation_kwargs=generation_kwargs,
118
+ method_generate='propose', # Method for generating reasoning content, can be independent sampling (sample) or sequential generation (propose)
119
+ method_evaluate='value', # Method for evaluating reasoning content, can be voting (vote) or scoring (value)
120
+ method_select='greedy', # Method for selecting reasoning content, can be greedy (greedy) or random (sample)
121
+ n_evaluate_sample=3,
122
+ n_select_sample=5,
123
+ task_wrapper=dict(type=Game24PromptWrapper) # This Wrapper class includes the prompts for each step and methods for generating and evaluating reasoning content, needs customization according to the task
124
+ ))
125
+ ```
126
+
127
+ If you want to use the ToT method on a custom dataset, you'll need to make additional configurations in the `opencompass.datasets.YourDataConfig.py` file to set up the `YourDataPromptWrapper` class. This is required for handling the thought generation and heuristic evaluation step within the ToT framework. For reasoning tasks similar to the game 24-Point, you can refer to the implementation in `opencompass/datasets/game24.py` for guidance.