2-rayza-2 commited on
Commit
a6cbc1b
·
verified ·
1 Parent(s): f30b06f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -97
app.py CHANGED
@@ -1,98 +1,98 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torchvision import datasets, transforms
5
- from torch.utils.data import DataLoader
6
- from PIL import Image
7
- import gradio as gr
8
- import os
9
-
10
- # Device (CPU for compatibility with Hugging Face Spaces)
11
- device = torch.device("cpu")
12
-
13
- # Transform for training and uploaded images
14
- transform = transforms.Compose([
15
- transforms.Resize((6, 6)),
16
- transforms.ToTensor()
17
- ])
18
-
19
- # Define a convolution block
20
- def conv(ic, oc):
21
- ks=3
22
- return nn.Sequential(
23
- nn.Conv2d(ic, oc, stride=2, kernel_size=ks, padding=ks//2),
24
- nn.BatchNorm2d(oc)
25
- )
26
-
27
- # CNN Model
28
- class SimpleCNN(nn.Module):
29
- def __init__(self):
30
- super().__init__()
31
- self.model = nn.Sequential(
32
- conv(1, 8),
33
- nn.Dropout2d(0.25),
34
- nn.ReLU(),
35
- conv(8, 16),
36
- nn.Dropout2d(0.25),
37
- nn.ReLU(),
38
- conv(16, 10),
39
- nn.Flatten()
40
- )
41
-
42
- def forward(self, x):
43
- return self.model(x)
44
-
45
- # Training function
46
- def train_model():
47
- train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
48
- batch_size = 36
49
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
50
-
51
- model = SimpleCNN().to(device)
52
- optimizer = optim.Adam(model.parameters(), lr=0.005)
53
- criterion = nn.CrossEntropyLoss()
54
-
55
- model.train()
56
- for epoch in range(3): # Keep it light for HF Spaces
57
- for images, labels in train_loader:
58
- images, labels = images.to(device), labels.to(device)
59
- optimizer.zero_grad()
60
- outputs = model(images)
61
- loss = criterion(outputs, labels)
62
- loss.backward()
63
- optimizer.step()
64
- return model
65
-
66
- # Load or train model
67
- model_path = "mnist_cnn.pt"
68
- if os.path.exists(model_path):
69
- model = SimpleCNN().to(device)
70
- model.load_state_dict(torch.load(model_path, map_location=device))
71
- else:
72
- model = train_model()
73
- torch.save(model.state_dict(), model_path)
74
-
75
- # Prediction function
76
- def predict(img):
77
- if isinstance(img, Image.Image):
78
- img = img.convert("L")
79
- else:
80
- return "Invalid image"
81
- x = transform(img).unsqueeze(0).to(device) # Shape: [1,1,8,8]
82
- model.eval()
83
- with torch.no_grad():
84
- output = model(x)
85
- pred = torch.argmax(output, dim=1).item()
86
- return f"Predicted digit: {pred}"
87
-
88
- # Gradio Interface
89
- demo = gr.Interface(
90
- fn=predict,
91
- inputs=gr.Image(shape=(28, 28), image_mode="L", invert_colors=True, sources=["upload", "canvas"]),
92
- outputs="text",
93
- title="MNIST Digit Classifier (6x6 CNN)",
94
- description="Upload or draw a digit to classify it using a lightweight CNN trained on MNIST resized to 8×8."
95
- )
96
-
97
- if __name__ == "__main__":
98
  demo.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import os
9
+
10
+ # Device (CPU for compatibility with Hugging Face Spaces)
11
+ device = torch.device("cpu")
12
+
13
+ # Transform for training and uploaded images
14
+ transform = transforms.Compose([
15
+ transforms.Resize((6, 6)),
16
+ transforms.ToTensor()
17
+ ])
18
+
19
+ # Define a convolution block
20
+ def conv(ic, oc):
21
+ ks=3
22
+ return nn.Sequential(
23
+ nn.Conv2d(ic, oc, stride=2, kernel_size=ks, padding=ks//2),
24
+ nn.BatchNorm2d(oc)
25
+ )
26
+
27
+ # CNN Model
28
+ class SimpleCNN(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.model = nn.Sequential(
32
+ conv(1, 8),
33
+ nn.Dropout2d(0.25),
34
+ nn.ReLU(),
35
+ conv(8, 16),
36
+ nn.Dropout2d(0.25),
37
+ nn.ReLU(),
38
+ conv(16, 10),
39
+ nn.Flatten()
40
+ )
41
+
42
+ def forward(self, x):
43
+ return self.model(x)
44
+
45
+ # Training function
46
+ def train_model():
47
+ train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
48
+ batch_size = 36
49
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
50
+
51
+ model = SimpleCNN().to(device)
52
+ optimizer = optim.Adam(model.parameters(), lr=0.005)
53
+ criterion = nn.CrossEntropyLoss()
54
+
55
+ model.train()
56
+ for epoch in range(3): # Keep it light for HF Spaces
57
+ for images, labels in train_loader:
58
+ images, labels = images.to(device), labels.to(device)
59
+ optimizer.zero_grad()
60
+ outputs = model(images)
61
+ loss = criterion(outputs, labels)
62
+ loss.backward()
63
+ optimizer.step()
64
+ return model
65
+
66
+ # Load or train model
67
+ model_path = "mnist_cnn.pt"
68
+ if os.path.exists(model_path):
69
+ model = SimpleCNN().to(device)
70
+ model.load_state_dict(torch.load(model_path, map_location=device))
71
+ else:
72
+ model = train_model()
73
+ torch.save(model.state_dict(), model_path)
74
+
75
+ # Prediction function
76
+ def predict(img):
77
+ if isinstance(img, Image.Image):
78
+ img = img.convert("L")
79
+ else:
80
+ return "Invalid image"
81
+ x = transform(img).unsqueeze(0).to(device) # Shape: [1,1,8,8]
82
+ model.eval()
83
+ with torch.no_grad():
84
+ output = model(x)
85
+ pred = torch.argmax(output, dim=1).item()
86
+ return f"Predicted digit: {pred}"
87
+
88
+ # Gradio Interface
89
+ demo = gr.Interface(
90
+ fn=predict,
91
+ inputs=gr.Image(image_mode="L", invert_colors=True, sources=["upload", "canvas"]),
92
+ outputs="text",
93
+ title="MNIST Digit Classifier (6x6 CNN)",
94
+ description="Upload or draw a digit to classify it using a lightweight CNN trained on MNIST resized to 8×8."
95
+ )
96
+
97
+ if __name__ == "__main__":
98
  demo.launch()