joel-woodfield commited on
Commit
41b207c
·
1 Parent(s): 5c45b99

Add MLP architecture configuration

Browse files
Files changed (2) hide show
  1. architecture.py +148 -0
  2. mlp_visualizer.py +22 -13
architecture.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ class Architecture:
5
+ def __init__(
6
+ self,
7
+ hidden_units: tuple[int] = (64, 64),
8
+ activations: tuple[str] = ("ReLU", "ReLU"),
9
+ ):
10
+ self.hidden_units = hidden_units
11
+ self.activations = activations
12
+
13
+ def update(self, **kwargs):
14
+ return Architecture(
15
+ hidden_units=kwargs.get("hidden_units", self.hidden_units),
16
+ activations=kwargs.get("activations", self.activations),
17
+ )
18
+
19
+ def __hash__(self):
20
+ return hash((self.hidden_units, self.activations))
21
+
22
+ @property
23
+ def num_layers(self):
24
+ return len(self.hidden_units)
25
+
26
+
27
+ class ArchitectureView:
28
+ def __init__(self, max_layers: int = 5):
29
+ self.max_layers = max_layers
30
+
31
+ def update_layer_components(
32
+ self, state: Architecture, *layer_components
33
+ ):
34
+ if len(layer_components) != self.max_layers * 2:
35
+ raise ValueError("Incorrect number of layer components")
36
+
37
+ num_layers = state.num_layers
38
+
39
+ hidden_units = []
40
+ activations = []
41
+ for i in range(0, num_layers * 2, 2):
42
+ hidden_units.append(layer_components[i])
43
+ activations.append(layer_components[i + 1])
44
+
45
+ state = state.update(
46
+ hidden_units=tuple(hidden_units),
47
+ activations=tuple(activations),
48
+ )
49
+
50
+ return state
51
+
52
+ def add_layer(self, state: Architecture):
53
+ if state.num_layers < self.max_layers:
54
+ state = state.update(
55
+ hidden_units=state.hidden_units + (64,),
56
+ activations=state.activations + ("ReLU",),
57
+ )
58
+
59
+ updates = []
60
+ for i in range(self.max_layers):
61
+ # twice for hidden units and activation
62
+ updates.append(gr.update(visible=(i < state.num_layers)))
63
+ updates.append(gr.update(visible=(i < state.num_layers)))
64
+
65
+ return state, *updates
66
+
67
+ def remove_layer(self, state: Architecture):
68
+ if state.num_layers > 0:
69
+ state = state.update(
70
+ hidden_units=state.hidden_units[:-1],
71
+ activations=state.activations[:-1],
72
+ )
73
+
74
+ updates = []
75
+ for i in range(self.max_layers):
76
+ # twice for hidden units and activation
77
+ updates.append(gr.update(visible=(i < state.num_layers)))
78
+ updates.append(gr.update(visible=(i < state.num_layers)))
79
+
80
+ return state, *updates
81
+
82
+ def build(self, state: gr.State):
83
+ architecture = state.value
84
+
85
+ layer_components = []
86
+ with gr.Column():
87
+ with gr.Row():
88
+ add_layer = gr.Button("Add Layer")
89
+ remove_layer = gr.Button("Remove Layer")
90
+
91
+ for layer in range(self.max_layers):
92
+ with gr.Row():
93
+ hidden_units = gr.Number(
94
+ label="Hidden units",
95
+ value=64,
96
+ visible=(layer < architecture.num_layers),
97
+ precision=0,
98
+ )
99
+ activation = gr.Dropdown(
100
+ label="Activation",
101
+ choices=["ReLU", "Sigmoid", "Tanh", "LeakyReLU", "ELU", "GELU", "Identity"],
102
+ value="ReLU",
103
+ visible=(layer < architecture.num_layers),
104
+ )
105
+
106
+ layer_components.append(hidden_units)
107
+ layer_components.append(activation)
108
+
109
+ with gr.Row():
110
+ output_units = gr.Number(
111
+ label="Output units",
112
+ value=1,
113
+ interactive=False,
114
+ )
115
+ output_activation = gr.Textbox(
116
+ label="Activation",
117
+ value="Identity",
118
+ interactive=False,
119
+ )
120
+
121
+ # callbacks
122
+ add_layer.click(
123
+ fn=self.add_layer,
124
+ inputs=[state],
125
+ outputs=[state] + layer_components,
126
+ )
127
+ remove_layer.click(
128
+ fn=self.remove_layer,
129
+ inputs=[state],
130
+ outputs=[state] + layer_components,
131
+ )
132
+
133
+ for i, component in enumerate(layer_components):
134
+ # hidden unit
135
+ if i % 2 == 0:
136
+ component.submit(
137
+ fn=self.update_layer_components,
138
+ inputs=[state] + layer_components,
139
+ outputs=[state],
140
+ )
141
+ # activation
142
+ else:
143
+ component.change(
144
+ fn=self.update_layer_components,
145
+ inputs=[state] + layer_components,
146
+ outputs=[state],
147
+ )
148
+
mlp_visualizer.py CHANGED
@@ -29,6 +29,7 @@ logging.basicConfig(
29
  )
