keysun89 commited on
Commit
84508af
·
verified ·
1 Parent(s): d07ff0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -2
app.py CHANGED
@@ -2,13 +2,40 @@ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from torchvision import transforms
6
  from PIL import Image
7
  import numpy as np
8
 
9
  # Define your 4 classes
10
  CLASS_NAMES = ['Cover Drive', 'Pull Shot', 'Cut Shot', 'Straight Drive'] # Update with your actual class names
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Custom CNN Model Definition
13
  class CricketShotCNN(nn.Module):
14
  def __init__(self, num_classes=4):
@@ -66,7 +93,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66
  def load_models():
67
  try:
68
  # Load VGG16 fine-tuned model
69
- vgg16_model = torch.load('vgg16_finetuned.pth', map_location=device)
 
 
70
  vgg16_model.eval()
71
 
72
  # Load Custom CNN model
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ from torchvision import transforms, models
6
  from PIL import Image
7
  import numpy as np
8
 
9
  # Define your 4 classes
10
  CLASS_NAMES = ['Cover Drive', 'Pull Shot', 'Cut Shot', 'Straight Drive'] # Update with your actual class names
11
 
12
+ # VGG16 Fine-tuned Model Definition
13
+ class VGG16FineTuned(nn.Module):
14
+ def __init__(self, num_classes=4):
15
+ super(VGG16FineTuned, self).__init__()
16
+ # Load pre-trained VGG16 features
17
+ vgg16 = models.vgg16(pretrained=False)
18
+ self.features = vgg16.features
19
+ self.avgpool = vgg16.avgpool
20
+
21
+ # Custom classifier to match your architecture
22
+ self.classifier = nn.Sequential(
23
+ nn.Linear(25088, 1024),
24
+ nn.ReLU(),
25
+ nn.Dropout(p=0.5),
26
+ nn.Linear(1024, 512),
27
+ nn.ReLU(),
28
+ nn.Dropout(p=0.5),
29
+ nn.Linear(512, num_classes)
30
+ )
31
+
32
+ def forward(self, x):
33
+ x = self.features(x)
34
+ x = self.avgpool(x)
35
+ x = torch.flatten(x, 1)
36
+ x = self.classifier(x)
37
+ return x
38
+
39
  # Custom CNN Model Definition
40
  class CricketShotCNN(nn.Module):
41
  def __init__(self, num_classes=4):
 
93
  def load_models():
94
  try:
95
  # Load VGG16 fine-tuned model
96
+ vgg16_model = VGG16FineTuned(num_classes=4)
97
+ vgg16_model.load_state_dict(torch.load('vgg16_finetuned.pth', map_location=device))
98
+ vgg16_model.to(device)
99
  vgg16_model.eval()
100
 
101
  # Load Custom CNN model