Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| # Mengunduh model dari Hugging Face | |
| model_path = hf_hub_download(repo_id="ahmadalfian/fruits_vegetables_classifier", filename="resnet50_finetuned.pth") | |
| # Menginisialisasi model | |
| class FruitVegetableClassifier(torch.nn.Module): | |
| def __init__(self): | |
| super(FruitVegetableClassifier, self).__init__() | |
| self.model = torch.nn.Sequential( | |
| torch.load(model_path, map_location=torch.device('cpu')) # Memastikan dimuat di CPU | |
| ) | |
| self.model.eval() | |
| def forward(self, x): | |
| return self.model(x) | |
| model = FruitVegetableClassifier() | |
| # Fungsi untuk memprediksi kelas | |
| def predict(image): | |
| # Proses gambar di sini (ubah ukuran, normalisasi, dll.) | |
| image = image.convert("RGB") | |
| image = image.resize((224, 224)) # Resize ke ukuran yang diharapkan model | |
| image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float().unsqueeze(0) # Ubah menjadi tensor | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| predictions = outputs.argmax(dim=1) | |
| return predictions.item() | |
| # Fungsi untuk mengambil informasi nutrisi | |
| def get_nutritional_info(food): | |
| api_key = "3pm2NGZzYongVN1gRjnroVLUpsHC8rKWJFyx5moq" | |
| url = "https://api.nal.usda.gov/fdc/v1/foods/search" | |
| params = { | |
| "query": food, | |
| "pageSize": 5, | |
| "api_key": api_key | |
| } | |
| response = requests.get(url, params=params) | |
| data = response.json() | |
| if "foods" in data and len(data["foods"]) > 0: | |
| nutrients_totals = { | |
| "Energy": 0, | |
| "Carbohydrate, by difference": 0, | |
| "Fiber, total dietary": 0, | |
| "Vitamin C, total ascorbic acid": 0 | |
| } | |
| item_count = len(data["foods"]) | |
| for food in data["foods"]: | |
| for nutrient in food['foodNutrients']: | |
| nutrient_name = nutrient['nutrientName'] | |
| nutrient_value = nutrient['value'] | |
| if nutrient_name in nutrients_totals: | |
| nutrients_totals[nutrient_name] += nutrient_value | |
| average_nutrients = {name: total / item_count for name, total in nutrients_totals.items()} | |
| return average_nutrients | |
| else: | |
| return None | |
| # Fungsi utama Gradio | |
| def classify_and_get_nutrition(image): | |
| predicted_class_idx = predict(image) | |
| class_labels = [ | |
| 'apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', | |
| 'carrot', 'cauliflower', 'chilli pepper', 'corn', 'cucumber', | |
| 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', | |
| 'lemon', 'lettuce', 'mango', 'onion', 'orange', 'paprika', | |
| 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', | |
| 'soy beans', 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', | |
| 'turnip', 'watermelon' | |
| ] | |
| predicted_label = class_labels[predicted_class_idx] | |
| nutrisi = get_nutritional_info(predicted_label) | |
| if nutrisi: | |
| return { | |
| "Predicted Class": predicted_label, | |
| "Energy (kcal)": nutrisi["Energy"], | |
| "Carbohydrates (g)": nutrisi["Carbohydrate, by difference"], | |
| "Fiber (g)": nutrisi["Fiber, total dietary"], | |
| "Vitamin C (mg)": nutrisi["Vitamin C, total ascorbic acid"] | |
| } | |
| else: | |
| return { | |
| "Predicted Class": predicted_label, | |
| "Nutritional Information": "Not Found" | |
| } | |
| # Antarmuka Gradio | |
| iface = gr.Interface( | |
| fn=classify_and_get_nutrition, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.JSON(), | |
| title="Fruits and Vegetables Classifier", | |
| description="Upload an image of a fruit or vegetable to classify and get its nutritional information." | |
| ) | |
| iface.launch() | |