Rose_classification / src /train_in_space.py
Terence9's picture
Create train_in_space.py
43e2735 verified
import os
import streamlit as st
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import json
from datetime import datetime
def create_model_directories():
"""Create necessary directories for model storage"""
os.makedirs('models', exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join('models', f'run_{timestamp}')
os.makedirs(run_dir, exist_ok=True)
return run_dir
def train_model(epochs, batch_size, learning_rate):
# Create directories
run_dir = create_model_directories()
# Set the paths to your image folders
BASE_DIR = os.path.join(os.getcwd(), 'dataset')
train_dir = os.path.join(BASE_DIR, 'train')
validation_dir = os.path.join(BASE_DIR, 'validation')
test_dir = os.path.join(BASE_DIR, 'test')
# Verify dataset directories exist
for dir_path in [train_dir, validation_dir, test_dir]:
if not os.path.exists(dir_path):
st.error(f"Directory not found: {dir_path}")
return None, None
# Set the parameters for the data generators
img_height, img_width = 256, 256
# Create data generators with data augmentation for training
train_datagen = ImageDataGenerator(
rescale=1.0/255.0,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
fill_mode='nearest'
)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical'
)
validation_datagen = ImageDataGenerator(rescale=1.0/255.0)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical'
)
# Create a CNN model
cnn_model = models.Sequential([
# First Convolutional Block
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# Second Convolutional Block
layers.Conv2D(64, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# Third Convolutional Block
layers.Conv2D(128, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# Dense Layers
layers.Flatten(),
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(train_generator.num_classes, activation='softmax')
])
# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
cnn_model.compile(
optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Training callbacks
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
),
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(run_dir, 'best_model.h5'),
monitor='val_accuracy',
save_best_only=True
)
]
# Train the model
st.write("Starting model training...")
st.write(f"Number of classes: {train_generator.num_classes}")
st.write(f"Training samples: {train_generator.samples}")
st.write(f"Validation samples: {validation_generator.samples}")
history = cnn_model.fit(
train_generator,
epochs=epochs,
validation_data=validation_generator,
callbacks=callbacks
)
# Save the model
model_path = os.path.join(run_dir, 'rose_model.h5')
cnn_model.save(model_path)
st.success(f"Model saved to: {model_path}")
# Save class names
class_names = list(train_generator.class_indices.keys())
class_names_path = os.path.join(run_dir, 'class_names.json')
with open(class_names_path, 'w') as f:
json.dump(class_names, f)
st.success(f"Class names saved to: {class_names_path}")
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Plot accuracy
ax1.plot(history.history['accuracy'], label='Training Accuracy')
ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax1.set_title('Model Accuracy')
ax1.set_ylabel('Accuracy')
ax1.set_xlabel('Epoch')
ax1.legend(loc='lower right')
ax1.grid(True)
# Plot loss
ax2.plot(history.history['loss'], label='Training Loss')
ax2.plot(history.history['val_loss'], label='Validation Loss')
ax2.set_title('Model Loss')
ax2.set_ylabel('Loss')
ax2.set_xlabel('Epoch')
ax2.legend(loc='upper right')
ax2.grid(True)
plt.tight_layout()
history_path = os.path.join(run_dir, 'training_history.png')
plt.savefig(history_path)
plt.close()
return history_path, run_dir
def main():
st.title("Rose Classification Model Training")
st.write("Train your rose classification model with custom parameters")
# Training parameters
col1, col2 = st.columns(2)
with col1:
epochs = st.slider("Number of Epochs", min_value=1, max_value=100, value=50, step=1)
batch_size = st.slider("Batch Size", min_value=8, max_value=64, value=32, step=8)
with col2:
learning_rate = st.slider("Learning Rate", min_value=0.0001, max_value=0.01, value=0.001, step=0.0001)
if st.button("Start Training"):
with st.spinner("Training in progress..."):
history_path, run_dir = train_model(epochs, batch_size, learning_rate)
if history_path and run_dir:
st.image(history_path, caption="Training History")
st.success(f"Training completed! Files saved in: {run_dir}")
if __name__ == "__main__":
main()