New Author Name commited on
Commit
4b714e2
·
1 Parent(s): 8264cee
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .editorconfig +16 -0
  2. .gitignore +18 -0
  3. .pre-commit-config.yaml +29 -0
  4. .streamlit/config.toml +34 -0
  5. README.md +8 -7
  6. app.py +606 -0
  7. apple/envs/discrete_apple.py +384 -0
  8. apple/envs/img/apple.png +0 -0
  9. apple/envs/img/elf_down.png +0 -0
  10. apple/envs/img/elf_left.png +0 -0
  11. apple/envs/img/elf_right.png +0 -0
  12. apple/envs/img/g1.png +0 -0
  13. apple/envs/img/g2.png +0 -0
  14. apple/envs/img/g3.png +0 -0
  15. apple/envs/img/grass.jpg +0 -0
  16. apple/envs/img/home.png +0 -0
  17. apple/envs/img/home00.png +0 -0
  18. apple/envs/img/home01.png +0 -0
  19. apple/envs/img/home02.png +0 -0
  20. apple/envs/img/home10.png +0 -0
  21. apple/envs/img/home11.png +0 -0
  22. apple/envs/img/home12.png +0 -0
  23. apple/envs/img/home2.png +0 -0
  24. apple/envs/img/home2_with_apples.png +0 -0
  25. apple/envs/img/home_grass.png +0 -0
  26. apple/envs/img/part_grass.png +0 -0
  27. apple/envs/img/stool.png +0 -0
  28. apple/envs/img/textures.jpg +0 -0
  29. apple/envs/img/white.png +0 -0
  30. apple/evaluation/render_episode.py +22 -0
  31. apple/logger.py +384 -0
  32. apple/models/categorical_policy.py +46 -0
  33. apple/training/reinforce_trainer.py +74 -0
  34. apple/training/trainer.py +77 -0
  35. apple/utils.py +25 -0
  36. apple/wrappers.py +35 -0
  37. assets/apple_env.png +0 -0
  38. assets/example_rollout.mp4 +0 -0
  39. assets/generate_example_rollout.py +30 -0
  40. input_args.py +17 -0
  41. mrunner_exps/behavioral_cloning.py +51 -0
  42. mrunner_exps/reinforce.py +53 -0
  43. mrunner_exps/utils.py +10 -0
  44. mrunner_run.py +18 -0
  45. mrunner_runs/local.sh +3 -0
  46. mrunner_runs/remote.sh +9 -0
  47. pyproject.toml +92 -0
  48. requirements.txt +7 -0
  49. run.py +87 -0
  50. setup.cfg +10 -0
.editorconfig ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ root = true
2
+
3
+ [*]
4
+ charset = utf-8
5
+ end_of_line = lf
6
+ insert_final_newline = true
7
+ trim_trailing_whitespace = true
8
+
9
+ [*]
10
+ indent_size = 4
11
+ indent_style = space
12
+ max_line_length = 120
13
+ tab_width = 8
14
+
15
+ [*.{yml,yaml}]
16
+ indent_size = 2
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /.venv/
2
+ /.python-version
3
+
4
+ /build/
5
+ /dist/
6
+ /site/
7
+ /test-results.xml
8
+ /.coverage
9
+ /coverage.xml
10
+
11
+ /.hypothesis/
12
+ __pycache__/
13
+ *.egg-info/
14
+
15
+ /.vscode/
16
+ wandb
17
+ logs
18
+ *.pt
.pre-commit-config.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: local
3
+ hooks:
4
+ - id: autoflake
5
+ name: autoflake
6
+ entry: autoflake
7
+ args: [--in-place, --remove-all-unused-imports, --remove-unused-variables]
8
+ language: system
9
+ types_or: [python, pyi]
10
+
11
+ - id: isort
12
+ name: isort
13
+ entry: isort
14
+ args: [--quiet]
15
+ language: system
16
+ types_or: [python, pyi]
17
+
18
+ - id: black
19
+ name: black
20
+ entry: black
21
+ args: [--quiet]
22
+ language: system
23
+ types_or: [python, pyi]
24
+
25
+ - id: flake8
26
+ name: flake8
27
+ entry: pflake8
28
+ language: system
29
+ types_or: [python, pyi]
.streamlit/config.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ #theme primary
3
+ base="dark"
4
+ # Primary accent color for interactive elements.
5
+ primaryColor="f63366"
6
+
7
+ # Background color for the main content area.
8
+ #backgroundColor =
9
+
10
+ # Background color used for the sidebar and most interactive widgets.
11
+ #secondaryBackgroundColor ='grey'
12
+
13
+ # Color used for almost all text.
14
+ #textColor ='blue'
15
+
16
+ # Font family for all text in the app, except code blocks. One of "sans serif", "serif", or "monospace".
17
+ # Default: "sans serif"
18
+ font = "sans serif"
19
+
20
+ # [logger]
21
+ # level='info'
22
+ # messageFormat = "%(message)s"
23
+ #messageFormat="%(asctime)s %(message)s"
24
+
25
+ [global]
26
+
27
+ # By default, Streamlit checks if the Python watchdog module is available and, if not, prints a warning asking for you to install it. The watchdog module is not required, but highly recommended. It improves Streamlit's ability to detect changes to files in your filesystem.
28
+ # If you'd like to turn off this warning, set this to True.
29
+ # Default: false
30
+ disableWatchdogWarning = false
31
+
32
+ # If True, will show a warning when you run a Streamlit-enabled script via "python my_script.py".
33
+ # Default: true
34
+ showWarningOnDirectExecution = false
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Apple
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: streamlit
7
- sdk_version: 1.17.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Apple Retrieval
3
+ emoji: 🍎
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: streamlit
7
+ sdk_version: 1.15.2
8
  app_file: app.py
