Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,8 +4,9 @@ import torch.nn as nn
|
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from PIL import Image
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
|
| 8 |
-
# Define the network
|
| 9 |
class Net(nn.Module):
|
| 10 |
def __init__(self):
|
| 11 |
super(Net, self).__init__()
|
|
@@ -24,88 +25,72 @@ class Net(nn.Module):
|
|
| 24 |
x = self.fc2(x)
|
| 25 |
return F.log_softmax(x, dim=1)
|
| 26 |
|
| 27 |
-
# Load
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
network = Net().to(device)
|
| 30 |
network.load_state_dict(torch.load("model.pth", map_location=device))
|
| 31 |
network.eval()
|
| 32 |
|
| 33 |
-
# Prediction function
|
| 34 |
def predict_digit(image):
|
| 35 |
-
"""
|
| 36 |
-
Takes an image from the drawing pad and predicts the digit
|
| 37 |
-
"""
|
| 38 |
if image is None:
|
| 39 |
-
return
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
if isinstance(image, dict):
|
| 43 |
image = image['composite']
|
| 44 |
|
| 45 |
img = Image.fromarray(image).convert('L')
|
| 46 |
img = img.resize((28, 28))
|
| 47 |
|
| 48 |
-
# Convert to tensor
|
| 49 |
img_array = np.array(img, dtype=np.float32) / 255.0
|
| 50 |
img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0).to(device)
|
| 51 |
-
|
| 52 |
-
# Apply same normalization as training
|
| 53 |
img_tensor = (img_tensor - 0.1307) / 0.3081
|
| 54 |
|
| 55 |
-
# Predict
|
| 56 |
with torch.no_grad():
|
| 57 |
output = network(img_tensor)
|
| 58 |
-
|
| 59 |
-
pred = output.data.max(1
|
| 60 |
-
confidence = probabilities[pred].item() * 100
|
| 61 |
|
| 62 |
-
# Create
|
| 63 |
-
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
gr.Markdown(
|
|
|
|
| 77 |
|
| 78 |
with gr.Row():
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
canvas = gr.Sketchpad(
|
| 82 |
-
label="Canvas",
|
| 83 |
-
type="numpy",
|
| 84 |
-
interactive=True
|
| 85 |
-
)
|
| 86 |
-
with gr.Row():
|
| 87 |
-
clear_btn = gr.Button("Clear Canvas", size="sm")
|
| 88 |
-
submit_btn = gr.Button("Predict", variant="primary", size="sm")
|
| 89 |
-
|
| 90 |
-
with gr.Column(scale=1):
|
| 91 |
-
gr.Markdown("### Prediction Result")
|
| 92 |
-
output_text = gr.Markdown("Draw a digit to get started!")
|
| 93 |
-
|
| 94 |
-
# Use Plot instead of BarChart for Gradio 6.0
|
| 95 |
-
confidence_plot = gr.Plot(label="Confidence Scores")
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
confidences = [data[1] for data in chart_data]
|
| 107 |
-
|
| 108 |
-
colors = ['#FF6B6B' if i != int(chart_data[int(chart_data[0][0])][0]) else '#4CAF50'
|
| 109 |
-
for i in range(10)]
|
| 110 |
-
|
| 111 |
-
ax.bar(digits, confidences, color=colors, alpha
|
|
|
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from PIL import Image
|
| 6 |
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
|
| 9 |
+
# Define the network
|
| 10 |
class Net(nn.Module):
|
| 11 |
def __init__(self):
|
| 12 |
super(Net, self).__init__()
|
|
|
|
| 25 |
x = self.fc2(x)
|
| 26 |
return F.log_softmax(x, dim=1)
|
| 27 |
|
| 28 |
+
# Load model
|
| 29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
network = Net().to(device)
|
| 31 |
network.load_state_dict(torch.load("model.pth", map_location=device))
|
| 32 |
network.eval()
|
| 33 |
|
|
|
|
| 34 |
def predict_digit(image):
|
| 35 |
+
"""Predict digit from image"""
|
|
|
|
|
|
|
| 36 |
if image is None:
|
| 37 |
+
return None
|
| 38 |
|
| 39 |
+
# Handle Gradio 6.0 dict format
|
| 40 |
+
if isinstance(image, dict):
|
| 41 |
image = image['composite']
|
| 42 |
|
| 43 |
img = Image.fromarray(image).convert('L')
|
| 44 |
img = img.resize((28, 28))
|
| 45 |
|
|
|
|
| 46 |
img_array = np.array(img, dtype=np.float32) / 255.0
|
| 47 |
img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0).to(device)
|
|
|
|
|
|
|
| 48 |
img_tensor = (img_tensor - 0.1307) / 0.3081
|
| 49 |
|
|
|
|
| 50 |
with torch.no_grad():
|
| 51 |
output = network(img_tensor)
|
| 52 |
+
probs = torch.exp(output[0])
|
| 53 |
+
pred = output.data.max(1)[1][0].item()
|
|
|
|
| 54 |
|
| 55 |
+
# Create visualization
|
| 56 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
| 57 |
|
| 58 |
+
# Plot 1: Drawn digit
|
| 59 |
+
ax1.imshow(img, cmap='gray')
|
| 60 |
+
ax1.set_title(f'Predicted: {int(pred)}', fontsize=16, fontweight='bold')
|
| 61 |
+
ax1.axis('off')
|
| 62 |
|
| 63 |
+
# Plot 2: Confidence bars
|
| 64 |
+
digits = list(range(10))
|
| 65 |
+
confidences = [float(probs[i].item() * 100) for i in range(10)]
|
| 66 |
+
colors = ['#4CAF50' if i == pred else '#FF6B6B' for i in range(10)]
|
| 67 |
+
|
| 68 |
+
ax2.bar(digits, confidences, color=colors, alpha=0.7) # ✅ FIXED: Added closing parenthesis
|
| 69 |
+
ax2.set_xlabel('Digit', fontsize=12)
|
| 70 |
+
ax2.set_ylabel('Confidence (%)', fontsize=12)
|
| 71 |
+
ax2.set_title('Confidence Scores', fontsize=12, fontweight='bold')
|
| 72 |
+
ax2.set_ylim([0, 100])
|
| 73 |
+
ax2.set_xticks(digits)
|
| 74 |
+
ax2.grid(axis='y', alpha=0.3)
|
| 75 |
+
|
| 76 |
+
plt.tight_layout()
|
| 77 |
+
return fig
|
| 78 |
|
| 79 |
+
# Create interface
|
| 80 |
+
with gr.Blocks(title="MNIST Digit Recognition") as demo:
|
| 81 |
+
gr.Markdown("# 🔢 MNIST Digit Recognition")
|
| 82 |
+
gr.Markdown("Draw a digit (0-9) and the model will predict it!")
|
| 83 |
|
| 84 |
with gr.Row():
|
| 85 |
+
canvas = gr.Sketchpad(label="Draw Here", type="numpy")
|
| 86 |
+
output = gr.Plot(label="Prediction")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
with gr.Row():
|
| 89 |
+
predict_btn = gr.Button("Predict", variant="primary")
|
| 90 |
+
clear_btn = gr.Button("Clear")
|
| 91 |
+
|
| 92 |
+
predict_btn.click(fn=predict_digit, inputs=canvas, outputs=output)
|
| 93 |
+
clear_btn.click(lambda: None, outputs=canvas)
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|