0118 / checkpoint-300000 /processing_eo1_internvl.py
jasonzhango's picture
Upload folder using huggingface_hub
24443be verified
"""
EO1Vision processor for `eo_pi_internvl`.
This is the InternVL-backbone EO1 processor with a Pi05-style action prompt:
- We keep a *single* `<|action_pad|>` as a placeholder suffix token in text prompts.
- The action expert consumes *continuous* action tokens (length=`action_chunk_size`) internally, so we do not need to
repeat `<|action_pad|>` by chunk size in the text (this also keeps AR loss extensible).
"""
from __future__ import annotations
import inspect
import torch
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
from eo_internvl.model.processing_eo1_internvl import (
ACTION_END_TOKEN,
ACTION_START_TOKEN,
DEFAULT_ACTION_TOKEN,
EO1VisionProcessor as _BaseEO1VisionProcessor,
EO1VisionProcessorKwargs,
RobotInput,
)
class EO1VisionProcessor(_BaseEO1VisionProcessor):
def expand_action_prompt(self, chunk_size: int) -> str:
# Pi05-style: keep a single placeholder token in text; the model builds the full continuous action block.
return DEFAULT_ACTION_TOKEN
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
videos: VideoInput = None,
states: RobotInput = None,
actions: RobotInput = None,
**kwargs: Unpack[EO1VisionProcessorKwargs],
) -> BatchFeature:
# Force action-token expansion length to 1 (no-op), regardless of robot_config / caller.
text_kwargs = kwargs.get("text_kwargs") or {}
text_kwargs = dict(text_kwargs)
text_kwargs["noise_token_num"] = 1
kwargs["text_kwargs"] = text_kwargs
return super().__call__(images=images, text=text, videos=videos, states=states, actions=actions, **kwargs)
@torch.no_grad()
def select_action(self, model, batch: dict, return_raw_actions: bool = False, **kwargs):
if not hasattr(model, "sample_actions"):
raise NotImplementedError("InternVL EO1 model does not implement sample_actions yet.")
action_prefix = batch.pop("action_prefix", None)
rtc_delay = batch.pop("rtc_delay", None)
batch_messages, batch_states, repo_ids = self._prepare_robot_inputs(batch)
chunk_size = int(getattr(getattr(model, "config", None), "action_chunk_size", 0) or self.robot_config.get("action_chunk_size") or 0)
noise_prompt = self.expand_action_prompt(chunk_size) if chunk_size > 0 else f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN}{ACTION_END_TOKEN}"
inputs = self.apply_chat_template(
batch_messages,
states=batch_states,
add_generation_prompt=True,
noise_prompt=noise_prompt,
tokenize=True,
padding=True,
truncation=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
sig = None
try:
sig = inspect.signature(model.sample_actions)
except Exception:
sig = None
if action_prefix is not None:
if isinstance(action_prefix, (list, tuple)):
elems = []
for v in action_prefix:
if not torch.is_tensor(v):
v = torch.as_tensor(v)
elems.append(v)
action_prefix = torch.stack(elems, dim=0)
elif not torch.is_tensor(action_prefix):
action_prefix = torch.as_tensor(action_prefix)
action_prefix = action_prefix.to(device=model.device, dtype=torch.float32)
if sig is not None and "action_prefix" in sig.parameters:
actions = model.sample_actions(**inputs, action_prefix=action_prefix, delay=rtc_delay).cpu()
else:
actions = model.sample_actions(**inputs).cpu()
if return_raw_actions:
return BatchFeature({"action": actions})
output_actions = self._process_robot_outputs(repo_ids, actions)
return BatchFeature({"action": output_actions})
EO1VisionProcessor.register_for_auto_class()