Spaces:
Sleeping
Sleeping
File size: 2,839 Bytes
acc38f0 3c7199b acc38f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
from torch import nn
model_name = "b0"
if model_name == "b4":
IMAGE_RESIZE_SHAPE = 384
IMAGE_FINAL_SHAPE = 380
BATCH_SIZE = 32
FEATURE_SHAPE = 1792
if model_name == "b0":
IMAGE_RESIZE_SHAPE = 256
IMAGE_FINAL_SHAPE = 224
BATCH_SIZE = 32
FEATURE_SHAPE = 1280
def load_labels(label_text_path):
with open(label_text_path, "r") as f:
lables = [line.strip() for line in f.readlines()]
label_dict = {i: lables[i] for i in range(len(lables))}
return label_dict
label_dict = load_labels("labels.txt")
# Load PyTorch model
model_params = torch.load("food101.pt", map_location=torch.device("cpu"))
if model_name == "b4":
model = models.efficientnet_b4()
if model_name == "b0":
model = models.efficientnet_b0()
model.eval()
for params in model.parameters():
params.requires_grad = False
model.classifier[1] = nn.Linear(in_features=FEATURE_SHAPE, out_features=101)
model.load_state_dict(model_params)
# Define image transformation
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
transform = transforms.Compose(
[
transforms.Resize(IMAGE_RESIZE_SHAPE),
transforms.CenterCrop(IMAGE_FINAL_SHAPE),
transforms.ToTensor(),
normalize,
]
)
# Define prediction function
def predict_image_class(image):
# Load image
image = Image.fromarray(image.astype("uint8"), "RGB")
# Apply transformation
transformed_image = transform(image)
# Add batch dimension
transformed_image = transformed_image.unsqueeze(0)
# Disable gradient calculation
with torch.no_grad():
# Make prediction
output = model(transformed_image)
_, indices = torch.sort(output, descending=True)
percentage = torch.nn.functional.softmax(output, dim=1)[0]
# create a dictionary of top 10 classes
top_10 = {}
for idx in indices[0][:10]:
top_10[label_dict[idx.item()]] = percentage[idx].item()
return top_10
def main():
# Define Gradio interface
description = "This is a demo of EfficientNet trained on Food101 dataset.\
Upload an image of food and it will predict the class of the food."
inputs = gr.Image()
outputs = gr.Label(num_top_classes=10, label="Prediction")
gradio_app = gr.Interface(
fn=predict_image_class,
inputs=inputs,
outputs=outputs,
title="FoodVision",
description=description,
theme="snehilsanyal/scikit-learn",
examples=[
["examples/pizza.jpg"],
["examples/samosa.jpg"],
],
)
gradio_app.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
# Run Gradio app
main()
|