| | import os
|
| | import zipfile
|
| | import tensorflow as tf
|
| | from tensorflow.keras import layers, models
|
| | from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
| | from sklearn.utils.class_weight import compute_class_weight
|
| | import numpy as np
|
| | import matplotlib.pyplot as plt
|
| |
|
| |
|
| |
|
| |
|
| | base_dir = "data"
|
| | train_dir = os.path.join(base_dir, "train")
|
| | val_dir = os.path.join(base_dir, "val")
|
| |
|
| | train_zip_fresh = os.path.join(train_dir, "fresh_f.zip")
|
| | train_zip_nonfresh = os.path.join(train_dir, "non_fresh_f.zip")
|
| | val_zip_fresh = os.path.join(val_dir, "fresh_f_val.zip")
|
| | val_zip_nonfresh = os.path.join(val_dir, "non_fresh_f_val.zip")
|
| |
|
| |
|
| |
|
| |
|
| | def unzip_if_needed(zip_path, extract_to):
|
| | if os.path.exists(zip_path):
|
| | folder_name = os.path.splitext(os.path.basename(zip_path))[0]
|
| | extract_folder = os.path.join(extract_to, folder_name)
|
| | if not os.path.exists(extract_folder):
|
| | print(f"🔹 Extracting {zip_path} ...")
|
| | with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| | zip_ref.extractall(extract_to)
|
| | print(f"✅ Extracted to: {extract_to}")
|
| | else:
|
| | print(f"✅ Already extracted: {extract_folder}")
|
| | else:
|
| | print(f"⚠️ Zip file not found: {zip_path}")
|
| |
|
| |
|
| | unzip_if_needed(train_zip_fresh, train_dir)
|
| | unzip_if_needed(train_zip_nonfresh, train_dir)
|
| | unzip_if_needed(val_zip_fresh, val_dir)
|
| | unzip_if_needed(val_zip_nonfresh, val_dir)
|
| |
|
| |
|
| |
|
| |
|
| | IMG_SIZE = (224, 224)
|
| | BATCH_SIZE = 32
|
| |
|
| | train_datagen = ImageDataGenerator(
|
| | rescale=1./255,
|
| | rotation_range=25,
|
| | width_shift_range=0.2,
|
| | height_shift_range=0.2,
|
| | zoom_range=0.2,
|
| | horizontal_flip=True,
|
| | brightness_range=[0.8, 1.2]
|
| | )
|
| |
|
| | val_datagen = ImageDataGenerator(rescale=1./255)
|
| |
|
| | train_generator = train_datagen.flow_from_directory(
|
| | train_dir,
|
| | target_size=IMG_SIZE,
|
| | batch_size=BATCH_SIZE,
|
| | class_mode='binary'
|
| | )
|
| |
|
| | val_generator = val_datagen.flow_from_directory(
|
| | val_dir,
|
| | target_size=IMG_SIZE,
|
| | batch_size=BATCH_SIZE,
|
| | class_mode='binary'
|
| | )
|
| |
|
| | print("✅ Data loaded successfully!")
|
| |
|
| |
|
| |
|
| |
|
| | class_weights = compute_class_weight(
|
| | class_weight='balanced',
|
| | classes=np.unique(train_generator.classes),
|
| | y=train_generator.classes
|
| | )
|
| | class_weights = dict(enumerate(class_weights))
|
| | print("📊 Class Weights:", class_weights)
|
| |
|
| |
|
| |
|
| |
|
| | base_model = tf.keras.applications.MobileNetV2(
|
| | input_shape=IMG_SIZE + (3,),
|
| | include_top=False,
|
| | weights='imagenet'
|
| | )
|
| | base_model.trainable = False
|
| |
|
| | model = models.Sequential([
|
| | base_model,
|
| | layers.GlobalAveragePooling2D(),
|
| | layers.Dense(128, activation='relu'),
|
| | layers.Dropout(0.3),
|
| | layers.Dense(1, activation='sigmoid')
|
| | ])
|
| |
|
| | model.compile(
|
| | optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
|
| | loss='binary_crossentropy',
|
| | metrics=['accuracy']
|
| | )
|
| |
|
| | model.summary()
|
| |
|
| |
|
| |
|
| |
|
| | EPOCHS = 20
|
| | history = model.fit(
|
| | train_generator,
|
| | validation_data=val_generator,
|
| | epochs=EPOCHS,
|
| | class_weight=class_weights
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| | os.makedirs("models", exist_ok=True)
|
| |
|
| | h5_path = "models/fish_freshness_mobilenetv2.h5"
|
| | model.save(h5_path)
|
| | print(f"✅ Model saved successfully at: {h5_path}")
|
| |
|
| |
|
| | tflite_converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
| | tflite_model = tflite_converter.convert()
|
| |
|
| | tflite_path = "models/fish_freshness_mobilenetv2.tflite"
|
| | with open(tflite_path, "wb") as f:
|
| | f.write(tflite_model)
|
| |
|
| | print(f"📱 TFLite model saved at: {tflite_path}")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | plt.figure(figsize=(8, 5))
|
| | plt.plot(history.history['accuracy'], label='Train Accuracy')
|
| | plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
|
| | plt.xlabel('Epochs')
|
| | plt.ylabel('Accuracy')
|
| | plt.title('Training vs Validation Accuracy')
|
| | plt.legend()
|
| | plt.show()
|
| |
|
| |
|
| |
|