Spaces:
Sleeping
Sleeping
File size: 6,569 Bytes
43e2735 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | import os
import streamlit as st
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import json
from datetime import datetime
def create_model_directories():
"""Create necessary directories for model storage"""
os.makedirs('models', exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join('models', f'run_{timestamp}')
os.makedirs(run_dir, exist_ok=True)
return run_dir
def train_model(epochs, batch_size, learning_rate):
# Create directories
run_dir = create_model_directories()
# Set the paths to your image folders
BASE_DIR = os.path.join(os.getcwd(), 'dataset')
train_dir = os.path.join(BASE_DIR, 'train')
validation_dir = os.path.join(BASE_DIR, 'validation')
test_dir = os.path.join(BASE_DIR, 'test')
# Verify dataset directories exist
for dir_path in [train_dir, validation_dir, test_dir]:
if not os.path.exists(dir_path):
st.error(f"Directory not found: {dir_path}")
return None, None
# Set the parameters for the data generators
img_height, img_width = 256, 256
# Create data generators with data augmentation for training
train_datagen = ImageDataGenerator(
rescale=1.0/255.0,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
fill_mode='nearest'
)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical'
)
validation_datagen = ImageDataGenerator(rescale=1.0/255.0)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical'
)
# Create a CNN model
cnn_model = models.Sequential([
# First Convolutional Block
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# Second Convolutional Block
layers.Conv2D(64, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# Third Convolutional Block
layers.Conv2D(128, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(128, (3, 3), activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# Dense Layers
layers.Flatten(),
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(train_generator.num_classes, activation='softmax')
])
# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
cnn_model.compile(
optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Training callbacks
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
),
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(run_dir, 'best_model.h5'),
monitor='val_accuracy',
save_best_only=True
)
]
# Train the model
st.write("Starting model training...")
st.write(f"Number of classes: {train_generator.num_classes}")
st.write(f"Training samples: {train_generator.samples}")
st.write(f"Validation samples: {validation_generator.samples}")
history = cnn_model.fit(
train_generator,
epochs=epochs,
validation_data=validation_generator,
callbacks=callbacks
)
# Save the model
model_path = os.path.join(run_dir, 'rose_model.h5')
cnn_model.save(model_path)
st.success(f"Model saved to: {model_path}")
# Save class names
class_names = list(train_generator.class_indices.keys())
class_names_path = os.path.join(run_dir, 'class_names.json')
with open(class_names_path, 'w') as f:
json.dump(class_names, f)
st.success(f"Class names saved to: {class_names_path}")
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Plot accuracy
ax1.plot(history.history['accuracy'], label='Training Accuracy')
ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax1.set_title('Model Accuracy')
ax1.set_ylabel('Accuracy')
ax1.set_xlabel('Epoch')
ax1.legend(loc='lower right')
ax1.grid(True)
# Plot loss
ax2.plot(history.history['loss'], label='Training Loss')
ax2.plot(history.history['val_loss'], label='Validation Loss')
ax2.set_title('Model Loss')
ax2.set_ylabel('Loss')
ax2.set_xlabel('Epoch')
ax2.legend(loc='upper right')
ax2.grid(True)
plt.tight_layout()
history_path = os.path.join(run_dir, 'training_history.png')
plt.savefig(history_path)
plt.close()
return history_path, run_dir
def main():
st.title("Rose Classification Model Training")
st.write("Train your rose classification model with custom parameters")
# Training parameters
col1, col2 = st.columns(2)
with col1:
epochs = st.slider("Number of Epochs", min_value=1, max_value=100, value=50, step=1)
batch_size = st.slider("Batch Size", min_value=8, max_value=64, value=32, step=8)
with col2:
learning_rate = st.slider("Learning Rate", min_value=0.0001, max_value=0.01, value=0.001, step=0.0001)
if st.button("Start Training"):
with st.spinner("Training in progress..."):
history_path, run_dir = train_model(epochs, batch_size, learning_rate)
if history_path and run_dir:
st.image(history_path, caption="Training History")
st.success(f"Training completed! Files saved in: {run_dir}")
if __name__ == "__main__":
main() |