Jack04810 commited on
Commit
cc0721b
·
verified ·
1 Parent(s): 4438605

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +712 -0
  2. client_utils/openai_api.py +123 -0
  3. config/__init__.py +21 -0
  4. config/__pycache__/config_rlsd_chartqa.cpython-310.pyc +0 -0
  5. config/__pycache__/config_trimode.cpython-310.pyc +0 -0
  6. config/__pycache__/config_trimode_antidegen.cpython-310.pyc +0 -0
  7. config/__pycache__/loader.cpython-312.pyc +0 -0
  8. config/config_7B.py +82 -0
  9. config/config_aok.py +119 -0
  10. config/config_llavacot.py +118 -0
  11. config/config_low.py +120 -0
  12. config/config_opd_7b_chartqa.py +48 -0
  13. config/config_rlsd_chartqa.py +152 -0
  14. config/config_trimode.py +88 -0
  15. default_config_8gpu.yaml +16 -0
  16. default_config_8gpu_deepspeed.yaml +21 -0
  17. default_config_zero2_8gpu.yaml +18 -0
  18. eval/eval_chartqa.py +310 -0
  19. figs/chartqa.png +0 -0
  20. kill_all.sh +55 -0
  21. main.py +522 -0
  22. main_llm.py +197 -0
  23. main_sft.py +80 -0
  24. multi_node_config_raw.yaml +21 -0
  25. opsd_utils/__pycache__/opsd_loss.cpython-312.pyc +0 -0
  26. opsd_utils/gate_policy.py +107 -0
  27. opsd_utils/health_monitor.py +410 -0
  28. opsd_utils/privileged/__pycache__/providers.cpython-310.pyc +0 -0
  29. opsd_utils/privileged/image_utils.py +143 -0
  30. opsd_utils/prompt_builder.py +265 -0
  31. outputs/logs/.ipynb_checkpoints/train_opd_7b_ds_20260614_175014-checkpoint.log +0 -0
  32. outputs/opd-7b-chartqa-ds/checkpoint-1764/zero_to_fp32.py +760 -0
  33. outputs/opd-7b-chartqa-ds/checkpoint-2352/preprocessor_config.json +171 -0
  34. outputs/opd-7b-chartqa-ds/checkpoint-588/config.json +235 -0
  35. papers/full_text.txt +1211 -0
  36. requirements.txt +16 -0
  37. reward_utils/__pycache__/format_checks.cpython-310.pyc +0 -0
  38. reward_utils/compute_rewards.py +126 -0
  39. reward_utils/refiner.py +162 -0
  40. tests/test_data_health_probe.py +28 -0
  41. tests/test_degeneration_probe.py +98 -0
  42. tests/test_health_monitor.py +72 -0
  43. tests/test_mode_router_rlsd.py +97 -0
  44. tests/test_privileged.py +172 -0
  45. tests/test_privileged_debug_artifacts.py +40 -0
  46. tests/test_slice_completion_logits.py +18 -0
  47. tests/test_teacher_dual_image.py +43 -0
  48. tests/test_vocab_align.py +71 -0
  49. trainer/DyMETrainer_7B.py +983 -0
  50. trainer/__init__.py +3 -0
README.md ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DyME: Empowering Small-scale VLMs with Reliable Thinking Capabilities
2
+
3
+ [![ICLR 2026](https://img.shields.io/badge/ICLR-2026-blue.svg)](#)
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2506.23061-b31b1b.svg)](https://arxiv.org/abs/2506.23061)
5
+
6
+ This repository provides the official implementation of **DyME** (**Dy**namically selecting between **M**emorization and **E**xploration), accepted at **ICLR 2026**.
7
+
8
+ ## Overview
9
+
10
+ Small-scale Vision-Language Models (SVLMs) are highly suited for proprietary tasks, but equipping them with reasoning and thinking capabilities remains challenging. Traditional Supervised Fine-Tuning (SFT) can force memorization of pseudo thinking traces, while Reinforcement Learning with Verifiable Reward (RLVR) often leads to unstable exploration (advantage collapse) due to limited model capacity.
11
+
12
+ **DyME** is a novel training paradigm that dynamically synergizes SFT and RLVR. At each optimization step, DyME dynamically selects between Memorization (via SFT) and Exploration (via RLVR), ensuring every update contributes to an optimal trade-off. To further enhance this, we introduce a **Visual Supervision mechanism** (a visual checker and refiner) to inject dynamically enhanced, image-grounded guidance during training.
13
+
14
+ Extensive experiments show that DyME delivers substantial performance improvements, establishing it as a robust strategy for stabilizing SVLM learning.
15
+
16
+
17
+ ## Repository Structure
18
+
19
+ ```text
20
+ DyME/
21
+ ├── client_utils/ # Client tools for online Visual Supervision (LLM API)
22
+ ├── data/ # Preprocessed textual datasets
23
+ ├── data_utils/ # Data processing and formatting scripts
24
+ │ ├── aokvqa/
25
+ │ ├── chart/
26
+ │ └── commom_util.py
27
+ ├── eval/ # Evaluation scripts for different benchmarks
28
+ ├── reward_utils/ # Reward function implementations for RLVR
29
+ ├── config/ # Modular configuration files for experiments
30
+ ├── opsd_utils/ # Privileged-context OPSD / TriMode extensions for DyMETrainer
31
+ ├── default_config.yaml # Default DDP (MULTI_GPU, no DeepSpeed required)
32
+ ├── default_config_deepspeed.yaml # Optional ZeRO-0 only (no sharding); needs pip install deepspeed
33
+ ├── default_config_zero2.yaml # ZeRO-2 student sharding (OPD 7B colocate)
34
+ ├── default_config_zero3_colocate.yaml # ZeRO-3 + CPU optimizer offload (tight VRAM)
35
+ ├── configs/deepspeed/ # DeepSpeed JSON templates (HF official "auto" fields)
36
+ ├── main.py # Entry point for DyME training
37
+ ├── main_*.py # Additional experimental variants (e.g., 7B, LLM-only)
38
+ ├── requirements.txt # Python dependencies
39
+ └── ...
40
+ ```
41
+
42
+
43
+ ## Configuration
44
+
45
+ Before launching training, please prepare the relevant configuration files. The main settings are managed through configuration files such as `config/config.py` and `default_config.yaml`.
46
+
47
+ ### `CLIENT_CONFIG`
48
+
49
+ This configuration is required when **Visual Supervision** is enabled. It specifies the online large-model API used by the visual checker and visual refiner.
50
+
51
+ ### `TRAINING_CONFIG`
52
+
53
+ This section contains standard training hyperparameters for both the memorization phase and the exploration phase, including optimizer settings, batch size, learning rate, and related options.
54
+
55
+ ### `RL_CONFIG`
56
+
57
+ This section defines critical variables for reward computation and response parsing during RLVR training. In particular, the following delimiters must be properly specified:
58
+
59
+ * `answer_flag`: used to explicitly separate the final answer from auxiliary generated content such as intermediate reasoning traces.
60
+ * `end_flag`: used to mark the end of generation.
61
+
62
+ These delimiters are essential for stable parsing, reward assignment, and evaluation consistency.
63
+
64
+ ### `DYME_OPSD_CONFIG` (OPSD / TriMode)
65
+
66
+ `config/config.py` defines `DYME_OPSD_CONFIG`, merged into `CONFIG["opsd"]`. When `enabled=False` (default), training follows the original DyME behavior. Set `enabled=True` or pass CLI flags to activate privileged-context **Self-OPSD** inside `DyMETrainer`.
67
+
68
+ | Field | Description |
69
+ | --- | --- |
70
+ | `enabled` | Master switch. `False` → original DyME only. |
71
+ | `mode` | Routing mode (see table below). |
72
+ | `privileged_profile` | Teacher preset: `text` \| `visual` \| `hybrid` (default **`hybrid`** in `config_trimode.py`). |
73
+ | `privileged_providers` | Override provider list; default derived from profile. |
74
+ | `privileged_image` | Teacher image layout: `mode` `single` (ChartQA default) or `dual` (full + crop); plus `crop_strategy`, `bbox_coord`, `margin_ratio`. |
75
+ | `privileged_debug` | Periodic artifact logging: `save_images`, `image_subdir` (`logs/images`), `max_samples_per_detail`. |
76
+ | `gate.correct_threshold` | Reward threshold to count a rollout as correct. |
77
+ | `gate.teacher_recoverable` | Recoverability gate: `privileged_available` (default) or `logprob_gain`. |
78
+ | `loss.beta` | JSD temperature for OPSD distillation. |
79
+ | `loss.opsd_weight` / `grpo_weight` / `sft_weight` | Per-mode loss weights. |
80
+
81
+ **Routing modes (`mode`):**
82
+
83
+ | Mode | Behavior |
84
+ | --- | --- |
85
+ | `dyme` | Original DyME: any correct rollout → GRPO; all wrong → SFT. |
86
+ | `trimode` | Any correct → OPSD (replaces GRPO); all wrong → SFT (DyME cold-start via `sft_check`, ignores recoverable). |
87
+ | `opsd_only` | All prompts use OPSD. |
88
+ | `replace_sft` | Any correct → GRPO; all wrong → OPSD (no SFT). |
89
+ | `opsd_on_wrong` | Any correct → GRPO; all wrong + recoverable → OPSD; all wrong + not recoverable → SFT (legacy three-way routing). |
90
+ | `grpo_opsd_joint` | Any correct → GRPO (+ optional joint OPSD loss); all wrong + recoverable → OPSD; else SFT. |
91
+
92
+ Under `trimode`, the SFT share is determined by accuracy (how often prompts are all-wrong) and DyME's per-group `sft_check` (teacher injection on the first generation only)—no extra `sft_ratio` hyperparameter.
93
+
94
+ **Privileged profiles** (`privileged_profile`):
95
+
96
+ | Profile | Teacher images | Teacher text suffix |
97
+ | --- | --- | --- |
98
+ | `text` | Single full image (same as student) | hint + answer |
99
+ | `visual` | **Dual**: full + evidence crop | Visual Facts only (no answer leak) |
100
+ | `hybrid` | Single full image by default (`privileged_image.mode=single`); dual with `mode=dual` | Visual Facts + hint + answer |
101
+
102
+ Student `collate_fn` never reads privileged fields. With `privileged_image.mode=dual`, teacher forward uses `[full, crop]`; crop comes from normalized `evidence_bbox` (C2), A-OKVQA `visual_fact` heuristic (D2), or center fallback (D1). ChartQA defaults to `single` (no crop zoom).
103
+
104
+ **Privileged providers** (under `opsd_utils/privileged/`):
105
+
106
+ * `text` — uses the `hint` / `answer` fields in training samples.
107
+ * `visual_facts` — uses `visual_fact` JSON (B1 raw string), plus ChartQA `visual_fact_hint` (F1) and `visual_fact_deplot` (F2).
108
+ * `crop` — evidence region as second teacher image (via `image_utils`, not a text suffix).
109
+ * `hybrid` — combines text + visual_facts providers per profile.
110
+
111
+ **Debug / artifact logging**
112
+
113
+ * Verbose OPSD logs: `--opsd_debug` or `DYME_OPSD_DEBUG=1`.
114
+ * Full diagnostic bundle every N steps: `--opsd_detail_every N` or `DYME_OPSD_DETAIL_EVERY`.
115
+ * On detail steps, teacher privileged images are saved under `{output_dir}/logs/images/` as `step_XXXXXX_idx_Y_full.png`, `_crop.png`, and `_meta.json` (controlled by `privileged_debug.max_samples_per_detail`).
116
+
117
+ **ChartQA visual-facts preprocessing (run on server before TriMode training)**
118
+
119
+ TriMode with `privileged_providers=text,visual_facts` requires `visual_fact_hint` / `visual_fact_deplot` (and optionally `visual_fact`) on each sample. Raw `train_medium.json` only has `hint` — without this step, logs show `visual_fact_len=0` and the VisualFacts teacher channel is empty.
120
+
121
+ From the repo root on your GPU server:
122
+
123
+ ```bash
124
+ cd /path/to/agentic-rl-main # project root (parent of scripts/, config/, data/)
125
+
126
+ # F1: copy hint → visual_fact_hint (+ visual_fact for backward compat)
127
+ python scripts/build_visual_facts_chartqa.py \
128
+ --input data/chartqa/train_medium.json \
129
+ --output data/chartqa/train_medium_vf_hint.json \
130
+ --also-set-visual-fact
131
+
132
+ # F2: DePlot offline table extraction (google/deplot, batched GPU inference; default ON)
133
+ python scripts/build_visual_facts_chartqa_deplot.py \
134
+ --input data/chartqa/train_medium_vf_hint.json \
135
+ --output data/chartqa/train_medium_vf_full.json \
136
+ --batch-size 8 \
137
+ --cache data/chartqa/deplot_cache.json
138
+
139
+ # Fast placeholder-only mode (no GPU / CI): add --no-enabled
140
+ # DYME_DEPLOT_ENABLED=0 bash scripts/train_local_gpus.sh
141
+
142
+ # quick sanity check (expect non-zero lengths)
143
+ python -c "
144
+ import json, random
145
+ d = json.load(open('data/chartqa/train_medium_vf_full.json', encoding='utf-8'))
146
+ s = random.choice(d)
147
+ assert s.get('visual_fact_hint'), 'missing visual_fact_hint'
148
+ assert s.get('visual_fact_deplot'), 'missing visual_fact_deplot'
149
+ print('ok', len(d), 'records; sample visual_fact_hint len', len(s['visual_fact_hint']))
150
+ "
151
+ ```
152
+
153
+ `config/config.py` points `train_dataset` at `data/chartqa/train_medium_vf_full.json`. Generated `*_vf_*.json` files are gitignored — **generate them on each server** (or copy from shared storage); do not rely on cloning them from GitHub.
154
+
155
+ `scripts/train_local_gpus.sh` will auto-run the two Python steps above if `train_medium_vf_full.json` is missing.
156
+
157
+ **Training examples (TriMode + hybrid default)**
158
+
159
+ ```bash
160
+ # Text-only OPSD ablation
161
+ python main.py --config trimode --opsd_privilege_profile text
162
+
163
+ # Vision-OPD style (no answer text to teacher)
164
+ python main.py --config trimode --opsd_privilege_profile visual
165
+
166
+ # Full hybrid (default in config_trimode)
167
+ python main.py --config trimode --opsd_privilege_profile hybrid --opsd_detail_every 10
168
+ ```
169
+
170
+ **Privileged sample schema**
171
+
172
+ | Field | Used by | Notes |
173
+ | --- | --- | --- |
174
+ | `prompt`, `image` | Student + teacher | Student always single full image |
175
+ | `hint`, `answer` | Teacher (`text` / `hybrid`) | Never in student collate |
176
+ | `visual_fact` | Teacher | Raw JSON string (A-OKVQA) |
177
+ | `visual_fact_hint` | Teacher (ChartQA F1) | Hint placeholder pipeline |
178
+ | `visual_fact_deplot` | Teacher (ChartQA F2) | DePlot `parsed_table` text (`google/deplot`; placeholder skipped) |
179
+ | `evidence_bbox` | Teacher crop | Normalized `[x0,y0,x1,y1]` in `[0,1]` |
180
+
181
+ Adapter helpers for future datasets: `data_utils/privileged_schema.py` (`normalize_evidence_bbox`, `parse_visual_fact`, `resolve_crop_bbox`).
182
+
183
+ For legacy ChartQA single-field preprocessing, see `scripts/build_visual_facts_chartqa.py`.
184
+
185
+
186
+
187
+ ## Data Preparation
188
+
189
+ We provide example preprocessing scripts in the `data_utils/` directory. After preprocessing, the training data should be organized as a list of dictionaries (e.g., `metadata_list`) following the format below:
190
+
191
+ ```python
192
+ metadata_list.append({
193
+ "question": question, # Full prompt used for training
194
+ "question_wo_prompt": question, # Raw question without prompt template
195
+ "answer": answer, # SFT target; should follow the answer_flag format
196
+ "image": image_save_path, # Local path to the corresponding image
197
+ })
198
+ ```
199
+
200
+ ### Field Description
201
+
202
+ * `question`: the complete model input used during training.
203
+ * `question_wo_prompt`: the raw question content without any additional prompt wrapper.
204
+ * `answer`: the training target for SFT; this field should be formatted consistently with the delimiter specification in `RL_CONFIG`.
205
+ * `image`: the local file path of the associated image, if applicable.
206
+
207
+
208
+
209
+ ## Environment Setup
210
+
211
+ Please first install the required dependencies and configure the distributed training environment:
212
+
213
+ ```bash
214
+ pip install -r requirements.txt
215
+ accelerate config
216
+ ```
217
+
218
+ The `accelerate config` step is required to initialize the distributed environment for both training and evaluation.
219
+
220
+
221
+
222
+ ## Dataset Setup
223
+
224
+ ### Text Data
225
+
226
+ Preprocessed text splits are provided under the `data/` directory.
227
+
228
+ ### Image Data
229
+
230
+ Due to storage constraints, image datasets are not included in this repository. Download scripts write images under `data/images/` by default:
231
+
232
+ ```text
233
+ data/images/
234
+ ├── chartqa/
235
+ │ ├── images/ # train_000000.png, val_000000.png, test_000000.png, ...
236
+ │ └── json/ # train.json, val.json, test.json (from download.py)
237
+ └── aokvqa/
238
+ ├── images/ # train_0000000.png, ...
239
+ └── json/ # train.json, validation.json, test.json (from download.py)
240
+ ```
241
+
242
+ **ChartQA** (images only, no API required):
243
+
244
+ ```bash
245
+ python data_utils/chart/download.py
246
+ ```
247
+
248
+ **A-OKVQA** (images only by default; set `FETCH_VISUAL_FACTS=1` only if local VLM APIs are running on ports 23333–23340):
249
+
250
+ ```bash
251
+ python data_utils/aokvqa/download.py
252
+ ```
253
+
254
+ If you already downloaded ChartQA to `chartqa_output/` at the project root, move it into the canonical layout:
255
+
256
+ ```bash
257
+ mkdir -p data/images
258
+ mv chartqa_output data/images/chartqa
259
+ ```
260
+
261
+ Preprocessed text annotations with hints live separately under `data/chartqa/` and `data/aokvqa/`. Image paths inside those JSON files are resolved automatically at load time (legacy prefixes like `/chartqa_output/` map to `data/images/chartqa/`).
262
+
263
+ ### Demo Samples
264
+
265
+ A small subset of demo images for verifying the data loading pipeline may be provided in a future update.
266
+
267
+
268
+
269
+ ## Dataset Examples
270
+
271
+ ### ChartQA
272
+
273
+ **ChartQA** is a visual question answering benchmark grounded in chart images. To illustrate different supervision granularities, we provide representative examples with three levels of reasoning-trace quality: **High**, **Medium**, and **Low**.
274
+
275
+ <div align="center">
276
+
277
+ | Example |
278
+ | --------------------------------------------------------------- |
279
+ | <img src="figs/chartqa.png" alt="ChartQA Example" width="220"/> |
280
+
281
+ </div>
282
+
283
+ #### High-quality Example
284
+
285
+ <details>
286
+ <summary><code>High-quality ChartQA Example</code></summary>
287
+
288
+ ```json
289
+ {
290
+ "question": "When does the unfavorable view reach the peak?",
291
+ "answer": "2017",
292
+ "hint": "<SUMMARY> To solve the problem, I will examine the image to identify trends in unfavorable views of Pakistan in India over time. I'll closely inspect the data points within the graph to determine the year where the \"very unfavorable view\" reaches its peak. This involves identifying the maximum value on the vertical axis and noting the corresponding year on the horizontal axis. </SUMMARY> \n\n<CAPTION> The image is a line graph titled \"Very unfavorable views of Pakistan increasing in India,\" with the subtitle \"Very unfavorable view of Pakistan.\" The y-axis represents the percentage of unfavorable views, ranging from 0% to 100%. The x-axis displays years from 2013 to 2017. The data points show the percentages of very unfavorable views over these years, with specific values marked: 54% in 2013, 49% in 2014, 51% in 2015, 55% in 2016, and 64% in 2017. The graph shows a general upward trend in unfavorable views, peaking in 2017. </CAPTION> \n\n<REASONING> To determine when the unfavorable view reaches its peak, one should observe the graph for the data point with the highest percentage on the y-axis. The graph shows percentages for each year from 2013 to 2017: starting at 54% in 2013, decreasing to 49% in 2014, and then gradually increasing to 51% in 2015 and 55% in 2016. The graph culminates with the highest percentage of 64% in 2017. Thus, the peak of unfavorable views is associated with the year 2017. </REASONING> \n\n<CONCLUSION> 2017 </CONCLUSION>"
293
+ }
294
+ ```
295
+
296
+ </details>
297
+
298
+ #### Medium-quality Example
299
+
300
+ <details>
301
+ <summary><code>Medium-quality ChartQA Example</code></summary>
302
+
303
+ ```json
304
+ {
305
+ "question": "When does the unfavorable view reach the peak?",
306
+ "answer": "2017",
307
+ "hint": "Goal: Find the year when the unfavorable view reaches its peak.\nObservation: The data shows the values for each year are: 2013: 0, 2014: 0, 2015: 0, 2016: 55, and 2017: 64.\nReasoning: By comparing the values in each year, the highest value is 64, which occurs in 2017.\nConclusion: The unfavorable view reaches its peak in 2017."
308
+ }
309
+ ```
310
+
311
+ </details>
312
+
313
+ #### Low-quality Example
314
+
315
+ <details>
316
+ <summary><code>Low-quality ChartQA Example</code></summary>
317
+
318
+ ```json
319
+ {
320
+ "question": "When does the unfavorable view reach the peak?",
321
+ "answer": "2017",
322
+ "hint": "I'm trying to figure out the year when the unfavorable view reaches its highest point. Looking at the data, I see that the values for each year are pretty low until 2016, where it jumps to 55. But the peak doesn't happen until 2017, when the value spikes to 64. So, it seems like the unfavorable view really hits its maximum in 2017."
323
+ }
324
+ ```
325
+
326
+ </details>
327
+
328
+
329
+ ### A-OKVQA
330
+
331
+ **A-OKVQA** is an open-ended visual question answering benchmark that requires world knowledge, commonsense reasoning, and visual understanding. Below we provide a representative example together with its corresponding annotation.
332
+
333
+ <div align="center">
334
+
335
+ | Example |
336
+ | -------------------------------------------------------------- |
337
+ | <img src="figs/aokvqa.png" alt="A-OKVQA Example" width="220"/> |
338
+
339
+ </div>
340
+
341
+ <details>
342
+ <summary><code>View A-OKVQA JSON Example</code></summary>
343
+
344
+ ```json
345
+ {
346
+ "question": "What is the man by the bags awaiting?",
347
+ "answer": "cab",
348
+ "visual_fact": "{\n \"description\": \"The image shows a man standing in the middle of a street, facing away from the camera. He is holding a red bag in one hand and appears to be pulling a black suitcase with wheels. Another black suitcase is lying on the ground near him. The setting is an urban area with houses, parked cars, and trees in the background. The man seems to be waiting or preparing to cross the street.\",\n \"objects\": [\n {\n \"name\": \"man\",\n \"attributes\": [\"wearing a light blue and white shirt\", \"blue jeans\", \"carrying a red bag\", \"pulling a black suitcase\"],\n \"position\": \"center\"\n },\n {\n \"name\": \"red bag\",\n \"attributes\": [\"held by the man\"],\n \"position\": \"left side of the man\"\n },\n {\n \"name\": \"black suitcase\",\n \"attributes\": [\"with wheels\", \"being pulled by the man\"],\n \"position\": \"near the man's feet\"\n },\n {\n \"name\": \"black suitcase\",\n \"attributes\": [\"on the ground\"],\n \"position\": \"on the ground near the man\"\n },\n {\n \"name\": \"street\",\n \"attributes\": [\"asphalt\", \"urban setting\"],\n \"position\": \"foreground\"\n },\n {\n \"name\": \"houses\",\n \"attributes\": [\"visible in the background\"],\n \"position\": \"left side\"\n },\n {\n \"name\": \"parked cars\",\n \"attributes\": [\"red SUV\", \"other vehicles\"],\n \"position\": \"left and center background\"\n },\n {\n \"name\": \"trees\",\n \"attributes\": [\"green foliage\"],\n \"position\": \"right side\"\n }\n ]\n}",
349
+ "hint": "A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him, so a cab is the only plausible answer."
350
+ }
351
+ ```
352
+
353
+ </details>
354
+
355
+
356
+
357
+ ### GSM8K
358
+
359
+ **GSM8K** is a mathematical word problem benchmark. Since it is text-only, we provide a representative question-answer example together with its reasoning trace.
360
+
361
+ <details>
362
+ <summary><code>View GSM8K JSON Example</code></summary>
363
+
364
+ ```json
365
+ {
366
+ "question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?",
367
+ "answer": "72",
368
+ "hint": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72"
369
+ }
370
+ ```
371
+
372
+ </details>
373
+
374
+
375
+
376
+ ## Training
377
+
378
+ All training scripts are launched using `accelerate`. Pass `--config` as a **Python config file path** (recommended) or a shorthand alias (`norm`, `trimode`, `llavacot`, `low`, `aok`).
379
+
380
+ **Important:** `num_processes` must match the number of visible GPUs on your node. Helper scripts auto-detect GPU count and use **native PyTorch DDP** (`default_config.yaml`, `distributed_type: MULTI_GPU`) — **DeepSpeed is not required** for 0.5B multi-GPU training.
381
+
382
+ Optional: if you already have `deepspeed` installed and want the Accelerate integration without parameter sharding, use ZeRO-0 (`default_config_deepspeed.yaml`, `zero_stage: 0`). Do **not** use ZeRO-2/3 for 0.5B-only RLSD unless you need the integration path.
383
+
384
+ **7B OPD (student + frozen teacher on each GPU):** use DeepSpeed ZeRO to shard the **trainable 0.5B student**; the frozen 7B teacher stays outside DeepSpeed on `cuda:{LOCAL_RANK}`.
385
+
386
+ ```bash
387
+ # ZeRO-2 colocate (recommended first try on 2×80G)
388
+ bash scripts/train_opd_7b_chartqa_deepspeed.sh
389
+
390
+ # Tighter memory: ZeRO-3 + CPU optimizer offload
391
+ ACCELERATE_CONFIG=default_config_zero3_colocate.yaml bash scripts/train_opd_7b_chartqa_deepspeed.sh
392
+ ```
393
+
394
+ Refs: [Transformers DeepSpeed](https://huggingface.co/docs/transformers/deepspeed), [Accelerate DeepSpeed](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).
395
+
396
+ ```bash
397
+ # 4-GPU node (default DDP, recommended)
398
+ bash scripts/train_rlsd_chartqa.sh
399
+
400
+ # Explicit DDP config
401
+ accelerate launch --config_file default_config.yaml --num_processes 4 main.py --config config/config.py --mode rl
402
+
403
+ # Optional ZeRO-0 (requires deepspeed, no sharding)
404
+ ACCELERATE_CONFIG=default_config_deepspeed.yaml bash scripts/train_rlsd_chartqa.sh
405
+ ```
406
+
407
+ Or override explicitly: `NUM_GPUS=4 bash scripts/train_trimode.sh`
408
+
409
+ For TriMode on **all visible local GPUs** (auto-detect via `CUDA_VISIBLE_DEVICES` / `torch.cuda.device_count()`):
410
+
411
+ ```bash
412
+ # 1) One-time (or when raw data changes): enrich ChartQA JSON on the server — see
413
+ # "ChartQA visual-facts preprocessing" above. train_local_gpus.sh also auto-runs
414
+ # this if train_medium_vf_full.json is absent.
415
+
416
+ # 2) Start training (default: OPSD verbose off, detail every 50 steps, probe on)
417
+ bash scripts/train_local_gpus.sh
418
+
419
+ # Optional: full verbose debug (large logs)
420
+ # DYME_OPSD_DEBUG=1 DYME_OPSD_DETAIL_EVERY=10 bash scripts/train_local_gpus.sh
421
+
422
+ # Roll back to original trimode config (pre-antidegen hyperparameters)
423
+ # DYME_CONFIG=config/config_trimode.py bash scripts/train_local_gpus.sh
424
+ ```
425
+
426
+ ### Anti-degeneration config (`config_trimode_antidegen`)
427
+
428
+ `scripts/train_local_gpus.sh` defaults to **`config/config_trimode_antidegen.py`** (alias `trimode_antidegen`). Overrides are based on offline analysis of `train_trimode_4gpu_20260610_173637.log` (1225 steps):
429
+
430
+ | Issue | Baseline log evidence | Antidegen change |
431
+ |-------|----------------------|------------------|
432
+ | Logit collapse | `LOGIT_MODE_COLLAPSE` 212×; step 1 clip 0→1.0; step 1175 clip≈0.92 | `max_completion_length=150`, `temperature=0.7`, `repetition_penalty=1.25` |
433
+ | Step-1 gradient shock | `GEN_CLIP_COLLAPSE` from step 1; `OPT_GRAD_SPIKE` 44× | `learning_rate=5e-5`, `warmup_steps=50` |
434
+ | OPSD coverage low | `opsd_mask` mean 5.6%; 492/1226 zero-mask steps | `require_format_for_opsd=False` (env default `DYME_OPSD_REQUIRE_FORMAT=0`) |
435
+ | RL signal sparse | `RL_ZERO_SIGNAL` expected in trimode | `reward_weights=[0.5, 1.5, 1.0]` (format, context F1, acc) |
436
+ | visual_fact empty | `visual_fact_empty_rate=0` throughout | no data change |
437
+
438
+ Environment overrides:
439
+
440
+ ```bash
441
+ export DYME_CONFIG=config/config_trimode_antidegen.py # default in train_local_gpus.sh
442
+ export DYME_OPSD_REQUIRE_FORMAT=0 # antidegen default; set 1 to restore strict gate
443
+ export DYME_REWARD_WEIGHTS=0.5,1.5,1.0 # format, context, accuracy
444
+ ```
445
+
446
+ After a new run (~200+ steps), compare against the baseline log:
447
+
448
+ ```bash
449
+ python scripts/parse_trimode_log.py outputs/logs/train_trimode_*_new.log
450
+ python scripts/degeneration_report.py outputs/logs/train_trimode_*_new.log
451
+ python scripts/compare_trimode_logs.py train_trimode_4gpu_20260610_173637.log outputs/logs/train_trimode_*_new.log
452
+ ```
453
+
454
+ Success criteria (candidate vs baseline): step 1 `clip` &lt; 1.0; `LOGIT_MODE_COLLAPSE` count down &gt;30%; `opsd_mask` mean &gt; 8%; step 200+ `mean_length` median &lt; 130. `RL_ZERO_SIGNAL` may remain high (trimode design).
455
+
456
+
457
+ ### 1. Training DyME (original)
458
+
459
+ Default config keeps OPSD disabled (`DYME_OPSD_CONFIG.enabled=False`):
460
+
461
+ ```bash
462
+ accelerate launch main.py --config config/config.py --mode rl
463
+ ```
464
+
465
+ ### OPSD debug logging + tee
466
+
467
+ When debugging OPSD / TriMode (e.g. NCCL timeout), enable verbose logs and save stdout/stderr:
468
+
469
+ ```bash
470
+ export DYME_OPSD_DEBUG=1
471
+ mkdir -p ./outputs/logs
472
+ LOG_FILE=./outputs/logs/train_$(date +%Y%m%d_%H%M%S).log
473
+
474
+ accelerate launch --config_file default_config.yaml --num_processes "$(nvidia-smi -L | wc -l)" main.py \
475
+ --config config/config_trimode.py \
476
+ --mode rl \
477
+ --opsd_enabled \
478
+ --opsd_debug \
479
+ --opsd_mode trimode \
480
+ --opsd_providers text,visual_facts \
481
+ 2>&1 | tee "${LOG_FILE}"
482
+ ```
483
+
484
+ Logs are prefixed with `[OPSD-DEBUG]` and include rank, step, `[SYNC_POINT]` markers before every distributed collective in the OPSD chain (reward gather, teacher prompt build, metrics gather, OPSD loss). Search the log for the last `[SYNC_POINT]` on each rank to locate where a hang occurred.
485
+
486
+ You can also use the helper script (debug + tee enabled by default):
487
+
488
+ ```bash
489
+ bash scripts/train_trimode.sh
490
+ ```
491
+
492
+ Disable debug when not needed: `DYME_OPSD_DEBUG=0 bash scripts/train_trimode.sh`
493
+
494
+ ### Periodic weak-signal diagnostics (`[OPSD-DETAIL]`)
495
+
496
+ Separate from per-step `[OPSD-DEBUG]` spam: every **N global steps** (default **10**, rank 0 only) a full diagnostic bundle is printed to investigate **reward ≈ 0** and **gradient ≈ 0** while the OPSD chain still runs:
497
+
498
+ - Generation: EOS rate, clipped ratio, effective completion tokens, decoded samples
499
+ - Reward: format / acc / context breakdown, advantage stats, per-sample table
500
+ - Routing: OPSD mask ratio, TriMode counts, advantage token distribution
501
+ - Loss: GRPO per-token logps, coef\_1, clip counts, weak-signal hints
502
+ - OPSD JSD: per-token JSD, student/teacher top-1 agreement, max-JSD token
503
+
504
+ Configure via config, CLI, or env:
505
+
506
+ ```bash
507
+ # default: every 10 steps (config_trimode.py)
508
+ export DYME_OPSD_DETAIL_EVERY=10
509
+
510
+ python main.py --config config/config_trimode.py --mode rl \
511
+ --opsd_enabled --opsd_detail_every 10
512
+
513
+ # disable periodic detail
514
+ export DYME_OPSD_DETAIL_EVERY=0
515
+ ```
516
+
517
+ Search logs for `[OPSD-DETAIL]` (not `[OPSD-DEBUG]`).
518
+
519
+ **Per-generate probe (`[OPSD-PROBE]`)** — enabled by default in `config_trimode.py`; fires on every `(re)generate` on rank 0 (no need to wait for step 10). Logs raw `completion_ids`, decode with/without special tokens, `eos_idx`, flags `ONE_TOKEN` / `EMPTY_DECODE` / `FIRST_IS_EOS`, and patterns `PAREN_THEN_EOS` / `REPEAT_LOOP`. Disable with `DYME_OPSD_PROBE_ON_GENERATE=0` or `--no_opsd_probe_on_generate`.
520
+
521
+ **Deep generate debug (`[OPSD-GENDBG]`)** — runs alongside `[OPSD-PROBE]` when probe is enabled. Before each `model.generate`, logs model training context, prompt tail tokens/decode, and first-token logits (`p_eos`, `p_token_340`, `entropy`, `top5`) via **per-sample** forward (up to `probe_sample_count`, default 4) to avoid OOM on large VLM batches. After generate, logs greedy-vs-actual first token, delta vs previous regenerate, and cross-rank summary.
522
+
523
+ ```bash
524
+ export DYME_OPSD_PROBE_ON_GENERATE=1 # default in config_trimode
525
+ grep -E '\[OPSD-(PROBE|GENDBG)\]' train.log
526
+ ```
527
+
528
+ | Observation in logs | Likely root cause |
529
+ |---------------------|-------------------|
530
+ | `p_eos` very high + `greedy_token_id==eos` | Weight collapse / train-mode distribution |
531
+ | `prompt_tail_decode` ends with unclosed template + high `p_token_340` | Prompt / chat template issue (legacy `"Answer: .."` quoted placeholder biased token 340 `)`; fixed in `data_utils/rl_prompt.py`) |
532
+ | `greedy_matches_actual=False` with low `p_eos` | Sampling noise (temperature / top_p) |
533
+ | Large `one_token_count` gap across ranks in `cross_rank` | Data sharding / batch composition |
534
+ | `delta_one_token_count` spikes at `generate_call_index>=2` | Weight drift after optimizer step |
535
+
536
+ Optional env overrides:
537
+
538
+ ```bash
539
+ export DYME_OPSD_PROBE_FIRST_TOKEN_LOGITS=0 # skip extra forward before generate
540
+ export DYME_OPSD_PROBE_PROMPT_TAIL_TOKENS=24
541
+ export DYME_OPSD_PROBE_LOG_MODEL_CONTEXT=0
542
+ ```
543
+
544
+ ### 2. Training TriMode (DyME + OPSD)
545
+
546
+ Use `config/config_trimode.py` (OPSD pre-enabled) or override on the base config via CLI:
547
+
548
+ ```bash
549
+ accelerate launch main.py \
550
+ --config config/config_trimode.py \
551
+ --mode rl \
552
+ --opsd_enabled \
553
+ --opsd_mode trimode \
554
+ --opsd_providers text,visual_facts
555
+ ```
556
+
557
+ Equivalent one-liner with base config + CLI only:
558
+
559
+ ```bash
560
+ accelerate launch main.py --config config/config.py --mode rl \
561
+ --opsd_enabled --opsd_mode trimode --opsd_providers text,visual_facts
562
+ ```
563
+
564
+ **CLI OPSD flags** (override `CONFIG["opsd"]`):
565
+
566
+ | Flag | Description |
567
+ | --- | --- |
568
+ | `--opsd_enabled` | Enable OPSD / TriMode extensions. |
569
+ | `--opsd_debug` | Verbose OPSD chain logs (`[OPSD-DEBUG]`, or env `DYME_OPSD_DEBUG=1`). |
570
+ | `--opsd_detail_every N` | Full weak-signal bundle every N steps (`[OPSD-DETAIL]`, default 10; `0` = off). |
571
+ | `--opsd_probe_on_generate` / `--no_opsd_probe_on_generate` | Per-generate `[OPSD-PROBE]` on rank 0 (trimode default on). |
572
+ | `--opsd_mode MODE` | Routing mode: `trimode` (legacy), `rlsd` (anti-leakage), `copsd_opd`, `dyme`, `opsd_only`, `replace_sft`, … |
573
+ | `--opsd_providers LIST` | Comma-separated providers, e.g. `text`, `format_only`, `visual_facts`. Empty = same-prompt OPD only. |
574
+
575
+ ### 2b. RLSD / anti-leakage OPSD (recommended for ChartQA)
576
+
577
+ `trimode` routes OPSD on **correct** completions and injects gold answer into the teacher prompt (information leakage). Use **`rlsd`** instead:
578
+
579
+ - **Correct** → GRPO (on-policy self-learning, no privileged suffix)
580
+ - **Wrong** → same-prompt OPSD / OPD (no `[Reference Answer]` in teacher)
581
+ - **All-wrong group** → online SFT replace on the first generation (DyME cold-start; no separate offline SFT phase)
582
+
583
+ **Important — online SFT ≠ offline SFT:** From step 0, training is always **RL + sparse online SFT** (typically 1/8 of completions per prompt when the group is all-wrong). There is no dedicated SFT-only phase unless you run a separate offline stage (see below).
584
+
585
+ **Anti-collapse knobs (ChartQA RLSD / OPD):**
586
+
587
+ | Env / config | Purpose |
588
+ | --- | --- |
589
+ | `DYME_MAX_COMPLETION_LENGTH`, `DYME_TEMPERATURE`, `DYME_REPETITION_PENALTY` | Antidegen decoding (RLSD defaults: 128 / 0.6 / 1.35) |
590
+ | `DYME_FORMAT_MIN_THINKING` | Minimum chars before `Answer:` for format reward (default 8) |
591
+ | `DYME_OPSD_SKIP_DEGENERATE=0` | Never skip OPSD on degenerate completions |
592
+ | `DYME_OPSD_DEGEN_WARMUP_STEPS` | Before this step, degenerate samples still run OPSD (default 200) |
593
+ | `DYME_SFT_WARMUP_SLOTS` | During warmup, inject GT into first N gens per all-wrong group (default 2) |
594
+
595
+ **OPD config trap:** `config/config_opd_7b_chartqa.py` must inherit `CONFIG["training"]["dyme_args"]` from `config_rlsd_chartqa` (not stale `TRAINING_CONFIG["dyme_args"]` from antidegen). If logs show `max_new_tokens=150, temperature=0.7`, you are on the wrong decode path — stop and `git pull`.
596
+
597
+ **Stop-training heuristics:** If after ~200 steps you see `degenerate_rate≈1`, `opsd_mask_true=0`, `grad_norm=0`, and `format_mean≈1` with `accuracy=0`, the run is collapsed — restart from base 0.5B or an early checkpoint.
598
+
599
+ ```bash
600
+ bash scripts/train_rlsd_chartqa.sh
601
+ # or: --config config/config_rlsd_chartqa.py --opsd_mode rlsd --opsd_providers format_only
602
+ ```
603
+
604
+ **Two-stage cold start (optional offline SFT → RLSD/OPD):**
605
+
606
+ ```bash
607
+ bash scripts/train_chartqa_sft.sh
608
+ export DYME_PRETRAINED_MODEL=./outputs/chartqa-sft/final_checkpoint
609
+ bash scripts/train_opd_7b_chartqa_deepspeed.sh
610
+ ```
611
+
612
+ **200-step smoke (OPD fixes):**
613
+
614
+ ```bash
615
+ bash scripts/train_opd_7b_smoke.sh
616
+ # Success: degenerate_rate<0.5, opsd_mask>8%, advantage_abs_mean>0, grad_norm>0
617
+ ```
618
+
619
+ **Cross-model OPD (7B frozen teacher + 0.5B student):**
620
+
621
+ ```bash
622
+ # Default: teacher on each rank's GPU (cuda:LOCAL_RANK). 2-GPU: student+teacher share the same card per rank.
623
+ # Optional dedicated teacher GPU: export DYME_TEACHER_DEVICE_MAP=cuda:1
624
+ # Vocab alignment debug at startup: DYME_VOCAB_ALIGN_FULL=1 (exhaustive) or DYME_VOCAB_ALIGN_STRIDE=500
625
+ bash scripts/train_opd_7b_chartqa.sh
626
+ ```
627
+
628
+ Note: `main.py --mode rl --config config/config.py` uses **`dyme_args`** (not the unused `grpo_args` block in the same file). Pure GRPO baselines use `main_rebuttal.py`.
629
+
630
+ **Helper scripts** (under `scripts/`):
631
+
632
+ ```bash
633
+ # TriMode on ChartQA (legacy; leakage risk on ChartQA)
634
+ bash scripts/train_trimode.sh
635
+
636
+ # Anti-leakage RLSD (recommended)
637
+ bash scripts/train_rlsd_chartqa.sh
638
+
639
+ # Ablation matrix: MODE=dyme|trimode|replace_sft|opsd_only|...
640
+ MODE=trimode DYME_OPSD_PROVIDERS=text,visual_facts bash scripts/train_baselines.sh
641
+
642
+ # Post-training eval (set CHECKPOINT_DIR)
643
+ CHECKPOINT_DIR=./outputs/trimode-chartqa/final_checkpoint bash scripts/run_eval_ablation.sh
644
+ ```
645
+
646
+ ### 3. Reproducing Baselines
647
+
648
+ To reproduce baseline settings such as standard SFT or RL training, use **`main_sft.py`** (offline ChartQA SFT) or `main.py` with `--opsd_enabled` off for pure DyME.
649
+
650
+ #### Supervised Fine-Tuning (SFT) — offline two-stage
651
+
652
+ ```bash
653
+ bash scripts/train_chartqa_sft.sh
654
+ # or: accelerate launch main_sft.py --config config/config_rlsd_chartqa.py
655
+ ```
656
+
657
+ Then point RLSD/OPD at the SFT checkpoint via `DYME_PRETRAINED_MODEL` or `MODEL_CONFIG.pretrained_model_path`.
658
+
659
+ #### Reinforcement Learning (GRPO / RL)
660
+
661
+ ```bash
662
+ accelerate launch main.py --config config/config.py --mode rl
663
+ ```
664
+
665
+ (`main_rebuttal.py` is referenced in the original DyME paper repo but is not shipped here; use `main_sft.py` + `main.py` instead.)
666
+
667
+ ### 4. Additional Experimental Variants
668
+
669
+ For specific experimental settings such as different model scales or architecture-specific ablations, please use the corresponding scripts:
670
+
671
+ * `main_7B.py`: experiments at the 7B scale
672
+ * `main_llm.py`: LLM-specific variants
673
+ * `main_change.py`: additional ablation settings
674
+
675
+
676
+ ## Evaluation
677
+
678
+ We support multi-process evaluation through `accelerate`. Evaluation scripts are located in the `eval/` directory and can be launched as Python modules.
679
+
680
+ ### General Usage
681
+
682
+ ```bash
683
+ accelerate launch -m eval.<eval_script_name>
684
+ ```
685
+
686
+ ### Example: ChartQA Evaluation
687
+
688
+ ```bash
689
+ accelerate launch -m eval.eval_chartqa
690
+ ```
691
+
692
+ ### Evaluation Setup
693
+
694
+ Before running evaluation, please open the corresponding evaluation script (for example, `eval_chartqa.py`) and modify the following fields as needed:
695
+
696
+ * `model_id`: the path or identifier of the checkpoint to be evaluated
697
+ * prompt templates: these should match the formatting used during training
698
+
699
+ Ensuring consistency between training and evaluation prompts is important for obtaining reliable results.
700
+
701
+ ## Citation
702
+
703
+ If you find this repository useful in your research, please consider citing our paper:
704
+
705
+ ```bibtex
706
+ @inproceedings{dyme2026,
707
+ title={Empowering Small VLMs to Think with Dynamic Memorization and Exploration},
708
+ author={Jiazhen Liu, Yuchuan Deng, Long Chen},
709
+ booktitle={ICLR},
710
+ year={2026},
711
+ }
712
+ ```
client_utils/openai_api.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import httpx
4
+ from openai import OpenAI
5
+ from typing import Optional
6
+
7
+
8
+ # It's good practice to define a simple configuration object or use a dictionary
9
+ # for passing credentials, rather than a generic object.
10
+ # For this example, we'll assume a config object like this.
11
+ class ClientConfig:
12
+ def __init__(self, api_key: str, base_url: str, model_id: str):
13
+ self.api_key = api_key
14
+ self.base_url = base_url
15
+ self.model_id = model_id
16
+
17
+
18
+ class OpenAIClient:
19
+ """
20
+ A client wrapper for interacting with the OpenAI ChatCompletion API.
21
+ It handles client initialization and API calls with retry logic.
22
+ """
23
+
24
+ def __init__(self, config: ClientConfig, max_retries: int = 3):
25
+ # The OpenAI client is initialized directly within the class constructor.
26
+ # This improves encapsulation by making the class self-contained.
27
+ # It takes the configuration object as a direct argument.
28
+ custom_http_client = httpx.Client(trust_env=False)
29
+ self.client = OpenAI(
30
+ api_key=config['api_key'], # Required: your API key
31
+ base_url=config['api_base'], # Optional: only needed for third-party services
32
+ http_client=custom_http_client,
33
+ )
34
+ self.model_id = config['model_id']
35
+ self.max_retries = max_retries
36
+
37
+ def get_completion(
38
+ self,
39
+ user_prompt: str,
40
+ system_prompt: Optional[str] = None,
41
+ max_tokens: int = 1024
42
+ ) -> Optional[str]:
43
+ """
44
+ Calls the OpenAI ChatCompletion API and returns the result.
45
+ Includes retry logic for handling transient errors.
46
+
47
+ Args:
48
+ user_prompt (str): The main input/prompt from the user.
49
+ system_prompt (Optional[str]): The system-level instruction for the model. Defaults to None.
50
+ max_tokens (int): The maximum number of tokens to generate. Defaults to 1024.
51
+
52
+ Returns:
53
+ Optional[str]: The content of the model's response, or None if the API call fails after all retries.
54
+ """
55
+ # Build the message list based on provided prompts.
56
+ messages = []
57
+ if system_prompt:
58
+ messages.append({'role': 'system', 'content': system_prompt})
59
+ messages.append({'role': 'user', 'content': user_prompt})
60
+
61
+ # Implement a clear retry loop instead of 'while True'.
62
+ for attempt in range(self.max_retries):
63
+ try:
64
+ # Make the API call to the chat completions endpoint.
65
+ response = self.client.chat.completions.create(
66
+ model=self.model_id,
67
+ messages=messages,
68
+ max_tokens=max_tokens,
69
+ # temperature=0.2 # You can add other parameters here as needed.
70
+ )
71
+ # If the call is successful, return the message content and exit the loop.
72
+ return response.choices[0].message.content
73
+
74
+ except Exception as e:
75
+ # If an error occurs, print a helpful message.
76
+ print(f"API call failed on attempt {attempt + 1}/{self.max_retries}. Error: {e}")
77
+
78
+ # If this was the last attempt, break the loop to return None.
79
+ if attempt + 1 == self.max_retries:
80
+ print("All retry attempts failed.")
81
+ break
82
+
83
+ # Wait for a short period before trying again.
84
+ print("Retrying in 2 seconds...")
85
+ time.sleep(2)
86
+
87
+ # Return None if all retries fail.
88
+ return None
89
+
90
+
91
+ # --- How to Use the Refactored Class ---
92
+ if __name__ == '__main__':
93
+ # 1. Define your configuration
94
+ # Replace with your actual credentials and mode
95
+
96
+ CLIENT_CONFIG = {
97
+ "client_type": "openai",
98
+ "api_key": "none",
99
+ "api_base": "http://127.0.0.1:23333/v1",
100
+ "timeout": 60,
101
+ "model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
102
+ "init_port": 23333,
103
+ "num_server": 8
104
+ }
105
+ # 2. Instantiate the client
106
+ my_client = OpenAIClient(config=CLIENT_CONFIG)
107
+
108
+ # 3. Define your prompts
109
+ user_message = "What is the capital of France?"
110
+ system_message = "You are a helpful assistant that provides concise answers."
111
+
112
+ # 4. Get the model's response
113
+ response_content = my_client.get_completion(
114
+ user_prompt=user_message,
115
+ system_prompt=system_message
116
+ )
117
+
118
+ # 5. Print the result
119
+ if response_content:
120
+ print("\nModel Response:")
121
+ print(response_content)
122
+ else:
123
+ print("\nFailed to get a response from the model.")
config/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config.config import (
2
+ CLIENT_CONFIG,
3
+ CONFIG,
4
+ DATASET_CONFIG,
5
+ DYME_OPSD_CONFIG,
6
+ MODEL_CONFIG,
7
+ RL_CONFIG,
8
+ TRAINING_CONFIG,
9
+ save_config,
10
+ )
11
+
12
+ __all__ = [
13
+ "CLIENT_CONFIG",
14
+ "CONFIG",
15
+ "DATASET_CONFIG",
16
+ "DYME_OPSD_CONFIG",
17
+ "MODEL_CONFIG",
18
+ "RL_CONFIG",
19
+ "TRAINING_CONFIG",
20
+ "save_config",
21
+ ]
config/__pycache__/config_rlsd_chartqa.cpython-310.pyc ADDED
Binary file (4 kB). View file
 
config/__pycache__/config_trimode.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
config/__pycache__/config_trimode_antidegen.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
config/__pycache__/loader.cpython-312.pyc ADDED
Binary file (3.36 kB). View file
 
config/config_7B.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ # ====== Model Configuration ======
5
+ MODEL_CONFIG = {
6
+ "pretrained_model_path": "Qwen/Qwen2.5-VL-7B-Instruct",
7
+ "use_flash_attention_2": False,
8
+ "torch_dtype": "bfloat16",
9
+ }
10
+
11
+ # ====== Training Configuration ======
12
+ TRAINING_CONFIG = {
13
+ "task": 'chart',
14
+ "num_gpus": 8,
15
+ "num_client": 8,
16
+
17
+ "dyme_args": {
18
+ "output_dir": '/path/to/dyme-qwen25_7B-chart-llava_cot',
19
+ "logging_steps": 1,
20
+ "num_generations": 8,
21
+ "max_completion_length": 300,
22
+ "per_device_train_batch_size": 1,
23
+ "gradient_accumulation_steps": 16,
24
+ "num_train_epochs": 10,
25
+ "learning_rate": 1e-5,
26
+ "bf16": True,
27
+ "gradient_checkpointing": False,
28
+ "ddp_find_unused_parameters": False,
29
+ "max_grad_norm": 1.0,
30
+ "save_steps": 100,
31
+ "weight_decay": 0.01,
32
+ "warmup_steps": 0,
33
+ "eval_strategy": "steps",
34
+ "eval_steps": 10000,
35
+ "beta": 0.0, # GRPO specific
36
+ "loss_type": 'grpo', # GRPO specific
37
+ "seed": 42,
38
+ },
39
+ }
40
+
41
+ RL_CONFIG = {
42
+ "answer_flag": "Answer:",
43
+ "end_flag": "<|im_end|>"
44
+ }
45
+
46
+ # ====== Client Configuration for Reward Calculation ======
47
+ CLIENT_CONFIG = {
48
+ "client_type": "openai",
49
+ "api_key": "none",
50
+ "api_base": "http://127.0.0.1:%s/v1",
51
+ "timeout": 60,
52
+ "model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
53
+ "init_port": 23333,
54
+ "num_server": 8
55
+ }
56
+
57
+ # ====== Dataset Configuration ======
58
+ DATASET_CONFIG = {
59
+
60
+ "train_dataset": "/path/to/data/chartqa_output/llavacot/json/chartqa_train_processed.json",
61
+ "eval_dataset": "HuggingFaceM4/ChartQA",
62
+ }
63
+
64
+
65
+ CONFIG = {
66
+ "model": MODEL_CONFIG,
67
+ "training": TRAINING_CONFIG,
68
+ "rl": RL_CONFIG,
69
+ "client": CLIENT_CONFIG,
70
+ "dataset": DATASET_CONFIG,
71
+ }
72
+
73
+
74
+ def save_config(config, config_path="./config.json"):
75
+ import json
76
+ with open(config_path, "w") as f:
77
+ json.dump(config, f, indent=4)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ save_config(CONFIG)
82
+
config/config_aok.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from data_utils.paths import project_path
5
+
6
+
7
+ MODEL_CONFIG = {
8
+ "pretrained_model_path": "llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
9
+ "use_flash_attention_2": True,
10
+ "torch_dtype": "bfloat16",
11
+ }
12
+
13
+
14
+ TRAINING_CONFIG = {
15
+ "task": 'world',
16
+ "num_gpus": 8,
17
+ "num_client": 8,
18
+ "dyme_args": {
19
+ "output_dir": '/path/to/dyme-aok-online',
20
+ "logging_steps": 1,
21
+ "num_generations": 4,
22
+ "max_completion_length": 300,
23
+ "per_device_train_batch_size": 1,
24
+ "gradient_accumulation_steps": 16,
25
+ "num_train_epochs": 10,
26
+ "learning_rate": 1e-5,
27
+ "bf16": True,
28
+ "gradient_checkpointing": False,
29
+ "ddp_find_unused_parameters": False,
30
+ "max_grad_norm": 1.0,
31
+ "save_strategy": "epoch",
32
+ "weight_decay": 0.01,
33
+ "warmup_steps": 0,
34
+ "beta": 0.0,
35
+ "loss_type": 'grpo',
36
+ "seed": 42,
37
+ },
38
+ "sft_args": {
39
+ "output_dir": '/path/to/sft-aok',
40
+ "logging_steps": 1,
41
+ "per_device_train_batch_size": 2,
42
+ "gradient_accumulation_steps": 4,
43
+ "num_train_epochs": 10,
44
+ "learning_rate": 1e-5,
45
+ "bf16": True,
46
+ "gradient_checkpointing": False,
47
+ "ddp_find_unused_parameters": False,
48
+ "max_grad_norm": 1.0,
49
+ "save_strategy": "epoch",
50
+ "weight_decay": 0.01,
51
+ "warmup_steps": 0,
52
+ "seed": 42,
53
+ "remove_unused_columns": False
54
+ },
55
+ "grpo_args":{
56
+ "output_dir": '/path/to/grpo-aok',
57
+ "logging_steps": 1,
58
+ "num_generations": 4,
59
+ "max_completion_length": 576,
60
+ "max_prompt_length": None,
61
+ "per_device_train_batch_size": 2,
62
+ "gradient_accumulation_steps": 4,
63
+ "num_train_epochs": 10,
64
+ "learning_rate": 1e-5,
65
+ "bf16": True,
66
+ "gradient_checkpointing": False,
67
+ "ddp_find_unused_parameters": False,
68
+ "max_grad_norm": 1.0,
69
+ "save_strategy": "epoch",
70
+ "weight_decay": 0.01,
71
+ "warmup_steps": 0,
72
+ "beta": 0.04,
73
+ "loss_type": 'grpo',
74
+ "seed": 42,
75
+ }
76
+
77
+ }
78
+
79
+ RL_CONFIG = {
80
+ "answer_flag": "Answer:",
81
+ "end_flag": "<|im_end|>"
82
+ }
83
+
84
+ # ====== Client Configuration for Reward Calculation ======
85
+ CLIENT_CONFIG = {
86
+ "client_type": "openai",
87
+ "api_key": "none",
88
+ "api_base": "http://127.0.0.1:%s/v1",
89
+ "timeout": 60,
90
+ "model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
91
+ "init_port": 23333,
92
+ "num_server": 8
93
+ }
94
+
95
+ # ====== Dataset Configuration ======
96
+ DATASET_CONFIG = {
97
+ "train_dataset": project_path("data/aokvqa/train.json"),
98
+ "eval_dataset": "HuggingFaceM4/A-OKVQA",
99
+ }
100
+
101
+ # ====== Full Configuration ======
102
+ CONFIG = {
103
+ "model": MODEL_CONFIG,
104
+ "training": TRAINING_CONFIG,
105
+ "rl": RL_CONFIG,
106
+ "client": CLIENT_CONFIG,
107
+ "dataset": DATASET_CONFIG,
108
+ }
109
+
110
+ # Save configuration to a file for reference
111
+ def save_config(config, config_path="./config.json"):
112
+ import json
113
+ with open(config_path, "w") as f:
114
+ json.dump(config, f, indent=4)
115
+
116
+ # Example usage to save config
117
+ if __name__ == "__main__":
118
+ save_config(CONFIG)
119
+
config/config_llavacot.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ # ====== Model Configuration ======
5
+ MODEL_CONFIG = {
6
+ "pretrained_model_path": '/path/to/sft-llavaov-chart-llava_cot/checkpoint-802', # two-stage grpo
7
+ "use_flash_attention_2": True,
8
+ "torch_dtype": "bfloat16",
9
+ }
10
+
11
+ # ====== Training Configuration ======
12
+ TRAINING_CONFIG = {
13
+ "task": 'chart',
14
+ "num_gpus": 8,
15
+ "num_client": 8,
16
+ "dyme_args": {
17
+ "output_dir": '/path/to/dyme-llavaov-chart-llava_cot',
18
+ "logging_steps": 1,
19
+ "num_generations": 4,
20
+ "max_completion_length": 300,
21
+ "per_device_train_batch_size": 1,
22
+ "gradient_accumulation_steps": 16,
23
+ "num_train_epochs": 10,
24
+ "learning_rate": 8e-5,
25
+ "bf16": True,
26
+ "gradient_checkpointing": False,
27
+ "ddp_find_unused_parameters": False,
28
+ "max_grad_norm": 1.0,
29
+ "save_strategy": "epoch",
30
+ "weight_decay": 0.01,
31
+ "warmup_steps": 0,
32
+ "beta": 0.0, # GRPO specific
33
+ "loss_type": 'grpo', # GRPO specific
34
+ "seed": 42,
35
+ },
36
+ "sft_args": {
37
+ "output_dir": '/path/to/sft-llavaov-chart-llava_cot',
38
+ "logging_steps": 1,
39
+ "per_device_train_batch_size": 2,
40
+ "gradient_accumulation_steps": 4,
41
+ "num_train_epochs": 10,
42
+ "learning_rate": 1e-5,
43
+ "bf16": True,
44
+ "gradient_checkpointing": False,
45
+ "ddp_find_unused_parameters": False,
46
+ "max_grad_norm": 1.0,
47
+ "max_length": 4096,
48
+ "save_strategy": "epoch",
49
+ "weight_decay": 0.01,
50
+ "warmup_steps": 0,
51
+ "seed": 42,
52
+ "remove_unused_columns": False
53
+ },
54
+ "grpo_args":{
55
+ "output_dir": '/path/to/grpo-llavaov-chart-beta',
56
+ "logging_steps": 1,
57
+ "num_generations": 4,
58
+ "max_completion_length": 576,
59
+ "max_prompt_length": None,
60
+ "per_device_train_batch_size": 4,
61
+ "gradient_accumulation_steps": 4,
62
+ "num_train_epochs": 10,
63
+ "learning_rate": 1e-5,
64
+ "bf16": True,
65
+ "gradient_checkpointing": False,
66
+ "ddp_find_unused_parameters": False,
67
+ "max_grad_norm": 1.0,
68
+ "save_strategy": "epoch",
69
+ "weight_decay": 0.01,
70
+ "warmup_steps": 0,
71
+ "beta": 0.04, # GRPO specific
72
+ "loss_type": 'grpo', # GRPO specific
73
+ "seed": 42,
74
+ }
75
+
76
+ }
77
+
78
+ RL_CONFIG = {
79
+ "answer_flag": "Answer:",
80
+ "end_flag": "<|im_end|>"
81
+ }
82
+
83
+ # ====== Client Configuration for Reward Calculation ======
84
+ CLIENT_CONFIG = {
85
+ "client_type": "openai",
86
+ "api_key": "none",
87
+ "api_base": "http://127.0.0.1:%s/v1",
88
+ "timeout": 60,
89
+ "model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
90
+ "init_port": 23333,
91
+ "num_server": 8
92
+ }
93
+
94
+ # ====== Dataset Configuration ======
95
+ DATASET_CONFIG = {
96
+ "train_dataset": "/path/to/data/chartqa_output/llavacot/json/chartqa_train_processed.json",
97
+ "eval_dataset": "HuggingFaceM4/ChartQA",
98
+ }
99
+
100
+ # ====== Full Configuration ======
101
+ CONFIG = {
102
+ "model": MODEL_CONFIG,
103
+ "training": TRAINING_CONFIG,
104
+ "rl": RL_CONFIG,
105
+ "client": CLIENT_CONFIG,
106
+ "dataset": DATASET_CONFIG,
107
+ }
108
+
109
+ # Save configuration to a file for reference
110
+ def save_config(config, config_path="./config.json"):
111
+ import json
112
+ with open(config_path, "w") as f:
113
+ json.dump(config, f, indent=4)
114
+
115
+ # Example usage to save config
116
+ if __name__ == "__main__":
117
+ save_config(CONFIG)
118
+
config/config_low.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ # ====== Model Configuration ======
5
+ MODEL_CONFIG = {
6
+ "pretrained_model_path": "/path/to/sft-llavaov-chart-low/checkpoint-296",
7
+ "use_flash_attention_2": True,
8
+ "torch_dtype": "bfloat16",
9
+ }
10
+
11
+ # ====== Training Configuration ======
12
+ TRAINING_CONFIG = {
13
+ "task": 'chart',
14
+ "num_gpus": 8,
15
+ "num_client": 8,
16
+
17
+ "dyme_args": {
18
+ "output_dir": '/path/to/dyme-llavaov-chart-low',
19
+ "logging_steps": 1,
20
+ "num_generations": 4,
21
+ "max_completion_length": 300,
22
+ "per_device_train_batch_size": 1,
23
+ "gradient_accumulation_steps": 16,
24
+ "num_train_epochs": 10,
25
+ "learning_rate": 1e-5,
26
+ "bf16": True,
27
+ "gradient_checkpointing": False,
28
+ "ddp_find_unused_parameters": False,
29
+ "max_grad_norm": 1.0,
30
+ "save_strategy": "epoch",
31
+ "weight_decay": 0.01,
32
+ "warmup_steps": 0,
33
+ "beta": 0.0, # GRPO specific
34
+ "loss_type": 'grpo', # GRPO specific
35
+ "seed": 42,
36
+ },
37
+ "sft_args": {
38
+ "output_dir": '/path/to/sft-llavaov-chart-low',
39
+ "logging_steps": 1,
40
+ "per_device_train_batch_size": 2,
41
+ "gradient_accumulation_steps": 4,
42
+ "num_train_epochs": 10,
43
+ "learning_rate": 1e-5,
44
+ "bf16": True,
45
+ "gradient_checkpointing": False,
46
+ "ddp_find_unused_parameters": False,
47
+ "max_grad_norm": 1.0,
48
+ "max_length": 4096,
49
+ # "save_steps": 100,
50
+ "save_strategy": "epoch",
51
+ "weight_decay": 0.01,
52
+ "warmup_steps": 0,
53
+ "seed": 42,
54
+ "remove_unused_columns": False
55
+ },
56
+ "grpo_args":{
57
+ "output_dir": '/path/to/grpo-llavaov-chart-low',
58
+ "logging_steps": 1,
59
+ "num_generations": 4,
60
+ "max_completion_length": 576,
61
+ "max_prompt_length": None,
62
+ "per_device_train_batch_size": 2,
63
+ "gradient_accumulation_steps": 4,
64
+ "num_train_epochs": 10,
65
+ "learning_rate": 1e-5,
66
+ "bf16": True,
67
+ "gradient_checkpointing": False,
68
+ "ddp_find_unused_parameters": False,
69
+ "max_grad_norm": 1.0,
70
+ "save_strategy": "epoch",
71
+ "weight_decay": 0.01,
72
+ "warmup_steps": 0,
73
+ "beta": 0.0, # GRPO specific
74
+ "loss_type": 'grpo', # GRPO specific
75
+ "seed": 42,
76
+ }
77
+
78
+ }
79
+
80
+ RL_CONFIG = {
81
+ "answer_flag": "Answer:",
82
+ "end_flag": "<|im_end|>"
83
+ }
84
+
85
+ # ====== Client Configuration for Reward Calculation ======
86
+ CLIENT_CONFIG = {
87
+ "client_type": "openai",
88
+ "api_key": "none",
89
+ "api_base": "http://127.0.0.1:%s/v1",
90
+ "timeout": 60,
91
+ "model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
92
+ "init_port": 23333,
93
+ "num_server": 8
94
+ }
95
+
96
+ # ====== Dataset Configuration ======
97
+ DATASET_CONFIG = {
98
+ "train_dataset": "/path/to/data/chartqa_output/json/train_low.json",
99
+ "eval_dataset": "HuggingFaceM4/ChartQA",
100
+ }
101
+
102
+ # ====== Full Configuration ======
103
+ CONFIG = {
104
+ "model": MODEL_CONFIG,
105
+ "training": TRAINING_CONFIG,
106
+ "rl": RL_CONFIG,
107
+ "client": CLIENT_CONFIG,
108
+ "dataset": DATASET_CONFIG,
109
+ }
110
+
111
+ # Save configuration to a file for reference
112
+ def save_config(config, config_path="./config.json"):
113
+ import json
114
+ with open(config_path, "w") as f:
115
+ json.dump(config, f, indent=4)
116
+
117
+ # Example usage to save config
118
+ if __name__ == "__main__":
119
+ save_config(CONFIG)
120
+
config/config_opd_7b_chartqa.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ COPSD-style cross-model OPD on ChartQA (Method 2).
3
+
4
+ Frozen LLaVA-OneVision 7B teacher; student default 0.5B.
5
+ Inherits RLSD routing + embedded SFT cold-start gates from config_rlsd_chartqa.
6
+ """
7
+ import os
8
+
9
+ import config.config_rlsd_chartqa as rlsd
10
+ from data_utils.paths import OUTPUTS_DIR
11
+
12
+ MODEL_CONFIG = {
13
+ **rlsd.MODEL_CONFIG,
14
+ "teacher_model_path": os.environ.get(
15
+ "DYME_TEACHER_MODEL",
16
+ "llava-hf/llava-onevision-qwen2-7b-ov-hf",
17
+ ),
18
+ "teacher_dtype": os.environ.get("DYME_TEACHER_DTYPE", "bfloat16"),
19
+ "teacher_device_map": os.environ.get("DYME_TEACHER_DEVICE_MAP") or None,
20
+ }
21
+
22
+ DYME_OPSD_CONFIG = {
23
+ **rlsd.DYME_OPSD_CONFIG,
24
+ "mode": os.environ.get("DYME_OPSD_MODE", "rlsd"),
25
+ "privileged_providers": [],
26
+ "loss": {
27
+ **rlsd.DYME_OPSD_CONFIG.get("loss", {}),
28
+ "opsd_weight": float(os.environ.get("DYME_OPSD_WEIGHT", "1.5")),
29
+ },
30
+ }
31
+
32
+ CONFIG = {
33
+ "model": MODEL_CONFIG,
34
+ "training": {
35
+ **rlsd.CONFIG["training"],
36
+ "dyme_args": {
37
+ **rlsd.CONFIG["training"]["dyme_args"],
38
+ "output_dir": os.environ.get(
39
+ "DYME_OUTPUT_DIR",
40
+ os.path.join(OUTPUTS_DIR, "opd-7b-chartqa"),
41
+ ),
42
+ },
43
+ },
44
+ "rl": rlsd.CONFIG["rl"],
45
+ "opsd": DYME_OPSD_CONFIG,
46
+ "client": rlsd.CONFIG["client"],
47
+ "dataset": rlsd.CONFIG["dataset"],
48
+ }
config/config_rlsd_chartqa.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RLSD / anti-leakage ChartQA config (Method 1).
3
+
4
+ - mode=rlsd: correct → GRPO, wrong → same-prompt OPSD, all-wrong group → online SFT
5
+ - No gold answer / visual_facts in teacher privileged context
6
+ - Hyperparameters based on config_trimode_antidegen
7
+ """
8
+ import os
9
+
10
+ import config.config_trimode_antidegen as antidegen
11
+ from data_utils.paths import OUTPUTS_DIR
12
+
13
+ MODEL_CONFIG = dict(antidegen.MODEL_CONFIG)
14
+
15
+ TRAINING_CONFIG = dict(antidegen.TRAINING_CONFIG)
16
+
17
+ _reward_weights_raw = os.environ.get("DYME_REWARD_WEIGHTS", "0.5,1.5,1.0")
18
+ try:
19
+ _reward_weights = [float(x.strip()) for x in _reward_weights_raw.split(",") if x.strip()]
20
+ if len(_reward_weights) != 3:
21
+ raise ValueError("expected 3 weights")
22
+ except ValueError:
23
+ _reward_weights = [0.5, 1.5, 1.0]
24
+
25
+ _providers_raw = os.environ.get("DYME_OPSD_PROVIDERS", "format_only").strip()
26
+ _privileged_providers = [p.strip() for p in _providers_raw.split(",") if p.strip()] if _providers_raw else []
27
+
28
+ _skip_degen_env = os.environ.get("DYME_OPSD_SKIP_DEGENERATE", "").strip().lower()
29
+ if _skip_degen_env in ("0", "false", "no", "off"):
30
+ _skip_degenerate_for_opsd = False
31
+ elif _skip_degen_env in ("1", "true", "yes", "on"):
32
+ _skip_degenerate_for_opsd = True
33
+ else:
34
+ _skip_degenerate_for_opsd = True
35
+
36
+ # Embedded SFT cold-start + RLSD warmup gates (env overrides optional).
37
+ _RLSD_GATE_DEFAULTS = {
38
+ "skip_degenerate_for_opsd": _skip_degenerate_for_opsd,
39
+ "degen_skip_warmup_steps": 200,
40
+ "sft_warmup_steps": 500,
41
+ "sft_warmup_slots_per_group": 4,
42
+ # First N steps: skip generate, 100% GT injection, pure SFT NLL (no OPSD/GRPO).
43
+ "sft_cold_start_frac": 0.08,
44
+ }
45
+
46
+ DYME_OPSD_CONFIG = {
47
+ **antidegen.DYME_OPSD_CONFIG,
48
+ "mode": os.environ.get("DYME_OPSD_MODE", "rlsd"),
49
+ "text_include_gold": False,
50
+ "privileged_profile": os.environ.get("DYME_OPSD_PRIVILEGE_PROFILE", "text"),
51
+ "privileged_providers": _privileged_providers,
52
+ "gate": {
53
+ **antidegen.DYME_OPSD_CONFIG.get("gate", {}),
54
+ "per_completion_opsd": True,
55
+ "recoverable_without_privilege": True,
56
+ "require_format_for_opsd": os.environ.get("DYME_OPSD_REQUIRE_FORMAT", "0").strip().lower()
57
+ not in ("0", "false", "no", "off"),
58
+ "online_sft_on_all_wrong": True,
59
+ # ChartQA short numeric answers lack "Answer:" — do not block OPSD on format alone
60
+ "opsd_degenerate_require_answer_flag": False,
61
+ **_RLSD_GATE_DEFAULTS,
62
+ },
63
+ "loss": {
64
+ **antidegen.DYME_OPSD_CONFIG.get("loss", {}),
65
+ "acc_gate": True,
66
+ "opsd_weight": float(os.environ.get("DYME_OPSD_WEIGHT", "1.5")),
67
+ "grpo_weight": 1.0,
68
+ },
69
+ "reward_weights": _reward_weights,
70
+ }
71
+
72
+ _dyme_args = {
73
+ **TRAINING_CONFIG["dyme_args"],
74
+ "output_dir": os.environ.get(
75
+ "DYME_OUTPUT_DIR",
76
+ os.path.join(OUTPUTS_DIR, "rlsd-chartqa"),
77
+ ),
78
+ # Mitigate early RL collapse (newline + bare number + immediate EOS)
79
+ "max_completion_length": 96,
80
+ "temperature": 0.5,
81
+ "repetition_penalty": 1.5,
82
+ }
83
+ _max_steps_raw = os.environ.get("DYME_MAX_STEPS", "").strip()
84
+ if _max_steps_raw:
85
+ _dyme_args["max_steps"] = int(_max_steps_raw)
86
+
87
+ _temp_raw = os.environ.get("DYME_TEMPERATURE", "").strip()
88
+ if _temp_raw:
89
+ _dyme_args["temperature"] = float(_temp_raw)
90
+ _rep_raw = os.environ.get("DYME_REPETITION_PENALTY", "").strip()
91
+ if _rep_raw:
92
+ _dyme_args["repetition_penalty"] = float(_rep_raw)
93
+ _max_len_raw = os.environ.get("DYME_MAX_COMPLETION_LENGTH", "").strip()
94
+ if _max_len_raw:
95
+ _dyme_args["max_completion_length"] = int(_max_len_raw)
96
+
97
+ # Keep module-level TRAINING_CONFIG in sync so imports of TRAINING_CONFIG["dyme_args"] match CONFIG.
98
+ TRAINING_CONFIG = {**TRAINING_CONFIG, "dyme_args": _dyme_args}
99
+
100
+ # Optional env overrides for gate defaults (see _RLSD_GATE_DEFAULTS above).
101
+ _degen_warmup_raw = os.environ.get("DYME_OPSD_DEGEN_WARMUP_STEPS", "").strip()
102
+ if _degen_warmup_raw:
103
+ DYME_OPSD_CONFIG["gate"]["degen_skip_warmup_steps"] = int(_degen_warmup_raw)
104
+
105
+ _sft_warmup_raw = os.environ.get("DYME_SFT_WARMUP_STEPS", "").strip()
106
+ if _sft_warmup_raw:
107
+ DYME_OPSD_CONFIG["gate"]["sft_warmup_steps"] = int(_sft_warmup_raw)
108
+
109
+ _sft_slots_raw = os.environ.get("DYME_SFT_WARMUP_SLOTS", "").strip()
110
+ if _sft_slots_raw:
111
+ DYME_OPSD_CONFIG["gate"]["sft_warmup_slots_per_group"] = int(_sft_slots_raw)
112
+
113
+ _cold_start_steps_raw = os.environ.get("DYME_SFT_COLD_START_STEPS", "").strip()
114
+ if _cold_start_steps_raw:
115
+ DYME_OPSD_CONFIG["gate"]["sft_cold_start_steps"] = int(_cold_start_steps_raw)
116
+ DYME_OPSD_CONFIG["gate"].pop("sft_cold_start_frac", None)
117
+ else:
118
+ _cold_start_frac_raw = os.environ.get("DYME_SFT_COLD_START_FRAC", "").strip()
119
+ if _cold_start_frac_raw:
120
+ DYME_OPSD_CONFIG["gate"]["sft_cold_start_frac"] = float(_cold_start_frac_raw)
121
+
122
+ CONFIG = {
123
+ "model": MODEL_CONFIG,
124
+ "training": {
125
+ **TRAINING_CONFIG,
126
+ "dyme_args": _dyme_args,
127
+ "sft_args": {
128
+ "output_dir": os.environ.get(
129
+ "DYME_SFT_OUTPUT_DIR",
130
+ os.path.join(OUTPUTS_DIR, "chartqa-sft"),
131
+ ),
132
+ "logging_steps": 10,
133
+ "per_device_train_batch_size": 2,
134
+ "gradient_accumulation_steps": 4,
135
+ "num_train_epochs": int(os.environ.get("DYME_SFT_EPOCHS", "2")),
136
+ "learning_rate": 1e-5,
137
+ "bf16": True,
138
+ "gradient_checkpointing": True,
139
+ "ddp_find_unused_parameters": False,
140
+ "max_grad_norm": 1.0,
141
+ "save_strategy": "epoch",
142
+ "weight_decay": 0.01,
143
+ "warmup_steps": 0,
144
+ "seed": 42,
145
+ "remove_unused_columns": False,
146
+ },
147
+ },
148
+ "rl": antidegen.CONFIG["rl"],
149
+ "opsd": DYME_OPSD_CONFIG,
150
+ "client": antidegen.CONFIG["client"],
151
+ "dataset": antidegen.CONFIG["dataset"],
152
+ }
config/config_trimode.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from config import CLIENT_CONFIG, DATASET_CONFIG, DYME_OPSD_CONFIG, MODEL_CONFIG, RL_CONFIG, TRAINING_CONFIG
4
+ from data_utils.paths import OUTPUTS_DIR
5
+
6
+ MODEL_CONFIG = dict(MODEL_CONFIG)
7
+
8
+ TRAINING_CONFIG = {
9
+ **TRAINING_CONFIG,
10
+ "dyme_args": {
11
+ **TRAINING_CONFIG["dyme_args"],
12
+ "output_dir": os.environ.get("DYME_OUTPUT_DIR", os.path.join(OUTPUTS_DIR, "dyme-trimode")),
13
+ },
14
+ }
15
+
16
+ _detail_every_raw = os.environ.get("DYME_OPSD_DETAIL_EVERY", "10")
17
+ try:
18
+ _detail_every = max(0, int(_detail_every_raw))
19
+ except ValueError:
20
+ _detail_every = 10
21
+
22
+ _probe_raw = os.environ.get("DYME_OPSD_PROBE_ON_GENERATE", "1").strip().lower()
23
+ _probe_on_generate = _probe_raw not in ("0", "false", "no", "off")
24
+
25
+ _first_logits_raw = os.environ.get("DYME_OPSD_PROBE_FIRST_TOKEN_LOGITS", "1").strip().lower()
26
+ _probe_first_token_logits = _first_logits_raw not in ("0", "false", "no", "off")
27
+
28
+ _tail_raw = os.environ.get("DYME_OPSD_PROBE_PROMPT_TAIL_TOKENS", "16").strip()
29
+ try:
30
+ _probe_prompt_tail_tokens = max(1, int(_tail_raw))
31
+ except ValueError:
32
+ _probe_prompt_tail_tokens = 16
33
+
34
+ _model_ctx_raw = os.environ.get("DYME_OPSD_PROBE_LOG_MODEL_CONTEXT", "1").strip().lower()
35
+ _probe_log_model_context = _model_ctx_raw not in ("0", "false", "no", "off")
36
+
37
+ _health_raw = os.environ.get("DYME_OPSD_HEALTH_MONITOR", "1").strip().lower()
38
+ _health_monitor_enabled = _health_raw not in ("0", "false", "no", "off")
39
+
40
+ _require_format_raw = os.environ.get("DYME_OPSD_REQUIRE_FORMAT", "1").strip().lower()
41
+ _require_format_for_opsd = _require_format_raw not in ("0", "false", "no", "off")
42
+
43
+ DYME_OPSD_CONFIG = {
44
+ **DYME_OPSD_CONFIG,
45
+ "enabled": True,
46
+ "mode": os.environ.get("DYME_OPSD_MODE", "trimode"),
47
+ "privileged_profile": os.environ.get("DYME_OPSD_PRIVILEGE_PROFILE", "hybrid"),
48
+ "privileged_providers": os.environ.get("DYME_OPSD_PROVIDERS", "text,visual_facts").split(","),
49
+ "privileged_image": {
50
+ **DYME_OPSD_CONFIG.get("privileged_image", {}),
51
+ "mode": os.environ.get("DYME_OPSD_PRIVILEGE_IMAGE_MODE", "single"),
52
+ "crop_strategy": os.environ.get("DYME_OPSD_CROP_STRATEGY", "bbox_then_center"),
53
+ "bbox_coord": "normalized",
54
+ "margin_ratio": float(os.environ.get("DYME_OPSD_CROP_MARGIN", "0.25")),
55
+ },
56
+ "privileged_debug": {
57
+ **DYME_OPSD_CONFIG.get("privileged_debug", {}),
58
+ "save_images": os.environ.get("DYME_OPSD_SAVE_PRIVILEGED_IMAGES", "1").strip().lower()
59
+ not in ("0", "false", "no", "off"),
60
+ "image_subdir": os.environ.get("DYME_OPSD_PRIVILEGED_IMAGE_DIR", "logs/images"),
61
+ "max_samples_per_detail": int(os.environ.get("DYME_OPSD_PRIVILEGED_IMAGE_MAX", "2")),
62
+ },
63
+ "gate": {
64
+ **DYME_OPSD_CONFIG.get("gate", {}),
65
+ "require_format_for_opsd": _require_format_for_opsd,
66
+ },
67
+ "debug": {
68
+ **DYME_OPSD_CONFIG.get("debug", {}),
69
+ "detail_every": _detail_every,
70
+ "probe_on_generate": _probe_on_generate,
71
+ "probe_first_token_logits": _probe_first_token_logits,
72
+ "probe_prompt_tail_tokens": _probe_prompt_tail_tokens,
73
+ "probe_log_model_context": _probe_log_model_context,
74
+ "health_monitor": {
75
+ **DYME_OPSD_CONFIG.get("debug", {}).get("health_monitor", {}),
76
+ "enabled": _health_monitor_enabled,
77
+ },
78
+ },
79
+ }
80
+
81
+ CONFIG = {
82
+ "model": MODEL_CONFIG,
83
+ "training": TRAINING_CONFIG,
84
+ "rl": RL_CONFIG,
85
+ "opsd": DYME_OPSD_CONFIG,
86
+ "client": CLIENT_CONFIG,
87
+ "dataset": DATASET_CONFIG,
88
+ }
default_config_8gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: bf16
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
default_config_8gpu_deepspeed.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 16
5
+ zero3_init_flag: false
6
+ zero_stage: 0
7
+ distributed_type: DEEPSPEED
8
+ downcast_bf16: 'no'
9
+ enable_cpu_affinity: false
10
+ machine_rank: 0
11
+ main_training_function: main
12
+ mixed_precision: bf16
13
+ num_machines: 1
14
+ # Optional ZeRO-0 for 8-GPU nodes. Requires: pip install deepspeed
15
+ num_processes: 8
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
default_config_zero2_8gpu.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_config_file: configs/deepspeed/zero2_bf16.json
5
+ zero3_init_flag: false
6
+ distributed_type: DEEPSPEED
7
+ downcast_bf16: 'no'
8
+ enable_cpu_affinity: false
9
+ machine_rank: 0
10
+ main_training_function: main
11
+ num_machines: 1
12
+ num_processes: 8
13
+ rdzv_backend: static
14
+ same_network: true
15
+ tpu_env: []
16
+ tpu_use_cluster: false
17
+ tpu_use_sudo: false
18
+ use_cpu: false
eval/eval_chartqa.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from accelerate import Accelerator
7
+ # Ensure this path is correct and the utility is available.
8
+ from datasets import load_dataset
9
+ from torch.distributed import all_gather_object
10
+ from transformers import AutoProcessor, AutoConfig, AutoTokenizer, LlavaOnevisionForConditionalGeneration
11
+ from trl.models import unwrap_model_for_generation
12
+
13
+ from data_utils.chart.evaluator import eval_one_chart
14
+ from data_utils.rl_prompt import PROMPT_TEMPLATE
15
+ from reward_utils.compute_rewards import split_initial_context
16
+
17
+ accelerator = Accelerator()
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+
21
+ DEVICE = accelerator.device
22
+
23
+ # Model and Processor Configuration
24
+ model_args = {} # Use {"torch_dtype": torch.bfloat16} if desired and supported
25
+
26
+ _eval_parser = argparse.ArgumentParser(add_help=False)
27
+ _eval_parser.add_argument("--model_path", default=None)
28
+ _eval_args, _ = _eval_parser.parse_known_args()
29
+ model_id = (
30
+ _eval_args.model_path
31
+ or os.environ.get("CHECKPOINT_DIR")
32
+ or "/path/to/dyme-k-8/final_checkpoint"
33
+ )
34
+
35
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
36
+ tokenizer = AutoTokenizer.from_pretrained(model_id, config=config, trust_remote_code=True)
37
+ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
38
+ model_id,
39
+ torch_dtype=torch.bfloat16,
40
+ low_cpu_mem_usage=True,
41
+ ).to(DEVICE)
42
+
43
+ model.eval()
44
+ # Make sure model and processor are loaded before being potentially used in generate_inner if it were called
45
+ # model = Idefics3ForConditionalGeneration.from_pretrained(model_id, **model_args).to(DEVICE)
46
+
47
+ processor = AutoProcessor.from_pretrained(model_id)
48
+
49
+ # Configure image processor size
50
+ # This can consume significant VRAM. Ensure it's intended.
51
+ if hasattr(processor.image_processor, 'size') and isinstance(processor.image_processor.size, dict):
52
+ # if 'longest_edge' in processor.image_processor.size:
53
+ # print('Setting image processor longest_edge to 2048')
54
+ # processor.image_processor.size['longest_edge'] = 512 * 4
55
+ processor.tokenizer.padding_side = 'left'
56
+ else:
57
+ print(
58
+ f"Warning: Could not directly set 'longest_edge' via dict. Current image processor size config: {processor.image_processor.size}"
59
+ )
60
+ # Attempt an alternative if applicable, e.g.
61
+ # processor.image_processor.size = {"longest_edge": 512 * 4} # if size itself can be replaced
62
+ # Or this might indicate that `size` is a single value or a different structure.
63
+
64
+ def run_kh_batch(batch_data_list): # Renamed from run_kh, takes a batch
65
+ batch_images = []
66
+ batch_formatted_prompts_for_chat_template = []
67
+
68
+ for item in batch_data_list:
69
+ image_path = item['image_path']
70
+ # 'item_model_input_text' already contains chart instructions + raw_question
71
+ item_model_input_text = item['model_input_text'].strip()
72
+
73
+ # question_with_tags = prompt + item_model_input_text
74
+ # question_with_tags = f"""{item_model_input_text} Think step by step and then answer the question."""
75
+ question_with_tags = PROMPT_TEMPLATE.format(question=item_model_input_text)
76
+ if isinstance(image_path, str):
77
+ image = Image.open(image_path).convert("RGB")
78
+ else:
79
+ image = image_path.convert("RGB") # Assuming image_path is already a PIL Image object
80
+ batch_images.append(image)
81
+
82
+ messages = [
83
+ {
84
+ "role": "user",
85
+ "content": [
86
+ {"type": "image"},
87
+ {"type": "text", "text": question_with_tags},
88
+ ]
89
+ },
90
+ ]
91
+ try:
92
+ templated_prompt_str = processor.apply_chat_template(messages, add_generation_prompt=True)
93
+ templated_prompt_str = templated_prompt_str.strip()
94
+ except Exception:
95
+ templated_prompt_str = f"USER: <image>\n{question_with_tags}\nASSISTANT:"
96
+ batch_formatted_prompts_for_chat_template.append(templated_prompt_str)
97
+
98
+ inputs = processor(
99
+ text=batch_formatted_prompts_for_chat_template,
100
+ images=batch_images,
101
+ return_tensors="pt",
102
+ padding=True,
103
+ truncation=True
104
+ )
105
+ # inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
106
+ inputs = {
107
+ k: v.to(DEVICE).to(torch.bfloat16) if v.is_floating_point() else v.to(DEVICE)
108
+ for k, v in inputs.items()
109
+ }
110
+
111
+ with unwrap_model_for_generation(model, accelerator) as unwrapped_model_instance:
112
+ generated_ids = unwrapped_model_instance.generate(
113
+ **inputs,
114
+ max_new_tokens=1024,
115
+ do_sample=False,
116
+ )
117
+
118
+ input_ids_length = inputs['input_ids'].shape[1]
119
+ newly_generated_ids = generated_ids[:, input_ids_length:]
120
+
121
+ generated_texts = processor.batch_decode(
122
+ newly_generated_ids,
123
+ skip_special_tokens=True, # Special tokens like <eos> are removed. <image> might be too.
124
+ )
125
+ return [text.strip('.').strip() for text in generated_texts]
126
+
127
+
128
+ # --- Main Evaluation Logic ---
129
+ task = 'chart'
130
+ # dt_record_local is initialized inside the if task == 'chart' block
131
+
132
+ if task == 'chart':
133
+ dt_record_local = {} # Store results for the current process
134
+ if accelerator.is_main_process:
135
+ print("Loading ChartQA dataset...")
136
+ try:
137
+ full_dataset = load_dataset("HuggingFaceM4/ChartQA", trust_remote_code=True)['test']
138
+ except Exception as e:
139
+ if accelerator.is_main_process:
140
+ print(f"Failed to load dataset directly. Error: {e}")
141
+ print("Attempting to load with specific revision if applicable, or check path/connection.")
142
+ # For example, you can try a specific revision (if known) or ensure path and network connection are correct
143
+ # full_dataset = load_dataset("HuggingFaceM4/ChartQA", revision="main", trust_remote_code=True)['test']
144
+ raise # Re-raise the exception since we cannot proceed without the dataset
145
+
146
+ # full_dataset = full_dataset.select(range(80)) # Uncomment for quick testing
147
+
148
+ eval_datasets_all_prepared = []
149
+ # chart_instructions_prefix = (
150
+ # "For the question below, follow the following instructions:\n"
151
+ # # ... (your detailed instructions) ...
152
+ # + "-Try to include the full label from the graph when asked about an entity.\n"
153
+ # + "Question: "
154
+ # )
155
+
156
+ for d_item in tqdm(full_dataset, desc="Preparing dataset", disable=not accelerator.is_main_process):
157
+ image_path = d_item['image']
158
+ raw_question = d_item['query']
159
+ answer_list = d_item.get('label') # Use .get() in case 'label' field does not exist
160
+ if not answer_list: # If 'label' is missing or an empty list
161
+ if accelerator.is_main_process:
162
+ tqdm.write(
163
+ f"Warning: Item missing 'label' or 'label' is empty. Query: {raw_question[:50]}..."
164
+ )
165
+ # Decide how to handle this: skip this sample or use a default answer
166
+ continue # Skip this sample
167
+ answer = answer_list[0]
168
+
169
+ model_input_text_for_template = raw_question
170
+ eval_datasets_all_prepared.append({
171
+ 'image_path': image_path,
172
+ 'model_input_text': model_input_text_for_template,
173
+ 'answer': answer,
174
+ 'original_question': raw_question
175
+ })
176
+
177
+ num_processes = accelerator.num_processes
178
+ process_index = accelerator.process_index
179
+ total_items = len(eval_datasets_all_prepared)
180
+
181
+ if total_items == 0:
182
+ if accelerator.is_main_process:
183
+ print("No data prepared for evaluation after filtering. Exiting chart evaluation.")
184
+ else:
185
+ items_per_proc = total_items // num_processes
186
+ extra_items = total_items % num_processes
187
+ local_start_index = process_index * items_per_proc + min(process_index, extra_items)
188
+ num_local_items = items_per_proc + (1 if process_index < extra_items else 0)
189
+ local_end_index = local_start_index + num_local_items
190
+ eval_datasets_local = eval_datasets_all_prepared[local_start_index:local_end_index]
191
+
192
+ BATCH_SIZE = 32 # Adjust according to your VRAM
193
+ REPORT_INTERVAL_BATCHES = 1 # Report once every N local batches (main process prints global stats)
194
+
195
+ # if accelerator.is_main_process:
196
+ # print(f"Total items for evaluation: {total_items}")
197
+ # print(f"Process {process_index} handling {len(eval_datasets_local)} items.")
198
+ # print(f"Batch size per process: {BATCH_SIZE}, Reporting interval: {REPORT_INTERVAL_BATCHES} local batches.")
199
+
200
+ pbar = None
201
+ if accelerator.is_main_process and len(eval_datasets_local) > 0: # Create pbar only if there is data
202
+ pbar = tqdm(total=len(eval_datasets_local), desc=f"Eval Proc {process_index}", dynamic_ncols=True)
203
+
204
+ dt_record_local['res'] = []
205
+ num_local_batches = (len(eval_datasets_local) + BATCH_SIZE - 1) // BATCH_SIZE
206
+
207
+ for batch_idx_local in range(num_local_batches):
208
+ start_idx = batch_idx_local * BATCH_SIZE
209
+ end_idx = min((batch_idx_local + 1) * BATCH_SIZE, len(eval_datasets_local))
210
+ current_batch_list = eval_datasets_local[start_idx:end_idx]
211
+
212
+ if not current_batch_list:
213
+ continue
214
+
215
+ batch_predictions_texts = run_kh_batch(current_batch_list)
216
+
217
+ for item_idx_in_batch, full_pred_text in enumerate(batch_predictions_texts):
218
+ original_item = current_batch_list[item_idx_in_batch]
219
+ ground_truth_answer = original_item['answer']
220
+
221
+ _, parsed_pred_answer = split_initial_context(full_pred_text)
222
+ if not parsed_pred_answer.strip():
223
+ parsed_pred_answer = full_pred_text # Fallback to full prediction if parsed answer is empty
224
+
225
+ score = eval_one_chart(parsed_pred_answer, ground_truth_answer) # nlp object is global
226
+ dt_record_local['res'].append(score)
227
+
228
+ # (Optional) Main process prints a few prediction details
229
+ if accelerator.is_main_process:
230
+ print(full_pred_text, "######", ground_truth_answer, "######", score)
231
+
232
+ if pbar:
233
+ pbar.update(len(current_batch_list))
234
+
235
+ # --- Intermediate reporting logic ---
236
+ is_last_local_batch = (batch_idx_local == num_local_batches - 1)
237
+ # Every REPORT_INTERVAL_BATCHES local batches, or on the last local batch of this process,
238
+ # perform synchronization and reporting
239
+ should_sync_and_report = ((batch_idx_local + 1) % REPORT_INTERVAL_BATCHES == 0) or is_last_local_batch
240
+
241
+ # Make sure that even if REPORT_INTERVAL_BATCHES is 1, we do not report when there is no data
242
+ # (e.g., len(eval_datasets_local) == 0)
243
+ if len(eval_datasets_local) == 0: # If the current process has no data, skip reporting logic
244
+ should_sync_and_report = False
245
+ # If num_local_batches > 0, this check ensures we report only when there is data
246
+
247
+ if num_local_batches == 0 and is_last_local_batch: # Special case: process has no data but must join final sync
248
+ should_sync_and_report = True
249
+
250
+ if should_sync_and_report:
251
+ accelerator.wait_for_everyone() # Wait for all processes to reach the sync point
252
+
253
+ gathered_all_processes_data = [None] * num_processes
254
+ # Each process sends its *current accumulated* dt_record_local
255
+ # If a process has no data, dt_record_local['res'] is an empty list, which is fine
256
+ all_gather_object(gathered_all_processes_data, dt_record_local)
257
+
258
+ if accelerator.is_main_process:
259
+ current_global_scores_list = []
260
+ for process_data_dict in gathered_all_processes_data:
261
+ if process_data_dict and 'res' in process_data_dict:
262
+ current_global_scores_list.extend(process_data_dict['res'])
263
+
264
+ total_samples_processed_globally = len(current_global_scores_list)
265
+
266
+ report_title = "--- Intermediate Report ---"
267
+ # Check whether this is the final reporting point where all processes have finished
268
+ # A simple heuristic: if this is the last local batch on the main process
269
+ # and the total collected samples equal the total number of items
270
+ if is_last_local_batch and total_samples_processed_globally == total_items:
271
+ report_title = "--- Final Report ---"
272
+ elif is_last_local_batch: # Last batch on main process but perhaps not all samples are done yet
273
+ report_title = (
274
+ f"--- Report (Main Proc Last Batch, {batch_idx_local + 1}/{num_local_batches}) ---"
275
+ )
276
+
277
+ tqdm.write(f"\n{report_title}") # Use tqdm.write to avoid clashing with the progress bar
278
+ if current_global_scores_list:
279
+ mean_acc_global = np.array(current_global_scores_list).mean()
280
+ if accelerator.is_main_process:
281
+ print(f"Global samples processed: {total_samples_processed_globally} / {total_items}")
282
+ print(f"Current Global Mean Accuracy: {mean_acc_global:.4f}")
283
+ if pbar:
284
+ pbar.set_description(
285
+ f"Global Acc: {mean_acc_global:.4f} ({total_samples_processed_globally}/{total_items})"
286
+ )
287
+ else:
288
+ if accelerator.is_main_process:
289
+ print(
290
+ f"No scores to report globally yet (Total processed: {total_samples_processed_globally})."
291
+ )
292
+
293
+ accelerator.wait_for_everyone() # Sync again after reporting in case some processes move ahead faster
294
+
295
+ if pbar:
296
+ pbar.close()
297
+
298
+ # Final metrics have already been printed in the last report
299
+ # (when is_last_local_batch is True)
300
+ if accelerator.is_main_process and len(eval_datasets_local) == 0 and total_items > 0:
301
+ print(
302
+ "Main process had no data, but other processes might have. "
303
+ "Final global metrics are printed by the last reporting sync."
304
+ )
305
+ elif accelerator.is_main_process and total_items == 0:
306
+ print("No data was prepared for evaluation. Nothing to report.")
307
+
308
+ else:
309
+ if accelerator.is_main_process:
310
+ print(f"Task '{task}' is not configured for batched evaluation in this script.")
figs/chartqa.png ADDED
kill_all.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ readonly WORKER_HOSTS=(
5
+ "xx.xx.xx.xx"
6
+ )
7
+
8
+ readonly REMOTE_USER="root"
9
+
10
+ readonly TRAIN_SCRIPT="main"
11
+
12
+
13
+
14
+
15
+ echo "--- Killing local processes matching '${TRAIN_SCRIPT}' first ---"
16
+
17
+ pkill -9 -f "${TRAIN_SCRIPT}" || true
18
+ pkill -f python
19
+ echo "Local check complete."
20
+ echo
21
+
22
+
23
+ echo "🛑 Sending targeted kill signal to processes matching '${TRAIN_SCRIPT}' on all remote hosts in parallel..."
24
+
25
+ for HOST in "${WORKER_HOSTS[@]}"; do
26
+
27
+ (
28
+ echo "--- Processing host: ${HOST} ---"
29
+
30
+ ssh -n "${REMOTE_USER}@${HOST}" "
31
+ set -e # 远程脚本也应该在出错时停止
32
+ pkill -f python
33
+ # 精确查找由 python 启动的、且包含 TRAIN_SCRIPT 名称的进程
34
+ # 这是为了避免误杀其他同名进程(比如一个名为 'main_rebuttal' 的shell脚本)
35
+ PIDS=\$(pgrep -f \"python.*${TRAIN_SCRIPT}\")
36
+
37
+ if [ -z \"\$PIDS\" ]; then
38
+ echo '[INFO] ✅ No matching processes found on this host.'
39
+ else
40
+ echo '[WARN] 🔥 Found processes to kill:'
41
+ # 在杀死前显示详细信息,增加安全性
42
+ ps -fp \$PIDS
43
+ echo '[KILL] Killing PIDs: '\$PIDS'...'
44
+ kill -9 \$PIDS
45
+ echo '[OK] ✅ Processes killed successfully.'
46
+ fi
47
+ "
48
+ echo "--- Finished host: ${HOST} ---"
49
+ echo
50
+ ) &
51
+ done
52
+
53
+ wait
54
+
55
+ echo "🎉 All hosts have been processed."
main.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_grpo.py
2
+ """
3
+ Main script for training a Llava-based model using the custom MyGRPOTrainer.
4
+
5
+ This script handles:
6
+ 1. Configuration loading.
7
+ 2. Initialization of Weights & Biases (wandb) and Hugging Face Accelerate.
8
+ 3. Loading the model and processor.
9
+ 4. Preparing the training and evaluation datasets.
10
+ 5. Setting up and running the GPRO trainer.
11
+ """
12
+ import argparse
13
+ import os
14
+ from functools import partial
15
+ from typing import Dict, Any
16
+
17
+ import torch
18
+ import wandb
19
+ from accelerate import Accelerator
20
+ from datasets import Dataset, load_dataset
21
+ from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
22
+ from trl import GRPOConfig
23
+
24
+ from config.loader import load_config
25
+ from data_utils.commom_util import collate_fn, define_task_data_func
26
+ from trainer.DyMETrainer import DyMETrainer
27
+ from reward_utils.checker import RewardCalculator, RewardCalculatorLocal
28
+ from reward_utils.refiner import ContextRefiner, ContextRefinerLocal
29
+ from opsd_utils import debug_log as opsd_debug
30
+ from opsd_utils.teacher_batching import (
31
+ log_teacher_placement,
32
+ resolve_teacher_device_map,
33
+ )
34
+ from opsd_utils.deepspeed_utils import (
35
+ deepspeed_zero_stage,
36
+ gradient_checkpointing_enable_kwargs,
37
+ is_deepspeed_accelerate_config,
38
+ should_disable_gradient_checkpointing,
39
+ uses_deepspeed_json_file,
40
+ )
41
+
42
+
43
+ def _run_cross_model_vocab_checks(model, processor, teacher_model, model_config: Dict[str, Any]) -> None:
44
+ """Startup checks for cross-model OPD vocab slice + tokenizer alignment."""
45
+ from transformers import AutoProcessor
46
+
47
+ from opsd_utils.vocab_align import print_vocab_align_report, verify_shared_tokenizer_alignment
48
+
49
+ student_vocab = getattr(model.config, "vocab_size", len(processor.tokenizer))
50
+ teacher_vocab = getattr(teacher_model.config, "vocab_size", student_vocab)
51
+ shared = min(student_vocab, teacher_vocab)
52
+ print(
53
+ f"[OPSD-VOCAB] lm_head widths: student={student_vocab} teacher={teacher_vocab} "
54
+ f"shared_slice={shared}",
55
+ flush=True,
56
+ )
57
+ if student_vocab == teacher_vocab:
58
+ print("[OPSD-VOCAB] vocab sizes match — no slice needed", flush=True)
59
+ return
60
+
61
+ teacher_path = model_config.get("teacher_model_path")
62
+ teacher_processor = AutoProcessor.from_pretrained(teacher_path)
63
+ full_scan = os.environ.get("DYME_VOCAB_ALIGN_FULL", "0").strip().lower() in ("1", "true", "yes")
64
+ stride = int(os.environ.get("DYME_VOCAB_ALIGN_STRIDE", "500"))
65
+ report = verify_shared_tokenizer_alignment(
66
+ processor.tokenizer,
67
+ teacher_processor.tokenizer,
68
+ shared_vocab=shared,
69
+ full_scan=full_scan,
70
+ sample_stride=stride,
71
+ )
72
+ print_vocab_align_report(report)
73
+
74
+
75
+ def _wandb_disabled_by_env() -> bool:
76
+ if os.environ.get("WANDB_DISABLED", "").lower() in ("true", "1", "yes", "on"):
77
+ return True
78
+ if os.environ.get("WANDB_MODE", "").lower() in ("disabled", "off"):
79
+ return True
80
+ return False
81
+
82
+
83
+ def _try_wandb_login() -> bool:
84
+ """Return True if wandb credentials are available (env, offline, or prior login)."""
85
+ if os.environ.get("WANDB_MODE", "").lower() == "offline":
86
+ return True
87
+ wandb_key = os.environ.get("WANDB_API_KEY")
88
+ if wandb_key:
89
+ wandb.login(key=wandb_key)
90
+ return True
91
+ try:
92
+ wandb.login(relogin=False)
93
+ key = wandb.api.api_key
94
+ return bool(key and len(key) >= 40)
95
+ except Exception:
96
+ return False
97
+
98
+
99
+ def setup_accelerator_and_wandb(bf16, want_wandb: bool) -> tuple[Accelerator, bool]:
100
+ """
101
+ Initialize Accelerator and optionally wandb.
102
+
103
+ Returns:
104
+ (accelerator, use_wandb)
105
+ """
106
+ use_wandb = want_wandb and not _wandb_disabled_by_env()
107
+ if use_wandb:
108
+ use_wandb = _try_wandb_login()
109
+
110
+ accel_kwargs: dict = {}
111
+ # bf16 for DDP/MULTI_GPU only; with deepspeed_config_file, precision lives in the JSON.
112
+ if bf16 and not uses_deepspeed_json_file():
113
+ accel_kwargs["mixed_precision"] = "bf16"
114
+ if use_wandb:
115
+ accel_kwargs["log_with"] = "wandb"
116
+ return Accelerator(**accel_kwargs), use_wandb
117
+
118
+
119
+ def load_model_and_processor(model_config: Dict[str, Any]):
120
+ """
121
+ Loads the pre-trained vision-language model and its associated processor.
122
+
123
+ Args:
124
+ model_config (Dict[str, Any]): Configuration dictionary for the model.
125
+
126
+ Returns:
127
+ Tuple[LlavaOnevisionForConditionalGeneration, PreTrainedProcessor]: The loaded model and processor.
128
+ """
129
+ model_id = model_config['pretrained_model_path']
130
+
131
+ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
132
+ model_id,
133
+ torch_dtype=getattr(torch, model_config['torch_dtype']),
134
+ attn_implementation='flash_attention_2' if model_config['use_flash_attention_2'] else 'sdpa',
135
+ low_cpu_mem_usage=True,
136
+ )
137
+
138
+ # Freeze the vision tower to save memory and computation
139
+ model.base_model.vision_tower.requires_grad_(False)
140
+
141
+ processor = AutoProcessor.from_pretrained(model_id)
142
+ processor.tokenizer.padding_side = "left"
143
+
144
+ return model, processor
145
+
146
+
147
+ def load_teacher_model(model_config: Dict[str, Any], *, local_rank: int = 0, num_gpus: int = 1):
148
+ """Load optional frozen teacher for cross-model OPD (e.g. LLaVA-OneVision 7B)."""
149
+ teacher_path = model_config.get("teacher_model_path")
150
+ if not teacher_path:
151
+ return None
152
+
153
+ dtype_name = model_config.get("teacher_dtype", model_config.get("torch_dtype", "bfloat16"))
154
+ torch_dtype = getattr(torch, dtype_name)
155
+ requested_map = model_config.get("teacher_device_map")
156
+ if not requested_map:
157
+ env_map = os.environ.get("DYME_TEACHER_DEVICE_MAP", "").strip()
158
+ if env_map:
159
+ requested_map = env_map
160
+
161
+ device_map = resolve_teacher_device_map(
162
+ requested_map,
163
+ local_rank=local_rank,
164
+ num_gpus=max(1, num_gpus),
165
+ )
166
+ log_teacher_placement(
167
+ local_rank=local_rank,
168
+ num_gpus=max(1, num_gpus),
169
+ teacher_path=teacher_path,
170
+ resolved_device=device_map,
171
+ requested_map=requested_map,
172
+ )
173
+
174
+ load_kwargs: Dict[str, Any] = {
175
+ "torch_dtype": torch_dtype,
176
+ "low_cpu_mem_usage": True,
177
+ "device_map": device_map,
178
+ }
179
+
180
+ teacher = LlavaOnevisionForConditionalGeneration.from_pretrained(
181
+ teacher_path,
182
+ attn_implementation='flash_attention_2' if model_config.get('use_flash_attention_2') else 'sdpa',
183
+ **load_kwargs,
184
+ )
185
+ teacher.eval()
186
+ teacher.requires_grad_(False)
187
+ if hasattr(teacher, "base_model") and hasattr(teacher.base_model, "vision_tower"):
188
+ teacher.base_model.vision_tower.requires_grad_(False)
189
+
190
+ return teacher
191
+
192
+
193
+ def prepare_datasets(task: str, dataset_config: Dict[str, Any], mode='rl') -> (Dataset, Dataset):
194
+ """
195
+ Prepares the training and evaluation datasets based on the specified task.
196
+
197
+ Args:
198
+ task (str): The name of the task (e.g., 'chartqa').
199
+ dataset_config (Dict[str, Any]): Configuration for datasets.
200
+
201
+ Returns:
202
+ Tuple[Dataset, Dataset]: The training and evaluation datasets.
203
+ """
204
+ data_func = define_task_data_func(task, mode=mode)
205
+
206
+ # Create training dataset
207
+ train_data_list = data_func(json_path=dataset_config['train_dataset'])
208
+ train_dataset = Dataset.from_list(train_data_list)
209
+
210
+ # Create evaluation dataset
211
+ if 'chart' in task:
212
+ eval_dataset = load_dataset(dataset_config['eval_dataset'])['test']
213
+ # Note: You can uncomment the line below for quick testing/debugging.
214
+ # eval_dataset = eval_dataset.select(range(1000, 1100))
215
+
216
+ else:
217
+ # Extend this section for other tasks if needed in the future.
218
+ eval_dataset = None
219
+
220
+ return train_dataset, eval_dataset
221
+
222
+
223
+ def main():
224
+ """
225
+ Main function to orchestrate the model training pipeline.
226
+ """
227
+
228
+ parser = argparse.ArgumentParser(description="Train a Llava model using either SFT or GRPO.")
229
+
230
+ parser.add_argument(
231
+ '--config', type=str, default='config/config.py',
232
+ help="Python config path (e.g. config/config.py, config/config_trimode.py) "
233
+ "or shorthand alias: norm | trimode | llavacot | low | aok",
234
+ )
235
+ parser.add_argument(
236
+ '--mode', type=str, default='rl',
237
+ )
238
+ parser.add_argument(
239
+ '--opsd_mode', type=str, default=None,
240
+ help="OPSD routing mode: dyme | trimode | rlsd | copsd_opd | opsd_only | replace_sft | opsd_on_wrong | grpo_opsd_joint",
241
+ )
242
+ parser.add_argument(
243
+ '--opsd_providers', type=str, default=None,
244
+ help="Comma-separated privileged providers: text,visual_facts,crop,hybrid",
245
+ )
246
+ parser.add_argument(
247
+ '--opsd_privilege_profile', type=str, default=None,
248
+ help="Privileged profile preset: text | visual | hybrid (default hybrid in config_trimode)",
249
+ )
250
+ parser.add_argument(
251
+ '--reward_weights', type=str, default=None,
252
+ help="Comma-separated reward weights: format,context,acc (e.g. 0.5,1.5,1.0). "
253
+ "Overrides config; env DYME_REWARD_WEIGHTS also supported in antidegen config.",
254
+ )
255
+ parser.add_argument(
256
+ '--opsd_enabled', action='store_true',
257
+ help="Enable OPSD / TriMode training extensions",
258
+ )
259
+ parser.add_argument(
260
+ '--opsd_debug', action='store_true',
261
+ help="Enable verbose OPSD debug logs (or set env DYME_OPSD_DEBUG=1)",
262
+ )
263
+ parser.add_argument(
264
+ '--opsd_detail_every', type=int, default=None,
265
+ help="Emit full weak-signal diagnostic bundle every N global steps on rank 0 "
266
+ "(default 10; config opsd.debug.detail_every or env DYME_OPSD_DETAIL_EVERY)",
267
+ )
268
+ parser.add_argument(
269
+ '--opsd_probe_on_generate', dest='opsd_probe_on_generate', action='store_true',
270
+ help="Emit [OPSD-PROBE] on every (re)generate on rank 0 (config_trimode default on)",
271
+ )
272
+ parser.add_argument(
273
+ '--no_opsd_probe_on_generate', dest='opsd_probe_on_generate', action='store_false',
274
+ help="Disable per-generate [OPSD-PROBE] logs",
275
+ )
276
+ parser.set_defaults(opsd_probe_on_generate=None)
277
+ parser.add_argument(
278
+ '--no_opsd_probe_first_token_logits', dest='opsd_probe_first_token_logits', action='store_false',
279
+ help="Disable pre-generate first-token logits probe ([OPSD-GENDBG])",
280
+ )
281
+ parser.set_defaults(opsd_probe_first_token_logits=None)
282
+ parser.add_argument(
283
+ '--wandb', dest='wandb', action='store_true',
284
+ help="Force enable Weights & Biases logging",
285
+ )
286
+ parser.add_argument(
287
+ '--no_wandb', dest='wandb', action='store_false',
288
+ help="Disable Weights & Biases logging (or set WANDB_MODE=offline/disabled)",
289
+ )
290
+ parser.set_defaults(wandb=None)
291
+
292
+ args = parser.parse_args()
293
+ mode = args.mode
294
+
295
+ # 1. Load Configurations
296
+ CONFIG = load_config(args.config)
297
+ model_config = CONFIG['model']
298
+ training_config = CONFIG['training']
299
+ rl_config = CONFIG['rl']
300
+ client_config = CONFIG['client']
301
+ dataset_config = CONFIG['dataset']
302
+ task = training_config['task']
303
+ opsd_config = dict(CONFIG.get('opsd', {"enabled": False, "mode": "dyme"}))
304
+ if args.opsd_enabled:
305
+ opsd_config["enabled"] = True
306
+ if args.opsd_mode is not None:
307
+ opsd_config["enabled"] = True
308
+ opsd_config["mode"] = args.opsd_mode
309
+ if args.opsd_providers is not None:
310
+ opsd_config["privileged_providers"] = [p.strip() for p in args.opsd_providers.split(",") if p.strip()]
311
+ if args.opsd_privilege_profile is not None:
312
+ opsd_config["privileged_profile"] = args.opsd_privilege_profile.strip()
313
+ reward_weights_raw = args.reward_weights or os.environ.get("DYME_REWARD_WEIGHTS")
314
+ if reward_weights_raw:
315
+ parts = [p.strip() for p in reward_weights_raw.split(",") if p.strip()]
316
+ if len(parts) != 3:
317
+ raise ValueError(
318
+ f"reward_weights must have exactly 3 comma-separated values (format,context,acc), got: {reward_weights_raw!r}"
319
+ )
320
+ opsd_config["reward_weights"] = [float(p) for p in parts]
321
+ debug_cfg = opsd_config.setdefault("debug", {})
322
+ detail_every = debug_cfg.get("detail_every", 10)
323
+ if args.opsd_detail_every is not None:
324
+ detail_every = max(0, args.opsd_detail_every)
325
+ debug_cfg["detail_every"] = detail_every
326
+ probe_on_generate = debug_cfg.get("probe_on_generate", False)
327
+ if args.opsd_probe_on_generate is not None:
328
+ probe_on_generate = args.opsd_probe_on_generate
329
+ debug_cfg["probe_on_generate"] = probe_on_generate
330
+ probe_first_token_logits = debug_cfg.get("probe_first_token_logits", True)
331
+ if args.opsd_probe_first_token_logits is not None:
332
+ probe_first_token_logits = args.opsd_probe_first_token_logits
333
+ debug_cfg["probe_first_token_logits"] = probe_first_token_logits
334
+
335
+ debug_enabled = opsd_debug.configure(
336
+ enabled=args.opsd_debug or None,
337
+ detail_every=detail_every,
338
+ probe_on_generate=probe_on_generate,
339
+ probe_first_token_logits=probe_first_token_logits,
340
+ probe_prompt_tail_tokens=debug_cfg.get("probe_prompt_tail_tokens", 16),
341
+ probe_log_model_context=debug_cfg.get("probe_log_model_context", True),
342
+ )
343
+ if debug_enabled:
344
+ opsd_debug.log_config("main", "resolved OPSD config", opsd_config)
345
+ opsd_debug.log("main", "training entry", mode=mode, config_path=args.config)
346
+
347
+ # 2. Setup Environment
348
+ want_wandb = True if args.wandb is None else args.wandb
349
+ accelerator, use_wandb = setup_accelerator_and_wandb(
350
+ bf16=training_config['dyme_args']['bf16'],
351
+ want_wandb=want_wandb,
352
+ )
353
+ if want_wandb and not use_wandb and args.wandb is True:
354
+ raise RuntimeError(
355
+ "wandb was requested (--wandb) but no API key is configured. "
356
+ "Run `wandb login`, set WANDB_API_KEY, or use WANDB_MODE=offline."
357
+ )
358
+ if accelerator.is_main_process:
359
+ if use_wandb:
360
+ print("[DyME] wandb enabled for training logs")
361
+ elif want_wandb:
362
+ print(
363
+ "[DyME] wandb disabled (no credentials). Training continues with report_to=none. "
364
+ "Run `wandb login`, export WANDB_API_KEY, or pass --wandb after configuring."
365
+ )
366
+ device_id = accelerator.process_index
367
+ opsd_debug.configure(
368
+ enabled=debug_enabled,
369
+ detail_every=detail_every,
370
+ probe_on_generate=probe_on_generate,
371
+ probe_first_token_logits=probe_first_token_logits,
372
+ probe_prompt_tail_tokens=debug_cfg.get("probe_prompt_tail_tokens", 16),
373
+ probe_log_model_context=debug_cfg.get("probe_log_model_context", True),
374
+ rank=accelerator.process_index,
375
+ world_size=accelerator.num_processes,
376
+ )
377
+ if debug_enabled:
378
+ opsd_debug.log(
379
+ "main",
380
+ "accelerator initialized",
381
+ process_index=accelerator.process_index,
382
+ local_process_index=accelerator.local_process_index,
383
+ num_processes=accelerator.num_processes,
384
+ device=str(accelerator.device),
385
+ )
386
+
387
+ visible_gpus = torch.cuda.device_count()
388
+ local_rank = int(os.environ.get("LOCAL_RANK", accelerator.local_process_index))
389
+ if visible_gpus == 0:
390
+ raise RuntimeError("No CUDA devices are visible to this process.")
391
+ if accelerator.num_processes > visible_gpus:
392
+ raise RuntimeError(
393
+ f"GPU/process mismatch: launched {accelerator.num_processes} distributed processes "
394
+ f"but only {visible_gpus} CUDA device(s) are visible "
395
+ f"(CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', '<unset>')}).\n"
396
+ f"Fix: accelerate launch --num_processes {visible_gpus} ...\n"
397
+ f"Or: bash scripts/train_local_gpus.sh (auto-detects {visible_gpus} GPU(s))"
398
+ )
399
+ if local_rank >= visible_gpus:
400
+ raise RuntimeError(
401
+ f"LOCAL_RANK={local_rank} but only {visible_gpus} GPU(s) visible. "
402
+ f"Reduce --num_processes to {visible_gpus}."
403
+ )
404
+ if accelerator.is_main_process:
405
+ print(
406
+ f"[DyME] Distributed launch OK: num_processes={accelerator.num_processes}, "
407
+ f"visible_gpus={visible_gpus}, CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', '<unset>')}"
408
+ )
409
+
410
+ # 3. Initialize Model and Processor
411
+ ds_zero_stage = deepspeed_zero_stage()
412
+ if accelerator.is_main_process and is_deepspeed_accelerate_config():
413
+ print(
414
+ f"[DyME] DeepSpeed enabled via ACCELERATE_CONFIG "
415
+ f"({os.environ.get('ACCELERATE_CONFIG', '<unset>')}), ZeRO stage={ds_zero_stage}",
416
+ flush=True,
417
+ )
418
+
419
+ model, processor = load_model_and_processor(model_config)
420
+ if os.environ.get("DYME_GRADIENT_CHECKPOINTING", "").strip().lower() in ("1", "true", "yes", "on"):
421
+ if should_disable_gradient_checkpointing():
422
+ if accelerator.is_main_process:
423
+ print(
424
+ "[DyME] gradient checkpointing skipped: incompatible with DeepSpeed ZeRO-1/2 "
425
+ "(multiple student forwards / checkpoint backward). "
426
+ "Use ZeRO-3, DDP, or DYME_GRADIENT_CHECKPOINTING=0.",
427
+ flush=True,
428
+ )
429
+ else:
430
+ gc_kwargs = gradient_checkpointing_enable_kwargs()
431
+ if gc_kwargs:
432
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs)
433
+ else:
434
+ model.gradient_checkpointing_enable()
435
+ if accelerator.is_main_process:
436
+ mode = f"use_reentrant={gc_kwargs['use_reentrant']}" if gc_kwargs else "default"
437
+ print(
438
+ f"[DyME] gradient checkpointing enabled on student "
439
+ f"(DYME_GRADIENT_CHECKPOINTING, {mode})",
440
+ flush=True,
441
+ )
442
+
443
+ cold_start_frac = float(
444
+ opsd_config.get("gate", {}).get("sft_cold_start_frac", 0.0) or 0.0
445
+ )
446
+ cold_start_steps = opsd_config.get("gate", {}).get("sft_cold_start_steps")
447
+ lazy_teacher = bool(cold_start_steps) or cold_start_frac > 0.0
448
+
449
+ teacher_model = None
450
+ teacher_model_config = None
451
+ if lazy_teacher:
452
+ teacher_model_config = dict(model_config)
453
+ if accelerator.is_main_process:
454
+ print(
455
+ "[DyME] SFT cold-start enabled: deferring 7B teacher load until RL phase",
456
+ flush=True,
457
+ )
458
+ else:
459
+ teacher_model = load_teacher_model(
460
+ model_config,
461
+ local_rank=local_rank,
462
+ num_gpus=visible_gpus,
463
+ )
464
+ if accelerator.is_main_process and teacher_model is not None:
465
+ _run_cross_model_vocab_checks(
466
+ model,
467
+ processor,
468
+ teacher_model,
469
+ model_config,
470
+ )
471
+
472
+ # 4. Prepare Datasets
473
+ train_dataset, eval_dataset = prepare_datasets(task, dataset_config, mode=mode)
474
+
475
+ # 5. Initialize Reward Calculator
476
+ # checker = RewardCalculator(rl_config, client_config.copy(), gpu_id=device_id)
477
+ # refiner = ContextRefiner(rl_config, client_config.copy(), gpu_id=device_id)
478
+
479
+ checker = RewardCalculatorLocal(rl_config, client_config.copy(), gpu_id=device_id)
480
+ refiner = ContextRefinerLocal(rl_config, client_config.copy(), gpu_id=device_id)
481
+ # 6. Define Training Arguments
482
+ dyme_args = dict(training_config['dyme_args'])
483
+ if ds_zero_stage is not None and ds_zero_stage >= 3:
484
+ dyme_args.setdefault("ds3_gather_for_generation", True)
485
+ if not use_wandb:
486
+ dyme_args["report_to"] = "none"
487
+ training_args = GRPOConfig(**dyme_args)
488
+
489
+ collate_fn_with_processor = partial(collate_fn, processor=processor)
490
+ # 7. Initialize the Trainer
491
+ dyme_trainer = DyMETrainer(
492
+ model=model,
493
+ checker=checker,
494
+ refiner=refiner,
495
+ args=training_args,
496
+ train_dataset=train_dataset,
497
+ eval_dataset=eval_dataset,
498
+ processing_class=processor,
499
+ processing_func=collate_fn_with_processor,
500
+ task_name=task,
501
+ end_flag=rl_config['end_flag'],
502
+ opsd_config=opsd_config,
503
+ teacher_model=teacher_model,
504
+ teacher_model_config=teacher_model_config,
505
+ )
506
+
507
+ # 8. Start Training
508
+ dyme_trainer.train()
509
+
510
+ output_dir = training_args.output_dir
511
+ output_dir = os.path.join(output_dir, "final_checkpoint")
512
+ if accelerator.is_main_process and is_deepspeed_accelerate_config():
513
+ print(
514
+ "[DyME] Saving consolidated student checkpoint (DeepSpeed ZeRO gather if configured)...",
515
+ flush=True,
516
+ )
517
+ dyme_trainer.save_model(output_dir)
518
+ if accelerator.is_main_process:
519
+ processor.save_pretrained(output_dir)
520
+ print(f"Model and processor saved to {output_dir}")
521
+ if __name__ == "__main__":
522
+ main()
main_llm.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_grpo.py
2
+ """
3
+ Main script for training a Llava-based model using the custom MyGRPOTrainer.
4
+
5
+ This script handles:
6
+ 1. Configuration loading.
7
+ 2. Initialization of Weights & Biases (wandb) and Hugging Face Accelerate.
8
+ 3. Loading the model and processor.
9
+ 4. Preparing the training and evaluation datasets.
10
+ 5. Setting up and running the GRPO trainer.
11
+ """
12
+ import argparse
13
+ import os
14
+ from functools import partial
15
+ from typing import Dict, Any
16
+
17
+ import torch
18
+ import wandb
19
+ from accelerate import Accelerator
20
+ from datasets import Dataset, load_dataset
21
+ from peft import LoraConfig, get_peft_model, TaskType
22
+ from transformers import AutoProcessor, AutoModelForCausalLM
23
+ from trl import GRPOConfig
24
+ from config.config_llm import CONFIG
25
+ from data_utils.commom_util import collate_fn, define_task_data_func, collate_fn_woI
26
+ from trainer.DyMETrainer_llm import DyMETrainer
27
+ from reward_utils.checker import RewardCalculator, RewardCalculatorLocal
28
+ from reward_utils.refiner import ContextRefiner, ContextRefinerLocal
29
+
30
+ def print_trainable_parameters(model):
31
+ """
32
+ Prints the number of trainable parameters in the model.
33
+ """
34
+ trainable_params = 0
35
+ all_param = 0
36
+ for _, param in model.named_parameters():
37
+ all_param += param.numel()
38
+ if param.requires_grad:
39
+ trainable_params += param.numel()
40
+ print(
41
+ f"trainable params: {trainable_params} || all params: {all_param} || "
42
+ f"trainable%: {100 * trainable_params / all_param:.2f}"
43
+ )
44
+
45
+ def setup_accelerator_and_wandb(bf16) -> Accelerator:
46
+ """
47
+ Initializes Weights & Biases and the Hugging Face Accelerator.
48
+
49
+ Returns:
50
+ Accelerator: The configured accelerator instance.
51
+ """
52
+ wandb_key = os.environ.get("WANDB_API_KEY")
53
+ if wandb_key:
54
+ wandb.login(key=wandb_key)
55
+ if bf16:
56
+ accelerator = Accelerator(mixed_precision="bf16", log_with="wandb")
57
+ else:
58
+ accelerator = Accelerator(log_with="wandb")
59
+ return accelerator
60
+
61
+
62
+
63
+ def load_model_and_processor(model_config: Dict[str, Any], peft_config: Dict[str, Any]):
64
+ """
65
+ Loads the base model, applies LoRA configuration, and loads its processor.
66
+
67
+ Args:
68
+ model_config (Dict[str, Any]): Configuration dictionary for the model.
69
+ peft_config (Dict[str, Any]): Configuration dictionary for PEFT (LoRA).
70
+
71
+ Returns:
72
+ Tuple[PeftModel, PreTrainedProcessor]: The loaded PEFT model and processor.
73
+ """
74
+ model_id = model_config['pretrained_model_path']
75
+
76
+ # Load base model
77
+ base_model = AutoModelForCausalLM.from_pretrained(
78
+ model_id,
79
+ torch_dtype=getattr(torch, model_config['torch_dtype']),
80
+ attn_implementation='flash_attention_2' if model_config['use_flash_attention_2'] else 'sdpa',
81
+ low_cpu_mem_usage=True,
82
+ )
83
+
84
+ processor = AutoProcessor.from_pretrained(model_id, padding_side='left')
85
+ processor._tokenizer.padding_side = "left"
86
+ lora_config = peft_config
87
+
88
+ model = get_peft_model(base_model, lora_config)
89
+
90
+ print("LoRA model created:")
91
+ print_trainable_parameters(model)
92
+
93
+ return model, processor
94
+ # ## --- LoRA modification End --- ##
95
+
96
+
97
+ def prepare_datasets(task: str, dataset_config: Dict[str, Any]) -> (Dataset, Dataset):
98
+ """
99
+ Prepares the training and evaluation datasets based on the specified task.
100
+ """
101
+ data_func = define_task_data_func(task)
102
+ train_data_list = data_func(json_path=dataset_config['train_dataset'])
103
+ train_dataset = Dataset.from_list(train_data_list)
104
+
105
+ if 'chart' in task:
106
+ eval_dataset = load_dataset(dataset_config['eval_dataset'])['test']
107
+ else:
108
+ eval_dataset = None
109
+
110
+ return train_dataset, eval_dataset
111
+
112
+
113
+ def main():
114
+ """
115
+ Main function to orchestrate the model training pipeline.
116
+ """
117
+
118
+ parser = argparse.ArgumentParser(description="Train a model using GRPO with LoRA.")
119
+
120
+ parser.add_argument(
121
+ '--config', type=str, default='norm',
122
+ help="config file to use: 'norm' or 'llavacot'..."
123
+ )
124
+ args = parser.parse_args()
125
+ config_select = args.config
126
+
127
+ if config_select == 'norm':
128
+ from config_llm import CONFIG
129
+
130
+ # 1. Load Configurations
131
+ model_config = CONFIG['model']
132
+ training_config = CONFIG['training']
133
+ rl_config = CONFIG['rl']
134
+ client_config = CONFIG['client']
135
+ dataset_config = CONFIG['dataset']
136
+ peft_config = LoraConfig(
137
+ r=16,
138
+ lora_alpha=64,
139
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
140
+ task_type="CAUSAL_LM",
141
+ lora_dropout=0.05,
142
+ )
143
+
144
+ task = training_config['task']
145
+
146
+ # 2. Setup Environment
147
+ accelerator = setup_accelerator_and_wandb(bf16=training_config['dyme_args']['bf16'])
148
+ device_id = accelerator.process_index
149
+
150
+ # 3. Initialize Model and Processor
151
+ # ## --- LoRA modification Start --- ##
152
+ # Pass peft_config to the model loading function
153
+ model, processor = load_model_and_processor(model_config, peft_config)
154
+ # ## --- LoRA modification End --- ##
155
+
156
+ # 4. Prepare Datasets
157
+ train_dataset, eval_dataset = prepare_datasets(task, dataset_config)
158
+
159
+ # 5. Initialize Reward Calculator
160
+ checker = RewardCalculatorLocal(rl_config, client_config.copy(), gpu_id=device_id)
161
+ refiner = ContextRefinerLocal(rl_config, client_config.copy(), gpu_id=device_id)
162
+
163
+ # 6. Define Training Arguments
164
+ training_args = GRPOConfig(**training_config['dyme_args'])
165
+
166
+ collate_fn_with_processor = partial(collate_fn_woI, processor=processor)
167
+
168
+ # 7. Initialize the Trainer
169
+ # Trainer handles PeftModel automatically
170
+ dyme_trainer = DyMETrainer(
171
+ model=model,
172
+ checker=checker,
173
+ refiner=refiner,
174
+ args=training_args,
175
+ train_dataset=train_dataset,
176
+ eval_dataset=eval_dataset,
177
+ processing_class=processor,
178
+ processing_func=collate_fn_with_processor,
179
+ task_name=task,
180
+ end_flag=rl_config['end_flag'],
181
+ )
182
+
183
+ # 8. Start Training
184
+ dyme_trainer.train()
185
+
186
+ # When saving, the Trainer automatically saves only the LoRA adapter weights
187
+ output_dir = training_args.output_dir
188
+ output_dir = os.path.join(output_dir, "final_checkpoint")
189
+ dyme_trainer.save_model(output_dir)
190
+
191
+ if accelerator.is_main_process:
192
+ # Non-model files like the processor still need to be saved manually
193
+ processor.save_pretrained(output_dir)
194
+ print(f"LoRA adapters and processor saved to {output_dir}")
195
+
196
+ if __name__ == "__main__":
197
+ main()
main_sft.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Offline supervised fine-tuning for ChartQA (two-stage cold start before RLSD/OPD).
3
+
4
+ Usage:
5
+ accelerate launch main_sft.py --config config/config_rlsd_chartqa.py
6
+ bash scripts/train_chartqa_sft.sh
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import os
12
+ from functools import partial
13
+
14
+ from accelerate import Accelerator
15
+ from datasets import Dataset
16
+ from transformers import Trainer, TrainingArguments
17
+
18
+ from config.loader import load_config
19
+ from data_utils.commom_util import collate_fn, define_task_data_func
20
+ from main import load_model_and_processor
21
+
22
+
23
+ def main() -> None:
24
+ parser = argparse.ArgumentParser(description="ChartQA offline SFT (hint + Answer GT).")
25
+ parser.add_argument(
26
+ "--config",
27
+ type=str,
28
+ default="config/config_rlsd_chartqa.py",
29
+ help="Config module (uses training.sft_args and dataset.train_dataset).",
30
+ )
31
+ parser.add_argument(
32
+ "--pretrained_model_path",
33
+ type=str,
34
+ default=None,
35
+ help="Override CONFIG model path (e.g. base 0.5B before RL).",
36
+ )
37
+ args = parser.parse_args()
38
+
39
+ config = load_config(args.config)
40
+ model_config = dict(config["model"])
41
+ if args.pretrained_model_path:
42
+ model_config["pretrained_model_path"] = args.pretrained_model_path
43
+
44
+ training_config = config["training"]
45
+ task = training_config["task"]
46
+ sft_args = dict(training_config.get("sft_args") or config.get("training", {}).get("sft_args", {}))
47
+ if not sft_args:
48
+ raise ValueError("Config must define training.sft_args for offline SFT.")
49
+
50
+ output_dir = os.environ.get("DYME_SFT_OUTPUT_DIR", sft_args.get("output_dir", "./outputs/chartqa-sft"))
51
+ sft_args["output_dir"] = output_dir
52
+ sft_args.setdefault("remove_unused_columns", False)
53
+
54
+ accelerator = Accelerator()
55
+ if accelerator.is_main_process:
56
+ os.makedirs(output_dir, exist_ok=True)
57
+
58
+ model, processor = load_model_and_processor(model_config)
59
+ data_func = define_task_data_func(task, mode="sft")
60
+ train_list = data_func(json_path=config["dataset"]["train_dataset"])
61
+ train_dataset = Dataset.from_list(train_list)
62
+
63
+ label_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
64
+ data_collator = partial(collate_fn, processor=processor, label_id=label_id)
65
+
66
+ train_args = TrainingArguments(**sft_args)
67
+ trainer = Trainer(
68
+ model=model,
69
+ args=train_args,
70
+ train_dataset=train_dataset,
71
+ data_collator=data_collator,
72
+ )
73
+ trainer.train()
74
+ trainer.save_model(os.path.join(output_dir, "final_checkpoint"))
75
+ if accelerator.is_main_process:
76
+ processor.save_pretrained(os.path.join(output_dir, "final_checkpoint"))
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
multi_node_config_raw.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ distributed_type: MULTI_GPU
3
+ downcast_bf16: 'no'
4
+ gpu_ids: all
5
+ machine_rank: 0
6
+ main_process_ip: 'xx.xx.xx.xx'
7
+ main_process_port: 36001
8
+ main_training_function: main
9
+ mixed_precision: 'bf16'
10
+ num_machines: 2
11
+ num_processes: 16
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
17
+ # Optional DeepSpeed ZeRO-0 (no sharding). Prefer MULTI_GPU DDP for single-node training.
18
+ deepspeed_config:
19
+ zero_stage: 0
20
+ gradient_accumulation_steps: 1
21
+ zero3_init_flag: false
opsd_utils/__pycache__/opsd_loss.cpython-312.pyc ADDED
Binary file (5.03 kB). View file
 
opsd_utils/gate_policy.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RLSD warmup gates for OPSD degenerate skip, denser online SFT, and embedded SFT cold start."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from typing import Any, Mapping, Optional
6
+
7
+
8
+ def current_global_step(trainer: Any) -> int:
9
+ return int(getattr(getattr(trainer, "state", None), "global_step", getattr(trainer, "_step", 0)) or 0)
10
+
11
+
12
+ def resolve_max_training_steps(trainer: Any) -> Optional[int]:
13
+ """Resolve total optimizer steps for gate math (cold start frac, warmup windows).
14
+
15
+ Priority: TrainingArguments.max_steps > Trainer.state.max_steps > epoch estimate.
16
+ HF sets state.max_steps when max_steps<=0 from num_train_epochs * len(dataloader).
17
+ """
18
+ args = getattr(trainer, "args", None)
19
+ if args is not None:
20
+ arg_max = getattr(args, "max_steps", None)
21
+ if arg_max is not None and int(arg_max) > 0:
22
+ return int(arg_max)
23
+
24
+ state = getattr(trainer, "state", None)
25
+ if state is not None:
26
+ state_max = getattr(state, "max_steps", None)
27
+ if state_max is not None and int(state_max) > 0:
28
+ return int(state_max)
29
+
30
+ if args is None:
31
+ return None
32
+
33
+ num_epochs = getattr(args, "num_train_epochs", None)
34
+ grad_accum = max(1, int(getattr(args, "gradient_accumulation_steps", 1) or 1))
35
+ if num_epochs is None or float(num_epochs) <= 0:
36
+ return None
37
+
38
+ dataloader = getattr(trainer, "train_dataloader", None)
39
+ if dataloader is None and hasattr(trainer, "get_train_dataloader"):
40
+ try:
41
+ dataloader = trainer.get_train_dataloader()
42
+ except Exception:
43
+ dataloader = None
44
+ if dataloader is None:
45
+ return None
46
+
47
+ try:
48
+ steps_per_epoch = len(dataloader)
49
+ except TypeError:
50
+ return None
51
+ if steps_per_epoch <= 0:
52
+ return None
53
+
54
+ total = math.ceil(float(num_epochs) * steps_per_epoch / grad_accum)
55
+ return total if total > 0 else None
56
+
57
+
58
+ def sft_cold_start_steps(opsd_config: Mapping[str, Any], max_steps: Optional[int]) -> int:
59
+ """Steps at start of training devoted to embedded offline-style SFT (no generate / no OPSD)."""
60
+ gate = opsd_config.get("gate", {})
61
+ steps_env = gate.get("sft_cold_start_steps")
62
+ if steps_env is not None:
63
+ return max(0, int(steps_env))
64
+ frac = float(gate.get("sft_cold_start_frac", 0.0) or 0.0)
65
+ if frac <= 0.0 or max_steps is None or max_steps <= 0:
66
+ return 0
67
+ return max(1, int(max_steps * frac))
68
+
69
+
70
+ def in_sft_cold_start(
71
+ opsd_config: Mapping[str, Any],
72
+ global_step: int,
73
+ max_steps: Optional[int],
74
+ ) -> bool:
75
+ cold_steps = sft_cold_start_steps(opsd_config, max_steps)
76
+ return cold_steps > 0 and global_step < cold_steps
77
+
78
+
79
+ def resolve_skip_degenerate_opsd(
80
+ opsd_config: Mapping[str, Any],
81
+ global_step: int,
82
+ max_steps: Optional[int] = None,
83
+ ) -> bool:
84
+ gate = opsd_config.get("gate", {})
85
+ if not gate.get("skip_degenerate_for_opsd", False):
86
+ return False
87
+ cold_end = sft_cold_start_steps(opsd_config, max_steps)
88
+ warmup = int(gate.get("degen_skip_warmup_steps", 200))
89
+ # Do not skip degenerate OPSD during embedded SFT cold start or its degen warmup window.
90
+ threshold = cold_end + warmup if cold_end > 0 else warmup
91
+ return global_step >= threshold
92
+
93
+
94
+ def sft_slots_for_step(
95
+ opsd_config: Mapping[str, Any],
96
+ global_step: int,
97
+ max_steps: Optional[int] = None,
98
+ ) -> int:
99
+ if in_sft_cold_start(opsd_config, global_step, max_steps):
100
+ return 0
101
+ gate = opsd_config.get("gate", {})
102
+ warmup_steps = int(gate.get("sft_warmup_steps", 200))
103
+ cold_end = sft_cold_start_steps(opsd_config, max_steps)
104
+ effective_warmup_end = cold_end + warmup_steps if cold_end > 0 else warmup_steps
105
+ if global_step < effective_warmup_end:
106
+ return max(1, int(gate.get("sft_warmup_slots_per_group", 2)))
107
+ return 1
opsd_utils/health_monitor.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training health monitor: degeneration alerts, rolling stats, cross-step correlation."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from collections import deque
6
+ from typing import Any, Optional
7
+
8
+ from opsd_utils import debug_log as opsd_debug
9
+
10
+ ALERT_GEN_CLIP_COLLAPSE = "GEN_CLIP_COLLAPSE"
11
+ ALERT_GEN_REPEAT_DEGEN = "GEN_REPEAT_DEGEN"
12
+ ALERT_OPT_GRAD_SPIKE = "OPT_GRAD_SPIKE"
13
+ ALERT_OPT_NAN_INF = "OPT_NAN_INF"
14
+ ALERT_RL_ZERO_SIGNAL = "RL_ZERO_SIGNAL"
15
+ ALERT_REWARD_FORMAT_HACK = "REWARD_FORMAT_HACK"
16
+ ALERT_DATA_EMPTY_VF = "DATA_EMPTY_VF"
17
+ ALERT_LOGIT_MODE_COLLAPSE = "LOGIT_MODE_COLLAPSE"
18
+ ALERT_ANSWER_TOKEN_DRIFT = "ANSWER_TOKEN_DRIFT"
19
+ ALERT_CLIP_FALSE_HEALTHY = "CLIP_FALSE_HEALTHY"
20
+ ALERT_OPSD_LEAKAGE_PATTERN = "OPSD_LEAKAGE_PATTERN"
21
+ ALERT_OPSD_ON_CORRECT = "OPSD_ON_CORRECT"
22
+
23
+
24
+ def _safe_float(v: Any, default: float = 0.0) -> float:
25
+ try:
26
+ if v is None:
27
+ return default
28
+ f = float(v)
29
+ if math.isnan(f) or math.isinf(f):
30
+ return default
31
+ return f
32
+ except (TypeError, ValueError):
33
+ return default
34
+
35
+
36
+ def _rolling_mean_std(values: list[float]) -> tuple[float, float]:
37
+ if not values:
38
+ return 0.0, 0.0
39
+ mean = sum(values) / len(values)
40
+ if len(values) < 2:
41
+ return mean, 0.0
42
+ var = sum((x - mean) ** 2 for x in values) / len(values)
43
+ return mean, math.sqrt(var)
44
+
45
+
46
+ class TrainingHealthMonitor:
47
+ """Collect per-step signals, emit layered [OPSD-HEALTH] logs, expose metrics keys."""
48
+
49
+ def __init__(self, config: Optional[dict[str, Any]] = None):
50
+ cfg = config or {}
51
+ self.enabled = bool(cfg.get("enabled", True))
52
+ self.window = max(2, int(cfg.get("window", 20)))
53
+ self.log_on_generate = bool(cfg.get("log_on_generate", True))
54
+ self.log_every_step = bool(cfg.get("log_every_step", True))
55
+ self.log_detail_bundle = bool(cfg.get("log_detail_bundle", True))
56
+ self.log_alerts_immediately = bool(cfg.get("log_alerts_immediately", True))
57
+ self.metrics_every_step = bool(cfg.get("metrics_every_step", True))
58
+
59
+ self._history: deque[dict[str, Any]] = deque(maxlen=self.window)
60
+ self._step_fields: dict[str, Any] = {}
61
+ self._step_alerts: list[str] = []
62
+ self._p_greedy_history: deque[float] = deque(maxlen=5)
63
+ self._p_answer_history: deque[float] = deque(maxlen=5)
64
+ self._eos_history: deque[float] = deque(maxlen=5)
65
+ self._last_step: Optional[int] = None
66
+
67
+ def reset_step(self, step: int) -> None:
68
+ self._step_fields = {"global_step": step}
69
+ self._step_alerts = []
70
+ self._last_step = step
71
+
72
+ def _emit_alert(self, step: int, code: str, **fields: Any) -> None:
73
+ if code not in self._step_alerts:
74
+ self._step_alerts.append(code)
75
+ if self.log_alerts_immediately and opsd_debug.should_log_health_alerts_immediately():
76
+ opsd_debug.log_health("ALERT", code, global_step=step, **fields)
77
+
78
+ def _check_generate_alerts(self, step: int, stats: dict[str, Any], logits: dict[str, Any]) -> list[str]:
79
+ clipped = _safe_float(stats.get("clipped_rate"))
80
+ eos_rate = _safe_float(stats.get("eos_terminated_rate"))
81
+ degenerate_rate = _safe_float(stats.get("degenerate_rate"))
82
+ repeat_loop = int(stats.get("repeat_loop_count", 0) or 0)
83
+ p_greedy = _safe_float(logits.get("p_greedy_first"))
84
+ p_eos = _safe_float(logits.get("p_eos_first"))
85
+ p_answer = _safe_float(logits.get("p_answer_first"))
86
+
87
+ if clipped > 0.8 and degenerate_rate < 0.05:
88
+ self._emit_alert(
89
+ step,
90
+ ALERT_CLIP_FALSE_HEALTHY,
91
+ clipped_rate=clipped,
92
+ degenerate_rate=degenerate_rate,
93
+ hint="high clip with low degenerate_rate often masks Answer-only collapse",
94
+ )
95
+ if clipped > 0.7 and eos_rate < 0.3:
96
+ self._emit_alert(
97
+ step,
98
+ ALERT_GEN_CLIP_COLLAPSE,
99
+ clipped_rate=clipped,
100
+ eos_rate=eos_rate,
101
+ hint="raise repetition_penalty, lower temperature, or shorten max_completion_length",
102
+ )
103
+ if degenerate_rate > 0.5 or repeat_loop > 0:
104
+ self._emit_alert(
105
+ step,
106
+ ALERT_GEN_REPEAT_DEGEN,
107
+ degenerate_rate=degenerate_rate,
108
+ repeat_loop_count=repeat_loop,
109
+ )
110
+
111
+ if p_greedy > 0:
112
+ self._p_greedy_history.append(p_greedy)
113
+ self._eos_history.append(eos_rate)
114
+ if (
115
+ len(self._p_greedy_history) >= 3
116
+ and all(p > 0.99 for p in list(self._p_greedy_history)[-3:])
117
+ and len(self._eos_history) >= 2
118
+ and self._eos_history[-1] < self._eos_history[-2] - 0.1
119
+ ):
120
+ self._emit_alert(
121
+ step,
122
+ ALERT_LOGIT_MODE_COLLAPSE,
123
+ p_greedy_first=p_greedy,
124
+ p_eos_first=p_eos,
125
+ eos_rate=eos_rate,
126
+ hint="first token collapsed to Goal: template; EOS probability near zero",
127
+ )
128
+ if p_answer > 0:
129
+ self._p_answer_history.append(p_answer)
130
+ if len(self._p_answer_history) >= 3 and all(
131
+ p < 0.5 for p in list(self._p_answer_history)[-3:]
132
+ ):
133
+ self._emit_alert(
134
+ step,
135
+ ALERT_ANSWER_TOKEN_DRIFT,
136
+ p_answer_first=p_answer,
137
+ hint="first-token Answer probability low for 3 consecutive generate batches",
138
+ )
139
+ return list(self._step_alerts)
140
+
141
+ def record_generate(
142
+ self,
143
+ step: int,
144
+ stats: dict[str, Any],
145
+ logits_stats: Optional[dict[str, Any]] = None,
146
+ ) -> list[str]:
147
+ if not self.enabled:
148
+ return []
149
+ logits_stats = logits_stats or {}
150
+ self._step_fields.update(
151
+ {
152
+ "degenerate_rate": stats.get("degenerate_rate"),
153
+ "clipped_rate": stats.get("clipped_rate"),
154
+ "eos_terminated_rate": stats.get("eos_terminated_rate"),
155
+ "repeat_loop_count": stats.get("repeat_loop_count"),
156
+ "char_repeat_count": stats.get("char_repeat_count", 0),
157
+ "p_greedy_first": logits_stats.get("p_greedy_first"),
158
+ "p_eos_first": logits_stats.get("p_eos_first"),
159
+ "p_answer_first": logits_stats.get("p_answer_first"),
160
+ "entropy_first": logits_stats.get("entropy_first"),
161
+ "degenerate_rate_format": stats.get("degenerate_rate_format"),
162
+ "degenerate_rate_repeat": stats.get("degenerate_rate_repeat"),
163
+ "format_without_thinking_rate": stats.get("format_without_thinking_rate"),
164
+ }
165
+ )
166
+ alerts = self._check_generate_alerts(step, stats, logits_stats)
167
+ alert_str = ",".join(alerts) if alerts else "none"
168
+
169
+ if self.log_on_generate and opsd_debug.should_log_health_on_generate():
170
+ opsd_debug.log_health(
171
+ "generate",
172
+ "batch health",
173
+ global_step=step,
174
+ degenerate_rate=stats.get("degenerate_rate"),
175
+ clipped_rate=stats.get("clipped_rate"),
176
+ eos_rate=stats.get("eos_terminated_rate"),
177
+ repeat_loop_count=stats.get("repeat_loop_count"),
178
+ char_repeat_count=stats.get("char_repeat_count", 0),
179
+ p_greedy=logits_stats.get("p_greedy_first"),
180
+ p_eos=logits_stats.get("p_eos_first"),
181
+ alerts=alert_str,
182
+ )
183
+ return alerts
184
+
185
+ def record_data(self, step: int, fields: dict[str, Any]) -> None:
186
+ if not self.enabled:
187
+ return
188
+ self._step_fields.update(fields)
189
+ vf_empty = _safe_float(fields.get("visual_fact_empty_rate"))
190
+ if vf_empty > 0.5:
191
+ self._emit_alert(
192
+ step,
193
+ ALERT_DATA_EMPTY_VF,
194
+ visual_fact_empty_rate=vf_empty,
195
+ hint="rebuild train_medium_vf_full.json with visual_fact hints",
196
+ )
197
+
198
+ def record_routing(self, step: int, fields: dict[str, Any]) -> None:
199
+ if not self.enabled:
200
+ return
201
+ self._step_fields.update(fields)
202
+ format_mean = _safe_float(fields.get("format_mean"))
203
+ acc_mean = _safe_float(fields.get("accuracy_mean"))
204
+ degenerate_rate = _safe_float(self._step_fields.get("degenerate_rate"))
205
+ if format_mean > 0.7 and acc_mean < 0.05 and degenerate_rate > 0.4:
206
+ self._emit_alert(
207
+ step,
208
+ ALERT_REWARD_FORMAT_HACK,
209
+ format_mean=format_mean,
210
+ accuracy_mean=acc_mean,
211
+ degenerate_rate=degenerate_rate,
212
+ )
213
+ opsd_on_correct = _safe_float(fields.get("opsd_on_correct_rate"))
214
+ if opsd_on_correct > 0.01:
215
+ self._emit_alert(
216
+ step,
217
+ ALERT_OPSD_ON_CORRECT,
218
+ opsd_on_correct_rate=opsd_on_correct,
219
+ )
220
+ leakage_skip = int(fields.get("opsd_skipped_leakage", 0) or 0)
221
+ if leakage_skip > 0:
222
+ self._emit_alert(
223
+ step,
224
+ ALERT_OPSD_LEAKAGE_PATTERN,
225
+ opsd_skipped_leakage=leakage_skip,
226
+ )
227
+
228
+ def record_loss(self, step: int, fields: dict[str, Any]) -> None:
229
+ if not self.enabled:
230
+ return
231
+ self._step_fields.update(fields)
232
+ loss_val = fields.get("combined_loss_scalar", fields.get("grpo_loss_scalar"))
233
+ if loss_val is not None and not math.isfinite(_safe_float(loss_val, default=float("nan"))):
234
+ self._emit_alert(step, ALERT_OPT_NAN_INF, loss=loss_val)
235
+
236
+ adv_abs = _safe_float(fields.get("advantages_abs_mean"))
237
+ zero_grpo = _safe_float(fields.get("grpo_zero_loss_rate"))
238
+ if adv_abs < 1e-6 and zero_grpo > 0.8:
239
+ self._emit_alert(
240
+ step,
241
+ ALERT_RL_ZERO_SIGNAL,
242
+ advantages_abs_mean=adv_abs,
243
+ grpo_zero_loss_rate=zero_grpo,
244
+ )
245
+
246
+ def record_optimizer(self, step: int, grad_norm: Optional[float], lr: Optional[float]) -> None:
247
+ if not self.enabled:
248
+ return
249
+ gn = _safe_float(grad_norm) if grad_norm is not None else None
250
+ if gn is not None:
251
+ self._step_fields["grad_norm"] = gn
252
+ hist = [h.get("grad_norm") for h in self._history if h.get("grad_norm") is not None]
253
+ if len(hist) >= 3:
254
+ mean, std = _rolling_mean_std([float(x) for x in hist])
255
+ if std > 1e-8 and gn > mean + 3 * std:
256
+ self._emit_alert(
257
+ step,
258
+ ALERT_OPT_GRAD_SPIKE,
259
+ grad_norm=gn,
260
+ rolling_mean=mean,
261
+ rolling_std=std,
262
+ )
263
+ if lr is not None:
264
+ self._step_fields["learning_rate"] = lr
265
+
266
+ def correlate(self) -> dict[str, Any]:
267
+ """Cross-step deltas and root-cause hints from rolling history."""
268
+ hints: list[str] = []
269
+ out: dict[str, Any] = {"root_cause_hints": hints}
270
+
271
+ if len(self._history) < 2:
272
+ out["root_cause_hints"] = ["insufficient history for correlation"]
273
+ return out
274
+
275
+ prev = self._history[-1]
276
+ prev2 = self._history[-2] if len(self._history) >= 2 else prev
277
+
278
+ for key in ("grad_norm", "clipped_rate", "eos_terminated_rate", "p_greedy_first", "degenerate_rate"):
279
+ cur_v = self._step_fields.get(key)
280
+ old_v = prev.get(key)
281
+ if cur_v is not None and old_v is not None:
282
+ out[f"delta_{key}"] = _safe_float(cur_v) - _safe_float(old_v)
283
+
284
+ gn_prev = prev.get("grad_norm")
285
+ clip_cur = self._step_fields.get("clipped_rate")
286
+ if gn_prev is not None and clip_cur is not None and _safe_float(clip_cur) > 0.7:
287
+ hints.append("high clip rate may follow recent gradient update (check delta_grad_norm)")
288
+
289
+ p_prev = prev2.get("p_greedy_first")
290
+ p_cur = self._step_fields.get("p_greedy_first")
291
+ eos_prev = prev2.get("eos_terminated_rate")
292
+ eos_cur = self._step_fields.get("eos_terminated_rate")
293
+ if (
294
+ p_prev is not None
295
+ and p_cur is not None
296
+ and _safe_float(p_cur) > 0.99
297
+ and eos_prev is not None
298
+ and eos_cur is not None
299
+ and _safe_float(eos_prev) > 0.5
300
+ and _safe_float(eos_cur) < 0.2
301
+ ):
302
+ hints.append("after gradient step: p_greedy rose to ~1.0 and eos_rate collapsed")
303
+
304
+ if ALERT_RL_ZERO_SIGNAL in self._step_alerts and ALERT_GEN_REPEAT_DEGEN in self._step_alerts:
305
+ hints.append("RL zero signal co-occurs with repetition degeneration")
306
+
307
+ if not hints:
308
+ hints.append("none")
309
+ out["root_cause_hints"] = hints
310
+ return out
311
+
312
+ def maybe_log_detail_bundle(self, step: int) -> None:
313
+ if not self.enabled or not self.log_detail_bundle:
314
+ return
315
+ if not opsd_debug.should_log_detail(step):
316
+ return
317
+ opsd_debug.log_health_detail_banner(step, "TRAINING HEALTH BUNDLE")
318
+ corr = self.correlate()
319
+ hist_keys = (
320
+ "degenerate_rate",
321
+ "clipped_rate",
322
+ "eos_terminated_rate",
323
+ "grad_norm",
324
+ "p_greedy_first",
325
+ "grpo_zero_loss_rate",
326
+ "sft_replaced_ratio",
327
+ )
328
+ rolling: dict[str, Any] = {}
329
+ for key in hist_keys:
330
+ vals = [_safe_float(h[key]) for h in self._history if h.get(key) is not None]
331
+ if vals:
332
+ mean, std = _rolling_mean_std(vals)
333
+ rolling[f"{key}_mean"] = mean
334
+ rolling[f"{key}_std"] = std
335
+
336
+ snapshot_fields = {k: v for k, v in self._step_fields.items() if k != "global_step"}
337
+ opsd_debug.log_health_detail(
338
+ "health",
339
+ "step snapshot",
340
+ global_step=step,
341
+ alerts=self._step_alerts or ["none"],
342
+ **snapshot_fields,
343
+ **rolling,
344
+ )
345
+ opsd_debug.log_health_detail(
346
+ "health",
347
+ "cross-step correlation",
348
+ global_step=step,
349
+ **corr,
350
+ )
351
+
352
+ def finish_step(self, step: int) -> dict[str, float]:
353
+ """L2 step summary log + metrics keys for Trainer.log()."""
354
+ snapshot = dict(self._step_fields)
355
+ snapshot["alert_count"] = len(self._step_alerts)
356
+ snapshot["alerts"] = list(self._step_alerts)
357
+ self._history.append(snapshot)
358
+
359
+ if self.log_every_step and opsd_debug.should_log_health_every_step():
360
+ corr = self.correlate()
361
+ opsd_debug.log_health(
362
+ "step",
363
+ "step summary",
364
+ global_step=step,
365
+ grad_norm=snapshot.get("grad_norm"),
366
+ lr=snapshot.get("learning_rate"),
367
+ sft_replaced_ratio=snapshot.get("sft_replaced_ratio"),
368
+ grpo_zero_loss_rate=snapshot.get("grpo_zero_loss_rate"),
369
+ degenerate_rate=snapshot.get("degenerate_rate"),
370
+ clipped_rate=snapshot.get("clipped_rate"),
371
+ eos_rate=snapshot.get("eos_terminated_rate"),
372
+ alert_count=len(self._step_alerts),
373
+ hints=corr.get("root_cause_hints"),
374
+ )
375
+
376
+ self.maybe_log_detail_bundle(step)
377
+
378
+ if not self.metrics_every_step:
379
+ return {}
380
+
381
+ metrics: dict[str, float] = {}
382
+ mapping = {
383
+ "completions/degenerate_rate": "degenerate_rate",
384
+ "completions/eos_rate": "eos_terminated_rate",
385
+ "completions/repeat_loop_count": "repeat_loop_count",
386
+ "routing/sft_replaced_ratio": "sft_replaced_ratio",
387
+ "routing/opsd_skipped_degenerate": "opsd_skipped_degenerate",
388
+ "routing/opsd_skipped_leakage": "opsd_skipped_leakage",
389
+ "routing/opsd_on_correct_rate": "opsd_on_correct_rate",
390
+ "routing/grpo_on_correct_rate": "grpo_on_correct_rate",
391
+ "routing/opd_teacher_call_rate": "opd_teacher_call_rate",
392
+ "teacher/privileged_suffix_has_gold_rate": "privileged_suffix_has_gold_rate",
393
+ "teacher/visual_fact_empty_rate": "visual_fact_empty_rate",
394
+ "teacher/suffix_len_mean": "teacher_suffix_len_mean",
395
+ "signal/grpo_zero_loss_rate": "grpo_zero_loss_rate",
396
+ "signal/advantage_abs_mean": "advantages_abs_mean",
397
+ "logits/p_greedy_first": "p_greedy_first",
398
+ "logits/p_eos_first": "p_eos_first",
399
+ "logits/p_answer_first": "p_answer_first",
400
+ "completions/degenerate_rate_format": "degenerate_rate_format",
401
+ "completions/degenerate_rate_repeat": "degenerate_rate_repeat",
402
+ "metrics/format_without_thinking_rate": "format_without_thinking_rate",
403
+ "phase/sft_cold_start": "phase_sft_cold_start",
404
+ "health/alert_count": "alert_count",
405
+ }
406
+ for metric_key, field_key in mapping.items():
407
+ val = snapshot.get(field_key)
408
+ if val is not None:
409
+ metrics[metric_key] = _safe_float(val)
410
+ return metrics
opsd_utils/privileged/__pycache__/providers.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
opsd_utils/privileged/image_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image loading and crop utilities for privileged teacher dual-image forward."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Optional
5
+
6
+ from PIL import Image
7
+
8
+ from data_utils.paths import resolve_image_path
9
+ from data_utils.privileged_schema import resolve_crop_bbox
10
+ from opsd_utils import debug_log as opsd_debug
11
+
12
+
13
+ def load_rgb(image: Any) -> Optional[Image.Image]:
14
+ """Load sample image as RGB PIL from path or in-memory object."""
15
+ if image is None:
16
+ return None
17
+ if isinstance(image, Image.Image):
18
+ return image.convert("RGB") if image.mode != "RGB" else image
19
+ if isinstance(image, str):
20
+ path = resolve_image_path(image)
21
+ try:
22
+ img = Image.open(path)
23
+ return img.convert("RGB")
24
+ except (FileNotFoundError, OSError):
25
+ opsd_debug.log("privileged_image", "load_rgb failed", path=path)
26
+ return None
27
+ return None
28
+
29
+
30
+ def center_crop(img: Image.Image, margin_ratio: float = 0.25) -> Image.Image:
31
+ w, h = img.size
32
+ margin_w = int(w * margin_ratio)
33
+ margin_h = int(h * margin_ratio)
34
+ return img.crop((margin_w, margin_h, w - margin_w, h - margin_h))
35
+
36
+
37
+ def crop_image(
38
+ img: Image.Image,
39
+ bbox_norm: Optional[list[float]] = None,
40
+ strategy: str = "center",
41
+ margin_ratio: float = 0.25,
42
+ fallback_reason: Optional[str] = None,
43
+ ) -> tuple[Image.Image, str]:
44
+ """
45
+ Crop image using C2 normalized bbox or center fallback.
46
+ Returns (cropped_image, crop_strategy_used).
47
+ """
48
+ if bbox_norm is not None and strategy in ("bbox", "heuristic", "bbox_then_center"):
49
+ w, h = img.size
50
+ x0 = int(bbox_norm[0] * w)
51
+ y0 = int(bbox_norm[1] * h)
52
+ x1 = int(bbox_norm[2] * w)
53
+ y1 = int(bbox_norm[3] * h)
54
+ x0, x1 = max(0, min(x0, w - 1)), max(1, min(x1, w))
55
+ y0, y1 = max(0, min(y0, h - 1)), max(1, min(y1, h))
56
+ if x1 > x0 and y1 > y0:
57
+ used = strategy if strategy != "bbox_then_center" else "bbox"
58
+ opsd_debug.log(
59
+ "privileged_image",
60
+ "crop_image bbox",
61
+ strategy=used,
62
+ bbox_norm=bbox_norm,
63
+ crop_px=(x0, y0, x1, y1),
64
+ fallback_reason=fallback_reason,
65
+ )
66
+ return img.crop((x0, y0, x1, y1)), used
67
+
68
+ crop = center_crop(img, margin_ratio=margin_ratio)
69
+ used = "center_fallback" if fallback_reason else "center"
70
+ opsd_debug.log(
71
+ "privileged_image",
72
+ "crop_image center",
73
+ strategy=used,
74
+ bbox_norm=bbox_norm,
75
+ margin_ratio=margin_ratio,
76
+ fallback_reason=fallback_reason,
77
+ )
78
+ return crop, used
79
+
80
+
81
+ def heuristic_crop_from_visual_fact(
82
+ img: Image.Image,
83
+ sample: dict[str, Any],
84
+ crop_cfg: Optional[dict[str, Any]] = None,
85
+ ) -> tuple[Image.Image, str, Optional[list[float]]]:
86
+ """D2 with D1 fallback: heuristic bbox from visual_fact, else center crop."""
87
+ crop_cfg = crop_cfg or {}
88
+ margin_ratio = float(crop_cfg.get("margin_ratio", 0.25))
89
+ bbox_norm, strategy = resolve_crop_bbox(sample, crop_cfg)
90
+ fallback_reason = None
91
+ if strategy == "center" and sample.get("visual_fact"):
92
+ fallback_reason = "heuristic_failed"
93
+ crop, used = crop_image(
94
+ img,
95
+ bbox_norm=bbox_norm,
96
+ strategy=strategy if bbox_norm else "center",
97
+ margin_ratio=margin_ratio,
98
+ fallback_reason=fallback_reason,
99
+ )
100
+ return crop, used, bbox_norm
101
+
102
+
103
+ def resolve_teacher_images(
104
+ sample: dict[str, Any],
105
+ profile: str,
106
+ crop_cfg: Optional[dict[str, Any]] = None,
107
+ ) -> tuple[list[Image.Image], dict[str, Any]]:
108
+ """
109
+ Build teacher image list for privileged forward.
110
+ text -> [full]; visual/hybrid + mode=dual -> [full, crop]; otherwise [full].
111
+ Returns (images, debug_meta).
112
+ """
113
+ crop_cfg = crop_cfg or {}
114
+ image = sample.get("image")
115
+ if image is None:
116
+ return [], {"crop_strategy": "none", "num_teacher_images": 0, "has_bbox": False}
117
+
118
+ full = load_rgb(image)
119
+ if full is None:
120
+ return [], {"crop_strategy": "load_failed", "num_teacher_images": 0, "has_bbox": False}
121
+
122
+ image_mode = str(crop_cfg.get("mode", "single")).strip().lower()
123
+ if profile == "text" or image_mode in ("single", "full", "off", "none"):
124
+ meta = {
125
+ "crop_strategy": "single_full",
126
+ "num_teacher_images": 1,
127
+ "has_bbox": False,
128
+ "bbox_norm": None,
129
+ "image_mode": image_mode,
130
+ }
131
+ return [full], meta
132
+
133
+ crop, crop_strategy, bbox_norm = heuristic_crop_from_visual_fact(full, sample, crop_cfg)
134
+ meta = {
135
+ "crop_strategy": crop_strategy,
136
+ "num_teacher_images": 2,
137
+ "has_bbox": bbox_norm is not None,
138
+ "bbox_norm": bbox_norm,
139
+ "full_size": full.size,
140
+ "crop_size": crop.size,
141
+ "image_mode": image_mode,
142
+ }
143
+ return [full, crop], meta
opsd_utils/prompt_builder.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from opsd_utils import debug_log as opsd_debug
8
+ from opsd_utils.privileged import build_privileged_context, maybe_save_privileged_images
9
+ from opsd_utils.teacher_batching import (
10
+ count_image_tokens,
11
+ process_teacher_sample,
12
+ stack_teacher_processor_batches,
13
+ )
14
+
15
+
16
+ def _build_teacher_text(student_prompt: str, privileged_suffix: str) -> str:
17
+ teacher_text = student_prompt
18
+ if privileged_suffix.strip():
19
+ teacher_text = f"{student_prompt}\n\n{privileged_suffix.strip()}"
20
+ return teacher_text
21
+
22
+
23
+ def tokenize_teacher_prompt(
24
+ processor,
25
+ student_prompt: str,
26
+ privileged_suffix: str,
27
+ images: Any,
28
+ ) -> dict:
29
+ """Tokenize teacher multimodal prompt = student question + privileged suffix + N images."""
30
+ if isinstance(images, list):
31
+ pil_images = [img for img in images if isinstance(img, Image.Image)]
32
+ else:
33
+ from opsd_utils.privileged.image_utils import load_rgb
34
+
35
+ one = load_rgb(images)
36
+ pil_images = [one] if one is not None else []
37
+
38
+ teacher_text = _build_teacher_text(student_prompt, privileged_suffix)
39
+
40
+ opsd_debug.log(
41
+ "teacher_prompt",
42
+ "tokenize_teacher_prompt",
43
+ num_images=len(pil_images),
44
+ suffix_len=len(privileged_suffix.strip()),
45
+ teacher_text_len=len(teacher_text),
46
+ )
47
+
48
+ batch = process_teacher_sample(processor, teacher_text, pil_images)
49
+
50
+ opsd_debug.log(
51
+ "teacher_prompt",
52
+ "tokenize_teacher_prompt result",
53
+ input_ids_shape=tuple(batch["input_ids"].shape),
54
+ has_pixel_values="pixel_values" in batch,
55
+ pixel_values_shape=tuple(batch["pixel_values"].shape) if "pixel_values" in batch else None,
56
+ image_token_count=count_image_tokens(batch["input_ids"], processor),
57
+ )
58
+ return batch
59
+
60
+
61
+ def build_teacher_prompt_batch(
62
+ processor,
63
+ samples: list[dict[str, Any]],
64
+ indices: list[int],
65
+ provider_names: list[str],
66
+ device,
67
+ *,
68
+ opsd_config: Optional[dict[str, Any]] = None,
69
+ global_step: Optional[int] = None,
70
+ output_dir: Optional[str] = None,
71
+ ) -> dict[str, Any]:
72
+ """Build padded teacher prompt tensors for OPSD samples at given indices."""
73
+ opsd_config = opsd_config or {}
74
+ privileged_profile = opsd_config.get("privileged_profile", "hybrid")
75
+ crop_cfg = opsd_config.get("privileged_image") or {}
76
+ privileged_debug_cfg = opsd_config.get("privileged_debug") or {}
77
+
78
+ opsd_debug.log(
79
+ "teacher_prompt",
80
+ "build_teacher_prompt_batch enter",
81
+ num_indices=len(indices),
82
+ indices=indices,
83
+ num_samples=len(samples),
84
+ provider_names=provider_names,
85
+ privileged_profile=privileged_profile,
86
+ device=str(device),
87
+ global_step=global_step,
88
+ )
89
+
90
+ if not indices:
91
+ opsd_debug.log("teacher_prompt", "empty indices, return {}")
92
+ return {}
93
+
94
+ sample_payloads: list[dict[str, Any]] = []
95
+ for idx in indices:
96
+ sample = samples[idx]
97
+ suffix, teacher_images = build_privileged_context(
98
+ sample,
99
+ provider_names,
100
+ privileged_profile=privileged_profile,
101
+ crop_cfg=crop_cfg,
102
+ opsd_config=opsd_config,
103
+ )
104
+ if not teacher_images:
105
+ from opsd_utils.privileged.image_utils import load_rgb
106
+
107
+ full = load_rgb(sample.get("image"))
108
+ teacher_images = [full] if full is not None else []
109
+
110
+ full_img = teacher_images[0] if teacher_images else None
111
+ crop_img = teacher_images[1] if len(teacher_images) > 1 else None
112
+ maybe_save_privileged_images(
113
+ global_step,
114
+ idx,
115
+ full_img,
116
+ crop_img,
117
+ meta={
118
+ "privileged_profile": privileged_profile,
119
+ "num_teacher_images": len(teacher_images),
120
+ "suffix_len": len(suffix.strip()),
121
+ },
122
+ output_dir=output_dir,
123
+ privileged_debug_cfg=privileged_debug_cfg,
124
+ )
125
+
126
+ teacher_text = _build_teacher_text(sample["prompt"], suffix)
127
+ sample_payloads.append(
128
+ {
129
+ "teacher_text": teacher_text,
130
+ "images": teacher_images,
131
+ "suffix_len": len(suffix.strip()),
132
+ "num_teacher_images": len(teacher_images),
133
+ }
134
+ )
135
+
136
+ batch = _build_teacher_batch_with_oom_retry(processor, sample_payloads)
137
+
138
+ out = {
139
+ "teacher_prompt_ids": batch["input_ids"].to(device),
140
+ "teacher_prompt_mask": batch["attention_mask"].to(device),
141
+ }
142
+ if batch.get("pixel_values_list"):
143
+ out["teacher_pixel_values_list"] = [pv.to(device) for pv in batch["pixel_values_list"]]
144
+ if batch.get("image_sizes_list"):
145
+ out["teacher_image_sizes_list"] = [sz.to(device) for sz in batch["image_sizes_list"]]
146
+
147
+ teacher_num_images = [int(max(0, n)) for n in batch.get("batch_num_images", [])]
148
+ if not teacher_num_images:
149
+ teacher_num_images = [p["num_teacher_images"] for p in sample_payloads]
150
+ out["teacher_num_images"] = torch.tensor(teacher_num_images, device=device, dtype=torch.long)
151
+
152
+ student_len = None
153
+ if indices and samples[indices[0]].get("prompt"):
154
+ student_messages = [
155
+ {
156
+ "role": "user",
157
+ "content": [{"type": "image"}, {"type": "text", "text": samples[indices[0]]["prompt"]}],
158
+ }
159
+ ]
160
+ student_text = processor.apply_chat_template(student_messages, add_generation_prompt=True)
161
+ student_len = len(processor(text=[student_text], return_tensors="pt")["input_ids"][0])
162
+
163
+ teacher_len = int(out["teacher_prompt_ids"].shape[1])
164
+ opsd_debug.log(
165
+ "teacher_prompt",
166
+ "build_teacher_prompt_batch done",
167
+ teacher_prompt_ids_shape=tuple(out["teacher_prompt_ids"].shape),
168
+ teacher_prompt_mask_shape=tuple(out["teacher_prompt_mask"].shape),
169
+ has_teacher_pixel_values=bool(out.get("teacher_pixel_values_list")),
170
+ teacher_pixel_values_shapes=[
171
+ tuple(pv.shape) for pv in out.get("teacher_pixel_values_list", [])[:4]
172
+ ],
173
+ teacher_images_count=sample_payloads[0]["num_teacher_images"] if sample_payloads else 0,
174
+ teacher_num_images=teacher_num_images,
175
+ teacher_image_token_counts=batch.get("image_token_counts"),
176
+ teacher_prompt_len=teacher_len,
177
+ vision_placeholder_delta=(teacher_len - student_len) if student_len else None,
178
+ )
179
+ opsd_debug.log_detail(
180
+ "teacher_prompt",
181
+ "teacher prompt batch built",
182
+ global_step=global_step,
183
+ batch_size=len(indices),
184
+ teacher_prompt_len=teacher_len,
185
+ teacher_pixel_values_shapes=[
186
+ tuple(pv.shape) for pv in out.get("teacher_pixel_values_list", [])[:4]
187
+ ],
188
+ teacher_image_token_counts=batch.get("image_token_counts"),
189
+ )
190
+
191
+ from opsd_utils.leakage import privileged_suffix_has_gold
192
+
193
+ vf_empty = 0
194
+ gold_suffix_count = 0
195
+ for idx in indices:
196
+ sample = samples[idx]
197
+ vf = (
198
+ sample.get("visual_fact_hint")
199
+ or sample.get("visual_fact")
200
+ or sample.get("visual_facts")
201
+ or ""
202
+ )
203
+ if not str(vf).strip():
204
+ vf_empty += 1
205
+ priv_suffix, _ = build_privileged_context(
206
+ sample,
207
+ provider_names,
208
+ privileged_profile=privileged_profile,
209
+ crop_cfg=crop_cfg,
210
+ opsd_config=opsd_config,
211
+ )
212
+ if privileged_suffix_has_gold(priv_suffix, sample):
213
+ gold_suffix_count += 1
214
+ suffix_lens = [p["suffix_len"] for p in sample_payloads]
215
+ n_idx = max(len(indices), 1)
216
+ out["teacher_stats"] = {
217
+ "teacher_suffix_len_mean": float(sum(suffix_lens) / len(suffix_lens)) if suffix_lens else 0.0,
218
+ "visual_fact_empty_rate": vf_empty / n_idx,
219
+ "privileged_suffix_has_gold_rate": gold_suffix_count / n_idx,
220
+ "num_teacher_images_mean": float(
221
+ sum(p["num_teacher_images"] for p in sample_payloads) / len(sample_payloads)
222
+ )
223
+ if sample_payloads
224
+ else 0.0,
225
+ }
226
+ return out
227
+
228
+
229
+ def _build_teacher_batch_with_oom_retry(
230
+ processor,
231
+ sample_payloads: list[dict[str, Any]],
232
+ ) -> dict:
233
+ """Process each teacher sample separately; on OOM halve micro-batch and retry."""
234
+ n = len(sample_payloads)
235
+ if n == 0:
236
+ return {}
237
+ micro = n
238
+ while micro >= 1:
239
+ try:
240
+ per_sample_batches: list[dict[str, Any]] = []
241
+ for start in range(0, n, micro):
242
+ end = min(start + micro, n)
243
+ for payload in sample_payloads[start:end]:
244
+ per_sample_batches.append(
245
+ process_teacher_sample(
246
+ processor,
247
+ payload["teacher_text"],
248
+ payload["images"],
249
+ )
250
+ )
251
+ return stack_teacher_processor_batches(processor, per_sample_batches)
252
+ except RuntimeError as exc:
253
+ if "out of memory" not in str(exc).lower() or micro == 1:
254
+ raise
255
+ opsd_debug.log(
256
+ "teacher_forward_oom",
257
+ "teacher prompt batch OOM, halving micro-batch",
258
+ original_batch=n,
259
+ micro_batch_size=micro,
260
+ new_micro_batch_size=max(1, micro // 2),
261
+ )
262
+ if torch.cuda.is_available():
263
+ torch.cuda.empty_cache()
264
+ micro = max(1, micro // 2)
265
+ return {}
outputs/logs/.ipynb_checkpoints/train_opd_7b_ds_20260614_175014-checkpoint.log ADDED
The diff for this file is too large to render. See raw diff
 
outputs/opd-7b-chartqa-ds/checkpoint-1764/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
outputs/opd-7b-chartqa-ds/checkpoint-2352/preprocessor_config.json ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": true,
3
+ "do_normalize": true,
4
+ "do_pad": true,
5
+ "do_rescale": true,
6
+ "do_resize": true,
7
+ "image_grid_pinpoints": [
8
+ [
9
+ 384,
10
+ 384
11
+ ],
12
+ [
13
+ 384,
14
+ 768
15
+ ],
16
+ [
17
+ 384,
18
+ 1152
19
+ ],
20
+ [
21
+ 384,
22
+ 1536
23
+ ],
24
+ [
25
+ 384,
26
+ 1920
27
+ ],
28
+ [
29
+ 384,
30
+ 2304
31
+ ],
32
+ [
33
+ 768,
34
+ 384
35
+ ],
36
+ [
37
+ 768,
38
+ 768
39
+ ],
40
+ [
41
+ 768,
42
+ 1152
43
+ ],
44
+ [
45
+ 768,
46
+ 1536
47
+ ],
48
+ [
49
+ 768,
50
+ 1920
51
+ ],
52
+ [
53
+ 768,
54
+ 2304
55
+ ],
56
+ [
57
+ 1152,
58
+ 384
59
+ ],
60
+ [
61
+ 1152,
62
+ 768
63
+ ],
64
+ [
65
+ 1152,
66
+ 1152
67
+ ],
68
+ [
69
+ 1152,
70
+ 1536
71
+ ],
72
+ [
73
+ 1152,
74
+ 1920
75
+ ],
76
+ [
77
+ 1152,
78
+ 2304
79
+ ],
80
+ [
81
+ 1536,
82
+ 384
83
+ ],
84
+ [
85
+ 1536,
86
+ 768
87
+ ],
88
+ [
89
+ 1536,
90
+ 1152
91
+ ],
92
+ [
93
+ 1536,
94
+ 1536
95
+ ],
96
+ [
97
+ 1536,
98
+ 1920
99
+ ],
100
+ [
101
+ 1536,
102
+ 2304
103
+ ],
104
+ [
105
+ 1920,
106
+ 384
107
+ ],
108
+ [
109
+ 1920,
110
+ 768
111
+ ],
112
+ [
113
+ 1920,
114
+ 1152
115
+ ],
116
+ [
117
+ 1920,
118
+ 1536
119
+ ],
120
+ [
121
+ 1920,
122
+ 1920
123
+ ],
124
+ [
125
+ 1920,
126
+ 2304
127
+ ],
128
+ [
129
+ 2304,
130
+ 384
131
+ ],
132
+ [
133
+ 2304,
134
+ 768
135
+ ],
136
+ [
137
+ 2304,
138
+ 1152
139
+ ],
140
+ [
141
+ 2304,
142
+ 1536
143
+ ],
144
+ [
145
+ 2304,
146
+ 1920
147
+ ],
148
+ [
149
+ 2304,
150
+ 2304
151
+ ]
152
+ ],
153
+ "image_mean": [
154
+ 0.5,
155
+ 0.5,
156
+ 0.5
157
+ ],
158
+ "image_processor_type": "LlavaOnevisionImageProcessor",
159
+ "image_std": [
160
+ 0.5,
161
+ 0.5,
162
+ 0.5
163
+ ],
164
+ "processor_class": "LlavaOnevisionProcessor",
165
+ "resample": 3,
166
+ "rescale_factor": 0.00392156862745098,
167
+ "size": {
168
+ "height": 384,
169
+ "width": 384
170
+ }
171
+ }
outputs/opd-7b-chartqa-ds/checkpoint-588/config.json ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlavaOnevisionForConditionalGeneration"
4
+ ],
5
+ "dtype": "bfloat16",
6
+ "eos_token_id": 151645,
7
+ "ignore_index": -100,
8
+ "image_grid_pinpoints": [
9
+ [
10
+ 384,
11
+ 384
12
+ ],
13
+ [
14
+ 384,
15
+ 768
16
+ ],
17
+ [
18
+ 384,
19
+ 1152
20
+ ],
21
+ [
22
+ 384,
23
+ 1536
24
+ ],
25
+ [
26
+ 384,
27
+ 1920
28
+ ],
29
+ [
30
+ 384,
31
+ 2304
32
+ ],
33
+ [
34
+ 768,
35
+ 384
36
+ ],
37
+ [
38
+ 768,
39
+ 768
40
+ ],
41
+ [
42
+ 768,
43
+ 1152
44
+ ],
45
+ [
46
+ 768,
47
+ 1536
48
+ ],
49
+ [
50
+ 768,
51
+ 1920
52
+ ],
53
+ [
54
+ 768,
55
+ 2304
56
+ ],
57
+ [
58
+ 1152,
59
+ 384
60
+ ],
61
+ [
62
+ 1152,
63
+ 768
64
+ ],
65
+ [
66
+ 1152,
67
+ 1152
68
+ ],
69
+ [
70
+ 1152,
71
+ 1536
72
+ ],
73
+ [
74
+ 1152,
75
+ 1920
76
+ ],
77
+ [
78
+ 1152,
79
+ 2304
80
+ ],
81
+ [
82
+ 1536,
83
+ 384
84
+ ],
85
+ [
86
+ 1536,
87
+ 768
88
+ ],
89
+ [
90
+ 1536,
91
+ 1152
92
+ ],
93
+ [
94
+ 1536,
95
+ 1536
96
+ ],
97
+ [
98
+ 1536,
99
+ 1920
100
+ ],
101
+ [
102
+ 1536,
103
+ 2304
104
+ ],
105
+ [
106
+ 1920,
107
+ 384
108
+ ],
109
+ [
110
+ 1920,
111
+ 768
112
+ ],
113
+ [
114
+ 1920,
115
+ 1152
116
+ ],
117
+ [
118
+ 1920,
119
+ 1536
120
+ ],
121
+ [
122
+ 1920,
123
+ 1920
124
+ ],
125
+ [
126
+ 1920,
127
+ 2304
128
+ ],
129
+ [
130
+ 2304,
131
+ 384
132
+ ],
133
+ [
134
+ 2304,
135
+ 768
136
+ ],
137
+ [
138
+ 2304,
139
+ 1152
140
+ ],
141
+ [
142
+ 2304,
143
+ 1536
144
+ ],
145
+ [
146
+ 2304,
147
+ 1920
148
+ ],
149
+ [
150
+ 2304,
151
+ 2304
152
+ ]
153
+ ],
154
+ "image_token_index": 151646,
155
+ "model_type": "llava_onevision",
156
+ "multimodal_projector_bias": true,
157
+ "pad_token_id": 151643,
158
+ "projector_hidden_act": "gelu",
159
+ "text_config": {
160
+ "_name_or_path": "Qwen/Qwen2-0.5B-Instruct",
161
+ "architectures": [
162
+ "Qwen2ForCausalLM"
163
+ ],
164
+ "attention_dropout": 0.0,
165
+ "bos_token_id": 151643,
166
+ "dtype": "bfloat16",
167
+ "eos_token_id": 151645,
168
+ "hidden_act": "silu",
169
+ "hidden_size": 896,
170
+ "initializer_range": 0.02,
171
+ "intermediate_size": 4864,
172
+ "layer_types": [
173
+ "full_attention",
174
+ "full_attention",
175
+ "full_attention",
176
+ "full_attention",
177
+ "full_attention",
178
+ "full_attention",
179
+ "full_attention",
180
+ "full_attention",
181
+ "full_attention",
182
+ "full_attention",
183
+ "full_attention",
184
+ "full_attention",
185
+ "full_attention",
186
+ "full_attention",
187
+ "full_attention",
188
+ "full_attention",
189
+ "full_attention",
190
+ "full_attention",
191
+ "full_attention",
192
+ "full_attention",
193
+ "full_attention",
194
+ "full_attention",
195
+ "full_attention",
196
+ "full_attention"
197
+ ],
198
+ "max_position_embeddings": 32768,
199
+ "max_window_layers": 24,
200
+ "model_type": "qwen2",
201
+ "num_attention_heads": 14,
202
+ "num_hidden_layers": 24,
203
+ "num_key_value_heads": 2,
204
+ "rms_norm_eps": 1e-06,
205
+ "rope_scaling": null,
206
+ "rope_theta": 1000000.0,
207
+ "sliding_window": null,
208
+ "tie_word_embeddings": true,
209
+ "use_cache": true,
210
+ "use_sliding_window": false,
211
+ "vocab_size": 152000
212
+ },
213
+ "tie_word_embeddings": false,
214
+ "transformers_version": "4.57.1",
215
+ "use_image_newline_parameter": true,
216
+ "video_token_index": 151647,
217
+ "vision_aspect_ratio": "anyres_max_9",
218
+ "vision_config": {
219
+ "attention_dropout": 0.0,
220
+ "dtype": "bfloat16",
221
+ "hidden_act": "gelu_pytorch_tanh",
222
+ "hidden_size": 1152,
223
+ "image_size": 384,
224
+ "intermediate_size": 4304,
225
+ "layer_norm_eps": 1e-06,
226
+ "model_type": "siglip_vision_model",
227
+ "num_attention_heads": 16,
228
+ "num_channels": 3,
229
+ "num_hidden_layers": 26,
230
+ "patch_size": 14,
231
+ "vision_use_head": false
232
+ },
233
+ "vision_feature_layer": -1,
234
+ "vision_feature_select_strategy": "full"
235
+ }
papers/full_text.txt ADDED
@@ -0,0 +1,1211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ --- PAGE 1 ---
3
+ PublishedasaconferencepaperatICLR2026
4
+ EMPOWERING SMALL VLMS TO THINK WITH DY-
5
+ NAMIC MEMORIZATION AND EXPLORATION
6
+ JiazhenLiu ,YuchuanDeng ,andLongChen∗
7
+ TheHongKongUniversityofScienceandTechnology
8
+ https://github.com/HKUST-LongGroup/DyME
9
+ ABSTRACT
10
+ Small-scaleVision–LanguageModels(SVLMs)areexceptionallywell-suitedfor
11
+ proprietarytasks. Equippingthemwiththinkingcapabilitiesisacriticalstepto
12
+ enhance their performance and reliability in these specific domains. However,
13
+ existingtrainingparadigms,includingSupervisedFine-Tuning(SFT)andRein-
14
+ forcementLearningwithVerifiableReward(RLVR),imposesubstantialdemands
15
+ onthebaseVLM,exceedingthecapacityofSVLMs. Consequently,directlyap-
16
+ plyingtheseparadigmstoSVLMsfailstoinstillthedesiredthinkingabilities. A
17
+ natural solution is to combine SFT and RLVR, leveraging their complementar-
18
+ ity to reduce the dependence on model capacity. Yet the core challenge lies in
19
+ managingtheinherenttrade-off: excessiverelianceonSFTcanforcethemodel
20
+ tomemorizepseudothinkingtraces,whileover-emphasizingRLVRcanleadto
21
+ unstableexploration(i.e.,advantagecollapse). Toaddressthis,weproposeDyME,
22
+ anoveltrainingparadigmthatDynamicallyselectsbetweenMemorization(via
23
+ SFT) and Exploration (via RLVR) at each optimization step. By ensuring that
24
+ every update contributes to the trade-off, DyME serves as a robust, standalone
25
+ strategythatstabilizesSVLMlearning. Complementingthisparadigm,wefurther
26
+ introduceasynergisticVisualSupervisionmechanism(comprisingavisualchecker
27
+ andrefiner)designedtoinjectdynamicallyenhanced,image-groundedguidance
28
+ duringoptimization. Extensiveexperimentsacrossdiversedomainsdemonstrate
29
+ that DyME consistently achieves this balance, and thus delivers substantial per-
30
+ formance improvements on specialized tasks. These results establish DyME as
31
+ apracticalandeffectivesolutionforempoweringSVLMswithreliablethinking
32
+ capabilities.
33
+ 1 INTRODUCTION
34
+ EquippingVision–LanguageModels(VLMs)withthinkingcapabilitiesisapivotalstepthatmoves
35
+ thembeyondrecognitiontowardreasoning. Recentstudieshaveadvancedthisgoalthroughspe-
36
+ cializedtraining,achievingstrongresultsonaspectrumofvisualtasks,fromrecognition-intensive
37
+ applicationslikegrounding(Laietal.,2025;Liu&Chen,2025;Pengetal.,2025;Liuetal.,2025c;a)
38
+ toreasoning-intensivechallengessuchaschartunderstanding(Zhangetal.,2025a;Xiaetal.,2024)
39
+ andgeometricproblemsolving(Shenetal.,2025;Chenetal.,2025b;Xiaetal.,2025). Whilethis
40
+ progressissignificant,thesuccessoftheseapproachesiscontingentuponthebaseVLMpossessing
41
+ strongfoundationalcapabilities,namely,sufficientcapacityandrobustinstructionadherence(Yang
42
+ etal.,2025a). Inpractice,onlyahandfulofVLMsmeettheseprerequisites,presentingasignificant
43
+ challengeforSmall-scaleVLMs(SVLMs)whichstruggletodevelopthinkingcapabilitiesunder
44
+ existingtrainingparadigms.
45
+ Tocontextualizethislimitation,webrieflyreviewthetwodominantparadigms,bothofwhichare
46
+ primarilytailoredforLarge-scaleVLMs(LVLMs). 1)SupervisedFine-Tuning(SFT)onChain-of-
47
+ Thought(CoT)data(Xuetal.,2024;Lietal.,2024b;Xiaetal.,2025;Gaoetal.,2025): VLMsare
48
+ supervisedtomemorizepredefinedthinkingpatternsfromlarge-scaleCoTannotations. SinceCoT
49
+ dataareoftenverboseandcontainmuchvision-irrelevantcontent,modelsmustpossesssufficient
50
+ capacity to absorb long textual content without compromising visual grounding (Marafioti et al.,
51
+ ∗Correspondingauthor(longchen@ust.hk)
52
+ 1
53
+ 6202
54
+ beF
55
+ 72
56
+ ]VC.sc[
57
+ 2v16032.6052:viXra
58
+ --- PAGE 2 ---
59
+ PublishedasaconferencepaperatICLR2026
60
+ After CoT SFT training During RLVR training
61
+ (Input Image)
62
+ Rapid Decline in Brazilians' Which year shows a greater divergence of opinions about Brazil's Which year has the most divergent
63
+ Assessment of Economy economy, 2010 or 2012? Output the thinking process and then give opinions about Brazil's economy? Output
64
+ the final answer in <answer> </answer> tag. the thinking process and then give the
65
+ final answer in <answer> </answer> tag.
66
+ We locate 2010 and 2012. To answer the question, we To answer the question, we examine the
67
+ Then, extract 2010 has 62 first locate the relevant years chart and compare the value differences
68
+ and 36, 2012 has 65 and in the chart. Then, we extract for each year. In 2015, the values are 87
69
+ 35. Comparing 62 − 36 = 26 the values for each year. and 13, showing the most significant
70
+ < 65 − 35 = 30, the year with Compare the differences, the divergence. <answer>2015</answer>
71
+ greater divergence is 2012. year with greater divergence is
72
+ <answer>2012</answer> 2010. <answer>2010</answer> 2015 has the greatest divergence.
73
+ Large-scale VLM
74
+ LVLMs: vision preserved SVLMs: vision compromised SVLMs poorly follow instructions
75
+ Small-scale VLM
76
+ Grounded thinking traces Pseudo thinking traces (fail) Advantage collapsing (fail)
77
+ a) SFT and RL paradigms fail to enable SVLMs to think.
78
+ Two-stage training DyME training
79
+ Extent
80
+ 1. Locate the year; Constrained
81
+ 2. Extract the value; exploration RL RL
82
+ 3. Get the answer. RL safe to
83
+ <answer>2012</answer> explore
84
+ (hard to achieve) (single step)
85
+ SFT . t < . h . a e g n c r s h o w a u e r n r t d > .. e 2 . d A 0 1 v f 2 t a e < l r u / t a e h n i f n s r w o k m i e n r g > , RL c T ap h a in b k il i i n ti g e s Switcher Switcher c T ap h a in b k il i i n ti g es
86
+ (numerous steps)
87
+ 2012 has the greatest div- need to
88
+ e (n rg o e f n o c r e m . at answer) RL A d co v l a la n p ta s g e e memorize SFT SFT
89
+ b) Two-stage training vs. DyME
90
+ Figure1: TrainingparadigmsforenablingVLMthinking. TheLVLMisQwen2.5-VL-32B(Bai
91
+ etal.,2025)andtheSVLMisSmolVLM-500M(Marafiotietal.,2025). (a)Existingparadigmsare
92
+ effectiveforLVLMsbutunsuitableforSVLMs. (b)Thetwo-stagetrainingparadigm(SFT→RL)
93
+ facesachallengingtrade-off. OurproposedDyMEdynamicallybalancesthistrade-off.
94
+ 2025). This capability gap is illustrated in Fig. 1a: After SFT, LVLMs can generate grounded
95
+ thinkingtraceswithaccurateintermediatevalues(ingreen),whileSVLMscannot. 2)Reinforcement
96
+ LearningwithVerifiableReward(RLVR)(Zhangetal.,2025a;Chenetal.,2025b;Pengetal.,
97
+ 2025;Shenetal.,2025): ontheotherhand,promotesexplorationofthinkingpatternsratherthan
98
+ imitations. Inthisparadigm,VLMsareinstructedtogenerateathoughtprocessfollowedbyastrictly
99
+ formattedanswer(e.g.,enclosedintags). Thisformatenablesverifiablerewardstoreinforcecorrect
100
+ generationsandpenalizeincorrectones. Owingtoitsrelianceoninstructionadherence,thisapproach
101
+ ispracticalprimarilyforstrongVLMsthatcanreliablygeneratestructuredoutputs.
102
+ Consequently,bothestablishedparadigmsareinadequate
103
+ forinstillingthinkinginSVLMs. Theextremelylimited
104
+ Baseline
105
+ capacity(e.g.,under1Bparameters)ofSVLMsrenders Baseline +CoT SFT
106
+ +RLVR
107
+ the SFT paradigm ineffective, as a high volume of tex- +Two-stage
108
+ +CoT SFT +Ours
109
+ tual information in CoT data can overwhelm the capac-
110
+ ity(Marafiotietal.,2025;Chenetal.,2025a). Moreover, +RLVR
111
+ the limited instruction adherence of SVLMs frequently
112
+ +Two-stage
113
+ resultsinunverifiableoutputs(Chuetal.,2025;Guoetal.,
114
+ 2025),precipitatingadvantagecollapseduringRLVR.We +Ours
115
+ quantitatively verify these limitations (cf., Fig. 2): both
116
+ 35 40 45 50 55 60 65 70 75
117
+ SFTandRLVRparadigmsindeedimpairtheperformance.
118
+ Figure2: PerformanceofSmolVLM-
119
+ Considering that SVLMs offer high efficiency and are 500M (Marafioti et al., 2025) on
120
+ crucialfordeploymentonedgedevices(Marafiotietal., ChartQA (Masry et al., 2022). Ex-
121
+ 2025), enabling them to think addresses a strong prac- istingparadigmsdegradeperformance,
122
+ tical demand. Thinking enhances the reliability and whereasDyMEyieldsimprovements.
123
+ performance of vision tasks (Zhang et al., 2025a), and
124
+ task-specificSVLMsprovideacompellingalternativeto
125
+ LVLMsinresource-constrainedsettings. Thismotivatesthedevelopmentofanewtrainingparadigm
126
+ thatempowersSVLMswiththinkingcapabilities,atleastforspecializedtasks.
127
+ 2
128
+ --- PAGE 3 ---
129
+ PublishedasaconferencepaperatICLR2026
130
+ ApromisingsolutionistofuseSFTandRLVR,asawell-calibratedtrade-offcanlowerthehigh
131
+ demandsonthebasemodel(DeepSeek,Inc.,2025;Yanetal.,2025): SFTencouragesthemodelto
132
+ memorizeverifiablethinkingpatternstopreventadvantagecollapse,whileRLforcesexplorationto
133
+ preventrigidtemplatesfromoverwhelmingthemodel’scapacity. Thecentralchallenge,however,is
134
+ thatSVLMsstruggletoachievethisbalance. Existinghybridmethods,liketwo-stagetraining(Chen
135
+ etal.,2025a;Chuetal.,2025)orannealedSFTlosses1(Zhangetal.,2025b),relyonastatictrade-off
136
+ governedbyhyperparameterssetempirically. Thisrigidityisthecriticalflawbecausetheminimal
137
+ capacityofSVLMsmeansthewindowforasuccessfulstaticbalanceisincrediblynarrow,making
138
+ failurealmostinevitable(cf. Fig. 1b). Ourrepeatedtrialswithtwo-stagetrainingconfirmedthisissue,
139
+ withperformanceoftenfallingbelowthebaseline(cf. Fig. 2).
140
+ SVLMs therefore require a more intelligent paradigm to navigate this trade-off. To this end, we
141
+ proposeDyME(DynamicMemorize–Explore),whichintegratesSFTandRLVRthroughadynamic
142
+ switchingmechanism. AsillustratedinFig.1b,DyMEassessesthemodel’sgenerationateachstep
143
+ andadaptsitstrainingmodeaccordingly. Whenthemodelfailstofollowinstructions,itswitchestoa
144
+ memorizationmode(SFT)toguaranteestableoptimizationsignals. Conversely,forvalidgenerations,
145
+ itengagesanexplorationmode(RLVR)toencouragediverseandgroundedthinking.Thisstate-driven
146
+ approachensuresmemorizationandexplorationarealwayscomplementary,dynamicallymaintaining
147
+ thedelicatetrade-off. Whilethisdynamicswitchingaloneguaranteestrainingstability,wefurther
148
+ maximizethemodel’spotentialbyincorporatingasynergisticVisualSupervisionmechanism. This
149
+ modulefacilitatesanadaptiveinteraction: theCoTground-truthguidesthescoringofexploration
150
+ (viaavisualchecker),whilesuccessfulexplorationtracesdynamicallyrefinetheCoTground-truth
151
+ (viaavisualrefiner).
152
+ TheaforementioneddesignmakesDyMEahighlyeffectiveparadigmforempoweringthinkingin
153
+ SVLMsforspecifictasks. Wevalidatethisacrossthreediversedomains,rangingfromrecognition-
154
+ intensivetasks(medicalVQA)toreasoning-intensivechallenges(chartunderstandingandgeometric
155
+ problemsolving). Remarkably, usingonlyafewthousandtrainingsamples, DyMEachievessub-
156
+ stantial performance gains, enabling it to match or even surpass several LVLMs. Our primary
157
+ contributionsareasfollows:
158
+ 1. We propose DyME, the first training paradigm that equips SVLMs with thinking capabilities,
159
+ substantiallyreducingrelianceonthebaseVLM’sinitialcapacity.
160
+ 2. Throughdynamicswitchingandsynergisticsupervision,DyMEalleviatespseudothinkingtraces
161
+ andadvantagecollapseinSVLMs,yieldingimage-groundedthinkingandconsistentperformance
162
+ improvements.
163
+ 3. WedemonstratetheeffectivenessandpracticalityofDyMEacrossthreediversedomains,each
164
+ consistentlyshowingsubstantialperformancegainswithonlyafewthousandtrainingsamples.
165
+ 2 RELATED WORK
166
+ Vision-LanguageModels. ModernVLMs,suchasLLaVA(Liuetal.,2024a)andQwen-VL(Bai
167
+ etal.,2023),havedemonstratedremarkablecapabilitiesacrossawidearrayofvisiontasks. How-
168
+ ever,theirsubstantialparametercountsandcomputationaldemandsrestricttheiruseinresource-
169
+ constrained environments like edge devices. This has motivated a growing interest in SVLMs
170
+ designedforefficiency(Zhouetal.,2024;Marafiotietal.,2025;Korrapati,2024). Althoughworks
171
+ likeTinyLLaVA(Zhouetal.,2024)andSmolVLM(Marafiotietal.,2025)haveshownthatcarefully
172
+ designedSVLMscanachievecompetitiveperformance, theyexhibitacriticalweakness. Recent
173
+ studieshighlightthattheirperformancedegradessignificantlyontasksrequiringcomplex,multi-step
174
+ instructionfollowing,indicatingagapintheircompositionalunderstandingandgeneralreasoning
175
+ abilities(Albalaketal.,2022;Ghoshetal.,2024;Liuetal.,2025b).
176
+ Empowering Thinking Capabilities in VLMs. Recent advances in LLM thinking (e.g., GPT-
177
+ o1(OpenAI,2024),DeepSeek-R1(Guoetal.,2025))havemotivatedeffortstoequipVLMswith
178
+ similarcapabilitiesviadedicatedtrainingparadigms.
179
+ SFTonCoTdata (Xu et al., 2024; Xia et al., 2024; 2025; Gao et al., 2025; Yang et al., 2025b).
180
+ Thisparadigmleverageslarge-scaleCoTsupervisiontoteachmodelstomemorizeandgeneralize
181
+ thinkingpatterns. Multimodal-CoT(Zhangetal.,2023)wasanearlyattemptusingfusedvisual–text
182
+ 1Seethesupplementarymaterialforfurthercomparison.
183
+ 3
184
+ --- PAGE 4 ---
185
+ PublishedasaconferencepaperatICLR2026
186
+ inputs,butitssmallscaledatalimitedgenuinethinking. Subsequentworkshighlighttheroleofscale:
187
+ G-LLaVA(Gaoetal.,2025)constructs170Kgeometry-specificCoTsamples;ChartVLM(Xiaetal.,
188
+ 2024)compilesalargechartcorpus;andLLaVA-CoT(Xuetal.,2024)aswellasR1-OneVision(Yang
189
+ etal.,2025b)curatediverse,structuredCoTdatathroughlarge-scalepromptengineering. These
190
+ approachesfacelonginputs,requiringlargeVLMsthatcanprocessrichtextualinformationwhile
191
+ preservingvisualgrounding(Marafiotietal.,2025;Zhaietal.,2023).
192
+ RLwithVerifiableReward(RLVR)(Zhangetal.,2025a;Chenetal.,2025b;Pengetal.,2025;Shen
193
+ et al., 2025; Liu et al., 2025c). RLVR adopts a distinct paradigm that elicits thinking through
194
+ autonomousexplorationwithminimalexternalsupervision. ThepopularlyusedalgorithmisGroup
195
+ RelativePolicyOptimization(GRPO),introducedbyDeepSeek-Math(Shaoetal.,2024), which
196
+ exploitsmodels’abilitytoproducestructuredoutputsthatseparatethinkingfromfinalanswers. It
197
+ leverages rule-verifiable data to optimize high-scoring generations, while light SFT is employed
198
+ forcold-startwhentheoutputstructureisunclear. ThisparadigmhasbeenextendedtoVLMsin
199
+ severalworks. R1-V(Chenetal.,2025b)appliesGRPOtoVLMs,enablingthinkingintaskssuchas
200
+ countingandgeometry. LMM-R1(Pengetal.,2025)introducesatwo-stagepipelinethattransfers
201
+ textualthinkingintomultimodallearning. VisualRFT(Liuetal.,2025c)andR1-VL(Zhangetal.,
202
+ 2025a)incorporatevision-specificrewardstoguidefine-grained,visuallygroundedoptimization.
203
+ SinceGRPOdependsonmodels’initialstructuredthinkingability,thesemethodstypicallybuildon
204
+ strongVLMs,suchastheQwen-VLseries(Baietal.,2025).
205
+ HybridTrainingParadigms(Chuetal.,2025;Yanetal.,2025;Zhangetal.,2025b). Toharnessthe
206
+ complementarystrengthsofSFTandRL,researchershavealsoinvestigatedhybridparadigms. A
207
+ commonapproachisatwo-stagetrainingprocess(Chuetal.,2025)thatfirstusesSFTtoteachthe
208
+ modelthedesiredoutputformat,followedbyRLforexploration. Althoughintuitive,thismethodis
209
+ highlysensitivetotheamountofSFT,aparameterthatisparticularlychallengingtotuneforSVLMs,
210
+ asthesesmallermodelscaneasilybecometrappedinsuboptimalstates. Alternativestrategiesattempt
211
+ tocontinuouslyblendSFTwithRL,forinstance, byincorporatingSFTasanannealedauxiliary
212
+ loss(Zhangetal.,2025b)orbymanagingitsinfluencewithanempiricalshapingfunction(Yanetal.,
213
+ 2025). However,allthesestrategiesultimatelyrelyonanempiricallydeterminedbalancebetween
214
+ thetwoparadigms. ThisrigidityrepresentsacriticalflawwhenappliedtoSVLMs. Theabsenceof
215
+ adaptivecontrolovertheSFTweightrendersthesemethodsbrittleandunreliable.
216
+ Thus,existingparadigmsarenotdirectlytransferabletoSVLMsduetotheirinherentlimitations
217
+ inmodelcapacityandinstruction-followingability. Thishighlightstheneedforanoveltraining
218
+ paradigmthatimposesminimalrequirementsonthebaseVLM.
219
+ 3 APPROACH
220
+ 3.1 PRELIMINARIES
221
+ Wefirstbrieflyrecapthetwotrainingparadigms(SFTandRLVR)thatunderlieourmethod. Let
222
+ D ={(x ,y )}N bethetrainingset,wherexdenotestheinput(e.g. animage-instructionpair)and
223
+ i i i=1
224
+ ythedesiredoutput. Themodeldefinesaconditionaldistributionp (y |x)withparametersθ.
225
+ θ
226
+ Supervised Fine-Tuning (SFT). For each training pair (x,y) in D, SFT updates the model by
227
+ minimizingthenegativelog-likelihood(cross-entropy)ofthedesiredoutputyundertheconditional
228
+ distributionp (y |x):
229
+ θ
230
+ L (θ)=−E (cid:2) logp (y |x) (cid:3) . (1)
231
+ SFT (x,y)∼D θ
232
+ Thisteacher-forcinglossallowsmodelstomemorizeextensivetrainingexamples,compellingthe
233
+ modeltoabsorbthisknowledge.
234
+ GroupRelativePolicyOptimization(GRPO).GRPOisanRLalgorithmthatexploresopen-ended
235
+ generationbycomparingcandidateoutputswithinagroup. Foreachinputx,thepolicyp samplesa
236
+ θ
237
+ set{y˜k}K ;arewardfunctionr (y˜k)iscomputedbasedonthecorrectnessoftheoutputanswer,
238
+ k=1 a
239
+ andeachsample’sadvantageAismeasuredrelativetotheothergroupmembers:
240
+ (cid:118)
241
+ A(y˜k) = r a (y˜ σ k) + − ε r¯ a, r¯ a = K 1 (cid:88) K r a (y˜j), σ= (cid:117) (cid:117) (cid:116) K 1 (cid:88) K (r a (y˜j)−r¯)2, (2)
242
+ j=1 j=1
243
+ 4
244
+ --- PAGE 5 ---
245
+ PublishedasaconferencepaperatICLR2026
246
+ where ε is a small constant for numerical stability. The policy then updates its parameters by
247
+ minimizingthefollowingloss,regularisedbyaKLconstraint:
248
+ L (θ)=−E E (cid:104) min (cid:0) r (x,y˜)A(y˜),clip (cid:0) r (x,y˜);1−ϵ,1+ϵ (cid:1) A(y˜) (cid:1)(cid:105)
249
+ GRPO x∼D y˜∼pθ θ θ
250
+ +βD (cid:2) p (·|x)∥p (·|x) (cid:3) , where r (x,y˜) = p θ (y˜|x) . (3)
251
+ KL θ ref θ p (y˜|x)
252
+ old
253
+ TheclipandKLtermsworktogethertokeepeachupdateclosetosaferegionsoftheparameterspace:
254
+ theclipgatelimitsstepsizearoundtherolloutpolicyp ,whiletheKLterm(βD )tethersthe
255
+ old KL
256
+ policytothereferencep (typicallytheinitialmodel).
257
+ ref
258
+ GradientCompatibilityofSFTandGRPO.Below,werevealthattheoptimizationobjectivesof
259
+ SFTandGRPOareformallyequivalent,withtheformertargetingtheground-truthdatadistribution
260
+ andthelatteraninternalone.
261
+ ThegradientoftheSFTlossisstraightforward:
262
+ ∇ L (θ)=−E [∇ logp (y |x)]. (4)
263
+ θ SFT (x,y)∼D θ θ
264
+ Similarly,theGRPOgradient(ignoringclippingandanyKL-penalty)canbewrittenas
265
+ ∇ L (θ)=−E (cid:2) r (x,y˜)A(y˜)∇ logp (y˜|x) (cid:3) . (5)
266
+ θ GRPO x∼D, θ θ θ
267
+ y˜∼pold(·|x)
268
+ ThiscomparisonshowsthattheSFTgradientisaspecialcaseoftheGRPOgradient,obtainedwhen
269
+ theground-truthsampleisusedwithunitadvantage. Thisequivalenceenablesaunifiedlossthat
270
+ balancesexternalimitation(SFT)withinternalrefinement(GRPO).Achievingthisfusionrequires
271
+ dynamicallyweightingthetwosignals(§3.2)andensuringstylisticconsistencybetweenexternal
272
+ ground-truthandself-generatedoutputs(§3.3).
273
+ 3.2 DYNAMICMEMORIZE–EXPLORE(DYME)
274
+ Torealizethiscomplementarity,weproposetheDynamicMemorize–Explore(DyME)paradigm,
275
+ whichadaptivelyswitchesbetweenSFTandGRPOateachtrainingstep. Inthefollowing,wefirst
276
+ outlinetheoverallpipelineandthenelaborateontheoptimizationproceduresforeachmode.
277
+ Overall.AsshowninFig.3a,eachtrainingstepbeginswithaninputx=(I,q),whereIistheimage
278
+ andq isaninstruction. ThepolicySVLMp generatesK responses{y˜k}K . Eachresponseis
279
+ θ k=1
280
+ parsedintoathinkingtraceandafinalanswer,whichisthenverifiedforcorrectnessusingpredefined
281
+ rules. Theverificationresultsfallintotwocategories: eitherallresponsesareincorrect(including
282
+ those that fail to parse), or at least one is correct. The decision rule: if at least one response is
283
+ correct, themodelproceedswithGRPO-basedexploration; otherwise, itfallsbacktoSFT-based
284
+ memorization. Formally,thetrainingmodeisswitchedas:
285
+ (cid:26) GRPO, if max r (y˜k)=1,
286
+ mode(x)= k a (6)
287
+ SFT, otherwise,
288
+ where r (y˜k) ∈ {0,1} indicates whether y˜k passes rule-based verification. Though simple, this
289
+ a
290
+ decisionruleishighlyeffective. Whenallresponsesareincorrect,theanswerrewardsareessentially
291
+ allzeroandthenormalizedadvantagesbecomedominatedbynoise,makingGRPOupdatesfora
292
+ smallSVLMunstable. Inthisregime,fallingbacktoSFTprovidesalow-variance,ground-truth
293
+ guidedgradient. Conversely,theappearanceofatleastonecorrectresponseindicatesthatthecurrent
294
+ policy has already discovered a feasible solution for this input, so GRPO can safely exploit the
295
+ relativeadvantagestodriveexploration.
296
+ GRPOMode. DyMEintroducesakeyrefinementtotheoriginalGRPO:beyondtheanswerreward
297
+ r ,itincorporatesanauxiliaryrewardr forthinkingtraces. Thisrewardiscomputedbyevaluating
298
+ a t
299
+ thegeneratedtracesagainstexpectedthinkingpatterns(e.g.,viatoken-levelF1scoreground-truth
300
+ comparison),promotingstructuredthinking.
301
+ Giventheserewards,weupdatethepolicyusingamodifiedGRPOobjective. Unlikethestandard
302
+ formulation(Eqs.2&3),weomittheKLpenaltyandclippingterms,asthedynamicintegrationof
303
+ 5
304
+ --- PAGE 6 ---
305
+ PublishedasaconferencepaperatICLR2026
306
+ GiveGni vveinsu vails ufaacl tfsacts
307
+ AnyA cnoyr rceocrtr ect FromFr tohme ftihgeu frieg, uthree, vthaelu veasl rueepsr reespernetseedn bteyd by
308
+ GRPGORPO
309
+ resproenspseosn?ses? the tthhree eth breaers b caarns cbaen d biree cdtilrye cotblys eorbvseedr.ved.
310
+ TakiTnagk tihnegi rt hmeiera mn eyaienl dysie alnd so avne roavllerall
311
+ q q SwiStcwhietrcher or or averaavgeer aogf ea popf raopxpirmoaxtiemlya t8e.l3y, 8w.3h,i cwhhich
312
+ SVLSMVLM provpirdoevsi ad ecso an cciosne csiusme smuamrym oafr yt hoef the
313
+ distrdibisutrtiiobnu trieofnl ercetfeledc itnVe dtih siuenV actihlhs ueaC racth.lh eaCcrkth.eerc:k 0er: 0
314
+ I I SFTSFT
315
+ OveOravlelr pailpl epliipneeline VisuVails uRaelf Rineefriner
316
+ pip pip
317
+ AnsAwnesrw Reerw Raerwdard
318
+ GTGT VisuVails uRaelf Rineefriner To aTnosw aners wtheer qthuee sqtuioens,t i.o..n, ,s .t.e.,p s1t eisp 1to is to
319
+ (0, 1(,0 1, ,1 0, )1, 0) extraecxtt rtahcet vthaelu veasl fuoers Lfoart vLiaat, vPiao,r tPuogratlu,gal,
320
+ anda Rnodm Raonmiaa:n 3i.a0:, 38..09,, 8a.n9d, a 1n3d.0 1,3.0,
321
+ respreecstpiveecltyiv. eSltye.p S 2te ips 2to i sc atolc cualalcteu ltahtee the
322
+ VisuVails uCahl eCckheercker
323
+ averaavgeer aogf et hoefs teh veaselu veasl:u (e3s.:0 (+3 .80. 9+ +8.9 +
324
+ Add Athdidn kthiningk rienwg arredward logiltosgits 13.01) 3/. 03) = / 83. 3=. 8T.3h.e Trehfeorreef,o trhee, athvee raavgeer aisge is
325
+ (0, 1(+0, 1, +1+, 1, +0), 0) 8.3. 8.3.
326
+ VisuVails uCahl eCckheerc:k 1er.0: 1.0
327
+ GRGPOR PmOo dmeode SFT S FmTo dmeode
328
+ (a)ThepipelineforDyME. (b)Visualrefinerandchecker.
329
+ Figure3: WorkflowandmodulecomponentsofDyME.Ateachtrainingstep,DyMEdynamically
330
+ switchesbetweenmemorization(viaSFT)andexploration(viaGRPO)modesbasedonitsgenerations.
331
+ Visualsupervisionisintroducedthroughthevisualrefinerandvisualchecker. Therefinerenhances
332
+ the targets for memorization by incorporating richer visual elements (green), while the checker
333
+ rewardsthethinkingcontextgeneratedbasedontheirvisualrelevance.
334
+ SFTalreadystabilizestraining. Thisavoidsoverlyconservativeupdatesandyieldsacleanergradient
335
+ form,enablingsmootheralignmentbetweenSFTandGRPO:
336
+ L˜ (θ)=−E E [r (x,y˜)A(y˜)], (7)
337
+ GRPO x∼D y˜∼pθ(·|x) θ
338
+ where A(y˜k) is the group-normalized advantage calculated from the combined answer (r ) and
339
+ a
340
+ thinking(r )rewards,andr (x,y˜k)= pθ(y˜|x) istheimportancesamplingratio.
341
+ t θ pold(y˜|x)
342
+ SFTMode.WhentrainingfallsbacktoSFT,themodelisoptimizedtowardtheground-truthresponse
343
+ yusingthestandardsupervisedloss(Eq.1). Thisensuresthatwheneverthemodelfailstoexplore
344
+ effectively,itreceivesastable,ground-truth-guidedgradientupdatetocorrectitsbehavior.
345
+ DyMEObjective. Thefinallossdynamicallycombinesthetwoobjectivesbasedonresponsecorrect-
346
+ ness:
347
+ (cid:20) (cid:21) (cid:18) (cid:20) (cid:21)(cid:19)
348
+ L (θ)=1 maxr (y˜k)=1 ·L˜ (θ)+ 1−1 maxr (y˜k)=1 ·L (θ), (8)
349
+ DyME a GRPO a SFT
350
+ k k
351
+ where1[·]istheindicatorfunction,returning1iftheconditionholds,0otherwise.
352
+ 3.3 VISIONSUPERVISION
353
+ DyME with Visual Supervision. While the aforementioned Pure DyME (using standard r and
354
+ t
355
+ staticground-truth)alreadyguaranteestrainingstabilitythroughitsdynamicswitchingmechanism,
356
+ wecanfurtherexploitthisdynamicnaturetomaximizeperformance. Specifically,theswitching
357
+ mechanismallowsustotailorthesupervisionsignalsateachoptimizationstep: refiningthereward
358
+ duringexplorationandenhancingtheground-truthduringmemorization. Tothisend,weintroducea
359
+ checker–refinerframework(cf. Fig.3b),whichconstitutestheFullDyME.
360
+ Thisframeworkreorganizestheground-truthtoadheretoapredefinedstructure,cruciallytransform-
361
+ ingitintoagroundedthinkingtrace. Therefinerrestructurestheexternalground-truthintostructured,
362
+ visuallygroundedresponses,whilethecheckerevaluatesself-generatedoutputsfortheirstructural
363
+ organizationandcoverageofvisualcontent. Werefertotheresultingsupervisionsignalscollectively
364
+ asvisionsupervision. TheimplementationiscarriedoutviaLLM-basedpromptengineering.
365
+ VisualFactsI arecentraltorealizingvisionsupervision. Theyaredefinedasfine-grainedvisual
366
+ c
367
+ componentsextractedfromeachimage,includingobjects,attributes,andstates. Theseelementsplay
368
+ 6
369
+ --- PAGE 7 ---
370
+ PublishedasaconferencepaperatICLR2026
371
+ adualrole: theyprovideevidenceforevaluatinggenerationsagainsttheimageandserveasbuilding
372
+ blocksforconstructingcompleteground-truthresponses.
373
+ VisualChecker.Thevisualcheckerevaluatesresponsesalongtwodimensions:(i)whethertheoutput
374
+ containssufficientcorrectvisualelementscomparedtoI ,and(ii)whetheritalignsstylisticallywith
375
+ c
376
+ providedexamples.TheseexamplesmaybemanuallydefinedorextractedfromtheSFTground-truth.
377
+ VisualRefiner. TherefinerproducesvisuallygroundedresponsesforSFTbyleveragingthemodel’s
378
+ validatedexplorations. High-scoringtracesidentifiedbythevisualcheckerarestoredinadynamic
379
+ examplepool. AnLLMthendrawsfromthispooltogenerateground-truthresponses,integrating
380
+ structuraltemplateswithvisualfactsfromI andreferencingthecollectedexamples.
381
+ c
382
+ Inessence,theacquisitionofVisualFacts,theevaluationbytheVisualChecker,andthesynthesisby
383
+ theVisualRefinerareallimplementedviastructuredpromptengineeringusingQwen2.5-14B.Please
384
+ refertotheSupplementaryMaterialsforthefullpromptsusedinourpipeline.
385
+ 4 EXPERIMENTS
386
+ TorigorouslyevaluateDyME,westructureourexperimentsintotwoparts: (1)AlgorithmicVal-
387
+ idation, where we evaluate “Pure DyME” in a controlled setting using offline data to isolate the
388
+ contributionofourdynamicswitchingmechanism;and(2)SystemEffectiveness,whereweevaluate
389
+ thefullDyMEpipeline(withVisualSupervision)acrossdiversedomainstodemonstrateitspractical
390
+ capabilityinempoweringSVLMs.
391
+ 4.1 PARTI:ALGORITHMICVALIDATION(PUREDYME)
392
+ Setup.SinceSVLMslackintrinsicreasoningcapabilitiesandcannotautonomouslydiscovercomplex
393
+ reasoningpaths,pre-constructedCoTdataisamandatoryprerequisiteforalltrainingparadigms. We
394
+ thereforeevaluatedallmethodsonChartQA(Masryetal.,2022)usingLLaVA-OV-S(Lietal.,2024a),
395
+ the0.5Bvariant,withthreepre-constructedCoTdatasetsofvaryingqualities: Low(Undesigned)
396
+ containing unstructured traces (∼80 words); Medium (Standard) consisting of semi-structured
397
+ traces(∼89words)fromQwen2.5-14B;andHigh(Premium)comprisinghighlystructuredtraces
398
+ (∼142words)fromGPT-4o. Followingestablishedprotocols(Liuetal.,2023;Masryetal.,2022),
399
+ wereportrelaxedcorrectness,whichallowsa5%tolerancefornumericalanswers.
400
+ Wepresentathreefoldevaluationtovalidatedatarobustness,designoptimality,andgeneralization:
401
+ (1)RobustnesstoDataQuality. Table1(a)demonstratedDyME’ssuperiority. OnLowqualitydata,
402
+ PureDyME(61.9%)significantlyoutperformstheunstableTwo-stagebaseline(57.6%). Remarkably,
403
+ using only Medium data, it surpasses the SFT baseline trained on premium High (GPT-4o) data
404
+ (61.6%). ThisconfirmsthatDyMEactsasarobuststudent,effectivelymaximizingdataefficiency.
405
+ (2)OptimalityofBinarySwitching. Tovalidateourbinarydesign,wecompareditagainstthree
406
+ alternativeswitchingheuristicsinTable1(b): (i)RewardThresholding,whichswitchestoRLonly
407
+ if the batch average reward exceeds a threshold t; (ii) SFT Annealing, which applies a weighted
408
+ SFTlossalongsideRLateverystep;and(iii)SFTBudget,whichperformsfocusedSFTupdateson
409
+ accumulatedfailurecases(hardmining).
410
+ Results: RewardThresholdingprovesbrittle,collapsingatsuboptimalthresholds(t=0.5,52.4%).
411
+ SFTAnnealingincursaheavycomputationaltax(+25%)duetotheauxiliarySFTgradientcalculation.
412
+ SFTBudgetyieldsinferiorresults(59.6%)asoverwhelmingthemodelwithconcentratedfailures
413
+ destabilizeslearning. Incontrast,DyME’sbinaryswitchisparameter-free,efficient,andempirically
414
+ optimal(64.9%).
415
+ (3) Mechanism Generality. Going beyond the primary setup, while DyME is primarily tailored
416
+ forSVLMs, weverifytheuniversalityofitscoreswitchingmechanism(seeSupplementary). In
417
+ thetext-onlydomain,itbooststhesmall-scaleQwen2.5-0.5BonGSM8K(Cobbeetal.,2021)to
418
+ 55.3%(+5.8%overGRPO),confirmingDyMEisaneffectiveparadigmforempoweringthinkingin
419
+ small-parametermodelsregardlessofmodality. Moreover,theparadigmscaleseffectively: onthe
420
+ strongerQwen2.5-VL-7B,itfurtherimprovesChartQAperformanceto89.6%(+2.3%).
421
+ 7
422
+ --- PAGE 8 ---
423
+ PublishedasaconferencepaperatICLR2026
424
+ Table1: AlgorithmicValidationofPureDyME.(a)DyMEoutperformsSFTandTwo-stagevariants
425
+ (w/andw/oKLpenalty)acrossalldataqualities. (b)Thebinaryswitchismorerobustandefficient
426
+ thansoftorhard-miningalternatives(evaluatedonMediumdata).
427
+ (a)RobustnessacrossDataQuality (b)SwitchingStrategyAblation
428
+ Method Low Medium High Strategy Hyperparam. Acc. Cost
429
+ SFT 50.5 57.8 61.6 RewardThreshold t=0.5/0.8/0.9 52.4/64.1/63.4 None
430
+ Two-stage 57.6 59.9 54.5 SFTAnnealing Cosine 64.0 +25%
431
+ Two-stage(w/KL) 55.4 60.8 62.7 SFTBudget HardMining 59.6 Budget-dep.
432
+ PureDyME 61.9 64.9 68.5 BinarySwitch(Ours) – 64.9 Baseline
433
+ 4.2 PARTII:SYSTEMEFFECTIVENESS(FULLDYME)
434
+ Havingvalidatedthealgorithmiccore,wenowevaluatetheFullDyMEpipeline,augmentedwith
435
+ VisualSupervision,acrossthreediversedomains:MedicalVQA,ChartUnderstanding,andGeometry.
436
+ Eachfollowedtheevaluationprotocolsofpriorwork(Zongetal.,2024).
437
+ Setup & Source of I . Unlike Part I, here we activate the Visual Supervision module to enable
438
+ c
439
+ the full online loop. Crucially, to demonstrate DyME’s capability to bootstrap from raw signals,
440
+ we utilize the “Undesigned” CoT data (defined in §4.1) derived from SLAKE (Liu et al., 2021),
441
+ ChartQA(Masryetal.,2022),andGeo170K(Gaoetal.,2025)asthecommontrainingsourcefor
442
+ all methods. Acquiring the necessary visual facts (I ) is a fully automated process: we leverage
443
+ c
444
+ standarddomaintools(e.g.,BiomedGPT(Zhangetal.,2024a)formedical,DePlot(Liuetal.,2023)
445
+ forcharts)orpromptgeneralistLLMs(e.g.,Qwen2.5(Team,2024))toparseimagesintostructured
446
+ textualdescriptions. Theautomatedpipelineandpromptsareincludedinthesupplementary.
447
+ EvaluationProtocol. Weusedofficialtrain-testsplitsforSLAKE(Accuracy/Recall)andChartQA
448
+ (Relaxedcorrectness). ForGeometry, sinceGeo170K(Gaoetal.,2025)providesnotestset, we
449
+ evaluatedAccuracyonMathVerse(Zhangetal.,2024b),consistentwithZongetal.(2024).
450
+ 4.2.1 MAINRESULTS
451
+ DyMEvs. ExistingTrainingParadigms. Thecomprehen-
452
+ siveresultsinTable2showthatDyMEconsistentlydeliv-
453
+ erssubstantialgains. Notably,aftertrainingwithDyME,
454
+ SmolVLM improves from 49.9 to 55.6 (+5.7), LLaVA-
455
+ OV-S from 50.7 to 55.4 (+4.7), and InternVL2-S from
456
+ 56.3to58.1(+1.8). Incontrast,existingparadigmstend
457
+ todegradeperformance(e.g.,SFTlowersSmolVLMto
458
+ 44.1),validatingouranalysisthatSFTyieldspseudothink-
459
+ ingtracesandGRPOfacesadvantagecollapse(cf. Fig.4).
460
+ DyME effectively mitigates these issues. It promotes
461
+ grounded traces that are concise yet informative (cf.
462
+ Fig.5),aligningwellwiththelimitedcapacityofSVLMs.
463
+ Importantly,DyMEplacesminimaldemandsonthebase
464
+ Figure 4: Training rewards. GRPO
465
+ model: evenSmolVLM(0.5B)achievessubstantialgains,
466
+ andtwo-stagetrainingsufferfromsevere
467
+ anditstilldeliversimprovements(+2.6%)onextensively
468
+ advantagecollapse.
469
+ pretrainedmodelslikeInternVL2-S.Wefurthercorrobo-
470
+ ratedthesefindingsthroughmanualinspection,asdetailed
471
+ intheSupplementaryMaterial.
472
+ MatchingtheEfficacyofGPT-4oSupervisionwithOpen-SourceModels. Comparingresults
473
+ between Part I and Part II reveals a crucial finding: LLaVA-OV-S trained with the full DyME
474
+ pipeline(usingtheaccessibleQwen2.5-14B)achieves67.5%(Table2). Thiseffectivelymatchesthe
475
+ performanceofPureDyMEtrainedonexpensiveGPT-4odata(68.5%,cf. Table1). Thisprovesthat
476
+ fullDyMEallowsopen-sourcesupervisiontoachievetrainingoutcomescomparabletothosederived
477
+ fromtop-tierproprietarymodels,eliminatingtheneedforexpensivedataannotation.
478
+ DyME-trainedSVLMsCanBeCompetitivewithLVLMs. Weensuredfairnessbyexposingall
479
+ baselinestoourtrainingdata. AsshowninTable2,SVLMstrainedwithDyMEcansurpassstronger
480
+ 8
481
+ --- PAGE 9 ---
482
+ PublishedasaconferencepaperatICLR2026
483
+ Table2: Comparisonsacrossthreedomains: medicalVQA,chartunderstanding,andgeometry
484
+ solving.TheevaluationfollowstheVLMEvalKitframework(Duanetal.,2024).ForSVLMs,existing
485
+ trainingparadigmsdegradetheirperformance,whereasDyMEconsistentlybringsimprovements. The
486
+ bestperformanceachievedbyeachSVLMishighlightedinbold,withtherelativeimprovementalso
487
+ indicated. Notably,afterbeingtrainedwithDyME,SVLMsachieveperformancecomparabletothat
488
+ ofMoVA(underlined).
489
+ Model ViT LLM Medical Chart Geometry Avg.
490
+ LVLMs
491
+ LLaVA-Med(Lietal.,2023) CLIP-ViT-L/14 Vicuna-7B 64.3 – – –
492
+ Cambrian-1(Tongetal.,2024) Hybrid-3B Llama3-8B – 72.6 22.0 –
493
+ LLaVA-1.5(Liuetal.,2024a) CLIP-ViT-L/14 Vicuna-7B 69.4 17.8 – –
494
+ LLaVA-1.6(Liuetal.,2024b) CLIP-ViT-L/14 Vicuna-7B 78.2 49.2 13.4 47.0
495
+ MoVA(Zongetal.,2024) Hybrid-3B Vicuna-7B 74.5 68.3 19.7 54.2
496
+ LLaVA-OV-L(Lietal.,2024a) SigLIP-SO400M Qwen2-7B 75.7 80.9 24.5 60.4
497
+ InternVL2-L(Chenetal.,2024) InternViT-300M InternLM2.5-7B 80.2 82.1 37.3 66.5
498
+ SVLMs
499
+ SmolVLM(Marafiotietal.,2025) SigLIP-93M SmolLM2-360M 72.1 63.2 14.6 49.9
500
+ +CoTSFT SigLIP-93M SmolLM2-360M 60.1 57.7 14.5 44.1
501
+ +GRPO SigLIP-93M SmolLM2-360M 61.1 53.8 17.1 44.0
502
+ +Two-stage SigLIP-93M SmolLM2-360M 59.4 60.1 16.7 45.4
503
+ +DyME SigLIP-93M SmolLM2-360M 78.1 69.7 18.9 55.6
504
+ (+6.0%) (+6.5%) (+4.3%) (+5.7%)
505
+ LLaVA-OV-S(Lietal.,2024a) SigLIP-400M Qwen2-0.5B 74.9 61.4 15.9 50.7
506
+ +Two-stage SigLIP-400M Qwen2-0.5B 74.5 52.9 16.5 48.0
507
+ +DyME SigLIP-400M Qwen2-0.5B 78.3 67.5 20.4 55.4
508
+ (+3.4%) (+6.1%) (+4.5%) (+4.7%)
509
+ InternVL2-S(Chenetal.,2024) InternViT-300M Qwen2-0.5B 78.3 71.9 18.7 56.3
510
+ +Two-stage InternViT-300M Qwen2-0.5B 73.6 55.7 17.1 48.8
511
+ +DyME InternViT-300M Qwen2-0.5B 80.0 74.5 19.8 58.1
512
+ (+1.7%) (+2.6%) (+1.1%) (+1.8%)
513
+ LVLMslikeMoVA(54.2)onthesespecializeddomains,withSmolVLMreaching55.6andLLaVA-
514
+ OV-S55.4. Asaresult,DyME-trainedSVLMsbecomereliableoptionsfortask-specificapplications
515
+ onresource-constrainededgedevices.
516
+ 4.2.2 ABLATIONSTUDY
517
+ Todissectthesourceofthesegains,weconductedanablationstudytoanalyzethecontributionof
518
+ DyME’sfourcorecomponentswithinthefullpipeline: thememorizationmode,explorationmode,
519
+ visualrefiner,andvisualchecker. Table3showstheperformanceimpact.
520
+ DynamicSwitchingMechanism. Theresults
521
+ confirmthatMemorizationandExplorationare Table3: Ablationstudy. Model: LLaVA-OV-S.
522
+ symbiotic. Disabling memorization causes a
523
+ catastrophicdrop(55.4→43.9),effectivelyre-
524
+ DyMEVariant Medical Chart Geometry Average
525
+ vertingtounconstrained,unstableexploration.
526
+ DyME(full) 78.3 67.5 20.4 55.4
527
+ Conversely, removing exploration (50.4) re- w/omemorization 63.2 53.4 15.0 43.9(20.6%↓)
528
+ strictsthemodeltothestaticimitationofsub-
529
+ w/oexploration 75.5 61.3 14.5 50.4(9.0%↓)
530
+ w/ovisualrefiner 75.6 62.3 16.8 51.6(6.9%↓)
531
+ optimaldata. AsshowninFig.4,theirdynamic w/ovisualchecker 76.9 64.3 17.1 52.8(4.7%↓)
532
+ interplay prevents the advantage collapse ob-
533
+ servedinbaselines,ensuringoptimizationstabilitythroughoutthelearningprocess.
534
+ Visual Supervision. Removing the visual checker and refiner drops performance by 4.7% and
535
+ 6.9%,respectively. Thisvalidatesthepivotalroleofvisualsupervisioninbootstrappingfromnoisy,
536
+ undesigneddata. GiventhelimitedcapacityofSVLMs,theyareeasilypronetohallucinationwhen
537
+ trainedonlow-qualitytraces. Thevisualcomponentsactasadynamicdenoiser,ensuringthatraw,
538
+ imperfectdataisfilteredandrefinedintogroundedvisualfacts(I )beforeoptimization,thusenabling
539
+ c
540
+ robustlearningevenfromweaksupervision.
541
+ 9
542
+ --- PAGE 10 ---
543
+ PublishedasaconferencepaperatICLR2026
544
+ What is the difference between The difference between the Given AB // CD, angle 1 = 50.0,
545
+ the values of 2017 and 2016? values of 2017 and 2016 is then what the degree of angle 2? Angle 2 is equal to angle 3,
546
+ (Input Image) 19000. angle 2 is also 50 degrees.
547
+ (Original)
548
+ Extraction: Extraction:
549
+ data is value 36700 for 2017 AB is parallel to CD,
550
+ and 29000 for 2016 angle 1 = 50°.
551
+ Calculation: Calculation:
552
+ 36700 - 29000 = 7700 angle 2 = 180° - angle 1
553
+ Conclusion: = 180° - 50° = 130°.
554
+ The difference between the Conclusion:
555
+ 2017 and 2016 values is 7700.(DyME) The degree of angle 2 is 130°.
556
+ Answer: 7700 Answer: 130°
557
+ Figure 5: Showcases on chart understanding and geometry solving. We use LLaVA-OV-S
558
+ todemonstratetheresults. TheSVLMoriginallyproduceshallucinatedanswers(red), whilethe
559
+ DyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategroundedvalues,
560
+ effectivelyimprovingtheperformance.
561
+ 4.3 TRAININGEFFICIENCY&DISCUSSION
562
+ Weanalyzethecomputationalefficiencyandperformancetrade-offsassociatedwithdifferentconfig-
563
+ urationsofDyME.ThecomparativeresultsaredetailedinTable4.
564
+ Computational Efficiency vs. Data Cost. The frame-
565
+ work offers two distinct operating regimes catering to Table4: Cost-BenefitAnalysis. Time
566
+ different resource profiles. Pure DyME represents the measuredinsec/step. Runon8xH800.
567
+ high-efficiency regime: when offline CoT data is pre-
568
+ constructed,itmaintainstrainingthroughputcomparable
569
+ Method Ext.Model Time Acc.
570
+ tostandardGRPO(∼14s/step)whiledeliveringsuperior
571
+ GRPO(Baseline) Qwen2.5-14B† 14.8s 60.8
572
+ performance. Incontrast,FullDyME(withVisualSupervi- PureDyME Qwen2.5-14B† 14.0s 64.9
573
+ sion)prioritizesdataautonomy. Whiletheonlineinterac- PureDyME GPT-4o† 19.1s 68.5
574
+ tionintroducesacomputationaloverhead(∼1.6×training FullDyME Qwen2.5-7B 21.2s 66.8
575
+ FullDyME Qwen2.5-14B 23.4s 67.5
576
+ time),itenablesthemodeltobootstraphigh-performance †Usedforofflinedataconstructiononly.
577
+ reasoningsolelyfromopen-sourcemodels,bypassingthe
578
+ dependencyonexpensive,proprietarydataannotation(e.g.,GPT-4o).
579
+ Sensitivity to External Model Capacity. For Full DyME, we further examine the impact of the
580
+ externalhelper’ssizeonsystemperformance. AsshowninTable4, replacingtheQwen2.5-14B
581
+ helperwiththesmaller7Bvariantresultsinanegligibleperformancevariation(67.5%→66.8%).
582
+ Thisindicatesthatourstructuredpromptengineeringeffectivelydecomposescomplexreasoning
583
+ tasks, allowing even smaller external models to provide sufficient guidance for SVLMs without
584
+ necessitatingheavy-weightmodels.
585
+ ApplicabilityofVisualSupervision. TheeffectivenessoftheVisualSupervisionmodulerelieson
586
+ theexplicitextractionofVisualFacts(I ). Thisprocesscreatesspecificapplicabilityboundaries. For
587
+ c
588
+ domainsinvolvingabstractsemantics(e.g.,ironyinmemes)orunstructuredperception(e.g.,dense
589
+ crowds),convertingholisticvisualsignalsintodiscretetextmayresultininformationloss. Insuch
590
+ scenarios,revertingtothePureDyMEparadigmservesasamorerobustalternative.
591
+ 5 CONCLUSION
592
+ Inthiswork,weintroducedDyME,anoveltrainingparadigmdesignedtoempowerthinkingcapa-
593
+ bilitieswithinSVLMs. Atitscore,DyMEcombinesmemorization(viaSFT)modeandexploration
594
+ (viaRLVR)modethroughadynamicswitchingmechanism. Ourexperimentsdemonstratethatthis
595
+ approachnotonlyresolvesthecriticaltrade-offbetweenthesetwomodesbutalsoyieldssubstantial
596
+ performance gains on a wide spectrum of vision tasks, from recognition-intensive to reasoning-
597
+ intensivescenarios. ThesuccessofDyMEisattributedtoitscarefullydesignedcomponents: the
598
+ dynamicswitchingmechanismaddressespseudothinkingtracesandadvantagecollapse,whilethe
599
+ visualcheckerandrefinerprovidecoordinated,high-qualityvisualsupervision. Itimposesminimal
600
+ requirementsonthebaseVLM,makingitbroadlyapplicabletoawiderangeofmodels,including
601
+ extremelylightweightSVLMs. Therefore,DyMEservesasthepracticalsolutionforempowering
602
+ SVLMstothink.
603
+ 10
604
+ --- PAGE 11 ---
605
+ PublishedasaconferencepaperatICLR2026
606
+ ACKNOWLEDGMENT
607
+ ThisworkwassupportedbytheHongKongSARRGCGeneralResearchFund(16219025),National
608
+ NaturalScienceFoundationofChinaYoungScholarFundCategoryB(62522216),NationalNatural
609
+ ScienceFoundationofChinaYoungScholarFundCategoryC(62402408),andHongKongSAR
610
+ RGCEarlyCareerScheme(26208924).
611
+ REFERENCES
612
+ Alon Albalak, Akshat Shrivastava, Chinnadhurai Sankar, Adithya Sagar, and Mike Ross. Data-
613
+ efficiencywithasinglegpu: Anexplorationoftransfermethodsforsmalllanguagemodels. arXiv
614
+ preprintarXiv:2210.03871,2022.
615
+ JinzeBai,ShuaiBai,ShushengYang,ShijieWang,SinanTan,PengWang,JunyangLin,ChangZhou,
616
+ andJingrenZhou. Qwen-VL:Aversatilevision-languagemodelforunderstanding,localization,
617
+ textreading,andbeyond. arXivpreprintarXiv:2308.12966,2023.
618
+ ShuaiBai,KeqinChen,XuejingLiu,JialinWang,WenbinGe,SiboSong,KaiDang,PengWang,
619
+ ShijieWang,JunTang,HumenZhong,YuanzhiZhu,MingkunYang,ZhaohaiLi,JianqiangWan,
620
+ PengfeiWang,WeiDing,ZherenFu,YihengXu,JiaboYe,XiZhang,TianbaoXie,ZesenCheng,
621
+ HangZhang,ZhiboYang,HaiyangXu,andJunyangLin. Qwen2.5-VLtechnicalreport. arXiv
622
+ preprintarXiv:2502.13923,2025.
623
+ HardyChen,HaoqinTu,FaliWang,HuiLiu,XianfengTang,XinyaDu,YuyinZhou,andCihangXie.
624
+ SFTorRL?anearlyinvestigationintotrainingR1-likereasoninglargevision-languagemodels.
625
+ arXivpreprintarXiv:2504.11468,2025a.
626
+ LiangChen,LeiLi,HaozheZhao,YifanSong,andVinci. R1-V:Reinforcingsupergeneralization
627
+ abilityinvision-languagemodelswithlessthan$3. https://github.com/Deep-Agent/
628
+ R1-V,2025b. Accessed: 2025-02-02.
629
+ ZheChen,WeiyunWang,YueCao,YangzhouLiu,ZhangweiGao,ErfeiCui,JinguoZhu,Shenglong
630
+ Ye,HaoTian,ZhaoyangLiu,etal. Expandingperformanceboundariesofopen-sourcemultimodal
631
+ modelswithmodel,data,andtest-timescaling. arXivpreprintarXiv:2412.05271,2024.
632
+ TianzheChu,YuexiangZhai,JihanYang,ShengbangTong,SainingXie,DaleSchuurmans,QuocV
633
+ Le,SergeyLevine,andYiMa.SFTmemorizes,RLgeneralizes:Acomparativestudyoffoundation
634
+ modelpost-training. arXivpreprintarXiv:2501.17161,2025.
635
+ Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Mark Chen, Heewoo Jun, Lukasz Kaiser,
636
+ MatthiasPlappert,JerryTworek,JacobHilton,ReiichiroNakano,etal. Trainingverifierstosolve
637
+ mathwordproblems. arXivpreprintarXiv:2110.14168,2021.
638
+ DeepSeek, Inc. DeepSeek-R1 Release. https://api-docs.deepseek.com/news/
639
+ news250120,January2025. Accessed: Jun.21,2025.
640
+ HaodongDuan,JunmingYang,YuxuanQiao,XinyuFang,LinChen,YuanLiu,XiaoyiDong,Yuhang
641
+ Zang,PanZhang,JiaqiWang,etal. VLMEvalKit: Anopen-sourcetoolkitforevaluatinglarge
642
+ multi-modalitymodels. InACMMM,2024.
643
+ Jiahui Gao, Renjie Pi, Jipeng Zhang, Jiacheng Ye, Wanjun Zhong, Yufei Wang, Lanqing Hong,
644
+ JianhuaHan,HangXu,ZhenguoLi,etal. G-LLaVA:Solvinggeometricproblemwithmulti-modal
645
+ largelanguagemodel. InICLR,2025.
646
+ AkashGhosh, ArkadeepAcharya, SriparnaSaha, VinijaJain, andAmanChadha. Exploringthe
647
+ frontierofvision-languagemodels:Asurveyofcurrentmethodologiesandfuturedirections. arXiv
648
+ preprintarXiv:2404.07214,2024.
649
+ Daya Guo, Dejian Yang, Haowei Zhang, Junxiao Song, Ruoyu Zhang, Runxin Xu, Qihao Zhu,
650
+ ShirongMa,PeiyiWang,XiaoBi,etal. Deepseek-R1: IncentivizingreasoningcapabilityinLLMs
651
+ viareinforcementlearning. arXivpreprintarXiv:2501.12948,2025.
652
+ VikKorrapati. Moondream. https://moondream.ai/,2024. Accessed: 2025-03-27.
653
+ 11
654
+ --- PAGE 12 ---
655
+ PublishedasaconferencepaperatICLR2026
656
+ Yuxiang Lai, Jike Zhong, Ming Li, Shitian Zhao, and Xiaofeng Yang. Med-R1: Reinforce-
657
+ ment learning for generalizable medical reasoning in vision-language models. arXiv preprint
658
+ arXiv:2503.13939,2025.
659
+ BoLi,YuanhanZhang,DongGuo,RenruiZhang,FengLi,HaoZhang,KaichenZhang,Yanwei
660
+ Li,ZiweiLiu,andChunyuanLi. LLaVA-OneVision: Easyvisualtasktransfer. arXivpreprint
661
+ arXiv:2408.03326,2024a.
662
+ Chunyuan Li, Cliff Wong, Sheng Zhang, Naoto Usuyama, Haotian Liu, Jianwei Yang, Tristan
663
+ Naumann,HoifungPoon,andJianfengGao. LLaVA-Med: Trainingalargelanguage-and-vision
664
+ assistantforbiomedicineinoneday. AdvancesinNeuralInformationProcessingSystems,36:
665
+ 28541–28564,2023.
666
+ Zhuowan Li, Bhavan Jasani, Peng Tang, and Shabnam Ghadar. Synthesize step-by-step: Tools
667
+ templatesandLLMsasdatageneratorsforreasoning-basedchartVQA. InCVPR,2024b.
668
+ BoLiu,Li-MingZhan,LiXu,LinMa,YanYang,andXiao-MingWu. SLAKE:Asemantically-
669
+ labeledknowledge-enhanceddatasetformedicalvisualquestionanswering. InISBI,2021.
670
+ FangyuLiu,JulianEisenschlos,FrancescoPiccinno,SyrineKrichene,ChenxiPang,KentonLee,
671
+ MandarJoshi,WenhuChen,NigelCollier,andYaseminAltun. DePlot: One-shotvisuallanguage
672
+ reasoningbyplot-to-tabletranslation. InFindingsoftheACL,2023.
673
+ HaotianLiu,ChunyuanLi,YuhengLi,andYongJaeLee. Improvedbaselineswithvisualinstruction
674
+ tuning. InCVPR,2024a.
675
+ Haotian Liu, Chunyuan Li, Yuheng Li, Bo Li, Yuanhan Zhang, Sheng Shen, and Yong Jae Lee.
676
+ LLaVA-NeXT:Improvedreasoning,OCR,andworldknowledge,January2024b. URLhttps:
677
+ //llava-vl.github.io/blog/2024-01-30-llava-next/.
678
+ JiazhenLiuandLongChen. Segmentationasaplug-and-playcapabilityforfrozenmultimodalllms.
679
+ arXivpreprintarXiv:2510.16785,2025.
680
+ JiazhenLiu,MingkuanFeng,andLongChen. Better,stronger,faster: Tacklingthetrilemmainmllm-
681
+ basedsegmentationwithsimultaneoustextualmaskprediction. arXivpreprintarXiv:2512.00395,
682
+ 2025a.
683
+ JiazhenLiu,YuhanFu,RuobingXie,RunquanXie,XingwuSun,FengzongLian,ZhanhuiKang,
684
+ andXirongLi. PhD:Achatgpt-promptedvisualhallucinationevaluationdataset. InCVPR,2025b.
685
+ ZiyuLiu,ZeyiSun,YuhangZang,XiaoyiDong,YuhangCao,HaodongDuan,DahuaLin,andJiaqi
686
+ Wang. Visual-RFT:Visualreinforcementfine-tuning. arXivpreprintarXiv:2503.01785,2025c.
687
+ Andre´sMarafioti,OrrZohar,MiquelFarre´,MerveNoyan,ElieBakouch,PedroCuenca,CyrilZakka,
688
+ Loubna Ben Allal, Anton Lozhkov, Nouamane Tazi, et al. SmolVLM: Redefining small and
689
+ efficientmultimodalmodels. arXivpreprintarXiv:2504.05299,2025.
690
+ AhmedMasry,DoXuanLong,JiaQingTan,ShafiqJoty,andEnamulHoque.ChartQA:Abenchmark
691
+ for question answering about charts with visual and logical reasoning. In Smaranda Muresan,
692
+ PreslavNakov,andAlineVillavicencio(eds.),FindingsoftheACL,May2022.
693
+ OpenAI. IntroducingOpenAIo1. https://openai.com/o1/,December2024. Accessed: Jun.
694
+ 21,2025.
695
+ Yingzhe Peng, Gongrui Zhang, Miaosen Zhang, Zhiyuan You, Jie Liu, Qipeng Zhu, Kai Yang,
696
+ XingzhongXu,XinGeng,andXuYang. LMM-R1: Empowering3BLMMswithstrongreasoning
697
+ abilitiesthroughtwo-stagerule-basedrl. arXivpreprintarXiv:2503.07536,2025.
698
+ Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Xiao Bi, Haowei Zhang,
699
+ Mingchuan Zhang, YK Li, Y Wu, et al. DeepSeekMath: Pushing the limits of mathematical
700
+ reasoninginopenlanguagemodels. arXivpreprintarXiv:2402.03300,2024.
701
+ HaozhanShen,PengLiu,JingchengLi,ChunxinFang,YiboMa,JiajiaLiao,QiaoliShen,Zilun
702
+ Zhang,KangjiaZhao,QianqianZhang,RuochenXu,andTianchengZhao. VLM-R1: Astable
703
+ andgeneralizableR1-stylelargevision-languagemodel. arXivpreprintarXiv:2504.07615,2025.
704
+ 12
705
+ --- PAGE 13 ---
706
+ PublishedasaconferencepaperatICLR2026
707
+ QwenTeam. Qwen2.5: Apartyoffoundationmodels,September2024. URLhttps://qwenlm.
708
+ github.io/blog/qwen2.5/.
709
+ PeterTong,EllisBrown,PenghaoWu,SanghyunWoo,AdithyaJairamVedagiriIYER,SaiCharitha
710
+ Akula,ShushengYang,JihanYang,ManojMiddepogu,ZitengWang,etal. Cambrian-1: Afully
711
+ open,vision-centricexplorationofmultimodalLLMs. AdvancesinNeuralInformationProcessing
712
+ Systems,37:87310–87356,2024.
713
+ RenqiuXia,BoZhang,HanchengYe,XiangchaoYan,QiLiu,HongbinZhou,ZijunChen,PengYe,
714
+ MinDou,BotianShi,etal. ChartX&ChartVLM:Aversatilebenchmarkandfoundationmodel
715
+ forcomplicatedchartreasoning. arXivpreprintarXiv:2402.12185,2024.
716
+ RenqiuXia,MingshengLi,HanchengYe,WenjieWu,HongbinZhou,JiakangYuan,TianshuoPeng,
717
+ XinyuCai,XiangchaoYan,BinWang,etal. GeoX:Geometricproblemsolvingthroughunified
718
+ formalizedvision-languagepre-training. InICLR,2025.
719
+ Guowei Xu, Peng Jin, Hao Li, Yibing Song, Lichao Sun, and Li Yuan. LLaVA-CoT: Let vision
720
+ languagemodelsreasonstep-by-step. arXivpreprintarXiv:2411.10440,2024.
721
+ Jianhao Yan, Yafu Li, Zican Hu, Zhi Wang, Ganqu Cui, Xiaoye Qu, Yu Cheng, and Yue Zhang.
722
+ Learningtoreasonunderoff-policyguidance. arXivpreprintarXiv:2504.14945,2025.
723
+ LeleYang,MuxiDiao,KongmingLiang,andZhanyuMa. GRPOforLLaVA. https://github.
724
+ com/PRIS-CV/GRPO-for-Llava,2025a.
725
+ YiYang,XiaoxuanHe,HongkunPan,XiyanJiang,YanDeng,XingtaoYang,HaoyuLu,Dacheng
726
+ Yin, Fengyun Rao, Minfeng Zhu, et al. R1-OneVision: Advancing generalized multimodal
727
+ reasoningthroughcross-modalformalization. arXivpreprintarXiv:2503.10615,2025b.
728
+ YuexiangZhai,ShengbangTong,XiaoLi,MuCai,QingQu,YongJaeLee,andYiMa. Investigating
729
+ thecatastrophicforgettinginmultimodallargelanguagemodelfine-tuning. InCPAL,2023.
730
+ JingyiZhang,JiaxingHuang,HuanjinYao,ShunyuLiu,XikunZhang,ShijianLu,andDachengTao.
731
+ R1-VL:Learningtoreasonwithmultimodallargelanguagemodelsviastep-wisegrouprelative
732
+ policyoptimization. arXivpreprintarXiv:2503.12937,2025a.
733
+ KaiZhang,RongZhou,EashanAdhikarla,ZhilingYan,YixinLiu,JunYu,ZhengliangLiu,Xun
734
+ Chen,BrianDDavison,HuiRen,etal. Ageneralistvision–languagefoundationmodelfordiverse
735
+ biomedicaltasks. NatureMedicine,pp.1–13,2024a.
736
+ RenruiZhang,DongzhiJiang,YichiZhang,HaokunLin,ZiyuGuo,PengshuoQiu,AojunZhou,
737
+ PanLu,Kai-WeiChang,PengGao,etal. MathVerse: Doesyourmulti-modalllmtrulyseethe
738
+ diagramsinvisualmathproblems? 2024b.
739
+ WenhaoZhang,YuexiangXie,YuchangSun,YanxiChen,GuoyinWang,YaliangLi,BolinDing,
740
+ andJingrenZhou. On-policyRLmeetsoff-policyexperts: Harmonizingsupervisedfine-tuning
741
+ andreinforcementlearningviadynamicweighting. arXivpreprintarXiv:2508.11408,2025b.
742
+ ZhuoshengZhang,AstonZhang,MuLi,HaiZhao,GeorgeKarypis,andAlexSmola. Multimodal
743
+ chain-of-thoughtreasoninginlanguagemodels. arXivpreprintarXiv:2302.00923,2023.
744
+ Baichuan Zhou, Ying Hu, Xi Weng, Junlong Jia, Jie Luo, Xien Liu, Ji Wu, and Lei
745
+ Huang. TinyLLaVA: A framework of small-scale large multimodal models. arXiv preprint
746
+ arXiv:2402.14289,2024.
747
+ ZhuofanZong,BingqiMa,DazhongShen,GuangluSong,HaoShao,DongzhiJiang,HongshengLi,
748
+ andYuLiu. MoVA:Adaptingmixtureofvisionexpertstomultimodalcontext. InNeurIPS,2024.
749
+ 13
750
+ --- PAGE 14 ---
751
+ PublishedasaconferencepaperatICLR2026
752
+ Empowering Small VLMs to Think with Dynamic Memorization
753
+ and Exploration
754
+ SupplementaryMaterial
755
+ Inthesupplementarymaterials,wereport:
756
+ • LLMinstructionsusedforconstructingvisionsupervision(§S1);
757
+ • Detailedexperimentalsetupandadditionalexperimentalresults(§S2);
758
+ • ShowcasesofSVLMstrainedviaDyMEperformingonmedicalVQA,chartunderstanding,
759
+ andgeometryproblemsolving(§S3);
760
+ S1 LLM INSTRUCTIONS FOR VISION SUPERVISION
761
+ TheinstructionsforconstructingI ,thevisualrefiner,andthevisualcheckerarelistedasfollows.
762
+ c
763
+ S1.1 INSTRUCTIONSFOREXTRACTINGVISUALELEMENTS
764
+ I isprimarilyderivedfromtwosources: groundtruthcaptions,andtheoutputsfromspecialized
765
+ c
766
+ toolssuchasthechart-parsingmodelDeplot. PromptS1isemployedtoextractvisualelementsfrom
767
+ captions.
768
+ 1 You are a helpful assistant that analyzes images and provides visual
769
+ facts.
770
+ 2 Your response MUST be a single, valid JSON object.
771
+ 3 The JSON object should contain:
772
+ 4 1. "description": A detailed and accurate description of the image.
773
+ 5 2. "objects": A list of key objects, including their name, attributes,
774
+ and approximate position in the image.
775
+ 6
776
+ 7 Example format:
777
+ 8 {
778
+ 9 "description": "A person riding a bicycle on a city street.... (
779
+ detailed description here)",
780
+ 10 "objects": [
781
+ 11 {"name": "person", "attributes": ["wearing helmet", "blue shirt"], "
782
+ position": "center"},
783
+ 12 {"name": "bicycle", "attributes": ["red", "mountain bike"], "position
784
+ ": "center"},
785
+ 13 {"name": "street", "attributes": ["asphalt", "wet"], "position": "
786
+ bottom"}
787
+ 14 ]
788
+ 15
789
+ 16 Analyze the attached image and provide the visual facts in the required
790
+ JSON format.
791
+ 17 For context, the user will be asked this question about the image (do not
792
+ answer the question, just use it for context):
793
+ 18 "{question}"
794
+ 19 }
795
+ PromptS1: AutomatedVisualFactExtraction
796
+ S1.2 INSTRUCTIONSFORVISUALREFINER
797
+ ThevisualrefinerprimarilytakesthinkingtemplatesfromatemplatepoolandI asinputtoconstruct
798
+ c
799
+ thegroundtruthforSFT.
800
+ 1 Given:
801
+ 2 <IC>: the data of an image
802
+ 3 <Q>: a question
803
+ 1
804
+ --- PAGE 15 ---
805
+ PublishedasaconferencepaperatICLR2026
806
+ 4 <A>: a reference answer
807
+ 5 <T>: a writing template
808
+ 6
809
+ 7 Goal:
810
+ 8 Transform the visual information in <IC> into a textualized data
811
+ description and incorporate it into a smooth, natural explanation
812
+ that reasons why the correct answer to <Q> is <A>, using the format
813
+ and tone defined by <T>.
814
+ 9
815
+ 10 Example:
816
+ 11 <IC>: [
817
+ 12 {"object": "bar", "attributes": ["˜120k", "Q4"], "label": "Product A"},
818
+ 13 {"object": "bar", "attributes": ["˜150k", "Q4"], "label": "Product B"},
819
+ 14 {"object": "bar", "attributes": ["˜90k", "Q4"], "label": "Product C"},
820
+ 15 {"title": "Quarterly Revenue"}
821
+ 16 ]
822
+ 17 <Q>: Which product has the highest revenue in Q4?
823
+ 18 <A>: product b
824
+ 19 <T>:
825
+ 20 [Extraction] Extract the relevant data points from <IC> that are
826
+ needed to answer <Q>.
827
+ 21 [Calculation] Analyze and compare the extracted data to determine the
828
+ correct result.
829
+ 22 [Conclusion] Based on the analysis, state why <A> is the correct
830
+ answer to <Q>.
831
+ 23
832
+ 24 <Output>:
833
+ 25 [Extraction] The image shows a chart comparing three items-Product A,
834
+ Product B, and Product C-in Q4. The bar heights indicate their
835
+ revenues: Product A at about 120k, Product B at about 150k, and
836
+ Product C at about 90k.
837
+ 26 [Calculation] Comparing these values, Product B’s revenue (150k) is
838
+ greater than both Product A’s (120k) and Product C’s (90k).
839
+ 27 [Conclusion] Since Product B has the highest revenue among the three,
840
+ the correct answer to the question "Which product has the
841
+ highest revenue in Q4?" is product b.
842
+ 28
843
+ 29 Now, according to the requirements and the examples above, convert my
844
+ input into the target reasoning text. Please give me the result
845
+ directly without any explanation or description.
846
+ 30
847
+ 31 <IC>: %s
848
+ 32 <Q>: %s
849
+ 33 <A>: %s
850
+ 34 <T>: %s
851
+ 35 <Output>:
852
+ PromptS2: Ground-truthconstructionforChartUnderstandingSFT
853
+ Promptsfortheotherdomainsfollowasimilardesign.
854
+ S1.3 INSTRUCTIONSFORVISUALCHECKER
855
+ Thevisualcheckerisprimarilyresponsibleforscoringthethinkingtraceofresponsesgeneratedin
856
+ theGRPOprocess. Itevaluatesthesetraceswithreferencetoexemplars,basedontheirfluencyand
857
+ thedegreetowhichthementionedvisualelementsalignwithI . Promptsfortheotherdomains
858
+ c
859
+ followasimilardesign.
860
+ 1 Given
861
+ 2 <IC>: the data of an image
862
+ 3 <Q>: a question
863
+ 4 <A>: a reference answer
864
+ 5 <R>: a reasoning text
865
+ 6
866
+ 2
867
+ --- PAGE 16 ---
868
+ PublishedasaconferencepaperatICLR2026
869
+ 7 Goal:
870
+ 8 Assess whether <R> correctly and reasonably uses visible data in <IC> to
871
+ justify that the correct answer to <Q> is <A>. Rate the quality as
872
+ low / medium / high according to:
873
+ 9 (a) low: Does not use data from <IC> at all, or the language is not
874
+ fluent/natural, or it fails to indicate the answer to <Q> is <A>.
875
+ 10 (b) medium: Uses data from <IC> and is written fluently, but the
876
+ reasoning is overly brief or insufficiently clear.
877
+ 11 (c) high: Uses data from <IC> and is written fluently; the reasoning
878
+ progresses step by step with depth, each step is correct and
879
+ reasonable; the data from <IC> appears exactly where it should;
880
+ overall, the reasoning text provides very strong support that the
881
+ answer to <Q> is <A>.
882
+ 12
883
+ 13 Example:
884
+ 14 <IC>: [
885
+ 15 {"object": "bar", "attributes": ["˜120k", "Q4"], "label": "Product A"},
886
+ 16 {"object": "bar", "attributes": ["˜150k", "Q4"], "label": "Product B"},
887
+ 17 {"object": "bar", "attributes": ["˜90k", "Q4"], "label": "Product C"},
888
+ 18 {"title": "Quarterly Revenue"}
889
+ 19 ]
890
+ 20 <Q>: Which product has the highest revenue in Q4?
891
+ 21 <A>: product b
892
+ 22 <R>:
893
+ 23 [Extraction] Reads Q4 bar heights: A ˜120k, B ˜150k, C ˜90k.
894
+ 24 [Calculation] Compares values: B > A and B > C.
895
+ 25 [Conclusion] Therefore, Product B is highest, matching the answer "
896
+ product b".
897
+ 26
898
+ 27 <Output>: medium
899
+ 28
900
+ 29 According to the requirements and examples above, score the input into
901
+ three categories. Please give me the result directly without any
902
+ explanation or description.
903
+ 30
904
+ 31 <IC>: %s
905
+ 32 <Q>: %s
906
+ 33 <A>: %s
907
+ 34 <R>: %s
908
+ 35 <Output>:
909
+ PromptS3: ScoringgenerationsduringGRPOforChartUnderstanding
910
+ S2 EXPERIMENTAL DETAILS AND EXTRA RESULTS
911
+ S2.1 TRAININGANDTESTSETTING
912
+ First,weprovidethestatisticalinformationforthetrainingandtestingphasesofourexperiments
913
+ intheTab. S1. Thetrainingdatasetforeachdomainconsistsofonlyafewthousandsamples. In
914
+ addition,Fig.S1visualizesacomparisonbetweentheground-truthresponsesproducedbytherefiner
915
+ andtheoriginalground-truth,showingthattherefinedversionsarenoticeablymorestructuredand
916
+ placegreateremphasisonintermediatevalues.
917
+ TableS1: Trainingandtestingsetup. DyMEempowersthinkingcapabilitiesbasedonsmalltraining
918
+ sets.
919
+ Domain Trainingset #Trainingsamples SourceofI Testset
920
+ c
921
+ MedicalVQA SLAKE-Train 4,919 BiomedGPT SLAKE-Test
922
+ ChartUnderstanding ChartQA-Train 4,576 DePlot ChartQA-Test
923
+ GeometrySolving Geo170K 6,417 Collected MathVerse
924
+ 3
925
+ --- PAGE 17 ---
926
+ PublishedasaconferencepaperatICLR2026
927
+ Medical(SLAKE). Chart(ChartQA). Math(Geo170k).
928
+ Q:Ifangle1is35◦,whatisangle
929
+ 2? Choices: A:55◦, B:135◦, C:
930
+ Q: What is the sum of making 145◦,D:155◦.
931
+ progress,losingground,andsame w/orefiner:
932
+ Q: What is the MR weighting in in2011? Sinceangle1is35degree,angle
933
+ thisimage? w/orefiner: 2 is 180 - 35 = 145 degree us-
934
+ w/orefiner: Toanswerthequestion,wefirstlo- ing a supplementary relationship.
935
+ Basedonsignalvariationstypical catetherelevantpartofthechart ChooseC.
936
+ of this image, the most plausible for 2011, then extract the values w/refiner:
937
+ answerisT2. forthethreeseriesandsumthem Extraction:TwoparallellinesCD
938
+ w/refiner: uptoobtain99. andABwithatransversal;angle1
939
+ Analysis: On T2, fluid is bright; w/refiner: is35degree,andangle2isonthe
940
+ bone/airandfastflowaredark. Extraction: Read 2011 values: sameside.
941
+ Extraction: In this axial slice, Making progress = 45, Losing Calculation:Same-sideinterioran-
942
+ intraluminalfluidisbright;vessels ground=29,Same=25. glesaresupplementary,soangle2
943
+ lowsignal;ribs/walldark. Calculation:45+29+25=99. =180-35=145degree.
944
+ Conclusion:PatternmatchesT2. Conclusion:Thesumis99. Conclusion:Answer:C.
945
+ FigureS1: Comparisonofground-truthresponsesbeforeandafterrefinement. Comparedto
946
+ theoriginalground-truth,therefinerinjectsrichervisualelementsandenforcesamorestructured
947
+ organization,therebyreducingthelearningburdenforSVLMs.
948
+ S2.2 EXTRARESULTS
949
+ Wealsoreportadditionalexperimentalcontent,includingthediscussionontrainingstrategiesand
950
+ dataorganizationformats,aswellasacomparativeanalysiswithothersimilarmethodsthatintegrate
951
+ SFTandRL.
952
+ Specifically, (1) we first demonstrate the importance of
953
+ constructing vision supervision, which proves essential Table S2: Two-stage training on
954
+ fortrainingSVLMstoproducegroundedthinkingtraces. ChartQA. Rel-corr denotes the
955
+ (2)Wethenexaminetheimpactofstructuredversusopen- relaxed-correctnessmetric. I indicates
956
+ c
957
+ endedoutputformatsonthinkingperformance. (3)Fur- whetheranexplicitimage-contentfield
958
+ thermore,tovalidateourearlierobservationthatSVLMs issupervised(✓yes;✗no).
959
+ arepronetoconvergingtolocaloptima,wepresentperfor-
960
+ manceacrossdifferenttrainingepochs,showingthatSFT
961
+ Model I Rel-corr
962
+ c
963
+ training saturates after only one epoch. (4) We provide
964
+ SmolVLM ✓ 64.32
965
+ adetailedcomparisonwithalternativemethodsthatinte- SmolVLM ✗ 60.09
966
+ grateSFTandRL.(5)Finally,weextendourevaluation LLaVA-OV-S ✓ 63.62
967
+ tostrongerbasemodelsandpuretextualdomains,and(6) LLaVA-OV-S ✗ 52.90
968
+ validatethequalityofgeneratedthinkingtracesthrough
969
+ humanevaluation.
970
+ (1)Intermediatevaluesmatter. AsshowninTableS2,wereporttheeffectofapplyingtwo-stage
971
+ trainingwithvisualsupervisiononSmolVLMandLLaVA-OV-S.Incorporatingvisualsupervision
972
+ significantlyimprovesthebestperformanceachievedduringtraining,despitecertaininstabilities,
973
+ thereby validating itscritical role forSVLMs. Thiseffect isfurther illustrated inFig. S1, where
974
+ visualsupervisioncompelsSVLMstogenerateintermediatereasoningenrichedwithvisualelements,
975
+ whichmakeaclearcontributiontothefinalanswer.
976
+ (2)StructuredthinkingalleviatesthelearningburdenofSVLMs. TableS3reportstheperfor-
977
+ mancegapbetweentrainingwithstructuredthinkingground-truthandwithunconstrainedground-
978
+ 4
979
+ --- PAGE 18 ---
980
+ PublishedasaconferencepaperatICLR2026
981
+ 1.0
982
+ 0.8
983
+ 0.6
984
+ 0.4
985
+ 0.2
986
+ 0.0
987
+ 0 500 1000 1500 2000 2500
988
+ Steps
989
+ ycilop-ffO
990
+ fo
991
+ noitroporP
992
+ DyME
993
+ CHORD-
994
+ LUFFY
995
+ FigureS2: Relativeoff-policyinfluenceduringtraining. Eachcurveisnormalizedtoitsinitial
996
+ valueforcomparability. DyMEmeasuresSFT/(SFT+RL)(rawinlightertone,Gaussian-smoothed
997
+ indarkertone),CHORD-µtrackstheglobalweightµ(t),andLUFFYadoptsapolicy-shapingproxy
998
+ E[f(π (a))]withf(x)= x . Allmethodsrevealtheshiftfromoff-policyguidancetoon-policy
999
+ θ x+γ
1000
+ optimization,albeitwithdistinctdecaydynamics.
1001
+ TableS3: Effectoftemplatedoutputacrossmodelsandtasks. ✓denotesfixed-templateoutput,
1002
+ whereas✗denotesfree-formgeneration.
1003
+ Model Template Chart Medical
1004
+ SmolVLM ✓ 60.10 59.38
1005
+ SmolVLM ✗ 59.24 56.13
1006
+ LLaVA-OV-S ✓ 52.87 74.52
1007
+ LLaVA-OV-S ✗ 50.86 72.64
1008
+ truth. Whileopen-endedexplorationisoftenbeneficialforLVLMs,thelimitedcapacityofSVLMs
1009
+ makesunconstrainedexplorationlesseffective,asittendstobeaimlessandincreasesthelearning
1010
+ burden. GiventhatSVLMsaredesignedfortask-specificratherthangeneral-purposescenarios,em-
1011
+ ployingtailoredthinkingtemplatesforeachtaskprovesmoresuitableandyieldsbetterperformance.
1012
+ Forinstance,SmolVLMachieves60.10vs.59.24onChartQAand59.38vs.56.13onMedicalVQA,
1013
+ withLLaVA-OV-Sexhibitingsimilargains.
1014
+ (3)ComparisonbetweenannealedSFTlossandDyME.
1015
+ AsshowninFig.S2, wecomparetherelativeSFT(off- Table S4: SVLM performance satu-
1016
+ policy)influenceacrosstrainingstepsforthreeapproaches: rates after a single training epoch.
1017
+ DyME,CHORD(Zhangetal.,2025b),andLUFFY(Yan
1018
+ Scoreisdomain-specific: chartdomain
1019
+ etal.,2025). ForDyMEandCHORD,thecurvesrepresent usesRel-corr,whilethemedicaldo-
1020
+ thenormalizedweightoftheSFTlossateachstep,while mainusestheaverageofaccuracyand
1021
+ forLUFFYthecurvereflectsthetrajectoryofSFTgradi- recallvalues.
1022
+ entshapingasafunctionofpredictionprobability(which
1023
+ generally correlates with training steps). These curves
1024
+ Model Domain Epoch Score
1025
+ highlightthedynamicnatureofDyME.Becauseoftheex-
1026
+ 1 60.70
1027
+ tremelylimitedcapacityofSVLMs,theirlearningpatterns
1028
+ LLaVA-OV-S Chart 5 60.44
1029
+ canshiftsignificantlyevenbetweenadjacentsteps,leading 10 60.12
1030
+ torapidforgettingofpreviouslyacquiredmodes. Unlike
1031
+ 1 60.22
1032
+ CHORD, which relies on a smooth annealing schedule Chart 5 63.21
1033
+ thatdecaysquicklyandisill-suitedtosuchsmallmodels, SmolVLM 10 62.22
1034
+ DyMEassignsweightsdirectlybasedonmodeloutputs. 1 71.73
1035
+ Medical 5 71.80
1036
+ Thisproducesahighlydynamicandirregulardecay,bet- 10 72.05
1037
+ ter accommodating the instability of SVLMs. LUFFY
1038
+ adopts a shaping function f(x) = x (γ=0.1), which
1039
+ x+γ
1040
+ alsoinducesadynamicdecaywithprobabilitybutremainsheuristicandmaynotalignwellwith
1041
+ thelocal-optimumtendencyofSVLMs. Overall,DyMEisexplicitlytailoredforSVLMs,whereas
1042
+ 5
1043
+ --- PAGE 19 ---
1044
+ PublishedasaconferencepaperatICLR2026
1045
+ TableS5: Detailedlearningtrajectoriesdemonstratingrigoroustuning. Wereporttheperfor-
1046
+ manceacrossmultiplesettingstoshowtheirfulllearningtrajectories. Two-stagebaselinesinclude
1047
+ variantswithandwithoutKLpenaltiestoensureoptimalperformanceiscaptured.
1048
+ DataQuality Method Performanceacrossepochs(1,3,5,10) Bestperf.
1049
+ DyME(ours,pure) Reportfinalscoredirectly 61.9
1050
+ SFT 43.1 → 47.9 → 50.0 → 50.5 50.5
1051
+ Low
1052
+ Two-stage 57.6 → 52.7 → 50.8 → 50.7 57.6
1053
+ Two-stage(w/KL) 54.2 → 55.4 → 52.6 → 54.2 55.4
1054
+ DyME(ours,pure) Reportfinalscoredirectly 64.9
1055
+ SFT 53.6 → 56.5 → 57.8 → 56.4 57.8
1056
+ Medium
1057
+ Two-stage 59.9 → 52.8 → 53.0 → 53.1 59.9
1058
+ Two-stage(w/KL) 59.0 → 60.6 → 60.6 → 60.8 60.8
1059
+ DyME(ours,pure) Reportfinalscoredirectly 68.5
1060
+ SFT 58.2 → 59.1 → 61.0 → 61.6 61.6
1061
+ High
1062
+ Two-stage 51.6 → 54.0 → 54.5 → 54.4 54.5
1063
+ Two-stage(w/KL) 61.7 → 60.9 → 62.7 → 61.8 62.7
1064
+ CHORDandLUFFYmaybemoreappropriateforstrongerbasemodels,reflectingcomplementary
1065
+ strengths.
1066
+ (4)SVLMsconvergerapidly.TableS4showsthatSVLMsconvergeextremelyquickly:performance
1067
+ afteronlyoneepochiscomparableto,orevenexceeds,thataftertenepochs(e.g.,LLaVA-OV-S
1068
+ achieves 60.70 vs. 60.12 on the Chart domain). This indicates that the very limited capacity of
1069
+ SVLMsmakesthempronetooverfittingtolocaloptima. Italsosubstantiatesourearlierclaimthat
1070
+ suchrapidconvergenceleavesonlyanarrowwindowforbalancingSFTandRL,makingitdifficultto
1071
+ achievethetrade-offthroughempiricalhyperparametertuning. Consequently,staticfusionmethods
1072
+ areunsuitableforSVLMs.
1073
+ To ensure a rigorous comparison, we further report the full learning trajectories of baselines in
1074
+ Table S5. We evaluated the Two-stage baseline (with and without KL penalty) and SFT across
1075
+ multipleepochs(1,3,5,10)tocapturetheirpeakperformance. Theresultsconfirmthatevenwith
1076
+ optimalstopping,thebaselinesconsistentlyunderperformDyME,whichachievessuperiorresultsina
1077
+ singletrainingrunwithouttheneedforepochselection.
1078
+ (5)Generalityacrosscomplexreasoningandpuretext. TodemonstratethescalabilityofDyME,
1079
+ we applied it to two new domains without modifying the core algorithm: Physical Reasoning
1080
+ (A-OKVQA)andPureTextReasoning(GSM8K).
1081
+ • PhysicalReasoning(A-OKVQA):Weaddressedthechallengeofopen-endedvisualreasoning
1082
+ bytestingonA-OKVQA.WeusedQwen2.5-VL-7BtoautomaticallygenerateVisualFactsusing
1083
+ the prompt defined in §S1 (e.g., “man, wearing a light blue and white shirt...”). As shown in
1084
+ TableS6,DyMEachievedamassivegainof+18.8%(54.2%→73.0%),provingthatthemethod
1085
+ scaleseffortlesslytotasksrequiringworldknowledgeandcommonsense.
1086
+ • PureTextReasoning(GSM8K):Inpuretextdomains,the“VisualFact”extractionstepisnaturally
1087
+ skipped. OntheGSM8Kmathbenchmark,DyMEimprovedQwen2.5-0.5Bfrom49.5%to55.3%,
1088
+ demonstratingthattheparadigmgeneralizesevenwhen“vision”isabsent.
1089
+ Theseresults,combinedwiththeChartQAimprovementsonthestrongerQwen2.5-VL-7Bmodel,
1090
+ confirm that DyME is not limited by the extraction step. By leveraging off-the-shelf LVLMs to
1091
+ automatevisualfactgeneration,theframeworkisimmediatelyapplicabletodiversevisualandtextual
1092
+ domains.
1093
+ Limitations on Abstract Visuals. We acknowledge that the VS module may face challenges
1094
+ in scenarios where “Visual Facts” are intrinsically difficult to define or extract, such as memes
1095
+ (relying on irony or cultural context) or highly abstract non-commonsense reasoning. However,
1096
+ 6
1097
+ --- PAGE 20 ---
1098
+ PublishedasaconferencepaperatICLR2026
1099
+ Table S6: Generality of DyME across New Domains. We demonstrate performance gains on
1100
+ ComplexScenes(A-OKVQA),PureText(GSM8K),andwithstrongerbasemodels(Qwen2.5-VL-
1101
+ 7B).BaselinesfortextusestandardGRPO.
1102
+ Domain Task BaseModel Method Baseline(%) DyME(%)
1103
+ WorldKnowledge A-OKVQA LLaVA-OV-S Two-stage 54.2 73.0(+18.8)
1104
+ PureText GSM8K Qwen2.5-0.5B GRPO 49.5 55.3(+5.8)
1105
+ NewLVLM ChartQA Qwen2.5-VL-7B SFT 87.3 89.6(+2.3)
1106
+ ourprimaryobjectiveistoempowerSVLMsforpractical,real-worldproductiontasks(e.g.,chart
1107
+ processing,medicaldiagnostics,geometricsolving). Inthesestructuredandsemi-structureddomains
1108
+ whereSVLMsaremostcommonlydeployed,VisualFactsarewell-definedandDyMEproveshighly
1109
+ effective.
1110
+ (6) Human evaluation of CoT quality. Automatic metrics like relaxed accuracy do not fully
1111
+ reflect the quality of the reasoning process. To verify whether DyME generates genuinely better
1112
+ thinkingtraces,weconductedahumanevaluationon100randomlysampledinstancesfromChartQA.
1113
+ AnnotatorsjudgedthevalidityofthegeneratedCoTbasedonitslogicalcoherenceandgrounding.
1114
+ As shown in Table S7, DyME produces traces that are slightly more concise (shorter length) but
1115
+ significantlymorevalid(validityrate∼70%)comparedtotheTwo-stagebaseline(∼30-40%). This
1116
+ confirms that DyME effectively mitigates the generation of “pseudo thinking traces” that plague
1117
+ standardSFT/Two-stagetraining.
1118
+ TableS7: HumanevaluationofCoTqualityonChartQA.
1119
+ BaseModel Method Avg. CoTLength HumanEval(Valid%)
1120
+ Two-stage ∼76.3Words 31%
1121
+ LLaVA-OV-S
1122
+ DyME ∼69.7Words 68%
1123
+ Two-stage ∼84.5Words 40%
1124
+ SmolVLM
1125
+ DyME ∼75.4Words 72%
1126
+ S3 SHOWCASES
1127
+ Before presenting the model outputs, we first illustrate the data quality definitions used in our
1128
+ Algorithmic Validation (Section 4.1 of the main paper). Table S8 showcases examples of Low
1129
+ (Undesigned),Medium(Standard),andHigh(Premium)qualityChain-of-Thoughtsupervisionfor
1130
+ thesamequestion. ThisvisualizesthesignificantgapinstructureanddetailthatDyMEmustbridge
1131
+ whentrainedonnon-premiumdata. Furthermore,TableS9illustratesthecomprehensiveformatof
1132
+ oursupervisiondata,encompassingtheinputimage,theassociatedquestion,theextractedvisual
1133
+ facts,andtheground-truthresponse.
1134
+ WepresentdialogueinstancesofSmolVLM,LLaVA-OV-S,andInternVL2-S,whichweretrained
1135
+ withDyMEinourexperiments,ontasksinthedomainsofmedicalVQA,chartunderstanding,and
1136
+ geometry. AsshowninFigS3toFigS5,allmodelstrainedwithDyMEdemonstratetheabilityto
1137
+ generateeffectivethinkingtraceswithaccurateintermediatevalues(ingreen),whichplayacrucial
1138
+ roleinreachingthefinalcorrectanswer.
1139
+ Thethinkingprocessesofthesemodelsarerelativelyfixedandtemplate-based. Thisisaresultof
1140
+ our vision supervision module taking effect, as SVLMs should not be overly broad and general;
1141
+ otherwise,theirexplorationcaneasilydivergeinsteadofconverging.
1142
+ 7
1143
+ --- PAGE 21 ---
1144
+ PublishedasaconferencepaperatICLR2026
1145
+ TableS8: Showcaseofchain-of-thought(CoT)supervisionwithdifferentqualityforChartQA.
1146
+ Type Content
1147
+ Question Whendoestheunfavorableviewreachthepeak?
1148
+ High-quality <SUMMARY>IinspectthelinegraphshowingveryunfavorableviewsofPakistaninIndiaover
1149
+ CoT time,andidentifytheyearwherethepercentagepeaksbylocatingthemaximumpointonthe
1150
+ curve.</SUMMARY>
1151
+ <CAPTION>Linegraphtitled“VeryunfavorableviewsofPakistanincreasinginIndia”.The
1152
+ y-axisshowstheshareofrespondentswithaveryunfavorableview(0–100%). Thex-axis
1153
+ showsyears2013–2017.Markedvalues:54%(2013),49%(2014),51%(2015),55%(2016),
1154
+ and64%(2017).Theseriesexhibitsanoverallupwardtrendandpeaksin2017.</CAPTION>
1155
+ <REASONING>Thepeakcorrespondstothehighestpercentageonthey-axis. Amongthe
1156
+ annotatedyears,2017hasthelargestvalue(64%),higherthan2013–2016.Therefore,thepeak
1157
+ occursin2017.</REASONING>
1158
+ <CONCLUSION>2017</CONCLUSION>
1159
+ Medium-quality Goal:Determinetheyearwhentheunfavorableviewreachesitsmaximum.
1160
+ CoT Observation:Thevaluesforeachyearare:2013:0,2014:0,2015:0,2016:55,and2017:64.
1161
+ Reasoning:Comparingthesevalues,thelargestvalueis64,whichappearsin2017.
1162
+ Conclusion:Theunfavorableviewreachesitspeakin2017.
1163
+ Low-quality I’mtryingtofigureouttheyearwhentheunfavorableviewreachesitshighestpoint.Thevalues
1164
+ CoT staylowuntil2016,wheretheyriseto55,andthenclimbfurtherto64in2017.Thismakes
1165
+ 2017theyearwhentheunfavorableviewreachesitsmaximum.
1166
+ TableS9: ShowcaseofacompleteCoTtrainingexample.
1167
+ Type Content
1168
+ Question Whatisthemanbythebagsawaiting?
1169
+ VisualFact {”description”: ”Amanstandsinthemiddleofaresidentialstreet,facingawayfromthe
1170
+ camera. Heholdsaredbagandpullsawheeledblacksuitcase,withanotherblacksuitcase
1171
+ placednearby.Houses,parkedcars,andtreeslinethebackground,suggestingheiswaiting
1172
+ tocrossorbepickedup.”,”objects”:[ ”name”:”man”,”attributes”:[”lightblueandwhite
1173
+ shirt”,”bluejeans”,”carryingaredbag”,”pullingawheeledsuitcase”],”position”:”center”,
1174
+ ”name”:”redbag”,”attributes”:[”heldbytheman”],”position”:”leftoftheman”, ”name”:
1175
+ ”blacksuitcase”,”attributes”:[”wheeled”,”beingpulled”],”position”:”neartheman’sfeet”
1176
+ , ”name”:”blacksuitcase”,”attributes”:[”lyingontheground”],”position”:”neartheman”
1177
+ , ”name”:”street”,”attributes”:[”asphalt”,”urbanresidential”],”position”:”foreground”,
1178
+ ”name”:”houses”,”attributes”:[”residentialhomes”],”position”:”backgroundleft”, ”name”:
1179
+ ”parkedcars”,”attributes”:[”redSUV”,”othervehicles”],”position”:”backgroundcenterand
1180
+ left”, ”name”:”trees”,”attributes”:[”greenfoliage”],”position”:”backgroundright” ]}
1181
+ GTResponse Atrainwouldnotappearonaresidentialstreet,andtheman’sluggagesuggestsheiswaiting
1182
+ fortransportationratherthanadelivery.Thenearbyskateboarderisnotinteractingwithhim,
1183
+ leavingacabasthemostplausibleoption.
1184
+ 8
1185
+ --- PAGE 22 ---
1186
+ PublishedasaconferencepaperatICLR2026
1187
+ (a)ShowcaseofMedicalVQA
1188
+ (b)ShowcaseofChartUnderstanding
1189
+ (c)ShowcaseofGeometrySolving
1190
+ FigureS3: ShowcasesofSmolVLM.TheSVLMoriginallyproduceshallucinatedanswers(red),
1191
+ whiletheDyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategrounded
1192
+ values,effectivelyimprovingtheperformance.
1193
+ 9
1194
+ --- PAGE 23 ---
1195
+ PublishedasaconferencepaperatICLR2026
1196
+ (a)ShowcaseofMedicalVQA
1197
+ (b)ShowcaseofChartUnderstanding
1198
+ (c)ShowcaseofGeometrySolving
1199
+ FigureS4: ShowcasesofInternVL2-S.TheSVLMoriginallyproduceshallucinatedanswers(red),
1200
+ whiletheDyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategrounded
1201
+ values,effectivelyimprovingtheperformance.
1202
+ 10
1203
+ --- PAGE 24 ---
1204
+ PublishedasaconferencepaperatICLR2026
1205
+ (a)ShowcaseofChartUnderstanding
1206
+ (b)ShowcaseofMedicalVQA
1207
+ (c)ShowcaseofGeometrySolving
1208
+ FigureS5: ShowcasesofLLaVA-OV-S.TheSVLMoriginallyproduceshallucinatedanswers(red),
1209
+ whiletheDyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategrounded
1210
+ values,effectivelyimprovingtheperformance.
1211
+ 11
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.11.0
2
+ pycocotools
3
+ matplotlib
4
+ datasets
5
+ peft==0.17.0
6
+ qwen_vl_utils
7
+ wandb
8
+ transformers==4.57.1
9
+ deepspeed==0.18.1
10
+ pycocotools
11
+ trl==0.23.1
12
+ flash-attn==2.7.4.post1
13
+ scikit-image
14
+ openai
15
+ spacy
16
+ autoawq==0.2.9
reward_utils/__pycache__/format_checks.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
reward_utils/compute_rewards.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ from typing import List, Dict, Any
3
+ from .checker import RewardCalculator
4
+
5
+ def split_initial_context(text: str):
6
+ text = text.lower()
7
+ flag = 'answer:'
8
+ if flag in text:
9
+ ans = text.split(flag)[-1].strip()
10
+ context = text.split(flag)[0].strip()
11
+ ans = ans.strip('.')
12
+ else:
13
+ context = text
14
+ ans = text
15
+ return context, ans
16
+
17
+ def calculate_rewards_in_parallel(
18
+ checker: RewardCalculator,
19
+ batch_data: Dict[str, Any],
20
+ gpu_id: int,
21
+ num_threads: int = 8,
22
+ task='chart'):
23
+ """
24
+ Calculates accuracy rewards for a batch of data in parallel using a thread pool.
25
+
26
+ Args:
27
+ batch_data: A dictionary containing lists of data, including 'response',
28
+ 'prompt', 'image', 'answer', and an optional 'tp' (answer_type).
29
+ gpu_id: The ID of the GPU to be used for processing.
30
+ num_threads: The number of parallel threads to use.
31
+
32
+ Returns:
33
+ A list of calculated reward scores for each item in the batch.
34
+ """
35
+ # Extract lists of data from the input dictionary
36
+ responses = batch_data['response']
37
+ predictions = []
38
+ for r in responses:
39
+ c, p = split_initial_context(r)
40
+ predictions.append(p)
41
+ prompts = batch_data['prompt']
42
+ # questions = batch_data['question']
43
+ answers = batch_data['answer']
44
+ hints = batch_data['hints'] if 'hints' in batch_data else [""] * len(responses)
45
+ num_samples = len(responses)
46
+
47
+ # Safely get 'answer_types', providing a list of Nones as a default
48
+ # This fixes a bug in the original code.
49
+ answer_types = batch_data.get('tp', [None] * num_samples)
50
+
51
+ # Prepare the arguments for each task by zipping the data together.
52
+ # This creates an iterator of tuples, where each tuple contains all args for one call.
53
+ in_answers = answers
54
+ if 'world' in task:
55
+ in_answers = batch_data['direct_answers']
56
+ task_answer_args = zip(
57
+ predictions,
58
+ in_answers,
59
+ [task] * num_samples,
60
+ # [gpu_id] * num_samples,
61
+ # answer_types,
62
+ # hints
63
+ )
64
+ task_thinking_args = zip(
65
+ responses,
66
+ prompts,
67
+ answers,
68
+ hints,
69
+ [task] * num_samples
70
+ )
71
+
72
+ # Use a ThreadPoolExecutor to process the data in parallel.
73
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
74
+ # Instead of a separate function, use a lambda to unpack the arguments.
75
+ # The '*' operator unpacks each tuple from task_args into positional arguments
76
+ # for the get_acc_reward function.
77
+
78
+ format_rewards = list(executor.map(lambda r: checker.get_format_reward(r, task=task), responses))
79
+ answer_rewards = list(executor.map(lambda args: checker.get_answer_reward(*args), task_answer_args))
80
+ thinking_rewards = list(executor.map(
81
+ lambda args: checker.get_thinking_reward_prompt(*args), task_thinking_args))
82
+
83
+ rewards = [0 if f == 0 else f + a + t for f, a, t in zip(format_rewards, answer_rewards, thinking_rewards)]
84
+
85
+ return rewards, format_rewards, answer_rewards, thinking_rewards
86
+
87
+ def refine_context_in_parallel(
88
+ refiner,
89
+ questions: List[str],
90
+ hints: List[str],
91
+ reference_answers: List[str],
92
+ task,
93
+ gpu_id: int,
94
+ num_threads: int = 8):
95
+ """
96
+ Refines contexts for a batch of data in parallel using a thread pool.
97
+
98
+ Args:
99
+ questions: A list of questions.
100
+ hints: A list of hints corresponding to each question.
101
+ reference_answers: A list of reference answers.
102
+ tasks: A list of task types corresponding to each question.
103
+ gpu_id: The ID of the GPU to be used for processing.
104
+ num_threads: The number of parallel threads to use.
105
+
106
+ Returns:
107
+ A list of refined contexts for each question.
108
+ """
109
+ num_samples = len(questions)
110
+ tasks = [task] * num_samples
111
+ # Prepare the arguments for each task by zipping the data together.
112
+ task_args = zip(
113
+ questions,
114
+ hints,
115
+ reference_answers,
116
+ tasks,
117
+ [gpu_id] * num_samples
118
+ )
119
+
120
+ # Use a ThreadPoolExecutor to process the data in parallel.
121
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
122
+ refined_contexts = list(executor.map(
123
+ lambda args: refiner.refine_hint(*args), task_args
124
+ ))
125
+
126
+ return refined_contexts
reward_utils/refiner.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Optional
3
+
4
+ from client_utils.openai_api import OpenAIClient
5
+ from data_utils.chart.evaluator import eval_one_chart
6
+ from data_utils.commom_util import prompt_ic
7
+
8
+ import os
9
+ import time
10
+ from filelock import FileLock
11
+ TEMPLATE_FILE = "best_template.txt"
12
+ LOCK_FILE = "best_template.txt.lock"
13
+
14
+
15
+ TEMPLATE_REFRESH_INTERVAL = 60 # Interval (in seconds) to refresh template from file
16
+
17
+
18
+ class ContextRefiner:
19
+ """
20
+ A class to refine hints / reasoning with an external LLM.
21
+ Encapsulates logic for template management and refinement calls.
22
+ """
23
+
24
+ def __init__(self, RL_CONFIG, CLIENT_CONFIG, gpu_id=0):
25
+ """
26
+ Initializes the ContextRefiner.
27
+
28
+ Args:
29
+ RL_CONFIG: RL-related configuration dict.
30
+ CLIENT_CONFIG: LLM client configuration dict.
31
+ gpu_id: process / GPU id used to select API server.
32
+ """
33
+ self.refine_templetes = ["""Goal: [State the user's objective, e.g., Find the year with the highest sales]
34
+ Observation: [List key data points from the chart, e.g., 2020: 150, 2021: 200, 2022: 180]
35
+ Reasoning: [State the logical step, e.g., Compare the values. 200 is the maximum.]
36
+ Conclusion: [Draw the conclusion, e.g., The year with the highest sales was 2021.]
37
+ """]
38
+ self.template_lock = FileLock(LOCK_FILE)
39
+ # Set to 0 so that the first call will force a refresh from file
40
+ self.last_template_check_time = 0
41
+ if CLIENT_CONFIG['client_type'] == 'openai':
42
+ if CLIENT_CONFIG['init_port'] is not None:
43
+ num_server = int(CLIENT_CONFIG['num_server'])
44
+ server_id = gpu_id % num_server
45
+ CLIENT_CONFIG['api_base'] = CLIENT_CONFIG['api_base'] % str(CLIENT_CONFIG['init_port'] + server_id)
46
+ self.client = OpenAIClient(config=CLIENT_CONFIG)
47
+ else:
48
+ raise ValueError(f"Client type '{CLIENT_CONFIG['client_type']}' not supported.")
49
+
50
+ def _check_and_update_template(self):
51
+ """
52
+ (Private method) Check whether we need to refresh the template from file.
53
+ This operation is process-safe.
54
+ """
55
+ current_time = time.time()
56
+
57
+ # 1. Check whether the refresh interval has passed
58
+ if (current_time - self.last_template_check_time) < TEMPLATE_REFRESH_INTERVAL:
59
+ return # Not yet time, keep using cached template
60
+
61
+ # 2. Try to acquire the lock and read (short timeout since reading should be fast)
62
+ try:
63
+ # print(f"[Process {os.getpid()}] Checking for template update...") # Uncomment for debugging
64
+ with self.template_lock.acquire(timeout=5):
65
+
66
+ # --- Lock acquired, safe to read ---
67
+ if not os.path.exists(TEMPLATE_FILE):
68
+ # File does not exist, keep default template
69
+ self.last_template_check_time = current_time
70
+ return
71
+
72
+ with open(TEMPLATE_FILE, "r", encoding="utf-8") as f:
73
+ new_template = f.read().strip()
74
+
75
+ # If file content is valid and different, update in-memory template
76
+ if new_template and new_template != self.refine_templetes[0]:
77
+ self.refine_templetes = [new_template]
78
+ print(f"[Process {os.getpid()}] Refiner template updated from file.")
79
+
80
+ # Regardless of success, update last check time to avoid frequent retries
81
+ self.last_template_check_time = current_time
82
+
83
+ except TimeoutError:
84
+ # Failed to acquire lock (another process is likely writing)
85
+ # Do not block; skip and try again next time
86
+ print(f"[Process {os.getpid()}] Failed to acquire lock for template read, using cached version.")
87
+ # Update time to avoid immediate retry
88
+ self.last_template_check_time = current_time
89
+
90
+ except Exception as e:
91
+ print(f"[Process {os.getpid()}] Error reading template file: {e}")
92
+ self.last_template_check_time = current_time
93
+
94
+ def refine_hint(self, question, hint: str, reference_answer: str, task: str, gpu_id=None):
95
+ if hint == "":
96
+ return hint
97
+
98
+ self._check_and_update_template()
99
+ system_prompt = None
100
+ if 'medical' in task:
101
+ system_prompt = 'You are a seasoned professional in the field of medical image analysis, demonstrating exceptional expertise and insight into complex medical imaging data. Your output should be only judgement, without any additional text or explanation.'
102
+ elif 'math' in task:
103
+ system_prompt = 'You are a seasoned professional in the field of mathematics, demonstrating exceptional expertise and insight into complex mathematical problems. Your output should be only judgement, without any additional text or explanation.'
104
+ elif 'chart' in task:
105
+ system_prompt = 'You are a seasoned professional in the field of chart analysis, demonstrating exceptional expertise and insight into complex chart data. Your output should be only judgement, without any additional text or explanation.'
106
+ elif 'world' in task:
107
+ system_prompt = 'You are a seasoned professional in the field of world knowledge and image analysis, demonstrating exceptional expertise and insight into complex real-world scenarios. Your output should be only judgement, without any additional text or explanation.'
108
+ else:
109
+ Exception('Unknown expert task')
110
+
111
+ try:
112
+ in_context_example = self.client.get_completion(
113
+ prompt_ic % hint,
114
+ system_prompt=system_prompt,
115
+ max_tokens=5000
116
+ )
117
+
118
+ if 'chart' in task or 'world' in task:
119
+ if 'chart' in task:
120
+ from data_utils.chart.prompts import prompt_thinking_reward, prompt_refine
121
+ else:
122
+ from data_utils.aokvqa.prompts import prompt_thinking_reward, prompt_refine
123
+ # Construct the final prompt for the evaluator model.
124
+ evaluation_prompt = prompt_refine % (
125
+ in_context_example,
126
+ question,
127
+ reference_answer,
128
+ self.refine_templetes[0]
129
+ )
130
+ output = self.client.get_completion(
131
+ evaluation_prompt,
132
+ system_prompt=system_prompt,
133
+ max_tokens=1000
134
+ )
135
+ return output
136
+ else:
137
+ raise ValueError(f"Task '{task}' not supported for thinking reward.")
138
+ except Exception as e:
139
+ print(f"An error occurred during thinking reward prompt generation: {e}")
140
+ return None
141
+
142
+
143
+ class ContextRefinerLocal:
144
+ """
145
+ A local (non-LLM) refiner that simply returns the original hint.
146
+ Used when remote refinement is disabled or not desired.
147
+ """
148
+
149
+ def __init__(self, RL_CONFIG, CLIENT_CONFIG, gpu_id=0):
150
+ """
151
+ Initializes the local ContextRefiner.
152
+
153
+ Args:
154
+ RL_CONFIG: RL-related configuration dict.
155
+ CLIENT_CONFIG: client configuration dict (unused here).
156
+ gpu_id: process / GPU id (unused here).
157
+ """
158
+ # Do nothing; local refiner is a no-op.
159
+ pass
160
+
161
+ def refine_hint(self, question, hint: str, reference_answer: str, task: str, gpu_id=None):
162
+ return hint
tests/test_data_health_probe.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for batch data health diagnostics."""
2
+ import torch
3
+
4
+ from opsd_utils.diagnostics import (
5
+ _detect_char_repeat,
6
+ summarize_batch_data_health,
7
+ )
8
+
9
+
10
+ def test_detect_char_repeat_cjk():
11
+ assert _detect_char_repeat("Goal: " + "其" * 10)
12
+
13
+
14
+ def test_summarize_batch_data_health_empty_vf():
15
+ samples = [
16
+ {"prompt": "q1", "visual_fact_hint": ""},
17
+ {"prompt": "q2", "visual_fact_hint": "bar value 3"},
18
+ ]
19
+ stats = summarize_batch_data_health(samples)
20
+ assert stats["visual_fact_empty_rate"] == 0.5
21
+ assert stats["batch_size"] == 2
22
+
23
+
24
+ def test_summarize_batch_data_health_pixel_nan():
25
+ samples = [{"prompt": "q", "visual_fact_hint": "x"}]
26
+ pixel = torch.tensor([float("nan"), 1.0, 2.0])
27
+ stats = summarize_batch_data_health(samples, pixel_values=pixel)
28
+ assert stats["pixel_has_nan"] is True
tests/test_degeneration_probe.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression tests for completion degeneration heuristics."""
2
+ from unittest.mock import patch
3
+
4
+ import torch
5
+
6
+ from opsd_utils import debug_log as opsd_debug
7
+ from opsd_utils.diagnostics import (
8
+ _detect_char_repeat,
9
+ _detect_degeneration,
10
+ _detect_repeat_loop,
11
+ _detect_single_token_repeat,
12
+ _max_same_token_run,
13
+ is_degenerate_completion,
14
+ log_generate_probe,
15
+ )
16
+
17
+
18
+ class _FakeTokenizer:
19
+ eos_token_id = 151645
20
+ pad_token_id = 151643
21
+ bos_token_id = None
22
+
23
+ def decode(self, ids, skip_special_tokens=False):
24
+ if isinstance(ids, torch.Tensor):
25
+ ids = ids.tolist()
26
+ return " ".join(str(i) for i in ids)
27
+
28
+
29
+ def test_single_token_repeat_detects_cjk_loop():
30
+ ids = [39992, 25, 7379] + [41146] * 40
31
+ assert _detect_single_token_repeat(ids)
32
+ assert _max_same_token_run(ids) == (40, 41146)
33
+
34
+
35
+ def test_ngram_repeat_not_limited_to_first_eight_tokens():
36
+ prefix = list(range(20))
37
+ gram = [9, 8, 7]
38
+ ids = prefix + gram * 5
39
+ assert _detect_repeat_loop(ids)
40
+
41
+
42
+ def test_char_repeat_detects_qiqiqi():
43
+ assert _detect_char_repeat("其其其其其其")
44
+
45
+
46
+ def test_is_degenerate_completion_detects_repeat():
47
+ ids = [39992, 25] + [41146] * 20
48
+ assert is_degenerate_completion(ids, "Goal: x\n" + "其" * 40)
49
+
50
+
51
+ def test_short_numeric_answer_not_degenerate_without_answer_flag():
52
+ ids = [198, 17, 15, 18, 15] # \n2030
53
+ assert not is_degenerate_completion(ids, "\n2030", require_answer_flag=False)
54
+ assert is_degenerate_completion(ids, "\n2030", require_answer_flag=True)
55
+
56
+
57
+ def test_degeneration_flags_missing_answer():
58
+ ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
59
+ text = "Goal: test\nObservation: x\nReasoning: y\nConclusion: z"
60
+ is_deg, reasons = _detect_degeneration(ids, text, answer_flag="Answer:")
61
+ assert is_deg
62
+ assert any(r.startswith("ANSWER_FLAG_COUNT") for r in reasons)
63
+
64
+
65
+ def test_log_generate_probe_does_not_shadow_tokenizer_across_samples():
66
+ """sample[0] single-token repeat must not break decode for sample[1]."""
67
+ repeat_tail = [24] * 12
68
+ row0 = [39992, 25, 7379] + repeat_tail + [0] * (200 - 3 - len(repeat_tail))
69
+ row1 = [39992, 25, 7379, 100, 101, 102] + [0] * 194
70
+ completion_ids = torch.tensor([row0, row1], dtype=torch.long)
71
+ completion_mask = torch.tensor(
72
+ [[1] * 15 + [0] * 185, [1] * 6 + [0] * 194],
73
+ dtype=torch.long,
74
+ )
75
+ is_eos = torch.zeros_like(completion_mask, dtype=torch.bool)
76
+ is_eos[:, 14] = True
77
+ is_eos[:, 5] = True
78
+ eos_idx = torch.tensor([14, 5], dtype=torch.long)
79
+ completions = ["Goal: repeat\n" + "x" * 20, "Goal: ok\nAnswer: 1"]
80
+
81
+ with patch.object(opsd_debug, "should_log_probe", return_value=True):
82
+ stats = log_generate_probe(
83
+ global_step=1,
84
+ trainer_step=1,
85
+ prompt_length=100,
86
+ prompt_completion_ids=torch.zeros(2, 300, dtype=torch.long),
87
+ completion_ids=completion_ids,
88
+ completion_mask=completion_mask,
89
+ is_eos=is_eos,
90
+ eos_idx=eos_idx,
91
+ completions=completions,
92
+ tokenizer=_FakeTokenizer(),
93
+ generation_config=None,
94
+ max_completion_length=200,
95
+ num_generations=1,
96
+ sample_count=2,
97
+ )
98
+ assert stats["degenerate_count"] >= 1
tests/test_health_monitor.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for TrainingHealthMonitor alerts and correlation."""
2
+ from opsd_utils.health_monitor import (
3
+ ALERT_GEN_CLIP_COLLAPSE,
4
+ ALERT_GEN_REPEAT_DEGEN,
5
+ ALERT_RL_ZERO_SIGNAL,
6
+ TrainingHealthMonitor,
7
+ )
8
+
9
+
10
+ def test_clip_collapse_alert():
11
+ hm = TrainingHealthMonitor({"enabled": True, "log_alerts_immediately": False})
12
+ hm.reset_step(1)
13
+ alerts = hm.record_generate(
14
+ 1,
15
+ {"clipped_rate": 0.85, "eos_terminated_rate": 0.1, "degenerate_rate": 0.2, "repeat_loop_count": 0},
16
+ {"p_greedy_first": 0.99, "p_eos_first": 1e-6},
17
+ )
18
+ assert ALERT_GEN_CLIP_COLLAPSE in alerts
19
+
20
+
21
+ def test_repeat_degen_alert():
22
+ hm = TrainingHealthMonitor({"enabled": True, "log_alerts_immediately": False})
23
+ hm.reset_step(2)
24
+ alerts = hm.record_generate(
25
+ 2,
26
+ {"clipped_rate": 0.3, "eos_terminated_rate": 0.5, "degenerate_rate": 0.6, "repeat_loop_count": 1},
27
+ {},
28
+ )
29
+ assert ALERT_GEN_REPEAT_DEGEN in alerts
30
+
31
+
32
+ def test_rl_zero_signal_alert():
33
+ hm = TrainingHealthMonitor({"enabled": True, "log_alerts_immediately": False})
34
+ hm.reset_step(3)
35
+ hm.record_loss(3, {"advantages_abs_mean": 0.0, "grpo_zero_loss_rate": 0.95})
36
+ assert ALERT_RL_ZERO_SIGNAL in hm._step_alerts
37
+
38
+
39
+ def test_correlate_hints_after_history():
40
+ hm = TrainingHealthMonitor({"enabled": True, "window": 5, "log_every_step": False})
41
+ hm.reset_step(0)
42
+ hm.record_generate(
43
+ 0,
44
+ {"clipped_rate": 0.1, "eos_terminated_rate": 0.9, "degenerate_rate": 0.1, "repeat_loop_count": 0},
45
+ {"p_greedy_first": 0.8, "p_eos_first": 0.01},
46
+ )
47
+ hm.record_optimizer(0, 0.5, 8e-5)
48
+ hm.finish_step(0)
49
+
50
+ hm.reset_step(1)
51
+ hm.record_generate(
52
+ 1,
53
+ {"clipped_rate": 0.9, "eos_terminated_rate": 0.05, "degenerate_rate": 0.5, "repeat_loop_count": 1},
54
+ {"p_greedy_first": 0.995, "p_eos_first": 1e-6},
55
+ )
56
+ hm.record_optimizer(1, 2.5, 8e-5)
57
+ corr = hm.correlate()
58
+ assert "delta_clipped_rate" in corr or "root_cause_hints" in corr
59
+
60
+
61
+ def test_finish_step_returns_metrics_keys():
62
+ hm = TrainingHealthMonitor({"enabled": True, "metrics_every_step": True, "log_every_step": False})
63
+ hm.reset_step(1)
64
+ hm.record_generate(
65
+ 1,
66
+ {"clipped_rate": 0.2, "eos_terminated_rate": 0.8, "degenerate_rate": 0.1, "repeat_loop_count": 0},
67
+ {"p_greedy_first": 0.9, "p_eos_first": 0.001},
68
+ )
69
+ hm.record_optimizer(1, 1.0, 8e-5)
70
+ metrics = hm.finish_step(1)
71
+ assert "completions/degenerate_rate" in metrics
72
+ assert "health/alert_count" in metrics
tests/test_mode_router_rlsd.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RLSD / COPSD anti-leakage routing tests."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+
8
+ import torch
9
+
10
+ from opsd_utils.constants import MODE_GRPO, MODE_OPSD, MODE_SFT
11
+ from opsd_utils.mode_router import route_completion_modes, route_prompt_modes
12
+
13
+
14
+ def _rlsd_cfg(**gate):
15
+ base = {
16
+ "enabled": True,
17
+ "mode": "rlsd",
18
+ "gate": {
19
+ "correct_threshold": 0.5,
20
+ "per_completion_opsd": True,
21
+ "require_format_for_opsd": False,
22
+ **gate,
23
+ },
24
+ }
25
+ return base
26
+
27
+
28
+ def test_rlsd_prompt_correct_grpo():
29
+ acc = torch.tensor([[1.0, 0.0]])
30
+ modes = route_prompt_modes(acc, 2, _rlsd_cfg(), recoverable_flags=[True])
31
+ assert modes == [MODE_GRPO]
32
+
33
+
34
+ def test_rlsd_prompt_wrong_opsd_when_recoverable():
35
+ acc = torch.tensor([[0.0, 0.0]])
36
+ modes = route_prompt_modes(acc, 2, _rlsd_cfg(), recoverable_flags=[True])
37
+ assert modes == [MODE_OPSD]
38
+
39
+
40
+ def test_rlsd_prompt_wrong_sft_when_not_recoverable():
41
+ acc = torch.tensor([[0.0, 0.0]])
42
+ modes = route_prompt_modes(acc, 2, _rlsd_cfg(), recoverable_flags=[False])
43
+ assert modes == [MODE_SFT]
44
+
45
+
46
+ def test_rlsd_per_completion_routing():
47
+ acc = torch.tensor([[1.0, 0.0]])
48
+ fmt = torch.tensor([[1.0, 0.5]])
49
+ modes = route_completion_modes(acc, 2, 2, _rlsd_cfg(), [True], format_rewards=fmt)
50
+ assert modes == [MODE_GRPO, MODE_OPSD]
51
+
52
+
53
+ def test_copsd_opd_alias_matches_rlsd():
54
+ acc = torch.tensor([[0.0, 1.0]])
55
+ cfg = _rlsd_cfg()
56
+ cfg["mode"] = "copsd_opd"
57
+ modes = route_completion_modes(acc, 2, 2, cfg, [True])
58
+ assert modes == [MODE_OPSD, MODE_GRPO]
59
+
60
+
61
+ def test_rlsd_all_wrong_group_first_completion_sft_cold_start():
62
+ """All-wrong group: gen 0 → SFT replace, other wrong gens → OPSD."""
63
+ acc = torch.tensor([[0.0, 0.0, 0.0, 0.0]])
64
+ modes = route_completion_modes(acc, 4, 4, _rlsd_cfg(), [True])
65
+ assert modes == [MODE_SFT, MODE_OPSD, MODE_OPSD, MODE_OPSD]
66
+
67
+
68
+ def test_rlsd_all_wrong_two_prompts():
69
+ acc = torch.tensor([[0.0, 0.0], [0.0, 0.0]])
70
+ modes = route_completion_modes(acc, 2, 4, _rlsd_cfg(), [True, True])
71
+ assert modes == [MODE_SFT, MODE_OPSD, MODE_SFT, MODE_OPSD]
72
+
73
+
74
+ def test_rlsd_partial_correct_no_cold_start_on_wrong():
75
+ acc = torch.tensor([[1.0, 0.0, 0.0]])
76
+ modes = route_completion_modes(acc, 3, 3, _rlsd_cfg(), [True])
77
+ assert modes == [MODE_GRPO, MODE_OPSD, MODE_OPSD]
78
+
79
+
80
+ def test_online_sft_on_all_wrong_can_be_disabled():
81
+ acc = torch.tensor([[0.0, 0.0]])
82
+ cfg = _rlsd_cfg(online_sft_on_all_wrong=False)
83
+ modes = route_completion_modes(acc, 2, 2, cfg, [True])
84
+ assert modes == [MODE_OPSD, MODE_OPSD]
85
+
86
+
87
+ if __name__ == "__main__":
88
+ test_rlsd_prompt_correct_grpo()
89
+ test_rlsd_prompt_wrong_opsd_when_recoverable()
90
+ test_rlsd_prompt_wrong_sft_when_not_recoverable()
91
+ test_rlsd_per_completion_routing()
92
+ test_copsd_opd_alias_matches_rlsd()
93
+ test_rlsd_all_wrong_group_first_completion_sft_cold_start()
94
+ test_rlsd_all_wrong_two_prompts()
95
+ test_rlsd_partial_correct_no_cold_start_on_wrong()
96
+ test_online_sft_on_all_wrong_can_be_disabled()
97
+ print("RLSD routing tests passed.")
tests/test_privileged.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import tempfile
5
+
6
+ import pytest
7
+ from PIL import Image
8
+
9
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+
11
+ from data_utils.privileged_schema import (
12
+ heuristic_bbox_from_visual_fact,
13
+ normalize_evidence_bbox,
14
+ parse_visual_fact,
15
+ resolve_crop_bbox,
16
+ )
17
+ from opsd_utils import debug_log as opsd_debug
18
+ from opsd_utils.privileged import build_privileged_context, maybe_save_privileged_images
19
+ from opsd_utils.privileged.image_utils import crop_image, load_rgb, resolve_teacher_images
20
+ from opsd_utils.privileged.profiles import effective_profile
21
+
22
+
23
+ def _make_image(path: str, size=(100, 100), color=(255, 0, 0)):
24
+ img = Image.new("RGB", size, color)
25
+ img.save(path)
26
+ return path
27
+
28
+
29
+ def test_text_provider():
30
+ sample = {"hint": "Rep=67", "answer": "Answer: 131"}
31
+ suffix, images = build_privileged_context(sample, ["text"], privileged_profile="text")
32
+ assert "Rep=67" in suffix
33
+ assert "131" in suffix
34
+ assert images == []
35
+
36
+
37
+ def test_hybrid_provider_suffix():
38
+ img = Image.new("RGB", (32, 32))
39
+ sample = {"hint": "step", "visual_fact": "bar=3", "answer": "Answer: 3", "image": img}
40
+ suffix, images = build_privileged_context(
41
+ sample,
42
+ privileged_profile="hybrid",
43
+ opsd_config={"privileged_image": {"mode": "dual"}},
44
+ )
45
+ assert "Visual Facts" in suffix
46
+ assert "Reference" in suffix
47
+ assert len(images) == 2
48
+
49
+
50
+ def test_hybrid_default_single_image_for_chartqa():
51
+ img = Image.new("RGB", (32, 32))
52
+ sample = {"hint": "step", "visual_fact": "bar=3", "answer": "Answer: 3", "image": img}
53
+ suffix, images = build_privileged_context(sample, privileged_profile="hybrid")
54
+ assert "Visual Facts" in suffix
55
+ assert "Reference" in suffix
56
+ assert len(images) == 1
57
+
58
+
59
+ def test_visual_profile_excludes_answer():
60
+ img = Image.new("RGB", (32, 32))
61
+ sample = {"hint": "secret", "visual_fact": '{"objects":[]}', "answer": "Answer: 3", "image": img}
62
+ suffix, _ = build_privileged_context(sample, privileged_profile="visual")
63
+ assert "Visual Facts" in suffix
64
+ assert "Reference Answer" not in suffix
65
+
66
+
67
+ def test_math_lm_downgrade():
68
+ sample = {"hint": "step", "answer": "Answer: 1"}
69
+ profile = effective_profile(sample, "hybrid")
70
+ assert profile == "text"
71
+
72
+
73
+ def test_normalize_evidence_bbox_c2():
74
+ assert normalize_evidence_bbox([0.1, 0.2, 0.8, 0.9]) == [0.1, 0.2, 0.8, 0.9]
75
+ assert normalize_evidence_bbox([0.1, 0.2, 1.5, 0.9]) is None
76
+
77
+
78
+ def test_heuristic_bbox_d2():
79
+ vf = json.dumps({"objects": [{"name": "cat", "position": "center"}]})
80
+ bbox = heuristic_bbox_from_visual_fact(vf)
81
+ assert bbox == [0.25, 0.25, 0.75, 0.75]
82
+
83
+
84
+ def test_crop_image_normalized_bbox():
85
+ img = Image.new("RGB", (100, 100), (0, 255, 0))
86
+ crop, strategy = crop_image(img, bbox_norm=[0.2, 0.2, 0.8, 0.8], strategy="bbox")
87
+ assert strategy == "bbox"
88
+ assert crop.size[0] > 0
89
+
90
+
91
+ def test_resolve_teacher_images_dual():
92
+ img = Image.new("RGB", (80, 80), (0, 0, 255))
93
+ sample = {
94
+ "image": img,
95
+ "visual_fact": json.dumps({"objects": [{"position": "top"}]}),
96
+ }
97
+ images, meta = resolve_teacher_images(sample, "hybrid", crop_cfg={"mode": "dual"})
98
+ assert len(images) == 2
99
+ assert meta["num_teacher_images"] == 2
100
+ assert meta["crop_strategy"] in ("heuristic", "center", "center_fallback", "bbox")
101
+
102
+
103
+ def test_chartqa_enriched_visual_fact_hint():
104
+ """Enriched ChartQA records (F1+F2) should activate VisualFactsProvider."""
105
+ from data_utils.chart.deplot_pipeline import build_deplot_visual_fact
106
+
107
+ sample = {
108
+ "hint": "Goal: Find the lowest value.\nObservation: values are 70, 72, 77.",
109
+ "answer": "Answer: 70",
110
+ "visual_fact_hint": "Goal: Find the lowest value.\nObservation: values are 70, 72, 77.",
111
+ "visual_fact": "Goal: Find the lowest value.\nObservation: values are 70, 72, 77.",
112
+ "visual_fact_deplot": build_deplot_visual_fact(
113
+ {"question": "q"}, "Year | Value\n2019 | 70\n2020 | 72"
114
+ ),
115
+ "image": Image.new("RGB", (64, 64)),
116
+ }
117
+ suffix, images = build_privileged_context(
118
+ sample,
119
+ ["text", "visual_facts"],
120
+ privileged_profile="hybrid",
121
+ )
122
+ assert "Visual Facts - Hint" in suffix
123
+ assert "Visual Facts - DePlot" in suffix
124
+ assert "2019 | 70" in suffix
125
+ assert "Reference Reasoning" in suffix
126
+ assert len(images) == 1
127
+ vf_raw = sample.get("visual_fact") or sample.get("visual_facts")
128
+ assert vf_raw and len(vf_raw.strip()) > 0
129
+
130
+
131
+ def test_visual_facts_f1_f2_merge():
132
+ from data_utils.chart.deplot_pipeline import build_deplot_visual_fact
133
+
134
+ sample = {
135
+ "visual_fact_hint": "hint table",
136
+ "visual_fact_deplot": build_deplot_visual_fact(
137
+ {"question": "q"}, "Col | Val\nA | 1"
138
+ ),
139
+ "image": Image.new("RGB", (32, 32)),
140
+ }
141
+ suffix, _ = build_privileged_context(sample, privileged_profile="hybrid")
142
+ assert "Visual Facts - Hint" in suffix
143
+ assert "Visual Facts - DePlot" in suffix
144
+ assert "Col | Val" in suffix
145
+
146
+
147
+ def test_parse_visual_fact_b1():
148
+ raw = {"objects": [{"name": "a"}]}
149
+ text = parse_visual_fact(raw)
150
+ assert "objects" in text
151
+
152
+
153
+ def test_debug_artifacts_respect_detail_every():
154
+ opsd_debug.configure(enabled=True, detail_every=10, rank=0, world_size=1)
155
+ with tempfile.TemporaryDirectory() as tmp:
156
+ img = Image.new("RGB", (32, 32))
157
+ path = maybe_save_privileged_images(5, 0, img, img, meta={"crop_strategy": "center"}, output_dir=tmp)
158
+ assert path is None
159
+ assert not os.path.exists(os.path.join(tmp, "logs", "images"))
160
+
161
+ path = maybe_save_privileged_images(10, 0, img, img, meta={"crop_strategy": "center"}, output_dir=tmp)
162
+ assert path is not None
163
+ assert os.path.exists(f"{path}_full.png")
164
+ assert os.path.exists(f"{path}_meta.json")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ test_text_provider()
169
+ test_hybrid_provider_suffix()
170
+ test_math_lm_downgrade()
171
+ test_debug_artifacts_respect_detail_every()
172
+ print("Privileged provider tests passed.")
tests/test_privileged_debug_artifacts.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import tempfile
5
+
6
+ from PIL import Image
7
+
8
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+
10
+ from opsd_utils import debug_log as opsd_debug
11
+ from opsd_utils.privileged.debug_artifacts import maybe_save_privileged_images
12
+
13
+
14
+ def test_max_samples_per_detail():
15
+ opsd_debug.configure(enabled=True, detail_every=1, rank=0, world_size=1)
16
+ cfg = {"save_images": True, "image_subdir": "logs/images", "max_samples_per_detail": 1}
17
+ img = Image.new("RGB", (16, 16))
18
+ with tempfile.TemporaryDirectory() as tmp:
19
+ p0 = maybe_save_privileged_images(1, 0, img, None, meta={}, output_dir=tmp, privileged_debug_cfg=cfg)
20
+ p1 = maybe_save_privileged_images(1, 1, img, None, meta={}, output_dir=tmp, privileged_debug_cfg=cfg)
21
+ assert p0 is not None
22
+ assert p1 is None
23
+
24
+
25
+ def test_meta_sidecar():
26
+ opsd_debug.configure(enabled=True, detail_every=1, rank=0, world_size=1)
27
+ img = Image.new("RGB", (16, 16))
28
+ with tempfile.TemporaryDirectory() as tmp:
29
+ prefix = maybe_save_privileged_images(
30
+ 1,
31
+ 0,
32
+ img,
33
+ img,
34
+ meta={"privileged_profile": "hybrid", "crop_strategy": "bbox"},
35
+ output_dir=tmp,
36
+ )
37
+ with open(f"{prefix}_meta.json", encoding="utf-8") as f:
38
+ meta = json.load(f)
39
+ assert meta["privileged_profile"] == "hybrid"
40
+ assert meta["crop_strategy"] == "bbox"
tests/test_slice_completion_logits.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for completion logit slicing shared by GRPO and OPSD."""
2
+ import os
3
+ import sys
4
+
5
+ import torch
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from opsd_utils.opsd_loss import slice_student_completion_logits
10
+
11
+
12
+ def test_slice_matches_legacy_grpo_path():
13
+ logits_to_keep = 4
14
+ full = torch.randn(2, 20, 8)
15
+ legacy = full[:, -logits_to_keep - 1 :, :]
16
+ legacy = legacy[:, :-1, :]
17
+ legacy = legacy[:, -logits_to_keep:, :]
18
+ assert torch.allclose(legacy, slice_student_completion_logits(full, logits_to_keep))
tests/test_teacher_dual_image.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import pytest
5
+ from PIL import Image
6
+
7
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from opsd_utils.privileged.image_utils import resolve_teacher_images
10
+ from opsd_utils.privileged.profiles import effective_profile
11
+
12
+
13
+ def test_text_profile_single_image():
14
+ img = Image.new("RGB", (64, 64))
15
+ sample = {"image": img, "hint": "h", "answer": "Answer: 1"}
16
+ images, meta = resolve_teacher_images(sample, "text")
17
+ assert len(images) == 1
18
+ assert meta["num_teacher_images"] == 1
19
+
20
+
21
+ def test_hybrid_profile_dual_image():
22
+ img = Image.new("RGB", (64, 64))
23
+ sample = {"image": img, "evidence_bbox": [0.1, 0.1, 0.9, 0.9]}
24
+ images, meta = resolve_teacher_images(sample, "hybrid", crop_cfg={"mode": "dual"})
25
+ assert len(images) == 2
26
+ assert meta["has_bbox"] is True
27
+
28
+
29
+ def test_hybrid_profile_single_image_by_default():
30
+ img = Image.new("RGB", (64, 64))
31
+ sample = {"image": img, "evidence_bbox": [0.1, 0.1, 0.9, 0.9]}
32
+ images, meta = resolve_teacher_images(sample, "hybrid")
33
+ assert len(images) == 1
34
+ assert meta["num_teacher_images"] == 1
35
+ assert meta["crop_strategy"] == "single_full"
36
+
37
+
38
+ def test_no_image_empty():
39
+ sample = {"hint": "only text"}
40
+ assert effective_profile(sample, "hybrid") == "text"
41
+ images, meta = resolve_teacher_images(sample, "text")
42
+ assert images == []
43
+ assert meta["num_teacher_images"] == 0
tests/test_vocab_align.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for cross-model vocab alignment diagnostics."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+
11
+ from opsd_utils.opsd_loss import generalized_jsd_loss
12
+ from opsd_utils.vocab_align import (
13
+ align_cross_model_logits,
14
+ reset_vocab_align_debug,
15
+ verify_shared_tokenizer_alignment,
16
+ )
17
+
18
+
19
+ class _Tok:
20
+ def __init__(self, vocab_size: int, offset: int = 0):
21
+ self._size = vocab_size
22
+ self._offset = offset
23
+
24
+ def __len__(self):
25
+ return self._size
26
+
27
+ def decode(self, ids, skip_special_tokens=False):
28
+ i = ids[0]
29
+ return f"tok_{i + self._offset}"
30
+
31
+ def convert_ids_to_tokens(self, i):
32
+ return f"tok_{i + self._offset}"
33
+
34
+
35
+ def test_align_slice_renorm_via_log_softmax():
36
+ reset_vocab_align_debug()
37
+ student = torch.randn(1, 3, 100, requires_grad=True)
38
+ teacher = torch.randn(1, 3, 128)
39
+ s, t = align_cross_model_logits(student, teacher, log_renorm_check=False)
40
+ assert s.shape[-1] == t.shape[-1] == 100
41
+ t_probs = F.softmax(t[0, 0], dim=-1)
42
+ assert abs(float(t_probs.sum()) - 1.0) < 1e-4
43
+
44
+
45
+ def test_generalized_jsd_renormalizes_after_slice():
46
+ reset_vocab_align_debug()
47
+ student = torch.randn(1, 5, 152000, requires_grad=True)
48
+ teacher = torch.randn(1, 5, 152128)
49
+ mask = torch.ones(1, 5)
50
+ loss = generalized_jsd_loss(student, teacher, mask)
51
+ assert loss.ndim == 0
52
+ assert loss.requires_grad
53
+
54
+
55
+ def test_tokenizer_alignment_detects_mismatch():
56
+ st = _Tok(200, offset=0)
57
+ tt = _Tok(200, offset=1)
58
+ report = verify_shared_tokenizer_alignment(
59
+ st, tt, shared_vocab=200, full_scan=True, sample_stride=1
60
+ )
61
+ assert not report["aligned"]
62
+ assert report["mismatch_count"] > 0
63
+
64
+
65
+ def test_tokenizer_alignment_passes_identical():
66
+ st = _Tok(1000, offset=0)
67
+ tt = _Tok(1200, offset=0)
68
+ report = verify_shared_tokenizer_alignment(
69
+ st, tt, shared_vocab=1000, full_scan=False, sample_stride=100
70
+ )
71
+ assert report["aligned"]
trainer/DyMETrainer_7B.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import textwrap
4
+ import warnings
5
+ from collections import defaultdict, deque
6
+ from collections.abc import Sized
7
+ from contextlib import nullcontext
8
+ from typing import Any, Callable, Optional, Union
9
+
10
+ from torch.nn.utils.rnn import pad_sequence
11
+
12
+ import datasets
13
+ import torch
14
+ import torch.utils.data
15
+ import transformers
16
+ from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
17
+ from datasets import Dataset, IterableDataset
18
+ from packaging import version
19
+ from torch import nn
20
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
21
+ from torch.utils.data import DataLoader, Sampler, DistributedSampler
22
+ from transformers import (
23
+ AutoModelForCausalLM,
24
+ AutoModelForSequenceClassification,
25
+ AutoTokenizer,
26
+ GenerationConfig,
27
+ PreTrainedModel,
28
+ PreTrainedTokenizerBase,
29
+ Trainer,
30
+ TrainerCallback,
31
+ is_wandb_available,
32
+ )
33
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
34
+ from transformers.trainer_utils import seed_worker
35
+ from transformers.utils import is_datasets_available, is_peft_available
36
+
37
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
38
+ from trl.extras.profiling import profiling_context, profiling_decorator
39
+ from trl.import_utils import is_liger_kernel_available, is_vllm_available
40
+ from trl.models import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
41
+ # from trl.models.utils import _ForwardRedirection
42
+ from trl.trainer.callbacks import SyncRefModelCallback
43
+ from trl.trainer.grpo_config import GRPOConfig
44
+ from trl.trainer.utils import (
45
+ disable_dropout_in_model,
46
+ generate_model_card,
47
+ get_comet_experiment_url,
48
+ pad,
49
+ print_prompt_completions_sample,
50
+ selective_log_softmax,
51
+ )
52
+
53
+ from trl.models import prepare_deepspeed, unwrap_model_for_generation
54
+ from trl.trainer.grpo_config import GRPOConfig
55
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, selective_log_softmax
56
+
57
+ import concurrent.futures
58
+ from datasets import Dataset, IterableDataset
59
+
60
+ from reward_utils import checker
61
+ from reward_utils.checker import RewardCalculator
62
+ from reward_utils.compute_rewards import calculate_rewards_in_parallel, refine_context_in_parallel
63
+
64
+ if is_wandb_available():
65
+ import wandb
66
+
67
+
68
+
69
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
70
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
71
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
72
+
73
+
74
+ class RepeatSampler(Sampler):
75
+ """
76
+ Sampler that repeats the indices of a dataset in a structured manner.
77
+
78
+ Args:
79
+ data_source (`Sized`):
80
+ Dataset to sample from.
81
+ mini_repeat_count (`int`):
82
+ Number of times to repeat each index per batch.
83
+ batch_size (`int`, *optional*, defaults to `1`):
84
+ Number of unique indices per batch.
85
+ repeat_count (`int`, *optional*, defaults to `1`):
86
+ Number of times to repeat the full sampling process.
87
+ shuffle (`bool`, *optional*, defaults to `True`):
88
+ Whether to shuffle the dataset.
89
+ seed (`int` or `None`, *optional*, defaults to `None`):
90
+ Random seed for reproducibility (only affects this sampler).
91
+
92
+ Example:
93
+ ```python
94
+ >>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
95
+ >>> list(sampler)
96
+ [4, 4, 3, 3, 0, 0,
97
+ 4, 4, 3, 3, 0, 0,
98
+ 4, 4, 3, 3, 0, 0,
99
+ 4, 4, 3, 3, 0, 0,
100
+
101
+ 1, 1, 2, 2, 6, 6,
102
+ 1, 1, 2, 2, 6, 6,
103
+ 1, 1, 2, 2, 6, 6,
104
+ 1, 1, 2, 2, 6, 6]
105
+ ```
106
+
107
+ ```txt
108
+ mini_repeat_count = 3
109
+ - - -
110
+ [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
111
+ 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
112
+ 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
113
+ repeat_count = 2
114
+ 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
115
+ 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
116
+ 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
117
+ --------- --------- --------- ---------
118
+ --------- --------- --------- ---------
119
+ --------- --------- --------- ---------
120
+ batch_size = 12
121
+ ```
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ data_source: Sized,
127
+ mini_repeat_count: int,
128
+ batch_size: int = 1,
129
+ repeat_count: int = 1,
130
+ shuffle: bool = True,
131
+ seed: Optional[int] = None,
132
+ ):
133
+ self.data_source = data_source
134
+ self.mini_repeat_count = mini_repeat_count
135
+ self.batch_size = batch_size
136
+ self.repeat_count = repeat_count
137
+ self.num_samples = len(data_source)
138
+ self.shuffle = shuffle
139
+ self.seed = seed
140
+
141
+ if shuffle:
142
+ self.generator = torch.Generator() # Create a local random generator
143
+ if seed is not None:
144
+ self.generator.manual_seed(seed)
145
+
146
+ def __iter__(self):
147
+ if self.shuffle:
148
+ # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
149
+ indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
150
+ else:
151
+ indexes = list(range(self.num_samples))
152
+
153
+ # [2, 4, 3, 1, 0, 6, 5]
154
+ # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
155
+ indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
156
+
157
+ # [[2, 4, 3], [1, 0, 6], [5]]
158
+ # -> [[2, 4, 3], [1, 0, 6]]
159
+ indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
160
+
161
+ for chunk in indexes:
162
+ for _ in range(self.repeat_count):
163
+ for index in chunk:
164
+ for _ in range(self.mini_repeat_count):
165
+ yield index
166
+
167
+ def __len__(self) -> int:
168
+ return self.num_samples * self.mini_repeat_count * self.repeat_count
169
+
170
+
171
+ # torch.nanstd doesn't exist, so we define it here
172
+ def nanstd(tensor: torch.Tensor) -> torch.Tensor:
173
+ """
174
+ Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.
175
+
176
+ Args:
177
+ tensor (`torch.Tensor`):
178
+ Input tensor of shape `(N,)`.
179
+
180
+ Returns:
181
+ `torch.Tensor`:
182
+ Standard deviation of the tensor, ignoring NaNs.
183
+ """
184
+ variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) # Compute variance ignoring NaNs
185
+ count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values
186
+ variance *= count / (count - 1) # Bessel's correction
187
+ return torch.sqrt(variance)
188
+
189
+
190
+ def split_tensor_dict(
191
+ tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int, image_patch_id=151655, patch_id_times=4
192
+ ) -> list[dict[str, Optional[torch.Tensor]]]:
193
+ """
194
+ Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts.
195
+
196
+ Example:
197
+ >>> x = torch.arange(12).reshape(6, 2)
198
+ >>> y = torch.arange(6).reshape(6, 1)
199
+ >>> tensor_dict = {"x": x, "y": y}
200
+ >>> split_tensor_dict(tensor_dict, 3)
201
+ [
202
+ {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])},
203
+ {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])},
204
+ {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])}
205
+ ]
206
+ """
207
+ if image_patch_id is None:
208
+ first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
209
+ chunk_size = first_tensor.shape[0] // num_chunks
210
+ # has = []
211
+ # if 'has_correct' in tensor_dict:
212
+ # has = tensor_dict['has_correct']
213
+ # del tensor_dict['has_correct']
214
+ l1 = []
215
+ for i in range(num_chunks):
216
+ dt = {
217
+ key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
218
+ for key, tensor in tensor_dict.items()
219
+ }
220
+ # if len(has) > 0:
221
+ # dt['has_correct'] = has[i]
222
+ l1.append(dt)
223
+
224
+ return l1
225
+ else:
226
+ first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
227
+ chunk_size = first_tensor.shape[0] // num_chunks
228
+ l1 = []
229
+ for i in range(num_chunks):
230
+ dt = {}
231
+ for key, tensor in tensor_dict.items():
232
+ if key != 'pixel_values':
233
+ dt[key] = tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
234
+
235
+ l1.append(dt)
236
+
237
+ if 'pixel_values' in tensor_dict:
238
+ raw_pixel_values = tensor_dict['pixel_values']
239
+ start_image_patch = 0
240
+ for dt in l1:
241
+ batch_input_ids = dt['prompt_ids']
242
+ num_image_patches = (batch_input_ids == image_patch_id).sum().item() * patch_id_times
243
+ batch_pixel_values = raw_pixel_values[start_image_patch : start_image_patch + num_image_patches]
244
+ start_image_patch += num_image_patches
245
+ dt['pixel_values'] = batch_pixel_values
246
+
247
+ return l1
248
+
249
+
250
+ def nanmin(tensor: torch.Tensor) -> torch.Tensor:
251
+ """
252
+ Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
253
+
254
+ Args:
255
+ tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
256
+
257
+ Returns:
258
+ `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
259
+ """
260
+ if torch.isnan(tensor).all():
261
+ return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
262
+ return torch.min(tensor[~torch.isnan(tensor)])
263
+
264
+
265
+ def nanmax(tensor: torch.Tensor) -> torch.Tensor:
266
+ """
267
+ Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
268
+
269
+ Args:
270
+ tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
271
+
272
+ Returns:
273
+ `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
274
+ """
275
+ if torch.isnan(tensor).all():
276
+ return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
277
+ return torch.max(tensor[~torch.isnan(tensor)])
278
+
279
+
280
+ class DyMETrainer(Trainer):
281
+
282
+ def __init__(
283
+ self,
284
+ model: PreTrainedModel,
285
+ checker = None,
286
+ refiner=None,
287
+ args: Optional[GRPOConfig] = None,
288
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
289
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
290
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
291
+ callbacks: Optional[list[TrainerCallback]] = None,
292
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
293
+ processing_func = None,
294
+ task_name: str = None,
295
+ end_flag: str = '<|im_end|>',
296
+ ):
297
+ self.task_name = task_name
298
+ self.reward_weights = torch.nn.Parameter(torch.ones(3), requires_grad=False)
299
+ self.reward_func_names = ['format', 'thinking', 'accuracy']
300
+ # Models
301
+ # Trained model
302
+ model_init_kwargs = args.model_init_kwargs or {}
303
+
304
+ # Enable gradient checkpointing if requested
305
+ if args.gradient_checkpointing:
306
+ model = self._enable_gradient_checkpointing(model, args)
307
+
308
+ # Processing class
309
+ if processing_class is None:
310
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
311
+
312
+ # Training arguments
313
+ self.max_prompt_length = args.max_prompt_length
314
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
315
+ self.num_generations = args.num_generations # = G in the GRPO paper
316
+ self.temperature = args.temperature
317
+ self.top_p = args.top_p
318
+ self.top_k = args.top_k
319
+ self.min_p = args.min_p
320
+ self.repetition_penalty = args.repetition_penalty
321
+ self.use_liger_loss = args.use_liger_loss
322
+ self.loss_type = args.loss_type
323
+ self.scale_rewards = args.scale_rewards
324
+ self.mask_truncated_completions = args.mask_truncated_completions
325
+ self.end_flag = end_flag
326
+ self.checker = checker
327
+ self.refiner = refiner
328
+ # Datasets
329
+ self.shuffle_dataset = args.shuffle_dataset
330
+
331
+ if (
332
+ isinstance(train_dataset, IterableDataset)
333
+ or isinstance(eval_dataset, IterableDataset)
334
+ or (
335
+ isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
336
+ )
337
+ ):
338
+ # See https://github.com/huggingface/trl/issues/3213
339
+ raise NotImplementedError(
340
+ "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead."
341
+ )
342
+
343
+ # Multi-step
344
+ self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
345
+ self.epsilon_low = args.epsilon
346
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
347
+ # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
348
+ self._step = 0
349
+ # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
350
+ # `_get_train_sampler` and `_prepare_inputs`.
351
+ self._buffered_inputs = None
352
+
353
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
354
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
355
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
356
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
357
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
358
+ # This acts as a flag to indicate that the warning has already been issued.
359
+ model.warnings_issued["estimate_tokens"] = True
360
+ def data_collator(features): # No data collation is needed in GRPO
361
+ return features
362
+
363
+ super().__init__(
364
+ model=model,
365
+ args=args,
366
+ data_collator=data_collator,
367
+ train_dataset=train_dataset,
368
+ eval_dataset=eval_dataset,
369
+ processing_class=processing_class,
370
+ callbacks=callbacks,
371
+ optimizers=optimizers,
372
+ )
373
+
374
+ # Reference model
375
+ self.beta = args.beta
376
+ assert self.beta == 0
377
+
378
+ # Disable dropout in the models
379
+ if args.disable_dropout:
380
+ disable_dropout_in_model(model)
381
+
382
+ # Initialize the metrics
383
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
384
+ self._total_train_tokens = 0
385
+ self.log_completions = args.log_completions
386
+ self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
387
+ self.num_completions_to_print = args.num_completions_to_print
388
+ # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
389
+ # final optimization step.
390
+ maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps
391
+ self._textual_logs = {
392
+ "prompt": deque(maxlen=maxlen),
393
+ "completion": deque(maxlen=maxlen),
394
+ "rewards": defaultdict(lambda: deque(maxlen=maxlen)),
395
+ }
396
+
397
+ # Check if the effective batch size can be divided by the number of generations
398
+ if self.num_generations < 2:
399
+ raise ValueError(
400
+ "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
401
+ f"{self.num_generations}, which is less than the minimum required."
402
+ )
403
+ num_processes = self.accelerator.num_processes
404
+ effective_batch_size = args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps
405
+ possible_values = [
406
+ n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0
407
+ ]
408
+ if self.num_generations not in possible_values:
409
+ raise ValueError(
410
+ f"The effective train batch size ({num_processes} x {args.per_device_train_batch_size} x "
411
+ f"{args.gradient_accumulation_steps}) must be evenly divisible by the number of generations per "
412
+ f"prompt ({self.num_generations}). Given the current effective train batch size, the valid values for "
413
+ f"the number of generations are: {possible_values}."
414
+ )
415
+
416
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
417
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
418
+ # it's safer to set it in all cases.
419
+ set_seed(args.seed, device_specific=True)
420
+
421
+
422
+ self.generation_config = GenerationConfig(
423
+ max_new_tokens=self.max_completion_length,
424
+ do_sample=True,
425
+ pad_token_id=processing_class.tokenizer.pad_token_id,
426
+ bos_token_id=processing_class.tokenizer.bos_token_id,
427
+ eos_token_id=processing_class.tokenizer.eos_token_id,
428
+ temperature=self.temperature,
429
+ top_p=self.top_p,
430
+ top_k=self.top_k,
431
+ min_p=self.min_p,
432
+ repetition_penalty=self.repetition_penalty,
433
+ cache_implementation=args.cache_implementation,
434
+ use_cache=False if self.args.gradient_checkpointing else True
435
+ )
436
+
437
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
438
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
439
+ # self.model_accepts_loss_kwargs to False to enable scaling.
440
+ self.model_accepts_loss_kwargs = False
441
+ self.processing_func = processing_func
442
+
443
+ def _set_signature_columns_if_needed(self):
444
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
445
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
446
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
447
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
448
+ if self._signature_columns is None:
449
+ self._signature_columns = ["prompt"]
450
+
451
+ def get_train_dataloader(self):
452
+ if self.train_dataset is None:
453
+ raise ValueError("Trainer: training requires a train_dataset.")
454
+
455
+ train_dataset = self.train_dataset
456
+ data_collator = self.data_collator
457
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
458
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
459
+ else:
460
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
461
+
462
+ dataloader_params = {
463
+ "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, # < this is the change
464
+ "collate_fn": data_collator,
465
+ "num_workers": self.args.dataloader_num_workers,
466
+ "pin_memory": self.args.dataloader_pin_memory,
467
+ "persistent_workers": self.args.dataloader_persistent_workers,
468
+ }
469
+
470
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
471
+ dataloader_params["sampler"] = self._get_train_sampler()
472
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
473
+ dataloader_params["worker_init_fn"] = seed_worker
474
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
475
+ dl = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
476
+ return dl
477
+
478
+ def _get_train_sampler(self) -> Sampler:
479
+ effective_batch_size = (
480
+ self.args.per_device_train_batch_size
481
+ * self.accelerator.num_processes
482
+ * self.args.gradient_accumulation_steps
483
+ )
484
+ return RepeatSampler(
485
+ data_source=self.train_dataset,
486
+ mini_repeat_count=self.num_generations,
487
+ batch_size=effective_batch_size // self.num_generations,
488
+ repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
489
+ shuffle=self.shuffle_dataset,
490
+ seed=self.args.seed,
491
+ )
492
+
493
+ def _get_eval_sampler(self, eval_dataset):
494
+ # eval_dataset 是一个 map-style Dataset(非 IterableDataset)
495
+ return DistributedSampler(
496
+ dataset=eval_dataset,
497
+ num_replicas=self.accelerator.num_processes,
498
+ rank=self.accelerator.process_index,
499
+ shuffle=False,
500
+ seed=self.args.seed,
501
+ )
502
+
503
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
504
+ """Enables gradient checkpointing for the model."""
505
+ # Ensure use_cache is disabled
506
+ model.config.use_cache = False
507
+
508
+ # Enable gradient checkpointing on the base model for PEFT
509
+ if is_peft_model(model):
510
+ model.base_model.gradient_checkpointing_enable()
511
+ # Enable gradient checkpointing for non-PEFT models
512
+ else:
513
+ model.gradient_checkpointing_enable()
514
+
515
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
516
+ use_reentrant = (
517
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
518
+ )
519
+
520
+ if use_reentrant:
521
+ model.enable_input_require_grads()
522
+
523
+ return model
524
+
525
+ @profiling_decorator
526
+ def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, logits_to_keep=None):
527
+ if is_peft_model(unwrapped_model):
528
+ unwrapped_model = unwrapped_model.base_model.model
529
+ last_hidden_state = unwrapped_model.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
530
+ last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H)
531
+ if logits_to_keep is not None:
532
+ last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
533
+ return last_hidden_state
534
+
535
+ # Get the per-token log probabilities for the completions for the model and the reference model
536
+ @profiling_decorator
537
+ def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thws, logits_to_keep, batch_size=None) -> torch.Tensor:
538
+ batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
539
+ all_logps = []
540
+ patch_s = 0
541
+ for i in range(0, input_ids.size(0), batch_size):
542
+ input_ids_batch = input_ids[i : i + batch_size]
543
+ attention_mask_batch = attention_mask[i : i + batch_size]
544
+ img_id = self.processing_class.tokenizer.convert_tokens_to_ids('<|image_pad|>')
545
+ image_patch_nums = (input_ids_batch == img_id).sum().item() * 4 # 每个 image_pad 对应 4 个图像 patch
546
+ # print("image_patch_nums", image_patch_nums, pixel_values.shape)
547
+ pixel_values_batch = pixel_values[patch_s : patch_s + image_patch_nums]
548
+ patch_s += image_patch_nums
549
+ image_grid_thw_batch = image_grid_thws[i : i + batch_size]
550
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
551
+ logits = model(
552
+ input_ids=input_ids_batch, pixel_values=pixel_values_batch, image_grid_thw=image_grid_thw_batch,
553
+ attention_mask=attention_mask_batch
554
+ ).logits
555
+ # logits = logits[:, :-1, :] # (B, L-1, H)
556
+ if logits_to_keep is not None:
557
+ logits = logits[:, -logits_to_keep-1:, :] # (B, logits_to_keep, H)
558
+
559
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
560
+ input_ids_batch = input_ids_batch[:, -logits_to_keep:]
561
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
562
+ # See https://github.com/huggingface/trl/issues/2770
563
+ logits = logits[:, -logits_to_keep:]
564
+ # Divide logits by sampling temperature.
565
+ # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
566
+ logits = logits / self.temperature
567
+ logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens
568
+ all_logps.append(logps)
569
+ return torch.cat(all_logps, dim=0)
570
+
571
+ @profiling_decorator
572
+ def _prepare_inputs(
573
+ self, accumulated_local_batch: dict[str, Union[torch.Tensor, Any]]
574
+ ) -> dict[str, Union[torch.Tensor, Any]]:
575
+
576
+ mode = "train" if self.model.training else "eval"
577
+ if mode == "train":
578
+ generate_every = self.args.gradient_accumulation_steps * self.num_iterations
579
+ if self._step % generate_every == 0 or self._buffered_inputs is None:
580
+ # self._buffered_inputs=None can occur when resuming from a checkpoint
581
+ accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
582
+ self._buffered_inputs = split_tensor_dict(
583
+ accumulated_local_batch, self.args.gradient_accumulation_steps
584
+ )
585
+ inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
586
+ self._step += 1
587
+ else:
588
+ # In evaluation, there is neither gradient accumulation, nor multiple iterations
589
+ inputs = self._generate_and_score_completions(accumulated_local_batch)
590
+ return inputs
591
+
592
+ def _generate_and_score_completions(
593
+ self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
594
+ ) -> dict[str, Union[torch.Tensor, Any]]:
595
+
596
+ # TODO
597
+ device = self.accelerator.device
598
+ mode = "train" if self.model.training else "eval"
599
+
600
+ inputs_for_generate = inputs.copy()
601
+
602
+ # 去除answer key
603
+ inputs_for_generate = [{k: v for k, v in x.items() if k != 'answer'} for x in inputs_for_generate]
604
+
605
+ dt_generate_dt = self.processing_func(inputs_for_generate)
606
+ prompt_inputs_generate = super(DyMETrainer, self)._prepare_inputs(dt_generate_dt)
607
+ if 'labels' in prompt_inputs_generate:
608
+ del prompt_inputs_generate["labels"]
609
+ prompt_ids = prompt_inputs_generate["input_ids"]
610
+ prompt_mask = prompt_inputs_generate["attention_mask"]
611
+ pixel_values = prompt_inputs_generate["pixel_values"]
612
+
613
+ image_grid_thws = prompt_inputs_generate["image_grid_thw"]
614
+
615
+ # Regular generation path
616
+ with unwrap_model_for_generation(
617
+ self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
618
+ ) as unwrapped_model:
619
+ with (
620
+ FSDP.summon_full_params(self.model_wrapped, recurse=False)
621
+ if self.is_fsdp_enabled
622
+ else nullcontext()
623
+ ):
624
+ prompt_completion_ids = unwrapped_model.generate(**prompt_inputs_generate, generation_config=self.generation_config)
625
+
626
+
627
+ # Compute prompt length and extract completion ids
628
+ prompt_length = prompt_ids.size(1)
629
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
630
+ completion_ids = prompt_completion_ids[:, prompt_length:]
631
+
632
+ # Mask everything after the first EOS token
633
+ is_eos = completion_ids == self.processing_class.tokenizer.eos_token_id
634
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
635
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
636
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
637
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
638
+
639
+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
640
+ if self.mask_truncated_completions:
641
+ truncated_completions = ~is_eos.any(dim=1)
642
+ completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
643
+
644
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
645
+
646
+ batch_size = len(completion_ids)
647
+ images = [x['image'] for x in inputs]
648
+ prompts = [x['prompt'] for x in inputs]
649
+ question_wo_prompts = [x['question_wo_prompt'] for x in inputs]
650
+ hints = [x.get('hint', '') for x in inputs]
651
+ answers = [x['answer'] for x in inputs]
652
+ images_path = [image if isinstance(image, str) else image.filename for image in images]
653
+ batch_data = {'prompt': prompts, 'hints': hints,
654
+ 'image': images_path, 'response': completions, 'answer': answers}
655
+
656
+ gpu_id = self.accelerator.device.index
657
+ all_rewards, format_rewards, acc_rewards, context_rewards = calculate_rewards_in_parallel(self.checker, batch_data,
658
+ gpu_id=gpu_id,
659
+ task=self.task_name, num_threads=1)
660
+ all_rewards = torch.tensor(all_rewards, dtype=torch.float32).to(self.accelerator.device)
661
+ format_rewards = torch.tensor(format_rewards, dtype=torch.float32).to(self.accelerator.device)
662
+ context_rewards = torch.tensor(context_rewards, dtype=torch.float32).to(self.accelerator.device)
663
+ acc_rewards = torch.tensor(acc_rewards, dtype=torch.float32).to(self.accelerator.device)
664
+
665
+ rewards_per_func = torch.zeros([len(all_rewards), 3], device=device)
666
+
667
+ rewards_per_func[:, 0] = format_rewards.clone()
668
+ rewards_per_func[:, 1] = context_rewards.clone()
669
+ rewards_per_func[:, -1] = acc_rewards.clone()
670
+
671
+ rewards_per_func = gather(rewards_per_func)
672
+
673
+ # Apply weights to each reward function's output and sum
674
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
675
+
676
+ # Compute grouped-wise rewards
677
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
678
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
679
+
680
+ # Normalize the rewards to compute the advantages
681
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
682
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
683
+ advantages = rewards - mean_grouped_rewards
684
+ if self.scale_rewards:
685
+ advantages = advantages / (std_grouped_rewards + 1e-4)
686
+
687
+ # Slice to keep only the local part of the data
688
+ process_slice = slice(
689
+ self.accelerator.process_index * len(prompts),
690
+ (self.accelerator.process_index + 1) * len(prompts),
691
+ )
692
+ advantages = advantages[process_slice]
693
+ advantages = advantages.reshape(-1, 1)
694
+ acc_rewards = acc_rewards.view(-1, self.num_generations)
695
+ format_rewards = format_rewards.view(-1, self.num_generations)
696
+
697
+ has_correct = (acc_rewards > 0.5).sum(1)
698
+ format_rewards = format_rewards.view(-1)
699
+
700
+ sft_check = []
701
+ for i in range(batch_size):
702
+ batch_id = i // self.num_generations
703
+ sft_check.append((has_correct[batch_id] == 0) & (i % self.num_generations == 0))
704
+
705
+ hints = refine_context_in_parallel(self.refiner, question_wo_prompts, hints, answers, task=self.task_name, gpu_id=gpu_id, num_threads=1)
706
+
707
+ sft_gt = [hint + '\n' + answer + self.end_flag for hint, answer in zip(hints, answers)]
708
+
709
+ sft_dt = self.processing_class.tokenizer(sft_gt, return_tensors="pt", padding=True,
710
+ padding_side="right")
711
+ sft_padded_ids = sft_dt['input_ids'].to(device)
712
+ sft_attn_masks = sft_dt['attention_mask'].to(device)
713
+ sft_advantages = torch.ones_like(sft_attn_masks, device=device)
714
+
715
+ final_completion_id_list = []
716
+ final_completion_mask_list = []
717
+ final_advantange_list = []
718
+
719
+ for i in range(len(sft_padded_ids)):
720
+ batch_id = i // self.num_generations
721
+ if has_correct[batch_id] == 0:
722
+ if sft_check[i]: # 第一个修改为正确答案,其他的保留为错误的。
723
+ completion_id_ = torch.cat([sft_padded_ids[i], completion_ids[i][0:0]])
724
+ completion_mask_ = torch.cat([sft_attn_masks[i], completion_mask[i][0:0]])
725
+ advantange_ = torch.cat([sft_advantages[i], advantages[i][0:0]])
726
+ advantange_[:] = 1
727
+ else:
728
+ completion_id_ = torch.cat([completion_ids[i], completion_ids[i][0:0]])
729
+ completion_mask_ = torch.cat([completion_mask[i], sft_attn_masks[i][0:0]])
730
+ advantange_ = torch.cat([advantages[i], sft_advantages[i][0:0]])
731
+ advantange_ = advantange_.repeat_interleave(len(completion_id_))
732
+ advantange_[:] = 0
733
+
734
+ else:
735
+ completion_id_ = torch.cat([completion_ids[i], sft_padded_ids[i][0:0]])
736
+ completion_mask_ = torch.cat([completion_mask[i], sft_attn_masks[i][0:0]])
737
+ advantange_ = torch.cat([advantages[i], sft_advantages[i][0:0]])
738
+ # 如果advantange_是一个数字的话需要扩展维度
739
+ advantange_ = advantange_.repeat_interleave(len(completion_id_))
740
+
741
+ if has_correct[batch_id] == self.num_generations: # 全部正确时停止优化
742
+ advantange_[:] = 0
743
+
744
+ final_completion_id_list.append(completion_id_)
745
+ final_completion_mask_list.append(completion_mask_)
746
+ final_advantange_list.append(advantange_)
747
+
748
+ completion_ids = pad_sequence(final_completion_id_list, batch_first=True,
749
+ padding_value=self.processing_class.tokenizer.pad_token_id).long()
750
+ completion_mask = pad_sequence(final_completion_mask_list, batch_first=True, padding_value=0)
751
+ completion_advantange = pad_sequence(final_advantange_list, batch_first=True, padding_value=0)
752
+ completion_ids = completion_ids.to(device)
753
+ completion_mask = completion_mask.to(device)
754
+ completion_advantange = completion_advantange.to(device)
755
+ input_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1).long()
756
+ attention_completion_mask = torch.cat([prompt_mask, completion_mask], dim=1)
757
+
758
+ for s, a in enumerate(completion_advantange):
759
+ if acc_rewards.view(-1)[s] > 0 and format_rewards.view(-1)[s] > 0 and a[0] < 0:
760
+ print('no')
761
+
762
+ if self.accelerator.device.index == 0:
763
+ completion_id = completion_ids[0]
764
+ completion_id_pos = completion_id[(completion_advantange[0] > 0) & (completion_mask[0] > 0)]
765
+ completion_id_neg = completion_id[(completion_advantange[0] < 0) & (completion_mask[0] > 0)]
766
+
767
+ show = self.processing_class.decode(completion_id_pos, skip_special_tokens=False)
768
+ show_neg = self.processing_class.decode(completion_id_neg, skip_special_tokens=False)
769
+ print("\n=====has_correct====================\n", has_correct,)
770
+ print("\n=====prediction====================\n", completions[0],)
771
+ if show != "":
772
+ print("\n=====POS GT====================\n", show)
773
+ if show_neg != "":
774
+ print("\n======NEG GT===================\n", show_neg)
775
+
776
+ # Concatenate prompt_mask with completion_mask for logit computation
777
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
778
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
779
+ batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
780
+
781
+ with torch.no_grad():
782
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
783
+ # computation here, and use per_token_logps.detach() instead.
784
+ if self.num_iterations > 1:
785
+ old_per_token_logps = self._get_per_token_logps(self.model, input_completion_ids, attention_completion_mask, pixel_values, image_grid_thws,
786
+ logits_to_keep, batch_size)
787
+ else:
788
+ old_per_token_logps = None
789
+
790
+ # Log the metrics
791
+ if mode == "train":
792
+ self.state.num_input_tokens_seen += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
793
+
794
+ # log completion lengths, mean, min, max
795
+ agg_completion_mask = self.accelerator.gather_for_metrics(completion_mask.sum(1))
796
+ self._metrics[mode]["completions/mean_length"].append(agg_completion_mask.float().mean().item())
797
+
798
+ # identify sequences that terminated with EOS and log their lengths
799
+ agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
800
+ term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
801
+ clipped_completions_ratio = 1 - len(term_completion_mask) / len(agg_completion_mask)
802
+ self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
803
+
804
+ # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
805
+ for i, reward_func_name in enumerate(self.reward_func_names):
806
+ mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
807
+ self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
808
+ self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
809
+ self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
810
+
811
+ for i, name in enumerate(self.reward_func_names):
812
+ self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
813
+ # completion_advantange: (batch_size, seq_len) 或 (batch_size, n)
814
+ mask_pos = completion_advantange > 0 # 正优势位置
815
+ row_min = completion_advantange.min(dim=1, keepdim=True).values.abs() # (batch, 1)
816
+
817
+ # 只对正优势加 abs(row_min),其余位置设 0
818
+ # completion_advantange = torch.where(
819
+ # mask_pos,
820
+ # completion_advantange + row_min, # broadcasting 自动对齐到每一行
821
+ # torch.zeros_like(completion_advantange)
822
+ # )
823
+ return {
824
+ "prompt_ids": prompt_ids,
825
+ "prompt_mask": prompt_mask,
826
+ "pixel_values": pixel_values,
827
+ "completion_ids": completion_ids,
828
+ "completion_mask": completion_mask,
829
+ "advantages": completion_advantange,
830
+ "old_per_token_logps": old_per_token_logps,
831
+ # "has_correct": has_correct,
832
+ "image_grid_thws": image_grid_thws
833
+ }
834
+
835
+ def compute_liger_loss(self, unwrapped_model, inputs):
836
+ # Compute the per-token log probabilities for the model
837
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
838
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
839
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
840
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
841
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
842
+
843
+ # Compute the KL divergence between the model and the reference model
844
+ ref_per_token_logps = None
845
+
846
+ # get the last hidden state of the model
847
+ last_hidden_state = self._get_last_hidden_state(unwrapped_model, input_ids, attention_mask, logits_to_keep)
848
+
849
+ # compute loss and metrics using liger grpo loss
850
+ loss, metrics = self.liger_grpo_loss(
851
+ _input=last_hidden_state,
852
+ lin_weight=unwrapped_model.lm_head.weight,
853
+ selected_token_ids=completion_ids,
854
+ attention_mask=completion_mask,
855
+ advantages=inputs["advantages"][:, 0],
856
+ bias=unwrapped_model.lm_head.bias,
857
+ old_per_token_logps=inputs["old_per_token_logps"],
858
+ ref_per_token_logps=ref_per_token_logps,
859
+ )
860
+ # Extract metrics from the liger_grpo_loss output
861
+ # KL divergence is the first metric when beta is non-zero
862
+ mean_kl = metrics[0] if self.beta != 0.0 else None
863
+ clip_ratio = metrics[-1]
864
+
865
+ mode = "train" if self.model.training else "eval"
866
+ if self.beta != 0.0:
867
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
868
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
869
+ return loss
870
+
871
+ @profiling_decorator
872
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
873
+ if return_outputs:
874
+ raise ValueError("The GRPOTrainer does not support returning outputs")
875
+ if self.use_liger_loss:
876
+ # Compute the loss using the liger grpo loss
877
+ unwrapped_model = self.accelerator.unwrap_model(model)
878
+ return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
879
+ else:
880
+ return self._compute_loss(model, inputs)
881
+
882
+ def _compute_loss(self, model, inputs):
883
+ # return torch.nn.Parameter(torch.tensor(0.0, device=self.accelerator.device)) # Dummy loss for compatibility
884
+ # Compute the per-token log probabilities for the model
885
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
886
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
887
+ pixel_values = inputs["pixel_values"]
888
+
889
+ # has_correct = inputs["has_correct"]
890
+ image_grid_thws = inputs["image_grid_thws"]
891
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
892
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
893
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
894
+ try:
895
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thws,
896
+ logits_to_keep)
897
+ except Exception as e:
898
+ print(f"Error in _get_per_token_logps: {e}")
899
+ raise e
900
+
901
+ # sft_loss = -(per_token_logps * completion_mask).sum(-1) / completion_mask.sum(-1)
902
+ advantages = inputs["advantages"][:, 0]
903
+ # sft_loss = (sft_loss * (advantages > 0)).sum() * (has_correct == 0)
904
+ # return sft_loss
905
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
906
+ # _generate_and_score_completions) and use per_token_logps.detach() instead.
907
+ old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
908
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
909
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
910
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
911
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
912
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
913
+
914
+ if self.loss_type == "grpo":
915
+ loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
916
+ elif self.loss_type == "bnpo":
917
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
918
+ elif self.loss_type == "dr_grpo":
919
+ loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
920
+ else:
921
+ raise ValueError(f"Unknown loss type: {self.loss_type}")
922
+ # loss = (has_correct > 0) * loss + sft_loss
923
+ # loss = (has_correct > 0) * loss
924
+ # Log the metrics
925
+ mode = "train" if self.model.training else "eval"
926
+
927
+ # Compute the clipped probability ratios
928
+ is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
929
+ is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
930
+ is_region_clipped = is_low_clipped | is_high_clipped
931
+
932
+ low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
933
+ high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
934
+ clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
935
+
936
+ gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
937
+ self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
938
+ self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
939
+ gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
940
+ self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
941
+ self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
942
+ gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
943
+ self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
944
+ return loss
945
+
946
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
947
+ inputs = self._prepare_inputs(inputs)
948
+ with torch.no_grad():
949
+ with self.compute_loss_context_manager():
950
+ loss = self.compute_loss(model, inputs)
951
+ loss = loss.mean().detach()
952
+ return loss, None, None
953
+
954
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
955
+ mode = "train" if self.model.training else "eval"
956
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
957
+
958
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
959
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
960
+ if mode == "eval":
961
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
962
+
963
+ logs = {**logs, **metrics}
964
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
965
+ super().log(logs, start_time)
966
+ else: # transformers<=4.46
967
+ super().log(logs)
968
+ self._metrics[mode].clear()
969
+
970
+ if self.accelerator.is_main_process and self.log_completions:
971
+ if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
972
+ import pandas as pd
973
+
974
+ table = {
975
+ "step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]),
976
+ "prompt": self._textual_logs["prompt"],
977
+ "completion": self._textual_logs["completion"],
978
+ **self._textual_logs["rewards"],
979
+ }
980
+ df = pd.DataFrame(table)
981
+ if self.wandb_log_unique_prompts:
982
+ df = df.drop_duplicates(subset=["prompt"])
983
+ wandb.log({"completions": wandb.Table(dataframe=df)})
trainer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .DyMETrainer import DyMETrainer
2
+
3
+ __all__ = ["DyMETrainer"]