Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| class Architecture: | |
| def __init__( | |
| self, | |
| hidden_units: tuple[int] = (64, 64), | |
| activations: tuple[str] = ("ReLU", "ReLU"), | |
| ): | |
| self.hidden_units = hidden_units | |
| self.activations = activations | |
| def update(self, **kwargs): | |
| return Architecture( | |
| hidden_units=kwargs.get("hidden_units", self.hidden_units), | |
| activations=kwargs.get("activations", self.activations), | |
| ) | |
| def __hash__(self): | |
| return hash((self.hidden_units, self.activations)) | |
| def num_layers(self): | |
| return len(self.hidden_units) | |
| class ArchitectureView: | |
| def __init__(self, max_layers: int = 5): | |
| self.max_layers = max_layers | |
| def update_layer_components( | |
| self, state: Architecture, *layer_components | |
| ): | |
| if len(layer_components) != self.max_layers * 2: | |
| raise ValueError("Incorrect number of layer components") | |
| num_layers = state.num_layers | |
| hidden_units = [] | |
| activations = [] | |
| for i in range(0, num_layers * 2, 2): | |
| hidden_units.append(layer_components[i]) | |
| activations.append(layer_components[i + 1]) | |
| state = state.update( | |
| hidden_units=tuple(hidden_units), | |
| activations=tuple(activations), | |
| ) | |
| return state | |
| def add_layer(self, state: Architecture): | |
| if state.num_layers < self.max_layers: | |
| state = state.update( | |
| hidden_units=state.hidden_units + (64,), | |
| activations=state.activations + ("ReLU",), | |
| ) | |
| updates = [] | |
| for i in range(self.max_layers): | |
| # twice for hidden units and activation | |
| updates.append(gr.update(visible=(i < state.num_layers))) | |
| updates.append(gr.update(visible=(i < state.num_layers))) | |
| return state, *updates | |
| def remove_layer(self, state: Architecture): | |
| if state.num_layers > 0: | |
| state = state.update( | |
| hidden_units=state.hidden_units[:-1], | |
| activations=state.activations[:-1], | |
| ) | |
| updates = [] | |
| for i in range(self.max_layers): | |
| # twice for hidden units and activation | |
| updates.append(gr.update(visible=(i < state.num_layers))) | |
| updates.append(gr.update(visible=(i < state.num_layers))) | |
| return state, *updates | |
| def build(self, state: gr.State): | |
| architecture = state.value | |
| layer_components = [] | |
| with gr.Column(): | |
| with gr.Row(): | |
| add_layer = gr.Button("Add Layer") | |
| remove_layer = gr.Button("Remove Layer") | |
| for layer in range(self.max_layers): | |
| with gr.Row(): | |
| hidden_units = gr.Number( | |
| label="Hidden units", | |
| value=64, | |
| visible=(layer < architecture.num_layers), | |
| precision=0, | |
| ) | |
| activation = gr.Dropdown( | |
| label="Activation", | |
| choices=["ReLU", "Sigmoid", "Tanh", "LeakyReLU", "ELU", "GELU", "Identity"], | |
| value="ReLU", | |
| visible=(layer < architecture.num_layers), | |
| ) | |
| layer_components.append(hidden_units) | |
| layer_components.append(activation) | |
| with gr.Row(): | |
| output_units = gr.Number( | |
| label="Output units", | |
| value=1, | |
| interactive=False, | |
| ) | |
| output_activation = gr.Textbox( | |
| label="Activation", | |
| value="Identity", | |
| interactive=False, | |
| ) | |
| # callbacks | |
| add_layer.click( | |
| fn=self.add_layer, | |
| inputs=[state], | |
| outputs=[state] + layer_components, | |
| ) | |
| remove_layer.click( | |
| fn=self.remove_layer, | |
| inputs=[state], | |
| outputs=[state] + layer_components, | |
| ) | |
| for i, component in enumerate(layer_components): | |
| # hidden unit | |
| if i % 2 == 0: | |
| component.submit( | |
| fn=self.update_layer_components, | |
| inputs=[state] + layer_components, | |
| outputs=[state], | |
| ) | |
| # activation | |
| else: | |
| component.change( | |
| fn=self.update_layer_components, | |
| inputs=[state] + layer_components, | |
| outputs=[state], | |
| ) | |