File size: 6,288 Bytes
af59633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python3
"""
OpenVLA-Micro ROS2 node for diff-drive robot (Nova Carter).

Subscribes:
  /rgb/image_raw       β€” Kinect RGB
  /depth/image_raw     β€” Kinect depth (obstacle avoidance)
  /vla_instruction     β€” text command

Publishes:
  /cmd_vel             β€” Twist

Run:
  python3 vla_micro_ros2_node.py

First-time setup will load & cache model components (~30s load, once).
Subsequent starts load in ~3-5s.
"""
import os, time, threading

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np
from PIL import Image
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image as RosImage
from geometry_msgs.msg import Twist
from std_msgs.msg import String
from cv_bridge import CvBridge
import torch

CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_cache")

# ── Diff-drive params ───────────────────────────────────────────
VX_GAIN = 0.15
WZ_GAIN = 0.8
VX_MAX = 0.3
WZ_MAX = 0.6
CTRL_RATE = 3.0
STOP_TIMEOUT = 1.0

# ── Obstacle avoidance ──────────────────────────────────────────
OBSTACLE_MIN_DIST = 0.3
OBSTACLE_BRAKE_DIST = 0.8
DEPTH_CENTER_STRIP = 0.3


def _load_model(device="cpu"):
    """Load model β€” uses cache if available, else downloads & caches."""
    from fast_load import load_openvla_micro

    cache = os.path.expanduser(CACHE_DIR)
    if not os.path.isdir(cache):
        # First run: download from HF and cache components
        print("First-time setup: downloading & caching model...", flush=True)
        from modeling_openvla_micro import OpenVLAMicro
        model = OpenVLAMicro.from_pretrained("theguy21/openvla-micro", device=device)
        import shutil
        shutil.copytree("model_cache", cache, dirs_exist_ok=True)
    return load_openvla_micro(cache, device=device)


class OpenVLAMicroNode(Node):
    def __init__(self):
        super().__init__("vla_micro_node")
        self.cmd_pub = self.create_publisher(Twist, "/cmd_vel", 10)

        self.instruction = None
        self.latest_rgb = None
        self.latest_depth = None
        self.lock = threading.Lock()
        self.br = CvBridge()

        self.create_subscription(String, "/vla_instruction", self.instruction_cb, 10)
        self.create_subscription(RosImage, "/rgb/image_raw", self.image_cb, 10)
        self.create_subscription(RosImage, "/depth/image_raw", self.depth_cb, 10)

        self.get_logger().info("Loading OpenVLA-Micro...")
        t0 = time.time()
        self.model = _load_model(device="cpu")
        n_params = sum(p.numel() for p in self.model.parameters()) / 1e6
        self.get_logger().info(
            f"Loaded in {time.time()-t0:.0f}s ({n_params:.0f}M params)"
        )

        self.last_cmd_time = time.time()
        self.timer = self.create_timer(1.0 / CTRL_RATE, self.tick)

    def instruction_cb(self, msg):
        with self.lock:
            self.instruction = msg.data.strip()
        self.get_logger().info(f"Instr: {self.instruction}")

    def image_cb(self, msg):
        with self.lock:
            self.latest_rgb = self.br.imgmsg_to_cv2(msg, desired_encoding="rgb8")

    def depth_cb(self, msg):
        with self.lock:
            self.latest_depth = self.br.imgmsg_to_cv2(msg, desired_encoding="16UC1")

    def check_obstacles(self):
        depth = self.latest_depth
        if depth is None:
            return 99.0, 99.0, 99.0
        h, w = depth.shape
        top = int(h * (0.5 - DEPTH_CENTER_STRIP / 2))
        bot = int(h * (0.5 + DEPTH_CENTER_STRIP / 2))
        center = depth[top:bot, :]
        valid = center[center > 0]
        if len(valid) == 0:
            return 99.0, 99.0, 99.0
        valid_m = valid.astype(np.float32) / 1000.0
        closest = float(np.min(valid_m))
        left = center[:, : w // 3]
        right = center[:, 2 * w // 3 :]
        left_m = left[left > 0].astype(np.float32) / 1000.0
        right_m = right[right > 0].astype(np.float32) / 1000.0
        return (
            closest,
            float(np.min(left_m)) if len(left_m) > 0 else 99.0,
            float(np.min(right_m)) if len(right_m) > 0 else 99.0,
        )

    def stop(self):
        self.cmd_pub.publish(Twist())
        self.last_cmd_time = time.time()

    def tick(self):
        with self.lock:
            instr = self.instruction
            rgb = self.latest_rgb

        if instr is None or rgb is None:
            if time.time() - self.last_cmd_time > STOP_TIMEOUT:
                self.stop()
            return

        pil_img = Image.fromarray(rgb, mode="RGB")

        t0 = time.perf_counter()
        with torch.no_grad():
            action = self.model.predict_action(pil_img, instr)
        dt = time.perf_counter() - t0

        # 7-DoF: [x, y, z, roll, pitch, yaw, gripper] β†’ diff-drive
        vx = float(action[0]) * VX_GAIN
        wz = float(action[3]) * WZ_GAIN

        vx = max(-VX_MAX, min(VX_MAX, vx))
        wz = max(-WZ_MAX, min(WZ_MAX, wz))

        # Obstacle avoidance
        closest, left_min, right_min = self.check_obstacles()
        if closest < OBSTACLE_MIN_DIST:
            self.get_logger().warn(f"Obstacle {closest:.2f}m β€” stop")
            self.stop()
            return
        if closest < OBSTACLE_BRAKE_DIST:
            brake = (closest - OBSTACLE_MIN_DIST) / (OBSTACLE_BRAKE_DIST - OBSTACLE_MIN_DIST)
            brake = max(0.0, min(1.0, brake))
            vx *= brake
            if left_min < right_min and abs(wz) < 0.3:
                wz = 0.5
            elif right_min < left_min and abs(wz) < 0.3:
                wz = -0.5
            self.get_logger().info(f"Brake {brake:.2f} obs={closest:.2f}m")

        cmd = Twist()
        cmd.linear.x = vx
        cmd.angular.z = wz
        self.cmd_pub.publish(cmd)
        self.last_cmd_time = time.time()

        self.get_logger().info(
            f"vx={vx:.3f} wz={wz:.3f} ({dt*1000:.0f}ms) [{instr[:40]}]"
        )


def main():
    rclpy.init()
    node = OpenVLAMicroNode()
    try:
        rclpy.spin(node)
    except KeyboardInterrupt:
        node.stop()
    node.destroy_node()
    rclpy.shutdown()


if __name__ == "__main__":
    main()