ksj47 commited on
Commit
e440ee7
Β·
verified Β·
1 Parent(s): a0ef0f4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +617 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import streamlit as st
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image, ImageDraw
8
+ import os
9
+ import base64
10
+ from io import BytesIO
11
+
12
+ # Define the neural network model - matching your trained model with 3 input channels
13
+ class Net(nn.Module):
14
+ def __init__(self):
15
+ super(Net, self).__init__()
16
+ # 3 input image channels (RGB), 6 output channels, 5x5 square convolution kernel
17
+ self.conv1 = nn.Conv2d(3, 6, 5)
18
+ self.conv2 = nn.Conv2d(6, 16, 5)
19
+ # an affine operation: y = Wx + b
20
+ self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
21
+ self.fc2 = nn.Linear(120, 84)
22
+ self.fc3 = nn.Linear(84, 10)
23
+
24
+ def forward(self, x):
25
+ # Convolution layer C1: 3 input image channels, 6 output channels,
26
+ # 5x5 square convolution, it uses RELU activation function, and
27
+ # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch
28
+ c1 = F.relu(self.conv1(x))
29
+ # Subsampling layer S2: 2x2 grid, purely functional,
30
+ # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor
31
+ s2 = F.max_pool2d(c1, (2, 2))
32
+ # Convolution layer C3: 6 input channels, 16 output channels,
33
+ # 5x5 square convolution, it uses RELU activation function, and
34
+ # outputs a (N, 16, 10, 10) Tensor
35
+ c3 = F.relu(self.conv2(s2))
36
+ # Subsampling layer S4: 2x2 grid, purely functional,
37
+ # this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor
38
+ s4 = F.max_pool2d(c3, 2)
39
+ # Flatten operation: purely functional, outputs a (N, 400) Tensor
40
+ s4 = torch.flatten(s4, 1)
41
+ # Fully connected layer F5: (N, 400) Tensor input,
42
+ # and outputs a (N, 120) Tensor, it uses RELU activation function
43
+ f5 = F.relu(self.fc1(s4))
44
+ # Fully connected layer F6: (N, 120) Tensor input,
45
+ # and outputs a (N, 84) Tensor, it uses RELU activation function
46
+ f6 = F.relu(self.fc2(f5))
47
+ # Gaussian layer OUTPUT: (N, 84) Tensor input, and
48
+ # outputs a (N, 10) Tensor
49
+ output = self.fc3(f6)
50
+ return output
51
+
52
+ # Initialize the model
53
+ model = Net()
54
+
55
+ # Load the trained model weights
56
+ def load_model():
57
+ model_path = "model.pth" # Update this path to where your model is stored
58
+ if os.path.exists(model_path):
59
+ try:
60
+ # Load the trained model weights
61
+ # Handle different PyTorch versions
62
+ try:
63
+ # For PyTorch 2.6+, we need to set weights_only=False for compatibility
64
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=False))
65
+ except TypeError:
66
+ # For older PyTorch versions that don't support weights_only parameter
67
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
68
+ print("Loaded trained model weights")
69
+ return True
70
+ except Exception as e:
71
+ print(f"Error loading model: {e}")
72
+ return False
73
+ else:
74
+ print("No trained model found at", model_path)
75
+ # Initialize with random weights for demonstration
76
+ for m in model.modules():
77
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
78
+ nn.init.xavier_uniform_(m.weight)
79
+ if m.bias is not None:
80
+ nn.init.constant_(m.bias, 0)
81
+ return False
82
+
83
+ # Preprocessing function for input images - now handles RGB images
84
+ def preprocess_image(image):
85
+ # Resize to 32x32 (expected input size for the network)
86
+ transform = transforms.Compose([
87
+ transforms.Resize((32, 32)),
88
+ transforms.ToTensor(),
89
+ ])
90
+
91
+ image_tensor = transform(image)
92
+ # Add batch dimension (1, 3, 32, 32)
93
+ image_tensor = image_tensor.unsqueeze(0)
94
+ return image_tensor
95
+
96
+ # Prediction function - matches the PyTorch tutorial exactly
97
+ def predict(image):
98
+ if image is None:
99
+ return {f"Class {i}": 0 for i in range(10)}
100
+
101
+ # Preprocess the image
102
+ input_tensor = preprocess_image(image)
103
+
104
+ # Make prediction - exactly as shown in the PyTorch tutorial
105
+ model.eval()
106
+ with torch.no_grad():
107
+ output = model(input_tensor)
108
+ # Apply softmax to get probabilities
109
+ probabilities = F.softmax(output, dim=1)
110
+ probabilities = probabilities.numpy()[0]
111
+
112
+ # Create labels for CIFAR-10 classes
113
+ cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"]
114
+
115
+ # Return as a dictionary
116
+ return {label: float(prob) for label, prob in zip(cifar10_classes, probabilities)}
117
+
118
+ # Create example images representing CIFAR-10 classes
119
+ def create_example_images():
120
+ examples = []
121
+ example_names = []
122
+
123
+ # CIFAR-10 class names
124
+ cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"]
125
+
126
+ # Create simple representations of CIFAR-10 classes
127
+ for i, class_name in enumerate(cifar10_classes):
128
+ # Create a 64x64 RGB image for better quality
129
+ img = Image.new('RGB', (64, 64), color=(255, 255, 255)) # White background
130
+ draw = ImageDraw.Draw(img)
131
+
132
+ # Draw simple representations of each class
133
+ if i == 0: # Airplane
134
+ # Draw a simple airplane shape
135
+ draw.polygon([(32, 10), (20, 30), (44, 30)], fill=(169, 169, 169)) # Main body
136
+ draw.rectangle([25, 30, 39, 35], fill=(105, 105, 105)) # Wings
137
+ draw.rectangle([30, 35, 34, 45], fill=(128, 128, 128)) # Tail
138
+ elif i == 1: # Automobile
139
+ # Draw a simple car shape
140
+ draw.rectangle([15, 30, 49, 45], fill=(0, 0, 255)) # Body
141
+ draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels
142
+ draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0))
143
+ draw.rectangle([25, 20, 39, 30], fill=(0, 0, 255)) # Top
144
+ elif i == 2: # Bird
145
+ # Draw a simple bird shape
146
+ draw.ellipse([25, 25, 39, 39], fill=(255, 165, 0)) # Body
147
+ draw.polygon([(32, 15), (25, 25), (39, 25)], fill=(255, 140, 0)) # Head
148
+ draw.line([20, 30, 10, 20], fill=(255, 165, 0), width=3) # Wing
149
+ draw.line([44, 30, 54, 20], fill=(255, 165, 0), width=3) # Wing
150
+ elif i == 3: # Cat
151
+ # Draw a simple cat shape
152
+ draw.ellipse([25, 25, 39, 39], fill=(128, 128, 128)) # Body
153
+ draw.ellipse([30, 20, 40, 30], fill=(169, 169, 169)) # Head
154
+ draw.polygon([(35, 22), (33, 27), (37, 27)], fill=(0, 0, 0)) # Ear
155
+ draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye
156
+ elif i == 4: # Deer
157
+ # Draw a simple deer shape
158
+ draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body
159
+ draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head
160
+ draw.line([35, 15, 40, 25], fill=(139, 69, 19), width=3) # Antler
161
+ draw.line([20, 35, 10, 30], fill=(139, 69, 19), width=2) # Leg
162
+ elif i == 5: # Dog
163
+ # Draw a simple dog shape
164
+ draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body
165
+ draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head
166
+ draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye
167
+ draw.ellipse([36, 32, 38, 34], fill=(0, 0, 0)) # Nose
168
+ elif i == 6: # Frog
169
+ # Draw a simple frog shape
170
+ draw.ellipse([25, 30, 39, 44], fill=(34, 139, 34)) # Body
171
+ draw.ellipse([30, 25, 40, 35], fill=(0, 100, 0)) # Head
172
+ draw.ellipse([27, 32, 29, 34], fill=(0, 0, 0)) # Eye
173
+ draw.ellipse([35, 32, 37, 34], fill=(0, 0, 0)) # Eye
174
+ elif i == 7: # Horse
175
+ # Draw a simple horse shape
176
+ draw.ellipse([25, 30, 39, 44], fill=(169, 169, 169)) # Body
177
+ draw.ellipse([35, 20, 45, 30], fill=(128, 128, 128)) # Head
178
+ draw.line([40, 25, 50, 15], fill=(105, 105, 105), width=3) # Mane
179
+ elif i == 8: # Ship
180
+ # Draw a simple ship shape
181
+ draw.polygon([(20, 35), (44, 35), (38, 45), (26, 45)], fill=(139, 69, 19)) # Hull
182
+ draw.rectangle([30, 20, 34, 35], fill=(169, 169, 169)) # Mast
183
+ draw.polygon([(30, 20), (32, 15), (34, 20)], fill=(255, 255, 255)) # Sail
184
+ elif i == 9: # Truck
185
+ # Draw a simple truck shape
186
+ draw.rectangle([15, 25, 49, 45], fill=(255, 0, 0)) # Cab
187
+ draw.rectangle([25, 15, 45, 25], fill=(255, 0, 0)) # Load area
188
+ draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels
189
+ draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0))
190
+
191
+ examples.append(img)
192
+ example_names.append(class_name)
193
+
194
+ return examples, example_names
195
+
196
+ # Function to convert PIL Image to base64 for display
197
+ def image_to_base64(image):
198
+ buffered = BytesIO()
199
+ image.save(buffered, format="PNG")
200
+ img_str = base64.b64encode(buffered.getvalue()).decode()
201
+ return img_str
202
+
203
+ # Initialize the model
204
+ model_loaded = load_model()
205
+
206
+ # Create example images
207
+ examples, example_names = create_example_images()
208
+
209
+ # Streamlit app
210
+ st.set_page_config(
211
+ page_title="CIFAR-10 Image Classifier",
212
+ page_icon="πŸš€",
213
+ layout="wide"
214
+ )
215
+
216
+ # Custom CSS with cleaner design
217
+ st.markdown("""
218
+ <style>
219
+ /* Import Google Fonts */
220
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;500;600;700&display=swap');
221
+
222
+ /* Base styles */
223
+ * {
224
+ font-family: 'Poppins', sans-serif;
225
+ }
226
+
227
+ /* Clean background */
228
+ body {
229
+ background: linear-gradient(135deg, #1a2a6c, #2c3e50);
230
+ color: white;
231
+ }
232
+
233
+ /* Main container with clean glassmorphism effect */
234
+ .main-container {
235
+ background: rgba(255, 255, 255, 0.05);
236
+ backdrop-filter: blur(10px);
237
+ border-radius: 20px;
238
+ border: 1px solid rgba(255, 255, 255, 0.1);
239
+ box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.3);
240
+ padding: 2rem;
241
+ margin: 2rem auto;
242
+ max-width: 1200px;
243
+ }
244
+
245
+ /* Title with clean gradient */
246
+ .title {
247
+ background: linear-gradient(90deg, #4facfe 0%, #00f2fe 100%);
248
+ -webkit-background-clip: text;
249
+ -webkit-text-fill-color: transparent;
250
+ background-clip: text;
251
+ font-weight: 800;
252
+ font-size: 2.5rem;
253
+ text-align: center;
254
+ margin-bottom: 0.5rem;
255
+ }
256
+
257
+ /* Subtitle styling */
258
+ .subtitle {
259
+ text-align: center;
260
+ color: #a0d2ff;
261
+ font-size: 1.1rem;
262
+ margin-bottom: 2rem;
263
+ opacity: 0.9;
264
+ }
265
+
266
+ /* Card styling */
267
+ .card {
268
+ background: rgba(255, 255, 255, 0.05);
269
+ border-radius: 15px;
270
+ padding: 1.5rem;
271
+ margin-bottom: 1.5rem;
272
+ border: 1px solid rgba(255, 255, 255, 0.1);
273
+ transition: all 0.3s ease;
274
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15);
275
+ }
276
+
277
+ .card:hover {
278
+ background: rgba(255, 255, 255, 0.08);
279
+ box-shadow: 0 6px 25px rgba(0, 0, 0, 0.25);
280
+ transform: translateY(-3px);
281
+ }
282
+
283
+ /* Section headers */
284
+ .section-header {
285
+ color: #4facfe;
286
+ border-bottom: 2px solid #00f2fe;
287
+ padding-bottom: 0.5rem;
288
+ margin-bottom: 1rem;
289
+ font-weight: 600;
290
+ font-size: 1.3rem;
291
+ }
292
+
293
+ /* Button styling */
294
+ .stButton > button {
295
+ background: linear-gradient(90deg, #4facfe 0%, #00f2fe 100%);
296
+ color: white;
297
+ border: none;
298
+ border-radius: 10px;
299
+ padding: 0.7rem 1.2rem;
300
+ font-weight: 600;
301
+ transition: all 0.3s ease;
302
+ box-shadow: 0 4px 15px rgba(79, 172, 254, 0.3);
303
+ width: 100%;
304
+ }
305
+
306
+ .stButton > button:hover {
307
+ transform: translateY(-2px);
308
+ box-shadow: 0 6px 20px rgba(79, 172, 254, 0.5);
309
+ }
310
+
311
+ .stButton > button:active {
312
+ transform: translateY(1px);
313
+ }
314
+
315
+ /* File uploader styling */
316
+ .stFileUploader > div {
317
+ background: rgba(255, 255, 255, 0.05);
318
+ border-radius: 15px;
319
+ border: 1px dashed rgba(255, 255, 255, 0.3);
320
+ padding: 1.5rem;
321
+ text-align: center;
322
+ }
323
+
324
+ /* Progress bar styling */
325
+ .stProgress > div > div {
326
+ background: linear-gradient(90deg, #4facfe 0%, #00f2fe 100%);
327
+ }
328
+
329
+ /* Result display */
330
+ .result-container {
331
+ display: flex;
332
+ flex-wrap: wrap;
333
+ gap: 0.8rem;
334
+ justify-content: center;
335
+ }
336
+
337
+ .result-item {
338
+ background: rgba(255, 255, 255, 0.08);
339
+ border-radius: 12px;
340
+ padding: 1rem;
341
+ text-align: center;
342
+ min-width: 110px;
343
+ transition: all 0.3s ease;
344
+ border: 1px solid rgba(255, 255, 255, 0.1);
345
+ }
346
+
347
+ .result-item:hover {
348
+ background: rgba(79, 172, 254, 0.2);
349
+ transform: translateY(-3px);
350
+ box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2);
351
+ }
352
+
353
+ .result-label {
354
+ font-weight: 600;
355
+ margin-bottom: 0.4rem;
356
+ color: #4facfe;
357
+ font-size: 0.9rem;
358
+ }
359
+
360
+ .result-value {
361
+ font-size: 1.1rem;
362
+ font-weight: 700;
363
+ color: white;
364
+ }
365
+
366
+ /* Example images grid */
367
+ .examples-grid {
368
+ display: grid;
369
+ grid-template-columns: repeat(auto-fill, minmax(90px, 1fr));
370
+ gap: 0.8rem;
371
+ margin-top: 1rem;
372
+ }
373
+
374
+ .example-item {
375
+ cursor: pointer;
376
+ border-radius: 10px;
377
+ overflow: hidden;
378
+ transition: all 0.3s ease;
379
+ border: 2px solid transparent;
380
+ background: rgba(255, 255, 255, 0.05);
381
+ }
382
+
383
+ .example-item:hover {
384
+ transform: scale(1.05);
385
+ border-color: #4facfe;
386
+ box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3);
387
+ background: rgba(79, 172, 254, 0.1);
388
+ }
389
+
390
+ .example-item img {
391
+ border-radius: 8px;
392
+ }
393
+
394
+ .example-name {
395
+ text-align: center;
396
+ margin-top: 5px;
397
+ font-size: 0.75rem;
398
+ color: #a0d2ff;
399
+ }
400
+
401
+ /* Footer */
402
+ .footer {
403
+ text-align: center;
404
+ padding: 1.5rem;
405
+ color: rgba(255, 255, 255, 0.6);
406
+ font-size: 0.9rem;
407
+ }
408
+
409
+ /* Responsive design */
410
+ @media (max-width: 768px) {
411
+ .main-container {
412
+ padding: 1rem;
413
+ margin: 1rem;
414
+ }
415
+
416
+ .title {
417
+ font-size: 2rem;
418
+ }
419
+
420
+ .card {
421
+ padding: 1rem;
422
+ }
423
+
424
+ .result-item {
425
+ min-width: 90px;
426
+ padding: 0.7rem;
427
+ }
428
+
429
+ .examples-grid {
430
+ grid-template-columns: repeat(auto-fill, minmax(70px, 1fr));
431
+ }
432
+ }
433
+ </style>
434
+ """, unsafe_allow_html=True)
435
+
436
+ # Main app content
437
+ st.markdown('<div class="main-container">', unsafe_allow_html=True)
438
+
439
+ st.markdown('<h1 class="title">πŸš€ CIFAR-10 Image Classifier</h1>', unsafe_allow_html=True)
440
+ st.markdown('<p class="subtitle">Convolutional Neural Network for Object Recognition</p>', unsafe_allow_html=True)
441
+
442
+ # Show model loading status
443
+ if model_loaded:
444
+ st.success("βœ… Model successfully loaded")
445
+ else:
446
+ st.warning("⚠️ Model not found or error loading. Using random weights for demonstration.")
447
+
448
+ # Create tabs for better organization
449
+ tab1, tab2, tab3 = st.tabs(["πŸ” Classify", "πŸ–ΌοΈ Examples", "πŸ“š Information"])
450
+
451
+ with tab1:
452
+ # Create two columns for input and output
453
+ col1, col2 = st.columns(2)
454
+
455
+ with col1:
456
+ st.markdown('<div class="card">', unsafe_allow_html=True)
457
+ st.markdown('<h2 class="section-header">πŸ“€ Input</h2>', unsafe_allow_html=True)
458
+
459
+ # File uploader
460
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
461
+
462
+ # Display image
463
+ image = None
464
+ if uploaded_file is not None:
465
+ image = Image.open(uploaded_file).convert('RGB')
466
+ st.image(image, caption="Uploaded Image", use_container_width=True)
467
+
468
+ # Classify button
469
+ if st.button("Classify Image"):
470
+ if image is not None:
471
+ st.session_state.predictions = predict(image)
472
+ else:
473
+ st.warning("Please upload an image first")
474
+
475
+ # Clear button
476
+ if st.button("Clear"):
477
+ st.session_state.predictions = None
478
+ st.experimental_rerun()
479
+
480
+ st.markdown('</div>', unsafe_allow_html=True)
481
+
482
+ # Model architecture section
483
+ st.markdown('<div class="card">', unsafe_allow_html=True)
484
+ st.markdown('<h2 class="section-header">🎯 Model Architecture</h2>', unsafe_allow_html=True)
485
+ st.code("""
486
+ Input β†’ Conv2D(3Γ—32Γ—32) β†’ ReLU β†’ MaxPool2D
487
+ β†’ Conv2D β†’ ReLU β†’ MaxPool2D
488
+ β†’ Flatten β†’ Linear β†’ ReLU
489
+ β†’ Linear β†’ ReLU β†’ Linear(10)
490
+ β†’ Output
491
+ """, language="text")
492
+ st.markdown('</div>', unsafe_allow_html=True)
493
+
494
+ with col2:
495
+ st.markdown('<div class="card">', unsafe_allow_html=True)
496
+ st.markdown('<h2 class="section-header">πŸ“Š Classification Results</h2>', unsafe_allow_html=True)
497
+
498
+ # Display results
499
+ if "predictions" in st.session_state and st.session_state.predictions:
500
+ predictions = st.session_state.predictions
501
+ # Sort predictions by probability
502
+ sorted_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True)
503
+
504
+ # Display top 5 predictions with animated bars
505
+ st.markdown('<div class="result-container">', unsafe_allow_html=True)
506
+ for label, prob in sorted_predictions[:5]:
507
+ st.markdown(f'''
508
+ <div class="result-item">
509
+ <div class="result-label">{label}</div>
510
+ <div class="result-value">{prob:.2f}</div>
511
+ </div>
512
+ ''', unsafe_allow_html=True)
513
+ st.markdown('</div>', unsafe_allow_html=True)
514
+
515
+ # Display all probabilities in a more detailed way
516
+ st.subheader("All Class Probabilities")
517
+ for label, prob in sorted_predictions:
518
+ st.progress(prob)
519
+ st.write(f"{label}: {prob:.4f}")
520
+ else:
521
+ st.info("Upload an image and click 'Classify Image' to see results")
522
+
523
+ st.markdown('</div>', unsafe_allow_html=True)
524
+
525
+ # Instructions section
526
+ st.markdown('<div class="card">', unsafe_allow_html=True)
527
+ st.markdown('<h2 class="section-header">ℹ️ Instructions</h2>', unsafe_allow_html=True)
528
+ st.markdown("""
529
+ 1. Upload an image using the file uploader
530
+ 2. The image will be automatically resized to 32Γ—32 pixels
531
+ 3. Click "Classify Image" to get predictions
532
+ 4. Results show probabilities for 10 CIFAR-10 classes
533
+ """)
534
+ st.markdown('</div>', unsafe_allow_html=True)
535
+
536
+ with tab2:
537
+ # Example images section
538
+ st.markdown('<div class="card">', unsafe_allow_html=True)
539
+ st.markdown('<h2 class="section-header">πŸ–ΌοΈ Example Images</h2>', unsafe_allow_html=True)
540
+ st.markdown("Click on any example image to classify it:")
541
+
542
+ # Create example grid
543
+ st.markdown('<div class="examples-grid">', unsafe_allow_html=True)
544
+ for i, (example_img, example_name) in enumerate(zip(examples, example_names)):
545
+ # Convert PIL image to base64
546
+ img_base64 = image_to_base64(example_img)
547
+
548
+ # Create clickable image
549
+ if st.button(f"example_{i}", key=f"btn_{i}"):
550
+ st.session_state.predictions = predict(example_img)
551
+ st.experimental_rerun()
552
+
553
+ st.markdown(f'''
554
+ <div class="example-item">
555
+ <img src="data:image/png;base64,{img_base64}" width="100" height="100" alt="{example_name}">
556
+ <div class="example-name">{example_name}</div>
557
+ </div>
558
+ ''', unsafe_allow_html=True)
559
+ st.markdown('</div>', unsafe_allow_html=True)
560
+ st.markdown('</div>', unsafe_allow_html=True)
561
+
562
+ with tab3:
563
+ # Information sections
564
+ st.markdown('<div class="card">', unsafe_allow_html=True)
565
+ st.markdown('<h2 class="section-header">πŸ§ͺ Testing Different Image Qualities</h2>', unsafe_allow_html=True)
566
+ st.markdown("""
567
+ This model is robust to various image conditions:
568
+ - **Resolution**: Works with images of any resolution (automatically resized to 32Γ—32)
569
+ - **Contrast**: Handles both high and low contrast images
570
+ - **Noise**: Can tolerate some image noise
571
+ - **Rotation**: Some tolerance to slight rotations
572
+ - **Scale**: Works with objects of different sizes within the image
573
+
574
+ For best results:
575
+ 1. Center the object in the image
576
+ 2. Use clear contrast between the object and background
577
+ 3. Avoid excessive noise or artifacts
578
+ 4. Fill most of the image area with the object
579
+ """)
580
+ st.markdown('</div>', unsafe_allow_html=True)
581
+
582
+ st.markdown('<div class="card">', unsafe_allow_html=True)
583
+ st.markdown('<h2 class="section-header">🎯 CIFAR-10 Classes</h2>', unsafe_allow_html=True)
584
+ classes_info = """
585
+ 1. **Airplane** - Aircraft flying in the sky
586
+ 2. **Automobile** - Cars and vehicles on the road
587
+ 3. **Bird** - Flying or perched birds
588
+ 4. **Cat** - Domestic cats and felines
589
+ 5. **Deer** - Wild deer and similar animals
590
+ 6. **Dog** - Domestic dogs and canines
591
+ 7. **Frog** - Amphibians like frogs
592
+ 8. **Horse** - Horses and similar animals
593
+ 9. **Ship** - Boats and ships on water
594
+ 10. **Truck** - Trucks and heavy vehicles
595
+ """
596
+ st.markdown(classes_info)
597
+ st.markdown('</div>', unsafe_allow_html=True)
598
+
599
+ # Model architecture section
600
+ st.markdown('<div class="card">', unsafe_allow_html=True)
601
+ st.markdown('<h2 class="section-header">🧠 Model Details</h2>', unsafe_allow_html=True)
602
+ st.markdown("""
603
+ This convolutional neural network follows the PyTorch CIFAR-10 tutorial architecture:
604
+ - **Input Layer**: 3Γ—32Γ—32 RGB images
605
+ - **Convolutional Layers**: 2 layers with ReLU activation
606
+ - **Pooling Layers**: 2 max-pooling layers
607
+ - **Fully Connected Layers**: 3 linear layers
608
+ - **Output Layer**: 10 classes with softmax activation
609
+ """)
610
+ st.markdown('</div>', unsafe_allow_html=True)
611
+
612
+ # Footer
613
+ st.markdown('<div class="footer">', unsafe_allow_html=True)
614
+ st.markdown("Built with ❀️ using Streamlit and PyTorch | Deployable to Hugging Face Spaces")
615
+ st.markdown('</div>', unsafe_allow_html=True)
616
+
617
+ st.markdown('</div>', unsafe_allow_html=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=1.7.0
2
+ torchvision>=0.8.0
3
+ streamlit>=1.25.0
4
+ pillow>=8.0.0
5
+ numpy>=1.19.0