Model Card for Qwen3-4B-Instruct-2507-Segmenter
This is the semantic segmenter introduced in the paper Towards Generalization of Block Attention via Automatic Segmentation and Block Distillation. It has been trained using our proposed semantic segmentation dataset SemanticSeg.
How to use
Requires transformer version >= 4.57.3
Insert the candidate cut points.
Feed the output from step 1 to the segmenter, and customize the recursion depth and threshold value if needed.
An optional implementation is shown below:
def insert_marker(txt:str, sep_pattern:str = None, cut_marker:str = "<cut {}>", marker_pos: Literal["left", "right"] = "right") -> str:
candidate_blocks = []
pre_m = None
i = 0
for m in re.finditer(sep_pattern, txt):
if marker_pos == "right":
b = txt[pre_m.end() if pre_m is not None else 0: m.end()] + cut_marker.format(i+1)
else:
b = txt[pre_m.end() if pre_m is not None else 0: m.start()] + cut_marker.format(i+1) + txt[m.start(): m.end()]
candidate_blocks.append(b)
pre_m = m
i += 1
if i==0:
candidate_blocks.append(txt)
candidate_blocks[-1] = candidate_blocks[-1] + cut_marker.format(i+1)
else:
if len(txt[pre_m.end(): ]) > 0:
candidate_blocks.append(txt[pre_m.end(): ] + cut_marker.format(i+1))
candidate_blocks[0] = cut_marker.format(0) + candidate_blocks[0]
txt_marker = "".join(candidate_blocks)
return txt_marker
# states - [bsz, q_len, ...]
def shift_value(states, cut_point_mask):
cut_pos = torch.nonzero(cut_point_mask, as_tuple=True)
shift_s = states.clone()
for i, row in enumerate(states):
col_index = cut_pos[1][cut_pos[0] == i]
pre_idx = col_index[0]
for idx in col_index[1:]:
shift_s[i, pre_idx, ...] = states[i, idx, ...]
pre_idx = idx
shift_s[i, pre_idx, ...] = torch.tensor([1, float('-inf')], device=states.device)
return shift_s
def get_cutpoint_label(input_ids: list,
cut_token_ids:List[int],
chunk_bound = None,
window_size: int = 1,
):
# chunk_bound - boundaries of the chunk
input_length = len(input_ids)
labels = [-100] * input_length
candidate_cut_points = list(locate(input_ids, lambda *args: args == tuple([cut_token_ids[0]]*window_size), window_size=window_size))
if len(cut_token_ids)>1:
candidate_cut_points.extend(list(locate(input_ids, lambda *args: args == tuple([cut_token_ids[-1]]), window_size=1)))
cut_points = []
# If there is no chunk bound, just return the mask of candidate cut points
if chunk_bound is not None:
for points in chunk_bound[:-1]:
cut_number = int(re.findall(r"\d+", points[-1])[-1])
cut_points.append(cut_number)
for i, idx in enumerate(candidate_cut_points):
if i in cut_points:
labels[idx] = 1
else:
labels[idx] = 0
return labels
def model_cut(txt, insert_pattern, model, tokenizer, return_txt=True, depth=1, threshold=[0.40]):
txt_marker = insert_marker(txt=txt, sep_pattern=insert_pattern, marker_pos="right")
input_txt = re.sub(r"<cut \d+>", "<cut>", txt_marker)
# print(input_txt)
inputs = tokenizer(input_txt, return_tensors="pt", truncation=True).to(model.device)
cut_labels = get_cutpoint_label(input_ids=inputs["input_ids"].squeeze().tolist(),
cut_token_ids=[tokenizer.additional_special_tokens_ids[-1]],
window_size=1,
)
cut_point_mask = (torch.tensor(cut_labels) != -100).to(model.device)
cut_point_mask = cut_point_mask[None,:].to(torch.bool)
model.eval()
tmp_cut_pos = [inputs["input_ids"].shape[-1]]
tmp_cut_prob = None
# The recursion begins
for d in range(depth):
pre_pos = 0
cut_pos = []
cut_prob = torch.tensor([], device=model.device)
for pos in tmp_cut_pos:
if cut_point_mask[..., pre_pos:pos+1].sum() > 2:
with torch.no_grad():
# print(inputs["input_ids"].shape)
outputs = model(input_ids=inputs["input_ids"][..., pre_pos:pos+1],
attention_mask=inputs["attention_mask"][..., pre_pos:pos+1],
)
shifted_logits = shift_value(states=outputs.logits, cut_point_mask=cut_point_mask[..., pre_pos:pos+1])
# Shift the hidden states for all the candidate cut tokens, since we use the next to predict the current.
prediction_prob = F.softmax(shifted_logits, dim=-1)
prediction = (prediction_prob[..., 1] >= threshold[d])
c_pos = torch.nonzero(prediction & cut_point_mask[..., pre_pos:pos+1], as_tuple=True)[-1]
c_pos = c_pos.sort(descending=False, stable=True, dim=-1)[0]
cut_pos.extend((c_pos + pre_pos).tolist())
c_prob = prediction_prob[..., 1]
if tmp_cut_prob is None:
cut_prob = torch.concat(tensors=[cut_prob, c_prob], dim=-1)
else:
c_prob = c_prob[:, 1:-1] if pos < inputs["input_ids"].shape[-1] else c_prob[:, 1:]
cut_prob = torch.concat(tensors=[cut_prob,
tmp_cut_prob[:, pre_pos].unsqueeze(-1),
c_prob],
dim=-1)
# print(cut_prob.shape)
else:
if tmp_cut_prob is not None:
cut_prob = torch.concat(tensors=[cut_prob, tmp_cut_prob[..., pre_pos: pos]], dim=-1)
cut_pos.append(pos)
pre_pos = pos
tmp_cut_pos = copy.deepcopy(cut_pos)
tmp_cut_prob = cut_prob.clone()
cut_prob = cut_prob[cut_point_mask].tolist()
# print(cut_point_mask.sum(), len(cut_prob))
p_degree = len(cut_pos) + 1
blocks = []
chunk_id = []
prediction_prob = []
if return_txt:
pre_c = 0
s=0
for c in cut_pos:
block_txt = tokenizer.batch_decode(inputs["input_ids"][:, pre_c : c])[0]
l = len(re.findall(r"<cut>", block_txt))
chunk_id.append("<cut {}> --- <cut {}>".format(s, s+l))
block_txt = re.sub(r"<cut>", "", block_txt)
prediction_prob.append(cut_prob[s:s+l])
if len(block_txt.strip())>0:
blocks.append(block_txt)
pre_c = c
s += l
return {"blocks":blocks,
"cut_prob": prediction_prob,
"chunk_id":"\n".join(chunk_id),
"parallel degree":p_degree,
"cut positions":cut_pos,
"threshold": threshold,
"length": inputs["input_ids"].shape[-1]}
The user can customize the recursive depth and the threshold value via the model_cut function to control the final segmentation granularity.
The segmenter is trained with a threshold of 0.5, but we find it also performs well in the range 0.2 ~ 0.5. We recommend pairing each recursion level with a threshold value.
Typical combinations:
Recursion depth 1 - threshold value [0.4] (Example: LongbenchSeg, LoCoMoSeg);
Recursion depth 2 - threshold value [0.2, 0.4] (Example: ChatQA2Seg) or [0.4, 0.4].
Note: Do remember to shift the final logits for the candidate cut points, because the segmenter is trained to use the next candidate point to predict the current one.
An example:
from transformers import AutoTokenizer, AutoModelForCausalLM
txt = '''
import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal
import scipy.stats.distributions as distrs
from scipy.stats.kde import gaussian_kde
from scipy.integrate import quad
import pytest
def augment_grid(x, n_inner_points):
test_arr = [
np.linspace(x[i], x[i + 1], n_inner_points + 1, endpoint=False)
for i in np.arange(len(x) - 1)
]
test_arr.append([x[-1]])
return np.concatenate(test_arr)
def circle_fun(x, low, high):
x = np.array(x)
center = 0.5 * (high + low)
radius = 0.5 * (high - low)
res = np.zeros_like(x)
center_dist = np.abs(x - center)
is_in = center_dist <= radius
res[is_in] = np.sqrt(radius ** 2 - center_dist[is_in] ** 2)
return res
class TestCont:
"""Regression tests for `Cont` class"""
def test_init_errors(self):
def check_one_input(def_args, var):
with pytest.raises(TypeError, match=f"`{var}`.*numpy array"):
def_args[var] = {"a": None}
Cont(**def_args)
with pytest.raises(TypeError, match=f"`{var}`.*float"):
def_args[var] = ["a", "a"]
Cont(**def_args)
with pytest.raises(TypeError, match=f"`{var}`.*finite values"):
def_args[var] = [0, np.nan]
Cont(**def_args)
with pytest.raises(TypeError, match=f"`{var}`.*finite values"):
def_args[var] = [0, np.inf]
Cont(**def_args)
with pytest.raises(ValueError, match=f"`{var}`.*1d array"):
def_args[var] = [[0, 1]]
Cont(**def_args)
check_one_input({"y": [1, 1]}, "x")
check_one_input({"x": [0, 1]}, "y")
with pytest.raises(ValueError, match="[Ll]engths.*match"):
Cont([0, 1], [1, 1, 1])
with pytest.raises(ValueError, match="two"):
Cont([1], [1])
with pytest.warns(UserWarning, match="`x`.*not sorted.*`x` and `y`"):
rv = Cont([1, 0], [0, 2])
rv_ref = Cont([0, 1], [2, 0])
_test_equal_rand(rv, rv_ref)
with pytest.raises(ValueError, match="`y`.*negative"):
Cont([0, 1], [1, -1])
with pytest.raises(ValueError, match="`y`.*no positive"):
Cont([0, 1], [0, 0])
def test_init(self):
x_ref = np.array([0, 1, 2])
y_ref = np.array([0, 1, 0])
rv_ref = Cont(x_ref, y_ref)
class TestFromRVAccuracy:
"""Accuracy of `Cont.from_rv()`"""
# Output of `from_rv()` should have CDF that differs from original CDF by
# no more than `thres`
@pytest.mark.slow
@pytest.mark.parametrize(
"distr_dict,thres",
[
(DISTRIBUTIONS_COMMON, 1e-4),
(DISTRIBUTIONS_INF_DENSITY, 1e-3),
(DISTRIBUTIONS_HEAVY_TAILS, 5e-3),
],
)
def test_cdf_maxerror(self, distr_dict, thres):
test_passed = {
name: TestFromRVAccuracy.from_rv_cdf_maxerror(distr) <= thres
for name, distr in distr_dict.items()
}
assert all(test_passed.values())
class TestFromSampleAccuracy:
"""Accuracy of `Cont.from_sample()`"""
# Output of `from_sample()` should differ from original density estimate by
# no more than `thres` (with default density estimator)
@pytest.mark.slow
@pytest.mark.parametrize(
"distr_dict,thres",
[
(DISTRIBUTIONS_COMMON, 1e-4),
(DISTRIBUTIONS_INF_DENSITY, 1.5e-4),
(DISTRIBUTIONS_HEAVY_TAILS, 1e-4),
],
)
def test_close_cdf(self, distr_dict, thres):
rng = np.random.default_rng(101)
test_passed = {
name: TestFromSampleAccuracy.simulated_cdf_error(distr, rng) <= thres
for name, distr in distr_dict.items()
}
'''
model = AutoModelForCausalLM.from_pretrained("Syon-Li/Qwen3-4B-Instruct-2507-Segmenter", dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
"Syon-Li/Qwen3-4B-Instruct-2507-Segmenter",
)
results = model_cut(txt=txt, model=model, tokenizer=tokenizer, insert_pattern=r"\n{1,}", depth=1, threshold=[0.40])
print(results)
Note that the insert pattern does not have to be r"\n{1,}", you can customize it based on your input text.
If you find this useful, please cite:
@article{li2026towards,
title={Towards Generalization of Block Attention via Automatic Segmentation and Block Distillation},
author={Li, Shuaiyi and Zhang, Zhisong and Wang, Yan and Zhu, Lei and Ma, Dongyang and Deng, Chenlong and Deng, Yang and Lam, Wai},
journal={arXiv preprint arXiv:2605.15913},
year={2026}
}
- Downloads last month
- 461
Model tree for Syon-Li/Qwen3-4B-Instruct-2507-Segmenter
Base model
Qwen/Qwen3-4B-Instruct-2507