Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import streamlit as st | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from PIL import Image, ImageDraw | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| # Define the neural network model - matching your trained model with 3 input channels | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| # 3 input image channels (RGB), 6 output channels, 5x5 square convolution kernel | |
| self.conv1 = nn.Conv2d(3, 6, 5) | |
| self.conv2 = nn.Conv2d(6, 16, 5) | |
| # an affine operation: y = Wx + b | |
| self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension | |
| self.fc2 = nn.Linear(120, 84) | |
| self.fc3 = nn.Linear(84, 10) | |
| def forward(self, x): | |
| # Convolution layer C1: 3 input image channels, 6 output channels, | |
| # 5x5 square convolution, it uses RELU activation function, and | |
| # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch | |
| c1 = F.relu(self.conv1(x)) | |
| # Subsampling layer S2: 2x2 grid, purely functional, | |
| # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor | |
| s2 = F.max_pool2d(c1, (2, 2)) | |
| # Convolution layer C3: 6 input channels, 16 output channels, | |
| # 5x5 square convolution, it uses RELU activation function, and | |
| # outputs a (N, 16, 10, 10) Tensor | |
| c3 = F.relu(self.conv2(s2)) | |
| # Subsampling layer S4: 2x2 grid, purely functional, | |
| # this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor | |
| s4 = F.max_pool2d(c3, 2) | |
| # Flatten operation: purely functional, outputs a (N, 400) Tensor | |
| s4 = torch.flatten(s4, 1) | |
| # Fully connected layer F5: (N, 400) Tensor input, | |
| # and outputs a (N, 120) Tensor, it uses RELU activation function | |
| f5 = F.relu(self.fc1(s4)) | |
| # Fully connected layer F6: (N, 120) Tensor input, | |
| # and outputs a (N, 84) Tensor, it uses RELU activation function | |
| f6 = F.relu(self.fc2(f5)) | |
| # Gaussian layer OUTPUT: (N, 84) Tensor input, and | |
| # outputs a (N, 10) Tensor | |
| output = self.fc3(f6) | |
| return output | |
| # Initialize the model | |
| model = Net() | |
| # Load the trained model weights | |
| def load_model(): | |
| model_path = "model.pth" # Update this path to where your model is stored | |
| if os.path.exists(model_path): | |
| try: | |
| # Load the trained model weights | |
| # Handle different PyTorch versions | |
| try: | |
| # For PyTorch 2.6+, we need to set weights_only=False for compatibility | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)) | |
| except TypeError: | |
| # For older PyTorch versions that don't support weights_only parameter | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| print("Loaded trained model weights") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| else: | |
| print("No trained model found at", model_path) | |
| # Initialize with random weights for demonstration | |
| for m in model.modules(): | |
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| return False | |
| # Preprocessing function for input images - now handles RGB images | |
| def preprocess_image(image): | |
| # Resize to 32x32 (expected input size for the network) | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| ]) | |
| image_tensor = transform(image) | |
| # Add batch dimension (1, 3, 32, 32) | |
| image_tensor = image_tensor.unsqueeze(0) | |
| return image_tensor | |
| # Prediction function - matches the PyTorch tutorial exactly | |
| def predict(image): | |
| if image is None: | |
| return {f"Class {i}": 0 for i in range(10)} | |
| # Preprocess the image | |
| input_tensor = preprocess_image(image) | |
| # Make prediction - exactly as shown in the PyTorch tutorial | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| # Apply softmax to get probabilities | |
| probabilities = F.softmax(output, dim=1) | |
| probabilities = probabilities.numpy()[0] | |
| # Create labels for CIFAR-10 classes | |
| cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"] | |
| # Return as a dictionary | |
| return {label: float(prob) for label, prob in zip(cifar10_classes, probabilities)} | |
| # Create example images representing CIFAR-10 classes | |
| def create_example_images(): | |
| examples = [] | |
| example_names = [] | |
| # CIFAR-10 class names | |
| cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"] | |
| # Create simple representations of CIFAR-10 classes | |
| for i, class_name in enumerate(cifar10_classes): | |
| # Create a 64x64 RGB image for better quality | |
| img = Image.new('RGB', (64, 64), color=(255, 255, 255)) # White background | |
| draw = ImageDraw.Draw(img) | |
| # Draw simple representations of each class | |
| if i == 0: # Airplane | |
| # Draw a simple airplane shape | |
| draw.polygon([(32, 10), (20, 30), (44, 30)], fill=(169, 169, 169)) # Main body | |
| draw.rectangle([25, 30, 39, 35], fill=(105, 105, 105)) # Wings | |
| draw.rectangle([30, 35, 34, 45], fill=(128, 128, 128)) # Tail | |
| elif i == 1: # Automobile | |
| # Draw a simple car shape | |
| draw.rectangle([15, 30, 49, 45], fill=(0, 0, 255)) # Body | |
| draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels | |
| draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0)) | |
| draw.rectangle([25, 20, 39, 30], fill=(0, 0, 255)) # Top | |
| elif i == 2: # Bird | |
| # Draw a simple bird shape | |
| draw.ellipse([25, 25, 39, 39], fill=(255, 165, 0)) # Body | |
| draw.polygon([(32, 15), (25, 25), (39, 25)], fill=(255, 140, 0)) # Head | |
| draw.line([20, 30, 10, 20], fill=(255, 165, 0), width=3) # Wing | |
| draw.line([44, 30, 54, 20], fill=(255, 165, 0), width=3) # Wing | |
| elif i == 3: # Cat | |
| # Draw a simple cat shape | |
| draw.ellipse([25, 25, 39, 39], fill=(128, 128, 128)) # Body | |
| draw.ellipse([30, 20, 40, 30], fill=(169, 169, 169)) # Head | |
| draw.polygon([(35, 22), (33, 27), (37, 27)], fill=(0, 0, 0)) # Ear | |
| draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye | |
| elif i == 4: # Deer | |
| # Draw a simple deer shape | |
| draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body | |
| draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head | |
| draw.line([35, 15, 40, 25], fill=(139, 69, 19), width=3) # Antler | |
| draw.line([20, 35, 10, 30], fill=(139, 69, 19), width=2) # Leg | |
| elif i == 5: # Dog | |
| # Draw a simple dog shape | |
| draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body | |
| draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head | |
| draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye | |
| draw.ellipse([36, 32, 38, 34], fill=(0, 0, 0)) # Nose | |
| elif i == 6: # Frog | |
| # Draw a simple frog shape | |
| draw.ellipse([25, 30, 39, 44], fill=(34, 139, 34)) # Body | |
| draw.ellipse([30, 25, 40, 35], fill=(0, 100, 0)) # Head | |
| draw.ellipse([27, 32, 29, 34], fill=(0, 0, 0)) # Eye | |
| draw.ellipse([35, 32, 37, 34], fill=(0, 0, 0)) # Eye | |
| elif i == 7: # Horse | |
| # Draw a simple horse shape | |
| draw.ellipse([25, 30, 39, 44], fill=(169, 169, 169)) # Body | |
| draw.ellipse([35, 20, 45, 30], fill=(128, 128, 128)) # Head | |
| draw.line([40, 25, 50, 15], fill=(105, 105, 105), width=3) # Mane | |
| elif i == 8: # Ship | |
| # Draw a simple ship shape | |
| draw.polygon([(20, 35), (44, 35), (38, 45), (26, 45)], fill=(139, 69, 19)) # Hull | |
| draw.rectangle([30, 20, 34, 35], fill=(169, 169, 169)) # Mast | |
| draw.polygon([(30, 20), (32, 15), (34, 20)], fill=(255, 255, 255)) # Sail | |
| elif i == 9: # Truck | |
| # Draw a simple truck shape | |
| draw.rectangle([15, 25, 49, 45], fill=(255, 0, 0)) # Cab | |
| draw.rectangle([25, 15, 45, 25], fill=(255, 0, 0)) # Load area | |
| draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels | |
| draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0)) | |
| examples.append(img) | |
| example_names.append(class_name) | |
| return examples, example_names | |
| # Function to convert PIL Image to base64 for display | |
| def image_to_base64(image): | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| # Initialize the model | |
| model_loaded = load_model() | |
| # Create example images | |
| examples, example_names = create_example_images() | |
| # Streamlit app | |
| st.set_page_config( | |
| page_title="CIFAR-10 Image Classifier", | |
| page_icon="🚀", | |
| layout="wide" | |
| ) | |
| # Custom CSS with cleaner design | |
| st.markdown(""" | |
| <style> | |
| /* Import Google Fonts */ | |
| @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;500;600;700&display=swap'); | |
| /* Base styles */ | |
| * { | |
| font-family: 'Poppins', sans-serif; | |
| } | |
| /* Clean background */ | |
| body { | |
| background: linear-gradient(135deg, #1a2a6c, #2c3e50); | |
| color: white; | |
| } | |
| /* Main container with clean glassmorphism effect */ | |
| .main-container { | |
| background: rgba(255, 255, 255, 0.05); | |
| backdrop-filter: blur(10px); | |
| border-radius: 20px; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| box-shadow: 0 8px 32px 0 rgba(0, 0, 0, 0.3); | |
| padding: 2rem; | |
| margin: 2rem auto; | |
| max-width: 1200px; | |
| } | |
| /* Title with clean gradient */ | |
| .title { | |
| background: linear-gradient(90deg, #4facfe 0%, #00f2fe 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| font-weight: 800; | |
| font-size: 2.5rem; | |
| text-align: center; | |
| margin-bottom: 0.5rem; | |
| } | |
| /* Subtitle styling */ | |
| .subtitle { | |
| text-align: center; | |
| color: #a0d2ff; | |
| font-size: 1.1rem; | |
| margin-bottom: 2rem; | |
| opacity: 0.9; | |
| } | |
| /* Card styling */ | |
| .card { | |
| background: rgba(255, 255, 255, 0.05); | |
| border-radius: 15px; | |
| padding: 1.5rem; | |
| margin-bottom: 1.5rem; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| transition: all 0.3s ease; | |
| box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15); | |
| } | |
| .card:hover { | |
| background: rgba(255, 255, 255, 0.08); | |
| box-shadow: 0 6px 25px rgba(0, 0, 0, 0.25); | |
| transform: translateY(-3px); | |
| } | |
| /* Section headers */ | |
| .section-header { | |
| color: #4facfe; | |
| border-bottom: 2px solid #00f2fe; | |
| padding-bottom: 0.5rem; | |
| margin-bottom: 1rem; | |
| font-weight: 600; | |
| font-size: 1.3rem; | |
| } | |
| /* Button styling */ | |
| .stButton > button { | |
| background: linear-gradient(90deg, #4facfe 0%, #00f2fe 100%); | |
| color: white; | |
| border: none; | |
| border-radius: 10px; | |
| padding: 0.7rem 1.2rem; | |
| font-weight: 600; | |
| transition: all 0.3s ease; | |
| box-shadow: 0 4px 15px rgba(79, 172, 254, 0.3); | |
| width: 100%; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 6px 20px rgba(79, 172, 254, 0.5); | |
| } | |
| .stButton > button:active { | |
| transform: translateY(1px); | |
| } | |
| /* File uploader styling */ | |
| .stFileUploader > div { | |
| background: rgba(255, 255, 255, 0.05); | |
| border-radius: 15px; | |
| border: 1px dashed rgba(255, 255, 255, 0.3); | |
| padding: 1.5rem; | |
| text-align: center; | |
| } | |
| /* Progress bar styling */ | |
| .stProgress > div > div { | |
| background: linear-gradient(90deg, #4facfe 0%, #00f2fe 100%); | |
| } | |
| /* Result display */ | |
| .result-container { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 0.8rem; | |
| justify-content: center; | |
| } | |
| .result-item { | |
| background: rgba(255, 255, 255, 0.08); | |
| border-radius: 12px; | |
| padding: 1rem; | |
| text-align: center; | |
| min-width: 110px; | |
| transition: all 0.3s ease; | |
| border: 1px solid rgba(255, 255, 255, 0.1); | |
| } | |
| .result-item:hover { | |
| background: rgba(79, 172, 254, 0.2); | |
| transform: translateY(-3px); | |
| box-shadow: 0 5px 15px rgba(0, 0, 0, 0.2); | |
| } | |
| .result-label { | |
| font-weight: 600; | |
| margin-bottom: 0.4rem; | |
| color: #4facfe; | |
| font-size: 0.9rem; | |
| } | |
| .result-value { | |
| font-size: 1.1rem; | |
| font-weight: 700; | |
| color: white; | |
| } | |
| /* Example images grid */ | |
| .examples-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fill, minmax(90px, 1fr)); | |
| gap: 0.8rem; | |
| margin-top: 1rem; | |
| } | |
| .example-item { | |
| cursor: pointer; | |
| border-radius: 10px; | |
| overflow: hidden; | |
| transition: all 0.3s ease; | |
| border: 2px solid transparent; | |
| background: rgba(255, 255, 255, 0.05); | |
| } | |
| .example-item:hover { | |
| transform: scale(1.05); | |
| border-color: #4facfe; | |
| box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3); | |
| background: rgba(79, 172, 254, 0.1); | |
| } | |
| .example-item img { | |
| border-radius: 8px; | |
| } | |
| .example-name { | |
| text-align: center; | |
| margin-top: 5px; | |
| font-size: 0.75rem; | |
| color: #a0d2ff; | |
| } | |
| /* Footer */ | |
| .footer { | |
| text-align: center; | |
| padding: 1.5rem; | |
| color: rgba(255, 255, 255, 0.6); | |
| font-size: 0.9rem; | |
| } | |
| /* Responsive design */ | |
| @media (max-width: 768px) { | |
| .main-container { | |
| padding: 1rem; | |
| margin: 1rem; | |
| } | |
| .title { | |
| font-size: 2rem; | |
| } | |
| .card { | |
| padding: 1rem; | |
| } | |
| .result-item { | |
| min-width: 90px; | |
| padding: 0.7rem; | |
| } | |
| .examples-grid { | |
| grid-template-columns: repeat(auto-fill, minmax(70px, 1fr)); | |
| } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Main app content | |
| st.markdown('<div class="main-container">', unsafe_allow_html=True) | |
| st.markdown('<h1 class="title">🚀 CIFAR-10 Image Classifier</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="subtitle">Convolutional Neural Network for Object Recognition</p>', unsafe_allow_html=True) | |
| # Show model loading status | |
| if model_loaded: | |
| st.success("✅ Model successfully loaded") | |
| else: | |
| st.warning("⚠️ Model not found or error loading. Using random weights for demonstration.") | |
| # Create tabs for better organization | |
| tab1, tab2, tab3 = st.tabs(["🔍 Classify", "🖼️ Examples", "📚 Information"]) | |
| with tab1: | |
| # Create two columns for input and output | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">📤 Input</h2>', unsafe_allow_html=True) | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| # Display image | |
| image = None | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| # Classify button | |
| if st.button("Classify Image"): | |
| if image is not None: | |
| st.session_state.predictions = predict(image) | |
| else: | |
| st.warning("Please upload an image first") | |
| # Clear button | |
| if st.button("Clear"): | |
| st.session_state.predictions = None | |
| st.experimental_rerun() | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Model architecture section | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">🎯 Model Architecture</h2>', unsafe_allow_html=True) | |
| st.code(""" | |
| Input → Conv2D(3×32×32) → ReLU → MaxPool2D | |
| → Conv2D → ReLU → MaxPool2D | |
| → Flatten → Linear → ReLU | |
| → Linear → ReLU → Linear(10) | |
| → Output | |
| """, language="text") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with col2: | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">📊 Classification Results</h2>', unsafe_allow_html=True) | |
| # Display results | |
| if "predictions" in st.session_state and st.session_state.predictions: | |
| predictions = st.session_state.predictions | |
| # Sort predictions by probability | |
| sorted_predictions = sorted(predictions.items(), key=lambda x: x[1], reverse=True) | |
| # Display top 5 predictions with animated bars | |
| st.markdown('<div class="result-container">', unsafe_allow_html=True) | |
| for label, prob in sorted_predictions[:5]: | |
| st.markdown(f''' | |
| <div class="result-item"> | |
| <div class="result-label">{label}</div> | |
| <div class="result-value">{prob:.2f}</div> | |
| </div> | |
| ''', unsafe_allow_html=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Display all probabilities in a more detailed way | |
| st.subheader("All Class Probabilities") | |
| for label, prob in sorted_predictions: | |
| st.progress(prob) | |
| st.write(f"{label}: {prob:.4f}") | |
| else: | |
| st.info("Upload an image and click 'Classify Image' to see results") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Instructions section | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">ℹ️ Instructions</h2>', unsafe_allow_html=True) | |
| st.markdown(""" | |
| 1. Upload an image using the file uploader | |
| 2. The image will be automatically resized to 32×32 pixels | |
| 3. Click "Classify Image" to get predictions | |
| 4. Results show probabilities for 10 CIFAR-10 classes | |
| """) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with tab2: | |
| # Example images section | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">🖼️ Example Images</h2>', unsafe_allow_html=True) | |
| st.markdown("Click on any example image to classify it:") | |
| # Create example grid | |
| st.markdown('<div class="examples-grid">', unsafe_allow_html=True) | |
| for i, (example_img, example_name) in enumerate(zip(examples, example_names)): | |
| # Convert PIL image to base64 | |
| img_base64 = image_to_base64(example_img) | |
| # Create clickable image | |
| if st.button(f"example_{i}", key=f"btn_{i}"): | |
| st.session_state.predictions = predict(example_img) | |
| st.experimental_rerun() | |
| st.markdown(f''' | |
| <div class="example-item"> | |
| <img src="data:image/png;base64,{img_base64}" width="100" height="100" alt="{example_name}"> | |
| <div class="example-name">{example_name}</div> | |
| </div> | |
| ''', unsafe_allow_html=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with tab3: | |
| # Information sections | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">🧪 Testing Different Image Qualities</h2>', unsafe_allow_html=True) | |
| st.markdown(""" | |
| This model is robust to various image conditions: | |
| - **Resolution**: Works with images of any resolution (automatically resized to 32×32) | |
| - **Contrast**: Handles both high and low contrast images | |
| - **Noise**: Can tolerate some image noise | |
| - **Rotation**: Some tolerance to slight rotations | |
| - **Scale**: Works with objects of different sizes within the image | |
| For best results: | |
| 1. Center the object in the image | |
| 2. Use clear contrast between the object and background | |
| 3. Avoid excessive noise or artifacts | |
| 4. Fill most of the image area with the object | |
| """) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">🎯 CIFAR-10 Classes</h2>', unsafe_allow_html=True) | |
| classes_info = """ | |
| 1. **Airplane** - Aircraft flying in the sky | |
| 2. **Automobile** - Cars and vehicles on the road | |
| 3. **Bird** - Flying or perched birds | |
| 4. **Cat** - Domestic cats and felines | |
| 5. **Deer** - Wild deer and similar animals | |
| 6. **Dog** - Domestic dogs and canines | |
| 7. **Frog** - Amphibians like frogs | |
| 8. **Horse** - Horses and similar animals | |
| 9. **Ship** - Boats and ships on water | |
| 10. **Truck** - Trucks and heavy vehicles | |
| """ | |
| st.markdown(classes_info) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Model architecture section | |
| st.markdown('<div class="card">', unsafe_allow_html=True) | |
| st.markdown('<h2 class="section-header">🧠 Model Details</h2>', unsafe_allow_html=True) | |
| st.markdown(""" | |
| This convolutional neural network follows the PyTorch CIFAR-10 tutorial architecture: | |
| - **Input Layer**: 3×32×32 RGB images | |
| - **Convolutional Layers**: 2 layers with ReLU activation | |
| - **Pooling Layers**: 2 max-pooling layers | |
| - **Fully Connected Layers**: 3 linear layers | |
| - **Output Layer**: 10 classes with softmax activation | |
| """) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| # Footer | |
| st.markdown('<div class="footer">', unsafe_allow_html=True) | |
| st.markdown("Built with ❤️ using Streamlit and PyTorch | Deployable to Hugging Face Spaces") | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| st.markdown('</div>', unsafe_allow_html=True) |