Spaces:
Sleeping
Sleeping
File size: 5,066 Bytes
c456c14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import argparse
import copy
import time
import gym
import numpy as np
import torch
import yaml
from omegaconf import DictConfig
import uvd.utils as U
from uvd.models.preprocessors import get_preprocessor
from uvd.decomp.decomp import embedding_decomp, DEFAULT_DECOMP_KWARGS
from uvd.envs.evaluator.inference_wrapper import InferenceWrapper
from uvd.envs.franka_kitchen.franka_kitchen_base import KitchenBase
MLP_CFG = """\
policy:
_target_: uvd.models.policy.MLPPolicy
observation_space: ???
action_space: ???
preprocessor: ???
obs_encoder:
__target__: uvd.models.nn.MLP
hidden_dims: [1024, 512, 256]
activation: ReLU
normalization: false
input_normalization: BatchNorm1d
input_normalization_full_obs: false
proprio_output_dim: 512
proprio_add_layernorm: true
proprio_activation: Tanh
proprio_add_noise_eval: false
actor_act: Tanh
act_head:
__target__: uvd.models.distributions.DeterministicHead
"""
GPT_CFG = """\
policy:
_target_: uvd.models.policy.GPTPolicy
observation_space: ???
action_space: ???
preprocessor: ???
use_kv_cache: true
max_seq_length: 10
obs_add: false
proprio_hidden_dim: 512
obs_encoder:
__target__: uvd.models.nn.GPT
use_wte: true
gpt_config:
block_size: 10
vocab_size: null
n_embd: 768
n_layer: 8
n_head: 8
dropout: 0.1
bias: false
use_llama_impl: true
position_embed: rotary
act_head:
__target__: uvd.models.distributions.DeterministicHead
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--policy", default="gpt")
parser.add_argument("--preprocessor_name", default="vip")
parser.add_argument("--use_uvd", action="store_true")
parser.add_argument("--n", type=int, default=100)
args = parser.parse_args()
use_gpu = torch.cuda.is_available()
if not use_gpu:
print("NO GPU FOUND")
preprocessor = get_preprocessor(
args.preprocessor_name, device="cuda" if use_gpu else None
)
policy_name = args.policy.lower()
assert policy_name in ["mlp", "gpt"]
is_causal = policy_name == "gpt"
env = KitchenBase(frame_height=224, frame_width=224)
env = InferenceWrapper(env, dummy_rtn=is_causal)
env.reset()
observation_space = gym.spaces.Dict(
rgb=gym.spaces.Box(-np.inf, np.inf, preprocessor.output_dim, np.float32),
proprio=gym.spaces.Box(-1, 1, (9,), np.float32),
milestones=gym.spaces.Box(
-np.inf, np.inf, (6,) + preprocessor.output_dim, np.float32
),
)
action_space = env.action_space
cfg = yaml.safe_load(MLP_CFG if policy_name == "mlp" else GPT_CFG)
cfg = DictConfig(cfg)
policy = U.hydra_instantiate(
cfg.policy,
observation_space=observation_space,
action_space=action_space,
preprocessor=preprocessor,
)
policy = policy.to(preprocessor.device).eval()
U.debug_model_info(policy)
if is_causal:
assert policy.causal and policy.use_kv_cache
preprocessor = policy.preprocessor
# Or load FrankaKitchen dummy datas
dummy_data = np.random.random((300, 224, 224, 3)).astype(np.float32)
emb = preprocessor.process(dummy_data, return_numpy=True)
if args.use_uvd:
_, decomp_meta = embedding_decomp(
embeddings=emb,
fill_embeddings=False,
return_intermediate_curves=False,
**DEFAULT_DECOMP_KWARGS["embed"],
)
milestones = emb[decomp_meta.milestone_indices] # nhw3
else:
milestones = emb[-1][None, ...]
env.milestones = milestones
MAX_HORIZON = 300
totals = []
for _ in range(args.n):
obs = env.reset()
if is_causal:
policy.reset_cache()
times = []
for st in range(MAX_HORIZON):
t = time.time()
obs = copy.deepcopy(obs)
batchify_obs = U.batch_observations([obs], device=policy.device)
if is_causal:
# B, T, ...
cur_milestone = env.current_milestone[None, None, ...]
for k in batchify_obs:
batchify_obs[k] = batchify_obs[k][:, None, ...]
else:
# B, ...
cur_milestone = env.current_milestone[None, ...]
with torch.no_grad():
action, obs_embed, goal_embed = policy(
batchify_obs,
goal=torch.as_tensor(cur_milestone, device=policy.device),
deterministic=True,
return_embeddings=True,
input_pos=torch.tensor([st], device=policy.device)
if is_causal
else None,
)
env.current_obs_embedding = obs_embed[0].cpu().numpy()
obs, r, done, info = env.step(action[0].cpu().numpy())
step_t = time.time() - t
times.append(step_t)
times = np.sum(times)
print(times)
totals.append(times)
print(np.mean(totals))
|