| from tqdm import trange |
|
|
| def fill_memory(agent, env, num_episodes=500 ): |
| print("Filling up memory....") |
| for _ in trange(num_episodes): |
| state = env.reset() |
| done = False |
| while not done: |
| action = agent.act(state) |
| next_state, reward, done, _ = env.step(action) |
| agent.cache(state, next_state, action, reward, done) |
| state = next_state |
|
|
|
|
| def train(agent, env, logger): |
| episodes = 5000 |
| for e in range(episodes): |
|
|
| state = env.reset() |
| |
| while True: |
| |
| |
| action = agent.act(state) |
| |
| |
| next_state, reward, done, info = env.step(action) |
| |
| |
| agent.cache(state, next_state, action, reward, done) |
|
|
| |
| q, loss = agent.learn() |
|
|
| |
| logger.log_step(reward, loss, q) |
|
|
| |
| state = next_state |
| |
| |
| if done: |
| break |
|
|
| logger.log_episode(e) |
|
|
| if e % 20 == 0: |
| logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step) |
|
|