from typing import Literal import torch import torch.nn as nn import torch.nn.functional as F from mmdet.models import CrossEntropyLoss from xtuner.registry import BUILDER from xtuner.model.utils import get_peft_model_state_dict from projects.lisa.datasets.utils import DEFAULT_IMAGE_TOKEN from projects.lisa.models.lisa import LisaModel from xtuner.utils import PROMPT_TEMPLATE from xtuner.tools.utils import get_stop_criteria from transformers import GenerationConfig from projects.llava_sam2.models.preprocess.image_resize import DirectResize import numpy as np from .utils import dynamic_preprocess import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from pycocotools import mask as _mask from types import MethodType from xtuner.model.utils import guess_load_checkpoint from mmcv.ops import point_sample from mmdet.models.utils import get_uncertain_point_coords_with_randomness class VideoLLaVASAMBaselineModel(LisaModel): def __init__(self, mllm, tokenizer, grounding_encoder, loss_mask=None, loss_dice=None, torch_dtype=torch.bfloat16, pretrained_pth=None, frozen_sam2_decoder=True, special_tokens=None, loss_sample_points=False, num_points=12544, # for slow fast arch fast_pool=False, fast_pool_size=4, use_fast_supervision=False, # for inference phi3=True, template=None, # for arch selection arch_type:Literal['intern_vl', 'qwen', 'llava']='intern_vl', # for inference large model split_model=False, ): super(LisaModel, self).__init__() self.split_model = split_model if split_model: mllm.model_split = split_model self.special_tokens = special_tokens self.mllm = BUILDER.build(mllm) self.mllm.get_model().initialize_vision_modules(self.mllm.get_model().config) vision_tower = self.mllm.get_model().get_vision_tower() vision_tower.to(dtype=torch_dtype) self.arch_type = arch_type self.fast_pool = fast_pool self.fast_pool_size = fast_pool_size self.tokenizer = BUILDER.build(tokenizer) self.mllm.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] self.grounding_encoder = BUILDER.build(grounding_encoder) self.grounding_encoder.requires_grad_(False) if not frozen_sam2_decoder: self.grounding_encoder.sam2_model.sam_mask_decoder.requires_grad_(True) self.loss_mask = BUILDER.build(loss_mask) self.loss_dice = BUILDER.build(loss_dice) if use_fast_supervision: self.loss_exists = BUILDER.build(dict( type=CrossEntropyLoss, use_sigmoid=True, reduction='mean', loss_weight=1.0) ) self.torch_dtype = torch_dtype if pretrained_pth is not None: pretrained_state_dict = guess_load_checkpoint(pretrained_pth) self.load_state_dict(pretrained_state_dict, strict=False) print(f'Load pretrained weight from {pretrained_pth}') self.loss_sample_points = loss_sample_points self.num_points = num_points self.oversample_ratio = 3.0 self.importance_sample_ratio = 0.75 self.use_fast_supervision = use_fast_supervision self.phi3 = phi3 self.template = template def activation_checkpointing_disable(self): self.mllm.gradient_checkpointing_disable() def _add_special_tokens(self): special_tokens = self.special_tokens _num_new_tokens = self.tokenizer.add_tokens(special_tokens, special_tokens=True) # if not isinstance(self.mllm.model.language_model.get_output_embeddings(), nn.Linear): # print("Change the lm_head to nn.Linear !!!") # transposed = False # old_lm_head = self.mllm.model.language_model.get_output_embeddings() # old_num_tokens, old_lm_head_dim = ( # old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() # ) # new_lm_head_shape = (old_lm_head_dim, len(tokenizer)) if not transposed else ( # len(tokenizer), old_lm_head_dim) # has_new_lm_head_bias = old_lm_head.bias is not None # new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device) # new_lm_head.weight = old_lm_head.weight # new_lm_head.bias = old_lm_head.bias # self.mllm.model.language_model.set_output_embeddings(new_lm_head) # this is already done in mllm # if num_new_tokens > 0: # self.mllm.model.language_model.resize_token_embeddings(len(self.tokenizer)) # assert isinstance(self.mllm, InternVL_Slowfast) self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] def state_dict(self, *args, **kwargs): state_dict = super(LisaModel, self).state_dict(*args, **kwargs) from collections import OrderedDict to_return = OrderedDict() # Step 1. visual_encoder if self.mllm.use_visual_encoder_lora: to_return.update( get_peft_model_state_dict( self.mllm.model.vision_model, state_dict=state_dict)) raise NotImplementedError elif not self.mllm.freeze_visual_encoder: to_return.update({ k: v for k, v in state_dict.items() if 'visual_encoder.' in k }) raise NotImplementedError # Step 2. LLM if self.mllm.use_llm_lora: if self.arch_type == 'intern_vl': to_return.update( get_peft_model_state_dict(self.mllm.model.language_model, state_dict=state_dict) ) elif self.arch_type == 'qwen': to_return.update( get_peft_model_state_dict(self.mllm.model.model, state_dict=state_dict) ) elif self.arch_type == 'llava': to_return.update( get_peft_model_state_dict(self.mllm.model.language_model, state_dict=state_dict) ) elif not self.mllm.freeze_llm: to_return.update( {k: v for k, v in state_dict.items() if 'llm.' in k}) raise NotImplementedError # Step 3. Projector to_return.update( {k: v for k, v in state_dict.items() if 'mlp1.' in k}) to_return.update( {k: v for k, v in state_dict.items() if 'model.multi_modal_projector.' in k}) # Step 4. mask decoder of grounding model (SAM/SAM2) to_return.update( {k: v for k, v in state_dict.items() if 'mask_decoder' in k}) # Step 5. others (fcs) to_return.update( {k: v for k, v in state_dict.items() if 'text_hidden_fcs.' in k}) to_return.update( {k: v for k, v in state_dict.items() if 'text_exist_fcs.' in k} ) to_return.update( {k: v for k, v in state_dict.items() if 'lm_head.weight' in k or 'output' in k and 'sam2_model' not in k}) to_return.update( {k: v for k, v in state_dict.items() if 'embed_tokens.weight' in k or 'tok_embeddings' in k}) return to_return def check_obj_number(self, pred_embeddings_list_video, gt_masks_video, fix_number=5): assert len(pred_embeddings_list_video) == len(gt_masks_video) ret_pred_embeddings_list_video = [] ret_gt_masks_video = [] for pred_mebeds, gt_masks in zip(pred_embeddings_list_video, gt_masks_video): # assert len(pred_mebeds) == len(gt_masks) if len(pred_mebeds) != len(gt_masks): min_num = min(len(pred_mebeds), len(gt_masks)) pred_mebeds = pred_mebeds[:min_num] gt_masks = gt_masks[:min_num] if len(pred_mebeds) != fix_number: if len(pred_mebeds) > fix_number: _idxs = torch.randperm(pred_mebeds.shape[0]) _idxs = _idxs[:fix_number] pred_mebeds = pred_mebeds[_idxs] gt_masks = gt_masks[_idxs] else: n_repeat = fix_number // len(pred_mebeds) + 1 pred_mebeds = torch.cat([pred_mebeds] * n_repeat, dim=0)[:fix_number] gt_masks = torch.cat([gt_masks] * n_repeat, dim=0)[:fix_number] ret_pred_embeddings_list_video.append(pred_mebeds) ret_gt_masks_video.append(gt_masks) return ret_pred_embeddings_list_video, ret_gt_masks_video def forward(self, data, data_samples=None, mode='loss'): g_pixel_values = data.pop('g_pixel_values', None) gt_masks = data.pop('masks', None) frames_per_batch = data.pop('frames_per_batch', None) input_ids = data['input_ids'] fast_exists = data.pop('fast_exists', None) # if self.arch_type == 'llava' and data.get('pixel_values', None) is not None: # data['pixel_values'] = data['pixel_values'].to(self.torch_dtype) if self.fast_pool: output = self.mllm(data, data_samples, mode, fast_token_idx=self.fast_token_idx) else: output = self.mllm(data, data_samples, mode) if gt_masks is None: return {'llm_loss': output.loss, 'loss_mask': output.loss * 0.0, 'loss_dice': output.loss * 0.0} assert frames_per_batch, "Video Lisa require frames_per_batch !!!" # print('frmaes_per_batch: ', frames_per_batch) ori_size_list = [] for i_bs, mask in enumerate(gt_masks): mask_shape = mask.shape[-2:] ori_size_list += [mask_shape] * frames_per_batch[i_bs] seg_token_mask = input_ids == self.seg_token_idx # seg_token_mask = seg_token_mask[:, 1:] # seg_token_mask = torch.cat([ # seg_token_mask, # seg_token_mask.new_zeros(seg_token_mask.shape[0], 1)], dim=-1) hidden_states = output.hidden_states hidden_states = self.text_hidden_fcs(hidden_states[-1]) pred_embeddings = hidden_states[seg_token_mask] seg_token_counts = seg_token_mask.int().sum(-1) pred_embeddings_list_ = torch.split(pred_embeddings, seg_token_counts.tolist(), dim=0) pred_embeddings_list = [] for item in pred_embeddings_list_: if len(item) != 0: pred_embeddings_list.append(item) # print('pred_embeddings_list: ', [item.shape for item in pred_embeddings_list]) pred_embeddings_list_video, success = self.genetate_video_pred_embeddings( pred_embeddings_list, frames_per_batch) if not success: return {'llm_loss': output.loss, 'loss_mask': output.loss * 0.0, 'loss_dice': output.loss * 0.0} if self.use_fast_supervision and fast_exists is not None: # gt_exists = [] # for id_x, _fast_exists in enumerate(fast_exists): # num_tot = _fast_exists.shape[0] # num_conv = gt_masks[id_x].shape[0] // frames_per_batch[id_x] # assert num_tot % num_conv == 0 # gt_exists.append(_fast_exists.reshape(num_conv, num_tot // num_conv)) fast_flag = input_ids == self.fast_token_idx fast_tokens = output.hidden_states[-1][fast_flag] exists_logit = self.text_exist_fcs(fast_tokens[self.fast_pool_size ** 2 - 1::self.fast_pool_size ** 2]) gt_exists = torch.cat(fast_exists) loss_exists = self.loss_exists(exists_logit, gt_exists) else: loss_exists = None gt_masks_video = self.process_video_gt_masks(gt_masks, frames_per_batch) pred_embeddings_list_video, gt_masks_video = self.check_obj_number( pred_embeddings_list_video, gt_masks_video ) g_pixel_values = torch.stack([ self.grounding_encoder.preprocess_image(pixel) for pixel in g_pixel_values ]) num_objs = pred_embeddings_list_video[0].shape[0] num_frames = len(pred_embeddings_list_video) language_embeddings = torch.cat(pred_embeddings_list_video, dim=0)[:, None] sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values, expand_size=num_objs) pred_masks = self.grounding_encoder.inject_language_embd(sam_states, language_embeddings, nf_nobj=(num_frames, num_objs)) bs = len(pred_masks) loss_mask, loss_dice = 0, 0 accuracy = 0 for i in range(bs): pred_mask = pred_masks[i] gt_mask = gt_masks_video[i] pred_mask = F.interpolate(pred_mask.unsqueeze(0), size=ori_size_list[i], mode='bilinear').squeeze(0) if len(pred_mask) != len(gt_mask): print('Warning !!! Pred and GT not equal !!!') _zero = pred_mask.sum() * 0.0 loss_mask += _zero loss_dice += _zero accuracy += _zero else: if self.loss_sample_points: sampled_pred_mask, sampled_gt_mask = self.sample_points(pred_mask, gt_mask) sam_loss_dice = self.loss_dice( sampled_pred_mask, sampled_gt_mask, avg_factor=(len(gt_mask) + 1e-4)) sam_loss_mask = self.loss_mask( sampled_pred_mask.reshape(-1), sampled_gt_mask.reshape(-1), avg_factor=(pred_mask.shape[0] * sampled_pred_mask.shape[1] + 1e-4)) else: sam_loss_mask = self.loss_mask(pred_mask, gt_mask) sam_loss_dice = self.loss_dice(pred_mask, gt_mask) accuracy += torch.eq((pred_mask.sigmoid() > 0.5), gt_mask).to(pred_mask).mean() loss_mask += sam_loss_mask loss_dice += sam_loss_dice loss_dict = { 'loss_mask': loss_mask / (bs + 1e-4), 'loss_dice': loss_dice / (bs + 1e-4), 'llm_loss': output.loss, } if loss_exists is not None: loss_dict['loss_exists'] = loss_exists return loss_dict def sample_points(self, mask_pred, gt_masks): gt_masks = gt_masks.unsqueeze(1) gt_masks = gt_masks.to(mask_pred) mask_pred = mask_pred.unsqueeze(1) # (N, 1, h, w) with torch.no_grad(): points_coords = get_uncertain_point_coords_with_randomness( mask_pred.to(torch.float32), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio) # shape (num_total_gts, h, w) -> (num_total_gts, num_points) mask_point_targets = point_sample( gt_masks.float(), points_coords).squeeze(1) # shape (num_queries, h, w) -> (num_queries, num_points) mask_point_preds = point_sample( mask_pred.to(torch.float32), points_coords.to(torch.float32)).squeeze(1) return mask_point_preds.to(mask_pred.dtype), mask_point_targets.to(mask_pred.dtype) def genetate_video_pred_embeddings(self, pred_embeddings_list, frames_per_batch): if len(pred_embeddings_list) == len(frames_per_batch): success = True else: success = False print("len(pred_embeddings_list):{} is not equal to len(frames_per_batch):{} !!!".format(len(pred_embeddings_list), len(frames_per_batch))) pred_embeddings_list_video = [] for pred_embedding_batch, frame_nums in zip(pred_embeddings_list, frames_per_batch): pred_embeddings_list_video += [pred_embedding_batch] * frame_nums return pred_embeddings_list_video, success def process_video_gt_masks(self, gt_masks, frames_per_batch): gt_masks_video = [] assert len(gt_masks) == len(frames_per_batch) for gt_masks_batch, frames_num in zip(gt_masks, frames_per_batch): N, H, W = gt_masks_batch.shape assert N % frames_num == 0 gt_masks_batch = gt_masks_batch.reshape( N // frames_num, frames_num, H, W) for i in range(frames_num): gt_masks_video.append(gt_masks_batch[:, i]) return gt_masks_video def preparing_for_generation(self, metainfo, **kwargs): # set stop criteria and generation configs for model assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" self.bot_name = 'BOT' if 'template' in metainfo.keys(): template = metainfo['template'] else: template = PROMPT_TEMPLATE['phi3_chat'] if self.template is None: self.template = template stop_words = [] stop_words += self.template.get('STOP_WORDS', []) stop_criteria = get_stop_criteria( tokenizer=self.tokenizer, stop_words=stop_words) self.stop_criteria = stop_criteria default_generation_kwargs = dict( max_new_tokens=512, do_sample=False, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=( self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id ), ) default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) self.gen_config = GenerationConfig(**default_generation_kwargs) self.init_prediction_config = True self.mllm.to(self.torch_dtype) # self.text_hidden_fcs.to(self.torch_dtype) # if getattr(self, 'text_exist_fcs', None) is not None: # self.text_exist_fcs.to(self.torch_dtype) # for sam image processor self.extra_image_processor = DirectResize(target_length=1024, ) # for multi image process self.min_dynamic_patch = 1 if 'max_dynamic_patch' in metainfo.keys(): self.max_dynamic_patch = metainfo['max_dynamic_patch'] else: self.max_dynamic_patch = 12 self.downsample_ratio = 0.5 self.image_size = 448 self.use_thumbnail = True patch_size = 14 self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2)) self.IMAGENET_MEAN = (0.485, 0.456, 0.406) self.IMAGENET_STD = (0.229, 0.224, 0.225) self.IMG_CONTEXT_TOKEN = '' self.IMG_START_TOKEN = '' self.IMG_END_TOKEN = '' self.transformer = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD) ]) def predict_video(self, pixel_values, text_prompts, **kwargs): ori_h, ori_w = kwargs['ori_height'], kwargs['ori_width'] _input_ids = kwargs['input_ids'] g_pixel_values = kwargs.pop('g_pixel_values', None) g_pixel_values = torch.stack([ self.grounding_encoder.preprocess_image(pixel) for pixel in g_pixel_values ]) fast_pixel_values = kwargs.pop('fast_pixel_values', None) if fast_pixel_values is None: fast_token_idx = None else: fast_token_idx = self.fast_token_idx predictions = [] pred_masks = [] is_exists_list = [] for input_ids in _input_ids: input_ids = torch.tensor(input_ids).unsqueeze(0) attention_mask = torch.ones_like(input_ids, dtype=torch.bool) pixel_values = pixel_values.to(dtype=self.torch_dtype) if fast_pixel_values is not None: fast_pixel_values = fast_pixel_values.to(dtype=self.torch_dtype) mm_inputs = { 'pixel_values': pixel_values, 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': None, 'past_key_values': None, 'labels': None, 'fast_pixel_values': fast_pixel_values, 'fast_token_idx': fast_token_idx, } generate_output = self.mllm.evaluate( images_clip=pixel_values, images=g_pixel_values, input_ids=input_ids, resize_list=[], original_size_list=[], max_new_tokens=512, ) generate_output = self.mllm.generate( **mm_inputs, generation_config=self.gen_config, streamer=None, bos_token_id=self.tokenizer.bos_token_id, stopping_criteria=self.stop_criteria, output_hidden_states=True, return_dict_in_generate=True ) predict = self.tokenizer.decode(generate_output.sequences[0], skip_special_tokens=False).strip() # input_text = self.tokenizer.decode(mm_inputs['input_ids'][0], skip_special_tokens=False) # print(input_text, generate_output.sequences[0], '\n', predict, self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0]) predictions.append(predict) hidden_states = generate_output.hidden_states last_hidden_states = [item[-1][0] for item in hidden_states] last_hidden_states = torch.cat(last_hidden_states, dim=0) seg_hidden_states = get_seg_hidden_states( last_hidden_states, generate_output.sequences[0][:-1], seg_id=self.seg_token_idx ) if len(seg_hidden_states) == 0: print("Warning, no [SEG] tokens !!!") pred_masks.append(torch.zeros((g_pixel_values.shape[0], ori_h, ori_w), dtype=torch.int)) continue elif len(seg_hidden_states) > 1: print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) seg_hidden_states = seg_hidden_states[:1] seg_hidden_states = self.text_hidden_fcs(seg_hidden_states) seg_hidden_states = seg_hidden_states.to(dtype=torch.float32) sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values) # TODO: change 5 if len(pixel_values) < 5: pred_mask = self.grounding_encoder.language_embd_inference(sam_states, [seg_hidden_states] * pixel_values.shape[0]) else: pred_mask = self.grounding_encoder.language_embd_inference(sam_states, [seg_hidden_states] * 5) pred_mask = F.interpolate( pred_mask, size=(ori_h, ori_w), mode='bilinear', align_corners=False, ) pred_mask = pred_mask[:, 0] pred_mask = pred_mask.sigmoid() > 0.5 pred_mask = pred_mask.int() # supervision if self.use_fast_supervision and (input_ids == self.fast_token_idx).sum() > 0: fast_flag = input_ids.squeeze(0) == self.fast_token_idx len_out = generate_output.sequences[0][:-1].shape[0] fast_tokens = last_hidden_states[:-len_out][fast_flag].to(dtype=torch.float32) exists_logit = self.text_exist_fcs(fast_tokens[self.fast_pool_size ** 2 - 1::self.fast_pool_size ** 2]) is_exists = exists_logit.squeeze(-1).sigmoid() > 0.5 is_exists_list.append(is_exists) not_exists = torch.logical_not(is_exists) if torch.any(not_exists): pred_mask[not_exists] = pred_mask[not_exists] * 0 pred_masks.append(pred_mask) assert len(pred_masks) == len(text_prompts) ret_dict = { 'prediction': predictions, 'prediction_masks': [mask_to_rle(_item.cpu().numpy()) for _item in pred_masks], } if 'id' in kwargs.keys(): ret_dict['id'] = kwargs['id'] if len(is_exists_list) > 0: ret_dict['is_exists'] = is_exists_list return ret_dict def predict_forward( self, pixel_values, text_prompts, ori_image_size=None, ori_image=None, mode='eval', **kwargs ): assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()" if kwargs.get('type', 'image') == 'video': return self.predict_video(pixel_values, text_prompts, **kwargs) if mode == 'demo_video': return self.predict_demo_video( pixel_values, text_prompts, ori_image_size, ori_image, **kwargs) input_dict = {} # prepare images assert ori_image is not None, "InternVL2 only support process the image from scratch !!!" image = ori_image # for pixel segmentation tasks if ori_image_size is not None and 'masks' in kwargs.keys(): g_image = np.array(image) # for grounding g_image = self.extra_image_processor.apply_image(g_image) g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous() input_dict['g_pixel_values'] = g_pixel_values images = dynamic_preprocess(image, self.min_dynamic_patch, self.max_dynamic_patch, self.image_size, self.use_thumbnail) pixel_values = [self.transformer(image) for image in images] pixel_values = torch.stack(pixel_values).to(self.torch_dtype) input_dict['pixel_values'] = pixel_values num_image_tokens = pixel_values.shape[0] * self.patch_token image_token_str = f'{self.IMG_START_TOKEN}' \ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \ f'{self.IMG_END_TOKEN}' ret_predictions = [] ret_masks = [] if isinstance(text_prompts, str): text_prompts = [text_prompts] for text_prompt in text_prompts: # add template for text text_prompt = text_prompt.replace(DEFAULT_IMAGE_TOKEN, image_token_str) input_text = '' input_text += self.template['INSTRUCTION'].format( input=text_prompt, round=1, bot_name=self.bot_name) ids = self.tokenizer.encode(input_text) ids = torch.tensor(ids).cuda().unsqueeze(0) attention_mask = torch.ones_like(ids, dtype=torch.bool) mm_inputs = { 'pixel_values': input_dict['pixel_values'], 'input_ids': ids, 'attention_mask': attention_mask, 'position_ids': None, 'past_key_values': None, 'labels': None } generate_output = self.mllm.generate( **mm_inputs, generation_config=self.gen_config, streamer=None, bos_token_id=self.tokenizer.bos_token_id, stopping_criteria=self.stop_criteria, output_hidden_states=True, return_dict_in_generate=True ) predict = self.tokenizer.decode( generate_output.sequences[0], skip_special_tokens=False).strip() # print(predict) ret_predictions.append(predict) # refcoco test need debug !!! if ori_image_size is not None and 'masks' in kwargs.keys(): hidden_states = generate_output.hidden_states last_hidden_states = [item[-1][0] for item in hidden_states] last_hidden_states = torch.cat(last_hidden_states, dim=0) seg_hidden_states = get_seg_hidden_states( last_hidden_states, generate_output.sequences[0][:-1], seg_id=self.seg_token_idx ) if mode == 'demo': all_seg_hidden_states = self.text_hidden_fcs(seg_hidden_states) for seg_hidden_states in all_seg_hidden_states: seg_hidden_states = seg_hidden_states.unsqueeze(0) g_pixel_values = torch.stack([ self.grounding_encoder.preprocess_image(pixel, dtype=self.torch_dtype) for pixel in [input_dict['g_pixel_values']] ]) sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values) pred_masks = self.grounding_encoder.inject_language_embd(sam_states, [seg_hidden_states]) w, h = ori_image_size masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False) masks = masks[:, 0] masks = masks.sigmoid() > 0.5 masks = masks.int() ret_masks.append(masks) print('Done gcg demos') continue if len(seg_hidden_states) == 0: print("Warning, no [SEG] tokens !!!") ret_masks.append(None) continue elif len(seg_hidden_states) > 1: print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) seg_hidden_states = seg_hidden_states[:1] seg_hidden_states = self.text_hidden_fcs(seg_hidden_states) g_pixel_values = torch.stack([ self.grounding_encoder.preprocess_image(pixel, dtype=self.torch_dtype) for pixel in [input_dict['g_pixel_values']] ]) sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values) pred_masks = self.grounding_encoder.inject_language_embd(sam_states, [seg_hidden_states]) w, h = ori_image_size masks = F.interpolate(pred_masks, size=(h, w), mode='bilinear', align_corners=False) masks = masks[:, 0] masks = masks.sigmoid() > 0.5 masks = masks.int() ret_masks.append(masks) if len(ret_predictions) == 1: ret_predictions = ret_predictions[0] if len(ret_masks) == 0: return {'prediction': ret_predictions} _ret_masks = [] for i, ret_mask in enumerate(ret_masks): if ret_mask is None: _ret_masks.append(None) else: ret_mask = ret_mask.cpu().numpy() _ret_masks.append(mask_to_rle(ret_mask)) if mode == 'demo': return { 'prediction': ret_predictions, 'prediction_masks': ret_masks, } if 'masks' not in kwargs.keys(): gt_masks = None else: gt_masks = mask_to_rle(kwargs['masks'].cpu().numpy()) return { 'prediction': ret_predictions, 'prediction_masks': _ret_masks, 'gt_masks': gt_masks, } def predict_demo_video( self, pixel_values, text_prompts, ori_image_size=None, ori_image=None, **kwargs ): input_dict = {} # prepare images assert ori_image is not None, "InternVL2 only support process the image from scratch !!!" assert isinstance(ori_image, list) all_image_token_str = '' all_pixel_values = [] for idx_img, image in enumerate(ori_image): images = dynamic_preprocess(image, self.min_dynamic_patch, 1, self.image_size, self.use_thumbnail) pixel_values = [self.transformer(image) for image in images] all_pixel_values += pixel_values num_image_tokens = len(pixel_values) * self.patch_token image_token_str = f'{self.IMG_START_TOKEN}' \ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \ f'{self.IMG_END_TOKEN}' image_token_str = f"Frame-{idx_img + 1}: " + image_token_str + '\n' all_image_token_str += image_token_str all_pixel_values = torch.stack(all_pixel_values).to(self.torch_dtype) input_dict['pixel_values'] = all_pixel_values ret_predictions = [] ret_masks = [] if isinstance(text_prompts, str): text_prompts = [text_prompts] for text_prompt in text_prompts: # add template for text text_prompt = text_prompt.replace(DEFAULT_IMAGE_TOKEN, all_image_token_str) input_text = '' input_text += self.template['INSTRUCTION'].format( input=text_prompt, round=1, bot_name=self.bot_name) ids = self.tokenizer.encode(input_text) ids = torch.tensor(ids).cuda().unsqueeze(0) attention_mask = torch.ones_like(ids, dtype=torch.bool) mm_inputs = { 'pixel_values': input_dict['pixel_values'], 'input_ids': ids, 'attention_mask': attention_mask, 'position_ids': None, 'past_key_values': None, 'labels': None } generate_output = self.mllm.generate( **mm_inputs, generation_config=self.gen_config, streamer=None, bos_token_id=self.tokenizer.bos_token_id, stopping_criteria=self.stop_criteria, output_hidden_states=True, return_dict_in_generate=True ) predict = self.tokenizer.decode( generate_output.sequences[0], skip_special_tokens=False).strip() # print(predict) ret_predictions.append(predict) return { 'prediction': ret_predictions } def get_seg_hidden_states(hidden_states, output_ids, seg_id): seg_mask = output_ids == seg_id n_out = len(seg_mask) return hidden_states[-n_out:][seg_mask] def mask_to_rle(mask): rle = [] for m in mask: rle.append(_mask.encode(np.asfortranarray(m.astype(np.uint8)))) rle[-1]['counts'] = rle[-1]['counts'].decode() return rle