import torch import torch.nn as nn import torch.nn.functional as F import streamlit as st import numpy as np import torchvision.transforms as transforms from PIL import Image, ImageDraw import os import base64 from io import BytesIO # Define the neural network model - matching your trained model with 3 input channels class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 3 input image channels (RGB), 6 output channels, 5x5 square convolution kernel self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) # an affine operation: y = Wx + b self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): # Convolution layer C1: 3 input image channels, 6 output channels, # 5x5 square convolution, it uses RELU activation function, and # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch c1 = F.relu(self.conv1(x)) # Subsampling layer S2: 2x2 grid, purely functional, # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor s2 = F.max_pool2d(c1, (2, 2)) # Convolution layer C3: 6 input channels, 16 output channels, # 5x5 square convolution, it uses RELU activation function, and # outputs a (N, 16, 10, 10) Tensor c3 = F.relu(self.conv2(s2)) # Subsampling layer S4: 2x2 grid, purely functional, # this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor s4 = F.max_pool2d(c3, 2) # Flatten operation: purely functional, outputs a (N, 400) Tensor s4 = torch.flatten(s4, 1) # Fully connected layer F5: (N, 400) Tensor input, # and outputs a (N, 120) Tensor, it uses RELU activation function f5 = F.relu(self.fc1(s4)) # Fully connected layer F6: (N, 120) Tensor input, # and outputs a (N, 84) Tensor, it uses RELU activation function f6 = F.relu(self.fc2(f5)) # Gaussian layer OUTPUT: (N, 84) Tensor input, and # outputs a (N, 10) Tensor output = self.fc3(f6) return output # Initialize the model model = Net() # Load the trained model weights def load_model(): model_path = "model.pth" # Update this path to where your model is stored if os.path.exists(model_path): try: # Load the trained model weights # Handle different PyTorch versions try: # For PyTorch 2.6+, we need to set weights_only=False for compatibility model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)) except TypeError: # For older PyTorch versions that don't support weights_only parameter model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) print("Loaded trained model weights") return True except Exception as e: print(f"Error loading model: {e}") return False else: print("No trained model found at", model_path) # Initialize with random weights for demonstration for m in model.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) return False # Preprocessing function for input images - now handles RGB images def preprocess_image(image): # Resize to 32x32 (expected input size for the network) transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), ]) image_tensor = transform(image) # Add batch dimension (1, 3, 32, 32) image_tensor = image_tensor.unsqueeze(0) return image_tensor # Prediction function - matches the PyTorch tutorial exactly def predict(image): if image is None: return {f"Class {i}": 0 for i in range(10)} # Preprocess the image input_tensor = preprocess_image(image) # Make prediction - exactly as shown in the PyTorch tutorial model.eval() with torch.no_grad(): output = model(input_tensor) # Apply softmax to get probabilities probabilities = F.softmax(output, dim=1) probabilities = probabilities.numpy()[0] # Create labels for CIFAR-10 classes cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"] # Return as a dictionary return {label: float(prob) for label, prob in zip(cifar10_classes, probabilities)} # Create example images representing CIFAR-10 classes def create_example_images(): examples = [] example_names = [] # CIFAR-10 class names cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"] # Create simple representations of CIFAR-10 classes for i, class_name in enumerate(cifar10_classes): # Create a 64x64 RGB image for better quality img = Image.new('RGB', (64, 64), color=(255, 255, 255)) # White background draw = ImageDraw.Draw(img) # Draw simple representations of each class if i == 0: # Airplane # Draw a simple airplane shape draw.polygon([(32, 10), (20, 30), (44, 30)], fill=(169, 169, 169)) # Main body draw.rectangle([25, 30, 39, 35], fill=(105, 105, 105)) # Wings draw.rectangle([30, 35, 34, 45], fill=(128, 128, 128)) # Tail elif i == 1: # Automobile # Draw a simple car shape draw.rectangle([15, 30, 49, 45], fill=(0, 0, 255)) # Body draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0)) draw.rectangle([25, 20, 39, 30], fill=(0, 0, 255)) # Top elif i == 2: # Bird # Draw a simple bird shape draw.ellipse([25, 25, 39, 39], fill=(255, 165, 0)) # Body draw.polygon([(32, 15), (25, 25), (39, 25)], fill=(255, 140, 0)) # Head draw.line([20, 30, 10, 20], fill=(255, 165, 0), width=3) # Wing draw.line([44, 30, 54, 20], fill=(255, 165, 0), width=3) # Wing elif i == 3: # Cat # Draw a simple cat shape draw.ellipse([25, 25, 39, 39], fill=(128, 128, 128)) # Body draw.ellipse([30, 20, 40, 30], fill=(169, 169, 169)) # Head draw.polygon([(35, 22), (33, 27), (37, 27)], fill=(0, 0, 0)) # Ear draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye elif i == 4: # Deer # Draw a simple deer shape draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head draw.line([35, 15, 40, 25], fill=(139, 69, 19), width=3) # Antler draw.line([20, 35, 10, 30], fill=(139, 69, 19), width=2) # Leg elif i == 5: # Dog # Draw a simple dog shape draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye draw.ellipse([36, 32, 38, 34], fill=(0, 0, 0)) # Nose elif i == 6: # Frog # Draw a simple frog shape draw.ellipse([25, 30, 39, 44], fill=(34, 139, 34)) # Body draw.ellipse([30, 25, 40, 35], fill=(0, 100, 0)) # Head draw.ellipse([27, 32, 29, 34], fill=(0, 0, 0)) # Eye draw.ellipse([35, 32, 37, 34], fill=(0, 0, 0)) # Eye elif i == 7: # Horse # Draw a simple horse shape draw.ellipse([25, 30, 39, 44], fill=(169, 169, 169)) # Body draw.ellipse([35, 20, 45, 30], fill=(128, 128, 128)) # Head draw.line([40, 25, 50, 15], fill=(105, 105, 105), width=3) # Mane elif i == 8: # Ship # Draw a simple ship shape draw.polygon([(20, 35), (44, 35), (38, 45), (26, 45)], fill=(139, 69, 19)) # Hull draw.rectangle([30, 20, 34, 35], fill=(169, 169, 169)) # Mast draw.polygon([(30, 20), (32, 15), (34, 20)], fill=(255, 255, 255)) # Sail elif i == 9: # Truck # Draw a simple truck shape draw.rectangle([15, 25, 49, 45], fill=(255, 0, 0)) # Cab draw.rectangle([25, 15, 45, 25], fill=(255, 0, 0)) # Load area draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0)) examples.append(img) example_names.append(class_name) return examples, example_names # Function to convert PIL Image to base64 for display def image_to_base64(image): buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return img_str # Initialize the model model_loaded = load_model() # Create example images examples, example_names = create_example_images() # Streamlit app st.set_page_config( page_title="CIFAR-10 Image Classifier", page_icon="๐Ÿš€", layout="wide" ) # Custom CSS with cleaner design st.markdown(""" """, unsafe_allow_html=True) # Main app content st.markdown('
', unsafe_allow_html=True) st.markdown('

