IS_Finals / app.py
Tzetha's picture
Uploaded Complete App
81e78bd verified
import streamlit as st
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
# Load class names from your training dataset
CLASS_NAMES = sorted(os.listdir("oxford_pet_dataset/train")) # ensure these match training classes
# Load the model
@st.cache_resource
def load_model():
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("pet_classifier.pth", map_location=torch.device("cpu")))
model.eval()
return model
model = load_model()
# Image transform (should match training)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
# Streamlit UI
st.title("🐾 Oxford Pet Classifier")
st.write("Upload a photo of a cat or dog and I’ll try to guess the breed!")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Preprocess
input_tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224)
# Predict
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
predicted_label = CLASS_NAMES[predicted.item()]
st.markdown(f"### 🐕 Prediction: **{predicted_label.title()}**")