Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Define the network | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
| self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
| self.conv2_drop = nn.Dropout2d() | |
| self.fc1 = nn.Linear(320, 50) | |
| self.fc2 = nn.Linear(50, 10) | |
| def forward(self, x): | |
| x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
| x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | |
| x = x.view(-1, 320) | |
| x = F.relu(self.fc1(x)) | |
| x = F.dropout(x, training=self.training) | |
| x = self.fc2(x) | |
| return F.log_softmax(x, dim=1) | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| network = Net().to(device) | |
| network.load_state_dict(torch.load("model.pth", map_location=device)) | |
| network.eval() | |
| def predict_digit(image): | |
| """Predict digit from image""" | |
| if image is None: | |
| return None | |
| # Handle Gradio 6.0 dict format | |
| if isinstance(image, dict): | |
| image = image['composite'] | |
| img = Image.fromarray(image).convert('L') | |
| img = img.resize((28, 28)) | |
| img_array = np.array(img, dtype=np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0).to(device) | |
| img_tensor = (img_tensor - 0.1307) / 0.3081 | |
| with torch.no_grad(): | |
| output = network(img_tensor) | |
| probs = torch.exp(output[0]) | |
| pred = output.data.max(1)[1][0].item() | |
| # Create visualization | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
| # Plot 1: Drawn digit | |
| ax1.imshow(img, cmap='gray') | |
| ax1.set_title(f'Predicted: {int(pred)}', fontsize=16, fontweight='bold') | |
| ax1.axis('off') | |
| # Plot 2: Confidence bars | |
| digits = list(range(10)) | |
| confidences = [float(probs[i].item() * 100) for i in range(10)] | |
| colors = ['#4CAF50' if i == pred else '#FF6B6B' for i in range(10)] | |
| ax2.bar(digits, confidences, color=colors, alpha=0.7) # ✅ FIXED: Added closing parenthesis | |
| ax2.set_xlabel('Digit', fontsize=12) | |
| ax2.set_ylabel('Confidence (%)', fontsize=12) | |
| ax2.set_title('Confidence Scores', fontsize=12, fontweight='bold') | |
| ax2.set_ylim([0, 100]) | |
| ax2.set_xticks(digits) | |
| ax2.grid(axis='y', alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| # Create interface | |
| with gr.Blocks(title="MNIST Digit Recognition") as demo: | |
| gr.Markdown("# 🔢 MNIST Digit Recognition") | |
| gr.Markdown("Draw a digit (0-9) and the model will predict it!") | |
| with gr.Row(): | |
| canvas = gr.Sketchpad(label="Draw Here", type="numpy") | |
| output = gr.Plot(label="Prediction") | |
| with gr.Row(): | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| predict_btn.click(fn=predict_digit, inputs=canvas, outputs=output) | |
| clear_btn.click(lambda: None, outputs=canvas) | |
| if __name__ == "__main__": | |
| demo.launch() | |