import asyncio from functools import partial import numpy as np import pandas as pd import streamlit as st import torch from apple.envs.discrete_apple import get_apple_env from apple.evaluation.render_episode import render_episode from apple.logger import EpochLogger from apple.models.categorical_policy import CategoricalPolicy from apple.training.trainer import Trainer QUEUE_SIZE = 1000 def init_training( c: float = 1.0, start_x: float = 0.0, goal_x: float = 50.0, time_limit: int = 200, lr: float = 1e-3, weight1: float = 0.0, weight2: float = 0.0, pretrain: str = "phase1", finetune: str = "full", bias_in_state: bool = True, position_in_state: bool = False, apple_in_state: bool = True, ): st.session_state.logger = EpochLogger(verbose=False) env_kwargs = dict( start_x=start_x, goal_x=goal_x, c=c, time_limit=time_limit, bias_in_state=bias_in_state, position_in_state=position_in_state, apple_in_state=apple_in_state, ) st.session_state.env_full = get_apple_env("full", render_mode="rgb_array", **env_kwargs) st.session_state.env_phase1 = get_apple_env(pretrain, **env_kwargs) st.session_state.env_phase2 = get_apple_env(finetune, **env_kwargs) st.session_state.test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]] st.session_state.model = CategoricalPolicy( st.session_state.env_phase1.observation_space.shape[0], 1, weight1, weight2 ) st.session_state.optim = torch.optim.SGD(st.session_state.model.parameters(), lr=lr) st.session_state.trainer = Trainer(st.session_state.model, st.session_state.optim, st.session_state.logger) st.session_state.train_it = 0 st.session_state.draw_it = 0 st.session_state.total_steps = 0 st.session_state.data = [] st.session_state.obs1 = st.session_state.env_phase1.reset() st.session_state.obs2 = st.session_state.env_phase2.reset() def init_reset(): st.session_state.rollout_iterator = iter( partial(render_episode, st.session_state.env_full, st.session_state.model)() ) st.session_state.last_image = dict( x=0, obs=st.session_state.env_full.reset(), action=None, reward=0, done=False, episode_len=0, episode_return=0, pixel_array=st.session_state.env_full.unwrapped.render(), ) def select_preset(): if st.session_state.pick_preset == 0: preset_finetuning_interference() elif st.session_state.pick_preset == 1: preset_train_full_from_scratch() elif st.session_state.pick_preset == 2: preset_task_interference() elif st.session_state.pick_preset == 3: preset_without_task_interference() def preset_task_interference(): st.session_state.pick_c = 0.5 st.session_state.pick_goal_x = 20 st.session_state.pick_time_limit = 50 st.session_state.pick_lr = 0.05 st.session_state.pick_phase1task = "phase2" st.session_state.pick_phase2task = "phase1" st.session_state.pick_weight1 = 0.0 st.session_state.pick_weight2 = 0.0 st.session_state.pick_phase1steps = 500 st.session_state.pick_phase2steps = 500 need_reset() def preset_finetuning_interference(): st.session_state.pick_c = 0.5 st.session_state.pick_goal_x = 20 st.session_state.pick_time_limit = 50 st.session_state.pick_lr = 0.05 st.session_state.pick_phase1task = "phase2" st.session_state.pick_phase2task = "full" st.session_state.pick_weight1 = 0.0 st.session_state.pick_weight2 = 0.0 st.session_state.pick_phase1steps = 500 st.session_state.pick_phase2steps = 2000 need_reset() def preset_without_task_interference(): st.session_state.pick_c = 1.0 st.session_state.pick_goal_x = 20 st.session_state.pick_time_limit = 50 st.session_state.pick_lr = 0.05 st.session_state.pick_phase1task = "phase2" st.session_state.pick_phase2task = "phase1" st.session_state.pick_weight1 = 0.0 st.session_state.pick_weight2 = 0.0 st.session_state.pick_phase1steps = 500 st.session_state.pick_phase2steps = 500 need_reset() def preset_train_full_from_scratch(): st.session_state.pick_c = 0.5 st.session_state.pick_goal_x = 20 st.session_state.pick_time_limit = 50 st.session_state.pick_lr = 0.05 st.session_state.pick_phase1task = "phase2" st.session_state.pick_phase2task = "full" st.session_state.pick_weight1 = 0.0 st.session_state.pick_weight2 = 0.0 st.session_state.pick_phase1steps = 0 st.session_state.pick_phase2steps = 2000 need_reset() def empty_queue(q: asyncio.Queue): for _ in range(q.qsize()): # Depending on your program, you may want to # catch QueueEmpty q.get_nowait() q.task_done() def reset(**kwargs): init_training(**kwargs) init_reset() st.session_state.play = False st.session_state.step = False st.session_state.render = False st.session_state.done = False empty_queue(st.session_state.queue) empty_queue(st.session_state.queue_render) st.session_state.play_pause = False st.session_state.need_reset = False def render_start(): st.session_state.render = True st.session_state.done = False init_reset() def need_reset(): st.session_state.need_reset = True st.session_state.play = False st.session_state.render = False def play_pause(): if st.session_state.play: st.session_state.play = False st.session_state.play_pause = False else: st.session_state.play = True st.session_state.play_pause = True def step(): st.session_state.step = True def plot(data_placeholder): df = pd.DataFrame(st.session_state.data) if not df.empty: df.set_index("total_env_steps", inplace=True) container = data_placeholder.container() c1, c2, c3 = container.columns(3) def view_df(names): rdf = df.loc[:, df.columns.isin(names)] if rdf.empty: return pd.DataFrame([{name: 0 for name in names}]) else: return rdf c1.write("phase1/success_rate") c1.line_chart(view_df(["phase1/success"])) c2.write("phase2/success_rate") c2.line_chart(view_df(["phase2/success"])) c3.write("full/success_rate") c3.line_chart(view_df(["full/success"])) c1.write("train/loss") c1.line_chart(view_df(["train/loss"])) c2.write("weight0") c2.line_chart(view_df(["weight0"])) c3.write("weight1") c3.line_chart(view_df(["weight1"])) async def draw(data_placeholder, queue, delay, steps, plotfrequency): while (st.session_state.play or st.session_state.step) and st.session_state.draw_it < steps: _ = await asyncio.sleep(delay) new_data = await queue.get() st.session_state.draw_it += 1 if st.session_state.draw_it % plotfrequency == 0: st.session_state.data.append(new_data) plot(data_placeholder) st.session_state.step = False queue.task_done() async def train(queue, delay, steps, obs, env, num_eval_episodes, plotfrequency): while (st.session_state.play or st.session_state.step) and st.session_state.train_it < steps: _ = await asyncio.sleep(delay) st.session_state.train_it += 1 st.session_state.total_steps += 1 output = st.session_state.model(obs) action, log_prob = st.session_state.model.sample(output) st.session_state.trainer.update( env, output, st.session_state.model, st.session_state.optim, st.session_state.logger ) obs, reward, done, info = env.step(action) if done: obs = env.reset() if st.session_state.train_it % plotfrequency == 0: st.session_state.trainer.test_agent( st.session_state.model, st.session_state.logger, st.session_state.test_envs, num_eval_episodes ) data = st.session_state.trainer.log( st.session_state.logger, st.session_state.train_it, st.session_state.model ) else: data = 0 _ = await queue.put(data) async def produce_images(queue, delay): while st.session_state.render and not st.session_state.done: _ = await asyncio.sleep(delay) data = next(st.session_state.rollout_iterator) st.session_state.done = data["done"] _ = await queue.put(data) def show_image(data, image_placeholder): c = image_placeholder.container() c.image( data["pixel_array"], ) c.text( f"agent position: {data['x']} \ntimestep: {data['episode_len']} \nepisode return: {data['episode_return']} \n" ) async def consume_images(image_placeholder, queue, delay): while st.session_state.render and not st.session_state.done: _ = await asyncio.sleep(delay) data = await queue.get() st.session_state.last_image = data show_image(data, image_placeholder) queue.task_done() async def run_app( data_placeholder, queue, produce_delay, consume_delay, phase1steps, phase2steps, plotfrequency, num_eval_episodes, image_placeholder, queue_render, render_produce_delay, render_consume_delay, ): _ = await asyncio.gather( produce_images(queue_render, render_produce_delay), consume_images(image_placeholder, queue_render, render_consume_delay), ) st.session_state.render = False st.session_state.done = False empty_queue(queue_render) _ = await asyncio.gather( train( queue, produce_delay, phase1steps, st.session_state.obs1, st.session_state.env_phase1, num_eval_episodes, plotfrequency, ), draw(data_placeholder, queue, consume_delay, phase1steps, plotfrequency), ) _ = await asyncio.gather( train( queue, produce_delay, phase1steps + phase2steps, st.session_state.obs2, st.session_state.env_phase2, num_eval_episodes, plotfrequency, ), draw(data_placeholder, queue, consume_delay, phase1steps + phase2steps, plotfrequency), ) ##### ACTUAL APP if __name__ == "__main__": st.set_page_config( layout="wide", initial_sidebar_state="auto", page_title="Apple Retrieval", page_icon=None, ) st.title("ON THE ROLE OF FORGETTING IN FINE-TUNING REINFORCEMENT LEARNING MODELS") st.header("Toy example of forgetting: AppleRetrieval") col1, col2, col3 = st.sidebar.columns(3) options = ( "phase2 full interference", "full from scratch", "phase2 phase1 forgetting", "phase2 phase1 optimal solution", ) st.sidebar.selectbox( "parameter presets", range(len(options)), index=0, format_func=lambda x: options[x], on_change=select_preset, key="pick_preset", ) pick_container = st.sidebar.container() c = pick_container.number_input("c", value=0.5, on_change=need_reset, key="pick_c") goal_x = pick_container.number_input("distance to apple", value=20, on_change=need_reset, key="pick_goal_x") time_limit = pick_container.number_input("time limit", value=50, on_change=need_reset, key="pick_time_limit") lr = pick_container.selectbox( "Learning rate", np.array([0.00001, 0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0]), 5, on_change=need_reset, key="pick_lr", ) phase1task = pick_container.selectbox( "Pretraining task", ("full", "phase1", "phase2"), 2, on_change=need_reset, key="pick_phase1task" ) phase2task = pick_container.selectbox( "Finetuning task", ("full", "phase1", "phase2"), 0, on_change=need_reset, key="pick_phase2task" ) # weight1 = pick_container.number_input("init weight1", value=0, on_change=need_reset, key="pick_weight1") # weight2 = pick_container.number_input("init weight2", value=0, on_change=need_reset, key="pick_weight2") if "event_loop" not in st.session_state: st.session_state.loop = asyncio.new_event_loop() asyncio.set_event_loop(st.session_state.loop) if "queue" not in st.session_state: st.session_state.queue = asyncio.Queue(QUEUE_SIZE) if "queue_render" not in st.session_state: st.session_state.queue_render = asyncio.Queue(QUEUE_SIZE) if "play" not in st.session_state: st.session_state.play = False if "step" not in st.session_state: st.session_state.step = False if "render" not in st.session_state: st.session_state.render = False reset_button = partial( reset, c=c, start_x=0, goal_x=goal_x, time_limit=time_limit, lr=lr, # weight1=weight1, # weight2=weight2, # weight1=0, # weight2=10, # soves the environment pretrain=phase1task, finetune=phase2task, ) col1.button("Reset", on_click=reset_button, type="primary") if "logger" not in st.session_state or st.session_state.need_reset: reset_button() myKey = "play_pause" if myKey not in st.session_state: st.session_state[myKey] = False if st.session_state[myKey]: myBtn = col2.button("Pause", on_click=play_pause, type="primary") else: myBtn = col2.button("Play", on_click=play_pause, type="primary") col3.button("Step", on_click=step, type="primary") st.header("Summary") st.write( """ Run training on the "phase2 full interference" setting to see an example of forgetting in fine-tuning RL models. A model is pre-trained on a part of the environment called Phase 2 for 500 steps, and then it is fine-tuned on the whole environment for another 2000 steps. However, it forgets how to perform on Phase 2 during fine-tuning before it even gets there. We highlight this as an important problem in fine-tuning RL models. We invite you to play around with the hyperparameters and find out more about the forgetting phenomenon. """ ) st.markdown( """ By comparing the results of these four presets, you can gain a deeper understanding of how different training configurations can impact the final performance of a neural network. * `phase2 full interference` - already mentioned setting where we pretrain on phase 2 and finetune on whole environment. * `full from scratch` - "standard" train on whole environment for 2000 steps for comparison. * `phase2 phase1 forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks causes network to forget how to solve phase 2. * `phase2 phase1 without forgetting` - a setting where we pretrain on phase 2 and finetune on phase 1. Here gradient interference between tasks "move" network weights to optimal solution for both tasks. """ ) data_placeholder = st.empty() render_placeholder = st.empty() phase1steps = st.sidebar.slider("Pretraining Steps", 0, 2000, 500, 10, key="pick_phase1steps") phase2steps = st.sidebar.slider("Finetuning Steps", 0, 2000, 2000, 10, key="pick_phase2steps") plotfrequency = st.sidebar.number_input("Log frequency ", min_value=1, value=10, step=1) num_eval_episodes = st.sidebar.number_input("Number of evaluation episodes", min_value=1, value=10, step=1) if not st.session_state.play and len(st.session_state.data) > 0: plot(data_placeholder) st.write( """ Figure 1: Plots (a), (b), (c) show the performance of the agent on three different variants of the environment. Phase1 (a) the agent's goal is to reach the apple. Phase2 (b) the agent's has to go back home by returning to x = 0. Full (c) combination of both phases of the environment. (d) shows the training loss. State vector has only two values, we can view the netwoks weights (e), (f). """ ) st.header("Vizualize agent behavior") st.button("Rollout one episode", on_click=render_start, type="primary") c1, c2 = st.columns(2) image_placeholder = c1.empty() render_produce_delay = 1 / 3 render_consume_delay = 1 / 3 st.header("About the environment") st.write( """This study forms a preliminary investigation into the problem of forgetting in fine-tuning RL models. We show that fine-tuning a pre-trained model on compositional RL problems might result in a rapid deterioration of the performance of the pre-trained model if the relevant data is not available at the beginning of the training. This phenomenon is known as catastrophic forgetting. In this demo we show how it can occur in simple toyish situations but it might occur in more realistic problems (e.g. (SAC with MLPs on a compositional robotic environmen). In our [WIP] paper we showed that applying CL methods significantly limits forgetting and allows for efficient transfer. """ ) st.write( """The AppleRetrieval environment is a toy example to demonstrate the issue of interference in reinforcement learning. The goal of the agent is to retrieve an apple from position x = M and return home to x = 0 within a set number of steps T. The state of the agent is represented by a vector s, which has two elements: s = [1, -c] in phase 1 and s = [1, c] in phase 2. The first element is a constant and the second element represents the information about the current phase. The optimal policy is to go right in phase 1 and go left in phase 2. The cause of interference can be identified by checking the reliance of the policy on either s1 or s2. If the model mostly relies on s2, interference will be limited, but if it relies on s1, interference will occur as its value is the same in both phases. The magnitude of s1 and s2 can be adjusted by changing the c parameter to guide the model towards focusing on either one. This simple toy environment shows that the issue of interference can be fundamental to reinforcement learning.""" ) cc1, cc2 = st.columns([1, 2]) cc1.image("assets/apple_env.png") st.header("Training algorithm") st.write( """For practical reasons (time and stability) we don't use REINFORCE in this DEMO to illustrate the training dynamics of the environment. Instead, we train the model by minimizing the negative log likelihood of target actions (move right or left). We train the model in each step of the environment and sample actions from Categorical distribution taken from model's output. """ ) st.markdown( """ Pseudocode of the train loop: ```python obs = env.reset() for timestep in range(steps): probs = model(obs) dist = Categorical(probs) action = dist.sample() target_action = env.get_target_action() loss = -dist.log_prob(target_action) optim.zero_grad() loss.backward() optim.step() obs, reward, done, info = env.step(action) if done: obs = env.reset() """ ) st.header("What do all the hyperparameters mean?") st.markdown( """ * **parameter presets** - few sets of pre-defined hyperparameters that can be used as a starting point for a specific experiment. * **c** - second element of state vecor, decreasing this value will result in stronger forgetting. * **distance to apple** - refers to how far agent needs to travel in the right before it encounters the apple. * **time limit** - maximum amount of timesteps that the agent is allowed to interact with the environment before episode ends. * **learning rate** - hyperparameter that determines the steps size taken by the learning algoritm in each iteration. * **pretraining and finetuning task** - define the environment the model will be trained on each stage. * **pretraining and finetuning steps** - define how long model will be trained on each stage. * **init weights** - initial values assigned to the model's parameters before trainig begins. * **log frequency** - refers to frequency of logging the metrics and evaluating the agent. * **number of evaluation episodes** - number of rollouts during testing of the agent. """ ) st.header("Limitations & Conclusions") st.write( """ At the same time, this study, due to its preliminary nature, has numerous limitations which we hope to address in future work. We only considered a fairly strict formulation of the forgetting scenario where we assumed that the pre-trained model works perfectly on tasks that appear later in the fine-tuning. In practice, one should also consider the case when even though there are differences between the pre-training and fine-tuning tasks, transfer is still possible. """ ) st.write( """ At the same time, even given these limitations, we see forgetting as an important problem to be solved and hope that addressing these issues in the future might help with building and fine-tuning better foundation models in RL. """ ) produce_delay = 1 / 1000 consume_delay = 1 / 1000 plot(data_placeholder) show_image(st.session_state.last_image, image_placeholder) asyncio.run( run_app( data_placeholder, st.session_state.queue, produce_delay, consume_delay, phase1steps, phase2steps, plotfrequency, num_eval_episodes, image_placeholder, st.session_state.queue_render, render_produce_delay, render_consume_delay, ) )