| """ |
| CIFAR-10 Image Classifier - Hugging Face Space |
| =============================================== |
| Advanced CNN with Residual Connections achieving 87.88% test accuracy |
| |
| Architecture: CIFARNet |
| - Depthwise Separable Convolutions |
| - Residual Connections in C2 and C3 layers |
| - Spatial Dropout for regularization |
| - 174,762 parameters (0.67 MB model size) |
| """ |
|
|
| import torch |
| import gradio as gr |
| from PIL import Image |
| from pathlib import Path |
| import numpy as np |
|
|
| |
| from model_cifar import CIFARNet |
| from preprocess import CIFAR10_MEAN, CIFAR10_STD |
| from torchvision import transforms |
|
|
| |
| CIFAR10_CLASSES = [ |
| "airplane", "automobile", "bird", "cat", "deer", |
| "dog", "frog", "horse", "ship", "truck" |
| ] |
|
|
| |
| CLASS_DESCRIPTIONS = { |
| "airplane": "βοΈ Commercial or military aircraft", |
| "automobile": "π Cars, sedans, and vehicles", |
| "bird": "π¦ Various bird species", |
| "cat": "π± Domestic and wild cats", |
| "deer": "π¦ Deer and similar animals", |
| "dog": "π Domestic dogs of various breeds", |
| "frog": "πΈ Frogs and similar amphibians", |
| "horse": "π΄ Horses and equines", |
| "ship": "π’ Ships, boats, and vessels", |
| "truck": "π Trucks and large vehicles" |
| } |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| @torch.no_grad() |
| def load_model(checkpoint_path: str = None): |
| """Load the trained CIFARNet model.""" |
| model = CIFARNet(num_classes=10).to(device) |
| |
| |
| if checkpoint_path and Path(checkpoint_path).exists(): |
| try: |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| if 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| print(f"β
Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}") |
| else: |
| model.load_state_dict(checkpoint) |
| print(f"β
Loaded model weights from {checkpoint_path}") |
| except Exception as e: |
| print(f"β οΈ Could not load checkpoint: {e}") |
| print("Using randomly initialized model") |
| else: |
| print("βΉοΈ No checkpoint provided, using randomly initialized model") |
| |
| model.eval() |
| return model |
|
|
| |
| print(f"Device: {device}") |
| model = load_model("./snapshots_complete/cifar_epoch_249.pth") |
|
|
| |
| preprocess = transforms.Compose([ |
| transforms.Resize((32, 32)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD) |
| ]) |
|
|
| def predict(image: Image.Image) -> tuple: |
| """ |
| Predict the class of an input image. |
| |
| Args: |
| image: PIL Image |
| |
| Returns: |
| Tuple of (predictions_dict, confidence_html) |
| """ |
| if image is None: |
| return {}, "<p style='color: red;'>Please upload an image first!</p>" |
| |
| try: |
| |
| img_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(img_tensor) |
| probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy() |
| |
| |
| sorted_indices = np.argsort(probabilities)[::-1] |
| |
| |
| top3_results = { |
| CIFAR10_CLASSES[i]: float(probabilities[i]) |
| for i in sorted_indices[:3] |
| } |
| |
| |
| predicted_class = CIFAR10_CLASSES[sorted_indices[0]] |
| confidence = probabilities[sorted_indices[0]] |
| |
| html_output = f""" |
| <div style='padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| border-radius: 10px; color: white; box-shadow: 0 4px 6px rgba(0,0,0,0.1);'> |
| <h2 style='margin: 0 0 10px 0;'>π― Prediction Result</h2> |
| <div style='font-size: 24px; font-weight: bold; margin: 10px 0;'> |
| {predicted_class.upper()} |
| </div> |
| <div style='font-size: 16px; opacity: 0.9;'> |
| {CLASS_DESCRIPTIONS[predicted_class]} |
| </div> |
| <div style='font-size: 18px; margin-top: 10px;'> |
| Confidence: <strong>{confidence*100:.2f}%</strong> |
| </div> |
| </div> |
| |
| <div style='margin-top: 20px; padding: 15px; background: #f8f9fa; |
| border-radius: 8px; border-left: 4px solid #667eea;'> |
| <h3 style='margin-top: 0; color: #333;'>π Top 5 Predictions:</h3> |
| <div style='margin-top: 10px;'> |
| """ |
| |
| for i, idx in enumerate(sorted_indices[:5], 1): |
| class_name = CIFAR10_CLASSES[idx] |
| prob = probabilities[idx] |
| bar_width = int(prob * 100) |
| |
| |
| if i == 1: |
| color = "#28a745" |
| elif i == 2: |
| color = "#17a2b8" |
| else: |
| color = "#6c757d" |
| |
| html_output += f""" |
| <div style='margin: 8px 0;'> |
| <div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;'> |
| <span style='font-weight: 500; color: #333;'>{i}. {class_name}</span> |
| <span style='font-weight: bold; color: {color};'>{prob*100:.2f}%</span> |
| </div> |
| <div style='width: 100%; background: #e9ecef; border-radius: 4px; height: 20px; overflow: hidden;'> |
| <div style='width: {bar_width}%; background: {color}; height: 100%; |
| transition: width 0.3s ease;'></div> |
| </div> |
| </div> |
| """ |
| |
| html_output += """ |
| </div> |
| </div> |
| """ |
| |
| return top3_results, html_output |
| |
| except Exception as e: |
| error_html = f"<p style='color: red;'>Error during prediction: {str(e)}</p>" |
| return {}, error_html |
|
|
| |
| model_description = """ |
| ## π About This Model |
| |
| **CIFARNet** is an advanced CNN architecture designed for CIFAR-10 image classification. It achieves state-of-the-art performance with exceptional efficiency. |
| |
| ### π Performance Metrics |
| - **Test Accuracy:** 87.88% |
| - **Top-3 Accuracy:** 97.74% |
| - **Top-5 Accuracy:** 99.31% |
| - **Model Size:** 174,762 parameters (0.67 MB) |
| |
| ### ποΈ Architecture Highlights |
| - **Depthwise Separable Convolutions** for parameter efficiency |
| - **Residual Connections** in C2 and C3 layers for improved gradient flow |
| - **Spatial Dropout** for better regularization |
| - **Dilated Convolutions** (C3 layer with dilation=4) for larger receptive field |
| |
| ### π― Best Performing Classes |
| - Ship: 92.95% F1-score |
| - Truck: 92.19% F1-score |
| - Automobile: 93.77% F1-score |
| - Frog: 90.28% F1-score |
| |
| ### π¬ Training Details |
| - **Training Set:** CIFAR-10 (50,000 images) |
| - **Test Set:** CIFAR-10 (10,000 images) |
| - **Epochs:** 250 with cosine annealing |
| - **Optimizer:** SGD with Nesterov momentum |
| - **Augmentation:** HorizontalFlip, ShiftScaleRotate, CoarseDropout, ColorJitter |
| |
| ### π Classes |
| The model classifies images into 10 categories: |
| - βοΈ Airplane |
| - π Automobile |
| - π¦ Bird |
| - π± Cat |
| - π¦ Deer |
| - π Dog |
| - πΈ Frog |
| - π΄ Horse |
| - π’ Ship |
| - π Truck |
| |
| ### π‘ Tips |
| - Upload clear images for best results |
| - The model works best with images containing the main object centered |
| - Images are automatically resized to 32Γ32 pixels (CIFAR-10 standard) |
| - Try different angles or lighting conditions to see how the model performs |
| |
| ### π Links |
| - [GitHub Repository](https://github.com/yourusername/CIFAR10-MLTraining) |
| - [Model Architecture Details](https://github.com/yourusername/CIFAR10-MLTraining#model-architecture) |
| - [Training Logs & Metrics](https://github.com/yourusername/CIFAR10-MLTraining#performance-results) |
| """ |
|
|
| |
| examples = [ |
| ["examples/airplane_1.jpg"], |
| ["examples/airplane_2.jpg"], |
| ["examples/automobile_1.jpg"], |
| ["examples/ship_1.jpg"], |
| ["examples/ship_2.jpg"], |
| ["examples/cat_1.jpg"], |
| ["examples/dog_1.jpg"], |
| ["examples/horse_1.jpg"], |
| ["examples/frog_1.jpg"], |
| ["examples/truck_1.jpg"], |
| ] |
|
|
| |
| custom_css = """ |
| .gradio-container { |
| font-family: 'Inter', sans-serif; |
| } |
| .output-html { |
| font-family: 'Inter', sans-serif; |
| } |
| """ |
|
|
| |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π― CIFAR-10 Image Classifier") |
| gr.Markdown("### Advanced CNN achieving 87.88% test accuracy") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil") |
| predict_btn = gr.Button("π Classify Image", variant="primary", size="lg") |
| |
| gr.Markdown("### π€ Try it out!") |
| gr.Markdown("Upload an image containing one of the 10 CIFAR-10 classes.") |
| |
| with gr.Column(scale=1): |
| label_output = gr.Label(num_top_classes=3, label="Top 3 Predictions") |
| html_output = gr.HTML(label="Detailed Results") |
| |
| |
| gr.Markdown("---") |
| gr.Markdown("## π‘ How to use") |
| gr.Markdown(""" |
| 1. **Upload an image** using the upload box on the left |
| 2. **Click 'Classify Image'** to get predictions |
| 3. **View results** showing the top predictions with confidence scores |
| |
| The model works best with images from these categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. |
| """) |
| |
| gr.Markdown("### πΈ Try These Examples") |
| gr.Examples( |
| examples=examples, |
| inputs=image_input, |
| outputs=[label_output, html_output], |
| fn=predict, |
| cache_examples=False, |
| ) |
| |
| |
| gr.Markdown("---") |
| with gr.Accordion("π Model Information & Performance Metrics", open=False): |
| gr.Markdown(model_description) |
| |
| |
| predict_btn.click( |
| fn=predict, |
| inputs=image_input, |
| outputs=[label_output, html_output] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|