Robotics
Transformers
Safetensors
English
rio2
feature-extraction
Mixture of Experts
diffusion-jepa
custom_code
File size: 5,427 Bytes
9a5262f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2026 The HuggingFace Inc. team and the Rio2 contributors.
# Licensed under the Apache License, Version 2.0.
"""Processor for Rio2."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from transformers.processing_utils import ProcessorMixin
from transformers.utils import logging


logger = logging.get_logger(__name__)


class Rio2Processor(ProcessorMixin):
    attributes = []
    optional_attributes = []

    def __init__(self, base_processor=None, base_model_id: str | None = None, **kwargs):
        self.base_processor = base_processor
        self.base_model_id = base_model_id
        self.chat_template = kwargs.pop("chat_template", None)

    @classmethod
    def from_base_model_id(cls, base_model_id: str, **kwargs):
        from transformers import AutoProcessor

        base_processor = AutoProcessor.from_pretrained(base_model_id, trust_remote_code=True, **kwargs)
        return cls(base_processor=base_processor, base_model_id=base_model_id)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        path = Path(pretrained_model_name_or_path)
        base_model_id = kwargs.pop("base_model_id", None)
        load_base_processor = bool(kwargs.pop("load_base_processor", False))
        hub_kwargs = {
            key: kwargs.get(key)
            for key in ["cache_dir", "force_download", "proxies", "token", "revision", "local_files_only", "subfolder"]
            if key in kwargs
        }
        if path.exists():
            cfg_path = path / "processor_config.json"
            model_cfg_path = path / "config.json"
            if cfg_path.exists():
                data = json.loads(cfg_path.read_text(encoding="utf-8"))
                base_model_id = base_model_id or data.get("base_model_id")
            if base_model_id is None and model_cfg_path.exists():
                data = json.loads(model_cfg_path.read_text(encoding="utf-8"))
                base_model_id = data.get("base_model_id")
        else:
            try:
                from transformers.utils import cached_file

                cfg_file = cached_file(pretrained_model_name_or_path, "processor_config.json", **hub_kwargs)
                if cfg_file:
                    data = json.loads(Path(cfg_file).read_text(encoding="utf-8"))
                    base_model_id = base_model_id or data.get("base_model_id")
            except Exception as exc:
                logger.debug("Could not load RIO-2 processor config from Hub: %s", exc)
            if base_model_id is None:
                try:
                    from transformers.utils import cached_file

                    cfg_file = cached_file(pretrained_model_name_or_path, "config.json", **hub_kwargs)
                    if cfg_file:
                        data = json.loads(Path(cfg_file).read_text(encoding="utf-8"))
                        base_model_id = data.get("base_model_id")
                except Exception as exc:
                    logger.debug("Could not load RIO-2 model config from Hub: %s", exc)

        base_processor = None
        if base_model_id and load_base_processor:
            try:
                from transformers import AutoProcessor

                trust_remote_code = kwargs.pop("trust_remote_code", True)
                base_processor = AutoProcessor.from_pretrained(base_model_id, trust_remote_code=trust_remote_code, **kwargs)
            except Exception as exc:
                logger.warning("Could not load base processor %s: %s", base_model_id, exc)
        return cls(base_processor=base_processor, base_model_id=base_model_id)

    def save_pretrained(self, save_directory, **kwargs):
        out = Path(save_directory)
        out.mkdir(parents=True, exist_ok=True)
        data = {
            "processor_class": self.__class__.__name__,
            "base_model_id": self.base_model_id,
            "auto_map": {"AutoProcessor": "processing_rio2.Rio2Processor"},
        }
        (out / "processor_config.json").write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8")
        if self.base_processor is not None and kwargs.pop("save_base_processor", False):
            base_dir = out / "base_processor"
            self.base_processor.save_pretrained(base_dir)
        return [str(out / "processor_config.json")]

    def __call__(
        self,
        images=None,
        instruction: str | None = None,
        state: Any | None = None,
        state_history: Any | None = None,
        action_history: Any | None = None,
        target_actions: Any | None = None,
        **kwargs,
    ) -> dict[str, Any]:
        out: dict[str, Any] = {}
        if self.base_processor is not None and images is not None and instruction is not None:
            out.update(self.base_processor(images=images, text=instruction, return_tensors="pt", **kwargs))
        else:
            if images is not None:
                out["images"] = images
            if instruction is not None:
                out["instruction"] = instruction
        if state is not None:
            out["state"] = state
        if state_history is not None:
            out["state_history"] = state_history
        if action_history is not None:
            out["action_history"] = action_history
        if target_actions is not None:
            out["target_actions"] = target_actions
        return out


__all__ = ["Rio2Processor"]