Upload demo_deploy.py

#1
Files changed (1) hide show
  1. demo_deploy.py +273 -0
demo_deploy.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from control_joints import *
2
+ from read_joints import *
3
+ from realsense import *
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import time
9
+ import argparse
10
+ import sys
11
+ import json
12
+ # get current workspace
13
+ sys.path.insert(0, 'policy/ACT_DP_multitask') # Add the path to the policy directory
14
+ from policy_anyrobot import ACTDiffusionPolicy
15
+ from t5_encoder import T5Embedder
16
+ from utils import convert_weight
17
+ import yaml
18
+
19
+ def get_image(cameras):
20
+ """
21
+ Get the latest images from all cameras and return them as a dictionary.
22
+ """
23
+ obs_image = dict()
24
+ for i in range(3):
25
+ for cam in cameras:
26
+ color_image = cam.get_latest_image() # Get the latest image from the camera BGR 0-255 H W 3
27
+ if color_image is not None:
28
+ obs_image[cam.name] = color_image[:,:,::-1] # BGR to RGB conversion 0-255 H W 3
29
+ # cv save
30
+ filename = f"{cam.name}_image_{i}.png"
31
+ cv2.imwrite(filename, color_image)
32
+ print(f"Saved image: {filename}")
33
+ return obs_image # Return the dictionary containing images from all cameras
34
+
35
+ def get_observation(obs_image, joint):
36
+ # for key, value in obs_image.items():
37
+ # print(f"Camera: {key}, Image shape: {value.shape} , Image: {value.max()}") # Debugging line to check image shapes
38
+ observation = dict()
39
+ observation['images'] = dict()
40
+ observation['images']['cam_high'] = np.moveaxis(obs_image['head_camera'], -1, 0) # rgb H W C -> C H W
41
+ observation['images']['cam_left_wrist']= np.moveaxis(obs_image['left_camera'], -1, 0) # rgb H W C -> C H W
42
+ observation['images']['cam_right_wrist'] = np.moveaxis(obs_image['right_camera'], -1, 0) # rgb H W C -> C H W
43
+ observation['qpos'] = joint
44
+
45
+ return observation
46
+
47
+ def encode_obs(observations, camera_names):
48
+ obs_img = []
49
+ for camera_name in camera_names: # ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
50
+ obs_img.append(observations['images'][camera_name])
51
+ obs_img = np.stack(obs_img, axis=0) # shape: (N_views, H, W, C) 0-255 rgb
52
+ image_data = torch.from_numpy(obs_img).unsqueeze(0).float() / 255.0 # Normalize to [0, 1] and add history dimension
53
+ qpos_data = torch.from_numpy(np.array(observations['qpos'], dtype=np.float32)).unsqueeze(0) # shape: (1, 14)
54
+ return image_data, qpos_data # no batch dimension
55
+
56
+ def get_model(config_file, ckpt_file, device):
57
+ with open(config_file, "r", encoding="utf-8") as file:
58
+ policy_config = json.load(file)
59
+ print(f"Loading policy config from {config_file}")
60
+ policy = ACTDiffusionPolicy(policy_config)
61
+ print(f"Loading model from {ckpt_file}")
62
+ policy.load_state_dict(convert_weight(torch.load(ckpt_file)["state_dict"]))
63
+ policy.to(device)
64
+ policy.eval()
65
+ stats = torch.load(ckpt_file)["stats"]
66
+ print('Resetting observation normalization stats')
67
+ policy.reset_obs(stats, norm_type = policy_config["norm_type"])
68
+ camera_names = policy_config["camera_names"]
69
+ return policy, camera_names
70
+
71
+ def get_language_encoder(device):
72
+ MODEL_PATH ='/data/gjx/.cache/huggingface/hub/models--google--t5-v1_1-xxl/snapshots/3db67ab1af984cf10548a73467f0e5bca2aaaeb2'
73
+ # MODEL_PATH = 'policy/weights/RDT/t5-v1_1-xxl'
74
+ CONFIG_PATH = os.path.join('policy/ACT_DP_multitask',"base.yaml")
75
+ with open(CONFIG_PATH, "r") as fp:
76
+ config = yaml.safe_load(fp)
77
+ text_embedder = T5Embedder(
78
+ from_pretrained=MODEL_PATH,
79
+ model_max_length=config["dataset"]["tokenizer_max_length"],
80
+ device=device,
81
+ use_offload_folder=None,
82
+ )
83
+ tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
84
+ return tokenizer, text_encoder
85
+
86
+ def get_language_embed(tokenizer, text_encoder, language_instruction, device):
87
+ tokens = tokenizer(
88
+ language_instruction, return_tensors="pt",
89
+ padding="longest",
90
+ truncation=True
91
+ )["input_ids"].to(device)
92
+
93
+ tokens = tokens.view(1, -1)
94
+ with torch.no_grad():
95
+ pred = text_encoder(tokens).last_hidden_state.detach().cpu()
96
+
97
+ return pred.squeeze(0).mean(0) # shape: (hidden_size) cpu tensor
98
+
99
+ def test_code():
100
+ print("Testing the code...") # Dummy test function to ensure the script runs without errors
101
+ # Dummy test function to ensure the script runs without errors
102
+ # reader = JointReader(left_can="can_left_1", right_can="can_right_1") # ��ȡ�ؽڶ���
103
+ # cameras = [
104
+ # RealSenseCam("250122079815", "left_camera"),
105
+ # RealSenseCam("048522073543", "head_camera"),
106
+ # RealSenseCam("030522070109", "right_camera"),
107
+ # ]
108
+ # left_can, right_can = "can_left_2", "can_right_2"
109
+ # ctx = rs.context()
110
+ # if len(ctx.devices) > 0:
111
+ # print("Found RealSense devices:")
112
+ # for d in ctx.devices:
113
+ # # ��ȡ�豸�����ƺ����к�
114
+ # name = d.get_info(rs.camera_info.name)
115
+ # serial_number = d.get_info(rs.camera_info.serial_number)
116
+ # print(f"Device: {name}, Serial Number: {serial_number}")
117
+ # else:
118
+ # print("No Intel RealSense devices connected")
119
+ # for cam in cameras:
120
+ # cam.start()
121
+
122
+ # for i in range(20):
123
+ # print(f"Warm up: {i}", end="\r")
124
+ # for cam in cameras:
125
+ # color_image = cam.get_latest_image()
126
+ # time.sleep(0.15)
127
+
128
+ device = 'cpu'
129
+ ckpt_dir = 'policy/ACT_DP_multitask/checkpoints/real_fintune_50_2000/act_dp'
130
+ config_path = os.path.join(ckpt_dir, 'policy_config.json')
131
+ ckpt_path = os.path.join(ckpt_dir, 'policy_lastest_seed_0.ckpt')
132
+ policy, cameras_name = get_model(config_path, ckpt_path, device)
133
+
134
+ instruction_file = 'instruction.txt'
135
+ with open(instruction_file, 'r') as f:
136
+ instruction = f.readline().strip().strip('"')
137
+ print(f"Instruction: {instruction}")
138
+ tokenizer, text_encoder = get_language_encoder(device) # ��ȡ���Ա�����
139
+ task_emb = get_language_embed(tokenizer, text_encoder, instruction, device) # ��ȡ����Ƕ�� temsor D
140
+ controller = ControlJoints(left_can=left_can, right_can=right_can)
141
+ image = torch.rand(1, 3, 3, 480, 640) # Dummy image tensor
142
+ qpos = torch.rand(1, 14) # Dummy joint position tensor
143
+ actions = policy.get_action(qpos.float().to(device), image.float().to(device), task_emb.float().to(device))
144
+ print(f"Actions shape: {actions.shape}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+
149
+ # test_code() # Uncomment to run the test function
150
+ # print("Testing the done...")
151
+ # exit()
152
+ print(torch.backends.cudnn.enabled)
153
+ print(torch.backends.cudnn.version())
154
+ print(torch.version.cuda)
155
+ print(torch.__version__)
156
+ # �������
157
+ parser = argparse.ArgumentParser(description="Deploy the action for a specific player")
158
+ parser.add_argument("--ckpt_path", type=str, default='policy/ACT_DP_multitask/checkpoints/real_fintune_50_2000/act_dp')
159
+ # ��ȡ�������� PLAYER
160
+ player_value = os.getenv("PLAYER")
161
+
162
+ # ��黷�������Ƿ������������
163
+ if player_value is None:
164
+ raise ValueError("�������� PLAYER ���")
165
+ try:
166
+ player_value = int(player_value)
167
+ except ValueError:
168
+ raise ValueError("�������� PLAYER ������һ������")
169
+
170
+ # ���� PLAYER ��ִֵ�в�ͬ�IJ���
171
+ if player_value == 1:
172
+ print("Player 1")
173
+ cameras = [
174
+ RealSenseCam("337322073280", "left_camera"),
175
+ RealSenseCam("337322074191", "head_camera"),
176
+ RealSenseCam("337122072617", "right_camera"),
177
+ ]
178
+ left_can, right_can = "can_left_1", "can_right_1"
179
+ elif player_value == 2:
180
+ print("Player 2")
181
+ cameras = [
182
+ RealSenseCam("250122079815", "left_camera"),
183
+ RealSenseCam("048522073543", "head_camera"),
184
+ RealSenseCam("030522070109", "right_camera"),
185
+ ]
186
+ left_can, right_can = "can_left_2", "can_right_2"
187
+ else:
188
+ raise ValueError("PLAYER ֵ��Ч�������� 1 �� 2")
189
+ reader = JointReader(left_can=left_can, right_can=right_can) # ��ȡ�ؽڶ���
190
+ # ==== Get RGB ====
191
+ # ���������Ķ������ڹ����������ӵ� RealSense �豸
192
+ ctx = rs.context()
193
+
194
+ # ����Ƿ����豸����
195
+ if len(ctx.devices) > 0:
196
+ print("Found RealSense devices:")
197
+ for d in ctx.devices:
198
+ # ��ȡ�豸�����ƺ����к�
199
+ name = d.get_info(rs.camera_info.name)
200
+ serial_number = d.get_info(rs.camera_info.serial_number)
201
+ print(f"Device: {name}, Serial Number: {serial_number}")
202
+ else:
203
+ print("No Intel RealSense devices connected")
204
+
205
+ # �����������
206
+ for cam in cameras:
207
+ cam.start()
208
+
209
+ # Ԥ�����
210
+ for i in range(10):
211
+ print(f"Warm up: {i}", end="\r")
212
+ for cam in cameras:
213
+ color_image = cam.get_latest_image()
214
+ time.sleep(0.15)
215
+
216
+
217
+ # Load model
218
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
219
+ # device ='cpu'
220
+ ckpt_dir = parser.parse_args().ckpt_path
221
+ config_path = os.path.join(ckpt_dir, 'policy_config.json')
222
+ # ckpt_path = os.path.join(ckpt_dir, 'policy_lastest_seed_0.ckpt') # policy_epoch_600_seed_0.ckpt
223
+ ckpt_path = os.path.join(ckpt_dir, 'policy_epoch_600_seed_0.ckpt')
224
+ policy, camera_names = get_model(config_path, ckpt_path, device) # ��ȡģ��
225
+ print('camera_names:', camera_names)
226
+ # Get instructions
227
+ instruction_file = 'instruction.txt'
228
+ with open(instruction_file, 'r') as f:
229
+ instruction = f.readline().strip()
230
+ print(f"Using instruction: {instruction}")
231
+ tokenizer, text_encoder = get_language_encoder(device) # ��ȡ���Ա�����
232
+ print('Loading language tokenizer and encoder...')
233
+ task_emb = get_language_embed(tokenizer, text_encoder, instruction, device) # ��ȡ����Ƕ�� temsor D
234
+ # # ==== Get Observation ====
235
+ controller = ControlJoints(left_can=left_can, right_can=right_can)
236
+ max_timestep = 600
237
+ step = 0
238
+ while step < max_timestep:
239
+ obs_image = get_image(cameras)
240
+ joint = reader.get_joint_value()
241
+
242
+ observation = get_observation(obs_image, joint)
243
+ image, qpos = encode_obs(observation, camera_names)
244
+ actions = policy.get_action(qpos.float().to(device), image.float().to(device), task_emb.float().to(device))
245
+ print(f"Step: {step}/{max_timestep}, Action: {actions.shape}")
246
+ for action in actions[0:30]: # ִ��ÿ������
247
+ controller.control(action)
248
+ step += 1
249
+ # import pdb; pdb.set_trace() # Debugging line to pause execution
250
+ time.sleep(0.05)
251
+ # joint = actions[-1]
252
+
253
+ # obs = dict()
254
+ # for i in range(3):
255
+ # for cam in cameras:
256
+ # color_image = cam.get_latest_image()
257
+ # if color_image is not None:
258
+ # # ����ͼ��
259
+ # obs[cam] = color_image
260
+ # # filename = f"{cam.name}_image_{i}.png"
261
+ # # cv2.imwrite(filename, color_image)
262
+ # # print(f"Saved image: {filename}")
263
+
264
+ # # ==== Get Joint ====
265
+ # reader = JointReader(left_can=left_can, right_can=right_can)
266
+ # print(reader.get_joint_value())
267
+
268
+ # # ==== Deploy Action ====
269
+ # controller = ControlJoints(left_can=left_can, right_can=right_can)
270
+ # for i in range(10):
271
+ # positions = [0] * 14
272
+ # controller.control(positions)
273
+ # time.sleep(0.1)