File size: 5,969 Bytes
523b89b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """
Streamlit app for CV200 Pet Classifier
Run:
streamlit run streamlit_app.py
"""
import os
import streamlit as st
import requests
from PIL import Image
import io
# Streamlit Community Cloud stores secrets in `st.secrets` (not necessarily env vars).
def _get_api_url_default() -> str:
try:
secret_val = st.secrets.get("API_URL") # type: ignore[attr-defined]
except Exception:
secret_val = None
return secret_val or os.environ.get("API_URL") or "https://solarevat-cv200.hf.space"
st.set_page_config(
page_title="Pet Classifier",
page_icon="πΎ",
layout="wide",
initial_sidebar_state="expanded",
)
# Get API URL from secrets/env
if "api_url" not in st.session_state:
st.session_state.api_url = _get_api_url_default()
st.title("πΎ Pet Classifier")
st.markdown("Upload an image of a pet to classify it using our CV200 model!")
# Sidebar with info and settings
with st.sidebar:
st.header("Settings")
st.session_state.api_url = st.text_input(
"API_URL",
value=st.session_state.api_url,
help="FastAPI base URL. Default: https://solarevat-cv200.hf.space",
)
st.header("About")
st.markdown("""
This app uses a deep learning model trained on 200 pet classes.
**How to use:**
1. Upload an image of a pet
2. Click "Classify Pet"
3. See the top 5 predictions with confidence scores
""")
st.header("API Info")
st.markdown("**Endpoints:**")
st.markdown("- `/predict` - Single image")
st.markdown("- `/predict_batch` - Multiple images")
st.markdown("- `/docs` - API documentation")
st.markdown("- `/healthz` - Health check")
st.info(
f"Current API_URL: `{st.session_state.api_url}` β [View API docs]({st.session_state.api_url}/docs)"
)
# Main content
uploaded_file = st.file_uploader(
"Choose an image...",
type=['jpg', 'jpeg', 'png', 'webp'],
help="Upload an image of a pet to classify"
)
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Classification button
if st.button("π Classify Pet", type="primary"):
with st.spinner("Classifying pet... Please wait."):
try:
# Prepare file for API
files = {'file': (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
data = {'top_k': 5}
# Make API request
response = requests.post(
f"{st.session_state.api_url}/predict",
files=files,
data=data,
timeout=30
)
if response.status_code == 200:
result = response.json()
st.success("Classification complete!")
# Display predictions
st.subheader("Top Predictions:")
predictions = result.get('predictions', [])
if predictions:
# Create columns for better layout
cols = st.columns(min(len(predictions), 3))
for idx, pred in enumerate(predictions):
col_idx = idx % 3
with cols[col_idx]:
confidence = pred['confidence'] * 100
class_name = pred['class_name']
# Progress bar for confidence
st.metric(
label=class_name,
value=f"{confidence:.1f}%"
)
st.progress(confidence / 100)
# Show top prediction prominently
top_pred = predictions[0]
st.markdown("---")
st.markdown(f"### π Top Prediction: **{top_pred['class_name']}**")
st.markdown(f"**Confidence:** {top_pred['confidence'] * 100:.2f}%")
# Show all predictions in a table
with st.expander("View all predictions"):
import pandas as pd
df = pd.DataFrame(predictions)
df['confidence'] = df['confidence'].apply(lambda x: f"{x*100:.2f}%")
df = df.rename(columns={
'class_name': 'Pet Breed',
'confidence': 'Confidence',
'class_id': 'Class ID'
})
st.dataframe(df[['Pet Breed', 'Confidence']], use_container_width=True)
else:
st.warning("No predictions returned.")
else:
st.error(f"API Error: {response.status_code}")
st.json(response.json() if response.content else {"error": "No response"})
except requests.exceptions.RequestException as e:
st.error(f"Failed to connect to API: {str(e)}")
st.info(f"Make sure the API is running at {st.session_state.api_url}")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
# Footer
st.markdown("---")
st.markdown(
"""
<div style='text-align: center; color: #666;'>
<p>Powered by CV200 Pet Classifier API</p>
<p><a href="{}/docs" target="_blank">View API Documentation</a></p>
</div>
""".format(st.session_state.api_url),
unsafe_allow_html=True
)
|