| | |
| | |
| | |
| | |
| | @@ -4,7 +4,9 @@ from typing import Dict, Optional |
| | import gym |
| | import metaworld |
| | import numpy as np |
| | +import pandas as pd |
| | import torch |
| | +from datasets import Dataset |
| | from huggingface_hub import HfApi, repocard, upload_folder |
| | from sample_factory.algo.learning.learner import Learner |
| | from sample_factory.algo.sampling.batched_sampling import preprocess_actions |
| | @@ -12,11 +14,7 @@ from sample_factory.algo.utils.action_distributions import argmax_actions |
| | from sample_factory.algo.utils.env_info import extract_env_info |
| | from sample_factory.algo.utils.make_env import make_env_func_batched |
| | from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs |
| | -from sample_factory.cfg.arguments import ( |
| | - load_from_checkpoint, |
| | - parse_full_cfg, |
| | - parse_sf_args, |
| | -) |
| | +from sample_factory.cfg.arguments import load_from_checkpoint, parse_full_cfg, parse_sf_args |
| | from sample_factory.envs.env_utils import register_env |
| | from sample_factory.model.actor_critic import create_actor_critic |
| | from sample_factory.model.model_utils import get_rnn_size |
| | @@ -165,10 +163,8 @@ def create_dataset(cfg: Config): |
| | # Create dataset |
| | dataset_size = 100_000 |
| | dataset = { |
| | - "observations": np.zeros( |
| | - (dataset_size, *env.observation_space["obs"].shape), dtype=env.observation_space["obs"].dtype |
| | - ), |
| | - "actions": np.zeros((dataset_size, *env.action_space.shape), env.action_space.dtype), |
| | + "observations": np.zeros((dataset_size, *env.observation_space["obs"].shape), dtype=np.float32), |
| | + "actions": np.zeros((dataset_size, *env.action_space.shape), np.float32), |
| | "dones": np.zeros((dataset_size,), bool), |
| | "rewards": np.zeros((dataset_size,), np.float32), |
| | } |
| | @@ -206,6 +202,13 @@ def create_dataset(cfg: Config): |
| | |
| | env.close() |
| | |
| | + # Convert dict of numpy array to pandas dataframe |
| | +# dataset = Dataset.from_dict(dataset) |
| | +# dataset.create_config_id |
| | + # Set the card of the dataset |
| | +# dataset.card = f"""""" |
| | +# dataset.push_to_hub("qgallouedec/prj_gia_dataset_metaworld_assembly_v2_1111_demo") |
| | + |
| | # Save dataset |
| | repo_path = f"{cfg.train_dir}/datasets/{cfg.experiment}" |
| | os.makedirs(repo_path, exist_ok=True) |
| | |
| | |
| | |
| | |
| | @@ -1,34 +1,6 @@ |
| | #!/bin/bash |
| | |
| | ENVS=( |
| | - assembly |
| | - basketball |
| | - bin-picking |
| | - box-close |
| | - button-press-topdown |
| | - button-press-topdown-wall |
| | - button-press |
| | - button-press-wall |
| | - coffee-button |
| | - coffee-pull |
| | - coffee-push |
| | - dial-turn |
| | - disassemble |
| | - door-close |
| | - door-lock |
| | - door-open |
| | - door-unlock |
| | - drawer-close |
| | - drawer-open |
| | - faucet-close |
| | - faucet-open |
| | - hammer |
| | - hand-insert |
| | - handle-press-side |
| | - handle-press |
| | - handle-pull-side |
| | - handle-pull |
| | - lever-pull |
| | peg-insert-side |
| | peg-unplug-side |
| | pick-out-of-hole |
| | @@ -40,19 +12,8 @@ ENVS=( |
| | plate-slide |
| | push-back |
| | push |
| | - push-wall |
| | - reach |
| | - reach-wall |
| | - shelf-place |
| | - soccer |
| | - stick-pull |
| | - stick-push |
| | - sweep-into |
| | - sweep |
| | - window-close |
| | - window-open |
| | ) |
| | |
| | for ENV in "${ENVS[@]}"; do |
| | - python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir |
| | + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir |
| | done |
| |
|