Ponleur commited on
Commit
a58f57e
·
verified ·
1 Parent(s): 818077e

Upload 5 files

Browse files
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.hub
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ import os
9
+
10
+ # Define the device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # List of available model files
14
+ MODEL_FILES = {
15
+ "lenet": "grayscale_lenet_state_dict.pt",
16
+ "cnn": "grayscale_custom_CNN_state_dict.pt",
17
+ "resnet": "grayscale_resnet_state_dict.pt",
18
+ "vgg": "grayscale_vgg_state_dict.pt"
19
+ }
20
+
21
+ # Replace with your actual class names
22
+ class_names = {'49': 'SARA AA',
23
+ '34': 'RO RUA',
24
+ '18': 'NO NEN',
25
+ '64': 'MAI THO',
26
+ '37': 'LU',
27
+ '05': 'KHO RAKHANG',
28
+ '46': 'PAIYANNOI',
29
+ '35': 'RU',
30
+ '17': 'THO PHUTHAO',
31
+ '06': 'NGO NGU',
32
+ '09': 'CHO CHANG',
33
+ '19': 'DO DEK',
34
+ '28': 'FO FA',
35
+ '24': 'NO NU',
36
+ '57': 'SARA E',
37
+ '23': 'THO THONG',
38
+ '42': 'HO HIP',
39
+ '08': 'CHO CHING',
40
+ '20': 'TO TAO',
41
+ '16': 'THO NANGMONTHO',
42
+ '44': 'O ANG',
43
+ '31': 'PHO SAMPHAO',
44
+ '02': 'KHO KHUAT',
45
+ '07': 'CHO CHAN',
46
+ '29': 'PHO PHAN',
47
+ '39': 'SO SALA',
48
+ '60': 'SARA AI MAIMUAN',
49
+ '11': 'CHO CHOE',
50
+ '55': 'SARA U',
51
+ '50': 'SARA AM',
52
+ '53': 'SARA UE',
53
+ '40': 'SO RUSI',
54
+ '59': 'SARA O',
55
+ '22': 'THO THAHAN',
56
+ '30': 'FO FAN',
57
+ '27': 'PHO PHUNG',
58
+ '13': 'DO CHADA',
59
+ '67': 'THANTHAKHAT',
60
+ '10': 'SO SO',
61
+ '61': 'SARA AI MAIMALAI',
62
+ '33': 'YO YAK',
63
+ '32': 'MO MA',
64
+ '54': 'SARA UEE',
65
+ '41': 'SO SUA',
66
+ '03': 'KHO KHWAI',
67
+ '65': 'MAI TRI',
68
+ '00': 'KO KAI',
69
+ '25': 'BO BAIMAI',
70
+ '52': 'SARA II',
71
+ '66': 'MAI CHATTAWA',
72
+ '45': 'HO NOKHUK',
73
+ '47': 'SARA A',
74
+ '38': 'WO WAEN',
75
+ '56': 'SARA UU',
76
+ '14': 'TO PATAK',
77
+ '58': 'SARA AE',
78
+ '26': 'PO PLA',
79
+ '63': 'MAI EK',
80
+ '15': 'THO THAN',
81
+ '12': 'YO YING',
82
+ '21': 'THO THUNG',
83
+ '01': 'KHO KHAI',
84
+ '36': 'LO LING',
85
+ '43': 'LO CHULA',
86
+ '48': 'MAI HAN',
87
+ '62': 'MAITAIKHU',
88
+ '04': 'KHO KHON',
89
+ '51': 'SARA I'} # Update with actual class names
90
+
91
+ # Image preprocessing
92
+ transform = transforms.Compose([
93
+ transforms.Grayscale(num_output_channels=1), # if your images are grayscale
94
+ transforms.Resize((64, 64)), # ResNet expects 224x224
95
+ transforms.ToTensor(),
96
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
97
+ transforms.Normalize(mean=[0.485,], std=[0.229,])
98
+ ])
99
+ class LeNet5(nn.Module):
100
+ def __init__(self, num_classes=68):
101
+ super(LeNet5, self).__init__()
102
+ self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
103
+ self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
104
+
105
+ self.fc1 = nn.Linear(16*13*13, 120)
106
+ self.fc2 = nn.Linear(120, 84)
107
+ self.fc3 = nn.Linear(84, num_classes)
108
+
109
+ def forward(self, x):
110
+ x = F.relu(self.conv1(x))
111
+ x = F.avg_pool2d(x, 2)
112
+ x = F.relu(self.conv2(x))
113
+ x = F.avg_pool2d(x, 2)
114
+
115
+ x = torch.flatten(x, 1)
116
+ x = F.relu(self.fc1(x))
117
+ x = F.relu(self.fc2(x))
118
+ x = self.fc3(x)
119
+
120
+ return x
121
+
122
+ class HandwrittenTextCNN(nn.Module):
123
+ def __init__(self):
124
+ super(HandwrittenTextCNN, self).__init__()
125
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
126
+ self.bn1 = nn.BatchNorm2d(32)
127
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
128
+ self.bn2 = nn.BatchNorm2d(64)
129
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
130
+ self.bn3 = nn.BatchNorm2d(128)
131
+ self.pool = nn.MaxPool2d(2, 2)
132
+ self.relu = nn.ReLU()
133
+ self.dropout = nn.Dropout(0.2)
134
+
135
+ self.fc1 = nn.Linear(8192,4096)
136
+ self.fc2 = nn.Linear(4096, 2048)
137
+ self.fc3 = nn.Linear(2048,1024)
138
+ self.fc4 = nn.Linear(1024,68)
139
+
140
+ def forward(self, x):
141
+ x = self.pool(self.relu(self.bn1(self.conv1(x))))
142
+ x = self.dropout(x)
143
+ x = self.pool(self.relu(self.bn2(self.conv2(x))))
144
+ x = self.dropout(x)
145
+ x = self.pool(self.relu(self.bn3(self.conv3(x))))
146
+ x = self.dropout(x)
147
+ x = torch.flatten(x,1)
148
+ x = self.relu(self.fc1(x))
149
+ x = self.relu(self.fc2(x))
150
+ x = self.relu(self.fc3(x))
151
+ x = self.fc4(x)
152
+ return x # Shape: [batch_size, 128, 8, 8]
153
+
154
+ def load_model(model_choice):
155
+ model_path = MODEL_FILES[model_choice]
156
+ if not os.path.exists(model_path):
157
+ raise FileNotFoundError(f"Model file {model_path} not found.")
158
+
159
+
160
+ if "cnn" in model_choice:
161
+ # Load custom model
162
+ model = HandwrittenTextCNN()
163
+
164
+ elif "lenet" in model_choice:
165
+ model = LeNet5()
166
+
167
+ elif "vgg" in model_choice:
168
+ model = torch.hub.load('pytorch/vision:v0.10.0','vgg11', pretrained=False)
169
+ model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
170
+ model.classifier[-1] = nn.Linear(in_features=4096, out_features=68, bias=True)
171
+
172
+ else:
173
+ # Load pre-trained ResNet18 from torch.hub
174
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
175
+ model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
176
+ model.fc = nn.Linear(model.fc.in_features, out_features=68, bias=True)
177
+
178
+ # Load state dictionary
179
+ model.load_state_dict(torch.load(model_path, map_location=device))
180
+ model = model.to(device)
181
+ model.eval()
182
+ return model
183
+
184
+ def predict(model_choice, image):
185
+ if image is None:
186
+ return "Please upload an image."
187
+
188
+ try:
189
+ # Load the selected model
190
+ model = load_model(model_choice)
191
+
192
+ # Process the image
193
+ image = Image.fromarray(image).convert("RGB")
194
+ image = transform(image).unsqueeze(0).to(device)
195
+
196
+ # Make prediction
197
+ with torch.no_grad():
198
+ outputs = model(image)
199
+ _, predicted = torch.max(outputs, 1)
200
+ predicted_class = class_names[predicted.item()]
201
+
202
+ return f"Predicted class: {class_names[predicted_class]}"
203
+
204
+ except Exception as e:
205
+ return f"Error: {str(e)}"
206
+
207
+ # Gradio interface
208
+ iface = gr.Interface(
209
+ fn=predict,
210
+ inputs=[
211
+ gr.Dropdown(choices=list(MODEL_FILES.keys()), label="Select Model"),
212
+ gr.Image(type="numpy", label="Upload Image")
213
+ ],
214
+ outputs="text",
215
+ title="Image Classification with PyTorch Models",
216
+ description="Select a custom or pre-trained model and upload an image to get a classification prediction."
217
+ )
218
+
219
+ # Launch the app
220
+ if __name__ == "__main__":
221
+ iface.launch()
222
+
223
+
224
+
grayscale_custom_CNN_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfdc5f9f3d8392278e94ad9af1b8249215e73842ff48771ac2f0486cb9814277
3
+ size 176852631
grayscale_lenet_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3d4f8b37e5cd05ece299477420ea2486e2decc119ca751ff76f399b1a02dabe
3
+ size 1376514
grayscale_resnet_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8149f19e258d0f42da16570fbf4e036f08618aa71f1d7026eb2ef497a383a2d8
3
+ size 44901910
grayscale_vgg_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef0e64faefa6f183c1b189543d7652e4c535eba9815edcb18599daeff141a9b2
3
+ size 516183630