theguy21 commited on
Commit
af59633
Β·
verified Β·
1 Parent(s): 1d2bcdf

Add ROS2 node for diff-drive robot deployment

Browse files
Files changed (1) hide show
  1. vla_micro_ros2_node.py +193 -0
vla_micro_ros2_node.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OpenVLA-Micro ROS2 node for diff-drive robot (Nova Carter).
4
+
5
+ Subscribes:
6
+ /rgb/image_raw β€” Kinect RGB
7
+ /depth/image_raw β€” Kinect depth (obstacle avoidance)
8
+ /vla_instruction β€” text command
9
+
10
+ Publishes:
11
+ /cmd_vel β€” Twist
12
+
13
+ Run:
14
+ python3 vla_micro_ros2_node.py
15
+
16
+ First-time setup will load & cache model components (~30s load, once).
17
+ Subsequent starts load in ~3-5s.
18
+ """
19
+ import os, time, threading
20
+
21
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
+
23
+ import numpy as np
24
+ from PIL import Image
25
+ import rclpy
26
+ from rclpy.node import Node
27
+ from sensor_msgs.msg import Image as RosImage
28
+ from geometry_msgs.msg import Twist
29
+ from std_msgs.msg import String
30
+ from cv_bridge import CvBridge
31
+ import torch
32
+
33
+ CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_cache")
34
+
35
+ # ── Diff-drive params ───────────────────────────────────────────
36
+ VX_GAIN = 0.15
37
+ WZ_GAIN = 0.8
38
+ VX_MAX = 0.3
39
+ WZ_MAX = 0.6
40
+ CTRL_RATE = 3.0
41
+ STOP_TIMEOUT = 1.0
42
+
43
+ # ── Obstacle avoidance ──────────────────────────────────────────
44
+ OBSTACLE_MIN_DIST = 0.3
45
+ OBSTACLE_BRAKE_DIST = 0.8
46
+ DEPTH_CENTER_STRIP = 0.3
47
+
48
+
49
+ def _load_model(device="cpu"):
50
+ """Load model β€” uses cache if available, else downloads & caches."""
51
+ from fast_load import load_openvla_micro
52
+
53
+ cache = os.path.expanduser(CACHE_DIR)
54
+ if not os.path.isdir(cache):
55
+ # First run: download from HF and cache components
56
+ print("First-time setup: downloading & caching model...", flush=True)
57
+ from modeling_openvla_micro import OpenVLAMicro
58
+ model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device=device)
59
+ import shutil
60
+ shutil.copytree("model_cache", cache, dirs_exist_ok=True)
61
+ return load_openvla_micro(cache, device=device)
62
+
63
+
64
+ class OpenVLAMicroNode(Node):
65
+ def __init__(self):
66
+ super().__init__("vla_micro_node")
67
+ self.cmd_pub = self.create_publisher(Twist, "/cmd_vel", 10)
68
+
69
+ self.instruction = None
70
+ self.latest_rgb = None
71
+ self.latest_depth = None
72
+ self.lock = threading.Lock()
73
+ self.br = CvBridge()
74
+
75
+ self.create_subscription(String, "/vla_instruction", self.instruction_cb, 10)
76
+ self.create_subscription(RosImage, "/rgb/image_raw", self.image_cb, 10)
77
+ self.create_subscription(RosImage, "/depth/image_raw", self.depth_cb, 10)
78
+
79
+ self.get_logger().info("Loading OpenVLA-Micro...")
80
+ t0 = time.time()
81
+ self.model = _load_model(device="cpu")
82
+ n_params = sum(p.numel() for p in self.model.parameters()) / 1e6
83
+ self.get_logger().info(
84
+ f"Loaded in {time.time()-t0:.0f}s ({n_params:.0f}M params)"
85
+ )
86
+
87
+ self.last_cmd_time = time.time()
88
+ self.timer = self.create_timer(1.0 / CTRL_RATE, self.tick)
89
+
90
+ def instruction_cb(self, msg):
91
+ with self.lock:
92
+ self.instruction = msg.data.strip()
93
+ self.get_logger().info(f"Instr: {self.instruction}")
94
+
95
+ def image_cb(self, msg):
96
+ with self.lock:
97
+ self.latest_rgb = self.br.imgmsg_to_cv2(msg, desired_encoding="rgb8")
98
+
99
+ def depth_cb(self, msg):
100
+ with self.lock:
101
+ self.latest_depth = self.br.imgmsg_to_cv2(msg, desired_encoding="16UC1")
102
+
103
+ def check_obstacles(self):
104
+ depth = self.latest_depth
105
+ if depth is None:
106
+ return 99.0, 99.0, 99.0
107
+ h, w = depth.shape
108
+ top = int(h * (0.5 - DEPTH_CENTER_STRIP / 2))
109
+ bot = int(h * (0.5 + DEPTH_CENTER_STRIP / 2))
110
+ center = depth[top:bot, :]
111
+ valid = center[center > 0]
112
+ if len(valid) == 0:
113
+ return 99.0, 99.0, 99.0
114
+ valid_m = valid.astype(np.float32) / 1000.0
115
+ closest = float(np.min(valid_m))
116
+ left = center[:, : w // 3]
117
+ right = center[:, 2 * w // 3 :]
118
+ left_m = left[left > 0].astype(np.float32) / 1000.0
119
+ right_m = right[right > 0].astype(np.float32) / 1000.0
120
+ return (
121
+ closest,
122
+ float(np.min(left_m)) if len(left_m) > 0 else 99.0,
123
+ float(np.min(right_m)) if len(right_m) > 0 else 99.0,
124
+ )
125
+
126
+ def stop(self):
127
+ self.cmd_pub.publish(Twist())
128
+ self.last_cmd_time = time.time()
129
+
130
+ def tick(self):
131
+ with self.lock:
132
+ instr = self.instruction
133
+ rgb = self.latest_rgb
134
+
135
+ if instr is None or rgb is None:
136
+ if time.time() - self.last_cmd_time > STOP_TIMEOUT:
137
+ self.stop()
138
+ return
139
+
140
+ pil_img = Image.fromarray(rgb, mode="RGB")
141
+
142
+ t0 = time.perf_counter()
143
+ with torch.no_grad():
144
+ action = self.model.predict_action(pil_img, instr)
145
+ dt = time.perf_counter() - t0
146
+
147
+ # 7-DoF: [x, y, z, roll, pitch, yaw, gripper] β†’ diff-drive
148
+ vx = float(action[0]) * VX_GAIN
149
+ wz = float(action[3]) * WZ_GAIN
150
+
151
+ vx = max(-VX_MAX, min(VX_MAX, vx))
152
+ wz = max(-WZ_MAX, min(WZ_MAX, wz))
153
+
154
+ # Obstacle avoidance
155
+ closest, left_min, right_min = self.check_obstacles()
156
+ if closest < OBSTACLE_MIN_DIST:
157
+ self.get_logger().warn(f"Obstacle {closest:.2f}m β€” stop")
158
+ self.stop()
159
+ return
160
+ if closest < OBSTACLE_BRAKE_DIST:
161
+ brake = (closest - OBSTACLE_MIN_DIST) / (OBSTACLE_BRAKE_DIST - OBSTACLE_MIN_DIST)
162
+ brake = max(0.0, min(1.0, brake))
163
+ vx *= brake
164
+ if left_min < right_min and abs(wz) < 0.3:
165
+ wz = 0.5
166
+ elif right_min < left_min and abs(wz) < 0.3:
167
+ wz = -0.5
168
+ self.get_logger().info(f"Brake {brake:.2f} obs={closest:.2f}m")
169
+
170
+ cmd = Twist()
171
+ cmd.linear.x = vx
172
+ cmd.angular.z = wz
173
+ self.cmd_pub.publish(cmd)
174
+ self.last_cmd_time = time.time()
175
+
176
+ self.get_logger().info(
177
+ f"vx={vx:.3f} wz={wz:.3f} ({dt*1000:.0f}ms) [{instr[:40]}]"
178
+ )
179
+
180
+
181
+ def main():
182
+ rclpy.init()
183
+ node = OpenVLAMicroNode()
184
+ try:
185
+ rclpy.spin(node)
186
+ except KeyboardInterrupt:
187
+ node.stop()
188
+ node.destroy_node()
189
+ rclpy.shutdown()
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()