PixDLM / model /PixDLM.py
WhynotHug's picture
Upload folder using huggingface_hub
3334467 verified
Raw
History Blame Contribute Delete
48.3 kB
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BitsAndBytesConfig, CLIPVisionModel
import copy
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN)
from .llava.language_model.llava_llama import (LlavaLlamaForCausalLM,
LlavaLlamaModel)
from .segment_anything import build_sam_vit_h
from .segment_anything.modeling import (
MaskDecoder,
PromptEncoder,
TwoWayTransformer,
LayerNorm2d,
MaskDecoderMultiScale,
Three_Level_Multi_Scale_Decoder,
)
from utils.matcher import match_pred
from typing import Any, Dict, List, Tuple
def dice_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
scale=1000,
eps=1e-6,
):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1, 2)
targets = targets.flatten(1, 2)
numerator = 2 * (inputs / scale * targets).sum(-1)
denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
loss = 1 - (numerator + eps) / (denominator + eps)
loss = loss.sum() / (num_masks + 1e-8)
return loss
def sigmoid_ce_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
return loss
def overlap_loss(inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
batch_seg_token_count: int):
if num_masks == 0:
return inputs.sum() * 0
batch_seg_token_count = batch_seg_token_count.cumsum(-1)
batch_seg_token_count = torch.cat(
[torch.zeros(1).long().cuda(), batch_seg_token_count], dim=0
)
loss = 0
for i in range(len(batch_seg_token_count) -1):
start_i = batch_seg_token_count[i]
end_i = batch_seg_token_count[i+1]
assert end_i <= len(targets), (targets.shape, batch_seg_token_count)
question_inputs = inputs[start_i:end_i]
question_targets = targets[start_i:end_i]
if len(question_targets) == 0:
continue
n, h, w = question_inputs.shape
all_targets = torch.zeros_like(question_targets[0]).bool()
for target in question_targets:
all_targets = (all_targets | target.bool())
bg_area = all_targets < 0
bg_area = bg_area[None].repeat(n, 1, 1)
overlap_area = (question_inputs > 0).sum(dim=0)
overlap_area = overlap_area >= 2
overlap_area = overlap_area[None].repeat(n, 1, 1)
weight = torch.ones_like(question_inputs)
weight[~overlap_area] = 0
q_loss = F.binary_cross_entropy_with_logits(question_inputs, question_targets, weight=weight, reduction="none")
q_loss = q_loss.flatten(1, 2).mean(1).sum()
loss = loss + q_loss
loss = loss / (num_masks + 1e-8)
return loss
class PixDLMMetaModel:
def __init__(
self,
config,
**kwargs,
):
super(PixDLMMetaModel, self).__init__(config)
self.logger = kwargs.get("logger", None)
self.local_rank = kwargs.get("local_rank", 1)
self.config = config
self.sam2_config = 'configs/sam2.1/sam2.1_hiera_l.yaml'
if "three_level_multi_scale_decoder" in kwargs:
self.config.three_level_multi_scale_decoder = kwargs["three_level_multi_scale_decoder"]
if not hasattr(self.config, "train_mask_decoder"):
self.config.train_mask_decoder = kwargs["train_mask_decoder"]
self.config.out_dim = kwargs["out_dim"]
self.vision_pretrained = kwargs.get("vision_pretrained", None)
else:
self.vision_pretrained = kwargs.get("vision_pretrained", None)
self.initialize_pixdlm_modules(self.config)
def initialize_pixdlm_modules(self, config):
if self.config.vision_tower_for_mask:
prompt_embed_dim = 256
image_size = config.resize_vision_tower_size
mask_decoder_transformer_depth = 2
if self.local_rank == 0 and self.logger is not None:
self.logger.info('--------build_sam_decoder--------')
self.logger.info('--------sam decoder image size {}--------'.format(image_size))
vit_patch_size = 14
image_embedding_size = image_size // vit_patch_size
self.prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
)
decoder_cls = (
Three_Level_Multi_Scale_Decoder
if getattr(config, "three_level_multi_scale_decoder", False)
else MaskDecoderMultiScale
)
self.mask_decoder = decoder_cls(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=mask_decoder_transformer_depth,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
image_feature_scale_num=config.image_feature_scale_num
)
embed_dim = self.config.hidden_size
out_chans = prompt_embed_dim
self.image_feature_neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.sam_to_embed_conv = nn.Sequential(
nn.Conv2d(256, embed_dim, kernel_size=1, bias=False),
LayerNorm2d(embed_dim),
)
else:
self.visual_model = build_sam_vit_h(self.vision_pretrained)
for param in self.visual_model.parameters():
param.requires_grad = False
if config.train_mask_decoder:
self.visual_model.mask_decoder.train()
for param in self.visual_model.mask_decoder.parameters():
param.requires_grad = True
in_dim = config.hidden_size
out_dim = config.out_dim
text_fc = [
nn.Linear(in_dim, in_dim),
nn.ReLU(inplace=True),
nn.Linear(in_dim, out_dim),
nn.Dropout(0.0),
]
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
self.text_hidden_fcs.train()
for param in self.text_hidden_fcs.parameters():
param.requires_grad = True
class PixDLMModel(PixDLMMetaModel, LlavaLlamaModel):
def __init__(
self,
config,
**kwargs,
):
super(PixDLMModel, self).__init__(config, **kwargs)
self.config.use_cache = False
self.config.vision_tower = self.config.mm_vision_tower
self.config.mm_vision_select_feature = "patch"
self.config.image_aspect_ratio = "square"
self.config.image_grid_pinpoints = None
self.config.tune_mm_mlp_adapter = False
self.config.freeze_mm_mlp_adapter = True
self.config.pretrain_mm_mlp_adapter = None
self.config.mm_use_im_patch_token = False
class PixDLMForCausalLM(LlavaLlamaForCausalLM):
def __init__(
self,
config,
**kwargs,
):
kwargs.setdefault("image_feature_scale_num", 3)
kwargs.setdefault("pad_train_clip_images", True)
kwargs.setdefault("resize_vision_tower", True)
kwargs.setdefault("resize_vision_tower_size", 448)
kwargs.setdefault("vision_tower_for_mask", True)
kwargs.setdefault("separate_mm_projector", False)
self.logger = kwargs.get("logger", None)
config.resize_vision_tower = kwargs.get("resize_vision_tower", False)
config.resize_vision_tower_size = kwargs.get("resize_vision_tower_size", 224)
config.pad_train_clip_images = kwargs.get("pad_train_clip_images", False)
config.vision_tower_for_mask = kwargs.get("vision_tower_for_mask", False)
config.separate_mm_projector = kwargs.get("separate_mm_projector", False)
config.mm_projector_hidden_dim = 2
config.mm_projector_out_dim = 1
self.image_feature_scale_num = kwargs.get("image_feature_scale_num", 1)
config.image_feature_scale_num = kwargs.get("image_feature_scale_num", 1)
if not hasattr(config, "train_mask_decoder"):
config.mm_use_im_start_end = kwargs.pop("use_mm_start_end", True)
config.mm_vision_tower = kwargs.get(
"vision_tower", "openai/clip-vit-large-patch14"
)
self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
else:
config.mm_vision_tower = config.vision_tower
self.ce_loss_weight = kwargs.pop("ce_loss_weight", getattr(config, "ce_loss_weight", 1.0))
self.dice_loss_weight = kwargs.pop("dice_loss_weight", getattr(config, "dice_loss_weight", 1.0))
self.bce_loss_weight = kwargs.pop("bce_loss_weight", getattr(config, "bce_loss_weight", 1.0))
self.vision_tower_for_mask = kwargs.get("vision_tower_for_mask", False)
self.seg_token_idx = kwargs.pop("seg_token_idx")
self.seg_token_num = kwargs.get("seg_token_num", 1)
self.tokenizer = kwargs.get("tokenizer", None)
self.local_rank = kwargs.get("local_rank", 1)
self.pad_train_clip_images = kwargs.get("pad_train_clip_images", False)
self.masks_process_with_clip = kwargs.get("masks_process_with_clip", False)
self.is_multipath_encoder = kwargs.get("is_multipath_encoder", True)
config.is_multipath_encoder=self.is_multipath_encoder
self.freeze_vision=kwargs.get("freeze_vision", False)
config.freeze_vision=self.freeze_vision
kwargs_value = kwargs.get("three_level_multi_scale_decoder", None)
config_value = getattr(config, "three_level_multi_scale_decoder", False)
if kwargs_value is not None:
self.three_level_multi_scale_decoder = kwargs_value
else:
self.three_level_multi_scale_decoder = config_value
config.three_level_multi_scale_decoder = self.three_level_multi_scale_decoder
logger = kwargs.get("logger", None)
if isinstance(self.seg_token_idx, list):
if self.local_rank == 0 and logger is not None:
logger.info("Initialize multi-seg scalar")
seg_token_num = len(self.seg_token_idx)
scalar = 1 / seg_token_num
self.multiseg_scalar = [torch.nn.Parameter(torch.ones([]) * scalar) for _ in range(seg_token_num)]
if self.image_feature_scale_num > 1:
scalar = 1 / self.image_feature_scale_num
self.multiscale_scalar = [torch.nn.Parameter(torch.ones([]) * scalar) for _ in range(self.image_feature_scale_num)]
super().__init__(config)
self.model = PixDLMModel(config, **kwargs)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
self.iter = 0
self.iter1 = 0
if config.resize_vision_tower_size != 224:
if self.local_rank == 0 and self.logger is not None:
self.logger.info('--------mm_projector requires grad--------')
for n, p in self.model.named_parameters():
if any([x in n for x in ["mm_projector"]]):
p.requires_grad = True
def get_visual_embs(self, pixel_values: torch.FloatTensor):
with torch.no_grad():
image_embeddings_list = []
for i in range(pixel_values.shape[0]):
torch.cuda.empty_cache()
image_embeddings = self.model.visual_model.image_encoder(
pixel_values[i].unsqueeze(0)
)
image_embeddings_list.append(image_embeddings)
torch.cuda.empty_cache()
image_embeddings = torch.cat(image_embeddings_list, 0)
return image_embeddings
def forward(self,**kwargs):
if "past_key_values" in kwargs:
return super().forward(**kwargs)
return self.model_forward(**kwargs)
def model_forward(
self,
images: torch.FloatTensor,
images_clip: torch.FloatTensor,
input_ids: torch.LongTensor,
labels: torch.LongTensor,
attention_masks: torch.LongTensor,
offset: torch.LongTensor,
masks_list: List[torch.FloatTensor],
label_list: List[torch.Tensor],
resize_list: List[tuple],
inference: bool = False,
clip_resize_list = None,
txt_feat=None,
**kwargs,
):
multi_reason_list = kwargs.get('multi_reason_list', None)
if not self.vision_tower_for_mask:
image_embeddings = self.get_visual_embs(images)
batch_size = images.shape[0]
assert batch_size == len(offset) - 1
if isinstance(self.seg_token_idx, list):
seg_token_num = self.seg_token_num
seg_token_mask = torch.zeros_like(input_ids[:, 1:]).bool()
for seg_token_idx in self.seg_token_idx:
seg_token_mask = seg_token_mask | (input_ids[:, 1:] == seg_token_idx)
else:
seg_token_num = self.seg_token_num
seg_token_mask = input_ids[:, 1:] == self.seg_token_idx
seg_token_mask = torch.cat(
[
seg_token_mask,
torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(),
],
dim=1,
)
if inference:
n_batch = 1
length = input_ids.shape[0]
assert images_clip.shape[0] == 1
images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous()
extend_clip_resize_list = [clip_resize_list[0]] * length
output_hidden_states = []
output_image_features = []
for i in range(n_batch):
start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0])
output_i = super().forward(
images=images_clip_extend[: end_i - start_i],
attention_mask=attention_masks[start_i:end_i],
input_ids=input_ids[start_i:end_i],
output_hidden_states=True,
clip_resize_list=extend_clip_resize_list,
txt_feat=txt_feat
)
num_image_tokens = self._last_visual_token_num
seg_token_mask = torch.cat(
[torch.zeros((seg_token_mask.shape[0], num_image_tokens)).bool().cuda(), seg_token_mask],
dim=1,
)
output_image_feature_i = torch.stack(output_i.image_features, dim=0)
output_hidden_states.append(output_i.hidden_states)
output_image_features.append(output_image_feature_i)
torch.cuda.empty_cache()
output_hidden_states_list = []
output_hidden_states_level = torch.cat(output_hidden_states, dim=0)
output_hidden_states_list.append(output_hidden_states_level)
output_hidden_states = output_hidden_states_list
output_image_features = torch.cat(output_image_features, dim=1)
output = None
else:
images_clip_list = []
extend_clip_resize_list = []
for i in range(len(offset) - 1):
start_i, end_i = offset[i], offset[i + 1]
images_clip_i = (
images_clip[i]
.unsqueeze(0)
.expand(end_i - start_i, -1, -1, -1)
.contiguous()
)
extend_clip_resize_list.extend([clip_resize_list[i]] * (end_i - start_i))
images_clip_list.append(images_clip_i)
images_clip = torch.cat(images_clip_list, dim=0)
output = super().forward(
images=images_clip,
attention_mask=attention_masks,
input_ids=input_ids,
labels=labels,
output_hidden_states=True,
txt_feat=txt_feat,
clip_resize_list=extend_clip_resize_list
)
num_image_tokens = self._last_visual_token_num
seg_token_mask = torch.cat(
[torch.zeros((seg_token_mask.shape[0], num_image_tokens)).bool().cuda(), seg_token_mask],
dim=1,
)
output_hidden_states = output.hidden_states
output_image_features = output.image_features
hidden_states = []
assert len(self.model.text_hidden_fcs) == 1
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1)
seg_token_offset = seg_token_counts.cumsum(-1)
seg_token_offset = torch.cat(
[torch.zeros(1).long().cuda(), seg_token_offset], dim=0
)
seg_token_offset = seg_token_offset[offset]
feat_scale_num = self.image_feature_scale_num
pred_embeddings_ = []
batch_seg_token_counts = []
for i in range(len(seg_token_offset) - 1):
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
batch_pred_embeddings = pred_embeddings[start_i:end_i]
batch_seg_token_counts.append(seg_token_counts[offset[i]:offset[i+1]] // seg_token_num*feat_scale_num)
assert len(batch_pred_embeddings) % seg_token_num == 0
batch_pred_embeddings = batch_pred_embeddings.view(len(batch_pred_embeddings) // (seg_token_num*feat_scale_num), feat_scale_num, seg_token_num, batch_pred_embeddings.shape[-1])
if seg_token_num > 1:
fused_batch_pred_embeddings = batch_pred_embeddings[:, :, 0] * 0
for i in range(seg_token_num):
fused_batch_pred_embeddings = fused_batch_pred_embeddings + self.multiseg_scalar[i] * batch_pred_embeddings[:, :, i]
batch_pred_embeddings = fused_batch_pred_embeddings
else:
batch_pred_embeddings = batch_pred_embeddings[:, :, 0]
pred_embeddings_.append(batch_pred_embeddings)
pred_embeddings = pred_embeddings_
multi_scale_num = len(output_image_features)
if not inference:
output_image_features = torch.stack(output_image_features, dim=0)
img_embeddings = output_image_features.flatten(1, 2)
img_token_mask = torch.ones(output_image_features.shape[1], output_image_features.shape[2]).to(seg_token_mask)
img_token_counts = img_token_mask.int().sum(-1)
patch_count = int(img_token_counts[0])
patch_size = int(patch_count**0.5)
img_token_offset = img_token_counts.cumsum(-1)
img_token_offset = torch.cat(
[torch.zeros(1).long().cuda(), img_token_offset], dim=0
)
img_token_offset = img_token_offset[offset]
img_embeddings_ = []
single_img_embeddings = []
for i in range(len(img_token_offset) - 1):
start_i, end_i = img_token_offset[i], img_token_offset[i + 1]
question_num = pred_embeddings_[i].shape[0]
img_num = img_embeddings[:, start_i:end_i].shape[1] // patch_count
single_img_embeddings.append(img_embeddings[:, start_i:end_i].view(multi_scale_num, img_num, patch_count, img_embeddings.shape[-1]).permute(0, 1, 3, 2).view(multi_scale_num, img_num, img_embeddings.shape[-1], patch_size, patch_size)[:, 0])
if question_num == 0:
batch_img_embeddings = torch.zeros(multi_scale_num, 0, 4096, patch_size, patch_size).to(img_embeddings)
else:
batch_img_embeddings = img_embeddings[:, start_i:end_i].view(multi_scale_num, img_num, patch_count, img_embeddings.shape[-1])
batch_img_embeddings = batch_img_embeddings.permute(0, 1, 3, 2).view(multi_scale_num, img_num, img_embeddings.shape[-1], patch_size, patch_size)
img_embeddings_.append(batch_img_embeddings)
img_embeddings = img_embeddings_
multimask_output = False
pred_masks = []
mask_scores = []
pred_depths = []
for i in range(len(pred_embeddings)):
if self.vision_tower_for_mask:
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=None, boxes=None,masks=None,text_embeds=pred_embeddings[i],)
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
_img_embeddings = None
use_sam_multilayer = False
if (
getattr(self, "three_level_multi_scale_decoder", False)
and hasattr(self.model, "vision_tower")
and hasattr(self.model.vision_tower, "forward_sam_multilayer_features")
):
if i < len(images_clip):
img_i = images_clip[i].unsqueeze(0)
sam_multilayer_feats = self.model.vision_tower.forward_sam_multilayer_features(img_i)
if isinstance(sam_multilayer_feats, list) and len(sam_multilayer_feats) > 0:
use_sam_multilayer = True
feats_to_use = sam_multilayer_feats[: self.image_feature_scale_num]
sam_feats_list = []
for feat in feats_to_use:
feat = feat.squeeze(0)
if hasattr(self.model, "image_feature_neck"):
embed_dim = self.model.image_feature_neck[0].in_channels
target_dtype = self.model.image_feature_neck[0].weight.dtype
target_device = feat.device
if feat.shape[0] != embed_dim:
feat = feat.to(target_dtype)
feat = self.model.sam_to_embed_conv(feat.unsqueeze(0)).squeeze(0)
else:
feat = feat.to(target_dtype)
feat = self.model.image_feature_neck(feat.unsqueeze(0)).squeeze(0)
sam_feats_list.append(feat)
if len(sam_feats_list) > 0:
_img_embeddings = torch.stack(sam_feats_list, dim=0)
if _img_embeddings is None:
print("hh")
_img_embeddings = self.model.image_feature_neck(single_img_embeddings[i])
out_size = 128
low_res_masks = torch.zeros([sparse_embeddings.shape[0], 1, out_size, out_size]).to(_img_embeddings)
if self.image_feature_scale_num > 1:
for l in range(self.image_feature_scale_num):
feat_h, feat_w = _img_embeddings[l].shape[1], _img_embeddings[l].shape[2]
dense_embeddings_adjusted = dense_embeddings
if dense_embeddings.shape[-1] != feat_w or dense_embeddings.shape[-2] != feat_h:
dense_embeddings_adjusted = F.interpolate(
dense_embeddings.float(),
size=(feat_h, feat_w),
mode='bilinear',
align_corners=False
).to(dense_embeddings.dtype)
default_embedding_size = self.model.prompt_encoder.image_embedding_size[0]
if feat_h != default_embedding_size or feat_w != default_embedding_size:
default_pe = self.model.prompt_encoder.get_dense_pe()
image_pe = F.interpolate(
default_pe.float(),
size=(feat_h, feat_w),
mode='bilinear',
align_corners=False
).to(default_pe.dtype)
else:
image_pe = self.model.prompt_encoder.get_dense_pe()
l_low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=_img_embeddings[l].unsqueeze(0),
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings[:, l].unsqueeze(1),
dense_prompt_embeddings=dense_embeddings_adjusted,
multimask_output=multimask_output,
previous_masks=l_low_res_masks if l>0 else None,
level_num=l
)
low_res_masks = low_res_masks + self.multiscale_scalar[l] * F.interpolate(l_low_res_masks.float(), (out_size, out_size),mode="bilinear",align_corners=False,).to(l_low_res_masks)
else:
feat_h, feat_w = _img_embeddings[0].shape[1], _img_embeddings[0].shape[2]
dense_embeddings_adjusted = dense_embeddings
if dense_embeddings.shape[-1] != feat_w or dense_embeddings.shape[-2] != feat_h:
dense_embeddings_adjusted = F.interpolate(
dense_embeddings.float(),
size=(feat_h, feat_w),
mode='bilinear',
align_corners=False
).to(dense_embeddings.dtype)
default_embedding_size = self.model.prompt_encoder.image_embedding_size[0]
if feat_h != default_embedding_size or feat_w != default_embedding_size:
default_pe = self.model.prompt_encoder.get_dense_pe()
image_pe = F.interpolate(
default_pe.float(),
size=(feat_h, feat_w),
mode='bilinear',
align_corners=False
).to(default_pe.dtype)
else:
image_pe = self.model.prompt_encoder.get_dense_pe()
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=_img_embeddings[0].unsqueeze(0),
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings[:, 0].unsqueeze(1),
dense_prompt_embeddings=dense_embeddings_adjusted,
multimask_output=multimask_output,
)
pred_mask = self.postprocess_masks(
low_res_masks,
input_size=clip_resize_list[i],
original_size=label_list[i].shape,
)
else:
(
sparse_embeddings,
dense_embeddings,
) = self.model.visual_model.prompt_encoder(
points=None,
boxes=None,
masks=None,
text_embeds=pred_embeddings[i],
)
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
image_embeddings=image_embeddings[i].unsqueeze(0),
image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
pred_mask = self.model.visual_model.postprocess_masks(
low_res_masks,
input_size=resize_list[i],
original_size=label_list[i].shape,
)
pred_depths.append([])
pred_masks.append(pred_mask[:, 0])
mask_score = (pred_mask[:, 0].sigmoid().flatten(1) * (pred_mask[:, 0] > 0).flatten(1)).sum(1) / ((pred_mask[:, 0] > 0).flatten(1).sum(1) + 1e-6)
mask_scores.append(mask_score)
model_output = output
gt_masks = masks_list
if inference:
return {
"pred_masks": pred_masks,
"gt_masks": gt_masks,
"batch_seg_token_counts": batch_seg_token_counts,
"mask_scores": mask_scores,
}
output = model_output.logits
ce_loss = model_output.loss
ce_loss = ce_loss * self.ce_loss_weight
loss = ce_loss
mask_bce_loss = pred_masks[0].sum() * 0
mask_dice_loss = pred_masks[0].sum() * 0
mask_overlap_loss = pred_masks[0].sum() * 0
num_masks = 0
for batch_idx in range(len(pred_masks)):
gt_mask = gt_masks[batch_idx]
pred_mask = pred_masks[batch_idx]
batch_seg_token_count = batch_seg_token_counts[batch_idx]
assert (
gt_mask.shape[0] == pred_mask.shape[0]
), "gt_mask.shape: {}, pred_mask.shape: {}".format(
gt_mask.shape, pred_mask.shape
)
mask_bce_loss += (
sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
* gt_mask.shape[0]
)
mask_dice_loss += (
dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
* gt_mask.shape[0]
)
num_masks += gt_mask.shape[0]
mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
mask_loss = mask_bce_loss + mask_dice_loss
loss = loss + mask_loss
return {
"loss": loss,
"ce_loss": ce_loss,
"mask_bce_loss": mask_bce_loss,
"mask_dice_loss": mask_dice_loss,
"mask_loss": mask_loss,
}
def evaluate(
self,
images_clip,
images,
input_ids,
resize_list,
clip_resize_list,
original_size_list,
max_new_tokens=32,
tokenizer=None,
txt_feat=None
):
import time
all_pred_embeddings = []
all_output_ids = []
batch_seg_token_counts = []
all_pred_embeddings = []
all_output_ids = []
batch_seg_token_counts = []
total_start_time = time.time()
image_encoding_time = 0
text_gen_time = 0
embedding_process_time = 0
mask_gen_time = 0
num_images = len(input_ids)
with torch.no_grad():
encoding_start = time.time()
output_image_features, _, _, _, _, _ = self.prepare_inputs_labels_for_multimodal(
input_ids=torch.ones(images_clip.shape[0], 2).long().cuda() * IMAGE_TOKEN_INDEX,
attention_mask=None,
past_key_values=None,
labels=None,
images=images_clip,
clip_resize_list=clip_resize_list,
txt_feat=txt_feat,
)
multi_scale_num = self.image_feature_scale_num
output_image_features = torch.stack(output_image_features, dim=0)
image_encoding_time = time.time() - encoding_start
for idx, input_id in enumerate(input_ids):
if 0 in input_id:
unk_start = torch.where(input_id==0)[0].min()
_input_id = input_id[:unk_start]
else:
_input_id = input_id
gen_start = time.time()
outputs = self.generate(
images=images_clip,
input_ids=_input_id[None],
max_new_tokens=max_new_tokens,
num_beams=1,
output_hidden_states=True,
return_dict_in_generate=True,
clip_resize_list=clip_resize_list
)
text_gen_time += (time.time() - gen_start)
embed_start = time.time()
output_hidden_states = outputs.hidden_states[-1]
output_ids = outputs.sequences
all_output_ids.append(output_ids)
if isinstance(self.seg_token_idx, list):
seg_token_num = self.seg_token_num
seg_token_mask = torch.zeros_like(output_ids[:, 1:]).bool()
for seg_token_idx in self.seg_token_idx:
seg_token_mask = seg_token_mask | (output_ids[:, 1:] == seg_token_idx)
else:
seg_token_num = self.seg_token_num
seg_token_mask = output_ids[:, 1:] == self.seg_token_idx
num_image_tokens = self._last_visual_token_num
seg_token_mask = torch.cat(
[
torch.zeros((seg_token_mask.shape[0], num_image_tokens)).bool().cuda(),
seg_token_mask,
],
dim=1,
)
hidden_states = []
assert len(self.model.text_hidden_fcs) == 1
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
feat_scale_num = self.image_feature_scale_num
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
if len(pred_embeddings) % (seg_token_num*feat_scale_num) != 0:
seg_token_mask = (seg_token_mask*0).bool()
pred_embeddings = last_hidden_state[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1)
seg_token_offset = seg_token_counts.cumsum(-1)
seg_token_offset = torch.cat(
[torch.zeros(1).long().cuda(), seg_token_offset], dim=0
)
seg_token_offset = seg_token_offset[[0, len(seg_token_offset)-1]]
pred_embeddings_ = []
for i in range(len(seg_token_offset) - 1):
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
batch_pred_embeddings = pred_embeddings[start_i:end_i]
assert len(batch_pred_embeddings) % (seg_token_num*feat_scale_num) == 0
batch_pred_embeddings = batch_pred_embeddings.view(len(batch_pred_embeddings) // (seg_token_num*feat_scale_num), feat_scale_num, seg_token_num, batch_pred_embeddings.shape[-1])
if seg_token_num > 1:
fused_batch_pred_embeddings = batch_pred_embeddings[:, :, 0] * 0
for i in range(seg_token_num):
fused_batch_pred_embeddings = fused_batch_pred_embeddings + self.multiseg_scalar[i] * batch_pred_embeddings[:, :, i]
batch_pred_embeddings = fused_batch_pred_embeddings
else:
batch_pred_embeddings = batch_pred_embeddings[:, :, 0]
pred_embeddings_.append(batch_pred_embeddings)
batch_seg_token_counts.append(len(batch_pred_embeddings))
pred_embeddings = pred_embeddings_
all_pred_embeddings.extend(pred_embeddings)
embedding_process_time += (time.time() - embed_start)
batch_seg_token_counts = [torch.tensor(batch_seg_token_counts).to(seg_token_counts)]
pred_embeddings = [torch.cat(all_pred_embeddings)]
mask_start = time.time()
multimask_output = False
pred_masks = []
mask_scores = []
if not self.vision_tower_for_mask:
image_embeddings = self.get_visual_embs(images)
else:
img_embeddings = output_image_features.flatten(1, 2)
img_embeddings = [img_embeddings.view(multi_scale_num, 1024, img_embeddings.shape[-1]).permute(0, 2, 1).view(multi_scale_num, img_embeddings.shape[-1], 32, 32)]
for i in range(len(pred_embeddings)):
if self.vision_tower_for_mask:
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=None, boxes=None,masks=None,text_embeds=pred_embeddings[i],)
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
_img_embeddings = self.model.image_feature_neck(img_embeddings[i])
out_size = 128
low_res_masks = torch.zeros([sparse_embeddings.shape[0], 1, out_size, out_size]).to(_img_embeddings)
if self.image_feature_scale_num > 1:
for l in range(self.image_feature_scale_num):
l_low_res_masks, iou_predictions = self.model.mask_decoder(image_embeddings=_img_embeddings[l].unsqueeze(0), image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings[:, l].unsqueeze(1), dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, previous_masks=l_low_res_masks if l>0 else None, level_num=l)
low_res_masks = low_res_masks + self.multiscale_scalar[l] * F.interpolate(l_low_res_masks.float(), (out_size, out_size),mode="bilinear",align_corners=False,).to(l_low_res_masks)
else:
low_res_masks, iou_predictions = self.model.mask_decoder(image_embeddings=_img_embeddings[0].unsqueeze(0), image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings[:, 0].unsqueeze(1), dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, )
pred_mask = self.postprocess_masks(
low_res_masks,
input_size=clip_resize_list[i],
original_size=original_size_list[i],
)
else:
(
sparse_embeddings,
dense_embeddings,
) = self.model.visual_model.prompt_encoder(
points=None,
boxes=None,
masks=None,
text_embeds=pred_embeddings[i],
)
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
image_embeddings=image_embeddings[i].unsqueeze(0),
image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
pred_mask = self.model.visual_model.postprocess_masks(
low_res_masks,
input_size=resize_list[i],
original_size=original_size_list[i],
)
pred_masks.append(pred_mask[:, 0])
mask_score = (pred_mask[:, 0].sigmoid().flatten(1) * (pred_mask[:, 0] > 0).flatten(1)).sum(1) / ((pred_mask[:, 0] > 0).flatten(1).sum(1) + 1e-6)
mask_scores.append(mask_score)
mask_gen_time = time.time() - mask_start
total_time = time.time() - total_start_time
avg_time = total_time / num_images
fps = num_images / total_time
print(f"\n{'='*50}")
print(f"Inference Speed Statistics")
print(f"{'='*50}")
print(f"Total images: {num_images}")
print(f"Total time: {total_time:.4f}s")
print(f" - Image encoding: {image_encoding_time:.4f}s ({image_encoding_time/total_time*100:.1f}%)")
print(f" - Text generation: {text_gen_time:.4f}s ({text_gen_time/total_time*100:.1f}%)")
print(f" - Embedding processing: {embedding_process_time:.4f}s ({embedding_process_time/total_time*100:.1f}%)")
print(f" - Mask generation: {mask_gen_time:.4f}s ({mask_gen_time/total_time*100:.1f}%)")
print(f"Average time per image: {avg_time:.4f}s")
print(f"FPS: {fps:.2f}")
print(f"{'='*50}\n")
return all_output_ids, pred_masks, batch_seg_token_counts, mask_scores
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
target_size = max(input_size)
dtype = masks.dtype
if self.vision_tower_for_mask:
masks = F.interpolate(
masks.float(),
(target_size, target_size),
mode="bilinear",
align_corners=False,
)
if not self.masks_process_with_clip:
assert input_size[0] <= target_size
assert input_size[1] <= target_size
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(
masks, original_size, mode="bilinear", align_corners=False
)
masks = masks.to(dtype)
return masks