| """EC-SimToken: SimToken + Existence Head for null detection. |
| |
| Architecture additions over Simtoken_ForCausalLM: |
| - existence_head: Linear(out_dim, 1) β sigmoid β p(object exists) |
| - BCE existence loss on synthetic null samples (audio-swapped during training) |
| - Mask loss gated: null-augmented samples skip mask loss |
| |
| Null augmentation is done in the training script (audio swap), not here. |
| This module just accepts an optional `is_null` bool tensor per batch. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import List |
|
|
| import random |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from models.avs_model import ( |
| Simtoken_ForCausalLM, |
| dice_loss, |
| sigmoid_ce_loss, |
| compute_alignment_loss, |
| ) |
|
|
|
|
| class ECSimtoken_ForCausalLM(Simtoken_ForCausalLM): |
| """SimToken with an existence head for null-sample detection. |
| |
| Extra kwargs (consumed here, not passed to parent): |
| exist_loss_weight: float BCE existence loss weight (default 1.0) |
| """ |
|
|
| def __init__(self, config, **kwargs): |
| self.exist_loss_weight = kwargs.pop("exist_loss_weight", 1.0) |
| super().__init__(config, **kwargs) |
| out_dim = config.out_dim |
| self.existence_head = nn.Linear(out_dim, 1) |
|
|
| |
| |
| |
|
|
| def model_forward( |
| self, |
| images: torch.FloatTensor, |
| images_clip: torch.FloatTensor, |
| audio_features: torch.FloatTensor, |
| image_features: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| labels: torch.LongTensor, |
| attention_masks: torch.LongTensor, |
| masks_list: List[torch.FloatTensor], |
| resize_list: List[tuple], |
| orgsize_list: List[tuple], |
| conversation_list: List[str], |
| ref_ids: List[torch.LongTensor], |
| refs_num: List[int], |
| vids, |
| fids, |
| epoch: int = 0, |
| inference: bool = False, |
| num_frames: int = 10, |
| contrast: float = 0.0, |
| is_null: torch.BoolTensor = None, |
| **kwargs, |
| ): |
| batch_size = len(images) |
| image_embeddings = torch.cat(image_features, dim=0) |
|
|
| audio_embeddings = self.audio_feature_layer( |
| torch.stack(audio_features, dim=0) |
| ) |
|
|
| target_frame = 5 |
|
|
| ( |
| input_ids_mm, |
| attention_masks_mm, |
| past_key_values, |
| inputs_embeds, |
| labels_mm, |
| ) = super(Simtoken_ForCausalLM, self).prepare_inputs_labels_for_multimodal( |
| input_ids, |
| attention_masks, |
| past_key_values=None, |
| labels=labels, |
| images=images_clip, |
| audio_features=audio_embeddings, |
| target_frame=target_frame, |
| ref_ids=ref_ids, |
| ) |
|
|
| output = super(Simtoken_ForCausalLM, self).forward( |
| input_ids=input_ids_mm, |
| attention_mask=attention_masks_mm, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels_mm, |
| output_hidden_states=True, |
| ) |
| output_hidden_states = output.hidden_states |
|
|
| seg_token_mask = output.labels[..., 1:] == self.seg_token_idx |
| seg_token_mask = torch.cat( |
| [ |
| seg_token_mask, |
| torch.zeros( |
| (seg_token_mask.shape[0], 1), |
| device=output.labels.device, |
| dtype=torch.bool, |
| ), |
| ], |
| dim=1, |
| ) |
|
|
| seg_embeddings = self.model.text_hidden_fcs[0]( |
| output_hidden_states[-1][seg_token_mask] |
| ) |
|
|
| |
| exist_logit = self.existence_head(seg_embeddings) |
|
|
| |
| fis_flat = [fid[0] for fid in fids] |
| ct_loss = torch.tensor(0.0, device=seg_embeddings.device) |
| if not inference and contrast > 0.0: |
| pos_feats = self.memory.get_positive_features(vids, fis_flat) |
| neg_feats = self.memory.get_negative_features_same_vid(vids, fis_flat) |
| for i in range(len(neg_feats)): |
| for j in range(len(seg_embeddings)): |
| if j != i: |
| neg_feats[i].append(seg_embeddings[j].detach().cpu()) |
| ct_loss = compute_alignment_loss(seg_embeddings, pos_feats, neg_feats) |
| |
| valid_vids = [vids[i] for i in range(batch_size) if not (is_null is not None and is_null[i])] |
| valid_fids = [fis_flat[i] for i in range(batch_size) if not (is_null is not None and is_null[i])] |
| valid_embs = seg_embeddings[ |
| [i for i in range(batch_size) if not (is_null is not None and is_null[i])] |
| ] if valid_vids else seg_embeddings[:0] |
| if valid_vids: |
| self.memory.add_batch(valid_vids, valid_fids, valid_embs) |
| elif not inference: |
| self.memory.add_batch(vids, fis_flat, seg_embeddings) |
|
|
| |
| pred_embeddings = [] |
| pred_idx = 0 |
| for ref_num in refs_num: |
| pred_embeddings.append(seg_embeddings[pred_idx : pred_idx + ref_num]) |
| pred_idx += ref_num |
|
|
| |
| pred_masks = [] |
| for i in range(batch_size): |
| sparse_embeddings, dense_embeddings = self.model.visual_model.prompt_encoder( |
| points=None, |
| boxes=None, |
| masks=None, |
| text_embeds=pred_embeddings[i].unsqueeze(1), |
| ) |
| sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
| dense_embeddings = dense_embeddings.to(pred_embeddings[i].dtype) |
|
|
| pred_masks_sample = [] |
| for prompt_idx in range(len(sparse_embeddings)): |
| low_res_masks, _ = self.model.visual_model.mask_decoder( |
| image_embeddings=image_embeddings[i * num_frames : (i + 1) * num_frames], |
| image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings[prompt_idx : prompt_idx + 1], |
| dense_prompt_embeddings=dense_embeddings[prompt_idx : prompt_idx + 1], |
| multimask_output=False, |
| ) |
| pred_mask = self.model.visual_model.postprocess_masks( |
| low_res_masks, |
| input_size=resize_list[i], |
| original_size=orgsize_list[i], |
| ) |
| pred_masks_sample.append(pred_mask.squeeze(1)) |
| pred_masks.append(torch.stack(pred_masks_sample, dim=0)) |
|
|
| gt_masks = masks_list |
|
|
| if inference: |
| return { |
| "pred_masks": pred_masks, |
| "gt_masks": gt_masks, |
| "exist_logit": exist_logit, |
| } |
|
|
| |
|
|
| ce_loss = output.loss * self.ce_loss_weight |
|
|
| |
| mask_bce_loss = 0.0 |
| mask_dice_loss = 0.0 |
| num_masks = 0 |
| for batch_idx in range(batch_size): |
| if is_null is not None and is_null[batch_idx]: |
| continue |
| gt_mask = gt_masks[batch_idx] |
| pred_mask = pred_masks[batch_idx] |
| a, b, c, d = gt_mask.shape |
| gt_flat = gt_mask.view(a * b, c, d) |
| pred_flat = pred_mask.view(a * b, c, d) |
| mask_bce_loss += ( |
| sigmoid_ce_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0]) |
| * gt_flat.shape[0] |
| ) |
| mask_dice_loss += ( |
| dice_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0]) |
| * gt_flat.shape[0] |
| ) |
| num_masks += gt_flat.shape[0] |
|
|
| mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) |
| mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) |
| mask_loss = mask_bce_loss + mask_dice_loss |
|
|
| |
| if is_null is not None: |
| exist_target = (~is_null).float().to(exist_logit.device) |
| exist_loss = F.binary_cross_entropy_with_logits( |
| exist_logit.squeeze(-1), exist_target |
| ) |
| else: |
| exist_loss = torch.tensor(0.0, device=exist_logit.device) |
|
|
| loss = ( |
| ce_loss |
| + mask_loss |
| + self.exist_loss_weight * exist_loss |
| + contrast * ct_loss |
| ) |
|
|
| return { |
| "loss": loss, |
| "ce_loss": ce_loss, |
| "mask_bce_loss": mask_bce_loss if isinstance(mask_bce_loss, torch.Tensor) else torch.tensor(mask_bce_loss), |
| "mask_dice_loss": mask_dice_loss if isinstance(mask_dice_loss, torch.Tensor) else torch.tensor(mask_dice_loss), |
| "mask_loss": mask_loss if isinstance(mask_loss, torch.Tensor) else torch.tensor(mask_loss), |
| "exist_loss": exist_loss, |
| "ct_loss": ct_loss, |
| "pred_masks": pred_masks, |
| "gt_masks": gt_masks, |
| "exist_logit": exist_logit, |
| } |
|
|