Meeteshn's picture
Update README.md
d223bdb verified
---
license: mit
language:
- en
metrics:
- accuracy
base_model:
- google/vit-base-patch16-224
pipeline_tag: image-classification
tags:
- biology
---
# ViT Fruit Ripeness Classifier
A fruit ripeness classification model combining **Vision Transformer (ViT)** feature extraction with **Logistic Regression** for fast, accurate inference on CPU or GPU.
## Model Description
This model classifies the ripeness condition of **apples, bananas, and oranges** into three categories:
- **Fresh** - Ready to eat
- **Unripe** - Needs more time to ripen
- **Rotten** - Past optimal consumption
### Architecture
- **Feature Extractor**: ViT Base Patch-16 (`google/vit-base-patch16-224`)
- **Classifier**: Scikit-learn Logistic Regression
- **Feature Dimension**: 768-dim pooled output from ViT
- **Total Classes**: 9 (3 fruits Γ— 3 ripeness states)
## Supported Classes
| Class | Description |
|-------|-------------|
| `freshapples` | Fresh, ready-to-eat apples |
| `freshbanana` | Fresh, ripe bananas |
| `freshoranges` | Fresh, ripe oranges |
| `rottenapples` | Overripe/rotten apples |
| `rottenbanana` | Overripe/rotten bananas |
| `rottenoranges` | Overripe/rotten oranges |
| `unripe apple` | Unripe apples |
| `unripe banana` | Unripe bananas |
| `unripe orange` | Unripe oranges |
## Quick Start
### Installation
```bash
pip install torch torchvision transformers scikit-learn pillow joblib numpy huggingface_hub
```
### First Run Notes
- **Automatic Downloads**: The model files (~350-400MB) download automatically on first run
- **No Manual Downloads**: You don't need to manually download any model files
- **Internet Required**: Only for the first run; subsequent runs work offline using cached files
- **Time**: First run takes 2-5 minutes for downloads; later runs are instant
### Single Image Inference
```python
import json
import joblib
from pathlib import Path
from PIL import Image
import torch
import numpy as np
from huggingface_hub import hf_hub_download, HfApi
from transformers import AutoImageProcessor, ViTModel
import warnings
# ----------------- CONFIG -----------------
REPO_ID = "Meeteshn/vit_fruit_ripeness_classifier"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NESTED_FOLDER = "vit_fruit_ripeness_updated" # your repo uses this nested folder
TOP_K = 5
# ------------------------------------------
def hf_download_try(repo_id: str, filename: str, nested_folder: str = NESTED_FOLDER):
"""
Try to download `filename` from repo root, then from nested_folder/filename.
Returns local path to downloaded file or raises an informative error.
"""
candidates = [filename, f"{nested_folder}/{filename}"]
last_exc = None
for f in candidates:
try:
print(f"Trying to download '{f}' from '{repo_id}'...")
path = hf_hub_download(repo_id=repo_id, filename=f)
print("Downloaded:", path)
return path
except Exception as e:
print(f"Not found at '{f}': {e}")
last_exc = e
raise RuntimeError(f"Could not download '{filename}' from repo '{repo_id}'. Last error: {last_exc}")
def load_processor_and_backbone(repo_id: str, nested_folder: str = NESTED_FOLDER, device: str = DEVICE):
"""
Try several likely subfolder locations for processor/backbone.
Returns (processor, backbone).
"""
# candidate subfolders for processor
proc_candidates = [
"processor",
f"{nested_folder}/processor",
"", # no subfolder (root)
]
last_exc = None
for sub in proc_candidates:
try:
if sub == "":
print(f"Trying AutoImageProcessor.from_pretrained('{repo_id}')")
processor = AutoImageProcessor.from_pretrained(repo_id, use_fast=True)
else:
print(f"Trying AutoImageProcessor.from_pretrained('{repo_id}', subfolder='{sub}')")
processor = AutoImageProcessor.from_pretrained(repo_id, subfolder=sub, use_fast=True)
# now try backbone with matching guessed subfolder
backbone_sub = sub.replace("processor", "vit_backbone") if sub and "processor" in sub else ("vit_backbone" if sub == "" else f"{nested_folder}/vit_backbone")
try:
print(f"Trying ViTModel.from_pretrained('{repo_id}', subfolder='{backbone_sub}')")
backbone = ViTModel.from_pretrained(repo_id, subfolder=backbone_sub)
except Exception as e_backbone:
# final fallback: try root vit_backbone
print(f"Backbone attempt failed for sub='{backbone_sub}': {e_backbone}. Trying root 'vit_backbone'.")
backbone = ViTModel.from_pretrained(repo_id, subfolder="vit_backbone")
backbone.to(device)
backbone.eval()
print(f"Loaded processor/backbone from subfolder='{sub or 'root'}'")
return processor, backbone
except Exception as e:
print(f"Processor load failed for sub='{sub}': {e}")
last_exc = e
# ultimate fallback: official ViT from hub
warnings.warn("Could not load processor/backbone from repo; falling back to official 'google/vit-base-patch16-224'.")
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224", use_fast=True)
backbone = ViTModel.from_pretrained("google/vit-base-patch16-224")
backbone.to(device)
backbone.eval()
return processor, backbone
# ----------------- Load assets (robust) -----------------
processor, backbone = load_processor_and_backbone(REPO_ID, nested_folder=NESTED_FOLDER, device=DEVICE)
# Download sklearn artifacts (try root then nested)
scaler_path = hf_download_try(REPO_ID, "scaler.joblib", nested_folder=NESTED_FOLDER)
clf_path = hf_download_try(REPO_ID, "logistic_model.joblib", nested_folder=NESTED_FOLDER)
metadata_path = hf_download_try(REPO_ID, "metadata.json", nested_folder=NESTED_FOLDER)
scaler = joblib.load(scaler_path)
clf = joblib.load(clf_path)
metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
classes = metadata["classes"]
# ----------------- Prediction function -----------------
def predict(image_path: str):
"""Predict ripeness condition for a single image."""
img = Image.open(image_path).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(DEVICE)
with torch.no_grad():
out = backbone(pixel_values=pixel_values, return_dict=True)
pooled = getattr(out, "pooler_output", None)
if pooled is None:
pooled = out.last_hidden_state[:, 0, :]
feat = pooled.cpu().numpy()
feat_scaled = scaler.transform(feat)
# get probabilities (works for sklearn logistic / classifiers with predict_proba)
if hasattr(clf, "predict_proba"):
probs = clf.predict_proba(feat_scaled)[0]
else:
# fallback for classifiers without predict_proba
dec = clf.decision_function(feat_scaled)[0]
exp = np.exp(dec - np.max(dec))
probs = exp / exp.sum()
idx = int(np.argmax(probs))
return classes[idx], float(probs[idx]), {classes[i]: float(probs[i]) for i in range(len(classes))}
# ----------------- Example usage -----------------
if __name__ == "__main__":
sample_image = "my_apple.jpg" # change as needed
label, prob, all_probs = predict(sample_image)
print(f"Prediction: {label} ({prob*100:.2f}%)")
print("\nTop probabilities:")
for cls, p in sorted(all_probs.items(), key=lambda x: -x[1])[:TOP_K]:
print(f" {cls}: {p*100:.2f}%")
```
### Batch Prediction
```python
from pathlib import Path
import csv
def batch_predict(folder_path: str, output_csv: str = "predictions.csv"):
"""Predict ripeness for all images in a folder."""
folder = Path(folder_path)
with open(output_csv, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["filename", "predicted_label", "probability"])
for img_path in sorted(folder.rglob("*")):
if img_path.suffix.lower() not in [".jpg", ".jpeg", ".png", ".bmp"]:
continue
label, prob, _ = predict(str(img_path))
writer.writerow([img_path.name, label, f"{prob*100:.2f}%"])
print(f"Predictions saved to {output_csv}")
# Usage
batch_predict("path/to/images")
```
## Example Output
```
Prediction: rottenapples (71.24%)
Top 5 probabilities:
rottenapples: 71.24%
rottenbanana: 12.35%
freshapples: 6.12%
unripe apple: 4.89%
freshoranges: 2.31%
```
## Repository Structure
```
vit_fruit_ripeness_updated/
β”œβ”€β”€ processor/ # AutoImageProcessor configuration
β”œβ”€β”€ vit_backbone/ # ViT feature extractor weights
β”œβ”€β”€ logistic_model.joblib # Trained classifier
β”œβ”€β”€ scaler.joblib # Feature scaler
β”œβ”€β”€ metadata.json # Class labels and metadata
└── features_extracted.npz # (Optional) Cached features
```
## Performance Notes
- **Best Results**: Fruit is centered and clearly visible in the image
- **Works Well**: Smartphone photos, typical market/kitchen images
- **Challenges**: Heavy background clutter, extreme lighting conditions, unusual fruit varieties
- **Uncertainty Handling**: Use top-K probabilities to assess prediction confidence
## Use Cases
- Quality control in fruit sorting facilities
- Smart grocery shopping apps
- Food waste reduction systems
- Educational tools for agriculture
- Retail inventory management
## Technical Details
- **Input Size**: 224Γ—224 pixels (automatically resized)
- **Inference Speed**: ~50-100ms per image (GPU), ~200-500ms (CPU)
- **Memory Usage**: ~500MB (model weights)
- **Training**: No training required for inference
## License
MIT License - See LICENSE file for details
## Author
**Meetesh Nagrecha**
## Acknowledgments
- Base model: `google/vit-base-patch16-224`
- Framework: Hugging Face Transformers & Scikit-learn
---
**Citation**
If you use this model in your research, please cite:
```bibtex
@misc{vit-fruit-ripeness-classifier,
author = {Nagrecha, Meetesh},
title = {ViT Fruit Ripeness Classifier},
year = {2024},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/Meeteshn/vit_fruit_ripeness_classifier}}
}
```