Update README.md
Browse files# ViT Fruit Ripeness Classifier
A fruit ripeness classification model combining **Vision Transformer (ViT)** feature extraction with **Logistic Regression** for fast, accurate inference on CPU or GPU.
## Model Description
This model classifies the ripeness condition of **apples, bananas, and oranges** into three categories:
- **Fresh** - Ready to eat
- **Unripe** - Needs more time to ripen
- **Rotten** - Past optimal consumption
### Architecture
- **Feature Extractor**: ViT Base Patch-16 (`google/vit-base-patch16-224`)
- **Classifier**: Scikit-learn Logistic Regression
- **Feature Dimension**: 768-dim pooled output from ViT
- **Total Classes**: 9 (3 fruits Γ 3 ripeness states)
## Supported Classes
| Class | Description |
|-------|-------------|
| `freshapples` | Fresh, ready-to-eat apples |
| `freshbanana` | Fresh, ripe bananas |
| `freshoranges` | Fresh, ripe oranges |
| `rottenapples` | Overripe/rotten apples |
| `rottenbanana` | Overripe/rotten bananas |
| `rottenoranges` | Overripe/rotten oranges |
| `unripe apple` | Unripe apples |
| `unripe banana` | Unripe bananas |
| `unripe orange` | Unripe oranges |
## π Quick Start
### Installation
```bash
pip install torch torchvision transformers scikit-learn pillow joblib numpy huggingface_hub
```
### Single Image Inference
```python
import json
import joblib
from pathlib import Path
from PIL import Image
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import AutoImageProcessor, ViTModel
# Configuration
REPO_ID = "Meeteshn/vit_fruit_ripeness_classifier"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load model components
processor = AutoImageProcessor.from_pretrained(REPO_ID, subfolder="processor", use_fast=True)
backbone = ViTModel.from_pretrained(REPO_ID, subfolder="vit_backbone")
backbone.to(DEVICE)
backbone.eval()
# Load sklearn artifacts
scaler_path = hf_hub_download(REPO_ID, "scaler.joblib")
clf_path = hf_hub_download(REPO_ID, "logistic_model.joblib")
metadata_path = hf_hub_download(REPO_ID, "metadata.json")
scaler = joblib.load(scaler_path)
clf = joblib.load(clf_path)
metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
classes = metadata["classes"]
def predict(image_path: str):
"""Predict ripeness condition for a single image."""
img = Image.open(image_path).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(DEVICE)
with torch.no_grad():
out = backbone(pixel_values=pixel_values, return_dict=True)
pooled = getattr(out, "pooler_output", None)
if pooled is None:
pooled = out.last_hidden_state[:, 0, :]
feat = pooled.cpu().numpy()
feat_scaled = scaler.transform(feat)
probs = clf.predict_proba(feat_scaled)[0]
idx = int(np.argmax(probs))
return classes[idx], float(probs[idx]), {
classes[i]: float(probs[i]) for i in range(len(classes))
}
# Example usage
if __name__ == "__main__":
label, prob, all_probs = predict("my_apple.jpg")
print(f"Prediction: {label} ({prob*100:.2f}%)")
print("\nTop 5 probabilities:")
for cls, p in sorted(all_probs.items(), key=lambda x: -x[1])[:5]:
print(f" {cls}: {p*100:.2f}%")
```
### Batch Prediction
```python
from pathlib import Path
import csv
def batch_predict(folder_path: str, output_csv: str = "predictions.csv"):
"""Predict ripeness for all images in a folder."""
folder = Path(folder_path)
with open(output_csv, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["filename", "predicted_label", "probability"])
for img_path in sorted(folder.rglob("*")):
if img_path.suffix.lower() not in [".jpg", ".jpeg", ".png", ".bmp"]:
continue
label, prob, _ = predict(str(img_path))
writer.writerow([img_path.name, label, f"{prob*100:.2f}%"])
print(f"Predictions saved to {output_csv}")
# Usage
batch_predict("path/to/images")
```
## Example Output
```
Prediction: rottenapples (71.24%)
Top 5 probabilities:
rottenapples: 71.24%
rottenbanana: 12.35%
freshapples: 6.12%
unripe apple: 4.89%
freshoranges: 2.31%
```
## Repository Structure
```
vit_fruit_ripeness_updated/
βββ processor/ # AutoImageProcessor configuration
βββ vit_backbone/ # ViT feature extractor weights
βββ logistic_model.joblib # Trained classifier
βββ scaler.joblib # Feature scaler
βββ metadata.json # Class labels and metadata
βββ features_extracted.npz # (Optional) Cached features
```
## Performance Notes
- **Best Results**: Fruit is centered and clearly visible in the image
- **Works Well**: Smartphone photos, typical market/kitchen images
- **Challenges**: Heavy background clutter, extreme lighting conditions, unusual fruit varieties
- **Uncertainty Handling**: Use top-K probabilities to assess prediction confidence
## Use Cases
- Quality control in fruit sorting facilities
- Smart grocery shopping apps
- Food waste reduction systems
- Educational tools for agriculture
- Retail inventory management
## Technical Details
- **Input Size**: 224Γ224 pixels (automatically resized)
- **Inference Speed**: ~50-100ms per image (GPU), ~200-500ms (CPU)
- **Memory Usage**: ~500MB (model weights)
- **Training**: No training required for inference
## Future Improvements
- [ ] Add Gradio/Streamlit demo for web interface
- [ ] Support for additional fruit types
- [ ] Regression head for continuous ripeness percentage
- [ ] Model quantization for mobile deployment
- [ ] Multi-fruit detection in single images
## License
MIT License - See LICENSE file for details
## π€ Author
**Meetesh Nagrecha**
## Acknowledgments
- Base model: `google/vit-base-patch16-224`
- Framework: Hugging Face Transformers & Scikit-learn
---
**Citation**
If you use this model in your research, please cite:
```bibtex
@misc {vit-fruit-ripeness-classifier,
author = {Nagrecha, Meetesh},
title = {ViT Fruit Ripeness Classifier},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/Meeteshn/vit-fruit-ripeness-classifier}}
}
```