CNN2 / train_cifar10.py
d-e-e-k-11's picture
Upload folder using huggingface_hub
294928d verified
import os
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
import matplotlib.pyplot as plt
import numpy as np
# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)
def load_and_preprocess_data():
"""Loads CIFAR-10 dataset and normalizes pixel values."""
print("Loading CIFAR-10 dataset...")
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
print(f"Train images shape: {train_images.shape}")
print(f"Test images shape: {test_images.shape}")
return (train_images, train_labels), (test_images, test_labels)
def build_cnn_model():
"""Defines a robust CNN architecture for CIFAR-10."""
print("Building CNN architecture...")
model = models.Sequential([
# Block 1
layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.2),
# Block 2
layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.3),
# Block 3
layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
layers.BatchNormalization(),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.4),
# Classification Head
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
def train_and_evaluate():
# 1. Prepare Data
(train_images, train_labels), (test_images, test_labels) = load_and_preprocess_data()
# 2. Build Model
model = build_cnn_model()
model.summary()
# 3. Callbacks
early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
# 4. Train
print("\nStarting training (limited to 2 epochs for demonstration)...")
history = model.fit(
train_images, train_labels,
epochs=2, # Increase to 50 for full training
batch_size=64,
validation_data=(test_images, test_labels),
callbacks=[early_stop, reduce_lr]
)
# 5. Evaluate
print("\nEvaluating model...")
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"\nFinal Test Accuracy: {test_acc*100:.2f}%")
# 6. Save Model
model.save('cifar10_cnn_v1.h5')
print("Model saved to cifar10_cnn_v1.h5")
return history
def plot_results(history):
"""Visualizes training history."""
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.title('Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.title('Loss')
plt.legend()
plt.savefig('training_plot.png')
print("Training plots saved as training_plot.png")
if __name__ == "__main__":
try:
hist = train_and_evaluate()
plot_results(hist)
except Exception as e:
print(f"\n[ERROR] An error occurred: {e}")
print("\nNote: If you encounter DLL errors or ModuleNotFound errors, please ensure "
"TensorFlow is correctly installed in your environment (e.g., pip install tensorflow).")