Spaces:
Runtime error
Runtime error
Thomas Simonini commited on
Commit ·
197c70e
1
Parent(s): 167b87e
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,8 @@ from stable_baselines3.common.vec_env import VecFrameStack
|
|
| 10 |
|
| 11 |
from stable_baselines3.common.env_util import make_atari_env
|
| 12 |
|
|
|
|
|
|
|
| 13 |
# Loading functions were taken from Edward Beeching code
|
| 14 |
def load_env(env_name):
|
| 15 |
env = make_atari_env(env_name, n_envs=1)
|
|
@@ -32,19 +34,23 @@ def load_model(env_name):
|
|
| 32 |
|
| 33 |
return model
|
| 34 |
|
| 35 |
-
def replay(env_name, time_sleep):
|
| 36 |
env = load_env(env_name)
|
| 37 |
model = load_model(env_name)
|
| 38 |
#for i in range(num_episodes):
|
| 39 |
obs = env.reset()
|
| 40 |
done = False
|
|
|
|
| 41 |
while not done:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
demo = gr.Interface(
|
| 50 |
replay,
|
|
@@ -53,6 +59,7 @@ demo = gr.Interface(
|
|
| 53 |
"SeaquestNoFrameskip-v4",
|
| 54 |
"QbertNoFrameskip-v4",
|
| 55 |
]),
|
|
|
|
| 56 |
gr.Slider(0.01, 1, value=0.05),
|
| 57 |
#gr.Slider(1, 20, value=5)
|
| 58 |
],
|
|
|
|
| 10 |
|
| 11 |
from stable_baselines3.common.env_util import make_atari_env
|
| 12 |
|
| 13 |
+
max_steps = 5000 # Let's try with 5000 steps.
|
| 14 |
+
|
| 15 |
# Loading functions were taken from Edward Beeching code
|
| 16 |
def load_env(env_name):
|
| 17 |
env = make_atari_env(env_name, n_envs=1)
|
|
|
|
| 34 |
|
| 35 |
return model
|
| 36 |
|
| 37 |
+
def replay(env_name, max_steps, time_sleep):
|
| 38 |
env = load_env(env_name)
|
| 39 |
model = load_model(env_name)
|
| 40 |
#for i in range(num_episodes):
|
| 41 |
obs = env.reset()
|
| 42 |
done = False
|
| 43 |
+
i = 0
|
| 44 |
while not done:
|
| 45 |
+
i++
|
| 46 |
+
if i < max_steps:
|
| 47 |
+
frame = env.render(mode="rgb_array")
|
| 48 |
+
action, _states = model.predict(obs)
|
| 49 |
+
obs, reward, done, info = env.step([action])
|
| 50 |
+
time.sleep(time_sleep)
|
| 51 |
+
yield frame
|
| 52 |
+
else:
|
| 53 |
+
break
|
| 54 |
|
| 55 |
demo = gr.Interface(
|
| 56 |
replay,
|
|
|
|
| 59 |
"SeaquestNoFrameskip-v4",
|
| 60 |
"QbertNoFrameskip-v4",
|
| 61 |
]),
|
| 62 |
+
gr.Slider(100, 10000, value=500),
|
| 63 |
gr.Slider(0.01, 1, value=0.05),
|
| 64 |
#gr.Slider(1, 20, value=5)
|
| 65 |
],
|