KabeerAmjad's picture
Update app.py
fadf2ad verified
raw
history blame
1.74 kB
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
from transformers import AutoFeatureExtractor
# Load the model from Hugging Face model hub
model_id = "KabeerAmjad/food_classification_model"
# Load ResNet50 model and adjust the final layer
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 11) # Adjust the output layer to match your number of classes
# Load the weights from the Hugging Face model hub
model.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/{model_id}/resolve/main/food_classification_model.pth"))
model.eval()
# Load the feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
# Define the same preprocessing used during training
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the prediction function
def classify_image(img):
# Preprocess the image
img = transform(img).unsqueeze(0) # Add batch dimension
# Make prediction
with torch.no_grad():
outputs = model(img)
probs = torch.softmax(outputs, dim=-1)
# Get the label with the highest probability
top_label = model.config.id2label[probs.argmax().item()] # Map to label (use your custom label mapping if needed)
return top_label
# Create the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Food Image Classification",
description="Upload an image to classify if it’s an apple pie, etc."
)
# Launch the app
iface.launch()