Bliss-Ruth commited on
Commit
b6b49c5
·
verified ·
1 Parent(s): 3ab18de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -6
app.py CHANGED
@@ -13,7 +13,7 @@ import os
13
  print("🚀 Loading Ugandan Sign Language Model...")
14
 
15
  # ============================================================================
16
- # MODEL SETUP
17
  # ============================================================================
18
 
19
  class MinimalClassifier(nn.Module):
@@ -32,24 +32,66 @@ processor = XCLIPProcessor.from_pretrained("microsoft/xclip-base-patch32")
32
  xclip_model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32").to(device)
33
  xclip_model.eval()
34
 
35
- # Load your trained model
36
  try:
37
  checkpoint = torch.load("best_xclip_model.pth", map_location=device, weights_only=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  model = MinimalClassifier(
39
  input_dim=512,
40
- num_classes=checkpoint['num_classes'],
41
  dropout=0.5
42
  ).to(device)
43
- model.load_state_dict(checkpoint['model_state_dict'])
44
- model.eval()
45
 
46
- id_to_label = checkpoint['id_to_label']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  label_to_id = {v: k for k, v in id_to_label.items()}
48
 
 
49
  print(f"✅ Model loaded! Can recognize {len(id_to_label)} signs")
 
50
 
51
  except Exception as e:
52
  print(f"❌ Error loading model: {e}")
 
53
  exit(1)
54
 
55
  # ============================================================================
@@ -121,6 +163,7 @@ def predict_sign(video_path):
121
  return predicted_label, confidence_value
122
 
123
  except Exception as e:
 
124
  return "Unknown", 0.0
125
 
126
  # ============================================================================
 
13
  print("🚀 Loading Ugandan Sign Language Model...")
14
 
15
  # ============================================================================
16
+ # MODEL SETUP - FIXED VERSION
17
  # ============================================================================
18
 
19
  class MinimalClassifier(nn.Module):
 
32
  xclip_model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32").to(device)
33
  xclip_model.eval()
34
 
35
+ # Load your trained model - WITH ERROR HANDLING
36
  try:
37
  checkpoint = torch.load("best_xclip_model.pth", map_location=device, weights_only=False)
38
+
39
+ # DEBUG: Check what's in the checkpoint
40
+ print(f"🔍 Checkpoint keys: {list(checkpoint.keys())}")
41
+
42
+ # FIX: Handle missing 'num_classes' key
43
+ if 'num_classes' in checkpoint:
44
+ num_classes = checkpoint['num_classes']
45
+ else:
46
+ # Try to infer number of classes
47
+ if 'id_to_label' in checkpoint:
48
+ num_classes = len(checkpoint['id_to_label'])
49
+ elif 'label_to_id' in checkpoint:
50
+ num_classes = len(checkpoint['label_to_id'])
51
+ else:
52
+ # Count from model weights
53
+ for key in checkpoint.keys():
54
+ if 'model_state_dict' in checkpoint:
55
+ weight_key = [k for k in checkpoint['model_state_dict'].keys() if 'classifier' in k and 'weight' in k]
56
+ if weight_key:
57
+ num_classes = checkpoint['model_state_dict'][weight_key[0]].shape[0]
58
+ break
59
+ else:
60
+ num_classes = 85 # Default fallback
61
+
62
+ print(f"✅ Using num_classes: {num_classes}")
63
+
64
+ # Initialize model
65
  model = MinimalClassifier(
66
  input_dim=512,
67
+ num_classes=num_classes,
68
  dropout=0.5
69
  ).to(device)
 
 
70
 
71
+ # Load state dict
72
+ if 'model_state_dict' in checkpoint:
73
+ model.load_state_dict(checkpoint['model_state_dict'])
74
+ else:
75
+ # If checkpoint IS the state dict
76
+ model.load_state_dict(checkpoint)
77
+
78
+ # Load label mappings
79
+ if 'id_to_label' in checkpoint:
80
+ id_to_label = checkpoint['id_to_label']
81
+ else:
82
+ # Create default mapping
83
+ id_to_label = {i: f"class_{i}" for i in range(num_classes)}
84
+ print("⚠️ Created default label mapping")
85
+
86
  label_to_id = {v: k for k, v in id_to_label.items()}
87
 
88
+ model.eval()
89
  print(f"✅ Model loaded! Can recognize {len(id_to_label)} signs")
90
+ print(f"📊 Sample classes: {list(id_to_label.values())[:5]}")
91
 
92
  except Exception as e:
93
  print(f"❌ Error loading model: {e}")
94
+ print("💡 TIP: Make sure your model file has 'num_classes' or 'id_to_label' key")
95
  exit(1)
96
 
97
  # ============================================================================
 
163
  return predicted_label, confidence_value
164
 
165
  except Exception as e:
166
+ print(f"❌ Prediction error: {e}")
167
  return "Unknown", 0.0
168
 
169
  # ============================================================================