Robotics
Transformers
Safetensors
English
rio2
feature-extraction
Mixture of Experts
diffusion-jepa
custom_code
hoguai commited on
Commit
9a5262f
·
verified ·
1 Parent(s): 1ffd957

Upload 7 files

Browse files
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_id": "allenai/MolmoAct2-SO100_101",
3
+ "norm_tag": "so100_so101_molmoact2",
4
+ "state_dim": 6,
5
+ "action_dim": 6,
6
+ "action_horizon": 30,
7
+ "state_history_len": 8,
8
+ "action_history_len": 8,
9
+ "s2_token_count": 16,
10
+ "s2_input_width": 2560,
11
+ "s2_width": 1024,
12
+ "s1_width": 384,
13
+ "s1_layers": 6,
14
+ "s1_heads": 8,
15
+ "s1_dropout": 0.05,
16
+ "flow_inference_steps": 4,
17
+ "temporal_ensemble_enabled": true,
18
+ "temporal_ensemble_max_chunks": 4,
19
+ "temporal_ensemble_decay": 0.15,
20
+ "task_memory_enabled": true,
21
+ "task_memory_slots": 8,
22
+ "task_memory_ema": 0.97,
23
+ "task_memory_alpha": 0.25,
24
+ "task_memory_max_norm": 10.0,
25
+ "s1_policy_mode": "jepa_diffusion",
26
+ "enable_jepa_diffusion": true,
27
+ "diffusion_inference_steps": 1,
28
+ "diffusion_loss_weight": 1.0,
29
+ "consistency_loss_weight": 0.5,
30
+ "flow_loss_weight": 0.1,
31
+ "jepa_loss_weight": 0.1,
32
+ "jepa_action_prior_weight": 0.05,
33
+ "jepa_hidden_dim": 256,
34
+ "jepa_latent_dim": 256,
35
+ "jepa_ema_decay": 0.996,
36
+ "jepa_action_prior_alpha": 0.25,
37
+ "jepa_condition_alpha": 1.0,
38
+ "s1_sampling_noise_scale": 1.0,
39
+ "enable_s1_moe": true,
40
+ "s1_moe_num_experts": 10,
41
+ "s1_moe_top_k": 1,
42
+ "s1_moe_expert_hidden_dim": 177000,
43
+ "s1_moe_residual_scale": 0.1,
44
+ "dtype": "bfloat16",
45
+ "s2_refresh_hz": 8.0,
46
+ "max_s2_cache_age_s": 0.2,
47
+ "action_clip": 1.0,
48
+ "smooth_loss_weight": 0.02,
49
+ "action_l1_weight": 0.0,
50
+ "residual_alpha": 1.0,
51
+ "model_type": "rio2",
52
+ "architectures": [
53
+ "Rio2Model"
54
+ ],
55
+ "weight_format": "safetensors",
56
+ "weight_file": "model.safetensors",
57
+ "runtime_type": "two_rate_weight_preserved",
58
+ "auto_map": {
59
+ "AutoConfig": "configuration_rio2.Rio2Config",
60
+ "AutoModel": "modeling_rio2.Rio2Model",
61
+ "AutoProcessor": "processing_rio2.Rio2Processor"
62
+ }
63
+ }
configuration_rio2.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team and the Rio2 contributors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ """RIO-2 configuration."""
5
+
6
+ from transformers.configuration_utils import PreTrainedConfig
7
+ from transformers.utils import logging
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class Rio2Config(PreTrainedConfig):
14
+ r"""
15
+ Configuration class for [`Rio2Model`].
16
+
17
+ RIO-2 is a SO101 robotics policy with real-time
18
+
19
+ - S2: low-frequency semantic/context refresh
20
+ - S1: high-frequency action generation, preferably using the larger module
21
+
22
+ """
23
+
24
+ model_type = "rio2"
25
+ attribute_map = {
26
+ "hidden_size": "s1_width",
27
+ "num_attention_heads": "s1_heads",
28
+ "num_hidden_layers": "s1_layers",
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ base_model_id="allenai/MolmoAct2-SO100_101",
34
+ norm_tag="so100_so101_molmoact2",
35
+ rio2_variant="weight_preserved",
36
+ runtime_mode="two_rate_weight_preserved",
37
+ state_dim=6,
38
+ action_dim=6,
39
+ action_horizon=30,
40
+ state_history_len=8,
41
+ action_history_len=8,
42
+ # Compact token fallback path. These remain for tests and for cases
43
+ # where the original action expert cannot be called directly.
44
+ s2_token_count=16,
45
+ s2_input_width=4096,
46
+ s2_width=1024,
47
+ s1_width=384,
48
+ s1_layers=6,
49
+ s1_heads=8,
50
+ s1_dropout=0.05,
51
+ flow_inference_steps=4,
52
+ temporal_ensemble_enabled=True,
53
+ temporal_ensemble_max_chunks=4,
54
+ temporal_ensemble_decay=0.15,
55
+ task_memory_enabled=True,
56
+ task_memory_slots=8,
57
+ task_memory_ema=0.97,
58
+ task_memory_alpha=0.25,
59
+ task_memory_max_norm=10.0,
60
+ # Weight-preserved MolmoAct2 path.
61
+ use_original_s2=True,
62
+ use_original_s1=True,
63
+ prefer_split_action_expert=True,
64
+ fallback_to_predict_action=True,
65
+ action_mode="continuous",
66
+ molmoact_num_steps=10,
67
+ s2_refresh_hz=8.0,
68
+ max_s2_cache_age_s=0.20,
69
+ action_clip=1.0,
70
+ # JEPA-style S1. This keeps the original/online S1 policy weights as
71
+ # the action generator and adds a small latent world-model side head.
72
+ # The target action encoder is updated by EMA and is used only for the
73
+ # self-supervised JEPA loss.
74
+ s1_architecture="jepa_diffusion",
75
+ enable_jepa_s1=False,
76
+ jepa_hidden_dim=256,
77
+ jepa_latent_dim=256,
78
+ jepa_layers=2,
79
+ jepa_heads=4,
80
+ jepa_loss_weight=0.10,
81
+ jepa_ema_decay=0.996,
82
+ use_jepa_action_residual=False,
83
+ jepa_action_alpha=0.0,
84
+ s1_policy_mode="jepa_diffusion",
85
+ enable_jepa_diffusion=True,
86
+ diffusion_inference_steps=1,
87
+ diffusion_loss_weight=1.0,
88
+ consistency_loss_weight=0.50,
89
+ flow_loss_weight=0.10,
90
+ jepa_action_prior_weight=0.05,
91
+ jepa_action_prior_alpha=0.25,
92
+ jepa_condition_alpha=1.0,
93
+ s1_sampling_noise_scale=1.0,
94
+ enable_s1_moe=False,
95
+ s1_moe_num_experts=10,
96
+ s1_moe_top_k=1,
97
+ s1_moe_expert_hidden_dim=105472,
98
+ s1_moe_residual_scale=0.10,
99
+ # Tiny tuning knobs.
100
+ train_adapters_only=True,
101
+ enable_residual_adapter=True,
102
+ residual_alpha=0.0,
103
+ residual_trainable=True,
104
+ enable_s1_lora=False,
105
+ enable_s2_lora=False,
106
+ lora_r=8,
107
+ lora_alpha=16,
108
+ # Training losses for fallback/adapter path.
109
+ smooth_loss_weight=0.02,
110
+ action_l1_weight=0.0,
111
+ torch_dtype="bfloat16",
112
+ load_base_on_init=False,
113
+ trust_remote_code=True,
114
+ **kwargs,
115
+ ):
116
+ self.base_model_id = base_model_id
117
+ self.norm_tag = norm_tag
118
+ self.rio2_variant = rio2_variant
119
+ self.runtime_mode = runtime_mode
120
+
121
+ self.state_dim = state_dim
122
+ self.action_dim = action_dim
123
+ self.action_horizon = action_horizon
124
+ self.state_history_len = state_history_len
125
+ self.action_history_len = action_history_len
126
+
127
+ self.s2_token_count = s2_token_count
128
+ self.s2_input_width = s2_input_width
129
+ self.s2_width = s2_width
130
+ self.s1_width = s1_width
131
+ self.s1_layers = s1_layers
132
+ self.s1_heads = s1_heads
133
+ self.s1_dropout = s1_dropout
134
+ self.flow_inference_steps = flow_inference_steps
135
+ self.temporal_ensemble_enabled = temporal_ensemble_enabled
136
+ self.temporal_ensemble_max_chunks = temporal_ensemble_max_chunks
137
+ self.temporal_ensemble_decay = temporal_ensemble_decay
138
+ self.task_memory_enabled = task_memory_enabled
139
+ self.task_memory_slots = task_memory_slots
140
+ self.task_memory_ema = task_memory_ema
141
+ self.task_memory_alpha = task_memory_alpha
142
+ self.task_memory_max_norm = task_memory_max_norm
143
+
144
+ self.use_original_s2 = use_original_s2
145
+ self.use_original_s1 = use_original_s1
146
+ self.prefer_split_action_expert = prefer_split_action_expert
147
+ self.fallback_to_predict_action = fallback_to_predict_action
148
+ self.action_mode = action_mode
149
+ self.molmoact_num_steps = molmoact_num_steps
150
+
151
+ self.s2_refresh_hz = s2_refresh_hz
152
+ self.max_s2_cache_age_s = max_s2_cache_age_s
153
+ self.action_clip = action_clip
154
+
155
+ self.s1_architecture = s1_architecture
156
+ self.enable_jepa_s1 = enable_jepa_s1
157
+ self.jepa_hidden_dim = jepa_hidden_dim
158
+ self.jepa_latent_dim = jepa_latent_dim
159
+ self.jepa_layers = jepa_layers
160
+ self.jepa_heads = jepa_heads
161
+ self.jepa_loss_weight = jepa_loss_weight
162
+ self.jepa_ema_decay = jepa_ema_decay
163
+ self.use_jepa_action_residual = use_jepa_action_residual
164
+ self.jepa_action_alpha = jepa_action_alpha
165
+ self.s1_policy_mode = s1_policy_mode
166
+ self.enable_jepa_diffusion = enable_jepa_diffusion
167
+ self.diffusion_inference_steps = diffusion_inference_steps
168
+ self.diffusion_loss_weight = diffusion_loss_weight
169
+ self.consistency_loss_weight = consistency_loss_weight
170
+ self.flow_loss_weight = flow_loss_weight
171
+ self.jepa_action_prior_weight = jepa_action_prior_weight
172
+ self.jepa_action_prior_alpha = jepa_action_prior_alpha
173
+ self.jepa_condition_alpha = jepa_condition_alpha
174
+ self.s1_sampling_noise_scale = s1_sampling_noise_scale
175
+ self.enable_s1_moe = enable_s1_moe
176
+ self.s1_moe_num_experts = s1_moe_num_experts
177
+ self.s1_moe_top_k = s1_moe_top_k
178
+ self.s1_moe_expert_hidden_dim = s1_moe_expert_hidden_dim
179
+ self.s1_moe_residual_scale = s1_moe_residual_scale
180
+
181
+ self.train_adapters_only = train_adapters_only
182
+ self.enable_residual_adapter = enable_residual_adapter
183
+ self.residual_alpha = residual_alpha
184
+ self.residual_trainable = residual_trainable
185
+ self.enable_s1_lora = enable_s1_lora
186
+ self.enable_s2_lora = enable_s2_lora
187
+ self.lora_r = lora_r
188
+ self.lora_alpha = lora_alpha
189
+
190
+ self.smooth_loss_weight = smooth_loss_weight
191
+ self.action_l1_weight = action_l1_weight
192
+
193
+ self.torch_dtype = torch_dtype
194
+ self.load_base_on_init = load_base_on_init
195
+ self.trust_remote_code = trust_remote_code
196
+
197
+ super().__init__(**kwargs)
198
+
199
+
200
+ __all__ = ["Rio2Config"]
modeling_rio2.py ADDED
@@ -0,0 +1,1364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team and the Rio2 contributors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ """PyTorch RIO-2 model.
5
+
6
+ Runtime modes:
7
+ - `refresh_s2(images, instruction)`: low-frequency context refresh.
8
+ - `act_fast(state, ...)`: high-frequency action generation.
9
+ - `forward(..., s2_tokens=...)`: cached-token fallback used for tests and
10
+ for adapter-only training when MolmoAct2 internals are unavailable.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import copy
16
+ import inspect
17
+ import math
18
+ import time
19
+ from collections.abc import Iterable
20
+ from dataclasses import dataclass
21
+ from typing import Any
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from PIL import Image
28
+
29
+ from transformers.modeling_outputs import ModelOutput
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import logging
32
+ from .configuration_rio2 import Rio2Config
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ ImageLike = Image.Image | np.ndarray | torch.Tensor
38
+
39
+
40
+ @dataclass
41
+ class Rio2Output(ModelOutput):
42
+ """Output type for RIO-2."""
43
+
44
+ loss: torch.FloatTensor | None = None
45
+ actions: torch.FloatTensor | None = None
46
+ s2_tokens: torch.FloatTensor | None = None
47
+ loss_flow_mse: torch.FloatTensor | None = None
48
+ loss_flow_l1: torch.FloatTensor | None = None
49
+ loss_diffusion: torch.FloatTensor | None = None
50
+ loss_consistency: torch.FloatTensor | None = None
51
+ loss_smooth: torch.FloatTensor | None = None
52
+ loss_jepa: torch.FloatTensor | None = None
53
+ loss_jepa_prior: torch.FloatTensor | None = None
54
+ pred_action_latent: torch.FloatTensor | None = None
55
+ target_action_latent: torch.FloatTensor | None = None
56
+ runtime_path: str | None = None
57
+
58
+
59
+ def _torch_dtype_from_string(dtype_name: str) -> torch.dtype:
60
+ table = {
61
+ "float32": torch.float32,
62
+ "fp32": torch.float32,
63
+ "float16": torch.float16,
64
+ "fp16": torch.float16,
65
+ "bfloat16": torch.bfloat16,
66
+ "bf16": torch.bfloat16,
67
+ }
68
+ return table.get(str(dtype_name).lower(), torch.bfloat16)
69
+
70
+
71
+ def _to_pil_list(images: ImageLike | list[ImageLike] | tuple[ImageLike, ...]) -> list[Image.Image]:
72
+ if isinstance(images, (list, tuple)):
73
+ return [_to_pil_list(x)[0] for x in images]
74
+ if isinstance(images, Image.Image):
75
+ return [images.convert("RGB")]
76
+ if isinstance(images, np.ndarray):
77
+ arr = images
78
+ if arr.ndim == 4:
79
+ return [_to_pil_list(a)[0] for a in arr]
80
+ if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
81
+ arr = np.transpose(arr, (1, 2, 0))
82
+ if arr.ndim == 2:
83
+ arr = np.repeat(arr[..., None], 3, axis=-1)
84
+ if arr.ndim == 3 and arr.shape[-1] == 1:
85
+ arr = np.repeat(arr, 3, axis=-1)
86
+ if arr.ndim == 3 and arr.shape[-1] == 4:
87
+ arr = arr[..., :3]
88
+ if arr.dtype != np.uint8:
89
+ arr = np.clip(arr, 0, 1) if arr.max() <= 1.5 else np.clip(arr, 0, 255)
90
+ arr = (arr * 255).astype(np.uint8) if arr.max() <= 1.5 else arr.astype(np.uint8)
91
+ return [Image.fromarray(arr).convert("RGB")]
92
+ if torch.is_tensor(images):
93
+ x = images.detach().cpu()
94
+ if x.ndim == 4:
95
+ return [_to_pil_list(xx)[0] for xx in x]
96
+ if x.ndim == 3 and x.shape[0] in (1, 3):
97
+ x = x.permute(1, 2, 0)
98
+ arr = x.numpy()
99
+ if arr.dtype != np.uint8:
100
+ arr = np.clip(arr, 0, 1) if arr.max() <= 1.5 else np.clip(arr, 0, 255)
101
+ arr = (arr * 255).astype(np.uint8) if arr.max() <= 1.5 else arr.astype(np.uint8)
102
+ return [Image.fromarray(arr).convert("RGB")]
103
+ raise TypeError(f"Unsupported image type: {type(images)}")
104
+
105
+
106
+ def _move_to_device(batch: Any, device: torch.device, dtype: torch.dtype | None = None) -> Any:
107
+ if torch.is_tensor(batch):
108
+ if batch.is_floating_point() and dtype is not None:
109
+ return batch.to(device=device, dtype=dtype)
110
+ return batch.to(device=device)
111
+ if isinstance(batch, dict):
112
+ return {k: _move_to_device(v, device, dtype) for k, v in batch.items()}
113
+ if isinstance(batch, (list, tuple)):
114
+ return type(batch)(_move_to_device(v, device, dtype) for v in batch)
115
+ return batch
116
+
117
+
118
+ def _first_existing_attr(obj: Any, names: Iterable[str]) -> Any | None:
119
+ for name in names:
120
+ cur = obj
121
+ ok = True
122
+ for part in name.split("."):
123
+ if hasattr(cur, part):
124
+ cur = getattr(cur, part)
125
+ else:
126
+ ok = False
127
+ break
128
+ if ok:
129
+ return cur
130
+ return None
131
+
132
+
133
+ def _safe_signature_accepts(fn: Any, name: str) -> bool:
134
+ try:
135
+ sig = inspect.signature(fn)
136
+ except (TypeError, ValueError):
137
+ return True
138
+ if name in sig.parameters:
139
+ return True
140
+ return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
141
+
142
+
143
+ class Rio2RMSNorm(nn.Module):
144
+ def __init__(self, dim: int, eps: float = 1e-6):
145
+ super().__init__()
146
+ self.weight = nn.Parameter(torch.ones(dim))
147
+ self.eps = eps
148
+
149
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
150
+ input_dtype = hidden_states.dtype
151
+ hidden_states = hidden_states.float()
152
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
153
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
154
+ return (self.weight * hidden_states).to(input_dtype)
155
+
156
+
157
+ class Rio2SinusoidalTimeEmbedding(nn.Module):
158
+ def __init__(self, dim: int):
159
+ super().__init__()
160
+ self.dim = dim
161
+ self.mlp = nn.Sequential(nn.Linear(dim, dim * 2), nn.SiLU(), nn.Linear(dim * 2, dim))
162
+
163
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
164
+ if timesteps.ndim == 0:
165
+ timesteps = timesteps[None]
166
+ half_dim = self.dim // 2
167
+ freqs = torch.exp(
168
+ torch.arange(half_dim, device=timesteps.device, dtype=torch.float32)
169
+ * -(math.log(10000.0) / max(half_dim - 1, 1))
170
+ )
171
+ args = timesteps.float()[:, None] * freqs[None]
172
+ emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
173
+ if emb.shape[-1] < self.dim:
174
+ emb = F.pad(emb, (0, self.dim - emb.shape[-1]))
175
+ return self.mlp(emb.to(dtype=self.mlp[0].weight.dtype))
176
+
177
+
178
+ class Rio2S1MoEResidualExpert(nn.Module):
179
+ def __init__(self, width: int, hidden_dim: int, flat_action_dim: int):
180
+ super().__init__()
181
+ self.net = nn.Sequential(
182
+ nn.Linear(width, hidden_dim),
183
+ nn.SiLU(),
184
+ nn.Linear(hidden_dim, flat_action_dim),
185
+ )
186
+ nn.init.zeros_(self.net[-1].weight)
187
+ nn.init.zeros_(self.net[-1].bias)
188
+
189
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
190
+ return self.net(context)
191
+
192
+
193
+ class Rio2S1MoEResidualBank(nn.Module):
194
+ def __init__(self, config: Rio2Config, width: int):
195
+ super().__init__()
196
+ self.config = config
197
+ self.flat_action_dim = int(config.action_horizon * config.action_dim)
198
+ self.num_experts = int(config.s1_moe_num_experts)
199
+ self.top_k = max(1, min(int(config.s1_moe_top_k), self.num_experts))
200
+ hidden_dim = int(config.s1_moe_expert_hidden_dim)
201
+ self.router = nn.Linear(width, self.num_experts)
202
+ self.experts = nn.ModuleList(
203
+ Rio2S1MoEResidualExpert(width, hidden_dim, self.flat_action_dim)
204
+ for _ in range(self.num_experts)
205
+ )
206
+
207
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
208
+ logits = self.router(context)
209
+ weights, indices = torch.topk(logits, k=self.top_k, dim=-1)
210
+ weights = torch.softmax(weights, dim=-1).to(dtype=context.dtype)
211
+ out = context.new_zeros(context.shape[0], self.flat_action_dim)
212
+ for slot in range(self.top_k):
213
+ slot_indices = indices[:, slot]
214
+ slot_weights = weights[:, slot]
215
+ for expert_id, expert in enumerate(self.experts):
216
+ mask = slot_indices == expert_id
217
+ if not bool(mask.any()):
218
+ continue
219
+ out[mask] = out[mask] + slot_weights[mask, None] * expert(context[mask])
220
+ return out.view(context.shape[0], self.config.action_horizon, self.config.action_dim)
221
+
222
+
223
+ class Rio2S2ContextCompressor(nn.Module):
224
+ """Fallback compressor for cached-token training.
225
+
226
+ Weight-preserved inference prefers the original MolmoAct2 S2/S1 bridge.
227
+ This compressor remains useful for small adapter training, tests, and for
228
+ base versions whose action expert cannot be split cleanly.
229
+ """
230
+
231
+ def __init__(self, config: Rio2Config):
232
+ super().__init__()
233
+ self.config = config
234
+ self.in_proj = nn.Linear(config.s2_input_width, config.s2_width)
235
+ self.query = nn.Parameter(torch.randn(config.s2_token_count, config.s2_width) / math.sqrt(config.s2_width))
236
+ layer = nn.TransformerEncoderLayer(
237
+ d_model=config.s2_width,
238
+ nhead=max(1, min(8, config.s2_width // 64)),
239
+ dim_feedforward=config.s2_width * 4,
240
+ dropout=0.0,
241
+ batch_first=True,
242
+ norm_first=True,
243
+ activation="gelu",
244
+ )
245
+ self.refiner = nn.TransformerEncoder(layer, num_layers=2)
246
+ self.norm = Rio2RMSNorm(config.s2_width)
247
+
248
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
249
+ if context.ndim == 2:
250
+ context = context.unsqueeze(0)
251
+ if context.shape[-1] != self.config.s2_input_width:
252
+ raise ValueError(
253
+ f"S2 context width mismatch: got {context.shape[-1]}, expected {self.config.s2_input_width}."
254
+ )
255
+ hidden_states = self.in_proj(context)
256
+ query = self.query.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
257
+ scores = (query @ hidden_states.transpose(-1, -2)) / math.sqrt(hidden_states.shape[-1])
258
+ attn = torch.softmax(scores, dim=-1)
259
+ tokens = attn @ hidden_states
260
+ tokens = self.refiner(tokens)
261
+ return self.norm(tokens)
262
+
263
+
264
+ class Rio2MolmoAct2Core(nn.Module):
265
+ """Weight-preserved wrapper around `allenai/MolmoAct2-SO100_101`.
266
+
267
+ The original MolmoAct2 object is loaded once and kept as the source of truth
268
+ for both S2 and S1. `refresh_s2()` extracts cache/context when possible;
269
+ `act_original()` first tries a split action-expert call and falls back to
270
+ `base.predict_action()` for exact original behavior.
271
+ """
272
+
273
+ VLM_CANDIDATES = (
274
+ "vlm",
275
+ "language_model",
276
+ "molmo",
277
+ "backbone",
278
+ "model",
279
+ "text_model",
280
+ )
281
+ ACTION_CANDIDATES = (
282
+ "action_expert",
283
+ "flow_head",
284
+ "action_head",
285
+ "continuous_action_expert",
286
+ "flow_matching_head",
287
+ "policy_head",
288
+ "robot_action_head",
289
+ )
290
+
291
+ def __init__(self, config: Rio2Config):
292
+ super().__init__()
293
+ self.config = config
294
+ self.base = None
295
+ self.processor = None
296
+ self.s2_module = None
297
+ self.s1_module = None
298
+ self.compressor = Rio2S2ContextCompressor(config)
299
+ self.last_pil_images: list[Image.Image] | None = None
300
+ self.last_instruction: str | None = None
301
+ self.last_base_outputs: Any | None = None
302
+ self.last_s2_cache: Any | None = None
303
+ self.last_compact_tokens: torch.Tensor | None = None
304
+ self.last_refresh_time: float = 0.0
305
+ self.last_runtime_path: str = "uninitialized"
306
+
307
+ @property
308
+ def base_device(self) -> torch.device:
309
+ if self.base is None:
310
+ return next(self.compressor.parameters()).device
311
+ return next(self.base.parameters()).device
312
+
313
+ def load_base(self, device: str | torch.device | None = None, device_map: str | None = None):
314
+ from transformers import AutoModelForImageTextToText, AutoProcessor
315
+
316
+ dtype = _torch_dtype_from_string(self.config.torch_dtype)
317
+ self.processor = AutoProcessor.from_pretrained(
318
+ self.config.base_model_id,
319
+ trust_remote_code=self.config.trust_remote_code,
320
+ )
321
+ kwargs = {"trust_remote_code": self.config.trust_remote_code, "dtype": dtype}
322
+ if device_map is not None:
323
+ kwargs["device_map"] = device_map
324
+ try:
325
+ self.base = AutoModelForImageTextToText.from_pretrained(self.config.base_model_id, **kwargs)
326
+ except TypeError:
327
+ kwargs.pop("dtype", None)
328
+ kwargs["torch_dtype"] = dtype
329
+ self.base = AutoModelForImageTextToText.from_pretrained(self.config.base_model_id, **kwargs)
330
+ if device is not None and device_map is None:
331
+ self.base.to(device)
332
+ self.base.eval()
333
+ self.s2_module = _first_existing_attr(self.base, self.VLM_CANDIDATES)
334
+ self.s1_module = _first_existing_attr(self.base, self.ACTION_CANDIDATES)
335
+ if self.s2_module is None:
336
+ logger.warning("RIO-2 could not locate a named MolmoAct2 S2/VLM module; full base forward will be used.")
337
+ if self.s1_module is None:
338
+ logger.warning("RIO-2 could not locate a named MolmoAct2 action expert; predict_action fallback will be used.")
339
+ return self
340
+
341
+ def freeze_base(self):
342
+ if self.base is not None:
343
+ self.base.eval()
344
+ for param in self.base.parameters():
345
+ param.requires_grad = False
346
+
347
+ def unfreeze_action_expert(self):
348
+ if self.s1_module is None:
349
+ return 0
350
+ count = 0
351
+ for param in self.s1_module.parameters():
352
+ param.requires_grad = True
353
+ count += param.numel()
354
+ return count
355
+
356
+ def unfreeze_adapters_only(self):
357
+ self.freeze_base()
358
+ for param in self.compressor.parameters():
359
+ param.requires_grad = True
360
+
361
+ def _extract_sequence_context(self, outputs: Any) -> torch.Tensor | None:
362
+ if outputs is None:
363
+ return None
364
+ if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
365
+ return outputs.hidden_states[-1]
366
+ if hasattr(outputs, "last_hidden_state") and outputs.last_hidden_state is not None:
367
+ return outputs.last_hidden_state
368
+ if isinstance(outputs, dict):
369
+ if outputs.get("hidden_states") is not None:
370
+ return outputs["hidden_states"][-1]
371
+ if outputs.get("last_hidden_state") is not None:
372
+ return outputs["last_hidden_state"]
373
+ if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None:
374
+ chunks = []
375
+ for layer in outputs.past_key_values:
376
+ if isinstance(layer, (tuple, list)) and len(layer) >= 2:
377
+ key, value = layer[0], layer[1]
378
+ chunks.append(key.float().mean(dim=(-3, -2)))
379
+ chunks.append(value.float().mean(dim=(-3, -2)))
380
+ if chunks:
381
+ return torch.stack(chunks, dim=1).to(dtype=chunks[0].dtype)
382
+ return None
383
+
384
+ def _extract_cache(self, outputs: Any) -> Any:
385
+ if outputs is None:
386
+ return None
387
+ for name in ("past_key_values", "kv_cache", "cache", "action_cache", "vlm_cache"):
388
+ if hasattr(outputs, name) and getattr(outputs, name) is not None:
389
+ return getattr(outputs, name)
390
+ if isinstance(outputs, dict) and outputs.get(name) is not None:
391
+ return outputs[name]
392
+ return outputs
393
+
394
+ @torch.no_grad()
395
+ def refresh_s2(self, images: ImageLike | list[ImageLike], instruction: str, force: bool = False) -> torch.Tensor:
396
+ if self.base is None or self.processor is None:
397
+ raise RuntimeError("MolmoAct2 base is not loaded. Call model.load_s2_base() first.")
398
+ age = time.time() - self.last_refresh_time
399
+ if (
400
+ not force
401
+ and self.last_compact_tokens is not None
402
+ and self.last_instruction == instruction
403
+ and age < self.config.max_s2_cache_age_s
404
+ ):
405
+ return self.last_compact_tokens
406
+
407
+ pil_images = _to_pil_list(images)
408
+ inputs = self.processor(images=pil_images, text=instruction, return_tensors="pt")
409
+ inputs = _move_to_device(inputs, self.base_device, _torch_dtype_from_string(self.config.torch_dtype))
410
+ try:
411
+ outputs = self.base(**inputs, use_cache=True, output_hidden_states=True, return_dict=True)
412
+ except TypeError:
413
+ outputs = self.base(**inputs, return_dict=True)
414
+
415
+ self.last_base_outputs = outputs
416
+ self.last_s2_cache = self._extract_cache(outputs)
417
+ self.last_pil_images = pil_images
418
+ self.last_instruction = instruction
419
+ self.last_refresh_time = time.time()
420
+ sequence_context = self._extract_sequence_context(outputs)
421
+ if sequence_context is not None:
422
+ sequence_context = sequence_context.to(
423
+ device=next(self.compressor.parameters()).device,
424
+ dtype=next(self.compressor.parameters()).dtype,
425
+ )
426
+ try:
427
+ self.last_compact_tokens = self.compressor(sequence_context).detach()
428
+ except Exception as exc:
429
+ logger.warning("RIO-2 compact-token compression failed: %s", exc)
430
+ self.last_compact_tokens = torch.zeros(
431
+ 1,
432
+ self.config.s2_token_count,
433
+ self.config.s2_width,
434
+ device=next(self.compressor.parameters()).device,
435
+ dtype=next(self.compressor.parameters()).dtype,
436
+ )
437
+ else:
438
+ self.last_compact_tokens = torch.zeros(
439
+ 1,
440
+ self.config.s2_token_count,
441
+ self.config.s2_width,
442
+ device=next(self.compressor.parameters()).device,
443
+ dtype=next(self.compressor.parameters()).dtype,
444
+ )
445
+ return self.last_compact_tokens
446
+
447
+ def _try_split_action_expert(
448
+ self,
449
+ state: torch.Tensor,
450
+ state_history: torch.Tensor | None,
451
+ action_history: torch.Tensor | None,
452
+ num_steps: int,
453
+ ) -> torch.Tensor | None:
454
+ if not self.config.prefer_split_action_expert or self.s1_module is None:
455
+ return None
456
+ candidates = [self.s1_module]
457
+ for method_name in ("predict_action", "sample", "generate_actions", "forward"):
458
+ if hasattr(self.s1_module, method_name):
459
+ candidates.append(getattr(self.s1_module, method_name))
460
+ for fn in candidates:
461
+ try:
462
+ kwargs = {}
463
+ if _safe_signature_accepts(fn, "state"):
464
+ kwargs["state"] = state
465
+ if _safe_signature_accepts(fn, "states"):
466
+ kwargs["states"] = state
467
+ if _safe_signature_accepts(fn, "vlm_kv_cache"):
468
+ kwargs["vlm_kv_cache"] = self.last_s2_cache
469
+ if _safe_signature_accepts(fn, "past_key_values"):
470
+ kwargs["past_key_values"] = self.last_s2_cache
471
+ if _safe_signature_accepts(fn, "s2_cache"):
472
+ kwargs["s2_cache"] = self.last_s2_cache
473
+ if _safe_signature_accepts(fn, "state_history"):
474
+ kwargs["state_history"] = state_history
475
+ if _safe_signature_accepts(fn, "action_history"):
476
+ kwargs["action_history"] = action_history
477
+ if _safe_signature_accepts(fn, "num_steps"):
478
+ kwargs["num_steps"] = num_steps
479
+ if _safe_signature_accepts(fn, "num_flow_steps"):
480
+ kwargs["num_flow_steps"] = num_steps
481
+ out = fn(**kwargs) if kwargs else fn(state)
482
+ actions = self._coerce_actions(out, state)
483
+ if actions is not None:
484
+ self.last_runtime_path = "split_original_action_expert"
485
+ return actions
486
+ except Exception as exc:
487
+ logger.debug("RIO-2 split action expert attempt failed for %s: %s", fn, exc)
488
+ return None
489
+
490
+ def _coerce_actions(self, out: Any, state: torch.Tensor) -> torch.Tensor | None:
491
+ if out is None:
492
+ return None
493
+ if torch.is_tensor(out):
494
+ actions = out
495
+ elif hasattr(out, "actions"):
496
+ actions = torch.as_tensor(out.actions, device=state.device)
497
+ elif isinstance(out, dict) and out.get("actions") is not None:
498
+ actions = torch.as_tensor(out["actions"], device=state.device)
499
+ else:
500
+ return None
501
+ if actions.ndim == 2:
502
+ actions = actions.unsqueeze(0)
503
+ return actions.to(device=state.device, dtype=state.dtype if state.is_floating_point() else torch.float32)
504
+
505
+ @torch.no_grad()
506
+ def predict_action_fallback(self, state: torch.Tensor, num_steps: int) -> torch.Tensor:
507
+ if self.base is None or self.processor is None:
508
+ raise RuntimeError("MolmoAct2 base is not loaded.")
509
+ if self.last_pil_images is None or self.last_instruction is None:
510
+ raise RuntimeError("S2 cache is empty. Call refresh_s2(images, instruction) first.")
511
+ if not hasattr(self.base, "predict_action"):
512
+ raise RuntimeError("MolmoAct2 base has no predict_action method and split action expert was unavailable.")
513
+ state_np = state.detach().float().cpu().numpy()
514
+ out = self.base.predict_action(
515
+ processor=self.processor,
516
+ images=self.last_pil_images,
517
+ task=self.last_instruction,
518
+ state=state_np,
519
+ norm_tag=self.config.norm_tag,
520
+ action_mode=self.config.action_mode,
521
+ num_steps=num_steps,
522
+ )
523
+ actions = torch.as_tensor(out.actions, device=state.device, dtype=state.dtype if state.is_floating_point() else torch.float32)
524
+ if actions.ndim == 2:
525
+ actions = actions.unsqueeze(0)
526
+ self.last_runtime_path = "predict_action_fallback_exact"
527
+ return actions
528
+
529
+ @torch.no_grad()
530
+ def act_original(
531
+ self,
532
+ state: torch.Tensor,
533
+ state_history: torch.Tensor | None = None,
534
+ action_history: torch.Tensor | None = None,
535
+ num_steps: int | None = None,
536
+ ) -> torch.Tensor:
537
+ steps = int(num_steps or self.config.molmoact_num_steps)
538
+ split_actions = self._try_split_action_expert(state, state_history, action_history, steps)
539
+ if split_actions is not None:
540
+ return split_actions
541
+ if self.config.fallback_to_predict_action:
542
+ return self.predict_action_fallback(state, steps)
543
+ raise RuntimeError("No callable original S1/action path was found and fallback_to_predict_action=False.")
544
+
545
+
546
+ class Rio2FastS1FlowActionExpert(nn.Module):
547
+ """Small fallback S1 for cached-token training.
548
+
549
+ In weight-preserved RIO-2, this is not the preferred runtime path. It remains
550
+ as an adapter/student fallback and for upstream tests without downloading
551
+ MolmoAct2.
552
+ """
553
+
554
+ def __init__(self, config: Rio2Config):
555
+ super().__init__()
556
+ self.config = config
557
+ width = config.s1_width
558
+ self.s2_proj = nn.Linear(config.s2_width, width)
559
+ self.state_proj = nn.Linear(config.state_dim, width)
560
+ self.state_hist_proj = nn.Linear(config.state_dim, width)
561
+ self.action_hist_proj = nn.Linear(config.action_dim, width)
562
+ self.noisy_action_proj = nn.Linear(config.action_dim, width)
563
+ self.time_emb = Rio2SinusoidalTimeEmbedding(width)
564
+ self.type_emb = nn.Parameter(torch.randn(5, width) / math.sqrt(width))
565
+ self.memory_proj = nn.Linear(config.s2_width, width)
566
+ self.memory_type_emb = nn.Parameter(torch.randn(1, width) / math.sqrt(width))
567
+ self.memory_gate = nn.Parameter(torch.tensor(-2.0))
568
+ layer = nn.TransformerEncoderLayer(
569
+ d_model=width,
570
+ nhead=config.s1_heads,
571
+ dim_feedforward=width * 4,
572
+ dropout=config.s1_dropout,
573
+ batch_first=True,
574
+ norm_first=True,
575
+ activation="gelu",
576
+ )
577
+ self.blocks = nn.TransformerEncoder(layer, num_layers=config.s1_layers)
578
+ self.norm = Rio2RMSNorm(width)
579
+ self.action_head = nn.Sequential(nn.Linear(width, width), nn.SiLU(), nn.Linear(width, config.action_dim))
580
+ self.noise_head = nn.Sequential(nn.Linear(width, width), nn.SiLU(), nn.Linear(width, config.action_dim))
581
+
582
+ hidden = int(config.jepa_hidden_dim)
583
+ latent = int(config.jepa_latent_dim)
584
+ self.jepa_s2_proj = nn.Linear(config.s2_width, hidden)
585
+ self.jepa_memory_proj = nn.Linear(config.s2_width, hidden)
586
+ self.jepa_state_proj = nn.Linear(config.state_dim, hidden)
587
+ self.jepa_action_hist_proj = nn.Linear(config.action_dim, hidden)
588
+ self.jepa_norm = Rio2RMSNorm(hidden)
589
+ self.jepa_predictor = nn.Sequential(nn.Linear(hidden, hidden), nn.SiLU(), nn.Linear(hidden, latent))
590
+ flat_action_dim = config.action_horizon * config.action_dim
591
+ self.action_encoder = nn.Sequential(nn.Linear(flat_action_dim, hidden), nn.SiLU(), nn.Linear(hidden, latent))
592
+ self.target_action_encoder = copy.deepcopy(self.action_encoder)
593
+ for param in self.target_action_encoder.parameters():
594
+ param.requires_grad = False
595
+ self.jepa_to_action_prior = nn.Sequential(
596
+ nn.Linear(latent, hidden),
597
+ nn.SiLU(),
598
+ nn.Linear(hidden, flat_action_dim),
599
+ )
600
+ self.consistency_head = nn.Sequential(
601
+ nn.Linear(latent, hidden),
602
+ nn.SiLU(),
603
+ nn.Linear(hidden, flat_action_dim),
604
+ )
605
+ self.jepa_condition_proj = nn.Linear(latent, width)
606
+ nn.init.zeros_(self.jepa_to_action_prior[-1].weight)
607
+ nn.init.zeros_(self.jepa_to_action_prior[-1].bias)
608
+ nn.init.zeros_(self.jepa_condition_proj.weight)
609
+ nn.init.zeros_(self.jepa_condition_proj.bias)
610
+ self.moe_residual = Rio2S1MoEResidualBank(config, width) if bool(config.enable_s1_moe) else None
611
+
612
+ def default_task_memory_from_s2(self, s2_tokens):
613
+ if s2_tokens.ndim == 2:
614
+ s2_tokens = s2_tokens.unsqueeze(0)
615
+ batch_size, token_count, width = s2_tokens.shape
616
+ slots = max(1, int(self.config.task_memory_slots))
617
+ if token_count >= slots:
618
+ return s2_tokens[:, :slots]
619
+ pad_value = s2_tokens.mean(dim=1, keepdim=True).expand(batch_size, slots - token_count, width)
620
+ return torch.cat([s2_tokens, pad_value], dim=1)
621
+
622
+ def _prepare_task_memory(self, task_memory, s2_tokens, batch_size, device, dtype):
623
+ if not bool(self.config.task_memory_enabled):
624
+ return None
625
+ if task_memory is None:
626
+ task_memory = self.default_task_memory_from_s2(s2_tokens)
627
+ if task_memory.ndim == 2:
628
+ task_memory = task_memory.unsqueeze(0)
629
+ task_memory = task_memory.to(device=device, dtype=dtype)
630
+ if task_memory.shape[0] == 1 and batch_size > 1:
631
+ task_memory = task_memory.expand(batch_size, -1, -1)
632
+ elif task_memory.shape[0] != batch_size:
633
+ task_memory = task_memory[:1].expand(batch_size, -1, -1)
634
+ slots = max(1, int(self.config.task_memory_slots))
635
+ if task_memory.shape[1] < slots:
636
+ pad = task_memory.mean(dim=1, keepdim=True).expand(batch_size, slots - task_memory.shape[1], task_memory.shape[2])
637
+ task_memory = torch.cat([task_memory, pad], dim=1)
638
+ return task_memory[:, :slots]
639
+
640
+ def _prepare_hist(self, values, length, dim, batch_size, device, dtype):
641
+ if values is None:
642
+ return torch.zeros(batch_size, length, dim, device=device, dtype=dtype)
643
+ if values.ndim == 2:
644
+ values = values.unsqueeze(0)
645
+ values = values.to(device=device, dtype=dtype)
646
+ if values.shape[1] < length:
647
+ pad = torch.zeros(values.shape[0], length - values.shape[1], values.shape[2], device=device, dtype=dtype)
648
+ values = torch.cat([pad, values], dim=1)
649
+ return values[:, -length:]
650
+
651
+ def _decode(self, s2_tokens, state, state_history, action_history, noisy_actions, timesteps, head, jepa_latent=None, task_memory=None):
652
+ if state.ndim == 1:
653
+ state = state.unsqueeze(0)
654
+ if noisy_actions.ndim == 2:
655
+ noisy_actions = noisy_actions.unsqueeze(0)
656
+ if s2_tokens.ndim == 2:
657
+ s2_tokens = s2_tokens.unsqueeze(0)
658
+ batch_size = state.shape[0]
659
+ device = state.device
660
+ dtype = state.dtype if state.is_floating_point() else torch.float32
661
+ state = state.to(device=device, dtype=dtype)
662
+ noisy_actions = noisy_actions.to(device=device, dtype=dtype)
663
+ s2_tokens = s2_tokens.to(device=device, dtype=dtype)
664
+ state_history = self._prepare_hist(state_history, self.config.state_history_len, self.config.state_dim, batch_size, device, dtype)
665
+ action_history = self._prepare_hist(action_history, self.config.action_history_len, self.config.action_dim, batch_size, device, dtype)
666
+ task_memory = self._prepare_task_memory(task_memory, s2_tokens, batch_size, device, dtype)
667
+ s2_tok = self.s2_proj(s2_tokens) + self.type_emb[0]
668
+ token_chunks = [s2_tok]
669
+ if task_memory is not None:
670
+ gate = torch.sigmoid(self.memory_gate).to(dtype=s2_tok.dtype)
671
+ mem_tok = gate * float(self.config.task_memory_alpha) * self.memory_proj(task_memory) + self.memory_type_emb
672
+ token_chunks.append(mem_tok)
673
+ state_tok = self.state_proj(state).unsqueeze(1) + self.type_emb[1]
674
+ state_hist_tok = self.state_hist_proj(state_history) + self.type_emb[2]
675
+ action_hist_tok = self.action_hist_proj(action_history) + self.type_emb[3]
676
+ action_tok = self.noisy_action_proj(noisy_actions) + self.type_emb[4]
677
+ action_tok = action_tok + self.time_emb(timesteps).unsqueeze(1)
678
+ if jepa_latent is not None and bool(self.config.enable_jepa_diffusion):
679
+ cond = self.jepa_condition_proj(jepa_latent.to(device=device, dtype=dtype)).unsqueeze(1)
680
+ action_tok = action_tok + float(self.config.jepa_condition_alpha) * cond.to(dtype=action_tok.dtype)
681
+ token_chunks.extend([state_tok, state_hist_tok, action_hist_tok, action_tok])
682
+ tokens = torch.cat(token_chunks, dim=1)
683
+ tokens = self.blocks(tokens)
684
+ tokens = self.norm(tokens)
685
+ return head(tokens[:, -self.config.action_horizon :])
686
+
687
+ def velocity(self, s2_tokens, state, state_history, action_history, noisy_actions, timesteps, jepa_latent=None, task_memory=None):
688
+ return self._decode(s2_tokens, state, state_history, action_history, noisy_actions, timesteps, self.action_head, jepa_latent, task_memory)
689
+
690
+ def diffusion_noise(self, s2_tokens, state, state_history, action_history, noisy_actions, timesteps, jepa_latent=None, task_memory=None):
691
+ return self._decode(s2_tokens, state, state_history, action_history, noisy_actions, timesteps, self.noise_head, jepa_latent, task_memory)
692
+
693
+ def predict_action_latent(self, s2_tokens, state, action_history=None, task_memory=None):
694
+ if state.ndim == 1:
695
+ state = state.unsqueeze(0)
696
+ if s2_tokens.ndim == 2:
697
+ s2_tokens = s2_tokens.unsqueeze(0)
698
+ batch_size = state.shape[0]
699
+ device = state.device
700
+ dtype = state.dtype if state.is_floating_point() else torch.float32
701
+ action_history = self._prepare_hist(action_history, self.config.action_history_len, self.config.action_dim, batch_size, device, dtype)
702
+ s2_summary = s2_tokens.to(device=device, dtype=dtype).mean(dim=1)
703
+ task_memory = self._prepare_task_memory(task_memory, s2_tokens, batch_size, device, dtype)
704
+ memory_summary = torch.zeros_like(s2_summary) if task_memory is None else task_memory.mean(dim=1)
705
+ hist_summary = action_history.mean(dim=1)
706
+ memory_scale = torch.sigmoid(self.memory_gate).to(dtype=s2_summary.dtype) * float(self.config.task_memory_alpha)
707
+ context = (
708
+ self.jepa_s2_proj(s2_summary)
709
+ + memory_scale * self.jepa_memory_proj(memory_summary)
710
+ + self.jepa_state_proj(state.to(dtype=dtype))
711
+ + self.jepa_action_hist_proj(hist_summary)
712
+ )
713
+ return self.jepa_predictor(self.jepa_norm(context))
714
+
715
+ def moe_action_residual(self, s2_tokens, state, action_history=None, task_memory=None):
716
+ if self.moe_residual is None:
717
+ return None
718
+ if state.ndim == 1:
719
+ state = state.unsqueeze(0)
720
+ if s2_tokens.ndim == 2:
721
+ s2_tokens = s2_tokens.unsqueeze(0)
722
+ batch_size = state.shape[0]
723
+ device = state.device
724
+ dtype = state.dtype if state.is_floating_point() else torch.float32
725
+ action_history = self._prepare_hist(action_history, self.config.action_history_len, self.config.action_dim, batch_size, device, dtype)
726
+ s2_tokens = s2_tokens.to(device=device, dtype=dtype)
727
+ task_memory = self._prepare_task_memory(task_memory, s2_tokens, batch_size, device, dtype)
728
+ context = (
729
+ self.s2_proj(s2_tokens).mean(dim=1)
730
+ + self.state_proj(state.to(dtype=dtype))
731
+ + self.action_hist_proj(action_history).mean(dim=1)
732
+ )
733
+ if task_memory is not None:
734
+ gate = torch.sigmoid(self.memory_gate).to(dtype=context.dtype)
735
+ context = context + gate * float(self.config.task_memory_alpha) * self.memory_proj(task_memory).mean(dim=1)
736
+ return self.moe_residual(context).to(dtype=dtype)
737
+
738
+ def encode_action_latent(self, actions, target=False):
739
+ if actions.ndim == 2:
740
+ actions = actions.unsqueeze(0)
741
+ flat = actions.reshape(actions.shape[0], -1)
742
+ encoder = self.target_action_encoder if target else self.action_encoder
743
+ return F.normalize(encoder(flat).float(), dim=-1).to(dtype=flat.dtype)
744
+
745
+ def action_prior_from_latent(self, latent, dtype):
746
+ prior = self.jepa_to_action_prior(latent).view(latent.shape[0], self.config.action_horizon, self.config.action_dim)
747
+ return prior.to(dtype=dtype)
748
+
749
+ def consistency_action_from_latent(self, latent, dtype):
750
+ actions = self.consistency_head(latent).view(latent.shape[0], self.config.action_horizon, self.config.action_dim)
751
+ return actions.to(dtype=dtype)
752
+
753
+ def jepa_diffusion_sample(self, s2_tokens, state, state_history=None, action_history=None, steps=None, task_memory=None):
754
+ if state.ndim == 1:
755
+ state = state.unsqueeze(0)
756
+ batch_size = state.shape[0]
757
+ dtype = state.dtype if state.is_floating_point() else torch.float32
758
+ jepa_latent = self.predict_action_latent(s2_tokens, state, action_history, task_memory)
759
+ x = self.consistency_action_from_latent(jepa_latent, dtype)
760
+ if float(self.config.jepa_action_prior_alpha) != 0.0:
761
+ x = x + float(self.config.jepa_action_prior_alpha) * self.action_prior_from_latent(jepa_latent, dtype)
762
+ moe_residual = self.moe_action_residual(s2_tokens, state, action_history, task_memory)
763
+ if moe_residual is not None:
764
+ x = x + float(self.config.s1_moe_residual_scale) * moe_residual
765
+ denoise_steps = int(steps if steps is not None else self.config.diffusion_inference_steps)
766
+ denoise_steps = max(0, denoise_steps)
767
+ if denoise_steps > 0:
768
+ x = x + torch.randn_like(x) * float(self.config.s1_sampling_noise_scale) / float(denoise_steps + 1)
769
+ for i in range(denoise_steps):
770
+ frac = float(denoise_steps - i) / float(max(denoise_steps, 1))
771
+ timesteps = torch.full((batch_size,), frac, device=state.device, dtype=dtype)
772
+ eps = self.diffusion_noise(s2_tokens, state, state_history, action_history, x, timesteps, jepa_latent, task_memory)
773
+ x = x - eps / float(denoise_steps + 1)
774
+ return x
775
+
776
+ @torch.no_grad()
777
+ def update_target_encoder(self, decay=None):
778
+ decay = float(self.config.jepa_ema_decay if decay is None else decay)
779
+ for online, target in zip(self.action_encoder.parameters(), self.target_action_encoder.parameters()):
780
+ target.data.mul_(decay).add_(online.data, alpha=1.0 - decay)
781
+
782
+ def freeze_target_encoder(self):
783
+ for param in self.target_action_encoder.parameters():
784
+ param.requires_grad = False
785
+
786
+ def training_loss(self, s2_tokens, state, state_history, action_history, target_actions, task_memory=None):
787
+ if target_actions.ndim == 2:
788
+ target_actions = target_actions.unsqueeze(0)
789
+ batch_size = target_actions.shape[0]
790
+ jepa_latent = self.predict_action_latent(s2_tokens, state, action_history, task_memory) if bool(self.config.enable_jepa_diffusion) else None
791
+ x0 = torch.randn_like(target_actions)
792
+ x1 = target_actions
793
+ timesteps = torch.rand(batch_size, device=target_actions.device, dtype=target_actions.dtype)
794
+ xt = (1.0 - timesteps[:, None, None]) * x0 + timesteps[:, None, None] * x1
795
+ target_velocity = x1 - x0
796
+ pred_velocity = self.velocity(s2_tokens, state, state_history, action_history, xt, timesteps, jepa_latent, task_memory)
797
+ loss_flow_mse = F.mse_loss(pred_velocity, target_velocity)
798
+ loss_flow_l1 = F.l1_loss(pred_velocity, target_velocity)
799
+ if bool(self.config.enable_jepa_diffusion) and float(self.config.diffusion_loss_weight) > 0:
800
+ diffusion_t = torch.rand(batch_size, device=target_actions.device, dtype=target_actions.dtype)
801
+ eps = torch.randn_like(target_actions)
802
+ alpha = torch.cos(diffusion_t[:, None, None] * (math.pi / 2.0))
803
+ sigma = torch.sin(diffusion_t[:, None, None] * (math.pi / 2.0))
804
+ noisy = alpha * target_actions + sigma * eps
805
+ pred_eps = self.diffusion_noise(s2_tokens, state, state_history, action_history, noisy, diffusion_t, jepa_latent, task_memory)
806
+ loss_diffusion = F.mse_loss(pred_eps, eps)
807
+ else:
808
+ loss_diffusion = target_actions.new_tensor(0.0)
809
+ if bool(self.config.enable_jepa_diffusion) and float(self.config.jepa_loss_weight) > 0:
810
+ pred_latent = F.normalize(jepa_latent.float(), dim=-1)
811
+ with torch.no_grad():
812
+ target_latent = self.encode_action_latent(target_actions, target=True).float()
813
+ loss_jepa = F.mse_loss(pred_latent, target_latent)
814
+ else:
815
+ loss_jepa = target_actions.new_tensor(0.0)
816
+ if bool(self.config.enable_jepa_diffusion) and float(self.config.jepa_action_prior_weight) > 0:
817
+ prior_actions = self.action_prior_from_latent(jepa_latent, target_actions.dtype)
818
+ loss_jepa_prior = F.mse_loss(prior_actions, target_actions)
819
+ else:
820
+ loss_jepa_prior = target_actions.new_tensor(0.0)
821
+ if bool(self.config.enable_jepa_diffusion) and float(self.config.consistency_loss_weight) > 0:
822
+ consistency_actions = self.consistency_action_from_latent(jepa_latent, target_actions.dtype)
823
+ moe_residual = self.moe_action_residual(s2_tokens, state, action_history, task_memory)
824
+ if moe_residual is not None:
825
+ consistency_actions = consistency_actions + float(self.config.s1_moe_residual_scale) * moe_residual
826
+ loss_consistency = F.mse_loss(consistency_actions, target_actions)
827
+ else:
828
+ loss_consistency = target_actions.new_tensor(0.0)
829
+ loss_smooth = (target_actions[:, 1:] - target_actions[:, :-1]).pow(2).mean() if target_actions.shape[1] > 1 else target_actions.new_tensor(0.0)
830
+ loss = (
831
+ self.config.flow_loss_weight * (loss_flow_mse + self.config.action_l1_weight * loss_flow_l1)
832
+ + self.config.smooth_loss_weight * loss_smooth
833
+ + self.config.diffusion_loss_weight * loss_diffusion
834
+ + self.config.consistency_loss_weight * loss_consistency
835
+ + self.config.jepa_loss_weight * loss_jepa.to(loss_flow_mse.dtype)
836
+ + self.config.jepa_action_prior_weight * loss_jepa_prior
837
+ )
838
+ return {
839
+ "loss": loss,
840
+ "loss_flow_mse": loss_flow_mse,
841
+ "loss_flow_l1": loss_flow_l1,
842
+ "loss_diffusion": loss_diffusion,
843
+ "loss_consistency": loss_consistency,
844
+ "loss_jepa": loss_jepa,
845
+ "loss_jepa_prior": loss_jepa_prior,
846
+ "loss_smooth": loss_smooth,
847
+ }
848
+
849
+ @torch.no_grad()
850
+ def sample(self, s2_tokens, state, state_history=None, action_history=None, steps=None, task_memory=None):
851
+ if state.ndim == 1:
852
+ state = state.unsqueeze(0)
853
+ if self.config.s1_policy_mode == "jepa_diffusion" and bool(self.config.enable_jepa_diffusion):
854
+ x = self.jepa_diffusion_sample(s2_tokens, state, state_history, action_history, steps=steps, task_memory=task_memory)
855
+ if self.config.action_clip > 0:
856
+ x = x.clamp(-self.config.action_clip, self.config.action_clip)
857
+ return x
858
+ batch_size = state.shape[0]
859
+ steps = steps or self.config.flow_inference_steps
860
+ dtype = state.dtype if state.is_floating_point() else torch.float32
861
+ jepa_latent = self.predict_action_latent(s2_tokens, state, action_history, task_memory) if bool(self.config.enable_jepa_diffusion) else None
862
+ x = torch.randn(batch_size, self.config.action_horizon, self.config.action_dim, device=state.device, dtype=dtype)
863
+ x = x * float(self.config.s1_sampling_noise_scale)
864
+ if jepa_latent is not None and float(self.config.jepa_action_prior_alpha) != 0.0:
865
+ x = x + float(self.config.jepa_action_prior_alpha) * self.action_prior_from_latent(jepa_latent, dtype)
866
+ moe_residual = self.moe_action_residual(s2_tokens, state, action_history, task_memory)
867
+ if moe_residual is not None:
868
+ x = x + float(self.config.s1_moe_residual_scale) * moe_residual
869
+ for i in range(steps):
870
+ timesteps = torch.full((batch_size,), float(i) / max(steps, 1), device=state.device, dtype=x.dtype)
871
+ x = x + self.velocity(s2_tokens, state, state_history, action_history, x, timesteps, jepa_latent, task_memory) / float(steps)
872
+ if self.config.action_clip > 0:
873
+ x = x.clamp(-self.config.action_clip, self.config.action_clip)
874
+ return x
875
+
876
+
877
+
878
+
879
+ class Rio2JepaS1ActionExpert(nn.Module):
880
+ """JEPA-style S1 that preserves the online S1 policy weights.
881
+
882
+ This module does **not** replace the original S1 policy with an unrelated
883
+ world model. Instead it wraps the existing fast flow S1 as `online_s1` and
884
+ adds a small latent prediction side objective:
885
+
886
+ - online_s1: action generator; initialized and trained exactly like the
887
+ existing RIO-2 S1 path, so old S1 checkpoints can be remapped into it.
888
+ - jepa_context_encoder + predictor: predicts future action latent from
889
+ S2 tokens, current state, and action history.
890
+ - target_action_encoder: EMA target encoder for the future action chunk.
891
+ - latent_to_action_delta: optional zero-initialized residual head.
892
+
893
+ Inference defaults to the online S1 policy. JEPA affects actions only when
894
+ `config.use_jepa_action_residual=True` and `config.jepa_action_alpha > 0`.
895
+ """
896
+
897
+ def __init__(self, config: Rio2Config):
898
+ super().__init__()
899
+ self.config = config
900
+ self.online_s1 = Rio2FastS1FlowActionExpert(config)
901
+
902
+ hidden = int(config.jepa_hidden_dim)
903
+ latent = int(config.jepa_latent_dim)
904
+ self.s2_jepa_proj = nn.Linear(config.s2_width, hidden)
905
+ self.state_jepa_proj = nn.Linear(config.state_dim, hidden)
906
+ self.action_hist_jepa_proj = nn.Linear(config.action_dim, hidden)
907
+ self.type_emb = nn.Parameter(torch.randn(3, hidden) / math.sqrt(hidden))
908
+
909
+ layer = nn.TransformerEncoderLayer(
910
+ d_model=hidden,
911
+ nhead=max(1, int(config.jepa_heads)),
912
+ dim_feedforward=hidden * 4,
913
+ dropout=config.s1_dropout,
914
+ batch_first=True,
915
+ norm_first=True,
916
+ activation="gelu",
917
+ )
918
+ self.jepa_context_encoder = nn.TransformerEncoder(layer, num_layers=max(1, int(config.jepa_layers)))
919
+ self.jepa_norm = Rio2RMSNorm(hidden)
920
+ self.jepa_predictor = nn.Sequential(
921
+ nn.Linear(hidden, hidden),
922
+ nn.SiLU(),
923
+ nn.Linear(hidden, latent),
924
+ )
925
+
926
+ flat_action_dim = config.action_horizon * config.action_dim
927
+ self.action_encoder = nn.Sequential(
928
+ nn.Linear(flat_action_dim, hidden),
929
+ nn.SiLU(),
930
+ nn.Linear(hidden, latent),
931
+ )
932
+ self.target_action_encoder = copy.deepcopy(self.action_encoder)
933
+ for param in self.target_action_encoder.parameters():
934
+ param.requires_grad = False
935
+
936
+ self.latent_to_action_delta = nn.Sequential(
937
+ nn.Linear(latent, hidden),
938
+ nn.SiLU(),
939
+ nn.Linear(hidden, flat_action_dim),
940
+ )
941
+ nn.init.zeros_(self.latent_to_action_delta[-1].weight)
942
+ nn.init.zeros_(self.latent_to_action_delta[-1].bias)
943
+
944
+ def _prepare_action_history(self, action_history, batch_size, device, dtype):
945
+ if action_history is None:
946
+ return torch.zeros(batch_size, self.config.action_history_len, self.config.action_dim, device=device, dtype=dtype)
947
+ if action_history.ndim == 2:
948
+ action_history = action_history.unsqueeze(0)
949
+ action_history = action_history.to(device=device, dtype=dtype)
950
+ if action_history.shape[1] < self.config.action_history_len:
951
+ pad = torch.zeros(
952
+ action_history.shape[0],
953
+ self.config.action_history_len - action_history.shape[1],
954
+ action_history.shape[2],
955
+ device=device,
956
+ dtype=dtype,
957
+ )
958
+ action_history = torch.cat([pad, action_history], dim=1)
959
+ return action_history[:, -self.config.action_history_len :]
960
+
961
+ def encode_context(self, s2_tokens, state, action_history=None):
962
+ if state.ndim == 1:
963
+ state = state.unsqueeze(0)
964
+ if s2_tokens.ndim == 2:
965
+ s2_tokens = s2_tokens.unsqueeze(0)
966
+ batch_size = state.shape[0]
967
+ device = state.device
968
+ dtype = state.dtype if state.is_floating_point() else torch.float32
969
+ s2_tokens = s2_tokens.to(device=device, dtype=dtype)
970
+ state = state.to(device=device, dtype=dtype)
971
+ action_history = self._prepare_action_history(action_history, batch_size, device, dtype)
972
+
973
+ s2_tok = self.s2_jepa_proj(s2_tokens) + self.type_emb[0]
974
+ state_tok = self.state_jepa_proj(state).unsqueeze(1) + self.type_emb[1]
975
+ hist_tok = self.action_hist_jepa_proj(action_history) + self.type_emb[2]
976
+ tokens = torch.cat([s2_tok, state_tok, hist_tok], dim=1)
977
+ hidden = self.jepa_context_encoder(tokens)
978
+ hidden = self.jepa_norm(hidden)
979
+ return hidden.mean(dim=1)
980
+
981
+ def predict_action_latent(self, s2_tokens, state, action_history=None):
982
+ context = self.encode_context(s2_tokens, state, action_history)
983
+ return self.jepa_predictor(context)
984
+
985
+ def encode_action_latent(self, actions: torch.Tensor, target: bool = False) -> torch.Tensor:
986
+ if actions.ndim == 2:
987
+ actions = actions.unsqueeze(0)
988
+ flat = actions.reshape(actions.shape[0], -1)
989
+ encoder = self.target_action_encoder if target else self.action_encoder
990
+ latent = encoder(flat)
991
+ return F.normalize(latent.float(), dim=-1).to(dtype=flat.dtype)
992
+
993
+ @torch.no_grad()
994
+ def update_target_encoder(self, decay: float | None = None):
995
+ decay = float(self.config.jepa_ema_decay if decay is None else decay)
996
+ for online, target in zip(self.action_encoder.parameters(), self.target_action_encoder.parameters()):
997
+ target.data.mul_(decay).add_(online.data, alpha=1.0 - decay)
998
+
999
+ def freeze_target_encoder(self):
1000
+ if hasattr(self.online_s1, "freeze_target_encoder"):
1001
+ self.online_s1.freeze_target_encoder()
1002
+ for param in self.target_action_encoder.parameters():
1003
+ param.requires_grad = False
1004
+
1005
+ def training_loss(self, s2_tokens, state, state_history, action_history, target_actions, task_memory=None):
1006
+ base_losses = self.online_s1.training_loss(s2_tokens, state, state_history, action_history, target_actions, task_memory=task_memory)
1007
+ pred_latent = F.normalize(self.predict_action_latent(s2_tokens, state, action_history).float(), dim=-1)
1008
+ with torch.no_grad():
1009
+ target_latent = self.encode_action_latent(target_actions, target=True).float()
1010
+ loss_jepa = F.mse_loss(pred_latent, target_latent)
1011
+ loss = base_losses["loss"] + float(self.config.jepa_loss_weight) * loss_jepa.to(base_losses["loss"].dtype)
1012
+ return {
1013
+ **base_losses,
1014
+ "loss": loss,
1015
+ "loss_jepa": loss_jepa,
1016
+ "pred_action_latent": pred_latent,
1017
+ "target_action_latent": target_latent,
1018
+ }
1019
+
1020
+ @torch.no_grad()
1021
+ def sample(self, s2_tokens, state, state_history=None, action_history=None, steps=None, task_memory=None):
1022
+ actions = self.online_s1.sample(s2_tokens, state, state_history, action_history, steps=steps, task_memory=task_memory)
1023
+ if bool(self.config.use_jepa_action_residual) and float(self.config.jepa_action_alpha) != 0.0:
1024
+ pred_latent = self.predict_action_latent(s2_tokens, state, action_history).to(actions.dtype)
1025
+ delta = self.latent_to_action_delta(pred_latent).view(
1026
+ actions.shape[0], self.config.action_horizon, self.config.action_dim
1027
+ )
1028
+ actions = actions + float(self.config.jepa_action_alpha) * delta
1029
+ if self.config.action_clip > 0:
1030
+ actions = actions.clamp(-self.config.action_clip, self.config.action_clip)
1031
+ return actions
1032
+
1033
+
1034
+ class Rio2ResidualAdapter(nn.Module):
1035
+ """Tiny correction head. Initial output is zero when residual_alpha=0."""
1036
+
1037
+ def __init__(self, config: Rio2Config):
1038
+ super().__init__()
1039
+ width = min(256, max(64, config.s1_width))
1040
+ self.net = nn.Sequential(
1041
+ nn.Linear(config.state_dim, width),
1042
+ nn.SiLU(),
1043
+ nn.Linear(width, config.action_horizon * config.action_dim),
1044
+ )
1045
+ self.config = config
1046
+ nn.init.zeros_(self.net[-1].weight)
1047
+ nn.init.zeros_(self.net[-1].bias)
1048
+
1049
+ def forward(self, state: torch.Tensor) -> torch.Tensor:
1050
+ if state.ndim == 1:
1051
+ state = state.unsqueeze(0)
1052
+ delta = self.net(state).view(state.shape[0], self.config.action_horizon, self.config.action_dim)
1053
+ return delta
1054
+
1055
+
1056
+ class Rio2PreTrainedModel(PreTrainedModel):
1057
+ config_class = Rio2Config
1058
+ base_model_prefix = "rio2"
1059
+ supports_gradient_checkpointing = False
1060
+ _no_split_modules = ["Rio2FastS1FlowActionExpert", "Rio2MolmoAct2Core"]
1061
+
1062
+ def _init_weights(self, module):
1063
+ std = 0.02
1064
+ if isinstance(module, nn.Linear):
1065
+ module.weight.data.normal_(mean=0.0, std=std)
1066
+ if module.bias is not None:
1067
+ module.bias.data.zero_()
1068
+ elif isinstance(module, nn.Embedding):
1069
+ module.weight.data.normal_(mean=0.0, std=std)
1070
+ elif isinstance(module, Rio2RMSNorm):
1071
+ module.weight.data.fill_(1.0)
1072
+
1073
+
1074
+ class Rio2Model(Rio2PreTrainedModel):
1075
+ """RIO-2 weight-preserved SO101 policy integrated as a Transformers model."""
1076
+
1077
+ def __init__(self, config: Rio2Config):
1078
+ super().__init__(config)
1079
+ self.molmoact = Rio2MolmoAct2Core(config)
1080
+ if bool(config.enable_jepa_s1):
1081
+ self.s1_student = Rio2JepaS1ActionExpert(config)
1082
+ else:
1083
+ self.s1_student = Rio2FastS1FlowActionExpert(config)
1084
+ self.residual_adapter = Rio2ResidualAdapter(config) if config.enable_residual_adapter else None
1085
+ self._s2_cache: torch.Tensor | None = None
1086
+ self._s2_cache_time: float = 0.0
1087
+ self._cached_instruction: str | None = None
1088
+ self._action_chunk_history: list[tuple[torch.Tensor, int]] = []
1089
+ self._task_memory_cache: torch.Tensor | None = None
1090
+ self.post_init()
1091
+ if config.load_base_on_init:
1092
+ logger.warning("config.load_base_on_init=True loads MolmoAct2 during construction; prefer load_s2_base().")
1093
+ self.load_s2_base()
1094
+ self.apply_finetuning_policy()
1095
+
1096
+ @property
1097
+ def s2(self):
1098
+ """Backward-compatible alias without duplicate module registration."""
1099
+ return self.molmoact
1100
+
1101
+ @property
1102
+ def s1(self):
1103
+ """Backward-compatible alias without duplicate module registration."""
1104
+ return self.s1_student
1105
+
1106
+ def load_s2_base(self, device: str | torch.device | None = None, device_map: str | None = None):
1107
+ self.molmoact.load_base(device=device, device_map=device_map)
1108
+ self.apply_finetuning_policy()
1109
+ return self
1110
+
1111
+ def freeze_s2_base(self):
1112
+ self.molmoact.freeze_base()
1113
+ return self
1114
+
1115
+ @torch.no_grad()
1116
+ def reset_temporal_ensemble(self):
1117
+ self._action_chunk_history.clear()
1118
+ return self
1119
+
1120
+ @torch.no_grad()
1121
+ def reset_task_memory(self):
1122
+ self._task_memory_cache = None
1123
+ return self
1124
+
1125
+ @torch.no_grad()
1126
+ def update_task_memory(self, s2_tokens: torch.Tensor, reset: bool = False):
1127
+ if not bool(self.config.task_memory_enabled):
1128
+ self._task_memory_cache = None
1129
+ return None
1130
+ device = next(self.s1_student.parameters()).device
1131
+ dtype = next(self.s1_student.parameters()).dtype
1132
+ if hasattr(self.s1_student, "default_task_memory_from_s2"):
1133
+ candidate = self.s1_student.default_task_memory_from_s2(s2_tokens.to(device=device, dtype=dtype)).detach()
1134
+ elif hasattr(self.s1_student, "online_s1"):
1135
+ candidate = self.s1_student.online_s1.default_task_memory_from_s2(s2_tokens.to(device=device, dtype=dtype)).detach()
1136
+ else:
1137
+ return None
1138
+ if (
1139
+ reset
1140
+ or self._task_memory_cache is None
1141
+ or tuple(self._task_memory_cache.shape) != tuple(candidate.shape)
1142
+ ):
1143
+ memory = candidate
1144
+ else:
1145
+ memory = float(self.config.task_memory_ema) * self._task_memory_cache.to(device=device, dtype=dtype)
1146
+ memory = memory + (1.0 - float(self.config.task_memory_ema)) * candidate
1147
+ max_norm = float(self.config.task_memory_max_norm)
1148
+ if max_norm > 0:
1149
+ norms = memory.norm(dim=-1, keepdim=True).clamp_min(1e-6)
1150
+ memory = memory * (max_norm / norms).clamp(max=1.0)
1151
+ self._task_memory_cache = memory.detach()
1152
+ return self._task_memory_cache
1153
+
1154
+ @torch.no_grad()
1155
+ def _apply_temporal_ensemble(self, actions: torch.Tensor, enabled: bool | None = None) -> torch.Tensor:
1156
+ use_ensemble = self.config.temporal_ensemble_enabled if enabled is None else enabled
1157
+ if not use_ensemble or actions.ndim != 3:
1158
+ return actions
1159
+ if self._action_chunk_history and self._action_chunk_history[0][0].shape != actions.shape:
1160
+ self.reset_temporal_ensemble()
1161
+ aged = []
1162
+ for chunk, age in self._action_chunk_history:
1163
+ next_age = age + 1
1164
+ if next_age < actions.shape[1]:
1165
+ aged.append((chunk, next_age))
1166
+ max_chunks = int(max(1, self.config.temporal_ensemble_max_chunks))
1167
+ self._action_chunk_history = [(actions.detach(), 0)] + aged[: max_chunks - 1]
1168
+ blended = []
1169
+ for offset in range(actions.shape[1]):
1170
+ weighted_sum = None
1171
+ weight_sum = 0.0
1172
+ for chunk, age in self._action_chunk_history:
1173
+ idx = age + offset
1174
+ if idx >= actions.shape[1]:
1175
+ continue
1176
+ weight = math.exp(-float(self.config.temporal_ensemble_decay) * age)
1177
+ value = chunk[:, idx]
1178
+ weighted_sum = value * weight if weighted_sum is None else weighted_sum + value * weight
1179
+ weight_sum += weight
1180
+ blended.append(weighted_sum / max(weight_sum, 1e-8))
1181
+ return torch.stack(blended, dim=1)
1182
+
1183
+ def apply_finetuning_policy(self):
1184
+ """Apply the default small-tuning policy.
1185
+
1186
+ Base MolmoAct2 weights are frozen by default. Trainable parameters are
1187
+ compressor/student/residual-adapter parameters, and optionally the
1188
+ detected original action expert when the user explicitly unfreezes it.
1189
+ """
1190
+ if self.config.train_adapters_only:
1191
+ if self.molmoact.base is not None:
1192
+ self.molmoact.freeze_base()
1193
+ for param in self.molmoact.compressor.parameters():
1194
+ param.requires_grad = True
1195
+ for param in self.s1_student.parameters():
1196
+ param.requires_grad = True
1197
+ if hasattr(self.s1_student, "freeze_target_encoder"):
1198
+ self.s1_student.freeze_target_encoder()
1199
+ if self.residual_adapter is not None:
1200
+ for param in self.residual_adapter.parameters():
1201
+ param.requires_grad = bool(self.config.residual_trainable)
1202
+ return self
1203
+
1204
+ def unfreeze_original_s1(self):
1205
+ return self.molmoact.unfreeze_action_expert()
1206
+
1207
+ def trainable_parameter_names(self) -> list[str]:
1208
+ return [name for name, param in self.named_parameters() if param.requires_grad]
1209
+
1210
+ @torch.no_grad()
1211
+ def update_jepa_target_encoder(self, decay: float | None = None):
1212
+ if hasattr(self.s1_student, "update_target_encoder"):
1213
+ self.s1_student.update_target_encoder(decay=decay)
1214
+ return self
1215
+
1216
+ @torch.no_grad()
1217
+ def refresh_s2(self, images: ImageLike | list[ImageLike], instruction: str, force: bool = False) -> torch.Tensor:
1218
+ tokens = self.molmoact.refresh_s2(images, instruction, force=force)
1219
+ if instruction != self._cached_instruction or force:
1220
+ self.reset_temporal_ensemble()
1221
+ self.update_task_memory(tokens, reset=instruction != self._cached_instruction)
1222
+ else:
1223
+ self.update_task_memory(tokens, reset=False)
1224
+ self._s2_cache = tokens.detach()
1225
+ self._s2_cache_time = time.time()
1226
+ self._cached_instruction = instruction
1227
+ return self._s2_cache
1228
+
1229
+ @torch.no_grad()
1230
+ def act_fast(
1231
+ self,
1232
+ state: torch.Tensor,
1233
+ state_history: torch.Tensor | None = None,
1234
+ action_history: torch.Tensor | None = None,
1235
+ steps: int | None = None,
1236
+ use_original: bool | None = None,
1237
+ temporal_ensemble: bool | None = None,
1238
+ ) -> torch.Tensor:
1239
+ use_original = self.config.use_original_s1 if use_original is None else use_original
1240
+ device = next(self.parameters()).device
1241
+ state = state.to(device)
1242
+ state_history = None if state_history is None else state_history.to(device)
1243
+ action_history = None if action_history is None else action_history.to(device)
1244
+
1245
+ if use_original and self.molmoact.base is not None:
1246
+ actions = self.molmoact.act_original(state, state_history, action_history, num_steps=steps)
1247
+ else:
1248
+ if self._s2_cache is None:
1249
+ raise RuntimeError("S2 cache is empty. Call refresh_s2() or pass s2_tokens to forward().")
1250
+ s2_tokens = self._s2_cache.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
1251
+ task_memory = None if self._task_memory_cache is None else self._task_memory_cache.to(device=device, dtype=s2_tokens.dtype)
1252
+ actions = self.s1_student.sample(s2_tokens, state, state_history, action_history, steps=steps, task_memory=task_memory)
1253
+
1254
+ if self.residual_adapter is not None and float(self.config.residual_alpha) != 0.0:
1255
+ actions = actions + float(self.config.residual_alpha) * self.residual_adapter(state).to(actions.dtype)
1256
+ if self.config.action_clip > 0:
1257
+ actions = actions.clamp(-self.config.action_clip, self.config.action_clip)
1258
+ return self._apply_temporal_ensemble(actions, enabled=temporal_ensemble)
1259
+
1260
+ def forward_from_s2_tokens(
1261
+ self,
1262
+ s2_tokens: torch.Tensor,
1263
+ state: torch.Tensor,
1264
+ state_history: torch.Tensor | None = None,
1265
+ action_history: torch.Tensor | None = None,
1266
+ target_actions: torch.Tensor | None = None,
1267
+ s1_steps: int | None = None,
1268
+ task_memory: torch.Tensor | None = None,
1269
+ return_dict: bool | None = None,
1270
+ ) -> tuple[torch.Tensor] | Rio2Output:
1271
+ return self.forward(
1272
+ state=state,
1273
+ s2_tokens=s2_tokens,
1274
+ state_history=state_history,
1275
+ action_history=action_history,
1276
+ target_actions=target_actions,
1277
+ s1_steps=s1_steps,
1278
+ task_memory=task_memory,
1279
+ return_dict=return_dict,
1280
+ use_original=False,
1281
+ )
1282
+
1283
+ def forward(
1284
+ self,
1285
+ state: torch.Tensor,
1286
+ s2_tokens: torch.Tensor | None = None,
1287
+ state_history: torch.Tensor | None = None,
1288
+ action_history: torch.Tensor | None = None,
1289
+ target_actions: torch.Tensor | None = None,
1290
+ images: ImageLike | list[ImageLike] | None = None,
1291
+ instruction: str | None = None,
1292
+ refresh_s2: bool = False,
1293
+ s1_steps: int | None = None,
1294
+ task_memory: torch.Tensor | None = None,
1295
+ use_original: bool | None = None,
1296
+ return_dict: bool | None = None,
1297
+ **kwargs,
1298
+ ) -> tuple[torch.Tensor] | Rio2Output:
1299
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1300
+ use_original = self.config.use_original_s1 if use_original is None else use_original
1301
+
1302
+ if refresh_s2:
1303
+ if images is None or instruction is None:
1304
+ raise ValueError("`images` and `instruction` are required when refresh_s2=True.")
1305
+ s2_tokens = self.refresh_s2(images, instruction, force=True)
1306
+ elif s2_tokens is None:
1307
+ s2_tokens = self._s2_cache
1308
+
1309
+ device = next(self.parameters()).device
1310
+ state = state.to(device)
1311
+ state_history = None if state_history is None else state_history.to(device)
1312
+ action_history = None if action_history is None else action_history.to(device)
1313
+
1314
+ # Training path: use cached-token/student path by default because the
1315
+ # original MolmoAct2 action expert is usually frozen and remote-code
1316
+ # signatures may not expose target-action training directly.
1317
+ if target_actions is not None:
1318
+ if s2_tokens is None:
1319
+ raise ValueError("Training requires `s2_tokens` or refresh_s2=True.")
1320
+ s2_tokens = s2_tokens.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
1321
+ task_memory = None if task_memory is None else task_memory.to(device=device, dtype=s2_tokens.dtype)
1322
+ target_actions = target_actions.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
1323
+ losses = self.s1_student.training_loss(s2_tokens, state, state_history, action_history, target_actions, task_memory=task_memory)
1324
+ output = Rio2Output(
1325
+ loss=losses["loss"],
1326
+ s2_tokens=s2_tokens,
1327
+ loss_flow_mse=losses["loss_flow_mse"],
1328
+ loss_flow_l1=losses["loss_flow_l1"],
1329
+ loss_diffusion=losses.get("loss_diffusion"),
1330
+ loss_consistency=losses.get("loss_consistency"),
1331
+ loss_smooth=losses["loss_smooth"],
1332
+ loss_jepa=losses.get("loss_jepa"),
1333
+ loss_jepa_prior=losses.get("loss_jepa_prior"),
1334
+ pred_action_latent=losses.get("pred_action_latent"),
1335
+ target_action_latent=losses.get("target_action_latent"),
1336
+ runtime_path="jepa_s1_training" if "loss_jepa" in losses else "student_adapter_training",
1337
+ )
1338
+ return tuple(v for v in output.to_tuple() if v is not None) if not return_dict else output
1339
+
1340
+ if use_original and self.molmoact.base is not None:
1341
+ actions = self.act_fast(state, state_history, action_history, steps=s1_steps, use_original=True)
1342
+ runtime_path = self.molmoact.last_runtime_path
1343
+ tokens = self._s2_cache
1344
+ else:
1345
+ if s2_tokens is None:
1346
+ raise ValueError("Pass `s2_tokens`, call refresh_s2(), or set refresh_s2=True.")
1347
+ s2_tokens = s2_tokens.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
1348
+ if task_memory is None and self._task_memory_cache is not None:
1349
+ task_memory = self._task_memory_cache
1350
+ task_memory = None if task_memory is None else task_memory.to(device=device, dtype=s2_tokens.dtype)
1351
+ actions = self.s1_student.sample(s2_tokens, state, state_history, action_history, steps=s1_steps, task_memory=task_memory)
1352
+ if self.residual_adapter is not None and float(self.config.residual_alpha) != 0.0:
1353
+ actions = actions + float(self.config.residual_alpha) * self.residual_adapter(state).to(actions.dtype)
1354
+ runtime_path = "student_cached_tokens"
1355
+ tokens = s2_tokens
1356
+
1357
+ output = Rio2Output(actions=actions, s2_tokens=tokens, runtime_path=runtime_path)
1358
+ return (actions, tokens) if not return_dict else output
1359
+
1360
+
1361
+ __all__ = [
1362
+ "Rio2Model",
1363
+ "Rio2PreTrainedModel",
1364
+ ]
processing_rio2.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team and the Rio2 contributors.
2
+ # Licensed under the Apache License, Version 2.0.
3
+ """Processor for Rio2."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from transformers.processing_utils import ProcessorMixin
12
+ from transformers.utils import logging
13
+
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class Rio2Processor(ProcessorMixin):
19
+ attributes = []
20
+ optional_attributes = []
21
+
22
+ def __init__(self, base_processor=None, base_model_id: str | None = None, **kwargs):
23
+ self.base_processor = base_processor
24
+ self.base_model_id = base_model_id
25
+ self.chat_template = kwargs.pop("chat_template", None)
26
+
27
+ @classmethod
28
+ def from_base_model_id(cls, base_model_id: str, **kwargs):
29
+ from transformers import AutoProcessor
30
+
31
+ base_processor = AutoProcessor.from_pretrained(base_model_id, trust_remote_code=True, **kwargs)
32
+ return cls(base_processor=base_processor, base_model_id=base_model_id)
33
+
34
+ @classmethod
35
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
36
+ path = Path(pretrained_model_name_or_path)
37
+ base_model_id = kwargs.pop("base_model_id", None)
38
+ load_base_processor = bool(kwargs.pop("load_base_processor", False))
39
+ hub_kwargs = {
40
+ key: kwargs.get(key)
41
+ for key in ["cache_dir", "force_download", "proxies", "token", "revision", "local_files_only", "subfolder"]
42
+ if key in kwargs
43
+ }
44
+ if path.exists():
45
+ cfg_path = path / "processor_config.json"
46
+ model_cfg_path = path / "config.json"
47
+ if cfg_path.exists():
48
+ data = json.loads(cfg_path.read_text(encoding="utf-8"))
49
+ base_model_id = base_model_id or data.get("base_model_id")
50
+ if base_model_id is None and model_cfg_path.exists():
51
+ data = json.loads(model_cfg_path.read_text(encoding="utf-8"))
52
+ base_model_id = data.get("base_model_id")
53
+ else:
54
+ try:
55
+ from transformers.utils import cached_file
56
+
57
+ cfg_file = cached_file(pretrained_model_name_or_path, "processor_config.json", **hub_kwargs)
58
+ if cfg_file:
59
+ data = json.loads(Path(cfg_file).read_text(encoding="utf-8"))
60
+ base_model_id = base_model_id or data.get("base_model_id")
61
+ except Exception as exc:
62
+ logger.debug("Could not load RIO-2 processor config from Hub: %s", exc)
63
+ if base_model_id is None:
64
+ try:
65
+ from transformers.utils import cached_file
66
+
67
+ cfg_file = cached_file(pretrained_model_name_or_path, "config.json", **hub_kwargs)
68
+ if cfg_file:
69
+ data = json.loads(Path(cfg_file).read_text(encoding="utf-8"))
70
+ base_model_id = data.get("base_model_id")
71
+ except Exception as exc:
72
+ logger.debug("Could not load RIO-2 model config from Hub: %s", exc)
73
+
74
+ base_processor = None
75
+ if base_model_id and load_base_processor:
76
+ try:
77
+ from transformers import AutoProcessor
78
+
79
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
80
+ base_processor = AutoProcessor.from_pretrained(base_model_id, trust_remote_code=trust_remote_code, **kwargs)
81
+ except Exception as exc:
82
+ logger.warning("Could not load base processor %s: %s", base_model_id, exc)
83
+ return cls(base_processor=base_processor, base_model_id=base_model_id)
84
+
85
+ def save_pretrained(self, save_directory, **kwargs):
86
+ out = Path(save_directory)
87
+ out.mkdir(parents=True, exist_ok=True)
88
+ data = {
89
+ "processor_class": self.__class__.__name__,
90
+ "base_model_id": self.base_model_id,
91
+ "auto_map": {"AutoProcessor": "processing_rio2.Rio2Processor"},
92
+ }
93
+ (out / "processor_config.json").write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8")
94
+ if self.base_processor is not None and kwargs.pop("save_base_processor", False):
95
+ base_dir = out / "base_processor"
96
+ self.base_processor.save_pretrained(base_dir)
97
+ return [str(out / "processor_config.json")]
98
+
99
+ def __call__(
100
+ self,
101
+ images=None,
102
+ instruction: str | None = None,
103
+ state: Any | None = None,
104
+ state_history: Any | None = None,
105
+ action_history: Any | None = None,
106
+ target_actions: Any | None = None,
107
+ **kwargs,
108
+ ) -> dict[str, Any]:
109
+ out: dict[str, Any] = {}
110
+ if self.base_processor is not None and images is not None and instruction is not None:
111
+ out.update(self.base_processor(images=images, text=instruction, return_tensors="pt", **kwargs))
112
+ else:
113
+ if images is not None:
114
+ out["images"] = images
115
+ if instruction is not None:
116
+ out["instruction"] = instruction
117
+ if state is not None:
118
+ out["state"] = state
119
+ if state_history is not None:
120
+ out["state_history"] = state_history
121
+ if action_history is not None:
122
+ out["action_history"] = action_history
123
+ if target_actions is not None:
124
+ out["target_actions"] = target_actions
125
+ return out
126
+
127
+
128
+ __all__ = ["Rio2Processor"]
processor_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "Rio2Processor",
3
+ "base_model_id": "allenai/MolmoAct2-SO100_101",
4
+ "auto_map": {
5
+ "AutoProcessor": "processing_rio2.Rio2Processor"
6
+ }
7
+ }
rio2_export_manifest.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "repo_mode": "custom_code",
3
+ "single_weight_file": "model.safetensors",
4
+ "config_file": "config.json",
5
+ "custom_code_files": [
6
+ "configuration_rio2.py",
7
+ "modeling_rio2.py",
8
+ "processing_rio2.py"
9
+ ],
10
+ "repo_id": "hoguai/RIO-2"
11
+ }
runtime_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_id": "__merged_in_model_safetensors__",
3
+ "local_base": true,
4
+ "single_weight_file": "model.safetensors",
5
+ "s1_expanded": true,
6
+ "s1_moe_finetuned": true,
7
+ "requires_finetune_after_expansion": false
8
+ }