Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| from tensorflow.keras import layers, models | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import confusion_matrix, classification_report | |
| import tensorflow as tf | |
| # Set the paths to your image folders | |
| BASE_DIR = os.path.join(os.getcwd(), 'dataset') | |
| train_dir = os.path.join(BASE_DIR, 'train') | |
| validation_dir = os.path.join(BASE_DIR, 'validation') | |
| test_dir = os.path.join(BASE_DIR, 'test') | |
| # Set the parameters for the data generators | |
| batch_size = 32 | |
| img_height, img_width = 256, 256 | |
| # Create data generators with data augmentation for training and validation | |
| train_datagen = ImageDataGenerator( | |
| rescale=1.0 / 255.0, | |
| rotation_range=20, | |
| width_shift_range=0.2, | |
| height_shift_range=0.2, | |
| shear_range=0.2, | |
| zoom_range=0.2, | |
| horizontal_flip=True, | |
| vertical_flip=True | |
| ) | |
| train_generator = train_datagen.flow_from_directory( | |
| train_dir, | |
| target_size=(img_height, img_width), | |
| batch_size=batch_size, | |
| class_mode='categorical' | |
| ) | |
| validation_datagen = ImageDataGenerator(rescale=1.0 / 255.0) | |
| validation_generator = validation_datagen.flow_from_directory( | |
| validation_dir, | |
| target_size=(img_height, img_width), | |
| batch_size=batch_size, | |
| class_mode='categorical' | |
| ) | |
| # Create a CNN model | |
| cnn_model = models.Sequential([ | |
| layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Conv2D(64, (3, 3), activation='relu'), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Conv2D(128, (3, 3), activation='relu'), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Flatten(), | |
| layers.Dense(64, activation='relu'), | |
| layers.Dense(train_generator.num_classes, activation='softmax') | |
| ]) | |
| # Compile the CNN model | |
| cnn_model.compile(optimizer='adam', | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy']) | |
| # Train the CNN model | |
| epochs = 10 | |
| history = cnn_model.fit( | |
| train_generator, | |
| epochs=epochs, | |
| validation_data=validation_generator | |
| ) | |
| # Save the trained model | |
| model_path = os.path.join(os.getcwd(), 'models', 'rose_model.h5') | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| cnn_model.save(model_path) | |
| # Save class names | |
| class_names = list(train_generator.class_indices.keys()) | |
| class_names_path = os.path.join(os.getcwd(), 'models', 'class_names.json') | |
| import json | |
| with open(class_names_path, 'w') as f: | |
| json.dump(class_names, f) | |
| # Evaluate the model | |
| test_datagen = ImageDataGenerator(rescale=1.0 / 255.0) | |
| test_generator = test_datagen.flow_from_directory( | |
| test_dir, | |
| target_size=(img_height, img_width), | |
| batch_size=batch_size, | |
| class_mode='categorical', | |
| shuffle=False | |
| ) | |
| # Make predictions | |
| test_predictions = cnn_model.predict(test_generator) | |
| predicted_labels = np.argmax(test_predictions, axis=1) | |
| true_labels = test_generator.classes | |
| # Calculate and print metrics | |
| print("\nClassification Report:") | |
| print(classification_report(true_labels, predicted_labels, target_names=class_names)) | |
| # Plot training history | |
| plt.figure(figsize=(12, 4)) | |
| # Plot training & validation accuracy | |
| plt.subplot(1, 2, 1) | |
| plt.plot(history.history['accuracy']) | |
| plt.plot(history.history['val_accuracy']) | |
| plt.title('Model Accuracy') | |
| plt.ylabel('Accuracy') | |
| plt.xlabel('Epoch') | |
| plt.legend(['Train', 'Validation'], loc='upper left') | |
| # Plot training & validation loss | |
| plt.subplot(1, 2, 2) | |
| plt.plot(history.history['loss']) | |
| plt.plot(history.history['val_loss']) | |
| plt.title('Model Loss') | |
| plt.ylabel('Loss') | |
| plt.xlabel('Epoch') | |
| plt.legend(['Train', 'Validation'], loc='upper left') | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(os.getcwd(), 'models', 'training_history.png')) | |
| plt.close() | |
| print("\nModel saved to:", model_path) | |
| print("Class names saved to:", class_names_path) | |
| print("Training history plot saved to:", os.path.join(os.getcwd(), 'models', 'training_history.png')) |