File size: 2,744 Bytes
20c9cef
 
61730e7
 
 
 
0a96c3e
61730e7
 
1169f2c
 
 
 
61730e7
 
 
 
1169f2c
61730e7
1169f2c
 
 
61730e7
 
 
 
 
 
 
 
 
 
 
0a96c3e
 
 
 
 
 
 
 
61730e7
 
20c9cef
 
 
 
 
 
 
61730e7
1169f2c
61730e7
 
1169f2c
61730e7
 
 
1169f2c
61730e7
1169f2c
61730e7
 
 
 
 
 
 
 
 
 
 
 
 
1169f2c
61730e7
 
 
0a96c3e
 
61730e7
b903cf1
61730e7
20c9cef
 
61730e7
1d54526
 
61730e7
 
20c9cef
61730e7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pickle

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
{{ model_import_statement }}{% if normalize %}{{ normalizer_import_statement }}{% endif %}{{ embedder_import_statement }}


def get_label_to_color_map(labels):
    """Get a mapping from labels to colors."""
    mapping = {}

    cmap = plt.get_cmap("tab10")
    named_colors = set(matplotlib.colors.CSS4_COLORS.keys())
    for i, label in enumerate(labels):
        if label.lower() in named_colors:
            mapping[label] = label.lower()
        else:
            mapping[label] = cmap(i % 10)

    return mapping


def load_dataset(file_path):
    """Load dataset from a CSV file."""
    data = pd.read_csv(file_path)
    X = data.loc[:, data.columns != 'target'].values
    y = data['target'].values.astype(str)

    if len(X) == 0:
        raise ValueError("The dataset is empty or not properly formatted.")

    return X, y{% if normalize or jitter %}


def process_features(X):{% if normalize %}
    normalizer = {{ normalizer_class }}()
    X = normalizer.fit_transform(X){% endif %}{% if jitter %}
    X = X + np.random.normal(0, {{ jitter_scale }}, X.shape){% endif %}
    return X{% endif %}


def load_model(file_path):
    """Load a model from a pickle file."""
    with open(file_path, "rb") as f:
        model = pickle.load(f)
    return model


def create_plot(X, y, model, embedder):
    """Create a 2D plot of the data points and the decision regions."""
    # plot data points
    labels = np.unique(y)
    label_to_color = get_label_to_color_map(labels)

    X_embedded = embedder.transform(X)

    for label in labels:
        subset = X_embedded[y == label]
        plt.scatter(subset[:, 0], subset[:, 1], color=label_to_color[label], label=label)
    plt.legend()

    # plot decision regions
    xx, yy = np.meshgrid(
        np.linspace({{ x_min }}, {{ x_max }}, {{ num_dots }}),
        np.linspace({{ y_min }}, {{ y_max }}, {{ num_dots }}),
    )
    xx = xx.ravel()
    yy = yy.ravel()

    grid = np.c_[xx, yy]
    predictions = model.predict(embedder.inverse_transform(grid))
    predictions = predictions.ravel()
    plt.scatter(xx, yy, c=[label_to_color[p] for p in predictions], s=1, alpha=0.5)


def main():
    X, y = load_dataset("{{ dataset_file }}"){% if normalize or jitter %}
    X = process_features(X){% endif %}

    model = {{ model_class }}({{ model_params }})
    model.fit(X, y)
    # if you want to load a model
    # model = load_model("model.pkl") 

    embedder = {{ embedder_class }}({{ embedder_args }})
    embedder.fit(X)
    create_plot(X, y, model, embedder)
    plt.show()
    # if you want to save as image
    # plt.savefig("plot.png")


if __name__ == "__main__":
    main()