"""EC-SimToken: SimToken + Existence Head for null detection. Architecture additions over Simtoken_ForCausalLM: - existence_head: Linear(out_dim, 1) → sigmoid → p(object exists) - BCE existence loss on synthetic null samples (audio-swapped during training) - Mask loss gated: null-augmented samples skip mask loss Null augmentation is done in the training script (audio swap), not here. This module just accepts an optional `is_null` bool tensor per batch. """ from __future__ import annotations from typing import List import random import torch import torch.nn as nn import torch.nn.functional as F from models.avs_model import ( Simtoken_ForCausalLM, dice_loss, sigmoid_ce_loss, compute_alignment_loss, ) class ECSimtoken_ForCausalLM(Simtoken_ForCausalLM): """SimToken with an existence head for null-sample detection. Extra kwargs (consumed here, not passed to parent): exist_loss_weight: float BCE existence loss weight (default 1.0) """ def __init__(self, config, **kwargs): self.exist_loss_weight = kwargs.pop("exist_loss_weight", 1.0) super().__init__(config, **kwargs) out_dim = config.out_dim self.existence_head = nn.Linear(out_dim, 1) # ------------------------------------------------------------------ # Forward # ------------------------------------------------------------------ def model_forward( self, images: torch.FloatTensor, images_clip: torch.FloatTensor, audio_features: torch.FloatTensor, image_features: torch.FloatTensor, input_ids: torch.LongTensor, labels: torch.LongTensor, attention_masks: torch.LongTensor, masks_list: List[torch.FloatTensor], resize_list: List[tuple], orgsize_list: List[tuple], conversation_list: List[str], ref_ids: List[torch.LongTensor], refs_num: List[int], vids, fids, epoch: int = 0, inference: bool = False, num_frames: int = 10, contrast: float = 0.0, is_null: torch.BoolTensor = None, # [B] True = synthetic null sample **kwargs, ): batch_size = len(images) image_embeddings = torch.cat(image_features, dim=0) # [BT, 256, 64, 64] audio_embeddings = self.audio_feature_layer( torch.stack(audio_features, dim=0) ) # [B, T, 4096] target_frame = 5 # fixed as in original ( input_ids_mm, attention_masks_mm, past_key_values, inputs_embeds, labels_mm, ) = super(Simtoken_ForCausalLM, self).prepare_inputs_labels_for_multimodal( input_ids, attention_masks, past_key_values=None, labels=labels, images=images_clip, audio_features=audio_embeddings, target_frame=target_frame, ref_ids=ref_ids, ) output = super(Simtoken_ForCausalLM, self).forward( input_ids=input_ids_mm, attention_mask=attention_masks_mm, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels_mm, output_hidden_states=True, ) output_hidden_states = output.hidden_states seg_token_mask = output.labels[..., 1:] == self.seg_token_idx seg_token_mask = torch.cat( [ seg_token_mask, torch.zeros( (seg_token_mask.shape[0], 1), device=output.labels.device, dtype=torch.bool, ), ], dim=1, ) # [B, seq_len] seg_embeddings = self.model.text_hidden_fcs[0]( output_hidden_states[-1][seg_token_mask] ) # [seg_num, 256] (seg_num == B when refs_num == [1]*B) # ── Existence head ──────────────────────────────────────────────── exist_logit = self.existence_head(seg_embeddings) # [seg_num, 1] # ── Memory / contrastive (optional, gated by contrast weight) ──── fis_flat = [fid[0] for fid in fids] ct_loss = torch.tensor(0.0, device=seg_embeddings.device) if not inference and contrast > 0.0: pos_feats = self.memory.get_positive_features(vids, fis_flat) neg_feats = self.memory.get_negative_features_same_vid(vids, fis_flat) for i in range(len(neg_feats)): for j in range(len(seg_embeddings)): if j != i: neg_feats[i].append(seg_embeddings[j].detach().cpu()) ct_loss = compute_alignment_loss(seg_embeddings, pos_feats, neg_feats) # Only add non-null samples to memory valid_vids = [vids[i] for i in range(batch_size) if not (is_null is not None and is_null[i])] valid_fids = [fis_flat[i] for i in range(batch_size) if not (is_null is not None and is_null[i])] valid_embs = seg_embeddings[ [i for i in range(batch_size) if not (is_null is not None and is_null[i])] ] if valid_vids else seg_embeddings[:0] if valid_vids: self.memory.add_batch(valid_vids, valid_fids, valid_embs) elif not inference: self.memory.add_batch(vids, fis_flat, seg_embeddings) # ── Reorganise seg embeddings per batch item ────────────────────── pred_embeddings = [] pred_idx = 0 for ref_num in refs_num: pred_embeddings.append(seg_embeddings[pred_idx : pred_idx + ref_num]) pred_idx += ref_num # ── SAM mask decoder ────────────────────────────────────────────── pred_masks = [] for i in range(batch_size): sparse_embeddings, dense_embeddings = self.model.visual_model.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=pred_embeddings[i].unsqueeze(1), ) sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) dense_embeddings = dense_embeddings.to(pred_embeddings[i].dtype) pred_masks_sample = [] for prompt_idx in range(len(sparse_embeddings)): low_res_masks, _ = self.model.visual_model.mask_decoder( image_embeddings=image_embeddings[i * num_frames : (i + 1) * num_frames], image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings[prompt_idx : prompt_idx + 1], dense_prompt_embeddings=dense_embeddings[prompt_idx : prompt_idx + 1], multimask_output=False, ) pred_mask = self.model.visual_model.postprocess_masks( low_res_masks, input_size=resize_list[i], original_size=orgsize_list[i], ) # [T, 1, H, W] pred_masks_sample.append(pred_mask.squeeze(1)) pred_masks.append(torch.stack(pred_masks_sample, dim=0)) # [num_seg, T, H, W] gt_masks = masks_list if inference: return { "pred_masks": pred_masks, "gt_masks": gt_masks, "exist_logit": exist_logit, # [seg_num, 1] } # ── Losses ──────────────────────────────────────────────────────── ce_loss = output.loss * self.ce_loss_weight # Mask loss — skip null-augmented samples mask_bce_loss = 0.0 mask_dice_loss = 0.0 num_masks = 0 for batch_idx in range(batch_size): if is_null is not None and is_null[batch_idx]: continue # null sample: no mask loss gt_mask = gt_masks[batch_idx] pred_mask = pred_masks[batch_idx] a, b, c, d = gt_mask.shape gt_flat = gt_mask.view(a * b, c, d) pred_flat = pred_mask.view(a * b, c, d) mask_bce_loss += ( sigmoid_ce_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0]) * gt_flat.shape[0] ) mask_dice_loss += ( dice_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0]) * gt_flat.shape[0] ) num_masks += gt_flat.shape[0] mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) mask_loss = mask_bce_loss + mask_dice_loss # Existence loss (BCE) if is_null is not None: exist_target = (~is_null).float().to(exist_logit.device) exist_loss = F.binary_cross_entropy_with_logits( exist_logit.squeeze(-1), exist_target ) else: exist_loss = torch.tensor(0.0, device=exist_logit.device) loss = ( ce_loss + mask_loss + self.exist_loss_weight * exist_loss + contrast * ct_loss ) return { "loss": loss, "ce_loss": ce_loss, "mask_bce_loss": mask_bce_loss if isinstance(mask_bce_loss, torch.Tensor) else torch.tensor(mask_bce_loss), "mask_dice_loss": mask_dice_loss if isinstance(mask_dice_loss, torch.Tensor) else torch.tensor(mask_dice_loss), "mask_loss": mask_loss if isinstance(mask_loss, torch.Tensor) else torch.tensor(mask_loss), "exist_loss": exist_loss, "ct_loss": ct_loss, "pred_masks": pred_masks, "gt_masks": gt_masks, "exist_logit": exist_logit, }