decision_boundary / old_code /export_code_template.py.j2
joel-woodfield's picture
Basic implementation of react version
9b05cbd
import pickle
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
{{ model_import_statement }}{% if normalize %}{{ normalizer_import_statement }}{% endif %}{{ embedder_import_statement }}
def get_label_to_color_map(labels):
"""Get a mapping from labels to colors."""
mapping = {}
cmap = plt.get_cmap("tab10")
named_colors = set(matplotlib.colors.CSS4_COLORS.keys())
for i, label in enumerate(labels):
if label.lower() in named_colors:
mapping[label] = label.lower()
else:
mapping[label] = cmap(i % 10)
return mapping
def load_dataset(file_path):
"""Load dataset from a CSV file."""
data = pd.read_csv(file_path)
X = data.loc[:, data.columns != 'target'].values
y = data['target'].values.astype(str)
if len(X) == 0:
raise ValueError("The dataset is empty or not properly formatted.")
return X, y{% if normalize or jitter %}
def process_features(X):{% if normalize %}
normalizer = {{ normalizer_class }}()
X = normalizer.fit_transform(X){% endif %}{% if jitter %}
X = X + np.random.normal(0, {{ jitter_scale }}, X.shape){% endif %}
return X{% endif %}
def load_model(file_path):
"""Load a model from a pickle file."""
with open(file_path, "rb") as f:
model = pickle.load(f)
return model
def create_plot(X, y, model, embedder):
"""Create a 2D plot of the data points and the decision regions."""
# plot data points
labels = np.unique(y)
label_to_color = get_label_to_color_map(labels)
X_embedded = embedder.transform(X)
for label in labels:
subset = X_embedded[y == label]
plt.scatter(subset[:, 0], subset[:, 1], color=label_to_color[label], label=label)
plt.legend()
# plot decision regions
xx, yy = np.meshgrid(
np.linspace({{ x_min }}, {{ x_max }}, {{ num_dots }}),
np.linspace({{ y_min }}, {{ y_max }}, {{ num_dots }}),
)
xx = xx.ravel()
yy = yy.ravel()
grid = np.c_[xx, yy]
predictions = model.predict(embedder.inverse_transform(grid))
predictions = predictions.ravel()
plt.scatter(xx, yy, c=[label_to_color[p] for p in predictions], s=1, alpha=0.5)
def main():
X, y = load_dataset("{{ dataset_file }}"){% if normalize or jitter %}
X = process_features(X){% endif %}
model = {{ model_class }}({{ model_params }})
model.fit(X, y)
# if you want to load a model
# model = load_model("model.pkl")
embedder = {{ embedder_class }}({{ embedder_args }})
embedder.fit(X)
create_plot(X, y, model, embedder)
plt.show()
# if you want to save as image
# plt.savefig("plot.png")
if __name__ == "__main__":
main()