30
  logger = logging.getLogger("ELVIS")
31
 
 
32
  from dataset import Dataset, DatasetView, get_function
33
 
34
 
@@ -44,7 +45,7 @@ class MlpVisualizer:
44
  display: none;
45
  }"""
46
 
47
- def plot(self, dataset_options: Dataset):
48
  print("Plotting")
49
  t1 = time.time()
50
  fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
@@ -53,8 +54,8 @@ class MlpVisualizer:
53
  ax = fig.add_axes([0., 0., 1., 1.]) #
54
  ax.margins(x=0, y=0) # no padding in both directions
55
 
56
- if dataset_options.mode == "generate":
57
- x_test, y_test = get_function(dataset_options.function, xlim=(-2, 2), nsample=100)
58
 
59
  # y_pred = self.model(torch.from_numpy(x_test).float()).detach().numpy()
60
 
@@ -64,15 +65,15 @@ class MlpVisualizer:
64
  ax.set_xlabel("x")
65
  ax.set_ylabel("y")
66
 
67
- if dataset_options.mode == "generate":
68
  ax.set_ylim(y_test.min() - 1, y_test.max() + 1)
69
 
70
- x_train = dataset_options.x
71
- y_train = dataset_options.y
72
  if True:
73
  plt.scatter(x_train.flatten(), y_train, label='training data', color=self.plot_cmap(0))
74
 
75
- if dataset_options.mode == "generate":
76
  plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
77
 
78
  if False:
@@ -98,13 +99,14 @@ class MlpVisualizer:
98
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
99
 
100
  # states
101
- dataset_options = gr.State(Dataset())
 
102
 
103
  # GUI elements and layout
104
  with gr.Row():
105
  with gr.Column(scale=2):
106
  canvas = gr.Image(
107
- value=self.plot(dataset_options.value),
108
  show_download_button=False,
109
  container=True,
110
  )
@@ -112,15 +114,22 @@ class MlpVisualizer:
112
  with gr.Column(scale=1):
113
  with gr.Tab("Dataset"):
114
  dataset_view = DatasetView()
115
- dataset_view.build(state=dataset_options)
116
- dataset_options.change(
117
  fn=self.plot,
118
- inputs=[dataset_options],
119
  outputs=[canvas],
120
  )
121
 
122
  with gr.Tab("Architecture"):
123
- gr.Markdown("HI")
 
 
 
 
 
 
 
124
  with gr.Tab("Train"):
125
  gr.Markdown("HI")
126
  with gr.Tab("Plot"):
 
29
  )
30
  logger = logging.getLogger("ELVIS")
31
 
32
+ from architecture import Architecture, ArchitectureView
33
  from dataset import Dataset, DatasetView, get_function
34
 
35
 
 
45
  display: none;
46
  }"""
47
 
48
+ def plot(self, dataset: Dataset, architecture: Architecture) -> Image.Image:
49
  print("Plotting")
50
  t1 = time.time()
51
  fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
 
54
  ax = fig.add_axes([0., 0., 1., 1.]) #
55
  ax.margins(x=0, y=0) # no padding in both directions
56
 
57
+ if dataset.mode == "generate":
58
+ x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100)
59
 
60
  # y_pred = self.model(torch.from_numpy(x_test).float()).detach().numpy()
61
 
 
65
  ax.set_xlabel("x")
66
  ax.set_ylabel("y")
67
 
68
+ if dataset.mode == "generate":
69
  ax.set_ylim(y_test.min() - 1, y_test.max() + 1)
70
 
71
+ x_train = dataset.x
72
+ y_train = dataset.y
73
  if True:
74
  plt.scatter(x_train.flatten(), y_train, label='training data', color=self.plot_cmap(0))
75
 
76
+ if dataset.mode == "generate":
77
  plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
78
 
79
  if False:
 
99
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
100
 
101
  # states
102
+ dataset = gr.State(Dataset())
103
+ architecture = gr.State(Architecture())
104
 
105
  # GUI elements and layout
106
  with gr.Row():
107
  with gr.Column(scale=2):
108
  canvas = gr.Image(
109
+ value=self.plot(dataset.value, architecture.value),
110
  show_download_button=False,
111
  container=True,
112
  )
 
114
  with gr.Column(scale=1):
115
  with gr.Tab("Dataset"):
116
  dataset_view = DatasetView()
117
+ dataset_view.build(state=dataset)
118
+ dataset.change(
119
  fn=self.plot,
120
+ inputs=[dataset],
121
  outputs=[canvas],
122
  )
123
 
124
  with gr.Tab("Architecture"):
125
+ architecture_view = ArchitectureView()
126
+ architecture_view.build(state=architecture)
127
+ architecture.change(
128
+ fn=self.plot,
129
+ inputs=[dataset, architecture],
130
+ outputs=[canvas],
131
+ )
132
+
133
  with gr.Tab("Train"):
134
  gr.Markdown("HI")
135
  with gr.Tab("Plot"):