keysun89 commited on
Commit
676ee83
·
verified ·
1 Parent(s): 84508af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -6,8 +6,7 @@ from torchvision import transforms, models
6
  from PIL import Image
7
  import numpy as np
8
 
9
- # Define your 4 classes
10
- CLASS_NAMES = ['Cover Drive', 'Pull Shot', 'Cut Shot', 'Straight Drive'] # Update with your actual class names
11
 
12
  # VGG16 Fine-tuned Model Definition
13
  class VGG16FineTuned(nn.Module):
@@ -100,7 +99,7 @@ def load_models():
100
 
101
  # Load Custom CNN model
102
  custom_cnn_model = CricketShotCNN(num_classes=4)
103
- custom_cnn_model.load_state_dict(torch.load('custom_cnn.pth', map_location=device))
104
  custom_cnn_model.to(device)
105
  custom_cnn_model.eval()
106
 
 
6
  from PIL import Image
7
  import numpy as np
8
 
9
+ class_names = ['drive', 'legglance_flick', 'pullshot', 'sweep']
 
10
 
11
  # VGG16 Fine-tuned Model Definition
12
  class VGG16FineTuned(nn.Module):
 
99
 
100
  # Load Custom CNN model
101
  custom_cnn_model = CricketShotCNN(num_classes=4)
102
+ custom_cnn_model.load_state_dict(torch.load('cricket_model.pth', map_location=device))
103
  custom_cnn_model.to(device)
104
  custom_cnn_model.eval()
105