Food-Classifier / app.py
Lumia101's picture
Update app.py
0e096eb verified
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.transforms import v2
from safetensors.torch import load_file
import json
import gradio as gr
with open("config (1).json") as f:
config = json.load(f)
id2label = config["id2label"]
model = models.efficientnet_b0(weights=None)
model.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=True),
nn.Linear(1280, 512),
nn.SiLU(),
nn.Dropout(0.2),
nn.Linear(512, 101)
)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
transform = v2.Compose([
v2.Resize(160),
v2.CenterCrop(128),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
])
def predict(image, top_k=5):
img = image.convert("RGB")
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
output = model(tensor)
probs = torch.softmax(output, dim=1)[0]
results = {id2label[str(i)]: float(probs[i]) for i in range(len(id2label))}
return results
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload your food image"),
outputs=gr.Label(num_top_classes=5, label="Result"),
title="🍦Food Classifier🍨",
description="What is this food???",
examples=["Pasta.png", "Steak.png"],
theme="soft"
)
demo.launch()