aditya-g07 commited on
Commit
c4c0f29
Β·
1 Parent(s): a0cfc96

Improve model loading with better error handling and fallback mechanisms

Browse files
Files changed (1) hide show
  1. app.py +69 -17
app.py CHANGED
@@ -77,7 +77,6 @@ def load_models():
77
  }
78
 
79
  # Check if model files exist
80
- import os
81
  if not os.path.exists('mobilenet0.25_Final.pth'):
82
  print("❌ mobilenet0.25_Final.pth not found!")
83
  return False
@@ -87,28 +86,69 @@ def load_models():
87
 
88
  print("Model files found, loading MobileNet...")
89
 
90
- # Load MobileNet model
91
- mobilenet_model = RetinaFace(cfg=mobilenet_cfg, phase='test')
92
- mobilenet_model.load_state_dict(torch.load('mobilenet0.25_Final.pth', map_location=device))
93
- mobilenet_model.eval()
94
- mobilenet_model = mobilenet_model.to(device)
95
- print("βœ… MobileNet model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  print("Loading ResNet...")
98
 
99
- # Load ResNet model
100
- resnet_model = RetinaFace(cfg=resnet_cfg, phase='test')
101
- resnet_model.load_state_dict(torch.load('Resnet50_Final.pth', map_location=device))
102
- resnet_model.eval()
103
- resnet_model = resnet_model.to(device)
104
- print("βœ… ResNet model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- print("βœ… All models loaded successfully!")
107
- return True
 
 
 
 
 
108
 
109
  except Exception as e:
110
  import traceback
111
- print(f"❌ Error loading models: {e}")
112
  print(f"❌ Full traceback: {traceback.format_exc()}")
113
  return False
114
 
@@ -127,6 +167,12 @@ def detect_faces(image, model_type="mobilenet", confidence_threshold=0.5, nms_th
127
  'clip': False,
128
  'image_size': 840
129
  }
 
 
 
 
 
 
130
  else:
131
  model = mobilenet_model
132
  cfg = {
@@ -136,9 +182,15 @@ def detect_faces(image, model_type="mobilenet", confidence_threshold=0.5, nms_th
136
  'clip': False,
137
  'image_size': 640
138
  }
 
 
 
 
 
 
139
 
140
  if model is None:
141
- return None, "Models not loaded"
142
 
143
  # Convert PIL to numpy array
144
  if isinstance(image, Image.Image):
 
77
  }
78
 
79
  # Check if model files exist
 
80
  if not os.path.exists('mobilenet0.25_Final.pth'):
81
  print("❌ mobilenet0.25_Final.pth not found!")
82
  return False
 
86
 
87
  print("Model files found, loading MobileNet...")
88
 
89
+ # Load MobileNet model with better error handling
90
+ try:
91
+ mobilenet_model = RetinaFace(cfg=mobilenet_cfg, phase='test')
92
+ print("βœ… MobileNet model instance created")
93
+
94
+ # Load state dict
95
+ mobilenet_state = torch.load('mobilenet0.25_Final.pth', map_location=device)
96
+ print(f"βœ… MobileNet state dict loaded with {len(mobilenet_state.keys())} keys")
97
+
98
+ # Try to load state dict with strict=False to handle key mismatches
99
+ missing_keys, unexpected_keys = mobilenet_model.load_state_dict(mobilenet_state, strict=False)
100
+
101
+ if missing_keys:
102
+ print(f"⚠️ Missing keys in MobileNet: {missing_keys[:5]}...") # Show first 5
103
+ if unexpected_keys:
104
+ print(f"⚠️ Unexpected keys in MobileNet: {unexpected_keys[:5]}...") # Show first 5
105
+
106
+ mobilenet_model.eval()
107
+ mobilenet_model = mobilenet_model.to(device)
108
+ print("βœ… MobileNet model loaded successfully!")
109
+
110
+ except Exception as e:
111
+ print(f"❌ Error loading MobileNet: {e}")
112
+ mobilenet_model = None
113
 
114
  print("Loading ResNet...")
115
 
116
+ # Load ResNet model with better error handling
117
+ try:
118
+ resnet_model = RetinaFace(cfg=resnet_cfg, phase='test')
119
+ print("βœ… ResNet model instance created")
120
+
121
+ # Load state dict
122
+ resnet_state = torch.load('Resnet50_Final.pth', map_location=device)
123
+ print(f"βœ… ResNet state dict loaded with {len(resnet_state.keys())} keys")
124
+
125
+ # Try to load state dict with strict=False to handle key mismatches
126
+ missing_keys, unexpected_keys = resnet_model.load_state_dict(resnet_state, strict=False)
127
+
128
+ if missing_keys:
129
+ print(f"⚠️ Missing keys in ResNet: {missing_keys[:5]}...") # Show first 5
130
+ if unexpected_keys:
131
+ print(f"⚠️ Unexpected keys in ResNet: {unexpected_keys[:5]}...") # Show first 5
132
+
133
+ resnet_model.eval()
134
+ resnet_model = resnet_model.to(device)
135
+ print("βœ… ResNet model loaded successfully!")
136
+
137
+ except Exception as e:
138
+ print(f"❌ Error loading ResNet: {e}")
139
+ resnet_model = None
140
 
141
+ # Check if at least one model loaded
142
+ if mobilenet_model is not None or resnet_model is not None:
143
+ print("βœ… At least one model loaded successfully!")
144
+ return True
145
+ else:
146
+ print("❌ No models loaded successfully!")
147
+ return False
148
 
149
  except Exception as e:
150
  import traceback
151
+ print(f"❌ Error in load_models: {e}")
152
  print(f"❌ Full traceback: {traceback.format_exc()}")
153
  return False
154
 
 
167
  'clip': False,
168
  'image_size': 840
169
  }
170
+ if model is None:
171
+ # Fallback to MobileNet if ResNet not available
172
+ print("⚠️ ResNet not available, falling back to MobileNet")
173
+ model = mobilenet_model
174
+ model_type = "mobilenet"
175
+ cfg['image_size'] = 640
176
  else:
177
  model = mobilenet_model
178
  cfg = {
 
182
  'clip': False,
183
  'image_size': 640
184
  }
185
+ if model is None:
186
+ # Fallback to ResNet if MobileNet not available
187
+ print("⚠️ MobileNet not available, falling back to ResNet")
188
+ model = resnet_model
189
+ model_type = "resnet"
190
+ cfg['image_size'] = 840
191
 
192
  if model is None:
193
+ return None, "❌ No models are loaded. Please check the model loading logs."
194
 
195
  # Convert PIL to numpy array
196
  if isinstance(image, Image.Image):