Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from train import TrainingLoop | |
| from scipy.special import softmax | |
| import numpy as np | |
| train = None | |
| frames, attributions = None, None | |
| lunar_lander_spec_conversion = { | |
| 0: "X-coordinate", | |
| 1: "Y-coordinate", | |
| 2: "Linear velocity in the X-axis", | |
| 3: "Linear velocity in the Y-axis", | |
| 4: "Angle", | |
| 5: "Angular velocity", | |
| 6: "Left leg touched the floor", | |
| 7: "Right leg touched the floor" | |
| } | |
| def create_training_loop(env_spec): | |
| global train | |
| train = TrainingLoop(env_spec=env_spec) | |
| train.create_agent() | |
| return train.env.spec | |
| def display_softmax(inputs): | |
| inputs = np.array(inputs) | |
| probabilities = softmax(inputs) | |
| softmax_dict = {name: float(prob) for name, prob in zip(lunar_lander_spec_conversion.values(), probabilities)} | |
| return softmax_dict | |
| def generate_output(num_iterations, option): | |
| global frames, attributions | |
| frames, attributions = train.explain_trained(num_iterations=num_iterations, option=option) | |
| slider.maximum = len(frames) | |
| def get_frame_and_attribution(slider_value): | |
| global frames, attributions | |
| frame = frames[slider_value] | |
| attribution = display_softmax(attributions[slider_value]) | |
| return frame, attribution | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Introspection in Deep Reinforcement Learning") | |
| with gr.Tab(label="Attribute"): | |
| env_spec = gr.Textbox(label="Environment Specification (e.g.: LunarLander-v2)", lines=1) | |
| env = gr.Interface(title="Create the Environment", allow_flagging="never", inputs=env_spec, fn=create_training_loop, outputs=gr.JSON()) | |
| with gr.Row(): | |
| option = gr.Dropdown(choices=["Torch Tensor of 0's", "Running Average"], type="index") | |
| baselines = gr.Slider(label="Number of Baseline Iterations", interactive=True, minimum=0, maximum=100, value=10, step=5, info="Baseline inputs to collect for the average", render=True) | |
| gr.Button("ATTRIBUTE").click(fn=generate_output, inputs=[baselines, option]) | |
| slider = gr.Slider(label="Key Frame", minimum=0, maximum=20000, step=1, value=0) | |
| gr.Interface(fn=get_frame_and_attribution, inputs=slider, live=True, outputs=[gr.Image(), gr.Label()]) | |
| demo.launch() |