Update README.md
Browse files
README.md
CHANGED
|
@@ -1,13 +1,54 @@
|
|
| 1 |
---
|
| 2 |
license: bigcode-openrail-m
|
| 3 |
language:
|
| 4 |
-
- en
|
| 5 |
-
- fr
|
| 6 |
metrics:
|
| 7 |
-
- accuracy
|
| 8 |
pipeline_tag: image-classification
|
| 9 |
tags:
|
| 10 |
-
- biology
|
| 11 |
-
- code
|
| 12 |
-
- tensorflow
|
| 13 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: bigcode-openrail-m
|
| 3 |
language:
|
| 4 |
+
- en
|
| 5 |
+
- fr
|
| 6 |
metrics:
|
| 7 |
+
- accuracy
|
| 8 |
pipeline_tag: image-classification
|
| 9 |
tags:
|
| 10 |
+
- biology
|
| 11 |
+
- code
|
| 12 |
+
- tensorflow
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Mushroom Classification Model - JarviSpore
|
| 16 |
+
|
| 17 |
+
This repository contains **JarviSpore**, a mushroom image classification model trained on a multi-class dataset with 23 different types of mushrooms. Developed from scratch with TensorFlow and Keras, this model aims to provide accurate mushroom identification using advanced deep learning techniques, including *Grad-CAM* for interpreting predictions. This project explores the performance of from-scratch models compared to transfer learning.
|
| 18 |
+
|
| 19 |
+
## Model Details
|
| 20 |
+
|
| 21 |
+
- **Architecture**: Custom CNN (Convolutional Neural Network)
|
| 22 |
+
- **Number of Classes**: 23 mushroom classes
|
| 23 |
+
- **Input Format**: RGB images resized to 224x224 pixels
|
| 24 |
+
- **Framework**: TensorFlow & Keras
|
| 25 |
+
- **Training**: Conducted on a machine with an i9 14900k processor, 192GB RAM, and an RTX 3090 GPU
|
| 26 |
+
|
| 27 |
+
## Key Features
|
| 28 |
+
|
| 29 |
+
1. **Multi-Class Classification**: The model can predict among 23 mushroom species.
|
| 30 |
+
2. **Regularization**: Includes L2 regularization and Dropout to prevent overfitting.
|
| 31 |
+
3. **Class Weighting**: Manages dataset imbalances by applying specific weights for each class.
|
| 32 |
+
4. **Grad-CAM Visualization**: Utilizes Grad-CAM to generate heatmaps, allowing visualization of the regions influencing the model's predictions.
|
| 33 |
+
|
| 34 |
+
## Model Training
|
| 35 |
+
|
| 36 |
+
The model was trained using a structured dataset directory with data split as follows:
|
| 37 |
+
- `train`: Balanced training dataset
|
| 38 |
+
- `validation`: Validation set to monitor performance
|
| 39 |
+
- `test`: Test set to evaluate final accuracy
|
| 40 |
+
|
| 41 |
+
Main training hyperparameters include:
|
| 42 |
+
- **Batch Size**: 32
|
| 43 |
+
- **Epochs**: 20 with Early Stopping
|
| 44 |
+
- **Learning Rate**: 0.0001
|
| 45 |
+
|
| 46 |
+
Training was tracked and logged via MLflow, including accuracy and loss curves, as well as the best model weights saved automatically.
|
| 47 |
+
|
| 48 |
+
## Model Usage
|
| 49 |
+
|
| 50 |
+
### Prerequisites
|
| 51 |
+
|
| 52 |
+
Ensure the following libraries are installed:
|
| 53 |
+
```bash
|
| 54 |
+
pip install tensorflow pillow matplotlib numpy
|