Cifar10 / README.md
PrarthanaTS's picture
Update README.md
0be209b

A newer version of the Gradio SDK is available: 6.13.0

Upgrade
metadata
title: Custom Resnet on CIFAR10 using pytorch Lightening and GradCAM
emoji: πŸƒ
colorFrom: indigo
colorTo: pink
sdk: gradio
sdk_version: 3.39.0
app_file: app.py
pinned: false
license: mit

Custom Resnet on CIFAR10 using pytorch Lightening and GradCAM

Introduction

This repository contains an application for CIFAR-10 classification using PyTorch Lightning. Image Classification is implemented using custom Resnet. The Application includes functionalities for missclassification and GradCam

Installation

To install this aplication or to run it locally on Colab.

The same folder structure needs to be followed

    CIFAR10 Image Classification
    |── requirements.txt
    |── resnet.py
    |── app.py
    |── resnet_model_v2.pth
    |── examples
    |── ── bird.png
    |── ── plane.png
    |── README.md
  1. Clone this repository:
  git lfs install
  git clone https://huggingface.co/spaces/PrarthanaTS/Cifar10
  1. Run the app.py script:

The app.py python file includes a 'demo.launch()' command that will launch a web-based interface using Gradio. You can access the interface by opening the provided URL in your web browser.

Usage

The app has two tabs:

Top Classes and Prediction

In this tab, we can upload our own image or choose an example image from the images provided to classify and visualize the class wise probaility of the image.

  1. Input Image: Upload our own image or select one of the example images from the given images
  2. Transparency : The transparency of the GradCAM, default value is set to 0.5
  3. Network Layers: The target layers for GradCAM visualization. GradCAM Visulaizations work better for -2 and -1 layers. The default value is -2
  4. Top Classes: The top predicted classes to display along with their probability or confidence. Values are between 1 to 10 and default is 10

After providing the settings values, click on the "Submit" button to see the results.

In the Output

  1. Top Classes for each Image from 0 to 100. It provides a confidence score
  2. Model Prediction for the image in GradCam Visulaization
Output1

Miss Classified Images

In this tab we provide the number of misclassified images we would like to view. It provies the GradCAM and original images visualization for the CIFAR10 Custom Resnet Model

  1. Misclassified Inputs: The number of misclassified examples to be displayed
  2. Enable GradCAM: Check this box to display the GradCAM overlay on the misclassified images. Uncheck it to view only the original images.
  3. Network Layer: The target layers for GradCAM visualization. GradCAM Visulaizations work better for -2 and -1 layers. The default value is -2
  4. Transparency: The transparency of the GradCAM, default value is set to 0.5

After providing the settings values, click on the "Submit" button to see the results.

In the Output

  1. Model Prediction using GradCAM Visualization for the provided number of images
  2. Model Prediction using original images visualization for the provided number of images
2

Training Code

  1. The notebook for this assignment can be accessed here: Assignment 12
  2. The CustomResNet Lightening Module includes all the classes including:
    1. forward - The forward function defines the forward pass of the model. It takes an input tensor x and passes it through the layers of the model to produce the output predictions.
    2. training_step - The training_step function defines the operations performed during a single training step. It takes a batch of training data and computes the loss and any other relevant metrics for that batch.
    3. validation_step - The validation_step function defines the operations performed during a single validation step. It takes a batch of validation data and computes the loss and metrics for that batch.
    4. test_step - The test_step function defines the operations performed during a single test step. It takes a batch of test data and computes the loss and metrics for that batch.
    5. configure_optimizers - The configure_optimizers function defines the optimizer(s) used during training. It allows you to specify the optimization algorithm and its hyperparameters.
    6. prepare_data - The prepare_data function is used for data preparation tasks that need to be performed only once, such as downloading and preprocessing the dataset.
    7. setup - The setup function is used for any setup or initialization tasks that need to be performed on each GPU or distributed process before training begins
    8. train_dataloader - The train_dataloader function defines the data loader for the training dataset. It specifies how training batches are sampled and prepared for training. The same is done for Validation and Test data
    9. show_misclassified_images - The show_misclassified_images function is a custom function that can be used to visualize misclassified images during evaluation.
  1. The Configurations for the notebook is provided, it provides classes, learning rate, batch size, epochs
    config = {
    'batch_size': 512,
    'data_dir': './data',
    'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
    'num_classes': 10,
    'lr': 0.01,
    'max_lr': 0.1,
    'max_lr_epoch': 5,
    'dropout' : 0.01,
}
  1. Tensor Board is used with Lightening for visualization and Montioring the training

    image

  2. The Model is saved using

    torch.save(model.state_dict(), "resnet_model_v2.pth")
  1. Gradio is then used for visulaization of the app. The App related information can be found Custom Resner App

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

The custom ResNet model is inspired by the ResNet architecture (https://github.com/davidcpage/cifar10-fast). The GradCAM implementation is based on the pytorch_grad_cam library (https://github.com/jacobgil/pytorch-grad-cam).

Conclusion

PyTorch Lightning has proven to be an invaluable tool for training and evaluating a CIFAR-10 image classification model.

  1. Its clean and modular design allowed us to focus on building and refining the model architecture With PyTorch Lightning, seamlessly organize our data loaders, model definition, and training loops, resulting in a more maintainable and scalable codebase.
  2. Integrating Grad-CAM visualization further enhanced our understanding of the model's inner workings. Grad-CAM provided us with insightful heatmaps, highlighting the regions of the input images that contributed most significantly to the model's predictions. This visualization technique not only improved interpretability but also instilled confidence in the model's performance and decision-making process.
  3. TensorBoard logging was instrumental in monitoring the training process effectively. I could easily track crucial training metrics such as loss, accuracy, and learning rate across epochs.
  4. By combining PyTorch Lightning, Grad-CAM visualization, and TensorBoard logging, i was able to build, analyze, and fine-tune a robust CIFAR-10 image classification model efficiently.

Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference