yfan07 commited on
Commit
8eb2324
·
verified ·
1 Parent(s): a88521e

Add files using upload-large-folder tool

Browse files
ChatUniVi/model/multimodal_encoder/eva_vit.py CHANGED
@@ -12,8 +12,8 @@ import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
  import torch.utils.checkpoint as checkpoint
15
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
- from timm.models.registry import register_model
17
 
18
  from .utils import download_cached_file
19
 
@@ -445,4 +445,4 @@ def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, pre
445
  if precision == "fp16":
446
  # model.to("cuda")
447
  convert_weights_to_fp16(model)
448
- return model
 
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
  import torch.utils.checkpoint as checkpoint
15
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models import register_model
17
 
18
  from .utils import download_cached_file
19
 
 
445
  if precision == "fp16":
446
  # model.to("cuda")
447
  convert_weights_to_fp16(model)
448
+ return model
ChatUniVi/model/multimodal_encoder/utils.py CHANGED
@@ -11,7 +11,10 @@ import os
11
 
12
  import torch
13
  import torch.distributed as dist
14
- import timm.models.hub as timm_hub
 
 
 
15
 
16
 
17
  def setup_for_distributed(is_master):
@@ -124,14 +127,14 @@ def download_cached_file(url, check_hash=True, progress=False):
124
  # a hack to sync the file path across processes
125
  parts = torch.hub.urlparse(url)
126
  filename = os.path.basename(parts.path)
127
- cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
 
129
  return cached_file
130
 
131
  if is_main_process():
132
- timm_hub.download_cached_file(url, check_hash, progress)
133
 
134
  if is_dist_avail_and_initialized():
135
  dist.barrier()
136
 
137
- return get_cached_file_path()
 
11
 
12
  import torch
13
  import torch.distributed as dist
14
+ from timm.models._hub import (
15
+ download_cached_file as timm_download_cached_file,
16
+ get_cache_dir as timm_get_cache_dir,
17
+ )
18
 
19
 
20
  def setup_for_distributed(is_master):
 
127
  # a hack to sync the file path across processes
128
  parts = torch.hub.urlparse(url)
129
  filename = os.path.basename(parts.path)
130
+ cached_file = os.path.join(timm_get_cache_dir(), filename)
131
 
132
  return cached_file
133
 
134
  if is_main_process():
135
+ timm_download_cached_file(url, check_hash, progress)
136
 
137
  if is_dist_avail_and_initialized():
138
  dist.barrier()
139
 
