| import os | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| from tensorflow.keras.models import Sequential | |
| from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout | |
| from tensorflow.keras.optimizers import Adam | |
| data_dir = "Dataset" | |
| classes = ["Alluvial", "Black", "Clay", "Red"] | |
| img_height, img_width = 224, 224 | |
| batch_size = 32 | |
| epochs = 10 | |
| datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2) | |
| train_data = datagen.flow_from_directory( | |
| data_dir, | |
| target_size=(img_height, img_width), | |
| batch_size=batch_size, | |
| class_mode='categorical', | |
| subset='training' | |
| ) | |
| val_data = datagen.flow_from_directory( | |
| data_dir, | |
| target_size=(img_height, img_width), | |
| batch_size=batch_size, | |
| class_mode='categorical', | |
| subset='validation' | |
| ) | |
| model = Sequential([ | |
| Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), | |
| MaxPooling2D(2, 2), | |
| Conv2D(64, (3, 3), activation='relu'), | |
| MaxPooling2D(2, 2), | |
| Conv2D(128, (3, 3), activation='relu'), | |
| MaxPooling2D(2, 2), | |
| Flatten(), | |
| Dense(128, activation='relu'), | |
| Dropout(0.5), | |
| Dense(len(classes), activation='softmax') | |
| ]) | |
| model.compile(optimizer=Adam(), | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy']) | |
| history = model.fit( | |
| train_data, | |
| validation_data=val_data, | |
| epochs=epochs | |
| ) | |
| model.save("SoilNet.keras") | |
| plt.figure(figsize=(10, 4)) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(history.history['accuracy'], label='Train Accuracy') | |
| plt.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
| plt.title('Accuracy') | |
| plt.legend() | |
| plt.subplot(1, 2, 2) | |
| plt.plot(history.history['loss'], label='Train Loss') | |
| plt.plot(history.history['val_loss'], label='Validation Loss') | |
| plt.title('Loss') | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.show() | |