File size: 5,574 Bytes
eac85ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
import json

def create_model_directories():
    """Create necessary directories for model storage"""
    os.makedirs('models', exist_ok=True)

def train_model():
    # Set the paths to your image folders
    BASE_DIR = os.path.join(os.getcwd(), 'dataset')
    train_dir = os.path.join(BASE_DIR, 'train')
    validation_dir = os.path.join(BASE_DIR, 'validation')
    test_dir = os.path.join(BASE_DIR, 'test')

    # Verify dataset directories exist
    for dir_path in [train_dir, validation_dir, test_dir]:
        if not os.path.exists(dir_path):
            raise FileNotFoundError(f"Directory not found: {dir_path}")

    # Set the parameters for the data generators
    batch_size = 32
    img_height, img_width = 256, 256

    # Create data generators with data augmentation for training
    train_datagen = ImageDataGenerator(
        rescale=1.0 / 255.0,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True
    )

    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        class_mode='categorical'
    )

    validation_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
    validation_generator = validation_datagen.flow_from_directory(
        validation_dir,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        class_mode='categorical'
    )

    # Create a CNN model
    cnn_model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(train_generator.num_classes, activation='softmax')
    ])

    # Compile the model
    cnn_model.compile(optimizer='adam',
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])

    # Train the model
    print("Starting model training...")
    epochs = 10
    history = cnn_model.fit(
        train_generator,
        epochs=epochs,
        validation_data=validation_generator
    )

    # Save the trained model
    model_path = os.path.join('models', 'rose_model.h5')
    cnn_model.save(model_path)
    print(f"Model saved to: {model_path}")

    # Save class names
    class_names = list(train_generator.class_indices.keys())
    class_names_path = os.path.join('models', 'class_names.json')
    with open(class_names_path, 'w') as f:
        json.dump(class_names, f)
    print(f"Class names saved to: {class_names_path}")

    # Evaluate the model
    test_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
    test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        class_mode='categorical',
        shuffle=False
    )

    # Make predictions
    test_predictions = cnn_model.predict(test_generator)
    predicted_labels = np.argmax(test_predictions, axis=1)
    true_labels = test_generator.classes

    # Calculate metrics
    print("\nClassification Report:")
    report = classification_report(true_labels, predicted_labels, target_names=class_names, output_dict=True)
    print(classification_report(true_labels, predicted_labels, target_names=class_names))

    # Save metrics
    metrics_path = os.path.join('models', 'metrics.json')
    with open(metrics_path, 'w') as f:
        json.dump(report, f)
    print(f"Metrics saved to: {metrics_path}")

    # Calculate and save confusion matrix
    cm = confusion_matrix(true_labels, predicted_labels)
    cm_path = os.path.join('models', 'confusion_matrix.json')
    with open(cm_path, 'w') as f:
        json.dump(cm.tolist(), f)
    print(f"Confusion matrix saved to: {cm_path}")

    # Plot and save training history
    plt.figure(figsize=(12, 4))

    # Plot training & validation accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')

    # Plot training & validation loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')

    plt.tight_layout()
    history_path = os.path.join('models', 'training_history.png')
    plt.savefig(history_path)
    plt.close()
    print(f"Training history plot saved to: {history_path}")

if __name__ == "__main__":
    try:
        create_model_directories()
        train_model()
        print("\nTraining completed successfully!")
    except Exception as e:
        print(f"\nError during training: {str(e)}")