Spaces:
Runtime error
Runtime error
File size: 2,812 Bytes
4b714e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | import numpy as np
from apple.envs.discrete_apple import get_apple_env
def test_discrete_apple_phase1():
c = 0.5
timelimit = 30
env = get_apple_env("phase1", start_x=0, goal_x=10, c=c, time_limit=timelimit)
observations, actions, rewards, done = [], [], [], False
obs = env.reset()
for i in range(timelimit):
action = np.random.choice([0, 1], p=[0.2, 0.8])
obs, reward, done, info = env.step(action)
observations.append(obs)
actions.append(action)
rewards.append(reward)
if done:
break
observations = np.array(observations)
actions = np.array(actions)
rewards = np.array(rewards)
target_rewards = np.ones(len(actions)) * actions * 2 - 1
if info["success"]:
target_rewards[-1] = 100
target_states = np.stack([np.ones(len(actions)), np.ones(len(actions)) * -c], axis=1)
assert (rewards == target_rewards).all()
assert (observations == target_states).all()
def test_discrete_apple_phase2():
c = 0.5
timelimit = 30
env = get_apple_env("phase2", start_x=0, goal_x=10, c=c, time_limit=timelimit)
observations, actions, rewards, done = [], [], [], False
obs = env.reset()
for i in range(timelimit):
action = np.random.choice([0, 1], p=[0.8, 0.2])
obs, reward, done, info = env.step(action)
observations.append(obs)
actions.append(action)
rewards.append(reward)
if done:
break
observations = np.array(observations)
actions = np.array(actions)
rewards = np.array(rewards)
target_rewards = np.ones(len(actions)) * (1 - actions) * 2 - 1
if info["success"]:
target_rewards[-1] = 100
target_states = np.stack([np.ones(len(actions)), np.ones(len(actions)) * c], axis=1)
assert (rewards == target_rewards).all()
assert (observations == target_states).all()
def test_discrete_apple_full():
c = 0.5
target_rewards = np.ones(20)
target_rewards[-1] = 100
target_states = np.stack([np.ones(20), np.concatenate([np.ones(10) * -c, np.ones(10) * c])], axis=1)
env = get_apple_env("full", start_x=0, goal_x=10, c=c, time_limit=30)
observations, actions, rewards = [], [], []
obs = env.reset()
for i in range(10):
action = 1
obs, reward, done, info = env.step(action)
observations.append(obs)
actions.append(action)
rewards.append(reward)
for i in range(10):
action = 0
obs, reward, done, info = env.step(action)
observations.append(obs)
actions.append(action)
rewards.append(reward)
rewards = np.array(rewards)
observations = np.array(observations)
assert (rewards == target_rewards).all()
assert (observations == target_states).all()
|