Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| OpenVLA-Micro ROS2 node for diff-drive robot (Nova Carter). | |
| Subscribes: | |
| /rgb/image_raw — Kinect RGB | |
| /depth/image_raw — Kinect depth (obstacle avoidance) | |
| /vla_instruction — text command | |
| Publishes: | |
| /cmd_vel — Twist | |
| Run: | |
| python3 vla_micro_ros2_node.py | |
| First-time setup will load & cache model components (~30s load, once). | |
| Subsequent starts load in ~3-5s. | |
| """ | |
| import os, time, threading | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import numpy as np | |
| from PIL import Image | |
| import rclpy | |
| from rclpy.node import Node | |
| from sensor_msgs.msg import Image as RosImage | |
| from geometry_msgs.msg import Twist | |
| from std_msgs.msg import String | |
| from cv_bridge import CvBridge | |
| import torch | |
| CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_cache") | |
| # ── Diff-drive params ─────────────────────────────────────────── | |
| VX_GAIN = 0.15 | |
| WZ_GAIN = 0.8 | |
| VX_MAX = 0.3 | |
| WZ_MAX = 0.6 | |
| CTRL_RATE = 3.0 | |
| STOP_TIMEOUT = 1.0 | |
| # ── Obstacle avoidance ────────────────────────────────────────── | |
| OBSTACLE_MIN_DIST = 0.3 | |
| OBSTACLE_BRAKE_DIST = 0.8 | |
| DEPTH_CENTER_STRIP = 0.3 | |
| def _load_model(device="cpu"): | |
| """Load model — uses cache if available, else downloads & caches.""" | |
| from fast_load import load_openvla_micro | |
| cache = os.path.expanduser(CACHE_DIR) | |
| if not os.path.isdir(cache): | |
| # First run: download from HF and cache components | |
| print("First-time setup: downloading & caching model...", flush=True) | |
| from modeling_openvla_micro import OpenVLAMicro | |
| model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device=device) | |
| import shutil | |
| shutil.copytree("model_cache", cache, dirs_exist_ok=True) | |
| return load_openvla_micro(cache, device=device) | |
| class OpenVLAMicroNode(Node): | |
| def __init__(self): | |
| super().__init__("vla_micro_node") | |
| self.cmd_pub = self.create_publisher(Twist, "/cmd_vel", 10) | |
| self.instruction = None | |
| self.latest_rgb = None | |
| self.latest_depth = None | |
| self.lock = threading.Lock() | |
| self.br = CvBridge() | |
| self.create_subscription(String, "/vla_instruction", self.instruction_cb, 10) | |
| self.create_subscription(RosImage, "/rgb/image_raw", self.image_cb, 10) | |
| self.create_subscription(RosImage, "/depth/image_raw", self.depth_cb, 10) | |
| self.get_logger().info("Loading OpenVLA-Micro...") | |
| t0 = time.time() | |
| self.model = _load_model(device="cpu") | |
| n_params = sum(p.numel() for p in self.model.parameters()) / 1e6 | |
| self.get_logger().info( | |
| f"Loaded in {time.time()-t0:.0f}s ({n_params:.0f}M params)" | |
| ) | |
| self.last_cmd_time = time.time() | |
| self.timer = self.create_timer(1.0 / CTRL_RATE, self.tick) | |
| def instruction_cb(self, msg): | |
| with self.lock: | |
| self.instruction = msg.data.strip() | |
| self.get_logger().info(f"Instr: {self.instruction}") | |
| def image_cb(self, msg): | |
| with self.lock: | |
| self.latest_rgb = self.br.imgmsg_to_cv2(msg, desired_encoding="rgb8") | |
| def depth_cb(self, msg): | |
| with self.lock: | |
| self.latest_depth = self.br.imgmsg_to_cv2(msg, desired_encoding="16UC1") | |
| def check_obstacles(self): | |
| depth = self.latest_depth | |
| if depth is None: | |
| return 99.0, 99.0, 99.0 | |
| h, w = depth.shape | |
| top = int(h * (0.5 - DEPTH_CENTER_STRIP / 2)) | |
| bot = int(h * (0.5 + DEPTH_CENTER_STRIP / 2)) | |
| center = depth[top:bot, :] | |
| valid = center[center > 0] | |
| if len(valid) == 0: | |
| return 99.0, 99.0, 99.0 | |
| valid_m = valid.astype(np.float32) / 1000.0 | |
| closest = float(np.min(valid_m)) | |
| left = center[:, : w // 3] | |
| right = center[:, 2 * w // 3 :] | |
| left_m = left[left > 0].astype(np.float32) / 1000.0 | |
| right_m = right[right > 0].astype(np.float32) / 1000.0 | |
| return ( | |
| closest, | |
| float(np.min(left_m)) if len(left_m) > 0 else 99.0, | |
| float(np.min(right_m)) if len(right_m) > 0 else 99.0, | |
| ) | |
| def stop(self): | |
| self.cmd_pub.publish(Twist()) | |
| self.last_cmd_time = time.time() | |
| def tick(self): | |
| with self.lock: | |
| instr = self.instruction | |
| rgb = self.latest_rgb | |
| if instr is None or rgb is None: | |
| if time.time() - self.last_cmd_time > STOP_TIMEOUT: | |
| self.stop() | |
| return | |
| pil_img = Image.fromarray(rgb, mode="RGB") | |
| t0 = time.perf_counter() | |
| with torch.no_grad(): | |
| action = self.model.predict_action(pil_img, instr) | |
| dt = time.perf_counter() - t0 | |
| # 7-DoF: [x, y, z, roll, pitch, yaw, gripper] → diff-drive | |
| vx = float(action[0]) * VX_GAIN | |
| wz = float(action[3]) * WZ_GAIN | |
| vx = max(-VX_MAX, min(VX_MAX, vx)) | |
| wz = max(-WZ_MAX, min(WZ_MAX, wz)) | |
| # Obstacle avoidance | |
| closest, left_min, right_min = self.check_obstacles() | |
| if closest < OBSTACLE_MIN_DIST: | |
| self.get_logger().warn(f"Obstacle {closest:.2f}m — stop") | |
| self.stop() | |
| return | |
| if closest < OBSTACLE_BRAKE_DIST: | |
| brake = (closest - OBSTACLE_MIN_DIST) / (OBSTACLE_BRAKE_DIST - OBSTACLE_MIN_DIST) | |
| brake = max(0.0, min(1.0, brake)) | |
| vx *= brake | |
| if left_min < right_min and abs(wz) < 0.3: | |
| wz = 0.5 | |
| elif right_min < left_min and abs(wz) < 0.3: | |
| wz = -0.5 | |
| self.get_logger().info(f"Brake {brake:.2f} obs={closest:.2f}m") | |
| cmd = Twist() | |
| cmd.linear.x = vx | |
| cmd.angular.z = wz | |
| self.cmd_pub.publish(cmd) | |
| self.last_cmd_time = time.time() | |
| self.get_logger().info( | |
| f"vx={vx:.3f} wz={wz:.3f} ({dt*1000:.0f}ms) [{instr[:40]}]" | |
| ) | |
| def main(): | |
| rclpy.init() | |
| node = OpenVLAMicroNode() | |
| try: | |
| rclpy.spin(node) | |
| except KeyboardInterrupt: | |
| node.stop() | |
| node.destroy_node() | |
| rclpy.shutdown() | |
| if __name__ == "__main__": | |
| main() | |