File size: 1,974 Bytes
6784e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
EO1Vision processor for `eo_pi_internvl`.

This is the InternVL-backbone EO1 processor with a Pi05-style action prompt:
- We keep a *single* `<|action_pad|>` as a placeholder suffix token in text prompts.
- The action expert consumes *continuous* action tokens (length=`action_chunk_size`) internally, so we do not need to
  repeat `<|action_pad|>` by chunk size in the text (this also keeps AR loss extensible).
"""

from __future__ import annotations

from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput

from eo_internvl.model.processing_eo1_internvl import (
    DEFAULT_ACTION_TOKEN,
    EO1VisionProcessor as _BaseEO1VisionProcessor,
    EO1VisionProcessorKwargs,
    RobotInput,
)


class EO1VisionProcessor(_BaseEO1VisionProcessor):
    def expand_action_prompt(self, chunk_size: int) -> str:
        # Pi05-style: keep a single placeholder token in text; the model builds the full continuous action block.
        return DEFAULT_ACTION_TOKEN

    def __call__(
        self,
        images: ImageInput = None,
        text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
        videos: VideoInput = None,
        states: RobotInput = None,
        actions: RobotInput = None,
        **kwargs: Unpack[EO1VisionProcessorKwargs],
    ) -> BatchFeature:
        # Force action-token expansion length to 1 (no-op), regardless of robot_config / caller.
        text_kwargs = kwargs.get("text_kwargs") or {}
        text_kwargs = dict(text_kwargs)
        text_kwargs["noise_token_num"] = 1
        kwargs["text_kwargs"] = text_kwargs
        return super().__call__(images=images, text=text, videos=videos, states=states, actions=actions, **kwargs)


EO1VisionProcessor.register_for_auto_class()