yfan07 commited on
Commit
65fb4ac
·
verified ·
1 Parent(s): 08ff7f7

Add files using upload-large-folder tool

Browse files
analyze_d2_csv.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import math
4
+ from collections import defaultdict
5
+
6
+ import numpy as np
7
+
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description="Analyze D2 frame-level CSV.")
11
+ parser.add_argument("--csv", required=True, help="Path to d2_llm_space.py or d2_basic.py CSV output.")
12
+ parser.add_argument("--beta", type=float, default=1.0)
13
+ parser.add_argument("--failure_iou", type=float, default=0.5)
14
+ parser.add_argument("--bottom_frac", type=float, default=0.2)
15
+ parser.add_argument("--pr_points", type=int, default=10)
16
+ return parser.parse_args()
17
+
18
+
19
+ def read_rows(path, beta):
20
+ rows = []
21
+ with open(path, newline="") as f:
22
+ reader = csv.DictReader(f)
23
+ for row in reader:
24
+ row_beta = float(row["beta"])
25
+ if abs(row_beta - beta) > 1e-8:
26
+ continue
27
+ q_col = "h_type" if "h_type" in row else "q_type"
28
+ rows.append(
29
+ {
30
+ "sample_idx": int(row["sample_idx"]),
31
+ "frame": int(row["frame"]),
32
+ "anchor_type": row[q_col],
33
+ "s_pred": float(row["s_pred"]),
34
+ "s_gt": float(row["s_gt"]),
35
+ "frame_iou": float(row["frame_iou"]),
36
+ "iou_pred": float(row["iou_pred"]),
37
+ "pred_area": float(row["pred_area"]),
38
+ "gt_area": float(row["gt_area"]),
39
+ }
40
+ )
41
+ if not rows:
42
+ raise RuntimeError(f"No rows found for beta={beta} in {path}")
43
+ return rows
44
+
45
+
46
+ def corr(x, y):
47
+ x = np.asarray(x, dtype=np.float64)
48
+ y = np.asarray(y, dtype=np.float64)
49
+ if len(x) < 2 or np.std(x) < 1e-12 or np.std(y) < 1e-12:
50
+ return float("nan")
51
+ return float(np.corrcoef(x, y)[0, 1])
52
+
53
+
54
+ def residualize(y, controls):
55
+ y = np.asarray(y, dtype=np.float64)
56
+ cols = [np.ones(len(y), dtype=np.float64)]
57
+ for control in controls:
58
+ cols.append(np.asarray(control, dtype=np.float64))
59
+ x = np.stack(cols, axis=1)
60
+ coef, *_ = np.linalg.lstsq(x, y, rcond=None)
61
+ return y - x @ coef
62
+
63
+
64
+ def r2_score(y, y_pred):
65
+ y = np.asarray(y, dtype=np.float64)
66
+ y_pred = np.asarray(y_pred, dtype=np.float64)
67
+ ss_res = np.sum((y - y_pred) ** 2)
68
+ ss_tot = np.sum((y - y.mean()) ** 2)
69
+ if ss_tot < 1e-12:
70
+ return float("nan")
71
+ return float(1.0 - ss_res / ss_tot)
72
+
73
+
74
+ def linear_r2(y, features):
75
+ y = np.asarray(y, dtype=np.float64)
76
+ cols = [np.ones(len(y), dtype=np.float64)]
77
+ for feature in features:
78
+ cols.append(np.asarray(feature, dtype=np.float64))
79
+ x = np.stack(cols, axis=1)
80
+ coef, *_ = np.linalg.lstsq(x, y, rcond=None)
81
+ return r2_score(y, x @ coef)
82
+
83
+
84
+ def real_rows(rows):
85
+ return [r for r in rows if r["anchor_type"] == "real"]
86
+
87
+
88
+ def bottom_failure_enrichment(rows, failure_iou, bottom_frac):
89
+ rr = real_rows(rows)
90
+ n = len(rr)
91
+ k = max(1, int(round(n * bottom_frac)))
92
+ sorted_rows = sorted(rr, key=lambda r: r["s_pred"])
93
+ bottom = sorted_rows[:k]
94
+ baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rr])
95
+ bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
96
+ total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
97
+ covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
98
+ recall = covered_failures / max(total_failures, 1)
99
+ enrichment = bottom_rate / max(baseline_rate, 1e-12)
100
+ return {
101
+ "n": n,
102
+ "k": k,
103
+ "baseline_failure_rate": baseline_rate,
104
+ "bottom_failure_rate": bottom_rate,
105
+ "bottom_failure_recall": recall,
106
+ "enrichment": enrichment,
107
+ "total_failures": total_failures,
108
+ }
109
+
110
+
111
+ def pr_curve(rows, failure_iou, points):
112
+ rr = sorted(real_rows(rows), key=lambda r: r["s_pred"])
113
+ total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
114
+ out = []
115
+ for frac in np.linspace(0.05, 1.0, points):
116
+ k = max(1, int(round(len(rr) * frac)))
117
+ selected = rr[:k]
118
+ failures = sum(r["frame_iou"] < failure_iou for r in selected)
119
+ precision = failures / k
120
+ recall = failures / max(total_failures, 1)
121
+ out.append((frac, precision, recall))
122
+ return out
123
+
124
+
125
+ def margin_rows(rows):
126
+ grouped = defaultdict(dict)
127
+ for r in rows:
128
+ key = (r["sample_idx"], r["frame"])
129
+ grouped[key][r["anchor_type"]] = r
130
+
131
+ out = []
132
+ for key, group in grouped.items():
133
+ if "real" not in group:
134
+ continue
135
+ controls = [group[name]["s_pred"] for name in ("shuffled", "wrong_ref") if name in group]
136
+ if not controls:
137
+ continue
138
+ real = group["real"]
139
+ item = dict(real)
140
+ item["s_margin"] = real["s_pred"] - max(controls)
141
+ out.append(item)
142
+ return out
143
+
144
+
145
+ def bottom_failure_enrichment_for_score(rows, score_key, failure_iou, bottom_frac):
146
+ n = len(rows)
147
+ k = max(1, int(round(n * bottom_frac)))
148
+ sorted_rows = sorted(rows, key=lambda r: r[score_key])
149
+ bottom = sorted_rows[:k]
150
+ baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rows])
151
+ bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
152
+ total_failures = sum(r["frame_iou"] < failure_iou for r in rows)
153
+ covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
154
+ return {
155
+ "n": n,
156
+ "k": k,
157
+ "baseline_failure_rate": baseline_rate,
158
+ "bottom_failure_rate": bottom_rate,
159
+ "bottom_failure_recall": covered_failures / max(total_failures, 1),
160
+ "enrichment": bottom_rate / max(baseline_rate, 1e-12),
161
+ }
162
+
163
+
164
+ def main():
165
+ args = parse_args()
166
+ rows = read_rows(args.csv, args.beta)
167
+ rr = real_rows(rows)
168
+
169
+ print(f"CSV: {args.csv}")
170
+ print(f"beta: {args.beta}")
171
+ print(f"real frames: {len(rr)}")
172
+ print(f"failure definition: frame_iou < {args.failure_iou}")
173
+
174
+ print("\nReal s_pred Correlations")
175
+ print(f"corr(s_pred, frame_iou): {corr([r['s_pred'] for r in rr], [r['frame_iou'] for r in rr]):+.4f}")
176
+ print(f"corr(s_pred, iou_pred): {corr([r['s_pred'] for r in rr], [r['iou_pred'] for r in rr]):+.4f}")
177
+ print(f"corr(s_pred, pred_area): {corr([r['s_pred'] for r in rr], [r['pred_area'] for r in rr]):+.4f}")
178
+
179
+ s_pred_values = [r["s_pred"] for r in rr]
180
+ frame_iou_values = [r["frame_iou"] for r in rr]
181
+ iou_pred_values = [r["iou_pred"] for r in rr]
182
+ pred_area_values = [r["pred_area"] for r in rr]
183
+ gt_area_values = [r["gt_area"] for r in rr]
184
+ partial_iou_pred = corr(
185
+ residualize(s_pred_values, [iou_pred_values]),
186
+ residualize(frame_iou_values, [iou_pred_values]),
187
+ )
188
+ partial_iou_area = corr(
189
+ residualize(s_pred_values, [iou_pred_values, pred_area_values]),
190
+ residualize(frame_iou_values, [iou_pred_values, pred_area_values]),
191
+ )
192
+ partial_iou_area_gt = corr(
193
+ residualize(s_pred_values, [iou_pred_values, pred_area_values, gt_area_values]),
194
+ residualize(frame_iou_values, [iou_pred_values, pred_area_values, gt_area_values]),
195
+ )
196
+ r2_iou_pred = linear_r2(frame_iou_values, [iou_pred_values])
197
+ r2_iou_pred_s = linear_r2(frame_iou_values, [iou_pred_values, s_pred_values])
198
+ r2_iou_pred_area = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values])
199
+ r2_iou_pred_area_s = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values, s_pred_values])
200
+
201
+ print("\nPartial Correlation / Residual Gain")
202
+ print(f"partial corr(s_pred, frame_iou | iou_pred): {partial_iou_pred:+.4f}")
203
+ print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area): {partial_iou_area:+.4f}")
204
+ print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area,gt_area): {partial_iou_area_gt:+.4f}")
205
+ print(f"R2 frame_iou ~ iou_pred: {r2_iou_pred:.4f}")
206
+ print(f"R2 frame_iou ~ iou_pred + s_pred: {r2_iou_pred_s:.4f} (gain {r2_iou_pred_s - r2_iou_pred:+.4f})")
207
+ print(f"R2 frame_iou ~ iou_pred + pred_area: {r2_iou_pred_area:.4f}")
208
+ print(f"R2 frame_iou ~ iou_pred + pred_area + s_pred: {r2_iou_pred_area_s:.4f} (gain {r2_iou_pred_area_s - r2_iou_pred_area:+.4f})")
209
+
210
+ stats = bottom_failure_enrichment(rows, args.failure_iou, args.bottom_frac)
211
+ print("\nBottom-k Failure Enrichment")
212
+ print(f"bottom_frac: {args.bottom_frac:.2f} ({stats['k']}/{stats['n']} frames)")
213
+ print(f"total failures: {stats['total_failures']}")
214
+ print(f"random/baseline failure rate: {stats['baseline_failure_rate']:.4f}")
215
+ print(f"bottom-s_pred failure rate: {stats['bottom_failure_rate']:.4f}")
216
+ print(f"bottom-s_pred failure recall: {stats['bottom_failure_recall']:.4f}")
217
+ print(f"enrichment: {stats['enrichment']:.2f}x")
218
+
219
+ print("\nPR Curve Summary")
220
+ print("selected_frac | precision | recall")
221
+ for frac, precision, recall in pr_curve(rows, args.failure_iou, args.pr_points):
222
+ print(f"{frac:.2f} | {precision:.4f} | {recall:.4f}")
223
+
224
+ mr = margin_rows(rows)
225
+ if mr:
226
+ print("\nOffline Margin-D2")
227
+ print(f"margin frames: {len(mr)}")
228
+ print(f"corr(s_margin, frame_iou): {corr([r['s_margin'] for r in mr], [r['frame_iou'] for r in mr]):+.4f}")
229
+ print(f"corr(s_margin, pred_area): {corr([r['s_margin'] for r in mr], [r['pred_area'] for r in mr]):+.4f}")
230
+ mstats = bottom_failure_enrichment_for_score(mr, "s_margin", args.failure_iou, args.bottom_frac)
231
+ print(f"bottom-s_margin failure rate: {mstats['bottom_failure_rate']:.4f}")
232
+ print(f"bottom-s_margin failure recall: {mstats['bottom_failure_recall']:.4f}")
233
+ print(f"margin enrichment: {mstats['enrichment']:.2f}x")
234
+ else:
235
+ print("\nOffline Margin-D2 skipped: shuffled/wrong_ref controls not available.")
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main()
cache_q_features.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import partial
4
+ from itertools import islice
5
+
6
+ import torch
7
+ import transformers
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from configs import args
12
+ from datasets import REFAVS
13
+ from decoder_invariance_check import build_model, set_seed
14
+ from load_model import collate_fn, dict_to_cuda
15
+
16
+
17
+ def _jsonable_size(size):
18
+ if isinstance(size, torch.Tensor):
19
+ return [int(x) for x in size.detach().cpu().tolist()]
20
+ return [int(x) for x in size]
21
+
22
+
23
+ def main():
24
+ set_seed(42)
25
+ torch.set_grad_enabled(False)
26
+
27
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
28
+ args.mllm,
29
+ cache_dir=None,
30
+ model_max_length=2048,
31
+ padding_side="right",
32
+ use_fast=False,
33
+ )
34
+ tokenizer.pad_token = tokenizer.unk_token
35
+ tokenizer.add_tokens("[SEG]")
36
+ seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
37
+
38
+ dataset = REFAVS(args.cache_split, args, tokenizer, input_type="refer")
39
+ loader = DataLoader(
40
+ dataset,
41
+ batch_size=1,
42
+ shuffle=False,
43
+ num_workers=0,
44
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
45
+ )
46
+
47
+ split_root = os.path.join(args.cache_root, args.cache_split)
48
+ os.makedirs(split_root, exist_ok=True)
49
+ index_path = os.path.join(split_root, "index.jsonl")
50
+ if os.path.exists(index_path) and not args.overwrite_cache:
51
+ raise FileExistsError(
52
+ f"{index_path} already exists. Pass --overwrite_cache to rebuild it."
53
+ )
54
+
55
+ limit = args.max_eval_rows if args.max_eval_rows > 0 else len(dataset)
56
+ print(f"cache split={args.cache_split} | samples={min(limit, len(dataset))}")
57
+ print(f"cache root: {split_root}")
58
+
59
+ model = build_model(tokenizer, seg_token_idx)
60
+ model.eval()
61
+
62
+ rows = []
63
+ for sample_idx, batch in enumerate(
64
+ tqdm(islice(loader, limit), total=min(limit, len(dataset)), desc=f"Caching {args.cache_split}")
65
+ ):
66
+ batch = dict_to_cuda(batch)
67
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
68
+ output = model.forward(
69
+ images=batch["images"],
70
+ images_clip=batch["images_clip"],
71
+ audio_features=batch["audio_feats"],
72
+ image_features=batch["image_feats"],
73
+ input_ids=batch["input_ids"],
74
+ labels=batch["labels"],
75
+ attention_masks=batch["attention_masks"],
76
+ masks_list=batch["masks"],
77
+ resize_list=batch["resizes"],
78
+ orgsize_list=batch["orgsizes"],
79
+ conversation_list=batch["convs"],
80
+ refs_num=batch["refs_num"],
81
+ fids=batch["fids"],
82
+ vids=batch["vids"],
83
+ contrast=args.ct_weight,
84
+ ref_ids=batch["ref_ids"],
85
+ inference=True,
86
+ )
87
+
88
+ cache_name = f"{sample_idx:06d}.pt"
89
+ cache_path = os.path.join(split_root, cache_name)
90
+ item = {
91
+ "sample_idx": sample_idx,
92
+ "vid": batch["vids"][0],
93
+ "refs": batch["refs"][0],
94
+ "fids": [int(x) for x in batch["fids"][0]],
95
+ "resize": _jsonable_size(batch["resizes"][0]),
96
+ "orgsize": _jsonable_size(batch["orgsizes"][0]),
97
+ "q": output["seg_embeddings"][0].detach().cpu().float(),
98
+ }
99
+ torch.save(item, cache_path)
100
+ rows.append(
101
+ {
102
+ "sample_idx": sample_idx,
103
+ "path": cache_name,
104
+ "vid": item["vid"],
105
+ "refs": item["refs"],
106
+ "fids": item["fids"],
107
+ "resize": item["resize"],
108
+ "orgsize": item["orgsize"],
109
+ "num_seg": int(item["q"].shape[0]),
110
+ }
111
+ )
112
+
113
+ if not rows:
114
+ raise RuntimeError("No samples were cached.")
115
+
116
+ with open(index_path, "w") as f:
117
+ for row in rows:
118
+ f.write(json.dumps(row) + "\n")
119
+
120
+ print(f"cached samples: {len(rows)}")
121
+ print(f"saved index: {index_path}")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
d2_basic.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import math
3
+ import os
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import transformers
10
+ from torch.utils.data import DataLoader
11
+
12
+ from configs import args
13
+ from datasets import REFAVS
14
+ from decoder_invariance_check import build_model, set_seed
15
+ from load_model import collate_fn, dict_to_cuda
16
+
17
+
18
+ def make_loader(tokenizer):
19
+ dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
20
+ return DataLoader(
21
+ dataset,
22
+ batch_size=1,
23
+ shuffle=False,
24
+ num_workers=0,
25
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
26
+ )
27
+
28
+
29
+ def build_tokenizer():
30
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
31
+ args.mllm,
32
+ cache_dir=None,
33
+ model_max_length=2048,
34
+ padding_side="right",
35
+ use_fast=False,
36
+ )
37
+ tokenizer.pad_token = tokenizer.unk_token
38
+ tokenizer.add_tokens("[SEG]")
39
+ seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
40
+ return tokenizer, seg_token_idx
41
+
42
+
43
+ def get_q(model, batch):
44
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
45
+ output = model.forward(
46
+ images=batch["images"],
47
+ images_clip=batch["images_clip"],
48
+ audio_features=batch["audio_feats"],
49
+ image_features=batch["image_feats"],
50
+ input_ids=batch["input_ids"],
51
+ labels=batch["labels"],
52
+ attention_masks=batch["attention_masks"],
53
+ masks_list=batch["masks"],
54
+ resize_list=batch["resizes"],
55
+ orgsize_list=batch["orgsizes"],
56
+ conversation_list=batch["convs"],
57
+ refs_num=batch["refs_num"],
58
+ fids=batch["fids"],
59
+ vids=batch["vids"],
60
+ contrast=args.ct_weight,
61
+ ref_ids=batch["ref_ids"],
62
+ inference=True,
63
+ )
64
+ return output["seg_embeddings"][0][0].float()
65
+
66
+
67
+ def decode_low_res(model, batch, q):
68
+ visual_model = model.get_model().visual_model
69
+ sparse, dense = visual_model.prompt_encoder(
70
+ points=None,
71
+ boxes=None,
72
+ masks=None,
73
+ text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
74
+ )
75
+ sparse = sparse.to(q.dtype)
76
+ dense = dense.to(q.dtype)
77
+
78
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
79
+ low_res_masks, iou_predictions = visual_model.mask_decoder(
80
+ image_embeddings=batch["image_feats"][0],
81
+ image_pe=visual_model.prompt_encoder.get_dense_pe(),
82
+ sparse_prompt_embeddings=sparse,
83
+ dense_prompt_embeddings=dense,
84
+ multimask_output=False,
85
+ )
86
+ return low_res_masks.float(), iou_predictions.float().squeeze(-1)
87
+
88
+
89
+ def masks_to_64(mask_logits_or_binary):
90
+ if mask_logits_or_binary.ndim == 3:
91
+ mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
92
+ return F.interpolate(
93
+ mask_logits_or_binary.float(),
94
+ size=(64, 64),
95
+ mode="bilinear",
96
+ align_corners=False,
97
+ ).clamp(0.0, 1.0)
98
+
99
+
100
+ def d2_scores(image_embeddings, mask64, q, beta):
101
+ feats = image_embeddings.float()
102
+ if mask64.shape[0] != feats.shape[0]:
103
+ raise ValueError(f"Mask/frame mismatch: {mask64.shape} vs {feats.shape}")
104
+
105
+ q = F.normalize(q.float().view(1, -1), dim=-1)
106
+ mask = mask64.float()
107
+ comp = 1.0 - mask
108
+
109
+ z_in = (feats * mask).sum(dim=(2, 3)) / mask.sum(dim=(2, 3)).clamp_min(1e-6)
110
+ z_out = (feats * comp).sum(dim=(2, 3)) / comp.sum(dim=(2, 3)).clamp_min(1e-6)
111
+
112
+ z_in = F.normalize(z_in, dim=-1)
113
+ z_out = F.normalize(z_out, dim=-1)
114
+ return (z_in @ q.T).squeeze(-1) - beta * (z_out @ q.T).squeeze(-1)
115
+
116
+
117
+ def frame_iou(pred_logits, gt_masks):
118
+ pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
119
+ gt = gt_masks.float()
120
+ if pred.ndim == 4:
121
+ pred = pred.squeeze(1)
122
+ inter = (pred * gt).sum(dim=(1, 2))
123
+ union = torch.maximum(pred, gt).sum(dim=(1, 2))
124
+ num_pixels = pred.shape[-1] * pred.shape[-2]
125
+ no_obj = gt.sum(dim=(1, 2)) == 0
126
+ inter_no_obj = ((1.0 - pred) * (1.0 - gt)).sum(dim=(1, 2))
127
+ inter = torch.where(no_obj, inter_no_obj, inter)
128
+ union = torch.where(no_obj, torch.full_like(union, float(num_pixels)), union)
129
+ return inter / union.clamp_min(1e-7)
130
+
131
+
132
+ def frame_fscore_proxy(pred_logits, gt_masks):
133
+ pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
134
+ gt = gt_masks.float()
135
+ if pred.ndim == 4:
136
+ pred = pred.squeeze(1)
137
+ tp = (pred * gt).sum(dim=(1, 2))
138
+ precision = tp / pred.sum(dim=(1, 2)).clamp_min(1e-7)
139
+ recall = tp / gt.sum(dim=(1, 2)).clamp_min(1e-7)
140
+ beta2 = 0.3
141
+ fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall).clamp_min(1e-7)
142
+ no_obj = gt.sum(dim=(1, 2)) == 0
143
+ return torch.where(no_obj, torch.zeros_like(fscore), fscore)
144
+
145
+
146
+ def parse_betas():
147
+ raw = os.environ.get("D2_BETAS", "0.5")
148
+ return [float(x.strip()) for x in raw.split(",") if x.strip()]
149
+
150
+
151
+ def collect_q_pool(model, tokenizer, limit):
152
+ q_pool = []
153
+ loader = make_loader(tokenizer)
154
+ for sample_idx, batch in enumerate(loader):
155
+ if sample_idx >= limit:
156
+ break
157
+ batch = dict_to_cuda(batch)
158
+ q = get_q(model, batch)
159
+ q_pool.append(
160
+ {
161
+ "sample_idx": sample_idx,
162
+ "vid": batch["vids"][0],
163
+ "ref": batch["refs"][0][0],
164
+ "fid": int(batch["fids"][0][0]),
165
+ "q": q.cpu(),
166
+ }
167
+ )
168
+ print(f"Collected q {sample_idx}: vid={q_pool[-1]['vid']} ref={q_pool[-1]['ref']}")
169
+ if not q_pool:
170
+ raise RuntimeError("No q vectors collected. Is the selected split empty?")
171
+ return q_pool
172
+
173
+
174
+ def choose_shuffled_idx(sample_idx, q_pool):
175
+ if len(q_pool) <= 1:
176
+ return None
177
+ return (sample_idx + 1) % len(q_pool)
178
+
179
+
180
+ def choose_wrong_ref_idx(sample_idx, q_pool):
181
+ current = q_pool[sample_idx]
182
+ for item in q_pool:
183
+ if item["sample_idx"] == sample_idx:
184
+ continue
185
+ if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
186
+ return item["sample_idx"]
187
+ for item in q_pool:
188
+ if item["sample_idx"] == sample_idx:
189
+ continue
190
+ if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
191
+ return item["sample_idx"]
192
+ return None
193
+
194
+
195
+ def run_d2(model, tokenizer, q_pool, betas, limit):
196
+ rows = []
197
+ loader = make_loader(tokenizer)
198
+ q_lookup = {item["sample_idx"]: item for item in q_pool}
199
+ generator = torch.Generator(device="cuda")
200
+ generator.manual_seed(1234)
201
+
202
+ for sample_idx, batch in enumerate(loader):
203
+ if sample_idx >= limit:
204
+ break
205
+ batch = dict_to_cuda(batch)
206
+ item = q_lookup[sample_idx]
207
+ real_q = item["q"].cuda()
208
+
209
+ low_res_masks, iou_predictions = decode_low_res(model, batch, real_q)
210
+ pred_mask64 = masks_to_64(torch.sigmoid(low_res_masks))
211
+ gt_masks = batch["masks"][0][0].float()
212
+ gt_mask64 = masks_to_64(gt_masks)
213
+ image_embeddings = batch["image_feats"][0].float()
214
+
215
+ pred_logits_hr = model.get_model().visual_model.postprocess_masks(
216
+ low_res_masks.to(batch["image_feats"][0].dtype),
217
+ input_size=batch["resizes"][0],
218
+ original_size=batch["orgsizes"][0],
219
+ ).squeeze(1)
220
+
221
+ frame_ious = frame_iou(pred_logits_hr, gt_masks)
222
+ frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
223
+ pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
224
+ gt_area = gt_masks.float().mean(dim=(1, 2))
225
+
226
+ shuffled_idx = choose_shuffled_idx(sample_idx, q_pool)
227
+ wrong_ref_idx = choose_wrong_ref_idx(sample_idx, q_pool)
228
+ q_controls = [
229
+ ("real", real_q, sample_idx),
230
+ ("random", torch.randn(real_q.shape, device=real_q.device, generator=generator), None),
231
+ ]
232
+ if shuffled_idx is not None:
233
+ q_controls.append(("shuffled", q_lookup[shuffled_idx]["q"].cuda(), shuffled_idx))
234
+ if wrong_ref_idx is not None:
235
+ q_controls.append(("wrong_ref", q_lookup[wrong_ref_idx]["q"].cuda(), wrong_ref_idx))
236
+
237
+ for beta in betas:
238
+ for q_type, q, q_source_idx in q_controls:
239
+ pred_scores = d2_scores(image_embeddings, pred_mask64, q, beta)
240
+ gt_scores = d2_scores(image_embeddings, gt_mask64, q, beta)
241
+ base_info = {
242
+ "sample_idx": sample_idx,
243
+ "vid": item["vid"],
244
+ "ref": item["ref"],
245
+ "fid": item["fid"],
246
+ "split": args.eval_split,
247
+ "frame_iou": math.nan,
248
+ "frame_fscore_proxy": math.nan,
249
+ "iou_pred": math.nan,
250
+ "pred_area": math.nan,
251
+ "gt_area": math.nan,
252
+ }
253
+ for frame_idx in range(pred_scores.shape[0]):
254
+ base_info_frame = dict(base_info)
255
+ base_info_frame.update(
256
+ {
257
+ "frame_iou": frame_ious[frame_idx].item(),
258
+ "frame_fscore_proxy": frame_fscores[frame_idx].item(),
259
+ "iou_pred": iou_predictions[frame_idx].item(),
260
+ "pred_area": pred_area[frame_idx].item(),
261
+ "gt_area": gt_area[frame_idx].item(),
262
+ }
263
+ )
264
+ row = dict(base_info_frame)
265
+ row.update(
266
+ {
267
+ "frame": frame_idx,
268
+ "q_type": q_type,
269
+ "beta": beta,
270
+ "s_pred": pred_scores[frame_idx].item(),
271
+ "s_gt": gt_scores[frame_idx].item(),
272
+ "q_source_idx": q_source_idx if q_source_idx is not None else "",
273
+ }
274
+ )
275
+ rows.append(row)
276
+
277
+ real_rows = [
278
+ r for r in rows if r["sample_idx"] == sample_idx and r["q_type"] == "real" and r["beta"] == betas[0]
279
+ ]
280
+ s_pred_values = [r["s_pred"] for r in real_rows]
281
+ print(
282
+ f"D2 {sample_idx}: vid={item['vid']} ref={item['ref']} "
283
+ f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
284
+ f"mean_iou={frame_ious.mean().item():.4f}"
285
+ )
286
+
287
+ return rows
288
+
289
+
290
+ def print_summary(rows):
291
+ real_rows = [r for r in rows if r["q_type"] == "real"]
292
+ if not real_rows:
293
+ return
294
+ by_beta = sorted(set(r["beta"] for r in real_rows))
295
+ print("\nSummary")
296
+ print(f"rows: {len(rows)}")
297
+ for beta in by_beta:
298
+ beta_rows = [r for r in rows if r["beta"] == beta]
299
+ print(f"\nbeta={beta}")
300
+ for q_type in sorted(set(r["q_type"] for r in beta_rows)):
301
+ qr = [r for r in beta_rows if r["q_type"] == q_type]
302
+ print(
303
+ f"{q_type:10s} "
304
+ f"mean_s_pred={np.mean([r['s_pred'] for r in qr]):+.4f} "
305
+ f"mean_s_gt={np.mean([r['s_gt'] for r in qr]):+.4f}"
306
+ )
307
+ real_beta = [r for r in beta_rows if r["q_type"] == "real"]
308
+ s_pred = np.array([r["s_pred"] for r in real_beta])
309
+ frame_iou_values = np.array([r["frame_iou"] for r in real_beta])
310
+ if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
311
+ corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
312
+ print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
313
+ else:
314
+ print("corr(real s_pred, frame_iou)=nan")
315
+
316
+
317
+ def main():
318
+ set_seed(42)
319
+ torch.set_grad_enabled(False)
320
+ betas = parse_betas()
321
+ tokenizer, seg_token_idx = build_tokenizer()
322
+ limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
323
+ print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
324
+
325
+ model = build_model(tokenizer, seg_token_idx)
326
+ q_pool = collect_q_pool(model, tokenizer, limit)
327
+ rows = run_d2(model, tokenizer, q_pool, betas, limit)
328
+ print_summary(rows)
329
+
330
+ csv_path = os.environ.get("D2_BASIC_CSV", f"/workspace/SimToken/d2_basic_{args.eval_split}_{limit}.csv")
331
+ os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
332
+ with open(csv_path, "w", newline="") as f:
333
+ writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
334
+ writer.writeheader()
335
+ writer.writerows(rows)
336
+ print(f"\nSaved CSV: {csv_path}")
337
+
338
+
339
+ if __name__ == "__main__":
340
+ main()
d2_llm_space.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import math
3
+ import os
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import transformers
10
+ from torch.utils.data import DataLoader
11
+
12
+ from configs import args
13
+ from datasets import REFAVS
14
+ from decoder_invariance_check import build_model, set_seed
15
+ from d2_basic import frame_fscore_proxy, frame_iou
16
+ from load_model import collate_fn, dict_to_cuda
17
+
18
+
19
+ def build_tokenizer():
20
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
21
+ args.mllm,
22
+ cache_dir=None,
23
+ model_max_length=2048,
24
+ padding_side="right",
25
+ use_fast=False,
26
+ )
27
+ tokenizer.pad_token = tokenizer.unk_token
28
+ tokenizer.add_tokens("[SEG]")
29
+ seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
30
+ return tokenizer, seg_token_idx
31
+
32
+
33
+ def make_loader(tokenizer):
34
+ dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
35
+ return DataLoader(
36
+ dataset,
37
+ batch_size=1,
38
+ shuffle=False,
39
+ num_workers=0,
40
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
41
+ )
42
+
43
+
44
+ def forward_for_hidden_and_q(model, batch):
45
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
46
+ output = model.forward(
47
+ images=batch["images"],
48
+ images_clip=batch["images_clip"],
49
+ audio_features=batch["audio_feats"],
50
+ image_features=batch["image_feats"],
51
+ input_ids=batch["input_ids"],
52
+ labels=batch["labels"],
53
+ attention_masks=batch["attention_masks"],
54
+ masks_list=batch["masks"],
55
+ resize_list=batch["resizes"],
56
+ orgsize_list=batch["orgsizes"],
57
+ conversation_list=batch["convs"],
58
+ refs_num=batch["refs_num"],
59
+ fids=batch["fids"],
60
+ vids=batch["vids"],
61
+ contrast=args.ct_weight,
62
+ ref_ids=batch["ref_ids"],
63
+ inference=True,
64
+ )
65
+ h_seg = output["seg_hidden_states"][0][0].float()
66
+ q = output["seg_embeddings"][0][0].float()
67
+ return h_seg, q
68
+
69
+
70
+ def decode_low_res(model, batch, q):
71
+ visual_model = model.get_model().visual_model
72
+ sparse, dense = visual_model.prompt_encoder(
73
+ points=None,
74
+ boxes=None,
75
+ masks=None,
76
+ text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
77
+ )
78
+ sparse = sparse.to(q.dtype)
79
+ dense = dense.to(q.dtype)
80
+
81
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
82
+ low_res_masks, iou_predictions = visual_model.mask_decoder(
83
+ image_embeddings=batch["image_feats"][0],
84
+ image_pe=visual_model.prompt_encoder.get_dense_pe(),
85
+ sparse_prompt_embeddings=sparse,
86
+ dense_prompt_embeddings=dense,
87
+ multimask_output=False,
88
+ )
89
+ return low_res_masks.float(), iou_predictions.float().squeeze(-1)
90
+
91
+
92
+ def clip_projected_tokens(model, batch):
93
+ images = torch.cat(batch["images_clip"], dim=0)
94
+ with torch.no_grad():
95
+ clip_tokens = model.encode_images(images)
96
+ projector = model.get_model().mm_projector
97
+ clip_tokens = clip_tokens.to(projector.weight.dtype)
98
+ llm_tokens = projector(clip_tokens).float()
99
+ return llm_tokens
100
+
101
+
102
+ def infer_square_grid(num_tokens):
103
+ grid = int(math.sqrt(num_tokens))
104
+ if grid * grid != num_tokens:
105
+ raise ValueError(f"Expected square patch-token grid, got {num_tokens} tokens")
106
+ return grid
107
+
108
+
109
+ def masks_to_token_grid(mask_logits_or_binary, num_tokens):
110
+ if mask_logits_or_binary.ndim == 3:
111
+ mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
112
+ grid = infer_square_grid(num_tokens)
113
+ return F.interpolate(
114
+ mask_logits_or_binary.float(),
115
+ size=(grid, grid),
116
+ mode="bilinear",
117
+ align_corners=False,
118
+ ).flatten(2).transpose(1, 2).clamp(0.0, 1.0)
119
+
120
+
121
+ def d2_scores_llm(llm_tokens, mask_tokens, h_seg, beta):
122
+ if llm_tokens.shape[:2] != mask_tokens.shape[:2]:
123
+ raise ValueError(f"Token/mask mismatch: {llm_tokens.shape} vs {mask_tokens.shape}")
124
+ h = F.normalize(h_seg.float().view(1, -1), dim=-1)
125
+ tokens = llm_tokens.float()
126
+ mask = mask_tokens.float()
127
+ comp = 1.0 - mask
128
+
129
+ z_in = (tokens * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-6)
130
+ z_out = (tokens * comp).sum(dim=1) / comp.sum(dim=1).clamp_min(1e-6)
131
+
132
+ z_in = F.normalize(z_in, dim=-1)
133
+ z_out = F.normalize(z_out, dim=-1)
134
+ return (z_in @ h.T).squeeze(-1) - beta * (z_out @ h.T).squeeze(-1)
135
+
136
+
137
+ def parse_betas():
138
+ raw = os.environ.get("D2_BETAS", "0.5")
139
+ return [float(x.strip()) for x in raw.split(",") if x.strip()]
140
+
141
+
142
+ def collect_hidden_pool(model, tokenizer, limit):
143
+ pool = []
144
+ loader = make_loader(tokenizer)
145
+ for sample_idx, batch in enumerate(loader):
146
+ if sample_idx >= limit:
147
+ break
148
+ batch = dict_to_cuda(batch)
149
+ h_seg, q = forward_for_hidden_and_q(model, batch)
150
+ pool.append(
151
+ {
152
+ "sample_idx": sample_idx,
153
+ "vid": batch["vids"][0],
154
+ "ref": batch["refs"][0][0],
155
+ "fid": int(batch["fids"][0][0]),
156
+ "h": h_seg.cpu(),
157
+ "q": q.cpu(),
158
+ }
159
+ )
160
+ print(f"Collected h {sample_idx}: vid={pool[-1]['vid']} ref={pool[-1]['ref']}")
161
+ if not pool:
162
+ raise RuntimeError("No hidden states collected. Is the selected split empty?")
163
+ return pool
164
+
165
+
166
+ def choose_shuffled_idx(sample_idx, pool):
167
+ if len(pool) <= 1:
168
+ return None
169
+ return (sample_idx + 1) % len(pool)
170
+
171
+
172
+ def choose_wrong_ref_idx(sample_idx, pool):
173
+ current = pool[sample_idx]
174
+ for item in pool:
175
+ if item["sample_idx"] == sample_idx:
176
+ continue
177
+ if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
178
+ return item["sample_idx"]
179
+ for item in pool:
180
+ if item["sample_idx"] == sample_idx:
181
+ continue
182
+ if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
183
+ return item["sample_idx"]
184
+ return None
185
+
186
+
187
+ def run_d2_llm(model, tokenizer, pool, betas, limit):
188
+ rows = []
189
+ lookup = {item["sample_idx"]: item for item in pool}
190
+ generator = torch.Generator(device="cuda")
191
+ generator.manual_seed(1234)
192
+ loader = make_loader(tokenizer)
193
+
194
+ for sample_idx, batch in enumerate(loader):
195
+ if sample_idx >= limit:
196
+ break
197
+ batch = dict_to_cuda(batch)
198
+ item = lookup[sample_idx]
199
+ h_real = item["h"].cuda()
200
+ q_real = item["q"].cuda()
201
+
202
+ low_res_masks, iou_predictions = decode_low_res(model, batch, q_real)
203
+ llm_tokens = clip_projected_tokens(model, batch)
204
+ pred_mask_tokens = masks_to_token_grid(torch.sigmoid(low_res_masks), llm_tokens.shape[1])
205
+ gt_masks = batch["masks"][0][0].float()
206
+ gt_mask_tokens = masks_to_token_grid(gt_masks, llm_tokens.shape[1])
207
+
208
+ pred_logits_hr = model.get_model().visual_model.postprocess_masks(
209
+ low_res_masks.to(batch["image_feats"][0].dtype),
210
+ input_size=batch["resizes"][0],
211
+ original_size=batch["orgsizes"][0],
212
+ ).squeeze(1)
213
+ frame_ious = frame_iou(pred_logits_hr, gt_masks)
214
+ frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
215
+ pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
216
+ gt_area = gt_masks.float().mean(dim=(1, 2))
217
+
218
+ shuffled_idx = choose_shuffled_idx(sample_idx, pool)
219
+ wrong_ref_idx = choose_wrong_ref_idx(sample_idx, pool)
220
+ controls = [
221
+ ("real", h_real, sample_idx),
222
+ ("random", torch.randn(h_real.shape, device=h_real.device, generator=generator), None),
223
+ ]
224
+ if shuffled_idx is not None:
225
+ controls.append(("shuffled", lookup[shuffled_idx]["h"].cuda(), shuffled_idx))
226
+ if wrong_ref_idx is not None:
227
+ controls.append(("wrong_ref", lookup[wrong_ref_idx]["h"].cuda(), wrong_ref_idx))
228
+
229
+ for beta in betas:
230
+ for h_type, h, h_source_idx in controls:
231
+ pred_scores = d2_scores_llm(llm_tokens, pred_mask_tokens, h, beta)
232
+ gt_scores = d2_scores_llm(llm_tokens, gt_mask_tokens, h, beta)
233
+ for frame_idx in range(pred_scores.shape[0]):
234
+ rows.append(
235
+ {
236
+ "sample_idx": sample_idx,
237
+ "vid": item["vid"],
238
+ "ref": item["ref"],
239
+ "fid": item["fid"],
240
+ "split": args.eval_split,
241
+ "frame": frame_idx,
242
+ "h_type": h_type,
243
+ "beta": beta,
244
+ "s_pred": pred_scores[frame_idx].item(),
245
+ "s_gt": gt_scores[frame_idx].item(),
246
+ "h_source_idx": h_source_idx if h_source_idx is not None else "",
247
+ "frame_iou": frame_ious[frame_idx].item(),
248
+ "frame_fscore_proxy": frame_fscores[frame_idx].item(),
249
+ "iou_pred": iou_predictions[frame_idx].item(),
250
+ "pred_area": pred_area[frame_idx].item(),
251
+ "gt_area": gt_area[frame_idx].item(),
252
+ }
253
+ )
254
+
255
+ real_rows = [
256
+ r for r in rows if r["sample_idx"] == sample_idx and r["h_type"] == "real" and r["beta"] == betas[0]
257
+ ]
258
+ s_pred_values = [r["s_pred"] for r in real_rows]
259
+ print(
260
+ f"D2-LLM {sample_idx}: vid={item['vid']} ref={item['ref']} "
261
+ f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
262
+ f"mean_iou={frame_ious.mean().item():.4f}"
263
+ )
264
+
265
+ return rows
266
+
267
+
268
+ def print_summary(rows):
269
+ print("\nSummary")
270
+ print(f"rows: {len(rows)}")
271
+ for beta in sorted(set(r["beta"] for r in rows)):
272
+ beta_rows = [r for r in rows if r["beta"] == beta]
273
+ print(f"\nbeta={beta}")
274
+ for h_type in sorted(set(r["h_type"] for r in beta_rows)):
275
+ hr = [r for r in beta_rows if r["h_type"] == h_type]
276
+ print(
277
+ f"{h_type:10s} "
278
+ f"mean_s_pred={np.mean([r['s_pred'] for r in hr]):+.4f} "
279
+ f"mean_s_gt={np.mean([r['s_gt'] for r in hr]):+.4f}"
280
+ )
281
+ real_rows = [r for r in beta_rows if r["h_type"] == "real"]
282
+ s_pred = np.array([r["s_pred"] for r in real_rows])
283
+ frame_iou_values = np.array([r["frame_iou"] for r in real_rows])
284
+ if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
285
+ corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
286
+ print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
287
+ else:
288
+ print("corr(real s_pred, frame_iou)=nan")
289
+
290
+
291
+ def main():
292
+ set_seed(42)
293
+ torch.set_grad_enabled(False)
294
+ betas = parse_betas()
295
+ tokenizer, seg_token_idx = build_tokenizer()
296
+ limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
297
+ print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
298
+
299
+ model = build_model(tokenizer, seg_token_idx)
300
+ pool = collect_hidden_pool(model, tokenizer, limit)
301
+ rows = run_d2_llm(model, tokenizer, pool, betas, limit)
302
+ print_summary(rows)
303
+
304
+ csv_path = os.environ.get("D2_LLM_CSV", f"/workspace/SimToken/d2_llm_{args.eval_split}_{limit}.csv")
305
+ os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
306
+ with open(csv_path, "w", newline="") as f:
307
+ writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
308
+ writer.writeheader()
309
+ writer.writerows(rows)
310
+ print(f"\nSaved CSV: {csv_path}")
311
+
312
+
313
+ if __name__ == "__main__":
314
+ main()
decoder_invariance_check.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import random
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ import torch
8
+ import transformers
9
+ from peft import LoraConfig, get_peft_model
10
+ from torch.utils.data import DataLoader
11
+ from transformers import AutoConfig
12
+
13
+ from configs import args
14
+ from datasets import REFAVS
15
+ from load_model import collate_fn, dict_to_cuda
16
+ from models.avs_model import Simtoken_ForCausalLM
17
+
18
+
19
+ def set_seed(seed=42):
20
+ torch.manual_seed(seed)
21
+ np.random.seed(seed)
22
+ random.seed(seed)
23
+ torch.cuda.manual_seed_all(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+
27
+
28
+ def find_lora_target_modules(model, target_modules=("q_proj", "v_proj")):
29
+ modules = set()
30
+ excluded = [
31
+ "visual_model",
32
+ "vision_tower",
33
+ "mm_projector",
34
+ "text_hidden_fcs",
35
+ "audio_feature_layer",
36
+ ]
37
+ for name, module in model.named_modules():
38
+ if not isinstance(module, torch.nn.Linear):
39
+ continue
40
+ if any(x in name for x in excluded):
41
+ continue
42
+ if any(x in name for x in target_modules):
43
+ modules.add(name)
44
+ return sorted(modules)
45
+
46
+
47
+ def build_model(tokenizer, seg_token_idx):
48
+ model_args = {
49
+ "train_mask_decoder": True,
50
+ "out_dim": 256,
51
+ "ce_loss_weight": 1.0,
52
+ "dice_loss_weight": 0.5,
53
+ "bce_loss_weight": 2.0,
54
+ "seg_token_idx": seg_token_idx,
55
+ "vision_pretrained": args.vision_pretrained,
56
+ "vision_tower": args.vision_tower,
57
+ "use_im_start_end": False,
58
+ "compress": args.compress,
59
+ "start": args.start,
60
+ }
61
+
62
+ model = Simtoken_ForCausalLM.from_pretrained(
63
+ args.mllm,
64
+ torch_dtype=torch.bfloat16,
65
+ low_cpu_mem_usage=True,
66
+ **model_args,
67
+ )
68
+
69
+ model.config.eos_token_id = tokenizer.eos_token_id
70
+ model.config.bos_token_id = tokenizer.bos_token_id
71
+ model.config.pad_token_id = tokenizer.pad_token_id
72
+
73
+ model.get_model().initialize_vision_modules(model.get_model().config)
74
+ vision_tower = model.get_model().get_vision_tower()
75
+ vision_tower.to(dtype=torch.float32, device="cuda")
76
+
77
+ model_args_from_pt = AutoConfig.from_pretrained(args.mllm)
78
+ model_args_from_pt.use_cluster = True
79
+ model_args_from_pt.freeze = False
80
+ model_args_from_pt.mm_tune = True
81
+ model_args_from_pt.spatial_cluster_rate0 = 64
82
+ model_args_from_pt.spatial_cluster_rate1 = 32
83
+ model_args_from_pt.spatial_cluster_rate2 = 16
84
+ model_args_from_pt.temporal_cluster_rate = 0.0625
85
+ model_args_from_pt.vision_tune = False
86
+ model.get_model().initialize_cluster_modules(model_args_from_pt)
87
+ model.get_model().initialize_lisa_modules(model.get_model().config)
88
+
89
+ lora_config = LoraConfig(
90
+ r=8,
91
+ lora_alpha=16,
92
+ target_modules=find_lora_target_modules(model),
93
+ lora_dropout=0.05,
94
+ bias="none",
95
+ task_type="CAUSAL_LM",
96
+ )
97
+ model = get_peft_model(model, lora_config)
98
+ model = model.to("cuda")
99
+ model.resize_token_embeddings(len(tokenizer))
100
+
101
+ state = torch.load(args.saved_model, map_location="cpu")
102
+ missing, unexpected = model.load_state_dict(state, strict=False)
103
+ print(f"Loaded checkpoint: {args.saved_model}")
104
+ print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
105
+
106
+ model.eval()
107
+ return model
108
+
109
+
110
+ def get_seg_embedding(model, batch):
111
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
112
+ output = model.forward(
113
+ images=batch["images"],
114
+ images_clip=batch["images_clip"],
115
+ audio_features=batch["audio_feats"],
116
+ image_features=batch["image_feats"],
117
+ input_ids=batch["input_ids"],
118
+ labels=batch["labels"],
119
+ attention_masks=batch["attention_masks"],
120
+ masks_list=batch["masks"],
121
+ resize_list=batch["resizes"],
122
+ orgsize_list=batch["orgsizes"],
123
+ conversation_list=batch["convs"],
124
+ refs_num=batch["refs_num"],
125
+ fids=batch["fids"],
126
+ vids=batch["vids"],
127
+ contrast=args.ct_weight,
128
+ ref_ids=batch["ref_ids"],
129
+ inference=True,
130
+ )
131
+ return output["seg_embeddings"][0][0:1]
132
+
133
+
134
+ def check_one_sample(model, batch):
135
+ q = get_seg_embedding(model, batch)
136
+ image_embeddings = batch["image_feats"][0]
137
+
138
+ visual_model = model.get_model().visual_model
139
+ sparse, dense = visual_model.prompt_encoder(
140
+ points=None,
141
+ boxes=None,
142
+ masks=None,
143
+ text_embeds=q.unsqueeze(1),
144
+ )
145
+ sparse = sparse.to(q.dtype)
146
+ dense = dense.to(q.dtype)
147
+
148
+ decoder = visual_model.mask_decoder
149
+ image_pe = visual_model.prompt_encoder.get_dense_pe()
150
+
151
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
152
+ full_masks, full_iou = decoder(
153
+ image_embeddings=image_embeddings,
154
+ image_pe=image_pe,
155
+ sparse_prompt_embeddings=sparse,
156
+ dense_prompt_embeddings=dense,
157
+ multimask_output=False,
158
+ )
159
+
160
+ rows = []
161
+ for t in range(image_embeddings.shape[0]):
162
+ single_masks, single_iou = decoder(
163
+ image_embeddings=image_embeddings[t : t + 1],
164
+ image_pe=image_pe,
165
+ sparse_prompt_embeddings=sparse,
166
+ dense_prompt_embeddings=dense,
167
+ multimask_output=False,
168
+ )
169
+
170
+ diff = (full_masks[t : t + 1] - single_masks).float().abs()
171
+ iou_diff = (full_iou[t : t + 1] - single_iou).float().abs()
172
+ rows.append(
173
+ {
174
+ "vid": batch["vids"][0],
175
+ "ref": batch["refs"][0][0],
176
+ "frame": t,
177
+ "max_abs_diff": diff.max().item(),
178
+ "mean_abs_diff": diff.mean().item(),
179
+ "iou_pred_diff": iou_diff.max().item(),
180
+ }
181
+ )
182
+ return rows
183
+
184
+
185
+ def main():
186
+ set_seed(42)
187
+ torch.set_grad_enabled(False)
188
+
189
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
190
+ args.mllm,
191
+ cache_dir=None,
192
+ model_max_length=2048,
193
+ padding_side="right",
194
+ use_fast=False,
195
+ )
196
+ tokenizer.pad_token = tokenizer.unk_token
197
+ tokenizer.add_tokens("[SEG]")
198
+ seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
199
+
200
+ dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
201
+ loader = DataLoader(
202
+ dataset,
203
+ batch_size=1,
204
+ shuffle=False,
205
+ num_workers=0,
206
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
207
+ )
208
+
209
+ limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
210
+ print(f"Split: {args.eval_split} | samples to check: {limit}")
211
+
212
+ model = build_model(tokenizer, seg_token_idx)
213
+
214
+ all_rows = []
215
+ for sample_idx, batch in enumerate(loader):
216
+ if sample_idx >= limit:
217
+ break
218
+ batch = dict_to_cuda(batch)
219
+ rows = check_one_sample(model, batch)
220
+ all_rows.extend(rows)
221
+
222
+ print(f"\nSample {sample_idx}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
223
+ print("frame | max_abs_diff | mean_abs_diff | iou_pred_diff")
224
+ for row in rows:
225
+ print(
226
+ f"{row['frame']:02d} | "
227
+ f"{row['max_abs_diff']:.8e} | "
228
+ f"{row['mean_abs_diff']:.8e} | "
229
+ f"{row['iou_pred_diff']:.8e}"
230
+ )
231
+
232
+ if not all_rows:
233
+ raise RuntimeError("No rows were checked. Is the selected split empty?")
234
+
235
+ max_diff = max(row["max_abs_diff"] for row in all_rows)
236
+ mean_diff = sum(row["mean_abs_diff"] for row in all_rows) / len(all_rows)
237
+ max_iou_diff = max(row["iou_pred_diff"] for row in all_rows)
238
+
239
+ print("\nSummary")
240
+ print(f"checked frames: {len(all_rows)}")
241
+ print(f"global max_abs_diff: {max_diff:.8e}")
242
+ print(f"average mean_abs_diff: {mean_diff:.8e}")
243
+ print(f"global max_iou_pred_diff: {max_iou_diff:.8e}")
244
+
245
+ csv_path = os.environ.get("DECODER_INVARIANCE_CSV")
246
+ if csv_path:
247
+ os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
248
+ with open(csv_path, "w", newline="") as f:
249
+ writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
250
+ writer.writeheader()
251
+ writer.writerows(all_rows)
252
+ print(f"Saved CSV: {csv_path}")
253
+
254
+
255
+ if __name__ == "__main__":
256
+ main()
load_model.py CHANGED
@@ -1,12 +1,3 @@
1
- # Compatibility: transformers==4.30.2 calls hf_hub_download(use_auth_token=...),
2
- # removed in huggingface_hub>=0.20. Patch before importing transformers so the
3
- # bound reference inside transformers.utils.hub picks up the fixed version.
4
- import huggingface_hub as _hfhub
5
- _hfhub_orig = _hfhub.hf_hub_download
6
- def _hfhub_compat(*args, use_auth_token=None, token=None, **kwargs):
7
- return _hfhub_orig(*args, token=token or use_auth_token, **kwargs)
8
- _hfhub.hf_hub_download = _hfhub_compat
9
-
10
  import transformers
11
 
12
  from torch.cuda.amp import autocast, GradScaler
@@ -217,7 +208,7 @@ def collate_fn(batch, tokenizer=None):
217
 
218
  import torch.multiprocessing as mp
219
  if __name__ == "__main__":
220
- mp.set_start_method("spawn", force=True)
221
  set_seed(42)
222
  tokenizer = transformers.AutoTokenizer.from_pretrained(
223
  args.mllm,
@@ -233,9 +224,17 @@ if __name__ == "__main__":
233
  print("seg_token_idx: ", seg_token_idx)
234
 
235
 
236
- _split = args.eval_split
237
- _dataset = REFAVS(_split, args, tokenizer, input_type='refer')
238
- _dataloader = DataLoader(_dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
 
 
 
 
 
 
 
 
239
 
240
 
241
 
@@ -341,8 +340,12 @@ if __name__ == "__main__":
341
  model = model.to("cuda")
342
  model.resize_token_embeddings(len(tokenizer))
343
 
344
- model.load_state_dict(torch.load(args.saved_model), strict=False)
345
- print("saved model loaded")
 
 
 
 
346
 
347
 
348
  save_root = args.visualization_root
@@ -401,16 +404,15 @@ if __name__ == "__main__":
401
  print("visualization finished")
402
 
403
 
404
- def valuate(model, dataloader, name, max_rows=-1):
405
  model.eval()
406
 
407
  total_iou = 0
408
  total_fscore = 0
409
  count = 0
410
 
411
- _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
412
- for i, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on {name}", total=_total)):
413
- if 0 < max_rows <= i:
414
  break
415
  input_dict = dict_to_cuda(batch)
416
 
@@ -445,39 +447,40 @@ if __name__ == "__main__":
445
  total_fscore += fscore * num_seg * T
446
  count += num_seg * T
447
 
 
 
 
448
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
449
 
450
 
451
- def valuate_Null(model, dataloader, max_rows=-1):
452
  model.eval()
453
 
454
  total_metric = 0
455
  count = 0
456
 
457
- _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
458
- for i, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on Null", total=_total)):
459
- if 0 < max_rows <= i:
460
  break
461
  input_dict = dict_to_cuda(batch)
462
- with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
463
- with torch.no_grad():
464
- output_dict = model.forward(images=input_dict["images"],
465
- images_clip=input_dict["images_clip"],
466
- audio_features=input_dict["audio_feats"],
467
- image_features=input_dict["image_feats"],
468
- input_ids=input_dict["input_ids"],
469
- labels=input_dict["labels"],
470
- attention_masks=input_dict["attention_masks"],
471
- masks_list=input_dict["masks"],
472
- resize_list=input_dict["resizes"],
473
- orgsize_list=input_dict["orgsizes"],
474
- conversation_list=input_dict["convs"],
475
- refs_num=input_dict["refs_num"],
476
- fids=input_dict["fids"],
477
- vids=input_dict["vids"],
478
- contrast=args.ct_weight,
479
- ref_ids=input_dict["ref_ids"],
480
- inference=True)
481
  pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
482
  gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
483
  for i in range(len(pred_masks)):
@@ -488,637 +491,13 @@ if __name__ == "__main__":
488
  total_metric += null_metric * num_seg * T
489
  count += num_seg * T
490
 
491
- print(f"\n valuate on test_n_refer, metric: {total_metric / count}")
492
-
493
-
494
 
 
495
 
496
- from seg_ltpo import (
497
- LTPOConfig, ltpo_optimize, best_of_2_optimize, decode_full_video,
498
- get_sam_model, get_anchor_indices,
499
- QLTPOConfig, q_ltpo_autograd, check_grad_connectivity,
500
- reset_q_ltpo_stats, get_q_ltpo_stats,
501
- q_ltpo_frame_adaptive, decode_full_video_adaptive,
502
- _compute_avt_proxy_reward,
503
- )
504
-
505
- def print_q_ltpo_stats(name: str) -> None:
506
- stats = get_q_ltpo_stats()
507
- if not stats:
508
- return
509
- n = len(stats)
510
- acc_rate = sum(s["accepted"] for s in stats) / n
511
- mean_gain = sum(s["reward_gain"] for s in stats) / n
512
- mean_drift = sum(s["drift"] for s in stats) / n
513
- clip_rate = sum(s["hit_clip"] for s in stats) / n
514
- mean_iou_init = sum(s["R_iou_pred_init"] for s in stats) / n
515
- mean_iou_best = sum(s["R_iou_pred_best"] for s in stats) / n
516
- mean_area_init = sum(s["area_hard_init"] for s in stats) / n
517
- mean_area_best = sum(s["area_hard_best"] for s in stats) / n
518
- # Null safety: reward improved but predicted area grew >20 %
519
- null_risk = sum(
520
- 1 for s in stats
521
- if s["reward_gain"] > 0 and s["area_hard_best"] > s["area_hard_init"] * 1.2
522
- ) / n
523
- gains = sorted(s["reward_gain"] for s in stats)
524
- def _pct(v, p): return v[max(0, int(len(v) * p / 100) - 1)]
525
- mean_e0 = sum(s["e0"] for s in stats) / n
526
- mean_mask_iou = sum(s.get("mask_soft_iou", 0.0) for s in stats) / n
527
- mean_iou_contrib = sum(s.get("R_iou_contrib_gain", 0.0) for s in stats) / n
528
- mean_soft_area_init = sum(s.get("r_area_soft_init", 0.0) for s in stats) / n
529
- mean_soft_area_best = sum(s.get("r_area_soft_best", 0.0) for s in stats) / n
530
- # B1 activation diagnostics
531
- b1_excesses = sorted(s.get("b1_peak_excess", 0.0) for s in stats)
532
- b1_act_rate = sum(1 for v in b1_excesses if v > 1e-8) / n
533
- b1_mean_excess = sum(b1_excesses) / n
534
- print(f"\n [q-LTPO stats | {name} | n={n}]")
535
- print(f" acceptance rate : {acc_rate:.3f}")
536
- print(f" mean e0 (exist prior): {mean_e0:.4f} ← should differ Null vs Seen")
537
- print(f" mean reward gain : {mean_gain:+.4f}")
538
- print(f" reward_gain p10/50/90: {_pct(gains,10):+.4f} / {_pct(gains,50):+.4f} / {_pct(gains,90):+.4f}")
539
- print(f" mean drift ‖q−q₀‖ : {mean_drift:.4f}")
540
- print(f" hit-clip ratio : {clip_rate:.3f}")
541
- print(f" R_iou_pred init→best : {mean_iou_init:.4f} → {mean_iou_best:.4f}")
542
- print(f" R_iou_contrib_gain : {mean_iou_contrib:+.4f} ← λ_iou·e0·Δiou")
543
- print(f" mask soft-IoU(init,best): {mean_mask_iou:.4f} ← 1.0=mask不变")
544
- print(f" area (hard) init→best: {mean_area_init:.4f} → {mean_area_best:.4f}")
545
- print(f" soft area init→best : {mean_soft_area_init:.4f} → {mean_soft_area_best:.4f}")
546
- print(f" B1 activation rate : {b1_act_rate:.3f} ← frac(peak_area > e0)")
547
- print(f" B1 mean excess : {b1_mean_excess:.5f} ← mean ReLU(peak_area - e0)")
548
- print(f" B1 excess p10/50/90 : {_pct(b1_excesses,10):.5f} / {_pct(b1_excesses,50):.5f} / {_pct(b1_excesses,90):.5f}")
549
- print(f" reward↑ & area+20%↑ : {null_risk:.3f} ← Null safety indicator")
550
- # Direction II: frame-adaptive delta diagnostics
551
- delta_norms = [s.get("delta_norm", 0.0) for s in stats]
552
- if any(v > 0 for v in delta_norms):
553
- print(f" mean delta ‖Δ‖ : {sum(delta_norms)/n:.4f} ← per-anchor residual norm")
554
-
555
- def valuate_ltpo(model, dataloader, name, ltpo_cfg, optimize_fn=None,
556
- max_rows=-1, multimask=False, use_edge=False):
557
- if optimize_fn is None:
558
- optimize_fn = ltpo_optimize
559
- """
560
- Evaluate with SEG-LTPO test-time optimisation + optional boundary refinement.
561
-
562
- decode_mode:
563
- multimask=False, use_edge=False : original single-mask decode (default)
564
- multimask=True, use_edge=False : 3 candidates, SAM iou_pred selection (step 1a)
565
- multimask=True, use_edge=True : 3 candidates, boundary-edge score (step 1b)
566
- """
567
- model.eval()
568
- sam_model = get_sam_model(model)
569
- model_dtype = torch.bfloat16
570
- num_frames = 10
571
- anchor_indices = get_anchor_indices(num_frames, ltpo_cfg.num_anchors)
572
-
573
- total_iou = 0
574
- total_fscore = 0
575
- count = 0
576
-
577
- _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
578
- for i, batch in enumerate(tqdm(dataloader, desc=f"LTPO Evaluating on {name}", total=_total)):
579
- if 0 < max_rows <= i:
580
- break
581
- input_dict = dict_to_cuda(batch)
582
-
583
- # ── Step 1: standard forward pass (LLM + SAM decode) ──────────
584
- with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
585
- with torch.no_grad():
586
- output_dict = model.forward(
587
- images=input_dict["images"],
588
- images_clip=input_dict["images_clip"],
589
- audio_features=input_dict["audio_feats"],
590
- image_features=input_dict["image_feats"],
591
- input_ids=input_dict["input_ids"],
592
- labels=input_dict["labels"],
593
- attention_masks=input_dict["attention_masks"],
594
- masks_list=input_dict["masks"],
595
- resize_list=input_dict["resizes"],
596
- orgsize_list=input_dict["orgsizes"],
597
- conversation_list=input_dict["convs"],
598
- refs_num=input_dict["refs_num"],
599
- fids=input_dict["fids"],
600
- vids=input_dict["vids"],
601
- contrast=args.ct_weight,
602
- ref_ids=input_dict["ref_ids"],
603
- inference=True,
604
- )
605
-
606
- gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
607
- seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
608
-
609
- for b in range(len(input_dict["images"])):
610
- image_embeds_b = input_dict["image_feats"][b] # [T, 256, 64, 64]
611
- resize_b = input_dict["resizes"][b]
612
- orgsize_b = input_dict["orgsizes"][b]
613
- rgb_b = input_dict["images"][b] if use_edge else None # [T,3,H,W]
614
-
615
- # Convert initial Fseg to float32 for stable optimisation.
616
- # seg_emb_list[b]: [num_seg, 256] in bfloat16
617
- F_init_b = seg_emb_list[b].detach().float() # [num_seg, 256]
618
-
619
- pred_masks_ltpo = []
620
- for seg_idx in range(F_init_b.shape[0]):
621
- fseg_init = F_init_b[seg_idx : seg_idx + 1] # [1, 256]
622
-
623
- # ── Step 2: optimisation (float32, outside autocast) ──────
624
- best_fseg = optimize_fn(
625
- fseg_init, image_embeds_b, anchor_indices,
626
- sam_model, model_dtype, ltpo_cfg,
627
- ) # [1, 256] float32
628
-
629
- # ── Step 3: decode full video with best Fseg ──────────────
630
- pred_mask = decode_full_video(
631
- best_fseg, image_embeds_b, sam_model,
632
- resize_b, orgsize_b, model_dtype,
633
- rgb_frames=rgb_b, multimask=multimask,
634
- ) # [T, H, W]
635
- pred_masks_ltpo.append(pred_mask)
636
-
637
- pred_masks_b = torch.stack(pred_masks_ltpo, dim=0) # [num_seg, T, H, W]
638
-
639
- num_seg = pred_masks_b.shape[0]
640
- T_ = pred_masks_b.shape[1]
641
- iou = utility.mask_iou(pred_masks_b, gt_masks[b])
642
- fscore = utility.Eval_Fmeasure(pred_masks_b, gt_masks[b], None)
643
-
644
- total_iou += iou * num_seg * T_
645
- total_fscore += fscore * num_seg * T_
646
- count += num_seg * T_
647
-
648
- print(f"\n LTPO valuate on {name}: miou: {total_iou/count:.4f} fscore: {total_fscore/count:.4f}")
649
-
650
-
651
- def valuate_ltpo_null(model, dataloader, ltpo_cfg, optimize_fn=None, max_rows=-1):
652
- if optimize_fn is None:
653
- optimize_fn = ltpo_optimize
654
- """LTPO evaluation for Null split: measures S metric (lower = fewer false-positive masks)."""
655
- model.eval()
656
- sam_model = get_sam_model(model)
657
- model_dtype = torch.bfloat16
658
- num_frames = 10
659
- anchor_indices = get_anchor_indices(num_frames, ltpo_cfg.num_anchors)
660
-
661
- total_metric = 0
662
- count = 0
663
-
664
- _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
665
- for i, batch in enumerate(tqdm(dataloader, desc="LTPO Evaluating on Null", total=_total)):
666
- if 0 < max_rows <= i:
667
- break
668
- input_dict = dict_to_cuda(batch)
669
-
670
- with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
671
- with torch.no_grad():
672
- output_dict = model.forward(
673
- images=input_dict["images"],
674
- images_clip=input_dict["images_clip"],
675
- audio_features=input_dict["audio_feats"],
676
- image_features=input_dict["image_feats"],
677
- input_ids=input_dict["input_ids"],
678
- labels=input_dict["labels"],
679
- attention_masks=input_dict["attention_masks"],
680
- masks_list=input_dict["masks"],
681
- resize_list=input_dict["resizes"],
682
- orgsize_list=input_dict["orgsizes"],
683
- conversation_list=input_dict["convs"],
684
- refs_num=input_dict["refs_num"],
685
- fids=input_dict["fids"],
686
- vids=input_dict["vids"],
687
- contrast=args.ct_weight,
688
- ref_ids=input_dict["ref_ids"],
689
- inference=True,
690
- )
691
-
692
- seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
693
-
694
- for b in range(len(input_dict["images"])):
695
- image_embeds_b = input_dict["image_feats"][b]
696
- resize_b = input_dict["resizes"][b]
697
- orgsize_b = input_dict["orgsizes"][b]
698
- F_init_b = seg_emb_list[b].detach().float()
699
-
700
- pred_masks_ltpo = []
701
- for seg_idx in range(F_init_b.shape[0]):
702
- fseg_init = F_init_b[seg_idx : seg_idx + 1]
703
- best_fseg = optimize_fn(
704
- fseg_init, image_embeds_b, anchor_indices,
705
- sam_model, model_dtype, ltpo_cfg,
706
- )
707
- pred_mask = decode_full_video(
708
- best_fseg, image_embeds_b, sam_model,
709
- resize_b, orgsize_b, model_dtype,
710
- )
711
- pred_masks_ltpo.append(pred_mask)
712
-
713
- pred_masks_b = torch.stack(pred_masks_ltpo, dim=0) # [num_seg, T, H, W]
714
- num_seg = pred_masks_b.shape[0]
715
- T_ = pred_masks_b.shape[1]
716
- null_metric = utility.metric_s_for_null(pred_masks_b)
717
-
718
- total_metric += null_metric * num_seg * T_
719
- count += num_seg * T_
720
-
721
- print(f"\n LTPO valuate on Null: S metric: {total_metric/count:.4f}")
722
-
723
-
724
- def valuate_ltpo_adaptive(model, dataloader, name, ltpo_cfg, max_rows=-1):
725
- """Evaluate with Direction II frame-adaptive token optimization."""
726
- model.eval()
727
- sam_model = get_sam_model(model)
728
- model_dtype = torch.bfloat16
729
- num_frames = 10
730
- anchor_indices = get_anchor_indices(num_frames, ltpo_cfg.num_anchors)
731
-
732
- total_iou = 0
733
- total_fscore = 0
734
- count = 0
735
-
736
- _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
737
- for i, batch in enumerate(tqdm(dataloader, desc=f"FA-LTPO Evaluating on {name}", total=_total)):
738
- if 0 < max_rows <= i:
739
- break
740
- input_dict = dict_to_cuda(batch)
741
-
742
- with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
743
- with torch.no_grad():
744
- output_dict = model.forward(
745
- images=input_dict["images"],
746
- images_clip=input_dict["images_clip"],
747
- audio_features=input_dict["audio_feats"],
748
- image_features=input_dict["image_feats"],
749
- input_ids=input_dict["input_ids"],
750
- labels=input_dict["labels"],
751
- attention_masks=input_dict["attention_masks"],
752
- masks_list=input_dict["masks"],
753
- resize_list=input_dict["resizes"],
754
- orgsize_list=input_dict["orgsizes"],
755
- conversation_list=input_dict["convs"],
756
- refs_num=input_dict["refs_num"],
757
- fids=input_dict["fids"],
758
- vids=input_dict["vids"],
759
- contrast=args.ct_weight,
760
- ref_ids=input_dict["ref_ids"],
761
- inference=True,
762
- )
763
-
764
- gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
765
- seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
766
-
767
- for b in range(len(input_dict["images"])):
768
- image_embeds_b = input_dict["image_feats"][b]
769
- resize_b = input_dict["resizes"][b]
770
- orgsize_b = input_dict["orgsizes"][b]
771
- F_init_b = seg_emb_list[b].detach().float()
772
-
773
- pred_masks_ltpo = []
774
- for seg_idx in range(F_init_b.shape[0]):
775
- fseg_init = F_init_b[seg_idx : seg_idx + 1]
776
-
777
- q_global, delta = q_ltpo_frame_adaptive(
778
- fseg_init, image_embeds_b, anchor_indices,
779
- sam_model, model_dtype, ltpo_cfg,
780
- )
781
-
782
- pred_mask = decode_full_video_adaptive(
783
- q_global, delta, anchor_indices,
784
- image_embeds_b, sam_model,
785
- resize_b, orgsize_b, model_dtype,
786
- )
787
- pred_masks_ltpo.append(pred_mask)
788
-
789
- pred_masks_b = torch.stack(pred_masks_ltpo, dim=0)
790
- num_seg = pred_masks_b.shape[0]
791
- T_ = pred_masks_b.shape[1]
792
- iou = utility.mask_iou(pred_masks_b, gt_masks[b])
793
- fscore = utility.Eval_Fmeasure(pred_masks_b, gt_masks[b], None)
794
-
795
- total_iou += iou * num_seg * T_
796
- total_fscore += fscore * num_seg * T_
797
- count += num_seg * T_
798
-
799
- print(f"\n FA-LTPO valuate on {name}: miou: {total_iou/count:.4f} fscore: {total_fscore/count:.4f}")
800
-
801
- # ── Step A0: reward–metric correlation study ─────────────────────────
802
-
803
- def _print_correlation_report(per_sample: list) -> None:
804
- import numpy as np
805
- n = len(per_sample)
806
- if n == 0:
807
- return
808
-
809
- r_iou = np.array([s["reward_gain"] for s in per_sample], dtype=float)
810
- r_avt = np.array([s["r_avt_gain"] for s in per_sample], dtype=float)
811
- r_avt_c = np.array([s["r_avt_c_gain"] for s in per_sample], dtype=float)
812
- dm = np.array([s["delta_miou"] for s in per_sample], dtype=float)
813
- df = np.array([s["delta_f"] for s in per_sample], dtype=float)
814
-
815
- def pearson(x, y):
816
- x = x - x.mean(); y = y - y.mean()
817
- denom = np.sqrt((x ** 2).sum() * (y ** 2).sum())
818
- return float((x * y).sum() / (denom + 1e-12))
819
-
820
- def wrong_frac(gains, deltas):
821
- return sum(1 for g, d in zip(gains, deltas) if g > 0 and d < 0) / n
822
-
823
- print(f"\n [Step A0: Reward–Metric Correlation | n={n}]")
824
- print(f" mean ΔmIoU : {dm.mean():+.4f} (std {dm.std():.4f})")
825
- print(f" mean ΔF : {df.mean():+.4f} (std {df.std():.4f})")
826
- print(f"\n Pearson r with ΔmIoU :")
827
- print(f" R_iou_pred_gain : {pearson(r_iou, dm):+.3f} ← current proxy")
828
- print(f" R_avt_gain : {pearson(r_avt, dm):+.3f} ← cos(z_in, q_init)")
829
- print(f" R_avt_c_gain : {pearson(r_avt_c, dm):+.3f} ← cos(z_in,q)-β·cos(z_out,q)")
830
- print(f"\n Pearson r with ΔF :")
831
- print(f" R_iou_pred_gain : {pearson(r_iou, df):+.3f}")
832
- print(f" R_avt_gain : {pearson(r_avt, df):+.3f}")
833
- print(f" R_avt_c_gain : {pearson(r_avt_c, df):+.3f}")
834
- print(f"\n Wrong direction (gain>0 but Δ<0):")
835
- print(f" R_iou / ΔmIoU : {wrong_frac(r_iou, dm):.3f}")
836
- print(f" R_avt / ΔmIoU : {wrong_frac(r_avt, dm):.3f}")
837
- print(f" R_iou / ΔF : {wrong_frac(r_iou, df):.3f}")
838
- print(f" R_avt / ΔF : {wrong_frac(r_avt, df):.3f}")
839
-
840
- def valuate_ltpo_correlation_study(model, dataloader, ltpo_cfg, max_rows=-1):
841
- """Step A0: per-sample reward–metric correlation study.
842
-
843
- For each (video, segment) sample runs:
844
- 1. Baseline decode (q_init → mask → IoU/F)
845
- 2. q-LTPO s1 (q_best → mask → IoU/F)
846
- Records reward signals and ΔmIoU / ΔF per sample, then prints
847
- Pearson correlation table to identify which reward best predicts
848
- actual metric improvement.
849
- """
850
- model.eval()
851
- sam_model = get_sam_model(model)
852
- model_dtype = torch.bfloat16
853
- anchor_indices = get_anchor_indices(10, ltpo_cfg.num_anchors)
854
-
855
- per_sample = []
856
-
857
- _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
858
- for i, batch in enumerate(
859
- tqdm(dataloader, desc="Correlation study (s1)", total=_total)
860
- ):
861
- if 0 < max_rows <= i:
862
- break
863
- input_dict = dict_to_cuda(batch)
864
-
865
- with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
866
- with torch.no_grad():
867
- output_dict = model.forward(
868
- images=input_dict["images"],
869
- images_clip=input_dict["images_clip"],
870
- audio_features=input_dict["audio_feats"],
871
- image_features=input_dict["image_feats"],
872
- input_ids=input_dict["input_ids"],
873
- labels=input_dict["labels"],
874
- attention_masks=input_dict["attention_masks"],
875
- masks_list=input_dict["masks"],
876
- resize_list=input_dict["resizes"],
877
- orgsize_list=input_dict["orgsizes"],
878
- conversation_list=input_dict["convs"],
879
- refs_num=input_dict["refs_num"],
880
- fids=input_dict["fids"],
881
- vids=input_dict["vids"],
882
- contrast=args.ct_weight,
883
- ref_ids=input_dict["ref_ids"],
884
- inference=True,
885
- )
886
-
887
- gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
888
- seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
889
-
890
- for b in range(len(input_dict["images"])):
891
- image_embeds_b = input_dict["image_feats"][b]
892
- resize_b = input_dict["resizes"][b]
893
- orgsize_b = input_dict["orgsizes"][b]
894
- F_init_b = seg_emb_list[b].detach().float()
895
-
896
- for seg_idx in range(F_init_b.shape[0]):
897
- q_init = F_init_b[seg_idx : seg_idx + 1] # [1, 256]
898
- gt_seg = gt_masks[b][seg_idx : seg_idx + 1] # [1, T, H, W]
899
-
900
- # Baseline decode (q_init, no LTPO)
901
- with torch.no_grad():
902
- pred_base = decode_full_video(
903
- q_init, image_embeds_b, sam_model,
904
- resize_b, orgsize_b, model_dtype,
905
- ).unsqueeze(0) # [1, T, H, W]
906
- iou_base = utility.mask_iou(pred_base, gt_seg)
907
- f_base = utility.Eval_Fmeasure(pred_base, gt_seg, None)
908
-
909
- # LTPO (s1) — also computes r_avt inside q_ltpo_autograd
910
- reset_q_ltpo_stats()
911
- q_best = q_ltpo_autograd(
912
- q_init, image_embeds_b, anchor_indices,
913
- sam_model, model_dtype, ltpo_cfg,
914
- )
915
- stat = get_q_ltpo_stats()[0]
916
-
917
- with torch.no_grad():
918
- pred_ltpo = decode_full_video(
919
- q_best, image_embeds_b, sam_model,
920
- resize_b, orgsize_b, model_dtype,
921
- ).unsqueeze(0)
922
- iou_ltpo = utility.mask_iou(pred_ltpo, gt_seg)
923
- f_ltpo = utility.Eval_Fmeasure(pred_ltpo, gt_seg, None)
924
-
925
- per_sample.append({
926
- "reward_gain": stat["reward_gain"],
927
- "r_avt_gain": stat.get("r_avt_gain", 0.0),
928
- "r_avt_c_gain": stat.get("r_avt_c_gain", 0.0),
929
- "e0": stat["e0"],
930
- "accepted": stat["accepted"],
931
- "delta_miou": float(iou_ltpo - iou_base),
932
- "delta_f": float(f_ltpo - f_base),
933
- })
934
-
935
- _print_correlation_report(per_sample)
936
-
937
- # ── Stage 0: gradient connectivity check ─────────────────────────────
938
- # Loads one image_embed directly from disk — no dataloader, no gt_mask,
939
- # no media frames required. F_init is a unit-scale random vector that
940
- # mimics the distribution of Fseg (SAM prompt embeddings are in ℝ^256
941
- # with per-dim std ≈ 0.05–0.3; we use std=0.1 as a neutral initialisation).
942
- def run_stage0_check():
943
- import glob
944
- sam_model = get_sam_model(model)
945
- model_dtype = torch.bfloat16
946
-
947
- embed_files = sorted(glob.glob(os.path.join(args.data_dir, "image_embed", "*.pt")))
948
- if not embed_files:
949
- print("[Stage 0] ERROR: no .pt files found in data/image_embed/")
950
- return False
951
-
952
- img_embs = torch.load(embed_files[0], map_location="cuda") # [T, 256, 64, 64]
953
- if img_embs.dim() == 3: # [256,64,64] → [1,256,64,64]
954
- img_embs = img_embs.unsqueeze(0)
955
-
956
- torch.manual_seed(42)
957
- F_init = torch.randn(1, 256, device="cuda") * 0.1 # [1, 256] float32
958
-
959
- anchors = get_anchor_indices(img_embs.shape[0], 4)
960
- diag = check_grad_connectivity(F_init, img_embs, anchors, sam_model, model_dtype)
961
- print("\n[Stage 0] Gradient connectivity check:")
962
- print(f" file used : {os.path.basename(embed_files[0])}")
963
- print(f" gradient_connected : {diag['gradient_connected']}")
964
- print(f" grad_norm (step 0) : {diag['grad_norm_step0']:.6f}")
965
- print(f" reward trajectory : {[f'{r:.4f}' for r in diag['reward_trajectory']]}")
966
- return diag["gradient_connected"]
967
-
968
- # ── Bypass equivalence test ───────────────────────────────────────────
969
- # Three controlled tests to verify that fseg.unsqueeze(1) (bypass) is
970
- # numerically equivalent to prompt_encoder(text_embeds=fseg.unsqueeze(1)):
971
- # Test 1 — dense_emb dtype: dense_A.to(bfloat16) vs dense_emb_bf16 (exact 0?)
972
- # Test 2 — matched-prec anchor decode: same decoder, same inputs, both bfloat16
973
- # Test 3 — full-video (all T frames) matched-prec decode
974
- # If all pass, delta_bypass_init = 0 and the +4.22% is purely from optimization.
975
- def run_bypass_test():
976
- from seg_ltpo import _precompute_dense_emb
977
-
978
- sam_model = get_sam_model(model)
979
- pe = sam_model.prompt_encoder
980
- mask_dec = sam_model.mask_decoder
981
- model_dtype = torch.bfloat16
982
-
983
- # Get one real Fseg via a standard forward pass on the first batch
984
- batch = next(iter(_dataloader))
985
- input_dict = dict_to_cuda(batch)
986
- with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
987
- with torch.no_grad():
988
- output_dict = model.forward(
989
- images=input_dict["images"],
990
- images_clip=input_dict["images_clip"],
991
- audio_features=input_dict["audio_feats"],
992
- image_features=input_dict["image_feats"],
993
- input_ids=input_dict["input_ids"],
994
- labels=input_dict["labels"],
995
- attention_masks=input_dict["attention_masks"],
996
- masks_list=input_dict["masks"],
997
- resize_list=input_dict["resizes"],
998
- orgsize_list=input_dict["orgsizes"],
999
- conversation_list=input_dict["convs"],
1000
- refs_num=input_dict["refs_num"],
1001
- fids=input_dict["fids"],
1002
- vids=input_dict["vids"],
1003
- contrast=args.ct_weight,
1004
- ref_ids=input_dict["ref_ids"],
1005
- inference=True,
1006
- )
1007
 
1008
- fseg = output_dict["seg_embeddings"][0][0:1].detach() # [1,256] bfloat16
1009
- image_embeds = input_dict["image_feats"][0] # [T,256,64,64] float32
1010
- device = fseg.device
1011
-
1012
- anchor_indices = get_anchor_indices(image_embeds.shape[0], 4)
1013
- img_anc = image_embeds[anchor_indices] # [A,256,64,64] float32
1014
- dense_emb_bf16 = _precompute_dense_emb(sam_model, model_dtype, device) # [1,256,64,64] bfloat16
1015
- dense_pe = pe.get_dense_pe().to(device) # float32
1016
-
1017
- def _decode(img, sparse_emb, dense_emb):
1018
- return mask_dec(
1019
- image_embeddings=img,
1020
- image_pe=dense_pe,
1021
- sparse_prompt_embeddings=sparse_emb,
1022
- dense_prompt_embeddings=dense_emb,
1023
- multimask_output=False,
1024
- )
1025
-
1026
- def _check(label, tensor_a, tensor_b, exact=False):
1027
- err = (tensor_a.float() - tensor_b.float()).abs().max().item()
1028
- tol = 0.0 if exact else 1e-4
1029
- status = "PASS" if err <= tol else "FAIL"
1030
- print(f" [{status}] {label:50s} max|A-B| = {err:.2e}")
1031
- return err <= tol
1032
-
1033
- print(f"\n[Bypass Test] fseg dtype={fseg.dtype} norm={fseg.float().norm().item():.4f}")
1034
-
1035
- with torch.no_grad():
1036
- # Get prompt_encoder outputs (called outside autocast → float32)
1037
- sparse_A, dense_A = pe(points=None, boxes=None, masks=None,
1038
- text_embeds=fseg.unsqueeze(1))
1039
- sparse_B = fseg.unsqueeze(1) # bypass sparse: identical tensor
1040
-
1041
- # ── Test 1: dense_emb dtype artifact ──────────���─────────────────────
1042
- # Hypothesis: dense_A (float32) and dense_emb_bf16 differ only because
1043
- # no_mask_embed.weight is float32; casting to bfloat16 should give exact 0.
1044
- print("\n [Test 1] dense_emb dtype artifact (expected: exact 0)")
1045
- t1 = _check("dense_A.to(bfloat16) vs dense_emb_bf16",
1046
- dense_A.to(torch.bfloat16), dense_emb_bf16, exact=True)
1047
-
1048
- # ── Test 2: matched-precision decode on anchors ──────────────────────
1049
- # Both paths use bfloat16 sparse + bfloat16 dense.
1050
- # If sparse_emb is identical and dense_emb is identical (per Test 1),
1051
- # masks and iou_preds must be identical (same decoder, same inputs).
1052
- print("\n [Test 2] matched-precision anchor decode (expected: exact 0)")
1053
- dense_A_bf16 = dense_A.to(model_dtype)
1054
- masks_A, iou_A = _decode(img_anc, sparse_A, dense_A_bf16)
1055
- masks_B, iou_B = _decode(img_anc, sparse_B, dense_emb_bf16)
1056
- _check("sparse_emb", sparse_A, sparse_B, exact=True)
1057
- t2m = _check("masks (anchors, matched prec)", masks_A, masks_B, exact=True)
1058
- t2i = _check("iou_preds (anchors, matched prec)", iou_A, iou_B, exact=True)
1059
- t2 = t2m and t2i
1060
-
1061
- # ── Test 3: full-video bypass-init baseline (all T frames) ──────────
1062
- # Extend Test 2 to all T frames; quantifies delta_bypass_init over
1063
- # the complete video rather than just the 4 anchor frames.
1064
- print(f"\n [Test 3] full-video matched-precision decode (T={image_embeds.shape[0]} frames)")
1065
- masks_full_A, _ = _decode(image_embeds, sparse_A, dense_A_bf16)
1066
- masks_full_B, _ = _decode(image_embeds, sparse_B, dense_emb_bf16)
1067
- t3 = _check("masks (all frames, matched prec)", masks_full_A, masks_full_B, exact=True)
1068
-
1069
- print("\n ── Verdict ──────────────────────────────────────────────────────")
1070
- if t1 and t2 and t3:
1071
- print(" ALL PASS — bypass is algebraically and numerically equivalent to")
1072
- print(" prompt_encoder path under matched precision. delta_bypass_init = 0.")
1073
- print(" The +4.22% mIoU improvement is purely from q-LTPO optimization.")
1074
- else:
1075
- failures = []
1076
- if not t1: failures.append("Test 1 (dense dtype)")
1077
- if not t2: failures.append("Test 2 (anchor decode)")
1078
- if not t3: failures.append("Test 3 (full-video decode)")
1079
- print(f" FAIL in: {', '.join(failures)}")
1080
- print(" delta_bypass_init ≠ 0; need per-sample mIoU comparison to quantify.")
1081
-
1082
- # ── Run evaluation ────────────────────────────────────────────────────
1083
-
1084
- ltpo_cfg = LTPOConfig()
1085
- q_ltpo_cfg_s1 = QLTPOConfig(stage=1)
1086
- q_ltpo_cfg_s2 = QLTPOConfig(stage=2)
1087
- q_ltpo_cfg_s21 = QLTPOConfig(stage=21) # P1a: tether probe
1088
- q_ltpo_cfg_s22 = QLTPOConfig(stage=22) # P1b: faithful ext-ref
1089
-
1090
- # ── Direction B: boundary precision probes ──────────────────────────────
1091
- q_ltpo_cfg_b1_w03 = QLTPOConfig(stage=1, lambda_area_inc=0.3, area_inc_tau=0.0)
1092
- q_ltpo_cfg_b1_w10 = QLTPOConfig(stage=1, lambda_area_inc=1.0, area_inc_tau=0.0)
1093
-
1094
- # ── Direction II: Frame-adaptive token optimization ─────────────────────
1095
- # fa_c03: delta clipped at 0.3×‖q_init‖ — moderate constraint.
1096
- # First probe to answer: "does constrained frame-adaptive beat shared q?"
1097
- # If yes → ablate tighter/looser constraints and smoothness in follow-up.
1098
- q_ltpo_cfg_fa_c03 = QLTPOConfig(stage=1, lambda_residual=0.001, lambda_smooth_temp=0.0, max_delta_drift_scale=0.3)
1099
-
1100
- max_rows = args.max_eval_rows # -1 = all rows
1101
-
1102
- # --max_eval_rows 0 → Stage 0 + bypass equivalence check, then exit
1103
- if max_rows == 0:
1104
- run_stage0_check()
1105
- run_bypass_test()
1106
- elif _split == 'test_n':
1107
- # Null safety check: baseline + Stage 1 + frame-adaptive
1108
- valuate_Null(model, _dataloader, max_rows=max_rows)
1109
- for cfg_name, cfg in [("s1", q_ltpo_cfg_s1)]:
1110
- reset_q_ltpo_stats()
1111
- valuate_ltpo_null(model, _dataloader, cfg,
1112
- optimize_fn=q_ltpo_autograd, max_rows=max_rows)
1113
- print_q_ltpo_stats(f"null_q_ltpo_{cfg_name}")
1114
- reset_q_ltpo_stats()
1115
- valuate_ltpo_adaptive(model, _dataloader, "null_fa_c03",
1116
- q_ltpo_cfg_fa_c03, max_rows=max_rows)
1117
- print_q_ltpo_stats("null_fa_c03")
1118
  else:
1119
- valuate(model, _dataloader, _split, max_rows=max_rows)
1120
- # Step A0: reward–metric correlation study (s1 + AVT proxy signals)
1121
- valuate_ltpo_correlation_study(
1122
- model, _dataloader, q_ltpo_cfg_s1, max_rows=max_rows
1123
- )
1124
-
 
 
 
 
 
 
 
 
 
 
1
  import transformers
2
 
3
  from torch.cuda.amp import autocast, GradScaler
 
208
 
209
  import torch.multiprocessing as mp
210
  if __name__ == "__main__":
211
+ mp.set_start_method("spawn")
212
  set_seed(42)
213
  tokenizer = transformers.AutoTokenizer.from_pretrained(
214
  args.mllm,
 
224
  print("seg_token_idx: ", seg_token_idx)
225
 
226
 
227
+ if args.eval_split not in {"test_s", "test_u", "test_n"}:
228
+ raise ValueError(f"Unsupported eval_split: {args.eval_split}")
229
+
230
+ val_dataset = REFAVS(args.eval_split, args, tokenizer, input_type='refer')
231
+ val_dataloader = DataLoader(
232
+ val_dataset,
233
+ batch_size=1,
234
+ shuffle=False,
235
+ num_workers=4,
236
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
237
+ )
238
 
239
 
240
 
 
340
  model = model.to("cuda")
341
  model.resize_token_embeddings(len(tokenizer))
342
 
343
+ missing, unexpected = model.load_state_dict(
344
+ torch.load(args.saved_model, map_location="cpu"),
345
+ strict=False,
346
+ )
347
+ print(f"saved model loaded: {args.saved_model}")
348
+ print(f"missing keys: {len(missing)} | unexpected keys: {len(unexpected)}")
349
 
350
 
351
  save_root = args.visualization_root
 
404
  print("visualization finished")
405
 
406
 
407
+ def valuate(model, dataloader, name):
408
  model.eval()
409
 
410
  total_iou = 0
411
  total_fscore = 0
412
  count = 0
413
 
414
+ for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on {name}")):
415
+ if args.max_eval_rows > 0 and batch_idx >= args.max_eval_rows:
 
416
  break
417
  input_dict = dict_to_cuda(batch)
418
 
 
447
  total_fscore += fscore * num_seg * T
448
  count += num_seg * T
449
 
450
+ if count == 0:
451
+ raise RuntimeError(f"No samples were evaluated for {name}")
452
+
453
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
454
 
455
 
456
+ def valuate_Null(model, dataloader):
457
  model.eval()
458
 
459
  total_metric = 0
460
  count = 0
461
 
462
+ for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on Null")):
463
+ if args.max_eval_rows > 0 and batch_idx >= args.max_eval_rows:
 
464
  break
465
  input_dict = dict_to_cuda(batch)
466
+ with torch.no_grad():
467
+ output_dict = model.forward(images=input_dict["images"],
468
+ images_clip=input_dict["images_clip"],
469
+ audio_features=input_dict["audio_feats"],
470
+ image_features=input_dict["image_feats"],
471
+ input_ids=input_dict["input_ids"],
472
+ labels=input_dict["labels"],
473
+ attention_masks=input_dict["attention_masks"],
474
+ masks_list=input_dict["masks"],
475
+ resize_list=input_dict["resizes"],
476
+ orgsize_list=input_dict["orgsizes"],
477
+ conversation_list=input_dict["convs"],
478
+ refs_num=input_dict["refs_num"],
479
+ fids=input_dict["fids"],
480
+ vids=input_dict["vids"],
481
+ contrast=args.ct_weight,
482
+ ref_ids=input_dict["ref_ids"],
483
+ inference=True)
 
484
  pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
485
  gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
486
  for i in range(len(pred_masks)):
 
491
  total_metric += null_metric * num_seg * T
492
  count += num_seg * T
493
 
494
+ if count == 0:
495
+ raise RuntimeError("No samples were evaluated for test_n")
 
496
 
497
+ print(f"\n valuate on test_n_refer, metric: {total_metric / count}")
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
+ if args.eval_split == "test_n":
501
+ valuate_Null(model, val_dataloader)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  else:
503
+ valuate(model, val_dataloader, args.eval_split)
 
 
 
 
 
save_audio_feats.py CHANGED
@@ -80,4 +80,3 @@ for vid in vids:
80
  # print(f"{vid}: {audio_embed.shape}")
81
  torch.save(audio_embed, f'{save_dir}/{vid}.pt')
82
  print(f'{vid} embedding saved {audio_embed.shape}')
83
-
 
80
  # print(f"{vid}: {audio_embed.shape}")
81
  torch.save(audio_embed, f'{save_dir}/{vid}.pt')
82
  print(f'{vid} embedding saved {audio_embed.shape}')
 
setup_simtoken.md CHANGED
@@ -1,12 +1,22 @@
1
  # SimToken Setup
2
 
 
 
3
  ---
4
 
5
  ## 1. Create Environment
6
 
 
 
7
  ```bash
8
- conda create -n simtoken python=3.10 -y
9
- conda activate simtoken
 
 
 
 
 
 
10
 
11
  python -m pip install --upgrade pip wheel "setuptools<81"
12
 
@@ -34,17 +44,63 @@ pip install \
34
  huggingface_hub
35
  ```
36
 
 
 
 
 
 
 
 
 
 
 
 
37
  ---
38
 
39
- ## 2. Download from HuggingFace(新机器初始化)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- 登录 HuggingFace(token 在 https://huggingface.co/settings/tokens 生成)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ```bash
44
  huggingface-cli login
45
  ```
46
 
47
- 下载完整 repo(代码 + 权重 + 压缩数据包,共约 190G)
48
 
49
  ```bash
50
  mkdir -p /workspace/SimToken
@@ -56,97 +112,108 @@ huggingface-cli download yfan07/SimToken \
56
  --local-dir-use-symlinks False
57
  ```
58
 
59
- 下载完成后解压数据
60
 
61
  ```bash
62
  cd /workspace/SimToken/data
63
 
64
- tar -xf image_embed.tar # ~5–10 分钟
65
  tar -xzf gt_mask.tar.gz
66
  tar -xzf audio_embed.tar.gz
67
  tar -xf media.tar
68
  ```
69
 
70
-
71
  ---
72
 
73
- ## 3. Pre-download Model Weights(首次使用必做)
74
 
75
- `transformers==4.30.2` 与新版 `huggingface_hub` 存在 API 兼容`use_auth_token` 已移除)
76
- 解决方案:先用 CLI 将模型下载到本地缓存,之后运行实验时加 `TRANSFORMERS_OFFLINE=1`,跳过所有网络请求。
77
 
78
  ```bash
79
- # Chat-UniVi-7B(~14G)
80
  huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
81
 
82
- # CLIP ViT-L(~1.6G)
83
  huggingface-cli download openai/clip-vit-large-patch14
84
  ```
85
 
86
- 下载完成后即永久缓存,新 session 无需重复下载。
 
 
 
 
 
 
 
87
 
88
  ---
89
 
90
- ## 4. Example Evaluation
91
 
92
- 所有评测命令统 `TRANSFORMERS_OFFLINE=1`
93
 
94
  ```bash
95
  cd /workspace/SimToken
96
 
97
- # Unseen split(全量 1656 样本)
98
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u
 
 
 
99
 
100
- # Seen split
101
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_s
102
 
103
- # Null split(S metric,越低越好)
104
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_n
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- # 限制样本数(快速验证)
107
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u --max_eval_rows 50
108
 
109
- # Stage 0 梯度连通性 + bypass 等价性检查(仅诊断)
110
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u --max_eval_rows 0
 
111
  ```
112
 
113
- 每次评估依次输出:Baseline + q-LTPO Stage 1 两组结果及诊断统计。
114
-
115
  ---
116
 
117
- ## 5. Upload to HuggingFace(实验结束后)
118
-
119
- 数据目录以压缩包形式存储,可大幅减少文件数量,避免 HuggingFace commit 频率限制。
120
 
121
- **第一步:将数据目录压缩为归档文件(如尚未压缩)**
122
 
123
  ```bash
124
  cd /workspace/SimToken/data
125
 
126
- tar -cf image_embed.tar image_embed/ # 不压缩(.pt 已是二进制)
127
  tar -czf gt_mask.tar.gz gt_mask/
128
  tar -czf audio_embed.tar.gz audio_embed/
129
  tar -cf media.tar media/
130
 
131
- # 确认压缩包存在后删除原始目录
132
  ls -lh *.tar*
133
  rm -rf image_embed/ gt_mask/ audio_embed/ media/
134
  ```
135
 
136
- **第二步:清理缓存并上传**
137
 
138
  ```bash
139
- find /workspace/SimToken -name "__pycache__" -exec rm -rf {} + 2>/dev/null
140
- find /workspace/SimToken -name "*.pyc" -delete
141
 
142
- huggingface-cli login # token https://huggingface.co/settings/tokens 生成(需 Write 权限)
 
 
 
143
 
144
- cd /workspace/SimToken
145
  python upload_hf.py --repo yfan07/SimToken
146
  ```
147
-
148
- **注意事项:**
149
- - 建议在 `tmux` 里运行,防止 SSH 断开:`tmux new -s upload`,完成后 `Ctrl+B D` detach
150
- - 支持断点续传:中断后重新执行同一命令会自动跳过已上传文件
151
- - 遇到 rate limit(HTTP 429)时脚本会自动等待约 1 小时后重试
152
- - 监控进度:`tail -f /workspace/SimToken/upload.log`
 
1
  # SimToken Setup
2
 
3
+ 本文档用于在新机器上重建 SimToken 环境,并准备后续 A-min 实验。
4
+
5
  ---
6
 
7
  ## 1. Create Environment
8
 
9
+ 先确认 GPU 和 CUDA driver 状态:
10
+
11
  ```bash
12
+ nvidia-smi
13
+ ```
14
+
15
+ 创建 conda 环境:
16
+
17
+ ```bash
18
+ /opt/miniforge3/condabin/conda create -n simtoken python=3.10 -y
19
+ /opt/miniforge3/condabin/conda activate simtoken
20
 
21
  python -m pip install --upgrade pip wheel "setuptools<81"
22
 
 
44
  huggingface_hub
45
  ```
46
 
47
+ 快速验证:
48
+
49
+ ```bash
50
+ python - <<'PY'
51
+ import torch
52
+ print("torch:", torch.__version__)
53
+ print("cuda available:", torch.cuda.is_available())
54
+ print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")
55
+ PY
56
+ ```
57
+
58
  ---
59
 
60
+ ## 2. Check Workspace After Migration
61
+
62
+ 使用服务器平台的迁移工具完成目录迁移后,在新机器上确认关键文件:
63
+
64
+ ```bash
65
+ cd /workspace/SimToken
66
+
67
+ ls -lh checkpoints/simtoken_pretrained.pth
68
+ ls -lh models/segment_anything/sam_vit_h_4b8939.pth
69
+ ls -d data/image_embed data/gt_mask data/audio_embed data/media
70
+ ```
71
+
72
+ 如果迁移后只有压缩包而没有解压目录,重新解压:
73
+
74
+ ```bash
75
+ cd /workspace/SimToken/data
76
+
77
+ tar -xf image_embed.tar
78
+ tar -xzf gt_mask.tar.gz
79
+ tar -xzf audio_embed.tar.gz
80
+ tar -xf media.tar
81
+ ```
82
 
83
+ 清理迁移中不需要的缓存
84
+
85
+ ```bash
86
+ cd /workspace/SimToken
87
+ find . -name "__pycache__" -prune -exec rm -rf {} +
88
+ find . -name ".pytest_cache" -prune -exec rm -rf {} +
89
+ find . -name ".cache" -prune -exec rm -rf {} +
90
+ find . -name "*.pyc" -delete
91
+ ```
92
+
93
+ ---
94
+
95
+ ## 3. Download from HuggingFace
96
+
97
+ 如果新机器不使用迁移工具,而是从 HuggingFace 重新初始化,先登录:
98
 
99
  ```bash
100
  huggingface-cli login
101
  ```
102
 
103
+ 下载完整 repo:
104
 
105
  ```bash
106
  mkdir -p /workspace/SimToken
 
112
  --local-dir-use-symlinks False
113
  ```
114
 
115
+ 下载完成后解压数据:
116
 
117
  ```bash
118
  cd /workspace/SimToken/data
119
 
120
+ tar -xf image_embed.tar
121
  tar -xzf gt_mask.tar.gz
122
  tar -xzf audio_embed.tar.gz
123
  tar -xf media.tar
124
  ```
125
 
 
126
  ---
127
 
128
+ ## 4. Pre-download Model Weights
129
 
130
+ `transformers==4.30.2` 与新版 `huggingface_hub` 可能存在网络/API 兼容问题。建议先用 CLI 将模型下载到本地缓存,实验时再加 `TRANSFORMERS_OFFLINE=1`。
 
131
 
132
  ```bash
133
+ # Chat-UniVi-7B
134
  huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
135
 
136
+ # CLIP ViT-L
137
  huggingface-cli download openai/clip-vit-large-patch14
138
  ```
139
 
140
+ 下载完成后做离线验证:
141
+
142
+ ```bash
143
+ cd /workspace/SimToken
144
+
145
+ TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
146
+ python -m py_compile train.py load_model.py decoder_invariance_check.py
147
+ ```
148
 
149
  ---
150
 
151
+ ## 5. Smoke Test
152
 
153
+ 先跑个轻量 sanity check,确认 checkpoint、数据和离线模型缓存都能正常读取
154
 
155
  ```bash
156
  cd /workspace/SimToken
157
 
158
+ TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
159
+ python decoder_invariance_check.py \
160
+ --eval_split test_s \
161
+ --max_eval_rows 1
162
+ ```
163
 
164
+ 如果可以正常加载模型并输出 per-frame diff,就可以启动完整 A-min 训练:
 
165
 
166
+ ```bash
167
+ cd /workspace/SimToken
168
+ mkdir -p log checkpoints
169
+
170
+ TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
171
+ python -W ignore train.py \
172
+ --name amin_full_e1 \
173
+ --init_from_saved_model \
174
+ --epochs 1 \
175
+ --batch_size 2 \
176
+ --lr 1e-4 \
177
+ --saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
178
+ --log_root /workspace/SimToken/log \
179
+ --checkpoint_root /workspace/SimToken/checkpoints
180
+ ```
181
 
182
+ 启动日志中应出现:
 
183
 
184
+ ```text
185
+ initialized training from saved model: /workspace/SimToken/checkpoints/simtoken_pretrained.pth
186
+ missing keys: ... | unexpected keys: ...
187
  ```
188
 
 
 
189
  ---
190
 
191
+ ## 6. Upload to HuggingFace
 
 
192
 
193
+ 实验结束后,如需重新上传到 HuggingFace,先将数据目录压缩为归档文件,减少文件数量:
194
 
195
  ```bash
196
  cd /workspace/SimToken/data
197
 
198
+ tar -cf image_embed.tar image_embed/
199
  tar -czf gt_mask.tar.gz gt_mask/
200
  tar -czf audio_embed.tar.gz audio_embed/
201
  tar -cf media.tar media/
202
 
 
203
  ls -lh *.tar*
204
  rm -rf image_embed/ gt_mask/ audio_embed/ media/
205
  ```
206
 
207
+ 清理缓存并上传
208
 
209
  ```bash
210
+ cd /workspace/SimToken
 
211
 
212
+ find . -name "__pycache__" -prune -exec rm -rf {} +
213
+ find . -name ".pytest_cache" -prune -exec rm -rf {} +
214
+ find . -name ".cache" -prune -exec rm -rf {} +
215
+ find . -name "*.pyc" -delete
216
 
217
+ huggingface-cli login
218
  python upload_hf.py --repo yfan07/SimToken
219
  ```
 
 
 
 
 
 
simtoken_experiment.md ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SimToken 实验路线文档
2
+
3
+ ## 0. 当前状态
4
+
5
+ 前置诊断已经完成,路线收敛到 **A-min dynamic referent gate training**。
6
+
7
+ 已确认结论:
8
+
9
+ 1. **SAM decoder 下游是逐帧 batch-parallel 解码**
10
+ `mask_decoder(image_embeddings[0:T])[t]` 与 `mask_decoder(image_embeddings[t:t+1])[0]` 只有混合精度数值噪声差异。旧的 decoder-level joint-frame competition 假设关闭。
11
+
12
+ 2. **target_frame sweep 基本无效**
13
+ 不同 target frame 生成的 q 几乎相同,`cos_to_q5` 通常在 `0.997+`;Seen/Null 上 oracle gain 约 `+0.0009`。这条 TTO 路线关闭。
14
+
15
+ 3. **raw SAM-space D2 失效**
16
+ 256 维 `q/Fseg` 与 SAM image embedding 不在可直接 cosine 的语义空间,`real q ≈ shuffled/wrong_ref q`,甚至 random q 更高。该定义关闭。
17
+
18
+ 4. **LLM-space D2 有弱诊断信号,但不适合作为主 reward**
19
+ 用 4096 维 `[SEG]` hidden state 与 `mm_projector(CLIP patch tokens)` 后的视觉 token 计算 D2,可以得到正相关:
20
+ - `corr(s_pred, frame_iou) ≈ +0.316`
21
+ - bottom 20% `s_pred` 中 failure rate 相比随机 baseline 约 `1.60x`
22
+ - 控制 `iou_pred` / `pred_area` 后偏相关约 `+0.14`
23
+
24
+ 结论:`s_pred(beta=1.0)` 可以作为诊断信号或 frame-aware gate 的候选输入,但不能作为核心 TTO reward。
25
+
26
+ 5. **margin-D2 无效**
27
+ 离线 `s_margin = s(real) - max(s(shuffled), s(wrong_ref))` 的 failure enrichment 约 `0.93x`,会抵消掉有用的通用可见性/质量信号。该路线关闭。
28
+
29
+ 当前最干净的解释是:
30
+
31
+ > q 本身通常是稳定的 referent anchor;主要瓶颈不在 q 生成,也不在简单 q selection,而在 SAM decoder 如何使用已有的 `mask_token -> q` sparse self-attention path。
32
+
33
+ 2026-04-22 更新:
34
+
35
+ 完整训练每个 epoch 约 2-4 小时,瓶颈主要在 7B MLLM forward,而不在 gate 本身。因此当前实验策略已调整为:
36
+
37
+ 1. 先缓存固定 checkpoint 下的 `q = seg_embeddings`;
38
+ 2. 在 cached q + cached SAM image embeddings 上训练 gate-only;
39
+ 3. 用 cached eval split 快速判断 gate 是否有泛化收益;
40
+ 4. 只有 gate-only 泛化信号成立后,再跑完整 A-min 联合训练。
41
+
42
+ ---
43
+
44
+ ## 1. A-min 当前实现
45
+
46
+ 已在代码中加入 A-min dynamic referent gate:
47
+
48
+ - 文件:`models/segment_anything/modeling/transformer.py`
49
+ - 模块:`ReferentGate`
50
+ - 插入位置:`TwoWayAttentionBlock` 的 sparse self-attention + `norm1` 之后,token-to-image cross-attention 之前
51
+ - 作用对象:只作用于 `mask_tokens`
52
+ - 不作用于:`iou_token` 和 `q/sparse_prompt` 本身
53
+
54
+ SAM token index:
55
+
56
+ ```python
57
+ tokens = [iou_token, mask_tokens..., sparse_prompt(q)]
58
+ ```
59
+
60
+ 因此:
61
+
62
+ ```python
63
+ iou_token index: 0
64
+ mask token range: 1 : 1 + num_mask_tokens
65
+ q token index: 1 + num_mask_tokens
66
+ ```
67
+
68
+ A-min gate 形式:
69
+
70
+ ```python
71
+ alpha = sigmoid(Linear([mask_token, q, cos(mask_token, q)]))
72
+ mask_token = mask_token + alpha * Linear(q)
73
+ ```
74
+
75
+ 为保证旧 checkpoint 初始行为不变,`proj(q)` 分支使用零初始化。当前也将 `gate` 分支零初始化,使 alpha 有干净观测基线:
76
+
77
+ ```python
78
+ nn.init.zeros_(self.gate.weight)
79
+ nn.init.zeros_(self.gate.bias)
80
+ nn.init.zeros_(self.proj.weight)
81
+ nn.init.zeros_(self.proj.bias)
82
+ ```
83
+
84
+ 初始时 gate 为 identity:
85
+
86
+ ```text
87
+ max_abs_diff(gate(mask, q), mask) = 0.0
88
+ alpha_mean = 0.5
89
+ alpha_std = 0.0
90
+ ```
91
+
92
+ 当前训练 forward 保持完整链路:`prepare_inputs_labels_for_multimodal -> MLLM forward -> text_hidden_fcs -> SAM mask decoder -> loss`。`--gate_only` 只控制参数冻结范围,不再改变 forward 语义。
93
+
94
+ ---
95
+
96
+ ## 2. 当前新增工具
97
+
98
+ ### 2.1 训练脚本增强
99
+
100
+ `train.py` 已加入:
101
+
102
+ - `--max_steps`
103
+ - `--overfit_samples`
104
+ - `--log_gate_stats_every`
105
+ - `--skip_eval_after_train`
106
+ - `--eval_train_only`
107
+
108
+ 启动时会打印 referent gate 参数是否 trainable、是否进入 optimizer,以及初始 `proj_norm/gate_norm`。
109
+
110
+ ### 2.2 cached q 路线
111
+
112
+ 新增脚本:
113
+
114
+ - `cache_q_features.py`
115
+ - 离线缓存 `q = seg_embeddings`
116
+ - cache 文件很小,因为只保存 q 和少量 metadata
117
+ - `image_embeddings` 仍使用已有 `data/image_embed/{vid}.pt`
118
+ - `gt_masks` 仍使用已有 `data/gt_mask/...`
119
+
120
+ - `train_cached_gate.py`
121
+ - 加载 base model 和 cached q
122
+ - 冻结全部参数,只训练 `referent_gate`
123
+ - 支持 `--eval_only`、`--disable_gate`
124
+ - 支持 `--save_gate_only`,只保存 gate 参数,checkpoint 约 1.6MB
125
+ - 支持 `--gate_checkpoint`,在 base checkpoint 上 overlay gate-only checkpoint
126
+ - gate stats 会记录:
127
+
128
+ ```text
129
+ batch_miou
130
+ batch_fscore
131
+ proj_norm
132
+ gate_norm
133
+ proj_grad_norm
134
+ gate_grad_norm
135
+ alpha_mean / alpha_std / alpha_min / alpha_max
136
+ ```
137
+
138
+ cached 解码已优化:一个 dataloader batch 会展平成 paired frame batch 调用 `mask_decoder.forward_modified_v3`,避免逐 sample 调 decoder 的主要开销,同时不会产生 prompt/image cross product。
139
+
140
+ ---
141
+
142
+ ## 3. 已完成实验结果
143
+
144
+ ### 3.1 cached identity 与原始 pipeline 一致性
145
+
146
+ 先用 `test_s` 前 10 条验证 cached pipeline 是否与原始 `load_model.py` 对齐:
147
+
148
+ ```text
149
+ cached identity:
150
+ mIoU = 0.9686462879
151
+ Fscore = 0.9868578851
152
+
153
+ original load_model.py:
154
+ mIoU = 0.9686277151
155
+ Fscore = 0.9868472159
156
+
157
+ diff:
158
+ mIoU = +0.0000186
159
+ Fscore = +0.0000107
160
+ ```
161
+
162
+ 结论:差异远小于 0.001,cached q pipeline 与原始 eval pipeline 一致,可以用于 gate-only 快速验证。
163
+
164
+ ### 3.2 gate probe:梯度路径与 alpha 分化
165
+
166
+ 在 cached train128 上跑 50 optimizer steps:
167
+
168
+ ```text
169
+ step 5:
170
+ proj_norm=0.074015
171
+ gate_norm=0.064479
172
+ proj_grad_norm=0.052291
173
+ gate_grad_norm=0.000170
174
+ alpha_mean=0.4999
175
+ alpha_std=0.0019
176
+
177
+ step 50:
178
+ proj_norm=0.428711
179
+ gate_norm=0.523223
180
+ proj_grad_norm=0.022453
181
+ gate_grad_norm=0.000504
182
+ alpha_mean=0.5063
183
+ alpha_std=0.0112
184
+ ```
185
+
186
+ 结论:
187
+
188
+ - `proj_norm` 从 0 稳定增长,注入分支有梯度;
189
+ - `gate_norm` 也开始增长,alpha 控制分支参与学习;
190
+ - `alpha_std` 从 0 增长,说明 gate 对不同输入有分化响应;
191
+ - 计算图、冻结范围、optimizer param groups 均正常。
192
+
193
+ ### 3.3 overfit32:表达能力验证
194
+
195
+ cached train32 identity baseline:
196
+
197
+ ```text
198
+ mIoU = 0.8814558
199
+ Fscore = 0.9375512
200
+ ```
201
+
202
+ cached gate overfit32,200 steps,lr=1e-4:
203
+
204
+ ```text
205
+ mIoU = 0.9085821
206
+ Fscore = 0.9444574
207
+ ```
208
+
209
+ 提升:
210
+
211
+ ```text
212
+ mIoU = +0.0271263
213
+ Fscore = +0.0069063
214
+ ```
215
+
216
+ 结论:在 q、SAM image embeddings、mask decoder 原始参数均固定时,仅训练 A-min gate 就能明显提高训练集 mIoU,说明 gate 机制有表达能力,梯度路径通畅。
217
+
218
+ ### 3.4 overfit32 泛化评估
219
+
220
+ 对 cached eval split 前 200 条,identity baseline:
221
+
222
+ ```text
223
+ test_s mIoU = 0.7390979
224
+ test_s Fscore = 0.8190672
225
+
226
+ test_u mIoU = 0.6732285
227
+ test_u Fscore = 0.7734924
228
+
229
+ test_n metric = 0.0606105
230
+ ```
231
+
232
+ overfit32 gate checkpoint:
233
+
234
+ ```text
235
+ test_s mIoU = 0.7199481
236
+ test_s Fscore = 0.8045849
237
+
238
+ test_u mIoU = 0.6672303
239
+ test_u Fscore = 0.7663978
240
+
241
+ test_n metric = 0.0648588
242
+ ```
243
+
244
+ delta:
245
+
246
+ ```text
247
+ test_s mIoU = -0.0191498
248
+ test_s Fscore = -0.0144823
249
+
250
+ test_u mIoU = -0.0059983
251
+ test_u Fscore = -0.0070946
252
+
253
+ test_n metric = +0.0042483
254
+ ```
255
+
256
+ 结论:
257
+
258
+ - overfit32 gate 没有泛化;
259
+ - Null metric 略升,说明小样本过拟合有轻微放大前景的倾向;
260
+ - 这不是方法失败,而是 32 个样本不足以学到泛化 referent anchoring 的预期结果;
261
+ - 下一步应扩大 cached train 样本量,并降低 lr。
262
+
263
+ ---
264
+
265
+ ## 4. 当前下一步实验:cached train256 gate-only
266
+
267
+ 用户已经完成 train256 的 q 缓存。下一步用 train256 跑更保守的 gate-only 泛化实验。
268
+
269
+ ### Step 1:训练 cached gate-only train256
270
+
271
+ ```bash
272
+ cd /workspace/SimToken
273
+ mkdir -p log checkpoints
274
+
275
+ TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
276
+ --cache_split train \
277
+ --cache_root /workspace/SimToken/cache_q \
278
+ --name cached_gate_train256_s300_lr3e5 \
279
+ --epochs 20 \
280
+ --max_steps 300 \
281
+ --batch_size 8 \
282
+ --lr 3e-5 \
283
+ --saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
284
+ --log_root /workspace/SimToken/log \
285
+ --checkpoint_root /workspace/SimToken/checkpoints \
286
+ --log_gate_stats_every 50 \
287
+ --skip_eval_after_train \
288
+ --save_gate_only \
289
+ 2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5.stdout
290
+ ```
291
+
292
+ 训练中重点观察:
293
+
294
+ ```text
295
+ batch_miou / batch_fscore 是否逐步改善
296
+ proj_norm 是否持续增长
297
+ alpha_std 是否温和分化
298
+ Null 风险:alpha 是否出现极端偏移
299
+ ```
300
+
301
+ 如果 `proj_norm` 在前 100 steps 仍接近 0,说明 lr=3e-5 可能过小,可以改回 1e-4 或使用分层 lr。
302
+
303
+ ### Step 2:评估 cached train256 gate checkpoint
304
+
305
+ ```bash
306
+ for split in test_s test_u test_n; do
307
+ TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
308
+ --cache_split $split \
309
+ --cache_root /workspace/SimToken/cache_q \
310
+ --batch_size 8 \
311
+ --saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
312
+ --gate_checkpoint /workspace/SimToken/checkpoints/cached_gate_train256_s300_lr3e5.pth \
313
+ --eval_only \
314
+ --name cached_gate_train256_s300_lr3e5_${split}_200 \
315
+ 2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5_${split}_200.stdout
316
+ done
317
+ ```
318
+
319
+ 对比 baseline 使用 3.4 中 identity 200 条结果。
320
+
321
+ ### Step 3:根据结果决策
322
+
323
+ 判断标准:
324
+
325
+ - Seen / Unseen 都提升:进入更大 cached train 或完整 A-min;
326
+ - Seen 提升、Unseen 不提升:gate 仍可能学 dataset pattern,需要更多 train cache 或更强正则;
327
+ - Seen / Unseen 都下降:不要跑完整 A-min,先调 lr、正则或 gate 容量;
328
+ - Null metric 保持 `< 0.07`:暂不加 area penalty;
329
+ - Null metric 超过 `0.10`:强危险信号,需要 area penalty 或约束预测面积。
330
+
331
+ 如果 train256 有弱正收益但幅度小,先看 alpha 分布和 hard/easy frames,而不是立刻扩大。若 alpha 在所有帧上几乎一致,可能只是全局偏置;若 hard frames alpha 系统性更高,说明更像 referent anchoring。
332
+
333
+ ---
334
+
335
+ ## 5. 成功标准
336
+
337
+ A-min 成功不能只看总体 mIoU,需要同时满足:
338
+
339
+ 1. Seen / Unseen mIoU 稳定提升;
340
+ 2. Unseen 至少不弱于 Seen 的提升趋势;
341
+ 3. Null 指标不恶化,预测面积不膨胀;
342
+ 4. hard frames 改善更明显;
343
+ 5. 如果记录 gate alpha,hard frames 的 alpha 应系统性高于 easy frames。
344
+
345
+ 失败解释:
346
+
347
+ - 如果 Seen 提升、Unseen 不提升:可能是 gate 学到数据集模式,而不是 referent anchoring;
348
+ - 如果 Null 恶化:gate 可能放大了通用前景显著性;
349
+ - 如果 gate-only 无变化但完整 A-min 有收益:说明 gate 需要与 mask decoder / text projection 协同适配;
350
+ - 如果全 split 下降:gate 插入位置、初始化或学习率需要重新检查。
351
+
352
+ ---
353
+
354
+ ## 6. 后续机制分析
355
+
356
+ 如果 A-min 有正收益,再做 hook 分析:
357
+
358
+ 1. sparse self-attention 中 `mask_token -> q`;
359
+ 2. token-to-image attention 中 mask token 对 image tokens 的关注;
360
+ 3. A-min 前后 hard/easy frames 的 gate alpha;
361
+ 4. `s_pred(beta=1.0)` 与 gate alpha 的关系。
362
+
363
+ 这部分用于论文解释,不作为当前阻塞项。
364
+
365
+ ---
366
+
367
+ ## 7. 当前一句话结论
368
+
369
+ > A-min gate 的梯度路径、表达能力和 cached pipeline 一致性已经通过验证;overfit32 能显著提升训练集但不能泛化。当前主线是用更大 cached train set(已完成 train256 cache)验证 gate-only 泛化,再决定是否投入完整 A-min 联合训练。
target_frame_sweep.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import random
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import transformers
10
+ from torch.utils.data import DataLoader
11
+
12
+ from configs import args
13
+ from datasets import REFAVS
14
+ from decoder_invariance_check import build_model, set_seed
15
+ from load_model import collate_fn, dict_to_cuda
16
+ from utils import utility
17
+
18
+
19
+ def decode_with_q(model, batch, q):
20
+ visual_model = model.get_model().visual_model
21
+ image_embeddings = batch["image_feats"][0]
22
+
23
+ sparse, dense = visual_model.prompt_encoder(
24
+ points=None,
25
+ boxes=None,
26
+ masks=None,
27
+ text_embeds=q.unsqueeze(1),
28
+ )
29
+ sparse = sparse.to(q.dtype)
30
+ dense = dense.to(q.dtype)
31
+
32
+ low_res_masks, iou_predictions = visual_model.mask_decoder(
33
+ image_embeddings=image_embeddings,
34
+ image_pe=visual_model.prompt_encoder.get_dense_pe(),
35
+ sparse_prompt_embeddings=sparse,
36
+ dense_prompt_embeddings=dense,
37
+ multimask_output=False,
38
+ )
39
+ pred_masks = visual_model.postprocess_masks(
40
+ low_res_masks,
41
+ input_size=batch["resizes"][0],
42
+ original_size=batch["orgsizes"][0],
43
+ ).squeeze(1)
44
+ return pred_masks.unsqueeze(0), iou_predictions.squeeze(-1)
45
+
46
+
47
+ def get_q_for_target_frame(model, batch, target_frame):
48
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
49
+ output = model.forward(
50
+ images=batch["images"],
51
+ images_clip=batch["images_clip"],
52
+ audio_features=batch["audio_feats"],
53
+ image_features=batch["image_feats"],
54
+ input_ids=batch["input_ids"],
55
+ labels=batch["labels"],
56
+ attention_masks=batch["attention_masks"],
57
+ masks_list=batch["masks"],
58
+ resize_list=batch["resizes"],
59
+ orgsize_list=batch["orgsizes"],
60
+ conversation_list=batch["convs"],
61
+ refs_num=batch["refs_num"],
62
+ fids=batch["fids"],
63
+ vids=batch["vids"],
64
+ contrast=args.ct_weight,
65
+ ref_ids=batch["ref_ids"],
66
+ inference=True,
67
+ target_frame=target_frame,
68
+ )
69
+ return output["seg_embeddings"][0][0:1]
70
+
71
+
72
+ def mask_area(pred_masks):
73
+ return (torch.sigmoid(pred_masks) > 0.4).float().mean().item()
74
+
75
+
76
+ def mean_mask_iou_to_others(mask, other_masks):
77
+ if not other_masks:
78
+ return 1.0
79
+ binary = (torch.sigmoid(mask) > 0.4).float()
80
+ other_binary = [(torch.sigmoid(m) > 0.4).float() for m in other_masks]
81
+ vals = []
82
+ for other in other_binary:
83
+ inter = (binary * other).sum()
84
+ union = torch.maximum(binary, other).sum()
85
+ vals.append((inter / (union + 1e-7)).item())
86
+ return float(np.mean(vals))
87
+
88
+
89
+ def evaluate_one_sample(model, batch, sample_idx):
90
+ rows = []
91
+ qs = []
92
+ pred_masks_by_tf = []
93
+
94
+ gt_masks = batch["masks"][0]
95
+ vid = batch["vids"][0]
96
+ ref = batch["refs"][0][0]
97
+
98
+ for target_frame in range(args.frame_n):
99
+ q = get_q_for_target_frame(model, batch, target_frame)
100
+ qs.append(q.float().squeeze(0))
101
+
102
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
103
+ pred_masks, iou_predictions = decode_with_q(model, batch, q)
104
+ pred_masks_by_tf.append(pred_masks.detach())
105
+
106
+ miou = utility.mask_iou(pred_masks.float(), gt_masks.float())
107
+ fscore = utility.Eval_Fmeasure(pred_masks.float(), gt_masks.float(), None)
108
+ null_metric = utility.metric_s_for_null(pred_masks.float())
109
+ area = mask_area(pred_masks)
110
+ mean_iou_pred = iou_predictions.float().mean().item()
111
+
112
+ rows.append(
113
+ {
114
+ "sample_idx": sample_idx,
115
+ "vid": vid,
116
+ "ref": ref,
117
+ "target_frame": target_frame,
118
+ "mean_iou_pred": mean_iou_pred,
119
+ "mask_area": area,
120
+ "null_metric": float(null_metric),
121
+ "miou": miou,
122
+ "fscore": fscore,
123
+ "cos_to_q5": 0.0,
124
+ "mean_cos_to_other_q": 0.0,
125
+ "mean_mask_iou_to_other_tf": 0.0,
126
+ }
127
+ )
128
+
129
+ q_stack = F.normalize(torch.stack(qs, dim=0), dim=-1)
130
+ q_cos = q_stack @ q_stack.T
131
+ q5_idx = min(5, len(qs) - 1)
132
+
133
+ for i, row in enumerate(rows):
134
+ other = [j for j in range(len(rows)) if j != i]
135
+ row["cos_to_q5"] = q_cos[i, q5_idx].item()
136
+ row["mean_cos_to_other_q"] = q_cos[i, other].mean().item()
137
+ row["mean_mask_iou_to_other_tf"] = mean_mask_iou_to_others(
138
+ pred_masks_by_tf[i], [pred_masks_by_tf[j] for j in other]
139
+ )
140
+
141
+ return rows
142
+
143
+
144
+ def print_sample_summary(rows):
145
+ print(f"\nSample {rows[0]['sample_idx']}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
146
+ print("tf | miou | fscore | null_s | iou_pred | area | cos_to_q5 | mean_q_cos")
147
+ for row in rows:
148
+ print(
149
+ f"{row['target_frame']:02d} | "
150
+ f"{row['miou']:.4f} | "
151
+ f"{row['fscore']:.4f} | "
152
+ f"{row['null_metric']:.4f} | "
153
+ f"{row['mean_iou_pred']:.4f} | "
154
+ f"{row['mask_area']:.4f} | "
155
+ f"{row['cos_to_q5']:.4f} | "
156
+ f"{row['mean_cos_to_other_q']:.4f}"
157
+ )
158
+
159
+ best_miou = max(rows, key=lambda x: x["miou"])
160
+ best_iou_pred = max(rows, key=lambda x: x["mean_iou_pred"])
161
+ fixed = rows[min(5, len(rows) - 1)]
162
+ miou_values = [row["miou"] for row in rows]
163
+ q5_values = [row["cos_to_q5"] for row in rows]
164
+ print(
165
+ "Best miou tf="
166
+ f"{best_miou['target_frame']} ({best_miou['miou']:.4f}); "
167
+ "best iou_pred tf="
168
+ f"{best_iou_pred['target_frame']} ({best_iou_pred['mean_iou_pred']:.4f}); "
169
+ f"fixed tf=5 miou={fixed['miou']:.4f}"
170
+ )
171
+ print(
172
+ f"target-frame miou range={max(miou_values) - min(miou_values):.4f}; "
173
+ f"min cos_to_q5={min(q5_values):.4f}"
174
+ )
175
+
176
+
177
+ def main():
178
+ set_seed(42)
179
+ torch.set_grad_enabled(False)
180
+
181
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
182
+ args.mllm,
183
+ cache_dir=None,
184
+ model_max_length=2048,
185
+ padding_side="right",
186
+ use_fast=False,
187
+ )
188
+ tokenizer.pad_token = tokenizer.unk_token
189
+ tokenizer.add_tokens("[SEG]")
190
+ seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
191
+
192
+ dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
193
+ loader = DataLoader(
194
+ dataset,
195
+ batch_size=1,
196
+ shuffle=False,
197
+ num_workers=0,
198
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
199
+ )
200
+
201
+ limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
202
+ print(f"Split: {args.eval_split} | samples to sweep: {limit}")
203
+
204
+ model = build_model(tokenizer, seg_token_idx)
205
+
206
+ all_rows = []
207
+ for sample_idx, batch in enumerate(loader):
208
+ if sample_idx >= limit:
209
+ break
210
+ batch = dict_to_cuda(batch)
211
+ rows = evaluate_one_sample(model, batch, sample_idx)
212
+ all_rows.extend(rows)
213
+ print_sample_summary(rows)
214
+
215
+ if not all_rows:
216
+ raise RuntimeError("No rows were checked. Is the selected split empty?")
217
+
218
+ fixed_rows = [r for r in all_rows if r["target_frame"] == min(5, args.frame_n - 1)]
219
+ oracle_by_sample = {}
220
+ iou_pred_by_sample = {}
221
+ for row in all_rows:
222
+ key = row["sample_idx"]
223
+ if key not in oracle_by_sample or row["miou"] > oracle_by_sample[key]["miou"]:
224
+ oracle_by_sample[key] = row
225
+ if key not in iou_pred_by_sample or row["mean_iou_pred"] > iou_pred_by_sample[key]["mean_iou_pred"]:
226
+ iou_pred_by_sample[key] = row
227
+
228
+ fixed_miou = np.mean([r["miou"] for r in fixed_rows])
229
+ fixed_null_metric = np.mean([r["null_metric"] for r in fixed_rows])
230
+ oracle_miou = np.mean([r["miou"] for r in oracle_by_sample.values()])
231
+ iou_pred_selected_miou = np.mean([r["miou"] for r in iou_pred_by_sample.values()])
232
+ min_cos_to_q5 = np.mean(
233
+ [min(r["cos_to_q5"] for r in all_rows if r["sample_idx"] == sample_idx) for sample_idx in oracle_by_sample]
234
+ )
235
+ mean_miou_range = np.mean(
236
+ [
237
+ max(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
238
+ - min(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
239
+ for sample_idx in oracle_by_sample
240
+ ]
241
+ )
242
+
243
+ print("\nSummary")
244
+ print(f"samples: {len(fixed_rows)}")
245
+ print(f"fixed target_frame=5 mean miou: {fixed_miou:.4f}")
246
+ print(f"fixed target_frame=5 mean null_s: {fixed_null_metric:.4f}")
247
+ print(f"oracle best-target-frame mean miou: {oracle_miou:.4f}")
248
+ print(f"best-by-iou_pred selected mean miou: {iou_pred_selected_miou:.4f}")
249
+ print(f"oracle gain over fixed: {oracle_miou - fixed_miou:+.4f}")
250
+ print(f"iou_pred-selection gain over fixed: {iou_pred_selected_miou - fixed_miou:+.4f}")
251
+ print(f"mean target-frame miou range: {mean_miou_range:.4f}")
252
+ print(f"mean sample min cos_to_q5: {min_cos_to_q5:.4f}")
253
+
254
+ csv_path = os.environ.get("TARGET_FRAME_SWEEP_CSV")
255
+ if csv_path:
256
+ os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
257
+ with open(csv_path, "w", newline="") as f:
258
+ writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
259
+ writer.writeheader()
260
+ writer.writerows(all_rows)
261
+ print(f"Saved CSV: {csv_path}")
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()
train.py CHANGED
@@ -1,7 +1,7 @@
1
  import transformers
2
  from datasets import REFAVS
3
  from configs import args
4
- from torch.utils.data import DataLoader
5
  from functools import partial
6
  from models.llava import conversation as conversation_lib
7
  # from models.avs_model import VISAForCausalLM
@@ -21,6 +21,7 @@ import numpy as np
21
  import re
22
  import time
23
  import os
 
24
 
25
 
26
  import warnings
@@ -235,11 +236,19 @@ if __name__ == "__main__":
235
  val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
236
  val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
237
 
 
 
 
 
 
 
 
238
 
239
  g = torch.Generator()
240
  g.manual_seed(42)
241
 
242
  train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g)
 
243
 
244
  val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
245
  val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
@@ -349,6 +358,11 @@ if __name__ == "__main__":
349
  model = model.to("cuda")
350
  model.resize_token_embeddings(len(tokenizer))
351
 
 
 
 
 
 
352
 
353
  for name, param in model.audio_feature_layer.named_parameters():
354
  param.requires_grad = True
@@ -366,9 +380,113 @@ if __name__ == "__main__":
366
  ):
367
  p.requires_grad = True
368
 
 
 
 
 
 
 
 
 
 
369
 
370
  print("will save train model")
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  def valuate(model, dataloader, args, name):
373
  model.eval()
374
 
@@ -420,11 +538,17 @@ if __name__ == "__main__":
420
  epochs = args.epochs
421
  print("init lr:", args.lr)
422
  optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
 
423
 
424
- gradient_accumulation_steps = int(16 // args.batch_size)
425
- step_per_epoch = len(train_dataloader) // gradient_accumulation_steps
426
- total_steps = epochs * step_per_epoch
 
427
  warmup_steps = int(total_steps * 0.1)
 
 
 
 
428
 
429
  scheduler = get_cosine_schedule_with_warmup(
430
  optimizer,
@@ -433,6 +557,9 @@ if __name__ == "__main__":
433
  )
434
 
435
 
 
 
 
436
  for epoch in range(epochs):
437
 
438
  model.train()
@@ -441,6 +568,9 @@ if __name__ == "__main__":
441
 
442
  loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}")
443
  for step, batch in enumerate(loop):
 
 
 
444
  input_dict = dict_to_cuda(batch)
445
  output_dict = model.forward(images=input_dict["images"],
446
  images_clip=input_dict["images_clip"],
@@ -459,6 +589,7 @@ if __name__ == "__main__":
459
  contrast=args.ct_weight,
460
  ref_ids=input_dict["ref_ids"],
461
  epoch=epoch,
 
462
  inference=False)
463
 
464
  loss = output_dict["loss"]
@@ -468,6 +599,15 @@ if __name__ == "__main__":
468
 
469
 
470
  if (step + 1) % gradient_accumulation_steps == 0:
 
 
 
 
 
 
 
 
 
471
  optimizer.step()
472
  scheduler.step()
473
  optimizer.zero_grad()
@@ -475,16 +615,33 @@ if __name__ == "__main__":
475
  current_lr = scheduler.get_lr()[0]
476
  loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps))
477
 
478
- print(f" Epoch {epoch + 1}, Loss:{running_loss / ((step + 1) / gradient_accumulation_steps) :.4f}, Learning Rate:{scheduler.get_last_lr()[0]:.6f}")
 
 
 
 
 
479
 
480
 
481
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
482
- f.write(f"Epoch {epoch}: running_loss {running_loss / len(train_dataloader) * gradient_accumulation_steps} Learning Rate:{scheduler.get_last_lr()[0]:.6f}\n")
 
 
 
 
483
 
484
 
485
  torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth"))
486
  print(f"trained model saved as {args.name}.pth")
487
 
 
 
 
 
 
 
 
 
488
  # ---------------test on seen & unseen ------------------------------------------
489
  model.eval()
490
 
@@ -531,4 +688,4 @@ if __name__ == "__main__":
531
  print(f"\n valuate on test_n_refer, metric: {total_metric/count}")
532
 
533
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
534
- f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n")
 
1
  import transformers
2
  from datasets import REFAVS
3
  from configs import args
4
+ from torch.utils.data import DataLoader, Subset
5
  from functools import partial
6
  from models.llava import conversation as conversation_lib
7
  # from models.avs_model import VISAForCausalLM
 
21
  import re
22
  import time
23
  import os
24
+ import sys
25
 
26
 
27
  import warnings
 
236
  val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
237
  val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
238
 
239
+ if args.overfit_samples > 0:
240
+ overfit_n = min(args.overfit_samples, len(train_dataset))
241
+ train_dataset = Subset(train_dataset, list(range(overfit_n)))
242
+ print(f"overfit_samples enabled: using first {overfit_n} train samples")
243
+
244
+ train_eval_dataset = train_dataset
245
+
246
 
247
  g = torch.Generator()
248
  g.manual_seed(42)
249
 
250
  train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g)
251
+ train_eval_dataloader = DataLoader(train_eval_dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
252
 
253
  val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
254
  val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
 
358
  model = model.to("cuda")
359
  model.resize_token_embeddings(len(tokenizer))
360
 
361
+ if args.init_from_saved_model or args.gate_only:
362
+ state = torch.load(args.saved_model, map_location="cpu")
363
+ missing, unexpected = model.load_state_dict(state, strict=False)
364
+ print(f"initialized training from saved model: {args.saved_model}")
365
+ print(f"missing keys: {len(missing)} | unexpected keys: {len(unexpected)}")
366
 
367
  for name, param in model.audio_feature_layer.named_parameters():
368
  param.requires_grad = True
 
380
  ):
381
  p.requires_grad = True
382
 
383
+ if args.gate_only:
384
+ for p in model.parameters():
385
+ p.requires_grad = False
386
+ for n, p in model.named_parameters():
387
+ if "referent_gate" in n:
388
+ p.requires_grad = True
389
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
390
+ total = sum(p.numel() for p in model.parameters())
391
+ print(f"gate_only enabled: trainable params {trainable} / {total}")
392
 
393
  print("will save train model")
394
 
395
+ def _total_norm(values):
396
+ if not values:
397
+ return 0.0
398
+ return float(sum(v * v for v in values) ** 0.5)
399
+
400
+ def collect_referent_gate_stats(model):
401
+ gate_modules = [(n, m) for n, m in model.named_modules() if n.endswith("referent_gate")]
402
+ proj_norms = []
403
+ gate_norms = []
404
+ proj_grad_norms = []
405
+ gate_grad_norms = []
406
+ alpha_tensors = []
407
+
408
+ for _, module in gate_modules:
409
+ proj_norms.append(module.proj.weight.detach().float().norm().item())
410
+ gate_norms.append(module.gate.weight.detach().float().norm().item())
411
+ if module.proj.weight.grad is not None:
412
+ proj_grad_norms.append(module.proj.weight.grad.detach().float().norm().item())
413
+ if module.gate.weight.grad is not None:
414
+ gate_grad_norms.append(module.gate.weight.grad.detach().float().norm().item())
415
+ if module.last_alpha is not None:
416
+ alpha_tensors.append(module.last_alpha.detach().float().reshape(-1))
417
+
418
+ stats = {
419
+ "modules": len(gate_modules),
420
+ "proj_norm": _total_norm(proj_norms),
421
+ "gate_norm": _total_norm(gate_norms),
422
+ "proj_grad_norm": _total_norm(proj_grad_norms),
423
+ "gate_grad_norm": _total_norm(gate_grad_norms),
424
+ }
425
+
426
+ if alpha_tensors:
427
+ alpha = torch.cat(alpha_tensors)
428
+ stats.update(
429
+ {
430
+ "alpha_mean": alpha.mean().item(),
431
+ "alpha_std": alpha.std(unbiased=False).item(),
432
+ "alpha_min": alpha.min().item(),
433
+ "alpha_max": alpha.max().item(),
434
+ }
435
+ )
436
+ else:
437
+ stats.update(
438
+ {
439
+ "alpha_mean": float("nan"),
440
+ "alpha_std": float("nan"),
441
+ "alpha_min": float("nan"),
442
+ "alpha_max": float("nan"),
443
+ }
444
+ )
445
+
446
+ return stats
447
+
448
+ def print_referent_gate_optimizer_sanity(model, optimizer):
449
+ optimizer_param_ids = {id(p) for group in optimizer.param_groups for p in group["params"]}
450
+ gate_params = [(n, p) for n, p in model.named_parameters() if "referent_gate" in n]
451
+ trainable_gate = [(n, p) for n, p in gate_params if p.requires_grad]
452
+ optimizer_gate = [(n, p) for n, p in gate_params if id(p) in optimizer_param_ids]
453
+ optimizer_trainable_gate = [
454
+ (n, p) for n, p in gate_params if p.requires_grad and id(p) in optimizer_param_ids
455
+ ]
456
+ print(
457
+ "referent_gate sanity: "
458
+ f"params={sum(p.numel() for _, p in gate_params)} | "
459
+ f"trainable={sum(p.numel() for _, p in trainable_gate)} | "
460
+ f"in_optimizer={sum(p.numel() for _, p in optimizer_gate)} | "
461
+ f"trainable_in_optimizer={sum(p.numel() for _, p in optimizer_trainable_gate)}"
462
+ )
463
+
464
+ stats = collect_referent_gate_stats(model)
465
+ print(
466
+ "referent_gate init stats: "
467
+ f"modules={stats['modules']} | "
468
+ f"proj_norm={stats['proj_norm']:.6f} | "
469
+ f"gate_norm={stats['gate_norm']:.6f}"
470
+ )
471
+
472
+ def log_referent_gate_stats(global_step, loss_value):
473
+ stats = collect_referent_gate_stats(model)
474
+ message = (
475
+ f"gate_stats step={global_step} "
476
+ f"loss={loss_value:.6f} "
477
+ f"proj_norm={stats['proj_norm']:.6f} "
478
+ f"gate_norm={stats['gate_norm']:.6f} "
479
+ f"proj_grad_norm={stats['proj_grad_norm']:.6f} "
480
+ f"gate_grad_norm={stats['gate_grad_norm']:.6f} "
481
+ f"alpha_mean={stats['alpha_mean']:.4f} "
482
+ f"alpha_std={stats['alpha_std']:.4f} "
483
+ f"alpha_min={stats['alpha_min']:.4f} "
484
+ f"alpha_max={stats['alpha_max']:.4f}"
485
+ )
486
+ print(message)
487
+ with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
488
+ f.write(message + "\n")
489
+
490
  def valuate(model, dataloader, args, name):
491
  model.eval()
492
 
 
538
  epochs = args.epochs
539
  print("init lr:", args.lr)
540
  optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
541
+ print_referent_gate_optimizer_sanity(model, optimizer)
542
 
543
+ gradient_accumulation_steps = max(1, int(16 // args.batch_size))
544
+ step_per_epoch = max(1, len(train_dataloader) // gradient_accumulation_steps)
545
+ full_total_steps = epochs * step_per_epoch
546
+ total_steps = min(args.max_steps, full_total_steps) if args.max_steps > 0 else full_total_steps
547
  warmup_steps = int(total_steps * 0.1)
548
+ print(
549
+ f"training schedule: grad_accum={gradient_accumulation_steps} | "
550
+ f"step_per_epoch={step_per_epoch} | total_optimizer_steps={total_steps}"
551
+ )
552
 
553
  scheduler = get_cosine_schedule_with_warmup(
554
  optimizer,
 
557
  )
558
 
559
 
560
+ optimizer_step_count = 0
561
+ stop_training = False
562
+
563
  for epoch in range(epochs):
564
 
565
  model.train()
 
568
 
569
  loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}")
570
  for step, batch in enumerate(loop):
571
+ if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
572
+ stop_training = True
573
+ break
574
  input_dict = dict_to_cuda(batch)
575
  output_dict = model.forward(images=input_dict["images"],
576
  images_clip=input_dict["images_clip"],
 
589
  contrast=args.ct_weight,
590
  ref_ids=input_dict["ref_ids"],
591
  epoch=epoch,
592
+ gate_only=args.gate_only,
593
  inference=False)
594
 
595
  loss = output_dict["loss"]
 
599
 
600
 
601
  if (step + 1) % gradient_accumulation_steps == 0:
602
+ optimizer_step_count += 1
603
+ if (
604
+ args.log_gate_stats_every > 0
605
+ and optimizer_step_count % args.log_gate_stats_every == 0
606
+ ):
607
+ log_referent_gate_stats(
608
+ optimizer_step_count,
609
+ loss.item() * gradient_accumulation_steps,
610
+ )
611
  optimizer.step()
612
  scheduler.step()
613
  optimizer.zero_grad()
 
615
  current_lr = scheduler.get_lr()[0]
616
  loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps))
617
 
618
+ if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
619
+ stop_training = True
620
+ break
621
+
622
+ denom = max(1, optimizer_step_count)
623
+ print(f" Epoch {epoch + 1}, Loss:{running_loss / denom :.4f}, Learning Rate:{scheduler.get_last_lr()[0]:.6f}")
624
 
625
 
626
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
627
+ f.write(f"Epoch {epoch}: running_loss {running_loss / denom} Learning Rate:{scheduler.get_last_lr()[0]:.6f}\n")
628
+
629
+ if stop_training:
630
+ print(f"stopped early at optimizer step {optimizer_step_count}")
631
+ break
632
 
633
 
634
  torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth"))
635
  print(f"trained model saved as {args.name}.pth")
636
 
637
+ if args.skip_eval_after_train:
638
+ print("skip_eval_after_train enabled: exiting after checkpoint save")
639
+ sys.exit(0)
640
+
641
+ if args.eval_train_only:
642
+ valuate(model, train_eval_dataloader, args, 'train_overfit')
643
+ sys.exit(0)
644
+
645
  # ---------------test on seen & unseen ------------------------------------------
646
  model.eval()
647
 
 
688
  print(f"\n valuate on test_n_refer, metric: {total_metric/count}")
689
 
690
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
691
+ f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n")
train_cached_gate.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import transformers
9
+ from torch.optim import AdamW
10
+ from torch.utils.data import DataLoader, Dataset, Subset
11
+ from tqdm import tqdm
12
+
13
+ from configs import args
14
+ from decoder_invariance_check import build_model, set_seed
15
+ from models.avs_model import dice_loss, sigmoid_ce_loss
16
+ from utils import utility
17
+
18
+
19
+ def _total_norm(values):
20
+ if not values:
21
+ return 0.0
22
+ return float(sum(v * v for v in values) ** 0.5)
23
+
24
+
25
+ def collect_referent_gate_stats(model):
26
+ gate_modules = [(n, m) for n, m in model.named_modules() if n.endswith("referent_gate")]
27
+ proj_norms = []
28
+ gate_norms = []
29
+ proj_grad_norms = []
30
+ gate_grad_norms = []
31
+ alpha_tensors = []
32
+
33
+ for _, module in gate_modules:
34
+ proj_norms.append(module.proj.weight.detach().float().norm().item())
35
+ gate_norms.append(module.gate.weight.detach().float().norm().item())
36
+ if module.proj.weight.grad is not None:
37
+ proj_grad_norms.append(module.proj.weight.grad.detach().float().norm().item())
38
+ if module.gate.weight.grad is not None:
39
+ gate_grad_norms.append(module.gate.weight.grad.detach().float().norm().item())
40
+ if module.last_alpha is not None:
41
+ alpha_tensors.append(module.last_alpha.detach().float().reshape(-1))
42
+
43
+ stats = {
44
+ "modules": len(gate_modules),
45
+ "proj_norm": _total_norm(proj_norms),
46
+ "gate_norm": _total_norm(gate_norms),
47
+ "proj_grad_norm": _total_norm(proj_grad_norms),
48
+ "gate_grad_norm": _total_norm(gate_grad_norms),
49
+ }
50
+
51
+ if alpha_tensors:
52
+ alpha = torch.cat(alpha_tensors)
53
+ stats.update(
54
+ {
55
+ "alpha_mean": alpha.mean().item(),
56
+ "alpha_std": alpha.std(unbiased=False).item(),
57
+ "alpha_min": alpha.min().item(),
58
+ "alpha_max": alpha.max().item(),
59
+ }
60
+ )
61
+ else:
62
+ stats.update(
63
+ {
64
+ "alpha_mean": float("nan"),
65
+ "alpha_std": float("nan"),
66
+ "alpha_min": float("nan"),
67
+ "alpha_max": float("nan"),
68
+ }
69
+ )
70
+
71
+ return stats
72
+
73
+
74
+ def zero_referent_gate(model):
75
+ with torch.no_grad():
76
+ for _, module in model.named_modules():
77
+ if not _.endswith("referent_gate"):
78
+ continue
79
+ module.gate.weight.zero_()
80
+ module.gate.bias.zero_()
81
+ module.proj.weight.zero_()
82
+ module.proj.bias.zero_()
83
+ module.last_alpha = None
84
+
85
+
86
+ def referent_gate_state_dict(model):
87
+ return {
88
+ name: param.detach().cpu()
89
+ for name, param in model.state_dict().items()
90
+ if "referent_gate" in name
91
+ }
92
+
93
+
94
+ def load_referent_gate_checkpoint(model, path):
95
+ checkpoint = torch.load(path, map_location="cpu")
96
+ if isinstance(checkpoint, dict) and checkpoint.get("type") == "referent_gate_only":
97
+ checkpoint = checkpoint["state_dict"]
98
+ gate_state = {k: v for k, v in checkpoint.items() if "referent_gate" in k}
99
+ if not gate_state:
100
+ raise RuntimeError(f"No referent_gate parameters found in {path}")
101
+ current = model.state_dict()
102
+ missing_shape = [
103
+ k
104
+ for k, v in gate_state.items()
105
+ if k not in current or tuple(current[k].shape) != tuple(v.shape)
106
+ ]
107
+ if missing_shape:
108
+ raise RuntimeError(f"Gate checkpoint has incompatible keys: {missing_shape[:5]}")
109
+ current.update(gate_state)
110
+ model.load_state_dict(current, strict=True)
111
+ print(f"loaded referent gate checkpoint: {path} ({len(gate_state)} tensors)")
112
+
113
+
114
+ def log_gate_stats(model, step, loss_value, batch_metrics=None):
115
+ stats = collect_referent_gate_stats(model)
116
+ metric_text = ""
117
+ if batch_metrics is not None:
118
+ metric_text = (
119
+ f"batch_miou={batch_metrics['miou']:.4f} "
120
+ f"batch_fscore={batch_metrics['fscore']:.4f} "
121
+ )
122
+ message = (
123
+ f"gate_stats step={step} "
124
+ f"loss={loss_value:.6f} "
125
+ f"{metric_text}"
126
+ f"proj_norm={stats['proj_norm']:.6f} "
127
+ f"gate_norm={stats['gate_norm']:.6f} "
128
+ f"proj_grad_norm={stats['proj_grad_norm']:.6f} "
129
+ f"gate_grad_norm={stats['gate_grad_norm']:.6f} "
130
+ f"alpha_mean={stats['alpha_mean']:.4f} "
131
+ f"alpha_std={stats['alpha_std']:.4f} "
132
+ f"alpha_min={stats['alpha_min']:.4f} "
133
+ f"alpha_max={stats['alpha_max']:.4f}"
134
+ )
135
+ print(message)
136
+ os.makedirs(args.log_root, exist_ok=True)
137
+ with open(os.path.join(args.log_root, f"{args.name}.txt"), "a") as f:
138
+ f.write(message + "\n")
139
+
140
+
141
+ class CachedQDataset(Dataset):
142
+ def __init__(self, split, cfg):
143
+ self.split = split
144
+ self.cfg = cfg
145
+ self.root = os.path.join(cfg.cache_root, split)
146
+ self.index_path = os.path.join(self.root, "index.jsonl")
147
+ if not os.path.exists(self.index_path):
148
+ raise FileNotFoundError(f"Missing cache index: {self.index_path}")
149
+ with open(self.index_path) as f:
150
+ self.rows = [json.loads(line) for line in f if line.strip()]
151
+
152
+ def __len__(self):
153
+ return len(self.rows)
154
+
155
+ def _load_masks(self, vid, fids):
156
+ masks = []
157
+ for fid in fids:
158
+ frames = []
159
+ for frame_idx in range(self.cfg.frame_n):
160
+ path = os.path.join(
161
+ self.cfg.data_dir,
162
+ "gt_mask",
163
+ vid,
164
+ f"fid_{int(fid)}",
165
+ f"0000{frame_idx}.png",
166
+ )
167
+ mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
168
+ if mask is None:
169
+ raise FileNotFoundError(path)
170
+ frames.append(torch.as_tensor(mask > 0, dtype=torch.float32))
171
+ masks.append(torch.stack(frames, dim=0))
172
+ return torch.stack(masks, dim=0)
173
+
174
+ def __getitem__(self, idx):
175
+ row = self.rows[idx]
176
+ cache = torch.load(os.path.join(self.root, row["path"]), map_location="cpu")
177
+ vid = cache["vid"]
178
+ return {
179
+ "sample_idx": cache["sample_idx"],
180
+ "vid": vid,
181
+ "refs": cache["refs"],
182
+ "fids": cache["fids"],
183
+ "q": cache["q"].float(),
184
+ "image_embeddings": torch.load(
185
+ os.path.join(self.cfg.data_dir, "image_embed", f"{vid}.pt"),
186
+ map_location="cpu",
187
+ ).float(),
188
+ "gt_masks": self._load_masks(vid, cache["fids"]),
189
+ "resize": tuple(cache["resize"]),
190
+ "orgsize": tuple(cache["orgsize"]),
191
+ }
192
+
193
+
194
+ def collate_cached(batch):
195
+ return batch
196
+
197
+
198
+ def decode_batch(visual_model, batch, device):
199
+ image_pe = visual_model.prompt_encoder.get_dense_pe()
200
+ frame_qs = []
201
+ frame_image_embeddings = []
202
+ prompt_spans = []
203
+
204
+ for sample_idx, sample in enumerate(batch):
205
+ q = sample["q"].to(device=device, dtype=torch.float32)
206
+ image_embeddings = sample["image_embeddings"].to(device=device, dtype=torch.float32)
207
+ frames = image_embeddings.shape[0]
208
+ for prompt_idx in range(q.shape[0]):
209
+ start = len(frame_qs) * frames
210
+ frame_qs.append(q[prompt_idx].unsqueeze(0).expand(frames, -1))
211
+ frame_image_embeddings.append(image_embeddings)
212
+ prompt_spans.append((sample_idx, prompt_idx, start, start + frames))
213
+
214
+ if not frame_qs:
215
+ raise RuntimeError("No cached prompts were provided for decoding.")
216
+
217
+ frame_qs = torch.cat(frame_qs, dim=0)
218
+ frame_image_embeddings = torch.cat(frame_image_embeddings, dim=0)
219
+ sparse_embeddings, dense_embeddings = visual_model.prompt_encoder(
220
+ points=None,
221
+ boxes=None,
222
+ masks=None,
223
+ text_embeds=frame_qs.unsqueeze(1),
224
+ )
225
+ sparse_embeddings = sparse_embeddings.to(frame_qs.dtype)
226
+ dense_embeddings = dense_embeddings.to(frame_qs.dtype)
227
+
228
+ low_res_masks = visual_model.mask_decoder.forward_modified_v3(
229
+ image_embeddings=frame_image_embeddings,
230
+ image_pe=image_pe,
231
+ sparse_prompt_embeddings=sparse_embeddings,
232
+ dense_prompt_embeddings=dense_embeddings,
233
+ ).unsqueeze(1)
234
+
235
+ pred_by_sample = [[] for _ in batch]
236
+ for sample_idx, _, start, end in prompt_spans:
237
+ sample = batch[sample_idx]
238
+ pred_mask = visual_model.postprocess_masks(
239
+ low_res_masks[start:end],
240
+ input_size=sample["resize"],
241
+ original_size=sample["orgsize"],
242
+ )
243
+ pred_by_sample[sample_idx].append(pred_mask.squeeze(1))
244
+
245
+ return [torch.stack(pred_masks, dim=0) for pred_masks in pred_by_sample]
246
+
247
+
248
+ def decode_sample(visual_model, sample, device):
249
+ return decode_batch(visual_model, [sample], device)[0]
250
+
251
+
252
+ def compute_mask_loss(pred_masks, gt_masks):
253
+ mask_bce_loss = 0.0
254
+ mask_dice_loss = 0.0
255
+ num_masks = 0
256
+
257
+ for pred_mask, gt_mask in zip(pred_masks, gt_masks):
258
+ gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
259
+ num_seg, frames, height, width = gt_mask.shape
260
+ gt_flat = gt_mask.view(num_seg * frames, height, width)
261
+ pred_flat = pred_mask.view(num_seg * frames, height, width)
262
+
263
+ mask_bce_loss = mask_bce_loss + (
264
+ sigmoid_ce_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
265
+ * gt_flat.shape[0]
266
+ )
267
+ mask_dice_loss = mask_dice_loss + (
268
+ dice_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
269
+ * gt_flat.shape[0]
270
+ )
271
+ num_masks += gt_flat.shape[0]
272
+
273
+ mask_bce_loss = 2.0 * mask_bce_loss / (num_masks + 1e-8)
274
+ mask_dice_loss = 0.5 * mask_dice_loss / (num_masks + 1e-8)
275
+ return mask_bce_loss + mask_dice_loss
276
+
277
+
278
+ def compute_batch_metrics(pred_masks, gt_masks):
279
+ total_iou = 0.0
280
+ total_fscore = 0.0
281
+ count = 0
282
+ for pred_mask, gt_mask in zip(pred_masks, gt_masks):
283
+ gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
284
+ num_seg, frames = pred_mask.shape[:2]
285
+ weight = num_seg * frames
286
+ total_iou += utility.mask_iou(pred_mask.detach().float(), gt_mask.float()) * weight
287
+ total_fscore += utility.Eval_Fmeasure(pred_mask.detach().float(), gt_mask.float(), None) * weight
288
+ count += weight
289
+ return {
290
+ "miou": total_iou / max(1, count),
291
+ "fscore": total_fscore / max(1, count),
292
+ }
293
+
294
+
295
+ def evaluate(model, loader):
296
+ model.eval()
297
+ visual_model = model.get_model().visual_model
298
+ total_iou = 0.0
299
+ total_fscore = 0.0
300
+ total_null = 0.0
301
+ count = 0
302
+
303
+ with torch.no_grad():
304
+ for batch in tqdm(loader, desc=f"Cached eval {args.cache_split}"):
305
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
306
+ batch_pred = decode_batch(visual_model, batch, "cuda")
307
+ for sample, pred in zip(batch, batch_pred):
308
+ gt = sample["gt_masks"].to(device=pred.device, dtype=pred.dtype)
309
+ num_seg, frames = pred.shape[:2]
310
+ weight = num_seg * frames
311
+ if args.cache_split == "test_n":
312
+ total_null += float(utility.metric_s_for_null(pred.float())) * weight
313
+ else:
314
+ total_iou += utility.mask_iou(pred.float(), gt.float()) * weight
315
+ total_fscore += utility.Eval_Fmeasure(pred.float(), gt.float(), None) * weight
316
+ count += weight
317
+
318
+ if count == 0:
319
+ raise RuntimeError("No cached samples were evaluated.")
320
+
321
+ if args.cache_split == "test_n":
322
+ print(f"cached valuate on test_n_refer, metric: {total_null / count}")
323
+ else:
324
+ print(
325
+ f"cached valuate on {args.cache_split}: "
326
+ f"miou: {total_iou / count} fscore: {total_fscore / count}"
327
+ )
328
+
329
+
330
+ def train(model, loader):
331
+ if args.disable_gate:
332
+ raise ValueError("--disable_gate is only valid with --eval_only")
333
+
334
+ for p in model.parameters():
335
+ p.requires_grad = False
336
+ for name, p in model.named_parameters():
337
+ if "referent_gate" in name:
338
+ p.requires_grad = True
339
+
340
+ gate_params = [p for p in model.parameters() if p.requires_grad]
341
+ optimizer = AdamW(gate_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
342
+ stats = collect_referent_gate_stats(model)
343
+ print(
344
+ "cached gate init: "
345
+ f"modules={stats['modules']} "
346
+ f"proj_norm={stats['proj_norm']:.6f} "
347
+ f"gate_norm={stats['gate_norm']:.6f} "
348
+ f"trainable_params={sum(p.numel() for p in gate_params)}"
349
+ )
350
+
351
+ visual_model = model.get_model().visual_model
352
+ step = 0
353
+ for epoch in range(args.epochs):
354
+ model.train()
355
+ order_loader = loader
356
+ for batch in tqdm(order_loader, desc=f"Cached gate train {epoch + 1}/{args.epochs}"):
357
+ if args.max_steps > 0 and step >= args.max_steps:
358
+ break
359
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
360
+ pred_masks = decode_batch(visual_model, batch, "cuda")
361
+ gt_masks = [sample["gt_masks"] for sample in batch]
362
+
363
+ loss = compute_mask_loss(pred_masks, gt_masks)
364
+ optimizer.zero_grad()
365
+ loss.backward()
366
+ step += 1
367
+ if args.log_gate_stats_every > 0 and step % args.log_gate_stats_every == 0:
368
+ batch_metrics = compute_batch_metrics(pred_masks, gt_masks)
369
+ log_gate_stats(model, step, loss.item(), batch_metrics)
370
+ optimizer.step()
371
+
372
+ if args.max_steps > 0 and step >= args.max_steps:
373
+ print(f"stopped early at cached optimizer step {step}")
374
+ break
375
+
376
+ os.makedirs(args.checkpoint_root, exist_ok=True)
377
+ ckpt_path = os.path.join(args.checkpoint_root, f"{args.name}.pth")
378
+ if args.save_gate_only:
379
+ torch.save(
380
+ {
381
+ "type": "referent_gate_only",
382
+ "base_model": args.saved_model,
383
+ "state_dict": referent_gate_state_dict(model),
384
+ },
385
+ ckpt_path,
386
+ )
387
+ else:
388
+ torch.save(model.state_dict(), ckpt_path)
389
+ print(f"cached gate model saved as {ckpt_path}")
390
+
391
+
392
+ def main():
393
+ set_seed(42)
394
+ random.seed(42)
395
+ np.random.seed(42)
396
+
397
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
398
+ args.mllm,
399
+ cache_dir=None,
400
+ model_max_length=2048,
401
+ padding_side="right",
402
+ use_fast=False,
403
+ )
404
+ tokenizer.pad_token = tokenizer.unk_token
405
+ tokenizer.add_tokens("[SEG]")
406
+ seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
407
+
408
+ dataset = CachedQDataset(args.cache_split, args)
409
+ if args.overfit_samples > 0:
410
+ n = min(args.overfit_samples, len(dataset))
411
+ dataset = Subset(dataset, list(range(n)))
412
+ print(f"cached overfit_samples enabled: using first {n} samples")
413
+
414
+ loader = DataLoader(
415
+ dataset,
416
+ batch_size=args.batch_size,
417
+ shuffle=not args.eval_only,
418
+ num_workers=4,
419
+ collate_fn=collate_cached,
420
+ )
421
+
422
+ model = build_model(tokenizer, seg_token_idx)
423
+ if args.gate_checkpoint:
424
+ load_referent_gate_checkpoint(model, args.gate_checkpoint)
425
+ if args.disable_gate:
426
+ zero_referent_gate(model)
427
+ print("disable_gate enabled: referent gate forced to identity")
428
+
429
+ if args.eval_only:
430
+ evaluate(model, loader)
431
+ return
432
+
433
+ train(model, loader)
434
+ if not args.skip_eval_after_train:
435
+ evaluate(model, loader)
436
+
437
+
438
+ if __name__ == "__main__":
439
+ main()