File size: 2,934 Bytes
31e9cbe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import os
import json
import tensorflow as tf
from tensorflow.keras import layers, models
# βββ CONFIG βββββββββββββββββββββββββββββ
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 10
# β
FIXED PATH
DATA_DIR = "dataset_scripts/dataset/animals/animals"
MODEL_DIR = "model"
MODEL_PATH = os.path.join(MODEL_DIR, "animal_cnn.keras")
NAMES_PATH = os.path.join(MODEL_DIR, "class_names.json")
# βββ LOAD DATA ββββββββββββββββββββββββββ
def load_data():
train_ds = tf.keras.utils.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="training",
seed=42,
image_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE
)
val_ds = tf.keras.utils.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="validation",
seed=42,
image_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE
)
class_names = train_ds.class_names
num_classes = len(class_names)
# Normalize
normalization = layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization(x), y))
val_ds = val_ds.map(lambda x, y: (normalization(x), y))
return train_ds, val_ds, class_names, num_classes
# βββ MODEL βββββββββββββββββββββββββββββ
def build_model(num_classes):
base_model = tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3),
include_top=False,
weights="imagenet"
)
base_model.trainable = False # freeze
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax")
])
return model
# βββ TRAIN βββββββββββββββββββββββββββββ
def train():
train_ds, val_ds, class_names, num_classes = load_data()
print(f"β
Classes: {num_classes}")
model = build_model(num_classes)
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS
)
# Save model
os.makedirs(MODEL_DIR, exist_ok=True)
model.save(MODEL_PATH)
# Save class names
with open(NAMES_PATH, "w") as f:
json.dump(class_names, f)
print("β
Training complete!")
print(f"π Model saved: {MODEL_PATH}")
# βββ RUN βββββββββββββββββββββββββββββββ
if __name__ == "__main__":
train() |