File size: 7,733 Bytes
bd4e104 da9036a 84508af bd4e104 da9036a bd4e104 676ee83 bd4e104 84508af da9036a bd4e104 da9036a bd4e104 e12a3ad bd4e104 da9036a e12a3ad 84508af e12a3ad 84508af da9036a e12a3ad da9036a e12a3ad da9036a e12a3ad da9036a e12a3ad bd4e104 e12a3ad bd4e104 e12a3ad 8ca6700 e12a3ad da9036a e12a3ad da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 da9036a bd4e104 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import numpy as np
class_names = ['drive', 'legglance_flick', 'pullshot', 'sweep']
# VGG16 Fine-tuned Model Definition
class VGG16FineTuned(nn.Module):
def __init__(self, num_classes=4):
super(VGG16FineTuned, self).__init__()
# Load pre-trained VGG16 features
vgg16 = models.vgg16(pretrained=False)
self.features = vgg16.features
self.avgpool = vgg16.avgpool
# Custom classifier to match your architecture
self.classifier = nn.Sequential(
nn.Linear(25088, 1024),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# Custom CNN Model Definition
class CricketShotCNN(nn.Module):
def __init__(self, num_classes=4):
super(CricketShotCNN, self).__init__()
# Block 1: Input (3, 224, 224) -> Output (64, 112, 112)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
# Block 2: Output (128, 56, 56)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(128)
# Block 3: Output (256, 28, 28)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(256)
# Block 4: Output (512, 14, 14)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(512)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.5)
# Fully Connected Layers
self.fc1 = nn.Linear(512 * 14 * 14, 512)
self.fc2 = nn.Linear(512, 128)
self.fc3 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.bn1(self.conv1(x))))
x = self.pool(F.relu(self.bn2(self.conv2(x))))
x = self.pool(F.relu(self.bn3(self.conv3(x))))
x = self.pool(F.relu(self.bn4(self.conv4(x))))
x = x.view(-1, 512 * 14 * 14)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_models():
vgg16_model = None
custom_cnn_model = None
error_messages = []
try:
# Load VGG16 fine-tuned model
print("Loading VGG16 model...")
vgg16_model = VGG16FineTuned(num_classes=4)
vgg16_state = torch.load('vgg16_finetuned.pth', map_location=device, weights_only=False)
vgg16_model.load_state_dict(vgg16_state)
vgg16_model.to(device)
vgg16_model.eval()
print("β VGG16 model loaded successfully")
except FileNotFoundError:
error_messages.append("VGG16: File 'vgg16_finetuned.pth' not found")
print("β VGG16 model file not found")
except Exception as e:
error_messages.append(f"VGG16: {str(e)}")
print(f"β VGG16 loading error: {e}")
try:
# Load Custom CNN model
print("Loading Custom CNN model...")
custom_cnn_model = CricketShotCNN(num_classes=4)
custom_cnn_state = torch.load('custom_cnn.pth', map_location=device, weights_only=False)
custom_cnn_model.load_state_dict(custom_cnn_state)
custom_cnn_model.to(device)
custom_cnn_model.eval()
print("β Custom CNN model loaded successfully")
except FileNotFoundError:
error_messages.append("Custom CNN: File 'custom_cnn.pth' not found")
print("β Custom CNN model file not found")
except Exception as e:
error_messages.append(f"Custom CNN: {str(e)}")
print(f"β Custom CNN loading error: {e}")
if error_messages:
print("\nβ οΈ Model Loading Errors:")
for msg in error_messages:
print(f" - {msg}")
return vgg16_model, custom_cnn_model
vgg16_model, custom_cnn_model = load_models()
def predict(image):
"""Make predictions with both models"""
if image is None:
return None, None
if vgg16_model is None or custom_cnn_model is None:
return "Models not loaded properly", "Models not loaded properly"
# Define class names here to ensure they're in scope
class_names = ['drive', 'legglance_flick', 'pullshot', 'sweep']
try:
# Convert numpy array to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Get predictions from both models
with torch.no_grad():
vgg16_output = vgg16_model(img_tensor)
custom_cnn_output = custom_cnn_model(img_tensor)
# Apply softmax to get probabilities
vgg16_probs = F.softmax(vgg16_output, dim=1)[0]
custom_cnn_probs = F.softmax(custom_cnn_output, dim=1)[0]
# Create confidence dictionaries
vgg16_confidence = {class_names[i]: float(vgg16_probs[i]) for i in range(len(class_names))}
custom_cnn_confidence = {class_names[i]: float(custom_cnn_probs[i]) for i in range(len(class_names))}
return vgg16_confidence, custom_cnn_confidence
except Exception as e:
print(f"Prediction error: {e}")
return f"Error: {str(e)}", f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Cricket Shot Classification - Dual Model Comparison", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# π Cricket Shot Classification - Dual Model Comparison
Compare predictions from two models trained on the same cricket shot dataset:
- **VGG16 Fine-tuned**: Transfer learning model based on VGG16
- **Custom CNN**: CNN trained from scratch
Upload an image of a cricket shot to see predictions and confidence scores from both models.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Cricket Shot Image", type="numpy")
predict_btn = gr.Button("π Predict", variant="primary", size="lg")
with gr.Row():
with gr.Column():
gr.Markdown("### π VGG16 Fine-tuned Model")
vgg16_output = gr.Label(label="Predictions", num_top_classes=4)
with gr.Column():
gr.Markdown("### π Custom CNN Model")
custom_cnn_output = gr.Label(label="Predictions", num_top_classes=4)
gr.Markdown(
"""
---
### π About the Models
- Both models are trained on the same cricket shot dataset with 4 classes
- Input image size: 224x224 pixels
- The predictions show probability scores for each cricket shot type
"""
)
# Connect the prediction function
predict_btn.click(
fn=predict,
inputs=input_image,
outputs=[vgg16_output, custom_cnn_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |