Scribbler310 commited on
Commit
99f7345
·
verified ·
1 Parent(s): c987b5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -39
app.py CHANGED
@@ -1,50 +1,47 @@
 
1
  import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import models, transforms
5
- from PIL import Image
6
 
7
- # 1. SETUP MODEL
8
- # We use ResNet18 structure to match your training
9
- model = models.resnet18(weights=None)
10
- model.fc = nn.Linear(model.fc.in_features, 10) # Adjust head to 10 classes
11
 
12
- # Load your 98.79% accuracy weights
13
- try:
14
- state_dict = torch.load("fulldigits.pt", map_location="cpu")
15
- model.load_state_dict(state_dict)
16
- model.eval()
17
- except Exception as e:
18
- print(f"Error loading model: {e}")
19
-
20
- # 2. PREPROCESSING
21
- # Must use the ImageNet stats you trained with!
22
- transform = transforms.Compose([
23
- transforms.Lambda(lambda x: x.convert("RGB")), # Force RGB
24
- transforms.Resize((128, 128)), # Match training size
25
- transforms.ToTensor(),
26
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
27
- ])
28
-
29
- # 3. PREDICT FUNCTION
30
- def predict(image):
31
- if image is None: return None
32
- img_tensor = transform(image).unsqueeze(0)
33
 
34
- with torch.no_grad():
35
- output = model(img_tensor)
36
- probabilities = torch.nn.functional.softmax(output[0], dim=0)
37
 
38
- return {str(i): float(probabilities[i]) for i in range(10)}
 
39
 
40
- # 4. INTERFACE
41
- demo = gr.Interface(
42
- fn=predict,
43
- inputs=gr.Image(type="pil", label="Draw or Upload Digit"),
 
44
  outputs=gr.Label(num_top_classes=3),
45
  title="Handwritten Digit Recognizer",
46
- description="A ResNet18 model fine-tuned to 98.79% accuracy."
47
  )
48
 
 
49
  if __name__ == "__main__":
50
- demo.launch()
 
1
+ import tensorflow as tf
2
  import gradio as gr
3
+ import numpy as np
4
+ import cv2
 
 
5
 
6
+ # 1. Load the trained model
7
+ model = tf.keras.models.load_model('digit_recognizer.keras')
 
 
8
 
9
+ # 2. Define the classification function
10
+ def classify_digit(image):
11
+ if image is None:
12
+ return None
13
+
14
+ # Preprocessing to match MNIST data format
15
+ # Convert to grayscale if it isn't already
16
+ if len(image.shape) == 3:
17
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
18
+
19
+ # Resize the image to 28x28 pixels
20
+ image = cv2.resize(image, (28, 28))
21
+
22
+ # Reshape to (1, 28, 28, 1) to match model input shape
23
+ # The '1' indicates a batch size of 1
24
+ image = image.reshape(1, 28, 28, 1)
25
+
26
+ # Normalize pixel values (0 to 1) just like in the training notebook
27
+ image = image / 255.0
 
 
28
 
29
+ # Predict
30
+ prediction = model.predict(image).flatten()
 
31
 
32
+ # Return dictionary for Gradio Label output
33
+ return {str(i): float(prediction[i]) for i in range(10)}
34
 
35
+ # 3. Build the Gradio Interface
36
+ # We use Sketchpad so users can draw the digit
37
+ interface = gr.Interface(
38
+ fn=classify_digit,
39
+ inputs=gr.Sketchpad(label="Draw a Digit"),
40
  outputs=gr.Label(num_top_classes=3),
41
  title="Handwritten Digit Recognizer",
42
+ description="Draw a digit (0-9) on the canvas to see if the Neural Network recognizes it."
43
  )
44
 
45
+ # 4. Launch
46
  if __name__ == "__main__":
47
+ interface.launch()