joel-woodfield commited on
Commit
8b69b54
·
1 Parent(s): 09ce41d

Set fixed colors for each plot component

Browse files
Files changed (1) hide show
  1. gp_visualizer.py +14 -6
gp_visualizer.py CHANGED
@@ -95,6 +95,8 @@ class GPVisualizer:
95
  "show_predictions": True,
96
  }
97
 
 
 
98
  self.css = """
99
  .hidden-button {
100
  display: none;
@@ -136,19 +138,25 @@ class GPVisualizer:
136
 
137
  if self.plot_options["show_training_data"]:
138
  if len(X_tr) > 1:
139
- plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2))
140
  else:
141
- plt.scatter(X_tr.flatten(), y_tr, label='training data')
142
 
143
  if self.plot_options["show_true_function"]:
144
- plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()), color='red', label='true function')
145
 
146
  if self.plot_options["show_predictions"]:
147
- plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions')
148
 
149
  if self.plot_options["show_confidence_interval"]:
150
- plt.fill_between(X_ts.flatten(), y_pred - 1.96*y_std, y_pred + 1.96*y_std, alpha=0.5,
151
- label='95% confidence interval')
 
 
 
 
 
 
152
 
153
  plt.legend()
154
 
 
95
  "show_predictions": True,
96
  }
97
 
98
+ self.plot_cmap = plt.get_cmap("tab20")
99
+
100
  self.css = """
101
  .hidden-button {
102
  display: none;
 
138
 
139
  if self.plot_options["show_training_data"]:
140
  if len(X_tr) > 1:
141
+ plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2), color=self.plot_cmap(0))
142
  else:
143
+ plt.scatter(X_tr.flatten(), y_tr, label='training data', color=self.plot_cmap(0))
144
 
145
  if self.plot_options["show_true_function"]:
146
+ plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()), label='true function', color=self.plot_cmap(1))
147
 
148
  if self.plot_options["show_predictions"]:
149
+ plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions', color=self.plot_cmap(2))
150
 
151
  if self.plot_options["show_confidence_interval"]:
152
+ plt.fill_between(
153
+ X_ts.flatten(),
154
+ y_pred - 1.96*y_std,
155
+ y_pred + 1.96*y_std,
156
+ alpha=0.2,
157
+ label='95% confidence interval',
158
+ color=self.plot_cmap(3)
159
+ )
160
 
161
  plt.legend()
162