mariajessington's picture
Upload 5 files
31e9cbe verified
import os
import json
import tensorflow as tf
from tensorflow.keras import layers, models
# ─── CONFIG ─────────────────────────────
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 10
# βœ… FIXED PATH
DATA_DIR = "dataset_scripts/dataset/animals/animals"
MODEL_DIR = "model"
MODEL_PATH = os.path.join(MODEL_DIR, "animal_cnn.keras")
NAMES_PATH = os.path.join(MODEL_DIR, "class_names.json")
# ─── LOAD DATA ──────────────────────────
def load_data():
train_ds = tf.keras.utils.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="training",
seed=42,
image_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE
)
val_ds = tf.keras.utils.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="validation",
seed=42,
image_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE
)
class_names = train_ds.class_names
num_classes = len(class_names)
# Normalize
normalization = layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization(x), y))
val_ds = val_ds.map(lambda x, y: (normalization(x), y))
return train_ds, val_ds, class_names, num_classes
# ─── MODEL ─────────────────────────────
def build_model(num_classes):
base_model = tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3),
include_top=False,
weights="imagenet"
)
base_model.trainable = False # freeze
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax")
])
return model
# ─── TRAIN ─────────────────────────────
def train():
train_ds, val_ds, class_names, num_classes = load_data()
print(f"βœ… Classes: {num_classes}")
model = build_model(num_classes)
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS
)
# Save model
os.makedirs(MODEL_DIR, exist_ok=True)
model.save(MODEL_PATH)
# Save class names
with open(NAMES_PATH, "w") as f:
json.dump(class_names, f)
print("βœ… Training complete!")
print(f"πŸ“ Model saved: {MODEL_PATH}")
# ─── RUN ───────────────────────────────
if __name__ == "__main__":
train()