Scribbler310 commited on
Commit
dda1ebf
·
verified ·
1 Parent(s): 6aa80f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -68
app.py CHANGED
@@ -1,94 +1,49 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from torchvision import transforms
5
  from PIL import Image
6
 
7
- # ---------------------------------------------------------
8
- # 1. MODEL ARCHITECTURE
9
- # ---------------------------------------------------------
10
- class SimpleCNN(nn.Module):
11
- def __init__(self, num_classes=10):
12
- super(SimpleCNN, self).__init__()
13
-
14
- self.conv_block1 = nn.Sequential(
15
- nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
16
- nn.ReLU(),
17
- nn.MaxPool2d(kernel_size=2, stride=2),
18
- )
19
 
20
- self.conv_block2 = nn.Sequential(
21
- nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
22
- nn.ReLU(),
23
- nn.MaxPool2d(kernel_size=2, stride=2),
24
- )
25
-
26
- self.classifier = nn.Sequential(
27
- nn.Flatten(),
28
- nn.Linear(in_features=32 * 32 * 32, out_features=128),
29
- nn.ReLU(),
30
- nn.Linear(in_features=128, out_features=num_classes),
31
- )
32
-
33
- def forward(self, x):
34
- x = self.conv_block1(x)
35
- x = self.conv_block2(x)
36
- x = self.classifier(x)
37
- return x
38
-
39
- # ---------------------------------------------------------
40
- # 2. SETUP
41
- # ---------------------------------------------------------
42
- # Initialize model
43
- model = SimpleCNN()
44
-
45
- # Load weights (Ensure 'fulldigits.pt' is uploaded to Hugging Face Files!)
46
  try:
47
- model.load_state_dict(torch.load("fulldigits.pt", map_location="cpu"))
 
48
  model.eval()
49
- except FileNotFoundError:
50
- print("Error: 'fulldigits.pt' not found. Please upload your model file.")
51
 
52
- # Define transforms
53
- # CRITICAL FIX: Added lambda to force RGB.
54
- # This prevents crashes if someone uploads a Grayscale or RGBA image.
55
  transform = transforms.Compose([
56
- transforms.Lambda(lambda x: x.convert("RGB")),
57
- transforms.Resize((128, 128)),
58
  transforms.ToTensor(),
59
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
60
  ])
61
 
62
- # ---------------------------------------------------------
63
- # 3. PREDICTION FUNCTION
64
- # ---------------------------------------------------------
65
  def predict(image):
66
- if image is None:
67
- return None
68
-
69
- # Transform image
70
  img_tensor = transform(image).unsqueeze(0)
71
-
72
- # Make prediction
73
  with torch.no_grad():
74
  output = model(img_tensor)
75
- # Get probabilities
76
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
77
 
78
- # Return a dictionary for Gradio's Label component
79
- # This creates the nice bar chart effect
80
  return {str(i): float(probabilities[i]) for i in range(10)}
81
 
82
- # ---------------------------------------------------------
83
- # 4. GRADIO INTERFACE
84
- # ---------------------------------------------------------
85
  demo = gr.Interface(
86
  fn=predict,
87
- inputs=gr.Image(type="pil", label="Upload Image"),
88
- outputs=gr.Label(num_top_classes=3, label="Predictions"), # Changed to Label for better UI
89
- title="Digit Classification Project",
90
- description="Upload an image to check if it contains a digit (0-9).",
91
- # removed share=True for production deployment
92
  )
93
 
94
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ from torchvision import models, transforms
5
  from PIL import Image
6
 
7
+ # 1. SETUP MODEL
8
+ # We use ResNet18 structure to match your training
9
+ model = models.resnet18(weights=None)
10
+ model.fc = nn.Linear(model.fc.in_features, 10) # Adjust head to 10 classes
 
 
 
 
 
 
 
 
11
 
12
+ # Load your 98.79% accuracy weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  try:
14
+ state_dict = torch.load("fulldigits.pt", map_location="cpu")
15
+ model.load_state_dict(state_dict)
16
  model.eval()
17
+ except Exception as e:
18
+ print(f"Error loading model: {e}")
19
 
20
+ # 2. PREPROCESSING
21
+ # Must use the ImageNet stats you trained with!
 
22
  transform = transforms.Compose([
23
+ transforms.Lambda(lambda x: x.convert("RGB")), # Force RGB
24
+ transforms.Resize((128, 128)), # Match training size
25
  transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
27
  ])
28
 
29
+ # 3. PREDICT FUNCTION
 
 
30
  def predict(image):
31
+ if image is None: return None
 
 
 
32
  img_tensor = transform(image).unsqueeze(0)
33
+
 
34
  with torch.no_grad():
35
  output = model(img_tensor)
 
36
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
37
 
 
 
38
  return {str(i): float(probabilities[i]) for i in range(10)}
39
 
40
+ # 4. INTERFACE
 
 
41
  demo = gr.Interface(
42
  fn=predict,
43
+ inputs=gr.Image(type="pil", label="Draw or Upload Digit"),
44
+ outputs=gr.Label(num_top_classes=3),
45
+ title="Handwritten Digit Recognizer",
46
+ description="A ResNet18 model fine-tuned to 98.79% accuracy."
 
47
  )
48
 
49
  if __name__ == "__main__":