|
|
import gradio as gr |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
import torchvision.models as models |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
|
|
model.fc = nn.Sequential( |
|
|
nn.Linear(2048, 512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(512, 15) |
|
|
) |
|
|
|
|
|
model.load_state_dict(torch.load("best_model.pth", map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
class_names = [ |
|
|
'Bear', 'Bird', 'Cat', 'Cow', 'Deer', 'Dog', 'Dolphin', |
|
|
'Elephant', 'Giraffe', 'Horse', 'Kangaroo', 'Lion', |
|
|
'Panda', 'Tiger', 'Zebra' |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def classify_image(img): |
|
|
img = transform(img).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(img) |
|
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
|
return {class_names[i]: float(probs[0][i]) for i in range(len(class_names))} |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=classify_image, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=gr.Label(num_top_classes=5), |
|
|
title="Animal Image Classifier", |
|
|
description="Upload an image of an animal and get the top predictions!" |
|
|
) |
|
|
|
|
|
|
|
|
interface.launch() |