developerPushkal's picture
Create README.md
c4cdbf6 verified
### **🌺 ResNet-50 Flowers Classification Model**
This repository hosts a fine-tuned **ResNet-50-based** model optimized for **flower classification** using the **Flowers-102 dataset**. The model classifies images into **102 different flower categories**.
---
## **πŸ“š Model Details**
- **Model Architecture**: ResNet-50
- **Task**: Multi-class Flower Classification
- **Dataset**: Flowers-102 ([Oxford Dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/))
- **Framework**: PyTorch
- **Input Image Size**: 224x224
- **Number of Classes**: 102 (Different Flower Categories)
- **Quantization**: FP16 (for efficiency)
---
## **πŸš€ Usage**
### **Installation**
```bash
pip install torch torchvision pillow
```
### **Loading the Model**
```python
import torch
import torchvision.models as models
# Step 1: Define the model architecture (Must match the trained model)
model = models.resnet50(pretrained=False)
model.fc = torch.nn.Linear(in_features=2048, out_features=102) # Ensure output matches 102 classes
# Step 2: Load the fine-tuned model weights
model_path = "/content/resnet50_flowers_model.pth" # Ensure the file is in the correct directory
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
# Step 3: Set model to evaluation mode
model.eval()
print("βœ… Model loaded successfully and ready for inference!")
```
---
### **πŸ“° Perform Flower Classification**
```python
from PIL import Image
import torchvision.transforms as transforms
# Load the image
image_path = "/content/sample_flower.jpg" # Replace with your test image
image = Image.open(image_path).convert("RGB") # Ensure 3-channel format
# Define preprocessing (same as used during training)
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to match model input
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Apply transformations
image = transform(image).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = model(image)
# Convert output to class prediction
predicted_class = torch.argmax(output, dim=1).item()
print(f"βœ… Predicted Flower Label: {predicted_class}")
```
---
## **πŸ“Š Evaluation Results**
After fine-tuning, the model was evaluated on the **Flowers-102 Dataset**, achieving the following performance:
| **Metric** | **Score** |
|------------------|----------|
| **Accuracy** | 92.8% |
| **Precision** | 91.5% |
| **Recall** | 90.9% |
| **F1-Score** | 91.2% |
| **Inference Speed** | Fast (Optimized with FP16) |
---
## **πŸ› οΈ Fine-Tuning Details**
### **Dataset**
The model was trained on the **Flowers-102 dataset**, which contains **8,189 flower images** classified into **102 categories**.
### **Training Configuration**
- **Number of epochs**: 20
- **Batch size**: 16
- **Optimizer**: Adam
- **Learning rate**: 1e-4
- **Loss Function**: Cross-Entropy
- **Evaluation Strategy**: Validation at each epoch
### **Quantization**
The model was quantized using **FP16 precision**, reducing latency and memory usage while maintaining high accuracy.
---
## **⚠️ Limitations**
- **Misclassification risk**: The model may incorrectly classify similar-looking flowers.
- **Dataset bias**: Performance may vary based on background, lighting, and image quality.
- **Generalization**: The model was trained on a specific dataset and may not generalize well to unseen flower species.
---
βœ… **Use this fine-tuned ResNet-50 model for accurate and efficient flower classification!** πŸŒΊπŸš€