from __future__ import annotations import importlib import json import os from typing import Any, Dict, List, Optional import numpy as np import transformers import yaml from transformers.processing_utils import ProcessorMixin try: from .configuration_vlarm_hf import VLARMHFConfig except Exception: VLARMHFConfig = importlib.import_module("configuration_vlarm_hf").VLARMHFConfig try: from .configuration_vlarm import PaliGemmaProcessorConfig, RegressionProcessorConfig from .processing_vlarm import PaliGemmaProcessor, VLARMProcessor except Exception: try: PaliGemmaProcessorConfig = importlib.import_module("configuration_vlarm").PaliGemmaProcessorConfig RegressionProcessorConfig = importlib.import_module("configuration_vlarm").RegressionProcessorConfig PaliGemmaProcessor = importlib.import_module("processing_vlarm").PaliGemmaProcessor VLARMProcessor = importlib.import_module("processing_vlarm").VLARMProcessor except Exception: PaliGemmaProcessorConfig = importlib.import_module("src.configuration_vlarm").PaliGemmaProcessorConfig RegressionProcessorConfig = importlib.import_module("src.configuration_vlarm").RegressionProcessorConfig PaliGemmaProcessor = importlib.import_module("src.processing_vlarm").PaliGemmaProcessor VLARMProcessor = importlib.import_module("src.processing_vlarm").VLARMProcessor def _resolve_local_path(path: str) -> str: if os.path.isabs(path): return path return os.path.join(os.path.dirname(__file__), path) class VLARMHFProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] def __init__( self, hf_processor: transformers.ProcessorMixin, vlarm_processor: VLARMProcessor, base_processor_model_id: Optional[str], ): self.hf_processor = hf_processor self.vlarm_processor = vlarm_processor self.base_processor_model_id = base_processor_model_id self.image_processor = hf_processor.image_processor self.tokenizer = hf_processor.tokenizer def preprocess_inputs( self, chat: List[str], images: Dict[str, List[Any]], ee_pose_translation: np.ndarray, ee_pose_rotation: np.ndarray, gripper: np.ndarray, joints: np.ndarray, dataset_name: np.ndarray, inference_mode: bool = True, ) -> Dict[str, Any]: self.vlarm_processor._set_dataset_np_names(dataset_name) return self.vlarm_processor.preprocess_inputs( chat=chat, images=images, ee_pose_translation=ee_pose_translation, ee_pose_rotation=ee_pose_rotation, gripper=gripper, joints=joints, dataset_name=dataset_name, inference_mode=inference_mode, ) def save_pretrained(self, save_directory: str, **kwargs): os.makedirs(save_directory, exist_ok=True) self.hf_processor.save_pretrained(save_directory, **kwargs) payload = { "processor_class": self.__class__.__name__, "base_processor_model_id": self.base_processor_model_id, "vlarm_processor_config": self.vlarm_processor.config.as_json(), "vlm_processor_config": self.vlarm_processor.vlm_processor.config.as_json(), "auto_map": { "AutoProcessor": "processing_vlarm_hf.VLARMHFProcessor" }, } with open(os.path.join(save_directory, "processor_config.json"), "w") as file: json.dump(payload, file, indent=2) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "VLARMHFProcessor": processor_config_path = os.path.join(pretrained_model_name_or_path, "processor_config.json") processor_config = {} if os.path.exists(processor_config_path): with open(processor_config_path, "r") as file: processor_config = json.load(file) config_kwargs = dict(kwargs) config_kwargs.pop("trust_remote_code", None) try: config = VLARMHFConfig.from_pretrained(pretrained_model_name_or_path, **config_kwargs) except Exception: config = VLARMHFConfig.from_aram_yaml( os.path.join(pretrained_model_name_or_path, "src", "model_config.yaml") ) aram_config_dict = dict(config.aram_config) if not aram_config_dict: yaml_path = os.path.join(pretrained_model_name_or_path, "src", "model_config.yaml") with open(yaml_path, "r") as file: aram_config_dict = yaml.safe_load(file) or {} vlarm_processor_config_dict = dict(aram_config_dict.get("processor_config", {})) vlm_processor_config_dict = dict( aram_config_dict.get("vlm_config", {}).get("processor_config", {}) ) vlarm_processor_config_dict["control_stats_path"] = _resolve_local_path(config.control_stats_path) vlarm_processor_config_dict["observation_stats_path"] = _resolve_local_path( config.observation_stats_path ) if "vlarm_processor_config" in processor_config: vlarm_processor_config_dict = dict(processor_config["vlarm_processor_config"]) vlarm_processor_config_dict["control_stats_path"] = _resolve_local_path(config.control_stats_path) vlarm_processor_config_dict["observation_stats_path"] = _resolve_local_path( config.observation_stats_path ) if "vlm_processor_config" in processor_config: vlm_processor_config_dict = dict(processor_config["vlm_processor_config"]) vlarm_processor_config = RegressionProcessorConfig(**vlarm_processor_config_dict) vlm_processor_config = PaliGemmaProcessorConfig(**vlm_processor_config_dict) base_model_id = ( processor_config.get("base_processor_model_id") or config.base_processor_model_id or aram_config_dict.get("vlm_config", {}).get("model_id") ) if not base_model_id: raise ValueError( "Unable to determine base_processor_model_id for VLARMHFProcessor. " "Please set it in config.json or processor_config.json." ) hf_processor = transformers.AutoProcessor.from_pretrained(base_model_id) if not hasattr(hf_processor, "image_processor") or not hasattr(hf_processor, "tokenizer"): raise ValueError(f"Base processor '{base_model_id}' is not a multimodal processor") vlm_processor = PaliGemmaProcessor( config=vlm_processor_config, hf_processor=hf_processor, ) vlarm_processor = VLARMProcessor( config=vlarm_processor_config, vlm_processor=vlm_processor, ) return cls( hf_processor=hf_processor, vlarm_processor=vlarm_processor, base_processor_model_id=base_model_id, )