--- license: mit datasets: - PedroSampaio/fruits-360 language: - en base_model: - google/efficientnet-b0 pipeline_tag: image-classification tags: - pytorch - torchvision - efficientnet - image-classification - fruits - fruits-360 - transfer-learning - neptune-ai widget: # Example image URLs from the web - replace if you have better ones - src: https://images.unsplash.com/photo-1573246123790-a64e870b8b1a?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 # Example Apple example_title: Apple Example - src: https://images.unsplash.com/photo-1528825871115-3581a5377919?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 # Example Banana example_title: Banana Example --- [DEMO APP](https://huggingface.co/spaces/bhumong/fruit-classifier-app) # Fruit Classifier - EfficientNet-B0 (Fruits-360 Merged) This repository contains a fruit image classification model based on a fine-tuned **EfficientNet-B0** architecture using PyTorch and torchvision. The model was trained on the **Fruits-360 dataset**, with a modification where specific fruit variants were merged into broader categories (e.g., "Apple Red 1", "Apple 6" merged into "Apple"), resulting in **[76]** distinct classes. <-- Make sure this matches your actual class count Training progress and metrics were tracked using **Neptune.ai**. ## Model Description * **Architecture:** EfficientNet-B0 (pre-trained on ImageNet) * **Fine-tuning Strategy:** Transfer learning. The pre-trained base model's weights were frozen, and only the final classifier layer was replaced and trained on the target dataset. * **Framework:** PyTorch / torchvision * **Task:** Image Classification * **Dataset:** Fruits-360 (Merged Classes) * **Number of Classes:** [76] <-- Make sure this matches your actual class count ## Intended Uses & Limitations * **Intended Use:** Classifying images of fruits belonging to one of the [76] merged categories derived from the Fruits-360 dataset. Suitable for educational purposes, demonstrations, or as a baseline for further development. * **Limitations:** * Trained *only* on the Fruits-360 dataset. Performance on images significantly different from this dataset (e.g., different lighting, backgrounds, occlusions, fruit varieties not present) is not guaranteed. * Only recognizes the specific [76] merged classes it was trained on. * Performance may vary depending on input image quality. * Not intended for safety-critical applications without rigorous testing and validation. ## How to Use You can load the model and its configuration directly from the Hugging Face Hub using `torch`, `torchvision`, and `huggingface_hub`. ```python import torch import torchvision.models as models from torchvision.models import EfficientNet_B0_Weights # Or the specific version used from PIL import Image from torchvision import transforms import json import requests from huggingface_hub import hf_hub_download import os # --- 1. Define Model Loading Function --- def load_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.json"): """Loads model state_dict and config from Hugging Face Hub.""" # Download config file config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) with open(config_path, 'r') as f: config = json.load(f) num_labels = config['num_labels'] id2label = config['id2label'] # Load label mapping # Instantiate the correct architecture (EfficientNet-B0) # Load architecture without pre-trained weights, as we'll load our fine-tuned ones model = models.efficientnet_b0(weights=None) # Modify the classifier head to match the number of classes used during training num_ftrs = model.classifier[1].in_features model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels) # Download model weights model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) # Load the state dict # Ensure map_location handles CPU/GPU as needed device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.eval() # Set to evaluation mode print(f"Model loaded successfully from {repo_id} and set to evaluation mode.") return model, config, id2label # --- 2. Define Preprocessing --- # Use the same transformations as validation during training IMG_SIZE = (224, 224) # Standard EfficientNet input size # ImageNet stats often used with EfficientNet pre-training mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] preprocess = transforms.Compose([ transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) # --- 3. Load Model --- repo_id_to_load = "Bhumong/fruit-classifier-efficientnet-b0" # Your repo ID model, config, id2label = load_model_from_hf(repo_id_to_load) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # --- 4. Prepare Input Image --- # Example: Load an image file (replace with your image path) image_path = "path/to/your/fruit_image.jpg" # <-- REPLACE WITH YOUR IMAGE PATH if not os.path.exists(image_path): print(f"Warning: Image path not found: {image_path}") print("Skipping prediction. Please provide a valid image path.") input_batch = None else: try: img = Image.open(image_path).convert("RGB") input_tensor = preprocess(img) # Add batch dimension (model expects batches) input_batch = input_tensor.unsqueeze(0) input_batch = input_batch.to(device) except Exception as e: print(f"Error processing image {image_path}: {e}") input_batch = None # --- 5. Make Prediction --- if input_batch is not None: with torch.no_grad(): # Disable gradient calculations for inference output = model(input_batch) probabilities = torch.nn.functional.softmax(output[0], dim=0) top_prob, top_catid = torch.max(probabilities, dim=0) predicted_label_index = top_catid.item() # Use the id2label mapping loaded from config predicted_label = id2label.get(str(predicted_label_index), "Unknown Label") confidence = top_prob.item() print(f"\nPrediction for: {os.path.basename(image_path)}") print(f"Predicted Label Index: {predicted_label_index}") print(f"Predicted Label: {predicted_label}") print(f"Confidence: {confidence:.4f}")