File size: 9,431 Bytes
f942ba9
 
 
 
 
88c7a58
02318d0
f942ba9
88c7a58
02318d0
88c7a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02318d0
f942ba9
 
 
88c7a58
f942ba9
695a8aa
f942ba9
 
88c7a58
 
695a8aa
3a0ae17
695a8aa
f942ba9
 
88c7a58
f942ba9
 
 
 
 
88c7a58
 
 
 
f942ba9
88c7a58
f942ba9
88c7a58
 
 
695a8aa
88c7a58
68cdc57
 
e5b3cff
68cdc57
e5b3cff
68cdc57
 
e5b3cff
 
68cdc57
e5b3cff
 
 
 
 
 
 
 
c6fa3f7
 
e5b3cff
 
68cdc57
 
 
 
 
88c7a58
 
 
68cdc57
9e0fb8d
 
 
68cdc57
9e0fb8d
 
 
68cdc57
 
88c7a58
68cdc57
88c7a58
 
 
 
 
68cdc57
 
 
9e0fb8d
68cdc57
9e0fb8d
88c7a58
9e0fb8d
88c7a58
 
 
 
 
 
9e0fb8d
 
 
 
 
68cdc57
9e0fb8d
 
68cdc57
 
88c7a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f942ba9
 
88c7a58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import io
from pathlib import Path
import streamlit as st
from fastai.vision.all import load_learner, PILImage

# ✅ Correct absolute path for Hugging Face Spaces
MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")

# Custom CSS for modern UI
st.markdown("""
<style>
    /* Main container styling */
    .main {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        padding: 2rem;
    }
    
    /* Card-like containers */
    .stApp {
        max-width: 1200px;
        margin: 0 auto;
    }
    
    /* Title styling */
    h1 {
        color: white !important;
        text-align: center;
        font-size: 3rem !important;
        font-weight: 800 !important;
        margin-bottom: 0.5rem !important;
        text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
    }
    
    /* Subtitle styling */
    .subtitle {
        text-align: center;
        color: rgba(255,255,255,0.9);
        font-size: 1.2rem;
        margin-bottom: 2rem;
    }
    
    /* File uploader styling */
    .stFileUploader {
        background: white;
        border-radius: 15px;
        padding: 2rem;
        box-shadow: 0 10px 30px rgba(0,0,0,0.2);
    }
    
    /* Prediction result card */
    .prediction-card {
        background: white;
        border-radius: 15px;
        padding: 2rem;
        margin-top: 2rem;
        box-shadow: 0 10px 30px rgba(0,0,0,0.2);
    }
    
    /* Success message styling */
    .stSuccess {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white !important;
        border-radius: 10px;
        font-size: 1.5rem;
        font-weight: bold;
        text-align: center;
        padding: 1rem;
    }
    
    /* Progress bar */
    .stProgress > div > div {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
    }
    
    /* Buttons */
    .stButton > button {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        border: none;
        border-radius: 10px;
        padding: 0.75rem 2rem;
        font-size: 1.1rem;
        font-weight: 600;
        transition: all 0.3s ease;
    }
    
    .stButton > button:hover {
        transform: translateY(-2px);
        box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
    }
    
    /* Image container */
    .uploaded-image {
        border-radius: 15px;
        overflow: hidden;
        box-shadow: 0 10px 30px rgba(0,0,0,0.2);
        margin: 2rem 0;
    }
    
    /* Probability bars */
    .prob-bar {
        background: #f0f2f6;
        border-radius: 10px;
        height: 40px;
        margin: 0.5rem 0;
        overflow: hidden;
        position: relative;
    }
    
    .prob-fill {
        height: 100%;
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        display: flex;
        align-items: center;
        padding: 0 1rem;
        color: white;
        font-weight: 600;
        transition: width 0.5s ease;
    }
    
    .prob-label {
        position: absolute;
        left: 1rem;
        top: 50%;
        transform: translateY(-50%);
        font-weight: 600;
        z-index: 1;
    }
</style>
""", unsafe_allow_html=True)

@st.cache_resource
def load_model():
    """Load and cache the FastAI learner. Returns None if model missing or incompatible."""
    if not MODEL_PATH.exists():
        st.error(f"❌ Model not found at {MODEL_PATH}")
        return None
    try:
        learner = load_learner(MODEL_PATH)
        return learner
    except Exception as e:
        st.error(f"⚠️ Error loading model:\n\n{e}")
        return None

def predict(learner, img_bytes: bytes):
    """Make a prediction on uploaded image bytes."""
    img = PILImage.create(io.BytesIO(img_bytes))
    pred, pred_idx, probs = learner.predict(img)
    return pred, probs

