junaidshafique's picture
Update app.py
77717c2 verified
import streamlit as st
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.models as models
# Load model
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 13) # Change '3' to the number of your classes
model.load_state_dict(torch.load("animal_classifier.pth", map_location=torch.device('cpu')))
model.eval()
# Class labels (update as needed)
class_labels = ["antelope", "badger", "bat", "bear", "bee", "beetle","jellyfish","kangaroo","koala", "ladybug","leopard","lion","lizard"] # <-- Replace with your real class names
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Streamlit UI
st.title("Animal Classifier")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image.", use_column_width=True)
# Predict
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
label = class_labels[predicted.item()]
st.write(f"### Prediction: {label}")