Quickstart

Install the transformers that supports Qwen3-VL.

Using πŸ€— Transformers to Chat

import re
import numpy as np

def extract_mt_token_ids_v1(text):
    pattern = r"<\|mt_(\d{4})\|>"
    return [int(x) for x in re.findall(pattern, text)]

def extract_mt_token_ids_v2(text):
    pattern = re.compile(r'<\|mt_start\|><\|mt_(\d{4})\|><\|mt_(\d{4})\|><\|mt_end\|>')
    matches = pattern.findall(text)
    ret_list = []
    for num1, num2 in matches:
        ret_list.append(int(num1))
        ret_list.append(int(num2))
    return ret_list

def find_first_index(arr, value):
    indices = np.where(arr == value)[0]
    
    return indices[0] if len(indices) > 0 else -1

def fix_mt_format_comprehensive(text):
    pattern_too_many = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_\d+\|>)(?:<\|mt_\d+\|>)+<\|mt_end\|>'
    replacement_too_many = r'\1\2\3<|mt_end|>'
    text = re.sub(pattern_too_many, replacement_too_many, text)

    pattern_too_few_with_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_end\|>)'
    replacement_too_few = r'\1\2<|mt_9999|><|mt_end|>'
    text = re.sub(pattern_too_few_with_end, replacement_too_few, text)

    pattern_too_few_no_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(?!<\|mt_)'
    replacement_too_few_no_end = r'\1\2<|mt_9999|><|mt_end|>'
    text = re.sub(pattern_too_few_no_end, replacement_too_few_no_end, text)
    return text

from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from projects.samtok.models import DirectResize, VQ_SAM2, VQ_SAM2Config, SAM2Config

# build VLM
model = Qwen3VLForConditionalGeneration.from_pretrained(
    "zhouyik/Qwen3-VL-8B-SAMTok", torch_dtype="auto"
).cuda().eval()
processor = AutoProcessor.from_pretrained("zhouyik/Qwen3-VL-8B-SAMTok")

# build SAMTok
CODEBOOK_SIZE = 256
CODEBOOK_DEPTH = 2
sam2_config = SAM2Config(
    ckpt_path="zhouyik/Qwen3-VL-8B-SAMTok/sam2.1_hiera_large.pt",
)
vq_sam2_config = VQ_SAM2Config(
    sam2_config=sam2_config,
    codebook_size=CODEBOOK_SIZE,
    codebook_depth=CODEBOOK_DEPTH,
    shared_codebook=False,
    latent_dim=256,
)
vq_sam2 = VQ_SAM2(vq_sam2_config).cuda().eval()
state = torch.load("zhouyik/Qwen3-VL-8B-SAMTok/mask_tokenizer_256x2.pth", map_location="cpu")
vq_sam2.load_state_dict(state)
sam2_image_processor = DirectResize(1024)

# message
image_path = "figs/totoro.jpg"
question = "Could you please give me a detail description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer."
image = Image.open(image_path).convert('RGB')
ori_width, ori_height = image.size
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": image_path,
            },
            {"type": "text", "text": question},
        ],
    }
]

# VLM inferece
inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_dict=True,
    return_tensors="pt"
)
inputs = inputs.to(model.device)

generated_ids = model.generate(
    **inputs, 
    max_new_tokens=512,
    do_sample=False,
    top_p=1.0,
)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

# decode mask
quant_ids = extract_mt_token_ids_v1(output_text[0])
if len(quant_ids) % CODEBOOK_DEPTH != 0:
    output_text = [fix_mt_format_comprehensive(output_text[0])]
    quant_ids = extract_mt_token_ids_v2(output_text[0])

batch_size = len(quant_ids) // CODEBOOK_DEPTH
remap_quant_ids = []
tags = []
for bs_id in range(batch_size):
    chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
    tags.append(f"{chunk_quant_ids[0]}-{chunk_quant_ids[1]}")
    remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
    code1 = remap_chunk_quant_ids[0]
    code2 = remap_chunk_quant_ids[1]
    if not (code2 >= 0 and code2 < CODEBOOK_SIZE):
        code2 = -1
    remap_chunk_quant_ids_error_handle = [code1, code2]
    remap_quant_ids.append(remap_chunk_quant_ids_error_handle)

batch_size = len(remap_quant_ids)
sam2_image = np.array(image)
sam2_image = sam2_image_processor.apply_image(sam2_image)
sam2_pixel_values = torch.from_numpy(sam2_image).permute(2, 0, 1).contiguous()
sam2_pixel_values = sam2_pixel_values.unsqueeze(0).to(vq_sam2.dtype).to(vq_sam2.device)
sam2_pixel_values = sam2_pixel_values.repeat(batch_size, 1, 1, 1)

quant_ids = torch.LongTensor(remap_quant_ids).to(vq_sam2.device)

with torch.no_grad():
    _pred_masks = vq_sam2.forward_with_codes(sam2_pixel_values, quant_ids)
_pred_masks = torch.nn.functional.interpolate(_pred_masks, size=(ori_height, ori_width), mode='bilinear')
_pred_masks = _pred_masks > 0.5
_pred_masks = _pred_masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
text_token_2d_mask_mapping = {tag: _pred_mask for tag, _pred_mask in zip(tags, _pred_masks)}
Downloads last month
175
Safetensors
Model size
9B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Spaces using zhouyik/Qwen3-VL-8B-SAMTok 2

Collection including zhouyik/Qwen3-VL-8B-SAMTok