def main():
    # Header
    st.title("🎮 Pokémon Gen 9 Classifier")
    st.markdown('<p class="subtitle">Upload a Pokémon image and discover which species it is!</p>', unsafe_allow_html=True)
    
    learner = load_model()
    
    if learner is None:
        st.warning(
            "⚠️ Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS."
        )
        st.stop()
    
    # Example images section
    st.markdown("---")
    st.markdown("### 🖼️ Try with Example Images")
    
    # Define example images path (adjust this to your actual examples folder)
    examples_path = Path("examples")
    
    if examples_path.exists():
        example_images = list(examples_path.glob("*.jpg")) + list(examples_path.glob("*.png")) + list(examples_path.glob("*.jpeg"))
        
        if example_images:
            # Display examples in a grid
            cols = st.columns(min(5, len(example_images)))
            
            for idx, img_path in enumerate(example_images[:5]):  # Show max 5 examples
                with cols[idx]:
                    st.image(str(img_path), use_container_width=True, caption=img_path.stem)
                    if st.button(f"Use", key=f"example_{idx}"):
                        # Store the selected example in session state
                        st.session_state.example_image = img_path
        else:
            st.info("No example images found in the 'examples' folder.")
    else:
        st.info("💡 **Tip:** Create an 'examples' folder with sample Pokémon images to display them here!")
    
    st.markdown("---")
    
    # Create two columns for better layout
    col1, col2 = st.columns([1, 1])
    
    # Check if example image was selected
    uploaded_file = None
    display_image = None
    
    if 'example_image' in st.session_state:
        example_path = st.session_state.example_image
        uploaded_file = example_path
        display_image = str(example_path)
        del st.session_state.example_image  # Clear after use
    
    with col1:
        file_upload = st.file_uploader(
            "Choose a Pokémon image", 
            type=["png", "jpg", "jpeg"],
            help="Upload a clear image of a Generation 9 Pokémon"
        )
        
        # Prioritize file upload over example
        if file_upload is not None:
            uploaded_file = file_upload
            display_image = file_upload
        
        if display_image is not None:
            st.markdown('<div class="uploaded-image">', unsafe_allow_html=True)
            st.image(display_image, use_container_width=True)
            st.markdown('</div>', unsafe_allow_html=True)
    
    with col2:
        if uploaded_file is not None:
            with st.spinner("🔍 Analyzing image..."):
                try:
                    # Read image bytes based on type
                    if isinstance(uploaded_file, Path):
                        # Example image - read from file
                        with open(uploaded_file, 'rb') as f:
                            img_bytes = f.read()
                    else:
                        # Uploaded file
                        img_bytes = uploaded_file.read()
                    
                    pred, probs = predict(learner, img_bytes)
                    
                    # Main prediction with emoji
                    st.markdown("### 🎯 Prediction Result")
                    st.success(f"✨ **{pred}**")
                    
                    # Confidence percentage
                    max_prob = float(probs.max())
                    st.metric(
                        "Confidence", 
                        f"{max_prob*100:.1f}%",
                        delta=None
                    )
                    
                    # Top predictions with visual progress bars
                    st.markdown("### 📊 Top 5 Predictions")
                    
                    vocab = learner.dls.vocab
                    probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
                    
                    for i, (label, p) in enumerate(probs_list[:5]):
                        prob_percent = float(p) * 100
                        
                        # Custom progress bar with label
                        st.markdown(f"""
                        <div style="margin: 1rem 0;">
                            <div style="display: flex; justify-content: space-between; margin-bottom: 0.3rem;">
                                <span style="font-weight: 600; color: #1f2937;">{'🥇' if i==0 else '🥈' if i==1 else '🥉' if i==2 else '⭐'} {label}</span>
                                <span style="font-weight: 600; color: #667eea;">{prob_percent:.1f}%</span>
                            </div>
                        </div>
                        """, unsafe_allow_html=True)
                        
                        st.progress(float(p))
                    
                except Exception as e:
                    st.error(f"❌ Error during prediction: {e}")
        else:
            # Placeholder when no image is uploaded
            st.info("👆 Upload an image to get started!")
            st.markdown("""
            ### How to use:
            1. 📤 Upload a Pokémon image (PNG, JPG, or JPEG)
            2. ⏳ Wait for the AI to analyze it
            3. 🎉 See the prediction and confidence scores!
            
            **Tip:** Use clear, well-lit images for best results!
            """)

if __name__ == "__main__":
    main()