Spaces:
Sleeping
Sleeping
File size: 6,940 Bytes
a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a 6df9113 a277a1a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from sklearn.datasets import fetch_openml
from sklearn.naive_bayes import BernoulliNB
from sklearn.preprocessing import Binarizer
from sklearn.metrics import accuracy_score
print("π Starting MNIST Digit Classifier...")
# Train model directly
try:
print("π Loading MNIST dataset...")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X, y = mnist["data"][:2000], mnist["target"][:2000].astype(int)
print("π Training Bernoulli Naive Bayes...")
binarizer = Binarizer(threshold=127.0)
X_bin = binarizer.fit_transform(X)
model = BernoulliNB()
model.fit(X_bin, y)
# Calculate accuracy
y_pred = model.predict(X_bin)
accuracy = accuracy_score(y, y_pred)
print(f"β
Model trained! Accuracy: {accuracy*100:.2f}%")
except Exception as e:
print(f"β Training failed: {e}")
model = None
binarizer = Binarizer(threshold=127.0)
accuracy = 0.83
def preprocess_image(image):
"""Convert drawing to MNIST format"""
try:
# Convert to numpy array if needed
if isinstance(image, np.ndarray):
image_array = image
else:
image_array = np.array(image)
# Convert to grayscale if needed
if len(image_array.shape) == 3:
image_array = np.mean(image_array, axis=2)
# Resize to 28x28
pil_image = Image.fromarray(image_array.astype('uint8'))
pil_image = pil_image.resize((28, 28))
image_array = np.array(pil_image)
# Invert colors (MNIST has white digits on black background)
image_array = 255 - image_array
# Flatten and binarize
image_flat = image_array.flatten()
image_bin = binarizer.transform([image_flat])
return image_bin, image_array
except Exception as e:
print(f"Preprocessing error: {e}")
return None, None
def predict_digit(image):
"""Predict digit from drawing"""
if image is None:
return "Please draw a digit (0-9) first! βοΈ", None
try:
processed_image, processed_array = preprocess_image(image)
if processed_image is None:
return "Error processing image. Please try again. π", None
if model is None:
return "Model not loaded. Please wait... β³", None
# Make prediction
prediction = model.predict(processed_image)[0]
probabilities = model.predict_proba(processed_image)[0]
# Get top 3 predictions
top_3_indices = np.argsort(probabilities)[-3:][::-1]
top_3_probs = probabilities[top_3_indices]
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Show processed image
ax1.imshow(processed_array, cmap='gray')
ax1.set_title(f'Processed Image\nPrediction: {prediction}')
ax1.axis('off')
# Show probabilities
colors = ['green' if i == prediction else 'blue' for i in range(10)]
bars = ax2.bar(range(10), probabilities, color=colors, alpha=0.7)
ax2.set_xlabel('Digits')
ax2.set_ylabel('Probability')
ax2.set_title('Prediction Probabilities')
ax2.set_xticks(range(10))
ax2.set_ylim(0, 1)
# Add value labels
for bar, prob in zip(bars, probabilities):
height = bar.get_height()
if height > 0.1:
ax2.text(bar.get_x() + bar.get_width()/2., height,
f'{prob:.2f}', ha='center', va='bottom', fontsize=9)
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
plot_image = Image.open(buf)
plt.close()
# Format results
result_text = f"π― **Predicted Digit: {prediction}**\n\n"
result_text += f"π **Confidence: {probabilities[prediction]*100:.2f}%**\n\n"
result_text += "π **Top 3 Predictions:**\n"
for i, (digit, prob) in enumerate(zip(top_3_indices, top_3_probs)):
result_text += f" {i+1}. Digit {digit}: {prob*100:.2f}%\n"
return result_text, plot_image
except Exception as e:
return f"β Error: {str(e)}", None
# Create Gradio interface - COMPLETELY FIXED VERSION
with gr.Blocks(
theme=gr.themes.Soft(),
title="MNIST Digit Classifier - Bernoulli Naive Bayes"
) as demo:
gr.Markdown(f"""
# βοΈ MNIST Handwritten Digit Classifier
## π€ Bernoulli Naive Bayes | Accuracy: {accuracy*100:.2f}%
**Upload an image of a digit (0-9) and see the AI prediction!**
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### π Upload Image")
# β
FIXED: Simple Image upload without sources parameter
image_input = gr.Image(
label="Upload digit image (0-9)",
type="numpy",
height=300,
width=300
)
with gr.Row():
clear_btn = gr.Button("π§Ή Clear")
predict_btn = gr.Button("π Predict Digit", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### π Prediction Results")
output_text = gr.Markdown(
value="**Upload an image of a digit and click Predict!**"
)
gr.Markdown("### π Visualization")
output_plot = gr.Image(
label="Probability Distribution",
height=300
)
# Instructions for drawing
gr.Markdown("### π‘ How to use:")
gr.Markdown("""
1. **Draw a digit** on paper or using any drawing app
2. **Save as image** (PNG/JPG format)
3. **Upload here** using the upload button above
4. **Click Predict** to see results
**Tips:**
- Draw clear, centered digits
- Use black ink on white background
- Make digits large and clear
""")
gr.Markdown("---")
gr.Markdown(f"""
**Model Information:**
- Algorithm: Bernoulli Naive Bayes
- Dataset: MNIST Handwritten Digits
- Accuracy: {accuracy*100:.2f}%
- Input: 28Γ28 grayscale images
""")
# Button actions
predict_btn.click(
fn=predict_digit,
inputs=image_input,
outputs=[output_text, output_plot]
)
clear_btn.click(
fn=lambda: [None, "**Cleared! Upload a new image.**", None],
outputs=[image_input, output_text, output_plot]
)
# Launch app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |