MoinulwithAI commited on
Commit
31c3d5a
·
verified ·
1 Parent(s): 670c205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -36
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
  import torch
3
- from torchvision import transforms
4
- from PIL import Image
5
  import torch.nn as nn
 
 
6
  import os
7
- from torchvision import models
8
 
9
- # Custom Residual Block
10
  class ResidualBlock(nn.Module):
11
  def __init__(self, in_channels, out_channels):
12
  super(ResidualBlock, self).__init__()
@@ -28,7 +27,6 @@ class ResidualBlock(nn.Module):
28
  x = self.relu(x)
29
  return x
30
 
31
- # EfficientNet Model with Novelty (Residual Block)
32
  class EfficientNetWithNovelty(nn.Module):
33
  def __init__(self, num_classes):
34
  super(EfficientNetWithNovelty, self).__init__()
@@ -55,49 +53,58 @@ class EfficientNetWithNovelty(nn.Module):
55
 
56
  return x
57
 
58
- # Load the model and weights
59
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
-
61
- # Update this path with your model path
62
- model_path = 'best_model.pth'
63
- num_classes = 10 # Update based on your dataset
64
-
65
  model = EfficientNetWithNovelty(num_classes)
66
-
67
- # Load checkpoint instead of raw state_dict
68
- checkpoint = torch.load(model_path, map_location=device)
69
- model.load_state_dict(checkpoint["model_state_dict"]) # Correct key for model weights
70
  model.to(device)
71
  model.eval()
72
 
73
- # Define image transformations (same as during training)
74
  transform = transforms.Compose([
75
  transforms.Resize((224, 224)),
76
  transforms.ToTensor(),
77
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
78
  ])
79
 
80
- # Define the prediction function for Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def predict(image):
82
- image = Image.fromarray(image) # Convert numpy array to PIL Image
83
- image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
84
- image = image.to(device)
85
-
 
86
  with torch.no_grad():
87
- outputs = model(image)
88
- _, predicted = torch.max(outputs, 1)
89
-
90
- # Class names for your classification
91
- class_names = ['OUTSWING', 'STRAIGHT', 'BACK_OF_HAND', 'CARROM', 'CROSSSEAM',
92
- 'GOOGLY', 'INSWING', 'KNUCKLE', 'LEGSPIN', 'OFFSPIN']
93
- predicted_label = class_names[predicted.item()]
94
  return predicted_label
95
 
96
- # Create the Gradio Interface
97
- iface = gr.Interface(fn=predict,
98
- inputs=gr.Image(type="numpy"), # Accepts image input
99
- outputs=gr.Text(), # Output the predicted class label
100
- live=True) # live=True enables prediction while image is being uploaded
 
 
101
 
102
- # Launch the interface
103
- iface.launch()
 
1
  import gradio as gr
2
  import torch
 
 
3
  import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
  import os
 
7
 
8
+ # Define the same custom residual block and EfficientNetWithNovelty model
9
  class ResidualBlock(nn.Module):
10
  def __init__(self, in_channels, out_channels):
11
  super(ResidualBlock, self).__init__()
 
27
  x = self.relu(x)
28
  return x
29
 
 
30
  class EfficientNetWithNovelty(nn.Module):
31
  def __init__(self, num_classes):
32
  super(EfficientNetWithNovelty, self).__init__()
 
53
 
54
  return x
55
 
56
+ # Load the model checkpoint
57
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
+ num_classes = 10 # Number of classes as per your dataset
 
 
 
 
59
  model = EfficientNetWithNovelty(num_classes)
60
+ checkpoint = torch.load('best_model2.pth')
61
+ model.load_state_dict(checkpoint['model_state_dict'])
 
 
62
  model.to(device)
63
  model.eval()
64
 
65
+ # Define image transformations for preprocessing
66
  transform = transforms.Compose([
67
  transforms.Resize((224, 224)),
68
  transforms.ToTensor(),
69
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
70
  ])
71
 
72
+ # Define the class labels explicitly
73
+ class_labels = [
74
+ "KNUCKLE",
75
+ "LEGSPIN",
76
+ "OFFSPIN",
77
+ "OUTSWING",
78
+ "STRAIGHT",
79
+ "BACK_OF_HAND",
80
+ "CARROM",
81
+ "CROSSSEAM",
82
+ "GOOGLY",
83
+ "INSWING"
84
+ ]
85
+
86
+ # Prediction function
87
  def predict(image):
88
+ # Preprocess image
89
+ image = Image.fromarray(image) # Convert numpy array to PIL Image if it's from Gradio
90
+ image = transform(image).unsqueeze(0).to(device)
91
+
92
+ # Get model predictions
93
  with torch.no_grad():
94
+ output = model(image)
95
+ _, predicted = torch.max(output, 1)
96
+
97
+ # Get predicted class label
98
+ predicted_label = class_labels[predicted.item()]
 
 
99
  return predicted_label
100
 
101
+ # Gradio interface
102
+ iface = gr.Interface(
103
+ fn=predict,
104
+ inputs=gr.Image(type="numpy", label="Upload Cricket Grip Image"),
105
+ outputs=gr.Textbox(label="Predicted Grip Type"),
106
+ live=True
107
+ )
108
 
109
+ if __name__ == "__main__":
110
+ iface.launch()