keysun89 commited on
Commit
da9036a
Β·
verified Β·
1 Parent(s): 56098d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -61
app.py CHANGED
@@ -1,17 +1,80 @@
1
  import gradio as gr
2
- import tensorflow as tf
3
- from tensorflow import keras
4
- import numpy as np
 
5
  from PIL import Image
 
6
 
7
- class_names = ['drive', 'legglance_flick', 'pullshot', 'sweep']
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Load models
 
 
11
  def load_models():
12
  try:
13
- vgg16_model = keras.models.load_model('vgg16_finetuned.pth')
14
- custom_cnn_model = keras.models.load_model('cricket_model.pth')
 
 
 
 
 
 
 
 
15
  return vgg16_model, custom_cnn_model
16
  except Exception as e:
17
  print(f"Error loading models: {e}")
@@ -19,33 +82,6 @@ def load_models():
19
 
20
  vgg16_model, custom_cnn_model = load_models()
21
 
22
- def preprocess_image(image, target_size=(224, 224)):
23
- """Preprocess image for model prediction"""
24
- if image is None:
25
- return None
26
-
27
- # Convert to PIL Image if needed
28
- if not isinstance(image, Image.Image):
29
- image = Image.fromarray(image)
30
-
31
- # Resize image
32
- image = image.resize(target_size)
33
-
34
- # Convert to array and normalize
35
- img_array = np.array(image)
36
-
37
- # Handle grayscale images
38
- if len(img_array.shape) == 2:
39
- img_array = np.stack([img_array] * 3, axis=-1)
40
-
41
- # Add batch dimension
42
- img_array = np.expand_dims(img_array, axis=0)
43
-
44
- # Normalize to [0, 1]
45
- img_array = img_array.astype('float32') / 255.0
46
-
47
- return img_array
48
-
49
  def predict(image):
50
  """Make predictions with both models"""
51
  if image is None:
@@ -54,56 +90,69 @@ def predict(image):
54
  if vgg16_model is None or custom_cnn_model is None:
55
  return "Models not loaded properly", "Models not loaded properly"
56
 
57
- # Preprocess image
58
- processed_img = preprocess_image(image)
59
-
60
- # Get predictions from both models
61
- vgg16_pred = vgg16_model.predict(processed_img, verbose=0)[0]
62
- custom_cnn_pred = custom_cnn_model.predict(processed_img, verbose=0)[0]
63
-
64
- # Create confidence dictionaries
65
- vgg16_confidence = {CLASS_NAMES[i]: float(vgg16_pred[i]) for i in range(len(CLASS_NAMES))}
66
- custom_cnn_confidence = {CLASS_NAMES[i]: float(custom_cnn_pred[i]) for i in range(len(CLASS_NAMES))}
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- return vgg16_confidence, custom_cnn_confidence
 
 
69
 
70
  # Create Gradio interface
71
- with gr.Blocks(title="Dual Model Comparison") as demo:
72
  gr.Markdown(
73
  """
74
- # πŸ” Dual Model Image Classification
75
 
76
- Compare predictions from two models trained on the same dataset:
77
  - **VGG16 Fine-tuned**: Transfer learning model based on VGG16
78
  - **Custom CNN**: CNN trained from scratch
79
 
80
- Upload an image to see predictions and confidence scores from both models.
81
  """
82
  )
83
 
84
  with gr.Row():
85
  with gr.Column():
86
- input_image = gr.Image(label="Upload Image", type="numpy")
87
- predict_btn = gr.Button("Predict", variant="primary")
88
 
89
  with gr.Row():
90
  with gr.Column():
91
- gr.Markdown("### VGG16 Fine-tuned Model")
92
  vgg16_output = gr.Label(label="Predictions", num_top_classes=4)
93
 
94
  with gr.Column():
95
- gr.Markdown("### Custom CNN Model")
96
  custom_cnn_output = gr.Label(label="Predictions", num_top_classes=4)
97
 
98
- # Examples section (optional - add your example images)
99
- gr.Markdown("### Examples")
100
- gr.Examples(
101
- examples=[
102
- # Add paths to example images here
103
- # ["example1.jpg"],
104
- # ["example2.jpg"],
105
- ],
106
- inputs=input_image,
107
  )
108
 
109
  # Connect the prediction function
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
  from PIL import Image
7
+ import numpy as np
8
 
9
+ # Define your 4 classes
10
+ CLASS_NAMES = ['Cover Drive', 'Pull Shot', 'Cut Shot', 'Straight Drive'] # Update with your actual class names
11
 
