Spaces:
Sleeping
Sleeping
Commit ·
8b69b54
1
Parent(s): 09ce41d
Set fixed colors for each plot component
Browse files- gp_visualizer.py +14 -6
gp_visualizer.py
CHANGED
|
@@ -95,6 +95,8 @@ class GPVisualizer:
|
|
| 95 |
"show_predictions": True,
|
| 96 |
}
|
| 97 |
|
|
|
|
|
|
|
| 98 |
self.css = """
|
| 99 |
.hidden-button {
|
| 100 |
display: none;
|
|
@@ -136,19 +138,25 @@ class GPVisualizer:
|
|
| 136 |
|
| 137 |
if self.plot_options["show_training_data"]:
|
| 138 |
if len(X_tr) > 1:
|
| 139 |
-
plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2))
|
| 140 |
else:
|
| 141 |
-
plt.scatter(X_tr.flatten(), y_tr, label='training data')
|
| 142 |
|
| 143 |
if self.plot_options["show_true_function"]:
|
| 144 |
-
plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()),
|
| 145 |
|
| 146 |
if self.plot_options["show_predictions"]:
|
| 147 |
-
plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions')
|
| 148 |
|
| 149 |
if self.plot_options["show_confidence_interval"]:
|
| 150 |
-
plt.fill_between(
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
plt.legend()
|
| 154 |
|
|
|
|
| 95 |
"show_predictions": True,
|
| 96 |
}
|
| 97 |
|
| 98 |
+
self.plot_cmap = plt.get_cmap("tab20")
|
| 99 |
+
|
| 100 |
self.css = """
|
| 101 |
.hidden-button {
|
| 102 |
display: none;
|
|
|
|
| 138 |
|
| 139 |
if self.plot_options["show_training_data"]:
|
| 140 |
if len(X_tr) > 1:
|
| 141 |
+
plt.scatter(X_tr.flatten(), y_tr, label='training data (R2=%.2f)' % (R2), color=self.plot_cmap(0))
|
| 142 |
else:
|
| 143 |
+
plt.scatter(X_tr.flatten(), y_tr, label='training data', color=self.plot_cmap(0))
|
| 144 |
|
| 145 |
if self.plot_options["show_true_function"]:
|
| 146 |
+
plt.plot(X_ts.flatten(), np.sin(2*np.pi*X_ts.flatten()), label='true function', color=self.plot_cmap(1))
|
| 147 |
|
| 148 |
if self.plot_options["show_predictions"]:
|
| 149 |
+
plt.scatter(X_ts.flatten(), y_pred, marker='+', label='predictions', color=self.plot_cmap(2))
|
| 150 |
|
| 151 |
if self.plot_options["show_confidence_interval"]:
|
| 152 |
+
plt.fill_between(
|
| 153 |
+
X_ts.flatten(),
|
| 154 |
+
y_pred - 1.96*y_std,
|
| 155 |
+
y_pred + 1.96*y_std,
|
| 156 |
+
alpha=0.2,
|
| 157 |
+
label='95% confidence interval',
|
| 158 |
+
color=self.plot_cmap(3)
|
| 159 |
+
)
|
| 160 |
|
| 161 |
plt.legend()
|
| 162 |
|