| | --- |
| | tags: |
| | - image_classification |
| | - computer_vision |
| | license: mit |
| | datasets: |
| | - p2pfl/CIFAR10 |
| | language: |
| | - en |
| | pipeline_tag: image-classification |
| | metrics: |
| | - f1 |
| | --- |
| | |
| | # SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers |
| |
|
| | ### Model Description |
| |
|
| | Implementation of the ***SAG-ViT*** model as proposed in the [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420) paper. |
| |
|
| | It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding. |
| |
|
| | ### Model Architecture |
| |
|
| |  |
| |
|
| | _Image source: [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420)_ |
| |
|
| | ### Usage |
| |
|
| | SAG-ViT expect input images normalized in the same way, |
| | i.e. mini-batches of 3-channel RGB images of shape `(N, 3, H, W)`, where `N` is the number of images, `H` and `W` are expected to be at least `49` pixels. |
| | The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]` |
| | and `std = [0.229, 0.224, 0.225]`. |
| |
|
| | To train or run inference on our model, refer to the following steps: |
| |
|
| | Clone our repository and load the model pretrained on CIFAR-10 dataset. |
| | ```bash |
| | git clone https://huggingface.co/shravvvv/SAG-ViT |
| | cd SAG-ViT |
| | ``` |
| |
|
| | Install required dependencies. |
| | ```bash |
| | pip install -r requirements.txt |
| | ``` |
| |
|
| | Use `from_pretrained` to load the model from Hugging Face Hub and run inference on a sample input image. |
| | ```python |
| | from transformers import AutoModel, AutoConfig |
| | from PIL import Image |
| | from torchvision import transforms |
| | import torch |
| | |
| | # Step 1: Load the model and configuration directly from Hugging Face Hub |
| | repo_name = "shravvvv/SAG-ViT" |
| | config = AutoConfig.from_pretrained(repo_name) # Load config from hub |
| | model = AutoModel.from_pretrained(repo_name, config=config) # Load model from hub |
| | |
| | # Step 2: Define the transformation for the input image |
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), # Resize to match the expected input size |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Example normalization |
| | ]) |
| | |
| | # Step 3: Load and preprocess the input image |
| | input_image_path = "path/to/your/image.jpg" |
| | img = Image.open(input_image_path).convert("RGB") |
| | img = transform(img).unsqueeze(0) # Add batch dimension |
| | |
| | # Step 4: Ensure the model is in evaluation mode |
| | model.eval() |
| | |
| | # Step 5: Run inference |
| | with torch.no_grad(): |
| | outputs = model(img) |
| | logits = outputs.logits # Accessing logits from ModelOutput |
| | |
| | # Step 6: Post-process the predictions |
| | predicted_class_index = torch.argmax(logits, dim=1) # Get the predicted class index |
| | |
| | # CIFAR-10 label mapping |
| | class_names = [ |
| | 'airplane', 'automobile', 'bird', 'cat', 'deer', |
| | 'dog', 'frog', 'horse', 'ship', 'truck' |
| | ] |
| | |
| | # Get the predicted class name from the class index |
| | predicted_class_name = class_names[predicted_class_index.item()] |
| | print(f"Predicted class: {predicted_class_name}") |
| | ``` |
| |
|
| | ### Running Tests |
| |
|
| | If you clone our [repository](https://github.com/shravan-18/SAG-ViT), the *'tests'* folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run: |
| | ```bash |
| | python -m unittest discover -s tests |
| | ``` |
| |
|
| | or, if you are using `pytest`, you can run: |
| | ```bash |
| | pytest tests |
| | ``` |
| |
|
| | **Results** |
| | We evaluated SAG-ViT on diverse datasets: |
| | - **CIFAR-10** (natural images) |
| | - **GTSRB** (traffic sign recognition) |
| | - **NCT-CRC-HE-100K** (histopathological images) |
| | - **NWPU-RESISC45** (remote sensing imagery) |
| | - **PlantVillage** (agricultural imagery) |
| |
|
| | SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores): |
| |
|
| | <center> |
| |
|
| | | Backbone | CIFAR-10 | GTSRB | NCT-CRC-HE-100K | NWPU-RESISC45 | PlantVillage | |
| | |--------------------|----------|--------|-----------------|---------------|--------------| |
| | | DenseNet201 | 0.5427 | 0.9862 | 0.9214 | 0.4493 | 0.8725 | |
| | | Vgg16 | 0.5345 | 0.8180 | 0.8234 | 0.4114 | 0.7064 | |
| | | Vgg19 | 0.5307 | 0.7551 | 0.8178 | 0.3844 | 0.6811 | |
| | | DenseNet121 | 0.5290 | 0.9813 | 0.9247 | 0.4381 | 0.8321 | |
| | | AlexNet | 0.6126 | 0.9059 | 0.8743 | 0.4397 | 0.7684 | |
| | | Inception | 0.7734 | 0.8934 | 0.8707 | 0.8707 | 0.8216 | |
| | | ResNet | 0.9172 | 0.9134 | 0.9478 | 0.9103 | 0.8905 | |
| | | MobileNet | 0.9169 | 0.3006 | 0.4965 | 0.1667 | 0.2213 | |
| | | ViT - S | 0.8465 | 0.8542 | 0.8234 | 0.6116 | 0.8654 | |
| | | ViT - L | 0.8637 | 0.8613 | 0.8345 | 0.8358 | 0.8842 | |
| | | MNASNet1_0 | 0.1032 | 0.0024 | 0.0212 | 0.0011 | 0.0049 | |
| | | ShuffleNet_V2_x1_0 | 0.3523 | 0.4244 | 0.4598 | 0.1808 | 0.3190 | |
| | | SqueezeNet1_0 | 0.4328 | 0.8392 | 0.7843 | 0.3913 | 0.6638 | |
| | | GoogLeNet | 0.4954 | 0.9455 | 0.8631 | 0.3720 | 0.7726 | |
| | | **Proposed (SAG-ViT)** | **0.9574** | **0.9958** | **0.9861** | **0.9549** | **0.9772** | |
| | |
| | </center> |
| | |
| | ## Citation |
| | |
| | If you find our [paper](https://arxiv.org/abs/2411.09420) and [code](https://github.com/shravan-18/SAG-ViT) helpful for your research, please consider citing our work and giving the repository a star: |
| | |
| | ```bibtex |
| | @misc{SAGViT, |
| | title={SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers}, |
| | author={Shravan Venkatraman and Jaskaran Singh Walia and Joe Dhanith P R}, |
| | year={2024}, |
| | eprint={2411.09420}, |
| | archivePrefix={arXiv}, |
| | primaryClass={cs.CV}, |
| | url={https://arxiv.org/abs/2411.09420}, |
| | } |
| | ``` |