Ali-Hyder2019's picture
Update app.py
6df9113 verified
import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from sklearn.datasets import fetch_openml
from sklearn.naive_bayes import BernoulliNB
from sklearn.preprocessing import Binarizer
from sklearn.metrics import accuracy_score
print("πŸš€ Starting MNIST Digit Classifier...")
# Train model directly
try:
print("πŸ”„ Loading MNIST dataset...")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X, y = mnist["data"][:2000], mnist["target"][:2000].astype(int)
print("πŸ”„ Training Bernoulli Naive Bayes...")
binarizer = Binarizer(threshold=127.0)
X_bin = binarizer.fit_transform(X)
model = BernoulliNB()
model.fit(X_bin, y)
# Calculate accuracy
y_pred = model.predict(X_bin)
accuracy = accuracy_score(y, y_pred)
print(f"βœ… Model trained! Accuracy: {accuracy*100:.2f}%")
except Exception as e:
print(f"❌ Training failed: {e}")
model = None
binarizer = Binarizer(threshold=127.0)
accuracy = 0.83
def preprocess_image(image):
"""Convert drawing to MNIST format"""
try:
# Convert to numpy array if needed
if isinstance(image, np.ndarray):
image_array = image
else:
image_array = np.array(image)
# Convert to grayscale if needed
if len(image_array.shape) == 3:
image_array = np.mean(image_array, axis=2)
# Resize to 28x28
pil_image = Image.fromarray(image_array.astype('uint8'))
pil_image = pil_image.resize((28, 28))
image_array = np.array(pil_image)
# Invert colors (MNIST has white digits on black background)
image_array = 255 - image_array
# Flatten and binarize
image_flat = image_array.flatten()
image_bin = binarizer.transform([image_flat])
return image_bin, image_array
except Exception as e:
print(f"Preprocessing error: {e}")
return None, None
def predict_digit(image):
"""Predict digit from drawing"""
if image is None:
return "Please draw a digit (0-9) first! ✏️", None
try:
processed_image, processed_array = preprocess_image(image)
if processed_image is None:
return "Error processing image. Please try again. πŸ”„", None
if model is None:
return "Model not loaded. Please wait... ⏳", None
# Make prediction
prediction = model.predict(processed_image)[0]
probabilities = model.predict_proba(processed_image)[0]
# Get top 3 predictions
top_3_indices = np.argsort(probabilities)[-3:][::-1]
top_3_probs = probabilities[top_3_indices]
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Show processed image
ax1.imshow(processed_array, cmap='gray')
ax1.set_title(f'Processed Image\nPrediction: {prediction}')
ax1.axis('off')
# Show probabilities
colors = ['green' if i == prediction else 'blue' for i in range(10)]
bars = ax2.bar(range(10), probabilities, color=colors, alpha=0.7)
ax2.set_xlabel('Digits')
ax2.set_ylabel('Probability')
ax2.set_title('Prediction Probabilities')
ax2.set_xticks(range(10))
ax2.set_ylim(0, 1)
# Add value labels
for bar, prob in zip(bars, probabilities):
height = bar.get_height()
if height > 0.1:
ax2.text(bar.get_x() + bar.get_width()/2., height,
f'{prob:.2f}', ha='center', va='bottom', fontsize=9)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
plot_image = Image.open(buf)
plt.close()
# Format results
result_text = f"🎯 **Predicted Digit: {prediction}**\n\n"
result_text += f"πŸ“Š **Confidence: {probabilities[prediction]*100:.2f}%**\n\n"
result_text += "πŸ† **Top 3 Predictions:**\n"
for i, (digit, prob) in enumerate(zip(top_3_indices, top_3_probs)):
result_text += f" {i+1}. Digit {digit}: {prob*100:.2f}%\n"
return result_text, plot_image
except Exception as e:
return f"❌ Error: {str(e)}", None
# Create Gradio interface - COMPLETELY FIXED VERSION
with gr.Blocks(
theme=gr.themes.Soft(),
title="MNIST Digit Classifier - Bernoulli Naive Bayes"
) as demo:
gr.Markdown(f"""
# ✍️ MNIST Handwritten Digit Classifier
## πŸ€– Bernoulli Naive Bayes | Accuracy: {accuracy*100:.2f}%
**Upload an image of a digit (0-9) and see the AI prediction!**
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Upload Image")
# βœ… FIXED: Simple Image upload without sources parameter
image_input = gr.Image(
label="Upload digit image (0-9)",
type="numpy",
height=300,
width=300
)
with gr.Row():
clear_btn = gr.Button("🧹 Clear")
predict_btn = gr.Button("πŸ” Predict Digit", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### πŸ“Š Prediction Results")
output_text = gr.Markdown(
value="**Upload an image of a digit and click Predict!**"
)
gr.Markdown("### πŸ“ˆ Visualization")
output_plot = gr.Image(
label="Probability Distribution",
height=300
)
# Instructions for drawing
gr.Markdown("### πŸ’‘ How to use:")
gr.Markdown("""
1. **Draw a digit** on paper or using any drawing app
2. **Save as image** (PNG/JPG format)
3. **Upload here** using the upload button above
4. **Click Predict** to see results
**Tips:**
- Draw clear, centered digits
- Use black ink on white background
- Make digits large and clear
""")
gr.Markdown("---")
gr.Markdown(f"""
**Model Information:**
- Algorithm: Bernoulli Naive Bayes
- Dataset: MNIST Handwritten Digits
- Accuracy: {accuracy*100:.2f}%
- Input: 28Γ—28 grayscale images
""")
# Button actions
predict_btn.click(
fn=predict_digit,
inputs=image_input,
outputs=[output_text, output_plot]
)
clear_btn.click(
fn=lambda: [None, "**Cleared! Upload a new image.**", None],
outputs=[image_input, output_text, output_plot]
)
# Launch app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)