Add files using upload-large-folder tool
Browse files- README.md +712 -0
- client_utils/openai_api.py +123 -0
- config/__init__.py +21 -0
- config/__pycache__/config_rlsd_chartqa.cpython-310.pyc +0 -0
- config/__pycache__/config_trimode.cpython-310.pyc +0 -0
- config/__pycache__/config_trimode_antidegen.cpython-310.pyc +0 -0
- config/__pycache__/loader.cpython-312.pyc +0 -0
- config/config_7B.py +82 -0
- config/config_aok.py +119 -0
- config/config_llavacot.py +118 -0
- config/config_low.py +120 -0
- config/config_opd_7b_chartqa.py +48 -0
- config/config_rlsd_chartqa.py +152 -0
- config/config_trimode.py +88 -0
- default_config_8gpu.yaml +16 -0
- default_config_8gpu_deepspeed.yaml +21 -0
- default_config_zero2_8gpu.yaml +18 -0
- eval/eval_chartqa.py +310 -0
- figs/chartqa.png +0 -0
- kill_all.sh +55 -0
- main.py +522 -0
- main_llm.py +197 -0
- main_sft.py +80 -0
- multi_node_config_raw.yaml +21 -0
- opsd_utils/__pycache__/opsd_loss.cpython-312.pyc +0 -0
- opsd_utils/gate_policy.py +107 -0
- opsd_utils/health_monitor.py +410 -0
- opsd_utils/privileged/__pycache__/providers.cpython-310.pyc +0 -0
- opsd_utils/privileged/image_utils.py +143 -0
- opsd_utils/prompt_builder.py +265 -0
- outputs/logs/.ipynb_checkpoints/train_opd_7b_ds_20260614_175014-checkpoint.log +0 -0
- outputs/opd-7b-chartqa-ds/checkpoint-1764/zero_to_fp32.py +760 -0
- outputs/opd-7b-chartqa-ds/checkpoint-2352/preprocessor_config.json +171 -0
- outputs/opd-7b-chartqa-ds/checkpoint-588/config.json +235 -0
- papers/full_text.txt +1211 -0
- requirements.txt +16 -0
- reward_utils/__pycache__/format_checks.cpython-310.pyc +0 -0
- reward_utils/compute_rewards.py +126 -0
- reward_utils/refiner.py +162 -0
- tests/test_data_health_probe.py +28 -0
- tests/test_degeneration_probe.py +98 -0
- tests/test_health_monitor.py +72 -0
- tests/test_mode_router_rlsd.py +97 -0
- tests/test_privileged.py +172 -0
- tests/test_privileged_debug_artifacts.py +40 -0
- tests/test_slice_completion_logits.py +18 -0
- tests/test_teacher_dual_image.py +43 -0
- tests/test_vocab_align.py +71 -0
- trainer/DyMETrainer_7B.py +983 -0
- trainer/__init__.py +3 -0
README.md
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DyME: Empowering Small-scale VLMs with Reliable Thinking Capabilities
|
| 2 |
+
|
| 3 |
+
[](#)
|
| 4 |
+
[](https://arxiv.org/abs/2506.23061)
|
| 5 |
+
|
| 6 |
+
This repository provides the official implementation of **DyME** (**Dy**namically selecting between **M**emorization and **E**xploration), accepted at **ICLR 2026**.
|
| 7 |
+
|
| 8 |
+
## Overview
|
| 9 |
+
|
| 10 |
+
Small-scale Vision-Language Models (SVLMs) are highly suited for proprietary tasks, but equipping them with reasoning and thinking capabilities remains challenging. Traditional Supervised Fine-Tuning (SFT) can force memorization of pseudo thinking traces, while Reinforcement Learning with Verifiable Reward (RLVR) often leads to unstable exploration (advantage collapse) due to limited model capacity.
|
| 11 |
+
|
| 12 |
+
**DyME** is a novel training paradigm that dynamically synergizes SFT and RLVR. At each optimization step, DyME dynamically selects between Memorization (via SFT) and Exploration (via RLVR), ensuring every update contributes to an optimal trade-off. To further enhance this, we introduce a **Visual Supervision mechanism** (a visual checker and refiner) to inject dynamically enhanced, image-grounded guidance during training.
|
| 13 |
+
|
| 14 |
+
Extensive experiments show that DyME delivers substantial performance improvements, establishing it as a robust strategy for stabilizing SVLM learning.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Repository Structure
|
| 18 |
+
|
| 19 |
+
```text
|
| 20 |
+
DyME/
|
| 21 |
+
├── client_utils/ # Client tools for online Visual Supervision (LLM API)
|
| 22 |
+
├── data/ # Preprocessed textual datasets
|
| 23 |
+
├── data_utils/ # Data processing and formatting scripts
|
| 24 |
+
│ ├── aokvqa/
|
| 25 |
+
│ ├── chart/
|
| 26 |
+
│ └── commom_util.py
|
| 27 |
+
├── eval/ # Evaluation scripts for different benchmarks
|
| 28 |
+
├── reward_utils/ # Reward function implementations for RLVR
|
| 29 |
+
├── config/ # Modular configuration files for experiments
|
| 30 |
+
├── opsd_utils/ # Privileged-context OPSD / TriMode extensions for DyMETrainer
|
| 31 |
+
├── default_config.yaml # Default DDP (MULTI_GPU, no DeepSpeed required)
|
| 32 |
+
├── default_config_deepspeed.yaml # Optional ZeRO-0 only (no sharding); needs pip install deepspeed
|
| 33 |
+
├── default_config_zero2.yaml # ZeRO-2 student sharding (OPD 7B colocate)
|
| 34 |
+
├── default_config_zero3_colocate.yaml # ZeRO-3 + CPU optimizer offload (tight VRAM)
|
| 35 |
+
├── configs/deepspeed/ # DeepSpeed JSON templates (HF official "auto" fields)
|
| 36 |
+
├── main.py # Entry point for DyME training
|
| 37 |
+
├── main_*.py # Additional experimental variants (e.g., 7B, LLM-only)
|
| 38 |
+
├── requirements.txt # Python dependencies
|
| 39 |
+
└── ...
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## Configuration
|
| 44 |
+
|
| 45 |
+
Before launching training, please prepare the relevant configuration files. The main settings are managed through configuration files such as `config/config.py` and `default_config.yaml`.
|
| 46 |
+
|
| 47 |
+
### `CLIENT_CONFIG`
|
| 48 |
+
|
| 49 |
+
This configuration is required when **Visual Supervision** is enabled. It specifies the online large-model API used by the visual checker and visual refiner.
|
| 50 |
+
|
| 51 |
+
### `TRAINING_CONFIG`
|
| 52 |
+
|
| 53 |
+
This section contains standard training hyperparameters for both the memorization phase and the exploration phase, including optimizer settings, batch size, learning rate, and related options.
|
| 54 |
+
|
| 55 |
+
### `RL_CONFIG`
|
| 56 |
+
|
| 57 |
+
This section defines critical variables for reward computation and response parsing during RLVR training. In particular, the following delimiters must be properly specified:
|
| 58 |
+
|
| 59 |
+
* `answer_flag`: used to explicitly separate the final answer from auxiliary generated content such as intermediate reasoning traces.
|
| 60 |
+
* `end_flag`: used to mark the end of generation.
|
| 61 |
+
|
| 62 |
+
These delimiters are essential for stable parsing, reward assignment, and evaluation consistency.
|
| 63 |
+
|
| 64 |
+
### `DYME_OPSD_CONFIG` (OPSD / TriMode)
|
| 65 |
+
|
| 66 |
+
`config/config.py` defines `DYME_OPSD_CONFIG`, merged into `CONFIG["opsd"]`. When `enabled=False` (default), training follows the original DyME behavior. Set `enabled=True` or pass CLI flags to activate privileged-context **Self-OPSD** inside `DyMETrainer`.
|
| 67 |
+
|
| 68 |
+
| Field | Description |
|
| 69 |
+
| --- | --- |
|
| 70 |
+
| `enabled` | Master switch. `False` → original DyME only. |
|
| 71 |
+
| `mode` | Routing mode (see table below). |
|
| 72 |
+
| `privileged_profile` | Teacher preset: `text` \| `visual` \| `hybrid` (default **`hybrid`** in `config_trimode.py`). |
|
| 73 |
+
| `privileged_providers` | Override provider list; default derived from profile. |
|
| 74 |
+
| `privileged_image` | Teacher image layout: `mode` `single` (ChartQA default) or `dual` (full + crop); plus `crop_strategy`, `bbox_coord`, `margin_ratio`. |
|
| 75 |
+
| `privileged_debug` | Periodic artifact logging: `save_images`, `image_subdir` (`logs/images`), `max_samples_per_detail`. |
|
| 76 |
+
| `gate.correct_threshold` | Reward threshold to count a rollout as correct. |
|
| 77 |
+
| `gate.teacher_recoverable` | Recoverability gate: `privileged_available` (default) or `logprob_gain`. |
|
| 78 |
+
| `loss.beta` | JSD temperature for OPSD distillation. |
|
| 79 |
+
| `loss.opsd_weight` / `grpo_weight` / `sft_weight` | Per-mode loss weights. |
|
| 80 |
+
|
| 81 |
+
**Routing modes (`mode`):**
|
| 82 |
+
|
| 83 |
+
| Mode | Behavior |
|
| 84 |
+
| --- | --- |
|
| 85 |
+
| `dyme` | Original DyME: any correct rollout → GRPO; all wrong → SFT. |
|
| 86 |
+
| `trimode` | Any correct → OPSD (replaces GRPO); all wrong → SFT (DyME cold-start via `sft_check`, ignores recoverable). |
|
| 87 |
+
| `opsd_only` | All prompts use OPSD. |
|
| 88 |
+
| `replace_sft` | Any correct → GRPO; all wrong → OPSD (no SFT). |
|
| 89 |
+
| `opsd_on_wrong` | Any correct → GRPO; all wrong + recoverable → OPSD; all wrong + not recoverable → SFT (legacy three-way routing). |
|
| 90 |
+
| `grpo_opsd_joint` | Any correct → GRPO (+ optional joint OPSD loss); all wrong + recoverable → OPSD; else SFT. |
|
| 91 |
+
|
| 92 |
+
Under `trimode`, the SFT share is determined by accuracy (how often prompts are all-wrong) and DyME's per-group `sft_check` (teacher injection on the first generation only)—no extra `sft_ratio` hyperparameter.
|
| 93 |
+
|
| 94 |
+
**Privileged profiles** (`privileged_profile`):
|
| 95 |
+
|
| 96 |
+
| Profile | Teacher images | Teacher text suffix |
|
| 97 |
+
| --- | --- | --- |
|
| 98 |
+
| `text` | Single full image (same as student) | hint + answer |
|
| 99 |
+
| `visual` | **Dual**: full + evidence crop | Visual Facts only (no answer leak) |
|
| 100 |
+
| `hybrid` | Single full image by default (`privileged_image.mode=single`); dual with `mode=dual` | Visual Facts + hint + answer |
|
| 101 |
+
|
| 102 |
+
Student `collate_fn` never reads privileged fields. With `privileged_image.mode=dual`, teacher forward uses `[full, crop]`; crop comes from normalized `evidence_bbox` (C2), A-OKVQA `visual_fact` heuristic (D2), or center fallback (D1). ChartQA defaults to `single` (no crop zoom).
|
| 103 |
+
|
| 104 |
+
**Privileged providers** (under `opsd_utils/privileged/`):
|
| 105 |
+
|
| 106 |
+
* `text` — uses the `hint` / `answer` fields in training samples.
|
| 107 |
+
* `visual_facts` — uses `visual_fact` JSON (B1 raw string), plus ChartQA `visual_fact_hint` (F1) and `visual_fact_deplot` (F2).
|
| 108 |
+
* `crop` — evidence region as second teacher image (via `image_utils`, not a text suffix).
|
| 109 |
+
* `hybrid` — combines text + visual_facts providers per profile.
|
| 110 |
+
|
| 111 |
+
**Debug / artifact logging**
|
| 112 |
+
|
| 113 |
+
* Verbose OPSD logs: `--opsd_debug` or `DYME_OPSD_DEBUG=1`.
|
| 114 |
+
* Full diagnostic bundle every N steps: `--opsd_detail_every N` or `DYME_OPSD_DETAIL_EVERY`.
|
| 115 |
+
* On detail steps, teacher privileged images are saved under `{output_dir}/logs/images/` as `step_XXXXXX_idx_Y_full.png`, `_crop.png`, and `_meta.json` (controlled by `privileged_debug.max_samples_per_detail`).
|
| 116 |
+
|
| 117 |
+
**ChartQA visual-facts preprocessing (run on server before TriMode training)**
|
| 118 |
+
|
| 119 |
+
TriMode with `privileged_providers=text,visual_facts` requires `visual_fact_hint` / `visual_fact_deplot` (and optionally `visual_fact`) on each sample. Raw `train_medium.json` only has `hint` — without this step, logs show `visual_fact_len=0` and the VisualFacts teacher channel is empty.
|
| 120 |
+
|
| 121 |
+
From the repo root on your GPU server:
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
cd /path/to/agentic-rl-main # project root (parent of scripts/, config/, data/)
|
| 125 |
+
|
| 126 |
+
# F1: copy hint → visual_fact_hint (+ visual_fact for backward compat)
|
| 127 |
+
python scripts/build_visual_facts_chartqa.py \
|
| 128 |
+
--input data/chartqa/train_medium.json \
|
| 129 |
+
--output data/chartqa/train_medium_vf_hint.json \
|
| 130 |
+
--also-set-visual-fact
|
| 131 |
+
|
| 132 |
+
# F2: DePlot offline table extraction (google/deplot, batched GPU inference; default ON)
|
| 133 |
+
python scripts/build_visual_facts_chartqa_deplot.py \
|
| 134 |
+
--input data/chartqa/train_medium_vf_hint.json \
|
| 135 |
+
--output data/chartqa/train_medium_vf_full.json \
|
| 136 |
+
--batch-size 8 \
|
| 137 |
+
--cache data/chartqa/deplot_cache.json
|
| 138 |
+
|
| 139 |
+
# Fast placeholder-only mode (no GPU / CI): add --no-enabled
|
| 140 |
+
# DYME_DEPLOT_ENABLED=0 bash scripts/train_local_gpus.sh
|
| 141 |
+
|
| 142 |
+
# quick sanity check (expect non-zero lengths)
|
| 143 |
+
python -c "
|
| 144 |
+
import json, random
|
| 145 |
+
d = json.load(open('data/chartqa/train_medium_vf_full.json', encoding='utf-8'))
|
| 146 |
+
s = random.choice(d)
|
| 147 |
+
assert s.get('visual_fact_hint'), 'missing visual_fact_hint'
|
| 148 |
+
assert s.get('visual_fact_deplot'), 'missing visual_fact_deplot'
|
| 149 |
+
print('ok', len(d), 'records; sample visual_fact_hint len', len(s['visual_fact_hint']))
|
| 150 |
+
"
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
`config/config.py` points `train_dataset` at `data/chartqa/train_medium_vf_full.json`. Generated `*_vf_*.json` files are gitignored — **generate them on each server** (or copy from shared storage); do not rely on cloning them from GitHub.
|
| 154 |
+
|
| 155 |
+
`scripts/train_local_gpus.sh` will auto-run the two Python steps above if `train_medium_vf_full.json` is missing.
|
| 156 |
+
|
| 157 |
+
**Training examples (TriMode + hybrid default)**
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
# Text-only OPSD ablation
|
| 161 |
+
python main.py --config trimode --opsd_privilege_profile text
|
| 162 |
+
|
| 163 |
+
# Vision-OPD style (no answer text to teacher)
|
| 164 |
+
python main.py --config trimode --opsd_privilege_profile visual
|
| 165 |
+
|
| 166 |
+
# Full hybrid (default in config_trimode)
|
| 167 |
+
python main.py --config trimode --opsd_privilege_profile hybrid --opsd_detail_every 10
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
**Privileged sample schema**
|
| 171 |
+
|
| 172 |
+
| Field | Used by | Notes |
|
| 173 |
+
| --- | --- | --- |
|
| 174 |
+
| `prompt`, `image` | Student + teacher | Student always single full image |
|
| 175 |
+
| `hint`, `answer` | Teacher (`text` / `hybrid`) | Never in student collate |
|
| 176 |
+
| `visual_fact` | Teacher | Raw JSON string (A-OKVQA) |
|
| 177 |
+
| `visual_fact_hint` | Teacher (ChartQA F1) | Hint placeholder pipeline |
|
| 178 |
+
| `visual_fact_deplot` | Teacher (ChartQA F2) | DePlot `parsed_table` text (`google/deplot`; placeholder skipped) |
|
| 179 |
+
| `evidence_bbox` | Teacher crop | Normalized `[x0,y0,x1,y1]` in `[0,1]` |
|
| 180 |
+
|
| 181 |
+
Adapter helpers for future datasets: `data_utils/privileged_schema.py` (`normalize_evidence_bbox`, `parse_visual_fact`, `resolve_crop_bbox`).
|
| 182 |
+
|
| 183 |
+
For legacy ChartQA single-field preprocessing, see `scripts/build_visual_facts_chartqa.py`.
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
## Data Preparation
|
| 188 |
+
|
| 189 |
+
We provide example preprocessing scripts in the `data_utils/` directory. After preprocessing, the training data should be organized as a list of dictionaries (e.g., `metadata_list`) following the format below:
|
| 190 |
+
|
| 191 |
+
```python
|
| 192 |
+
metadata_list.append({
|
| 193 |
+
"question": question, # Full prompt used for training
|
| 194 |
+
"question_wo_prompt": question, # Raw question without prompt template
|
| 195 |
+
"answer": answer, # SFT target; should follow the answer_flag format
|
| 196 |
+
"image": image_save_path, # Local path to the corresponding image
|
| 197 |
+
})
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
### Field Description
|
| 201 |
+
|
| 202 |
+
* `question`: the complete model input used during training.
|
| 203 |
+
* `question_wo_prompt`: the raw question content without any additional prompt wrapper.
|
| 204 |
+
* `answer`: the training target for SFT; this field should be formatted consistently with the delimiter specification in `RL_CONFIG`.
|
| 205 |
+
* `image`: the local file path of the associated image, if applicable.
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
## Environment Setup
|
| 210 |
+
|
| 211 |
+
Please first install the required dependencies and configure the distributed training environment:
|
| 212 |
+
|
| 213 |
+
```bash
|
| 214 |
+
pip install -r requirements.txt
|
| 215 |
+
accelerate config
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
The `accelerate config` step is required to initialize the distributed environment for both training and evaluation.
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
## Dataset Setup
|
| 223 |
+
|
| 224 |
+
### Text Data
|
| 225 |
+
|
| 226 |
+
Preprocessed text splits are provided under the `data/` directory.
|
| 227 |
+
|
| 228 |
+
### Image Data
|
| 229 |
+
|
| 230 |
+
Due to storage constraints, image datasets are not included in this repository. Download scripts write images under `data/images/` by default:
|
| 231 |
+
|
| 232 |
+
```text
|
| 233 |
+
data/images/
|
| 234 |
+
├── chartqa/
|
| 235 |
+
│ ├── images/ # train_000000.png, val_000000.png, test_000000.png, ...
|
| 236 |
+
│ └── json/ # train.json, val.json, test.json (from download.py)
|
| 237 |
+
└── aokvqa/
|
| 238 |
+
├── images/ # train_0000000.png, ...
|
| 239 |
+
└── json/ # train.json, validation.json, test.json (from download.py)
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
**ChartQA** (images only, no API required):
|
| 243 |
+
|
| 244 |
+
```bash
|
| 245 |
+
python data_utils/chart/download.py
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
**A-OKVQA** (images only by default; set `FETCH_VISUAL_FACTS=1` only if local VLM APIs are running on ports 23333–23340):
|
| 249 |
+
|
| 250 |
+
```bash
|
| 251 |
+
python data_utils/aokvqa/download.py
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
If you already downloaded ChartQA to `chartqa_output/` at the project root, move it into the canonical layout:
|
| 255 |
+
|
| 256 |
+
```bash
|
| 257 |
+
mkdir -p data/images
|
| 258 |
+
mv chartqa_output data/images/chartqa
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
Preprocessed text annotations with hints live separately under `data/chartqa/` and `data/aokvqa/`. Image paths inside those JSON files are resolved automatically at load time (legacy prefixes like `/chartqa_output/` map to `data/images/chartqa/`).
|
| 262 |
+
|
| 263 |
+
### Demo Samples
|
| 264 |
+
|
| 265 |
+
A small subset of demo images for verifying the data loading pipeline may be provided in a future update.
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
## Dataset Examples
|
| 270 |
+
|
| 271 |
+
### ChartQA
|
| 272 |
+
|
| 273 |
+
**ChartQA** is a visual question answering benchmark grounded in chart images. To illustrate different supervision granularities, we provide representative examples with three levels of reasoning-trace quality: **High**, **Medium**, and **Low**.
|
| 274 |
+
|
| 275 |
+
<div align="center">
|
| 276 |
+
|
| 277 |
+
| Example |
|
| 278 |
+
| --------------------------------------------------------------- |
|
| 279 |
+
| <img src="figs/chartqa.png" alt="ChartQA Example" width="220"/> |
|
| 280 |
+
|
| 281 |
+
</div>
|
| 282 |
+
|
| 283 |
+
#### High-quality Example
|
| 284 |
+
|
| 285 |
+
<details>
|
| 286 |
+
<summary><code>High-quality ChartQA Example</code></summary>
|
| 287 |
+
|
| 288 |
+
```json
|
| 289 |
+
{
|
| 290 |
+
"question": "When does the unfavorable view reach the peak?",
|
| 291 |
+
"answer": "2017",
|
| 292 |
+
"hint": "<SUMMARY> To solve the problem, I will examine the image to identify trends in unfavorable views of Pakistan in India over time. I'll closely inspect the data points within the graph to determine the year where the \"very unfavorable view\" reaches its peak. This involves identifying the maximum value on the vertical axis and noting the corresponding year on the horizontal axis. </SUMMARY> \n\n<CAPTION> The image is a line graph titled \"Very unfavorable views of Pakistan increasing in India,\" with the subtitle \"Very unfavorable view of Pakistan.\" The y-axis represents the percentage of unfavorable views, ranging from 0% to 100%. The x-axis displays years from 2013 to 2017. The data points show the percentages of very unfavorable views over these years, with specific values marked: 54% in 2013, 49% in 2014, 51% in 2015, 55% in 2016, and 64% in 2017. The graph shows a general upward trend in unfavorable views, peaking in 2017. </CAPTION> \n\n<REASONING> To determine when the unfavorable view reaches its peak, one should observe the graph for the data point with the highest percentage on the y-axis. The graph shows percentages for each year from 2013 to 2017: starting at 54% in 2013, decreasing to 49% in 2014, and then gradually increasing to 51% in 2015 and 55% in 2016. The graph culminates with the highest percentage of 64% in 2017. Thus, the peak of unfavorable views is associated with the year 2017. </REASONING> \n\n<CONCLUSION> 2017 </CONCLUSION>"
|
| 293 |
+
}
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
</details>
|
| 297 |
+
|
| 298 |
+
#### Medium-quality Example
|
| 299 |
+
|
| 300 |
+
<details>
|
| 301 |
+
<summary><code>Medium-quality ChartQA Example</code></summary>
|
| 302 |
+
|
| 303 |
+
```json
|
| 304 |
+
{
|
| 305 |
+
"question": "When does the unfavorable view reach the peak?",
|
| 306 |
+
"answer": "2017",
|
| 307 |
+
"hint": "Goal: Find the year when the unfavorable view reaches its peak.\nObservation: The data shows the values for each year are: 2013: 0, 2014: 0, 2015: 0, 2016: 55, and 2017: 64.\nReasoning: By comparing the values in each year, the highest value is 64, which occurs in 2017.\nConclusion: The unfavorable view reaches its peak in 2017."
|
| 308 |
+
}
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
</details>
|
| 312 |
+
|
| 313 |
+
#### Low-quality Example
|
| 314 |
+
|
| 315 |
+
<details>
|
| 316 |
+
<summary><code>Low-quality ChartQA Example</code></summary>
|
| 317 |
+
|
| 318 |
+
```json
|
| 319 |
+
{
|
| 320 |
+
"question": "When does the unfavorable view reach the peak?",
|
| 321 |
+
"answer": "2017",
|
| 322 |
+
"hint": "I'm trying to figure out the year when the unfavorable view reaches its highest point. Looking at the data, I see that the values for each year are pretty low until 2016, where it jumps to 55. But the peak doesn't happen until 2017, when the value spikes to 64. So, it seems like the unfavorable view really hits its maximum in 2017."
|
| 323 |
+
}
|
| 324 |
+
```
|
| 325 |
+
|
| 326 |
+
</details>
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
### A-OKVQA
|
| 330 |
+
|
| 331 |
+
**A-OKVQA** is an open-ended visual question answering benchmark that requires world knowledge, commonsense reasoning, and visual understanding. Below we provide a representative example together with its corresponding annotation.
|
| 332 |
+
|
| 333 |
+
<div align="center">
|
| 334 |
+
|
| 335 |
+
| Example |
|
| 336 |
+
| -------------------------------------------------------------- |
|
| 337 |
+
| <img src="figs/aokvqa.png" alt="A-OKVQA Example" width="220"/> |
|
| 338 |
+
|
| 339 |
+
</div>
|
| 340 |
+
|
| 341 |
+
<details>
|
| 342 |
+
<summary><code>View A-OKVQA JSON Example</code></summary>
|
| 343 |
+
|
| 344 |
+
```json
|
| 345 |
+
{
|
| 346 |
+
"question": "What is the man by the bags awaiting?",
|
| 347 |
+
"answer": "cab",
|
| 348 |
+
"visual_fact": "{\n \"description\": \"The image shows a man standing in the middle of a street, facing away from the camera. He is holding a red bag in one hand and appears to be pulling a black suitcase with wheels. Another black suitcase is lying on the ground near him. The setting is an urban area with houses, parked cars, and trees in the background. The man seems to be waiting or preparing to cross the street.\",\n \"objects\": [\n {\n \"name\": \"man\",\n \"attributes\": [\"wearing a light blue and white shirt\", \"blue jeans\", \"carrying a red bag\", \"pulling a black suitcase\"],\n \"position\": \"center\"\n },\n {\n \"name\": \"red bag\",\n \"attributes\": [\"held by the man\"],\n \"position\": \"left side of the man\"\n },\n {\n \"name\": \"black suitcase\",\n \"attributes\": [\"with wheels\", \"being pulled by the man\"],\n \"position\": \"near the man's feet\"\n },\n {\n \"name\": \"black suitcase\",\n \"attributes\": [\"on the ground\"],\n \"position\": \"on the ground near the man\"\n },\n {\n \"name\": \"street\",\n \"attributes\": [\"asphalt\", \"urban setting\"],\n \"position\": \"foreground\"\n },\n {\n \"name\": \"houses\",\n \"attributes\": [\"visible in the background\"],\n \"position\": \"left side\"\n },\n {\n \"name\": \"parked cars\",\n \"attributes\": [\"red SUV\", \"other vehicles\"],\n \"position\": \"left and center background\"\n },\n {\n \"name\": \"trees\",\n \"attributes\": [\"green foliage\"],\n \"position\": \"right side\"\n }\n ]\n}",
|
| 349 |
+
"hint": "A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him, so a cab is the only plausible answer."
|
| 350 |
+
}
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
+
</details>
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
### GSM8K
|
| 358 |
+
|
| 359 |
+
**GSM8K** is a mathematical word problem benchmark. Since it is text-only, we provide a representative question-answer example together with its reasoning trace.
|
| 360 |
+
|
| 361 |
+
<details>
|
| 362 |
+
<summary><code>View GSM8K JSON Example</code></summary>
|
| 363 |
+
|
| 364 |
+
```json
|
| 365 |
+
{
|
| 366 |
+
"question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?",
|
| 367 |
+
"answer": "72",
|
| 368 |
+
"hint": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72"
|
| 369 |
+
}
|
| 370 |
+
```
|
| 371 |
+
|
| 372 |
+
</details>
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
## Training
|
| 377 |
+
|
| 378 |
+
All training scripts are launched using `accelerate`. Pass `--config` as a **Python config file path** (recommended) or a shorthand alias (`norm`, `trimode`, `llavacot`, `low`, `aok`).
|
| 379 |
+
|
| 380 |
+
**Important:** `num_processes` must match the number of visible GPUs on your node. Helper scripts auto-detect GPU count and use **native PyTorch DDP** (`default_config.yaml`, `distributed_type: MULTI_GPU`) — **DeepSpeed is not required** for 0.5B multi-GPU training.
|
| 381 |
+
|
| 382 |
+
Optional: if you already have `deepspeed` installed and want the Accelerate integration without parameter sharding, use ZeRO-0 (`default_config_deepspeed.yaml`, `zero_stage: 0`). Do **not** use ZeRO-2/3 for 0.5B-only RLSD unless you need the integration path.
|
| 383 |
+
|
| 384 |
+
**7B OPD (student + frozen teacher on each GPU):** use DeepSpeed ZeRO to shard the **trainable 0.5B student**; the frozen 7B teacher stays outside DeepSpeed on `cuda:{LOCAL_RANK}`.
|
| 385 |
+
|
| 386 |
+
```bash
|
| 387 |
+
# ZeRO-2 colocate (recommended first try on 2×80G)
|
| 388 |
+
bash scripts/train_opd_7b_chartqa_deepspeed.sh
|
| 389 |
+
|
| 390 |
+
# Tighter memory: ZeRO-3 + CPU optimizer offload
|
| 391 |
+
ACCELERATE_CONFIG=default_config_zero3_colocate.yaml bash scripts/train_opd_7b_chartqa_deepspeed.sh
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
Refs: [Transformers DeepSpeed](https://huggingface.co/docs/transformers/deepspeed), [Accelerate DeepSpeed](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).
|
| 395 |
+
|
| 396 |
+
```bash
|
| 397 |
+
# 4-GPU node (default DDP, recommended)
|
| 398 |
+
bash scripts/train_rlsd_chartqa.sh
|
| 399 |
+
|
| 400 |
+
# Explicit DDP config
|
| 401 |
+
accelerate launch --config_file default_config.yaml --num_processes 4 main.py --config config/config.py --mode rl
|
| 402 |
+
|
| 403 |
+
# Optional ZeRO-0 (requires deepspeed, no sharding)
|
| 404 |
+
ACCELERATE_CONFIG=default_config_deepspeed.yaml bash scripts/train_rlsd_chartqa.sh
|
| 405 |
+
```
|
| 406 |
+
|
| 407 |
+
Or override explicitly: `NUM_GPUS=4 bash scripts/train_trimode.sh`
|
| 408 |
+
|
| 409 |
+
For TriMode on **all visible local GPUs** (auto-detect via `CUDA_VISIBLE_DEVICES` / `torch.cuda.device_count()`):
|
| 410 |
+
|
| 411 |
+
```bash
|
| 412 |
+
# 1) One-time (or when raw data changes): enrich ChartQA JSON on the server — see
|
| 413 |
+
# "ChartQA visual-facts preprocessing" above. train_local_gpus.sh also auto-runs
|
| 414 |
+
# this if train_medium_vf_full.json is absent.
|
| 415 |
+
|
| 416 |
+
# 2) Start training (default: OPSD verbose off, detail every 50 steps, probe on)
|
| 417 |
+
bash scripts/train_local_gpus.sh
|
| 418 |
+
|
| 419 |
+
# Optional: full verbose debug (large logs)
|
| 420 |
+
# DYME_OPSD_DEBUG=1 DYME_OPSD_DETAIL_EVERY=10 bash scripts/train_local_gpus.sh
|
| 421 |
+
|
| 422 |
+
# Roll back to original trimode config (pre-antidegen hyperparameters)
|
| 423 |
+
# DYME_CONFIG=config/config_trimode.py bash scripts/train_local_gpus.sh
|
| 424 |
+
```
|
| 425 |
+
|
| 426 |
+
### Anti-degeneration config (`config_trimode_antidegen`)
|
| 427 |
+
|
| 428 |
+
`scripts/train_local_gpus.sh` defaults to **`config/config_trimode_antidegen.py`** (alias `trimode_antidegen`). Overrides are based on offline analysis of `train_trimode_4gpu_20260610_173637.log` (1225 steps):
|
| 429 |
+
|
| 430 |
+
| Issue | Baseline log evidence | Antidegen change |
|
| 431 |
+
|-------|----------------------|------------------|
|
| 432 |
+
| Logit collapse | `LOGIT_MODE_COLLAPSE` 212×; step 1 clip 0→1.0; step 1175 clip≈0.92 | `max_completion_length=150`, `temperature=0.7`, `repetition_penalty=1.25` |
|
| 433 |
+
| Step-1 gradient shock | `GEN_CLIP_COLLAPSE` from step 1; `OPT_GRAD_SPIKE` 44× | `learning_rate=5e-5`, `warmup_steps=50` |
|
| 434 |
+
| OPSD coverage low | `opsd_mask` mean 5.6%; 492/1226 zero-mask steps | `require_format_for_opsd=False` (env default `DYME_OPSD_REQUIRE_FORMAT=0`) |
|
| 435 |
+
| RL signal sparse | `RL_ZERO_SIGNAL` expected in trimode | `reward_weights=[0.5, 1.5, 1.0]` (format, context F1, acc) |
|
| 436 |
+
| visual_fact empty | `visual_fact_empty_rate=0` throughout | no data change |
|
| 437 |
+
|
| 438 |
+
Environment overrides:
|
| 439 |
+
|
| 440 |
+
```bash
|
| 441 |
+
export DYME_CONFIG=config/config_trimode_antidegen.py # default in train_local_gpus.sh
|
| 442 |
+
export DYME_OPSD_REQUIRE_FORMAT=0 # antidegen default; set 1 to restore strict gate
|
| 443 |
+
export DYME_REWARD_WEIGHTS=0.5,1.5,1.0 # format, context, accuracy
|
| 444 |
+
```
|
| 445 |
+
|
| 446 |
+
After a new run (~200+ steps), compare against the baseline log:
|
| 447 |
+
|
| 448 |
+
```bash
|
| 449 |
+
python scripts/parse_trimode_log.py outputs/logs/train_trimode_*_new.log
|
| 450 |
+
python scripts/degeneration_report.py outputs/logs/train_trimode_*_new.log
|
| 451 |
+
python scripts/compare_trimode_logs.py train_trimode_4gpu_20260610_173637.log outputs/logs/train_trimode_*_new.log
|
| 452 |
+
```
|
| 453 |
+
|
| 454 |
+
Success criteria (candidate vs baseline): step 1 `clip` < 1.0; `LOGIT_MODE_COLLAPSE` count down >30%; `opsd_mask` mean > 8%; step 200+ `mean_length` median < 130. `RL_ZERO_SIGNAL` may remain high (trimode design).
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
### 1. Training DyME (original)
|
| 458 |
+
|
| 459 |
+
Default config keeps OPSD disabled (`DYME_OPSD_CONFIG.enabled=False`):
|
| 460 |
+
|
| 461 |
+
```bash
|
| 462 |
+
accelerate launch main.py --config config/config.py --mode rl
|
| 463 |
+
```
|
| 464 |
+
|
| 465 |
+
### OPSD debug logging + tee
|
| 466 |
+
|
| 467 |
+
When debugging OPSD / TriMode (e.g. NCCL timeout), enable verbose logs and save stdout/stderr:
|
| 468 |
+
|
| 469 |
+
```bash
|
| 470 |
+
export DYME_OPSD_DEBUG=1
|
| 471 |
+
mkdir -p ./outputs/logs
|
| 472 |
+
LOG_FILE=./outputs/logs/train_$(date +%Y%m%d_%H%M%S).log
|
| 473 |
+
|
| 474 |
+
accelerate launch --config_file default_config.yaml --num_processes "$(nvidia-smi -L | wc -l)" main.py \
|
| 475 |
+
--config config/config_trimode.py \
|
| 476 |
+
--mode rl \
|
| 477 |
+
--opsd_enabled \
|
| 478 |
+
--opsd_debug \
|
| 479 |
+
--opsd_mode trimode \
|
| 480 |
+
--opsd_providers text,visual_facts \
|
| 481 |
+
2>&1 | tee "${LOG_FILE}"
|
| 482 |
+
```
|
| 483 |
+
|
| 484 |
+
Logs are prefixed with `[OPSD-DEBUG]` and include rank, step, `[SYNC_POINT]` markers before every distributed collective in the OPSD chain (reward gather, teacher prompt build, metrics gather, OPSD loss). Search the log for the last `[SYNC_POINT]` on each rank to locate where a hang occurred.
|
| 485 |
+
|
| 486 |
+
You can also use the helper script (debug + tee enabled by default):
|
| 487 |
+
|
| 488 |
+
```bash
|
| 489 |
+
bash scripts/train_trimode.sh
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
Disable debug when not needed: `DYME_OPSD_DEBUG=0 bash scripts/train_trimode.sh`
|
| 493 |
+
|
| 494 |
+
### Periodic weak-signal diagnostics (`[OPSD-DETAIL]`)
|
| 495 |
+
|
| 496 |
+
Separate from per-step `[OPSD-DEBUG]` spam: every **N global steps** (default **10**, rank 0 only) a full diagnostic bundle is printed to investigate **reward ≈ 0** and **gradient ≈ 0** while the OPSD chain still runs:
|
| 497 |
+
|
| 498 |
+
- Generation: EOS rate, clipped ratio, effective completion tokens, decoded samples
|
| 499 |
+
- Reward: format / acc / context breakdown, advantage stats, per-sample table
|
| 500 |
+
- Routing: OPSD mask ratio, TriMode counts, advantage token distribution
|
| 501 |
+
- Loss: GRPO per-token logps, coef\_1, clip counts, weak-signal hints
|
| 502 |
+
- OPSD JSD: per-token JSD, student/teacher top-1 agreement, max-JSD token
|
| 503 |
+
|
| 504 |
+
Configure via config, CLI, or env:
|
| 505 |
+
|
| 506 |
+
```bash
|
| 507 |
+
# default: every 10 steps (config_trimode.py)
|
| 508 |
+
export DYME_OPSD_DETAIL_EVERY=10
|
| 509 |
+
|
| 510 |
+
python main.py --config config/config_trimode.py --mode rl \
|
| 511 |
+
--opsd_enabled --opsd_detail_every 10
|
| 512 |
+
|
| 513 |
+
# disable periodic detail
|
| 514 |
+
export DYME_OPSD_DETAIL_EVERY=0
|
| 515 |
+
```
|
| 516 |
+
|
| 517 |
+
Search logs for `[OPSD-DETAIL]` (not `[OPSD-DEBUG]`).
|
| 518 |
+
|
| 519 |
+
**Per-generate probe (`[OPSD-PROBE]`)** — enabled by default in `config_trimode.py`; fires on every `(re)generate` on rank 0 (no need to wait for step 10). Logs raw `completion_ids`, decode with/without special tokens, `eos_idx`, flags `ONE_TOKEN` / `EMPTY_DECODE` / `FIRST_IS_EOS`, and patterns `PAREN_THEN_EOS` / `REPEAT_LOOP`. Disable with `DYME_OPSD_PROBE_ON_GENERATE=0` or `--no_opsd_probe_on_generate`.
|
| 520 |
+
|
| 521 |
+
**Deep generate debug (`[OPSD-GENDBG]`)** — runs alongside `[OPSD-PROBE]` when probe is enabled. Before each `model.generate`, logs model training context, prompt tail tokens/decode, and first-token logits (`p_eos`, `p_token_340`, `entropy`, `top5`) via **per-sample** forward (up to `probe_sample_count`, default 4) to avoid OOM on large VLM batches. After generate, logs greedy-vs-actual first token, delta vs previous regenerate, and cross-rank summary.
|
| 522 |
+
|
| 523 |
+
```bash
|
| 524 |
+
export DYME_OPSD_PROBE_ON_GENERATE=1 # default in config_trimode
|
| 525 |
+
grep -E '\[OPSD-(PROBE|GENDBG)\]' train.log
|
| 526 |
+
```
|
| 527 |
+
|
| 528 |
+
| Observation in logs | Likely root cause |
|
| 529 |
+
|---------------------|-------------------|
|
| 530 |
+
| `p_eos` very high + `greedy_token_id==eos` | Weight collapse / train-mode distribution |
|
| 531 |
+
| `prompt_tail_decode` ends with unclosed template + high `p_token_340` | Prompt / chat template issue (legacy `"Answer: .."` quoted placeholder biased token 340 `)`; fixed in `data_utils/rl_prompt.py`) |
|
| 532 |
+
| `greedy_matches_actual=False` with low `p_eos` | Sampling noise (temperature / top_p) |
|
| 533 |
+
| Large `one_token_count` gap across ranks in `cross_rank` | Data sharding / batch composition |
|
| 534 |
+
| `delta_one_token_count` spikes at `generate_call_index>=2` | Weight drift after optimizer step |
|
| 535 |
+
|
| 536 |
+
Optional env overrides:
|
| 537 |
+
|
| 538 |
+
```bash
|
| 539 |
+
export DYME_OPSD_PROBE_FIRST_TOKEN_LOGITS=0 # skip extra forward before generate
|
| 540 |
+
export DYME_OPSD_PROBE_PROMPT_TAIL_TOKENS=24
|
| 541 |
+
export DYME_OPSD_PROBE_LOG_MODEL_CONTEXT=0
|
| 542 |
+
```
|
| 543 |
+
|
| 544 |
+
### 2. Training TriMode (DyME + OPSD)
|
| 545 |
+
|
| 546 |
+
Use `config/config_trimode.py` (OPSD pre-enabled) or override on the base config via CLI:
|
| 547 |
+
|
| 548 |
+
```bash
|
| 549 |
+
accelerate launch main.py \
|
| 550 |
+
--config config/config_trimode.py \
|
| 551 |
+
--mode rl \
|
| 552 |
+
--opsd_enabled \
|
| 553 |
+
--opsd_mode trimode \
|
| 554 |
+
--opsd_providers text,visual_facts
|
| 555 |
+
```
|
| 556 |
+
|
| 557 |
+
Equivalent one-liner with base config + CLI only:
|
| 558 |
+
|
| 559 |
+
```bash
|
| 560 |
+
accelerate launch main.py --config config/config.py --mode rl \
|
| 561 |
+
--opsd_enabled --opsd_mode trimode --opsd_providers text,visual_facts
|
| 562 |
+
```
|
| 563 |
+
|
| 564 |
+
**CLI OPSD flags** (override `CONFIG["opsd"]`):
|
| 565 |
+
|
| 566 |
+
| Flag | Description |
|
| 567 |
+
| --- | --- |
|
| 568 |
+
| `--opsd_enabled` | Enable OPSD / TriMode extensions. |
|
| 569 |
+
| `--opsd_debug` | Verbose OPSD chain logs (`[OPSD-DEBUG]`, or env `DYME_OPSD_DEBUG=1`). |
|
| 570 |
+
| `--opsd_detail_every N` | Full weak-signal bundle every N steps (`[OPSD-DETAIL]`, default 10; `0` = off). |
|
| 571 |
+
| `--opsd_probe_on_generate` / `--no_opsd_probe_on_generate` | Per-generate `[OPSD-PROBE]` on rank 0 (trimode default on). |
|
| 572 |
+
| `--opsd_mode MODE` | Routing mode: `trimode` (legacy), `rlsd` (anti-leakage), `copsd_opd`, `dyme`, `opsd_only`, `replace_sft`, … |
|
| 573 |
+
| `--opsd_providers LIST` | Comma-separated providers, e.g. `text`, `format_only`, `visual_facts`. Empty = same-prompt OPD only. |
|
| 574 |
+
|
| 575 |
+
### 2b. RLSD / anti-leakage OPSD (recommended for ChartQA)
|
| 576 |
+
|
| 577 |
+
`trimode` routes OPSD on **correct** completions and injects gold answer into the teacher prompt (information leakage). Use **`rlsd`** instead:
|
| 578 |
+
|
| 579 |
+
- **Correct** → GRPO (on-policy self-learning, no privileged suffix)
|
| 580 |
+
- **Wrong** → same-prompt OPSD / OPD (no `[Reference Answer]` in teacher)
|
| 581 |
+
- **All-wrong group** → online SFT replace on the first generation (DyME cold-start; no separate offline SFT phase)
|
| 582 |
+
|
| 583 |
+
**Important — online SFT ≠ offline SFT:** From step 0, training is always **RL + sparse online SFT** (typically 1/8 of completions per prompt when the group is all-wrong). There is no dedicated SFT-only phase unless you run a separate offline stage (see below).
|
| 584 |
+
|
| 585 |
+
**Anti-collapse knobs (ChartQA RLSD / OPD):**
|
| 586 |
+
|
| 587 |
+
| Env / config | Purpose |
|
| 588 |
+
| --- | --- |
|
| 589 |
+
| `DYME_MAX_COMPLETION_LENGTH`, `DYME_TEMPERATURE`, `DYME_REPETITION_PENALTY` | Antidegen decoding (RLSD defaults: 128 / 0.6 / 1.35) |
|
| 590 |
+
| `DYME_FORMAT_MIN_THINKING` | Minimum chars before `Answer:` for format reward (default 8) |
|
| 591 |
+
| `DYME_OPSD_SKIP_DEGENERATE=0` | Never skip OPSD on degenerate completions |
|
| 592 |
+
| `DYME_OPSD_DEGEN_WARMUP_STEPS` | Before this step, degenerate samples still run OPSD (default 200) |
|
| 593 |
+
| `DYME_SFT_WARMUP_SLOTS` | During warmup, inject GT into first N gens per all-wrong group (default 2) |
|
| 594 |
+
|
| 595 |
+
**OPD config trap:** `config/config_opd_7b_chartqa.py` must inherit `CONFIG["training"]["dyme_args"]` from `config_rlsd_chartqa` (not stale `TRAINING_CONFIG["dyme_args"]` from antidegen). If logs show `max_new_tokens=150, temperature=0.7`, you are on the wrong decode path — stop and `git pull`.
|
| 596 |
+
|
| 597 |
+
**Stop-training heuristics:** If after ~200 steps you see `degenerate_rate≈1`, `opsd_mask_true=0`, `grad_norm=0`, and `format_mean≈1` with `accuracy=0`, the run is collapsed — restart from base 0.5B or an early checkpoint.
|
| 598 |
+
|
| 599 |
+
```bash
|
| 600 |
+
bash scripts/train_rlsd_chartqa.sh
|
| 601 |
+
# or: --config config/config_rlsd_chartqa.py --opsd_mode rlsd --opsd_providers format_only
|
| 602 |
+
```
|
| 603 |
+
|
| 604 |
+
**Two-stage cold start (optional offline SFT → RLSD/OPD):**
|
| 605 |
+
|
| 606 |
+
```bash
|
| 607 |
+
bash scripts/train_chartqa_sft.sh
|
| 608 |
+
export DYME_PRETRAINED_MODEL=./outputs/chartqa-sft/final_checkpoint
|
| 609 |
+
bash scripts/train_opd_7b_chartqa_deepspeed.sh
|
| 610 |
+
```
|
| 611 |
+
|
| 612 |
+
**200-step smoke (OPD fixes):**
|
| 613 |
+
|
| 614 |
+
```bash
|
| 615 |
+
bash scripts/train_opd_7b_smoke.sh
|
| 616 |
+
# Success: degenerate_rate<0.5, opsd_mask>8%, advantage_abs_mean>0, grad_norm>0
|
| 617 |
+
```
|
| 618 |
+
|
| 619 |
+
**Cross-model OPD (7B frozen teacher + 0.5B student):**
|
| 620 |
+
|
| 621 |
+
```bash
|
| 622 |
+
# Default: teacher on each rank's GPU (cuda:LOCAL_RANK). 2-GPU: student+teacher share the same card per rank.
|
| 623 |
+
# Optional dedicated teacher GPU: export DYME_TEACHER_DEVICE_MAP=cuda:1
|
| 624 |
+
# Vocab alignment debug at startup: DYME_VOCAB_ALIGN_FULL=1 (exhaustive) or DYME_VOCAB_ALIGN_STRIDE=500
|
| 625 |
+
bash scripts/train_opd_7b_chartqa.sh
|
| 626 |
+
```
|
| 627 |
+
|
| 628 |
+
Note: `main.py --mode rl --config config/config.py` uses **`dyme_args`** (not the unused `grpo_args` block in the same file). Pure GRPO baselines use `main_rebuttal.py`.
|
| 629 |
+
|
| 630 |
+
**Helper scripts** (under `scripts/`):
|
| 631 |
+
|
| 632 |
+
```bash
|
| 633 |
+
# TriMode on ChartQA (legacy; leakage risk on ChartQA)
|
| 634 |
+
bash scripts/train_trimode.sh
|
| 635 |
+
|
| 636 |
+
# Anti-leakage RLSD (recommended)
|
| 637 |
+
bash scripts/train_rlsd_chartqa.sh
|
| 638 |
+
|
| 639 |
+
# Ablation matrix: MODE=dyme|trimode|replace_sft|opsd_only|...
|
| 640 |
+
MODE=trimode DYME_OPSD_PROVIDERS=text,visual_facts bash scripts/train_baselines.sh
|
| 641 |
+
|
| 642 |
+
# Post-training eval (set CHECKPOINT_DIR)
|
| 643 |
+
CHECKPOINT_DIR=./outputs/trimode-chartqa/final_checkpoint bash scripts/run_eval_ablation.sh
|
| 644 |
+
```
|
| 645 |
+
|
| 646 |
+
### 3. Reproducing Baselines
|
| 647 |
+
|
| 648 |
+
To reproduce baseline settings such as standard SFT or RL training, use **`main_sft.py`** (offline ChartQA SFT) or `main.py` with `--opsd_enabled` off for pure DyME.
|
| 649 |
+
|
| 650 |
+
#### Supervised Fine-Tuning (SFT) — offline two-stage
|
| 651 |
+
|
| 652 |
+
```bash
|
| 653 |
+
bash scripts/train_chartqa_sft.sh
|
| 654 |
+
# or: accelerate launch main_sft.py --config config/config_rlsd_chartqa.py
|
| 655 |
+
```
|
| 656 |
+
|
| 657 |
+
Then point RLSD/OPD at the SFT checkpoint via `DYME_PRETRAINED_MODEL` or `MODEL_CONFIG.pretrained_model_path`.
|
| 658 |
+
|
| 659 |
+
#### Reinforcement Learning (GRPO / RL)
|
| 660 |
+
|
| 661 |
+
```bash
|
| 662 |
+
accelerate launch main.py --config config/config.py --mode rl
|
| 663 |
+
```
|
| 664 |
+
|
| 665 |
+
(`main_rebuttal.py` is referenced in the original DyME paper repo but is not shipped here; use `main_sft.py` + `main.py` instead.)
|
| 666 |
+
|
| 667 |
+
### 4. Additional Experimental Variants
|
| 668 |
+
|
| 669 |
+
For specific experimental settings such as different model scales or architecture-specific ablations, please use the corresponding scripts:
|
| 670 |
+
|
| 671 |
+
* `main_7B.py`: experiments at the 7B scale
|
| 672 |
+
* `main_llm.py`: LLM-specific variants
|
| 673 |
+
* `main_change.py`: additional ablation settings
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
## Evaluation
|
| 677 |
+
|
| 678 |
+
We support multi-process evaluation through `accelerate`. Evaluation scripts are located in the `eval/` directory and can be launched as Python modules.
|
| 679 |
+
|
| 680 |
+
### General Usage
|
| 681 |
+
|
| 682 |
+
```bash
|
| 683 |
+
accelerate launch -m eval.<eval_script_name>
|
| 684 |
+
```
|
| 685 |
+
|
| 686 |
+
### Example: ChartQA Evaluation
|
| 687 |
+
|
| 688 |
+
```bash
|
| 689 |
+
accelerate launch -m eval.eval_chartqa
|
| 690 |
+
```
|
| 691 |
+
|
| 692 |
+
### Evaluation Setup
|
| 693 |
+
|
| 694 |
+
Before running evaluation, please open the corresponding evaluation script (for example, `eval_chartqa.py`) and modify the following fields as needed:
|
| 695 |
+
|
| 696 |
+
* `model_id`: the path or identifier of the checkpoint to be evaluated
|
| 697 |
+
* prompt templates: these should match the formatting used during training
|
| 698 |
+
|
| 699 |
+
Ensuring consistency between training and evaluation prompts is important for obtaining reliable results.
|
| 700 |
+
|
| 701 |
+
## Citation
|
| 702 |
+
|
| 703 |
+
If you find this repository useful in your research, please consider citing our paper:
|
| 704 |
+
|
| 705 |
+
```bibtex
|
| 706 |
+
@inproceedings{dyme2026,
|
| 707 |
+
title={Empowering Small VLMs to Think with Dynamic Memorization and Exploration},
|
| 708 |
+
author={Jiazhen Liu, Yuchuan Deng, Long Chen},
|
| 709 |
+
booktitle={ICLR},
|
| 710 |
+
year={2026},
|
| 711 |
+
}
|
| 712 |
+
```
|
client_utils/openai_api.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# It's good practice to define a simple configuration object or use a dictionary
|
| 9 |
+
# for passing credentials, rather than a generic object.
|
| 10 |
+
# For this example, we'll assume a config object like this.
|
| 11 |
+
class ClientConfig:
|
| 12 |
+
def __init__(self, api_key: str, base_url: str, model_id: str):
|
| 13 |
+
self.api_key = api_key
|
| 14 |
+
self.base_url = base_url
|
| 15 |
+
self.model_id = model_id
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OpenAIClient:
|
| 19 |
+
"""
|
| 20 |
+
A client wrapper for interacting with the OpenAI ChatCompletion API.
|
| 21 |
+
It handles client initialization and API calls with retry logic.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: ClientConfig, max_retries: int = 3):
|
| 25 |
+
# The OpenAI client is initialized directly within the class constructor.
|
| 26 |
+
# This improves encapsulation by making the class self-contained.
|
| 27 |
+
# It takes the configuration object as a direct argument.
|
| 28 |
+
custom_http_client = httpx.Client(trust_env=False)
|
| 29 |
+
self.client = OpenAI(
|
| 30 |
+
api_key=config['api_key'], # Required: your API key
|
| 31 |
+
base_url=config['api_base'], # Optional: only needed for third-party services
|
| 32 |
+
http_client=custom_http_client,
|
| 33 |
+
)
|
| 34 |
+
self.model_id = config['model_id']
|
| 35 |
+
self.max_retries = max_retries
|
| 36 |
+
|
| 37 |
+
def get_completion(
|
| 38 |
+
self,
|
| 39 |
+
user_prompt: str,
|
| 40 |
+
system_prompt: Optional[str] = None,
|
| 41 |
+
max_tokens: int = 1024
|
| 42 |
+
) -> Optional[str]:
|
| 43 |
+
"""
|
| 44 |
+
Calls the OpenAI ChatCompletion API and returns the result.
|
| 45 |
+
Includes retry logic for handling transient errors.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
user_prompt (str): The main input/prompt from the user.
|
| 49 |
+
system_prompt (Optional[str]): The system-level instruction for the model. Defaults to None.
|
| 50 |
+
max_tokens (int): The maximum number of tokens to generate. Defaults to 1024.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Optional[str]: The content of the model's response, or None if the API call fails after all retries.
|
| 54 |
+
"""
|
| 55 |
+
# Build the message list based on provided prompts.
|
| 56 |
+
messages = []
|
| 57 |
+
if system_prompt:
|
| 58 |
+
messages.append({'role': 'system', 'content': system_prompt})
|
| 59 |
+
messages.append({'role': 'user', 'content': user_prompt})
|
| 60 |
+
|
| 61 |
+
# Implement a clear retry loop instead of 'while True'.
|
| 62 |
+
for attempt in range(self.max_retries):
|
| 63 |
+
try:
|
| 64 |
+
# Make the API call to the chat completions endpoint.
|
| 65 |
+
response = self.client.chat.completions.create(
|
| 66 |
+
model=self.model_id,
|
| 67 |
+
messages=messages,
|
| 68 |
+
max_tokens=max_tokens,
|
| 69 |
+
# temperature=0.2 # You can add other parameters here as needed.
|
| 70 |
+
)
|
| 71 |
+
# If the call is successful, return the message content and exit the loop.
|
| 72 |
+
return response.choices[0].message.content
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
# If an error occurs, print a helpful message.
|
| 76 |
+
print(f"API call failed on attempt {attempt + 1}/{self.max_retries}. Error: {e}")
|
| 77 |
+
|
| 78 |
+
# If this was the last attempt, break the loop to return None.
|
| 79 |
+
if attempt + 1 == self.max_retries:
|
| 80 |
+
print("All retry attempts failed.")
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
# Wait for a short period before trying again.
|
| 84 |
+
print("Retrying in 2 seconds...")
|
| 85 |
+
time.sleep(2)
|
| 86 |
+
|
| 87 |
+
# Return None if all retries fail.
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --- How to Use the Refactored Class ---
|
| 92 |
+
if __name__ == '__main__':
|
| 93 |
+
# 1. Define your configuration
|
| 94 |
+
# Replace with your actual credentials and mode
|
| 95 |
+
|
| 96 |
+
CLIENT_CONFIG = {
|
| 97 |
+
"client_type": "openai",
|
| 98 |
+
"api_key": "none",
|
| 99 |
+
"api_base": "http://127.0.0.1:23333/v1",
|
| 100 |
+
"timeout": 60,
|
| 101 |
+
"model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
|
| 102 |
+
"init_port": 23333,
|
| 103 |
+
"num_server": 8
|
| 104 |
+
}
|
| 105 |
+
# 2. Instantiate the client
|
| 106 |
+
my_client = OpenAIClient(config=CLIENT_CONFIG)
|
| 107 |
+
|
| 108 |
+
# 3. Define your prompts
|
| 109 |
+
user_message = "What is the capital of France?"
|
| 110 |
+
system_message = "You are a helpful assistant that provides concise answers."
|
| 111 |
+
|
| 112 |
+
# 4. Get the model's response
|
| 113 |
+
response_content = my_client.get_completion(
|
| 114 |
+
user_prompt=user_message,
|
| 115 |
+
system_prompt=system_message
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# 5. Print the result
|
| 119 |
+
if response_content:
|
| 120 |
+
print("\nModel Response:")
|
| 121 |
+
print(response_content)
|
| 122 |
+
else:
|
| 123 |
+
print("\nFailed to get a response from the model.")
|
config/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config.config import (
|
| 2 |
+
CLIENT_CONFIG,
|
| 3 |
+
CONFIG,
|
| 4 |
+
DATASET_CONFIG,
|
| 5 |
+
DYME_OPSD_CONFIG,
|
| 6 |
+
MODEL_CONFIG,
|
| 7 |
+
RL_CONFIG,
|
| 8 |
+
TRAINING_CONFIG,
|
| 9 |
+
save_config,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"CLIENT_CONFIG",
|
| 14 |
+
"CONFIG",
|
| 15 |
+
"DATASET_CONFIG",
|
| 16 |
+
"DYME_OPSD_CONFIG",
|
| 17 |
+
"MODEL_CONFIG",
|
| 18 |
+
"RL_CONFIG",
|
| 19 |
+
"TRAINING_CONFIG",
|
| 20 |
+
"save_config",
|
| 21 |
+
]
|
config/__pycache__/config_rlsd_chartqa.cpython-310.pyc
ADDED
|
Binary file (4 kB). View file
|
|
|
config/__pycache__/config_trimode.cpython-310.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
config/__pycache__/config_trimode_antidegen.cpython-310.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
config/__pycache__/loader.cpython-312.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
config/config_7B.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# ====== Model Configuration ======
|
| 5 |
+
MODEL_CONFIG = {
|
| 6 |
+
"pretrained_model_path": "Qwen/Qwen2.5-VL-7B-Instruct",
|
| 7 |
+
"use_flash_attention_2": False,
|
| 8 |
+
"torch_dtype": "bfloat16",
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
# ====== Training Configuration ======
|
| 12 |
+
TRAINING_CONFIG = {
|
| 13 |
+
"task": 'chart',
|
| 14 |
+
"num_gpus": 8,
|
| 15 |
+
"num_client": 8,
|
| 16 |
+
|
| 17 |
+
"dyme_args": {
|
| 18 |
+
"output_dir": '/path/to/dyme-qwen25_7B-chart-llava_cot',
|
| 19 |
+
"logging_steps": 1,
|
| 20 |
+
"num_generations": 8,
|
| 21 |
+
"max_completion_length": 300,
|
| 22 |
+
"per_device_train_batch_size": 1,
|
| 23 |
+
"gradient_accumulation_steps": 16,
|
| 24 |
+
"num_train_epochs": 10,
|
| 25 |
+
"learning_rate": 1e-5,
|
| 26 |
+
"bf16": True,
|
| 27 |
+
"gradient_checkpointing": False,
|
| 28 |
+
"ddp_find_unused_parameters": False,
|
| 29 |
+
"max_grad_norm": 1.0,
|
| 30 |
+
"save_steps": 100,
|
| 31 |
+
"weight_decay": 0.01,
|
| 32 |
+
"warmup_steps": 0,
|
| 33 |
+
"eval_strategy": "steps",
|
| 34 |
+
"eval_steps": 10000,
|
| 35 |
+
"beta": 0.0, # GRPO specific
|
| 36 |
+
"loss_type": 'grpo', # GRPO specific
|
| 37 |
+
"seed": 42,
|
| 38 |
+
},
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
RL_CONFIG = {
|
| 42 |
+
"answer_flag": "Answer:",
|
| 43 |
+
"end_flag": "<|im_end|>"
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# ====== Client Configuration for Reward Calculation ======
|
| 47 |
+
CLIENT_CONFIG = {
|
| 48 |
+
"client_type": "openai",
|
| 49 |
+
"api_key": "none",
|
| 50 |
+
"api_base": "http://127.0.0.1:%s/v1",
|
| 51 |
+
"timeout": 60,
|
| 52 |
+
"model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
|
| 53 |
+
"init_port": 23333,
|
| 54 |
+
"num_server": 8
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# ====== Dataset Configuration ======
|
| 58 |
+
DATASET_CONFIG = {
|
| 59 |
+
|
| 60 |
+
"train_dataset": "/path/to/data/chartqa_output/llavacot/json/chartqa_train_processed.json",
|
| 61 |
+
"eval_dataset": "HuggingFaceM4/ChartQA",
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
CONFIG = {
|
| 66 |
+
"model": MODEL_CONFIG,
|
| 67 |
+
"training": TRAINING_CONFIG,
|
| 68 |
+
"rl": RL_CONFIG,
|
| 69 |
+
"client": CLIENT_CONFIG,
|
| 70 |
+
"dataset": DATASET_CONFIG,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def save_config(config, config_path="./config.json"):
|
| 75 |
+
import json
|
| 76 |
+
with open(config_path, "w") as f:
|
| 77 |
+
json.dump(config, f, indent=4)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
save_config(CONFIG)
|
| 82 |
+
|
config/config_aok.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from data_utils.paths import project_path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
MODEL_CONFIG = {
|
| 8 |
+
"pretrained_model_path": "llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
| 9 |
+
"use_flash_attention_2": True,
|
| 10 |
+
"torch_dtype": "bfloat16",
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
TRAINING_CONFIG = {
|
| 15 |
+
"task": 'world',
|
| 16 |
+
"num_gpus": 8,
|
| 17 |
+
"num_client": 8,
|
| 18 |
+
"dyme_args": {
|
| 19 |
+
"output_dir": '/path/to/dyme-aok-online',
|
| 20 |
+
"logging_steps": 1,
|
| 21 |
+
"num_generations": 4,
|
| 22 |
+
"max_completion_length": 300,
|
| 23 |
+
"per_device_train_batch_size": 1,
|
| 24 |
+
"gradient_accumulation_steps": 16,
|
| 25 |
+
"num_train_epochs": 10,
|
| 26 |
+
"learning_rate": 1e-5,
|
| 27 |
+
"bf16": True,
|
| 28 |
+
"gradient_checkpointing": False,
|
| 29 |
+
"ddp_find_unused_parameters": False,
|
| 30 |
+
"max_grad_norm": 1.0,
|
| 31 |
+
"save_strategy": "epoch",
|
| 32 |
+
"weight_decay": 0.01,
|
| 33 |
+
"warmup_steps": 0,
|
| 34 |
+
"beta": 0.0,
|
| 35 |
+
"loss_type": 'grpo',
|
| 36 |
+
"seed": 42,
|
| 37 |
+
},
|
| 38 |
+
"sft_args": {
|
| 39 |
+
"output_dir": '/path/to/sft-aok',
|
| 40 |
+
"logging_steps": 1,
|
| 41 |
+
"per_device_train_batch_size": 2,
|
| 42 |
+
"gradient_accumulation_steps": 4,
|
| 43 |
+
"num_train_epochs": 10,
|
| 44 |
+
"learning_rate": 1e-5,
|
| 45 |
+
"bf16": True,
|
| 46 |
+
"gradient_checkpointing": False,
|
| 47 |
+
"ddp_find_unused_parameters": False,
|
| 48 |
+
"max_grad_norm": 1.0,
|
| 49 |
+
"save_strategy": "epoch",
|
| 50 |
+
"weight_decay": 0.01,
|
| 51 |
+
"warmup_steps": 0,
|
| 52 |
+
"seed": 42,
|
| 53 |
+
"remove_unused_columns": False
|
| 54 |
+
},
|
| 55 |
+
"grpo_args":{
|
| 56 |
+
"output_dir": '/path/to/grpo-aok',
|
| 57 |
+
"logging_steps": 1,
|
| 58 |
+
"num_generations": 4,
|
| 59 |
+
"max_completion_length": 576,
|
| 60 |
+
"max_prompt_length": None,
|
| 61 |
+
"per_device_train_batch_size": 2,
|
| 62 |
+
"gradient_accumulation_steps": 4,
|
| 63 |
+
"num_train_epochs": 10,
|
| 64 |
+
"learning_rate": 1e-5,
|
| 65 |
+
"bf16": True,
|
| 66 |
+
"gradient_checkpointing": False,
|
| 67 |
+
"ddp_find_unused_parameters": False,
|
| 68 |
+
"max_grad_norm": 1.0,
|
| 69 |
+
"save_strategy": "epoch",
|
| 70 |
+
"weight_decay": 0.01,
|
| 71 |
+
"warmup_steps": 0,
|
| 72 |
+
"beta": 0.04,
|
| 73 |
+
"loss_type": 'grpo',
|
| 74 |
+
"seed": 42,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
RL_CONFIG = {
|
| 80 |
+
"answer_flag": "Answer:",
|
| 81 |
+
"end_flag": "<|im_end|>"
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# ====== Client Configuration for Reward Calculation ======
|
| 85 |
+
CLIENT_CONFIG = {
|
| 86 |
+
"client_type": "openai",
|
| 87 |
+
"api_key": "none",
|
| 88 |
+
"api_base": "http://127.0.0.1:%s/v1",
|
| 89 |
+
"timeout": 60,
|
| 90 |
+
"model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
|
| 91 |
+
"init_port": 23333,
|
| 92 |
+
"num_server": 8
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# ====== Dataset Configuration ======
|
| 96 |
+
DATASET_CONFIG = {
|
| 97 |
+
"train_dataset": project_path("data/aokvqa/train.json"),
|
| 98 |
+
"eval_dataset": "HuggingFaceM4/A-OKVQA",
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# ====== Full Configuration ======
|
| 102 |
+
CONFIG = {
|
| 103 |
+
"model": MODEL_CONFIG,
|
| 104 |
+
"training": TRAINING_CONFIG,
|
| 105 |
+
"rl": RL_CONFIG,
|
| 106 |
+
"client": CLIENT_CONFIG,
|
| 107 |
+
"dataset": DATASET_CONFIG,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
# Save configuration to a file for reference
|
| 111 |
+
def save_config(config, config_path="./config.json"):
|
| 112 |
+
import json
|
| 113 |
+
with open(config_path, "w") as f:
|
| 114 |
+
json.dump(config, f, indent=4)
|
| 115 |
+
|
| 116 |
+
# Example usage to save config
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
save_config(CONFIG)
|
| 119 |
+
|
config/config_llavacot.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# ====== Model Configuration ======
|
| 5 |
+
MODEL_CONFIG = {
|
| 6 |
+
"pretrained_model_path": '/path/to/sft-llavaov-chart-llava_cot/checkpoint-802', # two-stage grpo
|
| 7 |
+
"use_flash_attention_2": True,
|
| 8 |
+
"torch_dtype": "bfloat16",
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
# ====== Training Configuration ======
|
| 12 |
+
TRAINING_CONFIG = {
|
| 13 |
+
"task": 'chart',
|
| 14 |
+
"num_gpus": 8,
|
| 15 |
+
"num_client": 8,
|
| 16 |
+
"dyme_args": {
|
| 17 |
+
"output_dir": '/path/to/dyme-llavaov-chart-llava_cot',
|
| 18 |
+
"logging_steps": 1,
|
| 19 |
+
"num_generations": 4,
|
| 20 |
+
"max_completion_length": 300,
|
| 21 |
+
"per_device_train_batch_size": 1,
|
| 22 |
+
"gradient_accumulation_steps": 16,
|
| 23 |
+
"num_train_epochs": 10,
|
| 24 |
+
"learning_rate": 8e-5,
|
| 25 |
+
"bf16": True,
|
| 26 |
+
"gradient_checkpointing": False,
|
| 27 |
+
"ddp_find_unused_parameters": False,
|
| 28 |
+
"max_grad_norm": 1.0,
|
| 29 |
+
"save_strategy": "epoch",
|
| 30 |
+
"weight_decay": 0.01,
|
| 31 |
+
"warmup_steps": 0,
|
| 32 |
+
"beta": 0.0, # GRPO specific
|
| 33 |
+
"loss_type": 'grpo', # GRPO specific
|
| 34 |
+
"seed": 42,
|
| 35 |
+
},
|
| 36 |
+
"sft_args": {
|
| 37 |
+
"output_dir": '/path/to/sft-llavaov-chart-llava_cot',
|
| 38 |
+
"logging_steps": 1,
|
| 39 |
+
"per_device_train_batch_size": 2,
|
| 40 |
+
"gradient_accumulation_steps": 4,
|
| 41 |
+
"num_train_epochs": 10,
|
| 42 |
+
"learning_rate": 1e-5,
|
| 43 |
+
"bf16": True,
|
| 44 |
+
"gradient_checkpointing": False,
|
| 45 |
+
"ddp_find_unused_parameters": False,
|
| 46 |
+
"max_grad_norm": 1.0,
|
| 47 |
+
"max_length": 4096,
|
| 48 |
+
"save_strategy": "epoch",
|
| 49 |
+
"weight_decay": 0.01,
|
| 50 |
+
"warmup_steps": 0,
|
| 51 |
+
"seed": 42,
|
| 52 |
+
"remove_unused_columns": False
|
| 53 |
+
},
|
| 54 |
+
"grpo_args":{
|
| 55 |
+
"output_dir": '/path/to/grpo-llavaov-chart-beta',
|
| 56 |
+
"logging_steps": 1,
|
| 57 |
+
"num_generations": 4,
|
| 58 |
+
"max_completion_length": 576,
|
| 59 |
+
"max_prompt_length": None,
|
| 60 |
+
"per_device_train_batch_size": 4,
|
| 61 |
+
"gradient_accumulation_steps": 4,
|
| 62 |
+
"num_train_epochs": 10,
|
| 63 |
+
"learning_rate": 1e-5,
|
| 64 |
+
"bf16": True,
|
| 65 |
+
"gradient_checkpointing": False,
|
| 66 |
+
"ddp_find_unused_parameters": False,
|
| 67 |
+
"max_grad_norm": 1.0,
|
| 68 |
+
"save_strategy": "epoch",
|
| 69 |
+
"weight_decay": 0.01,
|
| 70 |
+
"warmup_steps": 0,
|
| 71 |
+
"beta": 0.04, # GRPO specific
|
| 72 |
+
"loss_type": 'grpo', # GRPO specific
|
| 73 |
+
"seed": 42,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
RL_CONFIG = {
|
| 79 |
+
"answer_flag": "Answer:",
|
| 80 |
+
"end_flag": "<|im_end|>"
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# ====== Client Configuration for Reward Calculation ======
|
| 84 |
+
CLIENT_CONFIG = {
|
| 85 |
+
"client_type": "openai",
|
| 86 |
+
"api_key": "none",
|
| 87 |
+
"api_base": "http://127.0.0.1:%s/v1",
|
| 88 |
+
"timeout": 60,
|
| 89 |
+
"model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
|
| 90 |
+
"init_port": 23333,
|
| 91 |
+
"num_server": 8
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# ====== Dataset Configuration ======
|
| 95 |
+
DATASET_CONFIG = {
|
| 96 |
+
"train_dataset": "/path/to/data/chartqa_output/llavacot/json/chartqa_train_processed.json",
|
| 97 |
+
"eval_dataset": "HuggingFaceM4/ChartQA",
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# ====== Full Configuration ======
|
| 101 |
+
CONFIG = {
|
| 102 |
+
"model": MODEL_CONFIG,
|
| 103 |
+
"training": TRAINING_CONFIG,
|
| 104 |
+
"rl": RL_CONFIG,
|
| 105 |
+
"client": CLIENT_CONFIG,
|
| 106 |
+
"dataset": DATASET_CONFIG,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Save configuration to a file for reference
|
| 110 |
+
def save_config(config, config_path="./config.json"):
|
| 111 |
+
import json
|
| 112 |
+
with open(config_path, "w") as f:
|
| 113 |
+
json.dump(config, f, indent=4)
|
| 114 |
+
|
| 115 |
+
# Example usage to save config
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
save_config(CONFIG)
|
| 118 |
+
|
config/config_low.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# ====== Model Configuration ======
|
| 5 |
+
MODEL_CONFIG = {
|
| 6 |
+
"pretrained_model_path": "/path/to/sft-llavaov-chart-low/checkpoint-296",
|
| 7 |
+
"use_flash_attention_2": True,
|
| 8 |
+
"torch_dtype": "bfloat16",
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
# ====== Training Configuration ======
|
| 12 |
+
TRAINING_CONFIG = {
|
| 13 |
+
"task": 'chart',
|
| 14 |
+
"num_gpus": 8,
|
| 15 |
+
"num_client": 8,
|
| 16 |
+
|
| 17 |
+
"dyme_args": {
|
| 18 |
+
"output_dir": '/path/to/dyme-llavaov-chart-low',
|
| 19 |
+
"logging_steps": 1,
|
| 20 |
+
"num_generations": 4,
|
| 21 |
+
"max_completion_length": 300,
|
| 22 |
+
"per_device_train_batch_size": 1,
|
| 23 |
+
"gradient_accumulation_steps": 16,
|
| 24 |
+
"num_train_epochs": 10,
|
| 25 |
+
"learning_rate": 1e-5,
|
| 26 |
+
"bf16": True,
|
| 27 |
+
"gradient_checkpointing": False,
|
| 28 |
+
"ddp_find_unused_parameters": False,
|
| 29 |
+
"max_grad_norm": 1.0,
|
| 30 |
+
"save_strategy": "epoch",
|
| 31 |
+
"weight_decay": 0.01,
|
| 32 |
+
"warmup_steps": 0,
|
| 33 |
+
"beta": 0.0, # GRPO specific
|
| 34 |
+
"loss_type": 'grpo', # GRPO specific
|
| 35 |
+
"seed": 42,
|
| 36 |
+
},
|
| 37 |
+
"sft_args": {
|
| 38 |
+
"output_dir": '/path/to/sft-llavaov-chart-low',
|
| 39 |
+
"logging_steps": 1,
|
| 40 |
+
"per_device_train_batch_size": 2,
|
| 41 |
+
"gradient_accumulation_steps": 4,
|
| 42 |
+
"num_train_epochs": 10,
|
| 43 |
+
"learning_rate": 1e-5,
|
| 44 |
+
"bf16": True,
|
| 45 |
+
"gradient_checkpointing": False,
|
| 46 |
+
"ddp_find_unused_parameters": False,
|
| 47 |
+
"max_grad_norm": 1.0,
|
| 48 |
+
"max_length": 4096,
|
| 49 |
+
# "save_steps": 100,
|
| 50 |
+
"save_strategy": "epoch",
|
| 51 |
+
"weight_decay": 0.01,
|
| 52 |
+
"warmup_steps": 0,
|
| 53 |
+
"seed": 42,
|
| 54 |
+
"remove_unused_columns": False
|
| 55 |
+
},
|
| 56 |
+
"grpo_args":{
|
| 57 |
+
"output_dir": '/path/to/grpo-llavaov-chart-low',
|
| 58 |
+
"logging_steps": 1,
|
| 59 |
+
"num_generations": 4,
|
| 60 |
+
"max_completion_length": 576,
|
| 61 |
+
"max_prompt_length": None,
|
| 62 |
+
"per_device_train_batch_size": 2,
|
| 63 |
+
"gradient_accumulation_steps": 4,
|
| 64 |
+
"num_train_epochs": 10,
|
| 65 |
+
"learning_rate": 1e-5,
|
| 66 |
+
"bf16": True,
|
| 67 |
+
"gradient_checkpointing": False,
|
| 68 |
+
"ddp_find_unused_parameters": False,
|
| 69 |
+
"max_grad_norm": 1.0,
|
| 70 |
+
"save_strategy": "epoch",
|
| 71 |
+
"weight_decay": 0.01,
|
| 72 |
+
"warmup_steps": 0,
|
| 73 |
+
"beta": 0.0, # GRPO specific
|
| 74 |
+
"loss_type": 'grpo', # GRPO specific
|
| 75 |
+
"seed": 42,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
RL_CONFIG = {
|
| 81 |
+
"answer_flag": "Answer:",
|
| 82 |
+
"end_flag": "<|im_end|>"
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# ====== Client Configuration for Reward Calculation ======
|
| 86 |
+
CLIENT_CONFIG = {
|
| 87 |
+
"client_type": "openai",
|
| 88 |
+
"api_key": "none",
|
| 89 |
+
"api_base": "http://127.0.0.1:%s/v1",
|
| 90 |
+
"timeout": 60,
|
| 91 |
+
"model_id": "Qwen/Qwen2.5-14B-Instruct-AWQ",
|
| 92 |
+
"init_port": 23333,
|
| 93 |
+
"num_server": 8
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# ====== Dataset Configuration ======
|
| 97 |
+
DATASET_CONFIG = {
|
| 98 |
+
"train_dataset": "/path/to/data/chartqa_output/json/train_low.json",
|
| 99 |
+
"eval_dataset": "HuggingFaceM4/ChartQA",
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# ====== Full Configuration ======
|
| 103 |
+
CONFIG = {
|
| 104 |
+
"model": MODEL_CONFIG,
|
| 105 |
+
"training": TRAINING_CONFIG,
|
| 106 |
+
"rl": RL_CONFIG,
|
| 107 |
+
"client": CLIENT_CONFIG,
|
| 108 |
+
"dataset": DATASET_CONFIG,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# Save configuration to a file for reference
|
| 112 |
+
def save_config(config, config_path="./config.json"):
|
| 113 |
+
import json
|
| 114 |
+
with open(config_path, "w") as f:
|
| 115 |
+
json.dump(config, f, indent=4)
|
| 116 |
+
|
| 117 |
+
# Example usage to save config
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
save_config(CONFIG)
|
| 120 |
+
|
config/config_opd_7b_chartqa.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
COPSD-style cross-model OPD on ChartQA (Method 2).
|
| 3 |
+
|
| 4 |
+
Frozen LLaVA-OneVision 7B teacher; student default 0.5B.
|
| 5 |
+
Inherits RLSD routing + embedded SFT cold-start gates from config_rlsd_chartqa.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import config.config_rlsd_chartqa as rlsd
|
| 10 |
+
from data_utils.paths import OUTPUTS_DIR
|
| 11 |
+
|
| 12 |
+
MODEL_CONFIG = {
|
| 13 |
+
**rlsd.MODEL_CONFIG,
|
| 14 |
+
"teacher_model_path": os.environ.get(
|
| 15 |
+
"DYME_TEACHER_MODEL",
|
| 16 |
+
"llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
| 17 |
+
),
|
| 18 |
+
"teacher_dtype": os.environ.get("DYME_TEACHER_DTYPE", "bfloat16"),
|
| 19 |
+
"teacher_device_map": os.environ.get("DYME_TEACHER_DEVICE_MAP") or None,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
DYME_OPSD_CONFIG = {
|
| 23 |
+
**rlsd.DYME_OPSD_CONFIG,
|
| 24 |
+
"mode": os.environ.get("DYME_OPSD_MODE", "rlsd"),
|
| 25 |
+
"privileged_providers": [],
|
| 26 |
+
"loss": {
|
| 27 |
+
**rlsd.DYME_OPSD_CONFIG.get("loss", {}),
|
| 28 |
+
"opsd_weight": float(os.environ.get("DYME_OPSD_WEIGHT", "1.5")),
|
| 29 |
+
},
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
CONFIG = {
|
| 33 |
+
"model": MODEL_CONFIG,
|
| 34 |
+
"training": {
|
| 35 |
+
**rlsd.CONFIG["training"],
|
| 36 |
+
"dyme_args": {
|
| 37 |
+
**rlsd.CONFIG["training"]["dyme_args"],
|
| 38 |
+
"output_dir": os.environ.get(
|
| 39 |
+
"DYME_OUTPUT_DIR",
|
| 40 |
+
os.path.join(OUTPUTS_DIR, "opd-7b-chartqa"),
|
| 41 |
+
),
|
| 42 |
+
},
|
| 43 |
+
},
|
| 44 |
+
"rl": rlsd.CONFIG["rl"],
|
| 45 |
+
"opsd": DYME_OPSD_CONFIG,
|
| 46 |
+
"client": rlsd.CONFIG["client"],
|
| 47 |
+
"dataset": rlsd.CONFIG["dataset"],
|
| 48 |
+
}
|
config/config_rlsd_chartqa.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RLSD / anti-leakage ChartQA config (Method 1).
|
| 3 |
+
|
| 4 |
+
- mode=rlsd: correct → GRPO, wrong → same-prompt OPSD, all-wrong group → online SFT
|
| 5 |
+
- No gold answer / visual_facts in teacher privileged context
|
| 6 |
+
- Hyperparameters based on config_trimode_antidegen
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import config.config_trimode_antidegen as antidegen
|
| 11 |
+
from data_utils.paths import OUTPUTS_DIR
|
| 12 |
+
|
| 13 |
+
MODEL_CONFIG = dict(antidegen.MODEL_CONFIG)
|
| 14 |
+
|
| 15 |
+
TRAINING_CONFIG = dict(antidegen.TRAINING_CONFIG)
|
| 16 |
+
|
| 17 |
+
_reward_weights_raw = os.environ.get("DYME_REWARD_WEIGHTS", "0.5,1.5,1.0")
|
| 18 |
+
try:
|
| 19 |
+
_reward_weights = [float(x.strip()) for x in _reward_weights_raw.split(",") if x.strip()]
|
| 20 |
+
if len(_reward_weights) != 3:
|
| 21 |
+
raise ValueError("expected 3 weights")
|
| 22 |
+
except ValueError:
|
| 23 |
+
_reward_weights = [0.5, 1.5, 1.0]
|
| 24 |
+
|
| 25 |
+
_providers_raw = os.environ.get("DYME_OPSD_PROVIDERS", "format_only").strip()
|
| 26 |
+
_privileged_providers = [p.strip() for p in _providers_raw.split(",") if p.strip()] if _providers_raw else []
|
| 27 |
+
|
| 28 |
+
_skip_degen_env = os.environ.get("DYME_OPSD_SKIP_DEGENERATE", "").strip().lower()
|
| 29 |
+
if _skip_degen_env in ("0", "false", "no", "off"):
|
| 30 |
+
_skip_degenerate_for_opsd = False
|
| 31 |
+
elif _skip_degen_env in ("1", "true", "yes", "on"):
|
| 32 |
+
_skip_degenerate_for_opsd = True
|
| 33 |
+
else:
|
| 34 |
+
_skip_degenerate_for_opsd = True
|
| 35 |
+
|
| 36 |
+
# Embedded SFT cold-start + RLSD warmup gates (env overrides optional).
|
| 37 |
+
_RLSD_GATE_DEFAULTS = {
|
| 38 |
+
"skip_degenerate_for_opsd": _skip_degenerate_for_opsd,
|
| 39 |
+
"degen_skip_warmup_steps": 200,
|
| 40 |
+
"sft_warmup_steps": 500,
|
| 41 |
+
"sft_warmup_slots_per_group": 4,
|
| 42 |
+
# First N steps: skip generate, 100% GT injection, pure SFT NLL (no OPSD/GRPO).
|
| 43 |
+
"sft_cold_start_frac": 0.08,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
DYME_OPSD_CONFIG = {
|
| 47 |
+
**antidegen.DYME_OPSD_CONFIG,
|
| 48 |
+
"mode": os.environ.get("DYME_OPSD_MODE", "rlsd"),
|
| 49 |
+
"text_include_gold": False,
|
| 50 |
+
"privileged_profile": os.environ.get("DYME_OPSD_PRIVILEGE_PROFILE", "text"),
|
| 51 |
+
"privileged_providers": _privileged_providers,
|
| 52 |
+
"gate": {
|
| 53 |
+
**antidegen.DYME_OPSD_CONFIG.get("gate", {}),
|
| 54 |
+
"per_completion_opsd": True,
|
| 55 |
+
"recoverable_without_privilege": True,
|
| 56 |
+
"require_format_for_opsd": os.environ.get("DYME_OPSD_REQUIRE_FORMAT", "0").strip().lower()
|
| 57 |
+
not in ("0", "false", "no", "off"),
|
| 58 |
+
"online_sft_on_all_wrong": True,
|
| 59 |
+
# ChartQA short numeric answers lack "Answer:" — do not block OPSD on format alone
|
| 60 |
+
"opsd_degenerate_require_answer_flag": False,
|
| 61 |
+
**_RLSD_GATE_DEFAULTS,
|
| 62 |
+
},
|
| 63 |
+
"loss": {
|
| 64 |
+
**antidegen.DYME_OPSD_CONFIG.get("loss", {}),
|
| 65 |
+
"acc_gate": True,
|
| 66 |
+
"opsd_weight": float(os.environ.get("DYME_OPSD_WEIGHT", "1.5")),
|
| 67 |
+
"grpo_weight": 1.0,
|
| 68 |
+
},
|
| 69 |
+
"reward_weights": _reward_weights,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
_dyme_args = {
|
| 73 |
+
**TRAINING_CONFIG["dyme_args"],
|
| 74 |
+
"output_dir": os.environ.get(
|
| 75 |
+
"DYME_OUTPUT_DIR",
|
| 76 |
+
os.path.join(OUTPUTS_DIR, "rlsd-chartqa"),
|
| 77 |
+
),
|
| 78 |
+
# Mitigate early RL collapse (newline + bare number + immediate EOS)
|
| 79 |
+
"max_completion_length": 96,
|
| 80 |
+
"temperature": 0.5,
|
| 81 |
+
"repetition_penalty": 1.5,
|
| 82 |
+
}
|
| 83 |
+
_max_steps_raw = os.environ.get("DYME_MAX_STEPS", "").strip()
|
| 84 |
+
if _max_steps_raw:
|
| 85 |
+
_dyme_args["max_steps"] = int(_max_steps_raw)
|
| 86 |
+
|
| 87 |
+
_temp_raw = os.environ.get("DYME_TEMPERATURE", "").strip()
|
| 88 |
+
if _temp_raw:
|
| 89 |
+
_dyme_args["temperature"] = float(_temp_raw)
|
| 90 |
+
_rep_raw = os.environ.get("DYME_REPETITION_PENALTY", "").strip()
|
| 91 |
+
if _rep_raw:
|
| 92 |
+
_dyme_args["repetition_penalty"] = float(_rep_raw)
|
| 93 |
+
_max_len_raw = os.environ.get("DYME_MAX_COMPLETION_LENGTH", "").strip()
|
| 94 |
+
if _max_len_raw:
|
| 95 |
+
_dyme_args["max_completion_length"] = int(_max_len_raw)
|
| 96 |
+
|
| 97 |
+
# Keep module-level TRAINING_CONFIG in sync so imports of TRAINING_CONFIG["dyme_args"] match CONFIG.
|
| 98 |
+
TRAINING_CONFIG = {**TRAINING_CONFIG, "dyme_args": _dyme_args}
|
| 99 |
+
|
| 100 |
+
# Optional env overrides for gate defaults (see _RLSD_GATE_DEFAULTS above).
|
| 101 |
+
_degen_warmup_raw = os.environ.get("DYME_OPSD_DEGEN_WARMUP_STEPS", "").strip()
|
| 102 |
+
if _degen_warmup_raw:
|
| 103 |
+
DYME_OPSD_CONFIG["gate"]["degen_skip_warmup_steps"] = int(_degen_warmup_raw)
|
| 104 |
+
|
| 105 |
+
_sft_warmup_raw = os.environ.get("DYME_SFT_WARMUP_STEPS", "").strip()
|
| 106 |
+
if _sft_warmup_raw:
|
| 107 |
+
DYME_OPSD_CONFIG["gate"]["sft_warmup_steps"] = int(_sft_warmup_raw)
|
| 108 |
+
|
| 109 |
+
_sft_slots_raw = os.environ.get("DYME_SFT_WARMUP_SLOTS", "").strip()
|
| 110 |
+
if _sft_slots_raw:
|
| 111 |
+
DYME_OPSD_CONFIG["gate"]["sft_warmup_slots_per_group"] = int(_sft_slots_raw)
|
| 112 |
+
|
| 113 |
+
_cold_start_steps_raw = os.environ.get("DYME_SFT_COLD_START_STEPS", "").strip()
|
| 114 |
+
if _cold_start_steps_raw:
|
| 115 |
+
DYME_OPSD_CONFIG["gate"]["sft_cold_start_steps"] = int(_cold_start_steps_raw)
|
| 116 |
+
DYME_OPSD_CONFIG["gate"].pop("sft_cold_start_frac", None)
|
| 117 |
+
else:
|
| 118 |
+
_cold_start_frac_raw = os.environ.get("DYME_SFT_COLD_START_FRAC", "").strip()
|
| 119 |
+
if _cold_start_frac_raw:
|
| 120 |
+
DYME_OPSD_CONFIG["gate"]["sft_cold_start_frac"] = float(_cold_start_frac_raw)
|
| 121 |
+
|
| 122 |
+
CONFIG = {
|
| 123 |
+
"model": MODEL_CONFIG,
|
| 124 |
+
"training": {
|
| 125 |
+
**TRAINING_CONFIG,
|
| 126 |
+
"dyme_args": _dyme_args,
|
| 127 |
+
"sft_args": {
|
| 128 |
+
"output_dir": os.environ.get(
|
| 129 |
+
"DYME_SFT_OUTPUT_DIR",
|
| 130 |
+
os.path.join(OUTPUTS_DIR, "chartqa-sft"),
|
| 131 |
+
),
|
| 132 |
+
"logging_steps": 10,
|
| 133 |
+
"per_device_train_batch_size": 2,
|
| 134 |
+
"gradient_accumulation_steps": 4,
|
| 135 |
+
"num_train_epochs": int(os.environ.get("DYME_SFT_EPOCHS", "2")),
|
| 136 |
+
"learning_rate": 1e-5,
|
| 137 |
+
"bf16": True,
|
| 138 |
+
"gradient_checkpointing": True,
|
| 139 |
+
"ddp_find_unused_parameters": False,
|
| 140 |
+
"max_grad_norm": 1.0,
|
| 141 |
+
"save_strategy": "epoch",
|
| 142 |
+
"weight_decay": 0.01,
|
| 143 |
+
"warmup_steps": 0,
|
| 144 |
+
"seed": 42,
|
| 145 |
+
"remove_unused_columns": False,
|
| 146 |
+
},
|
| 147 |
+
},
|
| 148 |
+
"rl": antidegen.CONFIG["rl"],
|
| 149 |
+
"opsd": DYME_OPSD_CONFIG,
|
| 150 |
+
"client": antidegen.CONFIG["client"],
|
| 151 |
+
"dataset": antidegen.CONFIG["dataset"],
|
| 152 |
+
}
|
config/config_trimode.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from config import CLIENT_CONFIG, DATASET_CONFIG, DYME_OPSD_CONFIG, MODEL_CONFIG, RL_CONFIG, TRAINING_CONFIG
|
| 4 |
+
from data_utils.paths import OUTPUTS_DIR
|
| 5 |
+
|
| 6 |
+
MODEL_CONFIG = dict(MODEL_CONFIG)
|
| 7 |
+
|
| 8 |
+
TRAINING_CONFIG = {
|
| 9 |
+
**TRAINING_CONFIG,
|
| 10 |
+
"dyme_args": {
|
| 11 |
+
**TRAINING_CONFIG["dyme_args"],
|
| 12 |
+
"output_dir": os.environ.get("DYME_OUTPUT_DIR", os.path.join(OUTPUTS_DIR, "dyme-trimode")),
|
| 13 |
+
},
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
_detail_every_raw = os.environ.get("DYME_OPSD_DETAIL_EVERY", "10")
|
| 17 |
+
try:
|
| 18 |
+
_detail_every = max(0, int(_detail_every_raw))
|
| 19 |
+
except ValueError:
|
| 20 |
+
_detail_every = 10
|
| 21 |
+
|
| 22 |
+
_probe_raw = os.environ.get("DYME_OPSD_PROBE_ON_GENERATE", "1").strip().lower()
|
| 23 |
+
_probe_on_generate = _probe_raw not in ("0", "false", "no", "off")
|
| 24 |
+
|
| 25 |
+
_first_logits_raw = os.environ.get("DYME_OPSD_PROBE_FIRST_TOKEN_LOGITS", "1").strip().lower()
|
| 26 |
+
_probe_first_token_logits = _first_logits_raw not in ("0", "false", "no", "off")
|
| 27 |
+
|
| 28 |
+
_tail_raw = os.environ.get("DYME_OPSD_PROBE_PROMPT_TAIL_TOKENS", "16").strip()
|
| 29 |
+
try:
|
| 30 |
+
_probe_prompt_tail_tokens = max(1, int(_tail_raw))
|
| 31 |
+
except ValueError:
|
| 32 |
+
_probe_prompt_tail_tokens = 16
|
| 33 |
+
|
| 34 |
+
_model_ctx_raw = os.environ.get("DYME_OPSD_PROBE_LOG_MODEL_CONTEXT", "1").strip().lower()
|
| 35 |
+
_probe_log_model_context = _model_ctx_raw not in ("0", "false", "no", "off")
|
| 36 |
+
|
| 37 |
+
_health_raw = os.environ.get("DYME_OPSD_HEALTH_MONITOR", "1").strip().lower()
|
| 38 |
+
_health_monitor_enabled = _health_raw not in ("0", "false", "no", "off")
|
| 39 |
+
|
| 40 |
+
_require_format_raw = os.environ.get("DYME_OPSD_REQUIRE_FORMAT", "1").strip().lower()
|
| 41 |
+
_require_format_for_opsd = _require_format_raw not in ("0", "false", "no", "off")
|
| 42 |
+
|
| 43 |
+
DYME_OPSD_CONFIG = {
|
| 44 |
+
**DYME_OPSD_CONFIG,
|
| 45 |
+
"enabled": True,
|
| 46 |
+
"mode": os.environ.get("DYME_OPSD_MODE", "trimode"),
|
| 47 |
+
"privileged_profile": os.environ.get("DYME_OPSD_PRIVILEGE_PROFILE", "hybrid"),
|
| 48 |
+
"privileged_providers": os.environ.get("DYME_OPSD_PROVIDERS", "text,visual_facts").split(","),
|
| 49 |
+
"privileged_image": {
|
| 50 |
+
**DYME_OPSD_CONFIG.get("privileged_image", {}),
|
| 51 |
+
"mode": os.environ.get("DYME_OPSD_PRIVILEGE_IMAGE_MODE", "single"),
|
| 52 |
+
"crop_strategy": os.environ.get("DYME_OPSD_CROP_STRATEGY", "bbox_then_center"),
|
| 53 |
+
"bbox_coord": "normalized",
|
| 54 |
+
"margin_ratio": float(os.environ.get("DYME_OPSD_CROP_MARGIN", "0.25")),
|
| 55 |
+
},
|
| 56 |
+
"privileged_debug": {
|
| 57 |
+
**DYME_OPSD_CONFIG.get("privileged_debug", {}),
|
| 58 |
+
"save_images": os.environ.get("DYME_OPSD_SAVE_PRIVILEGED_IMAGES", "1").strip().lower()
|
| 59 |
+
not in ("0", "false", "no", "off"),
|
| 60 |
+
"image_subdir": os.environ.get("DYME_OPSD_PRIVILEGED_IMAGE_DIR", "logs/images"),
|
| 61 |
+
"max_samples_per_detail": int(os.environ.get("DYME_OPSD_PRIVILEGED_IMAGE_MAX", "2")),
|
| 62 |
+
},
|
| 63 |
+
"gate": {
|
| 64 |
+
**DYME_OPSD_CONFIG.get("gate", {}),
|
| 65 |
+
"require_format_for_opsd": _require_format_for_opsd,
|
| 66 |
+
},
|
| 67 |
+
"debug": {
|
| 68 |
+
**DYME_OPSD_CONFIG.get("debug", {}),
|
| 69 |
+
"detail_every": _detail_every,
|
| 70 |
+
"probe_on_generate": _probe_on_generate,
|
| 71 |
+
"probe_first_token_logits": _probe_first_token_logits,
|
| 72 |
+
"probe_prompt_tail_tokens": _probe_prompt_tail_tokens,
|
| 73 |
+
"probe_log_model_context": _probe_log_model_context,
|
| 74 |
+
"health_monitor": {
|
| 75 |
+
**DYME_OPSD_CONFIG.get("debug", {}).get("health_monitor", {}),
|
| 76 |
+
"enabled": _health_monitor_enabled,
|
| 77 |
+
},
|
| 78 |
+
},
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
CONFIG = {
|
| 82 |
+
"model": MODEL_CONFIG,
|
| 83 |
+
"training": TRAINING_CONFIG,
|
| 84 |
+
"rl": RL_CONFIG,
|
| 85 |
+
"opsd": DYME_OPSD_CONFIG,
|
| 86 |
+
"client": CLIENT_CONFIG,
|
| 87 |
+
"dataset": DATASET_CONFIG,
|
| 88 |
+
}
|
default_config_8gpu.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: bf16
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 8
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
default_config_8gpu_deepspeed.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
gradient_accumulation_steps: 16
|
| 5 |
+
zero3_init_flag: false
|
| 6 |
+
zero_stage: 0
|
| 7 |
+
distributed_type: DEEPSPEED
|
| 8 |
+
downcast_bf16: 'no'
|
| 9 |
+
enable_cpu_affinity: false
|
| 10 |
+
machine_rank: 0
|
| 11 |
+
main_training_function: main
|
| 12 |
+
mixed_precision: bf16
|
| 13 |
+
num_machines: 1
|
| 14 |
+
# Optional ZeRO-0 for 8-GPU nodes. Requires: pip install deepspeed
|
| 15 |
+
num_processes: 8
|
| 16 |
+
rdzv_backend: static
|
| 17 |
+
same_network: true
|
| 18 |
+
tpu_env: []
|
| 19 |
+
tpu_use_cluster: false
|
| 20 |
+
tpu_use_sudo: false
|
| 21 |
+
use_cpu: false
|
default_config_zero2_8gpu.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_config_file: configs/deepspeed/zero2_bf16.json
|
| 5 |
+
zero3_init_flag: false
|
| 6 |
+
distributed_type: DEEPSPEED
|
| 7 |
+
downcast_bf16: 'no'
|
| 8 |
+
enable_cpu_affinity: false
|
| 9 |
+
machine_rank: 0
|
| 10 |
+
main_training_function: main
|
| 11 |
+
num_machines: 1
|
| 12 |
+
num_processes: 8
|
| 13 |
+
rdzv_backend: static
|
| 14 |
+
same_network: true
|
| 15 |
+
tpu_env: []
|
| 16 |
+
tpu_use_cluster: false
|
| 17 |
+
tpu_use_sudo: false
|
| 18 |
+
use_cpu: false
|
eval/eval_chartqa.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from accelerate import Accelerator
|
| 7 |
+
# Ensure this path is correct and the utility is available.
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from torch.distributed import all_gather_object
|
| 10 |
+
from transformers import AutoProcessor, AutoConfig, AutoTokenizer, LlavaOnevisionForConditionalGeneration
|
| 11 |
+
from trl.models import unwrap_model_for_generation
|
| 12 |
+
|
| 13 |
+
from data_utils.chart.evaluator import eval_one_chart
|
| 14 |
+
from data_utils.rl_prompt import PROMPT_TEMPLATE
|
| 15 |
+
from reward_utils.compute_rewards import split_initial_context
|
| 16 |
+
|
| 17 |
+
accelerator = Accelerator()
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
DEVICE = accelerator.device
|
| 22 |
+
|
| 23 |
+
# Model and Processor Configuration
|
| 24 |
+
model_args = {} # Use {"torch_dtype": torch.bfloat16} if desired and supported
|
| 25 |
+
|
| 26 |
+
_eval_parser = argparse.ArgumentParser(add_help=False)
|
| 27 |
+
_eval_parser.add_argument("--model_path", default=None)
|
| 28 |
+
_eval_args, _ = _eval_parser.parse_known_args()
|
| 29 |
+
model_id = (
|
| 30 |
+
_eval_args.model_path
|
| 31 |
+
or os.environ.get("CHECKPOINT_DIR")
|
| 32 |
+
or "/path/to/dyme-k-8/final_checkpoint"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, config=config, trust_remote_code=True)
|
| 37 |
+
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
| 38 |
+
model_id,
|
| 39 |
+
torch_dtype=torch.bfloat16,
|
| 40 |
+
low_cpu_mem_usage=True,
|
| 41 |
+
).to(DEVICE)
|
| 42 |
+
|
| 43 |
+
model.eval()
|
| 44 |
+
# Make sure model and processor are loaded before being potentially used in generate_inner if it were called
|
| 45 |
+
# model = Idefics3ForConditionalGeneration.from_pretrained(model_id, **model_args).to(DEVICE)
|
| 46 |
+
|
| 47 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 48 |
+
|
| 49 |
+
# Configure image processor size
|
| 50 |
+
# This can consume significant VRAM. Ensure it's intended.
|
| 51 |
+
if hasattr(processor.image_processor, 'size') and isinstance(processor.image_processor.size, dict):
|
| 52 |
+
# if 'longest_edge' in processor.image_processor.size:
|
| 53 |
+
# print('Setting image processor longest_edge to 2048')
|
| 54 |
+
# processor.image_processor.size['longest_edge'] = 512 * 4
|
| 55 |
+
processor.tokenizer.padding_side = 'left'
|
| 56 |
+
else:
|
| 57 |
+
print(
|
| 58 |
+
f"Warning: Could not directly set 'longest_edge' via dict. Current image processor size config: {processor.image_processor.size}"
|
| 59 |
+
)
|
| 60 |
+
# Attempt an alternative if applicable, e.g.
|
| 61 |
+
# processor.image_processor.size = {"longest_edge": 512 * 4} # if size itself can be replaced
|
| 62 |
+
# Or this might indicate that `size` is a single value or a different structure.
|
| 63 |
+
|
| 64 |
+
def run_kh_batch(batch_data_list): # Renamed from run_kh, takes a batch
|
| 65 |
+
batch_images = []
|
| 66 |
+
batch_formatted_prompts_for_chat_template = []
|
| 67 |
+
|
| 68 |
+
for item in batch_data_list:
|
| 69 |
+
image_path = item['image_path']
|
| 70 |
+
# 'item_model_input_text' already contains chart instructions + raw_question
|
| 71 |
+
item_model_input_text = item['model_input_text'].strip()
|
| 72 |
+
|
| 73 |
+
# question_with_tags = prompt + item_model_input_text
|
| 74 |
+
# question_with_tags = f"""{item_model_input_text} Think step by step and then answer the question."""
|
| 75 |
+
question_with_tags = PROMPT_TEMPLATE.format(question=item_model_input_text)
|
| 76 |
+
if isinstance(image_path, str):
|
| 77 |
+
image = Image.open(image_path).convert("RGB")
|
| 78 |
+
else:
|
| 79 |
+
image = image_path.convert("RGB") # Assuming image_path is already a PIL Image object
|
| 80 |
+
batch_images.append(image)
|
| 81 |
+
|
| 82 |
+
messages = [
|
| 83 |
+
{
|
| 84 |
+
"role": "user",
|
| 85 |
+
"content": [
|
| 86 |
+
{"type": "image"},
|
| 87 |
+
{"type": "text", "text": question_with_tags},
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
]
|
| 91 |
+
try:
|
| 92 |
+
templated_prompt_str = processor.apply_chat_template(messages, add_generation_prompt=True)
|
| 93 |
+
templated_prompt_str = templated_prompt_str.strip()
|
| 94 |
+
except Exception:
|
| 95 |
+
templated_prompt_str = f"USER: <image>\n{question_with_tags}\nASSISTANT:"
|
| 96 |
+
batch_formatted_prompts_for_chat_template.append(templated_prompt_str)
|
| 97 |
+
|
| 98 |
+
inputs = processor(
|
| 99 |
+
text=batch_formatted_prompts_for_chat_template,
|
| 100 |
+
images=batch_images,
|
| 101 |
+
return_tensors="pt",
|
| 102 |
+
padding=True,
|
| 103 |
+
truncation=True
|
| 104 |
+
)
|
| 105 |
+
# inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 106 |
+
inputs = {
|
| 107 |
+
k: v.to(DEVICE).to(torch.bfloat16) if v.is_floating_point() else v.to(DEVICE)
|
| 108 |
+
for k, v in inputs.items()
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
with unwrap_model_for_generation(model, accelerator) as unwrapped_model_instance:
|
| 112 |
+
generated_ids = unwrapped_model_instance.generate(
|
| 113 |
+
**inputs,
|
| 114 |
+
max_new_tokens=1024,
|
| 115 |
+
do_sample=False,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
input_ids_length = inputs['input_ids'].shape[1]
|
| 119 |
+
newly_generated_ids = generated_ids[:, input_ids_length:]
|
| 120 |
+
|
| 121 |
+
generated_texts = processor.batch_decode(
|
| 122 |
+
newly_generated_ids,
|
| 123 |
+
skip_special_tokens=True, # Special tokens like <eos> are removed. <image> might be too.
|
| 124 |
+
)
|
| 125 |
+
return [text.strip('.').strip() for text in generated_texts]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# --- Main Evaluation Logic ---
|
| 129 |
+
task = 'chart'
|
| 130 |
+
# dt_record_local is initialized inside the if task == 'chart' block
|
| 131 |
+
|
| 132 |
+
if task == 'chart':
|
| 133 |
+
dt_record_local = {} # Store results for the current process
|
| 134 |
+
if accelerator.is_main_process:
|
| 135 |
+
print("Loading ChartQA dataset...")
|
| 136 |
+
try:
|
| 137 |
+
full_dataset = load_dataset("HuggingFaceM4/ChartQA", trust_remote_code=True)['test']
|
| 138 |
+
except Exception as e:
|
| 139 |
+
if accelerator.is_main_process:
|
| 140 |
+
print(f"Failed to load dataset directly. Error: {e}")
|
| 141 |
+
print("Attempting to load with specific revision if applicable, or check path/connection.")
|
| 142 |
+
# For example, you can try a specific revision (if known) or ensure path and network connection are correct
|
| 143 |
+
# full_dataset = load_dataset("HuggingFaceM4/ChartQA", revision="main", trust_remote_code=True)['test']
|
| 144 |
+
raise # Re-raise the exception since we cannot proceed without the dataset
|
| 145 |
+
|
| 146 |
+
# full_dataset = full_dataset.select(range(80)) # Uncomment for quick testing
|
| 147 |
+
|
| 148 |
+
eval_datasets_all_prepared = []
|
| 149 |
+
# chart_instructions_prefix = (
|
| 150 |
+
# "For the question below, follow the following instructions:\n"
|
| 151 |
+
# # ... (your detailed instructions) ...
|
| 152 |
+
# + "-Try to include the full label from the graph when asked about an entity.\n"
|
| 153 |
+
# + "Question: "
|
| 154 |
+
# )
|
| 155 |
+
|
| 156 |
+
for d_item in tqdm(full_dataset, desc="Preparing dataset", disable=not accelerator.is_main_process):
|
| 157 |
+
image_path = d_item['image']
|
| 158 |
+
raw_question = d_item['query']
|
| 159 |
+
answer_list = d_item.get('label') # Use .get() in case 'label' field does not exist
|
| 160 |
+
if not answer_list: # If 'label' is missing or an empty list
|
| 161 |
+
if accelerator.is_main_process:
|
| 162 |
+
tqdm.write(
|
| 163 |
+
f"Warning: Item missing 'label' or 'label' is empty. Query: {raw_question[:50]}..."
|
| 164 |
+
)
|
| 165 |
+
# Decide how to handle this: skip this sample or use a default answer
|
| 166 |
+
continue # Skip this sample
|
| 167 |
+
answer = answer_list[0]
|
| 168 |
+
|
| 169 |
+
model_input_text_for_template = raw_question
|
| 170 |
+
eval_datasets_all_prepared.append({
|
| 171 |
+
'image_path': image_path,
|
| 172 |
+
'model_input_text': model_input_text_for_template,
|
| 173 |
+
'answer': answer,
|
| 174 |
+
'original_question': raw_question
|
| 175 |
+
})
|
| 176 |
+
|
| 177 |
+
num_processes = accelerator.num_processes
|
| 178 |
+
process_index = accelerator.process_index
|
| 179 |
+
total_items = len(eval_datasets_all_prepared)
|
| 180 |
+
|
| 181 |
+
if total_items == 0:
|
| 182 |
+
if accelerator.is_main_process:
|
| 183 |
+
print("No data prepared for evaluation after filtering. Exiting chart evaluation.")
|
| 184 |
+
else:
|
| 185 |
+
items_per_proc = total_items // num_processes
|
| 186 |
+
extra_items = total_items % num_processes
|
| 187 |
+
local_start_index = process_index * items_per_proc + min(process_index, extra_items)
|
| 188 |
+
num_local_items = items_per_proc + (1 if process_index < extra_items else 0)
|
| 189 |
+
local_end_index = local_start_index + num_local_items
|
| 190 |
+
eval_datasets_local = eval_datasets_all_prepared[local_start_index:local_end_index]
|
| 191 |
+
|
| 192 |
+
BATCH_SIZE = 32 # Adjust according to your VRAM
|
| 193 |
+
REPORT_INTERVAL_BATCHES = 1 # Report once every N local batches (main process prints global stats)
|
| 194 |
+
|
| 195 |
+
# if accelerator.is_main_process:
|
| 196 |
+
# print(f"Total items for evaluation: {total_items}")
|
| 197 |
+
# print(f"Process {process_index} handling {len(eval_datasets_local)} items.")
|
| 198 |
+
# print(f"Batch size per process: {BATCH_SIZE}, Reporting interval: {REPORT_INTERVAL_BATCHES} local batches.")
|
| 199 |
+
|
| 200 |
+
pbar = None
|
| 201 |
+
if accelerator.is_main_process and len(eval_datasets_local) > 0: # Create pbar only if there is data
|
| 202 |
+
pbar = tqdm(total=len(eval_datasets_local), desc=f"Eval Proc {process_index}", dynamic_ncols=True)
|
| 203 |
+
|
| 204 |
+
dt_record_local['res'] = []
|
| 205 |
+
num_local_batches = (len(eval_datasets_local) + BATCH_SIZE - 1) // BATCH_SIZE
|
| 206 |
+
|
| 207 |
+
for batch_idx_local in range(num_local_batches):
|
| 208 |
+
start_idx = batch_idx_local * BATCH_SIZE
|
| 209 |
+
end_idx = min((batch_idx_local + 1) * BATCH_SIZE, len(eval_datasets_local))
|
| 210 |
+
current_batch_list = eval_datasets_local[start_idx:end_idx]
|
| 211 |
+
|
| 212 |
+
if not current_batch_list:
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
batch_predictions_texts = run_kh_batch(current_batch_list)
|
| 216 |
+
|
| 217 |
+
for item_idx_in_batch, full_pred_text in enumerate(batch_predictions_texts):
|
| 218 |
+
original_item = current_batch_list[item_idx_in_batch]
|
| 219 |
+
ground_truth_answer = original_item['answer']
|
| 220 |
+
|
| 221 |
+
_, parsed_pred_answer = split_initial_context(full_pred_text)
|
| 222 |
+
if not parsed_pred_answer.strip():
|
| 223 |
+
parsed_pred_answer = full_pred_text # Fallback to full prediction if parsed answer is empty
|
| 224 |
+
|
| 225 |
+
score = eval_one_chart(parsed_pred_answer, ground_truth_answer) # nlp object is global
|
| 226 |
+
dt_record_local['res'].append(score)
|
| 227 |
+
|
| 228 |
+
# (Optional) Main process prints a few prediction details
|
| 229 |
+
if accelerator.is_main_process:
|
| 230 |
+
print(full_pred_text, "######", ground_truth_answer, "######", score)
|
| 231 |
+
|
| 232 |
+
if pbar:
|
| 233 |
+
pbar.update(len(current_batch_list))
|
| 234 |
+
|
| 235 |
+
# --- Intermediate reporting logic ---
|
| 236 |
+
is_last_local_batch = (batch_idx_local == num_local_batches - 1)
|
| 237 |
+
# Every REPORT_INTERVAL_BATCHES local batches, or on the last local batch of this process,
|
| 238 |
+
# perform synchronization and reporting
|
| 239 |
+
should_sync_and_report = ((batch_idx_local + 1) % REPORT_INTERVAL_BATCHES == 0) or is_last_local_batch
|
| 240 |
+
|
| 241 |
+
# Make sure that even if REPORT_INTERVAL_BATCHES is 1, we do not report when there is no data
|
| 242 |
+
# (e.g., len(eval_datasets_local) == 0)
|
| 243 |
+
if len(eval_datasets_local) == 0: # If the current process has no data, skip reporting logic
|
| 244 |
+
should_sync_and_report = False
|
| 245 |
+
# If num_local_batches > 0, this check ensures we report only when there is data
|
| 246 |
+
|
| 247 |
+
if num_local_batches == 0 and is_last_local_batch: # Special case: process has no data but must join final sync
|
| 248 |
+
should_sync_and_report = True
|
| 249 |
+
|
| 250 |
+
if should_sync_and_report:
|
| 251 |
+
accelerator.wait_for_everyone() # Wait for all processes to reach the sync point
|
| 252 |
+
|
| 253 |
+
gathered_all_processes_data = [None] * num_processes
|
| 254 |
+
# Each process sends its *current accumulated* dt_record_local
|
| 255 |
+
# If a process has no data, dt_record_local['res'] is an empty list, which is fine
|
| 256 |
+
all_gather_object(gathered_all_processes_data, dt_record_local)
|
| 257 |
+
|
| 258 |
+
if accelerator.is_main_process:
|
| 259 |
+
current_global_scores_list = []
|
| 260 |
+
for process_data_dict in gathered_all_processes_data:
|
| 261 |
+
if process_data_dict and 'res' in process_data_dict:
|
| 262 |
+
current_global_scores_list.extend(process_data_dict['res'])
|
| 263 |
+
|
| 264 |
+
total_samples_processed_globally = len(current_global_scores_list)
|
| 265 |
+
|
| 266 |
+
report_title = "--- Intermediate Report ---"
|
| 267 |
+
# Check whether this is the final reporting point where all processes have finished
|
| 268 |
+
# A simple heuristic: if this is the last local batch on the main process
|
| 269 |
+
# and the total collected samples equal the total number of items
|
| 270 |
+
if is_last_local_batch and total_samples_processed_globally == total_items:
|
| 271 |
+
report_title = "--- Final Report ---"
|
| 272 |
+
elif is_last_local_batch: # Last batch on main process but perhaps not all samples are done yet
|
| 273 |
+
report_title = (
|
| 274 |
+
f"--- Report (Main Proc Last Batch, {batch_idx_local + 1}/{num_local_batches}) ---"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
tqdm.write(f"\n{report_title}") # Use tqdm.write to avoid clashing with the progress bar
|
| 278 |
+
if current_global_scores_list:
|
| 279 |
+
mean_acc_global = np.array(current_global_scores_list).mean()
|
| 280 |
+
if accelerator.is_main_process:
|
| 281 |
+
print(f"Global samples processed: {total_samples_processed_globally} / {total_items}")
|
| 282 |
+
print(f"Current Global Mean Accuracy: {mean_acc_global:.4f}")
|
| 283 |
+
if pbar:
|
| 284 |
+
pbar.set_description(
|
| 285 |
+
f"Global Acc: {mean_acc_global:.4f} ({total_samples_processed_globally}/{total_items})"
|
| 286 |
+
)
|
| 287 |
+
else:
|
| 288 |
+
if accelerator.is_main_process:
|
| 289 |
+
print(
|
| 290 |
+
f"No scores to report globally yet (Total processed: {total_samples_processed_globally})."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
accelerator.wait_for_everyone() # Sync again after reporting in case some processes move ahead faster
|
| 294 |
+
|
| 295 |
+
if pbar:
|
| 296 |
+
pbar.close()
|
| 297 |
+
|
| 298 |
+
# Final metrics have already been printed in the last report
|
| 299 |
+
# (when is_last_local_batch is True)
|
| 300 |
+
if accelerator.is_main_process and len(eval_datasets_local) == 0 and total_items > 0:
|
| 301 |
+
print(
|
| 302 |
+
"Main process had no data, but other processes might have. "
|
| 303 |
+
"Final global metrics are printed by the last reporting sync."
|
| 304 |
+
)
|
| 305 |
+
elif accelerator.is_main_process and total_items == 0:
|
| 306 |
+
print("No data was prepared for evaluation. Nothing to report.")
|
| 307 |
+
|
| 308 |
+
else:
|
| 309 |
+
if accelerator.is_main_process:
|
| 310 |
+
print(f"Task '{task}' is not configured for batched evaluation in this script.")
|
figs/chartqa.png
ADDED
|
kill_all.sh
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
readonly WORKER_HOSTS=(
|
| 5 |
+
"xx.xx.xx.xx"
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
readonly REMOTE_USER="root"
|
| 9 |
+
|
| 10 |
+
readonly TRAIN_SCRIPT="main"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
echo "--- Killing local processes matching '${TRAIN_SCRIPT}' first ---"
|
| 16 |
+
|
| 17 |
+
pkill -9 -f "${TRAIN_SCRIPT}" || true
|
| 18 |
+
pkill -f python
|
| 19 |
+
echo "Local check complete."
|
| 20 |
+
echo
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
echo "🛑 Sending targeted kill signal to processes matching '${TRAIN_SCRIPT}' on all remote hosts in parallel..."
|
| 24 |
+
|
| 25 |
+
for HOST in "${WORKER_HOSTS[@]}"; do
|
| 26 |
+
|
| 27 |
+
(
|
| 28 |
+
echo "--- Processing host: ${HOST} ---"
|
| 29 |
+
|
| 30 |
+
ssh -n "${REMOTE_USER}@${HOST}" "
|
| 31 |
+
set -e # 远程脚本也应该在出错时停止
|
| 32 |
+
pkill -f python
|
| 33 |
+
# 精确查找由 python 启动的、且包含 TRAIN_SCRIPT 名称的进程
|
| 34 |
+
# 这是为了避免误杀其他同名进程(比如一个名为 'main_rebuttal' 的shell脚本)
|
| 35 |
+
PIDS=\$(pgrep -f \"python.*${TRAIN_SCRIPT}\")
|
| 36 |
+
|
| 37 |
+
if [ -z \"\$PIDS\" ]; then
|
| 38 |
+
echo '[INFO] ✅ No matching processes found on this host.'
|
| 39 |
+
else
|
| 40 |
+
echo '[WARN] 🔥 Found processes to kill:'
|
| 41 |
+
# 在杀死前显示详细信息,增加安全性
|
| 42 |
+
ps -fp \$PIDS
|
| 43 |
+
echo '[KILL] Killing PIDs: '\$PIDS'...'
|
| 44 |
+
kill -9 \$PIDS
|
| 45 |
+
echo '[OK] ✅ Processes killed successfully.'
|
| 46 |
+
fi
|
| 47 |
+
"
|
| 48 |
+
echo "--- Finished host: ${HOST} ---"
|
| 49 |
+
echo
|
| 50 |
+
) &
|
| 51 |
+
done
|
| 52 |
+
|
| 53 |
+
wait
|
| 54 |
+
|
| 55 |
+
echo "🎉 All hosts have been processed."
|
main.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train_grpo.py
|
| 2 |
+
"""
|
| 3 |
+
Main script for training a Llava-based model using the custom MyGRPOTrainer.
|
| 4 |
+
|
| 5 |
+
This script handles:
|
| 6 |
+
1. Configuration loading.
|
| 7 |
+
2. Initialization of Weights & Biases (wandb) and Hugging Face Accelerate.
|
| 8 |
+
3. Loading the model and processor.
|
| 9 |
+
4. Preparing the training and evaluation datasets.
|
| 10 |
+
5. Setting up and running the GPRO trainer.
|
| 11 |
+
"""
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Dict, Any
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import wandb
|
| 19 |
+
from accelerate import Accelerator
|
| 20 |
+
from datasets import Dataset, load_dataset
|
| 21 |
+
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
| 22 |
+
from trl import GRPOConfig
|
| 23 |
+
|
| 24 |
+
from config.loader import load_config
|
| 25 |
+
from data_utils.commom_util import collate_fn, define_task_data_func
|
| 26 |
+
from trainer.DyMETrainer import DyMETrainer
|
| 27 |
+
from reward_utils.checker import RewardCalculator, RewardCalculatorLocal
|
| 28 |
+
from reward_utils.refiner import ContextRefiner, ContextRefinerLocal
|
| 29 |
+
from opsd_utils import debug_log as opsd_debug
|
| 30 |
+
from opsd_utils.teacher_batching import (
|
| 31 |
+
log_teacher_placement,
|
| 32 |
+
resolve_teacher_device_map,
|
| 33 |
+
)
|
| 34 |
+
from opsd_utils.deepspeed_utils import (
|
| 35 |
+
deepspeed_zero_stage,
|
| 36 |
+
gradient_checkpointing_enable_kwargs,
|
| 37 |
+
is_deepspeed_accelerate_config,
|
| 38 |
+
should_disable_gradient_checkpointing,
|
| 39 |
+
uses_deepspeed_json_file,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _run_cross_model_vocab_checks(model, processor, teacher_model, model_config: Dict[str, Any]) -> None:
|
| 44 |
+
"""Startup checks for cross-model OPD vocab slice + tokenizer alignment."""
|
| 45 |
+
from transformers import AutoProcessor
|
| 46 |
+
|
| 47 |
+
from opsd_utils.vocab_align import print_vocab_align_report, verify_shared_tokenizer_alignment
|
| 48 |
+
|
| 49 |
+
student_vocab = getattr(model.config, "vocab_size", len(processor.tokenizer))
|
| 50 |
+
teacher_vocab = getattr(teacher_model.config, "vocab_size", student_vocab)
|
| 51 |
+
shared = min(student_vocab, teacher_vocab)
|
| 52 |
+
print(
|
| 53 |
+
f"[OPSD-VOCAB] lm_head widths: student={student_vocab} teacher={teacher_vocab} "
|
| 54 |
+
f"shared_slice={shared}",
|
| 55 |
+
flush=True,
|
| 56 |
+
)
|
| 57 |
+
if student_vocab == teacher_vocab:
|
| 58 |
+
print("[OPSD-VOCAB] vocab sizes match — no slice needed", flush=True)
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
teacher_path = model_config.get("teacher_model_path")
|
| 62 |
+
teacher_processor = AutoProcessor.from_pretrained(teacher_path)
|
| 63 |
+
full_scan = os.environ.get("DYME_VOCAB_ALIGN_FULL", "0").strip().lower() in ("1", "true", "yes")
|
| 64 |
+
stride = int(os.environ.get("DYME_VOCAB_ALIGN_STRIDE", "500"))
|
| 65 |
+
report = verify_shared_tokenizer_alignment(
|
| 66 |
+
processor.tokenizer,
|
| 67 |
+
teacher_processor.tokenizer,
|
| 68 |
+
shared_vocab=shared,
|
| 69 |
+
full_scan=full_scan,
|
| 70 |
+
sample_stride=stride,
|
| 71 |
+
)
|
| 72 |
+
print_vocab_align_report(report)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _wandb_disabled_by_env() -> bool:
|
| 76 |
+
if os.environ.get("WANDB_DISABLED", "").lower() in ("true", "1", "yes", "on"):
|
| 77 |
+
return True
|
| 78 |
+
if os.environ.get("WANDB_MODE", "").lower() in ("disabled", "off"):
|
| 79 |
+
return True
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _try_wandb_login() -> bool:
|
| 84 |
+
"""Return True if wandb credentials are available (env, offline, or prior login)."""
|
| 85 |
+
if os.environ.get("WANDB_MODE", "").lower() == "offline":
|
| 86 |
+
return True
|
| 87 |
+
wandb_key = os.environ.get("WANDB_API_KEY")
|
| 88 |
+
if wandb_key:
|
| 89 |
+
wandb.login(key=wandb_key)
|
| 90 |
+
return True
|
| 91 |
+
try:
|
| 92 |
+
wandb.login(relogin=False)
|
| 93 |
+
key = wandb.api.api_key
|
| 94 |
+
return bool(key and len(key) >= 40)
|
| 95 |
+
except Exception:
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def setup_accelerator_and_wandb(bf16, want_wandb: bool) -> tuple[Accelerator, bool]:
|
| 100 |
+
"""
|
| 101 |
+
Initialize Accelerator and optionally wandb.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
(accelerator, use_wandb)
|
| 105 |
+
"""
|
| 106 |
+
use_wandb = want_wandb and not _wandb_disabled_by_env()
|
| 107 |
+
if use_wandb:
|
| 108 |
+
use_wandb = _try_wandb_login()
|
| 109 |
+
|
| 110 |
+
accel_kwargs: dict = {}
|
| 111 |
+
# bf16 for DDP/MULTI_GPU only; with deepspeed_config_file, precision lives in the JSON.
|
| 112 |
+
if bf16 and not uses_deepspeed_json_file():
|
| 113 |
+
accel_kwargs["mixed_precision"] = "bf16"
|
| 114 |
+
if use_wandb:
|
| 115 |
+
accel_kwargs["log_with"] = "wandb"
|
| 116 |
+
return Accelerator(**accel_kwargs), use_wandb
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_model_and_processor(model_config: Dict[str, Any]):
|
| 120 |
+
"""
|
| 121 |
+
Loads the pre-trained vision-language model and its associated processor.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
model_config (Dict[str, Any]): Configuration dictionary for the model.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Tuple[LlavaOnevisionForConditionalGeneration, PreTrainedProcessor]: The loaded model and processor.
|
| 128 |
+
"""
|
| 129 |
+
model_id = model_config['pretrained_model_path']
|
| 130 |
+
|
| 131 |
+
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
| 132 |
+
model_id,
|
| 133 |
+
torch_dtype=getattr(torch, model_config['torch_dtype']),
|
| 134 |
+
attn_implementation='flash_attention_2' if model_config['use_flash_attention_2'] else 'sdpa',
|
| 135 |
+
low_cpu_mem_usage=True,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Freeze the vision tower to save memory and computation
|
| 139 |
+
model.base_model.vision_tower.requires_grad_(False)
|
| 140 |
+
|
| 141 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 142 |
+
processor.tokenizer.padding_side = "left"
|
| 143 |
+
|
| 144 |
+
return model, processor
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def load_teacher_model(model_config: Dict[str, Any], *, local_rank: int = 0, num_gpus: int = 1):
|
| 148 |
+
"""Load optional frozen teacher for cross-model OPD (e.g. LLaVA-OneVision 7B)."""
|
| 149 |
+
teacher_path = model_config.get("teacher_model_path")
|
| 150 |
+
if not teacher_path:
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
dtype_name = model_config.get("teacher_dtype", model_config.get("torch_dtype", "bfloat16"))
|
| 154 |
+
torch_dtype = getattr(torch, dtype_name)
|
| 155 |
+
requested_map = model_config.get("teacher_device_map")
|
| 156 |
+
if not requested_map:
|
| 157 |
+
env_map = os.environ.get("DYME_TEACHER_DEVICE_MAP", "").strip()
|
| 158 |
+
if env_map:
|
| 159 |
+
requested_map = env_map
|
| 160 |
+
|
| 161 |
+
device_map = resolve_teacher_device_map(
|
| 162 |
+
requested_map,
|
| 163 |
+
local_rank=local_rank,
|
| 164 |
+
num_gpus=max(1, num_gpus),
|
| 165 |
+
)
|
| 166 |
+
log_teacher_placement(
|
| 167 |
+
local_rank=local_rank,
|
| 168 |
+
num_gpus=max(1, num_gpus),
|
| 169 |
+
teacher_path=teacher_path,
|
| 170 |
+
resolved_device=device_map,
|
| 171 |
+
requested_map=requested_map,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
load_kwargs: Dict[str, Any] = {
|
| 175 |
+
"torch_dtype": torch_dtype,
|
| 176 |
+
"low_cpu_mem_usage": True,
|
| 177 |
+
"device_map": device_map,
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
teacher = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
| 181 |
+
teacher_path,
|
| 182 |
+
attn_implementation='flash_attention_2' if model_config.get('use_flash_attention_2') else 'sdpa',
|
| 183 |
+
**load_kwargs,
|
| 184 |
+
)
|
| 185 |
+
teacher.eval()
|
| 186 |
+
teacher.requires_grad_(False)
|
| 187 |
+
if hasattr(teacher, "base_model") and hasattr(teacher.base_model, "vision_tower"):
|
| 188 |
+
teacher.base_model.vision_tower.requires_grad_(False)
|
| 189 |
+
|
| 190 |
+
return teacher
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def prepare_datasets(task: str, dataset_config: Dict[str, Any], mode='rl') -> (Dataset, Dataset):
|
| 194 |
+
"""
|
| 195 |
+
Prepares the training and evaluation datasets based on the specified task.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
task (str): The name of the task (e.g., 'chartqa').
|
| 199 |
+
dataset_config (Dict[str, Any]): Configuration for datasets.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Tuple[Dataset, Dataset]: The training and evaluation datasets.
|
| 203 |
+
"""
|
| 204 |
+
data_func = define_task_data_func(task, mode=mode)
|
| 205 |
+
|
| 206 |
+
# Create training dataset
|
| 207 |
+
train_data_list = data_func(json_path=dataset_config['train_dataset'])
|
| 208 |
+
train_dataset = Dataset.from_list(train_data_list)
|
| 209 |
+
|
| 210 |
+
# Create evaluation dataset
|
| 211 |
+
if 'chart' in task:
|
| 212 |
+
eval_dataset = load_dataset(dataset_config['eval_dataset'])['test']
|
| 213 |
+
# Note: You can uncomment the line below for quick testing/debugging.
|
| 214 |
+
# eval_dataset = eval_dataset.select(range(1000, 1100))
|
| 215 |
+
|
| 216 |
+
else:
|
| 217 |
+
# Extend this section for other tasks if needed in the future.
|
| 218 |
+
eval_dataset = None
|
| 219 |
+
|
| 220 |
+
return train_dataset, eval_dataset
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def main():
|
| 224 |
+
"""
|
| 225 |
+
Main function to orchestrate the model training pipeline.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
parser = argparse.ArgumentParser(description="Train a Llava model using either SFT or GRPO.")
|
| 229 |
+
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
'--config', type=str, default='config/config.py',
|
| 232 |
+
help="Python config path (e.g. config/config.py, config/config_trimode.py) "
|
| 233 |
+
"or shorthand alias: norm | trimode | llavacot | low | aok",
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
'--mode', type=str, default='rl',
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
'--opsd_mode', type=str, default=None,
|
| 240 |
+
help="OPSD routing mode: dyme | trimode | rlsd | copsd_opd | opsd_only | replace_sft | opsd_on_wrong | grpo_opsd_joint",
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
'--opsd_providers', type=str, default=None,
|
| 244 |
+
help="Comma-separated privileged providers: text,visual_facts,crop,hybrid",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
'--opsd_privilege_profile', type=str, default=None,
|
| 248 |
+
help="Privileged profile preset: text | visual | hybrid (default hybrid in config_trimode)",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
'--reward_weights', type=str, default=None,
|
| 252 |
+
help="Comma-separated reward weights: format,context,acc (e.g. 0.5,1.5,1.0). "
|
| 253 |
+
"Overrides config; env DYME_REWARD_WEIGHTS also supported in antidegen config.",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
'--opsd_enabled', action='store_true',
|
| 257 |
+
help="Enable OPSD / TriMode training extensions",
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
'--opsd_debug', action='store_true',
|
| 261 |
+
help="Enable verbose OPSD debug logs (or set env DYME_OPSD_DEBUG=1)",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
'--opsd_detail_every', type=int, default=None,
|
| 265 |
+
help="Emit full weak-signal diagnostic bundle every N global steps on rank 0 "
|
| 266 |
+
"(default 10; config opsd.debug.detail_every or env DYME_OPSD_DETAIL_EVERY)",
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
'--opsd_probe_on_generate', dest='opsd_probe_on_generate', action='store_true',
|
| 270 |
+
help="Emit [OPSD-PROBE] on every (re)generate on rank 0 (config_trimode default on)",
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
'--no_opsd_probe_on_generate', dest='opsd_probe_on_generate', action='store_false',
|
| 274 |
+
help="Disable per-generate [OPSD-PROBE] logs",
|
| 275 |
+
)
|
| 276 |
+
parser.set_defaults(opsd_probe_on_generate=None)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
'--no_opsd_probe_first_token_logits', dest='opsd_probe_first_token_logits', action='store_false',
|
| 279 |
+
help="Disable pre-generate first-token logits probe ([OPSD-GENDBG])",
|
| 280 |
+
)
|
| 281 |
+
parser.set_defaults(opsd_probe_first_token_logits=None)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
'--wandb', dest='wandb', action='store_true',
|
| 284 |
+
help="Force enable Weights & Biases logging",
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
'--no_wandb', dest='wandb', action='store_false',
|
| 288 |
+
help="Disable Weights & Biases logging (or set WANDB_MODE=offline/disabled)",
|
| 289 |
+
)
|
| 290 |
+
parser.set_defaults(wandb=None)
|
| 291 |
+
|
| 292 |
+
args = parser.parse_args()
|
| 293 |
+
mode = args.mode
|
| 294 |
+
|
| 295 |
+
# 1. Load Configurations
|
| 296 |
+
CONFIG = load_config(args.config)
|
| 297 |
+
model_config = CONFIG['model']
|
| 298 |
+
training_config = CONFIG['training']
|
| 299 |
+
rl_config = CONFIG['rl']
|
| 300 |
+
client_config = CONFIG['client']
|
| 301 |
+
dataset_config = CONFIG['dataset']
|
| 302 |
+
task = training_config['task']
|
| 303 |
+
opsd_config = dict(CONFIG.get('opsd', {"enabled": False, "mode": "dyme"}))
|
| 304 |
+
if args.opsd_enabled:
|
| 305 |
+
opsd_config["enabled"] = True
|
| 306 |
+
if args.opsd_mode is not None:
|
| 307 |
+
opsd_config["enabled"] = True
|
| 308 |
+
opsd_config["mode"] = args.opsd_mode
|
| 309 |
+
if args.opsd_providers is not None:
|
| 310 |
+
opsd_config["privileged_providers"] = [p.strip() for p in args.opsd_providers.split(",") if p.strip()]
|
| 311 |
+
if args.opsd_privilege_profile is not None:
|
| 312 |
+
opsd_config["privileged_profile"] = args.opsd_privilege_profile.strip()
|
| 313 |
+
reward_weights_raw = args.reward_weights or os.environ.get("DYME_REWARD_WEIGHTS")
|
| 314 |
+
if reward_weights_raw:
|
| 315 |
+
parts = [p.strip() for p in reward_weights_raw.split(",") if p.strip()]
|
| 316 |
+
if len(parts) != 3:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f"reward_weights must have exactly 3 comma-separated values (format,context,acc), got: {reward_weights_raw!r}"
|
| 319 |
+
)
|
| 320 |
+
opsd_config["reward_weights"] = [float(p) for p in parts]
|
| 321 |
+
debug_cfg = opsd_config.setdefault("debug", {})
|
| 322 |
+
detail_every = debug_cfg.get("detail_every", 10)
|
| 323 |
+
if args.opsd_detail_every is not None:
|
| 324 |
+
detail_every = max(0, args.opsd_detail_every)
|
| 325 |
+
debug_cfg["detail_every"] = detail_every
|
| 326 |
+
probe_on_generate = debug_cfg.get("probe_on_generate", False)
|
| 327 |
+
if args.opsd_probe_on_generate is not None:
|
| 328 |
+
probe_on_generate = args.opsd_probe_on_generate
|
| 329 |
+
debug_cfg["probe_on_generate"] = probe_on_generate
|
| 330 |
+
probe_first_token_logits = debug_cfg.get("probe_first_token_logits", True)
|
| 331 |
+
if args.opsd_probe_first_token_logits is not None:
|
| 332 |
+
probe_first_token_logits = args.opsd_probe_first_token_logits
|
| 333 |
+
debug_cfg["probe_first_token_logits"] = probe_first_token_logits
|
| 334 |
+
|
| 335 |
+
debug_enabled = opsd_debug.configure(
|
| 336 |
+
enabled=args.opsd_debug or None,
|
| 337 |
+
detail_every=detail_every,
|
| 338 |
+
probe_on_generate=probe_on_generate,
|
| 339 |
+
probe_first_token_logits=probe_first_token_logits,
|
| 340 |
+
probe_prompt_tail_tokens=debug_cfg.get("probe_prompt_tail_tokens", 16),
|
| 341 |
+
probe_log_model_context=debug_cfg.get("probe_log_model_context", True),
|
| 342 |
+
)
|
| 343 |
+
if debug_enabled:
|
| 344 |
+
opsd_debug.log_config("main", "resolved OPSD config", opsd_config)
|
| 345 |
+
opsd_debug.log("main", "training entry", mode=mode, config_path=args.config)
|
| 346 |
+
|
| 347 |
+
# 2. Setup Environment
|
| 348 |
+
want_wandb = True if args.wandb is None else args.wandb
|
| 349 |
+
accelerator, use_wandb = setup_accelerator_and_wandb(
|
| 350 |
+
bf16=training_config['dyme_args']['bf16'],
|
| 351 |
+
want_wandb=want_wandb,
|
| 352 |
+
)
|
| 353 |
+
if want_wandb and not use_wandb and args.wandb is True:
|
| 354 |
+
raise RuntimeError(
|
| 355 |
+
"wandb was requested (--wandb) but no API key is configured. "
|
| 356 |
+
"Run `wandb login`, set WANDB_API_KEY, or use WANDB_MODE=offline."
|
| 357 |
+
)
|
| 358 |
+
if accelerator.is_main_process:
|
| 359 |
+
if use_wandb:
|
| 360 |
+
print("[DyME] wandb enabled for training logs")
|
| 361 |
+
elif want_wandb:
|
| 362 |
+
print(
|
| 363 |
+
"[DyME] wandb disabled (no credentials). Training continues with report_to=none. "
|
| 364 |
+
"Run `wandb login`, export WANDB_API_KEY, or pass --wandb after configuring."
|
| 365 |
+
)
|
| 366 |
+
device_id = accelerator.process_index
|
| 367 |
+
opsd_debug.configure(
|
| 368 |
+
enabled=debug_enabled,
|
| 369 |
+
detail_every=detail_every,
|
| 370 |
+
probe_on_generate=probe_on_generate,
|
| 371 |
+
probe_first_token_logits=probe_first_token_logits,
|
| 372 |
+
probe_prompt_tail_tokens=debug_cfg.get("probe_prompt_tail_tokens", 16),
|
| 373 |
+
probe_log_model_context=debug_cfg.get("probe_log_model_context", True),
|
| 374 |
+
rank=accelerator.process_index,
|
| 375 |
+
world_size=accelerator.num_processes,
|
| 376 |
+
)
|
| 377 |
+
if debug_enabled:
|
| 378 |
+
opsd_debug.log(
|
| 379 |
+
"main",
|
| 380 |
+
"accelerator initialized",
|
| 381 |
+
process_index=accelerator.process_index,
|
| 382 |
+
local_process_index=accelerator.local_process_index,
|
| 383 |
+
num_processes=accelerator.num_processes,
|
| 384 |
+
device=str(accelerator.device),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
visible_gpus = torch.cuda.device_count()
|
| 388 |
+
local_rank = int(os.environ.get("LOCAL_RANK", accelerator.local_process_index))
|
| 389 |
+
if visible_gpus == 0:
|
| 390 |
+
raise RuntimeError("No CUDA devices are visible to this process.")
|
| 391 |
+
if accelerator.num_processes > visible_gpus:
|
| 392 |
+
raise RuntimeError(
|
| 393 |
+
f"GPU/process mismatch: launched {accelerator.num_processes} distributed processes "
|
| 394 |
+
f"but only {visible_gpus} CUDA device(s) are visible "
|
| 395 |
+
f"(CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', '<unset>')}).\n"
|
| 396 |
+
f"Fix: accelerate launch --num_processes {visible_gpus} ...\n"
|
| 397 |
+
f"Or: bash scripts/train_local_gpus.sh (auto-detects {visible_gpus} GPU(s))"
|
| 398 |
+
)
|
| 399 |
+
if local_rank >= visible_gpus:
|
| 400 |
+
raise RuntimeError(
|
| 401 |
+
f"LOCAL_RANK={local_rank} but only {visible_gpus} GPU(s) visible. "
|
| 402 |
+
f"Reduce --num_processes to {visible_gpus}."
|
| 403 |
+
)
|
| 404 |
+
if accelerator.is_main_process:
|
| 405 |
+
print(
|
| 406 |
+
f"[DyME] Distributed launch OK: num_processes={accelerator.num_processes}, "
|
| 407 |
+
f"visible_gpus={visible_gpus}, CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', '<unset>')}"
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# 3. Initialize Model and Processor
|
| 411 |
+
ds_zero_stage = deepspeed_zero_stage()
|
| 412 |
+
if accelerator.is_main_process and is_deepspeed_accelerate_config():
|
| 413 |
+
print(
|
| 414 |
+
f"[DyME] DeepSpeed enabled via ACCELERATE_CONFIG "
|
| 415 |
+
f"({os.environ.get('ACCELERATE_CONFIG', '<unset>')}), ZeRO stage={ds_zero_stage}",
|
| 416 |
+
flush=True,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
model, processor = load_model_and_processor(model_config)
|
| 420 |
+
if os.environ.get("DYME_GRADIENT_CHECKPOINTING", "").strip().lower() in ("1", "true", "yes", "on"):
|
| 421 |
+
if should_disable_gradient_checkpointing():
|
| 422 |
+
if accelerator.is_main_process:
|
| 423 |
+
print(
|
| 424 |
+
"[DyME] gradient checkpointing skipped: incompatible with DeepSpeed ZeRO-1/2 "
|
| 425 |
+
"(multiple student forwards / checkpoint backward). "
|
| 426 |
+
"Use ZeRO-3, DDP, or DYME_GRADIENT_CHECKPOINTING=0.",
|
| 427 |
+
flush=True,
|
| 428 |
+
)
|
| 429 |
+
else:
|
| 430 |
+
gc_kwargs = gradient_checkpointing_enable_kwargs()
|
| 431 |
+
if gc_kwargs:
|
| 432 |
+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs)
|
| 433 |
+
else:
|
| 434 |
+
model.gradient_checkpointing_enable()
|
| 435 |
+
if accelerator.is_main_process:
|
| 436 |
+
mode = f"use_reentrant={gc_kwargs['use_reentrant']}" if gc_kwargs else "default"
|
| 437 |
+
print(
|
| 438 |
+
f"[DyME] gradient checkpointing enabled on student "
|
| 439 |
+
f"(DYME_GRADIENT_CHECKPOINTING, {mode})",
|
| 440 |
+
flush=True,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
cold_start_frac = float(
|
| 444 |
+
opsd_config.get("gate", {}).get("sft_cold_start_frac", 0.0) or 0.0
|
| 445 |
+
)
|
| 446 |
+
cold_start_steps = opsd_config.get("gate", {}).get("sft_cold_start_steps")
|
| 447 |
+
lazy_teacher = bool(cold_start_steps) or cold_start_frac > 0.0
|
| 448 |
+
|
| 449 |
+
teacher_model = None
|
| 450 |
+
teacher_model_config = None
|
| 451 |
+
if lazy_teacher:
|
| 452 |
+
teacher_model_config = dict(model_config)
|
| 453 |
+
if accelerator.is_main_process:
|
| 454 |
+
print(
|
| 455 |
+
"[DyME] SFT cold-start enabled: deferring 7B teacher load until RL phase",
|
| 456 |
+
flush=True,
|
| 457 |
+
)
|
| 458 |
+
else:
|
| 459 |
+
teacher_model = load_teacher_model(
|
| 460 |
+
model_config,
|
| 461 |
+
local_rank=local_rank,
|
| 462 |
+
num_gpus=visible_gpus,
|
| 463 |
+
)
|
| 464 |
+
if accelerator.is_main_process and teacher_model is not None:
|
| 465 |
+
_run_cross_model_vocab_checks(
|
| 466 |
+
model,
|
| 467 |
+
processor,
|
| 468 |
+
teacher_model,
|
| 469 |
+
model_config,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# 4. Prepare Datasets
|
| 473 |
+
train_dataset, eval_dataset = prepare_datasets(task, dataset_config, mode=mode)
|
| 474 |
+
|
| 475 |
+
# 5. Initialize Reward Calculator
|
| 476 |
+
# checker = RewardCalculator(rl_config, client_config.copy(), gpu_id=device_id)
|
| 477 |
+
# refiner = ContextRefiner(rl_config, client_config.copy(), gpu_id=device_id)
|
| 478 |
+
|
| 479 |
+
checker = RewardCalculatorLocal(rl_config, client_config.copy(), gpu_id=device_id)
|
| 480 |
+
refiner = ContextRefinerLocal(rl_config, client_config.copy(), gpu_id=device_id)
|
| 481 |
+
# 6. Define Training Arguments
|
| 482 |
+
dyme_args = dict(training_config['dyme_args'])
|
| 483 |
+
if ds_zero_stage is not None and ds_zero_stage >= 3:
|
| 484 |
+
dyme_args.setdefault("ds3_gather_for_generation", True)
|
| 485 |
+
if not use_wandb:
|
| 486 |
+
dyme_args["report_to"] = "none"
|
| 487 |
+
training_args = GRPOConfig(**dyme_args)
|
| 488 |
+
|
| 489 |
+
collate_fn_with_processor = partial(collate_fn, processor=processor)
|
| 490 |
+
# 7. Initialize the Trainer
|
| 491 |
+
dyme_trainer = DyMETrainer(
|
| 492 |
+
model=model,
|
| 493 |
+
checker=checker,
|
| 494 |
+
refiner=refiner,
|
| 495 |
+
args=training_args,
|
| 496 |
+
train_dataset=train_dataset,
|
| 497 |
+
eval_dataset=eval_dataset,
|
| 498 |
+
processing_class=processor,
|
| 499 |
+
processing_func=collate_fn_with_processor,
|
| 500 |
+
task_name=task,
|
| 501 |
+
end_flag=rl_config['end_flag'],
|
| 502 |
+
opsd_config=opsd_config,
|
| 503 |
+
teacher_model=teacher_model,
|
| 504 |
+
teacher_model_config=teacher_model_config,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# 8. Start Training
|
| 508 |
+
dyme_trainer.train()
|
| 509 |
+
|
| 510 |
+
output_dir = training_args.output_dir
|
| 511 |
+
output_dir = os.path.join(output_dir, "final_checkpoint")
|
| 512 |
+
if accelerator.is_main_process and is_deepspeed_accelerate_config():
|
| 513 |
+
print(
|
| 514 |
+
"[DyME] Saving consolidated student checkpoint (DeepSpeed ZeRO gather if configured)...",
|
| 515 |
+
flush=True,
|
| 516 |
+
)
|
| 517 |
+
dyme_trainer.save_model(output_dir)
|
| 518 |
+
if accelerator.is_main_process:
|
| 519 |
+
processor.save_pretrained(output_dir)
|
| 520 |
+
print(f"Model and processor saved to {output_dir}")
|
| 521 |
+
if __name__ == "__main__":
|
| 522 |
+
main()
|
main_llm.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train_grpo.py
|
| 2 |
+
"""
|
| 3 |
+
Main script for training a Llava-based model using the custom MyGRPOTrainer.
|
| 4 |
+
|
| 5 |
+
This script handles:
|
| 6 |
+
1. Configuration loading.
|
| 7 |
+
2. Initialization of Weights & Biases (wandb) and Hugging Face Accelerate.
|
| 8 |
+
3. Loading the model and processor.
|
| 9 |
+
4. Preparing the training and evaluation datasets.
|
| 10 |
+
5. Setting up and running the GRPO trainer.
|
| 11 |
+
"""
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
from functools import partial
|
| 15 |
+
from typing import Dict, Any
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import wandb
|
| 19 |
+
from accelerate import Accelerator
|
| 20 |
+
from datasets import Dataset, load_dataset
|
| 21 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 22 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 23 |
+
from trl import GRPOConfig
|
| 24 |
+
from config.config_llm import CONFIG
|
| 25 |
+
from data_utils.commom_util import collate_fn, define_task_data_func, collate_fn_woI
|
| 26 |
+
from trainer.DyMETrainer_llm import DyMETrainer
|
| 27 |
+
from reward_utils.checker import RewardCalculator, RewardCalculatorLocal
|
| 28 |
+
from reward_utils.refiner import ContextRefiner, ContextRefinerLocal
|
| 29 |
+
|
| 30 |
+
def print_trainable_parameters(model):
|
| 31 |
+
"""
|
| 32 |
+
Prints the number of trainable parameters in the model.
|
| 33 |
+
"""
|
| 34 |
+
trainable_params = 0
|
| 35 |
+
all_param = 0
|
| 36 |
+
for _, param in model.named_parameters():
|
| 37 |
+
all_param += param.numel()
|
| 38 |
+
if param.requires_grad:
|
| 39 |
+
trainable_params += param.numel()
|
| 40 |
+
print(
|
| 41 |
+
f"trainable params: {trainable_params} || all params: {all_param} || "
|
| 42 |
+
f"trainable%: {100 * trainable_params / all_param:.2f}"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def setup_accelerator_and_wandb(bf16) -> Accelerator:
|
| 46 |
+
"""
|
| 47 |
+
Initializes Weights & Biases and the Hugging Face Accelerator.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Accelerator: The configured accelerator instance.
|
| 51 |
+
"""
|
| 52 |
+
wandb_key = os.environ.get("WANDB_API_KEY")
|
| 53 |
+
if wandb_key:
|
| 54 |
+
wandb.login(key=wandb_key)
|
| 55 |
+
if bf16:
|
| 56 |
+
accelerator = Accelerator(mixed_precision="bf16", log_with="wandb")
|
| 57 |
+
else:
|
| 58 |
+
accelerator = Accelerator(log_with="wandb")
|
| 59 |
+
return accelerator
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_model_and_processor(model_config: Dict[str, Any], peft_config: Dict[str, Any]):
|
| 64 |
+
"""
|
| 65 |
+
Loads the base model, applies LoRA configuration, and loads its processor.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
model_config (Dict[str, Any]): Configuration dictionary for the model.
|
| 69 |
+
peft_config (Dict[str, Any]): Configuration dictionary for PEFT (LoRA).
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Tuple[PeftModel, PreTrainedProcessor]: The loaded PEFT model and processor.
|
| 73 |
+
"""
|
| 74 |
+
model_id = model_config['pretrained_model_path']
|
| 75 |
+
|
| 76 |
+
# Load base model
|
| 77 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 78 |
+
model_id,
|
| 79 |
+
torch_dtype=getattr(torch, model_config['torch_dtype']),
|
| 80 |
+
attn_implementation='flash_attention_2' if model_config['use_flash_attention_2'] else 'sdpa',
|
| 81 |
+
low_cpu_mem_usage=True,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
processor = AutoProcessor.from_pretrained(model_id, padding_side='left')
|
| 85 |
+
processor._tokenizer.padding_side = "left"
|
| 86 |
+
lora_config = peft_config
|
| 87 |
+
|
| 88 |
+
model = get_peft_model(base_model, lora_config)
|
| 89 |
+
|
| 90 |
+
print("LoRA model created:")
|
| 91 |
+
print_trainable_parameters(model)
|
| 92 |
+
|
| 93 |
+
return model, processor
|
| 94 |
+
# ## --- LoRA modification End --- ##
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def prepare_datasets(task: str, dataset_config: Dict[str, Any]) -> (Dataset, Dataset):
|
| 98 |
+
"""
|
| 99 |
+
Prepares the training and evaluation datasets based on the specified task.
|
| 100 |
+
"""
|
| 101 |
+
data_func = define_task_data_func(task)
|
| 102 |
+
train_data_list = data_func(json_path=dataset_config['train_dataset'])
|
| 103 |
+
train_dataset = Dataset.from_list(train_data_list)
|
| 104 |
+
|
| 105 |
+
if 'chart' in task:
|
| 106 |
+
eval_dataset = load_dataset(dataset_config['eval_dataset'])['test']
|
| 107 |
+
else:
|
| 108 |
+
eval_dataset = None
|
| 109 |
+
|
| 110 |
+
return train_dataset, eval_dataset
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main():
|
| 114 |
+
"""
|
| 115 |
+
Main function to orchestrate the model training pipeline.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
parser = argparse.ArgumentParser(description="Train a model using GRPO with LoRA.")
|
| 119 |
+
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
'--config', type=str, default='norm',
|
| 122 |
+
help="config file to use: 'norm' or 'llavacot'..."
|
| 123 |
+
)
|
| 124 |
+
args = parser.parse_args()
|
| 125 |
+
config_select = args.config
|
| 126 |
+
|
| 127 |
+
if config_select == 'norm':
|
| 128 |
+
from config_llm import CONFIG
|
| 129 |
+
|
| 130 |
+
# 1. Load Configurations
|
| 131 |
+
model_config = CONFIG['model']
|
| 132 |
+
training_config = CONFIG['training']
|
| 133 |
+
rl_config = CONFIG['rl']
|
| 134 |
+
client_config = CONFIG['client']
|
| 135 |
+
dataset_config = CONFIG['dataset']
|
| 136 |
+
peft_config = LoraConfig(
|
| 137 |
+
r=16,
|
| 138 |
+
lora_alpha=64,
|
| 139 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
|
| 140 |
+
task_type="CAUSAL_LM",
|
| 141 |
+
lora_dropout=0.05,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
task = training_config['task']
|
| 145 |
+
|
| 146 |
+
# 2. Setup Environment
|
| 147 |
+
accelerator = setup_accelerator_and_wandb(bf16=training_config['dyme_args']['bf16'])
|
| 148 |
+
device_id = accelerator.process_index
|
| 149 |
+
|
| 150 |
+
# 3. Initialize Model and Processor
|
| 151 |
+
# ## --- LoRA modification Start --- ##
|
| 152 |
+
# Pass peft_config to the model loading function
|
| 153 |
+
model, processor = load_model_and_processor(model_config, peft_config)
|
| 154 |
+
# ## --- LoRA modification End --- ##
|
| 155 |
+
|
| 156 |
+
# 4. Prepare Datasets
|
| 157 |
+
train_dataset, eval_dataset = prepare_datasets(task, dataset_config)
|
| 158 |
+
|
| 159 |
+
# 5. Initialize Reward Calculator
|
| 160 |
+
checker = RewardCalculatorLocal(rl_config, client_config.copy(), gpu_id=device_id)
|
| 161 |
+
refiner = ContextRefinerLocal(rl_config, client_config.copy(), gpu_id=device_id)
|
| 162 |
+
|
| 163 |
+
# 6. Define Training Arguments
|
| 164 |
+
training_args = GRPOConfig(**training_config['dyme_args'])
|
| 165 |
+
|
| 166 |
+
collate_fn_with_processor = partial(collate_fn_woI, processor=processor)
|
| 167 |
+
|
| 168 |
+
# 7. Initialize the Trainer
|
| 169 |
+
# Trainer handles PeftModel automatically
|
| 170 |
+
dyme_trainer = DyMETrainer(
|
| 171 |
+
model=model,
|
| 172 |
+
checker=checker,
|
| 173 |
+
refiner=refiner,
|
| 174 |
+
args=training_args,
|
| 175 |
+
train_dataset=train_dataset,
|
| 176 |
+
eval_dataset=eval_dataset,
|
| 177 |
+
processing_class=processor,
|
| 178 |
+
processing_func=collate_fn_with_processor,
|
| 179 |
+
task_name=task,
|
| 180 |
+
end_flag=rl_config['end_flag'],
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# 8. Start Training
|
| 184 |
+
dyme_trainer.train()
|
| 185 |
+
|
| 186 |
+
# When saving, the Trainer automatically saves only the LoRA adapter weights
|
| 187 |
+
output_dir = training_args.output_dir
|
| 188 |
+
output_dir = os.path.join(output_dir, "final_checkpoint")
|
| 189 |
+
dyme_trainer.save_model(output_dir)
|
| 190 |
+
|
| 191 |
+
if accelerator.is_main_process:
|
| 192 |
+
# Non-model files like the processor still need to be saved manually
|
| 193 |
+
processor.save_pretrained(output_dir)
|
| 194 |
+
print(f"LoRA adapters and processor saved to {output_dir}")
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|
main_sft.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Offline supervised fine-tuning for ChartQA (two-stage cold start before RLSD/OPD).
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
accelerate launch main_sft.py --config config/config_rlsd_chartqa.py
|
| 6 |
+
bash scripts/train_chartqa_sft.sh
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
from accelerate import Accelerator
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
from transformers import Trainer, TrainingArguments
|
| 17 |
+
|
| 18 |
+
from config.loader import load_config
|
| 19 |
+
from data_utils.commom_util import collate_fn, define_task_data_func
|
| 20 |
+
from main import load_model_and_processor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main() -> None:
|
| 24 |
+
parser = argparse.ArgumentParser(description="ChartQA offline SFT (hint + Answer GT).")
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--config",
|
| 27 |
+
type=str,
|
| 28 |
+
default="config/config_rlsd_chartqa.py",
|
| 29 |
+
help="Config module (uses training.sft_args and dataset.train_dataset).",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--pretrained_model_path",
|
| 33 |
+
type=str,
|
| 34 |
+
default=None,
|
| 35 |
+
help="Override CONFIG model path (e.g. base 0.5B before RL).",
|
| 36 |
+
)
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
config = load_config(args.config)
|
| 40 |
+
model_config = dict(config["model"])
|
| 41 |
+
if args.pretrained_model_path:
|
| 42 |
+
model_config["pretrained_model_path"] = args.pretrained_model_path
|
| 43 |
+
|
| 44 |
+
training_config = config["training"]
|
| 45 |
+
task = training_config["task"]
|
| 46 |
+
sft_args = dict(training_config.get("sft_args") or config.get("training", {}).get("sft_args", {}))
|
| 47 |
+
if not sft_args:
|
| 48 |
+
raise ValueError("Config must define training.sft_args for offline SFT.")
|
| 49 |
+
|
| 50 |
+
output_dir = os.environ.get("DYME_SFT_OUTPUT_DIR", sft_args.get("output_dir", "./outputs/chartqa-sft"))
|
| 51 |
+
sft_args["output_dir"] = output_dir
|
| 52 |
+
sft_args.setdefault("remove_unused_columns", False)
|
| 53 |
+
|
| 54 |
+
accelerator = Accelerator()
|
| 55 |
+
if accelerator.is_main_process:
|
| 56 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
model, processor = load_model_and_processor(model_config)
|
| 59 |
+
data_func = define_task_data_func(task, mode="sft")
|
| 60 |
+
train_list = data_func(json_path=config["dataset"]["train_dataset"])
|
| 61 |
+
train_dataset = Dataset.from_list(train_list)
|
| 62 |
+
|
| 63 |
+
label_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
| 64 |
+
data_collator = partial(collate_fn, processor=processor, label_id=label_id)
|
| 65 |
+
|
| 66 |
+
train_args = TrainingArguments(**sft_args)
|
| 67 |
+
trainer = Trainer(
|
| 68 |
+
model=model,
|
| 69 |
+
args=train_args,
|
| 70 |
+
train_dataset=train_dataset,
|
| 71 |
+
data_collator=data_collator,
|
| 72 |
+
)
|
| 73 |
+
trainer.train()
|
| 74 |
+
trainer.save_model(os.path.join(output_dir, "final_checkpoint"))
|
| 75 |
+
if accelerator.is_main_process:
|
| 76 |
+
processor.save_pretrained(os.path.join(output_dir, "final_checkpoint"))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
main()
|
multi_node_config_raw.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
distributed_type: MULTI_GPU
|
| 3 |
+
downcast_bf16: 'no'
|
| 4 |
+
gpu_ids: all
|
| 5 |
+
machine_rank: 0
|
| 6 |
+
main_process_ip: 'xx.xx.xx.xx'
|
| 7 |
+
main_process_port: 36001
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: 'bf16'
|
| 10 |
+
num_machines: 2
|
| 11 |
+
num_processes: 16
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
| 17 |
+
# Optional DeepSpeed ZeRO-0 (no sharding). Prefer MULTI_GPU DDP for single-node training.
|
| 18 |
+
deepspeed_config:
|
| 19 |
+
zero_stage: 0
|
| 20 |
+
gradient_accumulation_steps: 1
|
| 21 |
+
zero3_init_flag: false
|
opsd_utils/__pycache__/opsd_loss.cpython-312.pyc
ADDED
|
Binary file (5.03 kB). View file
|
|
|
opsd_utils/gate_policy.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RLSD warmup gates for OPSD degenerate skip, denser online SFT, and embedded SFT cold start."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Any, Mapping, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def current_global_step(trainer: Any) -> int:
|
| 9 |
+
return int(getattr(getattr(trainer, "state", None), "global_step", getattr(trainer, "_step", 0)) or 0)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def resolve_max_training_steps(trainer: Any) -> Optional[int]:
|
| 13 |
+
"""Resolve total optimizer steps for gate math (cold start frac, warmup windows).
|
| 14 |
+
|
| 15 |
+
Priority: TrainingArguments.max_steps > Trainer.state.max_steps > epoch estimate.
|
| 16 |
+
HF sets state.max_steps when max_steps<=0 from num_train_epochs * len(dataloader).
|
| 17 |
+
"""
|
| 18 |
+
args = getattr(trainer, "args", None)
|
| 19 |
+
if args is not None:
|
| 20 |
+
arg_max = getattr(args, "max_steps", None)
|
| 21 |
+
if arg_max is not None and int(arg_max) > 0:
|
| 22 |
+
return int(arg_max)
|
| 23 |
+
|
| 24 |
+
state = getattr(trainer, "state", None)
|
| 25 |
+
if state is not None:
|
| 26 |
+
state_max = getattr(state, "max_steps", None)
|
| 27 |
+
if state_max is not None and int(state_max) > 0:
|
| 28 |
+
return int(state_max)
|
| 29 |
+
|
| 30 |
+
if args is None:
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
num_epochs = getattr(args, "num_train_epochs", None)
|
| 34 |
+
grad_accum = max(1, int(getattr(args, "gradient_accumulation_steps", 1) or 1))
|
| 35 |
+
if num_epochs is None or float(num_epochs) <= 0:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
dataloader = getattr(trainer, "train_dataloader", None)
|
| 39 |
+
if dataloader is None and hasattr(trainer, "get_train_dataloader"):
|
| 40 |
+
try:
|
| 41 |
+
dataloader = trainer.get_train_dataloader()
|
| 42 |
+
except Exception:
|
| 43 |
+
dataloader = None
|
| 44 |
+
if dataloader is None:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
steps_per_epoch = len(dataloader)
|
| 49 |
+
except TypeError:
|
| 50 |
+
return None
|
| 51 |
+
if steps_per_epoch <= 0:
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
total = math.ceil(float(num_epochs) * steps_per_epoch / grad_accum)
|
| 55 |
+
return total if total > 0 else None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def sft_cold_start_steps(opsd_config: Mapping[str, Any], max_steps: Optional[int]) -> int:
|
| 59 |
+
"""Steps at start of training devoted to embedded offline-style SFT (no generate / no OPSD)."""
|
| 60 |
+
gate = opsd_config.get("gate", {})
|
| 61 |
+
steps_env = gate.get("sft_cold_start_steps")
|
| 62 |
+
if steps_env is not None:
|
| 63 |
+
return max(0, int(steps_env))
|
| 64 |
+
frac = float(gate.get("sft_cold_start_frac", 0.0) or 0.0)
|
| 65 |
+
if frac <= 0.0 or max_steps is None or max_steps <= 0:
|
| 66 |
+
return 0
|
| 67 |
+
return max(1, int(max_steps * frac))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def in_sft_cold_start(
|
| 71 |
+
opsd_config: Mapping[str, Any],
|
| 72 |
+
global_step: int,
|
| 73 |
+
max_steps: Optional[int],
|
| 74 |
+
) -> bool:
|
| 75 |
+
cold_steps = sft_cold_start_steps(opsd_config, max_steps)
|
| 76 |
+
return cold_steps > 0 and global_step < cold_steps
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def resolve_skip_degenerate_opsd(
|
| 80 |
+
opsd_config: Mapping[str, Any],
|
| 81 |
+
global_step: int,
|
| 82 |
+
max_steps: Optional[int] = None,
|
| 83 |
+
) -> bool:
|
| 84 |
+
gate = opsd_config.get("gate", {})
|
| 85 |
+
if not gate.get("skip_degenerate_for_opsd", False):
|
| 86 |
+
return False
|
| 87 |
+
cold_end = sft_cold_start_steps(opsd_config, max_steps)
|
| 88 |
+
warmup = int(gate.get("degen_skip_warmup_steps", 200))
|
| 89 |
+
# Do not skip degenerate OPSD during embedded SFT cold start or its degen warmup window.
|
| 90 |
+
threshold = cold_end + warmup if cold_end > 0 else warmup
|
| 91 |
+
return global_step >= threshold
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def sft_slots_for_step(
|
| 95 |
+
opsd_config: Mapping[str, Any],
|
| 96 |
+
global_step: int,
|
| 97 |
+
max_steps: Optional[int] = None,
|
| 98 |
+
) -> int:
|
| 99 |
+
if in_sft_cold_start(opsd_config, global_step, max_steps):
|
| 100 |
+
return 0
|
| 101 |
+
gate = opsd_config.get("gate", {})
|
| 102 |
+
warmup_steps = int(gate.get("sft_warmup_steps", 200))
|
| 103 |
+
cold_end = sft_cold_start_steps(opsd_config, max_steps)
|
| 104 |
+
effective_warmup_end = cold_end + warmup_steps if cold_end > 0 else warmup_steps
|
| 105 |
+
if global_step < effective_warmup_end:
|
| 106 |
+
return max(1, int(gate.get("sft_warmup_slots_per_group", 2)))
|
| 107 |
+
return 1
|
opsd_utils/health_monitor.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training health monitor: degeneration alerts, rolling stats, cross-step correlation."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from collections import deque
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
|
| 8 |
+
from opsd_utils import debug_log as opsd_debug
|
| 9 |
+
|
| 10 |
+
ALERT_GEN_CLIP_COLLAPSE = "GEN_CLIP_COLLAPSE"
|
| 11 |
+
ALERT_GEN_REPEAT_DEGEN = "GEN_REPEAT_DEGEN"
|
| 12 |
+
ALERT_OPT_GRAD_SPIKE = "OPT_GRAD_SPIKE"
|
| 13 |
+
ALERT_OPT_NAN_INF = "OPT_NAN_INF"
|
| 14 |
+
ALERT_RL_ZERO_SIGNAL = "RL_ZERO_SIGNAL"
|
| 15 |
+
ALERT_REWARD_FORMAT_HACK = "REWARD_FORMAT_HACK"
|
| 16 |
+
ALERT_DATA_EMPTY_VF = "DATA_EMPTY_VF"
|
| 17 |
+
ALERT_LOGIT_MODE_COLLAPSE = "LOGIT_MODE_COLLAPSE"
|
| 18 |
+
ALERT_ANSWER_TOKEN_DRIFT = "ANSWER_TOKEN_DRIFT"
|
| 19 |
+
ALERT_CLIP_FALSE_HEALTHY = "CLIP_FALSE_HEALTHY"
|
| 20 |
+
ALERT_OPSD_LEAKAGE_PATTERN = "OPSD_LEAKAGE_PATTERN"
|
| 21 |
+
ALERT_OPSD_ON_CORRECT = "OPSD_ON_CORRECT"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _safe_float(v: Any, default: float = 0.0) -> float:
|
| 25 |
+
try:
|
| 26 |
+
if v is None:
|
| 27 |
+
return default
|
| 28 |
+
f = float(v)
|
| 29 |
+
if math.isnan(f) or math.isinf(f):
|
| 30 |
+
return default
|
| 31 |
+
return f
|
| 32 |
+
except (TypeError, ValueError):
|
| 33 |
+
return default
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _rolling_mean_std(values: list[float]) -> tuple[float, float]:
|
| 37 |
+
if not values:
|
| 38 |
+
return 0.0, 0.0
|
| 39 |
+
mean = sum(values) / len(values)
|
| 40 |
+
if len(values) < 2:
|
| 41 |
+
return mean, 0.0
|
| 42 |
+
var = sum((x - mean) ** 2 for x in values) / len(values)
|
| 43 |
+
return mean, math.sqrt(var)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TrainingHealthMonitor:
|
| 47 |
+
"""Collect per-step signals, emit layered [OPSD-HEALTH] logs, expose metrics keys."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, config: Optional[dict[str, Any]] = None):
|
| 50 |
+
cfg = config or {}
|
| 51 |
+
self.enabled = bool(cfg.get("enabled", True))
|
| 52 |
+
self.window = max(2, int(cfg.get("window", 20)))
|
| 53 |
+
self.log_on_generate = bool(cfg.get("log_on_generate", True))
|
| 54 |
+
self.log_every_step = bool(cfg.get("log_every_step", True))
|
| 55 |
+
self.log_detail_bundle = bool(cfg.get("log_detail_bundle", True))
|
| 56 |
+
self.log_alerts_immediately = bool(cfg.get("log_alerts_immediately", True))
|
| 57 |
+
self.metrics_every_step = bool(cfg.get("metrics_every_step", True))
|
| 58 |
+
|
| 59 |
+
self._history: deque[dict[str, Any]] = deque(maxlen=self.window)
|
| 60 |
+
self._step_fields: dict[str, Any] = {}
|
| 61 |
+
self._step_alerts: list[str] = []
|
| 62 |
+
self._p_greedy_history: deque[float] = deque(maxlen=5)
|
| 63 |
+
self._p_answer_history: deque[float] = deque(maxlen=5)
|
| 64 |
+
self._eos_history: deque[float] = deque(maxlen=5)
|
| 65 |
+
self._last_step: Optional[int] = None
|
| 66 |
+
|
| 67 |
+
def reset_step(self, step: int) -> None:
|
| 68 |
+
self._step_fields = {"global_step": step}
|
| 69 |
+
self._step_alerts = []
|
| 70 |
+
self._last_step = step
|
| 71 |
+
|
| 72 |
+
def _emit_alert(self, step: int, code: str, **fields: Any) -> None:
|
| 73 |
+
if code not in self._step_alerts:
|
| 74 |
+
self._step_alerts.append(code)
|
| 75 |
+
if self.log_alerts_immediately and opsd_debug.should_log_health_alerts_immediately():
|
| 76 |
+
opsd_debug.log_health("ALERT", code, global_step=step, **fields)
|
| 77 |
+
|
| 78 |
+
def _check_generate_alerts(self, step: int, stats: dict[str, Any], logits: dict[str, Any]) -> list[str]:
|
| 79 |
+
clipped = _safe_float(stats.get("clipped_rate"))
|
| 80 |
+
eos_rate = _safe_float(stats.get("eos_terminated_rate"))
|
| 81 |
+
degenerate_rate = _safe_float(stats.get("degenerate_rate"))
|
| 82 |
+
repeat_loop = int(stats.get("repeat_loop_count", 0) or 0)
|
| 83 |
+
p_greedy = _safe_float(logits.get("p_greedy_first"))
|
| 84 |
+
p_eos = _safe_float(logits.get("p_eos_first"))
|
| 85 |
+
p_answer = _safe_float(logits.get("p_answer_first"))
|
| 86 |
+
|
| 87 |
+
if clipped > 0.8 and degenerate_rate < 0.05:
|
| 88 |
+
self._emit_alert(
|
| 89 |
+
step,
|
| 90 |
+
ALERT_CLIP_FALSE_HEALTHY,
|
| 91 |
+
clipped_rate=clipped,
|
| 92 |
+
degenerate_rate=degenerate_rate,
|
| 93 |
+
hint="high clip with low degenerate_rate often masks Answer-only collapse",
|
| 94 |
+
)
|
| 95 |
+
if clipped > 0.7 and eos_rate < 0.3:
|
| 96 |
+
self._emit_alert(
|
| 97 |
+
step,
|
| 98 |
+
ALERT_GEN_CLIP_COLLAPSE,
|
| 99 |
+
clipped_rate=clipped,
|
| 100 |
+
eos_rate=eos_rate,
|
| 101 |
+
hint="raise repetition_penalty, lower temperature, or shorten max_completion_length",
|
| 102 |
+
)
|
| 103 |
+
if degenerate_rate > 0.5 or repeat_loop > 0:
|
| 104 |
+
self._emit_alert(
|
| 105 |
+
step,
|
| 106 |
+
ALERT_GEN_REPEAT_DEGEN,
|
| 107 |
+
degenerate_rate=degenerate_rate,
|
| 108 |
+
repeat_loop_count=repeat_loop,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if p_greedy > 0:
|
| 112 |
+
self._p_greedy_history.append(p_greedy)
|
| 113 |
+
self._eos_history.append(eos_rate)
|
| 114 |
+
if (
|
| 115 |
+
len(self._p_greedy_history) >= 3
|
| 116 |
+
and all(p > 0.99 for p in list(self._p_greedy_history)[-3:])
|
| 117 |
+
and len(self._eos_history) >= 2
|
| 118 |
+
and self._eos_history[-1] < self._eos_history[-2] - 0.1
|
| 119 |
+
):
|
| 120 |
+
self._emit_alert(
|
| 121 |
+
step,
|
| 122 |
+
ALERT_LOGIT_MODE_COLLAPSE,
|
| 123 |
+
p_greedy_first=p_greedy,
|
| 124 |
+
p_eos_first=p_eos,
|
| 125 |
+
eos_rate=eos_rate,
|
| 126 |
+
hint="first token collapsed to Goal: template; EOS probability near zero",
|
| 127 |
+
)
|
| 128 |
+
if p_answer > 0:
|
| 129 |
+
self._p_answer_history.append(p_answer)
|
| 130 |
+
if len(self._p_answer_history) >= 3 and all(
|
| 131 |
+
p < 0.5 for p in list(self._p_answer_history)[-3:]
|
| 132 |
+
):
|
| 133 |
+
self._emit_alert(
|
| 134 |
+
step,
|
| 135 |
+
ALERT_ANSWER_TOKEN_DRIFT,
|
| 136 |
+
p_answer_first=p_answer,
|
| 137 |
+
hint="first-token Answer probability low for 3 consecutive generate batches",
|
| 138 |
+
)
|
| 139 |
+
return list(self._step_alerts)
|
| 140 |
+
|
| 141 |
+
def record_generate(
|
| 142 |
+
self,
|
| 143 |
+
step: int,
|
| 144 |
+
stats: dict[str, Any],
|
| 145 |
+
logits_stats: Optional[dict[str, Any]] = None,
|
| 146 |
+
) -> list[str]:
|
| 147 |
+
if not self.enabled:
|
| 148 |
+
return []
|
| 149 |
+
logits_stats = logits_stats or {}
|
| 150 |
+
self._step_fields.update(
|
| 151 |
+
{
|
| 152 |
+
"degenerate_rate": stats.get("degenerate_rate"),
|
| 153 |
+
"clipped_rate": stats.get("clipped_rate"),
|
| 154 |
+
"eos_terminated_rate": stats.get("eos_terminated_rate"),
|
| 155 |
+
"repeat_loop_count": stats.get("repeat_loop_count"),
|
| 156 |
+
"char_repeat_count": stats.get("char_repeat_count", 0),
|
| 157 |
+
"p_greedy_first": logits_stats.get("p_greedy_first"),
|
| 158 |
+
"p_eos_first": logits_stats.get("p_eos_first"),
|
| 159 |
+
"p_answer_first": logits_stats.get("p_answer_first"),
|
| 160 |
+
"entropy_first": logits_stats.get("entropy_first"),
|
| 161 |
+
"degenerate_rate_format": stats.get("degenerate_rate_format"),
|
| 162 |
+
"degenerate_rate_repeat": stats.get("degenerate_rate_repeat"),
|
| 163 |
+
"format_without_thinking_rate": stats.get("format_without_thinking_rate"),
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
alerts = self._check_generate_alerts(step, stats, logits_stats)
|
| 167 |
+
alert_str = ",".join(alerts) if alerts else "none"
|
| 168 |
+
|
| 169 |
+
if self.log_on_generate and opsd_debug.should_log_health_on_generate():
|
| 170 |
+
opsd_debug.log_health(
|
| 171 |
+
"generate",
|
| 172 |
+
"batch health",
|
| 173 |
+
global_step=step,
|
| 174 |
+
degenerate_rate=stats.get("degenerate_rate"),
|
| 175 |
+
clipped_rate=stats.get("clipped_rate"),
|
| 176 |
+
eos_rate=stats.get("eos_terminated_rate"),
|
| 177 |
+
repeat_loop_count=stats.get("repeat_loop_count"),
|
| 178 |
+
char_repeat_count=stats.get("char_repeat_count", 0),
|
| 179 |
+
p_greedy=logits_stats.get("p_greedy_first"),
|
| 180 |
+
p_eos=logits_stats.get("p_eos_first"),
|
| 181 |
+
alerts=alert_str,
|
| 182 |
+
)
|
| 183 |
+
return alerts
|
| 184 |
+
|
| 185 |
+
def record_data(self, step: int, fields: dict[str, Any]) -> None:
|
| 186 |
+
if not self.enabled:
|
| 187 |
+
return
|
| 188 |
+
self._step_fields.update(fields)
|
| 189 |
+
vf_empty = _safe_float(fields.get("visual_fact_empty_rate"))
|
| 190 |
+
if vf_empty > 0.5:
|
| 191 |
+
self._emit_alert(
|
| 192 |
+
step,
|
| 193 |
+
ALERT_DATA_EMPTY_VF,
|
| 194 |
+
visual_fact_empty_rate=vf_empty,
|
| 195 |
+
hint="rebuild train_medium_vf_full.json with visual_fact hints",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def record_routing(self, step: int, fields: dict[str, Any]) -> None:
|
| 199 |
+
if not self.enabled:
|
| 200 |
+
return
|
| 201 |
+
self._step_fields.update(fields)
|
| 202 |
+
format_mean = _safe_float(fields.get("format_mean"))
|
| 203 |
+
acc_mean = _safe_float(fields.get("accuracy_mean"))
|
| 204 |
+
degenerate_rate = _safe_float(self._step_fields.get("degenerate_rate"))
|
| 205 |
+
if format_mean > 0.7 and acc_mean < 0.05 and degenerate_rate > 0.4:
|
| 206 |
+
self._emit_alert(
|
| 207 |
+
step,
|
| 208 |
+
ALERT_REWARD_FORMAT_HACK,
|
| 209 |
+
format_mean=format_mean,
|
| 210 |
+
accuracy_mean=acc_mean,
|
| 211 |
+
degenerate_rate=degenerate_rate,
|
| 212 |
+
)
|
| 213 |
+
opsd_on_correct = _safe_float(fields.get("opsd_on_correct_rate"))
|
| 214 |
+
if opsd_on_correct > 0.01:
|
| 215 |
+
self._emit_alert(
|
| 216 |
+
step,
|
| 217 |
+
ALERT_OPSD_ON_CORRECT,
|
| 218 |
+
opsd_on_correct_rate=opsd_on_correct,
|
| 219 |
+
)
|
| 220 |
+
leakage_skip = int(fields.get("opsd_skipped_leakage", 0) or 0)
|
| 221 |
+
if leakage_skip > 0:
|
| 222 |
+
self._emit_alert(
|
| 223 |
+
step,
|
| 224 |
+
ALERT_OPSD_LEAKAGE_PATTERN,
|
| 225 |
+
opsd_skipped_leakage=leakage_skip,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def record_loss(self, step: int, fields: dict[str, Any]) -> None:
|
| 229 |
+
if not self.enabled:
|
| 230 |
+
return
|
| 231 |
+
self._step_fields.update(fields)
|
| 232 |
+
loss_val = fields.get("combined_loss_scalar", fields.get("grpo_loss_scalar"))
|
| 233 |
+
if loss_val is not None and not math.isfinite(_safe_float(loss_val, default=float("nan"))):
|
| 234 |
+
self._emit_alert(step, ALERT_OPT_NAN_INF, loss=loss_val)
|
| 235 |
+
|
| 236 |
+
adv_abs = _safe_float(fields.get("advantages_abs_mean"))
|
| 237 |
+
zero_grpo = _safe_float(fields.get("grpo_zero_loss_rate"))
|
| 238 |
+
if adv_abs < 1e-6 and zero_grpo > 0.8:
|
| 239 |
+
self._emit_alert(
|
| 240 |
+
step,
|
| 241 |
+
ALERT_RL_ZERO_SIGNAL,
|
| 242 |
+
advantages_abs_mean=adv_abs,
|
| 243 |
+
grpo_zero_loss_rate=zero_grpo,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def record_optimizer(self, step: int, grad_norm: Optional[float], lr: Optional[float]) -> None:
|
| 247 |
+
if not self.enabled:
|
| 248 |
+
return
|
| 249 |
+
gn = _safe_float(grad_norm) if grad_norm is not None else None
|
| 250 |
+
if gn is not None:
|
| 251 |
+
self._step_fields["grad_norm"] = gn
|
| 252 |
+
hist = [h.get("grad_norm") for h in self._history if h.get("grad_norm") is not None]
|
| 253 |
+
if len(hist) >= 3:
|
| 254 |
+
mean, std = _rolling_mean_std([float(x) for x in hist])
|
| 255 |
+
if std > 1e-8 and gn > mean + 3 * std:
|
| 256 |
+
self._emit_alert(
|
| 257 |
+
step,
|
| 258 |
+
ALERT_OPT_GRAD_SPIKE,
|
| 259 |
+
grad_norm=gn,
|
| 260 |
+
rolling_mean=mean,
|
| 261 |
+
rolling_std=std,
|
| 262 |
+
)
|
| 263 |
+
if lr is not None:
|
| 264 |
+
self._step_fields["learning_rate"] = lr
|
| 265 |
+
|
| 266 |
+
def correlate(self) -> dict[str, Any]:
|
| 267 |
+
"""Cross-step deltas and root-cause hints from rolling history."""
|
| 268 |
+
hints: list[str] = []
|
| 269 |
+
out: dict[str, Any] = {"root_cause_hints": hints}
|
| 270 |
+
|
| 271 |
+
if len(self._history) < 2:
|
| 272 |
+
out["root_cause_hints"] = ["insufficient history for correlation"]
|
| 273 |
+
return out
|
| 274 |
+
|
| 275 |
+
prev = self._history[-1]
|
| 276 |
+
prev2 = self._history[-2] if len(self._history) >= 2 else prev
|
| 277 |
+
|
| 278 |
+
for key in ("grad_norm", "clipped_rate", "eos_terminated_rate", "p_greedy_first", "degenerate_rate"):
|
| 279 |
+
cur_v = self._step_fields.get(key)
|
| 280 |
+
old_v = prev.get(key)
|
| 281 |
+
if cur_v is not None and old_v is not None:
|
| 282 |
+
out[f"delta_{key}"] = _safe_float(cur_v) - _safe_float(old_v)
|
| 283 |
+
|
| 284 |
+
gn_prev = prev.get("grad_norm")
|
| 285 |
+
clip_cur = self._step_fields.get("clipped_rate")
|
| 286 |
+
if gn_prev is not None and clip_cur is not None and _safe_float(clip_cur) > 0.7:
|
| 287 |
+
hints.append("high clip rate may follow recent gradient update (check delta_grad_norm)")
|
| 288 |
+
|
| 289 |
+
p_prev = prev2.get("p_greedy_first")
|
| 290 |
+
p_cur = self._step_fields.get("p_greedy_first")
|
| 291 |
+
eos_prev = prev2.get("eos_terminated_rate")
|
| 292 |
+
eos_cur = self._step_fields.get("eos_terminated_rate")
|
| 293 |
+
if (
|
| 294 |
+
p_prev is not None
|
| 295 |
+
and p_cur is not None
|
| 296 |
+
and _safe_float(p_cur) > 0.99
|
| 297 |
+
and eos_prev is not None
|
| 298 |
+
and eos_cur is not None
|
| 299 |
+
and _safe_float(eos_prev) > 0.5
|
| 300 |
+
and _safe_float(eos_cur) < 0.2
|
| 301 |
+
):
|
| 302 |
+
hints.append("after gradient step: p_greedy rose to ~1.0 and eos_rate collapsed")
|
| 303 |
+
|
| 304 |
+
if ALERT_RL_ZERO_SIGNAL in self._step_alerts and ALERT_GEN_REPEAT_DEGEN in self._step_alerts:
|
| 305 |
+
hints.append("RL zero signal co-occurs with repetition degeneration")
|
| 306 |
+
|
| 307 |
+
if not hints:
|
| 308 |
+
hints.append("none")
|
| 309 |
+
out["root_cause_hints"] = hints
|
| 310 |
+
return out
|
| 311 |
+
|
| 312 |
+
def maybe_log_detail_bundle(self, step: int) -> None:
|
| 313 |
+
if not self.enabled or not self.log_detail_bundle:
|
| 314 |
+
return
|
| 315 |
+
if not opsd_debug.should_log_detail(step):
|
| 316 |
+
return
|
| 317 |
+
opsd_debug.log_health_detail_banner(step, "TRAINING HEALTH BUNDLE")
|
| 318 |
+
corr = self.correlate()
|
| 319 |
+
hist_keys = (
|
| 320 |
+
"degenerate_rate",
|
| 321 |
+
"clipped_rate",
|
| 322 |
+
"eos_terminated_rate",
|
| 323 |
+
"grad_norm",
|
| 324 |
+
"p_greedy_first",
|
| 325 |
+
"grpo_zero_loss_rate",
|
| 326 |
+
"sft_replaced_ratio",
|
| 327 |
+
)
|
| 328 |
+
rolling: dict[str, Any] = {}
|
| 329 |
+
for key in hist_keys:
|
| 330 |
+
vals = [_safe_float(h[key]) for h in self._history if h.get(key) is not None]
|
| 331 |
+
if vals:
|
| 332 |
+
mean, std = _rolling_mean_std(vals)
|
| 333 |
+
rolling[f"{key}_mean"] = mean
|
| 334 |
+
rolling[f"{key}_std"] = std
|
| 335 |
+
|
| 336 |
+
snapshot_fields = {k: v for k, v in self._step_fields.items() if k != "global_step"}
|
| 337 |
+
opsd_debug.log_health_detail(
|
| 338 |
+
"health",
|
| 339 |
+
"step snapshot",
|
| 340 |
+
global_step=step,
|
| 341 |
+
alerts=self._step_alerts or ["none"],
|
| 342 |
+
**snapshot_fields,
|
| 343 |
+
**rolling,
|
| 344 |
+
)
|
| 345 |
+
opsd_debug.log_health_detail(
|
| 346 |
+
"health",
|
| 347 |
+
"cross-step correlation",
|
| 348 |
+
global_step=step,
|
| 349 |
+
**corr,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def finish_step(self, step: int) -> dict[str, float]:
|
| 353 |
+
"""L2 step summary log + metrics keys for Trainer.log()."""
|
| 354 |
+
snapshot = dict(self._step_fields)
|
| 355 |
+
snapshot["alert_count"] = len(self._step_alerts)
|
| 356 |
+
snapshot["alerts"] = list(self._step_alerts)
|
| 357 |
+
self._history.append(snapshot)
|
| 358 |
+
|
| 359 |
+
if self.log_every_step and opsd_debug.should_log_health_every_step():
|
| 360 |
+
corr = self.correlate()
|
| 361 |
+
opsd_debug.log_health(
|
| 362 |
+
"step",
|
| 363 |
+
"step summary",
|
| 364 |
+
global_step=step,
|
| 365 |
+
grad_norm=snapshot.get("grad_norm"),
|
| 366 |
+
lr=snapshot.get("learning_rate"),
|
| 367 |
+
sft_replaced_ratio=snapshot.get("sft_replaced_ratio"),
|
| 368 |
+
grpo_zero_loss_rate=snapshot.get("grpo_zero_loss_rate"),
|
| 369 |
+
degenerate_rate=snapshot.get("degenerate_rate"),
|
| 370 |
+
clipped_rate=snapshot.get("clipped_rate"),
|
| 371 |
+
eos_rate=snapshot.get("eos_terminated_rate"),
|
| 372 |
+
alert_count=len(self._step_alerts),
|
| 373 |
+
hints=corr.get("root_cause_hints"),
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self.maybe_log_detail_bundle(step)
|
| 377 |
+
|
| 378 |
+
if not self.metrics_every_step:
|
| 379 |
+
return {}
|
| 380 |
+
|
| 381 |
+
metrics: dict[str, float] = {}
|
| 382 |
+
mapping = {
|
| 383 |
+
"completions/degenerate_rate": "degenerate_rate",
|
| 384 |
+
"completions/eos_rate": "eos_terminated_rate",
|
| 385 |
+
"completions/repeat_loop_count": "repeat_loop_count",
|
| 386 |
+
"routing/sft_replaced_ratio": "sft_replaced_ratio",
|
| 387 |
+
"routing/opsd_skipped_degenerate": "opsd_skipped_degenerate",
|
| 388 |
+
"routing/opsd_skipped_leakage": "opsd_skipped_leakage",
|
| 389 |
+
"routing/opsd_on_correct_rate": "opsd_on_correct_rate",
|
| 390 |
+
"routing/grpo_on_correct_rate": "grpo_on_correct_rate",
|
| 391 |
+
"routing/opd_teacher_call_rate": "opd_teacher_call_rate",
|
| 392 |
+
"teacher/privileged_suffix_has_gold_rate": "privileged_suffix_has_gold_rate",
|
| 393 |
+
"teacher/visual_fact_empty_rate": "visual_fact_empty_rate",
|
| 394 |
+
"teacher/suffix_len_mean": "teacher_suffix_len_mean",
|
| 395 |
+
"signal/grpo_zero_loss_rate": "grpo_zero_loss_rate",
|
| 396 |
+
"signal/advantage_abs_mean": "advantages_abs_mean",
|
| 397 |
+
"logits/p_greedy_first": "p_greedy_first",
|
| 398 |
+
"logits/p_eos_first": "p_eos_first",
|
| 399 |
+
"logits/p_answer_first": "p_answer_first",
|
| 400 |
+
"completions/degenerate_rate_format": "degenerate_rate_format",
|
| 401 |
+
"completions/degenerate_rate_repeat": "degenerate_rate_repeat",
|
| 402 |
+
"metrics/format_without_thinking_rate": "format_without_thinking_rate",
|
| 403 |
+
"phase/sft_cold_start": "phase_sft_cold_start",
|
| 404 |
+
"health/alert_count": "alert_count",
|
| 405 |
+
}
|
| 406 |
+
for metric_key, field_key in mapping.items():
|
| 407 |
+
val = snapshot.get(field_key)
|
| 408 |
+
if val is not None:
|
| 409 |
+
metrics[metric_key] = _safe_float(val)
|
| 410 |
+
return metrics
|
opsd_utils/privileged/__pycache__/providers.cpython-310.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
opsd_utils/privileged/image_utils.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image loading and crop utilities for privileged teacher dual-image forward."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Any, Optional
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from data_utils.paths import resolve_image_path
|
| 9 |
+
from data_utils.privileged_schema import resolve_crop_bbox
|
| 10 |
+
from opsd_utils import debug_log as opsd_debug
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_rgb(image: Any) -> Optional[Image.Image]:
|
| 14 |
+
"""Load sample image as RGB PIL from path or in-memory object."""
|
| 15 |
+
if image is None:
|
| 16 |
+
return None
|
| 17 |
+
if isinstance(image, Image.Image):
|
| 18 |
+
return image.convert("RGB") if image.mode != "RGB" else image
|
| 19 |
+
if isinstance(image, str):
|
| 20 |
+
path = resolve_image_path(image)
|
| 21 |
+
try:
|
| 22 |
+
img = Image.open(path)
|
| 23 |
+
return img.convert("RGB")
|
| 24 |
+
except (FileNotFoundError, OSError):
|
| 25 |
+
opsd_debug.log("privileged_image", "load_rgb failed", path=path)
|
| 26 |
+
return None
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def center_crop(img: Image.Image, margin_ratio: float = 0.25) -> Image.Image:
|
| 31 |
+
w, h = img.size
|
| 32 |
+
margin_w = int(w * margin_ratio)
|
| 33 |
+
margin_h = int(h * margin_ratio)
|
| 34 |
+
return img.crop((margin_w, margin_h, w - margin_w, h - margin_h))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def crop_image(
|
| 38 |
+
img: Image.Image,
|
| 39 |
+
bbox_norm: Optional[list[float]] = None,
|
| 40 |
+
strategy: str = "center",
|
| 41 |
+
margin_ratio: float = 0.25,
|
| 42 |
+
fallback_reason: Optional[str] = None,
|
| 43 |
+
) -> tuple[Image.Image, str]:
|
| 44 |
+
"""
|
| 45 |
+
Crop image using C2 normalized bbox or center fallback.
|
| 46 |
+
Returns (cropped_image, crop_strategy_used).
|
| 47 |
+
"""
|
| 48 |
+
if bbox_norm is not None and strategy in ("bbox", "heuristic", "bbox_then_center"):
|
| 49 |
+
w, h = img.size
|
| 50 |
+
x0 = int(bbox_norm[0] * w)
|
| 51 |
+
y0 = int(bbox_norm[1] * h)
|
| 52 |
+
x1 = int(bbox_norm[2] * w)
|
| 53 |
+
y1 = int(bbox_norm[3] * h)
|
| 54 |
+
x0, x1 = max(0, min(x0, w - 1)), max(1, min(x1, w))
|
| 55 |
+
y0, y1 = max(0, min(y0, h - 1)), max(1, min(y1, h))
|
| 56 |
+
if x1 > x0 and y1 > y0:
|
| 57 |
+
used = strategy if strategy != "bbox_then_center" else "bbox"
|
| 58 |
+
opsd_debug.log(
|
| 59 |
+
"privileged_image",
|
| 60 |
+
"crop_image bbox",
|
| 61 |
+
strategy=used,
|
| 62 |
+
bbox_norm=bbox_norm,
|
| 63 |
+
crop_px=(x0, y0, x1, y1),
|
| 64 |
+
fallback_reason=fallback_reason,
|
| 65 |
+
)
|
| 66 |
+
return img.crop((x0, y0, x1, y1)), used
|
| 67 |
+
|
| 68 |
+
crop = center_crop(img, margin_ratio=margin_ratio)
|
| 69 |
+
used = "center_fallback" if fallback_reason else "center"
|
| 70 |
+
opsd_debug.log(
|
| 71 |
+
"privileged_image",
|
| 72 |
+
"crop_image center",
|
| 73 |
+
strategy=used,
|
| 74 |
+
bbox_norm=bbox_norm,
|
| 75 |
+
margin_ratio=margin_ratio,
|
| 76 |
+
fallback_reason=fallback_reason,
|
| 77 |
+
)
|
| 78 |
+
return crop, used
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def heuristic_crop_from_visual_fact(
|
| 82 |
+
img: Image.Image,
|
| 83 |
+
sample: dict[str, Any],
|
| 84 |
+
crop_cfg: Optional[dict[str, Any]] = None,
|
| 85 |
+
) -> tuple[Image.Image, str, Optional[list[float]]]:
|
| 86 |
+
"""D2 with D1 fallback: heuristic bbox from visual_fact, else center crop."""
|
| 87 |
+
crop_cfg = crop_cfg or {}
|
| 88 |
+
margin_ratio = float(crop_cfg.get("margin_ratio", 0.25))
|
| 89 |
+
bbox_norm, strategy = resolve_crop_bbox(sample, crop_cfg)
|
| 90 |
+
fallback_reason = None
|
| 91 |
+
if strategy == "center" and sample.get("visual_fact"):
|
| 92 |
+
fallback_reason = "heuristic_failed"
|
| 93 |
+
crop, used = crop_image(
|
| 94 |
+
img,
|
| 95 |
+
bbox_norm=bbox_norm,
|
| 96 |
+
strategy=strategy if bbox_norm else "center",
|
| 97 |
+
margin_ratio=margin_ratio,
|
| 98 |
+
fallback_reason=fallback_reason,
|
| 99 |
+
)
|
| 100 |
+
return crop, used, bbox_norm
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def resolve_teacher_images(
|
| 104 |
+
sample: dict[str, Any],
|
| 105 |
+
profile: str,
|
| 106 |
+
crop_cfg: Optional[dict[str, Any]] = None,
|
| 107 |
+
) -> tuple[list[Image.Image], dict[str, Any]]:
|
| 108 |
+
"""
|
| 109 |
+
Build teacher image list for privileged forward.
|
| 110 |
+
text -> [full]; visual/hybrid + mode=dual -> [full, crop]; otherwise [full].
|
| 111 |
+
Returns (images, debug_meta).
|
| 112 |
+
"""
|
| 113 |
+
crop_cfg = crop_cfg or {}
|
| 114 |
+
image = sample.get("image")
|
| 115 |
+
if image is None:
|
| 116 |
+
return [], {"crop_strategy": "none", "num_teacher_images": 0, "has_bbox": False}
|
| 117 |
+
|
| 118 |
+
full = load_rgb(image)
|
| 119 |
+
if full is None:
|
| 120 |
+
return [], {"crop_strategy": "load_failed", "num_teacher_images": 0, "has_bbox": False}
|
| 121 |
+
|
| 122 |
+
image_mode = str(crop_cfg.get("mode", "single")).strip().lower()
|
| 123 |
+
if profile == "text" or image_mode in ("single", "full", "off", "none"):
|
| 124 |
+
meta = {
|
| 125 |
+
"crop_strategy": "single_full",
|
| 126 |
+
"num_teacher_images": 1,
|
| 127 |
+
"has_bbox": False,
|
| 128 |
+
"bbox_norm": None,
|
| 129 |
+
"image_mode": image_mode,
|
| 130 |
+
}
|
| 131 |
+
return [full], meta
|
| 132 |
+
|
| 133 |
+
crop, crop_strategy, bbox_norm = heuristic_crop_from_visual_fact(full, sample, crop_cfg)
|
| 134 |
+
meta = {
|
| 135 |
+
"crop_strategy": crop_strategy,
|
| 136 |
+
"num_teacher_images": 2,
|
| 137 |
+
"has_bbox": bbox_norm is not None,
|
| 138 |
+
"bbox_norm": bbox_norm,
|
| 139 |
+
"full_size": full.size,
|
| 140 |
+
"crop_size": crop.size,
|
| 141 |
+
"image_mode": image_mode,
|
| 142 |
+
}
|
| 143 |
+
return [full, crop], meta
|
opsd_utils/prompt_builder.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from opsd_utils import debug_log as opsd_debug
|
| 8 |
+
from opsd_utils.privileged import build_privileged_context, maybe_save_privileged_images
|
| 9 |
+
from opsd_utils.teacher_batching import (
|
| 10 |
+
count_image_tokens,
|
| 11 |
+
process_teacher_sample,
|
| 12 |
+
stack_teacher_processor_batches,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _build_teacher_text(student_prompt: str, privileged_suffix: str) -> str:
|
| 17 |
+
teacher_text = student_prompt
|
| 18 |
+
if privileged_suffix.strip():
|
| 19 |
+
teacher_text = f"{student_prompt}\n\n{privileged_suffix.strip()}"
|
| 20 |
+
return teacher_text
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def tokenize_teacher_prompt(
|
| 24 |
+
processor,
|
| 25 |
+
student_prompt: str,
|
| 26 |
+
privileged_suffix: str,
|
| 27 |
+
images: Any,
|
| 28 |
+
) -> dict:
|
| 29 |
+
"""Tokenize teacher multimodal prompt = student question + privileged suffix + N images."""
|
| 30 |
+
if isinstance(images, list):
|
| 31 |
+
pil_images = [img for img in images if isinstance(img, Image.Image)]
|
| 32 |
+
else:
|
| 33 |
+
from opsd_utils.privileged.image_utils import load_rgb
|
| 34 |
+
|
| 35 |
+
one = load_rgb(images)
|
| 36 |
+
pil_images = [one] if one is not None else []
|
| 37 |
+
|
| 38 |
+
teacher_text = _build_teacher_text(student_prompt, privileged_suffix)
|
| 39 |
+
|
| 40 |
+
opsd_debug.log(
|
| 41 |
+
"teacher_prompt",
|
| 42 |
+
"tokenize_teacher_prompt",
|
| 43 |
+
num_images=len(pil_images),
|
| 44 |
+
suffix_len=len(privileged_suffix.strip()),
|
| 45 |
+
teacher_text_len=len(teacher_text),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
batch = process_teacher_sample(processor, teacher_text, pil_images)
|
| 49 |
+
|
| 50 |
+
opsd_debug.log(
|
| 51 |
+
"teacher_prompt",
|
| 52 |
+
"tokenize_teacher_prompt result",
|
| 53 |
+
input_ids_shape=tuple(batch["input_ids"].shape),
|
| 54 |
+
has_pixel_values="pixel_values" in batch,
|
| 55 |
+
pixel_values_shape=tuple(batch["pixel_values"].shape) if "pixel_values" in batch else None,
|
| 56 |
+
image_token_count=count_image_tokens(batch["input_ids"], processor),
|
| 57 |
+
)
|
| 58 |
+
return batch
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def build_teacher_prompt_batch(
|
| 62 |
+
processor,
|
| 63 |
+
samples: list[dict[str, Any]],
|
| 64 |
+
indices: list[int],
|
| 65 |
+
provider_names: list[str],
|
| 66 |
+
device,
|
| 67 |
+
*,
|
| 68 |
+
opsd_config: Optional[dict[str, Any]] = None,
|
| 69 |
+
global_step: Optional[int] = None,
|
| 70 |
+
output_dir: Optional[str] = None,
|
| 71 |
+
) -> dict[str, Any]:
|
| 72 |
+
"""Build padded teacher prompt tensors for OPSD samples at given indices."""
|
| 73 |
+
opsd_config = opsd_config or {}
|
| 74 |
+
privileged_profile = opsd_config.get("privileged_profile", "hybrid")
|
| 75 |
+
crop_cfg = opsd_config.get("privileged_image") or {}
|
| 76 |
+
privileged_debug_cfg = opsd_config.get("privileged_debug") or {}
|
| 77 |
+
|
| 78 |
+
opsd_debug.log(
|
| 79 |
+
"teacher_prompt",
|
| 80 |
+
"build_teacher_prompt_batch enter",
|
| 81 |
+
num_indices=len(indices),
|
| 82 |
+
indices=indices,
|
| 83 |
+
num_samples=len(samples),
|
| 84 |
+
provider_names=provider_names,
|
| 85 |
+
privileged_profile=privileged_profile,
|
| 86 |
+
device=str(device),
|
| 87 |
+
global_step=global_step,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if not indices:
|
| 91 |
+
opsd_debug.log("teacher_prompt", "empty indices, return {}")
|
| 92 |
+
return {}
|
| 93 |
+
|
| 94 |
+
sample_payloads: list[dict[str, Any]] = []
|
| 95 |
+
for idx in indices:
|
| 96 |
+
sample = samples[idx]
|
| 97 |
+
suffix, teacher_images = build_privileged_context(
|
| 98 |
+
sample,
|
| 99 |
+
provider_names,
|
| 100 |
+
privileged_profile=privileged_profile,
|
| 101 |
+
crop_cfg=crop_cfg,
|
| 102 |
+
opsd_config=opsd_config,
|
| 103 |
+
)
|
| 104 |
+
if not teacher_images:
|
| 105 |
+
from opsd_utils.privileged.image_utils import load_rgb
|
| 106 |
+
|
| 107 |
+
full = load_rgb(sample.get("image"))
|
| 108 |
+
teacher_images = [full] if full is not None else []
|
| 109 |
+
|
| 110 |
+
full_img = teacher_images[0] if teacher_images else None
|
| 111 |
+
crop_img = teacher_images[1] if len(teacher_images) > 1 else None
|
| 112 |
+
maybe_save_privileged_images(
|
| 113 |
+
global_step,
|
| 114 |
+
idx,
|
| 115 |
+
full_img,
|
| 116 |
+
crop_img,
|
| 117 |
+
meta={
|
| 118 |
+
"privileged_profile": privileged_profile,
|
| 119 |
+
"num_teacher_images": len(teacher_images),
|
| 120 |
+
"suffix_len": len(suffix.strip()),
|
| 121 |
+
},
|
| 122 |
+
output_dir=output_dir,
|
| 123 |
+
privileged_debug_cfg=privileged_debug_cfg,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
teacher_text = _build_teacher_text(sample["prompt"], suffix)
|
| 127 |
+
sample_payloads.append(
|
| 128 |
+
{
|
| 129 |
+
"teacher_text": teacher_text,
|
| 130 |
+
"images": teacher_images,
|
| 131 |
+
"suffix_len": len(suffix.strip()),
|
| 132 |
+
"num_teacher_images": len(teacher_images),
|
| 133 |
+
}
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
batch = _build_teacher_batch_with_oom_retry(processor, sample_payloads)
|
| 137 |
+
|
| 138 |
+
out = {
|
| 139 |
+
"teacher_prompt_ids": batch["input_ids"].to(device),
|
| 140 |
+
"teacher_prompt_mask": batch["attention_mask"].to(device),
|
| 141 |
+
}
|
| 142 |
+
if batch.get("pixel_values_list"):
|
| 143 |
+
out["teacher_pixel_values_list"] = [pv.to(device) for pv in batch["pixel_values_list"]]
|
| 144 |
+
if batch.get("image_sizes_list"):
|
| 145 |
+
out["teacher_image_sizes_list"] = [sz.to(device) for sz in batch["image_sizes_list"]]
|
| 146 |
+
|
| 147 |
+
teacher_num_images = [int(max(0, n)) for n in batch.get("batch_num_images", [])]
|
| 148 |
+
if not teacher_num_images:
|
| 149 |
+
teacher_num_images = [p["num_teacher_images"] for p in sample_payloads]
|
| 150 |
+
out["teacher_num_images"] = torch.tensor(teacher_num_images, device=device, dtype=torch.long)
|
| 151 |
+
|
| 152 |
+
student_len = None
|
| 153 |
+
if indices and samples[indices[0]].get("prompt"):
|
| 154 |
+
student_messages = [
|
| 155 |
+
{
|
| 156 |
+
"role": "user",
|
| 157 |
+
"content": [{"type": "image"}, {"type": "text", "text": samples[indices[0]]["prompt"]}],
|
| 158 |
+
}
|
| 159 |
+
]
|
| 160 |
+
student_text = processor.apply_chat_template(student_messages, add_generation_prompt=True)
|
| 161 |
+
student_len = len(processor(text=[student_text], return_tensors="pt")["input_ids"][0])
|
| 162 |
+
|
| 163 |
+
teacher_len = int(out["teacher_prompt_ids"].shape[1])
|
| 164 |
+
opsd_debug.log(
|
| 165 |
+
"teacher_prompt",
|
| 166 |
+
"build_teacher_prompt_batch done",
|
| 167 |
+
teacher_prompt_ids_shape=tuple(out["teacher_prompt_ids"].shape),
|
| 168 |
+
teacher_prompt_mask_shape=tuple(out["teacher_prompt_mask"].shape),
|
| 169 |
+
has_teacher_pixel_values=bool(out.get("teacher_pixel_values_list")),
|
| 170 |
+
teacher_pixel_values_shapes=[
|
| 171 |
+
tuple(pv.shape) for pv in out.get("teacher_pixel_values_list", [])[:4]
|
| 172 |
+
],
|
| 173 |
+
teacher_images_count=sample_payloads[0]["num_teacher_images"] if sample_payloads else 0,
|
| 174 |
+
teacher_num_images=teacher_num_images,
|
| 175 |
+
teacher_image_token_counts=batch.get("image_token_counts"),
|
| 176 |
+
teacher_prompt_len=teacher_len,
|
| 177 |
+
vision_placeholder_delta=(teacher_len - student_len) if student_len else None,
|
| 178 |
+
)
|
| 179 |
+
opsd_debug.log_detail(
|
| 180 |
+
"teacher_prompt",
|
| 181 |
+
"teacher prompt batch built",
|
| 182 |
+
global_step=global_step,
|
| 183 |
+
batch_size=len(indices),
|
| 184 |
+
teacher_prompt_len=teacher_len,
|
| 185 |
+
teacher_pixel_values_shapes=[
|
| 186 |
+
tuple(pv.shape) for pv in out.get("teacher_pixel_values_list", [])[:4]
|
| 187 |
+
],
|
| 188 |
+
teacher_image_token_counts=batch.get("image_token_counts"),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
from opsd_utils.leakage import privileged_suffix_has_gold
|
| 192 |
+
|
| 193 |
+
vf_empty = 0
|
| 194 |
+
gold_suffix_count = 0
|
| 195 |
+
for idx in indices:
|
| 196 |
+
sample = samples[idx]
|
| 197 |
+
vf = (
|
| 198 |
+
sample.get("visual_fact_hint")
|
| 199 |
+
or sample.get("visual_fact")
|
| 200 |
+
or sample.get("visual_facts")
|
| 201 |
+
or ""
|
| 202 |
+
)
|
| 203 |
+
if not str(vf).strip():
|
| 204 |
+
vf_empty += 1
|
| 205 |
+
priv_suffix, _ = build_privileged_context(
|
| 206 |
+
sample,
|
| 207 |
+
provider_names,
|
| 208 |
+
privileged_profile=privileged_profile,
|
| 209 |
+
crop_cfg=crop_cfg,
|
| 210 |
+
opsd_config=opsd_config,
|
| 211 |
+
)
|
| 212 |
+
if privileged_suffix_has_gold(priv_suffix, sample):
|
| 213 |
+
gold_suffix_count += 1
|
| 214 |
+
suffix_lens = [p["suffix_len"] for p in sample_payloads]
|
| 215 |
+
n_idx = max(len(indices), 1)
|
| 216 |
+
out["teacher_stats"] = {
|
| 217 |
+
"teacher_suffix_len_mean": float(sum(suffix_lens) / len(suffix_lens)) if suffix_lens else 0.0,
|
| 218 |
+
"visual_fact_empty_rate": vf_empty / n_idx,
|
| 219 |
+
"privileged_suffix_has_gold_rate": gold_suffix_count / n_idx,
|
| 220 |
+
"num_teacher_images_mean": float(
|
| 221 |
+
sum(p["num_teacher_images"] for p in sample_payloads) / len(sample_payloads)
|
| 222 |
+
)
|
| 223 |
+
if sample_payloads
|
| 224 |
+
else 0.0,
|
| 225 |
+
}
|
| 226 |
+
return out
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _build_teacher_batch_with_oom_retry(
|
| 230 |
+
processor,
|
| 231 |
+
sample_payloads: list[dict[str, Any]],
|
| 232 |
+
) -> dict:
|
| 233 |
+
"""Process each teacher sample separately; on OOM halve micro-batch and retry."""
|
| 234 |
+
n = len(sample_payloads)
|
| 235 |
+
if n == 0:
|
| 236 |
+
return {}
|
| 237 |
+
micro = n
|
| 238 |
+
while micro >= 1:
|
| 239 |
+
try:
|
| 240 |
+
per_sample_batches: list[dict[str, Any]] = []
|
| 241 |
+
for start in range(0, n, micro):
|
| 242 |
+
end = min(start + micro, n)
|
| 243 |
+
for payload in sample_payloads[start:end]:
|
| 244 |
+
per_sample_batches.append(
|
| 245 |
+
process_teacher_sample(
|
| 246 |
+
processor,
|
| 247 |
+
payload["teacher_text"],
|
| 248 |
+
payload["images"],
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
return stack_teacher_processor_batches(processor, per_sample_batches)
|
| 252 |
+
except RuntimeError as exc:
|
| 253 |
+
if "out of memory" not in str(exc).lower() or micro == 1:
|
| 254 |
+
raise
|
| 255 |
+
opsd_debug.log(
|
| 256 |
+
"teacher_forward_oom",
|
| 257 |
+
"teacher prompt batch OOM, halving micro-batch",
|
| 258 |
+
original_batch=n,
|
| 259 |
+
micro_batch_size=micro,
|
| 260 |
+
new_micro_batch_size=max(1, micro // 2),
|
| 261 |
+
)
|
| 262 |
+
if torch.cuda.is_available():
|
| 263 |
+
torch.cuda.empty_cache()
|
| 264 |
+
micro = max(1, micro // 2)
|
| 265 |
+
return {}
|
outputs/logs/.ipynb_checkpoints/train_opd_7b_ds_20260614_175014-checkpoint.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
outputs/opd-7b-chartqa-ds/checkpoint-1764/zero_to_fp32.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Microsoft Corporation.
|
| 4 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 5 |
+
|
| 6 |
+
# DeepSpeed Team
|
| 7 |
+
|
| 8 |
+
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
|
| 9 |
+
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
| 10 |
+
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
| 11 |
+
# application.
|
| 12 |
+
#
|
| 13 |
+
# example:
|
| 14 |
+
# python zero_to_fp32.py . output_dir/
|
| 15 |
+
# or
|
| 16 |
+
# python zero_to_fp32.py . output_dir/ --safe_serialization
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import torch
|
| 20 |
+
import glob
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
import gc
|
| 25 |
+
import json
|
| 26 |
+
import numpy as np
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
from collections import OrderedDict
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
|
| 31 |
+
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
| 32 |
+
# DeepSpeed data structures it has to be available in the current python environment.
|
| 33 |
+
from deepspeed.utils import logger
|
| 34 |
+
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
| 35 |
+
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
| 36 |
+
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class zero_model_state:
|
| 41 |
+
buffers: dict()
|
| 42 |
+
param_shapes: dict()
|
| 43 |
+
shared_params: list
|
| 44 |
+
ds_version: int
|
| 45 |
+
frozen_param_shapes: dict()
|
| 46 |
+
frozen_param_fragments: dict()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
debug = 0
|
| 50 |
+
|
| 51 |
+
# load to cpu
|
| 52 |
+
device = torch.device('cpu')
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def atoi(text):
|
| 56 |
+
return int(text) if text.isdigit() else text
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def natural_keys(text):
|
| 60 |
+
'''
|
| 61 |
+
alist.sort(key=natural_keys) sorts in human order
|
| 62 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
| 63 |
+
(See Toothy's implementation in the comments)
|
| 64 |
+
'''
|
| 65 |
+
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
| 69 |
+
if not os.path.isdir(checkpoint_dir):
|
| 70 |
+
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
| 71 |
+
|
| 72 |
+
# there should be only one file
|
| 73 |
+
if zero_stage <= 2:
|
| 74 |
+
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
| 75 |
+
elif zero_stage == 3:
|
| 76 |
+
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
| 77 |
+
|
| 78 |
+
if not os.path.exists(file):
|
| 79 |
+
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
| 80 |
+
|
| 81 |
+
return file
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
| 85 |
+
# XXX: need to test that this simple glob rule works for multi-node setup too
|
| 86 |
+
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
| 87 |
+
|
| 88 |
+
if len(ckpt_files) == 0:
|
| 89 |
+
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
| 90 |
+
|
| 91 |
+
return ckpt_files
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_optim_files(checkpoint_dir):
|
| 95 |
+
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_model_state_files(checkpoint_dir):
|
| 99 |
+
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def parse_model_states(files):
|
| 103 |
+
zero_model_states = []
|
| 104 |
+
for file in files:
|
| 105 |
+
state_dict = torch.load(file, map_location=device, weights_only=False)
|
| 106 |
+
|
| 107 |
+
if BUFFER_NAMES not in state_dict:
|
| 108 |
+
raise ValueError(f"{file} is not a model state checkpoint")
|
| 109 |
+
buffer_names = state_dict[BUFFER_NAMES]
|
| 110 |
+
if debug:
|
| 111 |
+
print("Found buffers:", buffer_names)
|
| 112 |
+
|
| 113 |
+
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
| 114 |
+
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
| 115 |
+
param_shapes = state_dict[PARAM_SHAPES]
|
| 116 |
+
|
| 117 |
+
# collect parameters that are included in param_shapes
|
| 118 |
+
param_names = []
|
| 119 |
+
for s in param_shapes:
|
| 120 |
+
for name in s.keys():
|
| 121 |
+
param_names.append(name)
|
| 122 |
+
|
| 123 |
+
# update with frozen parameters
|
| 124 |
+
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
| 125 |
+
if frozen_param_shapes is not None:
|
| 126 |
+
if debug:
|
| 127 |
+
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
| 128 |
+
param_names += list(frozen_param_shapes.keys())
|
| 129 |
+
|
| 130 |
+
# handle shared params
|
| 131 |
+
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
| 132 |
+
|
| 133 |
+
ds_version = state_dict.get(DS_VERSION, None)
|
| 134 |
+
|
| 135 |
+
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
| 136 |
+
|
| 137 |
+
z_model_state = zero_model_state(buffers=buffers,
|
| 138 |
+
param_shapes=param_shapes,
|
| 139 |
+
shared_params=shared_params,
|
| 140 |
+
ds_version=ds_version,
|
| 141 |
+
frozen_param_shapes=frozen_param_shapes,
|
| 142 |
+
frozen_param_fragments=frozen_param_fragments)
|
| 143 |
+
zero_model_states.append(z_model_state)
|
| 144 |
+
|
| 145 |
+
return zero_model_states
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def parse_optim_states(files, ds_checkpoint_dir):
|
| 149 |
+
total_files = len(files)
|
| 150 |
+
state_dicts = []
|
| 151 |
+
for f in tqdm(files, desc='Loading checkpoint shards'):
|
| 152 |
+
state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
|
| 153 |
+
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
| 154 |
+
# and also handle the case where it was already removed by another helper script
|
| 155 |
+
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
| 156 |
+
state_dicts.append(state_dict)
|
| 157 |
+
|
| 158 |
+
if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
| 159 |
+
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
| 160 |
+
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
| 161 |
+
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
| 162 |
+
|
| 163 |
+
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
| 164 |
+
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
| 165 |
+
# use the max of the partition_count to get the dp world_size.
|
| 166 |
+
|
| 167 |
+
if type(world_size) is list:
|
| 168 |
+
world_size = max(world_size)
|
| 169 |
+
|
| 170 |
+
if world_size != total_files:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
| 173 |
+
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# the groups are named differently in each stage
|
| 177 |
+
if zero_stage <= 2:
|
| 178 |
+
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
| 179 |
+
elif zero_stage == 3:
|
| 180 |
+
fp32_groups_key = FP32_FLAT_GROUPS
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f"unknown zero stage {zero_stage}")
|
| 183 |
+
|
| 184 |
+
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
| 185 |
+
return zero_stage, world_size, fp32_flat_groups
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
|
| 189 |
+
"""
|
| 190 |
+
Returns fp32 state_dict reconstructed from ds checkpoint
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
| 194 |
+
|
| 195 |
+
"""
|
| 196 |
+
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
| 197 |
+
|
| 198 |
+
optim_files = get_optim_files(ds_checkpoint_dir)
|
| 199 |
+
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
| 200 |
+
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
| 201 |
+
|
| 202 |
+
model_files = get_model_state_files(ds_checkpoint_dir)
|
| 203 |
+
|
| 204 |
+
zero_model_states = parse_model_states(model_files)
|
| 205 |
+
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
| 206 |
+
|
| 207 |
+
if zero_stage <= 2:
|
| 208 |
+
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 209 |
+
exclude_frozen_parameters)
|
| 210 |
+
elif zero_stage == 3:
|
| 211 |
+
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 212 |
+
exclude_frozen_parameters)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
| 216 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
| 220 |
+
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
| 221 |
+
|
| 222 |
+
if debug:
|
| 223 |
+
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
| 224 |
+
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
| 225 |
+
|
| 226 |
+
wanted_params = len(frozen_param_shapes)
|
| 227 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
| 228 |
+
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
| 229 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
| 230 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
| 231 |
+
|
| 232 |
+
total_params = 0
|
| 233 |
+
total_numel = 0
|
| 234 |
+
for name, shape in frozen_param_shapes.items():
|
| 235 |
+
total_params += 1
|
| 236 |
+
unpartitioned_numel = shape.numel()
|
| 237 |
+
total_numel += unpartitioned_numel
|
| 238 |
+
|
| 239 |
+
state_dict[name] = frozen_param_fragments[name]
|
| 240 |
+
|
| 241 |
+
if debug:
|
| 242 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
| 243 |
+
|
| 244 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _has_callable(obj, fn):
|
| 248 |
+
attr = getattr(obj, fn, None)
|
| 249 |
+
return callable(attr)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
| 253 |
+
param_shapes = zero_model_states[0].param_shapes
|
| 254 |
+
|
| 255 |
+
# Reconstruction protocol:
|
| 256 |
+
#
|
| 257 |
+
# XXX: document this
|
| 258 |
+
|
| 259 |
+
if debug:
|
| 260 |
+
for i in range(world_size):
|
| 261 |
+
for j in range(len(fp32_flat_groups[0])):
|
| 262 |
+
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
| 263 |
+
|
| 264 |
+
# XXX: memory usage doubles here (zero2)
|
| 265 |
+
num_param_groups = len(fp32_flat_groups[0])
|
| 266 |
+
merged_single_partition_of_fp32_groups = []
|
| 267 |
+
for i in range(num_param_groups):
|
| 268 |
+
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
| 269 |
+
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
| 270 |
+
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
| 271 |
+
avail_numel = sum(
|
| 272 |
+
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
| 273 |
+
|
| 274 |
+
if debug:
|
| 275 |
+
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
| 276 |
+
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
| 277 |
+
# not asserting if there is a mismatch due to possible padding
|
| 278 |
+
print(f"Have {avail_numel} numels to process.")
|
| 279 |
+
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
| 280 |
+
|
| 281 |
+
# params
|
| 282 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
| 283 |
+
# out-of-core computing solution
|
| 284 |
+
total_numel = 0
|
| 285 |
+
total_params = 0
|
| 286 |
+
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
| 287 |
+
offset = 0
|
| 288 |
+
avail_numel = full_single_fp32_vector.numel()
|
| 289 |
+
for name, shape in shapes.items():
|
| 290 |
+
|
| 291 |
+
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
| 292 |
+
total_numel += unpartitioned_numel
|
| 293 |
+
total_params += 1
|
| 294 |
+
|
| 295 |
+
if debug:
|
| 296 |
+
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
| 297 |
+
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
| 298 |
+
offset += unpartitioned_numel
|
| 299 |
+
|
| 300 |
+
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
| 301 |
+
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
| 302 |
+
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
| 303 |
+
# live optimizer object, so we are checking that the numbers are within the right range
|
| 304 |
+
align_to = 2 * world_size
|
| 305 |
+
|
| 306 |
+
def zero2_align(x):
|
| 307 |
+
return align_to * math.ceil(x / align_to)
|
| 308 |
+
|
| 309 |
+
if debug:
|
| 310 |
+
print(f"original offset={offset}, avail_numel={avail_numel}")
|
| 311 |
+
|
| 312 |
+
offset = zero2_align(offset)
|
| 313 |
+
avail_numel = zero2_align(avail_numel)
|
| 314 |
+
|
| 315 |
+
if debug:
|
| 316 |
+
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
| 317 |
+
|
| 318 |
+
# Sanity check
|
| 319 |
+
if offset != avail_numel:
|
| 320 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
| 321 |
+
|
| 322 |
+
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 326 |
+
exclude_frozen_parameters):
|
| 327 |
+
state_dict = OrderedDict()
|
| 328 |
+
|
| 329 |
+
# buffers
|
| 330 |
+
buffers = zero_model_states[0].buffers
|
| 331 |
+
state_dict.update(buffers)
|
| 332 |
+
if debug:
|
| 333 |
+
print(f"added {len(buffers)} buffers")
|
| 334 |
+
|
| 335 |
+
if not exclude_frozen_parameters:
|
| 336 |
+
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
| 337 |
+
|
| 338 |
+
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
| 339 |
+
|
| 340 |
+
# recover shared parameters
|
| 341 |
+
for pair in zero_model_states[0].shared_params:
|
| 342 |
+
if pair[1] in state_dict:
|
| 343 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
| 344 |
+
|
| 345 |
+
return state_dict
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
| 349 |
+
remainder = unpartitioned_numel % world_size
|
| 350 |
+
padding_numel = (world_size - remainder) if remainder else 0
|
| 351 |
+
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
| 352 |
+
return partitioned_numel, padding_numel
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
| 356 |
+
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
| 357 |
+
return
|
| 358 |
+
|
| 359 |
+
if debug:
|
| 360 |
+
for i in range(world_size):
|
| 361 |
+
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
| 362 |
+
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
| 363 |
+
|
| 364 |
+
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
| 365 |
+
wanted_params = len(frozen_param_shapes)
|
| 366 |
+
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
| 367 |
+
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
| 368 |
+
print(f'Frozen params: Have {avail_numel} numels to process.')
|
| 369 |
+
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
| 370 |
+
|
| 371 |
+
total_params = 0
|
| 372 |
+
total_numel = 0
|
| 373 |
+
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
| 374 |
+
total_params += 1
|
| 375 |
+
unpartitioned_numel = shape.numel()
|
| 376 |
+
total_numel += unpartitioned_numel
|
| 377 |
+
|
| 378 |
+
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
| 379 |
+
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
| 380 |
+
|
| 381 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
| 382 |
+
|
| 383 |
+
if debug:
|
| 384 |
+
print(
|
| 385 |
+
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class GatheredTensor:
|
| 392 |
+
"""
|
| 393 |
+
A pseudo tensor that collects partitioned weights.
|
| 394 |
+
It is more memory efficient when there are multiple groups.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
|
| 398 |
+
self.flat_groups = flat_groups
|
| 399 |
+
self.flat_groups_offset = flat_groups_offset
|
| 400 |
+
self.offset = offset
|
| 401 |
+
self.partitioned_numel = partitioned_numel
|
| 402 |
+
self.shape = shape
|
| 403 |
+
self.dtype = self.flat_groups[0][0].dtype
|
| 404 |
+
|
| 405 |
+
def contiguous(self):
|
| 406 |
+
"""
|
| 407 |
+
Merge partitioned weights from flat_groups into a single tensor.
|
| 408 |
+
"""
|
| 409 |
+
end_idx = self.offset + self.partitioned_numel
|
| 410 |
+
world_size = len(self.flat_groups)
|
| 411 |
+
pad_flat_param_chunks = []
|
| 412 |
+
|
| 413 |
+
for rank_i in range(world_size):
|
| 414 |
+
# for each rank, we need to collect weights from related group/groups
|
| 415 |
+
flat_groups_at_rank_i = self.flat_groups[rank_i]
|
| 416 |
+
start_group_id = None
|
| 417 |
+
end_group_id = None
|
| 418 |
+
for group_id in range(len(self.flat_groups_offset)):
|
| 419 |
+
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
|
| 420 |
+
start_group_id = group_id
|
| 421 |
+
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
|
| 422 |
+
end_group_id = group_id
|
| 423 |
+
break
|
| 424 |
+
# collect weights from related group/groups
|
| 425 |
+
for group_id in range(start_group_id, end_group_id + 1):
|
| 426 |
+
flat_tensor = flat_groups_at_rank_i[group_id]
|
| 427 |
+
start_offset = self.offset - self.flat_groups_offset[group_id]
|
| 428 |
+
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
|
| 429 |
+
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
|
| 430 |
+
|
| 431 |
+
# collect weights from all ranks
|
| 432 |
+
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
|
| 433 |
+
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
|
| 434 |
+
return param
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
| 438 |
+
param_shapes = zero_model_states[0].param_shapes
|
| 439 |
+
avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
|
| 440 |
+
|
| 441 |
+
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
| 442 |
+
# param, re-consolidating each param, while dealing with padding if any
|
| 443 |
+
|
| 444 |
+
# merge list of dicts, preserving order
|
| 445 |
+
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
| 446 |
+
|
| 447 |
+
if debug:
|
| 448 |
+
for i in range(world_size):
|
| 449 |
+
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
| 450 |
+
|
| 451 |
+
wanted_params = len(param_shapes)
|
| 452 |
+
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
| 453 |
+
# not asserting if there is a mismatch due to possible padding
|
| 454 |
+
avail_numel = fp32_flat_groups[0].numel() * world_size
|
| 455 |
+
print(f"Trainable params: Have {avail_numel} numels to process.")
|
| 456 |
+
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
| 457 |
+
|
| 458 |
+
# params
|
| 459 |
+
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
| 460 |
+
# out-of-core computing solution
|
| 461 |
+
offset = 0
|
| 462 |
+
total_numel = 0
|
| 463 |
+
total_params = 0
|
| 464 |
+
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
|
| 465 |
+
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
|
| 466 |
+
unpartitioned_numel = shape.numel()
|
| 467 |
+
total_numel += unpartitioned_numel
|
| 468 |
+
total_params += 1
|
| 469 |
+
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
| 470 |
+
|
| 471 |
+
if debug:
|
| 472 |
+
print(
|
| 473 |
+
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# memory efficient tensor
|
| 477 |
+
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
|
| 478 |
+
state_dict[name] = tensor
|
| 479 |
+
offset += partitioned_numel
|
| 480 |
+
|
| 481 |
+
offset *= world_size
|
| 482 |
+
|
| 483 |
+
# Sanity check
|
| 484 |
+
if offset != avail_numel:
|
| 485 |
+
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
| 486 |
+
|
| 487 |
+
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
| 491 |
+
exclude_frozen_parameters):
|
| 492 |
+
state_dict = OrderedDict()
|
| 493 |
+
|
| 494 |
+
# buffers
|
| 495 |
+
buffers = zero_model_states[0].buffers
|
| 496 |
+
state_dict.update(buffers)
|
| 497 |
+
if debug:
|
| 498 |
+
print(f"added {len(buffers)} buffers")
|
| 499 |
+
|
| 500 |
+
if not exclude_frozen_parameters:
|
| 501 |
+
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
| 502 |
+
|
| 503 |
+
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
| 504 |
+
|
| 505 |
+
# recover shared parameters
|
| 506 |
+
for pair in zero_model_states[0].shared_params:
|
| 507 |
+
if pair[1] in state_dict:
|
| 508 |
+
state_dict[pair[0]] = state_dict[pair[1]]
|
| 509 |
+
|
| 510 |
+
return state_dict
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def to_torch_tensor(state_dict, return_empty_tensor=False):
|
| 514 |
+
"""
|
| 515 |
+
Convert state_dict of GatheredTensor to torch tensor
|
| 516 |
+
"""
|
| 517 |
+
torch_state_dict = {}
|
| 518 |
+
converted_tensors = {}
|
| 519 |
+
for name, tensor in state_dict.items():
|
| 520 |
+
tensor_id = id(tensor)
|
| 521 |
+
if tensor_id in converted_tensors: # shared tensors
|
| 522 |
+
shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
|
| 523 |
+
torch_state_dict[name] = shared_tensor
|
| 524 |
+
else:
|
| 525 |
+
converted_tensors[tensor_id] = name
|
| 526 |
+
if return_empty_tensor:
|
| 527 |
+
torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
|
| 528 |
+
else:
|
| 529 |
+
torch_state_dict[name] = tensor.contiguous()
|
| 530 |
+
return torch_state_dict
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
| 534 |
+
tag=None,
|
| 535 |
+
exclude_frozen_parameters=False,
|
| 536 |
+
lazy_mode=False):
|
| 537 |
+
"""
|
| 538 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
| 539 |
+
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
| 540 |
+
via a model hub.
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder
|
| 544 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
| 545 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
| 546 |
+
- ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
|
| 547 |
+
Convert the pesduo tensor to torch tensor by ``.contiguous()``
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
- pytorch ``state_dict``
|
| 551 |
+
|
| 552 |
+
A typical usage might be ::
|
| 553 |
+
|
| 554 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
| 555 |
+
# do the training and checkpoint saving
|
| 556 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
| 557 |
+
model = model.cpu() # move to cpu
|
| 558 |
+
model.load_state_dict(state_dict)
|
| 559 |
+
# submit to model hub or save the model to share with others
|
| 560 |
+
|
| 561 |
+
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
| 562 |
+
application. i.e. you will need to re-initialize the deepspeed engine, since
|
| 563 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
| 564 |
+
|
| 565 |
+
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
| 566 |
+
|
| 567 |
+
Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
|
| 568 |
+
You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
| 569 |
+
the checkpoint. Or you can load state_dict in lazy mode ::
|
| 570 |
+
|
| 571 |
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
| 572 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
|
| 573 |
+
for name, lazy_tensor in state_dict.item():
|
| 574 |
+
tensor = lazy_tensor.contiguous() # to cpu
|
| 575 |
+
print(name, tensor)
|
| 576 |
+
# del tensor to release memory if it no longer in use
|
| 577 |
+
"""
|
| 578 |
+
if tag is None:
|
| 579 |
+
latest_path = os.path.join(checkpoint_dir, 'latest')
|
| 580 |
+
if os.path.isfile(latest_path):
|
| 581 |
+
with open(latest_path, 'r') as fd:
|
| 582 |
+
tag = fd.read().strip()
|
| 583 |
+
else:
|
| 584 |
+
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
| 585 |
+
|
| 586 |
+
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
| 587 |
+
|
| 588 |
+
if not os.path.isdir(ds_checkpoint_dir):
|
| 589 |
+
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
| 590 |
+
|
| 591 |
+
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
| 592 |
+
if lazy_mode:
|
| 593 |
+
return state_dict
|
| 594 |
+
else:
|
| 595 |
+
return to_torch_tensor(state_dict)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
| 599 |
+
output_dir,
|
| 600 |
+
max_shard_size="5GB",
|
| 601 |
+
safe_serialization=False,
|
| 602 |
+
tag=None,
|
| 603 |
+
exclude_frozen_parameters=False):
|
| 604 |
+
"""
|
| 605 |
+
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
| 606 |
+
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
| 610 |
+
- ``output_dir``: directory to the pytorch fp32 state_dict output files
|
| 611 |
+
- ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
|
| 612 |
+
- ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
| 613 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
| 614 |
+
- ``exclude_frozen_parameters``: exclude frozen parameters
|
| 615 |
+
"""
|
| 616 |
+
|
| 617 |
+
# Dependency pre-check
|
| 618 |
+
if safe_serialization:
|
| 619 |
+
try:
|
| 620 |
+
from safetensors.torch import save_file
|
| 621 |
+
except ImportError:
|
| 622 |
+
print('If you want to use `safe_serialization`, please `pip install safetensors`')
|
| 623 |
+
raise
|
| 624 |
+
if max_shard_size is not None:
|
| 625 |
+
try:
|
| 626 |
+
from huggingface_hub import split_torch_state_dict_into_shards
|
| 627 |
+
except ImportError:
|
| 628 |
+
print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
|
| 629 |
+
raise
|
| 630 |
+
|
| 631 |
+
# Convert zero checkpoint to state_dict
|
| 632 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
| 633 |
+
tag,
|
| 634 |
+
exclude_frozen_parameters,
|
| 635 |
+
lazy_mode=True)
|
| 636 |
+
|
| 637 |
+
# Shard the model if it is too big.
|
| 638 |
+
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
|
| 639 |
+
if max_shard_size is not None:
|
| 640 |
+
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
| 641 |
+
# an memory-efficient approach for sharding
|
| 642 |
+
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
| 643 |
+
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
|
| 644 |
+
filename_pattern=filename_pattern,
|
| 645 |
+
max_shard_size=max_shard_size)
|
| 646 |
+
else:
|
| 647 |
+
from collections import namedtuple
|
| 648 |
+
StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
|
| 649 |
+
state_dict_split = StateDictSplit(is_sharded=False,
|
| 650 |
+
filename_to_tensors={weights_name: list(state_dict.keys())})
|
| 651 |
+
|
| 652 |
+
# Save the model by shard
|
| 653 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 654 |
+
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
| 655 |
+
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
|
| 656 |
+
shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
|
| 657 |
+
shard_state_dict = to_torch_tensor(shard_state_dict)
|
| 658 |
+
output_path = os.path.join(output_dir, shard_file)
|
| 659 |
+
if safe_serialization:
|
| 660 |
+
save_file(shard_state_dict, output_path, metadata={"format": "pt"})
|
| 661 |
+
else:
|
| 662 |
+
torch.save(shard_state_dict, output_path)
|
| 663 |
+
# release the memory of current shard
|
| 664 |
+
for tensor_name in list(shard_state_dict.keys()):
|
| 665 |
+
del state_dict[tensor_name]
|
| 666 |
+
del shard_state_dict[tensor_name]
|
| 667 |
+
del shard_state_dict
|
| 668 |
+
gc.collect()
|
| 669 |
+
|
| 670 |
+
# Save index if sharded
|
| 671 |
+
if state_dict_split.is_sharded:
|
| 672 |
+
index = {
|
| 673 |
+
"metadata": state_dict_split.metadata,
|
| 674 |
+
"weight_map": state_dict_split.tensor_to_filename,
|
| 675 |
+
}
|
| 676 |
+
save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
|
| 677 |
+
save_index_file = os.path.join(output_dir, save_index_file)
|
| 678 |
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
| 679 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
| 680 |
+
f.write(content)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
| 684 |
+
"""
|
| 685 |
+
1. Put the provided model to cpu
|
| 686 |
+
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
| 687 |
+
3. Load it into the provided model
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
- ``model``: the model object to update
|
| 691 |
+
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
| 692 |
+
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
| 693 |
+
|
| 694 |
+
Returns:
|
| 695 |
+
- ``model`: modified model
|
| 696 |
+
|
| 697 |
+
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
| 698 |
+
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
| 699 |
+
conveniently placed for you in the checkpoint folder.
|
| 700 |
+
|
| 701 |
+
A typical usage might be ::
|
| 702 |
+
|
| 703 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
| 704 |
+
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
| 705 |
+
# submit to model hub or save the model to share with others
|
| 706 |
+
|
| 707 |
+
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
| 708 |
+
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
| 709 |
+
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
| 710 |
+
|
| 711 |
+
"""
|
| 712 |
+
logger.info("Extracting fp32 weights")
|
| 713 |
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
| 714 |
+
|
| 715 |
+
logger.info("Overwriting model with fp32 weights")
|
| 716 |
+
model = model.cpu()
|
| 717 |
+
model.load_state_dict(state_dict, strict=False)
|
| 718 |
+
|
| 719 |
+
return model
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if __name__ == "__main__":
|
| 723 |
+
parser = argparse.ArgumentParser()
|
| 724 |
+
parser.add_argument("checkpoint_dir",
|
| 725 |
+
type=str,
|
| 726 |
+
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
| 727 |
+
parser.add_argument("output_dir",
|
| 728 |
+
type=str,
|
| 729 |
+
help="directory to the pytorch fp32 state_dict output files"
|
| 730 |
+
"(e.g. path/checkpoint-12-output/)")
|
| 731 |
+
parser.add_argument(
|
| 732 |
+
"--max_shard_size",
|
| 733 |
+
type=str,
|
| 734 |
+
default="5GB",
|
| 735 |
+
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
|
| 736 |
+
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
|
| 737 |
+
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
|
| 738 |
+
"without CPU OOM issues.")
|
| 739 |
+
parser.add_argument(
|
| 740 |
+
"--safe_serialization",
|
| 741 |
+
default=False,
|
| 742 |
+
action='store_true',
|
| 743 |
+
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
|
| 744 |
+
parser.add_argument("-t",
|
| 745 |
+
"--tag",
|
| 746 |
+
type=str,
|
| 747 |
+
default=None,
|
| 748 |
+
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
| 749 |
+
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
| 750 |
+
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
| 751 |
+
args = parser.parse_args()
|
| 752 |
+
|
| 753 |
+
debug = args.debug
|
| 754 |
+
|
| 755 |
+
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
|
| 756 |
+
args.output_dir,
|
| 757 |
+
max_shard_size=args.max_shard_size,
|
| 758 |
+
safe_serialization=args.safe_serialization,
|
| 759 |
+
tag=args.tag,
|
| 760 |
+
exclude_frozen_parameters=args.exclude_frozen_parameters)
|
outputs/opd-7b-chartqa-ds/checkpoint-2352/preprocessor_config.json
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_convert_rgb": true,
|
| 3 |
+
"do_normalize": true,
|
| 4 |
+
"do_pad": true,
|
| 5 |
+
"do_rescale": true,
|
| 6 |
+
"do_resize": true,
|
| 7 |
+
"image_grid_pinpoints": [
|
| 8 |
+
[
|
| 9 |
+
384,
|
| 10 |
+
384
|
| 11 |
+
],
|
| 12 |
+
[
|
| 13 |
+
384,
|
| 14 |
+
768
|
| 15 |
+
],
|
| 16 |
+
[
|
| 17 |
+
384,
|
| 18 |
+
1152
|
| 19 |
+
],
|
| 20 |
+
[
|
| 21 |
+
384,
|
| 22 |
+
1536
|
| 23 |
+
],
|
| 24 |
+
[
|
| 25 |
+
384,
|
| 26 |
+
1920
|
| 27 |
+
],
|
| 28 |
+
[
|
| 29 |
+
384,
|
| 30 |
+
2304
|
| 31 |
+
],
|
| 32 |
+
[
|
| 33 |
+
768,
|
| 34 |
+
384
|
| 35 |
+
],
|
| 36 |
+
[
|
| 37 |
+
768,
|
| 38 |
+
768
|
| 39 |
+
],
|
| 40 |
+
[
|
| 41 |
+
768,
|
| 42 |
+
1152
|
| 43 |
+
],
|
| 44 |
+
[
|
| 45 |
+
768,
|
| 46 |
+
1536
|
| 47 |
+
],
|
| 48 |
+
[
|
| 49 |
+
768,
|
| 50 |
+
1920
|
| 51 |
+
],
|
| 52 |
+
[
|
| 53 |
+
768,
|
| 54 |
+
2304
|
| 55 |
+
],
|
| 56 |
+
[
|
| 57 |
+
1152,
|
| 58 |
+
384
|
| 59 |
+
],
|
| 60 |
+
[
|
| 61 |
+
1152,
|
| 62 |
+
768
|
| 63 |
+
],
|
| 64 |
+
[
|
| 65 |
+
1152,
|
| 66 |
+
1152
|
| 67 |
+
],
|
| 68 |
+
[
|
| 69 |
+
1152,
|
| 70 |
+
1536
|
| 71 |
+
],
|
| 72 |
+
[
|
| 73 |
+
1152,
|
| 74 |
+
1920
|
| 75 |
+
],
|
| 76 |
+
[
|
| 77 |
+
1152,
|
| 78 |
+
2304
|
| 79 |
+
],
|
| 80 |
+
[
|
| 81 |
+
1536,
|
| 82 |
+
384
|
| 83 |
+
],
|
| 84 |
+
[
|
| 85 |
+
1536,
|
| 86 |
+
768
|
| 87 |
+
],
|
| 88 |
+
[
|
| 89 |
+
1536,
|
| 90 |
+
1152
|
| 91 |
+
],
|
| 92 |
+
[
|
| 93 |
+
1536,
|
| 94 |
+
1536
|
| 95 |
+
],
|
| 96 |
+
[
|
| 97 |
+
1536,
|
| 98 |
+
1920
|
| 99 |
+
],
|
| 100 |
+
[
|
| 101 |
+
1536,
|
| 102 |
+
2304
|
| 103 |
+
],
|
| 104 |
+
[
|
| 105 |
+
1920,
|
| 106 |
+
384
|
| 107 |
+
],
|
| 108 |
+
[
|
| 109 |
+
1920,
|
| 110 |
+
768
|
| 111 |
+
],
|
| 112 |
+
[
|
| 113 |
+
1920,
|
| 114 |
+
1152
|
| 115 |
+
],
|
| 116 |
+
[
|
| 117 |
+
1920,
|
| 118 |
+
1536
|
| 119 |
+
],
|
| 120 |
+
[
|
| 121 |
+
1920,
|
| 122 |
+
1920
|
| 123 |
+
],
|
| 124 |
+
[
|
| 125 |
+
1920,
|
| 126 |
+
2304
|
| 127 |
+
],
|
| 128 |
+
[
|
| 129 |
+
2304,
|
| 130 |
+
384
|
| 131 |
+
],
|
| 132 |
+
[
|
| 133 |
+
2304,
|
| 134 |
+
768
|
| 135 |
+
],
|
| 136 |
+
[
|
| 137 |
+
2304,
|
| 138 |
+
1152
|
| 139 |
+
],
|
| 140 |
+
[
|
| 141 |
+
2304,
|
| 142 |
+
1536
|
| 143 |
+
],
|
| 144 |
+
[
|
| 145 |
+
2304,
|
| 146 |
+
1920
|
| 147 |
+
],
|
| 148 |
+
[
|
| 149 |
+
2304,
|
| 150 |
+
2304
|
| 151 |
+
]
|
| 152 |
+
],
|
| 153 |
+
"image_mean": [
|
| 154 |
+
0.5,
|
| 155 |
+
0.5,
|
| 156 |
+
0.5
|
| 157 |
+
],
|
| 158 |
+
"image_processor_type": "LlavaOnevisionImageProcessor",
|
| 159 |
+
"image_std": [
|
| 160 |
+
0.5,
|
| 161 |
+
0.5,
|
| 162 |
+
0.5
|
| 163 |
+
],
|
| 164 |
+
"processor_class": "LlavaOnevisionProcessor",
|
| 165 |
+
"resample": 3,
|
| 166 |
+
"rescale_factor": 0.00392156862745098,
|
| 167 |
+
"size": {
|
| 168 |
+
"height": 384,
|
| 169 |
+
"width": 384
|
| 170 |
+
}
|
| 171 |
+
}
|
outputs/opd-7b-chartqa-ds/checkpoint-588/config.json
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LlavaOnevisionForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"dtype": "bfloat16",
|
| 6 |
+
"eos_token_id": 151645,
|
| 7 |
+
"ignore_index": -100,
|
| 8 |
+
"image_grid_pinpoints": [
|
| 9 |
+
[
|
| 10 |
+
384,
|
| 11 |
+
384
|
| 12 |
+
],
|
| 13 |
+
[
|
| 14 |
+
384,
|
| 15 |
+
768
|
| 16 |
+
],
|
| 17 |
+
[
|
| 18 |
+
384,
|
| 19 |
+
1152
|
| 20 |
+
],
|
| 21 |
+
[
|
| 22 |
+
384,
|
| 23 |
+
1536
|
| 24 |
+
],
|
| 25 |
+
[
|
| 26 |
+
384,
|
| 27 |
+
1920
|
| 28 |
+
],
|
| 29 |
+
[
|
| 30 |
+
384,
|
| 31 |
+
2304
|
| 32 |
+
],
|
| 33 |
+
[
|
| 34 |
+
768,
|
| 35 |
+
384
|
| 36 |
+
],
|
| 37 |
+
[
|
| 38 |
+
768,
|
| 39 |
+
768
|
| 40 |
+
],
|
| 41 |
+
[
|
| 42 |
+
768,
|
| 43 |
+
1152
|
| 44 |
+
],
|
| 45 |
+
[
|
| 46 |
+
768,
|
| 47 |
+
1536
|
| 48 |
+
],
|
| 49 |
+
[
|
| 50 |
+
768,
|
| 51 |
+
1920
|
| 52 |
+
],
|
| 53 |
+
[
|
| 54 |
+
768,
|
| 55 |
+
2304
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
1152,
|
| 59 |
+
384
|
| 60 |
+
],
|
| 61 |
+
[
|
| 62 |
+
1152,
|
| 63 |
+
768
|
| 64 |
+
],
|
| 65 |
+
[
|
| 66 |
+
1152,
|
| 67 |
+
1152
|
| 68 |
+
],
|
| 69 |
+
[
|
| 70 |
+
1152,
|
| 71 |
+
1536
|
| 72 |
+
],
|
| 73 |
+
[
|
| 74 |
+
1152,
|
| 75 |
+
1920
|
| 76 |
+
],
|
| 77 |
+
[
|
| 78 |
+
1152,
|
| 79 |
+
2304
|
| 80 |
+
],
|
| 81 |
+
[
|
| 82 |
+
1536,
|
| 83 |
+
384
|
| 84 |
+
],
|
| 85 |
+
[
|
| 86 |
+
1536,
|
| 87 |
+
768
|
| 88 |
+
],
|
| 89 |
+
[
|
| 90 |
+
1536,
|
| 91 |
+
1152
|
| 92 |
+
],
|
| 93 |
+
[
|
| 94 |
+
1536,
|
| 95 |
+
1536
|
| 96 |
+
],
|
| 97 |
+
[
|
| 98 |
+
1536,
|
| 99 |
+
1920
|
| 100 |
+
],
|
| 101 |
+
[
|
| 102 |
+
1536,
|
| 103 |
+
2304
|
| 104 |
+
],
|
| 105 |
+
[
|
| 106 |
+
1920,
|
| 107 |
+
384
|
| 108 |
+
],
|
| 109 |
+
[
|
| 110 |
+
1920,
|
| 111 |
+
768
|
| 112 |
+
],
|
| 113 |
+
[
|
| 114 |
+
1920,
|
| 115 |
+
1152
|
| 116 |
+
],
|
| 117 |
+
[
|
| 118 |
+
1920,
|
| 119 |
+
1536
|
| 120 |
+
],
|
| 121 |
+
[
|
| 122 |
+
1920,
|
| 123 |
+
1920
|
| 124 |
+
],
|
| 125 |
+
[
|
| 126 |
+
1920,
|
| 127 |
+
2304
|
| 128 |
+
],
|
| 129 |
+
[
|
| 130 |
+
2304,
|
| 131 |
+
384
|
| 132 |
+
],
|
| 133 |
+
[
|
| 134 |
+
2304,
|
| 135 |
+
768
|
| 136 |
+
],
|
| 137 |
+
[
|
| 138 |
+
2304,
|
| 139 |
+
1152
|
| 140 |
+
],
|
| 141 |
+
[
|
| 142 |
+
2304,
|
| 143 |
+
1536
|
| 144 |
+
],
|
| 145 |
+
[
|
| 146 |
+
2304,
|
| 147 |
+
1920
|
| 148 |
+
],
|
| 149 |
+
[
|
| 150 |
+
2304,
|
| 151 |
+
2304
|
| 152 |
+
]
|
| 153 |
+
],
|
| 154 |
+
"image_token_index": 151646,
|
| 155 |
+
"model_type": "llava_onevision",
|
| 156 |
+
"multimodal_projector_bias": true,
|
| 157 |
+
"pad_token_id": 151643,
|
| 158 |
+
"projector_hidden_act": "gelu",
|
| 159 |
+
"text_config": {
|
| 160 |
+
"_name_or_path": "Qwen/Qwen2-0.5B-Instruct",
|
| 161 |
+
"architectures": [
|
| 162 |
+
"Qwen2ForCausalLM"
|
| 163 |
+
],
|
| 164 |
+
"attention_dropout": 0.0,
|
| 165 |
+
"bos_token_id": 151643,
|
| 166 |
+
"dtype": "bfloat16",
|
| 167 |
+
"eos_token_id": 151645,
|
| 168 |
+
"hidden_act": "silu",
|
| 169 |
+
"hidden_size": 896,
|
| 170 |
+
"initializer_range": 0.02,
|
| 171 |
+
"intermediate_size": 4864,
|
| 172 |
+
"layer_types": [
|
| 173 |
+
"full_attention",
|
| 174 |
+
"full_attention",
|
| 175 |
+
"full_attention",
|
| 176 |
+
"full_attention",
|
| 177 |
+
"full_attention",
|
| 178 |
+
"full_attention",
|
| 179 |
+
"full_attention",
|
| 180 |
+
"full_attention",
|
| 181 |
+
"full_attention",
|
| 182 |
+
"full_attention",
|
| 183 |
+
"full_attention",
|
| 184 |
+
"full_attention",
|
| 185 |
+
"full_attention",
|
| 186 |
+
"full_attention",
|
| 187 |
+
"full_attention",
|
| 188 |
+
"full_attention",
|
| 189 |
+
"full_attention",
|
| 190 |
+
"full_attention",
|
| 191 |
+
"full_attention",
|
| 192 |
+
"full_attention",
|
| 193 |
+
"full_attention",
|
| 194 |
+
"full_attention",
|
| 195 |
+
"full_attention",
|
| 196 |
+
"full_attention"
|
| 197 |
+
],
|
| 198 |
+
"max_position_embeddings": 32768,
|
| 199 |
+
"max_window_layers": 24,
|
| 200 |
+
"model_type": "qwen2",
|
| 201 |
+
"num_attention_heads": 14,
|
| 202 |
+
"num_hidden_layers": 24,
|
| 203 |
+
"num_key_value_heads": 2,
|
| 204 |
+
"rms_norm_eps": 1e-06,
|
| 205 |
+
"rope_scaling": null,
|
| 206 |
+
"rope_theta": 1000000.0,
|
| 207 |
+
"sliding_window": null,
|
| 208 |
+
"tie_word_embeddings": true,
|
| 209 |
+
"use_cache": true,
|
| 210 |
+
"use_sliding_window": false,
|
| 211 |
+
"vocab_size": 152000
|
| 212 |
+
},
|
| 213 |
+
"tie_word_embeddings": false,
|
| 214 |
+
"transformers_version": "4.57.1",
|
| 215 |
+
"use_image_newline_parameter": true,
|
| 216 |
+
"video_token_index": 151647,
|
| 217 |
+
"vision_aspect_ratio": "anyres_max_9",
|
| 218 |
+
"vision_config": {
|
| 219 |
+
"attention_dropout": 0.0,
|
| 220 |
+
"dtype": "bfloat16",
|
| 221 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 222 |
+
"hidden_size": 1152,
|
| 223 |
+
"image_size": 384,
|
| 224 |
+
"intermediate_size": 4304,
|
| 225 |
+
"layer_norm_eps": 1e-06,
|
| 226 |
+
"model_type": "siglip_vision_model",
|
| 227 |
+
"num_attention_heads": 16,
|
| 228 |
+
"num_channels": 3,
|
| 229 |
+
"num_hidden_layers": 26,
|
| 230 |
+
"patch_size": 14,
|
| 231 |
+
"vision_use_head": false
|
| 232 |
+
},
|
| 233 |
+
"vision_feature_layer": -1,
|
| 234 |
+
"vision_feature_select_strategy": "full"
|
| 235 |
+
}
|
papers/full_text.txt
ADDED
|
@@ -0,0 +1,1211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
--- PAGE 1 ---
|
| 3 |
+
PublishedasaconferencepaperatICLR2026
|
| 4 |
+
EMPOWERING SMALL VLMS TO THINK WITH DY-
|
| 5 |
+
NAMIC MEMORIZATION AND EXPLORATION
|
| 6 |
+
JiazhenLiu ,YuchuanDeng ,andLongChen∗
|
| 7 |
+
TheHongKongUniversityofScienceandTechnology
|
| 8 |
+
https://github.com/HKUST-LongGroup/DyME
|
| 9 |
+
ABSTRACT
|
| 10 |
+
Small-scaleVision–LanguageModels(SVLMs)areexceptionallywell-suitedfor
|
| 11 |
+
proprietarytasks. Equippingthemwiththinkingcapabilitiesisacriticalstepto
|
| 12 |
+
enhance their performance and reliability in these specific domains. However,
|
| 13 |
+
existingtrainingparadigms,includingSupervisedFine-Tuning(SFT)andRein-
|
| 14 |
+
forcementLearningwithVerifiableReward(RLVR),imposesubstantialdemands
|
| 15 |
+
onthebaseVLM,exceedingthecapacityofSVLMs. Consequently,directlyap-
|
| 16 |
+
plyingtheseparadigmstoSVLMsfailstoinstillthedesiredthinkingabilities. A
|
| 17 |
+
natural solution is to combine SFT and RLVR, leveraging their complementar-
|
| 18 |
+
ity to reduce the dependence on model capacity. Yet the core challenge lies in
|
| 19 |
+
managingtheinherenttrade-off: excessiverelianceonSFTcanforcethemodel
|
| 20 |
+
tomemorizepseudothinkingtraces,whileover-emphasizingRLVRcanleadto
|
| 21 |
+
unstableexploration(i.e.,advantagecollapse). Toaddressthis,weproposeDyME,
|
| 22 |
+
anoveltrainingparadigmthatDynamicallyselectsbetweenMemorization(via
|
| 23 |
+
SFT) and Exploration (via RLVR) at each optimization step. By ensuring that
|
| 24 |
+
every update contributes to the trade-off, DyME serves as a robust, standalone
|
| 25 |
+
strategythatstabilizesSVLMlearning. Complementingthisparadigm,wefurther
|
| 26 |
+
introduceasynergisticVisualSupervisionmechanism(comprisingavisualchecker
|
| 27 |
+
andrefiner)designedtoinjectdynamicallyenhanced,image-groundedguidance
|
| 28 |
+
duringoptimization. Extensiveexperimentsacrossdiversedomainsdemonstrate
|
| 29 |
+
that DyME consistently achieves this balance, and thus delivers substantial per-
|
| 30 |
+
formance improvements on specialized tasks. These results establish DyME as
|
| 31 |
+
apracticalandeffectivesolutionforempoweringSVLMswithreliablethinking
|
| 32 |
+
capabilities.
|
| 33 |
+
1 INTRODUCTION
|
| 34 |
+
EquippingVision–LanguageModels(VLMs)withthinkingcapabilitiesisapivotalstepthatmoves
|
| 35 |
+
thembeyondrecognitiontowardreasoning. Recentstudieshaveadvancedthisgoalthroughspe-
|
| 36 |
+
cializedtraining,achievingstrongresultsonaspectrumofvisualtasks,fromrecognition-intensive
|
| 37 |
+
applicationslikegrounding(Laietal.,2025;Liu&Chen,2025;Pengetal.,2025;Liuetal.,2025c;a)
|
| 38 |
+
toreasoning-intensivechallengessuchaschartunderstanding(Zhangetal.,2025a;Xiaetal.,2024)
|
| 39 |
+
andgeometricproblemsolving(Shenetal.,2025;Chenetal.,2025b;Xiaetal.,2025). Whilethis
|
| 40 |
+
progressissignificant,thesuccessoftheseapproachesiscontingentuponthebaseVLMpossessing
|
| 41 |
+
strongfoundationalcapabilities,namely,sufficientcapacityandrobustinstructionadherence(Yang
|
| 42 |
+
etal.,2025a). Inpractice,onlyahandfulofVLMsmeettheseprerequisites,presentingasignificant
|
| 43 |
+
challengeforSmall-scaleVLMs(SVLMs)whichstruggletodevelopthinkingcapabilitiesunder
|
| 44 |
+
existingtrainingparadigms.
|
| 45 |
+
Tocontextualizethislimitation,webrieflyreviewthetwodominantparadigms,bothofwhichare
|
| 46 |
+
primarilytailoredforLarge-scaleVLMs(LVLMs). 1)SupervisedFine-Tuning(SFT)onChain-of-
|
| 47 |
+
Thought(CoT)data(Xuetal.,2024;Lietal.,2024b;Xiaetal.,2025;Gaoetal.,2025): VLMsare
|
| 48 |
+
supervisedtomemorizepredefinedthinkingpatternsfromlarge-scaleCoTannotations. SinceCoT
|
| 49 |
+
dataareoftenverboseandcontainmuchvision-irrelevantcontent,modelsmustpossesssufficient
|
| 50 |
+
capacity to absorb long textual content without compromising visual grounding (Marafioti et al.,
|
| 51 |
+
∗Correspondingauthor(longchen@ust.hk)
|
| 52 |
+
1
|
| 53 |
+
6202
|
| 54 |
+
beF
|
| 55 |
+
72
|
| 56 |
+
]VC.sc[
|
| 57 |
+
2v16032.6052:viXra
|
| 58 |
+
--- PAGE 2 ---
|
| 59 |
+
PublishedasaconferencepaperatICLR2026
|
| 60 |
+
After CoT SFT training During RLVR training
|
| 61 |
+
(Input Image)
|
| 62 |
+
Rapid Decline in Brazilians' Which year shows a greater divergence of opinions about Brazil's Which year has the most divergent
|
| 63 |
+
Assessment of Economy economy, 2010 or 2012? Output the thinking process and then give opinions about Brazil's economy? Output
|
| 64 |
+
the final answer in <answer> </answer> tag. the thinking process and then give the
|
| 65 |
+
final answer in <answer> </answer> tag.
|
| 66 |
+
We locate 2010 and 2012. To answer the question, we To answer the question, we examine the
|
| 67 |
+
Then, extract 2010 has 62 first locate the relevant years chart and compare the value differences
|
| 68 |
+
and 36, 2012 has 65 and in the chart. Then, we extract for each year. In 2015, the values are 87
|
| 69 |
+
35. Comparing 62 − 36 = 26 the values for each year. and 13, showing the most significant
|
| 70 |
+
< 65 − 35 = 30, the year with Compare the differences, the divergence. <answer>2015</answer>
|
| 71 |
+
greater divergence is 2012. year with greater divergence is
|
| 72 |
+
<answer>2012</answer> 2010. <answer>2010</answer> 2015 has the greatest divergence.
|
| 73 |
+
Large-scale VLM
|
| 74 |
+
LVLMs: vision preserved SVLMs: vision compromised SVLMs poorly follow instructions
|
| 75 |
+
Small-scale VLM
|
| 76 |
+
Grounded thinking traces Pseudo thinking traces (fail) Advantage collapsing (fail)
|
| 77 |
+
a) SFT and RL paradigms fail to enable SVLMs to think.
|
| 78 |
+
Two-stage training DyME training
|
| 79 |
+
Extent
|
| 80 |
+
1. Locate the year; Constrained
|
| 81 |
+
2. Extract the value; exploration RL RL
|
| 82 |
+
3. Get the answer. RL safe to
|
| 83 |
+
<answer>2012</answer> explore
|
| 84 |
+
(hard to achieve) (single step)
|
| 85 |
+
SFT . t < . h . a e g n c r s h o w a u e r n r t d > .. e 2 . d A 0 1 v f 2 t a e < l r u / t a e h n i f n s r w o k m i e n r g > , RL c T ap h a in b k il i i n ti g e s Switcher Switcher c T ap h a in b k il i i n ti g es
|
| 86 |
+
(numerous steps)
|
| 87 |
+
2012 has the greatest div- need to
|
| 88 |
+
e (n rg o e f n o c r e m . at answer) RL A d co v l a la n p ta s g e e memorize SFT SFT
|
| 89 |
+
b) Two-stage training vs. DyME
|
| 90 |
+
Figure1: TrainingparadigmsforenablingVLMthinking. TheLVLMisQwen2.5-VL-32B(Bai
|
| 91 |
+
etal.,2025)andtheSVLMisSmolVLM-500M(Marafiotietal.,2025). (a)Existingparadigmsare
|
| 92 |
+
effectiveforLVLMsbutunsuitableforSVLMs. (b)Thetwo-stagetrainingparadigm(SFT→RL)
|
| 93 |
+
facesachallengingtrade-off. OurproposedDyMEdynamicallybalancesthistrade-off.
|
| 94 |
+
2025). This capability gap is illustrated in Fig. 1a: After SFT, LVLMs can generate grounded
|
| 95 |
+
thinkingtraceswithaccurateintermediatevalues(ingreen),whileSVLMscannot. 2)Reinforcement
|
| 96 |
+
LearningwithVerifiableReward(RLVR)(Zhangetal.,2025a;Chenetal.,2025b;Pengetal.,
|
| 97 |
+
2025;Shenetal.,2025): ontheotherhand,promotesexplorationofthinkingpatternsratherthan
|
| 98 |
+
imitations. Inthisparadigm,VLMsareinstructedtogenerateathoughtprocessfollowedbyastrictly
|
| 99 |
+
formattedanswer(e.g.,enclosedintags). Thisformatenablesverifiablerewardstoreinforcecorrect
|
| 100 |
+
generationsandpenalizeincorrectones. Owingtoitsrelianceoninstructionadherence,thisapproach
|
| 101 |
+
ispracticalprimarilyforstrongVLMsthatcanreliablygeneratestructuredoutputs.
|
| 102 |
+
Consequently,bothestablishedparadigmsareinadequate
|
| 103 |
+
forinstillingthinkinginSVLMs. Theextremelylimited
|
| 104 |
+
Baseline
|
| 105 |
+
capacity(e.g.,under1Bparameters)ofSVLMsrenders Baseline +CoT SFT
|
| 106 |
+
+RLVR
|
| 107 |
+
the SFT paradigm ineffective, as a high volume of tex- +Two-stage
|
| 108 |
+
+CoT SFT +Ours
|
| 109 |
+
tual information in CoT data can overwhelm the capac-
|
| 110 |
+
ity(Marafiotietal.,2025;Chenetal.,2025a). Moreover, +RLVR
|
| 111 |
+
the limited instruction adherence of SVLMs frequently
|
| 112 |
+
+Two-stage
|
| 113 |
+
resultsinunverifiableoutputs(Chuetal.,2025;Guoetal.,
|
| 114 |
+
2025),precipitatingadvantagecollapseduringRLVR.We +Ours
|
| 115 |
+
quantitatively verify these limitations (cf., Fig. 2): both
|
| 116 |
+
35 40 45 50 55 60 65 70 75
|
| 117 |
+
SFTandRLVRparadigmsindeedimpairtheperformance.
|
| 118 |
+
Figure2: PerformanceofSmolVLM-
|
| 119 |
+
Considering that SVLMs offer high efficiency and are 500M (Marafioti et al., 2025) on
|
| 120 |
+
crucialfordeploymentonedgedevices(Marafiotietal., ChartQA (Masry et al., 2022). Ex-
|
| 121 |
+
2025), enabling them to think addresses a strong prac- istingparadigmsdegradeperformance,
|
| 122 |
+
tical demand. Thinking enhances the reliability and whereasDyMEyieldsimprovements.
|
| 123 |
+
performance of vision tasks (Zhang et al., 2025a), and
|
| 124 |
+
task-specificSVLMsprovideacompellingalternativeto
|
| 125 |
+
LVLMsinresource-constrainedsettings. Thismotivatesthedevelopmentofanewtrainingparadigm
|
| 126 |
+
thatempowersSVLMswiththinkingcapabilities,atleastforspecializedtasks.
|
| 127 |
+
2
|
| 128 |
+
--- PAGE 3 ---
|
| 129 |
+
PublishedasaconferencepaperatICLR2026
|
| 130 |
+
ApromisingsolutionistofuseSFTandRLVR,asawell-calibratedtrade-offcanlowerthehigh
|
| 131 |
+
demandsonthebasemodel(DeepSeek,Inc.,2025;Yanetal.,2025): SFTencouragesthemodelto
|
| 132 |
+
memorizeverifiablethinkingpatternstopreventadvantagecollapse,whileRLforcesexplorationto
|
| 133 |
+
preventrigidtemplatesfromoverwhelmingthemodel’scapacity. Thecentralchallenge,however,is
|
| 134 |
+
thatSVLMsstruggletoachievethisbalance. Existinghybridmethods,liketwo-stagetraining(Chen
|
| 135 |
+
etal.,2025a;Chuetal.,2025)orannealedSFTlosses1(Zhangetal.,2025b),relyonastatictrade-off
|
| 136 |
+
governedbyhyperparameterssetempirically. Thisrigidityisthecriticalflawbecausetheminimal
|
| 137 |
+
capacityofSVLMsmeansthewindowforasuccessfulstaticbalanceisincrediblynarrow,making
|
| 138 |
+
failurealmostinevitable(cf. Fig. 1b). Ourrepeatedtrialswithtwo-stagetrainingconfirmedthisissue,
|
| 139 |
+
withperformanceoftenfallingbelowthebaseline(cf. Fig. 2).
|
| 140 |
+
SVLMs therefore require a more intelligent paradigm to navigate this trade-off. To this end, we
|
| 141 |
+
proposeDyME(DynamicMemorize–Explore),whichintegratesSFTandRLVRthroughadynamic
|
| 142 |
+
switchingmechanism. AsillustratedinFig.1b,DyMEassessesthemodel’sgenerationateachstep
|
| 143 |
+
andadaptsitstrainingmodeaccordingly. Whenthemodelfailstofollowinstructions,itswitchestoa
|
| 144 |
+
memorizationmode(SFT)toguaranteestableoptimizationsignals. Conversely,forvalidgenerations,
|
| 145 |
+
itengagesanexplorationmode(RLVR)toencouragediverseandgroundedthinking.Thisstate-driven
|
| 146 |
+
approachensuresmemorizationandexplorationarealwayscomplementary,dynamicallymaintaining
|
| 147 |
+
thedelicatetrade-off. Whilethisdynamicswitchingaloneguaranteestrainingstability,wefurther
|
| 148 |
+
maximizethemodel’spotentialbyincorporatingasynergisticVisualSupervisionmechanism. This
|
| 149 |
+
modulefacilitatesanadaptiveinteraction: theCoTground-truthguidesthescoringofexploration
|
| 150 |
+
(viaavisualchecker),whilesuccessfulexplorationtracesdynamicallyrefinetheCoTground-truth
|
| 151 |
+
(viaavisualrefiner).
|
| 152 |
+
TheaforementioneddesignmakesDyMEahighlyeffectiveparadigmforempoweringthinkingin
|
| 153 |
+
SVLMsforspecifictasks. Wevalidatethisacrossthreediversedomains,rangingfromrecognition-
|
| 154 |
+
intensivetasks(medicalVQA)toreasoning-intensivechallenges(chartunderstandingandgeometric
|
| 155 |
+
problemsolving). Remarkably, usingonlyafewthousandtrainingsamples, DyMEachievessub-
|
| 156 |
+
stantial performance gains, enabling it to match or even surpass several LVLMs. Our primary
|
| 157 |
+
contributionsareasfollows:
|
| 158 |
+
1. We propose DyME, the first training paradigm that equips SVLMs with thinking capabilities,
|
| 159 |
+
substantiallyreducingrelianceonthebaseVLM’sinitialcapacity.
|
| 160 |
+
2. Throughdynamicswitchingandsynergisticsupervision,DyMEalleviatespseudothinkingtraces
|
| 161 |
+
andadvantagecollapseinSVLMs,yieldingimage-groundedthinkingandconsistentperformance
|
| 162 |
+
improvements.
|
| 163 |
+
3. WedemonstratetheeffectivenessandpracticalityofDyMEacrossthreediversedomains,each
|
| 164 |
+
consistentlyshowingsubstantialperformancegainswithonlyafewthousandtrainingsamples.
|
| 165 |
+
2 RELATED WORK
|
| 166 |
+
Vision-LanguageModels. ModernVLMs,suchasLLaVA(Liuetal.,2024a)andQwen-VL(Bai
|
| 167 |
+
etal.,2023),havedemonstratedremarkablecapabilitiesacrossawidearrayofvisiontasks. How-
|
| 168 |
+
ever,theirsubstantialparametercountsandcomputationaldemandsrestricttheiruseinresource-
|
| 169 |
+
constrained environments like edge devices. This has motivated a growing interest in SVLMs
|
| 170 |
+
designedforefficiency(Zhouetal.,2024;Marafiotietal.,2025;Korrapati,2024). Althoughworks
|
| 171 |
+
likeTinyLLaVA(Zhouetal.,2024)andSmolVLM(Marafiotietal.,2025)haveshownthatcarefully
|
| 172 |
+
designedSVLMscanachievecompetitiveperformance, theyexhibitacriticalweakness. Recent
|
| 173 |
+
studieshighlightthattheirperformancedegradessignificantlyontasksrequiringcomplex,multi-step
|
| 174 |
+
instructionfollowing,indicatingagapintheircompositionalunderstandingandgeneralreasoning
|
| 175 |
+
abilities(Albalaketal.,2022;Ghoshetal.,2024;Liuetal.,2025b).
|
| 176 |
+
Empowering Thinking Capabilities in VLMs. Recent advances in LLM thinking (e.g., GPT-
|
| 177 |
+
o1(OpenAI,2024),DeepSeek-R1(Guoetal.,2025))havemotivatedeffortstoequipVLMswith
|
| 178 |
+
similarcapabilitiesviadedicatedtrainingparadigms.
|
| 179 |
+
SFTonCoTdata (Xu et al., 2024; Xia et al., 2024; 2025; Gao et al., 2025; Yang et al., 2025b).
|
| 180 |
+
Thisparadigmleverageslarge-scaleCoTsupervisiontoteachmodelstomemorizeandgeneralize
|
| 181 |
+
thinkingpatterns. Multimodal-CoT(Zhangetal.,2023)wasanearlyattemptusingfusedvisual–text
|
| 182 |
+
1Seethesupplementarymaterialforfurthercomparison.
|
| 183 |
+
3
|
| 184 |
+
--- PAGE 4 ---
|
| 185 |
+
PublishedasaconferencepaperatICLR2026
|
| 186 |
+
inputs,butitssmallscaledatalimitedgenuinethinking. Subsequentworkshighlighttheroleofscale:
|
| 187 |
+
G-LLaVA(Gaoetal.,2025)constructs170Kgeometry-specificCoTsamples;ChartVLM(Xiaetal.,
|
| 188 |
+
2024)compilesalargechartcorpus;andLLaVA-CoT(Xuetal.,2024)aswellasR1-OneVision(Yang
|
| 189 |
+
etal.,2025b)curatediverse,structuredCoTdatathroughlarge-scalepromptengineering. These
|
| 190 |
+
approachesfacelonginputs,requiringlargeVLMsthatcanprocessrichtextualinformationwhile
|
| 191 |
+
preservingvisualgrounding(Marafiotietal.,2025;Zhaietal.,2023).
|
| 192 |
+
RLwithVerifiableReward(RLVR)(Zhangetal.,2025a;Chenetal.,2025b;Pengetal.,2025;Shen
|
| 193 |
+
et al., 2025; Liu et al., 2025c). RLVR adopts a distinct paradigm that elicits thinking through
|
| 194 |
+
autonomousexplorationwithminimalexternalsupervision. ThepopularlyusedalgorithmisGroup
|
| 195 |
+
RelativePolicyOptimization(GRPO),introducedbyDeepSeek-Math(Shaoetal.,2024), which
|
| 196 |
+
exploitsmodels’abilitytoproducestructuredoutputsthatseparatethinkingfromfinalanswers. It
|
| 197 |
+
leverages rule-verifiable data to optimize high-scoring generations, while light SFT is employed
|
| 198 |
+
forcold-startwhentheoutputstructureisunclear. ThisparadigmhasbeenextendedtoVLMsin
|
| 199 |
+
severalworks. R1-V(Chenetal.,2025b)appliesGRPOtoVLMs,enablingthinkingintaskssuchas
|
| 200 |
+
countingandgeometry. LMM-R1(Pengetal.,2025)introducesatwo-stagepipelinethattransfers
|
| 201 |
+
textualthinkingintomultimodallearning. VisualRFT(Liuetal.,2025c)andR1-VL(Zhangetal.,
|
| 202 |
+
2025a)incorporatevision-specificrewardstoguidefine-grained,visuallygroundedoptimization.
|
| 203 |
+
SinceGRPOdependsonmodels’initialstructuredthinkingability,thesemethodstypicallybuildon
|
| 204 |
+
strongVLMs,suchastheQwen-VLseries(Baietal.,2025).
|
| 205 |
+
HybridTrainingParadigms(Chuetal.,2025;Yanetal.,2025;Zhangetal.,2025b). Toharnessthe
|
| 206 |
+
complementarystrengthsofSFTandRL,researchershavealsoinvestigatedhybridparadigms. A
|
| 207 |
+
commonapproachisatwo-stagetrainingprocess(Chuetal.,2025)thatfirstusesSFTtoteachthe
|
| 208 |
+
modelthedesiredoutputformat,followedbyRLforexploration. Althoughintuitive,thismethodis
|
| 209 |
+
highlysensitivetotheamountofSFT,aparameterthatisparticularlychallengingtotuneforSVLMs,
|
| 210 |
+
asthesesmallermodelscaneasilybecometrappedinsuboptimalstates. Alternativestrategiesattempt
|
| 211 |
+
tocontinuouslyblendSFTwithRL,forinstance, byincorporatingSFTasanannealedauxiliary
|
| 212 |
+
loss(Zhangetal.,2025b)orbymanagingitsinfluencewithanempiricalshapingfunction(Yanetal.,
|
| 213 |
+
2025). However,allthesestrategiesultimatelyrelyonanempiricallydeterminedbalancebetween
|
| 214 |
+
thetwoparadigms. ThisrigidityrepresentsacriticalflawwhenappliedtoSVLMs. Theabsenceof
|
| 215 |
+
adaptivecontrolovertheSFTweightrendersthesemethodsbrittleandunreliable.
|
| 216 |
+
Thus,existingparadigmsarenotdirectlytransferabletoSVLMsduetotheirinherentlimitations
|
| 217 |
+
inmodelcapacityandinstruction-followingability. Thishighlightstheneedforanoveltraining
|
| 218 |
+
paradigmthatimposesminimalrequirementsonthebaseVLM.
|
| 219 |
+
3 APPROACH
|
| 220 |
+
3.1 PRELIMINARIES
|
| 221 |
+
Wefirstbrieflyrecapthetwotrainingparadigms(SFTandRLVR)thatunderlieourmethod. Let
|
| 222 |
+
D ={(x ,y )}N bethetrainingset,wherexdenotestheinput(e.g. animage-instructionpair)and
|
| 223 |
+
i i i=1
|
| 224 |
+
ythedesiredoutput. Themodeldefinesaconditionaldistributionp (y |x)withparametersθ.
|
| 225 |
+
θ
|
| 226 |
+
Supervised Fine-Tuning (SFT). For each training pair (x,y) in D, SFT updates the model by
|
| 227 |
+
minimizingthenegativelog-likelihood(cross-entropy)ofthedesiredoutputyundertheconditional
|
| 228 |
+
distributionp (y |x):
|
| 229 |
+
θ
|
| 230 |
+
L (θ)=−E (cid:2) logp (y |x) (cid:3) . (1)
|
| 231 |
+
SFT (x,y)∼D θ
|
| 232 |
+
Thisteacher-forcinglossallowsmodelstomemorizeextensivetrainingexamples,compellingthe
|
| 233 |
+
modeltoabsorbthisknowledge.
|
| 234 |
+
GroupRelativePolicyOptimization(GRPO).GRPOisanRLalgorithmthatexploresopen-ended
|
| 235 |
+
generationbycomparingcandidateoutputswithinagroup. Foreachinputx,thepolicyp samplesa
|
| 236 |
+
θ
|
| 237 |
+
set{y˜k}K ;arewardfunctionr (y˜k)iscomputedbasedonthecorrectnessoftheoutputanswer,
|
| 238 |
+
k=1 a
|
| 239 |
+
andeachsample’sadvantageAismeasuredrelativetotheothergroupmembers:
|
| 240 |
+
(cid:118)
|
| 241 |
+
A(y˜k) = r a (y˜ σ k) + − ε r¯ a, r¯ a = K 1 (cid:88) K r a (y˜j), σ= (cid:117) (cid:117) (cid:116) K 1 (cid:88) K (r a (y˜j)−r¯)2, (2)
|
| 242 |
+
j=1 j=1
|
| 243 |
+
4
|
| 244 |
+
--- PAGE 5 ---
|
| 245 |
+
PublishedasaconferencepaperatICLR2026
|
| 246 |
+
where ε is a small constant for numerical stability. The policy then updates its parameters by
|
| 247 |
+
minimizingthefollowingloss,regularisedbyaKLconstraint:
|
| 248 |
+
L (θ)=−E E (cid:104) min (cid:0) r (x,y˜)A(y˜),clip (cid:0) r (x,y˜);1−ϵ,1+ϵ (cid:1) A(y˜) (cid:1)(cid:105)
|
| 249 |
+
GRPO x∼D y˜∼pθ θ θ
|
| 250 |
+
+βD (cid:2) p (·|x)∥p (·|x) (cid:3) , where r (x,y˜) = p θ (y˜|x) . (3)
|
| 251 |
+
KL θ ref θ p (y˜|x)
|
| 252 |
+
old
|
| 253 |
+
TheclipandKLtermsworktogethertokeepeachupdateclosetosaferegionsoftheparameterspace:
|
| 254 |
+
theclipgatelimitsstepsizearoundtherolloutpolicyp ,whiletheKLterm(βD )tethersthe
|
| 255 |
+
old KL
|
| 256 |
+
policytothereferencep (typicallytheinitialmodel).
|
| 257 |
+
ref
|
| 258 |
+
GradientCompatibilityofSFTandGRPO.Below,werevealthattheoptimizationobjectivesof
|
| 259 |
+
SFTandGRPOareformallyequivalent,withtheformertargetingtheground-truthdatadistribution
|
| 260 |
+
andthelatteraninternalone.
|
| 261 |
+
ThegradientoftheSFTlossisstraightforward:
|
| 262 |
+
∇ L (θ)=−E [∇ logp (y |x)]. (4)
|
| 263 |
+
θ SFT (x,y)∼D θ θ
|
| 264 |
+
Similarly,theGRPOgradient(ignoringclippingandanyKL-penalty)canbewrittenas
|
| 265 |
+
∇ L (θ)=−E (cid:2) r (x,y˜)A(y˜)∇ logp (y˜|x) (cid:3) . (5)
|
| 266 |
+
θ GRPO x∼D, θ θ θ
|
| 267 |
+
y˜∼pold(·|x)
|
| 268 |
+
ThiscomparisonshowsthattheSFTgradientisaspecialcaseoftheGRPOgradient,obtainedwhen
|
| 269 |
+
theground-truthsampleisusedwithunitadvantage. Thisequivalenceenablesaunifiedlossthat
|
| 270 |
+
balancesexternalimitation(SFT)withinternalrefinement(GRPO).Achievingthisfusionrequires
|
| 271 |
+
dynamicallyweightingthetwosignals(§3.2)andensuringstylisticconsistencybetweenexternal
|
| 272 |
+
ground-truthandself-generatedoutputs(§3.3).
|
| 273 |
+
3.2 DYNAMICMEMORIZE–EXPLORE(DYME)
|
| 274 |
+
Torealizethiscomplementarity,weproposetheDynamicMemorize–Explore(DyME)paradigm,
|
| 275 |
+
whichadaptivelyswitchesbetweenSFTandGRPOateachtrainingstep. Inthefollowing,wefirst
|
| 276 |
+
outlinetheoverallpipelineandthenelaborateontheoptimizationproceduresforeachmode.
|
| 277 |
+
Overall.AsshowninFig.3a,eachtrainingstepbeginswithaninputx=(I,q),whereIistheimage
|
| 278 |
+
andq isaninstruction. ThepolicySVLMp generatesK responses{y˜k}K . Eachresponseis
|
| 279 |
+
θ k=1
|
| 280 |
+
parsedintoathinkingtraceandafinalanswer,whichisthenverifiedforcorrectnessusingpredefined
|
| 281 |
+
rules. Theverificationresultsfallintotwocategories: eitherallresponsesareincorrect(including
|
| 282 |
+
those that fail to parse), or at least one is correct. The decision rule: if at least one response is
|
| 283 |
+
correct, themodelproceedswithGRPO-basedexploration; otherwise, itfallsbacktoSFT-based
|
| 284 |
+
memorization. Formally,thetrainingmodeisswitchedas:
|
| 285 |
+
(cid:26) GRPO, if max r (y˜k)=1,
|
| 286 |
+
mode(x)= k a (6)
|
| 287 |
+
SFT, otherwise,
|
| 288 |
+
where r (y˜k) ∈ {0,1} indicates whether y˜k passes rule-based verification. Though simple, this
|
| 289 |
+
a
|
| 290 |
+
decisionruleishighlyeffective. Whenallresponsesareincorrect,theanswerrewardsareessentially
|
| 291 |
+
allzeroandthenormalizedadvantagesbecomedominatedbynoise,makingGRPOupdatesfora
|
| 292 |
+
smallSVLMunstable. Inthisregime,fallingbacktoSFTprovidesalow-variance,ground-truth
|
| 293 |
+
guidedgradient. Conversely,theappearanceofatleastonecorrectresponseindicatesthatthecurrent
|
| 294 |
+
policy has already discovered a feasible solution for this input, so GRPO can safely exploit the
|
| 295 |
+
relativeadvantagestodriveexploration.
|
| 296 |
+
GRPOMode. DyMEintroducesakeyrefinementtotheoriginalGRPO:beyondtheanswerreward
|
| 297 |
+
r ,itincorporatesanauxiliaryrewardr forthinkingtraces. Thisrewardiscomputedbyevaluating
|
| 298 |
+
a t
|
| 299 |
+
thegeneratedtracesagainstexpectedthinkingpatterns(e.g.,viatoken-levelF1scoreground-truth
|
| 300 |
+
comparison),promotingstructuredthinking.
|
| 301 |
+
Giventheserewards,weupdatethepolicyusingamodifiedGRPOobjective. Unlikethestandard
|
| 302 |
+
formulation(Eqs.2&3),weomittheKLpenaltyandclippingterms,asthedynamicintegrationof
|
| 303 |
+
5
|
| 304 |
+
--- PAGE 6 ---
|
| 305 |
+
PublishedasaconferencepaperatICLR2026
|
| 306 |
+
GiveGni vveinsu vails ufaacl tfsacts
|
| 307 |
+
AnyA cnoyr rceocrtr ect FromFr tohme ftihgeu frieg, uthree, vthaelu veasl rueepsr reespernetseedn bteyd by
|
| 308 |
+
GRPGORPO
|
| 309 |
+
resproenspseosn?ses? the tthhree eth breaers b caarns cbaen d biree cdtilrye cotblys eorbvseedr.ved.
|
| 310 |
+
TakiTnagk tihnegi rt hmeiera mn eyaienl dysie alnd so avne roavllerall
|
| 311 |
+
q q SwiStcwhietrcher or or averaavgeer aogf ea popf raopxpirmoaxtiemlya t8e.l3y, 8w.3h,i cwhhich
|
| 312 |
+
SVLSMVLM provpirdoevsi ad ecso an cciosne csiusme smuamrym oafr yt hoef the
|
| 313 |
+
distrdibisutrtiiobnu trieofnl ercetfeledc itnVe dtih siuenV actihlhs ueaC racth.lh eaCcrkth.eerc:k 0er: 0
|
| 314 |
+
I I SFTSFT
|
| 315 |
+
OveOravlelr pailpl epliipneeline VisuVails uRaelf Rineefriner
|
| 316 |
+
pip pip
|
| 317 |
+
AnsAwnesrw Reerw Raerwdard
|
| 318 |
+
GTGT VisuVails uRaelf Rineefriner To aTnosw aners wtheer qthuee sqtuioens,t i.o..n, ,s .t.e.,p s1t eisp 1to is to
|
| 319 |
+
(0, 1(,0 1, ,1 0, )1, 0) extraecxtt rtahcet vthaelu veasl fuoers Lfoart vLiaat, vPiao,r tPuogratlu,gal,
|
| 320 |
+
anda Rnodm Raonmiaa:n 3i.a0:, 38..09,, 8a.n9d, a 1n3d.0 1,3.0,
|
| 321 |
+
respreecstpiveecltyiv. eSltye.p S 2te ips 2to i sc atolc cualalcteu ltahtee the
|
| 322 |
+
VisuVails uCahl eCckheercker
|
| 323 |
+
averaavgeer aogf et hoefs teh veaselu veasl:u (e3s.:0 (+3 .80. 9+ +8.9 +
|
| 324 |
+
Add Athdidn kthiningk rienwg arredward logiltosgits 13.01) 3/. 03) = / 83. 3=. 8T.3h.e Trehfeorreef,o trhee, athvee raavgeer aisge is
|
| 325 |
+
(0, 1(+0, 1, +1+, 1, +0), 0) 8.3. 8.3.
|
| 326 |
+
VisuVails uCahl eCckheerc:k 1er.0: 1.0
|
| 327 |
+
GRGPOR PmOo dmeode SFT S FmTo dmeode
|
| 328 |
+
(a)ThepipelineforDyME. (b)Visualrefinerandchecker.
|
| 329 |
+
Figure3: WorkflowandmodulecomponentsofDyME.Ateachtrainingstep,DyMEdynamically
|
| 330 |
+
switchesbetweenmemorization(viaSFT)andexploration(viaGRPO)modesbasedonitsgenerations.
|
| 331 |
+
Visualsupervisionisintroducedthroughthevisualrefinerandvisualchecker. Therefinerenhances
|
| 332 |
+
the targets for memorization by incorporating richer visual elements (green), while the checker
|
| 333 |
+
rewardsthethinkingcontextgeneratedbasedontheirvisualrelevance.
|
| 334 |
+
SFTalreadystabilizestraining. Thisavoidsoverlyconservativeupdatesandyieldsacleanergradient
|
| 335 |
+
form,enablingsmootheralignmentbetweenSFTandGRPO:
|
| 336 |
+
L˜ (θ)=−E E [r (x,y˜)A(y˜)], (7)
|
| 337 |
+
GRPO x∼D y˜∼pθ(·|x) θ
|
| 338 |
+
where A(y˜k) is the group-normalized advantage calculated from the combined answer (r ) and
|
| 339 |
+
a
|
| 340 |
+
thinking(r )rewards,andr (x,y˜k)= pθ(y˜|x) istheimportancesamplingratio.
|
| 341 |
+
t θ pold(y˜|x)
|
| 342 |
+
SFTMode.WhentrainingfallsbacktoSFT,themodelisoptimizedtowardtheground-truthresponse
|
| 343 |
+
yusingthestandardsupervisedloss(Eq.1). Thisensuresthatwheneverthemodelfailstoexplore
|
| 344 |
+
effectively,itreceivesastable,ground-truth-guidedgradientupdatetocorrectitsbehavior.
|
| 345 |
+
DyMEObjective. Thefinallossdynamicallycombinesthetwoobjectivesbasedonresponsecorrect-
|
| 346 |
+
ness:
|
| 347 |
+
(cid:20) (cid:21) (cid:18) (cid:20) (cid:21)(cid:19)
|
| 348 |
+
L (θ)=1 maxr (y˜k)=1 ·L˜ (θ)+ 1−1 maxr (y˜k)=1 ·L (θ), (8)
|
| 349 |
+
DyME a GRPO a SFT
|
| 350 |
+
k k
|
| 351 |
+
where1[·]istheindicatorfunction,returning1iftheconditionholds,0otherwise.
|
| 352 |
+
3.3 VISIONSUPERVISION
|
| 353 |
+
DyME with Visual Supervision. While the aforementioned Pure DyME (using standard r and
|
| 354 |
+
t
|
| 355 |
+
staticground-truth)alreadyguaranteestrainingstabilitythroughitsdynamicswitchingmechanism,
|
| 356 |
+
wecanfurtherexploitthisdynamicnaturetomaximizeperformance. Specifically,theswitching
|
| 357 |
+
mechanismallowsustotailorthesupervisionsignalsateachoptimizationstep: refiningthereward
|
| 358 |
+
duringexplorationandenhancingtheground-truthduringmemorization. Tothisend,weintroducea
|
| 359 |
+
checker–refinerframework(cf. Fig.3b),whichconstitutestheFullDyME.
|
| 360 |
+
Thisframeworkreorganizestheground-truthtoadheretoapredefinedstructure,cruciallytransform-
|
| 361 |
+
ingitintoagroundedthinkingtrace. Therefinerrestructurestheexternalground-truthintostructured,
|
| 362 |
+
visuallygroundedresponses,whilethecheckerevaluatesself-generatedoutputsfortheirstructural
|
| 363 |
+
organizationandcoverageofvisualcontent. Werefertotheresultingsupervisionsignalscollectively
|
| 364 |
+
asvisionsupervision. TheimplementationiscarriedoutviaLLM-basedpromptengineering.
|
| 365 |
+
VisualFactsI arecentraltorealizingvisionsupervision. Theyaredefinedasfine-grainedvisual
|
| 366 |
+
c
|
| 367 |
+
componentsextractedfromeachimage,includingobjects,attributes,andstates. Theseelementsplay
|
| 368 |
+
6
|
| 369 |
+
--- PAGE 7 ---
|
| 370 |
+
PublishedasaconferencepaperatICLR2026
|
| 371 |
+
adualrole: theyprovideevidenceforevaluatinggenerationsagainsttheimageandserveasbuilding
|
| 372 |
+
blocksforconstructingcompleteground-truthresponses.
|
| 373 |
+
VisualChecker.Thevisualcheckerevaluatesresponsesalongtwodimensions:(i)whethertheoutput
|
| 374 |
+
containssufficientcorrectvisualelementscomparedtoI ,and(ii)whetheritalignsstylisticallywith
|
| 375 |
+
c
|
| 376 |
+
providedexamples.TheseexamplesmaybemanuallydefinedorextractedfromtheSFTground-truth.
|
| 377 |
+
VisualRefiner. TherefinerproducesvisuallygroundedresponsesforSFTbyleveragingthemodel’s
|
| 378 |
+
validatedexplorations. High-scoringtracesidentifiedbythevisualcheckerarestoredinadynamic
|
| 379 |
+
examplepool. AnLLMthendrawsfromthispooltogenerateground-truthresponses,integrating
|
| 380 |
+
structuraltemplateswithvisualfactsfromI andreferencingthecollectedexamples.
|
| 381 |
+
c
|
| 382 |
+
Inessence,theacquisitionofVisualFacts,theevaluationbytheVisualChecker,andthesynthesisby
|
| 383 |
+
theVisualRefinerareallimplementedviastructuredpromptengineeringusingQwen2.5-14B.Please
|
| 384 |
+
refertotheSupplementaryMaterialsforthefullpromptsusedinourpipeline.
|
| 385 |
+
4 EXPERIMENTS
|
| 386 |
+
TorigorouslyevaluateDyME,westructureourexperimentsintotwoparts: (1)AlgorithmicVal-
|
| 387 |
+
idation, where we evaluate “Pure DyME” in a controlled setting using offline data to isolate the
|
| 388 |
+
contributionofourdynamicswitchingmechanism;and(2)SystemEffectiveness,whereweevaluate
|
| 389 |
+
thefullDyMEpipeline(withVisualSupervision)acrossdiversedomainstodemonstrateitspractical
|
| 390 |
+
capabilityinempoweringSVLMs.
|
| 391 |
+
4.1 PARTI:ALGORITHMICVALIDATION(PUREDYME)
|
| 392 |
+
Setup.SinceSVLMslackintrinsicreasoningcapabilitiesandcannotautonomouslydiscovercomplex
|
| 393 |
+
reasoningpaths,pre-constructedCoTdataisamandatoryprerequisiteforalltrainingparadigms. We
|
| 394 |
+
thereforeevaluatedallmethodsonChartQA(Masryetal.,2022)usingLLaVA-OV-S(Lietal.,2024a),
|
| 395 |
+
the0.5Bvariant,withthreepre-constructedCoTdatasetsofvaryingqualities: Low(Undesigned)
|
| 396 |
+
containing unstructured traces (∼80 words); Medium (Standard) consisting of semi-structured
|
| 397 |
+
traces(∼89words)fromQwen2.5-14B;andHigh(Premium)comprisinghighlystructuredtraces
|
| 398 |
+
(∼142words)fromGPT-4o. Followingestablishedprotocols(Liuetal.,2023;Masryetal.,2022),
|
| 399 |
+
wereportrelaxedcorrectness,whichallowsa5%tolerancefornumericalanswers.
|
| 400 |
+
Wepresentathreefoldevaluationtovalidatedatarobustness,designoptimality,andgeneralization:
|
| 401 |
+
(1)RobustnesstoDataQuality. Table1(a)demonstratedDyME’ssuperiority. OnLowqualitydata,
|
| 402 |
+
PureDyME(61.9%)significantlyoutperformstheunstableTwo-stagebaseline(57.6%). Remarkably,
|
| 403 |
+
using only Medium data, it surpasses the SFT baseline trained on premium High (GPT-4o) data
|
| 404 |
+
(61.6%). ThisconfirmsthatDyMEactsasarobuststudent,effectivelymaximizingdataefficiency.
|
| 405 |
+
(2)OptimalityofBinarySwitching. Tovalidateourbinarydesign,wecompareditagainstthree
|
| 406 |
+
alternativeswitchingheuristicsinTable1(b): (i)RewardThresholding,whichswitchestoRLonly
|
| 407 |
+
if the batch average reward exceeds a threshold t; (ii) SFT Annealing, which applies a weighted
|
| 408 |
+
SFTlossalongsideRLateverystep;and(iii)SFTBudget,whichperformsfocusedSFTupdateson
|
| 409 |
+
accumulatedfailurecases(hardmining).
|
| 410 |
+
Results: RewardThresholdingprovesbrittle,collapsingatsuboptimalthresholds(t=0.5,52.4%).
|
| 411 |
+
SFTAnnealingincursaheavycomputationaltax(+25%)duetotheauxiliarySFTgradientcalculation.
|
| 412 |
+
SFTBudgetyieldsinferiorresults(59.6%)asoverwhelmingthemodelwithconcentratedfailures
|
| 413 |
+
destabilizeslearning. Incontrast,DyME’sbinaryswitchisparameter-free,efficient,andempirically
|
| 414 |
+
optimal(64.9%).
|
| 415 |
+
(3) Mechanism Generality. Going beyond the primary setup, while DyME is primarily tailored
|
| 416 |
+
forSVLMs, weverifytheuniversalityofitscoreswitchingmechanism(seeSupplementary). In
|
| 417 |
+
thetext-onlydomain,itbooststhesmall-scaleQwen2.5-0.5BonGSM8K(Cobbeetal.,2021)to
|
| 418 |
+
55.3%(+5.8%overGRPO),confirmingDyMEisaneffectiveparadigmforempoweringthinkingin
|
| 419 |
+
small-parametermodelsregardlessofmodality. Moreover,theparadigmscaleseffectively: onthe
|
| 420 |
+
strongerQwen2.5-VL-7B,itfurtherimprovesChartQAperformanceto89.6%(+2.3%).
|
| 421 |
+
7
|
| 422 |
+
--- PAGE 8 ---
|
| 423 |
+
PublishedasaconferencepaperatICLR2026
|
| 424 |
+
Table1: AlgorithmicValidationofPureDyME.(a)DyMEoutperformsSFTandTwo-stagevariants
|
| 425 |
+
(w/andw/oKLpenalty)acrossalldataqualities. (b)Thebinaryswitchismorerobustandefficient
|
| 426 |
+
thansoftorhard-miningalternatives(evaluatedonMediumdata).
|
| 427 |
+
(a)RobustnessacrossDataQuality (b)SwitchingStrategyAblation
|
| 428 |
+
Method Low Medium High Strategy Hyperparam. Acc. Cost
|
| 429 |
+
SFT 50.5 57.8 61.6 RewardThreshold t=0.5/0.8/0.9 52.4/64.1/63.4 None
|
| 430 |
+
Two-stage 57.6 59.9 54.5 SFTAnnealing Cosine 64.0 +25%
|
| 431 |
+
Two-stage(w/KL) 55.4 60.8 62.7 SFTBudget HardMining 59.6 Budget-dep.
|
| 432 |
+
PureDyME 61.9 64.9 68.5 BinarySwitch(Ours) – 64.9 Baseline
|
| 433 |
+
4.2 PARTII:SYSTEMEFFECTIVENESS(FULLDYME)
|
| 434 |
+
Havingvalidatedthealgorithmiccore,wenowevaluatetheFullDyMEpipeline,augmentedwith
|
| 435 |
+
VisualSupervision,acrossthreediversedomains:MedicalVQA,ChartUnderstanding,andGeometry.
|
| 436 |
+
Eachfollowedtheevaluationprotocolsofpriorwork(Zongetal.,2024).
|
| 437 |
+
Setup & Source of I . Unlike Part I, here we activate the Visual Supervision module to enable
|
| 438 |
+
c
|
| 439 |
+
the full online loop. Crucially, to demonstrate DyME’s capability to bootstrap from raw signals,
|
| 440 |
+
we utilize the “Undesigned” CoT data (defined in §4.1) derived from SLAKE (Liu et al., 2021),
|
| 441 |
+
ChartQA(Masryetal.,2022),andGeo170K(Gaoetal.,2025)asthecommontrainingsourcefor
|
| 442 |
+
all methods. Acquiring the necessary visual facts (I ) is a fully automated process: we leverage
|
| 443 |
+
c
|
| 444 |
+
standarddomaintools(e.g.,BiomedGPT(Zhangetal.,2024a)formedical,DePlot(Liuetal.,2023)
|
| 445 |
+
forcharts)orpromptgeneralistLLMs(e.g.,Qwen2.5(Team,2024))toparseimagesintostructured
|
| 446 |
+
textualdescriptions. Theautomatedpipelineandpromptsareincludedinthesupplementary.
|
| 447 |
+
EvaluationProtocol. Weusedofficialtrain-testsplitsforSLAKE(Accuracy/Recall)andChartQA
|
| 448 |
+
(Relaxedcorrectness). ForGeometry, sinceGeo170K(Gaoetal.,2025)providesnotestset, we
|
| 449 |
+
evaluatedAccuracyonMathVerse(Zhangetal.,2024b),consistentwithZongetal.(2024).
|
| 450 |
+
4.2.1 MAINRESULTS
|
| 451 |
+
DyMEvs. ExistingTrainingParadigms. Thecomprehen-
|
| 452 |
+
siveresultsinTable2showthatDyMEconsistentlydeliv-
|
| 453 |
+
erssubstantialgains. Notably,aftertrainingwithDyME,
|
| 454 |
+
SmolVLM improves from 49.9 to 55.6 (+5.7), LLaVA-
|
| 455 |
+
OV-S from 50.7 to 55.4 (+4.7), and InternVL2-S from
|
| 456 |
+
56.3to58.1(+1.8). Incontrast,existingparadigmstend
|
| 457 |
+
todegradeperformance(e.g.,SFTlowersSmolVLMto
|
| 458 |
+
44.1),validatingouranalysisthatSFTyieldspseudothink-
|
| 459 |
+
ingtracesandGRPOfacesadvantagecollapse(cf. Fig.4).
|
| 460 |
+
DyME effectively mitigates these issues. It promotes
|
| 461 |
+
grounded traces that are concise yet informative (cf.
|
| 462 |
+
Fig.5),aligningwellwiththelimitedcapacityofSVLMs.
|
| 463 |
+
Importantly,DyMEplacesminimaldemandsonthebase
|
| 464 |
+
Figure 4: Training rewards. GRPO
|
| 465 |
+
model: evenSmolVLM(0.5B)achievessubstantialgains,
|
| 466 |
+
andtwo-stagetrainingsufferfromsevere
|
| 467 |
+
anditstilldeliversimprovements(+2.6%)onextensively
|
| 468 |
+
advantagecollapse.
|
| 469 |
+
pretrainedmodelslikeInternVL2-S.Wefurthercorrobo-
|
| 470 |
+
ratedthesefindingsthroughmanualinspection,asdetailed
|
| 471 |
+
intheSupplementaryMaterial.
|
| 472 |
+
MatchingtheEfficacyofGPT-4oSupervisionwithOpen-SourceModels. Comparingresults
|
| 473 |
+
between Part I and Part II reveals a crucial finding: LLaVA-OV-S trained with the full DyME
|
| 474 |
+
pipeline(usingtheaccessibleQwen2.5-14B)achieves67.5%(Table2). Thiseffectivelymatchesthe
|
| 475 |
+
performanceofPureDyMEtrainedonexpensiveGPT-4odata(68.5%,cf. Table1). Thisprovesthat
|
| 476 |
+
fullDyMEallowsopen-sourcesupervisiontoachievetrainingoutcomescomparabletothosederived
|
| 477 |
+
fromtop-tierproprietarymodels,eliminatingtheneedforexpensivedataannotation.
|
| 478 |
+
DyME-trainedSVLMsCanBeCompetitivewithLVLMs. Weensuredfairnessbyexposingall
|
| 479 |
+
baselinestoourtrainingdata. AsshowninTable2,SVLMstrainedwithDyMEcansurpassstronger
|
| 480 |
+
8
|
| 481 |
+
--- PAGE 9 ---
|
| 482 |
+
PublishedasaconferencepaperatICLR2026
|
| 483 |
+
Table2: Comparisonsacrossthreedomains: medicalVQA,chartunderstanding,andgeometry
|
| 484 |
+
solving.TheevaluationfollowstheVLMEvalKitframework(Duanetal.,2024).ForSVLMs,existing
|
| 485 |
+
trainingparadigmsdegradetheirperformance,whereasDyMEconsistentlybringsimprovements. The
|
| 486 |
+
bestperformanceachievedbyeachSVLMishighlightedinbold,withtherelativeimprovementalso
|
| 487 |
+
indicated. Notably,afterbeingtrainedwithDyME,SVLMsachieveperformancecomparabletothat
|
| 488 |
+
ofMoVA(underlined).
|
| 489 |
+
Model ViT LLM Medical Chart Geometry Avg.
|
| 490 |
+
LVLMs
|
| 491 |
+
LLaVA-Med(Lietal.,2023) CLIP-ViT-L/14 Vicuna-7B 64.3 – – –
|
| 492 |
+
Cambrian-1(Tongetal.,2024) Hybrid-3B Llama3-8B – 72.6 22.0 –
|
| 493 |
+
LLaVA-1.5(Liuetal.,2024a) CLIP-ViT-L/14 Vicuna-7B 69.4 17.8 – –
|
| 494 |
+
LLaVA-1.6(Liuetal.,2024b) CLIP-ViT-L/14 Vicuna-7B 78.2 49.2 13.4 47.0
|
| 495 |
+
MoVA(Zongetal.,2024) Hybrid-3B Vicuna-7B 74.5 68.3 19.7 54.2
|
| 496 |
+
LLaVA-OV-L(Lietal.,2024a) SigLIP-SO400M Qwen2-7B 75.7 80.9 24.5 60.4
|
| 497 |
+
InternVL2-L(Chenetal.,2024) InternViT-300M InternLM2.5-7B 80.2 82.1 37.3 66.5
|
| 498 |
+
SVLMs
|
| 499 |
+
SmolVLM(Marafiotietal.,2025) SigLIP-93M SmolLM2-360M 72.1 63.2 14.6 49.9
|
| 500 |
+
+CoTSFT SigLIP-93M SmolLM2-360M 60.1 57.7 14.5 44.1
|
| 501 |
+
+GRPO SigLIP-93M SmolLM2-360M 61.1 53.8 17.1 44.0
|
| 502 |
+
+Two-stage SigLIP-93M SmolLM2-360M 59.4 60.1 16.7 45.4
|
| 503 |
+
+DyME SigLIP-93M SmolLM2-360M 78.1 69.7 18.9 55.6
|
| 504 |
+
(+6.0%) (+6.5%) (+4.3%) (+5.7%)
|
| 505 |
+
LLaVA-OV-S(Lietal.,2024a) SigLIP-400M Qwen2-0.5B 74.9 61.4 15.9 50.7
|
| 506 |
+
+Two-stage SigLIP-400M Qwen2-0.5B 74.5 52.9 16.5 48.0
|
| 507 |
+
+DyME SigLIP-400M Qwen2-0.5B 78.3 67.5 20.4 55.4
|
| 508 |
+
(+3.4%) (+6.1%) (+4.5%) (+4.7%)
|
| 509 |
+
InternVL2-S(Chenetal.,2024) InternViT-300M Qwen2-0.5B 78.3 71.9 18.7 56.3
|
| 510 |
+
+Two-stage InternViT-300M Qwen2-0.5B 73.6 55.7 17.1 48.8
|
| 511 |
+
+DyME InternViT-300M Qwen2-0.5B 80.0 74.5 19.8 58.1
|
| 512 |
+
(+1.7%) (+2.6%) (+1.1%) (+1.8%)
|
| 513 |
+
LVLMslikeMoVA(54.2)onthesespecializeddomains,withSmolVLMreaching55.6andLLaVA-
|
| 514 |
+
OV-S55.4. Asaresult,DyME-trainedSVLMsbecomereliableoptionsfortask-specificapplications
|
| 515 |
+
onresource-constrainededgedevices.
|
| 516 |
+
4.2.2 ABLATIONSTUDY
|
| 517 |
+
Todissectthesourceofthesegains,weconductedanablationstudytoanalyzethecontributionof
|
| 518 |
+
DyME’sfourcorecomponentswithinthefullpipeline: thememorizationmode,explorationmode,
|
| 519 |
+
visualrefiner,andvisualchecker. Table3showstheperformanceimpact.
|
| 520 |
+
DynamicSwitchingMechanism. Theresults
|
| 521 |
+
confirmthatMemorizationandExplorationare Table3: Ablationstudy. Model: LLaVA-OV-S.
|
| 522 |
+
symbiotic. Disabling memorization causes a
|
| 523 |
+
catastrophicdrop(55.4→43.9),effectivelyre-
|
| 524 |
+
DyMEVariant Medical Chart Geometry Average
|
| 525 |
+
vertingtounconstrained,unstableexploration.
|
| 526 |
+
DyME(full) 78.3 67.5 20.4 55.4
|
| 527 |
+
Conversely, removing exploration (50.4) re- w/omemorization 63.2 53.4 15.0 43.9(20.6%↓)
|
| 528 |
+
strictsthemodeltothestaticimitationofsub-
|
| 529 |
+
w/oexploration 75.5 61.3 14.5 50.4(9.0%↓)
|
| 530 |
+
w/ovisualrefiner 75.6 62.3 16.8 51.6(6.9%↓)
|
| 531 |
+
optimaldata. AsshowninFig.4,theirdynamic w/ovisualchecker 76.9 64.3 17.1 52.8(4.7%↓)
|
| 532 |
+
interplay prevents the advantage collapse ob-
|
| 533 |
+
servedinbaselines,ensuringoptimizationstabilitythroughoutthelearningprocess.
|
| 534 |
+
Visual Supervision. Removing the visual checker and refiner drops performance by 4.7% and
|
| 535 |
+
6.9%,respectively. Thisvalidatesthepivotalroleofvisualsupervisioninbootstrappingfromnoisy,
|
| 536 |
+
undesigneddata. GiventhelimitedcapacityofSVLMs,theyareeasilypronetohallucinationwhen
|
| 537 |
+
trainedonlow-qualitytraces. Thevisualcomponentsactasadynamicdenoiser,ensuringthatraw,
|
| 538 |
+
imperfectdataisfilteredandrefinedintogroundedvisualfacts(I )beforeoptimization,thusenabling
|
| 539 |
+
c
|
| 540 |
+
robustlearningevenfromweaksupervision.
|
| 541 |
+
9
|
| 542 |
+
--- PAGE 10 ---
|
| 543 |
+
PublishedasaconferencepaperatICLR2026
|
| 544 |
+
What is the difference between The difference between the Given AB // CD, angle 1 = 50.0,
|
| 545 |
+
the values of 2017 and 2016? values of 2017 and 2016 is then what the degree of angle 2? Angle 2 is equal to angle 3,
|
| 546 |
+
(Input Image) 19000. angle 2 is also 50 degrees.
|
| 547 |
+
(Original)
|
| 548 |
+
Extraction: Extraction:
|
| 549 |
+
data is value 36700 for 2017 AB is parallel to CD,
|
| 550 |
+
and 29000 for 2016 angle 1 = 50°.
|
| 551 |
+
Calculation: Calculation:
|
| 552 |
+
36700 - 29000 = 7700 angle 2 = 180° - angle 1
|
| 553 |
+
Conclusion: = 180° - 50° = 130°.
|
| 554 |
+
The difference between the Conclusion:
|
| 555 |
+
2017 and 2016 values is 7700.(DyME) The degree of angle 2 is 130°.
|
| 556 |
+
Answer: 7700 Answer: 130°
|
| 557 |
+
Figure 5: Showcases on chart understanding and geometry solving. We use LLaVA-OV-S
|
| 558 |
+
todemonstratetheresults. TheSVLMoriginallyproduceshallucinatedanswers(red), whilethe
|
| 559 |
+
DyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategroundedvalues,
|
| 560 |
+
effectivelyimprovingtheperformance.
|
| 561 |
+
4.3 TRAININGEFFICIENCY&DISCUSSION
|
| 562 |
+
Weanalyzethecomputationalefficiencyandperformancetrade-offsassociatedwithdifferentconfig-
|
| 563 |
+
urationsofDyME.ThecomparativeresultsaredetailedinTable4.
|
| 564 |
+
Computational Efficiency vs. Data Cost. The frame-
|
| 565 |
+
work offers two distinct operating regimes catering to Table4: Cost-BenefitAnalysis. Time
|
| 566 |
+
different resource profiles. Pure DyME represents the measuredinsec/step. Runon8xH800.
|
| 567 |
+
high-efficiency regime: when offline CoT data is pre-
|
| 568 |
+
constructed,itmaintainstrainingthroughputcomparable
|
| 569 |
+
Method Ext.Model Time Acc.
|
| 570 |
+
tostandardGRPO(∼14s/step)whiledeliveringsuperior
|
| 571 |
+
GRPO(Baseline) Qwen2.5-14B† 14.8s 60.8
|
| 572 |
+
performance. Incontrast,FullDyME(withVisualSupervi- PureDyME Qwen2.5-14B† 14.0s 64.9
|
| 573 |
+
sion)prioritizesdataautonomy. Whiletheonlineinterac- PureDyME GPT-4o† 19.1s 68.5
|
| 574 |
+
tionintroducesacomputationaloverhead(∼1.6×training FullDyME Qwen2.5-7B 21.2s 66.8
|
| 575 |
+
FullDyME Qwen2.5-14B 23.4s 67.5
|
| 576 |
+
time),itenablesthemodeltobootstraphigh-performance †Usedforofflinedataconstructiononly.
|
| 577 |
+
reasoningsolelyfromopen-sourcemodels,bypassingthe
|
| 578 |
+
dependencyonexpensive,proprietarydataannotation(e.g.,GPT-4o).
|
| 579 |
+
Sensitivity to External Model Capacity. For Full DyME, we further examine the impact of the
|
| 580 |
+
externalhelper’ssizeonsystemperformance. AsshowninTable4, replacingtheQwen2.5-14B
|
| 581 |
+
helperwiththesmaller7Bvariantresultsinanegligibleperformancevariation(67.5%→66.8%).
|
| 582 |
+
Thisindicatesthatourstructuredpromptengineeringeffectivelydecomposescomplexreasoning
|
| 583 |
+
tasks, allowing even smaller external models to provide sufficient guidance for SVLMs without
|
| 584 |
+
necessitatingheavy-weightmodels.
|
| 585 |
+
ApplicabilityofVisualSupervision. TheeffectivenessoftheVisualSupervisionmodulerelieson
|
| 586 |
+
theexplicitextractionofVisualFacts(I ). Thisprocesscreatesspecificapplicabilityboundaries. For
|
| 587 |
+
c
|
| 588 |
+
domainsinvolvingabstractsemantics(e.g.,ironyinmemes)orunstructuredperception(e.g.,dense
|
| 589 |
+
crowds),convertingholisticvisualsignalsintodiscretetextmayresultininformationloss. Insuch
|
| 590 |
+
scenarios,revertingtothePureDyMEparadigmservesasamorerobustalternative.
|
| 591 |
+
5 CONCLUSION
|
| 592 |
+
Inthiswork,weintroducedDyME,anoveltrainingparadigmdesignedtoempowerthinkingcapa-
|
| 593 |
+
bilitieswithinSVLMs. Atitscore,DyMEcombinesmemorization(viaSFT)modeandexploration
|
| 594 |
+
(viaRLVR)modethroughadynamicswitchingmechanism. Ourexperimentsdemonstratethatthis
|
| 595 |
+
approachnotonlyresolvesthecriticaltrade-offbetweenthesetwomodesbutalsoyieldssubstantial
|
| 596 |
+
performance gains on a wide spectrum of vision tasks, from recognition-intensive to reasoning-
|
| 597 |
+
intensivescenarios. ThesuccessofDyMEisattributedtoitscarefullydesignedcomponents: the
|
| 598 |
+
dynamicswitchingmechanismaddressespseudothinkingtracesandadvantagecollapse,whilethe
|
| 599 |
+
visualcheckerandrefinerprovidecoordinated,high-qualityvisualsupervision. Itimposesminimal
|
| 600 |
+
requirementsonthebaseVLM,makingitbroadlyapplicabletoawiderangeofmodels,including
|
| 601 |
+
extremelylightweightSVLMs. Therefore,DyMEservesasthepracticalsolutionforempowering
|
| 602 |
+
SVLMstothink.
|
| 603 |
+
10
|
| 604 |
+
--- PAGE 11 ---
|
| 605 |
+
PublishedasaconferencepaperatICLR2026
|
| 606 |
+
ACKNOWLEDGMENT
|
| 607 |
+
ThisworkwassupportedbytheHongKongSARRGCGeneralResearchFund(16219025),National
|
| 608 |
+
NaturalScienceFoundationofChinaYoungScholarFundCategoryB(62522216),NationalNatural
|
| 609 |
+
ScienceFoundationofChinaYoungScholarFundCategoryC(62402408),andHongKongSAR
|
| 610 |
+
RGCEarlyCareerScheme(26208924).
|
| 611 |
+
REFERENCES
|
| 612 |
+
Alon Albalak, Akshat Shrivastava, Chinnadhurai Sankar, Adithya Sagar, and Mike Ross. Data-
|
| 613 |
+
efficiencywithasinglegpu: Anexplorationoftransfermethodsforsmalllanguagemodels. arXiv
|
| 614 |
+
preprintarXiv:2210.03871,2022.
|
| 615 |
+
JinzeBai,ShuaiBai,ShushengYang,ShijieWang,SinanTan,PengWang,JunyangLin,ChangZhou,
|
| 616 |
+
andJingrenZhou. Qwen-VL:Aversatilevision-languagemodelforunderstanding,localization,
|
| 617 |
+
textreading,andbeyond. arXivpreprintarXiv:2308.12966,2023.
|
| 618 |
+
ShuaiBai,KeqinChen,XuejingLiu,JialinWang,WenbinGe,SiboSong,KaiDang,PengWang,
|
| 619 |
+
ShijieWang,JunTang,HumenZhong,YuanzhiZhu,MingkunYang,ZhaohaiLi,JianqiangWan,
|
| 620 |
+
PengfeiWang,WeiDing,ZherenFu,YihengXu,JiaboYe,XiZhang,TianbaoXie,ZesenCheng,
|
| 621 |
+
HangZhang,ZhiboYang,HaiyangXu,andJunyangLin. Qwen2.5-VLtechnicalreport. arXiv
|
| 622 |
+
preprintarXiv:2502.13923,2025.
|
| 623 |
+
HardyChen,HaoqinTu,FaliWang,HuiLiu,XianfengTang,XinyaDu,YuyinZhou,andCihangXie.
|
| 624 |
+
SFTorRL?anearlyinvestigationintotrainingR1-likereasoninglargevision-languagemodels.
|
| 625 |
+
arXivpreprintarXiv:2504.11468,2025a.
|
| 626 |
+
LiangChen,LeiLi,HaozheZhao,YifanSong,andVinci. R1-V:Reinforcingsupergeneralization
|
| 627 |
+
abilityinvision-languagemodelswithlessthan$3. https://github.com/Deep-Agent/
|
| 628 |
+
R1-V,2025b. Accessed: 2025-02-02.
|
| 629 |
+
ZheChen,WeiyunWang,YueCao,YangzhouLiu,ZhangweiGao,ErfeiCui,JinguoZhu,Shenglong
|
| 630 |
+
Ye,HaoTian,ZhaoyangLiu,etal. Expandingperformanceboundariesofopen-sourcemultimodal
|
| 631 |
+
modelswithmodel,data,andtest-timescaling. arXivpreprintarXiv:2412.05271,2024.
|
| 632 |
+
TianzheChu,YuexiangZhai,JihanYang,ShengbangTong,SainingXie,DaleSchuurmans,QuocV
|
| 633 |
+
Le,SergeyLevine,andYiMa.SFTmemorizes,RLgeneralizes:Acomparativestudyoffoundation
|
| 634 |
+
modelpost-training. arXivpreprintarXiv:2501.17161,2025.
|
| 635 |
+
Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, Mark Chen, Heewoo Jun, Lukasz Kaiser,
|
| 636 |
+
MatthiasPlappert,JerryTworek,JacobHilton,ReiichiroNakano,etal. Trainingverifierstosolve
|
| 637 |
+
mathwordproblems. arXivpreprintarXiv:2110.14168,2021.
|
| 638 |
+
DeepSeek, Inc. DeepSeek-R1 Release. https://api-docs.deepseek.com/news/
|
| 639 |
+
news250120,January2025. Accessed: Jun.21,2025.
|
| 640 |
+
HaodongDuan,JunmingYang,YuxuanQiao,XinyuFang,LinChen,YuanLiu,XiaoyiDong,Yuhang
|
| 641 |
+
Zang,PanZhang,JiaqiWang,etal. VLMEvalKit: Anopen-sourcetoolkitforevaluatinglarge
|
| 642 |
+
multi-modalitymodels. InACMMM,2024.
|
| 643 |
+
Jiahui Gao, Renjie Pi, Jipeng Zhang, Jiacheng Ye, Wanjun Zhong, Yufei Wang, Lanqing Hong,
|
| 644 |
+
JianhuaHan,HangXu,ZhenguoLi,etal. G-LLaVA:Solvinggeometricproblemwithmulti-modal
|
| 645 |
+
largelanguagemodel. InICLR,2025.
|
| 646 |
+
AkashGhosh, ArkadeepAcharya, SriparnaSaha, VinijaJain, andAmanChadha. Exploringthe
|
| 647 |
+
frontierofvision-languagemodels:Asurveyofcurrentmethodologiesandfuturedirections. arXiv
|
| 648 |
+
preprintarXiv:2404.07214,2024.
|
| 649 |
+
Daya Guo, Dejian Yang, Haowei Zhang, Junxiao Song, Ruoyu Zhang, Runxin Xu, Qihao Zhu,
|
| 650 |
+
ShirongMa,PeiyiWang,XiaoBi,etal. Deepseek-R1: IncentivizingreasoningcapabilityinLLMs
|
| 651 |
+
viareinforcementlearning. arXivpreprintarXiv:2501.12948,2025.
|
| 652 |
+
VikKorrapati. Moondream. https://moondream.ai/,2024. Accessed: 2025-03-27.
|
| 653 |
+
11
|
| 654 |
+
--- PAGE 12 ---
|
| 655 |
+
PublishedasaconferencepaperatICLR2026
|
| 656 |
+
Yuxiang Lai, Jike Zhong, Ming Li, Shitian Zhao, and Xiaofeng Yang. Med-R1: Reinforce-
|
| 657 |
+
ment learning for generalizable medical reasoning in vision-language models. arXiv preprint
|
| 658 |
+
arXiv:2503.13939,2025.
|
| 659 |
+
BoLi,YuanhanZhang,DongGuo,RenruiZhang,FengLi,HaoZhang,KaichenZhang,Yanwei
|
| 660 |
+
Li,ZiweiLiu,andChunyuanLi. LLaVA-OneVision: Easyvisualtasktransfer. arXivpreprint
|
| 661 |
+
arXiv:2408.03326,2024a.
|
| 662 |
+
Chunyuan Li, Cliff Wong, Sheng Zhang, Naoto Usuyama, Haotian Liu, Jianwei Yang, Tristan
|
| 663 |
+
Naumann,HoifungPoon,andJianfengGao. LLaVA-Med: Trainingalargelanguage-and-vision
|
| 664 |
+
assistantforbiomedicineinoneday. AdvancesinNeuralInformationProcessingSystems,36:
|
| 665 |
+
28541–28564,2023.
|
| 666 |
+
Zhuowan Li, Bhavan Jasani, Peng Tang, and Shabnam Ghadar. Synthesize step-by-step: Tools
|
| 667 |
+
templatesandLLMsasdatageneratorsforreasoning-basedchartVQA. InCVPR,2024b.
|
| 668 |
+
BoLiu,Li-MingZhan,LiXu,LinMa,YanYang,andXiao-MingWu. SLAKE:Asemantically-
|
| 669 |
+
labeledknowledge-enhanceddatasetformedicalvisualquestionanswering. InISBI,2021.
|
| 670 |
+
FangyuLiu,JulianEisenschlos,FrancescoPiccinno,SyrineKrichene,ChenxiPang,KentonLee,
|
| 671 |
+
MandarJoshi,WenhuChen,NigelCollier,andYaseminAltun. DePlot: One-shotvisuallanguage
|
| 672 |
+
reasoningbyplot-to-tabletranslation. InFindingsoftheACL,2023.
|
| 673 |
+
HaotianLiu,ChunyuanLi,YuhengLi,andYongJaeLee. Improvedbaselineswithvisualinstruction
|
| 674 |
+
tuning. InCVPR,2024a.
|
| 675 |
+
Haotian Liu, Chunyuan Li, Yuheng Li, Bo Li, Yuanhan Zhang, Sheng Shen, and Yong Jae Lee.
|
| 676 |
+
LLaVA-NeXT:Improvedreasoning,OCR,andworldknowledge,January2024b. URLhttps:
|
| 677 |
+
//llava-vl.github.io/blog/2024-01-30-llava-next/.
|
| 678 |
+
JiazhenLiuandLongChen. Segmentationasaplug-and-playcapabilityforfrozenmultimodalllms.
|
| 679 |
+
arXivpreprintarXiv:2510.16785,2025.
|
| 680 |
+
JiazhenLiu,MingkuanFeng,andLongChen. Better,stronger,faster: Tacklingthetrilemmainmllm-
|
| 681 |
+
basedsegmentationwithsimultaneoustextualmaskprediction. arXivpreprintarXiv:2512.00395,
|
| 682 |
+
2025a.
|
| 683 |
+
JiazhenLiu,YuhanFu,RuobingXie,RunquanXie,XingwuSun,FengzongLian,ZhanhuiKang,
|
| 684 |
+
andXirongLi. PhD:Achatgpt-promptedvisualhallucinationevaluationdataset. InCVPR,2025b.
|
| 685 |
+
ZiyuLiu,ZeyiSun,YuhangZang,XiaoyiDong,YuhangCao,HaodongDuan,DahuaLin,andJiaqi
|
| 686 |
+
Wang. Visual-RFT:Visualreinforcementfine-tuning. arXivpreprintarXiv:2503.01785,2025c.
|
| 687 |
+
Andre´sMarafioti,OrrZohar,MiquelFarre´,MerveNoyan,ElieBakouch,PedroCuenca,CyrilZakka,
|
| 688 |
+
Loubna Ben Allal, Anton Lozhkov, Nouamane Tazi, et al. SmolVLM: Redefining small and
|
| 689 |
+
efficientmultimodalmodels. arXivpreprintarXiv:2504.05299,2025.
|
| 690 |
+
AhmedMasry,DoXuanLong,JiaQingTan,ShafiqJoty,andEnamulHoque.ChartQA:Abenchmark
|
| 691 |
+
for question answering about charts with visual and logical reasoning. In Smaranda Muresan,
|
| 692 |
+
PreslavNakov,andAlineVillavicencio(eds.),FindingsoftheACL,May2022.
|
| 693 |
+
OpenAI. IntroducingOpenAIo1. https://openai.com/o1/,December2024. Accessed: Jun.
|
| 694 |
+
21,2025.
|
| 695 |
+
Yingzhe Peng, Gongrui Zhang, Miaosen Zhang, Zhiyuan You, Jie Liu, Qipeng Zhu, Kai Yang,
|
| 696 |
+
XingzhongXu,XinGeng,andXuYang. LMM-R1: Empowering3BLMMswithstrongreasoning
|
| 697 |
+
abilitiesthroughtwo-stagerule-basedrl. arXivpreprintarXiv:2503.07536,2025.
|
| 698 |
+
Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Xiao Bi, Haowei Zhang,
|
| 699 |
+
Mingchuan Zhang, YK Li, Y Wu, et al. DeepSeekMath: Pushing the limits of mathematical
|
| 700 |
+
reasoninginopenlanguagemodels. arXivpreprintarXiv:2402.03300,2024.
|
| 701 |
+
HaozhanShen,PengLiu,JingchengLi,ChunxinFang,YiboMa,JiajiaLiao,QiaoliShen,Zilun
|
| 702 |
+
Zhang,KangjiaZhao,QianqianZhang,RuochenXu,andTianchengZhao. VLM-R1: Astable
|
| 703 |
+
andgeneralizableR1-stylelargevision-languagemodel. arXivpreprintarXiv:2504.07615,2025.
|
| 704 |
+
12
|
| 705 |
+
--- PAGE 13 ---
|
| 706 |
+
PublishedasaconferencepaperatICLR2026
|
| 707 |
+
QwenTeam. Qwen2.5: Apartyoffoundationmodels,September2024. URLhttps://qwenlm.
|
| 708 |
+
github.io/blog/qwen2.5/.
|
| 709 |
+
PeterTong,EllisBrown,PenghaoWu,SanghyunWoo,AdithyaJairamVedagiriIYER,SaiCharitha
|
| 710 |
+
Akula,ShushengYang,JihanYang,ManojMiddepogu,ZitengWang,etal. Cambrian-1: Afully
|
| 711 |
+
open,vision-centricexplorationofmultimodalLLMs. AdvancesinNeuralInformationProcessing
|
| 712 |
+
Systems,37:87310–87356,2024.
|
| 713 |
+
RenqiuXia,BoZhang,HanchengYe,XiangchaoYan,QiLiu,HongbinZhou,ZijunChen,PengYe,
|
| 714 |
+
MinDou,BotianShi,etal. ChartX&ChartVLM:Aversatilebenchmarkandfoundationmodel
|
| 715 |
+
forcomplicatedchartreasoning. arXivpreprintarXiv:2402.12185,2024.
|
| 716 |
+
RenqiuXia,MingshengLi,HanchengYe,WenjieWu,HongbinZhou,JiakangYuan,TianshuoPeng,
|
| 717 |
+
XinyuCai,XiangchaoYan,BinWang,etal. GeoX:Geometricproblemsolvingthroughunified
|
| 718 |
+
formalizedvision-languagepre-training. InICLR,2025.
|
| 719 |
+
Guowei Xu, Peng Jin, Hao Li, Yibing Song, Lichao Sun, and Li Yuan. LLaVA-CoT: Let vision
|
| 720 |
+
languagemodelsreasonstep-by-step. arXivpreprintarXiv:2411.10440,2024.
|
| 721 |
+
Jianhao Yan, Yafu Li, Zican Hu, Zhi Wang, Ganqu Cui, Xiaoye Qu, Yu Cheng, and Yue Zhang.
|
| 722 |
+
Learningtoreasonunderoff-policyguidance. arXivpreprintarXiv:2504.14945,2025.
|
| 723 |
+
LeleYang,MuxiDiao,KongmingLiang,andZhanyuMa. GRPOforLLaVA. https://github.
|
| 724 |
+
com/PRIS-CV/GRPO-for-Llava,2025a.
|
| 725 |
+
YiYang,XiaoxuanHe,HongkunPan,XiyanJiang,YanDeng,XingtaoYang,HaoyuLu,Dacheng
|
| 726 |
+
Yin, Fengyun Rao, Minfeng Zhu, et al. R1-OneVision: Advancing generalized multimodal
|
| 727 |
+
reasoningthroughcross-modalformalization. arXivpreprintarXiv:2503.10615,2025b.
|
| 728 |
+
YuexiangZhai,ShengbangTong,XiaoLi,MuCai,QingQu,YongJaeLee,andYiMa. Investigating
|
| 729 |
+
thecatastrophicforgettinginmultimodallargelanguagemodelfine-tuning. InCPAL,2023.
|
| 730 |
+
JingyiZhang,JiaxingHuang,HuanjinYao,ShunyuLiu,XikunZhang,ShijianLu,andDachengTao.
|
| 731 |
+
R1-VL:Learningtoreasonwithmultimodallargelanguagemodelsviastep-wisegrouprelative
|
| 732 |
+
policyoptimization. arXivpreprintarXiv:2503.12937,2025a.
|
| 733 |
+
KaiZhang,RongZhou,EashanAdhikarla,ZhilingYan,YixinLiu,JunYu,ZhengliangLiu,Xun
|
| 734 |
+
Chen,BrianDDavison,HuiRen,etal. Ageneralistvision–languagefoundationmodelfordiverse
|
| 735 |
+
biomedicaltasks. NatureMedicine,pp.1–13,2024a.
|
| 736 |
+
RenruiZhang,DongzhiJiang,YichiZhang,HaokunLin,ZiyuGuo,PengshuoQiu,AojunZhou,
|
| 737 |
+
PanLu,Kai-WeiChang,PengGao,etal. MathVerse: Doesyourmulti-modalllmtrulyseethe
|
| 738 |
+
diagramsinvisualmathproblems? 2024b.
|
| 739 |
+
WenhaoZhang,YuexiangXie,YuchangSun,YanxiChen,GuoyinWang,YaliangLi,BolinDing,
|
| 740 |
+
andJingrenZhou. On-policyRLmeetsoff-policyexperts: Harmonizingsupervisedfine-tuning
|
| 741 |
+
andreinforcementlearningviadynamicweighting. arXivpreprintarXiv:2508.11408,2025b.
|
| 742 |
+
ZhuoshengZhang,AstonZhang,MuLi,HaiZhao,GeorgeKarypis,andAlexSmola. Multimodal
|
| 743 |
+
chain-of-thoughtreasoninginlanguagemodels. arXivpreprintarXiv:2302.00923,2023.
|
| 744 |
+
Baichuan Zhou, Ying Hu, Xi Weng, Junlong Jia, Jie Luo, Xien Liu, Ji Wu, and Lei
|
| 745 |
+
Huang. TinyLLaVA: A framework of small-scale large multimodal models. arXiv preprint
|
| 746 |
+
arXiv:2402.14289,2024.
|
| 747 |
+
ZhuofanZong,BingqiMa,DazhongShen,GuangluSong,HaoShao,DongzhiJiang,HongshengLi,
|
| 748 |
+
andYuLiu. MoVA:Adaptingmixtureofvisionexpertstomultimodalcontext. InNeurIPS,2024.
|
| 749 |
+
13
|
| 750 |
+
--- PAGE 14 ---
|
| 751 |
+
PublishedasaconferencepaperatICLR2026
|
| 752 |
+
Empowering Small VLMs to Think with Dynamic Memorization
|
| 753 |
+
and Exploration
|
| 754 |
+
SupplementaryMaterial
|
| 755 |
+
Inthesupplementarymaterials,wereport:
|
| 756 |
+
• LLMinstructionsusedforconstructingvisionsupervision(§S1);
|
| 757 |
+
• Detailedexperimentalsetupandadditionalexperimentalresults(§S2);
|
| 758 |
+
• ShowcasesofSVLMstrainedviaDyMEperformingonmedicalVQA,chartunderstanding,
|
| 759 |
+
andgeometryproblemsolving(§S3);
|
| 760 |
+
S1 LLM INSTRUCTIONS FOR VISION SUPERVISION
|
| 761 |
+
TheinstructionsforconstructingI ,thevisualrefiner,andthevisualcheckerarelistedasfollows.
|
| 762 |
+
c
|
| 763 |
+
S1.1 INSTRUCTIONSFOREXTRACTINGVISUALELEMENTS
|
| 764 |
+
I isprimarilyderivedfromtwosources: groundtruthcaptions,andtheoutputsfromspecialized
|
| 765 |
+
c
|
| 766 |
+
toolssuchasthechart-parsingmodelDeplot. PromptS1isemployedtoextractvisualelementsfrom
|
| 767 |
+
captions.
|
| 768 |
+
1 You are a helpful assistant that analyzes images and provides visual
|
| 769 |
+
facts.
|
| 770 |
+
2 Your response MUST be a single, valid JSON object.
|
| 771 |
+
3 The JSON object should contain:
|
| 772 |
+
4 1. "description": A detailed and accurate description of the image.
|
| 773 |
+
5 2. "objects": A list of key objects, including their name, attributes,
|
| 774 |
+
and approximate position in the image.
|
| 775 |
+
6
|
| 776 |
+
7 Example format:
|
| 777 |
+
8 {
|
| 778 |
+
9 "description": "A person riding a bicycle on a city street.... (
|
| 779 |
+
detailed description here)",
|
| 780 |
+
10 "objects": [
|
| 781 |
+
11 {"name": "person", "attributes": ["wearing helmet", "blue shirt"], "
|
| 782 |
+
position": "center"},
|
| 783 |
+
12 {"name": "bicycle", "attributes": ["red", "mountain bike"], "position
|
| 784 |
+
": "center"},
|
| 785 |
+
13 {"name": "street", "attributes": ["asphalt", "wet"], "position": "
|
| 786 |
+
bottom"}
|
| 787 |
+
14 ]
|
| 788 |
+
15
|
| 789 |
+
16 Analyze the attached image and provide the visual facts in the required
|
| 790 |
+
JSON format.
|
| 791 |
+
17 For context, the user will be asked this question about the image (do not
|
| 792 |
+
answer the question, just use it for context):
|
| 793 |
+
18 "{question}"
|
| 794 |
+
19 }
|
| 795 |
+
PromptS1: AutomatedVisualFactExtraction
|
| 796 |
+
S1.2 INSTRUCTIONSFORVISUALREFINER
|
| 797 |
+
ThevisualrefinerprimarilytakesthinkingtemplatesfromatemplatepoolandI asinputtoconstruct
|
| 798 |
+
c
|
| 799 |
+
thegroundtruthforSFT.
|
| 800 |
+
1 Given:
|
| 801 |
+
2 <IC>: the data of an image
|
| 802 |
+
3 <Q>: a question
|
| 803 |
+
1
|
| 804 |
+
--- PAGE 15 ---
|
| 805 |
+
PublishedasaconferencepaperatICLR2026
|
| 806 |
+
4 <A>: a reference answer
|
| 807 |
+
5 <T>: a writing template
|
| 808 |
+
6
|
| 809 |
+
7 Goal:
|
| 810 |
+
8 Transform the visual information in <IC> into a textualized data
|
| 811 |
+
description and incorporate it into a smooth, natural explanation
|
| 812 |
+
that reasons why the correct answer to <Q> is <A>, using the format
|
| 813 |
+
and tone defined by <T>.
|
| 814 |
+
9
|
| 815 |
+
10 Example:
|
| 816 |
+
11 <IC>: [
|
| 817 |
+
12 {"object": "bar", "attributes": ["˜120k", "Q4"], "label": "Product A"},
|
| 818 |
+
13 {"object": "bar", "attributes": ["˜150k", "Q4"], "label": "Product B"},
|
| 819 |
+
14 {"object": "bar", "attributes": ["˜90k", "Q4"], "label": "Product C"},
|
| 820 |
+
15 {"title": "Quarterly Revenue"}
|
| 821 |
+
16 ]
|
| 822 |
+
17 <Q>: Which product has the highest revenue in Q4?
|
| 823 |
+
18 <A>: product b
|
| 824 |
+
19 <T>:
|
| 825 |
+
20 [Extraction] Extract the relevant data points from <IC> that are
|
| 826 |
+
needed to answer <Q>.
|
| 827 |
+
21 [Calculation] Analyze and compare the extracted data to determine the
|
| 828 |
+
correct result.
|
| 829 |
+
22 [Conclusion] Based on the analysis, state why <A> is the correct
|
| 830 |
+
answer to <Q>.
|
| 831 |
+
23
|
| 832 |
+
24 <Output>:
|
| 833 |
+
25 [Extraction] The image shows a chart comparing three items-Product A,
|
| 834 |
+
Product B, and Product C-in Q4. The bar heights indicate their
|
| 835 |
+
revenues: Product A at about 120k, Product B at about 150k, and
|
| 836 |
+
Product C at about 90k.
|
| 837 |
+
26 [Calculation] Comparing these values, Product B’s revenue (150k) is
|
| 838 |
+
greater than both Product A’s (120k) and Product C’s (90k).
|
| 839 |
+
27 [Conclusion] Since Product B has the highest revenue among the three,
|
| 840 |
+
the correct answer to the question "Which product has the
|
| 841 |
+
highest revenue in Q4?" is product b.
|
| 842 |
+
28
|
| 843 |
+
29 Now, according to the requirements and the examples above, convert my
|
| 844 |
+
input into the target reasoning text. Please give me the result
|
| 845 |
+
directly without any explanation or description.
|
| 846 |
+
30
|
| 847 |
+
31 <IC>: %s
|
| 848 |
+
32 <Q>: %s
|
| 849 |
+
33 <A>: %s
|
| 850 |
+
34 <T>: %s
|
| 851 |
+
35 <Output>:
|
| 852 |
+
PromptS2: Ground-truthconstructionforChartUnderstandingSFT
|
| 853 |
+
Promptsfortheotherdomainsfollowasimilardesign.
|
| 854 |
+
S1.3 INSTRUCTIONSFORVISUALCHECKER
|
| 855 |
+
Thevisualcheckerisprimarilyresponsibleforscoringthethinkingtraceofresponsesgeneratedin
|
| 856 |
+
theGRPOprocess. Itevaluatesthesetraceswithreferencetoexemplars,basedontheirfluencyand
|
| 857 |
+
thedegreetowhichthementionedvisualelementsalignwithI . Promptsfortheotherdomains
|
| 858 |
+
c
|
| 859 |
+
followasimilardesign.
|
| 860 |
+
1 Given
|
| 861 |
+
2 <IC>: the data of an image
|
| 862 |
+
3 <Q>: a question
|
| 863 |
+
4 <A>: a reference answer
|
| 864 |
+
5 <R>: a reasoning text
|
| 865 |
+
6
|
| 866 |
+
2
|
| 867 |
+
--- PAGE 16 ---
|
| 868 |
+
PublishedasaconferencepaperatICLR2026
|
| 869 |
+
7 Goal:
|
| 870 |
+
8 Assess whether <R> correctly and reasonably uses visible data in <IC> to
|
| 871 |
+
justify that the correct answer to <Q> is <A>. Rate the quality as
|
| 872 |
+
low / medium / high according to:
|
| 873 |
+
9 (a) low: Does not use data from <IC> at all, or the language is not
|
| 874 |
+
fluent/natural, or it fails to indicate the answer to <Q> is <A>.
|
| 875 |
+
10 (b) medium: Uses data from <IC> and is written fluently, but the
|
| 876 |
+
reasoning is overly brief or insufficiently clear.
|
| 877 |
+
11 (c) high: Uses data from <IC> and is written fluently; the reasoning
|
| 878 |
+
progresses step by step with depth, each step is correct and
|
| 879 |
+
reasonable; the data from <IC> appears exactly where it should;
|
| 880 |
+
overall, the reasoning text provides very strong support that the
|
| 881 |
+
answer to <Q> is <A>.
|
| 882 |
+
12
|
| 883 |
+
13 Example:
|
| 884 |
+
14 <IC>: [
|
| 885 |
+
15 {"object": "bar", "attributes": ["˜120k", "Q4"], "label": "Product A"},
|
| 886 |
+
16 {"object": "bar", "attributes": ["˜150k", "Q4"], "label": "Product B"},
|
| 887 |
+
17 {"object": "bar", "attributes": ["˜90k", "Q4"], "label": "Product C"},
|
| 888 |
+
18 {"title": "Quarterly Revenue"}
|
| 889 |
+
19 ]
|
| 890 |
+
20 <Q>: Which product has the highest revenue in Q4?
|
| 891 |
+
21 <A>: product b
|
| 892 |
+
22 <R>:
|
| 893 |
+
23 [Extraction] Reads Q4 bar heights: A ˜120k, B ˜150k, C ˜90k.
|
| 894 |
+
24 [Calculation] Compares values: B > A and B > C.
|
| 895 |
+
25 [Conclusion] Therefore, Product B is highest, matching the answer "
|
| 896 |
+
product b".
|
| 897 |
+
26
|
| 898 |
+
27 <Output>: medium
|
| 899 |
+
28
|
| 900 |
+
29 According to the requirements and examples above, score the input into
|
| 901 |
+
three categories. Please give me the result directly without any
|
| 902 |
+
explanation or description.
|
| 903 |
+
30
|
| 904 |
+
31 <IC>: %s
|
| 905 |
+
32 <Q>: %s
|
| 906 |
+
33 <A>: %s
|
| 907 |
+
34 <R>: %s
|
| 908 |
+
35 <Output>:
|
| 909 |
+
PromptS3: ScoringgenerationsduringGRPOforChartUnderstanding
|
| 910 |
+
S2 EXPERIMENTAL DETAILS AND EXTRA RESULTS
|
| 911 |
+
S2.1 TRAININGANDTESTSETTING
|
| 912 |
+
First,weprovidethestatisticalinformationforthetrainingandtestingphasesofourexperiments
|
| 913 |
+
intheTab. S1. Thetrainingdatasetforeachdomainconsistsofonlyafewthousandsamples. In
|
| 914 |
+
addition,Fig.S1visualizesacomparisonbetweentheground-truthresponsesproducedbytherefiner
|
| 915 |
+
andtheoriginalground-truth,showingthattherefinedversionsarenoticeablymorestructuredand
|
| 916 |
+
placegreateremphasisonintermediatevalues.
|
| 917 |
+
TableS1: Trainingandtestingsetup. DyMEempowersthinkingcapabilitiesbasedonsmalltraining
|
| 918 |
+
sets.
|
| 919 |
+
Domain Trainingset #Trainingsamples SourceofI Testset
|
| 920 |
+
c
|
| 921 |
+
MedicalVQA SLAKE-Train 4,919 BiomedGPT SLAKE-Test
|
| 922 |
+
ChartUnderstanding ChartQA-Train 4,576 DePlot ChartQA-Test
|
| 923 |
+
GeometrySolving Geo170K 6,417 Collected MathVerse
|
| 924 |
+
3
|
| 925 |
+
--- PAGE 17 ---
|
| 926 |
+
PublishedasaconferencepaperatICLR2026
|
| 927 |
+
Medical(SLAKE). Chart(ChartQA). Math(Geo170k).
|
| 928 |
+
Q:Ifangle1is35◦,whatisangle
|
| 929 |
+
2? Choices: A:55◦, B:135◦, C:
|
| 930 |
+
Q: What is the sum of making 145◦,D:155◦.
|
| 931 |
+
progress,losingground,andsame w/orefiner:
|
| 932 |
+
Q: What is the MR weighting in in2011? Sinceangle1is35degree,angle
|
| 933 |
+
thisimage? w/orefiner: 2 is 180 - 35 = 145 degree us-
|
| 934 |
+
w/orefiner: Toanswerthequestion,wefirstlo- ing a supplementary relationship.
|
| 935 |
+
Basedonsignalvariationstypical catetherelevantpartofthechart ChooseC.
|
| 936 |
+
of this image, the most plausible for 2011, then extract the values w/refiner:
|
| 937 |
+
answerisT2. forthethreeseriesandsumthem Extraction:TwoparallellinesCD
|
| 938 |
+
w/refiner: uptoobtain99. andABwithatransversal;angle1
|
| 939 |
+
Analysis: On T2, fluid is bright; w/refiner: is35degree,andangle2isonthe
|
| 940 |
+
bone/airandfastflowaredark. Extraction: Read 2011 values: sameside.
|
| 941 |
+
Extraction: In this axial slice, Making progress = 45, Losing Calculation:Same-sideinterioran-
|
| 942 |
+
intraluminalfluidisbright;vessels ground=29,Same=25. glesaresupplementary,soangle2
|
| 943 |
+
lowsignal;ribs/walldark. Calculation:45+29+25=99. =180-35=145degree.
|
| 944 |
+
Conclusion:PatternmatchesT2. Conclusion:Thesumis99. Conclusion:Answer:C.
|
| 945 |
+
FigureS1: Comparisonofground-truthresponsesbeforeandafterrefinement. Comparedto
|
| 946 |
+
theoriginalground-truth,therefinerinjectsrichervisualelementsandenforcesamorestructured
|
| 947 |
+
organization,therebyreducingthelearningburdenforSVLMs.
|
| 948 |
+
S2.2 EXTRARESULTS
|
| 949 |
+
Wealsoreportadditionalexperimentalcontent,includingthediscussionontrainingstrategiesand
|
| 950 |
+
dataorganizationformats,aswellasacomparativeanalysiswithothersimilarmethodsthatintegrate
|
| 951 |
+
SFTandRL.
|
| 952 |
+
Specifically, (1) we first demonstrate the importance of
|
| 953 |
+
constructing vision supervision, which proves essential Table S2: Two-stage training on
|
| 954 |
+
fortrainingSVLMstoproducegroundedthinkingtraces. ChartQA. Rel-corr denotes the
|
| 955 |
+
(2)Wethenexaminetheimpactofstructuredversusopen- relaxed-correctnessmetric. I indicates
|
| 956 |
+
c
|
| 957 |
+
endedoutputformatsonthinkingperformance. (3)Fur- whetheranexplicitimage-contentfield
|
| 958 |
+
thermore,tovalidateourearlierobservationthatSVLMs issupervised(✓yes;✗no).
|
| 959 |
+
arepronetoconvergingtolocaloptima,wepresentperfor-
|
| 960 |
+
manceacrossdifferenttrainingepochs,showingthatSFT
|
| 961 |
+
Model I Rel-corr
|
| 962 |
+
c
|
| 963 |
+
training saturates after only one epoch. (4) We provide
|
| 964 |
+
SmolVLM ✓ 64.32
|
| 965 |
+
adetailedcomparisonwithalternativemethodsthatinte- SmolVLM ✗ 60.09
|
| 966 |
+
grateSFTandRL.(5)Finally,weextendourevaluation LLaVA-OV-S ✓ 63.62
|
| 967 |
+
tostrongerbasemodelsandpuretextualdomains,and(6) LLaVA-OV-S ✗ 52.90
|
| 968 |
+
validatethequalityofgeneratedthinkingtracesthrough
|
| 969 |
+
humanevaluation.
|
| 970 |
+
(1)Intermediatevaluesmatter. AsshowninTableS2,wereporttheeffectofapplyingtwo-stage
|
| 971 |
+
trainingwithvisualsupervisiononSmolVLMandLLaVA-OV-S.Incorporatingvisualsupervision
|
| 972 |
+
significantlyimprovesthebestperformanceachievedduringtraining,despitecertaininstabilities,
|
| 973 |
+
thereby validating itscritical role forSVLMs. Thiseffect isfurther illustrated inFig. S1, where
|
| 974 |
+
visualsupervisioncompelsSVLMstogenerateintermediatereasoningenrichedwithvisualelements,
|
| 975 |
+
whichmakeaclearcontributiontothefinalanswer.
|
| 976 |
+
(2)StructuredthinkingalleviatesthelearningburdenofSVLMs. TableS3reportstheperfor-
|
| 977 |
+
mancegapbetweentrainingwithstructuredthinkingground-truthandwithunconstrainedground-
|
| 978 |
+
4
|
| 979 |
+
--- PAGE 18 ---
|
| 980 |
+
PublishedasaconferencepaperatICLR2026
|
| 981 |
+
1.0
|
| 982 |
+
0.8
|
| 983 |
+
0.6
|
| 984 |
+
0.4
|
| 985 |
+
0.2
|
| 986 |
+
0.0
|
| 987 |
+
0 500 1000 1500 2000 2500
|
| 988 |
+
Steps
|
| 989 |
+
ycilop-ffO
|
| 990 |
+
fo
|
| 991 |
+
noitroporP
|
| 992 |
+
DyME
|
| 993 |
+
CHORD-
|
| 994 |
+
LUFFY
|
| 995 |
+
FigureS2: Relativeoff-policyinfluenceduringtraining. Eachcurveisnormalizedtoitsinitial
|
| 996 |
+
valueforcomparability. DyMEmeasuresSFT/(SFT+RL)(rawinlightertone,Gaussian-smoothed
|
| 997 |
+
indarkertone),CHORD-µtrackstheglobalweightµ(t),andLUFFYadoptsapolicy-shapingproxy
|
| 998 |
+
E[f(π (a))]withf(x)= x . Allmethodsrevealtheshiftfromoff-policyguidancetoon-policy
|
| 999 |
+
θ x+γ
|
| 1000 |
+
optimization,albeitwithdistinctdecaydynamics.
|
| 1001 |
+
TableS3: Effectoftemplatedoutputacrossmodelsandtasks. ✓denotesfixed-templateoutput,
|
| 1002 |
+
whereas✗denotesfree-formgeneration.
|
| 1003 |
+
Model Template Chart Medical
|
| 1004 |
+
SmolVLM ✓ 60.10 59.38
|
| 1005 |
+
SmolVLM ✗ 59.24 56.13
|
| 1006 |
+
LLaVA-OV-S ✓ 52.87 74.52
|
| 1007 |
+
LLaVA-OV-S ✗ 50.86 72.64
|
| 1008 |
+
truth. Whileopen-endedexplorationisoftenbeneficialforLVLMs,thelimitedcapacityofSVLMs
|
| 1009 |
+
makesunconstrainedexplorationlesseffective,asittendstobeaimlessandincreasesthelearning
|
| 1010 |
+
burden. GiventhatSVLMsaredesignedfortask-specificratherthangeneral-purposescenarios,em-
|
| 1011 |
+
ployingtailoredthinkingtemplatesforeachtaskprovesmoresuitableandyieldsbetterperformance.
|
| 1012 |
+
Forinstance,SmolVLMachieves60.10vs.59.24onChartQAand59.38vs.56.13onMedicalVQA,
|
| 1013 |
+
withLLaVA-OV-Sexhibitingsimilargains.
|
| 1014 |
+
(3)ComparisonbetweenannealedSFTlossandDyME.
|
| 1015 |
+
AsshowninFig.S2, wecomparetherelativeSFT(off- Table S4: SVLM performance satu-
|
| 1016 |
+
policy)influenceacrosstrainingstepsforthreeapproaches: rates after a single training epoch.
|
| 1017 |
+
DyME,CHORD(Zhangetal.,2025b),andLUFFY(Yan
|
| 1018 |
+
Scoreisdomain-specific: chartdomain
|
| 1019 |
+
etal.,2025). ForDyMEandCHORD,thecurvesrepresent usesRel-corr,whilethemedicaldo-
|
| 1020 |
+
thenormalizedweightoftheSFTlossateachstep,while mainusestheaverageofaccuracyand
|
| 1021 |
+
forLUFFYthecurvereflectsthetrajectoryofSFTgradi- recallvalues.
|
| 1022 |
+
entshapingasafunctionofpredictionprobability(which
|
| 1023 |
+
generally correlates with training steps). These curves
|
| 1024 |
+
Model Domain Epoch Score
|
| 1025 |
+
highlightthedynamicnatureofDyME.Becauseoftheex-
|
| 1026 |
+
1 60.70
|
| 1027 |
+
tremelylimitedcapacityofSVLMs,theirlearningpatterns
|
| 1028 |
+
LLaVA-OV-S Chart 5 60.44
|
| 1029 |
+
canshiftsignificantlyevenbetweenadjacentsteps,leading 10 60.12
|
| 1030 |
+
torapidforgettingofpreviouslyacquiredmodes. Unlike
|
| 1031 |
+
1 60.22
|
| 1032 |
+
CHORD, which relies on a smooth annealing schedule Chart 5 63.21
|
| 1033 |
+
thatdecaysquicklyandisill-suitedtosuchsmallmodels, SmolVLM 10 62.22
|
| 1034 |
+
DyMEassignsweightsdirectlybasedonmodeloutputs. 1 71.73
|
| 1035 |
+
Medical 5 71.80
|
| 1036 |
+
Thisproducesahighlydynamicandirregulardecay,bet- 10 72.05
|
| 1037 |
+
ter accommodating the instability of SVLMs. LUFFY
|
| 1038 |
+
adopts a shaping function f(x) = x (γ=0.1), which
|
| 1039 |
+
x+γ
|
| 1040 |
+
alsoinducesadynamicdecaywithprobabilitybutremainsheuristicandmaynotalignwellwith
|
| 1041 |
+
thelocal-optimumtendencyofSVLMs. Overall,DyMEisexplicitlytailoredforSVLMs,whereas
|
| 1042 |
+
5
|
| 1043 |
+
--- PAGE 19 ---
|
| 1044 |
+
PublishedasaconferencepaperatICLR2026
|
| 1045 |
+
TableS5: Detailedlearningtrajectoriesdemonstratingrigoroustuning. Wereporttheperfor-
|
| 1046 |
+
manceacrossmultiplesettingstoshowtheirfulllearningtrajectories. Two-stagebaselinesinclude
|
| 1047 |
+
variantswithandwithoutKLpenaltiestoensureoptimalperformanceiscaptured.
|
| 1048 |
+
DataQuality Method Performanceacrossepochs(1,3,5,10) Bestperf.
|
| 1049 |
+
DyME(ours,pure) Reportfinalscoredirectly 61.9
|
| 1050 |
+
SFT 43.1 → 47.9 → 50.0 → 50.5 50.5
|
| 1051 |
+
Low
|
| 1052 |
+
Two-stage 57.6 → 52.7 → 50.8 → 50.7 57.6
|
| 1053 |
+
Two-stage(w/KL) 54.2 → 55.4 → 52.6 → 54.2 55.4
|
| 1054 |
+
DyME(ours,pure) Reportfinalscoredirectly 64.9
|
| 1055 |
+
SFT 53.6 → 56.5 → 57.8 → 56.4 57.8
|
| 1056 |
+
Medium
|
| 1057 |
+
Two-stage 59.9 → 52.8 → 53.0 → 53.1 59.9
|
| 1058 |
+
Two-stage(w/KL) 59.0 → 60.6 → 60.6 → 60.8 60.8
|
| 1059 |
+
DyME(ours,pure) Reportfinalscoredirectly 68.5
|
| 1060 |
+
SFT 58.2 → 59.1 → 61.0 → 61.6 61.6
|
| 1061 |
+
High
|
| 1062 |
+
Two-stage 51.6 → 54.0 → 54.5 → 54.4 54.5
|
| 1063 |
+
Two-stage(w/KL) 61.7 → 60.9 → 62.7 → 61.8 62.7
|
| 1064 |
+
CHORDandLUFFYmaybemoreappropriateforstrongerbasemodels,reflectingcomplementary
|
| 1065 |
+
strengths.
|
| 1066 |
+
(4)SVLMsconvergerapidly.TableS4showsthatSVLMsconvergeextremelyquickly:performance
|
| 1067 |
+
afteronlyoneepochiscomparableto,orevenexceeds,thataftertenepochs(e.g.,LLaVA-OV-S
|
| 1068 |
+
achieves 60.70 vs. 60.12 on the Chart domain). This indicates that the very limited capacity of
|
| 1069 |
+
SVLMsmakesthempronetooverfittingtolocaloptima. Italsosubstantiatesourearlierclaimthat
|
| 1070 |
+
suchrapidconvergenceleavesonlyanarrowwindowforbalancingSFTandRL,makingitdifficultto
|
| 1071 |
+
achievethetrade-offthroughempiricalhyperparametertuning. Consequently,staticfusionmethods
|
| 1072 |
+
areunsuitableforSVLMs.
|
| 1073 |
+
To ensure a rigorous comparison, we further report the full learning trajectories of baselines in
|
| 1074 |
+
Table S5. We evaluated the Two-stage baseline (with and without KL penalty) and SFT across
|
| 1075 |
+
multipleepochs(1,3,5,10)tocapturetheirpeakperformance. Theresultsconfirmthatevenwith
|
| 1076 |
+
optimalstopping,thebaselinesconsistentlyunderperformDyME,whichachievessuperiorresultsina
|
| 1077 |
+
singletrainingrunwithouttheneedforepochselection.
|
| 1078 |
+
(5)Generalityacrosscomplexreasoningandpuretext. TodemonstratethescalabilityofDyME,
|
| 1079 |
+
we applied it to two new domains without modifying the core algorithm: Physical Reasoning
|
| 1080 |
+
(A-OKVQA)andPureTextReasoning(GSM8K).
|
| 1081 |
+
• PhysicalReasoning(A-OKVQA):Weaddressedthechallengeofopen-endedvisualreasoning
|
| 1082 |
+
bytestingonA-OKVQA.WeusedQwen2.5-VL-7BtoautomaticallygenerateVisualFactsusing
|
| 1083 |
+
the prompt defined in §S1 (e.g., “man, wearing a light blue and white shirt...”). As shown in
|
| 1084 |
+
TableS6,DyMEachievedamassivegainof+18.8%(54.2%→73.0%),provingthatthemethod
|
| 1085 |
+
scaleseffortlesslytotasksrequiringworldknowledgeandcommonsense.
|
| 1086 |
+
• PureTextReasoning(GSM8K):Inpuretextdomains,the“VisualFact”extractionstepisnaturally
|
| 1087 |
+
skipped. OntheGSM8Kmathbenchmark,DyMEimprovedQwen2.5-0.5Bfrom49.5%to55.3%,
|
| 1088 |
+
demonstratingthattheparadigmgeneralizesevenwhen“vision”isabsent.
|
| 1089 |
+
Theseresults,combinedwiththeChartQAimprovementsonthestrongerQwen2.5-VL-7Bmodel,
|
| 1090 |
+
confirm that DyME is not limited by the extraction step. By leveraging off-the-shelf LVLMs to
|
| 1091 |
+
automatevisualfactgeneration,theframeworkisimmediatelyapplicabletodiversevisualandtextual
|
| 1092 |
+
domains.
|
| 1093 |
+
Limitations on Abstract Visuals. We acknowledge that the VS module may face challenges
|
| 1094 |
+
in scenarios where “Visual Facts” are intrinsically difficult to define or extract, such as memes
|
| 1095 |
+
(relying on irony or cultural context) or highly abstract non-commonsense reasoning. However,
|
| 1096 |
+
6
|
| 1097 |
+
--- PAGE 20 ---
|
| 1098 |
+
PublishedasaconferencepaperatICLR2026
|
| 1099 |
+
Table S6: Generality of DyME across New Domains. We demonstrate performance gains on
|
| 1100 |
+
ComplexScenes(A-OKVQA),PureText(GSM8K),andwithstrongerbasemodels(Qwen2.5-VL-
|
| 1101 |
+
7B).BaselinesfortextusestandardGRPO.
|
| 1102 |
+
Domain Task BaseModel Method Baseline(%) DyME(%)
|
| 1103 |
+
WorldKnowledge A-OKVQA LLaVA-OV-S Two-stage 54.2 73.0(+18.8)
|
| 1104 |
+
PureText GSM8K Qwen2.5-0.5B GRPO 49.5 55.3(+5.8)
|
| 1105 |
+
NewLVLM ChartQA Qwen2.5-VL-7B SFT 87.3 89.6(+2.3)
|
| 1106 |
+
ourprimaryobjectiveistoempowerSVLMsforpractical,real-worldproductiontasks(e.g.,chart
|
| 1107 |
+
processing,medicaldiagnostics,geometricsolving). Inthesestructuredandsemi-structureddomains
|
| 1108 |
+
whereSVLMsaremostcommonlydeployed,VisualFactsarewell-definedandDyMEproveshighly
|
| 1109 |
+
effective.
|
| 1110 |
+
(6) Human evaluation of CoT quality. Automatic metrics like relaxed accuracy do not fully
|
| 1111 |
+
reflect the quality of the reasoning process. To verify whether DyME generates genuinely better
|
| 1112 |
+
thinkingtraces,weconductedahumanevaluationon100randomlysampledinstancesfromChartQA.
|
| 1113 |
+
AnnotatorsjudgedthevalidityofthegeneratedCoTbasedonitslogicalcoherenceandgrounding.
|
| 1114 |
+
As shown in Table S7, DyME produces traces that are slightly more concise (shorter length) but
|
| 1115 |
+
significantlymorevalid(validityrate∼70%)comparedtotheTwo-stagebaseline(∼30-40%). This
|
| 1116 |
+
confirms that DyME effectively mitigates the generation of “pseudo thinking traces” that plague
|
| 1117 |
+
standardSFT/Two-stagetraining.
|
| 1118 |
+
TableS7: HumanevaluationofCoTqualityonChartQA.
|
| 1119 |
+
BaseModel Method Avg. CoTLength HumanEval(Valid%)
|
| 1120 |
+
Two-stage ∼76.3Words 31%
|
| 1121 |
+
LLaVA-OV-S
|
| 1122 |
+
DyME ∼69.7Words 68%
|
| 1123 |
+
Two-stage ∼84.5Words 40%
|
| 1124 |
+
SmolVLM
|
| 1125 |
+
DyME ∼75.4Words 72%
|
| 1126 |
+
S3 SHOWCASES
|
| 1127 |
+
Before presenting the model outputs, we first illustrate the data quality definitions used in our
|
| 1128 |
+
Algorithmic Validation (Section 4.1 of the main paper). Table S8 showcases examples of Low
|
| 1129 |
+
(Undesigned),Medium(Standard),andHigh(Premium)qualityChain-of-Thoughtsupervisionfor
|
| 1130 |
+
thesamequestion. ThisvisualizesthesignificantgapinstructureanddetailthatDyMEmustbridge
|
| 1131 |
+
whentrainedonnon-premiumdata. Furthermore,TableS9illustratesthecomprehensiveformatof
|
| 1132 |
+
oursupervisiondata,encompassingtheinputimage,theassociatedquestion,theextractedvisual
|
| 1133 |
+
facts,andtheground-truthresponse.
|
| 1134 |
+
WepresentdialogueinstancesofSmolVLM,LLaVA-OV-S,andInternVL2-S,whichweretrained
|
| 1135 |
+
withDyMEinourexperiments,ontasksinthedomainsofmedicalVQA,chartunderstanding,and
|
| 1136 |
+
geometry. AsshowninFigS3toFigS5,allmodelstrainedwithDyMEdemonstratetheabilityto
|
| 1137 |
+
generateeffectivethinkingtraceswithaccurateintermediatevalues(ingreen),whichplayacrucial
|
| 1138 |
+
roleinreachingthefinalcorrectanswer.
|
| 1139 |
+
Thethinkingprocessesofthesemodelsarerelativelyfixedandtemplate-based. Thisisaresultof
|
| 1140 |
+
our vision supervision module taking effect, as SVLMs should not be overly broad and general;
|
| 1141 |
+
otherwise,theirexplorationcaneasilydivergeinsteadofconverging.
|
| 1142 |
+
7
|
| 1143 |
+
--- PAGE 21 ---
|
| 1144 |
+
PublishedasaconferencepaperatICLR2026
|
| 1145 |
+
TableS8: Showcaseofchain-of-thought(CoT)supervisionwithdifferentqualityforChartQA.
|
| 1146 |
+
Type Content
|
| 1147 |
+
Question Whendoestheunfavorableviewreachthepeak?
|
| 1148 |
+
High-quality <SUMMARY>IinspectthelinegraphshowingveryunfavorableviewsofPakistaninIndiaover
|
| 1149 |
+
CoT time,andidentifytheyearwherethepercentagepeaksbylocatingthemaximumpointonthe
|
| 1150 |
+
curve.</SUMMARY>
|
| 1151 |
+
<CAPTION>Linegraphtitled“VeryunfavorableviewsofPakistanincreasinginIndia”.The
|
| 1152 |
+
y-axisshowstheshareofrespondentswithaveryunfavorableview(0–100%). Thex-axis
|
| 1153 |
+
showsyears2013–2017.Markedvalues:54%(2013),49%(2014),51%(2015),55%(2016),
|
| 1154 |
+
and64%(2017).Theseriesexhibitsanoverallupwardtrendandpeaksin2017.</CAPTION>
|
| 1155 |
+
<REASONING>Thepeakcorrespondstothehighestpercentageonthey-axis. Amongthe
|
| 1156 |
+
annotatedyears,2017hasthelargestvalue(64%),higherthan2013–2016.Therefore,thepeak
|
| 1157 |
+
occursin2017.</REASONING>
|
| 1158 |
+
<CONCLUSION>2017</CONCLUSION>
|
| 1159 |
+
Medium-quality Goal:Determinetheyearwhentheunfavorableviewreachesitsmaximum.
|
| 1160 |
+
CoT Observation:Thevaluesforeachyearare:2013:0,2014:0,2015:0,2016:55,and2017:64.
|
| 1161 |
+
Reasoning:Comparingthesevalues,thelargestvalueis64,whichappearsin2017.
|
| 1162 |
+
Conclusion:Theunfavorableviewreachesitspeakin2017.
|
| 1163 |
+
Low-quality I’mtryingtofigureouttheyearwhentheunfavorableviewreachesitshighestpoint.Thevalues
|
| 1164 |
+
CoT staylowuntil2016,wheretheyriseto55,andthenclimbfurtherto64in2017.Thismakes
|
| 1165 |
+
2017theyearwhentheunfavorableviewreachesitsmaximum.
|
| 1166 |
+
TableS9: ShowcaseofacompleteCoTtrainingexample.
|
| 1167 |
+
Type Content
|
| 1168 |
+
Question Whatisthemanbythebagsawaiting?
|
| 1169 |
+
VisualFact {”description”: ”Amanstandsinthemiddleofaresidentialstreet,facingawayfromthe
|
| 1170 |
+
camera. Heholdsaredbagandpullsawheeledblacksuitcase,withanotherblacksuitcase
|
| 1171 |
+
placednearby.Houses,parkedcars,andtreeslinethebackground,suggestingheiswaiting
|
| 1172 |
+
tocrossorbepickedup.”,”objects”:[ ”name”:”man”,”attributes”:[”lightblueandwhite
|
| 1173 |
+
shirt”,”bluejeans”,”carryingaredbag”,”pullingawheeledsuitcase”],”position”:”center”,
|
| 1174 |
+
”name”:”redbag”,”attributes”:[”heldbytheman”],”position”:”leftoftheman”, ”name”:
|
| 1175 |
+
”blacksuitcase”,”attributes”:[”wheeled”,”beingpulled”],”position”:”neartheman’sfeet”
|
| 1176 |
+
, ”name”:”blacksuitcase”,”attributes”:[”lyingontheground”],”position”:”neartheman”
|
| 1177 |
+
, ”name”:”street”,”attributes”:[”asphalt”,”urbanresidential”],”position”:”foreground”,
|
| 1178 |
+
”name”:”houses”,”attributes”:[”residentialhomes”],”position”:”backgroundleft”, ”name”:
|
| 1179 |
+
”parkedcars”,”attributes”:[”redSUV”,”othervehicles”],”position”:”backgroundcenterand
|
| 1180 |
+
left”, ”name”:”trees”,”attributes”:[”greenfoliage”],”position”:”backgroundright” ]}
|
| 1181 |
+
GTResponse Atrainwouldnotappearonaresidentialstreet,andtheman’sluggagesuggestsheiswaiting
|
| 1182 |
+
fortransportationratherthanadelivery.Thenearbyskateboarderisnotinteractingwithhim,
|
| 1183 |
+
leavingacabasthemostplausibleoption.
|
| 1184 |
+
8
|
| 1185 |
+
--- PAGE 22 ---
|
| 1186 |
+
PublishedasaconferencepaperatICLR2026
|
| 1187 |
+
(a)ShowcaseofMedicalVQA
|
| 1188 |
+
(b)ShowcaseofChartUnderstanding
|
| 1189 |
+
(c)ShowcaseofGeometrySolving
|
| 1190 |
+
FigureS3: ShowcasesofSmolVLM.TheSVLMoriginallyproduceshallucinatedanswers(red),
|
| 1191 |
+
whiletheDyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategrounded
|
| 1192 |
+
values,effectivelyimprovingtheperformance.
|
| 1193 |
+
9
|
| 1194 |
+
--- PAGE 23 ---
|
| 1195 |
+
PublishedasaconferencepaperatICLR2026
|
| 1196 |
+
(a)ShowcaseofMedicalVQA
|
| 1197 |
+
(b)ShowcaseofChartUnderstanding
|
| 1198 |
+
(c)ShowcaseofGeometrySolving
|
| 1199 |
+
FigureS4: ShowcasesofInternVL2-S.TheSVLMoriginallyproduceshallucinatedanswers(red),
|
| 1200 |
+
whiletheDyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategrounded
|
| 1201 |
+
values,effectivelyimprovingtheperformance.
|
| 1202 |
+
10
|
| 1203 |
+
--- PAGE 24 ---
|
| 1204 |
+
PublishedasaconferencepaperatICLR2026
|
| 1205 |
+
(a)ShowcaseofChartUnderstanding
|
| 1206 |
+
(b)ShowcaseofMedicalVQA
|
| 1207 |
+
(c)ShowcaseofGeometrySolving
|
| 1208 |
+
FigureS5: ShowcasesofLLaVA-OV-S.TheSVLMoriginallyproduceshallucinatedanswers(red),
|
| 1209 |
+
whiletheDyME-trainedmodelgeneratesstructuredthinkingtraces(green)thatincorporategrounded
|
| 1210 |
+
values,effectivelyimprovingtheperformance.
|
| 1211 |
+
11
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.11.0
|
| 2 |
+
pycocotools
|
| 3 |
+
matplotlib
|
| 4 |
+
datasets
|
| 5 |
+
peft==0.17.0
|
| 6 |
+
qwen_vl_utils
|
| 7 |
+
wandb
|
| 8 |
+
transformers==4.57.1
|
| 9 |
+
deepspeed==0.18.1
|
| 10 |
+
pycocotools
|
| 11 |
+
trl==0.23.1
|
| 12 |
+
flash-attn==2.7.4.post1
|
| 13 |
+
scikit-image
|
| 14 |
+
openai
|
| 15 |
+
spacy
|
| 16 |
+
autoawq==0.2.9
|
reward_utils/__pycache__/format_checks.cpython-310.pyc
ADDED
|
Binary file (2.82 kB). View file
|
|
|
reward_utils/compute_rewards.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import concurrent.futures
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
from .checker import RewardCalculator
|
| 4 |
+
|
| 5 |
+
def split_initial_context(text: str):
|
| 6 |
+
text = text.lower()
|
| 7 |
+
flag = 'answer:'
|
| 8 |
+
if flag in text:
|
| 9 |
+
ans = text.split(flag)[-1].strip()
|
| 10 |
+
context = text.split(flag)[0].strip()
|
| 11 |
+
ans = ans.strip('.')
|
| 12 |
+
else:
|
| 13 |
+
context = text
|
| 14 |
+
ans = text
|
| 15 |
+
return context, ans
|
| 16 |
+
|
| 17 |
+
def calculate_rewards_in_parallel(
|
| 18 |
+
checker: RewardCalculator,
|
| 19 |
+
batch_data: Dict[str, Any],
|
| 20 |
+
gpu_id: int,
|
| 21 |
+
num_threads: int = 8,
|
| 22 |
+
task='chart'):
|
| 23 |
+
"""
|
| 24 |
+
Calculates accuracy rewards for a batch of data in parallel using a thread pool.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
batch_data: A dictionary containing lists of data, including 'response',
|
| 28 |
+
'prompt', 'image', 'answer', and an optional 'tp' (answer_type).
|
| 29 |
+
gpu_id: The ID of the GPU to be used for processing.
|
| 30 |
+
num_threads: The number of parallel threads to use.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A list of calculated reward scores for each item in the batch.
|
| 34 |
+
"""
|
| 35 |
+
# Extract lists of data from the input dictionary
|
| 36 |
+
responses = batch_data['response']
|
| 37 |
+
predictions = []
|
| 38 |
+
for r in responses:
|
| 39 |
+
c, p = split_initial_context(r)
|
| 40 |
+
predictions.append(p)
|
| 41 |
+
prompts = batch_data['prompt']
|
| 42 |
+
# questions = batch_data['question']
|
| 43 |
+
answers = batch_data['answer']
|
| 44 |
+
hints = batch_data['hints'] if 'hints' in batch_data else [""] * len(responses)
|
| 45 |
+
num_samples = len(responses)
|
| 46 |
+
|
| 47 |
+
# Safely get 'answer_types', providing a list of Nones as a default
|
| 48 |
+
# This fixes a bug in the original code.
|
| 49 |
+
answer_types = batch_data.get('tp', [None] * num_samples)
|
| 50 |
+
|
| 51 |
+
# Prepare the arguments for each task by zipping the data together.
|
| 52 |
+
# This creates an iterator of tuples, where each tuple contains all args for one call.
|
| 53 |
+
in_answers = answers
|
| 54 |
+
if 'world' in task:
|
| 55 |
+
in_answers = batch_data['direct_answers']
|
| 56 |
+
task_answer_args = zip(
|
| 57 |
+
predictions,
|
| 58 |
+
in_answers,
|
| 59 |
+
[task] * num_samples,
|
| 60 |
+
# [gpu_id] * num_samples,
|
| 61 |
+
# answer_types,
|
| 62 |
+
# hints
|
| 63 |
+
)
|
| 64 |
+
task_thinking_args = zip(
|
| 65 |
+
responses,
|
| 66 |
+
prompts,
|
| 67 |
+
answers,
|
| 68 |
+
hints,
|
| 69 |
+
[task] * num_samples
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Use a ThreadPoolExecutor to process the data in parallel.
|
| 73 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
| 74 |
+
# Instead of a separate function, use a lambda to unpack the arguments.
|
| 75 |
+
# The '*' operator unpacks each tuple from task_args into positional arguments
|
| 76 |
+
# for the get_acc_reward function.
|
| 77 |
+
|
| 78 |
+
format_rewards = list(executor.map(lambda r: checker.get_format_reward(r, task=task), responses))
|
| 79 |
+
answer_rewards = list(executor.map(lambda args: checker.get_answer_reward(*args), task_answer_args))
|
| 80 |
+
thinking_rewards = list(executor.map(
|
| 81 |
+
lambda args: checker.get_thinking_reward_prompt(*args), task_thinking_args))
|
| 82 |
+
|
| 83 |
+
rewards = [0 if f == 0 else f + a + t for f, a, t in zip(format_rewards, answer_rewards, thinking_rewards)]
|
| 84 |
+
|
| 85 |
+
return rewards, format_rewards, answer_rewards, thinking_rewards
|
| 86 |
+
|
| 87 |
+
def refine_context_in_parallel(
|
| 88 |
+
refiner,
|
| 89 |
+
questions: List[str],
|
| 90 |
+
hints: List[str],
|
| 91 |
+
reference_answers: List[str],
|
| 92 |
+
task,
|
| 93 |
+
gpu_id: int,
|
| 94 |
+
num_threads: int = 8):
|
| 95 |
+
"""
|
| 96 |
+
Refines contexts for a batch of data in parallel using a thread pool.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
questions: A list of questions.
|
| 100 |
+
hints: A list of hints corresponding to each question.
|
| 101 |
+
reference_answers: A list of reference answers.
|
| 102 |
+
tasks: A list of task types corresponding to each question.
|
| 103 |
+
gpu_id: The ID of the GPU to be used for processing.
|
| 104 |
+
num_threads: The number of parallel threads to use.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
A list of refined contexts for each question.
|
| 108 |
+
"""
|
| 109 |
+
num_samples = len(questions)
|
| 110 |
+
tasks = [task] * num_samples
|
| 111 |
+
# Prepare the arguments for each task by zipping the data together.
|
| 112 |
+
task_args = zip(
|
| 113 |
+
questions,
|
| 114 |
+
hints,
|
| 115 |
+
reference_answers,
|
| 116 |
+
tasks,
|
| 117 |
+
[gpu_id] * num_samples
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Use a ThreadPoolExecutor to process the data in parallel.
|
| 121 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
| 122 |
+
refined_contexts = list(executor.map(
|
| 123 |
+
lambda args: refiner.refine_hint(*args), task_args
|
| 124 |
+
))
|
| 125 |
+
|
| 126 |
+
return refined_contexts
|
reward_utils/refiner.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from client_utils.openai_api import OpenAIClient
|
| 5 |
+
from data_utils.chart.evaluator import eval_one_chart
|
| 6 |
+
from data_utils.commom_util import prompt_ic
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from filelock import FileLock
|
| 11 |
+
TEMPLATE_FILE = "best_template.txt"
|
| 12 |
+
LOCK_FILE = "best_template.txt.lock"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
TEMPLATE_REFRESH_INTERVAL = 60 # Interval (in seconds) to refresh template from file
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ContextRefiner:
|
| 19 |
+
"""
|
| 20 |
+
A class to refine hints / reasoning with an external LLM.
|
| 21 |
+
Encapsulates logic for template management and refinement calls.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, RL_CONFIG, CLIENT_CONFIG, gpu_id=0):
|
| 25 |
+
"""
|
| 26 |
+
Initializes the ContextRefiner.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
RL_CONFIG: RL-related configuration dict.
|
| 30 |
+
CLIENT_CONFIG: LLM client configuration dict.
|
| 31 |
+
gpu_id: process / GPU id used to select API server.
|
| 32 |
+
"""
|
| 33 |
+
self.refine_templetes = ["""Goal: [State the user's objective, e.g., Find the year with the highest sales]
|
| 34 |
+
Observation: [List key data points from the chart, e.g., 2020: 150, 2021: 200, 2022: 180]
|
| 35 |
+
Reasoning: [State the logical step, e.g., Compare the values. 200 is the maximum.]
|
| 36 |
+
Conclusion: [Draw the conclusion, e.g., The year with the highest sales was 2021.]
|
| 37 |
+
"""]
|
| 38 |
+
self.template_lock = FileLock(LOCK_FILE)
|
| 39 |
+
# Set to 0 so that the first call will force a refresh from file
|
| 40 |
+
self.last_template_check_time = 0
|
| 41 |
+
if CLIENT_CONFIG['client_type'] == 'openai':
|
| 42 |
+
if CLIENT_CONFIG['init_port'] is not None:
|
| 43 |
+
num_server = int(CLIENT_CONFIG['num_server'])
|
| 44 |
+
server_id = gpu_id % num_server
|
| 45 |
+
CLIENT_CONFIG['api_base'] = CLIENT_CONFIG['api_base'] % str(CLIENT_CONFIG['init_port'] + server_id)
|
| 46 |
+
self.client = OpenAIClient(config=CLIENT_CONFIG)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Client type '{CLIENT_CONFIG['client_type']}' not supported.")
|
| 49 |
+
|
| 50 |
+
def _check_and_update_template(self):
|
| 51 |
+
"""
|
| 52 |
+
(Private method) Check whether we need to refresh the template from file.
|
| 53 |
+
This operation is process-safe.
|
| 54 |
+
"""
|
| 55 |
+
current_time = time.time()
|
| 56 |
+
|
| 57 |
+
# 1. Check whether the refresh interval has passed
|
| 58 |
+
if (current_time - self.last_template_check_time) < TEMPLATE_REFRESH_INTERVAL:
|
| 59 |
+
return # Not yet time, keep using cached template
|
| 60 |
+
|
| 61 |
+
# 2. Try to acquire the lock and read (short timeout since reading should be fast)
|
| 62 |
+
try:
|
| 63 |
+
# print(f"[Process {os.getpid()}] Checking for template update...") # Uncomment for debugging
|
| 64 |
+
with self.template_lock.acquire(timeout=5):
|
| 65 |
+
|
| 66 |
+
# --- Lock acquired, safe to read ---
|
| 67 |
+
if not os.path.exists(TEMPLATE_FILE):
|
| 68 |
+
# File does not exist, keep default template
|
| 69 |
+
self.last_template_check_time = current_time
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
with open(TEMPLATE_FILE, "r", encoding="utf-8") as f:
|
| 73 |
+
new_template = f.read().strip()
|
| 74 |
+
|
| 75 |
+
# If file content is valid and different, update in-memory template
|
| 76 |
+
if new_template and new_template != self.refine_templetes[0]:
|
| 77 |
+
self.refine_templetes = [new_template]
|
| 78 |
+
print(f"[Process {os.getpid()}] Refiner template updated from file.")
|
| 79 |
+
|
| 80 |
+
# Regardless of success, update last check time to avoid frequent retries
|
| 81 |
+
self.last_template_check_time = current_time
|
| 82 |
+
|
| 83 |
+
except TimeoutError:
|
| 84 |
+
# Failed to acquire lock (another process is likely writing)
|
| 85 |
+
# Do not block; skip and try again next time
|
| 86 |
+
print(f"[Process {os.getpid()}] Failed to acquire lock for template read, using cached version.")
|
| 87 |
+
# Update time to avoid immediate retry
|
| 88 |
+
self.last_template_check_time = current_time
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"[Process {os.getpid()}] Error reading template file: {e}")
|
| 92 |
+
self.last_template_check_time = current_time
|
| 93 |
+
|
| 94 |
+
def refine_hint(self, question, hint: str, reference_answer: str, task: str, gpu_id=None):
|
| 95 |
+
if hint == "":
|
| 96 |
+
return hint
|
| 97 |
+
|
| 98 |
+
self._check_and_update_template()
|
| 99 |
+
system_prompt = None
|
| 100 |
+
if 'medical' in task:
|
| 101 |
+
system_prompt = 'You are a seasoned professional in the field of medical image analysis, demonstrating exceptional expertise and insight into complex medical imaging data. Your output should be only judgement, without any additional text or explanation.'
|
| 102 |
+
elif 'math' in task:
|
| 103 |
+
system_prompt = 'You are a seasoned professional in the field of mathematics, demonstrating exceptional expertise and insight into complex mathematical problems. Your output should be only judgement, without any additional text or explanation.'
|
| 104 |
+
elif 'chart' in task:
|
| 105 |
+
system_prompt = 'You are a seasoned professional in the field of chart analysis, demonstrating exceptional expertise and insight into complex chart data. Your output should be only judgement, without any additional text or explanation.'
|
| 106 |
+
elif 'world' in task:
|
| 107 |
+
system_prompt = 'You are a seasoned professional in the field of world knowledge and image analysis, demonstrating exceptional expertise and insight into complex real-world scenarios. Your output should be only judgement, without any additional text or explanation.'
|
| 108 |
+
else:
|
| 109 |
+
Exception('Unknown expert task')
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
in_context_example = self.client.get_completion(
|
| 113 |
+
prompt_ic % hint,
|
| 114 |
+
system_prompt=system_prompt,
|
| 115 |
+
max_tokens=5000
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if 'chart' in task or 'world' in task:
|
| 119 |
+
if 'chart' in task:
|
| 120 |
+
from data_utils.chart.prompts import prompt_thinking_reward, prompt_refine
|
| 121 |
+
else:
|
| 122 |
+
from data_utils.aokvqa.prompts import prompt_thinking_reward, prompt_refine
|
| 123 |
+
# Construct the final prompt for the evaluator model.
|
| 124 |
+
evaluation_prompt = prompt_refine % (
|
| 125 |
+
in_context_example,
|
| 126 |
+
question,
|
| 127 |
+
reference_answer,
|
| 128 |
+
self.refine_templetes[0]
|
| 129 |
+
)
|
| 130 |
+
output = self.client.get_completion(
|
| 131 |
+
evaluation_prompt,
|
| 132 |
+
system_prompt=system_prompt,
|
| 133 |
+
max_tokens=1000
|
| 134 |
+
)
|
| 135 |
+
return output
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Task '{task}' not supported for thinking reward.")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f"An error occurred during thinking reward prompt generation: {e}")
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ContextRefinerLocal:
|
| 144 |
+
"""
|
| 145 |
+
A local (non-LLM) refiner that simply returns the original hint.
|
| 146 |
+
Used when remote refinement is disabled or not desired.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(self, RL_CONFIG, CLIENT_CONFIG, gpu_id=0):
|
| 150 |
+
"""
|
| 151 |
+
Initializes the local ContextRefiner.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
RL_CONFIG: RL-related configuration dict.
|
| 155 |
+
CLIENT_CONFIG: client configuration dict (unused here).
|
| 156 |
+
gpu_id: process / GPU id (unused here).
|
| 157 |
+
"""
|
| 158 |
+
# Do nothing; local refiner is a no-op.
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
def refine_hint(self, question, hint: str, reference_answer: str, task: str, gpu_id=None):
|
| 162 |
+
return hint
|
tests/test_data_health_probe.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for batch data health diagnostics."""
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from opsd_utils.diagnostics import (
|
| 5 |
+
_detect_char_repeat,
|
| 6 |
+
summarize_batch_data_health,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_detect_char_repeat_cjk():
|
| 11 |
+
assert _detect_char_repeat("Goal: " + "其" * 10)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_summarize_batch_data_health_empty_vf():
|
| 15 |
+
samples = [
|
| 16 |
+
{"prompt": "q1", "visual_fact_hint": ""},
|
| 17 |
+
{"prompt": "q2", "visual_fact_hint": "bar value 3"},
|
| 18 |
+
]
|
| 19 |
+
stats = summarize_batch_data_health(samples)
|
| 20 |
+
assert stats["visual_fact_empty_rate"] == 0.5
|
| 21 |
+
assert stats["batch_size"] == 2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_summarize_batch_data_health_pixel_nan():
|
| 25 |
+
samples = [{"prompt": "q", "visual_fact_hint": "x"}]
|
| 26 |
+
pixel = torch.tensor([float("nan"), 1.0, 2.0])
|
| 27 |
+
stats = summarize_batch_data_health(samples, pixel_values=pixel)
|
| 28 |
+
assert stats["pixel_has_nan"] is True
|
tests/test_degeneration_probe.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Regression tests for completion degeneration heuristics."""
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from opsd_utils import debug_log as opsd_debug
|
| 7 |
+
from opsd_utils.diagnostics import (
|
| 8 |
+
_detect_char_repeat,
|
| 9 |
+
_detect_degeneration,
|
| 10 |
+
_detect_repeat_loop,
|
| 11 |
+
_detect_single_token_repeat,
|
| 12 |
+
_max_same_token_run,
|
| 13 |
+
is_degenerate_completion,
|
| 14 |
+
log_generate_probe,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _FakeTokenizer:
|
| 19 |
+
eos_token_id = 151645
|
| 20 |
+
pad_token_id = 151643
|
| 21 |
+
bos_token_id = None
|
| 22 |
+
|
| 23 |
+
def decode(self, ids, skip_special_tokens=False):
|
| 24 |
+
if isinstance(ids, torch.Tensor):
|
| 25 |
+
ids = ids.tolist()
|
| 26 |
+
return " ".join(str(i) for i in ids)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_single_token_repeat_detects_cjk_loop():
|
| 30 |
+
ids = [39992, 25, 7379] + [41146] * 40
|
| 31 |
+
assert _detect_single_token_repeat(ids)
|
| 32 |
+
assert _max_same_token_run(ids) == (40, 41146)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_ngram_repeat_not_limited_to_first_eight_tokens():
|
| 36 |
+
prefix = list(range(20))
|
| 37 |
+
gram = [9, 8, 7]
|
| 38 |
+
ids = prefix + gram * 5
|
| 39 |
+
assert _detect_repeat_loop(ids)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_char_repeat_detects_qiqiqi():
|
| 43 |
+
assert _detect_char_repeat("其其其其其其")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_is_degenerate_completion_detects_repeat():
|
| 47 |
+
ids = [39992, 25] + [41146] * 20
|
| 48 |
+
assert is_degenerate_completion(ids, "Goal: x\n" + "其" * 40)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_short_numeric_answer_not_degenerate_without_answer_flag():
|
| 52 |
+
ids = [198, 17, 15, 18, 15] # \n2030
|
| 53 |
+
assert not is_degenerate_completion(ids, "\n2030", require_answer_flag=False)
|
| 54 |
+
assert is_degenerate_completion(ids, "\n2030", require_answer_flag=True)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_degeneration_flags_missing_answer():
|
| 58 |
+
ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 59 |
+
text = "Goal: test\nObservation: x\nReasoning: y\nConclusion: z"
|
| 60 |
+
is_deg, reasons = _detect_degeneration(ids, text, answer_flag="Answer:")
|
| 61 |
+
assert is_deg
|
| 62 |
+
assert any(r.startswith("ANSWER_FLAG_COUNT") for r in reasons)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_log_generate_probe_does_not_shadow_tokenizer_across_samples():
|
| 66 |
+
"""sample[0] single-token repeat must not break decode for sample[1]."""
|
| 67 |
+
repeat_tail = [24] * 12
|
| 68 |
+
row0 = [39992, 25, 7379] + repeat_tail + [0] * (200 - 3 - len(repeat_tail))
|
| 69 |
+
row1 = [39992, 25, 7379, 100, 101, 102] + [0] * 194
|
| 70 |
+
completion_ids = torch.tensor([row0, row1], dtype=torch.long)
|
| 71 |
+
completion_mask = torch.tensor(
|
| 72 |
+
[[1] * 15 + [0] * 185, [1] * 6 + [0] * 194],
|
| 73 |
+
dtype=torch.long,
|
| 74 |
+
)
|
| 75 |
+
is_eos = torch.zeros_like(completion_mask, dtype=torch.bool)
|
| 76 |
+
is_eos[:, 14] = True
|
| 77 |
+
is_eos[:, 5] = True
|
| 78 |
+
eos_idx = torch.tensor([14, 5], dtype=torch.long)
|
| 79 |
+
completions = ["Goal: repeat\n" + "x" * 20, "Goal: ok\nAnswer: 1"]
|
| 80 |
+
|
| 81 |
+
with patch.object(opsd_debug, "should_log_probe", return_value=True):
|
| 82 |
+
stats = log_generate_probe(
|
| 83 |
+
global_step=1,
|
| 84 |
+
trainer_step=1,
|
| 85 |
+
prompt_length=100,
|
| 86 |
+
prompt_completion_ids=torch.zeros(2, 300, dtype=torch.long),
|
| 87 |
+
completion_ids=completion_ids,
|
| 88 |
+
completion_mask=completion_mask,
|
| 89 |
+
is_eos=is_eos,
|
| 90 |
+
eos_idx=eos_idx,
|
| 91 |
+
completions=completions,
|
| 92 |
+
tokenizer=_FakeTokenizer(),
|
| 93 |
+
generation_config=None,
|
| 94 |
+
max_completion_length=200,
|
| 95 |
+
num_generations=1,
|
| 96 |
+
sample_count=2,
|
| 97 |
+
)
|
| 98 |
+
assert stats["degenerate_count"] >= 1
|
tests/test_health_monitor.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for TrainingHealthMonitor alerts and correlation."""
|
| 2 |
+
from opsd_utils.health_monitor import (
|
| 3 |
+
ALERT_GEN_CLIP_COLLAPSE,
|
| 4 |
+
ALERT_GEN_REPEAT_DEGEN,
|
| 5 |
+
ALERT_RL_ZERO_SIGNAL,
|
| 6 |
+
TrainingHealthMonitor,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_clip_collapse_alert():
|
| 11 |
+
hm = TrainingHealthMonitor({"enabled": True, "log_alerts_immediately": False})
|
| 12 |
+
hm.reset_step(1)
|
| 13 |
+
alerts = hm.record_generate(
|
| 14 |
+
1,
|
| 15 |
+
{"clipped_rate": 0.85, "eos_terminated_rate": 0.1, "degenerate_rate": 0.2, "repeat_loop_count": 0},
|
| 16 |
+
{"p_greedy_first": 0.99, "p_eos_first": 1e-6},
|
| 17 |
+
)
|
| 18 |
+
assert ALERT_GEN_CLIP_COLLAPSE in alerts
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_repeat_degen_alert():
|
| 22 |
+
hm = TrainingHealthMonitor({"enabled": True, "log_alerts_immediately": False})
|
| 23 |
+
hm.reset_step(2)
|
| 24 |
+
alerts = hm.record_generate(
|
| 25 |
+
2,
|
| 26 |
+
{"clipped_rate": 0.3, "eos_terminated_rate": 0.5, "degenerate_rate": 0.6, "repeat_loop_count": 1},
|
| 27 |
+
{},
|
| 28 |
+
)
|
| 29 |
+
assert ALERT_GEN_REPEAT_DEGEN in alerts
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_rl_zero_signal_alert():
|
| 33 |
+
hm = TrainingHealthMonitor({"enabled": True, "log_alerts_immediately": False})
|
| 34 |
+
hm.reset_step(3)
|
| 35 |
+
hm.record_loss(3, {"advantages_abs_mean": 0.0, "grpo_zero_loss_rate": 0.95})
|
| 36 |
+
assert ALERT_RL_ZERO_SIGNAL in hm._step_alerts
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_correlate_hints_after_history():
|
| 40 |
+
hm = TrainingHealthMonitor({"enabled": True, "window": 5, "log_every_step": False})
|
| 41 |
+
hm.reset_step(0)
|
| 42 |
+
hm.record_generate(
|
| 43 |
+
0,
|
| 44 |
+
{"clipped_rate": 0.1, "eos_terminated_rate": 0.9, "degenerate_rate": 0.1, "repeat_loop_count": 0},
|
| 45 |
+
{"p_greedy_first": 0.8, "p_eos_first": 0.01},
|
| 46 |
+
)
|
| 47 |
+
hm.record_optimizer(0, 0.5, 8e-5)
|
| 48 |
+
hm.finish_step(0)
|
| 49 |
+
|
| 50 |
+
hm.reset_step(1)
|
| 51 |
+
hm.record_generate(
|
| 52 |
+
1,
|
| 53 |
+
{"clipped_rate": 0.9, "eos_terminated_rate": 0.05, "degenerate_rate": 0.5, "repeat_loop_count": 1},
|
| 54 |
+
{"p_greedy_first": 0.995, "p_eos_first": 1e-6},
|
| 55 |
+
)
|
| 56 |
+
hm.record_optimizer(1, 2.5, 8e-5)
|
| 57 |
+
corr = hm.correlate()
|
| 58 |
+
assert "delta_clipped_rate" in corr or "root_cause_hints" in corr
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_finish_step_returns_metrics_keys():
|
| 62 |
+
hm = TrainingHealthMonitor({"enabled": True, "metrics_every_step": True, "log_every_step": False})
|
| 63 |
+
hm.reset_step(1)
|
| 64 |
+
hm.record_generate(
|
| 65 |
+
1,
|
| 66 |
+
{"clipped_rate": 0.2, "eos_terminated_rate": 0.8, "degenerate_rate": 0.1, "repeat_loop_count": 0},
|
| 67 |
+
{"p_greedy_first": 0.9, "p_eos_first": 0.001},
|
| 68 |
+
)
|
| 69 |
+
hm.record_optimizer(1, 1.0, 8e-5)
|
| 70 |
+
metrics = hm.finish_step(1)
|
| 71 |
+
assert "completions/degenerate_rate" in metrics
|
| 72 |
+
assert "health/alert_count" in metrics
|
tests/test_mode_router_rlsd.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RLSD / COPSD anti-leakage routing tests."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from opsd_utils.constants import MODE_GRPO, MODE_OPSD, MODE_SFT
|
| 11 |
+
from opsd_utils.mode_router import route_completion_modes, route_prompt_modes
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _rlsd_cfg(**gate):
|
| 15 |
+
base = {
|
| 16 |
+
"enabled": True,
|
| 17 |
+
"mode": "rlsd",
|
| 18 |
+
"gate": {
|
| 19 |
+
"correct_threshold": 0.5,
|
| 20 |
+
"per_completion_opsd": True,
|
| 21 |
+
"require_format_for_opsd": False,
|
| 22 |
+
**gate,
|
| 23 |
+
},
|
| 24 |
+
}
|
| 25 |
+
return base
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_rlsd_prompt_correct_grpo():
|
| 29 |
+
acc = torch.tensor([[1.0, 0.0]])
|
| 30 |
+
modes = route_prompt_modes(acc, 2, _rlsd_cfg(), recoverable_flags=[True])
|
| 31 |
+
assert modes == [MODE_GRPO]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_rlsd_prompt_wrong_opsd_when_recoverable():
|
| 35 |
+
acc = torch.tensor([[0.0, 0.0]])
|
| 36 |
+
modes = route_prompt_modes(acc, 2, _rlsd_cfg(), recoverable_flags=[True])
|
| 37 |
+
assert modes == [MODE_OPSD]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_rlsd_prompt_wrong_sft_when_not_recoverable():
|
| 41 |
+
acc = torch.tensor([[0.0, 0.0]])
|
| 42 |
+
modes = route_prompt_modes(acc, 2, _rlsd_cfg(), recoverable_flags=[False])
|
| 43 |
+
assert modes == [MODE_SFT]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_rlsd_per_completion_routing():
|
| 47 |
+
acc = torch.tensor([[1.0, 0.0]])
|
| 48 |
+
fmt = torch.tensor([[1.0, 0.5]])
|
| 49 |
+
modes = route_completion_modes(acc, 2, 2, _rlsd_cfg(), [True], format_rewards=fmt)
|
| 50 |
+
assert modes == [MODE_GRPO, MODE_OPSD]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_copsd_opd_alias_matches_rlsd():
|
| 54 |
+
acc = torch.tensor([[0.0, 1.0]])
|
| 55 |
+
cfg = _rlsd_cfg()
|
| 56 |
+
cfg["mode"] = "copsd_opd"
|
| 57 |
+
modes = route_completion_modes(acc, 2, 2, cfg, [True])
|
| 58 |
+
assert modes == [MODE_OPSD, MODE_GRPO]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_rlsd_all_wrong_group_first_completion_sft_cold_start():
|
| 62 |
+
"""All-wrong group: gen 0 → SFT replace, other wrong gens → OPSD."""
|
| 63 |
+
acc = torch.tensor([[0.0, 0.0, 0.0, 0.0]])
|
| 64 |
+
modes = route_completion_modes(acc, 4, 4, _rlsd_cfg(), [True])
|
| 65 |
+
assert modes == [MODE_SFT, MODE_OPSD, MODE_OPSD, MODE_OPSD]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_rlsd_all_wrong_two_prompts():
|
| 69 |
+
acc = torch.tensor([[0.0, 0.0], [0.0, 0.0]])
|
| 70 |
+
modes = route_completion_modes(acc, 2, 4, _rlsd_cfg(), [True, True])
|
| 71 |
+
assert modes == [MODE_SFT, MODE_OPSD, MODE_SFT, MODE_OPSD]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_rlsd_partial_correct_no_cold_start_on_wrong():
|
| 75 |
+
acc = torch.tensor([[1.0, 0.0, 0.0]])
|
| 76 |
+
modes = route_completion_modes(acc, 3, 3, _rlsd_cfg(), [True])
|
| 77 |
+
assert modes == [MODE_GRPO, MODE_OPSD, MODE_OPSD]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test_online_sft_on_all_wrong_can_be_disabled():
|
| 81 |
+
acc = torch.tensor([[0.0, 0.0]])
|
| 82 |
+
cfg = _rlsd_cfg(online_sft_on_all_wrong=False)
|
| 83 |
+
modes = route_completion_modes(acc, 2, 2, cfg, [True])
|
| 84 |
+
assert modes == [MODE_OPSD, MODE_OPSD]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
test_rlsd_prompt_correct_grpo()
|
| 89 |
+
test_rlsd_prompt_wrong_opsd_when_recoverable()
|
| 90 |
+
test_rlsd_prompt_wrong_sft_when_not_recoverable()
|
| 91 |
+
test_rlsd_per_completion_routing()
|
| 92 |
+
test_copsd_opd_alias_matches_rlsd()
|
| 93 |
+
test_rlsd_all_wrong_group_first_completion_sft_cold_start()
|
| 94 |
+
test_rlsd_all_wrong_two_prompts()
|
| 95 |
+
test_rlsd_partial_correct_no_cold_start_on_wrong()
|
| 96 |
+
test_online_sft_on_all_wrong_can_be_disabled()
|
| 97 |
+
print("RLSD routing tests passed.")
|
tests/test_privileged.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
+
|
| 11 |
+
from data_utils.privileged_schema import (
|
| 12 |
+
heuristic_bbox_from_visual_fact,
|
| 13 |
+
normalize_evidence_bbox,
|
| 14 |
+
parse_visual_fact,
|
| 15 |
+
resolve_crop_bbox,
|
| 16 |
+
)
|
| 17 |
+
from opsd_utils import debug_log as opsd_debug
|
| 18 |
+
from opsd_utils.privileged import build_privileged_context, maybe_save_privileged_images
|
| 19 |
+
from opsd_utils.privileged.image_utils import crop_image, load_rgb, resolve_teacher_images
|
| 20 |
+
from opsd_utils.privileged.profiles import effective_profile
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _make_image(path: str, size=(100, 100), color=(255, 0, 0)):
|
| 24 |
+
img = Image.new("RGB", size, color)
|
| 25 |
+
img.save(path)
|
| 26 |
+
return path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_text_provider():
|
| 30 |
+
sample = {"hint": "Rep=67", "answer": "Answer: 131"}
|
| 31 |
+
suffix, images = build_privileged_context(sample, ["text"], privileged_profile="text")
|
| 32 |
+
assert "Rep=67" in suffix
|
| 33 |
+
assert "131" in suffix
|
| 34 |
+
assert images == []
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_hybrid_provider_suffix():
|
| 38 |
+
img = Image.new("RGB", (32, 32))
|
| 39 |
+
sample = {"hint": "step", "visual_fact": "bar=3", "answer": "Answer: 3", "image": img}
|
| 40 |
+
suffix, images = build_privileged_context(
|
| 41 |
+
sample,
|
| 42 |
+
privileged_profile="hybrid",
|
| 43 |
+
opsd_config={"privileged_image": {"mode": "dual"}},
|
| 44 |
+
)
|
| 45 |
+
assert "Visual Facts" in suffix
|
| 46 |
+
assert "Reference" in suffix
|
| 47 |
+
assert len(images) == 2
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_hybrid_default_single_image_for_chartqa():
|
| 51 |
+
img = Image.new("RGB", (32, 32))
|
| 52 |
+
sample = {"hint": "step", "visual_fact": "bar=3", "answer": "Answer: 3", "image": img}
|
| 53 |
+
suffix, images = build_privileged_context(sample, privileged_profile="hybrid")
|
| 54 |
+
assert "Visual Facts" in suffix
|
| 55 |
+
assert "Reference" in suffix
|
| 56 |
+
assert len(images) == 1
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def test_visual_profile_excludes_answer():
|
| 60 |
+
img = Image.new("RGB", (32, 32))
|
| 61 |
+
sample = {"hint": "secret", "visual_fact": '{"objects":[]}', "answer": "Answer: 3", "image": img}
|
| 62 |
+
suffix, _ = build_privileged_context(sample, privileged_profile="visual")
|
| 63 |
+
assert "Visual Facts" in suffix
|
| 64 |
+
assert "Reference Answer" not in suffix
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_math_lm_downgrade():
|
| 68 |
+
sample = {"hint": "step", "answer": "Answer: 1"}
|
| 69 |
+
profile = effective_profile(sample, "hybrid")
|
| 70 |
+
assert profile == "text"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_normalize_evidence_bbox_c2():
|
| 74 |
+
assert normalize_evidence_bbox([0.1, 0.2, 0.8, 0.9]) == [0.1, 0.2, 0.8, 0.9]
|
| 75 |
+
assert normalize_evidence_bbox([0.1, 0.2, 1.5, 0.9]) is None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def test_heuristic_bbox_d2():
|
| 79 |
+
vf = json.dumps({"objects": [{"name": "cat", "position": "center"}]})
|
| 80 |
+
bbox = heuristic_bbox_from_visual_fact(vf)
|
| 81 |
+
assert bbox == [0.25, 0.25, 0.75, 0.75]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_crop_image_normalized_bbox():
|
| 85 |
+
img = Image.new("RGB", (100, 100), (0, 255, 0))
|
| 86 |
+
crop, strategy = crop_image(img, bbox_norm=[0.2, 0.2, 0.8, 0.8], strategy="bbox")
|
| 87 |
+
assert strategy == "bbox"
|
| 88 |
+
assert crop.size[0] > 0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_resolve_teacher_images_dual():
|
| 92 |
+
img = Image.new("RGB", (80, 80), (0, 0, 255))
|
| 93 |
+
sample = {
|
| 94 |
+
"image": img,
|
| 95 |
+
"visual_fact": json.dumps({"objects": [{"position": "top"}]}),
|
| 96 |
+
}
|
| 97 |
+
images, meta = resolve_teacher_images(sample, "hybrid", crop_cfg={"mode": "dual"})
|
| 98 |
+
assert len(images) == 2
|
| 99 |
+
assert meta["num_teacher_images"] == 2
|
| 100 |
+
assert meta["crop_strategy"] in ("heuristic", "center", "center_fallback", "bbox")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def test_chartqa_enriched_visual_fact_hint():
|
| 104 |
+
"""Enriched ChartQA records (F1+F2) should activate VisualFactsProvider."""
|
| 105 |
+
from data_utils.chart.deplot_pipeline import build_deplot_visual_fact
|
| 106 |
+
|
| 107 |
+
sample = {
|
| 108 |
+
"hint": "Goal: Find the lowest value.\nObservation: values are 70, 72, 77.",
|
| 109 |
+
"answer": "Answer: 70",
|
| 110 |
+
"visual_fact_hint": "Goal: Find the lowest value.\nObservation: values are 70, 72, 77.",
|
| 111 |
+
"visual_fact": "Goal: Find the lowest value.\nObservation: values are 70, 72, 77.",
|
| 112 |
+
"visual_fact_deplot": build_deplot_visual_fact(
|
| 113 |
+
{"question": "q"}, "Year | Value\n2019 | 70\n2020 | 72"
|
| 114 |
+
),
|
| 115 |
+
"image": Image.new("RGB", (64, 64)),
|
| 116 |
+
}
|
| 117 |
+
suffix, images = build_privileged_context(
|
| 118 |
+
sample,
|
| 119 |
+
["text", "visual_facts"],
|
| 120 |
+
privileged_profile="hybrid",
|
| 121 |
+
)
|
| 122 |
+
assert "Visual Facts - Hint" in suffix
|
| 123 |
+
assert "Visual Facts - DePlot" in suffix
|
| 124 |
+
assert "2019 | 70" in suffix
|
| 125 |
+
assert "Reference Reasoning" in suffix
|
| 126 |
+
assert len(images) == 1
|
| 127 |
+
vf_raw = sample.get("visual_fact") or sample.get("visual_facts")
|
| 128 |
+
assert vf_raw and len(vf_raw.strip()) > 0
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_visual_facts_f1_f2_merge():
|
| 132 |
+
from data_utils.chart.deplot_pipeline import build_deplot_visual_fact
|
| 133 |
+
|
| 134 |
+
sample = {
|
| 135 |
+
"visual_fact_hint": "hint table",
|
| 136 |
+
"visual_fact_deplot": build_deplot_visual_fact(
|
| 137 |
+
{"question": "q"}, "Col | Val\nA | 1"
|
| 138 |
+
),
|
| 139 |
+
"image": Image.new("RGB", (32, 32)),
|
| 140 |
+
}
|
| 141 |
+
suffix, _ = build_privileged_context(sample, privileged_profile="hybrid")
|
| 142 |
+
assert "Visual Facts - Hint" in suffix
|
| 143 |
+
assert "Visual Facts - DePlot" in suffix
|
| 144 |
+
assert "Col | Val" in suffix
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_parse_visual_fact_b1():
|
| 148 |
+
raw = {"objects": [{"name": "a"}]}
|
| 149 |
+
text = parse_visual_fact(raw)
|
| 150 |
+
assert "objects" in text
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def test_debug_artifacts_respect_detail_every():
|
| 154 |
+
opsd_debug.configure(enabled=True, detail_every=10, rank=0, world_size=1)
|
| 155 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 156 |
+
img = Image.new("RGB", (32, 32))
|
| 157 |
+
path = maybe_save_privileged_images(5, 0, img, img, meta={"crop_strategy": "center"}, output_dir=tmp)
|
| 158 |
+
assert path is None
|
| 159 |
+
assert not os.path.exists(os.path.join(tmp, "logs", "images"))
|
| 160 |
+
|
| 161 |
+
path = maybe_save_privileged_images(10, 0, img, img, meta={"crop_strategy": "center"}, output_dir=tmp)
|
| 162 |
+
assert path is not None
|
| 163 |
+
assert os.path.exists(f"{path}_full.png")
|
| 164 |
+
assert os.path.exists(f"{path}_meta.json")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
test_text_provider()
|
| 169 |
+
test_hybrid_provider_suffix()
|
| 170 |
+
test_math_lm_downgrade()
|
| 171 |
+
test_debug_artifacts_respect_detail_every()
|
| 172 |
+
print("Privileged provider tests passed.")
|
tests/test_privileged_debug_artifacts.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
|
| 10 |
+
from opsd_utils import debug_log as opsd_debug
|
| 11 |
+
from opsd_utils.privileged.debug_artifacts import maybe_save_privileged_images
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_max_samples_per_detail():
|
| 15 |
+
opsd_debug.configure(enabled=True, detail_every=1, rank=0, world_size=1)
|
| 16 |
+
cfg = {"save_images": True, "image_subdir": "logs/images", "max_samples_per_detail": 1}
|
| 17 |
+
img = Image.new("RGB", (16, 16))
|
| 18 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 19 |
+
p0 = maybe_save_privileged_images(1, 0, img, None, meta={}, output_dir=tmp, privileged_debug_cfg=cfg)
|
| 20 |
+
p1 = maybe_save_privileged_images(1, 1, img, None, meta={}, output_dir=tmp, privileged_debug_cfg=cfg)
|
| 21 |
+
assert p0 is not None
|
| 22 |
+
assert p1 is None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_meta_sidecar():
|
| 26 |
+
opsd_debug.configure(enabled=True, detail_every=1, rank=0, world_size=1)
|
| 27 |
+
img = Image.new("RGB", (16, 16))
|
| 28 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 29 |
+
prefix = maybe_save_privileged_images(
|
| 30 |
+
1,
|
| 31 |
+
0,
|
| 32 |
+
img,
|
| 33 |
+
img,
|
| 34 |
+
meta={"privileged_profile": "hybrid", "crop_strategy": "bbox"},
|
| 35 |
+
output_dir=tmp,
|
| 36 |
+
)
|
| 37 |
+
with open(f"{prefix}_meta.json", encoding="utf-8") as f:
|
| 38 |
+
meta = json.load(f)
|
| 39 |
+
assert meta["privileged_profile"] == "hybrid"
|
| 40 |
+
assert meta["crop_strategy"] == "bbox"
|
tests/test_slice_completion_logits.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for completion logit slicing shared by GRPO and OPSD."""
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from opsd_utils.opsd_loss import slice_student_completion_logits
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_slice_matches_legacy_grpo_path():
|
| 13 |
+
logits_to_keep = 4
|
| 14 |
+
full = torch.randn(2, 20, 8)
|
| 15 |
+
legacy = full[:, -logits_to_keep - 1 :, :]
|
| 16 |
+
legacy = legacy[:, :-1, :]
|
| 17 |
+
legacy = legacy[:, -logits_to_keep:, :]
|
| 18 |
+
assert torch.allclose(legacy, slice_student_completion_logits(full, logits_to_keep))
|
tests/test_teacher_dual_image.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from opsd_utils.privileged.image_utils import resolve_teacher_images
|
| 10 |
+
from opsd_utils.privileged.profiles import effective_profile
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_text_profile_single_image():
|
| 14 |
+
img = Image.new("RGB", (64, 64))
|
| 15 |
+
sample = {"image": img, "hint": "h", "answer": "Answer: 1"}
|
| 16 |
+
images, meta = resolve_teacher_images(sample, "text")
|
| 17 |
+
assert len(images) == 1
|
| 18 |
+
assert meta["num_teacher_images"] == 1
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_hybrid_profile_dual_image():
|
| 22 |
+
img = Image.new("RGB", (64, 64))
|
| 23 |
+
sample = {"image": img, "evidence_bbox": [0.1, 0.1, 0.9, 0.9]}
|
| 24 |
+
images, meta = resolve_teacher_images(sample, "hybrid", crop_cfg={"mode": "dual"})
|
| 25 |
+
assert len(images) == 2
|
| 26 |
+
assert meta["has_bbox"] is True
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_hybrid_profile_single_image_by_default():
|
| 30 |
+
img = Image.new("RGB", (64, 64))
|
| 31 |
+
sample = {"image": img, "evidence_bbox": [0.1, 0.1, 0.9, 0.9]}
|
| 32 |
+
images, meta = resolve_teacher_images(sample, "hybrid")
|
| 33 |
+
assert len(images) == 1
|
| 34 |
+
assert meta["num_teacher_images"] == 1
|
| 35 |
+
assert meta["crop_strategy"] == "single_full"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_no_image_empty():
|
| 39 |
+
sample = {"hint": "only text"}
|
| 40 |
+
assert effective_profile(sample, "hybrid") == "text"
|
| 41 |
+
images, meta = resolve_teacher_images(sample, "text")
|
| 42 |
+
assert images == []
|
| 43 |
+
assert meta["num_teacher_images"] == 0
|
tests/test_vocab_align.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for cross-model vocab alignment diagnostics."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
+
|
| 11 |
+
from opsd_utils.opsd_loss import generalized_jsd_loss
|
| 12 |
+
from opsd_utils.vocab_align import (
|
| 13 |
+
align_cross_model_logits,
|
| 14 |
+
reset_vocab_align_debug,
|
| 15 |
+
verify_shared_tokenizer_alignment,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _Tok:
|
| 20 |
+
def __init__(self, vocab_size: int, offset: int = 0):
|
| 21 |
+
self._size = vocab_size
|
| 22 |
+
self._offset = offset
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return self._size
|
| 26 |
+
|
| 27 |
+
def decode(self, ids, skip_special_tokens=False):
|
| 28 |
+
i = ids[0]
|
| 29 |
+
return f"tok_{i + self._offset}"
|
| 30 |
+
|
| 31 |
+
def convert_ids_to_tokens(self, i):
|
| 32 |
+
return f"tok_{i + self._offset}"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_align_slice_renorm_via_log_softmax():
|
| 36 |
+
reset_vocab_align_debug()
|
| 37 |
+
student = torch.randn(1, 3, 100, requires_grad=True)
|
| 38 |
+
teacher = torch.randn(1, 3, 128)
|
| 39 |
+
s, t = align_cross_model_logits(student, teacher, log_renorm_check=False)
|
| 40 |
+
assert s.shape[-1] == t.shape[-1] == 100
|
| 41 |
+
t_probs = F.softmax(t[0, 0], dim=-1)
|
| 42 |
+
assert abs(float(t_probs.sum()) - 1.0) < 1e-4
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_generalized_jsd_renormalizes_after_slice():
|
| 46 |
+
reset_vocab_align_debug()
|
| 47 |
+
student = torch.randn(1, 5, 152000, requires_grad=True)
|
| 48 |
+
teacher = torch.randn(1, 5, 152128)
|
| 49 |
+
mask = torch.ones(1, 5)
|
| 50 |
+
loss = generalized_jsd_loss(student, teacher, mask)
|
| 51 |
+
assert loss.ndim == 0
|
| 52 |
+
assert loss.requires_grad
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_tokenizer_alignment_detects_mismatch():
|
| 56 |
+
st = _Tok(200, offset=0)
|
| 57 |
+
tt = _Tok(200, offset=1)
|
| 58 |
+
report = verify_shared_tokenizer_alignment(
|
| 59 |
+
st, tt, shared_vocab=200, full_scan=True, sample_stride=1
|
| 60 |
+
)
|
| 61 |
+
assert not report["aligned"]
|
| 62 |
+
assert report["mismatch_count"] > 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_tokenizer_alignment_passes_identical():
|
| 66 |
+
st = _Tok(1000, offset=0)
|
| 67 |
+
tt = _Tok(1200, offset=0)
|
| 68 |
+
report = verify_shared_tokenizer_alignment(
|
| 69 |
+
st, tt, shared_vocab=1000, full_scan=False, sample_stride=100
|
| 70 |
+
)
|
| 71 |
+
assert report["aligned"]
|
trainer/DyMETrainer_7B.py
ADDED
|
@@ -0,0 +1,983 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import os
|
| 3 |
+
import textwrap
|
| 4 |
+
import warnings
|
| 5 |
+
from collections import defaultdict, deque
|
| 6 |
+
from collections.abc import Sized
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from typing import Any, Callable, Optional, Union
|
| 9 |
+
|
| 10 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 11 |
+
|
| 12 |
+
import datasets
|
| 13 |
+
import torch
|
| 14 |
+
import torch.utils.data
|
| 15 |
+
import transformers
|
| 16 |
+
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
| 17 |
+
from datasets import Dataset, IterableDataset
|
| 18 |
+
from packaging import version
|
| 19 |
+
from torch import nn
|
| 20 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 21 |
+
from torch.utils.data import DataLoader, Sampler, DistributedSampler
|
| 22 |
+
from transformers import (
|
| 23 |
+
AutoModelForCausalLM,
|
| 24 |
+
AutoModelForSequenceClassification,
|
| 25 |
+
AutoTokenizer,
|
| 26 |
+
GenerationConfig,
|
| 27 |
+
PreTrainedModel,
|
| 28 |
+
PreTrainedTokenizerBase,
|
| 29 |
+
Trainer,
|
| 30 |
+
TrainerCallback,
|
| 31 |
+
is_wandb_available,
|
| 32 |
+
)
|
| 33 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 34 |
+
from transformers.trainer_utils import seed_worker
|
| 35 |
+
from transformers.utils import is_datasets_available, is_peft_available
|
| 36 |
+
|
| 37 |
+
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
| 38 |
+
from trl.extras.profiling import profiling_context, profiling_decorator
|
| 39 |
+
from trl.import_utils import is_liger_kernel_available, is_vllm_available
|
| 40 |
+
from trl.models import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
|
| 41 |
+
# from trl.models.utils import _ForwardRedirection
|
| 42 |
+
from trl.trainer.callbacks import SyncRefModelCallback
|
| 43 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 44 |
+
from trl.trainer.utils import (
|
| 45 |
+
disable_dropout_in_model,
|
| 46 |
+
generate_model_card,
|
| 47 |
+
get_comet_experiment_url,
|
| 48 |
+
pad,
|
| 49 |
+
print_prompt_completions_sample,
|
| 50 |
+
selective_log_softmax,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
from trl.models import prepare_deepspeed, unwrap_model_for_generation
|
| 54 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 55 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, selective_log_softmax
|
| 56 |
+
|
| 57 |
+
import concurrent.futures
|
| 58 |
+
from datasets import Dataset, IterableDataset
|
| 59 |
+
|
| 60 |
+
from reward_utils import checker
|
| 61 |
+
from reward_utils.checker import RewardCalculator
|
| 62 |
+
from reward_utils.compute_rewards import calculate_rewards_in_parallel, refine_context_in_parallel
|
| 63 |
+
|
| 64 |
+
if is_wandb_available():
|
| 65 |
+
import wandb
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 70 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 71 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RepeatSampler(Sampler):
|
| 75 |
+
"""
|
| 76 |
+
Sampler that repeats the indices of a dataset in a structured manner.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
data_source (`Sized`):
|
| 80 |
+
Dataset to sample from.
|
| 81 |
+
mini_repeat_count (`int`):
|
| 82 |
+
Number of times to repeat each index per batch.
|
| 83 |
+
batch_size (`int`, *optional*, defaults to `1`):
|
| 84 |
+
Number of unique indices per batch.
|
| 85 |
+
repeat_count (`int`, *optional*, defaults to `1`):
|
| 86 |
+
Number of times to repeat the full sampling process.
|
| 87 |
+
shuffle (`bool`, *optional*, defaults to `True`):
|
| 88 |
+
Whether to shuffle the dataset.
|
| 89 |
+
seed (`int` or `None`, *optional*, defaults to `None`):
|
| 90 |
+
Random seed for reproducibility (only affects this sampler).
|
| 91 |
+
|
| 92 |
+
Example:
|
| 93 |
+
```python
|
| 94 |
+
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
|
| 95 |
+
>>> list(sampler)
|
| 96 |
+
[4, 4, 3, 3, 0, 0,
|
| 97 |
+
4, 4, 3, 3, 0, 0,
|
| 98 |
+
4, 4, 3, 3, 0, 0,
|
| 99 |
+
4, 4, 3, 3, 0, 0,
|
| 100 |
+
|
| 101 |
+
1, 1, 2, 2, 6, 6,
|
| 102 |
+
1, 1, 2, 2, 6, 6,
|
| 103 |
+
1, 1, 2, 2, 6, 6,
|
| 104 |
+
1, 1, 2, 2, 6, 6]
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
```txt
|
| 108 |
+
mini_repeat_count = 3
|
| 109 |
+
- - -
|
| 110 |
+
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
| 111 |
+
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
| 112 |
+
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
|
| 113 |
+
repeat_count = 2
|
| 114 |
+
0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
| 115 |
+
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
| 116 |
+
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
|
| 117 |
+
--------- --------- --------- ---------
|
| 118 |
+
--------- --------- --------- ---------
|
| 119 |
+
--------- --------- --------- ---------
|
| 120 |
+
batch_size = 12
|
| 121 |
+
```
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
data_source: Sized,
|
| 127 |
+
mini_repeat_count: int,
|
| 128 |
+
batch_size: int = 1,
|
| 129 |
+
repeat_count: int = 1,
|
| 130 |
+
shuffle: bool = True,
|
| 131 |
+
seed: Optional[int] = None,
|
| 132 |
+
):
|
| 133 |
+
self.data_source = data_source
|
| 134 |
+
self.mini_repeat_count = mini_repeat_count
|
| 135 |
+
self.batch_size = batch_size
|
| 136 |
+
self.repeat_count = repeat_count
|
| 137 |
+
self.num_samples = len(data_source)
|
| 138 |
+
self.shuffle = shuffle
|
| 139 |
+
self.seed = seed
|
| 140 |
+
|
| 141 |
+
if shuffle:
|
| 142 |
+
self.generator = torch.Generator() # Create a local random generator
|
| 143 |
+
if seed is not None:
|
| 144 |
+
self.generator.manual_seed(seed)
|
| 145 |
+
|
| 146 |
+
def __iter__(self):
|
| 147 |
+
if self.shuffle:
|
| 148 |
+
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
|
| 149 |
+
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
| 150 |
+
else:
|
| 151 |
+
indexes = list(range(self.num_samples))
|
| 152 |
+
|
| 153 |
+
# [2, 4, 3, 1, 0, 6, 5]
|
| 154 |
+
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
|
| 155 |
+
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
| 156 |
+
|
| 157 |
+
# [[2, 4, 3], [1, 0, 6], [5]]
|
| 158 |
+
# -> [[2, 4, 3], [1, 0, 6]]
|
| 159 |
+
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
| 160 |
+
|
| 161 |
+
for chunk in indexes:
|
| 162 |
+
for _ in range(self.repeat_count):
|
| 163 |
+
for index in chunk:
|
| 164 |
+
for _ in range(self.mini_repeat_count):
|
| 165 |
+
yield index
|
| 166 |
+
|
| 167 |
+
def __len__(self) -> int:
|
| 168 |
+
return self.num_samples * self.mini_repeat_count * self.repeat_count
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# torch.nanstd doesn't exist, so we define it here
|
| 172 |
+
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
"""
|
| 174 |
+
Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
tensor (`torch.Tensor`):
|
| 178 |
+
Input tensor of shape `(N,)`.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
`torch.Tensor`:
|
| 182 |
+
Standard deviation of the tensor, ignoring NaNs.
|
| 183 |
+
"""
|
| 184 |
+
variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) # Compute variance ignoring NaNs
|
| 185 |
+
count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values
|
| 186 |
+
variance *= count / (count - 1) # Bessel's correction
|
| 187 |
+
return torch.sqrt(variance)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def split_tensor_dict(
|
| 191 |
+
tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int, image_patch_id=151655, patch_id_times=4
|
| 192 |
+
) -> list[dict[str, Optional[torch.Tensor]]]:
|
| 193 |
+
"""
|
| 194 |
+
Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts.
|
| 195 |
+
|
| 196 |
+
Example:
|
| 197 |
+
>>> x = torch.arange(12).reshape(6, 2)
|
| 198 |
+
>>> y = torch.arange(6).reshape(6, 1)
|
| 199 |
+
>>> tensor_dict = {"x": x, "y": y}
|
| 200 |
+
>>> split_tensor_dict(tensor_dict, 3)
|
| 201 |
+
[
|
| 202 |
+
{"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])},
|
| 203 |
+
{"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])},
|
| 204 |
+
{"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])}
|
| 205 |
+
]
|
| 206 |
+
"""
|
| 207 |
+
if image_patch_id is None:
|
| 208 |
+
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
|
| 209 |
+
chunk_size = first_tensor.shape[0] // num_chunks
|
| 210 |
+
# has = []
|
| 211 |
+
# if 'has_correct' in tensor_dict:
|
| 212 |
+
# has = tensor_dict['has_correct']
|
| 213 |
+
# del tensor_dict['has_correct']
|
| 214 |
+
l1 = []
|
| 215 |
+
for i in range(num_chunks):
|
| 216 |
+
dt = {
|
| 217 |
+
key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
|
| 218 |
+
for key, tensor in tensor_dict.items()
|
| 219 |
+
}
|
| 220 |
+
# if len(has) > 0:
|
| 221 |
+
# dt['has_correct'] = has[i]
|
| 222 |
+
l1.append(dt)
|
| 223 |
+
|
| 224 |
+
return l1
|
| 225 |
+
else:
|
| 226 |
+
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
|
| 227 |
+
chunk_size = first_tensor.shape[0] // num_chunks
|
| 228 |
+
l1 = []
|
| 229 |
+
for i in range(num_chunks):
|
| 230 |
+
dt = {}
|
| 231 |
+
for key, tensor in tensor_dict.items():
|
| 232 |
+
if key != 'pixel_values':
|
| 233 |
+
dt[key] = tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
|
| 234 |
+
|
| 235 |
+
l1.append(dt)
|
| 236 |
+
|
| 237 |
+
if 'pixel_values' in tensor_dict:
|
| 238 |
+
raw_pixel_values = tensor_dict['pixel_values']
|
| 239 |
+
start_image_patch = 0
|
| 240 |
+
for dt in l1:
|
| 241 |
+
batch_input_ids = dt['prompt_ids']
|
| 242 |
+
num_image_patches = (batch_input_ids == image_patch_id).sum().item() * patch_id_times
|
| 243 |
+
batch_pixel_values = raw_pixel_values[start_image_patch : start_image_patch + num_image_patches]
|
| 244 |
+
start_image_patch += num_image_patches
|
| 245 |
+
dt['pixel_values'] = batch_pixel_values
|
| 246 |
+
|
| 247 |
+
return l1
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def nanmin(tensor: torch.Tensor) -> torch.Tensor:
|
| 251 |
+
"""
|
| 252 |
+
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
`torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
|
| 259 |
+
"""
|
| 260 |
+
if torch.isnan(tensor).all():
|
| 261 |
+
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
|
| 262 |
+
return torch.min(tensor[~torch.isnan(tensor)])
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def nanmax(tensor: torch.Tensor) -> torch.Tensor:
|
| 266 |
+
"""
|
| 267 |
+
Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
`torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
|
| 274 |
+
"""
|
| 275 |
+
if torch.isnan(tensor).all():
|
| 276 |
+
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
|
| 277 |
+
return torch.max(tensor[~torch.isnan(tensor)])
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class DyMETrainer(Trainer):
|
| 281 |
+
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
model: PreTrainedModel,
|
| 285 |
+
checker = None,
|
| 286 |
+
refiner=None,
|
| 287 |
+
args: Optional[GRPOConfig] = None,
|
| 288 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 289 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
| 290 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 291 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 292 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 293 |
+
processing_func = None,
|
| 294 |
+
task_name: str = None,
|
| 295 |
+
end_flag: str = '<|im_end|>',
|
| 296 |
+
):
|
| 297 |
+
self.task_name = task_name
|
| 298 |
+
self.reward_weights = torch.nn.Parameter(torch.ones(3), requires_grad=False)
|
| 299 |
+
self.reward_func_names = ['format', 'thinking', 'accuracy']
|
| 300 |
+
# Models
|
| 301 |
+
# Trained model
|
| 302 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 303 |
+
|
| 304 |
+
# Enable gradient checkpointing if requested
|
| 305 |
+
if args.gradient_checkpointing:
|
| 306 |
+
model = self._enable_gradient_checkpointing(model, args)
|
| 307 |
+
|
| 308 |
+
# Processing class
|
| 309 |
+
if processing_class is None:
|
| 310 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
| 311 |
+
|
| 312 |
+
# Training arguments
|
| 313 |
+
self.max_prompt_length = args.max_prompt_length
|
| 314 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
| 315 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 316 |
+
self.temperature = args.temperature
|
| 317 |
+
self.top_p = args.top_p
|
| 318 |
+
self.top_k = args.top_k
|
| 319 |
+
self.min_p = args.min_p
|
| 320 |
+
self.repetition_penalty = args.repetition_penalty
|
| 321 |
+
self.use_liger_loss = args.use_liger_loss
|
| 322 |
+
self.loss_type = args.loss_type
|
| 323 |
+
self.scale_rewards = args.scale_rewards
|
| 324 |
+
self.mask_truncated_completions = args.mask_truncated_completions
|
| 325 |
+
self.end_flag = end_flag
|
| 326 |
+
self.checker = checker
|
| 327 |
+
self.refiner = refiner
|
| 328 |
+
# Datasets
|
| 329 |
+
self.shuffle_dataset = args.shuffle_dataset
|
| 330 |
+
|
| 331 |
+
if (
|
| 332 |
+
isinstance(train_dataset, IterableDataset)
|
| 333 |
+
or isinstance(eval_dataset, IterableDataset)
|
| 334 |
+
or (
|
| 335 |
+
isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
|
| 336 |
+
)
|
| 337 |
+
):
|
| 338 |
+
# See https://github.com/huggingface/trl/issues/3213
|
| 339 |
+
raise NotImplementedError(
|
| 340 |
+
"Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Multi-step
|
| 344 |
+
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
| 345 |
+
self.epsilon_low = args.epsilon
|
| 346 |
+
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
|
| 347 |
+
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
|
| 348 |
+
self._step = 0
|
| 349 |
+
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
|
| 350 |
+
# `_get_train_sampler` and `_prepare_inputs`.
|
| 351 |
+
self._buffered_inputs = None
|
| 352 |
+
|
| 353 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 354 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 355 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 356 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 357 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 358 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 359 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 360 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 361 |
+
return features
|
| 362 |
+
|
| 363 |
+
super().__init__(
|
| 364 |
+
model=model,
|
| 365 |
+
args=args,
|
| 366 |
+
data_collator=data_collator,
|
| 367 |
+
train_dataset=train_dataset,
|
| 368 |
+
eval_dataset=eval_dataset,
|
| 369 |
+
processing_class=processing_class,
|
| 370 |
+
callbacks=callbacks,
|
| 371 |
+
optimizers=optimizers,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Reference model
|
| 375 |
+
self.beta = args.beta
|
| 376 |
+
assert self.beta == 0
|
| 377 |
+
|
| 378 |
+
# Disable dropout in the models
|
| 379 |
+
if args.disable_dropout:
|
| 380 |
+
disable_dropout_in_model(model)
|
| 381 |
+
|
| 382 |
+
# Initialize the metrics
|
| 383 |
+
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
| 384 |
+
self._total_train_tokens = 0
|
| 385 |
+
self.log_completions = args.log_completions
|
| 386 |
+
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
|
| 387 |
+
self.num_completions_to_print = args.num_completions_to_print
|
| 388 |
+
# maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
|
| 389 |
+
# final optimization step.
|
| 390 |
+
maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps
|
| 391 |
+
self._textual_logs = {
|
| 392 |
+
"prompt": deque(maxlen=maxlen),
|
| 393 |
+
"completion": deque(maxlen=maxlen),
|
| 394 |
+
"rewards": defaultdict(lambda: deque(maxlen=maxlen)),
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
# Check if the effective batch size can be divided by the number of generations
|
| 398 |
+
if self.num_generations < 2:
|
| 399 |
+
raise ValueError(
|
| 400 |
+
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
|
| 401 |
+
f"{self.num_generations}, which is less than the minimum required."
|
| 402 |
+
)
|
| 403 |
+
num_processes = self.accelerator.num_processes
|
| 404 |
+
effective_batch_size = args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps
|
| 405 |
+
possible_values = [
|
| 406 |
+
n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0
|
| 407 |
+
]
|
| 408 |
+
if self.num_generations not in possible_values:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
f"The effective train batch size ({num_processes} x {args.per_device_train_batch_size} x "
|
| 411 |
+
f"{args.gradient_accumulation_steps}) must be evenly divisible by the number of generations per "
|
| 412 |
+
f"prompt ({self.num_generations}). Given the current effective train batch size, the valid values for "
|
| 413 |
+
f"the number of generations are: {possible_values}."
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
| 417 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
| 418 |
+
# it's safer to set it in all cases.
|
| 419 |
+
set_seed(args.seed, device_specific=True)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
self.generation_config = GenerationConfig(
|
| 423 |
+
max_new_tokens=self.max_completion_length,
|
| 424 |
+
do_sample=True,
|
| 425 |
+
pad_token_id=processing_class.tokenizer.pad_token_id,
|
| 426 |
+
bos_token_id=processing_class.tokenizer.bos_token_id,
|
| 427 |
+
eos_token_id=processing_class.tokenizer.eos_token_id,
|
| 428 |
+
temperature=self.temperature,
|
| 429 |
+
top_p=self.top_p,
|
| 430 |
+
top_k=self.top_k,
|
| 431 |
+
min_p=self.min_p,
|
| 432 |
+
repetition_penalty=self.repetition_penalty,
|
| 433 |
+
cache_implementation=args.cache_implementation,
|
| 434 |
+
use_cache=False if self.args.gradient_checkpointing else True
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 438 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 439 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 440 |
+
self.model_accepts_loss_kwargs = False
|
| 441 |
+
self.processing_func = processing_func
|
| 442 |
+
|
| 443 |
+
def _set_signature_columns_if_needed(self):
|
| 444 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 445 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 446 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 447 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 448 |
+
if self._signature_columns is None:
|
| 449 |
+
self._signature_columns = ["prompt"]
|
| 450 |
+
|
| 451 |
+
def get_train_dataloader(self):
|
| 452 |
+
if self.train_dataset is None:
|
| 453 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
| 454 |
+
|
| 455 |
+
train_dataset = self.train_dataset
|
| 456 |
+
data_collator = self.data_collator
|
| 457 |
+
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
| 458 |
+
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
| 459 |
+
else:
|
| 460 |
+
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
| 461 |
+
|
| 462 |
+
dataloader_params = {
|
| 463 |
+
"batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, # < this is the change
|
| 464 |
+
"collate_fn": data_collator,
|
| 465 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 466 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 467 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
| 471 |
+
dataloader_params["sampler"] = self._get_train_sampler()
|
| 472 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 473 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
| 474 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
| 475 |
+
dl = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
| 476 |
+
return dl
|
| 477 |
+
|
| 478 |
+
def _get_train_sampler(self) -> Sampler:
|
| 479 |
+
effective_batch_size = (
|
| 480 |
+
self.args.per_device_train_batch_size
|
| 481 |
+
* self.accelerator.num_processes
|
| 482 |
+
* self.args.gradient_accumulation_steps
|
| 483 |
+
)
|
| 484 |
+
return RepeatSampler(
|
| 485 |
+
data_source=self.train_dataset,
|
| 486 |
+
mini_repeat_count=self.num_generations,
|
| 487 |
+
batch_size=effective_batch_size // self.num_generations,
|
| 488 |
+
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
| 489 |
+
shuffle=self.shuffle_dataset,
|
| 490 |
+
seed=self.args.seed,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def _get_eval_sampler(self, eval_dataset):
|
| 494 |
+
# eval_dataset 是一个 map-style Dataset(非 IterableDataset)
|
| 495 |
+
return DistributedSampler(
|
| 496 |
+
dataset=eval_dataset,
|
| 497 |
+
num_replicas=self.accelerator.num_processes,
|
| 498 |
+
rank=self.accelerator.process_index,
|
| 499 |
+
shuffle=False,
|
| 500 |
+
seed=self.args.seed,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
| 504 |
+
"""Enables gradient checkpointing for the model."""
|
| 505 |
+
# Ensure use_cache is disabled
|
| 506 |
+
model.config.use_cache = False
|
| 507 |
+
|
| 508 |
+
# Enable gradient checkpointing on the base model for PEFT
|
| 509 |
+
if is_peft_model(model):
|
| 510 |
+
model.base_model.gradient_checkpointing_enable()
|
| 511 |
+
# Enable gradient checkpointing for non-PEFT models
|
| 512 |
+
else:
|
| 513 |
+
model.gradient_checkpointing_enable()
|
| 514 |
+
|
| 515 |
+
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 516 |
+
use_reentrant = (
|
| 517 |
+
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if use_reentrant:
|
| 521 |
+
model.enable_input_require_grads()
|
| 522 |
+
|
| 523 |
+
return model
|
| 524 |
+
|
| 525 |
+
@profiling_decorator
|
| 526 |
+
def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, logits_to_keep=None):
|
| 527 |
+
if is_peft_model(unwrapped_model):
|
| 528 |
+
unwrapped_model = unwrapped_model.base_model.model
|
| 529 |
+
last_hidden_state = unwrapped_model.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
| 530 |
+
last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H)
|
| 531 |
+
if logits_to_keep is not None:
|
| 532 |
+
last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
|
| 533 |
+
return last_hidden_state
|
| 534 |
+
|
| 535 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 536 |
+
@profiling_decorator
|
| 537 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thws, logits_to_keep, batch_size=None) -> torch.Tensor:
|
| 538 |
+
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
|
| 539 |
+
all_logps = []
|
| 540 |
+
patch_s = 0
|
| 541 |
+
for i in range(0, input_ids.size(0), batch_size):
|
| 542 |
+
input_ids_batch = input_ids[i : i + batch_size]
|
| 543 |
+
attention_mask_batch = attention_mask[i : i + batch_size]
|
| 544 |
+
img_id = self.processing_class.tokenizer.convert_tokens_to_ids('<|image_pad|>')
|
| 545 |
+
image_patch_nums = (input_ids_batch == img_id).sum().item() * 4 # 每个 image_pad 对应 4 个图像 patch
|
| 546 |
+
# print("image_patch_nums", image_patch_nums, pixel_values.shape)
|
| 547 |
+
pixel_values_batch = pixel_values[patch_s : patch_s + image_patch_nums]
|
| 548 |
+
patch_s += image_patch_nums
|
| 549 |
+
image_grid_thw_batch = image_grid_thws[i : i + batch_size]
|
| 550 |
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
| 551 |
+
logits = model(
|
| 552 |
+
input_ids=input_ids_batch, pixel_values=pixel_values_batch, image_grid_thw=image_grid_thw_batch,
|
| 553 |
+
attention_mask=attention_mask_batch
|
| 554 |
+
).logits
|
| 555 |
+
# logits = logits[:, :-1, :] # (B, L-1, H)
|
| 556 |
+
if logits_to_keep is not None:
|
| 557 |
+
logits = logits[:, -logits_to_keep-1:, :] # (B, logits_to_keep, H)
|
| 558 |
+
|
| 559 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 560 |
+
input_ids_batch = input_ids_batch[:, -logits_to_keep:]
|
| 561 |
+
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
| 562 |
+
# See https://github.com/huggingface/trl/issues/2770
|
| 563 |
+
logits = logits[:, -logits_to_keep:]
|
| 564 |
+
# Divide logits by sampling temperature.
|
| 565 |
+
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
|
| 566 |
+
logits = logits / self.temperature
|
| 567 |
+
logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens
|
| 568 |
+
all_logps.append(logps)
|
| 569 |
+
return torch.cat(all_logps, dim=0)
|
| 570 |
+
|
| 571 |
+
@profiling_decorator
|
| 572 |
+
def _prepare_inputs(
|
| 573 |
+
self, accumulated_local_batch: dict[str, Union[torch.Tensor, Any]]
|
| 574 |
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
| 575 |
+
|
| 576 |
+
mode = "train" if self.model.training else "eval"
|
| 577 |
+
if mode == "train":
|
| 578 |
+
generate_every = self.args.gradient_accumulation_steps * self.num_iterations
|
| 579 |
+
if self._step % generate_every == 0 or self._buffered_inputs is None:
|
| 580 |
+
# self._buffered_inputs=None can occur when resuming from a checkpoint
|
| 581 |
+
accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
|
| 582 |
+
self._buffered_inputs = split_tensor_dict(
|
| 583 |
+
accumulated_local_batch, self.args.gradient_accumulation_steps
|
| 584 |
+
)
|
| 585 |
+
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
| 586 |
+
self._step += 1
|
| 587 |
+
else:
|
| 588 |
+
# In evaluation, there is neither gradient accumulation, nor multiple iterations
|
| 589 |
+
inputs = self._generate_and_score_completions(accumulated_local_batch)
|
| 590 |
+
return inputs
|
| 591 |
+
|
| 592 |
+
def _generate_and_score_completions(
|
| 593 |
+
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
| 594 |
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
| 595 |
+
|
| 596 |
+
# TODO
|
| 597 |
+
device = self.accelerator.device
|
| 598 |
+
mode = "train" if self.model.training else "eval"
|
| 599 |
+
|
| 600 |
+
inputs_for_generate = inputs.copy()
|
| 601 |
+
|
| 602 |
+
# 去除answer key
|
| 603 |
+
inputs_for_generate = [{k: v for k, v in x.items() if k != 'answer'} for x in inputs_for_generate]
|
| 604 |
+
|
| 605 |
+
dt_generate_dt = self.processing_func(inputs_for_generate)
|
| 606 |
+
prompt_inputs_generate = super(DyMETrainer, self)._prepare_inputs(dt_generate_dt)
|
| 607 |
+
if 'labels' in prompt_inputs_generate:
|
| 608 |
+
del prompt_inputs_generate["labels"]
|
| 609 |
+
prompt_ids = prompt_inputs_generate["input_ids"]
|
| 610 |
+
prompt_mask = prompt_inputs_generate["attention_mask"]
|
| 611 |
+
pixel_values = prompt_inputs_generate["pixel_values"]
|
| 612 |
+
|
| 613 |
+
image_grid_thws = prompt_inputs_generate["image_grid_thw"]
|
| 614 |
+
|
| 615 |
+
# Regular generation path
|
| 616 |
+
with unwrap_model_for_generation(
|
| 617 |
+
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 618 |
+
) as unwrapped_model:
|
| 619 |
+
with (
|
| 620 |
+
FSDP.summon_full_params(self.model_wrapped, recurse=False)
|
| 621 |
+
if self.is_fsdp_enabled
|
| 622 |
+
else nullcontext()
|
| 623 |
+
):
|
| 624 |
+
prompt_completion_ids = unwrapped_model.generate(**prompt_inputs_generate, generation_config=self.generation_config)
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
# Compute prompt length and extract completion ids
|
| 628 |
+
prompt_length = prompt_ids.size(1)
|
| 629 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 630 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 631 |
+
|
| 632 |
+
# Mask everything after the first EOS token
|
| 633 |
+
is_eos = completion_ids == self.processing_class.tokenizer.eos_token_id
|
| 634 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 635 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 636 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 637 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 638 |
+
|
| 639 |
+
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
| 640 |
+
if self.mask_truncated_completions:
|
| 641 |
+
truncated_completions = ~is_eos.any(dim=1)
|
| 642 |
+
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
|
| 643 |
+
|
| 644 |
+
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 645 |
+
|
| 646 |
+
batch_size = len(completion_ids)
|
| 647 |
+
images = [x['image'] for x in inputs]
|
| 648 |
+
prompts = [x['prompt'] for x in inputs]
|
| 649 |
+
question_wo_prompts = [x['question_wo_prompt'] for x in inputs]
|
| 650 |
+
hints = [x.get('hint', '') for x in inputs]
|
| 651 |
+
answers = [x['answer'] for x in inputs]
|
| 652 |
+
images_path = [image if isinstance(image, str) else image.filename for image in images]
|
| 653 |
+
batch_data = {'prompt': prompts, 'hints': hints,
|
| 654 |
+
'image': images_path, 'response': completions, 'answer': answers}
|
| 655 |
+
|
| 656 |
+
gpu_id = self.accelerator.device.index
|
| 657 |
+
all_rewards, format_rewards, acc_rewards, context_rewards = calculate_rewards_in_parallel(self.checker, batch_data,
|
| 658 |
+
gpu_id=gpu_id,
|
| 659 |
+
task=self.task_name, num_threads=1)
|
| 660 |
+
all_rewards = torch.tensor(all_rewards, dtype=torch.float32).to(self.accelerator.device)
|
| 661 |
+
format_rewards = torch.tensor(format_rewards, dtype=torch.float32).to(self.accelerator.device)
|
| 662 |
+
context_rewards = torch.tensor(context_rewards, dtype=torch.float32).to(self.accelerator.device)
|
| 663 |
+
acc_rewards = torch.tensor(acc_rewards, dtype=torch.float32).to(self.accelerator.device)
|
| 664 |
+
|
| 665 |
+
rewards_per_func = torch.zeros([len(all_rewards), 3], device=device)
|
| 666 |
+
|
| 667 |
+
rewards_per_func[:, 0] = format_rewards.clone()
|
| 668 |
+
rewards_per_func[:, 1] = context_rewards.clone()
|
| 669 |
+
rewards_per_func[:, -1] = acc_rewards.clone()
|
| 670 |
+
|
| 671 |
+
rewards_per_func = gather(rewards_per_func)
|
| 672 |
+
|
| 673 |
+
# Apply weights to each reward function's output and sum
|
| 674 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
|
| 675 |
+
|
| 676 |
+
# Compute grouped-wise rewards
|
| 677 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 678 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 679 |
+
|
| 680 |
+
# Normalize the rewards to compute the advantages
|
| 681 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 682 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 683 |
+
advantages = rewards - mean_grouped_rewards
|
| 684 |
+
if self.scale_rewards:
|
| 685 |
+
advantages = advantages / (std_grouped_rewards + 1e-4)
|
| 686 |
+
|
| 687 |
+
# Slice to keep only the local part of the data
|
| 688 |
+
process_slice = slice(
|
| 689 |
+
self.accelerator.process_index * len(prompts),
|
| 690 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 691 |
+
)
|
| 692 |
+
advantages = advantages[process_slice]
|
| 693 |
+
advantages = advantages.reshape(-1, 1)
|
| 694 |
+
acc_rewards = acc_rewards.view(-1, self.num_generations)
|
| 695 |
+
format_rewards = format_rewards.view(-1, self.num_generations)
|
| 696 |
+
|
| 697 |
+
has_correct = (acc_rewards > 0.5).sum(1)
|
| 698 |
+
format_rewards = format_rewards.view(-1)
|
| 699 |
+
|
| 700 |
+
sft_check = []
|
| 701 |
+
for i in range(batch_size):
|
| 702 |
+
batch_id = i // self.num_generations
|
| 703 |
+
sft_check.append((has_correct[batch_id] == 0) & (i % self.num_generations == 0))
|
| 704 |
+
|
| 705 |
+
hints = refine_context_in_parallel(self.refiner, question_wo_prompts, hints, answers, task=self.task_name, gpu_id=gpu_id, num_threads=1)
|
| 706 |
+
|
| 707 |
+
sft_gt = [hint + '\n' + answer + self.end_flag for hint, answer in zip(hints, answers)]
|
| 708 |
+
|
| 709 |
+
sft_dt = self.processing_class.tokenizer(sft_gt, return_tensors="pt", padding=True,
|
| 710 |
+
padding_side="right")
|
| 711 |
+
sft_padded_ids = sft_dt['input_ids'].to(device)
|
| 712 |
+
sft_attn_masks = sft_dt['attention_mask'].to(device)
|
| 713 |
+
sft_advantages = torch.ones_like(sft_attn_masks, device=device)
|
| 714 |
+
|
| 715 |
+
final_completion_id_list = []
|
| 716 |
+
final_completion_mask_list = []
|
| 717 |
+
final_advantange_list = []
|
| 718 |
+
|
| 719 |
+
for i in range(len(sft_padded_ids)):
|
| 720 |
+
batch_id = i // self.num_generations
|
| 721 |
+
if has_correct[batch_id] == 0:
|
| 722 |
+
if sft_check[i]: # 第一个修改为正确答案,其他的保留为错误的。
|
| 723 |
+
completion_id_ = torch.cat([sft_padded_ids[i], completion_ids[i][0:0]])
|
| 724 |
+
completion_mask_ = torch.cat([sft_attn_masks[i], completion_mask[i][0:0]])
|
| 725 |
+
advantange_ = torch.cat([sft_advantages[i], advantages[i][0:0]])
|
| 726 |
+
advantange_[:] = 1
|
| 727 |
+
else:
|
| 728 |
+
completion_id_ = torch.cat([completion_ids[i], completion_ids[i][0:0]])
|
| 729 |
+
completion_mask_ = torch.cat([completion_mask[i], sft_attn_masks[i][0:0]])
|
| 730 |
+
advantange_ = torch.cat([advantages[i], sft_advantages[i][0:0]])
|
| 731 |
+
advantange_ = advantange_.repeat_interleave(len(completion_id_))
|
| 732 |
+
advantange_[:] = 0
|
| 733 |
+
|
| 734 |
+
else:
|
| 735 |
+
completion_id_ = torch.cat([completion_ids[i], sft_padded_ids[i][0:0]])
|
| 736 |
+
completion_mask_ = torch.cat([completion_mask[i], sft_attn_masks[i][0:0]])
|
| 737 |
+
advantange_ = torch.cat([advantages[i], sft_advantages[i][0:0]])
|
| 738 |
+
# 如果advantange_是一个数字的话需要扩展维度
|
| 739 |
+
advantange_ = advantange_.repeat_interleave(len(completion_id_))
|
| 740 |
+
|
| 741 |
+
if has_correct[batch_id] == self.num_generations: # 全部正确时停止优化
|
| 742 |
+
advantange_[:] = 0
|
| 743 |
+
|
| 744 |
+
final_completion_id_list.append(completion_id_)
|
| 745 |
+
final_completion_mask_list.append(completion_mask_)
|
| 746 |
+
final_advantange_list.append(advantange_)
|
| 747 |
+
|
| 748 |
+
completion_ids = pad_sequence(final_completion_id_list, batch_first=True,
|
| 749 |
+
padding_value=self.processing_class.tokenizer.pad_token_id).long()
|
| 750 |
+
completion_mask = pad_sequence(final_completion_mask_list, batch_first=True, padding_value=0)
|
| 751 |
+
completion_advantange = pad_sequence(final_advantange_list, batch_first=True, padding_value=0)
|
| 752 |
+
completion_ids = completion_ids.to(device)
|
| 753 |
+
completion_mask = completion_mask.to(device)
|
| 754 |
+
completion_advantange = completion_advantange.to(device)
|
| 755 |
+
input_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1).long()
|
| 756 |
+
attention_completion_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 757 |
+
|
| 758 |
+
for s, a in enumerate(completion_advantange):
|
| 759 |
+
if acc_rewards.view(-1)[s] > 0 and format_rewards.view(-1)[s] > 0 and a[0] < 0:
|
| 760 |
+
print('no')
|
| 761 |
+
|
| 762 |
+
if self.accelerator.device.index == 0:
|
| 763 |
+
completion_id = completion_ids[0]
|
| 764 |
+
completion_id_pos = completion_id[(completion_advantange[0] > 0) & (completion_mask[0] > 0)]
|
| 765 |
+
completion_id_neg = completion_id[(completion_advantange[0] < 0) & (completion_mask[0] > 0)]
|
| 766 |
+
|
| 767 |
+
show = self.processing_class.decode(completion_id_pos, skip_special_tokens=False)
|
| 768 |
+
show_neg = self.processing_class.decode(completion_id_neg, skip_special_tokens=False)
|
| 769 |
+
print("\n=====has_correct====================\n", has_correct,)
|
| 770 |
+
print("\n=====prediction====================\n", completions[0],)
|
| 771 |
+
if show != "":
|
| 772 |
+
print("\n=====POS GT====================\n", show)
|
| 773 |
+
if show_neg != "":
|
| 774 |
+
print("\n======NEG GT===================\n", show_neg)
|
| 775 |
+
|
| 776 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 777 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
| 778 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 779 |
+
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
| 780 |
+
|
| 781 |
+
with torch.no_grad():
|
| 782 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
|
| 783 |
+
# computation here, and use per_token_logps.detach() instead.
|
| 784 |
+
if self.num_iterations > 1:
|
| 785 |
+
old_per_token_logps = self._get_per_token_logps(self.model, input_completion_ids, attention_completion_mask, pixel_values, image_grid_thws,
|
| 786 |
+
logits_to_keep, batch_size)
|
| 787 |
+
else:
|
| 788 |
+
old_per_token_logps = None
|
| 789 |
+
|
| 790 |
+
# Log the metrics
|
| 791 |
+
if mode == "train":
|
| 792 |
+
self.state.num_input_tokens_seen += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
|
| 793 |
+
|
| 794 |
+
# log completion lengths, mean, min, max
|
| 795 |
+
agg_completion_mask = self.accelerator.gather_for_metrics(completion_mask.sum(1))
|
| 796 |
+
self._metrics[mode]["completions/mean_length"].append(agg_completion_mask.float().mean().item())
|
| 797 |
+
|
| 798 |
+
# identify sequences that terminated with EOS and log their lengths
|
| 799 |
+
agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
|
| 800 |
+
term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
|
| 801 |
+
clipped_completions_ratio = 1 - len(term_completion_mask) / len(agg_completion_mask)
|
| 802 |
+
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
|
| 803 |
+
|
| 804 |
+
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
|
| 805 |
+
for i, reward_func_name in enumerate(self.reward_func_names):
|
| 806 |
+
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
|
| 807 |
+
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
|
| 808 |
+
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
|
| 809 |
+
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
|
| 810 |
+
|
| 811 |
+
for i, name in enumerate(self.reward_func_names):
|
| 812 |
+
self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
|
| 813 |
+
# completion_advantange: (batch_size, seq_len) 或 (batch_size, n)
|
| 814 |
+
mask_pos = completion_advantange > 0 # 正优势位置
|
| 815 |
+
row_min = completion_advantange.min(dim=1, keepdim=True).values.abs() # (batch, 1)
|
| 816 |
+
|
| 817 |
+
# 只对正优势加 abs(row_min),其余位置设 0
|
| 818 |
+
# completion_advantange = torch.where(
|
| 819 |
+
# mask_pos,
|
| 820 |
+
# completion_advantange + row_min, # broadcasting 自动对齐到每一行
|
| 821 |
+
# torch.zeros_like(completion_advantange)
|
| 822 |
+
# )
|
| 823 |
+
return {
|
| 824 |
+
"prompt_ids": prompt_ids,
|
| 825 |
+
"prompt_mask": prompt_mask,
|
| 826 |
+
"pixel_values": pixel_values,
|
| 827 |
+
"completion_ids": completion_ids,
|
| 828 |
+
"completion_mask": completion_mask,
|
| 829 |
+
"advantages": completion_advantange,
|
| 830 |
+
"old_per_token_logps": old_per_token_logps,
|
| 831 |
+
# "has_correct": has_correct,
|
| 832 |
+
"image_grid_thws": image_grid_thws
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
def compute_liger_loss(self, unwrapped_model, inputs):
|
| 836 |
+
# Compute the per-token log probabilities for the model
|
| 837 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 838 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 839 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 840 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 841 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 842 |
+
|
| 843 |
+
# Compute the KL divergence between the model and the reference model
|
| 844 |
+
ref_per_token_logps = None
|
| 845 |
+
|
| 846 |
+
# get the last hidden state of the model
|
| 847 |
+
last_hidden_state = self._get_last_hidden_state(unwrapped_model, input_ids, attention_mask, logits_to_keep)
|
| 848 |
+
|
| 849 |
+
# compute loss and metrics using liger grpo loss
|
| 850 |
+
loss, metrics = self.liger_grpo_loss(
|
| 851 |
+
_input=last_hidden_state,
|
| 852 |
+
lin_weight=unwrapped_model.lm_head.weight,
|
| 853 |
+
selected_token_ids=completion_ids,
|
| 854 |
+
attention_mask=completion_mask,
|
| 855 |
+
advantages=inputs["advantages"][:, 0],
|
| 856 |
+
bias=unwrapped_model.lm_head.bias,
|
| 857 |
+
old_per_token_logps=inputs["old_per_token_logps"],
|
| 858 |
+
ref_per_token_logps=ref_per_token_logps,
|
| 859 |
+
)
|
| 860 |
+
# Extract metrics from the liger_grpo_loss output
|
| 861 |
+
# KL divergence is the first metric when beta is non-zero
|
| 862 |
+
mean_kl = metrics[0] if self.beta != 0.0 else None
|
| 863 |
+
clip_ratio = metrics[-1]
|
| 864 |
+
|
| 865 |
+
mode = "train" if self.model.training else "eval"
|
| 866 |
+
if self.beta != 0.0:
|
| 867 |
+
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 868 |
+
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
| 869 |
+
return loss
|
| 870 |
+
|
| 871 |
+
@profiling_decorator
|
| 872 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 873 |
+
if return_outputs:
|
| 874 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 875 |
+
if self.use_liger_loss:
|
| 876 |
+
# Compute the loss using the liger grpo loss
|
| 877 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
| 878 |
+
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
|
| 879 |
+
else:
|
| 880 |
+
return self._compute_loss(model, inputs)
|
| 881 |
+
|
| 882 |
+
def _compute_loss(self, model, inputs):
|
| 883 |
+
# return torch.nn.Parameter(torch.tensor(0.0, device=self.accelerator.device)) # Dummy loss for compatibility
|
| 884 |
+
# Compute the per-token log probabilities for the model
|
| 885 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 886 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 887 |
+
pixel_values = inputs["pixel_values"]
|
| 888 |
+
|
| 889 |
+
# has_correct = inputs["has_correct"]
|
| 890 |
+
image_grid_thws = inputs["image_grid_thws"]
|
| 891 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 892 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 893 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 894 |
+
try:
|
| 895 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thws,
|
| 896 |
+
logits_to_keep)
|
| 897 |
+
except Exception as e:
|
| 898 |
+
print(f"Error in _get_per_token_logps: {e}")
|
| 899 |
+
raise e
|
| 900 |
+
|
| 901 |
+
# sft_loss = -(per_token_logps * completion_mask).sum(-1) / completion_mask.sum(-1)
|
| 902 |
+
advantages = inputs["advantages"][:, 0]
|
| 903 |
+
# sft_loss = (sft_loss * (advantages > 0)).sum() * (has_correct == 0)
|
| 904 |
+
# return sft_loss
|
| 905 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
|
| 906 |
+
# _generate_and_score_completions) and use per_token_logps.detach() instead.
|
| 907 |
+
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
|
| 908 |
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
| 909 |
+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
|
| 910 |
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
| 911 |
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
| 912 |
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
| 913 |
+
|
| 914 |
+
if self.loss_type == "grpo":
|
| 915 |
+
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
|
| 916 |
+
elif self.loss_type == "bnpo":
|
| 917 |
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
|
| 918 |
+
elif self.loss_type == "dr_grpo":
|
| 919 |
+
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
|
| 920 |
+
else:
|
| 921 |
+
raise ValueError(f"Unknown loss type: {self.loss_type}")
|
| 922 |
+
# loss = (has_correct > 0) * loss + sft_loss
|
| 923 |
+
# loss = (has_correct > 0) * loss
|
| 924 |
+
# Log the metrics
|
| 925 |
+
mode = "train" if self.model.training else "eval"
|
| 926 |
+
|
| 927 |
+
# Compute the clipped probability ratios
|
| 928 |
+
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
|
| 929 |
+
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
|
| 930 |
+
is_region_clipped = is_low_clipped | is_high_clipped
|
| 931 |
+
|
| 932 |
+
low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
|
| 933 |
+
high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
|
| 934 |
+
clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
|
| 935 |
+
|
| 936 |
+
gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
|
| 937 |
+
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
|
| 938 |
+
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
|
| 939 |
+
gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
|
| 940 |
+
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
|
| 941 |
+
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
|
| 942 |
+
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
|
| 943 |
+
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
|
| 944 |
+
return loss
|
| 945 |
+
|
| 946 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
| 947 |
+
inputs = self._prepare_inputs(inputs)
|
| 948 |
+
with torch.no_grad():
|
| 949 |
+
with self.compute_loss_context_manager():
|
| 950 |
+
loss = self.compute_loss(model, inputs)
|
| 951 |
+
loss = loss.mean().detach()
|
| 952 |
+
return loss, None, None
|
| 953 |
+
|
| 954 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 955 |
+
mode = "train" if self.model.training else "eval"
|
| 956 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
| 957 |
+
|
| 958 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 959 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 960 |
+
if mode == "eval":
|
| 961 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 962 |
+
|
| 963 |
+
logs = {**logs, **metrics}
|
| 964 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 965 |
+
super().log(logs, start_time)
|
| 966 |
+
else: # transformers<=4.46
|
| 967 |
+
super().log(logs)
|
| 968 |
+
self._metrics[mode].clear()
|
| 969 |
+
|
| 970 |
+
if self.accelerator.is_main_process and self.log_completions:
|
| 971 |
+
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
|
| 972 |
+
import pandas as pd
|
| 973 |
+
|
| 974 |
+
table = {
|
| 975 |
+
"step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]),
|
| 976 |
+
"prompt": self._textual_logs["prompt"],
|
| 977 |
+
"completion": self._textual_logs["completion"],
|
| 978 |
+
**self._textual_logs["rewards"],
|
| 979 |
+
}
|
| 980 |
+
df = pd.DataFrame(table)
|
| 981 |
+
if self.wandb_log_unique_prompts:
|
| 982 |
+
df = df.drop_duplicates(subset=["prompt"])
|
| 983 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
trainer/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .DyMETrainer import DyMETrainer
|
| 2 |
+
|
| 3 |
+
__all__ = ["DyMETrainer"]
|