import gradio as gr import numpy as np from PIL import Image import matplotlib.pyplot as plt import io from sklearn.datasets import fetch_openml from sklearn.naive_bayes import BernoulliNB from sklearn.preprocessing import Binarizer from sklearn.metrics import accuracy_score print("๐Ÿš€ Starting MNIST Digit Classifier...") # Train model directly try: print("๐Ÿ”„ Loading MNIST dataset...") mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto') X, y = mnist["data"][:2000], mnist["target"][:2000].astype(int) print("๐Ÿ”„ Training Bernoulli Naive Bayes...") binarizer = Binarizer(threshold=127.0) X_bin = binarizer.fit_transform(X) model = BernoulliNB() model.fit(X_bin, y) # Calculate accuracy y_pred = model.predict(X_bin) accuracy = accuracy_score(y, y_pred) print(f"โœ… Model trained! Accuracy: {accuracy*100:.2f}%") except Exception as e: print(f"โŒ Training failed: {e}") model = None binarizer = Binarizer(threshold=127.0) accuracy = 0.83 def preprocess_image(image): """Convert drawing to MNIST format""" try: # Convert to numpy array if needed if isinstance(image, np.ndarray): image_array = image else: image_array = np.array(image) # Convert to grayscale if needed if len(image_array.shape) == 3: image_array = np.mean(image_array, axis=2) # Resize to 28x28 pil_image = Image.fromarray(image_array.astype('uint8')) pil_image = pil_image.resize((28, 28)) image_array = np.array(pil_image) # Invert colors (MNIST has white digits on black background) image_array = 255 - image_array # Flatten and binarize image_flat = image_array.flatten() image_bin = binarizer.transform([image_flat]) return image_bin, image_array except Exception as e: print(f"Preprocessing error: {e}") return None, None def predict_digit(image): """Predict digit from drawing""" if image is None: return "Please draw a digit (0-9) first! โœ๏ธ", None try: processed_image, processed_array = preprocess_image(image) if processed_image is None: return "Error processing image. Please try again. ๐Ÿ”„", None if model is None: return "Model not loaded. Please wait... โณ", None # Make prediction prediction = model.predict(processed_image)[0] probabilities = model.predict_proba(processed_image)[0] # Get top 3 predictions top_3_indices = np.argsort(probabilities)[-3:][::-1] top_3_probs = probabilities[top_3_indices] # Create visualization fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) # Show processed image ax1.imshow(processed_array, cmap='gray') ax1.set_title(f'Processed Image\nPrediction: {prediction}') ax1.axis('off') # Show probabilities colors = ['green' if i == prediction else 'blue' for i in range(10)] bars = ax2.bar(range(10), probabilities, color=colors, alpha=0.7) ax2.set_xlabel('Digits') ax2.set_ylabel('Probability') ax2.set_title('Prediction Probabilities') ax2.set_xticks(range(10)) ax2.set_ylim(0, 1) # Add value labels for bar, prob in zip(bars, probabilities): height = bar.get_height() if height > 0.1: ax2.text(bar.get_x() + bar.get_width()/2., height, f'{prob:.2f}', ha='center', va='bottom', fontsize=9) plt.tight_layout() # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) plot_image = Image.open(buf) plt.close() # Format results result_text = f"๐ŸŽฏ **Predicted Digit: {prediction}**\n\n" result_text += f"๐Ÿ“Š **Confidence: {probabilities[prediction]*100:.2f}%**\n\n" result_text += "๐Ÿ† **Top 3 Predictions:**\n" for i, (digit, prob) in enumerate(zip(top_3_indices, top_3_probs)): result_text += f" {i+1}. Digit {digit}: {prob*100:.2f}%\n" return result_text, plot_image except Exception as e: return f"โŒ Error: {str(e)}", None # Create Gradio interface - COMPLETELY FIXED VERSION with gr.Blocks( theme=gr.themes.Soft(), title="MNIST Digit Classifier - Bernoulli Naive Bayes" ) as demo: gr.Markdown(f""" # โœ๏ธ MNIST Handwritten Digit Classifier ## ๐Ÿค– Bernoulli Naive Bayes | Accuracy: {accuracy*100:.2f}% **Upload an image of a digit (0-9) and see the AI prediction!** """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“ Upload Image") # โœ… FIXED: Simple Image upload without sources parameter image_input = gr.Image( label="Upload digit image (0-9)", type="numpy", height=300, width=300 ) with gr.Row(): clear_btn = gr.Button("๐Ÿงน Clear") predict_btn = gr.Button("๐Ÿ” Predict Digit", variant="primary") with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“Š Prediction Results") output_text = gr.Markdown( value="**Upload an image of a digit and click Predict!**" ) gr.Markdown("### ๐Ÿ“ˆ Visualization") output_plot = gr.Image( label="Probability Distribution", height=300 ) # Instructions for drawing gr.Markdown("### ๐Ÿ’ก How to use:") gr.Markdown(""" 1. **Draw a digit** on paper or using any drawing app 2. **Save as image** (PNG/JPG format) 3. **Upload here** using the upload button above 4. **Click Predict** to see results **Tips:** - Draw clear, centered digits - Use black ink on white background - Make digits large and clear """) gr.Markdown("---") gr.Markdown(f""" **Model Information:** - Algorithm: Bernoulli Naive Bayes - Dataset: MNIST Handwritten Digits - Accuracy: {accuracy*100:.2f}% - Input: 28ร—28 grayscale images """) # Button actions predict_btn.click( fn=predict_digit, inputs=image_input, outputs=[output_text, output_plot] ) clear_btn.click( fn=lambda: [None, "**Cleared! Upload a new image.**", None], outputs=[image_input, output_text, output_plot] ) # Launch app if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)