import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import numpy as np class_names = ['drive', 'legglance_flick', 'pullshot', 'sweep'] # VGG16 Fine-tuned Model Definition class VGG16FineTuned(nn.Module): def __init__(self, num_classes=4): super(VGG16FineTuned, self).__init__() # Load pre-trained VGG16 features vgg16 = models.vgg16(pretrained=False) self.features = vgg16.features self.avgpool = vgg16.avgpool # Custom classifier to match your architecture self.classifier = nn.Sequential( nn.Linear(25088, 1024), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(512, num_classes) ) def forward(self, x): x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x # Custom CNN Model Definition class CricketShotCNN(nn.Module): def __init__(self, num_classes=4): super(CricketShotCNN, self).__init__() # Block 1: Input (3, 224, 224) -> Output (64, 112, 112) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(64) # Block 2: Output (128, 56, 56) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(128) # Block 3: Output (256, 28, 28) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(256) # Block 4: Output (512, 14, 14) self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(512) self.pool = nn.MaxPool2d(2, 2) self.dropout = nn.Dropout(0.5) # Fully Connected Layers self.fc1 = nn.Linear(512 * 14 * 14, 512) self.fc2 = nn.Linear(512, 128) self.fc3 = nn.Linear(128, num_classes) def forward(self, x): x = self.pool(F.relu(self.bn1(self.conv1(x)))) x = self.pool(F.relu(self.bn2(self.conv2(x)))) x = self.pool(F.relu(self.bn3(self.conv3(x)))) x = self.pool(F.relu(self.bn4(self.conv4(x)))) x = x.view(-1, 512 * 14 * 14) x = F.relu(self.fc1(x)) x = self.dropout(x) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load models device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_models(): vgg16_model = None custom_cnn_model = None error_messages = [] try: # Load VGG16 fine-tuned model print("Loading VGG16 model...") vgg16_model = VGG16FineTuned(num_classes=4) vgg16_state = torch.load('vgg16_finetuned.pth', map_location=device, weights_only=False) vgg16_model.load_state_dict(vgg16_state) vgg16_model.to(device) vgg16_model.eval() print("✓ VGG16 model loaded successfully") except FileNotFoundError: error_messages.append("VGG16: File 'vgg16_finetuned.pth' not found") print("✗ VGG16 model file not found") except Exception as e: error_messages.append(f"VGG16: {str(e)}") print(f"✗ VGG16 loading error: {e}") try: # Load Custom CNN model print("Loading Custom CNN model...") custom_cnn_model = CricketShotCNN(num_classes=4) custom_cnn_state = torch.load('custom_cnn.pth', map_location=device, weights_only=False) custom_cnn_model.load_state_dict(custom_cnn_state) custom_cnn_model.to(device) custom_cnn_model.eval() print("✓ Custom CNN model loaded successfully") except FileNotFoundError: error_messages.append("Custom CNN: File 'custom_cnn.pth' not found") print("✗ Custom CNN model file not found") except Exception as e: error_messages.append(f"Custom CNN: {str(e)}") print(f"✗ Custom CNN loading error: {e}") if error_messages: print("\n⚠️ Model Loading Errors:") for msg in error_messages: print(f" - {msg}") return vgg16_model, custom_cnn_model vgg16_model, custom_cnn_model = load_models() def predict(image): """Make predictions with both models""" if image is None: return None, None if vgg16_model is None or custom_cnn_model is None: return "Models not loaded properly", "Models not loaded properly" # Define class names here to ensure they're in scope class_names = ['drive', 'legglance_flick', 'pullshot', 'sweep'] try: # Convert numpy array to PIL Image if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8'), 'RGB') # Preprocess image img_tensor = transform(image).unsqueeze(0).to(device) # Get predictions from both models with torch.no_grad(): vgg16_output = vgg16_model(img_tensor) custom_cnn_output = custom_cnn_model(img_tensor) # Apply softmax to get probabilities vgg16_probs = F.softmax(vgg16_output, dim=1)[0] custom_cnn_probs = F.softmax(custom_cnn_output, dim=1)[0] # Create confidence dictionaries vgg16_confidence = {class_names[i]: float(vgg16_probs[i]) for i in range(len(class_names))} custom_cnn_confidence = {class_names[i]: float(custom_cnn_probs[i]) for i in range(len(class_names))} return vgg16_confidence, custom_cnn_confidence except Exception as e: print(f"Prediction error: {e}") return f"Error: {str(e)}", f"Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="Cricket Shot Classification - Dual Model Comparison", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🏏 Cricket Shot Classification - Dual Model Comparison Compare predictions from two models trained on the same cricket shot dataset: - **VGG16 Fine-tuned**: Transfer learning model based on VGG16 - **Custom CNN**: CNN trained from scratch Upload an image of a cricket shot to see predictions and confidence scores from both models. """ ) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload Cricket Shot Image", type="numpy") predict_btn = gr.Button("🔍 Predict", variant="primary", size="lg") with gr.Row(): with gr.Column(): gr.Markdown("### 📊 VGG16 Fine-tuned Model") vgg16_output = gr.Label(label="Predictions", num_top_classes=4) with gr.Column(): gr.Markdown("### 📊 Custom CNN Model") custom_cnn_output = gr.Label(label="Predictions", num_top_classes=4) gr.Markdown( """ --- ### 📝 About the Models - Both models are trained on the same cricket shot dataset with 4 classes - Input image size: 224x224 pixels - The predictions show probability scores for each cricket shot type """ ) # Connect the prediction function predict_btn.click( fn=predict, inputs=input_image, outputs=[vgg16_output, custom_cnn_output] ) # Launch the app if __name__ == "__main__": demo.launch()