import os import zipfile import tensorflow as tf from tensorflow.keras import layers, models from tensorflow.keras.preprocessing.image import ImageDataGenerator from sklearn.utils.class_weight import compute_class_weight import numpy as np import matplotlib.pyplot as plt # ======================================= # 1️⃣ Paths # ======================================= base_dir = "data" train_dir = os.path.join(base_dir, "train") val_dir = os.path.join(base_dir, "val") train_zip_fresh = os.path.join(train_dir, "fresh_f.zip") train_zip_nonfresh = os.path.join(train_dir, "non_fresh_f.zip") val_zip_fresh = os.path.join(val_dir, "fresh_f_val.zip") val_zip_nonfresh = os.path.join(val_dir, "non_fresh_f_val.zip") # ======================================= # 2️⃣ Unzip Function # ======================================= def unzip_if_needed(zip_path, extract_to): if os.path.exists(zip_path): folder_name = os.path.splitext(os.path.basename(zip_path))[0] extract_folder = os.path.join(extract_to, folder_name) if not os.path.exists(extract_folder): print(f"🔹 Extracting {zip_path} ...") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_to) print(f"✅ Extracted to: {extract_to}") else: print(f"✅ Already extracted: {extract_folder}") else: print(f"⚠️ Zip file not found: {zip_path}") # Unzip training & validation sets unzip_if_needed(train_zip_fresh, train_dir) unzip_if_needed(train_zip_nonfresh, train_dir) unzip_if_needed(val_zip_fresh, val_dir) unzip_if_needed(val_zip_nonfresh, val_dir) # ======================================= # 3️⃣ Data Loading with Augmentation # ======================================= IMG_SIZE = (224, 224) BATCH_SIZE = 32 train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=25, width_shift_range=0.2, height_shift_range=0.2, zoom_range=0.2, horizontal_flip=True, brightness_range=[0.8, 1.2] ) val_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_dir, target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='binary' ) val_generator = val_datagen.flow_from_directory( val_dir, target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode='binary' ) print("✅ Data loaded successfully!") # ======================================= # 4️⃣ Handle Class Imbalance # ======================================= class_weights = compute_class_weight( class_weight='balanced', classes=np.unique(train_generator.classes), y=train_generator.classes ) class_weights = dict(enumerate(class_weights)) print("📊 Class Weights:", class_weights) # ======================================= # 5️⃣ Define Model (Transfer Learning) # ======================================= base_model = tf.keras.applications.MobileNetV2( input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet' ) base_model.trainable = False model = models.Sequential([ base_model, layers.GlobalAveragePooling2D(), layers.Dense(128, activation='relu'), layers.Dropout(0.3), layers.Dense(1, activation='sigmoid') ]) model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'] ) model.summary() # ======================================= # 6️⃣ Train Model # ======================================= EPOCHS = 20 history = model.fit( train_generator, validation_data=val_generator, epochs=EPOCHS, class_weight=class_weights ) # ======================================= # 7️⃣ Save Model # ======================================= os.makedirs("models", exist_ok=True) h5_path = "models/fish_freshness_mobilenetv2.h5" model.save(h5_path) print(f"✅ Model saved successfully at: {h5_path}") # Convert to TFLite (for mobile) tflite_converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = tflite_converter.convert() tflite_path = "models/fish_freshness_mobilenetv2.tflite" with open(tflite_path, "wb") as f: f.write(tflite_model) print(f"📱 TFLite model saved at: {tflite_path}") # ======================================= # 8️⃣ Plot Training History # ======================================= plt.figure(figsize=(8, 5)) plt.plot(history.history['accuracy'], label='Train Accuracy') plt.plot(history.history['val_accuracy'], label='Validation Accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.title('Training vs Validation Accuracy') plt.legend() plt.show()