Update app.py
Browse files
app.py
CHANGED
|
@@ -90,23 +90,48 @@ transform = transforms.Compose([
|
|
| 90 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 91 |
|
| 92 |
def load_models():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
try:
|
| 94 |
# Load VGG16 fine-tuned model
|
|
|
|
| 95 |
vgg16_model = VGG16FineTuned(num_classes=4)
|
| 96 |
-
|
|
|
|
| 97 |
vgg16_model.to(device)
|
| 98 |
vgg16_model.eval()
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# Load Custom CNN model
|
|
|
|
| 101 |
custom_cnn_model = CricketShotCNN(num_classes=4)
|
| 102 |
-
|
|
|
|
| 103 |
custom_cnn_model.to(device)
|
| 104 |
custom_cnn_model.eval()
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
except Exception as e:
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
vgg16_model, custom_cnn_model = load_models()
|
| 112 |
|
|
@@ -118,6 +143,9 @@ def predict(image):
|
|
| 118 |
if vgg16_model is None or custom_cnn_model is None:
|
| 119 |
return "Models not loaded properly", "Models not loaded properly"
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
try:
|
| 122 |
# Convert numpy array to PIL Image
|
| 123 |
if isinstance(image, np.ndarray):
|
|
@@ -136,8 +164,8 @@ def predict(image):
|
|
| 136 |
custom_cnn_probs = F.softmax(custom_cnn_output, dim=1)[0]
|
| 137 |
|
| 138 |
# Create confidence dictionaries
|
| 139 |
-
vgg16_confidence = {
|
| 140 |
-
custom_cnn_confidence = {
|
| 141 |
|
| 142 |
return vgg16_confidence, custom_cnn_confidence
|
| 143 |
|
|
|
|
| 90 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 91 |
|
| 92 |
def load_models():
|
| 93 |
+
vgg16_model = None
|
| 94 |
+
custom_cnn_model = None
|
| 95 |
+
error_messages = []
|
| 96 |
+
|
| 97 |
try:
|
| 98 |
# Load VGG16 fine-tuned model
|
| 99 |
+
print("Loading VGG16 model...")
|
| 100 |
vgg16_model = VGG16FineTuned(num_classes=4)
|
| 101 |
+
vgg16_state = torch.load('vgg16_finetuned.pth', map_location=device, weights_only=False)
|
| 102 |
+
vgg16_model.load_state_dict(vgg16_state)
|
| 103 |
vgg16_model.to(device)
|
| 104 |
vgg16_model.eval()
|
| 105 |
+
print("✓ VGG16 model loaded successfully")
|
| 106 |
+
except FileNotFoundError:
|
| 107 |
+
error_messages.append("VGG16: File 'vgg16_finetuned.pth' not found")
|
| 108 |
+
print("✗ VGG16 model file not found")
|
| 109 |
+
except Exception as e:
|
| 110 |
+
error_messages.append(f"VGG16: {str(e)}")
|
| 111 |
+
print(f"✗ VGG16 loading error: {e}")
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
# Load Custom CNN model
|
| 115 |
+
print("Loading Custom CNN model...")
|
| 116 |
custom_cnn_model = CricketShotCNN(num_classes=4)
|
| 117 |
+
custom_cnn_state = torch.load('custom_cnn.pth', map_location=device, weights_only=False)
|
| 118 |
+
custom_cnn_model.load_state_dict(custom_cnn_state)
|
| 119 |
custom_cnn_model.to(device)
|
| 120 |
custom_cnn_model.eval()
|
| 121 |
+
print("✓ Custom CNN model loaded successfully")
|
| 122 |
+
except FileNotFoundError:
|
| 123 |
+
error_messages.append("Custom CNN: File 'custom_cnn.pth' not found")
|
| 124 |
+
print("✗ Custom CNN model file not found")
|
| 125 |
except Exception as e:
|
| 126 |
+
error_messages.append(f"Custom CNN: {str(e)}")
|
| 127 |
+
print(f"✗ Custom CNN loading error: {e}")
|
| 128 |
+
|
| 129 |
+
if error_messages:
|
| 130 |
+
print("\n⚠️ Model Loading Errors:")
|
| 131 |
+
for msg in error_messages:
|
| 132 |
+
print(f" - {msg}")
|
| 133 |
+
|
| 134 |
+
return vgg16_model, custom_cnn_model
|
| 135 |
|
| 136 |
vgg16_model, custom_cnn_model = load_models()
|
| 137 |
|
|
|
|
| 143 |
if vgg16_model is None or custom_cnn_model is None:
|
| 144 |
return "Models not loaded properly", "Models not loaded properly"
|
| 145 |
|
| 146 |
+
# Define class names here to ensure they're in scope
|
| 147 |
+
class_names = ['Cover Drive', 'Pull Shot', 'Cut Shot', 'Straight Drive']
|
| 148 |
+
|
| 149 |
try:
|
| 150 |
# Convert numpy array to PIL Image
|
| 151 |
if isinstance(image, np.ndarray):
|
|
|
|
| 164 |
custom_cnn_probs = F.softmax(custom_cnn_output, dim=1)[0]
|
| 165 |
|
| 166 |
# Create confidence dictionaries
|
| 167 |
+
vgg16_confidence = {class_names[i]: float(vgg16_probs[i]) for i in range(len(class_names))}
|
| 168 |
+
custom_cnn_confidence = {class_names[i]: float(custom_cnn_probs[i]) for i in range(len(class_names))}
|
| 169 |
|
| 170 |
return vgg16_confidence, custom_cnn_confidence
|
| 171 |
|