Spaces:
Sleeping
Sleeping
File size: 4,582 Bytes
294928d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
import matplotlib.pyplot as plt
import numpy as np
# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)
def load_and_preprocess_data():
"""Loads CIFAR-10 dataset and normalizes pixel values."""
print("Loading CIFAR-10 dataset...")
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
print(f"Train images shape: {train_images.shape}")
print(f"Test images shape: {test_images.shape}")
return (train_images, train_labels), (test_images, test_labels)
def build_cnn_model():
"""Defines a robust CNN architecture for CIFAR-10."""
print("Building CNN architecture...")
model = models.Sequential([
# Block 1
layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.2),
# Block 2
layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.3),
# Block 3
layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.4),
# Classification Head
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
def train_and_evaluate():
# 1. Prepare Data
(train_images, train_labels), (test_images, test_labels) = load_and_preprocess_data()
# 2. Build Model
model = build_cnn_model()
model.summary()
# 3. Callbacks
early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
# 4. Train
print("\nStarting training (limited to 2 epochs for demonstration)...")
history = model.fit(
train_images, train_labels,
epochs=2, # Increase to 50 for full training
batch_size=64,
validation_data=(test_images, test_labels),
callbacks=[early_stop, reduce_lr]
)
# 5. Evaluate
print("\nEvaluating model...")
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%")
# 6. Save Model
model.save('cifar10_cnn_v1.h5')
print("Model saved to cifar10_cnn_v1.h5")
return history
def plot_results(history):
"""Visualizes training history."""
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.title('Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.title('Loss')
plt.legend()
plt.savefig('training_plot.png')
print("Training plots saved as training_plot.png")
if __name__ == "__main__":
try:
hist = train_and_evaluate()
plot_results(hist)
except Exception as e:
print(f"\n[ERROR] An error occurred: {e}")
print("\nNote: If you encounter DLL errors or ModuleNotFound errors, please ensure "
"TensorFlow is correctly installed in your environment (e.g., pip install tensorflow).")
|