SAMTok
Collection
A unified mask-token interface for MLLMs.
β’
5 items
β’
Updated
Install the transformers that supports Qwen3-VL.
import re
import numpy as np
def extract_mt_token_ids_v1(text):
pattern = r"<\|mt_(\d{4})\|>"
return [int(x) for x in re.findall(pattern, text)]
def extract_mt_token_ids_v2(text):
pattern = re.compile(r'<\|mt_start\|><\|mt_(\d{4})\|><\|mt_(\d{4})\|><\|mt_end\|>')
matches = pattern.findall(text)
ret_list = []
for num1, num2 in matches:
ret_list.append(int(num1))
ret_list.append(int(num2))
return ret_list
def find_first_index(arr, value):
indices = np.where(arr == value)[0]
return indices[0] if len(indices) > 0 else -1
def fix_mt_format_comprehensive(text):
pattern_too_many = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_\d+\|>)(?:<\|mt_\d+\|>)+<\|mt_end\|>'
replacement_too_many = r'\1\2\3<|mt_end|>'
text = re.sub(pattern_too_many, replacement_too_many, text)
pattern_too_few_with_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_end\|>)'
replacement_too_few = r'\1\2<|mt_9999|><|mt_end|>'
text = re.sub(pattern_too_few_with_end, replacement_too_few, text)
pattern_too_few_no_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(?!<\|mt_)'
replacement_too_few_no_end = r'\1\2<|mt_9999|><|mt_end|>'
text = re.sub(pattern_too_few_no_end, replacement_too_few_no_end, text)
return text
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from projects.samtok.models import DirectResize, VQ_SAM2, VQ_SAM2Config, SAM2Config
# build VLM
model = Qwen3VLForConditionalGeneration.from_pretrained(
"zhouyik/Qwen3-VL-8B-SAMTok", torch_dtype="auto"
).cuda().eval()
processor = AutoProcessor.from_pretrained("zhouyik/Qwen3-VL-4B-SAMTok")
# build SAMTok
CODEBOOK_SIZE = 256
CODEBOOK_DEPTH = 2
sam2_config = SAM2Config(
ckpt_path="zhouyik/Qwen3-VL-4B-SAMTok/sam2.1_hiera_large.pt",
)
vq_sam2_config = VQ_SAM2Config(
sam2_config=sam2_config,
codebook_size=CODEBOOK_SIZE,
codebook_depth=CODEBOOK_DEPTH,
shared_codebook=False,
latent_dim=256,
)
vq_sam2 = VQ_SAM2(vq_sam2_config).cuda().eval()
state = torch.load("zhouyik/Qwen3-VL-4B-SAMTok/mask_tokenizer_256x2.pth", map_location="cpu")
vq_sam2.load_state_dict(state)
sam2_image_processor = DirectResize(1024)
# message
image_path = "figs/totoro.jpg"
question = "Could you please give me a detail description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer."
image = Image.open(image_path).convert('RGB')
ori_width, ori_height = image.size
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_path,
},
{"type": "text", "text": question},
],
}
]
# VLM inferece
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
top_p=1.0,
)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# decode mask
quant_ids = extract_mt_token_ids_v1(output_text[0])
if len(quant_ids) % CODEBOOK_DEPTH != 0:
output_text = [fix_mt_format_comprehensive(output_text[0])]
quant_ids = extract_mt_token_ids_v2(output_text[0])
batch_size = len(quant_ids) // CODEBOOK_DEPTH
remap_quant_ids = []
tags = []
for bs_id in range(batch_size):
chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
tags.append(f"{chunk_quant_ids[0]}-{chunk_quant_ids[1]}")
remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
code1 = remap_chunk_quant_ids[0]
code2 = remap_chunk_quant_ids[1]
if not (code2 >= 0 and code2 < CODEBOOK_SIZE):
code2 = -1
remap_chunk_quant_ids_error_handle = [code1, code2]
remap_quant_ids.append(remap_chunk_quant_ids_error_handle)
batch_size = len(remap_quant_ids)
sam2_image = np.array(image)
sam2_image = sam2_image_processor.apply_image(sam2_image)
sam2_pixel_values = torch.from_numpy(sam2_image).permute(2, 0, 1).contiguous()
sam2_pixel_values = sam2_pixel_values.unsqueeze(0).to(vq_sam2.dtype).to(vq_sam2.device)
sam2_pixel_values = sam2_pixel_values.repeat(batch_size, 1, 1, 1)
quant_ids = torch.LongTensor(remap_quant_ids).to(vq_sam2.device)
with torch.no_grad():
_pred_masks = vq_sam2.forward_with_codes(sam2_pixel_values, quant_ids)
_pred_masks = torch.nn.functional.interpolate(_pred_masks, size=(ori_height, ori_width), mode='bilinear')
_pred_masks = _pred_masks > 0.5
_pred_masks = _pred_masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
text_token_2d_mask_mapping = {tag: _pred_mask for tag, _pred_mask in zip(tags, _pred_masks)}