Robotics
Transformers
Safetensors
English
prts_qwen3_vl
feature-extraction
vision-language-action
vla
libero
qwen3-vl
prts
custom_code
Instructions to use TeleEmbodied/PRTS-4B-LIBERO with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TeleEmbodied/PRTS-4B-LIBERO with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("TeleEmbodied/PRTS-4B-LIBERO", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse files- .gitattributes +1 -0
- README.md +88 -0
- added_tokens.json +2081 -0
- chat_template.jinja +120 -0
- config.json +155 -0
- configuration_prts_qwen3_vl.py +345 -0
- dit_action_head.py +1230 -0
- generation_config.json +12 -0
- merges.txt +0 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_prts_qwen3_vl.py +935 -0
- modeling_qwen3_vl.py +1645 -0
- preprocessor_config.json +39 -0
- processing_prts_qwen3_vl.py +352 -0
- special_tokens_map.json +31 -0
- statistics.json +333 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
- training_args.bin +3 -0
- video_preprocessor_config.json +41 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
library_name: transformers
|
| 4 |
+
pipeline_tag: robotics
|
| 5 |
+
tags:
|
| 6 |
+
- robotics
|
| 7 |
+
- vision-language-action
|
| 8 |
+
- vla
|
| 9 |
+
- libero
|
| 10 |
+
- qwen3-vl
|
| 11 |
+
- prts
|
| 12 |
+
- custom_code
|
| 13 |
+
language:
|
| 14 |
+
- en
|
| 15 |
+
base_model: TeleEmbodied/PRTS-4B
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
<h1 align="center">PRTS-4B-LIBERO</h1>
|
| 19 |
+
|
| 20 |
+
<p align="center">
|
| 21 |
+
<a href="https://arxiv.org/abs/2604.27472"><img src="https://img.shields.io/badge/arXiv-2604.27472-b31b1b.svg" alt="arXiv"></a>
|
| 22 |
+
|
| 23 |
+
<a href="https://github.com/TeleHuman/PRTS"><img src="https://img.shields.io/badge/GitHub-PRTS-181717.svg" alt="GitHub"></a>
|
| 24 |
+
|
| 25 |
+
<a href="https://huggingface.co/TeleEmbodied/PRTS-4B"><img src="https://img.shields.io/badge/Base-PRTS--4B-yellow.svg" alt="Base model"></a>
|
| 26 |
+
</p>
|
| 27 |
+
|
| 28 |
+
**PRTS-4B-LIBERO** is the LIBERO fine-tuned variant of [`TeleEmbodied/PRTS-4B`](https://huggingface.co/TeleEmbodied/PRTS-4B). This is the exact checkpoint used to report the LIBERO numbers in the PRTS paper. For the base model card (architecture, prompt format, contrastive RL design), please refer to the parent [PRTS-4B](https://huggingface.co/TeleEmbodied/PRTS-4B) repository.
|
| 29 |
+
|
| 30 |
+
## Post-training budget
|
| 31 |
+
|
| 32 |
+
Fine-tuned from `TeleEmbodied/PRTS-4B` with the launch script [`scripts/ft/launch_finetune.sh`](https://github.com/TeleHuman/PRTS/blob/main/scripts/ft/launch_finetune.sh) in the open-source repo. Key settings:
|
| 33 |
+
|
| 34 |
+
| | |
|
| 35 |
+
| :--- | :--- |
|
| 36 |
+
| Base model | `TeleEmbodied/PRTS-4B` |
|
| 37 |
+
| Dataset config | `configs/post-train/libero.yaml` |
|
| 38 |
+
| Embodiment tag | `libero_panda` |
|
| 39 |
+
| Hardware | 4 GPUs, DeepSpeed ZeRO-2, bf16, `flash_attention_3`, no gradient checkpointing |
|
| 40 |
+
| Steps | 30,000 total, 5,000 warmup, save every 10,000 |
|
| 41 |
+
| Effective batch | 8 (per-device) × 4 GPUs × 1 (grad-acc) = **32** |
|
| 42 |
+
| LRs | `1e-5` for vision / merger / LLM; `1e-4` for the action head |
|
| 43 |
+
| Scheduler | `cosine_with_min_lr` (min `1e-6`) |
|
| 44 |
+
| Optimizer | AdamW (β1=0.9, β2=0.95, ε=1e-8), weight decay `1e-8`, grad clip `1.0` |
|
| 45 |
+
| Action head | DiT-L + MoT action expert, chunk size `20`, max action dim `32` |
|
| 46 |
+
| Action normalization | `QUANTILE` (stats bundled in this checkpoint) |
|
| 47 |
+
| Seed | 42 |
|
| 48 |
+
|
| 49 |
+
## Loading for evaluation
|
| 50 |
+
|
| 51 |
+
This checkpoint plugs into the policy server [`scripts/serve_policy.py`](https://github.com/TeleHuman/PRTS/blob/main/scripts/serve_policy.py). Update the `EnvMode.LIBERO` entry in `DEFAULT_CHECKPOINT` so that `dir=` points to your local download of this repo. Normalization stats are already bundled in the checkpoint, so `dataset_path` can be left as `None`:
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
EnvMode.LIBERO: Checkpoint(
|
| 55 |
+
config="prts_libero",
|
| 56 |
+
dir="/path/to/PRTS-4B-libero", # local download path of this repo
|
| 57 |
+
action_dim=7,
|
| 58 |
+
dataset_path=None, # normalization stats are bundled in the checkpoint
|
| 59 |
+
state_mode="QUANTILE",
|
| 60 |
+
),
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Running LIBERO evaluation
|
| 64 |
+
|
| 65 |
+
Follow the LIBERO simulation setup in [`examples/libero/README.md`](https://github.com/TeleHuman/PRTS/blob/main/examples/libero/README.md), then start the policy server from the PRTS repo root with [`examples/libero/run_libero_server.sh`](https://github.com/TeleHuman/PRTS/blob/main/examples/libero/run_libero_server.sh):
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
bash examples/libero/run_libero_server.sh
|
| 69 |
+
# which runs:
|
| 70 |
+
# CUDA_VISIBLE_DEVICES=0 python scripts/serve_policy.py --env LIBERO --port 10000
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
The LIBERO simulator (Terminal 1 in the example README) connects to this server over websocket and rolls out the 4 LIBERO task suites.
|
| 74 |
+
|
| 75 |
+
## License
|
| 76 |
+
|
| 77 |
+
Released under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) — free for academic and non-commercial research; commercial use is **not** permitted.
|
| 78 |
+
|
| 79 |
+
## Citation
|
| 80 |
+
|
| 81 |
+
```bibtex
|
| 82 |
+
@article{zhang2026prts,
|
| 83 |
+
title = {PRTS: A Primitive Reasoning and Tasking System via Contrastive Representations},
|
| 84 |
+
author = {Yang Zhang and Jiangyuan Zhao and Chenyou Fan and Fangzheng Yan and Tian Li and Haitong Tang and Sen Fu and Xuan'er Wu and Qizhen Weng and Weinan Zhang and Xiu Li and Chi Zhang and Chenjia Bai and Xuelong Li},
|
| 85 |
+
journal = {arXiv preprint arXiv:2604.27472},
|
| 86 |
+
year = {2026},
|
| 87 |
+
}
|
| 88 |
+
```
|
added_tokens.json
ADDED
|
@@ -0,0 +1,2081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</think>": 151668,
|
| 3 |
+
"</tool_call>": 151658,
|
| 4 |
+
"</tool_response>": 151666,
|
| 5 |
+
"<think>": 151667,
|
| 6 |
+
"<tool_call>": 151657,
|
| 7 |
+
"<tool_response>": 151665,
|
| 8 |
+
"<|action_end|>": 151671,
|
| 9 |
+
"<|action_pad|>": 151670,
|
| 10 |
+
"<|action_start|>": 151669,
|
| 11 |
+
"<|action_token_0|>": 151674,
|
| 12 |
+
"<|action_token_1000|>": 152674,
|
| 13 |
+
"<|action_token_1001|>": 152675,
|
| 14 |
+
"<|action_token_1002|>": 152676,
|
| 15 |
+
"<|action_token_1003|>": 152677,
|
| 16 |
+
"<|action_token_1004|>": 152678,
|
| 17 |
+
"<|action_token_1005|>": 152679,
|
| 18 |
+
"<|action_token_1006|>": 152680,
|
| 19 |
+
"<|action_token_1007|>": 152681,
|
| 20 |
+
"<|action_token_1008|>": 152682,
|
| 21 |
+
"<|action_token_1009|>": 152683,
|
| 22 |
+
"<|action_token_100|>": 151774,
|
| 23 |
+
"<|action_token_1010|>": 152684,
|
| 24 |
+
"<|action_token_1011|>": 152685,
|
| 25 |
+
"<|action_token_1012|>": 152686,
|
| 26 |
+
"<|action_token_1013|>": 152687,
|
| 27 |
+
"<|action_token_1014|>": 152688,
|
| 28 |
+
"<|action_token_1015|>": 152689,
|
| 29 |
+
"<|action_token_1016|>": 152690,
|
| 30 |
+
"<|action_token_1017|>": 152691,
|
| 31 |
+
"<|action_token_1018|>": 152692,
|
| 32 |
+
"<|action_token_1019|>": 152693,
|
| 33 |
+
"<|action_token_101|>": 151775,
|
| 34 |
+
"<|action_token_1020|>": 152694,
|
| 35 |
+
"<|action_token_1021|>": 152695,
|
| 36 |
+
"<|action_token_1022|>": 152696,
|
| 37 |
+
"<|action_token_1023|>": 152697,
|
| 38 |
+
"<|action_token_1024|>": 152698,
|
| 39 |
+
"<|action_token_1025|>": 152699,
|
| 40 |
+
"<|action_token_1026|>": 152700,
|
| 41 |
+
"<|action_token_1027|>": 152701,
|
| 42 |
+
"<|action_token_1028|>": 152702,
|
| 43 |
+
"<|action_token_1029|>": 152703,
|
| 44 |
+
"<|action_token_102|>": 151776,
|
| 45 |
+
"<|action_token_1030|>": 152704,
|
| 46 |
+
"<|action_token_1031|>": 152705,
|
| 47 |
+
"<|action_token_1032|>": 152706,
|
| 48 |
+
"<|action_token_1033|>": 152707,
|
| 49 |
+
"<|action_token_1034|>": 152708,
|
| 50 |
+
"<|action_token_1035|>": 152709,
|
| 51 |
+
"<|action_token_1036|>": 152710,
|
| 52 |
+
"<|action_token_1037|>": 152711,
|
| 53 |
+
"<|action_token_1038|>": 152712,
|
| 54 |
+
"<|action_token_1039|>": 152713,
|
| 55 |
+
"<|action_token_103|>": 151777,
|
| 56 |
+
"<|action_token_1040|>": 152714,
|
| 57 |
+
"<|action_token_1041|>": 152715,
|
| 58 |
+
"<|action_token_1042|>": 152716,
|
| 59 |
+
"<|action_token_1043|>": 152717,
|
| 60 |
+
"<|action_token_1044|>": 152718,
|
| 61 |
+
"<|action_token_1045|>": 152719,
|
| 62 |
+
"<|action_token_1046|>": 152720,
|
| 63 |
+
"<|action_token_1047|>": 152721,
|
| 64 |
+
"<|action_token_1048|>": 152722,
|
| 65 |
+
"<|action_token_1049|>": 152723,
|
| 66 |
+
"<|action_token_104|>": 151778,
|
| 67 |
+
"<|action_token_1050|>": 152724,
|
| 68 |
+
"<|action_token_1051|>": 152725,
|
| 69 |
+
"<|action_token_1052|>": 152726,
|
| 70 |
+
"<|action_token_1053|>": 152727,
|
| 71 |
+
"<|action_token_1054|>": 152728,
|
| 72 |
+
"<|action_token_1055|>": 152729,
|
| 73 |
+
"<|action_token_1056|>": 152730,
|
| 74 |
+
"<|action_token_1057|>": 152731,
|
| 75 |
+
"<|action_token_1058|>": 152732,
|
| 76 |
+
"<|action_token_1059|>": 152733,
|
| 77 |
+
"<|action_token_105|>": 151779,
|
| 78 |
+
"<|action_token_1060|>": 152734,
|
| 79 |
+
"<|action_token_1061|>": 152735,
|
| 80 |
+
"<|action_token_1062|>": 152736,
|
| 81 |
+
"<|action_token_1063|>": 152737,
|
| 82 |
+
"<|action_token_1064|>": 152738,
|
| 83 |
+
"<|action_token_1065|>": 152739,
|
| 84 |
+
"<|action_token_1066|>": 152740,
|
| 85 |
+
"<|action_token_1067|>": 152741,
|
| 86 |
+
"<|action_token_1068|>": 152742,
|
| 87 |
+
"<|action_token_1069|>": 152743,
|
| 88 |
+
"<|action_token_106|>": 151780,
|
| 89 |
+
"<|action_token_1070|>": 152744,
|
| 90 |
+
"<|action_token_1071|>": 152745,
|
| 91 |
+
"<|action_token_1072|>": 152746,
|
| 92 |
+
"<|action_token_1073|>": 152747,
|
| 93 |
+
"<|action_token_1074|>": 152748,
|
| 94 |
+
"<|action_token_1075|>": 152749,
|
| 95 |
+
"<|action_token_1076|>": 152750,
|
| 96 |
+
"<|action_token_1077|>": 152751,
|
| 97 |
+
"<|action_token_1078|>": 152752,
|
| 98 |
+
"<|action_token_1079|>": 152753,
|
| 99 |
+
"<|action_token_107|>": 151781,
|
| 100 |
+
"<|action_token_1080|>": 152754,
|
| 101 |
+
"<|action_token_1081|>": 152755,
|
| 102 |
+
"<|action_token_1082|>": 152756,
|
| 103 |
+
"<|action_token_1083|>": 152757,
|
| 104 |
+
"<|action_token_1084|>": 152758,
|
| 105 |
+
"<|action_token_1085|>": 152759,
|
| 106 |
+
"<|action_token_1086|>": 152760,
|
| 107 |
+
"<|action_token_1087|>": 152761,
|
| 108 |
+
"<|action_token_1088|>": 152762,
|
| 109 |
+
"<|action_token_1089|>": 152763,
|
| 110 |
+
"<|action_token_108|>": 151782,
|
| 111 |
+
"<|action_token_1090|>": 152764,
|
| 112 |
+
"<|action_token_1091|>": 152765,
|
| 113 |
+
"<|action_token_1092|>": 152766,
|
| 114 |
+
"<|action_token_1093|>": 152767,
|
| 115 |
+
"<|action_token_1094|>": 152768,
|
| 116 |
+
"<|action_token_1095|>": 152769,
|
| 117 |
+
"<|action_token_1096|>": 152770,
|
| 118 |
+
"<|action_token_1097|>": 152771,
|
| 119 |
+
"<|action_token_1098|>": 152772,
|
| 120 |
+
"<|action_token_1099|>": 152773,
|
| 121 |
+
"<|action_token_109|>": 151783,
|
| 122 |
+
"<|action_token_10|>": 151684,
|
| 123 |
+
"<|action_token_1100|>": 152774,
|
| 124 |
+
"<|action_token_1101|>": 152775,
|
| 125 |
+
"<|action_token_1102|>": 152776,
|
| 126 |
+
"<|action_token_1103|>": 152777,
|
| 127 |
+
"<|action_token_1104|>": 152778,
|
| 128 |
+
"<|action_token_1105|>": 152779,
|
| 129 |
+
"<|action_token_1106|>": 152780,
|
| 130 |
+
"<|action_token_1107|>": 152781,
|
| 131 |
+
"<|action_token_1108|>": 152782,
|
| 132 |
+
"<|action_token_1109|>": 152783,
|
| 133 |
+
"<|action_token_110|>": 151784,
|
| 134 |
+
"<|action_token_1110|>": 152784,
|
| 135 |
+
"<|action_token_1111|>": 152785,
|
| 136 |
+
"<|action_token_1112|>": 152786,
|
| 137 |
+
"<|action_token_1113|>": 152787,
|
| 138 |
+
"<|action_token_1114|>": 152788,
|
| 139 |
+
"<|action_token_1115|>": 152789,
|
| 140 |
+
"<|action_token_1116|>": 152790,
|
| 141 |
+
"<|action_token_1117|>": 152791,
|
| 142 |
+
"<|action_token_1118|>": 152792,
|
| 143 |
+
"<|action_token_1119|>": 152793,
|
| 144 |
+
"<|action_token_111|>": 151785,
|
| 145 |
+
"<|action_token_1120|>": 152794,
|
| 146 |
+
"<|action_token_1121|>": 152795,
|
| 147 |
+
"<|action_token_1122|>": 152796,
|
| 148 |
+
"<|action_token_1123|>": 152797,
|
| 149 |
+
"<|action_token_1124|>": 152798,
|
| 150 |
+
"<|action_token_1125|>": 152799,
|
| 151 |
+
"<|action_token_1126|>": 152800,
|
| 152 |
+
"<|action_token_1127|>": 152801,
|
| 153 |
+
"<|action_token_1128|>": 152802,
|
| 154 |
+
"<|action_token_1129|>": 152803,
|
| 155 |
+
"<|action_token_112|>": 151786,
|
| 156 |
+
"<|action_token_1130|>": 152804,
|
| 157 |
+
"<|action_token_1131|>": 152805,
|
| 158 |
+
"<|action_token_1132|>": 152806,
|
| 159 |
+
"<|action_token_1133|>": 152807,
|
| 160 |
+
"<|action_token_1134|>": 152808,
|
| 161 |
+
"<|action_token_1135|>": 152809,
|
| 162 |
+
"<|action_token_1136|>": 152810,
|
| 163 |
+
"<|action_token_1137|>": 152811,
|
| 164 |
+
"<|action_token_1138|>": 152812,
|
| 165 |
+
"<|action_token_1139|>": 152813,
|
| 166 |
+
"<|action_token_113|>": 151787,
|
| 167 |
+
"<|action_token_1140|>": 152814,
|
| 168 |
+
"<|action_token_1141|>": 152815,
|
| 169 |
+
"<|action_token_1142|>": 152816,
|
| 170 |
+
"<|action_token_1143|>": 152817,
|
| 171 |
+
"<|action_token_1144|>": 152818,
|
| 172 |
+
"<|action_token_1145|>": 152819,
|
| 173 |
+
"<|action_token_1146|>": 152820,
|
| 174 |
+
"<|action_token_1147|>": 152821,
|
| 175 |
+
"<|action_token_1148|>": 152822,
|
| 176 |
+
"<|action_token_1149|>": 152823,
|
| 177 |
+
"<|action_token_114|>": 151788,
|
| 178 |
+
"<|action_token_1150|>": 152824,
|
| 179 |
+
"<|action_token_1151|>": 152825,
|
| 180 |
+
"<|action_token_1152|>": 152826,
|
| 181 |
+
"<|action_token_1153|>": 152827,
|
| 182 |
+
"<|action_token_1154|>": 152828,
|
| 183 |
+
"<|action_token_1155|>": 152829,
|
| 184 |
+
"<|action_token_1156|>": 152830,
|
| 185 |
+
"<|action_token_1157|>": 152831,
|
| 186 |
+
"<|action_token_1158|>": 152832,
|
| 187 |
+
"<|action_token_1159|>": 152833,
|
| 188 |
+
"<|action_token_115|>": 151789,
|
| 189 |
+
"<|action_token_1160|>": 152834,
|
| 190 |
+
"<|action_token_1161|>": 152835,
|
| 191 |
+
"<|action_token_1162|>": 152836,
|
| 192 |
+
"<|action_token_1163|>": 152837,
|
| 193 |
+
"<|action_token_1164|>": 152838,
|
| 194 |
+
"<|action_token_1165|>": 152839,
|
| 195 |
+
"<|action_token_1166|>": 152840,
|
| 196 |
+
"<|action_token_1167|>": 152841,
|
| 197 |
+
"<|action_token_1168|>": 152842,
|
| 198 |
+
"<|action_token_1169|>": 152843,
|
| 199 |
+
"<|action_token_116|>": 151790,
|
| 200 |
+
"<|action_token_1170|>": 152844,
|
| 201 |
+
"<|action_token_1171|>": 152845,
|
| 202 |
+
"<|action_token_1172|>": 152846,
|
| 203 |
+
"<|action_token_1173|>": 152847,
|
| 204 |
+
"<|action_token_1174|>": 152848,
|
| 205 |
+
"<|action_token_1175|>": 152849,
|
| 206 |
+
"<|action_token_1176|>": 152850,
|
| 207 |
+
"<|action_token_1177|>": 152851,
|
| 208 |
+
"<|action_token_1178|>": 152852,
|
| 209 |
+
"<|action_token_1179|>": 152853,
|
| 210 |
+
"<|action_token_117|>": 151791,
|
| 211 |
+
"<|action_token_1180|>": 152854,
|
| 212 |
+
"<|action_token_1181|>": 152855,
|
| 213 |
+
"<|action_token_1182|>": 152856,
|
| 214 |
+
"<|action_token_1183|>": 152857,
|
| 215 |
+
"<|action_token_1184|>": 152858,
|
| 216 |
+
"<|action_token_1185|>": 152859,
|
| 217 |
+
"<|action_token_1186|>": 152860,
|
| 218 |
+
"<|action_token_1187|>": 152861,
|
| 219 |
+
"<|action_token_1188|>": 152862,
|
| 220 |
+
"<|action_token_1189|>": 152863,
|
| 221 |
+
"<|action_token_118|>": 151792,
|
| 222 |
+
"<|action_token_1190|>": 152864,
|
| 223 |
+
"<|action_token_1191|>": 152865,
|
| 224 |
+
"<|action_token_1192|>": 152866,
|
| 225 |
+
"<|action_token_1193|>": 152867,
|
| 226 |
+
"<|action_token_1194|>": 152868,
|
| 227 |
+
"<|action_token_1195|>": 152869,
|
| 228 |
+
"<|action_token_1196|>": 152870,
|
| 229 |
+
"<|action_token_1197|>": 152871,
|
| 230 |
+
"<|action_token_1198|>": 152872,
|
| 231 |
+
"<|action_token_1199|>": 152873,
|
| 232 |
+
"<|action_token_119|>": 151793,
|
| 233 |
+
"<|action_token_11|>": 151685,
|
| 234 |
+
"<|action_token_1200|>": 152874,
|
| 235 |
+
"<|action_token_1201|>": 152875,
|
| 236 |
+
"<|action_token_1202|>": 152876,
|
| 237 |
+
"<|action_token_1203|>": 152877,
|
| 238 |
+
"<|action_token_1204|>": 152878,
|
| 239 |
+
"<|action_token_1205|>": 152879,
|
| 240 |
+
"<|action_token_1206|>": 152880,
|
| 241 |
+
"<|action_token_1207|>": 152881,
|
| 242 |
+
"<|action_token_1208|>": 152882,
|
| 243 |
+
"<|action_token_1209|>": 152883,
|
| 244 |
+
"<|action_token_120|>": 151794,
|
| 245 |
+
"<|action_token_1210|>": 152884,
|
| 246 |
+
"<|action_token_1211|>": 152885,
|
| 247 |
+
"<|action_token_1212|>": 152886,
|
| 248 |
+
"<|action_token_1213|>": 152887,
|
| 249 |
+
"<|action_token_1214|>": 152888,
|
| 250 |
+
"<|action_token_1215|>": 152889,
|
| 251 |
+
"<|action_token_1216|>": 152890,
|
| 252 |
+
"<|action_token_1217|>": 152891,
|
| 253 |
+
"<|action_token_1218|>": 152892,
|
| 254 |
+
"<|action_token_1219|>": 152893,
|
| 255 |
+
"<|action_token_121|>": 151795,
|
| 256 |
+
"<|action_token_1220|>": 152894,
|
| 257 |
+
"<|action_token_1221|>": 152895,
|
| 258 |
+
"<|action_token_1222|>": 152896,
|
| 259 |
+
"<|action_token_1223|>": 152897,
|
| 260 |
+
"<|action_token_1224|>": 152898,
|
| 261 |
+
"<|action_token_1225|>": 152899,
|
| 262 |
+
"<|action_token_1226|>": 152900,
|
| 263 |
+
"<|action_token_1227|>": 152901,
|
| 264 |
+
"<|action_token_1228|>": 152902,
|
| 265 |
+
"<|action_token_1229|>": 152903,
|
| 266 |
+
"<|action_token_122|>": 151796,
|
| 267 |
+
"<|action_token_1230|>": 152904,
|
| 268 |
+
"<|action_token_1231|>": 152905,
|
| 269 |
+
"<|action_token_1232|>": 152906,
|
| 270 |
+
"<|action_token_1233|>": 152907,
|
| 271 |
+
"<|action_token_1234|>": 152908,
|
| 272 |
+
"<|action_token_1235|>": 152909,
|
| 273 |
+
"<|action_token_1236|>": 152910,
|
| 274 |
+
"<|action_token_1237|>": 152911,
|
| 275 |
+
"<|action_token_1238|>": 152912,
|
| 276 |
+
"<|action_token_1239|>": 152913,
|
| 277 |
+
"<|action_token_123|>": 151797,
|
| 278 |
+
"<|action_token_1240|>": 152914,
|
| 279 |
+
"<|action_token_1241|>": 152915,
|
| 280 |
+
"<|action_token_1242|>": 152916,
|
| 281 |
+
"<|action_token_1243|>": 152917,
|
| 282 |
+
"<|action_token_1244|>": 152918,
|
| 283 |
+
"<|action_token_1245|>": 152919,
|
| 284 |
+
"<|action_token_1246|>": 152920,
|
| 285 |
+
"<|action_token_1247|>": 152921,
|
| 286 |
+
"<|action_token_1248|>": 152922,
|
| 287 |
+
"<|action_token_1249|>": 152923,
|
| 288 |
+
"<|action_token_124|>": 151798,
|
| 289 |
+
"<|action_token_1250|>": 152924,
|
| 290 |
+
"<|action_token_1251|>": 152925,
|
| 291 |
+
"<|action_token_1252|>": 152926,
|
| 292 |
+
"<|action_token_1253|>": 152927,
|
| 293 |
+
"<|action_token_1254|>": 152928,
|
| 294 |
+
"<|action_token_1255|>": 152929,
|
| 295 |
+
"<|action_token_1256|>": 152930,
|
| 296 |
+
"<|action_token_1257|>": 152931,
|
| 297 |
+
"<|action_token_1258|>": 152932,
|
| 298 |
+
"<|action_token_1259|>": 152933,
|
| 299 |
+
"<|action_token_125|>": 151799,
|
| 300 |
+
"<|action_token_1260|>": 152934,
|
| 301 |
+
"<|action_token_1261|>": 152935,
|
| 302 |
+
"<|action_token_1262|>": 152936,
|
| 303 |
+
"<|action_token_1263|>": 152937,
|
| 304 |
+
"<|action_token_1264|>": 152938,
|
| 305 |
+
"<|action_token_1265|>": 152939,
|
| 306 |
+
"<|action_token_1266|>": 152940,
|
| 307 |
+
"<|action_token_1267|>": 152941,
|
| 308 |
+
"<|action_token_1268|>": 152942,
|
| 309 |
+
"<|action_token_1269|>": 152943,
|
| 310 |
+
"<|action_token_126|>": 151800,
|
| 311 |
+
"<|action_token_1270|>": 152944,
|
| 312 |
+
"<|action_token_1271|>": 152945,
|
| 313 |
+
"<|action_token_1272|>": 152946,
|
| 314 |
+
"<|action_token_1273|>": 152947,
|
| 315 |
+
"<|action_token_1274|>": 152948,
|
| 316 |
+
"<|action_token_1275|>": 152949,
|
| 317 |
+
"<|action_token_1276|>": 152950,
|
| 318 |
+
"<|action_token_1277|>": 152951,
|
| 319 |
+
"<|action_token_1278|>": 152952,
|
| 320 |
+
"<|action_token_1279|>": 152953,
|
| 321 |
+
"<|action_token_127|>": 151801,
|
| 322 |
+
"<|action_token_1280|>": 152954,
|
| 323 |
+
"<|action_token_1281|>": 152955,
|
| 324 |
+
"<|action_token_1282|>": 152956,
|
| 325 |
+
"<|action_token_1283|>": 152957,
|
| 326 |
+
"<|action_token_1284|>": 152958,
|
| 327 |
+
"<|action_token_1285|>": 152959,
|
| 328 |
+
"<|action_token_1286|>": 152960,
|
| 329 |
+
"<|action_token_1287|>": 152961,
|
| 330 |
+
"<|action_token_1288|>": 152962,
|
| 331 |
+
"<|action_token_1289|>": 152963,
|
| 332 |
+
"<|action_token_128|>": 151802,
|
| 333 |
+
"<|action_token_1290|>": 152964,
|
| 334 |
+
"<|action_token_1291|>": 152965,
|
| 335 |
+
"<|action_token_1292|>": 152966,
|
| 336 |
+
"<|action_token_1293|>": 152967,
|
| 337 |
+
"<|action_token_1294|>": 152968,
|
| 338 |
+
"<|action_token_1295|>": 152969,
|
| 339 |
+
"<|action_token_1296|>": 152970,
|
| 340 |
+
"<|action_token_1297|>": 152971,
|
| 341 |
+
"<|action_token_1298|>": 152972,
|
| 342 |
+
"<|action_token_1299|>": 152973,
|
| 343 |
+
"<|action_token_129|>": 151803,
|
| 344 |
+
"<|action_token_12|>": 151686,
|
| 345 |
+
"<|action_token_1300|>": 152974,
|
| 346 |
+
"<|action_token_1301|>": 152975,
|
| 347 |
+
"<|action_token_1302|>": 152976,
|
| 348 |
+
"<|action_token_1303|>": 152977,
|
| 349 |
+
"<|action_token_1304|>": 152978,
|
| 350 |
+
"<|action_token_1305|>": 152979,
|
| 351 |
+
"<|action_token_1306|>": 152980,
|
| 352 |
+
"<|action_token_1307|>": 152981,
|
| 353 |
+
"<|action_token_1308|>": 152982,
|
| 354 |
+
"<|action_token_1309|>": 152983,
|
| 355 |
+
"<|action_token_130|>": 151804,
|
| 356 |
+
"<|action_token_1310|>": 152984,
|
| 357 |
+
"<|action_token_1311|>": 152985,
|
| 358 |
+
"<|action_token_1312|>": 152986,
|
| 359 |
+
"<|action_token_1313|>": 152987,
|
| 360 |
+
"<|action_token_1314|>": 152988,
|
| 361 |
+
"<|action_token_1315|>": 152989,
|
| 362 |
+
"<|action_token_1316|>": 152990,
|
| 363 |
+
"<|action_token_1317|>": 152991,
|
| 364 |
+
"<|action_token_1318|>": 152992,
|
| 365 |
+
"<|action_token_1319|>": 152993,
|
| 366 |
+
"<|action_token_131|>": 151805,
|
| 367 |
+
"<|action_token_1320|>": 152994,
|
| 368 |
+
"<|action_token_1321|>": 152995,
|
| 369 |
+
"<|action_token_1322|>": 152996,
|
| 370 |
+
"<|action_token_1323|>": 152997,
|
| 371 |
+
"<|action_token_1324|>": 152998,
|
| 372 |
+
"<|action_token_1325|>": 152999,
|
| 373 |
+
"<|action_token_1326|>": 153000,
|
| 374 |
+
"<|action_token_1327|>": 153001,
|
| 375 |
+
"<|action_token_1328|>": 153002,
|
| 376 |
+
"<|action_token_1329|>": 153003,
|
| 377 |
+
"<|action_token_132|>": 151806,
|
| 378 |
+
"<|action_token_1330|>": 153004,
|
| 379 |
+
"<|action_token_1331|>": 153005,
|
| 380 |
+
"<|action_token_1332|>": 153006,
|
| 381 |
+
"<|action_token_1333|>": 153007,
|
| 382 |
+
"<|action_token_1334|>": 153008,
|
| 383 |
+
"<|action_token_1335|>": 153009,
|
| 384 |
+
"<|action_token_1336|>": 153010,
|
| 385 |
+
"<|action_token_1337|>": 153011,
|
| 386 |
+
"<|action_token_1338|>": 153012,
|
| 387 |
+
"<|action_token_1339|>": 153013,
|
| 388 |
+
"<|action_token_133|>": 151807,
|
| 389 |
+
"<|action_token_1340|>": 153014,
|
| 390 |
+
"<|action_token_1341|>": 153015,
|
| 391 |
+
"<|action_token_1342|>": 153016,
|
| 392 |
+
"<|action_token_1343|>": 153017,
|
| 393 |
+
"<|action_token_1344|>": 153018,
|
| 394 |
+
"<|action_token_1345|>": 153019,
|
| 395 |
+
"<|action_token_1346|>": 153020,
|
| 396 |
+
"<|action_token_1347|>": 153021,
|
| 397 |
+
"<|action_token_1348|>": 153022,
|
| 398 |
+
"<|action_token_1349|>": 153023,
|
| 399 |
+
"<|action_token_134|>": 151808,
|
| 400 |
+
"<|action_token_1350|>": 153024,
|
| 401 |
+
"<|action_token_1351|>": 153025,
|
| 402 |
+
"<|action_token_1352|>": 153026,
|
| 403 |
+
"<|action_token_1353|>": 153027,
|
| 404 |
+
"<|action_token_1354|>": 153028,
|
| 405 |
+
"<|action_token_1355|>": 153029,
|
| 406 |
+
"<|action_token_1356|>": 153030,
|
| 407 |
+
"<|action_token_1357|>": 153031,
|
| 408 |
+
"<|action_token_1358|>": 153032,
|
| 409 |
+
"<|action_token_1359|>": 153033,
|
| 410 |
+
"<|action_token_135|>": 151809,
|
| 411 |
+
"<|action_token_1360|>": 153034,
|
| 412 |
+
"<|action_token_1361|>": 153035,
|
| 413 |
+
"<|action_token_1362|>": 153036,
|
| 414 |
+
"<|action_token_1363|>": 153037,
|
| 415 |
+
"<|action_token_1364|>": 153038,
|
| 416 |
+
"<|action_token_1365|>": 153039,
|
| 417 |
+
"<|action_token_1366|>": 153040,
|
| 418 |
+
"<|action_token_1367|>": 153041,
|
| 419 |
+
"<|action_token_1368|>": 153042,
|
| 420 |
+
"<|action_token_1369|>": 153043,
|
| 421 |
+
"<|action_token_136|>": 151810,
|
| 422 |
+
"<|action_token_1370|>": 153044,
|
| 423 |
+
"<|action_token_1371|>": 153045,
|
| 424 |
+
"<|action_token_1372|>": 153046,
|
| 425 |
+
"<|action_token_1373|>": 153047,
|
| 426 |
+
"<|action_token_1374|>": 153048,
|
| 427 |
+
"<|action_token_1375|>": 153049,
|
| 428 |
+
"<|action_token_1376|>": 153050,
|
| 429 |
+
"<|action_token_1377|>": 153051,
|
| 430 |
+
"<|action_token_1378|>": 153052,
|
| 431 |
+
"<|action_token_1379|>": 153053,
|
| 432 |
+
"<|action_token_137|>": 151811,
|
| 433 |
+
"<|action_token_1380|>": 153054,
|
| 434 |
+
"<|action_token_1381|>": 153055,
|
| 435 |
+
"<|action_token_1382|>": 153056,
|
| 436 |
+
"<|action_token_1383|>": 153057,
|
| 437 |
+
"<|action_token_1384|>": 153058,
|
| 438 |
+
"<|action_token_1385|>": 153059,
|
| 439 |
+
"<|action_token_1386|>": 153060,
|
| 440 |
+
"<|action_token_1387|>": 153061,
|
| 441 |
+
"<|action_token_1388|>": 153062,
|
| 442 |
+
"<|action_token_1389|>": 153063,
|
| 443 |
+
"<|action_token_138|>": 151812,
|
| 444 |
+
"<|action_token_1390|>": 153064,
|
| 445 |
+
"<|action_token_1391|>": 153065,
|
| 446 |
+
"<|action_token_1392|>": 153066,
|
| 447 |
+
"<|action_token_1393|>": 153067,
|
| 448 |
+
"<|action_token_1394|>": 153068,
|
| 449 |
+
"<|action_token_1395|>": 153069,
|
| 450 |
+
"<|action_token_1396|>": 153070,
|
| 451 |
+
"<|action_token_1397|>": 153071,
|
| 452 |
+
"<|action_token_1398|>": 153072,
|
| 453 |
+
"<|action_token_1399|>": 153073,
|
| 454 |
+
"<|action_token_139|>": 151813,
|
| 455 |
+
"<|action_token_13|>": 151687,
|
| 456 |
+
"<|action_token_1400|>": 153074,
|
| 457 |
+
"<|action_token_1401|>": 153075,
|
| 458 |
+
"<|action_token_1402|>": 153076,
|
| 459 |
+
"<|action_token_1403|>": 153077,
|
| 460 |
+
"<|action_token_1404|>": 153078,
|
| 461 |
+
"<|action_token_1405|>": 153079,
|
| 462 |
+
"<|action_token_1406|>": 153080,
|
| 463 |
+
"<|action_token_1407|>": 153081,
|
| 464 |
+
"<|action_token_1408|>": 153082,
|
| 465 |
+
"<|action_token_1409|>": 153083,
|
| 466 |
+
"<|action_token_140|>": 151814,
|
| 467 |
+
"<|action_token_1410|>": 153084,
|
| 468 |
+
"<|action_token_1411|>": 153085,
|
| 469 |
+
"<|action_token_1412|>": 153086,
|
| 470 |
+
"<|action_token_1413|>": 153087,
|
| 471 |
+
"<|action_token_1414|>": 153088,
|
| 472 |
+
"<|action_token_1415|>": 153089,
|
| 473 |
+
"<|action_token_1416|>": 153090,
|
| 474 |
+
"<|action_token_1417|>": 153091,
|
| 475 |
+
"<|action_token_1418|>": 153092,
|
| 476 |
+
"<|action_token_1419|>": 153093,
|
| 477 |
+
"<|action_token_141|>": 151815,
|
| 478 |
+
"<|action_token_1420|>": 153094,
|
| 479 |
+
"<|action_token_1421|>": 153095,
|
| 480 |
+
"<|action_token_1422|>": 153096,
|
| 481 |
+
"<|action_token_1423|>": 153097,
|
| 482 |
+
"<|action_token_1424|>": 153098,
|
| 483 |
+
"<|action_token_1425|>": 153099,
|
| 484 |
+
"<|action_token_1426|>": 153100,
|
| 485 |
+
"<|action_token_1427|>": 153101,
|
| 486 |
+
"<|action_token_1428|>": 153102,
|
| 487 |
+
"<|action_token_1429|>": 153103,
|
| 488 |
+
"<|action_token_142|>": 151816,
|
| 489 |
+
"<|action_token_1430|>": 153104,
|
| 490 |
+
"<|action_token_1431|>": 153105,
|
| 491 |
+
"<|action_token_1432|>": 153106,
|
| 492 |
+
"<|action_token_1433|>": 153107,
|
| 493 |
+
"<|action_token_1434|>": 153108,
|
| 494 |
+
"<|action_token_1435|>": 153109,
|
| 495 |
+
"<|action_token_1436|>": 153110,
|
| 496 |
+
"<|action_token_1437|>": 153111,
|
| 497 |
+
"<|action_token_1438|>": 153112,
|
| 498 |
+
"<|action_token_1439|>": 153113,
|
| 499 |
+
"<|action_token_143|>": 151817,
|
| 500 |
+
"<|action_token_1440|>": 153114,
|
| 501 |
+
"<|action_token_1441|>": 153115,
|
| 502 |
+
"<|action_token_1442|>": 153116,
|
| 503 |
+
"<|action_token_1443|>": 153117,
|
| 504 |
+
"<|action_token_1444|>": 153118,
|
| 505 |
+
"<|action_token_1445|>": 153119,
|
| 506 |
+
"<|action_token_1446|>": 153120,
|
| 507 |
+
"<|action_token_1447|>": 153121,
|
| 508 |
+
"<|action_token_1448|>": 153122,
|
| 509 |
+
"<|action_token_1449|>": 153123,
|
| 510 |
+
"<|action_token_144|>": 151818,
|
| 511 |
+
"<|action_token_1450|>": 153124,
|
| 512 |
+
"<|action_token_1451|>": 153125,
|
| 513 |
+
"<|action_token_1452|>": 153126,
|
| 514 |
+
"<|action_token_1453|>": 153127,
|
| 515 |
+
"<|action_token_1454|>": 153128,
|
| 516 |
+
"<|action_token_1455|>": 153129,
|
| 517 |
+
"<|action_token_1456|>": 153130,
|
| 518 |
+
"<|action_token_1457|>": 153131,
|
| 519 |
+
"<|action_token_1458|>": 153132,
|
| 520 |
+
"<|action_token_1459|>": 153133,
|
| 521 |
+
"<|action_token_145|>": 151819,
|
| 522 |
+
"<|action_token_1460|>": 153134,
|
| 523 |
+
"<|action_token_1461|>": 153135,
|
| 524 |
+
"<|action_token_1462|>": 153136,
|
| 525 |
+
"<|action_token_1463|>": 153137,
|
| 526 |
+
"<|action_token_1464|>": 153138,
|
| 527 |
+
"<|action_token_1465|>": 153139,
|
| 528 |
+
"<|action_token_1466|>": 153140,
|
| 529 |
+
"<|action_token_1467|>": 153141,
|
| 530 |
+
"<|action_token_1468|>": 153142,
|
| 531 |
+
"<|action_token_1469|>": 153143,
|
| 532 |
+
"<|action_token_146|>": 151820,
|
| 533 |
+
"<|action_token_1470|>": 153144,
|
| 534 |
+
"<|action_token_1471|>": 153145,
|
| 535 |
+
"<|action_token_1472|>": 153146,
|
| 536 |
+
"<|action_token_1473|>": 153147,
|
| 537 |
+
"<|action_token_1474|>": 153148,
|
| 538 |
+
"<|action_token_1475|>": 153149,
|
| 539 |
+
"<|action_token_1476|>": 153150,
|
| 540 |
+
"<|action_token_1477|>": 153151,
|
| 541 |
+
"<|action_token_1478|>": 153152,
|
| 542 |
+
"<|action_token_1479|>": 153153,
|
| 543 |
+
"<|action_token_147|>": 151821,
|
| 544 |
+
"<|action_token_1480|>": 153154,
|
| 545 |
+
"<|action_token_1481|>": 153155,
|
| 546 |
+
"<|action_token_1482|>": 153156,
|
| 547 |
+
"<|action_token_1483|>": 153157,
|
| 548 |
+
"<|action_token_1484|>": 153158,
|
| 549 |
+
"<|action_token_1485|>": 153159,
|
| 550 |
+
"<|action_token_1486|>": 153160,
|
| 551 |
+
"<|action_token_1487|>": 153161,
|
| 552 |
+
"<|action_token_1488|>": 153162,
|
| 553 |
+
"<|action_token_1489|>": 153163,
|
| 554 |
+
"<|action_token_148|>": 151822,
|
| 555 |
+
"<|action_token_1490|>": 153164,
|
| 556 |
+
"<|action_token_1491|>": 153165,
|
| 557 |
+
"<|action_token_1492|>": 153166,
|
| 558 |
+
"<|action_token_1493|>": 153167,
|
| 559 |
+
"<|action_token_1494|>": 153168,
|
| 560 |
+
"<|action_token_1495|>": 153169,
|
| 561 |
+
"<|action_token_1496|>": 153170,
|
| 562 |
+
"<|action_token_1497|>": 153171,
|
| 563 |
+
"<|action_token_1498|>": 153172,
|
| 564 |
+
"<|action_token_1499|>": 153173,
|
| 565 |
+
"<|action_token_149|>": 151823,
|
| 566 |
+
"<|action_token_14|>": 151688,
|
| 567 |
+
"<|action_token_1500|>": 153174,
|
| 568 |
+
"<|action_token_1501|>": 153175,
|
| 569 |
+
"<|action_token_1502|>": 153176,
|
| 570 |
+
"<|action_token_1503|>": 153177,
|
| 571 |
+
"<|action_token_1504|>": 153178,
|
| 572 |
+
"<|action_token_1505|>": 153179,
|
| 573 |
+
"<|action_token_1506|>": 153180,
|
| 574 |
+
"<|action_token_1507|>": 153181,
|
| 575 |
+
"<|action_token_1508|>": 153182,
|
| 576 |
+
"<|action_token_1509|>": 153183,
|
| 577 |
+
"<|action_token_150|>": 151824,
|
| 578 |
+
"<|action_token_1510|>": 153184,
|
| 579 |
+
"<|action_token_1511|>": 153185,
|
| 580 |
+
"<|action_token_1512|>": 153186,
|
| 581 |
+
"<|action_token_1513|>": 153187,
|
| 582 |
+
"<|action_token_1514|>": 153188,
|
| 583 |
+
"<|action_token_1515|>": 153189,
|
| 584 |
+
"<|action_token_1516|>": 153190,
|
| 585 |
+
"<|action_token_1517|>": 153191,
|
| 586 |
+
"<|action_token_1518|>": 153192,
|
| 587 |
+
"<|action_token_1519|>": 153193,
|
| 588 |
+
"<|action_token_151|>": 151825,
|
| 589 |
+
"<|action_token_1520|>": 153194,
|
| 590 |
+
"<|action_token_1521|>": 153195,
|
| 591 |
+
"<|action_token_1522|>": 153196,
|
| 592 |
+
"<|action_token_1523|>": 153197,
|
| 593 |
+
"<|action_token_1524|>": 153198,
|
| 594 |
+
"<|action_token_1525|>": 153199,
|
| 595 |
+
"<|action_token_1526|>": 153200,
|
| 596 |
+
"<|action_token_1527|>": 153201,
|
| 597 |
+
"<|action_token_1528|>": 153202,
|
| 598 |
+
"<|action_token_1529|>": 153203,
|
| 599 |
+
"<|action_token_152|>": 151826,
|
| 600 |
+
"<|action_token_1530|>": 153204,
|
| 601 |
+
"<|action_token_1531|>": 153205,
|
| 602 |
+
"<|action_token_1532|>": 153206,
|
| 603 |
+
"<|action_token_1533|>": 153207,
|
| 604 |
+
"<|action_token_1534|>": 153208,
|
| 605 |
+
"<|action_token_1535|>": 153209,
|
| 606 |
+
"<|action_token_1536|>": 153210,
|
| 607 |
+
"<|action_token_1537|>": 153211,
|
| 608 |
+
"<|action_token_1538|>": 153212,
|
| 609 |
+
"<|action_token_1539|>": 153213,
|
| 610 |
+
"<|action_token_153|>": 151827,
|
| 611 |
+
"<|action_token_1540|>": 153214,
|
| 612 |
+
"<|action_token_1541|>": 153215,
|
| 613 |
+
"<|action_token_1542|>": 153216,
|
| 614 |
+
"<|action_token_1543|>": 153217,
|
| 615 |
+
"<|action_token_1544|>": 153218,
|
| 616 |
+
"<|action_token_1545|>": 153219,
|
| 617 |
+
"<|action_token_1546|>": 153220,
|
| 618 |
+
"<|action_token_1547|>": 153221,
|
| 619 |
+
"<|action_token_1548|>": 153222,
|
| 620 |
+
"<|action_token_1549|>": 153223,
|
| 621 |
+
"<|action_token_154|>": 151828,
|
| 622 |
+
"<|action_token_1550|>": 153224,
|
| 623 |
+
"<|action_token_1551|>": 153225,
|
| 624 |
+
"<|action_token_1552|>": 153226,
|
| 625 |
+
"<|action_token_1553|>": 153227,
|
| 626 |
+
"<|action_token_1554|>": 153228,
|
| 627 |
+
"<|action_token_1555|>": 153229,
|
| 628 |
+
"<|action_token_1556|>": 153230,
|
| 629 |
+
"<|action_token_1557|>": 153231,
|
| 630 |
+
"<|action_token_1558|>": 153232,
|
| 631 |
+
"<|action_token_1559|>": 153233,
|
| 632 |
+
"<|action_token_155|>": 151829,
|
| 633 |
+
"<|action_token_1560|>": 153234,
|
| 634 |
+
"<|action_token_1561|>": 153235,
|
| 635 |
+
"<|action_token_1562|>": 153236,
|
| 636 |
+
"<|action_token_1563|>": 153237,
|
| 637 |
+
"<|action_token_1564|>": 153238,
|
| 638 |
+
"<|action_token_1565|>": 153239,
|
| 639 |
+
"<|action_token_1566|>": 153240,
|
| 640 |
+
"<|action_token_1567|>": 153241,
|
| 641 |
+
"<|action_token_1568|>": 153242,
|
| 642 |
+
"<|action_token_1569|>": 153243,
|
| 643 |
+
"<|action_token_156|>": 151830,
|
| 644 |
+
"<|action_token_1570|>": 153244,
|
| 645 |
+
"<|action_token_1571|>": 153245,
|
| 646 |
+
"<|action_token_1572|>": 153246,
|
| 647 |
+
"<|action_token_1573|>": 153247,
|
| 648 |
+
"<|action_token_1574|>": 153248,
|
| 649 |
+
"<|action_token_1575|>": 153249,
|
| 650 |
+
"<|action_token_1576|>": 153250,
|
| 651 |
+
"<|action_token_1577|>": 153251,
|
| 652 |
+
"<|action_token_1578|>": 153252,
|
| 653 |
+
"<|action_token_1579|>": 153253,
|
| 654 |
+
"<|action_token_157|>": 151831,
|
| 655 |
+
"<|action_token_1580|>": 153254,
|
| 656 |
+
"<|action_token_1581|>": 153255,
|
| 657 |
+
"<|action_token_1582|>": 153256,
|
| 658 |
+
"<|action_token_1583|>": 153257,
|
| 659 |
+
"<|action_token_1584|>": 153258,
|
| 660 |
+
"<|action_token_1585|>": 153259,
|
| 661 |
+
"<|action_token_1586|>": 153260,
|
| 662 |
+
"<|action_token_1587|>": 153261,
|
| 663 |
+
"<|action_token_1588|>": 153262,
|
| 664 |
+
"<|action_token_1589|>": 153263,
|
| 665 |
+
"<|action_token_158|>": 151832,
|
| 666 |
+
"<|action_token_1590|>": 153264,
|
| 667 |
+
"<|action_token_1591|>": 153265,
|
| 668 |
+
"<|action_token_1592|>": 153266,
|
| 669 |
+
"<|action_token_1593|>": 153267,
|
| 670 |
+
"<|action_token_1594|>": 153268,
|
| 671 |
+
"<|action_token_1595|>": 153269,
|
| 672 |
+
"<|action_token_1596|>": 153270,
|
| 673 |
+
"<|action_token_1597|>": 153271,
|
| 674 |
+
"<|action_token_1598|>": 153272,
|
| 675 |
+
"<|action_token_1599|>": 153273,
|
| 676 |
+
"<|action_token_159|>": 151833,
|
| 677 |
+
"<|action_token_15|>": 151689,
|
| 678 |
+
"<|action_token_1600|>": 153274,
|
| 679 |
+
"<|action_token_1601|>": 153275,
|
| 680 |
+
"<|action_token_1602|>": 153276,
|
| 681 |
+
"<|action_token_1603|>": 153277,
|
| 682 |
+
"<|action_token_1604|>": 153278,
|
| 683 |
+
"<|action_token_1605|>": 153279,
|
| 684 |
+
"<|action_token_1606|>": 153280,
|
| 685 |
+
"<|action_token_1607|>": 153281,
|
| 686 |
+
"<|action_token_1608|>": 153282,
|
| 687 |
+
"<|action_token_1609|>": 153283,
|
| 688 |
+
"<|action_token_160|>": 151834,
|
| 689 |
+
"<|action_token_1610|>": 153284,
|
| 690 |
+
"<|action_token_1611|>": 153285,
|
| 691 |
+
"<|action_token_1612|>": 153286,
|
| 692 |
+
"<|action_token_1613|>": 153287,
|
| 693 |
+
"<|action_token_1614|>": 153288,
|
| 694 |
+
"<|action_token_1615|>": 153289,
|
| 695 |
+
"<|action_token_1616|>": 153290,
|
| 696 |
+
"<|action_token_1617|>": 153291,
|
| 697 |
+
"<|action_token_1618|>": 153292,
|
| 698 |
+
"<|action_token_1619|>": 153293,
|
| 699 |
+
"<|action_token_161|>": 151835,
|
| 700 |
+
"<|action_token_1620|>": 153294,
|
| 701 |
+
"<|action_token_1621|>": 153295,
|
| 702 |
+
"<|action_token_1622|>": 153296,
|
| 703 |
+
"<|action_token_1623|>": 153297,
|
| 704 |
+
"<|action_token_1624|>": 153298,
|
| 705 |
+
"<|action_token_1625|>": 153299,
|
| 706 |
+
"<|action_token_1626|>": 153300,
|
| 707 |
+
"<|action_token_1627|>": 153301,
|
| 708 |
+
"<|action_token_1628|>": 153302,
|
| 709 |
+
"<|action_token_1629|>": 153303,
|
| 710 |
+
"<|action_token_162|>": 151836,
|
| 711 |
+
"<|action_token_1630|>": 153304,
|
| 712 |
+
"<|action_token_1631|>": 153305,
|
| 713 |
+
"<|action_token_1632|>": 153306,
|
| 714 |
+
"<|action_token_1633|>": 153307,
|
| 715 |
+
"<|action_token_1634|>": 153308,
|
| 716 |
+
"<|action_token_1635|>": 153309,
|
| 717 |
+
"<|action_token_1636|>": 153310,
|
| 718 |
+
"<|action_token_1637|>": 153311,
|
| 719 |
+
"<|action_token_1638|>": 153312,
|
| 720 |
+
"<|action_token_1639|>": 153313,
|
| 721 |
+
"<|action_token_163|>": 151837,
|
| 722 |
+
"<|action_token_1640|>": 153314,
|
| 723 |
+
"<|action_token_1641|>": 153315,
|
| 724 |
+
"<|action_token_1642|>": 153316,
|
| 725 |
+
"<|action_token_1643|>": 153317,
|
| 726 |
+
"<|action_token_1644|>": 153318,
|
| 727 |
+
"<|action_token_1645|>": 153319,
|
| 728 |
+
"<|action_token_1646|>": 153320,
|
| 729 |
+
"<|action_token_1647|>": 153321,
|
| 730 |
+
"<|action_token_1648|>": 153322,
|
| 731 |
+
"<|action_token_1649|>": 153323,
|
| 732 |
+
"<|action_token_164|>": 151838,
|
| 733 |
+
"<|action_token_1650|>": 153324,
|
| 734 |
+
"<|action_token_1651|>": 153325,
|
| 735 |
+
"<|action_token_1652|>": 153326,
|
| 736 |
+
"<|action_token_1653|>": 153327,
|
| 737 |
+
"<|action_token_1654|>": 153328,
|
| 738 |
+
"<|action_token_1655|>": 153329,
|
| 739 |
+
"<|action_token_1656|>": 153330,
|
| 740 |
+
"<|action_token_1657|>": 153331,
|
| 741 |
+
"<|action_token_1658|>": 153332,
|
| 742 |
+
"<|action_token_1659|>": 153333,
|
| 743 |
+
"<|action_token_165|>": 151839,
|
| 744 |
+
"<|action_token_1660|>": 153334,
|
| 745 |
+
"<|action_token_1661|>": 153335,
|
| 746 |
+
"<|action_token_1662|>": 153336,
|
| 747 |
+
"<|action_token_1663|>": 153337,
|
| 748 |
+
"<|action_token_1664|>": 153338,
|
| 749 |
+
"<|action_token_1665|>": 153339,
|
| 750 |
+
"<|action_token_1666|>": 153340,
|
| 751 |
+
"<|action_token_1667|>": 153341,
|
| 752 |
+
"<|action_token_1668|>": 153342,
|
| 753 |
+
"<|action_token_1669|>": 153343,
|
| 754 |
+
"<|action_token_166|>": 151840,
|
| 755 |
+
"<|action_token_1670|>": 153344,
|
| 756 |
+
"<|action_token_1671|>": 153345,
|
| 757 |
+
"<|action_token_1672|>": 153346,
|
| 758 |
+
"<|action_token_1673|>": 153347,
|
| 759 |
+
"<|action_token_1674|>": 153348,
|
| 760 |
+
"<|action_token_1675|>": 153349,
|
| 761 |
+
"<|action_token_1676|>": 153350,
|
| 762 |
+
"<|action_token_1677|>": 153351,
|
| 763 |
+
"<|action_token_1678|>": 153352,
|
| 764 |
+
"<|action_token_1679|>": 153353,
|
| 765 |
+
"<|action_token_167|>": 151841,
|
| 766 |
+
"<|action_token_1680|>": 153354,
|
| 767 |
+
"<|action_token_1681|>": 153355,
|
| 768 |
+
"<|action_token_1682|>": 153356,
|
| 769 |
+
"<|action_token_1683|>": 153357,
|
| 770 |
+
"<|action_token_1684|>": 153358,
|
| 771 |
+
"<|action_token_1685|>": 153359,
|
| 772 |
+
"<|action_token_1686|>": 153360,
|
| 773 |
+
"<|action_token_1687|>": 153361,
|
| 774 |
+
"<|action_token_1688|>": 153362,
|
| 775 |
+
"<|action_token_1689|>": 153363,
|
| 776 |
+
"<|action_token_168|>": 151842,
|
| 777 |
+
"<|action_token_1690|>": 153364,
|
| 778 |
+
"<|action_token_1691|>": 153365,
|
| 779 |
+
"<|action_token_1692|>": 153366,
|
| 780 |
+
"<|action_token_1693|>": 153367,
|
| 781 |
+
"<|action_token_1694|>": 153368,
|
| 782 |
+
"<|action_token_1695|>": 153369,
|
| 783 |
+
"<|action_token_1696|>": 153370,
|
| 784 |
+
"<|action_token_1697|>": 153371,
|
| 785 |
+
"<|action_token_1698|>": 153372,
|
| 786 |
+
"<|action_token_1699|>": 153373,
|
| 787 |
+
"<|action_token_169|>": 151843,
|
| 788 |
+
"<|action_token_16|>": 151690,
|
| 789 |
+
"<|action_token_1700|>": 153374,
|
| 790 |
+
"<|action_token_1701|>": 153375,
|
| 791 |
+
"<|action_token_1702|>": 153376,
|
| 792 |
+
"<|action_token_1703|>": 153377,
|
| 793 |
+
"<|action_token_1704|>": 153378,
|
| 794 |
+
"<|action_token_1705|>": 153379,
|
| 795 |
+
"<|action_token_1706|>": 153380,
|
| 796 |
+
"<|action_token_1707|>": 153381,
|
| 797 |
+
"<|action_token_1708|>": 153382,
|
| 798 |
+
"<|action_token_1709|>": 153383,
|
| 799 |
+
"<|action_token_170|>": 151844,
|
| 800 |
+
"<|action_token_1710|>": 153384,
|
| 801 |
+
"<|action_token_1711|>": 153385,
|
| 802 |
+
"<|action_token_1712|>": 153386,
|
| 803 |
+
"<|action_token_1713|>": 153387,
|
| 804 |
+
"<|action_token_1714|>": 153388,
|
| 805 |
+
"<|action_token_1715|>": 153389,
|
| 806 |
+
"<|action_token_1716|>": 153390,
|
| 807 |
+
"<|action_token_1717|>": 153391,
|
| 808 |
+
"<|action_token_1718|>": 153392,
|
| 809 |
+
"<|action_token_1719|>": 153393,
|
| 810 |
+
"<|action_token_171|>": 151845,
|
| 811 |
+
"<|action_token_1720|>": 153394,
|
| 812 |
+
"<|action_token_1721|>": 153395,
|
| 813 |
+
"<|action_token_1722|>": 153396,
|
| 814 |
+
"<|action_token_1723|>": 153397,
|
| 815 |
+
"<|action_token_1724|>": 153398,
|
| 816 |
+
"<|action_token_1725|>": 153399,
|
| 817 |
+
"<|action_token_1726|>": 153400,
|
| 818 |
+
"<|action_token_1727|>": 153401,
|
| 819 |
+
"<|action_token_1728|>": 153402,
|
| 820 |
+
"<|action_token_1729|>": 153403,
|
| 821 |
+
"<|action_token_172|>": 151846,
|
| 822 |
+
"<|action_token_1730|>": 153404,
|
| 823 |
+
"<|action_token_1731|>": 153405,
|
| 824 |
+
"<|action_token_1732|>": 153406,
|
| 825 |
+
"<|action_token_1733|>": 153407,
|
| 826 |
+
"<|action_token_1734|>": 153408,
|
| 827 |
+
"<|action_token_1735|>": 153409,
|
| 828 |
+
"<|action_token_1736|>": 153410,
|
| 829 |
+
"<|action_token_1737|>": 153411,
|
| 830 |
+
"<|action_token_1738|>": 153412,
|
| 831 |
+
"<|action_token_1739|>": 153413,
|
| 832 |
+
"<|action_token_173|>": 151847,
|
| 833 |
+
"<|action_token_1740|>": 153414,
|
| 834 |
+
"<|action_token_1741|>": 153415,
|
| 835 |
+
"<|action_token_1742|>": 153416,
|
| 836 |
+
"<|action_token_1743|>": 153417,
|
| 837 |
+
"<|action_token_1744|>": 153418,
|
| 838 |
+
"<|action_token_1745|>": 153419,
|
| 839 |
+
"<|action_token_1746|>": 153420,
|
| 840 |
+
"<|action_token_1747|>": 153421,
|
| 841 |
+
"<|action_token_1748|>": 153422,
|
| 842 |
+
"<|action_token_1749|>": 153423,
|
| 843 |
+
"<|action_token_174|>": 151848,
|
| 844 |
+
"<|action_token_1750|>": 153424,
|
| 845 |
+
"<|action_token_1751|>": 153425,
|
| 846 |
+
"<|action_token_1752|>": 153426,
|
| 847 |
+
"<|action_token_1753|>": 153427,
|
| 848 |
+
"<|action_token_1754|>": 153428,
|
| 849 |
+
"<|action_token_1755|>": 153429,
|
| 850 |
+
"<|action_token_1756|>": 153430,
|
| 851 |
+
"<|action_token_1757|>": 153431,
|
| 852 |
+
"<|action_token_1758|>": 153432,
|
| 853 |
+
"<|action_token_1759|>": 153433,
|
| 854 |
+
"<|action_token_175|>": 151849,
|
| 855 |
+
"<|action_token_1760|>": 153434,
|
| 856 |
+
"<|action_token_1761|>": 153435,
|
| 857 |
+
"<|action_token_1762|>": 153436,
|
| 858 |
+
"<|action_token_1763|>": 153437,
|
| 859 |
+
"<|action_token_1764|>": 153438,
|
| 860 |
+
"<|action_token_1765|>": 153439,
|
| 861 |
+
"<|action_token_1766|>": 153440,
|
| 862 |
+
"<|action_token_1767|>": 153441,
|
| 863 |
+
"<|action_token_1768|>": 153442,
|
| 864 |
+
"<|action_token_1769|>": 153443,
|
| 865 |
+
"<|action_token_176|>": 151850,
|
| 866 |
+
"<|action_token_1770|>": 153444,
|
| 867 |
+
"<|action_token_1771|>": 153445,
|
| 868 |
+
"<|action_token_1772|>": 153446,
|
| 869 |
+
"<|action_token_1773|>": 153447,
|
| 870 |
+
"<|action_token_1774|>": 153448,
|
| 871 |
+
"<|action_token_1775|>": 153449,
|
| 872 |
+
"<|action_token_1776|>": 153450,
|
| 873 |
+
"<|action_token_1777|>": 153451,
|
| 874 |
+
"<|action_token_1778|>": 153452,
|
| 875 |
+
"<|action_token_1779|>": 153453,
|
| 876 |
+
"<|action_token_177|>": 151851,
|
| 877 |
+
"<|action_token_1780|>": 153454,
|
| 878 |
+
"<|action_token_1781|>": 153455,
|
| 879 |
+
"<|action_token_1782|>": 153456,
|
| 880 |
+
"<|action_token_1783|>": 153457,
|
| 881 |
+
"<|action_token_1784|>": 153458,
|
| 882 |
+
"<|action_token_1785|>": 153459,
|
| 883 |
+
"<|action_token_1786|>": 153460,
|
| 884 |
+
"<|action_token_1787|>": 153461,
|
| 885 |
+
"<|action_token_1788|>": 153462,
|
| 886 |
+
"<|action_token_1789|>": 153463,
|
| 887 |
+
"<|action_token_178|>": 151852,
|
| 888 |
+
"<|action_token_1790|>": 153464,
|
| 889 |
+
"<|action_token_1791|>": 153465,
|
| 890 |
+
"<|action_token_1792|>": 153466,
|
| 891 |
+
"<|action_token_1793|>": 153467,
|
| 892 |
+
"<|action_token_1794|>": 153468,
|
| 893 |
+
"<|action_token_1795|>": 153469,
|
| 894 |
+
"<|action_token_1796|>": 153470,
|
| 895 |
+
"<|action_token_1797|>": 153471,
|
| 896 |
+
"<|action_token_1798|>": 153472,
|
| 897 |
+
"<|action_token_1799|>": 153473,
|
| 898 |
+
"<|action_token_179|>": 151853,
|
| 899 |
+
"<|action_token_17|>": 151691,
|
| 900 |
+
"<|action_token_1800|>": 153474,
|
| 901 |
+
"<|action_token_1801|>": 153475,
|
| 902 |
+
"<|action_token_1802|>": 153476,
|
| 903 |
+
"<|action_token_1803|>": 153477,
|
| 904 |
+
"<|action_token_1804|>": 153478,
|
| 905 |
+
"<|action_token_1805|>": 153479,
|
| 906 |
+
"<|action_token_1806|>": 153480,
|
| 907 |
+
"<|action_token_1807|>": 153481,
|
| 908 |
+
"<|action_token_1808|>": 153482,
|
| 909 |
+
"<|action_token_1809|>": 153483,
|
| 910 |
+
"<|action_token_180|>": 151854,
|
| 911 |
+
"<|action_token_1810|>": 153484,
|
| 912 |
+
"<|action_token_1811|>": 153485,
|
| 913 |
+
"<|action_token_1812|>": 153486,
|
| 914 |
+
"<|action_token_1813|>": 153487,
|
| 915 |
+
"<|action_token_1814|>": 153488,
|
| 916 |
+
"<|action_token_1815|>": 153489,
|
| 917 |
+
"<|action_token_1816|>": 153490,
|
| 918 |
+
"<|action_token_1817|>": 153491,
|
| 919 |
+
"<|action_token_1818|>": 153492,
|
| 920 |
+
"<|action_token_1819|>": 153493,
|
| 921 |
+
"<|action_token_181|>": 151855,
|
| 922 |
+
"<|action_token_1820|>": 153494,
|
| 923 |
+
"<|action_token_1821|>": 153495,
|
| 924 |
+
"<|action_token_1822|>": 153496,
|
| 925 |
+
"<|action_token_1823|>": 153497,
|
| 926 |
+
"<|action_token_1824|>": 153498,
|
| 927 |
+
"<|action_token_1825|>": 153499,
|
| 928 |
+
"<|action_token_1826|>": 153500,
|
| 929 |
+
"<|action_token_1827|>": 153501,
|
| 930 |
+
"<|action_token_1828|>": 153502,
|
| 931 |
+
"<|action_token_1829|>": 153503,
|
| 932 |
+
"<|action_token_182|>": 151856,
|
| 933 |
+
"<|action_token_1830|>": 153504,
|
| 934 |
+
"<|action_token_1831|>": 153505,
|
| 935 |
+
"<|action_token_1832|>": 153506,
|
| 936 |
+
"<|action_token_1833|>": 153507,
|
| 937 |
+
"<|action_token_1834|>": 153508,
|
| 938 |
+
"<|action_token_1835|>": 153509,
|
| 939 |
+
"<|action_token_1836|>": 153510,
|
| 940 |
+
"<|action_token_1837|>": 153511,
|
| 941 |
+
"<|action_token_1838|>": 153512,
|
| 942 |
+
"<|action_token_1839|>": 153513,
|
| 943 |
+
"<|action_token_183|>": 151857,
|
| 944 |
+
"<|action_token_1840|>": 153514,
|
| 945 |
+
"<|action_token_1841|>": 153515,
|
| 946 |
+
"<|action_token_1842|>": 153516,
|
| 947 |
+
"<|action_token_1843|>": 153517,
|
| 948 |
+
"<|action_token_1844|>": 153518,
|
| 949 |
+
"<|action_token_1845|>": 153519,
|
| 950 |
+
"<|action_token_1846|>": 153520,
|
| 951 |
+
"<|action_token_1847|>": 153521,
|
| 952 |
+
"<|action_token_1848|>": 153522,
|
| 953 |
+
"<|action_token_1849|>": 153523,
|
| 954 |
+
"<|action_token_184|>": 151858,
|
| 955 |
+
"<|action_token_1850|>": 153524,
|
| 956 |
+
"<|action_token_1851|>": 153525,
|
| 957 |
+
"<|action_token_1852|>": 153526,
|
| 958 |
+
"<|action_token_1853|>": 153527,
|
| 959 |
+
"<|action_token_1854|>": 153528,
|
| 960 |
+
"<|action_token_1855|>": 153529,
|
| 961 |
+
"<|action_token_1856|>": 153530,
|
| 962 |
+
"<|action_token_1857|>": 153531,
|
| 963 |
+
"<|action_token_1858|>": 153532,
|
| 964 |
+
"<|action_token_1859|>": 153533,
|
| 965 |
+
"<|action_token_185|>": 151859,
|
| 966 |
+
"<|action_token_1860|>": 153534,
|
| 967 |
+
"<|action_token_1861|>": 153535,
|
| 968 |
+
"<|action_token_1862|>": 153536,
|
| 969 |
+
"<|action_token_1863|>": 153537,
|
| 970 |
+
"<|action_token_1864|>": 153538,
|
| 971 |
+
"<|action_token_1865|>": 153539,
|
| 972 |
+
"<|action_token_1866|>": 153540,
|
| 973 |
+
"<|action_token_1867|>": 153541,
|
| 974 |
+
"<|action_token_1868|>": 153542,
|
| 975 |
+
"<|action_token_1869|>": 153543,
|
| 976 |
+
"<|action_token_186|>": 151860,
|
| 977 |
+
"<|action_token_1870|>": 153544,
|
| 978 |
+
"<|action_token_1871|>": 153545,
|
| 979 |
+
"<|action_token_1872|>": 153546,
|
| 980 |
+
"<|action_token_1873|>": 153547,
|
| 981 |
+
"<|action_token_1874|>": 153548,
|
| 982 |
+
"<|action_token_1875|>": 153549,
|
| 983 |
+
"<|action_token_1876|>": 153550,
|
| 984 |
+
"<|action_token_1877|>": 153551,
|
| 985 |
+
"<|action_token_1878|>": 153552,
|
| 986 |
+
"<|action_token_1879|>": 153553,
|
| 987 |
+
"<|action_token_187|>": 151861,
|
| 988 |
+
"<|action_token_1880|>": 153554,
|
| 989 |
+
"<|action_token_1881|>": 153555,
|
| 990 |
+
"<|action_token_1882|>": 153556,
|
| 991 |
+
"<|action_token_1883|>": 153557,
|
| 992 |
+
"<|action_token_1884|>": 153558,
|
| 993 |
+
"<|action_token_1885|>": 153559,
|
| 994 |
+
"<|action_token_1886|>": 153560,
|
| 995 |
+
"<|action_token_1887|>": 153561,
|
| 996 |
+
"<|action_token_1888|>": 153562,
|
| 997 |
+
"<|action_token_1889|>": 153563,
|
| 998 |
+
"<|action_token_188|>": 151862,
|
| 999 |
+
"<|action_token_1890|>": 153564,
|
| 1000 |
+
"<|action_token_1891|>": 153565,
|
| 1001 |
+
"<|action_token_1892|>": 153566,
|
| 1002 |
+
"<|action_token_1893|>": 153567,
|
| 1003 |
+
"<|action_token_1894|>": 153568,
|
| 1004 |
+
"<|action_token_1895|>": 153569,
|
| 1005 |
+
"<|action_token_1896|>": 153570,
|
| 1006 |
+
"<|action_token_1897|>": 153571,
|
| 1007 |
+
"<|action_token_1898|>": 153572,
|
| 1008 |
+
"<|action_token_1899|>": 153573,
|
| 1009 |
+
"<|action_token_189|>": 151863,
|
| 1010 |
+
"<|action_token_18|>": 151692,
|
| 1011 |
+
"<|action_token_1900|>": 153574,
|
| 1012 |
+
"<|action_token_1901|>": 153575,
|
| 1013 |
+
"<|action_token_1902|>": 153576,
|
| 1014 |
+
"<|action_token_1903|>": 153577,
|
| 1015 |
+
"<|action_token_1904|>": 153578,
|
| 1016 |
+
"<|action_token_1905|>": 153579,
|
| 1017 |
+
"<|action_token_1906|>": 153580,
|
| 1018 |
+
"<|action_token_1907|>": 153581,
|
| 1019 |
+
"<|action_token_1908|>": 153582,
|
| 1020 |
+
"<|action_token_1909|>": 153583,
|
| 1021 |
+
"<|action_token_190|>": 151864,
|
| 1022 |
+
"<|action_token_1910|>": 153584,
|
| 1023 |
+
"<|action_token_1911|>": 153585,
|
| 1024 |
+
"<|action_token_1912|>": 153586,
|
| 1025 |
+
"<|action_token_1913|>": 153587,
|
| 1026 |
+
"<|action_token_1914|>": 153588,
|
| 1027 |
+
"<|action_token_1915|>": 153589,
|
| 1028 |
+
"<|action_token_1916|>": 153590,
|
| 1029 |
+
"<|action_token_1917|>": 153591,
|
| 1030 |
+
"<|action_token_1918|>": 153592,
|
| 1031 |
+
"<|action_token_1919|>": 153593,
|
| 1032 |
+
"<|action_token_191|>": 151865,
|
| 1033 |
+
"<|action_token_1920|>": 153594,
|
| 1034 |
+
"<|action_token_1921|>": 153595,
|
| 1035 |
+
"<|action_token_1922|>": 153596,
|
| 1036 |
+
"<|action_token_1923|>": 153597,
|
| 1037 |
+
"<|action_token_1924|>": 153598,
|
| 1038 |
+
"<|action_token_1925|>": 153599,
|
| 1039 |
+
"<|action_token_1926|>": 153600,
|
| 1040 |
+
"<|action_token_1927|>": 153601,
|
| 1041 |
+
"<|action_token_1928|>": 153602,
|
| 1042 |
+
"<|action_token_1929|>": 153603,
|
| 1043 |
+
"<|action_token_192|>": 151866,
|
| 1044 |
+
"<|action_token_1930|>": 153604,
|
| 1045 |
+
"<|action_token_1931|>": 153605,
|
| 1046 |
+
"<|action_token_1932|>": 153606,
|
| 1047 |
+
"<|action_token_1933|>": 153607,
|
| 1048 |
+
"<|action_token_1934|>": 153608,
|
| 1049 |
+
"<|action_token_1935|>": 153609,
|
| 1050 |
+
"<|action_token_1936|>": 153610,
|
| 1051 |
+
"<|action_token_1937|>": 153611,
|
| 1052 |
+
"<|action_token_1938|>": 153612,
|
| 1053 |
+
"<|action_token_1939|>": 153613,
|
| 1054 |
+
"<|action_token_193|>": 151867,
|
| 1055 |
+
"<|action_token_1940|>": 153614,
|
| 1056 |
+
"<|action_token_1941|>": 153615,
|
| 1057 |
+
"<|action_token_1942|>": 153616,
|
| 1058 |
+
"<|action_token_1943|>": 153617,
|
| 1059 |
+
"<|action_token_1944|>": 153618,
|
| 1060 |
+
"<|action_token_1945|>": 153619,
|
| 1061 |
+
"<|action_token_1946|>": 153620,
|
| 1062 |
+
"<|action_token_1947|>": 153621,
|
| 1063 |
+
"<|action_token_1948|>": 153622,
|
| 1064 |
+
"<|action_token_1949|>": 153623,
|
| 1065 |
+
"<|action_token_194|>": 151868,
|
| 1066 |
+
"<|action_token_1950|>": 153624,
|
| 1067 |
+
"<|action_token_1951|>": 153625,
|
| 1068 |
+
"<|action_token_1952|>": 153626,
|
| 1069 |
+
"<|action_token_1953|>": 153627,
|
| 1070 |
+
"<|action_token_1954|>": 153628,
|
| 1071 |
+
"<|action_token_1955|>": 153629,
|
| 1072 |
+
"<|action_token_1956|>": 153630,
|
| 1073 |
+
"<|action_token_1957|>": 153631,
|
| 1074 |
+
"<|action_token_1958|>": 153632,
|
| 1075 |
+
"<|action_token_1959|>": 153633,
|
| 1076 |
+
"<|action_token_195|>": 151869,
|
| 1077 |
+
"<|action_token_1960|>": 153634,
|
| 1078 |
+
"<|action_token_1961|>": 153635,
|
| 1079 |
+
"<|action_token_1962|>": 153636,
|
| 1080 |
+
"<|action_token_1963|>": 153637,
|
| 1081 |
+
"<|action_token_1964|>": 153638,
|
| 1082 |
+
"<|action_token_1965|>": 153639,
|
| 1083 |
+
"<|action_token_1966|>": 153640,
|
| 1084 |
+
"<|action_token_1967|>": 153641,
|
| 1085 |
+
"<|action_token_1968|>": 153642,
|
| 1086 |
+
"<|action_token_1969|>": 153643,
|
| 1087 |
+
"<|action_token_196|>": 151870,
|
| 1088 |
+
"<|action_token_1970|>": 153644,
|
| 1089 |
+
"<|action_token_1971|>": 153645,
|
| 1090 |
+
"<|action_token_1972|>": 153646,
|
| 1091 |
+
"<|action_token_1973|>": 153647,
|
| 1092 |
+
"<|action_token_1974|>": 153648,
|
| 1093 |
+
"<|action_token_1975|>": 153649,
|
| 1094 |
+
"<|action_token_1976|>": 153650,
|
| 1095 |
+
"<|action_token_1977|>": 153651,
|
| 1096 |
+
"<|action_token_1978|>": 153652,
|
| 1097 |
+
"<|action_token_1979|>": 153653,
|
| 1098 |
+
"<|action_token_197|>": 151871,
|
| 1099 |
+
"<|action_token_1980|>": 153654,
|
| 1100 |
+
"<|action_token_1981|>": 153655,
|
| 1101 |
+
"<|action_token_1982|>": 153656,
|
| 1102 |
+
"<|action_token_1983|>": 153657,
|
| 1103 |
+
"<|action_token_1984|>": 153658,
|
| 1104 |
+
"<|action_token_1985|>": 153659,
|
| 1105 |
+
"<|action_token_1986|>": 153660,
|
| 1106 |
+
"<|action_token_1987|>": 153661,
|
| 1107 |
+
"<|action_token_1988|>": 153662,
|
| 1108 |
+
"<|action_token_1989|>": 153663,
|
| 1109 |
+
"<|action_token_198|>": 151872,
|
| 1110 |
+
"<|action_token_1990|>": 153664,
|
| 1111 |
+
"<|action_token_1991|>": 153665,
|
| 1112 |
+
"<|action_token_1992|>": 153666,
|
| 1113 |
+
"<|action_token_1993|>": 153667,
|
| 1114 |
+
"<|action_token_1994|>": 153668,
|
| 1115 |
+
"<|action_token_1995|>": 153669,
|
| 1116 |
+
"<|action_token_1996|>": 153670,
|
| 1117 |
+
"<|action_token_1997|>": 153671,
|
| 1118 |
+
"<|action_token_1998|>": 153672,
|
| 1119 |
+
"<|action_token_1999|>": 153673,
|
| 1120 |
+
"<|action_token_199|>": 151873,
|
| 1121 |
+
"<|action_token_19|>": 151693,
|
| 1122 |
+
"<|action_token_1|>": 151675,
|
| 1123 |
+
"<|action_token_2000|>": 153674,
|
| 1124 |
+
"<|action_token_2001|>": 153675,
|
| 1125 |
+
"<|action_token_2002|>": 153676,
|
| 1126 |
+
"<|action_token_2003|>": 153677,
|
| 1127 |
+
"<|action_token_2004|>": 153678,
|
| 1128 |
+
"<|action_token_2005|>": 153679,
|
| 1129 |
+
"<|action_token_2006|>": 153680,
|
| 1130 |
+
"<|action_token_2007|>": 153681,
|
| 1131 |
+
"<|action_token_2008|>": 153682,
|
| 1132 |
+
"<|action_token_2009|>": 153683,
|
| 1133 |
+
"<|action_token_200|>": 151874,
|
| 1134 |
+
"<|action_token_2010|>": 153684,
|
| 1135 |
+
"<|action_token_2011|>": 153685,
|
| 1136 |
+
"<|action_token_2012|>": 153686,
|
| 1137 |
+
"<|action_token_2013|>": 153687,
|
| 1138 |
+
"<|action_token_2014|>": 153688,
|
| 1139 |
+
"<|action_token_2015|>": 153689,
|
| 1140 |
+
"<|action_token_2016|>": 153690,
|
| 1141 |
+
"<|action_token_2017|>": 153691,
|
| 1142 |
+
"<|action_token_2018|>": 153692,
|
| 1143 |
+
"<|action_token_2019|>": 153693,
|
| 1144 |
+
"<|action_token_201|>": 151875,
|
| 1145 |
+
"<|action_token_2020|>": 153694,
|
| 1146 |
+
"<|action_token_2021|>": 153695,
|
| 1147 |
+
"<|action_token_2022|>": 153696,
|
| 1148 |
+
"<|action_token_2023|>": 153697,
|
| 1149 |
+
"<|action_token_2024|>": 153698,
|
| 1150 |
+
"<|action_token_2025|>": 153699,
|
| 1151 |
+
"<|action_token_2026|>": 153700,
|
| 1152 |
+
"<|action_token_2027|>": 153701,
|
| 1153 |
+
"<|action_token_2028|>": 153702,
|
| 1154 |
+
"<|action_token_2029|>": 153703,
|
| 1155 |
+
"<|action_token_202|>": 151876,
|
| 1156 |
+
"<|action_token_2030|>": 153704,
|
| 1157 |
+
"<|action_token_2031|>": 153705,
|
| 1158 |
+
"<|action_token_2032|>": 153706,
|
| 1159 |
+
"<|action_token_2033|>": 153707,
|
| 1160 |
+
"<|action_token_2034|>": 153708,
|
| 1161 |
+
"<|action_token_2035|>": 153709,
|
| 1162 |
+
"<|action_token_2036|>": 153710,
|
| 1163 |
+
"<|action_token_2037|>": 153711,
|
| 1164 |
+
"<|action_token_2038|>": 153712,
|
| 1165 |
+
"<|action_token_2039|>": 153713,
|
| 1166 |
+
"<|action_token_203|>": 151877,
|
| 1167 |
+
"<|action_token_2040|>": 153714,
|
| 1168 |
+
"<|action_token_2041|>": 153715,
|
| 1169 |
+
"<|action_token_2042|>": 153716,
|
| 1170 |
+
"<|action_token_2043|>": 153717,
|
| 1171 |
+
"<|action_token_2044|>": 153718,
|
| 1172 |
+
"<|action_token_2045|>": 153719,
|
| 1173 |
+
"<|action_token_2046|>": 153720,
|
| 1174 |
+
"<|action_token_2047|>": 153721,
|
| 1175 |
+
"<|action_token_204|>": 151878,
|
| 1176 |
+
"<|action_token_205|>": 151879,
|
| 1177 |
+
"<|action_token_206|>": 151880,
|
| 1178 |
+
"<|action_token_207|>": 151881,
|
| 1179 |
+
"<|action_token_208|>": 151882,
|
| 1180 |
+
"<|action_token_209|>": 151883,
|
| 1181 |
+
"<|action_token_20|>": 151694,
|
| 1182 |
+
"<|action_token_210|>": 151884,
|
| 1183 |
+
"<|action_token_211|>": 151885,
|
| 1184 |
+
"<|action_token_212|>": 151886,
|
| 1185 |
+
"<|action_token_213|>": 151887,
|
| 1186 |
+
"<|action_token_214|>": 151888,
|
| 1187 |
+
"<|action_token_215|>": 151889,
|
| 1188 |
+
"<|action_token_216|>": 151890,
|
| 1189 |
+
"<|action_token_217|>": 151891,
|
| 1190 |
+
"<|action_token_218|>": 151892,
|
| 1191 |
+
"<|action_token_219|>": 151893,
|
| 1192 |
+
"<|action_token_21|>": 151695,
|
| 1193 |
+
"<|action_token_220|>": 151894,
|
| 1194 |
+
"<|action_token_221|>": 151895,
|
| 1195 |
+
"<|action_token_222|>": 151896,
|
| 1196 |
+
"<|action_token_223|>": 151897,
|
| 1197 |
+
"<|action_token_224|>": 151898,
|
| 1198 |
+
"<|action_token_225|>": 151899,
|
| 1199 |
+
"<|action_token_226|>": 151900,
|
| 1200 |
+
"<|action_token_227|>": 151901,
|
| 1201 |
+
"<|action_token_228|>": 151902,
|
| 1202 |
+
"<|action_token_229|>": 151903,
|
| 1203 |
+
"<|action_token_22|>": 151696,
|
| 1204 |
+
"<|action_token_230|>": 151904,
|
| 1205 |
+
"<|action_token_231|>": 151905,
|
| 1206 |
+
"<|action_token_232|>": 151906,
|
| 1207 |
+
"<|action_token_233|>": 151907,
|
| 1208 |
+
"<|action_token_234|>": 151908,
|
| 1209 |
+
"<|action_token_235|>": 151909,
|
| 1210 |
+
"<|action_token_236|>": 151910,
|
| 1211 |
+
"<|action_token_237|>": 151911,
|
| 1212 |
+
"<|action_token_238|>": 151912,
|
| 1213 |
+
"<|action_token_239|>": 151913,
|
| 1214 |
+
"<|action_token_23|>": 151697,
|
| 1215 |
+
"<|action_token_240|>": 151914,
|
| 1216 |
+
"<|action_token_241|>": 151915,
|
| 1217 |
+
"<|action_token_242|>": 151916,
|
| 1218 |
+
"<|action_token_243|>": 151917,
|
| 1219 |
+
"<|action_token_244|>": 151918,
|
| 1220 |
+
"<|action_token_245|>": 151919,
|
| 1221 |
+
"<|action_token_246|>": 151920,
|
| 1222 |
+
"<|action_token_247|>": 151921,
|
| 1223 |
+
"<|action_token_248|>": 151922,
|
| 1224 |
+
"<|action_token_249|>": 151923,
|
| 1225 |
+
"<|action_token_24|>": 151698,
|
| 1226 |
+
"<|action_token_250|>": 151924,
|
| 1227 |
+
"<|action_token_251|>": 151925,
|
| 1228 |
+
"<|action_token_252|>": 151926,
|
| 1229 |
+
"<|action_token_253|>": 151927,
|
| 1230 |
+
"<|action_token_254|>": 151928,
|
| 1231 |
+
"<|action_token_255|>": 151929,
|
| 1232 |
+
"<|action_token_256|>": 151930,
|
| 1233 |
+
"<|action_token_257|>": 151931,
|
| 1234 |
+
"<|action_token_258|>": 151932,
|
| 1235 |
+
"<|action_token_259|>": 151933,
|
| 1236 |
+
"<|action_token_25|>": 151699,
|
| 1237 |
+
"<|action_token_260|>": 151934,
|
| 1238 |
+
"<|action_token_261|>": 151935,
|
| 1239 |
+
"<|action_token_262|>": 151936,
|
| 1240 |
+
"<|action_token_263|>": 151937,
|
| 1241 |
+
"<|action_token_264|>": 151938,
|
| 1242 |
+
"<|action_token_265|>": 151939,
|
| 1243 |
+
"<|action_token_266|>": 151940,
|
| 1244 |
+
"<|action_token_267|>": 151941,
|
| 1245 |
+
"<|action_token_268|>": 151942,
|
| 1246 |
+
"<|action_token_269|>": 151943,
|
| 1247 |
+
"<|action_token_26|>": 151700,
|
| 1248 |
+
"<|action_token_270|>": 151944,
|
| 1249 |
+
"<|action_token_271|>": 151945,
|
| 1250 |
+
"<|action_token_272|>": 151946,
|
| 1251 |
+
"<|action_token_273|>": 151947,
|
| 1252 |
+
"<|action_token_274|>": 151948,
|
| 1253 |
+
"<|action_token_275|>": 151949,
|
| 1254 |
+
"<|action_token_276|>": 151950,
|
| 1255 |
+
"<|action_token_277|>": 151951,
|
| 1256 |
+
"<|action_token_278|>": 151952,
|
| 1257 |
+
"<|action_token_279|>": 151953,
|
| 1258 |
+
"<|action_token_27|>": 151701,
|
| 1259 |
+
"<|action_token_280|>": 151954,
|
| 1260 |
+
"<|action_token_281|>": 151955,
|
| 1261 |
+
"<|action_token_282|>": 151956,
|
| 1262 |
+
"<|action_token_283|>": 151957,
|
| 1263 |
+
"<|action_token_284|>": 151958,
|
| 1264 |
+
"<|action_token_285|>": 151959,
|
| 1265 |
+
"<|action_token_286|>": 151960,
|
| 1266 |
+
"<|action_token_287|>": 151961,
|
| 1267 |
+
"<|action_token_288|>": 151962,
|
| 1268 |
+
"<|action_token_289|>": 151963,
|
| 1269 |
+
"<|action_token_28|>": 151702,
|
| 1270 |
+
"<|action_token_290|>": 151964,
|
| 1271 |
+
"<|action_token_291|>": 151965,
|
| 1272 |
+
"<|action_token_292|>": 151966,
|
| 1273 |
+
"<|action_token_293|>": 151967,
|
| 1274 |
+
"<|action_token_294|>": 151968,
|
| 1275 |
+
"<|action_token_295|>": 151969,
|
| 1276 |
+
"<|action_token_296|>": 151970,
|
| 1277 |
+
"<|action_token_297|>": 151971,
|
| 1278 |
+
"<|action_token_298|>": 151972,
|
| 1279 |
+
"<|action_token_299|>": 151973,
|
| 1280 |
+
"<|action_token_29|>": 151703,
|
| 1281 |
+
"<|action_token_2|>": 151676,
|
| 1282 |
+
"<|action_token_300|>": 151974,
|
| 1283 |
+
"<|action_token_301|>": 151975,
|
| 1284 |
+
"<|action_token_302|>": 151976,
|
| 1285 |
+
"<|action_token_303|>": 151977,
|
| 1286 |
+
"<|action_token_304|>": 151978,
|
| 1287 |
+
"<|action_token_305|>": 151979,
|
| 1288 |
+
"<|action_token_306|>": 151980,
|
| 1289 |
+
"<|action_token_307|>": 151981,
|
| 1290 |
+
"<|action_token_308|>": 151982,
|
| 1291 |
+
"<|action_token_309|>": 151983,
|
| 1292 |
+
"<|action_token_30|>": 151704,
|
| 1293 |
+
"<|action_token_310|>": 151984,
|
| 1294 |
+
"<|action_token_311|>": 151985,
|
| 1295 |
+
"<|action_token_312|>": 151986,
|
| 1296 |
+
"<|action_token_313|>": 151987,
|
| 1297 |
+
"<|action_token_314|>": 151988,
|
| 1298 |
+
"<|action_token_315|>": 151989,
|
| 1299 |
+
"<|action_token_316|>": 151990,
|
| 1300 |
+
"<|action_token_317|>": 151991,
|
| 1301 |
+
"<|action_token_318|>": 151992,
|
| 1302 |
+
"<|action_token_319|>": 151993,
|
| 1303 |
+
"<|action_token_31|>": 151705,
|
| 1304 |
+
"<|action_token_320|>": 151994,
|
| 1305 |
+
"<|action_token_321|>": 151995,
|
| 1306 |
+
"<|action_token_322|>": 151996,
|
| 1307 |
+
"<|action_token_323|>": 151997,
|
| 1308 |
+
"<|action_token_324|>": 151998,
|
| 1309 |
+
"<|action_token_325|>": 151999,
|
| 1310 |
+
"<|action_token_326|>": 152000,
|
| 1311 |
+
"<|action_token_327|>": 152001,
|
| 1312 |
+
"<|action_token_328|>": 152002,
|
| 1313 |
+
"<|action_token_329|>": 152003,
|
| 1314 |
+
"<|action_token_32|>": 151706,
|
| 1315 |
+
"<|action_token_330|>": 152004,
|
| 1316 |
+
"<|action_token_331|>": 152005,
|
| 1317 |
+
"<|action_token_332|>": 152006,
|
| 1318 |
+
"<|action_token_333|>": 152007,
|
| 1319 |
+
"<|action_token_334|>": 152008,
|
| 1320 |
+
"<|action_token_335|>": 152009,
|
| 1321 |
+
"<|action_token_336|>": 152010,
|
| 1322 |
+
"<|action_token_337|>": 152011,
|
| 1323 |
+
"<|action_token_338|>": 152012,
|
| 1324 |
+
"<|action_token_339|>": 152013,
|
| 1325 |
+
"<|action_token_33|>": 151707,
|
| 1326 |
+
"<|action_token_340|>": 152014,
|
| 1327 |
+
"<|action_token_341|>": 152015,
|
| 1328 |
+
"<|action_token_342|>": 152016,
|
| 1329 |
+
"<|action_token_343|>": 152017,
|
| 1330 |
+
"<|action_token_344|>": 152018,
|
| 1331 |
+
"<|action_token_345|>": 152019,
|
| 1332 |
+
"<|action_token_346|>": 152020,
|
| 1333 |
+
"<|action_token_347|>": 152021,
|
| 1334 |
+
"<|action_token_348|>": 152022,
|
| 1335 |
+
"<|action_token_349|>": 152023,
|
| 1336 |
+
"<|action_token_34|>": 151708,
|
| 1337 |
+
"<|action_token_350|>": 152024,
|
| 1338 |
+
"<|action_token_351|>": 152025,
|
| 1339 |
+
"<|action_token_352|>": 152026,
|
| 1340 |
+
"<|action_token_353|>": 152027,
|
| 1341 |
+
"<|action_token_354|>": 152028,
|
| 1342 |
+
"<|action_token_355|>": 152029,
|
| 1343 |
+
"<|action_token_356|>": 152030,
|
| 1344 |
+
"<|action_token_357|>": 152031,
|
| 1345 |
+
"<|action_token_358|>": 152032,
|
| 1346 |
+
"<|action_token_359|>": 152033,
|
| 1347 |
+
"<|action_token_35|>": 151709,
|
| 1348 |
+
"<|action_token_360|>": 152034,
|
| 1349 |
+
"<|action_token_361|>": 152035,
|
| 1350 |
+
"<|action_token_362|>": 152036,
|
| 1351 |
+
"<|action_token_363|>": 152037,
|
| 1352 |
+
"<|action_token_364|>": 152038,
|
| 1353 |
+
"<|action_token_365|>": 152039,
|
| 1354 |
+
"<|action_token_366|>": 152040,
|
| 1355 |
+
"<|action_token_367|>": 152041,
|
| 1356 |
+
"<|action_token_368|>": 152042,
|
| 1357 |
+
"<|action_token_369|>": 152043,
|
| 1358 |
+
"<|action_token_36|>": 151710,
|
| 1359 |
+
"<|action_token_370|>": 152044,
|
| 1360 |
+
"<|action_token_371|>": 152045,
|
| 1361 |
+
"<|action_token_372|>": 152046,
|
| 1362 |
+
"<|action_token_373|>": 152047,
|
| 1363 |
+
"<|action_token_374|>": 152048,
|
| 1364 |
+
"<|action_token_375|>": 152049,
|
| 1365 |
+
"<|action_token_376|>": 152050,
|
| 1366 |
+
"<|action_token_377|>": 152051,
|
| 1367 |
+
"<|action_token_378|>": 152052,
|
| 1368 |
+
"<|action_token_379|>": 152053,
|
| 1369 |
+
"<|action_token_37|>": 151711,
|
| 1370 |
+
"<|action_token_380|>": 152054,
|
| 1371 |
+
"<|action_token_381|>": 152055,
|
| 1372 |
+
"<|action_token_382|>": 152056,
|
| 1373 |
+
"<|action_token_383|>": 152057,
|
| 1374 |
+
"<|action_token_384|>": 152058,
|
| 1375 |
+
"<|action_token_385|>": 152059,
|
| 1376 |
+
"<|action_token_386|>": 152060,
|
| 1377 |
+
"<|action_token_387|>": 152061,
|
| 1378 |
+
"<|action_token_388|>": 152062,
|
| 1379 |
+
"<|action_token_389|>": 152063,
|
| 1380 |
+
"<|action_token_38|>": 151712,
|
| 1381 |
+
"<|action_token_390|>": 152064,
|
| 1382 |
+
"<|action_token_391|>": 152065,
|
| 1383 |
+
"<|action_token_392|>": 152066,
|
| 1384 |
+
"<|action_token_393|>": 152067,
|
| 1385 |
+
"<|action_token_394|>": 152068,
|
| 1386 |
+
"<|action_token_395|>": 152069,
|
| 1387 |
+
"<|action_token_396|>": 152070,
|
| 1388 |
+
"<|action_token_397|>": 152071,
|
| 1389 |
+
"<|action_token_398|>": 152072,
|
| 1390 |
+
"<|action_token_399|>": 152073,
|
| 1391 |
+
"<|action_token_39|>": 151713,
|
| 1392 |
+
"<|action_token_3|>": 151677,
|
| 1393 |
+
"<|action_token_400|>": 152074,
|
| 1394 |
+
"<|action_token_401|>": 152075,
|
| 1395 |
+
"<|action_token_402|>": 152076,
|
| 1396 |
+
"<|action_token_403|>": 152077,
|
| 1397 |
+
"<|action_token_404|>": 152078,
|
| 1398 |
+
"<|action_token_405|>": 152079,
|
| 1399 |
+
"<|action_token_406|>": 152080,
|
| 1400 |
+
"<|action_token_407|>": 152081,
|
| 1401 |
+
"<|action_token_408|>": 152082,
|
| 1402 |
+
"<|action_token_409|>": 152083,
|
| 1403 |
+
"<|action_token_40|>": 151714,
|
| 1404 |
+
"<|action_token_410|>": 152084,
|
| 1405 |
+
"<|action_token_411|>": 152085,
|
| 1406 |
+
"<|action_token_412|>": 152086,
|
| 1407 |
+
"<|action_token_413|>": 152087,
|
| 1408 |
+
"<|action_token_414|>": 152088,
|
| 1409 |
+
"<|action_token_415|>": 152089,
|
| 1410 |
+
"<|action_token_416|>": 152090,
|
| 1411 |
+
"<|action_token_417|>": 152091,
|
| 1412 |
+
"<|action_token_418|>": 152092,
|
| 1413 |
+
"<|action_token_419|>": 152093,
|
| 1414 |
+
"<|action_token_41|>": 151715,
|
| 1415 |
+
"<|action_token_420|>": 152094,
|
| 1416 |
+
"<|action_token_421|>": 152095,
|
| 1417 |
+
"<|action_token_422|>": 152096,
|
| 1418 |
+
"<|action_token_423|>": 152097,
|
| 1419 |
+
"<|action_token_424|>": 152098,
|
| 1420 |
+
"<|action_token_425|>": 152099,
|
| 1421 |
+
"<|action_token_426|>": 152100,
|
| 1422 |
+
"<|action_token_427|>": 152101,
|
| 1423 |
+
"<|action_token_428|>": 152102,
|
| 1424 |
+
"<|action_token_429|>": 152103,
|
| 1425 |
+
"<|action_token_42|>": 151716,
|
| 1426 |
+
"<|action_token_430|>": 152104,
|
| 1427 |
+
"<|action_token_431|>": 152105,
|
| 1428 |
+
"<|action_token_432|>": 152106,
|
| 1429 |
+
"<|action_token_433|>": 152107,
|
| 1430 |
+
"<|action_token_434|>": 152108,
|
| 1431 |
+
"<|action_token_435|>": 152109,
|
| 1432 |
+
"<|action_token_436|>": 152110,
|
| 1433 |
+
"<|action_token_437|>": 152111,
|
| 1434 |
+
"<|action_token_438|>": 152112,
|
| 1435 |
+
"<|action_token_439|>": 152113,
|
| 1436 |
+
"<|action_token_43|>": 151717,
|
| 1437 |
+
"<|action_token_440|>": 152114,
|
| 1438 |
+
"<|action_token_441|>": 152115,
|
| 1439 |
+
"<|action_token_442|>": 152116,
|
| 1440 |
+
"<|action_token_443|>": 152117,
|
| 1441 |
+
"<|action_token_444|>": 152118,
|
| 1442 |
+
"<|action_token_445|>": 152119,
|
| 1443 |
+
"<|action_token_446|>": 152120,
|
| 1444 |
+
"<|action_token_447|>": 152121,
|
| 1445 |
+
"<|action_token_448|>": 152122,
|
| 1446 |
+
"<|action_token_449|>": 152123,
|
| 1447 |
+
"<|action_token_44|>": 151718,
|
| 1448 |
+
"<|action_token_450|>": 152124,
|
| 1449 |
+
"<|action_token_451|>": 152125,
|
| 1450 |
+
"<|action_token_452|>": 152126,
|
| 1451 |
+
"<|action_token_453|>": 152127,
|
| 1452 |
+
"<|action_token_454|>": 152128,
|
| 1453 |
+
"<|action_token_455|>": 152129,
|
| 1454 |
+
"<|action_token_456|>": 152130,
|
| 1455 |
+
"<|action_token_457|>": 152131,
|
| 1456 |
+
"<|action_token_458|>": 152132,
|
| 1457 |
+
"<|action_token_459|>": 152133,
|
| 1458 |
+
"<|action_token_45|>": 151719,
|
| 1459 |
+
"<|action_token_460|>": 152134,
|
| 1460 |
+
"<|action_token_461|>": 152135,
|
| 1461 |
+
"<|action_token_462|>": 152136,
|
| 1462 |
+
"<|action_token_463|>": 152137,
|
| 1463 |
+
"<|action_token_464|>": 152138,
|
| 1464 |
+
"<|action_token_465|>": 152139,
|
| 1465 |
+
"<|action_token_466|>": 152140,
|
| 1466 |
+
"<|action_token_467|>": 152141,
|
| 1467 |
+
"<|action_token_468|>": 152142,
|
| 1468 |
+
"<|action_token_469|>": 152143,
|
| 1469 |
+
"<|action_token_46|>": 151720,
|
| 1470 |
+
"<|action_token_470|>": 152144,
|
| 1471 |
+
"<|action_token_471|>": 152145,
|
| 1472 |
+
"<|action_token_472|>": 152146,
|
| 1473 |
+
"<|action_token_473|>": 152147,
|
| 1474 |
+
"<|action_token_474|>": 152148,
|
| 1475 |
+
"<|action_token_475|>": 152149,
|
| 1476 |
+
"<|action_token_476|>": 152150,
|
| 1477 |
+
"<|action_token_477|>": 152151,
|
| 1478 |
+
"<|action_token_478|>": 152152,
|
| 1479 |
+
"<|action_token_479|>": 152153,
|
| 1480 |
+
"<|action_token_47|>": 151721,
|
| 1481 |
+
"<|action_token_480|>": 152154,
|
| 1482 |
+
"<|action_token_481|>": 152155,
|
| 1483 |
+
"<|action_token_482|>": 152156,
|
| 1484 |
+
"<|action_token_483|>": 152157,
|
| 1485 |
+
"<|action_token_484|>": 152158,
|
| 1486 |
+
"<|action_token_485|>": 152159,
|
| 1487 |
+
"<|action_token_486|>": 152160,
|
| 1488 |
+
"<|action_token_487|>": 152161,
|
| 1489 |
+
"<|action_token_488|>": 152162,
|
| 1490 |
+
"<|action_token_489|>": 152163,
|
| 1491 |
+
"<|action_token_48|>": 151722,
|
| 1492 |
+
"<|action_token_490|>": 152164,
|
| 1493 |
+
"<|action_token_491|>": 152165,
|
| 1494 |
+
"<|action_token_492|>": 152166,
|
| 1495 |
+
"<|action_token_493|>": 152167,
|
| 1496 |
+
"<|action_token_494|>": 152168,
|
| 1497 |
+
"<|action_token_495|>": 152169,
|
| 1498 |
+
"<|action_token_496|>": 152170,
|
| 1499 |
+
"<|action_token_497|>": 152171,
|
| 1500 |
+
"<|action_token_498|>": 152172,
|
| 1501 |
+
"<|action_token_499|>": 152173,
|
| 1502 |
+
"<|action_token_49|>": 151723,
|
| 1503 |
+
"<|action_token_4|>": 151678,
|
| 1504 |
+
"<|action_token_500|>": 152174,
|
| 1505 |
+
"<|action_token_501|>": 152175,
|
| 1506 |
+
"<|action_token_502|>": 152176,
|
| 1507 |
+
"<|action_token_503|>": 152177,
|
| 1508 |
+
"<|action_token_504|>": 152178,
|
| 1509 |
+
"<|action_token_505|>": 152179,
|
| 1510 |
+
"<|action_token_506|>": 152180,
|
| 1511 |
+
"<|action_token_507|>": 152181,
|
| 1512 |
+
"<|action_token_508|>": 152182,
|
| 1513 |
+
"<|action_token_509|>": 152183,
|
| 1514 |
+
"<|action_token_50|>": 151724,
|
| 1515 |
+
"<|action_token_510|>": 152184,
|
| 1516 |
+
"<|action_token_511|>": 152185,
|
| 1517 |
+
"<|action_token_512|>": 152186,
|
| 1518 |
+
"<|action_token_513|>": 152187,
|
| 1519 |
+
"<|action_token_514|>": 152188,
|
| 1520 |
+
"<|action_token_515|>": 152189,
|
| 1521 |
+
"<|action_token_516|>": 152190,
|
| 1522 |
+
"<|action_token_517|>": 152191,
|
| 1523 |
+
"<|action_token_518|>": 152192,
|
| 1524 |
+
"<|action_token_519|>": 152193,
|
| 1525 |
+
"<|action_token_51|>": 151725,
|
| 1526 |
+
"<|action_token_520|>": 152194,
|
| 1527 |
+
"<|action_token_521|>": 152195,
|
| 1528 |
+
"<|action_token_522|>": 152196,
|
| 1529 |
+
"<|action_token_523|>": 152197,
|
| 1530 |
+
"<|action_token_524|>": 152198,
|
| 1531 |
+
"<|action_token_525|>": 152199,
|
| 1532 |
+
"<|action_token_526|>": 152200,
|
| 1533 |
+
"<|action_token_527|>": 152201,
|
| 1534 |
+
"<|action_token_528|>": 152202,
|
| 1535 |
+
"<|action_token_529|>": 152203,
|
| 1536 |
+
"<|action_token_52|>": 151726,
|
| 1537 |
+
"<|action_token_530|>": 152204,
|
| 1538 |
+
"<|action_token_531|>": 152205,
|
| 1539 |
+
"<|action_token_532|>": 152206,
|
| 1540 |
+
"<|action_token_533|>": 152207,
|
| 1541 |
+
"<|action_token_534|>": 152208,
|
| 1542 |
+
"<|action_token_535|>": 152209,
|
| 1543 |
+
"<|action_token_536|>": 152210,
|
| 1544 |
+
"<|action_token_537|>": 152211,
|
| 1545 |
+
"<|action_token_538|>": 152212,
|
| 1546 |
+
"<|action_token_539|>": 152213,
|
| 1547 |
+
"<|action_token_53|>": 151727,
|
| 1548 |
+
"<|action_token_540|>": 152214,
|
| 1549 |
+
"<|action_token_541|>": 152215,
|
| 1550 |
+
"<|action_token_542|>": 152216,
|
| 1551 |
+
"<|action_token_543|>": 152217,
|
| 1552 |
+
"<|action_token_544|>": 152218,
|
| 1553 |
+
"<|action_token_545|>": 152219,
|
| 1554 |
+
"<|action_token_546|>": 152220,
|
| 1555 |
+
"<|action_token_547|>": 152221,
|
| 1556 |
+
"<|action_token_548|>": 152222,
|
| 1557 |
+
"<|action_token_549|>": 152223,
|
| 1558 |
+
"<|action_token_54|>": 151728,
|
| 1559 |
+
"<|action_token_550|>": 152224,
|
| 1560 |
+
"<|action_token_551|>": 152225,
|
| 1561 |
+
"<|action_token_552|>": 152226,
|
| 1562 |
+
"<|action_token_553|>": 152227,
|
| 1563 |
+
"<|action_token_554|>": 152228,
|
| 1564 |
+
"<|action_token_555|>": 152229,
|
| 1565 |
+
"<|action_token_556|>": 152230,
|
| 1566 |
+
"<|action_token_557|>": 152231,
|
| 1567 |
+
"<|action_token_558|>": 152232,
|
| 1568 |
+
"<|action_token_559|>": 152233,
|
| 1569 |
+
"<|action_token_55|>": 151729,
|
| 1570 |
+
"<|action_token_560|>": 152234,
|
| 1571 |
+
"<|action_token_561|>": 152235,
|
| 1572 |
+
"<|action_token_562|>": 152236,
|
| 1573 |
+
"<|action_token_563|>": 152237,
|
| 1574 |
+
"<|action_token_564|>": 152238,
|
| 1575 |
+
"<|action_token_565|>": 152239,
|
| 1576 |
+
"<|action_token_566|>": 152240,
|
| 1577 |
+
"<|action_token_567|>": 152241,
|
| 1578 |
+
"<|action_token_568|>": 152242,
|
| 1579 |
+
"<|action_token_569|>": 152243,
|
| 1580 |
+
"<|action_token_56|>": 151730,
|
| 1581 |
+
"<|action_token_570|>": 152244,
|
| 1582 |
+
"<|action_token_571|>": 152245,
|
| 1583 |
+
"<|action_token_572|>": 152246,
|
| 1584 |
+
"<|action_token_573|>": 152247,
|
| 1585 |
+
"<|action_token_574|>": 152248,
|
| 1586 |
+
"<|action_token_575|>": 152249,
|
| 1587 |
+
"<|action_token_576|>": 152250,
|
| 1588 |
+
"<|action_token_577|>": 152251,
|
| 1589 |
+
"<|action_token_578|>": 152252,
|
| 1590 |
+
"<|action_token_579|>": 152253,
|
| 1591 |
+
"<|action_token_57|>": 151731,
|
| 1592 |
+
"<|action_token_580|>": 152254,
|
| 1593 |
+
"<|action_token_581|>": 152255,
|
| 1594 |
+
"<|action_token_582|>": 152256,
|
| 1595 |
+
"<|action_token_583|>": 152257,
|
| 1596 |
+
"<|action_token_584|>": 152258,
|
| 1597 |
+
"<|action_token_585|>": 152259,
|
| 1598 |
+
"<|action_token_586|>": 152260,
|
| 1599 |
+
"<|action_token_587|>": 152261,
|
| 1600 |
+
"<|action_token_588|>": 152262,
|
| 1601 |
+
"<|action_token_589|>": 152263,
|
| 1602 |
+
"<|action_token_58|>": 151732,
|
| 1603 |
+
"<|action_token_590|>": 152264,
|
| 1604 |
+
"<|action_token_591|>": 152265,
|
| 1605 |
+
"<|action_token_592|>": 152266,
|
| 1606 |
+
"<|action_token_593|>": 152267,
|
| 1607 |
+
"<|action_token_594|>": 152268,
|
| 1608 |
+
"<|action_token_595|>": 152269,
|
| 1609 |
+
"<|action_token_596|>": 152270,
|
| 1610 |
+
"<|action_token_597|>": 152271,
|
| 1611 |
+
"<|action_token_598|>": 152272,
|
| 1612 |
+
"<|action_token_599|>": 152273,
|
| 1613 |
+
"<|action_token_59|>": 151733,
|
| 1614 |
+
"<|action_token_5|>": 151679,
|
| 1615 |
+
"<|action_token_600|>": 152274,
|
| 1616 |
+
"<|action_token_601|>": 152275,
|
| 1617 |
+
"<|action_token_602|>": 152276,
|
| 1618 |
+
"<|action_token_603|>": 152277,
|
| 1619 |
+
"<|action_token_604|>": 152278,
|
| 1620 |
+
"<|action_token_605|>": 152279,
|
| 1621 |
+
"<|action_token_606|>": 152280,
|
| 1622 |
+
"<|action_token_607|>": 152281,
|
| 1623 |
+
"<|action_token_608|>": 152282,
|
| 1624 |
+
"<|action_token_609|>": 152283,
|
| 1625 |
+
"<|action_token_60|>": 151734,
|
| 1626 |
+
"<|action_token_610|>": 152284,
|
| 1627 |
+
"<|action_token_611|>": 152285,
|
| 1628 |
+
"<|action_token_612|>": 152286,
|
| 1629 |
+
"<|action_token_613|>": 152287,
|
| 1630 |
+
"<|action_token_614|>": 152288,
|
| 1631 |
+
"<|action_token_615|>": 152289,
|
| 1632 |
+
"<|action_token_616|>": 152290,
|
| 1633 |
+
"<|action_token_617|>": 152291,
|
| 1634 |
+
"<|action_token_618|>": 152292,
|
| 1635 |
+
"<|action_token_619|>": 152293,
|
| 1636 |
+
"<|action_token_61|>": 151735,
|
| 1637 |
+
"<|action_token_620|>": 152294,
|
| 1638 |
+
"<|action_token_621|>": 152295,
|
| 1639 |
+
"<|action_token_622|>": 152296,
|
| 1640 |
+
"<|action_token_623|>": 152297,
|
| 1641 |
+
"<|action_token_624|>": 152298,
|
| 1642 |
+
"<|action_token_625|>": 152299,
|
| 1643 |
+
"<|action_token_626|>": 152300,
|
| 1644 |
+
"<|action_token_627|>": 152301,
|
| 1645 |
+
"<|action_token_628|>": 152302,
|
| 1646 |
+
"<|action_token_629|>": 152303,
|
| 1647 |
+
"<|action_token_62|>": 151736,
|
| 1648 |
+
"<|action_token_630|>": 152304,
|
| 1649 |
+
"<|action_token_631|>": 152305,
|
| 1650 |
+
"<|action_token_632|>": 152306,
|
| 1651 |
+
"<|action_token_633|>": 152307,
|
| 1652 |
+
"<|action_token_634|>": 152308,
|
| 1653 |
+
"<|action_token_635|>": 152309,
|
| 1654 |
+
"<|action_token_636|>": 152310,
|
| 1655 |
+
"<|action_token_637|>": 152311,
|
| 1656 |
+
"<|action_token_638|>": 152312,
|
| 1657 |
+
"<|action_token_639|>": 152313,
|
| 1658 |
+
"<|action_token_63|>": 151737,
|
| 1659 |
+
"<|action_token_640|>": 152314,
|
| 1660 |
+
"<|action_token_641|>": 152315,
|
| 1661 |
+
"<|action_token_642|>": 152316,
|
| 1662 |
+
"<|action_token_643|>": 152317,
|
| 1663 |
+
"<|action_token_644|>": 152318,
|
| 1664 |
+
"<|action_token_645|>": 152319,
|
| 1665 |
+
"<|action_token_646|>": 152320,
|
| 1666 |
+
"<|action_token_647|>": 152321,
|
| 1667 |
+
"<|action_token_648|>": 152322,
|
| 1668 |
+
"<|action_token_649|>": 152323,
|
| 1669 |
+
"<|action_token_64|>": 151738,
|
| 1670 |
+
"<|action_token_650|>": 152324,
|
| 1671 |
+
"<|action_token_651|>": 152325,
|
| 1672 |
+
"<|action_token_652|>": 152326,
|
| 1673 |
+
"<|action_token_653|>": 152327,
|
| 1674 |
+
"<|action_token_654|>": 152328,
|
| 1675 |
+
"<|action_token_655|>": 152329,
|
| 1676 |
+
"<|action_token_656|>": 152330,
|
| 1677 |
+
"<|action_token_657|>": 152331,
|
| 1678 |
+
"<|action_token_658|>": 152332,
|
| 1679 |
+
"<|action_token_659|>": 152333,
|
| 1680 |
+
"<|action_token_65|>": 151739,
|
| 1681 |
+
"<|action_token_660|>": 152334,
|
| 1682 |
+
"<|action_token_661|>": 152335,
|
| 1683 |
+
"<|action_token_662|>": 152336,
|
| 1684 |
+
"<|action_token_663|>": 152337,
|
| 1685 |
+
"<|action_token_664|>": 152338,
|
| 1686 |
+
"<|action_token_665|>": 152339,
|
| 1687 |
+
"<|action_token_666|>": 152340,
|
| 1688 |
+
"<|action_token_667|>": 152341,
|
| 1689 |
+
"<|action_token_668|>": 152342,
|
| 1690 |
+
"<|action_token_669|>": 152343,
|
| 1691 |
+
"<|action_token_66|>": 151740,
|
| 1692 |
+
"<|action_token_670|>": 152344,
|
| 1693 |
+
"<|action_token_671|>": 152345,
|
| 1694 |
+
"<|action_token_672|>": 152346,
|
| 1695 |
+
"<|action_token_673|>": 152347,
|
| 1696 |
+
"<|action_token_674|>": 152348,
|
| 1697 |
+
"<|action_token_675|>": 152349,
|
| 1698 |
+
"<|action_token_676|>": 152350,
|
| 1699 |
+
"<|action_token_677|>": 152351,
|
| 1700 |
+
"<|action_token_678|>": 152352,
|
| 1701 |
+
"<|action_token_679|>": 152353,
|
| 1702 |
+
"<|action_token_67|>": 151741,
|
| 1703 |
+
"<|action_token_680|>": 152354,
|
| 1704 |
+
"<|action_token_681|>": 152355,
|
| 1705 |
+
"<|action_token_682|>": 152356,
|
| 1706 |
+
"<|action_token_683|>": 152357,
|
| 1707 |
+
"<|action_token_684|>": 152358,
|
| 1708 |
+
"<|action_token_685|>": 152359,
|
| 1709 |
+
"<|action_token_686|>": 152360,
|
| 1710 |
+
"<|action_token_687|>": 152361,
|
| 1711 |
+
"<|action_token_688|>": 152362,
|
| 1712 |
+
"<|action_token_689|>": 152363,
|
| 1713 |
+
"<|action_token_68|>": 151742,
|
| 1714 |
+
"<|action_token_690|>": 152364,
|
| 1715 |
+
"<|action_token_691|>": 152365,
|
| 1716 |
+
"<|action_token_692|>": 152366,
|
| 1717 |
+
"<|action_token_693|>": 152367,
|
| 1718 |
+
"<|action_token_694|>": 152368,
|
| 1719 |
+
"<|action_token_695|>": 152369,
|
| 1720 |
+
"<|action_token_696|>": 152370,
|
| 1721 |
+
"<|action_token_697|>": 152371,
|
| 1722 |
+
"<|action_token_698|>": 152372,
|
| 1723 |
+
"<|action_token_699|>": 152373,
|
| 1724 |
+
"<|action_token_69|>": 151743,
|
| 1725 |
+
"<|action_token_6|>": 151680,
|
| 1726 |
+
"<|action_token_700|>": 152374,
|
| 1727 |
+
"<|action_token_701|>": 152375,
|
| 1728 |
+
"<|action_token_702|>": 152376,
|
| 1729 |
+
"<|action_token_703|>": 152377,
|
| 1730 |
+
"<|action_token_704|>": 152378,
|
| 1731 |
+
"<|action_token_705|>": 152379,
|
| 1732 |
+
"<|action_token_706|>": 152380,
|
| 1733 |
+
"<|action_token_707|>": 152381,
|
| 1734 |
+
"<|action_token_708|>": 152382,
|
| 1735 |
+
"<|action_token_709|>": 152383,
|
| 1736 |
+
"<|action_token_70|>": 151744,
|
| 1737 |
+
"<|action_token_710|>": 152384,
|
| 1738 |
+
"<|action_token_711|>": 152385,
|
| 1739 |
+
"<|action_token_712|>": 152386,
|
| 1740 |
+
"<|action_token_713|>": 152387,
|
| 1741 |
+
"<|action_token_714|>": 152388,
|
| 1742 |
+
"<|action_token_715|>": 152389,
|
| 1743 |
+
"<|action_token_716|>": 152390,
|
| 1744 |
+
"<|action_token_717|>": 152391,
|
| 1745 |
+
"<|action_token_718|>": 152392,
|
| 1746 |
+
"<|action_token_719|>": 152393,
|
| 1747 |
+
"<|action_token_71|>": 151745,
|
| 1748 |
+
"<|action_token_720|>": 152394,
|
| 1749 |
+
"<|action_token_721|>": 152395,
|
| 1750 |
+
"<|action_token_722|>": 152396,
|
| 1751 |
+
"<|action_token_723|>": 152397,
|
| 1752 |
+
"<|action_token_724|>": 152398,
|
| 1753 |
+
"<|action_token_725|>": 152399,
|
| 1754 |
+
"<|action_token_726|>": 152400,
|
| 1755 |
+
"<|action_token_727|>": 152401,
|
| 1756 |
+
"<|action_token_728|>": 152402,
|
| 1757 |
+
"<|action_token_729|>": 152403,
|
| 1758 |
+
"<|action_token_72|>": 151746,
|
| 1759 |
+
"<|action_token_730|>": 152404,
|
| 1760 |
+
"<|action_token_731|>": 152405,
|
| 1761 |
+
"<|action_token_732|>": 152406,
|
| 1762 |
+
"<|action_token_733|>": 152407,
|
| 1763 |
+
"<|action_token_734|>": 152408,
|
| 1764 |
+
"<|action_token_735|>": 152409,
|
| 1765 |
+
"<|action_token_736|>": 152410,
|
| 1766 |
+
"<|action_token_737|>": 152411,
|
| 1767 |
+
"<|action_token_738|>": 152412,
|
| 1768 |
+
"<|action_token_739|>": 152413,
|
| 1769 |
+
"<|action_token_73|>": 151747,
|
| 1770 |
+
"<|action_token_740|>": 152414,
|
| 1771 |
+
"<|action_token_741|>": 152415,
|
| 1772 |
+
"<|action_token_742|>": 152416,
|
| 1773 |
+
"<|action_token_743|>": 152417,
|
| 1774 |
+
"<|action_token_744|>": 152418,
|
| 1775 |
+
"<|action_token_745|>": 152419,
|
| 1776 |
+
"<|action_token_746|>": 152420,
|
| 1777 |
+
"<|action_token_747|>": 152421,
|
| 1778 |
+
"<|action_token_748|>": 152422,
|
| 1779 |
+
"<|action_token_749|>": 152423,
|
| 1780 |
+
"<|action_token_74|>": 151748,
|
| 1781 |
+
"<|action_token_750|>": 152424,
|
| 1782 |
+
"<|action_token_751|>": 152425,
|
| 1783 |
+
"<|action_token_752|>": 152426,
|
| 1784 |
+
"<|action_token_753|>": 152427,
|
| 1785 |
+
"<|action_token_754|>": 152428,
|
| 1786 |
+
"<|action_token_755|>": 152429,
|
| 1787 |
+
"<|action_token_756|>": 152430,
|
| 1788 |
+
"<|action_token_757|>": 152431,
|
| 1789 |
+
"<|action_token_758|>": 152432,
|
| 1790 |
+
"<|action_token_759|>": 152433,
|
| 1791 |
+
"<|action_token_75|>": 151749,
|
| 1792 |
+
"<|action_token_760|>": 152434,
|
| 1793 |
+
"<|action_token_761|>": 152435,
|
| 1794 |
+
"<|action_token_762|>": 152436,
|
| 1795 |
+
"<|action_token_763|>": 152437,
|
| 1796 |
+
"<|action_token_764|>": 152438,
|
| 1797 |
+
"<|action_token_765|>": 152439,
|
| 1798 |
+
"<|action_token_766|>": 152440,
|
| 1799 |
+
"<|action_token_767|>": 152441,
|
| 1800 |
+
"<|action_token_768|>": 152442,
|
| 1801 |
+
"<|action_token_769|>": 152443,
|
| 1802 |
+
"<|action_token_76|>": 151750,
|
| 1803 |
+
"<|action_token_770|>": 152444,
|
| 1804 |
+
"<|action_token_771|>": 152445,
|
| 1805 |
+
"<|action_token_772|>": 152446,
|
| 1806 |
+
"<|action_token_773|>": 152447,
|
| 1807 |
+
"<|action_token_774|>": 152448,
|
| 1808 |
+
"<|action_token_775|>": 152449,
|
| 1809 |
+
"<|action_token_776|>": 152450,
|
| 1810 |
+
"<|action_token_777|>": 152451,
|
| 1811 |
+
"<|action_token_778|>": 152452,
|
| 1812 |
+
"<|action_token_779|>": 152453,
|
| 1813 |
+
"<|action_token_77|>": 151751,
|
| 1814 |
+
"<|action_token_780|>": 152454,
|
| 1815 |
+
"<|action_token_781|>": 152455,
|
| 1816 |
+
"<|action_token_782|>": 152456,
|
| 1817 |
+
"<|action_token_783|>": 152457,
|
| 1818 |
+
"<|action_token_784|>": 152458,
|
| 1819 |
+
"<|action_token_785|>": 152459,
|
| 1820 |
+
"<|action_token_786|>": 152460,
|
| 1821 |
+
"<|action_token_787|>": 152461,
|
| 1822 |
+
"<|action_token_788|>": 152462,
|
| 1823 |
+
"<|action_token_789|>": 152463,
|
| 1824 |
+
"<|action_token_78|>": 151752,
|
| 1825 |
+
"<|action_token_790|>": 152464,
|
| 1826 |
+
"<|action_token_791|>": 152465,
|
| 1827 |
+
"<|action_token_792|>": 152466,
|
| 1828 |
+
"<|action_token_793|>": 152467,
|
| 1829 |
+
"<|action_token_794|>": 152468,
|
| 1830 |
+
"<|action_token_795|>": 152469,
|
| 1831 |
+
"<|action_token_796|>": 152470,
|
| 1832 |
+
"<|action_token_797|>": 152471,
|
| 1833 |
+
"<|action_token_798|>": 152472,
|
| 1834 |
+
"<|action_token_799|>": 152473,
|
| 1835 |
+
"<|action_token_79|>": 151753,
|
| 1836 |
+
"<|action_token_7|>": 151681,
|
| 1837 |
+
"<|action_token_800|>": 152474,
|
| 1838 |
+
"<|action_token_801|>": 152475,
|
| 1839 |
+
"<|action_token_802|>": 152476,
|
| 1840 |
+
"<|action_token_803|>": 152477,
|
| 1841 |
+
"<|action_token_804|>": 152478,
|
| 1842 |
+
"<|action_token_805|>": 152479,
|
| 1843 |
+
"<|action_token_806|>": 152480,
|
| 1844 |
+
"<|action_token_807|>": 152481,
|
| 1845 |
+
"<|action_token_808|>": 152482,
|
| 1846 |
+
"<|action_token_809|>": 152483,
|
| 1847 |
+
"<|action_token_80|>": 151754,
|
| 1848 |
+
"<|action_token_810|>": 152484,
|
| 1849 |
+
"<|action_token_811|>": 152485,
|
| 1850 |
+
"<|action_token_812|>": 152486,
|
| 1851 |
+
"<|action_token_813|>": 152487,
|
| 1852 |
+
"<|action_token_814|>": 152488,
|
| 1853 |
+
"<|action_token_815|>": 152489,
|
| 1854 |
+
"<|action_token_816|>": 152490,
|
| 1855 |
+
"<|action_token_817|>": 152491,
|
| 1856 |
+
"<|action_token_818|>": 152492,
|
| 1857 |
+
"<|action_token_819|>": 152493,
|
| 1858 |
+
"<|action_token_81|>": 151755,
|
| 1859 |
+
"<|action_token_820|>": 152494,
|
| 1860 |
+
"<|action_token_821|>": 152495,
|
| 1861 |
+
"<|action_token_822|>": 152496,
|
| 1862 |
+
"<|action_token_823|>": 152497,
|
| 1863 |
+
"<|action_token_824|>": 152498,
|
| 1864 |
+
"<|action_token_825|>": 152499,
|
| 1865 |
+
"<|action_token_826|>": 152500,
|
| 1866 |
+
"<|action_token_827|>": 152501,
|
| 1867 |
+
"<|action_token_828|>": 152502,
|
| 1868 |
+
"<|action_token_829|>": 152503,
|
| 1869 |
+
"<|action_token_82|>": 151756,
|
| 1870 |
+
"<|action_token_830|>": 152504,
|
| 1871 |
+
"<|action_token_831|>": 152505,
|
| 1872 |
+
"<|action_token_832|>": 152506,
|
| 1873 |
+
"<|action_token_833|>": 152507,
|
| 1874 |
+
"<|action_token_834|>": 152508,
|
| 1875 |
+
"<|action_token_835|>": 152509,
|
| 1876 |
+
"<|action_token_836|>": 152510,
|
| 1877 |
+
"<|action_token_837|>": 152511,
|
| 1878 |
+
"<|action_token_838|>": 152512,
|
| 1879 |
+
"<|action_token_839|>": 152513,
|
| 1880 |
+
"<|action_token_83|>": 151757,
|
| 1881 |
+
"<|action_token_840|>": 152514,
|
| 1882 |
+
"<|action_token_841|>": 152515,
|
| 1883 |
+
"<|action_token_842|>": 152516,
|
| 1884 |
+
"<|action_token_843|>": 152517,
|
| 1885 |
+
"<|action_token_844|>": 152518,
|
| 1886 |
+
"<|action_token_845|>": 152519,
|
| 1887 |
+
"<|action_token_846|>": 152520,
|
| 1888 |
+
"<|action_token_847|>": 152521,
|
| 1889 |
+
"<|action_token_848|>": 152522,
|
| 1890 |
+
"<|action_token_849|>": 152523,
|
| 1891 |
+
"<|action_token_84|>": 151758,
|
| 1892 |
+
"<|action_token_850|>": 152524,
|
| 1893 |
+
"<|action_token_851|>": 152525,
|
| 1894 |
+
"<|action_token_852|>": 152526,
|
| 1895 |
+
"<|action_token_853|>": 152527,
|
| 1896 |
+
"<|action_token_854|>": 152528,
|
| 1897 |
+
"<|action_token_855|>": 152529,
|
| 1898 |
+
"<|action_token_856|>": 152530,
|
| 1899 |
+
"<|action_token_857|>": 152531,
|
| 1900 |
+
"<|action_token_858|>": 152532,
|
| 1901 |
+
"<|action_token_859|>": 152533,
|
| 1902 |
+
"<|action_token_85|>": 151759,
|
| 1903 |
+
"<|action_token_860|>": 152534,
|
| 1904 |
+
"<|action_token_861|>": 152535,
|
| 1905 |
+
"<|action_token_862|>": 152536,
|
| 1906 |
+
"<|action_token_863|>": 152537,
|
| 1907 |
+
"<|action_token_864|>": 152538,
|
| 1908 |
+
"<|action_token_865|>": 152539,
|
| 1909 |
+
"<|action_token_866|>": 152540,
|
| 1910 |
+
"<|action_token_867|>": 152541,
|
| 1911 |
+
"<|action_token_868|>": 152542,
|
| 1912 |
+
"<|action_token_869|>": 152543,
|
| 1913 |
+
"<|action_token_86|>": 151760,
|
| 1914 |
+
"<|action_token_870|>": 152544,
|
| 1915 |
+
"<|action_token_871|>": 152545,
|
| 1916 |
+
"<|action_token_872|>": 152546,
|
| 1917 |
+
"<|action_token_873|>": 152547,
|
| 1918 |
+
"<|action_token_874|>": 152548,
|
| 1919 |
+
"<|action_token_875|>": 152549,
|
| 1920 |
+
"<|action_token_876|>": 152550,
|
| 1921 |
+
"<|action_token_877|>": 152551,
|
| 1922 |
+
"<|action_token_878|>": 152552,
|
| 1923 |
+
"<|action_token_879|>": 152553,
|
| 1924 |
+
"<|action_token_87|>": 151761,
|
| 1925 |
+
"<|action_token_880|>": 152554,
|
| 1926 |
+
"<|action_token_881|>": 152555,
|
| 1927 |
+
"<|action_token_882|>": 152556,
|
| 1928 |
+
"<|action_token_883|>": 152557,
|
| 1929 |
+
"<|action_token_884|>": 152558,
|
| 1930 |
+
"<|action_token_885|>": 152559,
|
| 1931 |
+
"<|action_token_886|>": 152560,
|
| 1932 |
+
"<|action_token_887|>": 152561,
|
| 1933 |
+
"<|action_token_888|>": 152562,
|
| 1934 |
+
"<|action_token_889|>": 152563,
|
| 1935 |
+
"<|action_token_88|>": 151762,
|
| 1936 |
+
"<|action_token_890|>": 152564,
|
| 1937 |
+
"<|action_token_891|>": 152565,
|
| 1938 |
+
"<|action_token_892|>": 152566,
|
| 1939 |
+
"<|action_token_893|>": 152567,
|
| 1940 |
+
"<|action_token_894|>": 152568,
|
| 1941 |
+
"<|action_token_895|>": 152569,
|
| 1942 |
+
"<|action_token_896|>": 152570,
|
| 1943 |
+
"<|action_token_897|>": 152571,
|
| 1944 |
+
"<|action_token_898|>": 152572,
|
| 1945 |
+
"<|action_token_899|>": 152573,
|
| 1946 |
+
"<|action_token_89|>": 151763,
|
| 1947 |
+
"<|action_token_8|>": 151682,
|
| 1948 |
+
"<|action_token_900|>": 152574,
|
| 1949 |
+
"<|action_token_901|>": 152575,
|
| 1950 |
+
"<|action_token_902|>": 152576,
|
| 1951 |
+
"<|action_token_903|>": 152577,
|
| 1952 |
+
"<|action_token_904|>": 152578,
|
| 1953 |
+
"<|action_token_905|>": 152579,
|
| 1954 |
+
"<|action_token_906|>": 152580,
|
| 1955 |
+
"<|action_token_907|>": 152581,
|
| 1956 |
+
"<|action_token_908|>": 152582,
|
| 1957 |
+
"<|action_token_909|>": 152583,
|
| 1958 |
+
"<|action_token_90|>": 151764,
|
| 1959 |
+
"<|action_token_910|>": 152584,
|
| 1960 |
+
"<|action_token_911|>": 152585,
|
| 1961 |
+
"<|action_token_912|>": 152586,
|
| 1962 |
+
"<|action_token_913|>": 152587,
|
| 1963 |
+
"<|action_token_914|>": 152588,
|
| 1964 |
+
"<|action_token_915|>": 152589,
|
| 1965 |
+
"<|action_token_916|>": 152590,
|
| 1966 |
+
"<|action_token_917|>": 152591,
|
| 1967 |
+
"<|action_token_918|>": 152592,
|
| 1968 |
+
"<|action_token_919|>": 152593,
|
| 1969 |
+
"<|action_token_91|>": 151765,
|
| 1970 |
+
"<|action_token_920|>": 152594,
|
| 1971 |
+
"<|action_token_921|>": 152595,
|
| 1972 |
+
"<|action_token_922|>": 152596,
|
| 1973 |
+
"<|action_token_923|>": 152597,
|
| 1974 |
+
"<|action_token_924|>": 152598,
|
| 1975 |
+
"<|action_token_925|>": 152599,
|
| 1976 |
+
"<|action_token_926|>": 152600,
|
| 1977 |
+
"<|action_token_927|>": 152601,
|
| 1978 |
+
"<|action_token_928|>": 152602,
|
| 1979 |
+
"<|action_token_929|>": 152603,
|
| 1980 |
+
"<|action_token_92|>": 151766,
|
| 1981 |
+
"<|action_token_930|>": 152604,
|
| 1982 |
+
"<|action_token_931|>": 152605,
|
| 1983 |
+
"<|action_token_932|>": 152606,
|
| 1984 |
+
"<|action_token_933|>": 152607,
|
| 1985 |
+
"<|action_token_934|>": 152608,
|
| 1986 |
+
"<|action_token_935|>": 152609,
|
| 1987 |
+
"<|action_token_936|>": 152610,
|
| 1988 |
+
"<|action_token_937|>": 152611,
|
| 1989 |
+
"<|action_token_938|>": 152612,
|
| 1990 |
+
"<|action_token_939|>": 152613,
|
| 1991 |
+
"<|action_token_93|>": 151767,
|
| 1992 |
+
"<|action_token_940|>": 152614,
|
| 1993 |
+
"<|action_token_941|>": 152615,
|
| 1994 |
+
"<|action_token_942|>": 152616,
|
| 1995 |
+
"<|action_token_943|>": 152617,
|
| 1996 |
+
"<|action_token_944|>": 152618,
|
| 1997 |
+
"<|action_token_945|>": 152619,
|
| 1998 |
+
"<|action_token_946|>": 152620,
|
| 1999 |
+
"<|action_token_947|>": 152621,
|
| 2000 |
+
"<|action_token_948|>": 152622,
|
| 2001 |
+
"<|action_token_949|>": 152623,
|
| 2002 |
+
"<|action_token_94|>": 151768,
|
| 2003 |
+
"<|action_token_950|>": 152624,
|
| 2004 |
+
"<|action_token_951|>": 152625,
|
| 2005 |
+
"<|action_token_952|>": 152626,
|
| 2006 |
+
"<|action_token_953|>": 152627,
|
| 2007 |
+
"<|action_token_954|>": 152628,
|
| 2008 |
+
"<|action_token_955|>": 152629,
|
| 2009 |
+
"<|action_token_956|>": 152630,
|
| 2010 |
+
"<|action_token_957|>": 152631,
|
| 2011 |
+
"<|action_token_958|>": 152632,
|
| 2012 |
+
"<|action_token_959|>": 152633,
|
| 2013 |
+
"<|action_token_95|>": 151769,
|
| 2014 |
+
"<|action_token_960|>": 152634,
|
| 2015 |
+
"<|action_token_961|>": 152635,
|
| 2016 |
+
"<|action_token_962|>": 152636,
|
| 2017 |
+
"<|action_token_963|>": 152637,
|
| 2018 |
+
"<|action_token_964|>": 152638,
|
| 2019 |
+
"<|action_token_965|>": 152639,
|
| 2020 |
+
"<|action_token_966|>": 152640,
|
| 2021 |
+
"<|action_token_967|>": 152641,
|
| 2022 |
+
"<|action_token_968|>": 152642,
|
| 2023 |
+
"<|action_token_969|>": 152643,
|
| 2024 |
+
"<|action_token_96|>": 151770,
|
| 2025 |
+
"<|action_token_970|>": 152644,
|
| 2026 |
+
"<|action_token_971|>": 152645,
|
| 2027 |
+
"<|action_token_972|>": 152646,
|
| 2028 |
+
"<|action_token_973|>": 152647,
|
| 2029 |
+
"<|action_token_974|>": 152648,
|
| 2030 |
+
"<|action_token_975|>": 152649,
|
| 2031 |
+
"<|action_token_976|>": 152650,
|
| 2032 |
+
"<|action_token_977|>": 152651,
|
| 2033 |
+
"<|action_token_978|>": 152652,
|
| 2034 |
+
"<|action_token_979|>": 152653,
|
| 2035 |
+
"<|action_token_97|>": 151771,
|
| 2036 |
+
"<|action_token_980|>": 152654,
|
| 2037 |
+
"<|action_token_981|>": 152655,
|
| 2038 |
+
"<|action_token_982|>": 152656,
|
| 2039 |
+
"<|action_token_983|>": 152657,
|
| 2040 |
+
"<|action_token_984|>": 152658,
|
| 2041 |
+
"<|action_token_985|>": 152659,
|
| 2042 |
+
"<|action_token_986|>": 152660,
|
| 2043 |
+
"<|action_token_987|>": 152661,
|
| 2044 |
+
"<|action_token_988|>": 152662,
|
| 2045 |
+
"<|action_token_989|>": 152663,
|
| 2046 |
+
"<|action_token_98|>": 151772,
|
| 2047 |
+
"<|action_token_990|>": 152664,
|
| 2048 |
+
"<|action_token_991|>": 152665,
|
| 2049 |
+
"<|action_token_992|>": 152666,
|
| 2050 |
+
"<|action_token_993|>": 152667,
|
| 2051 |
+
"<|action_token_994|>": 152668,
|
| 2052 |
+
"<|action_token_995|>": 152669,
|
| 2053 |
+
"<|action_token_996|>": 152670,
|
| 2054 |
+
"<|action_token_997|>": 152671,
|
| 2055 |
+
"<|action_token_998|>": 152672,
|
| 2056 |
+
"<|action_token_999|>": 152673,
|
| 2057 |
+
"<|action_token_99|>": 151773,
|
| 2058 |
+
"<|action_token_9|>": 151683,
|
| 2059 |
+
"<|box_end|>": 151649,
|
| 2060 |
+
"<|box_start|>": 151648,
|
| 2061 |
+
"<|endoftext|>": 151643,
|
| 2062 |
+
"<|file_sep|>": 151664,
|
| 2063 |
+
"<|fim_middle|>": 151660,
|
| 2064 |
+
"<|fim_pad|>": 151662,
|
| 2065 |
+
"<|fim_prefix|>": 151659,
|
| 2066 |
+
"<|fim_suffix|>": 151661,
|
| 2067 |
+
"<|goal_repr|>": 151672,
|
| 2068 |
+
"<|im_end|>": 151645,
|
| 2069 |
+
"<|im_start|>": 151644,
|
| 2070 |
+
"<|image_pad|>": 151655,
|
| 2071 |
+
"<|object_ref_end|>": 151647,
|
| 2072 |
+
"<|object_ref_start|>": 151646,
|
| 2073 |
+
"<|obs_repr|>": 151673,
|
| 2074 |
+
"<|quad_end|>": 151651,
|
| 2075 |
+
"<|quad_start|>": 151650,
|
| 2076 |
+
"<|repo_name|>": 151663,
|
| 2077 |
+
"<|video_pad|>": 151656,
|
| 2078 |
+
"<|vision_end|>": 151653,
|
| 2079 |
+
"<|vision_pad|>": 151654,
|
| 2080 |
+
"<|vision_start|>": 151652
|
| 2081 |
+
}
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{%- if messages[0].content is string %}
|
| 5 |
+
{{- messages[0].content }}
|
| 6 |
+
{%- else %}
|
| 7 |
+
{%- for content in messages[0].content %}
|
| 8 |
+
{%- if 'text' in content %}
|
| 9 |
+
{{- content.text }}
|
| 10 |
+
{%- endif %}
|
| 11 |
+
{%- endfor %}
|
| 12 |
+
{%- endif %}
|
| 13 |
+
{{- '\n\n' }}
|
| 14 |
+
{%- endif %}
|
| 15 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 16 |
+
{%- for tool in tools %}
|
| 17 |
+
{{- "\n" }}
|
| 18 |
+
{{- tool | tojson }}
|
| 19 |
+
{%- endfor %}
|
| 20 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 21 |
+
{%- else %}
|
| 22 |
+
{%- if messages[0].role == 'system' %}
|
| 23 |
+
{{- '<|im_start|>system\n' }}
|
| 24 |
+
{%- if messages[0].content is string %}
|
| 25 |
+
{{- messages[0].content }}
|
| 26 |
+
{%- else %}
|
| 27 |
+
{%- for content in messages[0].content %}
|
| 28 |
+
{%- if 'text' in content %}
|
| 29 |
+
{{- content.text }}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- endfor %}
|
| 32 |
+
{%- endif %}
|
| 33 |
+
{{- '<|im_end|>\n' }}
|
| 34 |
+
{%- endif %}
|
| 35 |
+
{%- endif %}
|
| 36 |
+
{%- set image_count = namespace(value=0) %}
|
| 37 |
+
{%- set video_count = namespace(value=0) %}
|
| 38 |
+
{%- for message in messages %}
|
| 39 |
+
{%- if message.role == "user" %}
|
| 40 |
+
{{- '<|im_start|>' + message.role + '\n' }}
|
| 41 |
+
{%- if message.content is string %}
|
| 42 |
+
{{- message.content }}
|
| 43 |
+
{%- else %}
|
| 44 |
+
{%- for content in message.content %}
|
| 45 |
+
{%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
|
| 46 |
+
{%- set image_count.value = image_count.value + 1 %}
|
| 47 |
+
{%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
|
| 48 |
+
<|vision_start|><|image_pad|><|vision_end|>
|
| 49 |
+
{%- elif content.type == 'video' or 'video' in content %}
|
| 50 |
+
{%- set video_count.value = video_count.value + 1 %}
|
| 51 |
+
{%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
|
| 52 |
+
<|vision_start|><|video_pad|><|vision_end|>
|
| 53 |
+
{%- elif 'text' in content %}
|
| 54 |
+
{{- content.text }}
|
| 55 |
+
{%- endif %}
|
| 56 |
+
{%- endfor %}
|
| 57 |
+
{%- endif %}
|
| 58 |
+
{{- '<|im_end|>\n' }}
|
| 59 |
+
{%- elif message.role == "assistant" %}
|
| 60 |
+
{{- '<|im_start|>' + message.role + '\n' }}
|
| 61 |
+
{%- if message.content is string %}
|
| 62 |
+
{{- message.content }}
|
| 63 |
+
{%- else %}
|
| 64 |
+
{%- for content_item in message.content %}
|
| 65 |
+
{%- if 'text' in content_item %}
|
| 66 |
+
{{- content_item.text }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{%- endfor %}
|
| 69 |
+
{%- endif %}
|
| 70 |
+
{%- if message.tool_calls %}
|
| 71 |
+
{%- for tool_call in message.tool_calls %}
|
| 72 |
+
{%- if (loop.first and message.content) or (not loop.first) %}
|
| 73 |
+
{{- '\n' }}
|
| 74 |
+
{%- endif %}
|
| 75 |
+
{%- if tool_call.function %}
|
| 76 |
+
{%- set tool_call = tool_call.function %}
|
| 77 |
+
{%- endif %}
|
| 78 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 79 |
+
{{- tool_call.name }}
|
| 80 |
+
{{- '", "arguments": ' }}
|
| 81 |
+
{%- if tool_call.arguments is string %}
|
| 82 |
+
{{- tool_call.arguments }}
|
| 83 |
+
{%- else %}
|
| 84 |
+
{{- tool_call.arguments | tojson }}
|
| 85 |
+
{%- endif %}
|
| 86 |
+
{{- '}\n</tool_call>' }}
|
| 87 |
+
{%- endfor %}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{{- '<|im_end|>\n' }}
|
| 90 |
+
{%- elif message.role == "tool" %}
|
| 91 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 92 |
+
{{- '<|im_start|>user' }}
|
| 93 |
+
{%- endif %}
|
| 94 |
+
{{- '\n<tool_response>\n' }}
|
| 95 |
+
{%- if message.content is string %}
|
| 96 |
+
{{- message.content }}
|
| 97 |
+
{%- else %}
|
| 98 |
+
{%- for content in message.content %}
|
| 99 |
+
{%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
|
| 100 |
+
{%- set image_count.value = image_count.value + 1 %}
|
| 101 |
+
{%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
|
| 102 |
+
<|vision_start|><|image_pad|><|vision_end|>
|
| 103 |
+
{%- elif content.type == 'video' or 'video' in content %}
|
| 104 |
+
{%- set video_count.value = video_count.value + 1 %}
|
| 105 |
+
{%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
|
| 106 |
+
<|vision_start|><|video_pad|><|vision_end|>
|
| 107 |
+
{%- elif 'text' in content %}
|
| 108 |
+
{{- content.text }}
|
| 109 |
+
{%- endif %}
|
| 110 |
+
{%- endfor %}
|
| 111 |
+
{%- endif %}
|
| 112 |
+
{{- '\n</tool_response>' }}
|
| 113 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 114 |
+
{{- '<|im_end|>\n' }}
|
| 115 |
+
{%- endif %}
|
| 116 |
+
{%- endif %}
|
| 117 |
+
{%- endfor %}
|
| 118 |
+
{%- if add_generation_prompt %}
|
| 119 |
+
{{- '<|im_start|>assistant\n' }}
|
| 120 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"action_chunk_size": 20,
|
| 3 |
+
"action_expert_config": {
|
| 4 |
+
"action_end_token_id": null,
|
| 5 |
+
"action_start_token_id": 151669,
|
| 6 |
+
"action_token_id": 151670,
|
| 7 |
+
"attention_bias": false,
|
| 8 |
+
"attention_dropout": 0.0,
|
| 9 |
+
"bos_token_id": 151643,
|
| 10 |
+
"crl_goal_repr_token_id": 151672,
|
| 11 |
+
"crl_obs_repr_token_id": 151673,
|
| 12 |
+
"dtype": "bfloat16",
|
| 13 |
+
"eos_token_id": 151645,
|
| 14 |
+
"head_dim": 128,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 1280,
|
| 17 |
+
"image_token_id": 151655,
|
| 18 |
+
"initializer_range": 0.02,
|
| 19 |
+
"intermediate_size": 2432,
|
| 20 |
+
"max_position_embeddings": 262144,
|
| 21 |
+
"model_type": "prts_qwen3_vl_text",
|
| 22 |
+
"num_attention_heads": 32,
|
| 23 |
+
"num_hidden_layers": 36,
|
| 24 |
+
"num_key_value_heads": 8,
|
| 25 |
+
"rms_norm_eps": 1e-06,
|
| 26 |
+
"rope_scaling": {
|
| 27 |
+
"mrope_interleaved": true,
|
| 28 |
+
"mrope_section": [
|
| 29 |
+
24,
|
| 30 |
+
20,
|
| 31 |
+
20
|
| 32 |
+
],
|
| 33 |
+
"rope_type": "default"
|
| 34 |
+
},
|
| 35 |
+
"rope_theta": 5000000,
|
| 36 |
+
"tie_word_embeddings": true,
|
| 37 |
+
"use_cache": true,
|
| 38 |
+
"video_token_id": 151656,
|
| 39 |
+
"vision_start_token_id": 151652,
|
| 40 |
+
"vocab_size": 153722
|
| 41 |
+
},
|
| 42 |
+
"action_start_token_id": 151669,
|
| 43 |
+
"architectures": [
|
| 44 |
+
"PRTS_Qwen3VL"
|
| 45 |
+
],
|
| 46 |
+
"auto_map": {
|
| 47 |
+
"AutoConfig": "configuration_prts_qwen3_vl.PRTS_FlowMatchingConfig_Qwen3VL",
|
| 48 |
+
"AutoModel": "modeling_prts_qwen3_vl.PRTS_Qwen3VL"
|
| 49 |
+
},
|
| 50 |
+
"crl_embed_dim": 256,
|
| 51 |
+
"crl_encoder_init_w": 0.001,
|
| 52 |
+
"crl_goal_repr_token_id": 151672,
|
| 53 |
+
"crl_logsumexp_reg_weight": 0.0,
|
| 54 |
+
"crl_loss_weight": 0.0,
|
| 55 |
+
"crl_obs_repr_token_id": 151673,
|
| 56 |
+
"crl_repr_norm": true,
|
| 57 |
+
"dit_action_head_config": {
|
| 58 |
+
"add_pos_embed": true,
|
| 59 |
+
"attend_text_every_n_blocks": 2,
|
| 60 |
+
"attention_head_dim": 48,
|
| 61 |
+
"attn_implementation": "sdpa",
|
| 62 |
+
"dropout": 0.2,
|
| 63 |
+
"final_dropout": true,
|
| 64 |
+
"interleave_self_attention": true,
|
| 65 |
+
"mlp_mult": 4,
|
| 66 |
+
"noise_beta_alpha": 1.5,
|
| 67 |
+
"noise_beta_beta": 1.0,
|
| 68 |
+
"noise_s": 0.999,
|
| 69 |
+
"norm_type": "ada_norm",
|
| 70 |
+
"num_attention_heads": 32,
|
| 71 |
+
"num_layers": 16,
|
| 72 |
+
"num_timestep_buckets": 1000,
|
| 73 |
+
"output_dim": 1024,
|
| 74 |
+
"use_alternate_vl_dit": true,
|
| 75 |
+
"use_mot_action_expert": true
|
| 76 |
+
},
|
| 77 |
+
"dtype": "bfloat16",
|
| 78 |
+
"embodiment_tag": "libero_panda",
|
| 79 |
+
"flow_matching_action_loss_weight": 1.0,
|
| 80 |
+
"flow_matching_sub_goal_loss_weight": 0.0,
|
| 81 |
+
"image_token_id": 151655,
|
| 82 |
+
"label2id": null,
|
| 83 |
+
"max_action_dim": 32,
|
| 84 |
+
"model_type": "prts_qwen3_vl",
|
| 85 |
+
"num_denoise_steps": 5,
|
| 86 |
+
"pad_token_id": 151643,
|
| 87 |
+
"text_config": {
|
| 88 |
+
"action_end_token_id": null,
|
| 89 |
+
"action_start_token_id": 151669,
|
| 90 |
+
"action_token_id": 151670,
|
| 91 |
+
"attention_bias": false,
|
| 92 |
+
"attention_dropout": 0.0,
|
| 93 |
+
"bos_token_id": 151643,
|
| 94 |
+
"crl_goal_repr_token_id": 151672,
|
| 95 |
+
"crl_obs_repr_token_id": 151673,
|
| 96 |
+
"dtype": "bfloat16",
|
| 97 |
+
"eos_token_id": 151645,
|
| 98 |
+
"head_dim": 128,
|
| 99 |
+
"hidden_act": "silu",
|
| 100 |
+
"hidden_size": 2560,
|
| 101 |
+
"image_token_id": 151655,
|
| 102 |
+
"initializer_range": 0.02,
|
| 103 |
+
"intermediate_size": 9728,
|
| 104 |
+
"max_position_embeddings": 262144,
|
| 105 |
+
"model_type": "prts_qwen3_vl_text",
|
| 106 |
+
"num_attention_heads": 32,
|
| 107 |
+
"num_hidden_layers": 36,
|
| 108 |
+
"num_key_value_heads": 8,
|
| 109 |
+
"rms_norm_eps": 1e-06,
|
| 110 |
+
"rope_scaling": {
|
| 111 |
+
"mrope_interleaved": true,
|
| 112 |
+
"mrope_section": [
|
| 113 |
+
24,
|
| 114 |
+
20,
|
| 115 |
+
20
|
| 116 |
+
],
|
| 117 |
+
"rope_type": "default"
|
| 118 |
+
},
|
| 119 |
+
"rope_theta": 5000000,
|
| 120 |
+
"tie_word_embeddings": true,
|
| 121 |
+
"use_cache": false,
|
| 122 |
+
"video_token_id": 151656,
|
| 123 |
+
"vision_start_token_id": 151652,
|
| 124 |
+
"vocab_size": 153722
|
| 125 |
+
},
|
| 126 |
+
"tie_word_embeddings": true,
|
| 127 |
+
"transformers_version": "4.57.3",
|
| 128 |
+
"use_cache": true,
|
| 129 |
+
"use_fast_action_tokenizer": true,
|
| 130 |
+
"video_token_id": 151656,
|
| 131 |
+
"vision_config": {
|
| 132 |
+
"deepstack_visual_indexes": [
|
| 133 |
+
5,
|
| 134 |
+
11,
|
| 135 |
+
17
|
| 136 |
+
],
|
| 137 |
+
"depth": 24,
|
| 138 |
+
"dtype": "bfloat16",
|
| 139 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 140 |
+
"hidden_size": 1024,
|
| 141 |
+
"in_channels": 3,
|
| 142 |
+
"initializer_range": 0.02,
|
| 143 |
+
"intermediate_size": 4096,
|
| 144 |
+
"model_type": "qwen3_vl",
|
| 145 |
+
"num_heads": 16,
|
| 146 |
+
"num_position_embeddings": 2304,
|
| 147 |
+
"out_hidden_size": 2560,
|
| 148 |
+
"patch_size": 16,
|
| 149 |
+
"spatial_merge_size": 2,
|
| 150 |
+
"temporal_patch_size": 2
|
| 151 |
+
},
|
| 152 |
+
"vision_end_token_id": 151653,
|
| 153 |
+
"vision_start_token_id": 151652,
|
| 154 |
+
"vocab_size": 153722
|
| 155 |
+
}
|
configuration_prts_qwen3_vl.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 TeleAI Rhodes Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Configuration classes for PRTS built on Qwen3-VL."""
|
| 16 |
+
|
| 17 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 19 |
+
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PRTS_Qwen3VLTextConfig(PretrainedConfig):
|
| 23 |
+
r"""
|
| 24 |
+
This is the configuration class to store the configuration of a PRTS Text Model based on Qwen3-VL.
|
| 25 |
+
It extends PretrainedConfig with Qwen3-VL text model parameters and PRTS-specific parameters.
|
| 26 |
+
|
| 27 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
vocab_size (`int`, *optional*, defaults to 151936):
|
| 31 |
+
Vocabulary size of the Qwen3VL model.
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 33 |
+
Dimension of the hidden representations.
|
| 34 |
+
intermediate_size (`int`, *optional*, defaults to 22016):
|
| 35 |
+
Dimension of the MLP representations.
|
| 36 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 37 |
+
Number of hidden layers in the Transformer encoder.
|
| 38 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 39 |
+
Number of attention heads for each attention layer.
|
| 40 |
+
num_key_value_heads (`int`, *optional*, defaults to 32):
|
| 41 |
+
Number of key-value heads for Grouped Query Attention.
|
| 42 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 43 |
+
The dimension of the head.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 45 |
+
The non-linear activation function.
|
| 46 |
+
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
| 47 |
+
The maximum sequence length.
|
| 48 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 49 |
+
The standard deviation of the truncated_normal_initializer.
|
| 50 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 51 |
+
The epsilon used by the rms normalization layers.
|
| 52 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 53 |
+
Whether or not the model should return the last key/values attentions.
|
| 54 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 55 |
+
Whether the model's input and output word embeddings should be tied.
|
| 56 |
+
rope_theta (`float`, *optional*, defaults to 5000000.0):
|
| 57 |
+
The base period of the RoPE embeddings.
|
| 58 |
+
rope_scaling (`Dict`, *optional*):
|
| 59 |
+
Dictionary containing the scaling configuration for the RoPE embeddings.
|
| 60 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 61 |
+
Whether to use a bias in the query, key, value and output projection layers.
|
| 62 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 63 |
+
The dropout ratio for the attention probabilities.
|
| 64 |
+
image_token_id (`int`, *optional*):
|
| 65 |
+
Token index used as placeholder for image embeddings.
|
| 66 |
+
video_token_id (`int`, *optional*):
|
| 67 |
+
Token index used as placeholder for video embeddings.
|
| 68 |
+
action_token_id (`int`, *optional*):
|
| 69 |
+
Token index used as placeholder for action embeddings.
|
| 70 |
+
action_start_token_id (`int`, *optional*):
|
| 71 |
+
Token index for action sequence start.
|
| 72 |
+
action_end_token_id (`int`, *optional*):
|
| 73 |
+
Token index for action sequence end.
|
| 74 |
+
vision_start_token_id (`int`, *optional*):
|
| 75 |
+
Token index for vision sequence start.
|
| 76 |
+
**kwargs:
|
| 77 |
+
Additional keyword arguments passed to PretrainedConfig.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
model_type = "prts_qwen3_vl_text" # TODO (zy): check if this is correct
|
| 81 |
+
base_config_key = "text_config"
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
vocab_size=151936,
|
| 86 |
+
hidden_size=4096,
|
| 87 |
+
intermediate_size=22016,
|
| 88 |
+
num_hidden_layers=32,
|
| 89 |
+
num_attention_heads=32,
|
| 90 |
+
num_key_value_heads=32,
|
| 91 |
+
head_dim=128,
|
| 92 |
+
hidden_act="silu",
|
| 93 |
+
max_position_embeddings=128000,
|
| 94 |
+
initializer_range=0.02,
|
| 95 |
+
rms_norm_eps=1e-6,
|
| 96 |
+
use_cache=True,
|
| 97 |
+
tie_word_embeddings=False,
|
| 98 |
+
rope_theta=5000000.0,
|
| 99 |
+
rope_scaling=None,
|
| 100 |
+
attention_bias=False,
|
| 101 |
+
attention_dropout=0.0,
|
| 102 |
+
# PRTS specific
|
| 103 |
+
action_token_id=None,
|
| 104 |
+
action_start_token_id=None,
|
| 105 |
+
action_end_token_id=None,
|
| 106 |
+
crl_goal_repr_token_id=None,
|
| 107 |
+
crl_obs_repr_token_id=None,
|
| 108 |
+
**kwargs,
|
| 109 |
+
):
|
| 110 |
+
self.vocab_size = vocab_size
|
| 111 |
+
self.max_position_embeddings = max_position_embeddings
|
| 112 |
+
self.hidden_size = hidden_size
|
| 113 |
+
self.intermediate_size = intermediate_size
|
| 114 |
+
self.num_hidden_layers = num_hidden_layers
|
| 115 |
+
self.num_attention_heads = num_attention_heads
|
| 116 |
+
|
| 117 |
+
# for backward compatibility
|
| 118 |
+
if num_key_value_heads is None:
|
| 119 |
+
num_key_value_heads = num_attention_heads
|
| 120 |
+
|
| 121 |
+
self.num_key_value_heads = num_key_value_heads
|
| 122 |
+
self.head_dim = head_dim
|
| 123 |
+
self.hidden_act = hidden_act
|
| 124 |
+
self.initializer_range = initializer_range
|
| 125 |
+
self.rms_norm_eps = rms_norm_eps
|
| 126 |
+
self.use_cache = use_cache
|
| 127 |
+
self.rope_theta = rope_theta
|
| 128 |
+
self.rope_scaling = rope_scaling
|
| 129 |
+
self.attention_bias = attention_bias
|
| 130 |
+
self.attention_dropout = attention_dropout
|
| 131 |
+
|
| 132 |
+
# Validate rope config
|
| 133 |
+
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
|
| 134 |
+
|
| 135 |
+
# PRTS specific token IDs
|
| 136 |
+
self.action_token_id = action_token_id
|
| 137 |
+
self.action_start_token_id = action_start_token_id
|
| 138 |
+
self.action_end_token_id = action_end_token_id
|
| 139 |
+
self.crl_goal_repr_token_id = crl_goal_repr_token_id
|
| 140 |
+
self.crl_obs_repr_token_id = crl_obs_repr_token_id
|
| 141 |
+
|
| 142 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class PRTS_FlowMatchingConfig_Qwen3VL(PretrainedConfig):
|
| 146 |
+
r"""
|
| 147 |
+
This is the configuration class to store the configuration of a PRTS model based on Qwen3-VL.
|
| 148 |
+
It extends PretrainedConfig with Qwen3-VL model parameters and PRTS-specific parameters for action prediction.
|
| 149 |
+
|
| 150 |
+
[`PRTS_FlowMatchingConfig_Qwen3VL`] is the configuration class to store the configuration of a PRTS model. It is used to
|
| 151 |
+
instantiate a PRTS model according to the specified arguments, defining the vision encoder, text encoder,
|
| 152 |
+
action expert, and flow matching components.
|
| 153 |
+
|
| 154 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PRTS_Qwen3VLTextConfig`):
|
| 158 |
+
The config object or dictionary of the text backbone.
|
| 159 |
+
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
|
| 160 |
+
The config object or dictionary of the vision backbone.
|
| 161 |
+
max_action_dim (`int`, *optional*, defaults to 14):
|
| 162 |
+
Maximum dimension of action vectors. Used for padding different robot action spaces.
|
| 163 |
+
action_chunk_size (`int`, *optional*, defaults to 100):
|
| 164 |
+
Number of action timesteps to predict in each forward pass.
|
| 165 |
+
num_denoise_steps (`int`, *optional*, defaults to 4):
|
| 166 |
+
Number of denoising steps for flow matching during inference.
|
| 167 |
+
flow_matching_action_loss_weight (`float`, *optional*, defaults to 1.0):
|
| 168 |
+
Weight for the flow matching action loss.
|
| 169 |
+
crl_loss_weight (`float`, *optional*, defaults to 0.0):
|
| 170 |
+
Weight for the Contrastive Reinforcement Learning (CRL) loss. Set to 0 to disable.
|
| 171 |
+
crl_embed_dim (`int`, *optional*, defaults to 256):
|
| 172 |
+
Dimension of the CRL embedding space for action and goal encoders.
|
| 173 |
+
crl_logsumexp_reg_weight (`float`, *optional*, defaults to 0.0):
|
| 174 |
+
Weight for logsumexp regularization on CRL logits.
|
| 175 |
+
image_token_id (`int`, *optional*):
|
| 176 |
+
Token id for image placeholders.
|
| 177 |
+
video_token_id (`int`, *optional*):
|
| 178 |
+
Token id for video placeholders.
|
| 179 |
+
vision_start_token_id (`int`, *optional*):
|
| 180 |
+
Token id for vision start marker.
|
| 181 |
+
vision_end_token_id (`int`, *optional*):
|
| 182 |
+
Token id for vision end marker.
|
| 183 |
+
**kwargs:
|
| 184 |
+
Additional keyword arguments passed to PretrainedConfig.
|
| 185 |
+
|
| 186 |
+
Example:
|
| 187 |
+
|
| 188 |
+
```python
|
| 189 |
+
>>> from prts.models import PRTS_FlowMatchingConfig_Qwen3VL, PRTS_Qwen3VL
|
| 190 |
+
|
| 191 |
+
>>> # Initializing a PRTS Qwen3-VL configuration
|
| 192 |
+
>>> configuration = PRTS_FlowMatchingConfig_Qwen3VL()
|
| 193 |
+
|
| 194 |
+
>>> # Initializing a model from the configuration
|
| 195 |
+
>>> model = PRTS_Qwen3VL(configuration)
|
| 196 |
+
|
| 197 |
+
>>> # Accessing the model configuration
|
| 198 |
+
>>> configuration = model.config
|
| 199 |
+
```
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
model_type = "prts_qwen3_vl"
|
| 203 |
+
sub_configs = {
|
| 204 |
+
"vision_config": Qwen3VLVisionConfig,
|
| 205 |
+
"text_config": PRTS_Qwen3VLTextConfig,
|
| 206 |
+
}
|
| 207 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
text_config=None,
|
| 212 |
+
vision_config=None,
|
| 213 |
+
image_token_id=151655,
|
| 214 |
+
video_token_id=151656,
|
| 215 |
+
vision_start_token_id=151652,
|
| 216 |
+
vision_end_token_id=151653,
|
| 217 |
+
tie_word_embeddings=False,
|
| 218 |
+
# PRTS specific
|
| 219 |
+
max_action_dim=32,
|
| 220 |
+
action_chunk_size=50,
|
| 221 |
+
num_denoise_steps=4,
|
| 222 |
+
flow_matching_action_loss_weight=0.,
|
| 223 |
+
use_fast_action_tokenizer=True,
|
| 224 |
+
# Embodiment tag: identifies the robot embodiment used for finetuning.
|
| 225 |
+
# Stores the delta_action_mask key so eval code can recover it without
|
| 226 |
+
# needing the training dataset config.
|
| 227 |
+
embodiment_tag=None,
|
| 228 |
+
# DiT action head config
|
| 229 |
+
dit_action_head_config=None,
|
| 230 |
+
# CRL (Contrastive Reinforcement Learning) parameters
|
| 231 |
+
crl_loss_weight=0.,
|
| 232 |
+
crl_embed_dim=256,
|
| 233 |
+
crl_logsumexp_reg_weight=0.0,
|
| 234 |
+
crl_encoder_init_w=1e-12, # Cold initialization weight for encoder last layer
|
| 235 |
+
crl_repr_norm=True, # Whether to L2-normalize CRL representations
|
| 236 |
+
**kwargs,
|
| 237 |
+
):
|
| 238 |
+
# Initialize vision config
|
| 239 |
+
if isinstance(vision_config, dict):
|
| 240 |
+
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
| 241 |
+
elif vision_config is None:
|
| 242 |
+
self.vision_config = self.sub_configs["vision_config"]()
|
| 243 |
+
|
| 244 |
+
# Initialize text config
|
| 245 |
+
if isinstance(text_config, dict):
|
| 246 |
+
self.text_config = self.sub_configs["text_config"](**text_config)
|
| 247 |
+
elif text_config is None:
|
| 248 |
+
# For BC use all kwargs to init `TextConfig`
|
| 249 |
+
self.text_config = self.sub_configs["text_config"](**kwargs)
|
| 250 |
+
|
| 251 |
+
# PRTS-specific parameters
|
| 252 |
+
self.max_action_dim = max_action_dim
|
| 253 |
+
self.action_chunk_size = action_chunk_size
|
| 254 |
+
self.num_denoise_steps = num_denoise_steps
|
| 255 |
+
self.flow_matching_action_loss_weight = flow_matching_action_loss_weight
|
| 256 |
+
self.use_fast_action_tokenizer = use_fast_action_tokenizer
|
| 257 |
+
self.embodiment_tag = embodiment_tag
|
| 258 |
+
|
| 259 |
+
# DiT action head config (nested dict)
|
| 260 |
+
# cross_attention_dim defaults to text_config.hidden_size at model init time
|
| 261 |
+
_default_dit_config = {
|
| 262 |
+
# Architecture — aligned with GR00T N1.6 (32 layers, inner_dim=32×48=1536)
|
| 263 |
+
"num_layers": 16, # 32
|
| 264 |
+
"num_attention_heads": 32,
|
| 265 |
+
"attention_head_dim": 48,
|
| 266 |
+
"output_dim": 1024,
|
| 267 |
+
# Regularisation
|
| 268 |
+
"dropout": 0.2,
|
| 269 |
+
"interleave_self_attention": True,
|
| 270 |
+
"norm_type": "ada_norm",
|
| 271 |
+
"final_dropout": True,
|
| 272 |
+
# Action-head specifics
|
| 273 |
+
"add_pos_embed": True,
|
| 274 |
+
# Noise schedule
|
| 275 |
+
"noise_beta_alpha": 1.5,
|
| 276 |
+
"noise_beta_beta": 1.0,
|
| 277 |
+
"noise_s": 0.999,
|
| 278 |
+
"num_timestep_buckets": 1000,
|
| 279 |
+
# Attention backend
|
| 280 |
+
"attn_implementation": "sdpa",
|
| 281 |
+
# AlternateVLDiT — separate visual / text token cross-attention
|
| 282 |
+
"use_alternate_vl_dit": True,
|
| 283 |
+
"attend_text_every_n_blocks": 2,
|
| 284 |
+
# MoT-style action expert: forwards full VLM ``past_key_values`` into the head;
|
| 285 |
+
# expert depth defaults to text_config.num_hidden_layers (override with expert_num_layers).
|
| 286 |
+
"use_mot_action_expert": False,
|
| 287 |
+
"mlp_mult": 4, # FFN hidden dim = inner_dim * mlp_mult (standard DiT only)
|
| 288 |
+
}
|
| 289 |
+
if dit_action_head_config is not None:
|
| 290 |
+
_default_dit_config.update(dit_action_head_config)
|
| 291 |
+
self.dit_action_head_config = _default_dit_config
|
| 292 |
+
|
| 293 |
+
# CRL (Contrastive Reinforcement Learning) parameters
|
| 294 |
+
self.crl_loss_weight = crl_loss_weight
|
| 295 |
+
self.crl_embed_dim = crl_embed_dim
|
| 296 |
+
self.crl_logsumexp_reg_weight = crl_logsumexp_reg_weight
|
| 297 |
+
self.crl_encoder_init_w = crl_encoder_init_w
|
| 298 |
+
self.crl_repr_norm = crl_repr_norm
|
| 299 |
+
|
| 300 |
+
# Token IDs
|
| 301 |
+
self.image_token_id = image_token_id
|
| 302 |
+
self.video_token_id = video_token_id
|
| 303 |
+
self.vision_start_token_id = vision_start_token_id
|
| 304 |
+
self.vision_end_token_id = vision_end_token_id
|
| 305 |
+
|
| 306 |
+
# # Propagate token IDs to text config
|
| 307 |
+
# if self.image_token_id is not None:
|
| 308 |
+
# self.text_config.image_token_id = self.image_token_id
|
| 309 |
+
# if self.video_token_id is not None:
|
| 310 |
+
# self.text_config.video_token_id = self.video_token_id
|
| 311 |
+
# if self.vision_start_token_id is not None:
|
| 312 |
+
# self.text_config.vision_start_token_id = self.vision_start_token_id
|
| 313 |
+
|
| 314 |
+
# Ensure vocab sizes are consistent
|
| 315 |
+
# if hasattr(self.text_config, 'vocab_size'):
|
| 316 |
+
# self.vocab_size = self.text_config.vocab_size
|
| 317 |
+
|
| 318 |
+
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
| 319 |
+
|
| 320 |
+
# TODO (zy): 这里需要看下是不是在VLConfig传入这些state action的特殊token更合适更灵活
|
| 321 |
+
@property
|
| 322 |
+
def action_token_id(self):
|
| 323 |
+
"""Get action token id from text config."""
|
| 324 |
+
return getattr(self.text_config, 'action_token_id', None)
|
| 325 |
+
|
| 326 |
+
@action_token_id.setter
|
| 327 |
+
def action_token_id(self, value):
|
| 328 |
+
"""Set action token id in text config."""
|
| 329 |
+
if hasattr(self.text_config, 'action_token_id'):
|
| 330 |
+
self.text_config.action_token_id = value
|
| 331 |
+
|
| 332 |
+
def __getattribute__(self, key):
|
| 333 |
+
if "text_config" in super().__getattribute__("__dict__") and key not in [
|
| 334 |
+
"dtype",
|
| 335 |
+
"_attn_implementation_internal",
|
| 336 |
+
]:
|
| 337 |
+
text_config = super().__getattribute__("text_config")
|
| 338 |
+
if key in text_config.__dict__:
|
| 339 |
+
return getattr(text_config, key)
|
| 340 |
+
|
| 341 |
+
return super().__getattribute__(key)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
PRTS_FlowMatchingConfig_Qwen3VL.register_for_auto_class()
|
| 345 |
+
__all__ = ["PRTS_FlowMatchingConfig_Qwen3VL", "PRTS_Qwen3VLTextConfig"]
|
dit_action_head.py
ADDED
|
@@ -0,0 +1,1230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DiT (Diffusion Transformer) based flow matching action head for PRTS.
|
| 3 |
+
|
| 4 |
+
Replaces the Qwen3VLTextModel-based fm_action_expert with a lightweight DiT
|
| 5 |
+
that uses explicit cross-attention to VLM hidden states, following the architecture
|
| 6 |
+
from GR00T / pi05.
|
| 7 |
+
|
| 8 |
+
Architecture:
|
| 9 |
+
ActionEncoder(noisy_actions + dof_mask, timestep)
|
| 10 |
+
→ action_features
|
| 11 |
+
→ DiT(cross-attn to VLM hidden states, ada-norm timestep conditioning)
|
| 12 |
+
→ ActionDecoder → predicted velocity
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch.distributions import Beta
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from transformers.cache_utils import Cache
|
| 24 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# DIT_PRESETS = {
|
| 28 |
+
# "DiT-B": {"num_attention_heads": 12, "attention_head_dim": 64, "output_dim": 768},
|
| 29 |
+
# "DiT-L": {"num_attention_heads": 32, "attention_head_dim": 48, "output_dim": 1536},
|
| 30 |
+
# }
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SinusoidalPositionalEncoding(nn.Module):
|
| 34 |
+
"""Sinusoidal positional encoding for sequence positions or timesteps."""
|
| 35 |
+
|
| 36 |
+
def __init__(self, embedding_dim: int):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.embedding_dim = embedding_dim
|
| 39 |
+
|
| 40 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
timesteps = timesteps.float()
|
| 42 |
+
squeeze = False
|
| 43 |
+
if timesteps.dim() == 1:
|
| 44 |
+
timesteps = timesteps.unsqueeze(1)
|
| 45 |
+
squeeze = True
|
| 46 |
+
|
| 47 |
+
half_dim = self.embedding_dim // 2
|
| 48 |
+
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
|
| 49 |
+
math.log(10000.0) / half_dim
|
| 50 |
+
)
|
| 51 |
+
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
| 52 |
+
enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
|
| 53 |
+
|
| 54 |
+
if squeeze:
|
| 55 |
+
enc = enc.squeeze(1)
|
| 56 |
+
return enc
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TimestepEncoder(nn.Module):
|
| 60 |
+
"""Projects scalar timesteps to embedding space via sinusoidal encoding + MLP."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, embedding_dim: int):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.sinusoidal = SinusoidalPositionalEncoding(256)
|
| 65 |
+
self.linear_1 = nn.Linear(256, embedding_dim)
|
| 66 |
+
self.act = nn.SiLU()
|
| 67 |
+
self.linear_2 = nn.Linear(embedding_dim, embedding_dim)
|
| 68 |
+
|
| 69 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
t_emb = self.sinusoidal(timesteps)
|
| 71 |
+
t_emb = self.linear_1(t_emb.to(dtype=self.linear_1.weight.dtype))
|
| 72 |
+
t_emb = self.act(t_emb)
|
| 73 |
+
t_emb = self.linear_2(t_emb)
|
| 74 |
+
return t_emb
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AdaLayerNorm(nn.Module):
|
| 78 |
+
"""Adaptive Layer Normalization conditioned on timestep embeddings.
|
| 79 |
+
|
| 80 |
+
Applies scale-shift modulation: out = norm(x) * (1 + scale) + shift,
|
| 81 |
+
where (scale, shift) are linearly projected from the timestep embedding.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, embedding_dim: int, eps: float = 1e-5):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.silu = nn.SiLU()
|
| 87 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
| 88 |
+
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=False)
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
temb = self.linear(self.silu(temb))
|
| 92 |
+
scale, shift = temb.chunk(2, dim=-1)
|
| 93 |
+
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class DiTAttention(nn.Module):
|
| 98 |
+
"""Multi-head attention supporting both self-attention and cross-attention.
|
| 99 |
+
|
| 100 |
+
Supports two backends selected via ``attn_implementation``:
|
| 101 |
+
|
| 102 |
+
* ``"sdpa"`` (default) – uses :func:`F.scaled_dot_product_attention`, which
|
| 103 |
+
dispatches automatically to FlashAttention / memory-efficient attention
|
| 104 |
+
depending on the installed PyTorch build. The encoder padding mask is
|
| 105 |
+
expanded to ``(B, 1, 1, S)`` and passed as ``attn_mask``.
|
| 106 |
+
|
| 107 |
+
* ``"flash_attention_2"`` – calls the ``flash_attn`` package directly for
|
| 108 |
+
lower memory usage and higher throughput. For cross-attention with an
|
| 109 |
+
encoder padding mask the k/v tensors are unpadded and
|
| 110 |
+
:func:`flash_attn_varlen_func` is used so that padding tokens are never
|
| 111 |
+
processed. For self-attention (no mask) the simpler
|
| 112 |
+
:func:`flash_attn_func` is used.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
query_dim: int,
|
| 118 |
+
num_heads: int,
|
| 119 |
+
head_dim: int,
|
| 120 |
+
cross_attention_dim: Optional[int] = None,
|
| 121 |
+
dropout: float = 0.0,
|
| 122 |
+
bias: bool = True,
|
| 123 |
+
attn_implementation: str = "sdpa",
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.num_heads = num_heads
|
| 127 |
+
self.head_dim = head_dim
|
| 128 |
+
self.attn_implementation = attn_implementation
|
| 129 |
+
inner_dim = num_heads * head_dim
|
| 130 |
+
|
| 131 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
| 132 |
+
kv_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 133 |
+
self.to_k = nn.Linear(kv_dim, inner_dim, bias=bias)
|
| 134 |
+
self.to_v = nn.Linear(kv_dim, inner_dim, bias=bias)
|
| 135 |
+
self.to_out = nn.Sequential(
|
| 136 |
+
nn.Linear(inner_dim, query_dim, bias=bias),
|
| 137 |
+
nn.Dropout(dropout),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# ------------------------------------------------------------------
|
| 141 |
+
# Flash-Attention backend
|
| 142 |
+
# ------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
def _flash_attn_forward(
|
| 145 |
+
self,
|
| 146 |
+
q: torch.Tensor,
|
| 147 |
+
k: torch.Tensor,
|
| 148 |
+
v: torch.Tensor,
|
| 149 |
+
attention_mask: Optional[torch.Tensor],
|
| 150 |
+
) -> torch.Tensor:
|
| 151 |
+
"""Run Flash Attention via HuggingFace's ``_flash_attention_forward``.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
q: ``(B, T_q, H, D)``
|
| 155 |
+
k: ``(B, T_k, H, D)``
|
| 156 |
+
v: ``(B, T_k, H, D)``
|
| 157 |
+
attention_mask: ``(B, T_k)`` bool, True = valid token.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
``(B, T_q, H*D)``
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
B, T_q, H, D = q.shape
|
| 164 |
+
# _flash_attention_forward returns (B, T_q, H, D); handles unpad/varlen internally.
|
| 165 |
+
out = _flash_attention_forward(
|
| 166 |
+
q, k, v,
|
| 167 |
+
attention_mask=attention_mask,
|
| 168 |
+
query_length=T_q,
|
| 169 |
+
is_causal=False,
|
| 170 |
+
dropout=0.0,
|
| 171 |
+
)
|
| 172 |
+
return out.reshape(B, T_q, H * D)
|
| 173 |
+
|
| 174 |
+
# ------------------------------------------------------------------
|
| 175 |
+
# Forward
|
| 176 |
+
# ------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
def forward(
|
| 179 |
+
self,
|
| 180 |
+
hidden_states: torch.Tensor,
|
| 181 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 182 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 183 |
+
) -> torch.Tensor:
|
| 184 |
+
B, T, _ = hidden_states.shape
|
| 185 |
+
|
| 186 |
+
q = self.to_q(hidden_states)
|
| 187 |
+
kv_input = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 188 |
+
k = self.to_k(kv_input)
|
| 189 |
+
v = self.to_v(kv_input)
|
| 190 |
+
|
| 191 |
+
if self.attn_implementation == "flash_attention_2":
|
| 192 |
+
# Flash Attention expects (B, S, H, D)
|
| 193 |
+
q = q.view(B, T, self.num_heads, self.head_dim)
|
| 194 |
+
k = k.view(B, -1, self.num_heads, self.head_dim)
|
| 195 |
+
v = v.view(B, -1, self.num_heads, self.head_dim)
|
| 196 |
+
attn_output = self._flash_attn_forward(q, k, v, attention_mask)
|
| 197 |
+
else:
|
| 198 |
+
# SDPA expects (B, H, S, D)
|
| 199 |
+
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
| 200 |
+
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 201 |
+
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 202 |
+
|
| 203 |
+
# Expand (B, S) bool mask → (B, 1, 1, S) for broadcasting.
|
| 204 |
+
sdpa_mask = None
|
| 205 |
+
if attention_mask is not None:
|
| 206 |
+
if attention_mask.dim() == 2:
|
| 207 |
+
sdpa_mask = attention_mask[:, None, None, :]
|
| 208 |
+
else:
|
| 209 |
+
sdpa_mask = attention_mask
|
| 210 |
+
|
| 211 |
+
attn_output = F.scaled_dot_product_attention(
|
| 212 |
+
q, k, v, attn_mask=sdpa_mask, dropout_p=0.0
|
| 213 |
+
)
|
| 214 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
|
| 215 |
+
|
| 216 |
+
return self.to_out(attn_output)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class FeedForward(nn.Module):
|
| 220 |
+
"""Feed-forward network with GELU activation."""
|
| 221 |
+
|
| 222 |
+
def __init__(self, dim: int, dropout: float = 0.0, mult: int = 4):
|
| 223 |
+
super().__init__()
|
| 224 |
+
inner_dim = dim * mult
|
| 225 |
+
self.net = nn.Sequential(
|
| 226 |
+
nn.Linear(dim, inner_dim),
|
| 227 |
+
nn.GELU(approximate="tanh"),
|
| 228 |
+
nn.Dropout(dropout),
|
| 229 |
+
nn.Linear(inner_dim, dim),
|
| 230 |
+
nn.Dropout(dropout),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 234 |
+
return self.net(x)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class BasicTransformerBlock(nn.Module):
|
| 238 |
+
"""Transformer block with self/cross-attention, optional AdaLayerNorm, and feed-forward.
|
| 239 |
+
|
| 240 |
+
When cross_attention_dim is set, the attention block performs cross-attention
|
| 241 |
+
to encoder_hidden_states. Otherwise, it performs self-attention.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
dim: int,
|
| 247 |
+
num_attention_heads: int,
|
| 248 |
+
attention_head_dim: int,
|
| 249 |
+
dropout: float = 0.0,
|
| 250 |
+
cross_attention_dim: Optional[int] = None,
|
| 251 |
+
norm_type: str = "ada_norm",
|
| 252 |
+
final_dropout: bool = False,
|
| 253 |
+
attn_implementation: str = "sdpa",
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.norm_type = norm_type
|
| 257 |
+
|
| 258 |
+
if norm_type == "ada_norm":
|
| 259 |
+
self.norm1 = AdaLayerNorm(dim)
|
| 260 |
+
else:
|
| 261 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 262 |
+
|
| 263 |
+
self.attn1 = DiTAttention(
|
| 264 |
+
query_dim=dim,
|
| 265 |
+
num_heads=num_attention_heads,
|
| 266 |
+
head_dim=attention_head_dim,
|
| 267 |
+
cross_attention_dim=cross_attention_dim,
|
| 268 |
+
dropout=dropout,
|
| 269 |
+
attn_implementation=attn_implementation,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 273 |
+
self.ff = FeedForward(dim, dropout=dropout)
|
| 274 |
+
self.final_dropout = nn.Dropout(dropout) if final_dropout else None
|
| 275 |
+
|
| 276 |
+
def forward(
|
| 277 |
+
self,
|
| 278 |
+
hidden_states: torch.Tensor,
|
| 279 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 280 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 281 |
+
temb: Optional[torch.Tensor] = None,
|
| 282 |
+
) -> torch.Tensor:
|
| 283 |
+
if self.norm_type == "ada_norm":
|
| 284 |
+
norm_hidden_states = self.norm1(hidden_states, temb)
|
| 285 |
+
else:
|
| 286 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 287 |
+
|
| 288 |
+
attn_output = self.attn1(
|
| 289 |
+
norm_hidden_states,
|
| 290 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 291 |
+
attention_mask=encoder_attention_mask,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if self.final_dropout is not None:
|
| 295 |
+
attn_output = self.final_dropout(attn_output)
|
| 296 |
+
|
| 297 |
+
hidden_states = attn_output + hidden_states
|
| 298 |
+
|
| 299 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 300 |
+
ff_output = self.ff(norm_hidden_states)
|
| 301 |
+
hidden_states = ff_output + hidden_states
|
| 302 |
+
|
| 303 |
+
return hidden_states
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class DiT(nn.Module):
|
| 307 |
+
"""Diffusion Transformer with cross-attention to VLM context features.
|
| 308 |
+
|
| 309 |
+
Interleaves cross-attention blocks (attending to encoder_hidden_states)
|
| 310 |
+
with self-attention blocks when interleave_self_attention=True.
|
| 311 |
+
Uses AdaLayerNorm for timestep conditioning throughout.
|
| 312 |
+
|
| 313 |
+
Output block applies timestep-conditioned scale-shift before final projection.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
def __init__(
|
| 317 |
+
self,
|
| 318 |
+
num_attention_heads: int = 12,
|
| 319 |
+
attention_head_dim: int = 64,
|
| 320 |
+
output_dim: int = 768,
|
| 321 |
+
num_layers: int = 12,
|
| 322 |
+
dropout: float = 0.1,
|
| 323 |
+
norm_type: str = "ada_norm",
|
| 324 |
+
final_dropout: bool = True,
|
| 325 |
+
interleave_self_attention: bool = False,
|
| 326 |
+
cross_attention_dim: Optional[int] = None,
|
| 327 |
+
attn_implementation: str = "sdpa",
|
| 328 |
+
):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 331 |
+
self.output_dim = output_dim
|
| 332 |
+
self.num_layers = num_layers
|
| 333 |
+
self.interleave_self_attention = interleave_self_attention
|
| 334 |
+
|
| 335 |
+
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
| 336 |
+
|
| 337 |
+
all_blocks = []
|
| 338 |
+
for idx in range(num_layers):
|
| 339 |
+
use_self_attn = idx % 2 == 1 and interleave_self_attention
|
| 340 |
+
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
|
| 341 |
+
|
| 342 |
+
all_blocks.append(
|
| 343 |
+
BasicTransformerBlock(
|
| 344 |
+
dim=self.inner_dim,
|
| 345 |
+
num_attention_heads=num_attention_heads,
|
| 346 |
+
attention_head_dim=attention_head_dim,
|
| 347 |
+
dropout=dropout,
|
| 348 |
+
cross_attention_dim=curr_cross_attention_dim,
|
| 349 |
+
norm_type=norm_type,
|
| 350 |
+
final_dropout=final_dropout,
|
| 351 |
+
attn_implementation=attn_implementation,
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
self.transformer_blocks = nn.ModuleList(all_blocks)
|
| 355 |
+
|
| 356 |
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 357 |
+
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
| 358 |
+
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
| 359 |
+
|
| 360 |
+
def forward(
|
| 361 |
+
self,
|
| 362 |
+
hidden_states: torch.Tensor,
|
| 363 |
+
encoder_hidden_states: torch.Tensor,
|
| 364 |
+
timestep: torch.LongTensor,
|
| 365 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 366 |
+
) -> torch.Tensor:
|
| 367 |
+
temb = self.timestep_encoder(timestep)
|
| 368 |
+
|
| 369 |
+
hidden_states = hidden_states.contiguous()
|
| 370 |
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
| 371 |
+
|
| 372 |
+
for idx, block in enumerate(self.transformer_blocks):
|
| 373 |
+
if idx % 2 == 1 and self.interleave_self_attention:
|
| 374 |
+
hidden_states = block(
|
| 375 |
+
hidden_states,
|
| 376 |
+
encoder_hidden_states=None,
|
| 377 |
+
encoder_attention_mask=None,
|
| 378 |
+
temb=temb,
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
hidden_states = block(
|
| 382 |
+
hidden_states,
|
| 383 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 384 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 385 |
+
temb=temb,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
conditioning = temb
|
| 389 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1)
|
| 390 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
| 391 |
+
return self.proj_out_2(hidden_states)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class AlternateVLDiT(DiT):
|
| 395 |
+
"""DiT variant that separates visual and text tokens during cross-attention.
|
| 396 |
+
|
| 397 |
+
Mirrors GR00T's AlternateVLDiT: even-indexed blocks do cross-attention,
|
| 398 |
+
alternating every ``attend_text_every_n_blocks`` between text tokens and
|
| 399 |
+
visual tokens. Odd-indexed blocks do self-attention (requires
|
| 400 |
+
``interleave_self_attention=True``).
|
| 401 |
+
|
| 402 |
+
When no visual tokens are present (``image_mask`` is None or all-False),
|
| 403 |
+
all valid tokens are treated as text.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
|
| 407 |
+
super().__init__(*args, **kwargs)
|
| 408 |
+
assert self.interleave_self_attention, (
|
| 409 |
+
"AlternateVLDiT requires interleave_self_attention=True"
|
| 410 |
+
)
|
| 411 |
+
self.attend_text_every_n_blocks = attend_text_every_n_blocks
|
| 412 |
+
|
| 413 |
+
def forward(
|
| 414 |
+
self,
|
| 415 |
+
hidden_states: torch.Tensor,
|
| 416 |
+
encoder_hidden_states: torch.Tensor,
|
| 417 |
+
timestep: torch.LongTensor,
|
| 418 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 419 |
+
image_mask: Optional[torch.Tensor] = None,
|
| 420 |
+
) -> torch.Tensor:
|
| 421 |
+
"""
|
| 422 |
+
Args:
|
| 423 |
+
encoder_attention_mask: (B, S) bool – True = valid VLM token.
|
| 424 |
+
image_mask: (B, S) bool – True = visual token position.
|
| 425 |
+
If None, all valid tokens are treated as text.
|
| 426 |
+
"""
|
| 427 |
+
temb = self.timestep_encoder(timestep)
|
| 428 |
+
hidden_states = hidden_states.contiguous()
|
| 429 |
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
| 430 |
+
|
| 431 |
+
B, S, _ = encoder_hidden_states.shape
|
| 432 |
+
backbone_mask = (
|
| 433 |
+
encoder_attention_mask.bool()
|
| 434 |
+
if encoder_attention_mask is not None
|
| 435 |
+
else torch.ones(B, S, dtype=torch.bool, device=hidden_states.device)
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if image_mask is not None and image_mask.any():
|
| 439 |
+
vis_mask = image_mask.bool() & backbone_mask # visual tokens
|
| 440 |
+
text_mask = (~image_mask.bool()) & backbone_mask # text tokens
|
| 441 |
+
else:
|
| 442 |
+
# No visual tokens – treat everything as text.
|
| 443 |
+
vis_mask = torch.zeros_like(backbone_mask)
|
| 444 |
+
text_mask = backbone_mask
|
| 445 |
+
|
| 446 |
+
for idx, block in enumerate(self.transformer_blocks):
|
| 447 |
+
if idx % 2 == 1:
|
| 448 |
+
# Self-attention block.
|
| 449 |
+
hidden_states = block(
|
| 450 |
+
hidden_states,
|
| 451 |
+
encoder_hidden_states=None,
|
| 452 |
+
encoder_attention_mask=None,
|
| 453 |
+
temb=temb,
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
# Cross-attention block: alternate text / visual every N blocks.
|
| 457 |
+
if idx % (2 * self.attend_text_every_n_blocks) == 0:
|
| 458 |
+
curr_mask = text_mask
|
| 459 |
+
else:
|
| 460 |
+
curr_mask = vis_mask
|
| 461 |
+
hidden_states = block(
|
| 462 |
+
hidden_states,
|
| 463 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 464 |
+
encoder_attention_mask=curr_mask,
|
| 465 |
+
temb=temb,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
conditioning = temb
|
| 469 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1)
|
| 470 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
| 471 |
+
return self.proj_out_2(hidden_states)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class ActionEncoder(nn.Module):
|
| 475 |
+
"""Encodes noisy actions (optionally concatenated with DOF mask) and timestep
|
| 476 |
+
into hidden features via MLP + sinusoidal time encoding.
|
| 477 |
+
|
| 478 |
+
Architecture: Linear → concat(action_emb, time_emb) → SiLU + Linear → Linear
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
def __init__(self, action_input_dim: int, hidden_size: int):
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.hidden_size = hidden_size
|
| 484 |
+
self.layer1 = nn.Linear(action_input_dim, hidden_size)
|
| 485 |
+
self.layer2 = nn.Linear(2 * hidden_size, hidden_size)
|
| 486 |
+
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
| 487 |
+
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
| 488 |
+
|
| 489 |
+
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
| 490 |
+
"""
|
| 491 |
+
Args:
|
| 492 |
+
actions: (B, T, action_input_dim) noisy actions (+ DOF mask)
|
| 493 |
+
timesteps: (B,) discretized timesteps
|
| 494 |
+
"""
|
| 495 |
+
B, T, _ = actions.shape
|
| 496 |
+
timesteps_expanded = timesteps.unsqueeze(1).expand(-1, T)
|
| 497 |
+
|
| 498 |
+
a_emb = self.layer1(actions)
|
| 499 |
+
tau_emb = self.pos_encoding(timesteps_expanded).to(dtype=a_emb.dtype)
|
| 500 |
+
|
| 501 |
+
x = torch.cat([a_emb, tau_emb], dim=-1)
|
| 502 |
+
x = F.silu(self.layer2(x))
|
| 503 |
+
x = self.layer3(x)
|
| 504 |
+
return x
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class ActionDecoder(nn.Module):
|
| 508 |
+
"""2-layer MLP that decodes DiT output to action-space velocity."""
|
| 509 |
+
|
| 510 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 511 |
+
super().__init__()
|
| 512 |
+
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
| 513 |
+
self.layer2 = nn.Linear(hidden_dim, output_dim)
|
| 514 |
+
|
| 515 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 516 |
+
return self.layer2(F.relu(self.layer1(x)))
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class FlowMatchingDiTHead(nn.Module):
|
| 520 |
+
"""Flow matching action head using DiT (Diffusion Transformer).
|
| 521 |
+
|
| 522 |
+
Replaces the fm_action_expert (Qwen3VLTextModel-based) with a DiT that uses
|
| 523 |
+
explicit cross-attention to VLM hidden states instead of KV cache continuation.
|
| 524 |
+
|
| 525 |
+
Training:
|
| 526 |
+
1. Sample noise and timestep from Beta distribution
|
| 527 |
+
2. Compute noisy trajectory: x_t = (1-t)*noise + t*actions
|
| 528 |
+
3. Compute velocity target: v = actions - noise
|
| 529 |
+
4. Encode noisy actions + DOF mask + timestep → action features
|
| 530 |
+
5. Prepend learned future query tokens
|
| 531 |
+
6. Run DiT with cross-attention to VLM hidden states
|
| 532 |
+
7. Decode to action-space velocity prediction
|
| 533 |
+
|
| 534 |
+
Inference:
|
| 535 |
+
Euler integration from pure noise (t=0) to clean actions (t=1)
|
| 536 |
+
over num_inference_timesteps steps.
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
def __init__(
|
| 540 |
+
self,
|
| 541 |
+
action_dim: int,
|
| 542 |
+
action_chunk_size: int,
|
| 543 |
+
cross_attention_dim: int,
|
| 544 |
+
num_inference_timesteps: int = 4,
|
| 545 |
+
config: Optional[dict] = None,
|
| 546 |
+
):
|
| 547 |
+
super().__init__()
|
| 548 |
+
cfg = {
|
| 549 |
+
"num_layers": 16,
|
| 550 |
+
"num_attention_heads": 12,
|
| 551 |
+
"attention_head_dim": 64,
|
| 552 |
+
"output_dim": 1024,
|
| 553 |
+
"dropout": 0.2,
|
| 554 |
+
"interleave_self_attention": True,
|
| 555 |
+
"norm_type": "ada_norm",
|
| 556 |
+
"final_dropout": True,
|
| 557 |
+
"add_pos_embed": True,
|
| 558 |
+
"noise_beta_alpha": 1.5,
|
| 559 |
+
"noise_beta_beta": 1.0,
|
| 560 |
+
"noise_s": 0.999,
|
| 561 |
+
"num_timestep_buckets": 1000,
|
| 562 |
+
"attn_implementation": "sdpa",
|
| 563 |
+
"use_alternate_vl_dit": False,
|
| 564 |
+
"attend_text_every_n_blocks": 2,
|
| 565 |
+
}
|
| 566 |
+
if config is not None:
|
| 567 |
+
cfg.update(config)
|
| 568 |
+
# dit_model_type = config.get("dit_model_type")
|
| 569 |
+
# if dit_model_type and dit_model_type in DIT_PRESETS:
|
| 570 |
+
# cfg.update(DIT_PRESETS[dit_model_type])
|
| 571 |
+
# cfg.pop("dit_model_type", None)
|
| 572 |
+
|
| 573 |
+
self.action_dim = action_dim
|
| 574 |
+
self.action_chunk_size = action_chunk_size
|
| 575 |
+
self.num_inference_timesteps = num_inference_timesteps
|
| 576 |
+
self.num_timestep_buckets = cfg["num_timestep_buckets"]
|
| 577 |
+
self.noise_s = cfg["noise_s"]
|
| 578 |
+
self.use_alternate_vl_dit = cfg["use_alternate_vl_dit"]
|
| 579 |
+
self.add_pos_embed = cfg["add_pos_embed"]
|
| 580 |
+
|
| 581 |
+
num_attention_heads = cfg["num_attention_heads"]
|
| 582 |
+
attention_head_dim = cfg["attention_head_dim"]
|
| 583 |
+
output_dim = cfg["output_dim"]
|
| 584 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 585 |
+
|
| 586 |
+
dit_kwargs = dict(
|
| 587 |
+
num_attention_heads=num_attention_heads,
|
| 588 |
+
attention_head_dim=attention_head_dim,
|
| 589 |
+
output_dim=output_dim,
|
| 590 |
+
num_layers=cfg["num_layers"],
|
| 591 |
+
dropout=cfg["dropout"],
|
| 592 |
+
norm_type=cfg["norm_type"],
|
| 593 |
+
final_dropout=cfg["final_dropout"],
|
| 594 |
+
interleave_self_attention=cfg["interleave_self_attention"],
|
| 595 |
+
cross_attention_dim=cross_attention_dim,
|
| 596 |
+
attn_implementation=cfg["attn_implementation"],
|
| 597 |
+
)
|
| 598 |
+
if self.use_alternate_vl_dit:
|
| 599 |
+
self.dit = AlternateVLDiT(
|
| 600 |
+
**dit_kwargs,
|
| 601 |
+
attend_text_every_n_blocks=cfg["attend_text_every_n_blocks"],
|
| 602 |
+
)
|
| 603 |
+
else:
|
| 604 |
+
self.dit = DiT(**dit_kwargs)
|
| 605 |
+
|
| 606 |
+
# action_dim * 2: noisy action + DOF mask concatenated
|
| 607 |
+
self.action_encoder = ActionEncoder(action_dim * 2, inner_dim)
|
| 608 |
+
self.action_decoder = ActionDecoder(output_dim, inner_dim, action_dim)
|
| 609 |
+
|
| 610 |
+
if self.add_pos_embed:
|
| 611 |
+
max_seq_len = max(action_chunk_size, 256)
|
| 612 |
+
self.position_embedding = nn.Embedding(max_seq_len, inner_dim)
|
| 613 |
+
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
| 614 |
+
|
| 615 |
+
# self.beta_dist = Beta(cfg["noise_beta_alpha"], cfg["noise_beta_beta"])
|
| 616 |
+
self._beta_alpha = cfg["noise_beta_alpha"]
|
| 617 |
+
self._beta_beta = cfg["noise_beta_beta"]
|
| 618 |
+
|
| 619 |
+
def reset_parameters(self):
|
| 620 |
+
"""Re-apply proper initialization.
|
| 621 |
+
|
| 622 |
+
HuggingFace from_pretrained calls _init_weights on modules whose
|
| 623 |
+
parameters are absent from the checkpoint, overwriting any custom
|
| 624 |
+
init done in __init__. Call this after from_pretrained when loading
|
| 625 |
+
from a base VLM checkpoint that does not contain DiT weights.
|
| 626 |
+
"""
|
| 627 |
+
if self.add_pos_embed:
|
| 628 |
+
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
| 629 |
+
for module in self.modules():
|
| 630 |
+
if isinstance(module, nn.Linear):
|
| 631 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 632 |
+
if module.bias is not None:
|
| 633 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 634 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 635 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 636 |
+
elif isinstance(module, nn.LayerNorm):
|
| 637 |
+
if module.elementwise_affine:
|
| 638 |
+
nn.init.ones_(module.weight)
|
| 639 |
+
nn.init.zeros_(module.bias)
|
| 640 |
+
|
| 641 |
+
def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor:
|
| 642 |
+
beta_dist = Beta(self._beta_alpha, self._beta_beta)
|
| 643 |
+
sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s)
|
| 644 |
+
return (self.noise_s - sample) / self.noise_s
|
| 645 |
+
|
| 646 |
+
def _encode_actions(
|
| 647 |
+
self,
|
| 648 |
+
noisy_actions: torch.Tensor,
|
| 649 |
+
t_discretized: torch.Tensor,
|
| 650 |
+
action_dof_mask: Optional[torch.Tensor],
|
| 651 |
+
device,
|
| 652 |
+
) -> torch.Tensor:
|
| 653 |
+
"""Encode noisy actions with DOF mask and timestep, add position embeddings."""
|
| 654 |
+
if action_dof_mask is not None:
|
| 655 |
+
encoder_input = torch.cat(
|
| 656 |
+
[noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1
|
| 657 |
+
)
|
| 658 |
+
else:
|
| 659 |
+
encoder_input = torch.cat(
|
| 660 |
+
[noisy_actions, torch.ones_like(noisy_actions)], dim=-1
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
action_features = self.action_encoder(encoder_input, t_discretized)
|
| 664 |
+
|
| 665 |
+
if self.add_pos_embed:
|
| 666 |
+
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
| 667 |
+
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
| 668 |
+
action_features = action_features + pos_embs
|
| 669 |
+
|
| 670 |
+
return action_features
|
| 671 |
+
|
| 672 |
+
def _dit_forward(
|
| 673 |
+
self,
|
| 674 |
+
sa_embs: torch.Tensor,
|
| 675 |
+
vl_embs: torch.Tensor,
|
| 676 |
+
t_discretized: torch.LongTensor,
|
| 677 |
+
encoder_attention_mask: Optional[torch.Tensor],
|
| 678 |
+
image_mask: Optional[torch.Tensor],
|
| 679 |
+
) -> torch.Tensor:
|
| 680 |
+
if self.use_alternate_vl_dit:
|
| 681 |
+
return self.dit(
|
| 682 |
+
hidden_states=sa_embs,
|
| 683 |
+
encoder_hidden_states=vl_embs,
|
| 684 |
+
timestep=t_discretized,
|
| 685 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 686 |
+
image_mask=image_mask,
|
| 687 |
+
)
|
| 688 |
+
return self.dit(
|
| 689 |
+
hidden_states=sa_embs,
|
| 690 |
+
encoder_hidden_states=vl_embs,
|
| 691 |
+
timestep=t_discretized,
|
| 692 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
def forward(
|
| 696 |
+
self,
|
| 697 |
+
vl_embs: torch.Tensor,
|
| 698 |
+
actions: torch.Tensor,
|
| 699 |
+
action_dof_mask: Optional[torch.Tensor] = None,
|
| 700 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 701 |
+
image_mask: Optional[torch.Tensor] = None,
|
| 702 |
+
) -> tuple:
|
| 703 |
+
"""Training forward pass.
|
| 704 |
+
|
| 705 |
+
Args:
|
| 706 |
+
vl_embs: (B, S, D) VLM hidden states for cross-attention
|
| 707 |
+
actions: (B, T, action_dim) ground truth action trajectories
|
| 708 |
+
action_dof_mask: (B, T, action_dim) DOF validity mask
|
| 709 |
+
encoder_attention_mask: (B, S) bool – True = valid VLM token
|
| 710 |
+
image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT)
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
(pred_v, velocity): predicted velocity and target velocity, both (B, T, action_dim)
|
| 714 |
+
"""
|
| 715 |
+
device = vl_embs.device
|
| 716 |
+
B = actions.shape[0]
|
| 717 |
+
|
| 718 |
+
noise = torch.randn(actions.shape, device=device, dtype=actions.dtype)
|
| 719 |
+
t = self.sample_time(B, device=device, dtype=actions.dtype)
|
| 720 |
+
t_expanded = t[:, None, None]
|
| 721 |
+
|
| 722 |
+
noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
|
| 723 |
+
velocity = actions - noise
|
| 724 |
+
|
| 725 |
+
t_discretized = (t * self.num_timestep_buckets).long()
|
| 726 |
+
|
| 727 |
+
action_features = self._encode_actions(noisy_trajectory, t_discretized, action_dof_mask, device)
|
| 728 |
+
|
| 729 |
+
model_output = self._dit_forward(
|
| 730 |
+
action_features, vl_embs, t_discretized, encoder_attention_mask, image_mask
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
pred = self.action_decoder(model_output)
|
| 734 |
+
pred_v = pred[:, :actions.shape[1]]
|
| 735 |
+
|
| 736 |
+
return pred_v, velocity
|
| 737 |
+
|
| 738 |
+
@torch.no_grad()
|
| 739 |
+
def predict_action(
|
| 740 |
+
self,
|
| 741 |
+
vl_embs: torch.Tensor,
|
| 742 |
+
action_dof_mask: Optional[torch.Tensor] = None,
|
| 743 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 744 |
+
image_mask: Optional[torch.Tensor] = None,
|
| 745 |
+
) -> torch.Tensor:
|
| 746 |
+
"""Inference: denoise actions from noise using Euler integration.
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
vl_embs: (B, S, D) VLM hidden states
|
| 750 |
+
action_dof_mask: optional (B, T, action_dim) or (1, T, action_dim) DOF mask
|
| 751 |
+
encoder_attention_mask: (B, S) bool – True = valid VLM token
|
| 752 |
+
image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT)
|
| 753 |
+
|
| 754 |
+
Returns:
|
| 755 |
+
(B, T, action_dim) denoised action trajectories
|
| 756 |
+
"""
|
| 757 |
+
B = vl_embs.shape[0]
|
| 758 |
+
device = vl_embs.device
|
| 759 |
+
dtype = vl_embs.dtype
|
| 760 |
+
|
| 761 |
+
actions = torch.randn(
|
| 762 |
+
(B, self.action_chunk_size, self.action_dim),
|
| 763 |
+
device=device, dtype=dtype,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
dt = 1.0 / self.num_inference_timesteps
|
| 767 |
+
|
| 768 |
+
for step in range(self.num_inference_timesteps):
|
| 769 |
+
t_cont = step / float(self.num_inference_timesteps)
|
| 770 |
+
t_discretized_val = int(t_cont * self.num_timestep_buckets)
|
| 771 |
+
timesteps_tensor = torch.full((B,), t_discretized_val, device=device, dtype=torch.long)
|
| 772 |
+
|
| 773 |
+
action_features = self._encode_actions(actions, timesteps_tensor, action_dof_mask, device)
|
| 774 |
+
|
| 775 |
+
model_output = self._dit_forward(
|
| 776 |
+
action_features, vl_embs, timesteps_tensor, encoder_attention_mask, image_mask
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
pred = self.action_decoder(model_output)
|
| 780 |
+
pred_velocity = pred[:, :self.action_chunk_size]
|
| 781 |
+
|
| 782 |
+
actions = actions + dt * pred_velocity
|
| 783 |
+
|
| 784 |
+
return actions
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
# ============================================================================
|
| 788 |
+
# Pi0.5-style KV-cache action expert (VLM K/V concat + GQA + SwiGLU FFN)
|
| 789 |
+
# ============================================================================
|
| 790 |
+
class AdaRMSNorm(nn.Module):
|
| 791 |
+
"""Adaptive RMS normalization: (scale, shift, gate) from cond; zero-init."""
|
| 792 |
+
|
| 793 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 794 |
+
super().__init__()
|
| 795 |
+
self.eps = eps
|
| 796 |
+
self.modulation = nn.Linear(dim, dim * 3)
|
| 797 |
+
nn.init.zeros_(self.modulation.weight)
|
| 798 |
+
nn.init.zeros_(self.modulation.bias)
|
| 799 |
+
|
| 800 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 801 |
+
var = x.float().pow(2).mean(-1, keepdim=True)
|
| 802 |
+
normed = (x * torch.rsqrt(var + self.eps)).to(x.dtype)
|
| 803 |
+
scale, shift, gate = self.modulation(cond).chunk(3, dim=-1)
|
| 804 |
+
normed = normed * (1 + scale[:, None]) + shift[:, None]
|
| 805 |
+
return normed, gate[:, None]
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
class SwiGLUFeedForward(nn.Module):
|
| 809 |
+
"""SiLU(gate_proj(x)) * up_proj(x) → down_proj."""
|
| 810 |
+
|
| 811 |
+
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0, bias: bool = True):
|
| 812 |
+
super().__init__()
|
| 813 |
+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
| 814 |
+
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
|
| 815 |
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
|
| 816 |
+
self.dropout = nn.Dropout(dropout)
|
| 817 |
+
|
| 818 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 819 |
+
return self.down_proj(self.dropout(F.silu(self.gate_proj(x)) * self.up_proj(x)))
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
class MoTAttention(nn.Module):
|
| 823 |
+
"""Action Q attends to concatenated [VLM KV cache ; action KV]; GQA expand for SDPA."""
|
| 824 |
+
|
| 825 |
+
def __init__(
|
| 826 |
+
self,
|
| 827 |
+
hidden_size: int,
|
| 828 |
+
num_attention_heads: int,
|
| 829 |
+
num_kv_heads: int,
|
| 830 |
+
head_dim: int,
|
| 831 |
+
dropout: float = 0.0,
|
| 832 |
+
bias: bool = True,
|
| 833 |
+
):
|
| 834 |
+
super().__init__()
|
| 835 |
+
if num_attention_heads % num_kv_heads != 0:
|
| 836 |
+
raise ValueError(
|
| 837 |
+
f"num_attention_heads ({num_attention_heads}) must be divisible by "
|
| 838 |
+
f"num_kv_heads ({num_kv_heads})"
|
| 839 |
+
)
|
| 840 |
+
self.num_attention_heads = num_attention_heads
|
| 841 |
+
self.num_kv_heads = num_kv_heads
|
| 842 |
+
self.head_dim = head_dim
|
| 843 |
+
q_dim = num_attention_heads * head_dim
|
| 844 |
+
kv_dim = num_kv_heads * head_dim
|
| 845 |
+
self.q_proj = nn.Linear(hidden_size, q_dim, bias=bias)
|
| 846 |
+
self.k_proj = nn.Linear(hidden_size, kv_dim, bias=bias)
|
| 847 |
+
self.v_proj = nn.Linear(hidden_size, kv_dim, bias=bias)
|
| 848 |
+
self.o_proj = nn.Linear(q_dim, hidden_size, bias=bias)
|
| 849 |
+
self.dropout = nn.Dropout(dropout)
|
| 850 |
+
|
| 851 |
+
def forward(
|
| 852 |
+
self,
|
| 853 |
+
action_hidden: torch.Tensor,
|
| 854 |
+
vlm_cached_k: torch.Tensor,
|
| 855 |
+
vlm_cached_v: torch.Tensor,
|
| 856 |
+
vlm_attention_mask: Optional[torch.Tensor] = None,
|
| 857 |
+
) -> torch.Tensor:
|
| 858 |
+
B, T_a, _ = action_hidden.shape
|
| 859 |
+
|
| 860 |
+
q = self.q_proj(action_hidden)
|
| 861 |
+
act_k = self.k_proj(action_hidden)
|
| 862 |
+
act_v = self.v_proj(action_hidden)
|
| 863 |
+
|
| 864 |
+
q = q.view(B, T_a, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
| 865 |
+
act_k = act_k.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 866 |
+
act_v = act_v.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 867 |
+
|
| 868 |
+
k = torch.cat([vlm_cached_k, act_k], dim=2)
|
| 869 |
+
v = torch.cat([vlm_cached_v, act_v], dim=2)
|
| 870 |
+
|
| 871 |
+
repeat_factor = self.num_attention_heads // self.num_kv_heads
|
| 872 |
+
k = k.repeat_interleave(repeat_factor, dim=1)
|
| 873 |
+
v = v.repeat_interleave(repeat_factor, dim=1)
|
| 874 |
+
|
| 875 |
+
sdpa_mask = None
|
| 876 |
+
if vlm_attention_mask is not None:
|
| 877 |
+
action_mask = vlm_attention_mask.new_ones(B, T_a)
|
| 878 |
+
combined_mask = torch.cat([vlm_attention_mask, action_mask], dim=1)
|
| 879 |
+
sdpa_mask = combined_mask[:, None, None, :]
|
| 880 |
+
|
| 881 |
+
attn_out = F.scaled_dot_product_attention(
|
| 882 |
+
q, k, v, attn_mask=sdpa_mask, dropout_p=0.0,
|
| 883 |
+
)
|
| 884 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T_a, -1)
|
| 885 |
+
return self.dropout(self.o_proj(attn_out))
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
class MoTBlock(nn.Module):
|
| 889 |
+
"""AdaRMSNorm → attention → gated residual → AdaRMSNorm → SwiGLU FFN → gated residual."""
|
| 890 |
+
|
| 891 |
+
def __init__(
|
| 892 |
+
self,
|
| 893 |
+
hidden_size: int,
|
| 894 |
+
num_attention_heads: int,
|
| 895 |
+
num_kv_heads: int,
|
| 896 |
+
head_dim: int,
|
| 897 |
+
intermediate_size: int,
|
| 898 |
+
dropout: float = 0.0,
|
| 899 |
+
):
|
| 900 |
+
super().__init__()
|
| 901 |
+
self.pre_attn_norm = AdaRMSNorm(hidden_size)
|
| 902 |
+
self.attn = MoTAttention(
|
| 903 |
+
hidden_size=hidden_size,
|
| 904 |
+
num_attention_heads=num_attention_heads,
|
| 905 |
+
num_kv_heads=num_kv_heads,
|
| 906 |
+
head_dim=head_dim,
|
| 907 |
+
dropout=dropout,
|
| 908 |
+
)
|
| 909 |
+
self.pre_ffn_norm = AdaRMSNorm(hidden_size)
|
| 910 |
+
self.ffn = SwiGLUFeedForward(hidden_size, intermediate_size, dropout=dropout)
|
| 911 |
+
|
| 912 |
+
def forward(
|
| 913 |
+
self,
|
| 914 |
+
action_hidden: torch.Tensor,
|
| 915 |
+
vlm_cached_k: torch.Tensor,
|
| 916 |
+
vlm_cached_v: torch.Tensor,
|
| 917 |
+
adarms_cond: torch.Tensor,
|
| 918 |
+
vlm_attention_mask: Optional[torch.Tensor] = None,
|
| 919 |
+
) -> torch.Tensor:
|
| 920 |
+
normed, gate1 = self.pre_attn_norm(action_hidden, adarms_cond)
|
| 921 |
+
attn_out = self.attn(normed, vlm_cached_k, vlm_cached_v, vlm_attention_mask)
|
| 922 |
+
action_hidden = action_hidden + attn_out * gate1
|
| 923 |
+
|
| 924 |
+
normed2, gate2 = self.pre_ffn_norm(action_hidden, adarms_cond)
|
| 925 |
+
action_hidden = action_hidden + self.ffn(normed2) * gate2
|
| 926 |
+
return action_hidden
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
class MoTDiT(nn.Module):
|
| 930 |
+
"""Stack of ActionBlocks; each block uses one VLM layer's KV pair."""
|
| 931 |
+
|
| 932 |
+
def __init__(
|
| 933 |
+
self,
|
| 934 |
+
hidden_size: int,
|
| 935 |
+
num_attention_heads: int,
|
| 936 |
+
num_kv_heads: int,
|
| 937 |
+
head_dim: int,
|
| 938 |
+
intermediate_size: int,
|
| 939 |
+
num_layers: int,
|
| 940 |
+
dropout: float = 0.2,
|
| 941 |
+
):
|
| 942 |
+
super().__init__()
|
| 943 |
+
self.num_layers = num_layers
|
| 944 |
+
self.blocks = nn.ModuleList([
|
| 945 |
+
MoTBlock(
|
| 946 |
+
hidden_size=hidden_size,
|
| 947 |
+
num_attention_heads=num_attention_heads,
|
| 948 |
+
num_kv_heads=num_kv_heads,
|
| 949 |
+
head_dim=head_dim,
|
| 950 |
+
intermediate_size=intermediate_size,
|
| 951 |
+
dropout=dropout,
|
| 952 |
+
)
|
| 953 |
+
for _ in range(num_layers)
|
| 954 |
+
])
|
| 955 |
+
self.final_norm = AdaRMSNorm(hidden_size)
|
| 956 |
+
|
| 957 |
+
def forward(
|
| 958 |
+
self,
|
| 959 |
+
action_hidden: torch.Tensor,
|
| 960 |
+
vlm_kv_cache: list,
|
| 961 |
+
adarms_cond: torch.Tensor,
|
| 962 |
+
vlm_attention_mask: Optional[torch.Tensor] = None,
|
| 963 |
+
) -> torch.Tensor:
|
| 964 |
+
for idx, block in enumerate(self.blocks):
|
| 965 |
+
cached_k, cached_v = vlm_kv_cache[idx]
|
| 966 |
+
action_hidden = block(
|
| 967 |
+
action_hidden, cached_k, cached_v, adarms_cond, vlm_attention_mask,
|
| 968 |
+
)
|
| 969 |
+
action_hidden, _ = self.final_norm(action_hidden, adarms_cond)
|
| 970 |
+
return action_hidden
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
def _kv_pairs_from_past_key_values(past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
| 974 |
+
"""Per-layer (K, V) from a HuggingFace decoder KV cache (order matches transformer layers)."""
|
| 975 |
+
return [
|
| 976 |
+
(past_key_values[i][0], past_key_values[i][1])
|
| 977 |
+
for i in range(len(past_key_values))
|
| 978 |
+
]
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
class MoTFlowMatchingHead(nn.Module):
|
| 982 |
+
"""Flow matching head: MoT-style action expert over VLM KV cache (concat + GQA)."""
|
| 983 |
+
|
| 984 |
+
def __init__(
|
| 985 |
+
self,
|
| 986 |
+
action_dim: int,
|
| 987 |
+
action_chunk_size: int,
|
| 988 |
+
vlm_config,
|
| 989 |
+
num_inference_timesteps: int = 10,
|
| 990 |
+
config: Optional[dict] = None,
|
| 991 |
+
):
|
| 992 |
+
super().__init__()
|
| 993 |
+
|
| 994 |
+
_vlm_num_q_heads = 8 # vlm_config.num_attention_heads // 2 # optional: 8
|
| 995 |
+
_vlm_num_kv_heads = vlm_config.num_key_value_heads # 8
|
| 996 |
+
_vlm_head_dim = getattr(
|
| 997 |
+
vlm_config, "head_dim", vlm_config.hidden_size // vlm_config.num_attention_heads
|
| 998 |
+
) # 128
|
| 999 |
+
|
| 1000 |
+
cfg = {
|
| 1001 |
+
"hidden_size": 1024, # vlm_config.hidden_size // 2,
|
| 1002 |
+
# "hidden_size": vlm_config.hidden_size // 2,
|
| 1003 |
+
"intermediate_size": vlm_config.intermediate_size // 4,
|
| 1004 |
+
"expert_num_layers": vlm_config.num_hidden_layers,
|
| 1005 |
+
# Attention dims default to VLM values (required for KV cache compat)
|
| 1006 |
+
"num_attention_heads": _vlm_num_q_heads,
|
| 1007 |
+
"num_kv_heads": _vlm_num_kv_heads,
|
| 1008 |
+
"head_dim": _vlm_head_dim,
|
| 1009 |
+
# Noise schedule
|
| 1010 |
+
"dropout": 0.2,
|
| 1011 |
+
"add_pos_embed": True,
|
| 1012 |
+
"noise_beta_alpha": 1.5,
|
| 1013 |
+
"noise_beta_beta": 1.0,
|
| 1014 |
+
"noise_s": 0.999,
|
| 1015 |
+
"num_timestep_buckets": 1000,
|
| 1016 |
+
}
|
| 1017 |
+
if config is not None:
|
| 1018 |
+
config = cfg.copy()
|
| 1019 |
+
|
| 1020 |
+
num_attention_heads = cfg["num_attention_heads"]
|
| 1021 |
+
num_kv_heads = cfg["num_kv_heads"]
|
| 1022 |
+
head_dim = cfg["head_dim"]
|
| 1023 |
+
hidden_size = cfg["hidden_size"]
|
| 1024 |
+
intermediate_size = cfg["intermediate_size"]
|
| 1025 |
+
num_layers = cfg["expert_num_layers"]
|
| 1026 |
+
|
| 1027 |
+
self.action_dim = action_dim
|
| 1028 |
+
self.action_chunk_size = action_chunk_size
|
| 1029 |
+
self.num_inference_timesteps = num_inference_timesteps
|
| 1030 |
+
self.num_timestep_buckets = cfg["num_timestep_buckets"]
|
| 1031 |
+
self.noise_s = cfg["noise_s"]
|
| 1032 |
+
self.add_pos_embed = cfg["add_pos_embed"]
|
| 1033 |
+
|
| 1034 |
+
self.action_in_proj = nn.Linear(action_dim * 2, hidden_size)
|
| 1035 |
+
self.action_out_proj = nn.Linear(hidden_size, action_dim)
|
| 1036 |
+
|
| 1037 |
+
self.time_sinusoidal = SinusoidalPositionalEncoding(hidden_size)
|
| 1038 |
+
self.time_mlp_1 = nn.Linear(hidden_size, hidden_size)
|
| 1039 |
+
self.time_mlp_2 = nn.Linear(hidden_size, hidden_size)
|
| 1040 |
+
|
| 1041 |
+
if self.add_pos_embed:
|
| 1042 |
+
max_seq = max(action_chunk_size, 256)
|
| 1043 |
+
self.position_embedding = nn.Embedding(max_seq, hidden_size)
|
| 1044 |
+
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
| 1045 |
+
|
| 1046 |
+
self.dit = MoTDiT(
|
| 1047 |
+
hidden_size=hidden_size,
|
| 1048 |
+
num_attention_heads=num_attention_heads,
|
| 1049 |
+
num_kv_heads=num_kv_heads,
|
| 1050 |
+
head_dim=head_dim,
|
| 1051 |
+
intermediate_size=intermediate_size,
|
| 1052 |
+
num_layers=num_layers,
|
| 1053 |
+
dropout=cfg["dropout"],
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
self._beta_alpha = cfg["noise_beta_alpha"]
|
| 1057 |
+
self._beta_beta = cfg["noise_beta_beta"]
|
| 1058 |
+
|
| 1059 |
+
@property
|
| 1060 |
+
def num_dit_layers(self) -> int:
|
| 1061 |
+
"""Number of expert blocks; must match ``len(past_key_values.key_cache)``."""
|
| 1062 |
+
return self.dit.num_layers
|
| 1063 |
+
|
| 1064 |
+
def _vlm_kv_list_from_past(self, past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
| 1065 |
+
n = len(past_key_values)
|
| 1066 |
+
if n != self.num_dit_layers:
|
| 1067 |
+
raise ValueError(
|
| 1068 |
+
f"MoT expert has {self.num_dit_layers} blocks but `past_key_values` has {n} "
|
| 1069 |
+
"layers. Set `dit_action_head_config['expert_num_layers']` to match "
|
| 1070 |
+
"`text_config.num_hidden_layers`."
|
| 1071 |
+
)
|
| 1072 |
+
return _kv_pairs_from_past_key_values(past_key_values)
|
| 1073 |
+
|
| 1074 |
+
def reset_parameters(self):
|
| 1075 |
+
"""Re-apply proper initialization after from_pretrained."""
|
| 1076 |
+
if self.add_pos_embed:
|
| 1077 |
+
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
| 1078 |
+
for module in self.modules():
|
| 1079 |
+
if isinstance(module, AdaRMSNorm):
|
| 1080 |
+
nn.init.zeros_(module.modulation.weight)
|
| 1081 |
+
nn.init.zeros_(module.modulation.bias)
|
| 1082 |
+
elif isinstance(module, nn.Linear):
|
| 1083 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 1084 |
+
if module.bias is not None:
|
| 1085 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 1086 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 1087 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 1088 |
+
|
| 1089 |
+
def _compute_adarms_cond(self, t_discretized: torch.Tensor) -> torch.Tensor:
|
| 1090 |
+
t_emb = self.time_sinusoidal(t_discretized.float())
|
| 1091 |
+
t_emb = t_emb.to(dtype=self.time_mlp_1.weight.dtype)
|
| 1092 |
+
t_emb = F.silu(self.time_mlp_1(t_emb))
|
| 1093 |
+
t_emb = F.silu(self.time_mlp_2(t_emb))
|
| 1094 |
+
return t_emb
|
| 1095 |
+
|
| 1096 |
+
def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor:
|
| 1097 |
+
beta_dist = Beta(self._beta_alpha, self._beta_beta)
|
| 1098 |
+
sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s)
|
| 1099 |
+
return (self.noise_s - sample) / self.noise_s
|
| 1100 |
+
|
| 1101 |
+
def _prepare_action_embeds(
|
| 1102 |
+
self,
|
| 1103 |
+
noisy_actions: torch.Tensor,
|
| 1104 |
+
action_dof_mask: Optional[torch.Tensor],
|
| 1105 |
+
) -> torch.Tensor:
|
| 1106 |
+
if action_dof_mask is not None:
|
| 1107 |
+
x = torch.cat(
|
| 1108 |
+
[noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1,
|
| 1109 |
+
)
|
| 1110 |
+
else:
|
| 1111 |
+
x = torch.cat([noisy_actions, torch.ones_like(noisy_actions)], dim=-1)
|
| 1112 |
+
|
| 1113 |
+
tokens = self.action_in_proj(x)
|
| 1114 |
+
|
| 1115 |
+
if self.add_pos_embed:
|
| 1116 |
+
pos_ids = torch.arange(tokens.shape[1], dtype=torch.long, device=noisy_actions.device)
|
| 1117 |
+
tokens = tokens + self.position_embedding(pos_ids).unsqueeze(0)
|
| 1118 |
+
|
| 1119 |
+
return tokens
|
| 1120 |
+
|
| 1121 |
+
def forward(
|
| 1122 |
+
self,
|
| 1123 |
+
past_key_values: Cache,
|
| 1124 |
+
actions: torch.Tensor,
|
| 1125 |
+
action_dof_mask: Optional[torch.Tensor] = None,
|
| 1126 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1127 |
+
) -> tuple:
|
| 1128 |
+
"""Training: returns (pred_velocity, target_velocity).
|
| 1129 |
+
|
| 1130 |
+
Args:
|
| 1131 |
+
past_key_values: VLM decoder KV cache; layer count must equal ``num_dit_layers``.
|
| 1132 |
+
"""
|
| 1133 |
+
vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
|
| 1134 |
+
device = actions.device
|
| 1135 |
+
B = actions.shape[0]
|
| 1136 |
+
|
| 1137 |
+
noise = torch.randn(actions.shape, device=device, dtype=actions.dtype)
|
| 1138 |
+
t = self.sample_time(B, device=device, dtype=actions.dtype)
|
| 1139 |
+
t_expanded = t[:, None, None]
|
| 1140 |
+
|
| 1141 |
+
noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
|
| 1142 |
+
velocity = actions - noise
|
| 1143 |
+
|
| 1144 |
+
t_discretized = (t * self.num_timestep_buckets).long()
|
| 1145 |
+
adarms_cond = self._compute_adarms_cond(t_discretized)
|
| 1146 |
+
|
| 1147 |
+
action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask)
|
| 1148 |
+
|
| 1149 |
+
output = self.dit(
|
| 1150 |
+
action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
pred = self.action_out_proj(output)
|
| 1154 |
+
pred_v = pred[:, :actions.shape[1]]
|
| 1155 |
+
return pred_v, velocity
|
| 1156 |
+
|
| 1157 |
+
def compute_velocity(
|
| 1158 |
+
self,
|
| 1159 |
+
past_key_values: Cache,
|
| 1160 |
+
actions: torch.Tensor,
|
| 1161 |
+
noise: torch.Tensor,
|
| 1162 |
+
t: torch.Tensor,
|
| 1163 |
+
action_dof_mask: Optional[torch.Tensor] = None,
|
| 1164 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1165 |
+
) -> torch.Tensor:
|
| 1166 |
+
"""Compute velocity prediction for pre-sampled noise and timestep.
|
| 1167 |
+
|
| 1168 |
+
Used by DiffusionNFT where noise and timestep must be shared between
|
| 1169 |
+
the current policy (v_θ) and the reference policy (v_old).
|
| 1170 |
+
|
| 1171 |
+
Args:
|
| 1172 |
+
past_key_values: VLM decoder KV cache
|
| 1173 |
+
actions: (B, T, action_dim) ground truth actions (x_0)
|
| 1174 |
+
noise: (B, T, action_dim) pre-sampled noise (ε)
|
| 1175 |
+
t: (B,) continuous timesteps in [0, 1)
|
| 1176 |
+
action_dof_mask, encoder_attention_mask,
|
| 1177 |
+
|
| 1178 |
+
Returns:
|
| 1179 |
+
pred_v: (B, T, action_dim) predicted velocity
|
| 1180 |
+
"""
|
| 1181 |
+
vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
|
| 1182 |
+
device = actions.device
|
| 1183 |
+
t_expanded = t[:, None, None]
|
| 1184 |
+
|
| 1185 |
+
noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
|
| 1186 |
+
t_discretized = (t * self.num_timestep_buckets).long()
|
| 1187 |
+
adarms_cond = self._compute_adarms_cond(t_discretized)
|
| 1188 |
+
action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask)
|
| 1189 |
+
output = self.dit(
|
| 1190 |
+
action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
|
| 1191 |
+
)
|
| 1192 |
+
pred = self.action_out_proj(output)
|
| 1193 |
+
return pred[:, :actions.shape[1]]
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
@torch.no_grad()
|
| 1197 |
+
def predict_action(
|
| 1198 |
+
self,
|
| 1199 |
+
past_key_values: Cache,
|
| 1200 |
+
action_dof_mask: Optional[torch.Tensor] = None,
|
| 1201 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1202 |
+
) -> torch.Tensor:
|
| 1203 |
+
"""Inference: Euler integration, returns (B, chunk_size, action_dim)."""
|
| 1204 |
+
k0 = past_key_values[0][0]
|
| 1205 |
+
B = k0.shape[0]
|
| 1206 |
+
device = k0.device
|
| 1207 |
+
dtype = k0.dtype
|
| 1208 |
+
vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
|
| 1209 |
+
|
| 1210 |
+
actions = torch.randn(
|
| 1211 |
+
(B, self.action_chunk_size, self.action_dim),
|
| 1212 |
+
device=device, dtype=dtype,
|
| 1213 |
+
)
|
| 1214 |
+
dt = 1.0 / self.num_inference_timesteps
|
| 1215 |
+
|
| 1216 |
+
for step in range(self.num_inference_timesteps):
|
| 1217 |
+
t_cont = step / float(self.num_inference_timesteps)
|
| 1218 |
+
t_disc_val = int(t_cont * self.num_timestep_buckets)
|
| 1219 |
+
t_tensor = torch.full((B,), t_disc_val, device=device, dtype=torch.long)
|
| 1220 |
+
|
| 1221 |
+
adarms_cond = self._compute_adarms_cond(t_tensor)
|
| 1222 |
+
action_tokens = self._prepare_action_embeds(actions, action_dof_mask)
|
| 1223 |
+
|
| 1224 |
+
output = self.dit(
|
| 1225 |
+
action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
|
| 1226 |
+
)
|
| 1227 |
+
pred_velocity = self.action_out_proj(output)[:, :self.action_chunk_size]
|
| 1228 |
+
actions = actions + dt * pred_velocity
|
| 1229 |
+
|
| 1230 |
+
return actions
|
generation_config.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_sample": true,
|
| 3 |
+
"eos_token_id": [
|
| 4 |
+
151645,
|
| 5 |
+
151643
|
| 6 |
+
],
|
| 7 |
+
"pad_token_id": 151643,
|
| 8 |
+
"temperature": 0.7,
|
| 9 |
+
"top_k": 20,
|
| 10 |
+
"top_p": 0.8,
|
| 11 |
+
"transformers_version": "4.57.3"
|
| 12 |
+
}
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model-00001-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:543d8bd40ff00d487b24b85c724587b96839434f61bc51d438319042e0cf0fcb
|
| 3 |
+
size 4999639200
|
model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:303b0fe49a58b0397c013f89c7047483c0910b77b8dc665ccb6f0e410cdd848c
|
| 3 |
+
size 4995750056
|
model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31a73b4447c20c0bb0cf72b8a999e01ff08a0df3354686b449bdcb560a010c66
|
| 3 |
+
size 981882944
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_prts_qwen3_vl.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 TeleAI Rhodes Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Main VLA model architecture based on Qwen3-VL."""
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 25 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 26 |
+
|
| 27 |
+
from transformers.modeling_outputs import ModelOutput
|
| 28 |
+
from transformers.cache_utils import Cache
|
| 29 |
+
from transformers.processing_utils import Unpack
|
| 30 |
+
from transformers.utils import TransformersKwargs, is_torchdynamo_compiling
|
| 31 |
+
|
| 32 |
+
from .modeling_qwen3_vl import (
|
| 33 |
+
Qwen3VLForConditionalGeneration,
|
| 34 |
+
Qwen3VLTextModel,
|
| 35 |
+
Qwen3VLVisionModel,
|
| 36 |
+
)
|
| 37 |
+
from .configuration_prts_qwen3_vl import PRTS_FlowMatchingConfig_Qwen3VL
|
| 38 |
+
from .dit_action_head import FlowMatchingDiTHead, MoTFlowMatchingHead
|
| 39 |
+
|
| 40 |
+
ACTION_DATASET_NAMES = []
|
| 41 |
+
|
| 42 |
+
# ----------------------------- Print Customization -----------------------------
|
| 43 |
+
from colorama import init, Fore, Style
|
| 44 |
+
from datetime import datetime
|
| 45 |
+
|
| 46 |
+
# Initialize colorama
|
| 47 |
+
init(autoreset=True)
|
| 48 |
+
|
| 49 |
+
class CustomPrinter:
|
| 50 |
+
"""Custom colored printer."""
|
| 51 |
+
|
| 52 |
+
# Define message type configuration
|
| 53 |
+
TYPE_CONFIG = {
|
| 54 |
+
'normal': {
|
| 55 |
+
'color': Fore.WHITE,
|
| 56 |
+
'icon': '',
|
| 57 |
+
'prefix': '',
|
| 58 |
+
'style': Style.NORMAL
|
| 59 |
+
},
|
| 60 |
+
'important': {
|
| 61 |
+
'color': Fore.CYAN,
|
| 62 |
+
'icon': '💡',
|
| 63 |
+
'prefix': 'IMPORTANT',
|
| 64 |
+
'style': Style.BRIGHT
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def print(cls, message, msg_type='normal', show_time=True, show_icon=True, end='\n'):
|
| 70 |
+
"""
|
| 71 |
+
Custom print function.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
message: The message content to print
|
| 75 |
+
msg_type: Message type ('normal', 'info', 'success', 'warning', 'error', 'fail', 'debug', 'important')
|
| 76 |
+
show_time: Whether to display a timestamp
|
| 77 |
+
show_icon: Whether to display the icon
|
| 78 |
+
end: Line terminator
|
| 79 |
+
"""
|
| 80 |
+
# Get configuration for the message type
|
| 81 |
+
config = cls.TYPE_CONFIG.get(msg_type, cls.TYPE_CONFIG['normal'])
|
| 82 |
+
|
| 83 |
+
# Build prefix parts
|
| 84 |
+
prefix_parts = []
|
| 85 |
+
|
| 86 |
+
# Add timestamp
|
| 87 |
+
if show_time:
|
| 88 |
+
timestamp = datetime.now().strftime('%H:%M:%S')
|
| 89 |
+
prefix_parts.append(f"[{timestamp}]")
|
| 90 |
+
|
| 91 |
+
# Add icon and prefix text
|
| 92 |
+
icon_text = f"{config['icon']} " if show_icon else ""
|
| 93 |
+
prefix_parts.append(f"{icon_text}{config['prefix']}")
|
| 94 |
+
|
| 95 |
+
if config['prefix'] == '':
|
| 96 |
+
full_message = message
|
| 97 |
+
else:
|
| 98 |
+
# Combine prefix parts
|
| 99 |
+
prefix = " ".join(prefix_parts)
|
| 100 |
+
|
| 101 |
+
# Construct full message
|
| 102 |
+
full_message = f"{prefix}: {message}"
|
| 103 |
+
|
| 104 |
+
# Apply color and style and print
|
| 105 |
+
formatted_message = f"{config['style']}{config['color']}{full_message}"
|
| 106 |
+
print(formatted_message, end=end)
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def normal(cls, message, **kwargs):
|
| 110 |
+
"""Convenience: normal-level print."""
|
| 111 |
+
cls.print(message, 'normal', **kwargs)
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def important(cls, message, **kwargs):
|
| 115 |
+
"""Convenience: important-level print."""
|
| 116 |
+
cls.print(message, 'important', **kwargs)
|
| 117 |
+
|
| 118 |
+
def important(message, **kwargs):
|
| 119 |
+
CustomPrinter.important(message, **kwargs)
|
| 120 |
+
|
| 121 |
+
# -------------------------------------------------------------
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def create_sinusoidal_pos_embedding(
|
| 125 |
+
time: torch.Tensor,
|
| 126 |
+
dimension: int,
|
| 127 |
+
min_period: float = 4e-3,
|
| 128 |
+
max_period: float = 4.0,
|
| 129 |
+
device="cpu",
|
| 130 |
+
) -> torch.Tensor:
|
| 131 |
+
"""
|
| 132 |
+
Computes sine-cosine positional embedding vectors for scalar positions (diffusion timesteps).
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
time: Tensor of shape (batch_size,) containing timestep values
|
| 136 |
+
dimension: Embedding dimension (must be even)
|
| 137 |
+
min_period: Minimum period for sinusoidal encoding
|
| 138 |
+
max_period: Maximum period for sinusoidal encoding
|
| 139 |
+
device: Device to create tensors on
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Positional embeddings of shape (batch_size, dimension)
|
| 143 |
+
"""
|
| 144 |
+
if dimension % 2 != 0:
|
| 145 |
+
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
| 146 |
+
|
| 147 |
+
if time.ndim != 1:
|
| 148 |
+
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
| 149 |
+
|
| 150 |
+
fraction = torch.linspace(0.0, 1.0, dimension // 2, device=device)
|
| 151 |
+
period = min_period * (max_period / min_period) ** fraction
|
| 152 |
+
|
| 153 |
+
scaling_factor = 1.0 / period * 2 * math.pi
|
| 154 |
+
sin_input = scaling_factor[None, :] * time[:, None]
|
| 155 |
+
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
| 156 |
+
return pos_emb
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class ContrastiveEncoder(nn.Module):
|
| 160 |
+
"""
|
| 161 |
+
MLP projector for Contrastive Reinforcement Learning (CRL) embeddings.
|
| 162 |
+
|
| 163 |
+
Projects hidden states to a shared latent space for contrastive learning,
|
| 164 |
+
with L2 normalization for stable similarity computation.
|
| 165 |
+
|
| 166 |
+
Architecture: N-layer MLP with LayerNorm and Swish activation,
|
| 167 |
+
followed by a cold-initialized output projection.
|
| 168 |
+
[Linear -> LayerNorm -> Swish] x N -> Linear (cold init)
|
| 169 |
+
|
| 170 |
+
Matches stable_contrastive_rl's Q network structure (default: 4 hidden layers).
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
input_dim: Dimension of input hidden states
|
| 174 |
+
output_dim: Dimension of output embeddings (default: 256)
|
| 175 |
+
hidden_dim: Dimension of hidden layers (default: 1024)
|
| 176 |
+
num_layers: Number of hidden layers (default: 4)
|
| 177 |
+
repr_norm: Whether to L2-normalize outputs (default: False)
|
| 178 |
+
init_w: Small value for last layer weight initialization for cold init (default: 1e-12)
|
| 179 |
+
"""
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
input_dim: int,
|
| 183 |
+
output_dim: int = 256,
|
| 184 |
+
hidden_dim: int = 1024,
|
| 185 |
+
num_layers: int = 4,
|
| 186 |
+
repr_norm: bool = False,
|
| 187 |
+
init_w: float = 1e-12,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.num_layers = num_layers
|
| 191 |
+
self.repr_norm = repr_norm
|
| 192 |
+
|
| 193 |
+
# Build hidden layers with LayerNorm
|
| 194 |
+
self.hidden_layers = nn.ModuleList()
|
| 195 |
+
self.layer_norms = nn.ModuleList()
|
| 196 |
+
|
| 197 |
+
for i in range(num_layers):
|
| 198 |
+
in_dim = input_dim if i == 0 else hidden_dim
|
| 199 |
+
self.hidden_layers.append(nn.Linear(in_dim, hidden_dim))
|
| 200 |
+
self.layer_norms.append(nn.LayerNorm(hidden_dim))
|
| 201 |
+
|
| 202 |
+
# Output projection layer with cold initialization
|
| 203 |
+
self.output_proj = nn.Linear(hidden_dim, output_dim)
|
| 204 |
+
self.output_proj.weight.data.uniform_(-init_w, init_w)
|
| 205 |
+
self.output_proj.bias.data.fill_(0)
|
| 206 |
+
|
| 207 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 208 |
+
"""
|
| 209 |
+
Project input to L2-normalized embedding space.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
x: Input tensor of shape (batch_size, input_dim)
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
L2-normalized embeddings of shape (batch_size, output_dim)
|
| 216 |
+
"""
|
| 217 |
+
# Pass through hidden layers
|
| 218 |
+
for fc, norm in zip(self.hidden_layers, self.layer_norms):
|
| 219 |
+
x = fc(x)
|
| 220 |
+
x = norm(x)
|
| 221 |
+
x = F.silu(x)
|
| 222 |
+
|
| 223 |
+
# Output projection
|
| 224 |
+
x = self.output_proj(x)
|
| 225 |
+
|
| 226 |
+
# Optional L2 normalization
|
| 227 |
+
if self.repr_norm:
|
| 228 |
+
x = F.normalize(x, dim=-1)
|
| 229 |
+
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@dataclass
|
| 235 |
+
class PRTS_Qwen3VL_ModelOutputWithPast(ModelOutput):
|
| 236 |
+
"""
|
| 237 |
+
Output class for PRTS model based on Qwen3-VL.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
loss: Combined total loss
|
| 241 |
+
flow_loss: Flow matching loss for action prediction
|
| 242 |
+
cross_entropy_loss: Standard language modeling loss
|
| 243 |
+
crl_loss: Contrastive Reinforcement Learning loss for goal-action alignment
|
| 244 |
+
logits: Language model logits
|
| 245 |
+
past_key_values: Cached key-value states
|
| 246 |
+
hidden_states: Hidden states from all layers (if output_hidden_states=True)
|
| 247 |
+
attentions: Attention weights (if output_attentions=True)
|
| 248 |
+
rope_deltas: RoPE position delta information
|
| 249 |
+
channel_loss_dict: Per-dataset loss values for logging
|
| 250 |
+
channel_loss_count_dict: Per-dataset token counts for loss normalization
|
| 251 |
+
"""
|
| 252 |
+
loss: Optional[torch.FloatTensor] = None
|
| 253 |
+
flow_loss: Optional[torch.FloatTensor] = None
|
| 254 |
+
cross_entropy_loss: Optional[torch.FloatTensor] = None
|
| 255 |
+
crl_loss: Optional[torch.FloatTensor] = None
|
| 256 |
+
logits: Optional[torch.FloatTensor] = None
|
| 257 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 258 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 259 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 260 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
| 261 |
+
|
| 262 |
+
crl_num_samples: Optional[torch.LongTensor] = None
|
| 263 |
+
channel_loss_dict: Optional[dict] = None
|
| 264 |
+
channel_loss_count_dict: Optional[dict] = None
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class PRTS_Qwen3VL(Qwen3VLForConditionalGeneration):
|
| 268 |
+
"""
|
| 269 |
+
Vision-Language-Action model based on Qwen3-VL.
|
| 270 |
+
|
| 271 |
+
This model extends Qwen3-VL to support:
|
| 272 |
+
1. Proprioceptive state embedding and prediction
|
| 273 |
+
2. Sub-task description generation (language format)
|
| 274 |
+
3. Action chunk prediction via flow matching (continuous actions)
|
| 275 |
+
4. Optional discrete action tokenization (fast mode)
|
| 276 |
+
|
| 277 |
+
The model uses a flow matching approach for continuous action prediction, with a DiT
|
| 278 |
+
(Diffusion Transformer) action head that cross-attends to VLM hidden states.
|
| 279 |
+
"""
|
| 280 |
+
config: PRTS_FlowMatchingConfig_Qwen3VL
|
| 281 |
+
|
| 282 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 283 |
+
_no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
|
| 284 |
+
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
config: PRTS_FlowMatchingConfig_Qwen3VL,
|
| 288 |
+
):
|
| 289 |
+
"""
|
| 290 |
+
Initialize the PRTS Qwen3-VL model for action processing.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
config: Model configuration
|
| 294 |
+
use_fast_tokenizer (bool): Whether to use FAST tokenizer for discrete actions
|
| 295 |
+
flow_matching_action_loss_weight (float): Weight for flow matching action loss
|
| 296 |
+
"""
|
| 297 |
+
super().__init__(config)
|
| 298 |
+
|
| 299 |
+
# The parent class initializes:
|
| 300 |
+
# - self.visual: Qwen3VLVisionModel
|
| 301 |
+
# - self.language_model: Qwen3VLTextModel
|
| 302 |
+
# - self.lm_head: Language model head
|
| 303 |
+
# - self.rope_deltas: Cached rope deltas
|
| 304 |
+
# We keep these and add PRTS-specific components
|
| 305 |
+
|
| 306 |
+
# PRTS-specific parameters
|
| 307 |
+
self.action_dim = config.max_action_dim
|
| 308 |
+
self.use_fast_tokenizer = config.use_fast_action_tokenizer
|
| 309 |
+
self.flow_matching_action_loss_weight = config.flow_matching_action_loss_weight
|
| 310 |
+
|
| 311 |
+
# Loss functions
|
| 312 |
+
self.loss_fct = CrossEntropyLoss(reduction="none")
|
| 313 |
+
self.loss_mse = MSELoss(reduction="none")
|
| 314 |
+
|
| 315 |
+
# DiT-based flow matching action head: standard (+ AlternateVLDiT) or pi0.5 KV expert
|
| 316 |
+
self.use_mot_action_expert = config.dit_action_head_config.get(
|
| 317 |
+
"use_mot_action_expert", False
|
| 318 |
+
)
|
| 319 |
+
if config.flow_matching_action_loss_weight > 0.:
|
| 320 |
+
if self.use_mot_action_expert:
|
| 321 |
+
self.dit_action_head = MoTFlowMatchingHead(
|
| 322 |
+
action_dim=self.action_dim,
|
| 323 |
+
action_chunk_size=config.action_chunk_size,
|
| 324 |
+
vlm_config=config.text_config,
|
| 325 |
+
num_inference_timesteps=config.num_denoise_steps,
|
| 326 |
+
config=config.dit_action_head_config,
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
self.dit_action_head = FlowMatchingDiTHead(
|
| 330 |
+
action_dim=self.action_dim,
|
| 331 |
+
action_chunk_size=config.action_chunk_size,
|
| 332 |
+
cross_attention_dim=config.text_config.hidden_size,
|
| 333 |
+
num_inference_timesteps=config.num_denoise_steps,
|
| 334 |
+
config=config.dit_action_head_config,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# CRL (Contrastive Reinforcement Learning) components
|
| 338 |
+
if config.crl_loss_weight > 0.:
|
| 339 |
+
hidden_size = config.text_config.hidden_size
|
| 340 |
+
# Current encoders (trainable)
|
| 341 |
+
self.crl_action_encoder = ContrastiveEncoder(
|
| 342 |
+
input_dim=hidden_size,
|
| 343 |
+
output_dim=config.crl_embed_dim,
|
| 344 |
+
init_w=config.crl_encoder_init_w,
|
| 345 |
+
repr_norm=config.crl_repr_norm,
|
| 346 |
+
)
|
| 347 |
+
self.crl_goal_encoder = ContrastiveEncoder(
|
| 348 |
+
input_dim=hidden_size,
|
| 349 |
+
output_dim=config.crl_embed_dim,
|
| 350 |
+
init_w=config.crl_encoder_init_w,
|
| 351 |
+
repr_norm=config.crl_repr_norm,
|
| 352 |
+
)
|
| 353 |
+
# Learnable temperature (log-space for numerical stability, CLIP recipe).
|
| 354 |
+
self.crl_logit_scale = nn.Parameter(
|
| 355 |
+
torch.ones([], requires_grad=True) * math.log(1 / 0.2)
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Initialize weights
|
| 359 |
+
self.post_init()
|
| 360 |
+
|
| 361 |
+
# Print parameter counts
|
| 362 |
+
visual_params = sum(p.numel() for p in self.visual.parameters())
|
| 363 |
+
language_params = sum(p.numel() for p in self.language_model.parameters())
|
| 364 |
+
model_params = visual_params + language_params
|
| 365 |
+
important(f"Backbone VLM (visual + language_model) parameters: {model_params / 1e6:.2f}M")
|
| 366 |
+
important(f"Flow Matching Loss coefficient: {self.flow_matching_action_loss_weight}")
|
| 367 |
+
|
| 368 |
+
if config.flow_matching_action_loss_weight > 0.:
|
| 369 |
+
dit_params = sum(p.numel() for p in self.dit_action_head.parameters())
|
| 370 |
+
# Get the inner model type name for logging
|
| 371 |
+
if hasattr(self.dit_action_head, 'dit'):
|
| 372 |
+
dit_head_type = type(self.dit_action_head.dit).__name__
|
| 373 |
+
else:
|
| 374 |
+
dit_head_type = type(self.dit_action_head).__name__
|
| 375 |
+
important(f"DiT Action Head ({dit_head_type}) parameters: {dit_params / 1e6:.2f}M")
|
| 376 |
+
|
| 377 |
+
if config.crl_loss_weight > 0.:
|
| 378 |
+
crl_params = sum(p.numel() for p in self.crl_action_encoder.parameters())
|
| 379 |
+
crl_params += sum(p.numel() for p in self.crl_goal_encoder.parameters())
|
| 380 |
+
important(f"CRL Encoders (action + goal) parameters: {crl_params / 1e6:.2f}M")
|
| 381 |
+
important(f"CRL Loss coefficient: {config.crl_loss_weight}")
|
| 382 |
+
important(f"CRL Encoder init_w: {config.crl_encoder_init_w}")
|
| 383 |
+
important(f"CRL Repr Norm: {config.crl_repr_norm}")
|
| 384 |
+
|
| 385 |
+
self.fast_action_token_start_idx = 200000
|
| 386 |
+
self.use_multi_positive = True
|
| 387 |
+
|
| 388 |
+
def get_input_embeddings(self):
|
| 389 |
+
return self.language_model.get_input_embeddings()
|
| 390 |
+
|
| 391 |
+
def set_input_embeddings(self, value):
|
| 392 |
+
self.language_model.set_input_embeddings(value)
|
| 393 |
+
|
| 394 |
+
def set_decoder(self, decoder):
|
| 395 |
+
self.language_model = decoder
|
| 396 |
+
|
| 397 |
+
def get_decoder(self):
|
| 398 |
+
return self.language_model
|
| 399 |
+
|
| 400 |
+
def get_output_embeddings(self):
|
| 401 |
+
return self.lm_head
|
| 402 |
+
|
| 403 |
+
def set_output_embeddings(self, new_embeddings):
|
| 404 |
+
self.lm_head = new_embeddings
|
| 405 |
+
|
| 406 |
+
def to_float32_flow_matching_head(self):
|
| 407 |
+
"""Convert flow matching heads to float32 for numerical stability."""
|
| 408 |
+
if hasattr(self, 'dit_action_head'):
|
| 409 |
+
self.dit_action_head = self.dit_action_head.to(dtype=torch.float32)
|
| 410 |
+
|
| 411 |
+
def set_fast_action_info(self, action_mapper, fast_action_token_start_idx):
|
| 412 |
+
"""Set information for fast (discrete) action tokenization."""
|
| 413 |
+
self.action_mapper = action_mapper
|
| 414 |
+
self.fast_action_token_start_idx = fast_action_token_start_idx
|
| 415 |
+
|
| 416 |
+
def get_placeholder_mask_with_special_token(
|
| 417 |
+
self,
|
| 418 |
+
input_ids: torch.LongTensor,
|
| 419 |
+
inputs_embeds: torch.FloatTensor,
|
| 420 |
+
special_features: torch.FloatTensor,
|
| 421 |
+
special_pad_token_id: int,
|
| 422 |
+
):
|
| 423 |
+
"""
|
| 424 |
+
Get placeholder mask for a specific special token (e.g., state tokens).
|
| 425 |
+
|
| 426 |
+
Similar to get_placeholder_mask but for custom special tokens beyond image/video.
|
| 427 |
+
"""
|
| 428 |
+
if input_ids is None:
|
| 429 |
+
special_mask = inputs_embeds == self.get_input_embeddings()(
|
| 430 |
+
torch.tensor(special_pad_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 431 |
+
)
|
| 432 |
+
special_mask = special_mask.all(-1)
|
| 433 |
+
else:
|
| 434 |
+
special_mask = input_ids == special_pad_token_id
|
| 435 |
+
|
| 436 |
+
n_special_tokens = special_mask.sum()
|
| 437 |
+
special_mask = special_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 438 |
+
if special_features is not None and inputs_embeds[special_mask].numel() != special_features.numel():
|
| 439 |
+
raise ValueError(
|
| 440 |
+
f"Features and tokens do not match: tokens: {n_special_tokens}, features {special_features.shape[0]}"
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
return special_mask
|
| 444 |
+
|
| 445 |
+
def forward(
|
| 446 |
+
self,
|
| 447 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 448 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 449 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 450 |
+
past_key_values: Optional[Cache] = None,
|
| 451 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 452 |
+
labels: Optional[torch.LongTensor] = None,
|
| 453 |
+
# use_cache: Optional[bool] = None,
|
| 454 |
+
# output_attentions: Optional[bool] = None,
|
| 455 |
+
# output_hidden_states: Optional[bool] = None,
|
| 456 |
+
# return_dict: Optional[bool] = None,
|
| 457 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 458 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 459 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 460 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 461 |
+
# rope_deltas: Optional[torch.LongTensor] = None,
|
| 462 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 463 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 464 |
+
actions: Optional[torch.Tensor] = None,
|
| 465 |
+
action_is_pad: torch.Tensor | None = None,
|
| 466 |
+
action_dof_mask: Optional[torch.Tensor] = None,
|
| 467 |
+
dataset_names: Optional[List[str]] = None,
|
| 468 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 469 |
+
) -> Union[tuple, PRTS_Qwen3VL_ModelOutputWithPast]:
|
| 470 |
+
"""
|
| 471 |
+
Forward pass for PRTS_Qwen3VL model.
|
| 472 |
+
|
| 473 |
+
This extends Qwen3VLForConditionalGeneration.forward with:
|
| 474 |
+
- State embedding injection
|
| 475 |
+
- Action chunk flow matching
|
| 476 |
+
- DeepStack visual feature handling
|
| 477 |
+
"""
|
| 478 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 479 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
# 1. Prepare input embeddings
|
| 483 |
+
if inputs_embeds is None:
|
| 484 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 485 |
+
|
| 486 |
+
image_mask = None
|
| 487 |
+
video_mask = None
|
| 488 |
+
|
| 489 |
+
# 2. Process images with deepstack features
|
| 490 |
+
deepstack_image_embeds = None
|
| 491 |
+
if pixel_values is not None:
|
| 492 |
+
image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw, image_max_seqlen=kwargs['image_max_seqlen'])
|
| 493 |
+
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
| 494 |
+
image_mask, _ = self.get_placeholder_mask(
|
| 495 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
|
| 496 |
+
)
|
| 497 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 498 |
+
|
| 499 |
+
# 3. Process videos with deepstack features
|
| 500 |
+
deepstack_video_embeds = None
|
| 501 |
+
if pixel_values_videos is not None:
|
| 502 |
+
video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
| 503 |
+
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
| 504 |
+
_, video_mask = self.get_placeholder_mask(
|
| 505 |
+
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
|
| 506 |
+
)
|
| 507 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 508 |
+
|
| 509 |
+
# 4. Aggregate deepstack visual features
|
| 510 |
+
visual_pos_masks = None
|
| 511 |
+
deepstack_visual_embeds = None
|
| 512 |
+
if image_mask is not None and video_mask is not None:
|
| 513 |
+
# aggregate visual_pos_masks and deepstack_visual_embeds
|
| 514 |
+
image_mask = image_mask[..., 0]
|
| 515 |
+
video_mask = video_mask[..., 0]
|
| 516 |
+
visual_pos_masks = image_mask | video_mask
|
| 517 |
+
deepstack_visual_embeds = []
|
| 518 |
+
image_mask_joint = image_mask[visual_pos_masks]
|
| 519 |
+
video_mask_joint = video_mask[visual_pos_masks]
|
| 520 |
+
for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
|
| 521 |
+
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
|
| 522 |
+
embed_joint[image_mask_joint, :] = img_embed
|
| 523 |
+
embed_joint[video_mask_joint, :] = vid_embed
|
| 524 |
+
deepstack_visual_embeds.append(embed_joint)
|
| 525 |
+
elif image_mask is not None:
|
| 526 |
+
image_mask = image_mask[..., 0]
|
| 527 |
+
visual_pos_masks = image_mask
|
| 528 |
+
deepstack_visual_embeds = deepstack_image_embeds
|
| 529 |
+
elif video_mask is not None:
|
| 530 |
+
video_mask = video_mask[..., 0]
|
| 531 |
+
visual_pos_masks = video_mask
|
| 532 |
+
deepstack_visual_embeds = deepstack_video_embeds
|
| 533 |
+
|
| 534 |
+
if attention_mask is not None:
|
| 535 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
| 536 |
+
|
| 537 |
+
# 7. Calculate position IDs using Qwen3VL's rope index
|
| 538 |
+
if position_ids is None:
|
| 539 |
+
attention_mask_tensor = (
|
| 540 |
+
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
| 541 |
+
)
|
| 542 |
+
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
| 543 |
+
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
| 544 |
+
if attention_mask_tensor.dtype.is_floating_point:
|
| 545 |
+
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
| 546 |
+
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
| 547 |
+
|
| 548 |
+
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
| 549 |
+
(input_ids is not None and input_ids.shape[1] != 1)
|
| 550 |
+
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
| 551 |
+
)
|
| 552 |
+
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
| 553 |
+
(cache_position is not None and cache_position[0] == 0)
|
| 554 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 555 |
+
)
|
| 556 |
+
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
| 557 |
+
position_ids, rope_deltas = self.get_rope_index(
|
| 558 |
+
input_ids,
|
| 559 |
+
image_grid_thw,
|
| 560 |
+
video_grid_thw,
|
| 561 |
+
attention_mask=attention_mask_tensor,
|
| 562 |
+
)
|
| 563 |
+
self.rope_deltas = rope_deltas
|
| 564 |
+
else:
|
| 565 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 566 |
+
delta = (
|
| 567 |
+
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
| 568 |
+
if cache_position is not None
|
| 569 |
+
else 0
|
| 570 |
+
)
|
| 571 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 572 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 573 |
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
| 574 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
| 575 |
+
position_ids = position_ids.add(delta)
|
| 576 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 577 |
+
|
| 578 |
+
_lm_extra_kwargs: dict = {}
|
| 579 |
+
|
| 580 |
+
_use_cache = (
|
| 581 |
+
self.use_mot_action_expert
|
| 582 |
+
and self.flow_matching_action_loss_weight > 0.
|
| 583 |
+
and actions is not None
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
vlm_outputs = self.language_model(
|
| 587 |
+
input_ids=None,
|
| 588 |
+
position_ids=position_ids,
|
| 589 |
+
attention_mask=attention_mask,
|
| 590 |
+
past_key_values=past_key_values,
|
| 591 |
+
inputs_embeds=inputs_embeds,
|
| 592 |
+
use_cache=_use_cache,
|
| 593 |
+
cache_position=cache_position,
|
| 594 |
+
visual_pos_masks=visual_pos_masks,
|
| 595 |
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
| 596 |
+
output_hidden_states=False,
|
| 597 |
+
**_lm_extra_kwargs,
|
| 598 |
+
**kwargs,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
vlm_hidden_states = vlm_outputs.last_hidden_state
|
| 602 |
+
|
| 603 |
+
# 11. Run DiT action head if actions are present
|
| 604 |
+
dit_pred_v = None
|
| 605 |
+
dit_velocity = None
|
| 606 |
+
if actions is not None and self.flow_matching_action_loss_weight > 0:
|
| 607 |
+
# vlm_hidden_states shape: bs, seq_length, hidden_size
|
| 608 |
+
actions_for_dit = actions.to(vlm_hidden_states.device, dtype=vlm_hidden_states.dtype)
|
| 609 |
+
dof_mask_for_dit = action_dof_mask.to(vlm_hidden_states.device, dtype=vlm_hidden_states.dtype) if action_dof_mask is not None else None
|
| 610 |
+
# Pass attention_mask so DiT cross-attention ignores padding tokens
|
| 611 |
+
dit_encoder_attention_mask = attention_mask.bool() if attention_mask is not None else None
|
| 612 |
+
|
| 613 |
+
if self.use_mot_action_expert and vlm_outputs.past_key_values is not None:
|
| 614 |
+
dit_pred_v, dit_velocity = self.dit_action_head(
|
| 615 |
+
vlm_outputs.past_key_values,
|
| 616 |
+
actions_for_dit,
|
| 617 |
+
dof_mask_for_dit,
|
| 618 |
+
encoder_attention_mask=dit_encoder_attention_mask,
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
# Standard: pass single (last-layer) VLM hidden states
|
| 622 |
+
dit_image_mask = visual_pos_masks.bool() if visual_pos_masks is not None else None
|
| 623 |
+
dit_pred_v, dit_velocity = self.dit_action_head(
|
| 624 |
+
vlm_hidden_states, actions_for_dit, dof_mask_for_dit,
|
| 625 |
+
encoder_attention_mask=dit_encoder_attention_mask,
|
| 626 |
+
image_mask=dit_image_mask,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# 12. Compute logits
|
| 630 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 631 |
+
logits = self.lm_head(vlm_hidden_states[:, slice_indices, :])
|
| 632 |
+
|
| 633 |
+
# 13. Compute losses
|
| 634 |
+
loss = None
|
| 635 |
+
cross_entropy_loss, flow_loss = None, None
|
| 636 |
+
channel_loss_dict = None
|
| 637 |
+
channel_loss_count_dict = None
|
| 638 |
+
|
| 639 |
+
if labels is not None:
|
| 640 |
+
loss = 0
|
| 641 |
+
action_accuracy = 0
|
| 642 |
+
unique_datasets_name = list(set(dataset_names)) if dataset_names is not None else []
|
| 643 |
+
|
| 644 |
+
# Compute cross-entropy loss
|
| 645 |
+
shift_logits = logits[..., :-1, :].float().contiguous()
|
| 646 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 647 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 648 |
+
shift_labels = shift_labels.view(-1)
|
| 649 |
+
|
| 650 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 651 |
+
non_ignored_mask = shift_labels != -100
|
| 652 |
+
_cross_entropy_loss = self.loss_fct(shift_logits, shift_labels)
|
| 653 |
+
cross_entropy_loss = (
|
| 654 |
+
_cross_entropy_loss[non_ignored_mask].mean()
|
| 655 |
+
if non_ignored_mask.any()
|
| 656 |
+
else (_cross_entropy_loss.sum() * 0.0)
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Add cross-entropy loss to total
|
| 660 |
+
if not torch.isnan(cross_entropy_loss):
|
| 661 |
+
loss += cross_entropy_loss
|
| 662 |
+
else:
|
| 663 |
+
with torch.no_grad():
|
| 664 |
+
cross_entropy_loss.detach()
|
| 665 |
+
|
| 666 |
+
# Compute action token prediction accuracy (for logging)
|
| 667 |
+
shift_logits_for_acc = logits[..., :-1, :].contiguous()
|
| 668 |
+
action_preds = shift_logits_for_acc.argmax(dim=-1)
|
| 669 |
+
shift_labels_for_acc = labels[..., 1:].contiguous()
|
| 670 |
+
|
| 671 |
+
action_mask = (
|
| 672 |
+
shift_labels_for_acc >= self.fast_action_token_start_idx
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
if self.use_fast_tokenizer and action_mask.any():
|
| 676 |
+
correct_preds = (action_preds == shift_labels_for_acc) & action_mask
|
| 677 |
+
action_accuracy = (
|
| 678 |
+
correct_preds.sum().float() / action_mask.sum().float()
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
if channel_loss_dict is None:
|
| 682 |
+
channel_loss_dict = {}
|
| 683 |
+
channel_loss_count_dict = {}
|
| 684 |
+
|
| 685 |
+
channel_loss_dict["action_accuracy"] = action_accuracy.detach()
|
| 686 |
+
channel_loss_count_dict["action_accuracy"] = torch.tensor(1, device=action_accuracy.device)
|
| 687 |
+
|
| 688 |
+
# 14. Compute flow matching loss (DiT action head)
|
| 689 |
+
if dit_pred_v is not None and self.flow_matching_action_loss_weight > 0:
|
| 690 |
+
if channel_loss_dict is not None:
|
| 691 |
+
channel_loss_dict.update(
|
| 692 |
+
{
|
| 693 |
+
f"flow_matching/{dataset_name}": torch.tensor(0.0, device=logits.device)
|
| 694 |
+
for dataset_name in ACTION_DATASET_NAMES
|
| 695 |
+
}
|
| 696 |
+
)
|
| 697 |
+
channel_loss_count_dict.update(
|
| 698 |
+
{
|
| 699 |
+
f"flow_matching/{dataset_name}": torch.tensor(0, device=logits.device)
|
| 700 |
+
for dataset_name in ACTION_DATASET_NAMES
|
| 701 |
+
}
|
| 702 |
+
)
|
| 703 |
+
else:
|
| 704 |
+
channel_loss_dict = {
|
| 705 |
+
f"flow_matching/{dataset_name}": torch.tensor(0.0, device=logits.device)
|
| 706 |
+
for dataset_name in ACTION_DATASET_NAMES
|
| 707 |
+
}
|
| 708 |
+
channel_loss_count_dict = {
|
| 709 |
+
f"flow_matching/{dataset_name}": torch.tensor(0, device=logits.device)
|
| 710 |
+
for dataset_name in ACTION_DATASET_NAMES
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
# Compute flow matching loss: MSE between predicted and target velocity
|
| 714 |
+
_fm_loss = self.loss_mse(dit_pred_v, dit_velocity)
|
| 715 |
+
|
| 716 |
+
# Apply DOF mask (zero out invalid action dimensions)
|
| 717 |
+
if action_dof_mask is not None:
|
| 718 |
+
valid_action_dim = int(action_dof_mask[0, 0, :].sum(dim=-1).item()) # NOTE: only support 单种具身实体数据微调
|
| 719 |
+
_fm_loss = _fm_loss[:, :, :valid_action_dim]
|
| 720 |
+
|
| 721 |
+
# Apply action_is_pad mask: exclude padding timesteps from loss
|
| 722 |
+
# action_is_pad: (B, T), True = pad timestep → should not contribute to loss
|
| 723 |
+
if action_is_pad is not None:
|
| 724 |
+
valid_timestep_mask = ~action_is_pad[:, :_fm_loss.shape[1]] # align length
|
| 725 |
+
_fm_loss = _fm_loss * valid_timestep_mask.unsqueeze(-1)
|
| 726 |
+
flow_loss = _fm_loss.sum() / (valid_timestep_mask.sum() * _fm_loss.shape[-1])
|
| 727 |
+
else:
|
| 728 |
+
flow_loss = _fm_loss.mean()
|
| 729 |
+
|
| 730 |
+
if not torch.isnan(flow_loss):
|
| 731 |
+
loss = loss + self.flow_matching_action_loss_weight * flow_loss if loss is not None else self.flow_matching_action_loss_weight * flow_loss
|
| 732 |
+
else:
|
| 733 |
+
with torch.no_grad():
|
| 734 |
+
flow_loss.detach()
|
| 735 |
+
|
| 736 |
+
# Per-dataset flow matching loss logging
|
| 737 |
+
logging_fm_loss = _fm_loss.detach().mean(dim=(1, 2)) # Sum over chunk_size and action_dim
|
| 738 |
+
|
| 739 |
+
action_dataset_names = dataset_names if dataset_names is not None else []
|
| 740 |
+
unique_action_datasets = list(set(action_dataset_names))
|
| 741 |
+
|
| 742 |
+
for dataset_name_i in unique_action_datasets:
|
| 743 |
+
action_dataset_mask = torch.tensor(
|
| 744 |
+
[name == dataset_name_i for name in action_dataset_names],
|
| 745 |
+
device=logits.device,
|
| 746 |
+
)
|
| 747 |
+
if action_dataset_mask.any():
|
| 748 |
+
dataset_fm_loss = logging_fm_loss[action_dataset_mask].sum()
|
| 749 |
+
dataset_fm_count = action_dataset_mask.sum()
|
| 750 |
+
|
| 751 |
+
prefixed_key = f"flow_matching/{dataset_name_i}"
|
| 752 |
+
channel_loss_dict[prefixed_key] += dataset_fm_loss
|
| 753 |
+
channel_loss_count_dict[prefixed_key] += dataset_fm_count
|
| 754 |
+
|
| 755 |
+
elif self.flow_matching_action_loss_weight > 0:
|
| 756 |
+
# Dummy loss to keep all DiT parameters in computation graph
|
| 757 |
+
dummy_params = [p.sum() * 0.0 for p in self.dit_action_head.parameters() if p.requires_grad]
|
| 758 |
+
dummy_loss = sum(dummy_params) if len(dummy_params) > 0 else torch.tensor(0.0, device=logits.device)
|
| 759 |
+
loss = (loss + dummy_loss) if loss is not None else dummy_loss
|
| 760 |
+
|
| 761 |
+
return PRTS_Qwen3VL_ModelOutputWithPast(
|
| 762 |
+
loss=loss,
|
| 763 |
+
cross_entropy_loss=(
|
| 764 |
+
cross_entropy_loss.detach() if cross_entropy_loss is not None else None
|
| 765 |
+
),
|
| 766 |
+
flow_loss=(
|
| 767 |
+
flow_loss.detach() if flow_loss is not None else None
|
| 768 |
+
),
|
| 769 |
+
crl_loss=None,
|
| 770 |
+
logits=logits,
|
| 771 |
+
past_key_values=vlm_outputs.past_key_values,
|
| 772 |
+
# hidden_states=vlm_outputs.hidden_states,
|
| 773 |
+
# attentions=vlm_outputs.attentions,
|
| 774 |
+
crl_num_samples=None,
|
| 775 |
+
rope_deltas=self.rope_deltas,
|
| 776 |
+
channel_loss_dict=channel_loss_dict,
|
| 777 |
+
channel_loss_count_dict=channel_loss_count_dict,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def embed_prefix(
|
| 782 |
+
self,
|
| 783 |
+
input_ids: torch.LongTensor,
|
| 784 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 785 |
+
pixel_values: torch.Tensor | None = None,
|
| 786 |
+
pixel_values_videos: torch.FloatTensor | None = None,
|
| 787 |
+
image_grid_thw: torch.LongTensor | None = None,
|
| 788 |
+
video_grid_thw: torch.LongTensor | None = None,
|
| 789 |
+
**kwargs,
|
| 790 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 791 |
+
"""
|
| 792 |
+
Embed prefix tokens including vision, DeepStack, and (optionally) state features.
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
(inputs_embeds, visual_pos_masks, deepstack_visual_embeds)
|
| 796 |
+
"""
|
| 797 |
+
if inputs_embeds is None:
|
| 798 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 799 |
+
|
| 800 |
+
image_mask = None
|
| 801 |
+
video_mask = None
|
| 802 |
+
deepstack_image_embeds = None
|
| 803 |
+
deepstack_video_embeds = None
|
| 804 |
+
|
| 805 |
+
if pixel_values is not None:
|
| 806 |
+
image_embeds, deepstack_image_embeds = self.get_image_features(
|
| 807 |
+
pixel_values, image_grid_thw,
|
| 808 |
+
image_max_seqlen=kwargs.get('image_max_seqlen'),
|
| 809 |
+
)
|
| 810 |
+
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
| 811 |
+
image_mask, _ = self.get_placeholder_mask(
|
| 812 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
|
| 813 |
+
)
|
| 814 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 815 |
+
|
| 816 |
+
if pixel_values_videos is not None:
|
| 817 |
+
video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
| 818 |
+
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
| 819 |
+
_, video_mask = self.get_placeholder_mask(
|
| 820 |
+
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
|
| 821 |
+
)
|
| 822 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 823 |
+
|
| 824 |
+
visual_pos_masks = None
|
| 825 |
+
deepstack_visual_embeds = None
|
| 826 |
+
if image_mask is not None and video_mask is not None:
|
| 827 |
+
image_mask = image_mask[..., 0]
|
| 828 |
+
video_mask = video_mask[..., 0]
|
| 829 |
+
visual_pos_masks = image_mask | video_mask
|
| 830 |
+
deepstack_visual_embeds = []
|
| 831 |
+
image_mask_joint = image_mask[visual_pos_masks]
|
| 832 |
+
video_mask_joint = video_mask[visual_pos_masks]
|
| 833 |
+
for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
|
| 834 |
+
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
|
| 835 |
+
embed_joint[image_mask_joint, :] = img_embed
|
| 836 |
+
embed_joint[video_mask_joint, :] = vid_embed
|
| 837 |
+
deepstack_visual_embeds.append(embed_joint)
|
| 838 |
+
elif image_mask is not None:
|
| 839 |
+
image_mask = image_mask[..., 0]
|
| 840 |
+
visual_pos_masks = image_mask
|
| 841 |
+
deepstack_visual_embeds = deepstack_image_embeds
|
| 842 |
+
elif video_mask is not None:
|
| 843 |
+
video_mask = video_mask[..., 0]
|
| 844 |
+
visual_pos_masks = video_mask
|
| 845 |
+
deepstack_visual_embeds = deepstack_video_embeds
|
| 846 |
+
|
| 847 |
+
return inputs_embeds, visual_pos_masks, deepstack_visual_embeds
|
| 848 |
+
|
| 849 |
+
@torch.no_grad()
|
| 850 |
+
def sample_actions(
|
| 851 |
+
self,
|
| 852 |
+
input_ids: torch.LongTensor | None = None,
|
| 853 |
+
position_ids: torch.LongTensor | None = None,
|
| 854 |
+
attention_mask: torch.Tensor | None = None,
|
| 855 |
+
past_key_values: list[torch.FloatTensor] | None = None,
|
| 856 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 857 |
+
cache_position: torch.LongTensor | None = None,
|
| 858 |
+
pixel_values: torch.Tensor | None = None,
|
| 859 |
+
pixel_values_videos: torch.FloatTensor | None = None,
|
| 860 |
+
image_grid_thw: torch.LongTensor | None = None,
|
| 861 |
+
video_grid_thw: torch.LongTensor | None = None,
|
| 862 |
+
action_dof_mask: Optional[torch.Tensor] = None,
|
| 863 |
+
**kwargs,
|
| 864 |
+
) -> Tuple[torch.Tensor, Any]:
|
| 865 |
+
"""
|
| 866 |
+
Sample actions using DiT-based flow matching denoising.
|
| 867 |
+
|
| 868 |
+
1. Computes position_ids via get_rope_index
|
| 869 |
+
2. Embeds the prefix (with DeepStack visual features)
|
| 870 |
+
3. Runs the language model to get hidden states
|
| 871 |
+
4. Uses DiT action head to denoise actions via cross-attention to VLM features
|
| 872 |
+
|
| 873 |
+
Returns:
|
| 874 |
+
(x_t, outputs) — denoised action trajectories and language-model outputs
|
| 875 |
+
"""
|
| 876 |
+
if position_ids is None:
|
| 877 |
+
position_ids, _ = self.get_rope_index(
|
| 878 |
+
input_ids,
|
| 879 |
+
image_grid_thw=image_grid_thw,
|
| 880 |
+
video_grid_thw=video_grid_thw,
|
| 881 |
+
attention_mask=attention_mask,
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
visual_pos_masks = None
|
| 885 |
+
deepstack_visual_embeds = None
|
| 886 |
+
if inputs_embeds is None:
|
| 887 |
+
inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self.embed_prefix(
|
| 888 |
+
input_ids,
|
| 889 |
+
pixel_values=pixel_values,
|
| 890 |
+
pixel_values_videos=pixel_values_videos,
|
| 891 |
+
image_grid_thw=image_grid_thw,
|
| 892 |
+
video_grid_thw=video_grid_thw,
|
| 893 |
+
**kwargs,
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
_sample_use_cache = (
|
| 897 |
+
self.use_mot_action_expert and self.flow_matching_action_loss_weight > 0
|
| 898 |
+
)
|
| 899 |
+
outputs = self.language_model(
|
| 900 |
+
input_ids=None,
|
| 901 |
+
position_ids=position_ids,
|
| 902 |
+
attention_mask=attention_mask,
|
| 903 |
+
past_key_values=past_key_values,
|
| 904 |
+
inputs_embeds=inputs_embeds,
|
| 905 |
+
use_cache=_sample_use_cache,
|
| 906 |
+
cache_position=cache_position,
|
| 907 |
+
visual_pos_masks=visual_pos_masks,
|
| 908 |
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
| 909 |
+
output_hidden_states=False,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
vlm_hidden_states = outputs.last_hidden_state
|
| 913 |
+
dit_encoder_attention_mask = attention_mask.bool() if attention_mask is not None else None
|
| 914 |
+
|
| 915 |
+
if self.use_mot_action_expert and outputs.past_key_values is not None:
|
| 916 |
+
x_t = self.dit_action_head.predict_action(
|
| 917 |
+
outputs.past_key_values,
|
| 918 |
+
action_dof_mask,
|
| 919 |
+
encoder_attention_mask=dit_encoder_attention_mask,
|
| 920 |
+
)
|
| 921 |
+
else:
|
| 922 |
+
dit_image_mask = visual_pos_masks.bool() if visual_pos_masks is not None else None
|
| 923 |
+
x_t = self.dit_action_head.predict_action(
|
| 924 |
+
vlm_hidden_states, action_dof_mask,
|
| 925 |
+
encoder_attention_mask=dit_encoder_attention_mask,
|
| 926 |
+
image_mask=dit_image_mask,
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
return x_t, outputs
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
PRTS_Qwen3VL.register_for_auto_class()
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
__all__ = ["PRTS_Qwen3VL", "PRTS_Qwen3VL_ModelOutputWithPast"]
|
modeling_qwen3_vl.py
ADDED
|
@@ -0,0 +1,1645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_qwen3_vl.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# coding=utf-8
|
| 8 |
+
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Any, Callable, Optional, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.distributed as dist
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
from transformers.activations import ACT2FN
|
| 31 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 32 |
+
from transformers.generation import GenerationMixin
|
| 33 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 34 |
+
from transformers.masking_utils import create_causal_mask
|
| 35 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 36 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 37 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 38 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 39 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 40 |
+
from transformers.processing_utils import Unpack
|
| 41 |
+
from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
|
| 42 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 43 |
+
from transformers.utils.generic import check_model_inputs
|
| 44 |
+
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
|
| 45 |
+
# 在文件头部导入
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from qwen_rope_kernel_2 import fused_qwen_rope as fused_qwen_rope_v2
|
| 49 |
+
HAS_QWEN_ROPE_V2 = True
|
| 50 |
+
except ImportError:
|
| 51 |
+
print("No qwen_rope_kernel_2 found")
|
| 52 |
+
HAS_QWEN_ROPE_V2 = False
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
from fused_rmsnorm import RMSNormModelFunction as _FUSED_RMSFUNC
|
| 56 |
+
HAS_FUSED_RMSNORM = True
|
| 57 |
+
except ImportError:
|
| 58 |
+
print("No fused_rmsnorm found")
|
| 59 |
+
HAS_FUSED_RMSNORM = False
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Qwen3VLVisionMLP(nn.Module):
|
| 63 |
+
def __init__(self, config):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.hidden_size = config.hidden_size
|
| 66 |
+
self.intermediate_size = config.intermediate_size
|
| 67 |
+
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
|
| 68 |
+
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
|
| 69 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 70 |
+
|
| 71 |
+
def forward(self, hidden_state):
|
| 72 |
+
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class Qwen3VLVisionPatchEmbed(nn.Module):
|
| 76 |
+
def __init__(self, config) -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.patch_size = config.patch_size
|
| 79 |
+
self.temporal_patch_size = config.temporal_patch_size
|
| 80 |
+
self.in_channels = config.in_channels
|
| 81 |
+
self.embed_dim = config.hidden_size
|
| 82 |
+
|
| 83 |
+
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
| 84 |
+
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
|
| 85 |
+
|
| 86 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
target_dtype = self.proj.weight.dtype
|
| 88 |
+
hidden_states = hidden_states.view(
|
| 89 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
| 90 |
+
)
|
| 91 |
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
| 92 |
+
return hidden_states
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Qwen3VLVisionRotaryEmbedding(nn.Module):
|
| 96 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 97 |
+
|
| 98 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 101 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 102 |
+
|
| 103 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
| 104 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 105 |
+
freqs = torch.outer(seq, self.inv_freq)
|
| 106 |
+
return freqs
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Qwen3VLVisionPatchMerger(nn.Module):
|
| 110 |
+
def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
| 113 |
+
self.use_postshuffle_norm = use_postshuffle_norm
|
| 114 |
+
self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
|
| 115 |
+
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
|
| 116 |
+
self.act_fn = nn.GELU()
|
| 117 |
+
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
|
| 118 |
+
|
| 119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
|
| 121 |
+
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def rotate_half(x):
|
| 126 |
+
"""Rotates half the hidden dims of the input."""
|
| 127 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 128 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 129 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def apply_rotary_pos_emb_vision(
|
| 133 |
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 134 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 135 |
+
|
| 136 |
+
if HAS_QWEN_ROPE_V2 and q.is_cuda and q.dtype == torch.bfloat16 and q.shape[-1] in (64, 128):
|
| 137 |
+
# qwen_rope_kernel_2 handles (S, D) cos/sin for (S, H, D) input naturally.
|
| 138 |
+
# The kernel REQUIRES cos/sin to be 2D [S, D] if input is 3D [S, H, D].
|
| 139 |
+
# It DOES NOT support 3D [S, 1, D] for cos/sin.
|
| 140 |
+
|
| 141 |
+
if cos.dtype != torch.float32:
|
| 142 |
+
cos = cos.to(torch.float32)
|
| 143 |
+
if sin.dtype != torch.float32:
|
| 144 |
+
sin = sin.to(torch.float32)
|
| 145 |
+
|
| 146 |
+
# Proactively squeeze [S, 1, D] -> [S, D] to satisfy kernel requirements
|
| 147 |
+
# This is a view operation, zero memory copy overhead.
|
| 148 |
+
if cos.ndim == 3 and cos.shape[1] == 1:
|
| 149 |
+
cos = cos.squeeze(1)
|
| 150 |
+
sin = sin.squeeze(1)
|
| 151 |
+
|
| 152 |
+
return fused_qwen_rope_v2(q, cos, sin), fused_qwen_rope_v2(k, cos, sin)
|
| 153 |
+
|
| 154 |
+
orig_q_dtype = q.dtype
|
| 155 |
+
orig_k_dtype = k.dtype
|
| 156 |
+
q, k = q.float(), k.float()
|
| 157 |
+
if cos.ndim == 2:
|
| 158 |
+
cos = cos.unsqueeze(-2)
|
| 159 |
+
sin = sin.unsqueeze(-2)
|
| 160 |
+
if cos.dtype != torch.float32:
|
| 161 |
+
cos = cos.to(torch.float32)
|
| 162 |
+
if sin.dtype != torch.float32:
|
| 163 |
+
sin = sin.to(torch.float32)
|
| 164 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 165 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 166 |
+
return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 170 |
+
"""
|
| 171 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 172 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 173 |
+
"""
|
| 174 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 175 |
+
if n_rep == 1:
|
| 176 |
+
return hidden_states
|
| 177 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 178 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def eager_attention_forward(
|
| 182 |
+
module: nn.Module,
|
| 183 |
+
query: torch.Tensor,
|
| 184 |
+
key: torch.Tensor,
|
| 185 |
+
value: torch.Tensor,
|
| 186 |
+
attention_mask: Optional[torch.Tensor],
|
| 187 |
+
scaling: float,
|
| 188 |
+
dropout: float = 0.0,
|
| 189 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 190 |
+
):
|
| 191 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 192 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 193 |
+
|
| 194 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 195 |
+
if attention_mask is not None:
|
| 196 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 197 |
+
attn_weights = attn_weights + causal_mask
|
| 198 |
+
|
| 199 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 200 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 201 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 202 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 203 |
+
|
| 204 |
+
return attn_output, attn_weights
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Qwen3VLVisionAttention(nn.Module):
|
| 208 |
+
def __init__(self, config: Qwen3VLVisionConfig) -> None:
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.dim = config.hidden_size
|
| 211 |
+
self.num_heads = config.num_heads
|
| 212 |
+
self.head_dim = self.dim // self.num_heads
|
| 213 |
+
self.num_key_value_groups = 1 # needed for eager attention
|
| 214 |
+
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
| 215 |
+
self.proj = nn.Linear(self.dim, self.dim)
|
| 216 |
+
self.scaling = self.head_dim**-0.5
|
| 217 |
+
self.config = config
|
| 218 |
+
self.attention_dropout = 0.0
|
| 219 |
+
self.is_causal = False
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
hidden_states: torch.Tensor,
|
| 224 |
+
cu_seqlens: torch.Tensor,
|
| 225 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 226 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 227 |
+
**kwargs,
|
| 228 |
+
) -> torch.Tensor:
|
| 229 |
+
seq_length = hidden_states.shape[0]
|
| 230 |
+
query_states, key_states, value_states = (
|
| 231 |
+
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 232 |
+
)
|
| 233 |
+
cos, sin = position_embeddings
|
| 234 |
+
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
| 235 |
+
|
| 236 |
+
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
| 237 |
+
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
| 238 |
+
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
| 239 |
+
|
| 240 |
+
attention_interface: Callable = eager_attention_forward
|
| 241 |
+
if self.config._attn_implementation != "eager":
|
| 242 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 243 |
+
|
| 244 |
+
if self.config._attn_implementation in ["flash_attention_2", "flash_attention_3"]:
|
| 245 |
+
# Flash Attention 2: Use cu_seqlens for variable length attention
|
| 246 |
+
if "image_max_seqlen" in kwargs and kwargs["image_max_seqlen"] is not None:
|
| 247 |
+
max_seqlen = kwargs["image_max_seqlen"]
|
| 248 |
+
else:
|
| 249 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
| 250 |
+
|
| 251 |
+
attn_output, _ = attention_interface(
|
| 252 |
+
self,
|
| 253 |
+
query_states,
|
| 254 |
+
key_states,
|
| 255 |
+
value_states,
|
| 256 |
+
attention_mask=None,
|
| 257 |
+
scaling=self.scaling,
|
| 258 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 259 |
+
cu_seq_lens_q=cu_seqlens,
|
| 260 |
+
cu_seq_lens_k=cu_seqlens,
|
| 261 |
+
max_length_q=max_seqlen,
|
| 262 |
+
max_length_k=max_seqlen,
|
| 263 |
+
is_causal=False,
|
| 264 |
+
**kwargs,
|
| 265 |
+
)
|
| 266 |
+
else:
|
| 267 |
+
# Other implementations: Process each chunk separately
|
| 268 |
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
| 269 |
+
splits = [
|
| 270 |
+
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
attn_outputs = [
|
| 274 |
+
attention_interface(
|
| 275 |
+
self,
|
| 276 |
+
q,
|
| 277 |
+
k,
|
| 278 |
+
v,
|
| 279 |
+
attention_mask=None,
|
| 280 |
+
scaling=self.scaling,
|
| 281 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 282 |
+
is_causal=False,
|
| 283 |
+
**kwargs,
|
| 284 |
+
)[0]
|
| 285 |
+
for q, k, v in zip(*splits)
|
| 286 |
+
]
|
| 287 |
+
attn_output = torch.cat(attn_outputs, dim=1)
|
| 288 |
+
|
| 289 |
+
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
| 290 |
+
attn_output = self.proj(attn_output)
|
| 291 |
+
return attn_output
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class Qwen3VLVisionBlock(GradientCheckpointingLayer):
|
| 295 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
| 298 |
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
| 299 |
+
self.attn = Qwen3VLVisionAttention(config=config)
|
| 300 |
+
self.mlp = Qwen3VLVisionMLP(config=config)
|
| 301 |
+
|
| 302 |
+
def forward(
|
| 303 |
+
self,
|
| 304 |
+
hidden_states: torch.Tensor,
|
| 305 |
+
cu_seqlens: torch.Tensor,
|
| 306 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 307 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 308 |
+
**kwargs,
|
| 309 |
+
) -> torch.Tensor:
|
| 310 |
+
hidden_states = hidden_states + self.attn(
|
| 311 |
+
self.norm1(hidden_states),
|
| 312 |
+
cu_seqlens=cu_seqlens,
|
| 313 |
+
rotary_pos_emb=rotary_pos_emb,
|
| 314 |
+
position_embeddings=position_embeddings,
|
| 315 |
+
**kwargs,
|
| 316 |
+
)
|
| 317 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 318 |
+
return hidden_states
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class Qwen3VLTextRotaryEmbedding(nn.Module):
|
| 322 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 323 |
+
|
| 324 |
+
def __init__(self, config: Qwen3VLTextConfig, device=None):
|
| 325 |
+
super().__init__()
|
| 326 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 327 |
+
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
| 328 |
+
else:
|
| 329 |
+
self.rope_type = "default"
|
| 330 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 331 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 332 |
+
|
| 333 |
+
self.config = config
|
| 334 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 335 |
+
|
| 336 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 337 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 338 |
+
self.original_inv_freq = self.inv_freq
|
| 339 |
+
|
| 340 |
+
self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
|
| 341 |
+
|
| 342 |
+
def apply_interleaved_mrope(self, freqs, mrope_section):
|
| 343 |
+
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
| 344 |
+
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
| 345 |
+
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
| 346 |
+
args:
|
| 347 |
+
x: (3, bs, seq_len, head_dim // 2)
|
| 348 |
+
mrope_section: (3,)
|
| 349 |
+
returns:
|
| 350 |
+
x_t: (bs, seq_len, head_dim // 2)
|
| 351 |
+
"""
|
| 352 |
+
freqs_t = freqs[0] # just overwrite the first dimension T
|
| 353 |
+
for dim, offset in enumerate((1, 2), start=1): # H, W
|
| 354 |
+
length = mrope_section[dim] * 3
|
| 355 |
+
idx = slice(offset, length, 3)
|
| 356 |
+
freqs_t[..., idx] = freqs[dim, ..., idx]
|
| 357 |
+
return freqs_t
|
| 358 |
+
|
| 359 |
+
@torch.no_grad()
|
| 360 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 361 |
+
def forward(self, x, position_ids):
|
| 362 |
+
if position_ids.ndim == 2:
|
| 363 |
+
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 364 |
+
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
| 365 |
+
device = inv_freq_expanded.device
|
| 366 |
+
position_ids_expanded = position_ids[:, :, None, :].float().to(device)
|
| 367 |
+
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3)
|
| 368 |
+
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
| 369 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 370 |
+
cos = emb.cos() * self.attention_scaling
|
| 371 |
+
sin = emb.sin() * self.attention_scaling
|
| 372 |
+
return cos.contiguous(), sin.contiguous()
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 376 |
+
class Qwen3VLTextRMSNorm(nn.Module):
|
| 377 |
+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
| 378 |
+
"""
|
| 379 |
+
Qwen3VLTextRMSNorm is equivalent to T5LayerNorm
|
| 380 |
+
"""
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.bfloat16))
|
| 383 |
+
self.variance_epsilon = eps
|
| 384 |
+
|
| 385 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 386 |
+
if HAS_FUSED_RMSNORM and hidden_states.is_cuda:
|
| 387 |
+
x = hidden_states if hidden_states.dtype == torch.bfloat16 else hidden_states.to(torch.bfloat16)
|
| 388 |
+
x = x.contiguous()
|
| 389 |
+
return _FUSED_RMSFUNC.apply(x, self.weight, self.variance_epsilon, self.weight.shape[0])
|
| 390 |
+
input_dtype = hidden_states.dtype
|
| 391 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 392 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 393 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 394 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 395 |
+
|
| 396 |
+
def extra_repr(self):
|
| 397 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 401 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
q (`torch.Tensor`): The query tensor.
|
| 405 |
+
k (`torch.Tensor`): The key tensor.
|
| 406 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 407 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 408 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 409 |
+
Deprecated and unused.
|
| 410 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 411 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 412 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 413 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 414 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 415 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 416 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 417 |
+
Returns:
|
| 418 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 419 |
+
"""
|
| 420 |
+
if HAS_QWEN_ROPE_V2 and q.is_cuda and q.dtype == torch.bfloat16 and q.shape[-1] in (64, 128):
|
| 421 |
+
# qwen_rope_kernel_2 handles (S, D) cos/sin for (S, H, D) input naturally.
|
| 422 |
+
if cos.dtype != torch.float32:
|
| 423 |
+
cos = cos.to(torch.float32)
|
| 424 |
+
if sin.dtype != torch.float32:
|
| 425 |
+
sin = sin.to(torch.float32)
|
| 426 |
+
return fused_qwen_rope_v2(q, cos, sin), fused_qwen_rope_v2(k, cos, sin)
|
| 427 |
+
|
| 428 |
+
if cos.ndim != q.ndim:
|
| 429 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 430 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 431 |
+
if cos.dtype != q.dtype:
|
| 432 |
+
cos = cos.to(q.dtype)
|
| 433 |
+
sin = sin.to(q.dtype)
|
| 434 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 435 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 436 |
+
return q_embed, k_embed
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class Qwen3VLTextAttention(nn.Module):
|
| 440 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 441 |
+
|
| 442 |
+
def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
|
| 443 |
+
super().__init__()
|
| 444 |
+
self.config = config
|
| 445 |
+
self.layer_idx = layer_idx
|
| 446 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 447 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 448 |
+
self.scaling = self.head_dim**-0.5
|
| 449 |
+
self.attention_dropout = config.attention_dropout
|
| 450 |
+
self.is_causal = True
|
| 451 |
+
|
| 452 |
+
self.q_proj = nn.Linear(
|
| 453 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 454 |
+
)
|
| 455 |
+
self.k_proj = nn.Linear(
|
| 456 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 457 |
+
)
|
| 458 |
+
self.v_proj = nn.Linear(
|
| 459 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 460 |
+
)
|
| 461 |
+
self.o_proj = nn.Linear(
|
| 462 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 463 |
+
)
|
| 464 |
+
self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
| 465 |
+
self.k_norm = Qwen3VLTextRMSNorm(
|
| 466 |
+
self.head_dim, eps=config.rms_norm_eps
|
| 467 |
+
) # thus post q_norm does not need reshape
|
| 468 |
+
|
| 469 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 470 |
+
def forward(
|
| 471 |
+
self,
|
| 472 |
+
hidden_states: torch.Tensor,
|
| 473 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 474 |
+
attention_mask: Optional[torch.Tensor],
|
| 475 |
+
past_key_values: Optional[Cache] = None,
|
| 476 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 477 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 478 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 479 |
+
input_shape = hidden_states.shape[:-1]
|
| 480 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 481 |
+
|
| 482 |
+
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 483 |
+
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 484 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 485 |
+
|
| 486 |
+
cos, sin = position_embeddings
|
| 487 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 488 |
+
|
| 489 |
+
if past_key_values is not None:
|
| 490 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 491 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 492 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 493 |
+
|
| 494 |
+
attention_interface: Callable = eager_attention_forward
|
| 495 |
+
if self.config._attn_implementation != "eager":
|
| 496 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 497 |
+
|
| 498 |
+
attn_output, attn_weights = attention_interface(
|
| 499 |
+
self,
|
| 500 |
+
query_states,
|
| 501 |
+
key_states,
|
| 502 |
+
value_states,
|
| 503 |
+
attention_mask,
|
| 504 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 505 |
+
scaling=self.scaling,
|
| 506 |
+
**kwargs,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 510 |
+
attn_output = self.o_proj(attn_output)
|
| 511 |
+
return attn_output, attn_weights
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class Qwen3VLTextMLP(nn.Module):
|
| 515 |
+
def __init__(self, config):
|
| 516 |
+
super().__init__()
|
| 517 |
+
self.config = config
|
| 518 |
+
self.hidden_size = config.hidden_size
|
| 519 |
+
self.intermediate_size = config.intermediate_size
|
| 520 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 521 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 522 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 523 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 524 |
+
|
| 525 |
+
def forward(self, x):
|
| 526 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 527 |
+
return down_proj
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer):
|
| 531 |
+
def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
|
| 532 |
+
super().__init__()
|
| 533 |
+
self.hidden_size = config.hidden_size
|
| 534 |
+
|
| 535 |
+
self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx)
|
| 536 |
+
|
| 537 |
+
self.mlp = Qwen3VLTextMLP(config)
|
| 538 |
+
self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 539 |
+
self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 540 |
+
|
| 541 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 542 |
+
def forward(
|
| 543 |
+
self,
|
| 544 |
+
hidden_states: torch.Tensor,
|
| 545 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 546 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 547 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 548 |
+
past_key_values: Optional[Cache] = None,
|
| 549 |
+
use_cache: Optional[bool] = False,
|
| 550 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 551 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 552 |
+
) -> torch.Tensor:
|
| 553 |
+
residual = hidden_states
|
| 554 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 555 |
+
# Self Attention. DEBUG: When we use packing mode, here we would enter `qwen3vl_forward` in `train_utils.py`
|
| 556 |
+
hidden_states, _ = self.self_attn(
|
| 557 |
+
hidden_states=hidden_states,
|
| 558 |
+
attention_mask=attention_mask,
|
| 559 |
+
position_ids=position_ids,
|
| 560 |
+
past_key_values=past_key_values,
|
| 561 |
+
use_cache=use_cache,
|
| 562 |
+
cache_position=cache_position,
|
| 563 |
+
position_embeddings=position_embeddings,
|
| 564 |
+
**kwargs,
|
| 565 |
+
)
|
| 566 |
+
hidden_states = residual + hidden_states
|
| 567 |
+
|
| 568 |
+
# Fully Connected
|
| 569 |
+
residual = hidden_states
|
| 570 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 571 |
+
hidden_states = self.mlp(hidden_states)
|
| 572 |
+
hidden_states = residual + hidden_states
|
| 573 |
+
return hidden_states
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
@dataclass
|
| 577 |
+
@auto_docstring(
|
| 578 |
+
custom_intro="""
|
| 579 |
+
Base class for Llava outputs, with hidden states and attentions.
|
| 580 |
+
"""
|
| 581 |
+
)
|
| 582 |
+
class Qwen3VLModelOutputWithPast(ModelOutput):
|
| 583 |
+
r"""
|
| 584 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 585 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 586 |
+
|
| 587 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 588 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 589 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
| 590 |
+
The rope index difference between sequence length and multimodal rope.
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 594 |
+
past_key_values: Optional[Cache] = None
|
| 595 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 596 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 597 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
@auto_docstring
|
| 601 |
+
class Qwen3VLPreTrainedModel(PreTrainedModel):
|
| 602 |
+
config: Qwen3VLConfig
|
| 603 |
+
base_model_prefix = "model"
|
| 604 |
+
supports_gradient_checkpointing = True
|
| 605 |
+
_no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
|
| 606 |
+
_skip_keys_device_placement = "past_key_values"
|
| 607 |
+
_supports_flash_attn = True
|
| 608 |
+
_supports_sdpa = True
|
| 609 |
+
|
| 610 |
+
_can_compile_fullgraph = True
|
| 611 |
+
_supports_attention_backend = True
|
| 612 |
+
_can_record_outputs = {
|
| 613 |
+
"hidden_states": Qwen3VLTextDecoderLayer,
|
| 614 |
+
"attentions": Qwen3VLTextAttention,
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):
|
| 619 |
+
config: Qwen3VLVisionConfig
|
| 620 |
+
_no_split_modules = ["Qwen3VLVisionBlock"]
|
| 621 |
+
|
| 622 |
+
def __init__(self, config, *inputs, **kwargs) -> None:
|
| 623 |
+
super().__init__(config, *inputs, **kwargs)
|
| 624 |
+
self.spatial_merge_size = config.spatial_merge_size
|
| 625 |
+
self.patch_size = config.patch_size
|
| 626 |
+
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
| 627 |
+
|
| 628 |
+
self.patch_embed = Qwen3VLVisionPatchEmbed(
|
| 629 |
+
config=config,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
|
| 633 |
+
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
|
| 634 |
+
|
| 635 |
+
head_dim = config.hidden_size // config.num_heads
|
| 636 |
+
self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)
|
| 637 |
+
|
| 638 |
+
self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])
|
| 639 |
+
self.merger = Qwen3VLVisionPatchMerger(
|
| 640 |
+
config=config,
|
| 641 |
+
use_postshuffle_norm=False,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
self.deepstack_visual_indexes = config.deepstack_visual_indexes
|
| 645 |
+
self.deepstack_merger_list = nn.ModuleList(
|
| 646 |
+
[
|
| 647 |
+
Qwen3VLVisionPatchMerger(
|
| 648 |
+
config=config,
|
| 649 |
+
use_postshuffle_norm=True,
|
| 650 |
+
)
|
| 651 |
+
for _ in range(len(config.deepstack_visual_indexes))
|
| 652 |
+
]
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
self.gradient_checkpointing = False
|
| 656 |
+
|
| 657 |
+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
| 658 |
+
merge_size = self.spatial_merge_size
|
| 659 |
+
|
| 660 |
+
max_hw = int(grid_thw[:, 1:].max().item())
|
| 661 |
+
freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
|
| 662 |
+
device = freq_table.device
|
| 663 |
+
|
| 664 |
+
total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
|
| 665 |
+
# pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
| 666 |
+
pos_ids_cpu = torch.empty((total_tokens, 2) , dtype=torch.long , device="cpu")
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
offset = 0
|
| 670 |
+
for num_frames, height, width in grid_thw.numpy():
|
| 671 |
+
merged_h, merged_w = height // merge_size, width // merge_size
|
| 672 |
+
|
| 673 |
+
block_rows = torch.arange(merged_h, device="cpu") # block row indices
|
| 674 |
+
block_cols = torch.arange(merged_w, device="cpu") # block col indices
|
| 675 |
+
intra_row = torch.arange(merge_size, device="cpu") # intra-block row offsets
|
| 676 |
+
intra_col = torch.arange(merge_size, device="cpu") # intra-block col offsets
|
| 677 |
+
|
| 678 |
+
# Compute full-resolution positions
|
| 679 |
+
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
| 680 |
+
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
| 681 |
+
|
| 682 |
+
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
| 683 |
+
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
| 684 |
+
|
| 685 |
+
coords = torch.stack((row_idx, col_idx), dim=-1)
|
| 686 |
+
|
| 687 |
+
if num_frames > 1:
|
| 688 |
+
coords = coords.repeat(num_frames, 1)
|
| 689 |
+
|
| 690 |
+
num_tokens = coords.shape[0]
|
| 691 |
+
pos_ids_cpu[offset : offset + num_tokens] = coords
|
| 692 |
+
offset += num_tokens
|
| 693 |
+
|
| 694 |
+
pos_ids = pos_ids_cpu.to(device , non_blocking=True)
|
| 695 |
+
embeddings = freq_table[pos_ids] # lookup rotary embeddings
|
| 696 |
+
embeddings = embeddings.flatten(1)
|
| 697 |
+
return embeddings
|
| 698 |
+
|
| 699 |
+
def fast_pos_embed_interpolate(self, grid_thw):
|
| 700 |
+
# grid_thw 已经是 CPU Tensor,直接解包
|
| 701 |
+
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
|
| 702 |
+
|
| 703 |
+
idx_accum = [[] for _ in range(4)]
|
| 704 |
+
weight_accum = [[] for _ in range(4)]
|
| 705 |
+
|
| 706 |
+
# 预取配置,避免循环内 getattr
|
| 707 |
+
num_grid = self.num_grid_per_side
|
| 708 |
+
|
| 709 |
+
# 这一步依然需要在 CPU 循环计算,因为 H/W 是变长的,但这只是纯算数,很快
|
| 710 |
+
for h, w in zip(grid_hs, grid_ws):
|
| 711 |
+
|
| 712 |
+
h_idxs = torch.linspace(0, num_grid - 1, h)
|
| 713 |
+
w_idxs = torch.linspace(0, num_grid - 1, w)
|
| 714 |
+
|
| 715 |
+
h_idxs_floor = h_idxs.int()
|
| 716 |
+
w_idxs_floor = w_idxs.int()
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
h_idxs_ceil = (h_idxs_floor + 1).clamp(max=num_grid - 1)
|
| 720 |
+
w_idxs_ceil = (w_idxs_floor + 1).clamp(max=num_grid - 1)
|
| 721 |
+
|
| 722 |
+
dh = h_idxs - h_idxs_floor
|
| 723 |
+
dw = w_idxs - w_idxs_floor
|
| 724 |
+
|
| 725 |
+
base_h = h_idxs_floor * num_grid
|
| 726 |
+
base_h_ceil = h_idxs_ceil * num_grid
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
indices = [
|
| 730 |
+
(base_h[:, None] + w_idxs_floor[None, :]).flatten(),
|
| 731 |
+
(base_h[:, None] + w_idxs_ceil[None, :]).flatten(),
|
| 732 |
+
(base_h_ceil[:, None] + w_idxs_floor[None, :]).flatten(),
|
| 733 |
+
(base_h_ceil[:, None] + w_idxs_ceil[None, :]).flatten(),
|
| 734 |
+
]
|
| 735 |
+
|
| 736 |
+
weights = [
|
| 737 |
+
((1 - dh)[:, None] * (1 - dw)[None, :]).flatten(),
|
| 738 |
+
((1 - dh)[:, None] * dw[None, :]).flatten(),
|
| 739 |
+
(dh[:, None] * (1 - dw)[None, :]).flatten(),
|
| 740 |
+
(dh[:, None] * dw[None, :]).flatten(),
|
| 741 |
+
]
|
| 742 |
+
|
| 743 |
+
# 直接 Append Tensor,不做 tolist()
|
| 744 |
+
for i in range(4):
|
| 745 |
+
idx_accum[i].append(indices[i])
|
| 746 |
+
weight_accum[i].append(weights[i])
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
target_device = self.pos_embed.weight.device
|
| 750 |
+
target_dtype = self.pos_embed.weight.dtype
|
| 751 |
+
|
| 752 |
+
idx_tensor = torch.stack([torch.cat(acc) for acc in idx_accum]).to(device=target_device, dtype=torch.long)
|
| 753 |
+
weight_tensor = torch.stack([torch.cat(acc) for acc in weight_accum]).to(device=target_device, dtype=target_dtype)
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
|
| 757 |
+
patch_pos_embeds = pos_embeds.sum(dim=0)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
merge_size = self.config.spatial_merge_size
|
| 761 |
+
indices_list = []
|
| 762 |
+
current_offset = 0
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
for t, h, w in zip(grid_ts.tolist(), grid_hs.tolist(), grid_ws.tolist()):
|
| 766 |
+
|
| 767 |
+
local_ids = torch.arange(h * w, device='cpu').view(h, w)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
local_ids_permuted = (
|
| 771 |
+
local_ids.view(h // merge_size, merge_size, w // merge_size, merge_size)
|
| 772 |
+
.permute(0, 2, 1, 3)
|
| 773 |
+
.reshape(-1)
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
global_ids = local_ids_permuted + current_offset
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
if t > 1:
|
| 781 |
+
global_ids = global_ids.repeat(t)
|
| 782 |
+
|
| 783 |
+
indices_list.append(global_ids)
|
| 784 |
+
current_offset += h * w
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
all_indices = torch.cat(indices_list).to(target_device)
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
patch_pos_embeds = patch_pos_embeds[all_indices]
|
| 791 |
+
|
| 792 |
+
return patch_pos_embeds
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 796 |
+
"""
|
| 797 |
+
Args:
|
| 798 |
+
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
| 799 |
+
The final hidden states of the model.
|
| 800 |
+
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
| 801 |
+
The temporal, height and width of feature shape of each image in LLM.
|
| 802 |
+
|
| 803 |
+
Returns:
|
| 804 |
+
`torch.Tensor`: hidden_states.
|
| 805 |
+
"""
|
| 806 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 807 |
+
|
| 808 |
+
#move grid_thw to cpu
|
| 809 |
+
grid_thw_cpu = grid_thw.cpu()
|
| 810 |
+
|
| 811 |
+
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_cpu)
|
| 812 |
+
hidden_states = hidden_states + pos_embeds
|
| 813 |
+
|
| 814 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw_cpu)
|
| 815 |
+
|
| 816 |
+
seq_len, _ = hidden_states.size()
|
| 817 |
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
| 818 |
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
| 819 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 820 |
+
cos = emb.cos().to(torch.float32).unsqueeze(-2).contiguous()
|
| 821 |
+
sin = emb.sin().to(torch.float32).unsqueeze(-2).contiguous()
|
| 822 |
+
cos = cos.to(device=hidden_states.device, non_blocking=True)
|
| 823 |
+
sin = sin.to(device=hidden_states.device, non_blocking=True)
|
| 824 |
+
position_embeddings = (cos, sin)
|
| 825 |
+
|
| 826 |
+
#use the grid_thw in gpu
|
| 827 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 828 |
+
dim=0,
|
| 829 |
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
| 830 |
+
)
|
| 831 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 832 |
+
cu_seqlens = cu_seqlens.to(device=hidden_states.device)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
deepstack_feature_lists = []
|
| 836 |
+
for layer_num, blk in enumerate(self.blocks):
|
| 837 |
+
if self.gradient_checkpointing and self.training:
|
| 838 |
+
blk.gradient_checkpointing = False
|
| 839 |
+
def create_custom_forward(module):
|
| 840 |
+
def custom_forward(*inputs):
|
| 841 |
+
return module(inputs[0], inputs[1], inputs[2], inputs[3], **inputs[4])
|
| 842 |
+
return custom_forward
|
| 843 |
+
|
| 844 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 845 |
+
create_custom_forward(blk),
|
| 846 |
+
hidden_states,
|
| 847 |
+
cu_seqlens,
|
| 848 |
+
None,
|
| 849 |
+
position_embeddings,
|
| 850 |
+
kwargs,
|
| 851 |
+
)
|
| 852 |
+
else:
|
| 853 |
+
hidden_states = blk(
|
| 854 |
+
hidden_states,
|
| 855 |
+
cu_seqlens=cu_seqlens,
|
| 856 |
+
position_embeddings=position_embeddings,
|
| 857 |
+
**kwargs,
|
| 858 |
+
)
|
| 859 |
+
if layer_num in self.deepstack_visual_indexes:
|
| 860 |
+
deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
|
| 861 |
+
hidden_states
|
| 862 |
+
)
|
| 863 |
+
deepstack_feature_lists.append(deepstack_feature)
|
| 864 |
+
|
| 865 |
+
hidden_states = self.merger(hidden_states)
|
| 866 |
+
|
| 867 |
+
return hidden_states, deepstack_feature_lists
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
@auto_docstring(
|
| 871 |
+
custom_intro=(
|
| 872 |
+
"Text part of Qwen3VL, "
|
| 873 |
+
"not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
|
| 874 |
+
)
|
| 875 |
+
)
|
| 876 |
+
class Qwen3VLTextModel(Qwen3VLPreTrainedModel):
|
| 877 |
+
config: Qwen3VLTextConfig
|
| 878 |
+
_no_split_modules = ["Qwen3VLTextDecoderLayer"]
|
| 879 |
+
|
| 880 |
+
def __init__(self, config: Qwen3VLTextConfig):
|
| 881 |
+
super().__init__(config)
|
| 882 |
+
self.padding_idx = config.pad_token_id
|
| 883 |
+
self.vocab_size = config.vocab_size
|
| 884 |
+
|
| 885 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 886 |
+
self.layers = nn.ModuleList(
|
| 887 |
+
[Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 888 |
+
)
|
| 889 |
+
self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 890 |
+
self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)
|
| 891 |
+
self.gradient_checkpointing = False
|
| 892 |
+
|
| 893 |
+
# Initialize weights and apply final processing
|
| 894 |
+
self.post_init()
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
def get_input_embeddings(self):
|
| 898 |
+
return self.embed_tokens
|
| 899 |
+
|
| 900 |
+
def set_input_embeddings(self, value):
|
| 901 |
+
self.embed_tokens = value
|
| 902 |
+
|
| 903 |
+
@check_model_inputs()
|
| 904 |
+
@auto_docstring
|
| 905 |
+
def forward(
|
| 906 |
+
self,
|
| 907 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 908 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 909 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 910 |
+
past_key_values: Optional[Cache] = None,
|
| 911 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 912 |
+
use_cache: Optional[bool] = None,
|
| 913 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 914 |
+
# args for deepstack
|
| 915 |
+
visual_pos_masks: Optional[torch.Tensor] = None,
|
| 916 |
+
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
|
| 917 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 918 |
+
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 919 |
+
r"""
|
| 920 |
+
visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
|
| 921 |
+
The mask of the visual positions.
|
| 922 |
+
deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
|
| 923 |
+
The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
|
| 924 |
+
The feature is extracted from the different visual encoder layers, and fed to the decoder
|
| 925 |
+
hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
|
| 926 |
+
"""
|
| 927 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 928 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 929 |
+
|
| 930 |
+
# torch.jit.trace() doesn't support cache objects in the output
|
| 931 |
+
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
| 932 |
+
past_key_values = DynamicCache(config=self.config)
|
| 933 |
+
|
| 934 |
+
if inputs_embeds is None:
|
| 935 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 936 |
+
|
| 937 |
+
if cache_position is None:
|
| 938 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 939 |
+
cache_position = torch.arange(
|
| 940 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
# the hard coded `3` is for temporal, height and width.
|
| 944 |
+
if position_ids is None:
|
| 945 |
+
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) # (3, bs, seq_length)
|
| 946 |
+
elif position_ids.ndim == 2:
|
| 947 |
+
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 948 |
+
|
| 949 |
+
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
| 950 |
+
text_position_ids = position_ids[0]
|
| 951 |
+
position_ids = position_ids[1:]
|
| 952 |
+
else:
|
| 953 |
+
text_position_ids = position_ids[0]
|
| 954 |
+
# NOTE: Attention! When we use packing mode, this `create_causal_mask` is overwrited, and directly return `attention_mask`.
|
| 955 |
+
attention_mask = create_causal_mask(
|
| 956 |
+
config=self.config,
|
| 957 |
+
input_embeds=inputs_embeds,
|
| 958 |
+
attention_mask=attention_mask,
|
| 959 |
+
cache_position=cache_position,
|
| 960 |
+
past_key_values=past_key_values,
|
| 961 |
+
position_ids=text_position_ids,
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
hidden_states = inputs_embeds
|
| 965 |
+
|
| 966 |
+
# create position embeddings to be shared across the decoder layers
|
| 967 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 968 |
+
cos, sin = position_embeddings
|
| 969 |
+
cos = cos.to(device=hidden_states.device, non_blocking=True).unsqueeze(1).contiguous()
|
| 970 |
+
sin = sin.to(device=hidden_states.device, non_blocking=True).unsqueeze(1).contiguous()
|
| 971 |
+
position_embeddings = (cos, sin)
|
| 972 |
+
|
| 973 |
+
# decoder layers
|
| 974 |
+
for layer_idx, decoder_layer in enumerate(self.layers):
|
| 975 |
+
if self.gradient_checkpointing and self.training:
|
| 976 |
+
decoder_layer.gradient_checkpointing = False
|
| 977 |
+
def create_custom_forward(module): # DEBUG: Here we enter the Qwen3VLTextDecoderLayer forward
|
| 978 |
+
def custom_forward(*inputs):
|
| 979 |
+
# inputs: hidden_states, position_embeddings, attention_mask, position_ids, past_key_values, use_cache, cache_position, kwargs_dict
|
| 980 |
+
return module(
|
| 981 |
+
inputs[0],
|
| 982 |
+
inputs[1],
|
| 983 |
+
attention_mask=inputs[2],
|
| 984 |
+
position_ids=inputs[3],
|
| 985 |
+
past_key_values=inputs[4],
|
| 986 |
+
use_cache=inputs[5],
|
| 987 |
+
cache_position=inputs[6],
|
| 988 |
+
**inputs[7]
|
| 989 |
+
)
|
| 990 |
+
return custom_forward
|
| 991 |
+
|
| 992 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 993 |
+
create_custom_forward(decoder_layer),
|
| 994 |
+
hidden_states,
|
| 995 |
+
position_embeddings,
|
| 996 |
+
attention_mask,
|
| 997 |
+
text_position_ids,
|
| 998 |
+
past_key_values,
|
| 999 |
+
False, # use_cache
|
| 1000 |
+
cache_position,
|
| 1001 |
+
kwargs,
|
| 1002 |
+
)
|
| 1003 |
+
else:
|
| 1004 |
+
layer_outputs = decoder_layer(
|
| 1005 |
+
hidden_states,
|
| 1006 |
+
attention_mask=attention_mask,
|
| 1007 |
+
position_ids=text_position_ids,
|
| 1008 |
+
past_key_values=past_key_values,
|
| 1009 |
+
cache_position=cache_position,
|
| 1010 |
+
position_embeddings=position_embeddings,
|
| 1011 |
+
**kwargs,
|
| 1012 |
+
)
|
| 1013 |
+
hidden_states = layer_outputs
|
| 1014 |
+
|
| 1015 |
+
# add visual features to the hidden states of first several layers
|
| 1016 |
+
if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
|
| 1017 |
+
hidden_states = self._deepstack_process(
|
| 1018 |
+
hidden_states,
|
| 1019 |
+
visual_pos_masks,
|
| 1020 |
+
deepstack_visual_embeds[layer_idx],
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
hidden_states = self.norm(hidden_states)
|
| 1024 |
+
|
| 1025 |
+
return BaseModelOutputWithPast(
|
| 1026 |
+
last_hidden_state=hidden_states,
|
| 1027 |
+
past_key_values=past_key_values,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
def _deepstack_process(
|
| 1031 |
+
self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
|
| 1032 |
+
):
|
| 1033 |
+
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
|
| 1034 |
+
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
|
| 1035 |
+
local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
|
| 1036 |
+
hidden_states[visual_pos_masks, :] = local_this
|
| 1037 |
+
return hidden_states
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
@dataclass
|
| 1041 |
+
@auto_docstring(
|
| 1042 |
+
custom_intro="""
|
| 1043 |
+
Base class for Qwen3VL causal language model (or autoregressive) outputs.
|
| 1044 |
+
"""
|
| 1045 |
+
)
|
| 1046 |
+
class Qwen3VLCausalLMOutputWithPast(ModelOutput):
|
| 1047 |
+
r"""
|
| 1048 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 1049 |
+
Language modeling loss (for next-token prediction).
|
| 1050 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 1051 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 1052 |
+
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 1053 |
+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
| 1054 |
+
|
| 1055 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 1056 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 1057 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
| 1058 |
+
The rope index difference between sequence length and multimodal rope.
|
| 1059 |
+
"""
|
| 1060 |
+
|
| 1061 |
+
loss: Optional[torch.FloatTensor] = None
|
| 1062 |
+
logits: Optional[torch.FloatTensor] = None
|
| 1063 |
+
past_key_values: Optional[Cache] = None
|
| 1064 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 1065 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 1066 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
|
| 1070 |
+
_checkpoint_conversion_mapping = {}
|
| 1071 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1072 |
+
# Reference: fix gemma3 grad acc #37208
|
| 1073 |
+
accepts_loss_kwargs = False
|
| 1074 |
+
config: Qwen3VLConfig
|
| 1075 |
+
_no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
|
| 1076 |
+
|
| 1077 |
+
def __init__(self, config):
|
| 1078 |
+
super().__init__(config)
|
| 1079 |
+
# Directly initialize visual and language_model instead of using Qwen3VLModel
|
| 1080 |
+
self.visual = Qwen3VLVisionModel._from_config(config.vision_config)
|
| 1081 |
+
self.language_model = Qwen3VLTextModel._from_config(config.text_config)
|
| 1082 |
+
self.rope_deltas = None # cache rope_deltas here
|
| 1083 |
+
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 1084 |
+
|
| 1085 |
+
self.post_init()
|
| 1086 |
+
|
| 1087 |
+
def get_input_embeddings(self):
|
| 1088 |
+
return self.language_model.get_input_embeddings()
|
| 1089 |
+
|
| 1090 |
+
def set_input_embeddings(self, value):
|
| 1091 |
+
self.language_model.set_input_embeddings(value)
|
| 1092 |
+
|
| 1093 |
+
def set_decoder(self, decoder):
|
| 1094 |
+
self.language_model = decoder
|
| 1095 |
+
|
| 1096 |
+
def get_decoder(self):
|
| 1097 |
+
return self.language_model
|
| 1098 |
+
|
| 1099 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| 1100 |
+
self.gradient_checkpointing = True
|
| 1101 |
+
self.visual.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
|
| 1102 |
+
self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
def get_rope_index(
|
| 1106 |
+
self,
|
| 1107 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1108 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 1109 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 1110 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1111 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1112 |
+
"""Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
|
| 1113 |
+
|
| 1114 |
+
# Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
|
| 1115 |
+
if video_grid_thw is not None:
|
| 1116 |
+
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
|
| 1117 |
+
video_grid_thw[:, 0] = 1
|
| 1118 |
+
|
| 1119 |
+
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
| 1120 |
+
image_token_id = self.config.image_token_id
|
| 1121 |
+
video_token_id = self.config.video_token_id
|
| 1122 |
+
vision_start_token_id = self.config.vision_start_token_id
|
| 1123 |
+
mrope_position_deltas = []
|
| 1124 |
+
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
| 1125 |
+
total_input_ids = input_ids
|
| 1126 |
+
if attention_mask is None:
|
| 1127 |
+
attention_mask = torch.ones_like(total_input_ids)
|
| 1128 |
+
position_ids = torch.ones(
|
| 1129 |
+
3,
|
| 1130 |
+
input_ids.shape[0],
|
| 1131 |
+
input_ids.shape[1],
|
| 1132 |
+
dtype=input_ids.dtype,
|
| 1133 |
+
device=input_ids.device,
|
| 1134 |
+
)
|
| 1135 |
+
image_index, video_index = 0, 0
|
| 1136 |
+
attention_mask = attention_mask.to(total_input_ids.device)
|
| 1137 |
+
for i, input_ids in enumerate(total_input_ids):
|
| 1138 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
| 1139 |
+
image_nums, video_nums = 0, 0
|
| 1140 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
| 1141 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
| 1142 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
| 1143 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
| 1144 |
+
input_tokens = input_ids.tolist()
|
| 1145 |
+
llm_pos_ids_list: list = []
|
| 1146 |
+
st = 0
|
| 1147 |
+
remain_images, remain_videos = image_nums, video_nums
|
| 1148 |
+
for _ in range(image_nums + video_nums):
|
| 1149 |
+
if image_token_id in input_tokens and remain_images > 0:
|
| 1150 |
+
ed_image = input_tokens.index(image_token_id, st)
|
| 1151 |
+
else:
|
| 1152 |
+
ed_image = len(input_tokens) + 1
|
| 1153 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
| 1154 |
+
ed_video = input_tokens.index(video_token_id, st)
|
| 1155 |
+
else:
|
| 1156 |
+
ed_video = len(input_tokens) + 1
|
| 1157 |
+
if ed_image < ed_video:
|
| 1158 |
+
t, h, w = (
|
| 1159 |
+
image_grid_thw[image_index][0],
|
| 1160 |
+
image_grid_thw[image_index][1],
|
| 1161 |
+
image_grid_thw[image_index][2],
|
| 1162 |
+
)
|
| 1163 |
+
image_index += 1
|
| 1164 |
+
remain_images -= 1
|
| 1165 |
+
ed = ed_image
|
| 1166 |
+
|
| 1167 |
+
else:
|
| 1168 |
+
t, h, w = (
|
| 1169 |
+
video_grid_thw[video_index][0],
|
| 1170 |
+
video_grid_thw[video_index][1],
|
| 1171 |
+
video_grid_thw[video_index][2],
|
| 1172 |
+
)
|
| 1173 |
+
video_index += 1
|
| 1174 |
+
remain_videos -= 1
|
| 1175 |
+
ed = ed_video
|
| 1176 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
| 1177 |
+
t.item(),
|
| 1178 |
+
h.item() // spatial_merge_size,
|
| 1179 |
+
w.item() // spatial_merge_size,
|
| 1180 |
+
)
|
| 1181 |
+
text_len = ed - st
|
| 1182 |
+
|
| 1183 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 1184 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 1185 |
+
|
| 1186 |
+
# t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
|
| 1187 |
+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
| 1188 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
| 1189 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
| 1190 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
| 1191 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 1192 |
+
|
| 1193 |
+
if st < len(input_tokens):
|
| 1194 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 1195 |
+
text_len = len(input_tokens) - st
|
| 1196 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 1197 |
+
|
| 1198 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 1199 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
| 1200 |
+
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
| 1201 |
+
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
| 1202 |
+
return position_ids, mrope_position_deltas
|
| 1203 |
+
else:
|
| 1204 |
+
if attention_mask is not None:
|
| 1205 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1206 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1207 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
| 1208 |
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
| 1209 |
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
| 1210 |
+
else:
|
| 1211 |
+
position_ids = (
|
| 1212 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
| 1213 |
+
.view(1, 1, -1)
|
| 1214 |
+
.expand(3, input_ids.shape[0], -1)
|
| 1215 |
+
)
|
| 1216 |
+
mrope_position_deltas = torch.zeros(
|
| 1217 |
+
[input_ids.shape[0], 1],
|
| 1218 |
+
device=input_ids.device,
|
| 1219 |
+
dtype=input_ids.dtype,
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
return position_ids, mrope_position_deltas
|
| 1223 |
+
|
| 1224 |
+
def get_video_features(
|
| 1225 |
+
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
| 1226 |
+
):
|
| 1227 |
+
"""
|
| 1228 |
+
Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
|
| 1229 |
+
|
| 1230 |
+
Args:
|
| 1231 |
+
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
| 1232 |
+
The tensors corresponding to the input videos.
|
| 1233 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
| 1234 |
+
The temporal, height and width of feature shape of each video in LLM.
|
| 1235 |
+
"""
|
| 1236 |
+
# Same implementation as for images
|
| 1237 |
+
return self.get_image_features(pixel_values_videos, video_grid_thw)
|
| 1238 |
+
|
| 1239 |
+
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None, **kwargs):
|
| 1240 |
+
"""
|
| 1241 |
+
Encodes images into continuous embeddings that can be forwarded to the language model.
|
| 1242 |
+
|
| 1243 |
+
Args:
|
| 1244 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
| 1245 |
+
The tensors corresponding to the input images.
|
| 1246 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
| 1247 |
+
The temporal, height and width of feature shape of each image in LLM.
|
| 1248 |
+
"""
|
| 1249 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
| 1250 |
+
image_embeds, deepstack_feature_lists = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs)
|
| 1251 |
+
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
| 1252 |
+
image_embeds = torch.split(image_embeds, split_sizes)
|
| 1253 |
+
return image_embeds, deepstack_feature_lists
|
| 1254 |
+
|
| 1255 |
+
def get_placeholder_mask(
|
| 1256 |
+
self,
|
| 1257 |
+
input_ids: torch.LongTensor,
|
| 1258 |
+
inputs_embeds: torch.FloatTensor,
|
| 1259 |
+
image_features: Optional[torch.FloatTensor] = None,
|
| 1260 |
+
video_features: Optional[torch.FloatTensor] = None,
|
| 1261 |
+
):
|
| 1262 |
+
"""
|
| 1263 |
+
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
| 1264 |
+
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
| 1265 |
+
"""
|
| 1266 |
+
if input_ids is None:
|
| 1267 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 1268 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 1269 |
+
)
|
| 1270 |
+
special_image_mask = special_image_mask.all(-1)
|
| 1271 |
+
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
| 1272 |
+
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 1273 |
+
)
|
| 1274 |
+
special_video_mask = special_video_mask.all(-1)
|
| 1275 |
+
else:
|
| 1276 |
+
special_image_mask = input_ids == self.config.image_token_id
|
| 1277 |
+
special_video_mask = input_ids == self.config.video_token_id
|
| 1278 |
+
|
| 1279 |
+
n_image_tokens = special_image_mask.sum()
|
| 1280 |
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 1281 |
+
if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 1282 |
+
raise ValueError(
|
| 1283 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
|
| 1284 |
+
)
|
| 1285 |
+
|
| 1286 |
+
n_video_tokens = special_video_mask.sum()
|
| 1287 |
+
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 1288 |
+
if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
|
| 1289 |
+
raise ValueError(
|
| 1290 |
+
f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
|
| 1291 |
+
)
|
| 1292 |
+
|
| 1293 |
+
return special_image_mask, special_video_mask
|
| 1294 |
+
|
| 1295 |
+
@check_model_inputs()
|
| 1296 |
+
def forward(
|
| 1297 |
+
self,
|
| 1298 |
+
input_ids: torch.LongTensor = None,
|
| 1299 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1300 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1301 |
+
past_key_values: Optional[Cache] = None,
|
| 1302 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1303 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1304 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1305 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 1306 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 1307 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 1308 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1309 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1310 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1311 |
+
) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]:
|
| 1312 |
+
r"""
|
| 1313 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1314 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1315 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1316 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1317 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
| 1318 |
+
The temporal, height and width of feature shape of each image in LLM.
|
| 1319 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
| 1320 |
+
The temporal, height and width of feature shape of each video in LLM.
|
| 1321 |
+
|
| 1322 |
+
Example:
|
| 1323 |
+
TODO: Add example
|
| 1324 |
+
"""
|
| 1325 |
+
# Inlined from Qwen3VLModel.forward
|
| 1326 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1327 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1328 |
+
|
| 1329 |
+
if inputs_embeds is None:
|
| 1330 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1331 |
+
|
| 1332 |
+
image_mask = None
|
| 1333 |
+
video_mask = None
|
| 1334 |
+
|
| 1335 |
+
if pixel_values is not None:
|
| 1336 |
+
image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw, image_max_seqlen=kwargs.get("image_max_seqlen"))
|
| 1337 |
+
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1338 |
+
|
| 1339 |
+
image_mask, _ = self.get_placeholder_mask(
|
| 1340 |
+
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
|
| 1341 |
+
)
|
| 1342 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
if pixel_values_videos is not None:
|
| 1346 |
+
video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
| 1347 |
+
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
|
| 1348 |
+
|
| 1349 |
+
_, video_mask = self.get_placeholder_mask(
|
| 1350 |
+
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
|
| 1351 |
+
)
|
| 1352 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 1353 |
+
|
| 1354 |
+
|
| 1355 |
+
visual_pos_masks = None
|
| 1356 |
+
deepstack_visual_embeds = None
|
| 1357 |
+
if image_mask is not None and video_mask is not None:
|
| 1358 |
+
# aggregate visual_pos_masks and deepstack_visual_embeds
|
| 1359 |
+
image_mask = image_mask[..., 0]
|
| 1360 |
+
video_mask = video_mask[..., 0]
|
| 1361 |
+
visual_pos_masks = image_mask | video_mask
|
| 1362 |
+
deepstack_visual_embeds = []
|
| 1363 |
+
image_mask_joint = image_mask[visual_pos_masks]
|
| 1364 |
+
video_mask_joint = video_mask[visual_pos_masks]
|
| 1365 |
+
for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
|
| 1366 |
+
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
|
| 1367 |
+
embed_joint[image_mask_joint, :] = img_embed
|
| 1368 |
+
embed_joint[video_mask_joint, :] = vid_embed
|
| 1369 |
+
deepstack_visual_embeds.append(embed_joint)
|
| 1370 |
+
elif image_mask is not None:
|
| 1371 |
+
image_mask = image_mask[..., 0]
|
| 1372 |
+
visual_pos_masks = image_mask
|
| 1373 |
+
deepstack_visual_embeds = deepstack_image_embeds
|
| 1374 |
+
elif video_mask is not None:
|
| 1375 |
+
video_mask = video_mask[..., 0]
|
| 1376 |
+
visual_pos_masks = video_mask
|
| 1377 |
+
deepstack_visual_embeds = deepstack_video_embeds
|
| 1378 |
+
|
| 1379 |
+
if position_ids is None:
|
| 1380 |
+
attention_mask_tensor = (
|
| 1381 |
+
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
| 1382 |
+
)
|
| 1383 |
+
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
| 1384 |
+
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
| 1385 |
+
# Only apply conversion for floating point tensors (inverted masks)
|
| 1386 |
+
if attention_mask_tensor.dtype.is_floating_point:
|
| 1387 |
+
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
| 1388 |
+
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
| 1389 |
+
|
| 1390 |
+
# Calculate RoPE index once per generation in the pre-fill stage only.
|
| 1391 |
+
# When compiling, we can't check tensor values thus we check only input length
|
| 1392 |
+
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
| 1393 |
+
# models currently cannot do asssisted decoding
|
| 1394 |
+
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
| 1395 |
+
(input_ids is not None and input_ids.shape[1] != 1)
|
| 1396 |
+
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
| 1397 |
+
)
|
| 1398 |
+
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
| 1399 |
+
(cache_position is not None and cache_position[0] == 0)
|
| 1400 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 1401 |
+
)
|
| 1402 |
+
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
| 1403 |
+
position_ids, rope_deltas = self.get_rope_index(
|
| 1404 |
+
input_ids,
|
| 1405 |
+
image_grid_thw,
|
| 1406 |
+
video_grid_thw,
|
| 1407 |
+
attention_mask=attention_mask_tensor,
|
| 1408 |
+
)
|
| 1409 |
+
self.rope_deltas = rope_deltas
|
| 1410 |
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
| 1411 |
+
else:
|
| 1412 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 1413 |
+
delta = (
|
| 1414 |
+
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
| 1415 |
+
if cache_position is not None
|
| 1416 |
+
else 0
|
| 1417 |
+
)
|
| 1418 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 1419 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 1420 |
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
| 1421 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
| 1422 |
+
position_ids = position_ids.add(delta)
|
| 1423 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 1424 |
+
|
| 1425 |
+
if kwargs.get("max_seqlen") is not None:
|
| 1426 |
+
try:
|
| 1427 |
+
self.language_model.config.max_seqlen = int(kwargs.get("max_seqlen"))
|
| 1428 |
+
except Exception:
|
| 1429 |
+
self.language_model.config.max_seqlen = kwargs.get("max_seqlen")
|
| 1430 |
+
|
| 1431 |
+
outputs = self.language_model(
|
| 1432 |
+
input_ids=None,
|
| 1433 |
+
position_ids=position_ids,
|
| 1434 |
+
attention_mask=attention_mask,
|
| 1435 |
+
past_key_values=past_key_values,
|
| 1436 |
+
inputs_embeds=inputs_embeds,
|
| 1437 |
+
cache_position=cache_position,
|
| 1438 |
+
visual_pos_masks=visual_pos_masks,
|
| 1439 |
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
| 1440 |
+
**kwargs,
|
| 1441 |
+
)
|
| 1442 |
+
|
| 1443 |
+
hidden_states = outputs[0]
|
| 1444 |
+
|
| 1445 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1446 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1447 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1448 |
+
|
| 1449 |
+
loss = None
|
| 1450 |
+
if labels is not None:
|
| 1451 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
| 1452 |
+
|
| 1453 |
+
return Qwen3VLCausalLMOutputWithPast(
|
| 1454 |
+
loss=loss,
|
| 1455 |
+
logits=logits,
|
| 1456 |
+
past_key_values=outputs.past_key_values,
|
| 1457 |
+
rope_deltas=self.rope_deltas,
|
| 1458 |
+
)
|
| 1459 |
+
|
| 1460 |
+
def prepare_inputs_for_generation(
|
| 1461 |
+
self,
|
| 1462 |
+
input_ids,
|
| 1463 |
+
past_key_values=None,
|
| 1464 |
+
attention_mask=None,
|
| 1465 |
+
inputs_embeds=None,
|
| 1466 |
+
cache_position=None,
|
| 1467 |
+
position_ids=None,
|
| 1468 |
+
use_cache=True,
|
| 1469 |
+
pixel_values=None,
|
| 1470 |
+
pixel_values_videos=None,
|
| 1471 |
+
image_grid_thw=None,
|
| 1472 |
+
video_grid_thw=None,
|
| 1473 |
+
**kwargs,
|
| 1474 |
+
):
|
| 1475 |
+
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
| 1476 |
+
|
| 1477 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 1478 |
+
input_ids,
|
| 1479 |
+
past_key_values=past_key_values,
|
| 1480 |
+
attention_mask=attention_mask,
|
| 1481 |
+
inputs_embeds=inputs_embeds,
|
| 1482 |
+
cache_position=cache_position,
|
| 1483 |
+
position_ids=position_ids,
|
| 1484 |
+
pixel_values=pixel_values,
|
| 1485 |
+
pixel_values_videos=pixel_values_videos,
|
| 1486 |
+
image_grid_thw=image_grid_thw,
|
| 1487 |
+
video_grid_thw=video_grid_thw,
|
| 1488 |
+
use_cache=use_cache,
|
| 1489 |
+
**kwargs,
|
| 1490 |
+
)
|
| 1491 |
+
|
| 1492 |
+
# Qwen3VL position_ids are prepareed with rope_deltas in forward
|
| 1493 |
+
model_inputs["position_ids"] = None
|
| 1494 |
+
|
| 1495 |
+
if cache_position[0] != 0:
|
| 1496 |
+
model_inputs["pixel_values"] = None
|
| 1497 |
+
model_inputs["pixel_values_videos"] = None
|
| 1498 |
+
|
| 1499 |
+
return model_inputs
|
| 1500 |
+
|
| 1501 |
+
def _get_image_nums_and_video_nums(
|
| 1502 |
+
self,
|
| 1503 |
+
input_ids: Optional[torch.LongTensor],
|
| 1504 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1505 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1506 |
+
"""
|
| 1507 |
+
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
| 1508 |
+
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
|
| 1509 |
+
|
| 1510 |
+
Args:
|
| 1511 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1512 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1513 |
+
|
| 1514 |
+
Returns:
|
| 1515 |
+
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
|
| 1516 |
+
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
| 1517 |
+
"""
|
| 1518 |
+
image_token_id = self.config.image_token_id
|
| 1519 |
+
video_token_id = self.config.video_token_id
|
| 1520 |
+
vision_start_token_id = self.config.vision_start_token_id
|
| 1521 |
+
|
| 1522 |
+
if inputs_embeds is not None:
|
| 1523 |
+
vision_start_mask = (
|
| 1524 |
+
inputs_embeds
|
| 1525 |
+
== self.get_input_embeddings()(
|
| 1526 |
+
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 1527 |
+
)
|
| 1528 |
+
)[..., 0]
|
| 1529 |
+
image_mask = (
|
| 1530 |
+
inputs_embeds
|
| 1531 |
+
== self.get_input_embeddings()(
|
| 1532 |
+
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 1533 |
+
)
|
| 1534 |
+
)[..., 0]
|
| 1535 |
+
video_mask = (
|
| 1536 |
+
inputs_embeds
|
| 1537 |
+
== self.get_input_embeddings()(
|
| 1538 |
+
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 1539 |
+
)
|
| 1540 |
+
)[..., 0]
|
| 1541 |
+
else:
|
| 1542 |
+
vision_start_mask = input_ids == vision_start_token_id
|
| 1543 |
+
image_mask = input_ids == image_token_id
|
| 1544 |
+
video_mask = input_ids == video_token_id
|
| 1545 |
+
|
| 1546 |
+
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
|
| 1547 |
+
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
|
| 1548 |
+
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
|
| 1549 |
+
|
| 1550 |
+
return image_nums, video_nums
|
| 1551 |
+
|
| 1552 |
+
def _expand_inputs_for_generation(
|
| 1553 |
+
self,
|
| 1554 |
+
expand_size: int = 1,
|
| 1555 |
+
is_encoder_decoder: bool = False,
|
| 1556 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1557 |
+
**model_kwargs,
|
| 1558 |
+
) -> tuple[torch.LongTensor, dict[str, Any]]:
|
| 1559 |
+
# Overwritten -- Support for expanding tensors without a batch size dimension
|
| 1560 |
+
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
|
| 1561 |
+
# pixel_values.shape[0] is sum(seqlen_images for samples)
|
| 1562 |
+
# image_grid_thw.shape[0] is sum(num_images for samples)
|
| 1563 |
+
|
| 1564 |
+
if expand_size == 1:
|
| 1565 |
+
return input_ids, model_kwargs
|
| 1566 |
+
|
| 1567 |
+
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
|
| 1568 |
+
|
| 1569 |
+
def _expand_dict_for_generation_visual(dict_to_expand):
|
| 1570 |
+
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
| 1571 |
+
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
| 1572 |
+
image_nums, video_nums = self._get_image_nums_and_video_nums(
|
| 1573 |
+
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
|
| 1574 |
+
)
|
| 1575 |
+
|
| 1576 |
+
def _repeat_interleave_samples(x, lengths, repeat_times):
|
| 1577 |
+
samples = torch.split(x, lengths)
|
| 1578 |
+
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
| 1579 |
+
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
|
| 1580 |
+
return result
|
| 1581 |
+
|
| 1582 |
+
for key in dict_to_expand:
|
| 1583 |
+
if key == "pixel_values":
|
| 1584 |
+
# split images into samples
|
| 1585 |
+
samples = torch.split(image_grid_thw, list(image_nums))
|
| 1586 |
+
# compute the sequence length of images for each sample
|
| 1587 |
+
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
| 1588 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
| 1589 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
| 1590 |
+
)
|
| 1591 |
+
elif key == "image_grid_thw":
|
| 1592 |
+
# get the num of images for each sample
|
| 1593 |
+
lengths = list(image_nums)
|
| 1594 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
| 1595 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
| 1596 |
+
)
|
| 1597 |
+
elif key == "pixel_values_videos":
|
| 1598 |
+
samples = torch.split(video_grid_thw, list(video_nums))
|
| 1599 |
+
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
| 1600 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
| 1601 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
| 1602 |
+
)
|
| 1603 |
+
elif key == "video_grid_thw":
|
| 1604 |
+
lengths = list(video_nums)
|
| 1605 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
| 1606 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
| 1607 |
+
)
|
| 1608 |
+
elif key == "second_per_grid_ts":
|
| 1609 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
| 1610 |
+
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
|
| 1611 |
+
)
|
| 1612 |
+
return dict_to_expand
|
| 1613 |
+
|
| 1614 |
+
def _expand_dict_for_generation(dict_to_expand):
|
| 1615 |
+
for key in dict_to_expand:
|
| 1616 |
+
if (
|
| 1617 |
+
key != "cache_position"
|
| 1618 |
+
and dict_to_expand[key] is not None
|
| 1619 |
+
and isinstance(dict_to_expand[key], torch.Tensor)
|
| 1620 |
+
and key not in visual_keys
|
| 1621 |
+
):
|
| 1622 |
+
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
| 1623 |
+
return dict_to_expand
|
| 1624 |
+
|
| 1625 |
+
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
| 1626 |
+
|
| 1627 |
+
if input_ids is not None:
|
| 1628 |
+
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
| 1629 |
+
|
| 1630 |
+
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
| 1631 |
+
|
| 1632 |
+
if is_encoder_decoder:
|
| 1633 |
+
if model_kwargs.get("encoder_outputs") is None:
|
| 1634 |
+
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
| 1635 |
+
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
| 1636 |
+
|
| 1637 |
+
return input_ids, model_kwargs
|
| 1638 |
+
|
| 1639 |
+
|
| 1640 |
+
__all__ = [
|
| 1641 |
+
"Qwen3VLVisionModel",
|
| 1642 |
+
"Qwen3VLForConditionalGeneration",
|
| 1643 |
+
"Qwen3VLPreTrainedModel",
|
| 1644 |
+
"Qwen3VLTextModel",
|
| 1645 |
+
]
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": null,
|
| 3 |
+
"data_format": "channels_first",
|
| 4 |
+
"default_to_square": true,
|
| 5 |
+
"device": null,
|
| 6 |
+
"disable_grouping": null,
|
| 7 |
+
"do_center_crop": null,
|
| 8 |
+
"do_convert_rgb": true,
|
| 9 |
+
"do_normalize": true,
|
| 10 |
+
"do_pad": null,
|
| 11 |
+
"do_rescale": true,
|
| 12 |
+
"do_resize": true,
|
| 13 |
+
"image_mean": [
|
| 14 |
+
0.5,
|
| 15 |
+
0.5,
|
| 16 |
+
0.5
|
| 17 |
+
],
|
| 18 |
+
"image_processor_type": "Qwen2VLImageProcessorFast",
|
| 19 |
+
"image_std": [
|
| 20 |
+
0.5,
|
| 21 |
+
0.5,
|
| 22 |
+
0.5
|
| 23 |
+
],
|
| 24 |
+
"input_data_format": null,
|
| 25 |
+
"max_pixels": 147456,
|
| 26 |
+
"merge_size": 2,
|
| 27 |
+
"min_pixels": 65536,
|
| 28 |
+
"pad_size": null,
|
| 29 |
+
"patch_size": 16,
|
| 30 |
+
"processor_class": "PRTS_Qwen3VLProcessor",
|
| 31 |
+
"resample": 3,
|
| 32 |
+
"rescale_factor": 0.00392156862745098,
|
| 33 |
+
"return_tensors": null,
|
| 34 |
+
"size": {
|
| 35 |
+
"longest_edge": 147456,
|
| 36 |
+
"shortest_edge": 65536
|
| 37 |
+
},
|
| 38 |
+
"temporal_patch_size": 2
|
| 39 |
+
}
|
processing_prts_qwen3_vl.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 TeleAI Rhodes Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Processor for PRTS built on Qwen3-VL (hub / trust_remote_code; no prts package required)."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
from typing import Optional, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 25 |
+
from transformers.image_utils import ImageInput
|
| 26 |
+
from transformers.processing_utils import (
|
| 27 |
+
ImagesKwargs,
|
| 28 |
+
MultiModalData,
|
| 29 |
+
ProcessingKwargs,
|
| 30 |
+
ProcessorMixin,
|
| 31 |
+
Unpack,
|
| 32 |
+
VideosKwargs,
|
| 33 |
+
)
|
| 34 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 35 |
+
from transformers.utils.logging import get_logger
|
| 36 |
+
from transformers.video_utils import VideoInput
|
| 37 |
+
|
| 38 |
+
ACTION_START_TOKEN = "<|action_start|>"
|
| 39 |
+
ACTION_PLACEHOLDER_TOKEN = "<|action_pad|>"
|
| 40 |
+
ACTION_END_TOKEN = "<|action_end|>"
|
| 41 |
+
CRL_GOAL_REPR_TOKEN = "<|goal_repr|>"
|
| 42 |
+
CRL_OBS_REPR_TOKEN = "<|obs_repr|>"
|
| 43 |
+
VISION_START_TOKEN = "<|vision_start|>" # beginning of vision input
|
| 44 |
+
IMAGE_PLACEHOLDER_TOKEN = "<|image_pad|>" # image placeholder
|
| 45 |
+
VIDEO_PLACEHOLDER_TOKEN = "<|video_pad|>" # video placeholder
|
| 46 |
+
|
| 47 |
+
logger = get_logger(__name__)
|
| 48 |
+
if not logger.handlers:
|
| 49 |
+
handler = logging.StreamHandler()
|
| 50 |
+
handler.setLevel(logging.INFO)
|
| 51 |
+
handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
|
| 52 |
+
logger.addHandler(handler)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Qwen3VLImagesKwargs(ImagesKwargs):
|
| 60 |
+
min_pixels: Optional[int]
|
| 61 |
+
max_pixels: Optional[int]
|
| 62 |
+
patch_size: Optional[int]
|
| 63 |
+
temporal_patch_size: Optional[int]
|
| 64 |
+
merge_size: Optional[int]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
|
| 68 |
+
images_kwargs: Qwen3VLImagesKwargs
|
| 69 |
+
videos_kwargs: Qwen3VLVideosProcessorKwargs
|
| 70 |
+
_defaults = {
|
| 71 |
+
"text_kwargs": {
|
| 72 |
+
"padding": False,
|
| 73 |
+
"return_token_type_ids": False,
|
| 74 |
+
"return_mm_token_type_ids": False,
|
| 75 |
+
},
|
| 76 |
+
"videos_kwargs": {"return_metadata": True},
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class PRTS_Qwen3VLProcessor(ProcessorMixin):
|
| 81 |
+
r"""
|
| 82 |
+
Constructs a PRTS processor which wraps a Qwen3-VL image processor and a Qwen2 tokenizer into a single processor.
|
| 83 |
+
|
| 84 |
+
This processor is built independently (not inheriting from Qwen3VLProcessor) to avoid tight coupling,
|
| 85 |
+
while maintaining compatibility with Qwen3-VL's timestamp-based video processing approach.
|
| 86 |
+
|
| 87 |
+
[`PRTS_Qwen3VLProcessor`] offers all the functionalities needed for PRTS model with:
|
| 88 |
+
- Action token handling (discrete and continuous)
|
| 89 |
+
- State token handling for proprioceptive inputs
|
| 90 |
+
- Expert trigger tokens for flow matching action prediction
|
| 91 |
+
- Qwen3-VL compatible image/video processing with timestamp-based video handling
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
| 95 |
+
The image processor is a required input.
|
| 96 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
| 97 |
+
The tokenizer is a required input.
|
| 98 |
+
video_processor ([`Qwen3VLVideoProcessor`], *optional*):
|
| 99 |
+
The video processor is a required input.
|
| 100 |
+
chat_template (`str`, *optional*):
|
| 101 |
+
A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
attributes = ["image_processor", "tokenizer", "video_processor"]
|
| 105 |
+
image_processor_class = "AutoImageProcessor"
|
| 106 |
+
video_processor_class = "AutoVideoProcessor"
|
| 107 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 108 |
+
|
| 109 |
+
def __init__(self, image_processor=None, tokenizer=None, video_processor=None,
|
| 110 |
+
chat_template=None, **kwargs):
|
| 111 |
+
# Initialize base ProcessorMixin
|
| 112 |
+
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
| 113 |
+
|
| 114 |
+
# Get image/video tokens from tokenizer
|
| 115 |
+
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
| 116 |
+
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
| 117 |
+
self.image_token_id = (
|
| 118 |
+
tokenizer.image_token_id
|
| 119 |
+
if getattr(tokenizer, "image_token_id", None)
|
| 120 |
+
else tokenizer.convert_tokens_to_ids(self.image_token)
|
| 121 |
+
)
|
| 122 |
+
self.video_token_id = (
|
| 123 |
+
tokenizer.video_token_id
|
| 124 |
+
if getattr(tokenizer, "video_token_id", None)
|
| 125 |
+
else tokenizer.convert_tokens_to_ids(self.video_token)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Qwen3-VL vision tokens
|
| 129 |
+
self.vision_start_token = (
|
| 130 |
+
"<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
|
| 131 |
+
)
|
| 132 |
+
self.vision_end_token = (
|
| 133 |
+
"<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
|
| 134 |
+
)
|
| 135 |
+
self.vision_start_token_id = (
|
| 136 |
+
tokenizer.vision_start_token_id
|
| 137 |
+
if getattr(tokenizer, "vision_start_token_id", None)
|
| 138 |
+
else tokenizer.convert_tokens_to_ids(self.vision_start_token)
|
| 139 |
+
)
|
| 140 |
+
self.vision_end_token_id = (
|
| 141 |
+
tokenizer.vision_end_token_id
|
| 142 |
+
if getattr(tokenizer, "vision_end_token_id", None)
|
| 143 |
+
else tokenizer.convert_tokens_to_ids(self.vision_end_token)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
prts_special_tokens = [
|
| 147 |
+
ACTION_START_TOKEN,
|
| 148 |
+
ACTION_PLACEHOLDER_TOKEN,
|
| 149 |
+
ACTION_END_TOKEN,
|
| 150 |
+
CRL_GOAL_REPR_TOKEN,
|
| 151 |
+
CRL_OBS_REPR_TOKEN,
|
| 152 |
+
]
|
| 153 |
+
num_new_tokens = tokenizer.add_tokens(prts_special_tokens, special_tokens=True)
|
| 154 |
+
logger.info(f"Added {num_new_tokens} new special tokens to the tokenizer.")
|
| 155 |
+
|
| 156 |
+
self.action_token = getattr(tokenizer, "action_token", ACTION_PLACEHOLDER_TOKEN)
|
| 157 |
+
self.action_token_id = tokenizer.convert_tokens_to_ids(self.action_token)
|
| 158 |
+
token_dict = {
|
| 159 |
+
"action_start_token_id": ACTION_START_TOKEN,
|
| 160 |
+
"action_token_id": ACTION_PLACEHOLDER_TOKEN,
|
| 161 |
+
"vision_start_token_id": VISION_START_TOKEN,
|
| 162 |
+
"image_token_id": IMAGE_PLACEHOLDER_TOKEN,
|
| 163 |
+
"video_token_id": VIDEO_PLACEHOLDER_TOKEN,
|
| 164 |
+
"crl_goal_repr_token_id": CRL_GOAL_REPR_TOKEN,
|
| 165 |
+
"crl_obs_repr_token_id": CRL_OBS_REPR_TOKEN,
|
| 166 |
+
}
|
| 167 |
+
self.token_ids = {key: tokenizer.convert_tokens_to_ids(value) for key, value in token_dict.items()}
|
| 168 |
+
|
| 169 |
+
def __call__(
|
| 170 |
+
self,
|
| 171 |
+
images: Optional[ImageInput] = None,
|
| 172 |
+
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
| 173 |
+
videos: Optional[VideoInput] = None,
|
| 174 |
+
actions: Union[torch.Tensor] = None,
|
| 175 |
+
**kwargs: Unpack[Qwen3VLProcessorKwargs],
|
| 176 |
+
) -> BatchFeature:
|
| 177 |
+
output_kwargs = self._merge_kwargs(
|
| 178 |
+
Qwen3VLProcessorKwargs,
|
| 179 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
image_inputs = {}
|
| 184 |
+
if images is not None:
|
| 185 |
+
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
| 186 |
+
image_grid_thw = image_inputs["image_grid_thw"]
|
| 187 |
+
else:
|
| 188 |
+
image_grid_thw = None
|
| 189 |
+
|
| 190 |
+
videos_inputs = {}
|
| 191 |
+
if videos is not None:
|
| 192 |
+
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
| 193 |
+
video_grid_thw = videos_inputs["video_grid_thw"]
|
| 194 |
+
if "return_metadata" not in kwargs:
|
| 195 |
+
video_metadata = videos_inputs.pop("video_metadata", None)
|
| 196 |
+
else:
|
| 197 |
+
video_metadata = videos_inputs.get("video_metadata", None)
|
| 198 |
+
else:
|
| 199 |
+
video_grid_thw = None
|
| 200 |
+
video_metadata = None
|
| 201 |
+
|
| 202 |
+
if not isinstance(text, list):
|
| 203 |
+
text = [text]
|
| 204 |
+
|
| 205 |
+
text = text.copy()
|
| 206 |
+
|
| 207 |
+
if image_grid_thw is not None:
|
| 208 |
+
merge_length = self.image_processor.merge_size**2
|
| 209 |
+
index = 0
|
| 210 |
+
for i in range(len(text)):
|
| 211 |
+
while self.image_token in text[i]:
|
| 212 |
+
num_image_tokens = image_grid_thw[index].prod() // merge_length
|
| 213 |
+
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
| 214 |
+
index += 1
|
| 215 |
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
| 216 |
+
|
| 217 |
+
if video_grid_thw is not None:
|
| 218 |
+
merge_length = self.video_processor.merge_size**2
|
| 219 |
+
index = 0
|
| 220 |
+
for i in range(len(text)):
|
| 221 |
+
while self.video_token in text[i]:
|
| 222 |
+
if video_metadata is not None and index < len(video_metadata):
|
| 223 |
+
metadata = video_metadata[index]
|
| 224 |
+
if metadata.fps is None:
|
| 225 |
+
logger.warning_once(
|
| 226 |
+
"Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
|
| 227 |
+
"Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
|
| 228 |
+
"Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
|
| 229 |
+
)
|
| 230 |
+
metadata.fps = 24 if metadata.fps is None else metadata.fps
|
| 231 |
+
|
| 232 |
+
curr_timestamp = self._calculate_timestamps(
|
| 233 |
+
metadata.frames_indices,
|
| 234 |
+
metadata.fps,
|
| 235 |
+
self.video_processor.merge_size,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
video_placeholder = ""
|
| 239 |
+
frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
|
| 240 |
+
for frame_idx in range(video_grid_thw[index][0]):
|
| 241 |
+
curr_time = curr_timestamp[frame_idx]
|
| 242 |
+
video_placeholder += f"<{curr_time:.1f} seconds>"
|
| 243 |
+
video_placeholder += (
|
| 244 |
+
self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
|
| 248 |
+
text[i] = text[i].replace(
|
| 249 |
+
f"{self.vision_start_token}{self.video_token}{self.vision_end_token}",
|
| 250 |
+
video_placeholder,
|
| 251 |
+
1,
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
text[i] = text[i].replace(self.video_token, video_placeholder, 1)
|
| 255 |
+
else:
|
| 256 |
+
num_video_tokens = video_grid_thw[index].prod() // merge_length
|
| 257 |
+
text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
|
| 258 |
+
|
| 259 |
+
index += 1
|
| 260 |
+
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
| 261 |
+
|
| 262 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 263 |
+
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
| 264 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 265 |
+
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
| 266 |
+
|
| 267 |
+
if return_mm_token_type_ids:
|
| 268 |
+
array_ids = np.array(text_inputs["input_ids"])
|
| 269 |
+
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
| 270 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 271 |
+
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
| 272 |
+
|
| 273 |
+
output_data = {**text_inputs, **image_inputs, **videos_inputs}
|
| 274 |
+
if actions is not None:
|
| 275 |
+
output_data["actions"] = actions
|
| 276 |
+
|
| 277 |
+
return BatchFeature(data=output_data, tensor_type=return_tensors)
|
| 278 |
+
|
| 279 |
+
def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2):
|
| 280 |
+
if not isinstance(indices, list):
|
| 281 |
+
indices = indices.tolist()
|
| 282 |
+
if len(indices) % merge_size != 0:
|
| 283 |
+
indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
|
| 284 |
+
timestamps = [idx / video_fps for idx in indices]
|
| 285 |
+
timestamps = [
|
| 286 |
+
(timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size)
|
| 287 |
+
]
|
| 288 |
+
return timestamps
|
| 289 |
+
|
| 290 |
+
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
| 291 |
+
vision_data = {}
|
| 292 |
+
if image_sizes is not None:
|
| 293 |
+
images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {})
|
| 294 |
+
images_kwargs.update(kwargs)
|
| 295 |
+
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
|
| 296 |
+
|
| 297 |
+
num_image_patches = [
|
| 298 |
+
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
| 299 |
+
for image_size in image_sizes
|
| 300 |
+
]
|
| 301 |
+
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
|
| 302 |
+
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
| 303 |
+
|
| 304 |
+
if video_sizes is not None:
|
| 305 |
+
videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {})
|
| 306 |
+
videos_kwargs.update(kwargs)
|
| 307 |
+
merge_size = videos_kwargs.get("merge_size", None) or self.video_processor.merge_size
|
| 308 |
+
num_video_patches = [
|
| 309 |
+
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
|
| 310 |
+
for video_size in video_sizes
|
| 311 |
+
]
|
| 312 |
+
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
|
| 313 |
+
vision_data["num_video_tokens"] = num_video_tokens
|
| 314 |
+
|
| 315 |
+
return MultiModalData(**vision_data)
|
| 316 |
+
|
| 317 |
+
def set_action_tokenizer(self, action_tokenizer):
|
| 318 |
+
self.action_tokenizer = action_tokenizer
|
| 319 |
+
|
| 320 |
+
prts_fast_action_tokens = [f"<|action_token_{i}|>" for i in range(action_tokenizer.vocab_size)]
|
| 321 |
+
num_new_tokens = self.tokenizer.add_tokens(prts_fast_action_tokens, special_tokens=True)
|
| 322 |
+
logger.info(f"Added {num_new_tokens} FAST action tokens to the tokenizer.")
|
| 323 |
+
|
| 324 |
+
self.action_token_start_index = self.tokenizer.convert_tokens_to_ids("<|action_token_0|>")
|
| 325 |
+
self.action_vocab_size = action_tokenizer.vocab_size
|
| 326 |
+
|
| 327 |
+
token_ids = self.tokenizer.convert_tokens_to_ids(prts_fast_action_tokens)
|
| 328 |
+
self.action_mapper = {k: v for k, v in zip(prts_fast_action_tokens, token_ids, strict=True)}
|
| 329 |
+
|
| 330 |
+
def preprocess_action(self, actions, **kwargs):
|
| 331 |
+
raise NotImplementedError
|
| 332 |
+
|
| 333 |
+
def post_process_image_text_to_text(
|
| 334 |
+
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
| 335 |
+
):
|
| 336 |
+
return self.tokenizer.batch_decode(
|
| 337 |
+
generated_outputs,
|
| 338 |
+
skip_special_tokens=skip_special_tokens,
|
| 339 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 340 |
+
**kwargs,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def model_input_names(self):
|
| 345 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 346 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 347 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
PRTS_Qwen3VLProcessor.register_for_auto_class()
|
| 351 |
+
|
| 352 |
+
__all__ = ["PRTS_Qwen3VLProcessor"]
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>"
|
| 16 |
+
],
|
| 17 |
+
"eos_token": {
|
| 18 |
+
"content": "<|im_end|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
"pad_token": {
|
| 25 |
+
"content": "<|endoftext|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": false,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
}
|
| 31 |
+
}
|
statistics.json
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"state_mode": "QUANTILE",
|
| 3 |
+
"features": {
|
| 4 |
+
"observation.state": {
|
| 5 |
+
"dtype": "float32",
|
| 6 |
+
"shape": [
|
| 7 |
+
8
|
| 8 |
+
],
|
| 9 |
+
"names": {
|
| 10 |
+
"motors": [
|
| 11 |
+
"x",
|
| 12 |
+
"y",
|
| 13 |
+
"z",
|
| 14 |
+
"roll",
|
| 15 |
+
"pitch",
|
| 16 |
+
"yaw",
|
| 17 |
+
"pad",
|
| 18 |
+
"gripper"
|
| 19 |
+
]
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"action": {
|
| 23 |
+
"dtype": "float32",
|
| 24 |
+
"shape": [
|
| 25 |
+
7
|
| 26 |
+
],
|
| 27 |
+
"names": {
|
| 28 |
+
"motors": [
|
| 29 |
+
"x",
|
| 30 |
+
"y",
|
| 31 |
+
"z",
|
| 32 |
+
"roll",
|
| 33 |
+
"pitch",
|
| 34 |
+
"yaw",
|
| 35 |
+
"gripper"
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
"stats": {
|
| 41 |
+
"observation.state": {
|
| 42 |
+
"min": [
|
| 43 |
+
-0.4828203022480011,
|
| 44 |
+
-0.3255046010017395,
|
| 45 |
+
0.008128180168569088,
|
| 46 |
+
0.35277295112609863,
|
| 47 |
+
-3.641430377960205,
|
| 48 |
+
-1.842738389968872,
|
| 49 |
+
-0.0013586411951109767,
|
| 50 |
+
-0.042040832340717316
|
| 51 |
+
],
|
| 52 |
+
"max": [
|
| 53 |
+
0.21031762659549713,
|
| 54 |
+
0.39128610491752625,
|
| 55 |
+
1.3660105466842651,
|
| 56 |
+
3.6714255809783936,
|
| 57 |
+
3.560650587081909,
|
| 58 |
+
1.386339545249939,
|
| 59 |
+
0.04233968257904053,
|
| 60 |
+
0.0013633022317662835
|
| 61 |
+
],
|
| 62 |
+
"mean": [
|
| 63 |
+
-0.046518828719854355,
|
| 64 |
+
0.034408919513225555,
|
| 65 |
+
0.7645694613456726,
|
| 66 |
+
2.9716713428497314,
|
| 67 |
+
-0.2204727977514267,
|
| 68 |
+
-0.12557993829250336,
|
| 69 |
+
0.026915358379483223,
|
| 70 |
+
-0.027191326022148132
|
| 71 |
+
],
|
| 72 |
+
"std": [
|
| 73 |
+
0.10494082421064377,
|
| 74 |
+
0.1517619788646698,
|
| 75 |
+
0.3785194456577301,
|
| 76 |
+
0.3442671298980713,
|
| 77 |
+
0.9068173766136169,
|
| 78 |
+
0.32538288831710815,
|
| 79 |
+
0.014175750315189362,
|
| 80 |
+
0.014058776199817657
|
| 81 |
+
],
|
| 82 |
+
"count": [
|
| 83 |
+
273465
|
| 84 |
+
],
|
| 85 |
+
"q01": [
|
| 86 |
+
-0.3991248905658722,
|
| 87 |
+
-0.2688351273536682,
|
| 88 |
+
0.03826696425676346,
|
| 89 |
+
1.508958101272583,
|
| 90 |
+
-2.7197911739349365,
|
| 91 |
+
-1.0805085897445679,
|
| 92 |
+
0.0017423711251467466,
|
| 93 |
+
-0.04002561420202255
|
| 94 |
+
],
|
| 95 |
+
"q99": [
|
| 96 |
+
0.13556525111198425,
|
| 97 |
+
0.33566486835479736,
|
| 98 |
+
1.2706660032272339,
|
| 99 |
+
3.277346134185791,
|
| 100 |
+
2.406111240386963,
|
| 101 |
+
0.5977716445922852,
|
| 102 |
+
0.04031316190958023,
|
| 103 |
+
-0.00177810771856457
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
"action": {
|
| 107 |
+
"min": [
|
| 108 |
+
-0.9375,
|
| 109 |
+
-0.9375,
|
| 110 |
+
-0.9375,
|
| 111 |
+
-0.2582142949104309,
|
| 112 |
+
-0.375,
|
| 113 |
+
-0.3675000071525574,
|
| 114 |
+
-1.0
|
| 115 |
+
],
|
| 116 |
+
"max": [
|
| 117 |
+
0.9375,
|
| 118 |
+
0.9375,
|
| 119 |
+
0.9375,
|
| 120 |
+
0.3557142913341522,
|
| 121 |
+
0.375,
|
| 122 |
+
0.375,
|
| 123 |
+
1.0
|
| 124 |
+
],
|
| 125 |
+
"mean": [
|
| 126 |
+
0.06278152763843536,
|
| 127 |
+
0.08684158325195312,
|
| 128 |
+
-0.0903734639286995,
|
| 129 |
+
0.0005407554563134909,
|
| 130 |
+
0.005643464159220457,
|
| 131 |
+
-0.005229106638580561,
|
| 132 |
+
-0.0496407225728035
|
| 133 |
+
],
|
| 134 |
+
"std": [
|
| 135 |
+
0.33551836013793945,
|
| 136 |
+
0.37847793102264404,
|
| 137 |
+
0.4446770250797272,
|
| 138 |
+
0.03924214467406273,
|
| 139 |
+
0.06341660022735596,
|
| 140 |
+
0.07792268693447113,
|
| 141 |
+
1.000144362449646
|
| 142 |
+
],
|
| 143 |
+
"count": [
|
| 144 |
+
273465
|
| 145 |
+
],
|
| 146 |
+
"q01": [
|
| 147 |
+
-0.7044642567634583,
|
| 148 |
+
-0.8008928298950195,
|
| 149 |
+
-0.9375,
|
| 150 |
+
-0.11464285850524902,
|
| 151 |
+
-0.1639285683631897,
|
| 152 |
+
-0.2239285707473755,
|
| 153 |
+
-1.0
|
| 154 |
+
],
|
| 155 |
+
"q99": [
|
| 156 |
+
0.9375,
|
| 157 |
+
0.8678571581840515,
|
| 158 |
+
0.9375,
|
| 159 |
+
0.13178572058677673,
|
| 160 |
+
0.19285714626312256,
|
| 161 |
+
0.335357129573822,
|
| 162 |
+
1.0
|
| 163 |
+
]
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"datasets": {
|
| 167 |
+
"libero_4_suites": {
|
| 168 |
+
"features": {
|
| 169 |
+
"observation.state": {
|
| 170 |
+
"dtype": "float32",
|
| 171 |
+
"shape": [
|
| 172 |
+
8
|
| 173 |
+
],
|
| 174 |
+
"names": {
|
| 175 |
+
"motors": [
|
| 176 |
+
"x",
|
| 177 |
+
"y",
|
| 178 |
+
"z",
|
| 179 |
+
"roll",
|
| 180 |
+
"pitch",
|
| 181 |
+
"yaw",
|
| 182 |
+
"pad",
|
| 183 |
+
"gripper"
|
| 184 |
+
]
|
| 185 |
+
}
|
| 186 |
+
},
|
| 187 |
+
"action": {
|
| 188 |
+
"dtype": "float32",
|
| 189 |
+
"shape": [
|
| 190 |
+
7
|
| 191 |
+
],
|
| 192 |
+
"names": {
|
| 193 |
+
"motors": [
|
| 194 |
+
"x",
|
| 195 |
+
"y",
|
| 196 |
+
"z",
|
| 197 |
+
"roll",
|
| 198 |
+
"pitch",
|
| 199 |
+
"yaw",
|
| 200 |
+
"gripper"
|
| 201 |
+
]
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
},
|
| 205 |
+
"stats": {
|
| 206 |
+
"observation.state": {
|
| 207 |
+
"min": [
|
| 208 |
+
-0.4828203022480011,
|
| 209 |
+
-0.3255046010017395,
|
| 210 |
+
0.008128180168569088,
|
| 211 |
+
0.35277295112609863,
|
| 212 |
+
-3.641430377960205,
|
| 213 |
+
-1.842738389968872,
|
| 214 |
+
-0.0013586411951109767,
|
| 215 |
+
-0.042040832340717316
|
| 216 |
+
],
|
| 217 |
+
"max": [
|
| 218 |
+
0.21031762659549713,
|
| 219 |
+
0.39128610491752625,
|
| 220 |
+
1.3660105466842651,
|
| 221 |
+
3.6714255809783936,
|
| 222 |
+
3.560650587081909,
|
| 223 |
+
1.386339545249939,
|
| 224 |
+
0.04233968257904053,
|
| 225 |
+
0.0013633022317662835
|
| 226 |
+
],
|
| 227 |
+
"mean": [
|
| 228 |
+
-0.046518828719854355,
|
| 229 |
+
0.034408919513225555,
|
| 230 |
+
0.7645694613456726,
|
| 231 |
+
2.9716713428497314,
|
| 232 |
+
-0.2204727977514267,
|
| 233 |
+
-0.12557993829250336,
|
| 234 |
+
0.026915358379483223,
|
| 235 |
+
-0.027191326022148132
|
| 236 |
+
],
|
| 237 |
+
"std": [
|
| 238 |
+
0.10494082421064377,
|
| 239 |
+
0.1517619788646698,
|
| 240 |
+
0.3785194456577301,
|
| 241 |
+
0.3442671298980713,
|
| 242 |
+
0.9068173766136169,
|
| 243 |
+
0.32538288831710815,
|
| 244 |
+
0.014175750315189362,
|
| 245 |
+
0.014058776199817657
|
| 246 |
+
],
|
| 247 |
+
"count": [
|
| 248 |
+
273465
|
| 249 |
+
],
|
| 250 |
+
"q01": [
|
| 251 |
+
-0.3991248905658722,
|
| 252 |
+
-0.2688351273536682,
|
| 253 |
+
0.03826696425676346,
|
| 254 |
+
1.508958101272583,
|
| 255 |
+
-2.7197911739349365,
|
| 256 |
+
-1.0805085897445679,
|
| 257 |
+
0.0017423711251467466,
|
| 258 |
+
-0.04002561420202255
|
| 259 |
+
],
|
| 260 |
+
"q99": [
|
| 261 |
+
0.13556525111198425,
|
| 262 |
+
0.33566486835479736,
|
| 263 |
+
1.2706660032272339,
|
| 264 |
+
3.277346134185791,
|
| 265 |
+
2.406111240386963,
|
| 266 |
+
0.5977716445922852,
|
| 267 |
+
0.04031316190958023,
|
| 268 |
+
-0.00177810771856457
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
"action": {
|
| 272 |
+
"min": [
|
| 273 |
+
-0.9375,
|
| 274 |
+
-0.9375,
|
| 275 |
+
-0.9375,
|
| 276 |
+
-0.2582142949104309,
|
| 277 |
+
-0.375,
|
| 278 |
+
-0.3675000071525574,
|
| 279 |
+
-1.0
|
| 280 |
+
],
|
| 281 |
+
"max": [
|
| 282 |
+
0.9375,
|
| 283 |
+
0.9375,
|
| 284 |
+
0.9375,
|
| 285 |
+
0.3557142913341522,
|
| 286 |
+
0.375,
|
| 287 |
+
0.375,
|
| 288 |
+
1.0
|
| 289 |
+
],
|
| 290 |
+
"mean": [
|
| 291 |
+
0.06278152763843536,
|
| 292 |
+
0.08684158325195312,
|
| 293 |
+
-0.0903734639286995,
|
| 294 |
+
0.0005407554563134909,
|
| 295 |
+
0.005643464159220457,
|
| 296 |
+
-0.005229106638580561,
|
| 297 |
+
-0.0496407225728035
|
| 298 |
+
],
|
| 299 |
+
"std": [
|
| 300 |
+
0.33551836013793945,
|
| 301 |
+
0.37847793102264404,
|
| 302 |
+
0.4446770250797272,
|
| 303 |
+
0.03924214467406273,
|
| 304 |
+
0.06341660022735596,
|
| 305 |
+
0.07792268693447113,
|
| 306 |
+
1.000144362449646
|
| 307 |
+
],
|
| 308 |
+
"count": [
|
| 309 |
+
273465
|
| 310 |
+
],
|
| 311 |
+
"q01": [
|
| 312 |
+
-0.7044642567634583,
|
| 313 |
+
-0.8008928298950195,
|
| 314 |
+
-0.9375,
|
| 315 |
+
-0.11464285850524902,
|
| 316 |
+
-0.1639285683631897,
|
| 317 |
+
-0.2239285707473755,
|
| 318 |
+
-1.0
|
| 319 |
+
],
|
| 320 |
+
"q99": [
|
| 321 |
+
0.9375,
|
| 322 |
+
0.8678571581840515,
|
| 323 |
+
0.9375,
|
| 324 |
+
0.13178572058677673,
|
| 325 |
+
0.19285714626312256,
|
| 326 |
+
0.335357129573822,
|
| 327 |
+
1.0
|
| 328 |
+
]
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5482df2482307db564c0595428d3dfdad4bf5dbd9d3d5156052ca12f93b7d3ed
|
| 3 |
+
size 11828002
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb6e322ff4c32859f3014cdd5e49182ec932f2b10cc2e365df3439522af926a7
|
| 3 |
+
size 10129
|
video_preprocessor_config.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": null,
|
| 3 |
+
"data_format": "channels_first",
|
| 4 |
+
"default_to_square": true,
|
| 5 |
+
"device": null,
|
| 6 |
+
"do_center_crop": null,
|
| 7 |
+
"do_convert_rgb": true,
|
| 8 |
+
"do_normalize": true,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"do_sample_frames": true,
|
| 12 |
+
"fps": 2.0,
|
| 13 |
+
"image_mean": [
|
| 14 |
+
0.5,
|
| 15 |
+
0.5,
|
| 16 |
+
0.5
|
| 17 |
+
],
|
| 18 |
+
"image_std": [
|
| 19 |
+
0.5,
|
| 20 |
+
0.5,
|
| 21 |
+
0.5
|
| 22 |
+
],
|
| 23 |
+
"input_data_format": null,
|
| 24 |
+
"max_frames": 8,
|
| 25 |
+
"merge_size": 2,
|
| 26 |
+
"min_frames": 4,
|
| 27 |
+
"num_frames": null,
|
| 28 |
+
"pad_size": null,
|
| 29 |
+
"patch_size": 16,
|
| 30 |
+
"processor_class": "PRTS_Qwen3VLProcessor",
|
| 31 |
+
"resample": 3,
|
| 32 |
+
"rescale_factor": 0.00392156862745098,
|
| 33 |
+
"return_metadata": false,
|
| 34 |
+
"size": {
|
| 35 |
+
"longest_edge": 147456,
|
| 36 |
+
"shortest_edge": 65536
|
| 37 |
+
},
|
| 38 |
+
"temporal_patch_size": 2,
|
| 39 |
+
"video_metadata": null,
|
| 40 |
+
"video_processor_type": "Qwen3VLVideoProcessor"
|
| 41 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|