Shilpaj commited on
Commit
de50636
·
1 Parent(s): aa63283

Fix: Model loading error

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -20,41 +20,45 @@ def load_model(model_path: str):
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  print(f"Using device: {device}")
22
 
23
- # Load the model with default weights first
24
- model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
25
  model = model.to(device)
26
 
27
  # Load custom weights
28
  state_dict = torch.load(model_path, map_location=device)
29
 
30
- # Debug: Print state dict info
31
- print("\nState dict keys:", list(state_dict['model_state_dict'].keys())[:5])
32
- print("Model state dict keys:", list(model.state_dict().keys())[:5])
33
-
34
- # Check if the final layer weights match
35
- fc_weight_shape = state_dict['model_state_dict']['fc.weight'].shape
36
- print(f"\nFC layer weight shape: {fc_weight_shape}")
37
 
38
- filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
39
- print(f"Filtered state dict size: {len(filtered_state_dict)} / {len(state_dict['model_state_dict'])}")
 
 
 
40
 
41
- model.load_state_dict(filtered_state_dict, strict=False)
42
- model.eval()
 
43
 
44
- # Verify model
45
- print("\nModel architecture:")
46
- print(model)
 
 
 
 
47
 
 
48
  return model
49
 
50
 
51
  def load_classes():
52
  """
53
- Load the classes.
54
  """
55
- # Load classes from the same weights version as the model was trained with
56
- weights = models.ResNet50_Weights.IMAGENET1K_V1 # Try V1 instead of V2
57
  classes = weights.meta["categories"]
 
58
  return classes
59
 
60
 
 
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  print(f"Using device: {device}")
22
 
23
+ # Initialize a fresh model without pretrained weights
24
+ model = models.resnet50(weights=None)
25
  model = model.to(device)
26
 
27
  # Load custom weights
28
  state_dict = torch.load(model_path, map_location=device)
29
 
30
+ # Debug: Print original state dict keys
31
+ print("\nOriginal state dict keys:", list(state_dict['model_state_dict'].keys())[:5])
 
 
 
 
 
32
 
33
+ # Remove the 'model.' prefix from state dict keys
34
+ new_state_dict = {}
35
+ for key, value in state_dict['model_state_dict'].items():
36
+ new_key = key.replace('model.', '')
37
+ new_state_dict[new_key] = value
38
 
39
+ # Debug: Print modified state dict keys
40
+ print("Modified state dict keys:", list(new_state_dict.keys())[:5])
41
+ print("Model state dict keys:", list(model.state_dict().keys())[:5])
42
 
43
+ # Load the modified state dict
44
+ try:
45
+ model.load_state_dict(new_state_dict)
46
+ print("Successfully loaded model weights")
47
+ except Exception as e:
48
+ print(f"Error loading state dict: {str(e)}")
49
+ raise e
50
 
51
+ model.eval()
52
  return model
53
 
54
 
55
  def load_classes():
56
  """
57
+ Load the ImageNet classes
58
  """
59
+ weights = models.ResNet50_Weights.IMAGENET1K_V1
 
60
  classes = weights.meta["categories"]
61
+ print(f"Loaded {len(classes)} classes")
62
  return classes
63
 
64