diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..50c7d2dbbcdb8b3e1d5b485899d05491056d349a --- /dev/null +++ b/.editorconfig @@ -0,0 +1,16 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*] +indent_size = 4 +indent_style = space +max_line_length = 120 +tab_width = 8 + +[*.{yml,yaml}] +indent_size = 2 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c3197e17fda8acd0a41ed6fcd3c3e9f9db7f39e4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +/.venv/ +/.python-version + +/build/ +/dist/ +/site/ +/test-results.xml +/.coverage +/coverage.xml + +/.hypothesis/ +__pycache__/ +*.egg-info/ + +/.vscode/ +wandb +logs +*.pt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75e3ca1e16566bba5452004f27081865cadc91c5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +repos: + - repo: local + hooks: + - id: autoflake + name: autoflake + entry: autoflake + args: [--in-place, --remove-all-unused-imports, --remove-unused-variables] + language: system + types_or: [python, pyi] + + - id: isort + name: isort + entry: isort + args: [--quiet] + language: system + types_or: [python, pyi] + + - id: black + name: black + entry: black + args: [--quiet] + language: system + types_or: [python, pyi] + + - id: flake8 + name: flake8 + entry: pflake8 + language: system + types_or: [python, pyi] diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 0000000000000000000000000000000000000000..3d139c14c0c6f405dd0eb6dc27eed79b037b9d4a --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,34 @@ +[theme] +#theme primary +base="dark" +# Primary accent color for interactive elements. +primaryColor="f63366" + +# Background color for the main content area. +#backgroundColor = + +# Background color used for the sidebar and most interactive widgets. +#secondaryBackgroundColor ='grey' + +# Color used for almost all text. +#textColor ='blue' + +# Font family for all text in the app, except code blocks. One of "sans serif", "serif", or "monospace". +# Default: "sans serif" +font = "sans serif" + +# [logger] +# level='info' +# messageFormat = "%(message)s" +#messageFormat="%(asctime)s %(message)s" + +[global] + +# By default, Streamlit checks if the Python watchdog module is available and, if not, prints a warning asking for you to install it. The watchdog module is not required, but highly recommended. It improves Streamlit's ability to detect changes to files in your filesystem. +# If you'd like to turn off this warning, set this to True. +# Default: false +disableWatchdogWarning = false + +# If True, will show a warning when you run a Streamlit-enabled script via "python my_script.py". +# Default: true +showWarningOnDirectExecution = false \ No newline at end of file diff --git a/README.md b/README.md index 105fd2b05c502de0997c771d7006c9397b830504..2fbe1c288e6dcf971f0f9dc71c37998331dfdf33 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ --- -title: Apple -emoji: 📚 -colorFrom: red -colorTo: yellow +title: Apple Retrieval +emoji: 🍎 +colorFrom: yellow +colorTo: red sdk: streamlit -sdk_version: 1.17.0 +sdk_version: 1.15.2 app_file: app.py -pinned: false +pinned: true +fullWidth: true --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..030c31b9600938810114132c18b31e2dfc470bbe --- /dev/null +++ b/app.py @@ -0,0 +1,606 @@ +import asyncio + +from functools import partial + +import numpy as np +import pandas as pd +import streamlit as st +import torch + +from apple.envs.discrete_apple import get_apple_env +from apple.evaluation.render_episode import render_episode +from apple.logger import EpochLogger +from apple.models.categorical_policy import CategoricalPolicy +from apple.training.trainer import Trainer + +QUEUE_SIZE = 1000 + + +def init_training( + c: float = 1.0, + start_x: float = 0.0, + goal_x: float = 50.0, + time_limit: int = 200, + lr: float = 1e-3, + weight1: float = 0.0, + weight2: float = 0.0, + pretrain: str = "phase1", + finetune: str = "full", + bias_in_state: bool = True, + position_in_state: bool = False, + apple_in_state: bool = True, +): + st.session_state.logger = EpochLogger(verbose=False) + + env_kwargs = dict( + start_x=start_x, + goal_x=goal_x, + c=c, + time_limit=time_limit, + bias_in_state=bias_in_state, + position_in_state=position_in_state, + apple_in_state=apple_in_state, + ) + + st.session_state.env_full = get_apple_env("full", render_mode="rgb_array", **env_kwargs) + st.session_state.env_phase1 = get_apple_env(pretrain, **env_kwargs) + st.session_state.env_phase2 = get_apple_env(finetune, **env_kwargs) + st.session_state.test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]] + + st.session_state.model = CategoricalPolicy( + st.session_state.env_phase1.observation_space.shape[0], 1, weight1, weight2 + ) + st.session_state.optim = torch.optim.SGD(st.session_state.model.parameters(), lr=lr) + st.session_state.trainer = Trainer(st.session_state.model, st.session_state.optim, st.session_state.logger) + st.session_state.train_it = 0 + st.session_state.draw_it = 0 + st.session_state.total_steps = 0 + st.session_state.data = [] + + st.session_state.obs1 = st.session_state.env_phase1.reset() + st.session_state.obs2 = st.session_state.env_phase2.reset() + + +def init_reset(): + st.session_state.rollout_iterator = iter( + partial(render_episode, st.session_state.env_full, st.session_state.model)() + ) + st.session_state.last_image = dict( + x=0, + obs=st.session_state.env_full.reset(), + action=None, + reward=0, + done=False, + episode_len=0, + episode_return=0, + pixel_array=st.session_state.env_full.unwrapped.render(), + ) + + +def select_preset(): + if st.session_state.pick_preset == 0: + preset_finetuning_interference() + elif st.session_state.pick_preset == 1: + preset_train_full_from_scratch() + elif st.session_state.pick_preset == 2: + preset_task_interference() + elif st.session_state.pick_preset == 3: + preset_without_task_interference() + + +def preset_task_interference(): + st.session_state.pick_c = 0.5 + st.session_state.pick_goal_x = 20 + st.session_state.pick_time_limit = 50 + st.session_state.pick_lr = 0.05 + st.session_state.pick_phase1task = "phase2" + st.session_state.pick_phase2task = "phase1" + st.session_state.pick_weight1 = 0.0 + st.session_state.pick_weight2 = 0.0 + st.session_state.pick_phase1steps = 500 + st.session_state.pick_phase2steps = 500 + need_reset() + + +def preset_finetuning_interference(): + st.session_state.pick_c = 0.5 + st.session_state.pick_goal_x = 20 + st.session_state.pick_time_limit = 50 + st.session_state.pick_lr = 0.05 + st.session_state.pick_phase1task = "phase2" + st.session_state.pick_phase2task = "full" + st.session_state.pick_weight1 = 0.0 + st.session_state.pick_weight2 = 0.0 + st.session_state.pick_phase1steps = 500 + st.session_state.pick_phase2steps = 2000 + need_reset() + + +def preset_without_task_interference(): + st.session_state.pick_c = 1.0 + st.session_state.pick_goal_x = 20 + st.session_state.pick_time_limit = 50 + st.session_state.pick_lr = 0.05 + st.session_state.pick_phase1task = "phase2" + st.session_state.pick_phase2task = "phase1" + st.session_state.pick_weight1 = 0.0 + st.session_state.pick_weight2 = 0.0 + st.session_state.pick_phase1steps = 500 + st.session_state.pick_phase2steps = 500 + need_reset() + + +def preset_train_full_from_scratch(): + st.session_state.pick_c = 0.5 + st.session_state.pick_goal_x = 20 + st.session_state.pick_time_limit = 50 + st.session_state.pick_lr = 0.05 + st.session_state.pick_phase1task = "phase2" + st.session_state.pick_phase2task = "full" + st.session_state.pick_weight1 = 0.0 + st.session_state.pick_weight2 = 0.0 + st.session_state.pick_phase1steps = 0 + st.session_state.pick_phase2steps = 2000 + need_reset() + + +def empty_queue(q: asyncio.Queue): + for _ in range(q.qsize()): + # Depending on your program, you may want to + # catch QueueEmpty + q.get_nowait() + q.task_done() + + +def reset(**kwargs): + init_training(**kwargs) + init_reset() + st.session_state.play = False + st.session_state.step = False + st.session_state.render = False + st.session_state.done = False + empty_queue(st.session_state.queue) + empty_queue(st.session_state.queue_render) + st.session_state.play_pause = False + st.session_state.need_reset = False + + +def render_start(): + st.session_state.render = True + st.session_state.done = False + init_reset() + + +def need_reset(): + st.session_state.need_reset = True + st.session_state.play = False + st.session_state.render = False + + +def play_pause(): + if st.session_state.play: + st.session_state.play = False + st.session_state.play_pause = False + else: + st.session_state.play = True + st.session_state.play_pause = True + + +def step(): + st.session_state.step = True + + +def plot(data_placeholder): + df = pd.DataFrame(st.session_state.data) + if not df.empty: + df.set_index("total_env_steps", inplace=True) + container = data_placeholder.container() + c1, c2, c3 = container.columns(3) + + def view_df(names): + rdf = df.loc[:, df.columns.isin(names)] + if rdf.empty: + return pd.DataFrame([{name: 0 for name in names}]) + else: + return rdf + + c1.write("phase1/success_rate") + c1.line_chart(view_df(["phase1/success"])) + c2.write("phase2/success_rate") + c2.line_chart(view_df(["phase2/success"])) + c3.write("full/success_rate") + c3.line_chart(view_df(["full/success"])) + + c1.write("train/loss") + c1.line_chart(view_df(["train/loss"])) + c2.write("weight0") + c2.line_chart(view_df(["weight0"])) + c3.write("weight1") + c3.line_chart(view_df(["weight1"])) + + +async def draw(data_placeholder, queue, delay, steps, plotfrequency): + while (st.session_state.play or st.session_state.step) and st.session_state.draw_it < steps: + _ = await asyncio.sleep(delay) + new_data = await queue.get() + st.session_state.draw_it += 1 + if st.session_state.draw_it % plotfrequency == 0: + st.session_state.data.append(new_data) + plot(data_placeholder) + st.session_state.step = False + queue.task_done() + + +async def train(queue, delay, steps, obs, env, num_eval_episodes, plotfrequency): + while (st.session_state.play or st.session_state.step) and st.session_state.train_it < steps: + _ = await asyncio.sleep(delay) + st.session_state.train_it += 1 + st.session_state.total_steps += 1 + + output = st.session_state.model(obs) + action, log_prob = st.session_state.model.sample(output) + st.session_state.trainer.update( + env, output, st.session_state.model, st.session_state.optim, st.session_state.logger + ) + + obs, reward, done, info = env.step(action) + + if done: + obs = env.reset() + + if st.session_state.train_it % plotfrequency == 0: + st.session_state.trainer.test_agent( + st.session_state.model, st.session_state.logger, st.session_state.test_envs, num_eval_episodes + ) + data = st.session_state.trainer.log( + st.session_state.logger, st.session_state.train_it, st.session_state.model + ) + else: + data = 0 + + _ = await queue.put(data) + + +async def produce_images(queue, delay): + while st.session_state.render and not st.session_state.done: + _ = await asyncio.sleep(delay) + data = next(st.session_state.rollout_iterator) + st.session_state.done = data["done"] + _ = await queue.put(data) + + +def show_image(data, image_placeholder): + c = image_placeholder.container() + c.image( + data["pixel_array"], + ) + c.text( + f"agent position: {data['x']} \ntimestep: {data['episode_len']} \nepisode return: {data['episode_return']} \n" + ) + + +async def consume_images(image_placeholder, queue, delay): + while st.session_state.render and not st.session_state.done: + _ = await asyncio.sleep(delay) + data = await queue.get() + st.session_state.last_image = data + show_image(data, image_placeholder) + queue.task_done() + + +async def run_app( + data_placeholder, + queue, + produce_delay, + consume_delay, + phase1steps, + phase2steps, + plotfrequency, + num_eval_episodes, + image_placeholder, + queue_render, + render_produce_delay, + render_consume_delay, +): + _ = await asyncio.gather( + produce_images(queue_render, render_produce_delay), + consume_images(image_placeholder, queue_render, render_consume_delay), + ) + + st.session_state.render = False + st.session_state.done = False + + empty_queue(queue_render) + + _ = await asyncio.gather( + train( + queue, + produce_delay, + phase1steps, + st.session_state.obs1, + st.session_state.env_phase1, + num_eval_episodes, + plotfrequency, + ), + draw(data_placeholder, queue, consume_delay, phase1steps, plotfrequency), + ) + + _ = await asyncio.gather( + train( + queue, + produce_delay, + phase1steps + phase2steps, + st.session_state.obs2, + st.session_state.env_phase2, + num_eval_episodes, + plotfrequency, + ), + draw(data_placeholder, queue, consume_delay, phase1steps + phase2steps, plotfrequency), + ) + + +##### ACTUAL APP + +if __name__ == "__main__": + st.set_page_config( + layout="wide", + initial_sidebar_state="auto", + page_title="Apple Retrieval", + page_icon=None, + ) + st.title("ON THE ROLE OF FORGETTING IN FINE-TUNING REINFORCEMENT LEARNING MODELS") + st.header("Toy example of forgetting: AppleRetrieval") + + col1, col2, col3 = st.sidebar.columns(3) + + options = ( + "phase2 full interference", + "full from scratch", + "phase2 phase1 forgetting", + "phase2 phase1 optimal solution", + ) + st.sidebar.selectbox( + "parameter presets", + range(len(options)), + index=0, + format_func=lambda x: options[x], + on_change=select_preset, + key="pick_preset", + ) + + pick_container = st.sidebar.container() + c = pick_container.number_input("c", value=0.5, on_change=need_reset, key="pick_c") + goal_x = pick_container.number_input("distance to apple", value=20, on_change=need_reset, key="pick_goal_x") + time_limit = pick_container.number_input("time limit", value=50, on_change=need_reset, key="pick_time_limit") + lr = pick_container.selectbox( + "Learning rate", + np.array([0.00001, 0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0]), + 5, + on_change=need_reset, + key="pick_lr", + ) + phase1task = pick_container.selectbox( + "Pretraining task", ("full", "phase1", "phase2"), 2, on_change=need_reset, key="pick_phase1task" + ) + phase2task = pick_container.selectbox( + "Finetuning task", ("full", "phase1", "phase2"), 0, on_change=need_reset, key="pick_phase2task" + ) + # weight1 = pick_container.number_input("init weight1", value=0, on_change=need_reset, key="pick_weight1") + # weight2 = pick_container.number_input("init weight2", value=0, on_change=need_reset, key="pick_weight2") + + if "event_loop" not in st.session_state: + st.session_state.loop = asyncio.new_event_loop() + asyncio.set_event_loop(st.session_state.loop) + + if "queue" not in st.session_state: + st.session_state.queue = asyncio.Queue(QUEUE_SIZE) + if "queue_render" not in st.session_state: + st.session_state.queue_render = asyncio.Queue(QUEUE_SIZE) + if "play" not in st.session_state: + st.session_state.play = False + if "step" not in st.session_state: + st.session_state.step = False + if "render" not in st.session_state: + st.session_state.render = False + + reset_button = partial( + reset, + c=c, + start_x=0, + goal_x=goal_x, + time_limit=time_limit, + lr=lr, + # weight1=weight1, + # weight2=weight2, + # weight1=0, + # weight2=10, # soves the environment + pretrain=phase1task, + finetune=phase2task, + ) + col1.button("Reset", on_click=reset_button, type="primary") + + if "logger" not in st.session_state or st.session_state.need_reset: + reset_button() + + myKey = "play_pause" + if myKey not in st.session_state: + st.session_state[myKey] = False + + if st.session_state[myKey]: + myBtn = col2.button("Pause", on_click=play_pause, type="primary") + else: + myBtn = col2.button("Play", on_click=play_pause, type="primary") + + col3.button("Step", on_click=step, type="primary") + + st.header("Summary") + st.write( + """ + Run training on the "phase2 full interference" setting to see an example of forgetting in fine-tuning RL models. + A model is pre-trained on a part of the environment called Phase 2 for 500 steps, + and then it is fine-tuned on the whole environment for another 2000 steps. + However, it forgets how to perform on Phase 2 during fine-tuning before it even gets there. + We highlight this as an important problem in fine-tuning RL models. + We invite you to play around with the hyperparameters and find out more about the forgetting phenomenon. + """ + ) + + st.markdown( + """ + By comparing the results of these four presets, you can gain a deeper understanding of how different training configurations can impact the final performance of a neural network. + * `phase2 full interference` - already mentioned setting where we pretrain on phase 2 and finetune on whole environment. + * `full from scratch` - "standard" train on whole environment for 2000 steps for comparison. + * `phase2 phase1 forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks causes network to forget how to solve phase 2. + * `phase2 phase1 without forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks "move" network weights to optimal solution for both tasks. + """ + ) + + data_placeholder = st.empty() + render_placeholder = st.empty() + + phase1steps = st.sidebar.slider("Pretraining Steps", 0, 2000, 500, 10, key="pick_phase1steps") + phase2steps = st.sidebar.slider("Finetuning Steps", 0, 2000, 2000, 10, key="pick_phase2steps") + plotfrequency = st.sidebar.number_input("Log frequency ", min_value=1, value=10, step=1) + num_eval_episodes = st.sidebar.number_input("Number of evaluation episodes", min_value=1, value=10, step=1) + + if not st.session_state.play and len(st.session_state.data) > 0: + plot(data_placeholder) + + st.write( + """ + Figure 1: Plots (a), (b), (c) show the performance of the agent on three different variants of the environment. + Phase1 (a) the agent's goal is to reach the apple. + Phase2 (b) the agent's has to go back home by returning to x = 0. + Full (c) combination of both phases of the environment. + (d) shows the training loss. + State vector has only two values, we can view the netwoks weights (e), (f). + """ + ) + + st.header("Vizualize agent behavior") + st.button("Rollout one episode", on_click=render_start, type="primary") + + c1, c2 = st.columns(2) + image_placeholder = c1.empty() + + render_produce_delay = 1 / 3 + render_consume_delay = 1 / 3 + + st.header("About the environment") + + st.write( + """This study forms a preliminary investigation into the problem of forgetting in fine-tuning RL models. + We show that fine-tuning a pre-trained model on compositional RL problems might result in a rapid + deterioration of the performance of the pre-trained model if the relevant data is not available at the + beginning of the training. + This phenomenon is known as catastrophic forgetting. + In this demo we show how it can occur in simple toyish situations but it might occur in more realistic problems (e.g. (SAC with MLPs on a compositional robotic environmen). + In our [WIP] paper we showed that applying CL methods significantly limits forgetting and allows for efficient transfer. + """ + ) + + st.write( + """The AppleRetrieval environment is a toy example to demonstrate the issue of interference in reinforcement learning. + The goal of the agent is to retrieve an apple from position x = M and return home to x = 0 within a set number of steps T. + The state of the agent is represented by a vector s, which has two elements: s = [1, -c] in phase 1 and s = [1, c] in phase 2. + The first element is a constant and the second element represents the information about the current phase. + The optimal policy is to go right in phase 1 and go left in phase 2. + The cause of interference can be identified by checking the reliance of the policy on either s1 or s2. + If the model mostly relies on s2, interference will be limited, but if it relies on s1, + interference will occur as its value is the same in both phases. + The magnitude of s1 and s2 can be adjusted by changing the c parameter to guide the model towards focusing on either one. + This simple toy environment shows that the issue of interference can be fundamental to reinforcement learning.""" + ) + + cc1, cc2 = st.columns([1, 2]) + cc1.image("assets/apple_env.png") + + st.header("Training algorithm") + st.write( + """For practical reasons (time and stability) we don't use REINFORCE in this DEMO to illustrate the training dynamics of the environment. + Instead, we train the model by minimizing the negative log likelihood of target actions (move right or left). + We train the model in each step of the environment and sample actions from Categorical distribution taken from model's output. + """ + ) + + st.markdown( + """ + Pseudocode of the train loop: +```python +obs = env.reset() +for timestep in range(steps): + probs = model(obs) + dist = Categorical(probs) + action = dist.sample() + target_action = env.get_target_action() + loss = -dist.log_prob(target_action) + + optim.zero_grad() + loss.backward() + optim.step() + + obs, reward, done, info = env.step(action) + if done: + obs = env.reset() + """ + ) + + st.header("What do all the hyperparameters mean?") + + st.markdown( + """ + * **parameter presets** - few sets of pre-defined hyperparameters that can be used as a starting point for a specific experiment. + * **c** - second element of state vecor, decreasing this value will result in stronger forgetting. + * **distance to apple** - refers to how far agent needs to travel in the right before it encounters the apple. + * **time limit** - maximum amount of timesteps that the agent is allowed to interact with the environment before episode ends. + * **learning rate** - hyperparameter that determines the steps size taken by the learning algoritm in each iteration. + * **pretraining and finetuning task** - define the environment the model will be trained on each stage. + * **pretraining and finetuning steps** - define how long model will be trained on each stage. + * **init weights** - initial values assigned to the model's parameters before trainig begins. + * **log frequency** - refers to frequency of logging the metrics and evaluating the agent. + * **number of evaluation episodes** - number of rollouts during testing of the agent. + """ + ) + + st.header("Limitations & Conclusions") + + st.write( + """ + At the same time, this study, due to its preliminary nature, has numerous limitations which we +hope to address in future work. We only considered a fairly strict formulation of the forgetting +scenario where we assumed that the pre-trained model works perfectly on tasks that appear later in +the fine-tuning. In practice, one should also consider the case when even though there are differences +between the pre-training and fine-tuning tasks, transfer is still possible. + """ + ) + + st.write( + """ + At the same time, even given +these limitations, we see forgetting as an important problem to be solved and hope that addressing +these issues in the future might help with building and fine-tuning better foundation models in RL. + """ + ) + + produce_delay = 1 / 1000 + consume_delay = 1 / 1000 + + plot(data_placeholder) + show_image(st.session_state.last_image, image_placeholder) + + asyncio.run( + run_app( + data_placeholder, + st.session_state.queue, + produce_delay, + consume_delay, + phase1steps, + phase2steps, + plotfrequency, + num_eval_episodes, + image_placeholder, + st.session_state.queue_render, + render_produce_delay, + render_consume_delay, + ) + ) diff --git a/apple/envs/discrete_apple.py b/apple/envs/discrete_apple.py new file mode 100644 index 0000000000000000000000000000000000000000..68a6787db20798b86c206ed46dad8df4e0fc2aba --- /dev/null +++ b/apple/envs/discrete_apple.py @@ -0,0 +1,384 @@ +from contextlib import closing +from io import StringIO +from os import path +from typing import Optional + +import gym +import gym.spaces +import numpy as np + +from gym.error import DependencyNotInstalled +from gym.spaces import Box +from gym.utils import colorize +from gym.wrappers import TimeLimit + +from apple.wrappers import SuccessCounter + + +def get_apple_env(task, time_limit=20, **kwargs): + if task == "full": + env = AppleEnv(**kwargs) + elif task == "phase1": + env = ApplePhase0Env(**kwargs) + time_limit = time_limit // 2 + elif task == "phase2": + env = ApplePhase1Env(**kwargs) + time_limit = time_limit // 2 + else: + raise NotImplementedError + env = TimeLimit(env, time_limit) + env = SuccessCounter(env) + env.name = task + return env + + +class AppleEnv(gym.Env): + metadata = { + "render_modes": ["human", "ansi", "rgb_array"], + "render_fps": 4, + } + + def __init__( + self, + start_x: int, + goal_x: int, + c: float, + reward_value: float = 1.0, + success_value: float = 1.0, + bias_in_state: bool = True, + position_in_state: bool = False, + apple_in_state: bool = True, + render_mode: Optional[str] = None, + ): + self.start_x = start_x + self.goal_x = goal_x + self.c = c + self.reward_value = reward_value + self.success_value = success_value + + self.x = start_x + self.phase = 0 + self.delta = 1 + self.change_phase = True + self.init_pos = self.start_x + self.success_when_finish_phase = 1 + + self.bias_in_state = bias_in_state + self.position_in_state = position_in_state + self.apple_in_state = apple_in_state + + example_state = self.state() + mult = np.ones_like(example_state) + if self.apple_in_state: + mult[-1] *= -1 + self.observation_space = Box(low=example_state, high=example_state * mult) + self.action_space = gym.spaces.Discrete(2) + self.timestep = 0 + + self.render_mode = render_mode + self.scope_size = 15 + + self.nrow, self.ncol = nrow, ncol = self.gui_canvas().shape + self.window_size = (64 * ncol, 64 * nrow) + self.cell_size = ( + self.window_size[0] // self.ncol, + self.window_size[1] // self.nrow, + ) + self.window_surface = None + + self.clock = None + self.empty_img = None + self.ground_img = None + self.underground_img = None + self.elf_images = None + self.home_images = None + self.goal_img = None + self.start_img = None + self.stool_img = None + + def reset(self): + self.x = self.init_pos + if self.change_phase: + self.phase = 0 + + self.timestep = 0 + state = self.state() + self.lastaction = None + return state + + def validate_action(self, action): + err_msg = f"{action!r} ({type(action)}) invalid" + assert self.action_space.contains(action), err_msg + assert self.state is not None, "Call reset before using step method." + + def move(self, action): + delta = self.delta if action == 1 else -self.delta + self.x += delta + + def state(self): + assert self.bias_in_state or self.position_in_state or self.apple_in_state + + if self.phase == 0: + c = -self.c + elif self.phase == 1: + c = self.c + else: + raise NotImplementedError + + state = [] + if self.bias_in_state: + state.append(1) + if self.position_in_state: + state.append(self.x) + if self.apple_in_state: + state.append(c) + + return np.array(state) + + def desc(self): + s = self.scope_size // 2 + + desc = list("." * self.scope_size) + desc = np.asarray(desc, dtype="c") + + start_relative_position = self.start_x - self.x + s + if 0 <= start_relative_position <= self.scope_size - 1: + desc[start_relative_position] = "S" + + goal_relative_position = self.goal_x - self.x + s + if 0 <= goal_relative_position <= self.scope_size - 1 and self.phase == 0: + desc[goal_relative_position] = "G" + + if 0 <= goal_relative_position <= self.scope_size - 1 and self.phase == 1: + desc[goal_relative_position] = "D" + + return desc + + def text_canvas(self): + desc = self.desc() + canvas = np.ones((2, len(desc) * 3 + 2), dtype="c") + canvas[:] = "\x20" + + for i, d in zip(range(2, len(canvas[0]), 3), desc): + canvas[0][i] = d + + axis = np.arange(len(desc)) - len(desc) // 2 + self.x + for i, d in zip(range(2, len(canvas[0]), 3), axis): + if d % 5 == 0: + s = str(d) + + c = len(s) // 2 + for j, char in zip(range(len(s)), reversed(s)): + canvas[1][i - j + c] = char + return canvas + + def gui_canvas(self): + desc = self.desc() + upper_canvas = np.ones(len(desc), dtype="c") + upper_canvas[:] = "~" + lower_canvas = np.ones(len(desc), dtype="c") + lower_canvas[:] = "#" + canvas = np.stack([upper_canvas, upper_canvas, desc, lower_canvas]) + + return canvas + + def reward(self, action): + return self.reward_value if action == self.get_target_action() else -self.reward_value + + def get_target_action(self): + return 1 if self.phase == 0 else 0 + + def step(self, action): + self.timestep += 1 + self.validate_action(action) + + done = False + info = {"success": False} + + reward = self.reward(action) + + self.move(action) + + finish_phase0 = self.phase == 0 and self.x >= self.goal_x + finish_phase1 = self.phase == 1 and self.x <= self.start_x + + if self.change_phase: + if finish_phase0: + self.phase = 1 + + if (self.success_when_finish_phase == 0 and finish_phase0) or ( + self.success_when_finish_phase == 1 and finish_phase1 + ): + done = True + info["success"] = True + reward = self.success_value + + state = self.state() + + self.lastaction = action + if self.render_mode == "human": + self.render() + + return state, reward, done, info + + def render(self): + if self.render_mode is None: + assert self.spec is not None + gym.logger.warn( + "You are calling render method without specifying any render mode. " + "You can specify the render_mode at initialization, " + f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")' + ) + return + + if self.render_mode == "ansi": + return self._render_text() + else: # self.render_mode in {"human", "rgb_array"}: + return self._render_gui(self.render_mode) + + def _render_gui(self, mode): + try: + import pygame + except ImportError as e: + raise DependencyNotInstalled("pygame is not installed, run `pip install pygame`") from e + + if self.window_surface is None: + pygame.init() + + if mode == "human": + pygame.display.init() + pygame.display.set_caption("Apple Retrieval") + self.window_surface = pygame.display.set_mode(self.window_size) + elif mode == "rgb_array": + self.window_surface = pygame.Surface(self.window_size) + + assert self.window_surface is not None, "Something went wrong with pygame. This should never happen." + + if self.clock is None: + self.clock = pygame.time.Clock() + if self.empty_img is None: + file_name = path.join(path.dirname(__file__), "img/white.png") + self.empty_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size) + if self.ground_img is None: + file_name = path.join(path.dirname(__file__), "img/part_grass.png") + self.ground_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size) + if self.underground_img is None: + file_name = path.join(path.dirname(__file__), "img/g2.png") + self.underground_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size) + if self.goal_img is None: + file_name = path.join(path.dirname(__file__), "img/apple.png") + self.goal_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size) + if self.stool_img is None: + file_name = path.join(path.dirname(__file__), "img/stool.png") + self.stool_img = pygame.transform.scale(pygame.image.load(file_name), self.cell_size) + if self.start_img is None: + homes = [ + path.join(path.dirname(__file__), "img/home00.png"), + path.join(path.dirname(__file__), "img/home01.png"), + path.join(path.dirname(__file__), "img/home02.png"), + path.join(path.dirname(__file__), "img/home10.png"), + path.join(path.dirname(__file__), "img/home11.png"), + path.join(path.dirname(__file__), "img/home12.png"), + ] + self.home_images = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in homes] + if self.elf_images is None: + elfs = [ + path.join(path.dirname(__file__), "img/elf_left.png"), + path.join(path.dirname(__file__), "img/elf_right.png"), + path.join(path.dirname(__file__), "img/elf_down.png"), + ] + self.elf_images = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in elfs] + + desc = self.gui_canvas().tolist() + + cache = [] + assert isinstance(desc, list), f"desc should be a list or an array, got {desc}" + for y in range(self.nrow): + for x in range(self.ncol): + pos = (x * self.cell_size[0], y * self.cell_size[1]) + + self.window_surface.blit(self.empty_img, pos) + + if desc[y][x] == b"~": + self.window_surface.blit(self.empty_img, pos) + elif desc[y][x] == b"#": + self.window_surface.blit(self.underground_img, pos) + else: + self.window_surface.blit(self.ground_img, pos) + # if y == self.nrow - 1: + + if len(cache) > 0: + cache_img, cache_pos = cache.pop() + self.window_surface.blit(cache_img, cache_pos) + + if desc[y][x] == b"G": + self.window_surface.blit(self.stool_img, pos) + self.window_surface.blit(self.goal_img, pos) + elif desc[y][x] == b"D": + self.window_surface.blit(self.stool_img, pos) + elif desc[y][x] == b"S": + for h in range(len(self.home_images)): + i = h // 3 + j = h % 3 + + home_img = self.home_images[i * 3 + j] + home_pos = ((x - 1 + j) * self.cell_size[0], (y - 1 + i) * self.cell_size[1]) + if h == len(self.home_images) - 1: + cache.append((home_img, home_pos)) + else: + self.window_surface.blit(home_img, home_pos) + + # paint the elf + # bot_row, bot_col = self.s // self.ncol, self.s % self.ncol + bot_col = self.scope_size // 2 + bot_row = 2 + cell_rect = (bot_col * self.cell_size[0], bot_row * self.cell_size[1]) + last_action = self.lastaction if self.lastaction is not None else 2 + elf_img = self.elf_images[last_action] + + self.window_surface.blit(elf_img, cell_rect) + + # font = pygame.font.SysFont(None, 20) + # img = font.render(f"agent position = {self.x}", True, "black") + # self.window_surface.blit(img, (5, 5)) + # img = font.render(f"timestep = {self.timestep}", True, "black") + # self.window_surface.blit(img, (5, 25)) + + if mode == "human": + pygame.event.pump() + pygame.display.update() + self.clock.tick(self.metadata["render_fps"]) + elif mode == "rgb_array": + return np.transpose(np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2)) + + def _render_text(self): + desc = self.text_canvas() + outfile = StringIO() + + row, col = 0, (self.scope_size // 2) * 3 + 2 + desc = [[c.decode("utf-8") for c in line] for line in desc] + desc[row][col] = colorize(desc[row][col], "red", highlight=True) + + outfile.write("\n") + outfile.write("\n".join("".join(line) for line in desc) + "\n") + + with closing(outfile): + return outfile.getvalue() + + +class ApplePhase0Env(AppleEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.change_phase = False + self.phase = 0 + self.init_pos = self.start_x + self.success_when_finish_phase = 0 + + +class ApplePhase1Env(AppleEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.change_phase = False + self.phase = 1 + self.init_pos = self.goal_x + self.success_when_finish_phase = 1 diff --git a/apple/envs/img/apple.png b/apple/envs/img/apple.png new file mode 100644 index 0000000000000000000000000000000000000000..f8e94216a10fa7b3dbc503a810e29adaaab95a86 Binary files /dev/null and b/apple/envs/img/apple.png differ diff --git a/apple/envs/img/elf_down.png b/apple/envs/img/elf_down.png new file mode 100644 index 0000000000000000000000000000000000000000..afa3daf1387c93fa81e34c0b0cc258fa23a5ee36 Binary files /dev/null and b/apple/envs/img/elf_down.png differ diff --git a/apple/envs/img/elf_left.png b/apple/envs/img/elf_left.png new file mode 100644 index 0000000000000000000000000000000000000000..bc9e22ea63286dda4a427770e582be96d76f25aa Binary files /dev/null and b/apple/envs/img/elf_left.png differ diff --git a/apple/envs/img/elf_right.png b/apple/envs/img/elf_right.png new file mode 100644 index 0000000000000000000000000000000000000000..836403158a586644dd5cb5d2894cdbb4a80d7035 Binary files /dev/null and b/apple/envs/img/elf_right.png differ diff --git a/apple/envs/img/g1.png b/apple/envs/img/g1.png new file mode 100644 index 0000000000000000000000000000000000000000..48885739b09897c5bab74a89bbb3fdd82e40d324 Binary files /dev/null and b/apple/envs/img/g1.png differ diff --git a/apple/envs/img/g2.png b/apple/envs/img/g2.png new file mode 100644 index 0000000000000000000000000000000000000000..ef182990597bee6639219718cb4f4e609a0efdaf Binary files /dev/null and b/apple/envs/img/g2.png differ diff --git a/apple/envs/img/g3.png b/apple/envs/img/g3.png new file mode 100644 index 0000000000000000000000000000000000000000..f1d47914ec83bd9b0f61494843bb309d02848ae6 Binary files /dev/null and b/apple/envs/img/g3.png differ diff --git a/apple/envs/img/grass.jpg b/apple/envs/img/grass.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b5808b1ae987bd003cf3cefface41fcd83f24b6b Binary files /dev/null and b/apple/envs/img/grass.jpg differ diff --git a/apple/envs/img/home.png b/apple/envs/img/home.png new file mode 100644 index 0000000000000000000000000000000000000000..4199ea4770d0d3296c7c41b446b773a5799db0e2 Binary files /dev/null and b/apple/envs/img/home.png differ diff --git a/apple/envs/img/home00.png b/apple/envs/img/home00.png new file mode 100644 index 0000000000000000000000000000000000000000..fe8e3aee3877b6006adde0ca4e43d28a7df038e1 Binary files /dev/null and b/apple/envs/img/home00.png differ diff --git a/apple/envs/img/home01.png b/apple/envs/img/home01.png new file mode 100644 index 0000000000000000000000000000000000000000..5813d7ae0d1e8cac116a38de1b1d149c71c0e887 Binary files /dev/null and b/apple/envs/img/home01.png differ diff --git a/apple/envs/img/home02.png b/apple/envs/img/home02.png new file mode 100644 index 0000000000000000000000000000000000000000..cc0eeff954fc49f6105ccdf3b0b13052df6e9199 Binary files /dev/null and b/apple/envs/img/home02.png differ diff --git a/apple/envs/img/home10.png b/apple/envs/img/home10.png new file mode 100644 index 0000000000000000000000000000000000000000..756cbcbd4c7ca8500f6f069f12272b14b2cdf485 Binary files /dev/null and b/apple/envs/img/home10.png differ diff --git a/apple/envs/img/home11.png b/apple/envs/img/home11.png new file mode 100644 index 0000000000000000000000000000000000000000..62f7bc53e9a49510f32f364de4db41d3d61ab055 Binary files /dev/null and b/apple/envs/img/home11.png differ diff --git a/apple/envs/img/home12.png b/apple/envs/img/home12.png new file mode 100644 index 0000000000000000000000000000000000000000..84c372a7026d304c18679ca2d6a3ad9e08542f96 Binary files /dev/null and b/apple/envs/img/home12.png differ diff --git a/apple/envs/img/home2.png b/apple/envs/img/home2.png new file mode 100644 index 0000000000000000000000000000000000000000..2e9878a15ffd64f236eee5bdfc861bedd552913b Binary files /dev/null and b/apple/envs/img/home2.png differ diff --git a/apple/envs/img/home2_with_apples.png b/apple/envs/img/home2_with_apples.png new file mode 100644 index 0000000000000000000000000000000000000000..84ba63bbb0adcdf136aa1777e4cf63f7dafa273a Binary files /dev/null and b/apple/envs/img/home2_with_apples.png differ diff --git a/apple/envs/img/home_grass.png b/apple/envs/img/home_grass.png new file mode 100644 index 0000000000000000000000000000000000000000..41232134663cb1b9ac8058e80cfd7c9703417a9a Binary files /dev/null and b/apple/envs/img/home_grass.png differ diff --git a/apple/envs/img/part_grass.png b/apple/envs/img/part_grass.png new file mode 100644 index 0000000000000000000000000000000000000000..86fb6d1684f82ee8041749ad2a324dafaa574cad Binary files /dev/null and b/apple/envs/img/part_grass.png differ diff --git a/apple/envs/img/stool.png b/apple/envs/img/stool.png new file mode 100644 index 0000000000000000000000000000000000000000..c8db2d77e16d05f6dc44f589f6464093213c1c00 Binary files /dev/null and b/apple/envs/img/stool.png differ diff --git a/apple/envs/img/textures.jpg b/apple/envs/img/textures.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b541daa41f0f6df27fc15fc2dd93d3b574f2c35 Binary files /dev/null and b/apple/envs/img/textures.jpg differ diff --git a/apple/envs/img/white.png b/apple/envs/img/white.png new file mode 100644 index 0000000000000000000000000000000000000000..54aae99d335866eaf325bfd4159070ccfd85c72d Binary files /dev/null and b/apple/envs/img/white.png differ diff --git a/apple/evaluation/render_episode.py b/apple/evaluation/render_episode.py new file mode 100644 index 0000000000000000000000000000000000000000..c7275e631b0664f56c74e075422c087cf4f8d792 --- /dev/null +++ b/apple/evaluation/render_episode.py @@ -0,0 +1,22 @@ +def render_episode(env, model): + obs, done, episode_return, episode_len = env.reset(), False, 0, 0 + + while not done: + action = model.get_action(obs) + new_obs, reward, done, _ = env.step(action) + episode_return += reward + episode_len += 1 + + data = dict( + x=env.x, + obs=obs, + action=action, + reward=reward, + done=done, + episode_len=episode_len, + episode_return=episode_return, + pixel_array=env.unwrapped.render(), + ) + yield data + + obs = new_obs diff --git a/apple/logger.py b/apple/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..56065dc36b85cbf496f3699913ea12bafef5da2c --- /dev/null +++ b/apple/logger.py @@ -0,0 +1,384 @@ +""" + +Some simple logging functionality, inspired by rllab's logging. + +Logs to a tab-separated-values file (path/to/output_directory/progress.txt) + +""" +import atexit +import os +import os.path as osp +import time +import warnings + +import joblib +import numpy as np +import torch + +import wandb + +color2num = dict(gray=30, red=31, green=32, yellow=33, blue=34, magenta=35, cyan=36, white=37, crimson=38) + + +def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=True): + """ + Sets up the output_dir for a logger and returns a dict for logger kwargs. + If no seed is given and datestamp is false, + :: + output_dir = data_dir/exp_name + If a seed is given and datestamp is false, + :: + output_dir = data_dir/exp_name/exp_name_s[seed] + If datestamp is true, amend to + :: + output_dir = data_dir/YY-MM-DD_exp_name/YY-MM-DD_HH-MM-SS_exp_name_s[seed] + You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in + ``spinup/user_config.py``. + Args: + exp_name (string): Name for experiment. + seed (int): Seed for random number generators used by experiment. + data_dir (string): Path to folder where results should be saved. + Default is the ``DEFAULT_DATA_DIR`` in ``spinup/user_config.py``. + datestamp (bool): Whether to include a date and timestamp in the + name of the save directory. + Returns: + logger_kwargs, a dict containing output_dir and exp_name. + """ + if data_dir is None: + data_dir = osp.join(osp.abspath(osp.dirname(osp.dirname(osp.dirname(__file__)))), "logs") + + # Make base path + ymd_time = time.strftime("%Y-%m-%d_") if datestamp else "" + relpath = "".join([ymd_time, exp_name]) + + if seed is not None: + # Make a seed-specific subfolder in the experiment directory. + if datestamp: + hms_time = time.strftime("%Y-%m-%d_%H-%M-%S") + subfolder = "".join([hms_time, "-", exp_name, "_s", str(seed)]) + else: + subfolder = "".join([exp_name, "_s", str(seed)]) + relpath = osp.join(relpath, subfolder) + + logger_kwargs = dict(output_dir=osp.join(data_dir, relpath), exp_name=exp_name) + return logger_kwargs + + +def colorize(string, color, bold=False, highlight=False): + """ + Colorize a string. + + This function was originally written by John Schulman. + """ + attr = [] + num = color2num[color] + if highlight: + num += 10 + attr.append(str(num)) + if bold: + attr.append("1") + return "\x1b[%sm%s\x1b[0m" % (";".join(attr), string) + + +class Logger: + """ + A general-purpose logger. + + Makes it easy to save diagnostics, hyperparameter configurations, the + state of a training run, and the trained model. + """ + + def __init__( + self, + log_to_wandb=False, + verbose=False, + output_dir=None, + output_fname="progress.csv", + delimeter=",", + exp_name=None, + wandbcommit=1, + ): + """ + Initialize a Logger. + + Args: + log_to_wandb (bool): If True logger will log to wandb + + output_dir (string): A directory for saving results to. If + ``None``, defaults to a temp directory of the form + ``/tmp/experiments/somerandomnumber``. + + output_fname (string): Name for the tab-separated-value file + containing metrics logged throughout a training run. + Defaults to ``progress.csv``. + + exp_name (string): Experiment name. If you run multiple training + runs and give them all the same ``exp_name``, the plotter + will know to group them. (Use case: if you run the same + hyperparameter configuration with multiple random seeds, you + should give them all the same ``exp_name``.) + + delimeter (string): Used to separate logged values saved in output_fname + """ + self.verbose = verbose + self.log_to_wandb = log_to_wandb + self.delimeter = delimeter + self.wandbcommit = wandbcommit + self.log_iter = 1 + # We assume that there's no multiprocessing. + if output_dir is not None: + self.output_dir = output_dir or "/tmp/experiments/%i" % int(time.time()) + if osp.exists(self.output_dir): + print("Warning: Log dir %s already exists! Storing info there anyway." % self.output_dir) + else: + os.makedirs(self.output_dir) + self.output_file = open(osp.join(self.output_dir, output_fname), "w+") + atexit.register(self.output_file.close) + print(colorize("Logging data to %s" % self.output_file.name, "green", bold=True)) + else: + self.output_file = None + + self.first_row = True + self.log_headers = [] + self.log_current_row = {} + self.exp_name = exp_name + + def log(self, msg, color="green"): + """Print a colorized message to stdout.""" + print(colorize(msg, color, bold=True)) + + def log_tabular(self, key, val): + """ + Log a value of some diagnostic. + + Call this only once for each diagnostic quantity, each iteration. + After using ``log_tabular`` to store values for each diagnostic, + make sure to call ``dump_tabular`` to write them out to file and + stdout (otherwise they will not get saved anywhere). + """ + if self.first_row: + self.log_headers.append(key) + else: + if key not in self.log_headers: + self.log_headers.append(key) + + if self.output_file is not None: + # move pointer at the beggining of the file + self.output_file.seek(0) + # skip the header + self.output_file.readline() + # keep rest of the file + logs = self.output_file.read() + # clear the file + self.output_file.truncate(0) + self.output_file.seek(0) + # write new headers + self.output_file.write(self.delimeter.join(self.log_headers) + "\n") + # write stored file + self.output_file.write(logs) + self.output_file.seek(0) + self.output_file.seek(0, 2) + # assert key in self.log_headers, ( + # "Trying to introduce a new key %s that you didn't include in the first iteration" % key + # ) + assert key not in self.log_current_row, ( + "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key + ) + self.log_current_row[key] = val + + def save_state(self, state_dict, itr=None): + """ + Saves the state of an experiment. + + To be clear: this is about saving *state*, not logging diagnostics. + All diagnostic logging is separate from this function. This function + will save whatever is in ``state_dict``---usually just a copy of the + environment---and the most recent parameters for the model you + previously set up saving for with ``setup_tf_saver``. + + Call with any frequency you prefer. If you only want to maintain a + single state and overwrite it at each call with the most recent + version, leave ``itr=None``. If you want to keep all of the states you + save, provide unique (increasing) values for 'itr'. + + Args: + state_dict (dict): Dictionary containing essential elements to + describe the current state of training. + + itr: An int, or None. Current iteration of training. + """ + fname = "vars.pkl" if itr is None else "vars%d.pkl" % itr + try: + joblib.dump(state_dict, osp.join(self.output_dir, fname)) + except: + self.log("Warning: could not pickle state_dict.", color="red") + if hasattr(self, "pytorch_saver_elements"): + self._pytorch_simple_save(itr) + + def setup_pytorch_saver(self, what_to_save): + """ + Set up easy model saving for a single PyTorch model. + + Because PyTorch saving and loading is especially painless, this is + very minimal; we just need references to whatever we would like to + pickle. This is integrated into the logger because the logger + knows where the user would like to save information about this + training run. + + Args: + what_to_save: Any PyTorch model or serializable object containing + PyTorch models. + """ + self.pytorch_saver_elements = what_to_save + + def _pytorch_simple_save(self, itr=None): + """ + Saves the PyTorch model (or models). + """ + assert hasattr(self, "pytorch_saver_elements"), "First have to setup saving with self.setup_pytorch_saver" + fpath = "pyt_save" + fpath = osp.join(self.output_dir, fpath) + fname = "model" + ("%d" % itr if itr is not None else "") + ".pt" + fname = osp.join(fpath, fname) + os.makedirs(fpath, exist_ok=True) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # We are using a non-recommended way of saving PyTorch models, + # by pickling whole objects (which are dependent on the exact + # directory structure at the time of saving) as opposed to + # just saving network weights. This works sufficiently well + # for the purposes of Spinning Up, but you may want to do + # something different for your personal PyTorch project. + # We use a catch_warnings() context to avoid the warnings about + # not being able to save the source code. + torch.save(self.pytorch_saver_elements, fname) + + def dump_tabular(self): + """ + Write all of the diagnostics from the current iteration. + + Writes both to stdout, and to the output file. + """ + vals = [] + key_lens = [len(key) for key in self.log_headers] + max_key_len = max(15, max(key_lens)) + keystr = "%" + "%d" % max_key_len + fmt = "| " + keystr + "s | %15s |" + n_slashes = 22 + max_key_len + step = self.log_current_row.get("total_env_steps") + + if self.verbose: + print("-" * n_slashes) + for key in self.log_headers: + val = self.log_current_row.get(key, "") + valstr = "%8.3g" % val if isinstance(val, float) else val + print(fmt % (key, valstr)) + vals.append(val) + print("-" * n_slashes, flush=True) + + if self.output_file is not None: + if self.first_row: + self.output_file.write(self.delimeter.join(self.log_headers) + "\n") + self.output_file.write(self.delimeter.join(map(str, vals)) + "\n") + self.output_file.flush() + + key_val_dict = {key: self.log_current_row.get(key, "") for key in self.log_headers} + if self.log_to_wandb: + if self.log_iter % self.wandbcommit == 0: + wandb.log(key_val_dict, step=step, commit=True) + else: + wandb.log(key_val_dict, step=step, commit=False) + + self.log_current_row.clear() + self.first_row = False + self.log_iter += 1 + + return key_val_dict + + +class EpochLogger(Logger): + """ + A variant of Logger tailored for tracking average values over epochs. + + Typical use case: there is some quantity which is calculated many times + throughout an epoch, and at the end of the epoch, you would like to + report the average / std / min / max value of that quantity. + + With an EpochLogger, each time the quantity is calculated, you would + use + + .. code-block:: python + + epoch_logger.store(NameOfQuantity=quantity_value) + + to load it into the EpochLogger's state. Then at the end of the epoch, you + would use + + .. code-block:: python + + epoch_logger.log_tabular(NameOfQuantity, **options) + + to record the desired values. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.epoch_dict = dict() + + def store(self, d): + """ + Save something into the epoch_logger's current state. + + Provide an arbitrary number of keyword arguments with numerical + values. + """ + for k, v in d.items(): + if not (k in self.epoch_dict.keys()): + self.epoch_dict[k] = [] + self.epoch_dict[k].append(v) + + def log_tabular(self, key, val=None, with_min_and_max=False, with_median=False, with_sum=False, average_only=False): + """ + Log a value or possibly the mean/std/min/max values of a diagnostic. + + Args: + key (string): The name of the diagnostic. If you are logging a + diagnostic whose state has previously been saved with + ``store``, the key here has to match the key you used there. + + val: A value for the diagnostic. If you have previously saved + values for this key via ``store``, do *not* provide a ``val`` + here. + + with_min_and_max (bool): If true, log min and max values of the + diagnostic over the epoch. + + average_only (bool): If true, do not log the standard deviation + of the diagnostic over the epoch. + """ + if val is not None: + super().log_tabular(key, val) + else: + stats = self.get_stats(key) + super().log_tabular(key if average_only else key + "/avg", stats[0]) + if not (average_only): + super().log_tabular(key + "/std", stats[1]) + if with_min_and_max: + super().log_tabular(key + "/max", stats[3]) + super().log_tabular(key + "/min", stats[2]) + if with_median: + super().log_tabular(key + "/med", stats[4]) + if with_sum: + super().log_tabular(key + "/sum", stats[5]) + + self.epoch_dict[key] = [] + + def get_stats(self, key): + """ + Lets an algorithm ask the logger for mean/std/min/max of a diagnostic. + """ + v = self.epoch_dict.get(key) + if not v: + return [np.nan, np.nan, np.nan, np.nan] + vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape) > 0 else v + return [np.mean(vals), np.std(vals), np.min(vals), np.max(vals), np.median(vals), np.sum(vals)] diff --git a/apple/models/categorical_policy.py b/apple/models/categorical_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..208cb8b34d67da8e070df15f408f6ee6f211ac47 --- /dev/null +++ b/apple/models/categorical_policy.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +from torch.distributions import Categorical + + +class CategoricalPolicy(nn.Module): + def __init__(self, state_dim, act_dim, weight1=None, weight2=None): + super().__init__() + self.model = nn.Linear(state_dim, act_dim, bias=False) + + if weight1 is not None: + nn.init.constant_(self.model.weight[0][0], weight1) + + if weight2 is not None: + nn.init.constant_(self.model.weight[0][1], weight2) + + def forward(self, state): + x = torch.from_numpy(state).float().unsqueeze(0) + x = self.model(x) + # we just consider 1 dimensional probability of action + p = torch.sigmoid(x) + return torch.cat([p, 1 - p], dim=1) + + def act(self, state): + probs = self.forward(state) + dist = Categorical(probs) + action = dist.sample() + return action.item(), dist.log_prob(action) + + def sample(self, probs): + dist = Categorical(probs) + action = dist.sample() + return action.item(), dist.log_prob(action) + + def log_prob(self, probs, target_action): + dist = Categorical(probs) + action = dist.sample() + return action.item(), dist.log_prob(target_action) + + @torch.no_grad() + def get_action(self, state): + probs = self.forward(state) + dist = Categorical(probs) + action = dist.sample() + return action.item() diff --git a/apple/training/reinforce_trainer.py b/apple/training/reinforce_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..08cc134dcd11099e0a41321ab0e8c452c0fbac68 --- /dev/null +++ b/apple/training/reinforce_trainer.py @@ -0,0 +1,74 @@ +import torch + +from apple.training.trainer import Trainer + + +def discount_cumsum(x, gamma): + discount_cumsum = torch.zeros_like(x) + discount_cumsum[-1] = x[-1] + for t in reversed(range(x.shape[0] - 1)): + discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1] + return discount_cumsum + + +class ReinforceTrainer(Trainer): + def __init__(self, *args, gamma: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.gamma = gamma + + def train(self, env, test_envs, num_episodes, log_every, update_every, num_eval_eps): + # code base on + # https://goodboychan.github.io/python/reinforcement_learning/pytorch/udacity/2021/05/12/REINFORCE-CartPole.html + self.optim.zero_grad() + for episode in range(num_episodes): + self.train_it += 1 + + if (episode + 1) % log_every == 0: + + self.test_agent(self.model, self.logger, test_envs, num_eval_eps) + + # Log info about epoch + self.logger.log_tabular("total_env_steps", self.train_it) + self.logger.log_tabular("train/return", with_min_and_max=True) + self.logger.log_tabular("train/ep_length", average_only=True) + + for e, w in enumerate(self.model.model.weight.flatten()): + self.logger.log_tabular(f"weights{e}", w.item()) + + self.logger.log_tabular("train/policy_loss", average_only=True) + self.logger.log_tabular("train/log_probs", average_only=True) + self.logger.dump_tabular() + + state = env.reset() + + saved_log_probs = [] + rewards = [] + ep_len, ep_ret = 0, 0 + while True: + # Sample the action from current policy + action, log_prob = self.model.act(state) + saved_log_probs.append(log_prob) + state, reward, done, _ = env.step(action) + ep_ret += reward + ep_len += 1 + + rewards.append(reward) + + if done: + self.logger.store({"train/return": ep_ret, "train/ep_length": ep_len}) + break + + saved_log_probs, rewards = torch.cat(saved_log_probs), torch.tensor(rewards) + + discounted_rewards = discount_cumsum(rewards, gamma=self.gamma) + # Note that we are using Gradient Ascent, not Descent. So we need to calculate it with negative rewards. + policy_loss = (-discounted_rewards * saved_log_probs).sum() + # Backpropagation + if (episode + 1) % update_every == 0: + self.optim.zero_grad() + policy_loss.backward() + if (episode + 1) % update_every == 0: + self.optim.step() + + self.logger.store({"train/policy_loss": policy_loss.item()}) + self.logger.store({"train/log_probs": saved_log_probs.mean().item()}) diff --git a/apple/training/trainer.py b/apple/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..77bd9c0ed790435c537ebe38881f94b73683fdcc --- /dev/null +++ b/apple/training/trainer.py @@ -0,0 +1,77 @@ +import numpy as np +import torch + + +class Trainer: + def __init__(self, model, optim, logger): + self.model = model + self.optim = optim + self.logger = logger + self.train_it = 0 + + def test_agent(self, model, logger, test_envs, num_episodes): + avg_success = [] + + for seq_idx, test_env in enumerate(test_envs): + key_prefix = f"{test_env.name}/" + + for j in range(num_episodes): + obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0 + + while not done: + action = model.get_action(obs) + obs, reward, done, _ = test_env.step(action) + episode_return += reward + episode_len += 1 + logger.store({key_prefix + "return": episode_return, key_prefix + "ep_length": episode_len}) + + logger.log_tabular(key_prefix + "return", with_min_and_max=True) + logger.log_tabular(key_prefix + "ep_length", average_only=True) + env_success = test_env.pop_successes() + avg_success += env_success + logger.log_tabular(key_prefix + "success", np.mean(env_success)) + + key = "average_success" + logger.log_tabular(key, np.mean(avg_success)) + + def log(self, logger, step, model): + # Log info about epoch + logger.log_tabular("total_env_steps", step) + + logger.log_tabular("train/loss", average_only=True) + logger.log_tabular("train/action", average_only=True) + + for e, w in enumerate(model.model.weight.flatten()): + logger.log_tabular(f"weight{e}", w.item()) + + return logger.dump_tabular() + + def update(self, env, probs, model, optim, logger): + target = torch.as_tensor([env.get_target_action()], dtype=torch.float32) + action, log_prob = model.log_prob(probs, target) + + optim.zero_grad() + loss = -torch.mean(log_prob) + loss.backward() + optim.step() + + logger.store({"train/action": action}) + logger.store({"train/loss": loss.item()}) + + def train(self, env, test_envs, steps, log_every, num_eval_eps): + obs = env.reset() + for timestep in range(steps): + self.train_it += 1 + + if (timestep + 1) % log_every == 0: + self.test_agent(self.model, self.logger, test_envs, num_eval_eps) + self.log(self.logger, self.train_it, self.model) + + output = self.model(obs) + action, log_prob = self.model.sample(output) + self.update(env, output, self.model, self.optim, self.logger) + + obs, reward, done, info = env.step(action) + + if done: + obs = env.reset() diff --git a/apple/utils.py b/apple/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..583438a6b96257341e04112f2cd55d90d21a77b0 --- /dev/null +++ b/apple/utils.py @@ -0,0 +1,25 @@ +import argparse +import random + +from typing import Union + +import numpy as np +import torch + + +# https://stackoverflow.com/a/43357954/6365092 +def str2bool(v: Union[bool, str]) -> bool: + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) diff --git a/apple/wrappers.py b/apple/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f8f6bb4bcb59083e699314e05fcfd2c71a4c17 --- /dev/null +++ b/apple/wrappers.py @@ -0,0 +1,35 @@ +from typing import ( + Any, + Dict, + List, + Tuple, +) + +import gym +import numpy as np + + +class SuccessCounter(gym.Wrapper): + """Helper class to keep count of successes in MetaWorld environments.""" + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + self.successes = [] + self.current_success = False + + def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]: + obs, reward, done, info = self.env.step(action) + if info.get("success", False): + self.current_success = True + if done: + self.successes.append(self.current_success) + return obs, reward, done, info + + def pop_successes(self) -> List[bool]: + res = self.successes + self.successes = [] + return res + + def reset(self, **kwargs) -> np.ndarray: + self.current_success = False + return self.env.reset(**kwargs) diff --git a/assets/apple_env.png b/assets/apple_env.png new file mode 100644 index 0000000000000000000000000000000000000000..f4126a0f62d8491b374ee35efcef69d575ba943b Binary files /dev/null and b/assets/apple_env.png differ diff --git a/assets/example_rollout.mp4 b/assets/example_rollout.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..645c6b620ddc9f129b3e5f67e3f937ce392749d3 Binary files /dev/null and b/assets/example_rollout.mp4 differ diff --git a/assets/generate_example_rollout.py b/assets/generate_example_rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1f0b5aea3d22377d08aaa2a07a1cfa1e184c90 --- /dev/null +++ b/assets/generate_example_rollout.py @@ -0,0 +1,30 @@ +import numpy as np +import skvideo.io + +from apple.envs.discrete_apple import get_apple_env + +env = get_apple_env("full", time_limit=10, start_x=0, c=0.5, goal_x=8, render_mode="rgb_array") + + +imgs = [] +env.reset() +for i in range(8): + imgs.append(env.unwrapped.render()) + env.step(1) + +for i in range(9): + imgs.append(env.unwrapped.render()) + env.step(0) + + +skvideo.io.vwrite( + "example_rollout.mp4", + np.stack(imgs), + inputdict={ + "-r": str(int(4)), + }, + outputdict={ + "-f": "mp4", + "-pix_fmt": "yuv420p", # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74 + }, +) diff --git a/input_args.py b/input_args.py new file mode 100644 index 0000000000000000000000000000000000000000..b719b48947989930d35bfe686f1e0bfe1a00757d --- /dev/null +++ b/input_args.py @@ -0,0 +1,17 @@ +import argparse + +from apple.utils import str2bool + + +def apple_parse_args(args=None): + parser = argparse.ArgumentParser() + + parser.add_argument("--c", type=float, default=0.25, required=False) + parser.add_argument("--start_x", type=float, default=0.0, required=False) + parser.add_argument("--goal_x", type=float, default=10.0, required=False) + parser.add_argument("--time_limit", type=int, default=100.0, required=False) + + parser.add_argument("--lr", type=float, default=1e-3, required=False) + parser.add_argument("--log_to_wandb", type=str2bool, default=True, required=False) + + return parser.parse_known_args(args=args)[0] diff --git a/mrunner_exps/behavioral_cloning.py b/mrunner_exps/behavioral_cloning.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec5d1fe0a6c5b421c592720d359372842d8c15a --- /dev/null +++ b/mrunner_exps/behavioral_cloning.py @@ -0,0 +1,51 @@ +import numpy as np + +from mrunner.helpers.specification_helper import create_experiments_helper + +from mrunner_exps.utils import combine_config_with_defaults + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tag": "behavioral_cloning", + "run_kind": "bc", + "log_to_wandb": True, + "pretrain_steps": 200, + "steps": 200, + "log_every": 1, + "num_eval_eps": 10, + "verbose": False, + "lr": 0.01, + "c": 0.5, + "start_x": 0.0, + "goal_x": 50.0, + "bias_in_state": True, + "position_in_state": False, + "time_limit": 100, + "wandbcommit": 100, + "pretrain": "phase2", + "finetune": "full", +} +config = combine_config_with_defaults(config) + +# params different between exps +params_grid = [ + { + "seed": list(range(10)), + "c": list(np.arange(0.1, 1.1, 0.1)), + "goal_x": list(np.arange(5, 50, 5)), + } +] + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="apple", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["logs", "wandb"], + base_config=config, + params_grid=params_grid, +) diff --git a/mrunner_exps/reinforce.py b/mrunner_exps/reinforce.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1d15d8946fe5c780b85508546a8b03a4f63487 --- /dev/null +++ b/mrunner_exps/reinforce.py @@ -0,0 +1,53 @@ +import numpy as np + +from mrunner.helpers.specification_helper import create_experiments_helper + +from mrunner_exps.utils import combine_config_with_defaults + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tag": "reinforce_goal_c3", + "run_kind": "reinforce", + "log_to_wandb": True, + "pretrain_steps": 1000, + "steps": 2000, + "log_every": 1, + "num_eval_eps": 10, + "verbose": False, + "lr": 0.001, + "c": 1.0, + "start_x": 0.0, + "goal_x": 50.0, + "bias_in_state": True, + "position_in_state": False, + "time_limit": 100, + "gamma": 0.99, + "wandbcommit": 1000, + "pretrain": "phase2", + "finetune": "full", + "update_every": 10, # for good definition on gradient +} +config = combine_config_with_defaults(config) + +# params different between exps +params_grid = [ + { + "seed": list(range(10)), + "c": list(np.arange(0.1, 1.1, 0.1)), + "goal_x": list(np.arange(5, 50, 5)), + } +] + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="apple", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["logs", "wandb"], + base_config=config, + params_grid=params_grid, +) diff --git a/mrunner_exps/utils.py b/mrunner_exps/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..86b982f152b26cce351485f90cf0c852d1922ffe --- /dev/null +++ b/mrunner_exps/utils.py @@ -0,0 +1,10 @@ +from input_args import apple_parse_args + +PARSE_ARGS_DICT = {"bc": apple_parse_args, "reinforce": apple_parse_args} + + +def combine_config_with_defaults(config): + run_kind = config["run_kind"] + res = vars(PARSE_ARGS_DICT[run_kind]([])) + res.update(config) + return res diff --git a/mrunner_run.py b/mrunner_run.py new file mode 100644 index 0000000000000000000000000000000000000000..f039ddec48d9af396dc59f0edde12ce6131fcf36 --- /dev/null +++ b/mrunner_run.py @@ -0,0 +1,18 @@ +from mrunner.helpers.client_helper import get_configuration + +import wandb + +from run import main + +if __name__ == "__main__": + config = get_configuration(print_diagnostics=True, with_neptune=False) + + del config["experiment_id"] + + if config.log_to_wandb: + wandb.init( + entity="gmum", + project="apple", + config=config, + ) + main(**config) diff --git a/mrunner_runs/local.sh b/mrunner_runs/local.sh new file mode 100755 index 0000000000000000000000000000000000000000..a922b03f07d6f6a46487f38b7dd9e15dd0e85d0e --- /dev/null +++ b/mrunner_runs/local.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python mrunner_run.py --ex mrunner_exps/baseline.py diff --git a/mrunner_runs/remote.sh b/mrunner_runs/remote.sh new file mode 100755 index 0000000000000000000000000000000000000000..1cc2c426c3bcf29b1f8bd489a42e5402bb8e6f1b --- /dev/null +++ b/mrunner_runs/remote.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +conda activate apple + +ssh-add +export PYTHONPATH=. + +mrunner --config ~/.mrunner.yaml --context eagle_transfer_mw2 run mrunner_exps/behavioral_cloning.py +# mrunner --config ~/.mrunner.yaml --context eagle_transfer_mw2 run mrunner_exps/reinforce.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..a42211ee23eb1e1f8d873e0cf93bdb228ea49dbe --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,92 @@ +[build-system] +requires = ["setuptools", "setuptools_scm", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "apple" +description = "Simplest experiment for showing forgetting" +license = { text = "Proprietary" } +authors = [{name = "BartekCupial", email = "bartlomiej.cupial@student.uj.edu.pl" }] + +dynamic = ["version"] + +requires-python = ">= 3.8, < 3.11" + +dependencies = [ + "numpy ~= 1.23", + "typing-extensions ~= 4.3", + "gym == 0.23", + "torch ~= 1.12", + "wandb ~= 0.13", + "pandas ~= 1.5", + "matplotlib ~= 3.6", + "seaborn ~= 0.12", + "scipy ~= 1.9", + "joblib ~= 1.2", + "pygame ~= 2.1", +] + +[project.optional-dependencies] +build = ["build ~= 0.8"] +mrunner = ["mrunner @ git+https://gitlab.com/awarelab/mrunner.git"] +lint = [ + "black ~= 22.6", + "autoflake ~= 1.4", + "flake8 ~= 4.0", + "flake8-pyi ~= 22.5", + "flake8-docstrings ~= 1.6", + "pyproject-flake8 ~= 0.0.1a4", + "isort ~= 5.10", + "pre-commit ~= 2.20", +] +test = [ + "pytest ~= 7.1", + "pytest-cases ~= 3.6", + "pytest-cov ~= 3.0", + "pytest-xdist ~= 2.5", + "pytest-sugar ~= 0.9", + "hypothesis ~= 6.54", +] +dev = [ + "apple[mrunner]", + "apple[build]", + "apple[lint]", + "apple[test]", +] + +[project.urls] +"Source" = "https://github.com/BartekCupial/apple" + +[tool.black] +line_length = 120 + +[tool.flake8] +extend_exclude = [".venv/", "build/", "dist/", "docs/"] +per_file_ignores = ["**/_[a-z]*.py:D", "tests/*.py:D", "*.pyi:D"] +ignore = [ + # Handled by black + "E", # pycodestyle + "W", # pycodestyle + "D", +] +ignore_decorators = "property" # https://github.com/PyCQA/pydocstyle/pull/546 + +[tool.isort] +profile = "black" +line_length = 120 +order_by_type = true +lines_between_types = 1 +combine_as_imports = true +force_grid_wrap = 2 + +[tool.pytest.ini_options] +testpaths = "tests" +addopts = """ + -n auto + -ra + --tb short + --doctest-modules + --junit-xml test-results.xml + --cov-report term-missing:skip-covered + --cov-report xml:coverage.xml +""" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..78eda9a10a6087f3515904aa5974d43d5fc5bd37 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +torch +numpy +pandas +gym == 0.23 +wandb +joblib +pygame \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..7239104fe4dd61ba94f4f7c85bce39a535a3dc09 --- /dev/null +++ b/run.py @@ -0,0 +1,87 @@ +import torch + +import wandb + +from apple.envs.discrete_apple import get_apple_env +from apple.logger import EpochLogger +from apple.models.categorical_policy import CategoricalPolicy +from apple.training.reinforce_trainer import ReinforceTrainer +from apple.training.trainer import Trainer +from apple.utils import set_seed +from input_args import apple_parse_args + + +def main( + run_kind: str, + c: float = 1.0, + start_x: float = 0.0, + goal_x: float = 50.0, + time_limit: int = 200, + bias_in_state: bool = True, + position_in_state: bool = False, + apple_in_state: bool = True, + lr: float = 1e-3, + pretrain_steps: int = 0, + steps: int = 10000, + log_every: int = 1, + num_eval_eps: int = 1, + pretrain: str = "phase1", + finetune: str = "full", + log_to_wandb: bool = False, + wandbcommit: int = 1, + verbose: bool = False, + output_dir="logs/apple", + gamma: float = 1.0, + update_every: int = 10, + seed=0, + **kwargs, +): + set_seed(seed) + + logger = EpochLogger( + exp_name=run_kind, + output_dir=output_dir, + log_to_wandb=log_to_wandb, + wandbcommit=wandbcommit, + verbose=verbose, + ) + + env_kwargs = dict( + start_x=start_x, + goal_x=goal_x, + c=c, + time_limit=time_limit, + bias_in_state=bias_in_state, + position_in_state=position_in_state, + apple_in_state=apple_in_state, + ) + + env_phase1 = get_apple_env(pretrain, **env_kwargs) + env_phase2 = get_apple_env(finetune, **env_kwargs) + test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]] + + model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1) + optim = torch.optim.SGD(model.parameters(), lr=lr) + + if run_kind == "reinforce": + trainer = ReinforceTrainer(model, optim, logger, gamma=gamma) + trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps) + trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps) + elif run_kind == "bc": + trainer = Trainer(model, optim, logger) + trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps) + trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps) + + +if __name__ == "__main__": + args = apple_parse_args() + + if args.log_to_wandb: + wandb.init( + entity="gmum", + project="apple", + config=args, + settings=wandb.Settings(start_method="fork"), + ) + + main(**vars(args)) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..30454e21cd4e2a72279ee91e6d201b05292f1db0 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,10 @@ +[options] +packages = find_namespace: +package_dir = + = apple + +[options.packages.find] +where = apple + +[options.package_data] +* = py.typed, *.pyi diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..0eccc5e1718bb6166809d74543159790ea4cfcb4 --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +# flake8: noqa +# type: ignore + +from setuptools import setup + +setup() diff --git a/tests/test_discrete_apple.py b/tests/test_discrete_apple.py new file mode 100644 index 0000000000000000000000000000000000000000..6597ab3bf8c5120d3eb24dc02e32e990e271994a --- /dev/null +++ b/tests/test_discrete_apple.py @@ -0,0 +1,93 @@ +import numpy as np + +from apple.envs.discrete_apple import get_apple_env + + +def test_discrete_apple_phase1(): + c = 0.5 + timelimit = 30 + env = get_apple_env("phase1", start_x=0, goal_x=10, c=c, time_limit=timelimit) + + observations, actions, rewards, done = [], [], [], False + obs = env.reset() + for i in range(timelimit): + action = np.random.choice([0, 1], p=[0.2, 0.8]) + obs, reward, done, info = env.step(action) + observations.append(obs) + actions.append(action) + rewards.append(reward) + + if done: + break + + observations = np.array(observations) + actions = np.array(actions) + rewards = np.array(rewards) + + target_rewards = np.ones(len(actions)) * actions * 2 - 1 + if info["success"]: + target_rewards[-1] = 100 + target_states = np.stack([np.ones(len(actions)), np.ones(len(actions)) * -c], axis=1) + assert (rewards == target_rewards).all() + assert (observations == target_states).all() + + +def test_discrete_apple_phase2(): + c = 0.5 + timelimit = 30 + env = get_apple_env("phase2", start_x=0, goal_x=10, c=c, time_limit=timelimit) + + observations, actions, rewards, done = [], [], [], False + obs = env.reset() + for i in range(timelimit): + action = np.random.choice([0, 1], p=[0.8, 0.2]) + obs, reward, done, info = env.step(action) + observations.append(obs) + actions.append(action) + rewards.append(reward) + + if done: + break + + observations = np.array(observations) + actions = np.array(actions) + rewards = np.array(rewards) + + target_rewards = np.ones(len(actions)) * (1 - actions) * 2 - 1 + if info["success"]: + target_rewards[-1] = 100 + target_states = np.stack([np.ones(len(actions)), np.ones(len(actions)) * c], axis=1) + assert (rewards == target_rewards).all() + assert (observations == target_states).all() + + +def test_discrete_apple_full(): + c = 0.5 + + target_rewards = np.ones(20) + target_rewards[-1] = 100 + target_states = np.stack([np.ones(20), np.concatenate([np.ones(10) * -c, np.ones(10) * c])], axis=1) + + env = get_apple_env("full", start_x=0, goal_x=10, c=c, time_limit=30) + + observations, actions, rewards = [], [], [] + obs = env.reset() + for i in range(10): + action = 1 + obs, reward, done, info = env.step(action) + observations.append(obs) + actions.append(action) + rewards.append(reward) + + for i in range(10): + action = 0 + obs, reward, done, info = env.step(action) + observations.append(obs) + actions.append(action) + rewards.append(reward) + + rewards = np.array(rewards) + observations = np.array(observations) + + assert (rewards == target_rewards).all() + assert (observations == target_states).all()