Instructions to use INSAIT-Institute/arvla-bridge with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use INSAIT-Institute/arvla-bridge with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("INSAIT-Institute/arvla-bridge", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import importlib | |
| import json | |
| import os | |
| from typing import Any, Dict, List, Optional | |
| import numpy as np | |
| import transformers | |
| import yaml | |
| from transformers.processing_utils import ProcessorMixin | |
| try: | |
| from .configuration_vlarm_hf import VLARMHFConfig | |
| except Exception: | |
| VLARMHFConfig = importlib.import_module("configuration_vlarm_hf").VLARMHFConfig | |
| try: | |
| from .configuration_vlarm import PaliGemmaProcessorConfig, RegressionProcessorConfig | |
| from .processing_vlarm import PaliGemmaProcessor, VLARMProcessor | |
| except Exception: | |
| try: | |
| PaliGemmaProcessorConfig = importlib.import_module("configuration_vlarm").PaliGemmaProcessorConfig | |
| RegressionProcessorConfig = importlib.import_module("configuration_vlarm").RegressionProcessorConfig | |
| PaliGemmaProcessor = importlib.import_module("processing_vlarm").PaliGemmaProcessor | |
| VLARMProcessor = importlib.import_module("processing_vlarm").VLARMProcessor | |
| except Exception: | |
| PaliGemmaProcessorConfig = importlib.import_module("src.configuration_vlarm").PaliGemmaProcessorConfig | |
| RegressionProcessorConfig = importlib.import_module("src.configuration_vlarm").RegressionProcessorConfig | |
| PaliGemmaProcessor = importlib.import_module("src.processing_vlarm").PaliGemmaProcessor | |
| VLARMProcessor = importlib.import_module("src.processing_vlarm").VLARMProcessor | |
| def _resolve_local_path(path: str) -> str: | |
| if os.path.isabs(path): | |
| return path | |
| return os.path.join(os.path.dirname(__file__), path) | |
| class VLARMHFProcessor(ProcessorMixin): | |
| attributes = ["image_processor", "tokenizer"] | |
| def __init__( | |
| self, | |
| hf_processor: transformers.ProcessorMixin, | |
| vlarm_processor: VLARMProcessor, | |
| base_processor_model_id: Optional[str], | |
| ): | |
| self.hf_processor = hf_processor | |
| self.vlarm_processor = vlarm_processor | |
| self.base_processor_model_id = base_processor_model_id | |
| self.image_processor = hf_processor.image_processor | |
| self.tokenizer = hf_processor.tokenizer | |
| def preprocess_inputs( | |
| self, | |
| chat: List[str], | |
| images: Dict[str, List[Any]], | |
| ee_pose_translation: np.ndarray, | |
| ee_pose_rotation: np.ndarray, | |
| gripper: np.ndarray, | |
| joints: np.ndarray, | |
| dataset_name: np.ndarray, | |
| inference_mode: bool = True, | |
| ) -> Dict[str, Any]: | |
| self.vlarm_processor._set_dataset_np_names(dataset_name) | |
| return self.vlarm_processor.preprocess_inputs( | |
| chat=chat, | |
| images=images, | |
| ee_pose_translation=ee_pose_translation, | |
| ee_pose_rotation=ee_pose_rotation, | |
| gripper=gripper, | |
| joints=joints, | |
| dataset_name=dataset_name, | |
| inference_mode=inference_mode, | |
| ) | |
| def save_pretrained(self, save_directory: str, **kwargs): | |
| os.makedirs(save_directory, exist_ok=True) | |
| self.hf_processor.save_pretrained(save_directory, **kwargs) | |
| payload = { | |
| "processor_class": self.__class__.__name__, | |
| "base_processor_model_id": self.base_processor_model_id, | |
| "vlarm_processor_config": self.vlarm_processor.config.as_json(), | |
| "vlm_processor_config": self.vlarm_processor.vlm_processor.config.as_json(), | |
| "auto_map": { | |
| "AutoProcessor": "processing_vlarm_hf.VLARMHFProcessor" | |
| }, | |
| } | |
| with open(os.path.join(save_directory, "processor_config.json"), "w") as file: | |
| json.dump(payload, file, indent=2) | |
| def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "VLARMHFProcessor": | |
| processor_config_path = os.path.join(pretrained_model_name_or_path, "processor_config.json") | |
| processor_config = {} | |
| if os.path.exists(processor_config_path): | |
| with open(processor_config_path, "r") as file: | |
| processor_config = json.load(file) | |
| config_kwargs = dict(kwargs) | |
| config_kwargs.pop("trust_remote_code", None) | |
| try: | |
| config = VLARMHFConfig.from_pretrained(pretrained_model_name_or_path, **config_kwargs) | |
| except Exception: | |
| config = VLARMHFConfig.from_aram_yaml( | |
| os.path.join(pretrained_model_name_or_path, "src", "model_config.yaml") | |
| ) | |
| aram_config_dict = dict(config.aram_config) | |
| if not aram_config_dict: | |
| yaml_path = os.path.join(pretrained_model_name_or_path, "src", "model_config.yaml") | |
| with open(yaml_path, "r") as file: | |
| aram_config_dict = yaml.safe_load(file) or {} | |
| vlarm_processor_config_dict = dict(aram_config_dict.get("processor_config", {})) | |
| vlm_processor_config_dict = dict( | |
| aram_config_dict.get("vlm_config", {}).get("processor_config", {}) | |
| ) | |
| vlarm_processor_config_dict["control_stats_path"] = _resolve_local_path(config.control_stats_path) | |
| vlarm_processor_config_dict["observation_stats_path"] = _resolve_local_path( | |
| config.observation_stats_path | |
| ) | |
| if "vlarm_processor_config" in processor_config: | |
| vlarm_processor_config_dict = dict(processor_config["vlarm_processor_config"]) | |
| vlarm_processor_config_dict["control_stats_path"] = _resolve_local_path(config.control_stats_path) | |
| vlarm_processor_config_dict["observation_stats_path"] = _resolve_local_path( | |
| config.observation_stats_path | |
| ) | |
| if "vlm_processor_config" in processor_config: | |
| vlm_processor_config_dict = dict(processor_config["vlm_processor_config"]) | |
| vlarm_processor_config = RegressionProcessorConfig(**vlarm_processor_config_dict) | |
| vlm_processor_config = PaliGemmaProcessorConfig(**vlm_processor_config_dict) | |
| base_model_id = ( | |
| processor_config.get("base_processor_model_id") | |
| or config.base_processor_model_id | |
| or aram_config_dict.get("vlm_config", {}).get("model_id") | |
| ) | |
| if not base_model_id: | |
| raise ValueError( | |
| "Unable to determine base_processor_model_id for VLARMHFProcessor. " | |
| "Please set it in config.json or processor_config.json." | |
| ) | |
| hf_processor = transformers.AutoProcessor.from_pretrained(base_model_id) | |
| if not hasattr(hf_processor, "image_processor") or not hasattr(hf_processor, "tokenizer"): | |
| raise ValueError(f"Base processor '{base_model_id}' is not a multimodal processor") | |
| vlm_processor = PaliGemmaProcessor( | |
| config=vlm_processor_config, | |
| hf_processor=hf_processor, | |
| ) | |
| vlarm_processor = VLARMProcessor( | |
| config=vlarm_processor_config, | |
| vlm_processor=vlm_processor, | |
| ) | |
| return cls( | |
| hf_processor=hf_processor, | |
| vlarm_processor=vlarm_processor, | |
| base_processor_model_id=base_model_id, | |
| ) | |