| import os |
| import torch |
| from pathlib import Path |
|
|
| from agent import DDQNAgent, MetricLogger |
| from wrappers import make_starpilot |
| import os |
| from train import train, fill_memory |
|
|
|
|
| env = make_starpilot() |
|
|
| use_cuda = torch.cuda.is_available() |
| print(f"Using CUDA: {use_cuda}\n") |
|
|
| checkpoint = None |
| |
|
|
| path = "checkpoints/procgen-starpilot-ddqn" |
| save_dir = Path(path) |
|
|
| isExist = os.path.exists(path) |
| if not isExist: |
| os.makedirs(path) |
|
|
| logger = MetricLogger(save_dir) |
|
|
| print("Training DDQN Agent!") |
| agent = DDQNAgent( |
| state_dim=(1, 64, 64), |
| action_dim=env.action_space.n, |
| save_dir=save_dir, |
| batch_size=256, |
| checkpoint=checkpoint, |
| exploration_rate_decay=0.999995, |
| exploration_rate_min=0.05, |
| training_frequency=1, |
| target_network_sync_frequency=200, |
| max_memory_size=50000, |
| learning_rate=0.0005, |
|
|
| ) |
|
|
| fill_memory(agent, env, 300) |
| train(agent, env, logger) |
|
|