Spaces:
Runtime error
Runtime error
New Author Name
commited on
Commit
·
4b714e2
1
Parent(s):
8264cee
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .editorconfig +16 -0
- .gitignore +18 -0
- .pre-commit-config.yaml +29 -0
- .streamlit/config.toml +34 -0
- README.md +8 -7
- app.py +606 -0
- apple/envs/discrete_apple.py +384 -0
- apple/envs/img/apple.png +0 -0
- apple/envs/img/elf_down.png +0 -0
- apple/envs/img/elf_left.png +0 -0
- apple/envs/img/elf_right.png +0 -0
- apple/envs/img/g1.png +0 -0
- apple/envs/img/g2.png +0 -0
- apple/envs/img/g3.png +0 -0
- apple/envs/img/grass.jpg +0 -0
- apple/envs/img/home.png +0 -0
- apple/envs/img/home00.png +0 -0
- apple/envs/img/home01.png +0 -0
- apple/envs/img/home02.png +0 -0
- apple/envs/img/home10.png +0 -0
- apple/envs/img/home11.png +0 -0
- apple/envs/img/home12.png +0 -0
- apple/envs/img/home2.png +0 -0
- apple/envs/img/home2_with_apples.png +0 -0
- apple/envs/img/home_grass.png +0 -0
- apple/envs/img/part_grass.png +0 -0
- apple/envs/img/stool.png +0 -0
- apple/envs/img/textures.jpg +0 -0
- apple/envs/img/white.png +0 -0
- apple/evaluation/render_episode.py +22 -0
- apple/logger.py +384 -0
- apple/models/categorical_policy.py +46 -0
- apple/training/reinforce_trainer.py +74 -0
- apple/training/trainer.py +77 -0
- apple/utils.py +25 -0
- apple/wrappers.py +35 -0
- assets/apple_env.png +0 -0
- assets/example_rollout.mp4 +0 -0
- assets/generate_example_rollout.py +30 -0
- input_args.py +17 -0
- mrunner_exps/behavioral_cloning.py +51 -0
- mrunner_exps/reinforce.py +53 -0
- mrunner_exps/utils.py +10 -0
- mrunner_run.py +18 -0
- mrunner_runs/local.sh +3 -0
- mrunner_runs/remote.sh +9 -0
- pyproject.toml +92 -0
- requirements.txt +7 -0
- run.py +87 -0
- setup.cfg +10 -0
.editorconfig
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
root = true
|
| 2 |
+
|
| 3 |
+
[*]
|
| 4 |
+
charset = utf-8
|
| 5 |
+
end_of_line = lf
|
| 6 |
+
insert_final_newline = true
|
| 7 |
+
trim_trailing_whitespace = true
|
| 8 |
+
|
| 9 |
+
[*]
|
| 10 |
+
indent_size = 4
|
| 11 |
+
indent_style = space
|
| 12 |
+
max_line_length = 120
|
| 13 |
+
tab_width = 8
|
| 14 |
+
|
| 15 |
+
[*.{yml,yaml}]
|
| 16 |
+
indent_size = 2
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/.venv/
|
| 2 |
+
/.python-version
|
| 3 |
+
|
| 4 |
+
/build/
|
| 5 |
+
/dist/
|
| 6 |
+
/site/
|
| 7 |
+
/test-results.xml
|
| 8 |
+
/.coverage
|
| 9 |
+
/coverage.xml
|
| 10 |
+
|
| 11 |
+
/.hypothesis/
|
| 12 |
+
__pycache__/
|
| 13 |
+
*.egg-info/
|
| 14 |
+
|
| 15 |
+
/.vscode/
|
| 16 |
+
wandb
|
| 17 |
+
logs
|
| 18 |
+
*.pt
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: local
|
| 3 |
+
hooks:
|
| 4 |
+
- id: autoflake
|
| 5 |
+
name: autoflake
|
| 6 |
+
entry: autoflake
|
| 7 |
+
args: [--in-place, --remove-all-unused-imports, --remove-unused-variables]
|
| 8 |
+
language: system
|
| 9 |
+
types_or: [python, pyi]
|
| 10 |
+
|
| 11 |
+
- id: isort
|
| 12 |
+
name: isort
|
| 13 |
+
entry: isort
|
| 14 |
+
args: [--quiet]
|
| 15 |
+
language: system
|
| 16 |
+
types_or: [python, pyi]
|
| 17 |
+
|
| 18 |
+
- id: black
|
| 19 |
+
name: black
|
| 20 |
+
entry: black
|
| 21 |
+
args: [--quiet]
|
| 22 |
+
language: system
|
| 23 |
+
types_or: [python, pyi]
|
| 24 |
+
|
| 25 |
+
- id: flake8
|
| 26 |
+
name: flake8
|
| 27 |
+
entry: pflake8
|
| 28 |
+
language: system
|
| 29 |
+
types_or: [python, pyi]
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
#theme primary
|
| 3 |
+
base="dark"
|
| 4 |
+
# Primary accent color for interactive elements.
|
| 5 |
+
primaryColor="f63366"
|
| 6 |
+
|
| 7 |
+
# Background color for the main content area.
|
| 8 |
+
#backgroundColor =
|
| 9 |
+
|
| 10 |
+
# Background color used for the sidebar and most interactive widgets.
|
| 11 |
+
#secondaryBackgroundColor ='grey'
|
| 12 |
+
|
| 13 |
+
# Color used for almost all text.
|
| 14 |
+
#textColor ='blue'
|
| 15 |
+
|
| 16 |
+
# Font family for all text in the app, except code blocks. One of "sans serif", "serif", or "monospace".
|
| 17 |
+
# Default: "sans serif"
|
| 18 |
+
font = "sans serif"
|
| 19 |
+
|
| 20 |
+
# [logger]
|
| 21 |
+
# level='info'
|
| 22 |
+
# messageFormat = "%(message)s"
|
| 23 |
+
#messageFormat="%(asctime)s %(message)s"
|
| 24 |
+
|
| 25 |
+
[global]
|
| 26 |
+
|
| 27 |
+
# By default, Streamlit checks if the Python watchdog module is available and, if not, prints a warning asking for you to install it. The watchdog module is not required, but highly recommended. It improves Streamlit's ability to detect changes to files in your filesystem.
|
| 28 |
+
# If you'd like to turn off this warning, set this to True.
|
| 29 |
+
# Default: false
|
| 30 |
+
disableWatchdogWarning = false
|
| 31 |
+
|
| 32 |
+
# If True, will show a warning when you run a Streamlit-enabled script via "python my_script.py".
|
| 33 |
+
# Default: true
|
| 34 |
+
showWarningOnDirectExecution = false
|
README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title: Apple
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: streamlit
|
| 7 |
-
sdk_version: 1.
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned:
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Apple Retrieval
|
| 3 |
+
emoji: 🍎
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: streamlit
|
| 7 |
+
sdk_version: 1.15.2
|
| 8 |
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
fullWidth: true
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from apple.envs.discrete_apple import get_apple_env
|
| 11 |
+
from apple.evaluation.render_episode import render_episode
|
| 12 |
+
from apple.logger import EpochLogger
|
| 13 |
+
from apple.models.categorical_policy import CategoricalPolicy
|
| 14 |
+
from apple.training.trainer import Trainer
|
| 15 |
+
|
| 16 |
+
QUEUE_SIZE = 1000
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def init_training(
|
| 20 |
+
c: float = 1.0,
|
| 21 |
+
start_x: float = 0.0,
|
| 22 |
+
goal_x: float = 50.0,
|
| 23 |
+
time_limit: int = 200,
|
| 24 |
+
lr: float = 1e-3,
|
| 25 |
+
weight1: float = 0.0,
|
| 26 |
+
weight2: float = 0.0,
|
| 27 |
+
pretrain: str = "phase1",
|
| 28 |
+
finetune: str = "full",
|
| 29 |
+
bias_in_state: bool = True,
|
| 30 |
+
position_in_state: bool = False,
|
| 31 |
+
apple_in_state: bool = True,
|
| 32 |
+
):
|
| 33 |
+
st.session_state.logger = EpochLogger(verbose=False)
|
| 34 |
+
|
| 35 |
+
env_kwargs = dict(
|
| 36 |
+
start_x=start_x,
|
| 37 |
+
goal_x=goal_x,
|
| 38 |
+
c=c,
|
| 39 |
+
time_limit=time_limit,
|
| 40 |
+
bias_in_state=bias_in_state,
|
| 41 |
+
position_in_state=position_in_state,
|
| 42 |
+
apple_in_state=apple_in_state,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
st.session_state.env_full = get_apple_env("full", render_mode="rgb_array", **env_kwargs)
|
| 46 |
+
st.session_state.env_phase1 = get_apple_env(pretrain, **env_kwargs)
|
| 47 |
+
st.session_state.env_phase2 = get_apple_env(finetune, **env_kwargs)
|
| 48 |
+
st.session_state.test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]]
|
| 49 |
+
|
| 50 |
+
st.session_state.model = CategoricalPolicy(
|
| 51 |
+
st.session_state.env_phase1.observation_space.shape[0], 1, weight1, weight2
|
| 52 |
+
)
|
| 53 |
+
st.session_state.optim = torch.optim.SGD(st.session_state.model.parameters(), lr=lr)
|
| 54 |
+
st.session_state.trainer = Trainer(st.session_state.model, st.session_state.optim, st.session_state.logger)
|
| 55 |
+
st.session_state.train_it = 0
|
| 56 |
+
st.session_state.draw_it = 0
|
| 57 |
+
st.session_state.total_steps = 0
|
| 58 |
+
st.session_state.data = []
|
| 59 |
+
|
| 60 |
+
st.session_state.obs1 = st.session_state.env_phase1.reset()
|
| 61 |
+
st.session_state.obs2 = st.session_state.env_phase2.reset()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def init_reset():
|
| 65 |
+
st.session_state.rollout_iterator = iter(
|
| 66 |
+
partial(render_episode, st.session_state.env_full, st.session_state.model)()
|
| 67 |
+
)
|
| 68 |
+
st.session_state.last_image = dict(
|
| 69 |
+
x=0,
|
| 70 |
+
obs=st.session_state.env_full.reset(),
|
| 71 |
+
action=None,
|
| 72 |
+
reward=0,
|
| 73 |
+
done=False,
|
| 74 |
+
episode_len=0,
|
| 75 |
+
episode_return=0,
|
| 76 |
+
pixel_array=st.session_state.env_full.unwrapped.render(),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def select_preset():
|
| 81 |
+
if st.session_state.pick_preset == 0:
|
| 82 |
+
preset_finetuning_interference()
|
| 83 |
+
elif st.session_state.pick_preset == 1:
|
| 84 |
+
preset_train_full_from_scratch()
|
| 85 |
+
elif st.session_state.pick_preset == 2:
|
| 86 |
+
preset_task_interference()
|
| 87 |
+
elif st.session_state.pick_preset == 3:
|
| 88 |
+
preset_without_task_interference()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def preset_task_interference():
|
| 92 |
+
st.session_state.pick_c = 0.5
|
| 93 |
+
st.session_state.pick_goal_x = 20
|
| 94 |
+
st.session_state.pick_time_limit = 50
|
| 95 |
+
st.session_state.pick_lr = 0.05
|
| 96 |
+
st.session_state.pick_phase1task = "phase2"
|
| 97 |
+
st.session_state.pick_phase2task = "phase1"
|
| 98 |
+
st.session_state.pick_weight1 = 0.0
|
| 99 |
+
st.session_state.pick_weight2 = 0.0
|
| 100 |
+
st.session_state.pick_phase1steps = 500
|
| 101 |
+
st.session_state.pick_phase2steps = 500
|
| 102 |
+
need_reset()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def preset_finetuning_interference():
|
| 106 |
+
st.session_state.pick_c = 0.5
|
| 107 |
+
st.session_state.pick_goal_x = 20
|
| 108 |
+
st.session_state.pick_time_limit = 50
|
| 109 |
+
st.session_state.pick_lr = 0.05
|
| 110 |
+
st.session_state.pick_phase1task = "phase2"
|
| 111 |
+
st.session_state.pick_phase2task = "full"
|
| 112 |
+
st.session_state.pick_weight1 = 0.0
|
| 113 |
+
st.session_state.pick_weight2 = 0.0
|
| 114 |
+
st.session_state.pick_phase1steps = 500
|
| 115 |
+
st.session_state.pick_phase2steps = 2000
|
| 116 |
+
need_reset()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def preset_without_task_interference():
|
| 120 |
+
st.session_state.pick_c = 1.0
|
| 121 |
+
st.session_state.pick_goal_x = 20
|
| 122 |
+
st.session_state.pick_time_limit = 50
|
| 123 |
+
st.session_state.pick_lr = 0.05
|
| 124 |
+
st.session_state.pick_phase1task = "phase2"
|
| 125 |
+
st.session_state.pick_phase2task = "phase1"
|
| 126 |
+
st.session_state.pick_weight1 = 0.0
|
| 127 |
+
st.session_state.pick_weight2 = 0.0
|
| 128 |
+
st.session_state.pick_phase1steps = 500
|
| 129 |
+
st.session_state.pick_phase2steps = 500
|
| 130 |
+
need_reset()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def preset_train_full_from_scratch():
|
| 134 |
+
st.session_state.pick_c = 0.5
|
| 135 |
+
st.session_state.pick_goal_x = 20
|
| 136 |
+
st.session_state.pick_time_limit = 50
|
| 137 |
+
st.session_state.pick_lr = 0.05
|
| 138 |
+
st.session_state.pick_phase1task = "phase2"
|
| 139 |
+
st.session_state.pick_phase2task = "full"
|
| 140 |
+
st.session_state.pick_weight1 = 0.0
|
| 141 |
+
st.session_state.pick_weight2 = 0.0
|
| 142 |
+
st.session_state.pick_phase1steps = 0
|
| 143 |
+
st.session_state.pick_phase2steps = 2000
|
| 144 |
+
need_reset()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def empty_queue(q: asyncio.Queue):
|
| 148 |
+
for _ in range(q.qsize()):
|
| 149 |
+
# Depending on your program, you may want to
|
| 150 |
+
# catch QueueEmpty
|
| 151 |
+
q.get_nowait()
|
| 152 |
+
q.task_done()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def reset(**kwargs):
|
| 156 |
+
init_training(**kwargs)
|
| 157 |
+
init_reset()
|
| 158 |
+
st.session_state.play = False
|
| 159 |
+
st.session_state.step = False
|
| 160 |
+
st.session_state.render = False
|
| 161 |
+
st.session_state.done = False
|
| 162 |
+
empty_queue(st.session_state.queue)
|
| 163 |
+
empty_queue(st.session_state.queue_render)
|
| 164 |
+
st.session_state.play_pause = False
|
| 165 |
+
st.session_state.need_reset = False
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def render_start():
|
| 169 |
+
st.session_state.render = True
|
| 170 |
+
st.session_state.done = False
|
| 171 |
+
init_reset()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def need_reset():
|
| 175 |
+
st.session_state.need_reset = True
|
| 176 |
+
st.session_state.play = False
|
| 177 |
+
st.session_state.render = False
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def play_pause():
|
| 181 |
+
if st.session_state.play:
|
| 182 |
+
st.session_state.play = False
|
| 183 |
+
st.session_state.play_pause = False
|
| 184 |
+
else:
|
| 185 |
+
st.session_state.play = True
|
| 186 |
+
st.session_state.play_pause = True
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def step():
|
| 190 |
+
st.session_state.step = True
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def plot(data_placeholder):
|
| 194 |
+
df = pd.DataFrame(st.session_state.data)
|
| 195 |
+
if not df.empty:
|
| 196 |
+
df.set_index("total_env_steps", inplace=True)
|
| 197 |
+
container = data_placeholder.container()
|
| 198 |
+
c1, c2, c3 = container.columns(3)
|
| 199 |
+
|
| 200 |
+
def view_df(names):
|
| 201 |
+
rdf = df.loc[:, df.columns.isin(names)]
|
| 202 |
+
if rdf.empty:
|
| 203 |
+
return pd.DataFrame([{name: 0 for name in names}])
|
| 204 |
+
else:
|
| 205 |
+
return rdf
|
| 206 |
+
|
| 207 |
+
c1.write("phase1/success_rate")
|
| 208 |
+
c1.line_chart(view_df(["phase1/success"]))
|
| 209 |
+
c2.write("phase2/success_rate")
|
| 210 |
+
c2.line_chart(view_df(["phase2/success"]))
|
| 211 |
+
c3.write("full/success_rate")
|
| 212 |
+
c3.line_chart(view_df(["full/success"]))
|
| 213 |
+
|
| 214 |
+
c1.write("train/loss")
|
| 215 |
+
c1.line_chart(view_df(["train/loss"]))
|
| 216 |
+
c2.write("weight0")
|
| 217 |
+
c2.line_chart(view_df(["weight0"]))
|
| 218 |
+
c3.write("weight1")
|
| 219 |
+
c3.line_chart(view_df(["weight1"]))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
async def draw(data_placeholder, queue, delay, steps, plotfrequency):
|
| 223 |
+
while (st.session_state.play or st.session_state.step) and st.session_state.draw_it < steps:
|
| 224 |
+
_ = await asyncio.sleep(delay)
|
| 225 |
+
new_data = await queue.get()
|
| 226 |
+
st.session_state.draw_it += 1
|
| 227 |
+
if st.session_state.draw_it % plotfrequency == 0:
|
| 228 |
+
st.session_state.data.append(new_data)
|
| 229 |
+
plot(data_placeholder)
|
| 230 |
+
st.session_state.step = False
|
| 231 |
+
queue.task_done()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
async def train(queue, delay, steps, obs, env, num_eval_episodes, plotfrequency):
|
| 235 |
+
while (st.session_state.play or st.session_state.step) and st.session_state.train_it < steps:
|
| 236 |
+
_ = await asyncio.sleep(delay)
|
| 237 |
+
st.session_state.train_it += 1
|
| 238 |
+
st.session_state.total_steps += 1
|
| 239 |
+
|
| 240 |
+
output = st.session_state.model(obs)
|
| 241 |
+
action, log_prob = st.session_state.model.sample(output)
|
| 242 |
+
st.session_state.trainer.update(
|
| 243 |
+
env, output, st.session_state.model, st.session_state.optim, st.session_state.logger
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
obs, reward, done, info = env.step(action)
|
| 247 |
+
|
| 248 |
+
if done:
|
| 249 |
+
obs = env.reset()
|
| 250 |
+
|
| 251 |
+
if st.session_state.train_it % plotfrequency == 0:
|
| 252 |
+
st.session_state.trainer.test_agent(
|
| 253 |
+
st.session_state.model, st.session_state.logger, st.session_state.test_envs, num_eval_episodes
|
| 254 |
+
)
|
| 255 |
+
data = st.session_state.trainer.log(
|
| 256 |
+
st.session_state.logger, st.session_state.train_it, st.session_state.model
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
data = 0
|
| 260 |
+
|
| 261 |
+
_ = await queue.put(data)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
async def produce_images(queue, delay):
|
| 265 |
+
while st.session_state.render and not st.session_state.done:
|
| 266 |
+
_ = await asyncio.sleep(delay)
|
| 267 |
+
data = next(st.session_state.rollout_iterator)
|
| 268 |
+
st.session_state.done = data["done"]
|
| 269 |
+
_ = await queue.put(data)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def show_image(data, image_placeholder):
|
| 273 |
+
c = image_placeholder.container()
|
| 274 |
+
c.image(
|
| 275 |
+
data["pixel_array"],
|
| 276 |
+
)
|
| 277 |
+
c.text(
|
| 278 |
+
f"agent position: {data['x']} \ntimestep: {data['episode_len']} \nepisode return: {data['episode_return']} \n"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
async def consume_images(image_placeholder, queue, delay):
|
| 283 |
+
while st.session_state.render and not st.session_state.done:
|
| 284 |
+
_ = await asyncio.sleep(delay)
|
| 285 |
+
data = await queue.get()
|
| 286 |
+
st.session_state.last_image = data
|
| 287 |
+
show_image(data, image_placeholder)
|
| 288 |
+
queue.task_done()
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
async def run_app(
|
| 292 |
+
data_placeholder,
|
| 293 |
+
queue,
|
| 294 |
+
produce_delay,
|
| 295 |
+
consume_delay,
|
| 296 |
+
phase1steps,
|
| 297 |
+
phase2steps,
|
| 298 |
+
plotfrequency,
|
| 299 |
+
num_eval_episodes,
|
| 300 |
+
image_placeholder,
|
| 301 |
+
queue_render,
|
| 302 |
+
render_produce_delay,
|
| 303 |
+
render_consume_delay,
|
| 304 |
+
):
|
| 305 |
+
_ = await asyncio.gather(
|
| 306 |
+
produce_images(queue_render, render_produce_delay),
|
| 307 |
+
consume_images(image_placeholder, queue_render, render_consume_delay),
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
st.session_state.render = False
|
| 311 |
+
st.session_state.done = False
|
| 312 |
+
|
| 313 |
+
empty_queue(queue_render)
|
| 314 |
+
|
| 315 |
+
_ = await asyncio.gather(
|
| 316 |
+
train(
|
| 317 |
+
queue,
|
| 318 |
+
produce_delay,
|
| 319 |
+
phase1steps,
|
| 320 |
+
st.session_state.obs1,
|
| 321 |
+
st.session_state.env_phase1,
|
| 322 |
+
num_eval_episodes,
|
| 323 |
+
plotfrequency,
|
| 324 |
+
),
|
| 325 |
+
draw(data_placeholder, queue, consume_delay, phase1steps, plotfrequency),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
_ = await asyncio.gather(
|
| 329 |
+
train(
|
| 330 |
+
queue,
|
| 331 |
+
produce_delay,
|
| 332 |
+
phase1steps + phase2steps,
|
| 333 |
+
st.session_state.obs2,
|
| 334 |
+
st.session_state.env_phase2,
|
| 335 |
+
num_eval_episodes,
|
| 336 |
+
plotfrequency,
|
| 337 |
+
),
|
| 338 |
+
draw(data_placeholder, queue, consume_delay, phase1steps + phase2steps, plotfrequency),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
##### ACTUAL APP
|
| 343 |
+
|
| 344 |
+
if __name__ == "__main__":
|
| 345 |
+
st.set_page_config(
|
| 346 |
+
layout="wide",
|
| 347 |
+
initial_sidebar_state="auto",
|
| 348 |
+
page_title="Apple Retrieval",
|
| 349 |
+
page_icon=None,
|
| 350 |
+
)
|
| 351 |
+
st.title("ON THE ROLE OF FORGETTING IN FINE-TUNING REINFORCEMENT LEARNING MODELS")
|
| 352 |
+
st.header("Toy example of forgetting: AppleRetrieval")
|
| 353 |
+
|
| 354 |
+
col1, col2, col3 = st.sidebar.columns(3)
|
| 355 |
+
|
| 356 |
+
options = (
|
| 357 |
+
"phase2 full interference",
|
| 358 |
+
"full from scratch",
|
| 359 |
+
"phase2 phase1 forgetting",
|
| 360 |
+
"phase2 phase1 optimal solution",
|
| 361 |
+
)
|
| 362 |
+
st.sidebar.selectbox(
|
| 363 |
+
"parameter presets",
|
| 364 |
+
range(len(options)),
|
| 365 |
+
index=0,
|
| 366 |
+
format_func=lambda x: options[x],
|
| 367 |
+
on_change=select_preset,
|
| 368 |
+
key="pick_preset",
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
pick_container = st.sidebar.container()
|
| 372 |
+
c = pick_container.number_input("c", value=0.5, on_change=need_reset, key="pick_c")
|
| 373 |
+
goal_x = pick_container.number_input("distance to apple", value=20, on_change=need_reset, key="pick_goal_x")
|
| 374 |
+
time_limit = pick_container.number_input("time limit", value=50, on_change=need_reset, key="pick_time_limit")
|
| 375 |
+
lr = pick_container.selectbox(
|
| 376 |
+
"Learning rate",
|
| 377 |
+
np.array([0.00001, 0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0]),
|
| 378 |
+
5,
|
| 379 |
+
on_change=need_reset,
|
| 380 |
+
key="pick_lr",
|
| 381 |
+
)
|
| 382 |
+
phase1task = pick_container.selectbox(
|
| 383 |
+
"Pretraining task", ("full", "phase1", "phase2"), 2, on_change=need_reset, key="pick_phase1task"
|
| 384 |
+
)
|
| 385 |
+
phase2task = pick_container.selectbox(
|
| 386 |
+
"Finetuning task", ("full", "phase1", "phase2"), 0, on_change=need_reset, key="pick_phase2task"
|
| 387 |
+
)
|
| 388 |
+
# weight1 = pick_container.number_input("init weight1", value=0, on_change=need_reset, key="pick_weight1")
|
| 389 |
+
# weight2 = pick_container.number_input("init weight2", value=0, on_change=need_reset, key="pick_weight2")
|
| 390 |
+
|
| 391 |
+
if "event_loop" not in st.session_state:
|
| 392 |
+
st.session_state.loop = asyncio.new_event_loop()
|
| 393 |
+
asyncio.set_event_loop(st.session_state.loop)
|
| 394 |
+
|
| 395 |
+
if "queue" not in st.session_state:
|
| 396 |
+
st.session_state.queue = asyncio.Queue(QUEUE_SIZE)
|
| 397 |
+
if "queue_render" not in st.session_state:
|
| 398 |
+
st.session_state.queue_render = asyncio.Queue(QUEUE_SIZE)
|
| 399 |
+
if "play" not in st.session_state:
|
| 400 |
+
st.session_state.play = False
|
| 401 |
+
if "step" not in st.session_state:
|
| 402 |
+
st.session_state.step = False
|
| 403 |
+
if "render" not in st.session_state:
|
| 404 |
+
st.session_state.render = False
|
| 405 |
+
|
| 406 |
+
reset_button = partial(
|
| 407 |
+
reset,
|
| 408 |
+
c=c,
|
| 409 |
+
start_x=0,
|
| 410 |
+
goal_x=goal_x,
|
| 411 |
+
time_limit=time_limit,
|
| 412 |
+
lr=lr,
|
| 413 |
+
# weight1=weight1,
|
| 414 |
+
# weight2=weight2,
|
| 415 |
+
# weight1=0,
|
| 416 |
+
# weight2=10, # soves the environment
|
| 417 |
+
pretrain=phase1task,
|
| 418 |
+
finetune=phase2task,
|
| 419 |
+
)
|
| 420 |
+
col1.button("Reset", on_click=reset_button, type="primary")
|
| 421 |
+
|
| 422 |
+
if "logger" not in st.session_state or st.session_state.need_reset:
|
| 423 |
+
reset_button()
|
| 424 |
+
|
| 425 |
+
myKey = "play_pause"
|
| 426 |
+
if myKey not in st.session_state:
|
| 427 |
+
st.session_state[myKey] = False
|
| 428 |
+
|
| 429 |
+
if st.session_state[myKey]:
|
| 430 |
+
myBtn = col2.button("Pause", on_click=play_pause, type="primary")
|
| 431 |
+
else:
|
| 432 |
+
myBtn = col2.button("Play", on_click=play_pause, type="primary")
|
| 433 |
+
|
| 434 |
+
col3.button("Step", on_click=step, type="primary")
|
| 435 |
+
|
| 436 |
+
st.header("Summary")
|
| 437 |
+
st.write(
|
| 438 |
+
"""
|
| 439 |
+
Run training on the "phase2 full interference" setting to see an example of forgetting in fine-tuning RL models.
|
| 440 |
+
A model is pre-trained on a part of the environment called Phase 2 for 500 steps,
|
| 441 |
+
and then it is fine-tuned on the whole environment for another 2000 steps.
|
| 442 |
+
However, it forgets how to perform on Phase 2 during fine-tuning before it even gets there.
|
| 443 |
+
We highlight this as an important problem in fine-tuning RL models.
|
| 444 |
+
We invite you to play around with the hyperparameters and find out more about the forgetting phenomenon.
|
| 445 |
+
"""
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
st.markdown(
|
| 449 |
+
"""
|
| 450 |
+
By comparing the results of these four presets, you can gain a deeper understanding of how different training configurations can impact the final performance of a neural network.
|
| 451 |
+
* `phase2 full interference` - already mentioned setting where we pretrain on phase 2 and finetune on whole environment.
|
| 452 |
+
* `full from scratch` - "standard" train on whole environment for 2000 steps for comparison.
|
| 453 |
+
* `phase2 phase1 forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks causes network to forget how to solve phase 2.
|
| 454 |
+
* `phase2 phase1 without forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks "move" network weights to optimal solution for both tasks.
|
| 455 |
+
"""
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
data_placeholder = st.empty()
|
| 459 |
+
render_placeholder = st.empty()
|
| 460 |
+
|
| 461 |
+
phase1steps = st.sidebar.slider("Pretraining Steps", 0, 2000, 500, 10, key="pick_phase1steps")
|
| 462 |
+
phase2steps = st.sidebar.slider("Finetuning Steps", 0, 2000, 2000, 10, key="pick_phase2steps")
|
| 463 |
+
plotfrequency = st.sidebar.number_input("Log frequency ", min_value=1, value=10, step=1)
|
| 464 |
+
num_eval_episodes = st.sidebar.number_input("Number of evaluation episodes", min_value=1, value=10, step=1)
|
| 465 |
+
|
| 466 |
+
if not st.session_state.play and len(st.session_state.data) > 0:
|
| 467 |
+
plot(data_placeholder)
|
| 468 |
+
|
| 469 |
+
st.write(
|
| 470 |
+
"""
|
| 471 |
+
Figure 1: Plots (a), (b), (c) show the performance of the agent on three different variants of the environment.
|
| 472 |
+
Phase1 (a) the agent's goal is to reach the apple.
|
| 473 |
+
Phase2 (b) the agent's has to go back home by returning to x = 0.
|
| 474 |
+
Full (c) combination of both phases of the environment.
|
| 475 |
+
(d) shows the training loss.
|
| 476 |
+
State vector has only two values, we can view the netwoks weights (e), (f).
|
| 477 |
+
"""
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
st.header("Vizualize agent behavior")
|
| 481 |
+
st.button("Rollout one episode", on_click=render_start, type="primary")
|
| 482 |
+
|
| 483 |
+
c1, c2 = st.columns(2)
|
| 484 |
+
image_placeholder = c1.empty()
|
| 485 |
+
|
| 486 |
+
render_produce_delay = 1 / 3
|
| 487 |
+
render_consume_delay = 1 / 3
|
| 488 |
+
|
| 489 |
+
st.header("About the environment")
|
| 490 |
+
|
| 491 |
+
st.write(
|
| 492 |
+
"""This study forms a preliminary investigation into the problem of forgetting in fine-tuning RL models.
|
| 493 |
+
We show that fine-tuning a pre-trained model on compositional RL problems might result in a rapid
|
| 494 |
+
deterioration of the performance of the pre-trained model if the relevant data is not available at the
|
| 495 |
+
beginning of the training.
|
| 496 |
+
This phenomenon is known as catastrophic forgetting.
|
| 497 |
+
In this demo we show how it can occur in simple toyish situations but it might occur in more realistic problems (e.g. (SAC with MLPs on a compositional robotic environmen).
|
| 498 |
+
In our [WIP] paper we showed that applying CL methods significantly limits forgetting and allows for efficient transfer.
|
| 499 |
+
"""
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
st.write(
|
| 503 |
+
"""The AppleRetrieval environment is a toy example to demonstrate the issue of interference in reinforcement learning.
|
| 504 |
+
The goal of the agent is to retrieve an apple from position x = M and return home to x = 0 within a set number of steps T.
|
| 505 |
+
The state of the agent is represented by a vector s, which has two elements: s = [1, -c] in phase 1 and s = [1, c] in phase 2.
|
| 506 |
+
The first element is a constant and the second element represents the information about the current phase.
|
| 507 |
+
The optimal policy is to go right in phase 1 and go left in phase 2.
|
| 508 |
+
The cause of interference can be identified by checking the reliance of the policy on either s1 or s2.
|
| 509 |
+
If the model mostly relies on s2, interference will be limited, but if it relies on s1,
|
| 510 |
+
interference will occur as its value is the same in both phases.
|
| 511 |
+
The magnitude of s1 and s2 can be adjusted by changing the c parameter to guide the model towards focusing on either one.
|
| 512 |
+
This simple toy environment shows that the issue of interference can be fundamental to reinforcement learning."""
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
cc1, cc2 = st.columns([1, 2])
|
| 516 |
+
cc1.image("assets/apple_env.png")
|
| 517 |
+
|
| 518 |
+
st.header("Training algorithm")
|
| 519 |
+
st.write(
|
| 520 |
+
"""For practical reasons (time and stability) we don't use REINFORCE in this DEMO to illustrate the training dynamics of the environment.
|
| 521 |
+
Instead, we train the model by minimizing the negative log likelihood of target actions (move right or left).
|
| 522 |
+
We train the model in each step of the environment and sample actions from Categorical distribution taken from model's output.
|
| 523 |
+
"""
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
st.markdown(
|
| 527 |
+
"""
|
| 528 |
+
Pseudocode of the train loop:
|
| 529 |
+
```python
|
| 530 |
+
obs = env.reset()
|
| 531 |
+
for timestep in range(steps):
|
| 532 |
+
probs = model(obs)
|
| 533 |
+
dist = Categorical(probs)
|
| 534 |
+
action = dist.sample()
|
| 535 |
+
target_action = env.get_target_action()
|
| 536 |
+
loss = -dist.log_prob(target_action)
|
| 537 |
+
|
| 538 |
+
optim.zero_grad()
|
| 539 |
+
loss.backward()
|
| 540 |
+
optim.step()
|
| 541 |
+
|
| 542 |
+
obs, reward, done, info = env.step(action)
|
| 543 |
+
if done:
|
| 544 |
+
obs = env.reset()
|
| 545 |
+
"""
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
st.header("What do all the hyperparameters mean?")
|
| 549 |
+
|
| 550 |
+
st.markdown(
|
| 551 |
+
"""
|
| 552 |
+
* **parameter presets** - few sets of pre-defined hyperparameters that can be used as a starting point for a specific experiment.
|
| 553 |
+
* **c** - second element of state vecor, decreasing this value will result in stronger forgetting.
|
| 554 |
+
* **distance to apple** - refers to how far agent needs to travel in the right before it encounters the apple.
|
| 555 |
+
* **time limit** - maximum amount of timesteps that the agent is allowed to interact with the environment before episode ends.
|
| 556 |
+
* **learning rate** - hyperparameter that determines the steps size taken by the learning algoritm in each iteration.
|
| 557 |
+
* **pretraining and finetuning task** - define the environment the model will be trained on each stage.
|
| 558 |
+
* **pretraining and finetuning steps** - define how long model will be trained on each stage.
|
| 559 |
+
* **init weights** - initial values assigned to the model's parameters before trainig begins.
|
| 560 |
+
* **log frequency** - refers to frequency of logging the metrics and evaluating the agent.
|
| 561 |
+
* **number of evaluation episodes** - number of rollouts during testing of the agent.
|
| 562 |
+
"""
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
st.header("Limitations & Conclusions")
|
| 566 |
+
|
| 567 |
+
st.write(
|
| 568 |
+
"""
|
| 569 |
+
At the same time, this study, due to its preliminary nature, has numerous limitations which we
|
| 570 |
+
hope to address in future work. We only considered a fairly strict formulation of the forgetting
|
| 571 |
+
scenario where we assumed that the pre-trained model works perfectly on tasks that appear later in
|
| 572 |
+
the fine-tuning. In practice, one should also consider the case when even though there are differences
|
| 573 |
+
between the pre-training and fine-tuning tasks, transfer is still possible.
|
| 574 |
+
"""
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
st.write(
|
| 578 |
+
"""
|
| 579 |
+
At the same time, even given
|
| 580 |
+
these limitations, we see forgetting as an important problem to be solved and hope that addressing
|
| 581 |
+
these issues in the future might help with building and fine-tuning better foundation models in RL.
|
| 582 |
+
"""
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
produce_delay = 1 / 1000
|
| 586 |
+
consume_delay = 1 / 1000
|
| 587 |
+
|
| 588 |
+
plot(data_placeholder)
|
| 589 |
+
show_image(st.session_state.last_image, image_placeholder)
|
| 590 |
+
|
| 591 |
+
asyncio.run(
|
| 592 |
+
run_app(
|
| 593 |
+
data_placeholder,
|
| 594 |
+
st.session_state.queue,
|
| 595 |
+
produce_delay,
|
| 596 |
+
consume_delay,
|
| 597 |
+
phase1steps,
|
| 598 |
+
phase2steps,
|
| 599 |
+
plotfrequency,
|
| 600 |
+
num_eval_episodes,
|
| 601 |
+
image_placeholder,
|
| 602 |
+
st.session_state.queue_render,
|
| 603 |
+
render_produce_delay,
|
| 604 |
+
render_consume_delay,
|
| 605 |
+
)
|
| 606 |
+
)
|
apple/envs/discrete_apple.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import closing
|
| 2 |
+
from io import StringIO
|
| 3 |
+
from os import path
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import gym
|
| 7 |
+
import gym.spaces
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from gym.error import DependencyNotInstalled
|
| 11 |
+
from gym.spaces import Box
|
| 12 |
+
from gym.utils import colorize
|
| 13 |
+
from gym.wrappers import TimeLimit
|
| 14 |
+
|
| 15 |
+
from apple.wrappers import SuccessCounter
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_apple_env(task, time_limit=20, **kwargs):
|
| 19 |
+
if task == "full":
|
| 20 |
+
env = AppleEnv(**kwargs)
|
| 21 |
+
elif task == "phase1":
|
| 22 |
+
env = ApplePhase0Env(**kwargs)
|
| 23 |
+
time_limit = time_limit // 2
|
| 24 |
+
elif task == "phase2":
|
| 25 |
+
env = ApplePhase1Env(**kwargs)
|
| 26 |
+
time_limit = time_limit // 2
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
env = TimeLimit(env, time_limit)
|
| 30 |
+
env = SuccessCounter(env)
|
| 31 |
+
env.name = task
|
| 32 |
+
return env
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AppleEnv(gym.Env):
|
| 36 |
+
metadata = {
|
| 37 |
+
"render_modes": ["human", "ansi", "rgb_array"],
|
| 38 |
+
"render_fps": 4,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
start_x: int,
|
| 44 |
+
goal_x: int,
|
| 45 |
+
c: float,
|
| 46 |
+
reward_value: float = 1.0,
|
| 47 |
+
success_value: float = 1.0,
|
| 48 |
+
bias_in_state: bool = True,
|
| 49 |
+
position_in_state: bool = False,
|
| 50 |
+
apple_in_state: bool = True,
|
| 51 |
+
render_mode: Optional[str] = None,
|
| 52 |
+
):
|
| 53 |
+
self.start_x = start_x
|
| 54 |
+
self.goal_x = goal_x
|
| 55 |
+
self.c = c
|
| 56 |
+
self.reward_value = reward_value
|
| 57 |
+
self.success_value = success_value
|
| 58 |
+
|
| 59 |
+
self.x = start_x
|
| 60 |
+
self.phase = 0
|
| 61 |
+
self.delta = 1
|
| 62 |
+
self.change_phase = True
|
| 63 |
+
self.init_pos = self.start_x
|
| 64 |
+
self.success_when_finish_phase = 1
|
| 65 |
+
|
| 66 |
+
self.bias_in_state = bias_in_state
|
| 67 |
+
self.position_in_state = position_in_state
|
| 68 |
+
self.apple_in_state = apple_in_state
|
| 69 |
+
|
| 70 |
+
example_state = self.state()
|
| 71 |
+
mult = np.ones_like(example_state)
|
| 72 |
+
if self.apple_in_state:
|
| 73 |
+
mult[-1] *= -1
|
| 74 |
+
self.observation_space = Box(low=example_state, high=example_state * mult)
|
| 75 |
+
self.action_space = gym.spaces.Discrete(2)
|
| 76 |
+
self.timestep = 0
|
| 77 |
+
|
| 78 |
+
self.render_mode = render_mode
|
| 79 |
+
self.scope_size = 15
|
| 80 |
+
|
| 81 |
+
self.nrow, self.ncol = nrow, ncol = self.gui_canvas().shape
|
| 82 |
+
self.window_size = (64 * ncol, 64 * nrow)
|
| 83 |
+
self.cell_size = (
|
| 84 |
+
self.window_size[0] // self.ncol,
|
| 85 |
+
self.window_size[1] // self.nrow,
|
| 86 |
+
)
|
| 87 |
+
self.window_surface = None
|
| 88 |
+
|
| 89 |
+
self.clock = None
|
| 90 |
+
self.empty_img = None
|
| 91 |
+
self.ground_img = None
|
| 92 |
+
self.underground_img = None
|
| 93 |
+
self.elf_images = None
|
| 94 |
+
self.home_images = None
|
| 95 |
+
self.goal_img = None
|
| 96 |
+
self.start_img = None
|
| 97 |
+
self.stool_img = None
|
| 98 |
+
|
| 99 |
+
def reset(self):
|
| 100 |
+
self.x = self.init_pos
|
| 101 |
+
if self.change_phase:
|
| 102 |
+
self.phase = 0
|
| 103 |
+
|
| 104 |
+
self.timestep = 0
|
| 105 |
+
state = self.state()
|
| 106 |
+
self.lastaction = None
|
| 107 |
+
return state
|
| 108 |
+
|
| 109 |
+
def validate_action(self, action):
|
| 110 |
+
err_msg = f"{action!r} ({type(action)}) invalid"
|
| 111 |
+
assert self.action_space.contains(action), err_msg
|
| 112 |
+
assert self.state is not None, "Call reset before using step method."
|
| 113 |
+
|
| 114 |
+
def move(self, action):
|
| 115 |
+
delta = self.delta if action == 1 else -self.delta
|
| 116 |
+
self.x += delta
|
| 117 |
+
|
| 118 |
+
def state(self):
|
| 119 |
+
assert self.bias_in_state or self.position_in_state or self.apple_in_state
|
| 120 |
+
|
| 121 |
+
if self.phase == 0:
|
| 122 |
+
c = -self.c
|
| 123 |
+
elif self.phase == 1:
|
| 124 |
+
c = self.c
|
| 125 |
+
else:
|
| 126 |
+
raise NotImplementedError
|
| 127 |
+
|
| 128 |
+
state = []
|
| 129 |
+
if self.bias_in_state:
|
| 130 |
+
state.append(1)
|
| 131 |
+
if self.position_in_state:
|
| 132 |
+
state.append(self.x)
|
| 133 |
+
if self.apple_in_state:
|
| 134 |
+
state.append(c)
|
| 135 |
+
|
| 136 |
+
return np.array(state)
|
| 137 |
+
|
| 138 |
+
def desc(self):
|
| 139 |
+
s = self.scope_size // 2
|
| 140 |
+
|
| 141 |
+
desc = list("." * self.scope_size)
|
| 142 |
+
desc = np.asarray(desc, dtype="c")
|
| 143 |
+
|
| 144 |
+
start_relative_position = self.start_x - self.x + s
|
| 145 |
+
if 0 <= start_relative_position <= self.scope_size - 1:
|
| 146 |
+
desc[start_relative_position] = "S"
|
| 147 |
+
|
| 148 |
+
goal_relative_position = self.goal_x - self.x + s
|
| 149 |
+
if 0 <= goal_relative_position <= self.scope_size - 1 and self.phase == 0:
|
| 150 |
+
desc[goal_relative_position] = "G"
|
| 151 |
+
|
| 152 |
+
if 0 <= goal_relative_position <= self.scope_size - 1 and self.phase == 1:
|
| 153 |
+
desc[goal_relative_position] = "D"
|
| 154 |
+
|
| 155 |
+
return desc
|
| 156 |
+
|
| 157 |
+
def text_canvas(self):
|
| 158 |
+
desc = self.desc()
|
| 159 |
+
canvas = np.ones((2, len(desc) * 3 + 2), dtype="c")
|
| 160 |
+
canvas[:] = "\x20"
|
| 161 |
+
|
| 162 |
+
for i, d in zip(range(2, len(canvas[0]), 3), desc):
|
| 163 |
+
canvas[0][i] = d
|
| 164 |
+
|
| 165 |
+
axis = np.arange(len(desc)) - len(desc) // 2 + self.x
|
| 166 |
+
for i, d in zip(range(2, len(canvas[0]), 3), axis):
|
| 167 |
+
if d % 5 == 0:
|
| 168 |
+
s = str(d)
|
| 169 |
+
|
| 170 |
+
c = len(s) // 2
|
| 171 |
+
for j, char in zip(range(len(s)), reversed(s)):
|
| 172 |
+
canvas[1][i - j + c] = char
|
| 173 |
+
return canvas
|
| 174 |
+
|
| 175 |
+
def gui_canvas(self):
|
| 176 |
+
desc = self.desc()
|
| 177 |
+
upper_canvas = np.ones(len(desc), dtype="c")
|
| 178 |
+
upper_canvas[:] = "~"
|
| 179 |
+
lower_canvas = np.ones(len(desc), dtype="c")
|
| 180 |
+
lower_canvas[:] = "#"
|
| 181 |
+
canvas = np.stack([upper_canvas, upper_canvas, desc, lower_canvas])
|
| 182 |
+
|
| 183 |
+
return canvas
|
| 184 |
+
|
| 185 |
+
def reward(self, action):
|
| 186 |
+
return self.reward_value if action == self.get_target_action() else -self.reward_value
|
| 187 |
+
|
| 188 |
+
def get_target_action(self):
|
| 189 |
+
return 1 if self.phase == 0 else 0
|
| 190 |
+
|
| 191 |
+
def step(self, action):
|
| 192 |
+
self.timestep += 1
|
| 193 |
+
self.validate_action(action)
|
| 194 |
+
|
| 195 |
+
done = False
|
| 196 |
+
info = {"success": False}
|
| 197 |
+
|
| 198 |
+
reward = self.reward(action)
|
| 199 |
+
|
| 200 |
+
self.move(action)
|
| 201 |
+
|
| 202 |
+
finish_phase0 = self.phase == 0 and self.x >= self.goal_x
|
| 203 |
+
finish_phase1 = self.phase == 1 and self.x <= self.start_x
|
| 204 |
+
|
| 205 |
+
if self.change_phase:
|
| 206 |
+
if finish_phase0:
|
| 207 |
+
self.phase = 1
|
| 208 |
+
|
| 209 |
+
if (self.success_when_finish_phase == 0 and finish_phase0) or (
|
| 210 |
+
self.success_when_finish_phase == 1 and finish_phase1
|
| 211 |
+
):
|
| 212 |
+
done = True
|
| 213 |
+
info["success"] = True
|
| 214 |
+
reward = self.success_value
|
| 215 |
+
|
| 216 |
+
state = self.state()
|
| 217 |
+
|
| 218 |
+
self.lastaction = action
|
| 219 |
+
if self.render_mode == "human":
|
| 220 |
+
self.render()
|
| 221 |
+
|
| 222 |
+
return state, reward, done, info
|
| 223 |
+
|
| 224 |
+
def render(self):
|
| 225 |
+
if self.render_mode is None:
|
| 226 |
+
assert self.spec is not None
|
| 227 |
+
gym.logger.warn(
|
| 228 |
+
"You are calling render method without specifying any render mode. "
|
| 229 |
+
"You can specify the render_mode at initialization, "
|
| 230 |
+
f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")'
|
| 231 |
+
)
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
if self.render_mode == "ansi":
|
| 235 |
+
return self._render_text()
|
| 236 |
+
else: # self.render_mode in {"human", "rgb_array"}:
|
| 237 |
+
return self._render_gui(self.render_mode)
|
| 238 |
+
|
| 239 |
+
def _render_gui(self, mode):
|
| 240 |
+
try:
|
| 241 |
+
import pygame
|
| 242 |
+
except ImportError as e:
|
| 243 |
+
raise DependencyNotInstalled("pygame is not installed, run `pip install pygame`") from e
|
| 244 |
+
|
| 245 |
+
if self.window_surface is None:
|
| 246 |
+
pygame.init()
|
| 247 |
+
|
| 248 |
+
if mode == "human":
|
| 249 |
+
pygame.display.init()
|
| 250 |
+
pygame.display.set_caption("Apple Retrieval")
|
| 251 |
+
self.window_surface = pygame.display.set_mode(self.window_size)
|
| 252 |
+
elif mode == "rgb_array":
|
| 253 |
+
self.window_surface = pygame.Surface(self.window_size)
|
| 254 |
+
|
| 255 |
+
assert self.window_surface is not None, "Something went wrong with pygame. This should never happen."
|
| 256 |
+
|
| 257 |
+
if self.clock is None:
|
| 258 |
+
self.clock = pygame.time.Clock()
|
| 259 |
+
if self.empty_img is None:
|
| 260 |
+
file_name = path.join(path.dirname(__file__), "img/white.png")
|
| 261 |
+
self.empty_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
|
| 262 |
+
if self.ground_img is None:
|
| 263 |
+
file_name = path.join(path.dirname(__file__), "img/part_grass.png")
|
| 264 |
+
self.ground_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
|
| 265 |
+
if self.underground_img is None:
|
| 266 |
+
file_name = path.join(path.dirname(__file__), "img/g2.png")
|
| 267 |
+
self.underground_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
|
| 268 |
+
if self.goal_img is None:
|
| 269 |
+
file_name = path.join(path.dirname(__file__), "img/apple.png")
|
| 270 |
+
self.goal_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
|
| 271 |
+
if self.stool_img is None:
|
| 272 |
+
file_name = path.join(path.dirname(__file__), "img/stool.png")
|
| 273 |
+
self.stool_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size)
|
| 274 |
+
if self.start_img is None:
|
| 275 |
+
homes = [
|
| 276 |
+
path.join(path.dirname(__file__), "img/home00.png"),
|
| 277 |
+
path.join(path.dirname(__file__), "img/home01.png"),
|
| 278 |
+
path.join(path.dirname(__file__), "img/home02.png"),
|
| 279 |
+
path.join(path.dirname(__file__), "img/home10.png"),
|
| 280 |
+
path.join(path.dirname(__file__), "img/home11.png"),
|
| 281 |
+
path.join(path.dirname(__file__), "img/home12.png"),
|
| 282 |
+
]
|
| 283 |
+
self.home_images = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in homes]
|
| 284 |
+
if self.elf_images is None:
|
| 285 |
+
elfs = [
|
| 286 |
+
path.join(path.dirname(__file__), "img/elf_left.png"),
|
| 287 |
+
path.join(path.dirname(__file__), "img/elf_right.png"),
|
| 288 |
+
path.join(path.dirname(__file__), "img/elf_down.png"),
|
| 289 |
+
]
|
| 290 |
+
self.elf_images = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in elfs]
|
| 291 |
+
|
| 292 |
+
desc = self.gui_canvas().tolist()
|
| 293 |
+
|
| 294 |
+
cache = []
|
| 295 |
+
assert isinstance(desc, list), f"desc should be a list or an array, got {desc}"
|
| 296 |
+
for y in range(self.nrow):
|
| 297 |
+
for x in range(self.ncol):
|
| 298 |
+
pos = (x * self.cell_size[0], y * self.cell_size[1])
|
| 299 |
+
|
| 300 |
+
self.window_surface.blit(self.empty_img, pos)
|
| 301 |
+
|
| 302 |
+
if desc[y][x] == b"~":
|
| 303 |
+
self.window_surface.blit(self.empty_img, pos)
|
| 304 |
+
elif desc[y][x] == b"#":
|
| 305 |
+
self.window_surface.blit(self.underground_img, pos)
|
| 306 |
+
else:
|
| 307 |
+
self.window_surface.blit(self.ground_img, pos)
|
| 308 |
+
# if y == self.nrow - 1:
|
| 309 |
+
|
| 310 |
+
if len(cache) > 0:
|
| 311 |
+
cache_img, cache_pos = cache.pop()
|
| 312 |
+
self.window_surface.blit(cache_img, cache_pos)
|
| 313 |
+
|
| 314 |
+
if desc[y][x] == b"G":
|
| 315 |
+
self.window_surface.blit(self.stool_img, pos)
|
| 316 |
+
self.window_surface.blit(self.goal_img, pos)
|
| 317 |
+
elif desc[y][x] == b"D":
|
| 318 |
+
self.window_surface.blit(self.stool_img, pos)
|
| 319 |
+
elif desc[y][x] == b"S":
|
| 320 |
+
for h in range(len(self.home_images)):
|
| 321 |
+
i = h // 3
|
| 322 |
+
j = h % 3
|
| 323 |
+
|
| 324 |
+
home_img = self.home_images[i * 3 + j]
|
| 325 |
+
home_pos = ((x - 1 + j) * self.cell_size[0], (y - 1 + i) * self.cell_size[1])
|
| 326 |
+
if h == len(self.home_images) - 1:
|
| 327 |
+
cache.append((home_img, home_pos))
|
| 328 |
+
else:
|
| 329 |
+
self.window_surface.blit(home_img, home_pos)
|
| 330 |
+
|
| 331 |
+
# paint the elf
|
| 332 |
+
# bot_row, bot_col = self.s // self.ncol, self.s % self.ncol
|
| 333 |
+
bot_col = self.scope_size // 2
|
| 334 |
+
bot_row = 2
|
| 335 |
+
cell_rect = (bot_col * self.cell_size[0], bot_row * self.cell_size[1])
|
| 336 |
+
last_action = self.lastaction if self.lastaction is not None else 2
|
| 337 |
+
elf_img = self.elf_images[last_action]
|
| 338 |
+
|
| 339 |
+
self.window_surface.blit(elf_img, cell_rect)
|
| 340 |
+
|
| 341 |
+
# font = pygame.font.SysFont(None, 20)
|
| 342 |
+
# img = font.render(f"agent position = {self.x}", True, "black")
|
| 343 |
+
# self.window_surface.blit(img, (5, 5))
|
| 344 |
+
# img = font.render(f"timestep = {self.timestep}", True, "black")
|
| 345 |
+
# self.window_surface.blit(img, (5, 25))
|
| 346 |
+
|
| 347 |
+
if mode == "human":
|
| 348 |
+
pygame.event.pump()
|
| 349 |
+
pygame.display.update()
|
| 350 |
+
self.clock.tick(self.metadata["render_fps"])
|
| 351 |
+
elif mode == "rgb_array":
|
| 352 |
+
return np.transpose(np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2))
|
| 353 |
+
|
| 354 |
+
def _render_text(self):
|
| 355 |
+
desc = self.text_canvas()
|
| 356 |
+
outfile = StringIO()
|
| 357 |
+
|
| 358 |
+
row, col = 0, (self.scope_size // 2) * 3 + 2
|
| 359 |
+
desc = [[c.decode("utf-8") for c in line] for line in desc]
|
| 360 |
+
desc[row][col] = colorize(desc[row][col], "red", highlight=True)
|
| 361 |
+
|
| 362 |
+
outfile.write("\n")
|
| 363 |
+
outfile.write("\n".join("".join(line) for line in desc) + "\n")
|
| 364 |
+
|
| 365 |
+
with closing(outfile):
|
| 366 |
+
return outfile.getvalue()
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class ApplePhase0Env(AppleEnv):
|
| 370 |
+
def __init__(self, *args, **kwargs):
|
| 371 |
+
super().__init__(*args, **kwargs)
|
| 372 |
+
self.change_phase = False
|
| 373 |
+
self.phase = 0
|
| 374 |
+
self.init_pos = self.start_x
|
| 375 |
+
self.success_when_finish_phase = 0
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class ApplePhase1Env(AppleEnv):
|
| 379 |
+
def __init__(self, *args, **kwargs):
|
| 380 |
+
super().__init__(*args, **kwargs)
|
| 381 |
+
self.change_phase = False
|
| 382 |
+
self.phase = 1
|
| 383 |
+
self.init_pos = self.goal_x
|
| 384 |
+
self.success_when_finish_phase = 1
|
apple/envs/img/apple.png
ADDED
|
apple/envs/img/elf_down.png
ADDED
|
apple/envs/img/elf_left.png
ADDED
|
apple/envs/img/elf_right.png
ADDED
|
apple/envs/img/g1.png
ADDED
|
apple/envs/img/g2.png
ADDED
|
apple/envs/img/g3.png
ADDED
|
apple/envs/img/grass.jpg
ADDED
|
apple/envs/img/home.png
ADDED
|
apple/envs/img/home00.png
ADDED
|
apple/envs/img/home01.png
ADDED
|
apple/envs/img/home02.png
ADDED
|
apple/envs/img/home10.png
ADDED
|
apple/envs/img/home11.png
ADDED
|
apple/envs/img/home12.png
ADDED
|
apple/envs/img/home2.png
ADDED
|
apple/envs/img/home2_with_apples.png
ADDED
|
apple/envs/img/home_grass.png
ADDED
|
apple/envs/img/part_grass.png
ADDED
|
apple/envs/img/stool.png
ADDED
|
apple/envs/img/textures.jpg
ADDED
|
apple/envs/img/white.png
ADDED
|
apple/evaluation/render_episode.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def render_episode(env, model):
|
| 2 |
+
obs, done, episode_return, episode_len = env.reset(), False, 0, 0
|
| 3 |
+
|
| 4 |
+
while not done:
|
| 5 |
+
action = model.get_action(obs)
|
| 6 |
+
new_obs, reward, done, _ = env.step(action)
|
| 7 |
+
episode_return += reward
|
| 8 |
+
episode_len += 1
|
| 9 |
+
|
| 10 |
+
data = dict(
|
| 11 |
+
x=env.x,
|
| 12 |
+
obs=obs,
|
| 13 |
+
action=action,
|
| 14 |
+
reward=reward,
|
| 15 |
+
done=done,
|
| 16 |
+
episode_len=episode_len,
|
| 17 |
+
episode_return=episode_return,
|
| 18 |
+
pixel_array=env.unwrapped.render(),
|
| 19 |
+
)
|
| 20 |
+
yield data
|
| 21 |
+
|
| 22 |
+
obs = new_obs
|
apple/logger.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
|
| 3 |
+
Some simple logging functionality, inspired by rllab's logging.
|
| 4 |
+
|
| 5 |
+
Logs to a tab-separated-values file (path/to/output_directory/progress.txt)
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
import atexit
|
| 9 |
+
import os
|
| 10 |
+
import os.path as osp
|
| 11 |
+
import time
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import joblib
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
import wandb
|
| 19 |
+
|
| 20 |
+
color2num = dict(gray=30, red=31, green=32, yellow=33, blue=34, magenta=35, cyan=36, white=37, crimson=38)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=True):
|
| 24 |
+
"""
|
| 25 |
+
Sets up the output_dir for a logger and returns a dict for logger kwargs.
|
| 26 |
+
If no seed is given and datestamp is false,
|
| 27 |
+
::
|
| 28 |
+
output_dir = data_dir/exp_name
|
| 29 |
+
If a seed is given and datestamp is false,
|
| 30 |
+
::
|
| 31 |
+
output_dir = data_dir/exp_name/exp_name_s[seed]
|
| 32 |
+
If datestamp is true, amend to
|
| 33 |
+
::
|
| 34 |
+
output_dir = data_dir/YY-MM-DD_exp_name/YY-MM-DD_HH-MM-SS_exp_name_s[seed]
|
| 35 |
+
You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in
|
| 36 |
+
``spinup/user_config.py``.
|
| 37 |
+
Args:
|
| 38 |
+
exp_name (string): Name for experiment.
|
| 39 |
+
seed (int): Seed for random number generators used by experiment.
|
| 40 |
+
data_dir (string): Path to folder where results should be saved.
|
| 41 |
+
Default is the ``DEFAULT_DATA_DIR`` in ``spinup/user_config.py``.
|
| 42 |
+
datestamp (bool): Whether to include a date and timestamp in the
|
| 43 |
+
name of the save directory.
|
| 44 |
+
Returns:
|
| 45 |
+
logger_kwargs, a dict containing output_dir and exp_name.
|
| 46 |
+
"""
|
| 47 |
+
if data_dir is None:
|
| 48 |
+
data_dir = osp.join(osp.abspath(osp.dirname(osp.dirname(osp.dirname(__file__)))), "logs")
|
| 49 |
+
|
| 50 |
+
# Make base path
|
| 51 |
+
ymd_time = time.strftime("%Y-%m-%d_") if datestamp else ""
|
| 52 |
+
relpath = "".join([ymd_time, exp_name])
|
| 53 |
+
|
| 54 |
+
if seed is not None:
|
| 55 |
+
# Make a seed-specific subfolder in the experiment directory.
|
| 56 |
+
if datestamp:
|
| 57 |
+
hms_time = time.strftime("%Y-%m-%d_%H-%M-%S")
|
| 58 |
+
subfolder = "".join([hms_time, "-", exp_name, "_s", str(seed)])
|
| 59 |
+
else:
|
| 60 |
+
subfolder = "".join([exp_name, "_s", str(seed)])
|
| 61 |
+
relpath = osp.join(relpath, subfolder)
|
| 62 |
+
|
| 63 |
+
logger_kwargs = dict(output_dir=osp.join(data_dir, relpath), exp_name=exp_name)
|
| 64 |
+
return logger_kwargs
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def colorize(string, color, bold=False, highlight=False):
|
| 68 |
+
"""
|
| 69 |
+
Colorize a string.
|
| 70 |
+
|
| 71 |
+
This function was originally written by John Schulman.
|
| 72 |
+
"""
|
| 73 |
+
attr = []
|
| 74 |
+
num = color2num[color]
|
| 75 |
+
if highlight:
|
| 76 |
+
num += 10
|
| 77 |
+
attr.append(str(num))
|
| 78 |
+
if bold:
|
| 79 |
+
attr.append("1")
|
| 80 |
+
return "\x1b[%sm%s\x1b[0m" % (";".join(attr), string)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Logger:
|
| 84 |
+
"""
|
| 85 |
+
A general-purpose logger.
|
| 86 |
+
|
| 87 |
+
Makes it easy to save diagnostics, hyperparameter configurations, the
|
| 88 |
+
state of a training run, and the trained model.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
log_to_wandb=False,
|
| 94 |
+
verbose=False,
|
| 95 |
+
output_dir=None,
|
| 96 |
+
output_fname="progress.csv",
|
| 97 |
+
delimeter=",",
|
| 98 |
+
exp_name=None,
|
| 99 |
+
wandbcommit=1,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initialize a Logger.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
log_to_wandb (bool): If True logger will log to wandb
|
| 106 |
+
|
| 107 |
+
output_dir (string): A directory for saving results to. If
|
| 108 |
+
``None``, defaults to a temp directory of the form
|
| 109 |
+
``/tmp/experiments/somerandomnumber``.
|
| 110 |
+
|
| 111 |
+
output_fname (string): Name for the tab-separated-value file
|
| 112 |
+
containing metrics logged throughout a training run.
|
| 113 |
+
Defaults to ``progress.csv``.
|
| 114 |
+
|
| 115 |
+
exp_name (string): Experiment name. If you run multiple training
|
| 116 |
+
runs and give them all the same ``exp_name``, the plotter
|
| 117 |
+
will know to group them. (Use case: if you run the same
|
| 118 |
+
hyperparameter configuration with multiple random seeds, you
|
| 119 |
+
should give them all the same ``exp_name``.)
|
| 120 |
+
|
| 121 |
+
delimeter (string): Used to separate logged values saved in output_fname
|
| 122 |
+
"""
|
| 123 |
+
self.verbose = verbose
|
| 124 |
+
self.log_to_wandb = log_to_wandb
|
| 125 |
+
self.delimeter = delimeter
|
| 126 |
+
self.wandbcommit = wandbcommit
|
| 127 |
+
self.log_iter = 1
|
| 128 |
+
# We assume that there's no multiprocessing.
|
| 129 |
+
if output_dir is not None:
|
| 130 |
+
self.output_dir = output_dir or "/tmp/experiments/%i" % int(time.time())
|
| 131 |
+
if osp.exists(self.output_dir):
|
| 132 |
+
print("Warning: Log dir %s already exists! Storing info there anyway." % self.output_dir)
|
| 133 |
+
else:
|
| 134 |
+
os.makedirs(self.output_dir)
|
| 135 |
+
self.output_file = open(osp.join(self.output_dir, output_fname), "w+")
|
| 136 |
+
atexit.register(self.output_file.close)
|
| 137 |
+
print(colorize("Logging data to %s" % self.output_file.name, "green", bold=True))
|
| 138 |
+
else:
|
| 139 |
+
self.output_file = None
|
| 140 |
+
|
| 141 |
+
self.first_row = True
|
| 142 |
+
self.log_headers = []
|
| 143 |
+
self.log_current_row = {}
|
| 144 |
+
self.exp_name = exp_name
|
| 145 |
+
|
| 146 |
+
def log(self, msg, color="green"):
|
| 147 |
+
"""Print a colorized message to stdout."""
|
| 148 |
+
print(colorize(msg, color, bold=True))
|
| 149 |
+
|
| 150 |
+
def log_tabular(self, key, val):
|
| 151 |
+
"""
|
| 152 |
+
Log a value of some diagnostic.
|
| 153 |
+
|
| 154 |
+
Call this only once for each diagnostic quantity, each iteration.
|
| 155 |
+
After using ``log_tabular`` to store values for each diagnostic,
|
| 156 |
+
make sure to call ``dump_tabular`` to write them out to file and
|
| 157 |
+
stdout (otherwise they will not get saved anywhere).
|
| 158 |
+
"""
|
| 159 |
+
if self.first_row:
|
| 160 |
+
self.log_headers.append(key)
|
| 161 |
+
else:
|
| 162 |
+
if key not in self.log_headers:
|
| 163 |
+
self.log_headers.append(key)
|
| 164 |
+
|
| 165 |
+
if self.output_file is not None:
|
| 166 |
+
# move pointer at the beggining of the file
|
| 167 |
+
self.output_file.seek(0)
|
| 168 |
+
# skip the header
|
| 169 |
+
self.output_file.readline()
|
| 170 |
+
# keep rest of the file
|
| 171 |
+
logs = self.output_file.read()
|
| 172 |
+
# clear the file
|
| 173 |
+
self.output_file.truncate(0)
|
| 174 |
+
self.output_file.seek(0)
|
| 175 |
+
# write new headers
|
| 176 |
+
self.output_file.write(self.delimeter.join(self.log_headers) + "\n")
|
| 177 |
+
# write stored file
|
| 178 |
+
self.output_file.write(logs)
|
| 179 |
+
self.output_file.seek(0)
|
| 180 |
+
self.output_file.seek(0, 2)
|
| 181 |
+
# assert key in self.log_headers, (
|
| 182 |
+
# "Trying to introduce a new key %s that you didn't include in the first iteration" % key
|
| 183 |
+
# )
|
| 184 |
+
assert key not in self.log_current_row, (
|
| 185 |
+
"You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key
|
| 186 |
+
)
|
| 187 |
+
self.log_current_row[key] = val
|
| 188 |
+
|
| 189 |
+
def save_state(self, state_dict, itr=None):
|
| 190 |
+
"""
|
| 191 |
+
Saves the state of an experiment.
|
| 192 |
+
|
| 193 |
+
To be clear: this is about saving *state*, not logging diagnostics.
|
| 194 |
+
All diagnostic logging is separate from this function. This function
|
| 195 |
+
will save whatever is in ``state_dict``---usually just a copy of the
|
| 196 |
+
environment---and the most recent parameters for the model you
|
| 197 |
+
previously set up saving for with ``setup_tf_saver``.
|
| 198 |
+
|
| 199 |
+
Call with any frequency you prefer. If you only want to maintain a
|
| 200 |
+
single state and overwrite it at each call with the most recent
|
| 201 |
+
version, leave ``itr=None``. If you want to keep all of the states you
|
| 202 |
+
save, provide unique (increasing) values for 'itr'.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
state_dict (dict): Dictionary containing essential elements to
|
| 206 |
+
describe the current state of training.
|
| 207 |
+
|
| 208 |
+
itr: An int, or None. Current iteration of training.
|
| 209 |
+
"""
|
| 210 |
+
fname = "vars.pkl" if itr is None else "vars%d.pkl" % itr
|
| 211 |
+
try:
|
| 212 |
+
joblib.dump(state_dict, osp.join(self.output_dir, fname))
|
| 213 |
+
except:
|
| 214 |
+
self.log("Warning: could not pickle state_dict.", color="red")
|
| 215 |
+
if hasattr(self, "pytorch_saver_elements"):
|
| 216 |
+
self._pytorch_simple_save(itr)
|
| 217 |
+
|
| 218 |
+
def setup_pytorch_saver(self, what_to_save):
|
| 219 |
+
"""
|
| 220 |
+
Set up easy model saving for a single PyTorch model.
|
| 221 |
+
|
| 222 |
+
Because PyTorch saving and loading is especially painless, this is
|
| 223 |
+
very minimal; we just need references to whatever we would like to
|
| 224 |
+
pickle. This is integrated into the logger because the logger
|
| 225 |
+
knows where the user would like to save information about this
|
| 226 |
+
training run.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
what_to_save: Any PyTorch model or serializable object containing
|
| 230 |
+
PyTorch models.
|
| 231 |
+
"""
|
| 232 |
+
self.pytorch_saver_elements = what_to_save
|
| 233 |
+
|
| 234 |
+
def _pytorch_simple_save(self, itr=None):
|
| 235 |
+
"""
|
| 236 |
+
Saves the PyTorch model (or models).
|
| 237 |
+
"""
|
| 238 |
+
assert hasattr(self, "pytorch_saver_elements"), "First have to setup saving with self.setup_pytorch_saver"
|
| 239 |
+
fpath = "pyt_save"
|
| 240 |
+
fpath = osp.join(self.output_dir, fpath)
|
| 241 |
+
fname = "model" + ("%d" % itr if itr is not None else "") + ".pt"
|
| 242 |
+
fname = osp.join(fpath, fname)
|
| 243 |
+
os.makedirs(fpath, exist_ok=True)
|
| 244 |
+
with warnings.catch_warnings():
|
| 245 |
+
warnings.simplefilter("ignore")
|
| 246 |
+
# We are using a non-recommended way of saving PyTorch models,
|
| 247 |
+
# by pickling whole objects (which are dependent on the exact
|
| 248 |
+
# directory structure at the time of saving) as opposed to
|
| 249 |
+
# just saving network weights. This works sufficiently well
|
| 250 |
+
# for the purposes of Spinning Up, but you may want to do
|
| 251 |
+
# something different for your personal PyTorch project.
|
| 252 |
+
# We use a catch_warnings() context to avoid the warnings about
|
| 253 |
+
# not being able to save the source code.
|
| 254 |
+
torch.save(self.pytorch_saver_elements, fname)
|
| 255 |
+
|
| 256 |
+
def dump_tabular(self):
|
| 257 |
+
"""
|
| 258 |
+
Write all of the diagnostics from the current iteration.
|
| 259 |
+
|
| 260 |
+
Writes both to stdout, and to the output file.
|
| 261 |
+
"""
|
| 262 |
+
vals = []
|
| 263 |
+
key_lens = [len(key) for key in self.log_headers]
|
| 264 |
+
max_key_len = max(15, max(key_lens))
|
| 265 |
+
keystr = "%" + "%d" % max_key_len
|
| 266 |
+
fmt = "| " + keystr + "s | %15s |"
|
| 267 |
+
n_slashes = 22 + max_key_len
|
| 268 |
+
step = self.log_current_row.get("total_env_steps")
|
| 269 |
+
|
| 270 |
+
if self.verbose:
|
| 271 |
+
print("-" * n_slashes)
|
| 272 |
+
for key in self.log_headers:
|
| 273 |
+
val = self.log_current_row.get(key, "")
|
| 274 |
+
valstr = "%8.3g" % val if isinstance(val, float) else val
|
| 275 |
+
print(fmt % (key, valstr))
|
| 276 |
+
vals.append(val)
|
| 277 |
+
print("-" * n_slashes, flush=True)
|
| 278 |
+
|
| 279 |
+
if self.output_file is not None:
|
| 280 |
+
if self.first_row:
|
| 281 |
+
self.output_file.write(self.delimeter.join(self.log_headers) + "\n")
|
| 282 |
+
self.output_file.write(self.delimeter.join(map(str, vals)) + "\n")
|
| 283 |
+
self.output_file.flush()
|
| 284 |
+
|
| 285 |
+
key_val_dict = {key: self.log_current_row.get(key, "") for key in self.log_headers}
|
| 286 |
+
if self.log_to_wandb:
|
| 287 |
+
if self.log_iter % self.wandbcommit == 0:
|
| 288 |
+
wandb.log(key_val_dict, step=step, commit=True)
|
| 289 |
+
else:
|
| 290 |
+
wandb.log(key_val_dict, step=step, commit=False)
|
| 291 |
+
|
| 292 |
+
self.log_current_row.clear()
|
| 293 |
+
self.first_row = False
|
| 294 |
+
self.log_iter += 1
|
| 295 |
+
|
| 296 |
+
return key_val_dict
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class EpochLogger(Logger):
|
| 300 |
+
"""
|
| 301 |
+
A variant of Logger tailored for tracking average values over epochs.
|
| 302 |
+
|
| 303 |
+
Typical use case: there is some quantity which is calculated many times
|
| 304 |
+
throughout an epoch, and at the end of the epoch, you would like to
|
| 305 |
+
report the average / std / min / max value of that quantity.
|
| 306 |
+
|
| 307 |
+
With an EpochLogger, each time the quantity is calculated, you would
|
| 308 |
+
use
|
| 309 |
+
|
| 310 |
+
.. code-block:: python
|
| 311 |
+
|
| 312 |
+
epoch_logger.store(NameOfQuantity=quantity_value)
|
| 313 |
+
|
| 314 |
+
to load it into the EpochLogger's state. Then at the end of the epoch, you
|
| 315 |
+
would use
|
| 316 |
+
|
| 317 |
+
.. code-block:: python
|
| 318 |
+
|
| 319 |
+
epoch_logger.log_tabular(NameOfQuantity, **options)
|
| 320 |
+
|
| 321 |
+
to record the desired values.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(self, *args, **kwargs):
|
| 325 |
+
super().__init__(*args, **kwargs)
|
| 326 |
+
self.epoch_dict = dict()
|
| 327 |
+
|
| 328 |
+
def store(self, d):
|
| 329 |
+
"""
|
| 330 |
+
Save something into the epoch_logger's current state.
|
| 331 |
+
|
| 332 |
+
Provide an arbitrary number of keyword arguments with numerical
|
| 333 |
+
values.
|
| 334 |
+
"""
|
| 335 |
+
for k, v in d.items():
|
| 336 |
+
if not (k in self.epoch_dict.keys()):
|
| 337 |
+
self.epoch_dict[k] = []
|
| 338 |
+
self.epoch_dict[k].append(v)
|
| 339 |
+
|
| 340 |
+
def log_tabular(self, key, val=None, with_min_and_max=False, with_median=False, with_sum=False, average_only=False):
|
| 341 |
+
"""
|
| 342 |
+
Log a value or possibly the mean/std/min/max values of a diagnostic.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
key (string): The name of the diagnostic. If you are logging a
|
| 346 |
+
diagnostic whose state has previously been saved with
|
| 347 |
+
``store``, the key here has to match the key you used there.
|
| 348 |
+
|
| 349 |
+
val: A value for the diagnostic. If you have previously saved
|
| 350 |
+
values for this key via ``store``, do *not* provide a ``val``
|
| 351 |
+
here.
|
| 352 |
+
|
| 353 |
+
with_min_and_max (bool): If true, log min and max values of the
|
| 354 |
+
diagnostic over the epoch.
|
| 355 |
+
|
| 356 |
+
average_only (bool): If true, do not log the standard deviation
|
| 357 |
+
of the diagnostic over the epoch.
|
| 358 |
+
"""
|
| 359 |
+
if val is not None:
|
| 360 |
+
super().log_tabular(key, val)
|
| 361 |
+
else:
|
| 362 |
+
stats = self.get_stats(key)
|
| 363 |
+
super().log_tabular(key if average_only else key + "/avg", stats[0])
|
| 364 |
+
if not (average_only):
|
| 365 |
+
super().log_tabular(key + "/std", stats[1])
|
| 366 |
+
if with_min_and_max:
|
| 367 |
+
super().log_tabular(key + "/max", stats[3])
|
| 368 |
+
super().log_tabular(key + "/min", stats[2])
|
| 369 |
+
if with_median:
|
| 370 |
+
super().log_tabular(key + "/med", stats[4])
|
| 371 |
+
if with_sum:
|
| 372 |
+
super().log_tabular(key + "/sum", stats[5])
|
| 373 |
+
|
| 374 |
+
self.epoch_dict[key] = []
|
| 375 |
+
|
| 376 |
+
def get_stats(self, key):
|
| 377 |
+
"""
|
| 378 |
+
Lets an algorithm ask the logger for mean/std/min/max of a diagnostic.
|
| 379 |
+
"""
|
| 380 |
+
v = self.epoch_dict.get(key)
|
| 381 |
+
if not v:
|
| 382 |
+
return [np.nan, np.nan, np.nan, np.nan]
|
| 383 |
+
vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape) > 0 else v
|
| 384 |
+
return [np.mean(vals), np.std(vals), np.min(vals), np.max(vals), np.median(vals), np.sum(vals)]
|
apple/models/categorical_policy.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from torch.distributions import Categorical
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CategoricalPolicy(nn.Module):
|
| 8 |
+
def __init__(self, state_dim, act_dim, weight1=None, weight2=None):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.model = nn.Linear(state_dim, act_dim, bias=False)
|
| 11 |
+
|
| 12 |
+
if weight1 is not None:
|
| 13 |
+
nn.init.constant_(self.model.weight[0][0], weight1)
|
| 14 |
+
|
| 15 |
+
if weight2 is not None:
|
| 16 |
+
nn.init.constant_(self.model.weight[0][1], weight2)
|
| 17 |
+
|
| 18 |
+
def forward(self, state):
|
| 19 |
+
x = torch.from_numpy(state).float().unsqueeze(0)
|
| 20 |
+
x = self.model(x)
|
| 21 |
+
# we just consider 1 dimensional probability of action
|
| 22 |
+
p = torch.sigmoid(x)
|
| 23 |
+
return torch.cat([p, 1 - p], dim=1)
|
| 24 |
+
|
| 25 |
+
def act(self, state):
|
| 26 |
+
probs = self.forward(state)
|
| 27 |
+
dist = Categorical(probs)
|
| 28 |
+
action = dist.sample()
|
| 29 |
+
return action.item(), dist.log_prob(action)
|
| 30 |
+
|
| 31 |
+
def sample(self, probs):
|
| 32 |
+
dist = Categorical(probs)
|
| 33 |
+
action = dist.sample()
|
| 34 |
+
return action.item(), dist.log_prob(action)
|
| 35 |
+
|
| 36 |
+
def log_prob(self, probs, target_action):
|
| 37 |
+
dist = Categorical(probs)
|
| 38 |
+
action = dist.sample()
|
| 39 |
+
return action.item(), dist.log_prob(target_action)
|
| 40 |
+
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def get_action(self, state):
|
| 43 |
+
probs = self.forward(state)
|
| 44 |
+
dist = Categorical(probs)
|
| 45 |
+
action = dist.sample()
|
| 46 |
+
return action.item()
|
apple/training/reinforce_trainer.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from apple.training.trainer import Trainer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def discount_cumsum(x, gamma):
|
| 7 |
+
discount_cumsum = torch.zeros_like(x)
|
| 8 |
+
discount_cumsum[-1] = x[-1]
|
| 9 |
+
for t in reversed(range(x.shape[0] - 1)):
|
| 10 |
+
discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
|
| 11 |
+
return discount_cumsum
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ReinforceTrainer(Trainer):
|
| 15 |
+
def __init__(self, *args, gamma: float = 1.0, **kwargs):
|
| 16 |
+
super().__init__(*args, **kwargs)
|
| 17 |
+
self.gamma = gamma
|
| 18 |
+
|
| 19 |
+
def train(self, env, test_envs, num_episodes, log_every, update_every, num_eval_eps):
|
| 20 |
+
# code base on
|
| 21 |
+
# https://goodboychan.github.io/python/reinforcement_learning/pytorch/udacity/2021/05/12/REINFORCE-CartPole.html
|
| 22 |
+
self.optim.zero_grad()
|
| 23 |
+
for episode in range(num_episodes):
|
| 24 |
+
self.train_it += 1
|
| 25 |
+
|
| 26 |
+
if (episode + 1) % log_every == 0:
|
| 27 |
+
|
| 28 |
+
self.test_agent(self.model, self.logger, test_envs, num_eval_eps)
|
| 29 |
+
|
| 30 |
+
# Log info about epoch
|
| 31 |
+
self.logger.log_tabular("total_env_steps", self.train_it)
|
| 32 |
+
self.logger.log_tabular("train/return", with_min_and_max=True)
|
| 33 |
+
self.logger.log_tabular("train/ep_length", average_only=True)
|
| 34 |
+
|
| 35 |
+
for e, w in enumerate(self.model.model.weight.flatten()):
|
| 36 |
+
self.logger.log_tabular(f"weights{e}", w.item())
|
| 37 |
+
|
| 38 |
+
self.logger.log_tabular("train/policy_loss", average_only=True)
|
| 39 |
+
self.logger.log_tabular("train/log_probs", average_only=True)
|
| 40 |
+
self.logger.dump_tabular()
|
| 41 |
+
|
| 42 |
+
state = env.reset()
|
| 43 |
+
|
| 44 |
+
saved_log_probs = []
|
| 45 |
+
rewards = []
|
| 46 |
+
ep_len, ep_ret = 0, 0
|
| 47 |
+
while True:
|
| 48 |
+
# Sample the action from current policy
|
| 49 |
+
action, log_prob = self.model.act(state)
|
| 50 |
+
saved_log_probs.append(log_prob)
|
| 51 |
+
state, reward, done, _ = env.step(action)
|
| 52 |
+
ep_ret += reward
|
| 53 |
+
ep_len += 1
|
| 54 |
+
|
| 55 |
+
rewards.append(reward)
|
| 56 |
+
|
| 57 |
+
if done:
|
| 58 |
+
self.logger.store({"train/return": ep_ret, "train/ep_length": ep_len})
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
saved_log_probs, rewards = torch.cat(saved_log_probs), torch.tensor(rewards)
|
| 62 |
+
|
| 63 |
+
discounted_rewards = discount_cumsum(rewards, gamma=self.gamma)
|
| 64 |
+
# Note that we are using Gradient Ascent, not Descent. So we need to calculate it with negative rewards.
|
| 65 |
+
policy_loss = (-discounted_rewards * saved_log_probs).sum()
|
| 66 |
+
# Backpropagation
|
| 67 |
+
if (episode + 1) % update_every == 0:
|
| 68 |
+
self.optim.zero_grad()
|
| 69 |
+
policy_loss.backward()
|
| 70 |
+
if (episode + 1) % update_every == 0:
|
| 71 |
+
self.optim.step()
|
| 72 |
+
|
| 73 |
+
self.logger.store({"train/policy_loss": policy_loss.item()})
|
| 74 |
+
self.logger.store({"train/log_probs": saved_log_probs.mean().item()})
|
apple/training/trainer.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Trainer:
|
| 6 |
+
def __init__(self, model, optim, logger):
|
| 7 |
+
self.model = model
|
| 8 |
+
self.optim = optim
|
| 9 |
+
self.logger = logger
|
| 10 |
+
self.train_it = 0
|
| 11 |
+
|
| 12 |
+
def test_agent(self, model, logger, test_envs, num_episodes):
|
| 13 |
+
avg_success = []
|
| 14 |
+
|
| 15 |
+
for seq_idx, test_env in enumerate(test_envs):
|
| 16 |
+
key_prefix = f"{test_env.name}/"
|
| 17 |
+
|
| 18 |
+
for j in range(num_episodes):
|
| 19 |
+
obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0
|
| 20 |
+
|
| 21 |
+
while not done:
|
| 22 |
+
action = model.get_action(obs)
|
| 23 |
+
obs, reward, done, _ = test_env.step(action)
|
| 24 |
+
episode_return += reward
|
| 25 |
+
episode_len += 1
|
| 26 |
+
logger.store({key_prefix + "return": episode_return, key_prefix + "ep_length": episode_len})
|
| 27 |
+
|
| 28 |
+
logger.log_tabular(key_prefix + "return", with_min_and_max=True)
|
| 29 |
+
logger.log_tabular(key_prefix + "ep_length", average_only=True)
|
| 30 |
+
env_success = test_env.pop_successes()
|
| 31 |
+
avg_success += env_success
|
| 32 |
+
logger.log_tabular(key_prefix + "success", np.mean(env_success))
|
| 33 |
+
|
| 34 |
+
key = "average_success"
|
| 35 |
+
logger.log_tabular(key, np.mean(avg_success))
|
| 36 |
+
|
| 37 |
+
def log(self, logger, step, model):
|
| 38 |
+
# Log info about epoch
|
| 39 |
+
logger.log_tabular("total_env_steps", step)
|
| 40 |
+
|
| 41 |
+
logger.log_tabular("train/loss", average_only=True)
|
| 42 |
+
logger.log_tabular("train/action", average_only=True)
|
| 43 |
+
|
| 44 |
+
for e, w in enumerate(model.model.weight.flatten()):
|
| 45 |
+
logger.log_tabular(f"weight{e}", w.item())
|
| 46 |
+
|
| 47 |
+
return logger.dump_tabular()
|
| 48 |
+
|
| 49 |
+
def update(self, env, probs, model, optim, logger):
|
| 50 |
+
target = torch.as_tensor([env.get_target_action()], dtype=torch.float32)
|
| 51 |
+
action, log_prob = model.log_prob(probs, target)
|
| 52 |
+
|
| 53 |
+
optim.zero_grad()
|
| 54 |
+
loss = -torch.mean(log_prob)
|
| 55 |
+
loss.backward()
|
| 56 |
+
optim.step()
|
| 57 |
+
|
| 58 |
+
logger.store({"train/action": action})
|
| 59 |
+
logger.store({"train/loss": loss.item()})
|
| 60 |
+
|
| 61 |
+
def train(self, env, test_envs, steps, log_every, num_eval_eps):
|
| 62 |
+
obs = env.reset()
|
| 63 |
+
for timestep in range(steps):
|
| 64 |
+
self.train_it += 1
|
| 65 |
+
|
| 66 |
+
if (timestep + 1) % log_every == 0:
|
| 67 |
+
self.test_agent(self.model, self.logger, test_envs, num_eval_eps)
|
| 68 |
+
self.log(self.logger, self.train_it, self.model)
|
| 69 |
+
|
| 70 |
+
output = self.model(obs)
|
| 71 |
+
action, log_prob = self.model.sample(output)
|
| 72 |
+
self.update(env, output, self.model, self.optim, self.logger)
|
| 73 |
+
|
| 74 |
+
obs, reward, done, info = env.step(action)
|
| 75 |
+
|
| 76 |
+
if done:
|
| 77 |
+
obs = env.reset()
|
apple/utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# https://stackoverflow.com/a/43357954/6365092
|
| 11 |
+
def str2bool(v: Union[bool, str]) -> bool:
|
| 12 |
+
if isinstance(v, bool):
|
| 13 |
+
return v
|
| 14 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 15 |
+
return True
|
| 16 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 17 |
+
return False
|
| 18 |
+
else:
|
| 19 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def set_seed(seed):
|
| 23 |
+
random.seed(seed)
|
| 24 |
+
np.random.seed(seed)
|
| 25 |
+
torch.manual_seed(seed)
|
apple/wrappers.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import (
|
| 2 |
+
Any,
|
| 3 |
+
Dict,
|
| 4 |
+
List,
|
| 5 |
+
Tuple,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
import gym
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SuccessCounter(gym.Wrapper):
|
| 13 |
+
"""Helper class to keep count of successes in MetaWorld environments."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, env: gym.Env) -> None:
|
| 16 |
+
super().__init__(env)
|
| 17 |
+
self.successes = []
|
| 18 |
+
self.current_success = False
|
| 19 |
+
|
| 20 |
+
def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
|
| 21 |
+
obs, reward, done, info = self.env.step(action)
|
| 22 |
+
if info.get("success", False):
|
| 23 |
+
self.current_success = True
|
| 24 |
+
if done:
|
| 25 |
+
self.successes.append(self.current_success)
|
| 26 |
+
return obs, reward, done, info
|
| 27 |
+
|
| 28 |
+
def pop_successes(self) -> List[bool]:
|
| 29 |
+
res = self.successes
|
| 30 |
+
self.successes = []
|
| 31 |
+
return res
|
| 32 |
+
|
| 33 |
+
def reset(self, **kwargs) -> np.ndarray:
|
| 34 |
+
self.current_success = False
|
| 35 |
+
return self.env.reset(**kwargs)
|
assets/apple_env.png
ADDED
|
assets/example_rollout.mp4
ADDED
|
Binary file (86.8 kB). View file
|
|
|
assets/generate_example_rollout.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import skvideo.io
|
| 3 |
+
|
| 4 |
+
from apple.envs.discrete_apple import get_apple_env
|
| 5 |
+
|
| 6 |
+
env = get_apple_env("full", time_limit=10, start_x=0, c=0.5, goal_x=8, render_mode="rgb_array")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
imgs = []
|
| 10 |
+
env.reset()
|
| 11 |
+
for i in range(8):
|
| 12 |
+
imgs.append(env.unwrapped.render())
|
| 13 |
+
env.step(1)
|
| 14 |
+
|
| 15 |
+
for i in range(9):
|
| 16 |
+
imgs.append(env.unwrapped.render())
|
| 17 |
+
env.step(0)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
skvideo.io.vwrite(
|
| 21 |
+
"example_rollout.mp4",
|
| 22 |
+
np.stack(imgs),
|
| 23 |
+
inputdict={
|
| 24 |
+
"-r": str(int(4)),
|
| 25 |
+
},
|
| 26 |
+
outputdict={
|
| 27 |
+
"-f": "mp4",
|
| 28 |
+
"-pix_fmt": "yuv420p", # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74
|
| 29 |
+
},
|
| 30 |
+
)
|
input_args.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from apple.utils import str2bool
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def apple_parse_args(args=None):
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
|
| 9 |
+
parser.add_argument("--c", type=float, default=0.25, required=False)
|
| 10 |
+
parser.add_argument("--start_x", type=float, default=0.0, required=False)
|
| 11 |
+
parser.add_argument("--goal_x", type=float, default=10.0, required=False)
|
| 12 |
+
parser.add_argument("--time_limit", type=int, default=100.0, required=False)
|
| 13 |
+
|
| 14 |
+
parser.add_argument("--lr", type=float, default=1e-3, required=False)
|
| 15 |
+
parser.add_argument("--log_to_wandb", type=str2bool, default=True, required=False)
|
| 16 |
+
|
| 17 |
+
return parser.parse_known_args(args=args)[0]
|
mrunner_exps/behavioral_cloning.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from mrunner.helpers.specification_helper import create_experiments_helper
|
| 4 |
+
|
| 5 |
+
from mrunner_exps.utils import combine_config_with_defaults
|
| 6 |
+
|
| 7 |
+
name = globals()["script"][:-3]
|
| 8 |
+
|
| 9 |
+
# params for all exps
|
| 10 |
+
config = {
|
| 11 |
+
"exp_tag": "behavioral_cloning",
|
| 12 |
+
"run_kind": "bc",
|
| 13 |
+
"log_to_wandb": True,
|
| 14 |
+
"pretrain_steps": 200,
|
| 15 |
+
"steps": 200,
|
| 16 |
+
"log_every": 1,
|
| 17 |
+
"num_eval_eps": 10,
|
| 18 |
+
"verbose": False,
|
| 19 |
+
"lr": 0.01,
|
| 20 |
+
"c": 0.5,
|
| 21 |
+
"start_x": 0.0,
|
| 22 |
+
"goal_x": 50.0,
|
| 23 |
+
"bias_in_state": True,
|
| 24 |
+
"position_in_state": False,
|
| 25 |
+
"time_limit": 100,
|
| 26 |
+
"wandbcommit": 100,
|
| 27 |
+
"pretrain": "phase2",
|
| 28 |
+
"finetune": "full",
|
| 29 |
+
}
|
| 30 |
+
config = combine_config_with_defaults(config)
|
| 31 |
+
|
| 32 |
+
# params different between exps
|
| 33 |
+
params_grid = [
|
| 34 |
+
{
|
| 35 |
+
"seed": list(range(10)),
|
| 36 |
+
"c": list(np.arange(0.1, 1.1, 0.1)),
|
| 37 |
+
"goal_x": list(np.arange(5, 50, 5)),
|
| 38 |
+
}
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
experiments_list = create_experiments_helper(
|
| 42 |
+
experiment_name=name,
|
| 43 |
+
project_name="apple",
|
| 44 |
+
with_neptune=False,
|
| 45 |
+
script="python3 mrunner_run.py",
|
| 46 |
+
python_path=".",
|
| 47 |
+
tags=[name],
|
| 48 |
+
exclude=["logs", "wandb"],
|
| 49 |
+
base_config=config,
|
| 50 |
+
params_grid=params_grid,
|
| 51 |
+
)
|
mrunner_exps/reinforce.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from mrunner.helpers.specification_helper import create_experiments_helper
|
| 4 |
+
|
| 5 |
+
from mrunner_exps.utils import combine_config_with_defaults
|
| 6 |
+
|
| 7 |
+
name = globals()["script"][:-3]
|
| 8 |
+
|
| 9 |
+
# params for all exps
|
| 10 |
+
config = {
|
| 11 |
+
"exp_tag": "reinforce_goal_c3",
|
| 12 |
+
"run_kind": "reinforce",
|
| 13 |
+
"log_to_wandb": True,
|
| 14 |
+
"pretrain_steps": 1000,
|
| 15 |
+
"steps": 2000,
|
| 16 |
+
"log_every": 1,
|
| 17 |
+
"num_eval_eps": 10,
|
| 18 |
+
"verbose": False,
|
| 19 |
+
"lr": 0.001,
|
| 20 |
+
"c": 1.0,
|
| 21 |
+
"start_x": 0.0,
|
| 22 |
+
"goal_x": 50.0,
|
| 23 |
+
"bias_in_state": True,
|
| 24 |
+
"position_in_state": False,
|
| 25 |
+
"time_limit": 100,
|
| 26 |
+
"gamma": 0.99,
|
| 27 |
+
"wandbcommit": 1000,
|
| 28 |
+
"pretrain": "phase2",
|
| 29 |
+
"finetune": "full",
|
| 30 |
+
"update_every": 10, # for good definition on gradient
|
| 31 |
+
}
|
| 32 |
+
config = combine_config_with_defaults(config)
|
| 33 |
+
|
| 34 |
+
# params different between exps
|
| 35 |
+
params_grid = [
|
| 36 |
+
{
|
| 37 |
+
"seed": list(range(10)),
|
| 38 |
+
"c": list(np.arange(0.1, 1.1, 0.1)),
|
| 39 |
+
"goal_x": list(np.arange(5, 50, 5)),
|
| 40 |
+
}
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
experiments_list = create_experiments_helper(
|
| 44 |
+
experiment_name=name,
|
| 45 |
+
project_name="apple",
|
| 46 |
+
with_neptune=False,
|
| 47 |
+
script="python3 mrunner_run.py",
|
| 48 |
+
python_path=".",
|
| 49 |
+
tags=[name],
|
| 50 |
+
exclude=["logs", "wandb"],
|
| 51 |
+
base_config=config,
|
| 52 |
+
params_grid=params_grid,
|
| 53 |
+
)
|
mrunner_exps/utils.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from input_args import apple_parse_args
|
| 2 |
+
|
| 3 |
+
PARSE_ARGS_DICT = {"bc": apple_parse_args, "reinforce": apple_parse_args}
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def combine_config_with_defaults(config):
|
| 7 |
+
run_kind = config["run_kind"]
|
| 8 |
+
res = vars(PARSE_ARGS_DICT[run_kind]([]))
|
| 9 |
+
res.update(config)
|
| 10 |
+
return res
|
mrunner_run.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mrunner.helpers.client_helper import get_configuration
|
| 2 |
+
|
| 3 |
+
import wandb
|
| 4 |
+
|
| 5 |
+
from run import main
|
| 6 |
+
|
| 7 |
+
if __name__ == "__main__":
|
| 8 |
+
config = get_configuration(print_diagnostics=True, with_neptune=False)
|
| 9 |
+
|
| 10 |
+
del config["experiment_id"]
|
| 11 |
+
|
| 12 |
+
if config.log_to_wandb:
|
| 13 |
+
wandb.init(
|
| 14 |
+
entity="gmum",
|
| 15 |
+
project="apple",
|
| 16 |
+
config=config,
|
| 17 |
+
)
|
| 18 |
+
main(**config)
|
mrunner_runs/local.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
python mrunner_run.py --ex mrunner_exps/baseline.py
|
mrunner_runs/remote.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
conda activate apple
|
| 4 |
+
|
| 5 |
+
ssh-add
|
| 6 |
+
export PYTHONPATH=.
|
| 7 |
+
|
| 8 |
+
mrunner --config ~/.mrunner.yaml --context eagle_transfer_mw2 run mrunner_exps/behavioral_cloning.py
|
| 9 |
+
# mrunner --config ~/.mrunner.yaml --context eagle_transfer_mw2 run mrunner_exps/reinforce.py
|
pyproject.toml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools", "setuptools_scm", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "apple"
|
| 7 |
+
description = "Simplest experiment for showing forgetting"
|
| 8 |
+
license = { text = "Proprietary" }
|
| 9 |
+
authors = [{name = "BartekCupial", email = "bartlomiej.cupial@student.uj.edu.pl" }]
|
| 10 |
+
|
| 11 |
+
dynamic = ["version"]
|
| 12 |
+
|
| 13 |
+
requires-python = ">= 3.8, < 3.11"
|
| 14 |
+
|
| 15 |
+
dependencies = [
|
| 16 |
+
"numpy ~= 1.23",
|
| 17 |
+
"typing-extensions ~= 4.3",
|
| 18 |
+
"gym == 0.23",
|
| 19 |
+
"torch ~= 1.12",
|
| 20 |
+
"wandb ~= 0.13",
|
| 21 |
+
"pandas ~= 1.5",
|
| 22 |
+
"matplotlib ~= 3.6",
|
| 23 |
+
"seaborn ~= 0.12",
|
| 24 |
+
"scipy ~= 1.9",
|
| 25 |
+
"joblib ~= 1.2",
|
| 26 |
+
"pygame ~= 2.1",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[project.optional-dependencies]
|
| 30 |
+
build = ["build ~= 0.8"]
|
| 31 |
+
mrunner = ["mrunner @ git+https://gitlab.com/awarelab/mrunner.git"]
|
| 32 |
+
lint = [
|
| 33 |
+
"black ~= 22.6",
|
| 34 |
+
"autoflake ~= 1.4",
|
| 35 |
+
"flake8 ~= 4.0",
|
| 36 |
+
"flake8-pyi ~= 22.5",
|
| 37 |
+
"flake8-docstrings ~= 1.6",
|
| 38 |
+
"pyproject-flake8 ~= 0.0.1a4",
|
| 39 |
+
"isort ~= 5.10",
|
| 40 |
+
"pre-commit ~= 2.20",
|
| 41 |
+
]
|
| 42 |
+
test = [
|
| 43 |
+
"pytest ~= 7.1",
|
| 44 |
+
"pytest-cases ~= 3.6",
|
| 45 |
+
"pytest-cov ~= 3.0",
|
| 46 |
+
"pytest-xdist ~= 2.5",
|
| 47 |
+
"pytest-sugar ~= 0.9",
|
| 48 |
+
"hypothesis ~= 6.54",
|
| 49 |
+
]
|
| 50 |
+
dev = [
|
| 51 |
+
"apple[mrunner]",
|
| 52 |
+
"apple[build]",
|
| 53 |
+
"apple[lint]",
|
| 54 |
+
"apple[test]",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
[project.urls]
|
| 58 |
+
"Source" = "https://github.com/BartekCupial/apple"
|
| 59 |
+
|
| 60 |
+
[tool.black]
|
| 61 |
+
line_length = 120
|
| 62 |
+
|
| 63 |
+
[tool.flake8]
|
| 64 |
+
extend_exclude = [".venv/", "build/", "dist/", "docs/"]
|
| 65 |
+
per_file_ignores = ["**/_[a-z]*.py:D", "tests/*.py:D", "*.pyi:D"]
|
| 66 |
+
ignore = [
|
| 67 |
+
# Handled by black
|
| 68 |
+
"E", # pycodestyle
|
| 69 |
+
"W", # pycodestyle
|
| 70 |
+
"D",
|
| 71 |
+
]
|
| 72 |
+
ignore_decorators = "property" # https://github.com/PyCQA/pydocstyle/pull/546
|
| 73 |
+
|
| 74 |
+
[tool.isort]
|
| 75 |
+
profile = "black"
|
| 76 |
+
line_length = 120
|
| 77 |
+
order_by_type = true
|
| 78 |
+
lines_between_types = 1
|
| 79 |
+
combine_as_imports = true
|
| 80 |
+
force_grid_wrap = 2
|
| 81 |
+
|
| 82 |
+
[tool.pytest.ini_options]
|
| 83 |
+
testpaths = "tests"
|
| 84 |
+
addopts = """
|
| 85 |
+
-n auto
|
| 86 |
+
-ra
|
| 87 |
+
--tb short
|
| 88 |
+
--doctest-modules
|
| 89 |
+
--junit-xml test-results.xml
|
| 90 |
+
--cov-report term-missing:skip-covered
|
| 91 |
+
--cov-report xml:coverage.xml
|
| 92 |
+
"""
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
gym == 0.23
|
| 5 |
+
wandb
|
| 6 |
+
joblib
|
| 7 |
+
pygame
|
run.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
import wandb
|
| 4 |
+
|
| 5 |
+
from apple.envs.discrete_apple import get_apple_env
|
| 6 |
+
from apple.logger import EpochLogger
|
| 7 |
+
from apple.models.categorical_policy import CategoricalPolicy
|
| 8 |
+
from apple.training.reinforce_trainer import ReinforceTrainer
|
| 9 |
+
from apple.training.trainer import Trainer
|
| 10 |
+
from apple.utils import set_seed
|
| 11 |
+
from input_args import apple_parse_args
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main(
|
| 15 |
+
run_kind: str,
|
| 16 |
+
c: float = 1.0,
|
| 17 |
+
start_x: float = 0.0,
|
| 18 |
+
goal_x: float = 50.0,
|
| 19 |
+
time_limit: int = 200,
|
| 20 |
+
bias_in_state: bool = True,
|
| 21 |
+
position_in_state: bool = False,
|
| 22 |
+
apple_in_state: bool = True,
|
| 23 |
+
lr: float = 1e-3,
|
| 24 |
+
pretrain_steps: int = 0,
|
| 25 |
+
steps: int = 10000,
|
| 26 |
+
log_every: int = 1,
|
| 27 |
+
num_eval_eps: int = 1,
|
| 28 |
+
pretrain: str = "phase1",
|
| 29 |
+
finetune: str = "full",
|
| 30 |
+
log_to_wandb: bool = False,
|
| 31 |
+
wandbcommit: int = 1,
|
| 32 |
+
verbose: bool = False,
|
| 33 |
+
output_dir="logs/apple",
|
| 34 |
+
gamma: float = 1.0,
|
| 35 |
+
update_every: int = 10,
|
| 36 |
+
seed=0,
|
| 37 |
+
**kwargs,
|
| 38 |
+
):
|
| 39 |
+
set_seed(seed)
|
| 40 |
+
|
| 41 |
+
logger = EpochLogger(
|
| 42 |
+
exp_name=run_kind,
|
| 43 |
+
output_dir=output_dir,
|
| 44 |
+
log_to_wandb=log_to_wandb,
|
| 45 |
+
wandbcommit=wandbcommit,
|
| 46 |
+
verbose=verbose,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
env_kwargs = dict(
|
| 50 |
+
start_x=start_x,
|
| 51 |
+
goal_x=goal_x,
|
| 52 |
+
c=c,
|
| 53 |
+
time_limit=time_limit,
|
| 54 |
+
bias_in_state=bias_in_state,
|
| 55 |
+
position_in_state=position_in_state,
|
| 56 |
+
apple_in_state=apple_in_state,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
env_phase1 = get_apple_env(pretrain, **env_kwargs)
|
| 60 |
+
env_phase2 = get_apple_env(finetune, **env_kwargs)
|
| 61 |
+
test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]]
|
| 62 |
+
|
| 63 |
+
model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1)
|
| 64 |
+
optim = torch.optim.SGD(model.parameters(), lr=lr)
|
| 65 |
+
|
| 66 |
+
if run_kind == "reinforce":
|
| 67 |
+
trainer = ReinforceTrainer(model, optim, logger, gamma=gamma)
|
| 68 |
+
trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps)
|
| 69 |
+
trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps)
|
| 70 |
+
elif run_kind == "bc":
|
| 71 |
+
trainer = Trainer(model, optim, logger)
|
| 72 |
+
trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps)
|
| 73 |
+
trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
args = apple_parse_args()
|
| 78 |
+
|
| 79 |
+
if args.log_to_wandb:
|
| 80 |
+
wandb.init(
|
| 81 |
+
entity="gmum",
|
| 82 |
+
project="apple",
|
| 83 |
+
config=args,
|
| 84 |
+
settings=wandb.Settings(start_method="fork"),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
main(**vars(args))
|
setup.cfg
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[options]
|
| 2 |
+
packages = find_namespace:
|
| 3 |
+
package_dir =
|
| 4 |
+
= apple
|
| 5 |
+
|
| 6 |
+
[options.packages.find]
|
| 7 |
+
where = apple
|
| 8 |
+
|
| 9 |
+
[options.package_data]
|
| 10 |
+
* = py.typed, *.pyi
|