Spaces:
Build error
Build error
Jensen Holm
commited on
Commit
·
498c4e0
1
Parent(s):
d7ea050
making the example code cleaner
Browse files- example.py +12 -11
example.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from sklearn import datasets
|
| 2 |
from sklearn.preprocessing import OneHotEncoder
|
| 3 |
from sklearn.model_selection import train_test_split
|
| 4 |
-
from sklearn.metrics import accuracy_score
|
| 5 |
import numpy as np
|
| 6 |
from numpyneuron import (
|
| 7 |
NN,
|
|
@@ -14,7 +14,7 @@ from numpyneuron import (
|
|
| 14 |
RANDOM_SEED = 2
|
| 15 |
|
| 16 |
|
| 17 |
-
def
|
| 18 |
seed: int,
|
| 19 |
) -> tuple[np.ndarray, ...]:
|
| 20 |
digits = datasets.load_digits(as_frame=False)
|
|
@@ -30,9 +30,10 @@ def _preprocess_digits(
|
|
| 30 |
return X_train, X_test, y_train, y_test
|
| 31 |
|
| 32 |
|
| 33 |
-
def train_nn_classifier(
|
| 34 |
-
X_train
|
| 35 |
-
|
|
|
|
| 36 |
nn_classifier = NN(
|
| 37 |
epochs=2_000,
|
| 38 |
hidden_size=16,
|
|
@@ -50,16 +51,16 @@ def train_nn_classifier() -> None:
|
|
| 50 |
X_train=X_train,
|
| 51 |
y_train=y_train,
|
| 52 |
)
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
|
|
|
| 56 |
pred = np.argmax(pred, axis=1)
|
| 57 |
y_test = np.argmax(y_test, axis=1)
|
| 58 |
|
| 59 |
accuracy = accuracy_score(y_true=y_test, y_pred=pred)
|
| 60 |
-
|
| 61 |
print(f"accuracy on validation set: {accuracy:.4f}")
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
if __name__ == "__main__":
|
| 65 |
-
train_nn_classifier()
|
|
|
|
| 1 |
from sklearn import datasets
|
| 2 |
from sklearn.preprocessing import OneHotEncoder
|
| 3 |
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.metrics import accuracy_score
|
| 5 |
import numpy as np
|
| 6 |
from numpyneuron import (
|
| 7 |
NN,
|
|
|
|
| 14 |
RANDOM_SEED = 2
|
| 15 |
|
| 16 |
|
| 17 |
+
def preprocess_digits(
|
| 18 |
seed: int,
|
| 19 |
) -> tuple[np.ndarray, ...]:
|
| 20 |
digits = datasets.load_digits(as_frame=False)
|
|
|
|
| 30 |
return X_train, X_test, y_train, y_test
|
| 31 |
|
| 32 |
|
| 33 |
+
def train_nn_classifier(
|
| 34 |
+
X_train: np.ndarray,
|
| 35 |
+
y_train: np.ndarray,
|
| 36 |
+
) -> NN:
|
| 37 |
nn_classifier = NN(
|
| 38 |
epochs=2_000,
|
| 39 |
hidden_size=16,
|
|
|
|
| 51 |
X_train=X_train,
|
| 52 |
y_train=y_train,
|
| 53 |
)
|
| 54 |
+
return nn_classifier
|
| 55 |
+
|
| 56 |
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
X_train, X_test, y_train, y_test = preprocess_digits(seed=RANDOM_SEED)
|
| 59 |
+
classifier = train_nn_classifier(X_train, y_train)
|
| 60 |
|
| 61 |
+
pred = classifier.predict(X_test)
|
| 62 |
pred = np.argmax(pred, axis=1)
|
| 63 |
y_test = np.argmax(y_test, axis=1)
|
| 64 |
|
| 65 |
accuracy = accuracy_score(y_true=y_test, y_pred=pred)
|
|
|
|
| 66 |
print(f"accuracy on validation set: {accuracy:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|