HimankJ's picture
Update README.md
c8e2321 verified

A newer version of the Gradio SDK is available: 6.5.1

Upgrade
metadata
title: CIFAR Classification
emoji: 📊
colorFrom: blue
colorTo: green
sdk: gradio
sdk_version: 4.27.0
app_file: app.py
pinned: false
license: mit

CIFAR-10 ResNet18 Model with GradCAM

This repository contains a Gradio interface for inference on a ResNet18 model trained on the CIFAR-10 dataset. Additionally, it provides the ability to visualize GradCAM (Gradient-weighted Class Activation Mapping) results, which highlight the regions of input images that contributed most to the model's predictions.

Gradio Interface Features

The Gradio interface allows users to:

  • Upload Images: Users can upload images to the interface for inference.
  • Enable GradCAM Visualization: Users can choose to enable GradCAM visualization, which highlights the regions of the input image that influenced the model's prediction the most.
  • Adjust Image Opacity: Users can adjust the overall opacity of the input image and the GradCAM visualization using a slider.
  • Select Layer: Users can select the layer from which GradCAM will generate the visualization. The default value is set to the last layer of the network (-1).
  • Choose Number of Top Classes: Users can specify the number of top predicted classes to display along with their confidence scores.

Output

After uploading an image and selecting the desired options, the interface provides the following outputs:

  • Predicted Category: The predicted category/class label based on the input image.
  • Output: The output image with GradCAM visualization overlaid (if enabled) and adjusted opacity.
  • Confidence Scores: The confidence scores of the top predicted classes, along with their corresponding labels.

About model

This repository contains an implementation of ResNet (Residual Neural Network) for CIFAR-10 classification using PyTorch Lightning. ResNet is a deep learning architecture known for its effectiveness in training very deep neural networks. PyTorch Lightning is a lightweight PyTorch wrapper that simplifies the training process and provides various utilities for research and production-level training.

Training Performance

The model was trained for 20 epochs on the CIFAR-10 dataset using a batch size of 512. The training logs show the following metrics:

  • Final Validation Accuracy: 80.83%
  • Final Validation Loss: 0.846 During training, the validation accuracy steadily increased, indicating that the model was learning to generalize well to unseen data. The validation loss also decreased over epochs, suggesting that the model's predictions aligned better with the ground truth labels.

Test Performance

After training, the model's performance was evaluated on the test dataset. The test logs report the following metrics:

  • Test Accuracy: 80.83%
  • Test Loss: 0.846

The test accuracy and loss are consistent with the validation metrics, indicating that the model generalizes well to unseen data. This performance demonstrates the effectiveness of the ResNet architecture for image classification tasks on the CIFAR-10 dataset.

The model achieved competitive performance, achieving over 80% accuracy on both the validation and test sets. These results validate the effectiveness of the ResNet architecture and the training strategy employed in this project.

Loss Graph

image/png