FoodVision / app.py
hari31416's picture
Update app.py
3c7199b
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
from torch import nn
model_name = "b0"
if model_name == "b4":
IMAGE_RESIZE_SHAPE = 384
IMAGE_FINAL_SHAPE = 380
BATCH_SIZE = 32
FEATURE_SHAPE = 1792
if model_name == "b0":
IMAGE_RESIZE_SHAPE = 256
IMAGE_FINAL_SHAPE = 224
BATCH_SIZE = 32
FEATURE_SHAPE = 1280
def load_labels(label_text_path):
with open(label_text_path, "r") as f:
lables = [line.strip() for line in f.readlines()]
label_dict = {i: lables[i] for i in range(len(lables))}
return label_dict
label_dict = load_labels("labels.txt")
# Load PyTorch model
model_params = torch.load("food101.pt", map_location=torch.device("cpu"))
if model_name == "b4":
model = models.efficientnet_b4()
if model_name == "b0":
model = models.efficientnet_b0()
model.eval()
for params in model.parameters():
params.requires_grad = False
model.classifier[1] = nn.Linear(in_features=FEATURE_SHAPE, out_features=101)
model.load_state_dict(model_params)
# Define image transformation
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
transform = transforms.Compose(
[
transforms.Resize(IMAGE_RESIZE_SHAPE),
transforms.CenterCrop(IMAGE_FINAL_SHAPE),
transforms.ToTensor(),
normalize,
]
)
# Define prediction function
def predict_image_class(image):
# Load image
image = Image.fromarray(image.astype("uint8"), "RGB")
# Apply transformation
transformed_image = transform(image)
# Add batch dimension
transformed_image = transformed_image.unsqueeze(0)
# Disable gradient calculation
with torch.no_grad():
# Make prediction
output = model(transformed_image)
_, indices = torch.sort(output, descending=True)
percentage = torch.nn.functional.softmax(output, dim=1)[0]
# create a dictionary of top 10 classes
top_10 = {}
for idx in indices[0][:10]:
top_10[label_dict[idx.item()]] = percentage[idx].item()
return top_10
def main():
# Define Gradio interface
description = "This is a demo of EfficientNet trained on Food101 dataset.\
Upload an image of food and it will predict the class of the food."
inputs = gr.Image()
outputs = gr.Label(num_top_classes=10, label="Prediction")
gradio_app = gr.Interface(
fn=predict_image_class,
inputs=inputs,
outputs=outputs,
title="FoodVision",
description=description,
theme="snehilsanyal/scikit-learn",
examples=[
["examples/pizza.jpg"],
["examples/samosa.jpg"],
],
)
gradio_app.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
# Run Gradio app
main()