yfan07 commited on
Commit
ac63a19
·
verified ·
1 Parent(s): e214bf0

Restore original SimToken source files

Browse files
.gitattributes CHANGED
The diff for this file is too large to render. See raw diff
 
ChatUniVi/model/multimodal_encoder/eva_vit.py CHANGED
@@ -12,8 +12,8 @@ import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
  import torch.utils.checkpoint as checkpoint
15
- from timm.layers import drop_path, to_2tuple, trunc_normal_
16
- from timm.models import register_model
17
 
18
  from .utils import download_cached_file
19
 
@@ -445,4 +445,4 @@ def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, pre
445
  if precision == "fp16":
446
  # model.to("cuda")
447
  convert_weights_to_fp16(model)
448
- return model
 
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
  import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
 
18
  from .utils import download_cached_file
19
 
 
445
  if precision == "fp16":
446
  # model.to("cuda")
447
  convert_weights_to_fp16(model)
448
+ return model
ChatUniVi/model/multimodal_encoder/utils.py CHANGED
@@ -11,10 +11,7 @@ import os
11
 
12
  import torch
13
  import torch.distributed as dist
14
- from timm.models._hub import (
15
- download_cached_file as timm_download_cached_file,
16
- get_cache_dir as timm_get_cache_dir,
17
- )
18
 
19
 
20
  def setup_for_distributed(is_master):
@@ -127,14 +124,14 @@ def download_cached_file(url, check_hash=True, progress=False):
127
  # a hack to sync the file path across processes
128
  parts = torch.hub.urlparse(url)
129
  filename = os.path.basename(parts.path)
130
- cached_file = os.path.join(timm_get_cache_dir(), filename)
131
 
132
  return cached_file
133
 
134
  if is_main_process():
135
- timm_download_cached_file(url, check_hash, progress)
136
 
137
  if is_dist_avail_and_initialized():
138
  dist.barrier()
139
 
140
- return get_cached_file_path()
 
11
 
12
  import torch
13
  import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
 
 
 
15
 
16
 
17
  def setup_for_distributed(is_master):
 
124
  # a hack to sync the file path across processes
125
  parts = torch.hub.urlparse(url)
126
  filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
 
129
  return cached_file
130
 
131
  if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
 
134
  if is_dist_avail_and_initialized():
135
  dist.barrier()
136
 
