| | --- |
| | tags: |
| | - image-classification |
| | - pytorch |
| | library_name: transformers |
| | datasets: |
| | - garythung/trashnet |
| | --- |
| | |
| |
|
| | # Trash Image Classification using Vision Transformer (ViT) |
| |
|
| | This repository contains an implementation of an image classification model using a pre-trained Vision Transformer (ViT) model from Hugging Face. The model is fine-tuned to classify images into six categories: cardboard, glass, metal, paper, plastic, and trash. |
| |
|
| | ## Dataset |
| |
|
| | The dataset consists of images from six categories from [`garythung/trashnet`](https://huggingface.co/datasets/garythung/trashnet) with the following distribution: |
| |
|
| | - Cardboard: 806 images |
| | - Glass: 1002 images |
| | - Metal: 820 images |
| | - Paper: 1188 images |
| | - Plastic: 964 images |
| | - Trash: 274 images |
| |
|
| | ## Model |
| |
|
| | We utilize the pre-trained Vision Transformer model [`google/vit-base-patch16-224-in21k`](https://huggingface.co/google/vit-base-patch16-224-in21k) from Hugging Face for image classification. The model is fine-tuned on the dataset to achieve optimal performance. |
| |
|
| | The trained model is accessible on Hugging Face Hub at: [tribber93/my-trash-classification](https://huggingface.co/tribber93/my-trash-classification) |
| |
|
| | ## Usage |
| |
|
| | To use the model for inference, follow these steps: |
| |
|
| | ```python |
| | import torch |
| | import requests |
| | from PIL import Image |
| | from transformers import AutoModelForImageClassification, AutoImageProcessor |
| | |
| | url = 'https://cdn.grid.id/crop/0x0:0x0/700x465/photo/grid/original/127308_kaleng-bekas.jpg' |
| | image = Image.open(requests.get(url, stream=True).raw) |
| | |
| | model_name = "tribber93/my-trash-classification" |
| | model = AutoModelForImageClassification.from_pretrained(model_name) |
| | processor = AutoImageProcessor.from_pretrained(model_name) |
| | inputs = processor(image, return_tensors="pt") |
| | |
| | outputs = model(**inputs) |
| | predictions = torch.argmax(outputs.logits, dim=-1) |
| | print("Predicted class:", model.config.id2label[predictions.item()]) |
| | ``` |
| |
|
| | ## Results |
| |
|
| | After training, the model achieved the following performance: |
| |
|
| | | Epoch | Training Loss | Validation Loss | Accuracy | |
| | |-------|---------------|-----------------|----------| |
| | | 1 | 3.3200 | 0.7011 | 86.25% | |
| | | 2 | 1.6611 | 0.4298 | 91.49% | |
| | | 3 | 1.4353 | 0.3563 | 94.26% | |