| import os
|
| import json
|
| import tensorflow as tf
|
| from tensorflow.keras import layers, models
|
|
|
|
|
| IMG_SIZE = 224
|
| BATCH_SIZE = 32
|
| EPOCHS = 10
|
|
|
|
|
| DATA_DIR = "dataset_scripts/dataset/animals/animals"
|
|
|
| MODEL_DIR = "model"
|
| MODEL_PATH = os.path.join(MODEL_DIR, "animal_cnn.keras")
|
| NAMES_PATH = os.path.join(MODEL_DIR, "class_names.json")
|
|
|
|
|
| def load_data():
|
| train_ds = tf.keras.utils.image_dataset_from_directory(
|
| DATA_DIR,
|
| validation_split=0.2,
|
| subset="training",
|
| seed=42,
|
| image_size=(IMG_SIZE, IMG_SIZE),
|
| batch_size=BATCH_SIZE
|
| )
|
|
|
| val_ds = tf.keras.utils.image_dataset_from_directory(
|
| DATA_DIR,
|
| validation_split=0.2,
|
| subset="validation",
|
| seed=42,
|
| image_size=(IMG_SIZE, IMG_SIZE),
|
| batch_size=BATCH_SIZE
|
| )
|
|
|
| class_names = train_ds.class_names
|
| num_classes = len(class_names)
|
|
|
|
|
| normalization = layers.Rescaling(1./255)
|
|
|
| train_ds = train_ds.map(lambda x, y: (normalization(x), y))
|
| val_ds = val_ds.map(lambda x, y: (normalization(x), y))
|
|
|
| return train_ds, val_ds, class_names, num_classes
|
|
|
|
|
|
|
| def build_model(num_classes):
|
| base_model = tf.keras.applications.MobileNetV2(
|
| input_shape=(IMG_SIZE, IMG_SIZE, 3),
|
| include_top=False,
|
| weights="imagenet"
|
| )
|
|
|
| base_model.trainable = False
|
|
|
| model = models.Sequential([
|
| base_model,
|
| layers.GlobalAveragePooling2D(),
|
| layers.Dense(128, activation="relu"),
|
| layers.Dropout(0.3),
|
| layers.Dense(num_classes, activation="softmax")
|
| ])
|
|
|
| return model
|
|
|
|
|
|
|
| def train():
|
| train_ds, val_ds, class_names, num_classes = load_data()
|
|
|
| print(f"β
Classes: {num_classes}")
|
|
|
| model = build_model(num_classes)
|
|
|
| model.compile(
|
| optimizer="adam",
|
| loss="sparse_categorical_crossentropy",
|
| metrics=["accuracy"]
|
| )
|
|
|
| model.fit(
|
| train_ds,
|
| validation_data=val_ds,
|
| epochs=EPOCHS
|
| )
|
|
|
|
|
| os.makedirs(MODEL_DIR, exist_ok=True)
|
| model.save(MODEL_PATH)
|
|
|
|
|
| with open(NAMES_PATH, "w") as f:
|
| json.dump(class_names, f)
|
|
|
| print("β
Training complete!")
|
| print(f"π Model saved: {MODEL_PATH}")
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| train() |