BoTahex's picture
Upload app.py
57d5757 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
CALORIE_DATA = {
"apple_pie": 237, "baby_back_ribs": 290, "baklava": 334,
"beef_carpaccio": 121, "beef_tartare": 200, "beet_salad": 70,
"beignets": 350, "bibimbap": 490, "bread_pudding": 260,
"breakfast_burrito": 305, "bruschetta": 120, "caesar_salad": 180,
"cannoli": 220, "caprese_salad": 250, "carrot_cake": 300,
"ceviche": 130, "cheese_plate": 350, "cheesecake": 320,
"chicken_curry": 240, "chicken_quesadilla": 330,
"chicken_wings": 290, "chocolate_cake": 370, "chocolate_mousse": 210,
"churros": 230, "clam_chowder": 170, "club_sandwich": 350,
"crab_cakes": 220, "creme_brulee": 260, "croque_madame": 450,
"cup_cakes": 305, "deviled_eggs": 130, "donuts": 250,
"dumplings": 210, "edamame": 120, "eggs_benedict": 290,
"escargots": 170, "falafel": 330, "filet_mignon": 280,
"fish_and_chips": 590, "foie_gras": 460, "french_fries": 365,
"french_onion_soup": 210, "french_toast": 260, "fried_calamari": 310,
"fried_rice": 230, "frozen_yogurt": 160, "garlic_bread": 200,
"gnocchi": 250, "greek_salad": 130, "grilled_cheese_sandwich": 370,
"grilled_salmon": 350, "guacamole": 150, "gyoza": 200,
"hamburger": 354, "hot_and_sour_soup": 90, "hot_dog": 290,
"huevos_rancheros": 360, "hummus": 170, "ice_cream": 210,
"lasagna": 290, "lobster_bisque": 240, "lobster_roll_sandwich": 290,
"macaroni_and_cheese": 310, "macarons": 100, "miso_soup": 40,
"mussels": 170, "nachos": 340, "omelette": 150,
"onion_rings": 330, "oysters": 60, "pad_thai": 360,
"paella": 310, "pancakes": 230, "panna_cotta": 340,
"peking_duck": 330, "pho": 350, "pizza": 270,
"pork_chop": 230, "poutine": 510, "prime_rib": 350,
"pulled_pork_sandwich": 390, "ramen": 380, "ravioli": 220,
"red_velvet_cake": 360, "risotto": 340, "samosa": 260,
"sashimi": 130, "scallops": 110, "seaweed_salad": 70,
"shrimp_and_grits": 280, "spaghetti_bolognese": 370,
"spaghetti_carbonara": 390, "spring_rolls": 150,
"steak": 270, "strawberry_shortcake": 280, "sushi": 200,
"tacos": 210, "takoyaki": 170, "tiramisu": 290,
"tuna_tartare": 180, "waffles": 290
}
CLASS_NAMES = sorted(CALORIE_DATA.keys())
def load_model():
model = models.resnet50(weights=None)
num_classes = len(CLASS_NAMES)
model.fc = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(model.fc.in_features, num_classes)
)
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'food_classifier_resnet50.pth')
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
return model
model = load_model()
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def classify_food(image):
if image is None:
return {}, ""
img = Image.fromarray(image).convert("RGB")
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)
top5_probs, top5_indices = torch.topk(probs, 5)
confidences = {}
for i in range(5):
class_name = CLASS_NAMES[top5_indices[0][i].item()]
food_name = class_name.replace("_", " ").title()
confidences[food_name] = float(top5_probs[0][i].item())
top_class = CLASS_NAMES[top5_indices[0][0].item()]
top_food = top_class.replace("_", " ").title()
top_cal = CALORIE_DATA.get(top_class, "N/A")
top_conf = top5_probs[0][0].item() * 100
calorie_text = f"## {top_food}\n"
calorie_text += f"**Confidence:** {top_conf:.1f}%\n\n"
calorie_text += f"**Estimated Calories:** ~{top_cal} kcal per serving\n\n"
calorie_text += "---\n\n"
calorie_text += "**Other possibilities:**\n\n"
for i in range(1, 5):
cls = CLASS_NAMES[top5_indices[0][i].item()]
name = cls.replace("_", " ").title()
cal = CALORIE_DATA.get(cls, "N/A")
conf = top5_probs[0][i].item() * 100
calorie_text += f"| {name} | {conf:.1f}% | ~{cal} kcal |\n"
return confidences, calorie_text
custom_css = """
.gradio-container {
max-width: 900px !important;
margin: auto !important;
}
h1 {
text-align: center;
margin-bottom: 0.2em;
}
.description {
text-align: center;
}
"""
theme = gr.themes.Soft(
primary_hue="orange",
secondary_hue="amber",
neutral_hue="gray",
font=gr.themes.GoogleFont("Inter"),
)
with gr.Blocks(theme=theme, css=custom_css, title="Food Image Classifier") as demo:
gr.Markdown("# Food Image Classifier")
gr.Markdown(
"Upload a photo of any food — the model identifies it from **101 categories** and estimates calories.",
elem_classes="description"
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
image_input = gr.Image(
label="Upload Food Photo",
type="numpy",
height=350
)
classify_btn = gr.Button("Classify", variant="primary", size="lg")
with gr.Column(scale=1):
label_output = gr.Label(num_top_classes=5, label="Top 5 Predictions")
calorie_output = gr.Markdown(label="Details")
classify_btn.click(
fn=classify_food,
inputs=image_input,
outputs=[label_output, calorie_output]
)
image_input.change(
fn=classify_food,
inputs=image_input,
outputs=[label_output, calorie_output]
)
gr.Markdown("---")
gr.Markdown(
"<center><small>Trained on Food-101 dataset (101K images) — "
"<a href='https://github.com/ahmedamr022/Food-Classification'>GitHub</a></small></center>",
sanitize_html=False
)
demo.launch()