OneclickAI commited on
Commit
e2121f9
·
verified ·
1 Parent(s): 305769a
Files changed (1) hide show
  1. train.py +0 -60
train.py DELETED
@@ -1,60 +0,0 @@
1
- import numpy as np
2
- import tensorflow as tf
3
- from tensorflow import keras
4
- from keras import layers
5
-
6
- # Keras 라이브러리를 통해 MNIST 데이터셋을 손쉽게 불러옵니다.
7
- (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
8
-
9
- # 정규화: 픽셀 값의 범위를 0~255에서 0~1 사이로 조정하여 학습 안정성 및 속도를 높입니다.
10
- x_train = x_train.astype("float32") / 255.0
11
- x_test = x_test.astype("float32") / 255.0
12
-
13
- # 채널 차원 추가: 흑백 이미지(채널 1)의 차원을 명시적으로 추가합니다.
14
- x_train = np.expand_dims(x_train, -1)
15
- x_test = np.expand_dims(x_test, -1)
16
-
17
- # 레이블 원-핫 인코딩: 숫자 '5'를 [0,0,0,0,0,1,0,0,0,0] 형태의 벡터로 변환합니다.
18
- num_classes = 10
19
- y_train = keras.utils.to_categorical(y_train, num_classes)
20
- y_test = keras.utils.to_categorical(y_test, num_classes)
21
-
22
- model = keras.Sequential([
23
- keras.Input(shape=(28, 28, 1)), # 입력 레이어
24
- layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
25
- layers.MaxPooling2D(pool_size=(2, 2)),
26
- layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
27
- layers.MaxPooling2D(pool_size=(2, 2)),
28
- layers.Flatten(),
29
- layers.Dropout(0.5),
30
- layers.Dense(num_classes, activation="softmax")
31
- ])
32
-
33
- model.compile(
34
- # 손실 함수(Loss Function): 모델의 예측이 정답과 얼마나 다른지 측정합니다.
35
- loss="categorical_crossentropy",
36
- # 옵티마이저(Optimizer): 손실을 최소화하기 위해 모델의 가중치를 업데이트하는 방법입니다.
37
- optimizer="adam",
38
- # 평가지표(Metrics): 훈련 과정을 모니터링할 지표로, 정확도를 사용합니다.
39
- metrics=["accuracy"]
40
- )
41
-
42
- batch_size = 128
43
- epochs = 15
44
-
45
- # 모델 학습 실행
46
- history = model.fit(
47
- x_train, y_train,
48
- batch_size=batch_size,
49
- epochs=epochs,
50
- validation_data=(x_test, y_test)
51
- )
52
-
53
- # 학습 완료 후 최종 성능 평가
54
- score = model.evaluate(x_test, y_test, verbose=0)
55
- print(f"\nTest loss: {score[0]:.4f}")
56
- print(f"Test accuracy: {score[1]:.4f}")
57
-
58
- # 모델의 구조, 가중치, 학습 설정을 모두 '.keras' 파일 하나에 저장합니다.
59
- model.save("my_keras_model.keras")
60
- print("\nModel saved to my_keras_model.keras")