Rose_classification / src /train_model.py
Terence9's picture
Upload 4 files
eac85ba verified
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
import json
def create_model_directories():
"""Create necessary directories for model storage"""
os.makedirs('models', exist_ok=True)
def train_model():
# Set the paths to your image folders
BASE_DIR = os.path.join(os.getcwd(), 'dataset')
train_dir = os.path.join(BASE_DIR, 'train')
validation_dir = os.path.join(BASE_DIR, 'validation')
test_dir = os.path.join(BASE_DIR, 'test')
# Verify dataset directories exist
for dir_path in [train_dir, validation_dir, test_dir]:
if not os.path.exists(dir_path):
raise FileNotFoundError(f"Directory not found: {dir_path}")
# Set the parameters for the data generators
batch_size = 32
img_height, img_width = 256, 256
# Create data generators with data augmentation for training
train_datagen = ImageDataGenerator(
rescale=1.0 / 255.0,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True
)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical'
)
validation_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical'
)
# Create a CNN model
cnn_model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(train_generator.num_classes, activation='softmax')
])
# Compile the model
cnn_model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# Train the model
print("Starting model training...")
epochs = 10
history = cnn_model.fit(
train_generator,
epochs=epochs,
validation_data=validation_generator
)
# Save the trained model
model_path = os.path.join('models', 'rose_model.h5')
cnn_model.save(model_path)
print(f"Model saved to: {model_path}")
# Save class names
class_names = list(train_generator.class_indices.keys())
class_names_path = os.path.join('models', 'class_names.json')
with open(class_names_path, 'w') as f:
json.dump(class_names, f)
print(f"Class names saved to: {class_names_path}")
# Evaluate the model
test_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical',
shuffle=False
)
# Make predictions
test_predictions = cnn_model.predict(test_generator)
predicted_labels = np.argmax(test_predictions, axis=1)
true_labels = test_generator.classes
# Calculate metrics
print("\nClassification Report:")
report = classification_report(true_labels, predicted_labels, target_names=class_names, output_dict=True)
print(classification_report(true_labels, predicted_labels, target_names=class_names))
# Save metrics
metrics_path = os.path.join('models', 'metrics.json')
with open(metrics_path, 'w') as f:
json.dump(report, f)
print(f"Metrics saved to: {metrics_path}")
# Calculate and save confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)
cm_path = os.path.join('models', 'confusion_matrix.json')
with open(cm_path, 'w') as f:
json.dump(cm.tolist(), f)
print(f"Confusion matrix saved to: {cm_path}")
# Plot and save training history
plt.figure(figsize=(12, 4))
# Plot training & validation accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
# Plot training & validation loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.tight_layout()
history_path = os.path.join('models', 'training_history.png')
plt.savefig(history_path)
plt.close()
print(f"Training history plot saved to: {history_path}")
if __name__ == "__main__":
try:
create_model_directories()
train_model()
print("\nTraining completed successfully!")
except Exception as e:
print(f"\nError during training: {str(e)}")