import streamlit as st
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import time
# Page configuration
st.set_page_config(
page_title="🐾 Oxford Pet Classifier",
page_icon="🐾",
layout="wide",
initial_sidebar_state="collapsed"
)
# Custom CSS for beautiful styling
st.markdown("""
""", unsafe_allow_html=True)
# Load class names from your training dataset
CLASS_NAMES = sorted(os.listdir("test")) # ensure these match training classes
# Load the model with caching
@st.cache_resource
def load_model():
"""Load the trained pet classifier model"""
try:
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("/app/pet_classifier.pth", map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
# Header
st.markdown("""
""", unsafe_allow_html=True)
# Sidebar with information
with st.sidebar:
st.markdown("### 📊 Model Information")
st.info(f"**Classes:** {len(CLASS_NAMES)} pet breeds")
st.info("**Architecture:** ResNet-18")
st.info("**Input Size:** 224x224 pixels")
st.markdown("### 🎯 How it works")
st.markdown("""
1. Upload a clear photo of your pet
2. The model analyzes the image features
3. Get instant breed prediction with confidence
""")
st.markdown("### 💡 Tips for best results")
st.markdown("""
- Use high-quality, well-lit photos
- Ensure the pet is clearly visible
- Avoid blurry or dark images
- Single pet per image works best
""")
# Load model
model = load_model()
if model is None:
st.error("Failed to load the model. Please check if 'pet_classifier.pth' exists.")
st.stop()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
# Main content area
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("""
📸 Upload Your Pet's Photo
Choose a clear image of your cat or dog
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png"],
help="Upload a clear photo of your pet for breed classification"
)
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(
image,
caption="📷 Your uploaded image",
use_column_width=True,
clamp=True
)
with col2:
if uploaded_file is not None:
st.markdown("""
🔍 Analyzing Your Pet...
""", unsafe_allow_html=True)
# Progress bar animation
progress_bar = st.progress(0)
status_text = st.empty()
# Simulate processing steps
for i in range(100):
progress_bar.progress(i + 1)
if i < 30:
status_text.text("🔍 Loading image...")
elif i < 60:
status_text.text("🧠 Processing with model...")
elif i < 90:
status_text.text("📊 Analyzing features...")
else:
status_text.text("✨ Almost done...")
time.sleep(0.02)
# Clear progress elements
progress_bar.empty()
status_text.empty()
# Preprocess and predict
input_tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
confidence = torch.max(probabilities).item()
_, predicted = torch.max(outputs, 1)
predicted_label = CLASS_NAMES[predicted.item()]
# Results display
st.markdown(f"""
🎯 Prediction Results
{predicted_label.replace('_', ' ').title()}
Confidence: {confidence:.1%}
""", unsafe_allow_html=True)
# Confidence meter
st.markdown("### 📊 Confidence Level")
conf_col1, conf_col2, conf_col3 = st.columns([1, 2, 1])
with conf_col2:
st.progress(confidence)
if confidence > 0.8:
st.success(f"Very confident! ({confidence:.1%})")
elif confidence > 0.6:
st.warning(f"Moderately confident ({confidence:.1%})")
else:
st.error(f"Low confidence ({confidence:.1%}) - try a clearer image or the animal is not a cat/dog. ")
# Top predictions
st.markdown("### 🏆 Top 5 Predictions")
top_k = torch.topk(probabilities, min(5, len(CLASS_NAMES)))
for i, (prob, idx) in enumerate(zip(top_k.values, top_k.indices)):
class_name = CLASS_NAMES[idx].replace('_', ' ').title()
percentage = prob.item()
# Create a nice progress bar for each prediction
st.markdown(f"**{i+1}. {class_name}**")
st.progress(percentage)
st.markdown(f"{percentage:.1%}", unsafe_allow_html=True)
st.markdown("---")
# Action buttons
st.markdown("### 🎬 Actions")
col_btn1, col_btn2, col_btn3 = st.columns(3)
with col_btn1:
if st.button("🔄 Try Another", use_container_width=True):
st.rerun()
with col_btn2:
if st.button("📥 Download Result", use_container_width=True):
st.balloons()
st.success("Result saved! 🎉")
with col_btn3:
if st.button("📤 Share", use_container_width=True):
st.info("Share feature coming soon! 📱")
else:
st.markdown("""
🚀 Ready to classify your pet?
Upload an image to get started! Our model can identify dozens of different cat and dog breeds with high accuracy.
Supported breeds include:
🐱 Cats: Persian, Siamese, Maine Coon, and more
🐶 Dogs: Golden Retriever, German Shepherd, Bulldog, and more
""", unsafe_allow_html=True)
# Footer
st.markdown("---")
st.markdown("""
🐾 Built with ❤️ using Streamlit and PyTorch | Oxford Pet Dataset
""", unsafe_allow_html=True)