| from turtledemo.penrose import start |
| from typing import List |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import BitsAndBytesConfig, CLIPVisionModel |
|
|
| |
|
|
| from ChatUniVi.model.language_model.llama import ChatUniViLlamaForCausalLM, ChatUniViLlamaModel |
|
|
| from models.segment_anything import build_sam_vit_h |
| from ChatUniVi.constants import IMAGE_TOKEN_INDEX |
|
|
| import cv2 |
| import time |
| import random |
| import math |
| from collections import defaultdict |
|
|
|
|
|
|
| def dice_loss( |
| inputs: torch.Tensor, |
| targets: torch.Tensor, |
| num_masks: float, |
| scale: float = 1000, |
| eps: float = 1e-6, |
| ): |
| """ |
| Compute the DICE loss, similar to generalized IOU for masks |
| Args: |
| inputs: A float tensor of arbitrary shape. |
| The predictions for each example. |
| targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). |
| """ |
| inputs = inputs.sigmoid() |
| inputs = inputs.flatten(1, 2) |
| targets = targets.flatten(1, 2) |
| numerator = 2 * (inputs / scale * targets).sum(-1) |
| denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) |
| loss = 1 - (numerator + eps) / (denominator + eps) |
| loss = loss.sum() / (num_masks + 1e-8) |
| return loss |
|
|
|
|
| def sigmoid_ce_loss( |
| inputs: torch.Tensor, |
| targets: torch.Tensor, |
| num_masks: float, |
| ): |
| """ |
| Args: |
| inputs: A float tensor of arbitrary shape. |
| The predictions for each example. |
| targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). |
| Returns: |
| Loss tensor |
| """ |
| loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
| loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) |
| return loss |
|
|
|
|
|
|
| def compute_alignment_loss(q: torch.Tensor, pos_feats: list, neg_feats: list, temperature=0.07): |
| """ |
| q: [B, D] embedding of the output SEG token |
| pos_feats: List[B][List[Tensor[D]]] semantic embeddings of positive sets |
| """ |
| B, D = q.shape |
| device = q.device |
| total_loss = 0.0 |
| count = 0 |
|
|
| for i in range(B): |
| pos = pos_feats[i] |
| neg = neg_feats[i] |
|
|
| if len(pos) == 0: |
| continue |
|
|
| |
| anchor = F.normalize(q[i].unsqueeze(0), dim=1) |
| pos_tensors = torch.stack(pos).to(device) |
| pos_tensors = F.normalize(pos_tensors, dim=1) |
|
|
| |
| sim_pos = torch.matmul(anchor, pos_tensors.T) / temperature |
| log_probs = F.log_softmax(sim_pos, dim=1) |
| loss = -log_probs.mean() |
| total_loss += loss |
| count += 1 |
|
|
| if count == 0: |
| return torch.tensor(0.0, device=device, requires_grad=True) |
|
|
| return total_loss / count |
|
|
|
|
|
|
|
|
| class Simtoken_MetaModel: |
| def __init__( |
| self, |
| config, |
| **kwargs, |
| ): |
| super(Simtoken_MetaModel, self).__init__(config) |
|
|
| self.config = config |
| if not hasattr(self.config, "train_mask_decoder"): |
| self.config.train_mask_decoder = kwargs["train_mask_decoder"] |
| self.config.out_dim = kwargs["out_dim"] |
| self.vision_pretrained = kwargs.get("vision_pretrained", None) |
| else: |
| self.vision_pretrained = kwargs.get("vision_pretrained", None) |
| self.initialize_lisa_modules(self.config) |
|
|
| def initialize_lisa_modules(self, config): |
| |
| self.visual_model = build_sam_vit_h(self.vision_pretrained) |
| for param in self.visual_model.parameters(): |
| param.requires_grad = False |
| if config.train_mask_decoder: |
| self.visual_model.mask_decoder.train() |
| for param in self.visual_model.mask_decoder.parameters(): |
| param.requires_grad = True |
|
|
| |
| in_dim = config.hidden_size |
| out_dim = config.out_dim |
| text_fc = [ |
| nn.Linear(in_dim, in_dim), |
| nn.ReLU(inplace=True), |
| nn.Linear(in_dim, out_dim), |
| nn.Dropout(0.0), |
| ] |
| self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) |
| self.text_hidden_fcs.train() |
| for param in self.text_hidden_fcs.parameters(): |
| param.requires_grad = True |
|
|
|
|
| class Simtoken_Model(Simtoken_MetaModel, ChatUniViLlamaModel): |
| def __init__( |
| self, |
| config, |
| **kwargs, |
| ): |
| super(Simtoken_Model, self).__init__(config, **kwargs) |
|
|
| self.config.use_cache = False |
| self.config.vision_tower = self.config.mm_vision_tower |
| self.config.mm_vision_select_feature = "patch" |
| self.config.image_aspect_ratio = "square" |
| self.config.image_grid_pinpoints = None |
| self.config.tune_mm_mlp_adapter = False |
| self.config.freeze_mm_mlp_adapter = True |
| self.config.pretrain_mm_mlp_adapter = None |
| self.config.mm_use_im_patch_token = False |
|
|
|
|
| class SemanticMemoryBank: |
| def __init__(self, max_per_object=5): |
| self.bank = defaultdict(lambda: defaultdict(list)) |
| self.max_per_object = max_per_object |
|
|
| def add(self, vid: str, fid: int, feat: torch.Tensor): |
| feat = feat.detach().cpu() |
| self.bank[vid][fid].append(feat) |
| if len(self.bank[vid][fid]) > self.max_per_object: |
| self.bank[vid][fid] = self.bank[vid][fid][-self.max_per_object:] |
|
|
| def add_batch(self, vids: list, fids: list, feats: torch.Tensor): |
| for vid, fid, feat in zip(vids, fids, feats): |
| self.add(vid, int(fid), feat) |
|
|
| def get_positive_features(self, vids: list, fids: list): |
| results = [] |
| for vid, fid in zip(vids, fids): |
| pos = self.bank[vid][int(fid)].copy() |
| results.append(pos) |
| return results |
|
|
| def get_negative_features_same_vid(self, vids: list, fids: list): |
| results = [] |
| for vid, fid in zip(vids, fids): |
| neg = [] |
| for other_fid, feats in self.bank[vid].items(): |
| if other_fid != int(fid): |
| neg.extend(feats) |
| results.append(neg) |
| return results |
|
|
|
|
| class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM): |
| def __init__( |
| self, |
| config, |
| **kwargs, |
| ): |
|
|
| if not hasattr(config, "train_mask_decoder"): |
| |
| config.mm_use_im_start_end = kwargs.pop("use_mm_start_end", True) |
|
|
| config.mm_vision_tower = kwargs.get("vision_tower", "openai/clip-vit-large-patch14") |
| |
| self.ce_loss_weight = kwargs.pop("ce_loss_weight", None) |
| self.dice_loss_weight = kwargs.pop("dice_loss_weight", None) |
| self.bce_loss_weight = kwargs.pop("bce_loss_weight", None) |
| else: |
| config.mm_vision_tower = config.vision_tower |
|
|
| self.seg_token_idx = kwargs.pop("seg_token_idx") |
|
|
|
|
| super().__init__(config) |
|
|
| self.model = Simtoken_Model(config, **kwargs) |
|
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| self.audio_feature_layer = nn.Linear(in_features=128, out_features=4096) |
|
|
| self.memory = SemanticMemoryBank() |
|
|
| self.compress = kwargs.pop("compress", True) |
|
|
| self.start = kwargs.pop("start") |
|
|
|
|
|
|
|
|
| def get_visual_embs(self, pixel_values: torch.FloatTensor): |
| with torch.no_grad(): |
| image_embeddings = self.model.visual_model.image_encoder(pixel_values) |
| return image_embeddings |
|
|
| def forward(self, **kwargs): |
| if "past_key_values" in kwargs: |
| return super().forward(**kwargs) |
| return self.model_forward(**kwargs) |
|
|
| def model_forward( |
| self, |
| images: torch.FloatTensor, |
| images_clip: torch.FloatTensor, |
| audio_features: torch.FloatTensor, |
| image_features: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| labels: torch.LongTensor, |
| attention_masks: torch.LongTensor, |
| masks_list: List[torch.FloatTensor], |
| resize_list: List[tuple], |
| orgsize_list: List[tuple], |
| conversation_list: List[str], |
| |
| |
| ref_ids: List[torch.LongTensor], |
| refs_num: List[int], |
| vids, |
| fids, |
| epoch: int =0, |
| inference: bool = False, |
| num_frames: int = 10, |
| contrast: float = 0.0, |
| |
| **kwargs, |
| ): |
| batch_size = len(images) |
| image_embeddings = torch.cat(image_features, dim=0) |
| |
|
|
| audio_embeddings = self.audio_feature_layer(torch.stack(audio_features, dim=0)) |
| |
| |
|
|
| |
| if not inference: |
| target_frame = random.randint(0, 9) |
| target_frame = 5 |
|
|
| else: |
| target_frame = 5 |
| |
|
|
| input_ids, attention_masks, past_key_values, inputs_embeds, labels = super().prepare_inputs_labels_for_multimodal( |
| input_ids, attention_masks, past_key_values=None, labels=labels, images=images_clip, audio_features=audio_embeddings, target_frame=target_frame, ref_ids=ref_ids |
| ) |
|
|
|
|
| output = super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_masks, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| output_hidden_states=True, |
| ) |
| output_hidden_states = output.hidden_states |
| |
|
|
| seg_token_mask = output.labels[..., 1:] == self.seg_token_idx |
| seg_token_mask = torch.cat( |
| [seg_token_mask, torch.zeros((seg_token_mask.shape[0], 1), device=output.labels.device).bool(), ], |
| dim=1, ) |
|
|
|
|
| seg_embeddings = self.model.text_hidden_fcs[0](output_hidden_states[-1][seg_token_mask]) |
|
|
| |
| |
| |
| fis_flat = [fid[0] for fid in fids] |
| |
| if not inference: |
|
|
| pos_feats = self.memory.get_positive_features(vids, fis_flat ) |
| neg_feats = self.memory.get_negative_features_same_vid(vids, fis_flat ) |
|
|
| for i in range(len(neg_feats)): |
| for j in range(len(seg_embeddings)): |
| if j != i: |
| neg_feats[i].append(seg_embeddings[j].detach().cpu()) |
|
|
| ct_loss = compute_alignment_loss(seg_embeddings, pos_feats, neg_feats) |
|
|
| |
| self.memory.add_batch(vids, fis_flat, seg_embeddings) |
|
|
|
|
| pred_embeddings = [] |
| |
| pred_idx = 0 |
| for ref_num in refs_num: |
| pred_embeddings.append(seg_embeddings[pred_idx:pred_idx + ref_num]) |
| pred_idx += ref_num |
| |
|
|
|
|
| pred_masks = [] |
|
|
| |
| for i in range(batch_size): |
|
|
| ( |
| sparse_embeddings, |
| dense_embeddings, |
| ) = self.model.visual_model.prompt_encoder( |
| points=None, |
| boxes=None, |
| masks=None, |
| text_embeds=pred_embeddings[i].unsqueeze(1), |
| ) |
| |
| sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
| dense_embeddings = dense_embeddings.to(pred_embeddings[i].dtype) |
| |
| |
|
|
|
|
| pred_masks_sample = [] |
| |
| for prompt_idx in range(len(sparse_embeddings)): |
| low_res_masks, iou_predictions = self.model.visual_model.mask_decoder( |
| image_embeddings=image_embeddings[i * num_frames: (i + 1) * num_frames], |
| image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings[prompt_idx : prompt_idx+1], |
| dense_prompt_embeddings=dense_embeddings[prompt_idx : prompt_idx+1], |
| multimask_output=False, |
| ) |
|
|
| pred_mask = self.model.visual_model.postprocess_masks( |
| low_res_masks, |
| input_size=resize_list[i], |
| original_size=orgsize_list[i] |
| ) |
| pred_masks_sample.append(pred_mask.squeeze(1)) |
|
|
| pred_masks_sample = torch.stack(pred_masks_sample, dim=0) |
| pred_masks.append(pred_masks_sample) |
|
|
|
|
|
|
|
|
| gt_masks = masks_list |
|
|
| if inference: |
| return { |
| "pred_masks": pred_masks, |
| "gt_masks": gt_masks, |
| } |
|
|
| model_output = output |
| output = model_output.logits |
|
|
|
|
| ce_loss = model_output.loss |
| ce_loss = ce_loss * self.ce_loss_weight |
|
|
| mask_bce_loss = 0 |
| mask_dice_loss = 0 |
| num_masks = 0 |
|
|
| |
| for batch_idx in range(batch_size): |
|
|
|
|
| gt_mask = gt_masks[batch_idx] |
| pred_mask = pred_masks[batch_idx] |
|
|
| a, b, c, d = gt_mask.shape |
| gt_mask = gt_mask.view(a*b, c, d) |
| pred_mask = pred_mask.view(a*b, c, d) |
|
|
| |
|
|
|
|
| mask_bce_loss += ( |
| sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) |
| * gt_mask.shape[0] |
| ) |
| mask_dice_loss += ( |
| dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) |
| * gt_mask.shape[0] |
| ) |
| num_masks += gt_mask.shape[0] |
|
|
| mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) |
| mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) |
| mask_loss = mask_bce_loss + mask_dice_loss |
|
|
|
|
|
|
| ct_weight = contrast |
|
|
|
|
| if epoch >= self.start: |
| loss = ce_loss + mask_loss + ct_weight * ct_loss |
| else: |
| loss = ce_loss + mask_loss |
|
|
| return { |
| "loss": loss, |
| "ce_loss": ce_loss, |
| "mask_bce_loss": mask_bce_loss, |
| "mask_dice_loss": mask_dice_loss, |
| "mask_loss": mask_loss, |
| "ct_loss": ct_loss, |
| "pred_masks": pred_masks, |
| "gt_masks": gt_masks, |
| } |
|
|
|
|
| def evaluate(self, *args, **kwargs): |
| raise NotImplementedError("This method is not implemented.") |
|
|
|
|