Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import model_builder as mb | |
| from torchvision import transforms | |
| import torch | |
| # Comprehensive nutrition data per 165g serving | |
| NUTRITION_DATA = { | |
| 'Fresh Apple': { | |
| 'macronutrients': { | |
| 'calories': 99.2, | |
| 'protein': 0.8, | |
| 'carbs': 23.3, | |
| 'fats': 0.3, | |
| 'water': 140.2, | |
| 'fiber': 1.5 | |
| }, | |
| 'micronutrients': { | |
| 'vitamin_c': 96.7, | |
| 'thiamin': 0.1, | |
| 'niacin': 0.4, | |
| 'vitamin_b6': 0.2 | |
| }, | |
| 'macrominerals': { | |
| 'magnesium': 22.1, | |
| 'phosphorus': 8.9, | |
| 'potassium': 226.0, | |
| 'calcium': 20.6 | |
| } | |
| }, | |
| 'Fresh Banana': { | |
| 'macronutrients': { | |
| 'calories': 147.0, | |
| 'protein': 1.8, | |
| 'carbs': 38.0, | |
| 'fats': 0.5, | |
| 'water': 132.0, | |
| 'fiber': 3.5 | |
| }, | |
| 'micronutrients': { | |
| 'vitamin_c': 14.7, | |
| 'thiamin': 0.4, | |
| 'niacin': 1.2, | |
| 'vitamin_b6': 0.5 | |
| }, | |
| 'macrominerals': { | |
| 'magnesium': 41.3, | |
| 'phosphorus': 33.0, | |
| 'potassium': 537.0, | |
| 'calcium': 8.3 | |
| } | |
| }, | |
| 'Fresh Orange': { | |
| 'macronutrients': { | |
| 'calories': 82.0, | |
| 'protein': 1.6, | |
| 'carbs': 21.0, | |
| 'fats': 0.2, | |
| 'water': 146.0, | |
| 'fiber': 4.0 | |
| }, | |
| 'micronutrients': { | |
| 'vitamin_c': 82.7, | |
| 'thiamin': 0.2, | |
| 'niacin': 0.5, | |
| 'vitamin_b6': 0.1 | |
| }, | |
| 'macrominerals': { | |
| 'magnesium': 18.2, | |
| 'phosphorus': 28.1, | |
| 'potassium': 237.6, | |
| 'calcium': 74.3 | |
| } | |
| } | |
| } | |
| device = torch.device("cpu") | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| manual_transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize(size=(224, 224)), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| class_names = ['Fresh Apple', | |
| 'Fresh Banana', | |
| 'Fresh Orange', | |
| 'Rotten Apple', | |
| 'Rotten Banana', | |
| 'Rotten Orange'] | |
| model = mb.create_model_baseline_effnetb0(out_feats=len(class_names), device=device) | |
| model.load_state_dict(torch.load(f="models/effnetb0_freshvisionv0_10_epochs.pt", map_location="cpu")) | |
| def format_nutrition(fruit_name): | |
| """Format comprehensive nutrition information for display""" | |
| if fruit_name not in NUTRITION_DATA: | |
| return "" | |
| nutrition = NUTRITION_DATA[fruit_name] | |
| macro = nutrition['macronutrients'] | |
| micro = nutrition['micronutrients'] | |
| minerals = nutrition['macrominerals'] | |
| return f""" | |
| Nutritional Information (per 165g serving): | |
| Macronutrients: | |
| • Calories: {macro['calories']} kcal | |
| • Protein: {macro['protein']} g | |
| • Carbs: {macro['carbs']} g | |
| • Fats: {macro['fats']} g | |
| • Water: {macro['water']} ml | |
| • Fiber: {macro['fiber']} g | |
| Micronutrients: | |
| • Vitamin C: {micro['vitamin_c']} mg | |
| • Thiamin: {micro['thiamin']} mg | |
| • Niacin: {micro['niacin']} mg | |
| • Vitamin B6: {micro['vitamin_b6']} mg | |
| Macrominerals: | |
| • Magnesium: {minerals['magnesium']} mg | |
| • Phosphorus: {minerals['phosphorus']} mg | |
| • Potassium: {minerals['potassium']} mg | |
| • Calcium: {minerals['calcium']} mg | |
| """ | |
| def pred(img): | |
| model.eval() | |
| transformed = manual_transform(img).to(device) | |
| with torch.inference_mode(): | |
| logits = model(transformed.unsqueeze(dim=0)) | |
| pred = torch.softmax(logits, dim=-1) | |
| predicted_class = class_names[pred.argmax(dim=-1).item()] | |
| confidence = pred.max().item() | |
| result = f"Prediction: {predicted_class} | Confidence: {confidence:.3f}" | |
| # Add nutrition information if it's a fresh fruit | |
| if predicted_class.startswith('Fresh'): | |
| nutrition_info = format_nutrition(predicted_class) | |
| result += f"\n{nutrition_info}" | |
| return result | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown(""" | |
| # Welcome to FreshVision 📷 | |
| _FreshVision is a machine learning model to classify freshness for fruits. This model | |
| utilizes transfer learning from pre-trained model from PyTorch [EfficientNetB0](https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b0.html). | |
| This model has been trained on [kaggle datasets](https://www.kaggle.com/datasets/sriramr/fruits-fresh-and-rotten-for-classification) using NVIDIA T4 GPU._ | |
| ## Model capabilities: | |
| - Classify freshness from fruits image (apple, orange, and banana) with two labels: *Fresh* and *Rotten/spoiled* | |
| - Provides comprehensive nutritional information for fresh fruits including: | |
| * Macronutrients (calories, protein, carbs, fats, water, fiber) | |
| * Micronutrients (vitamins C, B6, thiamin, niacin) | |
| * Macrominerals (magnesium, phosphorus, potassium, calcium) | |
| ## Model drawbacks: | |
| - Sometimes perform false prediction on some fruits condition, this is due to low variability on the image datasets | |
| - Can't perform accurate prediction on multiple objects/combined condition (e.g. two bananas with different freshness condition) | |
| - This models can't identify non-fruits objects, since it's only trained with fruits dataset | |
| ## **How to get the best result with this model:** | |
| 1. The image should only contain fruits that the model can recognize (apple, orange, and banana) | |
| 2. The image should only contain one object (one fruit) | |
| 3. Ensure the object is captured with sufficient light so that the surface of the fruits is exposed properly | |
| get the [source code](https://github.com/devdezzies/freshvision) on my github | |
| """) | |
| gr.Interface( | |
| fn=pred, | |
| inputs=gr.Image(), | |
| outputs=gr.Textbox(label="Prediction Results", lines=15), | |
| title="FreshVision Fruit Classifier" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |