Spaces:
Sleeping
Sleeping
| import os | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, models, callbacks | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # Set random seed for reproducibility | |
| tf.random.set_seed(42) | |
| np.random.seed(42) | |
| def load_and_preprocess_data(): | |
| """Loads CIFAR-10 dataset and normalizes pixel values.""" | |
| print("Loading CIFAR-10 dataset...") | |
| (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data() | |
| # Normalize pixel values to be between 0 and 1 | |
| train_images, test_images = train_images / 255.0, test_images / 255.0 | |
| print(f"Train images shape: {train_images.shape}") | |
| print(f"Test images shape: {test_images.shape}") | |
| return (train_images, train_labels), (test_images, test_labels) | |
| def build_cnn_model(): | |
| """Defines a robust CNN architecture for CIFAR-10.""" | |
| print("Building CNN architecture...") | |
| model = models.Sequential([ | |
| # Block 1 | |
| layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)), | |
| layers.BatchNormalization(), | |
| layers.Conv2D(32, (3, 3), padding='same', activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Dropout(0.2), | |
| # Block 2 | |
| layers.Conv2D(64, (3, 3), padding='same', activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.Conv2D(64, (3, 3), padding='same', activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Dropout(0.3), | |
| # Block 3 | |
| layers.Conv2D(128, (3, 3), padding='same', activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.Conv2D(128, (3, 3), padding='same', activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.MaxPooling2D((2, 2)), | |
| layers.Dropout(0.4), | |
| # Classification Head | |
| layers.Flatten(), | |
| layers.Dense(128, activation='relu'), | |
| layers.BatchNormalization(), | |
| layers.Dropout(0.5), | |
| layers.Dense(10, activation='softmax') | |
| ]) | |
| model.compile(optimizer='adam', | |
| loss='sparse_categorical_crossentropy', | |
| metrics=['accuracy']) | |
| return model | |
| def train_and_evaluate(): | |
| # 1. Prepare Data | |
| (train_images, train_labels), (test_images, test_labels) = load_and_preprocess_data() | |
| # 2. Build Model | |
| model = build_cnn_model() | |
| model.summary() | |
| # 3. Callbacks | |
| early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True) | |
| reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6) | |
| # 4. Train | |
| print("\nStarting training (limited to 2 epochs for demonstration)...") | |
| history = model.fit( | |
| train_images, train_labels, | |
| epochs=2, # Increase to 50 for full training | |
| batch_size=64, | |
| validation_data=(test_images, test_labels), | |
| callbacks=[early_stop, reduce_lr] | |
| ) | |
| # 5. Evaluate | |
| print("\nEvaluating model...") | |
| test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) | |
| print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%") | |
| # 6. Save Model | |
| model.save('cifar10_cnn_v1.h5') | |
| print("Model saved to cifar10_cnn_v1.h5") | |
| return history | |
| def plot_results(history): | |
| """Visualizes training history.""" | |
| acc = history.history['accuracy'] | |
| val_acc = history.history['val_accuracy'] | |
| loss = history.history['loss'] | |
| val_loss = history.history['val_loss'] | |
| epochs_range = range(len(acc)) | |
| plt.figure(figsize=(12, 5)) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(epochs_range, acc, label='Training Accuracy') | |
| plt.plot(epochs_range, val_acc, label='Validation Accuracy') | |
| plt.title('Accuracy') | |
| plt.legend() | |
| plt.subplot(1, 2, 2) | |
| plt.plot(epochs_range, loss, label='Training Loss') | |
| plt.plot(epochs_range, val_loss, label='Validation Loss') | |
| plt.title('Loss') | |
| plt.legend() | |
| plt.savefig('training_plot.png') | |
| print("Training plots saved as training_plot.png") | |
| if __name__ == "__main__": | |
| try: | |
| hist = train_and_evaluate() | |
| plot_results(hist) | |
| except Exception as e: | |
| print(f"\n[ERROR] An error occurred: {e}") | |
| print("\nNote: If you encounter DLL errors or ModuleNotFound errors, please ensure " | |
| "TensorFlow is correctly installed in your environment (e.g., pip install tensorflow).") | |