mlp_visualizer / old /architecture.py
joel-woodfield's picture
Refactor to separate frontend and backend logic
7af7098
import gradio as gr
class Architecture:
def __init__(
self,
hidden_units: tuple[int] = (64, 64),
activations: tuple[str] = ("ReLU", "ReLU"),
):
self.hidden_units = hidden_units
self.activations = activations
def update(self, **kwargs):
return Architecture(
hidden_units=kwargs.get("hidden_units", self.hidden_units),
activations=kwargs.get("activations", self.activations),
)
def __hash__(self):
return hash((self.hidden_units, self.activations))
@property
def num_layers(self):
return len(self.hidden_units)
class ArchitectureView:
def __init__(self, max_layers: int = 5):
self.max_layers = max_layers
def update_layer_components(
self, state: Architecture, *layer_components
):
if len(layer_components) != self.max_layers * 2:
raise ValueError("Incorrect number of layer components")
num_layers = state.num_layers
hidden_units = []
activations = []
for i in range(0, num_layers * 2, 2):
hidden_units.append(layer_components[i])
activations.append(layer_components[i + 1])
state = state.update(
hidden_units=tuple(hidden_units),
activations=tuple(activations),
)
return state
def add_layer(self, state: Architecture):
if state.num_layers < self.max_layers:
state = state.update(
hidden_units=state.hidden_units + (64,),
activations=state.activations + ("ReLU",),
)
updates = []
for i in range(self.max_layers):
# twice for hidden units and activation
updates.append(gr.update(visible=(i < state.num_layers)))
updates.append(gr.update(visible=(i < state.num_layers)))
return state, *updates
def remove_layer(self, state: Architecture):
if state.num_layers > 0:
state = state.update(
hidden_units=state.hidden_units[:-1],
activations=state.activations[:-1],
)
updates = []
for i in range(self.max_layers):
# twice for hidden units and activation
updates.append(gr.update(visible=(i < state.num_layers)))
updates.append(gr.update(visible=(i < state.num_layers)))
return state, *updates
def build(self, state: gr.State):
architecture = state.value
layer_components = []
with gr.Column():
with gr.Row():
add_layer = gr.Button("Add Layer")
remove_layer = gr.Button("Remove Layer")
for layer in range(self.max_layers):
with gr.Row():
hidden_units = gr.Number(
label="Hidden units",
value=64,
visible=(layer < architecture.num_layers),
precision=0,
)
activation = gr.Dropdown(
label="Activation",
choices=["ReLU", "Sigmoid", "Tanh", "LeakyReLU", "ELU", "GELU", "Identity"],
value="ReLU",
visible=(layer < architecture.num_layers),
)
layer_components.append(hidden_units)
layer_components.append(activation)
with gr.Row():
output_units = gr.Number(
label="Output units",
value=1,
interactive=False,
)
output_activation = gr.Textbox(
label="Activation",
value="Identity",
interactive=False,
)
# callbacks
add_layer.click(
fn=self.add_layer,
inputs=[state],
outputs=[state] + layer_components,
)
remove_layer.click(
fn=self.remove_layer,
inputs=[state],
outputs=[state] + layer_components,
)
for i, component in enumerate(layer_components):
# hidden unit
if i % 2 == 0:
component.submit(
fn=self.update_layer_components,
inputs=[state] + layer_components,
outputs=[state],
)
# activation
else:
component.change(
fn=self.update_layer_components,
inputs=[state] + layer_components,
outputs=[state],
)