Scribbler310 commited on
Commit
bf8d031
·
verified ·
1 Parent(s): be2ea1b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ # ---------------------------------------------------------
8
+ # 1. MODEL ARCHITECTURE
9
+ # ---------------------------------------------------------
10
+ class SimpleCNN(nn.Module):
11
+ def __init__(self, num_classes=10):
12
+ super(SimpleCNN, self).__init__()
13
+
14
+ self.conv_block1 = nn.Sequential(
15
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(kernel_size=2, stride=2),
18
+ )
19
+
20
+ self.conv_block2 = nn.Sequential(
21
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
22
+ nn.ReLU(),
23
+ nn.MaxPool2d(kernel_size=2, stride=2),
24
+ )
25
+
26
+ self.classifier = nn.Sequential(
27
+ nn.Flatten(),
28
+ nn.Linear(in_features=32 * 32 * 32, out_features=128),
29
+ nn.ReLU(),
30
+ nn.Linear(in_features=128, out_features=num_classes),
31
+ )
32
+
33
+ def forward(self, x):
34
+ x = self.conv_block1(x)
35
+ x = self.conv_block2(x)
36
+ x = self.classifier(x)
37
+ return x
38
+
39
+ # ---------------------------------------------------------
40
+ # 2. SETUP
41
+ # ---------------------------------------------------------
42
+ # Initialize model
43
+ model = SimpleCNN()
44
+
45
+ # Load weights (Ensure 'fulldigits.pt' is uploaded to Hugging Face Files!)
46
+ try:
47
+ model.load_state_dict(torch.load("fulldigits.pt", map_location="cpu"))
48
+ model.eval()
49
+ except FileNotFoundError:
50
+ print("Error: 'fulldigits.pt' not found. Please upload your model file.")
51
+
52
+ # Define transforms
53
+ # CRITICAL FIX: Added lambda to force RGB.
54
+ # This prevents crashes if someone uploads a Grayscale or RGBA image.
55
+ transform = transforms.Compose([
56
+ transforms.Lambda(lambda x: x.convert("RGB")),
57
+ transforms.Resize((128, 128)),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
60
+ ])
61
+
62
+ # ---------------------------------------------------------
63
+ # 3. PREDICTION FUNCTION
64
+ # ---------------------------------------------------------
65
+ def predict(image):
66
+ if image is None:
67
+ return None
68
+
69
+ # Transform image
70
+ img_tensor = transform(image).unsqueeze(0)
71
+
72
+ # Make prediction
73
+ with torch.no_grad():
74
+ output = model(img_tensor)
75
+ # Get probabilities
76
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
77
+
78
+ # Return a dictionary for Gradio's Label component
79
+ # This creates the nice bar chart effect
80
+ return {str(i): float(probabilities[i]) for i in range(10)}
81
+
82
+ # ---------------------------------------------------------
83
+ # 4. GRADIO INTERFACE
84
+ # ---------------------------------------------------------
85
+ demo = gr.Interface(
86
+ fn=predict,
87
+ inputs=gr.Image(type="pil", label="Upload Image"),
88
+ outputs=gr.Label(num_top_classes=3, label="Predictions"), # Changed to Label for better UI
89
+ title="Digit Classification Project",
90
+ description="Upload an image to check if it contains a digit (0-9).",
91
+ # removed share=True for production deployment
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ demo.launch()