|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
base_model: |
|
|
- google/vit-base-patch16-224 |
|
|
pipeline_tag: image-classification |
|
|
tags: |
|
|
- biology |
|
|
--- |
|
|
# ViT Fruit Ripeness Classifier |
|
|
|
|
|
A fruit ripeness classification model combining **Vision Transformer (ViT)** feature extraction with **Logistic Regression** for fast, accurate inference on CPU or GPU. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model classifies the ripeness condition of **apples, bananas, and oranges** into three categories: |
|
|
- **Fresh** - Ready to eat |
|
|
- **Unripe** - Needs more time to ripen |
|
|
- **Rotten** - Past optimal consumption |
|
|
|
|
|
### Architecture |
|
|
|
|
|
- **Feature Extractor**: ViT Base Patch-16 (`google/vit-base-patch16-224`) |
|
|
- **Classifier**: Scikit-learn Logistic Regression |
|
|
- **Feature Dimension**: 768-dim pooled output from ViT |
|
|
- **Total Classes**: 9 (3 fruits Γ 3 ripeness states) |
|
|
|
|
|
## Supported Classes |
|
|
|
|
|
| Class | Description | |
|
|
|-------|-------------| |
|
|
| `freshapples` | Fresh, ready-to-eat apples | |
|
|
| `freshbanana` | Fresh, ripe bananas | |
|
|
| `freshoranges` | Fresh, ripe oranges | |
|
|
| `rottenapples` | Overripe/rotten apples | |
|
|
| `rottenbanana` | Overripe/rotten bananas | |
|
|
| `rottenoranges` | Overripe/rotten oranges | |
|
|
| `unripe apple` | Unripe apples | |
|
|
| `unripe banana` | Unripe bananas | |
|
|
| `unripe orange` | Unripe oranges | |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision transformers scikit-learn pillow joblib numpy huggingface_hub |
|
|
``` |
|
|
|
|
|
### First Run Notes |
|
|
|
|
|
- **Automatic Downloads**: The model files (~350-400MB) download automatically on first run |
|
|
- **No Manual Downloads**: You don't need to manually download any model files |
|
|
- **Internet Required**: Only for the first run; subsequent runs work offline using cached files |
|
|
- **Time**: First run takes 2-5 minutes for downloads; later runs are instant |
|
|
|
|
|
### Single Image Inference |
|
|
|
|
|
```python |
|
|
|
|
|
import json |
|
|
import joblib |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
import torch |
|
|
import numpy as np |
|
|
from huggingface_hub import hf_hub_download, HfApi |
|
|
from transformers import AutoImageProcessor, ViTModel |
|
|
import warnings |
|
|
|
|
|
# ----------------- CONFIG ----------------- |
|
|
REPO_ID = "Meeteshn/vit_fruit_ripeness_classifier" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
NESTED_FOLDER = "vit_fruit_ripeness_updated" # your repo uses this nested folder |
|
|
TOP_K = 5 |
|
|
# ------------------------------------------ |
|
|
|
|
|
def hf_download_try(repo_id: str, filename: str, nested_folder: str = NESTED_FOLDER): |
|
|
""" |
|
|
Try to download `filename` from repo root, then from nested_folder/filename. |
|
|
Returns local path to downloaded file or raises an informative error. |
|
|
""" |
|
|
candidates = [filename, f"{nested_folder}/{filename}"] |
|
|
last_exc = None |
|
|
for f in candidates: |
|
|
try: |
|
|
print(f"Trying to download '{f}' from '{repo_id}'...") |
|
|
path = hf_hub_download(repo_id=repo_id, filename=f) |
|
|
print("Downloaded:", path) |
|
|
return path |
|
|
except Exception as e: |
|
|
print(f"Not found at '{f}': {e}") |
|
|
last_exc = e |
|
|
raise RuntimeError(f"Could not download '{filename}' from repo '{repo_id}'. Last error: {last_exc}") |
|
|
|
|
|
def load_processor_and_backbone(repo_id: str, nested_folder: str = NESTED_FOLDER, device: str = DEVICE): |
|
|
""" |
|
|
Try several likely subfolder locations for processor/backbone. |
|
|
Returns (processor, backbone). |
|
|
""" |
|
|
# candidate subfolders for processor |
|
|
proc_candidates = [ |
|
|
"processor", |
|
|
f"{nested_folder}/processor", |
|
|
"", # no subfolder (root) |
|
|
] |
|
|
last_exc = None |
|
|
for sub in proc_candidates: |
|
|
try: |
|
|
if sub == "": |
|
|
print(f"Trying AutoImageProcessor.from_pretrained('{repo_id}')") |
|
|
processor = AutoImageProcessor.from_pretrained(repo_id, use_fast=True) |
|
|
else: |
|
|
print(f"Trying AutoImageProcessor.from_pretrained('{repo_id}', subfolder='{sub}')") |
|
|
processor = AutoImageProcessor.from_pretrained(repo_id, subfolder=sub, use_fast=True) |
|
|
# now try backbone with matching guessed subfolder |
|
|
backbone_sub = sub.replace("processor", "vit_backbone") if sub and "processor" in sub else ("vit_backbone" if sub == "" else f"{nested_folder}/vit_backbone") |
|
|
try: |
|
|
print(f"Trying ViTModel.from_pretrained('{repo_id}', subfolder='{backbone_sub}')") |
|
|
backbone = ViTModel.from_pretrained(repo_id, subfolder=backbone_sub) |
|
|
except Exception as e_backbone: |
|
|
# final fallback: try root vit_backbone |
|
|
print(f"Backbone attempt failed for sub='{backbone_sub}': {e_backbone}. Trying root 'vit_backbone'.") |
|
|
backbone = ViTModel.from_pretrained(repo_id, subfolder="vit_backbone") |
|
|
backbone.to(device) |
|
|
backbone.eval() |
|
|
print(f"Loaded processor/backbone from subfolder='{sub or 'root'}'") |
|
|
return processor, backbone |
|
|
except Exception as e: |
|
|
print(f"Processor load failed for sub='{sub}': {e}") |
|
|
last_exc = e |
|
|
# ultimate fallback: official ViT from hub |
|
|
warnings.warn("Could not load processor/backbone from repo; falling back to official 'google/vit-base-patch16-224'.") |
|
|
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224", use_fast=True) |
|
|
backbone = ViTModel.from_pretrained("google/vit-base-patch16-224") |
|
|
backbone.to(device) |
|
|
backbone.eval() |
|
|
return processor, backbone |
|
|
|
|
|
# ----------------- Load assets (robust) ----------------- |
|
|
processor, backbone = load_processor_and_backbone(REPO_ID, nested_folder=NESTED_FOLDER, device=DEVICE) |
|
|
|
|
|
# Download sklearn artifacts (try root then nested) |
|
|
scaler_path = hf_download_try(REPO_ID, "scaler.joblib", nested_folder=NESTED_FOLDER) |
|
|
clf_path = hf_download_try(REPO_ID, "logistic_model.joblib", nested_folder=NESTED_FOLDER) |
|
|
metadata_path = hf_download_try(REPO_ID, "metadata.json", nested_folder=NESTED_FOLDER) |
|
|
|
|
|
scaler = joblib.load(scaler_path) |
|
|
clf = joblib.load(clf_path) |
|
|
metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8")) |
|
|
classes = metadata["classes"] |
|
|
|
|
|
# ----------------- Prediction function ----------------- |
|
|
def predict(image_path: str): |
|
|
"""Predict ripeness condition for a single image.""" |
|
|
img = Image.open(image_path).convert("RGB") |
|
|
inputs = processor(images=img, return_tensors="pt") |
|
|
pixel_values = inputs["pixel_values"].to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = backbone(pixel_values=pixel_values, return_dict=True) |
|
|
pooled = getattr(out, "pooler_output", None) |
|
|
if pooled is None: |
|
|
pooled = out.last_hidden_state[:, 0, :] |
|
|
feat = pooled.cpu().numpy() |
|
|
|
|
|
feat_scaled = scaler.transform(feat) |
|
|
# get probabilities (works for sklearn logistic / classifiers with predict_proba) |
|
|
if hasattr(clf, "predict_proba"): |
|
|
probs = clf.predict_proba(feat_scaled)[0] |
|
|
else: |
|
|
# fallback for classifiers without predict_proba |
|
|
dec = clf.decision_function(feat_scaled)[0] |
|
|
exp = np.exp(dec - np.max(dec)) |
|
|
probs = exp / exp.sum() |
|
|
|
|
|
idx = int(np.argmax(probs)) |
|
|
return classes[idx], float(probs[idx]), {classes[i]: float(probs[i]) for i in range(len(classes))} |
|
|
|
|
|
# ----------------- Example usage ----------------- |
|
|
if __name__ == "__main__": |
|
|
sample_image = "my_apple.jpg" # change as needed |
|
|
label, prob, all_probs = predict(sample_image) |
|
|
print(f"Prediction: {label} ({prob*100:.2f}%)") |
|
|
print("\nTop probabilities:") |
|
|
for cls, p in sorted(all_probs.items(), key=lambda x: -x[1])[:TOP_K]: |
|
|
print(f" {cls}: {p*100:.2f}%") |
|
|
|
|
|
``` |
|
|
|
|
|
### Batch Prediction |
|
|
|
|
|
```python |
|
|
from pathlib import Path |
|
|
import csv |
|
|
|
|
|
def batch_predict(folder_path: str, output_csv: str = "predictions.csv"): |
|
|
"""Predict ripeness for all images in a folder.""" |
|
|
folder = Path(folder_path) |
|
|
|
|
|
with open(output_csv, "w", newline="", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(["filename", "predicted_label", "probability"]) |
|
|
|
|
|
for img_path in sorted(folder.rglob("*")): |
|
|
if img_path.suffix.lower() not in [".jpg", ".jpeg", ".png", ".bmp"]: |
|
|
continue |
|
|
|
|
|
label, prob, _ = predict(str(img_path)) |
|
|
writer.writerow([img_path.name, label, f"{prob*100:.2f}%"]) |
|
|
|
|
|
print(f"Predictions saved to {output_csv}") |
|
|
|
|
|
# Usage |
|
|
batch_predict("path/to/images") |
|
|
``` |
|
|
|
|
|
## Example Output |
|
|
|
|
|
``` |
|
|
Prediction: rottenapples (71.24%) |
|
|
|
|
|
Top 5 probabilities: |
|
|
rottenapples: 71.24% |
|
|
rottenbanana: 12.35% |
|
|
freshapples: 6.12% |
|
|
unripe apple: 4.89% |
|
|
freshoranges: 2.31% |
|
|
``` |
|
|
|
|
|
## Repository Structure |
|
|
|
|
|
``` |
|
|
vit_fruit_ripeness_updated/ |
|
|
βββ processor/ # AutoImageProcessor configuration |
|
|
βββ vit_backbone/ # ViT feature extractor weights |
|
|
βββ logistic_model.joblib # Trained classifier |
|
|
βββ scaler.joblib # Feature scaler |
|
|
βββ metadata.json # Class labels and metadata |
|
|
βββ features_extracted.npz # (Optional) Cached features |
|
|
``` |
|
|
|
|
|
## Performance Notes |
|
|
|
|
|
- **Best Results**: Fruit is centered and clearly visible in the image |
|
|
- **Works Well**: Smartphone photos, typical market/kitchen images |
|
|
- **Challenges**: Heavy background clutter, extreme lighting conditions, unusual fruit varieties |
|
|
- **Uncertainty Handling**: Use top-K probabilities to assess prediction confidence |
|
|
|
|
|
## Use Cases |
|
|
|
|
|
- Quality control in fruit sorting facilities |
|
|
- Smart grocery shopping apps |
|
|
- Food waste reduction systems |
|
|
- Educational tools for agriculture |
|
|
- Retail inventory management |
|
|
|
|
|
## Technical Details |
|
|
|
|
|
- **Input Size**: 224Γ224 pixels (automatically resized) |
|
|
- **Inference Speed**: ~50-100ms per image (GPU), ~200-500ms (CPU) |
|
|
- **Memory Usage**: ~500MB (model weights) |
|
|
- **Training**: No training required for inference |
|
|
|
|
|
|
|
|
## License |
|
|
|
|
|
MIT License - See LICENSE file for details |
|
|
|
|
|
## Author |
|
|
|
|
|
**Meetesh Nagrecha** |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Base model: `google/vit-base-patch16-224` |
|
|
- Framework: Hugging Face Transformers & Scikit-learn |
|
|
|
|
|
--- |
|
|
|
|
|
**Citation** |
|
|
|
|
|
If you use this model in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{vit-fruit-ripeness-classifier, |
|
|
author = {Nagrecha, Meetesh}, |
|
|
title = {ViT Fruit Ripeness Classifier}, |
|
|
year = {2024}, |
|
|
publisher = {Hugging Face}, |
|
|
howpublished = {\url{https://huggingface.co/Meeteshn/vit_fruit_ripeness_classifier}} |
|
|
} |
|
|
``` |