williamj949 commited on
Commit
feac374
Β·
1 Parent(s): 8a261d9

Intial app code

Browse files
Files changed (1) hide show
  1. app.py +263 -4
app.py CHANGED
@@ -1,7 +1,266 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from fastai.vision.all import *
3
+ from pathlib import Path
4
+ import numpy as np
5
 
 
 
6
 
7
+ def load_model():
8
+ """Load the exported FastAI model"""
9
+ try:
10
+ model_path = Path('bears_model_clean.pkl')
11
+ learn = load_learner(model_path)
12
+ return learn
13
+ except Exception as e:
14
+ print(f"Error loading model: {e}")
15
+ return None
16
+
17
+
18
+ learn = load_model()
19
+
20
+
21
+ def classify_bear(image):
22
+ """
23
+ Detect bear species from uploaded image
24
+
25
+ Args:
26
+ image: PIL Image or numpy array
27
+
28
+ Returns:
29
+ dict: Prediction probabilities for each bear type
30
+ """
31
+ if learn is None:
32
+ return {"Error": "Model not loaded properly"}
33
+ if image is None:
34
+ return {"No Image": "Please upload an image"}
35
+
36
+ try:
37
+ # Make prediction
38
+ pred, pred_idx, probs = learn.predict(image)
39
+
40
+ # Get class names
41
+ class_names = learn.dls.vocab
42
+
43
+ # Create confidence dictionary
44
+ confidences = {}
45
+ for i, class_name in enumerate(class_names):
46
+ confidences[class_name] = float(probs[i])
47
+
48
+ return confidences
49
+
50
+ except Exception as e:
51
+ return {"Error": f"Prediction failed: {str(e)}"}
52
+
53
+
54
+ def get_bear_info(prediction_dict):
55
+ """
56
+ Get information about the predicted bear type
57
+
58
+ Args:
59
+ prediction_dict: Dictionary with prediction confidences
60
+
61
+ Returns:
62
+ str: Information about the most likely bear type
63
+ """
64
+ if "Error" in prediction_dict:
65
+ return prediction_dict["Error"]
66
+ if "No Image" in prediction_dict:
67
+ return "Upload an image to learn about the bear species!"
68
+
69
+ # Get the bear type with highest confidence
70
+ top_prediction = max(prediction_dict.items(), key=lambda x: x[1])
71
+ bear_type = top_prediction[0]
72
+ confidence = top_prediction[1]
73
+
74
+ # Bear information dictionary
75
+ bear_info = {
76
+ "black": "🐻 **Black Bear**: The most common bear in North America. They're excellent climbers and swimmers, with a varied omnivorous diet.",
77
+ "grizzly": "🐻 **Grizzly Bear**: A powerful subspecies of brown bear found in North America. Known for their distinctive shoulder hump and long claws.",
78
+ "polar": "πŸ»β€β„οΈ **Polar Bear**: The largest bear species, perfectly adapted to Arctic life. They're excellent swimmers and primarily hunt seals.",
79
+ "panda": "🐼 **Giant Panda**: A beloved bear species native to China, famous for their black and white coloring and bamboo diet.",
80
+ "teddy": "🧸 **Teddy Bear**: A stuffed toy bear! Named after President Theodore Roosevelt, these cuddly companions have been beloved by children for over a century."
81
+ }
82
+
83
+ # Find matching bear info (case insensitive)
84
+ info = ""
85
+ for key, value in bear_info.items():
86
+ if key.lower() in bear_type.lower():
87
+ info = value
88
+ break
89
+
90
+ if not info:
91
+ info = f"🐻 **{bear_type}**: A type of bear!"
92
+
93
+ return f"{info}\n\n**Confidence**: {confidence:.1%}"
94
+
95
+
96
+ def predict_and_explain(image):
97
+ """
98
+ Main function that combines prediction and explanation
99
+
100
+ Args:
101
+ image: Input image
102
+
103
+ Returns:
104
+ tuple: (prediction_dict, explanation_text)
105
+ """
106
+ predictions = classify_bear(image)
107
+ explanation = get_bear_info(predictions)
108
+ return predictions, explanation
109
+
110
+
111
+ def handle_image_change(image):
112
+ """
113
+ Handle image change events with proper None checking
114
+
115
+ Args:
116
+ image: Input image (can be None when cleared)
117
+
118
+ Returns:
119
+ tuple: (prediction_dict, explanation_text)
120
+ """
121
+ if image is None:
122
+ return {}, "Upload an image to learn about the bear species!"
123
+
124
+ return predict_and_explain(image)
125
+
126
+
127
+ def get_sample_images():
128
+ """
129
+ Get list of sample images if they exist
130
+
131
+ Returns:
132
+ list: List of image paths for examples
133
+ """
134
+ sample_paths = [
135
+ "samples/black.jpg",
136
+ "samples/grizzly.jpg",
137
+ "samples/polar.jpg",
138
+ "samples/panda.jpg",
139
+ "samples/teddy.jpg"
140
+ ]
141
+ existing_samples = []
142
+ for path in sample_paths:
143
+ if Path(path).exists():
144
+ existing_samples.append([path])
145
+ print(f"βœ… Found sample image: {path}")
146
+ else:
147
+ print(f"⚠️ Sample image not found: {path}")
148
+
149
+ return existing_samples
150
+
151
+
152
+ def create_interface():
153
+ """Create and configure the Gradio interface"""
154
+
155
+ css = """
156
+ .gradio-container {
157
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
158
+ }
159
+ .bear-title {
160
+ text-align: center;
161
+ color: #8B4513;
162
+ font-size: 2.5em;
163
+ margin-bottom: 20px;
164
+ }
165
+ """
166
+
167
+ with gr.Blocks(css=css, title="🐻 Bear Species Detector") as demo:
168
+ gr.HTML("""
169
+ <div class="bear-title">
170
+ 🐻 Bear Species Detector 🐼
171
+ </div>
172
+ <p style="text-align: center; font-size: 1.2em; color: #666;">
173
+ Upload an image of a bear and I'll tell you what species it is!<br>
174
+ <em>Supports: Black Bear, Grizzly Bear, Polar Bear, Giant Panda, and even Teddy Bears! 🧸</em>
175
+ </p>
176
+ """)
177
+
178
+ with gr.Row():
179
+ with gr.Column():
180
+ # Image input
181
+ image_input = gr.Image(
182
+ label="Upload Bear Image πŸ“Έ",
183
+ type="pil",
184
+ height=400
185
+ )
186
+
187
+ # Submit button
188
+ submit_btn = gr.Button(
189
+ "Detect Bear Type! πŸ”",
190
+ variant="primary",
191
+ size="lg"
192
+ )
193
+
194
+ # Get sample images
195
+ sample_images = get_sample_images()
196
+
197
+ # Only show examples if we have sample images
198
+ if sample_images:
199
+ gr.Examples(
200
+ examples=sample_images,
201
+ inputs=image_input,
202
+ label="Try these examples:"
203
+ )
204
+ else:
205
+ gr.HTML("""
206
+ <p style="text-align: center; color: #888; font-style: italic;">
207
+ πŸ’‘ Add sample images to the 'samples/' folder to see examples here!
208
+ </p>
209
+ """)
210
+
211
+ with gr.Column():
212
+ # Prediction output
213
+ prediction_output = gr.Label(
214
+ label="Prediction Confidence πŸ“Š",
215
+ num_top_classes=5
216
+ )
217
+
218
+ # Bear information output
219
+ info_output = gr.Markdown(
220
+ label="Bear Information πŸ“–",
221
+ value="Upload an image to learn about the bear species!"
222
+ )
223
+
224
+ # Connect the interface
225
+ submit_btn.click(
226
+ fn=predict_and_explain,
227
+ inputs=image_input,
228
+ outputs=[prediction_output, info_output]
229
+ )
230
+
231
+ # Also trigger on image upload
232
+ image_input.change(
233
+ fn=handle_image_change,
234
+ inputs=image_input,
235
+ outputs=[prediction_output, info_output]
236
+ )
237
+
238
+ gr.HTML("""
239
+ <div style="text-align: center; margin-top: 30px; color: #888;">
240
+ <p>Built with ❀️ using FastAI and Gradio</p>
241
+ </div>
242
+ """)
243
+
244
+ return demo
245
+
246
+
247
+ # Main execution
248
+ if __name__ == "__main__":
249
+ # Check if model is loaded
250
+ if learn is None:
251
+ print("❌ Error: Could not load the model. Please ensure 'bears_model_xx.pkl' is in the correct path.")
252
+ print("πŸ’‘ Tip: Update the model_path in the load_model() function to point to your saved model.")
253
+ else:
254
+ print("βœ… Model loaded successfully!")
255
+ print(f"πŸ“‹ Classes: {learn.dls.vocab}")
256
+
257
+ # Create and launch the interface
258
+ demo = create_interface()
259
+
260
+ # Launch the app
261
+ demo.launch(
262
+ share=True, # Set to True to create a public link
263
+ server_name="0.0.0.0", # Allow access from any IP
264
+ server_port=7860, # Default Gradio port
265
+ show_error=True
266
+ )