Upload demo_deploy.py
Browse files- demo_deploy.py +273 -0
demo_deploy.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from control_joints import *
|
| 2 |
+
from read_joints import *
|
| 3 |
+
from realsense import *
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import time
|
| 9 |
+
import argparse
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
# get current workspace
|
| 13 |
+
sys.path.insert(0, 'policy/ACT_DP_multitask') # Add the path to the policy directory
|
| 14 |
+
from policy_anyrobot import ACTDiffusionPolicy
|
| 15 |
+
from t5_encoder import T5Embedder
|
| 16 |
+
from utils import convert_weight
|
| 17 |
+
import yaml
|
| 18 |
+
|
| 19 |
+
def get_image(cameras):
|
| 20 |
+
"""
|
| 21 |
+
Get the latest images from all cameras and return them as a dictionary.
|
| 22 |
+
"""
|
| 23 |
+
obs_image = dict()
|
| 24 |
+
for i in range(3):
|
| 25 |
+
for cam in cameras:
|
| 26 |
+
color_image = cam.get_latest_image() # Get the latest image from the camera BGR 0-255 H W 3
|
| 27 |
+
if color_image is not None:
|
| 28 |
+
obs_image[cam.name] = color_image[:,:,::-1] # BGR to RGB conversion 0-255 H W 3
|
| 29 |
+
# cv save
|
| 30 |
+
filename = f"{cam.name}_image_{i}.png"
|
| 31 |
+
cv2.imwrite(filename, color_image)
|
| 32 |
+
print(f"Saved image: {filename}")
|
| 33 |
+
return obs_image # Return the dictionary containing images from all cameras
|
| 34 |
+
|
| 35 |
+
def get_observation(obs_image, joint):
|
| 36 |
+
# for key, value in obs_image.items():
|
| 37 |
+
# print(f"Camera: {key}, Image shape: {value.shape} , Image: {value.max()}") # Debugging line to check image shapes
|
| 38 |
+
observation = dict()
|
| 39 |
+
observation['images'] = dict()
|
| 40 |
+
observation['images']['cam_high'] = np.moveaxis(obs_image['head_camera'], -1, 0) # rgb H W C -> C H W
|
| 41 |
+
observation['images']['cam_left_wrist']= np.moveaxis(obs_image['left_camera'], -1, 0) # rgb H W C -> C H W
|
| 42 |
+
observation['images']['cam_right_wrist'] = np.moveaxis(obs_image['right_camera'], -1, 0) # rgb H W C -> C H W
|
| 43 |
+
observation['qpos'] = joint
|
| 44 |
+
|
| 45 |
+
return observation
|
| 46 |
+
|
| 47 |
+
def encode_obs(observations, camera_names):
|
| 48 |
+
obs_img = []
|
| 49 |
+
for camera_name in camera_names: # ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
| 50 |
+
obs_img.append(observations['images'][camera_name])
|
| 51 |
+
obs_img = np.stack(obs_img, axis=0) # shape: (N_views, H, W, C) 0-255 rgb
|
| 52 |
+
image_data = torch.from_numpy(obs_img).unsqueeze(0).float() / 255.0 # Normalize to [0, 1] and add history dimension
|
| 53 |
+
qpos_data = torch.from_numpy(np.array(observations['qpos'], dtype=np.float32)).unsqueeze(0) # shape: (1, 14)
|
| 54 |
+
return image_data, qpos_data # no batch dimension
|
| 55 |
+
|
| 56 |
+
def get_model(config_file, ckpt_file, device):
|
| 57 |
+
with open(config_file, "r", encoding="utf-8") as file:
|
| 58 |
+
policy_config = json.load(file)
|
| 59 |
+
print(f"Loading policy config from {config_file}")
|
| 60 |
+
policy = ACTDiffusionPolicy(policy_config)
|
| 61 |
+
print(f"Loading model from {ckpt_file}")
|
| 62 |
+
policy.load_state_dict(convert_weight(torch.load(ckpt_file)["state_dict"]))
|
| 63 |
+
policy.to(device)
|
| 64 |
+
policy.eval()
|
| 65 |
+
stats = torch.load(ckpt_file)["stats"]
|
| 66 |
+
print('Resetting observation normalization stats')
|
| 67 |
+
policy.reset_obs(stats, norm_type = policy_config["norm_type"])
|
| 68 |
+
camera_names = policy_config["camera_names"]
|
| 69 |
+
return policy, camera_names
|
| 70 |
+
|
| 71 |
+
def get_language_encoder(device):
|
| 72 |
+
MODEL_PATH ='/data/gjx/.cache/huggingface/hub/models--google--t5-v1_1-xxl/snapshots/3db67ab1af984cf10548a73467f0e5bca2aaaeb2'
|
| 73 |
+
# MODEL_PATH = 'policy/weights/RDT/t5-v1_1-xxl'
|
| 74 |
+
CONFIG_PATH = os.path.join('policy/ACT_DP_multitask',"base.yaml")
|
| 75 |
+
with open(CONFIG_PATH, "r") as fp:
|
| 76 |
+
config = yaml.safe_load(fp)
|
| 77 |
+
text_embedder = T5Embedder(
|
| 78 |
+
from_pretrained=MODEL_PATH,
|
| 79 |
+
model_max_length=config["dataset"]["tokenizer_max_length"],
|
| 80 |
+
device=device,
|
| 81 |
+
use_offload_folder=None,
|
| 82 |
+
)
|
| 83 |
+
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
| 84 |
+
return tokenizer, text_encoder
|
| 85 |
+
|
| 86 |
+
def get_language_embed(tokenizer, text_encoder, language_instruction, device):
|
| 87 |
+
tokens = tokenizer(
|
| 88 |
+
language_instruction, return_tensors="pt",
|
| 89 |
+
padding="longest",
|
| 90 |
+
truncation=True
|
| 91 |
+
)["input_ids"].to(device)
|
| 92 |
+
|
| 93 |
+
tokens = tokens.view(1, -1)
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
pred = text_encoder(tokens).last_hidden_state.detach().cpu()
|
| 96 |
+
|
| 97 |
+
return pred.squeeze(0).mean(0) # shape: (hidden_size) cpu tensor
|
| 98 |
+
|
| 99 |
+
def test_code():
|
| 100 |
+
print("Testing the code...") # Dummy test function to ensure the script runs without errors
|
| 101 |
+
# Dummy test function to ensure the script runs without errors
|
| 102 |
+
# reader = JointReader(left_can="can_left_1", right_can="can_right_1") # ��ȡ�ؽڶ���
|
| 103 |
+
# cameras = [
|
| 104 |
+
# RealSenseCam("250122079815", "left_camera"),
|
| 105 |
+
# RealSenseCam("048522073543", "head_camera"),
|
| 106 |
+
# RealSenseCam("030522070109", "right_camera"),
|
| 107 |
+
# ]
|
| 108 |
+
# left_can, right_can = "can_left_2", "can_right_2"
|
| 109 |
+
# ctx = rs.context()
|
| 110 |
+
# if len(ctx.devices) > 0:
|
| 111 |
+
# print("Found RealSense devices:")
|
| 112 |
+
# for d in ctx.devices:
|
| 113 |
+
# # ��ȡ�豸�����ƺ����к�
|
| 114 |
+
# name = d.get_info(rs.camera_info.name)
|
| 115 |
+
# serial_number = d.get_info(rs.camera_info.serial_number)
|
| 116 |
+
# print(f"Device: {name}, Serial Number: {serial_number}")
|
| 117 |
+
# else:
|
| 118 |
+
# print("No Intel RealSense devices connected")
|
| 119 |
+
# for cam in cameras:
|
| 120 |
+
# cam.start()
|
| 121 |
+
|
| 122 |
+
# for i in range(20):
|
| 123 |
+
# print(f"Warm up: {i}", end="\r")
|
| 124 |
+
# for cam in cameras:
|
| 125 |
+
# color_image = cam.get_latest_image()
|
| 126 |
+
# time.sleep(0.15)
|
| 127 |
+
|
| 128 |
+
device = 'cpu'
|
| 129 |
+
ckpt_dir = 'policy/ACT_DP_multitask/checkpoints/real_fintune_50_2000/act_dp'
|
| 130 |
+
config_path = os.path.join(ckpt_dir, 'policy_config.json')
|
| 131 |
+
ckpt_path = os.path.join(ckpt_dir, 'policy_lastest_seed_0.ckpt')
|
| 132 |
+
policy, cameras_name = get_model(config_path, ckpt_path, device)
|
| 133 |
+
|
| 134 |
+
instruction_file = 'instruction.txt'
|
| 135 |
+
with open(instruction_file, 'r') as f:
|
| 136 |
+
instruction = f.readline().strip().strip('"')
|
| 137 |
+
print(f"Instruction: {instruction}")
|
| 138 |
+
tokenizer, text_encoder = get_language_encoder(device) # ��ȡ���Ա�����
|
| 139 |
+
task_emb = get_language_embed(tokenizer, text_encoder, instruction, device) # ��ȡ����Ƕ�� temsor D
|
| 140 |
+
controller = ControlJoints(left_can=left_can, right_can=right_can)
|
| 141 |
+
image = torch.rand(1, 3, 3, 480, 640) # Dummy image tensor
|
| 142 |
+
qpos = torch.rand(1, 14) # Dummy joint position tensor
|
| 143 |
+
actions = policy.get_action(qpos.float().to(device), image.float().to(device), task_emb.float().to(device))
|
| 144 |
+
print(f"Actions shape: {actions.shape}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
|
| 149 |
+
# test_code() # Uncomment to run the test function
|
| 150 |
+
# print("Testing the done...")
|
| 151 |
+
# exit()
|
| 152 |
+
print(torch.backends.cudnn.enabled)
|
| 153 |
+
print(torch.backends.cudnn.version())
|
| 154 |
+
print(torch.version.cuda)
|
| 155 |
+
print(torch.__version__)
|
| 156 |
+
# �������
|
| 157 |
+
parser = argparse.ArgumentParser(description="Deploy the action for a specific player")
|
| 158 |
+
parser.add_argument("--ckpt_path", type=str, default='policy/ACT_DP_multitask/checkpoints/real_fintune_50_2000/act_dp')
|
| 159 |
+
# ��ȡ�������� PLAYER
|
| 160 |
+
player_value = os.getenv("PLAYER")
|
| 161 |
+
|
| 162 |
+
# ��黷�������Ƿ������������
|
| 163 |
+
if player_value is None:
|
| 164 |
+
raise ValueError("�������� PLAYER ���")
|
| 165 |
+
try:
|
| 166 |
+
player_value = int(player_value)
|
| 167 |
+
except ValueError:
|
| 168 |
+
raise ValueError("�������� PLAYER ������һ������")
|
| 169 |
+
|
| 170 |
+
# ���� PLAYER ��ִֵ�в�ͬ�IJ���
|
| 171 |
+
if player_value == 1:
|
| 172 |
+
print("Player 1")
|
| 173 |
+
cameras = [
|
| 174 |
+
RealSenseCam("337322073280", "left_camera"),
|
| 175 |
+
RealSenseCam("337322074191", "head_camera"),
|
| 176 |
+
RealSenseCam("337122072617", "right_camera"),
|
| 177 |
+
]
|
| 178 |
+
left_can, right_can = "can_left_1", "can_right_1"
|
| 179 |
+
elif player_value == 2:
|
| 180 |
+
print("Player 2")
|
| 181 |
+
cameras = [
|
| 182 |
+
RealSenseCam("250122079815", "left_camera"),
|
| 183 |
+
RealSenseCam("048522073543", "head_camera"),
|
| 184 |
+
RealSenseCam("030522070109", "right_camera"),
|
| 185 |
+
]
|
| 186 |
+
left_can, right_can = "can_left_2", "can_right_2"
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError("PLAYER ֵ��Ч�������� 1 �� 2")
|
| 189 |
+
reader = JointReader(left_can=left_can, right_can=right_can) # ��ȡ�ؽڶ���
|
| 190 |
+
# ==== Get RGB ====
|
| 191 |
+
# ���������Ķ������ڹ����������ӵ� RealSense �豸
|
| 192 |
+
ctx = rs.context()
|
| 193 |
+
|
| 194 |
+
# ����Ƿ����豸����
|
| 195 |
+
if len(ctx.devices) > 0:
|
| 196 |
+
print("Found RealSense devices:")
|
| 197 |
+
for d in ctx.devices:
|
| 198 |
+
# ��ȡ�豸�����ƺ����к�
|
| 199 |
+
name = d.get_info(rs.camera_info.name)
|
| 200 |
+
serial_number = d.get_info(rs.camera_info.serial_number)
|
| 201 |
+
print(f"Device: {name}, Serial Number: {serial_number}")
|
| 202 |
+
else:
|
| 203 |
+
print("No Intel RealSense devices connected")
|
| 204 |
+
|
| 205 |
+
# �����������
|
| 206 |
+
for cam in cameras:
|
| 207 |
+
cam.start()
|
| 208 |
+
|
| 209 |
+
# Ԥ�����
|
| 210 |
+
for i in range(10):
|
| 211 |
+
print(f"Warm up: {i}", end="\r")
|
| 212 |
+
for cam in cameras:
|
| 213 |
+
color_image = cam.get_latest_image()
|
| 214 |
+
time.sleep(0.15)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Load model
|
| 218 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 219 |
+
# device ='cpu'
|
| 220 |
+
ckpt_dir = parser.parse_args().ckpt_path
|
| 221 |
+
config_path = os.path.join(ckpt_dir, 'policy_config.json')
|
| 222 |
+
# ckpt_path = os.path.join(ckpt_dir, 'policy_lastest_seed_0.ckpt') # policy_epoch_600_seed_0.ckpt
|
| 223 |
+
ckpt_path = os.path.join(ckpt_dir, 'policy_epoch_600_seed_0.ckpt')
|
| 224 |
+
policy, camera_names = get_model(config_path, ckpt_path, device) # ��ȡģ��
|
| 225 |
+
print('camera_names:', camera_names)
|
| 226 |
+
# Get instructions
|
| 227 |
+
instruction_file = 'instruction.txt'
|
| 228 |
+
with open(instruction_file, 'r') as f:
|
| 229 |
+
instruction = f.readline().strip()
|
| 230 |
+
print(f"Using instruction: {instruction}")
|
| 231 |
+
tokenizer, text_encoder = get_language_encoder(device) # ��ȡ���Ա�����
|
| 232 |
+
print('Loading language tokenizer and encoder...')
|
| 233 |
+
task_emb = get_language_embed(tokenizer, text_encoder, instruction, device) # ��ȡ����Ƕ�� temsor D
|
| 234 |
+
# # ==== Get Observation ====
|
| 235 |
+
controller = ControlJoints(left_can=left_can, right_can=right_can)
|
| 236 |
+
max_timestep = 600
|
| 237 |
+
step = 0
|
| 238 |
+
while step < max_timestep:
|
| 239 |
+
obs_image = get_image(cameras)
|
| 240 |
+
joint = reader.get_joint_value()
|
| 241 |
+
|
| 242 |
+
observation = get_observation(obs_image, joint)
|
| 243 |
+
image, qpos = encode_obs(observation, camera_names)
|
| 244 |
+
actions = policy.get_action(qpos.float().to(device), image.float().to(device), task_emb.float().to(device))
|
| 245 |
+
print(f"Step: {step}/{max_timestep}, Action: {actions.shape}")
|
| 246 |
+
for action in actions[0:30]: # ִ��ÿ������
|
| 247 |
+
controller.control(action)
|
| 248 |
+
step += 1
|
| 249 |
+
# import pdb; pdb.set_trace() # Debugging line to pause execution
|
| 250 |
+
time.sleep(0.05)
|
| 251 |
+
# joint = actions[-1]
|
| 252 |
+
|
| 253 |
+
# obs = dict()
|
| 254 |
+
# for i in range(3):
|
| 255 |
+
# for cam in cameras:
|
| 256 |
+
# color_image = cam.get_latest_image()
|
| 257 |
+
# if color_image is not None:
|
| 258 |
+
# # ����ͼ��
|
| 259 |
+
# obs[cam] = color_image
|
| 260 |
+
# # filename = f"{cam.name}_image_{i}.png"
|
| 261 |
+
# # cv2.imwrite(filename, color_image)
|
| 262 |
+
# # print(f"Saved image: {filename}")
|
| 263 |
+
|
| 264 |
+
# # ==== Get Joint ====
|
| 265 |
+
# reader = JointReader(left_can=left_can, right_can=right_can)
|
| 266 |
+
# print(reader.get_joint_value())
|
| 267 |
+
|
| 268 |
+
# # ==== Deploy Action ====
|
| 269 |
+
# controller = ControlJoints(left_can=left_can, right_can=right_can)
|
| 270 |
+
# for i in range(10):
|
| 271 |
+
# positions = [0] * 14
|
| 272 |
+
# controller.control(positions)
|
| 273 |
+
# time.sleep(0.1)
|