Quickstart

Install the transformers that supports Qwen2.5-VL.

Using πŸ€— Transformers to Chat

import re
import numpy as np

def extract_mt_token_ids_v1(text):
    pattern = r"<\|mt_(\d{4})\|>"
    return [int(x) for x in re.findall(pattern, text)]

def extract_mt_token_ids_v2(text):
    pattern = re.compile(r'<\|mt_start\|><\|mt_(\d{4})\|><\|mt_(\d{4})\|><\|mt_end\|>')
    matches = pattern.findall(text)
    ret_list = []
    for num1, num2 in matches:
        ret_list.append(int(num1))
        ret_list.append(int(num2))
    return ret_list

def find_first_index(arr, value):
    indices = np.where(arr == value)[0]
    
    return indices[0] if len(indices) > 0 else -1

def fix_mt_format_comprehensive(text):
    pattern_too_many = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_\d+\|>)(?:<\|mt_\d+\|>)+<\|mt_end\|>'
    replacement_too_many = r'\1\2\3<|mt_end|>'
    text = re.sub(pattern_too_many, replacement_too_many, text)

    pattern_too_few_with_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_end\|>)'
    replacement_too_few = r'\1\2<|mt_9999|><|mt_end|>'
    text = re.sub(pattern_too_few_with_end, replacement_too_few, text)

    pattern_too_few_no_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(?!<\|mt_)'
    replacement_too_few_no_end = r'\1\2<|mt_9999|><|mt_end|>'
    text = re.sub(pattern_too_few_no_end, replacement_too_few_no_end, text)
    return text

def extract_think_and_answer_robust(response: str) -> Tuple[Optional[str], Optional[str]]:
    think_content = None
    answer_content = None
    think_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
    answer_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
    think_match = think_pattern.search(response)
    if think_match:
        think_content = think_match.group(1)
    answer_match = answer_pattern.search(response)
    if answer_match:
        answer_content = answer_match.group(1)
    
    if answer_content is None or think_content is None:
        if '<answer>' in response:
            head, tail = response.split('<answer>', 1)
            if think_content is None:
                think_content = head
            if answer_content is None:
                answer_content = tail
        elif '</think>' in response:
            head, tail = response.split('</think>', 1)
            if think_content is None:
                think_content = head
            if answer_content is None:
                answer_content = tail

    return think_content, answer_content

from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from projects.samtok.models import DirectResize, VQ_SAM2, VQ_SAM2Config, SAM2Config

# build VLM
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl", torch_dtype="auto"
).cuda().eval()
processor = AutoProcessor.from_pretrained("zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl")

# build SAMTok
CODEBOOK_SIZE = 256
CODEBOOK_DEPTH = 2
sam2_config = SAM2Config(
    ckpt_path="zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl/sam2.1_hiera_large.pt",
)
vq_sam2_config = VQ_SAM2Config(
    sam2_config=sam2_config,
    codebook_size=CODEBOOK_SIZE,
    codebook_depth=CODEBOOK_DEPTH,
    shared_codebook=False,
    latent_dim=256,
)
vq_sam2 = VQ_SAM2(vq_sam2_config).cuda().eval()
state = torch.load("zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl/mask_tokenizer_256x2.pth", map_location="cpu")
vq_sam2.load_state_dict(state)
sam2_image_processor = DirectResize(1024)

# message
image_path = "figs/totoro.jpg"
phrase = "the biggest totoro"
question = f"Please segment {phrase} in this image. A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>"
image = Image.open(image_path).convert('RGB')
ori_width, ori_height = image.size
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": image_path,
            },
            {"type": "text", "text": question},
        ],
    }
]

# VLM inferece
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

generated_ids = model.generate(
    **inputs, 
    max_new_tokens=512,
    do_sample=False,
    top_p=1.0,
)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

# decode mask
thinking_content, answer_content = extract_think_and_answer_robust(output_text[0])
quant_ids = extract_mt_token_ids_v1(answer_content)
if len(quant_ids) % CODEBOOK_DEPTH != 0:
    output_text = [fix_mt_format_comprehensive(answer_content)]
    quant_ids = extract_mt_token_ids_v2(answer_content)

batch_size = len(quant_ids) // CODEBOOK_DEPTH
remap_quant_ids = []
tags = []
for bs_id in range(batch_size):
    chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
    tags.append(f"{chunk_quant_ids[0]}-{chunk_quant_ids[1]}")
    remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
    code1 = remap_chunk_quant_ids[0]
    code2 = remap_chunk_quant_ids[1]
    if not (code2 >= 0 and code2 < CODEBOOK_SIZE):
        code2 = -1
    remap_chunk_quant_ids_error_handle = [code1, code2]
    remap_quant_ids.append(remap_chunk_quant_ids_error_handle)

batch_size = len(remap_quant_ids)
sam2_image = np.array(image)
sam2_image = sam2_image_processor.apply_image(sam2_image)
sam2_pixel_values = torch.from_numpy(sam2_image).permute(2, 0, 1).contiguous()
sam2_pixel_values = sam2_pixel_values.unsqueeze(0).to(vq_sam2.dtype).to(vq_sam2.device)
sam2_pixel_values = sam2_pixel_values.repeat(batch_size, 1, 1, 1)

quant_ids = torch.LongTensor(remap_quant_ids).to(vq_sam2.device)

with torch.no_grad():
    _pred_masks = vq_sam2.forward_with_codes(sam2_pixel_values, quant_ids)
_pred_masks = torch.nn.functional.interpolate(_pred_masks, size=(ori_height, ori_width), mode='bilinear')
_pred_masks = _pred_masks > 0.5
_pred_masks = _pred_masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
text_token_2d_mask_mapping = {tag: _pred_mask for tag, _pred_mask in zip(tags, _pred_masks)}
Downloads last month
34
Safetensors
Model size
4B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Collection including zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl