yfan07 commited on
Commit
a88521e
·
verified ·
1 Parent(s): f3bc639

Add files using upload-large-folder tool

Browse files
.huggingfaceignore CHANGED
@@ -1,4 +1,8 @@
1
  __pycache__/
2
  **/__pycache__/
3
  *.pyc
 
4
  .git/
 
 
 
 
1
  __pycache__/
2
  **/__pycache__/
3
  *.pyc
4
+ *.pyo
5
  .git/
6
+ **/.pytest_cache/
7
+ **/.cache/
8
+ upload.log
Residual_Prompt_Bridge.md ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Residual Prompt Bridge 论文导向实验路线图
2
+
3
+ ## 1. 当前主 claim
4
+
5
+ 论文主 claim 现在正式锁定为:
6
+
7
+ > **We propose an image-conditioned directional prompt correction module that orthogonalizes prompt updates to steer language-side prompts toward a more decodable SAM prompt manifold, mitigating cross-distribution prompt interface mismatch.**
8
+
9
+ 对应中文表述:
10
+
11
+ > **我们提出一种图像条件的方向型 prompt correction,通过正交化更新把语言侧 prompt 朝更可解码的 SAM prompt manifold 偏转,从而缓解跨分布的 prompt 接口失配。**
12
+
13
+ 从现在开始,所有实验都只服务这句 claim,不再让方法故事扩散成“大而全系统”。
14
+
15
+ ---
16
+
17
+ ## 2. 当前项目定位
18
+
19
+ 当前 RPB 项目已经完成了最关键的早期筛查:
20
+
21
+ 1. **实现正确性通过**
22
+ - checkpoint / LoRA 兼容问题已修复
23
+ - bridge 路径不会自动破坏 baseline
24
+ - identity-preserving sanity check 已通过
25
+
26
+ 2. **几何机制方向明确**
27
+ - additive residual 不足以推动 `p_hat` 离开 `q`
28
+ - directional bridge 明显优于 additive
29
+ - orthogonalization 能把 residual 预算从径向缩放转成方向修正
30
+
31
+ 3. **当前最小核心已浮现**
32
+ - `image-conditioned`
33
+ - `p_mask-only`
34
+ - `directional`
35
+ - `orthogonal`
36
+ - `single-token correction`
37
+
38
+ 4. **mixed 的角色目前仍未定型**
39
+ - weak mixed 不会抹掉 bridge
40
+ - 但目前更像 enhancer / compatibility probe,而不是稳定的 decoder-facing calibration mechanism
41
+
42
+ 因此,当前最重要的不是继续加模块,而是把这个**最小有效核心**做成稳定、可复现、可投稿的方法骨架。
43
+
44
+ ---
45
+
46
+ ## 3. 两套判据:Mechanism Pass vs Paper Pass
47
+
48
+ ### 3.1 Mechanism pass
49
+
50
+ 回答的问题是:
51
+
52
+ > 这个方法设计是否真的抓住了问题本质?
53
+
54
+ 当前 mechanism pass 需要被下面这些证据支撑:
55
+
56
+ - additive vs directional:directional 明显更能让 `p_hat` 离开 identity
57
+ - without orthogonal vs with orthogonal:orthogonalization 明显改善 `Δp` 的几何利用效率
58
+ - `Δp` 稳定朝 `p_mask`
59
+ - `p_hat` 能明显离开 `q`
60
+ - seen/unseen 的 alignment ratio 健康
61
+ - weak mixed 不会直接把 bridge 拉回 baseline
62
+
63
+ ### 3.2 Paper pass
64
+
65
+ 回答的问题是:
66
+
67
+ > 这个方法是否已经强到能单独撑起一篇顶会方法论文?
68
+
69
+ paper pass 需要下面这些更强条件:
70
+
71
+ - 更大规模评估上有稳定、同向的 headline 趋势
72
+ - 至少在 unseen 上有清晰、可复现的优势
73
+ - seen / null 的代价可接受
74
+ - 2 个随机种子下趋势稳定
75
+ - 最小闭环 ablation 完整
76
+
77
+ 当前状态:
78
+
79
+ - **mechanism pass:接近通过,但还缺更大规模验证和关键 baseline**
80
+ - **paper pass:尚未通过**
81
+
82
+ 后续每组实验都要明确写清楚:它是在推进 mechanism pass,还是在推进 paper pass。
83
+
84
+ ---
85
+
86
+ ## 4. 冻结最小核心方法
87
+
88
+ 在 pure RPB standalone 路线中,当前只保留下列组成:
89
+
90
+ - `image-conditioned correction`
91
+ - `p_mask-only teacher`
92
+ - `directional bridge`
93
+ - `orthogonalized update`
94
+ - `single-token prompt correction`
95
+
96
+ 当前明确**不进入主线**的内容:
97
+
98
+ - `z_gt` 作为主 teacher
99
+ - calibrator
100
+ - refinement
101
+ - 多 token bridge
102
+ - 大而全的完整 bridge 系统
103
+
104
+ 这些内容后续最多作为 ablation、扩展或 hybrid 组件,而不是当前主方法本体。
105
+
106
+ ---
107
+
108
+ ## 5. 当前实验事实总结
109
+
110
+ ### 5.1 已确认的正结果
111
+
112
+ - bridge 可以安全接入,不会自动毁掉 baseline
113
+ - 修复 checkpoint / LoRA 后,RPB 路径与 baseline 基本等价
114
+ - `directional + orthogonal` 后:
115
+ - `Δp` 高度对齐 `p_mask`
116
+ - `Δp` 不再主要沿 `q` 的平行方向浪费预算
117
+ - `p_hat` 能够明显离开 identity 区
118
+ - `p_mask-only teacher-only` 已在 quick eval 上给出:
119
+ - seen 小幅回落但可控
120
+ - unseen 轻微正信号
121
+ - null 基本持平
122
+
123
+ ### 5.2 已确认的负结果
124
+
125
+ - additive residual 不足以真正旋转 prompt
126
+ - `L_mask` 不是早期主矛盾
127
+ - `z_gt` 目前不是 sparse bridge 的主 teacher
128
+ - weak mixed 目前不能稳定把 seen 拉回 baseline
129
+
130
+ ### 5.3 当前最重要的工作假设
131
+
132
+ > `p_mask-only + image-conditioned + directional + orthogonal` 已经抓住主问题,但还需要找到更稳定的 operating point,并证明其 headline 趋势不是噪声。
133
+
134
+ ### 5.4 Fixed dev 阶段 A 当前记录
135
+
136
+ 固定 dev 子集:
137
+
138
+ - `test_s`: 200 samples
139
+ - `test_u`: 200 samples
140
+ - `test_n`: 200 samples
141
+ - manifest: `/workspace/SimToken/dev_subsets_rpb_v1.json`
142
+
143
+ #### Fixed dev baseline
144
+
145
+ | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null |
146
+ |---|---:|---:|---:|---:|---:|
147
+ | baseline | 0.72554 | 0.81811 | 0.68531 | 0.77238 | 0.01452 |
148
+
149
+ #### Teacher-only alpha search
150
+
151
+ | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 机制判断 |
152
+ |---|---:|---:|---:|---:|---:|---:|---:|---|
153
+ | image, alpha=0.20 | 0.72517 | 0.81376 | 0.68596 | 0.77730 | 0.01426 | 0.09502 | 0.06611 | 机制最强,Seen/F 有代价 |
154
+ | image, alpha=0.18 | 0.72692 | 0.81705 | 0.68595 | 0.77354 | 0.01448 | 0.02873 | 0.00605 | 性能平衡较好,机制偏弱 |
155
+ | image, alpha=0.15 | 0.72669 | 0.81725 | 0.68569 | 0.77330 | 0.01448 | 0.02373 | 0.00282 | 更接近 identity |
156
+ | image, alpha=0.12 | 0.72651 | 0.81748 | 0.68578 | 0.77314 | 0.01449 | 0.01871 | -0.00046 | 轻扰动区,机制最弱 |
157
+
158
+ 阶段 A 的 teacher-only 结论:
159
+
160
+ - `alpha=0.20` 是机制候选点,能明显改变 prompt geometry。
161
+ - `alpha=0.18` 是性能平衡候选点,seen / unseen / null 都更稳。
162
+ - `alpha=0.12/0.15` 已经过于接近 identity,不适合作为机制主证据。
163
+
164
+ #### Weak mixed 局部验证
165
+
166
+ | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 角色判断 |
167
+ |---|---:|---:|---:|---:|---:|---:|---:|---|
168
+ | image, alpha=0.18, weak mixed | 0.72704 | 0.81554 | 0.68706 | 0.77454 | 0.01451 | 0.04079 | 0.01325 | 当前最佳性能平衡候选 |
169
+ | image, alpha=0.15, weak mixed | 0.72684 | 0.81607 | 0.68674 | 0.77419 | 0.01451 | 0.03382 | 0.00882 | 稳定但略弱于 alpha=0.18 mixed |
170
+
171
+ weak mixed 当前结论:
172
+
173
+ - weak mixed 没有把 bridge 拉回 identity。
174
+ - weak mixed 对 `alpha=0.15/0.18` 都更像 mild enhancement,而不是 destructive pullback。
175
+ - `alpha=0.18 + weak mixed` 是当前 fixed dev 的最佳 operating point。
176
+
177
+ #### q-only directional baseline
178
+
179
+ | Setting | Seen mIoU | Seen F | Unseen mIoU | Unseen F | Null | Seen cos(p_hat,p_mask) | Unseen cos(p_hat,p_mask) | 判断 |
180
+ |---|---:|---:|---:|---:|---:|---:|---:|---|
181
+ | q-only, alpha=0.18 | 0.72311 | 0.81206 | 0.68289 | 0.77666 | 0.01424 | 0.12061 | 0.09598 | alignment 更强但 mIoU 更差 |
182
+
183
+ q-only 结论:
184
+
185
+ - directional / orthogonal 机制本身很强,q-only 也能大幅拉高 teacher alignment。
186
+ - q-only 的 prompt steering 更激进,`gate_mean` 更高,`delta_norm` 更大。
187
+ - q-only mIoU 在 seen / unseen 上都低于 image-conditioned candidate。
188
+ - 当前证据支持:image conditioning 的价值不是单纯提高 teacher cosine,而是约束方向修正,使 prompt steering 与 decoder compatibility 之间的平衡更好。
189
+
190
+ #### 阶段 A 当前候选
191
+
192
+ 当前 fixed dev 最佳候选:
193
+
194
+ > **image-conditioned + p_mask-only + directional + orthogonal + alpha=0.18 + weak mixed**
195
+
196
+ 对应 checkpoint:
197
+
198
+ > `/workspace/SimToken/checkpoints/rpb_dev_mixed_pm_only_a018_wm005.pth`
199
+
200
+ ---
201
+
202
+ ## 6. 实验纪律:停止在 test 上自由调方向
203
+
204
+ 从下一阶段开始,必须冻结一套 **dev tuning subset**,不再继续在 `test_s/test_u/test_n` 上自由调 alpha 和 mixed 设定。
205
+
206
+ 建议立即固定:
207
+
208
+ - `dev_seen`
209
+ - `dev_unseen`
210
+ - `dev_null`
211
+
212
+ 每个 split 可先取 `100` 或 `200` 个样本,后续:
213
+
214
+ - alpha 选择
215
+ - mixed 选择
216
+ - warm-start 配置
217
+ - early stopping
218
+
219
+ 全部只在 dev 上完成。
220
+ 真正的 test split 只用于后续一次性确认和最终表格。
221
+
222
+ ---
223
+
224
+ ## 7. 三阶段推进路线
225
+
226
+ ## 阶段 A:锁最小核心的 operating point
227
+
228
+ ### 目标
229
+
230
+ 回答:
231
+
232
+ > 当前最小核心是否能在更大 quick eval 上形成稳定、可接受的性能-几何平衡?
233
+
234
+ ### 本阶段只做两类实验
235
+
236
+ #### A1. teacher-only operating point 搜索
237
+
238
+ 固定:
239
+
240
+ - image-conditioned
241
+ - `p_mask-only`
242
+ - directional
243
+ - orthogonal
244
+ - single-token
245
+ - 不加 `z_gt`
246
+ - 不加 calibrator
247
+ - 不加 refinement
248
+
249
+ 重点只扫:
250
+
251
+ - `alpha = 0.12, 0.15, 0.18, 0.20`
252
+
253
+ 当前判断是:`0.20` 已经是 promising pass,因此没有必要继续向更大 alpha 发散。
254
+
255
+ #### A2. weak mixed 局部验证
256
+
257
+ 只围绕最佳 teacher-only checkpoint 做 warm-start,不做大 sweep。
258
+
259
+ 建议只测:
260
+
261
+ - `best_alpha`
262
+ - `best_alpha - 0.03`
263
+
264
+ 以及很弱的 mask 强度两档:
265
+
266
+ - `λ_mask = 0.05`
267
+ - `λ_mask = 0.10`
268
+
269
+ mixed 的目标不是涨分,而是判断它的角色到底是:
270
+
271
+ - calibration
272
+ - enhancement
273
+ - 还是 destructive pullback
274
+
275
+ ### 阶段 A 重点指标
276
+
277
+ 几何指标:
278
+
279
+ - `cos(p_hat, p_mask)_seen`
280
+ - `cos(p_hat, p_mask)_unseen`
281
+ - `cos(p_hat, q)`
282
+ - `cos(Δp, p_mask)`
283
+ - `cos(Δp, q)`
284
+ - `align_ratio = cos_u / cos_s`
285
+
286
+ 性能指标:
287
+
288
+ - `mIoU_seen`
289
+ - `mIoU_unseen`
290
+ - `Fscore_seen`
291
+ - `Fscore_unseen`
292
+ - `Null metric`
293
+
294
+ ### 阶段 A 的通过标准
295
+
296
+ 若在 dev 或更大 quick eval 上,能找到一个稳定点满足:
297
+
298
+ - unseen 稳定不差于 baseline,最好有小幅提升
299
+ - seen 代价可控
300
+ - null 基本持平或代价可接受
301
+ - `cos(p_hat, p_mask)` 明显离开 identity 区
302
+ - seen/unseen 的 alignment ratio 健康
303
+
304
+ 则阶段 A 通过。
305
+
306
+ ### 阶段 A 的停止条件
307
+
308
+ 若完成:
309
+
310
+ 1. alpha 局部搜索
311
+ 2. weak mixed 局部搜索
312
+ 3. 100 / 200 样本 quick eval
313
+
314
+ 之后仍出现任一情况,则停止 pure RPB standalone 主线:
315
+
316
+ - 在更大 quick eval 上没有稳定、同向的 unseen 优势
317
+ - seen/unseen tradeoff 对 alpha 高度敏感
318
+ - null 代价无法压到 baseline 附近
319
+ - mixed 始终只是增强器,而不是 decoder-facing calibration
320
+
321
+ ---
322
+
323
+ ## 阶段 B:做最小闭环 ablation
324
+
325
+ 只有阶段 A 通过后,才进入阶段 B。
326
+
327
+ ### 目标
328
+
329
+ 把方法主骨架讲圆,形成 mechanism pass 的闭环证据。
330
+
331
+ ### 必做的 4 个关键 ablation
332
+
333
+ 1. **additive vs directional**
334
+ 2. **directional without orthogonalization vs with orthogonalization**
335
+ 3. **q-only directional vs image-conditioned directional**
336
+ 4. **`p_mask-only` vs `p_mask + weak z_gt`**
337
+
338
+ 这 4 个已经足够支撑方法论证,不再继续扩更多 trick ablation。
339
+
340
+ ### 阶段 B 的补充要求
341
+
342
+ - 至少 2 个随机种子重复
343
+ - 至少一次更大规模验证
344
+ - 建立 geometry-performance coupling:
345
+ - prompt geometry 改写程度
346
+ - 与 seen/unseen 表现之间的关系
347
+ - 与 identity 回缩之间的关系
348
+
349
+ ### 阶段 B 的停止条件
350
+
351
+ 若完成:
352
+
353
+ 1. alpha 局部搜索
354
+ 2. weak mixed 局部搜索
355
+ 3. 100 / 200 样本 quick eval
356
+ 4. 至少一次更大规模验证
357
+ 5. 2 个随机种子重复
358
+
359
+ 后仍满足以下任一条,则停止 pure RPB standalone:
360
+
361
+ - 大子集 / full-split 上没有稳定、同向的 unseen 优势
362
+ - 最优点高度依赖 seed 或 alpha,趋势不稳定
363
+ - null 代价无法控制
364
+ - mixed 无法形成稳定 calibration 作用
365
+ - headline result 仍然只有极弱波动
366
+
367
+ ---
368
+
369
+ ## 阶段 C:决定论文定位
370
+
371
+ ### 路线 1:pure RPB standalone
372
+
373
+ 如果满足:
374
+
375
+ - 更大评估上有稳定 unseen gain
376
+ - seen / null 代价可接受
377
+ - 2 seeds 稳定
378
+ - 最小闭环 ablation 完整
379
+
380
+ 则走:
381
+
382
+ > **pure RPB 方法论文**
383
+
384
+ ### 路线 2:RPB + TTO hybrid
385
+
386
+ 如果出现:
387
+
388
+ - mechanism 成立
389
+ - 但 paper pass 不够硬
390
+ - headline result 仍然偏弱或不稳定
391
+
392
+ 则立刻切换定位:
393
+
394
+ > **RPB + TTO hybrid 方法论文**
395
+
396
+ 此时 RPB 的角色不再是 standalone 主方法,而是:
397
+
398
+ - amortized prompt corrector
399
+ - 改善 test-time refinement 起点质量的前端模块
400
+
401
+ ---
402
+
403
+ ## 8. Hybrid 路线作为明确 Plan B
404
+
405
+ 若 pure RPB 最终只能做到:
406
+
407
+ - unseen 稳定小涨
408
+ - seen 小掉
409
+ - null 持平或略好
410
+
411
+ 那么 standalone 顶会会比较吃力。
412
+ 但此时 RPB 作为前端 prompt corrector 仍很有价值:
413
+
414
+ - 改善初始 `q` 的几何
415
+ - 为 q-LTPO / selective refinement 提供更好的初始化
416
+ - 降低 test-time optimization 的步数和不稳定性
417
+
418
+ hybrid 的论文叙事可以明确写成:
419
+
420
+ 1. train-time:amortized interface correction
421
+ 2. test-time:instance-specific prompt refinement
422
+ 3. 两者结合:同时解决全局接口失配与样本级细化问题
423
+
424
+ 当前判断:hybrid 是非常强的 Plan B,而不是临时补救路线。
425
+
426
+ ---
427
+
428
+ ## 9. 负结果如何写进论文论证链条
429
+
430
+ 当前已经得到了一条清晰的“设计收敛链条”,后续可以直接转写为论文方法论证:
431
+
432
+ ### 为什么不是 additive residual
433
+
434
+ 因为 additive 下:
435
+
436
+ - `Δp` 主要对抗 `q` 的平行分量
437
+ - teacher 方向被大范数 `q` 吞掉
438
+ - 结果更像缩放,而不是旋转
439
+
440
+ ### 为什么要 directional
441
+
442
+ 因为 directional 才能把修正显式变成 prompt 方向控制,而不是数值扰动。
443
+
444
+ ### 为什么要 orthogonal
445
+
446
+ 因为 orthogonalization 才能避免 residual 预算浪费在径向缩放上。
447
+
448
+ ### 为什么当前只保留 `p_mask`
449
+
450
+ 因为当前 sparse bridge 里,`p_mask` 一直是主 teacher,`z_gt` 尚未成为主信号。
451
+
452
+ ### 为什么 mixed 不是主模块
453
+
454
+ 因为 mixed 目前更像 compatibility / enhancement probe,而不是稳定的 calibration mechanism。
455
+
456
+ 这条链条必须在文中明确写出,让 reviewer 看到方法是沿诊断逐步收敛的,而不是盲目堆模块。
457
+
458
+ ---
459
+
460
+ ## 10. 当前最直接的执行建议
461
+
462
+ 接下来不要发散,严格按下面顺序走:
463
+
464
+ 1. **立刻冻结论文主 claim**
465
+ 2. **立刻切换到固定 dev 子集,不再自由用 test 调方向**
466
+ 3. **完成阶段 A:最小核心 operating point 搜索**
467
+ 4. **补关键 baseline:q-only directional**
468
+ 5. **做两种 seed**
469
+ 6. **然后做 pure RPB standalone 的去留决策**
470
+
471
+ 当前最重要的执行原则是:
472
+
473
+ > **先证明最小核心能稳定成立;如果 headline 不够硬,就及时把它升级成 hybrid 前端,而不是继续把 pure RPB 做复杂。**
474
+
475
+ ---
476
+
477
+ ## 11. 当前阶段的明确结论
478
+
479
+ ### 当前方向值得继续吗?
480
+
481
+ **值得。**
482
+
483
+ ### 现在最应该做什么?
484
+
485
+ 不是继续扩模块,而是:
486
+
487
+ - 找到 teacher-only `p_mask-only directional orthogonal` 的最佳 operating point
488
+ - 用 very weak mixed 判断 mixed 是否能形成 calibration
489
+ - 在 dev 和更大 quick eval 上证明趋势不是噪声
490
+
491
+ ### 什么时候该停 pure RPB?
492
+
493
+ 只要阶段 A + B 完成后,headline 仍然弱且不稳定,就停止 pure RPB standalone。
494
+
495
+ ### 停了之后怎么办?
496
+
497
+ 直接转:
498
+
499
+ > **RPB + TTO hybrid**
500
+
501
+ 这条路线当前是明确的 Plan B,而且很可能是更强的顶会方法论文路径。
build_rpb_dev_manifest.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+
6
+ import pandas as pd
7
+
8
+
9
+ def sample_indices(size, count, seed):
10
+ if count <= 0:
11
+ return []
12
+ if count > size:
13
+ raise ValueError(f"Requested {count} samples from a split of size {size}")
14
+ rng = random.Random(seed)
15
+ indices = list(range(size))
16
+ rng.shuffle(indices)
17
+ selected = sorted(indices[:count])
18
+ return selected
19
+
20
+
21
+ def main():
22
+ parser = argparse.ArgumentParser(description="Build a fixed subset manifest for RPB dev experiments.")
23
+ parser.add_argument("--metadata", type=str, default="/workspace/SimToken/data/metadata.csv")
24
+ parser.add_argument("--output", type=str, required=True)
25
+ parser.add_argument("--seed", type=int, default=42)
26
+ parser.add_argument("--train_rows", type=int, default=0)
27
+ parser.add_argument("--test_s_rows", type=int, default=200)
28
+ parser.add_argument("--test_u_rows", type=int, default=200)
29
+ parser.add_argument("--test_n_rows", type=int, default=200)
30
+ args = parser.parse_args()
31
+
32
+ metadata = pd.read_csv(args.metadata, header=0)
33
+ split_sizes = {
34
+ "train": int((metadata["split"] == "train").sum()),
35
+ "test_s": int((metadata["split"] == "test_s").sum()),
36
+ "test_u": int((metadata["split"] == "test_u").sum()),
37
+ "test_n": int((metadata["split"] == "test_n").sum()),
38
+ }
39
+
40
+ manifest = {
41
+ "train": sample_indices(split_sizes["train"], args.train_rows, args.seed),
42
+ "test_s": sample_indices(split_sizes["test_s"], args.test_s_rows, args.seed + 1),
43
+ "test_u": sample_indices(split_sizes["test_u"], args.test_u_rows, args.seed + 2),
44
+ "test_n": sample_indices(split_sizes["test_n"], args.test_n_rows, args.seed + 3),
45
+ }
46
+
47
+ # Remove empty entries so train.py only subsets the splits we intentionally fix.
48
+ manifest = {key: value for key, value in manifest.items() if value}
49
+
50
+ os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
51
+ with open(args.output, "w", encoding="utf-8") as f:
52
+ json.dump(
53
+ {
54
+ "metadata": {
55
+ "seed": args.seed,
56
+ "split_sizes": split_sizes,
57
+ "source_metadata": os.path.abspath(args.metadata),
58
+ },
59
+ "subsets": manifest,
60
+ },
61
+ f,
62
+ indent=2,
63
+ )
64
+
65
+ print(f"saved subset manifest to {args.output}")
66
+ for split_name, indices in manifest.items():
67
+ print(f"{split_name}: {len(indices)} samples")
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
dev_subsets_rpb_v1.json ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "seed": 42,
4
+ "split_sizes": {
5
+ "train": 14113,
6
+ "test_s": 2288,
7
+ "test_u": 1656,
8
+ "test_n": 1028
9
+ },
10
+ "source_metadata": "/workspace/SimToken/data/metadata.csv"
11
+ },
12
+ "subsets": {
13
+ "test_s": [
14
+ 6,
15
+ 16,
16
+ 36,
17
+ 71,
18
+ 74,
19
+ 88,
20
+ 108,
21
+ 114,
22
+ 116,
23
+ 122,
24
+ 126,
25
+ 128,
26
+ 134,
27
+ 138,
28
+ 139,
29
+ 146,
30
+ 152,
31
+ 159,
32
+ 177,
33
+ 196,
34
+ 217,
35
+ 219,
36
+ 249,
37
+ 256,
38
+ 268,
39
+ 276,
40
+ 279,
41
+ 286,
42
+ 287,
43
+ 297,
44
+ 298,
45
+ 299,
46
+ 312,
47
+ 313,
48
+ 324,
49
+ 331,
50
+ 332,
51
+ 347,
52
+ 378,
53
+ 383,
54
+ 402,
55
+ 410,
56
+ 412,
57
+ 420,
58
+ 451,
59
+ 452,
60
+ 458,
61
+ 467,
62
+ 477,
63
+ 484,
64
+ 486,
65
+ 497,
66
+ 499,
67
+ 512,
68
+ 526,
69
+ 533,
70
+ 543,
71
+ 550,
72
+ 551,
73
+ 567,
74
+ 574,
75
+ 576,
76
+ 581,
77
+ 594,
78
+ 596,
79
+ 608,
80
+ 616,
81
+ 625,
82
+ 627,
83
+ 642,
84
+ 646,
85
+ 663,
86
+ 692,
87
+ 700,
88
+ 704,
89
+ 724,
90
+ 745,
91
+ 754,
92
+ 795,
93
+ 815,
94
+ 819,
95
+ 831,
96
+ 843,
97
+ 854,
98
+ 867,
99
+ 895,
100
+ 946,
101
+ 953,
102
+ 965,
103
+ 975,
104
+ 979,
105
+ 989,
106
+ 1004,
107
+ 1007,
108
+ 1008,
109
+ 1010,
110
+ 1023,
111
+ 1039,
112
+ 1051,
113
+ 1052,
114
+ 1072,
115
+ 1075,
116
+ 1080,
117
+ 1088,
118
+ 1099,
119
+ 1101,
120
+ 1104,
121
+ 1106,
122
+ 1134,
123
+ 1138,
124
+ 1169,
125
+ 1180,
126
+ 1201,
127
+ 1205,
128
+ 1221,
129
+ 1230,
130
+ 1247,
131
+ 1258,
132
+ 1272,
133
+ 1279,
134
+ 1284,
135
+ 1294,
136
+ 1297,
137
+ 1312,
138
+ 1329,
139
+ 1339,
140
+ 1343,
141
+ 1367,
142
+ 1379,
143
+ 1406,
144
+ 1417,
145
+ 1461,
146
+ 1462,
147
+ 1468,
148
+ 1473,
149
+ 1474,
150
+ 1489,
151
+ 1493,
152
+ 1500,
153
+ 1510,
154
+ 1517,
155
+ 1552,
156
+ 1556,
157
+ 1557,
158
+ 1589,
159
+ 1609,
160
+ 1612,
161
+ 1618,
162
+ 1622,
163
+ 1624,
164
+ 1644,
165
+ 1647,
166
+ 1665,
167
+ 1669,
168
+ 1676,
169
+ 1682,
170
+ 1683,
171
+ 1691,
172
+ 1700,
173
+ 1726,
174
+ 1746,
175
+ 1748,
176
+ 1758,
177
+ 1764,
178
+ 1765,
179
+ 1778,
180
+ 1785,
181
+ 1786,
182
+ 1808,
183
+ 1826,
184
+ 1852,
185
+ 1861,
186
+ 1883,
187
+ 1891,
188
+ 1916,
189
+ 1938,
190
+ 1944,
191
+ 1967,
192
+ 1971,
193
+ 1980,
194
+ 1986,
195
+ 2034,
196
+ 2044,
197
+ 2067,
198
+ 2074,
199
+ 2082,
200
+ 2085,
201
+ 2118,
202
+ 2128,
203
+ 2156,
204
+ 2176,
205
+ 2182,
206
+ 2185,
207
+ 2188,
208
+ 2194,
209
+ 2206,
210
+ 2211,
211
+ 2215,
212
+ 2247,
213
+ 2256
214
+ ],
215
+ "test_u": [
216
+ 4,
217
+ 16,
218
+ 26,
219
+ 38,
220
+ 40,
221
+ 48,
222
+ 50,
223
+ 65,
224
+ 83,
225
+ 92,
226
+ 102,
227
+ 117,
228
+ 120,
229
+ 135,
230
+ 144,
231
+ 153,
232
+ 155,
233
+ 185,
234
+ 200,
235
+ 201,
236
+ 211,
237
+ 219,
238
+ 221,
239
+ 226,
240
+ 227,
241
+ 240,
242
+ 245,
243
+ 251,
244
+ 252,
245
+ 255,
246
+ 267,
247
+ 272,
248
+ 274,
249
+ 276,
250
+ 278,
251
+ 282,
252
+ 284,
253
+ 286,
254
+ 303,
255
+ 309,
256
+ 313,
257
+ 328,
258
+ 345,
259
+ 348,
260
+ 358,
261
+ 363,
262
+ 374,
263
+ 376,
264
+ 379,
265
+ 383,
266
+ 385,
267
+ 387,
268
+ 393,
269
+ 396,
270
+ 400,
271
+ 412,
272
+ 417,
273
+ 428,
274
+ 434,
275
+ 452,
276
+ 453,
277
+ 456,
278
+ 459,
279
+ 463,
280
+ 473,
281
+ 490,
282
+ 493,
283
+ 504,
284
+ 517,
285
+ 525,
286
+ 535,
287
+ 543,
288
+ 544,
289
+ 545,
290
+ 549,
291
+ 550,
292
+ 565,
293
+ 584,
294
+ 585,
295
+ 594,
296
+ 602,
297
+ 603,
298
+ 606,
299
+ 638,
300
+ 642,
301
+ 643,
302
+ 651,
303
+ 684,
304
+ 687,
305
+ 692,
306
+ 700,
307
+ 721,
308
+ 728,
309
+ 752,
310
+ 757,
311
+ 779,
312
+ 783,
313
+ 785,
314
+ 794,
315
+ 803,
316
+ 807,
317
+ 814,
318
+ 847,
319
+ 849,
320
+ 853,
321
+ 854,
322
+ 861,
323
+ 867,
324
+ 884,
325
+ 900,
326
+ 903,
327
+ 906,
328
+ 924,
329
+ 930,
330
+ 931,
331
+ 941,
332
+ 948,
333
+ 957,
334
+ 968,
335
+ 972,
336
+ 980,
337
+ 987,
338
+ 995,
339
+ 996,
340
+ 1007,
341
+ 1009,
342
+ 1028,
343
+ 1033,
344
+ 1034,
345
+ 1040,
346
+ 1054,
347
+ 1098,
348
+ 1104,
349
+ 1111,
350
+ 1121,
351
+ 1126,
352
+ 1134,
353
+ 1155,
354
+ 1161,
355
+ 1167,
356
+ 1180,
357
+ 1186,
358
+ 1192,
359
+ 1212,
360
+ 1214,
361
+ 1219,
362
+ 1226,
363
+ 1254,
364
+ 1256,
365
+ 1259,
366
+ 1261,
367
+ 1270,
368
+ 1278,
369
+ 1285,
370
+ 1288,
371
+ 1290,
372
+ 1305,
373
+ 1310,
374
+ 1323,
375
+ 1325,
376
+ 1343,
377
+ 1360,
378
+ 1375,
379
+ 1376,
380
+ 1404,
381
+ 1411,
382
+ 1426,
383
+ 1429,
384
+ 1442,
385
+ 1449,
386
+ 1452,
387
+ 1456,
388
+ 1475,
389
+ 1478,
390
+ 1479,
391
+ 1484,
392
+ 1493,
393
+ 1499,
394
+ 1500,
395
+ 1501,
396
+ 1506,
397
+ 1517,
398
+ 1523,
399
+ 1528,
400
+ 1536,
401
+ 1545,
402
+ 1546,
403
+ 1550,
404
+ 1561,
405
+ 1570,
406
+ 1598,
407
+ 1609,
408
+ 1611,
409
+ 1625,
410
+ 1632,
411
+ 1634,
412
+ 1635,
413
+ 1641,
414
+ 1654,
415
+ 1655
416
+ ],
417
+ "test_n": [
418
+ 4,
419
+ 5,
420
+ 9,
421
+ 16,
422
+ 20,
423
+ 25,
424
+ 27,
425
+ 33,
426
+ 37,
427
+ 40,
428
+ 45,
429
+ 46,
430
+ 48,
431
+ 53,
432
+ 56,
433
+ 60,
434
+ 62,
435
+ 67,
436
+ 77,
437
+ 78,
438
+ 80,
439
+ 81,
440
+ 86,
441
+ 90,
442
+ 94,
443
+ 99,
444
+ 102,
445
+ 106,
446
+ 108,
447
+ 111,
448
+ 116,
449
+ 121,
450
+ 126,
451
+ 127,
452
+ 132,
453
+ 143,
454
+ 148,
455
+ 153,
456
+ 155,
457
+ 156,
458
+ 158,
459
+ 160,
460
+ 164,
461
+ 168,
462
+ 170,
463
+ 171,
464
+ 173,
465
+ 175,
466
+ 183,
467
+ 184,
468
+ 185,
469
+ 188,
470
+ 189,
471
+ 190,
472
+ 196,
473
+ 202,
474
+ 206,
475
+ 208,
476
+ 212,
477
+ 217,
478
+ 221,
479
+ 222,
480
+ 223,
481
+ 233,
482
+ 242,
483
+ 246,
484
+ 247,
485
+ 259,
486
+ 262,
487
+ 269,
488
+ 283,
489
+ 298,
490
+ 299,
491
+ 306,
492
+ 316,
493
+ 317,
494
+ 323,
495
+ 330,
496
+ 332,
497
+ 334,
498
+ 354,
499
+ 357,
500
+ 367,
501
+ 372,
502
+ 395,
503
+ 397,
504
+ 400,
505
+ 405,
506
+ 407,
507
+ 420,
508
+ 431,
509
+ 435,
510
+ 436,
511
+ 444,
512
+ 446,
513
+ 461,
514
+ 464,
515
+ 470,
516
+ 479,
517
+ 481,
518
+ 483,
519
+ 485,
520
+ 487,
521
+ 494,
522
+ 512,
523
+ 516,
524
+ 520,
525
+ 524,
526
+ 529,
527
+ 530,
528
+ 539,
529
+ 540,
530
+ 541,
531
+ 554,
532
+ 559,
533
+ 560,
534
+ 564,
535
+ 568,
536
+ 571,
537
+ 572,
538
+ 576,
539
+ 577,
540
+ 581,
541
+ 585,
542
+ 592,
543
+ 602,
544
+ 609,
545
+ 620,
546
+ 630,
547
+ 632,
548
+ 677,
549
+ 678,
550
+ 684,
551
+ 693,
552
+ 694,
553
+ 695,
554
+ 702,
555
+ 716,
556
+ 724,
557
+ 727,
558
+ 732,
559
+ 735,
560
+ 736,
561
+ 747,
562
+ 750,
563
+ 752,
564
+ 755,
565
+ 758,
566
+ 764,
567
+ 767,
568
+ 774,
569
+ 775,
570
+ 777,
571
+ 779,
572
+ 780,
573
+ 782,
574
+ 795,
575
+ 800,
576
+ 812,
577
+ 815,
578
+ 818,
579
+ 821,
580
+ 823,
581
+ 825,
582
+ 828,
583
+ 834,
584
+ 841,
585
+ 843,
586
+ 846,
587
+ 848,
588
+ 860,
589
+ 861,
590
+ 863,
591
+ 869,
592
+ 871,
593
+ 878,
594
+ 882,
595
+ 891,
596
+ 893,
597
+ 896,
598
+ 898,
599
+ 899,
600
+ 901,
601
+ 906,
602
+ 930,
603
+ 940,
604
+ 944,
605
+ 969,
606
+ 970,
607
+ 973,
608
+ 980,
609
+ 990,
610
+ 993,
611
+ 996,
612
+ 997,
613
+ 1007,
614
+ 1012,
615
+ 1013,
616
+ 1019,
617
+ 1025
618
+ ]
619
+ }
620
+ }
load_model.py CHANGED
@@ -208,7 +208,10 @@ def collate_fn(batch, tokenizer=None):
208
 
