Robotics
Transformers
Safetensors
English
rio2
feature-extraction
Mixture of Experts
diffusion-jepa
custom_code
RIO-2 / modeling_rio2.py
hoguai's picture
Upload 7 files
9a5262f verified
# Copyright 2026 The HuggingFace Inc. team and the Rio2 contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
"""PyTorch RIO-2 model.
Runtime modes:
- `refresh_s2(images, instruction)`: low-frequency context refresh.
- `act_fast(state, ...)`: high-frequency action generation.
- `forward(..., s2_tokens=...)`: cached-token fallback used for tests and
for adapter-only training when MolmoAct2 internals are unavailable.
"""
from __future__ import annotations
import copy
import inspect
import math
import time
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_rio2 import Rio2Config
logger = logging.get_logger(__name__)
ImageLike = Image.Image | np.ndarray | torch.Tensor
@dataclass
class Rio2Output(ModelOutput):
"""Output type for RIO-2."""
loss: torch.FloatTensor | None = None
actions: torch.FloatTensor | None = None
s2_tokens: torch.FloatTensor | None = None
loss_flow_mse: torch.FloatTensor | None = None
loss_flow_l1: torch.FloatTensor | None = None
loss_diffusion: torch.FloatTensor | None = None
loss_consistency: torch.FloatTensor | None = None
loss_smooth: torch.FloatTensor | None = None
loss_jepa: torch.FloatTensor | None = None
loss_jepa_prior: torch.FloatTensor | None = None
pred_action_latent: torch.FloatTensor | None = None
target_action_latent: torch.FloatTensor | None = None
runtime_path: str | None = None
def _torch_dtype_from_string(dtype_name: str) -> torch.dtype:
table = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
return table.get(str(dtype_name).lower(), torch.bfloat16)
def _to_pil_list(images: ImageLike | list[ImageLike] | tuple[ImageLike, ...]) -> list[Image.Image]:
if isinstance(images, (list, tuple)):
return [_to_pil_list(x)[0] for x in images]
if isinstance(images, Image.Image):
return [images.convert("RGB")]
if isinstance(images, np.ndarray):
arr = images
if arr.ndim == 4:
return [_to_pil_list(a)[0] for a in arr]
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = np.repeat(arr[..., None], 3, axis=-1)
if arr.ndim == 3 and arr.shape[-1] == 1:
arr = np.repeat(arr, 3, axis=-1)
if arr.ndim == 3 and arr.shape[-1] == 4:
arr = arr[..., :3]
if arr.dtype != np.uint8:
arr = np.clip(arr, 0, 1) if arr.max() <= 1.5 else np.clip(arr, 0, 255)
arr = (arr * 255).astype(np.uint8) if arr.max() <= 1.5 else arr.astype(np.uint8)
return [Image.fromarray(arr).convert("RGB")]
if torch.is_tensor(images):
x = images.detach().cpu()
if x.ndim == 4:
return [_to_pil_list(xx)[0] for xx in x]
if x.ndim == 3 and x.shape[0] in (1, 3):
x = x.permute(1, 2, 0)
arr = x.numpy()
if arr.dtype != np.uint8:
arr = np.clip(arr, 0, 1) if arr.max() <= 1.5 else np.clip(arr, 0, 255)
arr = (arr * 255).astype(np.uint8) if arr.max() <= 1.5 else arr.astype(np.uint8)
return [Image.fromarray(arr).convert("RGB")]
raise TypeError(f"Unsupported image type: {type(images)}")
def _move_to_device(batch: Any, device: torch.device, dtype: torch.dtype | None = None) -> Any:
if torch.is_tensor(batch):
if batch.is_floating_point() and dtype is not None:
return batch.to(device=device, dtype=dtype)
return batch.to(device=device)
if isinstance(batch, dict):
return {k: _move_to_device(v, device, dtype) for k, v in batch.items()}
if isinstance(batch, (list, tuple)):
return type(batch)(_move_to_device(v, device, dtype) for v in batch)
return batch
def _first_existing_attr(obj: Any, names: Iterable[str]) -> Any | None:
for name in names:
cur = obj
ok = True
for part in name.split("."):
if hasattr(cur, part):
cur = getattr(cur, part)
else:
ok = False
break
if ok:
return cur
return None
def _safe_signature_accepts(fn: Any, name: str) -> bool:
try:
sig = inspect.signature(fn)
except (TypeError, ValueError):
return True
if name in sig.parameters:
return True
return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
class Rio2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.float()
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return (self.weight * hidden_states).to(input_dtype)
class Rio2SinusoidalTimeEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
self.mlp = nn.Sequential(nn.Linear(dim, dim * 2), nn.SiLU(), nn.Linear(dim * 2, dim))
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
if timesteps.ndim == 0:
timesteps = timesteps[None]
half_dim = self.dim // 2
freqs = torch.exp(
torch.arange(half_dim, device=timesteps.device, dtype=torch.float32)
* -(math.log(10000.0) / max(half_dim - 1, 1))
)
args = timesteps.float()[:, None] * freqs[None]
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if emb.shape[-1] < self.dim:
emb = F.pad(emb, (0, self.dim - emb.shape[-1]))
return self.mlp(emb.to(dtype=self.mlp[0].weight.dtype))
class Rio2S1MoEResidualExpert(nn.Module):
def __init__(self, width: int, hidden_dim: int, flat_action_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(width, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, flat_action_dim),
)
nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)
def forward(self, context: torch.Tensor) -> torch.Tensor:
return self.net(context)
class Rio2S1MoEResidualBank(nn.Module):
def __init__(self, config: Rio2Config, width: int):
super().__init__()
self.config = config
self.flat_action_dim = int(config.action_horizon * config.action_dim)
self.num_experts = int(config.s1_moe_num_experts)
self.top_k = max(1, min(int(config.s1_moe_top_k), self.num_experts))
hidden_dim = int(config.s1_moe_expert_hidden_dim)
self.router = nn.Linear(width, self.num_experts)
self.experts = nn.ModuleList(
Rio2S1MoEResidualExpert(width, hidden_dim, self.flat_action_dim)
for _ in range(self.num_experts)
)
def forward(self, context: torch.Tensor) -> torch.Tensor:
logits = self.router(context)
weights, indices = torch.topk(logits, k=self.top_k, dim=-1)
weights = torch.softmax(weights, dim=-1).to(dtype=context.dtype)
out = context.new_zeros(context.shape[0], self.flat_action_dim)
for slot in range(self.top_k):
slot_indices = indices[:, slot]
slot_weights = weights[:, slot]
for expert_id, expert in enumerate(self.experts):
mask = slot_indices == expert_id
if not bool(mask.any()):
continue
out[mask] = out[mask] + slot_weights[mask, None] * expert(context[mask])
return out.view(context.shape[0], self.config.action_horizon, self.config.action_dim)
class Rio2S2ContextCompressor(nn.Module):
"""Fallback compressor for cached-token training.
Weight-preserved inference prefers the original MolmoAct2 S2/S1 bridge.
This compressor remains useful for small adapter training, tests, and for
base versions whose action expert cannot be split cleanly.
"""
def __init__(self, config: Rio2Config):
super().__init__()
self.config = config
self.in_proj = nn.Linear(config.s2_input_width, config.s2_width)
self.query = nn.Parameter(torch.randn(config.s2_token_count, config.s2_width) / math.sqrt(config.s2_width))
layer = nn.TransformerEncoderLayer(
d_model=config.s2_width,
nhead=max(1, min(8, config.s2_width // 64)),
dim_feedforward=config.s2_width * 4,
dropout=0.0,
batch_first=True,
norm_first=True,
activation="gelu",
)
self.refiner = nn.TransformerEncoder(layer, num_layers=2)
self.norm = Rio2RMSNorm(config.s2_width)
def forward(self, context: torch.Tensor) -> torch.Tensor:
if context.ndim == 2:
context = context.unsqueeze(0)
if context.shape[-1] != self.config.s2_input_width:
raise ValueError(
f"S2 context width mismatch: got {context.shape[-1]}, expected {self.config.s2_input_width}."
)
hidden_states = self.in_proj(context)
query = self.query.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
scores = (query @ hidden_states.transpose(-1, -2)) / math.sqrt(hidden_states.shape[-1])
attn = torch.softmax(scores, dim=-1)
tokens = attn @ hidden_states
tokens = self.refiner(tokens)
return self.norm(tokens)
class Rio2MolmoAct2Core(nn.Module):
"""Weight-preserved wrapper around `allenai/MolmoAct2-SO100_101`.
The original MolmoAct2 object is loaded once and kept as the source of truth
for both S2 and S1. `refresh_s2()` extracts cache/context when possible;
`act_original()` first tries a split action-expert call and falls back to
`base.predict_action()` for exact original behavior.
"""
VLM_CANDIDATES = (
"vlm",
"language_model",
"molmo",
"backbone",
"model",
"text_model",
)
ACTION_CANDIDATES = (
"action_expert",
"flow_head",
"action_head",
"continuous_action_expert",
"flow_matching_head",
"policy_head",
"robot_action_head",
)
def __init__(self, config: Rio2Config):
super().__init__()
self.config = config
self.base = None
self.processor = None
self.s2_module = None
self.s1_module = None
self.compressor = Rio2S2ContextCompressor(config)
self.last_pil_images: list[Image.Image] | None = None
self.last_instruction: str | None = None
self.last_base_outputs: Any | None = None
self.last_s2_cache: Any | None = None
self.last_compact_tokens: torch.Tensor | None = None
self.last_refresh_time: float = 0.0
self.last_runtime_path: str = "uninitialized"
@property
def base_device(self) -> torch.device:
if self.base is None:
return next(self.compressor.parameters()).device
return next(self.base.parameters()).device
def load_base(self, device: str | torch.device | None = None, device_map: str | None = None):
from transformers import AutoModelForImageTextToText, AutoProcessor
dtype = _torch_dtype_from_string(self.config.torch_dtype)
self.processor = AutoProcessor.from_pretrained(
self.config.base_model_id,
trust_remote_code=self.config.trust_remote_code,
)
kwargs = {"trust_remote_code": self.config.trust_remote_code, "dtype": dtype}
if device_map is not None:
kwargs["device_map"] = device_map
try:
self.base = AutoModelForImageTextToText.from_pretrained(self.config.base_model_id, **kwargs)
except TypeError:
kwargs.pop("dtype", None)
kwargs["torch_dtype"] = dtype
self.base = AutoModelForImageTextToText.from_pretrained(self.config.base_model_id, **kwargs)
if device is not None and device_map is None:
self.base.to(device)
self.base.eval()
self.s2_module = _first_existing_attr(self.base, self.VLM_CANDIDATES)
self.s1_module = _first_existing_attr(self.base, self.ACTION_CANDIDATES)
if self.s2_module is None:
logger.warning("RIO-2 could not locate a named MolmoAct2 S2/VLM module; full base forward will be used.")
if self.s1_module is None:
logger.warning("RIO-2 could not locate a named MolmoAct2 action expert; predict_action fallback will be used.")
return self
def freeze_base(self):
if self.base is not None:
self.base.eval()
for param in self.base.parameters():
param.requires_grad = False
def unfreeze_action_expert(self):
if self.s1_module is None:
return 0
count = 0
for param in self.s1_module.parameters():
param.requires_grad = True
count += param.numel()
return count
def unfreeze_adapters_only(self):
self.freeze_base()
for param in self.compressor.parameters():
param.requires_grad = True
def _extract_sequence_context(self, outputs: Any) -> torch.Tensor | None:
if outputs is None:
return None
if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
return outputs.hidden_states[-1]
if hasattr(outputs, "last_hidden_state") and outputs.last_hidden_state is not None:
return outputs.last_hidden_state
if isinstance(outputs, dict):
if outputs.get("hidden_states") is not None:
return outputs["hidden_states"][-1]
if outputs.get("last_hidden_state") is not None:
return outputs["last_hidden_state"]
if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None:
chunks = []
for layer in outputs.past_key_values:
if isinstance(layer, (tuple, list)) and len(layer) >= 2:
key, value = layer[0], layer[1]
chunks.append(key.float().mean(dim=(-3, -2)))
chunks.append(value.float().mean(dim=(-3, -2)))
if chunks:
return torch.stack(chunks, dim=1).to(dtype=chunks[0].dtype)
return None
def _extract_cache(self, outputs: Any) -> Any:
if outputs is None:
return None
for name in ("past_key_values", "kv_cache", "cache", "action_cache", "vlm_cache"):
if hasattr(outputs, name) and getattr(outputs, name) is not None:
return getattr(outputs, name)
if isinstance(outputs, dict) and outputs.get(name) is not None:
return outputs[name]
return outputs
@torch.no_grad()
def refresh_s2(self, images: ImageLike | list[ImageLike], instruction: str, force: bool = False) -> torch.Tensor:
if self.base is None or self.processor is None:
raise RuntimeError("MolmoAct2 base is not loaded. Call model.load_s2_base() first.")
age = time.time() - self.last_refresh_time
if (
not force
and self.last_compact_tokens is not None
and self.last_instruction == instruction
and age < self.config.max_s2_cache_age_s
):
return self.last_compact_tokens
pil_images = _to_pil_list(images)
inputs = self.processor(images=pil_images, text=instruction, return_tensors="pt")
inputs = _move_to_device(inputs, self.base_device, _torch_dtype_from_string(self.config.torch_dtype))
try:
outputs = self.base(**inputs, use_cache=True, output_hidden_states=True, return_dict=True)
except TypeError:
outputs = self.base(**inputs, return_dict=True)
self.last_base_outputs = outputs
self.last_s2_cache = self._extract_cache(outputs)
self.last_pil_images = pil_images
self.last_instruction = instruction
self.last_refresh_time = time.time()
sequence_context = self._extract_sequence_context(outputs)
if sequence_context is not None:
sequence_context = sequence_context.to(
device=next(self.compressor.parameters()).device,
dtype=next(self.compressor.parameters()).dtype,
)
try:
self.last_compact_tokens = self.compressor(sequence_context).detach()
except Exception as exc:
logger.warning("RIO-2 compact-token compression failed: %s", exc)
self.last_compact_tokens = torch.zeros(
1,
self.config.s2_token_count,
self.config.s2_width,
device=next(self.compressor.parameters()).device,
dtype=next(self.compressor.parameters()).dtype,
)
else:
self.last_compact_tokens = torch.zeros(
1,
self.config.s2_token_count,
self.config.s2_width,
device=next(self.compressor.parameters()).device,
dtype=next(self.compressor.parameters()).dtype,
)
return self.last_compact_tokens
def _try_split_action_expert(
self,
state: torch.Tensor,
state_history: torch.Tensor | None,
action_history: torch.Tensor | None,
num_steps: int,
) -> torch.Tensor | None:
if not self.config.prefer_split_action_expert or self.s1_module is None:
return None
candidates = [self.s1_module]
for method_name in ("predict_action", "sample", "generate_actions", "forward"):
if hasattr(self.s1_module, method_name):
candidates.append(getattr(self.s1_module, method_name))
for fn in candidates:
try:
kwargs = {}
if _safe_signature_accepts(fn, "state"):
kwargs["state"] = state
if _safe_signature_accepts(fn, "states"):
kwargs["states"] = state
if _safe_signature_accepts(fn, "vlm_kv_cache"):
kwargs["vlm_kv_cache"] = self.last_s2_cache
if _safe_signature_accepts(fn, "past_key_values"):
kwargs["past_key_values"] = self.last_s2_cache
if _safe_signature_accepts(fn, "s2_cache"):
kwargs["s2_cache"] = self.last_s2_cache
if _safe_signature_accepts(fn, "state_history"):
kwargs["state_history"] = state_history
if _safe_signature_accepts(fn, "action_history"):
kwargs["action_history"] = action_history
if _safe_signature_accepts(fn, "num_steps"):
kwargs["num_steps"] = num_steps
if _safe_signature_accepts(fn, "num_flow_steps"):
kwargs["num_flow_steps"] = num_steps
out = fn(**kwargs) if kwargs else fn(state)
actions = self._coerce_actions(out, state)
if actions is not None:
self.last_runtime_path = "split_original_action_expert"
return actions
except Exception as exc:
logger.debug("RIO-2 split action expert attempt failed for %s: %s", fn, exc)
return None
def _coerce_actions(self, out: Any, state: torch.Tensor) -> torch.Tensor | None:
if out is None:
return None
if torch.is_tensor(out):
actions = out
elif hasattr(out, "actions"):
actions = torch.as_tensor(out.actions, device=state.device)
elif isinstance(out, dict) and out.get("actions") is not None:
actions = torch.as_tensor(out["actions"], device=state.device)
else:
return None
if actions.ndim == 2:
actions = actions.unsqueeze(0)
return actions.to(device=state.device, dtype=state.dtype if state.is_floating_point() else torch.float32)
@torch.no_grad()
def predict_action_fallback(self, state: torch.Tensor, num_steps: int) -> torch.Tensor:
if self.base is None or self.processor is None:
raise RuntimeError("MolmoAct2 base is not loaded.")
if self.last_pil_images is None or self.last_instruction is None:
raise RuntimeError("S2 cache is empty. Call refresh_s2(images, instruction) first.")
if not hasattr(self.base, "predict_action"):
raise RuntimeError("MolmoAct2 base has no predict_action method and split action expert was unavailable.")
state_np = state.detach().float().cpu().numpy()
out = self.base.predict_action(
processor=self.processor,
images=self.last_pil_images,
task=self.last_instruction,
state=state_np,
norm_tag=self.config.norm_tag,
action_mode=self.config.action_mode,
num_steps=num_steps,
)
actions = torch.as_tensor(out.actions, device=state.device, dtype=state.dtype if state.is_floating_point() else torch.float32)
if actions.ndim == 2:
actions = actions.unsqueeze(0)
self.last_runtime_path = "predict_action_fallback_exact"
return actions
@torch.no_grad()
def act_original(
self,
state: torch.Tensor,
state_history: torch.Tensor | None = None,
action_history: torch.Tensor | None = None,
num_steps: int | None = None,
) -> torch.Tensor:
steps = int(num_steps or self.config.molmoact_num_steps)
split_actions = self._try_split_action_expert(state, state_history, action_history, steps)
if split_actions is not None:
return split_actions
if self.config.fallback_to_predict_action:
return self.predict_action_fallback(state, steps)
raise RuntimeError("No callable original S1/action path was found and fallback_to_predict_action=False.")
class Rio2FastS1FlowActionExpert(nn.Module):
"""Small fallback S1 for cached-token training.
In weight-preserved RIO-2, this is not the preferred runtime path. It remains
as an adapter/student fallback and for upstream tests without downloading
MolmoAct2.
"""
def __init__(self, config: Rio2Config):
super().__init__()
self.config = config
width = config.s1_width
self.s2_proj = nn.Linear(config.s2_width, width)
self.state_proj = nn.Linear(config.state_dim, width)
self.state_hist_proj = nn.Linear(config.state_dim, width)
self.action_hist_proj = nn.Linear(config.action_dim, width)
self.noisy_action_proj = nn.Linear(config.action_dim, width)
self.time_emb = Rio2SinusoidalTimeEmbedding(width)
self.type_emb = nn.Parameter(torch.randn(5, width) / math.sqrt(width))
self.memory_proj = nn.Linear(config.s2_width, width)
self.memory_type_emb = nn.Parameter(torch.randn(1, width) / math.sqrt(width))
self.memory_gate = nn.Parameter(torch.tensor(-2.0))
layer = nn.TransformerEncoderLayer(
d_model=width,
nhead=config.s1_heads,
dim_feedforward=width * 4,
dropout=config.s1_dropout,
batch_first=True,
norm_first=True,
activation="gelu",
)
self.blocks = nn.TransformerEncoder(layer, num_layers=config.s1_layers)
self.norm = Rio2RMSNorm(width)
self.action_head = nn.Sequential(nn.Linear(width, width), nn.SiLU(), nn.Linear(width, config.action_dim))
self.noise_head = nn.Sequential(nn.Linear(width, width), nn.SiLU(), nn.Linear(width, config.action_dim))
hidden = int(config.jepa_hidden_dim)
latent = int(config.jepa_latent_dim)
self.jepa_s2_proj = nn.Linear(config.s2_width, hidden)
self.jepa_memory_proj = nn.Linear(config.s2_width, hidden)
self.jepa_state_proj = nn.Linear(config.state_dim, hidden)
self.jepa_action_hist_proj = nn.Linear(config.action_dim, hidden)
self.jepa_norm = Rio2RMSNorm(hidden)
self.jepa_predictor = nn.Sequential(nn.Linear(hidden, hidden), nn.SiLU(), nn.Linear(hidden, latent))
flat_action_dim = config.action_horizon * config.action_dim
self.action_encoder = nn.Sequential(nn.Linear(flat_action_dim, hidden), nn.SiLU(), nn.Linear(hidden, latent))
self.target_action_encoder = copy.deepcopy(self.action_encoder)
for param in self.target_action_encoder.parameters():
param.requires_grad = False
self.jepa_to_action_prior = nn.Sequential(
nn.Linear(latent, hidden),
nn.SiLU(),
nn.Linear(hidden, flat_action_dim),
)
self.consistency_head = nn.Sequential(
nn.Linear(latent, hidden),
nn.SiLU(),
nn.Linear(hidden, flat_action_dim),
)
self.jepa_condition_proj = nn.Linear(latent, width)
nn.init.zeros_(self.jepa_to_action_prior[-1].weight)
nn.init.zeros_(self.jepa_to_action_prior[-1].bias)
nn.init.zeros_(self.jepa_condition_proj.weight)
nn.init.zeros_(self.jepa_condition_proj.bias)
self.moe_residual = Rio2S1MoEResidualBank(config, width) if bool(config.enable_s1_moe) else None
def default_task_memory_from_s2(self, s2_tokens):
if s2_tokens.ndim == 2:
s2_tokens = s2_tokens.unsqueeze(0)
batch_size, token_count, width = s2_tokens.shape
slots = max(1, int(self.config.task_memory_slots))
if token_count >= slots:
return s2_tokens[:, :slots]
pad_value = s2_tokens.mean(dim=1, keepdim=True).expand(batch_size, slots - token_count, width)
return torch.cat([s2_tokens, pad_value], dim=1)
def _prepare_task_memory(self, task_memory, s2_tokens, batch_size, device, dtype):
if not bool(self.config.task_memory_enabled):
return None
if task_memory is None:
task_memory = self.default_task_memory_from_s2(s2_tokens)
if task_memory.ndim == 2:
task_memory = task_memory.unsqueeze(0)
task_memory = task_memory.to(device=device, dtype=dtype)
if task_memory.shape[0] == 1 and batch_size > 1:
task_memory = task_memory.expand(batch_size, -1, -1)
elif task_memory.shape[0] != batch_size:
task_memory = task_memory[:1].expand(batch_size, -1, -1)
slots = max(1, int(self.config.task_memory_slots))
if task_memory.shape[1] < slots:
pad = task_memory.mean(dim=1, keepdim=True).expand(batch_size, slots - task_memory.shape[1], task_memory.shape[2])
task_memory = torch.cat([task_memory, pad], dim=1)
return task_memory[:, :slots]
def _prepare_hist(self, values, length, dim, batch_size, device, dtype):
if values is None:
return torch.zeros(batch_size, length, dim, device=device, dtype=dtype)
if values.ndim == 2:
values = values.unsqueeze(0)
values = values.to(device=device, dtype=dtype)
if values.shape[1] < length:
pad = torch.zeros(values.shape[0], length - values.shape[1], values.shape[2], device=device, dtype=dtype)
values = torch.cat([pad, values], dim=1)
return values[:, -length:]
def _decode(self, s2_tokens, state, state_history, action_history, noisy_actions, timesteps, head, jepa_latent=None, task_memory=None):
if state.ndim == 1:
state = state.unsqueeze(0)
if noisy_actions.ndim == 2:
noisy_actions = noisy_actions.unsqueeze(0)
if s2_tokens.ndim == 2:
s2_tokens = s2_tokens.unsqueeze(0)
batch_size = state.shape[0]
device = state.device
dtype = state.dtype if state.is_floating_point() else torch.float32
state = state.to(device=device, dtype=dtype)
noisy_actions = noisy_actions.to(device=device, dtype=dtype)
s2_tokens = s2_tokens.to(device=device, dtype=dtype)
state_history = self._prepare_hist(state_history, self.config.state_history_len, self.config.state_dim, batch_size, device, dtype)
action_history = self._prepare_hist(action_history, self.config.action_history_len, self.config.action_dim, batch_size, device, dtype)
task_memory = self._prepare_task_memory(task_memory, s2_tokens, batch_size, device, dtype)
s2_tok = self.s2_proj(s2_tokens) + self.type_emb[0]
token_chunks = [s2_tok]
if task_memory is not None:
gate = torch.sigmoid(self.memory_gate).to(dtype=s2_tok.dtype)
mem_tok = gate * float(self.config.task_memory_alpha) * self.memory_proj(task_memory) + self.memory_type_emb
token_chunks.append(mem_tok)
state_tok = self.state_proj(state).unsqueeze(1) + self.type_emb[1]
state_hist_tok = self.state_hist_proj(state_history) + self.type_emb[2]
action_hist_tok = self.action_hist_proj(action_history) + self.type_emb[3]
action_tok = self.noisy_action_proj(noisy_actions) + self.type_emb[4]
action_tok = action_tok + self.time_emb(timesteps).unsqueeze(1)
if jepa_latent is not None and bool(self.config.enable_jepa_diffusion):
cond = self.jepa_condition_proj(jepa_latent.to(device=device, dtype=dtype)).unsqueeze(1)
action_tok = action_tok + float(self.config.jepa_condition_alpha) * cond.to(dtype=action_tok.dtype)
token_chunks.extend([state_tok, state_hist_tok, action_hist_tok, action_tok])
tokens = torch.cat(token_chunks, dim=1)
tokens = self.blocks(tokens)
tokens = self.norm(tokens)
return head(tokens[:, -self.config.action_horizon :])
def velocity(self, s2_tokens, state, state_history, action_history, noisy_actions, timesteps, jepa_latent=None, task_memory=None):
return self._decode(s2_tokens, state, state_history, action_history, noisy_actions, timesteps, self.action_head, jepa_latent, task_memory)
def diffusion_noise(self, s2_tokens, state, state_history, action_history, noisy_actions, timesteps, jepa_latent=None, task_memory=None):
return self._decode(s2_tokens, state, state_history, action_history, noisy_actions, timesteps, self.noise_head, jepa_latent, task_memory)
def predict_action_latent(self, s2_tokens, state, action_history=None, task_memory=None):
if state.ndim == 1:
state = state.unsqueeze(0)
if s2_tokens.ndim == 2:
s2_tokens = s2_tokens.unsqueeze(0)
batch_size = state.shape[0]
device = state.device
dtype = state.dtype if state.is_floating_point() else torch.float32
action_history = self._prepare_hist(action_history, self.config.action_history_len, self.config.action_dim, batch_size, device, dtype)
s2_summary = s2_tokens.to(device=device, dtype=dtype).mean(dim=1)
task_memory = self._prepare_task_memory(task_memory, s2_tokens, batch_size, device, dtype)
memory_summary = torch.zeros_like(s2_summary) if task_memory is None else task_memory.mean(dim=1)
hist_summary = action_history.mean(dim=1)
memory_scale = torch.sigmoid(self.memory_gate).to(dtype=s2_summary.dtype) * float(self.config.task_memory_alpha)
context = (
self.jepa_s2_proj(s2_summary)
+ memory_scale * self.jepa_memory_proj(memory_summary)
+ self.jepa_state_proj(state.to(dtype=dtype))
+ self.jepa_action_hist_proj(hist_summary)
)
return self.jepa_predictor(self.jepa_norm(context))
def moe_action_residual(self, s2_tokens, state, action_history=None, task_memory=None):
if self.moe_residual is None:
return None
if state.ndim == 1:
state = state.unsqueeze(0)
if s2_tokens.ndim == 2:
s2_tokens = s2_tokens.unsqueeze(0)
batch_size = state.shape[0]
device = state.device
dtype = state.dtype if state.is_floating_point() else torch.float32
action_history = self._prepare_hist(action_history, self.config.action_history_len, self.config.action_dim, batch_size, device, dtype)
s2_tokens = s2_tokens.to(device=device, dtype=dtype)
task_memory = self._prepare_task_memory(task_memory, s2_tokens, batch_size, device, dtype)
context = (
self.s2_proj(s2_tokens).mean(dim=1)
+ self.state_proj(state.to(dtype=dtype))
+ self.action_hist_proj(action_history).mean(dim=1)
)
if task_memory is not None:
gate = torch.sigmoid(self.memory_gate).to(dtype=context.dtype)
context = context + gate * float(self.config.task_memory_alpha) * self.memory_proj(task_memory).mean(dim=1)
return self.moe_residual(context).to(dtype=dtype)
def encode_action_latent(self, actions, target=False):
if actions.ndim == 2:
actions = actions.unsqueeze(0)
flat = actions.reshape(actions.shape[0], -1)
encoder = self.target_action_encoder if target else self.action_encoder
return F.normalize(encoder(flat).float(), dim=-1).to(dtype=flat.dtype)
def action_prior_from_latent(self, latent, dtype):
prior = self.jepa_to_action_prior(latent).view(latent.shape[0], self.config.action_horizon, self.config.action_dim)
return prior.to(dtype=dtype)
def consistency_action_from_latent(self, latent, dtype):
actions = self.consistency_head(latent).view(latent.shape[0], self.config.action_horizon, self.config.action_dim)
return actions.to(dtype=dtype)
def jepa_diffusion_sample(self, s2_tokens, state, state_history=None, action_history=None, steps=None, task_memory=None):
if state.ndim == 1:
state = state.unsqueeze(0)
batch_size = state.shape[0]
dtype = state.dtype if state.is_floating_point() else torch.float32
jepa_latent = self.predict_action_latent(s2_tokens, state, action_history, task_memory)
x = self.consistency_action_from_latent(jepa_latent, dtype)
if float(self.config.jepa_action_prior_alpha) != 0.0:
x = x + float(self.config.jepa_action_prior_alpha) * self.action_prior_from_latent(jepa_latent, dtype)
moe_residual = self.moe_action_residual(s2_tokens, state, action_history, task_memory)
if moe_residual is not None:
x = x + float(self.config.s1_moe_residual_scale) * moe_residual
denoise_steps = int(steps if steps is not None else self.config.diffusion_inference_steps)
denoise_steps = max(0, denoise_steps)
if denoise_steps > 0:
x = x + torch.randn_like(x) * float(self.config.s1_sampling_noise_scale) / float(denoise_steps + 1)
for i in range(denoise_steps):
frac = float(denoise_steps - i) / float(max(denoise_steps, 1))
timesteps = torch.full((batch_size,), frac, device=state.device, dtype=dtype)
eps = self.diffusion_noise(s2_tokens, state, state_history, action_history, x, timesteps, jepa_latent, task_memory)
x = x - eps / float(denoise_steps + 1)
return x
@torch.no_grad()
def update_target_encoder(self, decay=None):
decay = float(self.config.jepa_ema_decay if decay is None else decay)
for online, target in zip(self.action_encoder.parameters(), self.target_action_encoder.parameters()):
target.data.mul_(decay).add_(online.data, alpha=1.0 - decay)
def freeze_target_encoder(self):
for param in self.target_action_encoder.parameters():
param.requires_grad = False
def training_loss(self, s2_tokens, state, state_history, action_history, target_actions, task_memory=None):
if target_actions.ndim == 2:
target_actions = target_actions.unsqueeze(0)
batch_size = target_actions.shape[0]
jepa_latent = self.predict_action_latent(s2_tokens, state, action_history, task_memory) if bool(self.config.enable_jepa_diffusion) else None
x0 = torch.randn_like(target_actions)
x1 = target_actions
timesteps = torch.rand(batch_size, device=target_actions.device, dtype=target_actions.dtype)
xt = (1.0 - timesteps[:, None, None]) * x0 + timesteps[:, None, None] * x1
target_velocity = x1 - x0
pred_velocity = self.velocity(s2_tokens, state, state_history, action_history, xt, timesteps, jepa_latent, task_memory)
loss_flow_mse = F.mse_loss(pred_velocity, target_velocity)
loss_flow_l1 = F.l1_loss(pred_velocity, target_velocity)
if bool(self.config.enable_jepa_diffusion) and float(self.config.diffusion_loss_weight) > 0:
diffusion_t = torch.rand(batch_size, device=target_actions.device, dtype=target_actions.dtype)
eps = torch.randn_like(target_actions)
alpha = torch.cos(diffusion_t[:, None, None] * (math.pi / 2.0))
sigma = torch.sin(diffusion_t[:, None, None] * (math.pi / 2.0))
noisy = alpha * target_actions + sigma * eps
pred_eps = self.diffusion_noise(s2_tokens, state, state_history, action_history, noisy, diffusion_t, jepa_latent, task_memory)
loss_diffusion = F.mse_loss(pred_eps, eps)
else:
loss_diffusion = target_actions.new_tensor(0.0)
if bool(self.config.enable_jepa_diffusion) and float(self.config.jepa_loss_weight) > 0:
pred_latent = F.normalize(jepa_latent.float(), dim=-1)
with torch.no_grad():
target_latent = self.encode_action_latent(target_actions, target=True).float()
loss_jepa = F.mse_loss(pred_latent, target_latent)
else:
loss_jepa = target_actions.new_tensor(0.0)
if bool(self.config.enable_jepa_diffusion) and float(self.config.jepa_action_prior_weight) > 0:
prior_actions = self.action_prior_from_latent(jepa_latent, target_actions.dtype)
loss_jepa_prior = F.mse_loss(prior_actions, target_actions)
else:
loss_jepa_prior = target_actions.new_tensor(0.0)
if bool(self.config.enable_jepa_diffusion) and float(self.config.consistency_loss_weight) > 0:
consistency_actions = self.consistency_action_from_latent(jepa_latent, target_actions.dtype)
moe_residual = self.moe_action_residual(s2_tokens, state, action_history, task_memory)
if moe_residual is not None:
consistency_actions = consistency_actions + float(self.config.s1_moe_residual_scale) * moe_residual
loss_consistency = F.mse_loss(consistency_actions, target_actions)
else:
loss_consistency = target_actions.new_tensor(0.0)
loss_smooth = (target_actions[:, 1:] - target_actions[:, :-1]).pow(2).mean() if target_actions.shape[1] > 1 else target_actions.new_tensor(0.0)
loss = (
self.config.flow_loss_weight * (loss_flow_mse + self.config.action_l1_weight * loss_flow_l1)
+ self.config.smooth_loss_weight * loss_smooth
+ self.config.diffusion_loss_weight * loss_diffusion
+ self.config.consistency_loss_weight * loss_consistency
+ self.config.jepa_loss_weight * loss_jepa.to(loss_flow_mse.dtype)
+ self.config.jepa_action_prior_weight * loss_jepa_prior
)
return {
"loss": loss,
"loss_flow_mse": loss_flow_mse,
"loss_flow_l1": loss_flow_l1,
"loss_diffusion": loss_diffusion,
"loss_consistency": loss_consistency,
"loss_jepa": loss_jepa,
"loss_jepa_prior": loss_jepa_prior,
"loss_smooth": loss_smooth,
}
@torch.no_grad()
def sample(self, s2_tokens, state, state_history=None, action_history=None, steps=None, task_memory=None):
if state.ndim == 1:
state = state.unsqueeze(0)
if self.config.s1_policy_mode == "jepa_diffusion" and bool(self.config.enable_jepa_diffusion):
x = self.jepa_diffusion_sample(s2_tokens, state, state_history, action_history, steps=steps, task_memory=task_memory)
if self.config.action_clip > 0:
x = x.clamp(-self.config.action_clip, self.config.action_clip)
return x
batch_size = state.shape[0]
steps = steps or self.config.flow_inference_steps
dtype = state.dtype if state.is_floating_point() else torch.float32
jepa_latent = self.predict_action_latent(s2_tokens, state, action_history, task_memory) if bool(self.config.enable_jepa_diffusion) else None
x = torch.randn(batch_size, self.config.action_horizon, self.config.action_dim, device=state.device, dtype=dtype)
x = x * float(self.config.s1_sampling_noise_scale)
if jepa_latent is not None and float(self.config.jepa_action_prior_alpha) != 0.0:
x = x + float(self.config.jepa_action_prior_alpha) * self.action_prior_from_latent(jepa_latent, dtype)
moe_residual = self.moe_action_residual(s2_tokens, state, action_history, task_memory)
if moe_residual is not None:
x = x + float(self.config.s1_moe_residual_scale) * moe_residual
for i in range(steps):
timesteps = torch.full((batch_size,), float(i) / max(steps, 1), device=state.device, dtype=x.dtype)
x = x + self.velocity(s2_tokens, state, state_history, action_history, x, timesteps, jepa_latent, task_memory) / float(steps)
if self.config.action_clip > 0:
x = x.clamp(-self.config.action_clip, self.config.action_clip)
return x
class Rio2JepaS1ActionExpert(nn.Module):
"""JEPA-style S1 that preserves the online S1 policy weights.
This module does **not** replace the original S1 policy with an unrelated
world model. Instead it wraps the existing fast flow S1 as `online_s1` and
adds a small latent prediction side objective:
- online_s1: action generator; initialized and trained exactly like the
existing RIO-2 S1 path, so old S1 checkpoints can be remapped into it.
- jepa_context_encoder + predictor: predicts future action latent from
S2 tokens, current state, and action history.
- target_action_encoder: EMA target encoder for the future action chunk.
- latent_to_action_delta: optional zero-initialized residual head.
Inference defaults to the online S1 policy. JEPA affects actions only when
`config.use_jepa_action_residual=True` and `config.jepa_action_alpha > 0`.
"""
def __init__(self, config: Rio2Config):
super().__init__()
self.config = config
self.online_s1 = Rio2FastS1FlowActionExpert(config)
hidden = int(config.jepa_hidden_dim)
latent = int(config.jepa_latent_dim)
self.s2_jepa_proj = nn.Linear(config.s2_width, hidden)
self.state_jepa_proj = nn.Linear(config.state_dim, hidden)
self.action_hist_jepa_proj = nn.Linear(config.action_dim, hidden)
self.type_emb = nn.Parameter(torch.randn(3, hidden) / math.sqrt(hidden))
layer = nn.TransformerEncoderLayer(
d_model=hidden,
nhead=max(1, int(config.jepa_heads)),
dim_feedforward=hidden * 4,
dropout=config.s1_dropout,
batch_first=True,
norm_first=True,
activation="gelu",
)
self.jepa_context_encoder = nn.TransformerEncoder(layer, num_layers=max(1, int(config.jepa_layers)))
self.jepa_norm = Rio2RMSNorm(hidden)
self.jepa_predictor = nn.Sequential(
nn.Linear(hidden, hidden),
nn.SiLU(),
nn.Linear(hidden, latent),
)
flat_action_dim = config.action_horizon * config.action_dim
self.action_encoder = nn.Sequential(
nn.Linear(flat_action_dim, hidden),
nn.SiLU(),
nn.Linear(hidden, latent),
)
self.target_action_encoder = copy.deepcopy(self.action_encoder)
for param in self.target_action_encoder.parameters():
param.requires_grad = False
self.latent_to_action_delta = nn.Sequential(
nn.Linear(latent, hidden),
nn.SiLU(),
nn.Linear(hidden, flat_action_dim),
)
nn.init.zeros_(self.latent_to_action_delta[-1].weight)
nn.init.zeros_(self.latent_to_action_delta[-1].bias)
def _prepare_action_history(self, action_history, batch_size, device, dtype):
if action_history is None:
return torch.zeros(batch_size, self.config.action_history_len, self.config.action_dim, device=device, dtype=dtype)
if action_history.ndim == 2:
action_history = action_history.unsqueeze(0)
action_history = action_history.to(device=device, dtype=dtype)
if action_history.shape[1] < self.config.action_history_len:
pad = torch.zeros(
action_history.shape[0],
self.config.action_history_len - action_history.shape[1],
action_history.shape[2],
device=device,
dtype=dtype,
)
action_history = torch.cat([pad, action_history], dim=1)
return action_history[:, -self.config.action_history_len :]
def encode_context(self, s2_tokens, state, action_history=None):
if state.ndim == 1:
state = state.unsqueeze(0)
if s2_tokens.ndim == 2:
s2_tokens = s2_tokens.unsqueeze(0)
batch_size = state.shape[0]
device = state.device
dtype = state.dtype if state.is_floating_point() else torch.float32
s2_tokens = s2_tokens.to(device=device, dtype=dtype)
state = state.to(device=device, dtype=dtype)
action_history = self._prepare_action_history(action_history, batch_size, device, dtype)
s2_tok = self.s2_jepa_proj(s2_tokens) + self.type_emb[0]
state_tok = self.state_jepa_proj(state).unsqueeze(1) + self.type_emb[1]
hist_tok = self.action_hist_jepa_proj(action_history) + self.type_emb[2]
tokens = torch.cat([s2_tok, state_tok, hist_tok], dim=1)
hidden = self.jepa_context_encoder(tokens)
hidden = self.jepa_norm(hidden)
return hidden.mean(dim=1)
def predict_action_latent(self, s2_tokens, state, action_history=None):
context = self.encode_context(s2_tokens, state, action_history)
return self.jepa_predictor(context)
def encode_action_latent(self, actions: torch.Tensor, target: bool = False) -> torch.Tensor:
if actions.ndim == 2:
actions = actions.unsqueeze(0)
flat = actions.reshape(actions.shape[0], -1)
encoder = self.target_action_encoder if target else self.action_encoder
latent = encoder(flat)
return F.normalize(latent.float(), dim=-1).to(dtype=flat.dtype)
@torch.no_grad()
def update_target_encoder(self, decay: float | None = None):
decay = float(self.config.jepa_ema_decay if decay is None else decay)
for online, target in zip(self.action_encoder.parameters(), self.target_action_encoder.parameters()):
target.data.mul_(decay).add_(online.data, alpha=1.0 - decay)
def freeze_target_encoder(self):
if hasattr(self.online_s1, "freeze_target_encoder"):
self.online_s1.freeze_target_encoder()
for param in self.target_action_encoder.parameters():
param.requires_grad = False
def training_loss(self, s2_tokens, state, state_history, action_history, target_actions, task_memory=None):
base_losses = self.online_s1.training_loss(s2_tokens, state, state_history, action_history, target_actions, task_memory=task_memory)
pred_latent = F.normalize(self.predict_action_latent(s2_tokens, state, action_history).float(), dim=-1)
with torch.no_grad():
target_latent = self.encode_action_latent(target_actions, target=True).float()
loss_jepa = F.mse_loss(pred_latent, target_latent)
loss = base_losses["loss"] + float(self.config.jepa_loss_weight) * loss_jepa.to(base_losses["loss"].dtype)
return {
**base_losses,
"loss": loss,
"loss_jepa": loss_jepa,
"pred_action_latent": pred_latent,
"target_action_latent": target_latent,
}
@torch.no_grad()
def sample(self, s2_tokens, state, state_history=None, action_history=None, steps=None, task_memory=None):
actions = self.online_s1.sample(s2_tokens, state, state_history, action_history, steps=steps, task_memory=task_memory)
if bool(self.config.use_jepa_action_residual) and float(self.config.jepa_action_alpha) != 0.0:
pred_latent = self.predict_action_latent(s2_tokens, state, action_history).to(actions.dtype)
delta = self.latent_to_action_delta(pred_latent).view(
actions.shape[0], self.config.action_horizon, self.config.action_dim
)
actions = actions + float(self.config.jepa_action_alpha) * delta
if self.config.action_clip > 0:
actions = actions.clamp(-self.config.action_clip, self.config.action_clip)
return actions
class Rio2ResidualAdapter(nn.Module):
"""Tiny correction head. Initial output is zero when residual_alpha=0."""
def __init__(self, config: Rio2Config):
super().__init__()
width = min(256, max(64, config.s1_width))
self.net = nn.Sequential(
nn.Linear(config.state_dim, width),
nn.SiLU(),
nn.Linear(width, config.action_horizon * config.action_dim),
)
self.config = config
nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)
def forward(self, state: torch.Tensor) -> torch.Tensor:
if state.ndim == 1:
state = state.unsqueeze(0)
delta = self.net(state).view(state.shape[0], self.config.action_horizon, self.config.action_dim)
return delta
class Rio2PreTrainedModel(PreTrainedModel):
config_class = Rio2Config
base_model_prefix = "rio2"
supports_gradient_checkpointing = False
_no_split_modules = ["Rio2FastS1FlowActionExpert", "Rio2MolmoAct2Core"]
def _init_weights(self, module):
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, Rio2RMSNorm):
module.weight.data.fill_(1.0)
class Rio2Model(Rio2PreTrainedModel):
"""RIO-2 weight-preserved SO101 policy integrated as a Transformers model."""
def __init__(self, config: Rio2Config):
super().__init__(config)
self.molmoact = Rio2MolmoAct2Core(config)
if bool(config.enable_jepa_s1):
self.s1_student = Rio2JepaS1ActionExpert(config)
else:
self.s1_student = Rio2FastS1FlowActionExpert(config)
self.residual_adapter = Rio2ResidualAdapter(config) if config.enable_residual_adapter else None
self._s2_cache: torch.Tensor | None = None
self._s2_cache_time: float = 0.0
self._cached_instruction: str | None = None
self._action_chunk_history: list[tuple[torch.Tensor, int]] = []
self._task_memory_cache: torch.Tensor | None = None
self.post_init()
if config.load_base_on_init:
logger.warning("config.load_base_on_init=True loads MolmoAct2 during construction; prefer load_s2_base().")
self.load_s2_base()
self.apply_finetuning_policy()
@property
def s2(self):
"""Backward-compatible alias without duplicate module registration."""
return self.molmoact
@property
def s1(self):
"""Backward-compatible alias without duplicate module registration."""
return self.s1_student
def load_s2_base(self, device: str | torch.device | None = None, device_map: str | None = None):
self.molmoact.load_base(device=device, device_map=device_map)
self.apply_finetuning_policy()
return self
def freeze_s2_base(self):
self.molmoact.freeze_base()
return self
@torch.no_grad()
def reset_temporal_ensemble(self):
self._action_chunk_history.clear()
return self
@torch.no_grad()
def reset_task_memory(self):
self._task_memory_cache = None
return self
@torch.no_grad()
def update_task_memory(self, s2_tokens: torch.Tensor, reset: bool = False):
if not bool(self.config.task_memory_enabled):
self._task_memory_cache = None
return None
device = next(self.s1_student.parameters()).device
dtype = next(self.s1_student.parameters()).dtype
if hasattr(self.s1_student, "default_task_memory_from_s2"):
candidate = self.s1_student.default_task_memory_from_s2(s2_tokens.to(device=device, dtype=dtype)).detach()
elif hasattr(self.s1_student, "online_s1"):
candidate = self.s1_student.online_s1.default_task_memory_from_s2(s2_tokens.to(device=device, dtype=dtype)).detach()
else:
return None
if (
reset
or self._task_memory_cache is None
or tuple(self._task_memory_cache.shape) != tuple(candidate.shape)
):
memory = candidate
else:
memory = float(self.config.task_memory_ema) * self._task_memory_cache.to(device=device, dtype=dtype)
memory = memory + (1.0 - float(self.config.task_memory_ema)) * candidate
max_norm = float(self.config.task_memory_max_norm)
if max_norm > 0:
norms = memory.norm(dim=-1, keepdim=True).clamp_min(1e-6)
memory = memory * (max_norm / norms).clamp(max=1.0)
self._task_memory_cache = memory.detach()
return self._task_memory_cache
@torch.no_grad()
def _apply_temporal_ensemble(self, actions: torch.Tensor, enabled: bool | None = None) -> torch.Tensor:
use_ensemble = self.config.temporal_ensemble_enabled if enabled is None else enabled
if not use_ensemble or actions.ndim != 3:
return actions
if self._action_chunk_history and self._action_chunk_history[0][0].shape != actions.shape:
self.reset_temporal_ensemble()
aged = []
for chunk, age in self._action_chunk_history:
next_age = age + 1
if next_age < actions.shape[1]:
aged.append((chunk, next_age))
max_chunks = int(max(1, self.config.temporal_ensemble_max_chunks))
self._action_chunk_history = [(actions.detach(), 0)] + aged[: max_chunks - 1]
blended = []
for offset in range(actions.shape[1]):
weighted_sum = None
weight_sum = 0.0
for chunk, age in self._action_chunk_history:
idx = age + offset
if idx >= actions.shape[1]:
continue
weight = math.exp(-float(self.config.temporal_ensemble_decay) * age)
value = chunk[:, idx]
weighted_sum = value * weight if weighted_sum is None else weighted_sum + value * weight
weight_sum += weight
blended.append(weighted_sum / max(weight_sum, 1e-8))
return torch.stack(blended, dim=1)
def apply_finetuning_policy(self):
"""Apply the default small-tuning policy.
Base MolmoAct2 weights are frozen by default. Trainable parameters are
compressor/student/residual-adapter parameters, and optionally the
detected original action expert when the user explicitly unfreezes it.
"""
if self.config.train_adapters_only:
if self.molmoact.base is not None:
self.molmoact.freeze_base()
for param in self.molmoact.compressor.parameters():
param.requires_grad = True
for param in self.s1_student.parameters():
param.requires_grad = True
if hasattr(self.s1_student, "freeze_target_encoder"):
self.s1_student.freeze_target_encoder()
if self.residual_adapter is not None:
for param in self.residual_adapter.parameters():
param.requires_grad = bool(self.config.residual_trainable)
return self
def unfreeze_original_s1(self):
return self.molmoact.unfreeze_action_expert()
def trainable_parameter_names(self) -> list[str]:
return [name for name, param in self.named_parameters() if param.requires_grad]
@torch.no_grad()
def update_jepa_target_encoder(self, decay: float | None = None):
if hasattr(self.s1_student, "update_target_encoder"):
self.s1_student.update_target_encoder(decay=decay)
return self
@torch.no_grad()
def refresh_s2(self, images: ImageLike | list[ImageLike], instruction: str, force: bool = False) -> torch.Tensor:
tokens = self.molmoact.refresh_s2(images, instruction, force=force)
if instruction != self._cached_instruction or force:
self.reset_temporal_ensemble()
self.update_task_memory(tokens, reset=instruction != self._cached_instruction)
else:
self.update_task_memory(tokens, reset=False)
self._s2_cache = tokens.detach()
self._s2_cache_time = time.time()
self._cached_instruction = instruction
return self._s2_cache
@torch.no_grad()
def act_fast(
self,
state: torch.Tensor,
state_history: torch.Tensor | None = None,
action_history: torch.Tensor | None = None,
steps: int | None = None,
use_original: bool | None = None,
temporal_ensemble: bool | None = None,
) -> torch.Tensor:
use_original = self.config.use_original_s1 if use_original is None else use_original
device = next(self.parameters()).device
state = state.to(device)
state_history = None if state_history is None else state_history.to(device)
action_history = None if action_history is None else action_history.to(device)
if use_original and self.molmoact.base is not None:
actions = self.molmoact.act_original(state, state_history, action_history, num_steps=steps)
else:
if self._s2_cache is None:
raise RuntimeError("S2 cache is empty. Call refresh_s2() or pass s2_tokens to forward().")
s2_tokens = self._s2_cache.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
task_memory = None if self._task_memory_cache is None else self._task_memory_cache.to(device=device, dtype=s2_tokens.dtype)
actions = self.s1_student.sample(s2_tokens, state, state_history, action_history, steps=steps, task_memory=task_memory)
if self.residual_adapter is not None and float(self.config.residual_alpha) != 0.0:
actions = actions + float(self.config.residual_alpha) * self.residual_adapter(state).to(actions.dtype)
if self.config.action_clip > 0:
actions = actions.clamp(-self.config.action_clip, self.config.action_clip)
return self._apply_temporal_ensemble(actions, enabled=temporal_ensemble)
def forward_from_s2_tokens(
self,
s2_tokens: torch.Tensor,
state: torch.Tensor,
state_history: torch.Tensor | None = None,
action_history: torch.Tensor | None = None,
target_actions: torch.Tensor | None = None,
s1_steps: int | None = None,
task_memory: torch.Tensor | None = None,
return_dict: bool | None = None,
) -> tuple[torch.Tensor] | Rio2Output:
return self.forward(
state=state,
s2_tokens=s2_tokens,
state_history=state_history,
action_history=action_history,
target_actions=target_actions,
s1_steps=s1_steps,
task_memory=task_memory,
return_dict=return_dict,
use_original=False,
)
def forward(
self,
state: torch.Tensor,
s2_tokens: torch.Tensor | None = None,
state_history: torch.Tensor | None = None,
action_history: torch.Tensor | None = None,
target_actions: torch.Tensor | None = None,
images: ImageLike | list[ImageLike] | None = None,
instruction: str | None = None,
refresh_s2: bool = False,
s1_steps: int | None = None,
task_memory: torch.Tensor | None = None,
use_original: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor] | Rio2Output:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_original = self.config.use_original_s1 if use_original is None else use_original
if refresh_s2:
if images is None or instruction is None:
raise ValueError("`images` and `instruction` are required when refresh_s2=True.")
s2_tokens = self.refresh_s2(images, instruction, force=True)
elif s2_tokens is None:
s2_tokens = self._s2_cache
device = next(self.parameters()).device
state = state.to(device)
state_history = None if state_history is None else state_history.to(device)
action_history = None if action_history is None else action_history.to(device)
# Training path: use cached-token/student path by default because the
# original MolmoAct2 action expert is usually frozen and remote-code
# signatures may not expose target-action training directly.
if target_actions is not None:
if s2_tokens is None:
raise ValueError("Training requires `s2_tokens` or refresh_s2=True.")
s2_tokens = s2_tokens.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
task_memory = None if task_memory is None else task_memory.to(device=device, dtype=s2_tokens.dtype)
target_actions = target_actions.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
losses = self.s1_student.training_loss(s2_tokens, state, state_history, action_history, target_actions, task_memory=task_memory)
output = Rio2Output(
loss=losses["loss"],
s2_tokens=s2_tokens,
loss_flow_mse=losses["loss_flow_mse"],
loss_flow_l1=losses["loss_flow_l1"],
loss_diffusion=losses.get("loss_diffusion"),
loss_consistency=losses.get("loss_consistency"),
loss_smooth=losses["loss_smooth"],
loss_jepa=losses.get("loss_jepa"),
loss_jepa_prior=losses.get("loss_jepa_prior"),
pred_action_latent=losses.get("pred_action_latent"),
target_action_latent=losses.get("target_action_latent"),
runtime_path="jepa_s1_training" if "loss_jepa" in losses else "student_adapter_training",
)
return tuple(v for v in output.to_tuple() if v is not None) if not return_dict else output
if use_original and self.molmoact.base is not None:
actions = self.act_fast(state, state_history, action_history, steps=s1_steps, use_original=True)
runtime_path = self.molmoact.last_runtime_path
tokens = self._s2_cache
else:
if s2_tokens is None:
raise ValueError("Pass `s2_tokens`, call refresh_s2(), or set refresh_s2=True.")
s2_tokens = s2_tokens.to(device=device, dtype=state.dtype if state.is_floating_point() else torch.float32)
if task_memory is None and self._task_memory_cache is not None:
task_memory = self._task_memory_cache
task_memory = None if task_memory is None else task_memory.to(device=device, dtype=s2_tokens.dtype)
actions = self.s1_student.sample(s2_tokens, state, state_history, action_history, steps=s1_steps, task_memory=task_memory)
if self.residual_adapter is not None and float(self.config.residual_alpha) != 0.0:
actions = actions + float(self.config.residual_alpha) * self.residual_adapter(state).to(actions.dtype)
runtime_path = "student_cached_tokens"
tokens = s2_tokens
output = Rio2Output(actions=actions, s2_tokens=tokens, runtime_path=runtime_path)
return (actions, tokens) if not return_dict else output
__all__ = [
"Rio2Model",
"Rio2PreTrainedModel",
]