abelkrw commited on
Commit
7bb60e4
·
1 Parent(s): 8316c41
Files changed (1) hide show
  1. README.md +82 -29
README.md CHANGED
@@ -1,34 +1,87 @@
1
  ---
2
  license: afl-3.0
3
  ---
4
-
5
- # Description
6
-
7
- This model is a ResNet18-based fruit classifier implemented using PyTorch. The ResNet18 architecture is utilized as the backbone for feature extraction, pretrained on the DEFAULT weights.
8
-
9
- # Architecture
10
-
11
- The model consists of a modified fully connected layer (fc) following the ResNet18 backbone. The number of input features (num_ftrs) for the fully connected layer is extracted from the pretrained ResNet18 model. The modified fully connected layer comprises three linear layers, each followed by ReLU activation and dropout. The final linear layer is followed by a LogSoftmax activation function to output class probabilities.
12
-
13
- ## Training
14
-
15
- The model is trained using the CrossEntropyLoss as the loss function and Stochastic Gradient Descent (SGD) as the optimizer. The learning rate (lr) is set to 0.01 with a momentum of 0.9. The training is conducted for one epoch, iterating through the training data and updating the model's weights using backpropagation and gradient descent. Training progress is logged, displaying the loss for every 100 batches processed.
16
-
17
- ## Evaluation
18
-
19
- Following training, the model's performance is evaluated using a separate test dataset. Accuracy is calculated by comparing the predicted labels with the ground truth labels from the test set with accuracy **98.58%**.
20
-
21
- ## Saving Models
22
-
23
- The trained model is saved in the 'resnet18_fruit_classifier.pth' file using torch.save(). This allows for easy reuse and deployment of the trained model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  ## Usage
26
-
27
- To use the model, load the saved weights using **torch.load('resnet18_fruit_classifier.pth')** and initialize the ResNet18 model with the modified fully connected layer. The model is then ready for inference on new fruit images.
28
-
29
- # Resources
30
-
31
- - PyTorch: https://pytorch.org/
32
- - ResNet: https://arxiv.org/abs/1512.03385
33
-
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: afl-3.0
3
  ---
4
+ # Model Card: ResNet18 for Fruit Classification
5
+
6
+ ## Introduction
7
+ This model card provides information about a ResNet18-based fruit classification model trained using the provided code.
8
+
9
+ ## Model Details
10
+ - **Model Architecture**: ResNet18
11
+ - **Number of Classes**: 33 (Fruit Categories)
12
+ - **Loss Function**: Cross Entropy Loss
13
+ - **Optimizer**: Stochastic Gradient Descent (SGD) with learning rate of 0.001
14
+ - **Training Duration**: 5 epochs
15
+
16
+ ## Training and Evaluation
17
+ The model was trained and evaluated using the specified training and test datasets. During training, the model's performance was assessed based on training loss, training accuracy, test loss, and test accuracy for each epoch.
18
+
19
+ ## Performance
20
+ The performance metrics for each epoch are as follows:
21
+
22
+ - **Epoch 1/5**
23
+ - Train Loss: 2.2403
24
+ - Train Accuracy: 70.68%
25
+ - Test Loss: 0.2475
26
+ - Test Accuracy: 99.11%
27
+
28
+ - **Epoch 2/5**
29
+ - Train Loss: 0.1282
30
+ - Train Accuracy: 99.65%
31
+ - Test Loss: 0.0771
32
+ - Test Accuracy: 99.82%
33
+
34
+ - **Epoch 3/5**
35
+ - Train Loss: 0.0568
36
+ - Train Accuracy: 99.89%
37
+ - Test Loss: 0.0514
38
+ - Test Accuracy: 99.76%
39
+
40
+ - **Epoch 4/5**
41
+ - Train Loss: 0.0347
42
+ - Train Accuracy: 99.96%
43
+ - Test Loss: 0.0332
44
+ - Test Accuracy: 99.91%
45
+
46
+ - **Epoch 5/5**
47
+ - Train Loss: 0.0247
48
+ - Train Accuracy: 99.97%
49
+ - Test Loss: 0.0240
50
+ - Test Accuracy: 99.94%
51
 
52
  ## Usage
53
+ To use this model for fruit classification, load the trained weights and utilize the model to classify fruit images into one of the 33 fruit categories.
54
+
55
+ ```python
56
+ # Load the trained weights
57
+ model = resnet18()
58
+ model_weights_path = 'resnet18_fruit_classifier.pth'
59
+ model.load_state_dict(torch.load(model_weights_path))
60
+ model.eval()
61
+ transform = transforms.Compose([
62
+ transforms.Resize(255),
63
+ transforms.ToTensor()
64
+ ])
65
+
66
+ # Perform inference on a sample image
67
+ # ... (code to preprocess and load an image)
68
+
69
+ # Forward pass to get predictions
70
+ with torch.no_grad():
71
+ output = model(image)
72
+
73
+ # Process the output to get predicted class
74
+ predicted_class = torch.argmax(output, dim=1)
75
+ print("Predicted Class:", predicted_class.item())
76
+ ```
77
+
78
+ ## Limitations and Considerations
79
+ - The model's performance may vary based on the quality and diversity of the dataset used for training.
80
+ - The provided number of epochs (5) for training may not be sufficient for achieving optimal performance. Further fine-tuning and experimentation might be necessary.
81
+ - Additional data augmentation and regularization techniques could potentially improve the model's robustness and accuracy.
82
+
83
+ ## Ethical Considerations
84
+ Ensure that the dataset used for training is collected and used ethically, respecting privacy, consent, and applicable laws and regulations.
85
+
86
+ ## Disclaimer
87
+ This model card is for illustrative purposes and does not guarantee any specific performance or outcomes when using the provided code. Users are encouraged to conduct thorough evaluation and testing for their specific use cases.