๐Ÿš€ CIFAR-10 Image Classifier

', unsafe_allow_html=True) st.markdown('

Convolutional Neural Network for Object Recognition

', unsafe_allow_html=True) # Show model loading status if model_loaded: st.success("โœ… Model successfully loaded") else: st.warning("โš ๏ธ Model not found or error loading. Using random weights for demonstration.") # Create tabs for better organization tab1, tab2, tab3 = st.tabs(["๐Ÿ” Classify", "๐Ÿ–ผ๏ธ Examples", "๐Ÿ“š Information"]) with tab1: # Create two columns for input and output col1, col2 = st.columns(2) with col1: st.markdown('
', unsafe_allow_html=True) st.markdown('

๐Ÿ“ค Input

', unsafe_allow_html=True) # File uploader uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) # Display image image = None if uploaded_file is not None: image = Image.open(uploaded_file).convert('RGB') st.image(image, caption="Uploaded Image", use_container_width=True) # Classify button if st.button("Classify Image"): if image is not None: st.session_state.predictions = predict(image) else: st.warning("Please upload an image first") # Clear button if st.button("Clear"): st.session_state.predictions = None st.experimental_rerun() st.markdown('
', unsafe_allow_html=True) # Model architecture section st.markdown('
', unsafe_allow_html=True) st.markdown('

๐ŸŽฏ Model Architecture

