arvla-bridge / processing_vlarm_hf.py
you2who's picture
Duplicate from you2who/paligemma-arvla-bridge
2672775
from __future__ import annotations
import importlib
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
import transformers
import yaml
from transformers.processing_utils import ProcessorMixin
try:
from .configuration_vlarm_hf import VLARMHFConfig
except Exception:
VLARMHFConfig = importlib.import_module("configuration_vlarm_hf").VLARMHFConfig
try:
from .configuration_vlarm import PaliGemmaProcessorConfig, RegressionProcessorConfig
from .processing_vlarm import PaliGemmaProcessor, VLARMProcessor
except Exception:
try:
PaliGemmaProcessorConfig = importlib.import_module("configuration_vlarm").PaliGemmaProcessorConfig
RegressionProcessorConfig = importlib.import_module("configuration_vlarm").RegressionProcessorConfig
PaliGemmaProcessor = importlib.import_module("processing_vlarm").PaliGemmaProcessor
VLARMProcessor = importlib.import_module("processing_vlarm").VLARMProcessor
except Exception:
PaliGemmaProcessorConfig = importlib.import_module("src.configuration_vlarm").PaliGemmaProcessorConfig
RegressionProcessorConfig = importlib.import_module("src.configuration_vlarm").RegressionProcessorConfig
PaliGemmaProcessor = importlib.import_module("src.processing_vlarm").PaliGemmaProcessor
VLARMProcessor = importlib.import_module("src.processing_vlarm").VLARMProcessor
def _resolve_local_path(path: str) -> str:
if os.path.isabs(path):
return path
return os.path.join(os.path.dirname(__file__), path)
class VLARMHFProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
hf_processor: transformers.ProcessorMixin,
vlarm_processor: VLARMProcessor,
base_processor_model_id: Optional[str],
):
self.hf_processor = hf_processor
self.vlarm_processor = vlarm_processor
self.base_processor_model_id = base_processor_model_id
self.image_processor = hf_processor.image_processor
self.tokenizer = hf_processor.tokenizer
def preprocess_inputs(
self,
chat: List[str],
images: Dict[str, List[Any]],
ee_pose_translation: np.ndarray,
ee_pose_rotation: np.ndarray,
gripper: np.ndarray,
joints: np.ndarray,
dataset_name: np.ndarray,
inference_mode: bool = True,
) -> Dict[str, Any]:
self.vlarm_processor._set_dataset_np_names(dataset_name)
return self.vlarm_processor.preprocess_inputs(
chat=chat,
images=images,
ee_pose_translation=ee_pose_translation,
ee_pose_rotation=ee_pose_rotation,
gripper=gripper,
joints=joints,
dataset_name=dataset_name,
inference_mode=inference_mode,
)
def save_pretrained(self, save_directory: str, **kwargs):
os.makedirs(save_directory, exist_ok=True)
self.hf_processor.save_pretrained(save_directory, **kwargs)
payload = {
"processor_class": self.__class__.__name__,
"base_processor_model_id": self.base_processor_model_id,
"vlarm_processor_config": self.vlarm_processor.config.as_json(),
"vlm_processor_config": self.vlarm_processor.vlm_processor.config.as_json(),
"auto_map": {
"AutoProcessor": "processing_vlarm_hf.VLARMHFProcessor"
},
}
with open(os.path.join(save_directory, "processor_config.json"), "w") as file:
json.dump(payload, file, indent=2)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "VLARMHFProcessor":
processor_config_path = os.path.join(pretrained_model_name_or_path, "processor_config.json")
processor_config = {}
if os.path.exists(processor_config_path):
with open(processor_config_path, "r") as file:
processor_config = json.load(file)
config_kwargs = dict(kwargs)
config_kwargs.pop("trust_remote_code", None)
try:
config = VLARMHFConfig.from_pretrained(pretrained_model_name_or_path, **config_kwargs)
except Exception:
config = VLARMHFConfig.from_aram_yaml(
os.path.join(pretrained_model_name_or_path, "src", "model_config.yaml")
)
aram_config_dict = dict(config.aram_config)
if not aram_config_dict:
yaml_path = os.path.join(pretrained_model_name_or_path, "src", "model_config.yaml")
with open(yaml_path, "r") as file:
aram_config_dict = yaml.safe_load(file) or {}
vlarm_processor_config_dict = dict(aram_config_dict.get("processor_config", {}))
vlm_processor_config_dict = dict(
aram_config_dict.get("vlm_config", {}).get("processor_config", {})
)
vlarm_processor_config_dict["control_stats_path"] = _resolve_local_path(config.control_stats_path)
vlarm_processor_config_dict["observation_stats_path"] = _resolve_local_path(
config.observation_stats_path
)
if "vlarm_processor_config" in processor_config:
vlarm_processor_config_dict = dict(processor_config["vlarm_processor_config"])
vlarm_processor_config_dict["control_stats_path"] = _resolve_local_path(config.control_stats_path)
vlarm_processor_config_dict["observation_stats_path"] = _resolve_local_path(
config.observation_stats_path
)
if "vlm_processor_config" in processor_config:
vlm_processor_config_dict = dict(processor_config["vlm_processor_config"])
vlarm_processor_config = RegressionProcessorConfig(**vlarm_processor_config_dict)
vlm_processor_config = PaliGemmaProcessorConfig(**vlm_processor_config_dict)
base_model_id = (
processor_config.get("base_processor_model_id")
or config.base_processor_model_id
or aram_config_dict.get("vlm_config", {}).get("model_id")
)
if not base_model_id:
raise ValueError(
"Unable to determine base_processor_model_id for VLARMHFProcessor. "
"Please set it in config.json or processor_config.json."
)
hf_processor = transformers.AutoProcessor.from_pretrained(base_model_id)
if not hasattr(hf_processor, "image_processor") or not hasattr(hf_processor, "tokenizer"):
raise ValueError(f"Base processor '{base_model_id}' is not a multimodal processor")
vlm_processor = PaliGemmaProcessor(
config=vlm_processor_config,
hf_processor=hf_processor,
)
vlarm_processor = VLARMProcessor(
config=vlarm_processor_config,
vlm_processor=vlm_processor,
)
return cls(
hf_processor=hf_processor,
vlarm_processor=vlarm_processor,
base_processor_model_id=base_model_id,
)