12
+ # Custom CNN Model Definition
13
+ class CricketShotCNN(nn.Module):
14
+ def __init__(self, num_classes=4):
15
+ super(CricketShotCNN, self).__init__()
16
+
17
+ # Block 1: Input (3, 224, 224) -> Output (64, 112, 112)
18
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
19
+ self.bn1 = nn.BatchNorm2d(64)
20
+
21
+ # Block 2: Output (128, 56, 56)
22
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
23
+ self.bn2 = nn.BatchNorm2d(128)
24
+
25
+ # Block 3: Output (256, 28, 28)
26
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
27
+ self.bn3 = nn.BatchNorm2d(256)
28
+
29
+ # Block 4: Output (512, 14, 14)
30
+ self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
31
+ self.bn4 = nn.BatchNorm2d(512)
32
+
33
+ self.pool = nn.MaxPool2d(2, 2)
34
+ self.dropout = nn.Dropout(0.5)
35
+
36
+ # Fully Connected Layers
37
+ self.fc1 = nn.Linear(512 * 14 * 14, 512)
38
+ self.fc2 = nn.Linear(512, 128)
39
+ self.fc3 = nn.Linear(128, num_classes)
40
+
41
+ def forward(self, x):
42
+ x = self.pool(F.relu(self.bn1(self.conv1(x))))
43
+ x = self.pool(F.relu(self.bn2(self.conv2(x))))
44
+ x = self.pool(F.relu(self.bn3(self.conv3(x))))
45
+ x = self.pool(F.relu(self.bn4(self.conv4(x))))
46
+
47
+ x = x.view(-1, 512 * 14 * 14)
48
+
49
+ x = F.relu(self.fc1(x))
50
+ x = self.dropout(x)
51
+ x = F.relu(self.fc2(x))
52
+ x = self.fc3(x)
53
+
54
+ return x
55
+
56
+ # Image preprocessing
57
+ transform = transforms.Compose([
58
+ transforms.Resize((224, 224)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
61
+ ])
62
 
63
  # Load models
64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+
66
  def load_models():
67
  try:
68
+ # Load VGG16 fine-tuned model
69
+ vgg16_model = torch.load('vgg16_finetuned.pth', map_location=device)
70
+ vgg16_model.eval()
71
+
72
+ # Load Custom CNN model
73
+ custom_cnn_model = CricketShotCNN(num_classes=4)
74
+ custom_cnn_model.load_state_dict(torch.load('custom_cnn.pth', map_location=device))
75
+ custom_cnn_model.to(device)
76
+ custom_cnn_model.eval()
77
+
78
  return vgg16_model, custom_cnn_model
79
  except Exception as e:
80
  print(f"Error loading models: {e}")
 
82
 
83
  vgg16_model, custom_cnn_model = load_models()
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def predict(image):
86
  """Make predictions with both models"""
87
  if image is None:
 
90
  if vgg16_model is None or custom_cnn_model is None:
91
  return "Models not loaded properly", "Models not loaded properly"
92
 
93
+ try:
94
+ # Convert numpy array to PIL Image
95
+ if isinstance(image, np.ndarray):
96
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
97
+
98
+ # Preprocess image
99
+ img_tensor = transform(image).unsqueeze(0).to(device)
100
+
101
+ # Get predictions from both models
102
+ with torch.no_grad():
103
+ vgg16_output = vgg16_model(img_tensor)
104
+ custom_cnn_output = custom_cnn_model(img_tensor)
105
+
106
+ # Apply softmax to get probabilities
107
+ vgg16_probs = F.softmax(vgg16_output, dim=1)[0]
108
+ custom_cnn_probs = F.softmax(custom_cnn_output, dim=1)[0]
109
+
110
+ # Create confidence dictionaries
111
+ vgg16_confidence = {CLASS_NAMES[i]: float(vgg16_probs[i]) for i in range(len(CLASS_NAMES))}
112
+ custom_cnn_confidence = {CLASS_NAMES[i]: float(custom_cnn_probs[i]) for i in range(len(CLASS_NAMES))}
113
+
114
+ return vgg16_confidence, custom_cnn_confidence
115
 
116
+ except Exception as e:
117
+ print(f"Prediction error: {e}")
118
+ return f"Error: {str(e)}", f"Error: {str(e)}"
119
 
120
  # Create Gradio interface
121
+ with gr.Blocks(title="Cricket Shot Classification - Dual Model Comparison", theme=gr.themes.Soft()) as demo:
122
  gr.Markdown(
123
  """
124
+ # 🏏 Cricket Shot Classification - Dual Model Comparison
125
 
126
+ Compare predictions from two models trained on the same cricket shot dataset:
127
  - **VGG16 Fine-tuned**: Transfer learning model based on VGG16
128
  - **Custom CNN**: CNN trained from scratch
129
 
130
+ Upload an image of a cricket shot to see predictions and confidence scores from both models.
131
  """
132
  )
133
 
134
  with gr.Row():
135
  with gr.Column():
136
+ input_image = gr.Image(label="Upload Cricket Shot Image", type="numpy")
137
+ predict_btn = gr.Button("πŸ” Predict", variant="primary", size="lg")
138
 
139
  with gr.Row():
140
  with gr.Column():
141
+ gr.Markdown("### πŸ“Š VGG16 Fine-tuned Model")
142
  vgg16_output = gr.Label(label="Predictions", num_top_classes=4)
143
 
144
  with gr.Column():
145
+ gr.Markdown("### πŸ“Š Custom CNN Model")
146
  custom_cnn_output = gr.Label(label="Predictions", num_top_classes=4)
147
 
148
+ gr.Markdown(
149
+ """
150
+ ---
151
+ ### πŸ“ About the Models
152
+ - Both models are trained on the same cricket shot dataset with 4 classes
153
+ - Input image size: 224x224 pixels
154
+ - The predictions show probability scores for each cricket shot type
155
+ """
 
156
  )
157
 
158
  # Connect the prediction function