Alamgirapi commited on
Commit
72c86b8
Β·
verified Β·
1 Parent(s): 7f38e7d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -0
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import timm
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ import requests
12
+ import tempfile
13
+ import os
14
+
15
+ # Set page config
16
+ st.set_page_config(
17
+ page_title="Dog Breed Classifier",
18
+ page_icon="πŸ•",
19
+ layout="wide"
20
+ )
21
+
22
+ # Default model URL
23
+ DEFAULT_MODEL_URL = "https://huggingface.co/Alamgirapi/dog-breed-convnext-classifier/resolve/main/model.pth"
24
+
25
+ # Device setup
26
+ @st.cache_resource
27
+ def setup_device_and_model():
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # Initialize model
31
+ model = timm.create_model('convnext_base', pretrained=True)
32
+
33
+ # Define label names
34
+ label_names = ['beagle', 'bulldog', 'dalmatian', 'german-shepherd', 'husky', 'poodle', 'rottweiler']
35
+
36
+ # Replace head with proper flattening
37
+ model.head = nn.Sequential(
38
+ nn.AdaptiveAvgPool2d(1),
39
+ nn.Flatten(),
40
+ nn.Linear(model.head.in_features, len(label_names))
41
+ )
42
+
43
+ model = model.to(device)
44
+
45
+ # Define transform
46
+ transform = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+ ])
51
+
52
+ return device, model, label_names, transform
53
+
54
+ @st.cache_resource
55
+ def download_and_load_model(_model, device):
56
+ """Download and load model weights from Hugging Face"""
57
+ try:
58
+ with st.spinner("Downloading model from Hugging Face..."):
59
+ # Download the model file
60
+ response = requests.get(DEFAULT_MODEL_URL)
61
+ response.raise_for_status()
62
+
63
+ # Save to temporary file
64
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file:
65
+ tmp_file.write(response.content)
66
+ tmp_model_path = tmp_file.name
67
+
68
+ # Load the model weights
69
+ _model.load_state_dict(torch.load(tmp_model_path, map_location=device))
70
+ _model.eval()
71
+
72
+ # Clean up temporary file
73
+ os.unlink(tmp_model_path)
74
+
75
+ return True
76
+ except Exception as e:
77
+ st.error(f"Error downloading/loading model: {str(e)}")
78
+ return False
79
+
80
+ def predict_image(image, model, transform, label_names, device, topk=3):
81
+ """Make predictions on uploaded image"""
82
+ # Transform image
83
+ if image.mode != 'RGB':
84
+ image = image.convert('RGB')
85
+
86
+ img_tensor = transform(image).unsqueeze(0).to(device)
87
+
88
+ # Predict
89
+ model.eval()
90
+ with torch.no_grad():
91
+ outputs = model(img_tensor)
92
+ probs = F.softmax(outputs, dim=1)
93
+ top_probs, top_idxs = torch.topk(probs, k=topk)
94
+
95
+ # Convert to CPU for display
96
+ top_probs = top_probs[0].cpu().numpy()
97
+ top_idxs = top_idxs[0].cpu().numpy()
98
+
99
+ # Build prediction results
100
+ predictions = []
101
+ for idx, prob in zip(top_idxs, top_probs):
102
+ predictions.append({
103
+ 'breed': label_names[idx],
104
+ 'confidence': prob * 100
105
+ })
106
+
107
+ return predictions
108
+
109
+ def create_prediction_chart(predictions):
110
+ """Create a horizontal bar chart for predictions"""
111
+ breeds = [pred['breed'].replace('-', ' ').title() for pred in predictions]
112
+ confidences = [float(pred['confidence']) for pred in predictions] # Convert to Python float
113
+
114
+ fig, ax = plt.subplots(figsize=(10, 6))
115
+ bars = ax.barh(breeds, confidences, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
116
+
117
+ ax.set_xlabel('Confidence (%)')
118
+ ax.set_title('Top 3 Breed Predictions')
119
+ ax.set_xlim(0, 100)
120
+
121
+ # Add percentage labels on bars
122
+ for i, (bar, conf) in enumerate(zip(bars, confidences)):
123
+ ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
124
+ f'{conf:.1f}%', va='center')
125
+
126
+ plt.tight_layout()
127
+ return fig
128
+
129
+ # Main app
130
+ def main():
131
+ st.title("πŸ• Dog Breed Classifier")
132
+ st.write("Upload an image of a dog to identify its breed!")
133
+
134
+ # Initialize model and device
135
+ device, model, label_names, transform = setup_device_and_model()
136
+
137
+ # Download and load the model automatically
138
+ model_loaded = download_and_load_model(model, device)
139
+
140
+ if model_loaded:
141
+ st.success("βœ… Model loaded successfully from Hugging Face!")
142
+ else:
143
+ st.error("❌ Failed to load model. Please refresh the page and try again.")
144
+ return
145
+
146
+ # Main content
147
+ col1, col2 = st.columns([1, 1])
148
+
149
+ with col1:
150
+ st.header("Upload Image")
151
+ uploaded_file = st.file_uploader(
152
+ "Choose an image file",
153
+ type=['jpg', 'jpeg', 'png'],
154
+ help="Upload a clear image of a dog for best results"
155
+ )
156
+
157
+ if uploaded_file is not None:
158
+ # Display uploaded image
159
+ image = Image.open(uploaded_file)
160
+ st.image(image, caption="Uploaded Image", use_container_width=True)
161
+
162
+ # Show image details
163
+ st.write(f"**Image Size:** {image.size}")
164
+ st.write(f"**Image Mode:** {image.mode}")
165
+
166
+ with col2:
167
+ st.header("Predictions")
168
+
169
+ if uploaded_file is not None:
170
+ try:
171
+ with st.spinner("Analyzing image..."):
172
+ # Make predictions
173
+ predictions = predict_image(image, model, transform, label_names, device)
174
+
175
+ # Display results
176
+ st.success("πŸŽ‰ Analysis Complete!")
177
+
178
+ # Show top prediction prominently
179
+ top_breed = predictions[0]['breed'].replace('-', ' ').title()
180
+ top_confidence = float(predictions[0]['confidence']) # Convert to Python float
181
+
182
+ st.markdown(f"""
183
+ <div style="background-color: #f0f8ff; padding: 20px; border-radius: 10px; border-left: 5px solid #1f77b4;">
184
+ <h3 style="color: #1f77b4; margin: 0;">πŸ† Most Likely Breed</h3>
185
+ <h2 style="margin: 5px 0;">{top_breed}</h2>
186
+ <h4 style="color: #666; margin: 0;">Confidence: {top_confidence:.1f}%</h4>
187
+ </div>
188
+ """, unsafe_allow_html=True)
189
+
190
+ st.write("") # Add some space
191
+
192
+ # Show all predictions
193
+ st.subheader("All Predictions:")
194
+ for i, pred in enumerate(predictions):
195
+ breed = pred['breed'].replace('-', ' ').title()
196
+ confidence = float(pred['confidence']) # Convert numpy float32 to Python float
197
+
198
+ # Create progress bar
199
+ st.write(f"**{i+1}. {breed}**")
200
+ st.progress(confidence/100)
201
+ st.write(f"Confidence: {confidence:.2f}%")
202
+ st.write("")
203
+
204
+ # Show chart
205
+ st.subheader("Prediction Chart:")
206
+ fig = create_prediction_chart(predictions)
207
+ st.pyplot(fig)
208
+
209
+ except Exception as e:
210
+ st.error(f"Error during prediction: {str(e)}")
211
+
212
+ else:
213
+ st.info("πŸ“€ Please upload an image to start classification.")
214
+
215
+ # Information section
216
+ with st.expander("ℹ️ About this App"):
217
+ st.write("""
218
+ This app uses a ConvNeXt-Base model trained to classify dog breeds among:
219
+ - Beagle
220
+ - Bulldog
221
+ - Dalmatian
222
+ - German Shepherd
223
+ - Husky
224
+ - Poodle
225
+ - Rottweiler
226
+
227
+ **How to use:**
228
+ 1. The model is automatically loaded from Hugging Face
229
+ 2. Upload a clear image of a dog
230
+ 3. View the top 3 breed predictions with confidence scores
231
+
232
+ **Tips for better results:**
233
+ - Use high-quality, well-lit images
234
+ - Ensure the dog is clearly visible in the image
235
+ - Avoid images with multiple dogs
236
+ """)
237
+
238
+ # Technical details
239
+ with st.expander("πŸ”§ Technical Details"):
240
+ st.write(f"""
241
+ - **Device:** {device}
242
+ - **Model:** ConvNeXt-Base
243
+ - **Model Source:** Hugging Face (Alamgirapi/dog-breed-convnext-classifier)
244
+ - **Input Size:** 224x224 pixels
245
+ - **Classes:** {len(label_names)}
246
+ - **Framework:** PyTorch + Streamlit
247
+ """)
248
+
249
+ if __name__ == "__main__":
250
+ main()