Ahmad-01's picture
Create app.py
35ffd37 verified
import gradio as gr
from PIL import Image
import torch
from torchvision import transforms, models
# Load trained model
checkpoint = torch.load("animal_model.pth", map_location="cpu")
class_names = checkpoint["class_names"]
# Define model architecture
model = models.resnet50(weights=None) # same as trained
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Prediction function
def predict(image):
img = Image.fromarray(image).convert("RGB")
img = transform(img).unsqueeze(0) # add batch dimension
with torch.no_grad():
outputs = model(img)
_, pred = torch.max(outputs, 1)
return class_names[pred.item()]
# Gradio Interface
app = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs="text",
title="Animal Image Classifier",
description="Upload an image of an animal and the model will classify it."
)
if __name__ == "__main__":
app.launch()