9
+ pinned: true
10
+ fullWidth: true
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from functools import partial
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import streamlit as st
8
+ import torch
9
+
10
+ from apple.envs.discrete_apple import get_apple_env
11
+ from apple.evaluation.render_episode import render_episode
12
+ from apple.logger import EpochLogger
13
+ from apple.models.categorical_policy import CategoricalPolicy
14
+ from apple.training.trainer import Trainer
15
+
16
+ QUEUE_SIZE = 1000
17
+
18
+
19
+ def init_training(
20
+ c: float = 1.0,
21
+ start_x: float = 0.0,
22
+ goal_x: float = 50.0,
23
+ time_limit: int = 200,
24
+ lr: float = 1e-3,
25
+ weight1: float = 0.0,
26
+ weight2: float = 0.0,
27
+ pretrain: str = "phase1",
28
+ finetune: str = "full",
29
+ bias_in_state: bool = True,
30
+ position_in_state: bool = False,
31
+ apple_in_state: bool = True,
32
+ ):
33
+ st.session_state.logger = EpochLogger(verbose=False)
34
+
35
+ env_kwargs = dict(
36
+ start_x=start_x,
37
+ goal_x=goal_x,
38
+ c=c,
39
+ time_limit=time_limit,
40
+ bias_in_state=bias_in_state,
41
+ position_in_state=position_in_state,
42
+ apple_in_state=apple_in_state,
43
+ )
44
+
45
+ st.session_state.env_full = get_apple_env("full", render_mode="rgb_array", **env_kwargs)
46
+ st.session_state.env_phase1 = get_apple_env(pretrain, **env_kwargs)
47
+ st.session_state.env_phase2 = get_apple_env(finetune, **env_kwargs)
48
+ st.session_state.test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]]
49
+
50
+ st.session_state.model = CategoricalPolicy(
51
+ st.session_state.env_phase1.observation_space.shape[0], 1, weight1, weight2
52
+ )
53
+ st.session_state.optim = torch.optim.SGD(st.session_state.model.parameters(), lr=lr)
54
+ st.session_state.trainer = Trainer(st.session_state.model, st.session_state.optim, st.session_state.logger)
55
+ st.session_state.train_it = 0
56
+ st.session_state.draw_it = 0
57
+ st.session_state.total_steps = 0
58
+ st.session_state.data = []
59
+
60
+ st.session_state.obs1 = st.session_state.env_phase1.reset()
61
+ st.session_state.obs2 = st.session_state.env_phase2.reset()
62
+
63
+
64
+ def init_reset():
65
+ st.session_state.rollout_iterator = iter(
66
+ partial(render_episode, st.session_state.env_full, st.session_state.model)()
67
+ )
68
+ st.session_state.last_image = dict(
69
+ x=0,
70
+ obs=st.session_state.env_full.reset(),
71
+ action=None,
72
+ reward=0,
73
+ done=False,
74
+ episode_len=0,
75
+ episode_return=0,
76
+ pixel_array=st.session_state.env_full.unwrapped.render(),
77
+ )
78
+
79
+
80
+ def select_preset():
81
+ if st.session_state.pick_preset == 0:
82
+ preset_finetuning_interference()
83
+ elif st.session_state.pick_preset == 1:
84
+ preset_train_full_from_scratch()
85
+ elif st.session_state.pick_preset == 2:
86
+ preset_task_interference()
87
+ elif st.session_state.pick_preset == 3:
88
+ preset_without_task_interference()
89
+
90
+
91
+ def preset_task_interference():
92
+ st.session_state.pick_c = 0.5
93
+ st.session_state.pick_goal_x = 20
94
+ st.session_state.pick_time_limit = 50
95
+ st.session_state.pick_lr = 0.05
96
+ st.session_state.pick_phase1task = "phase2"
97
+ st.session_state.pick_phase2task = "phase1"
98
+ st.session_state.pick_weight1 = 0.0
99
+ st.session_state.pick_weight2 = 0.0
100
+ st.session_state.pick_phase1steps = 500
101
+ st.session_state.pick_phase2steps = 500
102
+ need_reset()
103
+
104
+
105
+ def preset_finetuning_interference():
106
+ st.session_state.pick_c = 0.5
107
+ st.session_state.pick_goal_x = 20
108
+ st.session_state.pick_time_limit = 50
109
+ st.session_state.pick_lr = 0.05
110
+ st.session_state.pick_phase1task = "phase2"
111
+ st.session_state.pick_phase2task = "full"
112
+ st.session_state.pick_weight1 = 0.0
113
+ st.session_state.pick_weight2 = 0.0
114
+ st.session_state.pick_phase1steps = 500
115
+ st.session_state.pick_phase2steps = 2000
116
+ need_reset()
117
+
118
+
119
+ def preset_without_task_interference():
120
+ st.session_state.pick_c = 1.0
121
+ st.session_state.pick_goal_x = 20
122
+ st.session_state.pick_time_limit = 50
123
+ st.session_state.pick_lr = 0.05
124
+ st.session_state.pick_phase1task = "phase2"
125
+ st.session_state.pick_phase2task = "phase1"
126
+ st.session_state.pick_weight1 = 0.0
127
+ st.session_state.pick_weight2 = 0.0
128
+ st.session_state.pick_phase1steps = 500
129
+ st.session_state.pick_phase2steps = 500
130
+ need_reset()
131
+
132
+
133
+ def preset_train_full_from_scratch():
134
+ st.session_state.pick_c = 0.5
135
+ st.session_state.pick_goal_x = 20
136
+ st.session_state.pick_time_limit = 50
137
+ st.session_state.pick_lr = 0.05
138
+ st.session_state.pick_phase1task = "phase2"
139
+ st.session_state.pick_phase2task = "full"
140
+ st.session_state.pick_weight1 = 0.0
141
+ st.session_state.pick_weight2 = 0.0
142
+ st.session_state.pick_phase1steps = 0
143
+ st.session_state.pick_phase2steps = 2000
144
+ need_reset()
145
+
146
+
147
+ def empty_queue(q: asyncio.Queue):
148
+ for _ in range(q.qsize()):
149
+ # Depending on your program, you may want to
150
+ # catch QueueEmpty
151
+ q.get_nowait()
152
+ q.task_done()
153
+
154
+
155
+ def reset(**kwargs):
156
+ init_training(**kwargs)
157
+ init_reset()
158
+ st.session_state.play = False
159
+ st.session_state.step = False
160
+ st.session_state.render = False
161
+ st.session_state.done = False
162
+ empty_queue(st.session_state.queue)
163
+ empty_queue(st.session_state.queue_render)
164
+ st.session_state.play_pause = False
165
+ st.session_state.need_reset = False
166
+
167
+
168
+ def render_start():
169
+ st.session_state.render = True
170
+ st.session_state.done = False
171
+ init_reset()
172
+
173
+
174
+ def need_reset():
175
+ st.session_state.need_reset = True
176
+ st.session_state.play = False
177
+ st.session_state.render = False
178
+
179
+
180
+ def play_pause():
181
+ if st.session_state.play:
182
+ st.session_state.play = False
183
+ st.session_state.play_pause = False
184
+ else:
185
+ st.session_state.play = True
186
+ st.session_state.play_pause = True
187
+
188
+
189
+ def step():
190
+ st.session_state.step = True
191
+
192
+
193
+ def plot(data_placeholder):
194
+ df = pd.DataFrame(st.session_state.data)
195
+ if not df.empty:
196
+ df.set_index("total_env_steps", inplace=True)
197
+ container = data_placeholder.container()
198
+ c1, c2, c3 = container.columns(3)
199
+
200
+ def view_df(names):
201
+ rdf = df.loc[:, df.columns.isin(names)]
202
+ if rdf.empty:
203
+ return pd.DataFrame([{name: 0 for name in names}])
204
+ else:
205
+ return rdf
206
+
207
+ c1.write("phase1/success_rate")
208
+ c1.line_chart(view_df(["phase1/success"]))
209
+ c2.write("phase2/success_rate")
210
+ c2.line_chart(view_df(["phase2/success"]))
211
+ c3.write("full/success_rate")
212
+ c3.line_chart(view_df(["full/success"]))
213
+
214
+ c1.write("train/loss")
215
+ c1.line_chart(view_df(["train/loss"]))
216
+ c2.write("weight0")
217
+ c2.line_chart(view_df(["weight0"]))
218
+ c3.write("weight1")
219
+ c3.line_chart(view_df(["weight1"]))
220
+
221
+
222
+ async def draw(data_placeholder, queue, delay, steps, plotfrequency):
223
+ while (st.session_state.play or st.session_state.step) and st.session_state.draw_it < steps:
224
+ _ = await asyncio.sleep(delay)
225
+ new_data = await queue.get()
226
+ st.session_state.draw_it += 1
227
+ if st.session_state.draw_it % plotfrequency == 0:
228
+ st.session_state.data.append(new_data)
229
+ plot(data_placeholder)
230
+ st.session_state.step = False
231
+ queue.task_done()
232
+
233
+
234
+ async def train(queue, delay, steps, obs, env, num_eval_episodes, plotfrequency):
235
+ while (st.session_state.play or st.session_state.step) and st.session_state.train_it < steps:
236
+ _ = await asyncio.sleep(delay)
237
+ st.session_state.train_it += 1
238
+ st.session_state.total_steps += 1
239
+
240
+ output = st.session_state.model(obs)
241
+ action, log_prob = st.session_state.model.sample(output)
242
+ st.session_state.trainer.update(
243
+ env, output, st.session_state.model, st.session_state.optim, st.session_state.logger
244
+ )
245
+
246
+ obs, reward, done, info = env.step(action)
247
+
248
+ if done:
249
+ obs = env.reset()
250
+
251
+ if st.session_state.train_it % plotfrequency == 0:
252
+ st.session_state.trainer.test_agent(
253
+ st.session_state.model, st.session_state.logger, st.session_state.test_envs, num_eval_episodes
254
+ )
255
+ data = st.session_state.trainer.log(
256
+ st.session_state.logger, st.session_state.train_it, st.session_state.model
257
+ )
258
+ else:
259
+ data = 0
260
+
261
+ _ = await queue.put(data)
262
+
263
+
264
+ async def produce_images(queue, delay):
265
+ while st.session_state.render and not st.session_state.done:
266
+ _ = await asyncio.sleep(delay)
267
+ data = next(st.session_state.rollout_iterator)
268
+ st.session_state.done = data["done"]
269
+ _ = await queue.put(data)
270
+
271
+
272
+ def show_image(data, image_placeholder):
273
+ c = image_placeholder.container()
274
+ c.image(
275
+ data["pixel_array"],
276
+ )
277
+ c.text(
278
+ f"agent position: {data['x']} \ntimestep: {data['episode_len']} \nepisode return: {data['episode_return']} \n"
279
+ )
280
+
281
+
282
+ async def consume_images(image_placeholder, queue, delay):
283
+ while st.session_state.render and not st.session_state.done:
284
+ _ = await asyncio.sleep(delay)
285
+ data = await queue.get()
286
+ st.session_state.last_image = data
287
+ show_image(data, image_placeholder)
288
+ queue.task_done()
289
+
290
+
291
+ async def run_app(
292
+ data_placeholder,
293
+ queue,
294
+ produce_delay,
295
+ consume_delay,
296
+ phase1steps,
297
+ phase2steps,
298
+ plotfrequency,
299
+ num_eval_episodes,
300
+ image_placeholder,
301
+ queue_render,
302
+ render_produce_delay,
303
+ render_consume_delay,
304
+ ):
305
+ _ = await asyncio.gather(
306
+ produce_images(queue_render, render_produce_delay),
307
+ consume_images(image_placeholder, queue_render, render_consume_delay),
308
+ )
309
+
310
+ st.session_state.render = False
311
+ st.session_state.done = False
312
+
313
+ empty_queue(queue_render)
314
+
315
+ _ = await asyncio.gather(
316
+ train(
317
+ queue,
318
+ produce_delay,
319
+ phase1steps,
320
+ st.session_state.obs1,
321
+ st.session_state.env_phase1,
322
+ num_eval_episodes,
323
+ plotfrequency,
324
+ ),
325
+ draw(data_placeholder, queue, consume_delay, phase1steps, plotfrequency),
326
+ )
327
+
328
+ _ = await asyncio.gather(
329
+ train(
330
+ queue,
331
+ produce_delay,
332
+ phase1steps + phase2steps,
333
+ st.session_state.obs2,
334
+ st.session_state.env_phase2,
335
+ num_eval_episodes,
336
+ plotfrequency,
337
+ ),
338
+ draw(data_placeholder, queue, consume_delay, phase1steps + phase2steps, plotfrequency),
339
+ )
340
+
341
+
342
+ ##### ACTUAL APP
343
+
344
+ if __name__ == "__main__":
345
+ st.set_page_config(
346
+ layout="wide",
347
+ initial_sidebar_state="auto",
348
+ page_title="Apple Retrieval",
349
+ page_icon=None,
350
+ )
351
+ st.title("ON THE ROLE OF FORGETTING IN FINE-TUNING REINFORCEMENT LEARNING MODELS")
352
+ st.header("Toy example of forgetting: AppleRetrieval")
353
+
354
+ col1, col2, col3 = st.sidebar.columns(3)
355
+
356
+ options = (
357
+ "phase2 full interference",
358
+ "full from scratch",
359
+ "phase2 phase1 forgetting",
360
+ "phase2 phase1 optimal solution",
361
+ )
362
+ st.sidebar.selectbox(
363
+ "parameter presets",
364
+ range(len(options)),
365
+ index=0,
366
+ format_func=lambda x: options[x],
367
+ on_change=select_preset,
368
+ key="pick_preset",
369
+ )
370
+
371
+ pick_container = st.sidebar.container()
372
+ c = pick_container.number_input("c", value=0.5, on_change=need_reset, key="pick_c")
373
+ goal_x = pick_container.number_input("distance to apple", value=20, on_change=need_reset, key="pick_goal_x")
374
+ time_limit = pick_container.number_input("time limit", value=50, on_change=need_reset, key="pick_time_limit")
375
+ lr = pick_container.selectbox(
376
+ "Learning rate",
377
+ np.array([0.00001, 0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0]),
378
+ 5,
379
+ on_change=need_reset,
380
+ key="pick_lr",
381
+ )
382
+ phase1task = pick_container.selectbox(
383
+ "Pretraining task", ("full", "phase1", "phase2"), 2, on_change=need_reset, key="pick_phase1task"
384
+ )
385
+ phase2task = pick_container.selectbox(
386
+ "Finetuning task", ("full", "phase1", "phase2"), 0, on_change=need_reset, key="pick_phase2task"
387
+ )
388
+ # weight1 = pick_container.number_input("init weight1", value=0, on_change=need_reset, key="pick_weight1")
389
+ # weight2 = pick_container.number_input("init weight2", value=0, on_change=need_reset, key="pick_weight2")
390
+
391
+ if "event_loop" not in st.session_state:
392
+ st.session_state.loop = asyncio.new_event_loop()
393
+ asyncio.set_event_loop(st.session_state.loop)
394
+
395
+ if "queue" not in st.session_state:
396
+ st.session_state.queue = asyncio.Queue(QUEUE_SIZE)
397
+ if "queue_render" not in st.session_state:
398
+ st.session_state.queue_render = asyncio.Queue(QUEUE_SIZE)
399
+ if "play" not in st.session_state:
400
+ st.session_state.play = False
401
+ if "step" not in st.session_state:
402
+ st.session_state.step = False
403
+ if "render" not in st.session_state:
404
+ st.session_state.render = False
405
+
406
+ reset_button = partial(
407
+ reset,
408
+ c=c,
409
+ start_x=0,
410
+ goal_x=goal_x,
411
+ time_limit=time_limit,
412
+ lr=lr,
413
+ # weight1=weight1,
414
+ # weight2=weight2,
415
+ # weight1=0,
416
+ # weight2=10, # soves the environment
417
+ pretrain=phase1task,
418
+ finetune=phase2task,
419
+ )
420
+ col1.button("Reset", on_click=reset_button, type="primary")
421
+
422
+ if "logger" not in st.session_state or st.session_state.need_reset:
423
+ reset_button()
424
+
425
+ myKey = "play_pause"
426
+ if myKey not in st.session_state:
427
+ st.session_state[myKey] = False
428
+
429
+ if st.session_state[myKey]:
430
+ myBtn = col2.button("Pause", on_click=play_pause, type="primary")
431
+ else:
432
+ myBtn = col2.button("Play", on_click=play_pause, type="primary")
433
+
434
+ col3.button("Step", on_click=step, type="primary")
435
+
436
+ st.header("Summary")
437
+ st.write(
438
+ """
439
+ Run training on the "phase2 full interference" setting to see an example of forgetting in fine-tuning RL models.
440
+ A model is pre-trained on a part of the environment called Phase 2 for 500 steps,
441
+ and then it is fine-tuned on the whole environment for another 2000 steps.
442
+ However, it forgets how to perform on Phase 2 during fine-tuning before it even gets there.
443
+ We highlight this as an important problem in fine-tuning RL models.
444
+ We invite you to play around with the hyperparameters and find out more about the forgetting phenomenon.
445
+ """
446
+ )
447
+
448
+ st.markdown(
449
+ """
450
+ By comparing the results of these four presets, you can gain a deeper understanding of how different training configurations can impact the final performance of a neural network.
451
+ * `phase2 full interference` - already mentioned setting where we pretrain on phase 2 and finetune on whole environment.
452
+ * `full from scratch` - "standard" train on whole environment for 2000 steps for comparison.
453
+ * `phase2 phase1 forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks causes network to forget how to solve phase 2.
454
+ * `phase2 phase1 without forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks "move" network weights to optimal solution for both tasks.
455
+ """
456
+ )
457
+
458
+ data_placeholder = st.empty()
459
+ render_placeholder = st.empty()
460
+
461
+ phase1steps = st.sidebar.slider("Pretraining Steps", 0, 2000, 500, 10, key="pick_phase1steps")
462
+ phase2steps = st.sidebar.slider("Finetuning Steps", 0, 2000, 2000, 10, key="pick_phase2steps")
463
+ plotfrequency = st.sidebar.number_input("Log frequency ", min_value=1, value=10, step=1)
464
+ num_eval_episodes = st.sidebar.number_input("Number of evaluation episodes", min_value=1, value=10, step=1)
465
+
466
+ if not st.session_state.play and len(st.session_state.data) > 0:
467
+ plot(data_placeholder)
468
+
469
+ st.write(
470
+ """
471
+ Figure 1: Plots (a), (b), (c) show the performance of the agent on three different variants of the environment.
472
+ Phase1 (a) the agent's goal is to reach the apple.
473
+ Phase2 (b) the agent's has to go back home by returning to x = 0.
474
+ Full (c) combination of both phases of the environment.
475
+ (d) shows the training loss.
476
+ State vector has only two values, we can view the netwoks weights (e), (f).
477
+ """
478
+ )
479
+
480
+ st.header("Vizualize agent behavior")
481
+ st.button("Rollout one episode", on_click=render_start, type="primary")
482
+
483
+ c1, c2 = st.columns(2)
484
+ image_placeholder = c1.empty()
485
+
486
+ render_produce_delay = 1 / 3
487
+ render_consume_delay = 1 / 3
488
+
489
+ st.header("About the environment")
490
+
491
+ st.write(
492
+ """This study forms a preliminary investigation into the problem of forgetting in fine-tuning RL models.
493
+ We show that fine-tuning a pre-trained model on compositional RL problems might result in a rapid
494
+ deterioration of the performance of the pre-trained model if the relevant data is not available at the
495
+ beginning of the training.
496
+ This phenomenon is known as catastrophic forgetting.
497
+ In this demo we show how it can occur in simple toyish situations but it might occur in more realistic problems (e.g. (SAC with MLPs on a compositional robotic environmen).
498
+ In our [WIP] paper we showed that applying CL methods significantly limits forgetting and allows for efficient transfer.
499
+ """
500
+ )
501
+
502
+ st.write(
503
+ """The AppleRetrieval environment is a toy example to demonstrate the issue of interference in reinforcement learning.
504
+ The goal of the agent is to retrieve an apple from position x = M and return home to x = 0 within a set number of steps T.
505
+ The state of the agent is represented by a vector s, which has two elements: s = [1, -c] in phase 1 and s = [1, c] in phase 2.
506
+ The first element is a constant and the second element represents the information about the current phase.
507
+ The optimal policy is to go right in phase 1 and go left in phase 2.
508
+ The cause of interference can be identified by checking the reliance of the policy on either s1 or s2.
509
+ If the model mostly relies on s2, interference will be limited, but if it relies on s1,
510
+ interference will occur as its value is the same in both phases.
511
+ The magnitude of s1 and s2 can be adjusted by changing the c parameter to guide the model towards focusing on either one.
512
+ This simple toy environment shows that the issue of interference can be fundamental to reinforcement learning."""
513
+ )
514
+
515
+ cc1, cc2 = st.columns([1, 2])
516
+ cc1.image("assets/apple_env.png")
517
+
518
+ st.header("Training algorithm")
519
+ st.write(
520
+ """For practical reasons (time and stability) we don't use REINFORCE in this DEMO to illustrate the training dynamics of the environment.
521
+ Instead, we train the model by minimizing the negative log likelihood of target actions (move right or left).
522
+ We train the model in each step of the environment and sample actions from Categorical distribution taken from model's output.
523
+ """
524
+ )
525
+
526
+ st.markdown(
527
+ """
528
+ Pseudocode of the train loop:
529
+ ```python
530
+ obs = env.reset()
531
+ for timestep in range(steps):
532
+ probs = model(obs)
533
+ dist = Categorical(probs)
534
+ action = dist.sample()
535
+ target_action = env.get_target_action()
536
+ loss = -dist.log_prob(target_action)
537
+
538
+ optim.zero_grad()
539
+ loss.backward()
540
+ optim.step()
541
+
542
+ obs, reward, done, info = env.step(action)
543
+ if done:
544
+ obs = env.reset()
545
+ """
546
+ )
547
+
548
+ st.header("What do all the hyperparameters mean?")
549
+
550
+ st.markdown(
551
+ """
552
+ * **parameter presets** - few sets of pre-defined hyperparameters that can be used as a starting point for a specific experiment.
553
+ * **c** - second element of state vecor, decreasing this value will result in stronger forgetting.
554
+ * **distance to apple** - refers to how far agent needs to travel in the right before it encounters the apple.
555
+ * **time limit** - maximum amount of timesteps that the agent is allowed to interact with the environment before episode ends.
556
+ * **learning rate** - hyperparameter that determines the steps size taken by the learning algoritm in each iteration.
557
+ * **pretraining and finetuning task** - define the environment the model will be trained on each stage.
558
+ * **pretraining and finetuning steps** - define how long model will be trained on each stage.
559
+ * **init weights** - initial values assigned to the model's parameters before trainig begins.
560
+ * **log frequency** - refers to frequency of logging the metrics and evaluating the agent.
561
+ * **number of evaluation episodes** - number of rollouts during testing of the agent.
562
+ """
563
+ )
564
+
565
+ st.header("Limitations & Conclusions")
566
+
567
+ st.write(
568
+ """
569
+ At the same time, this study, due to its preliminary nature, has numerous limitations which we
570
+ hope to address in future work. We only considered a fairly strict formulation of the forgetting
571
+ scenario where we assumed that the pre-trained model works perfectly on tasks that appear later in
572
+ the fine-tuning. In practice, one should also consider the case when even though there are differences
573
+ between the pre-training and fine-tuning tasks, transfer is still possible.
574
+ """
575
+ )
576
+
577
+ st.write(
578
+ """
579
+ At the same time, even given
580
+ these limitations, we see forgetting as an important problem to be solved and hope that addressing
581
+ these issues in the future might help with building and fine-tuning better foundation models in RL.
582
+ """
583
+ )
584
+
585
+ produce_delay = 1 / 1000
586
+ consume_delay = 1 / 1000
587
+
588
+ plot(data_placeholder)
589
+ show_image(st.session_state.last_image, image_placeholder)
590
+
591
+ asyncio.run(
592
+ run_app(
593
+ data_placeholder,
594
+ st.session_state.queue,
595
+ produce_delay,
596
+ consume_delay,
597
+ phase1steps,
598
+ phase2steps,
599
+ plotfrequency,
600
+ num_eval_episodes,
601
+ image_placeholder,
602
+ st.session_state.queue_render,
603
+ render_produce_delay,
604
+ render_consume_delay,
605
+ )
606
+ )
apple/envs/discrete_apple.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import closing
2
+ from io import StringIO
3
+ from os import path
4
+ from typing import Optional
5
+
6
+ import gym
7
+ import gym.spaces
8
+ import numpy as np
9
+
10
+ from gym.error import DependencyNotInstalled
11
+ from gym.spaces import Box
12
+ from gym.utils import colorize
13
+ from gym.wrappers import TimeLimit
14
+
15
+ from apple.wrappers import SuccessCounter
16
+
17
+
18
+ def get_apple_env(task, time_limit=20, **kwargs):
19
+ if task == "full":
20
+ env = AppleEnv(**kwargs)
21
+ elif task == "phase1":
22
+ env = ApplePhase0Env(**kwargs)
23
+ time_limit = time_limit // 2
24
+ elif task == "phase2":
25
+ env = ApplePhase1Env(**kwargs)
26
+ time_limit = time_limit // 2
27
+ else:
28
+ raise NotImplementedError
29
+ env = TimeLimit(env, time_limit)
30
+ env = SuccessCounter(env)
31
+ env.name = task
32
+ return env
33
+
34
+
35
+ class AppleEnv(gym.Env):
36
+ metadata = {
37
+ "render_modes": ["human", "ansi", "rgb_array"],
38
+ "render_fps": 4,
39
+ }
40
+
41
+ def __init__(
42
+ self,
43
+ start_x: int,
44
+ goal_x: int,
45
+ c: float,
46
+ reward_value: float = 1.0,
47
+ success_value: float = 1.0,
48
+ bias_in_state: bool = True,
49
+ position_in_state: bool = False,
50
+ apple_in_state: bool = True,
51
+ render_mode: Optional[str] = None,
52
+ ):
53
+ self.start_x = start_x
54
+ self.goal_x = goal_x
55
+ self.c = c
56
+ self.reward_value = reward_value
57
+ self.success_value = success_value
58
+
59
+ self.x = start_x
60
+ self.phase = 0
61
+ self.delta = 1
62
+ self.change_phase = True
63
+ self.init_pos = self.start_x
64
+ self.success_when_finish_phase = 1
65
+
66
+ self.bias_in_state = bias_in_state
67
+ self.position_in_state = position_in_state
68
+ self.apple_in_state = apple_in_state
69
+
70
+ example_state = self.state()
71
+ mult = np.ones_like(example_state)
72
+ if self.apple_in_state:
73
+ mult[-1] *= -1
74
+ self.observation_space = Box(low=example_state, high=example_state * mult)
75
+ self.action_space = gym.spaces.Discrete(2)
76
+ self.timestep = 0
77
+
78
+ self.render_mode = render_mode
79
+ self.scope_size = 15
80
+
81
+ self.nrow, self.ncol = nrow, ncol = self.gui_canvas().shape
82
+ self.window_size = (64 * ncol, 64 * nrow)
83
+ self.cell_size = (
84
+ self.window_size[0] // self.ncol,
85
+ self.window_size[1] // self.nrow,
86
+ )
87
+ self.window_surface = None
88
+
89
+ self.clock = None
90
+ self.empty_img = None
91
+ self.ground_img = None
92
+ self.underground_img = None
93
+ self.elf_images = None
94
+ self.home_images = None
95
+ self.goal_img = None
96
+ self.start_img = None
97
+ self.stool_img = None
98
+
99
+ def reset(self):
100
+ self.x = self.init_pos
101
+ if self.change_phase:
102
+ self.phase = 0
103
+
104
+ self.timestep = 0
105
+ state = self.state()
106
+ self.lastaction = None
107
+ return state
108
+
109
+ def validate_action(self, action):
110
+ err_msg = f"{action!r} ({type(action)}) invalid"
111
+ assert self.action_space.contains(action), err_msg
112
+ assert self.state is not None, "Call reset before using step method."
113
+
114
+ def move(self, action):
115
+ delta = self.delta if action == 1 else -self.delta
116
+ self.x += delta
117
+
118
+ def state(self):
119
+ assert self.bias_in_state or self.position_in_state or self.apple_in_state
120
+
121
+ if self.phase == 0:
122
+ c = -self.c
123
+ elif self.phase == 1:
124
+ c = self.c
125
+ else:
126
+ raise NotImplementedError
127
+
128
+ state = []
129
+ if self.bias_in_state:
130
+ state.append(1)
131
+ if self.position_in_state:
132
+ state.append(self.x)
133
+ if self.apple_in_state:
134
+ state.append(c)
135
+
136
+ return np.array(state)
137
+
138
+ def desc(self):
139
+ s = self.scope_size // 2
140
+
141
+ desc = list("." * self.scope_size)
142
+ desc = np.asarray(desc, dtype="c")
143
+
144
+ start_relative_position = self.start_x - self.x + s
145
+ if 0 <= start_relative_position <= self.scope_size - 1:
146
+ desc[start_relative_position] = "S"
147
+
148
+ goal_relative_position = self.goal_x - self.x + s
149
+ if 0 <= goal_relative_position <= self.scope_size - 1 and self.phase == 0:
150
+ desc[goal_relative_position] = "G"
151
+
152
+ if 0 <= goal_relative_position <= self.scope_size - 1 and self.phase == 1:
153
+ desc[goal_relative_position] = "D"
154
+
155
+ return desc
156
+
157
+ def text_canvas(self):
158
+ desc = self.desc()
159
+ canvas = np.ones((2, len(desc) * 3 + 2), dtype="c")
160
+ canvas[:] = "\x20"
161
+
162
+ for i, d in zip(range(2, len(canvas[0]), 3), desc):
163
+ canvas[0][i] = d
164
+
165
+ axis = np.arange(len(desc)) - len(desc) // 2 + self.x
166
+ for i, d in zip(range(2, len(canvas[0]), 3), axis):
167
+ if d % 5 == 0:
168
+ s = str(d)
169
+
170
+ c = len(s) // 2
171
+ for j, char in zip(range(len(s)), reversed(s)):
172
+ canvas[1][i - j + c] = char
173
+ return canvas
174
+
175
+ def gui_canvas(self):
176
+ desc = self.desc()
177
+ upper_canvas = np.ones(len(desc), dtype="c")
178
+ upper_canvas[:] = "~"
179
+ lower_canvas = np.ones(len(desc), dtype="c")
180
+ lower_canvas[:] = "#"
181
+ canvas = np.stack([upper_canvas, upper_canvas, desc, lower_canvas])
182
+
183
+ return canvas
184
+
185
+ def reward(self, action):
186
+ return self.reward_value if action == self.get_target_action() else -self.reward_value
187
+
188
+ def get_target_action(self):
189
+ return 1 if self.phase == 0 else 0
190
+
191
+ def step(self, action):
192
+ self.timestep += 1
193
+ self.validate_action(action)
194
+
195
+ done = False
196
+ info = {"success": False}
197
+
198
+ reward = self.reward(action)
199
+
200
+ self.move(action)
201
+
202
+ finish_phase0 = self.phase == 0 and self.x >= self.goal_x
203
+ finish_phase1 = self.phase == 1 and self.x <= self.start_x
204
+
205
+ if self.change_phase:
206
+ if finish_phase0:
207
+ self.phase = 1
208
+
209
+ if (self.success_when_finish_phase == 0 and finish_phase0) or (
210
+ self.success_when_finish_phase == 1 and finish_phase1
211
+ ):
212
+ done = True
213
+ info["success"] = True
214
+ reward = self.success_value
215
+
216
+ state = self.state()
217
+
218
+ self.lastaction = action
219
+ if self.render_mode == "human":
220
+ self.render()
221
+
222
+ return state, reward, done, info
223
+
224
+ def render(self):
225
+ if self.render_mode is None:
226
+ assert self.spec is not None
227
+ gym.logger.warn(
228
+ "You are calling render method without specifying any render mode. "
229
+ "You can specify the render_mode at initialization, "
230
+ f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")'
231
+ )
232
+ return
233
+
234
+ if self.render_mode == "ansi":
235
+ return self._render_text()
236
+ else: # self.render_mode in {"human", "rgb_array"}:
237
+ return self._render_gui(self.render_mode)
238
+
239
+ def _render_gui(self, mode):
240
+ try:
241
+ import pygame
242
+ except ImportError as e:
243
+ raise DependencyNotInstalled("pygame is not installed, run `pip install pygame`") from e
244
+
245
+ if self.window_surface is None:
246
+ pygame.init()
247
+
248
+ if mode == "human":
249
+ pygame.display.init()
250
+ pygame.display.set_caption("Apple Retrieval")
251
+ self.window_surface = pygame.display.set_mode(self.window_size)
252
+ elif mode == "rgb_array":
253
+ self.window_surface = pygame.Surface(self.window_size)
254
+
255
+ assert self.window_surface is not None, "Something went wrong with pygame. This should never happen."
256
+
257
+ if self.clock is None:
258
+ self.clock = pygame.time.Clock()
259
+ if self.empty_img is None:
260
+ file_name = path.join(path.dirname(__file__), "img/white.png")
261
+ self.empty_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
262
+ if self.ground_img is None:
263
+ file_name = path.join(path.dirname(__file__), "img/part_grass.png")
264
+ self.ground_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
265
+ if self.underground_img is None:
266
+ file_name = path.join(path.dirname(__file__), "img/g2.png")
267
+ self.underground_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
268
+ if self.goal_img is None:
269
+ file_name = path.join(path.dirname(__file__), "img/apple.png")
270
+ self.goal_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
271
+ if self.stool_img is None:
272
+ file_name = path.join(path.dirname(__file__), "img/stool.png")
273
+ self.stool_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
274
+ if self.start_img is None:
275
+ homes = [
276
+ path.join(path.dirname(__file__), "img/home00.png"),
277
+ path.join(path.dirname(__file__), "img/home01.png"),
278
+ path.join(path.dirname(__file__), "img/home02.png"),
279
+ path.join(path.dirname(__file__), "img/home10.png"),
280
+ path.join(path.dirname(__file__), "img/home11.png"),
281
+ path.join(path.dirname(__file__), "img/home12.png"),
282
+ ]
283
+ self.home_images = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in homes]
284
+ if self.elf_images is None:
285
+ elfs = [
286
+ path.join(path.dirname(__file__), "img/elf_left.png"),
287
+ path.join(path.dirname(__file__), "img/elf_right.png"),
288
+ path.join(path.dirname(__file__), "img/elf_down.png"),
289
+ ]
290
+ self.elf_images = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in elfs]
291
+
292
+ desc = self.gui_canvas().tolist()
293
+
294
+ cache = []
295
+ assert isinstance(desc, list), f"desc should be a list or an array, got {desc}"
296
+ for y in range(self.nrow):
297
+ for x in range(self.ncol):
298
+ pos = (x * self.cell_size[0], y * self.cell_size[1])
299
+
300
+ self.window_surface.blit(self.empty_img, pos)
301
+
302
+ if desc[y][x] == b"~":
303
+ self.window_surface.blit(self.empty_img, pos)
304
+ elif desc[y][x] == b"#":
305
+ self.window_surface.blit(self.underground_img, pos)
306
+ else:
307
+ self.window_surface.blit(self.ground_img, pos)
308
+ # if y == self.nrow - 1:
309
+
310
+ if len(cache) > 0:
311
+ cache_img, cache_pos = cache.pop()
312
+ self.window_surface.blit(cache_img, cache_pos)
313
+
314
+ if desc[y][x] == b"G":
315
+ self.window_surface.blit(self.stool_img, pos)
316
+ self.window_surface.blit(self.goal_img, pos)
317
+ elif desc[y][x] == b"D":
318
+ self.window_surface.blit(self.stool_img, pos)
319
+ elif desc[y][x] == b"S":
320
+ for h in range(len(self.home_images)):
321
+ i = h // 3
322
+ j = h % 3
323
+
324
+ home_img = self.home_images[i * 3 + j]
325
+ home_pos = ((x - 1 + j) * self.cell_size[0], (y - 1 + i) * self.cell_size[1])
326
+ if h == len(self.home_images) - 1:
327
+ cache.append((home_img, home_pos))
328
+ else:
329
+ self.window_surface.blit(home_img, home_pos)
330
+
331
+ # paint the elf
332
+ # bot_row, bot_col = self.s // self.ncol, self.s % self.ncol
333
+ bot_col = self.scope_size // 2
334
+ bot_row = 2
335
+ cell_rect = (bot_col * self.cell_size[0], bot_row * self.cell_size[1])
336
+ last_action = self.lastaction if self.lastaction is not None else 2
337
+ elf_img = self.elf_images[last_action]
338
+
339
+ self.window_surface.blit(elf_img, cell_rect)
340
+
341
+ # font = pygame.font.SysFont(None, 20)
342
+ # img = font.render(f"agent position = {self.x}", True, "black")
343
+ # self.window_surface.blit(img, (5, 5))
344
+ # img = font.render(f"timestep = {self.timestep}", True, "black")
345
+ # self.window_surface.blit(img, (5, 25))
346
+
347
+ if mode == "human":
348
+ pygame.event.pump()
349
+ pygame.display.update()
350
+ self.clock.tick(self.metadata["render_fps"])
351
+ elif mode == "rgb_array":
352
+ return np.transpose(np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2))
353
+
354
+ def _render_text(self):
355
+ desc = self.text_canvas()
356
+ outfile = StringIO()
357
+
358
+ row, col = 0, (self.scope_size // 2) * 3 + 2
359
+ desc = [[c.decode("utf-8") for c in line] for line in desc]
360
+ desc[row][col] = colorize(desc[row][col], "red", highlight=True)
361
+
362
+ outfile.write("\n")
363
+ outfile.write("\n".join("".join(line) for line in desc) + "\n")
364
+
365
+ with closing(outfile):
366
+ return outfile.getvalue()
367
+
368
+
369
+ class ApplePhase0Env(AppleEnv):
370
+ def __init__(self, *args, **kwargs):
371
+ super().__init__(*args, **kwargs)
372
+ self.change_phase = False
373
+ self.phase = 0
374
+ self.init_pos = self.start_x
375
+ self.success_when_finish_phase = 0
376
+
377
+
378
+ class ApplePhase1Env(AppleEnv):
379
+ def __init__(self, *args, **kwargs):
380
+ super().__init__(*args, **kwargs)
381
+ self.change_phase = False
382
+ self.phase = 1
383
+ self.init_pos = self.goal_x
384
+ self.success_when_finish_phase = 1
apple/envs/img/apple.png ADDED
apple/envs/img/elf_down.png ADDED
apple/envs/img/elf_left.png ADDED
apple/envs/img/elf_right.png ADDED
apple/envs/img/g1.png ADDED
apple/envs/img/g2.png ADDED
apple/envs/img/g3.png ADDED
apple/envs/img/grass.jpg ADDED
apple/envs/img/home.png ADDED
apple/envs/img/home00.png ADDED
apple/envs/img/home01.png ADDED
apple/envs/img/home02.png ADDED
apple/envs/img/home10.png ADDED
apple/envs/img/home11.png ADDED
apple/envs/img/home12.png ADDED
apple/envs/img/home2.png ADDED
apple/envs/img/home2_with_apples.png ADDED
apple/envs/img/home_grass.png ADDED
apple/envs/img/part_grass.png ADDED
apple/envs/img/stool.png ADDED
apple/envs/img/textures.jpg ADDED
apple/envs/img/white.png ADDED
apple/evaluation/render_episode.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def render_episode(env, model):
2
+ obs, done, episode_return, episode_len = env.reset(), False, 0, 0
3
+
4
+ while not done:
5
+ action = model.get_action(obs)
6
+ new_obs, reward, done, _ = env.step(action)
7
+ episode_return += reward
8
+ episode_len += 1
9
+
10
+ data = dict(
11
+ x=env.x,
12
+ obs=obs,
13
+ action=action,
14
+ reward=reward,
15
+ done=done,
16
+ episode_len=episode_len,
17
+ episode_return=episode_return,
18
+ pixel_array=env.unwrapped.render(),
19
+ )
20
+ yield data
21
+
22
+ obs = new_obs
apple/logger.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ Some simple logging functionality, inspired by rllab's logging.
4
+
5
+ Logs to a tab-separated-values file (path/to/output_directory/progress.txt)
6
+
7
+ """
8
+ import atexit
9
+ import os
10
+ import os.path as osp
11
+ import time
12
+ import warnings
13
+
14
+ import joblib
15
+ import numpy as np
16
+ import torch
17
+
18
+ import wandb
19
+
20
+ color2num = dict(gray=30, red=31, green=32, yellow=33, blue=34, magenta=35, cyan=36, white=37, crimson=38)
21
+
22
+
23
+ def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=True):
24
+ """
25
+ Sets up the output_dir for a logger and returns a dict for logger kwargs.
26
+ If no seed is given and datestamp is false,
27
+ ::
28
+ output_dir = data_dir/exp_name
29
+ If a seed is given and datestamp is false,
30
+ ::
31
+ output_dir = data_dir/exp_name/exp_name_s[seed]
32
+ If datestamp is true, amend to
33
+ ::
34
+ output_dir = data_dir/YY-MM-DD_exp_name/YY-MM-DD_HH-MM-SS_exp_name_s[seed]
35
+ You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in
36
+ ``spinup/user_config.py``.
37
+ Args:
38
+ exp_name (string): Name for experiment.
39
+ seed (int): Seed for random number generators used by experiment.
40
+ data_dir (string): Path to folder where results should be saved.
41
+ Default is the ``DEFAULT_DATA_DIR`` in ``spinup/user_config.py``.
42
+ datestamp (bool): Whether to include a date and timestamp in the
43
+ name of the save directory.
44
+ Returns:
45
+ logger_kwargs, a dict containing output_dir and exp_name.
46
+ """
47
+ if data_dir is None:
48
+ data_dir = osp.join(osp.abspath(osp.dirname(osp.dirname(osp.dirname(__file__)))), "logs")
49
+
50
+ # Make base path
51
+ ymd_time = time.strftime("%Y-%m-%d_") if datestamp else ""
52
+ relpath = "".join([ymd_time, exp_name])
53
+
54
+ if seed is not None:
55
+ # Make a seed-specific subfolder in the experiment directory.
56
+ if datestamp:
57
+ hms_time = time.strftime("%Y-%m-%d_%H-%M-%S")
58
+ subfolder = "".join([hms_time, "-", exp_name, "_s", str(seed)])
59
+ else:
60
+ subfolder = "".join([exp_name, "_s", str(seed)])
61
+ relpath = osp.join(relpath, subfolder)
62
+
63
+ logger_kwargs = dict(output_dir=osp.join(data_dir, relpath), exp_name=exp_name)
64
+ return logger_kwargs
65
+
66
+
67
+ def colorize(string, color, bold=False, highlight=False):
68
+ """
69
+ Colorize a string.
70
+
71
+ This function was originally written by John Schulman.
72
+ """
73
+ attr = []
74
+ num = color2num[color]
75
+ if highlight:
76
+ num += 10
77
+ attr.append(str(num))
78
+ if bold:
79
+ attr.append("1")
80
+ return "\x1b[%sm%s\x1b[0m" % (";".join(attr), string)
81
+
82
+
83
+ class Logger:
84
+ """
85
+ A general-purpose logger.
86
+
87
+ Makes it easy to save diagnostics, hyperparameter configurations, the
88
+ state of a training run, and the trained model.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ log_to_wandb=False,
94
+ verbose=False,
95
+ output_dir=None,
96
+ output_fname="progress.csv",
97
+ delimeter=",",
98
+ exp_name=None,
99
+ wandbcommit=1,
100
+ ):
101
+ """
102
+ Initialize a Logger.
103
+
104
+ Args:
105
+ log_to_wandb (bool): If True logger will log to wandb
106
+
107
+ output_dir (string): A directory for saving results to. If
108
+ ``None``, defaults to a temp directory of the form
109
+ ``/tmp/experiments/somerandomnumber``.
110
+
111
+ output_fname (string): Name for the tab-separated-value file
112
+ containing metrics logged throughout a training run.
113
+ Defaults to ``progress.csv``.
114
+
115
+ exp_name (string): Experiment name. If you run multiple training
116
+ runs and give them all the same ``exp_name``, the plotter
117
+ will know to group them. (Use case: if you run the same
118
+ hyperparameter configuration with multiple random seeds, you
119
+ should give them all the same ``exp_name``.)
120
+
121
+ delimeter (string): Used to separate logged values saved in output_fname
122
+ """
123
+ self.verbose = verbose
124
+ self.log_to_wandb = log_to_wandb
125
+ self.delimeter = delimeter
126
+ self.wandbcommit = wandbcommit
127
+ self.log_iter = 1
128
+ # We assume that there's no multiprocessing.
129
+ if output_dir is not None:
130
+ self.output_dir = output_dir or "/tmp/experiments/%i" % int(time.time())
131
+ if osp.exists(self.output_dir):
132
+ print("Warning: Log dir %s already exists! Storing info there anyway." % self.output_dir)
133
+ else:
134
+ os.makedirs(self.output_dir)
135
+ self.output_file = open(osp.join(self.output_dir, output_fname), "w+")
136
+ atexit.register(self.output_file.close)
137
+ print(colorize("Logging data to %s" % self.output_file.name, "green", bold=True))
138
+ else:
139
+ self.output_file = None
140
+
141
+ self.first_row = True
142
+ self.log_headers = []
143
+ self.log_current_row = {}
144
+ self.exp_name = exp_name
145
+
146
+ def log(self, msg, color="green"):
147
+ """Print a colorized message to stdout."""
148
+ print(colorize(msg, color, bold=True))
149
+
150
+ def log_tabular(self, key, val):
151
+ """
152
+ Log a value of some diagnostic.
153
+
154
+ Call this only once for each diagnostic quantity, each iteration.
155
+ After using ``log_tabular`` to store values for each diagnostic,
156
+ make sure to call ``dump_tabular`` to write them out to file and
157
+ stdout (otherwise they will not get saved anywhere).
158
+ """
159
+ if self.first_row:
160
+ self.log_headers.append(key)
161
+ else:
162
+ if key not in self.log_headers:
163
+ self.log_headers.append(key)
164
+
165
+ if self.output_file is not None:
166
+ # move pointer at the beggining of the file
167
+ self.output_file.seek(0)
168
+ # skip the header
169
+ self.output_file.readline()
170
+ # keep rest of the file
171
+ logs = self.output_file.read()
172
+ # clear the file
173
+ self.output_file.truncate(0)
174
+ self.output_file.seek(0)
175
+ # write new headers
176
+ self.output_file.write(self.delimeter.join(self.log_headers) + "\n")
177
+ # write stored file
178
+ self.output_file.write(logs)
179
+ self.output_file.seek(0)
180
+ self.output_file.seek(0, 2)
181
+ # assert key in self.log_headers, (
182
+ # "Trying to introduce a new key %s that you didn't include in the first iteration" % key
183
+ # )
184
+ assert key not in self.log_current_row, (
185
+ "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key
186
+ )
187
+ self.log_current_row[key] = val
188
+
189
+ def save_state(self, state_dict, itr=None):
190
+ """
191
+ Saves the state of an experiment.
192
+
193
+ To be clear: this is about saving *state*, not logging diagnostics.
194
+ All diagnostic logging is separate from this function. This function
195
+ will save whatever is in ``state_dict``---usually just a copy of the
196
+ environment---and the most recent parameters for the model you
197
+ previously set up saving for with ``setup_tf_saver``.
198
+
199
+ Call with any frequency you prefer. If you only want to maintain a
200
+ single state and overwrite it at each call with the most recent
201
+ version, leave ``itr=None``. If you want to keep all of the states you
202
+ save, provide unique (increasing) values for 'itr'.
203
+
204
+ Args:
205
+ state_dict (dict): Dictionary containing essential elements to
206
+ describe the current state of training.
207
+
208
+ itr: An int, or None. Current iteration of training.
209
+ """
210
+ fname = "vars.pkl" if itr is None else "vars%d.pkl" % itr
211
+ try:
212
+ joblib.dump(state_dict, osp.join(self.output_dir, fname))
213
+ except:
214
+ self.log("Warning: could not pickle state_dict.", color="red")
215
+ if hasattr(self, "pytorch_saver_elements"):
216
+ self._pytorch_simple_save(itr)
217
+
218
+ def setup_pytorch_saver(self, what_to_save):
219
+ """
220
+ Set up easy model saving for a single PyTorch model.
221
+
222
+ Because PyTorch saving and loading is especially painless, this is
223
+ very minimal; we just need references to whatever we would like to
224
+ pickle. This is integrated into the logger because the logger
225
+ knows where the user would like to save information about this
226
+ training run.
227
+
228
+ Args:
229
+ what_to_save: Any PyTorch model or serializable object containing
230
+ PyTorch models.
231
+ """
232
+ self.pytorch_saver_elements = what_to_save
233
+
234
+ def _pytorch_simple_save(self, itr=None):
235
+ """
236
+ Saves the PyTorch model (or models).
237
+ """
238
+ assert hasattr(self, "pytorch_saver_elements"), "First have to setup saving with self.setup_pytorch_saver"
239
+ fpath = "pyt_save"
240
+ fpath = osp.join(self.output_dir, fpath)
241
+ fname = "model" + ("%d" % itr if itr is not None else "") + ".pt"
242
+ fname = osp.join(fpath, fname)
243
+ os.makedirs(fpath, exist_ok=True)
244
+ with warnings.catch_warnings():
245
+ warnings.simplefilter("ignore")
246
+ # We are using a non-recommended way of saving PyTorch models,
247
+ # by pickling whole objects (which are dependent on the exact
248
+ # directory structure at the time of saving) as opposed to
249
+ # just saving network weights. This works sufficiently well
250
+ # for the purposes of Spinning Up, but you may want to do
251
+ # something different for your personal PyTorch project.
252
+ # We use a catch_warnings() context to avoid the warnings about
253
+ # not being able to save the source code.
254
+ torch.save(self.pytorch_saver_elements, fname)
255
+
256
+ def dump_tabular(self):
257
+ """
258
+ Write all of the diagnostics from the current iteration.
259
+
260
+ Writes both to stdout, and to the output file.
261
+ """
262
+ vals = []
263
+ key_lens = [len(key) for key in self.log_headers]
264
+ max_key_len = max(15, max(key_lens))
265
+ keystr = "%" + "%d" % max_key_len
266
+ fmt = "| " + keystr + "s | %15s |"
267
+ n_slashes = 22 + max_key_len
268
+ step = self.log_current_row.get("total_env_steps")
269
+
270
+ if self.verbose:
271
+ print("-" * n_slashes)
272
+ for key in self.log_headers:
273
+ val = self.log_current_row.get(key, "")
274
+ valstr = "%8.3g" % val if isinstance(val, float) else val
275
+ print(fmt % (key, valstr))
276
+ vals.append(val)
277
+ print("-" * n_slashes, flush=True)
278
+
279
+ if self.output_file is not None:
280
+ if self.first_row:
281
+ self.output_file.write(self.delimeter.join(self.log_headers) + "\n")
282
+ self.output_file.write(self.delimeter.join(map(str, vals)) + "\n")
283
+ self.output_file.flush()
284
+
285
+ key_val_dict = {key: self.log_current_row.get(key, "") for key in self.log_headers}
286
+ if self.log_to_wandb:
287
+ if self.log_iter % self.wandbcommit == 0:
288
+ wandb.log(key_val_dict, step=step, commit=True)
289
+ else:
290
+ wandb.log(key_val_dict, step=step, commit=False)
291
+
292
+ self.log_current_row.clear()
293
+ self.first_row = False
294
+ self.log_iter += 1
295
+
296
+ return key_val_dict
297
+
298
+
299
+ class EpochLogger(Logger):
300
+ """
301
+ A variant of Logger tailored for tracking average values over epochs.
302
+
303
+ Typical use case: there is some quantity which is calculated many times
304
+ throughout an epoch, and at the end of the epoch, you would like to
305
+ report the average / std / min / max value of that quantity.
306
+
307
+ With an EpochLogger, each time the quantity is calculated, you would
308
+ use
309
+
310
+ .. code-block:: python
311
+
312
+ epoch_logger.store(NameOfQuantity=quantity_value)
313
+
314
+ to load it into the EpochLogger's state. Then at the end of the epoch, you
315
+ would use
316
+
317
+ .. code-block:: python
318
+
319
+ epoch_logger.log_tabular(NameOfQuantity, **options)
320
+
321
+ to record the desired values.
322
+ """
323
+
324
+ def __init__(self, *args, **kwargs):
325
+ super().__init__(*args, **kwargs)
326
+ self.epoch_dict = dict()
327
+
328
+ def store(self, d):
329
+ """
330
+ Save something into the epoch_logger's current state.
331
+
332
+ Provide an arbitrary number of keyword arguments with numerical
333
+ values.
334
+ """
335
+ for k, v in d.items():
336
+ if not (k in self.epoch_dict.keys()):
337
+ self.epoch_dict[k] = []
338
+ self.epoch_dict[k].append(v)
339
+
340
+ def log_tabular(self, key, val=None, with_min_and_max=False, with_median=False, with_sum=False, average_only=False):
341
+ """
342
+ Log a value or possibly the mean/std/min/max values of a diagnostic.
343
+
344
+ Args:
345
+ key (string): The name of the diagnostic. If you are logging a
346
+ diagnostic whose state has previously been saved with
347
+ ``store``, the key here has to match the key you used there.
348
+
349
+ val: A value for the diagnostic. If you have previously saved
350
+ values for this key via ``store``, do *not* provide a ``val``
351
+ here.
352
+
353
+ with_min_and_max (bool): If true, log min and max values of the
354
+ diagnostic over the epoch.
355
+
356
+ average_only (bool): If true, do not log the standard deviation
357
+ of the diagnostic over the epoch.
358
+ """
359
+ if val is not None:
360
+ super().log_tabular(key, val)
361
+ else:
362
+ stats = self.get_stats(key)
363
+ super().log_tabular(key if average_only else key + "/avg", stats[0])
364
+ if not (average_only):
365
+ super().log_tabular(key + "/std", stats[1])
366
+ if with_min_and_max:
367
+ super().log_tabular(key + "/max", stats[3])
368
+ super().log_tabular(key + "/min", stats[2])
369
+ if with_median:
370
+ super().log_tabular(key + "/med", stats[4])
371
+ if with_sum:
372
+ super().log_tabular(key + "/sum", stats[5])
373
+
374
+ self.epoch_dict[key] = []
375
+
376
+ def get_stats(self, key):
377
+ """
378
+ Lets an algorithm ask the logger for mean/std/min/max of a diagnostic.
379
+ """
380
+ v = self.epoch_dict.get(key)
381
+ if not v:
382
+ return [np.nan, np.nan, np.nan, np.nan]
383
+ vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape) > 0 else v
384
+ return [np.mean(vals), np.std(vals), np.min(vals), np.max(vals), np.median(vals), np.sum(vals)]
apple/models/categorical_policy.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.distributions import Categorical
5
+
6
+
7
+ class CategoricalPolicy(nn.Module):
8
+ def __init__(self, state_dim, act_dim, weight1=None, weight2=None):
9
+ super().__init__()
10
+ self.model = nn.Linear(state_dim, act_dim, bias=False)
11
+
12
+ if weight1 is not None:
13
+ nn.init.constant_(self.model.weight[0][0], weight1)
14
+
15
+ if weight2 is not None:
16
+ nn.init.constant_(self.model.weight[0][1], weight2)
17
+
18
+ def forward(self, state):
19
+ x = torch.from_numpy(state).float().unsqueeze(0)
20
+ x = self.model(x)
21
+ # we just consider 1 dimensional probability of action
22
+ p = torch.sigmoid(x)
23
+ return torch.cat([p, 1 - p], dim=1)
24
+
25
+ def act(self, state):
26
+ probs = self.forward(state)
27
+ dist = Categorical(probs)
28
+ action = dist.sample()
29
+ return action.item(), dist.log_prob(action)
30
+
31
+ def sample(self, probs):
32
+ dist = Categorical(probs)
33
+ action = dist.sample()
34
+ return action.item(), dist.log_prob(action)
35
+
36
+ def log_prob(self, probs, target_action):
37
+ dist = Categorical(probs)
38
+ action = dist.sample()
39
+ return action.item(), dist.log_prob(target_action)
40
+
41
+ @torch.no_grad()
42
+ def get_action(self, state):
43
+ probs = self.forward(state)
44
+ dist = Categorical(probs)
45
+ action = dist.sample()
46
+ return action.item()
apple/training/reinforce_trainer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from apple.training.trainer import Trainer
4
+
5
+
6
+ def discount_cumsum(x, gamma):
7
+ discount_cumsum = torch.zeros_like(x)
8
+ discount_cumsum[-1] = x[-1]
9
+ for t in reversed(range(x.shape[0] - 1)):
10
+ discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
11
+ return discount_cumsum
12
+
13
+
14
+ class ReinforceTrainer(Trainer):
15
+ def __init__(self, *args, gamma: float = 1.0, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+ self.gamma = gamma
18
+
19
+ def train(self, env, test_envs, num_episodes, log_every, update_every, num_eval_eps):
20
+ # code base on
21
+ # https://goodboychan.github.io/python/reinforcement_learning/pytorch/udacity/2021/05/12/REINFORCE-CartPole.html
22
+ self.optim.zero_grad()
23
+ for episode in range(num_episodes):
24
+ self.train_it += 1
25
+
26
+ if (episode + 1) % log_every == 0:
27
+
28
+ self.test_agent(self.model, self.logger, test_envs, num_eval_eps)
29
+
30
+ # Log info about epoch
31
+ self.logger.log_tabular("total_env_steps", self.train_it)
32
+ self.logger.log_tabular("train/return", with_min_and_max=True)
33
+ self.logger.log_tabular("train/ep_length", average_only=True)
34
+
35
+ for e, w in enumerate(self.model.model.weight.flatten()):
36
+ self.logger.log_tabular(f"weights{e}", w.item())
37
+
38
+ self.logger.log_tabular("train/policy_loss", average_only=True)
39
+ self.logger.log_tabular("train/log_probs", average_only=True)
40
+ self.logger.dump_tabular()
41
+
42
+ state = env.reset()
43
+
44
+ saved_log_probs = []
45
+ rewards = []
46
+ ep_len, ep_ret = 0, 0
47
+ while True:
48
+ # Sample the action from current policy
49
+ action, log_prob = self.model.act(state)
50
+ saved_log_probs.append(log_prob)
51
+ state, reward, done, _ = env.step(action)
52
+ ep_ret += reward
53
+ ep_len += 1
54
+
55
+ rewards.append(reward)
56
+
57
+ if done:
58
+ self.logger.store({"train/return": ep_ret, "train/ep_length": ep_len})
59
+ break
60
+
61
+ saved_log_probs, rewards = torch.cat(saved_log_probs), torch.tensor(rewards)
62
+
63
+ discounted_rewards = discount_cumsum(rewards, gamma=self.gamma)
64
+ # Note that we are using Gradient Ascent, not Descent. So we need to calculate it with negative rewards.
65
+ policy_loss = (-discounted_rewards * saved_log_probs).sum()
66
+ # Backpropagation
67
+ if (episode + 1) % update_every == 0:
68
+ self.optim.zero_grad()
69
+ policy_loss.backward()
70
+ if (episode + 1) % update_every == 0:
71
+ self.optim.step()
72
+
73
+ self.logger.store({"train/policy_loss": policy_loss.item()})
74
+ self.logger.store({"train/log_probs": saved_log_probs.mean().item()})
apple/training/trainer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class Trainer:
6
+ def __init__(self, model, optim, logger):
7
+ self.model = model
8
+ self.optim = optim
9
+ self.logger = logger
10
+ self.train_it = 0
11
+
12
+ def test_agent(self, model, logger, test_envs, num_episodes):
13
+ avg_success = []
14
+
15
+ for seq_idx, test_env in enumerate(test_envs):
16
+ key_prefix = f"{test_env.name}/"
17
+
18
+ for j in range(num_episodes):
19
+ obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0
20
+
21
+ while not done:
22
+ action = model.get_action(obs)
23
+ obs, reward, done, _ = test_env.step(action)
24
+ episode_return += reward
25
+ episode_len += 1
26
+ logger.store({key_prefix + "return": episode_return, key_prefix + "ep_length": episode_len})
27
+
28
+ logger.log_tabular(key_prefix + "return", with_min_and_max=True)
29
+ logger.log_tabular(key_prefix + "ep_length", average_only=True)
30
+ env_success = test_env.pop_successes()
31
+ avg_success += env_success
32
+ logger.log_tabular(key_prefix + "success", np.mean(env_success))
33
+
34
+ key = "average_success"
35
+ logger.log_tabular(key, np.mean(avg_success))
36
+
37
+ def log(self, logger, step, model):
38
+ # Log info about epoch
39
+ logger.log_tabular("total_env_steps", step)
40
+
41
+ logger.log_tabular("train/loss", average_only=True)
42
+ logger.log_tabular("train/action", average_only=True)
43
+
44
+ for e, w in enumerate(model.model.weight.flatten()):
45
+ logger.log_tabular(f"weight{e}", w.item())
46
+
47
+ return logger.dump_tabular()
48
+
49
+ def update(self, env, probs, model, optim, logger):
50
+ target = torch.as_tensor([env.get_target_action()], dtype=torch.float32)
51
+ action, log_prob = model.log_prob(probs, target)
52
+
53
+ optim.zero_grad()
54
+ loss = -torch.mean(log_prob)
55
+ loss.backward()
56
+ optim.step()
57
+
58
+ logger.store({"train/action": action})
59
+ logger.store({"train/loss": loss.item()})
60
+
61
+ def train(self, env, test_envs, steps, log_every, num_eval_eps):
62
+ obs = env.reset()
63
+ for timestep in range(steps):
64
+ self.train_it += 1
65
+
66
+ if (timestep + 1) % log_every == 0:
67
+ self.test_agent(self.model, self.logger, test_envs, num_eval_eps)
68
+ self.log(self.logger, self.train_it, self.model)
69
+
70
+ output = self.model(obs)
71
+ action, log_prob = self.model.sample(output)
72
+ self.update(env, output, self.model, self.optim, self.logger)
73
+
74
+ obs, reward, done, info = env.step(action)
75
+
76
+ if done:
77
+ obs = env.reset()
apple/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ # https://stackoverflow.com/a/43357954/6365092
11
+ def str2bool(v: Union[bool, str]) -> bool:
12
+ if isinstance(v, bool):
13
+ return v
14
+ if v.lower() in ("yes", "true", "t", "y", "1"):
15
+ return True
16
+ elif v.lower() in ("no", "false", "f", "n", "0"):
17
+ return False
18
+ else:
19
+ raise argparse.ArgumentTypeError("Boolean value expected.")
20
+
21
+
22
+ def set_seed(seed):
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
apple/wrappers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import (
2
+ Any,
3
+ Dict,
4
+ List,
5
+ Tuple,
6
+ )
7
+
8
+ import gym
9
+ import numpy as np
10
+
11
+
12
+ class SuccessCounter(gym.Wrapper):
13
+ """Helper class to keep count of successes in MetaWorld environments."""
14
+
15
+ def __init__(self, env: gym.Env) -> None:
16
+ super().__init__(env)
17
+ self.successes = []
18
+ self.current_success = False
19
+
20
+ def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
21
+ obs, reward, done, info = self.env.step(action)
22
+ if info.get("success", False):
23
+ self.current_success = True
24
+ if done:
25
+ self.successes.append(self.current_success)
26
+ return obs, reward, done, info
27
+
28
+ def pop_successes(self) -> List[bool]:
29
+ res = self.successes
30
+ self.successes = []
31
+ return res
32
+
33
+ def reset(self, **kwargs) -> np.ndarray:
34
+ self.current_success = False
35
+ return self.env.reset(**kwargs)
assets/apple_env.png ADDED
assets/example_rollout.mp4 ADDED
Binary file (86.8 kB). View file
 
assets/generate_example_rollout.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import skvideo.io
3
+
4
+ from apple.envs.discrete_apple import get_apple_env
5
+
6
+ env = get_apple_env("full", time_limit=10, start_x=0, c=0.5, goal_x=8, render_mode="rgb_array")
7
+
8
+
9
+ imgs = []
10
+ env.reset()
11
+ for i in range(8):
12
+ imgs.append(env.unwrapped.render())
13
+ env.step(1)
14
+
15
+ for i in range(9):
16
+ imgs.append(env.unwrapped.render())
17
+ env.step(0)
18
+
19
+
20
+ skvideo.io.vwrite(
21
+ "example_rollout.mp4",
22
+ np.stack(imgs),
23
+ inputdict={
24
+ "-r": str(int(4)),
25
+ },
26
+ outputdict={
27
+ "-f": "mp4",
28
+ "-pix_fmt": "yuv420p", # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74
29
+ },
30
+ )
input_args.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from apple.utils import str2bool
4
+
5
+
6
+ def apple_parse_args(args=None):
7
+ parser = argparse.ArgumentParser()
8
+
9
+ parser.add_argument("--c", type=float, default=0.25, required=False)
10
+ parser.add_argument("--start_x", type=float, default=0.0, required=False)
11
+ parser.add_argument("--goal_x", type=float, default=10.0, required=False)
12
+ parser.add_argument("--time_limit", type=int, default=100.0, required=False)
13
+
14
+ parser.add_argument("--lr", type=float, default=1e-3, required=False)
15
+ parser.add_argument("--log_to_wandb", type=str2bool, default=True, required=False)
16
+
17
+ return parser.parse_known_args(args=args)[0]
mrunner_exps/behavioral_cloning.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from mrunner.helpers.specification_helper import create_experiments_helper
4
+
5
+ from mrunner_exps.utils import combine_config_with_defaults
6
+
7
+ name = globals()["script"][:-3]
8
+
9
+ # params for all exps
10
+ config = {
11
+ "exp_tag": "behavioral_cloning",
12
+ "run_kind": "bc",
13
+ "log_to_wandb": True,
14
+ "pretrain_steps": 200,
15
+ "steps": 200,
16
+ "log_every": 1,
17
+ "num_eval_eps": 10,
18
+ "verbose": False,
19
+ "lr": 0.01,
20
+ "c": 0.5,
21
+ "start_x": 0.0,
22
+ "goal_x": 50.0,
23
+ "bias_in_state": True,
24
+ "position_in_state": False,
25
+ "time_limit": 100,
26
+ "wandbcommit": 100,
27
+ "pretrain": "phase2",
28
+ "finetune": "full",
29
+ }
30
+ config = combine_config_with_defaults(config)
31
+
32
+ # params different between exps
33
+ params_grid = [
34
+ {
35
+ "seed": list(range(10)),
36
+ "c": list(np.arange(0.1, 1.1, 0.1)),
37
+ "goal_x": list(np.arange(5, 50, 5)),
38
+ }
39
+ ]
40
+
41
+ experiments_list = create_experiments_helper(
42
+ experiment_name=name,
43
+ project_name="apple",
44
+ with_neptune=False,
45
+ script="python3 mrunner_run.py",
46
+ python_path=".",
47
+ tags=[name],
48
+ exclude=["logs", "wandb"],
49
+ base_config=config,
50
+ params_grid=params_grid,
51
+ )
mrunner_exps/reinforce.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from mrunner.helpers.specification_helper import create_experiments_helper
4
+
5
+ from mrunner_exps.utils import combine_config_with_defaults
6
+
7
+ name = globals()["script"][:-3]
8
+
9
+ # params for all exps
10
+ config = {
11
+ "exp_tag": "reinforce_goal_c3",
12
+ "run_kind": "reinforce",
13
+ "log_to_wandb": True,
14
+ "pretrain_steps": 1000,
15
+ "steps": 2000,
16
+ "log_every": 1,
17
+ "num_eval_eps": 10,
18
+ "verbose": False,
19
+ "lr": 0.001,
20
+ "c": 1.0,
21
+ "start_x": 0.0,
22
+ "goal_x": 50.0,
23
+ "bias_in_state": True,
24
+ "position_in_state": False,
25
+ "time_limit": 100,
26
+ "gamma": 0.99,
27
+ "wandbcommit": 1000,
28
+ "pretrain": "phase2",
29
+ "finetune": "full",
30
+ "update_every": 10, # for good definition on gradient
31
+ }
32
+ config = combine_config_with_defaults(config)
33
+
34
+ # params different between exps
35
+ params_grid = [
36
+ {
37
+ "seed": list(range(10)),
38
+ "c": list(np.arange(0.1, 1.1, 0.1)),
39
+ "goal_x": list(np.arange(5, 50, 5)),
40
+ }
41
+ ]
42
+
43
+ experiments_list = create_experiments_helper(
44
+ experiment_name=name,
45
+ project_name="apple",
46
+ with_neptune=False,
47
+ script="python3 mrunner_run.py",
48
+ python_path=".",
49
+ tags=[name],
50
+ exclude=["logs", "wandb"],
51
+ base_config=config,
52
+ params_grid=params_grid,
53
+ )
mrunner_exps/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from input_args import apple_parse_args
2
+
3
+ PARSE_ARGS_DICT = {"bc": apple_parse_args, "reinforce": apple_parse_args}
4
+
5
+
6
+ def combine_config_with_defaults(config):
7
+ run_kind = config["run_kind"]
8
+ res = vars(PARSE_ARGS_DICT[run_kind]([]))
9
+ res.update(config)
10
+ return res
mrunner_run.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mrunner.helpers.client_helper import get_configuration
2
+
3
+ import wandb
4
+
5
+ from run import main
6
+
7
+ if __name__ == "__main__":
8
+ config = get_configuration(print_diagnostics=True, with_neptune=False)
9
+
10
+ del config["experiment_id"]
11
+
12
+ if config.log_to_wandb:
13
+ wandb.init(
14
+ entity="gmum",
15
+ project="apple",
16
+ config=config,
17
+ )
18
+ main(**config)
mrunner_runs/local.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python mrunner_run.py --ex mrunner_exps/baseline.py
mrunner_runs/remote.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ conda activate apple
4
+
5
+ ssh-add
6
+ export PYTHONPATH=.
7
+
8
+ mrunner --config ~/.mrunner.yaml --context eagle_transfer_mw2 run mrunner_exps/behavioral_cloning.py
9
+ # mrunner --config ~/.mrunner.yaml --context eagle_transfer_mw2 run mrunner_exps/reinforce.py
pyproject.toml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "setuptools_scm", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "apple"
7
+ description = "Simplest experiment for showing forgetting"
8
+ license = { text = "Proprietary" }
9
+ authors = [{name = "BartekCupial", email = "bartlomiej.cupial@student.uj.edu.pl" }]
10
+
11
+ dynamic = ["version"]
12
+
13
+ requires-python = ">= 3.8, < 3.11"
14
+
15
+ dependencies = [
16
+ "numpy ~= 1.23",
17
+ "typing-extensions ~= 4.3",
18
+ "gym == 0.23",
19
+ "torch ~= 1.12",
20
+ "wandb ~= 0.13",
21
+ "pandas ~= 1.5",
22
+ "matplotlib ~= 3.6",
23
+ "seaborn ~= 0.12",
24
+ "scipy ~= 1.9",
25
+ "joblib ~= 1.2",
26
+ "pygame ~= 2.1",
27
+ ]
28
+
29
+ [project.optional-dependencies]
30
+ build = ["build ~= 0.8"]
31
+ mrunner = ["mrunner @ git+https://gitlab.com/awarelab/mrunner.git"]
32
+ lint = [
33
+ "black ~= 22.6",
34
+ "autoflake ~= 1.4",
35
+ "flake8 ~= 4.0",
36
+ "flake8-pyi ~= 22.5",
37
+ "flake8-docstrings ~= 1.6",
38
+ "pyproject-flake8 ~= 0.0.1a4",
39
+ "isort ~= 5.10",
40
+ "pre-commit ~= 2.20",
41
+ ]
42
+ test = [
43
+ "pytest ~= 7.1",
44
+ "pytest-cases ~= 3.6",
45
+ "pytest-cov ~= 3.0",
46
+ "pytest-xdist ~= 2.5",
47
+ "pytest-sugar ~= 0.9",
48
+ "hypothesis ~= 6.54",
49
+ ]
50
+ dev = [
51
+ "apple[mrunner]",
52
+ "apple[build]",
53
+ "apple[lint]",
54
+ "apple[test]",
55
+ ]
56
+
57
+ [project.urls]
58
+ "Source" = "https://github.com/BartekCupial/apple"
59
+
60
+ [tool.black]
61
+ line_length = 120
62
+
63
+ [tool.flake8]
64
+ extend_exclude = [".venv/", "build/", "dist/", "docs/"]
65
+ per_file_ignores = ["**/_[a-z]*.py:D", "tests/*.py:D", "*.pyi:D"]
66
+ ignore = [
67
+ # Handled by black
68
+ "E", # pycodestyle
69
+ "W", # pycodestyle
70
+ "D",
71
+ ]
72
+ ignore_decorators = "property" # https://github.com/PyCQA/pydocstyle/pull/546
73
+
74
+ [tool.isort]
75
+ profile = "black"
76
+ line_length = 120
77
+ order_by_type = true
78
+ lines_between_types = 1
79
+ combine_as_imports = true
80
+ force_grid_wrap = 2
81
+
82
+ [tool.pytest.ini_options]
83
+ testpaths = "tests"
84
+ addopts = """
85
+ -n auto
86
+ -ra
87
+ --tb short
88
+ --doctest-modules
89
+ --junit-xml test-results.xml
90
+ --cov-report term-missing:skip-covered
91
+ --cov-report xml:coverage.xml
92
+ """
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ pandas
4
+ gym == 0.23
5
+ wandb
6
+ joblib
7
+ pygame
run.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import wandb
4
+
5
+ from apple.envs.discrete_apple import get_apple_env
6
+ from apple.logger import EpochLogger
7
+ from apple.models.categorical_policy import CategoricalPolicy
8
+ from apple.training.reinforce_trainer import ReinforceTrainer
9
+ from apple.training.trainer import Trainer
10
+ from apple.utils import set_seed
11
+ from input_args import apple_parse_args
12
+
13
+
14
+ def main(
15
+ run_kind: str,
16
+ c: float = 1.0,
17
+ start_x: float = 0.0,
18
+ goal_x: float = 50.0,
19
+ time_limit: int = 200,
20
+ bias_in_state: bool = True,
21
+ position_in_state: bool = False,
22
+ apple_in_state: bool = True,
23
+ lr: float = 1e-3,
24
+ pretrain_steps: int = 0,
25
+ steps: int = 10000,
26
+ log_every: int = 1,
27
+ num_eval_eps: int = 1,
28
+ pretrain: str = "phase1",
29
+ finetune: str = "full",
30
+ log_to_wandb: bool = False,
31
+ wandbcommit: int = 1,
32
+ verbose: bool = False,
33
+ output_dir="logs/apple",
34
+ gamma: float = 1.0,
35
+ update_every: int = 10,
36
+ seed=0,
37
+ **kwargs,
38
+ ):
39
+ set_seed(seed)
40
+
41
+ logger = EpochLogger(
42
+ exp_name=run_kind,
43
+ output_dir=output_dir,
44
+ log_to_wandb=log_to_wandb,
45
+ wandbcommit=wandbcommit,
46
+ verbose=verbose,
47
+ )
48
+
49
+ env_kwargs = dict(
50
+ start_x=start_x,
51
+ goal_x=goal_x,
52
+ c=c,
53
+ time_limit=time_limit,
54
+ bias_in_state=bias_in_state,
55
+ position_in_state=position_in_state,
56
+ apple_in_state=apple_in_state,
57
+ )
58
+
59
+ env_phase1 = get_apple_env(pretrain, **env_kwargs)
60
+ env_phase2 = get_apple_env(finetune, **env_kwargs)
61
+ test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]]
62
+
63
+ model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1)
64
+ optim = torch.optim.SGD(model.parameters(), lr=lr)
65
+
66
+ if run_kind == "reinforce":
67
+ trainer = ReinforceTrainer(model, optim, logger, gamma=gamma)
68
+ trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps)
69
+ trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps)
70
+ elif run_kind == "bc":
71
+ trainer = Trainer(model, optim, logger)
72
+ trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps)
73
+ trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ args = apple_parse_args()
78
+
79
+ if args.log_to_wandb:
80
+ wandb.init(
81
+ entity="gmum",
82
+ project="apple",
83
+ config=args,
84
+ settings=wandb.Settings(start_method="fork"),
85
+ )
86
+
87
+ main(**vars(args))
setup.cfg ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [options]
2
+ packages = find_namespace:
3
+ package_dir =
4
+ = apple
5
+
6
+ [options.packages.find]
7
+ where = apple
8
+
9
+ [options.package_data]
10
+ * = py.typed, *.pyi