140
+ return get_cached_file_path()
configs/config.py CHANGED
@@ -64,6 +64,9 @@ parser.add_argument("--lr", type=float, default=5e-5, help='lr to fine tuning ad
64
  # epochs
65
  parser.add_argument("--epochs", type=int, default=10, help='epochs to fine tuning adapters.')
66
  parser.add_argument("--batch_size", type=int, default=8)
 
 
 
67
 
68
 
69
  parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.")
@@ -73,6 +76,7 @@ parser.add_argument("--run", type=str, default='train', help="train, test")
73
  parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.")
74
  parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.")
75
  parser.add_argument("--max_eval_rows", type=int, default=-1, help="Max samples per split during eval; -1 = all.")
 
76
  parser.add_argument("--eval_split", type=str, default="test_u", help="Which split to evaluate: test_s, test_u, test_n.")
77
  parser.add_argument("--gate_only", action="store_true", help="Train only A-min referent gate parameters.")
78
  parser.add_argument("--init_from_saved_model", action="store_true", help="Initialize training from --saved_model before updates.")
@@ -88,6 +92,19 @@ parser.add_argument("--eval_only", action="store_true", help="Only evaluate in c
88
  parser.add_argument("--disable_gate", action="store_true", help="Force A-min gate to identity for cached pipeline baseline checks.")
89
  parser.add_argument("--gate_checkpoint", type=str, default="", help="Optional referent-gate-only checkpoint to overlay after loading --saved_model.")
90
  parser.add_argument("--save_gate_only", action="store_true", help="In cached-gate training, save only referent_gate parameters.")
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
 
 
64
  # epochs
65
  parser.add_argument("--epochs", type=int, default=10, help='epochs to fine tuning adapters.')
66
  parser.add_argument("--batch_size", type=int, default=8)
67
+ parser.add_argument("--ce_loss_weight", type=float, default=1.0, help="Weight for language modeling loss.")
68
+ parser.add_argument("--dice_loss_weight", type=float, default=0.5, help="Weight for dice segmentation loss.")
69
+ parser.add_argument("--bce_loss_weight", type=float, default=2.0, help="Weight for BCE segmentation loss.")
70
 
71
 
72
  parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.")
 
76
  parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.")
77
  parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.")
78
  parser.add_argument("--max_eval_rows", type=int, default=-1, help="Max samples per split during eval; -1 = all.")
79
+ parser.add_argument("--subset_manifest", type=str, default="", help="Optional JSON file that fixes train/eval subset indices per split.")
80
  parser.add_argument("--eval_split", type=str, default="test_u", help="Which split to evaluate: test_s, test_u, test_n.")
81
  parser.add_argument("--gate_only", action="store_true", help="Train only A-min referent gate parameters.")
82
  parser.add_argument("--init_from_saved_model", action="store_true", help="Initialize training from --saved_model before updates.")
 
92
  parser.add_argument("--disable_gate", action="store_true", help="Force A-min gate to identity for cached pipeline baseline checks.")
93
  parser.add_argument("--gate_checkpoint", type=str, default="", help="Optional referent-gate-only checkpoint to overlay after loading --saved_model.")
94
  parser.add_argument("--save_gate_only", action="store_true", help="In cached-gate training, save only referent_gate parameters.")
95
+ parser.add_argument("--use_residual_prompt_bridge", action="store_true", help="Enable the image-conditioned residual prompt bridge before SAM prompt encoding.")
96
+ parser.add_argument("--bridge_only", action="store_true", help="Freeze all parameters except the residual prompt bridge.")
97
+ parser.add_argument("--bridge_pm_weight", type=float, default=0.0, help="Weight for prompt-manifold teacher loss.")
98
+ parser.add_argument("--bridge_rg_weight", type=float, default=0.0, help="Weight for region-semantic teacher loss.")
99
+ parser.add_argument("--bridge_norm_weight", type=float, default=0.0, help="Weight for prompt-norm preservation loss.")
100
+ parser.add_argument("--bridge_mode", type=str, default="additive", choices=["additive", "directional"], help="Prompt bridge parameterization.")
101
+ parser.add_argument("--bridge_condition", type=str, default="image", choices=["image", "q_only"], help="Condition source for the prompt bridge.")
102
+ parser.add_argument("--bridge_directional_alpha", type=float, default=0.1, help="Step size used by directional bridge updates after orthogonalization.")
103
+ parser.add_argument("--bridge_gate_bias_init", type=float, default=-4.0, help="Initial bias for bridge gate sigmoid.")
104
+ parser.add_argument("--bridge_residual_init_std", type=float, default=1e-3, help="Std used to initialize the bridge residual projection.")
105
+ parser.add_argument("--bridge_target_frame", type=int, default=5, help="Frame index used to build bridge teachers.")
106
+ parser.add_argument("--bridge_sanity_only", action="store_true", help="Run only bridge sanity checks (gradient, identity, teacher norms) and exit.")
107
+ parser.add_argument("--bridge_sanity_batches", type=int, default=3, help="How many batches to scan during bridge sanity stats collection.")
108
 
109
 
110
 
models/avs_model.py CHANGED
@@ -100,6 +100,74 @@ def compute_alignment_loss(q: torch.Tensor, pos_feats: list, neg_feats: list, te
100
  return total_loss / count
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  class Simtoken_MetaModel:
@@ -115,6 +183,12 @@ class Simtoken_MetaModel:
115
  self.config.train_mask_decoder = kwargs["train_mask_decoder"]
116
  self.config.out_dim = kwargs["out_dim"]
117
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
 
 
 
 
 
 
118
  else:
119
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
120
  self.initialize_lisa_modules(self.config)
@@ -143,6 +217,17 @@ class Simtoken_MetaModel:
143
  for param in self.text_hidden_fcs.parameters():
144
  param.requires_grad = True
145
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  class Simtoken_Model(Simtoken_MetaModel, ChatUniViLlamaModel):
148
  def __init__(
@@ -234,6 +319,104 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
234
  self.compress = kwargs.pop("compress", True)
235
 
236
  self.start = kwargs.pop("start")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
 
239
 
@@ -284,7 +467,7 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
284
  # audio_embeddings = audio_features # [B, 10, 128]
285
 
286
  if target_frame is None:
287
- target_frame = 5
288
  else:
289
  target_frame = int(target_frame)
290
  if target_frame < 0 or target_frame >= num_frames:
@@ -315,6 +498,60 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
315
  seg_hidden_states = output_hidden_states[-1][seg_token_mask] # [seg_num, hidden_size]
316
  seg_embeddings = self.model.text_hidden_fcs[0](seg_hidden_states) # [seg_num,256]
317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  # print("seg_embeddings in this batch:", seg_embeddings.shape)
319
  # print("vids:", vids)
320
  # print("fids:", fids)
@@ -337,11 +574,13 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
337
 
338
 
339
  pred_embeddings = []
 
340
  pred_hidden_states = []
341
  #--------------------------------------------------------------------------------------------
342
  pred_idx = 0
343
  for ref_num in refs_num:
344
  pred_embeddings.append(seg_embeddings[pred_idx:pred_idx + ref_num])
 
345
  pred_hidden_states.append(seg_hidden_states[pred_idx:pred_idx + ref_num])
346
  pred_idx += ref_num
347
  # list[B]:[num_seg, 256]
@@ -359,7 +598,7 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
359
  points=None,
360
  boxes=None,
361
  masks=None,
362
- text_embeds=pred_embeddings[i].unsqueeze(1), # [1, 1 ,256]
363
  )
364
  # 确保数据类型一致
365
  sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
@@ -395,12 +634,23 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
395
  gt_masks = masks_list # list[B]:[num_seg, T, H, W]
396
 
397
  if inference:
398
- return {
399
  "pred_masks": pred_masks, # list[B]:[num_seg, T, H, W]
400
  "gt_masks": gt_masks, # list[B]:[num_seg, T, H, W]
401
  "seg_embeddings": pred_embeddings, # list[B]:[num_seg, 256]
 
402
  "seg_hidden_states": pred_hidden_states, # list[B]:[num_seg, hidden_size]
403
  }
 
 
 
 
 
 
 
 
 
 
404
 
405
  model_output = output
406
  output = model_output.logits
@@ -451,6 +701,8 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
451
  else:
452
  loss = ce_loss + mask_loss
453
 
 
 
454
  return {
455
  "loss": loss,
456
  "ce_loss": ce_loss,
@@ -458,6 +710,12 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
458
  "mask_dice_loss": mask_dice_loss,
459
  "mask_loss": mask_loss,
460
  "ct_loss": ct_loss,
 
 
 
 
 
 
461
  "pred_masks": pred_masks,
462
  "gt_masks": gt_masks,
463
  }
 
100
  return total_loss / count
101
 
102
 
103
+ class ResidualPromptBridge(nn.Module):
104
+ def __init__(
105
+ self,
106
+ embedding_dim: int,
107
+ mode: str = "additive",
108
+ condition: str = "image",
109
+ directional_alpha: float = 0.1,
110
+ gate_bias_init: float = -4.0,
111
+ residual_init_std: float = 1e-3,
112
+ ) -> None:
113
+ super().__init__()
114
+ self.embedding_dim = embedding_dim
115
+ self.mode = mode
116
+ self.condition = condition
117
+ self.directional_alpha = directional_alpha
118
+ self.scale = math.sqrt(float(embedding_dim))
119
+ self.attn_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
120
+ self.residual_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
121
+ self.gate = nn.Linear(embedding_dim * 2, embedding_dim)
122
+ self.reset_parameters(gate_bias_init=gate_bias_init, residual_init_std=residual_init_std)
123
+
124
+ def reset_parameters(self, gate_bias_init: float, residual_init_std: float) -> None:
125
+ nn.init.xavier_uniform_(self.attn_proj.weight)
126
+ nn.init.normal_(self.residual_proj.weight, mean=0.0, std=residual_init_std)
127
+ nn.init.zeros_(self.gate.weight)
128
+ nn.init.constant_(self.gate.bias, gate_bias_init)
129
+
130
+ def forward(self, q: torch.Tensor, image_embeddings: torch.Tensor) -> dict:
131
+ if self.condition == "q_only":
132
+ attn = None
133
+ region = self.attn_proj(q)
134
+ else:
135
+ if image_embeddings.dim() != 4:
136
+ raise ValueError(
137
+ f"ResidualPromptBridge expects image_embeddings [B, C, H, W], got {tuple(image_embeddings.shape)}"
138
+ )
139
+ image_tokens = image_embeddings.flatten(2).transpose(1, 2) # [B, HW, C]
140
+ q_proj = self.attn_proj(q) # [B, C]
141
+ attn_logits = torch.bmm(image_tokens, q_proj.unsqueeze(-1)).squeeze(-1) / self.scale
142
+ attn = torch.softmax(attn_logits, dim=-1)
143
+ region = torch.bmm(attn.unsqueeze(1), image_tokens).squeeze(1)
144
+
145
+ gate = torch.sigmoid(self.gate(torch.cat([q, region], dim=-1)))
146
+ region_update = self.residual_proj(region)
147
+
148
+ if self.mode == "directional":
149
+ q_dir = F.normalize(q, dim=-1)
150
+ q_parallel = (region_update * q_dir).sum(dim=-1, keepdim=True) * q_dir
151
+ region_orth = region_update - q_parallel
152
+ region_orth_norm = region_orth.norm(dim=-1, keepdim=True).clamp_min(1e-6)
153
+ region_dir = region_orth / region_orth_norm
154
+ alpha = self.directional_alpha * gate.mean(dim=-1, keepdim=True)
155
+ mixed_dir = F.normalize(q_dir + alpha * region_dir, dim=-1)
156
+ p_hat = q.norm(dim=-1, keepdim=True) * mixed_dir
157
+ delta = p_hat - q
158
+ else:
159
+ delta = gate * region_update
160
+ p_hat = q + delta
161
+
162
+ return {
163
+ "p_hat": p_hat,
164
+ "attn": attn,
165
+ "region": region,
166
+ "gate": gate,
167
+ "delta": delta,
168
+ }
169
+
170
+
171
 
172
 
173
  class Simtoken_MetaModel:
 
183
  self.config.train_mask_decoder = kwargs["train_mask_decoder"]
184
  self.config.out_dim = kwargs["out_dim"]
185
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
186
+ self.config.use_residual_prompt_bridge = kwargs.get("use_residual_prompt_bridge", False)
187
+ self.config.bridge_mode = kwargs.get("bridge_mode", "additive")
188
+ self.config.bridge_condition = kwargs.get("bridge_condition", "image")
189
+ self.config.bridge_directional_alpha = kwargs.get("bridge_directional_alpha", 0.1)
190
+ self.config.bridge_gate_bias_init = kwargs.get("bridge_gate_bias_init", -4.0)
191
+ self.config.bridge_residual_init_std = kwargs.get("bridge_residual_init_std", 1e-3)
192
  else:
193
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
194
  self.initialize_lisa_modules(self.config)
 
217
  for param in self.text_hidden_fcs.parameters():
218
  param.requires_grad = True
219
 
220
+ self.prompt_bridge = None
221
+ if getattr(config, "use_residual_prompt_bridge", False):
222
+ self.prompt_bridge = ResidualPromptBridge(
223
+ embedding_dim=out_dim,
224
+ mode=getattr(config, "bridge_mode", "additive"),
225
+ condition=getattr(config, "bridge_condition", "image"),
226
+ directional_alpha=getattr(config, "bridge_directional_alpha", 0.1),
227
+ gate_bias_init=getattr(config, "bridge_gate_bias_init", -4.0),
228
+ residual_init_std=getattr(config, "bridge_residual_init_std", 1e-3),
229
+ )
230
+
231
 
232
  class Simtoken_Model(Simtoken_MetaModel, ChatUniViLlamaModel):
233
  def __init__(
 
319
  self.compress = kwargs.pop("compress", True)
320
 
321
  self.start = kwargs.pop("start")
322
+ self.use_residual_prompt_bridge = kwargs.pop("use_residual_prompt_bridge", False)
323
+ self.bridge_pm_weight = kwargs.pop("bridge_pm_weight", 0.0)
324
+ self.bridge_rg_weight = kwargs.pop("bridge_rg_weight", 0.0)
325
+ self.bridge_norm_weight = kwargs.pop("bridge_norm_weight", 0.0)
326
+ self.bridge_target_frame = kwargs.pop("bridge_target_frame", 5)
327
+
328
+ def _expand_prompt_level_inputs(
329
+ self,
330
+ image_features: List[torch.Tensor],
331
+ masks_list: List[torch.FloatTensor],
332
+ refs_num: List[int],
333
+ target_frame: int,
334
+ dtype: torch.dtype,
335
+ device: torch.device,
336
+ ) -> tuple:
337
+ prompt_image_embeddings = []
338
+ prompt_masks = []
339
+ prompt_mask_size = self.model.visual_model.prompt_encoder.mask_input_size
340
+
341
+ for sample_idx, ref_num in enumerate(refs_num):
342
+ frame_feat = image_features[sample_idx][target_frame].to(device=device, dtype=dtype)
343
+ for prompt_idx in range(ref_num):
344
+ prompt_image_embeddings.append(frame_feat)
345
+ mask = masks_list[sample_idx][prompt_idx, target_frame].to(
346
+ device=device, dtype=torch.float32
347
+ )
348
+ mask = F.interpolate(
349
+ mask.unsqueeze(0).unsqueeze(0),
350
+ size=prompt_mask_size,
351
+ mode="nearest",
352
+ ).squeeze(0).squeeze(0)
353
+ prompt_masks.append(mask)
354
+
355
+ return torch.stack(prompt_image_embeddings, dim=0), torch.stack(prompt_masks, dim=0)
356
+
357
+ def _compute_prompt_bridge_teachers(
358
+ self,
359
+ prompt_image_embeddings: torch.Tensor,
360
+ prompt_masks: torch.Tensor,
361
+ dtype: torch.dtype,
362
+ ) -> tuple:
363
+ mask_lowres = prompt_masks.unsqueeze(1)
364
+ _, dense_mask_embeddings = self.model.visual_model.prompt_encoder(
365
+ points=None,
366
+ boxes=None,
367
+ masks=mask_lowres.to(dtype=dtype),
368
+ text_embeds=None,
369
+ )
370
+ prompt_manifold_teacher = dense_mask_embeddings.mean(dim=(2, 3))
371
+
372
+ mask_64 = F.interpolate(
373
+ prompt_masks.unsqueeze(1),
374
+ size=prompt_image_embeddings.shape[-2:],
375
+ mode="nearest",
376
+ )
377
+ flat_feats = prompt_image_embeddings.flatten(2)
378
+ flat_mask = mask_64.flatten(2)
379
+ masked_sum = (flat_feats * flat_mask).sum(dim=-1)
380
+ mask_area = flat_mask.sum(dim=-1).clamp_min(1.0)
381
+ region_teacher = masked_sum / mask_area
382
+
383
+ return prompt_manifold_teacher, region_teacher
384
+
385
+ def _summarize_prompt_bridge(
386
+ self,
387
+ q: torch.Tensor,
388
+ p_hat: torch.Tensor,
389
+ prompt_manifold_teacher: torch.Tensor,
390
+ region_teacher: torch.Tensor,
391
+ gate: torch.Tensor,
392
+ ) -> dict:
393
+ delta = p_hat - q
394
+ q_norm = q.norm(dim=-1)
395
+ p_hat_norm = p_hat.norm(dim=-1)
396
+ pm_cos = F.cosine_similarity(p_hat, prompt_manifold_teacher, dim=-1)
397
+ rg_cos = F.cosine_similarity(p_hat, region_teacher, dim=-1)
398
+ qq_cos = F.cosine_similarity(p_hat, q, dim=-1)
399
+ teacher_cos = F.cosine_similarity(prompt_manifold_teacher, region_teacher, dim=-1)
400
+ delta_q_cos = F.cosine_similarity(delta, q, dim=-1)
401
+ delta_pm_cos = F.cosine_similarity(delta, prompt_manifold_teacher, dim=-1)
402
+ delta_rg_cos = F.cosine_similarity(delta, region_teacher, dim=-1)
403
+
404
+ return {
405
+ "q_norm_mean": q_norm.mean().item(),
406
+ "p_hat_norm_mean": p_hat_norm.mean().item(),
407
+ "delta_norm_mean": delta.norm(dim=-1).mean().item(),
408
+ "cos_p_hat_q_mean": qq_cos.mean().item(),
409
+ "cos_p_hat_p_mask_mean": pm_cos.mean().item(),
410
+ "cos_p_hat_z_gt_mean": rg_cos.mean().item(),
411
+ "cos_delta_q_mean": delta_q_cos.mean().item(),
412
+ "cos_delta_p_mask_mean": delta_pm_cos.mean().item(),
413
+ "cos_delta_z_gt_mean": delta_rg_cos.mean().item(),
414
+ "p_mask_norm_mean": prompt_manifold_teacher.norm(dim=-1).mean().item(),
415
+ "z_gt_norm_mean": region_teacher.norm(dim=-1).mean().item(),
416
+ "cos_p_mask_z_gt_mean": teacher_cos.mean().item(),
417
+ "gate_mean": gate.mean().item(),
418
+ "gate_std": gate.std(unbiased=False).item(),
419
+ }
420
 
421
 
422
 
 
467
  # audio_embeddings = audio_features # [B, 10, 128]
468
 
469
  if target_frame is None:
470
+ target_frame = self.bridge_target_frame
471
  else:
472
  target_frame = int(target_frame)
473
  if target_frame < 0 or target_frame >= num_frames:
 
498
  seg_hidden_states = output_hidden_states[-1][seg_token_mask] # [seg_num, hidden_size]
499
  seg_embeddings = self.model.text_hidden_fcs[0](seg_hidden_states) # [seg_num,256]
500
 
501
+ prompt_embeddings_all = seg_embeddings
502
+ bridge_metrics = {}
503
+ bridge_pm_loss = seg_embeddings.new_zeros(())
504
+ bridge_rg_loss = seg_embeddings.new_zeros(())
505
+ bridge_norm_loss = seg_embeddings.new_zeros(())
506
+ bridge_teacher_loss = seg_embeddings.new_zeros(())
507
+ bridge_teacher_loss_raw = seg_embeddings.new_zeros(())
508
+ prompt_image_embeddings = None
509
+ prompt_manifold_teacher = None
510
+ region_teacher = None
511
+
512
+ if self.use_residual_prompt_bridge:
513
+ prompt_image_embeddings, prompt_masks = self._expand_prompt_level_inputs(
514
+ image_features=image_features,
515
+ masks_list=masks_list,
516
+ refs_num=refs_num,
517
+ target_frame=target_frame,
518
+ dtype=seg_embeddings.dtype,
519
+ device=seg_embeddings.device,
520
+ )
521
+ bridge_outputs = self.model.prompt_bridge(seg_embeddings, prompt_image_embeddings)
522
+ prompt_embeddings_all = bridge_outputs["p_hat"]
523
+ prompt_manifold_teacher, region_teacher = self._compute_prompt_bridge_teachers(
524
+ prompt_image_embeddings=prompt_image_embeddings,
525
+ prompt_masks=prompt_masks,
526
+ dtype=seg_embeddings.dtype,
527
+ )
528
+
529
+ pm_l1 = F.smooth_l1_loss(prompt_embeddings_all, prompt_manifold_teacher)
530
+ pm_cos = 1.0 - F.cosine_similarity(
531
+ prompt_embeddings_all, prompt_manifold_teacher, dim=-1
532
+ ).mean()
533
+ bridge_pm_loss = pm_l1 + pm_cos
534
+ bridge_rg_loss = 1.0 - F.cosine_similarity(
535
+ prompt_embeddings_all, region_teacher, dim=-1
536
+ ).mean()
537
+ bridge_norm_loss = F.mse_loss(
538
+ prompt_embeddings_all.norm(dim=-1),
539
+ seg_embeddings.norm(dim=-1),
540
+ )
541
+ bridge_teacher_loss_raw = bridge_pm_loss + bridge_rg_loss + bridge_norm_loss
542
+ bridge_teacher_loss = (
543
+ self.bridge_pm_weight * bridge_pm_loss
544
+ + self.bridge_rg_weight * bridge_rg_loss
545
+ + self.bridge_norm_weight * bridge_norm_loss
546
+ )
547
+ bridge_metrics = self._summarize_prompt_bridge(
548
+ q=seg_embeddings,
549
+ p_hat=prompt_embeddings_all,
550
+ prompt_manifold_teacher=prompt_manifold_teacher,
551
+ region_teacher=region_teacher,
552
+ gate=bridge_outputs["gate"],
553
+ )
554
+
555
  # print("seg_embeddings in this batch:", seg_embeddings.shape)
556
  # print("vids:", vids)
557
  # print("fids:", fids)
 
574
 
575
 
576
  pred_embeddings = []
577
+ prompt_embeddings = []
578
  pred_hidden_states = []
579
  #--------------------------------------------------------------------------------------------
580
  pred_idx = 0
581
  for ref_num in refs_num:
582
  pred_embeddings.append(seg_embeddings[pred_idx:pred_idx + ref_num])
583
+ prompt_embeddings.append(prompt_embeddings_all[pred_idx:pred_idx + ref_num])
584
  pred_hidden_states.append(seg_hidden_states[pred_idx:pred_idx + ref_num])
585
  pred_idx += ref_num
586
  # list[B]:[num_seg, 256]
 
598
  points=None,
599
  boxes=None,
600
  masks=None,
601
+ text_embeds=prompt_embeddings[i].unsqueeze(1), # [1, 1 ,256]
602
  )
603
  # 确保数据类型一致
604
  sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
 
634
  gt_masks = masks_list # list[B]:[num_seg, T, H, W]
635
 
636
  if inference:
637
+ result = {
638
  "pred_masks": pred_masks, # list[B]:[num_seg, T, H, W]
639
  "gt_masks": gt_masks, # list[B]:[num_seg, T, H, W]
640
  "seg_embeddings": pred_embeddings, # list[B]:[num_seg, 256]
641
+ "prompt_embeddings": prompt_embeddings, # list[B]:[num_seg, 256]
642
  "seg_hidden_states": pred_hidden_states, # list[B]:[num_seg, hidden_size]
643
  }
644
+ if self.use_residual_prompt_bridge:
645
+ result.update(
646
+ {
647
+ "bridge_metrics": bridge_metrics,
648
+ "bridge_pm_loss": bridge_pm_loss.detach(),
649
+ "bridge_rg_loss": bridge_rg_loss.detach(),
650
+ "bridge_norm_loss": bridge_norm_loss.detach(),
651
+ }
652
+ )
653
+ return result
654
 
655
  model_output = output
656
  output = model_output.logits
 
701
  else:
702
  loss = ce_loss + mask_loss
703
 
704
+ loss = loss + bridge_teacher_loss
705
+
706
  return {
707
  "loss": loss,
708
  "ce_loss": ce_loss,
 
710
  "mask_dice_loss": mask_dice_loss,
711
  "mask_loss": mask_loss,
712
  "ct_loss": ct_loss,
713
+ "bridge_pm_loss": bridge_pm_loss,
714
+ "bridge_rg_loss": bridge_rg_loss,
715
+ "bridge_norm_loss": bridge_norm_loss,
716
+ "bridge_teacher_loss": bridge_teacher_loss,
717
+ "bridge_teacher_loss_raw": bridge_teacher_loss_raw,
718
+ "bridge_metrics": bridge_metrics,
719
  "pred_masks": pred_masks,
720
  "gt_masks": gt_masks,
721
  }