SAMTok
Collection
A unified mask-token interface for MLLMs.
β’
5 items
β’
Updated
Install the transformers that supports Qwen2.5-VL.
import re
import numpy as np
def extract_mt_token_ids_v1(text):
pattern = r"<\|mt_(\d{4})\|>"
return [int(x) for x in re.findall(pattern, text)]
def extract_mt_token_ids_v2(text):
pattern = re.compile(r'<\|mt_start\|><\|mt_(\d{4})\|><\|mt_(\d{4})\|><\|mt_end\|>')
matches = pattern.findall(text)
ret_list = []
for num1, num2 in matches:
ret_list.append(int(num1))
ret_list.append(int(num2))
return ret_list
def find_first_index(arr, value):
indices = np.where(arr == value)[0]
return indices[0] if len(indices) > 0 else -1
def fix_mt_format_comprehensive(text):
pattern_too_many = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_\d+\|>)(?:<\|mt_\d+\|>)+<\|mt_end\|>'
replacement_too_many = r'\1\2\3<|mt_end|>'
text = re.sub(pattern_too_many, replacement_too_many, text)
pattern_too_few_with_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(<\|mt_end\|>)'
replacement_too_few = r'\1\2<|mt_9999|><|mt_end|>'
text = re.sub(pattern_too_few_with_end, replacement_too_few, text)
pattern_too_few_no_end = r'(<\|mt_start\|>)(<\|mt_\d+\|>)(?!<\|mt_)'
replacement_too_few_no_end = r'\1\2<|mt_9999|><|mt_end|>'
text = re.sub(pattern_too_few_no_end, replacement_too_few_no_end, text)
return text
def extract_think_and_answer_robust(response: str) -> Tuple[Optional[str], Optional[str]]:
think_content = None
answer_content = None
think_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
answer_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
think_match = think_pattern.search(response)
if think_match:
think_content = think_match.group(1)
answer_match = answer_pattern.search(response)
if answer_match:
answer_content = answer_match.group(1)
if answer_content is None or think_content is None:
if '<answer>' in response:
head, tail = response.split('<answer>', 1)
if think_content is None:
think_content = head
if answer_content is None:
answer_content = tail
elif '</think>' in response:
head, tail = response.split('</think>', 1)
if think_content is None:
think_content = head
if answer_content is None:
answer_content = tail
return think_content, answer_content
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from projects.samtok.models import DirectResize, VQ_SAM2, VQ_SAM2Config, SAM2Config
# build VLM
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl", torch_dtype="auto"
).cuda().eval()
processor = AutoProcessor.from_pretrained("zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl")
# build SAMTok
CODEBOOK_SIZE = 256
CODEBOOK_DEPTH = 2
sam2_config = SAM2Config(
ckpt_path="zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl/sam2.1_hiera_large.pt",
)
vq_sam2_config = VQ_SAM2Config(
sam2_config=sam2_config,
codebook_size=CODEBOOK_SIZE,
codebook_depth=CODEBOOK_DEPTH,
shared_codebook=False,
latent_dim=256,
)
vq_sam2 = VQ_SAM2(vq_sam2_config).cuda().eval()
state = torch.load("zhouyik/Qwen2.5-VL-3B-SAMTok-gres-rl/mask_tokenizer_256x2.pth", map_location="cpu")
vq_sam2.load_state_dict(state)
sam2_image_processor = DirectResize(1024)
# message
image_path = "figs/totoro.jpg"
phrase = "the biggest totoro"
question = f"Please segment {phrase} in this image. A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>"
image = Image.open(image_path).convert('RGB')
ori_width, ori_height = image.size
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_path,
},
{"type": "text", "text": question},
],
}
]
# VLM inferece
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
top_p=1.0,
)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# decode mask
thinking_content, answer_content = extract_think_and_answer_robust(output_text[0])
quant_ids = extract_mt_token_ids_v1(answer_content)
if len(quant_ids) % CODEBOOK_DEPTH != 0:
output_text = [fix_mt_format_comprehensive(answer_content)]
quant_ids = extract_mt_token_ids_v2(answer_content)
batch_size = len(quant_ids) // CODEBOOK_DEPTH
remap_quant_ids = []
tags = []
for bs_id in range(batch_size):
chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
tags.append(f"{chunk_quant_ids[0]}-{chunk_quant_ids[1]}")
remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
code1 = remap_chunk_quant_ids[0]
code2 = remap_chunk_quant_ids[1]
if not (code2 >= 0 and code2 < CODEBOOK_SIZE):
code2 = -1
remap_chunk_quant_ids_error_handle = [code1, code2]
remap_quant_ids.append(remap_chunk_quant_ids_error_handle)
batch_size = len(remap_quant_ids)
sam2_image = np.array(image)
sam2_image = sam2_image_processor.apply_image(sam2_image)
sam2_pixel_values = torch.from_numpy(sam2_image).permute(2, 0, 1).contiguous()
sam2_pixel_values = sam2_pixel_values.unsqueeze(0).to(vq_sam2.dtype).to(vq_sam2.device)
sam2_pixel_values = sam2_pixel_values.repeat(batch_size, 1, 1, 1)
quant_ids = torch.LongTensor(remap_quant_ids).to(vq_sam2.device)
with torch.no_grad():
_pred_masks = vq_sam2.forward_with_codes(sam2_pixel_values, quant_ids)
_pred_masks = torch.nn.functional.interpolate(_pred_masks, size=(ori_height, ori_width), mode='bilinear')
_pred_masks = _pred_masks > 0.5
_pred_masks = _pred_masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
text_token_2d_mask_mapping = {tag: _pred_mask for tag, _pred_mask in zip(tags, _pred_masks)}