| import streamlit as st
|
| import torch
|
| import torch.nn as nn
|
| from torchvision import models, transforms
|
| from PIL import Image
|
| import os
|
|
|
|
|
| CLASS_NAMES = sorted(os.listdir("oxford_pet_dataset/train"))
|
|
|
|
|
| @st.cache_resource
|
| def load_model():
|
| model = models.resnet18(pretrained=False)
|
| model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
|
| model.load_state_dict(torch.load("pet_classifier.pth", map_location=torch.device("cpu")))
|
| model.eval()
|
| return model
|
|
|
| model = load_model()
|
|
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize((224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5]*3, [0.5]*3)
|
| ])
|
|
|
|
|
| st.title("🐾 Oxford Pet Classifier")
|
| st.write("Upload a photo of a cat or dog and I’ll try to guess the breed!")
|
|
|
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
|
|
| if uploaded_file is not None:
|
| image = Image.open(uploaded_file).convert("RGB")
|
| st.image(image, caption="Uploaded Image", use_column_width=True)
|
|
|
|
|
| input_tensor = transform(image).unsqueeze(0)
|
|
|
|
|
| with torch.no_grad():
|
| outputs = model(input_tensor)
|
| _, predicted = torch.max(outputs, 1)
|
| predicted_label = CLASS_NAMES[predicted.item()]
|
|
|
| st.markdown(f"### 🐕 Prediction: **{predicted_label.title()}**")
|
|
|