Spaces:
Runtime error
Runtime error
| import asyncio | |
| from functools import partial | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| from apple.envs.discrete_apple import get_apple_env | |
| from apple.evaluation.render_episode import render_episode | |
| from apple.logger import EpochLogger | |
| from apple.models.categorical_policy import CategoricalPolicy | |
| from apple.training.trainer import Trainer | |
| QUEUE_SIZE = 1000 | |
| def init_training( | |
| c: float = 1.0, | |
| start_x: float = 0.0, | |
| goal_x: float = 50.0, | |
| time_limit: int = 200, | |
| lr: float = 1e-3, | |
| weight1: float = 0.0, | |
| weight2: float = 0.0, | |
| pretrain: str = "phase1", | |
| finetune: str = "full", | |
| bias_in_state: bool = True, | |
| position_in_state: bool = False, | |
| apple_in_state: bool = True, | |
| ): | |
| st.session_state.logger = EpochLogger(verbose=False) | |
| env_kwargs = dict( | |
| start_x=start_x, | |
| goal_x=goal_x, | |
| c=c, | |
| time_limit=time_limit, | |
| bias_in_state=bias_in_state, | |
| position_in_state=position_in_state, | |
| apple_in_state=apple_in_state, | |
| ) | |
| st.session_state.env_full = get_apple_env("full", render_mode="rgb_array", **env_kwargs) | |
| st.session_state.env_phase1 = get_apple_env(pretrain, **env_kwargs) | |
| st.session_state.env_phase2 = get_apple_env(finetune, **env_kwargs) | |
| st.session_state.test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]] | |
| st.session_state.model = CategoricalPolicy( | |
| st.session_state.env_phase1.observation_space.shape[0], 1, weight1, weight2 | |
| ) | |
| st.session_state.optim = torch.optim.SGD(st.session_state.model.parameters(), lr=lr) | |
| st.session_state.trainer = Trainer(st.session_state.model, st.session_state.optim, st.session_state.logger) | |
| st.session_state.train_it = 0 | |
| st.session_state.draw_it = 0 | |
| st.session_state.total_steps = 0 | |
| st.session_state.data = [] | |
| st.session_state.obs1 = st.session_state.env_phase1.reset() | |
| st.session_state.obs2 = st.session_state.env_phase2.reset() | |
| def init_reset(): | |
| st.session_state.rollout_iterator = iter( | |
| partial(render_episode, st.session_state.env_full, st.session_state.model)() | |
| ) | |
| st.session_state.last_image = dict( | |
| x=0, | |
| obs=st.session_state.env_full.reset(), | |
| action=None, | |
| reward=0, | |
| done=False, | |
| episode_len=0, | |
| episode_return=0, | |
| pixel_array=st.session_state.env_full.unwrapped.render(), | |
| ) | |
| def select_preset(): | |
| if st.session_state.pick_preset == 0: | |
| preset_finetuning_interference() | |
| elif st.session_state.pick_preset == 1: | |
| preset_train_full_from_scratch() | |
| elif st.session_state.pick_preset == 2: | |
| preset_task_interference() | |
| elif st.session_state.pick_preset == 3: | |
| preset_without_task_interference() | |
| def preset_task_interference(): | |
| st.session_state.pick_c = 0.5 | |
| st.session_state.pick_goal_x = 20 | |
| st.session_state.pick_time_limit = 50 | |
| st.session_state.pick_lr = 0.05 | |
| st.session_state.pick_phase1task = "phase2" | |
| st.session_state.pick_phase2task = "phase1" | |
| st.session_state.pick_weight1 = 0.0 | |
| st.session_state.pick_weight2 = 0.0 | |
| st.session_state.pick_phase1steps = 500 | |
| st.session_state.pick_phase2steps = 500 | |
| need_reset() | |
| def preset_finetuning_interference(): | |
| st.session_state.pick_c = 0.5 | |
| st.session_state.pick_goal_x = 20 | |
| st.session_state.pick_time_limit = 50 | |
| st.session_state.pick_lr = 0.05 | |
| st.session_state.pick_phase1task = "phase2" | |
| st.session_state.pick_phase2task = "full" | |
| st.session_state.pick_weight1 = 0.0 | |
| st.session_state.pick_weight2 = 0.0 | |
| st.session_state.pick_phase1steps = 500 | |
| st.session_state.pick_phase2steps = 2000 | |
| need_reset() | |
| def preset_without_task_interference(): | |
| st.session_state.pick_c = 1.0 | |
| st.session_state.pick_goal_x = 20 | |
| st.session_state.pick_time_limit = 50 | |
| st.session_state.pick_lr = 0.05 | |
| st.session_state.pick_phase1task = "phase2" | |
| st.session_state.pick_phase2task = "phase1" | |
| st.session_state.pick_weight1 = 0.0 | |
| st.session_state.pick_weight2 = 0.0 | |
| st.session_state.pick_phase1steps = 500 | |
| st.session_state.pick_phase2steps = 500 | |
| need_reset() | |
| def preset_train_full_from_scratch(): | |
| st.session_state.pick_c = 0.5 | |
| st.session_state.pick_goal_x = 20 | |
| st.session_state.pick_time_limit = 50 | |
| st.session_state.pick_lr = 0.05 | |
| st.session_state.pick_phase1task = "phase2" | |
| st.session_state.pick_phase2task = "full" | |
| st.session_state.pick_weight1 = 0.0 | |
| st.session_state.pick_weight2 = 0.0 | |
| st.session_state.pick_phase1steps = 0 | |
| st.session_state.pick_phase2steps = 2000 | |
| need_reset() | |
| def empty_queue(q: asyncio.Queue): | |
| for _ in range(q.qsize()): | |
| # Depending on your program, you may want to | |
| # catch QueueEmpty | |
| q.get_nowait() | |
| q.task_done() | |
| def reset(**kwargs): | |
| init_training(**kwargs) | |
| init_reset() | |
| st.session_state.play = False | |
| st.session_state.step = False | |
| st.session_state.render = False | |
| st.session_state.done = False | |
| empty_queue(st.session_state.queue) | |
| empty_queue(st.session_state.queue_render) | |
| st.session_state.play_pause = False | |
| st.session_state.need_reset = False | |
| def render_start(): | |
| st.session_state.render = True | |
| st.session_state.done = False | |
| init_reset() | |
| def need_reset(): | |
| st.session_state.need_reset = True | |
| st.session_state.play = False | |
| st.session_state.render = False | |
| def play_pause(): | |
| if st.session_state.play: | |
| st.session_state.play = False | |
| st.session_state.play_pause = False | |
| else: | |
| st.session_state.play = True | |
| st.session_state.play_pause = True | |
| def step(): | |
| st.session_state.step = True | |
| def plot(data_placeholder): | |
| df = pd.DataFrame(st.session_state.data) | |
| if not df.empty: | |
| df.set_index("total_env_steps", inplace=True) | |
| container = data_placeholder.container() | |
| c1, c2, c3 = container.columns(3) | |
| def view_df(names): | |
| rdf = df.loc[:, df.columns.isin(names)] | |
| if rdf.empty: | |
| return pd.DataFrame([{name: 0 for name in names}]) | |
| else: | |
| return rdf | |
| c1.write("phase1/success_rate") | |
| c1.line_chart(view_df(["phase1/success"])) | |
| c2.write("phase2/success_rate") | |
| c2.line_chart(view_df(["phase2/success"])) | |
| c3.write("full/success_rate") | |
| c3.line_chart(view_df(["full/success"])) | |
| c1.write("train/loss") | |
| c1.line_chart(view_df(["train/loss"])) | |
| c2.write("weight0") | |
| c2.line_chart(view_df(["weight0"])) | |
| c3.write("weight1") | |
| c3.line_chart(view_df(["weight1"])) | |
| async def draw(data_placeholder, queue, delay, steps, plotfrequency): | |
| while (st.session_state.play or st.session_state.step) and st.session_state.draw_it < steps: | |
| _ = await asyncio.sleep(delay) | |
| new_data = await queue.get() | |
| st.session_state.draw_it += 1 | |
| if st.session_state.draw_it % plotfrequency == 0: | |
| st.session_state.data.append(new_data) | |
| plot(data_placeholder) | |
| st.session_state.step = False | |
| queue.task_done() | |
| async def train(queue, delay, steps, obs, env, num_eval_episodes, plotfrequency): | |
| while (st.session_state.play or st.session_state.step) and st.session_state.train_it < steps: | |
| _ = await asyncio.sleep(delay) | |
| st.session_state.train_it += 1 | |
| st.session_state.total_steps += 1 | |
| output = st.session_state.model(obs) | |
| action, log_prob = st.session_state.model.sample(output) | |
| st.session_state.trainer.update( | |
| env, output, st.session_state.model, st.session_state.optim, st.session_state.logger | |
| ) | |
| obs, reward, done, info = env.step(action) | |
| if done: | |
| obs = env.reset() | |
| if st.session_state.train_it % plotfrequency == 0: | |
| st.session_state.trainer.test_agent( | |
| st.session_state.model, st.session_state.logger, st.session_state.test_envs, num_eval_episodes | |
| ) | |
| data = st.session_state.trainer.log( | |
| st.session_state.logger, st.session_state.train_it, st.session_state.model | |
| ) | |
| else: | |
| data = 0 | |
| _ = await queue.put(data) | |
| async def produce_images(queue, delay): | |
| while st.session_state.render and not st.session_state.done: | |
| _ = await asyncio.sleep(delay) | |
| data = next(st.session_state.rollout_iterator) | |
| st.session_state.done = data["done"] | |
| _ = await queue.put(data) | |
| def show_image(data, image_placeholder): | |
| c = image_placeholder.container() | |
| c.image( | |
| data["pixel_array"], | |
| ) | |
| c.text( | |
| f"agent position: {data['x']} \ntimestep: {data['episode_len']} \nepisode return: {data['episode_return']} \n" | |
| ) | |
| async def consume_images(image_placeholder, queue, delay): | |
| while st.session_state.render and not st.session_state.done: | |
| _ = await asyncio.sleep(delay) | |
| data = await queue.get() | |
| st.session_state.last_image = data | |
| show_image(data, image_placeholder) | |
| queue.task_done() | |
| async def run_app( | |
| data_placeholder, | |
| queue, | |
| produce_delay, | |
| consume_delay, | |
| phase1steps, | |
| phase2steps, | |
| plotfrequency, | |
| num_eval_episodes, | |
| image_placeholder, | |
| queue_render, | |
| render_produce_delay, | |
| render_consume_delay, | |
| ): | |
| _ = await asyncio.gather( | |
| produce_images(queue_render, render_produce_delay), | |
| consume_images(image_placeholder, queue_render, render_consume_delay), | |
| ) | |
| st.session_state.render = False | |
| st.session_state.done = False | |
| empty_queue(queue_render) | |
| _ = await asyncio.gather( | |
| train( | |
| queue, | |
| produce_delay, | |
| phase1steps, | |
| st.session_state.obs1, | |
| st.session_state.env_phase1, | |
| num_eval_episodes, | |
| plotfrequency, | |
| ), | |
| draw(data_placeholder, queue, consume_delay, phase1steps, plotfrequency), | |
| ) | |
| _ = await asyncio.gather( | |
| train( | |
| queue, | |
| produce_delay, | |
| phase1steps + phase2steps, | |
| st.session_state.obs2, | |
| st.session_state.env_phase2, | |
| num_eval_episodes, | |
| plotfrequency, | |
| ), | |
| draw(data_placeholder, queue, consume_delay, phase1steps + phase2steps, plotfrequency), | |
| ) | |
| ##### ACTUAL APP | |
| if __name__ == "__main__": | |
| st.set_page_config( | |
| layout="wide", | |
| initial_sidebar_state="auto", | |
| page_title="Apple Retrieval", | |
| page_icon=None, | |
| ) | |
| st.title("ON THE ROLE OF FORGETTING IN FINE-TUNING REINFORCEMENT LEARNING MODELS") | |
| st.header("Toy example of forgetting: AppleRetrieval") | |
| col1, col2, col3 = st.sidebar.columns(3) | |
| options = ( | |
| "phase2 full interference", | |
| "full from scratch", | |
| "phase2 phase1 forgetting", | |
| "phase2 phase1 optimal solution", | |
| ) | |
| st.sidebar.selectbox( | |
| "parameter presets", | |
| range(len(options)), | |
| index=0, | |
| format_func=lambda x: options[x], | |
| on_change=select_preset, | |
| key="pick_preset", | |
| ) | |
| pick_container = st.sidebar.container() | |
| c = pick_container.number_input("c", value=0.5, on_change=need_reset, key="pick_c") | |
| goal_x = pick_container.number_input("distance to apple", value=20, on_change=need_reset, key="pick_goal_x") | |
| time_limit = pick_container.number_input("time limit", value=50, on_change=need_reset, key="pick_time_limit") | |
| lr = pick_container.selectbox( | |
| "Learning rate", | |
| np.array([0.00001, 0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0]), | |
| 5, | |
| on_change=need_reset, | |
| key="pick_lr", | |
| ) | |
| phase1task = pick_container.selectbox( | |
| "Pretraining task", ("full", "phase1", "phase2"), 2, on_change=need_reset, key="pick_phase1task" | |
| ) | |
| phase2task = pick_container.selectbox( | |
| "Finetuning task", ("full", "phase1", "phase2"), 0, on_change=need_reset, key="pick_phase2task" | |
| ) | |
| # weight1 = pick_container.number_input("init weight1", value=0, on_change=need_reset, key="pick_weight1") | |
| # weight2 = pick_container.number_input("init weight2", value=0, on_change=need_reset, key="pick_weight2") | |
| if "event_loop" not in st.session_state: | |
| st.session_state.loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(st.session_state.loop) | |
| if "queue" not in st.session_state: | |
| st.session_state.queue = asyncio.Queue(QUEUE_SIZE) | |
| if "queue_render" not in st.session_state: | |
| st.session_state.queue_render = asyncio.Queue(QUEUE_SIZE) | |
| if "play" not in st.session_state: | |
| st.session_state.play = False | |
| if "step" not in st.session_state: | |
| st.session_state.step = False | |
| if "render" not in st.session_state: | |
| st.session_state.render = False | |
| reset_button = partial( | |
| reset, | |
| c=c, | |
| start_x=0, | |
| goal_x=goal_x, | |
| time_limit=time_limit, | |
| lr=lr, | |
| # weight1=weight1, | |
| # weight2=weight2, | |
| # weight1=0, | |
| # weight2=10, # soves the environment | |
| pretrain=phase1task, | |
| finetune=phase2task, | |
| ) | |
| col1.button("Reset", on_click=reset_button, type="primary") | |
| if "logger" not in st.session_state or st.session_state.need_reset: | |
| reset_button() | |
| myKey = "play_pause" | |
| if myKey not in st.session_state: | |
| st.session_state[myKey] = False | |
| if st.session_state[myKey]: | |
| myBtn = col2.button("Pause", on_click=play_pause, type="primary") | |
| else: | |
| myBtn = col2.button("Play", on_click=play_pause, type="primary") | |
| col3.button("Step", on_click=step, type="primary") | |
| st.header("Summary") | |
| st.write( | |
| """ | |
| Run training on the "phase2 full interference" setting to see an example of forgetting in fine-tuning RL models. | |
| A model is pre-trained on a part of the environment called Phase 2 for 500 steps, | |
| and then it is fine-tuned on the whole environment for another 2000 steps. | |
| However, it forgets how to perform on Phase 2 during fine-tuning before it even gets there. | |
| We highlight this as an important problem in fine-tuning RL models. | |
| We invite you to play around with the hyperparameters and find out more about the forgetting phenomenon. | |
| """ | |
| ) | |
| st.markdown( | |
| """ | |
| By comparing the results of these four presets, you can gain a deeper understanding of how different training configurations can impact the final performance of a neural network. | |
| * `phase2 full interference` - already mentioned setting where we pretrain on phase 2 and finetune on whole environment. | |
| * `full from scratch` - "standard" train on whole environment for 2000 steps for comparison. | |
| * `phase2 phase1 forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks causes network to forget how to solve phase 2. | |
| * `phase2 phase1 without forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks "move" network weights to optimal solution for both tasks. | |
| """ | |
| ) | |
| data_placeholder = st.empty() | |
| render_placeholder = st.empty() | |
| phase1steps = st.sidebar.slider("Pretraining Steps", 0, 2000, 500, 10, key="pick_phase1steps") | |
| phase2steps = st.sidebar.slider("Finetuning Steps", 0, 2000, 2000, 10, key="pick_phase2steps") | |
| plotfrequency = st.sidebar.number_input("Log frequency ", min_value=1, value=10, step=1) | |
| num_eval_episodes = st.sidebar.number_input("Number of evaluation episodes", min_value=1, value=10, step=1) | |
| if not st.session_state.play and len(st.session_state.data) > 0: | |
| plot(data_placeholder) | |
| st.write( | |
| """ | |
| Figure 1: Plots (a), (b), (c) show the performance of the agent on three different variants of the environment. | |
| Phase1 (a) the agent's goal is to reach the apple. | |
| Phase2 (b) the agent's has to go back home by returning to x = 0. | |
| Full (c) combination of both phases of the environment. | |
| (d) shows the training loss. | |
| State vector has only two values, we can view the netwoks weights (e), (f). | |
| """ | |
| ) | |
| st.header("Vizualize agent behavior") | |
| st.button("Rollout one episode", on_click=render_start, type="primary") | |
| c1, c2 = st.columns(2) | |
| image_placeholder = c1.empty() | |
| render_produce_delay = 1 / 3 | |
| render_consume_delay = 1 / 3 | |
| st.header("About the environment") | |
| st.write( | |
| """This study forms a preliminary investigation into the problem of forgetting in fine-tuning RL models. | |
| We show that fine-tuning a pre-trained model on compositional RL problems might result in a rapid | |
| deterioration of the performance of the pre-trained model if the relevant data is not available at the | |
| beginning of the training. | |
| This phenomenon is known as catastrophic forgetting. | |
| In this demo we show how it can occur in simple toyish situations but it might occur in more realistic problems (e.g. (SAC with MLPs on a compositional robotic environmen). | |
| In our [WIP] paper we showed that applying CL methods significantly limits forgetting and allows for efficient transfer. | |
| """ | |
| ) | |
| st.write( | |
| """The AppleRetrieval environment is a toy example to demonstrate the issue of interference in reinforcement learning. | |
| The goal of the agent is to retrieve an apple from position x = M and return home to x = 0 within a set number of steps T. | |
| The state of the agent is represented by a vector s, which has two elements: s = [1, -c] in phase 1 and s = [1, c] in phase 2. | |
| The first element is a constant and the second element represents the information about the current phase. | |
| The optimal policy is to go right in phase 1 and go left in phase 2. | |
| The cause of interference can be identified by checking the reliance of the policy on either s1 or s2. | |
| If the model mostly relies on s2, interference will be limited, but if it relies on s1, | |
| interference will occur as its value is the same in both phases. | |
| The magnitude of s1 and s2 can be adjusted by changing the c parameter to guide the model towards focusing on either one. | |
| This simple toy environment shows that the issue of interference can be fundamental to reinforcement learning.""" | |
| ) | |
| cc1, cc2 = st.columns([1, 2]) | |
| cc1.image("assets/apple_env.png") | |
| st.header("Training algorithm") | |
| st.write( | |
| """For practical reasons (time and stability) we don't use REINFORCE in this DEMO to illustrate the training dynamics of the environment. | |
| Instead, we train the model by minimizing the negative log likelihood of target actions (move right or left). | |
| We train the model in each step of the environment and sample actions from Categorical distribution taken from model's output. | |
| """ | |
| ) | |
| st.markdown( | |
| """ | |
| Pseudocode of the train loop: | |
| ```python | |
| obs = env.reset() | |
| for timestep in range(steps): | |
| probs = model(obs) | |
| dist = Categorical(probs) | |
| action = dist.sample() | |
| target_action = env.get_target_action() | |
| loss = -dist.log_prob(target_action) | |
| optim.zero_grad() | |
| loss.backward() | |
| optim.step() | |
| obs, reward, done, info = env.step(action) | |
| if done: | |
| obs = env.reset() | |
| """ | |
| ) | |
| st.header("What do all the hyperparameters mean?") | |
| st.markdown( | |
| """ | |
| * **parameter presets** - few sets of pre-defined hyperparameters that can be used as a starting point for a specific experiment. | |
| * **c** - second element of state vecor, decreasing this value will result in stronger forgetting. | |
| * **distance to apple** - refers to how far agent needs to travel in the right before it encounters the apple. | |
| * **time limit** - maximum amount of timesteps that the agent is allowed to interact with the environment before episode ends. | |
| * **learning rate** - hyperparameter that determines the steps size taken by the learning algoritm in each iteration. | |
| * **pretraining and finetuning task** - define the environment the model will be trained on each stage. | |
| * **pretraining and finetuning steps** - define how long model will be trained on each stage. | |
| * **init weights** - initial values assigned to the model's parameters before trainig begins. | |
| * **log frequency** - refers to frequency of logging the metrics and evaluating the agent. | |
| * **number of evaluation episodes** - number of rollouts during testing of the agent. | |
| """ | |
| ) | |
| st.header("Limitations & Conclusions") | |
| st.write( | |
| """ | |
| At the same time, this study, due to its preliminary nature, has numerous limitations which we | |
| hope to address in future work. We only considered a fairly strict formulation of the forgetting | |
| scenario where we assumed that the pre-trained model works perfectly on tasks that appear later in | |
| the fine-tuning. In practice, one should also consider the case when even though there are differences | |
| between the pre-training and fine-tuning tasks, transfer is still possible. | |
| """ | |
| ) | |
| st.write( | |
| """ | |
| At the same time, even given | |
| these limitations, we see forgetting as an important problem to be solved and hope that addressing | |
| these issues in the future might help with building and fine-tuning better foundation models in RL. | |
| """ | |
| ) | |
| produce_delay = 1 / 1000 | |
| consume_delay = 1 / 1000 | |
| plot(data_placeholder) | |
| show_image(st.session_state.last_image, image_placeholder) | |
| asyncio.run( | |
| run_app( | |
| data_placeholder, | |
| st.session_state.queue, | |
| produce_delay, | |
| consume_delay, | |
| phase1steps, | |
| phase2steps, | |
| plotfrequency, | |
| num_eval_episodes, | |
| image_placeholder, | |
| st.session_state.queue_render, | |
| render_produce_delay, | |
| render_consume_delay, | |
| ) | |
| ) | |