2-rayza-2 commited on
Commit
a4b6c42
·
verified ·
1 Parent(s): 8aefeeb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import os
9
+
10
+ # Device (CPU for compatibility with Hugging Face Spaces)
11
+ device = torch.device("cpu")
12
+
13
+ # Transform for training and uploaded images
14
+ transform = transforms.Compose([
15
+ transforms.Resize((6, 6)),
16
+ transforms.ToTensor()
17
+ ])
18
+
19
+ # Define a convolution block
20
+ def conv(ic, oc):
21
+ ks=3
22
+ return nn.Sequential(
23
+ nn.Conv2d(ic, oc, stride=2, kernel_size=ks, padding=ks//2),
24
+ nn.BatchNorm2d(oc)
25
+ )
26
+
27
+ # CNN Model
28
+ class SimpleCNN(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.model = nn.Sequential(
32
+ conv(1, 8),
33
+ nn.Dropout2d(0.25),
34
+ nn.ReLU(),
35
+ conv(8, 16),
36
+ nn.Dropout2d(0.25),
37
+ nn.ReLU(),
38
+ conv(16, 10),
39
+ nn.Flatten()
40
+ )
41
+
42
+ def forward(self, x):
43
+ return self.model(x)
44
+
45
+ # Training function
46
+ def train_model():
47
+ train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
48
+ batch_size = 36
49
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
50
+
51
+ model = SimpleCNN().to(device)
52
+ optimizer = optim.Adam(model.parameters(), lr=0.005)
53
+ criterion = nn.CrossEntropyLoss()
54
+
55
+ model.train()
56
+ for epoch in range(3): # Keep it light for HF Spaces
57
+ for images, labels in train_loader:
58
+ images, labels = images.to(device), labels.to(device)
59
+ optimizer.zero_grad()
60
+ outputs = model(images)
61
+ loss = criterion(outputs, labels)
62
+ loss.backward()
63
+ optimizer.step()
64
+ return model
65
+
66
+ # Load or train model
67
+ model_path = "mnist_cnn.pt"
68
+ if os.path.exists(model_path):
69
+ model = SimpleCNN().to(device)
70
+ model.load_state_dict(torch.load(model_path, map_location=device))
71
+ else:
72
+ model = train_model()
73
+ torch.save(model.state_dict(), model_path)
74
+
75
+ # Prediction function
76
+ def predict(img):
77
+ if isinstance(img, Image.Image):
78
+ img = img.convert("L")
79
+ else:
80
+ return "Invalid image"
81
+ x = transform(img).unsqueeze(0).to(device) # Shape: [1,1,8,8]
82
+ model.eval()
83
+ with torch.no_grad():
84
+ output = model(x)
85
+ pred = torch.argmax(output, dim=1).item()
86
+ return f"Predicted digit: {pred}"
87
+
88
+ # Gradio Interface
89
+ demo = gr.Interface(
90
+ fn=predict,
91
+ inputs=gr.Image(shape=(28, 28), image_mode="L", invert_colors=True, sources=["upload", "canvas"]),
92
+ outputs="text",
93
+ title="MNIST Digit Classifier (6x6 CNN)",
94
+ description="Upload or draw a digit to classify it using a lightweight CNN trained on MNIST resized to 8×8."
95
+ )
96
+
97
+ if __name__ == "__main__":
98
+ demo.launch()