ved1beta commited on
Commit
164ffbc
·
1 Parent(s): e1d27d4
app.py CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
2
  import numpy as np
3
  import pickle
4
  from PIL import Image
 
 
5
 
6
- # Load the model
7
  with open('model.pkl', 'rb') as f:
8
  model_params = pickle.load(f)
9
 
@@ -12,6 +14,7 @@ b1 = model_params['b1']
12
  W2 = model_params['W2']
13
  b2 = model_params['b2']
14
 
 
15
  def ReLu(Z):
16
  return np.maximum(Z, 0)
17
 
@@ -40,6 +43,7 @@ def preprocess_image(image):
40
 
41
  return img_array.T # Transpose to match the shape (784, 1)
42
 
 
43
  def predict_digit(image):
44
  X = preprocess_image(image)
45
 
@@ -51,13 +55,43 @@ def predict_digit(image):
51
 
52
  return int(prediction[0])
53
 
54
- # Gradio interface
55
- iface = gr.Interface(
56
- fn=predict_digit,
57
- inputs=gr.Image(type="pil"),
58
- outputs=gr.Label(num_top_classes=1),
59
- title="Handwritten Digit Recognition",
60
- description="Upload an image of a handwritten digit (0-9) and the model will predict which digit it is."
61
- )
62
 
63
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pickle
4
  from PIL import Image
5
+ import os
6
+ import random
7
 
8
+ # 1. Load the model
9
  with open('model.pkl', 'rb') as f:
10
  model_params = pickle.load(f)
11
 
 
14
  W2 = model_params['W2']
15
  b2 = model_params['b2']
16
 
17
+ # 2. Define helper functions
18
  def ReLu(Z):
19
  return np.maximum(Z, 0)
20
 
 
43
 
44
  return img_array.T # Transpose to match the shape (784, 1)
45
 
46
+ # 3. Define prediction function
47
  def predict_digit(image):
48
  X = preprocess_image(image)
49
 
 
55
 
56
  return int(prediction[0])
57
 
58
+ # 4. Load sample images
59
+ sample_images = []
60
+ sample_dir = "sample_images" # Make sure this directory exists in your Space
61
+ for filename in os.listdir(sample_dir):
62
+ if filename.endswith((".png", ".jpg", ".jpeg")):
63
+ img_path = os.path.join(sample_dir, filename)
64
+ sample_images.append(img_path)
 
65
 
66
+ # 5. Define function to select random image
67
+ def select_random_image():
68
+ return random.choice(sample_images)
69
+
70
+ # 6. Create Gradio interface
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("# Handwritten Digit Recognition")
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ input_image = gr.Image(type="pil", label="Input Image")
77
+ upload_button = gr.UploadButton("Upload Image", file_types=["image"])
78
+ sample_button = gr.Button("Use Random Sample Image")
79
+
80
+ with gr.Column():
81
+ output_label = gr.Label(label="Prediction")
82
+ predict_button = gr.Button("Predict")
83
+
84
+ upload_button.upload(fn=lambda file: file.name, inputs=upload_button, outputs=input_image)
85
+ sample_button.click(fn=select_random_image, inputs=None, outputs=input_image)
86
+ predict_button.click(fn=predict_digit, inputs=input_image, outputs=output_label)
87
+
88
+ gr.Markdown("## Sample Images")
89
+ with gr.Row():
90
+ for img_path in sample_images[:5]: # Display first 5 sample images
91
+ gr.Image(img_path, show_label=False, height=100)
92
+ with gr.Row():
93
+ for img_path in sample_images[5:10]: # Display next 5 sample images
94
+ gr.Image(img_path, show_label=False, height=100)
95
+
96
+ # 7. Launch the app
97
+ demo.launch()
sample_images/1.png ADDED
sample_images/2.png ADDED
sample_images/3.jpg ADDED
sample_images/4.png ADDED
sample_images/5.png ADDED
sample_images/6.png ADDED
sample_images/7.png ADDED
sample_images/8.jpg ADDED
sample_images/9.png ADDED