209
  import torch.multiprocessing as mp
210
  if __name__ == "__main__":
211
- mp.set_start_method("spawn")
 
 
 
212
  set_seed(42)
213
  tokenizer = transformers.AutoTokenizer.from_pretrained(
214
  args.mllm,
 
208
 
209
  import torch.multiprocessing as mp
210
  if __name__ == "__main__":
211
+ try:
212
+ mp.set_start_method("spawn")
213
+ except RuntimeError:
214
+ pass
215
  set_seed(42)
216
  tokenizer = transformers.AutoTokenizer.from_pretrained(
217
  args.mllm,
train.py CHANGED
@@ -22,6 +22,8 @@ import re
22
  import time
23
  import os
24
  import sys
 
 
25
 
26
 
27
  import warnings
@@ -214,10 +216,61 @@ def collate_fn(batch, tokenizer=None):
214
  }
215
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  import torch.multiprocessing as mp
218
  if __name__ == "__main__":
219
- mp.set_start_method("spawn")
 
 
 
220
  set_seed(42)
 
 
221
  tokenizer = transformers.AutoTokenizer.from_pretrained(
222
  args.mllm,
223
  cache_dir=None,
@@ -230,18 +283,27 @@ if __name__ == "__main__":
230
  num_added_tokens = tokenizer.add_tokens("[SEG]")
231
  seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] # 32000
232
  print("seg_token_idx: ", seg_token_idx)
 
233
 
234
  train_dataset = REFAVS('train', args, tokenizer, input_type='refer')
235
  val_dataset_s_refer = REFAVS('test_s', args, tokenizer, input_type='refer')
236
  val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
237
  val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
238
 
 
 
 
 
 
239
  if args.overfit_samples > 0:
240
  overfit_n = min(args.overfit_samples, len(train_dataset))
241
  train_dataset = Subset(train_dataset, list(range(overfit_n)))
242
  print(f"overfit_samples enabled: using first {overfit_n} train samples")
243
 
244
- train_eval_dataset = train_dataset
 
 
 
245
 
246
 
247
  g = torch.Generator()
@@ -258,15 +320,25 @@ if __name__ == "__main__":
258
  model_args = {
259
  "train_mask_decoder": True,
260
  "out_dim": 256, # 256
261
- "ce_loss_weight": 1.0,
262
- "dice_loss_weight": 0.5,
263
- "bce_loss_weight": 2.0,
264
  "seg_token_idx": seg_token_idx,
265
  "vision_pretrained": args.vision_pretrained, # sam_vit_h_xxx.pth
266
  "vision_tower": args.vision_tower,
267
  "use_im_start_end": False,
268
  "compress": args.compress,
269
  "start": args.start,
 
 
 
 
 
 
 
 
 
 
270
  }
271
 
272
  model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.float32, low_cpu_mem_usage=True, **model_args)
