faizan
fix: set port to 7860 and share=True for HF Spaces
900d934
"""
Gradio app for MNIST digit classification.
Interactive web interface for handwritten digit recognition using trained CNN model.
"""
import gradio as gr
from scripts.inference import DigitClassifier
from PIL import Image
import numpy as np
# Initialize classifier
print("Loading model...")
classifier = DigitClassifier('models/best_model.pt')
print(f"Model loaded on {classifier.device}")
def predict_digit(image):
"""
Predict digit from user-drawn image.
Args:
image: numpy array from Gradio Sketchpad (H, W, 3) or (H, W)
Returns:
Tuple of (predicted_digit, confidence_text, probability_dict)
"""
if image is None:
return "Please draw a digit", "", {}
# Handle different image formats from Gradio
if isinstance(image, dict):
# Sketchpad returns dict with 'composite' key
image = image.get('composite', image)
# Convert to PIL Image
if isinstance(image, np.ndarray):
# If RGB, convert to grayscale
if len(image.shape) == 3:
# Take only the drawn part (alpha channel if available)
if image.shape[2] == 4: # RGBA
image = image[:, :, 3] # Use alpha channel
else: # RGB
image = np.mean(image, axis=2).astype(np.uint8)
# Ensure values are in [0, 255]
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
pil_image = Image.fromarray(image.astype(np.uint8), mode='L')
else:
pil_image = image
# Get prediction
result = classifier.predict(pil_image)
# Format output
digit = result['digit']
confidence = result['confidence']
probabilities = result['probabilities']
# Create confidence text
confidence_text = f"Confidence: {confidence*100:.1f}%"
# Create probability dictionary for bar chart
prob_dict = {str(i): prob for i, prob in enumerate(probabilities)}
return digit, confidence_text, prob_dict
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
max-width: 900px;
margin: auto;
}
.title {
text-align: center;
color: #2c3e50;
}
.description {
text-align: center;
color: #7f8c8d;
margin-bottom: 20px;
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, title="MNIST Digit Classifier") as demo:
gr.Markdown(
"""
# πŸ”’ Handwritten Digit Classifier
Draw a digit (0-9) in the box below and the AI will predict
what it is!
This model uses a Convolutional Neural Network (CNN) trained on
the MNIST dataset with **99.17% accuracy** on 10,000 test images.
"""
)
with gr.Row():
with gr.Column(scale=1):
# Sketchpad for drawing
input_image = gr.Sketchpad(
label="Draw a digit here",
type="numpy",
image_mode="L",
brush_radius=3,
height=280,
width=280
)
# Buttons
with gr.Row():
predict_btn = gr.Button(
"πŸ” Predict", variant="primary", scale=2
)
clear_btn = gr.ClearButton(
components=[input_image], value="πŸ—‘οΈ Clear", scale=1
)
with gr.Column(scale=1):
# Prediction output
output_digit = gr.Textbox(
label="Predicted Digit",
placeholder="Draw a digit to see prediction",
scale=1,
lines=1,
max_lines=1,
interactive=False
)
output_confidence = gr.Textbox(
label="Confidence",
placeholder="",
scale=1,
lines=1,
max_lines=1,
interactive=False
)
# Probability distribution
output_probs = gr.Label(
label="Probability Distribution",
num_top_classes=10
)
# Example images section
gr.Markdown("### πŸ“ Try these examples:")
gr.Examples(
examples=[
["examples/digit_0.png"] if __name__ != "__main__" else None,
],
inputs=input_image,
label="Example digits"
)
# Model info
gr.Markdown(
"""
---
### πŸ“Š Model Details
- **Architecture**: Convolutional Neural Network (CNN)
- **Parameters**: 421,066
- **Training**: MNIST dataset (60,000 images)
- **Test Accuracy**: 99.17%
- **Framework**: PyTorch 2.0.1
### πŸ’‘ Tips for best results:
- Draw the digit large and centered
- Use a thick brush stroke
- Draw in white on black background (like MNIST)
- Make sure the digit is clear and recognizable
"""
)
# Connect events
predict_btn.click(
fn=predict_digit,
inputs=input_image,
outputs=[output_digit, output_confidence, output_probs]
)
# Also predict on sketchpad change (real-time prediction)
input_image.change(
fn=predict_digit,
inputs=input_image,
outputs=[output_digit, output_confidence, output_probs]
)
if __name__ == "__main__":
# Launch the app
import os
port = int(os.getenv("GRADIO_SERVER_PORT", 7860))
demo.launch(
server_name="0.0.0.0", # Allow external access
server_port=port,
share=True, # Set to True to create public link
show_error=True
)