Spaces:
Build error
Build error
Commit
·
92f14e0
1
Parent(s):
f9522cf
better organization for generalization once we add other methods and such
Browse files- main.py +0 -5
- neural_network/main.py +6 -8
main.py
CHANGED
|
@@ -22,12 +22,7 @@ if __name__ == "__main__":
|
|
| 22 |
raise ValueError(f"Invalid method '{method}'. Choose 'nn' instead.")
|
| 23 |
|
| 24 |
X, y = random_dataset()
|
| 25 |
-
args = nn.get_args()
|
| 26 |
nn.main(
|
| 27 |
X=X,
|
| 28 |
y=y,
|
| 29 |
-
epochs=args["epochs"],
|
| 30 |
-
hidden_size=args["hidden_size"],
|
| 31 |
-
learning_rate=args["learning_rate"],
|
| 32 |
-
activation_func=args["activation_func"],
|
| 33 |
)
|
|
|
|
| 22 |
raise ValueError(f"Invalid method '{method}'. Choose 'nn' instead.")
|
| 23 |
|
| 24 |
X, y = random_dataset()
|
|
|
|
| 25 |
nn.main(
|
| 26 |
X=X,
|
| 27 |
y=y,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
neural_network/main.py
CHANGED
|
@@ -4,7 +4,7 @@ from neural_network.forwardprop import fp
|
|
| 4 |
from neural_network.backprop import bp
|
| 5 |
|
| 6 |
|
| 7 |
-
def get_args():
|
| 8 |
"""
|
| 9 |
returns a dictionary containing
|
| 10 |
the arguments to be passed to
|
|
@@ -34,16 +34,14 @@ def init(X: np.array, y: np.array, hidden_size: int) -> dict:
|
|
| 34 |
def main(
|
| 35 |
X: np.array,
|
| 36 |
y: np.array,
|
| 37 |
-
epochs: int,
|
| 38 |
-
hidden_size: int,
|
| 39 |
-
learning_rate: float,
|
| 40 |
-
activation_func: str,
|
| 41 |
) -> None:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
for e in range(epochs):
|
| 45 |
|
|
|
|
| 46 |
fp()
|
| 47 |
bp()
|
| 48 |
|
| 49 |
# update weights and biases
|
|
|
|
|
|
|
|
|
| 4 |
from neural_network.backprop import bp
|
| 5 |
|
| 6 |
|
| 7 |
+
def get_args() -> dict:
|
| 8 |
"""
|
| 9 |
returns a dictionary containing
|
| 10 |
the arguments to be passed to
|
|
|
|
| 34 |
def main(
|
| 35 |
X: np.array,
|
| 36 |
y: np.array,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
) -> None:
|
| 38 |
+
args = get_args()
|
| 39 |
+
wb = init(X, y, args["hidden_size"])
|
|
|
|
| 40 |
|
| 41 |
+
for e in range(args["epochs"]):
|
| 42 |
fp()
|
| 43 |
bp()
|
| 44 |
|
| 45 |
# update weights and biases
|
| 46 |
+
|
| 47 |
+
# print results
|