import streamlit as st import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import os import time # Page configuration st.set_page_config( page_title="🐾 Oxford Pet Classifier", page_icon="🐾", layout="wide", initial_sidebar_state="collapsed" ) # Custom CSS for beautiful styling st.markdown(""" """, unsafe_allow_html=True) # Load class names from your training dataset CLASS_NAMES = sorted(os.listdir("test")) # ensure these match training classes # Load the model with caching @st.cache_resource def load_model(): """Load the trained pet classifier model""" try: model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES)) model.load_state_dict(torch.load("/app/pet_classifier.pth", map_location=torch.device("cpu"))) model.eval() return model except Exception as e: st.error(f"Error loading model: {str(e)}") return None # Header st.markdown("""

🐾 Pet Breed Classifier

Discover your pet's breed with my trained model for prediction.

""", unsafe_allow_html=True) # Sidebar with information with st.sidebar: st.markdown("### 📊 Model Information") st.info(f"**Classes:** {len(CLASS_NAMES)} pet breeds") st.info("**Architecture:** ResNet-18") st.info("**Input Size:** 224x224 pixels") st.markdown("### 🎯 How it works") st.markdown(""" 1. Upload a clear photo of your pet 2. The model analyzes the image features 3. Get instant breed prediction with confidence """) st.markdown("### 💡 Tips for best results") st.markdown(""" - Use high-quality, well-lit photos - Ensure the pet is clearly visible - Avoid blurry or dark images - Single pet per image works best """) # Load model model = load_model() if model is None: st.error("Failed to load the model. Please check if 'pet_classifier.pth' exists.") st.stop() # Image transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3) ]) # Main content area col1, col2 = st.columns([1, 1]) with col1: st.markdown("""

📸 Upload Your Pet's Photo

Choose a clear image of your cat or dog

""", unsafe_allow_html=True) uploaded_file = st.file_uploader( "Choose an image...", type=["jpg", "jpeg", "png"], help="Upload a clear photo of your pet for breed classification" ) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") st.image( image, caption="📷 Your uploaded image", use_column_width=True, clamp=True ) with col2: if uploaded_file is not None: st.markdown("""

🔍 Analyzing Your Pet...

""", unsafe_allow_html=True) # Progress bar animation progress_bar = st.progress(0) status_text = st.empty() # Simulate processing steps for i in range(100): progress_bar.progress(i + 1) if i < 30: status_text.text("🔍 Loading image...") elif i < 60: status_text.text("🧠 Processing with model...") elif i < 90: status_text.text("📊 Analyzing features...") else: status_text.text("✨ Almost done...") time.sleep(0.02) # Clear progress elements progress_bar.empty() status_text.empty() # Preprocess and predict input_tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224) with torch.no_grad(): outputs = model(input_tensor) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) confidence = torch.max(probabilities).item() _, predicted = torch.max(outputs, 1) predicted_label = CLASS_NAMES[predicted.item()] # Results display st.markdown(f"""

🎯 Prediction Results

{predicted_label.replace('_', ' ').title()}

Confidence: {confidence:.1%}

""", unsafe_allow_html=True) # Confidence meter st.markdown("### 📊 Confidence Level") conf_col1, conf_col2, conf_col3 = st.columns([1, 2, 1]) with conf_col2: st.progress(confidence) if confidence > 0.8: st.success(f"Very confident! ({confidence:.1%})") elif confidence > 0.6: st.warning(f"Moderately confident ({confidence:.1%})") else: st.error(f"Low confidence ({confidence:.1%}) - try a clearer image or the animal is not a cat/dog. ") # Top predictions st.markdown("### 🏆 Top 5 Predictions") top_k = torch.topk(probabilities, min(5, len(CLASS_NAMES))) for i, (prob, idx) in enumerate(zip(top_k.values, top_k.indices)): class_name = CLASS_NAMES[idx].replace('_', ' ').title() percentage = prob.item() # Create a nice progress bar for each prediction st.markdown(f"**{i+1}. {class_name}**") st.progress(percentage) st.markdown(f"{percentage:.1%}", unsafe_allow_html=True) st.markdown("---") # Action buttons st.markdown("### 🎬 Actions") col_btn1, col_btn2, col_btn3 = st.columns(3) with col_btn1: if st.button("🔄 Try Another", use_container_width=True): st.rerun() with col_btn2: if st.button("📥 Download Result", use_container_width=True): st.balloons() st.success("Result saved! 🎉") with col_btn3: if st.button("📤 Share", use_container_width=True): st.info("Share feature coming soon! 📱") else: st.markdown("""

🚀 Ready to classify your pet?

Upload an image to get started! Our model can identify dozens of different cat and dog breeds with high accuracy.


Supported breeds include:

🐱 Cats: Persian, Siamese, Maine Coon, and more
🐶 Dogs: Golden Retriever, German Shepherd, Bulldog, and more

""", unsafe_allow_html=True) # Footer st.markdown("---") st.markdown("""

🐾 Built with ❤️ using Streamlit and PyTorch | Oxford Pet Dataset

""", unsafe_allow_html=True)