|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.preprocessing import image_dataset_from_directory |
|
|
from tensorflow.keras.applications import MobileNetV2 |
|
|
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization, Rescaling |
|
|
from tensorflow.keras.models import Model, Sequential |
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau |
|
|
from tensorflow.keras.optimizers import Adam, SGD |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
DATA_DIR = "data/train" |
|
|
MODEL_SAVE_PATH = "src/model/dog_breed_classifier.h5" |
|
|
CLASS_NAMES_PATH = "src/model/class_names.json" |
|
|
IMG_SIZE = (224, 224) |
|
|
BATCH_SIZE = 32 |
|
|
SEED = 42 |
|
|
|
|
|
|
|
|
print("[INFO] Loading dataset...") |
|
|
train_ds = image_dataset_from_directory( |
|
|
DATA_DIR, |
|
|
validation_split=0.2, |
|
|
subset="training", |
|
|
seed=SEED, |
|
|
image_size=IMG_SIZE, |
|
|
batch_size=BATCH_SIZE |
|
|
) |
|
|
|
|
|
val_ds = image_dataset_from_directory( |
|
|
DATA_DIR, |
|
|
validation_split=0.2, |
|
|
subset="validation", |
|
|
seed=SEED, |
|
|
image_size=IMG_SIZE, |
|
|
batch_size=BATCH_SIZE |
|
|
) |
|
|
|
|
|
|
|
|
class_names = train_ds.class_names |
|
|
num_classes = len(class_names) |
|
|
print(f"[INFO] Classes found: {num_classes}") |
|
|
|
|
|
with open(CLASS_NAMES_PATH, "w") as f: |
|
|
json.dump(class_names, f) |
|
|
|
|
|
|
|
|
resize_and_rescale = Sequential([ |
|
|
Rescaling(1./255) |
|
|
]) |
|
|
|
|
|
data_augmentation = Sequential([ |
|
|
tf.keras.layers.RandomFlip("horizontal"), |
|
|
tf.keras.layers.RandomRotation(0.15), |
|
|
tf.keras.layers.RandomZoom(0.1) |
|
|
]) |
|
|
|
|
|
AUTOTUNE = tf.data.AUTOTUNE |
|
|
train_ds = train_ds.map(lambda x, y: (resize_and_rescale(x), y)) |
|
|
train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y)) |
|
|
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE) |
|
|
|
|
|
val_ds = val_ds.map(lambda x, y: (resize_and_rescale(x), y)) |
|
|
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) |
|
|
|
|
|
|
|
|
print("[INFO] Computing class weights...") |
|
|
y_train = np.concatenate([y.numpy() for _, y in train_ds], axis=0) |
|
|
class_counts = np.bincount(y_train) |
|
|
total = len(y_train) |
|
|
class_weights = {i: total / (num_classes * count) for i, count in enumerate(class_counts)} |
|
|
print("[INFO] Class weights applied.") |
|
|
|
|
|
|
|
|
print("[INFO] Building model...") |
|
|
base_model = MobileNetV2(input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet') |
|
|
base_model.trainable = False |
|
|
|
|
|
x = base_model.output |
|
|
x = GlobalAveragePooling2D()(x) |
|
|
x = BatchNormalization()(x) |
|
|
x = Dropout(0.4)(x) |
|
|
output = Dense(num_classes, activation='softmax')(x) |
|
|
|
|
|
model = Model(inputs=base_model.input, outputs=output) |
|
|
model.compile( |
|
|
optimizer=Adam(learning_rate=1e-4), |
|
|
loss='sparse_categorical_crossentropy', |
|
|
metrics=['accuracy'] |
|
|
) |
|
|
|
|
|
model.summary() |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True) |
|
|
checkpoint = ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True, verbose=1) |
|
|
earlystop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) |
|
|
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.3, patience=2, verbose=1) |
|
|
|
|
|
|
|
|
print("[INFO] Training model (frozen base)...") |
|
|
history = model.fit( |
|
|
train_ds, |
|
|
validation_data=val_ds, |
|
|
epochs=15, |
|
|
class_weight=class_weights, |
|
|
callbacks=[checkpoint, earlystop, reduce_lr] |
|
|
) |
|
|
|
|
|
|
|
|
print("[INFO] Fine-tuning entire model...") |
|
|
base_model.trainable = True |
|
|
|
|
|
model.compile( |
|
|
optimizer=SGD(learning_rate=1e-4, momentum=0.9), |
|
|
loss='sparse_categorical_crossentropy', |
|
|
metrics=['accuracy'] |
|
|
) |
|
|
|
|
|
fine_tune_epochs = 10 |
|
|
total_epochs = len(history.history["loss"]) + fine_tune_epochs |
|
|
|
|
|
fine_tune_history = model.fit( |
|
|
train_ds, |
|
|
validation_data=val_ds, |
|
|
epochs=total_epochs, |
|
|
initial_epoch=history.epoch[-1] + 1, |
|
|
class_weight=class_weights, |
|
|
callbacks=[checkpoint, earlystop, reduce_lr] |
|
|
) |
|
|
|
|
|
|
|
|
for key in fine_tune_history.history: |
|
|
history.history[key] += fine_tune_history.history[key] |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 4)) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
plt.plot(history.history['loss'], label='Train Loss') |
|
|
plt.plot(history.history['val_loss'], label='Val Loss') |
|
|
plt.title("Loss") |
|
|
plt.legend() |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
|
plt.plot(history.history['accuracy'], label='Train Acc') |
|
|
plt.plot(history.history['val_accuracy'], label='Val Acc') |
|
|
plt.title("Accuracy") |
|
|
plt.legend() |
|
|
|
|
|
plt.savefig("training_curves.png") |
|
|
plt.show() |
|
|
|
|
|
print(f"[DONE] Model saved to {MODEL_SAVE_PATH}") |