@@ -302,7 +374,17 @@ if __name__ == "__main__":
302
  for p in model.get_model().mm_projector.parameters():
303
  p.requires_grad = False
304
 
305
- lora_r = 8
 
 
 
 
 
 
 
 
 
 
306
  target_modules = "q_proj,v_proj"
307
  if lora_r > 0:
308
 
@@ -370,17 +452,29 @@ if __name__ == "__main__":
370
  # for name, param in model.token_compressor.named_parameters():
371
  # param.requires_grad = True
372
 
373
-
374
  for n, p in model.named_parameters():
375
  if any(
376
- [
377
- x in n
378
- for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
379
- ]
380
  ):
381
  p.requires_grad = True
382
 
383
- if args.gate_only:
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  for p in model.parameters():
385
  p.requires_grad = False
386
  for n, p in model.named_parameters():
@@ -487,12 +581,145 @@ if __name__ == "__main__":
487
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
488
  f.write(message + "\n")
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  def valuate(model, dataloader, args, name):
491
  model.eval()
492
 
493
  total_iou = 0
494
  total_fscore = 0
495
  count = 0
 
 
496
 
497
  for batch in tqdm(dataloader, desc=f"Evaluating on {name}"):
498
  input_dict = dict_to_cuda(batch)
@@ -513,7 +740,8 @@ if __name__ == "__main__":
513
  vids=input_dict["vids"],
