Archana-g-123's picture
Upload 8 files
6ff79bd verified
import os
import zipfile
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import matplotlib.pyplot as plt
# =======================================
# 1️⃣ Paths
# =======================================
base_dir = "data"
train_dir = os.path.join(base_dir, "train")
val_dir = os.path.join(base_dir, "val")
train_zip_fresh = os.path.join(train_dir, "fresh_f.zip")
train_zip_nonfresh = os.path.join(train_dir, "non_fresh_f.zip")
val_zip_fresh = os.path.join(val_dir, "fresh_f_val.zip")
val_zip_nonfresh = os.path.join(val_dir, "non_fresh_f_val.zip")
# =======================================
# 2️⃣ Unzip Function
# =======================================
def unzip_if_needed(zip_path, extract_to):
if os.path.exists(zip_path):
folder_name = os.path.splitext(os.path.basename(zip_path))[0]
extract_folder = os.path.join(extract_to, folder_name)
if not os.path.exists(extract_folder):
print(f"🔹 Extracting {zip_path} ...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
print(f"✅ Extracted to: {extract_to}")
else:
print(f"✅ Already extracted: {extract_folder}")
else:
print(f"⚠️ Zip file not found: {zip_path}")
# Unzip training & validation sets
unzip_if_needed(train_zip_fresh, train_dir)
unzip_if_needed(train_zip_nonfresh, train_dir)
unzip_if_needed(val_zip_fresh, val_dir)
unzip_if_needed(val_zip_nonfresh, val_dir)
# =======================================
# 3️⃣ Data Loading with Augmentation
# =======================================
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=25,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
brightness_range=[0.8, 1.2]
)
val_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='binary'
)
val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=IMG_SIZE,
batch_size=BATCH_SIZE,
class_mode='binary'
)
print("✅ Data loaded successfully!")
# =======================================
# 4️⃣ Handle Class Imbalance
# =======================================
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.unique(train_generator.classes),
y=train_generator.classes
)
class_weights = dict(enumerate(class_weights))
print("📊 Class Weights:", class_weights)
# =======================================
# 5️⃣ Define Model (Transfer Learning)
# =======================================
base_model = tf.keras.applications.MobileNetV2(
input_shape=IMG_SIZE + (3,),
include_top=False,
weights='imagenet'
)
base_model.trainable = False
model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation='relu'),
layers.Dropout(0.3),
layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss='binary_crossentropy',
metrics=['accuracy']
)
model.summary()
# =======================================
# 6️⃣ Train Model
# =======================================
EPOCHS = 20
history = model.fit(
train_generator,
validation_data=val_generator,
epochs=EPOCHS,
class_weight=class_weights
)
# =======================================
# 7️⃣ Save Model
# =======================================
os.makedirs("models", exist_ok=True)
h5_path = "models/fish_freshness_mobilenetv2.h5"
model.save(h5_path)
print(f"✅ Model saved successfully at: {h5_path}")
# Convert to TFLite (for mobile)
tflite_converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = tflite_converter.convert()
tflite_path = "models/fish_freshness_mobilenetv2.tflite"
with open(tflite_path, "wb") as f:
f.write(tflite_model)
print(f"📱 TFLite model saved at: {tflite_path}")
# =======================================
# 8️⃣ Plot Training History
# =======================================
plt.figure(figsize=(8, 5))
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training vs Validation Accuracy')
plt.legend()
plt.show()