Clean experimental files and restore original SimToken layout
Browse files- Residual_Prompt_Bridge.md +0 -501
- SEG_LTPO_results.md +0 -488
- analyze_d2_csv.py +0 -239
- build_rpb_dev_manifest.py +0 -71
- cache_q_features.py +0 -125
- cache_q_smoke/test_s/000000.pt +0 -3
- cache_q_smoke/test_s/index.jsonl +0 -1
- checkpoints/rpb_dev_mixed_pm_only_a018_wm005.pth +0 -3
- checkpoints/rpb_dev_pm_only_a018.pth +0 -3
- checkpoints/rpb_probe_eval_directional_pm_only_a02.pth +0 -3
- d2_basic.py +0 -340
- d2_llm_space.py +0 -314
- decoder_invariance_check.py +0 -256
- dev_subsets_rpb_v1.json +0 -620
- log/rpb_dev_eval_baseline_step0.txt +0 -5
- log/rpb_dev_eval_pm_only_a02_step0.txt +0 -7
- log/rpb_dev_mixed_pm_only_a015_wm005.txt +0 -11
- log/rpb_dev_mixed_pm_only_a018_wm005.txt +0 -11
- log/rpb_dev_pm_only_a012.txt +0 -11
- log/rpb_dev_pm_only_a015.txt +0 -11
- log/rpb_dev_pm_only_a018.txt +0 -11
- log/rpb_dev_qonly_pm_only_a018.txt +0 -11
- log/rpb_e1_baseline.txt +0 -5
- log/rpb_e4_min.txt +0 -16
- log/rpb_e4_min_v2.txt +0 -11
- log/rpb_probe_a1_teacher_only.txt +0 -22
- log/rpb_probe_a1_teacher_only_v2.txt +0 -22
- log/rpb_probe_a1p_directional_pm_only.txt +0 -22
- log/rpb_probe_a1p_directional_pm_only_a02.txt +0 -22
- log/rpb_probe_eval_directional_pm_only_a02.txt +0 -11
- log/rpb_probe_eval_directional_pm_only_a02_step0.txt +0 -7
- log/rpb_probe_mixed_pm_only_a02_wm005_s80.txt +0 -11
- seg_ltpo.py +0 -1372
- setup_simtoken.md +0 -163
- simtoken_experiment.md +0 -369
- target_frame_sweep.py +0 -265
- train_cached_gate.py +0 -439
- upload_hf.py +0 -74
Residual_Prompt_Bridge.md
DELETED
|
@@ -1,501 +0,0 @@
|
|
| 1 |
-
# Residual Prompt Bridge 论文导向实验路线图
|
| 2 |
-
|
| 3 |
-
## 1. 当前主 claim
|
| 4 |
-
|
| 5 |
-
论文主 claim 现在正式锁定为:
|
| 6 |
-
|
| 7 |
-
> **We propose an image-conditioned directional prompt correction module that orthogonalizes prompt updates to steer language-side prompts toward a more decodable SAM prompt manifold, mitigating cross-distribution prompt interface mismatch.**
|
| 8 |
-
|
| 9 |
-
对应中文表述:
|
| 10 |
-
|
| 11 |
-
> **我们提出一种图像条件的方向型 prompt correction,通过正交化更新把语言侧 prompt 朝更可解码的 SAM prompt manifold 偏转,从而缓解跨分布的 prompt 接口失配。**
|
| 12 |
-
|
| 13 |
-
从现在开始,所有实验都只服务这句 claim,不再让方法故事扩散成“大而全系统”。
|
| 14 |
-
|
| 15 |
-
---
|
| 16 |
-
|
| 17 |
-
## 2. 当前项目定位
|
| 18 |
-
|
| 19 |
-
当前 RPB 项目已经完成了最关键的早期筛查:
|
| 20 |
-
|
| 21 |
-
1. **实现正确性通过**
|
| 22 |
-
- checkpoint / LoRA 兼容问题已修复
|
| 23 |
-
- bridge 路径不会自动破坏 baseline
|
| 24 |
-
- identity-preserving sanity check 已通过
|
| 25 |
-
|
| 26 |
-
2. **几何机制方向明确**
|
| 27 |
-
- additive residual 不足以推动 `p_hat` 离开 `q`
|
| 28 |
-
- directional bridge 明显优于 additive
|
| 29 |
-
- orthogonalization 能把 residual 预算从径向缩放转成方向修正
|
| 30 |
-
|
| 31 |
-
3. **当前最小核心已浮现**
|
| 32 |
-
- `image-conditioned`
|
| 33 |
-
- `p_mask-only`
|
| 34 |
-
- `directional`
|
| 35 |
-
- `orthogonal`
|
| 36 |
-
- `single-token correction`
|
| 37 |
-
|
| 38 |
-
4. **mixed 的角色目前仍未定型**
|
| 39 |
-
- weak mixed 不会抹掉 bridge
|
| 40 |
-
- 但目前更像 enhancer / compatibility probe,而不是稳定的 decoder-facing calibration mechanism
|
| 41 |
-
|
| 42 |
-
因此,当前最重要的不是继续加模块,而是把这个**最小有效核心**做成稳定、可复现、可投稿的方法骨架。
|
| 43 |
-
|
| 44 |
-
---
|
| 45 |
-
|
| 46 |
-
## 3. 两套判据:Mechanism Pass vs Paper Pass
|
| 47 |
-
|
| 48 |
-
### 3.1 Mechanism pass
|
| 49 |
-
|
| 50 |
-
回答的问题是:
|
| 51 |
-
|
| 52 |
-
> 这个方法设计是否真的抓住了问题本质?
|
| 53 |
-
|
| 54 |
-
当前 mechanism pass 需要被下面这些证据支撑:
|
| 55 |
-
|
| 56 |
-
- additive vs directional:directional 明显更能让 `p_hat` 离开 identity
|
| 57 |
-
- without orthogonal vs with orthogonal:orthogonalization 明显改善 `Δp` 的几何利用效率
|
| 58 |
-
- `Δp` 稳定朝 `p_mask`
|
| 59 |
-
- `p_hat` 能明显离开 `q`
|
| 60 |
-
- seen/unseen 的 alignment ratio 健康
|
| 61 |
-
- weak mixed 不会直接把 bridge 拉回 baseline
|
| 62 |
-
|
| 63 |
-
### 3.2 Paper pass
|
| 64 |
-
|
| 65 |
-
回答的问题是:
|
| 66 |
-
|
| 67 |
-
> 这个方法是否已经强到能单独撑起一篇顶会方法论文?
|
| 68 |
-
|
| 69 |
-
paper pass 需要下面这些更强条件:
|
| 70 |
-
|
| 71 |
-
- 更大规模评估上有稳定、同向的 headline 趋势
|
| 72 |
-
- 至少在 unseen 上有清晰、可复现的优势
|
| 73 |
-
- seen / null 的代价可接受
|
| 74 |
-
- 2 个随机种子下趋势稳定
|
| 75 |
-
- 最小闭环 ablation 完整
|
| 76 |
-
|
| 77 |
-
当前状态:
|
| 78 |
-
|
| 79 |
-
- **mechanism pass:接近通过,但还缺更大规模验证和关键 baseline**
|
| 80 |
-
- **paper pass:尚未通过**
|
| 81 |
-
|
| 82 |
-
后续每组实验都要明确写清楚:它是在推进 mechanism pass,还是在推进 paper pass。
|
| 83 |
-
|
| 84 |
-
---
|
| 85 |
-
|
| 86 |
-
## 4. 冻结最小核心方法
|
| 87 |
-
|
| 88 |
-
在 pure RPB standalone 路线中,当前只保留下列组成:
|
| 89 |
-
|
| 90 |
-
- `image-conditioned correction`
|
| 91 |
-
- `p_mask-only teacher`
|
| 92 |
-
- `directional bridge`
|
| 93 |
-
- `orthogonalized update`
|
| 94 |
-
- `single-token prompt correction`
|
| 95 |
-
|
| 96 |
-
当前明确**不进入主线**的内容:
|
| 97 |
-
|
| 98 |
-
- `z_gt` 作为主 teacher
|
| 99 |
-
- calibrator
|
| 100 |
-
- refinement
|
| 101 |
-
- 多 token bridge
|
| 102 |
-
- 大而全的完整 bridge 系统
|
| 103 |
-
|
| 104 |
-
这些内容后续最多作为 ablation、扩展或 hybrid 组件,而不是当前主方法本体。
|
| 105 |
-
|
| 106 |
-
---
|
| 107 |
-
|
| 108 |
-
## 5. 当前实验事实总结
|
| 109 |
-
|
| 110 |
-
### 5.1 已确认的正结果
|
| 111 |
-
|
| 112 |
-
- bridge 可以安全接入,不会自动毁掉 baseline
|
| 113 |
-
- 修复 checkpoint / LoRA 后,RPB 路径与 baseline 基本等价
|
| 114 |
-
- `directional + orthogonal` 后:
|
| 115 |
-
- `Δp` 高度对齐 `p_mask`
|
| 116 |
-
- `Δp` 不再主要沿 `q` 的平行方向浪费预算
|
| 117 |
-
- `p_hat` 能够明显离开 identity 区
|
| 118 |
-
- `p_mask-only teacher-only` 已在 quick eval 上给出:
|
| 119 |
-
- seen 小幅回落但可控
|
| 120 |
-
- unseen 轻微正信号
|
| 121 |
-
- null 基本持平
|
| 122 |
-
|
| 123 |
-
### 5.2 已确认的负结果
|
| 124 |
-
|
| 125 |
-
- additive residual 不足以真正旋转 prompt
|
| 126 |
-
- `L_mask` 不是早期主矛盾
|
| 127 |
-
- `z_gt` 目前不是 sparse bridge 的主 teacher
|
| 128 |
-
- weak mixed 目前不能稳定把 seen 拉回 baseline
|
| 129 |
-
|
| 130 |
-
### 5.3 当前最重要的工作假设
|
| 131 |
-
|
| 132 |
-
> `p_mask-only + image-conditioned + directional + orthogonal` 已经抓住主问题,但还需要找到更稳定的 operating point,并证明其 headline 趋势不是噪声。
|
| 133 |
-
|
| 134 |
-
### 5.4 Fixed dev 阶段 A 当前记录
|
| 135 |
-
|
| 136 |
-
固定 dev 子集:
|
| 137 |
-
|
| 138 |
-
- `test_s`: 200 samples
|
| 139 |
-
- `test_u`: 200 samples
|
| 140 |
-
- `test_n`: 200 samples
|
| 141 |
-
- manifest: `/workspace/SimToken/dev_subsets_rpb_v1.json`
|
| 142 |
-
|
| 143 |
-
#### Fixed dev baseline
|
| 144 |
-
|
| 145 |
-
| Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null |
|
| 146 |
-
|---|---:|---:|---:|---:|---:|
|
| 147 |
-
| baseline | 0.72554 | 0.81811 | 0.68531 | 0.77238 | 0.01452 |
|
| 148 |
-
|
| 149 |
-
#### Teacher-only alpha search
|
| 150 |
-
|
| 151 |
-
| Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 机制判断 |
|
| 152 |
-
|---|---:|---:|---:|---:|---:|---:|---:|---|
|
| 153 |
-
| image, alpha=0.20 | 0.72517 | 0.81376 | 0.68596 | 0.77730 | 0.01426 | 0.09502 | 0.06611 | 机制最强,Seen/F 有代价 |
|
| 154 |
-
| image, alpha=0.18 | 0.72692 | 0.81705 | 0.68595 | 0.77354 | 0.01448 | 0.02873 | 0.00605 | 性能平衡较好,机制偏弱 |
|
| 155 |
-
| image, alpha=0.15 | 0.72669 | 0.81725 | 0.68569 | 0.77330 | 0.01448 | 0.02373 | 0.00282 | 更接近 identity |
|
| 156 |
-
| image, alpha=0.12 | 0.72651 | 0.81748 | 0.68578 | 0.77314 | 0.01449 | 0.01871 | -0.00046 | 轻扰动区,机制最弱 |
|
| 157 |
-
|
| 158 |
-
阶段 A 的 teacher-only 结论:
|
| 159 |
-
|
| 160 |
-
- `alpha=0.20` 是机制候选点,能明显改变 prompt geometry。
|
| 161 |
-
- `alpha=0.18` 是性能平衡候选点,seen / unseen / null 都更稳。
|
| 162 |
-
- `alpha=0.12/0.15` 已经过于接近 identity,不适合作为机制主证据。
|
| 163 |
-
|
| 164 |
-
#### Weak mixed 局部验证
|
| 165 |
-
|
| 166 |
-
| Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 角色判断 |
|
| 167 |
-
|---|---:|---:|---:|---:|---:|---:|---:|---|
|
| 168 |
-
| image, alpha=0.18, weak mixed | 0.72704 | 0.81554 | 0.68706 | 0.77454 | 0.01451 | 0.04079 | 0.01325 | 当前最佳性能平衡候选 |
|
| 169 |
-
| image, alpha=0.15, weak mixed | 0.72684 | 0.81607 | 0.68674 | 0.77419 | 0.01451 | 0.03382 | 0.00882 | 稳定但略弱于 alpha=0.18 mixed |
|
| 170 |
-
|
| 171 |
-
weak mixed 当前结论:
|
| 172 |
-
|
| 173 |
-
- weak mixed 没有把 bridge 拉回 identity。
|
| 174 |
-
- weak mixed 对 `alpha=0.15/0.18` 都更像 mild enhancement,而不是 destructive pullback。
|
| 175 |
-
- `alpha=0.18 + weak mixed` 是当前 fixed dev 的最佳 operating point。
|
| 176 |
-
|
| 177 |
-
#### q-only directional baseline
|
| 178 |
-
|
| 179 |
-
| Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 判断 |
|
| 180 |
-
|---|---:|---:|---:|---:|---:|---:|---:|---|
|
| 181 |
-
| q-only, alpha=0.18 | 0.72311 | 0.81206 | 0.68289 | 0.77666 | 0.01424 | 0.12061 | 0.09598 | alignment 更强但 mIoU 更差 |
|
| 182 |
-
|
| 183 |
-
q-only 结论:
|
| 184 |
-
|
| 185 |
-
- directional / orthogonal 机制本身很强,q-only 也能大幅拉高 teacher alignment。
|
| 186 |
-
- q-only 的 prompt steering 更激进,`gate_mean` 更高,`delta_norm` 更大。
|
| 187 |
-
- q-only mIoU 在 seen / unseen 上都低于 image-conditioned candidate。
|
| 188 |
-
- 当前证据支持:image conditioning 的价值不是单纯提高 teacher cosine,而是约束方向修正,使 prompt steering 与 decoder compatibility 之间的平衡更好。
|
| 189 |
-
|
| 190 |
-
#### 阶段 A 当前候选
|
| 191 |
-
|
| 192 |
-
当前 fixed dev 最佳候选:
|
| 193 |
-
|
| 194 |
-
> **image-conditioned + p_mask-only + directional + orthogonal + alpha=0.18 + weak mixed**
|
| 195 |
-
|
| 196 |
-
对应 checkpoint:
|
| 197 |
-
|
| 198 |
-
> `/workspace/SimToken/checkpoints/rpb_dev_mixed_pm_only_a018_wm005.pth`
|
| 199 |
-
|
| 200 |
-
---
|
| 201 |
-
|
| 202 |
-
## 6. 实验纪律:停止在 test 上自由调方向
|
| 203 |
-
|
| 204 |
-
从下一阶段开始,必须冻结一套 **dev tuning subset**,不再继续在 `test_s/test_u/test_n` 上自由调 alpha 和 mixed 设定。
|
| 205 |
-
|
| 206 |
-
建议立即固定:
|
| 207 |
-
|
| 208 |
-
- `dev_seen`
|
| 209 |
-
- `dev_unseen`
|
| 210 |
-
- `dev_null`
|
| 211 |
-
|
| 212 |
-
每个 split 可先取 `100` 或 `200` 个样本,后续:
|
| 213 |
-
|
| 214 |
-
- alpha 选择
|
| 215 |
-
- mixed 选择
|
| 216 |
-
- warm-start 配置
|
| 217 |
-
- early stopping
|
| 218 |
-
|
| 219 |
-
全部只在 dev 上完成。
|
| 220 |
-
真正的 test split 只用于后续一次性确认和最终表格。
|
| 221 |
-
|
| 222 |
-
---
|
| 223 |
-
|
| 224 |
-
## 7. 三阶段推进路线
|
| 225 |
-
|
| 226 |
-
## 阶段 A:锁最小核心的 operating point
|
| 227 |
-
|
| 228 |
-
### 目标
|
| 229 |
-
|
| 230 |
-
回答:
|
| 231 |
-
|
| 232 |
-
> 当前最小核心是否能在更大 quick eval 上形成稳定、可接受的性能-几何平衡?
|
| 233 |
-
|
| 234 |
-
### 本阶段只做两类实验
|
| 235 |
-
|
| 236 |
-
#### A1. teacher-only operating point 搜索
|
| 237 |
-
|
| 238 |
-
固定:
|
| 239 |
-
|
| 240 |
-
- image-conditioned
|
| 241 |
-
- `p_mask-only`
|
| 242 |
-
- directional
|
| 243 |
-
- orthogonal
|
| 244 |
-
- single-token
|
| 245 |
-
- 不加 `z_gt`
|
| 246 |
-
- 不加 calibrator
|
| 247 |
-
- 不加 refinement
|
| 248 |
-
|
| 249 |
-
重点只扫:
|
| 250 |
-
|
| 251 |
-
- `alpha = 0.12, 0.15, 0.18, 0.20`
|
| 252 |
-
|
| 253 |
-
当前判断是:`0.20` 已经是 promising pass,因此没有必要继续向更大 alpha 发散。
|
| 254 |
-
|
| 255 |
-
#### A2. weak mixed 局部验证
|
| 256 |
-
|
| 257 |
-
只围绕最佳 teacher-only checkpoint 做 warm-start,不做大 sweep。
|
| 258 |
-
|
| 259 |
-
建议只测:
|
| 260 |
-
|
| 261 |
-
- `best_alpha`
|
| 262 |
-
- `best_alpha - 0.03`
|
| 263 |
-
|
| 264 |
-
以及很弱的 mask 强度两档:
|
| 265 |
-
|
| 266 |
-
- `λ_mask = 0.05`
|
| 267 |
-
- `λ_mask = 0.10`
|
| 268 |
-
|
| 269 |
-
mixed 的目标不是涨分,而是判断它的角色到底是:
|
| 270 |
-
|
| 271 |
-
- calibration
|
| 272 |
-
- enhancement
|
| 273 |
-
- 还是 destructive pullback
|
| 274 |
-
|
| 275 |
-
### 阶段 A 重点指标
|
| 276 |
-
|
| 277 |
-
几何指标:
|
| 278 |
-
|
| 279 |
-
- `cos(p_hat, p_mask)_seen`
|
| 280 |
-
- `cos(p_hat, p_mask)_unseen`
|
| 281 |
-
- `cos(p_hat, q)`
|
| 282 |
-
- `cos(Δp, p_mask)`
|
| 283 |
-
- `cos(Δp, q)`
|
| 284 |
-
- `align_ratio = cos_u / cos_s`
|
| 285 |
-
|
| 286 |
-
性能指标:
|
| 287 |
-
|
| 288 |
-
- `mIoU_seen`
|
| 289 |
-
- `mIoU_unseen`
|
| 290 |
-
- `Fscore_seen`
|
| 291 |
-
- `Fscore_unseen`
|
| 292 |
-
- `Null metric`
|
| 293 |
-
|
| 294 |
-
### 阶段 A 的通过标准
|
| 295 |
-
|
| 296 |
-
若在 dev 或更大 quick eval 上,能找到一个稳定点满足:
|
| 297 |
-
|
| 298 |
-
- unseen 稳定不差于 baseline,最好有小幅提升
|
| 299 |
-
- seen 代价可控
|
| 300 |
-
- null 基本持平或代价可接受
|
| 301 |
-
- `cos(p_hat, p_mask)` 明显离开 identity 区
|
| 302 |
-
- seen/unseen 的 alignment ratio 健康
|
| 303 |
-
|
| 304 |
-
则阶段 A 通过。
|
| 305 |
-
|
| 306 |
-
### 阶段 A 的停止条件
|
| 307 |
-
|
| 308 |
-
若完成:
|
| 309 |
-
|
| 310 |
-
1. alpha 局部搜索
|
| 311 |
-
2. weak mixed 局部搜索
|
| 312 |
-
3. 100 / 200 样本 quick eval
|
| 313 |
-
|
| 314 |
-
之后仍出现任一情况,则停止 pure RPB standalone 主线:
|
| 315 |
-
|
| 316 |
-
- 在更大 quick eval 上没有稳定、同向的 unseen 优势
|
| 317 |
-
- seen/unseen tradeoff 对 alpha 高度敏感
|
| 318 |
-
- null 代价无法压到 baseline 附近
|
| 319 |
-
- mixed 始终只是增强器,而不是 decoder-facing calibration
|
| 320 |
-
|
| 321 |
-
---
|
| 322 |
-
|
| 323 |
-
## 阶段 B:做最小闭环 ablation
|
| 324 |
-
|
| 325 |
-
只有阶段 A 通过后,才进入阶段 B。
|
| 326 |
-
|
| 327 |
-
### 目标
|
| 328 |
-
|
| 329 |
-
把方法主骨架讲圆,形成 mechanism pass 的闭环证据。
|
| 330 |
-
|
| 331 |
-
### 必做的 4 个关键 ablation
|
| 332 |
-
|
| 333 |
-
1. **additive vs directional**
|
| 334 |
-
2. **directional without orthogonalization vs with orthogonalization**
|
| 335 |
-
3. **q-only directional vs image-conditioned directional**
|
| 336 |
-
4. **`p_mask-only` vs `p_mask + weak z_gt`**
|
| 337 |
-
|
| 338 |
-
这 4 个已经足够支撑方法论证,不再继续扩更多 trick ablation。
|
| 339 |
-
|
| 340 |
-
### 阶段 B 的补充要求
|
| 341 |
-
|
| 342 |
-
- 至少 2 个随机种子重复
|
| 343 |
-
- 至少一次更大规模验证
|
| 344 |
-
- 建立 geometry-performance coupling:
|
| 345 |
-
- prompt geometry 改写程度
|
| 346 |
-
- 与 seen/unseen 表现之间的关系
|
| 347 |
-
- 与 identity 回缩之间的关系
|
| 348 |
-
|
| 349 |
-
### 阶段 B 的停止条件
|
| 350 |
-
|
| 351 |
-
若完成:
|
| 352 |
-
|
| 353 |
-
1. alpha 局部搜索
|
| 354 |
-
2. weak mixed 局部搜索
|
| 355 |
-
3. 100 / 200 样本 quick eval
|
| 356 |
-
4. 至少一次更大规模验证
|
| 357 |
-
5. 2 个随机种子重复
|
| 358 |
-
|
| 359 |
-
后仍满足以下任一条,则停止 pure RPB standalone:
|
| 360 |
-
|
| 361 |
-
- 大子集 / full-split 上没有稳定、同向的 unseen 优势
|
| 362 |
-
- 最优点高度依赖 seed 或 alpha,趋势不稳定
|
| 363 |
-
- null 代价无法控制
|
| 364 |
-
- mixed 无法形成稳定 calibration 作用
|
| 365 |
-
- headline result 仍然只有极弱波动
|
| 366 |
-
|
| 367 |
-
---
|
| 368 |
-
|
| 369 |
-
## 阶段 C:决定论文定位
|
| 370 |
-
|
| 371 |
-
### 路线 1:pure RPB standalone
|
| 372 |
-
|
| 373 |
-
如果满足:
|
| 374 |
-
|
| 375 |
-
- 更大评估上有稳定 unseen gain
|
| 376 |
-
- seen / null 代价可接受
|
| 377 |
-
- 2 seeds 稳定
|
| 378 |
-
- 最小闭环 ablation 完整
|
| 379 |
-
|
| 380 |
-
则走:
|
| 381 |
-
|
| 382 |
-
> **pure RPB 方法论文**
|
| 383 |
-
|
| 384 |
-
### 路线 2:RPB + TTO hybrid
|
| 385 |
-
|
| 386 |
-
如果出现:
|
| 387 |
-
|
| 388 |
-
- mechanism 成立
|
| 389 |
-
- 但 paper pass 不够硬
|
| 390 |
-
- headline result 仍然偏弱或不稳定
|
| 391 |
-
|
| 392 |
-
则立刻切换定位:
|
| 393 |
-
|
| 394 |
-
> **RPB + TTO hybrid 方法论文**
|
| 395 |
-
|
| 396 |
-
此时 RPB 的角色不再是 standalone 主方法,而是:
|
| 397 |
-
|
| 398 |
-
- amortized prompt corrector
|
| 399 |
-
- 改善 test-time refinement 起点质量的前端模块
|
| 400 |
-
|
| 401 |
-
---
|
| 402 |
-
|
| 403 |
-
## 8. Hybrid 路线作为明确 Plan B
|
| 404 |
-
|
| 405 |
-
若 pure RPB 最终只能做到:
|
| 406 |
-
|
| 407 |
-
- unseen 稳定小涨
|
| 408 |
-
- seen 小掉
|
| 409 |
-
- null 持平或略好
|
| 410 |
-
|
| 411 |
-
那么 standalone 顶会会比较吃力。
|
| 412 |
-
但此时 RPB 作为前端 prompt corrector 仍很有价值:
|
| 413 |
-
|
| 414 |
-
- 改善初始 `q` 的几何
|
| 415 |
-
- 为 q-LTPO / selective refinement 提供更好的初始化
|
| 416 |
-
- 降低 test-time optimization 的步数和不稳定性
|
| 417 |
-
|
| 418 |
-
hybrid 的论文叙事可以明确写成:
|
| 419 |
-
|
| 420 |
-
1. train-time:amortized interface correction
|
| 421 |
-
2. test-time:instance-specific prompt refinement
|
| 422 |
-
3. 两者结合:同时解决全局接口失配与样本级细化问题
|
| 423 |
-
|
| 424 |
-
当前判断:hybrid 是非常强的 Plan B,而不是临时补救路线。
|
| 425 |
-
|
| 426 |
-
---
|
| 427 |
-
|
| 428 |
-
## 9. 负结果如何写进论文论证链条
|
| 429 |
-
|
| 430 |
-
当前已经得到了一条清晰的“设计收敛链条”,后续可以直接转写为论文方法论证:
|
| 431 |
-
|
| 432 |
-
### 为什么不是 additive residual
|
| 433 |
-
|
| 434 |
-
因为 additive 下:
|
| 435 |
-
|
| 436 |
-
- `Δp` 主要对抗 `q` 的平行分量
|
| 437 |
-
- teacher 方向被大范数 `q` 吞掉
|
| 438 |
-
- 结果更像缩放,而不是旋转
|
| 439 |
-
|
| 440 |
-
### 为什么要 directional
|
| 441 |
-
|
| 442 |
-
因为 directional 才能把修正显式变成 prompt 方向控制,而不是数值扰动。
|
| 443 |
-
|
| 444 |
-
### 为什么要 orthogonal
|
| 445 |
-
|
| 446 |
-
因为 orthogonalization 才能避免 residual 预算浪费在径向缩放上。
|
| 447 |
-
|
| 448 |
-
### 为什么当前只保留 `p_mask`
|
| 449 |
-
|
| 450 |
-
因为当前 sparse bridge 里,`p_mask` 一直是主 teacher,`z_gt` 尚未成为主信号。
|
| 451 |
-
|
| 452 |
-
### 为什么 mixed 不是主模块
|
| 453 |
-
|
| 454 |
-
因为 mixed 目前更像 compatibility / enhancement probe,而不是稳定的 calibration mechanism。
|
| 455 |
-
|
| 456 |
-
这条链条必须在文中明确写出,让 reviewer 看到方法是沿诊断逐步收敛的,而不是盲目堆模块。
|
| 457 |
-
|
| 458 |
-
---
|
| 459 |
-
|
| 460 |
-
## 10. 当前最直接的执行建议
|
| 461 |
-
|
| 462 |
-
接下来不要发散,严格按下面顺序走:
|
| 463 |
-
|
| 464 |
-
1. **立刻冻结论文主 claim**
|
| 465 |
-
2. **立刻切换到固定 dev 子集,不再自由用 test 调方向**
|
| 466 |
-
3. **完成阶段 A:最小核心 operating point 搜索**
|
| 467 |
-
4. **补关键 baseline:q-only directional**
|
| 468 |
-
5. **做两种 seed**
|
| 469 |
-
6. **然后做 pure RPB standalone 的去留决策**
|
| 470 |
-
|
| 471 |
-
当前最重要的执行原则是:
|
| 472 |
-
|
| 473 |
-
> **先证明最小核心能稳定成立;如果 headline 不够硬,就及时把它升级成 hybrid 前端,而不是继续把 pure RPB 做复杂。**
|
| 474 |
-
|
| 475 |
-
---
|
| 476 |
-
|
| 477 |
-
## 11. 当前阶段的明确结论
|
| 478 |
-
|
| 479 |
-
### 当前方向值得继续吗?
|
| 480 |
-
|
| 481 |
-
**值得。**
|
| 482 |
-
|
| 483 |
-
### 现在最应该做什么?
|
| 484 |
-
|
| 485 |
-
不是继续扩模块,而是:
|
| 486 |
-
|
| 487 |
-
- 找到 teacher-only `p_mask-only directional orthogonal` 的最佳 operating point
|
| 488 |
-
- 用 very weak mixed 判断 mixed 是否能形成 calibration
|
| 489 |
-
- 在 dev 和更大 quick eval 上证明趋势不是噪声
|
| 490 |
-
|
| 491 |
-
### 什么时候该停 pure RPB?
|
| 492 |
-
|
| 493 |
-
只要阶段 A + B 完成后,headline 仍然弱且不稳定,就停止 pure RPB standalone。
|
| 494 |
-
|
| 495 |
-
### 停了之后怎么办?
|
| 496 |
-
|
| 497 |
-
直接转:
|
| 498 |
-
|
| 499 |
-
> **RPB + TTO hybrid**
|
| 500 |
-
|
| 501 |
-
这条路线当前是明确的 Plan B,而且很可能是更强的顶会方法论文路径。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEG_LTPO_results.md
DELETED
|
@@ -1,488 +0,0 @@
|
|
| 1 |
-
# SEG-LTPO: Experimental Results and Analysis
|
| 2 |
-
|
| 3 |
-
---
|
| 4 |
-
|
| 5 |
-
## Method 1: SEG-LTPO-simple (ES-based, zeroth-order)
|
| 6 |
-
|
| 7 |
-
### Overview
|
| 8 |
-
|
| 9 |
-
SEG-LTPO-simple performs test-time optimization of SimToken's single semantic token **Fseg** using antithetic Evolution Strategies (ES), guided by an internal reward signal that requires no ground-truth masks.
|
| 10 |
-
|
| 11 |
-
**Optimization loop** (T=5 steps, 4 anchor frames):
|
| 12 |
-
```
|
| 13 |
-
eps_t ~ N(0, σ_t² I)
|
| 14 |
-
F± = F_curr ± eps_t
|
| 15 |
-
F_curr = F_curr + η_t · (R+ − R−) / (2σ_t²) · eps_t
|
| 16 |
-
best_F = argmax_F R(F) over all evaluated candidates
|
| 17 |
-
```
|
| 18 |
-
|
| 19 |
-
**Reward function:**
|
| 20 |
-
```
|
| 21 |
-
R = λ1·R_temp_feat + λ2·R_iou_pred + λ3·R_align_contrast − λ4·R_area
|
| 22 |
-
= 0.3·R_temp + 0.4·R_iou + 1.0·R_align − 0.3·R_area
|
| 23 |
-
```
|
| 24 |
-
|
| 25 |
-
- **R_align_contrast**: cosine(Fseg, z_inside) − β·cosine(Fseg, z_outside); main signal
|
| 26 |
-
- **R_iou_pred**: SAM's internal mask quality head output
|
| 27 |
-
- **R_temp_feat**: feature-space cosine consistency between adjacent anchor frames
|
| 28 |
-
- **R_area**: average foreground ratio (degenerate-mask penalty)
|
| 29 |
-
|
| 30 |
-
**Reward gating**: accept optimized Fseg only when R(best_F) > R(F_init) + gate_delta.
|
| 31 |
-
|
| 32 |
-
### Results (Unseen split, full 1656 samples)
|
| 33 |
-
|
| 34 |
-
| Method | mIoU | F | Δ mIoU |
|
| 35 |
-
|--------|------|---|--------|
|
| 36 |
-
| Baseline | 0.6989 | 0.7927 | — |
|
| 37 |
-
| Best-of-2 Random | 0.7050 (subset) → 0.7030 (full) | 0.7953 | +0.0040 |
|
| 38 |
-
| SEG-LTPO-simple (ES) | **0.7050** | **0.7960** | **+0.0061** |
|
| 39 |
-
|
| 40 |
-
> Best-of-2 and LTPO-ES results at full scale confirmed in the q-LTPO evaluation run below.
|
| 41 |
-
|
| 42 |
-
### Key Findings
|
| 43 |
-
|
| 44 |
-
1. **Reward signal is valid**: both Best-of-2 and ES-LTPO outperform baseline, confirming R_align_contrast provides useful signal.
|
| 45 |
-
2. **ES update is noisy**: in 500-sample ablation, Best-of-2 (0.7235) slightly outperformed iterative ES (0.7228), due to extremely low SNR of single-sample gradient estimation in 256d space. At full scale (1656), ES-LTPO recovers (+0.0065 vs +0.0040), but the margin over Best-of-2 is small.
|
| 46 |
-
3. **Null stability**: Null S metric change negligible (+0.00025), reward gating effectively suppresses false positives.
|
| 47 |
-
|
| 48 |
-
---
|
| 49 |
-
|
| 50 |
-
## Method 2: q-LTPO-autograd (first-order, Adam maximize)
|
| 51 |
-
|
| 52 |
-
### Overview
|
| 53 |
-
|
| 54 |
-
**Core insight from LTPO analysis**: optimize the variable that is *directly consumed* by the downstream module, using autograd rather than noisy zeroth-order estimation.
|
| 55 |
-
|
| 56 |
-
**Three design decisions borrowed from original LTPO:**
|
| 57 |
-
|
| 58 |
-
1. **Optimize q, not Fseg.** In SimToken+SAM, the token that directly enters the mask decoder's cross-attention is `q = sparse_emb = Fseg.unsqueeze(1)` (prompt encoder passes text_embeds through unchanged). We set `q = nn.Parameter(q_init)` and optimize q directly, bypassing the prompt encoder entirely. This requires no invertibility of ε_p — q_best is used directly for final inference.
|
| 59 |
-
|
| 60 |
-
2. **Use autograd when reward is differentiable.** The mask decoder (transformer + MLP + matmul) is fully differentiable. With soft masks instead of hard thresholds, all reward terms are differentiable w.r.t. q. Adam maximize replaces the low-SNR score-function estimator.
|
| 61 |
-
|
| 62 |
-
3. **Track best_q by task reward (no regularization), gate at the end.** λ_reg penalty is excluded from gating to avoid penalizing solutions that drifted slightly from q_init but achieved better task reward.
|
| 63 |
-
|
| 64 |
-
**Stage 0: Gradient connectivity check (verified)**
|
| 65 |
-
```
|
| 66 |
-
grad_norm (step 0): 0.503070
|
| 67 |
-
reward trajectory: [0.4650, 0.4709, 0.4770, 0.4831, 0.4892] ← strictly monotone
|
| 68 |
-
gradient_connected: True
|
| 69 |
-
```
|
| 70 |
-
|
| 71 |
-
### Optimization loop
|
| 72 |
-
|
| 73 |
-
```python
|
| 74 |
-
q = nn.Parameter(q_init.float().detach().clone())
|
| 75 |
-
optimizer = Adam([q], lr=lr_auto, maximize=True)
|
| 76 |
-
best_q, best_reward = q_init.clone(), R_task(q_init)
|
| 77 |
-
|
| 78 |
-
for step in range(T=5):
|
| 79 |
-
R_full = R_task(q) - λ_reg * ||q - q_init||²
|
| 80 |
-
R_full.backward()
|
| 81 |
-
optimizer.step()
|
| 82 |
-
clip_to_L2_ball(q, q_init, max_drift) # hard norm constraint
|
| 83 |
-
if R_task(q) > best_reward:
|
| 84 |
-
best_q = q.clone()
|
| 85 |
-
|
| 86 |
-
# gating
|
| 87 |
-
use best_q if R_task(best_q) > R_task(q_init) + gate_delta, else q_init
|
| 88 |
-
```
|
| 89 |
-
|
| 90 |
-
**Hyperparameters (auto-scaled from q_init):**
|
| 91 |
-
- `lr = 0.01 × RMS(q_init)`
|
| 92 |
-
- `max_drift = 0.5 × ||q_init||`
|
| 93 |
-
- `λ_reg = 0.01`, `gate_delta = 0.0`
|
| 94 |
-
|
| 95 |
-
### Staged reward build-up
|
| 96 |
-
|
| 97 |
-
**Stage 1** (R_iou + R_area_soft + λ_reg):
|
| 98 |
-
```
|
| 99 |
-
R_task = 0.6·R_iou_pred − 0.2·sigmoid(mask_logits/τ).mean()
|
| 100 |
-
where τ=5.0 (temperature to avoid sigmoid saturation)
|
| 101 |
-
```
|
| 102 |
-
|
| 103 |
-
**Stage 2** (Stage 1 + R_align_det):
|
| 104 |
-
```
|
| 105 |
-
R_task = 0.4·R_iou_pred + 1.0·R_align_det − 0.3·R_area_soft
|
| 106 |
-
R_align_det = mean_t [ cosine(q, stopgrad(z_in^t)) − 0.5·cosine(q, stopgrad(z_out^t)) ]
|
| 107 |
-
```
|
| 108 |
-
z_in/z_out are stopgrad'd to avoid coupling: q first finds a mask, then moves toward the masked region's semantics.
|
| 109 |
-
|
| 110 |
-
### Results (Unseen split)
|
| 111 |
-
|
| 112 |
-
#### 200-sample subset (Stage 1 vs Stage 2 fair comparison, same baseline)
|
| 113 |
-
|
| 114 |
-
| Method | mIoU | F | Δ mIoU |
|
| 115 |
-
|--------|------|---|--------|
|
| 116 |
-
| Baseline | 0.6749 | 0.7763 | — |
|
| 117 |
-
| Best-of-2 ES | 0.6801 | 0.7803 | +0.0052 |
|
| 118 |
-
| LTPO-ES | 0.6838 | 0.7826 | +0.0089 |
|
| 119 |
-
| q-LTPO Stage 1 | 0.6979 | 0.7802 | +0.0230 |
|
| 120 |
-
| q-LTPO Stage 2 | **0.6989** | **0.7810** | **+0.0240** |
|
| 121 |
-
|
| 122 |
-
On 200 samples: Stage 2 marginally better than Stage 1 on both metrics.
|
| 123 |
-
|
| 124 |
-
#### Full evaluation (Unseen, 1656 samples)
|
| 125 |
-
|
| 126 |
-
| Method | mIoU | F | Δ mIoU vs Baseline |
|
| 127 |
-
|--------|------|---|---------------------|
|
| 128 |
-
| Baseline | 0.6990 | 0.7924 | — |
|
| 129 |
-
| Best-of-2 ES | 0.7030 | 0.7953 | +0.0040 (+0.57%) |
|
| 130 |
-
| LTPO-ES | 0.7055 | 0.7969 | +0.0065 (+0.93%) |
|
| 131 |
-
| **q-LTPO Stage 1** | **0.7285** | **0.8013** | **+0.0295 (+4.22%)** |
|
| 132 |
-
| q-LTPO Stage 2 | 0.7273 | 0.8002 | +0.0283 (+4.04%) |
|
| 133 |
-
|
| 134 |
-
**Stage 1 beats Stage 2 on full eval** (opposite of 200-sample trend). R_align_det adds noise at scale: in harder Unseen samples, the initial mask quality is lower, making stopgrad z_in/z_out a less reliable target.
|
| 135 |
-
|
| 136 |
-
### Evaluation Status (after e0 fix)
|
| 137 |
-
|
| 138 |
-
| Split | Baseline mIoU/S | q-LTPO S1 (no e0) | q-LTPO S1 (e0) | Status |
|
| 139 |
-
|-------|-----------------|-------------------|----------------|--------|
|
| 140 |
-
| Unseen (1656) | 0.6990 | **0.7285** | — | Done (pre-e0) |
|
| 141 |
-
| Seen (200-sample) | 0.7483 | 0.7618 (+0.0136) | **0.7634 (+0.0151)** | Quick-val done |
|
| 142 |
-
| Null (200-sample, S↓) | 0.0619 | 0.0646 (+4.4%) | **0.0634 (+2.4%)** | Quick-val done |
|
| 143 |
-
| Unseen (200-sample) | 0.6761 | — | **0.6929 (+0.0168)** | Quick-val done |
|
| 144 |
-
| Seen (full) | — | — | — | Pending |
|
| 145 |
-
| Null (full, S↓) | 0.0120 | 0.0126 (+5.0%) | — | Pending e0 run |
|
| 146 |
-
| Unseen (full) | — | — | — | Pending |
|
| 147 |
-
|
| 148 |
-
---
|
| 149 |
-
|
| 150 |
-
## Null Safety Analysis and e0-Modulated Reward
|
| 151 |
-
|
| 152 |
-
### Root Cause: R_iou_pred is a Conditional Quality Metric
|
| 153 |
-
|
| 154 |
-
The original q-LTPO Stage 1 reward:
|
| 155 |
-
```
|
| 156 |
-
R_task = 0.6·R_iou_pred − 0.2·R_area_soft
|
| 157 |
-
```
|
| 158 |
-
|
| 159 |
-
caused Null S metric degradation (+4.4% on 200-sample quick validation, +5.0% on full Null).
|
| 160 |
-
|
| 161 |
-
**Root cause**: `R_iou_pred` is SAM's internal mask quality head — it measures *how good the mask is given that segmentation was performed*, not *whether the target exists*. On Null frames, SAM still outputs `R_iou_pred ≈ 0.73–0.74` because it confidently segments the most prominent region (even if no audio target exists). The optimizer sees positive `R_iou_pred` and expands the mask accordingly.
|
| 162 |
-
|
| 163 |
-
**Why oracle gating approaches fail methodologically:**
|
| 164 |
-
|
| 165 |
-
- **Path A (gate_delta threshold)**: Distribution analysis showed Null reward_gain p50 = +0.0166 ≈ Seen p50 = +0.0181. The two distributions overlap heavily; any threshold that blocks most Null samples also blocks most Seen/Unseen samples.
|
| 166 |
-
- **Path B (area-based reject rule)**: Threshold 0.02 (area fraction) was derived by observing Null mean_area = 0.0094 vs Seen mean_area = 0.054 from the test distribution. This is benchmark-specific tuning = test-set overfitting. **Not a valid method.**
|
| 167 |
-
|
| 168 |
-
Both oracle approaches are useful for diagnostic analysis only. The principled fix must be structural.
|
| 169 |
-
|
| 170 |
-
### Principled Fix: e0-Modulated Reward
|
| 171 |
-
|
| 172 |
-
**Key insight**: decouple *existence* from *quality*. Use the initial mask area as a proxy for the prior probability that a real target exists.
|
| 173 |
-
|
| 174 |
-
```python
|
| 175 |
-
e0 = stopgrad( sigmoid(lrm_init / area_temp).mean() ) # R_area_soft at q_init
|
| 176 |
-
R_task = λ_iou · e0 · R_iou_pred − λ_area · R_area_soft
|
| 177 |
-
```
|
| 178 |
-
|
| 179 |
-
**Why stopgrad on e0 is critical:**
|
| 180 |
-
- Without stopgrad: gradients flow through e0 → optimizer first inflates area to increase e0, then uses the higher e0 to justify larger R_iou reward ("area gaming").
|
| 181 |
-
- With stopgrad: e0 is a fixed scalar from the initialization. Gradients only flow through the explicit terms `R_iou_pred` and `R_area_soft`.
|
| 182 |
-
|
| 183 |
-
**Effect by split:**
|
| 184 |
-
|
| 185 |
-
| Split | mean e0 | Effective λ_iou = 0.6·e0 | Behavior |
|
| 186 |
-
|-------|---------|--------------------------|----------|
|
| 187 |
-
| Null | 0.037 | 0.022 | Area penalty dominates → conservative |
|
| 188 |
-
| Seen | 0.120 | 0.072 | Balanced optimization |
|
| 189 |
-
| Unseen | 0.150 | 0.090 | Full optimization drive |
|
| 190 |
-
|
| 191 |
-
The 3.2× e0 ratio (Unseen/Null) arises naturally from the initial mask size, providing automatic split-specific optimization strength without any threshold tuning.
|
| 192 |
-
|
| 193 |
-
**Implementation fix also addressed (best_q tracking bug):**
|
| 194 |
-
Before fix, `q_{N+1}` (post-step) was evaluated using `lrm/iou` from `q_N` (pre-step), corrupting best_q selection. Fixed by adding a fresh `no_grad` forward after each `optimizer.step()`.
|
| 195 |
-
|
| 196 |
-
### Quick Validation Results (200 samples each, e0 modulation)
|
| 197 |
-
|
| 198 |
-
#### Null split (S metric, lower is better)
|
| 199 |
-
|
| 200 |
-
| Method | S metric | Δ relative |
|
| 201 |
-
|--------|----------|-----------|
|
| 202 |
-
| Baseline | 0.0619 | — |
|
| 203 |
-
| q-LTPO S1 (no e0) | 0.0646 | +4.4% |
|
| 204 |
-
| **q-LTPO S1 (e0)** | **0.0634** | **+2.4%** |
|
| 205 |
-
|
| 206 |
-
Diagnostic stats with e0:
|
| 207 |
-
```
|
| 208 |
-
acceptance rate : 1.000
|
| 209 |
-
mean e0 : 0.0372
|
| 210 |
-
reward_gain p10/50/90: 0.0 / 0.0000 / +0.0123 ← p50=0 means >50% of samples frozen
|
| 211 |
-
mean drift : 0.4962 ← down from ~0.8 without e0
|
| 212 |
-
area (hard) init→best: 0.0094 → 0.0098 ← minimal area expansion
|
| 213 |
-
reward↑ & area+20%↑ : 0.040 ← low Null-safety risk
|
| 214 |
-
```
|
| 215 |
-
|
| 216 |
-
#### Seen split (mIoU, higher is better)
|
| 217 |
-
|
| 218 |
-
| Method | mIoU | F | Δ mIoU |
|
| 219 |
-
|--------|------|---|--------|
|
| 220 |
-
| Baseline | 0.7483 | — | — |
|
| 221 |
-
| q-LTPO S1 (no e0) | 0.7618 | — | +0.0136 |
|
| 222 |
-
| **q-LTPO S1 (e0)** | **0.7634** | — | **+0.0151** |
|
| 223 |
-
|
| 224 |
-
Diagnostic stats with e0:
|
| 225 |
-
```
|
| 226 |
-
mean e0 : 0.1200
|
| 227 |
-
reward_gain p10/50/90: +0.0026 / +0.0181 / +0.0944
|
| 228 |
-
mean drift : 0.5225
|
| 229 |
-
area (hard) init→best: 0.054 → (slight increase)
|
| 230 |
-
```
|
| 231 |
-
|
| 232 |
-
#### Unseen split (mIoU, higher is better)
|
| 233 |
-
|
| 234 |
-
| Method | mIoU | F | Δ mIoU |
|
| 235 |
-
|--------|------|---|--------|
|
| 236 |
-
| Baseline | 0.6761 | 0.7776 | — |
|
| 237 |
-
| **q-LTPO S1 (e0)** | **0.6929** | **0.7765** | **+0.0168** |
|
| 238 |
-
|
| 239 |
-
Diagnostic stats with e0:
|
| 240 |
-
```
|
| 241 |
-
acceptance rate : 1.000
|
| 242 |
-
mean e0 : 0.1506
|
| 243 |
-
reward_gain p10/50/90: +0.0011 / +0.0055 / +0.0293
|
| 244 |
-
mean drift : 0.6666
|
| 245 |
-
R_iou_pred init→best : 0.8029 → 0.8802
|
| 246 |
-
area (hard) init→best: 0.0635 → 0.0650
|
| 247 |
-
reward↑ & area+20%↑ : 0.125
|
| 248 |
-
```
|
| 249 |
-
|
| 250 |
-
### Analysis: e0 is a Pareto Improvement
|
| 251 |
-
|
| 252 |
-
Three conditions for Pareto improvement all satisfied on quick validation:
|
| 253 |
-
|
| 254 |
-
1. **Null safer**: degradation halved (+4.4% → +2.4%). p50 reward_gain = 0.0000, meaning >50% of Null samples produce `best_q ≈ q_init`.
|
| 255 |
-
2. **Seen maintained and slightly improved**: +0.0151 vs +0.0136 without e0.
|
| 256 |
-
3. **Unseen not hurt — gains even larger**: +0.0168 > Seen +0.0151. The "harder positives suppressed" failure mode did not materialize.
|
| 257 |
-
|
| 258 |
-
**e0 hierarchy confirms split-level discriminability:**
|
| 259 |
-
```
|
| 260 |
-
Null (0.037) << Seen (0.120) < Unseen (0.150)
|
| 261 |
-
```
|
| 262 |
-
The ordering is sensible: Null frames have small/empty initial masks → low e0. Unseen e0 slightly exceeds Seen, possibly because the model produces slightly larger (less specific) masks on novel object-sentence combinations.
|
| 263 |
-
|
| 264 |
-
**Residual Null degradation (+2.4%) assessment**: Acceptable for now. The absolute magnitude is +0.0015 in S metric, while Seen/Unseen absolute gains are 10–11× larger. The residual originates from a small tail of Null samples where e0 is still large enough to permit some mask expansion. Further suppression (e.g., e0², sqrt(e0+ε)) risks hurting harder positives and should only be explored after full-set confirmation.
|
| 265 |
-
|
| 266 |
-
---
|
| 267 |
-
|
| 268 |
-
## Summary and Comparison
|
| 269 |
-
|
| 270 |
-
### Pre-e0 (original q-LTPO Stage 1, full Unseen)
|
| 271 |
-
|
| 272 |
-
| Method | Unseen mIoU | Δ vs Baseline | Relative to ES-LTPO |
|
| 273 |
-
|--------|-------------|---------------|----------------------|
|
| 274 |
-
| Baseline | 0.6990 | — | — |
|
| 275 |
-
| ES-LTPO | 0.7055 | +0.0065 | 1× |
|
| 276 |
-
| **q-LTPO Stage 1** | **0.7285** | **+0.0295** | **4.5×** |
|
| 277 |
-
|
| 278 |
-
### e0-Modulated Stage 1 (quick validation, 200 samples)
|
| 279 |
-
|
| 280 |
-
| Split | Baseline | e0-Stage1 | Δ | e0 |
|
| 281 |
-
|-------|----------|-----------|---|-----|
|
| 282 |
-
| Null (S↓) | 0.0619 | 0.0634 | +2.4% (rel) | 0.037 |
|
| 283 |
-
| Seen | 0.7483 | 0.7634 | +0.0151 | 0.120 |
|
| 284 |
-
| Unseen | 0.6761 | 0.6929 | +0.0168 | 0.150 |
|
| 285 |
-
|
| 286 |
-
q-LTPO-autograd with e0 modulation is the current primary method candidate. It achieves first-order gradient-based optimization with automatic Null-safety via the initial-area existence prior, without any test-set-derived thresholds.
|
| 287 |
-
|
| 288 |
-
---
|
| 289 |
-
|
| 290 |
-
## Hyperparameter Configurations
|
| 291 |
-
|
| 292 |
-
### ES-LTPO (Method 1)
|
| 293 |
-
```python
|
| 294 |
-
LTPOConfig(
|
| 295 |
-
T=5, num_anchors=4,
|
| 296 |
-
sigma_schedule=[0.10, 0.08, 0.06, 0.04, 0.02],
|
| 297 |
-
eta_scale=0.5,
|
| 298 |
-
lambda1=0.3, lambda2=0.4, lambda3=1.0, lambda4=0.3,
|
| 299 |
-
beta=0.5, gate_delta=0.0, trust_delta=None,
|
| 300 |
-
)
|
| 301 |
-
```
|
| 302 |
-
|
| 303 |
-
### q-LTPO Stage 1 with e0 (current primary candidate)
|
| 304 |
-
```python
|
| 305 |
-
QLTPOConfig(
|
| 306 |
-
stage=1, T=5, num_anchors=4,
|
| 307 |
-
lr=0.0, # auto: 0.01 × RMS(q_init)
|
| 308 |
-
max_drift=0.0, # auto: 0.5 × ||q_init||
|
| 309 |
-
lambda_iou=0.6, lambda_area=0.2,
|
| 310 |
-
lambda_reg=0.01, area_temp=5.0,
|
| 311 |
-
gate_delta=0.0,
|
| 312 |
-
e0_modulation="identity", # e0 = R_area_soft(q_init), stopgrad
|
| 313 |
-
e0_eps=1e-4,
|
| 314 |
-
# oracle-only fields (disabled, not used in final method):
|
| 315 |
-
null_area_threshold=0.02,
|
| 316 |
-
null_gate_delta=0.0,
|
| 317 |
-
)
|
| 318 |
-
```
|
| 319 |
-
|
| 320 |
-
### Full Unseen Evaluation with e0 (1656 samples)
|
| 321 |
-
|
| 322 |
-
| Method | mIoU | F | Δ mIoU |
|
| 323 |
-
|--------|------|---|--------|
|
| 324 |
-
| Baseline | 0.6990 | 0.7926 | — |
|
| 325 |
-
| q-LTPO S1 (no e0) | 0.7285 | 0.8013 | +0.0295 (+4.22%) |
|
| 326 |
-
| **q-LTPO S1 (e0)** | **0.7240** | **0.7985** | **+0.0250 (+3.56%)** |
|
| 327 |
-
|
| 328 |
-
e0 版本相比 no-e0 版本 mIoU 略低 (-0.0045),但 Null 安全性更好。F 与 mIoU 的提升比例基本一致(约 60%)。
|
| 329 |
-
|
| 330 |
-
**全量评估状态(更新):**
|
| 331 |
-
|
| 332 |
-
| Split | Baseline | q-LTPO S1 (e0) | Δ | Status |
|
| 333 |
-
|-------|----------|----------------|---|--------|
|
| 334 |
-
| Unseen (full, 1656) | 0.6990 / 0.7926 | 0.7240 / 0.7985 | +3.56% mIoU | ✅ Done |
|
| 335 |
-
| Seen (full) | — | — | — | Pending |
|
| 336 |
-
| Null (full, S↓) | 0.0120 | — | — | Pending |
|
| 337 |
-
|
| 338 |
-
---
|
| 339 |
-
|
| 340 |
-
## Direction B: Boundary Precision Experiments(已结束,结论为失败)
|
| 341 |
-
|
| 342 |
-
### B-Step1: Multimask Post-Processing(彻底失败)
|
| 343 |
-
|
| 344 |
-
用 SAM 多 mask 输出(K=3)替换单 mask 解码,分别用 iou_pred 和 Sobel edge score 选最佳候选。
|
| 345 |
-
|
| 346 |
-
| Method | mIoU | F | ΔF vs s1 |
|
| 347 |
-
|--------|------|---|----------|
|
| 348 |
-
| s1 (single mask) | 0.6979 | 0.8024 | — |
|
| 349 |
-
| s1_mm (iou_pred selection) | 0.6979 | 0.7917 | -0.0107 |
|
| 350 |
-
| s1_mm_edge (Sobel selection) | 0.5715 | 0.6820 | -0.1204 |
|
| 351 |
-
|
| 352 |
-
**根本原因:** SAM 内部的单 mask 选择已经最优;外部重选更差。Sobel 在 1024×1024 归一化空间中选到纹理碎片而非语义目标,灾难性失败。
|
| 353 |
-
|
| 354 |
-
### B1: 非对称面积膨胀惩罚(机制性无效)
|
| 355 |
-
|
| 356 |
-
假设:LTPO 导致 mask 向非目标区域膨胀(精度下降),加惩罚项压制。
|
| 357 |
-
|
| 358 |
-
**实验结论:假设错误。** LTPO 期间 soft area 实际在下降(-16%)而非上升:
|
| 359 |
-
|
| 360 |
-
```
|
| 361 |
-
soft area: 0.1507 → 0.1267 (-16%) ← background logits 更负
|
| 362 |
-
hard area: 0.0635 → 0.0650 (+2.4%) ← 实际 mask 区域微增
|
| 363 |
-
```
|
| 364 |
-
|
| 365 |
-
**"mask sharpening" 现象:** Adam 在 R_iou_pred 驱动下使 logit 更双峰化(前景更正、背景更负),soft area 因 93% 背景像素的贡献减少而下降。B1 惩罚的前提条件(soft area 上升)从未发生:
|
| 366 |
-
|
| 367 |
-
```
|
| 368 |
-
B1 activation rate : 0.025 ← 仅 2.5% 样本触发
|
| 369 |
-
B1 mean excess : 0.00002 ← 可忽略
|
| 370 |
-
```
|
| 371 |
-
|
| 372 |
-
**结论:** Direction B 从多 mask 选择到面积约束全部失败,不再追求。F-score 滞后于 mIoU 的根本原因不是 mask 精度,而是 reward 代理信号质量问题(见 Path A)。
|
| 373 |
-
|
| 374 |
-
---
|
| 375 |
-
|
| 376 |
-
## Direction II: Frame-Adaptive Token Optimization(初步探索,待后续)
|
| 377 |
-
|
| 378 |
-
### 方法设计
|
| 379 |
-
|
| 380 |
-
将单一共享 token q 扩展为视频 token 轨迹:
|
| 381 |
-
|
| 382 |
-
```
|
| 383 |
-
q_t = q_global + delta_t
|
| 384 |
-
```
|
| 385 |
-
|
| 386 |
-
其中 q_global 是全局共享 token,delta_t 是每个 anchor 帧的局部残差,初始化为 0。联合优化:
|
| 387 |
-
|
| 388 |
-
```
|
| 389 |
-
max Σ_t [λ_iou · e0_t · R_iou(q_t) - λ_area · R_area(q_t)]
|
| 390 |
-
- λ_residual · ||delta||² - λ_smooth · Σ_t ||delta_t - delta_{t+1}||² - λ_reg · ||q_global - q_init||²
|
| 391 |
-
```
|
| 392 |
-
|
| 393 |
-
每个 anchor 帧使用各自的 e0_t(per-frame 存在先验)。delta_t 受 hard clip 约束:`||delta_t|| ≤ scale × ||q_init||`。
|
| 394 |
-
|
| 395 |
-
### 200-sample Probe Results(Unseen split)
|
| 396 |
-
|
| 397 |
-
| Method | mIoU | F | reward gain p50 | delta ‖Δ‖ |
|
| 398 |
-
|--------|------|---|-----------------|-----------|
|
| 399 |
-
| baseline | 0.6745 | 0.7763 | — | — |
|
| 400 |
-
| s1 | 0.6945 | 0.7773 | +0.0053 | — |
|
| 401 |
-
| fa_base (无约束) | 0.6945 | 0.7711 | +0.0112 | 1.675 |
|
| 402 |
-
| fa_smooth (λ_smooth=0.01) | 0.6960 | 0.7731 | +0.0104 | 1.488 |
|
| 403 |
-
| fa_c03 (delta clip 0.3×) | 0.6959 | 0.7722 | +0.0112 | — |
|
| 404 |
-
|
| 405 |
-
### 关键发现
|
| 406 |
-
|
| 407 |
-
**Reward-metric gap(核心问题):**
|
| 408 |
-
```
|
| 409 |
-
reward gain p50: s1 = +0.0053 fa_c03 = +0.0112 (fa 高 2.1×)
|
| 410 |
-
R_iou_pred 提升: s1 +0.077 fa_c03 +0.114
|
| 411 |
-
实际 mIoU 提升: s1 +2.96% fa_c03 +3.17% (仅差 0.21%)
|
| 412 |
-
```
|
| 413 |
-
fa 拿到了多得多的 reward,但 mIoU 几乎没有额外提升,F 还略降。
|
| 414 |
-
|
| 415 |
-
**结论:** 瓶颈不是优化结构,而是 R_iou_pred 本身的任务相关性不足。R_iou_pred 衡量"mask 有多干净",不衡量"mask 是否包含正确的音频目标"。所有架构变体(单 token / frame-adaptive)都受同一个天花板限制。
|
| 416 |
-
|
| 417 |
-
Direction II 不在旧 reward 下继续调参,等 Path A(新 reward)有正向信号后再考虑是否重新引入。
|
| 418 |
-
|
| 419 |
-
---
|
| 420 |
-
|
| 421 |
-
## Path A: AVT-Aware Reward 重设计
|
| 422 |
-
|
| 423 |
-
### 动机
|
| 424 |
-
|
| 425 |
-
Ref-AVS 中的 referent 不一定是发声体本身(可能是拿着发声物体的人、与声源相关的对象)。纯音频对齐 reward 会将优化推向 sound source 而非 text 指向的 referent。需要 audio + text + global visual context 共同定义的 referent consistency。
|
| 426 |
-
|
| 427 |
-
### AVT Proxy Reward 设计
|
| 428 |
-
|
| 429 |
-
**核心洞察:** Fseg(= q_init)已经是 audio + video + text 的多模态融合 token,可直接作为 frozen AVT teacher。
|
| 430 |
-
|
| 431 |
-
```python
|
| 432 |
-
R_avt = mean_t cos(z_in_t, q_init)
|
| 433 |
-
R_avt_c = mean_t [cos(z_in_t, q_init) - β · cos(z_out_t, q_init)]
|
| 434 |
-
```
|
| 435 |
-
|
| 436 |
-
- `z_in_t`:anchor 帧 t 的 soft-masked 图像特征(SAM 256-dim 空间)
|
| 437 |
-
- `q_init`:frozen Fseg(AVT anchor,不参与优化梯度)
|
| 438 |
-
- R_avt 高 → mask 区域与查询 referent 对齐;R_avt 低 → mask 指向错误目标
|
| 439 |
-
|
| 440 |
-
与 Stage 2 的区别:Stage 2 用当前 q(移动)对齐 z_in(当前 mask),导致自我确认偏差;R_avt 用 q_init(固定)作为 teacher,打破偏差。
|
| 441 |
-
|
| 442 |
-
### Step A0: Reward–Metric Correlation Study(下一步要做)
|
| 443 |
-
|
| 444 |
-
**目的:** 在进入 full optimization 之前,先用数据验证新 reward 是否比 R_iou_pred 更能预测真实 metric 变化。
|
| 445 |
-
|
| 446 |
-
**实验设置(200 samples, Unseen split):**
|
| 447 |
-
对每个(视频,segment)样本:
|
| 448 |
-
1. Baseline decode → IoU_base, F_base
|
| 449 |
-
2. q-LTPO s1 → q_best;记录 reward_gain、r_avt_gain、r_avt_c_gain(均在 q_ltpo_autograd 内计算)
|
| 450 |
-
3. LTPO decode → IoU_ltpo, F_ltpo
|
| 451 |
-
4. Δ = LTPO - baseline
|
| 452 |
-
|
| 453 |
-
输出 Pearson 相关表:
|
| 454 |
-
|
| 455 |
-
```
|
| 456 |
-
Pearson r with ΔmIoU:
|
| 457 |
-
R_iou_pred_gain : +0.xxx ← 当前 proxy
|
| 458 |
-
R_avt_gain : +0.xxx ← cos(z_in, q_init)
|
| 459 |
-
R_avt_c_gain : +0.xxx ← 对比版本
|
| 460 |
-
|
| 461 |
-
Wrong direction (gain>0 但 Δ<0):
|
| 462 |
-
R_iou / ΔmIoU : 0.xxx
|
| 463 |
-
R_avt / ΔmIoU : 0.xxx
|
| 464 |
-
```
|
| 465 |
-
|
| 466 |
-
**运行命令:**
|
| 467 |
-
```bash
|
| 468 |
-
python load_model.py --eval_split test_u --max_eval_rows 200
|
| 469 |
-
```
|
| 470 |
-
|
| 471 |
-
**判断标准:**
|
| 472 |
-
- `r(R_avt, ΔmIoU) > r(R_iou, ΔmIoU)` → AVT proxy 更好,进入 Step A1
|
| 473 |
-
- 两者相近 → reward 本身不是瓶颈,需要重新审视
|
| 474 |
-
- `R_avt / ΔF wrong frac` 明显低于 `R_iou / ΔF` → AVT 能解释 F-score 不跟随 mIoU 的现象
|
| 475 |
-
|
| 476 |
-
### Step A1: Hybrid Reward(Step A0 验证后)
|
| 477 |
-
|
| 478 |
-
```
|
| 479 |
-
R_task = λ1 · e0 · R_iou_pred + λ2 · R_avt_c - λ3 · R_area_soft
|
| 480 |
-
```
|
| 481 |
-
|
| 482 |
-
- R_iou_pred 继续负责 mask quality(shape quality signal)
|
| 483 |
-
- R_avt_c 负责 referent correctness(task-specific signal)
|
| 484 |
-
- 两者结合才有可能同时维持 IoU 并提升 F
|
| 485 |
-
|
| 486 |
-
候选权重组合:`λ1=0.6, λ2=0.5, λ3=0.2`(AVT 作为辅助项,不完全取代 R_iou)。
|
| 487 |
-
|
| 488 |
-
如果 Step A1 有正向信号,再考虑将 Direction II(frame-adaptive)和新 reward 结合。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
analyze_d2_csv.py
DELETED
|
@@ -1,239 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import csv
|
| 3 |
-
import math
|
| 4 |
-
from collections import defaultdict
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def parse_args():
|
| 10 |
-
parser = argparse.ArgumentParser(description="Analyze D2 frame-level CSV.")
|
| 11 |
-
parser.add_argument("--csv", required=True, help="Path to d2_llm_space.py or d2_basic.py CSV output.")
|
| 12 |
-
parser.add_argument("--beta", type=float, default=1.0)
|
| 13 |
-
parser.add_argument("--failure_iou", type=float, default=0.5)
|
| 14 |
-
parser.add_argument("--bottom_frac", type=float, default=0.2)
|
| 15 |
-
parser.add_argument("--pr_points", type=int, default=10)
|
| 16 |
-
return parser.parse_args()
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def read_rows(path, beta):
|
| 20 |
-
rows = []
|
| 21 |
-
with open(path, newline="") as f:
|
| 22 |
-
reader = csv.DictReader(f)
|
| 23 |
-
for row in reader:
|
| 24 |
-
row_beta = float(row["beta"])
|
| 25 |
-
if abs(row_beta - beta) > 1e-8:
|
| 26 |
-
continue
|
| 27 |
-
q_col = "h_type" if "h_type" in row else "q_type"
|
| 28 |
-
rows.append(
|
| 29 |
-
{
|
| 30 |
-
"sample_idx": int(row["sample_idx"]),
|
| 31 |
-
"frame": int(row["frame"]),
|
| 32 |
-
"anchor_type": row[q_col],
|
| 33 |
-
"s_pred": float(row["s_pred"]),
|
| 34 |
-
"s_gt": float(row["s_gt"]),
|
| 35 |
-
"frame_iou": float(row["frame_iou"]),
|
| 36 |
-
"iou_pred": float(row["iou_pred"]),
|
| 37 |
-
"pred_area": float(row["pred_area"]),
|
| 38 |
-
"gt_area": float(row["gt_area"]),
|
| 39 |
-
}
|
| 40 |
-
)
|
| 41 |
-
if not rows:
|
| 42 |
-
raise RuntimeError(f"No rows found for beta={beta} in {path}")
|
| 43 |
-
return rows
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def corr(x, y):
|
| 47 |
-
x = np.asarray(x, dtype=np.float64)
|
| 48 |
-
y = np.asarray(y, dtype=np.float64)
|
| 49 |
-
if len(x) < 2 or np.std(x) < 1e-12 or np.std(y) < 1e-12:
|
| 50 |
-
return float("nan")
|
| 51 |
-
return float(np.corrcoef(x, y)[0, 1])
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def residualize(y, controls):
|
| 55 |
-
y = np.asarray(y, dtype=np.float64)
|
| 56 |
-
cols = [np.ones(len(y), dtype=np.float64)]
|
| 57 |
-
for control in controls:
|
| 58 |
-
cols.append(np.asarray(control, dtype=np.float64))
|
| 59 |
-
x = np.stack(cols, axis=1)
|
| 60 |
-
coef, *_ = np.linalg.lstsq(x, y, rcond=None)
|
| 61 |
-
return y - x @ coef
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def r2_score(y, y_pred):
|
| 65 |
-
y = np.asarray(y, dtype=np.float64)
|
| 66 |
-
y_pred = np.asarray(y_pred, dtype=np.float64)
|
| 67 |
-
ss_res = np.sum((y - y_pred) ** 2)
|
| 68 |
-
ss_tot = np.sum((y - y.mean()) ** 2)
|
| 69 |
-
if ss_tot < 1e-12:
|
| 70 |
-
return float("nan")
|
| 71 |
-
return float(1.0 - ss_res / ss_tot)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def linear_r2(y, features):
|
| 75 |
-
y = np.asarray(y, dtype=np.float64)
|
| 76 |
-
cols = [np.ones(len(y), dtype=np.float64)]
|
| 77 |
-
for feature in features:
|
| 78 |
-
cols.append(np.asarray(feature, dtype=np.float64))
|
| 79 |
-
x = np.stack(cols, axis=1)
|
| 80 |
-
coef, *_ = np.linalg.lstsq(x, y, rcond=None)
|
| 81 |
-
return r2_score(y, x @ coef)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def real_rows(rows):
|
| 85 |
-
return [r for r in rows if r["anchor_type"] == "real"]
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def bottom_failure_enrichment(rows, failure_iou, bottom_frac):
|
| 89 |
-
rr = real_rows(rows)
|
| 90 |
-
n = len(rr)
|
| 91 |
-
k = max(1, int(round(n * bottom_frac)))
|
| 92 |
-
sorted_rows = sorted(rr, key=lambda r: r["s_pred"])
|
| 93 |
-
bottom = sorted_rows[:k]
|
| 94 |
-
baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rr])
|
| 95 |
-
bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
|
| 96 |
-
total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
|
| 97 |
-
covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
|
| 98 |
-
recall = covered_failures / max(total_failures, 1)
|
| 99 |
-
enrichment = bottom_rate / max(baseline_rate, 1e-12)
|
| 100 |
-
return {
|
| 101 |
-
"n": n,
|
| 102 |
-
"k": k,
|
| 103 |
-
"baseline_failure_rate": baseline_rate,
|
| 104 |
-
"bottom_failure_rate": bottom_rate,
|
| 105 |
-
"bottom_failure_recall": recall,
|
| 106 |
-
"enrichment": enrichment,
|
| 107 |
-
"total_failures": total_failures,
|
| 108 |
-
}
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def pr_curve(rows, failure_iou, points):
|
| 112 |
-
rr = sorted(real_rows(rows), key=lambda r: r["s_pred"])
|
| 113 |
-
total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
|
| 114 |
-
out = []
|
| 115 |
-
for frac in np.linspace(0.05, 1.0, points):
|
| 116 |
-
k = max(1, int(round(len(rr) * frac)))
|
| 117 |
-
selected = rr[:k]
|
| 118 |
-
failures = sum(r["frame_iou"] < failure_iou for r in selected)
|
| 119 |
-
precision = failures / k
|
| 120 |
-
recall = failures / max(total_failures, 1)
|
| 121 |
-
out.append((frac, precision, recall))
|
| 122 |
-
return out
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def margin_rows(rows):
|
| 126 |
-
grouped = defaultdict(dict)
|
| 127 |
-
for r in rows:
|
| 128 |
-
key = (r["sample_idx"], r["frame"])
|
| 129 |
-
grouped[key][r["anchor_type"]] = r
|
| 130 |
-
|
| 131 |
-
out = []
|
| 132 |
-
for key, group in grouped.items():
|
| 133 |
-
if "real" not in group:
|
| 134 |
-
continue
|
| 135 |
-
controls = [group[name]["s_pred"] for name in ("shuffled", "wrong_ref") if name in group]
|
| 136 |
-
if not controls:
|
| 137 |
-
continue
|
| 138 |
-
real = group["real"]
|
| 139 |
-
item = dict(real)
|
| 140 |
-
item["s_margin"] = real["s_pred"] - max(controls)
|
| 141 |
-
out.append(item)
|
| 142 |
-
return out
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def bottom_failure_enrichment_for_score(rows, score_key, failure_iou, bottom_frac):
|
| 146 |
-
n = len(rows)
|
| 147 |
-
k = max(1, int(round(n * bottom_frac)))
|
| 148 |
-
sorted_rows = sorted(rows, key=lambda r: r[score_key])
|
| 149 |
-
bottom = sorted_rows[:k]
|
| 150 |
-
baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rows])
|
| 151 |
-
bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
|
| 152 |
-
total_failures = sum(r["frame_iou"] < failure_iou for r in rows)
|
| 153 |
-
covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
|
| 154 |
-
return {
|
| 155 |
-
"n": n,
|
| 156 |
-
"k": k,
|
| 157 |
-
"baseline_failure_rate": baseline_rate,
|
| 158 |
-
"bottom_failure_rate": bottom_rate,
|
| 159 |
-
"bottom_failure_recall": covered_failures / max(total_failures, 1),
|
| 160 |
-
"enrichment": bottom_rate / max(baseline_rate, 1e-12),
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def main():
|
| 165 |
-
args = parse_args()
|
| 166 |
-
rows = read_rows(args.csv, args.beta)
|
| 167 |
-
rr = real_rows(rows)
|
| 168 |
-
|
| 169 |
-
print(f"CSV: {args.csv}")
|
| 170 |
-
print(f"beta: {args.beta}")
|
| 171 |
-
print(f"real frames: {len(rr)}")
|
| 172 |
-
print(f"failure definition: frame_iou < {args.failure_iou}")
|
| 173 |
-
|
| 174 |
-
print("\nReal s_pred Correlations")
|
| 175 |
-
print(f"corr(s_pred, frame_iou): {corr([r['s_pred'] for r in rr], [r['frame_iou'] for r in rr]):+.4f}")
|
| 176 |
-
print(f"corr(s_pred, iou_pred): {corr([r['s_pred'] for r in rr], [r['iou_pred'] for r in rr]):+.4f}")
|
| 177 |
-
print(f"corr(s_pred, pred_area): {corr([r['s_pred'] for r in rr], [r['pred_area'] for r in rr]):+.4f}")
|
| 178 |
-
|
| 179 |
-
s_pred_values = [r["s_pred"] for r in rr]
|
| 180 |
-
frame_iou_values = [r["frame_iou"] for r in rr]
|
| 181 |
-
iou_pred_values = [r["iou_pred"] for r in rr]
|
| 182 |
-
pred_area_values = [r["pred_area"] for r in rr]
|
| 183 |
-
gt_area_values = [r["gt_area"] for r in rr]
|
| 184 |
-
partial_iou_pred = corr(
|
| 185 |
-
residualize(s_pred_values, [iou_pred_values]),
|
| 186 |
-
residualize(frame_iou_values, [iou_pred_values]),
|
| 187 |
-
)
|
| 188 |
-
partial_iou_area = corr(
|
| 189 |
-
residualize(s_pred_values, [iou_pred_values, pred_area_values]),
|
| 190 |
-
residualize(frame_iou_values, [iou_pred_values, pred_area_values]),
|
| 191 |
-
)
|
| 192 |
-
partial_iou_area_gt = corr(
|
| 193 |
-
residualize(s_pred_values, [iou_pred_values, pred_area_values, gt_area_values]),
|
| 194 |
-
residualize(frame_iou_values, [iou_pred_values, pred_area_values, gt_area_values]),
|
| 195 |
-
)
|
| 196 |
-
r2_iou_pred = linear_r2(frame_iou_values, [iou_pred_values])
|
| 197 |
-
r2_iou_pred_s = linear_r2(frame_iou_values, [iou_pred_values, s_pred_values])
|
| 198 |
-
r2_iou_pred_area = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values])
|
| 199 |
-
r2_iou_pred_area_s = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values, s_pred_values])
|
| 200 |
-
|
| 201 |
-
print("\nPartial Correlation / Residual Gain")
|
| 202 |
-
print(f"partial corr(s_pred, frame_iou | iou_pred): {partial_iou_pred:+.4f}")
|
| 203 |
-
print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area): {partial_iou_area:+.4f}")
|
| 204 |
-
print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area,gt_area): {partial_iou_area_gt:+.4f}")
|
| 205 |
-
print(f"R2 frame_iou ~ iou_pred: {r2_iou_pred:.4f}")
|
| 206 |
-
print(f"R2 frame_iou ~ iou_pred + s_pred: {r2_iou_pred_s:.4f} (gain {r2_iou_pred_s - r2_iou_pred:+.4f})")
|
| 207 |
-
print(f"R2 frame_iou ~ iou_pred + pred_area: {r2_iou_pred_area:.4f}")
|
| 208 |
-
print(f"R2 frame_iou ~ iou_pred + pred_area + s_pred: {r2_iou_pred_area_s:.4f} (gain {r2_iou_pred_area_s - r2_iou_pred_area:+.4f})")
|
| 209 |
-
|
| 210 |
-
stats = bottom_failure_enrichment(rows, args.failure_iou, args.bottom_frac)
|
| 211 |
-
print("\nBottom-k Failure Enrichment")
|
| 212 |
-
print(f"bottom_frac: {args.bottom_frac:.2f} ({stats['k']}/{stats['n']} frames)")
|
| 213 |
-
print(f"total failures: {stats['total_failures']}")
|
| 214 |
-
print(f"random/baseline failure rate: {stats['baseline_failure_rate']:.4f}")
|
| 215 |
-
print(f"bottom-s_pred failure rate: {stats['bottom_failure_rate']:.4f}")
|
| 216 |
-
print(f"bottom-s_pred failure recall: {stats['bottom_failure_recall']:.4f}")
|
| 217 |
-
print(f"enrichment: {stats['enrichment']:.2f}x")
|
| 218 |
-
|
| 219 |
-
print("\nPR Curve Summary")
|
| 220 |
-
print("selected_frac | precision | recall")
|
| 221 |
-
for frac, precision, recall in pr_curve(rows, args.failure_iou, args.pr_points):
|
| 222 |
-
print(f"{frac:.2f} | {precision:.4f} | {recall:.4f}")
|
| 223 |
-
|
| 224 |
-
mr = margin_rows(rows)
|
| 225 |
-
if mr:
|
| 226 |
-
print("\nOffline Margin-D2")
|
| 227 |
-
print(f"margin frames: {len(mr)}")
|
| 228 |
-
print(f"corr(s_margin, frame_iou): {corr([r['s_margin'] for r in mr], [r['frame_iou'] for r in mr]):+.4f}")
|
| 229 |
-
print(f"corr(s_margin, pred_area): {corr([r['s_margin'] for r in mr], [r['pred_area'] for r in mr]):+.4f}")
|
| 230 |
-
mstats = bottom_failure_enrichment_for_score(mr, "s_margin", args.failure_iou, args.bottom_frac)
|
| 231 |
-
print(f"bottom-s_margin failure rate: {mstats['bottom_failure_rate']:.4f}")
|
| 232 |
-
print(f"bottom-s_margin failure recall: {mstats['bottom_failure_recall']:.4f}")
|
| 233 |
-
print(f"margin enrichment: {mstats['enrichment']:.2f}x")
|
| 234 |
-
else:
|
| 235 |
-
print("\nOffline Margin-D2 skipped: shuffled/wrong_ref controls not available.")
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
if __name__ == "__main__":
|
| 239 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build_rpb_dev_manifest.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import json
|
| 3 |
-
import os
|
| 4 |
-
import random
|
| 5 |
-
|
| 6 |
-
import pandas as pd
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def sample_indices(size, count, seed):
|
| 10 |
-
if count <= 0:
|
| 11 |
-
return []
|
| 12 |
-
if count > size:
|
| 13 |
-
raise ValueError(f"Requested {count} samples from a split of size {size}")
|
| 14 |
-
rng = random.Random(seed)
|
| 15 |
-
indices = list(range(size))
|
| 16 |
-
rng.shuffle(indices)
|
| 17 |
-
selected = sorted(indices[:count])
|
| 18 |
-
return selected
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def main():
|
| 22 |
-
parser = argparse.ArgumentParser(description="Build a fixed subset manifest for RPB dev experiments.")
|
| 23 |
-
parser.add_argument("--metadata", type=str, default="/workspace/SimToken/data/metadata.csv")
|
| 24 |
-
parser.add_argument("--output", type=str, required=True)
|
| 25 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 26 |
-
parser.add_argument("--train_rows", type=int, default=0)
|
| 27 |
-
parser.add_argument("--test_s_rows", type=int, default=200)
|
| 28 |
-
parser.add_argument("--test_u_rows", type=int, default=200)
|
| 29 |
-
parser.add_argument("--test_n_rows", type=int, default=200)
|
| 30 |
-
args = parser.parse_args()
|
| 31 |
-
|
| 32 |
-
metadata = pd.read_csv(args.metadata, header=0)
|
| 33 |
-
split_sizes = {
|
| 34 |
-
"train": int((metadata["split"] == "train").sum()),
|
| 35 |
-
"test_s": int((metadata["split"] == "test_s").sum()),
|
| 36 |
-
"test_u": int((metadata["split"] == "test_u").sum()),
|
| 37 |
-
"test_n": int((metadata["split"] == "test_n").sum()),
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
manifest = {
|
| 41 |
-
"train": sample_indices(split_sizes["train"], args.train_rows, args.seed),
|
| 42 |
-
"test_s": sample_indices(split_sizes["test_s"], args.test_s_rows, args.seed + 1),
|
| 43 |
-
"test_u": sample_indices(split_sizes["test_u"], args.test_u_rows, args.seed + 2),
|
| 44 |
-
"test_n": sample_indices(split_sizes["test_n"], args.test_n_rows, args.seed + 3),
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
# Remove empty entries so train.py only subsets the splits we intentionally fix.
|
| 48 |
-
manifest = {key: value for key, value in manifest.items() if value}
|
| 49 |
-
|
| 50 |
-
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
| 51 |
-
with open(args.output, "w", encoding="utf-8") as f:
|
| 52 |
-
json.dump(
|
| 53 |
-
{
|
| 54 |
-
"metadata": {
|
| 55 |
-
"seed": args.seed,
|
| 56 |
-
"split_sizes": split_sizes,
|
| 57 |
-
"source_metadata": os.path.abspath(args.metadata),
|
| 58 |
-
},
|
| 59 |
-
"subsets": manifest,
|
| 60 |
-
},
|
| 61 |
-
f,
|
| 62 |
-
indent=2,
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
print(f"saved subset manifest to {args.output}")
|
| 66 |
-
for split_name, indices in manifest.items():
|
| 67 |
-
print(f"{split_name}: {len(indices)} samples")
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
if __name__ == "__main__":
|
| 71 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cache_q_features.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
from functools import partial
|
| 4 |
-
from itertools import islice
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import transformers
|
| 8 |
-
from torch.utils.data import DataLoader
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
-
|
| 11 |
-
from configs import args
|
| 12 |
-
from datasets import REFAVS
|
| 13 |
-
from decoder_invariance_check import build_model, set_seed
|
| 14 |
-
from load_model import collate_fn, dict_to_cuda
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def _jsonable_size(size):
|
| 18 |
-
if isinstance(size, torch.Tensor):
|
| 19 |
-
return [int(x) for x in size.detach().cpu().tolist()]
|
| 20 |
-
return [int(x) for x in size]
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def main():
|
| 24 |
-
set_seed(42)
|
| 25 |
-
torch.set_grad_enabled(False)
|
| 26 |
-
|
| 27 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 28 |
-
args.mllm,
|
| 29 |
-
cache_dir=None,
|
| 30 |
-
model_max_length=2048,
|
| 31 |
-
padding_side="right",
|
| 32 |
-
use_fast=False,
|
| 33 |
-
)
|
| 34 |
-
tokenizer.pad_token = tokenizer.unk_token
|
| 35 |
-
tokenizer.add_tokens("[SEG]")
|
| 36 |
-
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 37 |
-
|
| 38 |
-
dataset = REFAVS(args.cache_split, args, tokenizer, input_type="refer")
|
| 39 |
-
loader = DataLoader(
|
| 40 |
-
dataset,
|
| 41 |
-
batch_size=1,
|
| 42 |
-
shuffle=False,
|
| 43 |
-
num_workers=0,
|
| 44 |
-
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
split_root = os.path.join(args.cache_root, args.cache_split)
|
| 48 |
-
os.makedirs(split_root, exist_ok=True)
|
| 49 |
-
index_path = os.path.join(split_root, "index.jsonl")
|
| 50 |
-
if os.path.exists(index_path) and not args.overwrite_cache:
|
| 51 |
-
raise FileExistsError(
|
| 52 |
-
f"{index_path} already exists. Pass --overwrite_cache to rebuild it."
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
limit = args.max_eval_rows if args.max_eval_rows > 0 else len(dataset)
|
| 56 |
-
print(f"cache split={args.cache_split} | samples={min(limit, len(dataset))}")
|
| 57 |
-
print(f"cache root: {split_root}")
|
| 58 |
-
|
| 59 |
-
model = build_model(tokenizer, seg_token_idx)
|
| 60 |
-
model.eval()
|
| 61 |
-
|
| 62 |
-
rows = []
|
| 63 |
-
for sample_idx, batch in enumerate(
|
| 64 |
-
tqdm(islice(loader, limit), total=min(limit, len(dataset)), desc=f"Caching {args.cache_split}")
|
| 65 |
-
):
|
| 66 |
-
batch = dict_to_cuda(batch)
|
| 67 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 68 |
-
output = model.forward(
|
| 69 |
-
images=batch["images"],
|
| 70 |
-
images_clip=batch["images_clip"],
|
| 71 |
-
audio_features=batch["audio_feats"],
|
| 72 |
-
image_features=batch["image_feats"],
|
| 73 |
-
input_ids=batch["input_ids"],
|
| 74 |
-
labels=batch["labels"],
|
| 75 |
-
attention_masks=batch["attention_masks"],
|
| 76 |
-
masks_list=batch["masks"],
|
| 77 |
-
resize_list=batch["resizes"],
|
| 78 |
-
orgsize_list=batch["orgsizes"],
|
| 79 |
-
conversation_list=batch["convs"],
|
| 80 |
-
refs_num=batch["refs_num"],
|
| 81 |
-
fids=batch["fids"],
|
| 82 |
-
vids=batch["vids"],
|
| 83 |
-
contrast=args.ct_weight,
|
| 84 |
-
ref_ids=batch["ref_ids"],
|
| 85 |
-
inference=True,
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
cache_name = f"{sample_idx:06d}.pt"
|
| 89 |
-
cache_path = os.path.join(split_root, cache_name)
|
| 90 |
-
item = {
|
| 91 |
-
"sample_idx": sample_idx,
|
| 92 |
-
"vid": batch["vids"][0],
|
| 93 |
-
"refs": batch["refs"][0],
|
| 94 |
-
"fids": [int(x) for x in batch["fids"][0]],
|
| 95 |
-
"resize": _jsonable_size(batch["resizes"][0]),
|
| 96 |
-
"orgsize": _jsonable_size(batch["orgsizes"][0]),
|
| 97 |
-
"q": output["seg_embeddings"][0].detach().cpu().float(),
|
| 98 |
-
}
|
| 99 |
-
torch.save(item, cache_path)
|
| 100 |
-
rows.append(
|
| 101 |
-
{
|
| 102 |
-
"sample_idx": sample_idx,
|
| 103 |
-
"path": cache_name,
|
| 104 |
-
"vid": item["vid"],
|
| 105 |
-
"refs": item["refs"],
|
| 106 |
-
"fids": item["fids"],
|
| 107 |
-
"resize": item["resize"],
|
| 108 |
-
"orgsize": item["orgsize"],
|
| 109 |
-
"num_seg": int(item["q"].shape[0]),
|
| 110 |
-
}
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
if not rows:
|
| 114 |
-
raise RuntimeError("No samples were cached.")
|
| 115 |
-
|
| 116 |
-
with open(index_path, "w") as f:
|
| 117 |
-
for row in rows:
|
| 118 |
-
f.write(json.dumps(row) + "\n")
|
| 119 |
-
|
| 120 |
-
print(f"cached samples: {len(rows)}")
|
| 121 |
-
print(f"saved index: {index_path}")
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
if __name__ == "__main__":
|
| 125 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cache_q_smoke/test_s/000000.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5f85d7cf7b83caf6fedb153a2cea2b36dd144ee3c0e34039483e20d208ea92d3
|
| 3 |
-
size 2327
|
|
|
|
|
|
|
|
|
|
|
|
cache_q_smoke/test_s/index.jsonl
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"sample_idx": 0, "path": "000000.pt", "vid": "-3ABOVeVmpU_136000_146000", "refs": ["the object that keeps making sound at all times"], "fids": [1], "resize": [576, 1024], "orgsize": [720, 1280], "num_seg": 1}
|
|
|
|
|
|
checkpoints/rpb_dev_mixed_pm_only_a018_wm005.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2c1facc9eac5ffdfd12c97d252af2c8eedc4e526a53931d301b0ef4bed698213
|
| 3 |
-
size 30841132766
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/rpb_dev_pm_only_a018.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:33e8b6251c69d7d4de055b488a2f2345eece1991831a2f08ce5f1d1cb795ae5f
|
| 3 |
-
size 30841115170
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints/rpb_probe_eval_directional_pm_only_a02.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:6dc5cd6b02f5d54a026694f6a1217f46137fdaa4499a71fa7b9bd95ede17da6c
|
| 3 |
-
size 30841141852
|
|
|
|
|
|
|
|
|
|
|
|
d2_basic.py
DELETED
|
@@ -1,340 +0,0 @@
|
|
| 1 |
-
import csv
|
| 2 |
-
import math
|
| 3 |
-
import os
|
| 4 |
-
from functools import partial
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
import transformers
|
| 10 |
-
from torch.utils.data import DataLoader
|
| 11 |
-
|
| 12 |
-
from configs import args
|
| 13 |
-
from datasets import REFAVS
|
| 14 |
-
from decoder_invariance_check import build_model, set_seed
|
| 15 |
-
from load_model import collate_fn, dict_to_cuda
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def make_loader(tokenizer):
|
| 19 |
-
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 20 |
-
return DataLoader(
|
| 21 |
-
dataset,
|
| 22 |
-
batch_size=1,
|
| 23 |
-
shuffle=False,
|
| 24 |
-
num_workers=0,
|
| 25 |
-
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def build_tokenizer():
|
| 30 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 31 |
-
args.mllm,
|
| 32 |
-
cache_dir=None,
|
| 33 |
-
model_max_length=2048,
|
| 34 |
-
padding_side="right",
|
| 35 |
-
use_fast=False,
|
| 36 |
-
)
|
| 37 |
-
tokenizer.pad_token = tokenizer.unk_token
|
| 38 |
-
tokenizer.add_tokens("[SEG]")
|
| 39 |
-
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 40 |
-
return tokenizer, seg_token_idx
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def get_q(model, batch):
|
| 44 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 45 |
-
output = model.forward(
|
| 46 |
-
images=batch["images"],
|
| 47 |
-
images_clip=batch["images_clip"],
|
| 48 |
-
audio_features=batch["audio_feats"],
|
| 49 |
-
image_features=batch["image_feats"],
|
| 50 |
-
input_ids=batch["input_ids"],
|
| 51 |
-
labels=batch["labels"],
|
| 52 |
-
attention_masks=batch["attention_masks"],
|
| 53 |
-
masks_list=batch["masks"],
|
| 54 |
-
resize_list=batch["resizes"],
|
| 55 |
-
orgsize_list=batch["orgsizes"],
|
| 56 |
-
conversation_list=batch["convs"],
|
| 57 |
-
refs_num=batch["refs_num"],
|
| 58 |
-
fids=batch["fids"],
|
| 59 |
-
vids=batch["vids"],
|
| 60 |
-
contrast=args.ct_weight,
|
| 61 |
-
ref_ids=batch["ref_ids"],
|
| 62 |
-
inference=True,
|
| 63 |
-
)
|
| 64 |
-
return output["seg_embeddings"][0][0].float()
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def decode_low_res(model, batch, q):
|
| 68 |
-
visual_model = model.get_model().visual_model
|
| 69 |
-
sparse, dense = visual_model.prompt_encoder(
|
| 70 |
-
points=None,
|
| 71 |
-
boxes=None,
|
| 72 |
-
masks=None,
|
| 73 |
-
text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
|
| 74 |
-
)
|
| 75 |
-
sparse = sparse.to(q.dtype)
|
| 76 |
-
dense = dense.to(q.dtype)
|
| 77 |
-
|
| 78 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 79 |
-
low_res_masks, iou_predictions = visual_model.mask_decoder(
|
| 80 |
-
image_embeddings=batch["image_feats"][0],
|
| 81 |
-
image_pe=visual_model.prompt_encoder.get_dense_pe(),
|
| 82 |
-
sparse_prompt_embeddings=sparse,
|
| 83 |
-
dense_prompt_embeddings=dense,
|
| 84 |
-
multimask_output=False,
|
| 85 |
-
)
|
| 86 |
-
return low_res_masks.float(), iou_predictions.float().squeeze(-1)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def masks_to_64(mask_logits_or_binary):
|
| 90 |
-
if mask_logits_or_binary.ndim == 3:
|
| 91 |
-
mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
|
| 92 |
-
return F.interpolate(
|
| 93 |
-
mask_logits_or_binary.float(),
|
| 94 |
-
size=(64, 64),
|
| 95 |
-
mode="bilinear",
|
| 96 |
-
align_corners=False,
|
| 97 |
-
).clamp(0.0, 1.0)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def d2_scores(image_embeddings, mask64, q, beta):
|
| 101 |
-
feats = image_embeddings.float()
|
| 102 |
-
if mask64.shape[0] != feats.shape[0]:
|
| 103 |
-
raise ValueError(f"Mask/frame mismatch: {mask64.shape} vs {feats.shape}")
|
| 104 |
-
|
| 105 |
-
q = F.normalize(q.float().view(1, -1), dim=-1)
|
| 106 |
-
mask = mask64.float()
|
| 107 |
-
comp = 1.0 - mask
|
| 108 |
-
|
| 109 |
-
z_in = (feats * mask).sum(dim=(2, 3)) / mask.sum(dim=(2, 3)).clamp_min(1e-6)
|
| 110 |
-
z_out = (feats * comp).sum(dim=(2, 3)) / comp.sum(dim=(2, 3)).clamp_min(1e-6)
|
| 111 |
-
|
| 112 |
-
z_in = F.normalize(z_in, dim=-1)
|
| 113 |
-
z_out = F.normalize(z_out, dim=-1)
|
| 114 |
-
return (z_in @ q.T).squeeze(-1) - beta * (z_out @ q.T).squeeze(-1)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def frame_iou(pred_logits, gt_masks):
|
| 118 |
-
pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
|
| 119 |
-
gt = gt_masks.float()
|
| 120 |
-
if pred.ndim == 4:
|
| 121 |
-
pred = pred.squeeze(1)
|
| 122 |
-
inter = (pred * gt).sum(dim=(1, 2))
|
| 123 |
-
union = torch.maximum(pred, gt).sum(dim=(1, 2))
|
| 124 |
-
num_pixels = pred.shape[-1] * pred.shape[-2]
|
| 125 |
-
no_obj = gt.sum(dim=(1, 2)) == 0
|
| 126 |
-
inter_no_obj = ((1.0 - pred) * (1.0 - gt)).sum(dim=(1, 2))
|
| 127 |
-
inter = torch.where(no_obj, inter_no_obj, inter)
|
| 128 |
-
union = torch.where(no_obj, torch.full_like(union, float(num_pixels)), union)
|
| 129 |
-
return inter / union.clamp_min(1e-7)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def frame_fscore_proxy(pred_logits, gt_masks):
|
| 133 |
-
pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
|
| 134 |
-
gt = gt_masks.float()
|
| 135 |
-
if pred.ndim == 4:
|
| 136 |
-
pred = pred.squeeze(1)
|
| 137 |
-
tp = (pred * gt).sum(dim=(1, 2))
|
| 138 |
-
precision = tp / pred.sum(dim=(1, 2)).clamp_min(1e-7)
|
| 139 |
-
recall = tp / gt.sum(dim=(1, 2)).clamp_min(1e-7)
|
| 140 |
-
beta2 = 0.3
|
| 141 |
-
fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall).clamp_min(1e-7)
|
| 142 |
-
no_obj = gt.sum(dim=(1, 2)) == 0
|
| 143 |
-
return torch.where(no_obj, torch.zeros_like(fscore), fscore)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def parse_betas():
|
| 147 |
-
raw = os.environ.get("D2_BETAS", "0.5")
|
| 148 |
-
return [float(x.strip()) for x in raw.split(",") if x.strip()]
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def collect_q_pool(model, tokenizer, limit):
|
| 152 |
-
q_pool = []
|
| 153 |
-
loader = make_loader(tokenizer)
|
| 154 |
-
for sample_idx, batch in enumerate(loader):
|
| 155 |
-
if sample_idx >= limit:
|
| 156 |
-
break
|
| 157 |
-
batch = dict_to_cuda(batch)
|
| 158 |
-
q = get_q(model, batch)
|
| 159 |
-
q_pool.append(
|
| 160 |
-
{
|
| 161 |
-
"sample_idx": sample_idx,
|
| 162 |
-
"vid": batch["vids"][0],
|
| 163 |
-
"ref": batch["refs"][0][0],
|
| 164 |
-
"fid": int(batch["fids"][0][0]),
|
| 165 |
-
"q": q.cpu(),
|
| 166 |
-
}
|
| 167 |
-
)
|
| 168 |
-
print(f"Collected q {sample_idx}: vid={q_pool[-1]['vid']} ref={q_pool[-1]['ref']}")
|
| 169 |
-
if not q_pool:
|
| 170 |
-
raise RuntimeError("No q vectors collected. Is the selected split empty?")
|
| 171 |
-
return q_pool
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def choose_shuffled_idx(sample_idx, q_pool):
|
| 175 |
-
if len(q_pool) <= 1:
|
| 176 |
-
return None
|
| 177 |
-
return (sample_idx + 1) % len(q_pool)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def choose_wrong_ref_idx(sample_idx, q_pool):
|
| 181 |
-
current = q_pool[sample_idx]
|
| 182 |
-
for item in q_pool:
|
| 183 |
-
if item["sample_idx"] == sample_idx:
|
| 184 |
-
continue
|
| 185 |
-
if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
|
| 186 |
-
return item["sample_idx"]
|
| 187 |
-
for item in q_pool:
|
| 188 |
-
if item["sample_idx"] == sample_idx:
|
| 189 |
-
continue
|
| 190 |
-
if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
|
| 191 |
-
return item["sample_idx"]
|
| 192 |
-
return None
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def run_d2(model, tokenizer, q_pool, betas, limit):
|
| 196 |
-
rows = []
|
| 197 |
-
loader = make_loader(tokenizer)
|
| 198 |
-
q_lookup = {item["sample_idx"]: item for item in q_pool}
|
| 199 |
-
generator = torch.Generator(device="cuda")
|
| 200 |
-
generator.manual_seed(1234)
|
| 201 |
-
|
| 202 |
-
for sample_idx, batch in enumerate(loader):
|
| 203 |
-
if sample_idx >= limit:
|
| 204 |
-
break
|
| 205 |
-
batch = dict_to_cuda(batch)
|
| 206 |
-
item = q_lookup[sample_idx]
|
| 207 |
-
real_q = item["q"].cuda()
|
| 208 |
-
|
| 209 |
-
low_res_masks, iou_predictions = decode_low_res(model, batch, real_q)
|
| 210 |
-
pred_mask64 = masks_to_64(torch.sigmoid(low_res_masks))
|
| 211 |
-
gt_masks = batch["masks"][0][0].float()
|
| 212 |
-
gt_mask64 = masks_to_64(gt_masks)
|
| 213 |
-
image_embeddings = batch["image_feats"][0].float()
|
| 214 |
-
|
| 215 |
-
pred_logits_hr = model.get_model().visual_model.postprocess_masks(
|
| 216 |
-
low_res_masks.to(batch["image_feats"][0].dtype),
|
| 217 |
-
input_size=batch["resizes"][0],
|
| 218 |
-
original_size=batch["orgsizes"][0],
|
| 219 |
-
).squeeze(1)
|
| 220 |
-
|
| 221 |
-
frame_ious = frame_iou(pred_logits_hr, gt_masks)
|
| 222 |
-
frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
|
| 223 |
-
pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
|
| 224 |
-
gt_area = gt_masks.float().mean(dim=(1, 2))
|
| 225 |
-
|
| 226 |
-
shuffled_idx = choose_shuffled_idx(sample_idx, q_pool)
|
| 227 |
-
wrong_ref_idx = choose_wrong_ref_idx(sample_idx, q_pool)
|
| 228 |
-
q_controls = [
|
| 229 |
-
("real", real_q, sample_idx),
|
| 230 |
-
("random", torch.randn(real_q.shape, device=real_q.device, generator=generator), None),
|
| 231 |
-
]
|
| 232 |
-
if shuffled_idx is not None:
|
| 233 |
-
q_controls.append(("shuffled", q_lookup[shuffled_idx]["q"].cuda(), shuffled_idx))
|
| 234 |
-
if wrong_ref_idx is not None:
|
| 235 |
-
q_controls.append(("wrong_ref", q_lookup[wrong_ref_idx]["q"].cuda(), wrong_ref_idx))
|
| 236 |
-
|
| 237 |
-
for beta in betas:
|
| 238 |
-
for q_type, q, q_source_idx in q_controls:
|
| 239 |
-
pred_scores = d2_scores(image_embeddings, pred_mask64, q, beta)
|
| 240 |
-
gt_scores = d2_scores(image_embeddings, gt_mask64, q, beta)
|
| 241 |
-
base_info = {
|
| 242 |
-
"sample_idx": sample_idx,
|
| 243 |
-
"vid": item["vid"],
|
| 244 |
-
"ref": item["ref"],
|
| 245 |
-
"fid": item["fid"],
|
| 246 |
-
"split": args.eval_split,
|
| 247 |
-
"frame_iou": math.nan,
|
| 248 |
-
"frame_fscore_proxy": math.nan,
|
| 249 |
-
"iou_pred": math.nan,
|
| 250 |
-
"pred_area": math.nan,
|
| 251 |
-
"gt_area": math.nan,
|
| 252 |
-
}
|
| 253 |
-
for frame_idx in range(pred_scores.shape[0]):
|
| 254 |
-
base_info_frame = dict(base_info)
|
| 255 |
-
base_info_frame.update(
|
| 256 |
-
{
|
| 257 |
-
"frame_iou": frame_ious[frame_idx].item(),
|
| 258 |
-
"frame_fscore_proxy": frame_fscores[frame_idx].item(),
|
| 259 |
-
"iou_pred": iou_predictions[frame_idx].item(),
|
| 260 |
-
"pred_area": pred_area[frame_idx].item(),
|
| 261 |
-
"gt_area": gt_area[frame_idx].item(),
|
| 262 |
-
}
|
| 263 |
-
)
|
| 264 |
-
row = dict(base_info_frame)
|
| 265 |
-
row.update(
|
| 266 |
-
{
|
| 267 |
-
"frame": frame_idx,
|
| 268 |
-
"q_type": q_type,
|
| 269 |
-
"beta": beta,
|
| 270 |
-
"s_pred": pred_scores[frame_idx].item(),
|
| 271 |
-
"s_gt": gt_scores[frame_idx].item(),
|
| 272 |
-
"q_source_idx": q_source_idx if q_source_idx is not None else "",
|
| 273 |
-
}
|
| 274 |
-
)
|
| 275 |
-
rows.append(row)
|
| 276 |
-
|
| 277 |
-
real_rows = [
|
| 278 |
-
r for r in rows if r["sample_idx"] == sample_idx and r["q_type"] == "real" and r["beta"] == betas[0]
|
| 279 |
-
]
|
| 280 |
-
s_pred_values = [r["s_pred"] for r in real_rows]
|
| 281 |
-
print(
|
| 282 |
-
f"D2 {sample_idx}: vid={item['vid']} ref={item['ref']} "
|
| 283 |
-
f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
|
| 284 |
-
f"mean_iou={frame_ious.mean().item():.4f}"
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
-
return rows
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
def print_summary(rows):
|
| 291 |
-
real_rows = [r for r in rows if r["q_type"] == "real"]
|
| 292 |
-
if not real_rows:
|
| 293 |
-
return
|
| 294 |
-
by_beta = sorted(set(r["beta"] for r in real_rows))
|
| 295 |
-
print("\nSummary")
|
| 296 |
-
print(f"rows: {len(rows)}")
|
| 297 |
-
for beta in by_beta:
|
| 298 |
-
beta_rows = [r for r in rows if r["beta"] == beta]
|
| 299 |
-
print(f"\nbeta={beta}")
|
| 300 |
-
for q_type in sorted(set(r["q_type"] for r in beta_rows)):
|
| 301 |
-
qr = [r for r in beta_rows if r["q_type"] == q_type]
|
| 302 |
-
print(
|
| 303 |
-
f"{q_type:10s} "
|
| 304 |
-
f"mean_s_pred={np.mean([r['s_pred'] for r in qr]):+.4f} "
|
| 305 |
-
f"mean_s_gt={np.mean([r['s_gt'] for r in qr]):+.4f}"
|
| 306 |
-
)
|
| 307 |
-
real_beta = [r for r in beta_rows if r["q_type"] == "real"]
|
| 308 |
-
s_pred = np.array([r["s_pred"] for r in real_beta])
|
| 309 |
-
frame_iou_values = np.array([r["frame_iou"] for r in real_beta])
|
| 310 |
-
if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
|
| 311 |
-
corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
|
| 312 |
-
print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
|
| 313 |
-
else:
|
| 314 |
-
print("corr(real s_pred, frame_iou)=nan")
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
def main():
|
| 318 |
-
set_seed(42)
|
| 319 |
-
torch.set_grad_enabled(False)
|
| 320 |
-
betas = parse_betas()
|
| 321 |
-
tokenizer, seg_token_idx = build_tokenizer()
|
| 322 |
-
limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
|
| 323 |
-
print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
|
| 324 |
-
|
| 325 |
-
model = build_model(tokenizer, seg_token_idx)
|
| 326 |
-
q_pool = collect_q_pool(model, tokenizer, limit)
|
| 327 |
-
rows = run_d2(model, tokenizer, q_pool, betas, limit)
|
| 328 |
-
print_summary(rows)
|
| 329 |
-
|
| 330 |
-
csv_path = os.environ.get("D2_BASIC_CSV", f"/workspace/SimToken/d2_basic_{args.eval_split}_{limit}.csv")
|
| 331 |
-
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 332 |
-
with open(csv_path, "w", newline="") as f:
|
| 333 |
-
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
| 334 |
-
writer.writeheader()
|
| 335 |
-
writer.writerows(rows)
|
| 336 |
-
print(f"\nSaved CSV: {csv_path}")
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
if __name__ == "__main__":
|
| 340 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
d2_llm_space.py
DELETED
|
@@ -1,314 +0,0 @@
|
|
| 1 |
-
import csv
|
| 2 |
-
import math
|
| 3 |
-
import os
|
| 4 |
-
from functools import partial
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
import transformers
|
| 10 |
-
from torch.utils.data import DataLoader
|
| 11 |
-
|
| 12 |
-
from configs import args
|
| 13 |
-
from datasets import REFAVS
|
| 14 |
-
from decoder_invariance_check import build_model, set_seed
|
| 15 |
-
from d2_basic import frame_fscore_proxy, frame_iou
|
| 16 |
-
from load_model import collate_fn, dict_to_cuda
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def build_tokenizer():
|
| 20 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 21 |
-
args.mllm,
|
| 22 |
-
cache_dir=None,
|
| 23 |
-
model_max_length=2048,
|
| 24 |
-
padding_side="right",
|
| 25 |
-
use_fast=False,
|
| 26 |
-
)
|
| 27 |
-
tokenizer.pad_token = tokenizer.unk_token
|
| 28 |
-
tokenizer.add_tokens("[SEG]")
|
| 29 |
-
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 30 |
-
return tokenizer, seg_token_idx
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def make_loader(tokenizer):
|
| 34 |
-
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 35 |
-
return DataLoader(
|
| 36 |
-
dataset,
|
| 37 |
-
batch_size=1,
|
| 38 |
-
shuffle=False,
|
| 39 |
-
num_workers=0,
|
| 40 |
-
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def forward_for_hidden_and_q(model, batch):
|
| 45 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 46 |
-
output = model.forward(
|
| 47 |
-
images=batch["images"],
|
| 48 |
-
images_clip=batch["images_clip"],
|
| 49 |
-
audio_features=batch["audio_feats"],
|
| 50 |
-
image_features=batch["image_feats"],
|
| 51 |
-
input_ids=batch["input_ids"],
|
| 52 |
-
labels=batch["labels"],
|
| 53 |
-
attention_masks=batch["attention_masks"],
|
| 54 |
-
masks_list=batch["masks"],
|
| 55 |
-
resize_list=batch["resizes"],
|
| 56 |
-
orgsize_list=batch["orgsizes"],
|
| 57 |
-
conversation_list=batch["convs"],
|
| 58 |
-
refs_num=batch["refs_num"],
|
| 59 |
-
fids=batch["fids"],
|
| 60 |
-
vids=batch["vids"],
|
| 61 |
-
contrast=args.ct_weight,
|
| 62 |
-
ref_ids=batch["ref_ids"],
|
| 63 |
-
inference=True,
|
| 64 |
-
)
|
| 65 |
-
h_seg = output["seg_hidden_states"][0][0].float()
|
| 66 |
-
q = output["seg_embeddings"][0][0].float()
|
| 67 |
-
return h_seg, q
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def decode_low_res(model, batch, q):
|
| 71 |
-
visual_model = model.get_model().visual_model
|
| 72 |
-
sparse, dense = visual_model.prompt_encoder(
|
| 73 |
-
points=None,
|
| 74 |
-
boxes=None,
|
| 75 |
-
masks=None,
|
| 76 |
-
text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
|
| 77 |
-
)
|
| 78 |
-
sparse = sparse.to(q.dtype)
|
| 79 |
-
dense = dense.to(q.dtype)
|
| 80 |
-
|
| 81 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 82 |
-
low_res_masks, iou_predictions = visual_model.mask_decoder(
|
| 83 |
-
image_embeddings=batch["image_feats"][0],
|
| 84 |
-
image_pe=visual_model.prompt_encoder.get_dense_pe(),
|
| 85 |
-
sparse_prompt_embeddings=sparse,
|
| 86 |
-
dense_prompt_embeddings=dense,
|
| 87 |
-
multimask_output=False,
|
| 88 |
-
)
|
| 89 |
-
return low_res_masks.float(), iou_predictions.float().squeeze(-1)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def clip_projected_tokens(model, batch):
|
| 93 |
-
images = torch.cat(batch["images_clip"], dim=0)
|
| 94 |
-
with torch.no_grad():
|
| 95 |
-
clip_tokens = model.encode_images(images)
|
| 96 |
-
projector = model.get_model().mm_projector
|
| 97 |
-
clip_tokens = clip_tokens.to(projector.weight.dtype)
|
| 98 |
-
llm_tokens = projector(clip_tokens).float()
|
| 99 |
-
return llm_tokens
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def infer_square_grid(num_tokens):
|
| 103 |
-
grid = int(math.sqrt(num_tokens))
|
| 104 |
-
if grid * grid != num_tokens:
|
| 105 |
-
raise ValueError(f"Expected square patch-token grid, got {num_tokens} tokens")
|
| 106 |
-
return grid
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def masks_to_token_grid(mask_logits_or_binary, num_tokens):
|
| 110 |
-
if mask_logits_or_binary.ndim == 3:
|
| 111 |
-
mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
|
| 112 |
-
grid = infer_square_grid(num_tokens)
|
| 113 |
-
return F.interpolate(
|
| 114 |
-
mask_logits_or_binary.float(),
|
| 115 |
-
size=(grid, grid),
|
| 116 |
-
mode="bilinear",
|
| 117 |
-
align_corners=False,
|
| 118 |
-
).flatten(2).transpose(1, 2).clamp(0.0, 1.0)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def d2_scores_llm(llm_tokens, mask_tokens, h_seg, beta):
|
| 122 |
-
if llm_tokens.shape[:2] != mask_tokens.shape[:2]:
|
| 123 |
-
raise ValueError(f"Token/mask mismatch: {llm_tokens.shape} vs {mask_tokens.shape}")
|
| 124 |
-
h = F.normalize(h_seg.float().view(1, -1), dim=-1)
|
| 125 |
-
tokens = llm_tokens.float()
|
| 126 |
-
mask = mask_tokens.float()
|
| 127 |
-
comp = 1.0 - mask
|
| 128 |
-
|
| 129 |
-
z_in = (tokens * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-6)
|
| 130 |
-
z_out = (tokens * comp).sum(dim=1) / comp.sum(dim=1).clamp_min(1e-6)
|
| 131 |
-
|
| 132 |
-
z_in = F.normalize(z_in, dim=-1)
|
| 133 |
-
z_out = F.normalize(z_out, dim=-1)
|
| 134 |
-
return (z_in @ h.T).squeeze(-1) - beta * (z_out @ h.T).squeeze(-1)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def parse_betas():
|
| 138 |
-
raw = os.environ.get("D2_BETAS", "0.5")
|
| 139 |
-
return [float(x.strip()) for x in raw.split(",") if x.strip()]
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def collect_hidden_pool(model, tokenizer, limit):
|
| 143 |
-
pool = []
|
| 144 |
-
loader = make_loader(tokenizer)
|
| 145 |
-
for sample_idx, batch in enumerate(loader):
|
| 146 |
-
if sample_idx >= limit:
|
| 147 |
-
break
|
| 148 |
-
batch = dict_to_cuda(batch)
|
| 149 |
-
h_seg, q = forward_for_hidden_and_q(model, batch)
|
| 150 |
-
pool.append(
|
| 151 |
-
{
|
| 152 |
-
"sample_idx": sample_idx,
|
| 153 |
-
"vid": batch["vids"][0],
|
| 154 |
-
"ref": batch["refs"][0][0],
|
| 155 |
-
"fid": int(batch["fids"][0][0]),
|
| 156 |
-
"h": h_seg.cpu(),
|
| 157 |
-
"q": q.cpu(),
|
| 158 |
-
}
|
| 159 |
-
)
|
| 160 |
-
print(f"Collected h {sample_idx}: vid={pool[-1]['vid']} ref={pool[-1]['ref']}")
|
| 161 |
-
if not pool:
|
| 162 |
-
raise RuntimeError("No hidden states collected. Is the selected split empty?")
|
| 163 |
-
return pool
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def choose_shuffled_idx(sample_idx, pool):
|
| 167 |
-
if len(pool) <= 1:
|
| 168 |
-
return None
|
| 169 |
-
return (sample_idx + 1) % len(pool)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def choose_wrong_ref_idx(sample_idx, pool):
|
| 173 |
-
current = pool[sample_idx]
|
| 174 |
-
for item in pool:
|
| 175 |
-
if item["sample_idx"] == sample_idx:
|
| 176 |
-
continue
|
| 177 |
-
if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
|
| 178 |
-
return item["sample_idx"]
|
| 179 |
-
for item in pool:
|
| 180 |
-
if item["sample_idx"] == sample_idx:
|
| 181 |
-
continue
|
| 182 |
-
if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
|
| 183 |
-
return item["sample_idx"]
|
| 184 |
-
return None
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def run_d2_llm(model, tokenizer, pool, betas, limit):
|
| 188 |
-
rows = []
|
| 189 |
-
lookup = {item["sample_idx"]: item for item in pool}
|
| 190 |
-
generator = torch.Generator(device="cuda")
|
| 191 |
-
generator.manual_seed(1234)
|
| 192 |
-
loader = make_loader(tokenizer)
|
| 193 |
-
|
| 194 |
-
for sample_idx, batch in enumerate(loader):
|
| 195 |
-
if sample_idx >= limit:
|
| 196 |
-
break
|
| 197 |
-
batch = dict_to_cuda(batch)
|
| 198 |
-
item = lookup[sample_idx]
|
| 199 |
-
h_real = item["h"].cuda()
|
| 200 |
-
q_real = item["q"].cuda()
|
| 201 |
-
|
| 202 |
-
low_res_masks, iou_predictions = decode_low_res(model, batch, q_real)
|
| 203 |
-
llm_tokens = clip_projected_tokens(model, batch)
|
| 204 |
-
pred_mask_tokens = masks_to_token_grid(torch.sigmoid(low_res_masks), llm_tokens.shape[1])
|
| 205 |
-
gt_masks = batch["masks"][0][0].float()
|
| 206 |
-
gt_mask_tokens = masks_to_token_grid(gt_masks, llm_tokens.shape[1])
|
| 207 |
-
|
| 208 |
-
pred_logits_hr = model.get_model().visual_model.postprocess_masks(
|
| 209 |
-
low_res_masks.to(batch["image_feats"][0].dtype),
|
| 210 |
-
input_size=batch["resizes"][0],
|
| 211 |
-
original_size=batch["orgsizes"][0],
|
| 212 |
-
).squeeze(1)
|
| 213 |
-
frame_ious = frame_iou(pred_logits_hr, gt_masks)
|
| 214 |
-
frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
|
| 215 |
-
pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
|
| 216 |
-
gt_area = gt_masks.float().mean(dim=(1, 2))
|
| 217 |
-
|
| 218 |
-
shuffled_idx = choose_shuffled_idx(sample_idx, pool)
|
| 219 |
-
wrong_ref_idx = choose_wrong_ref_idx(sample_idx, pool)
|
| 220 |
-
controls = [
|
| 221 |
-
("real", h_real, sample_idx),
|
| 222 |
-
("random", torch.randn(h_real.shape, device=h_real.device, generator=generator), None),
|
| 223 |
-
]
|
| 224 |
-
if shuffled_idx is not None:
|
| 225 |
-
controls.append(("shuffled", lookup[shuffled_idx]["h"].cuda(), shuffled_idx))
|
| 226 |
-
if wrong_ref_idx is not None:
|
| 227 |
-
controls.append(("wrong_ref", lookup[wrong_ref_idx]["h"].cuda(), wrong_ref_idx))
|
| 228 |
-
|
| 229 |
-
for beta in betas:
|
| 230 |
-
for h_type, h, h_source_idx in controls:
|
| 231 |
-
pred_scores = d2_scores_llm(llm_tokens, pred_mask_tokens, h, beta)
|
| 232 |
-
gt_scores = d2_scores_llm(llm_tokens, gt_mask_tokens, h, beta)
|
| 233 |
-
for frame_idx in range(pred_scores.shape[0]):
|
| 234 |
-
rows.append(
|
| 235 |
-
{
|
| 236 |
-
"sample_idx": sample_idx,
|
| 237 |
-
"vid": item["vid"],
|
| 238 |
-
"ref": item["ref"],
|
| 239 |
-
"fid": item["fid"],
|
| 240 |
-
"split": args.eval_split,
|
| 241 |
-
"frame": frame_idx,
|
| 242 |
-
"h_type": h_type,
|
| 243 |
-
"beta": beta,
|
| 244 |
-
"s_pred": pred_scores[frame_idx].item(),
|
| 245 |
-
"s_gt": gt_scores[frame_idx].item(),
|
| 246 |
-
"h_source_idx": h_source_idx if h_source_idx is not None else "",
|
| 247 |
-
"frame_iou": frame_ious[frame_idx].item(),
|
| 248 |
-
"frame_fscore_proxy": frame_fscores[frame_idx].item(),
|
| 249 |
-
"iou_pred": iou_predictions[frame_idx].item(),
|
| 250 |
-
"pred_area": pred_area[frame_idx].item(),
|
| 251 |
-
"gt_area": gt_area[frame_idx].item(),
|
| 252 |
-
}
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
real_rows = [
|
| 256 |
-
r for r in rows if r["sample_idx"] == sample_idx and r["h_type"] == "real" and r["beta"] == betas[0]
|
| 257 |
-
]
|
| 258 |
-
s_pred_values = [r["s_pred"] for r in real_rows]
|
| 259 |
-
print(
|
| 260 |
-
f"D2-LLM {sample_idx}: vid={item['vid']} ref={item['ref']} "
|
| 261 |
-
f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
|
| 262 |
-
f"mean_iou={frame_ious.mean().item():.4f}"
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
return rows
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
def print_summary(rows):
|
| 269 |
-
print("\nSummary")
|
| 270 |
-
print(f"rows: {len(rows)}")
|
| 271 |
-
for beta in sorted(set(r["beta"] for r in rows)):
|
| 272 |
-
beta_rows = [r for r in rows if r["beta"] == beta]
|
| 273 |
-
print(f"\nbeta={beta}")
|
| 274 |
-
for h_type in sorted(set(r["h_type"] for r in beta_rows)):
|
| 275 |
-
hr = [r for r in beta_rows if r["h_type"] == h_type]
|
| 276 |
-
print(
|
| 277 |
-
f"{h_type:10s} "
|
| 278 |
-
f"mean_s_pred={np.mean([r['s_pred'] for r in hr]):+.4f} "
|
| 279 |
-
f"mean_s_gt={np.mean([r['s_gt'] for r in hr]):+.4f}"
|
| 280 |
-
)
|
| 281 |
-
real_rows = [r for r in beta_rows if r["h_type"] == "real"]
|
| 282 |
-
s_pred = np.array([r["s_pred"] for r in real_rows])
|
| 283 |
-
frame_iou_values = np.array([r["frame_iou"] for r in real_rows])
|
| 284 |
-
if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
|
| 285 |
-
corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
|
| 286 |
-
print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
|
| 287 |
-
else:
|
| 288 |
-
print("corr(real s_pred, frame_iou)=nan")
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
def main():
|
| 292 |
-
set_seed(42)
|
| 293 |
-
torch.set_grad_enabled(False)
|
| 294 |
-
betas = parse_betas()
|
| 295 |
-
tokenizer, seg_token_idx = build_tokenizer()
|
| 296 |
-
limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
|
| 297 |
-
print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
|
| 298 |
-
|
| 299 |
-
model = build_model(tokenizer, seg_token_idx)
|
| 300 |
-
pool = collect_hidden_pool(model, tokenizer, limit)
|
| 301 |
-
rows = run_d2_llm(model, tokenizer, pool, betas, limit)
|
| 302 |
-
print_summary(rows)
|
| 303 |
-
|
| 304 |
-
csv_path = os.environ.get("D2_LLM_CSV", f"/workspace/SimToken/d2_llm_{args.eval_split}_{limit}.csv")
|
| 305 |
-
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 306 |
-
with open(csv_path, "w", newline="") as f:
|
| 307 |
-
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
| 308 |
-
writer.writeheader()
|
| 309 |
-
writer.writerows(rows)
|
| 310 |
-
print(f"\nSaved CSV: {csv_path}")
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
if __name__ == "__main__":
|
| 314 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder_invariance_check.py
DELETED
|
@@ -1,256 +0,0 @@
|
|
| 1 |
-
import csv
|
| 2 |
-
import os
|
| 3 |
-
import random
|
| 4 |
-
from functools import partial
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import transformers
|
| 9 |
-
from peft import LoraConfig, get_peft_model
|
| 10 |
-
from torch.utils.data import DataLoader
|
| 11 |
-
from transformers import AutoConfig
|
| 12 |
-
|
| 13 |
-
from configs import args
|
| 14 |
-
from datasets import REFAVS
|
| 15 |
-
from load_model import collate_fn, dict_to_cuda
|
| 16 |
-
from models.avs_model import Simtoken_ForCausalLM
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def set_seed(seed=42):
|
| 20 |
-
torch.manual_seed(seed)
|
| 21 |
-
np.random.seed(seed)
|
| 22 |
-
random.seed(seed)
|
| 23 |
-
torch.cuda.manual_seed_all(seed)
|
| 24 |
-
torch.backends.cudnn.deterministic = True
|
| 25 |
-
torch.backends.cudnn.benchmark = False
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def find_lora_target_modules(model, target_modules=("q_proj", "v_proj")):
|
| 29 |
-
modules = set()
|
| 30 |
-
excluded = [
|
| 31 |
-
"visual_model",
|
| 32 |
-
"vision_tower",
|
| 33 |
-
"mm_projector",
|
| 34 |
-
"text_hidden_fcs",
|
| 35 |
-
"audio_feature_layer",
|
| 36 |
-
]
|
| 37 |
-
for name, module in model.named_modules():
|
| 38 |
-
if not isinstance(module, torch.nn.Linear):
|
| 39 |
-
continue
|
| 40 |
-
if any(x in name for x in excluded):
|
| 41 |
-
continue
|
| 42 |
-
if any(x in name for x in target_modules):
|
| 43 |
-
modules.add(name)
|
| 44 |
-
return sorted(modules)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def build_model(tokenizer, seg_token_idx):
|
| 48 |
-
model_args = {
|
| 49 |
-
"train_mask_decoder": True,
|
| 50 |
-
"out_dim": 256,
|
| 51 |
-
"ce_loss_weight": 1.0,
|
| 52 |
-
"dice_loss_weight": 0.5,
|
| 53 |
-
"bce_loss_weight": 2.0,
|
| 54 |
-
"seg_token_idx": seg_token_idx,
|
| 55 |
-
"vision_pretrained": args.vision_pretrained,
|
| 56 |
-
"vision_tower": args.vision_tower,
|
| 57 |
-
"use_im_start_end": False,
|
| 58 |
-
"compress": args.compress,
|
| 59 |
-
"start": args.start,
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
model = Simtoken_ForCausalLM.from_pretrained(
|
| 63 |
-
args.mllm,
|
| 64 |
-
torch_dtype=torch.bfloat16,
|
| 65 |
-
low_cpu_mem_usage=True,
|
| 66 |
-
**model_args,
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
model.config.eos_token_id = tokenizer.eos_token_id
|
| 70 |
-
model.config.bos_token_id = tokenizer.bos_token_id
|
| 71 |
-
model.config.pad_token_id = tokenizer.pad_token_id
|
| 72 |
-
|
| 73 |
-
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 74 |
-
vision_tower = model.get_model().get_vision_tower()
|
| 75 |
-
vision_tower.to(dtype=torch.float32, device="cuda")
|
| 76 |
-
|
| 77 |
-
model_args_from_pt = AutoConfig.from_pretrained(args.mllm)
|
| 78 |
-
model_args_from_pt.use_cluster = True
|
| 79 |
-
model_args_from_pt.freeze = False
|
| 80 |
-
model_args_from_pt.mm_tune = True
|
| 81 |
-
model_args_from_pt.spatial_cluster_rate0 = 64
|
| 82 |
-
model_args_from_pt.spatial_cluster_rate1 = 32
|
| 83 |
-
model_args_from_pt.spatial_cluster_rate2 = 16
|
| 84 |
-
model_args_from_pt.temporal_cluster_rate = 0.0625
|
| 85 |
-
model_args_from_pt.vision_tune = False
|
| 86 |
-
model.get_model().initialize_cluster_modules(model_args_from_pt)
|
| 87 |
-
model.get_model().initialize_lisa_modules(model.get_model().config)
|
| 88 |
-
|
| 89 |
-
lora_config = LoraConfig(
|
| 90 |
-
r=8,
|
| 91 |
-
lora_alpha=16,
|
| 92 |
-
target_modules=find_lora_target_modules(model),
|
| 93 |
-
lora_dropout=0.05,
|
| 94 |
-
bias="none",
|
| 95 |
-
task_type="CAUSAL_LM",
|
| 96 |
-
)
|
| 97 |
-
model = get_peft_model(model, lora_config)
|
| 98 |
-
model = model.to("cuda")
|
| 99 |
-
model.resize_token_embeddings(len(tokenizer))
|
| 100 |
-
|
| 101 |
-
state = torch.load(args.saved_model, map_location="cpu")
|
| 102 |
-
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 103 |
-
print(f"Loaded checkpoint: {args.saved_model}")
|
| 104 |
-
print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
|
| 105 |
-
|
| 106 |
-
model.eval()
|
| 107 |
-
return model
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def get_seg_embedding(model, batch):
|
| 111 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 112 |
-
output = model.forward(
|
| 113 |
-
images=batch["images"],
|
| 114 |
-
images_clip=batch["images_clip"],
|
| 115 |
-
audio_features=batch["audio_feats"],
|
| 116 |
-
image_features=batch["image_feats"],
|
| 117 |
-
input_ids=batch["input_ids"],
|
| 118 |
-
labels=batch["labels"],
|
| 119 |
-
attention_masks=batch["attention_masks"],
|
| 120 |
-
masks_list=batch["masks"],
|
| 121 |
-
resize_list=batch["resizes"],
|
| 122 |
-
orgsize_list=batch["orgsizes"],
|
| 123 |
-
conversation_list=batch["convs"],
|
| 124 |
-
refs_num=batch["refs_num"],
|
| 125 |
-
fids=batch["fids"],
|
| 126 |
-
vids=batch["vids"],
|
| 127 |
-
contrast=args.ct_weight,
|
| 128 |
-
ref_ids=batch["ref_ids"],
|
| 129 |
-
inference=True,
|
| 130 |
-
)
|
| 131 |
-
return output["seg_embeddings"][0][0:1]
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def check_one_sample(model, batch):
|
| 135 |
-
q = get_seg_embedding(model, batch)
|
| 136 |
-
image_embeddings = batch["image_feats"][0]
|
| 137 |
-
|
| 138 |
-
visual_model = model.get_model().visual_model
|
| 139 |
-
sparse, dense = visual_model.prompt_encoder(
|
| 140 |
-
points=None,
|
| 141 |
-
boxes=None,
|
| 142 |
-
masks=None,
|
| 143 |
-
text_embeds=q.unsqueeze(1),
|
| 144 |
-
)
|
| 145 |
-
sparse = sparse.to(q.dtype)
|
| 146 |
-
dense = dense.to(q.dtype)
|
| 147 |
-
|
| 148 |
-
decoder = visual_model.mask_decoder
|
| 149 |
-
image_pe = visual_model.prompt_encoder.get_dense_pe()
|
| 150 |
-
|
| 151 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 152 |
-
full_masks, full_iou = decoder(
|
| 153 |
-
image_embeddings=image_embeddings,
|
| 154 |
-
image_pe=image_pe,
|
| 155 |
-
sparse_prompt_embeddings=sparse,
|
| 156 |
-
dense_prompt_embeddings=dense,
|
| 157 |
-
multimask_output=False,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
rows = []
|
| 161 |
-
for t in range(image_embeddings.shape[0]):
|
| 162 |
-
single_masks, single_iou = decoder(
|
| 163 |
-
image_embeddings=image_embeddings[t : t + 1],
|
| 164 |
-
image_pe=image_pe,
|
| 165 |
-
sparse_prompt_embeddings=sparse,
|
| 166 |
-
dense_prompt_embeddings=dense,
|
| 167 |
-
multimask_output=False,
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
diff = (full_masks[t : t + 1] - single_masks).float().abs()
|
| 171 |
-
iou_diff = (full_iou[t : t + 1] - single_iou).float().abs()
|
| 172 |
-
rows.append(
|
| 173 |
-
{
|
| 174 |
-
"vid": batch["vids"][0],
|
| 175 |
-
"ref": batch["refs"][0][0],
|
| 176 |
-
"frame": t,
|
| 177 |
-
"max_abs_diff": diff.max().item(),
|
| 178 |
-
"mean_abs_diff": diff.mean().item(),
|
| 179 |
-
"iou_pred_diff": iou_diff.max().item(),
|
| 180 |
-
}
|
| 181 |
-
)
|
| 182 |
-
return rows
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def main():
|
| 186 |
-
set_seed(42)
|
| 187 |
-
torch.set_grad_enabled(False)
|
| 188 |
-
|
| 189 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 190 |
-
args.mllm,
|
| 191 |
-
cache_dir=None,
|
| 192 |
-
model_max_length=2048,
|
| 193 |
-
padding_side="right",
|
| 194 |
-
use_fast=False,
|
| 195 |
-
)
|
| 196 |
-
tokenizer.pad_token = tokenizer.unk_token
|
| 197 |
-
tokenizer.add_tokens("[SEG]")
|
| 198 |
-
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 199 |
-
|
| 200 |
-
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 201 |
-
loader = DataLoader(
|
| 202 |
-
dataset,
|
| 203 |
-
batch_size=1,
|
| 204 |
-
shuffle=False,
|
| 205 |
-
num_workers=0,
|
| 206 |
-
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
|
| 210 |
-
print(f"Split: {args.eval_split} | samples to check: {limit}")
|
| 211 |
-
|
| 212 |
-
model = build_model(tokenizer, seg_token_idx)
|
| 213 |
-
|
| 214 |
-
all_rows = []
|
| 215 |
-
for sample_idx, batch in enumerate(loader):
|
| 216 |
-
if sample_idx >= limit:
|
| 217 |
-
break
|
| 218 |
-
batch = dict_to_cuda(batch)
|
| 219 |
-
rows = check_one_sample(model, batch)
|
| 220 |
-
all_rows.extend(rows)
|
| 221 |
-
|
| 222 |
-
print(f"\nSample {sample_idx}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
|
| 223 |
-
print("frame | max_abs_diff | mean_abs_diff | iou_pred_diff")
|
| 224 |
-
for row in rows:
|
| 225 |
-
print(
|
| 226 |
-
f"{row['frame']:02d} | "
|
| 227 |
-
f"{row['max_abs_diff']:.8e} | "
|
| 228 |
-
f"{row['mean_abs_diff']:.8e} | "
|
| 229 |
-
f"{row['iou_pred_diff']:.8e}"
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
if not all_rows:
|
| 233 |
-
raise RuntimeError("No rows were checked. Is the selected split empty?")
|
| 234 |
-
|
| 235 |
-
max_diff = max(row["max_abs_diff"] for row in all_rows)
|
| 236 |
-
mean_diff = sum(row["mean_abs_diff"] for row in all_rows) / len(all_rows)
|
| 237 |
-
max_iou_diff = max(row["iou_pred_diff"] for row in all_rows)
|
| 238 |
-
|
| 239 |
-
print("\nSummary")
|
| 240 |
-
print(f"checked frames: {len(all_rows)}")
|
| 241 |
-
print(f"global max_abs_diff: {max_diff:.8e}")
|
| 242 |
-
print(f"average mean_abs_diff: {mean_diff:.8e}")
|
| 243 |
-
print(f"global max_iou_pred_diff: {max_iou_diff:.8e}")
|
| 244 |
-
|
| 245 |
-
csv_path = os.environ.get("DECODER_INVARIANCE_CSV")
|
| 246 |
-
if csv_path:
|
| 247 |
-
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 248 |
-
with open(csv_path, "w", newline="") as f:
|
| 249 |
-
writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
|
| 250 |
-
writer.writeheader()
|
| 251 |
-
writer.writerows(all_rows)
|
| 252 |
-
print(f"Saved CSV: {csv_path}")
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
if __name__ == "__main__":
|
| 256 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev_subsets_rpb_v1.json
DELETED
|
@@ -1,620 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"metadata": {
|
| 3 |
-
"seed": 42,
|
| 4 |
-
"split_sizes": {
|
| 5 |
-
"train": 14113,
|
| 6 |
-
"test_s": 2288,
|
| 7 |
-
"test_u": 1656,
|
| 8 |
-
"test_n": 1028
|
| 9 |
-
},
|
| 10 |
-
"source_metadata": "/workspace/SimToken/data/metadata.csv"
|
| 11 |
-
},
|
| 12 |
-
"subsets": {
|
| 13 |
-
"test_s": [
|
| 14 |
-
6,
|
| 15 |
-
16,
|
| 16 |
-
36,
|
| 17 |
-
71,
|
| 18 |
-
74,
|
| 19 |
-
88,
|
| 20 |
-
108,
|
| 21 |
-
114,
|
| 22 |
-
116,
|
| 23 |
-
122,
|
| 24 |
-
126,
|
| 25 |
-
128,
|
| 26 |
-
134,
|
| 27 |
-
138,
|
| 28 |
-
139,
|
| 29 |
-
146,
|
| 30 |
-
152,
|
| 31 |
-
159,
|
| 32 |
-
177,
|
| 33 |
-
196,
|
| 34 |
-
217,
|
| 35 |
-
219,
|
| 36 |
-
249,
|
| 37 |
-
256,
|
| 38 |
-
268,
|
| 39 |
-
276,
|
| 40 |
-
279,
|
| 41 |
-
286,
|
| 42 |
-
287,
|
| 43 |
-
297,
|
| 44 |
-
298,
|
| 45 |
-
299,
|
| 46 |
-
312,
|
| 47 |
-
313,
|
| 48 |
-
324,
|
| 49 |
-
331,
|
| 50 |
-
332,
|
| 51 |
-
347,
|
| 52 |
-
378,
|
| 53 |
-
383,
|
| 54 |
-
402,
|
| 55 |
-
410,
|
| 56 |
-
412,
|
| 57 |
-
420,
|
| 58 |
-
451,
|
| 59 |
-
452,
|
| 60 |
-
458,
|
| 61 |
-
467,
|
| 62 |
-
477,
|
| 63 |
-
484,
|
| 64 |
-
486,
|
| 65 |
-
497,
|
| 66 |
-
499,
|
| 67 |
-
512,
|
| 68 |
-
526,
|
| 69 |
-
533,
|
| 70 |
-
543,
|
| 71 |
-
550,
|
| 72 |
-
551,
|
| 73 |
-
567,
|
| 74 |
-
574,
|
| 75 |
-
576,
|
| 76 |
-
581,
|
| 77 |
-
594,
|
| 78 |
-
596,
|
| 79 |
-
608,
|
| 80 |
-
616,
|
| 81 |
-
625,
|
| 82 |
-
627,
|
| 83 |
-
642,
|
| 84 |
-
646,
|
| 85 |
-
663,
|
| 86 |
-
692,
|
| 87 |
-
700,
|
| 88 |
-
704,
|
| 89 |
-
724,
|
| 90 |
-
745,
|
| 91 |
-
754,
|
| 92 |
-
795,
|
| 93 |
-
815,
|
| 94 |
-
819,
|
| 95 |
-
831,
|
| 96 |
-
843,
|
| 97 |
-
854,
|
| 98 |
-
867,
|
| 99 |
-
895,
|
| 100 |
-
946,
|
| 101 |
-
953,
|
| 102 |
-
965,
|
| 103 |
-
975,
|
| 104 |
-
979,
|
| 105 |
-
989,
|
| 106 |
-
1004,
|
| 107 |
-
1007,
|
| 108 |
-
1008,
|
| 109 |
-
1010,
|
| 110 |
-
1023,
|
| 111 |
-
1039,
|
| 112 |
-
1051,
|
| 113 |
-
1052,
|
| 114 |
-
1072,
|
| 115 |
-
1075,
|
| 116 |
-
1080,
|
| 117 |
-
1088,
|
| 118 |
-
1099,
|
| 119 |
-
1101,
|
| 120 |
-
1104,
|
| 121 |
-
1106,
|
| 122 |
-
1134,
|
| 123 |
-
1138,
|
| 124 |
-
1169,
|
| 125 |
-
1180,
|
| 126 |
-
1201,
|
| 127 |
-
1205,
|
| 128 |
-
1221,
|
| 129 |
-
1230,
|
| 130 |
-
1247,
|
| 131 |
-
1258,
|
| 132 |
-
1272,
|
| 133 |
-
1279,
|
| 134 |
-
1284,
|
| 135 |
-
1294,
|
| 136 |
-
1297,
|
| 137 |
-
1312,
|
| 138 |
-
1329,
|
| 139 |
-
1339,
|
| 140 |
-
1343,
|
| 141 |
-
1367,
|
| 142 |
-
1379,
|
| 143 |
-
1406,
|
| 144 |
-
1417,
|
| 145 |
-
1461,
|
| 146 |
-
1462,
|
| 147 |
-
1468,
|
| 148 |
-
1473,
|
| 149 |
-
1474,
|
| 150 |
-
1489,
|
| 151 |
-
1493,
|
| 152 |
-
1500,
|
| 153 |
-
1510,
|
| 154 |
-
1517,
|
| 155 |
-
1552,
|
| 156 |
-
1556,
|
| 157 |
-
1557,
|
| 158 |
-
1589,
|
| 159 |
-
1609,
|
| 160 |
-
1612,
|
| 161 |
-
1618,
|
| 162 |
-
1622,
|
| 163 |
-
1624,
|
| 164 |
-
1644,
|
| 165 |
-
1647,
|
| 166 |
-
1665,
|
| 167 |
-
1669,
|
| 168 |
-
1676,
|
| 169 |
-
1682,
|
| 170 |
-
1683,
|
| 171 |
-
1691,
|
| 172 |
-
1700,
|
| 173 |
-
1726,
|
| 174 |
-
1746,
|
| 175 |
-
1748,
|
| 176 |
-
1758,
|
| 177 |
-
1764,
|
| 178 |
-
1765,
|
| 179 |
-
1778,
|
| 180 |
-
1785,
|
| 181 |
-
1786,
|
| 182 |
-
1808,
|
| 183 |
-
1826,
|
| 184 |
-
1852,
|
| 185 |
-
1861,
|
| 186 |
-
1883,
|
| 187 |
-
1891,
|
| 188 |
-
1916,
|
| 189 |
-
1938,
|
| 190 |
-
1944,
|
| 191 |
-
1967,
|
| 192 |
-
1971,
|
| 193 |
-
1980,
|
| 194 |
-
1986,
|
| 195 |
-
2034,
|
| 196 |
-
2044,
|
| 197 |
-
2067,
|
| 198 |
-
2074,
|
| 199 |
-
2082,
|
| 200 |
-
2085,
|
| 201 |
-
2118,
|
| 202 |
-
2128,
|
| 203 |
-
2156,
|
| 204 |
-
2176,
|
| 205 |
-
2182,
|
| 206 |
-
2185,
|
| 207 |
-
2188,
|
| 208 |
-
2194,
|
| 209 |
-
2206,
|
| 210 |
-
2211,
|
| 211 |
-
2215,
|
| 212 |
-
2247,
|
| 213 |
-
2256
|
| 214 |
-
],
|
| 215 |
-
"test_u": [
|
| 216 |
-
4,
|
| 217 |
-
16,
|
| 218 |
-
26,
|
| 219 |
-
38,
|
| 220 |
-
40,
|
| 221 |
-
48,
|
| 222 |
-
50,
|
| 223 |
-
65,
|
| 224 |
-
83,
|
| 225 |
-
92,
|
| 226 |
-
102,
|
| 227 |
-
117,
|
| 228 |
-
120,
|
| 229 |
-
135,
|
| 230 |
-
144,
|
| 231 |
-
153,
|
| 232 |
-
155,
|
| 233 |
-
185,
|
| 234 |
-
200,
|
| 235 |
-
201,
|
| 236 |
-
211,
|
| 237 |
-
219,
|
| 238 |
-
221,
|
| 239 |
-
226,
|
| 240 |
-
227,
|
| 241 |
-
240,
|
| 242 |
-
245,
|
| 243 |
-
251,
|
| 244 |
-
252,
|
| 245 |
-
255,
|
| 246 |
-
267,
|
| 247 |
-
272,
|
| 248 |
-
274,
|
| 249 |
-
276,
|
| 250 |
-
278,
|
| 251 |
-
282,
|
| 252 |
-
284,
|
| 253 |
-
286,
|
| 254 |
-
303,
|
| 255 |
-
309,
|
| 256 |
-
313,
|
| 257 |
-
328,
|
| 258 |
-
345,
|
| 259 |
-
348,
|
| 260 |
-
358,
|
| 261 |
-
363,
|
| 262 |
-
374,
|
| 263 |
-
376,
|
| 264 |
-
379,
|
| 265 |
-
383,
|
| 266 |
-
385,
|
| 267 |
-
387,
|
| 268 |
-
393,
|
| 269 |
-
396,
|
| 270 |
-
400,
|
| 271 |
-
412,
|
| 272 |
-
417,
|
| 273 |
-
428,
|
| 274 |
-
434,
|
| 275 |
-
452,
|
| 276 |
-
453,
|
| 277 |
-
456,
|
| 278 |
-
459,
|
| 279 |
-
463,
|
| 280 |
-
473,
|
| 281 |
-
490,
|
| 282 |
-
493,
|
| 283 |
-
504,
|
| 284 |
-
517,
|
| 285 |
-
525,
|
| 286 |
-
535,
|
| 287 |
-
543,
|
| 288 |
-
544,
|
| 289 |
-
545,
|
| 290 |
-
549,
|
| 291 |
-
550,
|
| 292 |
-
565,
|
| 293 |
-
584,
|
| 294 |
-
585,
|
| 295 |
-
594,
|
| 296 |
-
602,
|
| 297 |
-
603,
|
| 298 |
-
606,
|
| 299 |
-
638,
|
| 300 |
-
642,
|
| 301 |
-
643,
|
| 302 |
-
651,
|
| 303 |
-
684,
|
| 304 |
-
687,
|
| 305 |
-
692,
|
| 306 |
-
700,
|
| 307 |
-
721,
|
| 308 |
-
728,
|
| 309 |
-
752,
|
| 310 |
-
757,
|
| 311 |
-
779,
|
| 312 |
-
783,
|
| 313 |
-
785,
|
| 314 |
-
794,
|
| 315 |
-
803,
|
| 316 |
-
807,
|
| 317 |
-
814,
|
| 318 |
-
847,
|
| 319 |
-
849,
|
| 320 |
-
853,
|
| 321 |
-
854,
|
| 322 |
-
861,
|
| 323 |
-
867,
|
| 324 |
-
884,
|
| 325 |
-
900,
|
| 326 |
-
903,
|
| 327 |
-
906,
|
| 328 |
-
924,
|
| 329 |
-
930,
|
| 330 |
-
931,
|
| 331 |
-
941,
|
| 332 |
-
948,
|
| 333 |
-
957,
|
| 334 |
-
968,
|
| 335 |
-
972,
|
| 336 |
-
980,
|
| 337 |
-
987,
|
| 338 |
-
995,
|
| 339 |
-
996,
|
| 340 |
-
1007,
|
| 341 |
-
1009,
|
| 342 |
-
1028,
|
| 343 |
-
1033,
|
| 344 |
-
1034,
|
| 345 |
-
1040,
|
| 346 |
-
1054,
|
| 347 |
-
1098,
|
| 348 |
-
1104,
|
| 349 |
-
1111,
|
| 350 |
-
1121,
|
| 351 |
-
1126,
|
| 352 |
-
1134,
|
| 353 |
-
1155,
|
| 354 |
-
1161,
|
| 355 |
-
1167,
|
| 356 |
-
1180,
|
| 357 |
-
1186,
|
| 358 |
-
1192,
|
| 359 |
-
1212,
|
| 360 |
-
1214,
|
| 361 |
-
1219,
|
| 362 |
-
1226,
|
| 363 |
-
1254,
|
| 364 |
-
1256,
|
| 365 |
-
1259,
|
| 366 |
-
1261,
|
| 367 |
-
1270,
|
| 368 |
-
1278,
|
| 369 |
-
1285,
|
| 370 |
-
1288,
|
| 371 |
-
1290,
|
| 372 |
-
1305,
|
| 373 |
-
1310,
|
| 374 |
-
1323,
|
| 375 |
-
1325,
|
| 376 |
-
1343,
|
| 377 |
-
1360,
|
| 378 |
-
1375,
|
| 379 |
-
1376,
|
| 380 |
-
1404,
|
| 381 |
-
1411,
|
| 382 |
-
1426,
|
| 383 |
-
1429,
|
| 384 |
-
1442,
|
| 385 |
-
1449,
|
| 386 |
-
1452,
|
| 387 |
-
1456,
|
| 388 |
-
1475,
|
| 389 |
-
1478,
|
| 390 |
-
1479,
|
| 391 |
-
1484,
|
| 392 |
-
1493,
|
| 393 |
-
1499,
|
| 394 |
-
1500,
|
| 395 |
-
1501,
|
| 396 |
-
1506,
|
| 397 |
-
1517,
|
| 398 |
-
1523,
|
| 399 |
-
1528,
|
| 400 |
-
1536,
|
| 401 |
-
1545,
|
| 402 |
-
1546,
|
| 403 |
-
1550,
|
| 404 |
-
1561,
|
| 405 |
-
1570,
|
| 406 |
-
1598,
|
| 407 |
-
1609,
|
| 408 |
-
1611,
|
| 409 |
-
1625,
|
| 410 |
-
1632,
|
| 411 |
-
1634,
|
| 412 |
-
1635,
|
| 413 |
-
1641,
|
| 414 |
-
1654,
|
| 415 |
-
1655
|
| 416 |
-
],
|
| 417 |
-
"test_n": [
|
| 418 |
-
4,
|
| 419 |
-
5,
|
| 420 |
-
9,
|
| 421 |
-
16,
|
| 422 |
-
20,
|
| 423 |
-
25,
|
| 424 |
-
27,
|
| 425 |
-
33,
|
| 426 |
-
37,
|
| 427 |
-
40,
|
| 428 |
-
45,
|
| 429 |
-
46,
|
| 430 |
-
48,
|
| 431 |
-
53,
|
| 432 |
-
56,
|
| 433 |
-
60,
|
| 434 |
-
62,
|
| 435 |
-
67,
|
| 436 |
-
77,
|
| 437 |
-
78,
|
| 438 |
-
80,
|
| 439 |
-
81,
|
| 440 |
-
86,
|
| 441 |
-
90,
|
| 442 |
-
94,
|
| 443 |
-
99,
|
| 444 |
-
102,
|
| 445 |
-
106,
|
| 446 |
-
108,
|
| 447 |
-
111,
|
| 448 |
-
116,
|
| 449 |
-
121,
|
| 450 |
-
126,
|
| 451 |
-
127,
|
| 452 |
-
132,
|
| 453 |
-
143,
|
| 454 |
-
148,
|
| 455 |
-
153,
|
| 456 |
-
155,
|
| 457 |
-
156,
|
| 458 |
-
158,
|
| 459 |
-
160,
|
| 460 |
-
164,
|
| 461 |
-
168,
|
| 462 |
-
170,
|
| 463 |
-
171,
|
| 464 |
-
173,
|
| 465 |
-
175,
|
| 466 |
-
183,
|
| 467 |
-
184,
|
| 468 |
-
185,
|
| 469 |
-
188,
|
| 470 |
-
189,
|
| 471 |
-
190,
|
| 472 |
-
196,
|
| 473 |
-
202,
|
| 474 |
-
206,
|
| 475 |
-
208,
|
| 476 |
-
212,
|
| 477 |
-
217,
|
| 478 |
-
221,
|
| 479 |
-
222,
|
| 480 |
-
223,
|
| 481 |
-
233,
|
| 482 |
-
242,
|
| 483 |
-
246,
|
| 484 |
-
247,
|
| 485 |
-
259,
|
| 486 |
-
262,
|
| 487 |
-
269,
|
| 488 |
-
283,
|
| 489 |
-
298,
|
| 490 |
-
299,
|
| 491 |
-
306,
|
| 492 |
-
316,
|
| 493 |
-
317,
|
| 494 |
-
323,
|
| 495 |
-
330,
|
| 496 |
-
332,
|
| 497 |
-
334,
|
| 498 |
-
354,
|
| 499 |
-
357,
|
| 500 |
-
367,
|
| 501 |
-
372,
|
| 502 |
-
395,
|
| 503 |
-
397,
|
| 504 |
-
400,
|
| 505 |
-
405,
|
| 506 |
-
407,
|
| 507 |
-
420,
|
| 508 |
-
431,
|
| 509 |
-
435,
|
| 510 |
-
436,
|
| 511 |
-
444,
|
| 512 |
-
446,
|
| 513 |
-
461,
|
| 514 |
-
464,
|
| 515 |
-
470,
|
| 516 |
-
479,
|
| 517 |
-
481,
|
| 518 |
-
483,
|
| 519 |
-
485,
|
| 520 |
-
487,
|
| 521 |
-
494,
|
| 522 |
-
512,
|
| 523 |
-
516,
|
| 524 |
-
520,
|
| 525 |
-
524,
|
| 526 |
-
529,
|
| 527 |
-
530,
|
| 528 |
-
539,
|
| 529 |
-
540,
|
| 530 |
-
541,
|
| 531 |
-
554,
|
| 532 |
-
559,
|
| 533 |
-
560,
|
| 534 |
-
564,
|
| 535 |
-
568,
|
| 536 |
-
571,
|
| 537 |
-
572,
|
| 538 |
-
576,
|
| 539 |
-
577,
|
| 540 |
-
581,
|
| 541 |
-
585,
|
| 542 |
-
592,
|
| 543 |
-
602,
|
| 544 |
-
609,
|
| 545 |
-
620,
|
| 546 |
-
630,
|
| 547 |
-
632,
|
| 548 |
-
677,
|
| 549 |
-
678,
|
| 550 |
-
684,
|
| 551 |
-
693,
|
| 552 |
-
694,
|
| 553 |
-
695,
|
| 554 |
-
702,
|
| 555 |
-
716,
|
| 556 |
-
724,
|
| 557 |
-
727,
|
| 558 |
-
732,
|
| 559 |
-
735,
|
| 560 |
-
736,
|
| 561 |
-
747,
|
| 562 |
-
750,
|
| 563 |
-
752,
|
| 564 |
-
755,
|
| 565 |
-
758,
|
| 566 |
-
764,
|
| 567 |
-
767,
|
| 568 |
-
774,
|
| 569 |
-
775,
|
| 570 |
-
777,
|
| 571 |
-
779,
|
| 572 |
-
780,
|
| 573 |
-
782,
|
| 574 |
-
795,
|
| 575 |
-
800,
|
| 576 |
-
812,
|
| 577 |
-
815,
|
| 578 |
-
818,
|
| 579 |
-
821,
|
| 580 |
-
823,
|
| 581 |
-
825,
|
| 582 |
-
828,
|
| 583 |
-
834,
|
| 584 |
-
841,
|
| 585 |
-
843,
|
| 586 |
-
846,
|
| 587 |
-
848,
|
| 588 |
-
860,
|
| 589 |
-
861,
|
| 590 |
-
863,
|
| 591 |
-
869,
|
| 592 |
-
871,
|
| 593 |
-
878,
|
| 594 |
-
882,
|
| 595 |
-
891,
|
| 596 |
-
893,
|
| 597 |
-
896,
|
| 598 |
-
898,
|
| 599 |
-
899,
|
| 600 |
-
901,
|
| 601 |
-
906,
|
| 602 |
-
930,
|
| 603 |
-
940,
|
| 604 |
-
944,
|
| 605 |
-
969,
|
| 606 |
-
970,
|
| 607 |
-
973,
|
| 608 |
-
980,
|
| 609 |
-
990,
|
| 610 |
-
993,
|
| 611 |
-
996,
|
| 612 |
-
997,
|
| 613 |
-
1007,
|
| 614 |
-
1012,
|
| 615 |
-
1013,
|
| 616 |
-
1019,
|
| 617 |
-
1025
|
| 618 |
-
]
|
| 619 |
-
}
|
| 620 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_eval_baseline_step0.txt
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.004542401526123285 Learning Rate:0.000000
|
| 2 |
-
valuate on test_s_refer: miou 0.7255374467872275 true fscore 0.8181094569922425
|
| 3 |
-
valuate on test_u_refer: miou 0.68531153425507 true fscore 0.7723772643739357
|
| 4 |
-
|
| 5 |
-
valuate on test_n_refer: metric 0.014519116841256618
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_eval_pm_only_a02_step0.txt
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.013856410048902035 Learning Rate:0.000000
|
| 2 |
-
valuate on test_s_refer: miou 0.7251653336426284 true fscore 0.8137564373598434
|
| 3 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.752373 | cos_delta_q_mean=-0.063845 | cos_delta_z_gt_mean=0.066832 | cos_p_hat_p_mask_mean=0.095022 | cos_p_hat_q_mean=0.991696 | cos_p_hat_z_gt_mean=0.058512 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=4.838175 | gate_mean=0.642605 | gate_std=0.066554 | p_hat_norm_mean=37.143986 | p_mask_norm_mean=0.855194 | q_norm_mean=37.143986 | z_gt_norm_mean=1.270137
|
| 4 |
-
valuate on test_u_refer: miou 0.6859597001315854 true fscore 0.7773032036889345
|
| 5 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.752107 | cos_delta_q_mean=-0.052752 | cos_delta_z_gt_mean=0.059016 | cos_p_hat_p_mask_mean=0.066111 | cos_p_hat_q_mean=0.994380 | cos_p_hat_z_gt_mean=0.056506 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=3.232154 | gate_mean=0.529798 | gate_std=0.041540 | p_hat_norm_mean=30.350392 | p_mask_norm_mean=0.854621 | q_norm_mean=30.350392 | z_gt_norm_mean=1.131404
|
| 6 |
-
|
| 7 |
-
valuate on test_n_refer: metric 0.014255181886255741
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_mixed_pm_only_a015_wm005.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.12634180719032884 Learning Rate:0.000048
|
| 2 |
-
Epoch 1: running_loss 0.06299160566413775 Learning Rate:0.000038
|
| 3 |
-
Epoch 2: running_loss 0.04188278445508331 Learning Rate:0.000021
|
| 4 |
-
Epoch 3: running_loss 0.03136271081166342 Learning Rate:0.000006
|
| 5 |
-
Epoch 4: running_loss 0.025073944311589002 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.7268448945908449 true fscore 0.8160740848700516
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.780949 | cos_delta_q_mean=-0.022341 | cos_delta_z_gt_mean=0.080238 | cos_p_hat_p_mask_mean=0.033820 | cos_p_hat_q_mean=0.998889 | cos_p_hat_z_gt_mean=0.053521 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=1.741799 | gate_mean=0.298187 | gate_std=0.074034 | p_hat_norm_mean=37.144979 | p_mask_norm_mean=0.855194 | q_norm_mean=37.144979 | z_gt_norm_mean=1.270137
|
| 8 |
-
valuate on test_u_refer: miou 0.6867437321859904 true fscore 0.774193259445019
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.787519 | cos_delta_q_mean=-0.014046 | cos_delta_z_gt_mean=0.070144 | cos_p_hat_p_mask_mean=0.008821 | cos_p_hat_q_mean=0.999587 | cos_p_hat_z_gt_mean=0.052258 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.869715 | gate_mean=0.187340 | gate_std=0.030662 | p_hat_norm_mean=30.349741 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349741 | z_gt_norm_mean=1.131404
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.014510215260088444
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_mixed_pm_only_a018_wm005.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.12581317650619894 Learning Rate:0.000048
|
| 2 |
-
Epoch 1: running_loss 0.0626903815427795 Learning Rate:0.000038
|
| 3 |
-
Epoch 2: running_loss 0.04165894452792903 Learning Rate:0.000021
|
| 4 |
-
Epoch 3: running_loss 0.031184122432023287 Learning Rate:0.000006
|
| 5 |
-
Epoch 4: running_loss 0.024928097636438905 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.727035479994347 true fscore 0.8155373766715638
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.779142 | cos_delta_q_mean=-0.026866 | cos_delta_z_gt_mean=0.080963 | cos_p_hat_p_mask_mean=0.040792 | cos_p_hat_q_mean=0.998394 | cos_p_hat_z_gt_mean=0.054268 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=2.094408 | gate_mean=0.298949 | gate_std=0.074175 | p_hat_norm_mean=37.145271 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145271 | z_gt_norm_mean=1.270137
|
| 8 |
-
valuate on test_u_refer: miou 0.6870561258980442 true fscore 0.774542552176863
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.786014 | cos_delta_q_mean=-0.016895 | cos_delta_z_gt_mean=0.071182 | cos_p_hat_p_mask_mean=0.013252 | cos_p_hat_q_mean=0.999403 | cos_p_hat_z_gt_mean=0.052698 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=1.046129 | gate_mean=0.187813 | gate_std=0.030748 | p_hat_norm_mean=30.349577 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349577 | z_gt_norm_mean=1.131404
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.014507208950817585
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_pm_only_a012.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.1251933453604579 Learning Rate:0.000291
|
| 2 |
-
Epoch 1: running_loss 0.06243458506651223 Learning Rate:0.000225
|
| 3 |
-
Epoch 2: running_loss 0.04142383218277246 Learning Rate:0.000124
|
| 4 |
-
Epoch 3: running_loss 0.030912025278666988 Learning Rate:0.000035
|
| 5 |
-
Epoch 4: running_loss 0.024670254811644553 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.7265147582390341 true fscore 0.8174789174459874
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.785657 | cos_delta_q_mean=-0.012593 | cos_delta_z_gt_mean=0.074588 | cos_p_hat_p_mask_mean=0.018714 | cos_p_hat_q_mean=0.999648 | cos_p_hat_z_gt_mean=0.051832 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=0.980784 | gate_mean=0.209955 | gate_std=0.050712 | p_hat_norm_mean=37.145389 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145389 | z_gt_norm_mean=1.270137
|
| 8 |
-
valuate on test_u_refer: miou 0.685781483513075 true fscore 0.7731429794151335
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.790605 | cos_delta_q_mean=-0.008125 | cos_delta_z_gt_mean=0.065258 | cos_p_hat_p_mask_mean=-0.000455 | cos_p_hat_q_mean=0.999863 | cos_p_hat_z_gt_mean=0.051334 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.502185 | gate_mean=0.135438 | gate_std=0.020096 | p_hat_norm_mean=30.347839 | p_mask_norm_mean=0.854621 | q_norm_mean=30.347839 | z_gt_norm_mean=1.131404
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.014490844681859016
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_pm_only_a015.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.12516111659351736 Learning Rate:0.000291
|
| 2 |
-
Epoch 1: running_loss 0.06237624154891819 Learning Rate:0.000225
|
| 3 |
-
Epoch 2: running_loss 0.04133288407077392 Learning Rate:0.000124
|
| 4 |
-
Epoch 3: running_loss 0.03080323277390562 Learning Rate:0.000035
|
| 5 |
-
Epoch 4: running_loss 0.024568469962105155 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.7266912544447951 true fscore 0.8172510598856024
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.784637 | cos_delta_q_mean=-0.015801 | cos_delta_z_gt_mean=0.074893 | cos_p_hat_p_mask_mean=0.023727 | cos_p_hat_q_mean=0.999446 | cos_p_hat_z_gt_mean=0.052317 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=1.230677 | gate_mean=0.210794 | gate_std=0.050954 | p_hat_norm_mean=37.144974 | p_mask_norm_mean=0.855194 | q_norm_mean=37.144974 | z_gt_norm_mean=1.270137
|
| 8 |
-
valuate on test_u_refer: miou 0.6856936469832761 true fscore 0.7733012911863625
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.789761 | cos_delta_q_mean=-0.010194 | cos_delta_z_gt_mean=0.065751 | cos_p_hat_p_mask_mean=0.002815 | cos_p_hat_q_mean=0.999784 | cos_p_hat_z_gt_mean=0.051617 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.630081 | gate_mean=0.135950 | gate_std=0.020168 | p_hat_norm_mean=30.349286 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349286 | z_gt_norm_mean=1.131404
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.014483190141618252
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_pm_only_a018.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.12512886058539152 Learning Rate:0.000291
|
| 2 |
-
Epoch 1: running_loss 0.062317848962266 Learning Rate:0.000225
|
| 3 |
-
Epoch 2: running_loss 0.04124188135998944 Learning Rate:0.000124
|
| 4 |
-
Epoch 3: running_loss 0.03069439489627257 Learning Rate:0.000035
|
| 5 |
-
Epoch 4: running_loss 0.024466648511588574 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.7269170961743339 true fscore 0.817047117385082
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.783528 | cos_delta_q_mean=-0.019011 | cos_delta_z_gt_mean=0.075155 | cos_p_hat_p_mask_mean=0.028732 | cos_p_hat_q_mean=0.999199 | cos_p_hat_z_gt_mean=0.052798 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=1.480661 | gate_mean=0.211391 | gate_std=0.051102 | p_hat_norm_mean=37.145608 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145608 | z_gt_norm_mean=1.270137
|
| 8 |
-
valuate on test_u_refer: miou 0.6859480822706291 true fscore 0.7735356919141486
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.788825 | cos_delta_q_mean=-0.012263 | cos_delta_z_gt_mean=0.066219 | cos_p_hat_p_mask_mean=0.006046 | cos_p_hat_q_mean=0.999688 | cos_p_hat_z_gt_mean=0.051902 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=0.757877 | gate_mean=0.136287 | gate_std=0.020245 | p_hat_norm_mean=30.346972 | p_mask_norm_mean=0.854621 | q_norm_mean=30.346972 | z_gt_norm_mean=1.131404
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.014475596137344837
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_dev_qonly_pm_only_a018.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.1250931837130338 Learning Rate:0.000291
|
| 2 |
-
Epoch 1: running_loss 0.06158186250831932 Learning Rate:0.000225
|
| 3 |
-
Epoch 2: running_loss 0.03905615148444971 Learning Rate:0.000124
|
| 4 |
-
Epoch 3: running_loss 0.028493995574535802 Learning Rate:0.000035
|
| 5 |
-
Epoch 4: running_loss 0.022694221674464644 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.7231086666105239 true fscore 0.8120589338685386
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.740588 | cos_delta_q_mean=-0.082204 | cos_delta_z_gt_mean=0.083615 | cos_p_hat_p_mask_mean=0.120609 | cos_p_hat_q_mean=0.986413 | cos_p_hat_z_gt_mean=0.063688 | cos_p_mask_z_gt_mean=0.064319 | delta_norm_mean=6.165701 | gate_mean=0.922904 | gate_std=0.048146 | p_hat_norm_mean=37.145128 | p_mask_norm_mean=0.855194 | q_norm_mean=37.145128 | z_gt_norm_mean=1.270137
|
| 8 |
-
valuate on test_u_refer: miou 0.6828930461963626 true fscore 0.7766606059018523
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.750842 | cos_delta_q_mean=-0.072793 | cos_delta_z_gt_mean=0.080115 | cos_p_hat_p_mask_mean=0.095975 | cos_p_hat_q_mean=0.989300 | cos_p_hat_z_gt_mean=0.061951 | cos_p_mask_z_gt_mean=0.056127 | delta_norm_mean=4.458672 | gate_mean=0.815494 | gate_std=0.064275 | p_hat_norm_mean=30.349046 | p_mask_norm_mean=0.854621 | q_norm_mean=30.349046 | z_gt_norm_mean=1.131404
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.014240134507417679
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_e1_baseline.txt
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.0045423684641718864 Learning Rate:0.000000
|
| 2 |
-
valuate on test_s_refer: miou 0.7299158895817891 true fscore 0.8098922965396196
|
| 3 |
-
valuate on test_u_refer: miou 0.7330115197712439 true fscore 0.8183729078620672
|
| 4 |
-
|
| 5 |
-
valuate on test_n_refer: metric 0.1223459392786026
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_e4_min.txt
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 7.052718125283718 Learning Rate:0.000097
|
| 2 |
-
Epoch 1: running_loss 3.5262171775102615 Learning Rate:0.000075
|
| 3 |
-
Epoch 2: running_loss 2.35092111180226 Learning Rate:0.000041
|
| 4 |
-
Epoch 3: running_loss 1.7629929669201374 Learning Rate:0.000012
|
| 5 |
-
Epoch 4: running_loss 1.4105001017451286 Learning Rate:0.000000
|
| 6 |
-
Epoch 0: running_loss 7.052717879414558 Learning Rate:0.000097
|
| 7 |
-
Epoch 1: running_loss 3.526217419654131 Learning Rate:0.000075
|
| 8 |
-
Epoch 2: running_loss 2.3509211614727974 Learning Rate:0.000041
|
| 9 |
-
Epoch 3: running_loss 1.762992987409234 Learning Rate:0.000012
|
| 10 |
-
Epoch 4: running_loss 1.410500232875347 Learning Rate:0.000000
|
| 11 |
-
valuate on test_s_refer: miou 0.010701371397460661 true fscore 0.16367542997933923
|
| 12 |
-
bridge on test_s_refer: cos_p_hat_p_mask_mean=-0.003076 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=0.031631 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=0.003709 | gate_mean=0.019151 | gate_std=0.000754 | p_hat_norm_mean=6.222885 | p_mask_norm_mean=0.854909 | q_norm_mean=6.223040 | z_gt_norm_mean=1.275222
|
| 13 |
-
valuate on test_u_refer: miou 0.03141531638093511 true fscore 0.1579975866433233
|
| 14 |
-
bridge on test_u_refer: cos_p_hat_p_mask_mean=-0.004606 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=-0.000177 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=0.003449 | gate_mean=0.019014 | gate_std=0.000658 | p_hat_norm_mean=5.875611 | p_mask_norm_mean=0.855032 | q_norm_mean=5.875684 | z_gt_norm_mean=0.969146
|
| 15 |
-
|
| 16 |
-
valuate on test_n_refer: metric 0.15515293180942535
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_e4_min_v2.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.2470331892836839 Learning Rate:0.000097
|
| 2 |
-
Epoch 1: running_loss 0.12353144341614097 Learning Rate:0.000075
|
| 3 |
-
Epoch 2: running_loss 0.08232998211557667 Learning Rate:0.000041
|
| 4 |
-
Epoch 3: running_loss 0.0617638936964795 Learning Rate:0.000012
|
| 5 |
-
Epoch 4: running_loss 0.04941030433401465 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.729936970449844 true fscore 0.8099028875399381
|
| 7 |
-
bridge on test_s_refer: cos_p_hat_p_mask_mean=-0.009047 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=0.060572 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=0.004936 | gate_mean=0.024371 | gate_std=0.005409 | p_hat_norm_mean=36.236958 | p_mask_norm_mean=0.854909 | q_norm_mean=36.239986 | z_gt_norm_mean=1.275222
|
| 8 |
-
valuate on test_u_refer: miou 0.7330397108156467 true fscore 0.8183516443520784
|
| 9 |
-
bridge on test_u_refer: cos_p_hat_p_mask_mean=-0.004755 | cos_p_hat_q_mean=1.000000 | cos_p_hat_z_gt_mean=0.013517 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=0.004417 | gate_mean=0.023295 | gate_std=0.004361 | p_hat_norm_mean=30.846060 | p_mask_norm_mean=0.855032 | q_norm_mean=30.848833 | z_gt_norm_mean=0.969146
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.12235464155673981
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_probe_a1_teacher_only.txt
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.15941409580409527 Learning Rate:0.000150
|
| 2 |
-
Epoch 1: running_loss 0.07969226781278849 Learning Rate:0.000300
|
| 3 |
-
Epoch 2: running_loss 0.05310918173442284 Learning Rate:0.000298
|
| 4 |
-
Epoch 3: running_loss 0.03982830489985645 Learning Rate:0.000291
|
| 5 |
-
Epoch 4: running_loss 0.03184974528849125 Learning Rate:0.000280
|
| 6 |
-
Epoch 5: running_loss 0.02652722302203377 Learning Rate:0.000265
|
| 7 |
-
Epoch 6: running_loss 0.02272333244660071 Learning Rate:0.000246
|
| 8 |
-
Epoch 7: running_loss 0.019872855627909303 Learning Rate:0.000225
|
| 9 |
-
Epoch 8: running_loss 0.017649518532885447 Learning Rate:0.000201
|
| 10 |
-
Epoch 9: running_loss 0.015872883144766092 Learning Rate:0.000176
|
| 11 |
-
Epoch 10: running_loss 0.014423399655656382 Learning Rate:0.000150
|
| 12 |
-
Epoch 11: running_loss 0.013206382282078266 Learning Rate:0.000124
|
| 13 |
-
Epoch 12: running_loss 0.012179449988672366 Learning Rate:0.000099
|
| 14 |
-
Epoch 13: running_loss 0.011303224135190248 Learning Rate:0.000075
|
| 15 |
-
Epoch 14: running_loss 0.010542566950122515 Learning Rate:0.000054
|
| 16 |
-
Epoch 15: running_loss 0.0098747648880817 Learning Rate:0.000035
|
| 17 |
-
Epoch 16: running_loss 0.009292871307800798 Learning Rate:0.000020
|
| 18 |
-
Epoch 17: running_loss 0.008775248295731015 Learning Rate:0.000009
|
| 19 |
-
Epoch 18: running_loss 0.008311718702316284 Learning Rate:0.000002
|
| 20 |
-
Epoch 19: running_loss 0.007893257355317474 Learning Rate:0.000000
|
| 21 |
-
valuate on train_overfit: miou 0.8857842811448791 true fscore 0.9381048823706806
|
| 22 |
-
bridge on train_overfit: cos_p_hat_p_mask_mean=0.004767 | cos_p_hat_q_mean=0.999904 | cos_p_hat_z_gt_mean=0.058385 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=0.571159 | gate_mean=0.425535 | gate_std=0.188610 | p_hat_norm_mean=32.916147 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257832 | z_gt_norm_mean=1.191098
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_probe_a1_teacher_only_v2.txt
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.15941409580409527 Learning Rate:0.000150
|
| 2 |
-
Epoch 1: running_loss 0.0796922636218369 Learning Rate:0.000300
|
| 3 |
-
Epoch 2: running_loss 0.05310917769869169 Learning Rate:0.000298
|
| 4 |
-
Epoch 3: running_loss 0.03982830559834838 Learning Rate:0.000291
|
| 5 |
-
Epoch 4: running_loss 0.03184974305331707 Learning Rate:0.000280
|
| 6 |
-
Epoch 5: running_loss 0.02652722333247463 Learning Rate:0.000265
|
| 7 |
-
Epoch 6: running_loss 0.022723329652632986 Learning Rate:0.000246
|
| 8 |
-
Epoch 7: running_loss 0.019872855744324625 Learning Rate:0.000225
|
| 9 |
-
Epoch 8: running_loss 0.017649516980681155 Learning Rate:0.000201
|
| 10 |
-
Epoch 9: running_loss 0.015872882585972546 Learning Rate:0.000176
|
| 11 |
-
Epoch 10: running_loss 0.01442340033298189 Learning Rate:0.000150
|
| 12 |
-
Epoch 11: running_loss 0.013206382825349769 Learning Rate:0.000124
|
| 13 |
-
Epoch 12: running_loss 0.012179449773751773 Learning Rate:0.000099
|
| 14 |
-
Epoch 13: running_loss 0.011303224002144166 Learning Rate:0.000075
|
| 15 |
-
Epoch 14: running_loss 0.010542566763858001 Learning Rate:0.000054
|
| 16 |
-
Epoch 15: running_loss 0.00987476430600509 Learning Rate:0.000035
|
| 17 |
-
Epoch 16: running_loss 0.009292872293907054 Learning Rate:0.000020
|
| 18 |
-
Epoch 17: running_loss 0.0087752483992113 Learning Rate:0.000009
|
| 19 |
-
Epoch 18: running_loss 0.008311718849367216 Learning Rate:0.000002
|
| 20 |
-
Epoch 19: running_loss 0.007893257355317474 Learning Rate:0.000000
|
| 21 |
-
valuate on train_overfit: miou 0.8857840351993218 true fscore 0.9381047114729881
|
| 22 |
-
bridge on train_overfit: cos_delta_p_mask_mean=0.354064 | cos_delta_q_mean=-0.604202 | cos_delta_z_gt_mean=0.126264 | cos_p_hat_p_mask_mean=0.004767 | cos_p_hat_q_mean=0.999904 | cos_p_hat_z_gt_mean=0.058385 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=0.571159 | gate_mean=0.425535 | gate_std=0.188610 | p_hat_norm_mean=32.916147 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257831 | z_gt_norm_mean=1.191098
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_probe_a1p_directional_pm_only.txt
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.11214640829712152 Learning Rate:0.000150
|
| 2 |
-
Epoch 1: running_loss 0.05601485609076917 Learning Rate:0.000300
|
| 3 |
-
Epoch 2: running_loss 0.03723815083503723 Learning Rate:0.000298
|
| 4 |
-
Epoch 3: running_loss 0.02785203023813665 Learning Rate:0.000291
|
| 5 |
-
Epoch 4: running_loss 0.022219109814614058 Learning Rate:0.000280
|
| 6 |
-
Epoch 5: running_loss 0.018464789803450305 Learning Rate:0.000265
|
| 7 |
-
Epoch 6: running_loss 0.01578202284872532 Learning Rate:0.000246
|
| 8 |
-
Epoch 7: running_loss 0.013773231767117977 Learning Rate:0.000225
|
| 9 |
-
Epoch 8: running_loss 0.012206872407760885 Learning Rate:0.000201
|
| 10 |
-
Epoch 9: running_loss 0.010958488751202821 Learning Rate:0.000176
|
| 11 |
-
Epoch 10: running_loss 0.009943378030915152 Learning Rate:0.000150
|
| 12 |
-
Epoch 11: running_loss 0.009091336939794322 Learning Rate:0.000124
|
| 13 |
-
Epoch 12: running_loss 0.00837581454274746 Learning Rate:0.000099
|
| 14 |
-
Epoch 13: running_loss 0.007767901090638978 Learning Rate:0.000075
|
| 15 |
-
Epoch 14: running_loss 0.007241058039168516 Learning Rate:0.000054
|
| 16 |
-
Epoch 15: running_loss 0.006779163610190153 Learning Rate:0.000035
|
| 17 |
-
Epoch 16: running_loss 0.006378827452221338 Learning Rate:0.000020
|
| 18 |
-
Epoch 17: running_loss 0.006023053286804093 Learning Rate:0.000009
|
| 19 |
-
Epoch 18: running_loss 0.005704282390836038 Learning Rate:0.000002
|
| 20 |
-
Epoch 19: running_loss 0.005416269856505096 Learning Rate:0.000000
|
| 21 |
-
valuate on train_overfit: miou 0.883418077353781 true fscore 0.937678836286068
|
| 22 |
-
bridge on train_overfit: cos_delta_p_mask_mean=0.818447 | cos_delta_q_mean=-0.029885 | cos_delta_z_gt_mean=0.063824 | cos_p_hat_p_mask_mean=0.047561 | cos_p_hat_q_mean=0.998200 | cos_p_hat_z_gt_mean=0.059441 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=2.004932 | gate_mean=0.598515 | gate_std=0.034498 | p_hat_norm_mean=33.257835 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257834 | z_gt_norm_mean=1.191098
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_probe_a1p_directional_pm_only_a02.txt
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.11209722375497222 Learning Rate:0.000150
|
| 2 |
-
Epoch 1: running_loss 0.05594216543249786 Learning Rate:0.000300
|
| 3 |
-
Epoch 2: running_loss 0.03709370751554767 Learning Rate:0.000298
|
| 4 |
-
Epoch 3: running_loss 0.027660266729071736 Learning Rate:0.000291
|
| 5 |
-
Epoch 4: running_loss 0.02200547931715846 Learning Rate:0.000280
|
| 6 |
-
Epoch 5: running_loss 0.018238045663262408 Learning Rate:0.000265
|
| 7 |
-
Epoch 6: running_loss 0.015544687730393239 Learning Rate:0.000246
|
| 8 |
-
Epoch 7: running_loss 0.013526892522349954 Learning Rate:0.000225
|
| 9 |
-
Epoch 8: running_loss 0.01195424489883913 Learning Rate:0.000201
|
| 10 |
-
Epoch 9: running_loss 0.010702831950038672 Learning Rate:0.000176
|
| 11 |
-
Epoch 10: running_loss 0.009686671324412931 Learning Rate:0.000150
|
| 12 |
-
Epoch 11: running_loss 0.008837080444209278 Learning Rate:0.000124
|
| 13 |
-
Epoch 12: running_loss 0.008126160953767024 Learning Rate:0.000099
|
| 14 |
-
Epoch 13: running_loss 0.007524690058614526 Learning Rate:0.000075
|
| 15 |
-
Epoch 14: running_loss 0.007005957514047622 Learning Rate:0.000054
|
| 16 |
-
Epoch 15: running_loss 0.0065534417517483234 Learning Rate:0.000035
|
| 17 |
-
Epoch 16: running_loss 0.006162627901443664 Learning Rate:0.000020
|
| 18 |
-
Epoch 17: running_loss 0.005816713182462586 Learning Rate:0.000009
|
| 19 |
-
Epoch 18: running_loss 0.005507827319793011 Learning Rate:0.000002
|
| 20 |
-
Epoch 19: running_loss 0.005229406012222171 Learning Rate:0.000000
|
| 21 |
-
valuate on train_overfit: miou 0.8791497684578644 true fscore 0.9370119273662567
|
| 22 |
-
bridge on train_overfit: cos_delta_p_mask_mean=0.808940 | cos_delta_q_mean=-0.059708 | cos_delta_z_gt_mean=0.061659 | cos_p_hat_p_mask_mean=0.095240 | cos_p_hat_q_mean=0.992816 | cos_p_hat_z_gt_mean=0.062994 | cos_p_mask_z_gt_mean=0.065508 | delta_norm_mean=4.005328 | gate_mean=0.600366 | gate_std=0.034520 | p_hat_norm_mean=33.257835 | p_mask_norm_mean=0.854710 | q_norm_mean=33.257836 | z_gt_norm_mean=1.191098
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_probe_eval_directional_pm_only_a02.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.12453739601187408 Learning Rate:0.000291
|
| 2 |
-
Epoch 1: running_loss 0.06081169372191653 Learning Rate:0.000225
|
| 3 |
-
Epoch 2: running_loss 0.039517335942946374 Learning Rate:0.000124
|
| 4 |
-
Epoch 3: running_loss 0.029158065939554945 Learning Rate:0.000035
|
| 5 |
-
Epoch 4: running_loss 0.02320093212183565 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.7251764057789819 true fscore 0.8044321979023517
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.754565 | cos_delta_q_mean=-0.062171 | cos_delta_z_gt_mean=0.077296 | cos_p_hat_p_mask_mean=0.084720 | cos_p_hat_q_mean=0.992132 | cos_p_hat_z_gt_mean=0.070147 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=4.598394 | gate_mean=0.625537 | gate_std=0.054432 | p_hat_norm_mean=36.239987 | p_mask_norm_mean=0.854909 | q_norm_mean=36.239987 | z_gt_norm_mean=1.275222
|
| 8 |
-
valuate on test_u_refer: miou 0.7347305961538223 true fscore 0.8193065231665969
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.754954 | cos_delta_q_mean=-0.054195 | cos_delta_z_gt_mean=0.089436 | cos_p_hat_p_mask_mean=0.077127 | cos_p_hat_q_mean=0.994077 | cos_p_hat_z_gt_mean=0.023352 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=3.370293 | gate_mean=0.544416 | gate_std=0.033540 | p_hat_norm_mean=30.852975 | p_mask_norm_mean=0.855032 | q_norm_mean=30.852975 | z_gt_norm_mean=0.969146
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.12181796133518219
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_probe_eval_directional_pm_only_a02_step0.txt
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.01385641098022461 Learning Rate:0.000000
|
| 2 |
-
valuate on test_s_refer: miou 0.7251643069144439 true fscore 0.8044421944022179
|
| 3 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.754565 | cos_delta_q_mean=-0.062169 | cos_delta_z_gt_mean=0.077297 | cos_p_hat_p_mask_mean=0.084709 | cos_p_hat_q_mean=0.992133 | cos_p_hat_z_gt_mean=0.070145 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=4.598003 | gate_mean=0.625515 | gate_std=0.054416 | p_hat_norm_mean=36.238429 | p_mask_norm_mean=0.854909 | q_norm_mean=36.238428 | z_gt_norm_mean=1.275222
|
| 4 |
-
valuate on test_u_refer: miou 0.7346898949889146 true fscore 0.819309664927423
|
| 5 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.754958 | cos_delta_q_mean=-0.054197 | cos_delta_z_gt_mean=0.089438 | cos_p_hat_p_mask_mean=0.077138 | cos_p_hat_q_mean=0.994077 | cos_p_hat_z_gt_mean=0.023334 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=3.370548 | gate_mean=0.544434 | gate_std=0.033514 | p_hat_norm_mean=30.854847 | p_mask_norm_mean=0.855032 | q_norm_mean=30.854847 | z_gt_norm_mean=0.969146
|
| 6 |
-
|
| 7 |
-
valuate on test_n_refer: metric 0.12185448408126831
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log/rpb_probe_mixed_pm_only_a02_wm005_s80.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
Epoch 0: running_loss 0.11956256674602628 Learning Rate:0.000048
|
| 2 |
-
Epoch 1: running_loss 0.059521447168663144 Learning Rate:0.000038
|
| 3 |
-
Epoch 2: running_loss 0.03955021120297412 Learning Rate:0.000021
|
| 4 |
-
Epoch 3: running_loss 0.029611277248477563 Learning Rate:0.000006
|
| 5 |
-
Epoch 4: running_loss 0.023673650273121894 Learning Rate:0.000000
|
| 6 |
-
valuate on test_s_refer: miou 0.7234249453799384 true fscore 0.8020988971926272
|
| 7 |
-
bridge on test_s_refer: cos_delta_p_mask_mean=0.752115 | cos_delta_q_mean=-0.071252 | cos_delta_z_gt_mean=0.081856 | cos_p_hat_p_mask_mean=0.098034 | cos_p_hat_q_mean=0.989714 | cos_p_hat_z_gt_mean=0.072197 | cos_p_mask_z_gt_mean=0.072929 | delta_norm_mean=5.254162 | gate_mean=0.718218 | gate_std=0.053861 | p_hat_norm_mean=36.239985 | p_mask_norm_mean=0.854909 | q_norm_mean=36.239985 | z_gt_norm_mean=1.275222
|
| 8 |
-
valuate on test_u_refer: miou 0.7361468947966933 true fscore 0.8214005154371261
|
| 9 |
-
bridge on test_u_refer: cos_delta_p_mask_mean=0.754059 | cos_delta_q_mean=-0.063183 | cos_delta_z_gt_mean=0.096618 | cos_p_hat_p_mask_mean=0.090575 | cos_p_hat_q_mean=0.991959 | cos_p_hat_z_gt_mean=0.025874 | cos_p_mask_z_gt_mean=0.081724 | delta_norm_mean=3.926547 | gate_mean=0.635724 | gate_std=0.036734 | p_hat_norm_mean=30.848887 | p_mask_norm_mean=0.855032 | q_norm_mean=30.848887 | z_gt_norm_mean=0.969146
|
| 10 |
-
|
| 11 |
-
valuate on test_n_refer: metric 0.12358559668064117
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seg_ltpo.py
DELETED
|
@@ -1,1372 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SEG-LTPO: test-time optimization of SimToken's Fseg / q prompt token.
|
| 3 |
-
|
| 4 |
-
Two optimizers are provided:
|
| 5 |
-
|
| 6 |
-
ltpo_optimize – original antithetic-ES zeroth-order optimizer (Fseg space).
|
| 7 |
-
q_ltpo_autograd – autograd optimizer that directly optimizes q (= sparse
|
| 8 |
-
prompt embedding passed to the mask decoder) via Adam
|
| 9 |
-
maximize, with a differentiable reward. This is the
|
| 10 |
-
recommended path when the reward can be made differentiable.
|
| 11 |
-
|
| 12 |
-
Staged autograd reward build-up:
|
| 13 |
-
Stage 0 check_grad_connectivity — verify ∂R_iou/∂q ≠ 0
|
| 14 |
-
Stage 1 QLTPOConfig(stage=1) — R = 0.6·R_iou − 0.2·R_area_soft − λ_reg·‖q−q₀‖²
|
| 15 |
-
Stage 2 QLTPOConfig(stage=2) — Stage 1 + 1.0·R_align_det (z_in/z_out stopgrad)
|
| 16 |
-
Stage 3 QLTPOConfig(stage=3) — Stage 2 + 0.2·R_temp_feat (full reward)
|
| 17 |
-
|
| 18 |
-
Reward gating: use best_q only when R_task(best_q) > R_task(q_init) + gate_delta.
|
| 19 |
-
|
| 20 |
-
--- ES baseline (original) ---
|
| 21 |
-
Reward:
|
| 22 |
-
R = λ1·R_temp_feat + λ2·R_iou_pred + λ3·R_align_contrast − λ4·R_area
|
| 23 |
-
Update (antithetic ES, step t):
|
| 24 |
-
F_curr = F_curr + η_t · (R+ − R−)/(2σ_t²) · eps_t
|
| 25 |
-
best_F = argmax_F R(F)
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
from __future__ import annotations
|
| 29 |
-
|
| 30 |
-
from dataclasses import dataclass, field
|
| 31 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 32 |
-
|
| 33 |
-
import torch
|
| 34 |
-
import torch.nn.functional as F
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
# ---------------------------------------------------------------------------
|
| 38 |
-
# Per-sample diagnostics accumulator for q_ltpo_autograd
|
| 39 |
-
# ---------------------------------------------------------------------------
|
| 40 |
-
|
| 41 |
-
_q_ltpo_stats: List[Dict[str, Any]] = []
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def reset_q_ltpo_stats() -> None:
|
| 45 |
-
global _q_ltpo_stats
|
| 46 |
-
_q_ltpo_stats = []
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def get_q_ltpo_stats() -> List[Dict[str, Any]]:
|
| 50 |
-
return list(_q_ltpo_stats)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# ---------------------------------------------------------------------------
|
| 54 |
-
# Configuration
|
| 55 |
-
# ---------------------------------------------------------------------------
|
| 56 |
-
|
| 57 |
-
@dataclass
|
| 58 |
-
class LTPOConfig:
|
| 59 |
-
T: int = 5
|
| 60 |
-
num_anchors: int = 4
|
| 61 |
-
sigma_schedule: List[float] = field(
|
| 62 |
-
default_factory=lambda: [0.10, 0.08, 0.06, 0.04, 0.02]
|
| 63 |
-
)
|
| 64 |
-
eta_scale: float = 0.5 # η_t = eta_scale · σ_t
|
| 65 |
-
|
| 66 |
-
# Reward weights
|
| 67 |
-
lambda1: float = 0.3 # R_temp_feat
|
| 68 |
-
lambda2: float = 0.4 # R_iou_pred
|
| 69 |
-
lambda3: float = 1.0 # R_align_contrast
|
| 70 |
-
lambda4: float = 0.3 # R_area penalty
|
| 71 |
-
|
| 72 |
-
beta: float = 0.5 # background penalty coefficient in R_align_contrast
|
| 73 |
-
|
| 74 |
-
# Reward gating: fall back to F_init when improvement < gate_delta
|
| 75 |
-
gate_delta: float = 0.0
|
| 76 |
-
|
| 77 |
-
# L2 trust-region radius on Fseg; None = disabled
|
| 78 |
-
trust_delta: Optional[float] = None
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
# ---------------------------------------------------------------------------
|
| 82 |
-
# Utilities
|
| 83 |
-
# ---------------------------------------------------------------------------
|
| 84 |
-
|
| 85 |
-
def get_sam_model(model):
|
| 86 |
-
"""Return SAM visual_model, unwrapping a PeftModel wrapper if present."""
|
| 87 |
-
base = model.base_model.model if hasattr(model, "base_model") else model
|
| 88 |
-
return base.model.visual_model
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def get_anchor_indices(num_frames: int, num_anchors: int) -> List[int]:
|
| 92 |
-
"""Uniformly sample anchor frame indices from [0, num_frames-1]."""
|
| 93 |
-
return [round(v) for v in torch.linspace(0, num_frames - 1, num_anchors).tolist()]
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def _precompute_dense_emb(
|
| 97 |
-
sam_model, model_dtype: torch.dtype, device: torch.device
|
| 98 |
-
) -> torch.Tensor:
|
| 99 |
-
"""
|
| 100 |
-
Constant 'no-mask' dense embedding from SAM's prompt encoder.
|
| 101 |
-
Independent of Fseg; precompute once per sample to avoid redundant calls.
|
| 102 |
-
Shape: [1, 256, 64, 64].
|
| 103 |
-
"""
|
| 104 |
-
pe = sam_model.prompt_encoder
|
| 105 |
-
H, W = pe.image_embedding_size
|
| 106 |
-
return (
|
| 107 |
-
pe.no_mask_embed.weight # [1, 256]
|
| 108 |
-
.reshape(1, -1, 1, 1)
|
| 109 |
-
.expand(1, -1, H, W)
|
| 110 |
-
.contiguous()
|
| 111 |
-
.to(model_dtype)
|
| 112 |
-
.to(device)
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
# ---------------------------------------------------------------------------
|
| 117 |
-
# Lightweight SAM decode (skips prompt_encoder overhead)
|
| 118 |
-
# ---------------------------------------------------------------------------
|
| 119 |
-
|
| 120 |
-
def _decode_on_anchors(
|
| 121 |
-
fseg: torch.Tensor, # [1, 256] float32
|
| 122 |
-
image_embeds_anchor: torch.Tensor, # [A, 256, 64, 64] model dtype
|
| 123 |
-
dense_emb: torch.Tensor, # [1, 256, 64, 64] model dtype (constant)
|
| 124 |
-
mask_decoder,
|
| 125 |
-
dense_pe: torch.Tensor, # [1, 256, 64, 64]
|
| 126 |
-
model_dtype: torch.dtype,
|
| 127 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 128 |
-
"""
|
| 129 |
-
Decode anchor frames for a given Fseg.
|
| 130 |
-
|
| 131 |
-
Since no points/boxes are used, prompt_encoder simply concatenates
|
| 132 |
-
text_embeds onto an empty sparse tensor, so sparse_emb == Fseg.unsqueeze(1).
|
| 133 |
-
We exploit this to skip the full prompt_encoder call each iteration.
|
| 134 |
-
|
| 135 |
-
Returns:
|
| 136 |
-
low_res_masks: [A, 1, 256, 256]
|
| 137 |
-
iou_preds: [A, 1]
|
| 138 |
-
"""
|
| 139 |
-
sparse_emb = fseg.to(model_dtype).unsqueeze(1) # [1, 1, 256]
|
| 140 |
-
with torch.no_grad():
|
| 141 |
-
low_res_masks, iou_preds = mask_decoder(
|
| 142 |
-
image_embeddings=image_embeds_anchor,
|
| 143 |
-
image_pe=dense_pe,
|
| 144 |
-
sparse_prompt_embeddings=sparse_emb,
|
| 145 |
-
dense_prompt_embeddings=dense_emb,
|
| 146 |
-
multimask_output=False,
|
| 147 |
-
)
|
| 148 |
-
return low_res_masks, iou_preds # [A,1,256,256], [A,1]
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
# ---------------------------------------------------------------------------
|
| 152 |
-
# Reward computation
|
| 153 |
-
# ---------------------------------------------------------------------------
|
| 154 |
-
|
| 155 |
-
def _compute_reward(
|
| 156 |
-
fseg: torch.Tensor, # [1, 256] float32
|
| 157 |
-
low_res_masks: torch.Tensor, # [A, 1, 256, 256]
|
| 158 |
-
iou_preds: torch.Tensor, # [A, 1]
|
| 159 |
-
image_embeds_anchor: torch.Tensor, # [A, 256, 64, 64]
|
| 160 |
-
cfg: LTPOConfig,
|
| 161 |
-
) -> float:
|
| 162 |
-
num_anchor = low_res_masks.shape[0]
|
| 163 |
-
device = fseg.device
|
| 164 |
-
|
| 165 |
-
# Work entirely in float32 for numerical stability
|
| 166 |
-
masks_soft = torch.sigmoid(low_res_masks.float().squeeze(1)) # [A, 256, 256]
|
| 167 |
-
img_embs = image_embeds_anchor.float() # [A, 256, 64, 64]
|
| 168 |
-
|
| 169 |
-
# q lives in SAM's 256-d prompt space (same as Fseg after text_hidden_fcs)
|
| 170 |
-
q = F.normalize(fseg[0].float(), dim=0) # [256]
|
| 171 |
-
|
| 172 |
-
# Downsample soft masks 256×256 → 64×64 to match image_embed spatial dims.
|
| 173 |
-
# Keep as soft weights (no hard threshold) so the reward surface is smooth.
|
| 174 |
-
masks_64 = F.interpolate(
|
| 175 |
-
masks_soft.unsqueeze(1), size=(64, 64),
|
| 176 |
-
mode="bilinear", align_corners=False,
|
| 177 |
-
).squeeze(1) # [A, 64, 64]
|
| 178 |
-
|
| 179 |
-
# ── Per-frame masked pooling ──────────────────────────────────────────
|
| 180 |
-
z_ins: List[torch.Tensor] = []
|
| 181 |
-
z_outs: List[torch.Tensor] = []
|
| 182 |
-
for t in range(num_anchor):
|
| 183 |
-
m = masks_64[t] # [64, 64]
|
| 184 |
-
img = img_embs[t] # [256, 64, 64]
|
| 185 |
-
|
| 186 |
-
# Soft weighted average pooling over foreground / background
|
| 187 |
-
z_in = (img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6)
|
| 188 |
-
z_out = (img * (1.0 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1.0 - m).sum() + 1e-6)
|
| 189 |
-
|
| 190 |
-
z_ins.append(F.normalize(z_in, dim=0)) # [256]
|
| 191 |
-
z_outs.append(F.normalize(z_out, dim=0)) # [256]
|
| 192 |
-
|
| 193 |
-
# ── R_align_contrast ──────────────────────────────────────────────────
|
| 194 |
-
# Maximise Fseg↔inside alignment while penalising Fseg↔outside alignment.
|
| 195 |
-
# Contrast term prevents reward-hacking via large masks:
|
| 196 |
-
# a large mask pulls inside and outside features together, shrinking the gap.
|
| 197 |
-
r_align = sum(
|
| 198 |
-
(q @ z_ins[t]) - cfg.beta * (q @ z_outs[t])
|
| 199 |
-
for t in range(num_anchor)
|
| 200 |
-
) / num_anchor
|
| 201 |
-
|
| 202 |
-
# ── R_iou_pred ────────────────────────────────────────────────────────
|
| 203 |
-
# SAM's internal mask-quality head, calibrated during SAM training.
|
| 204 |
-
r_iou = iou_preds.float().mean()
|
| 205 |
-
|
| 206 |
-
# ── R_temp_feat ───────────────────────────────────────────────────────
|
| 207 |
-
# Feature-space consistency between adjacent anchor frames.
|
| 208 |
-
# Harder to game than mask-IoU: large masks pool diverse background
|
| 209 |
-
# features across frames, degrading cosine similarity.
|
| 210 |
-
r_temp = torch.tensor(0.0, device=device)
|
| 211 |
-
if num_anchor > 1:
|
| 212 |
-
r_temp = sum(
|
| 213 |
-
z_ins[t] @ z_ins[t + 1] for t in range(num_anchor - 1)
|
| 214 |
-
) / (num_anchor - 1)
|
| 215 |
-
|
| 216 |
-
# ── R_area ────────────────────────────────────────────────────────────
|
| 217 |
-
r_area = masks_64.mean()
|
| 218 |
-
|
| 219 |
-
R = (cfg.lambda1 * r_temp
|
| 220 |
-
+ cfg.lambda2 * r_iou
|
| 221 |
-
+ cfg.lambda3 * r_align
|
| 222 |
-
- cfg.lambda4 * r_area)
|
| 223 |
-
|
| 224 |
-
return R.item()
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
# ---------------------------------------------------------------------------
|
| 228 |
-
# Ablation baseline: Best-of-2 Random (no iterative update)
|
| 229 |
-
# ---------------------------------------------------------------------------
|
| 230 |
-
|
| 231 |
-
def best_of_2_optimize(
|
| 232 |
-
F_init: torch.Tensor,
|
| 233 |
-
image_embeds: torch.Tensor,
|
| 234 |
-
anchor_indices: List[int],
|
| 235 |
-
sam_model,
|
| 236 |
-
model_dtype: torch.dtype,
|
| 237 |
-
cfg: LTPOConfig,
|
| 238 |
-
) -> torch.Tensor:
|
| 239 |
-
"""
|
| 240 |
-
Best-of-2 Random baseline.
|
| 241 |
-
|
| 242 |
-
Sample one antithetic pair (F+, F-) using the first sigma value,
|
| 243 |
-
evaluate both, return whichever has the higher reward.
|
| 244 |
-
No iterative update — serves as the ablation for the update rule.
|
| 245 |
-
Same reward gating as ltpo_optimize for a fair comparison.
|
| 246 |
-
"""
|
| 247 |
-
device = F_init.device
|
| 248 |
-
image_embeds_anchor = image_embeds[anchor_indices]
|
| 249 |
-
|
| 250 |
-
dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
|
| 251 |
-
dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
|
| 252 |
-
mask_dec = sam_model.mask_decoder
|
| 253 |
-
|
| 254 |
-
lrm0, iou0 = _decode_on_anchors(
|
| 255 |
-
F_init, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
|
| 256 |
-
)
|
| 257 |
-
R_init = _compute_reward(F_init, lrm0, iou0, image_embeds_anchor, cfg)
|
| 258 |
-
|
| 259 |
-
sigma = cfg.sigma_schedule[0]
|
| 260 |
-
eps = torch.randn_like(F_init) * sigma
|
| 261 |
-
F_plus = F_init + eps
|
| 262 |
-
F_minus = F_init - eps
|
| 263 |
-
|
| 264 |
-
lrm_p, iou_p = _decode_on_anchors(
|
| 265 |
-
F_plus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
|
| 266 |
-
)
|
| 267 |
-
lrm_m, iou_m = _decode_on_anchors(
|
| 268 |
-
F_minus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
|
| 269 |
-
)
|
| 270 |
-
R_plus = _compute_reward(F_plus, lrm_p, iou_p, image_embeds_anchor, cfg)
|
| 271 |
-
R_minus = _compute_reward(F_minus, lrm_m, iou_m, image_embeds_anchor, cfg)
|
| 272 |
-
|
| 273 |
-
best_R, best_F = R_init, F_init.clone()
|
| 274 |
-
if R_plus > best_R: best_R, best_F = R_plus, F_plus.clone()
|
| 275 |
-
if R_minus > best_R: best_R, best_F = R_minus, F_minus.clone()
|
| 276 |
-
|
| 277 |
-
if best_R <= R_init + cfg.gate_delta:
|
| 278 |
-
return F_init
|
| 279 |
-
return best_F
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# ---------------------------------------------------------------------------
|
| 283 |
-
# Full-video decode with a given Fseg
|
| 284 |
-
# ---------------------------------------------------------------------------
|
| 285 |
-
|
| 286 |
-
def _sobel_edge(rgb_frames: torch.Tensor) -> torch.Tensor:
|
| 287 |
-
"""Compute Sobel edge magnitude from normalized RGB frames.
|
| 288 |
-
|
| 289 |
-
Args:
|
| 290 |
-
rgb_frames: [T, 3, H, W] float32 (SAM-normalized, CUDA)
|
| 291 |
-
Returns:
|
| 292 |
-
edge: [T, 1, H, W] float32, non-negative
|
| 293 |
-
"""
|
| 294 |
-
gray = rgb_frames.float().mean(dim=1, keepdim=True) # [T, 1, H, W]
|
| 295 |
-
kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
|
| 296 |
-
dtype=torch.float32, device=rgb_frames.device).view(1, 1, 3, 3)
|
| 297 |
-
ky = kx.transpose(2, 3)
|
| 298 |
-
gx = F.conv2d(gray, kx, padding=1)
|
| 299 |
-
gy = F.conv2d(gray, ky, padding=1)
|
| 300 |
-
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-6) # [T, 1, H, W]
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
def _boundary_edge_score(
|
| 304 |
-
low_res_masks: torch.Tensor, # [T, K, 256, 256] logits
|
| 305 |
-
rgb_frames: torch.Tensor, # [T, 3, H, W] float32
|
| 306 |
-
resize: tuple, # (H_resized, W_resized)
|
| 307 |
-
area_temp: float = 5.0,
|
| 308 |
-
) -> torch.Tensor:
|
| 309 |
-
"""Score each of K mask candidates by boundary-edge alignment.
|
| 310 |
-
|
| 311 |
-
R_edge = <soft_boundary_band, Sobel_edge> / (sum(soft_boundary_band) + ε)
|
| 312 |
-
Rewards masks whose boundaries coincide with image edges.
|
| 313 |
-
|
| 314 |
-
Returns: [T, K] float32 scores (higher = better boundary alignment)
|
| 315 |
-
"""
|
| 316 |
-
T, K = low_res_masks.shape[:2]
|
| 317 |
-
H_r, W_r = resize
|
| 318 |
-
|
| 319 |
-
# Upsample all candidates to resized image resolution at once
|
| 320 |
-
masks_up = F.interpolate(
|
| 321 |
-
low_res_masks.reshape(T * K, 1, 256, 256).float(),
|
| 322 |
-
size=(H_r, W_r), mode="bilinear", align_corners=False,
|
| 323 |
-
).reshape(T, K, H_r, W_r) # [T, K, H, W]
|
| 324 |
-
|
| 325 |
-
E = _sobel_edge(rgb_frames[:, :, :H_r, :W_r]) # [T, 1, H, W]
|
| 326 |
-
|
| 327 |
-
m = torch.sigmoid(masks_up / area_temp) # [T, K, H, W]
|
| 328 |
-
b = 4.0 * m * (1.0 - m) # soft boundary band
|
| 329 |
-
num = (b * E.squeeze(1).unsqueeze(1)).sum(dim=[2, 3]) # [T, K]
|
| 330 |
-
den = b.sum(dim=[2, 3]) + 1e-6
|
| 331 |
-
return num / den # [T, K]
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
def decode_full_video(
|
| 335 |
-
fseg: torch.Tensor, # [1, 256] float32
|
| 336 |
-
image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
|
| 337 |
-
sam_model,
|
| 338 |
-
resize: tuple, # (H_resized, W_resized)
|
| 339 |
-
orgsize: tuple, # (H_orig, W_orig)
|
| 340 |
-
model_dtype: torch.dtype,
|
| 341 |
-
rgb_frames: Optional[torch.Tensor] = None, # [T, 3, H, W]; enables edge selection
|
| 342 |
-
multimask: bool = False, # True = 3 candidates; False = single mask
|
| 343 |
-
) -> torch.Tensor:
|
| 344 |
-
"""Decode all T frames with the given Fseg.
|
| 345 |
-
|
| 346 |
-
Selection logic (applied per-frame):
|
| 347 |
-
- multimask=False, rgb_frames=None : original single-mask decode (baseline)
|
| 348 |
-
- multimask=True, rgb_frames=None : 3 candidates, select by SAM iou_pred
|
| 349 |
-
- multimask=True, rgb_frames=* : 3 candidates, select by boundary-edge score
|
| 350 |
-
(boundary band × Sobel edge; directly rewards boundary-image alignment)
|
| 351 |
-
|
| 352 |
-
Returns raw logit mask [T, H_orig, W_orig] (not yet sigmoid).
|
| 353 |
-
"""
|
| 354 |
-
device = image_embeds.device
|
| 355 |
-
dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
|
| 356 |
-
dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
|
| 357 |
-
sparse_emb = fseg.to(model_dtype).unsqueeze(1) # [1, 1, 256]
|
| 358 |
-
|
| 359 |
-
with torch.no_grad():
|
| 360 |
-
low_res_masks, iou_preds = sam_model.mask_decoder(
|
| 361 |
-
image_embeddings=image_embeds,
|
| 362 |
-
image_pe=dense_pe,
|
| 363 |
-
sparse_prompt_embeddings=sparse_emb,
|
| 364 |
-
dense_prompt_embeddings=dense_emb,
|
| 365 |
-
multimask_output=multimask,
|
| 366 |
-
) # [T, K, 256, 256], [T, K] where K=1 or K=3
|
| 367 |
-
|
| 368 |
-
if multimask:
|
| 369 |
-
T = low_res_masks.shape[0]
|
| 370 |
-
if rgb_frames is not None:
|
| 371 |
-
# Step 1b: boundary-edge score selects best candidate
|
| 372 |
-
scores = _boundary_edge_score(low_res_masks, rgb_frames, resize)
|
| 373 |
-
else:
|
| 374 |
-
# Step 1a: SAM's own iou_pred selects best candidate
|
| 375 |
-
scores = iou_preds
|
| 376 |
-
best_idx = scores.argmax(dim=1) # [T]
|
| 377 |
-
low_res_masks = low_res_masks[torch.arange(T, device=device), best_idx].unsqueeze(1)
|
| 378 |
-
|
| 379 |
-
pred_mask = sam_model.postprocess_masks(
|
| 380 |
-
low_res_masks, input_size=resize, original_size=orgsize
|
| 381 |
-
) # [T, 1, H, W]
|
| 382 |
-
return pred_mask.squeeze(1) # [T, H, W]
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
# ---------------------------------------------------------------------------
|
| 386 |
-
# Main optimisation loop
|
| 387 |
-
# ---------------------------------------------------------------------------
|
| 388 |
-
|
| 389 |
-
def ltpo_optimize(
|
| 390 |
-
F_init: torch.Tensor, # [1, 256] float32 on CUDA
|
| 391 |
-
image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
|
| 392 |
-
anchor_indices: List[int],
|
| 393 |
-
sam_model,
|
| 394 |
-
model_dtype: torch.dtype,
|
| 395 |
-
cfg: LTPOConfig,
|
| 396 |
-
) -> torch.Tensor:
|
| 397 |
-
"""
|
| 398 |
-
Optimise Fseg at test time via antithetic ES.
|
| 399 |
-
|
| 400 |
-
Returns best Fseg found [1, 256] float32.
|
| 401 |
-
Falls back to F_init when reward gating rejects all updates.
|
| 402 |
-
"""
|
| 403 |
-
device = F_init.device
|
| 404 |
-
image_embeds_anchor = image_embeds[anchor_indices] # [A, 256, 64, 64]
|
| 405 |
-
|
| 406 |
-
# Precompute constants shared across every optimisation step
|
| 407 |
-
dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
|
| 408 |
-
dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
|
| 409 |
-
mask_dec = sam_model.mask_decoder
|
| 410 |
-
|
| 411 |
-
# ── Evaluate initial token ────────────────────────────────────────────
|
| 412 |
-
lrm0, iou0 = _decode_on_anchors(
|
| 413 |
-
F_init, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
|
| 414 |
-
)
|
| 415 |
-
R_init = _compute_reward(F_init, lrm0, iou0, image_embeds_anchor, cfg)
|
| 416 |
-
|
| 417 |
-
best_F, best_R = F_init.clone(), R_init
|
| 418 |
-
F_curr = F_init.clone()
|
| 419 |
-
|
| 420 |
-
# ── Optimisation loop ─────────────────────────────────────────────────
|
| 421 |
-
for t in range(cfg.T):
|
| 422 |
-
sigma_t = cfg.sigma_schedule[t]
|
| 423 |
-
eta_t = cfg.eta_scale * sigma_t
|
| 424 |
-
|
| 425 |
-
eps = torch.randn_like(F_curr) * sigma_t
|
| 426 |
-
F_plus = F_curr + eps
|
| 427 |
-
F_minus = F_curr - eps
|
| 428 |
-
|
| 429 |
-
lrm_p, iou_p = _decode_on_anchors(
|
| 430 |
-
F_plus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
|
| 431 |
-
)
|
| 432 |
-
lrm_m, iou_m = _decode_on_anchors(
|
| 433 |
-
F_minus, image_embeds_anchor, dense_emb, mask_dec, dense_pe, model_dtype
|
| 434 |
-
)
|
| 435 |
-
|
| 436 |
-
R_plus = _compute_reward(F_plus, lrm_p, iou_p, image_embeds_anchor, cfg)
|
| 437 |
-
R_minus = _compute_reward(F_minus, lrm_m, iou_m, image_embeds_anchor, cfg)
|
| 438 |
-
|
| 439 |
-
# Track the best token seen across all evaluated candidates
|
| 440 |
-
if R_plus > best_R:
|
| 441 |
-
best_R, best_F = R_plus, F_plus.clone()
|
| 442 |
-
if R_minus > best_R:
|
| 443 |
-
best_R, best_F = R_minus, F_minus.clone()
|
| 444 |
-
|
| 445 |
-
# Antithetic policy-gradient update of the iterate
|
| 446 |
-
# Formula: F_{t+1} = F_t + η_t · (R+ - R−)/(2σ_t²) · eps_t
|
| 447 |
-
grad_est = (R_plus - R_minus) / (2.0 * sigma_t ** 2)
|
| 448 |
-
F_curr = F_curr + eta_t * grad_est * eps
|
| 449 |
-
|
| 450 |
-
# Optional L2 trust-region: keep F_curr within radius trust_delta of F_init
|
| 451 |
-
if cfg.trust_delta is not None:
|
| 452 |
-
diff = F_curr - F_init
|
| 453 |
-
norm = diff.norm()
|
| 454 |
-
if norm > cfg.trust_delta:
|
| 455 |
-
F_curr = F_init + diff * (cfg.trust_delta / norm)
|
| 456 |
-
|
| 457 |
-
# ── Reward gating ─────────────────────────────────────────────────────
|
| 458 |
-
# Reject the update when there is no meaningful improvement over the
|
| 459 |
-
# initial token (handles Null-like samples where no target exists).
|
| 460 |
-
if best_R <= R_init + cfg.gate_delta:
|
| 461 |
-
return F_init
|
| 462 |
-
return best_F
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
# ===========================================================================
|
| 466 |
-
# q-LTPO-autograd: differentiable test-time optimization of the prompt token
|
| 467 |
-
# ===========================================================================
|
| 468 |
-
|
| 469 |
-
@dataclass
|
| 470 |
-
class QLTPOConfig:
|
| 471 |
-
"""Configuration for q_ltpo_autograd (Stages 1–3 + Stage 2-ext variants).
|
| 472 |
-
|
| 473 |
-
stage controls which reward terms are active:
|
| 474 |
-
1 R_iou + R_area_soft + reg (baseline autograd)
|
| 475 |
-
2 Stage 1 + R_align_det (z_in/z_out stopgrad) (self-bootstrapped alignment)
|
| 476 |
-
3 Stage 2 + R_temp_feat (full reward)
|
| 477 |
-
21 Stage 1 + R_tether (P1a: tether probe) (frozen r_ref via q_init attn)
|
| 478 |
-
22 Stage 1 + R_faithful (P1b: faithful ext-ref) (z_in/z_out vs frozen r_ref)
|
| 479 |
-
"""
|
| 480 |
-
stage: int = 1
|
| 481 |
-
T: int = 5
|
| 482 |
-
num_anchors: int = 4
|
| 483 |
-
|
| 484 |
-
# ── Optimizer ──────────────────────────────────────────────────────────
|
| 485 |
-
# lr=0 → auto-set to 0.01 × RMS(q_init); any positive value is used directly
|
| 486 |
-
lr: float = 0.0
|
| 487 |
-
# max_drift=0 → auto-set to 0.5 × ‖q_init‖; any positive value is a hard radius
|
| 488 |
-
max_drift: float = 0.0
|
| 489 |
-
|
| 490 |
-
# ── Stage 1 reward weights ─────────────────────────────────────────────
|
| 491 |
-
lambda_iou: float = 0.6
|
| 492 |
-
lambda_area: float = 0.2
|
| 493 |
-
lambda_reg: float = 0.01
|
| 494 |
-
area_temp: float = 5.0 # sigmoid temperature for R_area_soft
|
| 495 |
-
|
| 496 |
-
# ── Stage 2 additional weights ─────────────────────────────────────────
|
| 497 |
-
lambda_align: float = 1.0
|
| 498 |
-
beta_align: float = 0.5 # background penalty coefficient in R_align
|
| 499 |
-
|
| 500 |
-
# ── Stage 3 additional weights ─────────────────────────────────────────
|
| 501 |
-
lambda_temp: float = 0.2
|
| 502 |
-
|
| 503 |
-
# ── Gating ─────────────────────────────────────────────────────────────
|
| 504 |
-
gate_delta: float = 0.0
|
| 505 |
-
|
| 506 |
-
# ── e0-modulated R_iou (principled Null-safety) ────────────────────────
|
| 507 |
-
# e0 = stopgrad(R_area_soft(q_init)): the initial soft-area fraction acts
|
| 508 |
-
# as an existence prior on the R_iou term.
|
| 509 |
-
# "none" → original behavior (e0 = 1, no modulation)
|
| 510 |
-
# "identity" → e0 = R_area_soft(q_init) [first version]
|
| 511 |
-
# "sqrt" → e0 = sqrt(R_area_soft(q_init) + e0_eps)
|
| 512 |
-
e0_modulation: str = "identity"
|
| 513 |
-
e0_eps: float = 1e-4 # epsilon for "sqrt" variant
|
| 514 |
-
|
| 515 |
-
# ── Stage 2-ext: external reference (stages 21 and 22) ────────────────
|
| 516 |
-
# r_ref = AttnPool(image_feats_anchor, q_init): frozen visual anchor derived
|
| 517 |
-
# from q_init's attention over SAM image features. Breaks Stage 2's
|
| 518 |
-
# self-confirming bias by providing a mask-independent teacher.
|
| 519 |
-
# r_ref_temp: softmax temperature for attention pooling (sqrt(256) = 16).
|
| 520 |
-
r_ref_temp: float = 16.0
|
| 521 |
-
|
| 522 |
-
# ── Direction B: boundary precision rewards ────────────────────────────
|
| 523 |
-
# B1: asymmetric area expansion penalty
|
| 524 |
-
# Only penalises growth beyond (1+τ)×e0; allows mask contraction.
|
| 525 |
-
# Targets the observed pattern where LTPO slightly expands masks into
|
| 526 |
-
# non-target regions (recall↑ but precision↓, hurting F-score).
|
| 527 |
-
# B2: boundary sharpness reward
|
| 528 |
-
# -mean(4m(1-m)) with temperature=1.0; rewards bimodal (certain)
|
| 529 |
-
# mask predictions, encouraging cleaner boundary predictions.
|
| 530 |
-
lambda_area_inc: float = 0.0 # B1 weight (0 = disabled)
|
| 531 |
-
area_inc_tau: float = 0.0 # B1 tolerance band: allow (1+τ)×e0
|
| 532 |
-
lambda_sharp: float = 0.0 # B2 weight (0 = disabled)
|
| 533 |
-
|
| 534 |
-
# ── Oracle Null-safety gate (analysis only; NOT for final method) ──────
|
| 535 |
-
# Derived from test-set distribution (Null area_hard ≈ 0.01, Seen ≈ 0.05)
|
| 536 |
-
# so must not be used in reported results. Set null_gate_delta=0 to disable.
|
| 537 |
-
null_area_threshold: float = 0.02 # hard area fraction below which guard activates
|
| 538 |
-
null_gate_delta: float = 0.0 # 0 = disabled; 0.05 = oracle experiment
|
| 539 |
-
|
| 540 |
-
# ── Direction II: Frame-adaptive token optimization (stage=4) ─────────
|
| 541 |
-
# q_t = q_global + delta_t, where delta_t is a per-anchor residual.
|
| 542 |
-
# Optimizes q_global and {delta_t} jointly with Adam.
|
| 543 |
-
# lambda_residual: soft L2 penalty on delta_t
|
| 544 |
-
# lambda_smooth_temp: temporal smoothness penalty on adjacent delta differences
|
| 545 |
-
# max_delta_drift_scale: per-anchor hard L2 clip = scale × ‖q_init‖
|
| 546 |
-
# Prevents individual anchors from wandering to a completely different visual mode.
|
| 547 |
-
# Keep << max_drift (0.5) so delta stays a "small frame correction" to q_global.
|
| 548 |
-
# 0.1 is tight (delta ≤ 20% of global drift budget), 0.3 is moderate.
|
| 549 |
-
lambda_residual: float = 0.001
|
| 550 |
-
lambda_smooth_temp: float = 0.0
|
| 551 |
-
max_delta_drift_scale: float = 0.1 # per-anchor clip = scale × ‖q_init‖
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
# ---------------------------------------------------------------------------
|
| 555 |
-
# e0 helper
|
| 556 |
-
# ---------------------------------------------------------------------------
|
| 557 |
-
|
| 558 |
-
def _compute_e0(r_area_soft_init: float, cfg: "QLTPOConfig") -> float:
|
| 559 |
-
"""Compute the existence-prior weight from the initial soft area."""
|
| 560 |
-
if cfg.e0_modulation == "identity":
|
| 561 |
-
return r_area_soft_init
|
| 562 |
-
if cfg.e0_modulation == "sqrt":
|
| 563 |
-
return (r_area_soft_init + cfg.e0_eps) ** 0.5
|
| 564 |
-
return 1.0 # "none"
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
# ---------------------------------------------------------------------------
|
| 568 |
-
# Differentiable anchor decode (float32 throughout; no torch.no_grad)
|
| 569 |
-
# ---------------------------------------------------------------------------
|
| 570 |
-
|
| 571 |
-
def _decode_on_anchors_diff(
|
| 572 |
-
q: torch.Tensor, # [1, 256] float32
|
| 573 |
-
image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
|
| 574 |
-
dense_emb_fp32: torch.Tensor, # [1, 256, 64, 64] float32
|
| 575 |
-
mask_decoder,
|
| 576 |
-
dense_pe_fp32: torch.Tensor, # [1, 256, 64, 64] float32
|
| 577 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 578 |
-
"""Differentiable mask-decoder forward.
|
| 579 |
-
|
| 580 |
-
All inputs are float32 to avoid fp16 gradient truncation.
|
| 581 |
-
q may be a Parameter (requires_grad=True) or a plain detached tensor.
|
| 582 |
-
Returns low_res_masks [A,1,256,256] and iou_preds [A,1], both float32.
|
| 583 |
-
"""
|
| 584 |
-
sparse_emb = q.unsqueeze(1) # [1, 1, 256]
|
| 585 |
-
low_res_masks, iou_preds = mask_decoder(
|
| 586 |
-
image_embeddings=image_embeds_anchor_fp32,
|
| 587 |
-
image_pe=dense_pe_fp32,
|
| 588 |
-
sparse_prompt_embeddings=sparse_emb,
|
| 589 |
-
dense_prompt_embeddings=dense_emb_fp32,
|
| 590 |
-
multimask_output=False,
|
| 591 |
-
)
|
| 592 |
-
return low_res_masks, iou_preds # [A,1,256,256], [A,1]
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
# ---------------------------------------------------------------------------
|
| 596 |
-
# Differentiable reward components
|
| 597 |
-
# ---------------------------------------------------------------------------
|
| 598 |
-
|
| 599 |
-
def _task_reward_stage1(
|
| 600 |
-
lrm: torch.Tensor, # [A,1,256,256] float32
|
| 601 |
-
iou: torch.Tensor, # [A,1] float32
|
| 602 |
-
cfg: QLTPOConfig,
|
| 603 |
-
e0: float = 1.0,
|
| 604 |
-
) -> torch.Tensor:
|
| 605 |
-
"""Task reward (no regularization): used for best_q tracking and gating.
|
| 606 |
-
|
| 607 |
-
e0 is the stopgrad existence prior: R_area_soft(q_init) scaled via
|
| 608 |
-
cfg.e0_modulation. When e0 << 1 the iou term is suppressed, so the
|
| 609 |
-
optimizer sees only the area-penalty gradient and naturally tends toward
|
| 610 |
-
smaller (more conservative) masks — the correct behavior when the initial
|
| 611 |
-
prediction is near-empty (Null frames).
|
| 612 |
-
|
| 613 |
-
Optional boundary precision terms (Direction B):
|
| 614 |
-
B1 (lambda_area_inc > 0): asymmetric expansion penalty
|
| 615 |
-
-λ_inc · ReLU(r_area - (1+τ)·e0)
|
| 616 |
-
Penalises mask growth beyond the initial area (+ tolerance band τ).
|
| 617 |
-
e0 doubles as the stopgrad initial-area threshold — zero extra cost.
|
| 618 |
-
B2 (lambda_sharp > 0): boundary sharpness reward
|
| 619 |
-
-λ_sharp · mean(4m(1-m)) with m = sigmoid(lrm), temperature=1.0
|
| 620 |
-
Maximises bimodality of mask logits → cleaner boundary predictions.
|
| 621 |
-
"""
|
| 622 |
-
r_iou = iou.mean()
|
| 623 |
-
r_area = torch.sigmoid(lrm / cfg.area_temp).mean()
|
| 624 |
-
R = cfg.lambda_iou * e0 * r_iou - cfg.lambda_area * r_area
|
| 625 |
-
|
| 626 |
-
# B1: penalise expansion beyond (1+τ)×e0 (allow contraction freely)
|
| 627 |
-
if cfg.lambda_area_inc > 0.0:
|
| 628 |
-
area_ceil = (1.0 + cfg.area_inc_tau) * e0
|
| 629 |
-
R = R - cfg.lambda_area_inc * F.relu(r_area - area_ceil)
|
| 630 |
-
|
| 631 |
-
# B2: reward confident (bimodal) boundary predictions
|
| 632 |
-
if cfg.lambda_sharp > 0.0:
|
| 633 |
-
m_sharp = torch.sigmoid(lrm) # temperature=1.0 (sharp)
|
| 634 |
-
boundary_uncertain = 4.0 * m_sharp * (1.0 - m_sharp)
|
| 635 |
-
R = R - cfg.lambda_sharp * boundary_uncertain.mean()
|
| 636 |
-
|
| 637 |
-
return R
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
def _task_reward_stage2(
|
| 641 |
-
q: torch.Tensor, # [1, 256] float32
|
| 642 |
-
lrm: torch.Tensor, # [A,1,256,256] float32
|
| 643 |
-
iou: torch.Tensor, # [A,1] float32
|
| 644 |
-
image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
|
| 645 |
-
cfg: QLTPOConfig,
|
| 646 |
-
e0: float = 1.0,
|
| 647 |
-
) -> torch.Tensor:
|
| 648 |
-
"""Stage 2 task reward: Stage 1 + R_align_det (z_in/z_out are stopgrad)."""
|
| 649 |
-
r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
|
| 650 |
-
|
| 651 |
-
A = lrm.shape[0]
|
| 652 |
-
masks_64 = F.interpolate(
|
| 653 |
-
torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
|
| 654 |
-
size=(64, 64), mode="bilinear", align_corners=False,
|
| 655 |
-
).squeeze(1) # [A, 64, 64]
|
| 656 |
-
|
| 657 |
-
q_norm = F.normalize(q[0], dim=0) # [256]
|
| 658 |
-
r_align = torch.tensor(0.0, device=q.device)
|
| 659 |
-
for t in range(A):
|
| 660 |
-
m = masks_64[t].detach() # stopgrad on z_in/z_out
|
| 661 |
-
img = image_embeds_anchor_fp32[t] # [256, 64, 64]
|
| 662 |
-
z_in = F.normalize((img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0)
|
| 663 |
-
z_out = F.normalize((img * (1 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1 - m).sum() + 1e-6), dim=0)
|
| 664 |
-
r_align = r_align + q_norm @ z_in - cfg.beta_align * (q_norm @ z_out)
|
| 665 |
-
r_align = r_align / A
|
| 666 |
-
|
| 667 |
-
return r_s1 + cfg.lambda_align * r_align
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
def _task_reward_stage3(
|
| 671 |
-
q: torch.Tensor,
|
| 672 |
-
lrm: torch.Tensor,
|
| 673 |
-
iou: torch.Tensor,
|
| 674 |
-
image_embeds_anchor_fp32: torch.Tensor,
|
| 675 |
-
cfg: QLTPOConfig,
|
| 676 |
-
e0: float = 1.0,
|
| 677 |
-
) -> torch.Tensor:
|
| 678 |
-
"""Stage 3 task reward: Stage 2 + R_temp_feat."""
|
| 679 |
-
r_s2 = _task_reward_stage2(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
|
| 680 |
-
|
| 681 |
-
A = lrm.shape[0]
|
| 682 |
-
if A < 2:
|
| 683 |
-
return r_s2
|
| 684 |
-
|
| 685 |
-
masks_64 = F.interpolate(
|
| 686 |
-
torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
|
| 687 |
-
size=(64, 64), mode="bilinear", align_corners=False,
|
| 688 |
-
).squeeze(1) # [A, 64, 64]
|
| 689 |
-
|
| 690 |
-
z_ins = []
|
| 691 |
-
for t in range(A):
|
| 692 |
-
m = masks_64[t].detach()
|
| 693 |
-
img = image_embeds_anchor_fp32[t]
|
| 694 |
-
z_in = F.normalize((img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0)
|
| 695 |
-
z_ins.append(z_in)
|
| 696 |
-
|
| 697 |
-
r_temp = sum(z_ins[t] @ z_ins[t + 1] for t in range(A - 1)) / (A - 1)
|
| 698 |
-
return r_s2 + cfg.lambda_temp * r_temp
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
@torch.no_grad()
|
| 702 |
-
def _compute_r_ref(
|
| 703 |
-
q_init: torch.Tensor, # [1, 256] float32
|
| 704 |
-
image_embeds_anchor: torch.Tensor, # [A, 256, 64, 64] float32
|
| 705 |
-
temp: float = 16.0,
|
| 706 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 707 |
-
"""Frozen external visual reference via attention pooling guided by q_init.
|
| 708 |
-
|
| 709 |
-
r_ref: regions most attended by q_init (positive anchor).
|
| 710 |
-
r_neg: regions least attended by q_init (anti-attended negative).
|
| 711 |
-
Both are in the SAM 256d space — no projection needed.
|
| 712 |
-
Computed once before the optimization loop and kept fixed (stopgrad).
|
| 713 |
-
"""
|
| 714 |
-
img_flat = image_embeds_anchor.flatten(2) # [A, 256, H*W]
|
| 715 |
-
q_norm = F.normalize(q_init[0], dim=0) # [256]
|
| 716 |
-
img_norm = F.normalize(img_flat, dim=1) # [A, 256, H*W]
|
| 717 |
-
|
| 718 |
-
# cosine similarity between q and each spatial position
|
| 719 |
-
attn = torch.einsum('d,adp->ap', q_norm, img_norm) # [A, H*W]
|
| 720 |
-
|
| 721 |
-
attn_w_pos = torch.softmax( attn / temp, dim=-1) # [A, H*W]
|
| 722 |
-
attn_w_neg = torch.softmax(-attn / temp, dim=-1) # [A, H*W] anti-attended
|
| 723 |
-
|
| 724 |
-
# soft attention pooling in the original (non-normalized) feature space
|
| 725 |
-
r_ref_frames = torch.einsum('ap,adp->ad', attn_w_pos, img_flat) # [A, 256]
|
| 726 |
-
r_neg_frames = torch.einsum('ap,adp->ad', attn_w_neg, img_flat) # [A, 256]
|
| 727 |
-
|
| 728 |
-
r_ref = F.normalize(r_ref_frames.mean(0), dim=0) # [256]
|
| 729 |
-
r_neg = F.normalize(r_neg_frames.mean(0), dim=0) # [256]
|
| 730 |
-
return r_ref, r_neg
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
def _task_reward_stage2_tether(
|
| 734 |
-
q: torch.Tensor, # [1, 256] float32
|
| 735 |
-
lrm: torch.Tensor, # [A,1,256,256] float32
|
| 736 |
-
iou: torch.Tensor, # [A,1] float32
|
| 737 |
-
r_ref: torch.Tensor, # [256] frozen
|
| 738 |
-
r_neg: torch.Tensor, # [256] frozen
|
| 739 |
-
cfg: QLTPOConfig,
|
| 740 |
-
e0: float = 1.0,
|
| 741 |
-
) -> torch.Tensor:
|
| 742 |
-
"""Stage 21 (P1a tether): Stage 1 + R_tether.
|
| 743 |
-
|
| 744 |
-
R_tether = cos(q, r_ref) - beta·cos(q, r_neg)
|
| 745 |
-
q is pulled toward the frozen visual anchor without touching mask features.
|
| 746 |
-
Tests whether a fixed external reference stabilizes the optimization trajectory.
|
| 747 |
-
"""
|
| 748 |
-
r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
|
| 749 |
-
q_norm = F.normalize(q[0], dim=0)
|
| 750 |
-
r_tether = q_norm @ r_ref - cfg.beta_align * (q_norm @ r_neg)
|
| 751 |
-
return r_s1 + cfg.lambda_align * r_tether
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
def _task_reward_stage2_faithful(
|
| 755 |
-
q: torch.Tensor, # [1, 256] float32
|
| 756 |
-
lrm: torch.Tensor, # [A,1,256,256] float32
|
| 757 |
-
iou: torch.Tensor, # [A,1] float32
|
| 758 |
-
image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
|
| 759 |
-
r_ref: torch.Tensor, # [256] frozen
|
| 760 |
-
cfg: QLTPOConfig,
|
| 761 |
-
e0: float = 1.0,
|
| 762 |
-
) -> torch.Tensor:
|
| 763 |
-
"""Stage 22 (P1b faithful): Stage 1 + R_faithful.
|
| 764 |
-
|
| 765 |
-
R_faithful = mean_t[ cos(z_in(q,t), r_ref) - beta·cos(z_out(q,t), r_ref) ]
|
| 766 |
-
z_in/z_out come from the *current* mask (change during optimization), but the
|
| 767 |
-
teacher r_ref is frozen — breaking Stage 2's self-confirming bias while keeping
|
| 768 |
-
the same structural form (mask-region vs. reference alignment).
|
| 769 |
-
"""
|
| 770 |
-
r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
|
| 771 |
-
A = lrm.shape[0]
|
| 772 |
-
masks_64 = F.interpolate(
|
| 773 |
-
torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
|
| 774 |
-
size=(64, 64), mode="bilinear", align_corners=False,
|
| 775 |
-
).squeeze(1) # [A, 64, 64]
|
| 776 |
-
|
| 777 |
-
r_align = torch.tensor(0.0, device=q.device)
|
| 778 |
-
for t in range(A):
|
| 779 |
-
m = masks_64[t].detach() # stopgrad on mask weights only
|
| 780 |
-
img = image_embeds_anchor_fp32[t] # [256, 64, 64]
|
| 781 |
-
z_in = F.normalize((img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0)
|
| 782 |
-
z_out = F.normalize((img * (1 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1 - m).sum() + 1e-6), dim=0)
|
| 783 |
-
# teacher is r_ref (frozen), not z_in itself — no confirmation bias
|
| 784 |
-
r_align = r_align + z_in @ r_ref - cfg.beta_align * (z_out @ r_ref)
|
| 785 |
-
r_align = r_align / A
|
| 786 |
-
|
| 787 |
-
return r_s1 + cfg.lambda_align * r_align
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
def _decode_on_anchors_diff_adaptive(
|
| 791 |
-
q_global: torch.Tensor, # [1, 256] float32, requires_grad
|
| 792 |
-
delta: torch.Tensor, # [A, 256] float32, requires_grad
|
| 793 |
-
image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32, detached
|
| 794 |
-
dense_emb_fp32: torch.Tensor, # [1, 256, 64, 64] float32, detached
|
| 795 |
-
mask_decoder,
|
| 796 |
-
dense_pe_fp32: torch.Tensor, # [1, 256, 64, 64] float32, detached
|
| 797 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 798 |
-
"""Frame-adaptive differentiable decode: each anchor t uses q_t = q_global + delta[t].
|
| 799 |
-
|
| 800 |
-
Loops over A anchors to preserve gradient flow through both q_global and delta.
|
| 801 |
-
Returns low_res_masks [A,1,256,256] and iou_preds [A,1], both float32.
|
| 802 |
-
"""
|
| 803 |
-
A = image_embeds_anchor_fp32.shape[0]
|
| 804 |
-
lrm_list: List[torch.Tensor] = []
|
| 805 |
-
iou_list: List[torch.Tensor] = []
|
| 806 |
-
for t in range(A):
|
| 807 |
-
q_t = q_global + delta[t : t + 1] # [1, 256]
|
| 808 |
-
sparse_emb = q_t.unsqueeze(1) # [1, 1, 256]
|
| 809 |
-
lrm_t, iou_t = mask_decoder(
|
| 810 |
-
image_embeddings=image_embeds_anchor_fp32[t : t + 1],
|
| 811 |
-
image_pe=dense_pe_fp32,
|
| 812 |
-
sparse_prompt_embeddings=sparse_emb,
|
| 813 |
-
dense_prompt_embeddings=dense_emb_fp32,
|
| 814 |
-
multimask_output=False,
|
| 815 |
-
) # [1,1,256,256], [1,1]
|
| 816 |
-
lrm_list.append(lrm_t)
|
| 817 |
-
iou_list.append(iou_t)
|
| 818 |
-
return torch.cat(lrm_list, dim=0), torch.cat(iou_list, dim=0) # [A,1,256,256], [A,1]
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
def _task_reward_frame_adaptive(
|
| 822 |
-
lrm: torch.Tensor, # [A, 1, 256, 256] float32
|
| 823 |
-
iou: torch.Tensor, # [A, 1] float32
|
| 824 |
-
cfg: "QLTPOConfig",
|
| 825 |
-
e0_vec: List[float], # per-anchor existence priors [A]
|
| 826 |
-
) -> torch.Tensor:
|
| 827 |
-
"""Per-anchor task reward averaged over anchors (no regularization)."""
|
| 828 |
-
A = lrm.shape[0]
|
| 829 |
-
R = torch.tensor(0.0, device=lrm.device)
|
| 830 |
-
for t in range(A):
|
| 831 |
-
r_iou_t = iou[t].mean()
|
| 832 |
-
r_area_t = torch.sigmoid(lrm[t] / cfg.area_temp).mean()
|
| 833 |
-
R = R + cfg.lambda_iou * e0_vec[t] * r_iou_t - cfg.lambda_area * r_area_t
|
| 834 |
-
return R / A
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
def _compute_full_reward_adaptive(
|
| 838 |
-
q_global: torch.Tensor, # [1, 256]
|
| 839 |
-
delta: torch.Tensor, # [A, 256]
|
| 840 |
-
lrm: torch.Tensor, # [A, 1, 256, 256]
|
| 841 |
-
iou: torch.Tensor, # [A, 1]
|
| 842 |
-
q_init: torch.Tensor, # [1, 256] detached
|
| 843 |
-
cfg: "QLTPOConfig",
|
| 844 |
-
e0_vec: List[float],
|
| 845 |
-
) -> torch.Tensor:
|
| 846 |
-
"""Full adaptive reward = task + residual penalty + temporal smoothness + L2 reg."""
|
| 847 |
-
r_task = _task_reward_frame_adaptive(lrm, iou, cfg, e0_vec)
|
| 848 |
-
r_delta = delta.pow(2).sum()
|
| 849 |
-
r_reg = (q_global - q_init).pow(2).sum()
|
| 850 |
-
R = r_task - cfg.lambda_residual * r_delta - cfg.lambda_reg * r_reg
|
| 851 |
-
|
| 852 |
-
A = delta.shape[0]
|
| 853 |
-
if A > 1 and cfg.lambda_smooth_temp > 0.0:
|
| 854 |
-
r_smooth = torch.tensor(0.0, device=delta.device)
|
| 855 |
-
for t in range(A - 1):
|
| 856 |
-
r_smooth = r_smooth + (delta[t] - delta[t + 1]).pow(2).sum()
|
| 857 |
-
R = R - cfg.lambda_smooth_temp * r_smooth / (A - 1)
|
| 858 |
-
|
| 859 |
-
return R
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
def _compute_task_reward(
|
| 863 |
-
q: torch.Tensor,
|
| 864 |
-
lrm: torch.Tensor,
|
| 865 |
-
iou: torch.Tensor,
|
| 866 |
-
image_embeds_anchor_fp32: torch.Tensor,
|
| 867 |
-
cfg: QLTPOConfig,
|
| 868 |
-
e0: float = 1.0,
|
| 869 |
-
r_ref: Optional[torch.Tensor] = None,
|
| 870 |
-
r_neg: Optional[torch.Tensor] = None,
|
| 871 |
-
) -> torch.Tensor:
|
| 872 |
-
"""Dispatch to the correct stage's task reward."""
|
| 873 |
-
if cfg.stage == 1:
|
| 874 |
-
return _task_reward_stage1(lrm, iou, cfg, e0)
|
| 875 |
-
if cfg.stage == 2:
|
| 876 |
-
return _task_reward_stage2(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
|
| 877 |
-
if cfg.stage == 21:
|
| 878 |
-
assert r_ref is not None and r_neg is not None, "stage 21 requires r_ref/r_neg"
|
| 879 |
-
return _task_reward_stage2_tether(q, lrm, iou, r_ref, r_neg, cfg, e0)
|
| 880 |
-
if cfg.stage == 22:
|
| 881 |
-
assert r_ref is not None, "stage 22 requires r_ref"
|
| 882 |
-
return _task_reward_stage2_faithful(q, lrm, iou, image_embeds_anchor_fp32, r_ref, cfg, e0)
|
| 883 |
-
return _task_reward_stage3(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
def _compute_full_reward(
|
| 887 |
-
q: torch.Tensor,
|
| 888 |
-
lrm: torch.Tensor,
|
| 889 |
-
iou: torch.Tensor,
|
| 890 |
-
image_embeds_anchor_fp32: torch.Tensor,
|
| 891 |
-
q_init: torch.Tensor,
|
| 892 |
-
cfg: QLTPOConfig,
|
| 893 |
-
e0: float = 1.0,
|
| 894 |
-
r_ref: Optional[torch.Tensor] = None,
|
| 895 |
-
r_neg: Optional[torch.Tensor] = None,
|
| 896 |
-
) -> torch.Tensor:
|
| 897 |
-
"""Full reward = task reward + L2 regularization (used for backward)."""
|
| 898 |
-
r_task = _compute_task_reward(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0, r_ref, r_neg)
|
| 899 |
-
r_reg = (q - q_init).pow(2).sum()
|
| 900 |
-
return r_task - cfg.lambda_reg * r_reg
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
# ---------------------------------------------------------------------------
|
| 904 |
-
# Stage 0: gradient connectivity check
|
| 905 |
-
# ---------------------------------------------------------------------------
|
| 906 |
-
|
| 907 |
-
def check_grad_connectivity(
|
| 908 |
-
F_init: torch.Tensor, # [1, 256] any dtype
|
| 909 |
-
image_embeds: torch.Tensor, # [T, 256, 64, 64] any dtype
|
| 910 |
-
anchor_indices: List[int],
|
| 911 |
-
sam_model,
|
| 912 |
-
model_dtype: torch.dtype,
|
| 913 |
-
num_steps: int = 5,
|
| 914 |
-
lr: float = 0.0,
|
| 915 |
-
) -> dict:
|
| 916 |
-
"""Stage 0: verify ∂R_iou_pred/∂q ≠ 0 and reward rises with Adam maximize.
|
| 917 |
-
|
| 918 |
-
Runs num_steps of Adam on R = R_iou_pred only (the simplest differentiable
|
| 919 |
-
reward, no custom ops required). Returns a diagnostic dict.
|
| 920 |
-
|
| 921 |
-
Usage:
|
| 922 |
-
diag = check_grad_connectivity(F_init, image_embeds, anchors, sam, dtype)
|
| 923 |
-
print(diag['grad_norm_step0'], diag['reward_trajectory'])
|
| 924 |
-
# expect grad_norm > 0 and rewards non-decreasing
|
| 925 |
-
"""
|
| 926 |
-
device = F_init.device
|
| 927 |
-
image_embeds_anchor = image_embeds[anchor_indices].float().detach()
|
| 928 |
-
dense_emb = _precompute_dense_emb(sam_model, model_dtype, device).float().detach()
|
| 929 |
-
dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device).float().detach()
|
| 930 |
-
mask_dec = sam_model.mask_decoder
|
| 931 |
-
|
| 932 |
-
q_init_fp32 = F_init.float().detach()
|
| 933 |
-
if lr <= 0:
|
| 934 |
-
lr = 0.01 * (q_init_fp32.norm() / (q_init_fp32.numel() ** 0.5)).item()
|
| 935 |
-
|
| 936 |
-
q = torch.nn.Parameter(q_init_fp32.clone())
|
| 937 |
-
optimizer = torch.optim.Adam([q], lr=lr, maximize=True)
|
| 938 |
-
|
| 939 |
-
grad_norms, rewards = [], []
|
| 940 |
-
for step in range(num_steps):
|
| 941 |
-
optimizer.zero_grad()
|
| 942 |
-
lrm, iou = _decode_on_anchors_diff(q, image_embeds_anchor, dense_emb, mask_dec, dense_pe)
|
| 943 |
-
R = iou.mean()
|
| 944 |
-
R.backward()
|
| 945 |
-
grad_norm = q.grad.norm().item() if q.grad is not None else 0.0
|
| 946 |
-
grad_norms.append(grad_norm)
|
| 947 |
-
rewards.append(R.item())
|
| 948 |
-
optimizer.step()
|
| 949 |
-
|
| 950 |
-
return {
|
| 951 |
-
"grad_norm_step0": grad_norms[0],
|
| 952 |
-
"grad_norms": grad_norms,
|
| 953 |
-
"reward_trajectory": rewards,
|
| 954 |
-
"gradient_connected": grad_norms[0] > 1e-8,
|
| 955 |
-
}
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
# ---------------------------------------------------------------------------
|
| 959 |
-
# AVT proxy reward (Step A0: reward–metric correlation study)
|
| 960 |
-
# ---------------------------------------------------------------------------
|
| 961 |
-
|
| 962 |
-
@torch.no_grad()
|
| 963 |
-
def _compute_avt_proxy_reward(
|
| 964 |
-
q_init_fp32: torch.Tensor, # [1, 256] — frozen AVT anchor (= Fseg)
|
| 965 |
-
lrm: torch.Tensor, # [A, 1, 256, 256] float32
|
| 966 |
-
image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
|
| 967 |
-
cfg: "QLTPOConfig",
|
| 968 |
-
beta: float = 0.5,
|
| 969 |
-
) -> Tuple[float, float]:
|
| 970 |
-
"""Task-specific proxy reward using frozen q_init (Fseg) as teacher.
|
| 971 |
-
|
| 972 |
-
q_init = Fseg is already the audio+video+text fusion token produced by SimToken.
|
| 973 |
-
Using it as a frozen reference breaks Stage 2's self-confirming bias while
|
| 974 |
-
measuring whether the mask region aligns with the correct referent.
|
| 975 |
-
|
| 976 |
-
Returns:
|
| 977 |
-
R_avt = mean_t cos(z_in_t, q_init) [scalar]
|
| 978 |
-
R_avt_c = mean_t [cos(z_in_t, q_init) - beta·cos(z_out_t, q_init)] [scalar]
|
| 979 |
-
"""
|
| 980 |
-
A = lrm.shape[0]
|
| 981 |
-
q_norm = F.normalize(q_init_fp32[0], dim=0) # [256]
|
| 982 |
-
|
| 983 |
-
masks_64 = F.interpolate(
|
| 984 |
-
torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
|
| 985 |
-
size=(64, 64), mode="bilinear", align_corners=False,
|
| 986 |
-
).squeeze(1) # [A, 64, 64]
|
| 987 |
-
|
| 988 |
-
r_avt, r_avt_c = 0.0, 0.0
|
| 989 |
-
for t in range(A):
|
| 990 |
-
m = masks_64[t]
|
| 991 |
-
img = image_embeds_anchor_fp32[t]
|
| 992 |
-
z_in = F.normalize(
|
| 993 |
-
(img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0
|
| 994 |
-
)
|
| 995 |
-
z_out = F.normalize(
|
| 996 |
-
(img * (1.0 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1.0 - m).sum() + 1e-6), dim=0
|
| 997 |
-
)
|
| 998 |
-
c_in = (q_norm @ z_in).item()
|
| 999 |
-
c_out = (q_norm @ z_out).item()
|
| 1000 |
-
r_avt += c_in
|
| 1001 |
-
r_avt_c += c_in - beta * c_out
|
| 1002 |
-
return r_avt / A, r_avt_c / A
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
# ---------------------------------------------------------------------------
|
| 1006 |
-
# Stage 1–3: q-LTPO-autograd main optimizer
|
| 1007 |
-
# ---------------------------------------------------------------------------
|
| 1008 |
-
|
| 1009 |
-
def q_ltpo_autograd(
|
| 1010 |
-
F_init: torch.Tensor, # [1, 256] any dtype on CUDA
|
| 1011 |
-
image_embeds: torch.Tensor, # [T, 256, 64, 64] any dtype on CUDA
|
| 1012 |
-
anchor_indices: List[int],
|
| 1013 |
-
sam_model,
|
| 1014 |
-
model_dtype: torch.dtype,
|
| 1015 |
-
cfg: QLTPOConfig,
|
| 1016 |
-
) -> torch.Tensor:
|
| 1017 |
-
"""Optimise the SAM prompt token q at test time via Adam maximize.
|
| 1018 |
-
|
| 1019 |
-
q is initialised to F_init (= Fseg after text_hidden_fcs projection).
|
| 1020 |
-
The prompt encoder is bypassed: sparse_emb = q.unsqueeze(1), identical
|
| 1021 |
-
to what prompt_encoder produces when text_embeds is the only prompt.
|
| 1022 |
-
|
| 1023 |
-
All computation is done in float32 to avoid fp16 gradient truncation.
|
| 1024 |
-
Returns best_q as float32 [1, 256]. Falls back to F_init when gating
|
| 1025 |
-
rejects all updates.
|
| 1026 |
-
"""
|
| 1027 |
-
device = F_init.device
|
| 1028 |
-
|
| 1029 |
-
# ── Precompute constants (float32, detached) ──────────────────────────
|
| 1030 |
-
q_init_fp32 = F_init.float().detach()
|
| 1031 |
-
image_embeds_anchor = image_embeds[anchor_indices].float().detach()
|
| 1032 |
-
dense_emb = _precompute_dense_emb(sam_model, model_dtype, device).float().detach()
|
| 1033 |
-
dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device).float().detach()
|
| 1034 |
-
mask_dec = sam_model.mask_decoder
|
| 1035 |
-
|
| 1036 |
-
# ── Auto-scale lr and max_drift from q_init magnitude ─────────────────
|
| 1037 |
-
rms = q_init_fp32.norm() / (q_init_fp32.numel() ** 0.5)
|
| 1038 |
-
lr = cfg.lr if cfg.lr > 0 else 0.01 * rms.item()
|
| 1039 |
-
max_drift = cfg.max_drift if cfg.max_drift > 0 else 0.5 * q_init_fp32.norm().item()
|
| 1040 |
-
|
| 1041 |
-
# ── Precompute frozen external reference (stages 21, 22 only) ────────
|
| 1042 |
-
r_ref, r_neg = None, None
|
| 1043 |
-
if cfg.stage in (21, 22):
|
| 1044 |
-
r_ref, r_neg = _compute_r_ref(q_init_fp32, image_embeds_anchor, cfg.r_ref_temp)
|
| 1045 |
-
|
| 1046 |
-
# ── Baseline forward + e0 existence prior ────────────────────────────
|
| 1047 |
-
with torch.no_grad():
|
| 1048 |
-
lrm0, iou0 = _decode_on_anchors_diff(
|
| 1049 |
-
q_init_fp32, image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1050 |
-
)
|
| 1051 |
-
# e0 = stopgrad(R_area_soft(q_init)): fixes the scalar before the loop.
|
| 1052 |
-
# Suppresses R_iou when the initial mask is near-empty (existence prior).
|
| 1053 |
-
r_area_soft_init = torch.sigmoid(lrm0 / cfg.area_temp).mean().item()
|
| 1054 |
-
e0 = _compute_e0(r_area_soft_init, cfg)
|
| 1055 |
-
|
| 1056 |
-
R_init_task = _compute_task_reward(
|
| 1057 |
-
q_init_fp32, lrm0, iou0, image_embeds_anchor, cfg, e0=e0,
|
| 1058 |
-
r_ref=r_ref, r_neg=r_neg,
|
| 1059 |
-
).item()
|
| 1060 |
-
|
| 1061 |
-
# ── Optimisation setup ────────────────────────────────────────────────
|
| 1062 |
-
q = torch.nn.Parameter(q_init_fp32.clone())
|
| 1063 |
-
optimizer = torch.optim.Adam([q], lr=lr, maximize=True)
|
| 1064 |
-
|
| 1065 |
-
best_q = q.detach().clone()
|
| 1066 |
-
best_reward = R_init_task
|
| 1067 |
-
hit_clip = False
|
| 1068 |
-
|
| 1069 |
-
# ── Optimisation loop ─────────────────────────────────────────────────
|
| 1070 |
-
# Track per-step soft area to diagnose whether B1 penalty ever activates.
|
| 1071 |
-
_step_soft_areas: List[float] = []
|
| 1072 |
-
|
| 1073 |
-
for step in range(cfg.T):
|
| 1074 |
-
optimizer.zero_grad()
|
| 1075 |
-
|
| 1076 |
-
lrm, iou = _decode_on_anchors_diff(
|
| 1077 |
-
q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1078 |
-
)
|
| 1079 |
-
R_full = _compute_full_reward(q, lrm, iou, image_embeds_anchor, q_init_fp32, cfg, e0=e0,
|
| 1080 |
-
r_ref=r_ref, r_neg=r_neg)
|
| 1081 |
-
R_full.backward()
|
| 1082 |
-
optimizer.step()
|
| 1083 |
-
|
| 1084 |
-
# Hard L2 norm clip: keep q within max_drift ball around q_init
|
| 1085 |
-
with torch.no_grad():
|
| 1086 |
-
diff = q - q_init_fp32
|
| 1087 |
-
d = diff.norm()
|
| 1088 |
-
if d > max_drift:
|
| 1089 |
-
q.copy_(q_init_fp32 + diff * (max_drift / d))
|
| 1090 |
-
hit_clip = True
|
| 1091 |
-
|
| 1092 |
-
# Fresh no_grad forward on the post-step q_{N+1} for correct tracking.
|
| 1093 |
-
# (Pre-step lrm/iou would mismatch the updated q, causing wrong best_q.)
|
| 1094 |
-
with torch.no_grad():
|
| 1095 |
-
lrm_eval, iou_eval = _decode_on_anchors_diff(
|
| 1096 |
-
q.detach(), image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1097 |
-
)
|
| 1098 |
-
# Record soft area at this step for B1 activation diagnosis
|
| 1099 |
-
_step_soft_areas.append(
|
| 1100 |
-
torch.sigmoid(lrm_eval / cfg.area_temp).mean().item()
|
| 1101 |
-
)
|
| 1102 |
-
r_task = _compute_task_reward(
|
| 1103 |
-
q.detach(), lrm_eval, iou_eval, image_embeds_anchor, cfg, e0=e0,
|
| 1104 |
-
r_ref=r_ref, r_neg=r_neg,
|
| 1105 |
-
).item()
|
| 1106 |
-
if r_task > best_reward:
|
| 1107 |
-
best_reward = r_task
|
| 1108 |
-
best_q = q.detach().clone()
|
| 1109 |
-
|
| 1110 |
-
# Peak excess: how much did soft area exceed e0 at its highest point?
|
| 1111 |
-
# b1_peak_excess > 0 ↔ B1 ReLU was non-zero at that step.
|
| 1112 |
-
# b1_peak_excess = 0 ↔ B1 never activated (area stayed below e0 throughout).
|
| 1113 |
-
_max_step_area = max(_step_soft_areas) if _step_soft_areas else r_area_soft_init
|
| 1114 |
-
b1_peak_excess = max(_max_step_area - e0, 0.0)
|
| 1115 |
-
|
| 1116 |
-
# ── Reward gating: clean re-eval of best_q vs q_init ─────────────────
|
| 1117 |
-
with torch.no_grad():
|
| 1118 |
-
lrm_b, iou_b = _decode_on_anchors_diff(
|
| 1119 |
-
best_q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1120 |
-
)
|
| 1121 |
-
R_best_task = _compute_task_reward(
|
| 1122 |
-
best_q, lrm_b, iou_b, image_embeds_anchor, cfg, e0=e0,
|
| 1123 |
-
r_ref=r_ref, r_neg=r_neg,
|
| 1124 |
-
).item()
|
| 1125 |
-
|
| 1126 |
-
area_init = (lrm0 > 0).float().mean().item()
|
| 1127 |
-
effective_gate = (
|
| 1128 |
-
cfg.null_gate_delta
|
| 1129 |
-
if (cfg.null_gate_delta > 0 and area_init < cfg.null_area_threshold)
|
| 1130 |
-
else cfg.gate_delta
|
| 1131 |
-
)
|
| 1132 |
-
accepted = R_best_task > R_init_task + effective_gate
|
| 1133 |
-
|
| 1134 |
-
# ── Mask soft-IoU: how much did the mask actually change? ─────────────
|
| 1135 |
-
# Answers whether q-drift translated into mask change, or fell in a
|
| 1136 |
-
# flat direction of the mask decoder manifold.
|
| 1137 |
-
with torch.no_grad():
|
| 1138 |
-
m0 = torch.sigmoid(lrm0 / cfg.area_temp).squeeze(1) # [A,256,256]
|
| 1139 |
-
mb = torch.sigmoid(lrm_b / cfg.area_temp).squeeze(1) # [A,256,256]
|
| 1140 |
-
inter = (m0 * mb).sum(dim=[1, 2])
|
| 1141 |
-
union = (m0 + mb - m0 * mb).sum(dim=[1, 2])
|
| 1142 |
-
mask_soft_iou = (inter / (union + 1e-6)).mean().item()
|
| 1143 |
-
|
| 1144 |
-
# Soft area at best_q — tracks whether B1 asymmetric penalty worked
|
| 1145 |
-
r_area_soft_best = mb.mean().item() # sigmoid(lrm_b/area_temp).mean()
|
| 1146 |
-
|
| 1147 |
-
# Reward decomposition: iou contribution to reward gain
|
| 1148 |
-
R_iou_contrib_gain = (
|
| 1149 |
-
cfg.lambda_iou * e0 * (iou_b.mean().item() - iou0.mean().item())
|
| 1150 |
-
)
|
| 1151 |
-
|
| 1152 |
-
# AVT proxy reward (Step A0 correlation study)
|
| 1153 |
-
r_avt_init, r_avt_c_init = _compute_avt_proxy_reward(
|
| 1154 |
-
q_init_fp32, lrm0, image_embeds_anchor, cfg
|
| 1155 |
-
)
|
| 1156 |
-
r_avt_best, r_avt_c_best = _compute_avt_proxy_reward(
|
| 1157 |
-
q_init_fp32, lrm_b, image_embeds_anchor, cfg
|
| 1158 |
-
)
|
| 1159 |
-
|
| 1160 |
-
# ── Per-sample diagnostics ────────────────────────────────────────────
|
| 1161 |
-
_q_ltpo_stats.append({
|
| 1162 |
-
"accepted": accepted,
|
| 1163 |
-
"reward_gain": R_best_task - R_init_task,
|
| 1164 |
-
"drift": (best_q - q_init_fp32).norm().item(),
|
| 1165 |
-
"hit_clip": hit_clip,
|
| 1166 |
-
"e0": e0,
|
| 1167 |
-
"R_iou_pred_init": iou0.mean().item(),
|
| 1168 |
-
"R_iou_pred_best": iou_b.mean().item(),
|
| 1169 |
-
"area_hard_init": area_init,
|
| 1170 |
-
"area_hard_best": (lrm_b > 0).float().mean().item(),
|
| 1171 |
-
"r_area_soft_init": r_area_soft_init,
|
| 1172 |
-
"r_area_soft_best": r_area_soft_best,
|
| 1173 |
-
"b1_peak_excess": b1_peak_excess,
|
| 1174 |
-
"mask_soft_iou": mask_soft_iou,
|
| 1175 |
-
"R_iou_contrib_gain": R_iou_contrib_gain,
|
| 1176 |
-
# AVT proxy: frozen q_init as teacher — task-specific alignment
|
| 1177 |
-
"r_avt_init": r_avt_init,
|
| 1178 |
-
"r_avt_best": r_avt_best,
|
| 1179 |
-
"r_avt_gain": r_avt_best - r_avt_init,
|
| 1180 |
-
"r_avt_c_init": r_avt_c_init,
|
| 1181 |
-
"r_avt_c_best": r_avt_c_best,
|
| 1182 |
-
"r_avt_c_gain": r_avt_c_best - r_avt_c_init,
|
| 1183 |
-
})
|
| 1184 |
-
|
| 1185 |
-
if not accepted:
|
| 1186 |
-
return F_init.float()
|
| 1187 |
-
return best_q
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
# ===========================================================================
|
| 1191 |
-
# Direction II: Frame-adaptive token optimization (stage=4)
|
| 1192 |
-
# q_t = q_global + delta_t — shared global token + per-anchor residual
|
| 1193 |
-
# ===========================================================================
|
| 1194 |
-
|
| 1195 |
-
def q_ltpo_frame_adaptive(
|
| 1196 |
-
F_init: torch.Tensor, # [1, 256] any dtype on CUDA
|
| 1197 |
-
image_embeds: torch.Tensor, # [T, 256, 64, 64] any dtype on CUDA
|
| 1198 |
-
anchor_indices: List[int],
|
| 1199 |
-
sam_model,
|
| 1200 |
-
model_dtype: torch.dtype,
|
| 1201 |
-
cfg: QLTPOConfig,
|
| 1202 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1203 |
-
"""Frame-adaptive q-LTPO: optimize q_global and per-anchor delta jointly.
|
| 1204 |
-
|
| 1205 |
-
Each anchor frame t gets its own token q_t = q_global + delta_t.
|
| 1206 |
-
delta_t is initialized to zero so q_t starts equal to q_init for all frames.
|
| 1207 |
-
Per-frame existence priors e0_t suppress optimization on near-empty anchors.
|
| 1208 |
-
|
| 1209 |
-
Returns:
|
| 1210 |
-
q_global [1, 256] float32 — shared global token
|
| 1211 |
-
delta [A, 256] float32 — per-anchor residuals (zero if not accepted)
|
| 1212 |
-
"""
|
| 1213 |
-
device = F_init.device
|
| 1214 |
-
A = len(anchor_indices)
|
| 1215 |
-
|
| 1216 |
-
q_init_fp32 = F_init.float().detach()
|
| 1217 |
-
image_embeds_anchor = image_embeds[anchor_indices].float().detach()
|
| 1218 |
-
dense_emb = _precompute_dense_emb(sam_model, model_dtype, device).float().detach()
|
| 1219 |
-
dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device).float().detach()
|
| 1220 |
-
mask_dec = sam_model.mask_decoder
|
| 1221 |
-
|
| 1222 |
-
rms = q_init_fp32.norm() / (q_init_fp32.numel() ** 0.5)
|
| 1223 |
-
lr = cfg.lr if cfg.lr > 0 else 0.01 * rms.item()
|
| 1224 |
-
max_drift = cfg.max_drift if cfg.max_drift > 0 else 0.5 * q_init_fp32.norm().item()
|
| 1225 |
-
max_delta_drift = cfg.max_delta_drift_scale * q_init_fp32.norm().item()
|
| 1226 |
-
|
| 1227 |
-
# ── Baseline: per-anchor e0 existence priors ────────────────────────────
|
| 1228 |
-
with torch.no_grad():
|
| 1229 |
-
lrm0, iou0 = _decode_on_anchors_diff(
|
| 1230 |
-
q_init_fp32, image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1231 |
-
)
|
| 1232 |
-
e0_vec: List[float] = []
|
| 1233 |
-
for t in range(A):
|
| 1234 |
-
e0_t = torch.sigmoid(lrm0[t] / cfg.area_temp).mean().item()
|
| 1235 |
-
e0_vec.append(_compute_e0(e0_t, cfg))
|
| 1236 |
-
e0_global = sum(e0_vec) / A
|
| 1237 |
-
|
| 1238 |
-
R_init_task = _task_reward_frame_adaptive(lrm0, iou0, cfg, e0_vec).item()
|
| 1239 |
-
|
| 1240 |
-
# ── Setup optimization ───────────────────────────────────────────────────
|
| 1241 |
-
q_global = torch.nn.Parameter(q_init_fp32.clone())
|
| 1242 |
-
delta = torch.nn.Parameter(torch.zeros(A, 256, device=device, dtype=torch.float32))
|
| 1243 |
-
optimizer = torch.optim.Adam([q_global, delta], lr=lr, maximize=True)
|
| 1244 |
-
|
| 1245 |
-
best_q_global = q_global.detach().clone()
|
| 1246 |
-
best_delta = delta.detach().clone()
|
| 1247 |
-
best_reward = R_init_task
|
| 1248 |
-
hit_clip = False
|
| 1249 |
-
|
| 1250 |
-
# ── Optimization loop ────────────────────────────────────────────────────
|
| 1251 |
-
for step in range(cfg.T):
|
| 1252 |
-
optimizer.zero_grad()
|
| 1253 |
-
lrm, iou = _decode_on_anchors_diff_adaptive(
|
| 1254 |
-
q_global, delta, image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1255 |
-
)
|
| 1256 |
-
R_full = _compute_full_reward_adaptive(
|
| 1257 |
-
q_global, delta, lrm, iou, q_init_fp32, cfg, e0_vec
|
| 1258 |
-
)
|
| 1259 |
-
R_full.backward()
|
| 1260 |
-
optimizer.step()
|
| 1261 |
-
|
| 1262 |
-
# Clip q_global and each per-anchor delta within trust regions
|
| 1263 |
-
with torch.no_grad():
|
| 1264 |
-
diff = q_global - q_init_fp32
|
| 1265 |
-
d = diff.norm()
|
| 1266 |
-
if d > max_drift:
|
| 1267 |
-
q_global.copy_(q_init_fp32 + diff * (max_drift / d))
|
| 1268 |
-
hit_clip = True
|
| 1269 |
-
for t in range(A):
|
| 1270 |
-
dn = delta[t].norm()
|
| 1271 |
-
if dn > max_delta_drift:
|
| 1272 |
-
delta[t].copy_(delta[t] * (max_delta_drift / dn))
|
| 1273 |
-
|
| 1274 |
-
# Track best (no_grad re-eval of task reward without reg)
|
| 1275 |
-
with torch.no_grad():
|
| 1276 |
-
lrm_eval, iou_eval = _decode_on_anchors_diff_adaptive(
|
| 1277 |
-
q_global.detach(), delta.detach(),
|
| 1278 |
-
image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1279 |
-
)
|
| 1280 |
-
r_task = _task_reward_frame_adaptive(lrm_eval, iou_eval, cfg, e0_vec).item()
|
| 1281 |
-
if r_task > best_reward:
|
| 1282 |
-
best_reward = r_task
|
| 1283 |
-
best_q_global = q_global.detach().clone()
|
| 1284 |
-
best_delta = delta.detach().clone()
|
| 1285 |
-
|
| 1286 |
-
# ── Gating ───────────────────────────────────────────────────────────────
|
| 1287 |
-
with torch.no_grad():
|
| 1288 |
-
lrm_b, iou_b = _decode_on_anchors_diff_adaptive(
|
| 1289 |
-
best_q_global, best_delta, image_embeds_anchor, dense_emb, mask_dec, dense_pe
|
| 1290 |
-
)
|
| 1291 |
-
R_best_task = _task_reward_frame_adaptive(lrm_b, iou_b, cfg, e0_vec).item()
|
| 1292 |
-
|
| 1293 |
-
accepted = R_best_task > R_init_task + cfg.gate_delta
|
| 1294 |
-
|
| 1295 |
-
area_init = (lrm0 > 0).float().mean().item()
|
| 1296 |
-
r_area_soft_init = sum(torch.sigmoid(lrm0[t] / cfg.area_temp).mean().item() for t in range(A)) / A
|
| 1297 |
-
r_area_soft_best = sum(torch.sigmoid(lrm_b[t] / cfg.area_temp).mean().item() for t in range(A)) / A
|
| 1298 |
-
|
| 1299 |
-
# Actual mask soft-IoU between init and best (per anchor, averaged)
|
| 1300 |
-
m0 = torch.sigmoid(lrm0 / cfg.area_temp).squeeze(1) # [A,256,256]
|
| 1301 |
-
mb = torch.sigmoid(lrm_b / cfg.area_temp).squeeze(1) # [A,256,256]
|
| 1302 |
-
inter = (m0 * mb).sum(dim=[1, 2])
|
| 1303 |
-
union = (m0 + mb - m0 * mb).sum(dim=[1, 2])
|
| 1304 |
-
mask_soft_iou_fa = (inter / (union + 1e-6)).mean().item()
|
| 1305 |
-
|
| 1306 |
-
_q_ltpo_stats.append({
|
| 1307 |
-
"accepted": accepted,
|
| 1308 |
-
"reward_gain": R_best_task - R_init_task,
|
| 1309 |
-
"drift": (best_q_global - q_init_fp32).norm().item(),
|
| 1310 |
-
"delta_norm": best_delta.norm().item(),
|
| 1311 |
-
"hit_clip": hit_clip,
|
| 1312 |
-
"e0": e0_global,
|
| 1313 |
-
"R_iou_pred_init": iou0.mean().item(),
|
| 1314 |
-
"R_iou_pred_best": iou_b.mean().item(),
|
| 1315 |
-
"area_hard_init": area_init,
|
| 1316 |
-
"area_hard_best": (lrm_b > 0).float().mean().item(),
|
| 1317 |
-
"r_area_soft_init": r_area_soft_init,
|
| 1318 |
-
"r_area_soft_best": r_area_soft_best,
|
| 1319 |
-
"b1_peak_excess": 0.0,
|
| 1320 |
-
"mask_soft_iou": mask_soft_iou_fa,
|
| 1321 |
-
"R_iou_contrib_gain": cfg.lambda_iou * e0_global * (iou_b.mean().item() - iou0.mean().item()),
|
| 1322 |
-
})
|
| 1323 |
-
|
| 1324 |
-
if not accepted:
|
| 1325 |
-
return q_init_fp32, torch.zeros(A, 256, device=device, dtype=torch.float32)
|
| 1326 |
-
return best_q_global, best_delta
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
def decode_full_video_adaptive(
|
| 1330 |
-
q_global: torch.Tensor, # [1, 256] float32
|
| 1331 |
-
delta: torch.Tensor, # [A, 256] float32
|
| 1332 |
-
anchor_indices: List[int],
|
| 1333 |
-
image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
|
| 1334 |
-
sam_model,
|
| 1335 |
-
resize: tuple,
|
| 1336 |
-
orgsize: tuple,
|
| 1337 |
-
model_dtype: torch.dtype,
|
| 1338 |
-
) -> torch.Tensor:
|
| 1339 |
-
"""Decode all T frames with frame-adaptive tokens.
|
| 1340 |
-
|
| 1341 |
-
Each frame is assigned to its nearest anchor by index distance, then decoded
|
| 1342 |
-
with q_t = q_global + delta[anchor_idx].
|
| 1343 |
-
Returns raw logit masks [T, H_orig, W_orig].
|
| 1344 |
-
"""
|
| 1345 |
-
T = image_embeds.shape[0]
|
| 1346 |
-
A = len(anchor_indices)
|
| 1347 |
-
device = image_embeds.device
|
| 1348 |
-
|
| 1349 |
-
dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
|
| 1350 |
-
dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
|
| 1351 |
-
|
| 1352 |
-
# Nearest-anchor assignment for every frame
|
| 1353 |
-
anchor_arr = torch.tensor(anchor_indices, dtype=torch.float32)
|
| 1354 |
-
frame_to_anchor = [int((anchor_arr - t).abs().argmin().item()) for t in range(T)]
|
| 1355 |
-
|
| 1356 |
-
pred_masks: List[torch.Tensor] = []
|
| 1357 |
-
with torch.no_grad():
|
| 1358 |
-
for t in range(T):
|
| 1359 |
-
a = frame_to_anchor[t]
|
| 1360 |
-
q_t = (q_global + delta[a : a + 1]).to(model_dtype) # [1, 256]
|
| 1361 |
-
sparse_emb = q_t.unsqueeze(1) # [1, 1, 256]
|
| 1362 |
-
lrm_t, _ = sam_model.mask_decoder(
|
| 1363 |
-
image_embeddings=image_embeds[t : t + 1],
|
| 1364 |
-
image_pe=dense_pe,
|
| 1365 |
-
sparse_prompt_embeddings=sparse_emb,
|
| 1366 |
-
dense_prompt_embeddings=dense_emb,
|
| 1367 |
-
multimask_output=False,
|
| 1368 |
-
) # [1, 1, 256, 256]
|
| 1369 |
-
pred_t = sam_model.postprocess_masks(lrm_t, input_size=resize, original_size=orgsize)
|
| 1370 |
-
pred_masks.append(pred_t.squeeze(0).squeeze(0)) # [H, W]
|
| 1371 |
-
|
| 1372 |
-
return torch.stack(pred_masks, dim=0) # [T, H_orig, W_orig]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup_simtoken.md
DELETED
|
@@ -1,163 +0,0 @@
|
|
| 1 |
-
# SimToken Setup
|
| 2 |
-
|
| 3 |
-
本文档用于在新机器上重建 SimToken 环境,并准备后续 A-min 实验。
|
| 4 |
-
|
| 5 |
-
---
|
| 6 |
-
|
| 7 |
-
## 1. Create Environment
|
| 8 |
-
|
| 9 |
-
先确认 GPU 和 CUDA driver 状态:
|
| 10 |
-
|
| 11 |
-
```bash
|
| 12 |
-
nvidia-smi
|
| 13 |
-
```
|
| 14 |
-
|
| 15 |
-
创建 conda 环境:
|
| 16 |
-
|
| 17 |
-
```bash
|
| 18 |
-
/opt/miniforge3/condabin/conda create -n simtoken python=3.10 -y
|
| 19 |
-
/opt/miniforge3/condabin/conda activate simtoken
|
| 20 |
-
|
| 21 |
-
python -m pip install --upgrade pip wheel "setuptools<81"
|
| 22 |
-
|
| 23 |
-
pip install \
|
| 24 |
-
torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 \
|
| 25 |
-
--index-url https://download.pytorch.org/whl/cu121
|
| 26 |
-
|
| 27 |
-
pip install \
|
| 28 |
-
transformers==4.30.2 \
|
| 29 |
-
peft==0.2.0 \
|
| 30 |
-
accelerate==0.21.0 \
|
| 31 |
-
sentencepiece \
|
| 32 |
-
protobuf \
|
| 33 |
-
safetensors \
|
| 34 |
-
numpy==1.26.4 \
|
| 35 |
-
pandas \
|
| 36 |
-
matplotlib \
|
| 37 |
-
opencv-python \
|
| 38 |
-
pillow \
|
| 39 |
-
tqdm \
|
| 40 |
-
einops \
|
| 41 |
-
timm \
|
| 42 |
-
requests \
|
| 43 |
-
towhee \
|
| 44 |
-
huggingface_hub
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
-
快速验证:
|
| 48 |
-
|
| 49 |
-
```bash
|
| 50 |
-
python - <<'PY'
|
| 51 |
-
import torch
|
| 52 |
-
print("torch:", torch.__version__)
|
| 53 |
-
print("cuda available:", torch.cuda.is_available())
|
| 54 |
-
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")
|
| 55 |
-
PY
|
| 56 |
-
```
|
| 57 |
-
|
| 58 |
-
---
|
| 59 |
-
|
| 60 |
-
## 2. Download from HuggingFace
|
| 61 |
-
|
| 62 |
-
如果新机器不使用迁移工具,而是从 HuggingFace 重新初始化,先登录:
|
| 63 |
-
|
| 64 |
-
```bash
|
| 65 |
-
huggingface-cli login
|
| 66 |
-
```
|
| 67 |
-
|
| 68 |
-
下载完整 repo:
|
| 69 |
-
|
| 70 |
-
```bash
|
| 71 |
-
mkdir -p /workspace/SimToken
|
| 72 |
-
cd /workspace/SimToken
|
| 73 |
-
|
| 74 |
-
huggingface-cli download yfan07/SimToken \
|
| 75 |
-
--repo-type model \
|
| 76 |
-
--local-dir . \
|
| 77 |
-
--local-dir-use-symlinks False
|
| 78 |
-
```
|
| 79 |
-
|
| 80 |
-
下载完成后解压数据:
|
| 81 |
-
|
| 82 |
-
```bash
|
| 83 |
-
cd /workspace/SimToken/data
|
| 84 |
-
|
| 85 |
-
tar -xf image_embed.tar
|
| 86 |
-
tar -xzf gt_mask.tar.gz
|
| 87 |
-
tar -xzf audio_embed.tar.gz
|
| 88 |
-
tar -xf media.tar
|
| 89 |
-
```
|
| 90 |
-
|
| 91 |
-
---
|
| 92 |
-
|
| 93 |
-
## 3. Pre-download Model Weights
|
| 94 |
-
|
| 95 |
-
`transformers==4.30.2` 与新版 `huggingface_hub` 可能存在网络/API 兼容问题。建议先用 CLI 将模型下载到本地缓存,实验时再加 `TRANSFORMERS_OFFLINE=1`。
|
| 96 |
-
|
| 97 |
-
```bash
|
| 98 |
-
# Chat-UniVi-7B
|
| 99 |
-
huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
|
| 100 |
-
|
| 101 |
-
# CLIP ViT-L
|
| 102 |
-
huggingface-cli download openai/clip-vit-large-patch14
|
| 103 |
-
```
|
| 104 |
-
|
| 105 |
-
下载完成后做离线验证:
|
| 106 |
-
|
| 107 |
-
```bash
|
| 108 |
-
cd /workspace/SimToken
|
| 109 |
-
|
| 110 |
-
TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
|
| 111 |
-
python -m py_compile train.py load_model.py decoder_invariance_check.py
|
| 112 |
-
```
|
| 113 |
-
|
| 114 |
-
---
|
| 115 |
-
|
| 116 |
-
## 4. Upload to HuggingFace
|
| 117 |
-
|
| 118 |
-
实验结束后,如需重新上传到 HuggingFace,先将数据目录压缩为归档文件,减少文件数量:
|
| 119 |
-
|
| 120 |
-
```bash
|
| 121 |
-
cd /workspace/SimToken/data
|
| 122 |
-
|
| 123 |
-
tar -cf image_embed.tar image_embed/
|
| 124 |
-
tar -czf gt_mask.tar.gz gt_mask/
|
| 125 |
-
tar -czf audio_embed.tar.gz audio_embed/
|
| 126 |
-
tar -cf media.tar media/
|
| 127 |
-
|
| 128 |
-
ls -lh *.tar*
|
| 129 |
-
|
| 130 |
-
# HuggingFace 单文件硬限制为 50GB;如果 image_embed.tar 超过 50GB,
|
| 131 |
-
# 需要切成小于 50GB 的分片再上传。
|
| 132 |
-
split -b 45G -d -a 2 image_embed.tar image_embed.tar.part-
|
| 133 |
-
|
| 134 |
-
# 校验分片拼接后仍能读出完整 tar 文件列表。
|
| 135 |
-
cat image_embed.tar.part-* | tar -tf - | grep -v '/$' | wc -l
|
| 136 |
-
|
| 137 |
-
# 分片校验通过后再删除超大原始 tar,避免上传失败。
|
| 138 |
-
rm -f image_embed.tar
|
| 139 |
-
|
| 140 |
-
rm -rf image_embed/ gt_mask/ audio_embed/ media/
|
| 141 |
-
```
|
| 142 |
-
|
| 143 |
-
下载后如需恢复 `image_embed.tar`:
|
| 144 |
-
|
| 145 |
-
```bash
|
| 146 |
-
cd /workspace/SimToken/data
|
| 147 |
-
cat image_embed.tar.part-* > image_embed.tar
|
| 148 |
-
tar -xf image_embed.tar
|
| 149 |
-
```
|
| 150 |
-
|
| 151 |
-
清理缓存并上传:
|
| 152 |
-
|
| 153 |
-
```bash
|
| 154 |
-
cd /workspace/SimToken
|
| 155 |
-
|
| 156 |
-
find . -name "__pycache__" -prune -exec rm -rf {} +
|
| 157 |
-
find . -name ".pytest_cache" -prune -exec rm -rf {} +
|
| 158 |
-
find . -name ".cache" -prune -exec rm -rf {} +
|
| 159 |
-
find . -name "*.pyc" -delete
|
| 160 |
-
|
| 161 |
-
huggingface-cli login
|
| 162 |
-
python upload_hf.py --repo yfan07/SimToken
|
| 163 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
simtoken_experiment.md
DELETED
|
@@ -1,369 +0,0 @@
|
|
| 1 |
-
# SimToken 实验路线文档
|
| 2 |
-
|
| 3 |
-
## 0. 当前状态
|
| 4 |
-
|
| 5 |
-
前置诊断已经完成,路线收敛到 **A-min dynamic referent gate training**。
|
| 6 |
-
|
| 7 |
-
已确认结论:
|
| 8 |
-
|
| 9 |
-
1. **SAM decoder 下游是逐帧 batch-parallel 解码**
|
| 10 |
-
`mask_decoder(image_embeddings[0:T])[t]` 与 `mask_decoder(image_embeddings[t:t+1])[0]` 只有混合精度数值噪声差异。旧的 decoder-level joint-frame competition 假设关闭。
|
| 11 |
-
|
| 12 |
-
2. **target_frame sweep 基本无效**
|
| 13 |
-
不同 target frame 生成的 q 几乎相同,`cos_to_q5` 通常在 `0.997+`;Seen/Null 上 oracle gain 约 `+0.0009`。这条 TTO 路线关闭。
|
| 14 |
-
|
| 15 |
-
3. **raw SAM-space D2 失效**
|
| 16 |
-
256 维 `q/Fseg` 与 SAM image embedding 不在可直接 cosine 的语义空间,`real q ≈ shuffled/wrong_ref q`,甚至 random q 更高。该定义关闭。
|
| 17 |
-
|
| 18 |
-
4. **LLM-space D2 有弱诊断信号,但不适合作为主 reward**
|
| 19 |
-
用 4096 维 `[SEG]` hidden state 与 `mm_projector(CLIP patch tokens)` 后的视觉 token 计算 D2,可以得到正相关:
|
| 20 |
-
- `corr(s_pred, frame_iou) ≈ +0.316`
|
| 21 |
-
- bottom 20% `s_pred` 中 failure rate 相比随机 baseline 约 `1.60x`
|
| 22 |
-
- 控制 `iou_pred` / `pred_area` 后偏相关约 `+0.14`
|
| 23 |
-
|
| 24 |
-
结论:`s_pred(beta=1.0)` 可以作为诊断信号或 frame-aware gate 的候选输入,但不能作为核心 TTO reward。
|
| 25 |
-
|
| 26 |
-
5. **margin-D2 无效**
|
| 27 |
-
离线 `s_margin = s(real) - max(s(shuffled), s(wrong_ref))` 的 failure enrichment 约 `0.93x`,会抵消掉有用的通用可见性/质量信号。该路线关闭。
|
| 28 |
-
|
| 29 |
-
当前最干净的解释是:
|
| 30 |
-
|
| 31 |
-
> q 本身通常是稳定的 referent anchor;主要瓶颈不在 q 生成,也不在简单 q selection,而在 SAM decoder 如何使用已有的 `mask_token -> q` sparse self-attention path。
|
| 32 |
-
|
| 33 |
-
2026-04-22 更新:
|
| 34 |
-
|
| 35 |
-
完整训练每个 epoch 约 2-4 小时,瓶颈主要在 7B MLLM forward,而不在 gate 本身。因此当前实验策略已调整为:
|
| 36 |
-
|
| 37 |
-
1. 先缓存固定 checkpoint 下的 `q = seg_embeddings`;
|
| 38 |
-
2. 在 cached q + cached SAM image embeddings 上训练 gate-only;
|
| 39 |
-
3. 用 cached eval split 快速判断 gate 是否有泛化收益;
|
| 40 |
-
4. 只有 gate-only 泛化信号成立后,再跑完整 A-min 联合训练。
|
| 41 |
-
|
| 42 |
-
---
|
| 43 |
-
|
| 44 |
-
## 1. A-min 当前实现
|
| 45 |
-
|
| 46 |
-
已在代码中加入 A-min dynamic referent gate:
|
| 47 |
-
|
| 48 |
-
- 文件:`models/segment_anything/modeling/transformer.py`
|
| 49 |
-
- 模块:`ReferentGate`
|
| 50 |
-
- 插入位置:`TwoWayAttentionBlock` 的 sparse self-attention + `norm1` 之后,token-to-image cross-attention 之前
|
| 51 |
-
- 作用对象:只作用于 `mask_tokens`
|
| 52 |
-
- 不作用于:`iou_token` 和 `q/sparse_prompt` 本身
|
| 53 |
-
|
| 54 |
-
SAM token index:
|
| 55 |
-
|
| 56 |
-
```python
|
| 57 |
-
tokens = [iou_token, mask_tokens..., sparse_prompt(q)]
|
| 58 |
-
```
|
| 59 |
-
|
| 60 |
-
因此:
|
| 61 |
-
|
| 62 |
-
```python
|
| 63 |
-
iou_token index: 0
|
| 64 |
-
mask token range: 1 : 1 + num_mask_tokens
|
| 65 |
-
q token index: 1 + num_mask_tokens
|
| 66 |
-
```
|
| 67 |
-
|
| 68 |
-
A-min gate 形式:
|
| 69 |
-
|
| 70 |
-
```python
|
| 71 |
-
alpha = sigmoid(Linear([mask_token, q, cos(mask_token, q)]))
|
| 72 |
-
mask_token = mask_token + alpha * Linear(q)
|
| 73 |
-
```
|
| 74 |
-
|
| 75 |
-
为保证旧 checkpoint 初始行为不变,`proj(q)` 分支使用零初始化。当前也将 `gate` 分支零初始化,使 alpha 有干净观测基线:
|
| 76 |
-
|
| 77 |
-
```python
|
| 78 |
-
nn.init.zeros_(self.gate.weight)
|
| 79 |
-
nn.init.zeros_(self.gate.bias)
|
| 80 |
-
nn.init.zeros_(self.proj.weight)
|
| 81 |
-
nn.init.zeros_(self.proj.bias)
|
| 82 |
-
```
|
| 83 |
-
|
| 84 |
-
初始时 gate 为 identity:
|
| 85 |
-
|
| 86 |
-
```text
|
| 87 |
-
max_abs_diff(gate(mask, q), mask) = 0.0
|
| 88 |
-
alpha_mean = 0.5
|
| 89 |
-
alpha_std = 0.0
|
| 90 |
-
```
|
| 91 |
-
|
| 92 |
-
当前训练 forward 保持完整链路:`prepare_inputs_labels_for_multimodal -> MLLM forward -> text_hidden_fcs -> SAM mask decoder -> loss`。`--gate_only` 只控制参数冻结范围,不再改变 forward 语义。
|
| 93 |
-
|
| 94 |
-
---
|
| 95 |
-
|
| 96 |
-
## 2. 当前新增工具
|
| 97 |
-
|
| 98 |
-
### 2.1 训练脚本增强
|
| 99 |
-
|
| 100 |
-
`train.py` 已加入:
|
| 101 |
-
|
| 102 |
-
- `--max_steps`
|
| 103 |
-
- `--overfit_samples`
|
| 104 |
-
- `--log_gate_stats_every`
|
| 105 |
-
- `--skip_eval_after_train`
|
| 106 |
-
- `--eval_train_only`
|
| 107 |
-
|
| 108 |
-
启动时会打印 referent gate 参数是否 trainable、是否进入 optimizer,以及初始 `proj_norm/gate_norm`。
|
| 109 |
-
|
| 110 |
-
### 2.2 cached q 路线
|
| 111 |
-
|
| 112 |
-
新增脚本:
|
| 113 |
-
|
| 114 |
-
- `cache_q_features.py`
|
| 115 |
-
- 离线缓存 `q = seg_embeddings`
|
| 116 |
-
- cache 文件很小,因为只保存 q 和少量 metadata
|
| 117 |
-
- `image_embeddings` 仍使用已有 `data/image_embed/{vid}.pt`
|
| 118 |
-
- `gt_masks` 仍使用已有 `data/gt_mask/...`
|
| 119 |
-
|
| 120 |
-
- `train_cached_gate.py`
|
| 121 |
-
- 加载 base model 和 cached q
|
| 122 |
-
- 冻结全部参数,只训练 `referent_gate`
|
| 123 |
-
- 支持 `--eval_only`、`--disable_gate`
|
| 124 |
-
- 支持 `--save_gate_only`,只保存 gate 参数,checkpoint 约 1.6MB
|
| 125 |
-
- 支持 `--gate_checkpoint`,在 base checkpoint 上 overlay gate-only checkpoint
|
| 126 |
-
- gate stats 会记录:
|
| 127 |
-
|
| 128 |
-
```text
|
| 129 |
-
batch_miou
|
| 130 |
-
batch_fscore
|
| 131 |
-
proj_norm
|
| 132 |
-
gate_norm
|
| 133 |
-
proj_grad_norm
|
| 134 |
-
gate_grad_norm
|
| 135 |
-
alpha_mean / alpha_std / alpha_min / alpha_max
|
| 136 |
-
```
|
| 137 |
-
|
| 138 |
-
cached 解码已优化:一个 dataloader batch 会展平成 paired frame batch 调用 `mask_decoder.forward_modified_v3`,避免逐 sample 调 decoder 的主要开销,同时不会产生 prompt/image cross product。
|
| 139 |
-
|
| 140 |
-
---
|
| 141 |
-
|
| 142 |
-
## 3. 已完成实验结果
|
| 143 |
-
|
| 144 |
-
### 3.1 cached identity 与原始 pipeline 一致性
|
| 145 |
-
|
| 146 |
-
先用 `test_s` 前 10 条验证 cached pipeline 是否与原始 `load_model.py` 对齐:
|
| 147 |
-
|
| 148 |
-
```text
|
| 149 |
-
cached identity:
|
| 150 |
-
mIoU = 0.9686462879
|
| 151 |
-
Fscore = 0.9868578851
|
| 152 |
-
|
| 153 |
-
original load_model.py:
|
| 154 |
-
mIoU = 0.9686277151
|
| 155 |
-
Fscore = 0.9868472159
|
| 156 |
-
|
| 157 |
-
diff:
|
| 158 |
-
mIoU = +0.0000186
|
| 159 |
-
Fscore = +0.0000107
|
| 160 |
-
```
|
| 161 |
-
|
| 162 |
-
结论:差异远小于 0.001,cached q pipeline 与原始 eval pipeline 一致,可以用于 gate-only 快速验证。
|
| 163 |
-
|
| 164 |
-
### 3.2 gate probe:梯度路径与 alpha 分化
|
| 165 |
-
|
| 166 |
-
在 cached train128 上跑 50 optimizer steps:
|
| 167 |
-
|
| 168 |
-
```text
|
| 169 |
-
step 5:
|
| 170 |
-
proj_norm=0.074015
|
| 171 |
-
gate_norm=0.064479
|
| 172 |
-
proj_grad_norm=0.052291
|
| 173 |
-
gate_grad_norm=0.000170
|
| 174 |
-
alpha_mean=0.4999
|
| 175 |
-
alpha_std=0.0019
|
| 176 |
-
|
| 177 |
-
step 50:
|
| 178 |
-
proj_norm=0.428711
|
| 179 |
-
gate_norm=0.523223
|
| 180 |
-
proj_grad_norm=0.022453
|
| 181 |
-
gate_grad_norm=0.000504
|
| 182 |
-
alpha_mean=0.5063
|
| 183 |
-
alpha_std=0.0112
|
| 184 |
-
```
|
| 185 |
-
|
| 186 |
-
结论:
|
| 187 |
-
|
| 188 |
-
- `proj_norm` 从 0 稳定增长,注入分支有梯度;
|
| 189 |
-
- `gate_norm` 也开始增长,alpha 控制分支参与学习;
|
| 190 |
-
- `alpha_std` 从 0 增长,说明 gate 对不同输入有分化响应;
|
| 191 |
-
- 计算图、冻结范围、optimizer param groups 均正常。
|
| 192 |
-
|
| 193 |
-
### 3.3 overfit32:表达能力验证
|
| 194 |
-
|
| 195 |
-
cached train32 identity baseline:
|
| 196 |
-
|
| 197 |
-
```text
|
| 198 |
-
mIoU = 0.8814558
|
| 199 |
-
Fscore = 0.9375512
|
| 200 |
-
```
|
| 201 |
-
|
| 202 |
-
cached gate overfit32,200 steps,lr=1e-4:
|
| 203 |
-
|
| 204 |
-
```text
|
| 205 |
-
mIoU = 0.9085821
|
| 206 |
-
Fscore = 0.9444574
|
| 207 |
-
```
|
| 208 |
-
|
| 209 |
-
提升:
|
| 210 |
-
|
| 211 |
-
```text
|
| 212 |
-
mIoU = +0.0271263
|
| 213 |
-
Fscore = +0.0069063
|
| 214 |
-
```
|
| 215 |
-
|
| 216 |
-
结论:在 q、SAM image embeddings、mask decoder 原始参数均固定时,仅训练 A-min gate 就能明显提高训练集 mIoU,说明 gate 机制有表达能力,梯度路径通畅。
|
| 217 |
-
|
| 218 |
-
### 3.4 overfit32 泛化评估
|
| 219 |
-
|
| 220 |
-
对 cached eval split 前 200 条,identity baseline:
|
| 221 |
-
|
| 222 |
-
```text
|
| 223 |
-
test_s mIoU = 0.7390979
|
| 224 |
-
test_s Fscore = 0.8190672
|
| 225 |
-
|
| 226 |
-
test_u mIoU = 0.6732285
|
| 227 |
-
test_u Fscore = 0.7734924
|
| 228 |
-
|
| 229 |
-
test_n metric = 0.0606105
|
| 230 |
-
```
|
| 231 |
-
|
| 232 |
-
overfit32 gate checkpoint:
|
| 233 |
-
|
| 234 |
-
```text
|
| 235 |
-
test_s mIoU = 0.7199481
|
| 236 |
-
test_s Fscore = 0.8045849
|
| 237 |
-
|
| 238 |
-
test_u mIoU = 0.6672303
|
| 239 |
-
test_u Fscore = 0.7663978
|
| 240 |
-
|
| 241 |
-
test_n metric = 0.0648588
|
| 242 |
-
```
|
| 243 |
-
|
| 244 |
-
delta:
|
| 245 |
-
|
| 246 |
-
```text
|
| 247 |
-
test_s mIoU = -0.0191498
|
| 248 |
-
test_s Fscore = -0.0144823
|
| 249 |
-
|
| 250 |
-
test_u mIoU = -0.0059983
|
| 251 |
-
test_u Fscore = -0.0070946
|
| 252 |
-
|
| 253 |
-
test_n metric = +0.0042483
|
| 254 |
-
```
|
| 255 |
-
|
| 256 |
-
结论:
|
| 257 |
-
|
| 258 |
-
- overfit32 gate 没有泛化;
|
| 259 |
-
- Null metric 略升,说明小样本过拟合有轻微放大前景的倾向;
|
| 260 |
-
- 这不是方法失败,而是 32 个样本不足以学到泛化 referent anchoring 的预期结果;
|
| 261 |
-
- 下一步应扩大 cached train 样本量,并降低 lr。
|
| 262 |
-
|
| 263 |
-
---
|
| 264 |
-
|
| 265 |
-
## 4. 当前下一步实验:cached train256 gate-only
|
| 266 |
-
|
| 267 |
-
用户已经完成 train256 的 q 缓存。下一步用 train256 跑更保守的 gate-only 泛化实验。
|
| 268 |
-
|
| 269 |
-
### Step 1:训练 cached gate-only train256
|
| 270 |
-
|
| 271 |
-
```bash
|
| 272 |
-
cd /workspace/SimToken
|
| 273 |
-
mkdir -p log checkpoints
|
| 274 |
-
|
| 275 |
-
TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
|
| 276 |
-
--cache_split train \
|
| 277 |
-
--cache_root /workspace/SimToken/cache_q \
|
| 278 |
-
--name cached_gate_train256_s300_lr3e5 \
|
| 279 |
-
--epochs 20 \
|
| 280 |
-
--max_steps 300 \
|
| 281 |
-
--batch_size 8 \
|
| 282 |
-
--lr 3e-5 \
|
| 283 |
-
--saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
|
| 284 |
-
--log_root /workspace/SimToken/log \
|
| 285 |
-
--checkpoint_root /workspace/SimToken/checkpoints \
|
| 286 |
-
--log_gate_stats_every 50 \
|
| 287 |
-
--skip_eval_after_train \
|
| 288 |
-
--save_gate_only \
|
| 289 |
-
2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5.stdout
|
| 290 |
-
```
|
| 291 |
-
|
| 292 |
-
训练中重点观察:
|
| 293 |
-
|
| 294 |
-
```text
|
| 295 |
-
batch_miou / batch_fscore 是否逐步改善
|
| 296 |
-
proj_norm 是否持续增长
|
| 297 |
-
alpha_std 是否温和分化
|
| 298 |
-
Null 风险:alpha 是否出现极端偏移
|
| 299 |
-
```
|
| 300 |
-
|
| 301 |
-
如果 `proj_norm` 在前 100 steps 仍接近 0,说明 lr=3e-5 可能过小,可以改回 1e-4 或使用分层 lr。
|
| 302 |
-
|
| 303 |
-
### Step 2:评估 cached train256 gate checkpoint
|
| 304 |
-
|
| 305 |
-
```bash
|
| 306 |
-
for split in test_s test_u test_n; do
|
| 307 |
-
TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
|
| 308 |
-
--cache_split $split \
|
| 309 |
-
--cache_root /workspace/SimToken/cache_q \
|
| 310 |
-
--batch_size 8 \
|
| 311 |
-
--saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
|
| 312 |
-
--gate_checkpoint /workspace/SimToken/checkpoints/cached_gate_train256_s300_lr3e5.pth \
|
| 313 |
-
--eval_only \
|
| 314 |
-
--name cached_gate_train256_s300_lr3e5_${split}_200 \
|
| 315 |
-
2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5_${split}_200.stdout
|
| 316 |
-
done
|
| 317 |
-
```
|
| 318 |
-
|
| 319 |
-
对比 baseline 使用 3.4 中 identity 200 条结果。
|
| 320 |
-
|
| 321 |
-
### Step 3:根据结果决策
|
| 322 |
-
|
| 323 |
-
判断标准:
|
| 324 |
-
|
| 325 |
-
- Seen / Unseen 都提升:进入更大 cached train 或完整 A-min;
|
| 326 |
-
- Seen 提升、Unseen 不提升:gate 仍可能学 dataset pattern,需要更多 train cache 或更强正则;
|
| 327 |
-
- Seen / Unseen 都下降:不要跑完整 A-min,先调 lr、正则或 gate 容量;
|
| 328 |
-
- Null metric 保持 `< 0.07`:暂不加 area penalty;
|
| 329 |
-
- Null metric 超过 `0.10`:强危险信号,需要 area penalty 或约束预测面积。
|
| 330 |
-
|
| 331 |
-
如果 train256 有弱正收益但幅度小,先看 alpha 分布和 hard/easy frames,而不是立刻扩大。若 alpha 在所有帧上几乎一致,可能只是全局偏置;若 hard frames alpha 系统性更高,说明更像 referent anchoring。
|
| 332 |
-
|
| 333 |
-
---
|
| 334 |
-
|
| 335 |
-
## 5. 成功标准
|
| 336 |
-
|
| 337 |
-
A-min 成功不能只看总体 mIoU,需要同时满足:
|
| 338 |
-
|
| 339 |
-
1. Seen / Unseen mIoU 稳定提升;
|
| 340 |
-
2. Unseen 至少不弱于 Seen 的提升趋势;
|
| 341 |
-
3. Null 指标不恶化,预测面积不膨胀;
|
| 342 |
-
4. hard frames 改善更明显;
|
| 343 |
-
5. 如果记录 gate alpha,hard frames 的 alpha 应系统性高于 easy frames。
|
| 344 |
-
|
| 345 |
-
失败解释:
|
| 346 |
-
|
| 347 |
-
- 如果 Seen 提升、Unseen 不提升:可能是 gate 学到数据集模式,而不是 referent anchoring;
|
| 348 |
-
- 如果 Null 恶化:gate 可能放大了通用前景显著性;
|
| 349 |
-
- 如果 gate-only 无变化但完整 A-min 有收益:说明 gate 需要与 mask decoder / text projection 协同适配;
|
| 350 |
-
- 如果全 split 下降:gate 插入位置、初始化或学习率需要重新检查。
|
| 351 |
-
|
| 352 |
-
---
|
| 353 |
-
|
| 354 |
-
## 6. 后续机制分析
|
| 355 |
-
|
| 356 |
-
如果 A-min 有正收益,再做 hook 分析:
|
| 357 |
-
|
| 358 |
-
1. sparse self-attention 中 `mask_token -> q`;
|
| 359 |
-
2. token-to-image attention 中 mask token 对 image tokens 的关注;
|
| 360 |
-
3. A-min 前后 hard/easy frames 的 gate alpha;
|
| 361 |
-
4. `s_pred(beta=1.0)` 与 gate alpha 的关系。
|
| 362 |
-
|
| 363 |
-
这部分用于论文解释,不作为当前阻塞项。
|
| 364 |
-
|
| 365 |
-
---
|
| 366 |
-
|
| 367 |
-
## 7. 当前一句话结论
|
| 368 |
-
|
| 369 |
-
> A-min gate 的梯度路径、表达能力和 cached pipeline 一致性已经通过验证;overfit32 能显著提升训练集但不能泛化。当前主线是用更大 cached train set(已完成 train256 cache)验证 gate-only 泛化,再决定是否投入完整 A-min 联合训练。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_frame_sweep.py
DELETED
|
@@ -1,265 +0,0 @@
|
|
| 1 |
-
import csv
|
| 2 |
-
import os
|
| 3 |
-
import random
|
| 4 |
-
from functools import partial
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
import transformers
|
| 10 |
-
from torch.utils.data import DataLoader
|
| 11 |
-
|
| 12 |
-
from configs import args
|
| 13 |
-
from datasets import REFAVS
|
| 14 |
-
from decoder_invariance_check import build_model, set_seed
|
| 15 |
-
from load_model import collate_fn, dict_to_cuda
|
| 16 |
-
from utils import utility
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def decode_with_q(model, batch, q):
|
| 20 |
-
visual_model = model.get_model().visual_model
|
| 21 |
-
image_embeddings = batch["image_feats"][0]
|
| 22 |
-
|
| 23 |
-
sparse, dense = visual_model.prompt_encoder(
|
| 24 |
-
points=None,
|
| 25 |
-
boxes=None,
|
| 26 |
-
masks=None,
|
| 27 |
-
text_embeds=q.unsqueeze(1),
|
| 28 |
-
)
|
| 29 |
-
sparse = sparse.to(q.dtype)
|
| 30 |
-
dense = dense.to(q.dtype)
|
| 31 |
-
|
| 32 |
-
low_res_masks, iou_predictions = visual_model.mask_decoder(
|
| 33 |
-
image_embeddings=image_embeddings,
|
| 34 |
-
image_pe=visual_model.prompt_encoder.get_dense_pe(),
|
| 35 |
-
sparse_prompt_embeddings=sparse,
|
| 36 |
-
dense_prompt_embeddings=dense,
|
| 37 |
-
multimask_output=False,
|
| 38 |
-
)
|
| 39 |
-
pred_masks = visual_model.postprocess_masks(
|
| 40 |
-
low_res_masks,
|
| 41 |
-
input_size=batch["resizes"][0],
|
| 42 |
-
original_size=batch["orgsizes"][0],
|
| 43 |
-
).squeeze(1)
|
| 44 |
-
return pred_masks.unsqueeze(0), iou_predictions.squeeze(-1)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def get_q_for_target_frame(model, batch, target_frame):
|
| 48 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 49 |
-
output = model.forward(
|
| 50 |
-
images=batch["images"],
|
| 51 |
-
images_clip=batch["images_clip"],
|
| 52 |
-
audio_features=batch["audio_feats"],
|
| 53 |
-
image_features=batch["image_feats"],
|
| 54 |
-
input_ids=batch["input_ids"],
|
| 55 |
-
labels=batch["labels"],
|
| 56 |
-
attention_masks=batch["attention_masks"],
|
| 57 |
-
masks_list=batch["masks"],
|
| 58 |
-
resize_list=batch["resizes"],
|
| 59 |
-
orgsize_list=batch["orgsizes"],
|
| 60 |
-
conversation_list=batch["convs"],
|
| 61 |
-
refs_num=batch["refs_num"],
|
| 62 |
-
fids=batch["fids"],
|
| 63 |
-
vids=batch["vids"],
|
| 64 |
-
contrast=args.ct_weight,
|
| 65 |
-
ref_ids=batch["ref_ids"],
|
| 66 |
-
inference=True,
|
| 67 |
-
target_frame=target_frame,
|
| 68 |
-
)
|
| 69 |
-
return output["seg_embeddings"][0][0:1]
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def mask_area(pred_masks):
|
| 73 |
-
return (torch.sigmoid(pred_masks) > 0.4).float().mean().item()
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def mean_mask_iou_to_others(mask, other_masks):
|
| 77 |
-
if not other_masks:
|
| 78 |
-
return 1.0
|
| 79 |
-
binary = (torch.sigmoid(mask) > 0.4).float()
|
| 80 |
-
other_binary = [(torch.sigmoid(m) > 0.4).float() for m in other_masks]
|
| 81 |
-
vals = []
|
| 82 |
-
for other in other_binary:
|
| 83 |
-
inter = (binary * other).sum()
|
| 84 |
-
union = torch.maximum(binary, other).sum()
|
| 85 |
-
vals.append((inter / (union + 1e-7)).item())
|
| 86 |
-
return float(np.mean(vals))
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def evaluate_one_sample(model, batch, sample_idx):
|
| 90 |
-
rows = []
|
| 91 |
-
qs = []
|
| 92 |
-
pred_masks_by_tf = []
|
| 93 |
-
|
| 94 |
-
gt_masks = batch["masks"][0]
|
| 95 |
-
vid = batch["vids"][0]
|
| 96 |
-
ref = batch["refs"][0][0]
|
| 97 |
-
|
| 98 |
-
for target_frame in range(args.frame_n):
|
| 99 |
-
q = get_q_for_target_frame(model, batch, target_frame)
|
| 100 |
-
qs.append(q.float().squeeze(0))
|
| 101 |
-
|
| 102 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 103 |
-
pred_masks, iou_predictions = decode_with_q(model, batch, q)
|
| 104 |
-
pred_masks_by_tf.append(pred_masks.detach())
|
| 105 |
-
|
| 106 |
-
miou = utility.mask_iou(pred_masks.float(), gt_masks.float())
|
| 107 |
-
fscore = utility.Eval_Fmeasure(pred_masks.float(), gt_masks.float(), None)
|
| 108 |
-
null_metric = utility.metric_s_for_null(pred_masks.float())
|
| 109 |
-
area = mask_area(pred_masks)
|
| 110 |
-
mean_iou_pred = iou_predictions.float().mean().item()
|
| 111 |
-
|
| 112 |
-
rows.append(
|
| 113 |
-
{
|
| 114 |
-
"sample_idx": sample_idx,
|
| 115 |
-
"vid": vid,
|
| 116 |
-
"ref": ref,
|
| 117 |
-
"target_frame": target_frame,
|
| 118 |
-
"mean_iou_pred": mean_iou_pred,
|
| 119 |
-
"mask_area": area,
|
| 120 |
-
"null_metric": float(null_metric),
|
| 121 |
-
"miou": miou,
|
| 122 |
-
"fscore": fscore,
|
| 123 |
-
"cos_to_q5": 0.0,
|
| 124 |
-
"mean_cos_to_other_q": 0.0,
|
| 125 |
-
"mean_mask_iou_to_other_tf": 0.0,
|
| 126 |
-
}
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
q_stack = F.normalize(torch.stack(qs, dim=0), dim=-1)
|
| 130 |
-
q_cos = q_stack @ q_stack.T
|
| 131 |
-
q5_idx = min(5, len(qs) - 1)
|
| 132 |
-
|
| 133 |
-
for i, row in enumerate(rows):
|
| 134 |
-
other = [j for j in range(len(rows)) if j != i]
|
| 135 |
-
row["cos_to_q5"] = q_cos[i, q5_idx].item()
|
| 136 |
-
row["mean_cos_to_other_q"] = q_cos[i, other].mean().item()
|
| 137 |
-
row["mean_mask_iou_to_other_tf"] = mean_mask_iou_to_others(
|
| 138 |
-
pred_masks_by_tf[i], [pred_masks_by_tf[j] for j in other]
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
return rows
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
def print_sample_summary(rows):
|
| 145 |
-
print(f"\nSample {rows[0]['sample_idx']}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
|
| 146 |
-
print("tf | miou | fscore | null_s | iou_pred | area | cos_to_q5 | mean_q_cos")
|
| 147 |
-
for row in rows:
|
| 148 |
-
print(
|
| 149 |
-
f"{row['target_frame']:02d} | "
|
| 150 |
-
f"{row['miou']:.4f} | "
|
| 151 |
-
f"{row['fscore']:.4f} | "
|
| 152 |
-
f"{row['null_metric']:.4f} | "
|
| 153 |
-
f"{row['mean_iou_pred']:.4f} | "
|
| 154 |
-
f"{row['mask_area']:.4f} | "
|
| 155 |
-
f"{row['cos_to_q5']:.4f} | "
|
| 156 |
-
f"{row['mean_cos_to_other_q']:.4f}"
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
best_miou = max(rows, key=lambda x: x["miou"])
|
| 160 |
-
best_iou_pred = max(rows, key=lambda x: x["mean_iou_pred"])
|
| 161 |
-
fixed = rows[min(5, len(rows) - 1)]
|
| 162 |
-
miou_values = [row["miou"] for row in rows]
|
| 163 |
-
q5_values = [row["cos_to_q5"] for row in rows]
|
| 164 |
-
print(
|
| 165 |
-
"Best miou tf="
|
| 166 |
-
f"{best_miou['target_frame']} ({best_miou['miou']:.4f}); "
|
| 167 |
-
"best iou_pred tf="
|
| 168 |
-
f"{best_iou_pred['target_frame']} ({best_iou_pred['mean_iou_pred']:.4f}); "
|
| 169 |
-
f"fixed tf=5 miou={fixed['miou']:.4f}"
|
| 170 |
-
)
|
| 171 |
-
print(
|
| 172 |
-
f"target-frame miou range={max(miou_values) - min(miou_values):.4f}; "
|
| 173 |
-
f"min cos_to_q5={min(q5_values):.4f}"
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def main():
|
| 178 |
-
set_seed(42)
|
| 179 |
-
torch.set_grad_enabled(False)
|
| 180 |
-
|
| 181 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 182 |
-
args.mllm,
|
| 183 |
-
cache_dir=None,
|
| 184 |
-
model_max_length=2048,
|
| 185 |
-
padding_side="right",
|
| 186 |
-
use_fast=False,
|
| 187 |
-
)
|
| 188 |
-
tokenizer.pad_token = tokenizer.unk_token
|
| 189 |
-
tokenizer.add_tokens("[SEG]")
|
| 190 |
-
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 191 |
-
|
| 192 |
-
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 193 |
-
loader = DataLoader(
|
| 194 |
-
dataset,
|
| 195 |
-
batch_size=1,
|
| 196 |
-
shuffle=False,
|
| 197 |
-
num_workers=0,
|
| 198 |
-
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
|
| 202 |
-
print(f"Split: {args.eval_split} | samples to sweep: {limit}")
|
| 203 |
-
|
| 204 |
-
model = build_model(tokenizer, seg_token_idx)
|
| 205 |
-
|
| 206 |
-
all_rows = []
|
| 207 |
-
for sample_idx, batch in enumerate(loader):
|
| 208 |
-
if sample_idx >= limit:
|
| 209 |
-
break
|
| 210 |
-
batch = dict_to_cuda(batch)
|
| 211 |
-
rows = evaluate_one_sample(model, batch, sample_idx)
|
| 212 |
-
all_rows.extend(rows)
|
| 213 |
-
print_sample_summary(rows)
|
| 214 |
-
|
| 215 |
-
if not all_rows:
|
| 216 |
-
raise RuntimeError("No rows were checked. Is the selected split empty?")
|
| 217 |
-
|
| 218 |
-
fixed_rows = [r for r in all_rows if r["target_frame"] == min(5, args.frame_n - 1)]
|
| 219 |
-
oracle_by_sample = {}
|
| 220 |
-
iou_pred_by_sample = {}
|
| 221 |
-
for row in all_rows:
|
| 222 |
-
key = row["sample_idx"]
|
| 223 |
-
if key not in oracle_by_sample or row["miou"] > oracle_by_sample[key]["miou"]:
|
| 224 |
-
oracle_by_sample[key] = row
|
| 225 |
-
if key not in iou_pred_by_sample or row["mean_iou_pred"] > iou_pred_by_sample[key]["mean_iou_pred"]:
|
| 226 |
-
iou_pred_by_sample[key] = row
|
| 227 |
-
|
| 228 |
-
fixed_miou = np.mean([r["miou"] for r in fixed_rows])
|
| 229 |
-
fixed_null_metric = np.mean([r["null_metric"] for r in fixed_rows])
|
| 230 |
-
oracle_miou = np.mean([r["miou"] for r in oracle_by_sample.values()])
|
| 231 |
-
iou_pred_selected_miou = np.mean([r["miou"] for r in iou_pred_by_sample.values()])
|
| 232 |
-
min_cos_to_q5 = np.mean(
|
| 233 |
-
[min(r["cos_to_q5"] for r in all_rows if r["sample_idx"] == sample_idx) for sample_idx in oracle_by_sample]
|
| 234 |
-
)
|
| 235 |
-
mean_miou_range = np.mean(
|
| 236 |
-
[
|
| 237 |
-
max(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
|
| 238 |
-
- min(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
|
| 239 |
-
for sample_idx in oracle_by_sample
|
| 240 |
-
]
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
print("\nSummary")
|
| 244 |
-
print(f"samples: {len(fixed_rows)}")
|
| 245 |
-
print(f"fixed target_frame=5 mean miou: {fixed_miou:.4f}")
|
| 246 |
-
print(f"fixed target_frame=5 mean null_s: {fixed_null_metric:.4f}")
|
| 247 |
-
print(f"oracle best-target-frame mean miou: {oracle_miou:.4f}")
|
| 248 |
-
print(f"best-by-iou_pred selected mean miou: {iou_pred_selected_miou:.4f}")
|
| 249 |
-
print(f"oracle gain over fixed: {oracle_miou - fixed_miou:+.4f}")
|
| 250 |
-
print(f"iou_pred-selection gain over fixed: {iou_pred_selected_miou - fixed_miou:+.4f}")
|
| 251 |
-
print(f"mean target-frame miou range: {mean_miou_range:.4f}")
|
| 252 |
-
print(f"mean sample min cos_to_q5: {min_cos_to_q5:.4f}")
|
| 253 |
-
|
| 254 |
-
csv_path = os.environ.get("TARGET_FRAME_SWEEP_CSV")
|
| 255 |
-
if csv_path:
|
| 256 |
-
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 257 |
-
with open(csv_path, "w", newline="") as f:
|
| 258 |
-
writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
|
| 259 |
-
writer.writeheader()
|
| 260 |
-
writer.writerows(all_rows)
|
| 261 |
-
print(f"Saved CSV: {csv_path}")
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
if __name__ == "__main__":
|
| 265 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_cached_gate.py
DELETED
|
@@ -1,439 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import os
|
| 3 |
-
import random
|
| 4 |
-
|
| 5 |
-
import cv2
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import transformers
|
| 9 |
-
from torch.optim import AdamW
|
| 10 |
-
from torch.utils.data import DataLoader, Dataset, Subset
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
from configs import args
|
| 14 |
-
from decoder_invariance_check import build_model, set_seed
|
| 15 |
-
from models.avs_model import dice_loss, sigmoid_ce_loss
|
| 16 |
-
from utils import utility
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def _total_norm(values):
|
| 20 |
-
if not values:
|
| 21 |
-
return 0.0
|
| 22 |
-
return float(sum(v * v for v in values) ** 0.5)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def collect_referent_gate_stats(model):
|
| 26 |
-
gate_modules = [(n, m) for n, m in model.named_modules() if n.endswith("referent_gate")]
|
| 27 |
-
proj_norms = []
|
| 28 |
-
gate_norms = []
|
| 29 |
-
proj_grad_norms = []
|
| 30 |
-
gate_grad_norms = []
|
| 31 |
-
alpha_tensors = []
|
| 32 |
-
|
| 33 |
-
for _, module in gate_modules:
|
| 34 |
-
proj_norms.append(module.proj.weight.detach().float().norm().item())
|
| 35 |
-
gate_norms.append(module.gate.weight.detach().float().norm().item())
|
| 36 |
-
if module.proj.weight.grad is not None:
|
| 37 |
-
proj_grad_norms.append(module.proj.weight.grad.detach().float().norm().item())
|
| 38 |
-
if module.gate.weight.grad is not None:
|
| 39 |
-
gate_grad_norms.append(module.gate.weight.grad.detach().float().norm().item())
|
| 40 |
-
if module.last_alpha is not None:
|
| 41 |
-
alpha_tensors.append(module.last_alpha.detach().float().reshape(-1))
|
| 42 |
-
|
| 43 |
-
stats = {
|
| 44 |
-
"modules": len(gate_modules),
|
| 45 |
-
"proj_norm": _total_norm(proj_norms),
|
| 46 |
-
"gate_norm": _total_norm(gate_norms),
|
| 47 |
-
"proj_grad_norm": _total_norm(proj_grad_norms),
|
| 48 |
-
"gate_grad_norm": _total_norm(gate_grad_norms),
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
if alpha_tensors:
|
| 52 |
-
alpha = torch.cat(alpha_tensors)
|
| 53 |
-
stats.update(
|
| 54 |
-
{
|
| 55 |
-
"alpha_mean": alpha.mean().item(),
|
| 56 |
-
"alpha_std": alpha.std(unbiased=False).item(),
|
| 57 |
-
"alpha_min": alpha.min().item(),
|
| 58 |
-
"alpha_max": alpha.max().item(),
|
| 59 |
-
}
|
| 60 |
-
)
|
| 61 |
-
else:
|
| 62 |
-
stats.update(
|
| 63 |
-
{
|
| 64 |
-
"alpha_mean": float("nan"),
|
| 65 |
-
"alpha_std": float("nan"),
|
| 66 |
-
"alpha_min": float("nan"),
|
| 67 |
-
"alpha_max": float("nan"),
|
| 68 |
-
}
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
return stats
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def zero_referent_gate(model):
|
| 75 |
-
with torch.no_grad():
|
| 76 |
-
for _, module in model.named_modules():
|
| 77 |
-
if not _.endswith("referent_gate"):
|
| 78 |
-
continue
|
| 79 |
-
module.gate.weight.zero_()
|
| 80 |
-
module.gate.bias.zero_()
|
| 81 |
-
module.proj.weight.zero_()
|
| 82 |
-
module.proj.bias.zero_()
|
| 83 |
-
module.last_alpha = None
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def referent_gate_state_dict(model):
|
| 87 |
-
return {
|
| 88 |
-
name: param.detach().cpu()
|
| 89 |
-
for name, param in model.state_dict().items()
|
| 90 |
-
if "referent_gate" in name
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def load_referent_gate_checkpoint(model, path):
|
| 95 |
-
checkpoint = torch.load(path, map_location="cpu")
|
| 96 |
-
if isinstance(checkpoint, dict) and checkpoint.get("type") == "referent_gate_only":
|
| 97 |
-
checkpoint = checkpoint["state_dict"]
|
| 98 |
-
gate_state = {k: v for k, v in checkpoint.items() if "referent_gate" in k}
|
| 99 |
-
if not gate_state:
|
| 100 |
-
raise RuntimeError(f"No referent_gate parameters found in {path}")
|
| 101 |
-
current = model.state_dict()
|
| 102 |
-
missing_shape = [
|
| 103 |
-
k
|
| 104 |
-
for k, v in gate_state.items()
|
| 105 |
-
if k not in current or tuple(current[k].shape) != tuple(v.shape)
|
| 106 |
-
]
|
| 107 |
-
if missing_shape:
|
| 108 |
-
raise RuntimeError(f"Gate checkpoint has incompatible keys: {missing_shape[:5]}")
|
| 109 |
-
current.update(gate_state)
|
| 110 |
-
model.load_state_dict(current, strict=True)
|
| 111 |
-
print(f"loaded referent gate checkpoint: {path} ({len(gate_state)} tensors)")
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def log_gate_stats(model, step, loss_value, batch_metrics=None):
|
| 115 |
-
stats = collect_referent_gate_stats(model)
|
| 116 |
-
metric_text = ""
|
| 117 |
-
if batch_metrics is not None:
|
| 118 |
-
metric_text = (
|
| 119 |
-
f"batch_miou={batch_metrics['miou']:.4f} "
|
| 120 |
-
f"batch_fscore={batch_metrics['fscore']:.4f} "
|
| 121 |
-
)
|
| 122 |
-
message = (
|
| 123 |
-
f"gate_stats step={step} "
|
| 124 |
-
f"loss={loss_value:.6f} "
|
| 125 |
-
f"{metric_text}"
|
| 126 |
-
f"proj_norm={stats['proj_norm']:.6f} "
|
| 127 |
-
f"gate_norm={stats['gate_norm']:.6f} "
|
| 128 |
-
f"proj_grad_norm={stats['proj_grad_norm']:.6f} "
|
| 129 |
-
f"gate_grad_norm={stats['gate_grad_norm']:.6f} "
|
| 130 |
-
f"alpha_mean={stats['alpha_mean']:.4f} "
|
| 131 |
-
f"alpha_std={stats['alpha_std']:.4f} "
|
| 132 |
-
f"alpha_min={stats['alpha_min']:.4f} "
|
| 133 |
-
f"alpha_max={stats['alpha_max']:.4f}"
|
| 134 |
-
)
|
| 135 |
-
print(message)
|
| 136 |
-
os.makedirs(args.log_root, exist_ok=True)
|
| 137 |
-
with open(os.path.join(args.log_root, f"{args.name}.txt"), "a") as f:
|
| 138 |
-
f.write(message + "\n")
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
class CachedQDataset(Dataset):
|
| 142 |
-
def __init__(self, split, cfg):
|
| 143 |
-
self.split = split
|
| 144 |
-
self.cfg = cfg
|
| 145 |
-
self.root = os.path.join(cfg.cache_root, split)
|
| 146 |
-
self.index_path = os.path.join(self.root, "index.jsonl")
|
| 147 |
-
if not os.path.exists(self.index_path):
|
| 148 |
-
raise FileNotFoundError(f"Missing cache index: {self.index_path}")
|
| 149 |
-
with open(self.index_path) as f:
|
| 150 |
-
self.rows = [json.loads(line) for line in f if line.strip()]
|
| 151 |
-
|
| 152 |
-
def __len__(self):
|
| 153 |
-
return len(self.rows)
|
| 154 |
-
|
| 155 |
-
def _load_masks(self, vid, fids):
|
| 156 |
-
masks = []
|
| 157 |
-
for fid in fids:
|
| 158 |
-
frames = []
|
| 159 |
-
for frame_idx in range(self.cfg.frame_n):
|
| 160 |
-
path = os.path.join(
|
| 161 |
-
self.cfg.data_dir,
|
| 162 |
-
"gt_mask",
|
| 163 |
-
vid,
|
| 164 |
-
f"fid_{int(fid)}",
|
| 165 |
-
f"0000{frame_idx}.png",
|
| 166 |
-
)
|
| 167 |
-
mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
| 168 |
-
if mask is None:
|
| 169 |
-
raise FileNotFoundError(path)
|
| 170 |
-
frames.append(torch.as_tensor(mask > 0, dtype=torch.float32))
|
| 171 |
-
masks.append(torch.stack(frames, dim=0))
|
| 172 |
-
return torch.stack(masks, dim=0)
|
| 173 |
-
|
| 174 |
-
def __getitem__(self, idx):
|
| 175 |
-
row = self.rows[idx]
|
| 176 |
-
cache = torch.load(os.path.join(self.root, row["path"]), map_location="cpu")
|
| 177 |
-
vid = cache["vid"]
|
| 178 |
-
return {
|
| 179 |
-
"sample_idx": cache["sample_idx"],
|
| 180 |
-
"vid": vid,
|
| 181 |
-
"refs": cache["refs"],
|
| 182 |
-
"fids": cache["fids"],
|
| 183 |
-
"q": cache["q"].float(),
|
| 184 |
-
"image_embeddings": torch.load(
|
| 185 |
-
os.path.join(self.cfg.data_dir, "image_embed", f"{vid}.pt"),
|
| 186 |
-
map_location="cpu",
|
| 187 |
-
).float(),
|
| 188 |
-
"gt_masks": self._load_masks(vid, cache["fids"]),
|
| 189 |
-
"resize": tuple(cache["resize"]),
|
| 190 |
-
"orgsize": tuple(cache["orgsize"]),
|
| 191 |
-
}
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def collate_cached(batch):
|
| 195 |
-
return batch
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def decode_batch(visual_model, batch, device):
|
| 199 |
-
image_pe = visual_model.prompt_encoder.get_dense_pe()
|
| 200 |
-
frame_qs = []
|
| 201 |
-
frame_image_embeddings = []
|
| 202 |
-
prompt_spans = []
|
| 203 |
-
|
| 204 |
-
for sample_idx, sample in enumerate(batch):
|
| 205 |
-
q = sample["q"].to(device=device, dtype=torch.float32)
|
| 206 |
-
image_embeddings = sample["image_embeddings"].to(device=device, dtype=torch.float32)
|
| 207 |
-
frames = image_embeddings.shape[0]
|
| 208 |
-
for prompt_idx in range(q.shape[0]):
|
| 209 |
-
start = len(frame_qs) * frames
|
| 210 |
-
frame_qs.append(q[prompt_idx].unsqueeze(0).expand(frames, -1))
|
| 211 |
-
frame_image_embeddings.append(image_embeddings)
|
| 212 |
-
prompt_spans.append((sample_idx, prompt_idx, start, start + frames))
|
| 213 |
-
|
| 214 |
-
if not frame_qs:
|
| 215 |
-
raise RuntimeError("No cached prompts were provided for decoding.")
|
| 216 |
-
|
| 217 |
-
frame_qs = torch.cat(frame_qs, dim=0)
|
| 218 |
-
frame_image_embeddings = torch.cat(frame_image_embeddings, dim=0)
|
| 219 |
-
sparse_embeddings, dense_embeddings = visual_model.prompt_encoder(
|
| 220 |
-
points=None,
|
| 221 |
-
boxes=None,
|
| 222 |
-
masks=None,
|
| 223 |
-
text_embeds=frame_qs.unsqueeze(1),
|
| 224 |
-
)
|
| 225 |
-
sparse_embeddings = sparse_embeddings.to(frame_qs.dtype)
|
| 226 |
-
dense_embeddings = dense_embeddings.to(frame_qs.dtype)
|
| 227 |
-
|
| 228 |
-
low_res_masks = visual_model.mask_decoder.forward_modified_v3(
|
| 229 |
-
image_embeddings=frame_image_embeddings,
|
| 230 |
-
image_pe=image_pe,
|
| 231 |
-
sparse_prompt_embeddings=sparse_embeddings,
|
| 232 |
-
dense_prompt_embeddings=dense_embeddings,
|
| 233 |
-
).unsqueeze(1)
|
| 234 |
-
|
| 235 |
-
pred_by_sample = [[] for _ in batch]
|
| 236 |
-
for sample_idx, _, start, end in prompt_spans:
|
| 237 |
-
sample = batch[sample_idx]
|
| 238 |
-
pred_mask = visual_model.postprocess_masks(
|
| 239 |
-
low_res_masks[start:end],
|
| 240 |
-
input_size=sample["resize"],
|
| 241 |
-
original_size=sample["orgsize"],
|
| 242 |
-
)
|
| 243 |
-
pred_by_sample[sample_idx].append(pred_mask.squeeze(1))
|
| 244 |
-
|
| 245 |
-
return [torch.stack(pred_masks, dim=0) for pred_masks in pred_by_sample]
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
def decode_sample(visual_model, sample, device):
|
| 249 |
-
return decode_batch(visual_model, [sample], device)[0]
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def compute_mask_loss(pred_masks, gt_masks):
|
| 253 |
-
mask_bce_loss = 0.0
|
| 254 |
-
mask_dice_loss = 0.0
|
| 255 |
-
num_masks = 0
|
| 256 |
-
|
| 257 |
-
for pred_mask, gt_mask in zip(pred_masks, gt_masks):
|
| 258 |
-
gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
|
| 259 |
-
num_seg, frames, height, width = gt_mask.shape
|
| 260 |
-
gt_flat = gt_mask.view(num_seg * frames, height, width)
|
| 261 |
-
pred_flat = pred_mask.view(num_seg * frames, height, width)
|
| 262 |
-
|
| 263 |
-
mask_bce_loss = mask_bce_loss + (
|
| 264 |
-
sigmoid_ce_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
|
| 265 |
-
* gt_flat.shape[0]
|
| 266 |
-
)
|
| 267 |
-
mask_dice_loss = mask_dice_loss + (
|
| 268 |
-
dice_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
|
| 269 |
-
* gt_flat.shape[0]
|
| 270 |
-
)
|
| 271 |
-
num_masks += gt_flat.shape[0]
|
| 272 |
-
|
| 273 |
-
mask_bce_loss = 2.0 * mask_bce_loss / (num_masks + 1e-8)
|
| 274 |
-
mask_dice_loss = 0.5 * mask_dice_loss / (num_masks + 1e-8)
|
| 275 |
-
return mask_bce_loss + mask_dice_loss
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
def compute_batch_metrics(pred_masks, gt_masks):
|
| 279 |
-
total_iou = 0.0
|
| 280 |
-
total_fscore = 0.0
|
| 281 |
-
count = 0
|
| 282 |
-
for pred_mask, gt_mask in zip(pred_masks, gt_masks):
|
| 283 |
-
gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
|
| 284 |
-
num_seg, frames = pred_mask.shape[:2]
|
| 285 |
-
weight = num_seg * frames
|
| 286 |
-
total_iou += utility.mask_iou(pred_mask.detach().float(), gt_mask.float()) * weight
|
| 287 |
-
total_fscore += utility.Eval_Fmeasure(pred_mask.detach().float(), gt_mask.float(), None) * weight
|
| 288 |
-
count += weight
|
| 289 |
-
return {
|
| 290 |
-
"miou": total_iou / max(1, count),
|
| 291 |
-
"fscore": total_fscore / max(1, count),
|
| 292 |
-
}
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def evaluate(model, loader):
|
| 296 |
-
model.eval()
|
| 297 |
-
visual_model = model.get_model().visual_model
|
| 298 |
-
total_iou = 0.0
|
| 299 |
-
total_fscore = 0.0
|
| 300 |
-
total_null = 0.0
|
| 301 |
-
count = 0
|
| 302 |
-
|
| 303 |
-
with torch.no_grad():
|
| 304 |
-
for batch in tqdm(loader, desc=f"Cached eval {args.cache_split}"):
|
| 305 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 306 |
-
batch_pred = decode_batch(visual_model, batch, "cuda")
|
| 307 |
-
for sample, pred in zip(batch, batch_pred):
|
| 308 |
-
gt = sample["gt_masks"].to(device=pred.device, dtype=pred.dtype)
|
| 309 |
-
num_seg, frames = pred.shape[:2]
|
| 310 |
-
weight = num_seg * frames
|
| 311 |
-
if args.cache_split == "test_n":
|
| 312 |
-
total_null += float(utility.metric_s_for_null(pred.float())) * weight
|
| 313 |
-
else:
|
| 314 |
-
total_iou += utility.mask_iou(pred.float(), gt.float()) * weight
|
| 315 |
-
total_fscore += utility.Eval_Fmeasure(pred.float(), gt.float(), None) * weight
|
| 316 |
-
count += weight
|
| 317 |
-
|
| 318 |
-
if count == 0:
|
| 319 |
-
raise RuntimeError("No cached samples were evaluated.")
|
| 320 |
-
|
| 321 |
-
if args.cache_split == "test_n":
|
| 322 |
-
print(f"cached valuate on test_n_refer, metric: {total_null / count}")
|
| 323 |
-
else:
|
| 324 |
-
print(
|
| 325 |
-
f"cached valuate on {args.cache_split}: "
|
| 326 |
-
f"miou: {total_iou / count} fscore: {total_fscore / count}"
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
def train(model, loader):
|
| 331 |
-
if args.disable_gate:
|
| 332 |
-
raise ValueError("--disable_gate is only valid with --eval_only")
|
| 333 |
-
|
| 334 |
-
for p in model.parameters():
|
| 335 |
-
p.requires_grad = False
|
| 336 |
-
for name, p in model.named_parameters():
|
| 337 |
-
if "referent_gate" in name:
|
| 338 |
-
p.requires_grad = True
|
| 339 |
-
|
| 340 |
-
gate_params = [p for p in model.parameters() if p.requires_grad]
|
| 341 |
-
optimizer = AdamW(gate_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
|
| 342 |
-
stats = collect_referent_gate_stats(model)
|
| 343 |
-
print(
|
| 344 |
-
"cached gate init: "
|
| 345 |
-
f"modules={stats['modules']} "
|
| 346 |
-
f"proj_norm={stats['proj_norm']:.6f} "
|
| 347 |
-
f"gate_norm={stats['gate_norm']:.6f} "
|
| 348 |
-
f"trainable_params={sum(p.numel() for p in gate_params)}"
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
visual_model = model.get_model().visual_model
|
| 352 |
-
step = 0
|
| 353 |
-
for epoch in range(args.epochs):
|
| 354 |
-
model.train()
|
| 355 |
-
order_loader = loader
|
| 356 |
-
for batch in tqdm(order_loader, desc=f"Cached gate train {epoch + 1}/{args.epochs}"):
|
| 357 |
-
if args.max_steps > 0 and step >= args.max_steps:
|
| 358 |
-
break
|
| 359 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 360 |
-
pred_masks = decode_batch(visual_model, batch, "cuda")
|
| 361 |
-
gt_masks = [sample["gt_masks"] for sample in batch]
|
| 362 |
-
|
| 363 |
-
loss = compute_mask_loss(pred_masks, gt_masks)
|
| 364 |
-
optimizer.zero_grad()
|
| 365 |
-
loss.backward()
|
| 366 |
-
step += 1
|
| 367 |
-
if args.log_gate_stats_every > 0 and step % args.log_gate_stats_every == 0:
|
| 368 |
-
batch_metrics = compute_batch_metrics(pred_masks, gt_masks)
|
| 369 |
-
log_gate_stats(model, step, loss.item(), batch_metrics)
|
| 370 |
-
optimizer.step()
|
| 371 |
-
|
| 372 |
-
if args.max_steps > 0 and step >= args.max_steps:
|
| 373 |
-
print(f"stopped early at cached optimizer step {step}")
|
| 374 |
-
break
|
| 375 |
-
|
| 376 |
-
os.makedirs(args.checkpoint_root, exist_ok=True)
|
| 377 |
-
ckpt_path = os.path.join(args.checkpoint_root, f"{args.name}.pth")
|
| 378 |
-
if args.save_gate_only:
|
| 379 |
-
torch.save(
|
| 380 |
-
{
|
| 381 |
-
"type": "referent_gate_only",
|
| 382 |
-
"base_model": args.saved_model,
|
| 383 |
-
"state_dict": referent_gate_state_dict(model),
|
| 384 |
-
},
|
| 385 |
-
ckpt_path,
|
| 386 |
-
)
|
| 387 |
-
else:
|
| 388 |
-
torch.save(model.state_dict(), ckpt_path)
|
| 389 |
-
print(f"cached gate model saved as {ckpt_path}")
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
def main():
|
| 393 |
-
set_seed(42)
|
| 394 |
-
random.seed(42)
|
| 395 |
-
np.random.seed(42)
|
| 396 |
-
|
| 397 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 398 |
-
args.mllm,
|
| 399 |
-
cache_dir=None,
|
| 400 |
-
model_max_length=2048,
|
| 401 |
-
padding_side="right",
|
| 402 |
-
use_fast=False,
|
| 403 |
-
)
|
| 404 |
-
tokenizer.pad_token = tokenizer.unk_token
|
| 405 |
-
tokenizer.add_tokens("[SEG]")
|
| 406 |
-
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 407 |
-
|
| 408 |
-
dataset = CachedQDataset(args.cache_split, args)
|
| 409 |
-
if args.overfit_samples > 0:
|
| 410 |
-
n = min(args.overfit_samples, len(dataset))
|
| 411 |
-
dataset = Subset(dataset, list(range(n)))
|
| 412 |
-
print(f"cached overfit_samples enabled: using first {n} samples")
|
| 413 |
-
|
| 414 |
-
loader = DataLoader(
|
| 415 |
-
dataset,
|
| 416 |
-
batch_size=args.batch_size,
|
| 417 |
-
shuffle=not args.eval_only,
|
| 418 |
-
num_workers=4,
|
| 419 |
-
collate_fn=collate_cached,
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
model = build_model(tokenizer, seg_token_idx)
|
| 423 |
-
if args.gate_checkpoint:
|
| 424 |
-
load_referent_gate_checkpoint(model, args.gate_checkpoint)
|
| 425 |
-
if args.disable_gate:
|
| 426 |
-
zero_referent_gate(model)
|
| 427 |
-
print("disable_gate enabled: referent gate forced to identity")
|
| 428 |
-
|
| 429 |
-
if args.eval_only:
|
| 430 |
-
evaluate(model, loader)
|
| 431 |
-
return
|
| 432 |
-
|
| 433 |
-
train(model, loader)
|
| 434 |
-
if not args.skip_eval_after_train:
|
| 435 |
-
evaluate(model, loader)
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
if __name__ == "__main__":
|
| 439 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
upload_hf.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
"""Upload the current SimToken workspace to HuggingFace Hub.
|
| 2 |
-
|
| 3 |
-
Example:
|
| 4 |
-
python upload_hf.py --repo yfan07/SimToken
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import argparse
|
| 10 |
-
import logging
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
|
| 13 |
-
from huggingface_hub import HfApi, create_repo
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
ROOT = Path(__file__).resolve().parent
|
| 17 |
-
|
| 18 |
-
IGNORE_PATTERNS = [
|
| 19 |
-
".git/**",
|
| 20 |
-
"**/__pycache__/**",
|
| 21 |
-
"**/.pytest_cache/**",
|
| 22 |
-
"**/.cache/**",
|
| 23 |
-
"**/*.pyc",
|
| 24 |
-
"**/*.pyo",
|
| 25 |
-
"upload.log",
|
| 26 |
-
]
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def parse_args() -> argparse.Namespace:
|
| 30 |
-
parser = argparse.ArgumentParser(description="Upload SimToken to HuggingFace Hub.")
|
| 31 |
-
parser.add_argument("--repo", required=True, help="Repo id, e.g. yfan07/SimToken")
|
| 32 |
-
parser.add_argument("--repo_type", default="model", choices=["model", "dataset", "space"])
|
| 33 |
-
parser.add_argument("--private", action="store_true", help="Create repo as private if missing.")
|
| 34 |
-
parser.add_argument("--num_workers", type=int, default=4)
|
| 35 |
-
return parser.parse_args()
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def main() -> None:
|
| 39 |
-
args = parse_args()
|
| 40 |
-
logging.basicConfig(
|
| 41 |
-
level=logging.INFO,
|
| 42 |
-
format="%(asctime)s %(levelname)s %(message)s",
|
| 43 |
-
handlers=[logging.FileHandler(ROOT / "upload.log"), logging.StreamHandler()],
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
create_repo(
|
| 47 |
-
repo_id=args.repo,
|
| 48 |
-
repo_type=args.repo_type,
|
| 49 |
-
private=args.private,
|
| 50 |
-
exist_ok=True,
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
api = HfApi()
|
| 54 |
-
if hasattr(api, "upload_large_folder"):
|
| 55 |
-
logging.info("Uploading %s to %s with upload_large_folder", ROOT, args.repo)
|
| 56 |
-
api.upload_large_folder(
|
| 57 |
-
repo_id=args.repo,
|
| 58 |
-
repo_type=args.repo_type,
|
| 59 |
-
folder_path=str(ROOT),
|
| 60 |
-
ignore_patterns=IGNORE_PATTERNS,
|
| 61 |
-
num_workers=args.num_workers,
|
| 62 |
-
)
|
| 63 |
-
else:
|
| 64 |
-
logging.info("Uploading %s to %s with upload_folder", ROOT, args.repo)
|
| 65 |
-
api.upload_folder(
|
| 66 |
-
repo_id=args.repo,
|
| 67 |
-
repo_type=args.repo_type,
|
| 68 |
-
folder_path=str(ROOT),
|
| 69 |
-
ignore_patterns=IGNORE_PATTERNS,
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
if __name__ == "__main__":
|
| 74 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|