514
  contrast=args.ct_weight,
515
  ref_ids=input_dict["ref_ids"],
516
- inference=True)
 
517
  pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
518
  gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
519
  for i in range(len(pred_masks)):
@@ -526,18 +754,35 @@ if __name__ == "__main__":
526
  total_fscore += fscore * num_seg * T
527
  count += num_seg * T
528
 
 
 
 
 
 
529
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
530
 
531
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
532
  f.write(f"valuate on {name}: miou {total_iou/count} true fscore {total_fscore/count} \n")
 
 
 
 
 
 
 
533
 
534
 
 
 
 
 
535
  # ---------------train------------------------------------------
536
 
537
  model.train()
538
  epochs = args.epochs
539
  print("init lr:", args.lr)
540
- optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
 
541
  print_referent_gate_optimizer_sanity(model, optimizer)
542
 
543
  gradient_accumulation_steps = max(1, int(16 // args.batch_size))
@@ -613,7 +858,15 @@ if __name__ == "__main__":
613
  optimizer.zero_grad()
614
 
615
  current_lr = scheduler.get_lr()[0]
616
- loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps))
 
 
 
 
 
 
 
 
617
 
618
  if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
619
  stop_training = True
 
22
  import time
23
  import os
24
  import sys
25
+ import json
26
+ from collections import defaultdict
27
 
