zstriving commited on
Commit
f313e88
·
verified ·
1 Parent(s): 113b391

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ library_name: transformers
4
+ pipeline_tag: robotics
5
+ tags:
6
+ - robotics
7
+ - vision-language-action
8
+ - vla
9
+ - libero
10
+ - qwen3-vl
11
+ - prts
12
+ - custom_code
13
+ language:
14
+ - en
15
+ base_model: TeleEmbodied/PRTS-4B
16
+ ---
17
+
18
+ <h1 align="center">PRTS-4B-LIBERO</h1>
19
+
20
+ <p align="center">
21
+ <a href="https://arxiv.org/abs/2604.27472"><img src="https://img.shields.io/badge/arXiv-2604.27472-b31b1b.svg" alt="arXiv"></a>
22
+ &nbsp;
23
+ <a href="https://github.com/TeleHuman/PRTS"><img src="https://img.shields.io/badge/GitHub-PRTS-181717.svg" alt="GitHub"></a>
24
+ &nbsp;
25
+ <a href="https://huggingface.co/TeleEmbodied/PRTS-4B"><img src="https://img.shields.io/badge/Base-PRTS--4B-yellow.svg" alt="Base model"></a>
26
+ </p>
27
+
28
+ **PRTS-4B-LIBERO** is the LIBERO fine-tuned variant of [`TeleEmbodied/PRTS-4B`](https://huggingface.co/TeleEmbodied/PRTS-4B). This is the exact checkpoint used to report the LIBERO numbers in the PRTS paper. For the base model card (architecture, prompt format, contrastive RL design), please refer to the parent [PRTS-4B](https://huggingface.co/TeleEmbodied/PRTS-4B) repository.
29
+
30
+ ## Post-training budget
31
+
32
+ Fine-tuned from `TeleEmbodied/PRTS-4B` with the launch script [`scripts/ft/launch_finetune.sh`](https://github.com/TeleHuman/PRTS/blob/main/scripts/ft/launch_finetune.sh) in the open-source repo. Key settings:
33
+
34
+ | | |
35
+ | :--- | :--- |
36
+ | Base model | `TeleEmbodied/PRTS-4B` |
37
+ | Dataset config | `configs/post-train/libero.yaml` |
38
+ | Embodiment tag | `libero_panda` |
39
+ | Hardware | 4 GPUs, DeepSpeed ZeRO-2, bf16, `flash_attention_3`, no gradient checkpointing |
40
+ | Steps | 30,000 total, 5,000 warmup, save every 10,000 |
41
+ | Effective batch | 8 (per-device) × 4 GPUs × 1 (grad-acc) = **32** |
42
+ | LRs | `1e-5` for vision / merger / LLM; `1e-4` for the action head |
43
+ | Scheduler | `cosine_with_min_lr` (min `1e-6`) |
44
+ | Optimizer | AdamW (β1=0.9, β2=0.95, ε=1e-8), weight decay `1e-8`, grad clip `1.0` |
45
+ | Action head | DiT-L + MoT action expert, chunk size `20`, max action dim `32` |
46
+ | Action normalization | `QUANTILE` (stats bundled in this checkpoint) |
47
+ | Seed | 42 |
48
+
49
+ ## Loading for evaluation
50
+
51
+ This checkpoint plugs into the policy server [`scripts/serve_policy.py`](https://github.com/TeleHuman/PRTS/blob/main/scripts/serve_policy.py). Update the `EnvMode.LIBERO` entry in `DEFAULT_CHECKPOINT` so that `dir=` points to your local download of this repo. Normalization stats are already bundled in the checkpoint, so `dataset_path` can be left as `None`:
52
+
53
+ ```python
54
+ EnvMode.LIBERO: Checkpoint(
55
+ config="prts_libero",
56
+ dir="/path/to/PRTS-4B-libero", # local download path of this repo
57
+ action_dim=7,
58
+ dataset_path=None, # normalization stats are bundled in the checkpoint
59
+ state_mode="QUANTILE",
60
+ ),
61
+ ```
62
+
63
+ ## Running LIBERO evaluation
64
+
65
+ Follow the LIBERO simulation setup in [`examples/libero/README.md`](https://github.com/TeleHuman/PRTS/blob/main/examples/libero/README.md), then start the policy server from the PRTS repo root with [`examples/libero/run_libero_server.sh`](https://github.com/TeleHuman/PRTS/blob/main/examples/libero/run_libero_server.sh):
66
+
67
+ ```bash
68
+ bash examples/libero/run_libero_server.sh
69
+ # which runs:
70
+ # CUDA_VISIBLE_DEVICES=0 python scripts/serve_policy.py --env LIBERO --port 10000
71
+ ```
72
+
73
+ The LIBERO simulator (Terminal 1 in the example README) connects to this server over websocket and rolls out the 4 LIBERO task suites.
74
+
75
+ ## License
76
+
77
+ Released under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) — free for academic and non-commercial research; commercial use is **not** permitted.
78
+
79
+ ## Citation
80
+
81
+ ```bibtex
82
+ @article{zhang2026prts,
83
+ title = {PRTS: A Primitive Reasoning and Tasking System via Contrastive Representations},
84
+ author = {Yang Zhang and Jiangyuan Zhao and Chenyou Fan and Fangzheng Yan and Tian Li and Haitong Tang and Sen Fu and Xuan'er Wu and Qizhen Weng and Weinan Zhang and Xiu Li and Chi Zhang and Chenjia Bai and Xuelong Li},
85
+ journal = {arXiv preprint arXiv:2604.27472},
86
+ year = {2026},
87
+ }
88
+ ```
added_tokens.json ADDED
@@ -0,0 +1,2081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|action_end|>": 151671,
9
+ "<|action_pad|>": 151670,
10
+ "<|action_start|>": 151669,
11
+ "<|action_token_0|>": 151674,
12
+ "<|action_token_1000|>": 152674,
13
+ "<|action_token_1001|>": 152675,
14
+ "<|action_token_1002|>": 152676,
15
+ "<|action_token_1003|>": 152677,
16
+ "<|action_token_1004|>": 152678,
17
+ "<|action_token_1005|>": 152679,
18
+ "<|action_token_1006|>": 152680,
19
+ "<|action_token_1007|>": 152681,
20
+ "<|action_token_1008|>": 152682,
21
+ "<|action_token_1009|>": 152683,
22
+ "<|action_token_100|>": 151774,
23
+ "<|action_token_1010|>": 152684,
24
+ "<|action_token_1011|>": 152685,
25
+ "<|action_token_1012|>": 152686,
26
+ "<|action_token_1013|>": 152687,
27
+ "<|action_token_1014|>": 152688,
28
+ "<|action_token_1015|>": 152689,
29
+ "<|action_token_1016|>": 152690,
30
+ "<|action_token_1017|>": 152691,
31
+ "<|action_token_1018|>": 152692,
32
+ "<|action_token_1019|>": 152693,
33
+ "<|action_token_101|>": 151775,
34
+ "<|action_token_1020|>": 152694,
35
+ "<|action_token_1021|>": 152695,
36
+ "<|action_token_1022|>": 152696,
37
+ "<|action_token_1023|>": 152697,
38
+ "<|action_token_1024|>": 152698,
39
+ "<|action_token_1025|>": 152699,
40
+ "<|action_token_1026|>": 152700,
41
+ "<|action_token_1027|>": 152701,
42
+ "<|action_token_1028|>": 152702,
43
+ "<|action_token_1029|>": 152703,
44
+ "<|action_token_102|>": 151776,
45
+ "<|action_token_1030|>": 152704,
46
+ "<|action_token_1031|>": 152705,
47
+ "<|action_token_1032|>": 152706,
48
+ "<|action_token_1033|>": 152707,
49
+ "<|action_token_1034|>": 152708,
50
+ "<|action_token_1035|>": 152709,
51
+ "<|action_token_1036|>": 152710,
52
+ "<|action_token_1037|>": 152711,
53
+ "<|action_token_1038|>": 152712,
54
+ "<|action_token_1039|>": 152713,
55
+ "<|action_token_103|>": 151777,
56
+ "<|action_token_1040|>": 152714,
57
+ "<|action_token_1041|>": 152715,
58
+ "<|action_token_1042|>": 152716,
59
+ "<|action_token_1043|>": 152717,
60
+ "<|action_token_1044|>": 152718,
61
+ "<|action_token_1045|>": 152719,
62
+ "<|action_token_1046|>": 152720,
63
+ "<|action_token_1047|>": 152721,
64
+ "<|action_token_1048|>": 152722,
65
+ "<|action_token_1049|>": 152723,
66
+ "<|action_token_104|>": 151778,
67
+ "<|action_token_1050|>": 152724,
68
+ "<|action_token_1051|>": 152725,
69
+ "<|action_token_1052|>": 152726,
70
+ "<|action_token_1053|>": 152727,
71
+ "<|action_token_1054|>": 152728,
72
+ "<|action_token_1055|>": 152729,
73
+ "<|action_token_1056|>": 152730,
74
+ "<|action_token_1057|>": 152731,
75
+ "<|action_token_1058|>": 152732,
76
+ "<|action_token_1059|>": 152733,
77
+ "<|action_token_105|>": 151779,
78
+ "<|action_token_1060|>": 152734,
79
+ "<|action_token_1061|>": 152735,
80
+ "<|action_token_1062|>": 152736,
81
+ "<|action_token_1063|>": 152737,
82
+ "<|action_token_1064|>": 152738,
83
+ "<|action_token_1065|>": 152739,
84
+ "<|action_token_1066|>": 152740,
85
+ "<|action_token_1067|>": 152741,
86
+ "<|action_token_1068|>": 152742,
87
+ "<|action_token_1069|>": 152743,
88
+ "<|action_token_106|>": 151780,
89
+ "<|action_token_1070|>": 152744,
90
+ "<|action_token_1071|>": 152745,
91
+ "<|action_token_1072|>": 152746,
92
+ "<|action_token_1073|>": 152747,
93
+ "<|action_token_1074|>": 152748,
94
+ "<|action_token_1075|>": 152749,
95
+ "<|action_token_1076|>": 152750,
96
+ "<|action_token_1077|>": 152751,
97
+ "<|action_token_1078|>": 152752,
98
+ "<|action_token_1079|>": 152753,
99
+ "<|action_token_107|>": 151781,
100
+ "<|action_token_1080|>": 152754,
101
+ "<|action_token_1081|>": 152755,
102
+ "<|action_token_1082|>": 152756,
103
+ "<|action_token_1083|>": 152757,
104
+ "<|action_token_1084|>": 152758,
105
+ "<|action_token_1085|>": 152759,
106
+ "<|action_token_1086|>": 152760,
107
+ "<|action_token_1087|>": 152761,
108
+ "<|action_token_1088|>": 152762,
109
+ "<|action_token_1089|>": 152763,
110
+ "<|action_token_108|>": 151782,
111
+ "<|action_token_1090|>": 152764,
112
+ "<|action_token_1091|>": 152765,
113
+ "<|action_token_1092|>": 152766,
114
+ "<|action_token_1093|>": 152767,
115
+ "<|action_token_1094|>": 152768,
116
+ "<|action_token_1095|>": 152769,
117
+ "<|action_token_1096|>": 152770,
118
+ "<|action_token_1097|>": 152771,
119
+ "<|action_token_1098|>": 152772,
120
+ "<|action_token_1099|>": 152773,
121
+ "<|action_token_109|>": 151783,
122
+ "<|action_token_10|>": 151684,
123
+ "<|action_token_1100|>": 152774,
124
+ "<|action_token_1101|>": 152775,
125
+ "<|action_token_1102|>": 152776,
126
+ "<|action_token_1103|>": 152777,
127
+ "<|action_token_1104|>": 152778,
128
+ "<|action_token_1105|>": 152779,
129
+ "<|action_token_1106|>": 152780,
130
+ "<|action_token_1107|>": 152781,
131
+ "<|action_token_1108|>": 152782,
132
+ "<|action_token_1109|>": 152783,
133
+ "<|action_token_110|>": 151784,
134
+ "<|action_token_1110|>": 152784,
135
+ "<|action_token_1111|>": 152785,
136
+ "<|action_token_1112|>": 152786,
137
+ "<|action_token_1113|>": 152787,
138
+ "<|action_token_1114|>": 152788,
139
+ "<|action_token_1115|>": 152789,
140
+ "<|action_token_1116|>": 152790,
141
+ "<|action_token_1117|>": 152791,
142
+ "<|action_token_1118|>": 152792,
143
+ "<|action_token_1119|>": 152793,
144
+ "<|action_token_111|>": 151785,
145
+ "<|action_token_1120|>": 152794,
146
+ "<|action_token_1121|>": 152795,
147
+ "<|action_token_1122|>": 152796,
148
+ "<|action_token_1123|>": 152797,
149
+ "<|action_token_1124|>": 152798,
150
+ "<|action_token_1125|>": 152799,
151
+ "<|action_token_1126|>": 152800,
152
+ "<|action_token_1127|>": 152801,
153
+ "<|action_token_1128|>": 152802,
154
+ "<|action_token_1129|>": 152803,
155
+ "<|action_token_112|>": 151786,
156
+ "<|action_token_1130|>": 152804,
157
+ "<|action_token_1131|>": 152805,
158
+ "<|action_token_1132|>": 152806,
159
+ "<|action_token_1133|>": 152807,
160
+ "<|action_token_1134|>": 152808,
161
+ "<|action_token_1135|>": 152809,
162
+ "<|action_token_1136|>": 152810,
163
+ "<|action_token_1137|>": 152811,
164
+ "<|action_token_1138|>": 152812,
165
+ "<|action_token_1139|>": 152813,
166
+ "<|action_token_113|>": 151787,
167
+ "<|action_token_1140|>": 152814,
168
+ "<|action_token_1141|>": 152815,
169
+ "<|action_token_1142|>": 152816,
170
+ "<|action_token_1143|>": 152817,
171
+ "<|action_token_1144|>": 152818,
172
+ "<|action_token_1145|>": 152819,
173
+ "<|action_token_1146|>": 152820,
174
+ "<|action_token_1147|>": 152821,
175
+ "<|action_token_1148|>": 152822,
176
+ "<|action_token_1149|>": 152823,
177
+ "<|action_token_114|>": 151788,
178
+ "<|action_token_1150|>": 152824,
179
+ "<|action_token_1151|>": 152825,
180
+ "<|action_token_1152|>": 152826,
181
+ "<|action_token_1153|>": 152827,
182
+ "<|action_token_1154|>": 152828,
183
+ "<|action_token_1155|>": 152829,
184
+ "<|action_token_1156|>": 152830,
185
+ "<|action_token_1157|>": 152831,
186
+ "<|action_token_1158|>": 152832,
187
+ "<|action_token_1159|>": 152833,
188
+ "<|action_token_115|>": 151789,
189
+ "<|action_token_1160|>": 152834,
190
+ "<|action_token_1161|>": 152835,
191
+ "<|action_token_1162|>": 152836,
192
+ "<|action_token_1163|>": 152837,
193
+ "<|action_token_1164|>": 152838,
194
+ "<|action_token_1165|>": 152839,
195
+ "<|action_token_1166|>": 152840,
196
+ "<|action_token_1167|>": 152841,
197
+ "<|action_token_1168|>": 152842,
198
+ "<|action_token_1169|>": 152843,
199
+ "<|action_token_116|>": 151790,
200
+ "<|action_token_1170|>": 152844,
201
+ "<|action_token_1171|>": 152845,
202
+ "<|action_token_1172|>": 152846,
203
+ "<|action_token_1173|>": 152847,
204
+ "<|action_token_1174|>": 152848,
205
+ "<|action_token_1175|>": 152849,
206
+ "<|action_token_1176|>": 152850,
207
+ "<|action_token_1177|>": 152851,
208
+ "<|action_token_1178|>": 152852,
209
+ "<|action_token_1179|>": 152853,
210
+ "<|action_token_117|>": 151791,
211
+ "<|action_token_1180|>": 152854,
212
+ "<|action_token_1181|>": 152855,
213
+ "<|action_token_1182|>": 152856,
214
+ "<|action_token_1183|>": 152857,
215
+ "<|action_token_1184|>": 152858,
216
+ "<|action_token_1185|>": 152859,
217
+ "<|action_token_1186|>": 152860,
218
+ "<|action_token_1187|>": 152861,
219
+ "<|action_token_1188|>": 152862,
220
+ "<|action_token_1189|>": 152863,
221
+ "<|action_token_118|>": 151792,
222
+ "<|action_token_1190|>": 152864,
223
+ "<|action_token_1191|>": 152865,
224
+ "<|action_token_1192|>": 152866,
225
+ "<|action_token_1193|>": 152867,
226
+ "<|action_token_1194|>": 152868,
227
+ "<|action_token_1195|>": 152869,
228
+ "<|action_token_1196|>": 152870,
229
+ "<|action_token_1197|>": 152871,
230
+ "<|action_token_1198|>": 152872,
231
+ "<|action_token_1199|>": 152873,
232
+ "<|action_token_119|>": 151793,
233
+ "<|action_token_11|>": 151685,
234
+ "<|action_token_1200|>": 152874,
235
+ "<|action_token_1201|>": 152875,
236
+ "<|action_token_1202|>": 152876,
237
+ "<|action_token_1203|>": 152877,
238
+ "<|action_token_1204|>": 152878,
239
+ "<|action_token_1205|>": 152879,
240
+ "<|action_token_1206|>": 152880,
241
+ "<|action_token_1207|>": 152881,
242
+ "<|action_token_1208|>": 152882,
243
+ "<|action_token_1209|>": 152883,
244
+ "<|action_token_120|>": 151794,
245
+ "<|action_token_1210|>": 152884,
246
+ "<|action_token_1211|>": 152885,
247
+ "<|action_token_1212|>": 152886,
248
+ "<|action_token_1213|>": 152887,
249
+ "<|action_token_1214|>": 152888,
250
+ "<|action_token_1215|>": 152889,
251
+ "<|action_token_1216|>": 152890,
252
+ "<|action_token_1217|>": 152891,
253
+ "<|action_token_1218|>": 152892,
254
+ "<|action_token_1219|>": 152893,
255
+ "<|action_token_121|>": 151795,
256
+ "<|action_token_1220|>": 152894,
257
+ "<|action_token_1221|>": 152895,
258
+ "<|action_token_1222|>": 152896,
259
+ "<|action_token_1223|>": 152897,
260
+ "<|action_token_1224|>": 152898,
261
+ "<|action_token_1225|>": 152899,
262
+ "<|action_token_1226|>": 152900,
263
+ "<|action_token_1227|>": 152901,
264
+ "<|action_token_1228|>": 152902,
265
+ "<|action_token_1229|>": 152903,
266
+ "<|action_token_122|>": 151796,
267
+ "<|action_token_1230|>": 152904,
268
+ "<|action_token_1231|>": 152905,
269
+ "<|action_token_1232|>": 152906,
270
+ "<|action_token_1233|>": 152907,
271
+ "<|action_token_1234|>": 152908,
272
+ "<|action_token_1235|>": 152909,
273
+ "<|action_token_1236|>": 152910,
274
+ "<|action_token_1237|>": 152911,
275
+ "<|action_token_1238|>": 152912,
276
+ "<|action_token_1239|>": 152913,
277
+ "<|action_token_123|>": 151797,
278
+ "<|action_token_1240|>": 152914,
279
+ "<|action_token_1241|>": 152915,
280
+ "<|action_token_1242|>": 152916,
281
+ "<|action_token_1243|>": 152917,
282
+ "<|action_token_1244|>": 152918,
283
+ "<|action_token_1245|>": 152919,
284
+ "<|action_token_1246|>": 152920,
285
+ "<|action_token_1247|>": 152921,
286
+ "<|action_token_1248|>": 152922,
287
+ "<|action_token_1249|>": 152923,
288
+ "<|action_token_124|>": 151798,
289
+ "<|action_token_1250|>": 152924,
290
+ "<|action_token_1251|>": 152925,
291
+ "<|action_token_1252|>": 152926,
292
+ "<|action_token_1253|>": 152927,
293
+ "<|action_token_1254|>": 152928,
294
+ "<|action_token_1255|>": 152929,
295
+ "<|action_token_1256|>": 152930,
296
+ "<|action_token_1257|>": 152931,
297
+ "<|action_token_1258|>": 152932,
298
+ "<|action_token_1259|>": 152933,
299
+ "<|action_token_125|>": 151799,
300
+ "<|action_token_1260|>": 152934,
301
+ "<|action_token_1261|>": 152935,
302
+ "<|action_token_1262|>": 152936,
303
+ "<|action_token_1263|>": 152937,
304
+ "<|action_token_1264|>": 152938,
305
+ "<|action_token_1265|>": 152939,
306
+ "<|action_token_1266|>": 152940,
307
+ "<|action_token_1267|>": 152941,
308
+ "<|action_token_1268|>": 152942,
309
+ "<|action_token_1269|>": 152943,
310
+ "<|action_token_126|>": 151800,
311
+ "<|action_token_1270|>": 152944,
312
+ "<|action_token_1271|>": 152945,
313
+ "<|action_token_1272|>": 152946,
314
+ "<|action_token_1273|>": 152947,
315
+ "<|action_token_1274|>": 152948,
316
+ "<|action_token_1275|>": 152949,
317
+ "<|action_token_1276|>": 152950,
318
+ "<|action_token_1277|>": 152951,
319
+ "<|action_token_1278|>": 152952,
320
+ "<|action_token_1279|>": 152953,
321
+ "<|action_token_127|>": 151801,
322
+ "<|action_token_1280|>": 152954,
323
+ "<|action_token_1281|>": 152955,
324
+ "<|action_token_1282|>": 152956,
325
+ "<|action_token_1283|>": 152957,
326
+ "<|action_token_1284|>": 152958,
327
+ "<|action_token_1285|>": 152959,
328
+ "<|action_token_1286|>": 152960,
329
+ "<|action_token_1287|>": 152961,
330
+ "<|action_token_1288|>": 152962,
331
+ "<|action_token_1289|>": 152963,
332
+ "<|action_token_128|>": 151802,
333
+ "<|action_token_1290|>": 152964,
334
+ "<|action_token_1291|>": 152965,
335
+ "<|action_token_1292|>": 152966,
336
+ "<|action_token_1293|>": 152967,
337
+ "<|action_token_1294|>": 152968,
338
+ "<|action_token_1295|>": 152969,
339
+ "<|action_token_1296|>": 152970,
340
+ "<|action_token_1297|>": 152971,
341
+ "<|action_token_1298|>": 152972,
342
+ "<|action_token_1299|>": 152973,
343
+ "<|action_token_129|>": 151803,
344
+ "<|action_token_12|>": 151686,
345
+ "<|action_token_1300|>": 152974,
346
+ "<|action_token_1301|>": 152975,
347
+ "<|action_token_1302|>": 152976,
348
+ "<|action_token_1303|>": 152977,
349
+ "<|action_token_1304|>": 152978,
350
+ "<|action_token_1305|>": 152979,
351
+ "<|action_token_1306|>": 152980,
352
+ "<|action_token_1307|>": 152981,
353
+ "<|action_token_1308|>": 152982,
354
+ "<|action_token_1309|>": 152983,
355
+ "<|action_token_130|>": 151804,
356
+ "<|action_token_1310|>": 152984,
357
+ "<|action_token_1311|>": 152985,
358
+ "<|action_token_1312|>": 152986,
359
+ "<|action_token_1313|>": 152987,
360
+ "<|action_token_1314|>": 152988,
361
+ "<|action_token_1315|>": 152989,
362
+ "<|action_token_1316|>": 152990,
363
+ "<|action_token_1317|>": 152991,
364
+ "<|action_token_1318|>": 152992,
365
+ "<|action_token_1319|>": 152993,
366
+ "<|action_token_131|>": 151805,
367
+ "<|action_token_1320|>": 152994,
368
+ "<|action_token_1321|>": 152995,
369
+ "<|action_token_1322|>": 152996,
370
+ "<|action_token_1323|>": 152997,
371
+ "<|action_token_1324|>": 152998,
372
+ "<|action_token_1325|>": 152999,
373
+ "<|action_token_1326|>": 153000,
374
+ "<|action_token_1327|>": 153001,
375
+ "<|action_token_1328|>": 153002,
376
+ "<|action_token_1329|>": 153003,
377
+ "<|action_token_132|>": 151806,
378
+ "<|action_token_1330|>": 153004,
379
+ "<|action_token_1331|>": 153005,
380
+ "<|action_token_1332|>": 153006,
381
+ "<|action_token_1333|>": 153007,
382
+ "<|action_token_1334|>": 153008,
383
+ "<|action_token_1335|>": 153009,
384
+ "<|action_token_1336|>": 153010,
385
+ "<|action_token_1337|>": 153011,
386
+ "<|action_token_1338|>": 153012,
387
+ "<|action_token_1339|>": 153013,
388
+ "<|action_token_133|>": 151807,
389
+ "<|action_token_1340|>": 153014,
390
+ "<|action_token_1341|>": 153015,
391
+ "<|action_token_1342|>": 153016,
392
+ "<|action_token_1343|>": 153017,
393
+ "<|action_token_1344|>": 153018,
394
+ "<|action_token_1345|>": 153019,
395
+ "<|action_token_1346|>": 153020,
396
+ "<|action_token_1347|>": 153021,
397
+ "<|action_token_1348|>": 153022,
398
+ "<|action_token_1349|>": 153023,
399
+ "<|action_token_134|>": 151808,
400
+ "<|action_token_1350|>": 153024,
401
+ "<|action_token_1351|>": 153025,
402
+ "<|action_token_1352|>": 153026,
403
+ "<|action_token_1353|>": 153027,
404
+ "<|action_token_1354|>": 153028,
405
+ "<|action_token_1355|>": 153029,
406
+ "<|action_token_1356|>": 153030,
407
+ "<|action_token_1357|>": 153031,
408
+ "<|action_token_1358|>": 153032,
409
+ "<|action_token_1359|>": 153033,
410
+ "<|action_token_135|>": 151809,
411
+ "<|action_token_1360|>": 153034,
412
+ "<|action_token_1361|>": 153035,
413
+ "<|action_token_1362|>": 153036,
414
+ "<|action_token_1363|>": 153037,
415
+ "<|action_token_1364|>": 153038,
416
+ "<|action_token_1365|>": 153039,
417
+ "<|action_token_1366|>": 153040,
418
+ "<|action_token_1367|>": 153041,
419
+ "<|action_token_1368|>": 153042,
420
+ "<|action_token_1369|>": 153043,
421
+ "<|action_token_136|>": 151810,
422
+ "<|action_token_1370|>": 153044,
423
+ "<|action_token_1371|>": 153045,
424
+ "<|action_token_1372|>": 153046,
425
+ "<|action_token_1373|>": 153047,
426
+ "<|action_token_1374|>": 153048,
427
+ "<|action_token_1375|>": 153049,
428
+ "<|action_token_1376|>": 153050,
429
+ "<|action_token_1377|>": 153051,
430
+ "<|action_token_1378|>": 153052,
431
+ "<|action_token_1379|>": 153053,
432
+ "<|action_token_137|>": 151811,
433
+ "<|action_token_1380|>": 153054,
434
+ "<|action_token_1381|>": 153055,
435
+ "<|action_token_1382|>": 153056,
436
+ "<|action_token_1383|>": 153057,
437
+ "<|action_token_1384|>": 153058,
438
+ "<|action_token_1385|>": 153059,
439
+ "<|action_token_1386|>": 153060,
440
+ "<|action_token_1387|>": 153061,
441
+ "<|action_token_1388|>": 153062,
442
+ "<|action_token_1389|>": 153063,
443
+ "<|action_token_138|>": 151812,
444
+ "<|action_token_1390|>": 153064,
445
+ "<|action_token_1391|>": 153065,
446
+ "<|action_token_1392|>": 153066,
447
+ "<|action_token_1393|>": 153067,
448
+ "<|action_token_1394|>": 153068,
449
+ "<|action_token_1395|>": 153069,
450
+ "<|action_token_1396|>": 153070,
451
+ "<|action_token_1397|>": 153071,
452
+ "<|action_token_1398|>": 153072,
453
+ "<|action_token_1399|>": 153073,
454
+ "<|action_token_139|>": 151813,
455
+ "<|action_token_13|>": 151687,
456
+ "<|action_token_1400|>": 153074,
457
+ "<|action_token_1401|>": 153075,
458
+ "<|action_token_1402|>": 153076,
459
+ "<|action_token_1403|>": 153077,
460
+ "<|action_token_1404|>": 153078,
461
+ "<|action_token_1405|>": 153079,
462
+ "<|action_token_1406|>": 153080,
463
+ "<|action_token_1407|>": 153081,
464
+ "<|action_token_1408|>": 153082,
465
+ "<|action_token_1409|>": 153083,
466
+ "<|action_token_140|>": 151814,
467
+ "<|action_token_1410|>": 153084,
468
+ "<|action_token_1411|>": 153085,
469
+ "<|action_token_1412|>": 153086,
470
+ "<|action_token_1413|>": 153087,
471
+ "<|action_token_1414|>": 153088,
472
+ "<|action_token_1415|>": 153089,
473
+ "<|action_token_1416|>": 153090,
474
+ "<|action_token_1417|>": 153091,
475
+ "<|action_token_1418|>": 153092,
476
+ "<|action_token_1419|>": 153093,
477
+ "<|action_token_141|>": 151815,
478
+ "<|action_token_1420|>": 153094,
479
+ "<|action_token_1421|>": 153095,
480
+ "<|action_token_1422|>": 153096,
481
+ "<|action_token_1423|>": 153097,
482
+ "<|action_token_1424|>": 153098,
483
+ "<|action_token_1425|>": 153099,
484
+ "<|action_token_1426|>": 153100,
485
+ "<|action_token_1427|>": 153101,
486
+ "<|action_token_1428|>": 153102,
487
+ "<|action_token_1429|>": 153103,
488
+ "<|action_token_142|>": 151816,
489
+ "<|action_token_1430|>": 153104,
490
+ "<|action_token_1431|>": 153105,
491
+ "<|action_token_1432|>": 153106,
492
+ "<|action_token_1433|>": 153107,
493
+ "<|action_token_1434|>": 153108,
494
+ "<|action_token_1435|>": 153109,
495
+ "<|action_token_1436|>": 153110,
496
+ "<|action_token_1437|>": 153111,
497
+ "<|action_token_1438|>": 153112,
498
+ "<|action_token_1439|>": 153113,
499
+ "<|action_token_143|>": 151817,
500
+ "<|action_token_1440|>": 153114,
501
+ "<|action_token_1441|>": 153115,
502
+ "<|action_token_1442|>": 153116,
503
+ "<|action_token_1443|>": 153117,
504
+ "<|action_token_1444|>": 153118,
505
+ "<|action_token_1445|>": 153119,
506
+ "<|action_token_1446|>": 153120,
507
+ "<|action_token_1447|>": 153121,
508
+ "<|action_token_1448|>": 153122,
509
+ "<|action_token_1449|>": 153123,
510
+ "<|action_token_144|>": 151818,
511
+ "<|action_token_1450|>": 153124,
512
+ "<|action_token_1451|>": 153125,
513
+ "<|action_token_1452|>": 153126,
514
+ "<|action_token_1453|>": 153127,
515
+ "<|action_token_1454|>": 153128,
516
+ "<|action_token_1455|>": 153129,
517
+ "<|action_token_1456|>": 153130,
518
+ "<|action_token_1457|>": 153131,
519
+ "<|action_token_1458|>": 153132,
520
+ "<|action_token_1459|>": 153133,
521
+ "<|action_token_145|>": 151819,
522
+ "<|action_token_1460|>": 153134,
523
+ "<|action_token_1461|>": 153135,
524
+ "<|action_token_1462|>": 153136,
525
+ "<|action_token_1463|>": 153137,
526
+ "<|action_token_1464|>": 153138,
527
+ "<|action_token_1465|>": 153139,
528
+ "<|action_token_1466|>": 153140,
529
+ "<|action_token_1467|>": 153141,
530
+ "<|action_token_1468|>": 153142,
531
+ "<|action_token_1469|>": 153143,
532
+ "<|action_token_146|>": 151820,
533
+ "<|action_token_1470|>": 153144,
534
+ "<|action_token_1471|>": 153145,
535
+ "<|action_token_1472|>": 153146,
536
+ "<|action_token_1473|>": 153147,
537
+ "<|action_token_1474|>": 153148,
538
+ "<|action_token_1475|>": 153149,
539
+ "<|action_token_1476|>": 153150,
540
+ "<|action_token_1477|>": 153151,
541
+ "<|action_token_1478|>": 153152,
542
+ "<|action_token_1479|>": 153153,
543
+ "<|action_token_147|>": 151821,
544
+ "<|action_token_1480|>": 153154,
545
+ "<|action_token_1481|>": 153155,
546
+ "<|action_token_1482|>": 153156,
547
+ "<|action_token_1483|>": 153157,
548
+ "<|action_token_1484|>": 153158,
549
+ "<|action_token_1485|>": 153159,
550
+ "<|action_token_1486|>": 153160,
551
+ "<|action_token_1487|>": 153161,
552
+ "<|action_token_1488|>": 153162,
553
+ "<|action_token_1489|>": 153163,
554
+ "<|action_token_148|>": 151822,
555
+ "<|action_token_1490|>": 153164,
556
+ "<|action_token_1491|>": 153165,
557
+ "<|action_token_1492|>": 153166,
558
+ "<|action_token_1493|>": 153167,
559
+ "<|action_token_1494|>": 153168,
560
+ "<|action_token_1495|>": 153169,
561
+ "<|action_token_1496|>": 153170,
562
+ "<|action_token_1497|>": 153171,
563
+ "<|action_token_1498|>": 153172,
564
+ "<|action_token_1499|>": 153173,
565
+ "<|action_token_149|>": 151823,
566
+ "<|action_token_14|>": 151688,
567
+ "<|action_token_1500|>": 153174,
568
+ "<|action_token_1501|>": 153175,
569
+ "<|action_token_1502|>": 153176,
570
+ "<|action_token_1503|>": 153177,
571
+ "<|action_token_1504|>": 153178,
572
+ "<|action_token_1505|>": 153179,
573
+ "<|action_token_1506|>": 153180,
574
+ "<|action_token_1507|>": 153181,
575
+ "<|action_token_1508|>": 153182,
576
+ "<|action_token_1509|>": 153183,
577
+ "<|action_token_150|>": 151824,
578
+ "<|action_token_1510|>": 153184,
579
+ "<|action_token_1511|>": 153185,
580
+ "<|action_token_1512|>": 153186,
581
+ "<|action_token_1513|>": 153187,
582
+ "<|action_token_1514|>": 153188,
583
+ "<|action_token_1515|>": 153189,
584
+ "<|action_token_1516|>": 153190,
585
+ "<|action_token_1517|>": 153191,
586
+ "<|action_token_1518|>": 153192,
587
+ "<|action_token_1519|>": 153193,
588
+ "<|action_token_151|>": 151825,
589
+ "<|action_token_1520|>": 153194,
590
+ "<|action_token_1521|>": 153195,
591
+ "<|action_token_1522|>": 153196,
592
+ "<|action_token_1523|>": 153197,
593
+ "<|action_token_1524|>": 153198,
594
+ "<|action_token_1525|>": 153199,
595
+ "<|action_token_1526|>": 153200,
596
+ "<|action_token_1527|>": 153201,
597
+ "<|action_token_1528|>": 153202,
598
+ "<|action_token_1529|>": 153203,
599
+ "<|action_token_152|>": 151826,
600
+ "<|action_token_1530|>": 153204,
601
+ "<|action_token_1531|>": 153205,
602
+ "<|action_token_1532|>": 153206,
603
+ "<|action_token_1533|>": 153207,
604
+ "<|action_token_1534|>": 153208,
605
+ "<|action_token_1535|>": 153209,
606
+ "<|action_token_1536|>": 153210,
607
+ "<|action_token_1537|>": 153211,
608
+ "<|action_token_1538|>": 153212,
609
+ "<|action_token_1539|>": 153213,
610
+ "<|action_token_153|>": 151827,
611
+ "<|action_token_1540|>": 153214,
612
+ "<|action_token_1541|>": 153215,
613
+ "<|action_token_1542|>": 153216,
614
+ "<|action_token_1543|>": 153217,
615
+ "<|action_token_1544|>": 153218,
616
+ "<|action_token_1545|>": 153219,
617
+ "<|action_token_1546|>": 153220,
618
+ "<|action_token_1547|>": 153221,
619
+ "<|action_token_1548|>": 153222,
620
+ "<|action_token_1549|>": 153223,
621
+ "<|action_token_154|>": 151828,
622
+ "<|action_token_1550|>": 153224,
623
+ "<|action_token_1551|>": 153225,
624
+ "<|action_token_1552|>": 153226,
625
+ "<|action_token_1553|>": 153227,
626
+ "<|action_token_1554|>": 153228,
627
+ "<|action_token_1555|>": 153229,
628
+ "<|action_token_1556|>": 153230,
629
+ "<|action_token_1557|>": 153231,
630
+ "<|action_token_1558|>": 153232,
631
+ "<|action_token_1559|>": 153233,
632
+ "<|action_token_155|>": 151829,
633
+ "<|action_token_1560|>": 153234,
634
+ "<|action_token_1561|>": 153235,
635
+ "<|action_token_1562|>": 153236,
636
+ "<|action_token_1563|>": 153237,
637
+ "<|action_token_1564|>": 153238,
638
+ "<|action_token_1565|>": 153239,
639
+ "<|action_token_1566|>": 153240,
640
+ "<|action_token_1567|>": 153241,
641
+ "<|action_token_1568|>": 153242,
642
+ "<|action_token_1569|>": 153243,
643
+ "<|action_token_156|>": 151830,
644
+ "<|action_token_1570|>": 153244,
645
+ "<|action_token_1571|>": 153245,
646
+ "<|action_token_1572|>": 153246,
647
+ "<|action_token_1573|>": 153247,
648
+ "<|action_token_1574|>": 153248,
649
+ "<|action_token_1575|>": 153249,
650
+ "<|action_token_1576|>": 153250,
651
+ "<|action_token_1577|>": 153251,
652
+ "<|action_token_1578|>": 153252,
653
+ "<|action_token_1579|>": 153253,
654
+ "<|action_token_157|>": 151831,
655
+ "<|action_token_1580|>": 153254,
656
+ "<|action_token_1581|>": 153255,
657
+ "<|action_token_1582|>": 153256,
658
+ "<|action_token_1583|>": 153257,
659
+ "<|action_token_1584|>": 153258,
660
+ "<|action_token_1585|>": 153259,
661
+ "<|action_token_1586|>": 153260,
662
+ "<|action_token_1587|>": 153261,
663
+ "<|action_token_1588|>": 153262,
664
+ "<|action_token_1589|>": 153263,
665
+ "<|action_token_158|>": 151832,
666
+ "<|action_token_1590|>": 153264,
667
+ "<|action_token_1591|>": 153265,
668
+ "<|action_token_1592|>": 153266,
669
+ "<|action_token_1593|>": 153267,
670
+ "<|action_token_1594|>": 153268,
671
+ "<|action_token_1595|>": 153269,
672
+ "<|action_token_1596|>": 153270,
673
+ "<|action_token_1597|>": 153271,
674
+ "<|action_token_1598|>": 153272,
675
+ "<|action_token_1599|>": 153273,
676
+ "<|action_token_159|>": 151833,
677
+ "<|action_token_15|>": 151689,
678
+ "<|action_token_1600|>": 153274,
679
+ "<|action_token_1601|>": 153275,
680
+ "<|action_token_1602|>": 153276,
681
+ "<|action_token_1603|>": 153277,
682
+ "<|action_token_1604|>": 153278,
683
+ "<|action_token_1605|>": 153279,
684
+ "<|action_token_1606|>": 153280,
685
+ "<|action_token_1607|>": 153281,
686
+ "<|action_token_1608|>": 153282,
687
+ "<|action_token_1609|>": 153283,
688
+ "<|action_token_160|>": 151834,
689
+ "<|action_token_1610|>": 153284,
690
+ "<|action_token_1611|>": 153285,
691
+ "<|action_token_1612|>": 153286,
692
+ "<|action_token_1613|>": 153287,
693
+ "<|action_token_1614|>": 153288,
694
+ "<|action_token_1615|>": 153289,
695
+ "<|action_token_1616|>": 153290,
696
+ "<|action_token_1617|>": 153291,
697
+ "<|action_token_1618|>": 153292,
698
+ "<|action_token_1619|>": 153293,
699
+ "<|action_token_161|>": 151835,
700
+ "<|action_token_1620|>": 153294,
701
+ "<|action_token_1621|>": 153295,
702
+ "<|action_token_1622|>": 153296,
703
+ "<|action_token_1623|>": 153297,
704
+ "<|action_token_1624|>": 153298,
705
+ "<|action_token_1625|>": 153299,
706
+ "<|action_token_1626|>": 153300,
707
+ "<|action_token_1627|>": 153301,
708
+ "<|action_token_1628|>": 153302,
709
+ "<|action_token_1629|>": 153303,
710
+ "<|action_token_162|>": 151836,
711
+ "<|action_token_1630|>": 153304,
712
+ "<|action_token_1631|>": 153305,
713
+ "<|action_token_1632|>": 153306,
714
+ "<|action_token_1633|>": 153307,
715
+ "<|action_token_1634|>": 153308,
716
+ "<|action_token_1635|>": 153309,
717
+ "<|action_token_1636|>": 153310,
718
+ "<|action_token_1637|>": 153311,
719
+ "<|action_token_1638|>": 153312,
720
+ "<|action_token_1639|>": 153313,
721
+ "<|action_token_163|>": 151837,
722
+ "<|action_token_1640|>": 153314,
723
+ "<|action_token_1641|>": 153315,
724
+ "<|action_token_1642|>": 153316,
725
+ "<|action_token_1643|>": 153317,
726
+ "<|action_token_1644|>": 153318,
727
+ "<|action_token_1645|>": 153319,
728
+ "<|action_token_1646|>": 153320,
729
+ "<|action_token_1647|>": 153321,
730
+ "<|action_token_1648|>": 153322,
731
+ "<|action_token_1649|>": 153323,
732
+ "<|action_token_164|>": 151838,
733
+ "<|action_token_1650|>": 153324,
734
+ "<|action_token_1651|>": 153325,
735
+ "<|action_token_1652|>": 153326,
736
+ "<|action_token_1653|>": 153327,
737
+ "<|action_token_1654|>": 153328,
738
+ "<|action_token_1655|>": 153329,
739
+ "<|action_token_1656|>": 153330,
740
+ "<|action_token_1657|>": 153331,
741
+ "<|action_token_1658|>": 153332,
742
+ "<|action_token_1659|>": 153333,
743
+ "<|action_token_165|>": 151839,
744
+ "<|action_token_1660|>": 153334,
745
+ "<|action_token_1661|>": 153335,
746
+ "<|action_token_1662|>": 153336,
747
+ "<|action_token_1663|>": 153337,
748
+ "<|action_token_1664|>": 153338,
749
+ "<|action_token_1665|>": 153339,
750
+ "<|action_token_1666|>": 153340,
751
+ "<|action_token_1667|>": 153341,
752
+ "<|action_token_1668|>": 153342,
753
+ "<|action_token_1669|>": 153343,
754
+ "<|action_token_166|>": 151840,
755
+ "<|action_token_1670|>": 153344,
756
+ "<|action_token_1671|>": 153345,
757
+ "<|action_token_1672|>": 153346,
758
+ "<|action_token_1673|>": 153347,
759
+ "<|action_token_1674|>": 153348,
760
+ "<|action_token_1675|>": 153349,
761
+ "<|action_token_1676|>": 153350,
762
+ "<|action_token_1677|>": 153351,
763
+ "<|action_token_1678|>": 153352,
764
+ "<|action_token_1679|>": 153353,
765
+ "<|action_token_167|>": 151841,
766
+ "<|action_token_1680|>": 153354,
767
+ "<|action_token_1681|>": 153355,
768
+ "<|action_token_1682|>": 153356,
769
+ "<|action_token_1683|>": 153357,
770
+ "<|action_token_1684|>": 153358,
771
+ "<|action_token_1685|>": 153359,
772
+ "<|action_token_1686|>": 153360,
773
+ "<|action_token_1687|>": 153361,
774
+ "<|action_token_1688|>": 153362,
775
+ "<|action_token_1689|>": 153363,
776
+ "<|action_token_168|>": 151842,
777
+ "<|action_token_1690|>": 153364,
778
+ "<|action_token_1691|>": 153365,
779
+ "<|action_token_1692|>": 153366,
780
+ "<|action_token_1693|>": 153367,
781
+ "<|action_token_1694|>": 153368,
782
+ "<|action_token_1695|>": 153369,
783
+ "<|action_token_1696|>": 153370,
784
+ "<|action_token_1697|>": 153371,
785
+ "<|action_token_1698|>": 153372,
786
+ "<|action_token_1699|>": 153373,
787
+ "<|action_token_169|>": 151843,
788
+ "<|action_token_16|>": 151690,
789
+ "<|action_token_1700|>": 153374,
790
+ "<|action_token_1701|>": 153375,
791
+ "<|action_token_1702|>": 153376,
792
+ "<|action_token_1703|>": 153377,
793
+ "<|action_token_1704|>": 153378,
794
+ "<|action_token_1705|>": 153379,
795
+ "<|action_token_1706|>": 153380,
796
+ "<|action_token_1707|>": 153381,
797
+ "<|action_token_1708|>": 153382,
798
+ "<|action_token_1709|>": 153383,
799
+ "<|action_token_170|>": 151844,
800
+ "<|action_token_1710|>": 153384,
801
+ "<|action_token_1711|>": 153385,
802
+ "<|action_token_1712|>": 153386,
803
+ "<|action_token_1713|>": 153387,
804
+ "<|action_token_1714|>": 153388,
805
+ "<|action_token_1715|>": 153389,
806
+ "<|action_token_1716|>": 153390,
807
+ "<|action_token_1717|>": 153391,
808
+ "<|action_token_1718|>": 153392,
809
+ "<|action_token_1719|>": 153393,
810
+ "<|action_token_171|>": 151845,
811
+ "<|action_token_1720|>": 153394,
812
+ "<|action_token_1721|>": 153395,
813
+ "<|action_token_1722|>": 153396,
814
+ "<|action_token_1723|>": 153397,
815
+ "<|action_token_1724|>": 153398,
816
+ "<|action_token_1725|>": 153399,
817
+ "<|action_token_1726|>": 153400,
818
+ "<|action_token_1727|>": 153401,
819
+ "<|action_token_1728|>": 153402,
820
+ "<|action_token_1729|>": 153403,
821
+ "<|action_token_172|>": 151846,
822
+ "<|action_token_1730|>": 153404,
823
+ "<|action_token_1731|>": 153405,
824
+ "<|action_token_1732|>": 153406,
825
+ "<|action_token_1733|>": 153407,
826
+ "<|action_token_1734|>": 153408,
827
+ "<|action_token_1735|>": 153409,
828
+ "<|action_token_1736|>": 153410,
829
+ "<|action_token_1737|>": 153411,
830
+ "<|action_token_1738|>": 153412,
831
+ "<|action_token_1739|>": 153413,
832
+ "<|action_token_173|>": 151847,
833
+ "<|action_token_1740|>": 153414,
834
+ "<|action_token_1741|>": 153415,
835
+ "<|action_token_1742|>": 153416,
836
+ "<|action_token_1743|>": 153417,
837
+ "<|action_token_1744|>": 153418,
838
+ "<|action_token_1745|>": 153419,
839
+ "<|action_token_1746|>": 153420,
840
+ "<|action_token_1747|>": 153421,
841
+ "<|action_token_1748|>": 153422,
842
+ "<|action_token_1749|>": 153423,
843
+ "<|action_token_174|>": 151848,
844
+ "<|action_token_1750|>": 153424,
845
+ "<|action_token_1751|>": 153425,
846
+ "<|action_token_1752|>": 153426,
847
+ "<|action_token_1753|>": 153427,
848
+ "<|action_token_1754|>": 153428,
849
+ "<|action_token_1755|>": 153429,
850
+ "<|action_token_1756|>": 153430,
851
+ "<|action_token_1757|>": 153431,
852
+ "<|action_token_1758|>": 153432,
853
+ "<|action_token_1759|>": 153433,
854
+ "<|action_token_175|>": 151849,
855
+ "<|action_token_1760|>": 153434,
856
+ "<|action_token_1761|>": 153435,
857
+ "<|action_token_1762|>": 153436,
858
+ "<|action_token_1763|>": 153437,
859
+ "<|action_token_1764|>": 153438,
860
+ "<|action_token_1765|>": 153439,
861
+ "<|action_token_1766|>": 153440,
862
+ "<|action_token_1767|>": 153441,
863
+ "<|action_token_1768|>": 153442,
864
+ "<|action_token_1769|>": 153443,
865
+ "<|action_token_176|>": 151850,
866
+ "<|action_token_1770|>": 153444,
867
+ "<|action_token_1771|>": 153445,
868
+ "<|action_token_1772|>": 153446,
869
+ "<|action_token_1773|>": 153447,
870
+ "<|action_token_1774|>": 153448,
871
+ "<|action_token_1775|>": 153449,
872
+ "<|action_token_1776|>": 153450,
873
+ "<|action_token_1777|>": 153451,
874
+ "<|action_token_1778|>": 153452,
875
+ "<|action_token_1779|>": 153453,
876
+ "<|action_token_177|>": 151851,
877
+ "<|action_token_1780|>": 153454,
878
+ "<|action_token_1781|>": 153455,
879
+ "<|action_token_1782|>": 153456,
880
+ "<|action_token_1783|>": 153457,
881
+ "<|action_token_1784|>": 153458,
882
+ "<|action_token_1785|>": 153459,
883
+ "<|action_token_1786|>": 153460,
884
+ "<|action_token_1787|>": 153461,
885
+ "<|action_token_1788|>": 153462,
886
+ "<|action_token_1789|>": 153463,
887
+ "<|action_token_178|>": 151852,
888
+ "<|action_token_1790|>": 153464,
889
+ "<|action_token_1791|>": 153465,
890
+ "<|action_token_1792|>": 153466,
891
+ "<|action_token_1793|>": 153467,
892
+ "<|action_token_1794|>": 153468,
893
+ "<|action_token_1795|>": 153469,
894
+ "<|action_token_1796|>": 153470,
895
+ "<|action_token_1797|>": 153471,
896
+ "<|action_token_1798|>": 153472,
897
+ "<|action_token_1799|>": 153473,
898
+ "<|action_token_179|>": 151853,
899
+ "<|action_token_17|>": 151691,
900
+ "<|action_token_1800|>": 153474,
901
+ "<|action_token_1801|>": 153475,
902
+ "<|action_token_1802|>": 153476,
903
+ "<|action_token_1803|>": 153477,
904
+ "<|action_token_1804|>": 153478,
905
+ "<|action_token_1805|>": 153479,
906
+ "<|action_token_1806|>": 153480,
907
+ "<|action_token_1807|>": 153481,
908
+ "<|action_token_1808|>": 153482,
909
+ "<|action_token_1809|>": 153483,
910
+ "<|action_token_180|>": 151854,
911
+ "<|action_token_1810|>": 153484,
912
+ "<|action_token_1811|>": 153485,
913
+ "<|action_token_1812|>": 153486,
914
+ "<|action_token_1813|>": 153487,
915
+ "<|action_token_1814|>": 153488,
916
+ "<|action_token_1815|>": 153489,
917
+ "<|action_token_1816|>": 153490,
918
+ "<|action_token_1817|>": 153491,
919
+ "<|action_token_1818|>": 153492,
920
+ "<|action_token_1819|>": 153493,
921
+ "<|action_token_181|>": 151855,
922
+ "<|action_token_1820|>": 153494,
923
+ "<|action_token_1821|>": 153495,
924
+ "<|action_token_1822|>": 153496,
925
+ "<|action_token_1823|>": 153497,
926
+ "<|action_token_1824|>": 153498,
927
+ "<|action_token_1825|>": 153499,
928
+ "<|action_token_1826|>": 153500,
929
+ "<|action_token_1827|>": 153501,
930
+ "<|action_token_1828|>": 153502,
931
+ "<|action_token_1829|>": 153503,
932
+ "<|action_token_182|>": 151856,
933
+ "<|action_token_1830|>": 153504,
934
+ "<|action_token_1831|>": 153505,
935
+ "<|action_token_1832|>": 153506,
936
+ "<|action_token_1833|>": 153507,
937
+ "<|action_token_1834|>": 153508,
938
+ "<|action_token_1835|>": 153509,
939
+ "<|action_token_1836|>": 153510,
940
+ "<|action_token_1837|>": 153511,
941
+ "<|action_token_1838|>": 153512,
942
+ "<|action_token_1839|>": 153513,
943
+ "<|action_token_183|>": 151857,
944
+ "<|action_token_1840|>": 153514,
945
+ "<|action_token_1841|>": 153515,
946
+ "<|action_token_1842|>": 153516,
947
+ "<|action_token_1843|>": 153517,
948
+ "<|action_token_1844|>": 153518,
949
+ "<|action_token_1845|>": 153519,
950
+ "<|action_token_1846|>": 153520,
951
+ "<|action_token_1847|>": 153521,
952
+ "<|action_token_1848|>": 153522,
953
+ "<|action_token_1849|>": 153523,
954
+ "<|action_token_184|>": 151858,
955
+ "<|action_token_1850|>": 153524,
956
+ "<|action_token_1851|>": 153525,
957
+ "<|action_token_1852|>": 153526,
958
+ "<|action_token_1853|>": 153527,
959
+ "<|action_token_1854|>": 153528,
960
+ "<|action_token_1855|>": 153529,
961
+ "<|action_token_1856|>": 153530,
962
+ "<|action_token_1857|>": 153531,
963
+ "<|action_token_1858|>": 153532,
964
+ "<|action_token_1859|>": 153533,
965
+ "<|action_token_185|>": 151859,
966
+ "<|action_token_1860|>": 153534,
967
+ "<|action_token_1861|>": 153535,
968
+ "<|action_token_1862|>": 153536,
969
+ "<|action_token_1863|>": 153537,
970
+ "<|action_token_1864|>": 153538,
971
+ "<|action_token_1865|>": 153539,
972
+ "<|action_token_1866|>": 153540,
973
+ "<|action_token_1867|>": 153541,
974
+ "<|action_token_1868|>": 153542,
975
+ "<|action_token_1869|>": 153543,
976
+ "<|action_token_186|>": 151860,
977
+ "<|action_token_1870|>": 153544,
978
+ "<|action_token_1871|>": 153545,
979
+ "<|action_token_1872|>": 153546,
980
+ "<|action_token_1873|>": 153547,
981
+ "<|action_token_1874|>": 153548,
982
+ "<|action_token_1875|>": 153549,
983
+ "<|action_token_1876|>": 153550,
984
+ "<|action_token_1877|>": 153551,
985
+ "<|action_token_1878|>": 153552,
986
+ "<|action_token_1879|>": 153553,
987
+ "<|action_token_187|>": 151861,
988
+ "<|action_token_1880|>": 153554,
989
+ "<|action_token_1881|>": 153555,
990
+ "<|action_token_1882|>": 153556,
991
+ "<|action_token_1883|>": 153557,
992
+ "<|action_token_1884|>": 153558,
993
+ "<|action_token_1885|>": 153559,
994
+ "<|action_token_1886|>": 153560,
995
+ "<|action_token_1887|>": 153561,
996
+ "<|action_token_1888|>": 153562,
997
+ "<|action_token_1889|>": 153563,
998
+ "<|action_token_188|>": 151862,
999
+ "<|action_token_1890|>": 153564,
1000
+ "<|action_token_1891|>": 153565,
1001
+ "<|action_token_1892|>": 153566,
1002
+ "<|action_token_1893|>": 153567,
1003
+ "<|action_token_1894|>": 153568,
1004
+ "<|action_token_1895|>": 153569,
1005
+ "<|action_token_1896|>": 153570,
1006
+ "<|action_token_1897|>": 153571,
1007
+ "<|action_token_1898|>": 153572,
1008
+ "<|action_token_1899|>": 153573,
1009
+ "<|action_token_189|>": 151863,
1010
+ "<|action_token_18|>": 151692,
1011
+ "<|action_token_1900|>": 153574,
1012
+ "<|action_token_1901|>": 153575,
1013
+ "<|action_token_1902|>": 153576,
1014
+ "<|action_token_1903|>": 153577,
1015
+ "<|action_token_1904|>": 153578,
1016
+ "<|action_token_1905|>": 153579,
1017
+ "<|action_token_1906|>": 153580,
1018
+ "<|action_token_1907|>": 153581,
1019
+ "<|action_token_1908|>": 153582,
1020
+ "<|action_token_1909|>": 153583,
1021
+ "<|action_token_190|>": 151864,
1022
+ "<|action_token_1910|>": 153584,
1023
+ "<|action_token_1911|>": 153585,
1024
+ "<|action_token_1912|>": 153586,
1025
+ "<|action_token_1913|>": 153587,
1026
+ "<|action_token_1914|>": 153588,
1027
+ "<|action_token_1915|>": 153589,
1028
+ "<|action_token_1916|>": 153590,
1029
+ "<|action_token_1917|>": 153591,
1030
+ "<|action_token_1918|>": 153592,
1031
+ "<|action_token_1919|>": 153593,
1032
+ "<|action_token_191|>": 151865,
1033
+ "<|action_token_1920|>": 153594,
1034
+ "<|action_token_1921|>": 153595,
1035
+ "<|action_token_1922|>": 153596,
1036
+ "<|action_token_1923|>": 153597,
1037
+ "<|action_token_1924|>": 153598,
1038
+ "<|action_token_1925|>": 153599,
1039
+ "<|action_token_1926|>": 153600,
1040
+ "<|action_token_1927|>": 153601,
1041
+ "<|action_token_1928|>": 153602,
1042
+ "<|action_token_1929|>": 153603,
1043
+ "<|action_token_192|>": 151866,
1044
+ "<|action_token_1930|>": 153604,
1045
+ "<|action_token_1931|>": 153605,
1046
+ "<|action_token_1932|>": 153606,
1047
+ "<|action_token_1933|>": 153607,
1048
+ "<|action_token_1934|>": 153608,
1049
+ "<|action_token_1935|>": 153609,
1050
+ "<|action_token_1936|>": 153610,
1051
+ "<|action_token_1937|>": 153611,
1052
+ "<|action_token_1938|>": 153612,
1053
+ "<|action_token_1939|>": 153613,
1054
+ "<|action_token_193|>": 151867,
1055
+ "<|action_token_1940|>": 153614,
1056
+ "<|action_token_1941|>": 153615,
1057
+ "<|action_token_1942|>": 153616,
1058
+ "<|action_token_1943|>": 153617,
1059
+ "<|action_token_1944|>": 153618,
1060
+ "<|action_token_1945|>": 153619,
1061
+ "<|action_token_1946|>": 153620,
1062
+ "<|action_token_1947|>": 153621,
1063
+ "<|action_token_1948|>": 153622,
1064
+ "<|action_token_1949|>": 153623,
1065
+ "<|action_token_194|>": 151868,
1066
+ "<|action_token_1950|>": 153624,
1067
+ "<|action_token_1951|>": 153625,
1068
+ "<|action_token_1952|>": 153626,
1069
+ "<|action_token_1953|>": 153627,
1070
+ "<|action_token_1954|>": 153628,
1071
+ "<|action_token_1955|>": 153629,
1072
+ "<|action_token_1956|>": 153630,
1073
+ "<|action_token_1957|>": 153631,
1074
+ "<|action_token_1958|>": 153632,
1075
+ "<|action_token_1959|>": 153633,
1076
+ "<|action_token_195|>": 151869,
1077
+ "<|action_token_1960|>": 153634,
1078
+ "<|action_token_1961|>": 153635,
1079
+ "<|action_token_1962|>": 153636,
1080
+ "<|action_token_1963|>": 153637,
1081
+ "<|action_token_1964|>": 153638,
1082
+ "<|action_token_1965|>": 153639,
1083
+ "<|action_token_1966|>": 153640,
1084
+ "<|action_token_1967|>": 153641,
1085
+ "<|action_token_1968|>": 153642,
1086
+ "<|action_token_1969|>": 153643,
1087
+ "<|action_token_196|>": 151870,
1088
+ "<|action_token_1970|>": 153644,
1089
+ "<|action_token_1971|>": 153645,
1090
+ "<|action_token_1972|>": 153646,
1091
+ "<|action_token_1973|>": 153647,
1092
+ "<|action_token_1974|>": 153648,
1093
+ "<|action_token_1975|>": 153649,
1094
+ "<|action_token_1976|>": 153650,
1095
+ "<|action_token_1977|>": 153651,
1096
+ "<|action_token_1978|>": 153652,
1097
+ "<|action_token_1979|>": 153653,
1098
+ "<|action_token_197|>": 151871,
1099
+ "<|action_token_1980|>": 153654,
1100
+ "<|action_token_1981|>": 153655,
1101
+ "<|action_token_1982|>": 153656,
1102
+ "<|action_token_1983|>": 153657,
1103
+ "<|action_token_1984|>": 153658,
1104
+ "<|action_token_1985|>": 153659,
1105
+ "<|action_token_1986|>": 153660,
1106
+ "<|action_token_1987|>": 153661,
1107
+ "<|action_token_1988|>": 153662,
1108
+ "<|action_token_1989|>": 153663,
1109
+ "<|action_token_198|>": 151872,
1110
+ "<|action_token_1990|>": 153664,
1111
+ "<|action_token_1991|>": 153665,
1112
+ "<|action_token_1992|>": 153666,
1113
+ "<|action_token_1993|>": 153667,
1114
+ "<|action_token_1994|>": 153668,
1115
+ "<|action_token_1995|>": 153669,
1116
+ "<|action_token_1996|>": 153670,
1117
+ "<|action_token_1997|>": 153671,
1118
+ "<|action_token_1998|>": 153672,
1119
+ "<|action_token_1999|>": 153673,
1120
+ "<|action_token_199|>": 151873,
1121
+ "<|action_token_19|>": 151693,
1122
+ "<|action_token_1|>": 151675,
1123
+ "<|action_token_2000|>": 153674,
1124
+ "<|action_token_2001|>": 153675,
1125
+ "<|action_token_2002|>": 153676,
1126
+ "<|action_token_2003|>": 153677,
1127
+ "<|action_token_2004|>": 153678,
1128
+ "<|action_token_2005|>": 153679,
1129
+ "<|action_token_2006|>": 153680,
1130
+ "<|action_token_2007|>": 153681,
1131
+ "<|action_token_2008|>": 153682,
1132
+ "<|action_token_2009|>": 153683,
1133
+ "<|action_token_200|>": 151874,
1134
+ "<|action_token_2010|>": 153684,
1135
+ "<|action_token_2011|>": 153685,
1136
+ "<|action_token_2012|>": 153686,
1137
+ "<|action_token_2013|>": 153687,
1138
+ "<|action_token_2014|>": 153688,
1139
+ "<|action_token_2015|>": 153689,
1140
+ "<|action_token_2016|>": 153690,
1141
+ "<|action_token_2017|>": 153691,
1142
+ "<|action_token_2018|>": 153692,
1143
+ "<|action_token_2019|>": 153693,
1144
+ "<|action_token_201|>": 151875,
1145
+ "<|action_token_2020|>": 153694,
1146
+ "<|action_token_2021|>": 153695,
1147
+ "<|action_token_2022|>": 153696,
1148
+ "<|action_token_2023|>": 153697,
1149
+ "<|action_token_2024|>": 153698,
1150
+ "<|action_token_2025|>": 153699,
1151
+ "<|action_token_2026|>": 153700,
1152
+ "<|action_token_2027|>": 153701,
1153
+ "<|action_token_2028|>": 153702,
1154
+ "<|action_token_2029|>": 153703,
1155
+ "<|action_token_202|>": 151876,
1156
+ "<|action_token_2030|>": 153704,
1157
+ "<|action_token_2031|>": 153705,
1158
+ "<|action_token_2032|>": 153706,
1159
+ "<|action_token_2033|>": 153707,
1160
+ "<|action_token_2034|>": 153708,
1161
+ "<|action_token_2035|>": 153709,
1162
+ "<|action_token_2036|>": 153710,
1163
+ "<|action_token_2037|>": 153711,
1164
+ "<|action_token_2038|>": 153712,
1165
+ "<|action_token_2039|>": 153713,
1166
+ "<|action_token_203|>": 151877,
1167
+ "<|action_token_2040|>": 153714,
1168
+ "<|action_token_2041|>": 153715,
1169
+ "<|action_token_2042|>": 153716,
1170
+ "<|action_token_2043|>": 153717,
1171
+ "<|action_token_2044|>": 153718,
1172
+ "<|action_token_2045|>": 153719,
1173
+ "<|action_token_2046|>": 153720,
1174
+ "<|action_token_2047|>": 153721,
1175
+ "<|action_token_204|>": 151878,
1176
+ "<|action_token_205|>": 151879,
1177
+ "<|action_token_206|>": 151880,
1178
+ "<|action_token_207|>": 151881,
1179
+ "<|action_token_208|>": 151882,
1180
+ "<|action_token_209|>": 151883,
1181
+ "<|action_token_20|>": 151694,
1182
+ "<|action_token_210|>": 151884,
1183
+ "<|action_token_211|>": 151885,
1184
+ "<|action_token_212|>": 151886,
1185
+ "<|action_token_213|>": 151887,
1186
+ "<|action_token_214|>": 151888,
1187
+ "<|action_token_215|>": 151889,
1188
+ "<|action_token_216|>": 151890,
1189
+ "<|action_token_217|>": 151891,
1190
+ "<|action_token_218|>": 151892,
1191
+ "<|action_token_219|>": 151893,
1192
+ "<|action_token_21|>": 151695,
1193
+ "<|action_token_220|>": 151894,
1194
+ "<|action_token_221|>": 151895,
1195
+ "<|action_token_222|>": 151896,
1196
+ "<|action_token_223|>": 151897,
1197
+ "<|action_token_224|>": 151898,
1198
+ "<|action_token_225|>": 151899,
1199
+ "<|action_token_226|>": 151900,
1200
+ "<|action_token_227|>": 151901,
1201
+ "<|action_token_228|>": 151902,
1202
+ "<|action_token_229|>": 151903,
1203
+ "<|action_token_22|>": 151696,
1204
+ "<|action_token_230|>": 151904,
1205
+ "<|action_token_231|>": 151905,
1206
+ "<|action_token_232|>": 151906,
1207
+ "<|action_token_233|>": 151907,
1208
+ "<|action_token_234|>": 151908,
1209
+ "<|action_token_235|>": 151909,
1210
+ "<|action_token_236|>": 151910,
1211
+ "<|action_token_237|>": 151911,
1212
+ "<|action_token_238|>": 151912,
1213
+ "<|action_token_239|>": 151913,
1214
+ "<|action_token_23|>": 151697,
1215
+ "<|action_token_240|>": 151914,
1216
+ "<|action_token_241|>": 151915,
1217
+ "<|action_token_242|>": 151916,
1218
+ "<|action_token_243|>": 151917,
1219
+ "<|action_token_244|>": 151918,
1220
+ "<|action_token_245|>": 151919,
1221
+ "<|action_token_246|>": 151920,
1222
+ "<|action_token_247|>": 151921,
1223
+ "<|action_token_248|>": 151922,
1224
+ "<|action_token_249|>": 151923,
1225
+ "<|action_token_24|>": 151698,
1226
+ "<|action_token_250|>": 151924,
1227
+ "<|action_token_251|>": 151925,
1228
+ "<|action_token_252|>": 151926,
1229
+ "<|action_token_253|>": 151927,
1230
+ "<|action_token_254|>": 151928,
1231
+ "<|action_token_255|>": 151929,
1232
+ "<|action_token_256|>": 151930,
1233
+ "<|action_token_257|>": 151931,
1234
+ "<|action_token_258|>": 151932,
1235
+ "<|action_token_259|>": 151933,
1236
+ "<|action_token_25|>": 151699,
1237
+ "<|action_token_260|>": 151934,
1238
+ "<|action_token_261|>": 151935,
1239
+ "<|action_token_262|>": 151936,
1240
+ "<|action_token_263|>": 151937,
1241
+ "<|action_token_264|>": 151938,
1242
+ "<|action_token_265|>": 151939,
1243
+ "<|action_token_266|>": 151940,
1244
+ "<|action_token_267|>": 151941,
1245
+ "<|action_token_268|>": 151942,
1246
+ "<|action_token_269|>": 151943,
1247
+ "<|action_token_26|>": 151700,
1248
+ "<|action_token_270|>": 151944,
1249
+ "<|action_token_271|>": 151945,
1250
+ "<|action_token_272|>": 151946,
1251
+ "<|action_token_273|>": 151947,
1252
+ "<|action_token_274|>": 151948,
1253
+ "<|action_token_275|>": 151949,
1254
+ "<|action_token_276|>": 151950,
1255
+ "<|action_token_277|>": 151951,
1256
+ "<|action_token_278|>": 151952,
1257
+ "<|action_token_279|>": 151953,
1258
+ "<|action_token_27|>": 151701,
1259
+ "<|action_token_280|>": 151954,
1260
+ "<|action_token_281|>": 151955,
1261
+ "<|action_token_282|>": 151956,
1262
+ "<|action_token_283|>": 151957,
1263
+ "<|action_token_284|>": 151958,
1264
+ "<|action_token_285|>": 151959,
1265
+ "<|action_token_286|>": 151960,
1266
+ "<|action_token_287|>": 151961,
1267
+ "<|action_token_288|>": 151962,
1268
+ "<|action_token_289|>": 151963,
1269
+ "<|action_token_28|>": 151702,
1270
+ "<|action_token_290|>": 151964,
1271
+ "<|action_token_291|>": 151965,
1272
+ "<|action_token_292|>": 151966,
1273
+ "<|action_token_293|>": 151967,
1274
+ "<|action_token_294|>": 151968,
1275
+ "<|action_token_295|>": 151969,
1276
+ "<|action_token_296|>": 151970,
1277
+ "<|action_token_297|>": 151971,
1278
+ "<|action_token_298|>": 151972,
1279
+ "<|action_token_299|>": 151973,
1280
+ "<|action_token_29|>": 151703,
1281
+ "<|action_token_2|>": 151676,
1282
+ "<|action_token_300|>": 151974,
1283
+ "<|action_token_301|>": 151975,
1284
+ "<|action_token_302|>": 151976,
1285
+ "<|action_token_303|>": 151977,
1286
+ "<|action_token_304|>": 151978,
1287
+ "<|action_token_305|>": 151979,
1288
+ "<|action_token_306|>": 151980,
1289
+ "<|action_token_307|>": 151981,
1290
+ "<|action_token_308|>": 151982,
1291
+ "<|action_token_309|>": 151983,
1292
+ "<|action_token_30|>": 151704,
1293
+ "<|action_token_310|>": 151984,
1294
+ "<|action_token_311|>": 151985,
1295
+ "<|action_token_312|>": 151986,
1296
+ "<|action_token_313|>": 151987,
1297
+ "<|action_token_314|>": 151988,
1298
+ "<|action_token_315|>": 151989,
1299
+ "<|action_token_316|>": 151990,
1300
+ "<|action_token_317|>": 151991,
1301
+ "<|action_token_318|>": 151992,
1302
+ "<|action_token_319|>": 151993,
1303
+ "<|action_token_31|>": 151705,
1304
+ "<|action_token_320|>": 151994,
1305
+ "<|action_token_321|>": 151995,
1306
+ "<|action_token_322|>": 151996,
1307
+ "<|action_token_323|>": 151997,
1308
+ "<|action_token_324|>": 151998,
1309
+ "<|action_token_325|>": 151999,
1310
+ "<|action_token_326|>": 152000,
1311
+ "<|action_token_327|>": 152001,
1312
+ "<|action_token_328|>": 152002,
1313
+ "<|action_token_329|>": 152003,
1314
+ "<|action_token_32|>": 151706,
1315
+ "<|action_token_330|>": 152004,
1316
+ "<|action_token_331|>": 152005,
1317
+ "<|action_token_332|>": 152006,
1318
+ "<|action_token_333|>": 152007,
1319
+ "<|action_token_334|>": 152008,
1320
+ "<|action_token_335|>": 152009,
1321
+ "<|action_token_336|>": 152010,
1322
+ "<|action_token_337|>": 152011,
1323
+ "<|action_token_338|>": 152012,
1324
+ "<|action_token_339|>": 152013,
1325
+ "<|action_token_33|>": 151707,
1326
+ "<|action_token_340|>": 152014,
1327
+ "<|action_token_341|>": 152015,
1328
+ "<|action_token_342|>": 152016,
1329
+ "<|action_token_343|>": 152017,
1330
+ "<|action_token_344|>": 152018,
1331
+ "<|action_token_345|>": 152019,
1332
+ "<|action_token_346|>": 152020,
1333
+ "<|action_token_347|>": 152021,
1334
+ "<|action_token_348|>": 152022,
1335
+ "<|action_token_349|>": 152023,
1336
+ "<|action_token_34|>": 151708,
1337
+ "<|action_token_350|>": 152024,
1338
+ "<|action_token_351|>": 152025,
1339
+ "<|action_token_352|>": 152026,
1340
+ "<|action_token_353|>": 152027,
1341
+ "<|action_token_354|>": 152028,
1342
+ "<|action_token_355|>": 152029,
1343
+ "<|action_token_356|>": 152030,
1344
+ "<|action_token_357|>": 152031,
1345
+ "<|action_token_358|>": 152032,
1346
+ "<|action_token_359|>": 152033,
1347
+ "<|action_token_35|>": 151709,
1348
+ "<|action_token_360|>": 152034,
1349
+ "<|action_token_361|>": 152035,
1350
+ "<|action_token_362|>": 152036,
1351
+ "<|action_token_363|>": 152037,
1352
+ "<|action_token_364|>": 152038,
1353
+ "<|action_token_365|>": 152039,
1354
+ "<|action_token_366|>": 152040,
1355
+ "<|action_token_367|>": 152041,
1356
+ "<|action_token_368|>": 152042,
1357
+ "<|action_token_369|>": 152043,
1358
+ "<|action_token_36|>": 151710,
1359
+ "<|action_token_370|>": 152044,
1360
+ "<|action_token_371|>": 152045,
1361
+ "<|action_token_372|>": 152046,
1362
+ "<|action_token_373|>": 152047,
1363
+ "<|action_token_374|>": 152048,
1364
+ "<|action_token_375|>": 152049,
1365
+ "<|action_token_376|>": 152050,
1366
+ "<|action_token_377|>": 152051,
1367
+ "<|action_token_378|>": 152052,
1368
+ "<|action_token_379|>": 152053,
1369
+ "<|action_token_37|>": 151711,
1370
+ "<|action_token_380|>": 152054,
1371
+ "<|action_token_381|>": 152055,
1372
+ "<|action_token_382|>": 152056,
1373
+ "<|action_token_383|>": 152057,
1374
+ "<|action_token_384|>": 152058,
1375
+ "<|action_token_385|>": 152059,
1376
+ "<|action_token_386|>": 152060,
1377
+ "<|action_token_387|>": 152061,
1378
+ "<|action_token_388|>": 152062,
1379
+ "<|action_token_389|>": 152063,
1380
+ "<|action_token_38|>": 151712,
1381
+ "<|action_token_390|>": 152064,
1382
+ "<|action_token_391|>": 152065,
1383
+ "<|action_token_392|>": 152066,
1384
+ "<|action_token_393|>": 152067,
1385
+ "<|action_token_394|>": 152068,
1386
+ "<|action_token_395|>": 152069,
1387
+ "<|action_token_396|>": 152070,
1388
+ "<|action_token_397|>": 152071,
1389
+ "<|action_token_398|>": 152072,
1390
+ "<|action_token_399|>": 152073,
1391
+ "<|action_token_39|>": 151713,
1392
+ "<|action_token_3|>": 151677,
1393
+ "<|action_token_400|>": 152074,
1394
+ "<|action_token_401|>": 152075,
1395
+ "<|action_token_402|>": 152076,
1396
+ "<|action_token_403|>": 152077,
1397
+ "<|action_token_404|>": 152078,
1398
+ "<|action_token_405|>": 152079,
1399
+ "<|action_token_406|>": 152080,
1400
+ "<|action_token_407|>": 152081,
1401
+ "<|action_token_408|>": 152082,
1402
+ "<|action_token_409|>": 152083,
1403
+ "<|action_token_40|>": 151714,
1404
+ "<|action_token_410|>": 152084,
1405
+ "<|action_token_411|>": 152085,
1406
+ "<|action_token_412|>": 152086,
1407
+ "<|action_token_413|>": 152087,
1408
+ "<|action_token_414|>": 152088,
1409
+ "<|action_token_415|>": 152089,
1410
+ "<|action_token_416|>": 152090,
1411
+ "<|action_token_417|>": 152091,
1412
+ "<|action_token_418|>": 152092,
1413
+ "<|action_token_419|>": 152093,
1414
+ "<|action_token_41|>": 151715,
1415
+ "<|action_token_420|>": 152094,
1416
+ "<|action_token_421|>": 152095,
1417
+ "<|action_token_422|>": 152096,
1418
+ "<|action_token_423|>": 152097,
1419
+ "<|action_token_424|>": 152098,
1420
+ "<|action_token_425|>": 152099,
1421
+ "<|action_token_426|>": 152100,
1422
+ "<|action_token_427|>": 152101,
1423
+ "<|action_token_428|>": 152102,
1424
+ "<|action_token_429|>": 152103,
1425
+ "<|action_token_42|>": 151716,
1426
+ "<|action_token_430|>": 152104,
1427
+ "<|action_token_431|>": 152105,
1428
+ "<|action_token_432|>": 152106,
1429
+ "<|action_token_433|>": 152107,
1430
+ "<|action_token_434|>": 152108,
1431
+ "<|action_token_435|>": 152109,
1432
+ "<|action_token_436|>": 152110,
1433
+ "<|action_token_437|>": 152111,
1434
+ "<|action_token_438|>": 152112,
1435
+ "<|action_token_439|>": 152113,
1436
+ "<|action_token_43|>": 151717,
1437
+ "<|action_token_440|>": 152114,
1438
+ "<|action_token_441|>": 152115,
1439
+ "<|action_token_442|>": 152116,
1440
+ "<|action_token_443|>": 152117,
1441
+ "<|action_token_444|>": 152118,
1442
+ "<|action_token_445|>": 152119,
1443
+ "<|action_token_446|>": 152120,
1444
+ "<|action_token_447|>": 152121,
1445
+ "<|action_token_448|>": 152122,
1446
+ "<|action_token_449|>": 152123,
1447
+ "<|action_token_44|>": 151718,
1448
+ "<|action_token_450|>": 152124,
1449
+ "<|action_token_451|>": 152125,
1450
+ "<|action_token_452|>": 152126,
1451
+ "<|action_token_453|>": 152127,
1452
+ "<|action_token_454|>": 152128,
1453
+ "<|action_token_455|>": 152129,
1454
+ "<|action_token_456|>": 152130,
1455
+ "<|action_token_457|>": 152131,
1456
+ "<|action_token_458|>": 152132,
1457
+ "<|action_token_459|>": 152133,
1458
+ "<|action_token_45|>": 151719,
1459
+ "<|action_token_460|>": 152134,
1460
+ "<|action_token_461|>": 152135,
1461
+ "<|action_token_462|>": 152136,
1462
+ "<|action_token_463|>": 152137,
1463
+ "<|action_token_464|>": 152138,
1464
+ "<|action_token_465|>": 152139,
1465
+ "<|action_token_466|>": 152140,
1466
+ "<|action_token_467|>": 152141,
1467
+ "<|action_token_468|>": 152142,
1468
+ "<|action_token_469|>": 152143,
1469
+ "<|action_token_46|>": 151720,
1470
+ "<|action_token_470|>": 152144,
1471
+ "<|action_token_471|>": 152145,
1472
+ "<|action_token_472|>": 152146,
1473
+ "<|action_token_473|>": 152147,
1474
+ "<|action_token_474|>": 152148,
1475
+ "<|action_token_475|>": 152149,
1476
+ "<|action_token_476|>": 152150,
1477
+ "<|action_token_477|>": 152151,
1478
+ "<|action_token_478|>": 152152,
1479
+ "<|action_token_479|>": 152153,
1480
+ "<|action_token_47|>": 151721,
1481
+ "<|action_token_480|>": 152154,
1482
+ "<|action_token_481|>": 152155,
1483
+ "<|action_token_482|>": 152156,
1484
+ "<|action_token_483|>": 152157,
1485
+ "<|action_token_484|>": 152158,
1486
+ "<|action_token_485|>": 152159,
1487
+ "<|action_token_486|>": 152160,
1488
+ "<|action_token_487|>": 152161,
1489
+ "<|action_token_488|>": 152162,
1490
+ "<|action_token_489|>": 152163,
1491
+ "<|action_token_48|>": 151722,
1492
+ "<|action_token_490|>": 152164,
1493
+ "<|action_token_491|>": 152165,
1494
+ "<|action_token_492|>": 152166,
1495
+ "<|action_token_493|>": 152167,
1496
+ "<|action_token_494|>": 152168,
1497
+ "<|action_token_495|>": 152169,
1498
+ "<|action_token_496|>": 152170,
1499
+ "<|action_token_497|>": 152171,
1500
+ "<|action_token_498|>": 152172,
1501
+ "<|action_token_499|>": 152173,
1502
+ "<|action_token_49|>": 151723,
1503
+ "<|action_token_4|>": 151678,
1504
+ "<|action_token_500|>": 152174,
1505
+ "<|action_token_501|>": 152175,
1506
+ "<|action_token_502|>": 152176,
1507
+ "<|action_token_503|>": 152177,
1508
+ "<|action_token_504|>": 152178,
1509
+ "<|action_token_505|>": 152179,
1510
+ "<|action_token_506|>": 152180,
1511
+ "<|action_token_507|>": 152181,
1512
+ "<|action_token_508|>": 152182,
1513
+ "<|action_token_509|>": 152183,
1514
+ "<|action_token_50|>": 151724,
1515
+ "<|action_token_510|>": 152184,
1516
+ "<|action_token_511|>": 152185,
1517
+ "<|action_token_512|>": 152186,
1518
+ "<|action_token_513|>": 152187,
1519
+ "<|action_token_514|>": 152188,
1520
+ "<|action_token_515|>": 152189,
1521
+ "<|action_token_516|>": 152190,
1522
+ "<|action_token_517|>": 152191,
1523
+ "<|action_token_518|>": 152192,
1524
+ "<|action_token_519|>": 152193,
1525
+ "<|action_token_51|>": 151725,
1526
+ "<|action_token_520|>": 152194,
1527
+ "<|action_token_521|>": 152195,
1528
+ "<|action_token_522|>": 152196,
1529
+ "<|action_token_523|>": 152197,
1530
+ "<|action_token_524|>": 152198,
1531
+ "<|action_token_525|>": 152199,
1532
+ "<|action_token_526|>": 152200,
1533
+ "<|action_token_527|>": 152201,
1534
+ "<|action_token_528|>": 152202,
1535
+ "<|action_token_529|>": 152203,
1536
+ "<|action_token_52|>": 151726,
1537
+ "<|action_token_530|>": 152204,
1538
+ "<|action_token_531|>": 152205,
1539
+ "<|action_token_532|>": 152206,
1540
+ "<|action_token_533|>": 152207,
1541
+ "<|action_token_534|>": 152208,
1542
+ "<|action_token_535|>": 152209,
1543
+ "<|action_token_536|>": 152210,
1544
+ "<|action_token_537|>": 152211,
1545
+ "<|action_token_538|>": 152212,
1546
+ "<|action_token_539|>": 152213,
1547
+ "<|action_token_53|>": 151727,
1548
+ "<|action_token_540|>": 152214,
1549
+ "<|action_token_541|>": 152215,
1550
+ "<|action_token_542|>": 152216,
1551
+ "<|action_token_543|>": 152217,
1552
+ "<|action_token_544|>": 152218,
1553
+ "<|action_token_545|>": 152219,
1554
+ "<|action_token_546|>": 152220,
1555
+ "<|action_token_547|>": 152221,
1556
+ "<|action_token_548|>": 152222,
1557
+ "<|action_token_549|>": 152223,
1558
+ "<|action_token_54|>": 151728,
1559
+ "<|action_token_550|>": 152224,
1560
+ "<|action_token_551|>": 152225,
1561
+ "<|action_token_552|>": 152226,
1562
+ "<|action_token_553|>": 152227,
1563
+ "<|action_token_554|>": 152228,
1564
+ "<|action_token_555|>": 152229,
1565
+ "<|action_token_556|>": 152230,
1566
+ "<|action_token_557|>": 152231,
1567
+ "<|action_token_558|>": 152232,
1568
+ "<|action_token_559|>": 152233,
1569
+ "<|action_token_55|>": 151729,
1570
+ "<|action_token_560|>": 152234,
1571
+ "<|action_token_561|>": 152235,
1572
+ "<|action_token_562|>": 152236,
1573
+ "<|action_token_563|>": 152237,
1574
+ "<|action_token_564|>": 152238,
1575
+ "<|action_token_565|>": 152239,
1576
+ "<|action_token_566|>": 152240,
1577
+ "<|action_token_567|>": 152241,
1578
+ "<|action_token_568|>": 152242,
1579
+ "<|action_token_569|>": 152243,
1580
+ "<|action_token_56|>": 151730,
1581
+ "<|action_token_570|>": 152244,
1582
+ "<|action_token_571|>": 152245,
1583
+ "<|action_token_572|>": 152246,
1584
+ "<|action_token_573|>": 152247,
1585
+ "<|action_token_574|>": 152248,
1586
+ "<|action_token_575|>": 152249,
1587
+ "<|action_token_576|>": 152250,
1588
+ "<|action_token_577|>": 152251,
1589
+ "<|action_token_578|>": 152252,
1590
+ "<|action_token_579|>": 152253,
1591
+ "<|action_token_57|>": 151731,
1592
+ "<|action_token_580|>": 152254,
1593
+ "<|action_token_581|>": 152255,
1594
+ "<|action_token_582|>": 152256,
1595
+ "<|action_token_583|>": 152257,
1596
+ "<|action_token_584|>": 152258,
1597
+ "<|action_token_585|>": 152259,
1598
+ "<|action_token_586|>": 152260,
1599
+ "<|action_token_587|>": 152261,
1600
+ "<|action_token_588|>": 152262,
1601
+ "<|action_token_589|>": 152263,
1602
+ "<|action_token_58|>": 151732,
1603
+ "<|action_token_590|>": 152264,
1604
+ "<|action_token_591|>": 152265,
1605
+ "<|action_token_592|>": 152266,
1606
+ "<|action_token_593|>": 152267,
1607
+ "<|action_token_594|>": 152268,
1608
+ "<|action_token_595|>": 152269,
1609
+ "<|action_token_596|>": 152270,
1610
+ "<|action_token_597|>": 152271,
1611
+ "<|action_token_598|>": 152272,
1612
+ "<|action_token_599|>": 152273,
1613
+ "<|action_token_59|>": 151733,
1614
+ "<|action_token_5|>": 151679,
1615
+ "<|action_token_600|>": 152274,
1616
+ "<|action_token_601|>": 152275,
1617
+ "<|action_token_602|>": 152276,
1618
+ "<|action_token_603|>": 152277,
1619
+ "<|action_token_604|>": 152278,
1620
+ "<|action_token_605|>": 152279,
1621
+ "<|action_token_606|>": 152280,
1622
+ "<|action_token_607|>": 152281,
1623
+ "<|action_token_608|>": 152282,
1624
+ "<|action_token_609|>": 152283,
1625
+ "<|action_token_60|>": 151734,
1626
+ "<|action_token_610|>": 152284,
1627
+ "<|action_token_611|>": 152285,
1628
+ "<|action_token_612|>": 152286,
1629
+ "<|action_token_613|>": 152287,
1630
+ "<|action_token_614|>": 152288,
1631
+ "<|action_token_615|>": 152289,
1632
+ "<|action_token_616|>": 152290,
1633
+ "<|action_token_617|>": 152291,
1634
+ "<|action_token_618|>": 152292,
1635
+ "<|action_token_619|>": 152293,
1636
+ "<|action_token_61|>": 151735,
1637
+ "<|action_token_620|>": 152294,
1638
+ "<|action_token_621|>": 152295,
1639
+ "<|action_token_622|>": 152296,
1640
+ "<|action_token_623|>": 152297,
1641
+ "<|action_token_624|>": 152298,
1642
+ "<|action_token_625|>": 152299,
1643
+ "<|action_token_626|>": 152300,
1644
+ "<|action_token_627|>": 152301,
1645
+ "<|action_token_628|>": 152302,
1646
+ "<|action_token_629|>": 152303,
1647
+ "<|action_token_62|>": 151736,
1648
+ "<|action_token_630|>": 152304,
1649
+ "<|action_token_631|>": 152305,
1650
+ "<|action_token_632|>": 152306,
1651
+ "<|action_token_633|>": 152307,
1652
+ "<|action_token_634|>": 152308,
1653
+ "<|action_token_635|>": 152309,
1654
+ "<|action_token_636|>": 152310,
1655
+ "<|action_token_637|>": 152311,
1656
+ "<|action_token_638|>": 152312,
1657
+ "<|action_token_639|>": 152313,
1658
+ "<|action_token_63|>": 151737,
1659
+ "<|action_token_640|>": 152314,
1660
+ "<|action_token_641|>": 152315,
1661
+ "<|action_token_642|>": 152316,
1662
+ "<|action_token_643|>": 152317,
1663
+ "<|action_token_644|>": 152318,
1664
+ "<|action_token_645|>": 152319,
1665
+ "<|action_token_646|>": 152320,
1666
+ "<|action_token_647|>": 152321,
1667
+ "<|action_token_648|>": 152322,
1668
+ "<|action_token_649|>": 152323,
1669
+ "<|action_token_64|>": 151738,
1670
+ "<|action_token_650|>": 152324,
1671
+ "<|action_token_651|>": 152325,
1672
+ "<|action_token_652|>": 152326,
1673
+ "<|action_token_653|>": 152327,
1674
+ "<|action_token_654|>": 152328,
1675
+ "<|action_token_655|>": 152329,
1676
+ "<|action_token_656|>": 152330,
1677
+ "<|action_token_657|>": 152331,
1678
+ "<|action_token_658|>": 152332,
1679
+ "<|action_token_659|>": 152333,
1680
+ "<|action_token_65|>": 151739,
1681
+ "<|action_token_660|>": 152334,
1682
+ "<|action_token_661|>": 152335,
1683
+ "<|action_token_662|>": 152336,
1684
+ "<|action_token_663|>": 152337,
1685
+ "<|action_token_664|>": 152338,
1686
+ "<|action_token_665|>": 152339,
1687
+ "<|action_token_666|>": 152340,
1688
+ "<|action_token_667|>": 152341,
1689
+ "<|action_token_668|>": 152342,
1690
+ "<|action_token_669|>": 152343,
1691
+ "<|action_token_66|>": 151740,
1692
+ "<|action_token_670|>": 152344,
1693
+ "<|action_token_671|>": 152345,
1694
+ "<|action_token_672|>": 152346,
1695
+ "<|action_token_673|>": 152347,
1696
+ "<|action_token_674|>": 152348,
1697
+ "<|action_token_675|>": 152349,
1698
+ "<|action_token_676|>": 152350,
1699
+ "<|action_token_677|>": 152351,
1700
+ "<|action_token_678|>": 152352,
1701
+ "<|action_token_679|>": 152353,
1702
+ "<|action_token_67|>": 151741,
1703
+ "<|action_token_680|>": 152354,
1704
+ "<|action_token_681|>": 152355,
1705
+ "<|action_token_682|>": 152356,
1706
+ "<|action_token_683|>": 152357,
1707
+ "<|action_token_684|>": 152358,
1708
+ "<|action_token_685|>": 152359,
1709
+ "<|action_token_686|>": 152360,
1710
+ "<|action_token_687|>": 152361,
1711
+ "<|action_token_688|>": 152362,
1712
+ "<|action_token_689|>": 152363,
1713
+ "<|action_token_68|>": 151742,
1714
+ "<|action_token_690|>": 152364,
1715
+ "<|action_token_691|>": 152365,
1716
+ "<|action_token_692|>": 152366,
1717
+ "<|action_token_693|>": 152367,
1718
+ "<|action_token_694|>": 152368,
1719
+ "<|action_token_695|>": 152369,
1720
+ "<|action_token_696|>": 152370,
1721
+ "<|action_token_697|>": 152371,
1722
+ "<|action_token_698|>": 152372,
1723
+ "<|action_token_699|>": 152373,
1724
+ "<|action_token_69|>": 151743,
1725
+ "<|action_token_6|>": 151680,
1726
+ "<|action_token_700|>": 152374,
1727
+ "<|action_token_701|>": 152375,
1728
+ "<|action_token_702|>": 152376,
1729
+ "<|action_token_703|>": 152377,
1730
+ "<|action_token_704|>": 152378,
1731
+ "<|action_token_705|>": 152379,
1732
+ "<|action_token_706|>": 152380,
1733
+ "<|action_token_707|>": 152381,
1734
+ "<|action_token_708|>": 152382,
1735
+ "<|action_token_709|>": 152383,
1736
+ "<|action_token_70|>": 151744,
1737
+ "<|action_token_710|>": 152384,
1738
+ "<|action_token_711|>": 152385,
1739
+ "<|action_token_712|>": 152386,
1740
+ "<|action_token_713|>": 152387,
1741
+ "<|action_token_714|>": 152388,
1742
+ "<|action_token_715|>": 152389,
1743
+ "<|action_token_716|>": 152390,
1744
+ "<|action_token_717|>": 152391,
1745
+ "<|action_token_718|>": 152392,
1746
+ "<|action_token_719|>": 152393,
1747
+ "<|action_token_71|>": 151745,
1748
+ "<|action_token_720|>": 152394,
1749
+ "<|action_token_721|>": 152395,
1750
+ "<|action_token_722|>": 152396,
1751
+ "<|action_token_723|>": 152397,
1752
+ "<|action_token_724|>": 152398,
1753
+ "<|action_token_725|>": 152399,
1754
+ "<|action_token_726|>": 152400,
1755
+ "<|action_token_727|>": 152401,
1756
+ "<|action_token_728|>": 152402,
1757
+ "<|action_token_729|>": 152403,
1758
+ "<|action_token_72|>": 151746,
1759
+ "<|action_token_730|>": 152404,
1760
+ "<|action_token_731|>": 152405,
1761
+ "<|action_token_732|>": 152406,
1762
+ "<|action_token_733|>": 152407,
1763
+ "<|action_token_734|>": 152408,
1764
+ "<|action_token_735|>": 152409,
1765
+ "<|action_token_736|>": 152410,
1766
+ "<|action_token_737|>": 152411,
1767
+ "<|action_token_738|>": 152412,
1768
+ "<|action_token_739|>": 152413,
1769
+ "<|action_token_73|>": 151747,
1770
+ "<|action_token_740|>": 152414,
1771
+ "<|action_token_741|>": 152415,
1772
+ "<|action_token_742|>": 152416,
1773
+ "<|action_token_743|>": 152417,
1774
+ "<|action_token_744|>": 152418,
1775
+ "<|action_token_745|>": 152419,
1776
+ "<|action_token_746|>": 152420,
1777
+ "<|action_token_747|>": 152421,
1778
+ "<|action_token_748|>": 152422,
1779
+ "<|action_token_749|>": 152423,
1780
+ "<|action_token_74|>": 151748,
1781
+ "<|action_token_750|>": 152424,
1782
+ "<|action_token_751|>": 152425,
1783
+ "<|action_token_752|>": 152426,
1784
+ "<|action_token_753|>": 152427,
1785
+ "<|action_token_754|>": 152428,
1786
+ "<|action_token_755|>": 152429,
1787
+ "<|action_token_756|>": 152430,
1788
+ "<|action_token_757|>": 152431,
1789
+ "<|action_token_758|>": 152432,
1790
+ "<|action_token_759|>": 152433,
1791
+ "<|action_token_75|>": 151749,
1792
+ "<|action_token_760|>": 152434,
1793
+ "<|action_token_761|>": 152435,
1794
+ "<|action_token_762|>": 152436,
1795
+ "<|action_token_763|>": 152437,
1796
+ "<|action_token_764|>": 152438,
1797
+ "<|action_token_765|>": 152439,
1798
+ "<|action_token_766|>": 152440,
1799
+ "<|action_token_767|>": 152441,
1800
+ "<|action_token_768|>": 152442,
1801
+ "<|action_token_769|>": 152443,
1802
+ "<|action_token_76|>": 151750,
1803
+ "<|action_token_770|>": 152444,
1804
+ "<|action_token_771|>": 152445,
1805
+ "<|action_token_772|>": 152446,
1806
+ "<|action_token_773|>": 152447,
1807
+ "<|action_token_774|>": 152448,
1808
+ "<|action_token_775|>": 152449,
1809
+ "<|action_token_776|>": 152450,
1810
+ "<|action_token_777|>": 152451,
1811
+ "<|action_token_778|>": 152452,
1812
+ "<|action_token_779|>": 152453,
1813
+ "<|action_token_77|>": 151751,
1814
+ "<|action_token_780|>": 152454,
1815
+ "<|action_token_781|>": 152455,
1816
+ "<|action_token_782|>": 152456,
1817
+ "<|action_token_783|>": 152457,
1818
+ "<|action_token_784|>": 152458,
1819
+ "<|action_token_785|>": 152459,
1820
+ "<|action_token_786|>": 152460,
1821
+ "<|action_token_787|>": 152461,
1822
+ "<|action_token_788|>": 152462,
1823
+ "<|action_token_789|>": 152463,
1824
+ "<|action_token_78|>": 151752,
1825
+ "<|action_token_790|>": 152464,
1826
+ "<|action_token_791|>": 152465,
1827
+ "<|action_token_792|>": 152466,
1828
+ "<|action_token_793|>": 152467,
1829
+ "<|action_token_794|>": 152468,
1830
+ "<|action_token_795|>": 152469,
1831
+ "<|action_token_796|>": 152470,
1832
+ "<|action_token_797|>": 152471,
1833
+ "<|action_token_798|>": 152472,
1834
+ "<|action_token_799|>": 152473,
1835
+ "<|action_token_79|>": 151753,
1836
+ "<|action_token_7|>": 151681,
1837
+ "<|action_token_800|>": 152474,
1838
+ "<|action_token_801|>": 152475,
1839
+ "<|action_token_802|>": 152476,
1840
+ "<|action_token_803|>": 152477,
1841
+ "<|action_token_804|>": 152478,
1842
+ "<|action_token_805|>": 152479,
1843
+ "<|action_token_806|>": 152480,
1844
+ "<|action_token_807|>": 152481,
1845
+ "<|action_token_808|>": 152482,
1846
+ "<|action_token_809|>": 152483,
1847
+ "<|action_token_80|>": 151754,
1848
+ "<|action_token_810|>": 152484,
1849
+ "<|action_token_811|>": 152485,
1850
+ "<|action_token_812|>": 152486,
1851
+ "<|action_token_813|>": 152487,
1852
+ "<|action_token_814|>": 152488,
1853
+ "<|action_token_815|>": 152489,
1854
+ "<|action_token_816|>": 152490,
1855
+ "<|action_token_817|>": 152491,
1856
+ "<|action_token_818|>": 152492,
1857
+ "<|action_token_819|>": 152493,
1858
+ "<|action_token_81|>": 151755,
1859
+ "<|action_token_820|>": 152494,
1860
+ "<|action_token_821|>": 152495,
1861
+ "<|action_token_822|>": 152496,
1862
+ "<|action_token_823|>": 152497,
1863
+ "<|action_token_824|>": 152498,
1864
+ "<|action_token_825|>": 152499,
1865
+ "<|action_token_826|>": 152500,
1866
+ "<|action_token_827|>": 152501,
1867
+ "<|action_token_828|>": 152502,
1868
+ "<|action_token_829|>": 152503,
1869
+ "<|action_token_82|>": 151756,
1870
+ "<|action_token_830|>": 152504,
1871
+ "<|action_token_831|>": 152505,
1872
+ "<|action_token_832|>": 152506,
1873
+ "<|action_token_833|>": 152507,
1874
+ "<|action_token_834|>": 152508,
1875
+ "<|action_token_835|>": 152509,
1876
+ "<|action_token_836|>": 152510,
1877
+ "<|action_token_837|>": 152511,
1878
+ "<|action_token_838|>": 152512,
1879
+ "<|action_token_839|>": 152513,
1880
+ "<|action_token_83|>": 151757,
1881
+ "<|action_token_840|>": 152514,
1882
+ "<|action_token_841|>": 152515,
1883
+ "<|action_token_842|>": 152516,
1884
+ "<|action_token_843|>": 152517,
1885
+ "<|action_token_844|>": 152518,
1886
+ "<|action_token_845|>": 152519,
1887
+ "<|action_token_846|>": 152520,
1888
+ "<|action_token_847|>": 152521,
1889
+ "<|action_token_848|>": 152522,
1890
+ "<|action_token_849|>": 152523,
1891
+ "<|action_token_84|>": 151758,
1892
+ "<|action_token_850|>": 152524,
1893
+ "<|action_token_851|>": 152525,
1894
+ "<|action_token_852|>": 152526,
1895
+ "<|action_token_853|>": 152527,
1896
+ "<|action_token_854|>": 152528,
1897
+ "<|action_token_855|>": 152529,
1898
+ "<|action_token_856|>": 152530,
1899
+ "<|action_token_857|>": 152531,
1900
+ "<|action_token_858|>": 152532,
1901
+ "<|action_token_859|>": 152533,
1902
+ "<|action_token_85|>": 151759,
1903
+ "<|action_token_860|>": 152534,
1904
+ "<|action_token_861|>": 152535,
1905
+ "<|action_token_862|>": 152536,
1906
+ "<|action_token_863|>": 152537,
1907
+ "<|action_token_864|>": 152538,
1908
+ "<|action_token_865|>": 152539,
1909
+ "<|action_token_866|>": 152540,
1910
+ "<|action_token_867|>": 152541,
1911
+ "<|action_token_868|>": 152542,
1912
+ "<|action_token_869|>": 152543,
1913
+ "<|action_token_86|>": 151760,
1914
+ "<|action_token_870|>": 152544,
1915
+ "<|action_token_871|>": 152545,
1916
+ "<|action_token_872|>": 152546,
1917
+ "<|action_token_873|>": 152547,
1918
+ "<|action_token_874|>": 152548,
1919
+ "<|action_token_875|>": 152549,
1920
+ "<|action_token_876|>": 152550,
1921
+ "<|action_token_877|>": 152551,
1922
+ "<|action_token_878|>": 152552,
1923
+ "<|action_token_879|>": 152553,
1924
+ "<|action_token_87|>": 151761,
1925
+ "<|action_token_880|>": 152554,
1926
+ "<|action_token_881|>": 152555,
1927
+ "<|action_token_882|>": 152556,
1928
+ "<|action_token_883|>": 152557,
1929
+ "<|action_token_884|>": 152558,
1930
+ "<|action_token_885|>": 152559,
1931
+ "<|action_token_886|>": 152560,
1932
+ "<|action_token_887|>": 152561,
1933
+ "<|action_token_888|>": 152562,
1934
+ "<|action_token_889|>": 152563,
1935
+ "<|action_token_88|>": 151762,
1936
+ "<|action_token_890|>": 152564,
1937
+ "<|action_token_891|>": 152565,
1938
+ "<|action_token_892|>": 152566,
1939
+ "<|action_token_893|>": 152567,
1940
+ "<|action_token_894|>": 152568,
1941
+ "<|action_token_895|>": 152569,
1942
+ "<|action_token_896|>": 152570,
1943
+ "<|action_token_897|>": 152571,
1944
+ "<|action_token_898|>": 152572,
1945
+ "<|action_token_899|>": 152573,
1946
+ "<|action_token_89|>": 151763,
1947
+ "<|action_token_8|>": 151682,
1948
+ "<|action_token_900|>": 152574,
1949
+ "<|action_token_901|>": 152575,
1950
+ "<|action_token_902|>": 152576,
1951
+ "<|action_token_903|>": 152577,
1952
+ "<|action_token_904|>": 152578,
1953
+ "<|action_token_905|>": 152579,
1954
+ "<|action_token_906|>": 152580,
1955
+ "<|action_token_907|>": 152581,
1956
+ "<|action_token_908|>": 152582,
1957
+ "<|action_token_909|>": 152583,
1958
+ "<|action_token_90|>": 151764,
1959
+ "<|action_token_910|>": 152584,
1960
+ "<|action_token_911|>": 152585,
1961
+ "<|action_token_912|>": 152586,
1962
+ "<|action_token_913|>": 152587,
1963
+ "<|action_token_914|>": 152588,
1964
+ "<|action_token_915|>": 152589,
1965
+ "<|action_token_916|>": 152590,
1966
+ "<|action_token_917|>": 152591,
1967
+ "<|action_token_918|>": 152592,
1968
+ "<|action_token_919|>": 152593,
1969
+ "<|action_token_91|>": 151765,
1970
+ "<|action_token_920|>": 152594,
1971
+ "<|action_token_921|>": 152595,
1972
+ "<|action_token_922|>": 152596,
1973
+ "<|action_token_923|>": 152597,
1974
+ "<|action_token_924|>": 152598,
1975
+ "<|action_token_925|>": 152599,
1976
+ "<|action_token_926|>": 152600,
1977
+ "<|action_token_927|>": 152601,
1978
+ "<|action_token_928|>": 152602,
1979
+ "<|action_token_929|>": 152603,
1980
+ "<|action_token_92|>": 151766,
1981
+ "<|action_token_930|>": 152604,
1982
+ "<|action_token_931|>": 152605,
1983
+ "<|action_token_932|>": 152606,
1984
+ "<|action_token_933|>": 152607,
1985
+ "<|action_token_934|>": 152608,
1986
+ "<|action_token_935|>": 152609,
1987
+ "<|action_token_936|>": 152610,
1988
+ "<|action_token_937|>": 152611,
1989
+ "<|action_token_938|>": 152612,
1990
+ "<|action_token_939|>": 152613,
1991
+ "<|action_token_93|>": 151767,
1992
+ "<|action_token_940|>": 152614,
1993
+ "<|action_token_941|>": 152615,
1994
+ "<|action_token_942|>": 152616,
1995
+ "<|action_token_943|>": 152617,
1996
+ "<|action_token_944|>": 152618,
1997
+ "<|action_token_945|>": 152619,
1998
+ "<|action_token_946|>": 152620,
1999
+ "<|action_token_947|>": 152621,
2000
+ "<|action_token_948|>": 152622,
2001
+ "<|action_token_949|>": 152623,
2002
+ "<|action_token_94|>": 151768,
2003
+ "<|action_token_950|>": 152624,
2004
+ "<|action_token_951|>": 152625,
2005
+ "<|action_token_952|>": 152626,
2006
+ "<|action_token_953|>": 152627,
2007
+ "<|action_token_954|>": 152628,
2008
+ "<|action_token_955|>": 152629,
2009
+ "<|action_token_956|>": 152630,
2010
+ "<|action_token_957|>": 152631,
2011
+ "<|action_token_958|>": 152632,
2012
+ "<|action_token_959|>": 152633,
2013
+ "<|action_token_95|>": 151769,
2014
+ "<|action_token_960|>": 152634,
2015
+ "<|action_token_961|>": 152635,
2016
+ "<|action_token_962|>": 152636,
2017
+ "<|action_token_963|>": 152637,
2018
+ "<|action_token_964|>": 152638,
2019
+ "<|action_token_965|>": 152639,
2020
+ "<|action_token_966|>": 152640,
2021
+ "<|action_token_967|>": 152641,
2022
+ "<|action_token_968|>": 152642,
2023
+ "<|action_token_969|>": 152643,
2024
+ "<|action_token_96|>": 151770,
2025
+ "<|action_token_970|>": 152644,
2026
+ "<|action_token_971|>": 152645,
2027
+ "<|action_token_972|>": 152646,
2028
+ "<|action_token_973|>": 152647,
2029
+ "<|action_token_974|>": 152648,
2030
+ "<|action_token_975|>": 152649,
2031
+ "<|action_token_976|>": 152650,
2032
+ "<|action_token_977|>": 152651,
2033
+ "<|action_token_978|>": 152652,
2034
+ "<|action_token_979|>": 152653,
2035
+ "<|action_token_97|>": 151771,
2036
+ "<|action_token_980|>": 152654,
2037
+ "<|action_token_981|>": 152655,
2038
+ "<|action_token_982|>": 152656,
2039
+ "<|action_token_983|>": 152657,
2040
+ "<|action_token_984|>": 152658,
2041
+ "<|action_token_985|>": 152659,
2042
+ "<|action_token_986|>": 152660,
2043
+ "<|action_token_987|>": 152661,
2044
+ "<|action_token_988|>": 152662,
2045
+ "<|action_token_989|>": 152663,
2046
+ "<|action_token_98|>": 151772,
2047
+ "<|action_token_990|>": 152664,
2048
+ "<|action_token_991|>": 152665,
2049
+ "<|action_token_992|>": 152666,
2050
+ "<|action_token_993|>": 152667,
2051
+ "<|action_token_994|>": 152668,
2052
+ "<|action_token_995|>": 152669,
2053
+ "<|action_token_996|>": 152670,
2054
+ "<|action_token_997|>": 152671,
2055
+ "<|action_token_998|>": 152672,
2056
+ "<|action_token_999|>": 152673,
2057
+ "<|action_token_99|>": 151773,
2058
+ "<|action_token_9|>": 151683,
2059
+ "<|box_end|>": 151649,
2060
+ "<|box_start|>": 151648,
2061
+ "<|endoftext|>": 151643,
2062
+ "<|file_sep|>": 151664,
2063
+ "<|fim_middle|>": 151660,
2064
+ "<|fim_pad|>": 151662,
2065
+ "<|fim_prefix|>": 151659,
2066
+ "<|fim_suffix|>": 151661,
2067
+ "<|goal_repr|>": 151672,
2068
+ "<|im_end|>": 151645,
2069
+ "<|im_start|>": 151644,
2070
+ "<|image_pad|>": 151655,
2071
+ "<|object_ref_end|>": 151647,
2072
+ "<|object_ref_start|>": 151646,
2073
+ "<|obs_repr|>": 151673,
2074
+ "<|quad_end|>": 151651,
2075
+ "<|quad_start|>": 151650,
2076
+ "<|repo_name|>": 151663,
2077
+ "<|video_pad|>": 151656,
2078
+ "<|vision_end|>": 151653,
2079
+ "<|vision_pad|>": 151654,
2080
+ "<|vision_start|>": 151652
2081
+ }
chat_template.jinja ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {%- if messages[0].content is string %}
5
+ {{- messages[0].content }}
6
+ {%- else %}
7
+ {%- for content in messages[0].content %}
8
+ {%- if 'text' in content %}
9
+ {{- content.text }}
10
+ {%- endif %}
11
+ {%- endfor %}
12
+ {%- endif %}
13
+ {{- '\n\n' }}
14
+ {%- endif %}
15
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
16
+ {%- for tool in tools %}
17
+ {{- "\n" }}
18
+ {{- tool | tojson }}
19
+ {%- endfor %}
20
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
21
+ {%- else %}
22
+ {%- if messages[0].role == 'system' %}
23
+ {{- '<|im_start|>system\n' }}
24
+ {%- if messages[0].content is string %}
25
+ {{- messages[0].content }}
26
+ {%- else %}
27
+ {%- for content in messages[0].content %}
28
+ {%- if 'text' in content %}
29
+ {{- content.text }}
30
+ {%- endif %}
31
+ {%- endfor %}
32
+ {%- endif %}
33
+ {{- '<|im_end|>\n' }}
34
+ {%- endif %}
35
+ {%- endif %}
36
+ {%- set image_count = namespace(value=0) %}
37
+ {%- set video_count = namespace(value=0) %}
38
+ {%- for message in messages %}
39
+ {%- if message.role == "user" %}
40
+ {{- '<|im_start|>' + message.role + '\n' }}
41
+ {%- if message.content is string %}
42
+ {{- message.content }}
43
+ {%- else %}
44
+ {%- for content in message.content %}
45
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
46
+ {%- set image_count.value = image_count.value + 1 %}
47
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
48
+ <|vision_start|><|image_pad|><|vision_end|>
49
+ {%- elif content.type == 'video' or 'video' in content %}
50
+ {%- set video_count.value = video_count.value + 1 %}
51
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
52
+ <|vision_start|><|video_pad|><|vision_end|>
53
+ {%- elif 'text' in content %}
54
+ {{- content.text }}
55
+ {%- endif %}
56
+ {%- endfor %}
57
+ {%- endif %}
58
+ {{- '<|im_end|>\n' }}
59
+ {%- elif message.role == "assistant" %}
60
+ {{- '<|im_start|>' + message.role + '\n' }}
61
+ {%- if message.content is string %}
62
+ {{- message.content }}
63
+ {%- else %}
64
+ {%- for content_item in message.content %}
65
+ {%- if 'text' in content_item %}
66
+ {{- content_item.text }}
67
+ {%- endif %}
68
+ {%- endfor %}
69
+ {%- endif %}
70
+ {%- if message.tool_calls %}
71
+ {%- for tool_call in message.tool_calls %}
72
+ {%- if (loop.first and message.content) or (not loop.first) %}
73
+ {{- '\n' }}
74
+ {%- endif %}
75
+ {%- if tool_call.function %}
76
+ {%- set tool_call = tool_call.function %}
77
+ {%- endif %}
78
+ {{- '<tool_call>\n{"name": "' }}
79
+ {{- tool_call.name }}
80
+ {{- '", "arguments": ' }}
81
+ {%- if tool_call.arguments is string %}
82
+ {{- tool_call.arguments }}
83
+ {%- else %}
84
+ {{- tool_call.arguments | tojson }}
85
+ {%- endif %}
86
+ {{- '}\n</tool_call>' }}
87
+ {%- endfor %}
88
+ {%- endif %}
89
+ {{- '<|im_end|>\n' }}
90
+ {%- elif message.role == "tool" %}
91
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
92
+ {{- '<|im_start|>user' }}
93
+ {%- endif %}
94
+ {{- '\n<tool_response>\n' }}
95
+ {%- if message.content is string %}
96
+ {{- message.content }}
97
+ {%- else %}
98
+ {%- for content in message.content %}
99
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
100
+ {%- set image_count.value = image_count.value + 1 %}
101
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
102
+ <|vision_start|><|image_pad|><|vision_end|>
103
+ {%- elif content.type == 'video' or 'video' in content %}
104
+ {%- set video_count.value = video_count.value + 1 %}
105
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
106
+ <|vision_start|><|video_pad|><|vision_end|>
107
+ {%- elif 'text' in content %}
108
+ {{- content.text }}
109
+ {%- endif %}
110
+ {%- endfor %}
111
+ {%- endif %}
112
+ {{- '\n</tool_response>' }}
113
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
114
+ {{- '<|im_end|>\n' }}
115
+ {%- endif %}
116
+ {%- endif %}
117
+ {%- endfor %}
118
+ {%- if add_generation_prompt %}
119
+ {{- '<|im_start|>assistant\n' }}
120
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_chunk_size": 20,
3
+ "action_expert_config": {
4
+ "action_end_token_id": null,
5
+ "action_start_token_id": 151669,
6
+ "action_token_id": 151670,
7
+ "attention_bias": false,
8
+ "attention_dropout": 0.0,
9
+ "bos_token_id": 151643,
10
+ "crl_goal_repr_token_id": 151672,
11
+ "crl_obs_repr_token_id": 151673,
12
+ "dtype": "bfloat16",
13
+ "eos_token_id": 151645,
14
+ "head_dim": 128,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 1280,
17
+ "image_token_id": 151655,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 2432,
20
+ "max_position_embeddings": 262144,
21
+ "model_type": "prts_qwen3_vl_text",
22
+ "num_attention_heads": 32,
23
+ "num_hidden_layers": 36,
24
+ "num_key_value_heads": 8,
25
+ "rms_norm_eps": 1e-06,
26
+ "rope_scaling": {
27
+ "mrope_interleaved": true,
28
+ "mrope_section": [
29
+ 24,
30
+ 20,
31
+ 20
32
+ ],
33
+ "rope_type": "default"
34
+ },
35
+ "rope_theta": 5000000,
36
+ "tie_word_embeddings": true,
37
+ "use_cache": true,
38
+ "video_token_id": 151656,
39
+ "vision_start_token_id": 151652,
40
+ "vocab_size": 153722
41
+ },
42
+ "action_start_token_id": 151669,
43
+ "architectures": [
44
+ "PRTS_Qwen3VL"
45
+ ],
46
+ "auto_map": {
47
+ "AutoConfig": "configuration_prts_qwen3_vl.PRTS_FlowMatchingConfig_Qwen3VL",
48
+ "AutoModel": "modeling_prts_qwen3_vl.PRTS_Qwen3VL"
49
+ },
50
+ "crl_embed_dim": 256,
51
+ "crl_encoder_init_w": 0.001,
52
+ "crl_goal_repr_token_id": 151672,
53
+ "crl_logsumexp_reg_weight": 0.0,
54
+ "crl_loss_weight": 0.0,
55
+ "crl_obs_repr_token_id": 151673,
56
+ "crl_repr_norm": true,
57
+ "dit_action_head_config": {
58
+ "add_pos_embed": true,
59
+ "attend_text_every_n_blocks": 2,
60
+ "attention_head_dim": 48,
61
+ "attn_implementation": "sdpa",
62
+ "dropout": 0.2,
63
+ "final_dropout": true,
64
+ "interleave_self_attention": true,
65
+ "mlp_mult": 4,
66
+ "noise_beta_alpha": 1.5,
67
+ "noise_beta_beta": 1.0,
68
+ "noise_s": 0.999,
69
+ "norm_type": "ada_norm",
70
+ "num_attention_heads": 32,
71
+ "num_layers": 16,
72
+ "num_timestep_buckets": 1000,
73
+ "output_dim": 1024,
74
+ "use_alternate_vl_dit": true,
75
+ "use_mot_action_expert": true
76
+ },
77
+ "dtype": "bfloat16",
78
+ "embodiment_tag": "libero_panda",
79
+ "flow_matching_action_loss_weight": 1.0,
80
+ "flow_matching_sub_goal_loss_weight": 0.0,
81
+ "image_token_id": 151655,
82
+ "label2id": null,
83
+ "max_action_dim": 32,
84
+ "model_type": "prts_qwen3_vl",
85
+ "num_denoise_steps": 5,
86
+ "pad_token_id": 151643,
87
+ "text_config": {
88
+ "action_end_token_id": null,
89
+ "action_start_token_id": 151669,
90
+ "action_token_id": 151670,
91
+ "attention_bias": false,
92
+ "attention_dropout": 0.0,
93
+ "bos_token_id": 151643,
94
+ "crl_goal_repr_token_id": 151672,
95
+ "crl_obs_repr_token_id": 151673,
96
+ "dtype": "bfloat16",
97
+ "eos_token_id": 151645,
98
+ "head_dim": 128,
99
+ "hidden_act": "silu",
100
+ "hidden_size": 2560,
101
+ "image_token_id": 151655,
102
+ "initializer_range": 0.02,
103
+ "intermediate_size": 9728,
104
+ "max_position_embeddings": 262144,
105
+ "model_type": "prts_qwen3_vl_text",
106
+ "num_attention_heads": 32,
107
+ "num_hidden_layers": 36,
108
+ "num_key_value_heads": 8,
109
+ "rms_norm_eps": 1e-06,
110
+ "rope_scaling": {
111
+ "mrope_interleaved": true,
112
+ "mrope_section": [
113
+ 24,
114
+ 20,
115
+ 20
116
+ ],
117
+ "rope_type": "default"
118
+ },
119
+ "rope_theta": 5000000,
120
+ "tie_word_embeddings": true,
121
+ "use_cache": false,
122
+ "video_token_id": 151656,
123
+ "vision_start_token_id": 151652,
124
+ "vocab_size": 153722
125
+ },
126
+ "tie_word_embeddings": true,
127
+ "transformers_version": "4.57.3",
128
+ "use_cache": true,
129
+ "use_fast_action_tokenizer": true,
130
+ "video_token_id": 151656,
131
+ "vision_config": {
132
+ "deepstack_visual_indexes": [
133
+ 5,
134
+ 11,
135
+ 17
136
+ ],
137
+ "depth": 24,
138
+ "dtype": "bfloat16",
139
+ "hidden_act": "gelu_pytorch_tanh",
140
+ "hidden_size": 1024,
141
+ "in_channels": 3,
142
+ "initializer_range": 0.02,
143
+ "intermediate_size": 4096,
144
+ "model_type": "qwen3_vl",
145
+ "num_heads": 16,
146
+ "num_position_embeddings": 2304,
147
+ "out_hidden_size": 2560,
148
+ "patch_size": 16,
149
+ "spatial_merge_size": 2,
150
+ "temporal_patch_size": 2
151
+ },
152
+ "vision_end_token_id": 151653,
153
+ "vision_start_token_id": 151652,
154
+ "vocab_size": 153722
155
+ }
configuration_prts_qwen3_vl.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 TeleAI Rhodes Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Configuration classes for PRTS built on Qwen3-VL."""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig
20
+
21
+
22
+ class PRTS_Qwen3VLTextConfig(PretrainedConfig):
23
+ r"""
24
+ This is the configuration class to store the configuration of a PRTS Text Model based on Qwen3-VL.
25
+ It extends PretrainedConfig with Qwen3-VL text model parameters and PRTS-specific parameters.
26
+
27
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
28
+
29
+ Args:
30
+ vocab_size (`int`, *optional*, defaults to 151936):
31
+ Vocabulary size of the Qwen3VL model.
32
+ hidden_size (`int`, *optional*, defaults to 4096):
33
+ Dimension of the hidden representations.
34
+ intermediate_size (`int`, *optional*, defaults to 22016):
35
+ Dimension of the MLP representations.
36
+ num_hidden_layers (`int`, *optional*, defaults to 32):
37
+ Number of hidden layers in the Transformer encoder.
38
+ num_attention_heads (`int`, *optional*, defaults to 32):
39
+ Number of attention heads for each attention layer.
40
+ num_key_value_heads (`int`, *optional*, defaults to 32):
41
+ Number of key-value heads for Grouped Query Attention.
42
+ head_dim (`int`, *optional*, defaults to 128):
43
+ The dimension of the head.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
45
+ The non-linear activation function.
46
+ max_position_embeddings (`int`, *optional*, defaults to 128000):
47
+ The maximum sequence length.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer.
50
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
51
+ The epsilon used by the rms normalization layers.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions.
54
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
55
+ Whether the model's input and output word embeddings should be tied.
56
+ rope_theta (`float`, *optional*, defaults to 5000000.0):
57
+ The base period of the RoPE embeddings.
58
+ rope_scaling (`Dict`, *optional*):
59
+ Dictionary containing the scaling configuration for the RoPE embeddings.
60
+ attention_bias (`bool`, *optional*, defaults to `False`):
61
+ Whether to use a bias in the query, key, value and output projection layers.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ image_token_id (`int`, *optional*):
65
+ Token index used as placeholder for image embeddings.
66
+ video_token_id (`int`, *optional*):
67
+ Token index used as placeholder for video embeddings.
68
+ action_token_id (`int`, *optional*):
69
+ Token index used as placeholder for action embeddings.
70
+ action_start_token_id (`int`, *optional*):
71
+ Token index for action sequence start.
72
+ action_end_token_id (`int`, *optional*):
73
+ Token index for action sequence end.
74
+ vision_start_token_id (`int`, *optional*):
75
+ Token index for vision sequence start.
76
+ **kwargs:
77
+ Additional keyword arguments passed to PretrainedConfig.
78
+ """
79
+
80
+ model_type = "prts_qwen3_vl_text" # TODO (zy): check if this is correct
81
+ base_config_key = "text_config"
82
+
83
+ def __init__(
84
+ self,
85
+ vocab_size=151936,
86
+ hidden_size=4096,
87
+ intermediate_size=22016,
88
+ num_hidden_layers=32,
89
+ num_attention_heads=32,
90
+ num_key_value_heads=32,
91
+ head_dim=128,
92
+ hidden_act="silu",
93
+ max_position_embeddings=128000,
94
+ initializer_range=0.02,
95
+ rms_norm_eps=1e-6,
96
+ use_cache=True,
97
+ tie_word_embeddings=False,
98
+ rope_theta=5000000.0,
99
+ rope_scaling=None,
100
+ attention_bias=False,
101
+ attention_dropout=0.0,
102
+ # PRTS specific
103
+ action_token_id=None,
104
+ action_start_token_id=None,
105
+ action_end_token_id=None,
106
+ crl_goal_repr_token_id=None,
107
+ crl_obs_repr_token_id=None,
108
+ **kwargs,
109
+ ):
110
+ self.vocab_size = vocab_size
111
+ self.max_position_embeddings = max_position_embeddings
112
+ self.hidden_size = hidden_size
113
+ self.intermediate_size = intermediate_size
114
+ self.num_hidden_layers = num_hidden_layers
115
+ self.num_attention_heads = num_attention_heads
116
+
117
+ # for backward compatibility
118
+ if num_key_value_heads is None:
119
+ num_key_value_heads = num_attention_heads
120
+
121
+ self.num_key_value_heads = num_key_value_heads
122
+ self.head_dim = head_dim
123
+ self.hidden_act = hidden_act
124
+ self.initializer_range = initializer_range
125
+ self.rms_norm_eps = rms_norm_eps
126
+ self.use_cache = use_cache
127
+ self.rope_theta = rope_theta
128
+ self.rope_scaling = rope_scaling
129
+ self.attention_bias = attention_bias
130
+ self.attention_dropout = attention_dropout
131
+
132
+ # Validate rope config
133
+ rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
134
+
135
+ # PRTS specific token IDs
136
+ self.action_token_id = action_token_id
137
+ self.action_start_token_id = action_start_token_id
138
+ self.action_end_token_id = action_end_token_id
139
+ self.crl_goal_repr_token_id = crl_goal_repr_token_id
140
+ self.crl_obs_repr_token_id = crl_obs_repr_token_id
141
+
142
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
143
+
144
+
145
+ class PRTS_FlowMatchingConfig_Qwen3VL(PretrainedConfig):
146
+ r"""
147
+ This is the configuration class to store the configuration of a PRTS model based on Qwen3-VL.
148
+ It extends PretrainedConfig with Qwen3-VL model parameters and PRTS-specific parameters for action prediction.
149
+
150
+ [`PRTS_FlowMatchingConfig_Qwen3VL`] is the configuration class to store the configuration of a PRTS model. It is used to
151
+ instantiate a PRTS model according to the specified arguments, defining the vision encoder, text encoder,
152
+ action expert, and flow matching components.
153
+
154
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
155
+
156
+ Args:
157
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `PRTS_Qwen3VLTextConfig`):
158
+ The config object or dictionary of the text backbone.
159
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
160
+ The config object or dictionary of the vision backbone.
161
+ max_action_dim (`int`, *optional*, defaults to 14):
162
+ Maximum dimension of action vectors. Used for padding different robot action spaces.
163
+ action_chunk_size (`int`, *optional*, defaults to 100):
164
+ Number of action timesteps to predict in each forward pass.
165
+ num_denoise_steps (`int`, *optional*, defaults to 4):
166
+ Number of denoising steps for flow matching during inference.
167
+ flow_matching_action_loss_weight (`float`, *optional*, defaults to 1.0):
168
+ Weight for the flow matching action loss.
169
+ crl_loss_weight (`float`, *optional*, defaults to 0.0):
170
+ Weight for the Contrastive Reinforcement Learning (CRL) loss. Set to 0 to disable.
171
+ crl_embed_dim (`int`, *optional*, defaults to 256):
172
+ Dimension of the CRL embedding space for action and goal encoders.
173
+ crl_logsumexp_reg_weight (`float`, *optional*, defaults to 0.0):
174
+ Weight for logsumexp regularization on CRL logits.
175
+ image_token_id (`int`, *optional*):
176
+ Token id for image placeholders.
177
+ video_token_id (`int`, *optional*):
178
+ Token id for video placeholders.
179
+ vision_start_token_id (`int`, *optional*):
180
+ Token id for vision start marker.
181
+ vision_end_token_id (`int`, *optional*):
182
+ Token id for vision end marker.
183
+ **kwargs:
184
+ Additional keyword arguments passed to PretrainedConfig.
185
+
186
+ Example:
187
+
188
+ ```python
189
+ >>> from prts.models import PRTS_FlowMatchingConfig_Qwen3VL, PRTS_Qwen3VL
190
+
191
+ >>> # Initializing a PRTS Qwen3-VL configuration
192
+ >>> configuration = PRTS_FlowMatchingConfig_Qwen3VL()
193
+
194
+ >>> # Initializing a model from the configuration
195
+ >>> model = PRTS_Qwen3VL(configuration)
196
+
197
+ >>> # Accessing the model configuration
198
+ >>> configuration = model.config
199
+ ```
200
+ """
201
+
202
+ model_type = "prts_qwen3_vl"
203
+ sub_configs = {
204
+ "vision_config": Qwen3VLVisionConfig,
205
+ "text_config": PRTS_Qwen3VLTextConfig,
206
+ }
207
+ keys_to_ignore_at_inference = ["past_key_values"]
208
+
209
+ def __init__(
210
+ self,
211
+ text_config=None,
212
+ vision_config=None,
213
+ image_token_id=151655,
214
+ video_token_id=151656,
215
+ vision_start_token_id=151652,
216
+ vision_end_token_id=151653,
217
+ tie_word_embeddings=False,
218
+ # PRTS specific
219
+ max_action_dim=32,
220
+ action_chunk_size=50,
221
+ num_denoise_steps=4,
222
+ flow_matching_action_loss_weight=0.,
223
+ use_fast_action_tokenizer=True,
224
+ # Embodiment tag: identifies the robot embodiment used for finetuning.
225
+ # Stores the delta_action_mask key so eval code can recover it without
226
+ # needing the training dataset config.
227
+ embodiment_tag=None,
228
+ # DiT action head config
229
+ dit_action_head_config=None,
230
+ # CRL (Contrastive Reinforcement Learning) parameters
231
+ crl_loss_weight=0.,
232
+ crl_embed_dim=256,
233
+ crl_logsumexp_reg_weight=0.0,
234
+ crl_encoder_init_w=1e-12, # Cold initialization weight for encoder last layer
235
+ crl_repr_norm=True, # Whether to L2-normalize CRL representations
236
+ **kwargs,
237
+ ):
238
+ # Initialize vision config
239
+ if isinstance(vision_config, dict):
240
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
241
+ elif vision_config is None:
242
+ self.vision_config = self.sub_configs["vision_config"]()
243
+
244
+ # Initialize text config
245
+ if isinstance(text_config, dict):
246
+ self.text_config = self.sub_configs["text_config"](**text_config)
247
+ elif text_config is None:
248
+ # For BC use all kwargs to init `TextConfig`
249
+ self.text_config = self.sub_configs["text_config"](**kwargs)
250
+
251
+ # PRTS-specific parameters
252
+ self.max_action_dim = max_action_dim
253
+ self.action_chunk_size = action_chunk_size
254
+ self.num_denoise_steps = num_denoise_steps
255
+ self.flow_matching_action_loss_weight = flow_matching_action_loss_weight
256
+ self.use_fast_action_tokenizer = use_fast_action_tokenizer
257
+ self.embodiment_tag = embodiment_tag
258
+
259
+ # DiT action head config (nested dict)
260
+ # cross_attention_dim defaults to text_config.hidden_size at model init time
261
+ _default_dit_config = {
262
+ # Architecture — aligned with GR00T N1.6 (32 layers, inner_dim=32×48=1536)
263
+ "num_layers": 16, # 32
264
+ "num_attention_heads": 32,
265
+ "attention_head_dim": 48,
266
+ "output_dim": 1024,
267
+ # Regularisation
268
+ "dropout": 0.2,
269
+ "interleave_self_attention": True,
270
+ "norm_type": "ada_norm",
271
+ "final_dropout": True,
272
+ # Action-head specifics
273
+ "add_pos_embed": True,
274
+ # Noise schedule
275
+ "noise_beta_alpha": 1.5,
276
+ "noise_beta_beta": 1.0,
277
+ "noise_s": 0.999,
278
+ "num_timestep_buckets": 1000,
279
+ # Attention backend
280
+ "attn_implementation": "sdpa",
281
+ # AlternateVLDiT — separate visual / text token cross-attention
282
+ "use_alternate_vl_dit": True,
283
+ "attend_text_every_n_blocks": 2,
284
+ # MoT-style action expert: forwards full VLM ``past_key_values`` into the head;
285
+ # expert depth defaults to text_config.num_hidden_layers (override with expert_num_layers).
286
+ "use_mot_action_expert": False,
287
+ "mlp_mult": 4, # FFN hidden dim = inner_dim * mlp_mult (standard DiT only)
288
+ }
289
+ if dit_action_head_config is not None:
290
+ _default_dit_config.update(dit_action_head_config)
291
+ self.dit_action_head_config = _default_dit_config
292
+
293
+ # CRL (Contrastive Reinforcement Learning) parameters
294
+ self.crl_loss_weight = crl_loss_weight
295
+ self.crl_embed_dim = crl_embed_dim
296
+ self.crl_logsumexp_reg_weight = crl_logsumexp_reg_weight
297
+ self.crl_encoder_init_w = crl_encoder_init_w
298
+ self.crl_repr_norm = crl_repr_norm
299
+
300
+ # Token IDs
301
+ self.image_token_id = image_token_id
302
+ self.video_token_id = video_token_id
303
+ self.vision_start_token_id = vision_start_token_id
304
+ self.vision_end_token_id = vision_end_token_id
305
+
306
+ # # Propagate token IDs to text config
307
+ # if self.image_token_id is not None:
308
+ # self.text_config.image_token_id = self.image_token_id
309
+ # if self.video_token_id is not None:
310
+ # self.text_config.video_token_id = self.video_token_id
311
+ # if self.vision_start_token_id is not None:
312
+ # self.text_config.vision_start_token_id = self.vision_start_token_id
313
+
314
+ # Ensure vocab sizes are consistent
315
+ # if hasattr(self.text_config, 'vocab_size'):
316
+ # self.vocab_size = self.text_config.vocab_size
317
+
318
+ super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
319
+
320
+ # TODO (zy): 这里需要看下是不是在VLConfig传入这些state action的特殊token更合适更灵活
321
+ @property
322
+ def action_token_id(self):
323
+ """Get action token id from text config."""
324
+ return getattr(self.text_config, 'action_token_id', None)
325
+
326
+ @action_token_id.setter
327
+ def action_token_id(self, value):
328
+ """Set action token id in text config."""
329
+ if hasattr(self.text_config, 'action_token_id'):
330
+ self.text_config.action_token_id = value
331
+
332
+ def __getattribute__(self, key):
333
+ if "text_config" in super().__getattribute__("__dict__") and key not in [
334
+ "dtype",
335
+ "_attn_implementation_internal",
336
+ ]:
337
+ text_config = super().__getattribute__("text_config")
338
+ if key in text_config.__dict__:
339
+ return getattr(text_config, key)
340
+
341
+ return super().__getattribute__(key)
342
+
343
+
344
+ PRTS_FlowMatchingConfig_Qwen3VL.register_for_auto_class()
345
+ __all__ = ["PRTS_FlowMatchingConfig_Qwen3VL", "PRTS_Qwen3VLTextConfig"]
dit_action_head.py ADDED
@@ -0,0 +1,1230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiT (Diffusion Transformer) based flow matching action head for PRTS.
3
+
4
+ Replaces the Qwen3VLTextModel-based fm_action_expert with a lightweight DiT
5
+ that uses explicit cross-attention to VLM hidden states, following the architecture
6
+ from GR00T / pi05.
7
+
8
+ Architecture:
9
+ ActionEncoder(noisy_actions + dof_mask, timestep)
10
+ → action_features
11
+ → DiT(cross-attn to VLM hidden states, ada-norm timestep conditioning)
12
+ → ActionDecoder → predicted velocity
13
+ """
14
+
15
+ import math
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.distributions import Beta
21
+ from typing import Optional
22
+
23
+ from transformers.cache_utils import Cache
24
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
25
+
26
+
27
+ # DIT_PRESETS = {
28
+ # "DiT-B": {"num_attention_heads": 12, "attention_head_dim": 64, "output_dim": 768},
29
+ # "DiT-L": {"num_attention_heads": 32, "attention_head_dim": 48, "output_dim": 1536},
30
+ # }
31
+
32
+
33
+ class SinusoidalPositionalEncoding(nn.Module):
34
+ """Sinusoidal positional encoding for sequence positions or timesteps."""
35
+
36
+ def __init__(self, embedding_dim: int):
37
+ super().__init__()
38
+ self.embedding_dim = embedding_dim
39
+
40
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
41
+ timesteps = timesteps.float()
42
+ squeeze = False
43
+ if timesteps.dim() == 1:
44
+ timesteps = timesteps.unsqueeze(1)
45
+ squeeze = True
46
+
47
+ half_dim = self.embedding_dim // 2
48
+ exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
49
+ math.log(10000.0) / half_dim
50
+ )
51
+ freqs = timesteps.unsqueeze(-1) * exponent.exp()
52
+ enc = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
53
+
54
+ if squeeze:
55
+ enc = enc.squeeze(1)
56
+ return enc
57
+
58
+
59
+ class TimestepEncoder(nn.Module):
60
+ """Projects scalar timesteps to embedding space via sinusoidal encoding + MLP."""
61
+
62
+ def __init__(self, embedding_dim: int):
63
+ super().__init__()
64
+ self.sinusoidal = SinusoidalPositionalEncoding(256)
65
+ self.linear_1 = nn.Linear(256, embedding_dim)
66
+ self.act = nn.SiLU()
67
+ self.linear_2 = nn.Linear(embedding_dim, embedding_dim)
68
+
69
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
70
+ t_emb = self.sinusoidal(timesteps)
71
+ t_emb = self.linear_1(t_emb.to(dtype=self.linear_1.weight.dtype))
72
+ t_emb = self.act(t_emb)
73
+ t_emb = self.linear_2(t_emb)
74
+ return t_emb
75
+
76
+
77
+ class AdaLayerNorm(nn.Module):
78
+ """Adaptive Layer Normalization conditioned on timestep embeddings.
79
+
80
+ Applies scale-shift modulation: out = norm(x) * (1 + scale) + shift,
81
+ where (scale, shift) are linearly projected from the timestep embedding.
82
+ """
83
+
84
+ def __init__(self, embedding_dim: int, eps: float = 1e-5):
85
+ super().__init__()
86
+ self.silu = nn.SiLU()
87
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
88
+ self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=False)
89
+
90
+ def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
91
+ temb = self.linear(self.silu(temb))
92
+ scale, shift = temb.chunk(2, dim=-1)
93
+ x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
94
+ return x
95
+
96
+
97
+ class DiTAttention(nn.Module):
98
+ """Multi-head attention supporting both self-attention and cross-attention.
99
+
100
+ Supports two backends selected via ``attn_implementation``:
101
+
102
+ * ``"sdpa"`` (default) – uses :func:`F.scaled_dot_product_attention`, which
103
+ dispatches automatically to FlashAttention / memory-efficient attention
104
+ depending on the installed PyTorch build. The encoder padding mask is
105
+ expanded to ``(B, 1, 1, S)`` and passed as ``attn_mask``.
106
+
107
+ * ``"flash_attention_2"`` – calls the ``flash_attn`` package directly for
108
+ lower memory usage and higher throughput. For cross-attention with an
109
+ encoder padding mask the k/v tensors are unpadded and
110
+ :func:`flash_attn_varlen_func` is used so that padding tokens are never
111
+ processed. For self-attention (no mask) the simpler
112
+ :func:`flash_attn_func` is used.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ query_dim: int,
118
+ num_heads: int,
119
+ head_dim: int,
120
+ cross_attention_dim: Optional[int] = None,
121
+ dropout: float = 0.0,
122
+ bias: bool = True,
123
+ attn_implementation: str = "sdpa",
124
+ ):
125
+ super().__init__()
126
+ self.num_heads = num_heads
127
+ self.head_dim = head_dim
128
+ self.attn_implementation = attn_implementation
129
+ inner_dim = num_heads * head_dim
130
+
131
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
132
+ kv_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
133
+ self.to_k = nn.Linear(kv_dim, inner_dim, bias=bias)
134
+ self.to_v = nn.Linear(kv_dim, inner_dim, bias=bias)
135
+ self.to_out = nn.Sequential(
136
+ nn.Linear(inner_dim, query_dim, bias=bias),
137
+ nn.Dropout(dropout),
138
+ )
139
+
140
+ # ------------------------------------------------------------------
141
+ # Flash-Attention backend
142
+ # ------------------------------------------------------------------
143
+
144
+ def _flash_attn_forward(
145
+ self,
146
+ q: torch.Tensor,
147
+ k: torch.Tensor,
148
+ v: torch.Tensor,
149
+ attention_mask: Optional[torch.Tensor],
150
+ ) -> torch.Tensor:
151
+ """Run Flash Attention via HuggingFace's ``_flash_attention_forward``.
152
+
153
+ Args:
154
+ q: ``(B, T_q, H, D)``
155
+ k: ``(B, T_k, H, D)``
156
+ v: ``(B, T_k, H, D)``
157
+ attention_mask: ``(B, T_k)`` bool, True = valid token.
158
+
159
+ Returns:
160
+ ``(B, T_q, H*D)``
161
+ """
162
+
163
+ B, T_q, H, D = q.shape
164
+ # _flash_attention_forward returns (B, T_q, H, D); handles unpad/varlen internally.
165
+ out = _flash_attention_forward(
166
+ q, k, v,
167
+ attention_mask=attention_mask,
168
+ query_length=T_q,
169
+ is_causal=False,
170
+ dropout=0.0,
171
+ )
172
+ return out.reshape(B, T_q, H * D)
173
+
174
+ # ------------------------------------------------------------------
175
+ # Forward
176
+ # ------------------------------------------------------------------
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.Tensor,
181
+ encoder_hidden_states: Optional[torch.Tensor] = None,
182
+ attention_mask: Optional[torch.Tensor] = None,
183
+ ) -> torch.Tensor:
184
+ B, T, _ = hidden_states.shape
185
+
186
+ q = self.to_q(hidden_states)
187
+ kv_input = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
188
+ k = self.to_k(kv_input)
189
+ v = self.to_v(kv_input)
190
+
191
+ if self.attn_implementation == "flash_attention_2":
192
+ # Flash Attention expects (B, S, H, D)
193
+ q = q.view(B, T, self.num_heads, self.head_dim)
194
+ k = k.view(B, -1, self.num_heads, self.head_dim)
195
+ v = v.view(B, -1, self.num_heads, self.head_dim)
196
+ attn_output = self._flash_attn_forward(q, k, v, attention_mask)
197
+ else:
198
+ # SDPA expects (B, H, S, D)
199
+ q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
200
+ k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
201
+ v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
202
+
203
+ # Expand (B, S) bool mask → (B, 1, 1, S) for broadcasting.
204
+ sdpa_mask = None
205
+ if attention_mask is not None:
206
+ if attention_mask.dim() == 2:
207
+ sdpa_mask = attention_mask[:, None, None, :]
208
+ else:
209
+ sdpa_mask = attention_mask
210
+
211
+ attn_output = F.scaled_dot_product_attention(
212
+ q, k, v, attn_mask=sdpa_mask, dropout_p=0.0
213
+ )
214
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
215
+
216
+ return self.to_out(attn_output)
217
+
218
+
219
+ class FeedForward(nn.Module):
220
+ """Feed-forward network with GELU activation."""
221
+
222
+ def __init__(self, dim: int, dropout: float = 0.0, mult: int = 4):
223
+ super().__init__()
224
+ inner_dim = dim * mult
225
+ self.net = nn.Sequential(
226
+ nn.Linear(dim, inner_dim),
227
+ nn.GELU(approximate="tanh"),
228
+ nn.Dropout(dropout),
229
+ nn.Linear(inner_dim, dim),
230
+ nn.Dropout(dropout),
231
+ )
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ return self.net(x)
235
+
236
+
237
+ class BasicTransformerBlock(nn.Module):
238
+ """Transformer block with self/cross-attention, optional AdaLayerNorm, and feed-forward.
239
+
240
+ When cross_attention_dim is set, the attention block performs cross-attention
241
+ to encoder_hidden_states. Otherwise, it performs self-attention.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ dim: int,
247
+ num_attention_heads: int,
248
+ attention_head_dim: int,
249
+ dropout: float = 0.0,
250
+ cross_attention_dim: Optional[int] = None,
251
+ norm_type: str = "ada_norm",
252
+ final_dropout: bool = False,
253
+ attn_implementation: str = "sdpa",
254
+ ):
255
+ super().__init__()
256
+ self.norm_type = norm_type
257
+
258
+ if norm_type == "ada_norm":
259
+ self.norm1 = AdaLayerNorm(dim)
260
+ else:
261
+ self.norm1 = nn.LayerNorm(dim)
262
+
263
+ self.attn1 = DiTAttention(
264
+ query_dim=dim,
265
+ num_heads=num_attention_heads,
266
+ head_dim=attention_head_dim,
267
+ cross_attention_dim=cross_attention_dim,
268
+ dropout=dropout,
269
+ attn_implementation=attn_implementation,
270
+ )
271
+
272
+ self.norm3 = nn.LayerNorm(dim)
273
+ self.ff = FeedForward(dim, dropout=dropout)
274
+ self.final_dropout = nn.Dropout(dropout) if final_dropout else None
275
+
276
+ def forward(
277
+ self,
278
+ hidden_states: torch.Tensor,
279
+ encoder_hidden_states: Optional[torch.Tensor] = None,
280
+ encoder_attention_mask: Optional[torch.Tensor] = None,
281
+ temb: Optional[torch.Tensor] = None,
282
+ ) -> torch.Tensor:
283
+ if self.norm_type == "ada_norm":
284
+ norm_hidden_states = self.norm1(hidden_states, temb)
285
+ else:
286
+ norm_hidden_states = self.norm1(hidden_states)
287
+
288
+ attn_output = self.attn1(
289
+ norm_hidden_states,
290
+ encoder_hidden_states=encoder_hidden_states,
291
+ attention_mask=encoder_attention_mask,
292
+ )
293
+
294
+ if self.final_dropout is not None:
295
+ attn_output = self.final_dropout(attn_output)
296
+
297
+ hidden_states = attn_output + hidden_states
298
+
299
+ norm_hidden_states = self.norm3(hidden_states)
300
+ ff_output = self.ff(norm_hidden_states)
301
+ hidden_states = ff_output + hidden_states
302
+
303
+ return hidden_states
304
+
305
+
306
+ class DiT(nn.Module):
307
+ """Diffusion Transformer with cross-attention to VLM context features.
308
+
309
+ Interleaves cross-attention blocks (attending to encoder_hidden_states)
310
+ with self-attention blocks when interleave_self_attention=True.
311
+ Uses AdaLayerNorm for timestep conditioning throughout.
312
+
313
+ Output block applies timestep-conditioned scale-shift before final projection.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ num_attention_heads: int = 12,
319
+ attention_head_dim: int = 64,
320
+ output_dim: int = 768,
321
+ num_layers: int = 12,
322
+ dropout: float = 0.1,
323
+ norm_type: str = "ada_norm",
324
+ final_dropout: bool = True,
325
+ interleave_self_attention: bool = False,
326
+ cross_attention_dim: Optional[int] = None,
327
+ attn_implementation: str = "sdpa",
328
+ ):
329
+ super().__init__()
330
+ self.inner_dim = num_attention_heads * attention_head_dim
331
+ self.output_dim = output_dim
332
+ self.num_layers = num_layers
333
+ self.interleave_self_attention = interleave_self_attention
334
+
335
+ self.timestep_encoder = TimestepEncoder(self.inner_dim)
336
+
337
+ all_blocks = []
338
+ for idx in range(num_layers):
339
+ use_self_attn = idx % 2 == 1 and interleave_self_attention
340
+ curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
341
+
342
+ all_blocks.append(
343
+ BasicTransformerBlock(
344
+ dim=self.inner_dim,
345
+ num_attention_heads=num_attention_heads,
346
+ attention_head_dim=attention_head_dim,
347
+ dropout=dropout,
348
+ cross_attention_dim=curr_cross_attention_dim,
349
+ norm_type=norm_type,
350
+ final_dropout=final_dropout,
351
+ attn_implementation=attn_implementation,
352
+ )
353
+ )
354
+ self.transformer_blocks = nn.ModuleList(all_blocks)
355
+
356
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
357
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
358
+ self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
359
+
360
+ def forward(
361
+ self,
362
+ hidden_states: torch.Tensor,
363
+ encoder_hidden_states: torch.Tensor,
364
+ timestep: torch.LongTensor,
365
+ encoder_attention_mask: Optional[torch.Tensor] = None,
366
+ ) -> torch.Tensor:
367
+ temb = self.timestep_encoder(timestep)
368
+
369
+ hidden_states = hidden_states.contiguous()
370
+ encoder_hidden_states = encoder_hidden_states.contiguous()
371
+
372
+ for idx, block in enumerate(self.transformer_blocks):
373
+ if idx % 2 == 1 and self.interleave_self_attention:
374
+ hidden_states = block(
375
+ hidden_states,
376
+ encoder_hidden_states=None,
377
+ encoder_attention_mask=None,
378
+ temb=temb,
379
+ )
380
+ else:
381
+ hidden_states = block(
382
+ hidden_states,
383
+ encoder_hidden_states=encoder_hidden_states,
384
+ encoder_attention_mask=encoder_attention_mask,
385
+ temb=temb,
386
+ )
387
+
388
+ conditioning = temb
389
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1)
390
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
391
+ return self.proj_out_2(hidden_states)
392
+
393
+
394
+ class AlternateVLDiT(DiT):
395
+ """DiT variant that separates visual and text tokens during cross-attention.
396
+
397
+ Mirrors GR00T's AlternateVLDiT: even-indexed blocks do cross-attention,
398
+ alternating every ``attend_text_every_n_blocks`` between text tokens and
399
+ visual tokens. Odd-indexed blocks do self-attention (requires
400
+ ``interleave_self_attention=True``).
401
+
402
+ When no visual tokens are present (``image_mask`` is None or all-False),
403
+ all valid tokens are treated as text.
404
+ """
405
+
406
+ def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
407
+ super().__init__(*args, **kwargs)
408
+ assert self.interleave_self_attention, (
409
+ "AlternateVLDiT requires interleave_self_attention=True"
410
+ )
411
+ self.attend_text_every_n_blocks = attend_text_every_n_blocks
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ encoder_hidden_states: torch.Tensor,
417
+ timestep: torch.LongTensor,
418
+ encoder_attention_mask: Optional[torch.Tensor] = None,
419
+ image_mask: Optional[torch.Tensor] = None,
420
+ ) -> torch.Tensor:
421
+ """
422
+ Args:
423
+ encoder_attention_mask: (B, S) bool – True = valid VLM token.
424
+ image_mask: (B, S) bool – True = visual token position.
425
+ If None, all valid tokens are treated as text.
426
+ """
427
+ temb = self.timestep_encoder(timestep)
428
+ hidden_states = hidden_states.contiguous()
429
+ encoder_hidden_states = encoder_hidden_states.contiguous()
430
+
431
+ B, S, _ = encoder_hidden_states.shape
432
+ backbone_mask = (
433
+ encoder_attention_mask.bool()
434
+ if encoder_attention_mask is not None
435
+ else torch.ones(B, S, dtype=torch.bool, device=hidden_states.device)
436
+ )
437
+
438
+ if image_mask is not None and image_mask.any():
439
+ vis_mask = image_mask.bool() & backbone_mask # visual tokens
440
+ text_mask = (~image_mask.bool()) & backbone_mask # text tokens
441
+ else:
442
+ # No visual tokens – treat everything as text.
443
+ vis_mask = torch.zeros_like(backbone_mask)
444
+ text_mask = backbone_mask
445
+
446
+ for idx, block in enumerate(self.transformer_blocks):
447
+ if idx % 2 == 1:
448
+ # Self-attention block.
449
+ hidden_states = block(
450
+ hidden_states,
451
+ encoder_hidden_states=None,
452
+ encoder_attention_mask=None,
453
+ temb=temb,
454
+ )
455
+ else:
456
+ # Cross-attention block: alternate text / visual every N blocks.
457
+ if idx % (2 * self.attend_text_every_n_blocks) == 0:
458
+ curr_mask = text_mask
459
+ else:
460
+ curr_mask = vis_mask
461
+ hidden_states = block(
462
+ hidden_states,
463
+ encoder_hidden_states=encoder_hidden_states,
464
+ encoder_attention_mask=curr_mask,
465
+ temb=temb,
466
+ )
467
+
468
+ conditioning = temb
469
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=-1)
470
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
471
+ return self.proj_out_2(hidden_states)
472
+
473
+
474
+ class ActionEncoder(nn.Module):
475
+ """Encodes noisy actions (optionally concatenated with DOF mask) and timestep
476
+ into hidden features via MLP + sinusoidal time encoding.
477
+
478
+ Architecture: Linear → concat(action_emb, time_emb) → SiLU + Linear → Linear
479
+ """
480
+
481
+ def __init__(self, action_input_dim: int, hidden_size: int):
482
+ super().__init__()
483
+ self.hidden_size = hidden_size
484
+ self.layer1 = nn.Linear(action_input_dim, hidden_size)
485
+ self.layer2 = nn.Linear(2 * hidden_size, hidden_size)
486
+ self.layer3 = nn.Linear(hidden_size, hidden_size)
487
+ self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
488
+
489
+ def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
490
+ """
491
+ Args:
492
+ actions: (B, T, action_input_dim) noisy actions (+ DOF mask)
493
+ timesteps: (B,) discretized timesteps
494
+ """
495
+ B, T, _ = actions.shape
496
+ timesteps_expanded = timesteps.unsqueeze(1).expand(-1, T)
497
+
498
+ a_emb = self.layer1(actions)
499
+ tau_emb = self.pos_encoding(timesteps_expanded).to(dtype=a_emb.dtype)
500
+
501
+ x = torch.cat([a_emb, tau_emb], dim=-1)
502
+ x = F.silu(self.layer2(x))
503
+ x = self.layer3(x)
504
+ return x
505
+
506
+
507
+ class ActionDecoder(nn.Module):
508
+ """2-layer MLP that decodes DiT output to action-space velocity."""
509
+
510
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
511
+ super().__init__()
512
+ self.layer1 = nn.Linear(input_dim, hidden_dim)
513
+ self.layer2 = nn.Linear(hidden_dim, output_dim)
514
+
515
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
516
+ return self.layer2(F.relu(self.layer1(x)))
517
+
518
+
519
+ class FlowMatchingDiTHead(nn.Module):
520
+ """Flow matching action head using DiT (Diffusion Transformer).
521
+
522
+ Replaces the fm_action_expert (Qwen3VLTextModel-based) with a DiT that uses
523
+ explicit cross-attention to VLM hidden states instead of KV cache continuation.
524
+
525
+ Training:
526
+ 1. Sample noise and timestep from Beta distribution
527
+ 2. Compute noisy trajectory: x_t = (1-t)*noise + t*actions
528
+ 3. Compute velocity target: v = actions - noise
529
+ 4. Encode noisy actions + DOF mask + timestep → action features
530
+ 5. Prepend learned future query tokens
531
+ 6. Run DiT with cross-attention to VLM hidden states
532
+ 7. Decode to action-space velocity prediction
533
+
534
+ Inference:
535
+ Euler integration from pure noise (t=0) to clean actions (t=1)
536
+ over num_inference_timesteps steps.
537
+ """
538
+
539
+ def __init__(
540
+ self,
541
+ action_dim: int,
542
+ action_chunk_size: int,
543
+ cross_attention_dim: int,
544
+ num_inference_timesteps: int = 4,
545
+ config: Optional[dict] = None,
546
+ ):
547
+ super().__init__()
548
+ cfg = {
549
+ "num_layers": 16,
550
+ "num_attention_heads": 12,
551
+ "attention_head_dim": 64,
552
+ "output_dim": 1024,
553
+ "dropout": 0.2,
554
+ "interleave_self_attention": True,
555
+ "norm_type": "ada_norm",
556
+ "final_dropout": True,
557
+ "add_pos_embed": True,
558
+ "noise_beta_alpha": 1.5,
559
+ "noise_beta_beta": 1.0,
560
+ "noise_s": 0.999,
561
+ "num_timestep_buckets": 1000,
562
+ "attn_implementation": "sdpa",
563
+ "use_alternate_vl_dit": False,
564
+ "attend_text_every_n_blocks": 2,
565
+ }
566
+ if config is not None:
567
+ cfg.update(config)
568
+ # dit_model_type = config.get("dit_model_type")
569
+ # if dit_model_type and dit_model_type in DIT_PRESETS:
570
+ # cfg.update(DIT_PRESETS[dit_model_type])
571
+ # cfg.pop("dit_model_type", None)
572
+
573
+ self.action_dim = action_dim
574
+ self.action_chunk_size = action_chunk_size
575
+ self.num_inference_timesteps = num_inference_timesteps
576
+ self.num_timestep_buckets = cfg["num_timestep_buckets"]
577
+ self.noise_s = cfg["noise_s"]
578
+ self.use_alternate_vl_dit = cfg["use_alternate_vl_dit"]
579
+ self.add_pos_embed = cfg["add_pos_embed"]
580
+
581
+ num_attention_heads = cfg["num_attention_heads"]
582
+ attention_head_dim = cfg["attention_head_dim"]
583
+ output_dim = cfg["output_dim"]
584
+ inner_dim = num_attention_heads * attention_head_dim
585
+
586
+ dit_kwargs = dict(
587
+ num_attention_heads=num_attention_heads,
588
+ attention_head_dim=attention_head_dim,
589
+ output_dim=output_dim,
590
+ num_layers=cfg["num_layers"],
591
+ dropout=cfg["dropout"],
592
+ norm_type=cfg["norm_type"],
593
+ final_dropout=cfg["final_dropout"],
594
+ interleave_self_attention=cfg["interleave_self_attention"],
595
+ cross_attention_dim=cross_attention_dim,
596
+ attn_implementation=cfg["attn_implementation"],
597
+ )
598
+ if self.use_alternate_vl_dit:
599
+ self.dit = AlternateVLDiT(
600
+ **dit_kwargs,
601
+ attend_text_every_n_blocks=cfg["attend_text_every_n_blocks"],
602
+ )
603
+ else:
604
+ self.dit = DiT(**dit_kwargs)
605
+
606
+ # action_dim * 2: noisy action + DOF mask concatenated
607
+ self.action_encoder = ActionEncoder(action_dim * 2, inner_dim)
608
+ self.action_decoder = ActionDecoder(output_dim, inner_dim, action_dim)
609
+
610
+ if self.add_pos_embed:
611
+ max_seq_len = max(action_chunk_size, 256)
612
+ self.position_embedding = nn.Embedding(max_seq_len, inner_dim)
613
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
614
+
615
+ # self.beta_dist = Beta(cfg["noise_beta_alpha"], cfg["noise_beta_beta"])
616
+ self._beta_alpha = cfg["noise_beta_alpha"]
617
+ self._beta_beta = cfg["noise_beta_beta"]
618
+
619
+ def reset_parameters(self):
620
+ """Re-apply proper initialization.
621
+
622
+ HuggingFace from_pretrained calls _init_weights on modules whose
623
+ parameters are absent from the checkpoint, overwriting any custom
624
+ init done in __init__. Call this after from_pretrained when loading
625
+ from a base VLM checkpoint that does not contain DiT weights.
626
+ """
627
+ if self.add_pos_embed:
628
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
629
+ for module in self.modules():
630
+ if isinstance(module, nn.Linear):
631
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
632
+ if module.bias is not None:
633
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
634
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
635
+ nn.init.uniform_(module.bias, -bound, bound)
636
+ elif isinstance(module, nn.LayerNorm):
637
+ if module.elementwise_affine:
638
+ nn.init.ones_(module.weight)
639
+ nn.init.zeros_(module.bias)
640
+
641
+ def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor:
642
+ beta_dist = Beta(self._beta_alpha, self._beta_beta)
643
+ sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s)
644
+ return (self.noise_s - sample) / self.noise_s
645
+
646
+ def _encode_actions(
647
+ self,
648
+ noisy_actions: torch.Tensor,
649
+ t_discretized: torch.Tensor,
650
+ action_dof_mask: Optional[torch.Tensor],
651
+ device,
652
+ ) -> torch.Tensor:
653
+ """Encode noisy actions with DOF mask and timestep, add position embeddings."""
654
+ if action_dof_mask is not None:
655
+ encoder_input = torch.cat(
656
+ [noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1
657
+ )
658
+ else:
659
+ encoder_input = torch.cat(
660
+ [noisy_actions, torch.ones_like(noisy_actions)], dim=-1
661
+ )
662
+
663
+ action_features = self.action_encoder(encoder_input, t_discretized)
664
+
665
+ if self.add_pos_embed:
666
+ pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
667
+ pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
668
+ action_features = action_features + pos_embs
669
+
670
+ return action_features
671
+
672
+ def _dit_forward(
673
+ self,
674
+ sa_embs: torch.Tensor,
675
+ vl_embs: torch.Tensor,
676
+ t_discretized: torch.LongTensor,
677
+ encoder_attention_mask: Optional[torch.Tensor],
678
+ image_mask: Optional[torch.Tensor],
679
+ ) -> torch.Tensor:
680
+ if self.use_alternate_vl_dit:
681
+ return self.dit(
682
+ hidden_states=sa_embs,
683
+ encoder_hidden_states=vl_embs,
684
+ timestep=t_discretized,
685
+ encoder_attention_mask=encoder_attention_mask,
686
+ image_mask=image_mask,
687
+ )
688
+ return self.dit(
689
+ hidden_states=sa_embs,
690
+ encoder_hidden_states=vl_embs,
691
+ timestep=t_discretized,
692
+ encoder_attention_mask=encoder_attention_mask,
693
+ )
694
+
695
+ def forward(
696
+ self,
697
+ vl_embs: torch.Tensor,
698
+ actions: torch.Tensor,
699
+ action_dof_mask: Optional[torch.Tensor] = None,
700
+ encoder_attention_mask: Optional[torch.Tensor] = None,
701
+ image_mask: Optional[torch.Tensor] = None,
702
+ ) -> tuple:
703
+ """Training forward pass.
704
+
705
+ Args:
706
+ vl_embs: (B, S, D) VLM hidden states for cross-attention
707
+ actions: (B, T, action_dim) ground truth action trajectories
708
+ action_dof_mask: (B, T, action_dim) DOF validity mask
709
+ encoder_attention_mask: (B, S) bool – True = valid VLM token
710
+ image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT)
711
+
712
+ Returns:
713
+ (pred_v, velocity): predicted velocity and target velocity, both (B, T, action_dim)
714
+ """
715
+ device = vl_embs.device
716
+ B = actions.shape[0]
717
+
718
+ noise = torch.randn(actions.shape, device=device, dtype=actions.dtype)
719
+ t = self.sample_time(B, device=device, dtype=actions.dtype)
720
+ t_expanded = t[:, None, None]
721
+
722
+ noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
723
+ velocity = actions - noise
724
+
725
+ t_discretized = (t * self.num_timestep_buckets).long()
726
+
727
+ action_features = self._encode_actions(noisy_trajectory, t_discretized, action_dof_mask, device)
728
+
729
+ model_output = self._dit_forward(
730
+ action_features, vl_embs, t_discretized, encoder_attention_mask, image_mask
731
+ )
732
+
733
+ pred = self.action_decoder(model_output)
734
+ pred_v = pred[:, :actions.shape[1]]
735
+
736
+ return pred_v, velocity
737
+
738
+ @torch.no_grad()
739
+ def predict_action(
740
+ self,
741
+ vl_embs: torch.Tensor,
742
+ action_dof_mask: Optional[torch.Tensor] = None,
743
+ encoder_attention_mask: Optional[torch.Tensor] = None,
744
+ image_mask: Optional[torch.Tensor] = None,
745
+ ) -> torch.Tensor:
746
+ """Inference: denoise actions from noise using Euler integration.
747
+
748
+ Args:
749
+ vl_embs: (B, S, D) VLM hidden states
750
+ action_dof_mask: optional (B, T, action_dim) or (1, T, action_dim) DOF mask
751
+ encoder_attention_mask: (B, S) bool – True = valid VLM token
752
+ image_mask: (B, S) bool – True = visual token (used by AlternateVLDiT)
753
+
754
+ Returns:
755
+ (B, T, action_dim) denoised action trajectories
756
+ """
757
+ B = vl_embs.shape[0]
758
+ device = vl_embs.device
759
+ dtype = vl_embs.dtype
760
+
761
+ actions = torch.randn(
762
+ (B, self.action_chunk_size, self.action_dim),
763
+ device=device, dtype=dtype,
764
+ )
765
+
766
+ dt = 1.0 / self.num_inference_timesteps
767
+
768
+ for step in range(self.num_inference_timesteps):
769
+ t_cont = step / float(self.num_inference_timesteps)
770
+ t_discretized_val = int(t_cont * self.num_timestep_buckets)
771
+ timesteps_tensor = torch.full((B,), t_discretized_val, device=device, dtype=torch.long)
772
+
773
+ action_features = self._encode_actions(actions, timesteps_tensor, action_dof_mask, device)
774
+
775
+ model_output = self._dit_forward(
776
+ action_features, vl_embs, timesteps_tensor, encoder_attention_mask, image_mask
777
+ )
778
+
779
+ pred = self.action_decoder(model_output)
780
+ pred_velocity = pred[:, :self.action_chunk_size]
781
+
782
+ actions = actions + dt * pred_velocity
783
+
784
+ return actions
785
+
786
+
787
+ # ============================================================================
788
+ # Pi0.5-style KV-cache action expert (VLM K/V concat + GQA + SwiGLU FFN)
789
+ # ============================================================================
790
+ class AdaRMSNorm(nn.Module):
791
+ """Adaptive RMS normalization: (scale, shift, gate) from cond; zero-init."""
792
+
793
+ def __init__(self, dim: int, eps: float = 1e-6):
794
+ super().__init__()
795
+ self.eps = eps
796
+ self.modulation = nn.Linear(dim, dim * 3)
797
+ nn.init.zeros_(self.modulation.weight)
798
+ nn.init.zeros_(self.modulation.bias)
799
+
800
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
801
+ var = x.float().pow(2).mean(-1, keepdim=True)
802
+ normed = (x * torch.rsqrt(var + self.eps)).to(x.dtype)
803
+ scale, shift, gate = self.modulation(cond).chunk(3, dim=-1)
804
+ normed = normed * (1 + scale[:, None]) + shift[:, None]
805
+ return normed, gate[:, None]
806
+
807
+
808
+ class SwiGLUFeedForward(nn.Module):
809
+ """SiLU(gate_proj(x)) * up_proj(x) → down_proj."""
810
+
811
+ def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0, bias: bool = True):
812
+ super().__init__()
813
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
814
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
815
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
816
+ self.dropout = nn.Dropout(dropout)
817
+
818
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
819
+ return self.down_proj(self.dropout(F.silu(self.gate_proj(x)) * self.up_proj(x)))
820
+
821
+
822
+ class MoTAttention(nn.Module):
823
+ """Action Q attends to concatenated [VLM KV cache ; action KV]; GQA expand for SDPA."""
824
+
825
+ def __init__(
826
+ self,
827
+ hidden_size: int,
828
+ num_attention_heads: int,
829
+ num_kv_heads: int,
830
+ head_dim: int,
831
+ dropout: float = 0.0,
832
+ bias: bool = True,
833
+ ):
834
+ super().__init__()
835
+ if num_attention_heads % num_kv_heads != 0:
836
+ raise ValueError(
837
+ f"num_attention_heads ({num_attention_heads}) must be divisible by "
838
+ f"num_kv_heads ({num_kv_heads})"
839
+ )
840
+ self.num_attention_heads = num_attention_heads
841
+ self.num_kv_heads = num_kv_heads
842
+ self.head_dim = head_dim
843
+ q_dim = num_attention_heads * head_dim
844
+ kv_dim = num_kv_heads * head_dim
845
+ self.q_proj = nn.Linear(hidden_size, q_dim, bias=bias)
846
+ self.k_proj = nn.Linear(hidden_size, kv_dim, bias=bias)
847
+ self.v_proj = nn.Linear(hidden_size, kv_dim, bias=bias)
848
+ self.o_proj = nn.Linear(q_dim, hidden_size, bias=bias)
849
+ self.dropout = nn.Dropout(dropout)
850
+
851
+ def forward(
852
+ self,
853
+ action_hidden: torch.Tensor,
854
+ vlm_cached_k: torch.Tensor,
855
+ vlm_cached_v: torch.Tensor,
856
+ vlm_attention_mask: Optional[torch.Tensor] = None,
857
+ ) -> torch.Tensor:
858
+ B, T_a, _ = action_hidden.shape
859
+
860
+ q = self.q_proj(action_hidden)
861
+ act_k = self.k_proj(action_hidden)
862
+ act_v = self.v_proj(action_hidden)
863
+
864
+ q = q.view(B, T_a, self.num_attention_heads, self.head_dim).transpose(1, 2)
865
+ act_k = act_k.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2)
866
+ act_v = act_v.view(B, T_a, self.num_kv_heads, self.head_dim).transpose(1, 2)
867
+
868
+ k = torch.cat([vlm_cached_k, act_k], dim=2)
869
+ v = torch.cat([vlm_cached_v, act_v], dim=2)
870
+
871
+ repeat_factor = self.num_attention_heads // self.num_kv_heads
872
+ k = k.repeat_interleave(repeat_factor, dim=1)
873
+ v = v.repeat_interleave(repeat_factor, dim=1)
874
+
875
+ sdpa_mask = None
876
+ if vlm_attention_mask is not None:
877
+ action_mask = vlm_attention_mask.new_ones(B, T_a)
878
+ combined_mask = torch.cat([vlm_attention_mask, action_mask], dim=1)
879
+ sdpa_mask = combined_mask[:, None, None, :]
880
+
881
+ attn_out = F.scaled_dot_product_attention(
882
+ q, k, v, attn_mask=sdpa_mask, dropout_p=0.0,
883
+ )
884
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, T_a, -1)
885
+ return self.dropout(self.o_proj(attn_out))
886
+
887
+
888
+ class MoTBlock(nn.Module):
889
+ """AdaRMSNorm → attention → gated residual → AdaRMSNorm → SwiGLU FFN → gated residual."""
890
+
891
+ def __init__(
892
+ self,
893
+ hidden_size: int,
894
+ num_attention_heads: int,
895
+ num_kv_heads: int,
896
+ head_dim: int,
897
+ intermediate_size: int,
898
+ dropout: float = 0.0,
899
+ ):
900
+ super().__init__()
901
+ self.pre_attn_norm = AdaRMSNorm(hidden_size)
902
+ self.attn = MoTAttention(
903
+ hidden_size=hidden_size,
904
+ num_attention_heads=num_attention_heads,
905
+ num_kv_heads=num_kv_heads,
906
+ head_dim=head_dim,
907
+ dropout=dropout,
908
+ )
909
+ self.pre_ffn_norm = AdaRMSNorm(hidden_size)
910
+ self.ffn = SwiGLUFeedForward(hidden_size, intermediate_size, dropout=dropout)
911
+
912
+ def forward(
913
+ self,
914
+ action_hidden: torch.Tensor,
915
+ vlm_cached_k: torch.Tensor,
916
+ vlm_cached_v: torch.Tensor,
917
+ adarms_cond: torch.Tensor,
918
+ vlm_attention_mask: Optional[torch.Tensor] = None,
919
+ ) -> torch.Tensor:
920
+ normed, gate1 = self.pre_attn_norm(action_hidden, adarms_cond)
921
+ attn_out = self.attn(normed, vlm_cached_k, vlm_cached_v, vlm_attention_mask)
922
+ action_hidden = action_hidden + attn_out * gate1
923
+
924
+ normed2, gate2 = self.pre_ffn_norm(action_hidden, adarms_cond)
925
+ action_hidden = action_hidden + self.ffn(normed2) * gate2
926
+ return action_hidden
927
+
928
+
929
+ class MoTDiT(nn.Module):
930
+ """Stack of ActionBlocks; each block uses one VLM layer's KV pair."""
931
+
932
+ def __init__(
933
+ self,
934
+ hidden_size: int,
935
+ num_attention_heads: int,
936
+ num_kv_heads: int,
937
+ head_dim: int,
938
+ intermediate_size: int,
939
+ num_layers: int,
940
+ dropout: float = 0.2,
941
+ ):
942
+ super().__init__()
943
+ self.num_layers = num_layers
944
+ self.blocks = nn.ModuleList([
945
+ MoTBlock(
946
+ hidden_size=hidden_size,
947
+ num_attention_heads=num_attention_heads,
948
+ num_kv_heads=num_kv_heads,
949
+ head_dim=head_dim,
950
+ intermediate_size=intermediate_size,
951
+ dropout=dropout,
952
+ )
953
+ for _ in range(num_layers)
954
+ ])
955
+ self.final_norm = AdaRMSNorm(hidden_size)
956
+
957
+ def forward(
958
+ self,
959
+ action_hidden: torch.Tensor,
960
+ vlm_kv_cache: list,
961
+ adarms_cond: torch.Tensor,
962
+ vlm_attention_mask: Optional[torch.Tensor] = None,
963
+ ) -> torch.Tensor:
964
+ for idx, block in enumerate(self.blocks):
965
+ cached_k, cached_v = vlm_kv_cache[idx]
966
+ action_hidden = block(
967
+ action_hidden, cached_k, cached_v, adarms_cond, vlm_attention_mask,
968
+ )
969
+ action_hidden, _ = self.final_norm(action_hidden, adarms_cond)
970
+ return action_hidden
971
+
972
+
973
+ def _kv_pairs_from_past_key_values(past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]:
974
+ """Per-layer (K, V) from a HuggingFace decoder KV cache (order matches transformer layers)."""
975
+ return [
976
+ (past_key_values[i][0], past_key_values[i][1])
977
+ for i in range(len(past_key_values))
978
+ ]
979
+
980
+
981
+ class MoTFlowMatchingHead(nn.Module):
982
+ """Flow matching head: MoT-style action expert over VLM KV cache (concat + GQA)."""
983
+
984
+ def __init__(
985
+ self,
986
+ action_dim: int,
987
+ action_chunk_size: int,
988
+ vlm_config,
989
+ num_inference_timesteps: int = 10,
990
+ config: Optional[dict] = None,
991
+ ):
992
+ super().__init__()
993
+
994
+ _vlm_num_q_heads = 8 # vlm_config.num_attention_heads // 2 # optional: 8
995
+ _vlm_num_kv_heads = vlm_config.num_key_value_heads # 8
996
+ _vlm_head_dim = getattr(
997
+ vlm_config, "head_dim", vlm_config.hidden_size // vlm_config.num_attention_heads
998
+ ) # 128
999
+
1000
+ cfg = {
1001
+ "hidden_size": 1024, # vlm_config.hidden_size // 2,
1002
+ # "hidden_size": vlm_config.hidden_size // 2,
1003
+ "intermediate_size": vlm_config.intermediate_size // 4,
1004
+ "expert_num_layers": vlm_config.num_hidden_layers,
1005
+ # Attention dims default to VLM values (required for KV cache compat)
1006
+ "num_attention_heads": _vlm_num_q_heads,
1007
+ "num_kv_heads": _vlm_num_kv_heads,
1008
+ "head_dim": _vlm_head_dim,
1009
+ # Noise schedule
1010
+ "dropout": 0.2,
1011
+ "add_pos_embed": True,
1012
+ "noise_beta_alpha": 1.5,
1013
+ "noise_beta_beta": 1.0,
1014
+ "noise_s": 0.999,
1015
+ "num_timestep_buckets": 1000,
1016
+ }
1017
+ if config is not None:
1018
+ config = cfg.copy()
1019
+
1020
+ num_attention_heads = cfg["num_attention_heads"]
1021
+ num_kv_heads = cfg["num_kv_heads"]
1022
+ head_dim = cfg["head_dim"]
1023
+ hidden_size = cfg["hidden_size"]
1024
+ intermediate_size = cfg["intermediate_size"]
1025
+ num_layers = cfg["expert_num_layers"]
1026
+
1027
+ self.action_dim = action_dim
1028
+ self.action_chunk_size = action_chunk_size
1029
+ self.num_inference_timesteps = num_inference_timesteps
1030
+ self.num_timestep_buckets = cfg["num_timestep_buckets"]
1031
+ self.noise_s = cfg["noise_s"]
1032
+ self.add_pos_embed = cfg["add_pos_embed"]
1033
+
1034
+ self.action_in_proj = nn.Linear(action_dim * 2, hidden_size)
1035
+ self.action_out_proj = nn.Linear(hidden_size, action_dim)
1036
+
1037
+ self.time_sinusoidal = SinusoidalPositionalEncoding(hidden_size)
1038
+ self.time_mlp_1 = nn.Linear(hidden_size, hidden_size)
1039
+ self.time_mlp_2 = nn.Linear(hidden_size, hidden_size)
1040
+
1041
+ if self.add_pos_embed:
1042
+ max_seq = max(action_chunk_size, 256)
1043
+ self.position_embedding = nn.Embedding(max_seq, hidden_size)
1044
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
1045
+
1046
+ self.dit = MoTDiT(
1047
+ hidden_size=hidden_size,
1048
+ num_attention_heads=num_attention_heads,
1049
+ num_kv_heads=num_kv_heads,
1050
+ head_dim=head_dim,
1051
+ intermediate_size=intermediate_size,
1052
+ num_layers=num_layers,
1053
+ dropout=cfg["dropout"],
1054
+ )
1055
+
1056
+ self._beta_alpha = cfg["noise_beta_alpha"]
1057
+ self._beta_beta = cfg["noise_beta_beta"]
1058
+
1059
+ @property
1060
+ def num_dit_layers(self) -> int:
1061
+ """Number of expert blocks; must match ``len(past_key_values.key_cache)``."""
1062
+ return self.dit.num_layers
1063
+
1064
+ def _vlm_kv_list_from_past(self, past_key_values: Cache) -> list[tuple[torch.Tensor, torch.Tensor]]:
1065
+ n = len(past_key_values)
1066
+ if n != self.num_dit_layers:
1067
+ raise ValueError(
1068
+ f"MoT expert has {self.num_dit_layers} blocks but `past_key_values` has {n} "
1069
+ "layers. Set `dit_action_head_config['expert_num_layers']` to match "
1070
+ "`text_config.num_hidden_layers`."
1071
+ )
1072
+ return _kv_pairs_from_past_key_values(past_key_values)
1073
+
1074
+ def reset_parameters(self):
1075
+ """Re-apply proper initialization after from_pretrained."""
1076
+ if self.add_pos_embed:
1077
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
1078
+ for module in self.modules():
1079
+ if isinstance(module, AdaRMSNorm):
1080
+ nn.init.zeros_(module.modulation.weight)
1081
+ nn.init.zeros_(module.modulation.bias)
1082
+ elif isinstance(module, nn.Linear):
1083
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
1084
+ if module.bias is not None:
1085
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
1086
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
1087
+ nn.init.uniform_(module.bias, -bound, bound)
1088
+
1089
+ def _compute_adarms_cond(self, t_discretized: torch.Tensor) -> torch.Tensor:
1090
+ t_emb = self.time_sinusoidal(t_discretized.float())
1091
+ t_emb = t_emb.to(dtype=self.time_mlp_1.weight.dtype)
1092
+ t_emb = F.silu(self.time_mlp_1(t_emb))
1093
+ t_emb = F.silu(self.time_mlp_2(t_emb))
1094
+ return t_emb
1095
+
1096
+ def sample_time(self, batch_size: int, device, dtype) -> torch.Tensor:
1097
+ beta_dist = Beta(self._beta_alpha, self._beta_beta)
1098
+ sample = beta_dist.sample([batch_size]).to(device, dtype=dtype).clamp(max=self.noise_s)
1099
+ return (self.noise_s - sample) / self.noise_s
1100
+
1101
+ def _prepare_action_embeds(
1102
+ self,
1103
+ noisy_actions: torch.Tensor,
1104
+ action_dof_mask: Optional[torch.Tensor],
1105
+ ) -> torch.Tensor:
1106
+ if action_dof_mask is not None:
1107
+ x = torch.cat(
1108
+ [noisy_actions, action_dof_mask.to(noisy_actions.dtype)], dim=-1,
1109
+ )
1110
+ else:
1111
+ x = torch.cat([noisy_actions, torch.ones_like(noisy_actions)], dim=-1)
1112
+
1113
+ tokens = self.action_in_proj(x)
1114
+
1115
+ if self.add_pos_embed:
1116
+ pos_ids = torch.arange(tokens.shape[1], dtype=torch.long, device=noisy_actions.device)
1117
+ tokens = tokens + self.position_embedding(pos_ids).unsqueeze(0)
1118
+
1119
+ return tokens
1120
+
1121
+ def forward(
1122
+ self,
1123
+ past_key_values: Cache,
1124
+ actions: torch.Tensor,
1125
+ action_dof_mask: Optional[torch.Tensor] = None,
1126
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1127
+ ) -> tuple:
1128
+ """Training: returns (pred_velocity, target_velocity).
1129
+
1130
+ Args:
1131
+ past_key_values: VLM decoder KV cache; layer count must equal ``num_dit_layers``.
1132
+ """
1133
+ vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
1134
+ device = actions.device
1135
+ B = actions.shape[0]
1136
+
1137
+ noise = torch.randn(actions.shape, device=device, dtype=actions.dtype)
1138
+ t = self.sample_time(B, device=device, dtype=actions.dtype)
1139
+ t_expanded = t[:, None, None]
1140
+
1141
+ noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
1142
+ velocity = actions - noise
1143
+
1144
+ t_discretized = (t * self.num_timestep_buckets).long()
1145
+ adarms_cond = self._compute_adarms_cond(t_discretized)
1146
+
1147
+ action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask)
1148
+
1149
+ output = self.dit(
1150
+ action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
1151
+ )
1152
+
1153
+ pred = self.action_out_proj(output)
1154
+ pred_v = pred[:, :actions.shape[1]]
1155
+ return pred_v, velocity
1156
+
1157
+ def compute_velocity(
1158
+ self,
1159
+ past_key_values: Cache,
1160
+ actions: torch.Tensor,
1161
+ noise: torch.Tensor,
1162
+ t: torch.Tensor,
1163
+ action_dof_mask: Optional[torch.Tensor] = None,
1164
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1165
+ ) -> torch.Tensor:
1166
+ """Compute velocity prediction for pre-sampled noise and timestep.
1167
+
1168
+ Used by DiffusionNFT where noise and timestep must be shared between
1169
+ the current policy (v_θ) and the reference policy (v_old).
1170
+
1171
+ Args:
1172
+ past_key_values: VLM decoder KV cache
1173
+ actions: (B, T, action_dim) ground truth actions (x_0)
1174
+ noise: (B, T, action_dim) pre-sampled noise (ε)
1175
+ t: (B,) continuous timesteps in [0, 1)
1176
+ action_dof_mask, encoder_attention_mask,
1177
+
1178
+ Returns:
1179
+ pred_v: (B, T, action_dim) predicted velocity
1180
+ """
1181
+ vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
1182
+ device = actions.device
1183
+ t_expanded = t[:, None, None]
1184
+
1185
+ noisy_trajectory = (1 - t_expanded) * noise + t_expanded * actions
1186
+ t_discretized = (t * self.num_timestep_buckets).long()
1187
+ adarms_cond = self._compute_adarms_cond(t_discretized)
1188
+ action_tokens = self._prepare_action_embeds(noisy_trajectory, action_dof_mask)
1189
+ output = self.dit(
1190
+ action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
1191
+ )
1192
+ pred = self.action_out_proj(output)
1193
+ return pred[:, :actions.shape[1]]
1194
+
1195
+
1196
+ @torch.no_grad()
1197
+ def predict_action(
1198
+ self,
1199
+ past_key_values: Cache,
1200
+ action_dof_mask: Optional[torch.Tensor] = None,
1201
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1202
+ ) -> torch.Tensor:
1203
+ """Inference: Euler integration, returns (B, chunk_size, action_dim)."""
1204
+ k0 = past_key_values[0][0]
1205
+ B = k0.shape[0]
1206
+ device = k0.device
1207
+ dtype = k0.dtype
1208
+ vlm_kv_cache = self._vlm_kv_list_from_past(past_key_values)
1209
+
1210
+ actions = torch.randn(
1211
+ (B, self.action_chunk_size, self.action_dim),
1212
+ device=device, dtype=dtype,
1213
+ )
1214
+ dt = 1.0 / self.num_inference_timesteps
1215
+
1216
+ for step in range(self.num_inference_timesteps):
1217
+ t_cont = step / float(self.num_inference_timesteps)
1218
+ t_disc_val = int(t_cont * self.num_timestep_buckets)
1219
+ t_tensor = torch.full((B,), t_disc_val, device=device, dtype=torch.long)
1220
+
1221
+ adarms_cond = self._compute_adarms_cond(t_tensor)
1222
+ action_tokens = self._prepare_action_embeds(actions, action_dof_mask)
1223
+
1224
+ output = self.dit(
1225
+ action_tokens, vlm_kv_cache, adarms_cond, encoder_attention_mask,
1226
+ )
1227
+ pred_velocity = self.action_out_proj(output)[:, :self.action_chunk_size]
1228
+ actions = actions + dt * pred_velocity
1229
+
1230
+ return actions
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "eos_token_id": [
4
+ 151645,
5
+ 151643
6
+ ],
7
+ "pad_token_id": 151643,
8
+ "temperature": 0.7,
9
+ "top_k": 20,
10
+ "top_p": 0.8,
11
+ "transformers_version": "4.57.3"
12
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:543d8bd40ff00d487b24b85c724587b96839434f61bc51d438319042e0cf0fcb
3
+ size 4999639200
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:303b0fe49a58b0397c013f89c7047483c0910b77b8dc665ccb6f0e410cdd848c
3
+ size 4995750056
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31a73b4447c20c0bb0cf72b8a999e01ff08a0df3354686b449bdcb560a010c66
3
+ size 981882944
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_prts_qwen3_vl.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 TeleAI Rhodes Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Main VLA model architecture based on Qwen3-VL."""
16
+
17
+ from dataclasses import dataclass
18
+
19
+ import math
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.nn import CrossEntropyLoss, MSELoss
25
+ from typing import Any, Dict, List, Optional, Tuple, Union
26
+
27
+ from transformers.modeling_outputs import ModelOutput
28
+ from transformers.cache_utils import Cache
29
+ from transformers.processing_utils import Unpack
30
+ from transformers.utils import TransformersKwargs, is_torchdynamo_compiling
31
+
32
+ from .modeling_qwen3_vl import (
33
+ Qwen3VLForConditionalGeneration,
34
+ Qwen3VLTextModel,
35
+ Qwen3VLVisionModel,
36
+ )
37
+ from .configuration_prts_qwen3_vl import PRTS_FlowMatchingConfig_Qwen3VL
38
+ from .dit_action_head import FlowMatchingDiTHead, MoTFlowMatchingHead
39
+
40
+ ACTION_DATASET_NAMES = []
41
+
42
+ # ----------------------------- Print Customization -----------------------------
43
+ from colorama import init, Fore, Style
44
+ from datetime import datetime
45
+
46
+ # Initialize colorama
47
+ init(autoreset=True)
48
+
49
+ class CustomPrinter:
50
+ """Custom colored printer."""
51
+
52
+ # Define message type configuration
53
+ TYPE_CONFIG = {
54
+ 'normal': {
55
+ 'color': Fore.WHITE,
56
+ 'icon': '',
57
+ 'prefix': '',
58
+ 'style': Style.NORMAL
59
+ },
60
+ 'important': {
61
+ 'color': Fore.CYAN,
62
+ 'icon': '💡',
63
+ 'prefix': 'IMPORTANT',
64
+ 'style': Style.BRIGHT
65
+ }
66
+ }
67
+
68
+ @classmethod
69
+ def print(cls, message, msg_type='normal', show_time=True, show_icon=True, end='\n'):
70
+ """
71
+ Custom print function.
72
+
73
+ Args:
74
+ message: The message content to print
75
+ msg_type: Message type ('normal', 'info', 'success', 'warning', 'error', 'fail', 'debug', 'important')
76
+ show_time: Whether to display a timestamp
77
+ show_icon: Whether to display the icon
78
+ end: Line terminator
79
+ """
80
+ # Get configuration for the message type
81
+ config = cls.TYPE_CONFIG.get(msg_type, cls.TYPE_CONFIG['normal'])
82
+
83
+ # Build prefix parts
84
+ prefix_parts = []
85
+
86
+ # Add timestamp
87
+ if show_time:
88
+ timestamp = datetime.now().strftime('%H:%M:%S')
89
+ prefix_parts.append(f"[{timestamp}]")
90
+
91
+ # Add icon and prefix text
92
+ icon_text = f"{config['icon']} " if show_icon else ""
93
+ prefix_parts.append(f"{icon_text}{config['prefix']}")
94
+
95
+ if config['prefix'] == '':
96
+ full_message = message
97
+ else:
98
+ # Combine prefix parts
99
+ prefix = " ".join(prefix_parts)
100
+
101
+ # Construct full message
102
+ full_message = f"{prefix}: {message}"
103
+
104
+ # Apply color and style and print
105
+ formatted_message = f"{config['style']}{config['color']}{full_message}"
106
+ print(formatted_message, end=end)
107
+
108
+ @classmethod
109
+ def normal(cls, message, **kwargs):
110
+ """Convenience: normal-level print."""
111
+ cls.print(message, 'normal', **kwargs)
112
+
113
+ @classmethod
114
+ def important(cls, message, **kwargs):
115
+ """Convenience: important-level print."""
116
+ cls.print(message, 'important', **kwargs)
117
+
118
+ def important(message, **kwargs):
119
+ CustomPrinter.important(message, **kwargs)
120
+
121
+ # -------------------------------------------------------------
122
+
123
+
124
+ def create_sinusoidal_pos_embedding(
125
+ time: torch.Tensor,
126
+ dimension: int,
127
+ min_period: float = 4e-3,
128
+ max_period: float = 4.0,
129
+ device="cpu",
130
+ ) -> torch.Tensor:
131
+ """
132
+ Computes sine-cosine positional embedding vectors for scalar positions (diffusion timesteps).
133
+
134
+ Args:
135
+ time: Tensor of shape (batch_size,) containing timestep values
136
+ dimension: Embedding dimension (must be even)
137
+ min_period: Minimum period for sinusoidal encoding
138
+ max_period: Maximum period for sinusoidal encoding
139
+ device: Device to create tensors on
140
+
141
+ Returns:
142
+ Positional embeddings of shape (batch_size, dimension)
143
+ """
144
+ if dimension % 2 != 0:
145
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
146
+
147
+ if time.ndim != 1:
148
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
149
+
150
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, device=device)
151
+ period = min_period * (max_period / min_period) ** fraction
152
+
153
+ scaling_factor = 1.0 / period * 2 * math.pi
154
+ sin_input = scaling_factor[None, :] * time[:, None]
155
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
156
+ return pos_emb
157
+
158
+
159
+ class ContrastiveEncoder(nn.Module):
160
+ """
161
+ MLP projector for Contrastive Reinforcement Learning (CRL) embeddings.
162
+
163
+ Projects hidden states to a shared latent space for contrastive learning,
164
+ with L2 normalization for stable similarity computation.
165
+
166
+ Architecture: N-layer MLP with LayerNorm and Swish activation,
167
+ followed by a cold-initialized output projection.
168
+ [Linear -> LayerNorm -> Swish] x N -> Linear (cold init)
169
+
170
+ Matches stable_contrastive_rl's Q network structure (default: 4 hidden layers).
171
+
172
+ Args:
173
+ input_dim: Dimension of input hidden states
174
+ output_dim: Dimension of output embeddings (default: 256)
175
+ hidden_dim: Dimension of hidden layers (default: 1024)
176
+ num_layers: Number of hidden layers (default: 4)
177
+ repr_norm: Whether to L2-normalize outputs (default: False)
178
+ init_w: Small value for last layer weight initialization for cold init (default: 1e-12)
179
+ """
180
+ def __init__(
181
+ self,
182
+ input_dim: int,
183
+ output_dim: int = 256,
184
+ hidden_dim: int = 1024,
185
+ num_layers: int = 4,
186
+ repr_norm: bool = False,
187
+ init_w: float = 1e-12,
188
+ ):
189
+ super().__init__()
190
+ self.num_layers = num_layers
191
+ self.repr_norm = repr_norm
192
+
193
+ # Build hidden layers with LayerNorm
194
+ self.hidden_layers = nn.ModuleList()
195
+ self.layer_norms = nn.ModuleList()
196
+
197
+ for i in range(num_layers):
198
+ in_dim = input_dim if i == 0 else hidden_dim
199
+ self.hidden_layers.append(nn.Linear(in_dim, hidden_dim))
200
+ self.layer_norms.append(nn.LayerNorm(hidden_dim))
201
+
202
+ # Output projection layer with cold initialization
203
+ self.output_proj = nn.Linear(hidden_dim, output_dim)
204
+ self.output_proj.weight.data.uniform_(-init_w, init_w)
205
+ self.output_proj.bias.data.fill_(0)
206
+
207
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
208
+ """
209
+ Project input to L2-normalized embedding space.
210
+
211
+ Args:
212
+ x: Input tensor of shape (batch_size, input_dim)
213
+
214
+ Returns:
215
+ L2-normalized embeddings of shape (batch_size, output_dim)
216
+ """
217
+ # Pass through hidden layers
218
+ for fc, norm in zip(self.hidden_layers, self.layer_norms):
219
+ x = fc(x)
220
+ x = norm(x)
221
+ x = F.silu(x)
222
+
223
+ # Output projection
224
+ x = self.output_proj(x)
225
+
226
+ # Optional L2 normalization
227
+ if self.repr_norm:
228
+ x = F.normalize(x, dim=-1)
229
+
230
+ return x
231
+
232
+
233
+
234
+ @dataclass
235
+ class PRTS_Qwen3VL_ModelOutputWithPast(ModelOutput):
236
+ """
237
+ Output class for PRTS model based on Qwen3-VL.
238
+
239
+ Args:
240
+ loss: Combined total loss
241
+ flow_loss: Flow matching loss for action prediction
242
+ cross_entropy_loss: Standard language modeling loss
243
+ crl_loss: Contrastive Reinforcement Learning loss for goal-action alignment
244
+ logits: Language model logits
245
+ past_key_values: Cached key-value states
246
+ hidden_states: Hidden states from all layers (if output_hidden_states=True)
247
+ attentions: Attention weights (if output_attentions=True)
248
+ rope_deltas: RoPE position delta information
249
+ channel_loss_dict: Per-dataset loss values for logging
250
+ channel_loss_count_dict: Per-dataset token counts for loss normalization
251
+ """
252
+ loss: Optional[torch.FloatTensor] = None
253
+ flow_loss: Optional[torch.FloatTensor] = None
254
+ cross_entropy_loss: Optional[torch.FloatTensor] = None
255
+ crl_loss: Optional[torch.FloatTensor] = None
256
+ logits: Optional[torch.FloatTensor] = None
257
+ past_key_values: Optional[List[torch.FloatTensor]] = None
258
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
259
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
260
+ rope_deltas: Optional[torch.LongTensor] = None
261
+
262
+ crl_num_samples: Optional[torch.LongTensor] = None
263
+ channel_loss_dict: Optional[dict] = None
264
+ channel_loss_count_dict: Optional[dict] = None
265
+
266
+
267
+ class PRTS_Qwen3VL(Qwen3VLForConditionalGeneration):
268
+ """
269
+ Vision-Language-Action model based on Qwen3-VL.
270
+
271
+ This model extends Qwen3-VL to support:
272
+ 1. Proprioceptive state embedding and prediction
273
+ 2. Sub-task description generation (language format)
274
+ 3. Action chunk prediction via flow matching (continuous actions)
275
+ 4. Optional discrete action tokenization (fast mode)
276
+
277
+ The model uses a flow matching approach for continuous action prediction, with a DiT
278
+ (Diffusion Transformer) action head that cross-attends to VLM hidden states.
279
+ """
280
+ config: PRTS_FlowMatchingConfig_Qwen3VL
281
+
282
+ _tied_weights_keys = ["lm_head.weight"]
283
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
284
+
285
+ def __init__(
286
+ self,
287
+ config: PRTS_FlowMatchingConfig_Qwen3VL,
288
+ ):
289
+ """
290
+ Initialize the PRTS Qwen3-VL model for action processing.
291
+
292
+ Args:
293
+ config: Model configuration
294
+ use_fast_tokenizer (bool): Whether to use FAST tokenizer for discrete actions
295
+ flow_matching_action_loss_weight (float): Weight for flow matching action loss
296
+ """
297
+ super().__init__(config)
298
+
299
+ # The parent class initializes:
300
+ # - self.visual: Qwen3VLVisionModel
301
+ # - self.language_model: Qwen3VLTextModel
302
+ # - self.lm_head: Language model head
303
+ # - self.rope_deltas: Cached rope deltas
304
+ # We keep these and add PRTS-specific components
305
+
306
+ # PRTS-specific parameters
307
+ self.action_dim = config.max_action_dim
308
+ self.use_fast_tokenizer = config.use_fast_action_tokenizer
309
+ self.flow_matching_action_loss_weight = config.flow_matching_action_loss_weight
310
+
311
+ # Loss functions
312
+ self.loss_fct = CrossEntropyLoss(reduction="none")
313
+ self.loss_mse = MSELoss(reduction="none")
314
+
315
+ # DiT-based flow matching action head: standard (+ AlternateVLDiT) or pi0.5 KV expert
316
+ self.use_mot_action_expert = config.dit_action_head_config.get(
317
+ "use_mot_action_expert", False
318
+ )
319
+ if config.flow_matching_action_loss_weight > 0.:
320
+ if self.use_mot_action_expert:
321
+ self.dit_action_head = MoTFlowMatchingHead(
322
+ action_dim=self.action_dim,
323
+ action_chunk_size=config.action_chunk_size,
324
+ vlm_config=config.text_config,
325
+ num_inference_timesteps=config.num_denoise_steps,
326
+ config=config.dit_action_head_config,
327
+ )
328
+ else:
329
+ self.dit_action_head = FlowMatchingDiTHead(
330
+ action_dim=self.action_dim,
331
+ action_chunk_size=config.action_chunk_size,
332
+ cross_attention_dim=config.text_config.hidden_size,
333
+ num_inference_timesteps=config.num_denoise_steps,
334
+ config=config.dit_action_head_config,
335
+ )
336
+
337
+ # CRL (Contrastive Reinforcement Learning) components
338
+ if config.crl_loss_weight > 0.:
339
+ hidden_size = config.text_config.hidden_size
340
+ # Current encoders (trainable)
341
+ self.crl_action_encoder = ContrastiveEncoder(
342
+ input_dim=hidden_size,
343
+ output_dim=config.crl_embed_dim,
344
+ init_w=config.crl_encoder_init_w,
345
+ repr_norm=config.crl_repr_norm,
346
+ )
347
+ self.crl_goal_encoder = ContrastiveEncoder(
348
+ input_dim=hidden_size,
349
+ output_dim=config.crl_embed_dim,
350
+ init_w=config.crl_encoder_init_w,
351
+ repr_norm=config.crl_repr_norm,
352
+ )
353
+ # Learnable temperature (log-space for numerical stability, CLIP recipe).
354
+ self.crl_logit_scale = nn.Parameter(
355
+ torch.ones([], requires_grad=True) * math.log(1 / 0.2)
356
+ )
357
+
358
+ # Initialize weights
359
+ self.post_init()
360
+
361
+ # Print parameter counts
362
+ visual_params = sum(p.numel() for p in self.visual.parameters())
363
+ language_params = sum(p.numel() for p in self.language_model.parameters())
364
+ model_params = visual_params + language_params
365
+ important(f"Backbone VLM (visual + language_model) parameters: {model_params / 1e6:.2f}M")
366
+ important(f"Flow Matching Loss coefficient: {self.flow_matching_action_loss_weight}")
367
+
368
+ if config.flow_matching_action_loss_weight > 0.:
369
+ dit_params = sum(p.numel() for p in self.dit_action_head.parameters())
370
+ # Get the inner model type name for logging
371
+ if hasattr(self.dit_action_head, 'dit'):
372
+ dit_head_type = type(self.dit_action_head.dit).__name__
373
+ else:
374
+ dit_head_type = type(self.dit_action_head).__name__
375
+ important(f"DiT Action Head ({dit_head_type}) parameters: {dit_params / 1e6:.2f}M")
376
+
377
+ if config.crl_loss_weight > 0.:
378
+ crl_params = sum(p.numel() for p in self.crl_action_encoder.parameters())
379
+ crl_params += sum(p.numel() for p in self.crl_goal_encoder.parameters())
380
+ important(f"CRL Encoders (action + goal) parameters: {crl_params / 1e6:.2f}M")
381
+ important(f"CRL Loss coefficient: {config.crl_loss_weight}")
382
+ important(f"CRL Encoder init_w: {config.crl_encoder_init_w}")
383
+ important(f"CRL Repr Norm: {config.crl_repr_norm}")
384
+
385
+ self.fast_action_token_start_idx = 200000
386
+ self.use_multi_positive = True
387
+
388
+ def get_input_embeddings(self):
389
+ return self.language_model.get_input_embeddings()
390
+
391
+ def set_input_embeddings(self, value):
392
+ self.language_model.set_input_embeddings(value)
393
+
394
+ def set_decoder(self, decoder):
395
+ self.language_model = decoder
396
+
397
+ def get_decoder(self):
398
+ return self.language_model
399
+
400
+ def get_output_embeddings(self):
401
+ return self.lm_head
402
+
403
+ def set_output_embeddings(self, new_embeddings):
404
+ self.lm_head = new_embeddings
405
+
406
+ def to_float32_flow_matching_head(self):
407
+ """Convert flow matching heads to float32 for numerical stability."""
408
+ if hasattr(self, 'dit_action_head'):
409
+ self.dit_action_head = self.dit_action_head.to(dtype=torch.float32)
410
+
411
+ def set_fast_action_info(self, action_mapper, fast_action_token_start_idx):
412
+ """Set information for fast (discrete) action tokenization."""
413
+ self.action_mapper = action_mapper
414
+ self.fast_action_token_start_idx = fast_action_token_start_idx
415
+
416
+ def get_placeholder_mask_with_special_token(
417
+ self,
418
+ input_ids: torch.LongTensor,
419
+ inputs_embeds: torch.FloatTensor,
420
+ special_features: torch.FloatTensor,
421
+ special_pad_token_id: int,
422
+ ):
423
+ """
424
+ Get placeholder mask for a specific special token (e.g., state tokens).
425
+
426
+ Similar to get_placeholder_mask but for custom special tokens beyond image/video.
427
+ """
428
+ if input_ids is None:
429
+ special_mask = inputs_embeds == self.get_input_embeddings()(
430
+ torch.tensor(special_pad_token_id, dtype=torch.long, device=inputs_embeds.device)
431
+ )
432
+ special_mask = special_mask.all(-1)
433
+ else:
434
+ special_mask = input_ids == special_pad_token_id
435
+
436
+ n_special_tokens = special_mask.sum()
437
+ special_mask = special_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
438
+ if special_features is not None and inputs_embeds[special_mask].numel() != special_features.numel():
439
+ raise ValueError(
440
+ f"Features and tokens do not match: tokens: {n_special_tokens}, features {special_features.shape[0]}"
441
+ )
442
+
443
+ return special_mask
444
+
445
+ def forward(
446
+ self,
447
+ input_ids: Optional[torch.LongTensor] = None,
448
+ attention_mask: Optional[torch.Tensor] = None,
449
+ position_ids: Optional[torch.LongTensor] = None,
450
+ past_key_values: Optional[Cache] = None,
451
+ inputs_embeds: Optional[torch.FloatTensor] = None,
452
+ labels: Optional[torch.LongTensor] = None,
453
+ # use_cache: Optional[bool] = None,
454
+ # output_attentions: Optional[bool] = None,
455
+ # output_hidden_states: Optional[bool] = None,
456
+ # return_dict: Optional[bool] = None,
457
+ pixel_values: Optional[torch.Tensor] = None,
458
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
459
+ image_grid_thw: Optional[torch.LongTensor] = None,
460
+ video_grid_thw: Optional[torch.LongTensor] = None,
461
+ # rope_deltas: Optional[torch.LongTensor] = None,
462
+ cache_position: Optional[torch.LongTensor] = None,
463
+ logits_to_keep: Union[int, torch.Tensor] = 0,
464
+ actions: Optional[torch.Tensor] = None,
465
+ action_is_pad: torch.Tensor | None = None,
466
+ action_dof_mask: Optional[torch.Tensor] = None,
467
+ dataset_names: Optional[List[str]] = None,
468
+ **kwargs: Unpack[TransformersKwargs],
469
+ ) -> Union[tuple, PRTS_Qwen3VL_ModelOutputWithPast]:
470
+ """
471
+ Forward pass for PRTS_Qwen3VL model.
472
+
473
+ This extends Qwen3VLForConditionalGeneration.forward with:
474
+ - State embedding injection
475
+ - Action chunk flow matching
476
+ - DeepStack visual feature handling
477
+ """
478
+ if (input_ids is None) ^ (inputs_embeds is not None):
479
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
480
+
481
+
482
+ # 1. Prepare input embeddings
483
+ if inputs_embeds is None:
484
+ inputs_embeds = self.get_input_embeddings()(input_ids)
485
+
486
+ image_mask = None
487
+ video_mask = None
488
+
489
+ # 2. Process images with deepstack features
490
+ deepstack_image_embeds = None
491
+ if pixel_values is not None:
492
+ image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw, image_max_seqlen=kwargs['image_max_seqlen'])
493
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
494
+ image_mask, _ = self.get_placeholder_mask(
495
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
496
+ )
497
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
498
+
499
+ # 3. Process videos with deepstack features
500
+ deepstack_video_embeds = None
501
+ if pixel_values_videos is not None:
502
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
503
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
504
+ _, video_mask = self.get_placeholder_mask(
505
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
506
+ )
507
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
508
+
509
+ # 4. Aggregate deepstack visual features
510
+ visual_pos_masks = None
511
+ deepstack_visual_embeds = None
512
+ if image_mask is not None and video_mask is not None:
513
+ # aggregate visual_pos_masks and deepstack_visual_embeds
514
+ image_mask = image_mask[..., 0]
515
+ video_mask = video_mask[..., 0]
516
+ visual_pos_masks = image_mask | video_mask
517
+ deepstack_visual_embeds = []
518
+ image_mask_joint = image_mask[visual_pos_masks]
519
+ video_mask_joint = video_mask[visual_pos_masks]
520
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
521
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
522
+ embed_joint[image_mask_joint, :] = img_embed
523
+ embed_joint[video_mask_joint, :] = vid_embed
524
+ deepstack_visual_embeds.append(embed_joint)
525
+ elif image_mask is not None:
526
+ image_mask = image_mask[..., 0]
527
+ visual_pos_masks = image_mask
528
+ deepstack_visual_embeds = deepstack_image_embeds
529
+ elif video_mask is not None:
530
+ video_mask = video_mask[..., 0]
531
+ visual_pos_masks = video_mask
532
+ deepstack_visual_embeds = deepstack_video_embeds
533
+
534
+ if attention_mask is not None:
535
+ attention_mask = attention_mask.to(inputs_embeds.device)
536
+
537
+ # 7. Calculate position IDs using Qwen3VL's rope index
538
+ if position_ids is None:
539
+ attention_mask_tensor = (
540
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
541
+ )
542
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
543
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
544
+ if attention_mask_tensor.dtype.is_floating_point:
545
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
546
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
547
+
548
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
549
+ (input_ids is not None and input_ids.shape[1] != 1)
550
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
551
+ )
552
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
553
+ (cache_position is not None and cache_position[0] == 0)
554
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
555
+ )
556
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
557
+ position_ids, rope_deltas = self.get_rope_index(
558
+ input_ids,
559
+ image_grid_thw,
560
+ video_grid_thw,
561
+ attention_mask=attention_mask_tensor,
562
+ )
563
+ self.rope_deltas = rope_deltas
564
+ else:
565
+ batch_size, seq_length, _ = inputs_embeds.shape
566
+ delta = (
567
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
568
+ if cache_position is not None
569
+ else 0
570
+ )
571
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
572
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
573
+ if cache_position is not None: # otherwise `deltas` is an int `0`
574
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
575
+ position_ids = position_ids.add(delta)
576
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
577
+
578
+ _lm_extra_kwargs: dict = {}
579
+
580
+ _use_cache = (
581
+ self.use_mot_action_expert
582
+ and self.flow_matching_action_loss_weight > 0.
583
+ and actions is not None
584
+ )
585
+
586
+ vlm_outputs = self.language_model(
587
+ input_ids=None,
588
+ position_ids=position_ids,
589
+ attention_mask=attention_mask,
590
+ past_key_values=past_key_values,
591
+ inputs_embeds=inputs_embeds,
592
+ use_cache=_use_cache,
593
+ cache_position=cache_position,
594
+ visual_pos_masks=visual_pos_masks,
595
+ deepstack_visual_embeds=deepstack_visual_embeds,
596
+ output_hidden_states=False,
597
+ **_lm_extra_kwargs,
598
+ **kwargs,
599
+ )
600
+
601
+ vlm_hidden_states = vlm_outputs.last_hidden_state
602
+
603
+ # 11. Run DiT action head if actions are present
604
+ dit_pred_v = None
605
+ dit_velocity = None
606
+ if actions is not None and self.flow_matching_action_loss_weight > 0:
607
+ # vlm_hidden_states shape: bs, seq_length, hidden_size
608
+ actions_for_dit = actions.to(vlm_hidden_states.device, dtype=vlm_hidden_states.dtype)
609
+ dof_mask_for_dit = action_dof_mask.to(vlm_hidden_states.device, dtype=vlm_hidden_states.dtype) if action_dof_mask is not None else None
610
+ # Pass attention_mask so DiT cross-attention ignores padding tokens
611
+ dit_encoder_attention_mask = attention_mask.bool() if attention_mask is not None else None
612
+
613
+ if self.use_mot_action_expert and vlm_outputs.past_key_values is not None:
614
+ dit_pred_v, dit_velocity = self.dit_action_head(
615
+ vlm_outputs.past_key_values,
616
+ actions_for_dit,
617
+ dof_mask_for_dit,
618
+ encoder_attention_mask=dit_encoder_attention_mask,
619
+ )
620
+ else:
621
+ # Standard: pass single (last-layer) VLM hidden states
622
+ dit_image_mask = visual_pos_masks.bool() if visual_pos_masks is not None else None
623
+ dit_pred_v, dit_velocity = self.dit_action_head(
624
+ vlm_hidden_states, actions_for_dit, dof_mask_for_dit,
625
+ encoder_attention_mask=dit_encoder_attention_mask,
626
+ image_mask=dit_image_mask,
627
+ )
628
+
629
+ # 12. Compute logits
630
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
631
+ logits = self.lm_head(vlm_hidden_states[:, slice_indices, :])
632
+
633
+ # 13. Compute losses
634
+ loss = None
635
+ cross_entropy_loss, flow_loss = None, None
636
+ channel_loss_dict = None
637
+ channel_loss_count_dict = None
638
+
639
+ if labels is not None:
640
+ loss = 0
641
+ action_accuracy = 0
642
+ unique_datasets_name = list(set(dataset_names)) if dataset_names is not None else []
643
+
644
+ # Compute cross-entropy loss
645
+ shift_logits = logits[..., :-1, :].float().contiguous()
646
+ shift_labels = labels[..., 1:].contiguous()
647
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
648
+ shift_labels = shift_labels.view(-1)
649
+
650
+ shift_labels = shift_labels.to(shift_logits.device)
651
+ non_ignored_mask = shift_labels != -100
652
+ _cross_entropy_loss = self.loss_fct(shift_logits, shift_labels)
653
+ cross_entropy_loss = (
654
+ _cross_entropy_loss[non_ignored_mask].mean()
655
+ if non_ignored_mask.any()
656
+ else (_cross_entropy_loss.sum() * 0.0)
657
+ )
658
+
659
+ # Add cross-entropy loss to total
660
+ if not torch.isnan(cross_entropy_loss):
661
+ loss += cross_entropy_loss
662
+ else:
663
+ with torch.no_grad():
664
+ cross_entropy_loss.detach()
665
+
666
+ # Compute action token prediction accuracy (for logging)
667
+ shift_logits_for_acc = logits[..., :-1, :].contiguous()
668
+ action_preds = shift_logits_for_acc.argmax(dim=-1)
669
+ shift_labels_for_acc = labels[..., 1:].contiguous()
670
+
671
+ action_mask = (
672
+ shift_labels_for_acc >= self.fast_action_token_start_idx
673
+ )
674
+
675
+ if self.use_fast_tokenizer and action_mask.any():
676
+ correct_preds = (action_preds == shift_labels_for_acc) & action_mask
677
+ action_accuracy = (
678
+ correct_preds.sum().float() / action_mask.sum().float()
679
+ )
680
+
681
+ if channel_loss_dict is None:
682
+ channel_loss_dict = {}
683
+ channel_loss_count_dict = {}
684
+
685
+ channel_loss_dict["action_accuracy"] = action_accuracy.detach()
686
+ channel_loss_count_dict["action_accuracy"] = torch.tensor(1, device=action_accuracy.device)
687
+
688
+ # 14. Compute flow matching loss (DiT action head)
689
+ if dit_pred_v is not None and self.flow_matching_action_loss_weight > 0:
690
+ if channel_loss_dict is not None:
691
+ channel_loss_dict.update(
692
+ {
693
+ f"flow_matching/{dataset_name}": torch.tensor(0.0, device=logits.device)
694
+ for dataset_name in ACTION_DATASET_NAMES
695
+ }
696
+ )
697
+ channel_loss_count_dict.update(
698
+ {
699
+ f"flow_matching/{dataset_name}": torch.tensor(0, device=logits.device)
700
+ for dataset_name in ACTION_DATASET_NAMES
701
+ }
702
+ )
703
+ else:
704
+ channel_loss_dict = {
705
+ f"flow_matching/{dataset_name}": torch.tensor(0.0, device=logits.device)
706
+ for dataset_name in ACTION_DATASET_NAMES
707
+ }
708
+ channel_loss_count_dict = {
709
+ f"flow_matching/{dataset_name}": torch.tensor(0, device=logits.device)
710
+ for dataset_name in ACTION_DATASET_NAMES
711
+ }
712
+
713
+ # Compute flow matching loss: MSE between predicted and target velocity
714
+ _fm_loss = self.loss_mse(dit_pred_v, dit_velocity)
715
+
716
+ # Apply DOF mask (zero out invalid action dimensions)
717
+ if action_dof_mask is not None:
718
+ valid_action_dim = int(action_dof_mask[0, 0, :].sum(dim=-1).item()) # NOTE: only support 单种具身实体数据微调
719
+ _fm_loss = _fm_loss[:, :, :valid_action_dim]
720
+
721
+ # Apply action_is_pad mask: exclude padding timesteps from loss
722
+ # action_is_pad: (B, T), True = pad timestep → should not contribute to loss
723
+ if action_is_pad is not None:
724
+ valid_timestep_mask = ~action_is_pad[:, :_fm_loss.shape[1]] # align length
725
+ _fm_loss = _fm_loss * valid_timestep_mask.unsqueeze(-1)
726
+ flow_loss = _fm_loss.sum() / (valid_timestep_mask.sum() * _fm_loss.shape[-1])
727
+ else:
728
+ flow_loss = _fm_loss.mean()
729
+
730
+ if not torch.isnan(flow_loss):
731
+ loss = loss + self.flow_matching_action_loss_weight * flow_loss if loss is not None else self.flow_matching_action_loss_weight * flow_loss
732
+ else:
733
+ with torch.no_grad():
734
+ flow_loss.detach()
735
+
736
+ # Per-dataset flow matching loss logging
737
+ logging_fm_loss = _fm_loss.detach().mean(dim=(1, 2)) # Sum over chunk_size and action_dim
738
+
739
+ action_dataset_names = dataset_names if dataset_names is not None else []
740
+ unique_action_datasets = list(set(action_dataset_names))
741
+
742
+ for dataset_name_i in unique_action_datasets:
743
+ action_dataset_mask = torch.tensor(
744
+ [name == dataset_name_i for name in action_dataset_names],
745
+ device=logits.device,
746
+ )
747
+ if action_dataset_mask.any():
748
+ dataset_fm_loss = logging_fm_loss[action_dataset_mask].sum()
749
+ dataset_fm_count = action_dataset_mask.sum()
750
+
751
+ prefixed_key = f"flow_matching/{dataset_name_i}"
752
+ channel_loss_dict[prefixed_key] += dataset_fm_loss
753
+ channel_loss_count_dict[prefixed_key] += dataset_fm_count
754
+
755
+ elif self.flow_matching_action_loss_weight > 0:
756
+ # Dummy loss to keep all DiT parameters in computation graph
757
+ dummy_params = [p.sum() * 0.0 for p in self.dit_action_head.parameters() if p.requires_grad]
758
+ dummy_loss = sum(dummy_params) if len(dummy_params) > 0 else torch.tensor(0.0, device=logits.device)
759
+ loss = (loss + dummy_loss) if loss is not None else dummy_loss
760
+
761
+ return PRTS_Qwen3VL_ModelOutputWithPast(
762
+ loss=loss,
763
+ cross_entropy_loss=(
764
+ cross_entropy_loss.detach() if cross_entropy_loss is not None else None
765
+ ),
766
+ flow_loss=(
767
+ flow_loss.detach() if flow_loss is not None else None
768
+ ),
769
+ crl_loss=None,
770
+ logits=logits,
771
+ past_key_values=vlm_outputs.past_key_values,
772
+ # hidden_states=vlm_outputs.hidden_states,
773
+ # attentions=vlm_outputs.attentions,
774
+ crl_num_samples=None,
775
+ rope_deltas=self.rope_deltas,
776
+ channel_loss_dict=channel_loss_dict,
777
+ channel_loss_count_dict=channel_loss_count_dict,
778
+ )
779
+
780
+
781
+ def embed_prefix(
782
+ self,
783
+ input_ids: torch.LongTensor,
784
+ inputs_embeds: torch.FloatTensor | None = None,
785
+ pixel_values: torch.Tensor | None = None,
786
+ pixel_values_videos: torch.FloatTensor | None = None,
787
+ image_grid_thw: torch.LongTensor | None = None,
788
+ video_grid_thw: torch.LongTensor | None = None,
789
+ **kwargs,
790
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
791
+ """
792
+ Embed prefix tokens including vision, DeepStack, and (optionally) state features.
793
+
794
+ Returns:
795
+ (inputs_embeds, visual_pos_masks, deepstack_visual_embeds)
796
+ """
797
+ if inputs_embeds is None:
798
+ inputs_embeds = self.get_input_embeddings()(input_ids)
799
+
800
+ image_mask = None
801
+ video_mask = None
802
+ deepstack_image_embeds = None
803
+ deepstack_video_embeds = None
804
+
805
+ if pixel_values is not None:
806
+ image_embeds, deepstack_image_embeds = self.get_image_features(
807
+ pixel_values, image_grid_thw,
808
+ image_max_seqlen=kwargs.get('image_max_seqlen'),
809
+ )
810
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
811
+ image_mask, _ = self.get_placeholder_mask(
812
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
813
+ )
814
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
815
+
816
+ if pixel_values_videos is not None:
817
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
818
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
819
+ _, video_mask = self.get_placeholder_mask(
820
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
821
+ )
822
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
823
+
824
+ visual_pos_masks = None
825
+ deepstack_visual_embeds = None
826
+ if image_mask is not None and video_mask is not None:
827
+ image_mask = image_mask[..., 0]
828
+ video_mask = video_mask[..., 0]
829
+ visual_pos_masks = image_mask | video_mask
830
+ deepstack_visual_embeds = []
831
+ image_mask_joint = image_mask[visual_pos_masks]
832
+ video_mask_joint = video_mask[visual_pos_masks]
833
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
834
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
835
+ embed_joint[image_mask_joint, :] = img_embed
836
+ embed_joint[video_mask_joint, :] = vid_embed
837
+ deepstack_visual_embeds.append(embed_joint)
838
+ elif image_mask is not None:
839
+ image_mask = image_mask[..., 0]
840
+ visual_pos_masks = image_mask
841
+ deepstack_visual_embeds = deepstack_image_embeds
842
+ elif video_mask is not None:
843
+ video_mask = video_mask[..., 0]
844
+ visual_pos_masks = video_mask
845
+ deepstack_visual_embeds = deepstack_video_embeds
846
+
847
+ return inputs_embeds, visual_pos_masks, deepstack_visual_embeds
848
+
849
+ @torch.no_grad()
850
+ def sample_actions(
851
+ self,
852
+ input_ids: torch.LongTensor | None = None,
853
+ position_ids: torch.LongTensor | None = None,
854
+ attention_mask: torch.Tensor | None = None,
855
+ past_key_values: list[torch.FloatTensor] | None = None,
856
+ inputs_embeds: torch.FloatTensor | None = None,
857
+ cache_position: torch.LongTensor | None = None,
858
+ pixel_values: torch.Tensor | None = None,
859
+ pixel_values_videos: torch.FloatTensor | None = None,
860
+ image_grid_thw: torch.LongTensor | None = None,
861
+ video_grid_thw: torch.LongTensor | None = None,
862
+ action_dof_mask: Optional[torch.Tensor] = None,
863
+ **kwargs,
864
+ ) -> Tuple[torch.Tensor, Any]:
865
+ """
866
+ Sample actions using DiT-based flow matching denoising.
867
+
868
+ 1. Computes position_ids via get_rope_index
869
+ 2. Embeds the prefix (with DeepStack visual features)
870
+ 3. Runs the language model to get hidden states
871
+ 4. Uses DiT action head to denoise actions via cross-attention to VLM features
872
+
873
+ Returns:
874
+ (x_t, outputs) — denoised action trajectories and language-model outputs
875
+ """
876
+ if position_ids is None:
877
+ position_ids, _ = self.get_rope_index(
878
+ input_ids,
879
+ image_grid_thw=image_grid_thw,
880
+ video_grid_thw=video_grid_thw,
881
+ attention_mask=attention_mask,
882
+ )
883
+
884
+ visual_pos_masks = None
885
+ deepstack_visual_embeds = None
886
+ if inputs_embeds is None:
887
+ inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self.embed_prefix(
888
+ input_ids,
889
+ pixel_values=pixel_values,
890
+ pixel_values_videos=pixel_values_videos,
891
+ image_grid_thw=image_grid_thw,
892
+ video_grid_thw=video_grid_thw,
893
+ **kwargs,
894
+ )
895
+
896
+ _sample_use_cache = (
897
+ self.use_mot_action_expert and self.flow_matching_action_loss_weight > 0
898
+ )
899
+ outputs = self.language_model(
900
+ input_ids=None,
901
+ position_ids=position_ids,
902
+ attention_mask=attention_mask,
903
+ past_key_values=past_key_values,
904
+ inputs_embeds=inputs_embeds,
905
+ use_cache=_sample_use_cache,
906
+ cache_position=cache_position,
907
+ visual_pos_masks=visual_pos_masks,
908
+ deepstack_visual_embeds=deepstack_visual_embeds,
909
+ output_hidden_states=False,
910
+ )
911
+
912
+ vlm_hidden_states = outputs.last_hidden_state
913
+ dit_encoder_attention_mask = attention_mask.bool() if attention_mask is not None else None
914
+
915
+ if self.use_mot_action_expert and outputs.past_key_values is not None:
916
+ x_t = self.dit_action_head.predict_action(
917
+ outputs.past_key_values,
918
+ action_dof_mask,
919
+ encoder_attention_mask=dit_encoder_attention_mask,
920
+ )
921
+ else:
922
+ dit_image_mask = visual_pos_masks.bool() if visual_pos_masks is not None else None
923
+ x_t = self.dit_action_head.predict_action(
924
+ vlm_hidden_states, action_dof_mask,
925
+ encoder_attention_mask=dit_encoder_attention_mask,
926
+ image_mask=dit_image_mask,
927
+ )
928
+
929
+ return x_t, outputs
930
+
931
+
932
+ PRTS_Qwen3VL.register_for_auto_class()
933
+
934
+
935
+ __all__ = ["PRTS_Qwen3VL", "PRTS_Qwen3VL_ModelOutputWithPast"]
modeling_qwen3_vl.py ADDED
@@ -0,0 +1,1645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3_vl.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from dataclasses import dataclass
23
+ from typing import Any, Callable, Optional, Union
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.integrations import use_kernel_forward_from_hub
34
+ from transformers.masking_utils import create_causal_mask
35
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
36
+ from transformers.modeling_layers import GradientCheckpointingLayer
37
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
38
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
+ from transformers.processing_utils import Unpack
41
+ from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
42
+ from transformers.utils.deprecation import deprecate_kwarg
43
+ from transformers.utils.generic import check_model_inputs
44
+ from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
45
+ # 在文件头部导入
46
+
47
+ try:
48
+ from qwen_rope_kernel_2 import fused_qwen_rope as fused_qwen_rope_v2
49
+ HAS_QWEN_ROPE_V2 = True
50
+ except ImportError:
51
+ print("No qwen_rope_kernel_2 found")
52
+ HAS_QWEN_ROPE_V2 = False
53
+
54
+ try:
55
+ from fused_rmsnorm import RMSNormModelFunction as _FUSED_RMSFUNC
56
+ HAS_FUSED_RMSNORM = True
57
+ except ImportError:
58
+ print("No fused_rmsnorm found")
59
+ HAS_FUSED_RMSNORM = False
60
+
61
+
62
+ class Qwen3VLVisionMLP(nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.hidden_size = config.hidden_size
66
+ self.intermediate_size = config.intermediate_size
67
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
68
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
69
+ self.act_fn = ACT2FN[config.hidden_act]
70
+
71
+ def forward(self, hidden_state):
72
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
73
+
74
+
75
+ class Qwen3VLVisionPatchEmbed(nn.Module):
76
+ def __init__(self, config) -> None:
77
+ super().__init__()
78
+ self.patch_size = config.patch_size
79
+ self.temporal_patch_size = config.temporal_patch_size
80
+ self.in_channels = config.in_channels
81
+ self.embed_dim = config.hidden_size
82
+
83
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
84
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
85
+
86
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
87
+ target_dtype = self.proj.weight.dtype
88
+ hidden_states = hidden_states.view(
89
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
90
+ )
91
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
92
+ return hidden_states
93
+
94
+
95
+ class Qwen3VLVisionRotaryEmbedding(nn.Module):
96
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
97
+
98
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
99
+ super().__init__()
100
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
101
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
102
+
103
+ def forward(self, seqlen: int) -> torch.Tensor:
104
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
105
+ freqs = torch.outer(seq, self.inv_freq)
106
+ return freqs
107
+
108
+
109
+ class Qwen3VLVisionPatchMerger(nn.Module):
110
+ def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:
111
+ super().__init__()
112
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
113
+ self.use_postshuffle_norm = use_postshuffle_norm
114
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
115
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
116
+ self.act_fn = nn.GELU()
117
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
121
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
122
+ return x
123
+
124
+
125
+ def rotate_half(x):
126
+ """Rotates half the hidden dims of the input."""
127
+ x1 = x[..., : x.shape[-1] // 2]
128
+ x2 = x[..., x.shape[-1] // 2 :]
129
+ return torch.cat((-x2, x1), dim=-1)
130
+
131
+
132
+ def apply_rotary_pos_emb_vision(
133
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
134
+ ) -> tuple[torch.Tensor, torch.Tensor]:
135
+
136
+ if HAS_QWEN_ROPE_V2 and q.is_cuda and q.dtype == torch.bfloat16 and q.shape[-1] in (64, 128):
137
+ # qwen_rope_kernel_2 handles (S, D) cos/sin for (S, H, D) input naturally.
138
+ # The kernel REQUIRES cos/sin to be 2D [S, D] if input is 3D [S, H, D].
139
+ # It DOES NOT support 3D [S, 1, D] for cos/sin.
140
+
141
+ if cos.dtype != torch.float32:
142
+ cos = cos.to(torch.float32)
143
+ if sin.dtype != torch.float32:
144
+ sin = sin.to(torch.float32)
145
+
146
+ # Proactively squeeze [S, 1, D] -> [S, D] to satisfy kernel requirements
147
+ # This is a view operation, zero memory copy overhead.
148
+ if cos.ndim == 3 and cos.shape[1] == 1:
149
+ cos = cos.squeeze(1)
150
+ sin = sin.squeeze(1)
151
+
152
+ return fused_qwen_rope_v2(q, cos, sin), fused_qwen_rope_v2(k, cos, sin)
153
+
154
+ orig_q_dtype = q.dtype
155
+ orig_k_dtype = k.dtype
156
+ q, k = q.float(), k.float()
157
+ if cos.ndim == 2:
158
+ cos = cos.unsqueeze(-2)
159
+ sin = sin.unsqueeze(-2)
160
+ if cos.dtype != torch.float32:
161
+ cos = cos.to(torch.float32)
162
+ if sin.dtype != torch.float32:
163
+ sin = sin.to(torch.float32)
164
+ q_embed = (q * cos) + (rotate_half(q) * sin)
165
+ k_embed = (k * cos) + (rotate_half(k) * sin)
166
+ return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype)
167
+
168
+
169
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
170
+ """
171
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
172
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
173
+ """
174
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
175
+ if n_rep == 1:
176
+ return hidden_states
177
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
178
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
179
+
180
+
181
+ def eager_attention_forward(
182
+ module: nn.Module,
183
+ query: torch.Tensor,
184
+ key: torch.Tensor,
185
+ value: torch.Tensor,
186
+ attention_mask: Optional[torch.Tensor],
187
+ scaling: float,
188
+ dropout: float = 0.0,
189
+ **kwargs: Unpack[TransformersKwargs],
190
+ ):
191
+ key_states = repeat_kv(key, module.num_key_value_groups)
192
+ value_states = repeat_kv(value, module.num_key_value_groups)
193
+
194
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
195
+ if attention_mask is not None:
196
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
197
+ attn_weights = attn_weights + causal_mask
198
+
199
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
200
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
201
+ attn_output = torch.matmul(attn_weights, value_states)
202
+ attn_output = attn_output.transpose(1, 2).contiguous()
203
+
204
+ return attn_output, attn_weights
205
+
206
+
207
+ class Qwen3VLVisionAttention(nn.Module):
208
+ def __init__(self, config: Qwen3VLVisionConfig) -> None:
209
+ super().__init__()
210
+ self.dim = config.hidden_size
211
+ self.num_heads = config.num_heads
212
+ self.head_dim = self.dim // self.num_heads
213
+ self.num_key_value_groups = 1 # needed for eager attention
214
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
215
+ self.proj = nn.Linear(self.dim, self.dim)
216
+ self.scaling = self.head_dim**-0.5
217
+ self.config = config
218
+ self.attention_dropout = 0.0
219
+ self.is_causal = False
220
+
221
+ def forward(
222
+ self,
223
+ hidden_states: torch.Tensor,
224
+ cu_seqlens: torch.Tensor,
225
+ rotary_pos_emb: Optional[torch.Tensor] = None,
226
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
227
+ **kwargs,
228
+ ) -> torch.Tensor:
229
+ seq_length = hidden_states.shape[0]
230
+ query_states, key_states, value_states = (
231
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
232
+ )
233
+ cos, sin = position_embeddings
234
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
235
+
236
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
237
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
238
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
239
+
240
+ attention_interface: Callable = eager_attention_forward
241
+ if self.config._attn_implementation != "eager":
242
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
243
+
244
+ if self.config._attn_implementation in ["flash_attention_2", "flash_attention_3"]:
245
+ # Flash Attention 2: Use cu_seqlens for variable length attention
246
+ if "image_max_seqlen" in kwargs and kwargs["image_max_seqlen"] is not None:
247
+ max_seqlen = kwargs["image_max_seqlen"]
248
+ else:
249
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
250
+
251
+ attn_output, _ = attention_interface(
252
+ self,
253
+ query_states,
254
+ key_states,
255
+ value_states,
256
+ attention_mask=None,
257
+ scaling=self.scaling,
258
+ dropout=0.0 if not self.training else self.attention_dropout,
259
+ cu_seq_lens_q=cu_seqlens,
260
+ cu_seq_lens_k=cu_seqlens,
261
+ max_length_q=max_seqlen,
262
+ max_length_k=max_seqlen,
263
+ is_causal=False,
264
+ **kwargs,
265
+ )
266
+ else:
267
+ # Other implementations: Process each chunk separately
268
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
269
+ splits = [
270
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
271
+ ]
272
+
273
+ attn_outputs = [
274
+ attention_interface(
275
+ self,
276
+ q,
277
+ k,
278
+ v,
279
+ attention_mask=None,
280
+ scaling=self.scaling,
281
+ dropout=0.0 if not self.training else self.attention_dropout,
282
+ is_causal=False,
283
+ **kwargs,
284
+ )[0]
285
+ for q, k, v in zip(*splits)
286
+ ]
287
+ attn_output = torch.cat(attn_outputs, dim=1)
288
+
289
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
290
+ attn_output = self.proj(attn_output)
291
+ return attn_output
292
+
293
+
294
+ class Qwen3VLVisionBlock(GradientCheckpointingLayer):
295
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
296
+ super().__init__()
297
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
298
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
299
+ self.attn = Qwen3VLVisionAttention(config=config)
300
+ self.mlp = Qwen3VLVisionMLP(config=config)
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ cu_seqlens: torch.Tensor,
306
+ rotary_pos_emb: Optional[torch.Tensor] = None,
307
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
308
+ **kwargs,
309
+ ) -> torch.Tensor:
310
+ hidden_states = hidden_states + self.attn(
311
+ self.norm1(hidden_states),
312
+ cu_seqlens=cu_seqlens,
313
+ rotary_pos_emb=rotary_pos_emb,
314
+ position_embeddings=position_embeddings,
315
+ **kwargs,
316
+ )
317
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
318
+ return hidden_states
319
+
320
+
321
+ class Qwen3VLTextRotaryEmbedding(nn.Module):
322
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
323
+
324
+ def __init__(self, config: Qwen3VLTextConfig, device=None):
325
+ super().__init__()
326
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
327
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
328
+ else:
329
+ self.rope_type = "default"
330
+ self.max_seq_len_cached = config.max_position_embeddings
331
+ self.original_max_seq_len = config.max_position_embeddings
332
+
333
+ self.config = config
334
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
335
+
336
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
337
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
338
+ self.original_inv_freq = self.inv_freq
339
+
340
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
341
+
342
+ def apply_interleaved_mrope(self, freqs, mrope_section):
343
+ """Apply interleaved MRoPE to 3D rotary embeddings.
344
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
345
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
346
+ args:
347
+ x: (3, bs, seq_len, head_dim // 2)
348
+ mrope_section: (3,)
349
+ returns:
350
+ x_t: (bs, seq_len, head_dim // 2)
351
+ """
352
+ freqs_t = freqs[0] # just overwrite the first dimension T
353
+ for dim, offset in enumerate((1, 2), start=1): # H, W
354
+ length = mrope_section[dim] * 3
355
+ idx = slice(offset, length, 3)
356
+ freqs_t[..., idx] = freqs[dim, ..., idx]
357
+ return freqs_t
358
+
359
+ @torch.no_grad()
360
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
361
+ def forward(self, x, position_ids):
362
+ if position_ids.ndim == 2:
363
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
364
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
365
+ device = inv_freq_expanded.device
366
+ position_ids_expanded = position_ids[:, :, None, :].float().to(device)
367
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3)
368
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
369
+ emb = torch.cat((freqs, freqs), dim=-1)
370
+ cos = emb.cos() * self.attention_scaling
371
+ sin = emb.sin() * self.attention_scaling
372
+ return cos.contiguous(), sin.contiguous()
373
+
374
+
375
+ @use_kernel_forward_from_hub("RMSNorm")
376
+ class Qwen3VLTextRMSNorm(nn.Module):
377
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
378
+ """
379
+ Qwen3VLTextRMSNorm is equivalent to T5LayerNorm
380
+ """
381
+ super().__init__()
382
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.bfloat16))
383
+ self.variance_epsilon = eps
384
+
385
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
386
+ if HAS_FUSED_RMSNORM and hidden_states.is_cuda:
387
+ x = hidden_states if hidden_states.dtype == torch.bfloat16 else hidden_states.to(torch.bfloat16)
388
+ x = x.contiguous()
389
+ return _FUSED_RMSFUNC.apply(x, self.weight, self.variance_epsilon, self.weight.shape[0])
390
+ input_dtype = hidden_states.dtype
391
+ hidden_states = hidden_states.to(torch.float32)
392
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
393
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
394
+ return self.weight * hidden_states.to(input_dtype)
395
+
396
+ def extra_repr(self):
397
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
398
+
399
+
400
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
401
+ """Applies Rotary Position Embedding to the query and key tensors.
402
+
403
+ Args:
404
+ q (`torch.Tensor`): The query tensor.
405
+ k (`torch.Tensor`): The key tensor.
406
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
407
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
408
+ position_ids (`torch.Tensor`, *optional*):
409
+ Deprecated and unused.
410
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
411
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
412
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
413
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
414
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
415
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
416
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
417
+ Returns:
418
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
419
+ """
420
+ if HAS_QWEN_ROPE_V2 and q.is_cuda and q.dtype == torch.bfloat16 and q.shape[-1] in (64, 128):
421
+ # qwen_rope_kernel_2 handles (S, D) cos/sin for (S, H, D) input naturally.
422
+ if cos.dtype != torch.float32:
423
+ cos = cos.to(torch.float32)
424
+ if sin.dtype != torch.float32:
425
+ sin = sin.to(torch.float32)
426
+ return fused_qwen_rope_v2(q, cos, sin), fused_qwen_rope_v2(k, cos, sin)
427
+
428
+ if cos.ndim != q.ndim:
429
+ cos = cos.unsqueeze(unsqueeze_dim)
430
+ sin = sin.unsqueeze(unsqueeze_dim)
431
+ if cos.dtype != q.dtype:
432
+ cos = cos.to(q.dtype)
433
+ sin = sin.to(q.dtype)
434
+ q_embed = (q * cos) + (rotate_half(q) * sin)
435
+ k_embed = (k * cos) + (rotate_half(k) * sin)
436
+ return q_embed, k_embed
437
+
438
+
439
+ class Qwen3VLTextAttention(nn.Module):
440
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
441
+
442
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
443
+ super().__init__()
444
+ self.config = config
445
+ self.layer_idx = layer_idx
446
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
447
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
448
+ self.scaling = self.head_dim**-0.5
449
+ self.attention_dropout = config.attention_dropout
450
+ self.is_causal = True
451
+
452
+ self.q_proj = nn.Linear(
453
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
454
+ )
455
+ self.k_proj = nn.Linear(
456
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
457
+ )
458
+ self.v_proj = nn.Linear(
459
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
460
+ )
461
+ self.o_proj = nn.Linear(
462
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
463
+ )
464
+ self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
465
+ self.k_norm = Qwen3VLTextRMSNorm(
466
+ self.head_dim, eps=config.rms_norm_eps
467
+ ) # thus post q_norm does not need reshape
468
+
469
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
470
+ def forward(
471
+ self,
472
+ hidden_states: torch.Tensor,
473
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
474
+ attention_mask: Optional[torch.Tensor],
475
+ past_key_values: Optional[Cache] = None,
476
+ cache_position: Optional[torch.LongTensor] = None,
477
+ **kwargs: Unpack[FlashAttentionKwargs],
478
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
479
+ input_shape = hidden_states.shape[:-1]
480
+ hidden_shape = (*input_shape, -1, self.head_dim)
481
+
482
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
483
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
484
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
485
+
486
+ cos, sin = position_embeddings
487
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
488
+
489
+ if past_key_values is not None:
490
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
491
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
492
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
493
+
494
+ attention_interface: Callable = eager_attention_forward
495
+ if self.config._attn_implementation != "eager":
496
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
497
+
498
+ attn_output, attn_weights = attention_interface(
499
+ self,
500
+ query_states,
501
+ key_states,
502
+ value_states,
503
+ attention_mask,
504
+ dropout=0.0 if not self.training else self.attention_dropout,
505
+ scaling=self.scaling,
506
+ **kwargs,
507
+ )
508
+
509
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
510
+ attn_output = self.o_proj(attn_output)
511
+ return attn_output, attn_weights
512
+
513
+
514
+ class Qwen3VLTextMLP(nn.Module):
515
+ def __init__(self, config):
516
+ super().__init__()
517
+ self.config = config
518
+ self.hidden_size = config.hidden_size
519
+ self.intermediate_size = config.intermediate_size
520
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
521
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
522
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
523
+ self.act_fn = ACT2FN[config.hidden_act]
524
+
525
+ def forward(self, x):
526
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
527
+ return down_proj
528
+
529
+
530
+ class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer):
531
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
532
+ super().__init__()
533
+ self.hidden_size = config.hidden_size
534
+
535
+ self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx)
536
+
537
+ self.mlp = Qwen3VLTextMLP(config)
538
+ self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
539
+ self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
540
+
541
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
542
+ def forward(
543
+ self,
544
+ hidden_states: torch.Tensor,
545
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
546
+ attention_mask: Optional[torch.Tensor] = None,
547
+ position_ids: Optional[torch.LongTensor] = None,
548
+ past_key_values: Optional[Cache] = None,
549
+ use_cache: Optional[bool] = False,
550
+ cache_position: Optional[torch.LongTensor] = None,
551
+ **kwargs: Unpack[TransformersKwargs],
552
+ ) -> torch.Tensor:
553
+ residual = hidden_states
554
+ hidden_states = self.input_layernorm(hidden_states)
555
+ # Self Attention. DEBUG: When we use packing mode, here we would enter `qwen3vl_forward` in `train_utils.py`
556
+ hidden_states, _ = self.self_attn(
557
+ hidden_states=hidden_states,
558
+ attention_mask=attention_mask,
559
+ position_ids=position_ids,
560
+ past_key_values=past_key_values,
561
+ use_cache=use_cache,
562
+ cache_position=cache_position,
563
+ position_embeddings=position_embeddings,
564
+ **kwargs,
565
+ )
566
+ hidden_states = residual + hidden_states
567
+
568
+ # Fully Connected
569
+ residual = hidden_states
570
+ hidden_states = self.post_attention_layernorm(hidden_states)
571
+ hidden_states = self.mlp(hidden_states)
572
+ hidden_states = residual + hidden_states
573
+ return hidden_states
574
+
575
+
576
+ @dataclass
577
+ @auto_docstring(
578
+ custom_intro="""
579
+ Base class for Llava outputs, with hidden states and attentions.
580
+ """
581
+ )
582
+ class Qwen3VLModelOutputWithPast(ModelOutput):
583
+ r"""
584
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
585
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
586
+
587
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
588
+ `past_key_values` input) to speed up sequential decoding.
589
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
590
+ The rope index difference between sequence length and multimodal rope.
591
+ """
592
+
593
+ last_hidden_state: Optional[torch.FloatTensor] = None
594
+ past_key_values: Optional[Cache] = None
595
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
596
+ attentions: Optional[tuple[torch.FloatTensor]] = None
597
+ rope_deltas: Optional[torch.LongTensor] = None
598
+
599
+
600
+ @auto_docstring
601
+ class Qwen3VLPreTrainedModel(PreTrainedModel):
602
+ config: Qwen3VLConfig
603
+ base_model_prefix = "model"
604
+ supports_gradient_checkpointing = True
605
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
606
+ _skip_keys_device_placement = "past_key_values"
607
+ _supports_flash_attn = True
608
+ _supports_sdpa = True
609
+
610
+ _can_compile_fullgraph = True
611
+ _supports_attention_backend = True
612
+ _can_record_outputs = {
613
+ "hidden_states": Qwen3VLTextDecoderLayer,
614
+ "attentions": Qwen3VLTextAttention,
615
+ }
616
+
617
+
618
+ class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):
619
+ config: Qwen3VLVisionConfig
620
+ _no_split_modules = ["Qwen3VLVisionBlock"]
621
+
622
+ def __init__(self, config, *inputs, **kwargs) -> None:
623
+ super().__init__(config, *inputs, **kwargs)
624
+ self.spatial_merge_size = config.spatial_merge_size
625
+ self.patch_size = config.patch_size
626
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
627
+
628
+ self.patch_embed = Qwen3VLVisionPatchEmbed(
629
+ config=config,
630
+ )
631
+
632
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
633
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
634
+
635
+ head_dim = config.hidden_size // config.num_heads
636
+ self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)
637
+
638
+ self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])
639
+ self.merger = Qwen3VLVisionPatchMerger(
640
+ config=config,
641
+ use_postshuffle_norm=False,
642
+ )
643
+
644
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
645
+ self.deepstack_merger_list = nn.ModuleList(
646
+ [
647
+ Qwen3VLVisionPatchMerger(
648
+ config=config,
649
+ use_postshuffle_norm=True,
650
+ )
651
+ for _ in range(len(config.deepstack_visual_indexes))
652
+ ]
653
+ )
654
+
655
+ self.gradient_checkpointing = False
656
+
657
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
658
+ merge_size = self.spatial_merge_size
659
+
660
+ max_hw = int(grid_thw[:, 1:].max().item())
661
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
662
+ device = freq_table.device
663
+
664
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
665
+ # pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
666
+ pos_ids_cpu = torch.empty((total_tokens, 2) , dtype=torch.long , device="cpu")
667
+
668
+
669
+ offset = 0
670
+ for num_frames, height, width in grid_thw.numpy():
671
+ merged_h, merged_w = height // merge_size, width // merge_size
672
+
673
+ block_rows = torch.arange(merged_h, device="cpu") # block row indices
674
+ block_cols = torch.arange(merged_w, device="cpu") # block col indices
675
+ intra_row = torch.arange(merge_size, device="cpu") # intra-block row offsets
676
+ intra_col = torch.arange(merge_size, device="cpu") # intra-block col offsets
677
+
678
+ # Compute full-resolution positions
679
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
680
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
681
+
682
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
683
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
684
+
685
+ coords = torch.stack((row_idx, col_idx), dim=-1)
686
+
687
+ if num_frames > 1:
688
+ coords = coords.repeat(num_frames, 1)
689
+
690
+ num_tokens = coords.shape[0]
691
+ pos_ids_cpu[offset : offset + num_tokens] = coords
692
+ offset += num_tokens
693
+
694
+ pos_ids = pos_ids_cpu.to(device , non_blocking=True)
695
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
696
+ embeddings = embeddings.flatten(1)
697
+ return embeddings
698
+
699
+ def fast_pos_embed_interpolate(self, grid_thw):
700
+ # grid_thw 已经是 CPU Tensor,直接解包
701
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
702
+
703
+ idx_accum = [[] for _ in range(4)]
704
+ weight_accum = [[] for _ in range(4)]
705
+
706
+ # 预取配置,避免循环内 getattr
707
+ num_grid = self.num_grid_per_side
708
+
709
+ # 这一步依然需要在 CPU 循环计算,因为 H/W 是变长的,但这只是纯算数,很快
710
+ for h, w in zip(grid_hs, grid_ws):
711
+
712
+ h_idxs = torch.linspace(0, num_grid - 1, h)
713
+ w_idxs = torch.linspace(0, num_grid - 1, w)
714
+
715
+ h_idxs_floor = h_idxs.int()
716
+ w_idxs_floor = w_idxs.int()
717
+
718
+
719
+ h_idxs_ceil = (h_idxs_floor + 1).clamp(max=num_grid - 1)
720
+ w_idxs_ceil = (w_idxs_floor + 1).clamp(max=num_grid - 1)
721
+
722
+ dh = h_idxs - h_idxs_floor
723
+ dw = w_idxs - w_idxs_floor
724
+
725
+ base_h = h_idxs_floor * num_grid
726
+ base_h_ceil = h_idxs_ceil * num_grid
727
+
728
+
729
+ indices = [
730
+ (base_h[:, None] + w_idxs_floor[None, :]).flatten(),
731
+ (base_h[:, None] + w_idxs_ceil[None, :]).flatten(),
732
+ (base_h_ceil[:, None] + w_idxs_floor[None, :]).flatten(),
733
+ (base_h_ceil[:, None] + w_idxs_ceil[None, :]).flatten(),
734
+ ]
735
+
736
+ weights = [
737
+ ((1 - dh)[:, None] * (1 - dw)[None, :]).flatten(),
738
+ ((1 - dh)[:, None] * dw[None, :]).flatten(),
739
+ (dh[:, None] * (1 - dw)[None, :]).flatten(),
740
+ (dh[:, None] * dw[None, :]).flatten(),
741
+ ]
742
+
743
+ # 直接 Append Tensor,不做 tolist()
744
+ for i in range(4):
745
+ idx_accum[i].append(indices[i])
746
+ weight_accum[i].append(weights[i])
747
+
748
+
749
+ target_device = self.pos_embed.weight.device
750
+ target_dtype = self.pos_embed.weight.dtype
751
+
752
+ idx_tensor = torch.stack([torch.cat(acc) for acc in idx_accum]).to(device=target_device, dtype=torch.long)
753
+ weight_tensor = torch.stack([torch.cat(acc) for acc in weight_accum]).to(device=target_device, dtype=target_dtype)
754
+
755
+
756
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
757
+ patch_pos_embeds = pos_embeds.sum(dim=0)
758
+
759
+
760
+ merge_size = self.config.spatial_merge_size
761
+ indices_list = []
762
+ current_offset = 0
763
+
764
+
765
+ for t, h, w in zip(grid_ts.tolist(), grid_hs.tolist(), grid_ws.tolist()):
766
+
767
+ local_ids = torch.arange(h * w, device='cpu').view(h, w)
768
+
769
+
770
+ local_ids_permuted = (
771
+ local_ids.view(h // merge_size, merge_size, w // merge_size, merge_size)
772
+ .permute(0, 2, 1, 3)
773
+ .reshape(-1)
774
+ )
775
+
776
+
777
+ global_ids = local_ids_permuted + current_offset
778
+
779
+
780
+ if t > 1:
781
+ global_ids = global_ids.repeat(t)
782
+
783
+ indices_list.append(global_ids)
784
+ current_offset += h * w
785
+
786
+
787
+ all_indices = torch.cat(indices_list).to(target_device)
788
+
789
+
790
+ patch_pos_embeds = patch_pos_embeds[all_indices]
791
+
792
+ return patch_pos_embeds
793
+
794
+
795
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
796
+ """
797
+ Args:
798
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
799
+ The final hidden states of the model.
800
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
801
+ The temporal, height and width of feature shape of each image in LLM.
802
+
803
+ Returns:
804
+ `torch.Tensor`: hidden_states.
805
+ """
806
+ hidden_states = self.patch_embed(hidden_states)
807
+
808
+ #move grid_thw to cpu
809
+ grid_thw_cpu = grid_thw.cpu()
810
+
811
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw_cpu)
812
+ hidden_states = hidden_states + pos_embeds
813
+
814
+ rotary_pos_emb = self.rot_pos_emb(grid_thw_cpu)
815
+
816
+ seq_len, _ = hidden_states.size()
817
+ hidden_states = hidden_states.reshape(seq_len, -1)
818
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
819
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
820
+ cos = emb.cos().to(torch.float32).unsqueeze(-2).contiguous()
821
+ sin = emb.sin().to(torch.float32).unsqueeze(-2).contiguous()
822
+ cos = cos.to(device=hidden_states.device, non_blocking=True)
823
+ sin = sin.to(device=hidden_states.device, non_blocking=True)
824
+ position_embeddings = (cos, sin)
825
+
826
+ #use the grid_thw in gpu
827
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
828
+ dim=0,
829
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
830
+ )
831
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
832
+ cu_seqlens = cu_seqlens.to(device=hidden_states.device)
833
+
834
+
835
+ deepstack_feature_lists = []
836
+ for layer_num, blk in enumerate(self.blocks):
837
+ if self.gradient_checkpointing and self.training:
838
+ blk.gradient_checkpointing = False
839
+ def create_custom_forward(module):
840
+ def custom_forward(*inputs):
841
+ return module(inputs[0], inputs[1], inputs[2], inputs[3], **inputs[4])
842
+ return custom_forward
843
+
844
+ hidden_states = self._gradient_checkpointing_func(
845
+ create_custom_forward(blk),
846
+ hidden_states,
847
+ cu_seqlens,
848
+ None,
849
+ position_embeddings,
850
+ kwargs,
851
+ )
852
+ else:
853
+ hidden_states = blk(
854
+ hidden_states,
855
+ cu_seqlens=cu_seqlens,
856
+ position_embeddings=position_embeddings,
857
+ **kwargs,
858
+ )
859
+ if layer_num in self.deepstack_visual_indexes:
860
+ deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
861
+ hidden_states
862
+ )
863
+ deepstack_feature_lists.append(deepstack_feature)
864
+
865
+ hidden_states = self.merger(hidden_states)
866
+
867
+ return hidden_states, deepstack_feature_lists
868
+
869
+
870
+ @auto_docstring(
871
+ custom_intro=(
872
+ "Text part of Qwen3VL, "
873
+ "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
874
+ )
875
+ )
876
+ class Qwen3VLTextModel(Qwen3VLPreTrainedModel):
877
+ config: Qwen3VLTextConfig
878
+ _no_split_modules = ["Qwen3VLTextDecoderLayer"]
879
+
880
+ def __init__(self, config: Qwen3VLTextConfig):
881
+ super().__init__(config)
882
+ self.padding_idx = config.pad_token_id
883
+ self.vocab_size = config.vocab_size
884
+
885
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
886
+ self.layers = nn.ModuleList(
887
+ [Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
888
+ )
889
+ self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
890
+ self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)
891
+ self.gradient_checkpointing = False
892
+
893
+ # Initialize weights and apply final processing
894
+ self.post_init()
895
+
896
+
897
+ def get_input_embeddings(self):
898
+ return self.embed_tokens
899
+
900
+ def set_input_embeddings(self, value):
901
+ self.embed_tokens = value
902
+
903
+ @check_model_inputs()
904
+ @auto_docstring
905
+ def forward(
906
+ self,
907
+ input_ids: Optional[torch.LongTensor] = None,
908
+ attention_mask: Optional[torch.Tensor] = None,
909
+ position_ids: Optional[torch.LongTensor] = None,
910
+ past_key_values: Optional[Cache] = None,
911
+ inputs_embeds: Optional[torch.FloatTensor] = None,
912
+ use_cache: Optional[bool] = None,
913
+ cache_position: Optional[torch.LongTensor] = None,
914
+ # args for deepstack
915
+ visual_pos_masks: Optional[torch.Tensor] = None,
916
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
917
+ **kwargs: Unpack[FlashAttentionKwargs],
918
+ ) -> Union[tuple, BaseModelOutputWithPast]:
919
+ r"""
920
+ visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
921
+ The mask of the visual positions.
922
+ deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
923
+ The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
924
+ The feature is extracted from the different visual encoder layers, and fed to the decoder
925
+ hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
926
+ """
927
+ if (input_ids is None) ^ (inputs_embeds is not None):
928
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
929
+
930
+ # torch.jit.trace() doesn't support cache objects in the output
931
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
932
+ past_key_values = DynamicCache(config=self.config)
933
+
934
+ if inputs_embeds is None:
935
+ inputs_embeds = self.embed_tokens(input_ids)
936
+
937
+ if cache_position is None:
938
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
939
+ cache_position = torch.arange(
940
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
941
+ )
942
+
943
+ # the hard coded `3` is for temporal, height and width.
944
+ if position_ids is None:
945
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) # (3, bs, seq_length)
946
+ elif position_ids.ndim == 2:
947
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
948
+
949
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
950
+ text_position_ids = position_ids[0]
951
+ position_ids = position_ids[1:]
952
+ else:
953
+ text_position_ids = position_ids[0]
954
+ # NOTE: Attention! When we use packing mode, this `create_causal_mask` is overwrited, and directly return `attention_mask`.
955
+ attention_mask = create_causal_mask(
956
+ config=self.config,
957
+ input_embeds=inputs_embeds,
958
+ attention_mask=attention_mask,
959
+ cache_position=cache_position,
960
+ past_key_values=past_key_values,
961
+ position_ids=text_position_ids,
962
+ )
963
+
964
+ hidden_states = inputs_embeds
965
+
966
+ # create position embeddings to be shared across the decoder layers
967
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
968
+ cos, sin = position_embeddings
969
+ cos = cos.to(device=hidden_states.device, non_blocking=True).unsqueeze(1).contiguous()
970
+ sin = sin.to(device=hidden_states.device, non_blocking=True).unsqueeze(1).contiguous()
971
+ position_embeddings = (cos, sin)
972
+
973
+ # decoder layers
974
+ for layer_idx, decoder_layer in enumerate(self.layers):
975
+ if self.gradient_checkpointing and self.training:
976
+ decoder_layer.gradient_checkpointing = False
977
+ def create_custom_forward(module): # DEBUG: Here we enter the Qwen3VLTextDecoderLayer forward
978
+ def custom_forward(*inputs):
979
+ # inputs: hidden_states, position_embeddings, attention_mask, position_ids, past_key_values, use_cache, cache_position, kwargs_dict
980
+ return module(
981
+ inputs[0],
982
+ inputs[1],
983
+ attention_mask=inputs[2],
984
+ position_ids=inputs[3],
985
+ past_key_values=inputs[4],
986
+ use_cache=inputs[5],
987
+ cache_position=inputs[6],
988
+ **inputs[7]
989
+ )
990
+ return custom_forward
991
+
992
+ layer_outputs = self._gradient_checkpointing_func(
993
+ create_custom_forward(decoder_layer),
994
+ hidden_states,
995
+ position_embeddings,
996
+ attention_mask,
997
+ text_position_ids,
998
+ past_key_values,
999
+ False, # use_cache
1000
+ cache_position,
1001
+ kwargs,
1002
+ )
1003
+ else:
1004
+ layer_outputs = decoder_layer(
1005
+ hidden_states,
1006
+ attention_mask=attention_mask,
1007
+ position_ids=text_position_ids,
1008
+ past_key_values=past_key_values,
1009
+ cache_position=cache_position,
1010
+ position_embeddings=position_embeddings,
1011
+ **kwargs,
1012
+ )
1013
+ hidden_states = layer_outputs
1014
+
1015
+ # add visual features to the hidden states of first several layers
1016
+ if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
1017
+ hidden_states = self._deepstack_process(
1018
+ hidden_states,
1019
+ visual_pos_masks,
1020
+ deepstack_visual_embeds[layer_idx],
1021
+ )
1022
+
1023
+ hidden_states = self.norm(hidden_states)
1024
+
1025
+ return BaseModelOutputWithPast(
1026
+ last_hidden_state=hidden_states,
1027
+ past_key_values=past_key_values,
1028
+ )
1029
+
1030
+ def _deepstack_process(
1031
+ self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
1032
+ ):
1033
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
1034
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
1035
+ local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
1036
+ hidden_states[visual_pos_masks, :] = local_this
1037
+ return hidden_states
1038
+
1039
+
1040
+ @dataclass
1041
+ @auto_docstring(
1042
+ custom_intro="""
1043
+ Base class for Qwen3VL causal language model (or autoregressive) outputs.
1044
+ """
1045
+ )
1046
+ class Qwen3VLCausalLMOutputWithPast(ModelOutput):
1047
+ r"""
1048
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1049
+ Language modeling loss (for next-token prediction).
1050
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1051
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1052
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1053
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
1054
+
1055
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1056
+ `past_key_values` input) to speed up sequential decoding.
1057
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1058
+ The rope index difference between sequence length and multimodal rope.
1059
+ """
1060
+
1061
+ loss: Optional[torch.FloatTensor] = None
1062
+ logits: Optional[torch.FloatTensor] = None
1063
+ past_key_values: Optional[Cache] = None
1064
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1065
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1066
+ rope_deltas: Optional[torch.LongTensor] = None
1067
+
1068
+
1069
+ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
1070
+ _checkpoint_conversion_mapping = {}
1071
+ _tied_weights_keys = ["lm_head.weight"]
1072
+ # Reference: fix gemma3 grad acc #37208
1073
+ accepts_loss_kwargs = False
1074
+ config: Qwen3VLConfig
1075
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
1076
+
1077
+ def __init__(self, config):
1078
+ super().__init__(config)
1079
+ # Directly initialize visual and language_model instead of using Qwen3VLModel
1080
+ self.visual = Qwen3VLVisionModel._from_config(config.vision_config)
1081
+ self.language_model = Qwen3VLTextModel._from_config(config.text_config)
1082
+ self.rope_deltas = None # cache rope_deltas here
1083
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1084
+
1085
+ self.post_init()
1086
+
1087
+ def get_input_embeddings(self):
1088
+ return self.language_model.get_input_embeddings()
1089
+
1090
+ def set_input_embeddings(self, value):
1091
+ self.language_model.set_input_embeddings(value)
1092
+
1093
+ def set_decoder(self, decoder):
1094
+ self.language_model = decoder
1095
+
1096
+ def get_decoder(self):
1097
+ return self.language_model
1098
+
1099
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
1100
+ self.gradient_checkpointing = True
1101
+ self.visual.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
1102
+ self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
1103
+
1104
+
1105
+ def get_rope_index(
1106
+ self,
1107
+ input_ids: Optional[torch.LongTensor] = None,
1108
+ image_grid_thw: Optional[torch.LongTensor] = None,
1109
+ video_grid_thw: Optional[torch.LongTensor] = None,
1110
+ attention_mask: Optional[torch.Tensor] = None,
1111
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1112
+ """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
1113
+
1114
+ # Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
1115
+ if video_grid_thw is not None:
1116
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
1117
+ video_grid_thw[:, 0] = 1
1118
+
1119
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
1120
+ image_token_id = self.config.image_token_id
1121
+ video_token_id = self.config.video_token_id
1122
+ vision_start_token_id = self.config.vision_start_token_id
1123
+ mrope_position_deltas = []
1124
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
1125
+ total_input_ids = input_ids
1126
+ if attention_mask is None:
1127
+ attention_mask = torch.ones_like(total_input_ids)
1128
+ position_ids = torch.ones(
1129
+ 3,
1130
+ input_ids.shape[0],
1131
+ input_ids.shape[1],
1132
+ dtype=input_ids.dtype,
1133
+ device=input_ids.device,
1134
+ )
1135
+ image_index, video_index = 0, 0
1136
+ attention_mask = attention_mask.to(total_input_ids.device)
1137
+ for i, input_ids in enumerate(total_input_ids):
1138
+ input_ids = input_ids[attention_mask[i] == 1]
1139
+ image_nums, video_nums = 0, 0
1140
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1141
+ vision_tokens = input_ids[vision_start_indices + 1]
1142
+ image_nums = (vision_tokens == image_token_id).sum()
1143
+ video_nums = (vision_tokens == video_token_id).sum()
1144
+ input_tokens = input_ids.tolist()
1145
+ llm_pos_ids_list: list = []
1146
+ st = 0
1147
+ remain_images, remain_videos = image_nums, video_nums
1148
+ for _ in range(image_nums + video_nums):
1149
+ if image_token_id in input_tokens and remain_images > 0:
1150
+ ed_image = input_tokens.index(image_token_id, st)
1151
+ else:
1152
+ ed_image = len(input_tokens) + 1
1153
+ if video_token_id in input_tokens and remain_videos > 0:
1154
+ ed_video = input_tokens.index(video_token_id, st)
1155
+ else:
1156
+ ed_video = len(input_tokens) + 1
1157
+ if ed_image < ed_video:
1158
+ t, h, w = (
1159
+ image_grid_thw[image_index][0],
1160
+ image_grid_thw[image_index][1],
1161
+ image_grid_thw[image_index][2],
1162
+ )
1163
+ image_index += 1
1164
+ remain_images -= 1
1165
+ ed = ed_image
1166
+
1167
+ else:
1168
+ t, h, w = (
1169
+ video_grid_thw[video_index][0],
1170
+ video_grid_thw[video_index][1],
1171
+ video_grid_thw[video_index][2],
1172
+ )
1173
+ video_index += 1
1174
+ remain_videos -= 1
1175
+ ed = ed_video
1176
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1177
+ t.item(),
1178
+ h.item() // spatial_merge_size,
1179
+ w.item() // spatial_merge_size,
1180
+ )
1181
+ text_len = ed - st
1182
+
1183
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1184
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1185
+
1186
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
1187
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1188
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1189
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1190
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1191
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1192
+
1193
+ if st < len(input_tokens):
1194
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1195
+ text_len = len(input_tokens) - st
1196
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1197
+
1198
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1199
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1200
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1201
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1202
+ return position_ids, mrope_position_deltas
1203
+ else:
1204
+ if attention_mask is not None:
1205
+ position_ids = attention_mask.long().cumsum(-1) - 1
1206
+ position_ids.masked_fill_(attention_mask == 0, 1)
1207
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1208
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1209
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1210
+ else:
1211
+ position_ids = (
1212
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1213
+ .view(1, 1, -1)
1214
+ .expand(3, input_ids.shape[0], -1)
1215
+ )
1216
+ mrope_position_deltas = torch.zeros(
1217
+ [input_ids.shape[0], 1],
1218
+ device=input_ids.device,
1219
+ dtype=input_ids.dtype,
1220
+ )
1221
+
1222
+ return position_ids, mrope_position_deltas
1223
+
1224
+ def get_video_features(
1225
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1226
+ ):
1227
+ """
1228
+ Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
1229
+
1230
+ Args:
1231
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1232
+ The tensors corresponding to the input videos.
1233
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1234
+ The temporal, height and width of feature shape of each video in LLM.
1235
+ """
1236
+ # Same implementation as for images
1237
+ return self.get_image_features(pixel_values_videos, video_grid_thw)
1238
+
1239
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None, **kwargs):
1240
+ """
1241
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1242
+
1243
+ Args:
1244
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1245
+ The tensors corresponding to the input images.
1246
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1247
+ The temporal, height and width of feature shape of each image in LLM.
1248
+ """
1249
+ pixel_values = pixel_values.type(self.visual.dtype)
1250
+ image_embeds, deepstack_feature_lists = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs)
1251
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1252
+ image_embeds = torch.split(image_embeds, split_sizes)
1253
+ return image_embeds, deepstack_feature_lists
1254
+
1255
+ def get_placeholder_mask(
1256
+ self,
1257
+ input_ids: torch.LongTensor,
1258
+ inputs_embeds: torch.FloatTensor,
1259
+ image_features: Optional[torch.FloatTensor] = None,
1260
+ video_features: Optional[torch.FloatTensor] = None,
1261
+ ):
1262
+ """
1263
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1264
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1265
+ """
1266
+ if input_ids is None:
1267
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1268
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1269
+ )
1270
+ special_image_mask = special_image_mask.all(-1)
1271
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
1272
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
1273
+ )
1274
+ special_video_mask = special_video_mask.all(-1)
1275
+ else:
1276
+ special_image_mask = input_ids == self.config.image_token_id
1277
+ special_video_mask = input_ids == self.config.video_token_id
1278
+
1279
+ n_image_tokens = special_image_mask.sum()
1280
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1281
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
1282
+ raise ValueError(
1283
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
1284
+ )
1285
+
1286
+ n_video_tokens = special_video_mask.sum()
1287
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1288
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
1289
+ raise ValueError(
1290
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
1291
+ )
1292
+
1293
+ return special_image_mask, special_video_mask
1294
+
1295
+ @check_model_inputs()
1296
+ def forward(
1297
+ self,
1298
+ input_ids: torch.LongTensor = None,
1299
+ attention_mask: Optional[torch.Tensor] = None,
1300
+ position_ids: Optional[torch.LongTensor] = None,
1301
+ past_key_values: Optional[Cache] = None,
1302
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1303
+ labels: Optional[torch.LongTensor] = None,
1304
+ pixel_values: Optional[torch.Tensor] = None,
1305
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1306
+ image_grid_thw: Optional[torch.LongTensor] = None,
1307
+ video_grid_thw: Optional[torch.LongTensor] = None,
1308
+ cache_position: Optional[torch.LongTensor] = None,
1309
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1310
+ **kwargs: Unpack[TransformersKwargs],
1311
+ ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]:
1312
+ r"""
1313
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1314
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1315
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1316
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1317
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1318
+ The temporal, height and width of feature shape of each image in LLM.
1319
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1320
+ The temporal, height and width of feature shape of each video in LLM.
1321
+
1322
+ Example:
1323
+ TODO: Add example
1324
+ """
1325
+ # Inlined from Qwen3VLModel.forward
1326
+ if (input_ids is None) ^ (inputs_embeds is not None):
1327
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1328
+
1329
+ if inputs_embeds is None:
1330
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1331
+
1332
+ image_mask = None
1333
+ video_mask = None
1334
+
1335
+ if pixel_values is not None:
1336
+ image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw, image_max_seqlen=kwargs.get("image_max_seqlen"))
1337
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1338
+
1339
+ image_mask, _ = self.get_placeholder_mask(
1340
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1341
+ )
1342
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1343
+
1344
+
1345
+ if pixel_values_videos is not None:
1346
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
1347
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1348
+
1349
+ _, video_mask = self.get_placeholder_mask(
1350
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1351
+ )
1352
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1353
+
1354
+
1355
+ visual_pos_masks = None
1356
+ deepstack_visual_embeds = None
1357
+ if image_mask is not None and video_mask is not None:
1358
+ # aggregate visual_pos_masks and deepstack_visual_embeds
1359
+ image_mask = image_mask[..., 0]
1360
+ video_mask = video_mask[..., 0]
1361
+ visual_pos_masks = image_mask | video_mask
1362
+ deepstack_visual_embeds = []
1363
+ image_mask_joint = image_mask[visual_pos_masks]
1364
+ video_mask_joint = video_mask[visual_pos_masks]
1365
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
1366
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
1367
+ embed_joint[image_mask_joint, :] = img_embed
1368
+ embed_joint[video_mask_joint, :] = vid_embed
1369
+ deepstack_visual_embeds.append(embed_joint)
1370
+ elif image_mask is not None:
1371
+ image_mask = image_mask[..., 0]
1372
+ visual_pos_masks = image_mask
1373
+ deepstack_visual_embeds = deepstack_image_embeds
1374
+ elif video_mask is not None:
1375
+ video_mask = video_mask[..., 0]
1376
+ visual_pos_masks = video_mask
1377
+ deepstack_visual_embeds = deepstack_video_embeds
1378
+
1379
+ if position_ids is None:
1380
+ attention_mask_tensor = (
1381
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1382
+ )
1383
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1384
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1385
+ # Only apply conversion for floating point tensors (inverted masks)
1386
+ if attention_mask_tensor.dtype.is_floating_point:
1387
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1388
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1389
+
1390
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1391
+ # When compiling, we can't check tensor values thus we check only input length
1392
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1393
+ # models currently cannot do asssisted decoding
1394
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
1395
+ (input_ids is not None and input_ids.shape[1] != 1)
1396
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1397
+ )
1398
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1399
+ (cache_position is not None and cache_position[0] == 0)
1400
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
1401
+ )
1402
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1403
+ position_ids, rope_deltas = self.get_rope_index(
1404
+ input_ids,
1405
+ image_grid_thw,
1406
+ video_grid_thw,
1407
+ attention_mask=attention_mask_tensor,
1408
+ )
1409
+ self.rope_deltas = rope_deltas
1410
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1411
+ else:
1412
+ batch_size, seq_length, _ = inputs_embeds.shape
1413
+ delta = (
1414
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1415
+ if cache_position is not None
1416
+ else 0
1417
+ )
1418
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1419
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1420
+ if cache_position is not None: # otherwise `deltas` is an int `0`
1421
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1422
+ position_ids = position_ids.add(delta)
1423
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1424
+
1425
+ if kwargs.get("max_seqlen") is not None:
1426
+ try:
1427
+ self.language_model.config.max_seqlen = int(kwargs.get("max_seqlen"))
1428
+ except Exception:
1429
+ self.language_model.config.max_seqlen = kwargs.get("max_seqlen")
1430
+
1431
+ outputs = self.language_model(
1432
+ input_ids=None,
1433
+ position_ids=position_ids,
1434
+ attention_mask=attention_mask,
1435
+ past_key_values=past_key_values,
1436
+ inputs_embeds=inputs_embeds,
1437
+ cache_position=cache_position,
1438
+ visual_pos_masks=visual_pos_masks,
1439
+ deepstack_visual_embeds=deepstack_visual_embeds,
1440
+ **kwargs,
1441
+ )
1442
+
1443
+ hidden_states = outputs[0]
1444
+
1445
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1446
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1447
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1448
+
1449
+ loss = None
1450
+ if labels is not None:
1451
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
1452
+
1453
+ return Qwen3VLCausalLMOutputWithPast(
1454
+ loss=loss,
1455
+ logits=logits,
1456
+ past_key_values=outputs.past_key_values,
1457
+ rope_deltas=self.rope_deltas,
1458
+ )
1459
+
1460
+ def prepare_inputs_for_generation(
1461
+ self,
1462
+ input_ids,
1463
+ past_key_values=None,
1464
+ attention_mask=None,
1465
+ inputs_embeds=None,
1466
+ cache_position=None,
1467
+ position_ids=None,
1468
+ use_cache=True,
1469
+ pixel_values=None,
1470
+ pixel_values_videos=None,
1471
+ image_grid_thw=None,
1472
+ video_grid_thw=None,
1473
+ **kwargs,
1474
+ ):
1475
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1476
+
1477
+ model_inputs = super().prepare_inputs_for_generation(
1478
+ input_ids,
1479
+ past_key_values=past_key_values,
1480
+ attention_mask=attention_mask,
1481
+ inputs_embeds=inputs_embeds,
1482
+ cache_position=cache_position,
1483
+ position_ids=position_ids,
1484
+ pixel_values=pixel_values,
1485
+ pixel_values_videos=pixel_values_videos,
1486
+ image_grid_thw=image_grid_thw,
1487
+ video_grid_thw=video_grid_thw,
1488
+ use_cache=use_cache,
1489
+ **kwargs,
1490
+ )
1491
+
1492
+ # Qwen3VL position_ids are prepareed with rope_deltas in forward
1493
+ model_inputs["position_ids"] = None
1494
+
1495
+ if cache_position[0] != 0:
1496
+ model_inputs["pixel_values"] = None
1497
+ model_inputs["pixel_values_videos"] = None
1498
+
1499
+ return model_inputs
1500
+
1501
+ def _get_image_nums_and_video_nums(
1502
+ self,
1503
+ input_ids: Optional[torch.LongTensor],
1504
+ inputs_embeds: Optional[torch.Tensor] = None,
1505
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1506
+ """
1507
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1508
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1509
+
1510
+ Args:
1511
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1512
+ Indices of input sequence tokens in the vocabulary.
1513
+
1514
+ Returns:
1515
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1516
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1517
+ """
1518
+ image_token_id = self.config.image_token_id
1519
+ video_token_id = self.config.video_token_id
1520
+ vision_start_token_id = self.config.vision_start_token_id
1521
+
1522
+ if inputs_embeds is not None:
1523
+ vision_start_mask = (
1524
+ inputs_embeds
1525
+ == self.get_input_embeddings()(
1526
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1527
+ )
1528
+ )[..., 0]
1529
+ image_mask = (
1530
+ inputs_embeds
1531
+ == self.get_input_embeddings()(
1532
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
1533
+ )
1534
+ )[..., 0]
1535
+ video_mask = (
1536
+ inputs_embeds
1537
+ == self.get_input_embeddings()(
1538
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
1539
+ )
1540
+ )[..., 0]
1541
+ else:
1542
+ vision_start_mask = input_ids == vision_start_token_id
1543
+ image_mask = input_ids == image_token_id
1544
+ video_mask = input_ids == video_token_id
1545
+
1546
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1547
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1548
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1549
+
1550
+ return image_nums, video_nums
1551
+
1552
+ def _expand_inputs_for_generation(
1553
+ self,
1554
+ expand_size: int = 1,
1555
+ is_encoder_decoder: bool = False,
1556
+ input_ids: Optional[torch.LongTensor] = None,
1557
+ **model_kwargs,
1558
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1559
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1560
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1561
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1562
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1563
+
1564
+ if expand_size == 1:
1565
+ return input_ids, model_kwargs
1566
+
1567
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1568
+
1569
+ def _expand_dict_for_generation_visual(dict_to_expand):
1570
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1571
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1572
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
1573
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1574
+ )
1575
+
1576
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1577
+ samples = torch.split(x, lengths)
1578
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1579
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1580
+ return result
1581
+
1582
+ for key in dict_to_expand:
1583
+ if key == "pixel_values":
1584
+ # split images into samples
1585
+ samples = torch.split(image_grid_thw, list(image_nums))
1586
+ # compute the sequence length of images for each sample
1587
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1588
+ dict_to_expand[key] = _repeat_interleave_samples(
1589
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1590
+ )
1591
+ elif key == "image_grid_thw":
1592
+ # get the num of images for each sample
1593
+ lengths = list(image_nums)
1594
+ dict_to_expand[key] = _repeat_interleave_samples(
1595
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1596
+ )
1597
+ elif key == "pixel_values_videos":
1598
+ samples = torch.split(video_grid_thw, list(video_nums))
1599
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1600
+ dict_to_expand[key] = _repeat_interleave_samples(
1601
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1602
+ )
1603
+ elif key == "video_grid_thw":
1604
+ lengths = list(video_nums)
1605
+ dict_to_expand[key] = _repeat_interleave_samples(
1606
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1607
+ )
1608
+ elif key == "second_per_grid_ts":
1609
+ dict_to_expand[key] = _repeat_interleave_samples(
1610
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1611
+ )
1612
+ return dict_to_expand
1613
+
1614
+ def _expand_dict_for_generation(dict_to_expand):
1615
+ for key in dict_to_expand:
1616
+ if (
1617
+ key != "cache_position"
1618
+ and dict_to_expand[key] is not None
1619
+ and isinstance(dict_to_expand[key], torch.Tensor)
1620
+ and key not in visual_keys
1621
+ ):
1622
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1623
+ return dict_to_expand
1624
+
1625
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1626
+
1627
+ if input_ids is not None:
1628
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1629
+
1630
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1631
+
1632
+ if is_encoder_decoder:
1633
+ if model_kwargs.get("encoder_outputs") is None:
1634
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1635
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1636
+
1637
+ return input_ids, model_kwargs
1638
+
1639
+
1640
+ __all__ = [
1641
+ "Qwen3VLVisionModel",
1642
+ "Qwen3VLForConditionalGeneration",
1643
+ "Qwen3VLPreTrainedModel",
1644
+ "Qwen3VLTextModel",
1645
+ ]
preprocessor_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": null,
3
+ "data_format": "channels_first",
4
+ "default_to_square": true,
5
+ "device": null,
6
+ "disable_grouping": null,
7
+ "do_center_crop": null,
8
+ "do_convert_rgb": true,
9
+ "do_normalize": true,
10
+ "do_pad": null,
11
+ "do_rescale": true,
12
+ "do_resize": true,
13
+ "image_mean": [
14
+ 0.5,
15
+ 0.5,
16
+ 0.5
17
+ ],
18
+ "image_processor_type": "Qwen2VLImageProcessorFast",
19
+ "image_std": [
20
+ 0.5,
21
+ 0.5,
22
+ 0.5
23
+ ],
24
+ "input_data_format": null,
25
+ "max_pixels": 147456,
26
+ "merge_size": 2,
27
+ "min_pixels": 65536,
28
+ "pad_size": null,
29
+ "patch_size": 16,
30
+ "processor_class": "PRTS_Qwen3VLProcessor",
31
+ "resample": 3,
32
+ "rescale_factor": 0.00392156862745098,
33
+ "return_tensors": null,
34
+ "size": {
35
+ "longest_edge": 147456,
36
+ "shortest_edge": 65536
37
+ },
38
+ "temporal_patch_size": 2
39
+ }
processing_prts_qwen3_vl.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 TeleAI Rhodes Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Processor for PRTS built on Qwen3-VL (hub / trust_remote_code; no prts package required)."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from typing import Optional, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from transformers.feature_extraction_utils import BatchFeature
25
+ from transformers.image_utils import ImageInput
26
+ from transformers.processing_utils import (
27
+ ImagesKwargs,
28
+ MultiModalData,
29
+ ProcessingKwargs,
30
+ ProcessorMixin,
31
+ Unpack,
32
+ VideosKwargs,
33
+ )
34
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
35
+ from transformers.utils.logging import get_logger
36
+ from transformers.video_utils import VideoInput
37
+
38
+ ACTION_START_TOKEN = "<|action_start|>"
39
+ ACTION_PLACEHOLDER_TOKEN = "<|action_pad|>"
40
+ ACTION_END_TOKEN = "<|action_end|>"
41
+ CRL_GOAL_REPR_TOKEN = "<|goal_repr|>"
42
+ CRL_OBS_REPR_TOKEN = "<|obs_repr|>"
43
+ VISION_START_TOKEN = "<|vision_start|>" # beginning of vision input
44
+ IMAGE_PLACEHOLDER_TOKEN = "<|image_pad|>" # image placeholder
45
+ VIDEO_PLACEHOLDER_TOKEN = "<|video_pad|>" # video placeholder
46
+
47
+ logger = get_logger(__name__)
48
+ if not logger.handlers:
49
+ handler = logging.StreamHandler()
50
+ handler.setLevel(logging.INFO)
51
+ handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
52
+ logger.addHandler(handler)
53
+
54
+
55
+ class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
56
+ pass
57
+
58
+
59
+ class Qwen3VLImagesKwargs(ImagesKwargs):
60
+ min_pixels: Optional[int]
61
+ max_pixels: Optional[int]
62
+ patch_size: Optional[int]
63
+ temporal_patch_size: Optional[int]
64
+ merge_size: Optional[int]
65
+
66
+
67
+ class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
68
+ images_kwargs: Qwen3VLImagesKwargs
69
+ videos_kwargs: Qwen3VLVideosProcessorKwargs
70
+ _defaults = {
71
+ "text_kwargs": {
72
+ "padding": False,
73
+ "return_token_type_ids": False,
74
+ "return_mm_token_type_ids": False,
75
+ },
76
+ "videos_kwargs": {"return_metadata": True},
77
+ }
78
+
79
+
80
+ class PRTS_Qwen3VLProcessor(ProcessorMixin):
81
+ r"""
82
+ Constructs a PRTS processor which wraps a Qwen3-VL image processor and a Qwen2 tokenizer into a single processor.
83
+
84
+ This processor is built independently (not inheriting from Qwen3VLProcessor) to avoid tight coupling,
85
+ while maintaining compatibility with Qwen3-VL's timestamp-based video processing approach.
86
+
87
+ [`PRTS_Qwen3VLProcessor`] offers all the functionalities needed for PRTS model with:
88
+ - Action token handling (discrete and continuous)
89
+ - State token handling for proprioceptive inputs
90
+ - Expert trigger tokens for flow matching action prediction
91
+ - Qwen3-VL compatible image/video processing with timestamp-based video handling
92
+
93
+ Args:
94
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
95
+ The image processor is a required input.
96
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
97
+ The tokenizer is a required input.
98
+ video_processor ([`Qwen3VLVideoProcessor`], *optional*):
99
+ The video processor is a required input.
100
+ chat_template (`str`, *optional*):
101
+ A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
102
+ """
103
+
104
+ attributes = ["image_processor", "tokenizer", "video_processor"]
105
+ image_processor_class = "AutoImageProcessor"
106
+ video_processor_class = "AutoVideoProcessor"
107
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
108
+
109
+ def __init__(self, image_processor=None, tokenizer=None, video_processor=None,
110
+ chat_template=None, **kwargs):
111
+ # Initialize base ProcessorMixin
112
+ super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
113
+
114
+ # Get image/video tokens from tokenizer
115
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
116
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
117
+ self.image_token_id = (
118
+ tokenizer.image_token_id
119
+ if getattr(tokenizer, "image_token_id", None)
120
+ else tokenizer.convert_tokens_to_ids(self.image_token)
121
+ )
122
+ self.video_token_id = (
123
+ tokenizer.video_token_id
124
+ if getattr(tokenizer, "video_token_id", None)
125
+ else tokenizer.convert_tokens_to_ids(self.video_token)
126
+ )
127
+
128
+ # Qwen3-VL vision tokens
129
+ self.vision_start_token = (
130
+ "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
131
+ )
132
+ self.vision_end_token = (
133
+ "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
134
+ )
135
+ self.vision_start_token_id = (
136
+ tokenizer.vision_start_token_id
137
+ if getattr(tokenizer, "vision_start_token_id", None)
138
+ else tokenizer.convert_tokens_to_ids(self.vision_start_token)
139
+ )
140
+ self.vision_end_token_id = (
141
+ tokenizer.vision_end_token_id
142
+ if getattr(tokenizer, "vision_end_token_id", None)
143
+ else tokenizer.convert_tokens_to_ids(self.vision_end_token)
144
+ )
145
+
146
+ prts_special_tokens = [
147
+ ACTION_START_TOKEN,
148
+ ACTION_PLACEHOLDER_TOKEN,
149
+ ACTION_END_TOKEN,
150
+ CRL_GOAL_REPR_TOKEN,
151
+ CRL_OBS_REPR_TOKEN,
152
+ ]
153
+ num_new_tokens = tokenizer.add_tokens(prts_special_tokens, special_tokens=True)
154
+ logger.info(f"Added {num_new_tokens} new special tokens to the tokenizer.")
155
+
156
+ self.action_token = getattr(tokenizer, "action_token", ACTION_PLACEHOLDER_TOKEN)
157
+ self.action_token_id = tokenizer.convert_tokens_to_ids(self.action_token)
158
+ token_dict = {
159
+ "action_start_token_id": ACTION_START_TOKEN,
160
+ "action_token_id": ACTION_PLACEHOLDER_TOKEN,
161
+ "vision_start_token_id": VISION_START_TOKEN,
162
+ "image_token_id": IMAGE_PLACEHOLDER_TOKEN,
163
+ "video_token_id": VIDEO_PLACEHOLDER_TOKEN,
164
+ "crl_goal_repr_token_id": CRL_GOAL_REPR_TOKEN,
165
+ "crl_obs_repr_token_id": CRL_OBS_REPR_TOKEN,
166
+ }
167
+ self.token_ids = {key: tokenizer.convert_tokens_to_ids(value) for key, value in token_dict.items()}
168
+
169
+ def __call__(
170
+ self,
171
+ images: Optional[ImageInput] = None,
172
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
173
+ videos: Optional[VideoInput] = None,
174
+ actions: Union[torch.Tensor] = None,
175
+ **kwargs: Unpack[Qwen3VLProcessorKwargs],
176
+ ) -> BatchFeature:
177
+ output_kwargs = self._merge_kwargs(
178
+ Qwen3VLProcessorKwargs,
179
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
180
+ **kwargs,
181
+ )
182
+
183
+ image_inputs = {}
184
+ if images is not None:
185
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
186
+ image_grid_thw = image_inputs["image_grid_thw"]
187
+ else:
188
+ image_grid_thw = None
189
+
190
+ videos_inputs = {}
191
+ if videos is not None:
192
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
193
+ video_grid_thw = videos_inputs["video_grid_thw"]
194
+ if "return_metadata" not in kwargs:
195
+ video_metadata = videos_inputs.pop("video_metadata", None)
196
+ else:
197
+ video_metadata = videos_inputs.get("video_metadata", None)
198
+ else:
199
+ video_grid_thw = None
200
+ video_metadata = None
201
+
202
+ if not isinstance(text, list):
203
+ text = [text]
204
+
205
+ text = text.copy()
206
+
207
+ if image_grid_thw is not None:
208
+ merge_length = self.image_processor.merge_size**2
209
+ index = 0
210
+ for i in range(len(text)):
211
+ while self.image_token in text[i]:
212
+ num_image_tokens = image_grid_thw[index].prod() // merge_length
213
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
214
+ index += 1
215
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
216
+
217
+ if video_grid_thw is not None:
218
+ merge_length = self.video_processor.merge_size**2
219
+ index = 0
220
+ for i in range(len(text)):
221
+ while self.video_token in text[i]:
222
+ if video_metadata is not None and index < len(video_metadata):
223
+ metadata = video_metadata[index]
224
+ if metadata.fps is None:
225
+ logger.warning_once(
226
+ "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
227
+ "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
228
+ "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
229
+ )
230
+ metadata.fps = 24 if metadata.fps is None else metadata.fps
231
+
232
+ curr_timestamp = self._calculate_timestamps(
233
+ metadata.frames_indices,
234
+ metadata.fps,
235
+ self.video_processor.merge_size,
236
+ )
237
+
238
+ video_placeholder = ""
239
+ frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
240
+ for frame_idx in range(video_grid_thw[index][0]):
241
+ curr_time = curr_timestamp[frame_idx]
242
+ video_placeholder += f"<{curr_time:.1f} seconds>"
243
+ video_placeholder += (
244
+ self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
245
+ )
246
+
247
+ if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
248
+ text[i] = text[i].replace(
249
+ f"{self.vision_start_token}{self.video_token}{self.vision_end_token}",
250
+ video_placeholder,
251
+ 1,
252
+ )
253
+ else:
254
+ text[i] = text[i].replace(self.video_token, video_placeholder, 1)
255
+ else:
256
+ num_video_tokens = video_grid_thw[index].prod() // merge_length
257
+ text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1)
258
+
259
+ index += 1
260
+ text[i] = text[i].replace("<|placeholder|>", self.video_token)
261
+
262
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
263
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
264
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
265
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
266
+
267
+ if return_mm_token_type_ids:
268
+ array_ids = np.array(text_inputs["input_ids"])
269
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
270
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
271
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
272
+
273
+ output_data = {**text_inputs, **image_inputs, **videos_inputs}
274
+ if actions is not None:
275
+ output_data["actions"] = actions
276
+
277
+ return BatchFeature(data=output_data, tensor_type=return_tensors)
278
+
279
+ def _calculate_timestamps(self, indices: Union[list[int], np.ndarray], video_fps: float, merge_size: int = 2):
280
+ if not isinstance(indices, list):
281
+ indices = indices.tolist()
282
+ if len(indices) % merge_size != 0:
283
+ indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
284
+ timestamps = [idx / video_fps for idx in indices]
285
+ timestamps = [
286
+ (timestamps[i] + timestamps[i + merge_size - 1]) / 2 for i in range(0, len(timestamps), merge_size)
287
+ ]
288
+ return timestamps
289
+
290
+ def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
291
+ vision_data = {}
292
+ if image_sizes is not None:
293
+ images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {})
294
+ images_kwargs.update(kwargs)
295
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
296
+
297
+ num_image_patches = [
298
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
299
+ for image_size in image_sizes
300
+ ]
301
+ num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
302
+ vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
303
+
304
+ if video_sizes is not None:
305
+ videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {})
306
+ videos_kwargs.update(kwargs)
307
+ merge_size = videos_kwargs.get("merge_size", None) or self.video_processor.merge_size
308
+ num_video_patches = [
309
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
310
+ for video_size in video_sizes
311
+ ]
312
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
313
+ vision_data["num_video_tokens"] = num_video_tokens
314
+
315
+ return MultiModalData(**vision_data)
316
+
317
+ def set_action_tokenizer(self, action_tokenizer):
318
+ self.action_tokenizer = action_tokenizer
319
+
320
+ prts_fast_action_tokens = [f"<|action_token_{i}|>" for i in range(action_tokenizer.vocab_size)]
321
+ num_new_tokens = self.tokenizer.add_tokens(prts_fast_action_tokens, special_tokens=True)
322
+ logger.info(f"Added {num_new_tokens} FAST action tokens to the tokenizer.")
323
+
324
+ self.action_token_start_index = self.tokenizer.convert_tokens_to_ids("<|action_token_0|>")
325
+ self.action_vocab_size = action_tokenizer.vocab_size
326
+
327
+ token_ids = self.tokenizer.convert_tokens_to_ids(prts_fast_action_tokens)
328
+ self.action_mapper = {k: v for k, v in zip(prts_fast_action_tokens, token_ids, strict=True)}
329
+
330
+ def preprocess_action(self, actions, **kwargs):
331
+ raise NotImplementedError
332
+
333
+ def post_process_image_text_to_text(
334
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
335
+ ):
336
+ return self.tokenizer.batch_decode(
337
+ generated_outputs,
338
+ skip_special_tokens=skip_special_tokens,
339
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
340
+ **kwargs,
341
+ )
342
+
343
+ @property
344
+ def model_input_names(self):
345
+ tokenizer_input_names = self.tokenizer.model_input_names
346
+ image_processor_input_names = self.image_processor.model_input_names
347
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
348
+
349
+
350
+ PRTS_Qwen3VLProcessor.register_for_auto_class()
351
+
352
+ __all__ = ["PRTS_Qwen3VLProcessor"]
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
statistics.json ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "state_mode": "QUANTILE",
3
+ "features": {
4
+ "observation.state": {
5
+ "dtype": "float32",
6
+ "shape": [
7
+ 8
8
+ ],
9
+ "names": {
10
+ "motors": [
11
+ "x",
12
+ "y",
13
+ "z",
14
+ "roll",
15
+ "pitch",
16
+ "yaw",
17
+ "pad",
18
+ "gripper"
19
+ ]
20
+ }
21
+ },
22
+ "action": {
23
+ "dtype": "float32",
24
+ "shape": [
25
+ 7
26
+ ],
27
+ "names": {
28
+ "motors": [
29
+ "x",
30
+ "y",
31
+ "z",
32
+ "roll",
33
+ "pitch",
34
+ "yaw",
35
+ "gripper"
36
+ ]
37
+ }
38
+ }
39
+ },
40
+ "stats": {
41
+ "observation.state": {
42
+ "min": [
43
+ -0.4828203022480011,
44
+ -0.3255046010017395,
45
+ 0.008128180168569088,
46
+ 0.35277295112609863,
47
+ -3.641430377960205,
48
+ -1.842738389968872,
49
+ -0.0013586411951109767,
50
+ -0.042040832340717316
51
+ ],
52
+ "max": [
53
+ 0.21031762659549713,
54
+ 0.39128610491752625,
55
+ 1.3660105466842651,
56
+ 3.6714255809783936,
57
+ 3.560650587081909,
58
+ 1.386339545249939,
59
+ 0.04233968257904053,
60
+ 0.0013633022317662835
61
+ ],
62
+ "mean": [
63
+ -0.046518828719854355,
64
+ 0.034408919513225555,
65
+ 0.7645694613456726,
66
+ 2.9716713428497314,
67
+ -0.2204727977514267,
68
+ -0.12557993829250336,
69
+ 0.026915358379483223,
70
+ -0.027191326022148132
71
+ ],
72
+ "std": [
73
+ 0.10494082421064377,
74
+ 0.1517619788646698,
75
+ 0.3785194456577301,
76
+ 0.3442671298980713,
77
+ 0.9068173766136169,
78
+ 0.32538288831710815,
79
+ 0.014175750315189362,
80
+ 0.014058776199817657
81
+ ],
82
+ "count": [
83
+ 273465
84
+ ],
85
+ "q01": [
86
+ -0.3991248905658722,
87
+ -0.2688351273536682,
88
+ 0.03826696425676346,
89
+ 1.508958101272583,
90
+ -2.7197911739349365,
91
+ -1.0805085897445679,
92
+ 0.0017423711251467466,
93
+ -0.04002561420202255
94
+ ],
95
+ "q99": [
96
+ 0.13556525111198425,
97
+ 0.33566486835479736,
98
+ 1.2706660032272339,
99
+ 3.277346134185791,
100
+ 2.406111240386963,
101
+ 0.5977716445922852,
102
+ 0.04031316190958023,
103
+ -0.00177810771856457
104
+ ]
105
+ },
106
+ "action": {
107
+ "min": [
108
+ -0.9375,
109
+ -0.9375,
110
+ -0.9375,
111
+ -0.2582142949104309,
112
+ -0.375,
113
+ -0.3675000071525574,
114
+ -1.0
115
+ ],
116
+ "max": [
117
+ 0.9375,
118
+ 0.9375,
119
+ 0.9375,
120
+ 0.3557142913341522,
121
+ 0.375,
122
+ 0.375,
123
+ 1.0
124
+ ],
125
+ "mean": [
126
+ 0.06278152763843536,
127
+ 0.08684158325195312,
128
+ -0.0903734639286995,
129
+ 0.0005407554563134909,
130
+ 0.005643464159220457,
131
+ -0.005229106638580561,
132
+ -0.0496407225728035
133
+ ],
134
+ "std": [
135
+ 0.33551836013793945,
136
+ 0.37847793102264404,
137
+ 0.4446770250797272,
138
+ 0.03924214467406273,
139
+ 0.06341660022735596,
140
+ 0.07792268693447113,
141
+ 1.000144362449646
142
+ ],
143
+ "count": [
144
+ 273465
145
+ ],
146
+ "q01": [
147
+ -0.7044642567634583,
148
+ -0.8008928298950195,
149
+ -0.9375,
150
+ -0.11464285850524902,
151
+ -0.1639285683631897,
152
+ -0.2239285707473755,
153
+ -1.0
154
+ ],
155
+ "q99": [
156
+ 0.9375,
157
+ 0.8678571581840515,
158
+ 0.9375,
159
+ 0.13178572058677673,
160
+ 0.19285714626312256,
161
+ 0.335357129573822,
162
+ 1.0
163
+ ]
164
+ }
165
+ },
166
+ "datasets": {
167
+ "libero_4_suites": {
168
+ "features": {
169
+ "observation.state": {
170
+ "dtype": "float32",
171
+ "shape": [
172
+ 8
173
+ ],
174
+ "names": {
175
+ "motors": [
176
+ "x",
177
+ "y",
178
+ "z",
179
+ "roll",
180
+ "pitch",
181
+ "yaw",
182
+ "pad",
183
+ "gripper"
184
+ ]
185
+ }
186
+ },
187
+ "action": {
188
+ "dtype": "float32",
189
+ "shape": [
190
+ 7
191
+ ],
192
+ "names": {
193
+ "motors": [
194
+ "x",
195
+ "y",
196
+ "z",
197
+ "roll",
198
+ "pitch",
199
+ "yaw",
200
+ "gripper"
201
+ ]
202
+ }
203
+ }
204
+ },
205
+ "stats": {
206
+ "observation.state": {
207
+ "min": [
208
+ -0.4828203022480011,
209
+ -0.3255046010017395,
210
+ 0.008128180168569088,
211
+ 0.35277295112609863,
212
+ -3.641430377960205,
213
+ -1.842738389968872,
214
+ -0.0013586411951109767,
215
+ -0.042040832340717316
216
+ ],
217
+ "max": [
218
+ 0.21031762659549713,
219
+ 0.39128610491752625,
220
+ 1.3660105466842651,
221
+ 3.6714255809783936,
222
+ 3.560650587081909,
223
+ 1.386339545249939,
224
+ 0.04233968257904053,
225
+ 0.0013633022317662835
226
+ ],
227
+ "mean": [
228
+ -0.046518828719854355,
229
+ 0.034408919513225555,
230
+ 0.7645694613456726,
231
+ 2.9716713428497314,
232
+ -0.2204727977514267,
233
+ -0.12557993829250336,
234
+ 0.026915358379483223,
235
+ -0.027191326022148132
236
+ ],
237
+ "std": [
238
+ 0.10494082421064377,
239
+ 0.1517619788646698,
240
+ 0.3785194456577301,
241
+ 0.3442671298980713,
242
+ 0.9068173766136169,
243
+ 0.32538288831710815,
244
+ 0.014175750315189362,
245
+ 0.014058776199817657
246
+ ],
247
+ "count": [
248
+ 273465
249
+ ],
250
+ "q01": [
251
+ -0.3991248905658722,
252
+ -0.2688351273536682,
253
+ 0.03826696425676346,
254
+ 1.508958101272583,
255
+ -2.7197911739349365,
256
+ -1.0805085897445679,
257
+ 0.0017423711251467466,
258
+ -0.04002561420202255
259
+ ],
260
+ "q99": [
261
+ 0.13556525111198425,
262
+ 0.33566486835479736,
263
+ 1.2706660032272339,
264
+ 3.277346134185791,
265
+ 2.406111240386963,
266
+ 0.5977716445922852,
267
+ 0.04031316190958023,
268
+ -0.00177810771856457
269
+ ]
270
+ },
271
+ "action": {
272
+ "min": [
273
+ -0.9375,
274
+ -0.9375,
275
+ -0.9375,
276
+ -0.2582142949104309,
277
+ -0.375,
278
+ -0.3675000071525574,
279
+ -1.0
280
+ ],
281
+ "max": [
282
+ 0.9375,
283
+ 0.9375,
284
+ 0.9375,
285
+ 0.3557142913341522,
286
+ 0.375,
287
+ 0.375,
288
+ 1.0
289
+ ],
290
+ "mean": [
291
+ 0.06278152763843536,
292
+ 0.08684158325195312,
293
+ -0.0903734639286995,
294
+ 0.0005407554563134909,
295
+ 0.005643464159220457,
296
+ -0.005229106638580561,
297
+ -0.0496407225728035
298
+ ],
299
+ "std": [
300
+ 0.33551836013793945,
301
+ 0.37847793102264404,
302
+ 0.4446770250797272,
303
+ 0.03924214467406273,
304
+ 0.06341660022735596,
305
+ 0.07792268693447113,
306
+ 1.000144362449646
307
+ ],
308
+ "count": [
309
+ 273465
310
+ ],
311
+ "q01": [
312
+ -0.7044642567634583,
313
+ -0.8008928298950195,
314
+ -0.9375,
315
+ -0.11464285850524902,
316
+ -0.1639285683631897,
317
+ -0.2239285707473755,
318
+ -1.0
319
+ ],
320
+ "q99": [
321
+ 0.9375,
322
+ 0.8678571581840515,
323
+ 0.9375,
324
+ 0.13178572058677673,
325
+ 0.19285714626312256,
326
+ 0.335357129573822,
327
+ 1.0
328
+ ]
329
+ }
330
+ }
331
+ }
332
+ }
333
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5482df2482307db564c0595428d3dfdad4bf5dbd9d3d5156052ca12f93b7d3ed
3
+ size 11828002
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb6e322ff4c32859f3014cdd5e49182ec932f2b10cc2e365df3439522af926a7
3
+ size 10129
video_preprocessor_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": null,
3
+ "data_format": "channels_first",
4
+ "default_to_square": true,
5
+ "device": null,
6
+ "do_center_crop": null,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "do_sample_frames": true,
12
+ "fps": 2.0,
13
+ "image_mean": [
14
+ 0.5,
15
+ 0.5,
16
+ 0.5
17
+ ],
18
+ "image_std": [
19
+ 0.5,
20
+ 0.5,
21
+ 0.5
22
+ ],
23
+ "input_data_format": null,
24
+ "max_frames": 8,
25
+ "merge_size": 2,
26
+ "min_frames": 4,
27
+ "num_frames": null,
28
+ "pad_size": null,
29
+ "patch_size": 16,
30
+ "processor_class": "PRTS_Qwen3VLProcessor",
31
+ "resample": 3,
32
+ "rescale_factor": 0.00392156862745098,
33
+ "return_metadata": false,
34
+ "size": {
35
+ "longest_edge": 147456,
36
+ "shortest_edge": 65536
37
+ },
38
+ "temporal_patch_size": 2,
39
+ "video_metadata": null,
40
+ "video_processor_type": "Qwen3VLVideoProcessor"
41
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff