Spaces:
Sleeping
Sleeping
| """ | |
| Gradio app for MNIST digit classification. | |
| Interactive web interface for handwritten digit recognition using trained CNN model. | |
| """ | |
| import gradio as gr | |
| from scripts.inference import DigitClassifier | |
| from PIL import Image | |
| import numpy as np | |
| # Initialize classifier | |
| print("Loading model...") | |
| classifier = DigitClassifier('models/best_model.pt') | |
| print(f"Model loaded on {classifier.device}") | |
| def predict_digit(image): | |
| """ | |
| Predict digit from user-drawn image. | |
| Args: | |
| image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W) | |
| Returns: | |
| Tuple of (predicted_digit, confidence_text, probability_dict) | |
| """ | |
| if image is None: | |
| return "Please draw a digit", "", {} | |
| # Handle different image formats from Gradio | |
| if isinstance(image, dict): | |
| # Sketchpad returns dict with 'composite' key | |
| image = image.get('composite', image) | |
| # Convert to PIL Image | |
| if isinstance(image, np.ndarray): | |
| # If RGB, convert to grayscale | |
| if len(image.shape) == 3: | |
| # Take only the drawn part (alpha channel if available) | |
| if image.shape[2] == 4: # RGBA | |
| image = image[:, :, 3] # Use alpha channel | |
| else: # RGB | |
| image = np.mean(image, axis=2).astype(np.uint8) | |
| # Ensure values are in [0, 255] | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| pil_image = Image.fromarray(image.astype(np.uint8), mode='L') | |
| else: | |
| pil_image = image | |
| # Get prediction | |
| result = classifier.predict(pil_image) | |
| # Format output | |
| digit = result['digit'] | |
| confidence = result['confidence'] | |
| probabilities = result['probabilities'] | |
| # Create confidence text | |
| confidence_text = f"Confidence: {confidence*100:.1f}%" | |
| # Create probability dictionary for bar chart | |
| prob_dict = {str(i): prob for i, prob in enumerate(probabilities)} | |
| return digit, confidence_text, prob_dict | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Arial', sans-serif; | |
| max-width: 900px; | |
| margin: auto; | |
| } | |
| .title { | |
| text-align: center; | |
| color: #2c3e50; | |
| } | |
| .description { | |
| text-align: center; | |
| color: #7f8c8d; | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo: | |
| gr.Markdown( | |
| """ | |
| # π’ Handwritten Digit Classifier | |
| Draw a digit (0-9) in the box below and the AI will predict | |
| what it is! | |
| This model uses a Convolutional Neural Network (CNN) trained on | |
| the MNIST dataset with **99.17% accuracy** on 10,000 test images. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Sketchpad for drawing | |
| input_image = gr.Sketchpad( | |
| label="Draw a digit here", | |
| type="numpy", | |
| image_mode="L", | |
| brush_radius=3, | |
| height=280, | |
| width=280 | |
| ) | |
| # Buttons | |
| with gr.Row(): | |
| predict_btn = gr.Button( | |
| "π Predict", variant="primary", scale=2 | |
| ) | |
| clear_btn = gr.ClearButton( | |
| components=[input_image], value="ποΈ Clear", scale=1 | |
| ) | |
| with gr.Column(scale=1): | |
| # Prediction output | |
| output_digit = gr.Textbox( | |
| label="Predicted Digit", | |
| placeholder="Draw a digit to see prediction", | |
| scale=1, | |
| lines=1, | |
| max_lines=1, | |
| interactive=False | |
| ) | |
| output_confidence = gr.Textbox( | |
| label="Confidence", | |
| placeholder="", | |
| scale=1, | |
| lines=1, | |
| max_lines=1, | |
| interactive=False | |
| ) | |
| # Probability distribution | |
| output_probs = gr.Label( | |
| label="Probability Distribution", | |
| num_top_classes=10 | |
| ) | |
| # Example images section | |
| gr.Markdown("### π Try these examples:") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/digit_0.png"] if __name__ != "__main__" else None, | |
| ], | |
| inputs=input_image, | |
| label="Example digits" | |
| ) | |
| # Model info | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π Model Details | |
| - **Architecture**: Convolutional Neural Network (CNN) | |
| - **Parameters**: 421,066 | |
| - **Training**: MNIST dataset (60,000 images) | |
| - **Test Accuracy**: 99.17% | |
| - **Framework**: PyTorch 2.0.1 | |
| ### π‘ Tips for best results: | |
| - Draw the digit large and centered | |
| - Use a thick brush stroke | |
| - Draw in white on black background (like MNIST) | |
| - Make sure the digit is clear and recognizable | |
| """ | |
| ) | |
| # Connect events | |
| predict_btn.click( | |
| fn=predict_digit, | |
| inputs=input_image, | |
| outputs=[output_digit, output_confidence, output_probs] | |
| ) | |
| # Also predict on sketchpad change (real-time prediction) | |
| input_image.change( | |
| fn=predict_digit, | |
| inputs=input_image, | |
| outputs=[output_digit, output_confidence, output_probs] | |
| ) | |
| if __name__ == "__main__": | |
| # Launch the app | |
| import os | |
| port = int(os.getenv("GRADIO_SERVER_PORT", 7860)) | |
| demo.launch( | |
| server_name="0.0.0.0", # Allow external access | |
| server_port=port, | |
| share=True, # Set to True to create public link | |
| show_error=True | |
| ) | |