28
 
29
  import warnings
 
216
  }
217
 
218
 
219
+ def maybe_limit_dataset(dataset, max_rows, name):
220
+ if max_rows is None or max_rows <= 0:
221
+ return dataset
222
+ limited_n = min(max_rows, len(dataset))
223
+ print(f"max_eval_rows enabled: using first {limited_n} samples from {name}")
224
+ return Subset(dataset, list(range(limited_n)))
225
+
226
+
227
+ def load_subset_manifest(path):
228
+ if not path:
229
+ return {}
230
+ with open(path, "r", encoding="utf-8") as f:
231
+ manifest = json.load(f)
232
+ if not isinstance(manifest, dict):
233
+ raise ValueError(f"subset_manifest must be a JSON object, got {type(manifest).__name__}")
234
+ if "subsets" in manifest:
235
+ manifest = manifest["subsets"]
236
+ return manifest
237
+
238
+
239
+ def maybe_apply_manifest_subset(dataset, manifest, split_name, name):
240
+ if split_name not in manifest:
241
+ return dataset
242
+ indices = manifest[split_name]
243
+ if not isinstance(indices, list) or not all(isinstance(i, int) for i in indices):
244
+ raise ValueError(f"subset_manifest[{split_name!r}] must be a list of integers")
245
+ if not indices:
246
+ raise ValueError(f"subset_manifest[{split_name!r}] is empty")
247
+ max_index = len(dataset) - 1
248
+ bad_indices = [i for i in indices if i < 0 or i > max_index]
249
+ if bad_indices:
250
+ raise ValueError(
251
+ f"subset_manifest[{split_name!r}] contains out-of-range indices; "
252
+ f"dataset size={len(dataset)}, examples={bad_indices[:5]}"
253
+ )
254
+ print(f"subset_manifest enabled: using {len(indices)} fixed samples from {name} ({split_name})")
255
+ return Subset(dataset, indices)
256
+
257
+
258
+ def checkpoint_requires_lora(saved_model_path):
259
+ if not saved_model_path or not os.path.exists(saved_model_path):
260
+ return False
261
+ state = torch.load(saved_model_path, map_location="cpu")
262
+ return any("lora_" in key for key in state.keys())
263
+
264
+
265
  import torch.multiprocessing as mp
