Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +173 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +173 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py +428 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +123 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py +273 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/naive.py +96 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +164 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +284 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/__init__.py +7 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/chunk.py +500 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py +452 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/naive.py +69 -0
- build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/wy_fast.py +300 -0
- docs/en/.readthedocs.yaml +17 -0
- docs/en/Makefile +20 -0
- docs/en/_static/css/readthedocs.css +62 -0
- docs/en/_static/image/logo.svg +79 -0
- docs/en/_static/image/logo_icon.svg +31 -0
- docs/en/_static/js/custom.js +20 -0
- docs/en/_templates/404.html +18 -0
- docs/en/_templates/autosummary/class.rst +13 -0
- docs/en/_templates/callable.rst +14 -0
- docs/en/advanced_guides/accelerator_intro.md +142 -0
- docs/en/advanced_guides/circular_eval.md +113 -0
- docs/en/advanced_guides/code_eval.md +104 -0
- docs/en/advanced_guides/code_eval_service.md +224 -0
- docs/en/advanced_guides/contamination_eval.md +124 -0
- docs/en/advanced_guides/custom_dataset.md +267 -0
- docs/en/advanced_guides/evaluation_lightllm.md +71 -0
- docs/en/advanced_guides/evaluation_lmdeploy.md +88 -0
- docs/en/advanced_guides/llm_judge.md +370 -0
- docs/en/advanced_guides/longeval.md +169 -0
- docs/en/advanced_guides/math_verify.md +190 -0
- docs/en/advanced_guides/needleinahaystack_eval.md +138 -0
- docs/en/advanced_guides/new_dataset.md +105 -0
- docs/en/advanced_guides/new_model.md +73 -0
- docs/en/advanced_guides/objective_judgelm_evaluation.md +186 -0
- docs/en/advanced_guides/persistence.md +65 -0
- docs/en/advanced_guides/prompt_attack.md +108 -0
- docs/en/advanced_guides/subjective_evaluation.md +171 -0
- docs/en/conf.py +234 -0
- docs/en/docutils.conf +2 -0
- docs/en/get_started/faq.md +128 -0
- docs/en/get_started/installation.md +142 -0
- docs/en/get_started/quick_start.md +300 -0
- docs/en/index.rst +99 -0
- docs/en/notes/academic.md +106 -0
- docs/en/notes/contribution_guide.md +158 -0
- docs/en/notes/news.md +40 -0
- docs/en/prompt/chain_of_thought.md +127 -0
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
|
| 11 |
+
from ....ops.utils.op import exp
|
| 12 |
+
from ....utils import check_shared_mem, use_cuda_graph
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.heuristics({
|
| 16 |
+
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
|
| 17 |
+
'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
|
| 18 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 19 |
+
})
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=[
|
| 22 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 23 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 24 |
+
for num_stages in [2, 3, 4]
|
| 25 |
+
],
|
| 26 |
+
key=['BT', 'BK', 'BV', "V"],
|
| 27 |
+
use_cuda_graph=use_cuda_graph,
|
| 28 |
+
)
|
| 29 |
+
@triton.jit(do_not_specialize=['T'])
|
| 30 |
+
def chunk_dplr_bwd_kernel_dhu(
|
| 31 |
+
qg,
|
| 32 |
+
bg,
|
| 33 |
+
w,
|
| 34 |
+
gk,
|
| 35 |
+
dht,
|
| 36 |
+
dh0,
|
| 37 |
+
do,
|
| 38 |
+
dh,
|
| 39 |
+
dv,
|
| 40 |
+
dv2,
|
| 41 |
+
cu_seqlens,
|
| 42 |
+
chunk_offsets,
|
| 43 |
+
T,
|
| 44 |
+
H: tl.constexpr,
|
| 45 |
+
K: tl.constexpr,
|
| 46 |
+
V: tl.constexpr,
|
| 47 |
+
BT: tl.constexpr,
|
| 48 |
+
BC: tl.constexpr,
|
| 49 |
+
BK: tl.constexpr,
|
| 50 |
+
BV: tl.constexpr,
|
| 51 |
+
USE_FINAL_STATE_GRADIENT: tl.constexpr,
|
| 52 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 53 |
+
IS_VARLEN: tl.constexpr,
|
| 54 |
+
):
|
| 55 |
+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 56 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 57 |
+
if IS_VARLEN:
|
| 58 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 59 |
+
T = eos - bos
|
| 60 |
+
NT = tl.cdiv(T, BT)
|
| 61 |
+
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
| 62 |
+
else:
|
| 63 |
+
bos, eos = i_n * T, i_n * T + T
|
| 64 |
+
NT = tl.cdiv(T, BT)
|
| 65 |
+
boh = i_n * NT
|
| 66 |
+
|
| 67 |
+
# [BK, BV]
|
| 68 |
+
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
| 69 |
+
if USE_FINAL_STATE_GRADIENT:
|
| 70 |
+
p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 71 |
+
b_dh += tl.load(p_dht, boundary_check=(0, 1))
|
| 72 |
+
|
| 73 |
+
mask_k = tl.arange(0, BK) < K
|
| 74 |
+
for i_t in range(NT - 1, -1, -1):
|
| 75 |
+
p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 76 |
+
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
| 77 |
+
b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
|
| 78 |
+
for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
|
| 79 |
+
p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
| 80 |
+
p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
|
| 81 |
+
p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
| 82 |
+
p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
| 83 |
+
p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
| 84 |
+
p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
| 85 |
+
# [BK, BT]
|
| 86 |
+
b_qg = tl.load(p_qg, boundary_check=(0, 1))
|
| 87 |
+
# [BT, BK]
|
| 88 |
+
b_bg = tl.load(p_bg, boundary_check=(0, 1))
|
| 89 |
+
b_w = tl.load(p_w, boundary_check=(0, 1))
|
| 90 |
+
# [BT, V]
|
| 91 |
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
| 92 |
+
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
| 93 |
+
b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
|
| 94 |
+
tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
| 95 |
+
# [BK, BV]
|
| 96 |
+
b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
|
| 97 |
+
b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
|
| 98 |
+
last_idx = min((i_t + 1) * BT, T) - 1
|
| 99 |
+
bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
|
| 100 |
+
b_dh *= exp(bg_last)[:, None]
|
| 101 |
+
b_dh += b_dh_tmp
|
| 102 |
+
|
| 103 |
+
if USE_INITIAL_STATE:
|
| 104 |
+
p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 105 |
+
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def chunk_dplr_bwd_dhu(
|
| 109 |
+
qg: torch.Tensor,
|
| 110 |
+
bg: torch.Tensor,
|
| 111 |
+
w: torch.Tensor,
|
| 112 |
+
gk: torch.Tensor,
|
| 113 |
+
h0: torch.Tensor,
|
| 114 |
+
dht: Optional[torch.Tensor],
|
| 115 |
+
do: torch.Tensor,
|
| 116 |
+
dv: torch.Tensor,
|
| 117 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 118 |
+
chunk_size: int = 64
|
| 119 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 120 |
+
B, T, H, K, V = *qg.shape, do.shape[-1]
|
| 121 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 122 |
+
BK = triton.next_power_of_2(K)
|
| 123 |
+
assert BK <= 256, "current kernel does not support head dimension being larger than 256."
|
| 124 |
+
# H100
|
| 125 |
+
if check_shared_mem('hopper', qg.device.index):
|
| 126 |
+
BV = 64
|
| 127 |
+
BC = 64 if K <= 128 else 32
|
| 128 |
+
elif check_shared_mem('ampere', qg.device.index): # A100
|
| 129 |
+
BV = 32
|
| 130 |
+
BC = 32
|
| 131 |
+
else: # Etc: 4090
|
| 132 |
+
BV = 16
|
| 133 |
+
BC = 16
|
| 134 |
+
|
| 135 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 136 |
+
# N: the actual number of sequences in the batch with either equal or variable lengths
|
| 137 |
+
if cu_seqlens is None:
|
| 138 |
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
| 139 |
+
else:
|
| 140 |
+
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
| 141 |
+
|
| 142 |
+
BC = min(BT, BC)
|
| 143 |
+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
| 144 |
+
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
|
| 145 |
+
|
| 146 |
+
dh = qg.new_empty(B, NT, H, K, V)
|
| 147 |
+
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
|
| 148 |
+
dv2 = torch.zeros_like(dv)
|
| 149 |
+
|
| 150 |
+
grid = (NK, NV, N * H)
|
| 151 |
+
chunk_dplr_bwd_kernel_dhu[grid](
|
| 152 |
+
qg=qg,
|
| 153 |
+
bg=bg,
|
| 154 |
+
w=w,
|
| 155 |
+
gk=gk,
|
| 156 |
+
dht=dht,
|
| 157 |
+
dh0=dh0,
|
| 158 |
+
do=do,
|
| 159 |
+
dh=dh,
|
| 160 |
+
dv=dv,
|
| 161 |
+
dv2=dv2,
|
| 162 |
+
cu_seqlens=cu_seqlens,
|
| 163 |
+
chunk_offsets=chunk_offsets,
|
| 164 |
+
T=T,
|
| 165 |
+
H=H,
|
| 166 |
+
K=K,
|
| 167 |
+
V=V,
|
| 168 |
+
BT=BT,
|
| 169 |
+
BC=BC,
|
| 170 |
+
BK=BK,
|
| 171 |
+
BV=BV,
|
| 172 |
+
)
|
| 173 |
+
return dh, dh0, dv2
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
|
| 11 |
+
from ....ops.utils.op import exp
|
| 12 |
+
from ....utils import check_shared_mem, use_cuda_graph
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.heuristics({
|
| 16 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 17 |
+
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
| 18 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 19 |
+
})
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=[
|
| 22 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 23 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 24 |
+
for num_stages in [2, 3, 4]
|
| 25 |
+
],
|
| 26 |
+
key=['BT', 'BK', 'BV'],
|
| 27 |
+
use_cuda_graph=use_cuda_graph,
|
| 28 |
+
)
|
| 29 |
+
@triton.jit(do_not_specialize=['T'])
|
| 30 |
+
def chunk_dplr_fwd_kernel_h(
|
| 31 |
+
kg,
|
| 32 |
+
v,
|
| 33 |
+
w,
|
| 34 |
+
bg,
|
| 35 |
+
u,
|
| 36 |
+
v_new,
|
| 37 |
+
gk,
|
| 38 |
+
h,
|
| 39 |
+
h0,
|
| 40 |
+
ht,
|
| 41 |
+
cu_seqlens,
|
| 42 |
+
chunk_offsets,
|
| 43 |
+
T,
|
| 44 |
+
H: tl.constexpr,
|
| 45 |
+
K: tl.constexpr,
|
| 46 |
+
V: tl.constexpr,
|
| 47 |
+
BT: tl.constexpr,
|
| 48 |
+
BC: tl.constexpr,
|
| 49 |
+
BK: tl.constexpr,
|
| 50 |
+
BV: tl.constexpr,
|
| 51 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 52 |
+
STORE_FINAL_STATE: tl.constexpr,
|
| 53 |
+
IS_VARLEN: tl.constexpr,
|
| 54 |
+
):
|
| 55 |
+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 56 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 57 |
+
if IS_VARLEN:
|
| 58 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 59 |
+
T = eos - bos
|
| 60 |
+
NT = tl.cdiv(T, BT)
|
| 61 |
+
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
| 62 |
+
else:
|
| 63 |
+
bos, eos = i_n * T, i_n * T + T
|
| 64 |
+
NT = tl.cdiv(T, BT)
|
| 65 |
+
boh = i_n * NT
|
| 66 |
+
o_k = i_k * BK + tl.arange(0, BK)
|
| 67 |
+
|
| 68 |
+
# [BK, BV]
|
| 69 |
+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
| 70 |
+
if USE_INITIAL_STATE:
|
| 71 |
+
p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 72 |
+
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
| 73 |
+
|
| 74 |
+
for i_t in range(NT):
|
| 75 |
+
p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 76 |
+
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
| 77 |
+
|
| 78 |
+
b_hc = tl.zeros([BK, BV], dtype=tl.float32)
|
| 79 |
+
# since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
|
| 80 |
+
for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
|
| 81 |
+
p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
| 82 |
+
p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
| 83 |
+
p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
|
| 84 |
+
p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
| 85 |
+
p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
| 86 |
+
p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
|
| 87 |
+
# [BK, BC]
|
| 88 |
+
b_kg = tl.load(p_kg, boundary_check=(0, 1))
|
| 89 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 90 |
+
b_w = tl.load(p_w, boundary_check=(0, 1))
|
| 91 |
+
b_bg = tl.load(p_bg, boundary_check=(0, 1))
|
| 92 |
+
b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
|
| 93 |
+
b_hc += tl.dot(b_kg, b_v)
|
| 94 |
+
b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
|
| 95 |
+
tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
| 96 |
+
|
| 97 |
+
last_idx = min((i_t + 1) * BT, T) - 1
|
| 98 |
+
b_g_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k, mask=o_k < K).to(tl.float32)
|
| 99 |
+
b_h *= exp(b_g_last[:, None])
|
| 100 |
+
b_h += b_hc
|
| 101 |
+
|
| 102 |
+
if STORE_FINAL_STATE:
|
| 103 |
+
p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 104 |
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def chunk_dplr_fwd_h(
|
| 108 |
+
kg: torch.Tensor,
|
| 109 |
+
v: torch.Tensor,
|
| 110 |
+
w: torch.Tensor,
|
| 111 |
+
u: torch.Tensor,
|
| 112 |
+
bg: torch.Tensor,
|
| 113 |
+
gk: torch.Tensor,
|
| 114 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 115 |
+
output_final_state: bool = False,
|
| 116 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 117 |
+
chunk_size: int = 64
|
| 118 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 119 |
+
B, T, H, K, V = *kg.shape, u.shape[-1]
|
| 120 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 121 |
+
|
| 122 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 123 |
+
# N: the actual number of sequences in the batch with either equal or variable lengths
|
| 124 |
+
if cu_seqlens is None:
|
| 125 |
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
| 126 |
+
else:
|
| 127 |
+
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
| 128 |
+
BK = triton.next_power_of_2(K)
|
| 129 |
+
assert BK <= 256, "current kernel does not support head dimension larger than 256."
|
| 130 |
+
# H100 can have larger block size
|
| 131 |
+
|
| 132 |
+
if check_shared_mem('hopper', kg.device.index):
|
| 133 |
+
BV = 64
|
| 134 |
+
BC = 64 if K <= 128 else 32
|
| 135 |
+
elif check_shared_mem('ampere', kg.device.index): # A100
|
| 136 |
+
BV = 32
|
| 137 |
+
BC = 32
|
| 138 |
+
else:
|
| 139 |
+
BV = 16
|
| 140 |
+
BC = 16
|
| 141 |
+
|
| 142 |
+
BC = min(BT, BC)
|
| 143 |
+
NK = triton.cdiv(K, BK)
|
| 144 |
+
NV = triton.cdiv(V, BV)
|
| 145 |
+
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
|
| 146 |
+
|
| 147 |
+
h = kg.new_empty(B, NT, H, K, V)
|
| 148 |
+
final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
| 149 |
+
v_new = torch.empty_like(u)
|
| 150 |
+
grid = (NK, NV, N * H)
|
| 151 |
+
chunk_dplr_fwd_kernel_h[grid](
|
| 152 |
+
kg=kg,
|
| 153 |
+
v=v,
|
| 154 |
+
w=w,
|
| 155 |
+
bg=bg,
|
| 156 |
+
u=u,
|
| 157 |
+
v_new=v_new,
|
| 158 |
+
h=h,
|
| 159 |
+
gk=gk,
|
| 160 |
+
h0=initial_state,
|
| 161 |
+
ht=final_state,
|
| 162 |
+
cu_seqlens=cu_seqlens,
|
| 163 |
+
chunk_offsets=chunk_offsets,
|
| 164 |
+
T=T,
|
| 165 |
+
H=H,
|
| 166 |
+
K=K,
|
| 167 |
+
V=V,
|
| 168 |
+
BT=BT,
|
| 169 |
+
BC=BC,
|
| 170 |
+
BK=BK,
|
| 171 |
+
BV=BV,
|
| 172 |
+
)
|
| 173 |
+
return h, v_new, final_state
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....ops.utils import prepare_chunk_indices
|
| 11 |
+
from ....ops.utils.op import exp
|
| 12 |
+
from ....utils import check_shared_mem, use_cuda_graph
|
| 13 |
+
|
| 14 |
+
BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.heuristics({
|
| 18 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 19 |
+
})
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=[
|
| 22 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 23 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 24 |
+
for num_stages in [2, 3, 4]
|
| 25 |
+
],
|
| 26 |
+
key=['BV', 'BT'],
|
| 27 |
+
use_cuda_graph=use_cuda_graph,
|
| 28 |
+
)
|
| 29 |
+
@triton.jit(do_not_specialize=['T'])
|
| 30 |
+
def chunk_dplr_bwd_kernel_dAu(
|
| 31 |
+
v,
|
| 32 |
+
do,
|
| 33 |
+
v_new,
|
| 34 |
+
A_qb,
|
| 35 |
+
dA_qk,
|
| 36 |
+
dA_qb,
|
| 37 |
+
dv_new,
|
| 38 |
+
cu_seqlens,
|
| 39 |
+
chunk_indices,
|
| 40 |
+
scale: tl.constexpr,
|
| 41 |
+
T,
|
| 42 |
+
H: tl.constexpr,
|
| 43 |
+
V: tl.constexpr,
|
| 44 |
+
BT: tl.constexpr,
|
| 45 |
+
BV: tl.constexpr,
|
| 46 |
+
IS_VARLEN: tl.constexpr,
|
| 47 |
+
):
|
| 48 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 49 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 50 |
+
if IS_VARLEN:
|
| 51 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 52 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 53 |
+
else:
|
| 54 |
+
bos, eos = i_b * T, i_b * T + T
|
| 55 |
+
T = eos - bos
|
| 56 |
+
|
| 57 |
+
b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32)
|
| 58 |
+
b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32)
|
| 59 |
+
|
| 60 |
+
p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 61 |
+
|
| 62 |
+
b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1))
|
| 63 |
+
# causal mask
|
| 64 |
+
b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype)
|
| 65 |
+
|
| 66 |
+
for i_v in range(tl.cdiv(V, BV)):
|
| 67 |
+
p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 68 |
+
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
|
| 69 |
+
p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
|
| 70 |
+
p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 71 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 72 |
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
| 73 |
+
b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
|
| 74 |
+
b_dA_qk += tl.dot(b_do, b_v)
|
| 75 |
+
b_dA_qb += tl.dot(b_do, b_v_new)
|
| 76 |
+
b_dv_new = tl.dot(tl.trans(b_A_qb), b_do)
|
| 77 |
+
# for recurrent
|
| 78 |
+
tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1))
|
| 79 |
+
|
| 80 |
+
p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 81 |
+
p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 82 |
+
m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
|
| 83 |
+
b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.)
|
| 84 |
+
tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1))
|
| 85 |
+
b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.)
|
| 86 |
+
tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@triton.heuristics({
|
| 90 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 91 |
+
})
|
| 92 |
+
@triton.autotune(
|
| 93 |
+
configs=[
|
| 94 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 95 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 96 |
+
for num_stages in [2, 3, 4]
|
| 97 |
+
],
|
| 98 |
+
key=['BT', 'BK', 'BV'],
|
| 99 |
+
use_cuda_graph=use_cuda_graph,
|
| 100 |
+
)
|
| 101 |
+
@triton.jit
|
| 102 |
+
def chunk_dplr_bwd_o_kernel(
|
| 103 |
+
v,
|
| 104 |
+
v_new,
|
| 105 |
+
h,
|
| 106 |
+
do,
|
| 107 |
+
dh,
|
| 108 |
+
dk,
|
| 109 |
+
db,
|
| 110 |
+
w,
|
| 111 |
+
dq,
|
| 112 |
+
dv,
|
| 113 |
+
dw,
|
| 114 |
+
gk,
|
| 115 |
+
dgk_last,
|
| 116 |
+
k,
|
| 117 |
+
b,
|
| 118 |
+
cu_seqlens,
|
| 119 |
+
chunk_indices,
|
| 120 |
+
T,
|
| 121 |
+
H: tl.constexpr,
|
| 122 |
+
K: tl.constexpr,
|
| 123 |
+
V: tl.constexpr,
|
| 124 |
+
BT: tl.constexpr,
|
| 125 |
+
BK: tl.constexpr,
|
| 126 |
+
BV: tl.constexpr,
|
| 127 |
+
IS_VARLEN: tl.constexpr,
|
| 128 |
+
):
|
| 129 |
+
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 130 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 131 |
+
|
| 132 |
+
if IS_VARLEN:
|
| 133 |
+
i_tg = i_t
|
| 134 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 135 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 136 |
+
T = eos - bos
|
| 137 |
+
NT = tl.cdiv(T, BT)
|
| 138 |
+
else:
|
| 139 |
+
NT = tl.cdiv(T, BT)
|
| 140 |
+
i_tg = i_b * NT + i_t
|
| 141 |
+
bos, eos = i_b * T, i_b * T + T
|
| 142 |
+
|
| 143 |
+
# offset calculation
|
| 144 |
+
v += (bos * H + i_h) * V
|
| 145 |
+
v_new += (bos * H + i_h) * V
|
| 146 |
+
do += (bos * H + i_h) * V
|
| 147 |
+
h += (i_tg * H + i_h) * K * V
|
| 148 |
+
dh += (i_tg * H + i_h) * K * V
|
| 149 |
+
dk += (bos * H + i_h) * K
|
| 150 |
+
k += (bos * H + i_h) * K
|
| 151 |
+
db += (bos * H + i_h) * K
|
| 152 |
+
b += (bos * H + i_h) * K
|
| 153 |
+
dw += (bos * H + i_h) * K
|
| 154 |
+
dv += (bos * H + i_h) * V
|
| 155 |
+
dq += (bos * H + i_h) * K
|
| 156 |
+
w += (bos * H + i_h) * K
|
| 157 |
+
|
| 158 |
+
dgk_last += (i_tg * H + i_h) * K
|
| 159 |
+
gk += (bos * H + i_h) * K
|
| 160 |
+
|
| 161 |
+
stride_qk = H*K
|
| 162 |
+
stride_vo = H*V
|
| 163 |
+
|
| 164 |
+
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
| 165 |
+
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
| 166 |
+
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
|
| 167 |
+
b_db = tl.zeros([BT, BK], dtype=tl.float32)
|
| 168 |
+
b_dgk_last = tl.zeros([BK], dtype=tl.float32)
|
| 169 |
+
|
| 170 |
+
for i_v in range(tl.cdiv(V, BV)):
|
| 171 |
+
p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 172 |
+
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 173 |
+
p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 174 |
+
p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
| 175 |
+
p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
| 176 |
+
# [BT, BV]
|
| 177 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 178 |
+
b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
|
| 179 |
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
| 180 |
+
# [BV, BK]
|
| 181 |
+
b_h = tl.load(p_h, boundary_check=(0, 1))
|
| 182 |
+
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
| 183 |
+
b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0)
|
| 184 |
+
|
| 185 |
+
# [BT, BV] @ [BV, BK] -> [BT, BK]
|
| 186 |
+
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
|
| 187 |
+
# [BT, BV] @ [BV, BK] -> [BT, BK]
|
| 188 |
+
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
|
| 189 |
+
b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype))
|
| 190 |
+
p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 191 |
+
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
| 192 |
+
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
|
| 193 |
+
|
| 194 |
+
m_k = (i_k*BK+tl.arange(0, BK)) < K
|
| 195 |
+
last_idx = min(i_t * BT + BT, T) - 1
|
| 196 |
+
b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf'))
|
| 197 |
+
b_dgk_last *= exp(b_gk_last)
|
| 198 |
+
p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 199 |
+
p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 200 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 201 |
+
b_b = tl.load(p_b, boundary_check=(0, 1))
|
| 202 |
+
b_dgk_last += tl.sum(b_k * b_dk, axis=0)
|
| 203 |
+
b_dgk_last += tl.sum(b_b * b_db, axis=0)
|
| 204 |
+
tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k)
|
| 205 |
+
|
| 206 |
+
p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 207 |
+
p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 208 |
+
p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 209 |
+
p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 210 |
+
tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
|
| 211 |
+
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
| 212 |
+
tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
|
| 213 |
+
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@triton.heuristics({
|
| 217 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 218 |
+
})
|
| 219 |
+
@triton.autotune(
|
| 220 |
+
configs=[
|
| 221 |
+
triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
| 222 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 223 |
+
for num_stages in [2, 3, 4]
|
| 224 |
+
for BK in BK_LIST
|
| 225 |
+
for BV in BK_LIST
|
| 226 |
+
],
|
| 227 |
+
key=['BT'],
|
| 228 |
+
use_cuda_graph=use_cuda_graph,
|
| 229 |
+
)
|
| 230 |
+
@triton.jit
|
| 231 |
+
def chunk_dplr_bwd_kernel_dv(
|
| 232 |
+
A_qk,
|
| 233 |
+
kg,
|
| 234 |
+
do,
|
| 235 |
+
dv,
|
| 236 |
+
dh,
|
| 237 |
+
cu_seqlens,
|
| 238 |
+
chunk_indices,
|
| 239 |
+
T,
|
| 240 |
+
H: tl.constexpr,
|
| 241 |
+
K: tl.constexpr,
|
| 242 |
+
V: tl.constexpr,
|
| 243 |
+
BT: tl.constexpr,
|
| 244 |
+
BK: tl.constexpr,
|
| 245 |
+
BV: tl.constexpr,
|
| 246 |
+
IS_VARLEN: tl.constexpr,
|
| 247 |
+
):
|
| 248 |
+
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 249 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 250 |
+
if IS_VARLEN:
|
| 251 |
+
i_tg = i_t
|
| 252 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 253 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 254 |
+
T = eos - bos
|
| 255 |
+
NT = tl.cdiv(T, BT)
|
| 256 |
+
else:
|
| 257 |
+
NT = tl.cdiv(T, BT)
|
| 258 |
+
i_tg = i_b * NT + i_t
|
| 259 |
+
bos, eos = i_b * T, i_b * T + T
|
| 260 |
+
|
| 261 |
+
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
|
| 262 |
+
|
| 263 |
+
# offset calculation
|
| 264 |
+
A_qk += (bos * H + i_h) * BT
|
| 265 |
+
do += (bos * H + i_h) * V
|
| 266 |
+
dv += (bos * H + i_h) * V
|
| 267 |
+
kg += (bos * H + i_h) * K
|
| 268 |
+
dh += (i_tg * H + i_h) * K*V
|
| 269 |
+
|
| 270 |
+
stride_qk = H*K
|
| 271 |
+
stride_vo = H*V
|
| 272 |
+
stride_A = H*BT
|
| 273 |
+
|
| 274 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 275 |
+
p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 276 |
+
p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 277 |
+
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
| 278 |
+
b_kg = tl.load(p_kg, boundary_check=(0, 1))
|
| 279 |
+
b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype))
|
| 280 |
+
|
| 281 |
+
p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1))
|
| 282 |
+
b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0)
|
| 283 |
+
p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 284 |
+
p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 285 |
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
| 286 |
+
b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
|
| 287 |
+
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def chunk_dplr_bwd_dv(
|
| 291 |
+
A_qk: torch.Tensor,
|
| 292 |
+
kg: torch.Tensor,
|
| 293 |
+
do: torch.Tensor,
|
| 294 |
+
dh: torch.Tensor,
|
| 295 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 296 |
+
chunk_size: int = 64
|
| 297 |
+
) -> torch.Tensor:
|
| 298 |
+
B, T, H, K, V = *kg.shape, do.shape[-1]
|
| 299 |
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
| 300 |
+
|
| 301 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 302 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 303 |
+
|
| 304 |
+
dv = torch.empty_like(do)
|
| 305 |
+
|
| 306 |
+
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
|
| 307 |
+
chunk_dplr_bwd_kernel_dv[grid](
|
| 308 |
+
A_qk=A_qk,
|
| 309 |
+
kg=kg,
|
| 310 |
+
do=do,
|
| 311 |
+
dv=dv,
|
| 312 |
+
dh=dh,
|
| 313 |
+
cu_seqlens=cu_seqlens,
|
| 314 |
+
chunk_indices=chunk_indices,
|
| 315 |
+
T=T,
|
| 316 |
+
H=H,
|
| 317 |
+
K=K,
|
| 318 |
+
V=V,
|
| 319 |
+
BT=BT,
|
| 320 |
+
)
|
| 321 |
+
return dv
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def chunk_dplr_bwd_o(
|
| 325 |
+
k: torch.Tensor,
|
| 326 |
+
b: torch.Tensor,
|
| 327 |
+
v: torch.Tensor,
|
| 328 |
+
v_new: torch.Tensor,
|
| 329 |
+
gk: torch.Tensor,
|
| 330 |
+
do: torch.Tensor,
|
| 331 |
+
h: torch.Tensor,
|
| 332 |
+
dh: torch.Tensor,
|
| 333 |
+
dv: torch.Tensor,
|
| 334 |
+
w: torch.Tensor,
|
| 335 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 336 |
+
chunk_size: int = 64,
|
| 337 |
+
scale: float = 1.0,
|
| 338 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 339 |
+
|
| 340 |
+
B, T, H, K, V = *w.shape, v.shape[-1]
|
| 341 |
+
|
| 342 |
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
| 343 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 344 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 345 |
+
|
| 346 |
+
BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
|
| 347 |
+
BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
|
| 348 |
+
NK = triton.cdiv(K, BK)
|
| 349 |
+
dq = torch.empty_like(k)
|
| 350 |
+
dk = torch.empty_like(k)
|
| 351 |
+
dw = torch.empty_like(w)
|
| 352 |
+
db = torch.empty_like(b)
|
| 353 |
+
grid = (NK, NT, B * H)
|
| 354 |
+
|
| 355 |
+
dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device)
|
| 356 |
+
|
| 357 |
+
chunk_dplr_bwd_o_kernel[grid](
|
| 358 |
+
k=k,
|
| 359 |
+
b=b,
|
| 360 |
+
v=v,
|
| 361 |
+
v_new=v_new,
|
| 362 |
+
h=h,
|
| 363 |
+
do=do,
|
| 364 |
+
dh=dh,
|
| 365 |
+
dq=dq,
|
| 366 |
+
dk=dk,
|
| 367 |
+
db=db,
|
| 368 |
+
dgk_last=dgk_last,
|
| 369 |
+
w=w,
|
| 370 |
+
dv=dv,
|
| 371 |
+
dw=dw,
|
| 372 |
+
gk=gk,
|
| 373 |
+
cu_seqlens=cu_seqlens,
|
| 374 |
+
chunk_indices=chunk_indices,
|
| 375 |
+
T=T,
|
| 376 |
+
H=H,
|
| 377 |
+
K=K,
|
| 378 |
+
V=V,
|
| 379 |
+
BT=BT,
|
| 380 |
+
BK=BK,
|
| 381 |
+
BV=BV,
|
| 382 |
+
)
|
| 383 |
+
return dq, dk, dw, db, dgk_last
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def chunk_dplr_bwd_dAu(
|
| 387 |
+
v: torch.Tensor,
|
| 388 |
+
v_new: torch.Tensor,
|
| 389 |
+
do: torch.Tensor,
|
| 390 |
+
A_qb: torch.Tensor,
|
| 391 |
+
scale: float,
|
| 392 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 393 |
+
chunk_size: int = 64
|
| 394 |
+
) -> torch.Tensor:
|
| 395 |
+
B, T, H, V = v.shape
|
| 396 |
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
| 397 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 398 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 399 |
+
|
| 400 |
+
if check_shared_mem('ampere'): # A100
|
| 401 |
+
BV = min(triton.next_power_of_2(V), 128)
|
| 402 |
+
elif check_shared_mem('ada'): # 4090
|
| 403 |
+
BV = min(triton.next_power_of_2(V), 64)
|
| 404 |
+
else:
|
| 405 |
+
BV = min(triton.next_power_of_2(V), 32)
|
| 406 |
+
|
| 407 |
+
grid = (NT, B * H)
|
| 408 |
+
dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
|
| 409 |
+
dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
|
| 410 |
+
dv_new = torch.empty_like(v_new)
|
| 411 |
+
chunk_dplr_bwd_kernel_dAu[grid](
|
| 412 |
+
v=v,
|
| 413 |
+
do=do,
|
| 414 |
+
v_new=v_new,
|
| 415 |
+
A_qb=A_qb,
|
| 416 |
+
dA_qk=dA_qk,
|
| 417 |
+
dA_qb=dA_qb,
|
| 418 |
+
dv_new=dv_new,
|
| 419 |
+
cu_seqlens=cu_seqlens,
|
| 420 |
+
chunk_indices=chunk_indices,
|
| 421 |
+
scale=scale,
|
| 422 |
+
T=T,
|
| 423 |
+
H=H,
|
| 424 |
+
V=V,
|
| 425 |
+
BT=BT,
|
| 426 |
+
BV=BV,
|
| 427 |
+
)
|
| 428 |
+
return dv_new, dA_qk, dA_qb
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....ops.utils import prepare_chunk_indices
|
| 11 |
+
from ....utils import check_shared_mem, use_cuda_graph
|
| 12 |
+
|
| 13 |
+
BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@triton.heuristics({
|
| 17 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 18 |
+
})
|
| 19 |
+
@triton.autotune(
|
| 20 |
+
configs=[
|
| 21 |
+
triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
| 22 |
+
for BK in BK_LIST
|
| 23 |
+
for BV in BK_LIST
|
| 24 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 25 |
+
for num_stages in [2, 3, 4]
|
| 26 |
+
],
|
| 27 |
+
key=['BT'],
|
| 28 |
+
use_cuda_graph=use_cuda_graph,
|
| 29 |
+
)
|
| 30 |
+
@triton.jit(do_not_specialize=['T'])
|
| 31 |
+
def chunk_dplr_fwd_kernel_o(
|
| 32 |
+
qg,
|
| 33 |
+
v,
|
| 34 |
+
v_new,
|
| 35 |
+
A_qk,
|
| 36 |
+
A_qb,
|
| 37 |
+
h,
|
| 38 |
+
o,
|
| 39 |
+
cu_seqlens,
|
| 40 |
+
chunk_indices,
|
| 41 |
+
T,
|
| 42 |
+
H: tl.constexpr,
|
| 43 |
+
K: tl.constexpr,
|
| 44 |
+
V: tl.constexpr,
|
| 45 |
+
BT: tl.constexpr,
|
| 46 |
+
BK: tl.constexpr,
|
| 47 |
+
BV: tl.constexpr,
|
| 48 |
+
IS_VARLEN: tl.constexpr,
|
| 49 |
+
):
|
| 50 |
+
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 51 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 52 |
+
|
| 53 |
+
if IS_VARLEN:
|
| 54 |
+
i_tg = i_t
|
| 55 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 56 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 57 |
+
T = eos - bos
|
| 58 |
+
NT = tl.cdiv(T, BT)
|
| 59 |
+
else:
|
| 60 |
+
NT = tl.cdiv(T, BT)
|
| 61 |
+
i_tg = i_b * NT + i_t
|
| 62 |
+
bos, eos = i_b * T, i_b * T + T
|
| 63 |
+
|
| 64 |
+
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
| 65 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 66 |
+
p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 67 |
+
p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 68 |
+
b_qg = tl.load(p_qg, boundary_check=(0, 1))
|
| 69 |
+
b_h = tl.load(p_h, boundary_check=(0, 1))
|
| 70 |
+
b_o += tl.dot(b_qg, b_h)
|
| 71 |
+
|
| 72 |
+
p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 73 |
+
p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 74 |
+
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 75 |
+
p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 76 |
+
p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 77 |
+
|
| 78 |
+
m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
|
| 79 |
+
b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
|
| 80 |
+
b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
|
| 81 |
+
b_Aqk = tl.where(m_s, b_Aqk, 0)
|
| 82 |
+
b_Aqb = tl.where(m_s, b_Aqb, 0)
|
| 83 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 84 |
+
b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
|
| 85 |
+
b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
|
| 86 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def chunk_dplr_fwd_o(
|
| 90 |
+
qg: torch.Tensor,
|
| 91 |
+
v: torch.Tensor,
|
| 92 |
+
v_new: torch.Tensor,
|
| 93 |
+
A_qk: torch.Tensor,
|
| 94 |
+
A_qb: torch.Tensor,
|
| 95 |
+
h: torch.Tensor,
|
| 96 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 97 |
+
chunk_size: int = 64
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
B, T, H, K, V = *qg.shape, v.shape[-1]
|
| 100 |
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
| 101 |
+
|
| 102 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 103 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 104 |
+
|
| 105 |
+
o = torch.empty_like(v)
|
| 106 |
+
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
|
| 107 |
+
chunk_dplr_fwd_kernel_o[grid](
|
| 108 |
+
qg=qg,
|
| 109 |
+
v=v,
|
| 110 |
+
v_new=v_new,
|
| 111 |
+
A_qk=A_qk,
|
| 112 |
+
A_qb=A_qb,
|
| 113 |
+
h=h,
|
| 114 |
+
o=o,
|
| 115 |
+
cu_seqlens=cu_seqlens,
|
| 116 |
+
chunk_indices=chunk_indices,
|
| 117 |
+
T=T,
|
| 118 |
+
H=H,
|
| 119 |
+
K=K,
|
| 120 |
+
V=V,
|
| 121 |
+
BT=BT,
|
| 122 |
+
)
|
| 123 |
+
return o
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....ops.utils.op import exp
|
| 11 |
+
from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@triton.heuristics({
|
| 15 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 16 |
+
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
| 17 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 18 |
+
})
|
| 19 |
+
@triton.autotune(
|
| 20 |
+
configs=[
|
| 21 |
+
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
| 22 |
+
for BV in [16, 32, 64]
|
| 23 |
+
for num_warps in [2, 4, 8, 16]
|
| 24 |
+
for num_stages in [2, 3, 4]
|
| 25 |
+
],
|
| 26 |
+
key=['BK'],
|
| 27 |
+
use_cuda_graph=use_cuda_graph,
|
| 28 |
+
)
|
| 29 |
+
@triton.jit(do_not_specialize=['T'])
|
| 30 |
+
def fused_recurrent_dplr_delta_rule_fwd_kernel(
|
| 31 |
+
q,
|
| 32 |
+
k,
|
| 33 |
+
v,
|
| 34 |
+
a,
|
| 35 |
+
b,
|
| 36 |
+
gk,
|
| 37 |
+
o,
|
| 38 |
+
h0,
|
| 39 |
+
ht,
|
| 40 |
+
cu_seqlens,
|
| 41 |
+
scale,
|
| 42 |
+
T,
|
| 43 |
+
B: tl.constexpr,
|
| 44 |
+
H: tl.constexpr,
|
| 45 |
+
K: tl.constexpr,
|
| 46 |
+
V: tl.constexpr,
|
| 47 |
+
BK: tl.constexpr,
|
| 48 |
+
BV: tl.constexpr,
|
| 49 |
+
REVERSE: tl.constexpr,
|
| 50 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 51 |
+
STORE_FINAL_STATE: tl.constexpr,
|
| 52 |
+
IS_VARLEN: tl.constexpr,
|
| 53 |
+
):
|
| 54 |
+
i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
|
| 55 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 56 |
+
|
| 57 |
+
if IS_VARLEN:
|
| 58 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
| 59 |
+
T = eos - bos
|
| 60 |
+
else:
|
| 61 |
+
bos, eos = i_n * T, i_n * T + T
|
| 62 |
+
|
| 63 |
+
o_k = tl.arange(0, BK)
|
| 64 |
+
o_v = i_v * BV + tl.arange(0, BV)
|
| 65 |
+
p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 66 |
+
p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 67 |
+
p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 68 |
+
p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 69 |
+
p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 70 |
+
p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
|
| 71 |
+
p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
|
| 72 |
+
|
| 73 |
+
mask_k = o_k < K
|
| 74 |
+
mask_v = o_v < V
|
| 75 |
+
mask_h = mask_k[None, :] & mask_v[:, None]
|
| 76 |
+
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
| 77 |
+
|
| 78 |
+
if USE_INITIAL_STATE:
|
| 79 |
+
p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
|
| 80 |
+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
| 81 |
+
|
| 82 |
+
for _ in range(0, T):
|
| 83 |
+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
|
| 84 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 85 |
+
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
|
| 86 |
+
b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
|
| 87 |
+
b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
|
| 88 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 89 |
+
|
| 90 |
+
tmp = tl.sum(b_h * b_a[None, :], axis=1)
|
| 91 |
+
b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
|
| 92 |
+
b_o = tl.sum(b_h * b_q[None, :], axis=1)
|
| 93 |
+
|
| 94 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
| 95 |
+
p_q += (-1 if REVERSE else 1) * H*K
|
| 96 |
+
p_k += (-1 if REVERSE else 1) * H*K
|
| 97 |
+
p_a += (-1 if REVERSE else 1) * H*K
|
| 98 |
+
p_b += (-1 if REVERSE else 1) * H*K
|
| 99 |
+
p_gk += (-1 if REVERSE else 1) * H*K
|
| 100 |
+
p_v += (-1 if REVERSE else 1) * H*V
|
| 101 |
+
p_o += (-1 if REVERSE else 1) * H*V
|
| 102 |
+
|
| 103 |
+
if STORE_FINAL_STATE:
|
| 104 |
+
p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
|
| 105 |
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def fused_recurrent_dplr_delta_rule_fwd(
|
| 109 |
+
q: torch.Tensor,
|
| 110 |
+
k: torch.Tensor,
|
| 111 |
+
v: torch.Tensor,
|
| 112 |
+
a: torch.Tensor,
|
| 113 |
+
b: torch.Tensor,
|
| 114 |
+
gk: torch.Tensor,
|
| 115 |
+
scale: Optional[float] = 1.0,
|
| 116 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 117 |
+
output_final_state: bool = False,
|
| 118 |
+
reverse: bool = False,
|
| 119 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 120 |
+
):
|
| 121 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 122 |
+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
| 123 |
+
BK = triton.next_power_of_2(K)
|
| 124 |
+
|
| 125 |
+
h0 = initial_state
|
| 126 |
+
if output_final_state:
|
| 127 |
+
ht = q.new_empty(N, H, K, V, dtype=torch.float32)
|
| 128 |
+
else:
|
| 129 |
+
ht = None
|
| 130 |
+
o = torch.empty_like(v)
|
| 131 |
+
|
| 132 |
+
def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
|
| 133 |
+
fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
|
| 134 |
+
q,
|
| 135 |
+
k,
|
| 136 |
+
v,
|
| 137 |
+
a,
|
| 138 |
+
b,
|
| 139 |
+
gk,
|
| 140 |
+
o,
|
| 141 |
+
h0,
|
| 142 |
+
ht,
|
| 143 |
+
cu_seqlens,
|
| 144 |
+
scale,
|
| 145 |
+
T=T,
|
| 146 |
+
B=B,
|
| 147 |
+
H=H,
|
| 148 |
+
K=K,
|
| 149 |
+
V=V,
|
| 150 |
+
BK=BK,
|
| 151 |
+
REVERSE=reverse,
|
| 152 |
+
)
|
| 153 |
+
return o, ht
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
@input_guard
|
| 160 |
+
@autocast_custom_fwd
|
| 161 |
+
def forward(
|
| 162 |
+
ctx,
|
| 163 |
+
q: torch.Tensor,
|
| 164 |
+
k: torch.Tensor,
|
| 165 |
+
v: torch.Tensor,
|
| 166 |
+
a: torch.Tensor,
|
| 167 |
+
b: torch.Tensor,
|
| 168 |
+
gk: torch.Tensor,
|
| 169 |
+
scale: Optional[float] = 1.0,
|
| 170 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 171 |
+
output_final_state: bool = False,
|
| 172 |
+
reverse: bool = False,
|
| 173 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 174 |
+
):
|
| 175 |
+
o, ht = fused_recurrent_dplr_delta_rule_fwd(
|
| 176 |
+
q=q,
|
| 177 |
+
k=k,
|
| 178 |
+
v=v,
|
| 179 |
+
a=a,
|
| 180 |
+
b=b,
|
| 181 |
+
gk=gk,
|
| 182 |
+
scale=scale,
|
| 183 |
+
initial_state=initial_state,
|
| 184 |
+
output_final_state=output_final_state,
|
| 185 |
+
reverse=reverse,
|
| 186 |
+
cu_seqlens=cu_seqlens,
|
| 187 |
+
)
|
| 188 |
+
return o, ht
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
@input_guard
|
| 192 |
+
@autocast_custom_bwd
|
| 193 |
+
def backward(ctx, do, dht):
|
| 194 |
+
raise NotImplementedError(
|
| 195 |
+
"Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
|
| 196 |
+
"This kernel is only for inference. "
|
| 197 |
+
"For training, please use `chunk_dplr_delta_rule`."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def fused_recurrent_dplr_delta_rule(
|
| 202 |
+
q: torch.Tensor,
|
| 203 |
+
k: torch.Tensor,
|
| 204 |
+
v: torch.Tensor,
|
| 205 |
+
a: torch.Tensor,
|
| 206 |
+
b: torch.Tensor,
|
| 207 |
+
gk: torch.Tensor,
|
| 208 |
+
scale: Optional[float] = 1.0,
|
| 209 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 210 |
+
output_final_state: bool = False,
|
| 211 |
+
reverse: bool = False,
|
| 212 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 213 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 214 |
+
r"""
|
| 215 |
+
This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
q (torch.Tensor):
|
| 219 |
+
queries of shape `[B, T, H, K]`.
|
| 220 |
+
k (torch.Tensor):
|
| 221 |
+
keys of shape `[B, T, H, K]`.
|
| 222 |
+
v (torch.Tensor):
|
| 223 |
+
values of shape `[B, T, H, V]`.
|
| 224 |
+
a (torch.Tensor):
|
| 225 |
+
a of shape `[B, T, H, K]`.
|
| 226 |
+
b (torch.Tensor):
|
| 227 |
+
b of shape `[B, T, H, K]`.
|
| 228 |
+
gk (torch.Tensor):
|
| 229 |
+
gk of shape `[B, T, H, K]`. decay term in log space!
|
| 230 |
+
scale (Optional[int]):
|
| 231 |
+
Scale factor for the RetNet attention scores.
|
| 232 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: 1.
|
| 233 |
+
initial_state (Optional[torch.Tensor]):
|
| 234 |
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
| 235 |
+
For equal-length input sequences, `N` equals the batch size `B`.
|
| 236 |
+
Default: `None`.
|
| 237 |
+
output_final_state (Optional[bool]):
|
| 238 |
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
| 239 |
+
reverse (Optional[bool]):
|
| 240 |
+
If `True`, process the state passing in reverse order. Default: `False`.
|
| 241 |
+
cu_seqlens (Optional[torch.Tensor]):
|
| 242 |
+
Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
|
| 243 |
+
consistent with the FlashAttention API.
|
| 244 |
+
"""
|
| 245 |
+
if cu_seqlens is not None:
|
| 246 |
+
if q.shape[0] != 1:
|
| 247 |
+
raise ValueError(
|
| 248 |
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
| 249 |
+
f"Please flatten variable-length inputs before processing."
|
| 250 |
+
)
|
| 251 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 252 |
+
raise ValueError(
|
| 253 |
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 254 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
| 255 |
+
)
|
| 256 |
+
if scale is None:
|
| 257 |
+
scale = q.shape[-1] ** -0.5
|
| 258 |
+
else:
|
| 259 |
+
assert scale > 0, "scale must be positive"
|
| 260 |
+
o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
|
| 261 |
+
q,
|
| 262 |
+
k,
|
| 263 |
+
v,
|
| 264 |
+
a,
|
| 265 |
+
b,
|
| 266 |
+
gk,
|
| 267 |
+
scale,
|
| 268 |
+
initial_state,
|
| 269 |
+
output_final_state,
|
| 270 |
+
reverse,
|
| 271 |
+
cu_seqlens,
|
| 272 |
+
)
|
| 273 |
+
return o, final_state
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/naive.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
|
| 7 |
+
# q, k, alpha, beta [B, H, L, D_K]
|
| 8 |
+
# v [B, H, L, D_V]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
|
| 12 |
+
orig_dtype = q.dtype
|
| 13 |
+
b, h, l, d_k = q.shape
|
| 14 |
+
q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
|
| 15 |
+
d_v = v.shape[-1]
|
| 16 |
+
o = torch.zeros_like(v)
|
| 17 |
+
S = torch.zeros(b, h, d_k, d_v).to(v)
|
| 18 |
+
q = q * (d_k ** -0.5)
|
| 19 |
+
|
| 20 |
+
if initial_state is not None:
|
| 21 |
+
S += initial_state
|
| 22 |
+
|
| 23 |
+
for i in range(l):
|
| 24 |
+
_k = k[:, :, i]
|
| 25 |
+
_q = q[:, :, i]
|
| 26 |
+
_v = v[:, :, i]
|
| 27 |
+
_alpha = alpha[:, :, i].clone()
|
| 28 |
+
_beta = beta[:, :, i].clone()
|
| 29 |
+
_kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
|
| 30 |
+
S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
|
| 31 |
+
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
|
| 32 |
+
S = None if output_final_state is False else S
|
| 33 |
+
return o.to(orig_dtype), S
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
|
| 37 |
+
b, h, l, d_k = q.shape
|
| 38 |
+
d_v = v.shape[-1]
|
| 39 |
+
q = q * (d_k ** -0.5)
|
| 40 |
+
v = v
|
| 41 |
+
assert l % chunk_size == 0
|
| 42 |
+
|
| 43 |
+
S = k.new_zeros(b, h, d_k, d_v).to(q)
|
| 44 |
+
if initial_state is not None:
|
| 45 |
+
S += initial_state
|
| 46 |
+
|
| 47 |
+
# note that diagonal is masked.
|
| 48 |
+
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
|
| 49 |
+
q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
|
| 50 |
+
c=chunk_size).float(), [q, k, v, alpha, beta, gk])
|
| 51 |
+
|
| 52 |
+
gk_cumsum = gk.cumsum(-2)
|
| 53 |
+
|
| 54 |
+
# v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
|
| 55 |
+
A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
|
| 56 |
+
A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
|
| 57 |
+
A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
|
| 58 |
+
A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
|
| 59 |
+
|
| 60 |
+
for i in range(chunk_size):
|
| 61 |
+
alpha_i = alpha[:, :, :, i, None]
|
| 62 |
+
q_i = q[:, :, :, i, None]
|
| 63 |
+
gk_i = gk_cumsum[:, :, :, i, None]
|
| 64 |
+
mask = (torch.arange(chunk_size) <= i).to(q.device)
|
| 65 |
+
attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
|
| 66 |
+
A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
|
| 67 |
+
A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
|
| 68 |
+
mask = (torch.arange(chunk_size) < i).to(q.device)
|
| 69 |
+
# shift by one.
|
| 70 |
+
attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
|
| 71 |
+
A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
|
| 72 |
+
A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
|
| 73 |
+
|
| 74 |
+
A_ab = A_ab
|
| 75 |
+
for i in range(1, chunk_size):
|
| 76 |
+
A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
|
| 77 |
+
|
| 78 |
+
A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
|
| 79 |
+
u = A_ab @ (A_ak @ v)
|
| 80 |
+
w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
|
| 81 |
+
|
| 82 |
+
o = torch.zeros_like(v)
|
| 83 |
+
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
|
| 84 |
+
for i in range(0, l // chunk_size):
|
| 85 |
+
q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
|
| 86 |
+
v2_i = u_i + w_i @ S
|
| 87 |
+
|
| 88 |
+
o_1 = A_qk[:, :, i] @ v_i
|
| 89 |
+
o_2 = A_qb[:, :, i] @ v2_i
|
| 90 |
+
o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
|
| 91 |
+
o[:, :, i] = o_1 + o_2 + o_3
|
| 92 |
+
decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
|
| 93 |
+
S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
|
| 94 |
+
(beta_i * decay).transpose(-1, -2) @ v2_i
|
| 95 |
+
S = None if output_final_state is False else S
|
| 96 |
+
return rearrange(o, 'b h n c d -> b h (n c) d'), S
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....ops.utils import prepare_chunk_indices
|
| 11 |
+
from ....utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
|
| 12 |
+
|
| 13 |
+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
|
| 14 |
+
triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.heuristics({
|
| 18 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 19 |
+
})
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=[
|
| 22 |
+
triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
|
| 23 |
+
for num_warps in [2, 4, 8, 16]
|
| 24 |
+
for num_stages in [2, 3, 4]
|
| 25 |
+
],
|
| 26 |
+
key=['BT', 'BK', 'BV'],
|
| 27 |
+
use_cuda_graph=use_cuda_graph,
|
| 28 |
+
)
|
| 29 |
+
@triton.jit(do_not_specialize=['T'])
|
| 30 |
+
def prepare_wy_repr_bwd_kernel(
|
| 31 |
+
A_ab_inv,
|
| 32 |
+
A_ak,
|
| 33 |
+
ag,
|
| 34 |
+
v,
|
| 35 |
+
dw,
|
| 36 |
+
du,
|
| 37 |
+
dv,
|
| 38 |
+
dv0,
|
| 39 |
+
dag,
|
| 40 |
+
dAak,
|
| 41 |
+
dAab,
|
| 42 |
+
cu_seqlens,
|
| 43 |
+
chunk_indices,
|
| 44 |
+
T,
|
| 45 |
+
H: tl.constexpr,
|
| 46 |
+
K: tl.constexpr,
|
| 47 |
+
V: tl.constexpr,
|
| 48 |
+
BT: tl.constexpr,
|
| 49 |
+
BK: tl.constexpr,
|
| 50 |
+
BV: tl.constexpr,
|
| 51 |
+
IS_VARLEN: tl.constexpr,
|
| 52 |
+
):
|
| 53 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 54 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 55 |
+
if IS_VARLEN:
|
| 56 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 57 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 58 |
+
T = eos - bos
|
| 59 |
+
else:
|
| 60 |
+
bos, eos = i_b * T, i_b * T + T
|
| 61 |
+
|
| 62 |
+
p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
|
| 63 |
+
p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
|
| 64 |
+
p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 65 |
+
p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 66 |
+
|
| 67 |
+
b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
|
| 68 |
+
b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
|
| 69 |
+
b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
|
| 70 |
+
b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
|
| 71 |
+
b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
|
| 72 |
+
b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
|
| 73 |
+
|
| 74 |
+
for i_v in range(tl.cdiv(V, BV)):
|
| 75 |
+
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 76 |
+
p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 77 |
+
p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 78 |
+
p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 79 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 80 |
+
b_du = tl.load(p_du, boundary_check=(0, 1))
|
| 81 |
+
b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
|
| 82 |
+
b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
|
| 83 |
+
b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
|
| 84 |
+
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
| 85 |
+
|
| 86 |
+
m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :]
|
| 87 |
+
b_dA_tmp = tl.where(m_i, b_dA_tmp, 0)
|
| 88 |
+
b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
|
| 89 |
+
b_dA_ak = tl.where(m_i, b_dA_ak, 0)
|
| 90 |
+
tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
|
| 91 |
+
b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
|
| 92 |
+
|
| 93 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 94 |
+
p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 95 |
+
p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 96 |
+
p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 97 |
+
b_ag = tl.load(p_ag, boundary_check=(0, 1))
|
| 98 |
+
b_dw = tl.load(p_dw, boundary_check=(0, 1))
|
| 99 |
+
b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
|
| 100 |
+
b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
|
| 101 |
+
tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
|
| 102 |
+
|
| 103 |
+
# if we know dL/dA^(-1), for dL/dA, we can use the following formula:
|
| 104 |
+
# dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
|
| 105 |
+
# in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
|
| 106 |
+
# denote A = I - lower(A_ab), B = A^-1
|
| 107 |
+
# in the backward pass.
|
| 108 |
+
# dL/dA = -(B)^T @ (dL/dB) @ B^T
|
| 109 |
+
# dL/dA_ab = lower(B^T @ dL/dB @ B^T)
|
| 110 |
+
b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
|
| 111 |
+
b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
|
| 112 |
+
b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
|
| 113 |
+
b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0)
|
| 114 |
+
tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def chunk_dplr_bwd_wy(
|
| 118 |
+
A_ab_inv: torch.Tensor,
|
| 119 |
+
A_ak: torch.Tensor,
|
| 120 |
+
v: torch.Tensor,
|
| 121 |
+
ag: torch.Tensor,
|
| 122 |
+
dw: torch.Tensor,
|
| 123 |
+
du: torch.Tensor,
|
| 124 |
+
dv0: torch.Tensor,
|
| 125 |
+
cu_seqlens: Optional[torch.LongTensor],
|
| 126 |
+
chunk_size: int,
|
| 127 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 128 |
+
A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
|
| 129 |
+
B, T, H, K, V = *dw.shape, du.shape[-1]
|
| 130 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 131 |
+
|
| 132 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 133 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 134 |
+
BK = min(triton.next_power_of_2(K), 64)
|
| 135 |
+
BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
|
| 136 |
+
|
| 137 |
+
dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
|
| 138 |
+
dA_ak = torch.empty_like(A_ak, dtype=torch.float)
|
| 139 |
+
dv = torch.empty_like(v)
|
| 140 |
+
dag = torch.empty_like(ag)
|
| 141 |
+
|
| 142 |
+
prepare_wy_repr_bwd_kernel[(NT, B * H)](
|
| 143 |
+
A_ab_inv=A_ab_inv,
|
| 144 |
+
A_ak=A_ak,
|
| 145 |
+
ag=ag,
|
| 146 |
+
v=v,
|
| 147 |
+
dw=dw,
|
| 148 |
+
du=du,
|
| 149 |
+
dv=dv,
|
| 150 |
+
dv0=dv0,
|
| 151 |
+
dag=dag,
|
| 152 |
+
dAak=dA_ak,
|
| 153 |
+
dAab=dA_ab,
|
| 154 |
+
cu_seqlens=cu_seqlens,
|
| 155 |
+
chunk_indices=chunk_indices,
|
| 156 |
+
T=T,
|
| 157 |
+
H=H,
|
| 158 |
+
K=K,
|
| 159 |
+
V=V,
|
| 160 |
+
BT=BT,
|
| 161 |
+
BK=BK,
|
| 162 |
+
BV=BV,
|
| 163 |
+
)
|
| 164 |
+
return dA_ab, dA_ak, dv, dag
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....ops.utils import prepare_chunk_indices
|
| 11 |
+
from ....ops.utils.op import gather
|
| 12 |
+
from ....utils import is_gather_supported, use_cuda_graph
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.heuristics({
|
| 16 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 17 |
+
})
|
| 18 |
+
@triton.autotune(
|
| 19 |
+
configs=[
|
| 20 |
+
triton.Config({}, num_warps=num_warps)
|
| 21 |
+
for num_warps in [1, 2, 4, 8, 16]
|
| 22 |
+
],
|
| 23 |
+
key=['BT'],
|
| 24 |
+
use_cuda_graph=use_cuda_graph,
|
| 25 |
+
)
|
| 26 |
+
@triton.jit(do_not_specialize=['T'])
|
| 27 |
+
def prepare_wy_repr_fwd_kernel_chunk32(
|
| 28 |
+
A_ab,
|
| 29 |
+
A_ab_inv,
|
| 30 |
+
cu_seqlens,
|
| 31 |
+
chunk_indices,
|
| 32 |
+
T,
|
| 33 |
+
H: tl.constexpr,
|
| 34 |
+
BT: tl.constexpr,
|
| 35 |
+
BC: tl.constexpr, # placeholder, do not delete
|
| 36 |
+
IS_VARLEN: tl.constexpr,
|
| 37 |
+
):
|
| 38 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 39 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 40 |
+
if IS_VARLEN:
|
| 41 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 42 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 43 |
+
T = eos - bos
|
| 44 |
+
else:
|
| 45 |
+
bos, eos = i_b * T, i_b * T + T
|
| 46 |
+
p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 47 |
+
p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 48 |
+
b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
|
| 49 |
+
b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
|
| 50 |
+
for i in range(1, BT):
|
| 51 |
+
mask = tl.arange(0, BT) == i
|
| 52 |
+
b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
|
| 53 |
+
b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
|
| 54 |
+
b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
|
| 55 |
+
b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
|
| 56 |
+
tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@triton.heuristics({
|
| 60 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 61 |
+
})
|
| 62 |
+
@triton.autotune(
|
| 63 |
+
configs=[
|
| 64 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 65 |
+
for num_warps in [2, 4, 8]
|
| 66 |
+
for num_stages in [2, 3, 4]
|
| 67 |
+
],
|
| 68 |
+
key=['BC'],
|
| 69 |
+
use_cuda_graph=use_cuda_graph,
|
| 70 |
+
)
|
| 71 |
+
@triton.jit(do_not_specialize=['T'])
|
| 72 |
+
def prepare_wy_repr_fwd_kernel_chunk64(
|
| 73 |
+
A_ab,
|
| 74 |
+
A_ab_inv,
|
| 75 |
+
cu_seqlens,
|
| 76 |
+
chunk_indices,
|
| 77 |
+
T,
|
| 78 |
+
H: tl.constexpr,
|
| 79 |
+
BT: tl.constexpr,
|
| 80 |
+
BC: tl.constexpr,
|
| 81 |
+
IS_VARLEN: tl.constexpr,
|
| 82 |
+
GATHER_SUPPORTED: tl.constexpr = is_gather_supported
|
| 83 |
+
):
|
| 84 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 85 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 86 |
+
if IS_VARLEN:
|
| 87 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 88 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 89 |
+
T = eos - bos
|
| 90 |
+
else:
|
| 91 |
+
bos, eos = i_b * T, i_b * T + T
|
| 92 |
+
|
| 93 |
+
p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
|
| 94 |
+
p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
|
| 95 |
+
p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
|
| 96 |
+
p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
|
| 97 |
+
p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
|
| 98 |
+
p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
|
| 99 |
+
p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
|
| 100 |
+
|
| 101 |
+
b_A = tl.load(p_A1, boundary_check=(0, 1))
|
| 102 |
+
b_A2 = tl.load(p_A2, boundary_check=(0, 1))
|
| 103 |
+
b_A3 = tl.load(p_A3, boundary_check=(0, 1))
|
| 104 |
+
b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
|
| 105 |
+
b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
|
| 106 |
+
|
| 107 |
+
for i in range(1, BC):
|
| 108 |
+
if GATHER_SUPPORTED:
|
| 109 |
+
row_idx = tl.full([1, BC], i, dtype=tl.int16)
|
| 110 |
+
# [1, BK] -> [BK]
|
| 111 |
+
b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
|
| 112 |
+
b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
|
| 113 |
+
else:
|
| 114 |
+
mask = tl.arange(0, BC) == i
|
| 115 |
+
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
|
| 116 |
+
b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
|
| 117 |
+
mask = tl.arange(0, BC) == i
|
| 118 |
+
# b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
|
| 119 |
+
# b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
|
| 120 |
+
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
|
| 121 |
+
b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
|
| 122 |
+
b_A = tl.where(mask[:, None], b_a, b_A)
|
| 123 |
+
b_A2 = tl.where(mask[:, None], b_a2, b_A2)
|
| 124 |
+
|
| 125 |
+
# blockwise computation of lower triangular matrix's inverse
|
| 126 |
+
# i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
|
| 127 |
+
b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
|
| 128 |
+
b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
|
| 129 |
+
b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
|
| 130 |
+
# tl.debug_barrier()
|
| 131 |
+
tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 132 |
+
tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 133 |
+
tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 134 |
+
# causal mask
|
| 135 |
+
tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@triton.heuristics({
|
| 139 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 140 |
+
})
|
| 141 |
+
@triton.autotune(
|
| 142 |
+
configs=[
|
| 143 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 144 |
+
for num_warps in [2, 4, 8, 16]
|
| 145 |
+
for num_stages in [2, 3, 4]
|
| 146 |
+
],
|
| 147 |
+
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
|
| 148 |
+
use_cuda_graph=use_cuda_graph,
|
| 149 |
+
)
|
| 150 |
+
@triton.jit(do_not_specialize=['T'])
|
| 151 |
+
def wu_fwd_kernel(
|
| 152 |
+
w,
|
| 153 |
+
u,
|
| 154 |
+
ag,
|
| 155 |
+
v,
|
| 156 |
+
A_ab_inv,
|
| 157 |
+
A_ak,
|
| 158 |
+
cu_seqlens,
|
| 159 |
+
chunk_indices,
|
| 160 |
+
T,
|
| 161 |
+
H: tl.constexpr,
|
| 162 |
+
K: tl.constexpr,
|
| 163 |
+
V: tl.constexpr,
|
| 164 |
+
BT: tl.constexpr,
|
| 165 |
+
BK: tl.constexpr,
|
| 166 |
+
BV: tl.constexpr,
|
| 167 |
+
IS_VARLEN: tl.constexpr,
|
| 168 |
+
):
|
| 169 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 170 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 171 |
+
if IS_VARLEN:
|
| 172 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 173 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 174 |
+
T = eos - bos
|
| 175 |
+
else:
|
| 176 |
+
bos, eos = i_b * T, i_b * T + T
|
| 177 |
+
o_s = tl.arange(0, BT)
|
| 178 |
+
|
| 179 |
+
p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 180 |
+
p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 181 |
+
|
| 182 |
+
b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
|
| 183 |
+
b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
|
| 184 |
+
b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
|
| 185 |
+
b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
|
| 186 |
+
# let's use tf32 here
|
| 187 |
+
b_Aak = tl.dot(b_Aab_inv, b_Aak)
|
| 188 |
+
# (SY 01/04) should be bf16 or tf32? To verify.
|
| 189 |
+
b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
|
| 190 |
+
b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
|
| 191 |
+
|
| 192 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 193 |
+
p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 194 |
+
p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 195 |
+
b_ag = tl.load(p_ag, boundary_check=(0, 1))
|
| 196 |
+
b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
|
| 197 |
+
tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 198 |
+
|
| 199 |
+
for i_v in range(tl.cdiv(V, BV)):
|
| 200 |
+
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 201 |
+
p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 202 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 203 |
+
b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
|
| 204 |
+
tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def wu_fwd(
|
| 208 |
+
ag: torch.Tensor,
|
| 209 |
+
v: torch.Tensor,
|
| 210 |
+
A_ak: torch.Tensor,
|
| 211 |
+
A_ab_inv: torch.Tensor,
|
| 212 |
+
cu_seqlens: Optional[torch.LongTensor],
|
| 213 |
+
chunk_size: int
|
| 214 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 215 |
+
B, T, H, K, V = *ag.shape, v.shape[-1]
|
| 216 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 217 |
+
|
| 218 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 219 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 220 |
+
BK = min(triton.next_power_of_2(K), 64)
|
| 221 |
+
BV = min(triton.next_power_of_2(V), 64)
|
| 222 |
+
|
| 223 |
+
w = torch.empty_like(ag)
|
| 224 |
+
u = torch.empty_like(v)
|
| 225 |
+
wu_fwd_kernel[(NT, B * H)](
|
| 226 |
+
ag=ag,
|
| 227 |
+
v=v,
|
| 228 |
+
A_ak=A_ak,
|
| 229 |
+
A_ab_inv=A_ab_inv,
|
| 230 |
+
w=w,
|
| 231 |
+
u=u,
|
| 232 |
+
cu_seqlens=cu_seqlens,
|
| 233 |
+
chunk_indices=chunk_indices,
|
| 234 |
+
T=T,
|
| 235 |
+
H=H,
|
| 236 |
+
K=K,
|
| 237 |
+
V=V,
|
| 238 |
+
BT=BT,
|
| 239 |
+
BK=BK,
|
| 240 |
+
BV=BV,
|
| 241 |
+
)
|
| 242 |
+
return w, u
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def prepare_wy_repr_fwd(
|
| 246 |
+
ag: torch.Tensor,
|
| 247 |
+
v: torch.Tensor,
|
| 248 |
+
A_ak: torch.Tensor,
|
| 249 |
+
A_ab: torch.Tensor,
|
| 250 |
+
cu_seqlens: Optional[torch.LongTensor],
|
| 251 |
+
chunk_size: int = 64
|
| 252 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 253 |
+
B, T, H, _ = ag.shape
|
| 254 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 255 |
+
|
| 256 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 257 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 258 |
+
BC = min(BT, 32)
|
| 259 |
+
fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
|
| 260 |
+
A_ab_inv = torch.empty_like(A_ab)
|
| 261 |
+
fwd_fn[(NT, B * H)](
|
| 262 |
+
A_ab=A_ab,
|
| 263 |
+
A_ab_inv=A_ab_inv,
|
| 264 |
+
cu_seqlens=cu_seqlens,
|
| 265 |
+
chunk_indices=chunk_indices,
|
| 266 |
+
T=T,
|
| 267 |
+
H=H,
|
| 268 |
+
BT=BT,
|
| 269 |
+
BC=BC,
|
| 270 |
+
)
|
| 271 |
+
w, u = wu_fwd(
|
| 272 |
+
ag=ag,
|
| 273 |
+
v=v,
|
| 274 |
+
A_ak=A_ak,
|
| 275 |
+
A_ab_inv=A_ab_inv,
|
| 276 |
+
cu_seqlens=cu_seqlens,
|
| 277 |
+
chunk_size=BT
|
| 278 |
+
)
|
| 279 |
+
return w, u, A_ab_inv
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
fwd_prepare_wy_repr = prepare_wy_repr_fwd
|
| 283 |
+
|
| 284 |
+
fwd_wu = wu_fwd
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .chunk import chunk_iplr_delta_rule
|
| 2 |
+
from .fused_recurrent import fused_recurrent_iplr_delta_rule
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'chunk_iplr_delta_rule',
|
| 6 |
+
'fused_recurrent_iplr_delta_rule'
|
| 7 |
+
]
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/chunk.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from ....ops.generalized_delta_rule.iplr.wy_fast import prepare_wy_repr_fwd
|
| 13 |
+
from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
|
| 14 |
+
from ....utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
|
| 15 |
+
|
| 16 |
+
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@triton.heuristics({
|
| 20 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 21 |
+
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
| 22 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 23 |
+
})
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=num_warps)
|
| 27 |
+
for num_warps in [2, 4, 8, 16]
|
| 28 |
+
],
|
| 29 |
+
key=['BT', 'BK', 'BV'],
|
| 30 |
+
use_cuda_graph=use_cuda_graph,
|
| 31 |
+
)
|
| 32 |
+
@triton.jit(do_not_specialize=['T'])
|
| 33 |
+
def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
|
| 34 |
+
k,
|
| 35 |
+
v,
|
| 36 |
+
d,
|
| 37 |
+
b,
|
| 38 |
+
u,
|
| 39 |
+
v_new,
|
| 40 |
+
h,
|
| 41 |
+
h0,
|
| 42 |
+
ht,
|
| 43 |
+
cu_seqlens,
|
| 44 |
+
chunk_offsets,
|
| 45 |
+
T,
|
| 46 |
+
H: tl.constexpr,
|
| 47 |
+
K: tl.constexpr,
|
| 48 |
+
V: tl.constexpr,
|
| 49 |
+
BT: tl.constexpr,
|
| 50 |
+
BC: tl.constexpr,
|
| 51 |
+
BK: tl.constexpr,
|
| 52 |
+
BV: tl.constexpr,
|
| 53 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 54 |
+
STORE_FINAL_STATE: tl.constexpr,
|
| 55 |
+
IS_VARLEN: tl.constexpr,
|
| 56 |
+
):
|
| 57 |
+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 58 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 59 |
+
if IS_VARLEN:
|
| 60 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 61 |
+
T = eos - bos
|
| 62 |
+
NT = tl.cdiv(T, BT)
|
| 63 |
+
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
| 64 |
+
else:
|
| 65 |
+
bos, eos = i_n * T, i_n * T + T
|
| 66 |
+
NT = tl.cdiv(T, BT)
|
| 67 |
+
boh = i_n * NT
|
| 68 |
+
|
| 69 |
+
# [BK, BV]
|
| 70 |
+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
| 71 |
+
if USE_INITIAL_STATE:
|
| 72 |
+
p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 73 |
+
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
| 74 |
+
|
| 75 |
+
for i_t in range(NT):
|
| 76 |
+
p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 77 |
+
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
| 78 |
+
b_hc = tl.zeros([BK, BV], dtype=tl.float32)
|
| 79 |
+
# since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
|
| 80 |
+
for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
|
| 81 |
+
p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
| 82 |
+
p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
| 83 |
+
p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
|
| 84 |
+
p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
| 85 |
+
p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
| 86 |
+
p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
|
| 87 |
+
# [BK, BC]
|
| 88 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 89 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 90 |
+
b_d = tl.load(p_d, boundary_check=(0, 1))
|
| 91 |
+
b_b = tl.load(p_b, boundary_check=(0, 1))
|
| 92 |
+
b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
|
| 93 |
+
b_hc += tl.dot(b_k, b_v)
|
| 94 |
+
b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
|
| 95 |
+
tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
| 96 |
+
b_h += b_hc
|
| 97 |
+
|
| 98 |
+
if STORE_FINAL_STATE:
|
| 99 |
+
p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 100 |
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@triton.heuristics({
|
| 104 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 105 |
+
})
|
| 106 |
+
@triton.autotune(
|
| 107 |
+
configs=[
|
| 108 |
+
triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
| 109 |
+
for BK in BKV_LIST
|
| 110 |
+
for BV in BKV_LIST
|
| 111 |
+
for num_warps in [2, 4, 8]
|
| 112 |
+
for num_stages in [2, 3]
|
| 113 |
+
],
|
| 114 |
+
key=['BT'],
|
| 115 |
+
use_cuda_graph=use_cuda_graph,
|
| 116 |
+
)
|
| 117 |
+
@triton.jit(do_not_specialize=['T'])
|
| 118 |
+
def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
|
| 119 |
+
q,
|
| 120 |
+
k,
|
| 121 |
+
v,
|
| 122 |
+
u,
|
| 123 |
+
b,
|
| 124 |
+
h,
|
| 125 |
+
o,
|
| 126 |
+
cu_seqlens,
|
| 127 |
+
chunk_indices,
|
| 128 |
+
scale,
|
| 129 |
+
T,
|
| 130 |
+
H: tl.constexpr,
|
| 131 |
+
K: tl.constexpr,
|
| 132 |
+
V: tl.constexpr,
|
| 133 |
+
BT: tl.constexpr,
|
| 134 |
+
BK: tl.constexpr,
|
| 135 |
+
BV: tl.constexpr,
|
| 136 |
+
IS_VARLEN: tl.constexpr,
|
| 137 |
+
):
|
| 138 |
+
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 139 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 140 |
+
|
| 141 |
+
if IS_VARLEN:
|
| 142 |
+
i_tg = i_t
|
| 143 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 144 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 145 |
+
T = eos - bos
|
| 146 |
+
NT = tl.cdiv(T, BT)
|
| 147 |
+
else:
|
| 148 |
+
NT = tl.cdiv(T, BT)
|
| 149 |
+
i_tg = i_b * NT + i_t
|
| 150 |
+
bos, eos = i_b * T, i_b * T + T
|
| 151 |
+
|
| 152 |
+
# offset calculation
|
| 153 |
+
q += (bos * H + i_h) * K
|
| 154 |
+
k += (bos * H + i_h) * K
|
| 155 |
+
b += (bos * H + i_h) * K
|
| 156 |
+
v += (bos * H + i_h) * V
|
| 157 |
+
u += (bos * H + i_h) * V
|
| 158 |
+
o += (bos * H + i_h) * V
|
| 159 |
+
h += (i_tg * H + i_h) * K * V
|
| 160 |
+
stride_qk = H*K
|
| 161 |
+
stride_vo = H*V
|
| 162 |
+
|
| 163 |
+
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
| 164 |
+
b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
|
| 165 |
+
b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
|
| 166 |
+
|
| 167 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 168 |
+
p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 169 |
+
p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
| 170 |
+
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
| 171 |
+
p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
| 172 |
+
# [BT, BK]
|
| 173 |
+
b_q = tl.load(p_q, boundary_check=(0, 1))
|
| 174 |
+
# [BK, BT]
|
| 175 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 176 |
+
b_b = tl.load(p_b, boundary_check=(0, 1))
|
| 177 |
+
# [BK, BV]
|
| 178 |
+
b_h = tl.load(p_h, boundary_check=(0, 1))
|
| 179 |
+
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
| 180 |
+
b_o += tl.dot(b_q, b_h)
|
| 181 |
+
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
| 182 |
+
b_Aqk += tl.dot(b_q, b_k)
|
| 183 |
+
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
| 184 |
+
b_Aqb += tl.dot(b_q, b_b)
|
| 185 |
+
|
| 186 |
+
o_i = tl.arange(0, BT)
|
| 187 |
+
m_A = o_i[:, None] >= o_i[None, :]
|
| 188 |
+
b_Aqk = tl.where(m_A, b_Aqk, 0)
|
| 189 |
+
b_Aqb = tl.where(m_A, b_Aqb, 0)
|
| 190 |
+
|
| 191 |
+
p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 192 |
+
p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 193 |
+
p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 194 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 195 |
+
b_u = tl.load(p_u, boundary_check=(0, 1))
|
| 196 |
+
b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
|
| 197 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def chunk_generalized_iplr_delta_rule_fwd_o(
|
| 201 |
+
q: torch.Tensor,
|
| 202 |
+
k: torch.Tensor,
|
| 203 |
+
v: torch.Tensor,
|
| 204 |
+
v_new: torch.Tensor,
|
| 205 |
+
b: torch.Tensor,
|
| 206 |
+
h: torch.Tensor,
|
| 207 |
+
scale: Optional[float] = None,
|
| 208 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 209 |
+
chunk_size: int = 64
|
| 210 |
+
) -> torch.Tensor:
|
| 211 |
+
B, T, H, K, V = *q.shape, v.shape[-1]
|
| 212 |
+
if scale is None:
|
| 213 |
+
scale = k.shape[-1] ** -0.5
|
| 214 |
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
| 215 |
+
|
| 216 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 217 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 218 |
+
|
| 219 |
+
o = torch.empty_like(v)
|
| 220 |
+
|
| 221 |
+
def grid(meta): return (
|
| 222 |
+
triton.cdiv(V, meta['BV']),
|
| 223 |
+
NT,
|
| 224 |
+
B * H
|
| 225 |
+
)
|
| 226 |
+
chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
|
| 227 |
+
q=q,
|
| 228 |
+
k=k,
|
| 229 |
+
v=v,
|
| 230 |
+
u=v_new,
|
| 231 |
+
b=b,
|
| 232 |
+
h=h,
|
| 233 |
+
o=o,
|
| 234 |
+
cu_seqlens=cu_seqlens,
|
| 235 |
+
chunk_indices=chunk_indices,
|
| 236 |
+
scale=scale,
|
| 237 |
+
T=T,
|
| 238 |
+
H=H,
|
| 239 |
+
K=K,
|
| 240 |
+
V=V,
|
| 241 |
+
BT=BT,
|
| 242 |
+
)
|
| 243 |
+
return o
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def chunk_generalized_iplr_delta_rule_fwd_h(
|
| 247 |
+
k: torch.Tensor,
|
| 248 |
+
v: torch.Tensor,
|
| 249 |
+
w: torch.Tensor,
|
| 250 |
+
u: torch.Tensor,
|
| 251 |
+
b: torch.Tensor,
|
| 252 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 253 |
+
output_final_state: bool = False,
|
| 254 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 255 |
+
chunk_size: int = 64
|
| 256 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 257 |
+
B, T, H, K, V = *k.shape, u.shape[-1]
|
| 258 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 259 |
+
|
| 260 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 261 |
+
# N: the actual number of sequences in the batch with either equal or variable lengths
|
| 262 |
+
if cu_seqlens is None:
|
| 263 |
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
| 264 |
+
else:
|
| 265 |
+
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
| 266 |
+
|
| 267 |
+
BK = triton.next_power_of_2(K)
|
| 268 |
+
assert BK <= 256, "current kernel does not support head dimension larger than 256."
|
| 269 |
+
# H100 can have larger block size
|
| 270 |
+
|
| 271 |
+
if check_shared_mem('hopper', k.device.index):
|
| 272 |
+
BV = 64
|
| 273 |
+
BC = 64 if K <= 128 else 32
|
| 274 |
+
elif check_shared_mem('ampere', k.device.index): # A100
|
| 275 |
+
BV = 32
|
| 276 |
+
BC = 32
|
| 277 |
+
else:
|
| 278 |
+
BV = 16
|
| 279 |
+
BC = 16
|
| 280 |
+
|
| 281 |
+
BC = min(BT, BC)
|
| 282 |
+
NK = triton.cdiv(K, BK)
|
| 283 |
+
NV = triton.cdiv(V, BV)
|
| 284 |
+
|
| 285 |
+
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
|
| 286 |
+
|
| 287 |
+
h = k.new_empty(B, NT, H, K, V)
|
| 288 |
+
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
| 289 |
+
|
| 290 |
+
v_new = torch.empty_like(u)
|
| 291 |
+
grid = (NK, NV, N * H)
|
| 292 |
+
|
| 293 |
+
chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
|
| 294 |
+
k=k,
|
| 295 |
+
v=v,
|
| 296 |
+
d=w,
|
| 297 |
+
b=b,
|
| 298 |
+
u=u,
|
| 299 |
+
v_new=v_new,
|
| 300 |
+
h=h,
|
| 301 |
+
h0=initial_state,
|
| 302 |
+
ht=final_state,
|
| 303 |
+
cu_seqlens=cu_seqlens,
|
| 304 |
+
chunk_offsets=chunk_offsets,
|
| 305 |
+
T=T,
|
| 306 |
+
H=H,
|
| 307 |
+
K=K,
|
| 308 |
+
V=V,
|
| 309 |
+
BT=BT,
|
| 310 |
+
BC=BC,
|
| 311 |
+
BK=BK,
|
| 312 |
+
BV=BV,
|
| 313 |
+
)
|
| 314 |
+
return h, v_new, final_state
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def chunk_generalized_iplr_delta_rule_fwd(
|
| 318 |
+
q: torch.Tensor,
|
| 319 |
+
k: torch.Tensor,
|
| 320 |
+
v: torch.Tensor,
|
| 321 |
+
a: torch.Tensor,
|
| 322 |
+
b: torch.Tensor,
|
| 323 |
+
scale: float,
|
| 324 |
+
initial_state: torch.Tensor,
|
| 325 |
+
output_final_state: bool,
|
| 326 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 327 |
+
chunk_size: int = 64
|
| 328 |
+
):
|
| 329 |
+
T = q.shape[1]
|
| 330 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 331 |
+
w, u, _ = prepare_wy_repr_fwd(
|
| 332 |
+
a=a,
|
| 333 |
+
b=b,
|
| 334 |
+
k=k,
|
| 335 |
+
v=v,
|
| 336 |
+
cu_seqlens=cu_seqlens,
|
| 337 |
+
chunk_size=BT
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
|
| 341 |
+
k=k,
|
| 342 |
+
v=v,
|
| 343 |
+
b=b,
|
| 344 |
+
w=w,
|
| 345 |
+
u=u,
|
| 346 |
+
initial_state=initial_state,
|
| 347 |
+
output_final_state=output_final_state,
|
| 348 |
+
cu_seqlens=cu_seqlens,
|
| 349 |
+
chunk_size=BT
|
| 350 |
+
)
|
| 351 |
+
o = chunk_generalized_iplr_delta_rule_fwd_o(
|
| 352 |
+
q=q,
|
| 353 |
+
k=k,
|
| 354 |
+
v=v,
|
| 355 |
+
v_new=v_new,
|
| 356 |
+
b=b,
|
| 357 |
+
h=h,
|
| 358 |
+
scale=scale,
|
| 359 |
+
cu_seqlens=cu_seqlens,
|
| 360 |
+
chunk_size=BT
|
| 361 |
+
)
|
| 362 |
+
return o, final_state
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
|
| 366 |
+
|
| 367 |
+
@staticmethod
|
| 368 |
+
@input_guard
|
| 369 |
+
@autocast_custom_fwd
|
| 370 |
+
def forward(
|
| 371 |
+
ctx,
|
| 372 |
+
q: torch.Tensor,
|
| 373 |
+
k: torch.Tensor,
|
| 374 |
+
v: torch.Tensor,
|
| 375 |
+
a: torch.Tensor,
|
| 376 |
+
b: torch.Tensor,
|
| 377 |
+
scale: float,
|
| 378 |
+
initial_state: torch.Tensor,
|
| 379 |
+
output_final_state: bool,
|
| 380 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 381 |
+
):
|
| 382 |
+
chunk_size = 64
|
| 383 |
+
|
| 384 |
+
o, final_state = chunk_generalized_iplr_delta_rule_fwd(
|
| 385 |
+
q=q,
|
| 386 |
+
k=k,
|
| 387 |
+
v=v,
|
| 388 |
+
a=a,
|
| 389 |
+
b=b,
|
| 390 |
+
scale=scale,
|
| 391 |
+
initial_state=initial_state,
|
| 392 |
+
output_final_state=output_final_state,
|
| 393 |
+
cu_seqlens=cu_seqlens,
|
| 394 |
+
chunk_size=chunk_size
|
| 395 |
+
)
|
| 396 |
+
return o.to(q.dtype), final_state
|
| 397 |
+
|
| 398 |
+
@staticmethod
|
| 399 |
+
@input_guard
|
| 400 |
+
@autocast_custom_bwd
|
| 401 |
+
def backward(
|
| 402 |
+
ctx,
|
| 403 |
+
do: torch.Tensor,
|
| 404 |
+
dht: torch.Tensor
|
| 405 |
+
):
|
| 406 |
+
raise NotImplementedError(
|
| 407 |
+
"Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
|
| 408 |
+
"Stay tuned!"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
@torch.compiler.disable
|
| 413 |
+
def chunk_iplr_delta_rule(
|
| 414 |
+
q: torch.Tensor,
|
| 415 |
+
k: torch.Tensor,
|
| 416 |
+
v: torch.Tensor,
|
| 417 |
+
a: torch.Tensor,
|
| 418 |
+
b: torch.Tensor,
|
| 419 |
+
scale: float = None,
|
| 420 |
+
initial_state: torch.Tensor = None,
|
| 421 |
+
output_final_state: bool = False,
|
| 422 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 423 |
+
head_first: bool = False
|
| 424 |
+
):
|
| 425 |
+
r"""
|
| 426 |
+
Args:
|
| 427 |
+
q (torch.Tensor):
|
| 428 |
+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 429 |
+
k (torch.Tensor):
|
| 430 |
+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 431 |
+
v (torch.Tensor):
|
| 432 |
+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
| 433 |
+
a (torch.Tensor):
|
| 434 |
+
activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 435 |
+
b (torch.Tensor):
|
| 436 |
+
betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 437 |
+
scale (Optional[int]):
|
| 438 |
+
Scale factor for the RetNet attention scores.
|
| 439 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
| 440 |
+
initial_state (Optional[torch.Tensor]):
|
| 441 |
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
| 442 |
+
For equal-length input sequences, `N` equals the batch size `B`.
|
| 443 |
+
Default: `None`.
|
| 444 |
+
output_final_state (Optional[bool]):
|
| 445 |
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
| 446 |
+
cu_seqlens (torch.LongTensor):
|
| 447 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 448 |
+
consistent with the FlashAttention API.
|
| 449 |
+
head_first (Optional[bool]):
|
| 450 |
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
| 451 |
+
Default: `False`.
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
o (torch.Tensor):
|
| 455 |
+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
| 456 |
+
final_state (torch.Tensor):
|
| 457 |
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
| 458 |
+
"""
|
| 459 |
+
assert q.dtype == k.dtype == v.dtype
|
| 460 |
+
assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
|
| 461 |
+
|
| 462 |
+
if head_first:
|
| 463 |
+
raise DeprecationWarning(
|
| 464 |
+
"head_first is deprecated and will be removed in a future version. "
|
| 465 |
+
"Please use head_first=False for now instead."
|
| 466 |
+
)
|
| 467 |
+
q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
|
| 468 |
+
if not head_first and q.shape[1] < q.shape[2]:
|
| 469 |
+
warnings.warn(
|
| 470 |
+
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
| 471 |
+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
| 472 |
+
"when head_first=False was specified. "
|
| 473 |
+
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
| 474 |
+
)
|
| 475 |
+
if cu_seqlens is not None:
|
| 476 |
+
if q.shape[0] != 1:
|
| 477 |
+
raise ValueError(
|
| 478 |
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
| 479 |
+
f"Please ...tten variable-length inputs before processing."
|
| 480 |
+
)
|
| 481 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 482 |
+
raise ValueError(
|
| 483 |
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 484 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
| 485 |
+
)
|
| 486 |
+
scale = k.shape[-1] ** -0.5 if scale is None else scale
|
| 487 |
+
o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
|
| 488 |
+
q,
|
| 489 |
+
k,
|
| 490 |
+
v,
|
| 491 |
+
a,
|
| 492 |
+
b,
|
| 493 |
+
scale,
|
| 494 |
+
initial_state,
|
| 495 |
+
output_final_state,
|
| 496 |
+
cu_seqlens,
|
| 497 |
+
)
|
| 498 |
+
if head_first:
|
| 499 |
+
o = rearrange(o, 'b t h ... -> b h t ...')
|
| 500 |
+
return o, final_state
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ....utils import input_guard
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.heuristics({
|
| 14 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 15 |
+
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
| 16 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 17 |
+
})
|
| 18 |
+
@triton.autotune(
|
| 19 |
+
configs=[
|
| 20 |
+
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
| 21 |
+
for BV in [32, 64]
|
| 22 |
+
for num_warps in [2, 4, 8, 16]
|
| 23 |
+
for num_stages in [2, 3, 4]
|
| 24 |
+
],
|
| 25 |
+
key=["BK"],
|
| 26 |
+
)
|
| 27 |
+
@triton.jit
|
| 28 |
+
def fused_recurrent_fwd_kernel(
|
| 29 |
+
q, # query [B, H, L, K]
|
| 30 |
+
k, # key [B, H, L, V]
|
| 31 |
+
v, # value [B, H, L, V].
|
| 32 |
+
a, # a [B, H, L, K]
|
| 33 |
+
b, # b [B, H, L, K]
|
| 34 |
+
o, # output [B, H, L, V]
|
| 35 |
+
ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
|
| 36 |
+
h0, # initial hidden state [B, H, K, V]
|
| 37 |
+
ht, # final hidden state [B, H, K, V]
|
| 38 |
+
cu_seqlens, # varlen cu_seqlens
|
| 39 |
+
scale, # K ** -0.5
|
| 40 |
+
H, # n_heads
|
| 41 |
+
T, # seq_len
|
| 42 |
+
K: tl.constexpr, # K
|
| 43 |
+
V: tl.constexpr, # V
|
| 44 |
+
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
| 45 |
+
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
| 46 |
+
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
| 47 |
+
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
| 48 |
+
IS_VARLEN: tl.constexpr,
|
| 49 |
+
):
|
| 50 |
+
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
| 51 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 52 |
+
|
| 53 |
+
if IS_VARLEN:
|
| 54 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
| 55 |
+
T = eos - bos
|
| 56 |
+
else:
|
| 57 |
+
bos, eos = i_n * T, i_n * T + T
|
| 58 |
+
|
| 59 |
+
p_q = q + (bos * H + i_h) * K + tl.arange(0, BK)
|
| 60 |
+
p_k = k + (bos * H + i_h) * K + tl.arange(0, BK)
|
| 61 |
+
p_a = a + (bos * H + i_h) * K + tl.arange(0, BK)
|
| 62 |
+
p_b = b + (bos * H + i_h) * K + tl.arange(0, BK)
|
| 63 |
+
p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
|
| 64 |
+
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
|
| 65 |
+
p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
|
| 66 |
+
|
| 67 |
+
mask_k = tl.arange(0, BK) < K
|
| 68 |
+
mask_v = (i_v * BV + tl.arange(0, BV)) < V
|
| 69 |
+
mask_h = mask_k[None, :] & mask_v[:, None]
|
| 70 |
+
|
| 71 |
+
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
| 72 |
+
|
| 73 |
+
if USE_INITIAL_STATE:
|
| 74 |
+
p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
|
| 75 |
+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
| 76 |
+
|
| 77 |
+
for _ in range(0, T):
|
| 78 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 79 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 80 |
+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
|
| 81 |
+
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
|
| 82 |
+
b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
|
| 83 |
+
# to store
|
| 84 |
+
tmp = tl.sum(b_h * b_a[None, :], axis=1)
|
| 85 |
+
b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
|
| 86 |
+
b_o = b_h * b_q[None, :]
|
| 87 |
+
b_o = tl.sum(b_o, axis=1)
|
| 88 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
| 89 |
+
tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v)
|
| 90 |
+
p_q += K*H
|
| 91 |
+
p_k += K*H
|
| 92 |
+
p_o += V*H
|
| 93 |
+
p_v += V*H
|
| 94 |
+
p_ha += V*H
|
| 95 |
+
p_a += K*H
|
| 96 |
+
p_b += K*H
|
| 97 |
+
|
| 98 |
+
if STORE_FINAL_STATE:
|
| 99 |
+
p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
|
| 100 |
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@triton.heuristics({
|
| 104 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 105 |
+
'USE_DHT': lambda args: args['dht'] is not None,
|
| 106 |
+
'USE_DH0': lambda args: args['dh0'] is not None,
|
| 107 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 108 |
+
})
|
| 109 |
+
@triton.autotune(
|
| 110 |
+
configs=[
|
| 111 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 112 |
+
for num_warps in [2, 4, 8, 16]
|
| 113 |
+
for num_stages in [2, 3]
|
| 114 |
+
],
|
| 115 |
+
key=["BK", "BV"],
|
| 116 |
+
)
|
| 117 |
+
@triton.jit
|
| 118 |
+
def fused_recurrent_bwd_kernel(
|
| 119 |
+
# B: batch_size, H: n_heads, T: seq_len, D: b_dhead
|
| 120 |
+
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
| 121 |
+
q, # query [B, H, L, K]
|
| 122 |
+
k, # key [B, H, L, V]
|
| 123 |
+
v, # value [B, H, L, V]
|
| 124 |
+
a, # a [B, H, L, K]
|
| 125 |
+
b, # b [B, H, L, K]
|
| 126 |
+
ha, # ha [B, H, L, V]
|
| 127 |
+
dht, # gradient of final state [B, H, K, V]
|
| 128 |
+
dh0, # gradient of initial state [B, H, K, V]
|
| 129 |
+
do, # gradient of output [B, H, L, V]
|
| 130 |
+
dq, # gradient of query [NV, B, H, L, K]
|
| 131 |
+
dk, # gradient of key [NV, B, H, L, K]
|
| 132 |
+
dv, # gradient of value [NK, B, H, L, V]
|
| 133 |
+
da, # gradient of a [NV, B, H, L, K]
|
| 134 |
+
db, # gradient of b [NV, B, H, L, K]
|
| 135 |
+
dha, # gradient of ha [NK, B, H, L, V]
|
| 136 |
+
h0, # initial state [B, H, K, V]
|
| 137 |
+
scale, # K ** -0.5
|
| 138 |
+
cu_seqlens, # cu_seqlens
|
| 139 |
+
B, # batch_size
|
| 140 |
+
H, # n_heads
|
| 141 |
+
T, # seq_len
|
| 142 |
+
K: tl.constexpr, # K
|
| 143 |
+
V: tl.constexpr, # V
|
| 144 |
+
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
| 145 |
+
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
| 146 |
+
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0
|
| 147 |
+
USE_DH0: tl.constexpr, # whether to use dh0
|
| 148 |
+
USE_DHT: tl.constexpr, # whether to use dht
|
| 149 |
+
IS_VARLEN: tl.constexpr,
|
| 150 |
+
):
|
| 151 |
+
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
| 152 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 153 |
+
dk += i_v * B * H * K * T
|
| 154 |
+
db += i_v * B * H * K * T
|
| 155 |
+
dq += i_v * B * H * K * T
|
| 156 |
+
da += i_v * B * H * K * T
|
| 157 |
+
if IS_VARLEN:
|
| 158 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
| 159 |
+
T = eos - bos
|
| 160 |
+
else:
|
| 161 |
+
bos, eos = i_n * T, i_n * T + T
|
| 162 |
+
mask_k = tl.arange(0, BK) < K
|
| 163 |
+
mask_v = (tl.arange(0, BV) + i_v * BV) < V
|
| 164 |
+
|
| 165 |
+
q += (bos * H + i_h) * K
|
| 166 |
+
k += (bos * H + i_h) * K
|
| 167 |
+
v += (bos * H + i_h) * V + i_v * BV
|
| 168 |
+
ha += (bos * H + i_h) * V + i_v * BV
|
| 169 |
+
a += (bos * H + i_h) * K
|
| 170 |
+
b += (bos * H + i_h) * K
|
| 171 |
+
do += (bos * H + i_h) * V + i_v * BV
|
| 172 |
+
dq += (bos * H + i_h) * K
|
| 173 |
+
dk += (bos * H + i_h) * K
|
| 174 |
+
dv += (bos * H + i_h) * V + i_v * BV
|
| 175 |
+
da += (bos * H + i_h) * K
|
| 176 |
+
db += (bos * H + i_h) * K
|
| 177 |
+
dha += (bos * H + i_h) * V + i_v * BV
|
| 178 |
+
|
| 179 |
+
p_q = q + tl.arange(0, BK) + (T - 1) * H*K
|
| 180 |
+
p_k = k + tl.arange(0, BK) + (T - 1) * H*K
|
| 181 |
+
p_v = v + tl.arange(0, BV) + (T - 1) * H*V
|
| 182 |
+
p_ha = ha + tl.arange(0, BV) + (T - 1) * H*V
|
| 183 |
+
p_a = a + tl.arange(0, BK) + (T - 1) * H*K
|
| 184 |
+
p_b = b + tl.arange(0, BK) + (T - 1) * H*K
|
| 185 |
+
p_do = do + tl.arange(0, BV) + (T - 1) * H*V
|
| 186 |
+
p_dk = dk + tl.arange(0, BK) + (T - 1) * H*K
|
| 187 |
+
p_dv = dv + tl.arange(0, BV) + (T - 1) * H*V
|
| 188 |
+
p_dha = dha + tl.arange(0, BV) + (T - 1) * H*V
|
| 189 |
+
p_db = db + tl.arange(0, BK) + (T - 1) * H*K
|
| 190 |
+
p_da = da + tl.arange(0, BK) + (T - 1) * H*K
|
| 191 |
+
p_dq = dq + tl.arange(0, BK) + (T - 1) * H*K
|
| 192 |
+
|
| 193 |
+
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
| 194 |
+
if USE_DHT:
|
| 195 |
+
p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
|
| 196 |
+
b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
|
| 197 |
+
|
| 198 |
+
for _ in range(T):
|
| 199 |
+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
|
| 200 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 201 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 202 |
+
b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
|
| 203 |
+
b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
|
| 204 |
+
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
|
| 205 |
+
b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
|
| 206 |
+
|
| 207 |
+
b_dh += b_q[:, None] * b_do[None, :]
|
| 208 |
+
d_k = tl.sum(b_dh * b_v[None, :], axis=1)
|
| 209 |
+
d_v = tl.sum(b_dh * b_k[:, None], axis=0)
|
| 210 |
+
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k)
|
| 211 |
+
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v)
|
| 212 |
+
|
| 213 |
+
b_dha = tl.sum(b_dh * b_b[:, None], axis=0)
|
| 214 |
+
tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v)
|
| 215 |
+
b_db = tl.sum(b_dh * b_ha[None, :], axis=1)
|
| 216 |
+
tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k)
|
| 217 |
+
|
| 218 |
+
b_dh += b_dha[None, :] * b_a[:, None]
|
| 219 |
+
p_do -= H*V
|
| 220 |
+
p_q -= H*K
|
| 221 |
+
p_k -= H*K
|
| 222 |
+
p_v -= H*V
|
| 223 |
+
p_dk -= H*K
|
| 224 |
+
p_dv -= H*V
|
| 225 |
+
p_b -= H*K
|
| 226 |
+
p_db -= H*K
|
| 227 |
+
p_a -= H*K
|
| 228 |
+
p_dha -= H*V
|
| 229 |
+
p_ha -= H*V
|
| 230 |
+
|
| 231 |
+
if USE_DH0:
|
| 232 |
+
p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
|
| 233 |
+
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
|
| 234 |
+
|
| 235 |
+
tl.debug_barrier()
|
| 236 |
+
|
| 237 |
+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
| 238 |
+
|
| 239 |
+
if USE_INITIAL_STATE:
|
| 240 |
+
mask_kv = mask_k[:, None] & mask_v[None, :]
|
| 241 |
+
p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
|
| 242 |
+
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
|
| 243 |
+
|
| 244 |
+
p_k = k + tl.arange(0, BK)
|
| 245 |
+
p_v = v + tl.arange(0, BV)
|
| 246 |
+
p_ha = ha + tl.arange(0, BV)
|
| 247 |
+
p_do = do + tl.arange(0, BV)
|
| 248 |
+
p_dha = dha + tl.arange(0, BV)
|
| 249 |
+
p_da = da + tl.arange(0, BK)
|
| 250 |
+
p_dq = dq + tl.arange(0, BK)
|
| 251 |
+
p_b = b + tl.arange(0, BK)
|
| 252 |
+
|
| 253 |
+
for i in range(0, T):
|
| 254 |
+
b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32)
|
| 255 |
+
d_a = tl.sum(b_dha[None, :] * b_h, axis=1)
|
| 256 |
+
tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k)
|
| 257 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 258 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 259 |
+
b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
|
| 260 |
+
b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
|
| 261 |
+
b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
|
| 262 |
+
b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :]
|
| 263 |
+
_d_q = b_h * b_do[None, :]
|
| 264 |
+
d_q = tl.sum(_d_q, axis=1) * scale
|
| 265 |
+
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
|
| 266 |
+
|
| 267 |
+
p_k += H*K
|
| 268 |
+
p_do += H*V
|
| 269 |
+
p_v += H*V
|
| 270 |
+
p_da += H*K
|
| 271 |
+
p_dha += H*V
|
| 272 |
+
p_ha += H*V
|
| 273 |
+
p_dq += H*K
|
| 274 |
+
p_b += H*K
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function):
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
@input_guard
|
| 281 |
+
def forward(
|
| 282 |
+
ctx,
|
| 283 |
+
q: torch.Tensor,
|
| 284 |
+
k: torch.Tensor,
|
| 285 |
+
v: torch.Tensor,
|
| 286 |
+
a: torch.Tensor,
|
| 287 |
+
b: torch.Tensor,
|
| 288 |
+
scale: Optional[float] = None,
|
| 289 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 290 |
+
output_final_state: bool = False,
|
| 291 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 292 |
+
):
|
| 293 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 294 |
+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
| 295 |
+
|
| 296 |
+
BK = triton.next_power_of_2(K)
|
| 297 |
+
if output_final_state:
|
| 298 |
+
final_state = q.new_empty(B, H, K, V, dtype=torch.float32)
|
| 299 |
+
else:
|
| 300 |
+
final_state = None
|
| 301 |
+
|
| 302 |
+
ha = torch.empty_like(v, dtype=torch.float32)
|
| 303 |
+
|
| 304 |
+
def grid(meta): return (
|
| 305 |
+
triton.cdiv(V, meta['BV']),
|
| 306 |
+
N * H
|
| 307 |
+
)
|
| 308 |
+
o = torch.empty_like(v)
|
| 309 |
+
fused_recurrent_fwd_kernel[grid](
|
| 310 |
+
q=q,
|
| 311 |
+
k=k,
|
| 312 |
+
v=v,
|
| 313 |
+
a=a,
|
| 314 |
+
b=b,
|
| 315 |
+
o=o,
|
| 316 |
+
ha=ha,
|
| 317 |
+
h0=initial_state,
|
| 318 |
+
ht=final_state,
|
| 319 |
+
scale=scale,
|
| 320 |
+
cu_seqlens=cu_seqlens,
|
| 321 |
+
H=H,
|
| 322 |
+
T=T,
|
| 323 |
+
K=K,
|
| 324 |
+
V=V,
|
| 325 |
+
BK=BK,
|
| 326 |
+
)
|
| 327 |
+
ctx.save_for_backward(q, k, v, a, b, ha, initial_state)
|
| 328 |
+
ctx.scale = scale
|
| 329 |
+
ctx.cu_seqlens = cu_seqlens
|
| 330 |
+
return o, final_state
|
| 331 |
+
|
| 332 |
+
@staticmethod
|
| 333 |
+
@input_guard
|
| 334 |
+
def backward(ctx, do, dht):
|
| 335 |
+
q, k, v, a, b, ha, initial_state = ctx.saved_tensors
|
| 336 |
+
B, T, H, K, V = *q.shape, v.shape[-1]
|
| 337 |
+
N = B if ctx.cu_seqlens is None else len(ctx.cu_seqlens) - 1
|
| 338 |
+
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
|
| 339 |
+
NV = triton.cdiv(V, BV)
|
| 340 |
+
scale = ctx.scale
|
| 341 |
+
|
| 342 |
+
dq = q.new_empty(NV, *q.shape)
|
| 343 |
+
dk = k.new_empty(NV, *k.shape)
|
| 344 |
+
da = a.new_empty(NV, *a.shape)
|
| 345 |
+
db = b.new_empty(NV, *b.shape)
|
| 346 |
+
dv = torch.empty_like(v)
|
| 347 |
+
dha = torch.empty_like(ha)
|
| 348 |
+
grid = (NV, N * H)
|
| 349 |
+
|
| 350 |
+
if initial_state is not None and initial_state.requires_grad:
|
| 351 |
+
dh0 = torch.empty_like(initial_state, dtype=torch.float32)
|
| 352 |
+
else:
|
| 353 |
+
dh0 = None
|
| 354 |
+
|
| 355 |
+
fused_recurrent_bwd_kernel[grid](
|
| 356 |
+
q=q,
|
| 357 |
+
k=k,
|
| 358 |
+
v=v,
|
| 359 |
+
a=a,
|
| 360 |
+
b=b,
|
| 361 |
+
ha=ha,
|
| 362 |
+
dht=dht,
|
| 363 |
+
dh0=dh0,
|
| 364 |
+
do=do,
|
| 365 |
+
dq=dq,
|
| 366 |
+
dk=dk,
|
| 367 |
+
dv=dv,
|
| 368 |
+
da=da,
|
| 369 |
+
db=db,
|
| 370 |
+
dha=dha,
|
| 371 |
+
h0=initial_state,
|
| 372 |
+
scale=scale,
|
| 373 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 374 |
+
B=B,
|
| 375 |
+
H=H,
|
| 376 |
+
T=T,
|
| 377 |
+
K=K,
|
| 378 |
+
V=V,
|
| 379 |
+
BK=BK,
|
| 380 |
+
BV=BV,
|
| 381 |
+
)
|
| 382 |
+
dq = dq.sum(0)
|
| 383 |
+
dk = dk.sum(0)
|
| 384 |
+
da = da.sum(0)
|
| 385 |
+
db = db.sum(0)
|
| 386 |
+
return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def fused_recurrent_iplr_delta_rule(
|
| 390 |
+
q: torch.Tensor,
|
| 391 |
+
k: torch.Tensor,
|
| 392 |
+
v: torch.Tensor,
|
| 393 |
+
a: torch.Tensor,
|
| 394 |
+
b: torch.Tensor,
|
| 395 |
+
scale: float = None,
|
| 396 |
+
initial_state: torch.Tensor = None,
|
| 397 |
+
output_final_state: bool = False,
|
| 398 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 399 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 400 |
+
r"""
|
| 401 |
+
This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
q (torch.Tensor):
|
| 405 |
+
queries of shape `[B, T, H, K]`
|
| 406 |
+
k (torch.Tensor):
|
| 407 |
+
keys of shape `[B, T, H, K]`
|
| 408 |
+
v (torch.Tensor):
|
| 409 |
+
values of shape `[B, T, H, V]`
|
| 410 |
+
a (torch.Tensor):
|
| 411 |
+
as of shape `[B, T, H, K]`
|
| 412 |
+
b (torch.Tensor):
|
| 413 |
+
bs of shape `[B, T, H, K]`
|
| 414 |
+
scale (Optional[int]):
|
| 415 |
+
Scale factor for the RetNet attention scores.
|
| 416 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
| 417 |
+
initial_state (Optional[torch.Tensor]):
|
| 418 |
+
Initial state of shape `[B, H, K, V]`. Default: `None`.
|
| 419 |
+
output_final_state (Optional[bool]):
|
| 420 |
+
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
|
| 421 |
+
cu_seqlens (torch.LongTensor):
|
| 422 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 423 |
+
consistent with the FlashAttention API.
|
| 424 |
+
|
| 425 |
+
"""
|
| 426 |
+
if cu_seqlens is not None:
|
| 427 |
+
if q.shape[0] != 1:
|
| 428 |
+
raise ValueError(
|
| 429 |
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
| 430 |
+
f"Please flatten variable-length inputs before processing."
|
| 431 |
+
)
|
| 432 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 435 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
| 436 |
+
)
|
| 437 |
+
if scale is None:
|
| 438 |
+
scale = q.shape[-1] ** -0.5
|
| 439 |
+
else:
|
| 440 |
+
assert scale > 0, "scale must be positive"
|
| 441 |
+
o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
|
| 442 |
+
q,
|
| 443 |
+
k,
|
| 444 |
+
v,
|
| 445 |
+
a,
|
| 446 |
+
b,
|
| 447 |
+
scale,
|
| 448 |
+
initial_state,
|
| 449 |
+
output_final_state,
|
| 450 |
+
cu_seqlens
|
| 451 |
+
)
|
| 452 |
+
return o, final_state
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/naive.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
|
| 8 |
+
# q, k, alpha, beta [B, H, L, D_K]
|
| 9 |
+
# v [B, H, L, D_V]
|
| 10 |
+
def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True):
|
| 11 |
+
orig_dtype = q.dtype
|
| 12 |
+
b, h, l, d_k = q.shape
|
| 13 |
+
q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
|
| 14 |
+
d_v = v.shape[-1]
|
| 15 |
+
o = torch.zeros_like(v)
|
| 16 |
+
S = torch.zeros(b, h, d_k, d_v).to(v)
|
| 17 |
+
q = q * (d_k ** -0.5)
|
| 18 |
+
|
| 19 |
+
if initial_state is not None:
|
| 20 |
+
S += initial_state
|
| 21 |
+
|
| 22 |
+
for i in range(l):
|
| 23 |
+
_k = k[:, :, i]
|
| 24 |
+
_q = q[:, :, i]
|
| 25 |
+
_v = v[:, :, i]
|
| 26 |
+
_alpha = alpha[:, :, i]
|
| 27 |
+
_beta = beta[:, :, i]
|
| 28 |
+
_kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
|
| 29 |
+
S = S + _kv
|
| 30 |
+
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
|
| 31 |
+
S = None if output_final_state is False else S
|
| 32 |
+
return o.to(orig_dtype), S
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32):
|
| 36 |
+
b, h, l, d_k = q.shape
|
| 37 |
+
d_v = v.shape[-1]
|
| 38 |
+
q = q * (d_k ** -0.5)
|
| 39 |
+
v = v
|
| 40 |
+
assert l % chunk_size == 0
|
| 41 |
+
|
| 42 |
+
S = k.new_zeros(b, h, d_k, d_v)
|
| 43 |
+
if initial_state is not None:
|
| 44 |
+
S += initial_state
|
| 45 |
+
|
| 46 |
+
# note that diagonal is masked.
|
| 47 |
+
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
|
| 48 |
+
q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta])
|
| 49 |
+
|
| 50 |
+
v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
|
| 51 |
+
attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0)
|
| 52 |
+
for i in range(1, chunk_size):
|
| 53 |
+
attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
|
| 54 |
+
|
| 55 |
+
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
|
| 56 |
+
u = attn @ v2
|
| 57 |
+
w = attn @ alpha
|
| 58 |
+
o = torch.zeros_like(v)
|
| 59 |
+
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
|
| 60 |
+
for i in range(0, l // chunk_size):
|
| 61 |
+
q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
|
| 62 |
+
o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i
|
| 63 |
+
v2_i = u_i + w_i @ S
|
| 64 |
+
o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i)
|
| 65 |
+
o_3 = q_i @ S
|
| 66 |
+
o[:, :, i] = o_1 + o_2 + o_3
|
| 67 |
+
S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i
|
| 68 |
+
S = None if output_final_state is False else S
|
| 69 |
+
return rearrange(o, 'b h n c d -> b h (n c) d'), S
|
build/lib/opencompass/tasks/fla2/ops/generalized_delta_rule/iplr/wy_fast.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
from ....ops.utils import prepare_chunk_indices
|
| 12 |
+
from ....utils import check_shared_mem, is_nvidia_hopper
|
| 13 |
+
|
| 14 |
+
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.heuristics({
|
| 18 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 19 |
+
})
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=[
|
| 22 |
+
triton.Config({}, num_warps=num_warps)
|
| 23 |
+
for num_warps in [1, 2, 4, 8, 16]
|
| 24 |
+
],
|
| 25 |
+
key=['BK']
|
| 26 |
+
)
|
| 27 |
+
@triton.jit(do_not_specialize=['T'])
|
| 28 |
+
def prepare_wy_repr_fwd_kernel_chunk32(
|
| 29 |
+
a,
|
| 30 |
+
b,
|
| 31 |
+
A,
|
| 32 |
+
cu_seqlens,
|
| 33 |
+
chunk_indices,
|
| 34 |
+
T,
|
| 35 |
+
H: tl.constexpr,
|
| 36 |
+
K: tl.constexpr,
|
| 37 |
+
BT: tl.constexpr,
|
| 38 |
+
BK: tl.constexpr,
|
| 39 |
+
BC: tl.constexpr, # dummy placeholder
|
| 40 |
+
IS_VARLEN: tl.constexpr,
|
| 41 |
+
):
|
| 42 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 43 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 44 |
+
if IS_VARLEN:
|
| 45 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 46 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 47 |
+
T = eos - bos
|
| 48 |
+
else:
|
| 49 |
+
bos, eos = i_b * T, i_b * T + T
|
| 50 |
+
|
| 51 |
+
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
| 52 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 53 |
+
p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 54 |
+
p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
| 55 |
+
b_a = tl.load(p_a, boundary_check=(0, 1))
|
| 56 |
+
b_b = tl.load(p_b, boundary_check=(0, 1))
|
| 57 |
+
b_A += tl.dot(b_a, b_b)
|
| 58 |
+
|
| 59 |
+
b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
|
| 60 |
+
for i in range(1, BT):
|
| 61 |
+
mask = tl.arange(0, BT) == i
|
| 62 |
+
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
|
| 63 |
+
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
|
| 64 |
+
b_A = tl.where(mask[:, None], b_a, b_A)
|
| 65 |
+
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
|
| 66 |
+
|
| 67 |
+
p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 68 |
+
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@triton.heuristics({
|
| 72 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 73 |
+
})
|
| 74 |
+
@triton.autotune(
|
| 75 |
+
configs=[
|
| 76 |
+
triton.Config({}, num_warps=num_warps)
|
| 77 |
+
for num_warps in [1, 2, 4, 8, 16]
|
| 78 |
+
],
|
| 79 |
+
key=['BK']
|
| 80 |
+
)
|
| 81 |
+
@triton.jit(do_not_specialize=['T'])
|
| 82 |
+
def prepare_wy_repr_fwd_kernel_chunk64(
|
| 83 |
+
a,
|
| 84 |
+
b,
|
| 85 |
+
A,
|
| 86 |
+
cu_seqlens,
|
| 87 |
+
chunk_indices,
|
| 88 |
+
T,
|
| 89 |
+
H: tl.constexpr,
|
| 90 |
+
K: tl.constexpr,
|
| 91 |
+
BT: tl.constexpr,
|
| 92 |
+
BK: tl.constexpr,
|
| 93 |
+
BC: tl.constexpr,
|
| 94 |
+
IS_VARLEN: tl.constexpr,
|
| 95 |
+
):
|
| 96 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 97 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 98 |
+
if IS_VARLEN:
|
| 99 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 100 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 101 |
+
T = eos - bos
|
| 102 |
+
else:
|
| 103 |
+
bos, eos = i_b * T, i_b * T + T
|
| 104 |
+
|
| 105 |
+
b_A = tl.zeros([BC, BC], dtype=tl.float32)
|
| 106 |
+
b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
|
| 107 |
+
b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
|
| 108 |
+
|
| 109 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 110 |
+
p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
|
| 111 |
+
p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
|
| 112 |
+
p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
|
| 113 |
+
p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
|
| 114 |
+
b_a1 = tl.load(p_a1, boundary_check=(0, 1))
|
| 115 |
+
b_a2 = tl.load(p_a2, boundary_check=(0, 1))
|
| 116 |
+
b_b1 = tl.load(p_b1, boundary_check=(0, 1))
|
| 117 |
+
b_b2 = tl.load(p_b2, boundary_check=(0, 1))
|
| 118 |
+
b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
|
| 119 |
+
b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
|
| 120 |
+
b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
|
| 121 |
+
|
| 122 |
+
b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
|
| 123 |
+
b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
|
| 124 |
+
|
| 125 |
+
for i in range(1, BC):
|
| 126 |
+
mask = tl.arange(0, BC) == i
|
| 127 |
+
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
|
| 128 |
+
b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
|
| 129 |
+
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
|
| 130 |
+
b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
|
| 131 |
+
b_A = tl.where(mask[:, None], b_a, b_A)
|
| 132 |
+
b_A2 = tl.where(mask[:, None], b_a2, b_A2)
|
| 133 |
+
|
| 134 |
+
# blockwise computation of lower triangular matrix's inverse
|
| 135 |
+
# i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
|
| 136 |
+
b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
|
| 137 |
+
b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
|
| 138 |
+
b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
|
| 139 |
+
|
| 140 |
+
p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
|
| 141 |
+
p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
|
| 142 |
+
p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
|
| 143 |
+
p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
|
| 144 |
+
tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
|
| 145 |
+
tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
|
| 146 |
+
tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
|
| 147 |
+
# causal mask
|
| 148 |
+
tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@triton.heuristics({
|
| 152 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 153 |
+
})
|
| 154 |
+
@triton.autotune(
|
| 155 |
+
configs=[
|
| 156 |
+
triton.Config({}, num_warps=num_warps)
|
| 157 |
+
for num_warps in NUM_WARPS
|
| 158 |
+
],
|
| 159 |
+
key=['BT', 'BK', 'BV']
|
| 160 |
+
)
|
| 161 |
+
@triton.jit(do_not_specialize=['T'])
|
| 162 |
+
def wu_fwd_kernel(
|
| 163 |
+
w,
|
| 164 |
+
u,
|
| 165 |
+
a,
|
| 166 |
+
k,
|
| 167 |
+
v,
|
| 168 |
+
A,
|
| 169 |
+
cu_seqlens,
|
| 170 |
+
chunk_indices,
|
| 171 |
+
T,
|
| 172 |
+
H: tl.constexpr,
|
| 173 |
+
K: tl.constexpr,
|
| 174 |
+
V: tl.constexpr,
|
| 175 |
+
BT: tl.constexpr,
|
| 176 |
+
BK: tl.constexpr,
|
| 177 |
+
BV: tl.constexpr,
|
| 178 |
+
IS_VARLEN: tl.constexpr,
|
| 179 |
+
):
|
| 180 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 181 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 182 |
+
if IS_VARLEN:
|
| 183 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 184 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 185 |
+
T = eos - bos
|
| 186 |
+
else:
|
| 187 |
+
bos, eos = i_b * T, i_b * T + T
|
| 188 |
+
|
| 189 |
+
p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
| 190 |
+
|
| 191 |
+
b_A = tl.load(p_A, boundary_check=(0, 1))
|
| 192 |
+
b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
|
| 193 |
+
|
| 194 |
+
for i_k in range(tl.cdiv(K, BK)):
|
| 195 |
+
p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 196 |
+
p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 197 |
+
p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 198 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 199 |
+
b_a = tl.load(p_a, boundary_check=(0, 1))
|
| 200 |
+
b_w = tl.dot(b_A, b_a)
|
| 201 |
+
b_Aak += tl.dot(b_a, tl.trans(b_k))
|
| 202 |
+
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
| 203 |
+
|
| 204 |
+
b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
|
| 205 |
+
b_Aak = b_Aak.to(k.dtype.element_ty)
|
| 206 |
+
|
| 207 |
+
for i_v in range(tl.cdiv(V, BV)):
|
| 208 |
+
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 209 |
+
p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 210 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 211 |
+
b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
|
| 212 |
+
b_u = tl.dot(b_A, b_v)
|
| 213 |
+
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def prepare_wy_repr_fwd(
|
| 217 |
+
a: torch.Tensor,
|
| 218 |
+
b: torch.Tensor,
|
| 219 |
+
v: torch.Tensor,
|
| 220 |
+
k: torch.Tensor,
|
| 221 |
+
cu_seqlens: Optional[torch.LongTensor],
|
| 222 |
+
chunk_size: int = 64
|
| 223 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 224 |
+
B, T, H, K = a.shape
|
| 225 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 226 |
+
|
| 227 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 228 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 229 |
+
BC = min(BT, 32)
|
| 230 |
+
BK = min(triton.next_power_of_2(K), 64)
|
| 231 |
+
|
| 232 |
+
A = torch.empty(B, T, H, BT, device=a.device, dtype=a.dtype)
|
| 233 |
+
fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
|
| 234 |
+
|
| 235 |
+
fwd_fn[(NT, B * H)](
|
| 236 |
+
a=a,
|
| 237 |
+
b=b,
|
| 238 |
+
A=A,
|
| 239 |
+
cu_seqlens=cu_seqlens,
|
| 240 |
+
chunk_indices=chunk_indices,
|
| 241 |
+
T=T,
|
| 242 |
+
H=H,
|
| 243 |
+
K=K,
|
| 244 |
+
BT=BT,
|
| 245 |
+
BK=BK,
|
| 246 |
+
BC=BC,
|
| 247 |
+
)
|
| 248 |
+
w, u = wu_fwd(
|
| 249 |
+
a=a,
|
| 250 |
+
v=v,
|
| 251 |
+
k=k,
|
| 252 |
+
A=A,
|
| 253 |
+
cu_seqlens=cu_seqlens,
|
| 254 |
+
chunk_size=chunk_size
|
| 255 |
+
)
|
| 256 |
+
return w, u, A
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def wu_fwd(
|
| 260 |
+
a: torch.Tensor,
|
| 261 |
+
v: torch.Tensor,
|
| 262 |
+
k: torch.Tensor,
|
| 263 |
+
A: torch.Tensor,
|
| 264 |
+
cu_seqlens: Optional[torch.LongTensor],
|
| 265 |
+
chunk_size: int
|
| 266 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 267 |
+
B, T, H, K, V = *a.shape, v.shape[-1]
|
| 268 |
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
| 269 |
+
|
| 270 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 271 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 272 |
+
CONST_TILING = 64 if check_shared_mem() else 32
|
| 273 |
+
BK = min(triton.next_power_of_2(K), CONST_TILING)
|
| 274 |
+
BV = min(triton.next_power_of_2(V), CONST_TILING)
|
| 275 |
+
|
| 276 |
+
u = torch.empty_like(v)
|
| 277 |
+
w = torch.empty_like(a)
|
| 278 |
+
wu_fwd_kernel[(NT, B*H)](
|
| 279 |
+
a=a,
|
| 280 |
+
v=v,
|
| 281 |
+
w=w,
|
| 282 |
+
u=u,
|
| 283 |
+
A=A,
|
| 284 |
+
k=k,
|
| 285 |
+
cu_seqlens=cu_seqlens,
|
| 286 |
+
chunk_indices=chunk_indices,
|
| 287 |
+
T=T,
|
| 288 |
+
H=H,
|
| 289 |
+
K=K,
|
| 290 |
+
V=V,
|
| 291 |
+
BT=BT,
|
| 292 |
+
BK=BK,
|
| 293 |
+
BV=BV,
|
| 294 |
+
)
|
| 295 |
+
return w, u
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
fwd_prepare_wy_repr = prepare_wy_repr_fwd
|
| 299 |
+
|
| 300 |
+
fwd_wu = wu_fwd
|
docs/en/.readthedocs.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: 2
|
| 2 |
+
|
| 3 |
+
# Set the version of Python and other tools you might need
|
| 4 |
+
build:
|
| 5 |
+
os: ubuntu-22.04
|
| 6 |
+
tools:
|
| 7 |
+
python: "3.8"
|
| 8 |
+
|
| 9 |
+
formats:
|
| 10 |
+
- epub
|
| 11 |
+
|
| 12 |
+
sphinx:
|
| 13 |
+
configuration: docs/en/conf.py
|
| 14 |
+
|
| 15 |
+
python:
|
| 16 |
+
install:
|
| 17 |
+
- requirements: requirements/docs.txt
|
docs/en/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
#
|
| 3 |
+
|
| 4 |
+
# You can set these variables from the command line, and also
|
| 5 |
+
# from the environment for the first two.
|
| 6 |
+
SPHINXOPTS ?=
|
| 7 |
+
SPHINXBUILD ?= sphinx-build
|
| 8 |
+
SOURCEDIR = .
|
| 9 |
+
BUILDDIR = _build
|
| 10 |
+
|
| 11 |
+
# Put it first so that "make" without argument is like "make help".
|
| 12 |
+
help:
|
| 13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
| 14 |
+
|
| 15 |
+
.PHONY: help Makefile
|
| 16 |
+
|
| 17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 19 |
+
%: Makefile
|
| 20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
docs/en/_static/css/readthedocs.css
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.header-logo {
|
| 2 |
+
background-image: url("../image/logo.svg");
|
| 3 |
+
background-size: 275px 80px;
|
| 4 |
+
height: 80px;
|
| 5 |
+
width: 275px;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
@media screen and (min-width: 1100px) {
|
| 9 |
+
.header-logo {
|
| 10 |
+
top: -25px;
|
| 11 |
+
}
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
pre {
|
| 15 |
+
white-space: pre;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
@media screen and (min-width: 2000px) {
|
| 19 |
+
.pytorch-content-left {
|
| 20 |
+
width: 1200px;
|
| 21 |
+
margin-left: 30px;
|
| 22 |
+
}
|
| 23 |
+
article.pytorch-article {
|
| 24 |
+
max-width: 1200px;
|
| 25 |
+
}
|
| 26 |
+
.pytorch-breadcrumbs-wrapper {
|
| 27 |
+
width: 1200px;
|
| 28 |
+
}
|
| 29 |
+
.pytorch-right-menu.scrolling-fixed {
|
| 30 |
+
position: fixed;
|
| 31 |
+
top: 45px;
|
| 32 |
+
left: 1580px;
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
article.pytorch-article section code {
|
| 38 |
+
padding: .2em .4em;
|
| 39 |
+
background-color: #f3f4f7;
|
| 40 |
+
border-radius: 5px;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/* Disable the change in tables */
|
| 44 |
+
article.pytorch-article section table code {
|
| 45 |
+
padding: unset;
|
| 46 |
+
background-color: unset;
|
| 47 |
+
border-radius: unset;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
table.autosummary td {
|
| 51 |
+
width: 50%
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
img.align-center {
|
| 55 |
+
display: block;
|
| 56 |
+
margin-left: auto;
|
| 57 |
+
margin-right: auto;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
article.pytorch-article p.rubric {
|
| 61 |
+
font-weight: bold;
|
| 62 |
+
}
|
docs/en/_static/image/logo.svg
ADDED
|
|
docs/en/_static/image/logo_icon.svg
ADDED
|
|
docs/en/_static/js/custom.js
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
var collapsedSections = ['Dataset Statistics'];
|
| 2 |
+
|
| 3 |
+
$(document).ready(function () {
|
| 4 |
+
$('.dataset').DataTable({
|
| 5 |
+
"stateSave": false,
|
| 6 |
+
"lengthChange": false,
|
| 7 |
+
"pageLength": 20,
|
| 8 |
+
"order": [],
|
| 9 |
+
"language": {
|
| 10 |
+
"info": "Show _START_ to _END_ Items(Totally _TOTAL_ )",
|
| 11 |
+
"infoFiltered": "(Filtered from _MAX_ Items)",
|
| 12 |
+
"search": "Search:",
|
| 13 |
+
"zeroRecords": "Item Not Found",
|
| 14 |
+
"paginate": {
|
| 15 |
+
"next": "Next",
|
| 16 |
+
"previous": "Previous"
|
| 17 |
+
},
|
| 18 |
+
}
|
| 19 |
+
});
|
| 20 |
+
});
|
docs/en/_templates/404.html
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% extends "layout.html" %}
|
| 2 |
+
|
| 3 |
+
{% block body %}
|
| 4 |
+
|
| 5 |
+
<h1>Page Not Found</h1>
|
| 6 |
+
<p>
|
| 7 |
+
The page you are looking for cannot be found.
|
| 8 |
+
</p>
|
| 9 |
+
<p>
|
| 10 |
+
If you just switched documentation versions, it is likely that the page you were on is moved. You can look for it in
|
| 11 |
+
the content table left, or go to <a href="{{ pathto(root_doc) }}">the homepage</a>.
|
| 12 |
+
</p>
|
| 13 |
+
<!-- <p>
|
| 14 |
+
If you cannot find documentation you want, please <a
|
| 15 |
+
href="">open an issue</a> to tell us!
|
| 16 |
+
</p> -->
|
| 17 |
+
|
| 18 |
+
{% endblock %}
|
docs/en/_templates/autosummary/class.rst
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
.. currentmodule:: {{ module }}
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
{{ name | underline}}
|
| 7 |
+
|
| 8 |
+
.. autoclass:: {{ name }}
|
| 9 |
+
:members:
|
| 10 |
+
|
| 11 |
+
..
|
| 12 |
+
autogenerated from _templates/autosummary/class.rst
|
| 13 |
+
note it does not have :inherited-members:
|
docs/en/_templates/callable.rst
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. role:: hidden
|
| 2 |
+
:class: hidden-section
|
| 3 |
+
.. currentmodule:: {{ module }}
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
{{ name | underline}}
|
| 7 |
+
|
| 8 |
+
.. autoclass:: {{ name }}
|
| 9 |
+
:members:
|
| 10 |
+
:special-members: __call__
|
| 11 |
+
|
| 12 |
+
..
|
| 13 |
+
autogenerated from _templates/callable.rst
|
| 14 |
+
note it does not have :inherited-members:
|
docs/en/advanced_guides/accelerator_intro.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Accelerate Evaluation Inference with vLLM or LMDeploy
|
| 2 |
+
|
| 3 |
+
## Background
|
| 4 |
+
|
| 5 |
+
During the OpenCompass evaluation process, the Huggingface transformers library is used for inference by default. While this is a very general solution, there are scenarios where more efficient inference methods are needed to speed up the process, such as leveraging VLLM or LMDeploy.
|
| 6 |
+
|
| 7 |
+
- [LMDeploy](https://github.com/InternLM/lmdeploy) is a toolkit designed for compressing, deploying, and serving large language models (LLMs), developed by the [MMRazor](https://github.com/open-mmlab/mmrazor) and [MMDeploy](https://github.com/open-mmlab/mmdeploy) teams.
|
| 8 |
+
- [vLLM](https://github.com/vllm-project/vllm) is a fast and user-friendly library for LLM inference and serving, featuring advanced serving throughput, efficient PagedAttention memory management, continuous batching of requests, fast model execution via CUDA/HIP graphs, quantization techniques (e.g., GPTQ, AWQ, SqueezeLLM, FP8 KV Cache), and optimized CUDA kernels.
|
| 9 |
+
|
| 10 |
+
## Preparation for Acceleration
|
| 11 |
+
|
| 12 |
+
First, check whether the model you want to evaluate supports inference acceleration using vLLM or LMDeploy. Additionally, ensure you have installed vLLM or LMDeploy as per their official documentation. Below are the installation methods for reference:
|
| 13 |
+
|
| 14 |
+
### LMDeploy Installation Method
|
| 15 |
+
|
| 16 |
+
Install LMDeploy using pip (Python 3.8+) or from [source](https://github.com/InternLM/lmdeploy/blob/main/docs/en/build.md):
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
pip install lmdeploy
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### VLLM Installation Method
|
| 23 |
+
|
| 24 |
+
Install vLLM using pip or from [source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
pip install vllm
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Accelerated Evaluation Using VLLM or LMDeploy
|
| 31 |
+
|
| 32 |
+
### Method 1: Using Command Line Parameters to Change the Inference Backend
|
| 33 |
+
|
| 34 |
+
OpenCompass offers one-click evaluation acceleration. During evaluation, it can automatically convert Huggingface transformer models to VLLM or LMDeploy models for use. Below is an example code for evaluating the GSM8k dataset using the default Huggingface version of the llama3-8b-instruct model:
|
| 35 |
+
|
| 36 |
+
```python
|
| 37 |
+
# eval_gsm8k.py
|
| 38 |
+
from mmengine.config import read_base
|
| 39 |
+
|
| 40 |
+
with read_base():
|
| 41 |
+
# Select a dataset list
|
| 42 |
+
from .datasets.gsm8k.gsm8k_0shot_gen_a58960 import gsm8k_datasets as datasets
|
| 43 |
+
# Select an interested model
|
| 44 |
+
from ..models.hf_llama.hf_llama3_8b_instruct import models
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Here, `hf_llama3_8b_instruct` specifies the original Huggingface model configuration, as shown below:
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
from opencompass.models import HuggingFacewithChatTemplate
|
| 51 |
+
|
| 52 |
+
models = [
|
| 53 |
+
dict(
|
| 54 |
+
type=HuggingFacewithChatTemplate,
|
| 55 |
+
abbr='llama-3-8b-instruct-hf',
|
| 56 |
+
path='meta-llama/Meta-Llama-3-8B-Instruct',
|
| 57 |
+
max_out_len=1024,
|
| 58 |
+
batch_size=8,
|
| 59 |
+
run_cfg=dict(num_gpus=1),
|
| 60 |
+
stop_words=['<|end_of_text|>', '<|eot_id|>'],
|
| 61 |
+
)
|
| 62 |
+
]
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
To evaluate the GSM8k dataset using the default Huggingface version of the llama3-8b-instruct model, use:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
python run.py config/eval_gsm8k.py
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
To accelerate the evaluation using vLLM or LMDeploy, you can use the following script:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
python run.py config/eval_gsm8k.py -a vllm
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
or
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
python run.py config/eval_gsm8k.py -a lmdeploy
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### Method 2: Accelerating Evaluation via Deployed Inference Acceleration Service API
|
| 84 |
+
|
| 85 |
+
OpenCompass also supports accelerating evaluation by deploying vLLM or LMDeploy inference acceleration service APIs. Follow these steps:
|
| 86 |
+
|
| 87 |
+
1. Install the openai package:
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
pip install openai
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
2. Deploy the inference acceleration service API for vLLM or LMDeploy. Below is an example for LMDeploy:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
lmdeploy serve api_server meta-llama/Meta-Llama-3-8B-Instruct --model-name Meta-Llama-3-8B-Instruct --server-port 23333
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
Parameters for starting the api_server can be checked using `lmdeploy serve api_server -h`, such as --tp for tensor parallelism, --session-len for the maximum context window length, --cache-max-entry-count for adjusting the k/v cache memory usage ratio, etc.
|
| 100 |
+
|
| 101 |
+
3. Once the service is successfully deployed, modify the evaluation script by changing the model configuration path to the service address, as shown below:
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
from opencompass.models import OpenAISDK
|
| 105 |
+
|
| 106 |
+
api_meta_template = dict(
|
| 107 |
+
round=[
|
| 108 |
+
dict(role='HUMAN', api_role='HUMAN'),
|
| 109 |
+
dict(role='BOT', api_role='BOT', generate=True),
|
| 110 |
+
],
|
| 111 |
+
reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
models = [
|
| 115 |
+
dict(
|
| 116 |
+
abbr='Meta-Llama-3-8B-Instruct-LMDeploy-API',
|
| 117 |
+
type=OpenAISDK,
|
| 118 |
+
key='EMPTY', # API key
|
| 119 |
+
openai_api_base='http://0.0.0.0:23333/v1', # Service address
|
| 120 |
+
path='Meta-Llama-3-8B-Instruct', # Model name for service request
|
| 121 |
+
tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', # The tokenizer name or path, if set to `None`, uses the default `gpt-4` tokenizer
|
| 122 |
+
rpm_verbose=True, # Whether to print request rate
|
| 123 |
+
meta_template=api_meta_template, # Service request template
|
| 124 |
+
query_per_second=1, # Service request rate
|
| 125 |
+
max_out_len=1024, # Maximum output length
|
| 126 |
+
max_seq_len=4096, # Maximum input length
|
| 127 |
+
temperature=0.01, # Generation temperature
|
| 128 |
+
batch_size=8, # Batch size
|
| 129 |
+
retry=3, # Number of retries
|
| 130 |
+
)
|
| 131 |
+
]
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## Acceleration Effect and Performance Comparison
|
| 135 |
+
|
| 136 |
+
Below is a comparison table of the acceleration effect and performance when using VLLM or LMDeploy on a single A800 GPU for evaluating the Llama-3-8B-Instruct model on the GSM8k dataset:
|
| 137 |
+
|
| 138 |
+
| Inference Backend | Accuracy | Inference Time (minutes:seconds) | Speedup (relative to Huggingface) |
|
| 139 |
+
| ----------------- | -------- | -------------------------------- | --------------------------------- |
|
| 140 |
+
| Huggingface | 74.22 | 24:26 | 1.0 |
|
| 141 |
+
| LMDeploy | 73.69 | 11:15 | 2.2 |
|
| 142 |
+
| VLLM | 72.63 | 07:52 | 3.1 |
|
docs/en/advanced_guides/circular_eval.md
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CircularEval
|
| 2 |
+
|
| 3 |
+
## Background
|
| 4 |
+
|
| 5 |
+
For multiple-choice questions, when a Language Model (LLM) provides the correct option, it does not necessarily imply a true understanding and reasoning of the question. It could be a guess. To differentiate these scenarios and reduce LLM bias towards options, CircularEval (CircularEval) can be utilized. A multiple-choice question is augmented by shuffling its options, and if the LLM correctly answers all variations of the augmented question, it is considered correct under CircularEval.
|
| 6 |
+
|
| 7 |
+
## Adding Your Own CircularEval Dataset
|
| 8 |
+
|
| 9 |
+
Generally, to evaluate a dataset using CircularEval, both its loading and evaluation methods need to be rewritten. Modifications are required in both the OpenCompass main library and configuration files. We will use C-Eval as an example for explanation.
|
| 10 |
+
|
| 11 |
+
OpenCompass main library:
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
from opencompass.datasets.ceval import CEvalDataset
|
| 15 |
+
from opencompass.datasets.circular import CircularDatasetMeta
|
| 16 |
+
|
| 17 |
+
class CircularCEvalDataset(CEvalDataset, metaclass=CircularDatasetMeta):
|
| 18 |
+
# The overloaded dataset class
|
| 19 |
+
dataset_class = CEvalDataset
|
| 20 |
+
|
| 21 |
+
# Splits of the DatasetDict that need CircularEval. For CEvalDataset, which loads [dev, val, test], we only need 'val' and 'test' for CircularEval, not 'dev'
|
| 22 |
+
default_circular_splits = ['val', 'test']
|
| 23 |
+
|
| 24 |
+
# List of keys to be shuffled
|
| 25 |
+
default_option_keys = ['A', 'B', 'C', 'D']
|
| 26 |
+
|
| 27 |
+
# If the content of 'answer_key' is one of ['A', 'B', 'C', 'D'], representing the correct answer. This field indicates how to update the correct answer after shuffling options. Choose either this or default_answer_key_switch_method
|
| 28 |
+
default_answer_key = 'answer'
|
| 29 |
+
|
| 30 |
+
# If 'answer_key' content is not one of ['A', 'B', 'C', 'D'], a function can be used to specify the correct answer after shuffling options. Choose either this or default_answer_key
|
| 31 |
+
# def default_answer_key_switch_method(item, circular_pattern):
|
| 32 |
+
# # 'item' is the original data item
|
| 33 |
+
# # 'circular_pattern' is a tuple indicating the order after shuffling options, e.g., ('D', 'A', 'B', 'C') means the original option A is now D, and so on
|
| 34 |
+
# item['answer'] = circular_pattern['ABCD'.index(item['answer'])]
|
| 35 |
+
# return item
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
`CircularCEvalDataset` accepts the `circular_pattern` parameter with two values:
|
| 39 |
+
|
| 40 |
+
- `circular`: Indicates a single cycle. It is the default value. ABCD is expanded to ABCD, BCDA, CDAB, DABC, a total of 4 variations.
|
| 41 |
+
- `all_possible`: Indicates all permutations. ABCD is expanded to ABCD, ABDC, ACBD, ACDB, ADBC, ADCB, BACD, ..., a total of 24 variations.
|
| 42 |
+
|
| 43 |
+
Additionally, we provide a `CircularEvaluator` to replace `AccEvaluator`. This Evaluator also accepts `circular_pattern`, and it should be consistent with the above. It produces the following metrics:
|
| 44 |
+
|
| 45 |
+
- `acc_{origin|circular|all_possible}`: Treating each question with shuffled options as separate, calculating accuracy.
|
| 46 |
+
- `perf_{origin|circular|all_possible}`: Following Circular logic, a question is considered correct only if all its variations with shuffled options are answered correctly, calculating accuracy.
|
| 47 |
+
- `more_{num}_{origin|circular|all_possible}`: According to Circular logic, a question is deemed correct if the number of its variations answered correctly is greater than or equal to num, calculating accuracy.
|
| 48 |
+
|
| 49 |
+
OpenCompass configuration file:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
from mmengine.config import read_base
|
| 53 |
+
from opencompass.datasets.circular import CircularCEvalDataset
|
| 54 |
+
|
| 55 |
+
with read_base():
|
| 56 |
+
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
|
| 57 |
+
|
| 58 |
+
for d in ceval_datasets:
|
| 59 |
+
# Overloading the load method
|
| 60 |
+
d['type'] = CircularCEvalDataset
|
| 61 |
+
# Renaming for differentiation from non-circular evaluation versions
|
| 62 |
+
d['abbr'] = d['abbr'] + '-circular-4'
|
| 63 |
+
# Overloading the evaluation method
|
| 64 |
+
d['eval_cfg']['evaluator'] = {'type': CircularEvaluator}
|
| 65 |
+
|
| 66 |
+
# The dataset after the above operations looks like this:
|
| 67 |
+
# dict(
|
| 68 |
+
# type=CircularCEvalDataset,
|
| 69 |
+
# path='./data/ceval/formal_ceval', # Unchanged
|
| 70 |
+
# name='computer_network', # Unchanged
|
| 71 |
+
# abbr='ceval-computer_network-circular-4',
|
| 72 |
+
# reader_cfg=dict(...), # Unchanged
|
| 73 |
+
# infer_cfg=dict(...), # Unchanged
|
| 74 |
+
# eval_cfg=dict(evaluator=dict(type=CircularEvaluator), ...),
|
| 75 |
+
# )
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Additionally, for better presentation of results in CircularEval, consider using the following summarizer:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
from mmengine.config import read_base
|
| 84 |
+
from opencompass.summarizers import CircularSummarizer
|
| 85 |
+
|
| 86 |
+
with read_base():
|
| 87 |
+
from ...summarizers.groups.ceval.ceval_summary_groups
|
| 88 |
+
|
| 89 |
+
new_summary_groups = []
|
| 90 |
+
for item in ceval_summary_groups:
|
| 91 |
+
new_summary_groups.append(
|
| 92 |
+
{
|
| 93 |
+
'name': item['name'] + '-circular-4',
|
| 94 |
+
'subsets': [i + '-circular-4' for i in item['subsets']],
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
summarizer = dict(
|
| 99 |
+
type=CircularSummarizer,
|
| 100 |
+
# Select specific metrics to view
|
| 101 |
+
metric_types=['acc_origin', 'perf_circular'],
|
| 102 |
+
dataset_abbrs = [
|
| 103 |
+
'ceval-circular-4',
|
| 104 |
+
'ceval-humanities-circular-4',
|
| 105 |
+
'ceval-stem-circular-4',
|
| 106 |
+
'ceval-social-science-circular-4',
|
| 107 |
+
'ceval-other-circular-4',
|
| 108 |
+
],
|
| 109 |
+
summary_groups=new_summary_groups,
|
| 110 |
+
)
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
For more complex evaluation examples, refer to this sample code: https://github.com/open-compass/opencompass/tree/main/examples/eval_circular.py
|
docs/en/advanced_guides/code_eval.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Evaluation Tutorial
|
| 2 |
+
|
| 3 |
+
This tutorial primarily focuses on evaluating a model's coding proficiency, using `humaneval` and `mbpp` as examples.
|
| 4 |
+
|
| 5 |
+
## pass@1
|
| 6 |
+
|
| 7 |
+
If you only need to generate a single response to evaluate the pass@1 performance, you can directly use [configs/datasets/humaneval/humaneval_gen_8e312c.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/humaneval/humaneval_gen_8e312c.py) and [configs/datasets/mbpp/deprecated_mbpp_gen_1e1056.py](https://github.com/open-compass/opencompass/blob/main/configs/datasets/mbpp/deprecated_mbpp_gen_1e1056.py), referring to the general [quick start tutorial](../get_started/quick_start.md).
|
| 8 |
+
|
| 9 |
+
For multilingual evaluation, please refer to the [Multilingual Code Evaluation Tutorial](./code_eval_service.md).
|
| 10 |
+
|
| 11 |
+
## pass@k
|
| 12 |
+
|
| 13 |
+
If you need to generate multiple responses for a single example to evaluate the pass@k performance, consider the following two situations. Here we take 10 responses as an example:
|
| 14 |
+
|
| 15 |
+
### Typical Situation
|
| 16 |
+
|
| 17 |
+
For most models that support the `num_return_sequences` parameter in HF's generation, we can use it directly to obtain multiple responses. Refer to the following configuration file:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
from opencompass.datasets import MBPPDatasetV2, MBPPPassKEvaluator
|
| 21 |
+
|
| 22 |
+
with read_base():
|
| 23 |
+
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
|
| 24 |
+
from .datasets.mbpp.deprecated_mbpp_gen_1e1056 import mbpp_datasets
|
| 25 |
+
|
| 26 |
+
mbpp_datasets[0]['type'] = MBPPDatasetV2
|
| 27 |
+
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
|
| 28 |
+
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
|
| 29 |
+
|
| 30 |
+
datasets = []
|
| 31 |
+
datasets += humaneval_datasets
|
| 32 |
+
datasets += mbpp_datasets
|
| 33 |
+
|
| 34 |
+
models = [
|
| 35 |
+
dict(
|
| 36 |
+
type=HuggingFaceCausalLM,
|
| 37 |
+
...,
|
| 38 |
+
generation_kwargs=dict(
|
| 39 |
+
num_return_sequences=10,
|
| 40 |
+
do_sample=True,
|
| 41 |
+
top_p=0.95,
|
| 42 |
+
temperature=0.8,
|
| 43 |
+
),
|
| 44 |
+
...,
|
| 45 |
+
)
|
| 46 |
+
]
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
For `mbpp`, new changes are needed in the dataset and evaluation, so we simultaneously modify the `type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` fields to accommodate these requirements.
|
| 50 |
+
|
| 51 |
+
We also need model responses with randomness, thus setting the `generation_kwargs` parameter is necessary. Note that we need to set `num_return_sequences` to get the number of responses.
|
| 52 |
+
|
| 53 |
+
Note: `num_return_sequences` must be greater than or equal to k, as pass@k itself is a probability estimate.
|
| 54 |
+
|
| 55 |
+
You can specifically refer to the following configuration file [examples/eval_code_passk.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_code_passk.py)
|
| 56 |
+
|
| 57 |
+
### For Models That Do Not Support Multiple Responses
|
| 58 |
+
|
| 59 |
+
This applies to some HF models with poorly designed APIs or missing features. In this case, we need to repeatedly construct datasets to achieve multiple response effects. Refer to the following configuration:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
from opencompass.datasets import MBPPDatasetV2, MBPPPassKEvaluator
|
| 63 |
+
|
| 64 |
+
with read_base():
|
| 65 |
+
from .datasets.humaneval.humaneval_gen_8e312c import humaneval_datasets
|
| 66 |
+
from .datasets.mbpp.deprecated_mbpp_gen_1e1056 import mbpp_datasets
|
| 67 |
+
|
| 68 |
+
humaneval_datasets[0]['abbr'] = 'openai_humaneval_pass10'
|
| 69 |
+
humaneval_datasets[0]['num_repeats'] = 10
|
| 70 |
+
mbpp_datasets[0]['abbr'] = 'mbpp_pass10'
|
| 71 |
+
mbpp_datasets[0]['num_repeats'] = 10
|
| 72 |
+
mbpp_datasets[0]['type'] = MBPPDatasetV2
|
| 73 |
+
mbpp_datasets[0]['eval_cfg']['evaluator']['type'] = MBPPPassKEvaluator
|
| 74 |
+
mbpp_datasets[0]['reader_cfg']['output_column'] = 'test_column'
|
| 75 |
+
|
| 76 |
+
datasets = []
|
| 77 |
+
datasets += humaneval_datasets
|
| 78 |
+
datasets += mbpp_datasets
|
| 79 |
+
|
| 80 |
+
models = [
|
| 81 |
+
dict(
|
| 82 |
+
type=HuggingFaceCausalLM,
|
| 83 |
+
...,
|
| 84 |
+
generation_kwargs=dict(
|
| 85 |
+
do_sample=True,
|
| 86 |
+
top_p=0.95,
|
| 87 |
+
temperature=0.8,
|
| 88 |
+
),
|
| 89 |
+
...,
|
| 90 |
+
)
|
| 91 |
+
]
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Since the dataset's prompt has not been modified, we need to replace the corresponding fields to achieve the purpose of repeating the dataset.
|
| 95 |
+
You need to modify these fields:
|
| 96 |
+
|
| 97 |
+
- `num_repeats`: the number of times the dataset is repeated
|
| 98 |
+
- `abbr`: It's best to modify the dataset abbreviation along with the number of repetitions because the number of datasets will change, preventing potential issues arising from discrepancies with the values in `.cache/dataset_size.json`.
|
| 99 |
+
|
| 100 |
+
For `mbpp`, modify the `type`, `eval_cfg.evaluator.type`, `reader_cfg.output_column` fields as well.
|
| 101 |
+
|
| 102 |
+
We also need model responses with randomness, thus setting the `generation_kwargs` parameter is necessary.
|
| 103 |
+
|
| 104 |
+
You can specifically refer to the following configuration file [examples/eval_code_passk_repeat_dataset.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_code_passk_repeat_dataset.py)
|
docs/en/advanced_guides/code_eval_service.md
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Evaluation Docker Tutorial
|
| 2 |
+
|
| 3 |
+
To complete the LLM code capability evaluation, we need to build a separate evaluation environment to avoid executing erroneous code in the development environment, which would inevitably cause losses. The code evaluation service currently used by OpenCompass can refer to the [code-evaluator](https://github.com/open-compass/code-evaluator) project. The following will introduce evaluation tutorials around the code evaluation service.
|
| 4 |
+
|
| 5 |
+
1. humaneval-x
|
| 6 |
+
|
| 7 |
+
This is a multi-programming language dataset [humaneval-x](https://huggingface.co/datasets/THUDM/humaneval-x).
|
| 8 |
+
You can download the dataset from this [download link](https://github.com/THUDM/CodeGeeX2/tree/main/benchmark/humanevalx). Please download the language file (××.jsonl.gz) that needs to be evaluated and place it in the `./data/humanevalx` folder.
|
| 9 |
+
|
| 10 |
+
The currently supported languages are `python`, `cpp`, `go`, `java`, `js`.
|
| 11 |
+
|
| 12 |
+
2. DS1000
|
| 13 |
+
|
| 14 |
+
This is a Python multi-algorithm library dataset [ds1000](https://github.com/xlang-ai/DS-1000).
|
| 15 |
+
You can download the dataset from this [download link](https://github.com/xlang-ai/DS-1000/blob/main/ds1000_data.zip).
|
| 16 |
+
|
| 17 |
+
The currently supported algorithm libraries are `Pandas`, `Numpy`, `Tensorflow`, `Scipy`, `Sklearn`, `Pytorch`, `Matplotlib`.
|
| 18 |
+
|
| 19 |
+
## Launching the Code Evaluation Service
|
| 20 |
+
|
| 21 |
+
1. Ensure you have installed Docker, please refer to [Docker installation document](https://docs.docker.com/engine/install/).
|
| 22 |
+
2. Pull the source code of the code evaluation service project and build the Docker image.
|
| 23 |
+
|
| 24 |
+
Choose the dockerfile corresponding to the dataset you need, and replace `humanevalx` or `ds1000` in the command below.
|
| 25 |
+
|
| 26 |
+
```shell
|
| 27 |
+
git clone https://github.com/open-compass/code-evaluator.git
|
| 28 |
+
docker build -t code-eval-{your-dataset}:latest -f docker/{your-dataset}/Dockerfile .
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
3. Create a container with the following commands:
|
| 32 |
+
|
| 33 |
+
```shell
|
| 34 |
+
# Log output format
|
| 35 |
+
docker run -it -p 5000:5000 code-eval-{your-dataset}:latest python server.py
|
| 36 |
+
|
| 37 |
+
# Run the program in the background
|
| 38 |
+
# docker run -itd -p 5000:5000 code-eval-{your-dataset}:latest python server.py
|
| 39 |
+
|
| 40 |
+
# Using different ports
|
| 41 |
+
# docker run -itd -p 5001:5001 code-eval-{your-dataset}:latest python server.py --port 5001
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
**Note:**
|
| 45 |
+
|
| 46 |
+
- If you encounter a timeout during the evaluation of Go, please use the following command when creating the container.
|
| 47 |
+
|
| 48 |
+
```shell
|
| 49 |
+
docker run -it -p 5000:5000 -e GO111MODULE=on -e GOPROXY=https://goproxy.io code-eval-{your-dataset}:latest python server.py
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
4. To ensure you have access to the service, use the following command to check the inference environment and evaluation service connection status. (If both inferences and code evaluations run on the same host, skip this step.)
|
| 53 |
+
|
| 54 |
+
```shell
|
| 55 |
+
ping your_service_ip_address
|
| 56 |
+
telnet your_service_ip_address your_service_port
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Local Code Evaluation
|
| 60 |
+
|
| 61 |
+
When the model inference and code evaluation services are running on the same host or within the same local area network, direct code reasoning and evaluation can be performed. **Note: DS1000 is currently not supported, please proceed with remote evaluation.**
|
| 62 |
+
|
| 63 |
+
### Configuration File
|
| 64 |
+
|
| 65 |
+
We provide [the configuration file](https://github.com/open-compass/opencompass/blob/main/examples/eval_codegeex2.py) of using `humanevalx` for evaluation on `codegeex2` as reference.
|
| 66 |
+
|
| 67 |
+
The dataset and related post-processing configurations files can be found at this [link](https://github.com/open-compass/opencompass/tree/main/configs/datasets/humanevalx) with attention paid to the `evaluator` field in the humanevalx_eval_cfg_dict.
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
| 71 |
+
from opencompass.openicl.icl_retriever import ZeroRetriever
|
| 72 |
+
from opencompass.openicl.icl_inferencer import GenInferencer
|
| 73 |
+
from opencompass.datasets import HumanevalXDataset, HumanevalXEvaluator
|
| 74 |
+
|
| 75 |
+
humanevalx_reader_cfg = dict(
|
| 76 |
+
input_columns=['prompt'], output_column='task_id', train_split='test')
|
| 77 |
+
|
| 78 |
+
humanevalx_infer_cfg = dict(
|
| 79 |
+
prompt_template=dict(
|
| 80 |
+
type=PromptTemplate,
|
| 81 |
+
template='{prompt}'),
|
| 82 |
+
retriever=dict(type=ZeroRetriever),
|
| 83 |
+
inferencer=dict(type=GenInferencer, max_out_len=1024))
|
| 84 |
+
|
| 85 |
+
humanevalx_eval_cfg_dict = {
|
| 86 |
+
lang : dict(
|
| 87 |
+
evaluator=dict(
|
| 88 |
+
type=HumanevalXEvaluator,
|
| 89 |
+
language=lang,
|
| 90 |
+
ip_address="localhost", # replace to your code_eval_server ip_address, port
|
| 91 |
+
port=5000), # refer to https://github.com/open-compass/code-evaluator to launch a server
|
| 92 |
+
pred_role='BOT')
|
| 93 |
+
for lang in ['python', 'cpp', 'go', 'java', 'js'] # do not support rust now
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
humanevalx_datasets = [
|
| 97 |
+
dict(
|
| 98 |
+
type=HumanevalXDataset,
|
| 99 |
+
abbr=f'humanevalx-{lang}',
|
| 100 |
+
language=lang,
|
| 101 |
+
path='./data/humanevalx',
|
| 102 |
+
reader_cfg=humanevalx_reader_cfg,
|
| 103 |
+
infer_cfg=humanevalx_infer_cfg,
|
| 104 |
+
eval_cfg=humanevalx_eval_cfg_dict[lang])
|
| 105 |
+
for lang in ['python', 'cpp', 'go', 'java', 'js']
|
| 106 |
+
]
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### Task Launch
|
| 110 |
+
|
| 111 |
+
Refer to the [Quick Start](../get_started.html)
|
| 112 |
+
|
| 113 |
+
## Remote Code Evaluation
|
| 114 |
+
|
| 115 |
+
Model inference and code evaluation services located in different machines which cannot be accessed directly require prior model inference before collecting the code evaluation results. The configuration file and inference process can be reused from the previous tutorial.
|
| 116 |
+
|
| 117 |
+
### Collect Inference Results(Only for Humanevalx)
|
| 118 |
+
|
| 119 |
+
In OpenCompass's tools folder, there is a script called `collect_code_preds.py` provided to process and collect the inference results after providing the task launch configuration file during startup along with specifying the working directory used corresponding to the task.
|
| 120 |
+
It is the same with `-r` option in `run.py`. More details can be referred through the [documentation](https://opencompass.readthedocs.io/en/latest/get_started/quick_start.html#launching-evaluation).
|
| 121 |
+
|
| 122 |
+
```shell
|
| 123 |
+
python tools/collect_code_preds.py [config] [-r latest]
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
The collected results will be organized as following under the `-r` folder:
|
| 127 |
+
|
| 128 |
+
```
|
| 129 |
+
workdir/humanevalx
|
| 130 |
+
├── codegeex2-6b
|
| 131 |
+
│ ├── humanevalx_cpp.json
|
| 132 |
+
│ ├── humanevalx_go.json
|
| 133 |
+
│ ├── humanevalx_java.json
|
| 134 |
+
│ ├── humanevalx_js.json
|
| 135 |
+
│ └── humanevalx_python.json
|
| 136 |
+
├── CodeLlama-13b
|
| 137 |
+
│ ├── ...
|
| 138 |
+
├── CodeLlama-13b-Instruct
|
| 139 |
+
│ ├── ...
|
| 140 |
+
├── CodeLlama-13b-Python
|
| 141 |
+
│ ├── ...
|
| 142 |
+
├── ...
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
For DS1000, you just need to obtain the corresponding prediction file generated by `opencompass`.
|
| 146 |
+
|
| 147 |
+
### Code Evaluation
|
| 148 |
+
|
| 149 |
+
Make sure your code evaluation service is started, and use `curl` to request:
|
| 150 |
+
|
| 151 |
+
#### The following only supports Humanevalx
|
| 152 |
+
|
| 153 |
+
```shell
|
| 154 |
+
curl -X POST -F 'file=@{result_absolute_path}' -F 'dataset={dataset/language}' {your_service_ip_address}:{your_service_port}/evaluate
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
For example:
|
| 158 |
+
|
| 159 |
+
```shell
|
| 160 |
+
curl -X POST -F 'file=@./examples/humanevalx/python.json' -F 'dataset=humanevalx/python' localhost:5000/evaluate
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
The we have:
|
| 164 |
+
|
| 165 |
+
```
|
| 166 |
+
"{\"pass@1\": 37.19512195121951%}"
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
Additionally, we offer an extra option named `with_prompt`(Defaults to `True`), since some models(like `WizardCoder`) generate complete codes without requiring the form of concatenating prompt and prediction. You may refer to the following commands for evaluation.
|
| 170 |
+
|
| 171 |
+
```shell
|
| 172 |
+
curl -X POST -F 'file=@./examples/humanevalx/python.json' -F 'dataset=humanevalx/python' -H 'with-prompt: False' localhost:5000/evaluate
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
#### The following only supports DS1000
|
| 176 |
+
|
| 177 |
+
Make sure the code evaluation service is started, then use `curl` to submit a request:
|
| 178 |
+
|
| 179 |
+
```shell
|
| 180 |
+
curl -X POST -F 'file=@./internlm-chat-7b-hf-v11/ds1000_Numpy.json' localhost:5000/evaluate
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
DS1000 supports additional debug parameters. Be aware that a large amount of log will be generated when it is turned on:
|
| 184 |
+
|
| 185 |
+
- `full`: Additional print out of the original prediction for each error sample, post-processing prediction, running program, and final error.
|
| 186 |
+
- `half`: Additional print out of the running program and final error for each error sample.
|
| 187 |
+
- `error`: Additional print out of the final error for each error sample.
|
| 188 |
+
|
| 189 |
+
```shell
|
| 190 |
+
curl -X POST -F 'file=@./internlm-chat-7b-hf-v11/ds1000_Numpy.json' -F 'debug=error' localhost:5000/evaluate
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
You can also modify the `num_workers` in the same way to control the degree of parallelism.
|
| 194 |
+
|
| 195 |
+
## Advanced Tutorial
|
| 196 |
+
|
| 197 |
+
Besides evaluating the supported HUMANEVAList data set, users might also need:
|
| 198 |
+
|
| 199 |
+
### Support New Dataset
|
| 200 |
+
|
| 201 |
+
Please refer to the [tutorial on supporting new datasets](./new_dataset.md).
|
| 202 |
+
|
| 203 |
+
### Modify Post-Processing
|
| 204 |
+
|
| 205 |
+
1. For local evaluation, follow the post-processing section in the tutorial on supporting new datasets to modify the post-processing method.
|
| 206 |
+
2. For remote evaluation, please modify the post-processing part in the tool's `collect_code_preds.py`.
|
| 207 |
+
3. Some parts of post-processing could also be modified in the code evaluation service, more information will be available in the next section.
|
| 208 |
+
|
| 209 |
+
### Debugging Code Evaluation Service
|
| 210 |
+
|
| 211 |
+
When supporting new datasets or modifying post-processors, it is possible that modifications need to be made to the original code evaluation service. Please make changes based on the following steps:
|
| 212 |
+
|
| 213 |
+
1. Remove the installation of the `code-evaluator` in `Dockerfile`, mount the `code-evaluator` when starting the container instead:
|
| 214 |
+
|
| 215 |
+
```shell
|
| 216 |
+
docker run -it -p 5000:5000 -v /local/path/of/code-evaluator:/workspace/code-evaluator code-eval:latest bash
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
2. Install and start the code evaluation service locally. At this point, any necessary modifications can be made to the local copy of the `code-evaluator`.
|
| 220 |
+
|
| 221 |
+
```shell
|
| 222 |
+
cd code-evaluator && pip install -r requirements.txt
|
| 223 |
+
python server.py
|
| 224 |
+
```
|
docs/en/advanced_guides/contamination_eval.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Contamination Assessment
|
| 2 |
+
|
| 3 |
+
**Data Contamination** refers to the phenomenon where data intended for downstream testing tasks appear in the training data of large language models (LLMs), resulting in artificially inflated performance metrics in downstream tasks (such as summarization, natural language inference, text classification), which do not accurately reflect the model's true generalization capabilities.
|
| 4 |
+
|
| 5 |
+
Since the source of data contamination lies in the training data used by LLMs, the most direct method to detect data contamination is to collide test data with training data and then report the extent of overlap between the two. The classic GPT-3 [paper](https://arxiv.org/pdf/2005.14165.pdf) reported on this in Table C.1.
|
| 6 |
+
|
| 7 |
+
However, today's open-source community often only publishes model parameters, not training datasets. In such cases, how to determine the presence and extent of data contamination remains unsolved. OpenCompass offers two possible solutions.
|
| 8 |
+
|
| 9 |
+
## Contamination Data Annotation Based on Self-Built Co-Distribution Data
|
| 10 |
+
|
| 11 |
+
Referencing the method mentioned in Section 5.2 of [Skywork](https://arxiv.org/pdf/2310.19341.pdf), we directly used the dataset [mock_gsm8k_test](https://huggingface.co/datasets/Skywork/mock_gsm8k_test) uploaded to HuggingFace by Skywork.
|
| 12 |
+
|
| 13 |
+
In this method, the authors used GPT-4 to synthesize data similar to the original GSM8K style, and then calculated the perplexity on the GSM8K training set (train), GSM8K test set (test), and GSM8K reference set (ref). Since the GSM8K reference set was newly generated, the authors considered it as clean, not belonging to any training set of any model. They posited:
|
| 14 |
+
|
| 15 |
+
- If the test set's perplexity is significantly lower than the reference set's, the test set might have appeared in the model's training phase;
|
| 16 |
+
- If the training set's perplexity is significantly lower than the test set's, the training set might have been overfitted by the model.
|
| 17 |
+
|
| 18 |
+
The following configuration file can be referenced:
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
from mmengine.config import read_base
|
| 22 |
+
|
| 23 |
+
with read_base():
|
| 24 |
+
from .datasets.gsm8k_contamination.gsm8k_contamination_ppl_ecdd22 import gsm8k_datasets # includes training, test, and reference sets
|
| 25 |
+
from .models.qwen.hf_qwen_7b import models as hf_qwen_7b_model # model under review
|
| 26 |
+
from .models.yi.hf_yi_6b import models as hf_yi_6b_model
|
| 27 |
+
|
| 28 |
+
datasets = [*gsm8k_datasets]
|
| 29 |
+
models = [*hf_qwen_7b_model, *hf_yi_6b_model]
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
An example output is as follows:
|
| 33 |
+
|
| 34 |
+
```text
|
| 35 |
+
dataset version metric mode internlm-7b-hf qwen-7b-hf yi-6b-hf chatglm3-6b-base-hf qwen-14b-hf baichuan2-13b-base-hf internlm-20b-hf aquila2-34b-hf ...
|
| 36 |
+
--------------- --------- ----------- ------- ---------------- ------------ ---------- --------------------- ------------- ----------------------- ----------------- ---------------- ...
|
| 37 |
+
gsm8k-train-ppl 0b8e46 average_ppl unknown 1.5 0.78 1.37 1.16 0.5 0.76 1.41 0.78 ...
|
| 38 |
+
gsm8k-test-ppl 0b8e46 average_ppl unknown 1.56 1.33 1.42 1.3 1.15 1.13 1.52 1.16 ...
|
| 39 |
+
gsm8k-ref-ppl f729ba average_ppl unknown 1.55 1.2 1.43 1.35 1.27 1.19 1.47 1.35 ...
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Currently, this solution only supports the GSM8K dataset. We welcome the community to contribute more datasets.
|
| 43 |
+
|
| 44 |
+
Consider cite the following paper if you find it helpful:
|
| 45 |
+
|
| 46 |
+
```bibtex
|
| 47 |
+
@misc{2023opencompass,
|
| 48 |
+
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
| 49 |
+
author={OpenCompass Contributors},
|
| 50 |
+
howpublished = {\url{https://github.com/open-compass/opencompass}},
|
| 51 |
+
year={2023}
|
| 52 |
+
}
|
| 53 |
+
@misc{wei2023skywork,
|
| 54 |
+
title={Skywork: A More Open Bilingual Foundation Model},
|
| 55 |
+
author={Tianwen Wei and Liang Zhao and Lichang Zhang and Bo Zhu and Lijie Wang and Haihua Yang and Biye Li and Cheng Cheng and Weiwei Lü and Rui Hu and Chenxia Li and Liu Yang and Xilin Luo and Xuejie Wu and Lunan Liu and Wenjun Cheng and Peng Cheng and Jianhao Zhang and Xiaoyu Zhang and Lei Lin and Xiaokun Wang and Yutuan Ma and Chuanhai Dong and Yanqi Sun and Yifu Chen and Yongyi Peng and Xiaojuan Liang and Shuicheng Yan and Han Fang and Yahui Zhou},
|
| 56 |
+
year={2023},
|
| 57 |
+
eprint={2310.19341},
|
| 58 |
+
archivePrefix={arXiv},
|
| 59 |
+
primaryClass={cs.CL}
|
| 60 |
+
}
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Contamination Data Annotation Based on Classic Pre-trained Sets
|
| 64 |
+
|
| 65 |
+
Thanks to [Contamination_Detector](https://github.com/liyucheng09/Contamination_Detector) and @liyucheng09 for providing this method.
|
| 66 |
+
|
| 67 |
+
In this method, the authors search the test datasets (such as C-Eval, ARC, HellaSwag, etc.) using the Common Crawl database and Bing search engine, then mark each test sample as clean / question contaminated / both question and answer contaminated.
|
| 68 |
+
|
| 69 |
+
During testing, OpenCompass
|
| 70 |
+
|
| 71 |
+
will report the accuracy or perplexity of ceval on subsets composed of these three labels. Generally, the accuracy ranges from low to high: clean, question contaminated, both question and answer contaminated subsets. The authors believe:
|
| 72 |
+
|
| 73 |
+
- If the performance of the three is relatively close, the contamination level of the model on that test set is light; otherwise, it is heavy.
|
| 74 |
+
|
| 75 |
+
The following configuration file can be referenced [link](https://github.com/open-compass/opencompass/blob/main/examples/eval_contamination.py):
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
from mmengine.config import read_base
|
| 79 |
+
|
| 80 |
+
with read_base():
|
| 81 |
+
from .datasets.ceval.ceval_clean_ppl import ceval_datasets # ceval dataset with contamination tags
|
| 82 |
+
from .models.yi.hf_yi_6b import models as hf_yi_6b_model # model under review
|
| 83 |
+
from .models.qwen.hf_qwen_7b import models as hf_qwen_7b_model
|
| 84 |
+
from .summarizers.contamination import ceval_summarizer as summarizer # output formatting
|
| 85 |
+
|
| 86 |
+
datasets = [*ceval_datasets]
|
| 87 |
+
models = [*hf_yi_6b_model, *hf_qwen_7b_model]
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
An example output is as follows:
|
| 91 |
+
|
| 92 |
+
```text
|
| 93 |
+
dataset version mode yi-6b-hf - - qwen-7b-hf - - ...
|
| 94 |
+
---------------------------------------------- --------- ------ ---------------- ----------------------------- --------------------------------------- ---------------- ----------------------------- --------------------------------------- ...
|
| 95 |
+
- - - accuracy - clean accuracy - input contaminated accuracy - input-and-label contaminated accuracy - clean accuracy - input contaminated accuracy - input-and-label contaminated ...
|
| 96 |
+
...
|
| 97 |
+
ceval-humanities - ppl 74.42 75.00 82.14 67.44 50.00 70.54 ...
|
| 98 |
+
ceval-stem - ppl 53.70 57.14 85.61 47.41 52.38 67.63 ...
|
| 99 |
+
ceval-social-science - ppl 81.60 84.62 83.09 76.00 61.54 72.79 ...
|
| 100 |
+
ceval-other - ppl 72.31 73.91 75.00 58.46 39.13 61.88 ...
|
| 101 |
+
ceval-hard - ppl 44.35 37.50 70.00 41.13 25.00 30.00 ...
|
| 102 |
+
ceval - ppl 67.32 71.01 81.17 58.97 49.28 67.82 ...
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Currently, this solution only supports the C-Eval, MMLU, HellaSwag and ARC. [Contamination_Detector](https://github.com/liyucheng09/Contamination_Detector) also includes CSQA and WinoGrande, but these have not yet been implemented in OpenCompass. We welcome the community to contribute more datasets.
|
| 106 |
+
|
| 107 |
+
Consider cite the following paper if you find it helpful:
|
| 108 |
+
|
| 109 |
+
```bibtex
|
| 110 |
+
@misc{2023opencompass,
|
| 111 |
+
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
| 112 |
+
author={OpenCompass Contributors},
|
| 113 |
+
howpublished = {\url{https://github.com/open-compass/opencompass}},
|
| 114 |
+
year={2023}
|
| 115 |
+
}
|
| 116 |
+
@article{Li2023AnOS,
|
| 117 |
+
title={An Open Source Data Contamination Report for Llama Series Models},
|
| 118 |
+
author={Yucheng Li},
|
| 119 |
+
journal={ArXiv},
|
| 120 |
+
year={2023},
|
| 121 |
+
volume={abs/2310.17589},
|
| 122 |
+
url={https://api.semanticscholar.org/CorpusID:264490711}
|
| 123 |
+
}
|
| 124 |
+
```
|
docs/en/advanced_guides/custom_dataset.md
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset Quick Evaluation Tutorial
|
| 2 |
+
|
| 3 |
+
OpenCompass provides two paths for quickly evaluating the provided data, the data format protocol based on ChatMLDataset and the data format protocol based on CustomDataset.
|
| 4 |
+
Compared to the complete dataset integration process in [new_dataset.md](./new_dataset.md), these two evaluation paths are more convenient and efficient, being able to directly enter the evaluation process without adding new configuration files.
|
| 5 |
+
But if you have specific needs for custom reading/inference/evaluation, it is recommended to still follow the complete integration process to add a new dataset.
|
| 6 |
+
|
| 7 |
+
## Data Format Protocol and Fast Evaluation Based on ChatMLDataset
|
| 8 |
+
|
| 9 |
+
OpenCompass has recently launched a dataset evaluation mode based on the ChatML dialogue template, which allow users to provide a dataset .json file that conforms to the ChatML dialogue template, and simply set the dataset information config like model configs to start evaluating directly.
|
| 10 |
+
|
| 11 |
+
### Format Requirements for Data Files
|
| 12 |
+
|
| 13 |
+
This evaluation method only supports data files in `.json` format, and each sample must comply with the following format:
|
| 14 |
+
|
| 15 |
+
The format of a text-only dataset with a simple structure:
|
| 16 |
+
|
| 17 |
+
```jsonl
|
| 18 |
+
{
|
| 19 |
+
"question":[
|
| 20 |
+
{
|
| 21 |
+
"role": "system" # Omittable
|
| 22 |
+
"content": Str
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"role": "user",
|
| 26 |
+
"content": Str
|
| 27 |
+
}
|
| 28 |
+
],
|
| 29 |
+
"answer":[
|
| 30 |
+
Str
|
| 31 |
+
]
|
| 32 |
+
}
|
| 33 |
+
{
|
| 34 |
+
...
|
| 35 |
+
}
|
| 36 |
+
...
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
The format of multiple rounds and multiple modes datasets:
|
| 40 |
+
|
| 41 |
+
```jsonl
|
| 42 |
+
{
|
| 43 |
+
"question":[
|
| 44 |
+
{
|
| 45 |
+
"role": "system",
|
| 46 |
+
"content": Str,
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"role": "user",
|
| 50 |
+
"content": Str or List
|
| 51 |
+
[
|
| 52 |
+
{
|
| 53 |
+
"type": Str, # "image"
|
| 54 |
+
"image_url": Str,
|
| 55 |
+
},
|
| 56 |
+
...
|
| 57 |
+
{
|
| 58 |
+
"type": Str, # "text"
|
| 59 |
+
"text": Str,
|
| 60 |
+
},
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"role": "assistant",
|
| 65 |
+
"content": Str
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"role": "user",
|
| 69 |
+
"content": Str or List
|
| 70 |
+
},
|
| 71 |
+
...
|
| 72 |
+
],
|
| 73 |
+
"answer":[
|
| 74 |
+
Str,
|
| 75 |
+
Str,
|
| 76 |
+
...
|
| 77 |
+
]
|
| 78 |
+
}
|
| 79 |
+
{
|
| 80 |
+
...
|
| 81 |
+
}
|
| 82 |
+
...
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
(As OpenCompass currently does not support multi-mode evaluation, the template above is for reference only.)
|
| 86 |
+
|
| 87 |
+
When ChatMLDataset reading `.json` files, it will use `pydantic` to perform simple format validation on the files.
|
| 88 |
+
You can use `tools/chatml_fformat_test.py` to check your provided data file.
|
| 89 |
+
|
| 90 |
+
After format checking, please add a config dictionary named `chatml_datasets` in your running config file to convert the data file into an OpenCompass dataset at runtime.
|
| 91 |
+
An example is as follows:
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
chatml_datasets = [
|
| 95 |
+
dict(
|
| 96 |
+
abbr='YOUR_DATASET_NAME',
|
| 97 |
+
path='YOUR_DATASET_PATH',
|
| 98 |
+
evaluator=dict(
|
| 99 |
+
type='cascade_evaluator',
|
| 100 |
+
rule_evaluator=dict(
|
| 101 |
+
type='math_evaluator',
|
| 102 |
+
),
|
| 103 |
+
llm_evaluator=dict(
|
| 104 |
+
type='llm_evaluator',
|
| 105 |
+
prompt="YOUR_JUDGE_PROMPT",
|
| 106 |
+
judge_cfg=dict(), # YOUR Judge Model Config
|
| 107 |
+
)
|
| 108 |
+
),
|
| 109 |
+
n=1, # Repeat Number
|
| 110 |
+
),
|
| 111 |
+
]
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
The ChatML evaluation module currently provides four preset evaluators, `mcq_rule_evaluator` used for MCQ evaluation, `math_evaluator` used for latex mathematical formula evaluation, `llm_evaluator` used for evaluating answers that are open-ended or difficult to extract), and `cascade_evaluator`, an evaluation mode composed of rule and LLM evaluators cascaded together.
|
| 115 |
+
|
| 116 |
+
In addition, if you have a long-term need to use datasets based on ChatML templates, you can contribute your dataset config to `opencompass/config/chatml_datasets`.
|
| 117 |
+
An eval example of calling these dataset configs is provided in `examples/evalchat_datasets.py`.
|
| 118 |
+
|
| 119 |
+
## Data Format Protocol and Fast Evaluation Based on CustomsDataset
|
| 120 |
+
|
| 121 |
+
(This module is no longer being updated, but it can still be used if there is a need for cli- quick evaluation.)
|
| 122 |
+
|
| 123 |
+
This module support two types of tasks: multiple choice (`mcq`) and question & answer (`qa`). For `mcq`, both ppl and gen inferences are supported; for `qa`, gen inference is supported.
|
| 124 |
+
|
| 125 |
+
### Dataset Format
|
| 126 |
+
|
| 127 |
+
We support datasets in both `.jsonl` and `.csv` formats.
|
| 128 |
+
|
| 129 |
+
#### Multiple Choice (`mcq`)
|
| 130 |
+
|
| 131 |
+
For `mcq` datasets, the default fields are as follows:
|
| 132 |
+
|
| 133 |
+
- `question`: The stem of the multiple-choice question.
|
| 134 |
+
- `A`, `B`, `C`, ...: Single uppercase letters representing the options, with no limit on the number. Defaults to parsing consecutive letters strating from `A` as options.
|
| 135 |
+
- `answer`: The correct answer to the multiple-choice question, which must be one of the options used above, such as `A`, `B`, `C`, etc.
|
| 136 |
+
|
| 137 |
+
Non-default fields will be read in but are not used by default. To use them, specify in the `.meta.json` file.
|
| 138 |
+
|
| 139 |
+
An example of the `.jsonl` format:
|
| 140 |
+
|
| 141 |
+
```jsonl
|
| 142 |
+
{"question": "165+833+650+615=", "A": "2258", "B": "2263", "C": "2281", "answer": "B"}
|
| 143 |
+
{"question": "368+959+918+653+978=", "A": "3876", "B": "3878", "C": "3880", "answer": "A"}
|
| 144 |
+
{"question": "776+208+589+882+571+996+515+726=", "A": "5213", "B": "5263", "C": "5383", "answer": "B"}
|
| 145 |
+
{"question": "803+862+815+100+409+758+262+169=", "A": "4098", "B": "4128", "C": "4178", "answer": "C"}
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
An example of the `.csv` format:
|
| 149 |
+
|
| 150 |
+
```csv
|
| 151 |
+
question,A,B,C,answer
|
| 152 |
+
127+545+588+620+556+199=,2632,2635,2645,B
|
| 153 |
+
735+603+102+335+605=,2376,2380,2410,B
|
| 154 |
+
506+346+920+451+910+142+659+850=,4766,4774,4784,C
|
| 155 |
+
504+811+870+445=,2615,2630,2750,B
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
#### Question & Answer (`qa`)
|
| 159 |
+
|
| 160 |
+
For `qa` datasets, the default fields are as follows:
|
| 161 |
+
|
| 162 |
+
- `question`: The stem of the question & answer question.
|
| 163 |
+
- `answer`: The correct answer to the question & answer question. It can be missing, indicating the dataset has no correct answer.
|
| 164 |
+
|
| 165 |
+
Non-default fields will be read in but are not used by default. To use them, specify in the `.meta.json` file.
|
| 166 |
+
|
| 167 |
+
An example of the `.jsonl` format:
|
| 168 |
+
|
| 169 |
+
```jsonl
|
| 170 |
+
{"question": "752+361+181+933+235+986=", "answer": "3448"}
|
| 171 |
+
{"question": "712+165+223+711=", "answer": "1811"}
|
| 172 |
+
{"question": "921+975+888+539=", "answer": "3323"}
|
| 173 |
+
{"question": "752+321+388+643+568+982+468+397=", "answer": "4519"}
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
An example of the `.csv` format:
|
| 177 |
+
|
| 178 |
+
```csv
|
| 179 |
+
question,answer
|
| 180 |
+
123+147+874+850+915+163+291+604=,3967
|
| 181 |
+
149+646+241+898+822+386=,3142
|
| 182 |
+
332+424+582+962+735+798+653+214=,4700
|
| 183 |
+
649+215+412+495+220+738+989+452=,4170
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### Command Line List
|
| 187 |
+
|
| 188 |
+
Custom datasets can be directly called for evaluation through the command line.
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
python run.py \
|
| 192 |
+
--models hf_llama2_7b \
|
| 193 |
+
--custom-dataset-path xxx/test_mcq.csv \
|
| 194 |
+
--custom-dataset-data-type mcq \
|
| 195 |
+
--custom-dataset-infer-method ppl
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
```bash
|
| 199 |
+
python run.py \
|
| 200 |
+
--models hf_llama2_7b \
|
| 201 |
+
--custom-dataset-path xxx/test_qa.jsonl \
|
| 202 |
+
--custom-dataset-data-type qa \
|
| 203 |
+
--custom-dataset-infer-method gen
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
In most cases, `--custom-dataset-data-type` and `--custom-dataset-infer-method` can be omitted. OpenCompass will
|
| 207 |
+
|
| 208 |
+
set them based on the following logic:
|
| 209 |
+
|
| 210 |
+
- If options like `A`, `B`, `C`, etc., can be parsed from the dataset file, it is considered an `mcq` dataset; otherwise, it is considered a `qa` dataset.
|
| 211 |
+
- The default `infer_method` is `gen`.
|
| 212 |
+
|
| 213 |
+
### Configuration File
|
| 214 |
+
|
| 215 |
+
In the original configuration file, simply add a new item to the `datasets` variable. Custom datasets can be mixed with regular datasets.
|
| 216 |
+
|
| 217 |
+
```python
|
| 218 |
+
datasets = [
|
| 219 |
+
{"path": "xxx/test_mcq.csv", "data_type": "mcq", "infer_method": "ppl"},
|
| 220 |
+
{"path": "xxx/test_qa.jsonl", "data_type": "qa", "infer_method": "gen"},
|
| 221 |
+
]
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### Supplemental Information for Dataset `.meta.json`
|
| 225 |
+
|
| 226 |
+
OpenCompass will try to parse the input dataset file by default, so in most cases, the `.meta.json` file is **not necessary**. However, if the dataset field names are not the default ones, or custom prompt words are required, it should be specified in the `.meta.json` file.
|
| 227 |
+
|
| 228 |
+
The file is placed in the same directory as the dataset, with the filename followed by `.meta.json`. An example file structure is as follows:
|
| 229 |
+
|
| 230 |
+
```tree
|
| 231 |
+
.
|
| 232 |
+
├── test_mcq.csv
|
| 233 |
+
├── test_mcq.csv.meta.json
|
| 234 |
+
├── test_qa.jsonl
|
| 235 |
+
└── test_qa.jsonl.meta.json
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
Possible fields in this file include:
|
| 239 |
+
|
| 240 |
+
- `abbr` (str): Abbreviation of the dataset, serving as its ID.
|
| 241 |
+
- `data_type` (str): Type of dataset, options are `mcq` and `qa`.
|
| 242 |
+
- `infer_method` (str): Inference method, options are `ppl` and `gen`.
|
| 243 |
+
- `human_prompt` (str): User prompt template for generating prompts. Variables in the template are enclosed in `{}`, like `{question}`, `{opt1}`, etc. If `template` exists, this field will be ignored.
|
| 244 |
+
- `bot_prompt` (str): Bot prompt template for generating prompts. Variables in the template are enclosed in `{}`, like `{answer}`, etc. If `template` exists, this field will be ignored.
|
| 245 |
+
- `template` (str or dict): Question template for generating prompts. Variables in the template are enclosed in `{}`, like `{question}`, `{opt1}`, etc. The relevant syntax is in [here](../prompt/prompt_template.md) regarding `infer_cfg['prompt_template']['template']`.
|
| 246 |
+
- `input_columns` (list): List of input fields for reading data.
|
| 247 |
+
- `output_column` (str): Output field for reading data.
|
| 248 |
+
- `options` (list): List of options for reading data, valid only when `data_type` is `mcq`.
|
| 249 |
+
|
| 250 |
+
For example:
|
| 251 |
+
|
| 252 |
+
```json
|
| 253 |
+
{
|
| 254 |
+
"human_prompt": "Question: 127 + 545 + 588 + 620 + 556 + 199 =\nA. 2632\nB. 2635\nC. 2645\nAnswer: Let's think step by step, 127 + 545 + 588 + 620 + 556 + 199 = 672 + 588 + 620 + 556 + 199 = 1260 + 620 + 556 + 199 = 1880 + 556 + 199 = 2436 + 199 = 2635. So the answer is B.\nQuestion: {question}\nA. {A}\nB. {B}\nC. {C}\nAnswer: ",
|
| 255 |
+
"bot_prompt": "{answer}"
|
| 256 |
+
}
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
or
|
| 260 |
+
|
| 261 |
+
```json
|
| 262 |
+
{
|
| 263 |
+
"template": "Question: {my_question}\nX. {X}\nY. {Y}\nZ. {Z}\nW. {W}\nAnswer:",
|
| 264 |
+
"input_columns": ["my_question", "X", "Y", "Z", "W"],
|
| 265 |
+
"output_column": "my_answer",
|
| 266 |
+
}
|
| 267 |
+
```
|
docs/en/advanced_guides/evaluation_lightllm.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation with Lightllm
|
| 2 |
+
|
| 3 |
+
We now support the evaluation of large language models using [Lightllm](https://github.com/ModelTC/lightllm) for inference. Developed by SenseTime, LightLLM is a Python-based LLM (Large Language Model) inference and serving framework, notable for its lightweight design, easy scalability, and high-speed performance. Lightllm provides support for various large Language models, allowing users to perform model inference through Lightllm, locally deploying it as a service. During the evaluation process, OpenCompass feeds data to Lightllm through an API and processes the response. OpenCompass has been adapted for compatibility with Lightllm, and this tutorial will guide you on using OpenCompass to evaluate models with Lightllm as the inference backend.
|
| 4 |
+
|
| 5 |
+
## Setup
|
| 6 |
+
|
| 7 |
+
### Install OpenCompass
|
| 8 |
+
|
| 9 |
+
Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets.
|
| 10 |
+
|
| 11 |
+
### Install Lightllm
|
| 12 |
+
|
| 13 |
+
Please follow the [Lightllm homepage](https://github.com/ModelTC/lightllm) to install the Lightllm. Pay attention to aligning the versions of relevant dependencies, especially the version of the Transformers.
|
| 14 |
+
|
| 15 |
+
## Evaluation
|
| 16 |
+
|
| 17 |
+
We use the evaluation of Humaneval with the llama2-7B model as an example.
|
| 18 |
+
|
| 19 |
+
### Step-1: Deploy the model locally as a service using Lightllm.
|
| 20 |
+
|
| 21 |
+
```shell
|
| 22 |
+
python -m lightllm.server.api_server --model_dir /path/llama2-7B \
|
| 23 |
+
--host 0.0.0.0 \
|
| 24 |
+
--port 1030 \
|
| 25 |
+
--nccl_port 2066 \
|
| 26 |
+
--max_req_input_len 4096 \
|
| 27 |
+
--max_req_total_len 6144 \
|
| 28 |
+
--tp 1 \
|
| 29 |
+
--trust_remote_code \
|
| 30 |
+
--max_total_token_num 120000
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
\*\*Note: \*\* tp can be configured to enable TensorParallel inference on several gpus, suitable for the inference of very large models.
|
| 34 |
+
|
| 35 |
+
\*\*Note: \*\* The max_total_token_num in the above command will affect the throughput performance during testing. It can be configured according to the documentation on the [Lightllm homepage](https://github.com/ModelTC/lightllm). As long as it does not run out of memory, it is often better to set it as high as possible.
|
| 36 |
+
|
| 37 |
+
\*\*Note: \*\* If you want to start multiple LightLLM services on the same machine, you need to reconfigure the above port and nccl_port.
|
| 38 |
+
|
| 39 |
+
You can use the following Python script to quickly test whether the current service has been successfully started.
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
import time
|
| 43 |
+
import requests
|
| 44 |
+
import json
|
| 45 |
+
|
| 46 |
+
url = 'http://localhost:8080/generate'
|
| 47 |
+
headers = {'Content-Type': 'application/json'}
|
| 48 |
+
data = {
|
| 49 |
+
'inputs': 'What is AI?',
|
| 50 |
+
"parameters": {
|
| 51 |
+
'do_sample': False,
|
| 52 |
+
'ignore_eos': False,
|
| 53 |
+
'max_new_tokens': 1024,
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
response = requests.post(url, headers=headers, data=json.dumps(data))
|
| 57 |
+
if response.status_code == 200:
|
| 58 |
+
print(response.json())
|
| 59 |
+
else:
|
| 60 |
+
print('Error:', response.status_code, response.text)
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### Step-2: Evaluate the above model using OpenCompass.
|
| 64 |
+
|
| 65 |
+
```shell
|
| 66 |
+
python run.py examples/eval_lightllm.py
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
You are expected to get the evaluation results after the inference and evaluation.
|
| 70 |
+
|
| 71 |
+
\*\*Note: \*\*In `eval_lightllm.py`, please align the configured URL with the service address from the previous step.
|
docs/en/advanced_guides/evaluation_lmdeploy.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation with LMDeploy
|
| 2 |
+
|
| 3 |
+
We now support evaluation of models accelerated by the [LMDeploy](https://github.com/InternLM/lmdeploy). LMDeploy is a toolkit designed for compressing, deploying, and serving LLM. It has a remarkable inference performance. We now illustrate how to evaluate a model with the support of LMDeploy in OpenCompass.
|
| 4 |
+
|
| 5 |
+
## Setup
|
| 6 |
+
|
| 7 |
+
### Install OpenCompass
|
| 8 |
+
|
| 9 |
+
Please follow the [instructions](https://opencompass.readthedocs.io/en/latest/get_started/installation.html) to install the OpenCompass and prepare the evaluation datasets.
|
| 10 |
+
|
| 11 |
+
### Install LMDeploy
|
| 12 |
+
|
| 13 |
+
Install lmdeploy via pip (python 3.8+)
|
| 14 |
+
|
| 15 |
+
```shell
|
| 16 |
+
pip install lmdeploy
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
The default prebuilt package is compiled on CUDA 12. However, if CUDA 11+ is required, you can install lmdeploy by:
|
| 20 |
+
|
| 21 |
+
```shell
|
| 22 |
+
export LMDEPLOY_VERSION=0.6.0
|
| 23 |
+
export PYTHON_VERSION=310
|
| 24 |
+
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Evaluation
|
| 28 |
+
|
| 29 |
+
When evaluating a model, it is necessary to prepare an evaluation configuration that specifies information such as the evaluation dataset, the model, and inference parameters.
|
| 30 |
+
|
| 31 |
+
Taking [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) as an example, the evaluation config is as follows:
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
# configure the dataset
|
| 35 |
+
from mmengine.config import read_base
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
with read_base():
|
| 39 |
+
# choose a list of datasets
|
| 40 |
+
from .datasets.mmlu.mmlu_gen_a484b3 import mmlu_datasets
|
| 41 |
+
from .datasets.ceval.ceval_gen_5f30c7 import ceval_datasets
|
| 42 |
+
from .datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
|
| 43 |
+
from opencompass.configs.datasets.gsm8k.gsm8k_0shot_v2_gen_a58960 import \
|
| 44 |
+
gsm8k_datasets
|
| 45 |
+
# and output the results in a chosen format
|
| 46 |
+
from .summarizers.medium import summarizer
|
| 47 |
+
|
| 48 |
+
datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
|
| 49 |
+
|
| 50 |
+
# configure lmdeploy
|
| 51 |
+
from opencompass.models import TurboMindModelwithChatTemplate
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# configure the model
|
| 56 |
+
models = [
|
| 57 |
+
dict(
|
| 58 |
+
type=TurboMindModelwithChatTemplate,
|
| 59 |
+
abbr=f'internlm2-chat-7b-lmdeploy',
|
| 60 |
+
# model path, which can be the address of a model repository on the Hugging Face Hub or a local path
|
| 61 |
+
path='internlm/internlm2-chat-7b',
|
| 62 |
+
# inference backend of LMDeploy. It can be either 'turbomind' or 'pytorch'.
|
| 63 |
+
# If the model is not supported by 'turbomind', it will fallback to
|
| 64 |
+
# 'pytorch'
|
| 65 |
+
backend='turbomind',
|
| 66 |
+
# For the detailed engine config and generation config, please refer to
|
| 67 |
+
# https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/messages.py
|
| 68 |
+
engine_config=dict(tp=1),
|
| 69 |
+
gen_config=dict(do_sample=False),
|
| 70 |
+
# the max size of the context window
|
| 71 |
+
max_seq_len=7168,
|
| 72 |
+
# the max number of new tokens
|
| 73 |
+
max_out_len=1024,
|
| 74 |
+
# the max number of prompts that LMDeploy receives
|
| 75 |
+
# in `generate` function
|
| 76 |
+
batch_size=5000,
|
| 77 |
+
run_cfg=dict(num_gpus=1),
|
| 78 |
+
)
|
| 79 |
+
]
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Place the aforementioned configuration in a file, such as "configs/eval_internlm2_lmdeploy.py". Then, in the home folder of OpenCompass, start evaluation by the following command:
|
| 83 |
+
|
| 84 |
+
```shell
|
| 85 |
+
python run.py configs/eval_internlm2_lmdeploy.py -w outputs
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
You are expected to get the evaluation results after the inference and evaluation.
|
docs/en/advanced_guides/llm_judge.md
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM as Judge Evaluation
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
The GenericLLMEvaluator is particularly useful for scenarios where rule-based methods (like regular expressions) cannot perfectly judge outputs, such as:
|
| 6 |
+
|
| 7 |
+
- Cases where models output answer content without option identifiers
|
| 8 |
+
- Factual judgment datasets that are difficult to evaluate with rules
|
| 9 |
+
- Open-ended responses requiring complex understanding and reasoning
|
| 10 |
+
- Evaluation that requires a lot of rules to be designed
|
| 11 |
+
|
| 12 |
+
OpenCompass provides the GenericLLMEvaluator component to facilitate LLM-as-judge evaluations.
|
| 13 |
+
|
| 14 |
+
## Dataset Format
|
| 15 |
+
|
| 16 |
+
The dataset for LLM judge evaluation should be in either JSON Lines (.jsonl) or CSV format. Each entry should contain at least:
|
| 17 |
+
|
| 18 |
+
- A problem or question
|
| 19 |
+
- A reference answer or gold standard
|
| 20 |
+
- (The model's prediction will be generated during evaluation)
|
| 21 |
+
|
| 22 |
+
Example JSONL format:
|
| 23 |
+
|
| 24 |
+
```json
|
| 25 |
+
{"problem": "What is the capital of France?", "answer": "Paris"}
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Example CSV format:
|
| 29 |
+
|
| 30 |
+
```csv
|
| 31 |
+
problem,answer
|
| 32 |
+
"What is the capital of France?","Paris"
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Configuration
|
| 36 |
+
|
| 37 |
+
### Using LLM for Evaluation via Command Line
|
| 38 |
+
|
| 39 |
+
Some datasets in OpenCompass already include LLM judge configurations.
|
| 40 |
+
You need to use a model service (such as OpenAI or DeepSeek's official API) or start a model service locally using tools like LMDeploy, vLLM, or SGLang.
|
| 41 |
+
|
| 42 |
+
Then, you can set the environment variables for the evaluation service and evaluate models using the following commands:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
export OC_JUDGE_MODEL=Qwen/Qwen2.5-32B-Instruct
|
| 46 |
+
export OC_JUDGE_API_KEY=sk-1234
|
| 47 |
+
export OC_JUDGE_API_BASE=http://172.30.56.1:4000/v1
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Note that by default, OpenCompass will use these three environment variables, but if you use configuration files to configure the evaluation service, these environment variables will not take effect.
|
| 51 |
+
|
| 52 |
+
### Using LLM for Evaluation via Configuration Files
|
| 53 |
+
|
| 54 |
+
To set up an LLM judge evaluation, you'll need to configure three main components:
|
| 55 |
+
|
| 56 |
+
1. Dataset Reader Configuration
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
reader_cfg = dict(
|
| 60 |
+
input_columns=['problem'], # Column name for the question
|
| 61 |
+
output_column='answer' # Column name for the reference answer
|
| 62 |
+
)
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
2. Inference Configuration
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
infer_cfg = dict(
|
| 69 |
+
prompt_template=dict(
|
| 70 |
+
type=PromptTemplate,
|
| 71 |
+
template=dict(
|
| 72 |
+
round=[
|
| 73 |
+
dict(
|
| 74 |
+
role='HUMAN',
|
| 75 |
+
prompt='{problem}', # Template for prompting the model
|
| 76 |
+
),
|
| 77 |
+
]
|
| 78 |
+
),
|
| 79 |
+
),
|
| 80 |
+
retriever=dict(type=ZeroRetriever),
|
| 81 |
+
inferencer=dict(type=GenInferencer),
|
| 82 |
+
)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
3. Evaluation Configuration with LLM Judge
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
eval_cfg = dict(
|
| 89 |
+
evaluator=dict(
|
| 90 |
+
type=GenericLLMEvaluator, # Using LLM as evaluator
|
| 91 |
+
prompt_template=dict(
|
| 92 |
+
type=PromptTemplate,
|
| 93 |
+
template=dict(
|
| 94 |
+
begin=[
|
| 95 |
+
dict(
|
| 96 |
+
role='SYSTEM',
|
| 97 |
+
fallback_role='HUMAN',
|
| 98 |
+
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
|
| 99 |
+
)
|
| 100 |
+
],
|
| 101 |
+
round=[
|
| 102 |
+
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE), # Template for the judge
|
| 103 |
+
],
|
| 104 |
+
),
|
| 105 |
+
),
|
| 106 |
+
dataset_cfg=dict(
|
| 107 |
+
type=CustomDataset,
|
| 108 |
+
path='path/to/your/dataset',
|
| 109 |
+
file_name='your_dataset.jsonl',
|
| 110 |
+
reader_cfg=reader_cfg,
|
| 111 |
+
),
|
| 112 |
+
judge_cfg=YOUR_JUDGE_MODEL_CONFIG, # Configuration for the judge model
|
| 113 |
+
dict_postprocessor=dict(type=generic_llmjudge_postprocess), # Post-processing the judge's output
|
| 114 |
+
),
|
| 115 |
+
)
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
## Using CustomDataset with GenericLLMEvaluator
|
| 119 |
+
|
| 120 |
+
Here's how to set up a complete configuration for LLM judge evaluation:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
from mmengine.config import read_base
|
| 124 |
+
from opencompass.models import TurboMindModelwithChatTemplate
|
| 125 |
+
from opencompass.datasets import CustomDataset
|
| 126 |
+
from opencompass.evaluator import GenericLLMEvaluator
|
| 127 |
+
from opencompass.datasets import generic_llmjudge_postprocess
|
| 128 |
+
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
| 129 |
+
from opencompass.openicl.icl_retriever import ZeroRetriever
|
| 130 |
+
from opencompass.openicl.icl_inferencer import GenInferencer
|
| 131 |
+
|
| 132 |
+
# Import your judge model configuration
|
| 133 |
+
with read_base():
|
| 134 |
+
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_14b_instruct import (
|
| 135 |
+
models as judge_model,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Define your judge template
|
| 139 |
+
JUDGE_TEMPLATE = """
|
| 140 |
+
Please evaluate whether the following response correctly answers the question.
|
| 141 |
+
Question: {problem}
|
| 142 |
+
Reference Answer: {answer}
|
| 143 |
+
Model Response: {prediction}
|
| 144 |
+
|
| 145 |
+
Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
|
| 146 |
+
""".strip()
|
| 147 |
+
|
| 148 |
+
# Dataset reader configuration
|
| 149 |
+
reader_cfg = dict(input_columns=['problem'], output_column='answer')
|
| 150 |
+
|
| 151 |
+
# Inference configuration for the model being evaluated
|
| 152 |
+
infer_cfg = dict(
|
| 153 |
+
prompt_template=dict(
|
| 154 |
+
type=PromptTemplate,
|
| 155 |
+
template=dict(
|
| 156 |
+
round=[
|
| 157 |
+
dict(
|
| 158 |
+
role='HUMAN',
|
| 159 |
+
prompt='{problem}',
|
| 160 |
+
),
|
| 161 |
+
]
|
| 162 |
+
),
|
| 163 |
+
),
|
| 164 |
+
retriever=dict(type=ZeroRetriever),
|
| 165 |
+
inferencer=dict(type=GenInferencer),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Evaluation configuration with LLM judge
|
| 169 |
+
eval_cfg = dict(
|
| 170 |
+
evaluator=dict(
|
| 171 |
+
type=GenericLLMEvaluator,
|
| 172 |
+
prompt_template=dict(
|
| 173 |
+
type=PromptTemplate,
|
| 174 |
+
template=dict(
|
| 175 |
+
begin=[
|
| 176 |
+
dict(
|
| 177 |
+
role='SYSTEM',
|
| 178 |
+
fallback_role='HUMAN',
|
| 179 |
+
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
|
| 180 |
+
)
|
| 181 |
+
],
|
| 182 |
+
round=[
|
| 183 |
+
dict(role='HUMAN', prompt=JUDGE_TEMPLATE),
|
| 184 |
+
],
|
| 185 |
+
),
|
| 186 |
+
),
|
| 187 |
+
dataset_cfg=dict(
|
| 188 |
+
type=CustomDataset,
|
| 189 |
+
path='path/to/your/dataset',
|
| 190 |
+
file_name='your_dataset.jsonl',
|
| 191 |
+
reader_cfg=reader_cfg,
|
| 192 |
+
),
|
| 193 |
+
judge_cfg=judge_model[0],
|
| 194 |
+
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
|
| 195 |
+
),
|
| 196 |
+
pred_role='BOT',
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Dataset configuration
|
| 200 |
+
datasets = [
|
| 201 |
+
dict(
|
| 202 |
+
type=CustomDataset,
|
| 203 |
+
abbr='my-dataset',
|
| 204 |
+
path='path/to/your/dataset',
|
| 205 |
+
file_name='your_dataset.jsonl',
|
| 206 |
+
reader_cfg=reader_cfg,
|
| 207 |
+
infer_cfg=infer_cfg,
|
| 208 |
+
eval_cfg=eval_cfg,
|
| 209 |
+
)
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
# Model configuration for the model being evaluated
|
| 213 |
+
models = [
|
| 214 |
+
dict(
|
| 215 |
+
type=TurboMindModelwithChatTemplate,
|
| 216 |
+
abbr='model-to-evaluate',
|
| 217 |
+
path='path/to/your/model',
|
| 218 |
+
# ... other model configurations
|
| 219 |
+
)
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
# Output directory
|
| 223 |
+
work_dir = './outputs/llm_judge_eval'
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
## GenericLLMEvaluator
|
| 227 |
+
|
| 228 |
+
The GenericLLMEvaluator is designed to use an LLM as a judge for evaluating model outputs. Key features include:
|
| 229 |
+
|
| 230 |
+
1. Flexible prompt templates for instructing the judge
|
| 231 |
+
2. Support for various judge models (local or API-based)
|
| 232 |
+
3. Customizable evaluation criteria through prompt engineering
|
| 233 |
+
4. Post-processing of judge outputs to extract structured evaluations
|
| 234 |
+
|
| 235 |
+
**Important Note**: The current generic version of the judge template only supports outputs in the format of "A" (correct) or "B" (incorrect), and does not support other output formats (like "CORRECT" or "INCORRECT"). This is because the post-processing function `generic_llmjudge_postprocess` is specifically designed to parse this format.
|
| 236 |
+
|
| 237 |
+
The evaluator works by:
|
| 238 |
+
|
| 239 |
+
1. Taking the original problem, reference answer, and model prediction
|
| 240 |
+
2. Formatting them into a prompt for the judge model
|
| 241 |
+
3. Parsing the judge's response to determine the evaluation result (looking for "A" or "B")
|
| 242 |
+
4. Aggregating results across the dataset
|
| 243 |
+
|
| 244 |
+
If you would like to see the full details of evaluation results, you can add `--dump-eval-details` to the command line when you start the job.
|
| 245 |
+
Example evaluation output:
|
| 246 |
+
|
| 247 |
+
```python
|
| 248 |
+
{
|
| 249 |
+
'accuracy': 75.0, # Percentage of responses judged as correct
|
| 250 |
+
'details': [
|
| 251 |
+
{
|
| 252 |
+
'origin_prompt': """
|
| 253 |
+
Please evaluate whether the following response correctly answers the question.
|
| 254 |
+
Question: What is the capital of France?
|
| 255 |
+
Reference Answer: Paris
|
| 256 |
+
Model Response: Paris
|
| 257 |
+
Is the model response correct? If correct, answer "A"; if incorrect, answer "B".
|
| 258 |
+
""",
|
| 259 |
+
'gold': 'Paris',
|
| 260 |
+
'prediction': 'A',
|
| 261 |
+
},
|
| 262 |
+
# ... more results
|
| 263 |
+
]
|
| 264 |
+
}
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
## CascadeEvaluator
|
| 268 |
+
|
| 269 |
+
OpenCompass also provides a CascadeEvaluator that combines the strengths of rule-based evaluation and LLM-based evaluation. The cascade evaluator has two modes:
|
| 270 |
+
|
| 271 |
+
1. **Cascade Mode (parallel=False)**: First evaluates all samples with a rule-based evaluator, then only sends samples that were deemed incorrect by the rule-based evaluation to an LLM judge for re-evaluation. This approach reduces reliance on LLM judgments while maintaining accuracy, thus lowering evaluation costs and time.
|
| 272 |
+
|
| 273 |
+
2. **Parallel Mode (parallel=True)**: Evaluates all samples with both the rule-based evaluator and LLM judge, then considers a sample correct if either method marks it as correct. This approach can increase the leniency of evaluation but may result in higher costs since all samples require LLM evaluation.
|
| 274 |
+
|
| 275 |
+
### Configuring CascadeEvaluator
|
| 276 |
+
|
| 277 |
+
Here's an example of how to configure the CascadeEvaluator:
|
| 278 |
+
|
| 279 |
+
```python
|
| 280 |
+
# Define a rule-based evaluator
|
| 281 |
+
rule_evaluator = dict(type=MATHVerifyEvaluator)
|
| 282 |
+
|
| 283 |
+
# Define an LLM judge evaluator
|
| 284 |
+
llm_judge_evaluator = dict(
|
| 285 |
+
type=GenericLLMEvaluator,
|
| 286 |
+
prompt_template=dict(
|
| 287 |
+
type=PromptTemplate,
|
| 288 |
+
template=dict(
|
| 289 |
+
begin=[
|
| 290 |
+
dict(
|
| 291 |
+
role='SYSTEM',
|
| 292 |
+
fallback_role='HUMAN',
|
| 293 |
+
prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.",
|
| 294 |
+
)
|
| 295 |
+
],
|
| 296 |
+
round=[
|
| 297 |
+
dict(role='HUMAN', prompt=YOUR_JUDGE_TEMPLATE),
|
| 298 |
+
],
|
| 299 |
+
),
|
| 300 |
+
),
|
| 301 |
+
dataset_cfg=dict(
|
| 302 |
+
type=YourDataset,
|
| 303 |
+
path='path/to/your/dataset',
|
| 304 |
+
reader_cfg=reader_cfg,
|
| 305 |
+
),
|
| 306 |
+
judge_cfg=dict(), # Can use environment variables to configure the judge model
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Configure cascade evaluator (cascade mode)
|
| 310 |
+
cascade_evaluator = dict(
|
| 311 |
+
type=CascadeEvaluator,
|
| 312 |
+
llm_evaluator=llm_judge_evaluator,
|
| 313 |
+
rule_evaluator=rule_evaluator,
|
| 314 |
+
parallel=False # Cascade mode
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# For parallel mode, set parallel=True
|
| 318 |
+
parallel_evaluator = dict(
|
| 319 |
+
type=CascadeEvaluator,
|
| 320 |
+
llm_evaluator=llm_judge_evaluator,
|
| 321 |
+
rule_evaluator=rule_evaluator,
|
| 322 |
+
parallel=True # Parallel mode
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Use the cascade evaluator in your dataset evaluation config
|
| 326 |
+
eval_cfg = dict(evaluator=cascade_evaluator)
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
### Evaluation Results
|
| 330 |
+
|
| 331 |
+
The cascade evaluator outputs detailed evaluation statistics including:
|
| 332 |
+
|
| 333 |
+
- Accuracy of the rule-based evaluation
|
| 334 |
+
- Accuracy of the LLM evaluation (for samples that failed rule-based evaluation in cascade mode)
|
| 335 |
+
- Final combined accuracy
|
| 336 |
+
|
| 337 |
+
Example output:
|
| 338 |
+
|
| 339 |
+
```python
|
| 340 |
+
{
|
| 341 |
+
'accuracy': 85.0, # Final accuracy
|
| 342 |
+
'cascade_stats': {
|
| 343 |
+
'total_samples': 100,
|
| 344 |
+
'rule_correct': 70, # Number of samples correct by rule evaluation
|
| 345 |
+
'rule_accuracy': 70.0, # Accuracy of rule evaluation
|
| 346 |
+
'llm_evaluated': 30, # Number of samples evaluated by LLM (failed samples in cascade mode)
|
| 347 |
+
'llm_correct': 15, # Number of samples correct by LLM evaluation
|
| 348 |
+
'llm_accuracy': 50.0, # Accuracy of LLM evaluation
|
| 349 |
+
'final_correct': 85, # Total correct samples
|
| 350 |
+
'final_accuracy': 85.0, # Final accuracy
|
| 351 |
+
'parallel_mode': False, # Whether parallel mode was used
|
| 352 |
+
},
|
| 353 |
+
'details': [
|
| 354 |
+
# Detailed evaluation results for each sample
|
| 355 |
+
]
|
| 356 |
+
}
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
The cascade evaluator is particularly useful for:
|
| 360 |
+
|
| 361 |
+
1. Scenarios that require balancing evaluation cost and accuracy
|
| 362 |
+
2. Cases where rule-based evaluators are available but might not be comprehensive
|
| 363 |
+
3. Evaluation tasks that need more nuanced judgment for edge cases
|
| 364 |
+
|
| 365 |
+
## Complete Example
|
| 366 |
+
|
| 367 |
+
For a complete working example using GenericLLMEvaluator
|
| 368 |
+
, refer to the `eval_llm_judge.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving .
|
| 369 |
+
|
| 370 |
+
For a complete working example using CascadeEvaluator, refer to the `eval_cascade_evaluator.py` file in the examples directory, which demonstrates how to evaluate mathematical problem-solving .
|
docs/en/advanced_guides/longeval.md
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Long Context Evaluation Guidance
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Although large-scale language models (LLMs) such as GPT-4 have demonstrated significant advantages in handling natural language tasks, most current open-source models can only handle texts with a length of a few thousand tokens, which limits their ability to process long contexts such as reading books and writing text summaries. To explore the performance of models in dealing with long contexts, we use the [L-Eval](https://github.com/OpenLMLab/LEval) and [LongBench](https://github.com/THUDM/LongBench) datasets to test the model's ability to handle long contexts.
|
| 6 |
+
|
| 7 |
+
## Existing Algorithms and models
|
| 8 |
+
|
| 9 |
+
When dealing with long context inputs, the two main challenges faced by large models are the inference time cost and catastrophic forgetting. Recently, a large amount of research has been devoted to extending the model length, focusing on three improvement directions:
|
| 10 |
+
|
| 11 |
+
- Attention mechanisms. The ultimate goal of these methods is to reduce the computation cost of query-key pairs, but they may affect the performance of downstream tasks.
|
| 12 |
+
- Input methods. Some studies divide long context inputs into chunks or retrieve pre-existing text segments to enhance the model's ability to handle long contexts, but these methods are only effective for some tasks and are difficult to adapt to multiple downstream tasks.
|
| 13 |
+
- Position encoding. This research includes RoPE, ALiBi, Position Interpolation etc., which have shown good results in length extrapolation. These methods have been used to train long context models such as ChatGLM2-6B-32k and LongChat-32k.
|
| 14 |
+
|
| 15 |
+
First, we introduce some popular position encoding algorithms.
|
| 16 |
+
|
| 17 |
+
### RoPE
|
| 18 |
+
|
| 19 |
+
RoPE is a type of positional embedding that injects the information of position in Transformer. It encodes the absolute position with a rotation matrix and meanwhile incorporates the explicit relative position dependency in self-attention formulation. A graphic illustration of RoPE is shown below.
|
| 20 |
+
|
| 21 |
+
<div align="center">
|
| 22 |
+
<img src=https://github.com/open-compass/opencompass/assets/75252858/08c57958-0dcb-40d7-b91b-33f20ca2d89f>
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
RoPE comes with valuable properties such as flexibility of being expand to any sequence lengths, decaying inter-token dependency with increasing relative distances, and capability of equipping the linear self-attention with relative position encoding.
|
| 26 |
+
|
| 27 |
+
RoPE is adopted in many LLMs including LLaMA, LLaMA 2 and Vicuna-7b-v1.5-16k.
|
| 28 |
+
|
| 29 |
+
### ALiBi
|
| 30 |
+
|
| 31 |
+
Though RoPE and other alternatives to the original sinusoidal position method(like T5 bias) have improved extrapolation, they are considerably slower than the sinusoidal approach and use extra memory and parameter. Therefore, Attention with Linear Biases (ALiBi) is introduced to facilitate efficient extrapolation.
|
| 32 |
+
|
| 33 |
+
For an input subsequence of length L, the attention sublayer computes the attention scores for the ith query
|
| 34 |
+
|
| 35 |
+
```{math}
|
| 36 |
+
q_{i} \in R^{1 \times d}, (1 \leq i \leq L)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
in each head, given the first i keys
|
| 40 |
+
|
| 41 |
+
```{math}
|
| 42 |
+
K \in R^{i \times d}
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
where d is the head dimension.
|
| 46 |
+
|
| 47 |
+
```{math}
|
| 48 |
+
softmax(q_{i}K^{T})
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
ALiBi negatively biases attention scores with a linearly decreasing penalty proportional to the distance between the relevant key and query. The only modification it applies is after the query-key dot product, where it adds a static, non-learned bias.
|
| 52 |
+
|
| 53 |
+
```{math}
|
| 54 |
+
softmax(q_{i}K^{T}+m\cdot[-(i-1),...,-2,-1,0])
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
where scalar m is a head-specific slope fixed before training.
|
| 58 |
+
|
| 59 |
+
ALiBi eliminates position embeddings and it is as fast as the sinusoidal approach. It is used in LLMs including mpt-7b-storywriter, which is prepared to handle extremely long inputs.
|
| 60 |
+
|
| 61 |
+
### Position Interpolation(PI)
|
| 62 |
+
|
| 63 |
+
Many existing pre-trained LLMs including LLaMA use positional encodings that have weak extrapolation properties(e.g. RoPE). Position Interpolation is proposed and it can easily enable very long context windows while preserving model quality relatively well for the tasks within its original context window size.
|
| 64 |
+
|
| 65 |
+
The key idea of Position Interpolation is directly down-scale the position indices so that the maximum position index matches the previous context window limit in the pre-training stage. In other words, to accommodate more input tokens, the algorithm interpolates position encodings at neighboring integer positions, utilizing the fact that position encodings can be applied on non-integer positions, as opposed toextrapolating outside the trained positions, which may lead to catastrophic values. The algorithm requires only a very short period of fine-tuning for the model to fully adapt to greatly extended context windows.
|
| 66 |
+
|
| 67 |
+
An illustration of Position Interpolation method is shown below. Lower left illustrates Position Interpolation where it downscales the position indices (blue and green dots) themselves from \[0, 4096\] to \[0, 2048\] to force them to reside in the pretrained range.
|
| 68 |
+
|
| 69 |
+
<div align="center">
|
| 70 |
+
<img src=https://github.com/open-compass/opencompass/assets/75252858/406454ba-a811-4c66-abbe-3a5528947257>
|
| 71 |
+
</div>
|
| 72 |
+
|
| 73 |
+
Position Interpolation empowers ChatGLM2-6B-32k, a model based on ChatGLM2-6B, to deal with a 32k context window size.
|
| 74 |
+
|
| 75 |
+
Next, we introduce some long context language models we evaluate.
|
| 76 |
+
|
| 77 |
+
### XGen-7B-8k
|
| 78 |
+
|
| 79 |
+
XGen-7B-8k is trained with standard dense attention on up to 8k sequence length for up to 1.5T tokens. To mitigate slow training, XGen-7B-8k introduces training in stages with increasing sequence length. First, 800B tokens with sequence length of 2k tokens are observed, then 400B tokens with 4k, finally, 300B tokens with 8k length.
|
| 80 |
+
|
| 81 |
+
### Vicuna-7b-v1.5-16k
|
| 82 |
+
|
| 83 |
+
Vicuna-7b-v1.5-16k is fine-tuned from LLaMA 2 with supervised instruction fine-tuning and linear RoPE scaling. The training data is around 125K conversations collected from ShareGPT, a website where users can share their ChatGPT conversation. These conversations are packed into sequences that contain 16k tokens each.
|
| 84 |
+
|
| 85 |
+
### LongChat-7b-v1.5-32k
|
| 86 |
+
|
| 87 |
+
LongChat-7b-v1.5-32k is fine-tuned from LLaMA 2 models, which were originally pretrained with 4k context length. The training recipe can be conceptually described in two steps. The first step is condensing RoPE. Since the LLaMA model has not observed scenarios where position_ids > 4096 during the pre-training phase, LongChat condenses position_ids > 4096 to be within 0 to 4096. The second step is fine-tuning LongChat model on curated conversation data. In this step, the data is cleaned using FastChat data pipeline and truncated to the maximum length of model.
|
| 88 |
+
|
| 89 |
+
### ChatGLM2-6B-32k
|
| 90 |
+
|
| 91 |
+
The ChatGLM2-6B-32k further strengthens the ability to understand long texts based on the ChatGLM2-6B. Based on the method of Positional Interpolation, and trained with a 32K context length during the dialogue alignment, ChatGLM2-6B-32k can better handle up to 32K context length.
|
| 92 |
+
|
| 93 |
+
## [L-Eval](https://github.com/OpenLMLab/LEval)
|
| 94 |
+
|
| 95 |
+
L-Eval is a long context dataset built by OpenLMLab, consisting of 18 subtasks, including texts from various fields such as law, economy, and technology. The dataset consists of a total of 411 documents, over 2000 test cases, with an average document length of 7217 words. The subtasks in this dataset are divided into close-ended and open-ended categories, with 5 close-ended tasks evaluated using the exact match criterion and 13 open-ended tasks evaluated using Rouge scores.
|
| 96 |
+
|
| 97 |
+
## [LongBench](https://github.com/THUDM/LongBench)
|
| 98 |
+
|
| 99 |
+
LongBench is a long context dataset built by THUDM, consisting of 21 subtasks with a total of 4750 test cases. This dataset is the first long context dataset that includes both English and Chinese texts, with an average English text length of 6711 words and an average Chinese text length of 13386 characters. The 21 subtasks are divided into 6 types, providing a more comprehensive evaluation of the model's capabilities in various aspects.
|
| 100 |
+
|
| 101 |
+
<div align="center">
|
| 102 |
+
<img src=https://github.com/open-compass/opencompass/assets/75252858/4555e937-c519-4e9c-ad8d-7370430d466a>
|
| 103 |
+
</div>
|
| 104 |
+
|
| 105 |
+
## Evaluation Method
|
| 106 |
+
|
| 107 |
+
Due to the different maximum input lengths accepted by different models, in order to compare these large models more fairly, when the input length exceeds the maximum input limit of the model, we will trim the middle part of the input text to avoid missing prompt words.
|
| 108 |
+
|
| 109 |
+
## Long Context Ability Ranking
|
| 110 |
+
|
| 111 |
+
In the LongBench and L-Eval ability rankings, we select the average ranking **(The lower the better)** of each model in the subtask as the standard. It can be seen that GPT-4 and GPT-3.5-turbo-16k still occupy a leading position in long context tasks, while models like ChatGLM2-6B-32k also show significant improvement in long context ability after position interpolation based on ChatGLM2-6B.
|
| 112 |
+
|
| 113 |
+
<div align="center">
|
| 114 |
+
<img src=https://github.com/open-compass/opencompass/assets/75252858/29b5ad12-d9a3-4255-be0a-f770923fe514>
|
| 115 |
+
<img src=https://github.com/open-compass/opencompass/assets/75252858/680b4cda-c2b1-45d1-8c33-196dee1a38f3>
|
| 116 |
+
</div>
|
| 117 |
+
|
| 118 |
+
The original scores are shown below.
|
| 119 |
+
|
| 120 |
+
| L-Eval | GPT-4 | GPT-3.5-turbo-16k | chatglm2-6b-32k | vicuna-7b-v1.5-16k | xgen-7b-8k | internlm-chat-7b-8k | longchat-7b-v1.5-32k | chatglm2-6b |
|
| 121 |
+
| ----------------- | ----- | ----------------- | --------------- | ------------------ | ---------- | ------------------- | -------------------- | ----------- |
|
| 122 |
+
| coursera | 61.05 | 50 | 45.35 | 26.74 | 33.72 | 40.12 | 27.91 | 38.95 |
|
| 123 |
+
| gsm100 | 92 | 78 | 27 | 11 | 8 | 19 | 5 | 8 |
|
| 124 |
+
| quality | 81.19 | 62.87 | 44.55 | 11.39 | 33.66 | 45.54 | 29.7 | 41.09 |
|
| 125 |
+
| tpo | 72.93 | 74.72 | 56.51 | 17.47 | 44.61 | 60.59 | 17.1 | 56.51 |
|
| 126 |
+
| topic_retrieval | 100 | 79.33 | 44.67 | 24.67 | 1.33 | 0 | 25.33 | 1.33 |
|
| 127 |
+
| | | | | | | | | |
|
| 128 |
+
| financialqa | 53.49 | 50.32 | 35.41 | 44.59 | 39.28 | 25.09 | 34.07 | 17.82 |
|
| 129 |
+
| gov_report | 50.84 | 50.48 | 42.97 | 48.17 | 38.52 | 31.29 | 36.52 | 41.88 |
|
| 130 |
+
| legal_contract_qa | 31.23 | 27.97 | 34.21 | 24.25 | 21.36 | 19.28 | 13.32 | 17.59 |
|
| 131 |
+
| meeting_summ | 31.44 | 33.54 | 29.13 | 28.52 | 27.96 | 17.56 | 22.32 | 15.98 |
|
| 132 |
+
| multidocqa | 37.81 | 35.84 | 28.6 | 26.88 | 24.41 | 22.43 | 21.85 | 19.66 |
|
| 133 |
+
| narrativeqa | 25.87 | 25.73 | 18.24 | 20.58 | 16.87 | 13.81 | 16.87 | 1.16 |
|
| 134 |
+
| nq | 67.36 | 66.91 | 41.06 | 36.44 | 29.43 | 16.42 | 35.02 | 0.92 |
|
| 135 |
+
| news_summ | 34.52 | 40.41 | 32.72 | 33.98 | 26.87 | 22.48 | 30.33 | 29.51 |
|
| 136 |
+
| paper_assistant | 42.26 | 41.76 | 34.59 | 35.83 | 25.39 | 28.25 | 30.42 | 30.43 |
|
| 137 |
+
| patent_summ | 48.61 | 50.62 | 46.04 | 48.87 | 46.53 | 30.3 | 41.6 | 41.25 |
|
| 138 |
+
| review_summ | 31.98 | 33.37 | 21.88 | 29.21 | 26.85 | 16.61 | 20.02 | 19.68 |
|
| 139 |
+
| scientificqa | 49.76 | 48.32 | 31.27 | 31 | 27.43 | 33.01 | 20.98 | 13.61 |
|
| 140 |
+
| tvshow_summ | 34.84 | 31.36 | 23.97 | 27.88 | 26.6 | 14.55 | 25.09 | 19.45 |
|
| 141 |
+
|
| 142 |
+
| LongBench | GPT-4 | GPT-3.5-turbo-16k | chatglm2-6b-32k | longchat-7b-v1.5-32k | vicuna-7b-v1.5-16k | internlm-chat-7b-8k | chatglm2-6b | xgen-7b-8k |
|
| 143 |
+
| ------------------- | ----- | ----------------- | --------------- | -------------------- | ------------------ | ------------------- | ----------- | ---------- |
|
| 144 |
+
| NarrativeQA | 31.2 | 25.79 | 19.27 | 19.19 | 23.65 | 12.24 | 13.09 | 18.85 |
|
| 145 |
+
| Qasper | 42.77 | 43.4 | 33.93 | 30.36 | 31.45 | 24.81 | 22.52 | 20.18 |
|
| 146 |
+
| MultiFieldQA-en | 55.1 | 54.35 | 45.58 | 44.6 | 43.38 | 25.41 | 38.09 | 37 |
|
| 147 |
+
| MultiFieldQA-zh | 64.4 | 61.92 | 52.94 | 32.35 | 44.65 | 36.13 | 37.67 | 14.7 |
|
| 148 |
+
| | | | | | | | | |
|
| 149 |
+
| HotpotQA | 59.85 | 52.49 | 46.41 | 34.43 | 34.17 | 27.42 | 27.35 | 28.78 |
|
| 150 |
+
| 2WikiMQA | 67.52 | 41.7 | 33.63 | 23.06 | 20.45 | 26.24 | 22.83 | 20.13 |
|
| 151 |
+
| Musique | 37.53 | 27.5 | 21.57 | 12.42 | 13.92 | 9.75 | 7.26 | 11.34 |
|
| 152 |
+
| DuReader (zh) | 38.65 | 29.37 | 38.53 | 20.25 | 20.42 | 11.11 | 17.18 | 8.57 |
|
| 153 |
+
| | | | | | | | | |
|
| 154 |
+
| GovReport | 32.09 | 29.92 | 32.47 | 29.83 | 29.27 | 18.38 | 22.86 | 23.37 |
|
| 155 |
+
| QMSum | 24.37 | 23.67 | 23.19 | 22.71 | 23.37 | 18.45 | 21.23 | 21.12 |
|
| 156 |
+
| Multi_news | 28.52 | 27.05 | 25.12 | 26.1 | 27.83 | 24.52 | 24.7 | 23.69 |
|
| 157 |
+
| VCSUM (zh) | 15.54 | 16.88 | 15.95 | 13.46 | 15.76 | 12.91 | 14.07 | 0.98 |
|
| 158 |
+
| | | | | | | | | |
|
| 159 |
+
| TREC | 78.5 | 73.5 | 30.96 | 29.23 | 32.06 | 39 | 24.46 | 29.31 |
|
| 160 |
+
| TriviaQA | 92.19 | 92.75 | 80.64 | 64.19 | 46.53 | 79.55 | 64.19 | 69.58 |
|
| 161 |
+
| SAMSum | 46.32 | 43.16 | 29.49 | 25.23 | 25.23 | 43.05 | 20.22 | 16.05 |
|
| 162 |
+
| LSHT (zh) | 41.5 | 34.5 | 22.75 | 20 | 24.75 | 20.5 | 16 | 18.67 |
|
| 163 |
+
| | | | | | | | | |
|
| 164 |
+
| Passage Count | 8.5 | 3 | 3 | 1 | 3 | 1.76 | 3 | 1 |
|
| 165 |
+
| PassageRetrieval-en | 75 | 73 | 57.5 | 20.5 | 16.5 | 7 | 5.5 | 12 |
|
| 166 |
+
| PassageRetrieval-zh | 96 | 82.5 | 58 | 15 | 21 | 2.29 | 5 | 3.75 |
|
| 167 |
+
| | | | | | | | | |
|
| 168 |
+
| LCC | 59.25 | 53.49 | 53.3 | 51.46 | 49.3 | 49.32 | 46.59 | 44.1 |
|
| 169 |
+
| RepoBench-P | 55.42 | 55.95 | 46.66 | 52.18 | 41.49 | 35.86 | 41.97 | 41.83 |
|
docs/en/advanced_guides/math_verify.md
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# General Math Evaluation Guidance
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Mathematical reasoning is a crucial capability for large language models (LLMs). To evaluate a model's mathematical abilities, we need to test its capability to solve mathematical problems step by step and provide accurate final answers. OpenCompass provides a convenient way to evaluate mathematical reasoning through the CustomDataset and MATHVerifyEvaluator components.
|
| 6 |
+
|
| 7 |
+
## Dataset Format
|
| 8 |
+
|
| 9 |
+
The math evaluation dataset should be in either JSON Lines (.jsonl) or CSV format. Each problem should contain at least:
|
| 10 |
+
|
| 11 |
+
- A problem statement
|
| 12 |
+
- A solution/answer (typically in LaTeX format with the final answer in \\boxed{})
|
| 13 |
+
|
| 14 |
+
Example JSONL format:
|
| 15 |
+
|
| 16 |
+
```json
|
| 17 |
+
{"problem": "Find the value of x if 2x + 3 = 7", "solution": "Let's solve step by step:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\nTherefore, \\boxed{2}"}
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
Example CSV format:
|
| 21 |
+
|
| 22 |
+
```csv
|
| 23 |
+
problem,solution
|
| 24 |
+
"Find the value of x if 2x + 3 = 7","Let's solve step by step:\n2x + 3 = 7\n2x = 7 - 3\n2x = 4\nx = 2\nTherefore, \\boxed{2}"
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Configuration
|
| 28 |
+
|
| 29 |
+
To evaluate mathematical reasoning, you'll need to set up three main components:
|
| 30 |
+
|
| 31 |
+
1. Dataset Reader Configuration
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
math_reader_cfg = dict(
|
| 35 |
+
input_columns=['problem'], # Column name for the question
|
| 36 |
+
output_column='solution' # Column name for the answer
|
| 37 |
+
)
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
2. Inference Configuration
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
math_infer_cfg = dict(
|
| 44 |
+
prompt_template=dict(
|
| 45 |
+
type=PromptTemplate,
|
| 46 |
+
template=dict(
|
| 47 |
+
round=[
|
| 48 |
+
dict(
|
| 49 |
+
role='HUMAN',
|
| 50 |
+
prompt='{problem}\nPlease reason step by step, and put your final answer within \\boxed{}.',
|
| 51 |
+
),
|
| 52 |
+
]
|
| 53 |
+
),
|
| 54 |
+
),
|
| 55 |
+
retriever=dict(type=ZeroRetriever),
|
| 56 |
+
inferencer=dict(type=GenInferencer),
|
| 57 |
+
)
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
3. Evaluation Configuration
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
math_eval_cfg = dict(
|
| 64 |
+
evaluator=dict(type=MATHVerifyEvaluator),
|
| 65 |
+
)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Using CustomDataset
|
| 69 |
+
|
| 70 |
+
Here's how to set up a complete configuration for math evaluation:
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
from mmengine.config import read_base
|
| 74 |
+
from opencompass.models import TurboMindModelwithChatTemplate
|
| 75 |
+
from opencompass.datasets import CustomDataset
|
| 76 |
+
|
| 77 |
+
math_datasets = [
|
| 78 |
+
dict(
|
| 79 |
+
type=CustomDataset,
|
| 80 |
+
abbr='my-math-dataset', # Dataset abbreviation
|
| 81 |
+
path='path/to/your/dataset', # Path to your dataset file
|
| 82 |
+
reader_cfg=math_reader_cfg,
|
| 83 |
+
infer_cfg=math_infer_cfg,
|
| 84 |
+
eval_cfg=math_eval_cfg,
|
| 85 |
+
)
|
| 86 |
+
]
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
## MATHVerifyEvaluator
|
| 90 |
+
|
| 91 |
+
The MATHVerifyEvaluator is specifically designed to evaluate mathematical answers. It is developed based on the math_verify library, which provides mathematical expression parsing and verification capabilities, supporting extraction and equivalence verification for both LaTeX and general expressions.
|
| 92 |
+
|
| 93 |
+
The MATHVerifyEvaluator implements:
|
| 94 |
+
|
| 95 |
+
1. Extracts answers from both predictions and references using LaTeX extraction
|
| 96 |
+
2. Handles various LaTeX formats and environments
|
| 97 |
+
3. Verifies mathematical equivalence between predicted and reference answers
|
| 98 |
+
4. Provides detailed evaluation results including:
|
| 99 |
+
- Accuracy score
|
| 100 |
+
- Detailed comparison between predictions and references
|
| 101 |
+
- Parse results of both predicted and reference answers
|
| 102 |
+
|
| 103 |
+
The evaluator supports:
|
| 104 |
+
|
| 105 |
+
- Basic arithmetic operations
|
| 106 |
+
- Fractions and decimals
|
| 107 |
+
- Algebraic expressions
|
| 108 |
+
- Trigonometric functions
|
| 109 |
+
- Roots and exponents
|
| 110 |
+
- Mathematical symbols and operators
|
| 111 |
+
|
| 112 |
+
Example evaluation output:
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
{
|
| 116 |
+
'accuracy': 85.0, # Percentage of correct answers
|
| 117 |
+
'details': [
|
| 118 |
+
{
|
| 119 |
+
'predictions': 'x = 2', # Parsed prediction
|
| 120 |
+
'references': 'x = 2', # Parsed reference
|
| 121 |
+
'correct': True # Whether they match
|
| 122 |
+
},
|
| 123 |
+
# ... more results
|
| 124 |
+
]
|
| 125 |
+
}
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Complete Example
|
| 129 |
+
|
| 130 |
+
Here's a complete example of how to set up math evaluation:
|
| 131 |
+
|
| 132 |
+
```python
|
| 133 |
+
from mmengine.config import read_base
|
| 134 |
+
from opencompass.models import TurboMindModelwithChatTemplate
|
| 135 |
+
from opencompass.datasets import CustomDataset
|
| 136 |
+
from opencompass.openicl.icl_evaluator.math_evaluator import MATHVerifyEvaluator
|
| 137 |
+
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
| 138 |
+
from opencompass.openicl.icl_retriever import ZeroRetriever
|
| 139 |
+
from opencompass.openicl.icl_inferencer import GenInferencer
|
| 140 |
+
|
| 141 |
+
# Dataset reader configuration
|
| 142 |
+
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
|
| 143 |
+
|
| 144 |
+
# Inference configuration
|
| 145 |
+
math_infer_cfg = dict(
|
| 146 |
+
prompt_template=dict(
|
| 147 |
+
type=PromptTemplate,
|
| 148 |
+
template=dict(
|
| 149 |
+
round=[
|
| 150 |
+
dict(
|
| 151 |
+
role='HUMAN',
|
| 152 |
+
prompt='{problem}\nPlease reason step by step, and put your final answer within \\boxed{}.',
|
| 153 |
+
),
|
| 154 |
+
]
|
| 155 |
+
),
|
| 156 |
+
),
|
| 157 |
+
retriever=dict(type=ZeroRetriever),
|
| 158 |
+
inferencer=dict(type=GenInferencer),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Evaluation configuration
|
| 162 |
+
math_eval_cfg = dict(
|
| 163 |
+
evaluator=dict(type=MATHVerifyEvaluator),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Dataset configuration
|
| 167 |
+
math_datasets = [
|
| 168 |
+
dict(
|
| 169 |
+
type=CustomDataset,
|
| 170 |
+
abbr='my-math-dataset',
|
| 171 |
+
path='path/to/your/dataset.jsonl', # or .csv
|
| 172 |
+
reader_cfg=math_reader_cfg,
|
| 173 |
+
infer_cfg=math_infer_cfg,
|
| 174 |
+
eval_cfg=math_eval_cfg,
|
| 175 |
+
)
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
# Model configuration
|
| 179 |
+
models = [
|
| 180 |
+
dict(
|
| 181 |
+
type=TurboMindModelwithChatTemplate,
|
| 182 |
+
abbr='your-model-name',
|
| 183 |
+
path='your/model/path',
|
| 184 |
+
# ... other model configurations
|
| 185 |
+
)
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
# Output directory
|
| 189 |
+
work_dir = './outputs/math_eval'
|
| 190 |
+
```
|
docs/en/advanced_guides/needleinahaystack_eval.md
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Needle In A Haystack Evaluation
|
| 2 |
+
|
| 3 |
+
## Introduction to the Needle In A Haystack Test
|
| 4 |
+
|
| 5 |
+
The Needle In A Haystack test (inspired by [NeedleInAHaystack](https://github.com/gkamradt/LLMTest_NeedleInAHaystack/blob/main/LLMNeedleHaystackTester.py)) is an evaluation method where key information is randomly inserted into long texts to form the prompt for large language models (LLMs). This test aims to assess whether LLMs can extract critical information from long texts, thereby evaluating their fundamental ability to comprehend and process long-context documents.
|
| 6 |
+
|
| 7 |
+
## Task Overview
|
| 8 |
+
|
| 9 |
+
Within the `OpenCompass` framework, under `NeedleBench`, we designed a series of progressively challenging evaluation tasks to comprehensively assess LLMs' long-text information extraction and reasoning capabilities. For a complete description, please refer to our [technical report](https://arxiv.org/abs/2407.11963).
|
| 10 |
+
|
| 11 |
+
- **Single-Needle Retrieval Task (S-RT)**: Evaluates the LLM's ability to retrieve a single piece of key information from a long text, testing precise recall of specific details within extensive narratives. This corresponds to the **original Needle In A Haystack test** setup.
|
| 12 |
+
|
| 13 |
+
- **Multi-Needle Retrieval Task (M-RT)**: Explores the LLM's ability to retrieve multiple relevant pieces of information from long texts, simulating complex queries over comprehensive documents.
|
| 14 |
+
|
| 15 |
+
- **Multi-Needle Reasoning Task (M-RS)**: Assesses LLMs' abilities to integrate multiple key pieces of information extracted from long texts for reasoning, requiring a comprehensive understanding of content.
|
| 16 |
+
|
| 17 |
+
- **Ancestral Trace Challenge (ATC)**: Tests LLMs' capabilities in handling multi-layer logical challenges within realistic long-text contexts through "kinship trace needles." In the ATC task, no irrelevant (haystack) texts are added; every piece of text is critical, and models must reason through all details for accurate answers.
|
| 18 |
+
|
| 19 |
+
> **Note:** NeedleBench (v2) includes several optimizations and adjustments in dataset construction and task details. For a detailed comparison between the old and new versions, as well as a summary of updates, please refer to [opencompass/configs/datasets/needlebench_v2/readme.md](https://github.com/open-compass/opencompass/blob/main/opencompass/configs/datasets/needlebench_v2/readme.md).
|
| 20 |
+
|
| 21 |
+
## Evaluation Steps
|
| 22 |
+
|
| 23 |
+
> Note: In the latest `OpenCompass` codebase, the NeedleBench dataset is automatically loaded from the [Huggingface interface](https://huggingface.co/datasets/opencompass/NeedleBench), with no need for manual download or configuration.
|
| 24 |
+
|
| 25 |
+
### `OpenCompass` Environment Setup
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
|
| 29 |
+
conda activate opencompass
|
| 30 |
+
git clone https://github.com/open-compass/opencompass opencompass
|
| 31 |
+
cd opencompass
|
| 32 |
+
pip install -e .
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Dataset Configuration
|
| 36 |
+
|
| 37 |
+
We have pre-configured various long-context settings (4k, 8k, 32k, 128k, 200k, 1000k) in `opencompass/configs/datasets/needlebench_v2`, and you can flexibly define your parameters by adjusting the configuration files.
|
| 38 |
+
|
| 39 |
+
### Evaluation Example
|
| 40 |
+
|
| 41 |
+
#### Evaluating with `VLLM` Deployed `Qwen2-5-7B` Model
|
| 42 |
+
|
| 43 |
+
To evaluate the `Qwen2-5-7B` model deployed with `VLLM` on all tasks under NeedleBench-128K, use the following command. This leverages pre-defined model and dataset configuration files without needing additional configuration:
|
| 44 |
+
|
| 45 |
+
##### Local Evaluation
|
| 46 |
+
|
| 47 |
+
If evaluating locally, the command will use all available GPUs. You can control GPU visibility using `CUDA_VISIBLE_DEVICES`:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# Local evaluation
|
| 51 |
+
python run.py --datasets needlebench_v2_128k --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
##### Evaluation on Slurm Cluster
|
| 55 |
+
|
| 56 |
+
For Slurm environments, you can add options like `--slurm -p partition_name -q reserved --max-num-workers 16`:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# Slurm evaluation
|
| 60 |
+
python run.py --datasets needlebench_v2_128k --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer --slurm -p partition_name -q reserved --max-num-workers 16
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
##### Evaluating Specific Subsets
|
| 64 |
+
|
| 65 |
+
If you only want to test the original Needle In A Haystack task (e.g., single-needle 128k), adjust the dataset parameter:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
python run.py --datasets needlebench_v2_single_128k --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer --slurm -p partition_name -q reserved --max-num-workers 16
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
To evaluate only Chinese versions, specify the subset dataset after `/`:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
python run.py --datasets needlebench_v2_single_128k/needlebench_zh_datasets --models vllm_qwen2_5_7b_instruct_128k --summarizer needlebench/needlebench_v2_128k_summarizer --slurm -p partition_name -q reserved --max-num-workers 16
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Ensure `VLLM` is installed beforehand:
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
# Install vLLM with CUDA 12.4.
|
| 81 |
+
# For other CUDA versions, please refer to the [official documentation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu.html)
|
| 82 |
+
pip install vllm
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
#### Evaluating Other `Huggingface` Models
|
| 86 |
+
|
| 87 |
+
For other models, it is recommended to write your own config file (such as `examples/eval_needlebench_v2.py`) to adjust `max_seq_len` and `max_out_len`, so that the model can process the full context.
|
| 88 |
+
|
| 89 |
+
You can then run evaluation with:
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
python run.py examples/eval_needlebench_v2.py --slurm -p partition_name -q reserved --max-num-workers 16
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
No need to manually specify `--datasets`, `--models`, or `--summarizer` again.
|
| 96 |
+
|
| 97 |
+
### Visualization
|
| 98 |
+
|
| 99 |
+
NeedleBench's latest version has built-in visualization integrated into the summarizer. You can find corresponding visualizations in the `plots` directory under the output folder without needing additional scripts.
|
| 100 |
+
|
| 101 |
+
### Citation
|
| 102 |
+
|
| 103 |
+
If you use NeedleBench, please cite us:
|
| 104 |
+
|
| 105 |
+
```bibtex
|
| 106 |
+
@misc{li2025needlebenchllmsretrievalreasoning,
|
| 107 |
+
title={NeedleBench: Can LLMs Do Retrieval and Reasoning in Information-Dense Context?},
|
| 108 |
+
author={Mo Li and Songyang Zhang and Taolin Zhang and Haodong Duan and Yunxin Liu and Kai Chen},
|
| 109 |
+
year={2025},
|
| 110 |
+
eprint={2407.11963},
|
| 111 |
+
archivePrefix={arXiv},
|
| 112 |
+
primaryClass={cs.CL},
|
| 113 |
+
url={https://arxiv.org/abs/2407.11963},
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
@misc{2023opencompass,
|
| 117 |
+
title={OpenCompass: A Universal Evaluation Platform for Foundation Models},
|
| 118 |
+
author={OpenCompass Contributors},
|
| 119 |
+
howpublished={\url{https://github.com/open-compass/opencompass}},
|
| 120 |
+
year={2023}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
@misc{LLMTest_NeedleInAHaystack,
|
| 124 |
+
title={LLMTest Needle In A Haystack - Pressure Testing LLMs},
|
| 125 |
+
author={gkamradt},
|
| 126 |
+
year={2023},
|
| 127 |
+
howpublished={\url{https://github.com/gkamradt/LLMTest_NeedleInAHaystack}}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
@misc{wei2023skywork,
|
| 131 |
+
title={Skywork: A More Open Bilingual Foundation Model},
|
| 132 |
+
author={Tianwen Wei and Liang Zhao and Lichang Zhang and Bo Zhu and Lijie Wang and Haihua Yang and Biye Li and Cheng Cheng and Weiwei L\"u and Rui Hu and Chenxia Li and Liu Yang and Xilin Luo and Xuejie Wu and Lunan Liu and Wenjun Cheng and Peng Cheng and Jianhao Zhang and Xiaoyu Zhang and Lei Lin and Xiaokun Wang and Yutuan Ma and Chuanhai Dong and Yanqi Sun and Yifu Chen and Yongyi Peng and Xiaojuan Liang and Shuicheng Yan and Han Fang and Yahui Zhou},
|
| 133 |
+
year={2023},
|
| 134 |
+
eprint={2310.19341},
|
| 135 |
+
archivePrefix={arXiv},
|
| 136 |
+
primaryClass={cs.CL}
|
| 137 |
+
}
|
| 138 |
+
```
|
docs/en/advanced_guides/new_dataset.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Add a dataset
|
| 2 |
+
|
| 3 |
+
Although OpenCompass has already included most commonly used datasets, users need to follow the steps below to support a new dataset if wanted:
|
| 4 |
+
|
| 5 |
+
1. Add a dataset script `mydataset.py` to the `opencompass/datasets` folder. This script should include:
|
| 6 |
+
|
| 7 |
+
- The dataset and its loading method. Define a `MyDataset` class that implements the data loading method `load` as a static method. This method should return data of type `datasets.Dataset`. We use the Hugging Face dataset as the unified interface for datasets to avoid introducing additional logic. Here's an example:
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
import datasets
|
| 11 |
+
from .base import BaseDataset
|
| 12 |
+
|
| 13 |
+
class MyDataset(BaseDataset):
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
def load(**kwargs) -> datasets.Dataset:
|
| 17 |
+
pass
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
- (Optional) If the existing evaluators in OpenCompass do not meet your needs, you need to define a `MyDatasetEvaluator` class that implements the scoring method `score`. This method should take `predictions` and `references` as input and return the desired dictionary. Since a dataset may have multiple metrics, the method should return a dictionary containing the metrics and their corresponding scores. Here's an example:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
| 24 |
+
|
| 25 |
+
class MyDatasetEvaluator(BaseEvaluator):
|
| 26 |
+
|
| 27 |
+
def score(self, predictions: List, references: List) -> dict:
|
| 28 |
+
pass
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
- (Optional) If the existing postprocessors in OpenCompass do not meet your needs, you need to define the `mydataset_postprocess` method. This method takes an input string and returns the corresponding postprocessed result string. Here's an example:
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
def mydataset_postprocess(text: str) -> str:
|
| 35 |
+
pass
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
2. After defining the dataset loading, data postprocessing, and evaluator methods, you need to add the following configurations to the configuration file:
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
from opencompass.datasets import MyDataset, MyDatasetEvaluator, mydataset_postprocess
|
| 42 |
+
|
| 43 |
+
mydataset_eval_cfg = dict(
|
| 44 |
+
evaluator=dict(type=MyDatasetEvaluator),
|
| 45 |
+
pred_postprocessor=dict(type=mydataset_postprocess))
|
| 46 |
+
|
| 47 |
+
mydataset_datasets = [
|
| 48 |
+
dict(
|
| 49 |
+
type=MyDataset,
|
| 50 |
+
...,
|
| 51 |
+
reader_cfg=...,
|
| 52 |
+
infer_cfg=...,
|
| 53 |
+
eval_cfg=mydataset_eval_cfg)
|
| 54 |
+
]
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
- To facilitate the access of your datasets to other users, you need to specify the channels for downloading the datasets in the configuration file. Specifically, you need to first fill in a dataset name given by yourself in the `path` field in the `mydataset_datasets` configuration, and this name will be mapped to the actual download path in the `opencompass/utils/datasets_info.py` file. Here's an example:
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
mmlu_datasets = [an
|
| 61 |
+
dict(
|
| 62 |
+
...,
|
| 63 |
+
path='opencompass/mmlu',
|
| 64 |
+
...,
|
| 65 |
+
)
|
| 66 |
+
]
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
- Next, you need to create a dictionary key in `opencompass/utils/datasets_info.py` with the same name as the one you provided above. If you have already hosted the dataset on HuggingFace or Modelscope, please add a dictionary key to the `DATASETS_MAPPING` dictionary and fill in the HuggingFace or Modelscope dataset address in the `hf_id` or `ms_id` key, respectively. You can also specify a default local address. Here's an example:
|
| 70 |
+
|
| 71 |
+
```python
|
| 72 |
+
"opencompass/mmlu": {
|
| 73 |
+
"ms_id": "opencompass/mmlu",
|
| 74 |
+
"hf_id": "opencompass/mmlu",
|
| 75 |
+
"local": "./data/mmlu/",
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
- If you wish for the provided dataset to be directly accessible from the OpenCompass OSS repository when used by others, you need to submit the dataset files in the Pull Request phase. We will then transfer the dataset to the OSS on your behalf and create a new dictionary key in the `DATASET_URL`.
|
| 80 |
+
|
| 81 |
+
- To ensure the optionality of data sources, you need to improve the method `load` in the dataset script `mydataset.py`. Specifically, you need to implement a functionality to switch among different download sources based on the setting of the environment variable `DATASET_SOURCE`. It should be noted that if the environment variable `DATASET_SOURCE` is not set, the dataset will default to being downloaded from the OSS repository. Here's an example from `opencompass/dataset/cmmlu.py`:
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
def load(path: str, name: str, **kwargs):
|
| 85 |
+
...
|
| 86 |
+
if environ.get('DATASET_SOURCE') == 'ModelScope':
|
| 87 |
+
...
|
| 88 |
+
else:
|
| 89 |
+
...
|
| 90 |
+
return dataset
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
3. After completing the dataset script and config file, you need to register the information of your new dataset in the file `dataset-index.yml` at the main directory, so that it can be added to the dataset statistics list on the OpenCompass website.
|
| 94 |
+
|
| 95 |
+
- The keys that need to be filled in include `name`: the name of your dataset, `category`: the category of your dataset, `paper`: the URL of the paper or project, and `configpath`: the path to the dataset config file. Here's an example:
|
| 96 |
+
|
| 97 |
+
```
|
| 98 |
+
- mydataset:
|
| 99 |
+
name: MyDataset
|
| 100 |
+
category: Understanding
|
| 101 |
+
paper: https://arxiv.org/pdf/xxxxxxx
|
| 102 |
+
configpath: opencompass/configs/datasets/MyDataset
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Detailed dataset configuration files and other required configuration files can be referred to in the [Configuration Files](../user_guides/config.md) tutorial. For guides on launching tasks, please refer to the [Quick Start](../get_started/quick_start.md) tutorial.
|
docs/en/advanced_guides/new_model.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Add a Model
|
| 2 |
+
|
| 3 |
+
Currently, we support HF models, some model APIs, and some third-party models.
|
| 4 |
+
|
| 5 |
+
## Adding API Models
|
| 6 |
+
|
| 7 |
+
To add a new API-based model, you need to create a new file named `mymodel_api.py` under `opencompass/models` directory. In this file, you should inherit from `BaseAPIModel` and implement the `generate` method for inference and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file.
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
from ..base_api import BaseAPIModel
|
| 11 |
+
|
| 12 |
+
class MyModelAPI(BaseAPIModel):
|
| 13 |
+
|
| 14 |
+
is_api: bool = True
|
| 15 |
+
|
| 16 |
+
def __init__(self,
|
| 17 |
+
path: str,
|
| 18 |
+
max_seq_len: int = 2048,
|
| 19 |
+
query_per_second: int = 1,
|
| 20 |
+
retry: int = 2,
|
| 21 |
+
**kwargs):
|
| 22 |
+
super().__init__(path=path,
|
| 23 |
+
max_seq_len=max_seq_len,
|
| 24 |
+
meta_template=meta_template,
|
| 25 |
+
query_per_second=query_per_second,
|
| 26 |
+
retry=retry)
|
| 27 |
+
...
|
| 28 |
+
|
| 29 |
+
def generate(
|
| 30 |
+
self,
|
| 31 |
+
inputs,
|
| 32 |
+
max_out_len: int = 512,
|
| 33 |
+
temperature: float = 0.7,
|
| 34 |
+
) -> List[str]:
|
| 35 |
+
"""Generate results given a list of inputs."""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def get_token_len(self, prompt: str) -> int:
|
| 39 |
+
"""Get lengths of the tokenized string."""
|
| 40 |
+
pass
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Adding Third-Party Models
|
| 44 |
+
|
| 45 |
+
To add a new third-party model, you need to create a new file named `mymodel.py` under `opencompass/models` directory. In this file, you should inherit from `BaseModel` and implement the `generate` method for generative inference, the `get_ppl` method for discriminative inference, and the `get_token_len` method to calculate the length of tokens. Once you have defined the model, you can modify the corresponding configuration file.
|
| 46 |
+
|
| 47 |
+
```python
|
| 48 |
+
from ..base import BaseModel
|
| 49 |
+
|
| 50 |
+
class MyModel(BaseModel):
|
| 51 |
+
|
| 52 |
+
def __init__(self,
|
| 53 |
+
pkg_root: str,
|
| 54 |
+
ckpt_path: str,
|
| 55 |
+
tokenizer_only: bool = False,
|
| 56 |
+
meta_template: Optional[Dict] = None,
|
| 57 |
+
**kwargs):
|
| 58 |
+
...
|
| 59 |
+
|
| 60 |
+
def get_token_len(self, prompt: str) -> int:
|
| 61 |
+
"""Get lengths of the tokenized strings."""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
| 65 |
+
"""Generate results given a list of inputs. """
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
def get_ppl(self,
|
| 69 |
+
inputs: List[str],
|
| 70 |
+
mask_length: Optional[List[int]] = None) -> List[float]:
|
| 71 |
+
"""Get perplexity scores given a list of inputs."""
|
| 72 |
+
pass
|
| 73 |
+
```
|
docs/en/advanced_guides/objective_judgelm_evaluation.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Using Large Models as JudgeLLM for Objective Evaluation
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Traditional objective evaluations often rely on standard answers for reference. However, in practical applications, the predicted results of models may vary due to differences in the model's instruction-following capabilities or imperfections in post-processing functions. This can lead to incorrect extraction of answers and comparison with standard answers, resulting in potentially inaccurate evaluation outcomes. To address this issue, we have adopted a process similar to subjective evaluations by introducing JudgeLLM post-prediction to assess the consistency between model responses and standard answers. ([LLM-as-a-Judge](https://arxiv.org/abs/2306.05685)).
|
| 6 |
+
|
| 7 |
+
Currently, all models supported by the opencompass repository can be directly used as JudgeLLM. Additionally, we are planning to support dedicated JudgeLLMs.
|
| 8 |
+
|
| 9 |
+
## Currently Supported Objective Evaluation Datasets
|
| 10 |
+
|
| 11 |
+
1. MATH ([https://github.com/hendrycks/math](https://github.com/hendrycks/math))
|
| 12 |
+
|
| 13 |
+
## Custom JudgeLLM Objective Dataset Evaluation
|
| 14 |
+
|
| 15 |
+
OpenCompass currently supports most datasets that use `GenInferencer` for inference. The specific process for custom JudgeLLM objective evaluation includes:
|
| 16 |
+
|
| 17 |
+
1. Building evaluation configurations using API models or open-source models for inference of question answers.
|
| 18 |
+
2. Employing a selected evaluation model (JudgeLLM) to assess the outputs of the model.
|
| 19 |
+
|
| 20 |
+
### Step One: Building Evaluation Configurations, Using MATH as an Example
|
| 21 |
+
|
| 22 |
+
Below is the Config for evaluating the MATH dataset with JudgeLLM, with the evaluation model being *Llama3-8b-instruct* and the JudgeLLM being *Llama3-70b-instruct*. For more detailed config settings, please refer to `examples/eval_math_llm_judge.py`. The following is a brief version of the annotations to help users understand the meaning of the configuration file.
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
# Most of the code in this file is copied from https://github.com/openai/simple-evals/blob/main/math_eval.py
|
| 26 |
+
from mmengine.config import read_base
|
| 27 |
+
with read_base():
|
| 28 |
+
from .models.hf_llama.hf_llama3_8b_instruct import models as hf_llama3_8b_instruct_model # noqa: F401, F403
|
| 29 |
+
from .models.hf_llama.hf_llama3_70b_instruct import models as hf_llama3_70b_instruct_model # noqa: F401, F403
|
| 30 |
+
from .datasets.math.math_llm_judge import math_datasets # noqa: F401, F403
|
| 31 |
+
from opencompass.datasets import math_judement_preprocess
|
| 32 |
+
from opencompass.partitioners import NaivePartitioner, SizePartitioner
|
| 33 |
+
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
|
| 34 |
+
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
|
| 35 |
+
from opencompass.runners import LocalRunner
|
| 36 |
+
from opencompass.runners import SlurmSequentialRunner
|
| 37 |
+
from opencompass.tasks import OpenICLInferTask
|
| 38 |
+
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
|
| 39 |
+
from opencompass.summarizers import AllObjSummarizer
|
| 40 |
+
from opencompass.openicl.icl_evaluator import LMEvaluator
|
| 41 |
+
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ------------- Prompt Settings ----------------------------------------
|
| 45 |
+
# Evaluation template, please modify the template as needed, JudgeLLM typically uses [Yes] or [No] as the response. For the MATH dataset, the evaluation template is as follows:
|
| 46 |
+
eng_obj_prompt = """
|
| 47 |
+
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
| 48 |
+
|
| 49 |
+
Examples:
|
| 50 |
+
|
| 51 |
+
Expression 1: $2x+3$
|
| 52 |
+
Expression 2: $3+2x$
|
| 53 |
+
|
| 54 |
+
[Yes]
|
| 55 |
+
|
| 56 |
+
Expression 1: 3/2
|
| 57 |
+
Expression 2: 1.5
|
| 58 |
+
|
| 59 |
+
[Yes]
|
| 60 |
+
|
| 61 |
+
Expression 1: $x^2+2x+1$
|
| 62 |
+
Expression 2: $y^2+2y+1$
|
| 63 |
+
|
| 64 |
+
[No]
|
| 65 |
+
|
| 66 |
+
Expression 1: $x^2+2x+1$
|
| 67 |
+
Expression 2: $(x+1)^2$
|
| 68 |
+
|
| 69 |
+
[Yes]
|
| 70 |
+
|
| 71 |
+
Expression 1: 3245/5
|
| 72 |
+
Expression 2: 649
|
| 73 |
+
|
| 74 |
+
[No]
|
| 75 |
+
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
| 76 |
+
|
| 77 |
+
Expression 1: 2/(-3)
|
| 78 |
+
Expression 2: -2/3
|
| 79 |
+
|
| 80 |
+
[Yes]
|
| 81 |
+
(trivial simplifications are allowed)
|
| 82 |
+
|
| 83 |
+
Expression 1: 72 degrees
|
| 84 |
+
Expression 2: 72
|
| 85 |
+
|
| 86 |
+
[Yes]
|
| 87 |
+
(give benefit of the doubt to units)
|
| 88 |
+
|
| 89 |
+
Expression 1: 64
|
| 90 |
+
Expression 2: 64 square feet
|
| 91 |
+
|
| 92 |
+
[Yes]
|
| 93 |
+
(give benefit of the doubt to units)
|
| 94 |
+
|
| 95 |
+
Expression 1: 64
|
| 96 |
+
Expression 2:
|
| 97 |
+
|
| 98 |
+
[No]
|
| 99 |
+
(only mark as equivalent if both expressions are nonempty)
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
YOUR TASK
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
Respond with only "[Yes]" or "[No]" (without quotes). Do not include a rationale.
|
| 107 |
+
Expression 1: {obj_gold}
|
| 108 |
+
Expression 2: {prediction}
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
# ------------- Inference Phase ----------------------------------------
|
| 113 |
+
# Models to be evaluated
|
| 114 |
+
models = [*hf_llama3_8b_instruct_model]
|
| 115 |
+
# Evaluation models
|
| 116 |
+
judge_models = hf_llama3_70b_instruct_model
|
| 117 |
+
|
| 118 |
+
eng_datasets = [*math_datasets]
|
| 119 |
+
chn_datasets = []
|
| 120 |
+
datasets = eng_datasets + chn_datasets
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
for d in eng_datasets:
|
| 124 |
+
d['eval_cfg']= dict(
|
| 125 |
+
evaluator=dict(
|
| 126 |
+
type=LMEvaluator,
|
| 127 |
+
# If you need to preprocess model predictions before judging,
|
| 128 |
+
# you can specify a pred_postprocessor function here
|
| 129 |
+
pred_postprocessor=dict(type=math_judement_preprocess),
|
| 130 |
+
prompt_template=dict(
|
| 131 |
+
type=PromptTemplate,
|
| 132 |
+
template=dict(round=[
|
| 133 |
+
dict(
|
| 134 |
+
role='HUMAN',
|
| 135 |
+
prompt = eng_obj_prompt
|
| 136 |
+
),
|
| 137 |
+
]),
|
| 138 |
+
),
|
| 139 |
+
),
|
| 140 |
+
pred_role="BOT",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
infer = dict(
|
| 144 |
+
partitioner=dict(type=SizePartitioner, max_task_size=40000),
|
| 145 |
+
runner=dict(
|
| 146 |
+
type=LocalRunner,
|
| 147 |
+
max_num_workers=256,
|
| 148 |
+
task=dict(type=OpenICLInferTask)),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# ------------- Evaluation Configuration --------------------------------
|
| 152 |
+
eval = dict(
|
| 153 |
+
partitioner=dict(
|
| 154 |
+
type=SubjectiveSizePartitioner, max_task_size=80000, mode='singlescore', models=models, judge_models=judge_models,
|
| 155 |
+
),
|
| 156 |
+
runner=dict(type=LocalRunner,
|
| 157 |
+
max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
summarizer = dict(
|
| 161 |
+
type=AllObjSummarizer
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Output folder
|
| 165 |
+
work_dir = 'outputs/obj_all/'
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
### Step Two: Launch Evaluation and Output Results
|
| 169 |
+
|
| 170 |
+
```shell
|
| 171 |
+
python run.py eval_math_llm_judge.py
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
This will initiate two rounds of evaluation. The first round involves model inference to obtain predicted answers to questions, and the second round involves JudgeLLM evaluating the consistency between the predicted answers and the standard answers, and scoring them.
|
| 175 |
+
|
| 176 |
+
- The results of model predictions will be saved in `output/.../timestamp/predictions/xxmodel/xxx.json`
|
| 177 |
+
- The JudgeLLM's evaluation responses will be saved in `output/.../timestamp/results/xxmodel/xxx.json`
|
| 178 |
+
- The evaluation report will be output to `output/.../timestamp/summary/timestamp/xxx.csv`
|
| 179 |
+
|
| 180 |
+
## Results
|
| 181 |
+
|
| 182 |
+
Using the Llama3-8b-instruct as the evaluation model and the Llama3-70b-instruct as the evaluator, the MATH dataset was assessed with the following results:
|
| 183 |
+
|
| 184 |
+
| Model | JudgeLLM Evaluation | Naive Evaluation |
|
| 185 |
+
| ------------------- | ------------------- | ---------------- |
|
| 186 |
+
| llama-3-8b-instruct | 27.7 | 27.8 |
|
docs/en/advanced_guides/persistence.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation Results Persistence
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Normally, the evaluation results of OpenCompass will be saved to your work directory. But in some cases, there may be a need for data sharing among users or quickly browsing existing public evaluation results. Therefore, we provide an interface that can quickly transfer evaluation results to external public data stations, and on this basis, provide functions such as uploading, overwriting, and reading.
|
| 6 |
+
|
| 7 |
+
## Quick Start
|
| 8 |
+
|
| 9 |
+
### Uploading
|
| 10 |
+
|
| 11 |
+
By adding `args` to the evaluation command or adding configuration in the Eval script, the results of evaluation can be stored in the path you specify. Here are the examples:
|
| 12 |
+
|
| 13 |
+
(Approach 1) Add an `args` option to the command and specify your public path address.
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
opencompass ... -sp '/your_path'
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
(Approach 2) Add configuration in the Eval script.
|
| 20 |
+
|
| 21 |
+
```pythonE
|
| 22 |
+
station_path = '/your_path'
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### Overwriting
|
| 26 |
+
|
| 27 |
+
The above storage method will first determine whether the same task result already exists in the data station based on the `abbr` attribute in the model and dataset configuration before uploading data. If results already exists, cancel this storage. If you need to update these results, please add the `station-overwrite` option to the command, here is an example:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
opencompass ... -sp '/your_path' --station-overwrite
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Reading
|
| 34 |
+
|
| 35 |
+
You can directly read existing results from the data station to avoid duplicate evaluation tasks. The read results will directly participate in the 'summarize' step. When using this configuration, only tasks that do not store results in the data station will be initiated. Here is an example:
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
opencompass ... -sp '/your_path' --read-from-station
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Command Combination
|
| 42 |
+
|
| 43 |
+
1. Only upload the results under your latest working directory to the data station, without supplementing tasks that missing results:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
opencompass ... -sp '/your_path' -r latest -m viz
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Storage Format of the Data Station
|
| 50 |
+
|
| 51 |
+
In the data station, the evaluation results are stored as `json` files for each `model-dataset` pair. The specific directory form is `/your_path/dataset_name/model_name.json `. Each `json` file stores a dictionary corresponding to the results, including `predictions`, `results`, and `cfg`, here is an example:
|
| 52 |
+
|
| 53 |
+
```pythonE
|
| 54 |
+
Result = {
|
| 55 |
+
'predictions': List[Dict],
|
| 56 |
+
'results': Dict,
|
| 57 |
+
'cfg': Dict = {
|
| 58 |
+
'models': Dict,
|
| 59 |
+
'datasets': Dict,
|
| 60 |
+
(Only subjective datasets)'judge_models': Dict
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Among this three keys, `predictions` records the predictions of the model on each item of data in the dataset. `results` records the total score of the model on the dataset. `cfg` records detailed configurations of the model and the dataset in this evaluation task.
|
docs/en/advanced_guides/prompt_attack.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Prompt Attack
|
| 2 |
+
|
| 3 |
+
We support prompt attack following the idea of [PromptBench](https://github.com/microsoft/promptbench). The main purpose here is to evaluate the robustness of prompt instruction, which means when attack/modify the prompt to instruct the task, how well can this task perform as the original task.
|
| 4 |
+
|
| 5 |
+
## Set up environment
|
| 6 |
+
|
| 7 |
+
Some components are necessary to prompt attack experiment, therefore we need to set up environments.
|
| 8 |
+
|
| 9 |
+
```shell
|
| 10 |
+
git clone https://github.com/microsoft/promptbench.git
|
| 11 |
+
pip install textattack==0.3.8
|
| 12 |
+
export PYTHONPATH=$PYTHONPATH:promptbench/
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## How to attack
|
| 16 |
+
|
| 17 |
+
### Add a dataset config
|
| 18 |
+
|
| 19 |
+
We will use GLUE-wnli dataset as example, most configuration settings can refer to [config.md](../user_guides/config.md) for help.
|
| 20 |
+
|
| 21 |
+
First we need support the basic dataset config, you can find the existing config files in `configs` or support your own config according to [new-dataset](./new_dataset.md)
|
| 22 |
+
|
| 23 |
+
Take the following `infer_cfg` as example, we need to define the prompt template. `adv_prompt` is the basic prompt placeholder to be attacked in the experiment. `sentence1` and `sentence2` are the input columns of this dataset. The attack will only modify the `adv_prompt` here.
|
| 24 |
+
|
| 25 |
+
Then, we should use `AttackInferencer` with `original_prompt_list` and `adv_key` to tell the inferencer where to attack and what text to be attacked.
|
| 26 |
+
|
| 27 |
+
More details can refer to `configs/datasets/promptbench/promptbench_wnli_gen_50662f.py` config file.
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
original_prompt_list = [
|
| 31 |
+
'Are the following two sentences entailment or not_entailment? Answer me with "A. entailment" or "B. not_entailment", just one word. ',
|
| 32 |
+
"Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'.",
|
| 33 |
+
...,
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
wnli_infer_cfg = dict(
|
| 37 |
+
prompt_template=dict(
|
| 38 |
+
type=PromptTemplate,
|
| 39 |
+
template=dict(round=[
|
| 40 |
+
dict(
|
| 41 |
+
role="HUMAN",
|
| 42 |
+
prompt="""{adv_prompt}
|
| 43 |
+
Sentence 1: {sentence1}
|
| 44 |
+
Sentence 2: {sentence2}
|
| 45 |
+
Answer:"""),
|
| 46 |
+
]),
|
| 47 |
+
),
|
| 48 |
+
retriever=dict(type=ZeroRetriever),
|
| 49 |
+
inferencer=dict(
|
| 50 |
+
type=AttackInferencer,
|
| 51 |
+
original_prompt_list=original_prompt_list,
|
| 52 |
+
adv_key='adv_prompt'))
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### Add a eval config
|
| 56 |
+
|
| 57 |
+
We should use `OpenICLAttackTask` here for attack task. Also `NaivePartitioner` should be used because the attack experiment will run the whole dataset repeatedly for nearly hurdurds times to search the best attack, we do not want to split the dataset for convenience.
|
| 58 |
+
|
| 59 |
+
```note
|
| 60 |
+
Please choose a small dataset(example < 1000) for attack, due to the aforementioned repeated search, otherwise the time cost is enumerous.
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
There are several other options in `attack` config:
|
| 64 |
+
|
| 65 |
+
- `attack`: attack type, available options includes `textfooler`, `textbugger`, `deepwordbug`, `bertattack`, `checklist`, `stresstest`;
|
| 66 |
+
- `query_budget`: upper boundary of queries, which means the total numbers of running the dataset;
|
| 67 |
+
- `prompt_topk`: number of topk prompt to be attacked. In most case, the original prompt list is great than 10, running the whole set is time consuming.
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
# Please run whole dataset at a time, aka use `NaivePartitioner` only
|
| 71 |
+
# Please use `OpenICLAttackTask` if want to perform attack experiment
|
| 72 |
+
infer = dict(
|
| 73 |
+
partitioner=dict(type=NaivePartitioner),
|
| 74 |
+
runner=dict(
|
| 75 |
+
type=SlurmRunner,
|
| 76 |
+
max_num_workers=8,
|
| 77 |
+
task=dict(type=OpenICLAttackTask),
|
| 78 |
+
retry=0),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
attack = dict(
|
| 82 |
+
attack='textfooler',
|
| 83 |
+
query_budget=100,
|
| 84 |
+
prompt_topk=2,
|
| 85 |
+
)
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### Run the experiment
|
| 89 |
+
|
| 90 |
+
Please use `--mode infer` when run the attack experiment, and set `PYTHONPATH` env.
|
| 91 |
+
|
| 92 |
+
```shell
|
| 93 |
+
python run.py examples/eval_attack.py --mode infer
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
All the results will be saved in `attack` folder.
|
| 97 |
+
The content includes the original prompt accuracy and the attacked prompt with dropped accuracy of `topk` prompt, for instance:
|
| 98 |
+
|
| 99 |
+
```
|
| 100 |
+
Prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'., acc: 59.15%
|
| 101 |
+
Prompt: Does the relationship between the given sentences represent entailment or not_entailment? Respond with 'A. entailment' or 'B. not_entailment'., acc: 57.75%
|
| 102 |
+
Prompt: Analyze the two provided sentences and decide if their relationship is 'A. entailment' or 'B. not_entailment'., acc: 56.34%
|
| 103 |
+
Prompt: Identify whether the given pair of sentences demonstrates entailment or not_entailment. Answer with 'A. entailment' or 'B. not_entailment'., acc: 54.93%
|
| 104 |
+
...
|
| 105 |
+
Original prompt: Assess the connection between the following sentences and classify it as 'A. entailment' or 'B. not_entailment'.
|
| 106 |
+
Attacked prompt: b"Assess the attach between the following sentences and sorted it as 'A. entailment' or 'B. not_entailment'."
|
| 107 |
+
Original acc: 59.15%, attacked acc: 40.85%, dropped acc: 18.31%
|
| 108 |
+
```
|
docs/en/advanced_guides/subjective_evaluation.md
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Subjective Evaluation Guidance
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Subjective evaluation aims to assess the model's performance in tasks that align with human preferences. The key criterion for this evaluation is human preference, but it comes with a high cost of annotation.
|
| 6 |
+
|
| 7 |
+
To explore the model's subjective capabilities, we employ JudgeLLM as a substitute for human assessors ([LLM-as-a-Judge](https://arxiv.org/abs/2306.05685)).
|
| 8 |
+
|
| 9 |
+
A popular evaluation method involves
|
| 10 |
+
|
| 11 |
+
- Compare Mode: comparing model responses pairwise to calculate their win rate
|
| 12 |
+
- Score Mode: another method involves calculate scores with single model response ([Chatbot Arena](https://chat.lmsys.org/)).
|
| 13 |
+
|
| 14 |
+
We support the use of GPT-4 (or other JudgeLLM) for the subjective evaluation of models based on above methods.
|
| 15 |
+
|
| 16 |
+
## Currently Supported Subjective Evaluation Datasets
|
| 17 |
+
|
| 18 |
+
1. AlignBench Chinese Scoring Dataset (https://github.com/THUDM/AlignBench)
|
| 19 |
+
2. MTBench English Scoring Dataset, two-turn dialogue (https://github.com/lm-sys/FastChat)
|
| 20 |
+
3. MTBench101 English Scoring Dataset, multi-turn dialogue (https://github.com/mtbench101/mt-bench-101)
|
| 21 |
+
4. AlpacaEvalv2 English Compare Dataset (https://github.com/tatsu-lab/alpaca_eval)
|
| 22 |
+
5. ArenaHard English Compare Dataset, mainly focused on coding (https://github.com/lm-sys/arena-hard/tree/main)
|
| 23 |
+
6. Fofo English Scoring Dataset (https://github.com/SalesforceAIResearch/FoFo/)
|
| 24 |
+
7. Wildbench English Score and Compare Dataset(https://github.com/allenai/WildBench)
|
| 25 |
+
|
| 26 |
+
## Initiating Subjective Evaluation
|
| 27 |
+
|
| 28 |
+
Similar to existing objective evaluation methods, you can configure related settings in `examples/eval_subjective.py`.
|
| 29 |
+
|
| 30 |
+
### Basic Parameters: Specifying models, datasets, and judgemodels
|
| 31 |
+
|
| 32 |
+
Similar to objective evaluation, import the models and datasets that need to be evaluated, for example:
|
| 33 |
+
|
| 34 |
+
```
|
| 35 |
+
with read_base():
|
| 36 |
+
from .datasets.subjective.alignbench.alignbench_judgeby_critiquellm import alignbench_datasets
|
| 37 |
+
from .datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import subjective_datasets as alpacav2
|
| 38 |
+
from .models.qwen.hf_qwen_7b import models
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
It is worth noting that since the model setup parameters for subjective evaluation are often different from those for objective evaluation, it often requires setting up `do_sample` for inference instead of `greedy`. You can modify the relevant parameters in the configuration file as needed, for example:
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
models = [
|
| 45 |
+
dict(
|
| 46 |
+
type=HuggingFaceChatGLM3,
|
| 47 |
+
abbr='chatglm3-6b-hf2',
|
| 48 |
+
path='THUDM/chatglm3-6b',
|
| 49 |
+
tokenizer_path='THUDM/chatglm3-6b',
|
| 50 |
+
model_kwargs=dict(
|
| 51 |
+
device_map='auto',
|
| 52 |
+
trust_remote_code=True,
|
| 53 |
+
),
|
| 54 |
+
tokenizer_kwargs=dict(
|
| 55 |
+
padding_side='left',
|
| 56 |
+
truncation_side='left',
|
| 57 |
+
trust_remote_code=True,
|
| 58 |
+
),
|
| 59 |
+
generation_kwargs=dict(
|
| 60 |
+
do_sample=True,
|
| 61 |
+
),
|
| 62 |
+
meta_template=api_meta_template,
|
| 63 |
+
max_out_len=2048,
|
| 64 |
+
max_seq_len=4096,
|
| 65 |
+
batch_size=8,
|
| 66 |
+
run_cfg=dict(num_gpus=1, num_procs=1),
|
| 67 |
+
)
|
| 68 |
+
]
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
The judgemodel is usually set to a powerful model like GPT4, and you can directly enter your API key according to the configuration in the config file, or use a custom model as the judgemodel.
|
| 72 |
+
|
| 73 |
+
### Specifying Other Parameters
|
| 74 |
+
|
| 75 |
+
In addition to the basic parameters, you can also modify the `infer` and `eval` fields in the config to set a more appropriate partitioning method. The currently supported partitioning methods mainly include three types: NaivePartitioner, SizePartitioner, and NumberWorkPartitioner. You can also specify your own workdir to save related files.
|
| 76 |
+
|
| 77 |
+
## Subjective Evaluation with Custom Dataset
|
| 78 |
+
|
| 79 |
+
The specific process includes:
|
| 80 |
+
|
| 81 |
+
1. Data preparation
|
| 82 |
+
2. Model response generation
|
| 83 |
+
3. Evaluate the response with a JudgeLLM
|
| 84 |
+
4. Generate JudgeLLM's response and calculate the metric
|
| 85 |
+
|
| 86 |
+
### Step-1: Data Preparation
|
| 87 |
+
|
| 88 |
+
This step requires preparing the dataset file and implementing your own dataset class under `Opencompass/datasets/subjective/`, returning the read data in the format of `list of dict`.
|
| 89 |
+
|
| 90 |
+
Actually, you can prepare the data in any format you like (csv, json, jsonl, etc.). However, to make it easier to get started, it is recommended to construct the data according to the format of the existing subjective datasets or according to the following json format.
|
| 91 |
+
We provide mini test-set for **Compare Mode** and **Score Mode** as below:
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
###COREV2
|
| 95 |
+
[
|
| 96 |
+
{
|
| 97 |
+
"question": "如果我在空中垂直抛球,球最初向哪个方向行进?",
|
| 98 |
+
"capability": "知识-社会常识",
|
| 99 |
+
"others": {
|
| 100 |
+
"question": "如果我在空中垂直抛球,球最初向哪个方向行进?",
|
| 101 |
+
"evaluating_guidance": "",
|
| 102 |
+
"reference_answer": "上"
|
| 103 |
+
}
|
| 104 |
+
},...]
|
| 105 |
+
|
| 106 |
+
###CreationV0.1
|
| 107 |
+
[
|
| 108 |
+
{
|
| 109 |
+
"question": "请你扮演一个邮件管家,我让你给谁发送什么主题的邮件,你就帮我扩充好邮件正文,并打印在聊天框里。你需要根据我提供���邮件收件人以及邮件主题,来斟酌用词,并使用合适的敬语。现在请给导师发送邮件,询问他是否可以下周三下午15:00进行科研同步会,大约200字。",
|
| 110 |
+
"capability": "邮件通知",
|
| 111 |
+
"others": ""
|
| 112 |
+
},
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
The json must includes the following fields:
|
| 116 |
+
|
| 117 |
+
- 'question': Question description
|
| 118 |
+
- 'capability': The capability dimension of the question.
|
| 119 |
+
- 'others': Other needed information.
|
| 120 |
+
|
| 121 |
+
If you want to modify prompt on each single question, you can full some other information into 'others' and construct it.
|
| 122 |
+
|
| 123 |
+
### Step-2: Evaluation Configuration(Compare Mode)
|
| 124 |
+
|
| 125 |
+
Taking Alignbench as an example, `configs/datasets/subjective/alignbench/alignbench_judgeby_critiquellm.py`:
|
| 126 |
+
|
| 127 |
+
1. First, you need to set `subjective_reader_cfg` to receive the relevant fields returned from the custom Dataset class and specify the output fields when saving files.
|
| 128 |
+
2. Then, you need to specify the root path `data_path` of the dataset and the dataset filename `subjective_all_sets`. If there are multiple sub-files, you can add them to this list.
|
| 129 |
+
3. Specify `subjective_infer_cfg` and `subjective_eval_cfg` to configure the corresponding inference and evaluation prompts.
|
| 130 |
+
4. Specify additional information such as `mode` at the corresponding location. Note that the fields required for different subjective datasets may vary.
|
| 131 |
+
5. Define post-processing and score statistics. For example, the postprocessing function `alignbench_postprocess` located under `opencompass/opencompass/datasets/subjective/alignbench`.
|
| 132 |
+
|
| 133 |
+
### Step-3: Launch the Evaluation
|
| 134 |
+
|
| 135 |
+
```shell
|
| 136 |
+
python run.py config/eval_subjective_score.py -r
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
The `-r` parameter allows the reuse of model inference and GPT-4 evaluation results.
|
| 140 |
+
|
| 141 |
+
The response of JudgeLLM will be output to `output/.../results/timestamp/xxmodel/xxdataset/.json`.
|
| 142 |
+
The evaluation report will be output to `output/.../summary/timestamp/report.csv`.
|
| 143 |
+
|
| 144 |
+
## Multi-round Subjective Evaluation in OpenCompass
|
| 145 |
+
|
| 146 |
+
In OpenCompass, we also support subjective multi-turn dialogue evaluation. For instance, the evaluation of MT-Bench can be referred to in `configs/datasets/subjective/multiround`.
|
| 147 |
+
|
| 148 |
+
In the multi-turn dialogue evaluation, you need to organize the data format into the following dialogue structure:
|
| 149 |
+
|
| 150 |
+
```
|
| 151 |
+
"dialogue": [
|
| 152 |
+
{
|
| 153 |
+
"role": "user",
|
| 154 |
+
"content": "Imagine you are participating in a race with a group of people. If you have just overtaken the second person, what's your current position? Where is the person you just overtook?"
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"role": "assistant",
|
| 158 |
+
"content": ""
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"role": "user",
|
| 162 |
+
"content": "If the \"second person\" is changed to \"last person\" in the above question, what would the answer be?"
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"role": "assistant",
|
| 166 |
+
"content": ""
|
| 167 |
+
}
|
| 168 |
+
],
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
It's important to note that due to the different question types in MTBench having different temperature settings, we need to divide the original data files into three different subsets according to the temperature for separate inference. For different subsets, we can set different temperatures. For specific settings, please refer to `configs\datasets\subjective\multiround\mtbench_single_judge_diff_temp.py`.
|
docs/en/conf.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
# Configuration file for the Sphinx documentation builder.
|
| 3 |
+
#
|
| 4 |
+
# This file only contains a selection of the most common options. For a full
|
| 5 |
+
# list see the documentation:
|
| 6 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
| 7 |
+
|
| 8 |
+
# -- Path setup --------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
| 11 |
+
# add these directories to sys.path here. If the directory is relative to the
|
| 12 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
| 13 |
+
#
|
| 14 |
+
import os
|
| 15 |
+
import subprocess
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
import pytorch_sphinx_theme
|
| 19 |
+
from sphinx.builders.html import StandaloneHTMLBuilder
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, os.path.abspath('../../'))
|
| 22 |
+
|
| 23 |
+
# -- Project information -----------------------------------------------------
|
| 24 |
+
|
| 25 |
+
project = 'OpenCompass'
|
| 26 |
+
copyright = '2023, OpenCompass'
|
| 27 |
+
author = 'OpenCompass Authors'
|
| 28 |
+
|
| 29 |
+
# The full version, including alpha/beta/rc tags
|
| 30 |
+
version_file = '../../opencompass/__init__.py'
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_version():
|
| 34 |
+
with open(version_file, 'r') as f:
|
| 35 |
+
exec(compile(f.read(), version_file, 'exec'))
|
| 36 |
+
return locals()['__version__']
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
release = get_version()
|
| 40 |
+
|
| 41 |
+
# -- General configuration ---------------------------------------------------
|
| 42 |
+
|
| 43 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
| 44 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
| 45 |
+
# ones.
|
| 46 |
+
extensions = [
|
| 47 |
+
'sphinx.ext.autodoc',
|
| 48 |
+
'sphinx.ext.autosummary',
|
| 49 |
+
'sphinx.ext.intersphinx',
|
| 50 |
+
'sphinx.ext.napoleon',
|
| 51 |
+
'sphinx.ext.viewcode',
|
| 52 |
+
'myst_parser',
|
| 53 |
+
'sphinx_copybutton',
|
| 54 |
+
'sphinx_tabs.tabs',
|
| 55 |
+
'notfound.extension',
|
| 56 |
+
'sphinxcontrib.jquery',
|
| 57 |
+
'sphinx_design',
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
# Add any paths that contain templates here, relative to this directory.
|
| 61 |
+
templates_path = ['_templates']
|
| 62 |
+
|
| 63 |
+
# The suffix(es) of source filenames.
|
| 64 |
+
# You can specify multiple suffix as a list of string:
|
| 65 |
+
#
|
| 66 |
+
source_suffix = {
|
| 67 |
+
'.rst': 'restructuredtext',
|
| 68 |
+
'.md': 'markdown',
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
language = 'en'
|
| 72 |
+
|
| 73 |
+
# The master toctree document.
|
| 74 |
+
root_doc = 'index'
|
| 75 |
+
|
| 76 |
+
# List of patterns, relative to source directory, that match files and
|
| 77 |
+
# directories to ignore when looking for source files.
|
| 78 |
+
# This pattern also affects html_static_path and html_extra_path.
|
| 79 |
+
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
| 80 |
+
|
| 81 |
+
# -- Options for HTML output -------------------------------------------------
|
| 82 |
+
|
| 83 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
| 84 |
+
# a list of builtin themes.
|
| 85 |
+
#
|
| 86 |
+
html_theme = 'pytorch_sphinx_theme'
|
| 87 |
+
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
|
| 88 |
+
|
| 89 |
+
# Theme options are theme-specific and customize the look and feel of a theme
|
| 90 |
+
# further. For a list of options available for each theme, see the
|
| 91 |
+
# documentation.
|
| 92 |
+
# yapf: disable
|
| 93 |
+
html_theme_options = {
|
| 94 |
+
'menu': [
|
| 95 |
+
{
|
| 96 |
+
'name': 'GitHub',
|
| 97 |
+
'url': 'https://github.com/open-compass/opencompass'
|
| 98 |
+
},
|
| 99 |
+
],
|
| 100 |
+
# Specify the language of shared menu
|
| 101 |
+
'menu_lang': 'en',
|
| 102 |
+
# Disable the default edit on GitHub
|
| 103 |
+
'default_edit_on_github': False,
|
| 104 |
+
}
|
| 105 |
+
# yapf: enable
|
| 106 |
+
|
| 107 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
| 108 |
+
# relative to this directory. They are copied after the builtin static files,
|
| 109 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 110 |
+
html_static_path = ['_static']
|
| 111 |
+
html_css_files = [
|
| 112 |
+
'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.css',
|
| 113 |
+
'css/readthedocs.css'
|
| 114 |
+
]
|
| 115 |
+
html_js_files = [
|
| 116 |
+
'https://cdn.datatables.net/v/bs4/dt-1.12.1/datatables.min.js',
|
| 117 |
+
'js/custom.js'
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
html_context = {
|
| 121 |
+
'github_version': 'main',
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# -- Options for HTMLHelp output ---------------------------------------------
|
| 125 |
+
|
| 126 |
+
# Output file base name for HTML help builder.
|
| 127 |
+
htmlhelp_basename = 'opencompassdoc'
|
| 128 |
+
|
| 129 |
+
# -- Options for LaTeX output ------------------------------------------------
|
| 130 |
+
|
| 131 |
+
latex_elements = {
|
| 132 |
+
# The paper size ('letterpaper' or 'a4paper').
|
| 133 |
+
#
|
| 134 |
+
# 'papersize': 'letterpaper',
|
| 135 |
+
|
| 136 |
+
# The font size ('10pt', '11pt' or '12pt').
|
| 137 |
+
#
|
| 138 |
+
# 'pointsize': '10pt',
|
| 139 |
+
|
| 140 |
+
# Additional stuff for the LaTeX preamble.
|
| 141 |
+
#
|
| 142 |
+
# 'preamble': '',
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
# Grouping the document tree into LaTeX files. List of tuples
|
| 146 |
+
# (source start file, target name, title,
|
| 147 |
+
# author, documentclass [howto, manual, or own class]).
|
| 148 |
+
latex_documents = [
|
| 149 |
+
(root_doc, 'opencompass.tex', 'OpenCompass Documentation', author,
|
| 150 |
+
'manual'),
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
# -- Options for manual page output ------------------------------------------
|
| 154 |
+
|
| 155 |
+
# One entry per manual page. List of tuples
|
| 156 |
+
# (source start file, name, description, authors, manual section).
|
| 157 |
+
man_pages = [(root_doc, 'opencompass', 'OpenCompass Documentation', [author],
|
| 158 |
+
1)]
|
| 159 |
+
|
| 160 |
+
# -- Options for Texinfo output ----------------------------------------------
|
| 161 |
+
|
| 162 |
+
# Grouping the document tree into Texinfo files. List of tuples
|
| 163 |
+
# (source start file, target name, title, author,
|
| 164 |
+
# dir menu entry, description, category)
|
| 165 |
+
texinfo_documents = [
|
| 166 |
+
(root_doc, 'opencompass', 'OpenCompass Documentation', author,
|
| 167 |
+
'OpenCompass Authors', 'AGI evaluation toolbox and benchmark.',
|
| 168 |
+
'Miscellaneous'),
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
# -- Options for Epub output -------------------------------------------------
|
| 172 |
+
|
| 173 |
+
# Bibliographic Dublin Core info.
|
| 174 |
+
epub_title = project
|
| 175 |
+
|
| 176 |
+
# The unique identifier of the text. This can be a ISBN number
|
| 177 |
+
# or the project homepage.
|
| 178 |
+
#
|
| 179 |
+
# epub_identifier = ''
|
| 180 |
+
|
| 181 |
+
# A unique identification for the text.
|
| 182 |
+
#
|
| 183 |
+
# epub_uid = ''
|
| 184 |
+
|
| 185 |
+
# A list of files that should not be packed into the epub file.
|
| 186 |
+
epub_exclude_files = ['search.html']
|
| 187 |
+
|
| 188 |
+
# set priority when building html
|
| 189 |
+
StandaloneHTMLBuilder.supported_image_types = [
|
| 190 |
+
'image/svg+xml', 'image/gif', 'image/png', 'image/jpeg'
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
# -- Extension configuration -------------------------------------------------
|
| 194 |
+
# Ignore >>> when copying code
|
| 195 |
+
copybutton_prompt_text = r'>>> |\.\.\. '
|
| 196 |
+
copybutton_prompt_is_regexp = True
|
| 197 |
+
|
| 198 |
+
# Auto-generated header anchors
|
| 199 |
+
myst_heading_anchors = 3
|
| 200 |
+
# Enable "colon_fence" extension of myst.
|
| 201 |
+
myst_enable_extensions = ['colon_fence', 'dollarmath']
|
| 202 |
+
|
| 203 |
+
# Configuration for intersphinx
|
| 204 |
+
intersphinx_mapping = {
|
| 205 |
+
'python': ('https://docs.python.org/3', None),
|
| 206 |
+
'numpy': ('https://numpy.org/doc/stable', None),
|
| 207 |
+
'torch': ('https://pytorch.org/docs/stable/', None),
|
| 208 |
+
'mmengine': ('https://mmengine.readthedocs.io/en/latest/', None),
|
| 209 |
+
'transformers':
|
| 210 |
+
('https://huggingface.co/docs/transformers/main/en/', None),
|
| 211 |
+
}
|
| 212 |
+
napoleon_custom_sections = [
|
| 213 |
+
# Custom sections for data elements.
|
| 214 |
+
('Meta fields', 'params_style'),
|
| 215 |
+
('Data fields', 'params_style'),
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
# Disable docstring inheritance
|
| 219 |
+
autodoc_inherit_docstrings = False
|
| 220 |
+
# Mock some imports during generate API docs.
|
| 221 |
+
autodoc_mock_imports = ['rich', 'attr', 'einops']
|
| 222 |
+
# Disable displaying type annotations, these can be very verbose
|
| 223 |
+
autodoc_typehints = 'none'
|
| 224 |
+
|
| 225 |
+
# The not found page
|
| 226 |
+
notfound_template = '404.html'
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def builder_inited_handler(app):
|
| 230 |
+
subprocess.run(['./statis.py'])
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def setup(app):
|
| 234 |
+
app.connect('builder-inited', builder_inited_handler)
|
docs/en/docutils.conf
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[html writers]
|
| 2 |
+
table_style: colwidths-auto
|
docs/en/get_started/faq.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FAQ
|
| 2 |
+
|
| 3 |
+
## General
|
| 4 |
+
|
| 5 |
+
### What are the differences and connections between `ppl` and `gen`?
|
| 6 |
+
|
| 7 |
+
`ppl` stands for perplexity, an index used to evaluate a model's language modeling capabilities. In the context of OpenCompass, it generally refers to a method of answering multiple-choice questions: given a context, the model needs to choose the most appropriate option from multiple choices. In this case, we concatenate the n options with the context to form n sequences, then calculate the model's perplexity for these n sequences. We consider the option corresponding to the sequence with the lowest perplexity as the model's reasoning result for this question. This evaluation method is simple and direct in post-processing, with high certainty.
|
| 8 |
+
|
| 9 |
+
`gen` is an abbreviation for generate. In the context of OpenCompass, it refers to the model's continuation writing result given a context as the reasoning result for a question. Generally, the string obtained from continuation writing requires a heavier post-processing process to extract reliable answers and complete the evaluation.
|
| 10 |
+
|
| 11 |
+
In terms of usage, multiple-choice questions and some multiple-choice-like questions of the base model use `ppl`, while the base model's multiple-selection and non-multiple-choice questions use `gen`. All questions of the chat model use `gen`, as many commercial API models do not expose the `ppl` interface. However, there are exceptions, such as when we want the base model to output the problem-solving process (e.g., Let's think step by step), we will also use `gen`, but the overall usage is as shown in the following table:
|
| 12 |
+
|
| 13 |
+
| | ppl | gen |
|
| 14 |
+
| ---------- | -------------- | -------------------- |
|
| 15 |
+
| Base Model | Only MCQ Tasks | Tasks Other Than MCQ |
|
| 16 |
+
| Chat Model | None | All Tasks |
|
| 17 |
+
|
| 18 |
+
Similar to `ppl`, conditional log probability (`clp`) calculates the probability of the next token given a context. It is also only applicable to multiple-choice questions, and the range of probability calculation is limited to the tokens corresponding to the option numbers. The option corresponding to the token with the highest probability is considered the model's reasoning result. Compared to `ppl`, `clp` calculation is more efficient, requiring only one inference, whereas `ppl` requires n inferences. However, the drawback is that `clp` is subject to the tokenizer. For example, the presence or absence of space symbols before and after an option can change the tokenizer's encoding result, leading to unreliable test results. Therefore, `clp` is rarely used in OpenCompass.
|
| 19 |
+
|
| 20 |
+
### How does OpenCompass control the number of shots in few-shot evaluations?
|
| 21 |
+
|
| 22 |
+
In the dataset configuration file, there is a retriever field indicating how to recall samples from the dataset as context examples. The most commonly used is `FixKRetriever`, which means using a fixed k samples, hence k-shot. There is also `ZeroRetriever`, which means not using any samples, which in most cases implies 0-shot.
|
| 23 |
+
|
| 24 |
+
On the other hand, in-context samples can also be directly specified in the dataset template. In this case, `ZeroRetriever` is also used, but the evaluation is not 0-shot and needs to be determined based on the specific template. Refer to [prompt](../prompt/prompt_template.md) for more details
|
| 25 |
+
|
| 26 |
+
### How does OpenCompass allocate GPUs?
|
| 27 |
+
|
| 28 |
+
OpenCompass processes evaluation requests using the unit termed as "task". Each task is an independent combination of model(s) and dataset(s). The GPU resources needed for a task are determined entirely by the model being evaluated, specifically by the `num_gpus` parameter.
|
| 29 |
+
|
| 30 |
+
During evaluation, OpenCompass deploys multiple workers to execute tasks in parallel. These workers continuously try to secure GPU resources and run tasks until they succeed. As a result, OpenCompass always strives to leverage all available GPU resources to their maximum capacity.
|
| 31 |
+
|
| 32 |
+
For instance, if you're using OpenCompass on a local machine equipped with 8 GPUs, and each task demands 4 GPUs, then by default, OpenCompass will employ all 8 GPUs to concurrently run 2 tasks. However, if you adjust the `--max-num-workers` setting to 1, then only one task will be processed at a time, utilizing just 4 GPUs.
|
| 33 |
+
|
| 34 |
+
### Why doesn't the GPU behavior of HuggingFace models align with my expectations?
|
| 35 |
+
|
| 36 |
+
This is a complex issue that needs to be explained from both the supply and demand sides:
|
| 37 |
+
|
| 38 |
+
The supply side refers to how many tasks are being run. A task is a combination of a model and a dataset, and it primarily depends on how many models and datasets need to be tested. Additionally, since OpenCompass splits a larger task into multiple smaller tasks, the number of data entries per sub-task (`--max-partition-size`) also affects the number of tasks. (The `--max-partition-size` is proportional to the actual number of data entries, but the relationship is not 1:1).
|
| 39 |
+
|
| 40 |
+
The demand side refers to how many workers are running. Since OpenCompass instantiates multiple models for inference simultaneously, we use `--hf-num-gpus` to specify how many GPUs each instance uses. Note that `--hf-num-gpus` is a parameter specific to HuggingFace models and setting this parameter for non-HuggingFace models will not have any effect. We also use `--max-num-workers` to indicate the maximum number of instances running at the same time. Lastly, due to issues like GPU memory and insufficient load, OpenCompass also supports running multiple instances on the same GPU, which is managed by the parameter `--max-num-workers-per-gpu`. Therefore, it can be generally assumed that we will use a total of `--hf-num-gpus` * `--max-num-workers` / `--max-num-workers-per-gpu` GPUs.
|
| 41 |
+
|
| 42 |
+
In summary, when tasks run slowly or the GPU load is low, we first need to check if the supply is sufficient. If not, consider reducing `--max-partition-size` to split the tasks into finer parts. Next, we need to check if the demand is sufficient. If not, consider increasing `--max-num-workers` and `--max-num-workers-per-gpu`. Generally, **we set `--hf-num-gpus` to the minimum value that meets the demand and do not adjust it further.**
|
| 43 |
+
|
| 44 |
+
### How do I control the number of GPUs that OpenCompass occupies?
|
| 45 |
+
|
| 46 |
+
Currently, there isn't a direct method to specify the number of GPUs OpenCompass can utilize. However, the following are some indirect strategies:
|
| 47 |
+
|
| 48 |
+
**If evaluating locally:**
|
| 49 |
+
You can limit OpenCompass's GPU access by setting the `CUDA_VISIBLE_DEVICES` environment variable. For instance, using `CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py ...` will only expose the first four GPUs to OpenCompass, ensuring it uses no more than these four GPUs simultaneously.
|
| 50 |
+
|
| 51 |
+
**If using Slurm or DLC:**
|
| 52 |
+
Although OpenCompass doesn't have direct access to the resource pool, you can adjust the `--max-num-workers` parameter to restrict the number of evaluation tasks being submitted simultaneously. This will indirectly manage the number of GPUs that OpenCompass employs. For instance, if each task requires 4 GPUs, and you wish to allocate a total of 8 GPUs, then you should set `--max-num-workers` to 2.
|
| 53 |
+
|
| 54 |
+
### `libGL.so.1` not foune
|
| 55 |
+
|
| 56 |
+
opencv-python depends on some dynamic libraries that are not present in the environment. The simplest solution is to uninstall opencv-python and then install opencv-python-headless.
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
pip uninstall opencv-python
|
| 60 |
+
pip install opencv-python-headless
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Alternatively, you can install the corresponding dependency libraries according to the error message
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
sudo apt-get update
|
| 67 |
+
sudo apt-get install -y libgl1 libglib2.0-0
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Network
|
| 71 |
+
|
| 72 |
+
### My tasks failed with error: `('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))` or `urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='cdn-lfs.huggingface.co', port=443)`
|
| 73 |
+
|
| 74 |
+
Because of HuggingFace's implementation, OpenCompass requires network (especially the connection to HuggingFace) for the first time it loads some datasets and models. Additionally, it connects to HuggingFace each time it is launched. For a successful run, you may:
|
| 75 |
+
|
| 76 |
+
- Work behind a proxy by specifying the environment variables `http_proxy` and `https_proxy`;
|
| 77 |
+
- Use the cache files from other machines. You may first run the experiment on a machine that has access to the Internet, and then copy the cached files to the offline one. The cached files are located at `~/.cache/huggingface/` by default ([doc](https://huggingface.co/docs/datasets/cache#cache-directory)). When the cached files are ready, you can start the evaluation in offline mode:
|
| 78 |
+
```python
|
| 79 |
+
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 HF_EVALUATE_OFFLINE=1 python run.py ...
|
| 80 |
+
```
|
| 81 |
+
With which no more network connection is needed for the evaluation. However, error will still be raised if the files any dataset or model is missing from the cache.
|
| 82 |
+
- Use mirror like [hf-mirror](https://hf-mirror.com/)
|
| 83 |
+
```python
|
| 84 |
+
HF_ENDPOINT=https://hf-mirror.com python run.py ...
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### My server cannot connect to the Internet, how can I use OpenCompass?
|
| 88 |
+
|
| 89 |
+
Use the cache files from other machines, as suggested in the answer to [Network-Q1](#my-tasks-failed-with-error-connection-aborted-connectionreseterror104-connection-reset-by-peer-or-urllib3exceptionsmaxretryerror-httpsconnectionpoolhostcdn-lfshuggingfaceco-port443).
|
| 90 |
+
|
| 91 |
+
### In evaluation phase, I'm running into an error saying that `FileNotFoundError: Couldn't find a module script at opencompass/accuracy.py. Module 'accuracy' doesn't exist on the Hugging Face Hub either.`
|
| 92 |
+
|
| 93 |
+
HuggingFace tries to load the metric (e.g. `accuracy`) as an module online, and it could fail if the network is unreachable. Please refer to [Network-Q1](#my-tasks-failed-with-error-connection-aborted-connectionreseterror104-connection-reset-by-peer-or-urllib3exceptionsmaxretryerror-httpsconnectionpoolhostcdn-lfshuggingfaceco-port443) for guidelines to fix your network issue.
|
| 94 |
+
|
| 95 |
+
The issue has been fixed in the latest version of OpenCompass, so you might also consider pull from the latest version.
|
| 96 |
+
|
| 97 |
+
## Efficiency
|
| 98 |
+
|
| 99 |
+
### Why does OpenCompass partition each evaluation request into tasks?
|
| 100 |
+
|
| 101 |
+
Given the extensive evaluation time and the vast quantity of datasets, conducting a comprehensive linear evaluation on LLM models can be immensely time-consuming. To address this, OpenCompass divides the evaluation request into multiple independent "tasks". These tasks are then dispatched to various GPU groups or nodes, achieving full parallelism and maximizing the efficiency of computational resources.
|
| 102 |
+
|
| 103 |
+
### How does task partitioning work?
|
| 104 |
+
|
| 105 |
+
Each task in OpenCompass represents a combination of specific model(s) and portions of the dataset awaiting evaluation. OpenCompass offers a variety of task partitioning strategies, each tailored for different scenarios. During the inference stage, the prevalent partitioning method seeks to balance task size, or computational cost. This cost is heuristically derived from the dataset size and the type of inference.
|
| 106 |
+
|
| 107 |
+
### Why does it take more time to evaluate LLM models on OpenCompass?
|
| 108 |
+
|
| 109 |
+
There is a tradeoff between the number of tasks and the time to load the model. For example, if we partition an request that evaluates a model against a dataset into 100 tasks, the model will be loaded 100 times in total. When resources are abundant, these 100 tasks can be executed in parallel, so the additional time spent on model loading can be ignored. However, if resources are limited, these 100 tasks will operate more sequentially, and repeated loadings can become a bottleneck in execution time.
|
| 110 |
+
|
| 111 |
+
Hence, if users find that the number of tasks greatly exceeds the available GPUs, we advise setting the `--max-partition-size` to a larger value.
|
| 112 |
+
|
| 113 |
+
## Model
|
| 114 |
+
|
| 115 |
+
### How to use the downloaded huggingface models?
|
| 116 |
+
|
| 117 |
+
If you have already download the checkpoints of the model, you can specify the local path of the model. For example
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
python run.py --datasets siqa_gen winograd_ppl --hf-type base --hf-path /path/to/model
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## Dataset
|
| 124 |
+
|
| 125 |
+
### How to build a new dataset?
|
| 126 |
+
|
| 127 |
+
- For building new objective dataset: [new_dataset](../advanced_guides/new_dataset.md)
|
| 128 |
+
- For building new subjective dataset: [subjective_evaluation](../advanced_guides/subjective_evaluation.md)
|
docs/en/get_started/installation.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Installation
|
| 2 |
+
|
| 3 |
+
## Basic Installation
|
| 4 |
+
|
| 5 |
+
1. Prepare the OpenCompass runtime environment using Conda:
|
| 6 |
+
|
| 7 |
+
```conda create --name opencompass python=3.10 -y
|
| 8 |
+
# conda create --name opencompass_lmdeploy python=3.10 -y
|
| 9 |
+
|
| 10 |
+
conda activate opencompass
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
If you want to customize the PyTorch version or related CUDA version, please refer to the [official documentation](https://pytorch.org/get-started/locally/) to set up the PyTorch environment. Note that OpenCompass requires `pytorch>=1.13`.
|
| 14 |
+
|
| 15 |
+
2. Install OpenCompass:
|
| 16 |
+
- pip Installation
|
| 17 |
+
```bash
|
| 18 |
+
# For support of most datasets and models
|
| 19 |
+
pip install -U opencompass
|
| 20 |
+
|
| 21 |
+
# Complete installation (supports more datasets)
|
| 22 |
+
# pip install "opencompass[full]"
|
| 23 |
+
|
| 24 |
+
# API Testing (e.g., OpenAI, Qwen)
|
| 25 |
+
# pip install "opencompass[api]"
|
| 26 |
+
```
|
| 27 |
+
- Building from Source Code If you want to use the latest features of OpenCompass
|
| 28 |
+
```bash
|
| 29 |
+
git clone https://github.com/open-compass/opencompass opencompass
|
| 30 |
+
cd opencompass
|
| 31 |
+
pip install -e .
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Other Installations
|
| 35 |
+
|
| 36 |
+
### Inference Backends
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# Model inference backends. Since these backends often have dependency conflicts,
|
| 40 |
+
# we recommend using separate virtual environments to manage them.
|
| 41 |
+
pip install "opencompass[lmdeploy]"
|
| 42 |
+
# pip install "opencompass[vllm]"
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
- LMDeploy
|
| 46 |
+
|
| 47 |
+
You can check if the inference backend has been installed successfully with the following command. For more information, refer to the [official documentation](https://lmdeploy.readthedocs.io/en/latest/get_started.html)
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
lmdeploy chat internlm/internlm2_5-1_8b-chat --backend turbomind
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
- vLLM
|
| 54 |
+
|
| 55 |
+
You can check if the inference backend has been installed successfully with the following command. For more information, refer to the [official documentation](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
vllm serve facebook/opt-125m
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### API
|
| 62 |
+
|
| 63 |
+
OpenCompass supports different commercial model API calls, which you can install via pip or by referring to the [API dependencies](https://github.com/open-compass/opencompass/blob/main/requirements/api.txt) for specific API model dependencies.
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
pip install "opencompass[api]"
|
| 67 |
+
|
| 68 |
+
# pip install openai # GPT-3.5-Turbo / GPT-4-Turbo / GPT-4 / GPT-4o (API)
|
| 69 |
+
# pip install anthropic # Claude (API)
|
| 70 |
+
# pip install dashscope # Qwen (API)
|
| 71 |
+
# pip install volcengine-python-sdk # ByteDance Volcano Engine (API)
|
| 72 |
+
# ...
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Datasets
|
| 76 |
+
|
| 77 |
+
The basic installation supports most fundamental datasets. For certain datasets (e.g., Alpaca-eval, Longbench, etc.), additional dependencies need to be installed.
|
| 78 |
+
|
| 79 |
+
You can install these through pip or refer to the [additional dependencies](<(https://github.com/open-compass/opencompass/blob/main/requirements/extra.txt)>) for specific dependencies.
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
pip install "opencompass[full]"
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
For HumanEvalX / HumanEval+ / MBPP+, you need to manually clone the Git repository and install it.
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
git clone --recurse-submodules git@github.com:open-compass/human-eval.git
|
| 89 |
+
cd human-eval
|
| 90 |
+
pip install -e .
|
| 91 |
+
pip install -e evalplus
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Some agent evaluations require installing numerous dependencies, which may conflict with existing runtime environments. We recommend creating separate conda environments to manage these.
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
# T-Eval
|
| 98 |
+
pip install lagent==0.1.2
|
| 99 |
+
# CIBench
|
| 100 |
+
pip install -r requirements/agent.txt
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
# Dataset Preparation
|
| 104 |
+
|
| 105 |
+
The datasets supported by OpenCompass mainly include three parts:
|
| 106 |
+
|
| 107 |
+
1. Huggingface datasets: The [Huggingface Datasets](https://huggingface.co/datasets) provide a large number of datasets, which will **automatically download** when running with this option.
|
| 108 |
+
Translate the paragraph into English:
|
| 109 |
+
|
| 110 |
+
2. ModelScope Datasets: [ModelScope OpenCompass Dataset](https://modelscope.cn/organization/opencompass) supports automatic downloading of datasets from ModelScope.
|
| 111 |
+
|
| 112 |
+
To enable this feature, set the environment variable: `export DATASET_SOURCE=ModelScope`. The available datasets include (sourced from OpenCompassData-core.zip):
|
| 113 |
+
|
| 114 |
+
```plain
|
| 115 |
+
humaneval, triviaqa, commonsenseqa, tydiqa, strategyqa, cmmlu, lambada, piqa, ceval, math, LCSTS, Xsum, winogrande, openbookqa, AGIEval, gsm8k, nq, race, siqa, mbpp, mmlu, hellaswag, ARC, BBH, xstory_cloze, summedits, GAOKAO-BENCH, OCNLI, cmnli
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
3. Custom dataset: OpenCompass also provides some Chinese custom **self-built** datasets. Please run the following command to **manually download and extract** them.
|
| 119 |
+
|
| 120 |
+
Run the following commands to download and place the datasets in the `${OpenCompass}/data` directory can complete dataset preparation.
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
# Run in the OpenCompass directory
|
| 124 |
+
wget https://github.com/open-compass/opencompass/releases/download/0.2.2.rc1/OpenCompassData-core-20240207.zip
|
| 125 |
+
unzip OpenCompassData-core-20240207.zip
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
If you need to use the more comprehensive dataset (~500M) provided by OpenCompass, You can download and `unzip` it using the following command:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
# For proxy and resumable downloads, try `aria2c -x16 -s16 -k1M "http://ghfast.top/https://github.com/open-compass/opencompass/releases/download/0.2.2.rc1/OpenCompassData-complete-20240207.zip" `
|
| 132 |
+
wget https://github.com/open-compass/opencompass/releases/download/0.2.2.rc1/OpenCompassData-complete-20240207.zip
|
| 133 |
+
unzip OpenCompassData-complete-20240207.zip
|
| 134 |
+
cd ./data
|
| 135 |
+
find . -name "*.zip" -exec unzip "{}" \;
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
The list of datasets included in both `.zip` can be found [here](https://github.com/open-compass/opencompass/releases/tag/0.2.2.rc1)
|
| 139 |
+
|
| 140 |
+
OpenCompass has supported most of the datasets commonly used for performance comparison, please refer to `configs/dataset` for the specific list of supported datasets.
|
| 141 |
+
|
| 142 |
+
For next step, please read [Quick Start](./quick_start.md).
|
docs/en/get_started/quick_start.md
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quick Start
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
OpenCompass provides a streamlined workflow for evaluating a model, which consists of the following stages: **Configure** -> **Inference** -> **Evaluation** -> **Visualization**.
|
| 8 |
+
|
| 9 |
+
**Configure**: This is your starting point. Here, you'll set up the entire evaluation process, choosing the model(s) and dataset(s) to assess. You also have the option to select an evaluation strategy, the computation backend, and define how you'd like the results displayed.
|
| 10 |
+
|
| 11 |
+
**Inference & Evaluation**: OpenCompass efficiently manages the heavy lifting, conducting parallel inference and evaluation on your chosen model(s) and dataset(s). The **Inference** phase is all about producing outputs from your datasets, whereas the **Evaluation** phase measures how well these outputs align with the gold standard answers. While this procedure is broken down into multiple "tasks" that run concurrently for greater efficiency, be aware that working with limited computational resources might introduce some unexpected overheads, and resulting in generally slower evaluation. To understand this issue and know how to solve it, check out [FAQ: Efficiency](faq.md#efficiency).
|
| 12 |
+
|
| 13 |
+
**Visualization**: Once the evaluation is done, OpenCompass collates the results into an easy-to-read table and saves them as both CSV and TXT files. If you need real-time updates, you can activate lark reporting and get immediate status reports in your Lark clients.
|
| 14 |
+
|
| 15 |
+
Coming up, we'll walk you through the basics of OpenCompass, showcasing evaluations of pretrained models [OPT-125M](https://huggingface.co/facebook/opt-125m) and [OPT-350M](https://huggingface.co/facebook/opt-350m) on the [SIQA](https://huggingface.co/datasets/social_i_qa) and [Winograd](https://huggingface.co/datasets/winograd_wsc) benchmark tasks. Their configuration files can be found at [configs/eval_demo.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_demo.py).
|
| 16 |
+
|
| 17 |
+
Before running this experiment, please make sure you have installed OpenCompass locally and it should run successfully under one _GTX-1660-6G_ GPU.
|
| 18 |
+
For larger parameterized models like Llama-7B, refer to other examples provided in the [configs directory](https://github.com/open-compass/opencompass/tree/main/configs).
|
| 19 |
+
|
| 20 |
+
## Configuring an Evaluation Task
|
| 21 |
+
|
| 22 |
+
In OpenCompass, each evaluation task consists of the model to be evaluated and the dataset. The entry point for evaluation is `run.py`. Users can select the model and dataset to be tested either via command line or configuration files.
|
| 23 |
+
|
| 24 |
+
`````{tabs}
|
| 25 |
+
````{tab} Command Line (Custom HF Model)
|
| 26 |
+
|
| 27 |
+
For HuggingFace models, users can set model parameters directly through the command line without additional configuration files. For instance, for the `facebook/opt-125m` model, you can evaluate it with the following command:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
python run.py --datasets siqa_gen winograd_ppl \
|
| 31 |
+
--hf-type base \
|
| 32 |
+
--hf-path facebook/opt-125m
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Note that in this way, OpenCompass only evaluates one model at a time, while other ways can evaluate multiple models at once.
|
| 36 |
+
|
| 37 |
+
```{caution}
|
| 38 |
+
`--hf-num-gpus` does not stand for the actual number of GPUs to use in evaluation, but the minimum required number of GPUs for this model. [More](faq.md#how-does-opencompass-allocate-gpus)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
:::{dropdown} More detailed example
|
| 42 |
+
:animate: fade-in-slide-down
|
| 43 |
+
```bash
|
| 44 |
+
python run.py --datasets siqa_gen winograd_ppl \
|
| 45 |
+
--hf-type base \ # HuggingFace model type, base or chat
|
| 46 |
+
--hf-path facebook/opt-125m \ # HuggingFace model path
|
| 47 |
+
--tokenizer-path facebook/opt-125m \ # HuggingFace tokenizer path (if the same as the model path, can be omitted)
|
| 48 |
+
--tokenizer-kwargs padding_side='left' truncation='left' trust_remote_code=True \ # Arguments to construct the tokenizer
|
| 49 |
+
--model-kwargs device_map='auto' \ # Arguments to construct the model
|
| 50 |
+
--max-seq-len 2048 \ # Maximum sequence length the model can accept
|
| 51 |
+
--max-out-len 100 \ # Maximum number of tokens to generate
|
| 52 |
+
--min-out-len 100 \ # Minimum number of tokens to generate
|
| 53 |
+
--batch-size 64 \ # Batch size
|
| 54 |
+
--hf-num-gpus 1 # Number of GPUs required to run the model
|
| 55 |
+
```
|
| 56 |
+
```{seealso}
|
| 57 |
+
For all HuggingFace related parameters supported by `run.py`, please read [Launching Evaluation Task](../user_guides/experimentation.md#launching-an-evaluation-task).
|
| 58 |
+
```
|
| 59 |
+
:::
|
| 60 |
+
|
| 61 |
+
````
|
| 62 |
+
````{tab} Command Line
|
| 63 |
+
|
| 64 |
+
Users can combine the models and datasets they want to test using `--models` and `--datasets`.
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
python run.py --models hf_opt_125m hf_opt_350m --datasets siqa_gen winograd_ppl
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
The models and datasets are pre-stored in the form of configuration files in `configs/models` and `configs/datasets`. Users can view or filter the currently available model and dataset configurations using `tools/list_configs.py`.
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
# List all configurations
|
| 74 |
+
python tools/list_configs.py
|
| 75 |
+
# List all configurations related to llama and mmlu
|
| 76 |
+
python tools/list_configs.py llama mmlu
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
:::{dropdown} More about `list_configs`
|
| 80 |
+
:animate: fade-in-slide-down
|
| 81 |
+
|
| 82 |
+
Running `python tools/list_configs.py llama mmlu` gives the output like:
|
| 83 |
+
|
| 84 |
+
```text
|
| 85 |
+
+-----------------+-----------------------------------+
|
| 86 |
+
| Model | Config Path |
|
| 87 |
+
|-----------------+-----------------------------------|
|
| 88 |
+
| hf_llama2_13b | configs/models/hf_llama2_13b.py |
|
| 89 |
+
| hf_llama2_70b | configs/models/hf_llama2_70b.py |
|
| 90 |
+
| ... | ... |
|
| 91 |
+
+-----------------+-----------------------------------+
|
| 92 |
+
+-------------------+---------------------------------------------------+
|
| 93 |
+
| Dataset | Config Path |
|
| 94 |
+
|-------------------+---------------------------------------------------|
|
| 95 |
+
| cmmlu_gen | configs/datasets/cmmlu/cmmlu_gen.py |
|
| 96 |
+
| cmmlu_gen_ffe7c0 | configs/datasets/cmmlu/cmmlu_gen_ffe7c0.py |
|
| 97 |
+
| ... | ... |
|
| 98 |
+
+-------------------+---------------------------------------------------+
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Users can use the names in the first column as input parameters for `--models` and `--datasets` in `python run.py`. For datasets, the same name with different suffixes generally indicates that its prompts or evaluation methods are different.
|
| 102 |
+
:::
|
| 103 |
+
|
| 104 |
+
:::{dropdown} Model not on the list?
|
| 105 |
+
:animate: fade-in-slide-down
|
| 106 |
+
|
| 107 |
+
If you want to evaluate other models, please check out the "Command Line (Custom HF Model)" tab for the way to construct a custom HF model without a configuration file, or "Configuration File" tab to learn the general way to prepare your model configurations.
|
| 108 |
+
|
| 109 |
+
:::
|
| 110 |
+
|
| 111 |
+
````
|
| 112 |
+
|
| 113 |
+
````{tab} Configuration File
|
| 114 |
+
|
| 115 |
+
In addition to configuring the experiment through the command line, OpenCompass also allows users to write the full configuration of the experiment in a configuration file and run it directly through `run.py`. The configuration file is organized in Python format and must include the `datasets` and `models` fields.
|
| 116 |
+
|
| 117 |
+
The test configuration for this time is [configs/eval_demo.py](https://github.com/open-compass/opencompass/blob/main/configs/eval_demo.py). This configuration introduces the required dataset and model configurations through the [inheritance mechanism](../user_guides/config.md#inheritance-mechanism) and combines the `datasets` and `models` fields in the required format.
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
from mmengine.config import read_base
|
| 121 |
+
|
| 122 |
+
with read_base():
|
| 123 |
+
from .datasets.siqa.siqa_gen import siqa_datasets
|
| 124 |
+
from .datasets.winograd.winograd_ppl import winograd_datasets
|
| 125 |
+
from .models.opt.hf_opt_125m import opt125m
|
| 126 |
+
from .models.opt.hf_opt_350m import opt350m
|
| 127 |
+
|
| 128 |
+
datasets = [*siqa_datasets, *winograd_datasets]
|
| 129 |
+
models = [opt125m, opt350m]
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
When running tasks, we just need to pass the path of the configuration file to `run.py`:
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
python run.py configs/eval_demo.py
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
:::{dropdown} More about `models`
|
| 139 |
+
:animate: fade-in-slide-down
|
| 140 |
+
|
| 141 |
+
OpenCompass provides a series of pre-defined model configurations under `configs/models`. Below is the configuration snippet related to [opt-350m](https://github.com/open-compass/opencompass/blob/main/configs/models/opt/hf_opt_350m.py) (`configs/models/opt/hf_opt_350m.py`):
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
# Evaluate models supported by HuggingFace's `AutoModelForCausalLM` using `HuggingFaceBaseModel`
|
| 145 |
+
from opencompass.models import HuggingFaceBaseModel
|
| 146 |
+
|
| 147 |
+
models = [
|
| 148 |
+
# OPT-350M
|
| 149 |
+
dict(
|
| 150 |
+
type=HuggingFaceBaseModel,
|
| 151 |
+
# Initialization parameters for `HuggingFaceBaseModel`
|
| 152 |
+
path='facebook/opt-350m',
|
| 153 |
+
# Below are common parameters for all models, not specific to HuggingFaceBaseModel
|
| 154 |
+
abbr='opt-350m-hf', # Model abbreviation
|
| 155 |
+
max_out_len=1024, # Maximum number of generated tokens
|
| 156 |
+
batch_size=32, # Batch size
|
| 157 |
+
run_cfg=dict(num_gpus=1), # The required GPU numbers for this model
|
| 158 |
+
)
|
| 159 |
+
]
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
When using configurations, we can specify the relevant files through the command-line argument ` --models` or import the model configurations into the `models` list in the configuration file using the inheritance mechanism.
|
| 163 |
+
|
| 164 |
+
```{seealso}
|
| 165 |
+
More information about model configuration can be found in [Prepare Models](../user_guides/models.md).
|
| 166 |
+
```
|
| 167 |
+
:::
|
| 168 |
+
|
| 169 |
+
:::{dropdown} More about `datasets`
|
| 170 |
+
:animate: fade-in-slide-down
|
| 171 |
+
|
| 172 |
+
Similar to models, dataset configuration files are provided under `configs/datasets`. Users can use `--datasets` in the command line or import related configurations in the configuration file via inheritance
|
| 173 |
+
|
| 174 |
+
Below is a dataset-related configuration snippet from `configs/eval_demo.py`:
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
from mmengine.config import read_base # Use mmengine.read_base() to read the base configuration
|
| 178 |
+
|
| 179 |
+
with read_base():
|
| 180 |
+
# Directly read the required dataset configurations from the preset dataset configurations
|
| 181 |
+
from .datasets.winograd.winograd_ppl import winograd_datasets # Read Winograd configuration, evaluated based on PPL (perplexity)
|
| 182 |
+
from .datasets.siqa.siqa_gen import siqa_datasets # Read SIQA configuration, evaluated based on generation
|
| 183 |
+
|
| 184 |
+
datasets = [*siqa_datasets, *winograd_datasets] # The final config needs to contain the required evaluation dataset list 'datasets'
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
Dataset configurations are typically of two types: 'ppl' and 'gen', indicating the evaluation method used. Where `ppl` means discriminative evaluation and `gen` means generative evaluation.
|
| 188 |
+
|
| 189 |
+
Moreover, [configs/datasets/collections](https://github.com/open-compass/opencompass/blob/main/configs/datasets/collections) houses various dataset collections, making it convenient for comprehensive evaluations. OpenCompass often uses [`base_medium.py`](/configs/datasets/collections/base_medium.py) for full-scale model testing. To replicate results, simply import that file, for example:
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
python run.py --models hf_llama_7b --datasets base_medium
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
```{seealso}
|
| 196 |
+
You can find more information from [Dataset Preparation](../user_guides/datasets.md).
|
| 197 |
+
```
|
| 198 |
+
:::
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
````
|
| 202 |
+
|
| 203 |
+
`````
|
| 204 |
+
|
| 205 |
+
```{warning}
|
| 206 |
+
OpenCompass usually assumes network is available. If you encounter network issues or wish to run OpenCompass in an offline environment, please refer to [FAQ - Network - Q1](./faq.md#network) for solutions.
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
The following sections will use configuration-based method as an example to explain the other features.
|
| 210 |
+
|
| 211 |
+
## Launching Evaluation
|
| 212 |
+
|
| 213 |
+
Since OpenCompass launches evaluation processes in parallel by default, we can start the evaluation in `--debug` mode for the first run and check if there is any problem. In `--debug` mode, the tasks will be executed sequentially and output will be printed in real time.
|
| 214 |
+
|
| 215 |
+
```bash
|
| 216 |
+
python run.py configs/eval_demo.py -w outputs/demo --debug
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
The pretrained models 'facebook/opt-350m' and 'facebook/opt-125m' will be automatically downloaded from HuggingFace during the first run.
|
| 220 |
+
If everything is fine, you should see "Starting inference process" on screen:
|
| 221 |
+
|
| 222 |
+
```bash
|
| 223 |
+
[2023-07-12 18:23:55,076] [opencompass.openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
Then you can press `ctrl+c` to interrupt the program, and run the following command in normal mode:
|
| 227 |
+
|
| 228 |
+
```bash
|
| 229 |
+
python run.py configs/eval_demo.py -w outputs/demo
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
In normal mode, the evaluation tasks will be executed parallelly in the background, and their output will be redirected to the output directory `outputs/demo/{TIMESTAMP}`. The progress bar on the frontend only indicates the number of completed tasks, regardless of their success or failure. **Any backend task failures will only trigger a warning message in the terminal.**
|
| 233 |
+
|
| 234 |
+
:::{dropdown} More parameters in `run.py`
|
| 235 |
+
:animate: fade-in-slide-down
|
| 236 |
+
Here are some parameters related to evaluation that can help you configure more efficient inference tasks based on your environment:
|
| 237 |
+
|
| 238 |
+
- `-w outputs/demo`: Work directory to save evaluation logs and results. In this case, the experiment result will be saved to `outputs/demo/{TIMESTAMP}`.
|
| 239 |
+
- `-r`: Reuse existing inference results, and skip the finished tasks. If followed by a timestamp, the result under that timestamp in the workspace path will be reused; otherwise, the latest result in the specified workspace path will be reused.
|
| 240 |
+
- `--mode all`: Specify a specific stage of the task.
|
| 241 |
+
- all: (Default) Perform a complete evaluation, including inference and evaluation.
|
| 242 |
+
- infer: Perform inference on each dataset.
|
| 243 |
+
- eval: Perform evaluation based on the inference results.
|
| 244 |
+
- viz: Display evaluation results only.
|
| 245 |
+
- `--max-partition-size 2000`: Dataset partition size. Some datasets may be large, and using this parameter can split them into multiple sub-tasks to efficiently utilize resources. However, if the partition is too fine, the overall speed may be slower due to longer model loading times.
|
| 246 |
+
- `--max-num-workers 32`: Maximum number of parallel tasks. In distributed environments such as Slurm, this parameter specifies the maximum number of submitted tasks. In a local environment, it specifies the maximum number of tasks executed in parallel. Note that the actual number of parallel tasks depends on the available GPU resources and may not be equal to this number.
|
| 247 |
+
|
| 248 |
+
If you are not performing the evaluation on your local machine but using a Slurm cluster, you can specify the following parameters:
|
| 249 |
+
|
| 250 |
+
- `--slurm`: Submit tasks using Slurm on the cluster.
|
| 251 |
+
- `--partition(-p) my_part`: Slurm cluster partition.
|
| 252 |
+
- `--retry 2`: Number of retries for failed tasks.
|
| 253 |
+
|
| 254 |
+
```{seealso}
|
| 255 |
+
The entry also supports submitting tasks to Alibaba Deep Learning Center (DLC), and more customized evaluation strategies. Please refer to [Launching an Evaluation Task](../user_guides/experimentation.md#launching-an-evaluation-task) for details.
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
:::
|
| 259 |
+
|
| 260 |
+
## Visualizing Evaluation Results
|
| 261 |
+
|
| 262 |
+
After the evaluation is complete, the evaluation results table will be printed as follows:
|
| 263 |
+
|
| 264 |
+
```text
|
| 265 |
+
dataset version metric mode opt350m opt125m
|
| 266 |
+
--------- --------- -------- ------ --------- ---------
|
| 267 |
+
siqa e78df3 accuracy gen 21.55 12.44
|
| 268 |
+
winograd b6c7ed accuracy ppl 51.23 49.82
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
All run outputs will be directed to `outputs/demo/` directory with following structure:
|
| 272 |
+
|
| 273 |
+
```text
|
| 274 |
+
outputs/default/
|
| 275 |
+
├── 20200220_120000
|
| 276 |
+
├── 20230220_183030 # one experiment pre folder
|
| 277 |
+
│ ├── configs # Dumped config files for record. Multiple configs may be kept if different experiments have been re-run on the same experiment folder
|
| 278 |
+
│ ├── logs # log files for both inference and evaluation stages
|
| 279 |
+
│ │ ├── eval
|
| 280 |
+
│ │ └── infer
|
| 281 |
+
│ ├── predictions # Prediction results for each task
|
| 282 |
+
│ ├── results # Evaluation results for each task
|
| 283 |
+
│ └── summary # Summarized evaluation results for a single experiment
|
| 284 |
+
├── ...
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
The summarization process can be further customized in configuration and output the averaged score of some benchmarks (MMLU, C-Eval, etc.).
|
| 288 |
+
|
| 289 |
+
More information about obtaining evaluation results can be found in [Results Summary](../user_guides/summarizer.md).
|
| 290 |
+
|
| 291 |
+
## Additional Tutorials
|
| 292 |
+
|
| 293 |
+
To learn more about using OpenCompass, explore the following tutorials:
|
| 294 |
+
|
| 295 |
+
- [Prepare Datasets](../user_guides/datasets.md)
|
| 296 |
+
- [Prepare Models](../user_guides/models.md)
|
| 297 |
+
- [Task Execution and Monitoring](../user_guides/experimentation.md)
|
| 298 |
+
- [Understand Prompts](../prompt/overview.md)
|
| 299 |
+
- [Results Summary](../user_guides/summarizer.md)
|
| 300 |
+
- [Learn about Config](../user_guides/config.md)
|
docs/en/index.rst
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Welcome to OpenCompass' documentation!
|
| 2 |
+
==========================================
|
| 3 |
+
|
| 4 |
+
Getting started with OpenCompass
|
| 5 |
+
-------------------------------
|
| 6 |
+
|
| 7 |
+
To help you quickly familiarized with OpenCompass, we recommend you to walk through the following documents in order:
|
| 8 |
+
|
| 9 |
+
- First read the GetStarted_ section set up the environment, and run a mini experiment.
|
| 10 |
+
|
| 11 |
+
- Then learn its basic usage through the UserGuides_.
|
| 12 |
+
|
| 13 |
+
- If you want to tune the prompts, refer to the Prompt_.
|
| 14 |
+
|
| 15 |
+
- If you want to customize some modules, like adding a new dataset or model, we have provided the AdvancedGuides_.
|
| 16 |
+
|
| 17 |
+
- There are more handy tools, such as prompt viewer and lark bot reporter, all presented in Tools_.
|
| 18 |
+
|
| 19 |
+
We always welcome *PRs* and *Issues* for the betterment of OpenCompass.
|
| 20 |
+
|
| 21 |
+
.. _GetStarted:
|
| 22 |
+
.. toctree::
|
| 23 |
+
:maxdepth: 1
|
| 24 |
+
:caption: Get Started
|
| 25 |
+
|
| 26 |
+
get_started/installation.md
|
| 27 |
+
get_started/quick_start.md
|
| 28 |
+
get_started/faq.md
|
| 29 |
+
|
| 30 |
+
.. _UserGuides:
|
| 31 |
+
.. toctree::
|
| 32 |
+
:maxdepth: 1
|
| 33 |
+
:caption: User Guides
|
| 34 |
+
|
| 35 |
+
user_guides/framework_overview.md
|
| 36 |
+
user_guides/config.md
|
| 37 |
+
user_guides/datasets.md
|
| 38 |
+
user_guides/models.md
|
| 39 |
+
user_guides/evaluation.md
|
| 40 |
+
user_guides/experimentation.md
|
| 41 |
+
user_guides/metrics.md
|
| 42 |
+
user_guides/deepseek_r1.md
|
| 43 |
+
user_guides/interns1.md
|
| 44 |
+
|
| 45 |
+
.. _Prompt:
|
| 46 |
+
.. toctree::
|
| 47 |
+
:maxdepth: 1
|
| 48 |
+
:caption: Prompt
|
| 49 |
+
|
| 50 |
+
prompt/overview.md
|
| 51 |
+
prompt/prompt_template.md
|
| 52 |
+
prompt/meta_template.md
|
| 53 |
+
prompt/chain_of_thought.md
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
.. _AdvancedGuides:
|
| 57 |
+
.. toctree::
|
| 58 |
+
:maxdepth: 1
|
| 59 |
+
:caption: Advanced Guides
|
| 60 |
+
|
| 61 |
+
advanced_guides/new_dataset.md
|
| 62 |
+
advanced_guides/custom_dataset.md
|
| 63 |
+
advanced_guides/new_model.md
|
| 64 |
+
advanced_guides/evaluation_lmdeploy.md
|
| 65 |
+
advanced_guides/accelerator_intro.md
|
| 66 |
+
advanced_guides/math_verify.md
|
| 67 |
+
advanced_guides/llm_judge.md
|
| 68 |
+
advanced_guides/code_eval.md
|
| 69 |
+
advanced_guides/code_eval_service.md
|
| 70 |
+
advanced_guides/subjective_evaluation.md
|
| 71 |
+
advanced_guides/persistence.md
|
| 72 |
+
|
| 73 |
+
.. _Tools:
|
| 74 |
+
.. toctree::
|
| 75 |
+
:maxdepth: 1
|
| 76 |
+
:caption: Tools
|
| 77 |
+
|
| 78 |
+
tools.md
|
| 79 |
+
|
| 80 |
+
.. _Dataset List:
|
| 81 |
+
.. toctree::
|
| 82 |
+
:maxdepth: 1
|
| 83 |
+
:caption: Dataset List
|
| 84 |
+
|
| 85 |
+
dataset_statistics.md
|
| 86 |
+
|
| 87 |
+
.. _Notes:
|
| 88 |
+
.. toctree::
|
| 89 |
+
:maxdepth: 1
|
| 90 |
+
:caption: Notes
|
| 91 |
+
|
| 92 |
+
notes/contribution_guide.md
|
| 93 |
+
notes/academic.md
|
| 94 |
+
|
| 95 |
+
Indexes & Tables
|
| 96 |
+
==================
|
| 97 |
+
|
| 98 |
+
* :ref:`genindex`
|
| 99 |
+
* :ref:`search`
|
docs/en/notes/academic.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Guide to Reproducing CompassAcademic Leaderboard Results
|
| 2 |
+
|
| 3 |
+
To provide users with a quick and intuitive overview of the performance of mainstream open-source and commercial models on widely-used datasets, we maintain the [CompassAcademic Leaderboard](https://rank.opencompass.org.cn/leaderboard-llm-academic/?m=REALTIME) for LLMs on our official website, updating it typically every two weeks.
|
| 4 |
+
|
| 5 |
+
Given the continuous iteration of models and datasets, along with ongoing upgrades to the OpenCompass, the configuration settings for the CompassAcademic leaderboard may evolve. Specifically, we adhere to the following update principles:
|
| 6 |
+
|
| 7 |
+
- Newly released models are promptly included, while models published six months to one year (or more) ago are removed from the leaderboard.
|
| 8 |
+
- New datasets are incorporated, while datasets nearing performance saturation are phased out.
|
| 9 |
+
- Existing evaluation results on the leaderboard are updated in sync with changes to the evaluation configuration.
|
| 10 |
+
|
| 11 |
+
To support rapid reproducibility, OpenCompass provides the real-time configuration files used in the academic leaderboard.
|
| 12 |
+
|
| 13 |
+
## CompassAcademic Leaderboard Reproduction
|
| 14 |
+
|
| 15 |
+
[eval_academic_leaderboard_REALTIME.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_academic_leaderboard_REALTIME.py) contains the configuration currently used for academic ranking evaluation. You can replicate the evaluation by following the steps as follows.
|
| 16 |
+
|
| 17 |
+
### 1: Model Configs
|
| 18 |
+
|
| 19 |
+
Firstly, modify the Model List code block in [eval_academic_leaderboard_REALTIME.py](https://github.com/open-compass/opencompass/blob/main/examples/eval_academic_leaderboard_REALTIME.py) to include the model you wish to evaluate.
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
# Models (add your models here)
|
| 23 |
+
from opencompass.configs.models.hf_internlm.lmdeploy_internlm2_5_7b_chat import \
|
| 24 |
+
models as hf_internlm2_5_7b_chat_model
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
The original example calls an lmdeploy-based model configuration in OpenCompass.
|
| 28 |
+
You can also build your new model configuration based on [this document](https://opencompass.readthedocs.io/zh-cn/latest/user_guides/models.html).
|
| 29 |
+
An example of a configuration that calls the deployed service of Qwen3-235B-A22B based on OpenAISDK is as follows:
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
from opencompass.models import OpenAISDK
|
| 33 |
+
from opencompass.utils.text_postprocessors import extract_non_reasoning_content
|
| 34 |
+
|
| 35 |
+
qwen3_235b_a22b_model = dict(
|
| 36 |
+
abbr="qwen_3_235b_a22b_thinking", # Used to identify the model configuration
|
| 37 |
+
key="YOUR_SERVE_API_KEY",
|
| 38 |
+
openai_api_base="YOUR_SERVE_API_URL",
|
| 39 |
+
type=OpenAISDK, # The model configuration types, commonly used such as OpenAISDK, TurboMindModelwithChatTemplate, HuggingFacewithChatTemplate
|
| 40 |
+
path="Qwen/Qwen3-235B-A22B",
|
| 41 |
+
temperature=0.6,
|
| 42 |
+
meta_template=dict(
|
| 43 |
+
round=[
|
| 44 |
+
dict(role='HUMAN', api_role='HUMAN'),
|
| 45 |
+
dict(role='BOT', api_role='BOT', generate=True),
|
| 46 |
+
],
|
| 47 |
+
),
|
| 48 |
+
query_per_second=1,
|
| 49 |
+
max_out_len=32000,
|
| 50 |
+
max_seq_len=32768,
|
| 51 |
+
batch_size=8,
|
| 52 |
+
retry=10,
|
| 53 |
+
extra_body={
|
| 54 |
+
'chat_template_kwargs': {'enable_thinking': True},
|
| 55 |
+
}, # Additional configurations of the model, such as the option in Qwen3 series to control whether they thinks or not
|
| 56 |
+
pred_postprocessor=dict(type=extract_non_reasoning_content), # adding this pred_postprocessor can extract the non-reasoning content from models that output with a think tag
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
models = [
|
| 60 |
+
qwen3_235b_a22b_model,
|
| 61 |
+
]
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Here are the commonly used parameters for reference.
|
| 65 |
+
|
| 66 |
+
- `max_seq_len` = 65536 or 32768
|
| 67 |
+
- `max_out_len` = 64000 or 32000
|
| 68 |
+
- `temperature` = 0.6
|
| 69 |
+
- `top_p` = 0.95
|
| 70 |
+
|
| 71 |
+
### 2: Verifier Configs
|
| 72 |
+
|
| 73 |
+
Complete your verifier model information in `judge_cfg`.
|
| 74 |
+
For detailed information about LLM verifiers, please refer to [this document](https://opencompass.readthedocs.io/zh-cn/latest/advanced_guides/llm_judge.html).
|
| 75 |
+
At present, CompassAcademic use [CompassVerifier-32B](https://huggingface.co/opencompass/CompassVerifier-32B), here is the config example using OpenAISDK:
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
judge_cfg = dict(
|
| 79 |
+
abbr='CompassVerifier',
|
| 80 |
+
type=OpenAISDK,
|
| 81 |
+
path='opencompass/CompassVerifier-32B',
|
| 82 |
+
key='YOUR_API_KEY',
|
| 83 |
+
openai_api_base='YOUR_API_BASE',
|
| 84 |
+
meta_template=dict(
|
| 85 |
+
round=[
|
| 86 |
+
dict(role='HUMAN', api_role='HUMAN'),
|
| 87 |
+
dict(role='BOT', api_role='BOT', generate=True),
|
| 88 |
+
]),
|
| 89 |
+
query_per_second=1,
|
| 90 |
+
batch_size=8,
|
| 91 |
+
temperature=0.001,
|
| 92 |
+
max_out_len=8192,
|
| 93 |
+
max_seq_len=32768,
|
| 94 |
+
mode='mid',
|
| 95 |
+
)
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### 3: Execute evaluation
|
| 99 |
+
|
| 100 |
+
After completing the above configuration file, you can enter the following content in the CLI to start the evaluation:
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
opencompass examples/eval_academic_leaderboard_REALTIME.py
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
For more detailed CLI parameters, please refer to [this document](https://opencompass.readthedocs.io/zh-cn/latest/user_guides/experimentation.html)。
|
docs/en/notes/contribution_guide.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to OpenCompass
|
| 2 |
+
|
| 3 |
+
- [Contributing to OpenCompass](#contributing-to-opencompass)
|
| 4 |
+
- [What is PR](#what-is-pr)
|
| 5 |
+
- [Basic Workflow](#basic-workflow)
|
| 6 |
+
- [Procedures in detail](#procedures-in-detail)
|
| 7 |
+
- [1. Get the most recent codebase](#1-get-the-most-recent-codebase)
|
| 8 |
+
- [2. Checkout a new branch from `main` branch](#2-checkout-a-new-branch-from-main-branch)
|
| 9 |
+
- [3. Commit your changes](#3-commit-your-changes)
|
| 10 |
+
- [4. Push your changes to the forked repository and create a PR](#4-push-your-changes-to-the-forked-repository-and-create-a-pr)
|
| 11 |
+
- [5. Discuss and review your code](#5-discuss-and-review-your-code)
|
| 12 |
+
- [6. Merge your branch to `main` branch and delete the branch](#6--merge-your-branch-to-main-branch-and-delete-the-branch)
|
| 13 |
+
- [Code style](#code-style)
|
| 14 |
+
- [Python](#python)
|
| 15 |
+
- [About Contributing Test Datasets](#about-contributing-test-datasets)
|
| 16 |
+
|
| 17 |
+
Thanks for your interest in contributing to OpenCompass! All kinds of contributions are welcome, including but not limited to the following.
|
| 18 |
+
|
| 19 |
+
- Fix typo or bugs
|
| 20 |
+
- Add documentation or translate the documentation into other languages
|
| 21 |
+
- Add new features and components
|
| 22 |
+
|
| 23 |
+
## What is PR
|
| 24 |
+
|
| 25 |
+
`PR` is the abbreviation of `Pull Request`. Here's the definition of `PR` in the [official document](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) of Github.
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
Pull requests let you tell others about changes you have pushed to a branch in a repository on GitHub. Once a pull request is opened, you can discuss and review the potential changes with collaborators and add follow-up commits before your changes are merged into the base branch.
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Basic Workflow
|
| 32 |
+
|
| 33 |
+
1. Get the most recent codebase
|
| 34 |
+
2. Checkout a new branch from `main` branch.
|
| 35 |
+
3. Commit your changes ([Don't forget to use pre-commit hooks!](#3-commit-your-changes))
|
| 36 |
+
4. Push your changes and create a PR
|
| 37 |
+
5. Discuss and review your code
|
| 38 |
+
6. Merge your branch to `main` branch
|
| 39 |
+
|
| 40 |
+
## Procedures in detail
|
| 41 |
+
|
| 42 |
+
### 1. Get the most recent codebase
|
| 43 |
+
|
| 44 |
+
- When you work on your first PR
|
| 45 |
+
|
| 46 |
+
Fork the OpenCompass repository: click the **fork** button at the top right corner of Github page
|
| 47 |
+

|
| 48 |
+
|
| 49 |
+
Clone forked repository to local
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
git clone git@github.com:XXX/opencompass.git
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Add source repository to upstream
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
git remote add upstream git@github.com:InternLM/opencompass.git
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
- After your first PR
|
| 62 |
+
|
| 63 |
+
Checkout the latest branch of the local repository and pull the latest branch of the source repository.
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
git checkout main
|
| 67 |
+
git pull upstream main
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
### 2. Checkout a new branch from `main` branch
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
git checkout main -b branchname
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### 3. Commit your changes
|
| 77 |
+
|
| 78 |
+
- If you are a first-time contributor, please install and initialize pre-commit hooks from the repository root directory first.
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
pip install -U pre-commit
|
| 82 |
+
pre-commit install
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
- Commit your changes as usual. Pre-commit hooks will be triggered to stylize your code before each commit.
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# coding
|
| 89 |
+
git add [files]
|
| 90 |
+
git commit -m 'messages'
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
```{note}
|
| 94 |
+
Sometimes your code may be changed by pre-commit hooks. In this case, please remember to re-stage the modified files and commit again.
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### 4. Push your changes to the forked repository and create a PR
|
| 98 |
+
|
| 99 |
+
- Push the branch to your forked remote repository
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
git push origin branchname
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
- Create a PR
|
| 106 |
+

|
| 107 |
+
|
| 108 |
+
- Revise PR message template to describe your motivation and modifications made in this PR. You can also link the related issue to the PR manually in the PR message (For more information, checkout the [official guidance](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue)).
|
| 109 |
+
|
| 110 |
+
- You can also ask a specific person to review the changes you've proposed.
|
| 111 |
+
|
| 112 |
+
### 5. Discuss and review your code
|
| 113 |
+
|
| 114 |
+
- Modify your codes according to reviewers' suggestions and then push your changes.
|
| 115 |
+
|
| 116 |
+
### 6. Merge your branch to `main` branch and delete the branch
|
| 117 |
+
|
| 118 |
+
- After the PR is merged by the maintainer, you can delete the branch you created in your forked repository.
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
git branch -d branchname # delete local branch
|
| 122 |
+
git push origin --delete branchname # delete remote branch
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Code style
|
| 126 |
+
|
| 127 |
+
### Python
|
| 128 |
+
|
| 129 |
+
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
|
| 130 |
+
|
| 131 |
+
We use the following tools for linting and formatting:
|
| 132 |
+
|
| 133 |
+
- [flake8](https://github.com/PyCQA/flake8): A wrapper around some linter tools.
|
| 134 |
+
- [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports.
|
| 135 |
+
- [yapf](https://github.com/google/yapf): A formatter for Python files.
|
| 136 |
+
- [codespell](https://github.com/codespell-project/codespell): A Python utility to fix common misspellings in text files.
|
| 137 |
+
- [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.
|
| 138 |
+
- [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
|
| 139 |
+
|
| 140 |
+
Style configurations of yapf and isort can be found in [setup.cfg](https://github.com/open-mmlab/OpenCompass/blob/main/setup.cfg).
|
| 141 |
+
|
| 142 |
+
## About Contributing Test Datasets
|
| 143 |
+
|
| 144 |
+
- Submitting Test Datasets
|
| 145 |
+
- Please implement logic for automatic dataset downloading in the code; or provide a method for obtaining the dataset in the PR. The OpenCompass maintainers will follow up accordingly. If the dataset is not yet public, please indicate so.
|
| 146 |
+
- Submitting Data Configuration Files
|
| 147 |
+
- Provide a README in the same directory as the data configuration. The README should include, but is not limited to:
|
| 148 |
+
- A brief description of the dataset
|
| 149 |
+
- The official link to the dataset
|
| 150 |
+
- Some test examples from the dataset
|
| 151 |
+
- Evaluation results of the dataset on relevant models
|
| 152 |
+
- Citation of the dataset
|
| 153 |
+
- (Optional) Summarizer of the dataset
|
| 154 |
+
- (Optional) If the testing process cannot be achieved simply by concatenating the dataset and model configuration files, a configuration file for conducting the test is also required.
|
| 155 |
+
- (Optional) If necessary, please add a description of the dataset in the relevant documentation sections. This is very necessary to help users understand the testing scheme. You can refer to the following types of documents in OpenCompass:
|
| 156 |
+
- [Circular Evaluation](../advanced_guides/circular_eval.md)
|
| 157 |
+
- [Code Evaluation](../advanced_guides/code_eval.md)
|
| 158 |
+
- [Contamination Assessment](../advanced_guides/contamination_eval.md)
|
docs/en/notes/news.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# News
|
| 2 |
+
|
| 3 |
+
- **\[2024.05.08\]** We supported the evaluation of 4 MoE models: [Mixtral-8x22B-v0.1](configs/models/mixtral/hf_mixtral_8x22b_v0_1.py), [Mixtral-8x22B-Instruct-v0.1](configs/models/mixtral/hf_mixtral_8x22b_instruct_v0_1.py), [Qwen1.5-MoE-A2.7B](configs/models/qwen/hf_qwen1_5_moe_a2_7b.py), [Qwen1.5-MoE-A2.7B-Chat](configs/models/qwen/hf_qwen1_5_moe_a2_7b_chat.py). Try them out now!
|
| 4 |
+
- **\[2024.04.30\]** We supported evaluating a model's compression efficiency by calculating its Bits per Character (BPC) metric on an [external corpora](configs/datasets/llm_compression/README.md) ([official paper](https://github.com/hkust-nlp/llm-compression-intelligence)). Check out the [llm-compression](configs/eval_llm_compression.py) evaluation config now! 🔥🔥🔥
|
| 5 |
+
- **\[2024.04.29\]** We report the performance of several famous LLMs on the common benchmarks, welcome to [documentation](https://opencompass.readthedocs.io/en/latest/user_guides/corebench.html) for more information! 🔥🔥🔥.
|
| 6 |
+
- **\[2024.04.26\]** We deprecated the multi-madality evaluating function from OpenCompass, related implement has moved to [VLMEvalKit](https://github.com/open-compass/VLMEvalKit), welcome to use! 🔥🔥🔥.
|
| 7 |
+
- **\[2024.04.26\]** We supported the evaluation of [ArenaHard](configs/eval_subjective_arena_hard.py) welcome to try!🔥🔥🔥.
|
| 8 |
+
- **\[2024.04.22\]** We supported the evaluation of [LLaMA3](configs/models/hf_llama/hf_llama3_8b.py) 和 [LLaMA3-Instruct](configs/models/hf_llama/hf_llama3_8b_instruct.py), welcome to try! 🔥🔥🔥
|
| 9 |
+
- **\[2024.02.29\]** We supported the MT-Bench, AlpacalEval and AlignBench, more information can be found [here](https://opencompass.readthedocs.io/en/latest/advanced_guides/subjective_evaluation.html)
|
| 10 |
+
- **\[2024.01.30\]** We release OpenCompass 2.0. Click [CompassKit](https://github.com/open-compass), [CompassHub](https://hub.opencompass.org.cn/home), and [CompassRank](https://rank.opencompass.org.cn/home) for more information !
|
| 11 |
+
- **\[2024.01.17\]** We supported the evaluation of [InternLM2](https://github.com/open-compass/opencompass/blob/main/configs/eval_internlm2_keyset.py) and [InternLM2-Chat](https://github.com/open-compass/opencompass/blob/main/configs/eval_internlm2_chat_keyset.py), InternLM2 showed extremely strong performance in these tests, welcome to try!
|
| 12 |
+
- **\[2024.01.17\]** We supported the needle in a haystack test with multiple needles, more information can be found [here](https://opencompass.readthedocs.io/en/latest/advanced_guides/needleinahaystack_eval.html#id8).
|
| 13 |
+
- **\[2023.12.28\]** We have enabled seamless evaluation of all models developed using [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), a powerful toolkit for comprehensive LLM development.
|
| 14 |
+
- **\[2023.12.22\]** We have released [T-Eval](https://github.com/open-compass/T-Eval), a step-by-step evaluation benchmark to gauge your LLMs on tool utilization. Welcome to our [Leaderboard](https://open-compass.github.io/T-Eval/leaderboard.html) for more details!
|
| 15 |
+
- **\[2023.12.10\]** We have released [VLMEvalKit](https://github.com/open-compass/VLMEvalKit), a toolkit for evaluating vision-language models (VLMs), currently support 20+ VLMs and 7 multi-modal benchmarks (including MMBench series).
|
| 16 |
+
- **\[2023.12.10\]** We have supported Mistral AI's MoE LLM: **Mixtral-8x7B-32K**. Welcome to [MixtralKit](https://github.com/open-compass/MixtralKit) for more details about inference and evaluation.
|
| 17 |
+
- **\[2023.11.22\]** We have supported many API-based models, include **Baidu, ByteDance, Huawei, 360**. Welcome to [Models](https://opencompass.readthedocs.io/en/latest/user_guides/models.html) section for more details.
|
| 18 |
+
- **\[2023.11.20\]** Thanks [helloyongyang](https://github.com/helloyongyang) for supporting the evaluation with [LightLLM](https://github.com/ModelTC/lightllm) as backent. Welcome to [Evaluation With LightLLM](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lightllm.html) for more details.
|
| 19 |
+
- **\[2023.11.13\]** We are delighted to announce the release of OpenCompass v0.1.8. This version enables local loading of evaluation benchmarks, thereby eliminating the need for an internet connection. Please note that with this update, **you must re-download all evaluation datasets** to ensure accurate and up-to-date results.
|
| 20 |
+
- **\[2023.11.06\]** We have supported several API-based models, include **ChatGLM Pro@Zhipu, ABAB-Chat@MiniMax and Xunfei**. Welcome to [Models](https://opencompass.readthedocs.io/en/latest/user_guides/models.html) section for more details.
|
| 21 |
+
- **\[2023.10.24\]** We release a new benchmark for evaluating LLMs’ capabilities of having multi-turn dialogues. Welcome to [BotChat](https://github.com/open-compass/BotChat) for more details.
|
| 22 |
+
- **\[2023.09.26\]** We update the leaderboard with [Qwen](https://github.com/QwenLM/Qwen), one of the best-performing open-source models currently available, welcome to our [homepage](https://opencompass.org.cn) for more details.
|
| 23 |
+
- **\[2023.09.20\]** We update the leaderboard with [InternLM-20B](https://github.com/InternLM/InternLM), welcome to our [homepage](https://opencompass.org.cn) for more details.
|
| 24 |
+
- **\[2023.09.19\]** We update the leaderboard with WeMix-LLaMA2-70B/Phi-1.5-1.3B, welcome to our [homepage](https://opencompass.org.cn) for more details.
|
| 25 |
+
- **\[2023.09.18\]** We have released [long context evaluation guidance](docs/en/advanced_guides/longeval.md).
|
| 26 |
+
- **\[2023.09.08\]** We update the leaderboard with Baichuan-2/Tigerbot-2/Vicuna-v1.5, welcome to our [homepage](https://opencompass.org.cn) for more details.
|
| 27 |
+
- **\[2023.09.06\]** [**Baichuan2**](https://github.com/baichuan-inc/Baichuan2) team adpots OpenCompass to evaluate their models systematically. We deeply appreciate the community's dedication to transparency and reproducibility in LLM evaluation.
|
| 28 |
+
- **\[2023.09.02\]** We have supported the evaluation of [Qwen-VL](https://github.com/QwenLM/Qwen-VL) in OpenCompass.
|
| 29 |
+
- **\[2023.08.25\]** [**TigerBot**](https://github.com/TigerResearch/TigerBot) team adpots OpenCompass to evaluate their models systematically. We deeply appreciate the community's dedication to transparency and reproducibility in LLM evaluation.
|
| 30 |
+
- **\[2023.08.21\]** [**Lagent**](https://github.com/InternLM/lagent) has been released, which is a lightweight framework for building LLM-based agents. We are working with Lagent team to support the evaluation of general tool-use capability, stay tuned!
|
| 31 |
+
- **\[2023.08.18\]** We have supported evaluation for **multi-modality learning**, include **MMBench, SEED-Bench, COCO-Caption, Flickr-30K, OCR-VQA, ScienceQA** and so on. Leaderboard is on the road. Feel free to try multi-modality evaluation with OpenCompass !
|
| 32 |
+
- **\[2023.08.18\]** [Dataset card](https://opencompass.org.cn/dataset-detail/MMLU) is now online. Welcome new evaluation benchmark OpenCompass !
|
| 33 |
+
- **\[2023.08.11\]** [Model comparison](https://opencompass.org.cn/model-compare/GPT-4,ChatGPT,LLaMA-2-70B,LLaMA-65B) is now online. We hope this feature offers deeper insights!
|
| 34 |
+
- **\[2023.08.11\]** We have supported [LEval](https://github.com/OpenLMLab/LEval).
|
| 35 |
+
- **\[2023.08.10\]** OpenCompass is compatible with [LMDeploy](https://github.com/InternLM/lmdeploy). Now you can follow this [instruction](https://opencompass.readthedocs.io/en/latest/advanced_guides/evaluation_lmdeploy.html#) to evaluate the accelerated models provide by the **Turbomind**.
|
| 36 |
+
- **\[2023.08.10\]** We have supported [Qwen-7B](https://github.com/QwenLM/Qwen-7B) and [XVERSE-13B](https://github.com/xverse-ai/XVERSE-13B) ! Go to our [leaderboard](https://opencompass.org.cn/leaderboard-llm) for more results! More models are welcome to join OpenCompass.
|
| 37 |
+
- **\[2023.08.09\]** Several new datasets(**CMMLU, TydiQA, SQuAD2.0, DROP**) are updated on our [leaderboard](https://opencompass.org.cn/leaderboard-llm)! More datasets are welcomed to join OpenCompass.
|
| 38 |
+
- **\[2023.08.07\]** We have added a [script](tools/eval_mmbench.py) for users to evaluate the inference results of [MMBench](https://opencompass.org.cn/MMBench)-dev.
|
| 39 |
+
- **\[2023.08.05\]** We have supported [GPT-4](https://openai.com/gpt-4)! Go to our [leaderboard](https://opencompass.org.cn/leaderboard-llm) for more results! More models are welcome to join OpenCompass.
|
| 40 |
+
- **\[2023.07.27\]** We have supported [CMMLU](https://github.com/haonan-li/CMMLU)! More datasets are welcome to join OpenCompass.
|
docs/en/prompt/chain_of_thought.md
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chain of Thought
|
| 2 |
+
|
| 3 |
+
## Background
|
| 4 |
+
|
| 5 |
+
During the process of reasoning, CoT (Chain of Thought) method is an efficient way to help LLMs deal complex questions, for example: math problem and relation inference. In OpenCompass, we support multiple types of CoT method.
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
## 1. Zero Shot CoT
|
| 10 |
+
|
| 11 |
+
You can change the `PromptTemplate` of the dataset config, by simply add *Let's think step by step* to realize a Zero-Shot CoT prompt for your evaluation:
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
qa_infer_cfg = dict(
|
| 15 |
+
prompt_template=dict(
|
| 16 |
+
type=PromptTemplate,
|
| 17 |
+
template="Answer the question:\nQ: {question}?\nLet's think step by step:\n"
|
| 18 |
+
),
|
| 19 |
+
retriever=dict(type=ZeroRetriever)
|
| 20 |
+
)
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## 2. Few Shot CoT
|
| 24 |
+
|
| 25 |
+
Few-shot CoT can make LLMs easy to follow your instructions and get better answers. For few-shot CoT, add your CoT template to `PromptTemplate` like following config to create a one-shot prompt:
|
| 26 |
+
|
| 27 |
+
```python
|
| 28 |
+
qa_infer_cfg = dict(
|
| 29 |
+
prompt_template=dict(
|
| 30 |
+
type=PromptTemplate,
|
| 31 |
+
template=
|
| 32 |
+
'''Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?
|
| 33 |
+
Let's think step by step
|
| 34 |
+
Answer:
|
| 35 |
+
Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.
|
| 36 |
+
His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers
|
| 37 |
+
They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.
|
| 38 |
+
All together his team scored 50+24+10= 84 points
|
| 39 |
+
Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.
|
| 40 |
+
His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.
|
| 41 |
+
They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.
|
| 42 |
+
All together Mark's opponents scored 100+12+5=117 points
|
| 43 |
+
The total score for the game is both team's scores added together, so it is 84+117=201 points
|
| 44 |
+
The answer is 201
|
| 45 |
+
|
| 46 |
+
Question: {question}\nLet's think step by step:\n{answer}
|
| 47 |
+
'''),
|
| 48 |
+
retriever=dict(type=ZeroRetriever)
|
| 49 |
+
)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## 3. Self-Consistency
|
| 53 |
+
|
| 54 |
+
The SC (Self-Consistency) method is proposed in [this paper](https://arxiv.org/abs/2203.11171), which will sample multiple reasoning paths for the question, and make majority voting to the generated answers for LLMs. This method displays remarkable proficiency among reasoning tasks with high accuracy but may consume more time and resources when inferencing, because of the majority voting strategy. In OpenCompass, You can easily implement the SC method by replacing `GenInferencer` with `SCInferencer` in the dataset configuration and setting the corresponding parameters like:
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
# This SC gsm8k config can be found at: opencompass.configs.datasets.gsm8k.gsm8k_gen_a3e34a.py
|
| 58 |
+
gsm8k_infer_cfg = dict(
|
| 59 |
+
inferencer=dict(
|
| 60 |
+
type=SCInferencer, # Replace GenInferencer with SCInferencer.
|
| 61 |
+
generation_kwargs=dict(do_sample=True, temperature=0.7, top_k=40), # Set sample parameters to make sure model generate various output, only works for models load from HuggingFace now.
|
| 62 |
+
infer_type='SC',
|
| 63 |
+
sc_size = SAMPLE_SIZE
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
gsm8k_eval_cfg = dict(sc_size=SAMPLE_SIZE)
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
```{note}
|
| 70 |
+
OpenCompass defaults to use argmax for sampling the next token. Therefore, if the sampling parameters are not specified, the model's inference results will be completely consistent each time, and multiple rounds of evaluation will be ineffective.
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Where `SAMPLE_SIZE` is the number of reasoning paths in Self-Consistency, higher value usually outcome higher performance. The following figure from the original SC paper demonstrates the relation between reasoning paths and performance in several reasoning tasks:
|
| 74 |
+
|
| 75 |
+

|
| 76 |
+
|
| 77 |
+
From the figure, it can be seen that in different reasoning tasks, performance tends to improve as the number of reasoning paths increases. However, for some tasks, increasing the number of reasoning paths may reach a limit, and further increasing the number of paths may not bring significant performance improvement. Therefore, it is necessary to conduct experiments and adjustments on specific tasks to find the optimal number of reasoning paths that best suit the task.
|
| 78 |
+
|
| 79 |
+
## 4. Tree-of-Thoughts
|
| 80 |
+
|
| 81 |
+
In contrast to the conventional CoT approach that considers only a single reasoning path, Tree-of-Thoughts (ToT) allows the language model to explore multiple diverse reasoning paths simultaneously. The model evaluates the reasoning process through self-assessment and makes global choices by conducting lookahead or backtracking when necessary. Specifically, this process is divided into the following four stages:
|
| 82 |
+
|
| 83 |
+
**1. Thought Decomposition**
|
| 84 |
+
|
| 85 |
+
Based on the nature of the problem, break down the problem into multiple intermediate steps. Each step can be a phrase, equation, or writing plan, depending on the nature of the problem.
|
| 86 |
+
|
| 87 |
+
**2. Thought Generation**
|
| 88 |
+
|
| 89 |
+
Assuming that solving the problem requires k steps, there are two methods to generate reasoning content:
|
| 90 |
+
|
| 91 |
+
- Independent sampling: For each state, the model independently extracts k reasoning contents from the CoT prompts, without relying on other reasoning contents.
|
| 92 |
+
- Sequential generation: Sequentially use "prompts" to guide the generation of reasoning content, where each reasoning content may depend on the previous one.
|
| 93 |
+
|
| 94 |
+
**3. Heuristic Evaluation**
|
| 95 |
+
|
| 96 |
+
Use heuristic methods to evaluate the contribution of each generated reasoning content to problem-solving. This self-evaluation is based on the model's self-feedback and involves designing prompts to have the model score multiple generated results.
|
| 97 |
+
|
| 98 |
+
**4. Search Algorithm Selection**
|
| 99 |
+
|
| 100 |
+
Based on the methods of generating and evaluating reasoning content, select an appropriate search algorithm. For example, you can use breadth-first search (BFS) or depth-first search (DFS) algorithms to systematically explore the thought tree, conducting lookahead and backtracking.
|
| 101 |
+
|
| 102 |
+
In OpenCompass, ToT parameters need to be set according to the requirements. Below is an example configuration for the 24-Point game from the [official paper](https://arxiv.org/pdf/2305.10601.pdf). Currently, ToT inference is supported only with Huggingface models:
|
| 103 |
+
|
| 104 |
+
```python
|
| 105 |
+
# This ToT Game24 config can be found at: opencompass/configs/datasets/game24/game24_gen_8dfde3.py.
|
| 106 |
+
from opencompass.datasets import (Game24Dataset, game24_postprocess,
|
| 107 |
+
Game24Evaluator, Game24PromptWrapper)
|
| 108 |
+
|
| 109 |
+
generation_kwargs = dict(temperature=0.7)
|
| 110 |
+
|
| 111 |
+
game24_infer_cfg = dict(
|
| 112 |
+
prompt_template=dict(
|
| 113 |
+
type=PromptTemplate,
|
| 114 |
+
template='{input}'), # Directly pass the input content, as the Prompt needs to be specified in steps
|
| 115 |
+
retriever=dict(type=ZeroRetriever),
|
| 116 |
+
inferencer=dict(type=ToTInferencer, # Replace GenInferencer with ToTInferencer
|
| 117 |
+
generation_kwargs=generation_kwargs,
|
| 118 |
+
method_generate='propose', # Method for generating reasoning content, can be independent sampling (sample) or sequential generation (propose)
|
| 119 |
+
method_evaluate='value', # Method for evaluating reasoning content, can be voting (vote) or scoring (value)
|
| 120 |
+
method_select='greedy', # Method for selecting reasoning content, can be greedy (greedy) or random (sample)
|
| 121 |
+
n_evaluate_sample=3,
|
| 122 |
+
n_select_sample=5,
|
| 123 |
+
task_wrapper=dict(type=Game24PromptWrapper) # This Wrapper class includes the prompts for each step and methods for generating and evaluating reasoning content, needs customization according to the task
|
| 124 |
+
))
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
If you want to use the ToT method on a custom dataset, you'll need to make additional configurations in the `opencompass.datasets.YourDataConfig.py` file to set up the `YourDataPromptWrapper` class. This is required for handling the thought generation and heuristic evaluation step within the ToT framework. For reasoning tasks similar to the game 24-Point, you can refer to the implementation in `opencompass/datasets/game24.py` for guidance.
|