import os import json import numpy as np import tensorflow as tf from tensorflow.keras.preprocessing import image_dataset_from_directory from tensorflow.keras.applications import MobileNetV2 # Lightweight & effective for small datasets from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization, Rescaling from tensorflow.keras.models import Model, Sequential from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau from tensorflow.keras.optimizers import Adam, SGD import matplotlib.pyplot as plt # === Paths === DATA_DIR = "data/train" MODEL_SAVE_PATH = "src/model/dog_breed_classifier.h5" CLASS_NAMES_PATH = "src/model/class_names.json" IMG_SIZE = (224, 224) BATCH_SIZE = 32 SEED = 42 # === Load dataset === print("[INFO] Loading dataset...") train_ds = image_dataset_from_directory( DATA_DIR, validation_split=0.2, subset="training", seed=SEED, image_size=IMG_SIZE, batch_size=BATCH_SIZE ) val_ds = image_dataset_from_directory( DATA_DIR, validation_split=0.2, subset="validation", seed=SEED, image_size=IMG_SIZE, batch_size=BATCH_SIZE ) # Save class names for inference class_names = train_ds.class_names num_classes = len(class_names) print(f"[INFO] Classes found: {num_classes}") with open(CLASS_NAMES_PATH, "w") as f: json.dump(class_names, f) # === Data preprocessing & augmentation === resize_and_rescale = Sequential([ Rescaling(1./255) ]) data_augmentation = Sequential([ tf.keras.layers.RandomFlip("horizontal"), tf.keras.layers.RandomRotation(0.15), tf.keras.layers.RandomZoom(0.1) ]) AUTOTUNE = tf.data.AUTOTUNE train_ds = train_ds.map(lambda x, y: (resize_and_rescale(x), y)) train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y)) train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE) val_ds = val_ds.map(lambda x, y: (resize_and_rescale(x), y)) val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) # === Compute class weights (to handle class imbalance) === print("[INFO] Computing class weights...") y_train = np.concatenate([y.numpy() for _, y in train_ds], axis=0) class_counts = np.bincount(y_train) total = len(y_train) class_weights = {i: total / (num_classes * count) for i, count in enumerate(class_counts)} print("[INFO] Class weights applied.") # === Build model === print("[INFO] Building model...") base_model = MobileNetV2(input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet') base_model.trainable = False x = base_model.output x = GlobalAveragePooling2D()(x) x = BatchNormalization()(x) x = Dropout(0.4)(x) output = Dense(num_classes, activation='softmax')(x) model = Model(inputs=base_model.input, outputs=output) model.compile( optimizer=Adam(learning_rate=1e-4), loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) model.summary() # === Callbacks === os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True) checkpoint = ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True, verbose=1) earlystop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.3, patience=2, verbose=1) # === Phase 1: Train head === print("[INFO] Training model (frozen base)...") history = model.fit( train_ds, validation_data=val_ds, epochs=15, class_weight=class_weights, callbacks=[checkpoint, earlystop, reduce_lr] ) # === Phase 2: Fine-tune full model === print("[INFO] Fine-tuning entire model...") base_model.trainable = True model.compile( optimizer=SGD(learning_rate=1e-4, momentum=0.9), loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) fine_tune_epochs = 10 total_epochs = len(history.history["loss"]) + fine_tune_epochs fine_tune_history = model.fit( train_ds, validation_data=val_ds, epochs=total_epochs, initial_epoch=history.epoch[-1] + 1, class_weight=class_weights, callbacks=[checkpoint, earlystop, reduce_lr] ) # === Merge histories === for key in fine_tune_history.history: history.history[key] += fine_tune_history.history[key] # === Plot training results === plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(history.history['loss'], label='Train Loss') plt.plot(history.history['val_loss'], label='Val Loss') plt.title("Loss") plt.legend() plt.subplot(1, 2, 2) plt.plot(history.history['accuracy'], label='Train Acc') plt.plot(history.history['val_accuracy'], label='Val Acc') plt.title("Accuracy") plt.legend() plt.savefig("training_curves.png") plt.show() print(f"[DONE] Model saved to {MODEL_SAVE_PATH}")