Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from sklearn.datasets import fetch_openml | |
| from sklearn.naive_bayes import BernoulliNB | |
| from sklearn.preprocessing import Binarizer | |
| from sklearn.metrics import accuracy_score | |
| print("π Starting MNIST Digit Classifier...") | |
| # Train model directly | |
| try: | |
| print("π Loading MNIST dataset...") | |
| mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto') | |
| X, y = mnist["data"][:2000], mnist["target"][:2000].astype(int) | |
| print("π Training Bernoulli Naive Bayes...") | |
| binarizer = Binarizer(threshold=127.0) | |
| X_bin = binarizer.fit_transform(X) | |
| model = BernoulliNB() | |
| model.fit(X_bin, y) | |
| # Calculate accuracy | |
| y_pred = model.predict(X_bin) | |
| accuracy = accuracy_score(y, y_pred) | |
| print(f"β Model trained! Accuracy: {accuracy*100:.2f}%") | |
| except Exception as e: | |
| print(f"β Training failed: {e}") | |
| model = None | |
| binarizer = Binarizer(threshold=127.0) | |
| accuracy = 0.83 | |
| def preprocess_image(image): | |
| """Convert drawing to MNIST format""" | |
| try: | |
| # Convert to numpy array if needed | |
| if isinstance(image, np.ndarray): | |
| image_array = image | |
| else: | |
| image_array = np.array(image) | |
| # Convert to grayscale if needed | |
| if len(image_array.shape) == 3: | |
| image_array = np.mean(image_array, axis=2) | |
| # Resize to 28x28 | |
| pil_image = Image.fromarray(image_array.astype('uint8')) | |
| pil_image = pil_image.resize((28, 28)) | |
| image_array = np.array(pil_image) | |
| # Invert colors (MNIST has white digits on black background) | |
| image_array = 255 - image_array | |
| # Flatten and binarize | |
| image_flat = image_array.flatten() | |
| image_bin = binarizer.transform([image_flat]) | |
| return image_bin, image_array | |
| except Exception as e: | |
| print(f"Preprocessing error: {e}") | |
| return None, None | |
| def predict_digit(image): | |
| """Predict digit from drawing""" | |
| if image is None: | |
| return "Please draw a digit (0-9) first! βοΈ", None | |
| try: | |
| processed_image, processed_array = preprocess_image(image) | |
| if processed_image is None: | |
| return "Error processing image. Please try again. π", None | |
| if model is None: | |
| return "Model not loaded. Please wait... β³", None | |
| # Make prediction | |
| prediction = model.predict(processed_image)[0] | |
| probabilities = model.predict_proba(processed_image)[0] | |
| # Get top 3 predictions | |
| top_3_indices = np.argsort(probabilities)[-3:][::-1] | |
| top_3_probs = probabilities[top_3_indices] | |
| # Create visualization | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
| # Show processed image | |
| ax1.imshow(processed_array, cmap='gray') | |
| ax1.set_title(f'Processed Image\nPrediction: {prediction}') | |
| ax1.axis('off') | |
| # Show probabilities | |
| colors = ['green' if i == prediction else 'blue' for i in range(10)] | |
| bars = ax2.bar(range(10), probabilities, color=colors, alpha=0.7) | |
| ax2.set_xlabel('Digits') | |
| ax2.set_ylabel('Probability') | |
| ax2.set_title('Prediction Probabilities') | |
| ax2.set_xticks(range(10)) | |
| ax2.set_ylim(0, 1) | |
| # Add value labels | |
| for bar, prob in zip(bars, probabilities): | |
| height = bar.get_height() | |
| if height > 0.1: | |
| ax2.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{prob:.2f}', ha='center', va='bottom', fontsize=9) | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| plot_image = Image.open(buf) | |
| plt.close() | |
| # Format results | |
| result_text = f"π― **Predicted Digit: {prediction}**\n\n" | |
| result_text += f"π **Confidence: {probabilities[prediction]*100:.2f}%**\n\n" | |
| result_text += "π **Top 3 Predictions:**\n" | |
| for i, (digit, prob) in enumerate(zip(top_3_indices, top_3_probs)): | |
| result_text += f" {i+1}. Digit {digit}: {prob*100:.2f}%\n" | |
| return result_text, plot_image | |
| except Exception as e: | |
| return f"β Error: {str(e)}", None | |
| # Create Gradio interface - COMPLETELY FIXED VERSION | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="MNIST Digit Classifier - Bernoulli Naive Bayes" | |
| ) as demo: | |
| gr.Markdown(f""" | |
| # βοΈ MNIST Handwritten Digit Classifier | |
| ## π€ Bernoulli Naive Bayes | Accuracy: {accuracy*100:.2f}% | |
| **Upload an image of a digit (0-9) and see the AI prediction!** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Upload Image") | |
| # β FIXED: Simple Image upload without sources parameter | |
| image_input = gr.Image( | |
| label="Upload digit image (0-9)", | |
| type="numpy", | |
| height=300, | |
| width=300 | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("π§Ή Clear") | |
| predict_btn = gr.Button("π Predict Digit", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Prediction Results") | |
| output_text = gr.Markdown( | |
| value="**Upload an image of a digit and click Predict!**" | |
| ) | |
| gr.Markdown("### π Visualization") | |
| output_plot = gr.Image( | |
| label="Probability Distribution", | |
| height=300 | |
| ) | |
| # Instructions for drawing | |
| gr.Markdown("### π‘ How to use:") | |
| gr.Markdown(""" | |
| 1. **Draw a digit** on paper or using any drawing app | |
| 2. **Save as image** (PNG/JPG format) | |
| 3. **Upload here** using the upload button above | |
| 4. **Click Predict** to see results | |
| **Tips:** | |
| - Draw clear, centered digits | |
| - Use black ink on white background | |
| - Make digits large and clear | |
| """) | |
| gr.Markdown("---") | |
| gr.Markdown(f""" | |
| **Model Information:** | |
| - Algorithm: Bernoulli Naive Bayes | |
| - Dataset: MNIST Handwritten Digits | |
| - Accuracy: {accuracy*100:.2f}% | |
| - Input: 28Γ28 grayscale images | |
| """) | |
| # Button actions | |
| predict_btn.click( | |
| fn=predict_digit, | |
| inputs=image_input, | |
| outputs=[output_text, output_plot] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None, "**Cleared! Upload a new image.**", None], | |
| outputs=[image_input, output_text, output_plot] | |
| ) | |
| # Launch app | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |