pallabi1608 commited on
Commit
3886447
·
verified ·
1 Parent(s): 142e8fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -60
app.py CHANGED
@@ -4,8 +4,9 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  from PIL import Image
6
  import numpy as np
 
7
 
8
- # Define the network (same as your model)
9
  class Net(nn.Module):
10
  def __init__(self):
11
  super(Net, self).__init__()
@@ -24,88 +25,72 @@ class Net(nn.Module):
24
  x = self.fc2(x)
25
  return F.log_softmax(x, dim=1)
26
 
27
- # Load the model
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  network = Net().to(device)
30
  network.load_state_dict(torch.load("model.pth", map_location=device))
31
  network.eval()
32
 
33
- # Prediction function
34
  def predict_digit(image):
35
- """
36
- Takes an image from the drawing pad and predicts the digit
37
- """
38
  if image is None:
39
- return "No input", "0%"
40
 
41
- # Convert image to grayscale and resize to 28x28
42
- if isinstance(image, dict): # Gradio 6.0 returns dict with 'composite' key
43
  image = image['composite']
44
 
45
  img = Image.fromarray(image).convert('L')
46
  img = img.resize((28, 28))
47
 
48
- # Convert to tensor
49
  img_array = np.array(img, dtype=np.float32) / 255.0
50
  img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0).to(device)
51
-
52
- # Apply same normalization as training
53
  img_tensor = (img_tensor - 0.1307) / 0.3081
54
 
55
- # Predict
56
  with torch.no_grad():
57
  output = network(img_tensor)
58
- probabilities = torch.exp(output[0])
59
- pred = output.data.max(1, keepdim=True)[1][0].item()
60
- confidence = probabilities[pred].item() * 100
61
 
62
- # Create confidence text and plot data
63
- confidence_text = f"Predicted: **{int(pred)}** ({confidence:.1f}%)"
64
 
65
- # Format for bar chart (as list of tuples)
66
- chart_data = [(str(i), float(probabilities[i].item() * 100)) for i in range(10)]
 
 
67
 
68
- return confidence_text, chart_data
69
-
70
- # Create Gradio interface
71
- title = "🔢 MNIST Digit Recognition"
72
- description = "Draw a digit (0-9) on the canvas and the model will predict what it is!"
 
 
 
 
 
 
 
 
 
 
73
 
74
- with gr.Blocks(title=title) as demo:
75
- gr.Markdown(f"# {title}")
76
- gr.Markdown(description)
 
77
 
78
  with gr.Row():
79
- with gr.Column(scale=1):
80
- gr.Markdown("### Draw a Digit")
81
- canvas = gr.Sketchpad(
82
- label="Canvas",
83
- type="numpy",
84
- interactive=True
85
- )
86
- with gr.Row():
87
- clear_btn = gr.Button("Clear Canvas", size="sm")
88
- submit_btn = gr.Button("Predict", variant="primary", size="sm")
89
-
90
- with gr.Column(scale=1):
91
- gr.Markdown("### Prediction Result")
92
- output_text = gr.Markdown("Draw a digit to get started!")
93
-
94
- # Use Plot instead of BarChart for Gradio 6.0
95
- confidence_plot = gr.Plot(label="Confidence Scores")
96
 
97
- # Function to create bar chart using matplotlib
98
- def create_prediction_with_chart(image):
99
- pred_text, chart_data = predict_digit(image)
100
-
101
- # Create matplotlib figure
102
- import matplotlib.pyplot as plt
103
- fig, ax = plt.subplots(figsize=(8, 4))
104
-
105
- digits = [str(i) for i in range(10)]
106
- confidences = [data[1] for data in chart_data]
107
-
108
- colors = ['#FF6B6B' if i != int(chart_data[int(chart_data[0][0])][0]) else '#4CAF50'
109
- for i in range(10)]
110
-
111
- ax.bar(digits, confidences, color=colors, alpha
 
4
  import torch.nn.functional as F
5
  from PIL import Image
6
  import numpy as np
7
+ import matplotlib.pyplot as plt
8
 
9
+ # Define the network
10
  class Net(nn.Module):
11
  def __init__(self):
12
  super(Net, self).__init__()
 
25
  x = self.fc2(x)
26
  return F.log_softmax(x, dim=1)
27
 
28
+ # Load model
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  network = Net().to(device)
31
  network.load_state_dict(torch.load("model.pth", map_location=device))
32
  network.eval()
33
 
 
34
  def predict_digit(image):
35
+ """Predict digit from image"""
 
 
36
  if image is None:
37
+ return None
38
 
39
+ # Handle Gradio 6.0 dict format
40
+ if isinstance(image, dict):
41
  image = image['composite']
42
 
43
  img = Image.fromarray(image).convert('L')
44
  img = img.resize((28, 28))
45
 
 
46
  img_array = np.array(img, dtype=np.float32) / 255.0
47
  img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0).to(device)
 
 
48
  img_tensor = (img_tensor - 0.1307) / 0.3081
49
 
 
50
  with torch.no_grad():
51
  output = network(img_tensor)
52
+ probs = torch.exp(output[0])
53
+ pred = output.data.max(1)[1][0].item()
 
54
 
55
+ # Create visualization
56
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
57
 
58
+ # Plot 1: Drawn digit
59
+ ax1.imshow(img, cmap='gray')
60
+ ax1.set_title(f'Predicted: {int(pred)}', fontsize=16, fontweight='bold')
61
+ ax1.axis('off')
62
 
63
+ # Plot 2: Confidence bars
64
+ digits = list(range(10))
65
+ confidences = [float(probs[i].item() * 100) for i in range(10)]
66
+ colors = ['#4CAF50' if i == pred else '#FF6B6B' for i in range(10)]
67
+
68
+ ax2.bar(digits, confidences, color=colors, alpha=0.7) # ✅ FIXED: Added closing parenthesis
69
+ ax2.set_xlabel('Digit', fontsize=12)
70
+ ax2.set_ylabel('Confidence (%)', fontsize=12)
71
+ ax2.set_title('Confidence Scores', fontsize=12, fontweight='bold')
72
+ ax2.set_ylim([0, 100])
73
+ ax2.set_xticks(digits)
74
+ ax2.grid(axis='y', alpha=0.3)
75
+
76
+ plt.tight_layout()
77
+ return fig
78
 
79
+ # Create interface
80
+ with gr.Blocks(title="MNIST Digit Recognition") as demo:
81
+ gr.Markdown("# 🔢 MNIST Digit Recognition")
82
+ gr.Markdown("Draw a digit (0-9) and the model will predict it!")
83
 
84
  with gr.Row():
85
+ canvas = gr.Sketchpad(label="Draw Here", type="numpy")
86
+ output = gr.Plot(label="Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ with gr.Row():
89
+ predict_btn = gr.Button("Predict", variant="primary")
90
+ clear_btn = gr.Button("Clear")
91
+
92
+ predict_btn.click(fn=predict_digit, inputs=canvas, outputs=output)
93
+ clear_btn.click(lambda: None, outputs=canvas)
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch()