Ali-Hyder2019 commited on
Commit
16cddb5
Β·
verified Β·
1 Parent(s): 892ad19

Upload 3 files

Browse files
app (1).py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import joblib
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+
9
+ # Load model
10
+ try:
11
+ model_data = joblib.load("mnist_bernoulli_nb_model.joblib")
12
+ model = model_data["model"]
13
+ binarizer = model_data["binarizer"]
14
+ accuracy = model_data["accuracy"]
15
+ print("βœ… Model loaded successfully!")
16
+ except Exception as e:
17
+ print(f"❌ Error loading model: {e}")
18
+ # Fallback: create a simple model
19
+ from sklearn.naive_bayes import BernoulliNB
20
+ from sklearn.preprocessing import Binarizer
21
+ from sklearn.datasets import fetch_openml
22
+ mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
23
+ X, y = mnist["data"][:1000], mnist["target"][:1000].astype(int)
24
+ binarizer = Binarizer(threshold=127.0)
25
+ X_bin = binarizer.fit_transform(X)
26
+ model = BernoulliNB()
27
+ model.fit(X_bin, y)
28
+ accuracy = 0.83
29
+
30
+ def preprocess_drawing(image):
31
+ """Convert drawing to MNIST format"""
32
+ try:
33
+ # Convert to grayscale
34
+ if image.mode != 'L':
35
+ image = image.convert('L')
36
+
37
+ # Resize to 28x28
38
+ image = image.resize((28, 28))
39
+
40
+ # Invert colors and normalize
41
+ image_array = 255 - np.array(image)
42
+ image_flat = image_array.flatten()
43
+
44
+ # Binarize
45
+ image_bin = binarizer.transform([image_flat])
46
+
47
+ return image_bin, image_array.reshape(28, 28)
48
+ except Exception as e:
49
+ print(f"Preprocessing error: {e}")
50
+ return None, None
51
+
52
+ def predict_digit(image):
53
+ """Predict digit from drawing"""
54
+ if image is None:
55
+ return "Please draw a digit (0-9) first!", None
56
+
57
+ try:
58
+ processed_image, processed_array = preprocess_drawing(image)
59
+
60
+ if processed_image is None:
61
+ return "Error processing image. Please try again.", None
62
+
63
+ # Predict
64
+ prediction = model.predict(processed_image)[0]
65
+ probabilities = model.predict_proba(processed_image)[0]
66
+
67
+ # Create visualization
68
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
69
+
70
+ # Show processed image
71
+ ax1.imshow(processed_array, cmap='gray')
72
+ ax1.set_title(f'Processed Image\nPrediction: {prediction}')
73
+ ax1.axis('off')
74
+
75
+ # Show probabilities
76
+ colors = ['green' if i == prediction else 'blue' for i in range(10)]
77
+ ax2.bar(range(10), probabilities, color=colors, alpha=0.7)
78
+ ax2.set_xlabel('Digits')
79
+ ax2.set_ylabel('Probability')
80
+ ax2.set_title('Prediction Probabilities')
81
+ ax2.set_xticks(range(10))
82
+ ax2.set_ylim(0, 1)
83
+
84
+ plt.tight_layout()
85
+
86
+ # Convert to image
87
+ buf = io.BytesIO()
88
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
89
+ buf.seek(0)
90
+ plot_image = Image.open(buf)
91
+ plt.close()
92
+
93
+ # Format results
94
+ result_text = f"🎯 **Predicted Digit: {prediction}**\n\n"
95
+ result_text += f"πŸ“Š **Confidence: {probabilities[prediction]*100:.2f}%**\n\n"
96
+ result_text += "πŸ† **Top 3 Predictions:**\n"
97
+ top_3_indices = np.argsort(probabilities)[-3:][::-1]
98
+ for i, digit in enumerate(top_3_indices):
99
+ result_text += f" {i+1}. Digit {digit}: {probabilities[digit]*100:.2f}%\n"
100
+
101
+ return result_text, plot_image
102
+
103
+ except Exception as e:
104
+ return f"❌ Prediction error: {str(e)}", None
105
+
106
+ # Create interface
107
+ with gr.Blocks(theme=gr.themes.Soft(), title="MNIST Digit Classifier") as demo:
108
+ gr.Markdown(f"""
109
+ # ✍️ MNIST Handwritten Digit Classifier
110
+ ## πŸ€– Bernoulli Naive Bayes | Accuracy: {accuracy*100:.2f}%
111
+
112
+ **Draw a digit (0-9) and see AI prediction!**
113
+ """)
114
+
115
+ with gr.Row():
116
+ with gr.Column():
117
+ sketchpad = gr.Sketchpad(
118
+ label="🎨 Draw Digit (0-9)",
119
+ shape=(280, 280),
120
+ brush_radius=12,
121
+ type="pil"
122
+ )
123
+ with gr.Row():
124
+ clear_btn = gr.Button("🧹 Clear")
125
+ predict_btn = gr.Button("πŸ” Predict", variant="primary")
126
+
127
+ with gr.Column():
128
+ output_text = gr.Markdown("Draw a digit and click Predict!")
129
+ output_plot = gr.Image(label="πŸ“Š Visualization", height=300)
130
+
131
+ # Button actions
132
+ predict_btn.click(
133
+ predict_digit,
134
+ inputs=sketchpad,
135
+ outputs=[output_text, output_plot]
136
+ )
137
+
138
+ clear_btn.click(
139
+ lambda: [None, "Canvas cleared! Draw a digit...", None],
140
+ outputs=[sketchpad, output_text, output_plot]
141
+ )
142
+
143
+ gr.Markdown("---")
144
+ gr.Markdown("""
145
+ **Model Info:** Bernoulli Naive Bayes | MNIST Dataset | 28Γ—28 pixels
146
+ """)
147
+
148
+ if __name__ == "__main__":
149
+ demo.launch(server_name="0.0.0.0", server_port=7860)
mnist_bernoulli_nb_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4bd47c23b254f34b1f2c3723d96525272b9b32efca96fa79a6fb07db0a04d5f
3
+ size 126598
requirements (1).txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ scikit-learn>=1.3.0
3
+ pandas>=2.0.0
4
+ numpy>=1.24.0
5
+ pillow>=10.0.0
6
+ joblib>=1.3.0
7
+ matplotlib>=3.7.0