Spaces:
Sleeping
Sleeping
| import argparse | |
| import copy | |
| import time | |
| import gym | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from omegaconf import DictConfig | |
| import uvd.utils as U | |
| from uvd.models.preprocessors import get_preprocessor | |
| from uvd.decomp.decomp import embedding_decomp, DEFAULT_DECOMP_KWARGS | |
| from uvd.envs.evaluator.inference_wrapper import InferenceWrapper | |
| from uvd.envs.franka_kitchen.franka_kitchen_base import KitchenBase | |
| MLP_CFG = """\ | |
| policy: | |
| _target_: uvd.models.policy.MLPPolicy | |
| observation_space: ??? | |
| action_space: ??? | |
| preprocessor: ??? | |
| obs_encoder: | |
| __target__: uvd.models.nn.MLP | |
| hidden_dims: [1024, 512, 256] | |
| activation: ReLU | |
| normalization: false | |
| input_normalization: BatchNorm1d | |
| input_normalization_full_obs: false | |
| proprio_output_dim: 512 | |
| proprio_add_layernorm: true | |
| proprio_activation: Tanh | |
| proprio_add_noise_eval: false | |
| actor_act: Tanh | |
| act_head: | |
| __target__: uvd.models.distributions.DeterministicHead | |
| """ | |
| GPT_CFG = """\ | |
| policy: | |
| _target_: uvd.models.policy.GPTPolicy | |
| observation_space: ??? | |
| action_space: ??? | |
| preprocessor: ??? | |
| use_kv_cache: true | |
| max_seq_length: 10 | |
| obs_add: false | |
| proprio_hidden_dim: 512 | |
| obs_encoder: | |
| __target__: uvd.models.nn.GPT | |
| use_wte: true | |
| gpt_config: | |
| block_size: 10 | |
| vocab_size: null | |
| n_embd: 768 | |
| n_layer: 8 | |
| n_head: 8 | |
| dropout: 0.1 | |
| bias: false | |
| use_llama_impl: true | |
| position_embed: rotary | |
| act_head: | |
| __target__: uvd.models.distributions.DeterministicHead | |
| """ | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--policy", default="gpt") | |
| parser.add_argument("--preprocessor_name", default="vip") | |
| parser.add_argument("--use_uvd", action="store_true") | |
| parser.add_argument("--n", type=int, default=100) | |
| args = parser.parse_args() | |
| use_gpu = torch.cuda.is_available() | |
| if not use_gpu: | |
| print("NO GPU FOUND") | |
| preprocessor = get_preprocessor( | |
| args.preprocessor_name, device="cuda" if use_gpu else None | |
| ) | |
| policy_name = args.policy.lower() | |
| assert policy_name in ["mlp", "gpt"] | |
| is_causal = policy_name == "gpt" | |
| env = KitchenBase(frame_height=224, frame_width=224) | |
| env = InferenceWrapper(env, dummy_rtn=is_causal) | |
| env.reset() | |
| observation_space = gym.spaces.Dict( | |
| rgb=gym.spaces.Box(-np.inf, np.inf, preprocessor.output_dim, np.float32), | |
| proprio=gym.spaces.Box(-1, 1, (9,), np.float32), | |
| milestones=gym.spaces.Box( | |
| -np.inf, np.inf, (6,) + preprocessor.output_dim, np.float32 | |
| ), | |
| ) | |
| action_space = env.action_space | |
| cfg = yaml.safe_load(MLP_CFG if policy_name == "mlp" else GPT_CFG) | |
| cfg = DictConfig(cfg) | |
| policy = U.hydra_instantiate( | |
| cfg.policy, | |
| observation_space=observation_space, | |
| action_space=action_space, | |
| preprocessor=preprocessor, | |
| ) | |
| policy = policy.to(preprocessor.device).eval() | |
| U.debug_model_info(policy) | |
| if is_causal: | |
| assert policy.causal and policy.use_kv_cache | |
| preprocessor = policy.preprocessor | |
| # Or load FrankaKitchen dummy datas | |
| dummy_data = np.random.random((300, 224, 224, 3)).astype(np.float32) | |
| emb = preprocessor.process(dummy_data, return_numpy=True) | |
| if args.use_uvd: | |
| _, decomp_meta = embedding_decomp( | |
| embeddings=emb, | |
| fill_embeddings=False, | |
| return_intermediate_curves=False, | |
| **DEFAULT_DECOMP_KWARGS["embed"], | |
| ) | |
| milestones = emb[decomp_meta.milestone_indices] # nhw3 | |
| else: | |
| milestones = emb[-1][None, ...] | |
| env.milestones = milestones | |
| MAX_HORIZON = 300 | |
| totals = [] | |
| for _ in range(args.n): | |
| obs = env.reset() | |
| if is_causal: | |
| policy.reset_cache() | |
| times = [] | |
| for st in range(MAX_HORIZON): | |
| t = time.time() | |
| obs = copy.deepcopy(obs) | |
| batchify_obs = U.batch_observations([obs], device=policy.device) | |
| if is_causal: | |
| # B, T, ... | |
| cur_milestone = env.current_milestone[None, None, ...] | |
| for k in batchify_obs: | |
| batchify_obs[k] = batchify_obs[k][:, None, ...] | |
| else: | |
| # B, ... | |
| cur_milestone = env.current_milestone[None, ...] | |
| with torch.no_grad(): | |
| action, obs_embed, goal_embed = policy( | |
| batchify_obs, | |
| goal=torch.as_tensor(cur_milestone, device=policy.device), | |
| deterministic=True, | |
| return_embeddings=True, | |
| input_pos=torch.tensor([st], device=policy.device) | |
| if is_causal | |
| else None, | |
| ) | |
| env.current_obs_embedding = obs_embed[0].cpu().numpy() | |
| obs, r, done, info = env.step(action[0].cpu().numpy()) | |
| step_t = time.time() - t | |
| times.append(step_t) | |
| times = np.sum(times) | |
| print(times) | |
| totals.append(times) | |
| print(np.mean(totals)) | |