Robotics
Transformers
Safetensors
English
rio2
feature-extraction
Mixture of Experts
diffusion-jepa
custom_code
RIO-2 / processing_rio2.py
hoguai's picture
Upload 7 files
9a5262f verified
# Copyright 2026 The HuggingFace Inc. team and the Rio2 contributors.
# Licensed under the Apache License, Version 2.0.
"""Processor for Rio2."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from transformers.processing_utils import ProcessorMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Rio2Processor(ProcessorMixin):
attributes = []
optional_attributes = []
def __init__(self, base_processor=None, base_model_id: str | None = None, **kwargs):
self.base_processor = base_processor
self.base_model_id = base_model_id
self.chat_template = kwargs.pop("chat_template", None)
@classmethod
def from_base_model_id(cls, base_model_id: str, **kwargs):
from transformers import AutoProcessor
base_processor = AutoProcessor.from_pretrained(base_model_id, trust_remote_code=True, **kwargs)
return cls(base_processor=base_processor, base_model_id=base_model_id)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
path = Path(pretrained_model_name_or_path)
base_model_id = kwargs.pop("base_model_id", None)
load_base_processor = bool(kwargs.pop("load_base_processor", False))
hub_kwargs = {
key: kwargs.get(key)
for key in ["cache_dir", "force_download", "proxies", "token", "revision", "local_files_only", "subfolder"]
if key in kwargs
}
if path.exists():
cfg_path = path / "processor_config.json"
model_cfg_path = path / "config.json"
if cfg_path.exists():
data = json.loads(cfg_path.read_text(encoding="utf-8"))
base_model_id = base_model_id or data.get("base_model_id")
if base_model_id is None and model_cfg_path.exists():
data = json.loads(model_cfg_path.read_text(encoding="utf-8"))
base_model_id = data.get("base_model_id")
else:
try:
from transformers.utils import cached_file
cfg_file = cached_file(pretrained_model_name_or_path, "processor_config.json", **hub_kwargs)
if cfg_file:
data = json.loads(Path(cfg_file).read_text(encoding="utf-8"))
base_model_id = base_model_id or data.get("base_model_id")
except Exception as exc:
logger.debug("Could not load RIO-2 processor config from Hub: %s", exc)
if base_model_id is None:
try:
from transformers.utils import cached_file
cfg_file = cached_file(pretrained_model_name_or_path, "config.json", **hub_kwargs)
if cfg_file:
data = json.loads(Path(cfg_file).read_text(encoding="utf-8"))
base_model_id = data.get("base_model_id")
except Exception as exc:
logger.debug("Could not load RIO-2 model config from Hub: %s", exc)
base_processor = None
if base_model_id and load_base_processor:
try:
from transformers import AutoProcessor
trust_remote_code = kwargs.pop("trust_remote_code", True)
base_processor = AutoProcessor.from_pretrained(base_model_id, trust_remote_code=trust_remote_code, **kwargs)
except Exception as exc:
logger.warning("Could not load base processor %s: %s", base_model_id, exc)
return cls(base_processor=base_processor, base_model_id=base_model_id)
def save_pretrained(self, save_directory, **kwargs):
out = Path(save_directory)
out.mkdir(parents=True, exist_ok=True)
data = {
"processor_class": self.__class__.__name__,
"base_model_id": self.base_model_id,
"auto_map": {"AutoProcessor": "processing_rio2.Rio2Processor"},
}
(out / "processor_config.json").write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8")
if self.base_processor is not None and kwargs.pop("save_base_processor", False):
base_dir = out / "base_processor"
self.base_processor.save_pretrained(base_dir)
return [str(out / "processor_config.json")]
def __call__(
self,
images=None,
instruction: str | None = None,
state: Any | None = None,
state_history: Any | None = None,
action_history: Any | None = None,
target_actions: Any | None = None,
**kwargs,
) -> dict[str, Any]:
out: dict[str, Any] = {}
if self.base_processor is not None and images is not None and instruction is not None:
out.update(self.base_processor(images=images, text=instruction, return_tensors="pt", **kwargs))
else:
if images is not None:
out["images"] = images
if instruction is not None:
out["instruction"] = instruction
if state is not None:
out["state"] = state
if state_history is not None:
out["state_history"] = state_history
if action_history is not None:
out["action_history"] = action_history
if target_actions is not None:
out["target_actions"] = target_actions
return out
__all__ = ["Rio2Processor"]