| | from typing import List, Optional, Tuple, Union |
| |
|
| | import os |
| | import torch |
| | import numpy as np |
| | import torch.nn as nn |
| | import matplotlib.pyplot as plt |
| | from PIL import Image |
| | import torch.nn.functional as F |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM |
| | from model.IXC.modeling_internlm2 import InternLM2Model |
| | from model.sam2.build_sam import build_sam2_hf |
| | from model.sam2.utils.transforms import SAM2Transforms |
| | from transformers import TextStreamer |
| | try: |
| | from transformers.generation.streamers import BaseStreamer |
| | except: |
| | BaseStreamer = None |
| |
|
| |
|
| | def dice_loss( |
| | inputs: torch.Tensor, |
| | targets: torch.Tensor, |
| | num_masks: float, |
| | scale=1000, |
| | eps=1e-6, |
| | ): |
| | """ |
| | Compute the DICE loss, similar to generalized IOU for masks |
| | Args: |
| | inputs: A float tensor of arbitrary shape. |
| | The predictions for each example. |
| | targets: A float tensor with the same shape as inputs. Stores the binary |
| | classification label for each element in inputs |
| | (0 for the negative class and 1 for the positive class). |
| | """ |
| | inputs = inputs.sigmoid() |
| | inputs = inputs.flatten(1, 2) |
| | targets = targets.flatten(1, 2) |
| | numerator = 2 * (inputs / scale * targets).sum(-1) |
| | denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) |
| | loss = 1 - (numerator + eps) / (denominator + eps) |
| | loss = loss.sum() / (num_masks + 1e-8) |
| | return loss |
| |
|
| |
|
| | def sigmoid_ce_loss( |
| | inputs: torch.Tensor, |
| | targets: torch.Tensor, |
| | num_masks: float, |
| | ): |
| | """ |
| | Args: |
| | inputs: A float tensor of arbitrary shape. |
| | The predictions for each example. |
| | targets: A float tensor with the same shape as inputs. Stores the binary |
| | classification label for each element in inputs |
| | (0 for the negative class and 1 for the positive class). |
| | Returns: |
| | Loss tensor |
| | """ |
| | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
| | loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) |
| | return loss |
| |
|
| |
|
| | class GeoPixelMetaModel: |
| | def __init__( |
| | self, |
| | config, |
| | **kwargs, |
| | ): |
| | super(GeoPixelMetaModel, self).__init__(config) |
| | self.config = config |
| | self.config.train_mask_decoder = getattr(self.config, "train_mask_decoder", kwargs.get("train_mask_decoder", False)) |
| | self.config.out_dim = getattr(self.config, "out_dim", kwargs.get("out_dim", 256)) |
| | self.vision_pretrained = kwargs.get("vision_pretrained", None) |
| | self.initialize_geopixel_modules(self.config) |
| |
|
| | def initialize_geopixel_modules(self, config): |
| | |
| | self.visual_model = build_sam2_hf(self.vision_pretrained) |
| |
|
| | self._transform = SAM2Transforms( |
| | resolution=self.visual_model.image_size, |
| | mask_threshold=0.0, |
| | max_hole_area=0.0, |
| | max_sprinkle_area=0.0, |
| | ) |
| | |
| | self._bb_feat_sizes = [ |
| | (256, 256), |
| | (128, 128), |
| | (64, 64), |
| | ] |
| | |
| | for param in self.visual_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | if config.train_mask_decoder: |
| | self.visual_model.sam_mask_decoder.train() |
| | for param in self.visual_model.sam_mask_decoder.parameters(): |
| | param.requires_grad = True |
| |
|
| | |
| | in_dim = config.hidden_size |
| | out_dim = config.out_dim |
| | text_projection_layers = [ |
| | nn.Linear(in_dim, in_dim), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(in_dim, out_dim), |
| | nn.Dropout(0.0), |
| | ] |
| | self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_projection_layers)]) |
| | self.text_hidden_fcs.train() |
| | for param in self.text_hidden_fcs.parameters(): |
| | param.requires_grad = True |
| |
|
| |
|
| | class GeoPixelModel(GeoPixelMetaModel, InternLM2Model): |
| | def __init__( |
| | self, |
| | config, |
| | **kwargs, |
| | ): |
| | super(GeoPixelModel, self).__init__(config, **kwargs) |
| | self.config.use_cache = False |
| |
|
| |
|
| | class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM): |
| | def __init__(self,config,**kwargs,): |
| | |
| | self.ce_loss_weight = kwargs.pop("ce_loss_weight", None) |
| | self.dice_loss_weight = kwargs.pop("dice_loss_weight", None) |
| | self.bce_loss_weight = kwargs.pop("bce_loss_weight", None) |
| | self.seg_token_idx = kwargs.pop("seg_token_idx") |
| |
|
| | super().__init__(config) |
| | self.model = GeoPixelModel(config, **kwargs) |
| | self.vocab_size = config.vocab_size |
| | self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.post_init() |
| |
|
| | def encode_g_img(self, image): |
| | """ |
| | Calculates the image embeddings for the provided image |
| | Arguments: |
| | image (np.ndarray or str) |
| | """ |
| | if image is None: |
| | return None |
| | if isinstance(image, str): |
| | _, ext = os.path.splitext(image) |
| | if ext.lower() in {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp','.tif'}: |
| | image = Image.open(image) |
| | w, h = image.size |
| | _orig_hw = [(h, w)] |
| | else: |
| | print ('Unknow input format', image) |
| | return None |
| | else: |
| | assert isinstance(image, torch.Tensor) |
| | _orig_hw = [image.shape[:2]] |
| | image = self.model._transform(image) |
| | image = image[None, ...].to(self.device) |
| | assert ( len(image.shape) == 4 and image.shape[1] == 3), f"image must be of size 1x3xHxW, got {image.shape}" |
| | features = self.get_visual_embs(image) |
| | return features,_orig_hw |
| |
|
| | def get_visual_embs(self, img_batch: torch.FloatTensor): |
| | with torch.no_grad(): |
| | torch.cuda.empty_cache() |
| | img_batch = img_batch.to(self.device) |
| | batch_size = img_batch.shape[0] |
| | assert ( |
| | len(img_batch.shape) == 4 and img_batch.shape[1] == 3 |
| | ), f"grounding_img_batch must be of size Bx3xHxW, got {img_batch.shape}" |
| | backbone_out = self.model.visual_model.forward_image(img_batch) |
| | _, vision_feats, _, _ = self.model.visual_model._prepare_backbone_features(backbone_out) |
| | if self.model.visual_model.directly_add_no_mem_embed: |
| | vision_feats[-1] = vision_feats[-1] + self.model.visual_model.no_mem_embed |
| | feats = [ |
| | feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) |
| | for feat, feat_size in zip(vision_feats[::-1], self.model._bb_feat_sizes[::-1]) |
| | ][::-1] |
| | features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} |
| | return features |
| | |
| | def forward(self, **kwargs): |
| | return super().forward(**kwargs) if "past_key_values" in kwargs else self.model_forward(**kwargs) |
| | |
| | def model_forward( |
| | self, |
| | inference: bool = False, |
| | **kwargs, |
| | ): |
| | samples = kwargs.get('samples', None) |
| | if samples and samples['data_type'][0] == 'grounding': |
| | kwargs['output_hidden_states'] = True |
| | kwargs['use_cache'] = False |
| |
|
| | torch.cuda.empty_cache() |
| | outputs = super().forward(**kwargs) |
| |
|
| | if inference: |
| | assert len(samples['text_input']) == 1 and len(samples['image'][0]) == 1 |
| | output_hidden_states = [outputs.hidden_states] |
| | outputs = None |
| | else: |
| | output_hidden_states = outputs.hidden_states |
| |
|
| | hidden_states = [] |
| | assert len(self.model.text_hidden_fcs) == 1 |
| | hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1])) |
| | last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
| |
|
| | seg_token_mask = outputs.seg_token_mask |
| | pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)] |
| | image_g_batch = torch.cat(samples['image_g'][0],dim = 0) |
| | image_g_features = self.get_visual_embs(image_g_batch) |
| | ori_hw = samples['ori_hw'][0] |
| | all_pred_masks = [] |
| | for i in range(len(pred_embeddings)): |
| | if (pred_embeddings[i].numel()== 0): |
| | pred_masks.append([]) |
| | continue |
| | (sparse_embeddings, dense_embeddings,) = self.model.visual_model.sam_prompt_encoder( |
| | points=None, |
| | boxes=None, |
| | masks=None, |
| | text_embeds=pred_embeddings[i].unsqueeze(1), |
| | ) |
| | batch_mode = (pred_embeddings[i].shape[0]>1) |
| | high_res_features = [ |
| | feat_level[i].unsqueeze(0) |
| | for feat_level in image_g_features["high_res_feats"] |
| | ] |
| | sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
| | image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16) |
| | low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder( |
| | image_embeddings=image_g_embeds, |
| | image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | repeat_image=batch_mode, |
| | multimask_output=False, |
| | high_res_features=high_res_features, |
| | ) |
| | pred_masks = self.model._transform.postprocess_masks( |
| | low_res_masks, |
| | ori_hw[i], |
| | ) |
| | all_pred_masks.append(pred_masks[:, 0]) |
| | |
| |
|
| | model_output = outputs |
| | gt_masks = samples['masks'][0] |
| | pred_masks = all_pred_masks |
| |
|
| | if inference: |
| | return { |
| | "pred_masks": pred_masks, |
| | "gt_masks": gt_masks, |
| | } |
| |
|
| | ce_loss = model_output.loss |
| | ce_loss = ce_loss * self.ce_loss_weight |
| | mask_bce_loss = 0 |
| | mask_dice_loss = 0 |
| | num_masks = 0 |
| |
|
| | for batch_idx in range(len(pred_masks)): |
| | cur_gt_masks = torch.stack( |
| | [ |
| | torch.from_numpy(gt_mask).to(dtype=pred_masks[batch_idx].dtype, device=pred_masks[batch_idx].device) |
| | for gt_mask in gt_masks[batch_idx] |
| | ], |
| | dim=0 |
| | ) |
| | cur_pred_masks = pred_masks[batch_idx] |
| | assert ( |
| | cur_gt_masks.shape[0] == cur_pred_masks.shape[0] |
| | ), "gt_masks.shape: {}, pred_masks.shape: {}".format( |
| | cur_gt_masks.shape, cur_pred_masks.shape |
| | ) |
| | mask_bce_loss += ( |
| | sigmoid_ce_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0]) |
| | * cur_gt_masks.shape[0] |
| | ) |
| | mask_dice_loss += ( |
| | dice_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0]) |
| | * cur_gt_masks.shape[0] |
| | ) |
| | num_masks += cur_gt_masks.shape[0] |
| |
|
| | mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) |
| | mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) |
| | mask_loss = mask_bce_loss + mask_dice_loss |
| |
|
| | loss = ce_loss + mask_loss |
| | outputs = CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=model_output.logits, |
| | past_key_values=model_output.past_key_values, |
| | hidden_states=output_hidden_states, |
| | attentions=model_output.attentions, |
| | ) |
| | outputs.ce_loss = ce_loss |
| | outputs.mask_bce_loss = mask_bce_loss |
| | outputs.mask_dice_loss = mask_dice_loss |
| | outputs.mask_loss = mask_loss |
| | else: |
| | outputs = super().forward(**kwargs) |
| | return outputs |
| |
|
| | def evaluate( |
| | self, |
| | tokenizer, |
| | query: str, |
| | images: List[Tuple[str, str]] = [], |
| | hd_num: int = 9, |
| | history: List[Tuple[str, str]] = [], |
| | max_new_tokens: int = 1024, |
| | stream: bool = False, |
| | **kwargs, |
| | ): |
| | with torch.no_grad(): |
| | inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num) |
| | inputs = { |
| | k: v.to(self.device) |
| | for k, v in inputs.items() if torch.is_tensor(v) |
| | } |
| | eos_token_id = [ |
| | tokenizer.eos_token_id, |
| | |
| | ] |
| | all_pred_masks = [] |
| | |
| | if stream: |
| | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| | else: |
| | streamer = None |
| |
|
| | outputs = self.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | im_mask=im_mask, |
| | input_ids = None, |
| | streamer= streamer, |
| | num_beams=1, |
| | do_sample=False, |
| | temperature=1.0, |
| | top_p= 1.0, |
| | top_k = 0, |
| | eos_token_id=eos_token_id, |
| | repetition_penalty=1.0, |
| | infer_mode = 'base', |
| | output_hidden_states=True, |
| | return_dict_in_generate=True, |
| | **kwargs, |
| | ) |
| | output_ids = outputs['sequences'] |
| | response = tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True) |
| | response = response.replace("[UNUSED_TOKEN_145]","") |
| | history = history + [(query, response)] |
| | if len(images)==1 and isinstance(images[0], str): |
| | output_hidden_states = outputs.hidden_states[-1] |
| | seg_token_mask = output_ids[:, 1:-1] == self.seg_token_idx |
| | inputs_embeds_len = inputs['inputs_embeds'].size(1) |
| | seg_token_mask = torch.cat( |
| | [ |
| | torch.zeros((seg_token_mask.shape[0], inputs_embeds_len)).bool().cuda(), |
| | seg_token_mask, |
| | ], |
| | dim=1, |
| | ) |
| | hidden_states = [] |
| | assert len(self.model.text_hidden_fcs) == 1 |
| | hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states)) |
| | last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
| | pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)] |
| | image_g_features, ori_hw = self.encode_g_img(images[0]) |
| |
|
| | for i in range(len(pred_embeddings)): |
| | if (pred_embeddings[i].numel()== 0): |
| | all_pred_masks.append([]) |
| | continue |
| | (sparse_embeddings,dense_embeddings,) = self.model.visual_model.sam_prompt_encoder( |
| | points=None, |
| | boxes=None, |
| | masks=None, |
| | text_embeds=pred_embeddings[i].unsqueeze(1), |
| | ) |
| | batch_mode = (pred_embeddings[i].shape[0]>1) |
| | high_res_features = [ |
| | feat_level[i].unsqueeze(0) |
| | for feat_level in image_g_features["high_res_feats"] |
| | ] |
| | sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
| | image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16) |
| |
|
| | low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder( |
| | image_embeddings=image_g_embeds, |
| | image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | repeat_image=batch_mode, |
| | multimask_output=False, |
| | high_res_features=high_res_features, |
| | ) |
| | pred_masks = self.model._transform.postprocess_masks( |
| | low_res_masks, |
| | ori_hw[i], |
| | ) |
| | all_pred_masks.append(pred_masks[:, 0]) |
| |
|
| | return response, all_pred_masks |