Reinforcement Learning
Transformers
English
robotics
vla
vision-language-action
openvla
omnivla
robot
qwen
dinov2
siglip
Instructions to use theguy21/openvla-micro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use theguy21/openvla-micro with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("theguy21/openvla-micro", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 6,288 Bytes
af59633 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | #!/usr/bin/env python3
"""
OpenVLA-Micro ROS2 node for diff-drive robot (Nova Carter).
Subscribes:
/rgb/image_raw β Kinect RGB
/depth/image_raw β Kinect depth (obstacle avoidance)
/vla_instruction β text command
Publishes:
/cmd_vel β Twist
Run:
python3 vla_micro_ros2_node.py
First-time setup will load & cache model components (~30s load, once).
Subsequent starts load in ~3-5s.
"""
import os, time, threading
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import numpy as np
from PIL import Image
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image as RosImage
from geometry_msgs.msg import Twist
from std_msgs.msg import String
from cv_bridge import CvBridge
import torch
CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_cache")
# ββ Diff-drive params βββββββββββββββββββββββββββββββββββββββββββ
VX_GAIN = 0.15
WZ_GAIN = 0.8
VX_MAX = 0.3
WZ_MAX = 0.6
CTRL_RATE = 3.0
STOP_TIMEOUT = 1.0
# ββ Obstacle avoidance ββββββββββββββββββββββββββββββββββββββββββ
OBSTACLE_MIN_DIST = 0.3
OBSTACLE_BRAKE_DIST = 0.8
DEPTH_CENTER_STRIP = 0.3
def _load_model(device="cpu"):
"""Load model β uses cache if available, else downloads & caches."""
from fast_load import load_openvla_micro
cache = os.path.expanduser(CACHE_DIR)
if not os.path.isdir(cache):
# First run: download from HF and cache components
print("First-time setup: downloading & caching model...", flush=True)
from modeling_openvla_micro import OpenVLAMicro
model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device=device)
import shutil
shutil.copytree("model_cache", cache, dirs_exist_ok=True)
return load_openvla_micro(cache, device=device)
class OpenVLAMicroNode(Node):
def __init__(self):
super().__init__("vla_micro_node")
self.cmd_pub = self.create_publisher(Twist, "/cmd_vel", 10)
self.instruction = None
self.latest_rgb = None
self.latest_depth = None
self.lock = threading.Lock()
self.br = CvBridge()
self.create_subscription(String, "/vla_instruction", self.instruction_cb, 10)
self.create_subscription(RosImage, "/rgb/image_raw", self.image_cb, 10)
self.create_subscription(RosImage, "/depth/image_raw", self.depth_cb, 10)
self.get_logger().info("Loading OpenVLA-Micro...")
t0 = time.time()
self.model = _load_model(device="cpu")
n_params = sum(p.numel() for p in self.model.parameters()) / 1e6
self.get_logger().info(
f"Loaded in {time.time()-t0:.0f}s ({n_params:.0f}M params)"
)
self.last_cmd_time = time.time()
self.timer = self.create_timer(1.0 / CTRL_RATE, self.tick)
def instruction_cb(self, msg):
with self.lock:
self.instruction = msg.data.strip()
self.get_logger().info(f"Instr: {self.instruction}")
def image_cb(self, msg):
with self.lock:
self.latest_rgb = self.br.imgmsg_to_cv2(msg, desired_encoding="rgb8")
def depth_cb(self, msg):
with self.lock:
self.latest_depth = self.br.imgmsg_to_cv2(msg, desired_encoding="16UC1")
def check_obstacles(self):
depth = self.latest_depth
if depth is None:
return 99.0, 99.0, 99.0
h, w = depth.shape
top = int(h * (0.5 - DEPTH_CENTER_STRIP / 2))
bot = int(h * (0.5 + DEPTH_CENTER_STRIP / 2))
center = depth[top:bot, :]
valid = center[center > 0]
if len(valid) == 0:
return 99.0, 99.0, 99.0
valid_m = valid.astype(np.float32) / 1000.0
closest = float(np.min(valid_m))
left = center[:, : w // 3]
right = center[:, 2 * w // 3 :]
left_m = left[left > 0].astype(np.float32) / 1000.0
right_m = right[right > 0].astype(np.float32) / 1000.0
return (
closest,
float(np.min(left_m)) if len(left_m) > 0 else 99.0,
float(np.min(right_m)) if len(right_m) > 0 else 99.0,
)
def stop(self):
self.cmd_pub.publish(Twist())
self.last_cmd_time = time.time()
def tick(self):
with self.lock:
instr = self.instruction
rgb = self.latest_rgb
if instr is None or rgb is None:
if time.time() - self.last_cmd_time > STOP_TIMEOUT:
self.stop()
return
pil_img = Image.fromarray(rgb, mode="RGB")
t0 = time.perf_counter()
with torch.no_grad():
action = self.model.predict_action(pil_img, instr)
dt = time.perf_counter() - t0
# 7-DoF: [x, y, z, roll, pitch, yaw, gripper] β diff-drive
vx = float(action[0]) * VX_GAIN
wz = float(action[3]) * WZ_GAIN
vx = max(-VX_MAX, min(VX_MAX, vx))
wz = max(-WZ_MAX, min(WZ_MAX, wz))
# Obstacle avoidance
closest, left_min, right_min = self.check_obstacles()
if closest < OBSTACLE_MIN_DIST:
self.get_logger().warn(f"Obstacle {closest:.2f}m β stop")
self.stop()
return
if closest < OBSTACLE_BRAKE_DIST:
brake = (closest - OBSTACLE_MIN_DIST) / (OBSTACLE_BRAKE_DIST - OBSTACLE_MIN_DIST)
brake = max(0.0, min(1.0, brake))
vx *= brake
if left_min < right_min and abs(wz) < 0.3:
wz = 0.5
elif right_min < left_min and abs(wz) < 0.3:
wz = -0.5
self.get_logger().info(f"Brake {brake:.2f} obs={closest:.2f}m")
cmd = Twist()
cmd.linear.x = vx
cmd.angular.z = wz
self.cmd_pub.publish(cmd)
self.last_cmd_time = time.time()
self.get_logger().info(
f"vx={vx:.3f} wz={wz:.3f} ({dt*1000:.0f}ms) [{instr[:40]}]"
)
def main():
rclpy.init()
node = OpenVLAMicroNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
node.stop()
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()
|