Spaces:
Running
Running
| from dizoo.classic_control.cartpole.offline_data.collect_dqn_data_config import main_config, create_config | |
| from ding.entry import serial_pipeline_offline | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from ding.config import read_config, compile_config | |
| from ding.utils.data import create_dataset | |
| def train(args): | |
| config = [main_config, create_config] | |
| input_cfg = config | |
| if isinstance(input_cfg, str): | |
| cfg, create_cfg = read_config(input_cfg) | |
| else: | |
| cfg, create_cfg = input_cfg | |
| create_cfg.policy.type = create_cfg.policy.type + '_command' | |
| cfg = compile_config(cfg, seed=args.seed, auto=True, create_cfg=create_cfg) | |
| # Dataset | |
| dataset = create_dataset(cfg) | |
| print(dataset.__len__()) | |
| # print(dataset.__getitem__(0)) | |
| print(dataset.__getitem__(0)[0]['action']) | |
| # episode_action = [] | |
| # for i in range(dataset.__getitem__(0).__len__()): # length of the firse collected episode | |
| # episode_action.append(dataset.__getitem__(0)[i]['action']) | |
| # stacked action of the first collected episode | |
| episode_action = torch.stack( | |
| [dataset.__getitem__(0)[i]['action'] for i in range(dataset.__getitem__(0).__len__())], axis=0 | |
| ) | |
| # dataloader = DataLoader(dataset, cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) | |
| # for i, train_data in enumerate(dataloader): | |
| # print(i, train_data) | |
| # serial_pipeline_offline(config, seed=args.seed) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--seed', '-s', type=int, default=0) | |
| args = parser.parse_args() | |
| train(args) | |