266
  if __name__ == "__main__":
267
+ try:
268
+ mp.set_start_method("spawn")
269
+ except RuntimeError:
270
+ pass
271
  set_seed(42)
272
+ if args.bridge_only and not args.use_residual_prompt_bridge:
273
+ raise ValueError("--bridge_only requires --use_residual_prompt_bridge")
274
  tokenizer = transformers.AutoTokenizer.from_pretrained(
275
  args.mllm,
276
  cache_dir=None,
 
283
  num_added_tokens = tokenizer.add_tokens("[SEG]")
284
  seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] # 32000
285
  print("seg_token_idx: ", seg_token_idx)
286
+ subset_manifest = load_subset_manifest(args.subset_manifest)
287
 
288
  train_dataset = REFAVS('train', args, tokenizer, input_type='refer')
289
  val_dataset_s_refer = REFAVS('test_s', args, tokenizer, input_type='refer')
290
  val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
291
  val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
292
 
293
+ train_dataset = maybe_apply_manifest_subset(train_dataset, subset_manifest, "train", "train")
294
+ val_dataset_s_refer = maybe_apply_manifest_subset(val_dataset_s_refer, subset_manifest, "test_s", "test_s")
295
+ val_dataset_u_refer = maybe_apply_manifest_subset(val_dataset_u_refer, subset_manifest, "test_u", "test_u")
296
+ val_dataset_n_refer = maybe_apply_manifest_subset(val_dataset_n_refer, subset_manifest, "test_n", "test_n")
297
+
298
  if args.overfit_samples > 0:
299
  overfit_n = min(args.overfit_samples, len(train_dataset))
300
  train_dataset = Subset(train_dataset, list(range(overfit_n)))
301
  print(f"overfit_samples enabled: using first {overfit_n} train samples")
302
 
303
+ train_eval_dataset = maybe_limit_dataset(train_dataset, args.max_eval_rows, "train_eval")
304
+ val_dataset_s_refer = maybe_limit_dataset(val_dataset_s_refer, args.max_eval_rows, "test_s")
305
+ val_dataset_u_refer = maybe_limit_dataset(val_dataset_u_refer, args.max_eval_rows, "test_u")
306
+ val_dataset_n_refer = maybe_limit_dataset(val_dataset_n_refer, args.max_eval_rows, "test_n")
307
 
308
 
309
  g = torch.Generator()
 
320
  model_args = {
321
  "train_mask_decoder": True,
322
  "out_dim": 256, # 256
323
+ "ce_loss_weight": args.ce_loss_weight,
324
+ "dice_loss_weight": args.dice_loss_weight,
325
+ "bce_loss_weight": args.bce_loss_weight,
326
  "seg_token_idx": seg_token_idx,
327
  "vision_pretrained": args.vision_pretrained, # sam_vit_h_xxx.pth
328
  "vision_tower": args.vision_tower,
329
  "use_im_start_end": False,
330
  "compress": args.compress,
331
  "start": args.start,
332
+ "use_residual_prompt_bridge": args.use_residual_prompt_bridge,
333
+ "bridge_pm_weight": args.bridge_pm_weight,
334
+ "bridge_rg_weight": args.bridge_rg_weight,
335
+ "bridge_norm_weight": args.bridge_norm_weight,
336
+ "bridge_mode": args.bridge_mode,
337
+ "bridge_condition": args.bridge_condition,
338
+ "bridge_directional_alpha": args.bridge_directional_alpha,
339
+ "bridge_gate_bias_init": args.bridge_gate_bias_init,
340
+ "bridge_residual_init_std": args.bridge_residual_init_std,
341
+ "bridge_target_frame": args.bridge_target_frame,
342
  }
343
 
344
  model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.float32, low_cpu_mem_usage=True, **model_args)
 
374
  for p in model.get_model().mm_projector.parameters():
375
  p.requires_grad = False
376
 
377
+ use_lora_checkpoint = (
378
+ (args.init_from_saved_model or args.gate_only)
379
+ and checkpoint_requires_lora(args.saved_model)
380
+ )
381
+ if args.bridge_only and use_lora_checkpoint:
382
+ print(
383
+ "bridge_only notice: saved_model contains LoRA weights, "
384
+ "so LoRA modules will be instantiated for checkpoint compatibility and then frozen."
385
+ )
386
+
387
+ lora_r = 8 if (not args.bridge_only or use_lora_checkpoint) else 0
388
  target_modules = "q_proj,v_proj"
389
  if lora_r > 0:
390
 
 
452
  # for name, param in model.token_compressor.named_parameters():
453
  # param.requires_grad = True
454
 
 
455
  for n, p in model.named_parameters():
