|
|
import streamlit as st |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
import os |
|
|
import time |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="πΎ Oxford Pet Classifier", |
|
|
page_icon="πΎ", |
|
|
layout="wide", |
|
|
initial_sidebar_state="collapsed" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
/* Main background and theme */ |
|
|
.main { |
|
|
padding-top: 2rem; |
|
|
} |
|
|
|
|
|
/* Custom header styling */ |
|
|
.custom-header { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
padding: 2rem; |
|
|
border-radius: 15px; |
|
|
margin-bottom: 2rem; |
|
|
text-align: center; |
|
|
color: white; |
|
|
box-shadow: 0 10px 30px rgba(102, 126, 234, 0.3); |
|
|
} |
|
|
|
|
|
.custom-header h1 { |
|
|
font-size: 3rem; |
|
|
margin: 0; |
|
|
font-weight: 700; |
|
|
text-shadow: 2px 2px 4px rgba(0,0,0,0.3); |
|
|
} |
|
|
|
|
|
.custom-header p { |
|
|
font-size: 1.2rem; |
|
|
margin: 0.5rem 0 0 0; |
|
|
opacity: 0.9; |
|
|
} |
|
|
|
|
|
/* Upload area styling */ |
|
|
.upload-container { |
|
|
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); |
|
|
padding: 2rem; |
|
|
border-radius: 15px; |
|
|
margin: 1rem 0; |
|
|
text-align: center; |
|
|
color: white; |
|
|
box-shadow: 0 8px 25px rgba(240, 147, 251, 0.3); |
|
|
} |
|
|
|
|
|
/* Results container */ |
|
|
.results-container { |
|
|
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); |
|
|
padding: 2rem; |
|
|
border-radius: 15px; |
|
|
margin: 1rem 0; |
|
|
text-align: center; |
|
|
color: white; |
|
|
box-shadow: 0 8px 25px rgba(79, 172, 254, 0.3); |
|
|
animation: slideIn 0.5s ease-out; |
|
|
} |
|
|
|
|
|
@keyframes slideIn { |
|
|
from { |
|
|
transform: translateY(20px); |
|
|
opacity: 0; |
|
|
} |
|
|
to { |
|
|
transform: translateY(0); |
|
|
opacity: 1; |
|
|
} |
|
|
} |
|
|
|
|
|
/* Custom buttons */ |
|
|
.stButton > button { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
border: none; |
|
|
border-radius: 25px; |
|
|
padding: 0.75rem 2rem; |
|
|
font-weight: 600; |
|
|
transition: all 0.3s ease; |
|
|
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4); |
|
|
} |
|
|
|
|
|
.stButton > button:hover { |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6); |
|
|
} |
|
|
|
|
|
/* File uploader styling */ |
|
|
.uploadedFile { |
|
|
border-radius: 10px; |
|
|
border: 2px dashed #667eea; |
|
|
padding: 1rem; |
|
|
} |
|
|
|
|
|
/* Progress bar */ |
|
|
.stProgress > div > div > div > div { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
} |
|
|
|
|
|
/* Info boxes */ |
|
|
.info-box { |
|
|
background: rgba(255, 255, 255, 0.1); |
|
|
backdrop-filter: blur(10px); |
|
|
border-radius: 15px; |
|
|
padding: 1.5rem; |
|
|
margin: 1rem 0; |
|
|
border: 1px solid rgba(255, 255, 255, 0.2); |
|
|
} |
|
|
|
|
|
/* Hide Streamlit elements */ |
|
|
#MainMenu {visibility: hidden;} |
|
|
footer {visibility: hidden;} |
|
|
.stDeployButton {display:none;} |
|
|
|
|
|
/* Custom metrics */ |
|
|
[data-testid="metric-container"] { |
|
|
background: rgba(255, 255, 255, 0.1); |
|
|
border: 1px solid rgba(255, 255, 255, 0.2); |
|
|
padding: 1rem; |
|
|
border-radius: 10px; |
|
|
backdrop-filter: blur(10px); |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
CLASS_NAMES = sorted(os.listdir("test")) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
"""Load the trained pet classifier model""" |
|
|
try: |
|
|
model = models.resnet18(pretrained=False) |
|
|
model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES)) |
|
|
model.load_state_dict(torch.load("/app/pet_classifier.pth", map_location=torch.device("cpu"))) |
|
|
model.eval() |
|
|
return model |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div class="custom-header"> |
|
|
<h1>πΎ Pet Breed Classifier</h1> |
|
|
<p>Discover your pet's breed with my trained model for prediction.</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.markdown("### π Model Information") |
|
|
st.info(f"**Classes:** {len(CLASS_NAMES)} pet breeds") |
|
|
st.info("**Architecture:** ResNet-18") |
|
|
st.info("**Input Size:** 224x224 pixels") |
|
|
|
|
|
st.markdown("### π― How it works") |
|
|
st.markdown(""" |
|
|
1. Upload a clear photo of your pet |
|
|
2. The model analyzes the image features |
|
|
3. Get instant breed prediction with confidence |
|
|
""") |
|
|
|
|
|
st.markdown("### π‘ Tips for best results") |
|
|
st.markdown(""" |
|
|
- Use high-quality, well-lit photos |
|
|
- Ensure the pet is clearly visible |
|
|
- Avoid blurry or dark images |
|
|
- Single pet per image works best |
|
|
""") |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
if model is None: |
|
|
st.error("Failed to load the model. Please check if 'pet_classifier.pth' exists.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5]*3, [0.5]*3) |
|
|
]) |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
|
|
with col1: |
|
|
st.markdown(""" |
|
|
<div class="upload-container"> |
|
|
<h3>πΈ Upload Your Pet's Photo</h3> |
|
|
<p>Choose a clear image of your cat or dog</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"Choose an image...", |
|
|
type=["jpg", "jpeg", "png"], |
|
|
help="Upload a clear photo of your pet for breed classification" |
|
|
) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
st.image( |
|
|
image, |
|
|
caption="π· Your uploaded image", |
|
|
use_column_width=True, |
|
|
clamp=True |
|
|
) |
|
|
|
|
|
|
|
|
with col2: |
|
|
if uploaded_file is not None: |
|
|
st.markdown(""" |
|
|
<div class="results-container"> |
|
|
<h3>π Analyzing Your Pet...</h3> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
progress_bar = st.progress(0) |
|
|
status_text = st.empty() |
|
|
|
|
|
|
|
|
for i in range(100): |
|
|
progress_bar.progress(i + 1) |
|
|
if i < 30: |
|
|
status_text.text("π Loading image...") |
|
|
elif i < 60: |
|
|
status_text.text("π§ Processing with model...") |
|
|
elif i < 90: |
|
|
status_text.text("π Analyzing features...") |
|
|
else: |
|
|
status_text.text("β¨ Almost done...") |
|
|
time.sleep(0.02) |
|
|
|
|
|
|
|
|
progress_bar.empty() |
|
|
status_text.empty() |
|
|
|
|
|
|
|
|
input_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
|
|
confidence = torch.max(probabilities).item() |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
predicted_label = CLASS_NAMES[predicted.item()] |
|
|
|
|
|
|
|
|
st.markdown(f""" |
|
|
<div class="results-container"> |
|
|
<h2>π― Prediction Results</h2> |
|
|
<h1 style="font-size: 2.5rem; margin: 1rem 0;"> |
|
|
{predicted_label.replace('_', ' ').title()} |
|
|
</h1> |
|
|
<p style="font-size: 1.2rem;"> |
|
|
Confidence: {confidence:.1%} |
|
|
</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown("### π Confidence Level") |
|
|
conf_col1, conf_col2, conf_col3 = st.columns([1, 2, 1]) |
|
|
with conf_col2: |
|
|
st.progress(confidence) |
|
|
if confidence > 0.8: |
|
|
st.success(f"Very confident! ({confidence:.1%})") |
|
|
elif confidence > 0.6: |
|
|
st.warning(f"Moderately confident ({confidence:.1%})") |
|
|
else: |
|
|
st.error(f"Low confidence ({confidence:.1%}) - try a clearer image or the animal is not a cat/dog. ") |
|
|
|
|
|
|
|
|
st.markdown("### π Top 5 Predictions") |
|
|
top_k = torch.topk(probabilities, min(5, len(CLASS_NAMES))) |
|
|
|
|
|
for i, (prob, idx) in enumerate(zip(top_k.values, top_k.indices)): |
|
|
class_name = CLASS_NAMES[idx].replace('_', ' ').title() |
|
|
percentage = prob.item() |
|
|
|
|
|
|
|
|
st.markdown(f"**{i+1}. {class_name}**") |
|
|
st.progress(percentage) |
|
|
st.markdown(f"<small>{percentage:.1%}</small>", unsafe_allow_html=True) |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.markdown("### π¬ Actions") |
|
|
col_btn1, col_btn2, col_btn3 = st.columns(3) |
|
|
|
|
|
with col_btn1: |
|
|
if st.button("π Try Another", use_container_width=True): |
|
|
st.rerun() |
|
|
|
|
|
with col_btn2: |
|
|
if st.button("π₯ Download Result", use_container_width=True): |
|
|
st.balloons() |
|
|
st.success("Result saved! π") |
|
|
|
|
|
with col_btn3: |
|
|
if st.button("π€ Share", use_container_width=True): |
|
|
st.info("Share feature coming soon! π±") |
|
|
|
|
|
else: |
|
|
st.markdown(""" |
|
|
<div class="info-box"> |
|
|
<h3>π Ready to classify your pet?</h3> |
|
|
<p>Upload an image to get started! Our model can identify dozens of different cat and dog breeds with high accuracy.</p> |
|
|
<br> |
|
|
<p><strong>Supported breeds include:</strong></p> |
|
|
<p>π± Cats: Persian, Siamese, Maine Coon, and more<br> |
|
|
πΆ Dogs: Golden Retriever, German Shepherd, Bulldog, and more</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown(""" |
|
|
<div style="text-align: center; padding: 2rem; opacity: 0.7;"> |
|
|
<p>πΎ Built with β€οΈ using Streamlit and PyTorch | Oxford Pet Dataset</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |