gradio-deploy / gradio_app_image_classification.py
1aguschin's picture
Upload folder using huggingface_hub
d7bf453 verified
import gradio as gr
from PIL import Image
import torch
from torchvision import transforms, models
import requests
# Load pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval()
# Download ImageNet class labels
LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
response = requests.get(LABELS_URL)
LABELS = response.text.split("\n")
# Image preprocessing
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def classify_image(image):
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Preprocess image
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(input_batch)
# Get predicted class
_, predicted_idx = torch.max(output, 1)
predicted_label = LABELS[predicted_idx.item()]
return predicted_label
# Create Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(),
outputs=gr.Text(label="Predicted Class"),
title="Image Classification",
description="Upload an image to classify it using ResNet50"
)
# Launch the app
iface.launch(share=True)