456
  if any(
457
+ [
458
+ x in n
459
+ for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
460
+ ]
461
  ):
462
  p.requires_grad = True
463
 
464
+ if args.bridge_only:
465
+ for p in model.parameters():
466
+ p.requires_grad = False
467
+ trainable_names = []
468
+ for n, p in model.named_parameters():
469
+ if "prompt_bridge" in n:
470
+ p.requires_grad = True
471
+ trainable_names.append(n)
472
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
473
+ total = sum(p.numel() for p in model.parameters())
474
+ print(f"bridge_only enabled: trainable params {trainable} / {total}")
475
+ for name in trainable_names:
476
+ print(f" bridge trainable: {name}")
477
+ elif args.gate_only:
478
  for p in model.parameters():
479
  p.requires_grad = False
480
  for n, p in model.named_parameters():
 
581
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
582
  f.write(message + "\n")
583
 
584
+ def find_prompt_bridge_module(model):
585
+ for _, module in model.named_modules():
586
+ if module.__class__.__name__ == "ResidualPromptBridge":
587
+ return module
588
+ return None
589
+
590
+ def collect_prompt_bridge_grad_norms(model):
591
+ module = find_prompt_bridge_module(model)
592
+ if module is None:
593
+ return {}
594
+
595
+ def grad_norm(param):
596
+ if param.grad is None:
597
+ return None
598
+ return float(param.grad.detach().float().norm().item())
599
+
600
+ return {
601
+ "W_a": grad_norm(module.attn_proj.weight),
602
+ "W_r": grad_norm(module.residual_proj.weight),
603
+ "W_g": grad_norm(module.gate.weight),
604
+ "b_g": grad_norm(module.gate.bias),
605
+ }
606
+
607
+ def print_prompt_bridge_grad_norms(label, norms):
608
+ parts = []
609
+ for key in ["W_a", "W_r", "W_g", "b_g"]:
610
+ value = norms.get(key)
611
+ if value is None:
612
+ parts.append(f"{key}=None")
613
+ else:
614
+ parts.append(f"{key}={value:.6e}")
615
+ print(f"{label}: " + " | ".join(parts))
616
+
617
+ def run_bridge_sanity_checks(model, dataloader):
618
+ if not args.use_residual_prompt_bridge:
619
+ raise ValueError("--bridge_sanity_only requires --use_residual_prompt_bridge")
620
+
621
+ model.train()
622
+ batch = next(iter(dataloader))
623
+ input_dict = dict_to_cuda(batch)
624
+
625
+ output_dict = model.forward(
626
+ images=input_dict["images"],
627
+ images_clip=input_dict["images_clip"],
628
+ audio_features=input_dict["audio_feats"],
629
+ image_features=input_dict["image_feats"],
630
+ input_ids=input_dict["input_ids"],
631
+ labels=input_dict["labels"],
632
+ attention_masks=input_dict["attention_masks"],
633
+ masks_list=input_dict["masks"],
634
+ resize_list=input_dict["resizes"],
635
+ orgsize_list=input_dict["orgsizes"],
636
+ conversation_list=input_dict["convs"],
637
+ refs_num=input_dict["refs_num"],
638
+ fids=input_dict["fids"],
639
+ vids=input_dict["vids"],
640
+ contrast=0.0,
641
+ ref_ids=input_dict["ref_ids"],
642
+ epoch=0,
643
+ inference=False,
644
+ target_frame=args.bridge_target_frame,
645
+ )
646
+
647
+ model.zero_grad(set_to_none=True)
648
+ output_dict["mask_loss"].backward(retain_graph=True)
649
+ print_prompt_bridge_grad_norms(
650
+ "bridge grad check | L_mask only",
651
+ collect_prompt_bridge_grad_norms(model),
652
+ )
653
+
654
+ model.zero_grad(set_to_none=True)
655
+ output_dict["bridge_teacher_loss_raw"].backward()
656
+ print_prompt_bridge_grad_norms(
657
+ "bridge grad check | L_teach only",
658
+ collect_prompt_bridge_grad_norms(model),
659
+ )
660
+
661
+ metrics = output_dict["bridge_metrics"]
662
+ print(
663
+ "bridge identity check: "
664
+ f"delta_norm_mean={metrics['delta_norm_mean']:.6f} | "
665
+ f"cos(p_hat,q)={metrics['cos_p_hat_q_mean']:.6f} | "
666
+ f"q_norm_mean={metrics['q_norm_mean']:.6f} | "
667
+ f"p_hat_norm_mean={metrics['p_hat_norm_mean']:.6f} | "
668
+ f"gate_mean={metrics['gate_mean']:.6f} | "
669
+ f"gate_std={metrics['gate_std']:.6f}"
670
+ )
671
+
672
+ teacher_pm_norms = []
673
+ teacher_rg_norms = []
674
+ teacher_cosines = []
675
+ scanned_batches = max(1, args.bridge_sanity_batches)
676
+
677
+ model.eval()
678
+ with torch.no_grad():
679
+ for batch_idx, batch in enumerate(dataloader):
680
+ if batch_idx >= scanned_batches:
681
+ break
682
+ input_dict = dict_to_cuda(batch)
683
+ result = model.forward(
684
+ images=input_dict["images"],
685
+ images_clip=input_dict["images_clip"],
686
+ audio_features=input_dict["audio_feats"],
687
+ image_features=input_dict["image_feats"],
688
+ input_ids=input_dict["input_ids"],
689
+ labels=input_dict["labels"],
690
+ attention_masks=input_dict["attention_masks"],
691
+ masks_list=input_dict["masks"],
692
+ resize_list=input_dict["resizes"],
693
+ orgsize_list=input_dict["orgsizes"],
694
+ conversation_list=input_dict["convs"],
695
+ refs_num=input_dict["refs_num"],
696
+ fids=input_dict["fids"],
697
+ vids=input_dict["vids"],
698
+ contrast=0.0,
699
+ ref_ids=input_dict["ref_ids"],
700
+ inference=True,
701
+ target_frame=args.bridge_target_frame,
702
+ )
703
+ bridge_metrics = result["bridge_metrics"]
704
+ teacher_pm_norms.append(bridge_metrics["p_mask_norm_mean"])
705
+ teacher_rg_norms.append(bridge_metrics["z_gt_norm_mean"])
706
+ teacher_cosines.append(bridge_metrics["cos_p_mask_z_gt_mean"])
707
+
708
+ print(
709
+ "bridge teacher sanity: "
710
+ f"mean||p_mask||={float(np.mean(teacher_pm_norms)):.6f} | "
711
+ f"mean||z_gt||={float(np.mean(teacher_rg_norms)):.6f} | "
712
+ f"mean cos(p_mask,z_gt)={float(np.mean(teacher_cosines)):.6f}"
713
+ )
714
+
715
  def valuate(model, dataloader, args, name):
716
  model.eval()
717
 
718
  total_iou = 0
719
  total_fscore = 0
720
  count = 0
721
+ bridge_accumulators = defaultdict(float)
722
+ bridge_count = 0
723
 
724
  for batch in tqdm(dataloader, desc=f"Evaluating on {name}"):
725
  input_dict = dict_to_cuda(batch)
 
740
  vids=input_dict["vids"],
741
  contrast=args.ct_weight,
742
  ref_ids=input_dict["ref_ids"],
743
+ inference=True,
744
+ target_frame=args.bridge_target_frame)
745
  pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
746
  gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
747
  for i in range(len(pred_masks)):
 
754
  total_fscore += fscore * num_seg * T
755
  count += num_seg * T
756
 
757
+ if args.use_residual_prompt_bridge and "bridge_metrics" in output_dict:
758
+ for key, value in output_dict["bridge_metrics"].items():
759
+ bridge_accumulators[key] += float(value)
760
+ bridge_count += 1
761
+
762
  print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
763
 
764
  with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
765
  f.write(f"valuate on {name}: miou {total_iou/count} true fscore {total_fscore/count} \n")
766
+ if bridge_count > 0:
767
+ bridge_summary = " | ".join(
768
+ f"{key}={bridge_accumulators[key] / bridge_count:.6f}"
769
+ for key in sorted(bridge_accumulators.keys())
770
+ )
771
+ print(f" bridge on {name}: {bridge_summary}")
772
+ f.write(f"bridge on {name}: {bridge_summary}\n")
773
 
774
 
