arju10 commited on
Commit
afda760
Β·
verified Β·
1 Parent(s): 7bda02e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +466 -0
app.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ πŸ¦‹ Butterfly Species Classifier - Streamlit Web App
3
+ Production-ready web interface for butterfly identification
4
+ Features:
5
+ - Upload butterfly images
6
+ - Get instant predictions
7
+ - View top-5 most likely species
8
+ - Confidence visualization
9
+ - Beautiful, user-friendly interface
10
+ """
11
+
12
+ import streamlit as st
13
+ import tensorflow as tf
14
+ from tensorflow import keras
15
+ import numpy as np
16
+ from PIL import Image
17
+ import json
18
+ import os
19
+ import plotly.graph_objects as go
20
+ from datetime import datetime
21
+ import warnings
22
+ warnings.filterwarnings('ignore')
23
+
24
+ # Page configuration
25
+ st.set_page_config(
26
+ page_title="πŸ¦‹ Butterfly Classifier",
27
+ page_icon="πŸ¦‹",
28
+ layout="wide",
29
+ initial_sidebar_state="expanded"
30
+ )
31
+
32
+ # Custom CSS for better styling
33
+ st.markdown("""
34
+ <style>
35
+ .main-header {
36
+ font-size: 3rem;
37
+ font-weight: bold;
38
+ text-align: center;
39
+ color: #10b981;
40
+ margin-bottom: 0.5rem;
41
+ }
42
+ .sub-header {
43
+ font-size: 1.2rem;
44
+ text-align: center;
45
+ color: #6b7280;
46
+ margin-bottom: 2rem;
47
+ }
48
+ .prediction-card {
49
+ background-color: #f0fdf4;
50
+ padding: 1.5rem;
51
+ border-radius: 0.5rem;
52
+ border-left: 4px solid #10b981;
53
+ margin: 1rem 0;
54
+ }
55
+ .confidence-high {
56
+ color: #10b981;
57
+ font-weight: bold;
58
+ }
59
+ .confidence-medium {
60
+ color: #f59e0b;
61
+ font-weight: bold;
62
+ }
63
+ .confidence-low {
64
+ color: #ef4444;
65
+ font-weight: bold;
66
+ }
67
+ .stButton>button {
68
+ width: 100%;
69
+ background-color: #10b981;
70
+ color: white;
71
+ font-weight: bold;
72
+ padding: 0.75rem;
73
+ border-radius: 0.5rem;
74
+ border: none;
75
+ font-size: 1.1rem;
76
+ }
77
+ .stButton>button:hover {
78
+ background-color: #059669;
79
+ }
80
+ </style>
81
+ """, unsafe_allow_html=True)
82
+
83
+
84
+ @st.cache_resource
85
+ def load_model_and_classes():
86
+ """Load the trained model and class indices with caching"""
87
+ try:
88
+ # Model path - using .keras format
89
+ model_path = 'butterfly_model_WORKING.keras'
90
+
91
+ # Check if model exists
92
+ if not os.path.exists(model_path):
93
+ st.error(f"❌ Model not found at: {model_path}")
94
+ st.info("""
95
+ **Setup Instructions:**
96
+ 1. Place `butterfly_model_best.keras` in `models/` directory
97
+ 2. Place `class_indices.json` in the project root
98
+ 3. Restart the Streamlit app
99
+ """)
100
+ return None, None, None
101
+
102
+ st.info(f"πŸ“‚ Loading model from: {model_path}")
103
+
104
+ # Load model
105
+ model = None
106
+ try:
107
+ # Load with compile=False for faster loading
108
+ model = keras.models.load_model(model_path, compile=False)
109
+
110
+ # Compile for predictions
111
+ model.compile(
112
+ optimizer='adam',
113
+ loss='categorical_crossentropy',
114
+ metrics=['accuracy']
115
+ )
116
+
117
+ st.success("βœ… Model loaded successfully!")
118
+
119
+ except Exception as e:
120
+ st.error(f"❌ Failed to load model: {e}")
121
+ st.info("""
122
+ **Troubleshooting:**
123
+ 1. Make sure you have the .keras file (not .h5)
124
+ 2. File should be ~173 MB
125
+ 3. Run: `ls -lh models/butterfly_model_best.keras`
126
+ """)
127
+ return None, None, None
128
+
129
+ # Load class indices
130
+ class_indices_path = 'class_indices.json'
131
+ if not os.path.exists(class_indices_path):
132
+ st.error(f"❌ Class indices not found: {class_indices_path}")
133
+ st.info("Run: `python generate_json_files.py` to create it")
134
+ return None, None, None
135
+
136
+ with open(class_indices_path, 'r') as f:
137
+ class_indices = json.load(f)
138
+
139
+ # Create reverse mapping (index -> class name)
140
+ idx_to_class = {v: k for k, v in class_indices.items()}
141
+
142
+ st.success(f"βœ… Loaded {len(class_indices)} butterfly species")
143
+
144
+ return model, class_indices, idx_to_class
145
+
146
+ except Exception as e:
147
+ st.error(f"❌ Unexpected error: {e}")
148
+ import traceback
149
+ with st.expander("Show error details"):
150
+ st.code(traceback.format_exc())
151
+ return None, None, None
152
+
153
+
154
+ def preprocess_image(image, target_size=(224, 224)):
155
+ """Preprocess image for model prediction"""
156
+ # Resize image
157
+ image = image.resize(target_size)
158
+
159
+ # Convert to array and normalize to [0, 1]
160
+ img_array = np.array(image, dtype=np.float32) / 255.0
161
+
162
+ # Add batch dimension
163
+ img_array = np.expand_dims(img_array, axis=0)
164
+
165
+ return img_array
166
+
167
+
168
+ def get_confidence_color(confidence):
169
+ """Return CSS class based on confidence level"""
170
+ if confidence >= 0.7:
171
+ return "confidence-high"
172
+ elif confidence >= 0.4:
173
+ return "confidence-medium"
174
+ else:
175
+ return "confidence-low"
176
+
177
+
178
+ def get_confidence_interpretation(confidence):
179
+ """Return human-readable confidence interpretation"""
180
+ if confidence >= 0.9:
181
+ return "Very High Confidence"
182
+ elif confidence >= 0.7:
183
+ return "High Confidence"
184
+ elif confidence >= 0.5:
185
+ return "Medium Confidence"
186
+ elif confidence >= 0.3:
187
+ return "Low Confidence"
188
+ else:
189
+ return "Very Low Confidence"
190
+
191
+
192
+ def create_confidence_gauge(confidence, species_name):
193
+ """Create a beautiful confidence gauge using Plotly"""
194
+ # Determine color based on confidence
195
+ if confidence >= 0.7:
196
+ bar_color = "#10b981" # Green
197
+ elif confidence >= 0.4:
198
+ bar_color = "#f59e0b" # Yellow
199
+ else:
200
+ bar_color = "#ef4444" # Red
201
+
202
+ fig = go.Figure(go.Indicator(
203
+ mode="gauge+number",
204
+ value=confidence * 100,
205
+ domain={'x': [0, 1], 'y': [0, 1]},
206
+ title={'text': f"Confidence", 'font': {'size': 20}},
207
+ number={'suffix': "%", 'font': {'size': 40}},
208
+ gauge={
209
+ 'axis': {'range': [0, 100], 'tickwidth': 2, 'tickcolor': "darkgray"},
210
+ 'bar': {'color': bar_color, 'thickness': 0.75},
211
+ 'bgcolor': "white",
212
+ 'borderwidth': 2,
213
+ 'bordercolor': "gray",
214
+ 'steps': [
215
+ {'range': [0, 40], 'color': '#fee2e2'},
216
+ {'range': [40, 70], 'color': '#fef3c7'},
217
+ {'range': [70, 100], 'color': '#d1fae5'}
218
+ ],
219
+ 'threshold': {
220
+ 'line': {'color': "red", 'width': 4},
221
+ 'thickness': 0.75,
222
+ 'value': 50
223
+ }
224
+ }
225
+ ))
226
+
227
+ fig.update_layout(
228
+ height=300,
229
+ margin=dict(l=20, r=20, t=60, b=20),
230
+ paper_bgcolor="rgba(0,0,0,0)",
231
+ font={'family': "Arial, sans-serif"}
232
+ )
233
+
234
+ return fig
235
+
236
+
237
+ def create_top_predictions_chart(predictions, idx_to_class, top_k=5):
238
+ """Create horizontal bar chart for top predictions"""
239
+ # Get top k predictions
240
+ top_indices = np.argsort(predictions[0])[-top_k:][::-1]
241
+ top_species = [idx_to_class[i] for i in top_indices]
242
+ top_confidences = predictions[0][top_indices] * 100
243
+
244
+ # Create color scale based on confidence
245
+ colors = []
246
+ for c in top_confidences:
247
+ if c >= 70:
248
+ colors.append('#10b981') # Green
249
+ elif c >= 40:
250
+ colors.append('#f59e0b') # Yellow
251
+ else:
252
+ colors.append('#ef4444') # Red
253
+
254
+ fig = go.Figure(go.Bar(
255
+ x=top_confidences,
256
+ y=top_species,
257
+ orientation='h',
258
+ marker=dict(color=colors),
259
+ text=[f'{c:.1f}%' for c in top_confidences],
260
+ textposition='auto',
261
+ textfont=dict(size=14, color='white', family='Arial Black')
262
+ ))
263
+
264
+ fig.update_layout(
265
+ title=f"Top {top_k} Most Likely Species",
266
+ xaxis_title="Confidence (%)",
267
+ yaxis_title="Species",
268
+ height=300,
269
+ margin=dict(l=20, r=20, t=60, b=20),
270
+ paper_bgcolor="rgba(0,0,0,0)",
271
+ plot_bgcolor="rgba(0,0,0,0)",
272
+ font={'family': "Arial, sans-serif", 'size': 12},
273
+ xaxis=dict(gridcolor='lightgray', range=[0, 100]),
274
+ yaxis=dict(autorange="reversed")
275
+ )
276
+
277
+ return fig
278
+
279
+
280
+ def main():
281
+ # Header
282
+ st.markdown('<p class="main-header">πŸ¦‹ Butterfly Species Classifier</p>', unsafe_allow_html=True)
283
+ st.markdown('<p class="sub-header">Upload a butterfly image to identify its species using AI</p>', unsafe_allow_html=True)
284
+
285
+ # Load model
286
+ with st.spinner("πŸ”„ Loading AI model..."):
287
+ model, class_indices, idx_to_class = load_model_and_classes()
288
+
289
+ # Check if model loaded
290
+ if model is None:
291
+ st.error("❌ Failed to load model. Please check the setup instructions above.")
292
+ st.stop()
293
+
294
+ # Sidebar
295
+ with st.sidebar:
296
+ st.header("ℹ️ About")
297
+ st.write(f"""
298
+ This AI-powered app can identify **{len(class_indices)} different butterfly species** with high accuracy!
299
+
300
+ **How to use:**
301
+ 1. Upload a clear butterfly image
302
+ 2. Click 'Identify Species'
303
+ 3. Get instant predictions!
304
+
305
+ **Best results:**
306
+ - Clear, well-lit photos
307
+ - Butterfly in focus
308
+ - Minimal background clutter
309
+ """)
310
+
311
+ st.divider()
312
+
313
+ st.header("πŸ“Š Model Info")
314
+ if os.path.exists('model_info.json'):
315
+ try:
316
+ with open('model_info.json', 'r') as f:
317
+ model_info = json.load(f)
318
+ st.write(f"**Model:** {model_info.get('best_model', 'MobileNetV2')}")
319
+ st.write(f"**Accuracy:** {model_info.get('best_model_metrics', {}).get('accuracy', 0.85)*100:.1f}%")
320
+ st.write(f"**Species:** {model_info.get('num_classes', len(class_indices))}")
321
+ except:
322
+ st.write(f"**Species:** {len(class_indices)}")
323
+ else:
324
+ st.write(f"**Architecture:** MobileNetV2")
325
+ st.write(f"**Species:** {len(class_indices)}")
326
+ st.write(f"**Format:** Keras 3.x (.keras)")
327
+
328
+ st.divider()
329
+
330
+ st.header("🎯 Tips")
331
+ st.write("""
332
+ - **High confidence (>70%)**: Very reliable
333
+ - **Medium (40-70%)**: Generally good
334
+ - **Low (<40%)**: May need verification
335
+ """)
336
+
337
+ # Main content
338
+ col1, col2 = st.columns([1, 1])
339
+
340
+ with col1:
341
+ st.header("πŸ“€ Upload Image")
342
+ uploaded_file = st.file_uploader(
343
+ "Choose a butterfly image...",
344
+ type=['jpg', 'jpeg', 'png'],
345
+ help="Upload a clear image of a butterfly"
346
+ )
347
+
348
+ if uploaded_file is not None:
349
+ # Display uploaded image
350
+ image = Image.open(uploaded_file).convert('RGB')
351
+ st.image(image, caption='Uploaded Image', use_container_width=True)
352
+
353
+ # Show image info
354
+ st.info(f"πŸ“ Image size: {image.size[0]} x {image.size[1]} pixels")
355
+
356
+ # Predict button
357
+ if st.button("πŸ” Identify Species", type="primary"):
358
+ with st.spinner("πŸ€” Analyzing butterfly..."):
359
+ try:
360
+ # Preprocess image
361
+ processed_image = preprocess_image(image)
362
+
363
+ # Make prediction
364
+ predictions = model.predict(processed_image, verbose=0)
365
+
366
+ # Get top prediction
367
+ top_class_idx = np.argmax(predictions[0])
368
+ top_species = idx_to_class[top_class_idx]
369
+ top_confidence = float(predictions[0][top_class_idx])
370
+
371
+ # Store in session state
372
+ st.session_state['predictions'] = predictions
373
+ st.session_state['top_species'] = top_species
374
+ st.session_state['top_confidence'] = top_confidence
375
+ st.session_state['prediction_time'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
376
+
377
+ st.success("βœ… Prediction complete!")
378
+
379
+ except Exception as e:
380
+ st.error(f"❌ Prediction failed: {e}")
381
+ st.info("Please try uploading a different image.")
382
+
383
+ with col2:
384
+ st.header("🎯 Results")
385
+
386
+ if 'predictions' in st.session_state:
387
+ predictions = st.session_state['predictions']
388
+ top_species = st.session_state['top_species']
389
+ top_confidence = st.session_state['top_confidence']
390
+
391
+ # Main prediction card
392
+ confidence_class = get_confidence_color(top_confidence)
393
+ confidence_text = get_confidence_interpretation(top_confidence)
394
+
395
+ st.markdown(f"""
396
+ <div class="prediction-card">
397
+ <h2 style="margin-top: 0; color: #10b981;">Predicted Species</h2>
398
+ <h1 style="margin: 0.5rem 0; color: #1f2937;">{top_species}</h1>
399
+ <p style="margin: 0; font-size: 1.5rem;" class="{confidence_class}">
400
+ {top_confidence*100:.1f}% - {confidence_text}
401
+ </p>
402
+ </div>
403
+ """, unsafe_allow_html=True)
404
+
405
+ # Confidence gauge
406
+ st.plotly_chart(
407
+ create_confidence_gauge(top_confidence, top_species),
408
+ use_container_width=True
409
+ )
410
+
411
+ # Additional info
412
+ st.info(f"πŸ• Predicted at: {st.session_state['prediction_time']}")
413
+ else:
414
+ st.info("πŸ‘† Upload an image and click 'Identify Species' to see results")
415
+
416
+ # Top predictions chart (full width)
417
+ if 'predictions' in st.session_state:
418
+ st.divider()
419
+ st.header("πŸ“Š Top 5 Predictions")
420
+
421
+ col_chart1, col_chart2 = st.columns([2, 1])
422
+
423
+ with col_chart1:
424
+ st.plotly_chart(
425
+ create_top_predictions_chart(st.session_state['predictions'], idx_to_class, top_k=5),
426
+ use_container_width=True
427
+ )
428
+
429
+ with col_chart2:
430
+ st.subheader("πŸ” Interpretation")
431
+ top_conf = st.session_state['top_confidence']
432
+
433
+ if top_conf >= 0.7:
434
+ st.success("βœ… **High Confidence**")
435
+ st.write("The model is very sure about this prediction!")
436
+ elif top_conf >= 0.4:
437
+ st.warning("⚠️ **Medium Confidence**")
438
+ st.write("The prediction is likely correct, but consider the alternatives.")
439
+ else:
440
+ st.error("❌ **Low Confidence**")
441
+ st.write("The model is uncertain. This might not be in the training dataset.")
442
+
443
+ st.write("**What to do:**")
444
+ if top_conf >= 0.7:
445
+ st.write("- βœ… Trust this prediction")
446
+ st.write("- πŸ“š Use for identification")
447
+ elif top_conf >= 0.4:
448
+ st.write("- πŸ‘€ Check top alternatives")
449
+ st.write("- πŸ” Verify with expert")
450
+ else:
451
+ st.write("- ⚠️ Image may be unclear")
452
+ st.write("- πŸ”„ Try a different photo")
453
+ st.write("- πŸ‘€ Consult an expert")
454
+
455
+ # Footer
456
+ st.divider()
457
+ st.markdown(f"""
458
+ <div style="text-align: center; color: #6b7280; padding: 2rem 0;">
459
+ <p>πŸ¦‹ <strong>Butterfly Species Classifier</strong> | Created by Arju</p>
460
+ <p style="font-size: 0.9rem;">Trained on {len(class_indices) if class_indices else 75} species | Built with TensorFlow & Streamlit</p>
461
+ </div>
462
+ """, unsafe_allow_html=True)
463
+
464
+
465
+ if __name__ == "__main__":
466
+ main()