gnn_wm / Ctrl-World /scripts /rollout_interact_pi_eval.py
EndeavourDD's picture
Add files using upload-large-folder tool
09a71b2 verified
from openpi.training import config as config_pi
from openpi.policies import policy_config
from openpi_client import image_tools
# from openpi.shared import download
import numpy as np
from accelerate import Accelerator
import torch
from diffusers import StableVideoDiffusionPipeline
import numpy as np
# import cv2
import torch
import torch.nn.functional as F
import torch.nn as nn
import einops
from accelerate import Accelerator
import datetime
import os
from accelerate.logging import get_logger
from tqdm.auto import tqdm
import wandb
import json
from decord import VideoReader, cpu
import swanlab
import mediapy
import sys
from scipy.spatial.transform import Rotation as R
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline
from models.ctrl_world import CrtlWorld
from models.utils import key_board_control, get_fk_solution
class agent():
def __init__(self,args):
# args = Args()
args.val_model_path = args.ckpt_path
self.args = args
self.accelerator = Accelerator()
self.device = self.accelerator.device
self.dtype = args.dtype
# load pi policy
if 'pi05' in args.policy_type:
config = config_pi.get_config("pi05_droid")
# checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid'
elif 'pi0fast' in args.policy_type:
config = config_pi.get_config("pi0fast_droid")
# checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0fast_droid'
elif 'pi0' in args.policy_type:
config = config_pi.get_config("pi0_droid")
# checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0_droid'
else:
raise ValueError(f"Unknown policy type: {args.policy_type}")
self.policy = policy_config.create_trained_policy(config, args.pi_ckpt)
# load ctrl-world model
self.model = CrtlWorld(args)
self.model.load_state_dict(torch.load(args.val_model_path))
self.model.to(self.accelerator.device).to(self.dtype)
self.model.eval()
print("load world model success")
with open(f"{args.data_stat_path}", 'r') as f:
data_stat = json.load(f)
self.state_p01 = np.array(data_stat['state_01'])[None,:]
self.state_p99 = np.array(data_stat['state_99'])[None,:]
# Since the official Pi-Droid model output joint velocity, and crtl-world is train on cartesian space, we need to load an light-weight adapter to transform joint velocity action into cartesian pose action.
if args.action_adapter is not None:
from models.action_adapter.train2 import Dynamics
self.dynamics_model = Dynamics(action_dim=7, action_num=15, hidden_size=512).to(self.device)
self.dynamics_model.load_state_dict(torch.load(args.action_adapter, map_location=self.device))
def normalize_bound(
self,
data: np.ndarray,
data_min: np.ndarray,
data_max: np.ndarray,
clip_min: float = -1,
clip_max: float = 1,
eps: float = 1e-8,
) -> np.ndarray:
ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1
return np.clip(ndata, clip_min, clip_max)
def get_traj_info(self, id, start_idx=0, steps=8,skip=1):
val_dataset_dir = self.args.val_dataset_dir
num_frames = steps
annotation_path = f"{val_dataset_dir}/annotation/val/{id}.json"
with open(annotation_path) as f:
anno = json.load(f)
try:
length = len(anno['action'])
except:
length = anno["video_length"]
frames_ids = np.arange(start_idx, start_idx + num_frames * skip, skip)
max_ids = np.ones_like(frames_ids) * (length - 1)
frames_ids = np.min([frames_ids, max_ids], axis=0).astype(int)
print("Ground truth frames ids", frames_ids)
# get action and joint pos
instruction = anno['texts'][0]
car_action = np.array(anno['states'])
car_action = car_action[frames_ids]
joint_pos = np.array(anno['joints'])
joint_pos = joint_pos[frames_ids]
# get videos
video_dict =[]
video_latent = []
for id in range(len(anno['videos'])):
video_path = anno['videos'][id]['video_path']
video_path = f"{val_dataset_dir}/{video_path}"
# load videos from all views
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
try:
true_video = vr.get_batch(range(length)).asnumpy()
except:
true_video = vr.get_batch(range(length)).numpy()
true_video = true_video[frames_ids]
video_dict.append(true_video)
# encode video
device = self.device
true_video = torch.from_numpy(true_video).to(self.dtype).to(device)
x = true_video.permute(0,3,1,2).to(device) / 255.0*2-1
vae = self.model.pipeline.vae
with torch.no_grad():
batch_size = 32
latents = []
for i in range(0, len(x), batch_size):
batch = x[i:i+batch_size]
latent = vae.encode(batch).latent_dist.sample().mul_(vae.config.scaling_factor)
latents.append(latent)
x = torch.cat(latents, dim=0)
video_latent.append(x)
return car_action, joint_pos, video_dict, video_latent, instruction
def forward_wm(self, action_cond, video_latent_true, video_latent_cond, his_cond=None, text=None):
# action_cond, video_latent_true, current_latent, his_cond=his_latent,text=text_i
args = self.args
image_cond = video_latent_cond
# action should be normed
action_cond = self.normalize_bound(action_cond, self.state_p01, self.state_p99, clip_min=-1, clip_max=1)
action_cond = torch.tensor(action_cond).unsqueeze(0).to(self.device).to(self.dtype)
assert image_cond.shape[1:] == (4, 72, 40)
assert action_cond.shape[1:] == (args.num_frames+args.num_history, args.action_dim)
# predict future frames
with torch.no_grad():
bsz = action_cond.shape[0]
if text is not None:
text_token = self.model.action_encoder(action_cond, text, self.model.tokenizer, self.model.text_encoder)
else:
text_token = self.model.action_encoder(action_cond)
pipeline = self.model.pipeline
_, latents = CtrlWorldDiffusionPipeline.__call__(
pipeline,
image=image_cond,
text=text_token,
width=args.width,
height=int(args.height*3),
num_frames=args.num_frames,
history=his_cond,
num_inference_steps=args.num_inference_steps,
decode_chunk_size=args.decode_chunk_size,
max_guidance_scale=args.guidance_scale,
fps=args.fps,
motion_bucket_id=args.motion_bucket_id,
mask=None,
output_type='latent',
return_dict=False,
frame_level_cond=True,
)
latents = einops.rearrange(latents, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32)
# decode ground truth video
true_video = torch.stack(video_latent_true, dim=0) # (bsz, 8,32,32)
decoded_video = []
bsz,frame_num = true_video.shape[:2]
true_video = true_video.flatten(0,1)
decode_kwargs = {}
for i in range(0,true_video.shape[0],args.decode_chunk_size):
chunk = true_video[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
decode_kwargs["num_frames"] = chunk.shape[0]
decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
true_video = torch.cat(decoded_video,dim=0)
true_video = true_video.reshape(bsz,frame_num,*true_video.shape[1:])
true_video = ((true_video / 2.0 + 0.5).clamp(0, 1)*255)
true_video = true_video.detach().to(torch.float32).cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3)
# decode predicted video
decoded_video = []
bsz,frame_num = latents.shape[:2]
x = latents.flatten(0,1)
decode_kwargs = {}
for i in range(0,x.shape[0],args.decode_chunk_size):
chunk = x[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
decode_kwargs["num_frames"] = chunk.shape[0]
decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
videos = torch.cat(decoded_video,dim=0)
videos = videos.reshape(bsz,frame_num,*videos.shape[1:])
videos = ((videos / 2.0 + 0.5).clamp(0, 1)*255)
videos = videos.detach().to(torch.float32).cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8)
# concatenate true videos and video
videos_cat = np.concatenate([true_video,videos],axis=-3) # (3, 8, 256, 256, 3)
videos_cat = np.concatenate([video for video in videos_cat],axis=-2).astype(np.uint8)
return videos_cat, true_video, videos, latents # np.uint8:(3, 8, 128, 256, 3) or (3, 8, 192, 320, 3)
def forward_policy(self, videos, state, joints, text, time_step=1):
# inference policy
image1 = videos[1]
image2 = videos[2]
image1 = torch.from_numpy(image1).to(torch.uint8) # convert to torch tensor
image2 = torch.from_numpy(image2).to(torch.uint8) # convert to torch tensor
assert image1.shape == (192, 320, 3), "Image 1 shape should be (192, 320, 3), got {}".format(image1.shape)
image1 = torch.nn.functional.interpolate(image1.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
image2 = torch.nn.functional.interpolate(image2.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
image1 = image1.numpy() # convert back to numpy array
image2 = image2.numpy() # convert back to numpy array
example = {
"observation/exterior_image_1_left": image_tools.resize_with_pad(image1, 224, 224),
"observation/wrist_image_left": image_tools.resize_with_pad(image2, 224, 224),
"observation/joint_position": joints[:7],
"observation/gripper_position": joints[-1:],
"prompt": text,
}
action_chunk = self.policy.infer(example)["actions"] #(10,8) velocity
# action adapater
current_joint = joints[None,:][:,:7]
current_gripper = joints[None,:][:,7:]
if 'pi05' in self.args.policy_type:
idx = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14] # for dynamics model, we need 15 steps
else:
idx = [0,1,2,3,4,5,6,7,8,9,9,9,9,9,9]
# policy output joint velocity and gripper position
joint_vel = action_chunk[:,:7] # (15, 7)
gripper_pos = action_chunk[:,7:] # (15, 1)
joint_vel = joint_vel[idx] # (15, 7)
gripper_pos = gripper_pos[idx] # (15, 1)
gripper_max = self.args.gripper_max
z_min = self.args.z_min
gripper_pos = np.clip(gripper_pos, 0, gripper_max)
# calculate future joint positions
joint_pos = self.dynamics_model(current_joint, joint_vel,None, training=False)
# fk
state_fk = []
joint_pos = np.concatenate([current_joint, joint_pos], axis=0)[:15] # (15, 7)
gripper_pos = np.concatenate([current_gripper, gripper_pos], axis=0)[:15] # (15, 1)
joint_vel = joint_vel # (15, 7)
for i in range(joint_pos.shape[0]):
current_state_fk = get_fk_solution(joint_pos[i,:7])
xyz = current_state_fk[:3, 3]
# clip z axis to avoid collision with table
xyz[2] = np.clip(xyz[2], z_min, None)
rotation_matrix = current_state_fk[:3, :3]
r = R.from_matrix(rotation_matrix)
euler = r.as_euler('xyz')
state_fk.append(np.concatenate([xyz, euler, gripper_pos[i]], axis=0))
state_fk = np.array(state_fk) # (15,7)
# prepare output
skip = self.args.policy_skip_step
valid_num = int(skip*(self.args.pred_step-1))
policy_in_out = {
'joint_pos': joint_pos[:valid_num], # (12, 7)
'joint_vel': joint_vel[:valid_num], # (12, 7)
'state_fk': state_fk[:valid_num], # (12, 7)
}
state_fk_skip = state_fk[::skip][:self.args.pred_step] # (5, 7)
joint_pos_skip = joint_pos[::skip][:self.args.pred_step] # (5, 7)
joint_pos_skip = np.concatenate([joint_pos_skip, state_fk_skip[:,-1:]], axis=-1) # (5, 8) add gripper pos
return policy_in_out, joint_pos_skip, state_fk_skip
if __name__ == "__main__":
from config_eval import wm_args
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--svd_model_path', type=str, default=None)
parser.add_argument('--clip_model_path', type=str, default=None)
parser.add_argument('--ckpt_path', type=str, default=None)
parser.add_argument('--dataset_root_path', type=str, default=None)
parser.add_argument('--dataset_meta_info_path', type=str, default=None)
parser.add_argument('--dataset_names', type=str, default=None)
parser.add_argument('--task_type', type=str, default=None)
parser.add_argument('--pi_ckpt', type=str, default='/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid')
args_new = parser.parse_args()
args = wm_args(task_type=args_new.task_type)
def merge_args(cfg, cli_args):
for k, v in vars(cli_args).items():
if v is not None:
setattr(cfg, k, v)
return cfg
args = merge_args(args, args_new)
# create agent
Agent = agent(args)
interact_num = args.interact_num
pred_step = args.pred_step
num_history = args.num_history
num_frames = args.num_frames
history_idx = args.history_idx
# run len(val_id) trajectory
for val_id_i, text_i, start_idx_i in zip(args.val_id, args.instruction, args.start_idx):
# get initial state and groud truth
id = val_id_i
eef_gt, joint_pos_gt, video_dict, video_latents,_ = Agent.get_traj_info(val_id_i, start_idx=start_idx_i, steps=int(pred_step*interact_num+8))
print("text_i:",text_i, "eef pose at t=0", eef_gt[0], "joint at t=0", joint_pos_gt[0])
# initialize all history buffer
video_to_save, info_to_save = [], []
his_cond, his_joint, his_eef = [], [], []
first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0) # (1, 4, 72, 40)
assert first_latent.shape == (1, 4, 72, 40), f"Expected first_latent shape (1, 4, 72, 40), got {first_latent.shape}"
for i in range(Agent.args.num_history*4):
his_cond.append(first_latent) # (1, 4, 72, 40)
his_joint.append(joint_pos_gt[0:1]) # (1, 7)
his_eef.append(eef_gt[0:1]) # (1, 7)
video_dict_pred = [v[0:1] for v in video_dict]
# start rollout
for i in range(interact_num):
# get ground truth video latents
# video_latent_true = [v[int(i*pred_step):int(i*pred_step+num_frames)] for v in video_latents]
start_id = int(i*(pred_step-1))
end_id = start_id + pred_step
video_latent_true = [v[start_id:end_id] for v in video_latents]
print("################ policy forward ####################")
# prepare input for policy
current_joint = his_joint[-1][0] # (1, 8)
current_pose = his_eef[-1][0] # (1, 8)
current_obs = [v[-1] for v in video_dict_pred]
# forward policy
policy_in_out, joint_pos, cartesian_pose= Agent.forward_policy(current_obs, current_pose, current_joint, text=text_i)
print("cartesian space action", cartesian_pose[0]) # output xyz and gripper for debug
print("cartesian space action", cartesian_pose[-1]) # output xyz and gripper for debug
print("################ world model forward ################")
# prepare input for world model
print(f'task: {text_i}, traj_id: {val_id_i}, interact step: {i}/{interact_num}')
# history_idx = [0,0,-12,-9,-6,-3]
history_idx = args.history_idx
action_cond = np.concatenate([his_eef[idx] for idx in history_idx], axis=0)
action_cond = np.concatenate([action_cond, cartesian_pose], axis=0) # (num_history+num_frames, 7)
his_latent = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0)
current_latent = his_cond[-1] # (1, 4, 72, 40)
# forward world model
videos_cat, true_videos, video_dict_pred, predict_latents = Agent.forward_wm(action_cond, video_latent_true, current_latent, his_cond=his_latent,text=text_i if Agent.args.text_cond else None)
print("################ record information ################")
# push current step to history buffer
his_joint.append(joint_pos[pred_step-1][None,:]) # (1, 8)
his_eef.append(cartesian_pose[pred_step-1][None,:]) # (1, 7)
his_cond.append(torch.cat([v[pred_step-1] for v in predict_latents], dim=1).unsqueeze(0)) # (1, 4, 72, 40)
video_to_save.append(videos_cat[:pred_step-1])
info_to_save.append(policy_in_out) # save policy output info
# save rollout video and info with parameters
print("##########################################################################")
video = np.concatenate(video_to_save, axis=0)
text_id = text_i.replace(' ', '_').replace(',', '').replace('.', '').replace('\'', '').replace('\"', '')[:40]
uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename_video = f"{args.save_dir}/{args.task_name}/video/{args.task_type}_time_{uuid}_traj_{val_id_i}_{start_idx_i}_{args.policy_skip_step}_{text_id}.mp4"
os.makedirs(os.path.dirname(filename_video), exist_ok=True)
mediapy.write_video(filename_video, video, fps=4)
print(f"Saving video to {filename_video}")
info = {'success': 1, 'start_idx': 0, 'end_idx': video.shape[0]-1, 'instructions':text_i}
for key in info_to_save[0].keys():
info[key] = []
for i in range(len(info_to_save)):
info[key]+=info_to_save[i][key].tolist()
# save to json
filename_info = f"{args.save_dir}/{args.task_name}/info/{args.task_type}_time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{text_id}.json"
os.makedirs(os.path.dirname(filename_info), exist_ok=True)
with open(filename_info, 'w') as f:
json.dump(info, f, indent=4)
print(f"Saving trajectory info to {filename_info}")
print("##########################################################################")
# CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 python rollout_interact_pi.py --task_type pickplace
# CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 python scripts/rollout_interact_pi.py --dataset_root_path dataset_example --dataset_meta_info_path dataset_meta_info --dataset_names droid_subset --task_type pickplace