775
+ if args.bridge_sanity_only:
776
+ run_bridge_sanity_checks(model, train_eval_dataloader)
777
+ sys.exit(0)
778
+
779
  # ---------------train------------------------------------------
780
 
781
  model.train()
782
  epochs = args.epochs
783
  print("init lr:", args.lr)
784
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
785
+ optimizer = AdamW(trainable_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
786
  print_referent_gate_optimizer_sanity(model, optimizer)
787
 
788
  gradient_accumulation_steps = max(1, int(16 // args.batch_size))
 
858
  optimizer.zero_grad()
859
 
860
  current_lr = scheduler.get_lr()[0]
861
+ postfix = {
862
+ "lr": current_lr,
863
+ "loss": running_loss / ((step + 1) / gradient_accumulation_steps),
864
+ }
865
+ if args.use_residual_prompt_bridge:
866
+ postfix["bridge"] = float(output_dict["bridge_teacher_loss"].item())
867
+ postfix["pm"] = float(output_dict["bridge_pm_loss"].item())
868
+ postfix["rg"] = float(output_dict["bridge_rg_loss"].item())
869
+ loop.set_postfix(**postfix)
870
 
871
  if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
872
  stop_training = True
upload_hf.py CHANGED
@@ -1,120 +1,73 @@
1
- """
2
- Upload SimToken folder to HuggingFace.
3
-
4
- Usage:
5
- python upload_hf.py --repo your-username/SimToken [--private]
6
 
7
- Features:
8
- - Automatic retry on rate limit (HTTP 429) with exponential backoff
9
- - Built-in resumption: upload_large_folder caches progress locally;
10
- re-running the script will skip already-uploaded files
11
- - Logs to both console and upload.log
12
  """
13
 
 
 
14
  import argparse
15
  import logging
16
- import time
17
  from pathlib import Path
18
 
19
- from huggingface_hub import HfApi
20
- from huggingface_hub.utils import HfHubHTTPError
 
 
21
 
22
- # ── Config ─────────────────────────────────────────────────────────────────
23
- FOLDER = Path(__file__).parent # SimToken directory
24
  IGNORE_PATTERNS = [
 
25
  "**/__pycache__/**",
 
 
26
  "**/*.pyc",
 
27
  "upload.log",
28
  ]
29
 
30
- NUM_WORKERS = 1 # conservative; increase to 8 if no rate-limit errors
31
- MAX_RETRIES = 10
32
- # ───────────────────────────────────────────────────────────────────────────
33
-
34
- logging.basicConfig(
35
- level=logging.INFO,
36
- format="%(asctime)s %(levelname)-8s %(message)s",
37
- datefmt="%H:%M:%S",
38
- handlers=[
39
- logging.FileHandler(FOLDER / "upload.log"),
40
- logging.StreamHandler(),
41
- ],
42
- )
43
- log = logging.getLogger(__name__)
44
-
45
 
46
- def parse_args():
47
- p = argparse.ArgumentParser()
48
- p.add_argument("--repo", required=True,
49
- help="HuggingFace repo id, e.g. your-username/SimToken")
50
- return p.parse_args()
 
 
51
 
52
 
53
- def main():
54
  args = parse_args()
55
- api = HfApi()
 
 
 
 
56
 
57
- # ── 1. Create repo (idempotent) ────────────────────────────────────────
58
- log.info(f"Ensuring repo '{args.repo}' exists ...")
59
- api.create_repo(
60
  repo_id=args.repo,
61
- repo_type="model",
62
- private=False,
63
  exist_ok=True,
64
  )
65
- log.info("Repo ready.")
66
-
67
- # ── 2. Upload with retry ───────────────────────────────────────────────
68
- for attempt in range(1, MAX_RETRIES + 1):
69
- try:
70
- log.info(f"[Attempt {attempt}/{MAX_RETRIES}] Starting upload_large_folder ...")
71
- log.info(f" folder : {FOLDER}")
72
- log.info(f" repo : {args.repo}")
73
- log.info(f" workers: {NUM_WORKERS}")
74
- log.info(" (re-running this script will resume from where it left off)")
75
-
76
- api.upload_large_folder(
77
- folder_path=str(FOLDER),
78
- repo_id=args.repo,
79
- repo_type="model",
80
- ignore_patterns=IGNORE_PATTERNS,
81
- num_workers=NUM_WORKERS,
82
- print_report=True,
83
- print_report_every=120, # print progress every 2 minutes
84
- )
85
-
86
- log.info("Upload complete!")
87
- return
88
-
89
- except HfHubHTTPError as e:
90
- status = e.response.status_code if e.response is not None else "?"
91
- if status == 429:
92
- # Two possible 429 causes:
93
- # 1. API request rate (resets in ~300s)
94
- # 2. Commit rate limit: 128 commits/hour (resets in ~3600s)
95
- # Wait long enough to cover the commit rate limit reset.
96
- wait = 3700
97
- log.warning(f"Rate limited (HTTP 429). Waiting {wait}s (~1 hour) for commit rate limit reset ...")
98
- time.sleep(wait)
99
- elif status in (500, 502, 503, 504):
100
- # Transient server error
101
- wait = 30 * attempt
102
- log.warning(f"Server error (HTTP {status}). Waiting {wait}s before retry ...")
103
- time.sleep(wait)
104
- else:
105
- log.error(f"HTTP error {status}: {e}")
106
- raise
107
-
108
- except Exception as e:
109
- if attempt < MAX_RETRIES:
110
- wait = 30 * attempt
111
- log.warning(f"Unexpected error: {e}. Retrying in {wait}s ...")
112
- time.sleep(wait)
113
- else:
114
- log.error(f"All {MAX_RETRIES} attempts failed. Last error: {e}")
115
- raise
116
 
117
- log.error("Upload did not complete after all retries.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
 
120
  if __name__ == "__main__":
 
1
+ """Upload the current SimToken workspace to HuggingFace Hub.
 
 
 
 
2
 
3
+ Example:
4
+ python upload_hf.py --repo yfan07/SimToken
 
 
 
5
  """
6
 
7
+ from __future__ import annotations
8
+
9
  import argparse
10
  import logging
 
11
  from pathlib import Path
12
 
13
+ from huggingface_hub import HfApi, create_repo
14
+
15
+
16
+ ROOT = Path(__file__).resolve().parent
17
 
 
 
18
  IGNORE_PATTERNS = [
19
+ ".git/**",
20
  "**/__pycache__/**",
21
+ "**/.pytest_cache/**",
22
+ "**/.cache/**",
23
  "**/*.pyc",
24
+ "**/*.pyo",
25
  "upload.log",
26
  ]
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser(description="Upload SimToken to HuggingFace Hub.")
31
+ parser.add_argument("--repo", required=True, help="Repo id, e.g. yfan07/SimToken")
32
+ parser.add_argument("--repo_type", default="model", choices=["model", "dataset", "space"])
33
+ parser.add_argument("--private", action="store_true", help="Create repo as private if missing.")
34
+ parser.add_argument("--num_workers", type=int, default=4)
35
+ return parser.parse_args()
36
 
37
 
38
+ def main() -> None:
39
  args = parse_args()
40
+ logging.basicConfig(
41
+ level=logging.INFO,
42
+ format="%(asctime)s %(levelname)s %(message)s",
43
+ handlers=[logging.FileHandler(ROOT / "upload.log"), logging.StreamHandler()],
44
+ )
45
 
46
+ create_repo(
 
 
47
  repo_id=args.repo,
48
+ repo_type=args.repo_type,
49
+ private=args.private,
50
  exist_ok=True,
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ api = HfApi()
54
+ if hasattr(api, "upload_large_folder"):
55
+ logging.info("Uploading %s to %s with upload_large_folder", ROOT, args.repo)
56
+ api.upload_large_folder(
57
+ repo_id=args.repo,
58
+ repo_type=args.repo_type,
59
+ folder_path=str(ROOT),
60
+ ignore_patterns=IGNORE_PATTERNS,
61
+ num_workers=args.num_workers,
62
+ )
63
+ else:
64
+ logging.info("Uploading %s to %s with upload_folder", ROOT, args.repo)
65
+ api.upload_folder(
66
+ repo_id=args.repo,
67
+ repo_type=args.repo_type,
68
+ folder_path=str(ROOT),
69
+ ignore_patterns=IGNORE_PATTERNS,
70
+ )
71
 
72
 
73
  if __name__ == "__main__":