IS_Finals / src /streamlit_app.py
Tzetha's picture
Update src/streamlit_app.py
d869948 verified
import streamlit as st
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import time
# Page configuration
st.set_page_config(
page_title="🐾 Oxford Pet Classifier",
page_icon="🐾",
layout="wide",
initial_sidebar_state="collapsed"
)
# Custom CSS for beautiful styling
st.markdown("""
<style>
/* Main background and theme */
.main {
padding-top: 2rem;
}
/* Custom header styling */
.custom-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 15px;
margin-bottom: 2rem;
text-align: center;
color: white;
box-shadow: 0 10px 30px rgba(102, 126, 234, 0.3);
}
.custom-header h1 {
font-size: 3rem;
margin: 0;
font-weight: 700;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.custom-header p {
font-size: 1.2rem;
margin: 0.5rem 0 0 0;
opacity: 0.9;
}
/* Upload area styling */
.upload-container {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
padding: 2rem;
border-radius: 15px;
margin: 1rem 0;
text-align: center;
color: white;
box-shadow: 0 8px 25px rgba(240, 147, 251, 0.3);
}
/* Results container */
.results-container {
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
padding: 2rem;
border-radius: 15px;
margin: 1rem 0;
text-align: center;
color: white;
box-shadow: 0 8px 25px rgba(79, 172, 254, 0.3);
animation: slideIn 0.5s ease-out;
}
@keyframes slideIn {
from {
transform: translateY(20px);
opacity: 0;
}
to {
transform: translateY(0);
opacity: 1;
}
}
/* Custom buttons */
.stButton > button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 25px;
padding: 0.75rem 2rem;
font-weight: 600;
transition: all 0.3s ease;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4);
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6);
}
/* File uploader styling */
.uploadedFile {
border-radius: 10px;
border: 2px dashed #667eea;
padding: 1rem;
}
/* Progress bar */
.stProgress > div > div > div > div {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
/* Info boxes */
.info-box {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
border-radius: 15px;
padding: 1.5rem;
margin: 1rem 0;
border: 1px solid rgba(255, 255, 255, 0.2);
}
/* Hide Streamlit elements */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stDeployButton {display:none;}
/* Custom metrics */
[data-testid="metric-container"] {
background: rgba(255, 255, 255, 0.1);
border: 1px solid rgba(255, 255, 255, 0.2);
padding: 1rem;
border-radius: 10px;
backdrop-filter: blur(10px);
}
</style>
""", unsafe_allow_html=True)
# Load class names from your training dataset
CLASS_NAMES = sorted(os.listdir("test")) # ensure these match training classes
# Load the model with caching
@st.cache_resource
def load_model():
"""Load the trained pet classifier model"""
try:
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("/app/pet_classifier.pth", map_location=torch.device("cpu")))
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
# Header
st.markdown("""
<div class="custom-header">
<h1>🐾 Pet Breed Classifier</h1>
<p>Discover your pet's breed with my trained model for prediction.</p>
</div>
""", unsafe_allow_html=True)
# Sidebar with information
with st.sidebar:
st.markdown("### πŸ“Š Model Information")
st.info(f"**Classes:** {len(CLASS_NAMES)} pet breeds")
st.info("**Architecture:** ResNet-18")
st.info("**Input Size:** 224x224 pixels")
st.markdown("### 🎯 How it works")
st.markdown("""
1. Upload a clear photo of your pet
2. The model analyzes the image features
3. Get instant breed prediction with confidence
""")
st.markdown("### πŸ’‘ Tips for best results")
st.markdown("""
- Use high-quality, well-lit photos
- Ensure the pet is clearly visible
- Avoid blurry or dark images
- Single pet per image works best
""")
# Load model
model = load_model()
if model is None:
st.error("Failed to load the model. Please check if 'pet_classifier.pth' exists.")
st.stop()
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3)
])
# Main content area
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("""
<div class="upload-container">
<h3>πŸ“Έ Upload Your Pet's Photo</h3>
<p>Choose a clear image of your cat or dog</p>
</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png"],
help="Upload a clear photo of your pet for breed classification"
)
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(
image,
caption="πŸ“· Your uploaded image",
use_column_width=True,
clamp=True
)
with col2:
if uploaded_file is not None:
st.markdown("""
<div class="results-container">
<h3>πŸ” Analyzing Your Pet...</h3>
</div>
""", unsafe_allow_html=True)
# Progress bar animation
progress_bar = st.progress(0)
status_text = st.empty()
# Simulate processing steps
for i in range(100):
progress_bar.progress(i + 1)
if i < 30:
status_text.text("πŸ” Loading image...")
elif i < 60:
status_text.text("🧠 Processing with model...")
elif i < 90:
status_text.text("πŸ“Š Analyzing features...")
else:
status_text.text("✨ Almost done...")
time.sleep(0.02)
# Clear progress elements
progress_bar.empty()
status_text.empty()
# Preprocess and predict
input_tensor = transform(image).unsqueeze(0) # (1, 3, 224, 224)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
confidence = torch.max(probabilities).item()
_, predicted = torch.max(outputs, 1)
predicted_label = CLASS_NAMES[predicted.item()]
# Results display
st.markdown(f"""
<div class="results-container">
<h2>🎯 Prediction Results</h2>
<h1 style="font-size: 2.5rem; margin: 1rem 0;">
{predicted_label.replace('_', ' ').title()}
</h1>
<p style="font-size: 1.2rem;">
Confidence: {confidence:.1%}
</p>
</div>
""", unsafe_allow_html=True)
# Confidence meter
st.markdown("### πŸ“Š Confidence Level")
conf_col1, conf_col2, conf_col3 = st.columns([1, 2, 1])
with conf_col2:
st.progress(confidence)
if confidence > 0.8:
st.success(f"Very confident! ({confidence:.1%})")
elif confidence > 0.6:
st.warning(f"Moderately confident ({confidence:.1%})")
else:
st.error(f"Low confidence ({confidence:.1%}) - try a clearer image or the animal is not a cat/dog. ")
# Top predictions
st.markdown("### πŸ† Top 5 Predictions")
top_k = torch.topk(probabilities, min(5, len(CLASS_NAMES)))
for i, (prob, idx) in enumerate(zip(top_k.values, top_k.indices)):
class_name = CLASS_NAMES[idx].replace('_', ' ').title()
percentage = prob.item()
# Create a nice progress bar for each prediction
st.markdown(f"**{i+1}. {class_name}**")
st.progress(percentage)
st.markdown(f"<small>{percentage:.1%}</small>", unsafe_allow_html=True)
st.markdown("---")
# Action buttons
st.markdown("### 🎬 Actions")
col_btn1, col_btn2, col_btn3 = st.columns(3)
with col_btn1:
if st.button("πŸ”„ Try Another", use_container_width=True):
st.rerun()
with col_btn2:
if st.button("πŸ“₯ Download Result", use_container_width=True):
st.balloons()
st.success("Result saved! πŸŽ‰")
with col_btn3:
if st.button("πŸ“€ Share", use_container_width=True):
st.info("Share feature coming soon! πŸ“±")
else:
st.markdown("""
<div class="info-box">
<h3>πŸš€ Ready to classify your pet?</h3>
<p>Upload an image to get started! Our model can identify dozens of different cat and dog breeds with high accuracy.</p>
<br>
<p><strong>Supported breeds include:</strong></p>
<p>🐱 Cats: Persian, Siamese, Maine Coon, and more<br>
🐢 Dogs: Golden Retriever, German Shepherd, Bulldog, and more</p>
</div>
""", unsafe_allow_html=True)
# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center; padding: 2rem; opacity: 0.7;">
<p>🐾 Built with ❀️ using Streamlit and PyTorch | Oxford Pet Dataset</p>
</div>
""", unsafe_allow_html=True)