keysun89 commited on
Commit
e12a3ad
·
verified ·
1 Parent(s): 676ee83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -9
app.py CHANGED
@@ -90,23 +90,48 @@ transform = transforms.Compose([
90
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
 
92
  def load_models():
 
 
 
 
93
  try:
94
  # Load VGG16 fine-tuned model
 
95
  vgg16_model = VGG16FineTuned(num_classes=4)
96
- vgg16_model.load_state_dict(torch.load('vgg16_finetuned.pth', map_location=device))
 
97
  vgg16_model.to(device)
98
  vgg16_model.eval()
99
-
 
 
 
 
 
 
 
 
100
  # Load Custom CNN model
 
101
  custom_cnn_model = CricketShotCNN(num_classes=4)
102
- custom_cnn_model.load_state_dict(torch.load('cricket_model.pth', map_location=device))
 
103
  custom_cnn_model.to(device)
104
  custom_cnn_model.eval()
105
-
106
- return vgg16_model, custom_cnn_model
 
 
107
  except Exception as e:
108
- print(f"Error loading models: {e}")
109
- return None, None
 
 
 
 
 
 
 
110
 
111
  vgg16_model, custom_cnn_model = load_models()
112
 
@@ -118,6 +143,9 @@ def predict(image):
118
  if vgg16_model is None or custom_cnn_model is None:
119
  return "Models not loaded properly", "Models not loaded properly"
120
 
 
 
 
121
  try:
122
  # Convert numpy array to PIL Image
123
  if isinstance(image, np.ndarray):
@@ -136,8 +164,8 @@ def predict(image):
136
  custom_cnn_probs = F.softmax(custom_cnn_output, dim=1)[0]
137
 
138
  # Create confidence dictionaries
139
- vgg16_confidence = {CLASS_NAMES[i]: float(vgg16_probs[i]) for i in range(len(CLASS_NAMES))}
140
- custom_cnn_confidence = {CLASS_NAMES[i]: float(custom_cnn_probs[i]) for i in range(len(CLASS_NAMES))}
141
 
142
  return vgg16_confidence, custom_cnn_confidence
143
 
 
90
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
 
92
  def load_models():
93
+ vgg16_model = None
94
+ custom_cnn_model = None
95
+ error_messages = []
96
+
97
  try:
98
  # Load VGG16 fine-tuned model
99
+ print("Loading VGG16 model...")
100
  vgg16_model = VGG16FineTuned(num_classes=4)
101
+ vgg16_state = torch.load('vgg16_finetuned.pth', map_location=device, weights_only=False)
102
+ vgg16_model.load_state_dict(vgg16_state)
103
  vgg16_model.to(device)
104
  vgg16_model.eval()
105
+ print("✓ VGG16 model loaded successfully")
106
+ except FileNotFoundError:
107
+ error_messages.append("VGG16: File 'vgg16_finetuned.pth' not found")
108
+ print("✗ VGG16 model file not found")
109
+ except Exception as e:
110
+ error_messages.append(f"VGG16: {str(e)}")
111
+ print(f"✗ VGG16 loading error: {e}")
112
+
113
+ try:
114
  # Load Custom CNN model
115
+ print("Loading Custom CNN model...")
116
  custom_cnn_model = CricketShotCNN(num_classes=4)
117
+ custom_cnn_state = torch.load('custom_cnn.pth', map_location=device, weights_only=False)
118
+ custom_cnn_model.load_state_dict(custom_cnn_state)
119
  custom_cnn_model.to(device)
120
  custom_cnn_model.eval()
121
+ print("✓ Custom CNN model loaded successfully")
122
+ except FileNotFoundError:
123
+ error_messages.append("Custom CNN: File 'custom_cnn.pth' not found")
124
+ print("✗ Custom CNN model file not found")
125
  except Exception as e:
126
+ error_messages.append(f"Custom CNN: {str(e)}")
127
+ print(f"✗ Custom CNN loading error: {e}")
128
+
129
+ if error_messages:
130
+ print("\n⚠️ Model Loading Errors:")
131
+ for msg in error_messages:
132
+ print(f" - {msg}")
133
+
134
+ return vgg16_model, custom_cnn_model
135
 
136
  vgg16_model, custom_cnn_model = load_models()
137
 
 
143
  if vgg16_model is None or custom_cnn_model is None:
144
  return "Models not loaded properly", "Models not loaded properly"
145
 
146
+ # Define class names here to ensure they're in scope
147
+ class_names = ['Cover Drive', 'Pull Shot', 'Cut Shot', 'Straight Drive']
148
+
149
  try:
150
  # Convert numpy array to PIL Image
151
  if isinstance(image, np.ndarray):
 
164
  custom_cnn_probs = F.softmax(custom_cnn_output, dim=1)[0]
165
 
166
  # Create confidence dictionaries
167
+ vgg16_confidence = {class_names[i]: float(vgg16_probs[i]) for i in range(len(class_names))}
168
+ custom_cnn_confidence = {class_names[i]: float(custom_cnn_probs[i]) for i in range(len(class_names))}
169
 
170
  return vgg16_confidence, custom_cnn_confidence
171