Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,34 @@
|
|
| 1 |
---
|
| 2 |
license: afl-3.0
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: afl-3.0
|
| 3 |
---
|
| 4 |
+
|
| 5 |
+
# Description
|
| 6 |
+
|
| 7 |
+
This model is a ResNet18-based fruit classifier implemented using PyTorch. The ResNet18 architecture is utilized as the backbone for feature extraction, pretrained on the DEFAULT weights.
|
| 8 |
+
|
| 9 |
+
# Architecture
|
| 10 |
+
|
| 11 |
+
The model consists of a modified fully connected layer (fc) following the ResNet18 backbone. The number of input features (num_ftrs) for the fully connected layer is extracted from the pretrained ResNet18 model. The modified fully connected layer comprises three linear layers, each followed by ReLU activation and dropout. The final linear layer is followed by a LogSoftmax activation function to output class probabilities.
|
| 12 |
+
|
| 13 |
+
## Training
|
| 14 |
+
|
| 15 |
+
The model is trained using the CrossEntropyLoss as the loss function and Stochastic Gradient Descent (SGD) as the optimizer. The learning rate (lr) is set to 0.01 with a momentum of 0.9. The training is conducted for one epoch, iterating through the training data and updating the model's weights using backpropagation and gradient descent. Training progress is logged, displaying the loss for every 100 batches processed.
|
| 16 |
+
|
| 17 |
+
## Evaluation
|
| 18 |
+
|
| 19 |
+
Following training, the model's performance is evaluated using a separate test dataset. Accuracy is calculated by comparing the predicted labels with the ground truth labels from the test set with accuracy **98.58%**.
|
| 20 |
+
|
| 21 |
+
## Saving Models
|
| 22 |
+
|
| 23 |
+
The trained model is saved in the 'resnet18_fruit_classifier.pth' file using torch.save(). This allows for easy reuse and deployment of the trained model.
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
|
| 27 |
+
To use the model, load the saved weights using **torch.load('resnet18_fruit_classifier.pth')** and initialize the ResNet18 model with the modified fully connected layer. The model is then ready for inference on new fruit images.
|
| 28 |
+
|
| 29 |
+
# Resources
|
| 30 |
+
|
| 31 |
+
- PyTorch: https://pytorch.org/
|
| 32 |
+
- ResNet: https://arxiv.org/abs/1512.03385
|
| 33 |
+
|
| 34 |
+
|