SimToken / models /ec_simtoken_model.py
yfan07's picture
Upload folder using huggingface_hub
a95e79a verified
"""EC-SimToken: SimToken + Existence Head for null detection.
Architecture additions over Simtoken_ForCausalLM:
- existence_head: Linear(out_dim, 1) β†’ sigmoid β†’ p(object exists)
- BCE existence loss on synthetic null samples (audio-swapped during training)
- Mask loss gated: null-augmented samples skip mask loss
Null augmentation is done in the training script (audio swap), not here.
This module just accepts an optional `is_null` bool tensor per batch.
"""
from __future__ import annotations
from typing import List
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.avs_model import (
Simtoken_ForCausalLM,
dice_loss,
sigmoid_ce_loss,
compute_alignment_loss,
)
class ECSimtoken_ForCausalLM(Simtoken_ForCausalLM):
"""SimToken with an existence head for null-sample detection.
Extra kwargs (consumed here, not passed to parent):
exist_loss_weight: float BCE existence loss weight (default 1.0)
"""
def __init__(self, config, **kwargs):
self.exist_loss_weight = kwargs.pop("exist_loss_weight", 1.0)
super().__init__(config, **kwargs)
out_dim = config.out_dim
self.existence_head = nn.Linear(out_dim, 1)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def model_forward(
self,
images: torch.FloatTensor,
images_clip: torch.FloatTensor,
audio_features: torch.FloatTensor,
image_features: torch.FloatTensor,
input_ids: torch.LongTensor,
labels: torch.LongTensor,
attention_masks: torch.LongTensor,
masks_list: List[torch.FloatTensor],
resize_list: List[tuple],
orgsize_list: List[tuple],
conversation_list: List[str],
ref_ids: List[torch.LongTensor],
refs_num: List[int],
vids,
fids,
epoch: int = 0,
inference: bool = False,
num_frames: int = 10,
contrast: float = 0.0,
is_null: torch.BoolTensor = None, # [B] True = synthetic null sample
**kwargs,
):
batch_size = len(images)
image_embeddings = torch.cat(image_features, dim=0) # [BT, 256, 64, 64]
audio_embeddings = self.audio_feature_layer(
torch.stack(audio_features, dim=0)
) # [B, T, 4096]
target_frame = 5 # fixed as in original
(
input_ids_mm,
attention_masks_mm,
past_key_values,
inputs_embeds,
labels_mm,
) = super(Simtoken_ForCausalLM, self).prepare_inputs_labels_for_multimodal(
input_ids,
attention_masks,
past_key_values=None,
labels=labels,
images=images_clip,
audio_features=audio_embeddings,
target_frame=target_frame,
ref_ids=ref_ids,
)
output = super(Simtoken_ForCausalLM, self).forward(
input_ids=input_ids_mm,
attention_mask=attention_masks_mm,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels_mm,
output_hidden_states=True,
)
output_hidden_states = output.hidden_states
seg_token_mask = output.labels[..., 1:] == self.seg_token_idx
seg_token_mask = torch.cat(
[
seg_token_mask,
torch.zeros(
(seg_token_mask.shape[0], 1),
device=output.labels.device,
dtype=torch.bool,
),
],
dim=1,
) # [B, seq_len]
seg_embeddings = self.model.text_hidden_fcs[0](
output_hidden_states[-1][seg_token_mask]
) # [seg_num, 256] (seg_num == B when refs_num == [1]*B)
# ── Existence head ────────────────────────────────────────────────
exist_logit = self.existence_head(seg_embeddings) # [seg_num, 1]
# ── Memory / contrastive (optional, gated by contrast weight) ────
fis_flat = [fid[0] for fid in fids]
ct_loss = torch.tensor(0.0, device=seg_embeddings.device)
if not inference and contrast > 0.0:
pos_feats = self.memory.get_positive_features(vids, fis_flat)
neg_feats = self.memory.get_negative_features_same_vid(vids, fis_flat)
for i in range(len(neg_feats)):
for j in range(len(seg_embeddings)):
if j != i:
neg_feats[i].append(seg_embeddings[j].detach().cpu())
ct_loss = compute_alignment_loss(seg_embeddings, pos_feats, neg_feats)
# Only add non-null samples to memory
valid_vids = [vids[i] for i in range(batch_size) if not (is_null is not None and is_null[i])]
valid_fids = [fis_flat[i] for i in range(batch_size) if not (is_null is not None and is_null[i])]
valid_embs = seg_embeddings[
[i for i in range(batch_size) if not (is_null is not None and is_null[i])]
] if valid_vids else seg_embeddings[:0]
if valid_vids:
self.memory.add_batch(valid_vids, valid_fids, valid_embs)
elif not inference:
self.memory.add_batch(vids, fis_flat, seg_embeddings)
# ── Reorganise seg embeddings per batch item ──────────────────────
pred_embeddings = []
pred_idx = 0
for ref_num in refs_num:
pred_embeddings.append(seg_embeddings[pred_idx : pred_idx + ref_num])
pred_idx += ref_num
# ── SAM mask decoder ──────────────────────────────────────────────
pred_masks = []
for i in range(batch_size):
sparse_embeddings, dense_embeddings = self.model.visual_model.prompt_encoder(
points=None,
boxes=None,
masks=None,
text_embeds=pred_embeddings[i].unsqueeze(1),
)
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
dense_embeddings = dense_embeddings.to(pred_embeddings[i].dtype)
pred_masks_sample = []
for prompt_idx in range(len(sparse_embeddings)):
low_res_masks, _ = self.model.visual_model.mask_decoder(
image_embeddings=image_embeddings[i * num_frames : (i + 1) * num_frames],
image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings[prompt_idx : prompt_idx + 1],
dense_prompt_embeddings=dense_embeddings[prompt_idx : prompt_idx + 1],
multimask_output=False,
)
pred_mask = self.model.visual_model.postprocess_masks(
low_res_masks,
input_size=resize_list[i],
original_size=orgsize_list[i],
) # [T, 1, H, W]
pred_masks_sample.append(pred_mask.squeeze(1))
pred_masks.append(torch.stack(pred_masks_sample, dim=0)) # [num_seg, T, H, W]
gt_masks = masks_list
if inference:
return {
"pred_masks": pred_masks,
"gt_masks": gt_masks,
"exist_logit": exist_logit, # [seg_num, 1]
}
# ── Losses ────────────────────────────────────────────────────────
ce_loss = output.loss * self.ce_loss_weight
# Mask loss β€” skip null-augmented samples
mask_bce_loss = 0.0
mask_dice_loss = 0.0
num_masks = 0
for batch_idx in range(batch_size):
if is_null is not None and is_null[batch_idx]:
continue # null sample: no mask loss
gt_mask = gt_masks[batch_idx]
pred_mask = pred_masks[batch_idx]
a, b, c, d = gt_mask.shape
gt_flat = gt_mask.view(a * b, c, d)
pred_flat = pred_mask.view(a * b, c, d)
mask_bce_loss += (
sigmoid_ce_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
* gt_flat.shape[0]
)
mask_dice_loss += (
dice_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
* gt_flat.shape[0]
)
num_masks += gt_flat.shape[0]
mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
mask_loss = mask_bce_loss + mask_dice_loss
# Existence loss (BCE)
if is_null is not None:
exist_target = (~is_null).float().to(exist_logit.device)
exist_loss = F.binary_cross_entropy_with_logits(
exist_logit.squeeze(-1), exist_target
)
else:
exist_loss = torch.tensor(0.0, device=exist_logit.device)
loss = (
ce_loss
+ mask_loss
+ self.exist_loss_weight * exist_loss
+ contrast * ct_loss
)
return {
"loss": loss,
"ce_loss": ce_loss,
"mask_bce_loss": mask_bce_loss if isinstance(mask_bce_loss, torch.Tensor) else torch.tensor(mask_bce_loss),
"mask_dice_loss": mask_dice_loss if isinstance(mask_dice_loss, torch.Tensor) else torch.tensor(mask_dice_loss),
"mask_loss": mask_loss if isinstance(mask_loss, torch.Tensor) else torch.tensor(mask_loss),
"exist_loss": exist_loss,
"ct_loss": ct_loss,
"pred_masks": pred_masks,
"gt_masks": gt_masks,
"exist_logit": exist_logit,
}