Ponleur commited on
Commit
2762e74
·
verified ·
1 Parent(s): 7321a1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -12,10 +12,10 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  # List of available model files
14
  MODEL_FILES = {
15
- "lenet": "grayscale_lenet_state_dict.pt",
16
- "cnn": "grayscale_custom_CNN_state_dict.pt",
17
- "resnet": "grayscale_resnet_state_dict.pt",
18
- "vgg": "grayscale_vgg_state_dict.pt"
19
  }
20
 
21
  # Replace with your actual class names
@@ -157,14 +157,14 @@ def load_model(model_choice):
157
  raise FileNotFoundError(f"Model file {model_path} not found.")
158
 
159
 
160
- if "cnn" in model_choice:
161
  # Load custom model
162
  model = HandwrittenTextCNN()
163
 
164
- elif "lenet" in model_choice:
165
  model = LeNet5()
166
 
167
- elif "vgg" in model_choice:
168
  model = torch.hub.load('pytorch/vision:v0.10.0','vgg11', pretrained=False)
169
  model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
170
  model.classifier[-1] = nn.Linear(in_features=4096, out_features=68, bias=True)
@@ -200,7 +200,7 @@ def predict(model_choice, image):
200
  predicted = torch.argmax(outputs,dim=1)
201
  predicted_class = class_names[f"{predicted.item():02d}"]
202
 
203
- return f"Predicted class: {predicted_class}"
204
 
205
  except Exception as e:
206
  return f"Error: {str(e)}"
@@ -213,7 +213,7 @@ iface = gr.Interface(
213
  gr.Image(type="pil", label="Upload Image")
214
  ],
215
  outputs="text",
216
- title="Image Classification with PyTorch Models",
217
  description="Select a custom or pre-trained model and upload an image to get a classification prediction."
218
  )
219
 
 
12
 
13
  # List of available model files
14
  MODEL_FILES = {
15
+ "LeNet": "grayscale_lenet_state_dict.pt",
16
+ "CNN": "grayscale_custom_CNN_state_dict.pt",
17
+ "ResNet": "grayscale_resnet_state_dict.pt",
18
+ "VGG": "grayscale_vgg_state_dict.pt"
19
  }
20
 
21
  # Replace with your actual class names
 
157
  raise FileNotFoundError(f"Model file {model_path} not found.")
158
 
159
 
160
+ if "CNN" in model_choice:
161
  # Load custom model
162
  model = HandwrittenTextCNN()
163
 
164
+ elif "LeNet" in model_choice:
165
  model = LeNet5()
166
 
167
+ elif "VGG" in model_choice:
168
  model = torch.hub.load('pytorch/vision:v0.10.0','vgg11', pretrained=False)
169
  model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
170
  model.classifier[-1] = nn.Linear(in_features=4096, out_features=68, bias=True)
 
200
  predicted = torch.argmax(outputs,dim=1)
201
  predicted_class = class_names[f"{predicted.item():02d}"]
202
 
203
+ return f"{predicted_class}"
204
 
205
  except Exception as e:
206
  return f"Error: {str(e)}"
 
213
  gr.Image(type="pil", label="Upload Image")
214
  ],
215
  outputs="text",
216
+ title="Burapha-TH Character dataset classification",
217
  description="Select a custom or pre-trained model and upload an image to get a classification prediction."
218
  )
219