137
+ return get_cached_file_path()
configs/config.py CHANGED
@@ -31,7 +31,7 @@ parser = argparse.ArgumentParser(
31
 
32
 
33
 
34
- parser.add_argument("--vision_pretrained",type=str,default='/workspace/SimToken/models/segment_anything/sam_vit_h_4b8939.pth')
35
  parser.add_argument("--vision_tower",type=str,default='openai/clip-vit-large-patch14')
36
  parser.add_argument("--mllm",type=str,default='Chat-UniVi/Chat-UniVi-7B-v1.5')
37
 
@@ -44,9 +44,9 @@ parser.add_argument("--start",type=int,default=0)
44
 
45
  parser.add_argument("--name",type=str,default='testrun')
46
  # path to ref-avs dataset
47
- parser.add_argument("--data_dir",type=str,default='/workspace/SimToken/data',help=f"The data paranet dir. File arch should be: {file_arch}")
48
  # path to pretrained checkpoints
49
- parser.add_argument("--saved_model",type=str,default='/workspace/SimToken/checkpoints/simtoken_pretrained.pth', help="the pretrained simtoken pth")
50
 
51
 
52
  parser.add_argument("--log_root",type=str,default='log', help="where to save log during training")
@@ -64,9 +64,6 @@ parser.add_argument("--lr", type=float, default=5e-5, help='lr to fine tuning ad
64
  # epochs
65
  parser.add_argument("--epochs", type=int, default=10, help='epochs to fine tuning adapters.')
66
  parser.add_argument("--batch_size", type=int, default=8)
67
- parser.add_argument("--ce_loss_weight", type=float, default=1.0, help="Weight for language modeling loss.")
68
- parser.add_argument("--dice_loss_weight", type=float, default=0.5, help="Weight for dice segmentation loss.")
69
- parser.add_argument("--bce_loss_weight", type=float, default=2.0, help="Weight for BCE segmentation loss.")
70
 
71
 
72
  parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.")
@@ -75,36 +72,6 @@ parser.add_argument("--run", type=str, default='train', help="train, test")
75
 
76
  parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.")
77
  parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.")
78
- parser.add_argument("--max_eval_rows", type=int, default=-1, help="Max samples per split during eval; -1 = all.")
79
- parser.add_argument("--subset_manifest", type=str, default="", help="Optional JSON file that fixes train/eval subset indices per split.")
80
- parser.add_argument("--eval_split", type=str, default="test_u", help="Which split to evaluate: test_s, test_u, test_n.")
81
- parser.add_argument("--gate_only", action="store_true", help="Train only A-min referent gate parameters.")
82
- parser.add_argument("--init_from_saved_model", action="store_true", help="Initialize training from --saved_model before updates.")
83
- parser.add_argument("--max_steps", type=int, default=-1, help="Max optimizer steps during training; -1 = full schedule.")
84
- parser.add_argument("--overfit_samples", type=int, default=-1, help="Train on the first N train samples for overfit probes; -1 = full train set.")
85
- parser.add_argument("--log_gate_stats_every", type=int, default=-1, help="Log A-min gate/proj stats every N optimizer steps; -1 = disabled.")
86
- parser.add_argument("--skip_eval_after_train", action="store_true", help="Save checkpoint and exit without post-train evaluation.")
87
- parser.add_argument("--eval_train_only", action="store_true", help="After training, evaluate only the training subset and skip test splits.")
88
- parser.add_argument("--cache_root", type=str, default="/workspace/SimToken/cache_q", help="Root directory for cached q features.")
89
- parser.add_argument("--cache_split", type=str, default="train", help="Dataset split to cache or read cached q features from.")
90
- parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite existing cached q feature files.")
91
- parser.add_argument("--eval_only", action="store_true", help="Only evaluate in cached-gate scripts; do not train.")
92
- parser.add_argument("--disable_gate", action="store_true", help="Force A-min gate to identity for cached pipeline baseline checks.")
93
- parser.add_argument("--gate_checkpoint", type=str, default="", help="Optional referent-gate-only checkpoint to overlay after loading --saved_model.")
94
- parser.add_argument("--save_gate_only", action="store_true", help="In cached-gate training, save only referent_gate parameters.")
95
- parser.add_argument("--use_residual_prompt_bridge", action="store_true", help="Enable the image-conditioned residual prompt bridge before SAM prompt encoding.")
96
- parser.add_argument("--bridge_only", action="store_true", help="Freeze all parameters except the residual prompt bridge.")
97
- parser.add_argument("--bridge_pm_weight", type=float, default=0.0, help="Weight for prompt-manifold teacher loss.")
98
- parser.add_argument("--bridge_rg_weight", type=float, default=0.0, help="Weight for region-semantic teacher loss.")
99
- parser.add_argument("--bridge_norm_weight", type=float, default=0.0, help="Weight for prompt-norm preservation loss.")
100
- parser.add_argument("--bridge_mode", type=str, default="additive", choices=["additive", "directional"], help="Prompt bridge parameterization.")
101
- parser.add_argument("--bridge_condition", type=str, default="image", choices=["image", "q_only"], help="Condition source for the prompt bridge.")
102
- parser.add_argument("--bridge_directional_alpha", type=float, default=0.1, help="Step size used by directional bridge updates after orthogonalization.")
103
- parser.add_argument("--bridge_gate_bias_init", type=float, default=-4.0, help="Initial bias for bridge gate sigmoid.")
104
- parser.add_argument("--bridge_residual_init_std", type=float, default=1e-3, help="Std used to initialize the bridge residual projection.")
105
- parser.add_argument("--bridge_target_frame", type=int, default=5, help="Frame index used to build bridge teachers.")
106
- parser.add_argument("--bridge_sanity_only", action="store_true", help="Run only bridge sanity checks (gradient, identity, teacher norms) and exit.")
107
- parser.add_argument("--bridge_sanity_batches", type=int, default=3, help="How many batches to scan during bridge sanity stats collection.")
108
 
109
 
110
 
 
31
 
32
 
33
 
34
+ parser.add_argument("--vision_pretrained",type=str,default='path/to/segment_anything/sam_vit_h_4b8939.pth')
35
  parser.add_argument("--vision_tower",type=str,default='openai/clip-vit-large-patch14')
36
  parser.add_argument("--mllm",type=str,default='Chat-UniVi/Chat-UniVi-7B-v1.5')
37
 
 
44
 
45
  parser.add_argument("--name",type=str,default='testrun')
46
  # path to ref-avs dataset
47
+ parser.add_argument("--data_dir",type=str,default='data',help=f"The data paranet dir. File arch should be: {file_arch}")
48
  # path to pretrained checkpoints
49
+ parser.add_argument("--saved_model",type=str,default='trained_simtoken.pth', help="the pretrained simtoken pth")
50
 
51
 
52
  parser.add_argument("--log_root",type=str,default='log', help="where to save log during training")
 
64
  # epochs
65
  parser.add_argument("--epochs", type=int, default=10, help='epochs to fine tuning adapters.')
66
  parser.add_argument("--batch_size", type=int, default=8)
 
 
 
67
 
68
 
69
  parser.add_argument("--gpu_id", type=str, default="0", help="The GPU device to run generation on.")
 
72
 
73
  parser.add_argument("--frame_n", type=int, default=10, help="Frame num of each video. Fixed to 10.")
74
  parser.add_argument("--text_max_len", type=int, default=25, help="Maximum textual reference length.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
 
data/metadata.csv CHANGED
The diff for this file is too large to render. See raw diff
 
load_model.py CHANGED
@@ -208,10 +208,7 @@ def collate_fn(batch, tokenizer=None):
208
 
209
  import torch.multiprocessing as mp
210
  if __name__ == "__main__":
211
- try:
212
- mp.set_start_method("spawn")
213
- except RuntimeError:
214
- pass
215
  set_seed(42)
216
  tokenizer = transformers.AutoTokenizer.from_pretrained(
217
  args.mllm,
@@ -227,17 +224,14 @@ if __name__ == "__main__":
227
  print("seg_token_idx: ", seg_token_idx)
228
 
229
 
230
- if args.eval_split not in {"test_s", "test_u", "test_n"}:
231
- raise ValueError(f"Unsupported eval_split: {args.eval_split}")
 
232
 
233
- val_dataset = REFAVS(args.eval_split, args, tokenizer, input_type='refer')
234
- val_dataloader = DataLoader(
235
- val_dataset,
236
- batch_size=1,
237
- shuffle=False,
238
- num_workers=4,
239
- collate_fn=partial(collate_fn, tokenizer=tokenizer),
240
- )
241
 
242
 
243
 
@@ -343,12 +337,8 @@ if __name__ == "__main__":
343
  model = model.to("cuda")
344
  model.resize_token_embeddings(len(tokenizer))
345
 
346
- missing, unexpected = model.load_state_dict(
347
- torch.load(args.saved_model, map_location="cpu"),
348
- strict=False,
349
- )
350
- print(f"saved model loaded: {args.saved_model}")
351
- print(f"missing keys: {len(missing)} | unexpected keys: {len(unexpected)}")
352
 
353
 
354
  save_root = args.visualization_root
@@ -414,9 +404,7 @@ if __name__ == "__main__":
414
  total_fscore = 0
415
  count = 0
416
 
417
- for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on {name}")):
418
- if args.max_eval_rows > 0 and batch_idx >= args.max_eval_rows:
419
- break
420
  input_dict = dict_to_cuda(batch)
421
 
422
  with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
@@ -450,9 +438,6 @@ if __name__ == "__main__":
450
  total_fscore += fscore * num_seg * T
451
  count += num_seg * T
452
 
453
- if count == 0:
454
- raise RuntimeError(f"No samples were evaluated for {name}")
455
-
456
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
457
 
458
 
@@ -462,9 +447,7 @@ if __name__ == "__main__":
462
  total_metric = 0
463
  count = 0
464
 
465
- for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on Null")):
466
- if args.max_eval_rows > 0 and batch_idx >= args.max_eval_rows:
467
- break
468
  input_dict = dict_to_cuda(batch)
469
  with torch.no_grad():
470
  output_dict = model.forward(images=input_dict["images"],
@@ -494,13 +477,14 @@ if __name__ == "__main__":
494
  total_metric += null_metric * num_seg * T
495
  count += num_seg * T
496
 
497
- if count == 0:
498
- raise RuntimeError("No samples were evaluated for test_n")
499
-
500
  print(f"\n valuate on test_n_refer, metric: {total_metric / count}")
501
 
502
 
503
- if args.eval_split == "test_n":
504
- valuate_Null(model, val_dataloader)
505
- else:
506
- valuate(model, val_dataloader, args.eval_split)
 
 
 
 
 
208
 
209
  import torch.multiprocessing as mp
210
  if __name__ == "__main__":
211
+ mp.set_start_method("spawn")
 
 
 
212
  set_seed(42)
213
  tokenizer = transformers.AutoTokenizer.from_pretrained(
214
  args.mllm,
 
224
  print("seg_token_idx: ", seg_token_idx)
225
 
226
 
227
+ val_dataset_s = REFAVS('test_s', args, tokenizer, input_type='refer')
228
+ # val_dataset_u = REFAVS('test_u', args, tokenizer, input_type='refer')
229
+ # val_dataset_n = REFAVS('test_n', args, tokenizer, input_type='refer')
230
 
231
+
232
+ val_dataloader_s = DataLoader(val_dataset_s, batch_size=1, shuffle=False, num_workers=4, collate_fn=partial(collate_fn, tokenizer=tokenizer))
233
+ # val_dataloader_u = DataLoader(val_dataset_u, batch_size=1, shuffle=False, num_workers=4, collate_fn=partial(collate_fn, tokenizer=tokenizer))
234
+ # val_dataloader_n = DataLoader(val_dataset_n, batch_size=2, shuffle=False, num_workers=4, collate_fn=partial(collate_fn, tokenizer=tokenizer))
 
 
 
 
235
 
236
 
237
 
 
337
  model = model.to("cuda")
338
  model.resize_token_embeddings(len(tokenizer))
339
 
340
+ model.load_state_dict(torch.load(args.saved_model), strict=False)
341
+ print("saved model loaded")
 
 
 
 
342
 
343
 
344
  save_root = args.visualization_root
 
404
  total_fscore = 0
405
  count = 0
406
 
407
+ for batch in tqdm(dataloader, desc=f"Evaluating on {name}"):
 
 
408
  input_dict = dict_to_cuda(batch)
409
 
410
  with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
 
438
  total_fscore += fscore * num_seg * T
439
  count += num_seg * T
440
 
 
 
 
441
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
442
 
443
 
 
447
  total_metric = 0
448
  count = 0
449
 
450
+ for batch in tqdm(dataloader, desc=f"Evaluating on Null"):
 
 
451
  input_dict = dict_to_cuda(batch)
452
  with torch.no_grad():
453
  output_dict = model.forward(images=input_dict["images"],
 
477
  total_metric += null_metric * num_seg * T
478
  count += num_seg * T
479
 
 
 
 
480
  print(f"\n valuate on test_n_refer, metric: {total_metric / count}")
481
 
482
 
483
+
484
+
485
+ valuate(model, val_dataloader_s, 'test_seen')
486
+
487
+ # valuate(model, val_dataloader_u, 'test_unseen')
488
+ #
489
+ # valuate_Null(model, val_dataloader_u)
490
+
models/avs_model.py CHANGED
@@ -100,74 +100,6 @@ def compute_alignment_loss(q: torch.Tensor, pos_feats: list, neg_feats: list, te
100
  return total_loss / count
101
 
102
 
103
- class ResidualPromptBridge(nn.Module):
104
- def __init__(
105
- self,
106
- embedding_dim: int,
107
- mode: str = "additive",
108
- condition: str = "image",
109
- directional_alpha: float = 0.1,
110
- gate_bias_init: float = -4.0,
111
- residual_init_std: float = 1e-3,
112
- ) -> None:
113
- super().__init__()
114
- self.embedding_dim = embedding_dim
115
- self.mode = mode
116
- self.condition = condition
117
- self.directional_alpha = directional_alpha
118
- self.scale = math.sqrt(float(embedding_dim))
119
- self.attn_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
120
- self.residual_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
121
- self.gate = nn.Linear(embedding_dim * 2, embedding_dim)
122
- self.reset_parameters(gate_bias_init=gate_bias_init, residual_init_std=residual_init_std)
123
-
124
- def reset_parameters(self, gate_bias_init: float, residual_init_std: float) -> None:
125
- nn.init.xavier_uniform_(self.attn_proj.weight)
126
- nn.init.normal_(self.residual_proj.weight, mean=0.0, std=residual_init_std)
127
- nn.init.zeros_(self.gate.weight)
128
- nn.init.constant_(self.gate.bias, gate_bias_init)
129
-
130
- def forward(self, q: torch.Tensor, image_embeddings: torch.Tensor) -> dict:
131
- if self.condition == "q_only":
132
- attn = None
133
- region = self.attn_proj(q)
134
- else:
135
- if image_embeddings.dim() != 4:
136
- raise ValueError(
137
- f"ResidualPromptBridge expects image_embeddings [B, C, H, W], got {tuple(image_embeddings.shape)}"
138
- )
139
- image_tokens = image_embeddings.flatten(2).transpose(1, 2) # [B, HW, C]
140
- q_proj = self.attn_proj(q) # [B, C]
141
- attn_logits = torch.bmm(image_tokens, q_proj.unsqueeze(-1)).squeeze(-1) / self.scale
142
- attn = torch.softmax(attn_logits, dim=-1)
143
- region = torch.bmm(attn.unsqueeze(1), image_tokens).squeeze(1)
144
-
145
- gate = torch.sigmoid(self.gate(torch.cat([q, region], dim=-1)))
146
- region_update = self.residual_proj(region)
147
-
148
- if self.mode == "directional":
149
- q_dir = F.normalize(q, dim=-1)
150
- q_parallel = (region_update * q_dir).sum(dim=-1, keepdim=True) * q_dir
151
- region_orth = region_update - q_parallel
152
- region_orth_norm = region_orth.norm(dim=-1, keepdim=True).clamp_min(1e-6)
153
- region_dir = region_orth / region_orth_norm
154
- alpha = self.directional_alpha * gate.mean(dim=-1, keepdim=True)
155
- mixed_dir = F.normalize(q_dir + alpha * region_dir, dim=-1)
156
- p_hat = q.norm(dim=-1, keepdim=True) * mixed_dir
157
- delta = p_hat - q
158
- else:
159
- delta = gate * region_update
160
- p_hat = q + delta
161
-
162
- return {
163
- "p_hat": p_hat,
164
- "attn": attn,
165
- "region": region,
166
- "gate": gate,
167
- "delta": delta,
168
- }
169
-
170
-
171
 
172
 
173
  class Simtoken_MetaModel:
@@ -183,12 +115,6 @@ class Simtoken_MetaModel:
183
  self.config.train_mask_decoder = kwargs["train_mask_decoder"]
184
  self.config.out_dim = kwargs["out_dim"]
185
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
186
- self.config.use_residual_prompt_bridge = kwargs.get("use_residual_prompt_bridge", False)
187
- self.config.bridge_mode = kwargs.get("bridge_mode", "additive")
188
- self.config.bridge_condition = kwargs.get("bridge_condition", "image")
189
- self.config.bridge_directional_alpha = kwargs.get("bridge_directional_alpha", 0.1)
190
- self.config.bridge_gate_bias_init = kwargs.get("bridge_gate_bias_init", -4.0)
191
- self.config.bridge_residual_init_std = kwargs.get("bridge_residual_init_std", 1e-3)
192
  else:
193
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
194
  self.initialize_lisa_modules(self.config)
@@ -217,17 +143,6 @@ class Simtoken_MetaModel:
217
  for param in self.text_hidden_fcs.parameters():
218
  param.requires_grad = True
219
 
220
- self.prompt_bridge = None
221
- if getattr(config, "use_residual_prompt_bridge", False):
222
- self.prompt_bridge = ResidualPromptBridge(
223
- embedding_dim=out_dim,
224
- mode=getattr(config, "bridge_mode", "additive"),
225
- condition=getattr(config, "bridge_condition", "image"),
226
- directional_alpha=getattr(config, "bridge_directional_alpha", 0.1),
227
- gate_bias_init=getattr(config, "bridge_gate_bias_init", -4.0),
228
- residual_init_std=getattr(config, "bridge_residual_init_std", 1e-3),
229
- )
230
-
231
 
232
  class Simtoken_Model(Simtoken_MetaModel, ChatUniViLlamaModel):
233
  def __init__(
@@ -319,104 +234,6 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
319
  self.compress = kwargs.pop("compress", True)
320
 
321
  self.start = kwargs.pop("start")
322
- self.use_residual_prompt_bridge = kwargs.pop("use_residual_prompt_bridge", False)
323
- self.bridge_pm_weight = kwargs.pop("bridge_pm_weight", 0.0)
324
- self.bridge_rg_weight = kwargs.pop("bridge_rg_weight", 0.0)
325
- self.bridge_norm_weight = kwargs.pop("bridge_norm_weight", 0.0)
326
- self.bridge_target_frame = kwargs.pop("bridge_target_frame", 5)
327
-
328
- def _expand_prompt_level_inputs(
329
- self,
330
- image_features: List[torch.Tensor],
331
- masks_list: List[torch.FloatTensor],
332
- refs_num: List[int],
333
- target_frame: int,
334
- dtype: torch.dtype,
335
- device: torch.device,
336
- ) -> tuple:
337
- prompt_image_embeddings = []
338
- prompt_masks = []
339
- prompt_mask_size = self.model.visual_model.prompt_encoder.mask_input_size
340
-
341
- for sample_idx, ref_num in enumerate(refs_num):
342
- frame_feat = image_features[sample_idx][target_frame].to(device=device, dtype=dtype)
343
- for prompt_idx in range(ref_num):
344
- prompt_image_embeddings.append(frame_feat)
345
- mask = masks_list[sample_idx][prompt_idx, target_frame].to(
346
- device=device, dtype=torch.float32
347
- )
348
- mask = F.interpolate(
349
- mask.unsqueeze(0).unsqueeze(0),
350
- size=prompt_mask_size,
351
- mode="nearest",
352
- ).squeeze(0).squeeze(0)
353
- prompt_masks.append(mask)
354
-
355
- return torch.stack(prompt_image_embeddings, dim=0), torch.stack(prompt_masks, dim=0)
356
-
357
- def _compute_prompt_bridge_teachers(
358
- self,
359
- prompt_image_embeddings: torch.Tensor,
360
- prompt_masks: torch.Tensor,
361
- dtype: torch.dtype,
362
- ) -> tuple:
363
- mask_lowres = prompt_masks.unsqueeze(1)
364
- _, dense_mask_embeddings = self.model.visual_model.prompt_encoder(
365
- points=None,
366
- boxes=None,
367
- masks=mask_lowres.to(dtype=dtype),
368
- text_embeds=None,
369
- )
370
- prompt_manifold_teacher = dense_mask_embeddings.mean(dim=(2, 3))
371
-
372
- mask_64 = F.interpolate(
373
- prompt_masks.unsqueeze(1),
374
- size=prompt_image_embeddings.shape[-2:],
375
- mode="nearest",
376
- )
377
- flat_feats = prompt_image_embeddings.flatten(2)
378
- flat_mask = mask_64.flatten(2)
379
- masked_sum = (flat_feats * flat_mask).sum(dim=-1)
380
- mask_area = flat_mask.sum(dim=-1).clamp_min(1.0)
381
- region_teacher = masked_sum / mask_area
382
-
383
- return prompt_manifold_teacher, region_teacher
384
-
385
- def _summarize_prompt_bridge(
386
- self,
387
- q: torch.Tensor,
388
- p_hat: torch.Tensor,
389
- prompt_manifold_teacher: torch.Tensor,
390
- region_teacher: torch.Tensor,
391
- gate: torch.Tensor,
392
- ) -> dict:
393
- delta = p_hat - q
394
- q_norm = q.norm(dim=-1)
395
- p_hat_norm = p_hat.norm(dim=-1)
396
- pm_cos = F.cosine_similarity(p_hat, prompt_manifold_teacher, dim=-1)
397
- rg_cos = F.cosine_similarity(p_hat, region_teacher, dim=-1)
398
- qq_cos = F.cosine_similarity(p_hat, q, dim=-1)
399
- teacher_cos = F.cosine_similarity(prompt_manifold_teacher, region_teacher, dim=-1)
400
- delta_q_cos = F.cosine_similarity(delta, q, dim=-1)
401
- delta_pm_cos = F.cosine_similarity(delta, prompt_manifold_teacher, dim=-1)
402
- delta_rg_cos = F.cosine_similarity(delta, region_teacher, dim=-1)
403
-
404
- return {
405
- "q_norm_mean": q_norm.mean().item(),
406
- "p_hat_norm_mean": p_hat_norm.mean().item(),
407
- "delta_norm_mean": delta.norm(dim=-1).mean().item(),
408
- "cos_p_hat_q_mean": qq_cos.mean().item(),
409
- "cos_p_hat_p_mask_mean": pm_cos.mean().item(),
410
- "cos_p_hat_z_gt_mean": rg_cos.mean().item(),
411
- "cos_delta_q_mean": delta_q_cos.mean().item(),
412
- "cos_delta_p_mask_mean": delta_pm_cos.mean().item(),
413
- "cos_delta_z_gt_mean": delta_rg_cos.mean().item(),
414
- "p_mask_norm_mean": prompt_manifold_teacher.norm(dim=-1).mean().item(),
415
- "z_gt_norm_mean": region_teacher.norm(dim=-1).mean().item(),
416
- "cos_p_mask_z_gt_mean": teacher_cos.mean().item(),
417
- "gate_mean": gate.mean().item(),
418
- "gate_std": gate.std(unbiased=False).item(),
419
- }
420
 
421
 
422
 
@@ -453,7 +270,6 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
453
  epoch: int =0,
454
  inference: bool = False,
455
  num_frames: int = 10,
456
- target_frame: int = None,
457
  contrast: float = 0.0,
458
 
459
  **kwargs,
@@ -466,12 +282,14 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
466
  # audio_embeddings = torch.cat(audio_features, dim=0) # [B*10, 128]
467
  # audio_embeddings = audio_features # [B, 10, 128]
468
 
469
- if target_frame is None:
470
- target_frame = self.bridge_target_frame
 
 
 
471
  else:
472
- target_frame = int(target_frame)
473
- if target_frame < 0 or target_frame >= num_frames:
474
- raise ValueError(f"target_frame must be in [0, {num_frames}), got {target_frame}")
475
 
476
  input_ids, attention_masks, past_key_values, inputs_embeds, labels = super().prepare_inputs_labels_for_multimodal(
477
  input_ids, attention_masks, past_key_values=None, labels=labels, images=images_clip, audio_features=audio_embeddings, target_frame=target_frame, ref_ids=ref_ids
@@ -495,62 +313,7 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
495
  dim=1, ) # [batch_size, seq_len]
496
 
497
 
498
- seg_hidden_states = output_hidden_states[-1][seg_token_mask] # [seg_num, hidden_size]
499
- seg_embeddings = self.model.text_hidden_fcs[0](seg_hidden_states) # [seg_num,256]
500
-
501
- prompt_embeddings_all = seg_embeddings
502
- bridge_metrics = {}
503
- bridge_pm_loss = seg_embeddings.new_zeros(())
504
- bridge_rg_loss = seg_embeddings.new_zeros(())
505
- bridge_norm_loss = seg_embeddings.new_zeros(())
506
- bridge_teacher_loss = seg_embeddings.new_zeros(())
507
- bridge_teacher_loss_raw = seg_embeddings.new_zeros(())
508
- prompt_image_embeddings = None
509
- prompt_manifold_teacher = None
510
- region_teacher = None
511
-
512
- if self.use_residual_prompt_bridge:
513
- prompt_image_embeddings, prompt_masks = self._expand_prompt_level_inputs(
514
- image_features=image_features,
515
- masks_list=masks_list,
516
- refs_num=refs_num,
517
- target_frame=target_frame,
518
- dtype=seg_embeddings.dtype,
519
- device=seg_embeddings.device,
520
- )
521
- bridge_outputs = self.model.prompt_bridge(seg_embeddings, prompt_image_embeddings)
522
- prompt_embeddings_all = bridge_outputs["p_hat"]
523
- prompt_manifold_teacher, region_teacher = self._compute_prompt_bridge_teachers(
524
- prompt_image_embeddings=prompt_image_embeddings,
525
- prompt_masks=prompt_masks,
526
- dtype=seg_embeddings.dtype,
527
- )
528
-
529
- pm_l1 = F.smooth_l1_loss(prompt_embeddings_all, prompt_manifold_teacher)
530
- pm_cos = 1.0 - F.cosine_similarity(
531
- prompt_embeddings_all, prompt_manifold_teacher, dim=-1
532
- ).mean()
533
- bridge_pm_loss = pm_l1 + pm_cos
534
- bridge_rg_loss = 1.0 - F.cosine_similarity(
535
- prompt_embeddings_all, region_teacher, dim=-1
536
- ).mean()
537
- bridge_norm_loss = F.mse_loss(
538
- prompt_embeddings_all.norm(dim=-1),
539
- seg_embeddings.norm(dim=-1),
540
- )
541
- bridge_teacher_loss_raw = bridge_pm_loss + bridge_rg_loss + bridge_norm_loss
542
- bridge_teacher_loss = (
543
- self.bridge_pm_weight * bridge_pm_loss
544
- + self.bridge_rg_weight * bridge_rg_loss
545
- + self.bridge_norm_weight * bridge_norm_loss
546
- )
547
- bridge_metrics = self._summarize_prompt_bridge(
548
- q=seg_embeddings,
549
- p_hat=prompt_embeddings_all,
550
- prompt_manifold_teacher=prompt_manifold_teacher,
551
- region_teacher=region_teacher,
552
- gate=bridge_outputs["gate"],
553
- )
554
 
555
  # print("seg_embeddings in this batch:", seg_embeddings.shape)
556
  # print("vids:", vids)
@@ -574,14 +337,10 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
574
 
575
 
576
  pred_embeddings = []
577
- prompt_embeddings = []
578
- pred_hidden_states = []
579
  #--------------------------------------------------------------------------------------------
580
  pred_idx = 0
581
  for ref_num in refs_num:
582
  pred_embeddings.append(seg_embeddings[pred_idx:pred_idx + ref_num])
583
- prompt_embeddings.append(prompt_embeddings_all[pred_idx:pred_idx + ref_num])
584
- pred_hidden_states.append(seg_hidden_states[pred_idx:pred_idx + ref_num])
585
  pred_idx += ref_num
586
  # list[B]:[num_seg, 256]
587
 
@@ -598,7 +357,7 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
598
  points=None,
599
  boxes=None,
600
  masks=None,
601
- text_embeds=prompt_embeddings[i].unsqueeze(1), # [1, 1 ,256]
602
  )
603
  # 确保数据类型一致
604
  sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
@@ -634,23 +393,10 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
634
  gt_masks = masks_list # list[B]:[num_seg, T, H, W]
635
 
636
  if inference:
637
- result = {
638
- "pred_masks": pred_masks, # list[B]:[num_seg, T, H, W]
639
- "gt_masks": gt_masks, # list[B]:[num_seg, T, H, W]
640
- "seg_embeddings": pred_embeddings, # list[B]:[num_seg, 256]
641
- "prompt_embeddings": prompt_embeddings, # list[B]:[num_seg, 256]
642
- "seg_hidden_states": pred_hidden_states, # list[B]:[num_seg, hidden_size]
643
  }
644
- if self.use_residual_prompt_bridge:
645
- result.update(
646
- {
647
- "bridge_metrics": bridge_metrics,
648
- "bridge_pm_loss": bridge_pm_loss.detach(),
649
- "bridge_rg_loss": bridge_rg_loss.detach(),
650
- "bridge_norm_loss": bridge_norm_loss.detach(),
651
- }
652
- )
653
- return result
654
 
655
  model_output = output
656
  output = model_output.logits
@@ -701,8 +447,6 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
701
  else:
702
  loss = ce_loss + mask_loss
703
 
704
- loss = loss + bridge_teacher_loss
705
-
706
  return {
707
  "loss": loss,
708
  "ce_loss": ce_loss,
@@ -710,12 +454,6 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
710
  "mask_dice_loss": mask_dice_loss,
711
  "mask_loss": mask_loss,
712
  "ct_loss": ct_loss,
713
- "bridge_pm_loss": bridge_pm_loss,
714
- "bridge_rg_loss": bridge_rg_loss,
715
- "bridge_norm_loss": bridge_norm_loss,
716
- "bridge_teacher_loss": bridge_teacher_loss,
717
- "bridge_teacher_loss_raw": bridge_teacher_loss_raw,
718
- "bridge_metrics": bridge_metrics,
719
  "pred_masks": pred_masks,
720
  "gt_masks": gt_masks,
721
  }
@@ -723,3 +461,4 @@ class Simtoken_ForCausalLM(ChatUniViLlamaForCausalLM):
723
 
724
  def evaluate(self, *args, **kwargs):
725
  raise NotImplementedError("This method is not implemented.")
 
 
100
  return total_loss / count
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  class Simtoken_MetaModel:
 
115
  self.config.train_mask_decoder = kwargs["train_mask_decoder"]
116
  self.config.out_dim = kwargs["out_dim"]
117
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
 
 
 
 
 
 
118
  else:
119
  self.vision_pretrained = kwargs.get("vision_pretrained", None)
120
  self.initialize_lisa_modules(self.config)
 
143
  for param in self.text_hidden_fcs.parameters():
144
  param.requires_grad = True
145
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  class Simtoken_Model(Simtoken_MetaModel, ChatUniViLlamaModel):
148
  def __init__(
 
234
  self.compress = kwargs.pop("compress", True)
235
 
236
  self.start = kwargs.pop("start")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
 
239
 
 
270
  epoch: int =0,
271
  inference: bool = False,
272
  num_frames: int = 10,
 
273
  contrast: float = 0.0,
274
 
275
  **kwargs,
 
282
  # audio_embeddings = torch.cat(audio_features, dim=0) # [B*10, 128]
283
  # audio_embeddings = audio_features # [B, 10, 128]
284
 
285
+ # train
286
+ if not inference:
287
+ target_frame = random.randint(0, 9)
288
+ target_frame = 5
289
+
290
  else:
291
+ target_frame = 5
292
+ # print("target_frame", target_frame)
 
293
 
294
  input_ids, attention_masks, past_key_values, inputs_embeds, labels = super().prepare_inputs_labels_for_multimodal(
295
  input_ids, attention_masks, past_key_values=None, labels=labels, images=images_clip, audio_features=audio_embeddings, target_frame=target_frame, ref_ids=ref_ids
 
313
  dim=1, ) # [batch_size, seq_len]
314
 
315
 
316
+ seg_embeddings = self.model.text_hidden_fcs[0](output_hidden_states[-1][seg_token_mask]) # [seg_num,256]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  # print("seg_embeddings in this batch:", seg_embeddings.shape)
319
  # print("vids:", vids)
 
337
 
338
 
339
  pred_embeddings = []
 
 
340
  #--------------------------------------------------------------------------------------------
341
  pred_idx = 0
342
  for ref_num in refs_num:
343
  pred_embeddings.append(seg_embeddings[pred_idx:pred_idx + ref_num])
 
 
344
  pred_idx += ref_num
345
  # list[B]:[num_seg, 256]
346
 
 
357
  points=None,
358
  boxes=None,
359
  masks=None,
360
+ text_embeds=pred_embeddings[i].unsqueeze(1), # [1, 1 ,256]
361
  )
362
  # 确保数据类型一致
363
  sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
 
393
  gt_masks = masks_list # list[B]:[num_seg, T, H, W]
394
 
395
  if inference:
396
+ return {
397
+ "pred_masks": pred_masks, # list[B]:[num_seg, T, H, W]
398
+ "gt_masks": gt_masks, # list[B]:[num_seg, T, H, W]
 
 
 
399
  }
 
 
 
 
 
 
 
 
 
 
400
 
401
  model_output = output
402
  output = model_output.logits
 
447
  else:
448
  loss = ce_loss + mask_loss
449
 
 
 
450
  return {
451
  "loss": loss,
452
  "ce_loss": ce_loss,
 
454
  "mask_dice_loss": mask_dice_loss,
455
  "mask_loss": mask_loss,
456
  "ct_loss": ct_loss,
 
 
 
 
 
 
457
  "pred_masks": pred_masks,
458
  "gt_masks": gt_masks,
459
  }
 
461
 
462
  def evaluate(self, *args, **kwargs):
463
  raise NotImplementedError("This method is not implemented.")
464
+
models/segment_anything/modeling/mask_decoder.py CHANGED
@@ -140,17 +140,7 @@ class MaskDecoder(nn.Module):
140
  b, c, h, w = src.shape
141
 
142
  # Run the transformer
143
- referent_token_index = (
144
- 1 + self.num_mask_tokens if sparse_prompt_embeddings.shape[1] > 0 else None
145
- )
146
- hs, src = self.transformer(
147
- src,
148
- pos_src,
149
- tokens,
150
- mask_token_start=1,
151
- num_mask_tokens=self.num_mask_tokens,
152
- referent_token_index=referent_token_index,
153
- )
154
  iou_token_out = hs[:, 0, :]
155
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
156
 
@@ -198,17 +188,7 @@ class MaskDecoder(nn.Module):
198
  _, c, h, w = src.shape
199
 
200
  # Run the transformer
201
- referent_token_index = (
202
- 1 + self.num_mask_tokens if sparse_prompt_embeddings.shape[1] > 0 else None
203
- )
204
- hs, src = self.transformer(
205
- src,
206
- pos_src,
207
- tokens,
208
- mask_token_start=1,
209
- num_mask_tokens=self.num_mask_tokens,
210
- referent_token_index=referent_token_index,
211
- )
212
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
213
 
214
  # Upscale mask embeddings and predict masks using the mask tokens
 
140
  b, c, h, w = src.shape
141
 
142
  # Run the transformer
143
+ hs, src = self.transformer(src, pos_src, tokens)
 
 
 
 
 
 
 
 
 
 
144
  iou_token_out = hs[:, 0, :]
145
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
146
 
 
188
  _, c, h, w = src.shape
189
 
190
  # Run the transformer
191
+ hs, src = self.transformer(src, pos_src, tokens)
 
 
 
 
 
 
 
 
 
 
192
  mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
193
 
194
  # Upscale mask embeddings and predict masks using the mask tokens
models/segment_anything/modeling/transformer.py CHANGED
@@ -9,7 +9,6 @@ from typing import Tuple, Type
9
 
10
  import torch
11
  from torch import Tensor, nn
12
- from torch.nn import functional as F
13
 
14
  from .common import MLPBlock
15
 
@@ -65,9 +64,6 @@ class TwoWayTransformer(nn.Module):
65
  image_embedding: Tensor,
66
  image_pe: Tensor,
67
  point_embedding: Tensor,
68
- mask_token_start: int = None,
69
- num_mask_tokens: int = None,
70
- referent_token_index: int = None,
71
  ) -> Tuple[Tensor, Tensor]:
72
  """
73
  Args:
@@ -98,9 +94,6 @@ class TwoWayTransformer(nn.Module):
98
  keys=keys,
99
  query_pe=point_embedding,
100
  key_pe=image_pe,
101
- mask_token_start=mask_token_start,
102
- num_mask_tokens=num_mask_tokens,
103
- referent_token_index=referent_token_index,
104
  )
105
 
106
  # Apply the final attention layer from the points to the image
@@ -152,19 +145,11 @@ class TwoWayAttentionBlock(nn.Module):
152
  self.cross_attn_image_to_token = Attention(
153
  embedding_dim, num_heads, downsample_rate=attention_downsample_rate
154
  )
155
- self.referent_gate = ReferentGate(embedding_dim)
156
 
157
  self.skip_first_layer_pe = skip_first_layer_pe
158
 
159
  def forward(
160
- self,
161
- queries: Tensor,
162
- keys: Tensor,
163
- query_pe: Tensor,
164
- key_pe: Tensor,
165
- mask_token_start: int = None,
166
- num_mask_tokens: int = None,
167
- referent_token_index: int = None,
168
  ) -> Tuple[Tensor, Tensor]:
169
  # Self attention block
170
  if self.skip_first_layer_pe:
@@ -175,17 +160,6 @@ class TwoWayAttentionBlock(nn.Module):
175
  queries = queries + attn_out
176
  queries = self.norm1(queries)
177
 
178
- if (
179
- mask_token_start is not None
180
- and num_mask_tokens is not None
181
- and referent_token_index is not None
182
- ):
183
- mask_slice = slice(mask_token_start, mask_token_start + num_mask_tokens)
184
- mask_tokens = queries[:, mask_slice, :]
185
- referent_token = queries[:, referent_token_index : referent_token_index + 1, :]
186
- queries = queries.clone()
187
- queries[:, mask_slice, :] = self.referent_gate(mask_tokens, referent_token)
188
-
189
  # Cross attention block, tokens attending to image embedding
190
  q = queries + query_pe
191
  k = keys + key_pe
@@ -208,26 +182,6 @@ class TwoWayAttentionBlock(nn.Module):
208
  return queries, keys
209
 
210
 
211
- class ReferentGate(nn.Module):
212
- def __init__(self, embedding_dim: int) -> None:
213
- super().__init__()
214
- self.gate = nn.Linear(embedding_dim * 2 + 1, embedding_dim)
215
- self.proj = nn.Linear(embedding_dim, embedding_dim)
216
- nn.init.zeros_(self.gate.weight)
217
- nn.init.zeros_(self.gate.bias)
218
- nn.init.zeros_(self.proj.weight)
219
- nn.init.zeros_(self.proj.bias)
220
- self.last_alpha = None
221
-
222
- def forward(self, mask_tokens: Tensor, referent_token: Tensor) -> Tensor:
223
- referent = referent_token.expand_as(mask_tokens)
224
- cosine = F.cosine_similarity(mask_tokens, referent, dim=-1).unsqueeze(-1)
225
- gate_input = torch.cat([mask_tokens, referent, cosine], dim=-1)
226
- alpha = torch.sigmoid(self.gate(gate_input))
227
- self.last_alpha = alpha.detach()
228
- return mask_tokens + alpha * self.proj(referent)
229
-
230
-
231
  class Attention(nn.Module):
232
  """
233
  An attention layer that allows for downscaling the size of the embedding
 
9
 
10
  import torch
11
  from torch import Tensor, nn
 
12
 
13
  from .common import MLPBlock
14
 
 
64
  image_embedding: Tensor,
65
  image_pe: Tensor,
66
  point_embedding: Tensor,
 
 
 
67
  ) -> Tuple[Tensor, Tensor]:
68
  """
69
  Args:
 
94
  keys=keys,
95
  query_pe=point_embedding,
96
  key_pe=image_pe,
 
 
 
97
  )
98
 
99
  # Apply the final attention layer from the points to the image
 
145
  self.cross_attn_image_to_token = Attention(
146
  embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
  )
 
148
 
149
  self.skip_first_layer_pe = skip_first_layer_pe
150
 
151
  def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
 
 
 
 
 
 
 
153
  ) -> Tuple[Tensor, Tensor]:
154
  # Self attention block
155
  if self.skip_first_layer_pe:
 
160
  queries = queries + attn_out
161
  queries = self.norm1(queries)
162
 
 
 
 
 
 
 
 
 
 
 
 
163
  # Cross attention block, tokens attending to image embedding
164
  q = queries + query_pe
165
  k = keys + key_pe
 
182
  return queries, keys
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  class Attention(nn.Module):
186
  """
187
  An attention layer that allows for downscaling the size of the embedding
save_audio_feats.py CHANGED
@@ -80,3 +80,4 @@ for vid in vids:
80
  # print(f"{vid}: {audio_embed.shape}")
81
  torch.save(audio_embed, f'{save_dir}/{vid}.pt')
82
  print(f'{vid} embedding saved {audio_embed.shape}')
 
 
80
  # print(f"{vid}: {audio_embed.shape}")
81
  torch.save(audio_embed, f'{save_dir}/{vid}.pt')
82
  print(f'{vid} embedding saved {audio_embed.shape}')
83
+
train.py CHANGED
@@ -1,7 +1,7 @@
1
  import transformers
2
  from datasets import REFAVS
3
  from configs import args
4
- from torch.utils.data import DataLoader, Subset
5
  from functools import partial
6
  from models.llava import conversation as conversation_lib
7
  # from models.avs_model import VISAForCausalLM
@@ -21,9 +21,6 @@ import numpy as np
21
  import re
22
  import time
23
  import os
24
- import sys
25
- import json
26
- from collections import defaultdict
27
 
28
 
29
  import warnings
@@ -216,61 +213,10 @@ def collate_fn(batch, tokenizer=None):
216
  }
217
 
218
 
219
- def maybe_limit_dataset(dataset, max_rows, name):
220
- if max_rows is None or max_rows <= 0:
221
- return dataset
222
- limited_n = min(max_rows, len(dataset))
223
- print(f"max_eval_rows enabled: using first {limited_n} samples from {name}")
224
- return Subset(dataset, list(range(limited_n)))
225
-
226
-
227
- def load_subset_manifest(path):
228
- if not path:
229
- return {}
230
- with open(path, "r", encoding="utf-8") as f:
231
- manifest = json.load(f)
232
- if not isinstance(manifest, dict):
233
- raise ValueError(f"subset_manifest must be a JSON object, got {type(manifest).__name__}")
234
- if "subsets" in manifest:
235
- manifest = manifest["subsets"]
236
- return manifest
237
-
238
-
239
- def maybe_apply_manifest_subset(dataset, manifest, split_name, name):
240
- if split_name not in manifest:
241
- return dataset
242
- indices = manifest[split_name]
243
- if not isinstance(indices, list) or not all(isinstance(i, int) for i in indices):
244
- raise ValueError(f"subset_manifest[{split_name!r}] must be a list of integers")
245
- if not indices:
246
- raise ValueError(f"subset_manifest[{split_name!r}] is empty")
247
- max_index = len(dataset) - 1
248
- bad_indices = [i for i in indices if i < 0 or i > max_index]
249
- if bad_indices:
250
- raise ValueError(
251
- f"subset_manifest[{split_name!r}] contains out-of-range indices; "
252
- f"dataset size={len(dataset)}, examples={bad_indices[:5]}"
253
- )
254
- print(f"subset_manifest enabled: using {len(indices)} fixed samples from {name} ({split_name})")
255
- return Subset(dataset, indices)
256
-
257
-
258
- def checkpoint_requires_lora(saved_model_path):
259
- if not saved_model_path or not os.path.exists(saved_model_path):
260
- return False
261
- state = torch.load(saved_model_path, map_location="cpu")
262
- return any("lora_" in key for key in state.keys())
263
-
264
-
265
  import torch.multiprocessing as mp
266
  if __name__ == "__main__":
267
- try:
268
- mp.set_start_method("spawn")
269
- except RuntimeError:
270
- pass
271
  set_seed(42)
272
- if args.bridge_only and not args.use_residual_prompt_bridge:
273
- raise ValueError("--bridge_only requires --use_residual_prompt_bridge")
274
  tokenizer = transformers.AutoTokenizer.from_pretrained(
275
  args.mllm,
276
  cache_dir=None,
@@ -283,34 +229,17 @@ if __name__ == "__main__":
283
  num_added_tokens = tokenizer.add_tokens("[SEG]")
284
  seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] # 32000
285
  print("seg_token_idx: ", seg_token_idx)
286
- subset_manifest = load_subset_manifest(args.subset_manifest)
287
 
288
  train_dataset = REFAVS('train', args, tokenizer, input_type='refer')
289
  val_dataset_s_refer = REFAVS('test_s', args, tokenizer, input_type='refer')
290
  val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
291
  val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
292
 
293
- train_dataset = maybe_apply_manifest_subset(train_dataset, subset_manifest, "train", "train")
294
- val_dataset_s_refer = maybe_apply_manifest_subset(val_dataset_s_refer, subset_manifest, "test_s", "test_s")
295
- val_dataset_u_refer = maybe_apply_manifest_subset(val_dataset_u_refer, subset_manifest, "test_u", "test_u")
296
- val_dataset_n_refer = maybe_apply_manifest_subset(val_dataset_n_refer, subset_manifest, "test_n", "test_n")
297
-
298
- if args.overfit_samples > 0:
299
- overfit_n = min(args.overfit_samples, len(train_dataset))
300
- train_dataset = Subset(train_dataset, list(range(overfit_n)))
301
- print(f"overfit_samples enabled: using first {overfit_n} train samples")
302
-
303
- train_eval_dataset = maybe_limit_dataset(train_dataset, args.max_eval_rows, "train_eval")
304
- val_dataset_s_refer = maybe_limit_dataset(val_dataset_s_refer, args.max_eval_rows, "test_s")
305
- val_dataset_u_refer = maybe_limit_dataset(val_dataset_u_refer, args.max_eval_rows, "test_u")
306
- val_dataset_n_refer = maybe_limit_dataset(val_dataset_n_refer, args.max_eval_rows, "test_n")
307
-
308
 
309
  g = torch.Generator()
310
  g.manual_seed(42)
311
 
312
  train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g)
313
- train_eval_dataloader = DataLoader(train_eval_dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
314
 
315
  val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
316
  val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
@@ -320,25 +249,15 @@ if __name__ == "__main__":
320
  model_args = {
321
  "train_mask_decoder": True,
322
  "out_dim": 256, # 256
323
- "ce_loss_weight": args.ce_loss_weight,
324
- "dice_loss_weight": args.dice_loss_weight,
325
- "bce_loss_weight": args.bce_loss_weight,
326
  "seg_token_idx": seg_token_idx,
327
  "vision_pretrained": args.vision_pretrained, # sam_vit_h_xxx.pth
328
  "vision_tower": args.vision_tower,
329
  "use_im_start_end": False,
330
  "compress": args.compress,
331
  "start": args.start,
332
- "use_residual_prompt_bridge": args.use_residual_prompt_bridge,
333
- "bridge_pm_weight": args.bridge_pm_weight,
334
- "bridge_rg_weight": args.bridge_rg_weight,
335
- "bridge_norm_weight": args.bridge_norm_weight,
336
- "bridge_mode": args.bridge_mode,
337
- "bridge_condition": args.bridge_condition,
338
- "bridge_directional_alpha": args.bridge_directional_alpha,
339
- "bridge_gate_bias_init": args.bridge_gate_bias_init,
340
- "bridge_residual_init_std": args.bridge_residual_init_std,
341
- "bridge_target_frame": args.bridge_target_frame,
342
  }
343
 
344
  model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.float32, low_cpu_mem_usage=True, **model_args)
@@ -374,17 +293,7 @@ if __name__ == "__main__":
374
  for p in model.get_model().mm_projector.parameters():
375
  p.requires_grad = False
376
 
377
- use_lora_checkpoint = (
378
- (args.init_from_saved_model or args.gate_only)
379
- and checkpoint_requires_lora(args.saved_model)
380
- )
381
- if args.bridge_only and use_lora_checkpoint:
382
- print(
383
- "bridge_only notice: saved_model contains LoRA weights, "
384
- "so LoRA modules will be instantiated for checkpoint compatibility and then frozen."
385
- )
386
-
387
- lora_r = 8 if (not args.bridge_only or use_lora_checkpoint) else 0
388
  target_modules = "q_proj,v_proj"
389
  if lora_r > 0:
390
 
@@ -440,11 +349,6 @@ if __name__ == "__main__":
440
  model = model.to("cuda")
441
  model.resize_token_embeddings(len(tokenizer))
442
 
443
- if args.init_from_saved_model or args.gate_only:
444
- state = torch.load(args.saved_model, map_location="cpu")
445
- missing, unexpected = model.load_state_dict(state, strict=False)
446
- print(f"initialized training from saved model: {args.saved_model}")
447
- print(f"missing keys: {len(missing)} | unexpected keys: {len(unexpected)}")
448
 
449
  for name, param in model.audio_feature_layer.named_parameters():
450
  param.requires_grad = True
@@ -452,274 +356,25 @@ if __name__ == "__main__":
452
  # for name, param in model.token_compressor.named_parameters():
453
  # param.requires_grad = True
454
 
 
455
  for n, p in model.named_parameters():
456
  if any(
457
- [
458
- x in n
459
- for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
460
- ]
461
  ):
462
  p.requires_grad = True
463
 
464
- if args.bridge_only:
465
- for p in model.parameters():
466
- p.requires_grad = False
467
- trainable_names = []
468
- for n, p in model.named_parameters():
469
- if "prompt_bridge" in n:
470
- p.requires_grad = True
471
- trainable_names.append(n)
472
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
473
- total = sum(p.numel() for p in model.parameters())
474
- print(f"bridge_only enabled: trainable params {trainable} / {total}")
475
- for name in trainable_names:
476
- print(f" bridge trainable: {name}")
477
- elif args.gate_only:
478
- for p in model.parameters():
479
- p.requires_grad = False
480
- for n, p in model.named_parameters():
481
- if "referent_gate" in n:
482
- p.requires_grad = True
483
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
484
- total = sum(p.numel() for p in model.parameters())
485
- print(f"gate_only enabled: trainable params {trainable} / {total}")
486
 
487
  print("will save train model")
488
 
489
- def _total_norm(values):
490
- if not values:
491
- return 0.0
492
- return float(sum(v * v for v in values) ** 0.5)
493
-
494
- def collect_referent_gate_stats(model):
495
- gate_modules = [(n, m) for n, m in model.named_modules() if n.endswith("referent_gate")]
496
- proj_norms = []
497
- gate_norms = []
498
- proj_grad_norms = []
499
- gate_grad_norms = []
500
- alpha_tensors = []
501
-
502
- for _, module in gate_modules:
503
- proj_norms.append(module.proj.weight.detach().float().norm().item())
504
- gate_norms.append(module.gate.weight.detach().float().norm().item())
505
- if module.proj.weight.grad is not None:
506
- proj_grad_norms.append(module.proj.weight.grad.detach().float().norm().item())
507
- if module.gate.weight.grad is not None:
508
- gate_grad_norms.append(module.gate.weight.grad.detach().float().norm().item())
509
- if module.last_alpha is not None:
510
- alpha_tensors.append(module.last_alpha.detach().float().reshape(-1))
511
-
512
- stats = {
513
- "modules": len(gate_modules),
514
- "proj_norm": _total_norm(proj_norms),
515
- "gate_norm": _total_norm(gate_norms),
516
- "proj_grad_norm": _total_norm(proj_grad_norms),
517
- "gate_grad_norm": _total_norm(gate_grad_norms),
518
- }
519
-
520
- if alpha_tensors:
521
- alpha = torch.cat(alpha_tensors)
522
- stats.update(
523
- {
524
- "alpha_mean": alpha.mean().item(),
525
- "alpha_std": alpha.std(unbiased=False).item(),
526
- "alpha_min": alpha.min().item(),
527
- "alpha_max": alpha.max().item(),
528
- }
529
- )
530
- else:
531
- stats.update(
532
- {
533
- "alpha_mean": float("nan"),
534
- "alpha_std": float("nan"),
535
- "alpha_min": float("nan"),
536
- "alpha_max": float("nan"),
537
- }
538
- )
539
-
540
- return stats
541
-
542
- def print_referent_gate_optimizer_sanity(model, optimizer):
543
- optimizer_param_ids = {id(p) for group in optimizer.param_groups for p in group["params"]}
544
- gate_params = [(n, p) for n, p in model.named_parameters() if "referent_gate" in n]
545
- trainable_gate = [(n, p) for n, p in gate_params if p.requires_grad]
546
- optimizer_gate = [(n, p) for n, p in gate_params if id(p) in optimizer_param_ids]
547
- optimizer_trainable_gate = [
548
- (n, p) for n, p in gate_params if p.requires_grad and id(p) in optimizer_param_ids
549
- ]
550
- print(
551
- "referent_gate sanity: "
552
- f"params={sum(p.numel() for _, p in gate_params)} | "
553
- f"trainable={sum(p.numel() for _, p in trainable_gate)} | "
554
- f"in_optimizer={sum(p.numel() for _, p in optimizer_gate)} | "
555
- f"trainable_in_optimizer={sum(p.numel() for _, p in optimizer_trainable_gate)}"
556
- )
557
-
558
- stats = collect_referent_gate_stats(model)
559
- print(
560
- "referent_gate init stats: "
561
- f"modules={stats['modules']} | "
562
- f"proj_norm={stats['proj_norm']:.6f} | "
563
- f"gate_norm={stats['gate_norm']:.6f}"
564
- )
565
-
566
- def log_referent_gate_stats(global_step, loss_value):
567
- stats = collect_referent_gate_stats(model)
568
- message = (
569
- f"gate_stats step={global_step} "
570
- f"loss={loss_value:.6f} "
571
- f"proj_norm={stats['proj_norm']:.6f} "
572
- f"gate_norm={stats['gate_norm']:.6f} "
573
- f"proj_grad_norm={stats['proj_grad_norm']:.6f} "
574
- f"gate_grad_norm={stats['gate_grad_norm']:.6f} "
575
- f"alpha_mean={stats['alpha_mean']:.4f} "
576
- f"alpha_std={stats['alpha_std']:.4f} "
577
- f"alpha_min={stats['alpha_min']:.4f} "
578
- f"alpha_max={stats['alpha_max']:.4f}"
579
- )
580
- print(message)
581
- with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
582
- f.write(message + "\n")
583
-
584
- def find_prompt_bridge_module(model):
585
- for _, module in model.named_modules():
586
- if module.__class__.__name__ == "ResidualPromptBridge":
587
- return module
588
- return None
589
-
590
- def collect_prompt_bridge_grad_norms(model):
591
- module = find_prompt_bridge_module(model)
592
- if module is None:
593
- return {}
594
-
595
- def grad_norm(param):
596
- if param.grad is None:
597
- return None
598
- return float(param.grad.detach().float().norm().item())
599
-
600
- return {
601
- "W_a": grad_norm(module.attn_proj.weight),
602
- "W_r": grad_norm(module.residual_proj.weight),
603
- "W_g": grad_norm(module.gate.weight),
604
- "b_g": grad_norm(module.gate.bias),
605
- }
606
-
607
- def print_prompt_bridge_grad_norms(label, norms):
608
- parts = []
609
- for key in ["W_a", "W_r", "W_g", "b_g"]:
610
- value = norms.get(key)
611
- if value is None:
612
- parts.append(f"{key}=None")
613
- else:
614
- parts.append(f"{key}={value:.6e}")
615
- print(f"{label}: " + " | ".join(parts))
616
-
617
- def run_bridge_sanity_checks(model, dataloader):
618
- if not args.use_residual_prompt_bridge:
619
- raise ValueError("--bridge_sanity_only requires --use_residual_prompt_bridge")
620
-
621
- model.train()
622
- batch = next(iter(dataloader))
623
- input_dict = dict_to_cuda(batch)
624
-
625
- output_dict = model.forward(
626
- images=input_dict["images"],
627
- images_clip=input_dict["images_clip"],
628
- audio_features=input_dict["audio_feats"],
629
- image_features=input_dict["image_feats"],
630
- input_ids=input_dict["input_ids"],
631
- labels=input_dict["labels"],
632
- attention_masks=input_dict["attention_masks"],
633
- masks_list=input_dict["masks"],
634
- resize_list=input_dict["resizes"],
635
- orgsize_list=input_dict["orgsizes"],
636
- conversation_list=input_dict["convs"],
637
- refs_num=input_dict["refs_num"],
638
- fids=input_dict["fids"],
639
- vids=input_dict["vids"],
640
- contrast=0.0,
641
- ref_ids=input_dict["ref_ids"],
642
- epoch=0,
643
- inference=False,
644
- target_frame=args.bridge_target_frame,
645
- )
646
-
647
- model.zero_grad(set_to_none=True)
648
- output_dict["mask_loss"].backward(retain_graph=True)
649
- print_prompt_bridge_grad_norms(
650
- "bridge grad check | L_mask only",
651
- collect_prompt_bridge_grad_norms(model),
652
- )
653
-
654
- model.zero_grad(set_to_none=True)
655
- output_dict["bridge_teacher_loss_raw"].backward()
656
- print_prompt_bridge_grad_norms(
657
- "bridge grad check | L_teach only",
658
- collect_prompt_bridge_grad_norms(model),
659
- )
660
-
661
- metrics = output_dict["bridge_metrics"]
662
- print(
663
- "bridge identity check: "
664
- f"delta_norm_mean={metrics['delta_norm_mean']:.6f} | "
665
- f"cos(p_hat,q)={metrics['cos_p_hat_q_mean']:.6f} | "
666
- f"q_norm_mean={metrics['q_norm_mean']:.6f} | "
667
- f"p_hat_norm_mean={metrics['p_hat_norm_mean']:.6f} | "
668
- f"gate_mean={metrics['gate_mean']:.6f} | "
669
- f"gate_std={metrics['gate_std']:.6f}"
670
- )
671
-
672
- teacher_pm_norms = []
673
- teacher_rg_norms = []
674
- teacher_cosines = []
675
- scanned_batches = max(1, args.bridge_sanity_batches)
676
-
677
- model.eval()
678
- with torch.no_grad():
679
- for batch_idx, batch in enumerate(dataloader):
680
- if batch_idx >= scanned_batches:
681
- break
682
- input_dict = dict_to_cuda(batch)
683
- result = model.forward(
684
- images=input_dict["images"],
685
- images_clip=input_dict["images_clip"],
686
- audio_features=input_dict["audio_feats"],
687
- image_features=input_dict["image_feats"],
688
- input_ids=input_dict["input_ids"],
689
- labels=input_dict["labels"],
690
- attention_masks=input_dict["attention_masks"],
691
- masks_list=input_dict["masks"],
692
- resize_list=input_dict["resizes"],
693
- orgsize_list=input_dict["orgsizes"],
694
- conversation_list=input_dict["convs"],
695
- refs_num=input_dict["refs_num"],
696
- fids=input_dict["fids"],
697
- vids=input_dict["vids"],
698
- contrast=0.0,
699
- ref_ids=input_dict["ref_ids"],
700
- inference=True,
701
- target_frame=args.bridge_target_frame,
702
- )
703
- bridge_metrics = result["bridge_metrics"]
704
- teacher_pm_norms.append(bridge_metrics["p_mask_norm_mean"])
705
- teacher_rg_norms.append(bridge_metrics["z_gt_norm_mean"])
706
- teacher_cosines.append(bridge_metrics["cos_p_mask_z_gt_mean"])
707
-
708
- print(
709
- "bridge teacher sanity: "
710
- f"mean||p_mask||={float(np.mean(teacher_pm_norms)):.6f} | "
711
- f"mean||z_gt||={float(np.mean(teacher_rg_norms)):.6f} | "
712
- f"mean cos(p_mask,z_gt)={float(np.mean(teacher_cosines)):.6f}"
713
- )
714
-
715
  def valuate(model, dataloader, args, name):
716
  model.eval()
717
 
718
  total_iou = 0
719
  total_fscore = 0
720
  count = 0
721
- bridge_accumulators = defaultdict(float)
722
- bridge_count = 0
723
 
724
  for batch in tqdm(dataloader, desc=f"Evaluating on {name}"):
725
  input_dict = dict_to_cuda(batch)
@@ -740,8 +395,7 @@ if __name__ == "__main__":
740
  vids=input_dict["vids"],
741
  contrast=args.ct_weight,
742
  ref_ids=input_dict["ref_ids"],
743
- inference=True,
744
- target_frame=args.bridge_target_frame)
745
  pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
746
  gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
747
  for i in range(len(pred_masks)):
@@ -754,46 +408,23 @@ if __name__ == "__main__":
754
  total_fscore += fscore * num_seg * T
755
  count += num_seg * T
756
 
757
- if args.use_residual_prompt_bridge and "bridge_metrics" in output_dict:
758
- for key, value in output_dict["bridge_metrics"].items():
759
- bridge_accumulators[key] += float(value)
760
- bridge_count += 1
761
-
762
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
763
 
764
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
765
  f.write(f"valuate on {name}: miou {total_iou/count} true fscore {total_fscore/count} \n")
766
- if bridge_count > 0:
767
- bridge_summary = " | ".join(
768
- f"{key}={bridge_accumulators[key] / bridge_count:.6f}"
769
- for key in sorted(bridge_accumulators.keys())
770
- )
771
- print(f" bridge on {name}: {bridge_summary}")
772
- f.write(f"bridge on {name}: {bridge_summary}\n")
773
-
774
 
775
- if args.bridge_sanity_only:
776
- run_bridge_sanity_checks(model, train_eval_dataloader)
777
- sys.exit(0)
778
 
779
  # ---------------train------------------------------------------
780
 
781
  model.train()
782
  epochs = args.epochs
783
  print("init lr:", args.lr)
784
- trainable_params = [p for p in model.parameters() if p.requires_grad]
785
- optimizer = AdamW(trainable_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
786
- print_referent_gate_optimizer_sanity(model, optimizer)
787
-
788
- gradient_accumulation_steps = max(1, int(16 // args.batch_size))
789
- step_per_epoch = max(1, len(train_dataloader) // gradient_accumulation_steps)
790
- full_total_steps = epochs * step_per_epoch
791
- total_steps = min(args.max_steps, full_total_steps) if args.max_steps > 0 else full_total_steps
792
  warmup_steps = int(total_steps * 0.1)
793
- print(
794
- f"training schedule: grad_accum={gradient_accumulation_steps} | "
795
- f"step_per_epoch={step_per_epoch} | total_optimizer_steps={total_steps}"
796
- )
797
 
798
  scheduler = get_cosine_schedule_with_warmup(
799
  optimizer,
@@ -802,9 +433,6 @@ if __name__ == "__main__":
802
  )
803
 
804
 
805
- optimizer_step_count = 0
806
- stop_training = False
807
-
808
  for epoch in range(epochs):
809
 
810
  model.train()
@@ -813,9 +441,6 @@ if __name__ == "__main__":
813
 
814
  loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}")
815
  for step, batch in enumerate(loop):
816
- if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
817
- stop_training = True
818
- break
819
  input_dict = dict_to_cuda(batch)
820
  output_dict = model.forward(images=input_dict["images"],
821
  images_clip=input_dict["images_clip"],
@@ -834,7 +459,6 @@ if __name__ == "__main__":
834
  contrast=args.ct_weight,
835
  ref_ids=input_dict["ref_ids"],
836
  epoch=epoch,
837
- gate_only=args.gate_only,
838
  inference=False)
839
 
840
  loss = output_dict["loss"]
@@ -844,57 +468,23 @@ if __name__ == "__main__":
844
 
845
 
846
  if (step + 1) % gradient_accumulation_steps == 0:
847
- optimizer_step_count += 1
848
- if (
849
- args.log_gate_stats_every > 0
850
- and optimizer_step_count % args.log_gate_stats_every == 0
851
- ):
852
- log_referent_gate_stats(
853
- optimizer_step_count,
854
- loss.item() * gradient_accumulation_steps,
855
- )
856
  optimizer.step()
857
  scheduler.step()
858
  optimizer.zero_grad()
859
 
860
  current_lr = scheduler.get_lr()[0]
861
- postfix = {
862
- "lr": current_lr,
863
- "loss": running_loss / ((step + 1) / gradient_accumulation_steps),
864
- }
865
- if args.use_residual_prompt_bridge:
866
- postfix["bridge"] = float(output_dict["bridge_teacher_loss"].item())
867
- postfix["pm"] = float(output_dict["bridge_pm_loss"].item())
868
- postfix["rg"] = float(output_dict["bridge_rg_loss"].item())
869
- loop.set_postfix(**postfix)
870
-
871
- if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
872
- stop_training = True
873
- break
874
 
875
- denom = max(1, optimizer_step_count)
876
- print(f" Epoch {epoch + 1}, Loss:{running_loss / denom :.4f}, Learning Rate:{scheduler.get_last_lr()[0]:.6f}")
877
 
878
 
879
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
880
- f.write(f"Epoch {epoch}: running_loss {running_loss / denom} Learning Rate:{scheduler.get_last_lr()[0]:.6f}\n")
881
-
882
- if stop_training:
883
- print(f"stopped early at optimizer step {optimizer_step_count}")
884
- break
885
 
886
 
887
  torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth"))
888
  print(f"trained model saved as {args.name}.pth")
889
 
890
- if args.skip_eval_after_train:
891
- print("skip_eval_after_train enabled: exiting after checkpoint save")
892
- sys.exit(0)
893
-
894
- if args.eval_train_only:
895
- valuate(model, train_eval_dataloader, args, 'train_overfit')
896
- sys.exit(0)
897
-
898
  # ---------------test on seen & unseen ------------------------------------------
899
  model.eval()
900
 
@@ -941,4 +531,4 @@ if __name__ == "__main__":
941
  print(f"\n valuate on test_n_refer, metric: {total_metric/count}")
942
 
943
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
944
- f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n")
 
1
  import transformers
2
  from datasets import REFAVS
3
  from configs import args
4
+ from torch.utils.data import DataLoader
5
  from functools import partial
6
  from models.llava import conversation as conversation_lib
7
  # from models.avs_model import VISAForCausalLM
 
21
  import re
22
  import time
23
  import os
 
 
 
24
 
25
 
26
  import warnings
 
213
  }
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  import torch.multiprocessing as mp
217
  if __name__ == "__main__":
218
+ mp.set_start_method("spawn")
 
 
 
219
  set_seed(42)
 
 
220
  tokenizer = transformers.AutoTokenizer.from_pretrained(
221
  args.mllm,
222
  cache_dir=None,
 
229
  num_added_tokens = tokenizer.add_tokens("[SEG]")
230
  seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] # 32000
231
  print("seg_token_idx: ", seg_token_idx)
 
232
 
233
  train_dataset = REFAVS('train', args, tokenizer, input_type='refer')
234
  val_dataset_s_refer = REFAVS('test_s', args, tokenizer, input_type='refer')
235
  val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
236
  val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  g = torch.Generator()
240
  g.manual_seed(42)
241
 
242
  train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g)
 
243
 
244
  val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
245
  val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
 
249
  model_args = {
250
  "train_mask_decoder": True,
251
  "out_dim": 256, # 256
252
+ "ce_loss_weight": 1.0,
253
+ "dice_loss_weight": 0.5,
254
+ "bce_loss_weight": 2.0,
255
  "seg_token_idx": seg_token_idx,
256
  "vision_pretrained": args.vision_pretrained, # sam_vit_h_xxx.pth
257
  "vision_tower": args.vision_tower,
258
  "use_im_start_end": False,
259
  "compress": args.compress,
260
  "start": args.start,
 
 
 
 
 
 
 
 
 
 
261
  }
262
 
263
  model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.float32, low_cpu_mem_usage=True, **model_args)
 
293
  for p in model.get_model().mm_projector.parameters():
294
  p.requires_grad = False
295
 
296
+ lora_r = 8
 
 
 
 
 
 
 
 
 
 
297
  target_modules = "q_proj,v_proj"
298
  if lora_r > 0:
299
 
 
349
  model = model.to("cuda")
350
  model.resize_token_embeddings(len(tokenizer))
351
 
 
 
 
 
 
352
 
353
  for name, param in model.audio_feature_layer.named_parameters():
354
  param.requires_grad = True
 
356
  # for name, param in model.token_compressor.named_parameters():
357
  # param.requires_grad = True
358
 
359
+
360
  for n, p in model.named_parameters():
361
  if any(
362
+ [
363
+ x in n
364
+ for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
365
+ ]
366
  ):
367
  p.requires_grad = True
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  print("will save train model")
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  def valuate(model, dataloader, args, name):
373
  model.eval()
374
 
375
  total_iou = 0
376
  total_fscore = 0
377
  count = 0
 
 
378
 
379
  for batch in tqdm(dataloader, desc=f"Evaluating on {name}"):
380
  input_dict = dict_to_cuda(batch)
 
395
  vids=input_dict["vids"],
396
  contrast=args.ct_weight,
397
  ref_ids=input_dict["ref_ids"],
398
+ inference=True)
 
399
  pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
400
  gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
401
  for i in range(len(pred_masks)):
 
408
  total_fscore += fscore * num_seg * T
409
  count += num_seg * T
410
 
 
 
 
 
 
411
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
412
 
413
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
414
  f.write(f"valuate on {name}: miou {total_iou/count} true fscore {total_fscore/count} \n")
 
 
 
 
 
 
 
 
415
 
 
 
 
416
 
417
  # ---------------train------------------------------------------
418
 
419
  model.train()
420
  epochs = args.epochs
421
  print("init lr:", args.lr)
422
+ optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
423
+
424
+ gradient_accumulation_steps = int(16 // args.batch_size)
425
+ step_per_epoch = len(train_dataloader) // gradient_accumulation_steps
426
+ total_steps = epochs * step_per_epoch
 
 
 
427
  warmup_steps = int(total_steps * 0.1)
 
 
 
 
428
 
429
  scheduler = get_cosine_schedule_with_warmup(
430
  optimizer,
 
433
  )
434
 
435
 
 
 
 
436
  for epoch in range(epochs):
437
 
438
  model.train()
 
441
 
442
  loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}")
443
  for step, batch in enumerate(loop):
 
 
 
444
  input_dict = dict_to_cuda(batch)
445
  output_dict = model.forward(images=input_dict["images"],
446
  images_clip=input_dict["images_clip"],
 
459
  contrast=args.ct_weight,
460
  ref_ids=input_dict["ref_ids"],
461
  epoch=epoch,
 
462
  inference=False)
463
 
464
  loss = output_dict["loss"]
 
468
 
469
 
470
  if (step + 1) % gradient_accumulation_steps == 0:
 
 
 
 
 
 
 
 
 
471
  optimizer.step()
472
  scheduler.step()
473
  optimizer.zero_grad()
474
 
475
  current_lr = scheduler.get_lr()[0]
476
+ loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps))
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
+ print(f" Epoch {epoch + 1}, Loss:{running_loss / ((step + 1) / gradient_accumulation_steps) :.4f}, Learning Rate:{scheduler.get_last_lr()[0]:.6f}")
 
479
 
480
 
481
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
482
+ f.write(f"Epoch {epoch}: running_loss {running_loss / len(train_dataloader) * gradient_accumulation_steps} Learning Rate:{scheduler.get_last_lr()[0]:.6f}\n")
 
 
 
 
483
 
484
 
485
  torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth"))
486
  print(f"trained model saved as {args.name}.pth")
487
 
 
 
 
 
 
 
 
 
488
  # ---------------test on seen & unseen ------------------------------------------
489
  model.eval()
490
 
 
531
  print(f"\n valuate on test_n_refer, metric: {total_metric/count}")
532
 
533
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
534
+ f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n")