|
|
--- |
|
|
license: mit |
|
|
datasets: |
|
|
- jxie/stl10 |
|
|
--- |
|
|
|
|
|
# Image Classifier |
|
|
|
|
|
This repository contains a pre-trained PyTorch model, designed for classifying images into 10 categories: airplane, bird, car, cat, deer, dog, horse, monkey, ship, and truck. The model uses a Convolutional Neural Network (CNN) architecture and can classify images based on the categories below. |
|
|
|
|
|
## Model Overview |
|
|
|
|
|
The model is a simple CNN classifier with two convolutional blocks followed by a fully connected layer. It was trained on an image dataset and can classify images into the following categories: |
|
|
|
|
|
- **0**: Airplane |
|
|
- **1**: Bird |
|
|
- **2**: Car |
|
|
- **3**: Cat |
|
|
- **4**: Deer |
|
|
- **5**: Dog |
|
|
- **6**: Horse |
|
|
- **7**: Monkey |
|
|
- **8**: Ship |
|
|
- **9**: Truck |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
The model consists of the following layers: |
|
|
1. **Conv Block 1**: Two convolutional layers with ReLU activations followed by max pooling. |
|
|
2. **Conv Block 2**: Two more convolutional layers with ReLU activations and max pooling. |
|
|
3. **Fully Connected Classifier**: A linear layer that maps the features to 10 output categories. |
|
|
|
|
|
Here’s the architecture of the model: |
|
|
```python |
|
|
class CNNV0(nn.Module): |
|
|
def __init__(self, input_shape: int, hidden_units: int, output_shape: int): |
|
|
super().__init__() |
|
|
self.conv_block_1 = nn.Sequential( |
|
|
nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(kernel_size=2) |
|
|
) |
|
|
self.conv_block_2 = nn.Sequential( |
|
|
nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(kernel_size=2) |
|
|
) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Flatten(), |
|
|
nn.Linear(in_features=hidden_units*576, out_features=output_shape) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_block_1(x) |
|
|
x = self.conv_block_2(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
``` |
|
|
## Requirements |
|
|
|
|
|
- **Python** 3.7 or higher |
|
|
- **PyTorch** 1.8 or higher |
|
|
- **torchvision** (for loading and preprocessing images) |
|
|
|
|
|
## Usage |
|
|
|
|
|
1. Clone this repository and install dependencies: |
|
|
```bash |
|
|
git clone <repository-url> |
|
|
cd <repository-folder> |
|
|
pip install torch torchvision |
|
|
``` |
|
|
|
|
|
2. Load and use the model in your Python script: |
|
|
```python |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
|
|
|
# Load the model |
|
|
model = torch.load('model_0.pth') |
|
|
model.eval() # Set to evaluation mode |
|
|
|
|
|
# Load and preprocess the image |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
img = Image.open('path_to_image.jpg') |
|
|
img = transform(img).view(1, 3, 224, 224) # Reshape to (1, 3, 224, 224) for batch processing |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
output = model(img) |
|
|
_, predicted = torch.max(output, 1) |
|
|
print("Predicted Aircraft Type:", predicted.item()) |
|
|
``` |
|
|
|