', unsafe_allow_html=True) st.code(""" Input โ†’ Conv2D(3ร—32ร—32) โ†’ ReLU โ†’ MaxPool2D โ†’ Conv2D โ†’ ReLU โ†’ MaxPool2D โ†’ Flatten โ†’ Linear โ†’ ReLU โ†’ Linear โ†’ ReLU โ†’ Linear(10) โ†’ Output """, language="text") st.markdown('
', unsafe_allow_html=True) with col2: st.markdown('
', unsafe_allow_html=True) st.markdown('

๐Ÿ“Š Classification Results

', unsafe_allow_html=True) # Display results if "predictions" in st.session_state and st.session_state.predictions: predictions = st.session_state.predictions # Sort predictions by probability sorted_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True) # Display top 5 predictions with animated bars st.markdown('
', unsafe_allow_html=True) for label, prob in sorted_predictions[:5]: st.markdown(f'''
{label}
{prob:.2f}
''', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) # Display all probabilities in a more detailed way st.subheader("All Class Probabilities") for label, prob in sorted_predictions: st.progress(prob) st.write(f"{label}: {prob:.4f}") else: st.info("Upload an image and click 'Classify Image' to see results") st.markdown('
', unsafe_allow_html=True) # Instructions section st.markdown('
', unsafe_allow_html=True) st.markdown('

โ„น๏ธ Instructions

', unsafe_allow_html=True) st.markdown(""" 1. Upload an image using the file uploader 2. The image will be automatically resized to 32ร—32 pixels 3. Click "Classify Image" to get predictions 4. Results show probabilities for 10 CIFAR-10 classes """) st.markdown('
', unsafe_allow_html=True) with tab2: # Example images section st.markdown('
', unsafe_allow_html=True) st.markdown('

๐Ÿ–ผ๏ธ Example Images

', unsafe_allow_html=True) st.markdown("Click on any example image to classify it:") # Create example grid st.markdown('
', unsafe_allow_html=True) for i, (example_img, example_name) in enumerate(zip(examples, example_names)): # Convert PIL image to base64 img_base64 = image_to_base64(example_img) # Create clickable image if st.button(f"example_{i}", key=f"btn_{i}"): st.session_state.predictions = predict(example_img) st.experimental_rerun() st.markdown(f'''
{example_name}
{example_name}
''', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) with tab3: # Information sections st.markdown('
', unsafe_allow_html=True) st.markdown('

๐Ÿงช Testing Different Image Qualities

', unsafe_allow_html=True) st.markdown(""" This model is robust to various image conditions: - **Resolution**: Works with images of any resolution (automatically resized to 32ร—32) - **Contrast**: Handles both high and low contrast images - **Noise**: Can tolerate some image noise - **Rotation**: Some tolerance to slight rotations - **Scale**: Works with objects of different sizes within the image For best results: 1. Center the object in the image 2. Use clear contrast between the object and background 3. Avoid excessive noise or artifacts 4. Fill most of the image area with the object """) st.markdown('
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) st.markdown('

๐ŸŽฏ CIFAR-10 Classes

', unsafe_allow_html=True) classes_info = """ 1. **Airplane** - Aircraft flying in the sky 2. **Automobile** - Cars and vehicles on the road 3. **Bird** - Flying or perched birds 4. **Cat** - Domestic cats and felines 5. **Deer** - Wild deer and similar animals 6. **Dog** - Domestic dogs and canines 7. **Frog** - Amphibians like frogs 8. **Horse** - Horses and similar animals 9. **Ship** - Boats and ships on water 10. **Truck** - Trucks and heavy vehicles """ st.markdown(classes_info) st.markdown('
', unsafe_allow_html=True) # Model architecture section st.markdown('
', unsafe_allow_html=True) st.markdown('

๐Ÿง  Model Details

', unsafe_allow_html=True) st.markdown(""" This convolutional neural network follows the PyTorch CIFAR-10 tutorial architecture: - **Input Layer**: 3ร—32ร—32 RGB images - **Convolutional Layers**: 2 layers with ReLU activation - **Pooling Layers**: 2 max-pooling layers - **Fully Connected Layers**: 3 linear layers - **Output Layer**: 10 classes with softmax activation """) st.markdown('
', unsafe_allow_html=True) # Footer st.markdown('', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True)