Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, models | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import json | |
| from datetime import datetime | |
| def create_model_directories(): | |
| """Create necessary directories for model storage""" | |
| os.makedirs('models', exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| run_dir = os.path.join('models', f'run_{timestamp}') | |
| os.makedirs(run_dir, exist_ok=True) | |
| return run_dir | |
| def train_model(epochs, batch_size, learning_rate): | |
| # Create directories | |
| run_dir = create_model_directories() | |
| # Set the paths to your image folders | |
| BASE_DIR = os.path.join(os.getcwd(), 'dataset') | |
| train_dir = os.path.join(BASE_DIR, 'train') | |
| validation_dir = os.path.join(BASE_DIR, 'validation') | |
| test_dir = os.path.join(BASE_DIR, 'test') | |
| # Verify dataset directories exist | |
| for dir_path in [train_dir, validation_dir, test_dir]: | |
| if not os.path.exists(dir_path): | |
| st.error(f"Directory not found: {dir_path}") | |
| return None, None | |
| # Set the parameters for the data generators | |
| img_height, img_width = 256, 256 | |
| # Create data generators with data augmentation for training | |
| train_datagen = ImageDataGenerator( | |
| rescale=1.0/255.0, | |
| rotation_range=20, | |
| width_shift_range=0.2, | |
| height_shift_range=0.2, | |
| shear_range=0.2, | |
| zoom_range=0.2, | |
| horizontal_flip=True, | |
| vertical_flip=True, | |
| fill_mode='nearest' | |
| ) | |
| train_generator = train_datagen.flow_from_directory( | |
| train_dir, | |
| target_size=(img_height, img_width), | |
| batch_size=batch_size, | |
| class_mode='categorical' | |
| ) | |
| validation_datagen = ImageDataGenerator(rescale=1.0/255.0) | |
| validation_generator = validation_datagen.flow_from_directory( | |
| validation_dir, | |
| target_size=(img_height, img_width), | |
| batch_size=batch_size, | |
| class_mode='categorical' | |
| ) | |
| # Create a CNN model | |
| cnn_model = models.Sequential([ | |
| # First Convolutional Block | |
| layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), | |
| layers.BatchNormalization(), | |
| layers.Conv2D(32, (3, 3), activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Dropout(0.25), | |
| # Second Convolutional Block | |
| layers.Conv2D(64, (3, 3), activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.Conv2D(64, (3, 3), activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Dropout(0.25), | |
| # Third Convolutional Block | |
| layers.Conv2D(128, (3, 3), activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.Conv2D(128, (3, 3), activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Dropout(0.25), | |
| # Dense Layers | |
| layers.Flatten(), | |
| layers.Dense(256, activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.Dropout(0.5), | |
| layers.Dense(train_generator.num_classes, activation='softmax') | |
| ]) | |
| # Compile the model | |
| optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) | |
| cnn_model.compile( | |
| optimizer=optimizer, | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy'] | |
| ) | |
| # Training callbacks | |
| callbacks = [ | |
| tf.keras.callbacks.EarlyStopping( | |
| monitor='val_loss', | |
| patience=5, | |
| restore_best_weights=True | |
| ), | |
| tf.keras.callbacks.ModelCheckpoint( | |
| filepath=os.path.join(run_dir, 'best_model.h5'), | |
| monitor='val_accuracy', | |
| save_best_only=True | |
| ) | |
| ] | |
| # Train the model | |
| st.write("Starting model training...") | |
| st.write(f"Number of classes: {train_generator.num_classes}") | |
| st.write(f"Training samples: {train_generator.samples}") | |
| st.write(f"Validation samples: {validation_generator.samples}") | |
| history = cnn_model.fit( | |
| train_generator, | |
| epochs=epochs, | |
| validation_data=validation_generator, | |
| callbacks=callbacks | |
| ) | |
| # Save the model | |
| model_path = os.path.join(run_dir, 'rose_model.h5') | |
| cnn_model.save(model_path) | |
| st.success(f"Model saved to: {model_path}") | |
| # Save class names | |
| class_names = list(train_generator.class_indices.keys()) | |
| class_names_path = os.path.join(run_dir, 'class_names.json') | |
| with open(class_names_path, 'w') as f: | |
| json.dump(class_names, f) | |
| st.success(f"Class names saved to: {class_names_path}") | |
| # Plot training history | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) | |
| # Plot accuracy | |
| ax1.plot(history.history['accuracy'], label='Training Accuracy') | |
| ax1.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
| ax1.set_title('Model Accuracy') | |
| ax1.set_ylabel('Accuracy') | |
| ax1.set_xlabel('Epoch') | |
| ax1.legend(loc='lower right') | |
| ax1.grid(True) | |
| # Plot loss | |
| ax2.plot(history.history['loss'], label='Training Loss') | |
| ax2.plot(history.history['val_loss'], label='Validation Loss') | |
| ax2.set_title('Model Loss') | |
| ax2.set_ylabel('Loss') | |
| ax2.set_xlabel('Epoch') | |
| ax2.legend(loc='upper right') | |
| ax2.grid(True) | |
| plt.tight_layout() | |
| history_path = os.path.join(run_dir, 'training_history.png') | |
| plt.savefig(history_path) | |
| plt.close() | |
| return history_path, run_dir | |
| def main(): | |
| st.title("Rose Classification Model Training") | |
| st.write("Train your rose classification model with custom parameters") | |
| # Training parameters | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| epochs = st.slider("Number of Epochs", min_value=1, max_value=100, value=50, step=1) | |
| batch_size = st.slider("Batch Size", min_value=8, max_value=64, value=32, step=8) | |
| with col2: | |
| learning_rate = st.slider("Learning Rate", min_value=0.0001, max_value=0.01, value=0.001, step=0.0001) | |
| if st.button("Start Training"): | |
| with st.spinner("Training in progress..."): | |
| history_path, run_dir = train_model(epochs, batch_size, learning_rate) | |
| if history_path and run_dir: | |
| st.image(history_path, caption="Training History") | |
| st.success(f"Training completed! Files saved in: {run_dir}") | |
| if __name__ == "__main__": | |
| main() |