import os import numpy as np from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras import layers, models import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix, classification_report import tensorflow as tf import json def create_model_directories(): """Create necessary directories for model storage""" os.makedirs('models', exist_ok=True) def train_model(): # Set the paths to your image folders BASE_DIR = os.path.join(os.getcwd(), 'dataset') train_dir = os.path.join(BASE_DIR, 'train') validation_dir = os.path.join(BASE_DIR, 'validation') test_dir = os.path.join(BASE_DIR, 'test') # Verify dataset directories exist for dir_path in [train_dir, validation_dir, test_dir]: if not os.path.exists(dir_path): raise FileNotFoundError(f"Directory not found: {dir_path}") # Set the parameters for the data generators batch_size = 32 img_height, img_width = 256, 256 # Create data generators with data augmentation for training train_datagen = ImageDataGenerator( rescale=1.0 / 255.0, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, vertical_flip=True ) train_generator = train_datagen.flow_from_directory( train_dir, target_size=(img_height, img_width), batch_size=batch_size, class_mode='categorical' ) validation_datagen = ImageDataGenerator(rescale=1.0 / 255.0) validation_generator = validation_datagen.flow_from_directory( validation_dir, target_size=(img_height, img_width), batch_size=batch_size, class_mode='categorical' ) # Create a CNN model cnn_model = models.Sequential([ layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), layers.MaxPooling2D((2, 2)), layers.Conv2D(64, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(128, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Flatten(), layers.Dense(64, activation='relu'), layers.Dense(train_generator.num_classes, activation='softmax') ]) # Compile the model cnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # Train the model print("Starting model training...") epochs = 10 history = cnn_model.fit( train_generator, epochs=epochs, validation_data=validation_generator ) # Save the trained model model_path = os.path.join('models', 'rose_model.h5') cnn_model.save(model_path) print(f"Model saved to: {model_path}") # Save class names class_names = list(train_generator.class_indices.keys()) class_names_path = os.path.join('models', 'class_names.json') with open(class_names_path, 'w') as f: json.dump(class_names, f) print(f"Class names saved to: {class_names_path}") # Evaluate the model test_datagen = ImageDataGenerator(rescale=1.0 / 255.0) test_generator = test_datagen.flow_from_directory( test_dir, target_size=(img_height, img_width), batch_size=batch_size, class_mode='categorical', shuffle=False ) # Make predictions test_predictions = cnn_model.predict(test_generator) predicted_labels = np.argmax(test_predictions, axis=1) true_labels = test_generator.classes # Calculate metrics print("\nClassification Report:") report = classification_report(true_labels, predicted_labels, target_names=class_names, output_dict=True) print(classification_report(true_labels, predicted_labels, target_names=class_names)) # Save metrics metrics_path = os.path.join('models', 'metrics.json') with open(metrics_path, 'w') as f: json.dump(report, f) print(f"Metrics saved to: {metrics_path}") # Calculate and save confusion matrix cm = confusion_matrix(true_labels, predicted_labels) cm_path = os.path.join('models', 'confusion_matrix.json') with open(cm_path, 'w') as f: json.dump(cm.tolist(), f) print(f"Confusion matrix saved to: {cm_path}") # Plot and save training history plt.figure(figsize=(12, 4)) # Plot training & validation accuracy plt.subplot(1, 2, 1) plt.plot(history.history['accuracy']) plt.plot(history.history['val_accuracy']) plt.title('Model Accuracy') plt.ylabel('Accuracy') plt.xlabel('Epoch') plt.legend(['Train', 'Validation'], loc='upper left') # Plot training & validation loss plt.subplot(1, 2, 2) plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('Model Loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['Train', 'Validation'], loc='upper left') plt.tight_layout() history_path = os.path.join('models', 'training_history.png') plt.savefig(history_path) plt.close() print(f"Training history plot saved to: {history_path}") if __name__ == "__main__": try: create_model_directories() train_model() print("\nTraining completed successfully!") except Exception as e: print(f"\nError during training: {str(e)}")