yfan07 commited on
Commit
e214bf0
·
verified ·
1 Parent(s): d5c375d

Clean experimental files and restore original SimToken layout

Browse files
Residual_Prompt_Bridge.md DELETED
@@ -1,501 +0,0 @@
1
- # Residual Prompt Bridge 论文导向实验路线图
2
-
3
- ## 1. 当前主 claim
4
-
5
- 论文主 claim 现在正式锁定为:
6
-
7
- > **We propose an image-conditioned directional prompt correction module that orthogonalizes prompt updates to steer language-side prompts toward a more decodable SAM prompt manifold, mitigating cross-distribution prompt interface mismatch.**
8
-
9
- 对应中文表述:
10
-
11
- > **我们提出一种图像条件的方向型 prompt correction,通过正交化更新把语言侧 prompt 朝更可解码的 SAM prompt manifold 偏转,从而缓解跨分布的 prompt 接口失配。**
12
-
13
- 从现在开始,所有实验都只服务这句 claim,不再让方法故事扩散成“大而全系统”。
14
-
15
- ---
16
-
17
- ## 2. 当前项目定位
18
-
19
- 当前 RPB 项目已经完成了最关键的早期筛查:
20
-
21
- 1. **实现正确性通过**
22
- - checkpoint / LoRA 兼容问题已修复
23
- - bridge 路径不会自动破坏 baseline
24
- - identity-preserving sanity check 已通过
25
-
26
- 2. **几何机制方向明确**
27
- - additive residual 不足以推动 `p_hat` 离开 `q`
28
- - directional bridge 明显优于 additive
29
- - orthogonalization 能把 residual 预算从径向缩放转成方向修正
30
-
31
- 3. **当前最小核心已浮现**
32
- - `image-conditioned`
33
- - `p_mask-only`
34
- - `directional`
35
- - `orthogonal`
36
- - `single-token correction`
37
-
38
- 4. **mixed 的角色目前仍未定型**
39
- - weak mixed 不会抹掉 bridge
40
- - 但目前更像 enhancer / compatibility probe,而不是稳定的 decoder-facing calibration mechanism
41
-
42
- 因此,当前最重要的不是继续加模块,而是把这个**最小有效核心**做成稳定、可复现、可投稿的方法骨架。
43
-
44
- ---
45
-
46
- ## 3. 两套判据:Mechanism Pass vs Paper Pass
47
-
48
- ### 3.1 Mechanism pass
49
-
50
- 回答的问题是:
51
-
52
- > 这个方法设计是否真的抓住了问题本质?
53
-
54
- 当前 mechanism pass 需要被下面这些证据支撑:
55
-
56
- - additive vs directional:directional 明显更能让 `p_hat` 离开 identity
57
- - without orthogonal vs with orthogonal:orthogonalization 明显改善 `Δp` 的几何利用效率
58
- - `Δp` 稳定朝 `p_mask`
59
- - `p_hat` 能明显离开 `q`
60
- - seen/unseen 的 alignment ratio 健康
61
- - weak mixed 不会直接把 bridge 拉回 baseline
62
-
63
- ### 3.2 Paper pass
64
-
65
- 回答的问题是:
66
-
67
- > 这个方法是否已经强到能单独撑起一篇顶会方法论文?
68
-
69
- paper pass 需要下面这些更强条件:
70
-
71
- - 更大规模评估上有稳定、同向的 headline 趋势
72
- - 至少在 unseen 上有清晰、可复现的优势
73
- - seen / null 的代价可接受
74
- - 2 个随机种子下趋势稳定
75
- - 最小闭环 ablation 完整
76
-
77
- 当前状态:
78
-
79
- - **mechanism pass:接近通过,但还缺更大规模验证和关键 baseline**
80
- - **paper pass:尚未通过**
81
-
82
- 后续每组实验都要明确写清楚:它是在推进 mechanism pass,还是在推进 paper pass。
83
-
84
- ---
85
-
86
- ## 4. 冻结最小核心方法
87
-
88
- 在 pure RPB standalone 路线中,当前只保留下列组成:
89
-
90
- - `image-conditioned correction`
91
- - `p_mask-only teacher`
92
- - `directional bridge`
93
- - `orthogonalized update`
94
- - `single-token prompt correction`
95
-
96
- 当前明确**不进入主线**的内容:
97
-
98
- - `z_gt` 作为主 teacher
99
- - calibrator
100
- - refinement
101
- - 多 token bridge
102
- - 大而全的完整 bridge 系统
103
-
104
- 这些内容后续最多作为 ablation、扩展或 hybrid 组件,而不是当前主方法本体。
105
-
106
- ---
107
-
108
- ## 5. 当前实验事实总结
109
-
110
- ### 5.1 已确认的正结果
111
-
112
- - bridge 可以安全接入,不会自动毁掉 baseline
113
- - 修复 checkpoint / LoRA 后,RPB 路径与 baseline 基本等价
114
- - `directional + orthogonal` 后:
115
- - `Δp` 高度对齐 `p_mask`
116
- - `Δp` 不再主要沿 `q` 的平行方向浪费预算
117
- - `p_hat` 能够明显离开 identity 区
118
- - `p_mask-only teacher-only` 已在 quick eval 上给出:
119
- - seen 小幅回落但可控
120
- - unseen 轻微正信号
121
- - null 基本持平
122
-
123
- ### 5.2 已确认的负结果
124
-
125
- - additive residual 不足以真正旋转 prompt
126
- - `L_mask` 不是早期主矛盾
127
- - `z_gt` 目前不是 sparse bridge 的主 teacher
128
- - weak mixed 目前不能稳定把 seen 拉回 baseline
129
-
130
- ### 5.3 当前最重要的工作假设
131
-
132
- > `p_mask-only + image-conditioned + directional + orthogonal` 已经抓住主问题,但还需要找到更稳定的 operating point,并证明其 headline 趋势不是噪声。
133
-
134
- ### 5.4 Fixed dev 阶段 A 当前记录
135
-
136
- 固定 dev 子集:
137
-
138
- - `test_s`: 200 samples
139
- - `test_u`: 200 samples
140
- - `test_n`: 200 samples
141
- - manifest: `/workspace/SimToken/dev_subsets_rpb_v1.json`
142
-
143
- #### Fixed dev baseline
144
-
145
- | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null |
146
- |---|---:|---:|---:|---:|---:|
147
- | baseline | 0.72554 | 0.81811 | 0.68531 | 0.77238 | 0.01452 |
148
-
149
- #### Teacher-only alpha search
150
-
151
- | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 机制判断 |
152
- |---|---:|---:|---:|---:|---:|---:|---:|---|
153
- | image, alpha=0.20 | 0.72517 | 0.81376 | 0.68596 | 0.77730 | 0.01426 | 0.09502 | 0.06611 | 机制最强,Seen/F 有代价 |
154
- | image, alpha=0.18 | 0.72692 | 0.81705 | 0.68595 | 0.77354 | 0.01448 | 0.02873 | 0.00605 | 性能平衡较好,机制偏弱 |
155
- | image, alpha=0.15 | 0.72669 | 0.81725 | 0.68569 | 0.77330 | 0.01448 | 0.02373 | 0.00282 | 更接近 identity |
156
- | image, alpha=0.12 | 0.72651 | 0.81748 | 0.68578 | 0.77314 | 0.01449 | 0.01871 | -0.00046 | 轻扰动区,机制最弱 |
157
-
158
- 阶段 A 的 teacher-only 结论:
159
-
160
- - `alpha=0.20` 是机制候选点,能明显改变 prompt geometry。
161
- - `alpha=0.18` 是性能平衡候选点,seen / unseen / null 都更稳。
162
- - `alpha=0.12/0.15` 已经过于接近 identity,不适合作为机制主证据。
163
-
164
- #### Weak mixed 局部验证
165
-
166
- | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 角色判断 |
167
- |---|---:|---:|---:|---:|---:|---:|---:|---|
168
- | image, alpha=0.18, weak mixed | 0.72704 | 0.81554 | 0.68706 | 0.77454 | 0.01451 | 0.04079 | 0.01325 | 当前最佳性能平衡候选 |
169
- | image, alpha=0.15, weak mixed | 0.72684 | 0.81607 | 0.68674 | 0.77419 | 0.01451 | 0.03382 | 0.00882 | 稳定但略弱于 alpha=0.18 mixed |
170
-
171
- weak mixed 当前结论:
172
-
173
- - weak mixed 没有把 bridge 拉回 identity。
174
- - weak mixed 对 `alpha=0.15/0.18` 都更像 mild enhancement,而不是 destructive pullback。
175
- - `alpha=0.18 + weak mixed` 是当前 fixed dev 的最佳 operating point。
176
-
177
- #### q-only directional baseline
178
-
179
- | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 判断 |
180
- |---|---:|---:|---:|---:|---:|---:|---:|---|
181
- | q-only, alpha=0.18 | 0.72311 | 0.81206 | 0.68289 | 0.77666 | 0.01424 | 0.12061 | 0.09598 | alignment 更强但 mIoU 更差 |
182
-
183
- q-only 结论:
184
-
185
- - directional / orthogonal 机制本身很强,q-only 也能大幅拉高 teacher alignment。
186
- - q-only 的 prompt steering 更激进,`gate_mean` 更高,`delta_norm` 更大。
187
- - q-only mIoU 在 seen / unseen 上都低于 image-conditioned candidate。
188
- - 当前证据支持:image conditioning 的价值不是单纯提高 teacher cosine,而是约束方向修正,使 prompt steering 与 decoder compatibility 之间的平衡更好。
189
-
190
- #### 阶段 A 当前候选
191
-
192
- 当前 fixed dev 最佳候选:
193
-
194
- > **image-conditioned + p_mask-only + directional + orthogonal + alpha=0.18 + weak mixed**
195
-
196
- 对应 checkpoint:
197
-
198
- > `/workspace/SimToken/checkpoints/rpb_dev_mixed_pm_only_a018_wm005.pth`
199
-
200
- ---
201
-
202
- ## 6. 实验纪律:停止在 test 上自由调方向
203
-
204
- 从下一阶段开始,必须冻结一套 **dev tuning subset**,不再继续在 `test_s/test_u/test_n` 上自由调 alpha 和 mixed 设定。
205
-
206
- 建议立即固定:
207
-
208
- - `dev_seen`
209
- - `dev_unseen`
210
- - `dev_null`
211
-
212
- 每个 split 可先取 `100` 或 `200` 个样本,后续:
213
-
214
- - alpha 选择
215
- - mixed 选择
216
- - warm-start 配置
217
- - early stopping
218
-
219
- 全部只在 dev 上完成。
220
- 真正的 test split 只用于后续一次性确认和最终表格。
221
-
222
- ---
223
-
224
- ## 7. 三阶段推进路线
225
-
226
- ## 阶段 A:锁最小核心的 operating point
227
-
228
- ### 目标
229
-
230
- 回答:
231
-
232
- > 当前最小核心是否能在更大 quick eval 上形成稳定、可接受的性能-几何平衡?
233
-
234
- ### 本阶段只做两类实验
235
-
236
- #### A1. teacher-only operating point 搜索
237
-
238
- 固定:
239
-
240
- - image-conditioned
241
- - `p_mask-only`
242
- - directional
243
- - orthogonal
244
- - single-token
245
- - 不加 `z_gt`
246
- - 不加 calibrator
247
- - 不加 refinement
248
-
249
- 重点只扫:
250
-
251
- - `alpha = 0.12, 0.15, 0.18, 0.20`
252
-
253
- 当前判断是:`0.20` 已经是 promising pass,因此没有必要继续向更大 alpha 发散。
254
-
255
- #### A2. weak mixed 局部验证
256
-
257
- 只围绕最佳 teacher-only checkpoint 做 warm-start,不做大 sweep。
258
-
259
- 建议只测:
260
-
261
- - `best_alpha`
262
- - `best_alpha - 0.03`
263
-
264
- 以及很弱的 mask 强度两档:
265
-
266
- - `λ_mask = 0.05`
267
- - `λ_mask = 0.10`
268
-
269
- mixed 的目标不是涨分,而是判断它的角色到底是:
270
-
271
- - calibration
272
- - enhancement
273
- - 还是 destructive pullback
274
-
275
- ### 阶段 A 重点指标
276
-
277
- 几何指标:
278
-
279
- - `cos(p_hat, p_mask)_seen`
280
- - `cos(p_hat, p_mask)_unseen`
281
- - `cos(p_hat, q)`
282
- - `cos(Δp, p_mask)`
283
- - `cos(Δp, q)`
284
- - `align_ratio = cos_u / cos_s`
285
-
286
- 性能指标:
287
-
288
- - `mIoU_seen`
289
- - `mIoU_unseen`
290
- - `Fscore_seen`
291
- - `Fscore_unseen`
292
- - `Null metric`
293
-
294
- ### 阶段 A 的通过标准
295
-
296
- 若在 dev 或更大 quick eval 上,能找到一个稳定点满足:
297
-
298
- - unseen 稳定不差于 baseline,最好有小幅提升
299
- - seen 代价可控
300
- - null 基本持平或代价可接受
301
- - `cos(p_hat, p_mask)` 明显离开 identity 区
302
- - seen/unseen 的 alignment ratio 健康
303
-
304
- 则阶段 A 通过。
305
-
306
- ### 阶段 A 的停止条件
307
-
308
- 若完成:
309
-
310
- 1. alpha 局部搜索
311
- 2. weak mixed 局部搜索
312
- 3. 100 / 200 样本 quick eval
313
-
314
- 之后仍出现任一情况,则停止 pure RPB standalone 主线:
315
-
316
- - 在更大 quick eval 上没有稳定、同向的 unseen 优势
317
- - seen/unseen tradeoff 对 alpha 高度敏感
318
- - null 代价无法压到 baseline 附近
319
- - mixed 始终只是增强器,而不是 decoder-facing calibration
320
-
321
- ---
322
-
323
- ## 阶段 B:做最小闭环 ablation
324
-
325
- 只有阶段 A 通过后,才进入阶段 B。
326
-
327
- ### 目标
328
-
329
- 把方法主骨架讲圆,形成 mechanism pass 的闭环证据。
330
-
331
- ### 必做的 4 个关键 ablation
332
-
333
- 1. **additive vs directional**
334
- 2. **directional without orthogonalization vs with orthogonalization**
335
- 3. **q-only directional vs image-conditioned directional**
336
- 4. **`p_mask-only` vs `p_mask + weak z_gt`**
337
-
338
- 这 4 个已经足够支撑方法论证,不再继续扩更多 trick ablation。
339
-
340
- ### 阶段 B 的补充要求
341
-
342
- - 至少 2 个随机种子重复
343
- - 至少一次更大规模验证
344
- - 建立 geometry-performance coupling:
345
- - prompt geometry 改写程度
346
- - 与 seen/unseen 表现之间的关系
347
- - 与 identity 回缩之间的关系
348
-
349
- ### 阶段 B 的停止条件
350
-
351
- 若完成:
352
-
353
- 1. alpha 局部搜索
354
- 2. weak mixed 局部搜索
355
- 3. 100 / 200 样本 quick eval
356
- 4. 至少一次更大规模验证
357
- 5. 2 个随机种子重复
358
-
359
- 后仍满足以下任一条,则停止 pure RPB standalone:
360
-
361
- - 大子集 / full-split 上没有稳定、同向的 unseen 优势
362
- - 最优点高度依赖 seed 或 alpha,趋势不稳定
363
- - null 代价无法控制
364
- - mixed 无法形成稳定 calibration 作用
365
- - headline result 仍然只有极弱波动
366
-
367
- ---
368
-
369
- ## 阶段 C:决定论文定位
370
-
371
- ### 路线 1:pure RPB standalone
372
-
373
- 如果满足:
374
-
375
- - 更大评估上有稳定 unseen gain
376
- - seen / null 代价可接受
377
- - 2 seeds 稳定
378
- - 最小闭环 ablation 完整
379
-
380
- 则走:
381
-
382
- > **pure RPB 方法论文**
383
-
384
- ### 路线 2:RPB + TTO hybrid
385
-
386
- 如果出现:
387
-
388
- - mechanism 成立
389
- - 但 paper pass 不够硬
390
- - headline result 仍然偏弱或不稳定
391
-
392
- 则立刻切换定位:
393
-
394
- > **RPB + TTO hybrid 方法论文**
395
-
396
- 此时 RPB 的角色不再是 standalone 主方法,而是:
397
-
398
- - amortized prompt corrector
399
- - 改善 test-time refinement 起点质量的前端模块
400
-
401
- ---
402
-
403
- ## 8. Hybrid 路线作为明确 Plan B
404
-
405
- 若 pure RPB 最终只能做到:
406
-
407
- - unseen 稳定小涨
408
- - seen 小掉
409
- - null 持平或略好
410
-
411
- 那么 standalone 顶会会比较吃力。
412
- 但此时 RPB 作为前端 prompt corrector 仍很有价值:
413
-
414
- - 改善初始 `q` 的几何
415
- - 为 q-LTPO / selective refinement 提供更好的初始化
416
- - 降低 test-time optimization 的步数和不稳定性
417
-
418
- hybrid 的论文叙事可以明确写成:
419
-
420
- 1. train-time:amortized interface correction
421
- 2. test-time:instance-specific prompt refinement
422
- 3. 两者结合:同时解决全局接口失配与样本级细化问题
423
-
424
- 当前判断:hybrid 是非常强的 Plan B,而不是临时补救路线。
425
-
426
- ---
427
-
428
- ## 9. 负结果如何写进论文论证链条
429
-
430
- 当前已经得到了一条清晰的“设计收敛链条”,后续可以直接转写为论文方法论证:
431
-
432
- ### 为什么不是 additive residual
433
-
434
- 因为 additive 下:
435
-
436
- - `Δp` 主要对抗 `q` 的平行分量
437
- - teacher 方向被大范数 `q` 吞掉
438
- - 结果更像缩放,而不是旋转
439
-
440
- ### 为什么要 directional
441
-
442
- 因为 directional 才能把修正显式变成 prompt 方向控制,而不是数值扰动。
443
-
444
- ### 为什么要 orthogonal
445
-
446
- 因为 orthogonalization 才能避免 residual 预算浪费在径向缩放上。
447
-
448
- ### 为什么当前只保留 `p_mask`
449
-
450
- 因为当前 sparse bridge 里,`p_mask` 一直是主 teacher,`z_gt` 尚未成为主信号。
451
-
452
- ### 为什么 mixed 不是主模块
453
-
454
- 因为 mixed 目前更像 compatibility / enhancement probe,而不是稳定的 calibration mechanism。
455
-
456
- 这条链条必须在文中明确写出,让 reviewer 看到方法是沿诊断逐步收敛的,而不是盲目堆模块。
457
-
458
- ---
459
-
460
- ## 10. 当前最直接的执行建议
461
-
462
- 接下来不要发散,严格按下面顺序走:
463
-
464
- 1. **立刻冻结论文主 claim**
465
- 2. **立刻切换到固定 dev 子集,不再自由用 test 调方向**
466
- 3. **完成阶段 A:最小核心 operating point 搜索**
467
- 4. **补关键 baseline:q-only directional**
468
- 5. **做两种 seed**
469
- 6. **然后做 pure RPB standalone 的去留决策**
470
-
471
- 当前最重要的执行原则是:
472
-
473
- > **先证明最小核心能稳定成立;如果 headline 不够硬,就及时把它升级成 hybrid 前端,而不是继续把 pure RPB 做复杂。**
474
-
475
- ---
476
-
477
- ## 11. 当前阶段的明确结论
478
-
479
- ### 当前方向值得继续吗?
480
-
481
- **值得。**
482
-
483
- ### 现在最应该做什么?
484
-
485
- 不是继续扩模块,而是:
486
-
487
- - 找到 teacher-only `p_mask-only directional orthogonal` 的最佳 operating point
488
- - 用 very weak mixed 判断 mixed 是否能形成 calibration
489
- - 在 dev 和更大 quick eval 上证明趋势不是噪声
490
-
491
- ### 什么时候该停 pure RPB?
492
-
493
- 只要阶段 A + B 完成后,headline 仍然弱且不稳定,就停止 pure RPB standalone。
494
-
495
- ### 停了之后怎么办?
496
-
497
- 直接转:
498
-
499
- > **RPB + TTO hybrid**
500
-
501
- 这条路线当前是明确的 Plan B,而且很可能是更强的顶会方法论文路径。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SEG_LTPO_results.md DELETED
@@ -1,488 +0,0 @@
1
- # SEG-LTPO: Experimental Results and Analysis
2
-
3
- ---
4
-
5
- ## Method 1: SEG-LTPO-simple (ES-based, zeroth-order)
6
-
7
- ### Overview
8
-
9
- SEG-LTPO-simple performs test-time optimization of SimToken's single semantic token **Fseg** using antithetic Evolution Strategies (ES), guided by an internal reward signal that requires no ground-truth masks.
10
-
11
- **Optimization loop** (T=5 steps, 4 anchor frames):
12
- ```
13
- eps_t ~ N(0, σ_t² I)
14
- F± = F_curr ± eps_t
15
- F_curr = F_curr + η_t · (R+ − R−) / (2σ_t²) · eps_t
16
- best_F = argmax_F R(F) over all evaluated candidates
17
- ```
18
-
19
- **Reward function:**
20
- ```
21
- R = λ1·R_temp_feat + λ2·R_iou_pred + λ3·R_align_contrast − λ4·R_area
22
- = 0.3·R_temp + 0.4·R_iou + 1.0·R_align − 0.3·R_area
23
- ```
24
-
25
- - **R_align_contrast**: cosine(Fseg, z_inside) − β·cosine(Fseg, z_outside); main signal
26
- - **R_iou_pred**: SAM's internal mask quality head output
27
- - **R_temp_feat**: feature-space cosine consistency between adjacent anchor frames
28
- - **R_area**: average foreground ratio (degenerate-mask penalty)
29
-
30
- **Reward gating**: accept optimized Fseg only when R(best_F) > R(F_init) + gate_delta.
31
-
32
- ### Results (Unseen split, full 1656 samples)
33
-
34
- | Method | mIoU | F | Δ mIoU |
35
- |--------|------|---|--------|
36
- | Baseline | 0.6989 | 0.7927 | — |
37
- | Best-of-2 Random | 0.7050 (subset) → 0.7030 (full) | 0.7953 | +0.0040 |
38
- | SEG-LTPO-simple (ES) | **0.7050** | **0.7960** | **+0.0061** |
39
-
40
- > Best-of-2 and LTPO-ES results at full scale confirmed in the q-LTPO evaluation run below.
41
-
42
- ### Key Findings
43
-
44
- 1. **Reward signal is valid**: both Best-of-2 and ES-LTPO outperform baseline, confirming R_align_contrast provides useful signal.
45
- 2. **ES update is noisy**: in 500-sample ablation, Best-of-2 (0.7235) slightly outperformed iterative ES (0.7228), due to extremely low SNR of single-sample gradient estimation in 256d space. At full scale (1656), ES-LTPO recovers (+0.0065 vs +0.0040), but the margin over Best-of-2 is small.
46
- 3. **Null stability**: Null S metric change negligible (+0.00025), reward gating effectively suppresses false positives.
47
-
48
- ---
49
-
50
- ## Method 2: q-LTPO-autograd (first-order, Adam maximize)
51
-
52
- ### Overview
53
-
54
- **Core insight from LTPO analysis**: optimize the variable that is *directly consumed* by the downstream module, using autograd rather than noisy zeroth-order estimation.
55
-
56
- **Three design decisions borrowed from original LTPO:**
57
-
58
- 1. **Optimize q, not Fseg.** In SimToken+SAM, the token that directly enters the mask decoder's cross-attention is `q = sparse_emb = Fseg.unsqueeze(1)` (prompt encoder passes text_embeds through unchanged). We set `q = nn.Parameter(q_init)` and optimize q directly, bypassing the prompt encoder entirely. This requires no invertibility of ε_p — q_best is used directly for final inference.
59
-
60
- 2. **Use autograd when reward is differentiable.** The mask decoder (transformer + MLP + matmul) is fully differentiable. With soft masks instead of hard thresholds, all reward terms are differentiable w.r.t. q. Adam maximize replaces the low-SNR score-function estimator.
61
-
62
- 3. **Track best_q by task reward (no regularization), gate at the end.** λ_reg penalty is excluded from gating to avoid penalizing solutions that drifted slightly from q_init but achieved better task reward.
63
-
64
- **Stage 0: Gradient connectivity check (verified)**
65
- ```
66
- grad_norm (step 0): 0.503070
67
- reward trajectory: [0.4650, 0.4709, 0.4770, 0.4831, 0.4892] ← strictly monotone
68
- gradient_connected: True
69
- ```
70
-
71
- ### Optimization loop
72
-
73
- ```python
74
- q = nn.Parameter(q_init.float().detach().clone())
75
- optimizer = Adam([q], lr=lr_auto, maximize=True)
76
- best_q, best_reward = q_init.clone(), R_task(q_init)
77
-
78
- for step in range(T=5):
79
- R_full = R_task(q) - λ_reg * ||q - q_init||²
80
- R_full.backward()
81
- optimizer.step()
82
- clip_to_L2_ball(q, q_init, max_drift) # hard norm constraint
83
- if R_task(q) > best_reward:
84
- best_q = q.clone()
85
-
86
- # gating
87
- use best_q if R_task(best_q) > R_task(q_init) + gate_delta, else q_init
88
- ```
89
-
90
- **Hyperparameters (auto-scaled from q_init):**
91
- - `lr = 0.01 × RMS(q_init)`
92
- - `max_drift = 0.5 × ||q_init||`
93
- - `λ_reg = 0.01`, `gate_delta = 0.0`
94
-
95
- ### Staged reward build-up
96
-
97
- **Stage 1** (R_iou + R_area_soft + λ_reg):
98
- ```
99
- R_task = 0.6·R_iou_pred − 0.2·sigmoid(mask_logits/τ).mean()
100
- where τ=5.0 (temperature to avoid sigmoid saturation)
101
- ```
102
-
103
- **Stage 2** (Stage 1 + R_align_det):
104
- ```
105
- R_task = 0.4·R_iou_pred + 1.0·R_align_det − 0.3·R_area_soft
106
- R_align_det = mean_t [ cosine(q, stopgrad(z_in^t)) − 0.5·cosine(q, stopgrad(z_out^t)) ]
107
- ```
108
- z_in/z_out are stopgrad'd to avoid coupling: q first finds a mask, then moves toward the masked region's semantics.
109
-
110
- ### Results (Unseen split)
111
-
112
- #### 200-sample subset (Stage 1 vs Stage 2 fair comparison, same baseline)
113
-
114
- | Method | mIoU | F | Δ mIoU |
115
- |--------|------|---|--------|
116
- | Baseline | 0.6749 | 0.7763 | — |
117
- | Best-of-2 ES | 0.6801 | 0.7803 | +0.0052 |
118
- | LTPO-ES | 0.6838 | 0.7826 | +0.0089 |
119
- | q-LTPO Stage 1 | 0.6979 | 0.7802 | +0.0230 |
120
- | q-LTPO Stage 2 | **0.6989** | **0.7810** | **+0.0240** |
121
-
122
- On 200 samples: Stage 2 marginally better than Stage 1 on both metrics.
123
-
124
- #### Full evaluation (Unseen, 1656 samples)
125
-
126
- | Method | mIoU | F | Δ mIoU vs Baseline |
127
- |--------|------|---|---------------------|
128
- | Baseline | 0.6990 | 0.7924 | — |
129
- | Best-of-2 ES | 0.7030 | 0.7953 | +0.0040 (+0.57%) |
130
- | LTPO-ES | 0.7055 | 0.7969 | +0.0065 (+0.93%) |
131
- | **q-LTPO Stage 1** | **0.7285** | **0.8013** | **+0.0295 (+4.22%)** |
132
- | q-LTPO Stage 2 | 0.7273 | 0.8002 | +0.0283 (+4.04%) |
133
-
134
- **Stage 1 beats Stage 2 on full eval** (opposite of 200-sample trend). R_align_det adds noise at scale: in harder Unseen samples, the initial mask quality is lower, making stopgrad z_in/z_out a less reliable target.
135
-
136
- ### Evaluation Status (after e0 fix)
137
-
138
- | Split | Baseline mIoU/S | q-LTPO S1 (no e0) | q-LTPO S1 (e0) | Status |
139
- |-------|-----------------|-------------------|----------------|--------|
140
- | Unseen (1656) | 0.6990 | **0.7285** | — | Done (pre-e0) |
141
- | Seen (200-sample) | 0.7483 | 0.7618 (+0.0136) | **0.7634 (+0.0151)** | Quick-val done |
142
- | Null (200-sample, S↓) | 0.0619 | 0.0646 (+4.4%) | **0.0634 (+2.4%)** | Quick-val done |
143
- | Unseen (200-sample) | 0.6761 | — | **0.6929 (+0.0168)** | Quick-val done |
144
- | Seen (full) | — | — | — | Pending |
145
- | Null (full, S↓) | 0.0120 | 0.0126 (+5.0%) | — | Pending e0 run |
146
- | Unseen (full) | — | — | — | Pending |
147
-
148
- ---
149
-
150
- ## Null Safety Analysis and e0-Modulated Reward
151
-
152
- ### Root Cause: R_iou_pred is a Conditional Quality Metric
153
-
154
- The original q-LTPO Stage 1 reward:
155
- ```
156
- R_task = 0.6·R_iou_pred − 0.2·R_area_soft
157
- ```
158
-
159
- caused Null S metric degradation (+4.4% on 200-sample quick validation, +5.0% on full Null).
160
-
161
- **Root cause**: `R_iou_pred` is SAM's internal mask quality head — it measures *how good the mask is given that segmentation was performed*, not *whether the target exists*. On Null frames, SAM still outputs `R_iou_pred ≈ 0.73–0.74` because it confidently segments the most prominent region (even if no audio target exists). The optimizer sees positive `R_iou_pred` and expands the mask accordingly.
162
-
163
- **Why oracle gating approaches fail methodologically:**
164
-
165
- - **Path A (gate_delta threshold)**: Distribution analysis showed Null reward_gain p50 = +0.0166 ≈ Seen p50 = +0.0181. The two distributions overlap heavily; any threshold that blocks most Null samples also blocks most Seen/Unseen samples.
166
- - **Path B (area-based reject rule)**: Threshold 0.02 (area fraction) was derived by observing Null mean_area = 0.0094 vs Seen mean_area = 0.054 from the test distribution. This is benchmark-specific tuning = test-set overfitting. **Not a valid method.**
167
-
168
- Both oracle approaches are useful for diagnostic analysis only. The principled fix must be structural.
169
-
170
- ### Principled Fix: e0-Modulated Reward
171
-
172
- **Key insight**: decouple *existence* from *quality*. Use the initial mask area as a proxy for the prior probability that a real target exists.
173
-
174
- ```python
175
- e0 = stopgrad( sigmoid(lrm_init / area_temp).mean() ) # R_area_soft at q_init
176
- R_task = λ_iou · e0 · R_iou_pred − λ_area · R_area_soft
177
- ```
178
-
179
- **Why stopgrad on e0 is critical:**
180
- - Without stopgrad: gradients flow through e0 → optimizer first inflates area to increase e0, then uses the higher e0 to justify larger R_iou reward ("area gaming").
181
- - With stopgrad: e0 is a fixed scalar from the initialization. Gradients only flow through the explicit terms `R_iou_pred` and `R_area_soft`.
182
-
183
- **Effect by split:**
184
-
185
- | Split | mean e0 | Effective λ_iou = 0.6·e0 | Behavior |
186
- |-------|---------|--------------------------|----------|
187
- | Null | 0.037 | 0.022 | Area penalty dominates → conservative |
188
- | Seen | 0.120 | 0.072 | Balanced optimization |
189
- | Unseen | 0.150 | 0.090 | Full optimization drive |
190
-
191
- The 3.2× e0 ratio (Unseen/Null) arises naturally from the initial mask size, providing automatic split-specific optimization strength without any threshold tuning.
192
-
193
- **Implementation fix also addressed (best_q tracking bug):**
194
- Before fix, `q_{N+1}` (post-step) was evaluated using `lrm/iou` from `q_N` (pre-step), corrupting best_q selection. Fixed by adding a fresh `no_grad` forward after each `optimizer.step()`.
195
-
196
- ### Quick Validation Results (200 samples each, e0 modulation)
197
-
198
- #### Null split (S metric, lower is better)
199
-
200
- | Method | S metric | Δ relative |
201
- |--------|----------|-----------|
202
- | Baseline | 0.0619 | — |
203
- | q-LTPO S1 (no e0) | 0.0646 | +4.4% |
204
- | **q-LTPO S1 (e0)** | **0.0634** | **+2.4%** |
205
-
206
- Diagnostic stats with e0:
207
- ```
208
- acceptance rate : 1.000
209
- mean e0 : 0.0372
210
- reward_gain p10/50/90: 0.0 / 0.0000 / +0.0123 ← p50=0 means >50% of samples frozen
211
- mean drift : 0.4962 ← down from ~0.8 without e0
212
- area (hard) init→best: 0.0094 → 0.0098 ← minimal area expansion
213
- reward↑ & area+20%↑ : 0.040 ← low Null-safety risk
214
- ```
215
-
216
- #### Seen split (mIoU, higher is better)
217
-
218
- | Method | mIoU | F | Δ mIoU |
219
- |--------|------|---|--------|
220
- | Baseline | 0.7483 | — | — |
221
- | q-LTPO S1 (no e0) | 0.7618 | — | +0.0136 |
222
- | **q-LTPO S1 (e0)** | **0.7634** | — | **+0.0151** |
223
-
224
- Diagnostic stats with e0:
225
- ```
226
- mean e0 : 0.1200
227
- reward_gain p10/50/90: +0.0026 / +0.0181 / +0.0944
228
- mean drift : 0.5225
229
- area (hard) init→best: 0.054 → (slight increase)
230
- ```
231
-
232
- #### Unseen split (mIoU, higher is better)
233
-
234
- | Method | mIoU | F | Δ mIoU |
235
- |--------|------|---|--------|
236
- | Baseline | 0.6761 | 0.7776 | — |
237
- | **q-LTPO S1 (e0)** | **0.6929** | **0.7765** | **+0.0168** |
238
-
239
- Diagnostic stats with e0:
240
- ```
241
- acceptance rate : 1.000
242
- mean e0 : 0.1506
243
- reward_gain p10/50/90: +0.0011 / +0.0055 / +0.0293
244
- mean drift : 0.6666
245
- R_iou_pred init→best : 0.8029 → 0.8802
246
- area (hard) init→best: 0.0635 → 0.0650
247
- reward↑ & area+20%↑ : 0.125
248
- ```
249
-
250
- ### Analysis: e0 is a Pareto Improvement
251
-
252
- Three conditions for Pareto improvement all satisfied on quick validation:
253
-
254
- 1. **Null safer**: degradation halved (+4.4% → +2.4%). p50 reward_gain = 0.0000, meaning >50% of Null samples produce `best_q ≈ q_init`.
255
- 2. **Seen maintained and slightly improved**: +0.0151 vs +0.0136 without e0.
256
- 3. **Unseen not hurt — gains even larger**: +0.0168 > Seen +0.0151. The "harder positives suppressed" failure mode did not materialize.
257
-
258
- **e0 hierarchy confirms split-level discriminability:**
259
- ```
260
- Null (0.037) << Seen (0.120) < Unseen (0.150)
261
- ```
262
- The ordering is sensible: Null frames have small/empty initial masks → low e0. Unseen e0 slightly exceeds Seen, possibly because the model produces slightly larger (less specific) masks on novel object-sentence combinations.
263
-
264
- **Residual Null degradation (+2.4%) assessment**: Acceptable for now. The absolute magnitude is +0.0015 in S metric, while Seen/Unseen absolute gains are 10–11× larger. The residual originates from a small tail of Null samples where e0 is still large enough to permit some mask expansion. Further suppression (e.g., e0², sqrt(e0+ε)) risks hurting harder positives and should only be explored after full-set confirmation.
265
-
266
- ---
267
-
268
- ## Summary and Comparison
269
-
270
- ### Pre-e0 (original q-LTPO Stage 1, full Unseen)
271
-
272
- | Method | Unseen mIoU | Δ vs Baseline | Relative to ES-LTPO |
273
- |--------|-------------|---------------|----------------------|
274
- | Baseline | 0.6990 | — | — |
275
- | ES-LTPO | 0.7055 | +0.0065 | 1× |
276
- | **q-LTPO Stage 1** | **0.7285** | **+0.0295** | **4.5×** |
277
-
278
- ### e0-Modulated Stage 1 (quick validation, 200 samples)
279
-
280
- | Split | Baseline | e0-Stage1 | Δ | e0 |
281
- |-------|----------|-----------|---|-----|
282
- | Null (S↓) | 0.0619 | 0.0634 | +2.4% (rel) | 0.037 |
283
- | Seen | 0.7483 | 0.7634 | +0.0151 | 0.120 |
284
- | Unseen | 0.6761 | 0.6929 | +0.0168 | 0.150 |
285
-
286
- q-LTPO-autograd with e0 modulation is the current primary method candidate. It achieves first-order gradient-based optimization with automatic Null-safety via the initial-area existence prior, without any test-set-derived thresholds.
287
-
288
- ---
289
-
290
- ## Hyperparameter Configurations
291
-
292
- ### ES-LTPO (Method 1)
293
- ```python
294
- LTPOConfig(
295
- T=5, num_anchors=4,
296
- sigma_schedule=[0.10, 0.08, 0.06, 0.04, 0.02],
297
- eta_scale=0.5,
298
- lambda1=0.3, lambda2=0.4, lambda3=1.0, lambda4=0.3,
299
- beta=0.5, gate_delta=0.0, trust_delta=None,
300
- )
301
- ```
302
-
303
- ### q-LTPO Stage 1 with e0 (current primary candidate)
304
- ```python
305
- QLTPOConfig(
306
- stage=1, T=5, num_anchors=4,
307
- lr=0.0, # auto: 0.01 × RMS(q_init)
308
- max_drift=0.0, # auto: 0.5 × ||q_init||
309
- lambda_iou=0.6, lambda_area=0.2,
310
- lambda_reg=0.01, area_temp=5.0,
311
- gate_delta=0.0,
312
- e0_modulation="identity", # e0 = R_area_soft(q_init), stopgrad
313
- e0_eps=1e-4,
314
- # oracle-only fields (disabled, not used in final method):
315
- null_area_threshold=0.02,
316
- null_gate_delta=0.0,
317
- )
318
- ```
319
-
320
- ### Full Unseen Evaluation with e0 (1656 samples)
321
-
322
- | Method | mIoU | F | Δ mIoU |
323
- |--------|------|---|--------|
324
- | Baseline | 0.6990 | 0.7926 | — |
325
- | q-LTPO S1 (no e0) | 0.7285 | 0.8013 | +0.0295 (+4.22%) |
326
- | **q-LTPO S1 (e0)** | **0.7240** | **0.7985** | **+0.0250 (+3.56%)** |
327
-
328
- e0 版本相比 no-e0 版本 mIoU 略低 (-0.0045),但 Null 安全性更好。F 与 mIoU 的提升比例基本一致(约 60%)。
329
-
330
- **全量评估状态(更新):**
331
-
332
- | Split | Baseline | q-LTPO S1 (e0) | Δ | Status |
333
- |-------|----------|----------------|---|--------|
334
- | Unseen (full, 1656) | 0.6990 / 0.7926 | 0.7240 / 0.7985 | +3.56% mIoU | ✅ Done |
335
- | Seen (full) | — | — | — | Pending |
336
- | Null (full, S↓) | 0.0120 | — | — | Pending |
337
-
338
- ---
339
-
340
- ## Direction B: Boundary Precision Experiments(已结束,结论为失败)
341
-
342
- ### B-Step1: Multimask Post-Processing(彻底失败)
343
-
344
- 用 SAM 多 mask 输出(K=3)替换单 mask 解码,分别用 iou_pred 和 Sobel edge score 选最佳候选。
345
-
346
- | Method | mIoU | F | ΔF vs s1 |
347
- |--------|------|---|----------|
348
- | s1 (single mask) | 0.6979 | 0.8024 | — |
349
- | s1_mm (iou_pred selection) | 0.6979 | 0.7917 | -0.0107 |
350
- | s1_mm_edge (Sobel selection) | 0.5715 | 0.6820 | -0.1204 |
351
-
352
- **根本原因:** SAM 内部的单 mask 选择已经最优;外部重选更差。Sobel 在 1024×1024 归一化空间中选到纹理碎片而非语义目标,灾难性失败。
353
-
354
- ### B1: 非对称面积膨胀惩罚(机制性无效)
355
-
356
- 假设:LTPO 导致 mask 向非目标区域膨胀(精度下降),加惩罚项压制。
357
-
358
- **实验结论:假设错误。** LTPO 期间 soft area 实际在下降(-16%)而非上升:
359
-
360
- ```
361
- soft area: 0.1507 → 0.1267 (-16%) ← background logits 更负
362
- hard area: 0.0635 → 0.0650 (+2.4%) ← 实际 mask 区域微增
363
- ```
364
-
365
- **"mask sharpening" 现象:** Adam 在 R_iou_pred 驱动下使 logit 更双峰化(前景更正、背景更负),soft area 因 93% 背景像素的贡献减少而下降。B1 惩罚的前提条件(soft area 上升)从未发生:
366
-
367
- ```
368
- B1 activation rate : 0.025 ← 仅 2.5% 样本触发
369
- B1 mean excess : 0.00002 ← 可忽略
370
- ```
371
-
372
- **结论:** Direction B 从多 mask 选择到面积约束全部失败,不再追求。F-score 滞后于 mIoU 的根本原因不是 mask 精度,而是 reward 代理信号质量问题(见 Path A)。
373
-
374
- ---
375
-
376
- ## Direction II: Frame-Adaptive Token Optimization(初步探索,待后续)
377
-
378
- ### 方法设计
379
-
380
- 将单一共享 token q 扩展为视频 token 轨迹:
381
-
382
- ```
383
- q_t = q_global + delta_t
384
- ```
385
-
386
- 其中 q_global 是全局共享 token,delta_t 是每个 anchor 帧的局部残差,初始化为 0。联合优化:
387
-
388
- ```
389
- max Σ_t [λ_iou · e0_t · R_iou(q_t) - λ_area · R_area(q_t)]
390
- - λ_residual · ||delta||² - λ_smooth · Σ_t ||delta_t - delta_{t+1}||² - λ_reg · ||q_global - q_init||²
391
- ```
392
-
393
- 每个 anchor 帧使用各自的 e0_t(per-frame 存在先验)。delta_t 受 hard clip 约束:`||delta_t|| ≤ scale × ||q_init||`。
394
-
395
- ### 200-sample Probe Results(Unseen split)
396
-
397
- | Method | mIoU | F | reward gain p50 | delta ‖Δ‖ |
398
- |--------|------|---|-----------------|-----------|
399
- | baseline | 0.6745 | 0.7763 | — | — |
400
- | s1 | 0.6945 | 0.7773 | +0.0053 | — |
401
- | fa_base (无约束) | 0.6945 | 0.7711 | +0.0112 | 1.675 |
402
- | fa_smooth (λ_smooth=0.01) | 0.6960 | 0.7731 | +0.0104 | 1.488 |
403
- | fa_c03 (delta clip 0.3×) | 0.6959 | 0.7722 | +0.0112 | — |
404
-
405
- ### 关键发现
406
-
407
- **Reward-metric gap(核心问题):**
408
- ```
409
- reward gain p50: s1 = +0.0053 fa_c03 = +0.0112 (fa 高 2.1×)
410
- R_iou_pred 提升: s1 +0.077 fa_c03 +0.114
411
- 实际 mIoU 提升: s1 +2.96% fa_c03 +3.17% (仅差 0.21%)
412
- ```
413
- fa 拿到了多得多的 reward,但 mIoU 几乎没有额外提升,F 还略降。
414
-
415
- **结论:** 瓶颈不是优化结构,而是 R_iou_pred 本身的任务相关性不足。R_iou_pred 衡量"mask 有多干净",不衡量"mask 是否包含正确的音频目标"。所有架构变体(单 token / frame-adaptive)都受同一个天花板限制。
416
-
417
- Direction II 不在旧 reward 下继续调参,等 Path A(新 reward)有正向信号后再考虑是否重新引入。
418
-
419
- ---
420
-
421
- ## Path A: AVT-Aware Reward 重设计
422
-
423
- ### 动机
424
-
425
- Ref-AVS 中的 referent 不一定是发声体本身(可能是拿着发声物体的人、与声源相关的对象)。纯音频对齐 reward 会将优化推向 sound source 而非 text 指向的 referent。需要 audio + text + global visual context 共同定义的 referent consistency。
426
-
427
- ### AVT Proxy Reward 设计
428
-
429
- **核心洞察:** Fseg(= q_init)已经是 audio + video + text 的多模态融合 token,可直接作为 frozen AVT teacher。
430
-
431
- ```python
432
- R_avt = mean_t cos(z_in_t, q_init)
433
- R_avt_c = mean_t [cos(z_in_t, q_init) - β · cos(z_out_t, q_init)]
434
- ```
435
-
436
- - `z_in_t`:anchor 帧 t 的 soft-masked 图像特征(SAM 256-dim 空间)
437
- - `q_init`:frozen Fseg(AVT anchor,不参与优化梯度)
438
- - R_avt 高 → mask 区域与查询 referent 对齐;R_avt 低 → mask 指向错误目标
439
-
440
- 与 Stage 2 的区别:Stage 2 用当前 q(移动)对齐 z_in(当前 mask),导致自我确认偏差;R_avt 用 q_init(固定)作为 teacher,打破偏差。
441
-
442
- ### Step A0: Reward–Metric Correlation Study(下一步要做)
443
-
444
- **目的:** 在进入 full optimization 之前,先用数据验证新 reward 是否比 R_iou_pred 更能预测真实 metric 变化。
445
-
446
- **实验设置(200 samples, Unseen split):**
447
- 对每个(视频,segment)样本:
448
- 1. Baseline decode → IoU_base, F_base
449
- 2. q-LTPO s1 → q_best;记录 reward_gain、r_avt_gain、r_avt_c_gain(均在 q_ltpo_autograd 内计算)
450
- 3. LTPO decode → IoU_ltpo, F_ltpo
451
- 4. Δ = LTPO - baseline
452
-
453
- 输出 Pearson 相关表:
454
-
455
- ```
456
- Pearson r with ΔmIoU:
457
- R_iou_pred_gain : +0.xxx ← 当前 proxy
458
- R_avt_gain : +0.xxx ← cos(z_in, q_init)
459
- R_avt_c_gain : +0.xxx ← 对比版本
460
-
461
- Wrong direction (gain>0 但 Δ<0):
462
- R_iou / ΔmIoU : 0.xxx
463
- R_avt / ΔmIoU : 0.xxx
464
- ```
465
-
466
- **运行命令:**
467
- ```bash
468
- python load_model.py --eval_split test_u --max_eval_rows 200
469
- ```
470
-
471
- **判断标准:**
472
- - `r(R_avt, ΔmIoU) > r(R_iou, ΔmIoU)` → AVT proxy 更好,进入 Step A1
473
- - 两者相近 → reward 本身不是瓶颈,需要重新审视
474
- - `R_avt / ΔF wrong frac` 明显低于 `R_iou / ΔF` → AVT 能解释 F-score 不跟随 mIoU 的现象
475
-
476
- ### Step A1: Hybrid Reward(Step A0 验证后)
477
-
478
- ```
479
- R_task = λ1 · e0 · R_iou_pred + λ2 · R_avt_c - λ3 · R_area_soft
480
- ```
481
-
482
- - R_iou_pred 继续负责 mask quality(shape quality signal)
483
- - R_avt_c 负责 referent correctness(task-specific signal)
484
- - 两者结合才有可能同时维持 IoU 并提升 F
485
-
486
- 候选权重组合:`λ1=0.6, λ2=0.5, λ3=0.2`(AVT 作为辅助项,不完全取代 R_iou)。
487
-
488
- 如果 Step A1 有正向信号,再考虑将 Direction II(frame-adaptive)和新 reward 结合。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
analyze_d2_csv.py DELETED
@@ -1,239 +0,0 @@
1
- import argparse
2
- import csv
3
- import math
4
- from collections import defaultdict
5
-
6
- import numpy as np
7
-
8
-
9
- def parse_args():
10
- parser = argparse.ArgumentParser(description="Analyze D2 frame-level CSV.")
11
- parser.add_argument("--csv", required=True, help="Path to d2_llm_space.py or d2_basic.py CSV output.")
12
- parser.add_argument("--beta", type=float, default=1.0)
13
- parser.add_argument("--failure_iou", type=float, default=0.5)
14
- parser.add_argument("--bottom_frac", type=float, default=0.2)
15
- parser.add_argument("--pr_points", type=int, default=10)
16
- return parser.parse_args()
17
-
18
-
19
- def read_rows(path, beta):
20
- rows = []
21
- with open(path, newline="") as f:
22
- reader = csv.DictReader(f)
23
- for row in reader:
24
- row_beta = float(row["beta"])
25
- if abs(row_beta - beta) > 1e-8:
26
- continue
27
- q_col = "h_type" if "h_type" in row else "q_type"
28
- rows.append(
29
- {
30
- "sample_idx": int(row["sample_idx"]),
31
- "frame": int(row["frame"]),
32
- "anchor_type": row[q_col],
33
- "s_pred": float(row["s_pred"]),
34
- "s_gt": float(row["s_gt"]),
35
- "frame_iou": float(row["frame_iou"]),
36
- "iou_pred": float(row["iou_pred"]),
37
- "pred_area": float(row["pred_area"]),
38
- "gt_area": float(row["gt_area"]),
39
- }
40
- )
41
- if not rows:
42
- raise RuntimeError(f"No rows found for beta={beta} in {path}")
43
- return rows
44
-
45
-
46
- def corr(x, y):
47
- x = np.asarray(x, dtype=np.float64)
48
- y = np.asarray(y, dtype=np.float64)
49
- if len(x) < 2 or np.std(x) < 1e-12 or np.std(y) < 1e-12:
50
- return float("nan")
51
- return float(np.corrcoef(x, y)[0, 1])
52
-
53
-
54
- def residualize(y, controls):
55
- y = np.asarray(y, dtype=np.float64)
56
- cols = [np.ones(len(y), dtype=np.float64)]
57
- for control in controls:
58
- cols.append(np.asarray(control, dtype=np.float64))
59
- x = np.stack(cols, axis=1)
60
- coef, *_ = np.linalg.lstsq(x, y, rcond=None)
61
- return y - x @ coef
62
-
63
-
64
- def r2_score(y, y_pred):
65
- y = np.asarray(y, dtype=np.float64)
66
- y_pred = np.asarray(y_pred, dtype=np.float64)
67
- ss_res = np.sum((y - y_pred) ** 2)
68
- ss_tot = np.sum((y - y.mean()) ** 2)
69
- if ss_tot < 1e-12:
70
- return float("nan")
71
- return float(1.0 - ss_res / ss_tot)
72
-
73
-
74
- def linear_r2(y, features):
75
- y = np.asarray(y, dtype=np.float64)
76
- cols = [np.ones(len(y), dtype=np.float64)]
77
- for feature in features:
78
- cols.append(np.asarray(feature, dtype=np.float64))
79
- x = np.stack(cols, axis=1)
80
- coef, *_ = np.linalg.lstsq(x, y, rcond=None)
81
- return r2_score(y, x @ coef)
82
-
83
-
84
- def real_rows(rows):
85
- return [r for r in rows if r["anchor_type"] == "real"]
86
-
87
-
88
- def bottom_failure_enrichment(rows, failure_iou, bottom_frac):
89
- rr = real_rows(rows)
90
- n = len(rr)
91
- k = max(1, int(round(n * bottom_frac)))
92
- sorted_rows = sorted(rr, key=lambda r: r["s_pred"])
93
- bottom = sorted_rows[:k]
94
- baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rr])
95
- bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
96
- total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
97
- covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
98
- recall = covered_failures / max(total_failures, 1)
99
- enrichment = bottom_rate / max(baseline_rate, 1e-12)
100
- return {
101
- "n": n,
102
- "k": k,
103
- "baseline_failure_rate": baseline_rate,
104
- "bottom_failure_rate": bottom_rate,
105
- "bottom_failure_recall": recall,
106
- "enrichment": enrichment,
107
- "total_failures": total_failures,
108
- }
109
-
110
-
111
- def pr_curve(rows, failure_iou, points):
112
- rr = sorted(real_rows(rows), key=lambda r: r["s_pred"])
113
- total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
114
- out = []
115
- for frac in np.linspace(0.05, 1.0, points):
116
- k = max(1, int(round(len(rr) * frac)))
117
- selected = rr[:k]
118
- failures = sum(r["frame_iou"] < failure_iou for r in selected)
119
- precision = failures / k
120
- recall = failures / max(total_failures, 1)
121
- out.append((frac, precision, recall))
122
- return out
123
-
124
-
125
- def margin_rows(rows):
126
- grouped = defaultdict(dict)
127
- for r in rows:
128
- key = (r["sample_idx"], r["frame"])
129
- grouped[key][r["anchor_type"]] = r
130
-
131
- out = []
132
- for key, group in grouped.items():
133
- if "real" not in group:
134
- continue
135
- controls = [group[name]["s_pred"] for name in ("shuffled", "wrong_ref") if name in group]
136
- if not controls:
137
- continue
138
- real = group["real"]
139
- item = dict(real)
140
- item["s_margin"] = real["s_pred"] - max(controls)
141
- out.append(item)
142
- return out
143
-
144
-
145
- def bottom_failure_enrichment_for_score(rows, score_key, failure_iou, bottom_frac):
146
- n = len(rows)
147
- k = max(1, int(round(n * bottom_frac)))
148
- sorted_rows = sorted(rows, key=lambda r: r[score_key])
149
- bottom = sorted_rows[:k]
150
- baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rows])
151
- bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
152
- total_failures = sum(r["frame_iou"] < failure_iou for r in rows)
153
- covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
154
- return {
155
- "n": n,
156
- "k": k,
157
- "baseline_failure_rate": baseline_rate,
158
- "bottom_failure_rate": bottom_rate,
159
- "bottom_failure_recall": covered_failures / max(total_failures, 1),
160
- "enrichment": bottom_rate / max(baseline_rate, 1e-12),
161
- }
162
-
163
-
164
- def main():
165
- args = parse_args()
166
- rows = read_rows(args.csv, args.beta)
167
- rr = real_rows(rows)
168
-
169
- print(f"CSV: {args.csv}")
170
- print(f"beta: {args.beta}")
171
- print(f"real frames: {len(rr)}")
172
- print(f"failure definition: frame_iou < {args.failure_iou}")
173
-
174
- print("\nReal s_pred Correlations")
175
- print(f"corr(s_pred, frame_iou): {corr([r['s_pred'] for r in rr], [r['frame_iou'] for r in rr]):+.4f}")
176
- print(f"corr(s_pred, iou_pred): {corr([r['s_pred'] for r in rr], [r['iou_pred'] for r in rr]):+.4f}")
177
- print(f"corr(s_pred, pred_area): {corr([r['s_pred'] for r in rr], [r['pred_area'] for r in rr]):+.4f}")
178
-
179
- s_pred_values = [r["s_pred"] for r in rr]
180
- frame_iou_values = [r["frame_iou"] for r in rr]
181
- iou_pred_values = [r["iou_pred"] for r in rr]
182
- pred_area_values = [r["pred_area"] for r in rr]
183
- gt_area_values = [r["gt_area"] for r in rr]
184
- partial_iou_pred = corr(
185
- residualize(s_pred_values, [iou_pred_values]),
186
- residualize(frame_iou_values, [iou_pred_values]),
187
- )
188
- partial_iou_area = corr(
189
- residualize(s_pred_values, [iou_pred_values, pred_area_values]),
190
- residualize(frame_iou_values, [iou_pred_values, pred_area_values]),
191
- )
192
- partial_iou_area_gt = corr(
193
- residualize(s_pred_values, [iou_pred_values, pred_area_values, gt_area_values]),
194
- residualize(frame_iou_values, [iou_pred_values, pred_area_values, gt_area_values]),
195
- )
196
- r2_iou_pred = linear_r2(frame_iou_values, [iou_pred_values])
197
- r2_iou_pred_s = linear_r2(frame_iou_values, [iou_pred_values, s_pred_values])
198
- r2_iou_pred_area = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values])
199
- r2_iou_pred_area_s = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values, s_pred_values])
200
-
201
- print("\nPartial Correlation / Residual Gain")
202
- print(f"partial corr(s_pred, frame_iou | iou_pred): {partial_iou_pred:+.4f}")
203
- print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area): {partial_iou_area:+.4f}")
204
- print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area,gt_area): {partial_iou_area_gt:+.4f}")
205
- print(f"R2 frame_iou ~ iou_pred: {r2_iou_pred:.4f}")
206
- print(f"R2 frame_iou ~ iou_pred + s_pred: {r2_iou_pred_s:.4f} (gain {r2_iou_pred_s - r2_iou_pred:+.4f})")
207
- print(f"R2 frame_iou ~ iou_pred + pred_area: {r2_iou_pred_area:.4f}")
208
- print(f"R2 frame_iou ~ iou_pred + pred_area + s_pred: {r2_iou_pred_area_s:.4f} (gain {r2_iou_pred_area_s - r2_iou_pred_area:+.4f})")
209
-
210
- stats = bottom_failure_enrichment(rows, args.failure_iou, args.bottom_frac)
211
- print("\nBottom-k Failure Enrichment")
212
- print(f"bottom_frac: {args.bottom_frac:.2f} ({stats['k']}/{stats['n']} frames)")
213
- print(f"total failures: {stats['total_failures']}")
214
- print(f"random/baseline failure rate: {stats['baseline_failure_rate']:.4f}")
215
- print(f"bottom-s_pred failure rate: {stats['bottom_failure_rate']:.4f}")
216
- print(f"bottom-s_pred failure recall: {stats['bottom_failure_recall']:.4f}")
217
- print(f"enrichment: {stats['enrichment']:.2f}x")
218
-
219
- print("\nPR Curve Summary")
220
- print("selected_frac | precision | recall")
221
- for frac, precision, recall in pr_curve(rows, args.failure_iou, args.pr_points):
222
- print(f"{frac:.2f} | {precision:.4f} | {recall:.4f}")
223
-
224
- mr = margin_rows(rows)
225
- if mr:
226
- print("\nOffline Margin-D2")
227
- print(f"margin frames: {len(mr)}")
228
- print(f"corr(s_margin, frame_iou): {corr([r['s_margin'] for r in mr], [r['frame_iou'] for r in mr]):+.4f}")
229
- print(f"corr(s_margin, pred_area): {corr([r['s_margin'] for r in mr], [r['pred_area'] for r in mr]):+.4f}")
230
- mstats = bottom_failure_enrichment_for_score(mr, "s_margin", args.failure_iou, args.bottom_frac)
231
- print(f"bottom-s_margin failure rate: {mstats['bottom_failure_rate']:.4f}")
232
- print(f"bottom-s_margin failure recall: {mstats['bottom_failure_recall']:.4f}")
233
- print(f"margin enrichment: {mstats['enrichment']:.2f}x")
234
- else:
235
- print("\nOffline Margin-D2 skipped: shuffled/wrong_ref controls not available.")
236
-
237
-
238
- if __name__ == "__main__":
239
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build_rpb_dev_manifest.py DELETED
@@ -1,71 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
- import random
5
-
6
- import pandas as pd
7
-
8
-
9
- def sample_indices(size, count, seed):
10
- if count <= 0:
11
- return []
12
- if count > size:
13
- raise ValueError(f"Requested {count} samples from a split of size {size}")
14
- rng = random.Random(seed)
15
- indices = list(range(size))
16
- rng.shuffle(indices)
17
- selected = sorted(indices[:count])
18
- return selected
19
-
20
-
21
- def main():
22
- parser = argparse.ArgumentParser(description="Build a fixed subset manifest for RPB dev experiments.")
23
- parser.add_argument("--metadata", type=str, default="/workspace/SimToken/data/metadata.csv")
24
- parser.add_argument("--output", type=str, required=True)
25
- parser.add_argument("--seed", type=int, default=42)
26
- parser.add_argument("--train_rows", type=int, default=0)
27
- parser.add_argument("--test_s_rows", type=int, default=200)
28
- parser.add_argument("--test_u_rows", type=int, default=200)
29
- parser.add_argument("--test_n_rows", type=int, default=200)
30
- args = parser.parse_args()
31
-
32
- metadata = pd.read_csv(args.metadata, header=0)
33
- split_sizes = {
34
- "train": int((metadata["split"] == "train").sum()),
35
- "test_s": int((metadata["split"] == "test_s").sum()),
36
- "test_u": int((metadata["split"] == "test_u").sum()),
37
- "test_n": int((metadata["split"] == "test_n").sum()),
38
- }
39
-
40
- manifest = {
41
- "train": sample_indices(split_sizes["train"], args.train_rows, args.seed),
42
- "test_s": sample_indices(split_sizes["test_s"], args.test_s_rows, args.seed + 1),
43
- "test_u": sample_indices(split_sizes["test_u"], args.test_u_rows, args.seed + 2),
44
- "test_n": sample_indices(split_sizes["test_n"], args.test_n_rows, args.seed + 3),
45
- }
46
-
47
- # Remove empty entries so train.py only subsets the splits we intentionally fix.
48
- manifest = {key: value for key, value in manifest.items() if value}
49
-
50
- os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
51
- with open(args.output, "w", encoding="utf-8") as f:
52
- json.dump(
53
- {
54
- "metadata": {
55
- "seed": args.seed,
56
- "split_sizes": split_sizes,
57
- "source_metadata": os.path.abspath(args.metadata),
58
- },
59
- "subsets": manifest,
60
- },
61
- f,
62
- indent=2,
63
- )
64
-
65
- print(f"saved subset manifest to {args.output}")
66
- for split_name, indices in manifest.items():
67
- print(f"{split_name}: {len(indices)} samples")
68
-
69
-
70
- if __name__ == "__main__":
71
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache_q_features.py DELETED
@@ -1,125 +0,0 @@
1
- import json
2
- import os
3
- from functools import partial
4
- from itertools import islice
5
-
6
- import torch
7
- import transformers
8
- from torch.utils.data import DataLoader
9
- from tqdm import tqdm
10
-
11
- from configs import args
12
- from datasets import REFAVS
13
- from decoder_invariance_check import build_model, set_seed
14
- from load_model import collate_fn, dict_to_cuda
15
-
16
-
17
- def _jsonable_size(size):
18
- if isinstance(size, torch.Tensor):
19
- return [int(x) for x in size.detach().cpu().tolist()]
20
- return [int(x) for x in size]
21
-
22
-
23
- def main():
24
- set_seed(42)
25
- torch.set_grad_enabled(False)
26
-
27
- tokenizer = transformers.AutoTokenizer.from_pretrained(
28
- args.mllm,
29
- cache_dir=None,
30
- model_max_length=2048,
31
- padding_side="right",
32
- use_fast=False,
33
- )
34
- tokenizer.pad_token = tokenizer.unk_token
35
- tokenizer.add_tokens("[SEG]")
36
- seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
37
-
38
- dataset = REFAVS(args.cache_split, args, tokenizer, input_type="refer")
39
- loader = DataLoader(
40
- dataset,
41
- batch_size=1,
42
- shuffle=False,
43
- num_workers=0,
44
- collate_fn=partial(collate_fn, tokenizer=tokenizer),
45
- )
46
-
47
- split_root = os.path.join(args.cache_root, args.cache_split)
48
- os.makedirs(split_root, exist_ok=True)
49
- index_path = os.path.join(split_root, "index.jsonl")
50
- if os.path.exists(index_path) and not args.overwrite_cache:
51
- raise FileExistsError(
52
- f"{index_path} already exists. Pass --overwrite_cache to rebuild it."
53
- )
54
-
55
- limit = args.max_eval_rows if args.max_eval_rows > 0 else len(dataset)
56
- print(f"cache split={args.cache_split} | samples={min(limit, len(dataset))}")
57
- print(f"cache root: {split_root}")
58
-
59
- model = build_model(tokenizer, seg_token_idx)
60
- model.eval()
61
-
62
- rows = []
63
- for sample_idx, batch in enumerate(
64
- tqdm(islice(loader, limit), total=min(limit, len(dataset)), desc=f"Caching {args.cache_split}")
65
- ):
66
- batch = dict_to_cuda(batch)
67
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
68
- output = model.forward(
69
- images=batch["images"],
70
- images_clip=batch["images_clip"],
71
- audio_features=batch["audio_feats"],
72
- image_features=batch["image_feats"],
73
- input_ids=batch["input_ids"],
74
- labels=batch["labels"],
75
- attention_masks=batch["attention_masks"],
76
- masks_list=batch["masks"],
77
- resize_list=batch["resizes"],
78
- orgsize_list=batch["orgsizes"],
79
- conversation_list=batch["convs"],
80
- refs_num=batch["refs_num"],
81
- fids=batch["fids"],
82
- vids=batch["vids"],
83
- contrast=args.ct_weight,
84
- ref_ids=batch["ref_ids"],
85
- inference=True,
86
- )
87
-
88
- cache_name = f"{sample_idx:06d}.pt"
89
- cache_path = os.path.join(split_root, cache_name)
90
- item = {
91
- "sample_idx": sample_idx,
92
- "vid": batch["vids"][0],
93
- "refs": batch["refs"][0],
94
- "fids": [int(x) for x in batch["fids"][0]],
95
- "resize": _jsonable_size(batch["resizes"][0]),
96
- "orgsize": _jsonable_size(batch["orgsizes"][0]),
97
- "q": output["seg_embeddings"][0].detach().cpu().float(),
98
- }
99
- torch.save(item, cache_path)
100
- rows.append(
101
- {
102
- "sample_idx": sample_idx,
103
- "path": cache_name,
104
- "vid": item["vid"],
105
- "refs": item["refs"],
106
- "fids": item["fids"],
107
- "resize": item["resize"],
108
- "orgsize": item["orgsize"],
109
- "num_seg": int(item["q"].shape[0]),
110
- }
111
- )
112
-
113
- if not rows:
114
- raise RuntimeError("No samples were cached.")
115
-
116
- with open(index_path, "w") as f:
117
- for row in rows:
118
- f.write(json.dumps(row) + "\n")
119
-
120
- print(f"cached samples: {len(rows)}")
121
- print(f"saved index: {index_path}")
122
-
123
-
124
- if __name__ == "__main__":
125
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cache_q_smoke/test_s/000000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5f85d7cf7b83caf6fedb153a2cea2b36dd144ee3c0e34039483e20d208ea92d3
3
- size 2327
 
 
 
 
cache_q_smoke/test_s/index.jsonl DELETED
@@ -1 +0,0 @@
1
- {"sample_idx": 0, "path": "000000.pt", "vid": "-3ABOVeVmpU_136000_146000", "refs": ["the object that keeps making sound at all times"], "fids": [1], "resize": [576, 1024], "orgsize": [720, 1280], "num_seg": 1}
 
 
checkpoints/rpb_dev_mixed_pm_only_a018_wm005.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c1facc9eac5ffdfd12c97d252af2c8eedc4e526a53931d301b0ef4bed698213
3
- size 30841132766
 
 
 
 
checkpoints/rpb_dev_pm_only_a018.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:33e8b6251c69d7d4de055b488a2f2345eece1991831a2f08ce5f1d1cb795ae5f
3
- size 30841115170
 
 
 
 
checkpoints/rpb_probe_eval_directional_pm_only_a02.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6dc5cd6b02f5d54a026694f6a1217f46137fdaa4499a71fa7b9bd95ede17da6c
3
- size 30841141852
 
 
 
 
d2_basic.py DELETED
@@ -1,340 +0,0 @@
1
- import csv
2
- import math
3
- import os
4
- from functools import partial
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- import transformers
10
- from torch.utils.data import DataLoader
11
-
12
- from configs import args
13
- from datasets import REFAVS
14
- from decoder_invariance_check import build_model, set_seed
15
- from load_model import collate_fn, dict_to_cuda
16
-
17
-
18
- def make_loader(tokenizer):
19
- dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
20
- return DataLoader(
21
- dataset,
22
- batch_size=1,
23
- shuffle=False,
24
- num_workers=0,
25
- collate_fn=partial(collate_fn, tokenizer=tokenizer),
26
- )
27
-
28
-
29
- def build_tokenizer():
30
- tokenizer = transformers.AutoTokenizer.from_pretrained(
31
- args.mllm,
32
- cache_dir=None,
33
- model_max_length=2048,
34
- padding_side="right",
35
- use_fast=False,
36
- )
37
- tokenizer.pad_token = tokenizer.unk_token
38
- tokenizer.add_tokens("[SEG]")
39
- seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
40
- return tokenizer, seg_token_idx
41
-
42
-
43
- def get_q(model, batch):
44
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
45
- output = model.forward(
46
- images=batch["images"],
47
- images_clip=batch["images_clip"],
48
- audio_features=batch["audio_feats"],
49
- image_features=batch["image_feats"],
50
- input_ids=batch["input_ids"],
51
- labels=batch["labels"],
52
- attention_masks=batch["attention_masks"],
53
- masks_list=batch["masks"],
54
- resize_list=batch["resizes"],
55
- orgsize_list=batch["orgsizes"],
56
- conversation_list=batch["convs"],
57
- refs_num=batch["refs_num"],
58
- fids=batch["fids"],
59
- vids=batch["vids"],
60
- contrast=args.ct_weight,
61
- ref_ids=batch["ref_ids"],
62
- inference=True,
63
- )
64
- return output["seg_embeddings"][0][0].float()
65
-
66
-
67
- def decode_low_res(model, batch, q):
68
- visual_model = model.get_model().visual_model
69
- sparse, dense = visual_model.prompt_encoder(
70
- points=None,
71
- boxes=None,
72
- masks=None,
73
- text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
74
- )
75
- sparse = sparse.to(q.dtype)
76
- dense = dense.to(q.dtype)
77
-
78
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
79
- low_res_masks, iou_predictions = visual_model.mask_decoder(
80
- image_embeddings=batch["image_feats"][0],
81
- image_pe=visual_model.prompt_encoder.get_dense_pe(),
82
- sparse_prompt_embeddings=sparse,
83
- dense_prompt_embeddings=dense,
84
- multimask_output=False,
85
- )
86
- return low_res_masks.float(), iou_predictions.float().squeeze(-1)
87
-
88
-
89
- def masks_to_64(mask_logits_or_binary):
90
- if mask_logits_or_binary.ndim == 3:
91
- mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
92
- return F.interpolate(
93
- mask_logits_or_binary.float(),
94
- size=(64, 64),
95
- mode="bilinear",
96
- align_corners=False,
97
- ).clamp(0.0, 1.0)
98
-
99
-
100
- def d2_scores(image_embeddings, mask64, q, beta):
101
- feats = image_embeddings.float()
102
- if mask64.shape[0] != feats.shape[0]:
103
- raise ValueError(f"Mask/frame mismatch: {mask64.shape} vs {feats.shape}")
104
-
105
- q = F.normalize(q.float().view(1, -1), dim=-1)
106
- mask = mask64.float()
107
- comp = 1.0 - mask
108
-
109
- z_in = (feats * mask).sum(dim=(2, 3)) / mask.sum(dim=(2, 3)).clamp_min(1e-6)
110
- z_out = (feats * comp).sum(dim=(2, 3)) / comp.sum(dim=(2, 3)).clamp_min(1e-6)
111
-
112
- z_in = F.normalize(z_in, dim=-1)
113
- z_out = F.normalize(z_out, dim=-1)
114
- return (z_in @ q.T).squeeze(-1) - beta * (z_out @ q.T).squeeze(-1)
115
-
116
-
117
- def frame_iou(pred_logits, gt_masks):
118
- pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
119
- gt = gt_masks.float()
120
- if pred.ndim == 4:
121
- pred = pred.squeeze(1)
122
- inter = (pred * gt).sum(dim=(1, 2))
123
- union = torch.maximum(pred, gt).sum(dim=(1, 2))
124
- num_pixels = pred.shape[-1] * pred.shape[-2]
125
- no_obj = gt.sum(dim=(1, 2)) == 0
126
- inter_no_obj = ((1.0 - pred) * (1.0 - gt)).sum(dim=(1, 2))
127
- inter = torch.where(no_obj, inter_no_obj, inter)
128
- union = torch.where(no_obj, torch.full_like(union, float(num_pixels)), union)
129
- return inter / union.clamp_min(1e-7)
130
-
131
-
132
- def frame_fscore_proxy(pred_logits, gt_masks):
133
- pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
134
- gt = gt_masks.float()
135
- if pred.ndim == 4:
136
- pred = pred.squeeze(1)
137
- tp = (pred * gt).sum(dim=(1, 2))
138
- precision = tp / pred.sum(dim=(1, 2)).clamp_min(1e-7)
139
- recall = tp / gt.sum(dim=(1, 2)).clamp_min(1e-7)
140
- beta2 = 0.3
141
- fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall).clamp_min(1e-7)
142
- no_obj = gt.sum(dim=(1, 2)) == 0
143
- return torch.where(no_obj, torch.zeros_like(fscore), fscore)
144
-
145
-
146
- def parse_betas():
147
- raw = os.environ.get("D2_BETAS", "0.5")
148
- return [float(x.strip()) for x in raw.split(",") if x.strip()]
149
-
150
-
151
- def collect_q_pool(model, tokenizer, limit):
152
- q_pool = []
153
- loader = make_loader(tokenizer)
154
- for sample_idx, batch in enumerate(loader):
155
- if sample_idx >= limit:
156
- break
157
- batch = dict_to_cuda(batch)
158
- q = get_q(model, batch)
159
- q_pool.append(
160
- {
161
- "sample_idx": sample_idx,
162
- "vid": batch["vids"][0],
163
- "ref": batch["refs"][0][0],
164
- "fid": int(batch["fids"][0][0]),
165
- "q": q.cpu(),
166
- }
167
- )
168
- print(f"Collected q {sample_idx}: vid={q_pool[-1]['vid']} ref={q_pool[-1]['ref']}")
169
- if not q_pool:
170
- raise RuntimeError("No q vectors collected. Is the selected split empty?")
171
- return q_pool
172
-
173
-
174
- def choose_shuffled_idx(sample_idx, q_pool):
175
- if len(q_pool) <= 1:
176
- return None
177
- return (sample_idx + 1) % len(q_pool)
178
-
179
-
180
- def choose_wrong_ref_idx(sample_idx, q_pool):
181
- current = q_pool[sample_idx]
182
- for item in q_pool:
183
- if item["sample_idx"] == sample_idx:
184
- continue
185
- if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
186
- return item["sample_idx"]
187
- for item in q_pool:
188
- if item["sample_idx"] == sample_idx:
189
- continue
190
- if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
191
- return item["sample_idx"]
192
- return None
193
-
194
-
195
- def run_d2(model, tokenizer, q_pool, betas, limit):
196
- rows = []
197
- loader = make_loader(tokenizer)
198
- q_lookup = {item["sample_idx"]: item for item in q_pool}
199
- generator = torch.Generator(device="cuda")
200
- generator.manual_seed(1234)
201
-
202
- for sample_idx, batch in enumerate(loader):
203
- if sample_idx >= limit:
204
- break
205
- batch = dict_to_cuda(batch)
206
- item = q_lookup[sample_idx]
207
- real_q = item["q"].cuda()
208
-
209
- low_res_masks, iou_predictions = decode_low_res(model, batch, real_q)
210
- pred_mask64 = masks_to_64(torch.sigmoid(low_res_masks))
211
- gt_masks = batch["masks"][0][0].float()
212
- gt_mask64 = masks_to_64(gt_masks)
213
- image_embeddings = batch["image_feats"][0].float()
214
-
215
- pred_logits_hr = model.get_model().visual_model.postprocess_masks(
216
- low_res_masks.to(batch["image_feats"][0].dtype),
217
- input_size=batch["resizes"][0],
218
- original_size=batch["orgsizes"][0],
219
- ).squeeze(1)
220
-
221
- frame_ious = frame_iou(pred_logits_hr, gt_masks)
222
- frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
223
- pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
224
- gt_area = gt_masks.float().mean(dim=(1, 2))
225
-
226
- shuffled_idx = choose_shuffled_idx(sample_idx, q_pool)
227
- wrong_ref_idx = choose_wrong_ref_idx(sample_idx, q_pool)
228
- q_controls = [
229
- ("real", real_q, sample_idx),
230
- ("random", torch.randn(real_q.shape, device=real_q.device, generator=generator), None),
231
- ]
232
- if shuffled_idx is not None:
233
- q_controls.append(("shuffled", q_lookup[shuffled_idx]["q"].cuda(), shuffled_idx))
234
- if wrong_ref_idx is not None:
235
- q_controls.append(("wrong_ref", q_lookup[wrong_ref_idx]["q"].cuda(), wrong_ref_idx))
236
-
237
- for beta in betas:
238
- for q_type, q, q_source_idx in q_controls:
239
- pred_scores = d2_scores(image_embeddings, pred_mask64, q, beta)
240
- gt_scores = d2_scores(image_embeddings, gt_mask64, q, beta)
241
- base_info = {
242
- "sample_idx": sample_idx,
243
- "vid": item["vid"],
244
- "ref": item["ref"],
245
- "fid": item["fid"],
246
- "split": args.eval_split,
247
- "frame_iou": math.nan,
248
- "frame_fscore_proxy": math.nan,
249
- "iou_pred": math.nan,
250
- "pred_area": math.nan,
251
- "gt_area": math.nan,
252
- }
253
- for frame_idx in range(pred_scores.shape[0]):
254
- base_info_frame = dict(base_info)
255
- base_info_frame.update(
256
- {
257
- "frame_iou": frame_ious[frame_idx].item(),
258
- "frame_fscore_proxy": frame_fscores[frame_idx].item(),
259
- "iou_pred": iou_predictions[frame_idx].item(),
260
- "pred_area": pred_area[frame_idx].item(),
261
- "gt_area": gt_area[frame_idx].item(),
262
- }
263
- )
264
- row = dict(base_info_frame)
265
- row.update(
266
- {
267
- "frame": frame_idx,
268
- "q_type": q_type,
269
- "beta": beta,
270
- "s_pred": pred_scores[frame_idx].item(),
271
- "s_gt": gt_scores[frame_idx].item(),
272
- "q_source_idx": q_source_idx if q_source_idx is not None else "",
273
- }
274
- )
275
- rows.append(row)
276
-
277
- real_rows = [
278
- r for r in rows if r["sample_idx"] == sample_idx and r["q_type"] == "real" and r["beta"] == betas[0]
279
- ]
280
- s_pred_values = [r["s_pred"] for r in real_rows]
281
- print(
282
- f"D2 {sample_idx}: vid={item['vid']} ref={item['ref']} "
283
- f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
284
- f"mean_iou={frame_ious.mean().item():.4f}"
285
- )
286
-
287
- return rows
288
-
289
-
290
- def print_summary(rows):
291
- real_rows = [r for r in rows if r["q_type"] == "real"]
292
- if not real_rows:
293
- return
294
- by_beta = sorted(set(r["beta"] for r in real_rows))
295
- print("\nSummary")
296
- print(f"rows: {len(rows)}")
297
- for beta in by_beta:
298
- beta_rows = [r for r in rows if r["beta"] == beta]
299
- print(f"\nbeta={beta}")
300
- for q_type in sorted(set(r["q_type"] for r in beta_rows)):
301
- qr = [r for r in beta_rows if r["q_type"] == q_type]
302
- print(
303
- f"{q_type:10s} "
304
- f"mean_s_pred={np.mean([r['s_pred'] for r in qr]):+.4f} "
305
- f"mean_s_gt={np.mean([r['s_gt'] for r in qr]):+.4f}"
306
- )
307
- real_beta = [r for r in beta_rows if r["q_type"] == "real"]
308
- s_pred = np.array([r["s_pred"] for r in real_beta])
309
- frame_iou_values = np.array([r["frame_iou"] for r in real_beta])
310
- if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
311
- corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
312
- print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
313
- else:
314
- print("corr(real s_pred, frame_iou)=nan")
315
-
316
-
317
- def main():
318
- set_seed(42)
319
- torch.set_grad_enabled(False)
320
- betas = parse_betas()
321
- tokenizer, seg_token_idx = build_tokenizer()
322
- limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
323
- print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
324
-
325
- model = build_model(tokenizer, seg_token_idx)
326
- q_pool = collect_q_pool(model, tokenizer, limit)
327
- rows = run_d2(model, tokenizer, q_pool, betas, limit)
328
- print_summary(rows)
329
-
330
- csv_path = os.environ.get("D2_BASIC_CSV", f"/workspace/SimToken/d2_basic_{args.eval_split}_{limit}.csv")
331
- os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
332
- with open(csv_path, "w", newline="") as f:
333
- writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
334
- writer.writeheader()
335
- writer.writerows(rows)
336
- print(f"\nSaved CSV: {csv_path}")
337
-
338
-
339
- if __name__ == "__main__":
340
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2_llm_space.py DELETED
@@ -1,314 +0,0 @@
1
- import csv
2
- import math
3
- import os
4
- from functools import partial
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- import transformers
10
- from torch.utils.data import DataLoader
11
-
12
- from configs import args
13
- from datasets import REFAVS
14
- from decoder_invariance_check import build_model, set_seed
15
- from d2_basic import frame_fscore_proxy, frame_iou
16
- from load_model import collate_fn, dict_to_cuda
17
-
18
-
19
- def build_tokenizer():
20
- tokenizer = transformers.AutoTokenizer.from_pretrained(
21
- args.mllm,
22
- cache_dir=None,
23
- model_max_length=2048,
24
- padding_side="right",
25
- use_fast=False,
26
- )
27
- tokenizer.pad_token = tokenizer.unk_token
28
- tokenizer.add_tokens("[SEG]")
29
- seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
30
- return tokenizer, seg_token_idx
31
-
32
-
33
- def make_loader(tokenizer):
34
- dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
35
- return DataLoader(
36
- dataset,
37
- batch_size=1,
38
- shuffle=False,
39
- num_workers=0,
40
- collate_fn=partial(collate_fn, tokenizer=tokenizer),
41
- )
42
-
43
-
44
- def forward_for_hidden_and_q(model, batch):
45
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
46
- output = model.forward(
47
- images=batch["images"],
48
- images_clip=batch["images_clip"],
49
- audio_features=batch["audio_feats"],
50
- image_features=batch["image_feats"],
51
- input_ids=batch["input_ids"],
52
- labels=batch["labels"],
53
- attention_masks=batch["attention_masks"],
54
- masks_list=batch["masks"],
55
- resize_list=batch["resizes"],
56
- orgsize_list=batch["orgsizes"],
57
- conversation_list=batch["convs"],
58
- refs_num=batch["refs_num"],
59
- fids=batch["fids"],
60
- vids=batch["vids"],
61
- contrast=args.ct_weight,
62
- ref_ids=batch["ref_ids"],
63
- inference=True,
64
- )
65
- h_seg = output["seg_hidden_states"][0][0].float()
66
- q = output["seg_embeddings"][0][0].float()
67
- return h_seg, q
68
-
69
-
70
- def decode_low_res(model, batch, q):
71
- visual_model = model.get_model().visual_model
72
- sparse, dense = visual_model.prompt_encoder(
73
- points=None,
74
- boxes=None,
75
- masks=None,
76
- text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
77
- )
78
- sparse = sparse.to(q.dtype)
79
- dense = dense.to(q.dtype)
80
-
81
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
82
- low_res_masks, iou_predictions = visual_model.mask_decoder(
83
- image_embeddings=batch["image_feats"][0],
84
- image_pe=visual_model.prompt_encoder.get_dense_pe(),
85
- sparse_prompt_embeddings=sparse,
86
- dense_prompt_embeddings=dense,
87
- multimask_output=False,
88
- )
89
- return low_res_masks.float(), iou_predictions.float().squeeze(-1)
90
-
91
-
92
- def clip_projected_tokens(model, batch):
93
- images = torch.cat(batch["images_clip"], dim=0)
94
- with torch.no_grad():
95
- clip_tokens = model.encode_images(images)
96
- projector = model.get_model().mm_projector
97
- clip_tokens = clip_tokens.to(projector.weight.dtype)
98
- llm_tokens = projector(clip_tokens).float()
99
- return llm_tokens
100
-
101
-
102
- def infer_square_grid(num_tokens):
103
- grid = int(math.sqrt(num_tokens))
104
- if grid * grid != num_tokens:
105
- raise ValueError(f"Expected square patch-token grid, got {num_tokens} tokens")
106
- return grid
107
-
108
-
109
- def masks_to_token_grid(mask_logits_or_binary, num_tokens):
110
- if mask_logits_or_binary.ndim == 3:
111
- mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
112
- grid = infer_square_grid(num_tokens)
113
- return F.interpolate(
114
- mask_logits_or_binary.float(),
115
- size=(grid, grid),
116
- mode="bilinear",
117
- align_corners=False,
118
- ).flatten(2).transpose(1, 2).clamp(0.0, 1.0)
119
-
120
-
121
- def d2_scores_llm(llm_tokens, mask_tokens, h_seg, beta):
122
- if llm_tokens.shape[:2] != mask_tokens.shape[:2]:
123
- raise ValueError(f"Token/mask mismatch: {llm_tokens.shape} vs {mask_tokens.shape}")
124
- h = F.normalize(h_seg.float().view(1, -1), dim=-1)
125
- tokens = llm_tokens.float()
126
- mask = mask_tokens.float()
127
- comp = 1.0 - mask
128
-
129
- z_in = (tokens * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-6)
130
- z_out = (tokens * comp).sum(dim=1) / comp.sum(dim=1).clamp_min(1e-6)
131
-
132
- z_in = F.normalize(z_in, dim=-1)
133
- z_out = F.normalize(z_out, dim=-1)
134
- return (z_in @ h.T).squeeze(-1) - beta * (z_out @ h.T).squeeze(-1)
135
-
136
-
137
- def parse_betas():
138
- raw = os.environ.get("D2_BETAS", "0.5")
139
- return [float(x.strip()) for x in raw.split(",") if x.strip()]
140
-
141
-
142
- def collect_hidden_pool(model, tokenizer, limit):
143
- pool = []
144
- loader = make_loader(tokenizer)
145
- for sample_idx, batch in enumerate(loader):
146
- if sample_idx >= limit:
147
- break
148
- batch = dict_to_cuda(batch)
149
- h_seg, q = forward_for_hidden_and_q(model, batch)
150
- pool.append(
151
- {
152
- "sample_idx": sample_idx,
153
- "vid": batch["vids"][0],
154
- "ref": batch["refs"][0][0],
155
- "fid": int(batch["fids"][0][0]),
156
- "h": h_seg.cpu(),
157
- "q": q.cpu(),
158
- }
159
- )
160
- print(f"Collected h {sample_idx}: vid={pool[-1]['vid']} ref={pool[-1]['ref']}")
161
- if not pool:
162
- raise RuntimeError("No hidden states collected. Is the selected split empty?")
163
- return pool
164
-
165
-
166
- def choose_shuffled_idx(sample_idx, pool):
167
- if len(pool) <= 1:
168
- return None
169
- return (sample_idx + 1) % len(pool)
170
-
171
-
172
- def choose_wrong_ref_idx(sample_idx, pool):
173
- current = pool[sample_idx]
174
- for item in pool:
175
- if item["sample_idx"] == sample_idx:
176
- continue
177
- if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
178
- return item["sample_idx"]
179
- for item in pool:
180
- if item["sample_idx"] == sample_idx:
181
- continue
182
- if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
183
- return item["sample_idx"]
184
- return None
185
-
186
-
187
- def run_d2_llm(model, tokenizer, pool, betas, limit):
188
- rows = []
189
- lookup = {item["sample_idx"]: item for item in pool}
190
- generator = torch.Generator(device="cuda")
191
- generator.manual_seed(1234)
192
- loader = make_loader(tokenizer)
193
-
194
- for sample_idx, batch in enumerate(loader):
195
- if sample_idx >= limit:
196
- break
197
- batch = dict_to_cuda(batch)
198
- item = lookup[sample_idx]
199
- h_real = item["h"].cuda()
200
- q_real = item["q"].cuda()
201
-
202
- low_res_masks, iou_predictions = decode_low_res(model, batch, q_real)
203
- llm_tokens = clip_projected_tokens(model, batch)
204
- pred_mask_tokens = masks_to_token_grid(torch.sigmoid(low_res_masks), llm_tokens.shape[1])
205
- gt_masks = batch["masks"][0][0].float()
206
- gt_mask_tokens = masks_to_token_grid(gt_masks, llm_tokens.shape[1])
207
-
208
- pred_logits_hr = model.get_model().visual_model.postprocess_masks(
209
- low_res_masks.to(batch["image_feats"][0].dtype),
210
- input_size=batch["resizes"][0],
211
- original_size=batch["orgsizes"][0],
212
- ).squeeze(1)
213
- frame_ious = frame_iou(pred_logits_hr, gt_masks)
214
- frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
215
- pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
216
- gt_area = gt_masks.float().mean(dim=(1, 2))
217
-
218
- shuffled_idx = choose_shuffled_idx(sample_idx, pool)
219
- wrong_ref_idx = choose_wrong_ref_idx(sample_idx, pool)
220
- controls = [
221
- ("real", h_real, sample_idx),
222
- ("random", torch.randn(h_real.shape, device=h_real.device, generator=generator), None),
223
- ]
224
- if shuffled_idx is not None:
225
- controls.append(("shuffled", lookup[shuffled_idx]["h"].cuda(), shuffled_idx))
226
- if wrong_ref_idx is not None:
227
- controls.append(("wrong_ref", lookup[wrong_ref_idx]["h"].cuda(), wrong_ref_idx))
228
-
229
- for beta in betas:
230
- for h_type, h, h_source_idx in controls:
231
- pred_scores = d2_scores_llm(llm_tokens, pred_mask_tokens, h, beta)
232
- gt_scores = d2_scores_llm(llm_tokens, gt_mask_tokens, h, beta)
233
- for frame_idx in range(pred_scores.shape[0]):
234
- rows.append(
235
- {
236
- "sample_idx": sample_idx,
237
- "vid": item["vid"],
238
- "ref": item["ref"],
239
- "fid": item["fid"],
240
- "split": args.eval_split,
241
- "frame": frame_idx,
242
- "h_type": h_type,
243
- "beta": beta,
244
- "s_pred": pred_scores[frame_idx].item(),
245
- "s_gt": gt_scores[frame_idx].item(),
246
- "h_source_idx": h_source_idx if h_source_idx is not None else "",
247
- "frame_iou": frame_ious[frame_idx].item(),
248
- "frame_fscore_proxy": frame_fscores[frame_idx].item(),
249
- "iou_pred": iou_predictions[frame_idx].item(),
250
- "pred_area": pred_area[frame_idx].item(),
251
- "gt_area": gt_area[frame_idx].item(),
252
- }
253
- )
254
-
255
- real_rows = [
256
- r for r in rows if r["sample_idx"] == sample_idx and r["h_type"] == "real" and r["beta"] == betas[0]
257
- ]
258
- s_pred_values = [r["s_pred"] for r in real_rows]
259
- print(
260
- f"D2-LLM {sample_idx}: vid={item['vid']} ref={item['ref']} "
261
- f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
262
- f"mean_iou={frame_ious.mean().item():.4f}"
263
- )
264
-
265
- return rows
266
-
267
-
268
- def print_summary(rows):
269
- print("\nSummary")
270
- print(f"rows: {len(rows)}")
271
- for beta in sorted(set(r["beta"] for r in rows)):
272
- beta_rows = [r for r in rows if r["beta"] == beta]
273
- print(f"\nbeta={beta}")
274
- for h_type in sorted(set(r["h_type"] for r in beta_rows)):
275
- hr = [r for r in beta_rows if r["h_type"] == h_type]
276
- print(
277
- f"{h_type:10s} "
278
- f"mean_s_pred={np.mean([r['s_pred'] for r in hr]):+.4f} "
279
- f"mean_s_gt={np.mean([r['s_gt'] for r in hr]):+.4f}"
280
- )
281
- real_rows = [r for r in beta_rows if r["h_type"] == "real"]
282
- s_pred = np.array([r["s_pred"] for r in real_rows])
283
- frame_iou_values = np.array([r["frame_iou"] for r in real_rows])
284
- if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
285
- corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
286
- print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
287
- else:
288
- print("corr(real s_pred, frame_iou)=nan")
289
-
290
-
291
- def main():
292
- set_seed(42)
293
- torch.set_grad_enabled(False)
294
- betas = parse_betas()
295
- tokenizer, seg_token_idx = build_tokenizer()
296
- limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
297
- print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
298
-
299
- model = build_model(tokenizer, seg_token_idx)
300
- pool = collect_hidden_pool(model, tokenizer, limit)
301
- rows = run_d2_llm(model, tokenizer, pool, betas, limit)
302
- print_summary(rows)
303
-
304
- csv_path = os.environ.get("D2_LLM_CSV", f"/workspace/SimToken/d2_llm_{args.eval_split}_{limit}.csv")
305
- os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
306
- with open(csv_path, "w", newline="") as f:
307
- writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
308
- writer.writeheader()
309
- writer.writerows(rows)
310
- print(f"\nSaved CSV: {csv_path}")
311
-
312
-
313
- if __name__ == "__main__":
314
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
decoder_invariance_check.py DELETED
@@ -1,256 +0,0 @@
1
- import csv
2
- import os
3
- import random
4
- from functools import partial
5
-
6
- import numpy as np
7
- import torch
8
- import transformers
9
- from peft import LoraConfig, get_peft_model
10
- from torch.utils.data import DataLoader
11
- from transformers import AutoConfig
12
-
13
- from configs import args
14
- from datasets import REFAVS
15
- from load_model import collate_fn, dict_to_cuda
16
- from models.avs_model import Simtoken_ForCausalLM
17
-
18
-
19
- def set_seed(seed=42):
20
- torch.manual_seed(seed)
21
- np.random.seed(seed)
22
- random.seed(seed)
23
- torch.cuda.manual_seed_all(seed)
24
- torch.backends.cudnn.deterministic = True
25
- torch.backends.cudnn.benchmark = False
26
-
27
-
28
- def find_lora_target_modules(model, target_modules=("q_proj", "v_proj")):
29
- modules = set()
30
- excluded = [
31
- "visual_model",
32
- "vision_tower",
33
- "mm_projector",
34
- "text_hidden_fcs",
35
- "audio_feature_layer",
36
- ]
37
- for name, module in model.named_modules():
38
- if not isinstance(module, torch.nn.Linear):
39
- continue
40
- if any(x in name for x in excluded):
41
- continue
42
- if any(x in name for x in target_modules):
43
- modules.add(name)
44
- return sorted(modules)
45
-
46
-
47
- def build_model(tokenizer, seg_token_idx):
48
- model_args = {
49
- "train_mask_decoder": True,
50
- "out_dim": 256,
51
- "ce_loss_weight": 1.0,
52
- "dice_loss_weight": 0.5,
53
- "bce_loss_weight": 2.0,
54
- "seg_token_idx": seg_token_idx,
55
- "vision_pretrained": args.vision_pretrained,
56
- "vision_tower": args.vision_tower,
57
- "use_im_start_end": False,
58
- "compress": args.compress,
59
- "start": args.start,
60
- }
61
-
62
- model = Simtoken_ForCausalLM.from_pretrained(
63
- args.mllm,
64
- torch_dtype=torch.bfloat16,
65
- low_cpu_mem_usage=True,
66
- **model_args,
67
- )
68
-
69
- model.config.eos_token_id = tokenizer.eos_token_id
70
- model.config.bos_token_id = tokenizer.bos_token_id
71
- model.config.pad_token_id = tokenizer.pad_token_id
72
-
73
- model.get_model().initialize_vision_modules(model.get_model().config)
74
- vision_tower = model.get_model().get_vision_tower()
75
- vision_tower.to(dtype=torch.float32, device="cuda")
76
-
77
- model_args_from_pt = AutoConfig.from_pretrained(args.mllm)
78
- model_args_from_pt.use_cluster = True
79
- model_args_from_pt.freeze = False
80
- model_args_from_pt.mm_tune = True
81
- model_args_from_pt.spatial_cluster_rate0 = 64
82
- model_args_from_pt.spatial_cluster_rate1 = 32
83
- model_args_from_pt.spatial_cluster_rate2 = 16
84
- model_args_from_pt.temporal_cluster_rate = 0.0625
85
- model_args_from_pt.vision_tune = False
86
- model.get_model().initialize_cluster_modules(model_args_from_pt)
87
- model.get_model().initialize_lisa_modules(model.get_model().config)
88
-
89
- lora_config = LoraConfig(
90
- r=8,
91
- lora_alpha=16,
92
- target_modules=find_lora_target_modules(model),
93
- lora_dropout=0.05,
94
- bias="none",
95
- task_type="CAUSAL_LM",
96
- )
97
- model = get_peft_model(model, lora_config)
98
- model = model.to("cuda")
99
- model.resize_token_embeddings(len(tokenizer))
100
-
101
- state = torch.load(args.saved_model, map_location="cpu")
102
- missing, unexpected = model.load_state_dict(state, strict=False)
103
- print(f"Loaded checkpoint: {args.saved_model}")
104
- print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
105
-
106
- model.eval()
107
- return model
108
-
109
-
110
- def get_seg_embedding(model, batch):
111
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
112
- output = model.forward(
113
- images=batch["images"],
114
- images_clip=batch["images_clip"],
115
- audio_features=batch["audio_feats"],
116
- image_features=batch["image_feats"],
117
- input_ids=batch["input_ids"],
118
- labels=batch["labels"],
119
- attention_masks=batch["attention_masks"],
120
- masks_list=batch["masks"],
121
- resize_list=batch["resizes"],
122
- orgsize_list=batch["orgsizes"],
123
- conversation_list=batch["convs"],
124
- refs_num=batch["refs_num"],
125
- fids=batch["fids"],
126
- vids=batch["vids"],
127
- contrast=args.ct_weight,
128
- ref_ids=batch["ref_ids"],
129
- inference=True,
130
- )
131
- return output["seg_embeddings"][0][0:1]
132
-
133
-
134
- def check_one_sample(model, batch):
135
- q = get_seg_embedding(model, batch)
136
- image_embeddings = batch["image_feats"][0]
137
-
138
- visual_model = model.get_model().visual_model
139
- sparse, dense = visual_model.prompt_encoder(
140
- points=None,
141
- boxes=None,
142
- masks=None,
143
- text_embeds=q.unsqueeze(1),
144
- )
145
- sparse = sparse.to(q.dtype)
146
- dense = dense.to(q.dtype)
147
-
148
- decoder = visual_model.mask_decoder
149
- image_pe = visual_model.prompt_encoder.get_dense_pe()
150
-
151
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
152
- full_masks, full_iou = decoder(
153
- image_embeddings=image_embeddings,
154
- image_pe=image_pe,
155
- sparse_prompt_embeddings=sparse,
156
- dense_prompt_embeddings=dense,
157
- multimask_output=False,
158
- )
159
-
160
- rows = []
161
- for t in range(image_embeddings.shape[0]):
162
- single_masks, single_iou = decoder(
163
- image_embeddings=image_embeddings[t : t + 1],
164
- image_pe=image_pe,
165
- sparse_prompt_embeddings=sparse,
166
- dense_prompt_embeddings=dense,
167
- multimask_output=False,
168
- )
169
-
170
- diff = (full_masks[t : t + 1] - single_masks).float().abs()
171
- iou_diff = (full_iou[t : t + 1] - single_iou).float().abs()
172
- rows.append(
173
- {
174
- "vid": batch["vids"][0],
175
- "ref": batch["refs"][0][0],
176
- "frame": t,
177
- "max_abs_diff": diff.max().item(),
178
- "mean_abs_diff": diff.mean().item(),
179
- "iou_pred_diff": iou_diff.max().item(),
180
- }
181
- )
182
- return rows
183
-
184
-
185
- def main():
186
- set_seed(42)
187
- torch.set_grad_enabled(False)
188
-
189
- tokenizer = transformers.AutoTokenizer.from_pretrained(
190
- args.mllm,
191
- cache_dir=None,
192
- model_max_length=2048,
193
- padding_side="right",
194
- use_fast=False,
195
- )
196
- tokenizer.pad_token = tokenizer.unk_token
197
- tokenizer.add_tokens("[SEG]")
198
- seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
199
-
200
- dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
201
- loader = DataLoader(
202
- dataset,
203
- batch_size=1,
204
- shuffle=False,
205
- num_workers=0,
206
- collate_fn=partial(collate_fn, tokenizer=tokenizer),
207
- )
208
-
209
- limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
210
- print(f"Split: {args.eval_split} | samples to check: {limit}")
211
-
212
- model = build_model(tokenizer, seg_token_idx)
213
-
214
- all_rows = []
215
- for sample_idx, batch in enumerate(loader):
216
- if sample_idx >= limit:
217
- break
218
- batch = dict_to_cuda(batch)
219
- rows = check_one_sample(model, batch)
220
- all_rows.extend(rows)
221
-
222
- print(f"\nSample {sample_idx}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
223
- print("frame | max_abs_diff | mean_abs_diff | iou_pred_diff")
224
- for row in rows:
225
- print(
226
- f"{row['frame']:02d} | "
227
- f"{row['max_abs_diff']:.8e} | "
228
- f"{row['mean_abs_diff']:.8e} | "
229
- f"{row['iou_pred_diff']:.8e}"
230
- )
231
-
232
- if not all_rows:
233
- raise RuntimeError("No rows were checked. Is the selected split empty?")
234
-
235
- max_diff = max(row["max_abs_diff"] for row in all_rows)
236
- mean_diff = sum(row["mean_abs_diff"] for row in all_rows) / len(all_rows)
237
- max_iou_diff = max(row["iou_pred_diff"] for row in all_rows)
238
-
239
- print("\nSummary")
240
- print(f"checked frames: {len(all_rows)}")
241
- print(f"global max_abs_diff: {max_diff:.8e}")
242
- print(f"average mean_abs_diff: {mean_diff:.8e}")
243
- print(f"global max_iou_pred_diff: {max_iou_diff:.8e}")
244
-
245
- csv_path = os.environ.get("DECODER_INVARIANCE_CSV")
246
- if csv_path:
247
- os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
248
- with open(csv_path, "w", newline="") as f:
249
- writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
250
- writer.writeheader()
251
- writer.writerows(all_rows)
252
- print(f"Saved CSV: {csv_path}")
253
-
254
-
255
- if __name__ == "__main__":
256
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev_subsets_rpb_v1.json DELETED
@@ -1,620 +0,0 @@
1
- {
2
- "metadata": {
3
- "seed": 42,
4
- "split_sizes": {
5
- "train": 14113,
6
- "test_s": 2288,
7
- "test_u": 1656,
8
- "test_n": 1028
9
- },
10
- "source_metadata": "/workspace/SimToken/data/metadata.csv"
11
- },
12
- "subsets": {
13
- "test_s": [
14
- 6,
15
- 16,
16
- 36,
17
- 71,
18
- 74,
19
- 88,
20
- 108,
21
- 114,
22
- 116,
23
- 122,
24
- 126,
25
- 128,
26
- 134,
27
- 138,
28
- 139,
29
- 146,
30
- 152,
31
- 159,
32
- 177,
33
- 196,
34
- 217,
35
- 219,
36
- 249,
37
- 256,
38
- 268,
39
- 276,
40
- 279,
41
- 286,
42
- 287,
43
- 297,
44
- 298,
45
- 299,
46
- 312,
47
- 313,
48
- 324,
49
- 331,
50
- 332,
51
- 347,
52
- 378,
53
- 383,
54
- 402,
55
- 410,
56
- 412,
57
- 420,
58
- 451,
59
- 452,
60
- 458,
61
- 467,
62
- 477,
63
- 484,
64
- 486,
65
- 497,
66
- 499,
67
- 512,
68
- 526,
69
- 533,
70
- 543,
71
- 550,
72
- 551,
73
- 567,
74
- 574,
75
- 576,
76
- 581,
77
- 594,
78
- 596,
79
- 608,
80
- 616,
81
- 625,
82
- 627,
83
- 642,
84
- 646,
85
- 663,
86
- 692,
87
- 700,
88
- 704,
89
- 724,
90
- 745,
91
- 754,
92
- 795,
93
- 815,
94
- 819,
95
- 831,
96
- 843,
97
- 854,
98
- 867,
99
- 895,
100
- 946,
101
- 953,
102
- 965,
103
- 975,
104
- 979,
105
- 989,
106
- 1004,
107
- 1007,
108
- 1008,
109
- 1010,
110
- 1023,
111
- 1039,
112
- 1051,
113
- 1052,
114
- 1072,
115
- 1075,
116
- 1080,
117
- 1088,
118
- 1099,
119
- 1101,
120
- 1104,
121
- 1106,
122
- 1134,
123
- 1138,
124
- 1169,
125
- 1180,
126
- 1201,
127
- 1205,
128
- 1221,
129
- 1230,
130
- 1247,
131
- 1258,
132
- 1272,
133
- 1279,
134
- 1284,
135
- 1294,
136
- 1297,
137
- 1312,
138
- 1329,
139
- 1339,
140
- 1343,
141
- 1367,
142
- 1379,
143
- 1406,
144
- 1417,
145
- 1461,
146
- 1462,
147
- 1468,
148
- 1473,
149
- 1474,
150
- 1489,
151
- 1493,
152
- 1500,
153
- 1510,
154
- 1517,
155
- 1552,
156
- 1556,
157
- 1557,
158
- 1589,
159
- 1609,
160
- 1612,
161
- 1618,
162
- 1622,
163
- 1624,
164
- 1644,
165
- 1647,
166
- 1665,
167
- 1669,
168
- 1676,
169
- 1682,
170
- 1683,
171
- 1691,
172
- 1700,
173
- 1726,
174
- 1746,
175
- 1748,
176
- 1758,
177
- 1764,
178
- 1765,
179
- 1778,
180
- 1785,
181
- 1786,
182
- 1808,
183
- 1826,
184
- 1852,
185
- 1861,
186
- 1883,
187
- 1891,
188
- 1916,
189
- 1938,
190
- 1944,
191
- 1967,
192
- 1971,
193
- 1980,
194
- 1986,
195
- 2034,
196
- 2044,
197
- 2067,
198
- 2074,
199
- 2082,
200
- 2085,
201
- 2118,
202
- 2128,
203
- 2156,
204
- 2176,
205
- 2182,
206
- 2185,
207
- 2188,
208
- 2194,
209
- 2206,
210
- 2211,
211
- 2215,
212
- 2247,
213
- 2256
214
- ],
215
- "test_u": [
216
- 4,
217
- 16,
218
- 26,
219
- 38,
220
- 40,
221
- 48,
222
- 50,
223
- 65,
224
- 83,
225
- 92,
226
- 102,
227
- 117,
228
- 120,
229
- 135,
230
- 144,
231
- 153,
232
- 155,
233
- 185,
234
- 200,
235
- 201,
236
- 211,
237
- 219,
238
- 221,
239
- 226,
240
- 227,
241
- 240,
242
- 245,
243
- 251,
244
- 252,
245
- 255,
246
- 267,
247
- 272,
248
- 274,
249
- 276,
250
- 278,
251
- 282,
252
- 284,
253
- 286,
254
- 303,
255
- 309,
256
- 313,
257
- 328,
258
- 345,
259
- 348,
260
- 358,
261
- 363,
262
- 374,
263
- 376,
264
- 379,
265
- 383,
266
- 385,
267
- 387,
268
- 393,
269
- 396,
270
- 400,
271
- 412,
272
- 417,
273
- 428,
274
- 434,
275
- 452,
276
- 453,
277
- 456,
278
- 459,
279
- 463,
280
- 473,
281
- 490,
282
- 493,
283
- 504,
284
- 517,
285
- 525,
286
- 535,
287
- 543,
288
- 544,
289
- 545,
290
- 549,
291
- 550,
292
- 565,
293
- 584,
294
- 585,
295
- 594,
296
- 602,
297
- 603,
298
- 606,
299
- 638,
300
- 642,
301
- 643,
302
- 651,
303
- 684,
304
- 687,
305
- 692,
306
- 700,
307
- 721,
308
- 728,
309
- 752,
310
- 757,
311
- 779,
312
- 783,
313
- 785,
314
- 794,
315
- 803,
316
- 807,
317
- 814,
318
- 847,
319
- 849,
320
- 853,
321
- 854,
322
- 861,
323
- 867,
324
- 884,
325
- 900,
326
- 903,
327
- 906,
328
- 924,
329
- 930,
330
- 931,
331
- 941,
332
- 948,
333
- 957,
334
- 968,
335
- 972,
336
- 980,
337
- 987,
338
- 995,
339
- 996,
340
- 1007,
341
- 1009,
342
- 1028,
343
- 1033,
344
- 1034,
345
- 1040,
346
- 1054,
347
- 1098,
348
- 1104,
349
- 1111,
350
- 1121,
351
- 1126,
352
- 1134,
353
- 1155,
354
- 1161,
355
- 1167,
356
- 1180,
357
- 1186,
358
- 1192,
359
- 1212,
360
- 1214,
361
- 1219,
362
- 1226,
363
- 1254,
364
- 1256,
365
- 1259,
366
- 1261,
367
- 1270,
368
- 1278,
369
- 1285,
370
- 1288,
371
- 1290,
372
- 1305,
373
- 1310,
374
- 1323,
375
- 1325,
376
- 1343,
377
- 1360,
378
- 1375,
379
- 1376,
380
- 1404,
381
- 1411,
382
- 1426,
383
- 1429,
384
- 1442,
385
- 1449,
386
- 1452,
387
- 1456,
388
- 1475,
389
- 1478,
390
- 1479,
391
- 1484,
392
- 1493,
393
- 1499,
394
- 1500,
395
- 1501,
396
- 1506,
397
- 1517,
398
- 1523,
399
- 1528,
400
- 1536,
401
- 1545,
402
- 1546,
403
- 1550,
404
- 1561,
405
- 1570,
406
- 1598,
407
- 1609,
408
- 1611,
409
- 1625,
410
- 1632,
411
- 1634,
412
- 1635,
413
- 1641,
414
- 1654,
415
- 1655
416
- ],
417
- "test_n": [
418
- 4,
419
- 5,
420
- 9,
421
- 16,
422
- 20,
423
- 25,
424
- 27,
425
- 33,
426
- 37,
427
- 40,
428
- 45,
429
- 46,
430
- 48,
431
- 53,
432
- 56,
433
- 60,
434
- 62,
435
- 67,
436
- 77,
437
- 78,
438
- 80,
439
- 81,
440
- 86,
441
- 90,
442
- 94,
443
- 99,
444
- 102,
445
- 106,
446
- 108,
447
- 111,
448
- 116,
449
- 121,
450
- 126,
451
- 127,
452
- 132,
453
- 143,
454
- 148,
455
- 153,
456
- 155,
457
- 156,
458
- 158,
459
- 160,
460
- 164,
461
- 168,
462
- 170,
463
- 171,
464
- 173,
465
- 175,
466
- 183,
467
- 184,
468
- 185,
469
- 188,
470
- 189,
471
- 190,
472
- 196,
473
- 202,
474
- 206,
475
- 208,
476
- 212,
477
- 217,
478
- 221,
479
- 222,
480
- 223,
481
- 233,
482
- 242,
483
- 246,
484
- 247,
485
- 259,
486
- 262,
487
- 269,
488
- 283,
489
- 298,
490
- 299,
491
- 306,
492
- 316,
493
- 317,
494
- 323,
495
- 330,
496
- 332,
497
- 334,
498
- 354,
499
- 357,
500
- 367,
501
- 372,
502
- 395,
503
- 397,
504
- 400,
505
- 405,
506
- 407,
507
- 420,
508
- 431,
509
- 435,
510
- 436,
511
- 444,
512
- 446,
513
- 461,
514
- 464,
515
- 470,
516
- 479,
517
- 481,
518
- 483,
519
- 485,
520
- 487,
521
- 494,
522
- 512,
523
- 516,
524
- 520,
525
- 524,
526
- 529,
527
- 530,
528
- 539,
529
- 540,
530
- 541,
531
- 554,
532
- 559,
533
- 560,
534
- 564,
535
- 568,
536
- 571,
537
- 572,
538
- 576,
539
- 577,
540
- 581,
541
- 585,
542
- 592,
543
- 602,
544
- 609,
545
- 620,
546
- 630,
547
- 632,
548
- 677,
549
- 678,
550
- 684,
551
- 693,
552
- 694,
553
- 695,
554
- 702,
555
- 716,
556
- 724,
557
- 727,
558
- 732,
559
- 735,
560
- 736,
561
- 747,
562
- 750,
563
- 752,
564
- 755,
565
- 758,
566
- 764,
567
- 767,
568
- 774,
569
- 775,
570
- 777,
571
- 779,
572
- 780,
573
- 782,
574
- 795,
575
- 800,
576
- 812,
577
- 815,
578
- 818,
579
- 821,
580
- 823,
581
- 825,
582
- 828,
583
- 834,
584
- 841,
585
- 843,
586
- 846,
587
- 848,
588
- 860,
589
- 861,
590
- 863,
591
- 869,
592
- 871,
593
- 878,
594
- 882,
595
- 891,
596
- 893,
597
- 896,
598
- 898,
599
- 899,
600
- 901,
601
- 906,
602
- 930,
603
- 940,
604
- 944,
605
- 969,
606
- 970,
607
- 973,
608
- 980,
609
- 990,
610
- 993,
611
- 996,
612
- 997,
613
- 1007,
614
- 1012,
615
- 1013,
616
- 1019,
617
- 1025
618
- ]
619
- }
620
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_dev_eval_baseline_step0.txt DELETED
@@ -1,5 +0,0 @@
1
- Epoch 0: running_loss 0.004542401526123285 Learning Rate:0.000000
2
- valuate on test_s_refer: miou 0.7255374467872275 true fscore 0.8181094569922425
3
- valuate on test_u_refer: miou 0.68531153425507 true fscore 0.7723772643739357
4
-
5
- valuate on test_n_refer: metric 0.014519116841256618
 
 
 
 
 
 
log/rpb_dev_eval_pm_only_a02_step0.txt DELETED
@@ -1,7 +0,0 @@
1
- Epoch 0: running_loss 0.013856410048902035 Learning Rate:0.000000
2
- valuate on test_s_refer: miou 0.7251653336426284 true fscore 0.8137564373598434
3
- bridge on test_s_refer: cos_delta_p_mask_mean=0.752373 | cos_delta_q_mean=-0.063845 | cos_delta_z_gt_mean=0.066832 | cos_p_hat_p_mask_mean=0.095022 | cos_p_hat_q_mean=0.991696 | cos_p_hat_z_gt_mean=0.058512 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=4.838175 | gate_mean=0.642605 | gate_std=0.066554 | p_hat_norm_mean=37.143986 | p_mask_norm_mean=0.855194 | q_norm_mean=37.143986 | z_gt_norm_mean=1.270137
4
- valuate on test_u_refer: miou 0.6859597001315854 true fscore 0.7773032036889345
5
- bridge on test_u_refer: cos_delta_p_mask_mean=0.752107 | cos_delta_q_mean=-0.052752 | cos_delta_z_gt_mean=0.059016 | cos_p_hat_p_mask_mean=0.066111 | cos_p_hat_q_mean=0.994380 | cos_p_hat_z_gt_mean=0.056506 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=3.232154 | gate_mean=0.529798 | gate_std=0.041540 | p_hat_norm_mean=30.350392 | p_mask_norm_mean=0.854621 | q_norm_mean=30.350392 | z_gt_norm_mean=1.131404
6
-
7
- valuate on test_n_refer: metric 0.014255181886255741
 
 
 
 
 
 
 
 
log/rpb_dev_mixed_pm_only_a015_wm005.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.12634180719032884 Learning Rate:0.000048
2
- Epoch 1: running_loss 0.06299160566413775 Learning Rate:0.000038
3
- Epoch 2: running_loss 0.04188278445508331 Learning Rate:0.000021
4
- Epoch 3: running_loss 0.03136271081166342 Learning Rate:0.000006
5
- Epoch 4: running_loss 0.025073944311589002 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.7268448945908449 true fscore 0.8160740848700516
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.780949 | cos_delta_q_mean=-0.022341 | cos_delta_z_gt_mean=0.080238 | cos_p_hat_p_mask_mean=0.033820 | cos_p_hat_q_mean=0.998889 | cos_p_hat_z_gt_mean=0.053521 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=1.741799 | gate_mean=0.298187 | gate_std=0.074034 | p_hat_norm_mean=37.144979 | p_mask_norm_mean=0.855194 | q_norm_mean=37.144979 | z_gt_norm_mean=1.270137
8
- valuate on test_u_refer: miou 0.6867437321859904 true fscore 0.774193259445019
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.787519 | cos_delta_q_mean=-0.014046 | cos_delta_z_gt_mean=0.070144 | cos_p_hat_p_mask_mean=0.008821 | cos_p_hat_q_mean=0.999587 | cos_p_hat_z_gt_mean=0.052258 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.869715 | gate_mean=0.187340 | gate_std=0.030662 | p_hat_norm_mean=30.349741 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349741 | z_gt_norm_mean=1.131404
10
-
11
- valuate on test_n_refer: metric 0.014510215260088444
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_dev_mixed_pm_only_a018_wm005.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.12581317650619894 Learning Rate:0.000048
2
- Epoch 1: running_loss 0.0626903815427795 Learning Rate:0.000038
3
- Epoch 2: running_loss 0.04165894452792903 Learning Rate:0.000021
4
- Epoch 3: running_loss 0.031184122432023287 Learning Rate:0.000006
5
- Epoch 4: running_loss 0.024928097636438905 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.727035479994347 true fscore 0.8155373766715638
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.779142 | cos_delta_q_mean=-0.026866 | cos_delta_z_gt_mean=0.080963 | cos_p_hat_p_mask_mean=0.040792 | cos_p_hat_q_mean=0.998394 | cos_p_hat_z_gt_mean=0.054268 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=2.094408 | gate_mean=0.298949 | gate_std=0.074175 | p_hat_norm_mean=37.145271 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145271 | z_gt_norm_mean=1.270137
8
- valuate on test_u_refer: miou 0.6870561258980442 true fscore 0.774542552176863
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.786014 | cos_delta_q_mean=-0.016895 | cos_delta_z_gt_mean=0.071182 | cos_p_hat_p_mask_mean=0.013252 | cos_p_hat_q_mean=0.999403 | cos_p_hat_z_gt_mean=0.052698 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=1.046129 | gate_mean=0.187813 | gate_std=0.030748 | p_hat_norm_mean=30.349577 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349577 | z_gt_norm_mean=1.131404
10
-
11
- valuate on test_n_refer: metric 0.014507208950817585
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_dev_pm_only_a012.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.1251933453604579 Learning Rate:0.000291
2
- Epoch 1: running_loss 0.06243458506651223 Learning Rate:0.000225
3
- Epoch 2: running_loss 0.04142383218277246 Learning Rate:0.000124
4
- Epoch 3: running_loss 0.030912025278666988 Learning Rate:0.000035
5
- Epoch 4: running_loss 0.024670254811644553 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.7265147582390341 true fscore 0.8174789174459874
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.785657 | cos_delta_q_mean=-0.012593 | cos_delta_z_gt_mean=0.074588 | cos_p_hat_p_mask_mean=0.018714 | cos_p_hat_q_mean=0.999648 | cos_p_hat_z_gt_mean=0.051832 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=0.980784 | gate_mean=0.209955 | gate_std=0.050712 | p_hat_norm_mean=37.145389 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145389 | z_gt_norm_mean=1.270137
8
- valuate on test_u_refer: miou 0.685781483513075 true fscore 0.7731429794151335
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.790605 | cos_delta_q_mean=-0.008125 | cos_delta_z_gt_mean=0.065258 | cos_p_hat_p_mask_mean=-0.000455 | cos_p_hat_q_mean=0.999863 | cos_p_hat_z_gt_mean=0.051334 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.502185 | gate_mean=0.135438 | gate_std=0.020096 | p_hat_norm_mean=30.347839 | p_mask_norm_mean=0.854621 | q_norm_mean=30.347839 | z_gt_norm_mean=1.131404
10
-
11
- valuate on test_n_refer: metric 0.014490844681859016
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_dev_pm_only_a015.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.12516111659351736 Learning Rate:0.000291
2
- Epoch 1: running_loss 0.06237624154891819 Learning Rate:0.000225
3
- Epoch 2: running_loss 0.04133288407077392 Learning Rate:0.000124
4
- Epoch 3: running_loss 0.03080323277390562 Learning Rate:0.000035
5
- Epoch 4: running_loss 0.024568469962105155 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.7266912544447951 true fscore 0.8172510598856024
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.784637 | cos_delta_q_mean=-0.015801 | cos_delta_z_gt_mean=0.074893 | cos_p_hat_p_mask_mean=0.023727 | cos_p_hat_q_mean=0.999446 | cos_p_hat_z_gt_mean=0.052317 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=1.230677 | gate_mean=0.210794 | gate_std=0.050954 | p_hat_norm_mean=37.144974 | p_mask_norm_mean=0.855194 | q_norm_mean=37.144974 | z_gt_norm_mean=1.270137
8
- valuate on test_u_refer: miou 0.6856936469832761 true fscore 0.7733012911863625
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.789761 | cos_delta_q_mean=-0.010194 | cos_delta_z_gt_mean=0.065751 | cos_p_hat_p_mask_mean=0.002815 | cos_p_hat_q_mean=0.999784 | cos_p_hat_z_gt_mean=0.051617 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.630081 | gate_mean=0.135950 | gate_std=0.020168 | p_hat_norm_mean=30.349286 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349286 | z_gt_norm_mean=1.131404
10
-
11
- valuate on test_n_refer: metric 0.014483190141618252
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_dev_pm_only_a018.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.12512886058539152 Learning Rate:0.000291
2
- Epoch 1: running_loss 0.062317848962266 Learning Rate:0.000225
3
- Epoch 2: running_loss 0.04124188135998944 Learning Rate:0.000124
4
- Epoch 3: running_loss 0.03069439489627257 Learning Rate:0.000035
5
- Epoch 4: running_loss 0.024466648511588574 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.7269170961743339 true fscore 0.817047117385082
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.783528 | cos_delta_q_mean=-0.019011 | cos_delta_z_gt_mean=0.075155 | cos_p_hat_p_mask_mean=0.028732 | cos_p_hat_q_mean=0.999199 | cos_p_hat_z_gt_mean=0.052798 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=1.480661 | gate_mean=0.211391 | gate_std=0.051102 | p_hat_norm_mean=37.145608 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145608 | z_gt_norm_mean=1.270137
8
- valuate on test_u_refer: miou 0.6859480822706291 true fscore 0.7735356919141486
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.788825 | cos_delta_q_mean=-0.012263 | cos_delta_z_gt_mean=0.066219 | cos_p_hat_p_mask_mean=0.006046 | cos_p_hat_q_mean=0.999688 | cos_p_hat_z_gt_mean=0.051902 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.757877 | gate_mean=0.136287 | gate_std=0.020245 | p_hat_norm_mean=30.346972 | p_mask_norm_mean=0.854621 | q_norm_mean=30.346972 | z_gt_norm_mean=1.131404
10
-
11
- valuate on test_n_refer: metric 0.014475596137344837
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_dev_qonly_pm_only_a018.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.1250931837130338 Learning Rate:0.000291
2
- Epoch 1: running_loss 0.06158186250831932 Learning Rate:0.000225
3
- Epoch 2: running_loss 0.03905615148444971 Learning Rate:0.000124
4
- Epoch 3: running_loss 0.028493995574535802 Learning Rate:0.000035
5
- Epoch 4: running_loss 0.022694221674464644 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.7231086666105239 true fscore 0.8120589338685386
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.740588 | cos_delta_q_mean=-0.082204 | cos_delta_z_gt_mean=0.083615 | cos_p_hat_p_mask_mean=0.120609 | cos_p_hat_q_mean=0.986413 | cos_p_hat_z_gt_mean=0.063688 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=6.165701 | gate_mean=0.922904 | gate_std=0.048146 | p_hat_norm_mean=37.145128 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145128 | z_gt_norm_mean=1.270137
8
- valuate on test_u_refer: miou 0.6828930461963626 true fscore 0.7766606059018523
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.750842 | cos_delta_q_mean=-0.072793 | cos_delta_z_gt_mean=0.080115 | cos_p_hat_p_mask_mean=0.095975 | cos_p_hat_q_mean=0.989300 | cos_p_hat_z_gt_mean=0.061951 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=4.458672 | gate_mean=0.815494 | gate_std=0.064275 | p_hat_norm_mean=30.349046 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349046 | z_gt_norm_mean=1.131404
10
-
11
- valuate on test_n_refer: metric 0.014240134507417679
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_e1_baseline.txt DELETED
@@ -1,5 +0,0 @@
1
- Epoch 0: running_loss 0.0045423684641718864 Learning Rate:0.000000
2
- valuate on test_s_refer: miou 0.7299158895817891 true fscore 0.8098922965396196
3
- valuate on test_u_refer: miou 0.7330115197712439 true fscore 0.8183729078620672
4
-
5
- valuate on test_n_refer: metric 0.1223459392786026
 
 
 
 
 
 
log/rpb_e4_min.txt DELETED
@@ -1,16 +0,0 @@
1
- Epoch 0: running_loss 7.052718125283718 Learning Rate:0.000097
2
- Epoch 1: running_loss 3.5262171775102615 Learning Rate:0.000075
3
- Epoch 2: running_loss 2.35092111180226 Learning Rate:0.000041
4
- Epoch 3: running_loss 1.7629929669201374 Learning Rate:0.000012
5
- Epoch 4: running_loss 1.4105001017451286 Learning Rate:0.000000
6
- Epoch 0: running_loss 7.052717879414558 Learning Rate:0.000097
7
- Epoch 1: running_loss 3.526217419654131 Learning Rate:0.000075
8
- Epoch 2: running_loss 2.3509211614727974 Learning Rate:0.000041
9
- Epoch 3: running_loss 1.762992987409234 Learning Rate:0.000012
10
- Epoch 4: running_loss 1.410500232875347 Learning Rate:0.000000
11
- valuate on test_s_refer: miou 0.010701371397460661 true fscore 0.16367542997933923
12
- bridge on test_s_refer: cos_p_hat_p_mask_mean=-0.003076 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=0.031631 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=0.003709 | gate_mean=0.019151 | gate_std=0.000754 | p_hat_norm_mean=6.222885 | p_mask_norm_mean=0.854909 | q_norm_mean=6.223040 | z_gt_norm_mean=1.275222
13
- valuate on test_u_refer: miou 0.03141531638093511 true fscore 0.1579975866433233
14
- bridge on test_u_refer: cos_p_hat_p_mask_mean=-0.004606 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=-0.000177 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=0.003449 | gate_mean=0.019014 | gate_std=0.000658 | p_hat_norm_mean=5.875611 | p_mask_norm_mean=0.855032 | q_norm_mean=5.875684 | z_gt_norm_mean=0.969146
15
-
16
- valuate on test_n_refer: metric 0.15515293180942535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_e4_min_v2.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.2470331892836839 Learning Rate:0.000097
2
- Epoch 1: running_loss 0.12353144341614097 Learning Rate:0.000075
3
- Epoch 2: running_loss 0.08232998211557667 Learning Rate:0.000041
4
- Epoch 3: running_loss 0.0617638936964795 Learning Rate:0.000012
5
- Epoch 4: running_loss 0.04941030433401465 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.729936970449844 true fscore 0.8099028875399381
7
- bridge on test_s_refer: cos_p_hat_p_mask_mean=-0.009047 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=0.060572 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=0.004936 | gate_mean=0.024371 | gate_std=0.005409 | p_hat_norm_mean=36.236958 | p_mask_norm_mean=0.854909 | q_norm_mean=36.239986 | z_gt_norm_mean=1.275222
8
- valuate on test_u_refer: miou 0.7330397108156467 true fscore 0.8183516443520784
9
- bridge on test_u_refer: cos_p_hat_p_mask_mean=-0.004755 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=0.013517 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=0.004417 | gate_mean=0.023295 | gate_std=0.004361 | p_hat_norm_mean=30.846060 | p_mask_norm_mean=0.855032 | q_norm_mean=30.848833 | z_gt_norm_mean=0.969146
10
-
11
- valuate on test_n_refer: metric 0.12235464155673981
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_probe_a1_teacher_only.txt DELETED
@@ -1,22 +0,0 @@
1
- Epoch 0: running_loss 0.15941409580409527 Learning Rate:0.000150
2
- Epoch 1: running_loss 0.07969226781278849 Learning Rate:0.000300
3
- Epoch 2: running_loss 0.05310918173442284 Learning Rate:0.000298
4
- Epoch 3: running_loss 0.03982830489985645 Learning Rate:0.000291
5
- Epoch 4: running_loss 0.03184974528849125 Learning Rate:0.000280
6
- Epoch 5: running_loss 0.02652722302203377 Learning Rate:0.000265
7
- Epoch 6: running_loss 0.02272333244660071 Learning Rate:0.000246
8
- Epoch 7: running_loss 0.019872855627909303 Learning Rate:0.000225
9
- Epoch 8: running_loss 0.017649518532885447 Learning Rate:0.000201
10
- Epoch 9: running_loss 0.015872883144766092 Learning Rate:0.000176
11
- Epoch 10: running_loss 0.014423399655656382 Learning Rate:0.000150
12
- Epoch 11: running_loss 0.013206382282078266 Learning Rate:0.000124
13
- Epoch 12: running_loss 0.012179449988672366 Learning Rate:0.000099
14
- Epoch 13: running_loss 0.011303224135190248 Learning Rate:0.000075
15
- Epoch 14: running_loss 0.010542566950122515 Learning Rate:0.000054
16
- Epoch 15: running_loss 0.0098747648880817 Learning Rate:0.000035
17
- Epoch 16: running_loss 0.009292871307800798 Learning Rate:0.000020
18
- Epoch 17: running_loss 0.008775248295731015 Learning Rate:0.000009
19
- Epoch 18: running_loss 0.008311718702316284 Learning Rate:0.000002
20
- Epoch 19: running_loss 0.007893257355317474 Learning Rate:0.000000
21
- valuate on train_overfit: miou 0.8857842811448791 true fscore 0.9381048823706806
22
- bridge on train_overfit: cos_p_hat_p_mask_mean=0.004767 | cos_p_hat_q_mean=0.999904 | cos_p_hat_z_gt_mean=0.058385 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=0.571159 | gate_mean=0.425535 | gate_std=0.188610 | p_hat_norm_mean=32.916147 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257832 | z_gt_norm_mean=1.191098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_probe_a1_teacher_only_v2.txt DELETED
@@ -1,22 +0,0 @@
1
- Epoch 0: running_loss 0.15941409580409527 Learning Rate:0.000150
2
- Epoch 1: running_loss 0.0796922636218369 Learning Rate:0.000300
3
- Epoch 2: running_loss 0.05310917769869169 Learning Rate:0.000298
4
- Epoch 3: running_loss 0.03982830559834838 Learning Rate:0.000291
5
- Epoch 4: running_loss 0.03184974305331707 Learning Rate:0.000280
6
- Epoch 5: running_loss 0.02652722333247463 Learning Rate:0.000265
7
- Epoch 6: running_loss 0.022723329652632986 Learning Rate:0.000246
8
- Epoch 7: running_loss 0.019872855744324625 Learning Rate:0.000225
9
- Epoch 8: running_loss 0.017649516980681155 Learning Rate:0.000201
10
- Epoch 9: running_loss 0.015872882585972546 Learning Rate:0.000176
11
- Epoch 10: running_loss 0.01442340033298189 Learning Rate:0.000150
12
- Epoch 11: running_loss 0.013206382825349769 Learning Rate:0.000124
13
- Epoch 12: running_loss 0.012179449773751773 Learning Rate:0.000099
14
- Epoch 13: running_loss 0.011303224002144166 Learning Rate:0.000075
15
- Epoch 14: running_loss 0.010542566763858001 Learning Rate:0.000054
16
- Epoch 15: running_loss 0.00987476430600509 Learning Rate:0.000035
17
- Epoch 16: running_loss 0.009292872293907054 Learning Rate:0.000020
18
- Epoch 17: running_loss 0.0087752483992113 Learning Rate:0.000009
19
- Epoch 18: running_loss 0.008311718849367216 Learning Rate:0.000002
20
- Epoch 19: running_loss 0.007893257355317474 Learning Rate:0.000000
21
- valuate on train_overfit: miou 0.8857840351993218 true fscore 0.9381047114729881
22
- bridge on train_overfit: cos_delta_p_mask_mean=0.354064 | cos_delta_q_mean=-0.604202 | cos_delta_z_gt_mean=0.126264 | cos_p_hat_p_mask_mean=0.004767 | cos_p_hat_q_mean=0.999904 | cos_p_hat_z_gt_mean=0.058385 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=0.571159 | gate_mean=0.425535 | gate_std=0.188610 | p_hat_norm_mean=32.916147 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257831 | z_gt_norm_mean=1.191098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_probe_a1p_directional_pm_only.txt DELETED
@@ -1,22 +0,0 @@
1
- Epoch 0: running_loss 0.11214640829712152 Learning Rate:0.000150
2
- Epoch 1: running_loss 0.05601485609076917 Learning Rate:0.000300
3
- Epoch 2: running_loss 0.03723815083503723 Learning Rate:0.000298
4
- Epoch 3: running_loss 0.02785203023813665 Learning Rate:0.000291
5
- Epoch 4: running_loss 0.022219109814614058 Learning Rate:0.000280
6
- Epoch 5: running_loss 0.018464789803450305 Learning Rate:0.000265
7
- Epoch 6: running_loss 0.01578202284872532 Learning Rate:0.000246
8
- Epoch 7: running_loss 0.013773231767117977 Learning Rate:0.000225
9
- Epoch 8: running_loss 0.012206872407760885 Learning Rate:0.000201
10
- Epoch 9: running_loss 0.010958488751202821 Learning Rate:0.000176
11
- Epoch 10: running_loss 0.009943378030915152 Learning Rate:0.000150
12
- Epoch 11: running_loss 0.009091336939794322 Learning Rate:0.000124
13
- Epoch 12: running_loss 0.00837581454274746 Learning Rate:0.000099
14
- Epoch 13: running_loss 0.007767901090638978 Learning Rate:0.000075
15
- Epoch 14: running_loss 0.007241058039168516 Learning Rate:0.000054
16
- Epoch 15: running_loss 0.006779163610190153 Learning Rate:0.000035
17
- Epoch 16: running_loss 0.006378827452221338 Learning Rate:0.000020
18
- Epoch 17: running_loss 0.006023053286804093 Learning Rate:0.000009
19
- Epoch 18: running_loss 0.005704282390836038 Learning Rate:0.000002
20
- Epoch 19: running_loss 0.005416269856505096 Learning Rate:0.000000
21
- valuate on train_overfit: miou 0.883418077353781 true fscore 0.937678836286068
22
- bridge on train_overfit: cos_delta_p_mask_mean=0.818447 | cos_delta_q_mean=-0.029885 | cos_delta_z_gt_mean=0.063824 | cos_p_hat_p_mask_mean=0.047561 | cos_p_hat_q_mean=0.998200 | cos_p_hat_z_gt_mean=0.059441 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=2.004932 | gate_mean=0.598515 | gate_std=0.034498 | p_hat_norm_mean=33.257835 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257834 | z_gt_norm_mean=1.191098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_probe_a1p_directional_pm_only_a02.txt DELETED
@@ -1,22 +0,0 @@
1
- Epoch 0: running_loss 0.11209722375497222 Learning Rate:0.000150
2
- Epoch 1: running_loss 0.05594216543249786 Learning Rate:0.000300
3
- Epoch 2: running_loss 0.03709370751554767 Learning Rate:0.000298
4
- Epoch 3: running_loss 0.027660266729071736 Learning Rate:0.000291
5
- Epoch 4: running_loss 0.02200547931715846 Learning Rate:0.000280
6
- Epoch 5: running_loss 0.018238045663262408 Learning Rate:0.000265
7
- Epoch 6: running_loss 0.015544687730393239 Learning Rate:0.000246
8
- Epoch 7: running_loss 0.013526892522349954 Learning Rate:0.000225
9
- Epoch 8: running_loss 0.01195424489883913 Learning Rate:0.000201
10
- Epoch 9: running_loss 0.010702831950038672 Learning Rate:0.000176
11
- Epoch 10: running_loss 0.009686671324412931 Learning Rate:0.000150
12
- Epoch 11: running_loss 0.008837080444209278 Learning Rate:0.000124
13
- Epoch 12: running_loss 0.008126160953767024 Learning Rate:0.000099
14
- Epoch 13: running_loss 0.007524690058614526 Learning Rate:0.000075
15
- Epoch 14: running_loss 0.007005957514047622 Learning Rate:0.000054
16
- Epoch 15: running_loss 0.0065534417517483234 Learning Rate:0.000035
17
- Epoch 16: running_loss 0.006162627901443664 Learning Rate:0.000020
18
- Epoch 17: running_loss 0.005816713182462586 Learning Rate:0.000009
19
- Epoch 18: running_loss 0.005507827319793011 Learning Rate:0.000002
20
- Epoch 19: running_loss 0.005229406012222171 Learning Rate:0.000000
21
- valuate on train_overfit: miou 0.8791497684578644 true fscore 0.9370119273662567
22
- bridge on train_overfit: cos_delta_p_mask_mean=0.808940 | cos_delta_q_mean=-0.059708 | cos_delta_z_gt_mean=0.061659 | cos_p_hat_p_mask_mean=0.095240 | cos_p_hat_q_mean=0.992816 | cos_p_hat_z_gt_mean=0.062994 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=4.005328 | gate_mean=0.600366 | gate_std=0.034520 | p_hat_norm_mean=33.257835 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257836 | z_gt_norm_mean=1.191098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_probe_eval_directional_pm_only_a02.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.12453739601187408 Learning Rate:0.000291
2
- Epoch 1: running_loss 0.06081169372191653 Learning Rate:0.000225
3
- Epoch 2: running_loss 0.039517335942946374 Learning Rate:0.000124
4
- Epoch 3: running_loss 0.029158065939554945 Learning Rate:0.000035
5
- Epoch 4: running_loss 0.02320093212183565 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.7251764057789819 true fscore 0.8044321979023517
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.754565 | cos_delta_q_mean=-0.062171 | cos_delta_z_gt_mean=0.077296 | cos_p_hat_p_mask_mean=0.084720 | cos_p_hat_q_mean=0.992132 | cos_p_hat_z_gt_mean=0.070147 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=4.598394 | gate_mean=0.625537 | gate_std=0.054432 | p_hat_norm_mean=36.239987 | p_mask_norm_mean=0.854909 | q_norm_mean=36.239987 | z_gt_norm_mean=1.275222
8
- valuate on test_u_refer: miou 0.7347305961538223 true fscore 0.8193065231665969
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.754954 | cos_delta_q_mean=-0.054195 | cos_delta_z_gt_mean=0.089436 | cos_p_hat_p_mask_mean=0.077127 | cos_p_hat_q_mean=0.994077 | cos_p_hat_z_gt_mean=0.023352 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=3.370293 | gate_mean=0.544416 | gate_std=0.033540 | p_hat_norm_mean=30.852975 | p_mask_norm_mean=0.855032 | q_norm_mean=30.852975 | z_gt_norm_mean=0.969146
10
-
11
- valuate on test_n_refer: metric 0.12181796133518219
 
 
 
 
 
 
 
 
 
 
 
 
log/rpb_probe_eval_directional_pm_only_a02_step0.txt DELETED
@@ -1,7 +0,0 @@
1
- Epoch 0: running_loss 0.01385641098022461 Learning Rate:0.000000
2
- valuate on test_s_refer: miou 0.7251643069144439 true fscore 0.8044421944022179
3
- bridge on test_s_refer: cos_delta_p_mask_mean=0.754565 | cos_delta_q_mean=-0.062169 | cos_delta_z_gt_mean=0.077297 | cos_p_hat_p_mask_mean=0.084709 | cos_p_hat_q_mean=0.992133 | cos_p_hat_z_gt_mean=0.070145 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=4.598003 | gate_mean=0.625515 | gate_std=0.054416 | p_hat_norm_mean=36.238429 | p_mask_norm_mean=0.854909 | q_norm_mean=36.238428 | z_gt_norm_mean=1.275222
4
- valuate on test_u_refer: miou 0.7346898949889146 true fscore 0.819309664927423
5
- bridge on test_u_refer: cos_delta_p_mask_mean=0.754958 | cos_delta_q_mean=-0.054197 | cos_delta_z_gt_mean=0.089438 | cos_p_hat_p_mask_mean=0.077138 | cos_p_hat_q_mean=0.994077 | cos_p_hat_z_gt_mean=0.023334 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=3.370548 | gate_mean=0.544434 | gate_std=0.033514 | p_hat_norm_mean=30.854847 | p_mask_norm_mean=0.855032 | q_norm_mean=30.854847 | z_gt_norm_mean=0.969146
6
-
7
- valuate on test_n_refer: metric 0.12185448408126831
 
 
 
 
 
 
 
 
log/rpb_probe_mixed_pm_only_a02_wm005_s80.txt DELETED
@@ -1,11 +0,0 @@
1
- Epoch 0: running_loss 0.11956256674602628 Learning Rate:0.000048
2
- Epoch 1: running_loss 0.059521447168663144 Learning Rate:0.000038
3
- Epoch 2: running_loss 0.03955021120297412 Learning Rate:0.000021
4
- Epoch 3: running_loss 0.029611277248477563 Learning Rate:0.000006
5
- Epoch 4: running_loss 0.023673650273121894 Learning Rate:0.000000
6
- valuate on test_s_refer: miou 0.7234249453799384 true fscore 0.8020988971926272
7
- bridge on test_s_refer: cos_delta_p_mask_mean=0.752115 | cos_delta_q_mean=-0.071252 | cos_delta_z_gt_mean=0.081856 | cos_p_hat_p_mask_mean=0.098034 | cos_p_hat_q_mean=0.989714 | cos_p_hat_z_gt_mean=0.072197 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=5.254162 | gate_mean=0.718218 | gate_std=0.053861 | p_hat_norm_mean=36.239985 | p_mask_norm_mean=0.854909 | q_norm_mean=36.239985 | z_gt_norm_mean=1.275222
8
- valuate on test_u_refer: miou 0.7361468947966933 true fscore 0.8214005154371261
9
- bridge on test_u_refer: cos_delta_p_mask_mean=0.754059 | cos_delta_q_mean=-0.063183 | cos_delta_z_gt_mean=0.096618 | cos_p_hat_p_mask_mean=0.090575 | cos_p_hat_q_mean=0.991959 | cos_p_hat_z_gt_mean=0.025874 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=3.926547 | gate_mean=0.635724 | gate_std=0.036734 | p_hat_norm_mean=30.848887 | p_mask_norm_mean=0.855032 | q_norm_mean=30.848887 | z_gt_norm_mean=0.969146
10
-
11
- valuate on test_n_refer: metric 0.12358559668064117
 
 
 
 
 
 
 
 
 
 
 
 
seg_ltpo.py DELETED
@@ -1,1372 +0,0 @@
1
- """
2
- SEG-LTPO: test-time optimization of SimToken's Fseg / q prompt token.
3
-
4
- Two optimizers are provided:
5
-
6
- ltpo_optimize – original antithetic-ES zeroth-order optimizer (Fseg space).
7
- q_ltpo_autograd – autograd optimizer that directly optimizes q (= sparse
8
- prompt embedding passed to the mask decoder) via Adam
9
- maximize, with a differentiable reward. This is the
10
- recommended path when the reward can be made differentiable.
11
-
12
- Staged autograd reward build-up:
13
- Stage 0 check_grad_connectivity — verify ∂R_iou/∂q ≠ 0
14
- Stage 1 QLTPOConfig(stage=1) — R = 0.6·R_iou − 0.2·R_area_soft − λ_reg·‖q−q₀‖²
15
- Stage 2 QLTPOConfig(stage=2) — Stage 1 + 1.0·R_align_det (z_in/z_out stopgrad)
16
- Stage 3 QLTPOConfig(stage=3) — Stage 2 + 0.2·R_temp_feat (full reward)
17
-
18
- Reward gating: use best_q only when R_task(best_q) > R_task(q_init) + gate_delta.
19
-
20
- --- ES baseline (original) ---
21
- Reward:
22
- R = λ1·R_temp_feat + λ2·R_iou_pred + λ3·R_align_contrast − λ4·R_area
23
- Update (antithetic ES, step t):
24
- F_curr = F_curr + η_t · (R+ − R−)/(2σ_t²) · eps_t
25
- best_F = argmax_F R(F)
26
- """
27
-
28
- from __future__ import annotations
29
-
30
- from dataclasses import dataclass, field
31
- from typing import Any, Dict, List, Optional, Tuple
32
-
33
- import torch
34
- import torch.nn.functional as F
35
-
36
-
37
- # ---------------------------------------------------------------------------
38
- # Per-sample diagnostics accumulator for q_ltpo_autograd
39
- # ---------------------------------------------------------------------------
40
-
41
- _q_ltpo_stats: List[Dict[str, Any]] = []
42
-
43
-
44
- def reset_q_ltpo_stats() -> None:
45
- global _q_ltpo_stats
46
- _q_ltpo_stats = []
47
-
48
-
49
- def get_q_ltpo_stats() -> List[Dict[str, Any]]:
50
- return list(_q_ltpo_stats)
51
-
52
-
53
- # ---------------------------------------------------------------------------
54
- # Configuration
55
- # ---------------------------------------------------------------------------
56
-
57
- @dataclass
58
- class LTPOConfig:
59
- T: int = 5
60
- num_anchors: int = 4
61
- sigma_schedule: List[float] = field(
62
- default_factory=lambda: [0.10, 0.08, 0.06, 0.04, 0.02]
63
- )
64
- eta_scale: float = 0.5 # η_t = eta_scale · σ_t
65
-
66
- # Reward weights
67
- lambda1: float = 0.3 # R_temp_feat
68
- lambda2: float = 0.4 # R_iou_pred
69
- lambda3: float = 1.0 # R_align_contrast
70
- lambda4: float = 0.3 # R_area penalty
71
-
72
- beta: float = 0.5 # background penalty coefficient in R_align_contrast
73
-
74
- # Reward gating: fall back to F_init when improvement < gate_delta
75
- gate_delta: float = 0.0
76
-
77
- # L2 trust-region radius on Fseg; None = disabled
78
- trust_delta: Optional[float] = None
79
-
80
-
81
- # ---------------------------------------------------------------------------
82
- # Utilities
83
- # ---------------------------------------------------------------------------
84
-
85
- def get_sam_model(model):
86
- """Return SAM visual_model, unwrapping a PeftModel wrapper if present."""
87
- base = model.base_model.model if hasattr(model, "base_model") else model
88
- return base.model.visual_model
89
-
90
-
91
- def get_anchor_indices(num_frames: int, num_anchors: int) -> List[int]:
92
- """Uniformly sample anchor frame indices from [0, num_frames-1]."""
93
- return [round(v) for v in torch.linspace(0, num_frames - 1, num_anchors).tolist()]
94
-
95
-
96
- def _precompute_dense_emb(
97
- sam_model, model_dtype: torch.dtype, device: torch.device
98
- ) -> torch.Tensor:
99
- """
100
- Constant 'no-mask' dense embedding from SAM's prompt encoder.
101
- Independent of Fseg; precompute once per sample to avoid redundant calls.
102
- Shape: [1, 256, 64, 64].
103
- """
104
- pe = sam_model.prompt_encoder
105
- H, W = pe.image_embedding_size
106
- return (
107
- pe.no_mask_embed.weight # [1, 256]
108
- .reshape(1, -1, 1, 1)
109
- .expand(1, -1, H, W)
110
- .contiguous()
111
- .to(model_dtype)
112
- .to(device)
113
- )
114
-
115
-
116
- # ---------------------------------------------------------------------------
117
- # Lightweight SAM decode (skips prompt_encoder overhead)
118
- # ---------------------------------------------------------------------------
119
-
120
- def _decode_on_anchors(
121
- fseg: torch.Tensor, # [1, 256] float32
122
- image_embeds_anchor: torch.Tensor, # [A, 256, 64, 64] model dtype
123
- dense_emb: torch.Tensor, # [1, 256, 64, 64] model dtype (constant)
124
- mask_decoder,
125
- dense_pe: torch.Tensor, # [1, 256, 64, 64]
126
- model_dtype: torch.dtype,
127
- ) -> Tuple[torch.Tensor, torch.Tensor]:
128
- """
129
- Decode anchor frames for a given Fseg.
130
-
131
- Since no points/boxes are used, prompt_encoder simply concatenates
132
- text_embeds onto an empty sparse tensor, so sparse_emb == Fseg.unsqueeze(1).
133
- We exploit this to skip the full prompt_encoder call each iteration.
134
-
135
- Returns:
136
- low_res_masks: [A, 1, 256, 256]
137
- iou_preds: [A, 1]
138
- """
139
- sparse_emb = fseg.to(model_dtype).unsqueeze(1) # [1, 1, 256]
140
- with torch.no_grad():
141
- low_res_masks, iou_preds = mask_decoder(
142
- image_embeddings=image_embeds_anchor,
143
- image_pe=dense_pe,
144
- sparse_prompt_embeddings=sparse_emb,
145
- dense_prompt_embeddings=dense_emb,
146
- multimask_output=False,
147
- )
148
- return low_res_masks, iou_preds # [A,1,256,256], [A,1]
149
-
150
-
151
- # ---------------------------------------------------------------------------
152
- # Reward computation
153
- # ---------------------------------------------------------------------------
154
-
155
- def _compute_reward(
156
- fseg: torch.Tensor, # [1, 256] float32
157
- low_res_masks: torch.Tensor, # [A, 1, 256, 256]
158
- iou_preds: torch.Tensor, # [A, 1]
159
- image_embeds_anchor: torch.Tensor, # [A, 256, 64, 64]
160
- cfg: LTPOConfig,
161
- ) -> float:
162
- num_anchor = low_res_masks.shape[0]
163
- device = fseg.device
164
-
165
- # Work entirely in float32 for numerical stability
166
- masks_soft = torch.sigmoid(low_res_masks.float().squeeze(1)) # [A, 256, 256]
167
- img_embs = image_embeds_anchor.float() # [A, 256, 64, 64]
168
-
169
- # q lives in SAM's 256-d prompt space (same as Fseg after text_hidden_fcs)
170
- q = F.normalize(fseg[0].float(), dim=0) # [256]
171
-
172
- # Downsample soft masks 256×256 → 64×64 to match image_embed spatial dims.
173
- # Keep as soft weights (no hard threshold) so the reward surface is smooth.
174
- masks_64 = F.interpolate(
175
- masks_soft.unsqueeze(1), size=(64, 64),
176
- mode="bilinear", align_corners=False,
177
- ).squeeze(1) # [A, 64, 64]
178
-
179
- # ── Per-frame masked pooling ──────────────────────────────────────────
180
- z_ins: List[torch.Tensor] = []
181
- z_outs: List[torch.Tensor] = []
182
- for t in range(num_anchor):
183
- m = masks_64[t] # [64, 64]
184
- img = img_embs[t] # [256, 64, 64]
185
-
186
- # Soft weighted average pooling over foreground / background
187
- z_in = (img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6)
188
- z_out = (img * (1.0 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1.0 - m).sum() + 1e-6)
189
-
190
- z_ins.append(F.normalize(z_in, dim=0)) # [256]
191
- z_outs.append(F.normalize(z_out, dim=0)) # [256]
192
-
193
- # ── R_align_contrast ──────────────────────────────────────────────────
194
- # Maximise Fseg↔inside alignment while penalising Fseg↔outside alignment.
195
- # Contrast term prevents reward-hacking via large masks:
196
- # a large mask pulls inside and outside features together, shrinking the gap.
197
- r_align = sum(
198
- (q @ z_ins[t]) - cfg.beta * (q @ z_outs[t])
199
- for t in range(num_anchor)
200
- ) / num_anchor
201
-
202
- # ── R_iou_pred ────────────────────────────────────────────────────────
203
- # SAM's internal mask-quality head, calibrated during SAM training.
204
- r_iou = iou_preds.float().mean()
205
-
206
- # ── R_temp_feat ───────────────────────────────────────────────────────
207
- # Feature-space consistency between adjacent anchor frames.
208
- # Harder to game than mask-IoU: large masks pool diverse background
209
- # features across frames, degrading cosine similarity.
210
- r_temp = torch.tensor(0.0, device=device)
211
- if num_anchor > 1:
212
- r_temp = sum(
213
- z_ins[t] @ z_ins[t + 1] for t in range(num_anchor - 1)
214
- ) / (num_anchor - 1)
215
-
216
- # ── R_area ────────────────────────────────────────────────────────────
217
- r_area = masks_64.mean()
218
-
219
- R = (cfg.lambda1 * r_temp
220
- + cfg.lambda2 * r_iou
221
- + cfg.lambda3 * r_align
222
- - cfg.lambda4 * r_area)
223
-
224
- return R.item()
225
-
226
-
227
- # ---------------------------------------------------------------------------
228
- # Ablation baseline: Best-of-2 Random (no iterative update)
229
- # ---------------------------------------------------------------------------
230
-
231
- def best_of_2_optimize(
232
- F_init: torch.Tensor,
233
- image_embeds: torch.Tensor,
234
- anchor_indices: List[int],
235
- sam_model,
236
- model_dtype: torch.dtype,
237
- cfg: LTPOConfig,
238
- ) -> torch.Tensor:
239
- """
240
- Best-of-2 Random baseline.
241
-
242
- Sample one antithetic pair (F+, F-) using the first sigma value,
243
- evaluate both, return whichever has the higher reward.
244
- No iterative update — serves as the ablation for the update rule.
245
- Same reward gating as ltpo_optimize for a fair comparison.
246
- """
247
- device = F_init.device
248
- image_embeds_anchor = image_embeds[anchor_indices]
249
-
250
- dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
251
- dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
252
- mask_dec = sam_model.mask_decoder
253
-
254
- lrm0, iou0 = _decode_on_anchors(
255
- F_init, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
256
- )
257
- R_init = _compute_reward(F_init, lrm0, iou0, image_embeds_anchor, cfg)
258
-
259
- sigma = cfg.sigma_schedule[0]
260
- eps = torch.randn_like(F_init) * sigma
261
- F_plus = F_init + eps
262
- F_minus = F_init - eps
263
-
264
- lrm_p, iou_p = _decode_on_anchors(
265
- F_plus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
266
- )
267
- lrm_m, iou_m = _decode_on_anchors(
268
- F_minus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
269
- )
270
- R_plus = _compute_reward(F_plus, lrm_p, iou_p, image_embeds_anchor, cfg)
271
- R_minus = _compute_reward(F_minus, lrm_m, iou_m, image_embeds_anchor, cfg)
272
-
273
- best_R, best_F = R_init, F_init.clone()
274
- if R_plus > best_R: best_R, best_F = R_plus, F_plus.clone()
275
- if R_minus > best_R: best_R, best_F = R_minus, F_minus.clone()
276
-
277
- if best_R <= R_init + cfg.gate_delta:
278
- return F_init
279
- return best_F
280
-
281
-
282
- # ---------------------------------------------------------------------------
283
- # Full-video decode with a given Fseg
284
- # ---------------------------------------------------------------------------
285
-
286
- def _sobel_edge(rgb_frames: torch.Tensor) -> torch.Tensor:
287
- """Compute Sobel edge magnitude from normalized RGB frames.
288
-
289
- Args:
290
- rgb_frames: [T, 3, H, W] float32 (SAM-normalized, CUDA)
291
- Returns:
292
- edge: [T, 1, H, W] float32, non-negative
293
- """
294
- gray = rgb_frames.float().mean(dim=1, keepdim=True) # [T, 1, H, W]
295
- kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
296
- dtype=torch.float32, device=rgb_frames.device).view(1, 1, 3, 3)
297
- ky = kx.transpose(2, 3)
298
- gx = F.conv2d(gray, kx, padding=1)
299
- gy = F.conv2d(gray, ky, padding=1)
300
- return torch.sqrt(gx ** 2 + gy ** 2 + 1e-6) # [T, 1, H, W]
301
-
302
-
303
- def _boundary_edge_score(
304
- low_res_masks: torch.Tensor, # [T, K, 256, 256] logits
305
- rgb_frames: torch.Tensor, # [T, 3, H, W] float32
306
- resize: tuple, # (H_resized, W_resized)
307
- area_temp: float = 5.0,
308
- ) -> torch.Tensor:
309
- """Score each of K mask candidates by boundary-edge alignment.
310
-
311
- R_edge = <soft_boundary_band, Sobel_edge> / (sum(soft_boundary_band) + ε)
312
- Rewards masks whose boundaries coincide with image edges.
313
-
314
- Returns: [T, K] float32 scores (higher = better boundary alignment)
315
- """
316
- T, K = low_res_masks.shape[:2]
317
- H_r, W_r = resize
318
-
319
- # Upsample all candidates to resized image resolution at once
320
- masks_up = F.interpolate(
321
- low_res_masks.reshape(T * K, 1, 256, 256).float(),
322
- size=(H_r, W_r), mode="bilinear", align_corners=False,
323
- ).reshape(T, K, H_r, W_r) # [T, K, H, W]
324
-
325
- E = _sobel_edge(rgb_frames[:, :, :H_r, :W_r]) # [T, 1, H, W]
326
-
327
- m = torch.sigmoid(masks_up / area_temp) # [T, K, H, W]
328
- b = 4.0 * m * (1.0 - m) # soft boundary band
329
- num = (b * E.squeeze(1).unsqueeze(1)).sum(dim=[2, 3]) # [T, K]
330
- den = b.sum(dim=[2, 3]) + 1e-6
331
- return num / den # [T, K]
332
-
333
-
334
- def decode_full_video(
335
- fseg: torch.Tensor, # [1, 256] float32
336
- image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
337
- sam_model,
338
- resize: tuple, # (H_resized, W_resized)
339
- orgsize: tuple, # (H_orig, W_orig)
340
- model_dtype: torch.dtype,
341
- rgb_frames: Optional[torch.Tensor] = None, # [T, 3, H, W]; enables edge selection
342
- multimask: bool = False, # True = 3 candidates; False = single mask
343
- ) -> torch.Tensor:
344
- """Decode all T frames with the given Fseg.
345
-
346
- Selection logic (applied per-frame):
347
- - multimask=False, rgb_frames=None : original single-mask decode (baseline)
348
- - multimask=True, rgb_frames=None : 3 candidates, select by SAM iou_pred
349
- - multimask=True, rgb_frames=* : 3 candidates, select by boundary-edge score
350
- (boundary band × Sobel edge; directly rewards boundary-image alignment)
351
-
352
- Returns raw logit mask [T, H_orig, W_orig] (not yet sigmoid).
353
- """
354
- device = image_embeds.device
355
- dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
356
- dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
357
- sparse_emb = fseg.to(model_dtype).unsqueeze(1) # [1, 1, 256]
358
-
359
- with torch.no_grad():
360
- low_res_masks, iou_preds = sam_model.mask_decoder(
361
- image_embeddings=image_embeds,
362
- image_pe=dense_pe,
363
- sparse_prompt_embeddings=sparse_emb,
364
- dense_prompt_embeddings=dense_emb,
365
- multimask_output=multimask,
366
- ) # [T, K, 256, 256], [T, K] where K=1 or K=3
367
-
368
- if multimask:
369
- T = low_res_masks.shape[0]
370
- if rgb_frames is not None:
371
- # Step 1b: boundary-edge score selects best candidate
372
- scores = _boundary_edge_score(low_res_masks, rgb_frames, resize)
373
- else:
374
- # Step 1a: SAM's own iou_pred selects best candidate
375
- scores = iou_preds
376
- best_idx = scores.argmax(dim=1) # [T]
377
- low_res_masks = low_res_masks[torch.arange(T, device=device), best_idx].unsqueeze(1)
378
-
379
- pred_mask = sam_model.postprocess_masks(
380
- low_res_masks, input_size=resize, original_size=orgsize
381
- ) # [T, 1, H, W]
382
- return pred_mask.squeeze(1) # [T, H, W]
383
-
384
-
385
- # ---------------------------------------------------------------------------
386
- # Main optimisation loop
387
- # ---------------------------------------------------------------------------
388
-
389
- def ltpo_optimize(
390
- F_init: torch.Tensor, # [1, 256] float32 on CUDA
391
- image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
392
- anchor_indices: List[int],
393
- sam_model,
394
- model_dtype: torch.dtype,
395
- cfg: LTPOConfig,
396
- ) -> torch.Tensor:
397
- """
398
- Optimise Fseg at test time via antithetic ES.
399
-
400
- Returns best Fseg found [1, 256] float32.
401
- Falls back to F_init when reward gating rejects all updates.
402
- """
403
- device = F_init.device
404
- image_embeds_anchor = image_embeds[anchor_indices] # [A, 256, 64, 64]
405
-
406
- # Precompute constants shared across every optimisation step
407
- dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
408
- dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
409
- mask_dec = sam_model.mask_decoder
410
-
411
- # ── Evaluate initial token ────────────────────────────────────────────
412
- lrm0, iou0 = _decode_on_anchors(
413
- F_init, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
414
- )
415
- R_init = _compute_reward(F_init, lrm0, iou0, image_embeds_anchor, cfg)
416
-
417
- best_F, best_R = F_init.clone(), R_init
418
- F_curr = F_init.clone()
419
-
420
- # ── Optimisation loop ─────────────────────────────────────────────────
421
- for t in range(cfg.T):
422
- sigma_t = cfg.sigma_schedule[t]
423
- eta_t = cfg.eta_scale * sigma_t
424
-
425
- eps = torch.randn_like(F_curr) * sigma_t
426
- F_plus = F_curr + eps
427
- F_minus = F_curr - eps
428
-
429
- lrm_p, iou_p = _decode_on_anchors(
430
- F_plus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
431
- )
432
- lrm_m, iou_m = _decode_on_anchors(
433
- F_minus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
434
- )
435
-
436
- R_plus = _compute_reward(F_plus, lrm_p, iou_p, image_embeds_anchor, cfg)
437
- R_minus = _compute_reward(F_minus, lrm_m, iou_m, image_embeds_anchor, cfg)
438
-
439
- # Track the best token seen across all evaluated candidates
440
- if R_plus > best_R:
441
- best_R, best_F = R_plus, F_plus.clone()
442
- if R_minus > best_R:
443
- best_R, best_F = R_minus, F_minus.clone()
444
-
445
- # Antithetic policy-gradient update of the iterate
446
- # Formula: F_{t+1} = F_t + η_t · (R+ - R−)/(2σ_t²) · eps_t
447
- grad_est = (R_plus - R_minus) / (2.0 * sigma_t ** 2)
448
- F_curr = F_curr + eta_t * grad_est * eps
449
-
450
- # Optional L2 trust-region: keep F_curr within radius trust_delta of F_init
451
- if cfg.trust_delta is not None:
452
- diff = F_curr - F_init
453
- norm = diff.norm()
454
- if norm > cfg.trust_delta:
455
- F_curr = F_init + diff * (cfg.trust_delta / norm)
456
-
457
- # ── Reward gating ─────────────────────────────────────────────────────
458
- # Reject the update when there is no meaningful improvement over the
459
- # initial token (handles Null-like samples where no target exists).
460
- if best_R <= R_init + cfg.gate_delta:
461
- return F_init
462
- return best_F
463
-
464
-
465
- # ===========================================================================
466
- # q-LTPO-autograd: differentiable test-time optimization of the prompt token
467
- # ===========================================================================
468
-
469
- @dataclass
470
- class QLTPOConfig:
471
- """Configuration for q_ltpo_autograd (Stages 1–3 + Stage 2-ext variants).
472
-
473
- stage controls which reward terms are active:
474
- 1 R_iou + R_area_soft + reg (baseline autograd)
475
- 2 Stage 1 + R_align_det (z_in/z_out stopgrad) (self-bootstrapped alignment)
476
- 3 Stage 2 + R_temp_feat (full reward)
477
- 21 Stage 1 + R_tether (P1a: tether probe) (frozen r_ref via q_init attn)
478
- 22 Stage 1 + R_faithful (P1b: faithful ext-ref) (z_in/z_out vs frozen r_ref)
479
- """
480
- stage: int = 1
481
- T: int = 5
482
- num_anchors: int = 4
483
-
484
- # ── Optimizer ──────────────────────────────────────────────────────────
485
- # lr=0 → auto-set to 0.01 × RMS(q_init); any positive value is used directly
486
- lr: float = 0.0
487
- # max_drift=0 → auto-set to 0.5 × ‖q_init‖; any positive value is a hard radius
488
- max_drift: float = 0.0
489
-
490
- # ── Stage 1 reward weights ─────────────────────────────────────────────
491
- lambda_iou: float = 0.6
492
- lambda_area: float = 0.2
493
- lambda_reg: float = 0.01
494
- area_temp: float = 5.0 # sigmoid temperature for R_area_soft
495
-
496
- # ── Stage 2 additional weights ─────────────────────────────────────────
497
- lambda_align: float = 1.0
498
- beta_align: float = 0.5 # background penalty coefficient in R_align
499
-
500
- # ── Stage 3 additional weights ─────────────────────────────────────────
501
- lambda_temp: float = 0.2
502
-
503
- # ── Gating ─────────────────────────────────────────────────────────────
504
- gate_delta: float = 0.0
505
-
506
- # ── e0-modulated R_iou (principled Null-safety) ────────────────────────
507
- # e0 = stopgrad(R_area_soft(q_init)): the initial soft-area fraction acts
508
- # as an existence prior on the R_iou term.
509
- # "none" → original behavior (e0 = 1, no modulation)
510
- # "identity" → e0 = R_area_soft(q_init) [first version]
511
- # "sqrt" → e0 = sqrt(R_area_soft(q_init) + e0_eps)
512
- e0_modulation: str = "identity"
513
- e0_eps: float = 1e-4 # epsilon for "sqrt" variant
514
-
515
- # ── Stage 2-ext: external reference (stages 21 and 22) ────────────────
516
- # r_ref = AttnPool(image_feats_anchor, q_init): frozen visual anchor derived
517
- # from q_init's attention over SAM image features. Breaks Stage 2's
518
- # self-confirming bias by providing a mask-independent teacher.
519
- # r_ref_temp: softmax temperature for attention pooling (sqrt(256) = 16).
520
- r_ref_temp: float = 16.0
521
-
522
- # ── Direction B: boundary precision rewards ────────────────────────────
523
- # B1: asymmetric area expansion penalty
524
- # Only penalises growth beyond (1+τ)×e0; allows mask contraction.
525
- # Targets the observed pattern where LTPO slightly expands masks into
526
- # non-target regions (recall↑ but precision↓, hurting F-score).
527
- # B2: boundary sharpness reward
528
- # -mean(4m(1-m)) with temperature=1.0; rewards bimodal (certain)
529
- # mask predictions, encouraging cleaner boundary predictions.
530
- lambda_area_inc: float = 0.0 # B1 weight (0 = disabled)
531
- area_inc_tau: float = 0.0 # B1 tolerance band: allow (1+τ)×e0
532
- lambda_sharp: float = 0.0 # B2 weight (0 = disabled)
533
-
534
- # ── Oracle Null-safety gate (analysis only; NOT for final method) ──────
535
- # Derived from test-set distribution (Null area_hard ≈ 0.01, Seen ≈ 0.05)
536
- # so must not be used in reported results. Set null_gate_delta=0 to disable.
537
- null_area_threshold: float = 0.02 # hard area fraction below which guard activates
538
- null_gate_delta: float = 0.0 # 0 = disabled; 0.05 = oracle experiment
539
-
540
- # ── Direction II: Frame-adaptive token optimization (stage=4) ─────────
541
- # q_t = q_global + delta_t, where delta_t is a per-anchor residual.
542
- # Optimizes q_global and {delta_t} jointly with Adam.
543
- # lambda_residual: soft L2 penalty on delta_t
544
- # lambda_smooth_temp: temporal smoothness penalty on adjacent delta differences
545
- # max_delta_drift_scale: per-anchor hard L2 clip = scale × ‖q_init‖
546
- # Prevents individual anchors from wandering to a completely different visual mode.
547
- # Keep << max_drift (0.5) so delta stays a "small frame correction" to q_global.
548
- # 0.1 is tight (delta ≤ 20% of global drift budget), 0.3 is moderate.
549
- lambda_residual: float = 0.001
550
- lambda_smooth_temp: float = 0.0
551
- max_delta_drift_scale: float = 0.1 # per-anchor clip = scale × ‖q_init‖
552
-
553
-
554
- # ---------------------------------------------------------------------------
555
- # e0 helper
556
- # ---------------------------------------------------------------------------
557
-
558
- def _compute_e0(r_area_soft_init: float, cfg: "QLTPOConfig") -> float:
559
- """Compute the existence-prior weight from the initial soft area."""
560
- if cfg.e0_modulation == "identity":
561
- return r_area_soft_init
562
- if cfg.e0_modulation == "sqrt":
563
- return (r_area_soft_init + cfg.e0_eps) ** 0.5
564
- return 1.0 # "none"
565
-
566
-
567
- # ---------------------------------------------------------------------------
568
- # Differentiable anchor decode (float32 throughout; no torch.no_grad)
569
- # ---------------------------------------------------------------------------
570
-
571
- def _decode_on_anchors_diff(
572
- q: torch.Tensor, # [1, 256] float32
573
- image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
574
- dense_emb_fp32: torch.Tensor, # [1, 256, 64, 64] float32
575
- mask_decoder,
576
- dense_pe_fp32: torch.Tensor, # [1, 256, 64, 64] float32
577
- ) -> Tuple[torch.Tensor, torch.Tensor]:
578
- """Differentiable mask-decoder forward.
579
-
580
- All inputs are float32 to avoid fp16 gradient truncation.
581
- q may be a Parameter (requires_grad=True) or a plain detached tensor.
582
- Returns low_res_masks [A,1,256,256] and iou_preds [A,1], both float32.
583
- """
584
- sparse_emb = q.unsqueeze(1) # [1, 1, 256]
585
- low_res_masks, iou_preds = mask_decoder(
586
- image_embeddings=image_embeds_anchor_fp32,
587
- image_pe=dense_pe_fp32,
588
- sparse_prompt_embeddings=sparse_emb,
589
- dense_prompt_embeddings=dense_emb_fp32,
590
- multimask_output=False,
591
- )
592
- return low_res_masks, iou_preds # [A,1,256,256], [A,1]
593
-
594
-
595
- # ---------------------------------------------------------------------------
596
- # Differentiable reward components
597
- # ---------------------------------------------------------------------------
598
-
599
- def _task_reward_stage1(
600
- lrm: torch.Tensor, # [A,1,256,256] float32
601
- iou: torch.Tensor, # [A,1] float32
602
- cfg: QLTPOConfig,
603
- e0: float = 1.0,
604
- ) -> torch.Tensor:
605
- """Task reward (no regularization): used for best_q tracking and gating.
606
-
607
- e0 is the stopgrad existence prior: R_area_soft(q_init) scaled via
608
- cfg.e0_modulation. When e0 << 1 the iou term is suppressed, so the
609
- optimizer sees only the area-penalty gradient and naturally tends toward
610
- smaller (more conservative) masks — the correct behavior when the initial
611
- prediction is near-empty (Null frames).
612
-
613
- Optional boundary precision terms (Direction B):
614
- B1 (lambda_area_inc > 0): asymmetric expansion penalty
615
- -λ_inc · ReLU(r_area - (1+τ)·e0)
616
- Penalises mask growth beyond the initial area (+ tolerance band τ).
617
- e0 doubles as the stopgrad initial-area threshold — zero extra cost.
618
- B2 (lambda_sharp > 0): boundary sharpness reward
619
- -λ_sharp · mean(4m(1-m)) with m = sigmoid(lrm), temperature=1.0
620
- Maximises bimodality of mask logits → cleaner boundary predictions.
621
- """
622
- r_iou = iou.mean()
623
- r_area = torch.sigmoid(lrm / cfg.area_temp).mean()
624
- R = cfg.lambda_iou * e0 * r_iou - cfg.lambda_area * r_area
625
-
626
- # B1: penalise expansion beyond (1+τ)×e0 (allow contraction freely)
627
- if cfg.lambda_area_inc > 0.0:
628
- area_ceil = (1.0 + cfg.area_inc_tau) * e0
629
- R = R - cfg.lambda_area_inc * F.relu(r_area - area_ceil)
630
-
631
- # B2: reward confident (bimodal) boundary predictions
632
- if cfg.lambda_sharp > 0.0:
633
- m_sharp = torch.sigmoid(lrm) # temperature=1.0 (sharp)
634
- boundary_uncertain = 4.0 * m_sharp * (1.0 - m_sharp)
635
- R = R - cfg.lambda_sharp * boundary_uncertain.mean()
636
-
637
- return R
638
-
639
-
640
- def _task_reward_stage2(
641
- q: torch.Tensor, # [1, 256] float32
642
- lrm: torch.Tensor, # [A,1,256,256] float32
643
- iou: torch.Tensor, # [A,1] float32
644
- image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
645
- cfg: QLTPOConfig,
646
- e0: float = 1.0,
647
- ) -> torch.Tensor:
648
- """Stage 2 task reward: Stage 1 + R_align_det (z_in/z_out are stopgrad)."""
649
- r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
650
-
651
- A = lrm.shape[0]
652
- masks_64 = F.interpolate(
653
- torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
654
- size=(64, 64), mode="bilinear", align_corners=False,
655
- ).squeeze(1) # [A, 64, 64]
656
-
657
- q_norm = F.normalize(q[0], dim=0) # [256]
658
- r_align = torch.tensor(0.0, device=q.device)
659
- for t in range(A):
660
- m = masks_64[t].detach() # stopgrad on z_in/z_out
661
- img = image_embeds_anchor_fp32[t] # [256, 64, 64]
662
- z_in = F.normalize((img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0)
663
- z_out = F.normalize((img * (1 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1 - m).sum() + 1e-6), dim=0)
664
- r_align = r_align + q_norm @ z_in - cfg.beta_align * (q_norm @ z_out)
665
- r_align = r_align / A
666
-
667
- return r_s1 + cfg.lambda_align * r_align
668
-
669
-
670
- def _task_reward_stage3(
671
- q: torch.Tensor,
672
- lrm: torch.Tensor,
673
- iou: torch.Tensor,
674
- image_embeds_anchor_fp32: torch.Tensor,
675
- cfg: QLTPOConfig,
676
- e0: float = 1.0,
677
- ) -> torch.Tensor:
678
- """Stage 3 task reward: Stage 2 + R_temp_feat."""
679
- r_s2 = _task_reward_stage2(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
680
-
681
- A = lrm.shape[0]
682
- if A < 2:
683
- return r_s2
684
-
685
- masks_64 = F.interpolate(
686
- torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
687
- size=(64, 64), mode="bilinear", align_corners=False,
688
- ).squeeze(1) # [A, 64, 64]
689
-
690
- z_ins = []
691
- for t in range(A):
692
- m = masks_64[t].detach()
693
- img = image_embeds_anchor_fp32[t]
694
- z_in = F.normalize((img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0)
695
- z_ins.append(z_in)
696
-
697
- r_temp = sum(z_ins[t] @ z_ins[t + 1] for t in range(A - 1)) / (A - 1)
698
- return r_s2 + cfg.lambda_temp * r_temp
699
-
700
-
701
- @torch.no_grad()
702
- def _compute_r_ref(
703
- q_init: torch.Tensor, # [1, 256] float32
704
- image_embeds_anchor: torch.Tensor, # [A, 256, 64, 64] float32
705
- temp: float = 16.0,
706
- ) -> Tuple[torch.Tensor, torch.Tensor]:
707
- """Frozen external visual reference via attention pooling guided by q_init.
708
-
709
- r_ref: regions most attended by q_init (positive anchor).
710
- r_neg: regions least attended by q_init (anti-attended negative).
711
- Both are in the SAM 256d space — no projection needed.
712
- Computed once before the optimization loop and kept fixed (stopgrad).
713
- """
714
- img_flat = image_embeds_anchor.flatten(2) # [A, 256, H*W]
715
- q_norm = F.normalize(q_init[0], dim=0) # [256]
716
- img_norm = F.normalize(img_flat, dim=1) # [A, 256, H*W]
717
-
718
- # cosine similarity between q and each spatial position
719
- attn = torch.einsum('d,adp->ap', q_norm, img_norm) # [A, H*W]
720
-
721
- attn_w_pos = torch.softmax( attn / temp, dim=-1) # [A, H*W]
722
- attn_w_neg = torch.softmax(-attn / temp, dim=-1) # [A, H*W] anti-attended
723
-
724
- # soft attention pooling in the original (non-normalized) feature space
725
- r_ref_frames = torch.einsum('ap,adp->ad', attn_w_pos, img_flat) # [A, 256]
726
- r_neg_frames = torch.einsum('ap,adp->ad', attn_w_neg, img_flat) # [A, 256]
727
-
728
- r_ref = F.normalize(r_ref_frames.mean(0), dim=0) # [256]
729
- r_neg = F.normalize(r_neg_frames.mean(0), dim=0) # [256]
730
- return r_ref, r_neg
731
-
732
-
733
- def _task_reward_stage2_tether(
734
- q: torch.Tensor, # [1, 256] float32
735
- lrm: torch.Tensor, # [A,1,256,256] float32
736
- iou: torch.Tensor, # [A,1] float32
737
- r_ref: torch.Tensor, # [256] frozen
738
- r_neg: torch.Tensor, # [256] frozen
739
- cfg: QLTPOConfig,
740
- e0: float = 1.0,
741
- ) -> torch.Tensor:
742
- """Stage 21 (P1a tether): Stage 1 + R_tether.
743
-
744
- R_tether = cos(q, r_ref) - beta·cos(q, r_neg)
745
- q is pulled toward the frozen visual anchor without touching mask features.
746
- Tests whether a fixed external reference stabilizes the optimization trajectory.
747
- """
748
- r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
749
- q_norm = F.normalize(q[0], dim=0)
750
- r_tether = q_norm @ r_ref - cfg.beta_align * (q_norm @ r_neg)
751
- return r_s1 + cfg.lambda_align * r_tether
752
-
753
-
754
- def _task_reward_stage2_faithful(
755
- q: torch.Tensor, # [1, 256] float32
756
- lrm: torch.Tensor, # [A,1,256,256] float32
757
- iou: torch.Tensor, # [A,1] float32
758
- image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
759
- r_ref: torch.Tensor, # [256] frozen
760
- cfg: QLTPOConfig,
761
- e0: float = 1.0,
762
- ) -> torch.Tensor:
763
- """Stage 22 (P1b faithful): Stage 1 + R_faithful.
764
-
765
- R_faithful = mean_t[ cos(z_in(q,t), r_ref) - beta·cos(z_out(q,t), r_ref) ]
766
- z_in/z_out come from the *current* mask (change during optimization), but the
767
- teacher r_ref is frozen — breaking Stage 2's self-confirming bias while keeping
768
- the same structural form (mask-region vs. reference alignment).
769
- """
770
- r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
771
- A = lrm.shape[0]
772
- masks_64 = F.interpolate(
773
- torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
774
- size=(64, 64), mode="bilinear", align_corners=False,
775
- ).squeeze(1) # [A, 64, 64]
776
-
777
- r_align = torch.tensor(0.0, device=q.device)
778
- for t in range(A):
779
- m = masks_64[t].detach() # stopgrad on mask weights only
780
- img = image_embeds_anchor_fp32[t] # [256, 64, 64]
781
- z_in = F.normalize((img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0)
782
- z_out = F.normalize((img * (1 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1 - m).sum() + 1e-6), dim=0)
783
- # teacher is r_ref (frozen), not z_in itself — no confirmation bias
784
- r_align = r_align + z_in @ r_ref - cfg.beta_align * (z_out @ r_ref)
785
- r_align = r_align / A
786
-
787
- return r_s1 + cfg.lambda_align * r_align
788
-
789
-
790
- def _decode_on_anchors_diff_adaptive(
791
- q_global: torch.Tensor, # [1, 256] float32, requires_grad
792
- delta: torch.Tensor, # [A, 256] float32, requires_grad
793
- image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32, detached
794
- dense_emb_fp32: torch.Tensor, # [1, 256, 64, 64] float32, detached
795
- mask_decoder,
796
- dense_pe_fp32: torch.Tensor, # [1, 256, 64, 64] float32, detached
797
- ) -> Tuple[torch.Tensor, torch.Tensor]:
798
- """Frame-adaptive differentiable decode: each anchor t uses q_t = q_global + delta[t].
799
-
800
- Loops over A anchors to preserve gradient flow through both q_global and delta.
801
- Returns low_res_masks [A,1,256,256] and iou_preds [A,1], both float32.
802
- """
803
- A = image_embeds_anchor_fp32.shape[0]
804
- lrm_list: List[torch.Tensor] = []
805
- iou_list: List[torch.Tensor] = []
806
- for t in range(A):
807
- q_t = q_global + delta[t : t + 1] # [1, 256]
808
- sparse_emb = q_t.unsqueeze(1) # [1, 1, 256]
809
- lrm_t, iou_t = mask_decoder(
810
- image_embeddings=image_embeds_anchor_fp32[t : t + 1],
811
- image_pe=dense_pe_fp32,
812
- sparse_prompt_embeddings=sparse_emb,
813
- dense_prompt_embeddings=dense_emb_fp32,
814
- multimask_output=False,
815
- ) # [1,1,256,256], [1,1]
816
- lrm_list.append(lrm_t)
817
- iou_list.append(iou_t)
818
- return torch.cat(lrm_list, dim=0), torch.cat(iou_list, dim=0) # [A,1,256,256], [A,1]
819
-
820
-
821
- def _task_reward_frame_adaptive(
822
- lrm: torch.Tensor, # [A, 1, 256, 256] float32
823
- iou: torch.Tensor, # [A, 1] float32
824
- cfg: "QLTPOConfig",
825
- e0_vec: List[float], # per-anchor existence priors [A]
826
- ) -> torch.Tensor:
827
- """Per-anchor task reward averaged over anchors (no regularization)."""
828
- A = lrm.shape[0]
829
- R = torch.tensor(0.0, device=lrm.device)
830
- for t in range(A):
831
- r_iou_t = iou[t].mean()
832
- r_area_t = torch.sigmoid(lrm[t] / cfg.area_temp).mean()
833
- R = R + cfg.lambda_iou * e0_vec[t] * r_iou_t - cfg.lambda_area * r_area_t
834
- return R / A
835
-
836
-
837
- def _compute_full_reward_adaptive(
838
- q_global: torch.Tensor, # [1, 256]
839
- delta: torch.Tensor, # [A, 256]
840
- lrm: torch.Tensor, # [A, 1, 256, 256]
841
- iou: torch.Tensor, # [A, 1]
842
- q_init: torch.Tensor, # [1, 256] detached
843
- cfg: "QLTPOConfig",
844
- e0_vec: List[float],
845
- ) -> torch.Tensor:
846
- """Full adaptive reward = task + residual penalty + temporal smoothness + L2 reg."""
847
- r_task = _task_reward_frame_adaptive(lrm, iou, cfg, e0_vec)
848
- r_delta = delta.pow(2).sum()
849
- r_reg = (q_global - q_init).pow(2).sum()
850
- R = r_task - cfg.lambda_residual * r_delta - cfg.lambda_reg * r_reg
851
-
852
- A = delta.shape[0]
853
- if A > 1 and cfg.lambda_smooth_temp > 0.0:
854
- r_smooth = torch.tensor(0.0, device=delta.device)
855
- for t in range(A - 1):
856
- r_smooth = r_smooth + (delta[t] - delta[t + 1]).pow(2).sum()
857
- R = R - cfg.lambda_smooth_temp * r_smooth / (A - 1)
858
-
859
- return R
860
-
861
-
862
- def _compute_task_reward(
863
- q: torch.Tensor,
864
- lrm: torch.Tensor,
865
- iou: torch.Tensor,
866
- image_embeds_anchor_fp32: torch.Tensor,
867
- cfg: QLTPOConfig,
868
- e0: float = 1.0,
869
- r_ref: Optional[torch.Tensor] = None,
870
- r_neg: Optional[torch.Tensor] = None,
871
- ) -> torch.Tensor:
872
- """Dispatch to the correct stage's task reward."""
873
- if cfg.stage == 1:
874
- return _task_reward_stage1(lrm, iou, cfg, e0)
875
- if cfg.stage == 2:
876
- return _task_reward_stage2(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
877
- if cfg.stage == 21:
878
- assert r_ref is not None and r_neg is not None, "stage 21 requires r_ref/r_neg"
879
- return _task_reward_stage2_tether(q, lrm, iou, r_ref, r_neg, cfg, e0)
880
- if cfg.stage == 22:
881
- assert r_ref is not None, "stage 22 requires r_ref"
882
- return _task_reward_stage2_faithful(q, lrm, iou, image_embeds_anchor_fp32, r_ref, cfg, e0)
883
- return _task_reward_stage3(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
884
-
885
-
886
- def _compute_full_reward(
887
- q: torch.Tensor,
888
- lrm: torch.Tensor,
889
- iou: torch.Tensor,
890
- image_embeds_anchor_fp32: torch.Tensor,
891
- q_init: torch.Tensor,
892
- cfg: QLTPOConfig,
893
- e0: float = 1.0,
894
- r_ref: Optional[torch.Tensor] = None,
895
- r_neg: Optional[torch.Tensor] = None,
896
- ) -> torch.Tensor:
897
- """Full reward = task reward + L2 regularization (used for backward)."""
898
- r_task = _compute_task_reward(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0, r_ref, r_neg)
899
- r_reg = (q - q_init).pow(2).sum()
900
- return r_task - cfg.lambda_reg * r_reg
901
-
902
-
903
- # ---------------------------------------------------------------------------
904
- # Stage 0: gradient connectivity check
905
- # ---------------------------------------------------------------------------
906
-
907
- def check_grad_connectivity(
908
- F_init: torch.Tensor, # [1, 256] any dtype
909
- image_embeds: torch.Tensor, # [T, 256, 64, 64] any dtype
910
- anchor_indices: List[int],
911
- sam_model,
912
- model_dtype: torch.dtype,
913
- num_steps: int = 5,
914
- lr: float = 0.0,
915
- ) -> dict:
916
- """Stage 0: verify ∂R_iou_pred/∂q ≠ 0 and reward rises with Adam maximize.
917
-
918
- Runs num_steps of Adam on R = R_iou_pred only (the simplest differentiable
919
- reward, no custom ops required). Returns a diagnostic dict.
920
-
921
- Usage:
922
- diag = check_grad_connectivity(F_init, image_embeds, anchors, sam, dtype)
923
- print(diag['grad_norm_step0'], diag['reward_trajectory'])
924
- # expect grad_norm > 0 and rewards non-decreasing
925
- """
926
- device = F_init.device
927
- image_embeds_anchor = image_embeds[anchor_indices].float().detach()
928
- dense_emb = _precompute_dense_emb(sam_model, model_dtype, device).float().detach()
929
- dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device).float().detach()
930
- mask_dec = sam_model.mask_decoder
931
-
932
- q_init_fp32 = F_init.float().detach()
933
- if lr <= 0:
934
- lr = 0.01 * (q_init_fp32.norm() / (q_init_fp32.numel() ** 0.5)).item()
935
-
936
- q = torch.nn.Parameter(q_init_fp32.clone())
937
- optimizer = torch.optim.Adam([q], lr=lr, maximize=True)
938
-
939
- grad_norms, rewards = [], []
940
- for step in range(num_steps):
941
- optimizer.zero_grad()
942
- lrm, iou = _decode_on_anchors_diff(q, image_embeds_anchor, dense_emb, mask_dec, dense_pe)
943
- R = iou.mean()
944
- R.backward()
945
- grad_norm = q.grad.norm().item() if q.grad is not None else 0.0
946
- grad_norms.append(grad_norm)
947
- rewards.append(R.item())
948
- optimizer.step()
949
-
950
- return {
951
- "grad_norm_step0": grad_norms[0],
952
- "grad_norms": grad_norms,
953
- "reward_trajectory": rewards,
954
- "gradient_connected": grad_norms[0] > 1e-8,
955
- }
956
-
957
-
958
- # ---------------------------------------------------------------------------
959
- # AVT proxy reward (Step A0: reward–metric correlation study)
960
- # ---------------------------------------------------------------------------
961
-
962
- @torch.no_grad()
963
- def _compute_avt_proxy_reward(
964
- q_init_fp32: torch.Tensor, # [1, 256] — frozen AVT anchor (= Fseg)
965
- lrm: torch.Tensor, # [A, 1, 256, 256] float32
966
- image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
967
- cfg: "QLTPOConfig",
968
- beta: float = 0.5,
969
- ) -> Tuple[float, float]:
970
- """Task-specific proxy reward using frozen q_init (Fseg) as teacher.
971
-
972
- q_init = Fseg is already the audio+video+text fusion token produced by SimToken.
973
- Using it as a frozen reference breaks Stage 2's self-confirming bias while
974
- measuring whether the mask region aligns with the correct referent.
975
-
976
- Returns:
977
- R_avt = mean_t cos(z_in_t, q_init) [scalar]
978
- R_avt_c = mean_t [cos(z_in_t, q_init) - beta·cos(z_out_t, q_init)] [scalar]
979
- """
980
- A = lrm.shape[0]
981
- q_norm = F.normalize(q_init_fp32[0], dim=0) # [256]
982
-
983
- masks_64 = F.interpolate(
984
- torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
985
- size=(64, 64), mode="bilinear", align_corners=False,
986
- ).squeeze(1) # [A, 64, 64]
987
-
988
- r_avt, r_avt_c = 0.0, 0.0
989
- for t in range(A):
990
- m = masks_64[t]
991
- img = image_embeds_anchor_fp32[t]
992
- z_in = F.normalize(
993
- (img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0
994
- )
995
- z_out = F.normalize(
996
- (img * (1.0 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1.0 - m).sum() + 1e-6), dim=0
997
- )
998
- c_in = (q_norm @ z_in).item()
999
- c_out = (q_norm @ z_out).item()
1000
- r_avt += c_in
1001
- r_avt_c += c_in - beta * c_out
1002
- return r_avt / A, r_avt_c / A
1003
-
1004
-
1005
- # ---------------------------------------------------------------------------
1006
- # Stage 1–3: q-LTPO-autograd main optimizer
1007
- # ---------------------------------------------------------------------------
1008
-
1009
- def q_ltpo_autograd(
1010
- F_init: torch.Tensor, # [1, 256] any dtype on CUDA
1011
- image_embeds: torch.Tensor, # [T, 256, 64, 64] any dtype on CUDA
1012
- anchor_indices: List[int],
1013
- sam_model,
1014
- model_dtype: torch.dtype,
1015
- cfg: QLTPOConfig,
1016
- ) -> torch.Tensor:
1017
- """Optimise the SAM prompt token q at test time via Adam maximize.
1018
-
1019
- q is initialised to F_init (= Fseg after text_hidden_fcs projection).
1020
- The prompt encoder is bypassed: sparse_emb = q.unsqueeze(1), identical
1021
- to what prompt_encoder produces when text_embeds is the only prompt.
1022
-
1023
- All computation is done in float32 to avoid fp16 gradient truncation.
1024
- Returns best_q as float32 [1, 256]. Falls back to F_init when gating
1025
- rejects all updates.
1026
- """
1027
- device = F_init.device
1028
-
1029
- # ── Precompute constants (float32, detached) ──────────────────────────
1030
- q_init_fp32 = F_init.float().detach()
1031
- image_embeds_anchor = image_embeds[anchor_indices].float().detach()
1032
- dense_emb = _precompute_dense_emb(sam_model, model_dtype, device).float().detach()
1033
- dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device).float().detach()
1034
- mask_dec = sam_model.mask_decoder
1035
-
1036
- # ── Auto-scale lr and max_drift from q_init magnitude ─────────────────
1037
- rms = q_init_fp32.norm() / (q_init_fp32.numel() ** 0.5)
1038
- lr = cfg.lr if cfg.lr > 0 else 0.01 * rms.item()
1039
- max_drift = cfg.max_drift if cfg.max_drift > 0 else 0.5 * q_init_fp32.norm().item()
1040
-
1041
- # ── Precompute frozen external reference (stages 21, 22 only) ────────
1042
- r_ref, r_neg = None, None
1043
- if cfg.stage in (21, 22):
1044
- r_ref, r_neg = _compute_r_ref(q_init_fp32, image_embeds_anchor, cfg.r_ref_temp)
1045
-
1046
- # ── Baseline forward + e0 existence prior ────────────────────────────
1047
- with torch.no_grad():
1048
- lrm0, iou0 = _decode_on_anchors_diff(
1049
- q_init_fp32, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1050
- )
1051
- # e0 = stopgrad(R_area_soft(q_init)): fixes the scalar before the loop.
1052
- # Suppresses R_iou when the initial mask is near-empty (existence prior).
1053
- r_area_soft_init = torch.sigmoid(lrm0 / cfg.area_temp).mean().item()
1054
- e0 = _compute_e0(r_area_soft_init, cfg)
1055
-
1056
- R_init_task = _compute_task_reward(
1057
- q_init_fp32, lrm0, iou0, image_embeds_anchor, cfg, e0=e0,
1058
- r_ref=r_ref, r_neg=r_neg,
1059
- ).item()
1060
-
1061
- # ── Optimisation setup ────────────────────────────────────────────────
1062
- q = torch.nn.Parameter(q_init_fp32.clone())
1063
- optimizer = torch.optim.Adam([q], lr=lr, maximize=True)
1064
-
1065
- best_q = q.detach().clone()
1066
- best_reward = R_init_task
1067
- hit_clip = False
1068
-
1069
- # ── Optimisation loop ─────────────────────────────────────────────────
1070
- # Track per-step soft area to diagnose whether B1 penalty ever activates.
1071
- _step_soft_areas: List[float] = []
1072
-
1073
- for step in range(cfg.T):
1074
- optimizer.zero_grad()
1075
-
1076
- lrm, iou = _decode_on_anchors_diff(
1077
- q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1078
- )
1079
- R_full = _compute_full_reward(q, lrm, iou, image_embeds_anchor, q_init_fp32, cfg, e0=e0,
1080
- r_ref=r_ref, r_neg=r_neg)
1081
- R_full.backward()
1082
- optimizer.step()
1083
-
1084
- # Hard L2 norm clip: keep q within max_drift ball around q_init
1085
- with torch.no_grad():
1086
- diff = q - q_init_fp32
1087
- d = diff.norm()
1088
- if d > max_drift:
1089
- q.copy_(q_init_fp32 + diff * (max_drift / d))
1090
- hit_clip = True
1091
-
1092
- # Fresh no_grad forward on the post-step q_{N+1} for correct tracking.
1093
- # (Pre-step lrm/iou would mismatch the updated q, causing wrong best_q.)
1094
- with torch.no_grad():
1095
- lrm_eval, iou_eval = _decode_on_anchors_diff(
1096
- q.detach(), image_embeds_anchor, dense_emb, mask_dec, dense_pe
1097
- )
1098
- # Record soft area at this step for B1 activation diagnosis
1099
- _step_soft_areas.append(
1100
- torch.sigmoid(lrm_eval / cfg.area_temp).mean().item()
1101
- )
1102
- r_task = _compute_task_reward(
1103
- q.detach(), lrm_eval, iou_eval, image_embeds_anchor, cfg, e0=e0,
1104
- r_ref=r_ref, r_neg=r_neg,
1105
- ).item()
1106
- if r_task > best_reward:
1107
- best_reward = r_task
1108
- best_q = q.detach().clone()
1109
-
1110
- # Peak excess: how much did soft area exceed e0 at its highest point?
1111
- # b1_peak_excess > 0 ↔ B1 ReLU was non-zero at that step.
1112
- # b1_peak_excess = 0 ↔ B1 never activated (area stayed below e0 throughout).
1113
- _max_step_area = max(_step_soft_areas) if _step_soft_areas else r_area_soft_init
1114
- b1_peak_excess = max(_max_step_area - e0, 0.0)
1115
-
1116
- # ── Reward gating: clean re-eval of best_q vs q_init ─────────────────
1117
- with torch.no_grad():
1118
- lrm_b, iou_b = _decode_on_anchors_diff(
1119
- best_q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1120
- )
1121
- R_best_task = _compute_task_reward(
1122
- best_q, lrm_b, iou_b, image_embeds_anchor, cfg, e0=e0,
1123
- r_ref=r_ref, r_neg=r_neg,
1124
- ).item()
1125
-
1126
- area_init = (lrm0 > 0).float().mean().item()
1127
- effective_gate = (
1128
- cfg.null_gate_delta
1129
- if (cfg.null_gate_delta > 0 and area_init < cfg.null_area_threshold)
1130
- else cfg.gate_delta
1131
- )
1132
- accepted = R_best_task > R_init_task + effective_gate
1133
-
1134
- # ── Mask soft-IoU: how much did the mask actually change? ─────────────
1135
- # Answers whether q-drift translated into mask change, or fell in a
1136
- # flat direction of the mask decoder manifold.
1137
- with torch.no_grad():
1138
- m0 = torch.sigmoid(lrm0 / cfg.area_temp).squeeze(1) # [A,256,256]
1139
- mb = torch.sigmoid(lrm_b / cfg.area_temp).squeeze(1) # [A,256,256]
1140
- inter = (m0 * mb).sum(dim=[1, 2])
1141
- union = (m0 + mb - m0 * mb).sum(dim=[1, 2])
1142
- mask_soft_iou = (inter / (union + 1e-6)).mean().item()
1143
-
1144
- # Soft area at best_q — tracks whether B1 asymmetric penalty worked
1145
- r_area_soft_best = mb.mean().item() # sigmoid(lrm_b/area_temp).mean()
1146
-
1147
- # Reward decomposition: iou contribution to reward gain
1148
- R_iou_contrib_gain = (
1149
- cfg.lambda_iou * e0 * (iou_b.mean().item() - iou0.mean().item())
1150
- )
1151
-
1152
- # AVT proxy reward (Step A0 correlation study)
1153
- r_avt_init, r_avt_c_init = _compute_avt_proxy_reward(
1154
- q_init_fp32, lrm0, image_embeds_anchor, cfg
1155
- )
1156
- r_avt_best, r_avt_c_best = _compute_avt_proxy_reward(
1157
- q_init_fp32, lrm_b, image_embeds_anchor, cfg
1158
- )
1159
-
1160
- # ── Per-sample diagnostics ────────────────────────────────────────────
1161
- _q_ltpo_stats.append({
1162
- "accepted": accepted,
1163
- "reward_gain": R_best_task - R_init_task,
1164
- "drift": (best_q - q_init_fp32).norm().item(),
1165
- "hit_clip": hit_clip,
1166
- "e0": e0,
1167
- "R_iou_pred_init": iou0.mean().item(),
1168
- "R_iou_pred_best": iou_b.mean().item(),
1169
- "area_hard_init": area_init,
1170
- "area_hard_best": (lrm_b > 0).float().mean().item(),
1171
- "r_area_soft_init": r_area_soft_init,
1172
- "r_area_soft_best": r_area_soft_best,
1173
- "b1_peak_excess": b1_peak_excess,
1174
- "mask_soft_iou": mask_soft_iou,
1175
- "R_iou_contrib_gain": R_iou_contrib_gain,
1176
- # AVT proxy: frozen q_init as teacher — task-specific alignment
1177
- "r_avt_init": r_avt_init,
1178
- "r_avt_best": r_avt_best,
1179
- "r_avt_gain": r_avt_best - r_avt_init,
1180
- "r_avt_c_init": r_avt_c_init,
1181
- "r_avt_c_best": r_avt_c_best,
1182
- "r_avt_c_gain": r_avt_c_best - r_avt_c_init,
1183
- })
1184
-
1185
- if not accepted:
1186
- return F_init.float()
1187
- return best_q
1188
-
1189
-
1190
- # ===========================================================================
1191
- # Direction II: Frame-adaptive token optimization (stage=4)
1192
- # q_t = q_global + delta_t — shared global token + per-anchor residual
1193
- # ===========================================================================
1194
-
1195
- def q_ltpo_frame_adaptive(
1196
- F_init: torch.Tensor, # [1, 256] any dtype on CUDA
1197
- image_embeds: torch.Tensor, # [T, 256, 64, 64] any dtype on CUDA
1198
- anchor_indices: List[int],
1199
- sam_model,
1200
- model_dtype: torch.dtype,
1201
- cfg: QLTPOConfig,
1202
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1203
- """Frame-adaptive q-LTPO: optimize q_global and per-anchor delta jointly.
1204
-
1205
- Each anchor frame t gets its own token q_t = q_global + delta_t.
1206
- delta_t is initialized to zero so q_t starts equal to q_init for all frames.
1207
- Per-frame existence priors e0_t suppress optimization on near-empty anchors.
1208
-
1209
- Returns:
1210
- q_global [1, 256] float32 — shared global token
1211
- delta [A, 256] float32 — per-anchor residuals (zero if not accepted)
1212
- """
1213
- device = F_init.device
1214
- A = len(anchor_indices)
1215
-
1216
- q_init_fp32 = F_init.float().detach()
1217
- image_embeds_anchor = image_embeds[anchor_indices].float().detach()
1218
- dense_emb = _precompute_dense_emb(sam_model, model_dtype, device).float().detach()
1219
- dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device).float().detach()
1220
- mask_dec = sam_model.mask_decoder
1221
-
1222
- rms = q_init_fp32.norm() / (q_init_fp32.numel() ** 0.5)
1223
- lr = cfg.lr if cfg.lr > 0 else 0.01 * rms.item()
1224
- max_drift = cfg.max_drift if cfg.max_drift > 0 else 0.5 * q_init_fp32.norm().item()
1225
- max_delta_drift = cfg.max_delta_drift_scale * q_init_fp32.norm().item()
1226
-
1227
- # ── Baseline: per-anchor e0 existence priors ────────────────────────────
1228
- with torch.no_grad():
1229
- lrm0, iou0 = _decode_on_anchors_diff(
1230
- q_init_fp32, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1231
- )
1232
- e0_vec: List[float] = []
1233
- for t in range(A):
1234
- e0_t = torch.sigmoid(lrm0[t] / cfg.area_temp).mean().item()
1235
- e0_vec.append(_compute_e0(e0_t, cfg))
1236
- e0_global = sum(e0_vec) / A
1237
-
1238
- R_init_task = _task_reward_frame_adaptive(lrm0, iou0, cfg, e0_vec).item()
1239
-
1240
- # ── Setup optimization ───────────────────────────────────────────────────
1241
- q_global = torch.nn.Parameter(q_init_fp32.clone())
1242
- delta = torch.nn.Parameter(torch.zeros(A, 256, device=device, dtype=torch.float32))
1243
- optimizer = torch.optim.Adam([q_global, delta], lr=lr, maximize=True)
1244
-
1245
- best_q_global = q_global.detach().clone()
1246
- best_delta = delta.detach().clone()
1247
- best_reward = R_init_task
1248
- hit_clip = False
1249
-
1250
- # ── Optimization loop ────────────────────────────────────────────────────
1251
- for step in range(cfg.T):
1252
- optimizer.zero_grad()
1253
- lrm, iou = _decode_on_anchors_diff_adaptive(
1254
- q_global, delta, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1255
- )
1256
- R_full = _compute_full_reward_adaptive(
1257
- q_global, delta, lrm, iou, q_init_fp32, cfg, e0_vec
1258
- )
1259
- R_full.backward()
1260
- optimizer.step()
1261
-
1262
- # Clip q_global and each per-anchor delta within trust regions
1263
- with torch.no_grad():
1264
- diff = q_global - q_init_fp32
1265
- d = diff.norm()
1266
- if d > max_drift:
1267
- q_global.copy_(q_init_fp32 + diff * (max_drift / d))
1268
- hit_clip = True
1269
- for t in range(A):
1270
- dn = delta[t].norm()
1271
- if dn > max_delta_drift:
1272
- delta[t].copy_(delta[t] * (max_delta_drift / dn))
1273
-
1274
- # Track best (no_grad re-eval of task reward without reg)
1275
- with torch.no_grad():
1276
- lrm_eval, iou_eval = _decode_on_anchors_diff_adaptive(
1277
- q_global.detach(), delta.detach(),
1278
- image_embeds_anchor, dense_emb, mask_dec, dense_pe
1279
- )
1280
- r_task = _task_reward_frame_adaptive(lrm_eval, iou_eval, cfg, e0_vec).item()
1281
- if r_task > best_reward:
1282
- best_reward = r_task
1283
- best_q_global = q_global.detach().clone()
1284
- best_delta = delta.detach().clone()
1285
-
1286
- # ── Gating ───────────────────────────────────────────────────────────────
1287
- with torch.no_grad():
1288
- lrm_b, iou_b = _decode_on_anchors_diff_adaptive(
1289
- best_q_global, best_delta, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1290
- )
1291
- R_best_task = _task_reward_frame_adaptive(lrm_b, iou_b, cfg, e0_vec).item()
1292
-
1293
- accepted = R_best_task > R_init_task + cfg.gate_delta
1294
-
1295
- area_init = (lrm0 > 0).float().mean().item()
1296
- r_area_soft_init = sum(torch.sigmoid(lrm0[t] / cfg.area_temp).mean().item() for t in range(A)) / A
1297
- r_area_soft_best = sum(torch.sigmoid(lrm_b[t] / cfg.area_temp).mean().item() for t in range(A)) / A
1298
-
1299
- # Actual mask soft-IoU between init and best (per anchor, averaged)
1300
- m0 = torch.sigmoid(lrm0 / cfg.area_temp).squeeze(1) # [A,256,256]
1301
- mb = torch.sigmoid(lrm_b / cfg.area_temp).squeeze(1) # [A,256,256]
1302
- inter = (m0 * mb).sum(dim=[1, 2])
1303
- union = (m0 + mb - m0 * mb).sum(dim=[1, 2])
1304
- mask_soft_iou_fa = (inter / (union + 1e-6)).mean().item()
1305
-
1306
- _q_ltpo_stats.append({
1307
- "accepted": accepted,
1308
- "reward_gain": R_best_task - R_init_task,
1309
- "drift": (best_q_global - q_init_fp32).norm().item(),
1310
- "delta_norm": best_delta.norm().item(),
1311
- "hit_clip": hit_clip,
1312
- "e0": e0_global,
1313
- "R_iou_pred_init": iou0.mean().item(),
1314
- "R_iou_pred_best": iou_b.mean().item(),
1315
- "area_hard_init": area_init,
1316
- "area_hard_best": (lrm_b > 0).float().mean().item(),
1317
- "r_area_soft_init": r_area_soft_init,
1318
- "r_area_soft_best": r_area_soft_best,
1319
- "b1_peak_excess": 0.0,
1320
- "mask_soft_iou": mask_soft_iou_fa,
1321
- "R_iou_contrib_gain": cfg.lambda_iou * e0_global * (iou_b.mean().item() - iou0.mean().item()),
1322
- })
1323
-
1324
- if not accepted:
1325
- return q_init_fp32, torch.zeros(A, 256, device=device, dtype=torch.float32)
1326
- return best_q_global, best_delta
1327
-
1328
-
1329
- def decode_full_video_adaptive(
1330
- q_global: torch.Tensor, # [1, 256] float32
1331
- delta: torch.Tensor, # [A, 256] float32
1332
- anchor_indices: List[int],
1333
- image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
1334
- sam_model,
1335
- resize: tuple,
1336
- orgsize: tuple,
1337
- model_dtype: torch.dtype,
1338
- ) -> torch.Tensor:
1339
- """Decode all T frames with frame-adaptive tokens.
1340
-
1341
- Each frame is assigned to its nearest anchor by index distance, then decoded
1342
- with q_t = q_global + delta[anchor_idx].
1343
- Returns raw logit masks [T, H_orig, W_orig].
1344
- """
1345
- T = image_embeds.shape[0]
1346
- A = len(anchor_indices)
1347
- device = image_embeds.device
1348
-
1349
- dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
1350
- dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
1351
-
1352
- # Nearest-anchor assignment for every frame
1353
- anchor_arr = torch.tensor(anchor_indices, dtype=torch.float32)
1354
- frame_to_anchor = [int((anchor_arr - t).abs().argmin().item()) for t in range(T)]
1355
-
1356
- pred_masks: List[torch.Tensor] = []
1357
- with torch.no_grad():
1358
- for t in range(T):
1359
- a = frame_to_anchor[t]
1360
- q_t = (q_global + delta[a : a + 1]).to(model_dtype) # [1, 256]
1361
- sparse_emb = q_t.unsqueeze(1) # [1, 1, 256]
1362
- lrm_t, _ = sam_model.mask_decoder(
1363
- image_embeddings=image_embeds[t : t + 1],
1364
- image_pe=dense_pe,
1365
- sparse_prompt_embeddings=sparse_emb,
1366
- dense_prompt_embeddings=dense_emb,
1367
- multimask_output=False,
1368
- ) # [1, 1, 256, 256]
1369
- pred_t = sam_model.postprocess_masks(lrm_t, input_size=resize, original_size=orgsize)
1370
- pred_masks.append(pred_t.squeeze(0).squeeze(0)) # [H, W]
1371
-
1372
- return torch.stack(pred_masks, dim=0) # [T, H_orig, W_orig]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup_simtoken.md DELETED
@@ -1,163 +0,0 @@
1
- # SimToken Setup
2
-
3
- 本文档用于在新机器上重建 SimToken 环境,并准备后续 A-min 实验。
4
-
5
- ---
6
-
7
- ## 1. Create Environment
8
-
9
- 先确认 GPU 和 CUDA driver 状态:
10
-
11
- ```bash
12
- nvidia-smi
13
- ```
14
-
15
- 创建 conda 环境:
16
-
17
- ```bash
18
- /opt/miniforge3/condabin/conda create -n simtoken python=3.10 -y
19
- /opt/miniforge3/condabin/conda activate simtoken
20
-
21
- python -m pip install --upgrade pip wheel "setuptools<81"
22
-
23
- pip install \
24
- torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 \
25
- --index-url https://download.pytorch.org/whl/cu121
26
-
27
- pip install \
28
- transformers==4.30.2 \
29
- peft==0.2.0 \
30
- accelerate==0.21.0 \
31
- sentencepiece \
32
- protobuf \
33
- safetensors \
34
- numpy==1.26.4 \
35
- pandas \
36
- matplotlib \
37
- opencv-python \
38
- pillow \
39
- tqdm \
40
- einops \
41
- timm \
42
- requests \
43
- towhee \
44
- huggingface_hub
45
- ```
46
-
47
- 快速验证:
48
-
49
- ```bash
50
- python - <<'PY'
51
- import torch
52
- print("torch:", torch.__version__)
53
- print("cuda available:", torch.cuda.is_available())
54
- print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")
55
- PY
56
- ```
57
-
58
- ---
59
-
60
- ## 2. Download from HuggingFace
61
-
62
- 如果新机器不使用迁移工具,而是从 HuggingFace 重新初始化,先登录:
63
-
64
- ```bash
65
- huggingface-cli login
66
- ```
67
-
68
- 下载完整 repo:
69
-
70
- ```bash
71
- mkdir -p /workspace/SimToken
72
- cd /workspace/SimToken
73
-
74
- huggingface-cli download yfan07/SimToken \
75
- --repo-type model \
76
- --local-dir . \
77
- --local-dir-use-symlinks False
78
- ```
79
-
80
- 下载完成后解压数据:
81
-
82
- ```bash
83
- cd /workspace/SimToken/data
84
-
85
- tar -xf image_embed.tar
86
- tar -xzf gt_mask.tar.gz
87
- tar -xzf audio_embed.tar.gz
88
- tar -xf media.tar
89
- ```
90
-
91
- ---
92
-
93
- ## 3. Pre-download Model Weights
94
-
95
- `transformers==4.30.2` 与新版 `huggingface_hub` 可能存在网络/API 兼容问题。建议先用 CLI 将模型下载到本地缓存,实验时再加 `TRANSFORMERS_OFFLINE=1`。
96
-
97
- ```bash
98
- # Chat-UniVi-7B
99
- huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
100
-
101
- # CLIP ViT-L
102
- huggingface-cli download openai/clip-vit-large-patch14
103
- ```
104
-
105
- 下载完成后做离线验证:
106
-
107
- ```bash
108
- cd /workspace/SimToken
109
-
110
- TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
111
- python -m py_compile train.py load_model.py decoder_invariance_check.py
112
- ```
113
-
114
- ---
115
-
116
- ## 4. Upload to HuggingFace
117
-
118
- 实验结束后,如需重新上传到 HuggingFace,先将数据目录压缩为归档文件,减少文件数量:
119
-
120
- ```bash
121
- cd /workspace/SimToken/data
122
-
123
- tar -cf image_embed.tar image_embed/
124
- tar -czf gt_mask.tar.gz gt_mask/
125
- tar -czf audio_embed.tar.gz audio_embed/
126
- tar -cf media.tar media/
127
-
128
- ls -lh *.tar*
129
-
130
- # HuggingFace 单文件硬限制为 50GB;如果 image_embed.tar 超过 50GB,
131
- # 需要切成小于 50GB 的分片再上传。
132
- split -b 45G -d -a 2 image_embed.tar image_embed.tar.part-
133
-
134
- # 校验分片拼接后仍能读出完整 tar 文件列表。
135
- cat image_embed.tar.part-* | tar -tf - | grep -v '/$' | wc -l
136
-
137
- # 分片校验通过后再删除超大原始 tar,避免上传失败。
138
- rm -f image_embed.tar
139
-
140
- rm -rf image_embed/ gt_mask/ audio_embed/ media/
141
- ```
142
-
143
- 下载后如需恢复 `image_embed.tar`:
144
-
145
- ```bash
146
- cd /workspace/SimToken/data
147
- cat image_embed.tar.part-* > image_embed.tar
148
- tar -xf image_embed.tar
149
- ```
150
-
151
- 清理缓存并上传:
152
-
153
- ```bash
154
- cd /workspace/SimToken
155
-
156
- find . -name "__pycache__" -prune -exec rm -rf {} +
157
- find . -name ".pytest_cache" -prune -exec rm -rf {} +
158
- find . -name ".cache" -prune -exec rm -rf {} +
159
- find . -name "*.pyc" -delete
160
-
161
- huggingface-cli login
162
- python upload_hf.py --repo yfan07/SimToken
163
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
simtoken_experiment.md DELETED
@@ -1,369 +0,0 @@
1
- # SimToken 实验路线文档
2
-
3
- ## 0. 当前状态
4
-
5
- 前置诊断已经完成,路线收敛到 **A-min dynamic referent gate training**。
6
-
7
- 已确认结论:
8
-
9
- 1. **SAM decoder 下游是逐帧 batch-parallel 解码**
10
- `mask_decoder(image_embeddings[0:T])[t]` 与 `mask_decoder(image_embeddings[t:t+1])[0]` 只有混合精度数值噪声差异。旧的 decoder-level joint-frame competition 假设关闭。
11
-
12
- 2. **target_frame sweep 基本无效**
13
- 不同 target frame 生成的 q 几乎相同,`cos_to_q5` 通常在 `0.997+`;Seen/Null 上 oracle gain 约 `+0.0009`。这条 TTO 路线关闭。
14
-
15
- 3. **raw SAM-space D2 失效**
16
- 256 维 `q/Fseg` 与 SAM image embedding 不在可直接 cosine 的语义空间,`real q ≈ shuffled/wrong_ref q`,甚至 random q 更高。该定义关闭。
17
-
18
- 4. **LLM-space D2 有弱诊断信号,但不适合作为主 reward**
19
- 用 4096 维 `[SEG]` hidden state 与 `mm_projector(CLIP patch tokens)` 后的视觉 token 计算 D2,可以得到正相关:
20
- - `corr(s_pred, frame_iou) ≈ +0.316`
21
- - bottom 20% `s_pred` 中 failure rate 相比随机 baseline 约 `1.60x`
22
- - 控制 `iou_pred` / `pred_area` 后偏相关约 `+0.14`
23
-
24
- 结论:`s_pred(beta=1.0)` 可以作为诊断信号或 frame-aware gate 的候选输入,但不能作为核心 TTO reward。
25
-
26
- 5. **margin-D2 无效**
27
- 离线 `s_margin = s(real) - max(s(shuffled), s(wrong_ref))` 的 failure enrichment 约 `0.93x`,会抵消掉有用的通用可见性/质量信号。该路线关闭。
28
-
29
- 当前最干净的解释是:
30
-
31
- > q 本身通常是稳定的 referent anchor;主要瓶颈不在 q 生成,也不在简单 q selection,而在 SAM decoder 如何使用已有的 `mask_token -> q` sparse self-attention path。
32
-
33
- 2026-04-22 更新:
34
-
35
- 完整训练每个 epoch 约 2-4 小时,瓶颈主要在 7B MLLM forward,而不在 gate 本身。因此当前实验策略已调整为:
36
-
37
- 1. 先缓存固定 checkpoint 下的 `q = seg_embeddings`;
38
- 2. 在 cached q + cached SAM image embeddings 上训练 gate-only;
39
- 3. 用 cached eval split 快速判断 gate 是否有泛化收益;
40
- 4. 只有 gate-only 泛化信号成立后,再跑完整 A-min 联合训练。
41
-
42
- ---
43
-
44
- ## 1. A-min 当前实现
45
-
46
- 已在代码中加入 A-min dynamic referent gate:
47
-
48
- - 文件:`models/segment_anything/modeling/transformer.py`
49
- - 模块:`ReferentGate`
50
- - 插入位置:`TwoWayAttentionBlock` 的 sparse self-attention + `norm1` 之后,token-to-image cross-attention 之前
51
- - 作用对象:只作用于 `mask_tokens`
52
- - 不作用于:`iou_token` 和 `q/sparse_prompt` 本身
53
-
54
- SAM token index:
55
-
56
- ```python
57
- tokens = [iou_token, mask_tokens..., sparse_prompt(q)]
58
- ```
59
-
60
- 因此:
61
-
62
- ```python
63
- iou_token index: 0
64
- mask token range: 1 : 1 + num_mask_tokens
65
- q token index: 1 + num_mask_tokens
66
- ```
67
-
68
- A-min gate 形式:
69
-
70
- ```python
71
- alpha = sigmoid(Linear([mask_token, q, cos(mask_token, q)]))
72
- mask_token = mask_token + alpha * Linear(q)
73
- ```
74
-
75
- 为保证旧 checkpoint 初始行为不变,`proj(q)` 分支使用零初始化。当前也将 `gate` 分支零初始化,使 alpha 有干净观测基线:
76
-
77
- ```python
78
- nn.init.zeros_(self.gate.weight)
79
- nn.init.zeros_(self.gate.bias)
80
- nn.init.zeros_(self.proj.weight)
81
- nn.init.zeros_(self.proj.bias)
82
- ```
83
-
84
- 初始时 gate 为 identity:
85
-
86
- ```text
87
- max_abs_diff(gate(mask, q), mask) = 0.0
88
- alpha_mean = 0.5
89
- alpha_std = 0.0
90
- ```
91
-
92
- 当前训练 forward 保持完整链路:`prepare_inputs_labels_for_multimodal -> MLLM forward -> text_hidden_fcs -> SAM mask decoder -> loss`。`--gate_only` 只控制参数冻结范围,不再改变 forward 语义。
93
-
94
- ---
95
-
96
- ## 2. 当前新增工具
97
-
98
- ### 2.1 训练脚本增强
99
-
100
- `train.py` 已加入:
101
-
102
- - `--max_steps`
103
- - `--overfit_samples`
104
- - `--log_gate_stats_every`
105
- - `--skip_eval_after_train`
106
- - `--eval_train_only`
107
-
108
- 启动时会打印 referent gate 参数是否 trainable、是否进入 optimizer,以及初始 `proj_norm/gate_norm`。
109
-
110
- ### 2.2 cached q 路线
111
-
112
- 新增脚本:
113
-
114
- - `cache_q_features.py`
115
- - 离线缓存 `q = seg_embeddings`
116
- - cache 文件很小,因为只保存 q 和少量 metadata
117
- - `image_embeddings` 仍使用已有 `data/image_embed/{vid}.pt`
118
- - `gt_masks` 仍使用已有 `data/gt_mask/...`
119
-
120
- - `train_cached_gate.py`
121
- - 加载 base model 和 cached q
122
- - 冻结全部参数,只训练 `referent_gate`
123
- - 支持 `--eval_only`、`--disable_gate`
124
- - 支持 `--save_gate_only`,只保存 gate 参数,checkpoint 约 1.6MB
125
- - 支持 `--gate_checkpoint`,在 base checkpoint 上 overlay gate-only checkpoint
126
- - gate stats 会记录:
127
-
128
- ```text
129
- batch_miou
130
- batch_fscore
131
- proj_norm
132
- gate_norm
133
- proj_grad_norm
134
- gate_grad_norm
135
- alpha_mean / alpha_std / alpha_min / alpha_max
136
- ```
137
-
138
- cached 解码已优化:一个 dataloader batch 会展平成 paired frame batch 调用 `mask_decoder.forward_modified_v3`,避免逐 sample 调 decoder 的主要开销,同时不会产生 prompt/image cross product。
139
-
140
- ---
141
-
142
- ## 3. 已完成实验结果
143
-
144
- ### 3.1 cached identity 与原始 pipeline 一致性
145
-
146
- 先用 `test_s` 前 10 条验证 cached pipeline 是否与原始 `load_model.py` 对齐:
147
-
148
- ```text
149
- cached identity:
150
- mIoU = 0.9686462879
151
- Fscore = 0.9868578851
152
-
153
- original load_model.py:
154
- mIoU = 0.9686277151
155
- Fscore = 0.9868472159
156
-
157
- diff:
158
- mIoU = +0.0000186
159
- Fscore = +0.0000107
160
- ```
161
-
162
- 结论:差异远小于 0.001,cached q pipeline 与原始 eval pipeline 一致,可以用于 gate-only 快速验证。
163
-
164
- ### 3.2 gate probe:梯度路径与 alpha 分化
165
-
166
- 在 cached train128 上跑 50 optimizer steps:
167
-
168
- ```text
169
- step 5:
170
- proj_norm=0.074015
171
- gate_norm=0.064479
172
- proj_grad_norm=0.052291
173
- gate_grad_norm=0.000170
174
- alpha_mean=0.4999
175
- alpha_std=0.0019
176
-
177
- step 50:
178
- proj_norm=0.428711
179
- gate_norm=0.523223
180
- proj_grad_norm=0.022453
181
- gate_grad_norm=0.000504
182
- alpha_mean=0.5063
183
- alpha_std=0.0112
184
- ```
185
-
186
- 结论:
187
-
188
- - `proj_norm` 从 0 稳定增长,注入分支有梯度;
189
- - `gate_norm` 也开始增长,alpha 控制分支参与学习;
190
- - `alpha_std` 从 0 增长,说明 gate 对不同输入有分化响应;
191
- - 计算图、冻结范围、optimizer param groups 均正常。
192
-
193
- ### 3.3 overfit32:表达能力验证
194
-
195
- cached train32 identity baseline:
196
-
197
- ```text
198
- mIoU = 0.8814558
199
- Fscore = 0.9375512
200
- ```
201
-
202
- cached gate overfit32,200 steps,lr=1e-4:
203
-
204
- ```text
205
- mIoU = 0.9085821
206
- Fscore = 0.9444574
207
- ```
208
-
209
- 提升:
210
-
211
- ```text
212
- mIoU = +0.0271263
213
- Fscore = +0.0069063
214
- ```
215
-
216
- 结论:在 q、SAM image embeddings、mask decoder 原始参数均固定时,仅训练 A-min gate 就能明显提高训练集 mIoU,说明 gate 机制有表达能力,梯度路径通畅。
217
-
218
- ### 3.4 overfit32 泛化评估
219
-
220
- 对 cached eval split 前 200 条,identity baseline:
221
-
222
- ```text
223
- test_s mIoU = 0.7390979
224
- test_s Fscore = 0.8190672
225
-
226
- test_u mIoU = 0.6732285
227
- test_u Fscore = 0.7734924
228
-
229
- test_n metric = 0.0606105
230
- ```
231
-
232
- overfit32 gate checkpoint:
233
-
234
- ```text
235
- test_s mIoU = 0.7199481
236
- test_s Fscore = 0.8045849
237
-
238
- test_u mIoU = 0.6672303
239
- test_u Fscore = 0.7663978
240
-
241
- test_n metric = 0.0648588
242
- ```
243
-
244
- delta:
245
-
246
- ```text
247
- test_s mIoU = -0.0191498
248
- test_s Fscore = -0.0144823
249
-
250
- test_u mIoU = -0.0059983
251
- test_u Fscore = -0.0070946
252
-
253
- test_n metric = +0.0042483
254
- ```
255
-
256
- 结论:
257
-
258
- - overfit32 gate 没有泛化;
259
- - Null metric 略升,说明小样本过拟合有轻微放大前景的倾向;
260
- - 这不是方法失败,而是 32 个样本不足以学到泛化 referent anchoring 的预期结果;
261
- - 下一步应扩大 cached train 样本量,并降低 lr。
262
-
263
- ---
264
-
265
- ## 4. 当前下一步实验:cached train256 gate-only
266
-
267
- 用户已经完成 train256 的 q 缓存。下一步用 train256 跑更保守的 gate-only 泛化实验。
268
-
269
- ### Step 1:训练 cached gate-only train256
270
-
271
- ```bash
272
- cd /workspace/SimToken
273
- mkdir -p log checkpoints
274
-
275
- TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
276
- --cache_split train \
277
- --cache_root /workspace/SimToken/cache_q \
278
- --name cached_gate_train256_s300_lr3e5 \
279
- --epochs 20 \
280
- --max_steps 300 \
281
- --batch_size 8 \
282
- --lr 3e-5 \
283
- --saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
284
- --log_root /workspace/SimToken/log \
285
- --checkpoint_root /workspace/SimToken/checkpoints \
286
- --log_gate_stats_every 50 \
287
- --skip_eval_after_train \
288
- --save_gate_only \
289
- 2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5.stdout
290
- ```
291
-
292
- 训练中重点观察:
293
-
294
- ```text
295
- batch_miou / batch_fscore 是否逐步改善
296
- proj_norm 是否持续增长
297
- alpha_std 是否温和分化
298
- Null 风险:alpha 是否出现极端偏移
299
- ```
300
-
301
- 如果 `proj_norm` 在前 100 steps 仍接近 0,说明 lr=3e-5 可能过小,可以改回 1e-4 或使用分层 lr。
302
-
303
- ### Step 2:评估 cached train256 gate checkpoint
304
-
305
- ```bash
306
- for split in test_s test_u test_n; do
307
- TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
308
- --cache_split $split \
309
- --cache_root /workspace/SimToken/cache_q \
310
- --batch_size 8 \
311
- --saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
312
- --gate_checkpoint /workspace/SimToken/checkpoints/cached_gate_train256_s300_lr3e5.pth \
313
- --eval_only \
314
- --name cached_gate_train256_s300_lr3e5_${split}_200 \
315
- 2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5_${split}_200.stdout
316
- done
317
- ```
318
-
319
- 对比 baseline 使用 3.4 中 identity 200 条结果。
320
-
321
- ### Step 3:根据结果决策
322
-
323
- 判断标准:
324
-
325
- - Seen / Unseen 都提升:进入更大 cached train 或完整 A-min;
326
- - Seen 提升、Unseen 不提升:gate 仍可能学 dataset pattern,需要更多 train cache 或更强正则;
327
- - Seen / Unseen 都下降:不要跑完整 A-min,先调 lr、正则或 gate 容量;
328
- - Null metric 保持 `< 0.07`:暂不加 area penalty;
329
- - Null metric 超过 `0.10`:强危险信号,需要 area penalty 或约束预测面积。
330
-
331
- 如果 train256 有弱正收益但幅度小,先看 alpha 分布和 hard/easy frames,而不是立刻扩大。若 alpha 在所有帧上几乎一致,可能只是全局偏置;若 hard frames alpha 系统性更高,说明更像 referent anchoring。
332
-
333
- ---
334
-
335
- ## 5. 成功标准
336
-
337
- A-min 成功不能只看总体 mIoU,需要同时满足:
338
-
339
- 1. Seen / Unseen mIoU 稳定提升;
340
- 2. Unseen 至少不弱于 Seen 的提升趋势;
341
- 3. Null 指标不恶化,预测面积不膨胀;
342
- 4. hard frames 改善更明显;
343
- 5. 如果记录 gate alpha,hard frames 的 alpha 应系统性高于 easy frames。
344
-
345
- 失败解释:
346
-
347
- - 如果 Seen 提升、Unseen 不提升:可能是 gate 学到数据集模式,而不是 referent anchoring;
348
- - 如果 Null 恶化:gate 可能放大了通用前景显著性;
349
- - 如果 gate-only 无变化但完整 A-min 有收益:说明 gate 需要与 mask decoder / text projection 协同适配;
350
- - 如果全 split 下降:gate 插入位置、初始化或学习率需要重新检查。
351
-
352
- ---
353
-
354
- ## 6. 后续机制分析
355
-
356
- 如果 A-min 有正收益,再做 hook 分析:
357
-
358
- 1. sparse self-attention 中 `mask_token -> q`;
359
- 2. token-to-image attention 中 mask token 对 image tokens 的关注;
360
- 3. A-min 前后 hard/easy frames 的 gate alpha;
361
- 4. `s_pred(beta=1.0)` 与 gate alpha 的关系。
362
-
363
- 这部分用于论文解释,不作为当前阻塞项。
364
-
365
- ---
366
-
367
- ## 7. 当前一句话结论
368
-
369
- > A-min gate 的梯度路径、表达能力和 cached pipeline 一致性已经通过验证;overfit32 能显著提升训练集但不能泛化。当前主线是用更大 cached train set(已完成 train256 cache)验证 gate-only 泛化,再决定是否投入完整 A-min 联合训练。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
target_frame_sweep.py DELETED
@@ -1,265 +0,0 @@
1
- import csv
2
- import os
3
- import random
4
- from functools import partial
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- import transformers
10
- from torch.utils.data import DataLoader
11
-
12
- from configs import args
13
- from datasets import REFAVS
14
- from decoder_invariance_check import build_model, set_seed
15
- from load_model import collate_fn, dict_to_cuda
16
- from utils import utility
17
-
18
-
19
- def decode_with_q(model, batch, q):
20
- visual_model = model.get_model().visual_model
21
- image_embeddings = batch["image_feats"][0]
22
-
23
- sparse, dense = visual_model.prompt_encoder(
24
- points=None,
25
- boxes=None,
26
- masks=None,
27
- text_embeds=q.unsqueeze(1),
28
- )
29
- sparse = sparse.to(q.dtype)
30
- dense = dense.to(q.dtype)
31
-
32
- low_res_masks, iou_predictions = visual_model.mask_decoder(
33
- image_embeddings=image_embeddings,
34
- image_pe=visual_model.prompt_encoder.get_dense_pe(),
35
- sparse_prompt_embeddings=sparse,
36
- dense_prompt_embeddings=dense,
37
- multimask_output=False,
38
- )
39
- pred_masks = visual_model.postprocess_masks(
40
- low_res_masks,
41
- input_size=batch["resizes"][0],
42
- original_size=batch["orgsizes"][0],
43
- ).squeeze(1)
44
- return pred_masks.unsqueeze(0), iou_predictions.squeeze(-1)
45
-
46
-
47
- def get_q_for_target_frame(model, batch, target_frame):
48
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
49
- output = model.forward(
50
- images=batch["images"],
51
- images_clip=batch["images_clip"],
52
- audio_features=batch["audio_feats"],
53
- image_features=batch["image_feats"],
54
- input_ids=batch["input_ids"],
55
- labels=batch["labels"],
56
- attention_masks=batch["attention_masks"],
57
- masks_list=batch["masks"],
58
- resize_list=batch["resizes"],
59
- orgsize_list=batch["orgsizes"],
60
- conversation_list=batch["convs"],
61
- refs_num=batch["refs_num"],
62
- fids=batch["fids"],
63
- vids=batch["vids"],
64
- contrast=args.ct_weight,
65
- ref_ids=batch["ref_ids"],
66
- inference=True,
67
- target_frame=target_frame,
68
- )
69
- return output["seg_embeddings"][0][0:1]
70
-
71
-
72
- def mask_area(pred_masks):
73
- return (torch.sigmoid(pred_masks) > 0.4).float().mean().item()
74
-
75
-
76
- def mean_mask_iou_to_others(mask, other_masks):
77
- if not other_masks:
78
- return 1.0
79
- binary = (torch.sigmoid(mask) > 0.4).float()
80
- other_binary = [(torch.sigmoid(m) > 0.4).float() for m in other_masks]
81
- vals = []
82
- for other in other_binary:
83
- inter = (binary * other).sum()
84
- union = torch.maximum(binary, other).sum()
85
- vals.append((inter / (union + 1e-7)).item())
86
- return float(np.mean(vals))
87
-
88
-
89
- def evaluate_one_sample(model, batch, sample_idx):
90
- rows = []
91
- qs = []
92
- pred_masks_by_tf = []
93
-
94
- gt_masks = batch["masks"][0]
95
- vid = batch["vids"][0]
96
- ref = batch["refs"][0][0]
97
-
98
- for target_frame in range(args.frame_n):
99
- q = get_q_for_target_frame(model, batch, target_frame)
100
- qs.append(q.float().squeeze(0))
101
-
102
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
103
- pred_masks, iou_predictions = decode_with_q(model, batch, q)
104
- pred_masks_by_tf.append(pred_masks.detach())
105
-
106
- miou = utility.mask_iou(pred_masks.float(), gt_masks.float())
107
- fscore = utility.Eval_Fmeasure(pred_masks.float(), gt_masks.float(), None)
108
- null_metric = utility.metric_s_for_null(pred_masks.float())
109
- area = mask_area(pred_masks)
110
- mean_iou_pred = iou_predictions.float().mean().item()
111
-
112
- rows.append(
113
- {
114
- "sample_idx": sample_idx,
115
- "vid": vid,
116
- "ref": ref,
117
- "target_frame": target_frame,
118
- "mean_iou_pred": mean_iou_pred,
119
- "mask_area": area,
120
- "null_metric": float(null_metric),
121
- "miou": miou,
122
- "fscore": fscore,
123
- "cos_to_q5": 0.0,
124
- "mean_cos_to_other_q": 0.0,
125
- "mean_mask_iou_to_other_tf": 0.0,
126
- }
127
- )
128
-
129
- q_stack = F.normalize(torch.stack(qs, dim=0), dim=-1)
130
- q_cos = q_stack @ q_stack.T
131
- q5_idx = min(5, len(qs) - 1)
132
-
133
- for i, row in enumerate(rows):
134
- other = [j for j in range(len(rows)) if j != i]
135
- row["cos_to_q5"] = q_cos[i, q5_idx].item()
136
- row["mean_cos_to_other_q"] = q_cos[i, other].mean().item()
137
- row["mean_mask_iou_to_other_tf"] = mean_mask_iou_to_others(
138
- pred_masks_by_tf[i], [pred_masks_by_tf[j] for j in other]
139
- )
140
-
141
- return rows
142
-
143
-
144
- def print_sample_summary(rows):
145
- print(f"\nSample {rows[0]['sample_idx']}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
146
- print("tf | miou | fscore | null_s | iou_pred | area | cos_to_q5 | mean_q_cos")
147
- for row in rows:
148
- print(
149
- f"{row['target_frame']:02d} | "
150
- f"{row['miou']:.4f} | "
151
- f"{row['fscore']:.4f} | "
152
- f"{row['null_metric']:.4f} | "
153
- f"{row['mean_iou_pred']:.4f} | "
154
- f"{row['mask_area']:.4f} | "
155
- f"{row['cos_to_q5']:.4f} | "
156
- f"{row['mean_cos_to_other_q']:.4f}"
157
- )
158
-
159
- best_miou = max(rows, key=lambda x: x["miou"])
160
- best_iou_pred = max(rows, key=lambda x: x["mean_iou_pred"])
161
- fixed = rows[min(5, len(rows) - 1)]
162
- miou_values = [row["miou"] for row in rows]
163
- q5_values = [row["cos_to_q5"] for row in rows]
164
- print(
165
- "Best miou tf="
166
- f"{best_miou['target_frame']} ({best_miou['miou']:.4f}); "
167
- "best iou_pred tf="
168
- f"{best_iou_pred['target_frame']} ({best_iou_pred['mean_iou_pred']:.4f}); "
169
- f"fixed tf=5 miou={fixed['miou']:.4f}"
170
- )
171
- print(
172
- f"target-frame miou range={max(miou_values) - min(miou_values):.4f}; "
173
- f"min cos_to_q5={min(q5_values):.4f}"
174
- )
175
-
176
-
177
- def main():
178
- set_seed(42)
179
- torch.set_grad_enabled(False)
180
-
181
- tokenizer = transformers.AutoTokenizer.from_pretrained(
182
- args.mllm,
183
- cache_dir=None,
184
- model_max_length=2048,
185
- padding_side="right",
186
- use_fast=False,
187
- )
188
- tokenizer.pad_token = tokenizer.unk_token
189
- tokenizer.add_tokens("[SEG]")
190
- seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
191
-
192
- dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
193
- loader = DataLoader(
194
- dataset,
195
- batch_size=1,
196
- shuffle=False,
197
- num_workers=0,
198
- collate_fn=partial(collate_fn, tokenizer=tokenizer),
199
- )
200
-
201
- limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
202
- print(f"Split: {args.eval_split} | samples to sweep: {limit}")
203
-
204
- model = build_model(tokenizer, seg_token_idx)
205
-
206
- all_rows = []
207
- for sample_idx, batch in enumerate(loader):
208
- if sample_idx >= limit:
209
- break
210
- batch = dict_to_cuda(batch)
211
- rows = evaluate_one_sample(model, batch, sample_idx)
212
- all_rows.extend(rows)
213
- print_sample_summary(rows)
214
-
215
- if not all_rows:
216
- raise RuntimeError("No rows were checked. Is the selected split empty?")
217
-
218
- fixed_rows = [r for r in all_rows if r["target_frame"] == min(5, args.frame_n - 1)]
219
- oracle_by_sample = {}
220
- iou_pred_by_sample = {}
221
- for row in all_rows:
222
- key = row["sample_idx"]
223
- if key not in oracle_by_sample or row["miou"] > oracle_by_sample[key]["miou"]:
224
- oracle_by_sample[key] = row
225
- if key not in iou_pred_by_sample or row["mean_iou_pred"] > iou_pred_by_sample[key]["mean_iou_pred"]:
226
- iou_pred_by_sample[key] = row
227
-
228
- fixed_miou = np.mean([r["miou"] for r in fixed_rows])
229
- fixed_null_metric = np.mean([r["null_metric"] for r in fixed_rows])
230
- oracle_miou = np.mean([r["miou"] for r in oracle_by_sample.values()])
231
- iou_pred_selected_miou = np.mean([r["miou"] for r in iou_pred_by_sample.values()])
232
- min_cos_to_q5 = np.mean(
233
- [min(r["cos_to_q5"] for r in all_rows if r["sample_idx"] == sample_idx) for sample_idx in oracle_by_sample]
234
- )
235
- mean_miou_range = np.mean(
236
- [
237
- max(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
238
- - min(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
239
- for sample_idx in oracle_by_sample
240
- ]
241
- )
242
-
243
- print("\nSummary")
244
- print(f"samples: {len(fixed_rows)}")
245
- print(f"fixed target_frame=5 mean miou: {fixed_miou:.4f}")
246
- print(f"fixed target_frame=5 mean null_s: {fixed_null_metric:.4f}")
247
- print(f"oracle best-target-frame mean miou: {oracle_miou:.4f}")
248
- print(f"best-by-iou_pred selected mean miou: {iou_pred_selected_miou:.4f}")
249
- print(f"oracle gain over fixed: {oracle_miou - fixed_miou:+.4f}")
250
- print(f"iou_pred-selection gain over fixed: {iou_pred_selected_miou - fixed_miou:+.4f}")
251
- print(f"mean target-frame miou range: {mean_miou_range:.4f}")
252
- print(f"mean sample min cos_to_q5: {min_cos_to_q5:.4f}")
253
-
254
- csv_path = os.environ.get("TARGET_FRAME_SWEEP_CSV")
255
- if csv_path:
256
- os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
257
- with open(csv_path, "w", newline="") as f:
258
- writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
259
- writer.writeheader()
260
- writer.writerows(all_rows)
261
- print(f"Saved CSV: {csv_path}")
262
-
263
-
264
- if __name__ == "__main__":
265
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_cached_gate.py DELETED
@@ -1,439 +0,0 @@
1
- import json
2
- import os
3
- import random
4
-
5
- import cv2
6
- import numpy as np
7
- import torch
8
- import transformers
9
- from torch.optim import AdamW
10
- from torch.utils.data import DataLoader, Dataset, Subset
11
- from tqdm import tqdm
12
-
13
- from configs import args
14
- from decoder_invariance_check import build_model, set_seed
15
- from models.avs_model import dice_loss, sigmoid_ce_loss
16
- from utils import utility
17
-
18
-
19
- def _total_norm(values):
20
- if not values:
21
- return 0.0
22
- return float(sum(v * v for v in values) ** 0.5)
23
-
24
-
25
- def collect_referent_gate_stats(model):
26
- gate_modules = [(n, m) for n, m in model.named_modules() if n.endswith("referent_gate")]
27
- proj_norms = []
28
- gate_norms = []
29
- proj_grad_norms = []
30
- gate_grad_norms = []
31
- alpha_tensors = []
32
-
33
- for _, module in gate_modules:
34
- proj_norms.append(module.proj.weight.detach().float().norm().item())
35
- gate_norms.append(module.gate.weight.detach().float().norm().item())
36
- if module.proj.weight.grad is not None:
37
- proj_grad_norms.append(module.proj.weight.grad.detach().float().norm().item())
38
- if module.gate.weight.grad is not None:
39
- gate_grad_norms.append(module.gate.weight.grad.detach().float().norm().item())
40
- if module.last_alpha is not None:
41
- alpha_tensors.append(module.last_alpha.detach().float().reshape(-1))
42
-
43
- stats = {
44
- "modules": len(gate_modules),
45
- "proj_norm": _total_norm(proj_norms),
46
- "gate_norm": _total_norm(gate_norms),
47
- "proj_grad_norm": _total_norm(proj_grad_norms),
48
- "gate_grad_norm": _total_norm(gate_grad_norms),
49
- }
50
-
51
- if alpha_tensors:
52
- alpha = torch.cat(alpha_tensors)
53
- stats.update(
54
- {
55
- "alpha_mean": alpha.mean().item(),
56
- "alpha_std": alpha.std(unbiased=False).item(),
57
- "alpha_min": alpha.min().item(),
58
- "alpha_max": alpha.max().item(),
59
- }
60
- )
61
- else:
62
- stats.update(
63
- {
64
- "alpha_mean": float("nan"),
65
- "alpha_std": float("nan"),
66
- "alpha_min": float("nan"),
67
- "alpha_max": float("nan"),
68
- }
69
- )
70
-
71
- return stats
72
-
73
-
74
- def zero_referent_gate(model):
75
- with torch.no_grad():
76
- for _, module in model.named_modules():
77
- if not _.endswith("referent_gate"):
78
- continue
79
- module.gate.weight.zero_()
80
- module.gate.bias.zero_()
81
- module.proj.weight.zero_()
82
- module.proj.bias.zero_()
83
- module.last_alpha = None
84
-
85
-
86
- def referent_gate_state_dict(model):
87
- return {
88
- name: param.detach().cpu()
89
- for name, param in model.state_dict().items()
90
- if "referent_gate" in name
91
- }
92
-
93
-
94
- def load_referent_gate_checkpoint(model, path):
95
- checkpoint = torch.load(path, map_location="cpu")
96
- if isinstance(checkpoint, dict) and checkpoint.get("type") == "referent_gate_only":
97
- checkpoint = checkpoint["state_dict"]
98
- gate_state = {k: v for k, v in checkpoint.items() if "referent_gate" in k}
99
- if not gate_state:
100
- raise RuntimeError(f"No referent_gate parameters found in {path}")
101
- current = model.state_dict()
102
- missing_shape = [
103
- k
104
- for k, v in gate_state.items()
105
- if k not in current or tuple(current[k].shape) != tuple(v.shape)
106
- ]
107
- if missing_shape:
108
- raise RuntimeError(f"Gate checkpoint has incompatible keys: {missing_shape[:5]}")
109
- current.update(gate_state)
110
- model.load_state_dict(current, strict=True)
111
- print(f"loaded referent gate checkpoint: {path} ({len(gate_state)} tensors)")
112
-
113
-
114
- def log_gate_stats(model, step, loss_value, batch_metrics=None):
115
- stats = collect_referent_gate_stats(model)
116
- metric_text = ""
117
- if batch_metrics is not None:
118
- metric_text = (
119
- f"batch_miou={batch_metrics['miou']:.4f} "
120
- f"batch_fscore={batch_metrics['fscore']:.4f} "
121
- )
122
- message = (
123
- f"gate_stats step={step} "
124
- f"loss={loss_value:.6f} "
125
- f"{metric_text}"
126
- f"proj_norm={stats['proj_norm']:.6f} "
127
- f"gate_norm={stats['gate_norm']:.6f} "
128
- f"proj_grad_norm={stats['proj_grad_norm']:.6f} "
129
- f"gate_grad_norm={stats['gate_grad_norm']:.6f} "
130
- f"alpha_mean={stats['alpha_mean']:.4f} "
131
- f"alpha_std={stats['alpha_std']:.4f} "
132
- f"alpha_min={stats['alpha_min']:.4f} "
133
- f"alpha_max={stats['alpha_max']:.4f}"
134
- )
135
- print(message)
136
- os.makedirs(args.log_root, exist_ok=True)
137
- with open(os.path.join(args.log_root, f"{args.name}.txt"), "a") as f:
138
- f.write(message + "\n")
139
-
140
-
141
- class CachedQDataset(Dataset):
142
- def __init__(self, split, cfg):
143
- self.split = split
144
- self.cfg = cfg
145
- self.root = os.path.join(cfg.cache_root, split)
146
- self.index_path = os.path.join(self.root, "index.jsonl")
147
- if not os.path.exists(self.index_path):
148
- raise FileNotFoundError(f"Missing cache index: {self.index_path}")
149
- with open(self.index_path) as f:
150
- self.rows = [json.loads(line) for line in f if line.strip()]
151
-
152
- def __len__(self):
153
- return len(self.rows)
154
-
155
- def _load_masks(self, vid, fids):
156
- masks = []
157
- for fid in fids:
158
- frames = []
159
- for frame_idx in range(self.cfg.frame_n):
160
- path = os.path.join(
161
- self.cfg.data_dir,
162
- "gt_mask",
163
- vid,
164
- f"fid_{int(fid)}",
165
- f"0000{frame_idx}.png",
166
- )
167
- mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
168
- if mask is None:
169
- raise FileNotFoundError(path)
170
- frames.append(torch.as_tensor(mask > 0, dtype=torch.float32))
171
- masks.append(torch.stack(frames, dim=0))
172
- return torch.stack(masks, dim=0)
173
-
174
- def __getitem__(self, idx):
175
- row = self.rows[idx]
176
- cache = torch.load(os.path.join(self.root, row["path"]), map_location="cpu")
177
- vid = cache["vid"]
178
- return {
179
- "sample_idx": cache["sample_idx"],
180
- "vid": vid,
181
- "refs": cache["refs"],
182
- "fids": cache["fids"],
183
- "q": cache["q"].float(),
184
- "image_embeddings": torch.load(
185
- os.path.join(self.cfg.data_dir, "image_embed", f"{vid}.pt"),
186
- map_location="cpu",
187
- ).float(),
188
- "gt_masks": self._load_masks(vid, cache["fids"]),
189
- "resize": tuple(cache["resize"]),
190
- "orgsize": tuple(cache["orgsize"]),
191
- }
192
-
193
-
194
- def collate_cached(batch):
195
- return batch
196
-
197
-
198
- def decode_batch(visual_model, batch, device):
199
- image_pe = visual_model.prompt_encoder.get_dense_pe()
200
- frame_qs = []
201
- frame_image_embeddings = []
202
- prompt_spans = []
203
-
204
- for sample_idx, sample in enumerate(batch):
205
- q = sample["q"].to(device=device, dtype=torch.float32)
206
- image_embeddings = sample["image_embeddings"].to(device=device, dtype=torch.float32)
207
- frames = image_embeddings.shape[0]
208
- for prompt_idx in range(q.shape[0]):
209
- start = len(frame_qs) * frames
210
- frame_qs.append(q[prompt_idx].unsqueeze(0).expand(frames, -1))
211
- frame_image_embeddings.append(image_embeddings)
212
- prompt_spans.append((sample_idx, prompt_idx, start, start + frames))
213
-
214
- if not frame_qs:
215
- raise RuntimeError("No cached prompts were provided for decoding.")
216
-
217
- frame_qs = torch.cat(frame_qs, dim=0)
218
- frame_image_embeddings = torch.cat(frame_image_embeddings, dim=0)
219
- sparse_embeddings, dense_embeddings = visual_model.prompt_encoder(
220
- points=None,
221
- boxes=None,
222
- masks=None,
223
- text_embeds=frame_qs.unsqueeze(1),
224
- )
225
- sparse_embeddings = sparse_embeddings.to(frame_qs.dtype)
226
- dense_embeddings = dense_embeddings.to(frame_qs.dtype)
227
-
228
- low_res_masks = visual_model.mask_decoder.forward_modified_v3(
229
- image_embeddings=frame_image_embeddings,
230
- image_pe=image_pe,
231
- sparse_prompt_embeddings=sparse_embeddings,
232
- dense_prompt_embeddings=dense_embeddings,
233
- ).unsqueeze(1)
234
-
235
- pred_by_sample = [[] for _ in batch]
236
- for sample_idx, _, start, end in prompt_spans:
237
- sample = batch[sample_idx]
238
- pred_mask = visual_model.postprocess_masks(
239
- low_res_masks[start:end],
240
- input_size=sample["resize"],
241
- original_size=sample["orgsize"],
242
- )
243
- pred_by_sample[sample_idx].append(pred_mask.squeeze(1))
244
-
245
- return [torch.stack(pred_masks, dim=0) for pred_masks in pred_by_sample]
246
-
247
-
248
- def decode_sample(visual_model, sample, device):
249
- return decode_batch(visual_model, [sample], device)[0]
250
-
251
-
252
- def compute_mask_loss(pred_masks, gt_masks):
253
- mask_bce_loss = 0.0
254
- mask_dice_loss = 0.0
255
- num_masks = 0
256
-
257
- for pred_mask, gt_mask in zip(pred_masks, gt_masks):
258
- gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
259
- num_seg, frames, height, width = gt_mask.shape
260
- gt_flat = gt_mask.view(num_seg * frames, height, width)
261
- pred_flat = pred_mask.view(num_seg * frames, height, width)
262
-
263
- mask_bce_loss = mask_bce_loss + (
264
- sigmoid_ce_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
265
- * gt_flat.shape[0]
266
- )
267
- mask_dice_loss = mask_dice_loss + (
268
- dice_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
269
- * gt_flat.shape[0]
270
- )
271
- num_masks += gt_flat.shape[0]
272
-
273
- mask_bce_loss = 2.0 * mask_bce_loss / (num_masks + 1e-8)
274
- mask_dice_loss = 0.5 * mask_dice_loss / (num_masks + 1e-8)
275
- return mask_bce_loss + mask_dice_loss
276
-
277
-
278
- def compute_batch_metrics(pred_masks, gt_masks):
279
- total_iou = 0.0
280
- total_fscore = 0.0
281
- count = 0
282
- for pred_mask, gt_mask in zip(pred_masks, gt_masks):
283
- gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
284
- num_seg, frames = pred_mask.shape[:2]
285
- weight = num_seg * frames
286
- total_iou += utility.mask_iou(pred_mask.detach().float(), gt_mask.float()) * weight
287
- total_fscore += utility.Eval_Fmeasure(pred_mask.detach().float(), gt_mask.float(), None) * weight
288
- count += weight
289
- return {
290
- "miou": total_iou / max(1, count),
291
- "fscore": total_fscore / max(1, count),
292
- }
293
-
294
-
295
- def evaluate(model, loader):
296
- model.eval()
297
- visual_model = model.get_model().visual_model
298
- total_iou = 0.0
299
- total_fscore = 0.0
300
- total_null = 0.0
301
- count = 0
302
-
303
- with torch.no_grad():
304
- for batch in tqdm(loader, desc=f"Cached eval {args.cache_split}"):
305
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
306
- batch_pred = decode_batch(visual_model, batch, "cuda")
307
- for sample, pred in zip(batch, batch_pred):
308
- gt = sample["gt_masks"].to(device=pred.device, dtype=pred.dtype)
309
- num_seg, frames = pred.shape[:2]
310
- weight = num_seg * frames
311
- if args.cache_split == "test_n":
312
- total_null += float(utility.metric_s_for_null(pred.float())) * weight
313
- else:
314
- total_iou += utility.mask_iou(pred.float(), gt.float()) * weight
315
- total_fscore += utility.Eval_Fmeasure(pred.float(), gt.float(), None) * weight
316
- count += weight
317
-
318
- if count == 0:
319
- raise RuntimeError("No cached samples were evaluated.")
320
-
321
- if args.cache_split == "test_n":
322
- print(f"cached valuate on test_n_refer, metric: {total_null / count}")
323
- else:
324
- print(
325
- f"cached valuate on {args.cache_split}: "
326
- f"miou: {total_iou / count} fscore: {total_fscore / count}"
327
- )
328
-
329
-
330
- def train(model, loader):
331
- if args.disable_gate:
332
- raise ValueError("--disable_gate is only valid with --eval_only")
333
-
334
- for p in model.parameters():
335
- p.requires_grad = False
336
- for name, p in model.named_parameters():
337
- if "referent_gate" in name:
338
- p.requires_grad = True
339
-
340
- gate_params = [p for p in model.parameters() if p.requires_grad]
341
- optimizer = AdamW(gate_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
342
- stats = collect_referent_gate_stats(model)
343
- print(
344
- "cached gate init: "
345
- f"modules={stats['modules']} "
346
- f"proj_norm={stats['proj_norm']:.6f} "
347
- f"gate_norm={stats['gate_norm']:.6f} "
348
- f"trainable_params={sum(p.numel() for p in gate_params)}"
349
- )
350
-
351
- visual_model = model.get_model().visual_model
352
- step = 0
353
- for epoch in range(args.epochs):
354
- model.train()
355
- order_loader = loader
356
- for batch in tqdm(order_loader, desc=f"Cached gate train {epoch + 1}/{args.epochs}"):
357
- if args.max_steps > 0 and step >= args.max_steps:
358
- break
359
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
360
- pred_masks = decode_batch(visual_model, batch, "cuda")
361
- gt_masks = [sample["gt_masks"] for sample in batch]
362
-
363
- loss = compute_mask_loss(pred_masks, gt_masks)
364
- optimizer.zero_grad()
365
- loss.backward()
366
- step += 1
367
- if args.log_gate_stats_every > 0 and step % args.log_gate_stats_every == 0:
368
- batch_metrics = compute_batch_metrics(pred_masks, gt_masks)
369
- log_gate_stats(model, step, loss.item(), batch_metrics)
370
- optimizer.step()
371
-
372
- if args.max_steps > 0 and step >= args.max_steps:
373
- print(f"stopped early at cached optimizer step {step}")
374
- break
375
-
376
- os.makedirs(args.checkpoint_root, exist_ok=True)
377
- ckpt_path = os.path.join(args.checkpoint_root, f"{args.name}.pth")
378
- if args.save_gate_only:
379
- torch.save(
380
- {
381
- "type": "referent_gate_only",
382
- "base_model": args.saved_model,
383
- "state_dict": referent_gate_state_dict(model),
384
- },
385
- ckpt_path,
386
- )
387
- else:
388
- torch.save(model.state_dict(), ckpt_path)
389
- print(f"cached gate model saved as {ckpt_path}")
390
-
391
-
392
- def main():
393
- set_seed(42)
394
- random.seed(42)
395
- np.random.seed(42)
396
-
397
- tokenizer = transformers.AutoTokenizer.from_pretrained(
398
- args.mllm,
399
- cache_dir=None,
400
- model_max_length=2048,
401
- padding_side="right",
402
- use_fast=False,
403
- )
404
- tokenizer.pad_token = tokenizer.unk_token
405
- tokenizer.add_tokens("[SEG]")
406
- seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
407
-
408
- dataset = CachedQDataset(args.cache_split, args)
409
- if args.overfit_samples > 0:
410
- n = min(args.overfit_samples, len(dataset))
411
- dataset = Subset(dataset, list(range(n)))
412
- print(f"cached overfit_samples enabled: using first {n} samples")
413
-
414
- loader = DataLoader(
415
- dataset,
416
- batch_size=args.batch_size,
417
- shuffle=not args.eval_only,
418
- num_workers=4,
419
- collate_fn=collate_cached,
420
- )
421
-
422
- model = build_model(tokenizer, seg_token_idx)
423
- if args.gate_checkpoint:
424
- load_referent_gate_checkpoint(model, args.gate_checkpoint)
425
- if args.disable_gate:
426
- zero_referent_gate(model)
427
- print("disable_gate enabled: referent gate forced to identity")
428
-
429
- if args.eval_only:
430
- evaluate(model, loader)
431
- return
432
-
433
- train(model, loader)
434
- if not args.skip_eval_after_train:
435
- evaluate(model, loader)
436
-
437
-
438
- if __name__ == "__main__":
439
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
upload_hf.py DELETED
@@ -1,74 +0,0 @@
1
- """Upload the current SimToken workspace to HuggingFace Hub.
2
-
3
- Example:
4
- python upload_hf.py --repo yfan07/SimToken
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import argparse
10
- import logging
11
- from pathlib import Path
12
-
13
- from huggingface_hub import HfApi, create_repo
14
-
15
-
16
- ROOT = Path(__file__).resolve().parent
17
-
18
- IGNORE_PATTERNS = [
19
- ".git/**",
20
- "**/__pycache__/**",
21
- "**/.pytest_cache/**",
22
- "**/.cache/**",
23
- "**/*.pyc",
24
- "**/*.pyo",
25
- "upload.log",
26
- ]
27
-
28
-
29
- def parse_args() -> argparse.Namespace:
30
- parser = argparse.ArgumentParser(description="Upload SimToken to HuggingFace Hub.")
31
- parser.add_argument("--repo", required=True, help="Repo id, e.g. yfan07/SimToken")
32
- parser.add_argument("--repo_type", default="model", choices=["model", "dataset", "space"])
33
- parser.add_argument("--private", action="store_true", help="Create repo as private if missing.")
34
- parser.add_argument("--num_workers", type=int, default=4)
35
- return parser.parse_args()
36
-
37
-
38
- def main() -> None:
39
- args = parse_args()
40
- logging.basicConfig(
41
- level=logging.INFO,
42
- format="%(asctime)s %(levelname)s %(message)s",
43
- handlers=[logging.FileHandler(ROOT / "upload.log"), logging.StreamHandler()],
44
- )
45
-
46
- create_repo(
47
- repo_id=args.repo,
48
- repo_type=args.repo_type,
49
- private=args.private,
50
- exist_ok=True,
51
- )
52
-
53
- api = HfApi()
54
- if hasattr(api, "upload_large_folder"):
55
- logging.info("Uploading %s to %s with upload_large_folder", ROOT, args.repo)
56
- api.upload_large_folder(
57
- repo_id=args.repo,
58
- repo_type=args.repo_type,
59
- folder_path=str(ROOT),
60
- ignore_patterns=IGNORE_PATTERNS,
61
- num_workers=args.num_workers,
62
- )
63
- else:
64
- logging.info("Uploading %s to %s with upload_folder", ROOT, args.repo)
65
- api.upload_folder(
66
- repo_id=args.repo,
67
- repo_type=args.repo_type,
68
- folder_path=str(ROOT),
69
- ignore_patterns=IGNORE_PATTERNS,
70
- )
71
-
72
-
73
- if __name__ == "__main__":
74
- main()