Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from openpi.training import config as config_pi | |
| from openpi.policies import policy_config | |
| from openpi_client import image_tools | |
| # from openpi.shared import download | |
| import numpy as np | |
| from accelerate import Accelerator | |
| import torch | |
| from diffusers import StableVideoDiffusionPipeline | |
| import numpy as np | |
| # import cv2 | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import einops | |
| from accelerate import Accelerator | |
| import datetime | |
| import os | |
| from accelerate.logging import get_logger | |
| from tqdm.auto import tqdm | |
| import wandb | |
| import json | |
| from decord import VideoReader, cpu | |
| import swanlab | |
| import mediapy | |
| import sys | |
| from scipy.spatial.transform import Rotation as R | |
| import sys, os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline | |
| from models.ctrl_world import CrtlWorld | |
| from models.utils import key_board_control, get_fk_solution | |
| class agent(): | |
| def __init__(self,args): | |
| # args = Args() | |
| args.val_model_path = args.ckpt_path | |
| self.args = args | |
| self.accelerator = Accelerator() | |
| self.device = self.accelerator.device | |
| self.dtype = args.dtype | |
| # load pi policy | |
| if 'pi05' in args.policy_type: | |
| config = config_pi.get_config("pi05_droid") | |
| # checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid' | |
| elif 'pi0fast' in args.policy_type: | |
| config = config_pi.get_config("pi0fast_droid") | |
| # checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0fast_droid' | |
| elif 'pi0' in args.policy_type: | |
| config = config_pi.get_config("pi0_droid") | |
| # checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0_droid' | |
| else: | |
| raise ValueError(f"Unknown policy type: {args.policy_type}") | |
| self.policy = policy_config.create_trained_policy(config, args.pi_ckpt) | |
| # load ctrl-world model | |
| self.model = CrtlWorld(args) | |
| self.model.load_state_dict(torch.load(args.val_model_path)) | |
| self.model.to(self.accelerator.device).to(self.dtype) | |
| self.model.eval() | |
| print("load world model success") | |
| with open(f"{args.data_stat_path}", 'r') as f: | |
| data_stat = json.load(f) | |
| self.state_p01 = np.array(data_stat['state_01'])[None,:] | |
| self.state_p99 = np.array(data_stat['state_99'])[None,:] | |
| # Since the official Pi-Droid model output joint velocity, and crtl-world is train on cartesian space, we need to load an light-weight adapter to transform joint velocity action into cartesian pose action. | |
| if args.action_adapter is not None: | |
| from models.action_adapter.train2 import Dynamics | |
| self.dynamics_model = Dynamics(action_dim=7, action_num=15, hidden_size=512).to(self.device) | |
| self.dynamics_model.load_state_dict(torch.load(args.action_adapter, map_location=self.device)) | |
| def normalize_bound( | |
| self, | |
| data: np.ndarray, | |
| data_min: np.ndarray, | |
| data_max: np.ndarray, | |
| clip_min: float = -1, | |
| clip_max: float = 1, | |
| eps: float = 1e-8, | |
| ) -> np.ndarray: | |
| ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1 | |
| return np.clip(ndata, clip_min, clip_max) | |
| def get_traj_info(self, id, start_idx=0, steps=8,skip=1): | |
| val_dataset_dir = self.args.val_dataset_dir | |
| num_frames = steps | |
| annotation_path = f"{val_dataset_dir}/annotation/val/{id}.json" | |
| with open(annotation_path) as f: | |
| anno = json.load(f) | |
| try: | |
| length = len(anno['action']) | |
| except: | |
| length = anno["video_length"] | |
| frames_ids = np.arange(start_idx, start_idx + num_frames * skip, skip) | |
| max_ids = np.ones_like(frames_ids) * (length - 1) | |
| frames_ids = np.min([frames_ids, max_ids], axis=0).astype(int) | |
| print("Ground truth frames ids", frames_ids) | |
| # get action and joint pos | |
| instruction = anno['texts'][0] | |
| car_action = np.array(anno['states']) | |
| car_action = car_action[frames_ids] | |
| joint_pos = np.array(anno['joints']) | |
| joint_pos = joint_pos[frames_ids] | |
| # get videos | |
| video_dict =[] | |
| video_latent = [] | |
| for id in range(len(anno['videos'])): | |
| video_path = anno['videos'][id]['video_path'] | |
| video_path = f"{val_dataset_dir}/{video_path}" | |
| # load videos from all views | |
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) | |
| try: | |
| true_video = vr.get_batch(range(length)).asnumpy() | |
| except: | |
| true_video = vr.get_batch(range(length)).numpy() | |
| true_video = true_video[frames_ids] | |
| video_dict.append(true_video) | |
| # encode video | |
| device = self.device | |
| true_video = torch.from_numpy(true_video).to(self.dtype).to(device) | |
| x = true_video.permute(0,3,1,2).to(device) / 255.0*2-1 | |
| vae = self.model.pipeline.vae | |
| with torch.no_grad(): | |
| batch_size = 32 | |
| latents = [] | |
| for i in range(0, len(x), batch_size): | |
| batch = x[i:i+batch_size] | |
| latent = vae.encode(batch).latent_dist.sample().mul_(vae.config.scaling_factor) | |
| latents.append(latent) | |
| x = torch.cat(latents, dim=0) | |
| video_latent.append(x) | |
| return car_action, joint_pos, video_dict, video_latent, instruction | |
| def forward_wm(self, action_cond, video_latent_true, video_latent_cond, his_cond=None, text=None): | |
| # action_cond, video_latent_true, current_latent, his_cond=his_latent,text=text_i | |
| args = self.args | |
| image_cond = video_latent_cond | |
| # action should be normed | |
| action_cond = self.normalize_bound(action_cond, self.state_p01, self.state_p99, clip_min=-1, clip_max=1) | |
| action_cond = torch.tensor(action_cond).unsqueeze(0).to(self.device).to(self.dtype) | |
| assert image_cond.shape[1:] == (4, 72, 40) | |
| assert action_cond.shape[1:] == (args.num_frames+args.num_history, args.action_dim) | |
| # predict future frames | |
| with torch.no_grad(): | |
| bsz = action_cond.shape[0] | |
| if text is not None: | |
| text_token = self.model.action_encoder(action_cond, text, self.model.tokenizer, self.model.text_encoder) | |
| else: | |
| text_token = self.model.action_encoder(action_cond) | |
| pipeline = self.model.pipeline | |
| _, latents = CtrlWorldDiffusionPipeline.__call__( | |
| pipeline, | |
| image=image_cond, | |
| text=text_token, | |
| width=args.width, | |
| height=int(args.height*3), | |
| num_frames=args.num_frames, | |
| history=his_cond, | |
| num_inference_steps=args.num_inference_steps, | |
| decode_chunk_size=args.decode_chunk_size, | |
| max_guidance_scale=args.guidance_scale, | |
| fps=args.fps, | |
| motion_bucket_id=args.motion_bucket_id, | |
| mask=None, | |
| output_type='latent', | |
| return_dict=False, | |
| frame_level_cond=True, | |
| ) | |
| latents = einops.rearrange(latents, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32) | |
| # decode ground truth video | |
| true_video = torch.stack(video_latent_true, dim=0) # (bsz, 8,32,32) | |
| decoded_video = [] | |
| bsz,frame_num = true_video.shape[:2] | |
| true_video = true_video.flatten(0,1) | |
| decode_kwargs = {} | |
| for i in range(0,true_video.shape[0],args.decode_chunk_size): | |
| chunk = true_video[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor | |
| decode_kwargs["num_frames"] = chunk.shape[0] | |
| decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample) | |
| true_video = torch.cat(decoded_video,dim=0) | |
| true_video = true_video.reshape(bsz,frame_num,*true_video.shape[1:]) | |
| true_video = ((true_video / 2.0 + 0.5).clamp(0, 1)*255) | |
| true_video = true_video.detach().to(torch.float32).cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3) | |
| # decode predicted video | |
| decoded_video = [] | |
| bsz,frame_num = latents.shape[:2] | |
| x = latents.flatten(0,1) | |
| decode_kwargs = {} | |
| for i in range(0,x.shape[0],args.decode_chunk_size): | |
| chunk = x[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor | |
| decode_kwargs["num_frames"] = chunk.shape[0] | |
| decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample) | |
| videos = torch.cat(decoded_video,dim=0) | |
| videos = videos.reshape(bsz,frame_num,*videos.shape[1:]) | |
| videos = ((videos / 2.0 + 0.5).clamp(0, 1)*255) | |
| videos = videos.detach().to(torch.float32).cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) | |
| # concatenate true videos and video | |
| videos_cat = np.concatenate([true_video,videos],axis=-3) # (3, 8, 256, 256, 3) | |
| videos_cat = np.concatenate([video for video in videos_cat],axis=-2).astype(np.uint8) | |
| return videos_cat, true_video, videos, latents # np.uint8:(3, 8, 128, 256, 3) or (3, 8, 192, 320, 3) | |
| def forward_policy(self, videos, state, joints, text, time_step=1): | |
| # inference policy | |
| image1 = videos[1] | |
| image2 = videos[2] | |
| image1 = torch.from_numpy(image1).to(torch.uint8) # convert to torch tensor | |
| image2 = torch.from_numpy(image2).to(torch.uint8) # convert to torch tensor | |
| assert image1.shape == (192, 320, 3), "Image 1 shape should be (192, 320, 3), got {}".format(image1.shape) | |
| image1 = torch.nn.functional.interpolate(image1.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8) | |
| image2 = torch.nn.functional.interpolate(image2.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8) | |
| image1 = image1.numpy() # convert back to numpy array | |
| image2 = image2.numpy() # convert back to numpy array | |
| example = { | |
| "observation/exterior_image_1_left": image_tools.resize_with_pad(image1, 224, 224), | |
| "observation/wrist_image_left": image_tools.resize_with_pad(image2, 224, 224), | |
| "observation/joint_position": joints[:7], | |
| "observation/gripper_position": joints[-1:], | |
| "prompt": text, | |
| } | |
| action_chunk = self.policy.infer(example)["actions"] #(10,8) velocity | |
| # action adapater | |
| current_joint = joints[None,:][:,:7] | |
| current_gripper = joints[None,:][:,7:] | |
| if 'pi05' in self.args.policy_type: | |
| idx = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14] # for dynamics model, we need 15 steps | |
| else: | |
| idx = [0,1,2,3,4,5,6,7,8,9,9,9,9,9,9] | |
| # policy output joint velocity and gripper position | |
| joint_vel = action_chunk[:,:7] # (15, 7) | |
| gripper_pos = action_chunk[:,7:] # (15, 1) | |
| joint_vel = joint_vel[idx] # (15, 7) | |
| gripper_pos = gripper_pos[idx] # (15, 1) | |
| gripper_max = self.args.gripper_max | |
| z_min = self.args.z_min | |
| gripper_pos = np.clip(gripper_pos, 0, gripper_max) | |
| # calculate future joint positions | |
| joint_pos = self.dynamics_model(current_joint, joint_vel,None, training=False) | |
| # fk | |
| state_fk = [] | |
| joint_pos = np.concatenate([current_joint, joint_pos], axis=0)[:15] # (15, 7) | |
| gripper_pos = np.concatenate([current_gripper, gripper_pos], axis=0)[:15] # (15, 1) | |
| joint_vel = joint_vel # (15, 7) | |
| for i in range(joint_pos.shape[0]): | |
| current_state_fk = get_fk_solution(joint_pos[i,:7]) | |
| xyz = current_state_fk[:3, 3] | |
| # clip z axis to avoid collision with table | |
| xyz[2] = np.clip(xyz[2], z_min, None) | |
| rotation_matrix = current_state_fk[:3, :3] | |
| r = R.from_matrix(rotation_matrix) | |
| euler = r.as_euler('xyz') | |
| state_fk.append(np.concatenate([xyz, euler, gripper_pos[i]], axis=0)) | |
| state_fk = np.array(state_fk) # (15,7) | |
| # prepare output | |
| skip = self.args.policy_skip_step | |
| valid_num = int(skip*(self.args.pred_step-1)) | |
| policy_in_out = { | |
| 'joint_pos': joint_pos[:valid_num], # (12, 7) | |
| 'joint_vel': joint_vel[:valid_num], # (12, 7) | |
| 'state_fk': state_fk[:valid_num], # (12, 7) | |
| } | |
| state_fk_skip = state_fk[::skip][:self.args.pred_step] # (5, 7) | |
| joint_pos_skip = joint_pos[::skip][:self.args.pred_step] # (5, 7) | |
| joint_pos_skip = np.concatenate([joint_pos_skip, state_fk_skip[:,-1:]], axis=-1) # (5, 8) add gripper pos | |
| return policy_in_out, joint_pos_skip, state_fk_skip | |
| if __name__ == "__main__": | |
| from config_eval import wm_args | |
| from argparse import ArgumentParser | |
| parser = ArgumentParser() | |
| parser.add_argument('--svd_model_path', type=str, default=None) | |
| parser.add_argument('--clip_model_path', type=str, default=None) | |
| parser.add_argument('--ckpt_path', type=str, default=None) | |
| parser.add_argument('--dataset_root_path', type=str, default=None) | |
| parser.add_argument('--dataset_meta_info_path', type=str, default=None) | |
| parser.add_argument('--dataset_names', type=str, default=None) | |
| parser.add_argument('--task_type', type=str, default=None) | |
| parser.add_argument('--pi_ckpt', type=str, default='/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid') | |
| args_new = parser.parse_args() | |
| args = wm_args(task_type=args_new.task_type) | |
| def merge_args(cfg, cli_args): | |
| for k, v in vars(cli_args).items(): | |
| if v is not None: | |
| setattr(cfg, k, v) | |
| return cfg | |
| args = merge_args(args, args_new) | |
| # create agent | |
| Agent = agent(args) | |
| interact_num = args.interact_num | |
| pred_step = args.pred_step | |
| num_history = args.num_history | |
| num_frames = args.num_frames | |
| history_idx = args.history_idx | |
| # run len(val_id) trajectory | |
| for val_id_i, text_i, start_idx_i in zip(args.val_id, args.instruction, args.start_idx): | |
| # get initial state and groud truth | |
| id = val_id_i | |
| eef_gt, joint_pos_gt, video_dict, video_latents,_ = Agent.get_traj_info(val_id_i, start_idx=start_idx_i, steps=int(pred_step*interact_num+8)) | |
| print("text_i:",text_i, "eef pose at t=0", eef_gt[0], "joint at t=0", joint_pos_gt[0]) | |
| # initialize all history buffer | |
| video_to_save, info_to_save = [], [] | |
| his_cond, his_joint, his_eef = [], [], [] | |
| first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0) # (1, 4, 72, 40) | |
| assert first_latent.shape == (1, 4, 72, 40), f"Expected first_latent shape (1, 4, 72, 40), got {first_latent.shape}" | |
| for i in range(Agent.args.num_history*4): | |
| his_cond.append(first_latent) # (1, 4, 72, 40) | |
| his_joint.append(joint_pos_gt[0:1]) # (1, 7) | |
| his_eef.append(eef_gt[0:1]) # (1, 7) | |
| video_dict_pred = [v[0:1] for v in video_dict] | |
| # start rollout | |
| for i in range(interact_num): | |
| # get ground truth video latents | |
| # video_latent_true = [v[int(i*pred_step):int(i*pred_step+num_frames)] for v in video_latents] | |
| start_id = int(i*(pred_step-1)) | |
| end_id = start_id + pred_step | |
| video_latent_true = [v[start_id:end_id] for v in video_latents] | |
| print("################ policy forward ####################") | |
| # prepare input for policy | |
| current_joint = his_joint[-1][0] # (1, 8) | |
| current_pose = his_eef[-1][0] # (1, 8) | |
| current_obs = [v[-1] for v in video_dict_pred] | |
| # forward policy | |
| policy_in_out, joint_pos, cartesian_pose= Agent.forward_policy(current_obs, current_pose, current_joint, text=text_i) | |
| print("cartesian space action", cartesian_pose[0]) # output xyz and gripper for debug | |
| print("cartesian space action", cartesian_pose[-1]) # output xyz and gripper for debug | |
| print("################ world model forward ################") | |
| # prepare input for world model | |
| print(f'task: {text_i}, traj_id: {val_id_i}, interact step: {i}/{interact_num}') | |
| # history_idx = [0,0,-12,-9,-6,-3] | |
| history_idx = args.history_idx | |
| action_cond = np.concatenate([his_eef[idx] for idx in history_idx], axis=0) | |
| action_cond = np.concatenate([action_cond, cartesian_pose], axis=0) # (num_history+num_frames, 7) | |
| his_latent = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0) | |
| current_latent = his_cond[-1] # (1, 4, 72, 40) | |
| # forward world model | |
| videos_cat, true_videos, video_dict_pred, predict_latents = Agent.forward_wm(action_cond, video_latent_true, current_latent, his_cond=his_latent,text=text_i if Agent.args.text_cond else None) | |
| print("################ record information ################") | |
| # push current step to history buffer | |
| his_joint.append(joint_pos[pred_step-1][None,:]) # (1, 8) | |
| his_eef.append(cartesian_pose[pred_step-1][None,:]) # (1, 7) | |
| his_cond.append(torch.cat([v[pred_step-1] for v in predict_latents], dim=1).unsqueeze(0)) # (1, 4, 72, 40) | |
| video_to_save.append(videos_cat[:pred_step-1]) | |
| info_to_save.append(policy_in_out) # save policy output info | |
| # save rollout video and info with parameters | |
| print("##########################################################################") | |
| video = np.concatenate(video_to_save, axis=0) | |
| text_id = text_i.replace(' ', '_').replace(',', '').replace('.', '').replace('\'', '').replace('\"', '')[:40] | |
| uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename_video = f"{args.save_dir}/{args.task_name}/video/{args.task_type}_time_{uuid}_traj_{val_id_i}_{start_idx_i}_{args.policy_skip_step}_{text_id}.mp4" | |
| os.makedirs(os.path.dirname(filename_video), exist_ok=True) | |
| mediapy.write_video(filename_video, video, fps=4) | |
| print(f"Saving video to {filename_video}") | |
| info = {'success': 1, 'start_idx': 0, 'end_idx': video.shape[0]-1, 'instructions':text_i} | |
| for key in info_to_save[0].keys(): | |
| info[key] = [] | |
| for i in range(len(info_to_save)): | |
| info[key]+=info_to_save[i][key].tolist() | |
| # save to json | |
| filename_info = f"{args.save_dir}/{args.task_name}/info/{args.task_type}_time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{text_id}.json" | |
| os.makedirs(os.path.dirname(filename_info), exist_ok=True) | |
| with open(filename_info, 'w') as f: | |
| json.dump(info, f, indent=4) | |
| print(f"Saving trajectory info to {filename_info}") | |
| print("##########################################################################") | |
| # CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 python rollout_interact_pi.py --task_type pickplace | |
| # CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 python scripts/rollout_interact_pi.py --dataset_root_path dataset_example --dataset_meta_info_path dataset_meta_info --dataset_names droid_subset --task_type pickplace |