import sys import os import subprocess import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas import cv2 from pathlib import Path robowin_root = Path("/group/ossdphi_algo_scratch_11/weicxu/pythonproject/RoboTwin") if str(robowin_root) not in sys.path: sys.path.insert(0, str(robowin_root)) import os os.chdir(robowin_root) from envs import CONFIGS_PATH from envs.utils.create_actor import UnStableError import numpy as np from pathlib import Path from collections import deque import traceback import yaml from datetime import datetime import importlib import argparse import pdb from evaluation.robotwin.geometry import euler2quat import numpy as np from description.utils.generate_episode_instructions import * import traceback import imageio import numpy as np from pathlib import Path from scipy.spatial.transform import Rotation as R import json from pathlib import Path from evaluation.robotwin.websocket_client_policy import WebsocketClientPolicy from evaluation.robotwin.test_render import Sapien_TEST def write_json(data: dict, fpath: Path) -> None: """Write data to a JSON file. Creates parent directories if they don't exist. Args: data (dict): The dictionary to write. fpath (Path): The path to the output JSON file. """ fpath.parent.mkdir(exist_ok=True, parents=True) with open(fpath, "w") as f: json.dump(data, f, indent=4, ensure_ascii=False) def add_title_bar(img, text, font_scale=0.8, thickness=2): """Add a black title bar with text above the image""" h, w, _ = img.shape bar_height = 40 # Create black background bar title_bar = np.zeros((bar_height, w, 3), dtype=np.uint8) # Calculate text position to center it (text_w, text_h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) text_x = (w - text_w) // 2 text_y = (bar_height + text_h) // 2 - 5 cv2.putText(title_bar, text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) return np.vstack([title_bar, img]) def quaternion_to_euler(quat): """ Convert quaternion to Euler angles (roll, pitch, yaw) quat: [rx, ry, rz, rw] format Return: [roll, pitch, yaw] (radians) """ # scipy uses [x, y, z, w] format rotation = R.from_quat(quat) euler = rotation.as_euler('xyz', degrees=False) # returns [roll, pitch, yaw] return euler def visualize_action_step(action_history, step_idx, window=50): """ Plot dual-arm action curves: Subplot 1: Left arm XYZ Position + Gripper Subplot 2: Left arm Euler angles (Roll, Pitch, Yaw) - converted from quaternion Subplot 3: Right arm XYZ Position + Gripper Subplot 4: Right arm Euler angles (Roll, Pitch, Yaw) - converted from quaternion Input data format: [left_x, left_y, left_z, left_rx, left_ry, left_rz, left_rw, left_gripper, right_x, right_y, right_z, right_rx, right_ry, right_rz, right_rw, right_gripper] Total 16 dimensions """ # Create four subplots, sharing the X-axis fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 8), dpi=100, sharex=True) # 1. Determine slice range start = max(0, step_idx - window) end = step_idx + 1 # 2. Get data subset history_subset = np.array(action_history)[start:end] # 3. Generate X-axis based on actual data length actual_len = len(history_subset) x_axis = range(start, start + actual_len) if actual_len > 0 and history_subset.shape[1] >= 16: # Convert quaternions to Euler angles left_euler = [] right_euler = [] for action in history_subset: # Left arm quaternion to Euler angles left_quat = action[3:7] # [rx, ry, rz, rw] left_rpy = quaternion_to_euler(left_quat) left_euler.append(left_rpy) # Right arm quaternion to Euler angles right_quat = action[11:15] # [rx, ry, rz, rw] right_rpy = quaternion_to_euler(right_quat) right_euler.append(right_rpy) left_euler = np.array(left_euler) right_euler = np.array(right_euler) # --- Left Arm --- # Subplot 1: Left Arm Translation (XYZ) + Gripper ax1.plot(x_axis, history_subset[:, 0], label='left_x', color='r', linewidth=1.5) ax1.plot(x_axis, history_subset[:, 1], label='left_y', color='g', linewidth=1.5) ax1.plot(x_axis, history_subset[:, 2], label='left_z', color='b', linewidth=1.5) ax1.plot(x_axis, history_subset[:, 7], label='left_grip', color='orange', linestyle=':', linewidth=2, alpha=0.8) ax1.set_ylabel('Position (m)') ax1.legend(loc='upper right', fontsize='x-small', ncol=4) ax1.grid(True, alpha=0.3) ax1.set_title(f"Step {step_idx}: Left Arm Position & Gripper") # Subplot 2: Left Arm Euler Angles (Roll, Pitch, Yaw) ax2.plot(x_axis, left_euler[:, 0], label='left_roll', color='c', linewidth=1.5) ax2.plot(x_axis, left_euler[:, 1], label='left_pitch', color='m', linewidth=1.5) ax2.plot(x_axis, left_euler[:, 2], label='left_yaw', color='y', linewidth=1.5) ax2.set_ylabel('Rotation (rad)') ax2.legend(loc='upper right', fontsize='x-small', ncol=3) ax2.grid(True, alpha=0.3) ax2.set_title("Left Arm Rotation (RPY from Quaternion)") # --- Right Arm --- # Subplot 3: Right Arm Translation (XYZ) + Gripper ax3.plot(x_axis, history_subset[:, 8], label='right_x', color='r', linewidth=1.5, linestyle='--') ax3.plot(x_axis, history_subset[:, 9], label='right_y', color='g', linewidth=1.5, linestyle='--') ax3.plot(x_axis, history_subset[:, 10], label='right_z', color='b', linewidth=1.5, linestyle='--') ax3.plot(x_axis, history_subset[:, 15], label='right_grip', color='orange', linestyle=':', linewidth=2, alpha=0.8) ax3.set_ylabel('Position (m)') ax3.legend(loc='upper right', fontsize='x-small', ncol=4) ax3.grid(True, alpha=0.3) ax3.set_title("Right Arm Position & Gripper") # Subplot 4: Right Arm Euler Angles (Roll, Pitch, Yaw) ax4.plot(x_axis, right_euler[:, 0], label='right_roll', color='c', linewidth=1.5, linestyle='--') ax4.plot(x_axis, right_euler[:, 1], label='right_pitch', color='m', linewidth=1.5, linestyle='--') ax4.plot(x_axis, right_euler[:, 2], label='right_yaw', color='y', linewidth=1.5, linestyle='--') ax4.set_ylabel('Rotation (rad)') ax4.legend(loc='upper right', fontsize='x-small', ncol=3) ax4.grid(True, alpha=0.3) ax4.set_title("Right Arm Rotation (RPY from Quaternion)") # Set X-axis display range to maintain sliding window effect ax1.set_xlim(max(0, step_idx - window), max(window, step_idx)) ax3.set_xlabel('Step') ax4.set_xlabel('Step') plt.tight_layout() canvas = FigureCanvas(fig) canvas.draw() img = np.asarray(canvas.buffer_rgba()) img = img[:, :, :3] # Convert to uint8 if img.dtype != np.uint8: img = (img * 255).astype(np.uint8) plt.close(fig) return img def save_comparison_video(real_obs_list, imagined_video, action_history, save_path, fps=15): if not real_obs_list: return n_real = len(real_obs_list) if imagined_video is not None: imagined_video = np.concatenate(imagined_video, 0) n_imagined = len(imagined_video) else: n_imagined = 0 n_frames = n_real # Based on real observation frames print(f"Saving video: Real {n_real} frames, Imagined {n_imagined} frames...") final_frames = [] for i in range(n_frames): obs = real_obs_list[i] cam_high = obs["observation.images.cam_high"] cam_left = obs["observation.images.cam_left_wrist"] cam_right = obs["observation.images.cam_right_wrist"] base_h = cam_high.shape[0] def resize_h(img, h): if img.shape[0] != h: w = int(img.shape[1] * h / img.shape[0]) return cv2.resize(img, (w, h)) return img row_real = np.hstack([ resize_h(cam_high, base_h), resize_h(cam_left, base_h), resize_h(cam_right, base_h) ]) if row_real.dtype != np.uint8: row_real = (row_real * 255).astype(np.uint8) row_real = add_title_bar(row_real, "Real Observation (High / Left / Right)") target_width = row_real.shape[1] if imagined_video is not None and i < n_imagined: img_frame = imagined_video[i] if img_frame.dtype != np.uint8 and img_frame.max() <= 1.0001: img_frame = (img_frame * 255).astype(np.uint8) elif img_frame.dtype != np.uint8: img_frame = img_frame.astype(np.uint8) h = int(img_frame.shape[0] * target_width / img_frame.shape[1]) row_imagined = cv2.resize(img_frame, (target_width, h)) else: row_imagined = np.zeros((300, target_width, 3), dtype=np.uint8) cv2.putText(row_imagined, "Coming soon", (target_width//2 - 100, 150), cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 100), 2) row_imagined = add_title_bar(row_imagined, "Imagined Video Stream") full_frame = np.vstack([row_real, row_imagined]) final_frames.append(full_frame) imageio.mimsave(save_path, final_frames, fps=fps) print(f"Combined video saved to: {save_path}") def class_decorator(task_name): envs_module = importlib.import_module(f"envs.{task_name}") try: env_class = getattr(envs_module, task_name) env_instance = env_class() except: raise SystemExit("No Task") return env_instance def eval_function_decorator(policy_name, model_name): try: policy_model = importlib.import_module(policy_name) return getattr(policy_model, model_name) except ImportError as e: raise e def get_camera_config(camera_type): camera_config_path = os.path.join(robowin_root, "task_config/_camera_config.yml") assert os.path.isfile(camera_config_path), "task config file is missing" with open(camera_config_path, "r", encoding="utf-8") as f: args = yaml.load(f.read(), Loader=yaml.FullLoader) assert camera_type in args, f"camera {camera_type} is not defined" return args[camera_type] def get_embodiment_config(robot_file): robot_config_file = os.path.join(robot_file, "config.yml") with open(robot_config_file, "r", encoding="utf-8") as f: embodiment_args = yaml.load(f.read(), Loader=yaml.FullLoader) return embodiment_args def main(usr_args): current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") task_name = usr_args["task_name"] task_config = usr_args["task_config"] ckpt_setting = usr_args["ckpt_setting"] save_root = usr_args["save_root"] policy_name = usr_args["policy_name"] video_guidance_scale = usr_args["video_guidance_scale"] action_guidance_scale = usr_args["action_guidance_scale"] instruction_type = 'seen' save_dir = None video_save_dir = None video_size = None with open(f"./task_config/{task_config}.yml", "r", encoding="utf-8") as f: args = yaml.load(f.read(), Loader=yaml.FullLoader) args['task_name'] = task_name args["task_config"] = task_config args["ckpt_setting"] = ckpt_setting args["save_root"] = save_root embodiment_type = args.get("embodiment") embodiment_config_path = os.path.join(CONFIGS_PATH, "_embodiment_config.yml") with open(embodiment_config_path, "r", encoding="utf-8") as f: _embodiment_types = yaml.load(f.read(), Loader=yaml.FullLoader) def get_embodiment_file(embodiment_type): robot_file = _embodiment_types[embodiment_type]["file_path"] if robot_file is None: raise "No embodiment files" return robot_file with open(CONFIGS_PATH + "_camera_config.yml", "r", encoding="utf-8") as f: _camera_config = yaml.load(f.read(), Loader=yaml.FullLoader) head_camera_type = args["camera"]["head_camera_type"] args["head_camera_h"] = _camera_config[head_camera_type]["h"] args["head_camera_w"] = _camera_config[head_camera_type]["w"] if len(embodiment_type) == 1: args["left_robot_file"] = get_embodiment_file(embodiment_type[0]) args["right_robot_file"] = get_embodiment_file(embodiment_type[0]) args["dual_arm_embodied"] = True elif len(embodiment_type) == 3: args["left_robot_file"] = get_embodiment_file(embodiment_type[0]) args["right_robot_file"] = get_embodiment_file(embodiment_type[1]) args["embodiment_dis"] = embodiment_type[2] args["dual_arm_embodied"] = False else: raise "embodiment items should be 1 or 3" args["left_embodiment_config"] = get_embodiment_config(args["left_robot_file"]) args["right_embodiment_config"] = get_embodiment_config(args["right_robot_file"]) if len(embodiment_type) == 1: embodiment_name = str(embodiment_type[0]) else: embodiment_name = str(embodiment_type[0]) + "+" + str(embodiment_type[1]) save_dir = Path(f"eval_result/{task_name}/{policy_name}/{task_config}/{ckpt_setting}/{current_time}") save_dir.mkdir(parents=True, exist_ok=True) if args["eval_video_log"]: video_save_dir = save_dir camera_config = get_camera_config(args["camera"]["head_camera_type"]) video_size = str(camera_config["w"]) + "x" + str(camera_config["h"]) video_save_dir.mkdir(parents=True, exist_ok=True) args["eval_video_save_dir"] = video_save_dir print("============= Config =============\n") print("\033[95mMessy Table:\033[0m " + str(args["domain_randomization"]["cluttered_table"])) print("\033[95mRandom Background:\033[0m " + str(args["domain_randomization"]["random_background"])) if args["domain_randomization"]["random_background"]: print(" - Clean Background Rate: " + str(args["domain_randomization"]["clean_background_rate"])) print("\033[95mRandom Light:\033[0m " + str(args["domain_randomization"]["random_light"])) if args["domain_randomization"]["random_light"]: print(" - Crazy Random Light Rate: " + str(args["domain_randomization"]["crazy_random_light_rate"])) print("\033[95mRandom Table Height:\033[0m " + str(args["domain_randomization"]["random_table_height"])) print("\033[95mRandom Head Camera Distance:\033[0m " + str(args["domain_randomization"]["random_head_camera_dis"])) print("\033[94mHead Camera Config:\033[0m " + str(args["camera"]["head_camera_type"]) + f", " + str(args["camera"]["collect_head_camera"])) print("\033[94mWrist Camera Config:\033[0m " + str(args["camera"]["wrist_camera_type"]) + f", " + str(args["camera"]["collect_wrist_camera"])) print("\033[94mEmbodiment Config:\033[0m " + embodiment_name) print("\n==================================") TASK_ENV = class_decorator(args["task_name"]) args["policy_name"] = policy_name usr_args["left_arm_dim"] = len(args["left_embodiment_config"]["arm_joints_name"][0]) usr_args["right_arm_dim"] = len(args["right_embodiment_config"]["arm_joints_name"][1]) seed = usr_args["seed"] st_seed = 10000 * (1 + seed) suc_nums = [] test_num = usr_args["test_num"] model = WebsocketClientPolicy(port=usr_args['port']) st_seed, suc_num = eval_policy(task_name, TASK_ENV, args, model, st_seed, test_num=test_num, video_size=video_size, instruction_type=instruction_type, save_visualization=True, video_guidance_scale=video_guidance_scale, action_guidance_scale=action_guidance_scale) suc_nums.append(suc_num) file_path = os.path.join(save_dir, f"_result.txt") with open(file_path, "w") as file: file.write(f"Timestamp: {current_time}\n\n") file.write(f"Instruction Type: {instruction_type}\n\n") file.write("\n".join(map(str, np.array(suc_nums) / test_num))) print(f"Data has been saved to {file_path}") def format_obs(observation, prompt): return { "observation.images.cam_high": observation["observation"]["head_camera"]["rgb"], # H,W,3 "observation.images.cam_left_wrist": observation["observation"]["left_camera"]["rgb"], "observation.images.cam_right_wrist": observation["observation"]["right_camera"]["rgb"], "observation.state": observation["joint_action"]["vector"], "task": prompt, } def add_eef_pose(new_pose, init_pose): new_pose_R = R.from_quat(new_pose[3:7][None]) init_pose_R = R.from_quat(init_pose[3:7][None]) out_rot = (init_pose_R * new_pose_R).as_quat().reshape(-1) out_trans = new_pose[:3] + init_pose[:3] return np.concatenate([out_trans, out_rot, new_pose[7:8]]) def add_init_pose(new_pose, init_pose): left_pose = add_eef_pose(new_pose[:8], init_pose[:8]) right_pose = add_eef_pose(new_pose[8:], init_pose[8:]) return np.concatenate([left_pose, right_pose]) def eval_policy(task_name, TASK_ENV, args, model, st_seed, test_num=100, video_size=None, instruction_type=None, save_visualization=False, video_guidance_scale=5.0, action_guidance_scale=5.0): print(f"\033[34mTask Name: {args['task_name']}\033[0m") print(f"\033[34mPolicy Name: {args['policy_name']}\033[0m") expert_check = True TASK_ENV.suc = 0 TASK_ENV.test_num = 0 now_id = 0 succ_seed = 0 suc_test_seed_list = [] now_seed = st_seed clear_cache_freq = args["clear_cache_freq"] args["eval_mode"] = True while succ_seed < test_num: render_freq = args["render_freq"] args["render_freq"] = 0 if expert_check: try: TASK_ENV.setup_demo(now_ep_num=now_id, seed=now_seed, is_test=True, **args) episode_info = TASK_ENV.play_once() TASK_ENV.close_env() except UnStableError as e: TASK_ENV.close_env() now_seed += 1 args["render_freq"] = render_freq continue except Exception as e: TASK_ENV.close_env() now_seed += 1 args["render_freq"] = render_freq print(f"error occurs ! {e}") traceback.print_exc() continue if (not expert_check) or (TASK_ENV.plan_success and TASK_ENV.check_success()): succ_seed += 1 suc_test_seed_list.append(now_seed) else: now_seed += 1 args["render_freq"] = render_freq continue args["render_freq"] = render_freq TASK_ENV.setup_demo(now_ep_num=now_id, seed=now_seed, is_test=True, **args) episode_info_list = [episode_info["info"]] results = generate_episode_descriptions(args["task_name"], episode_info_list, test_num) instruction = np.random.choice(results[0][instruction_type]) TASK_ENV.set_instruction(instruction=instruction) # set language instruction if TASK_ENV.eval_video_path is not None: ffmpeg = subprocess.Popen( [ "ffmpeg", "-y", "-loglevel", "error", "-f", "rawvideo", "-pixel_format", "rgb24", "-video_size", video_size, "-framerate", "10", "-i", "-", "-pix_fmt", "yuv420p", "-vcodec", "libx264", "-crf", "23", f"{TASK_ENV.eval_video_path}/episode{TASK_ENV.test_num}.mp4", ], stdin=subprocess.PIPE, ) TASK_ENV._set_eval_video_ffmpeg(ffmpeg) succ = False prompt = TASK_ENV.get_instruction() ret = model.infer(dict(reset = True, prompt=prompt, save_visualization=save_visualization)) first = True full_obs_list = [] gen_video_list = [] full_action_history = [] initial_obs = TASK_ENV.get_obs() inint_eef_pose = initial_obs['endpose']['left_endpose'] + \ [initial_obs['endpose']['left_gripper']] + \ initial_obs['endpose']['right_endpose'] + \ [initial_obs['endpose']['right_gripper']] inint_eef_pose = np.array(inint_eef_pose, dtype=np.float64) initial_formatted_obs = format_obs(initial_obs, prompt) full_obs_list.append(initial_formatted_obs) first_obs = None while TASK_ENV.take_action_cnt \033[95m{round(TASK_ENV.suc/TASK_ENV.test_num*100, 1)}%\033[0m, current seed: \033[90m{now_seed}\033[0m\n" ) now_seed += 1 return now_seed, TASK_ENV.suc def parse_args_and_config(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--overrides", nargs=argparse.REMAINDER) parser.add_argument("--port", type=int, default=8000, help='remote policy socket port.') parser.add_argument("--save_root", type=str, default="results/default_vis_path") parser.add_argument("--video_guidance_scale", type=float, default=5.0) parser.add_argument("--action_guidance_scale", type=float, default=5.0) parser.add_argument("--test_num", type=int, default=100) args = parser.parse_args() with open(args.config, "r", encoding="utf-8") as f: config = yaml.safe_load(f) # Parse overrides def parse_override_pairs(pairs): override_dict = {} for i in range(0, len(pairs), 2): key = pairs[i].lstrip("--") value = pairs[i + 1] try: value = eval(value) except: pass override_dict[key] = value return override_dict if args.overrides: overrides = parse_override_pairs(args.overrides) config.update(overrides) return config if __name__ == "__main__": Sapien_TEST() usr_args = parse_args_and_config() main(usr_args)