File size: 5,969 Bytes
523b89b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Streamlit app for CV200 Pet Classifier

Run:
  streamlit run streamlit_app.py
"""
import os
import streamlit as st
import requests
from PIL import Image
import io

# Streamlit Community Cloud stores secrets in `st.secrets` (not necessarily env vars).
def _get_api_url_default() -> str:
    try:
        secret_val = st.secrets.get("API_URL")  # type: ignore[attr-defined]
    except Exception:
        secret_val = None
    return secret_val or os.environ.get("API_URL") or "https://solarevat-cv200.hf.space"

st.set_page_config(
    page_title="Pet Classifier",
    page_icon="🐾",
    layout="wide",
    initial_sidebar_state="expanded",
)

# Get API URL from secrets/env
if "api_url" not in st.session_state:
    st.session_state.api_url = _get_api_url_default()

st.title("🐾 Pet Classifier")
st.markdown("Upload an image of a pet to classify it using our CV200 model!")

# Sidebar with info and settings
with st.sidebar:
    st.header("Settings")
    st.session_state.api_url = st.text_input(
        "API_URL",
        value=st.session_state.api_url,
        help="FastAPI base URL. Default: https://solarevat-cv200.hf.space",
    )
    
    st.header("About")
    st.markdown("""
    This app uses a deep learning model trained on 200 pet classes.
    
    **How to use:**
    1. Upload an image of a pet
    2. Click "Classify Pet"
    3. See the top 5 predictions with confidence scores
    """)
    
    st.header("API Info")
    st.markdown("**Endpoints:**")
    st.markdown("- `/predict` - Single image")
    st.markdown("- `/predict_batch` - Multiple images")
    st.markdown("- `/docs` - API documentation")
    st.markdown("- `/healthz` - Health check")

st.info(
    f"Current API_URL: `{st.session_state.api_url}` β€” [View API docs]({st.session_state.api_url}/docs)"
)

# Main content
uploaded_file = st.file_uploader(
    "Choose an image...",
    type=['jpg', 'jpeg', 'png', 'webp'],
    help="Upload an image of a pet to classify"
)

if uploaded_file is not None:
    # Display uploaded image
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", use_container_width=True)
    
    # Classification button
    if st.button("πŸ” Classify Pet", type="primary"):
        with st.spinner("Classifying pet... Please wait."):
            try:
                # Prepare file for API
                files = {'file': (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
                data = {'top_k': 5}
                
                # Make API request
                response = requests.post(
                    f"{st.session_state.api_url}/predict",
                    files=files,
                    data=data,
                    timeout=30
                )
                
                if response.status_code == 200:
                    result = response.json()
                    
                    st.success("Classification complete!")
                    
                    # Display predictions
                    st.subheader("Top Predictions:")
                    
                    predictions = result.get('predictions', [])
                    if predictions:
                        # Create columns for better layout
                        cols = st.columns(min(len(predictions), 3))
                        
                        for idx, pred in enumerate(predictions):
                            col_idx = idx % 3
                            with cols[col_idx]:
                                confidence = pred['confidence'] * 100
                                class_name = pred['class_name']
                                
                                # Progress bar for confidence
                                st.metric(
                                    label=class_name,
                                    value=f"{confidence:.1f}%"
                                )
                                st.progress(confidence / 100)
                        
                        # Show top prediction prominently
                        top_pred = predictions[0]
                        st.markdown("---")
                        st.markdown(f"### πŸ† Top Prediction: **{top_pred['class_name']}**")
                        st.markdown(f"**Confidence:** {top_pred['confidence'] * 100:.2f}%")
                        
                        # Show all predictions in a table
                        with st.expander("View all predictions"):
                            import pandas as pd
                            df = pd.DataFrame(predictions)
                            df['confidence'] = df['confidence'].apply(lambda x: f"{x*100:.2f}%")
                            df = df.rename(columns={
                                'class_name': 'Pet Breed',
                                'confidence': 'Confidence',
                                'class_id': 'Class ID'
                            })
                            st.dataframe(df[['Pet Breed', 'Confidence']], use_container_width=True)
                    else:
                        st.warning("No predictions returned.")
                        
                else:
                    st.error(f"API Error: {response.status_code}")
                    st.json(response.json() if response.content else {"error": "No response"})
                    
            except requests.exceptions.RequestException as e:
                st.error(f"Failed to connect to API: {str(e)}")
                st.info(f"Make sure the API is running at {st.session_state.api_url}")
            except Exception as e:
                st.error(f"An error occurred: {str(e)}")

# Footer
st.markdown("---")
st.markdown(
    """
    <div style='text-align: center; color: #666;'>
        <p>Powered by CV200 Pet Classifier API</p>
        <p><a href="{}/docs" target="_blank">View API Documentation</a></p>
    </div>
    """.format(st.session_state.api_url),
    unsafe_allow_html=True
)