FrAnKu34t23 commited on
Commit
c99d892
·
verified ·
1 Parent(s): b2baace

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -71
app.py CHANGED
@@ -14,82 +14,132 @@ import os
14
  # Import our model architecture
15
  from models import create_model
16
 
 
 
 
 
 
 
 
17
  # Configuration
18
- MODEL_PATH = "best_model.pth"
19
- CLASS_NAMES_PATH = "class_names.json"
 
 
 
20
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
 
22
  # Load class names
23
- with open(CLASS_NAMES_PATH, 'r') as f:
24
- class_names = json.load(f)
 
 
 
 
 
 
25
 
26
  NUM_CLASSES = len(class_names)
27
 
28
- # Load model - detect architecture from checkpoint
29
- print("Loading model...")
30
 
31
- # First, try to detect the correct architecture from the model file
32
- if os.path.exists(MODEL_PATH):
33
- checkpoint = torch.load(MODEL_PATH, map_location='cpu')
34
-
35
- # Detect EfficientNet variant based on feature dimensions
36
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
37
- state_dict = checkpoint['model_state_dict']
38
- else:
39
- state_dict = checkpoint
40
-
41
- # Check backbone head feature size to determine EfficientNet variant
42
- if 'backbone._conv_head.weight' in state_dict:
43
- conv_head_shape = state_dict['backbone._conv_head.weight'].shape
44
- if conv_head_shape[0] == 1536: # EfficientNet-B3
45
- model_type = 'efficientnet_b3'
46
- elif conv_head_shape[0] == 1408: # EfficientNet-B2
47
- model_type = 'efficientnet_b2'
48
- elif conv_head_shape[0] == 1280: # EfficientNet-B0/B1
49
- model_type = 'efficientnet_b1'
50
- else:
51
- model_type = 'efficientnet_b2' # Default fallback
52
- else:
53
- model_type = 'efficientnet_b2' # Default fallback
54
-
55
- # Check actual number of classes from classifier
56
- if 'classifier.9.weight' in state_dict:
57
- actual_classes = state_dict['classifier.9.weight'].shape[0]
58
  else:
59
- actual_classes = NUM_CLASSES
60
-
61
- print("Detected model: {} with {} classes".format(model_type, actual_classes))
62
-
63
- else:
64
- model_type = 'efficientnet_b2'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  actual_classes = NUM_CLASSES
66
- print("Model file not found, using default: {}".format(model_type))
 
 
 
 
 
 
67
 
68
- model = create_model(
69
- num_classes=actual_classes,
70
- model_type=model_type,
71
- pretrained=False, # We're loading trained weights
72
- dropout_rate=0.3
73
- )
 
 
 
 
 
 
 
 
74
 
75
- # Load trained weights
76
- if os.path.exists(MODEL_PATH):
 
77
  try:
78
- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
79
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
80
- model.load_state_dict(checkpoint['model_state_dict'])
81
- print("✅ Model loaded successfully! ({}, {} classes)".format(model_type, actual_classes))
82
- else:
83
- model.load_state_dict(checkpoint)
84
- print("✅ Model loaded successfully! ({}, {} classes)".format(model_type, actual_classes))
85
  except Exception as e:
86
- print(" Error loading model: {}".format(str(e)))
87
- print("Please ensure the model architecture matches the saved weights.")
88
- else:
89
- print("⚠️ Model file not found. Please ensure best_model.pth is in the repository.")
90
 
91
- model.to(DEVICE)
92
- model.eval()
 
 
 
 
 
 
93
 
94
  def predict_bird(image):
95
  """
@@ -118,24 +168,39 @@ def predict_bird(image):
118
  # Prediction
119
  with torch.no_grad():
120
  outputs = model(input_tensor)
121
- probabilities = F.softmax(outputs, dim=1)
122
- confidence, predicted = torch.max(probabilities, 1)
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # Get top 5 predictions
125
- top5_prob, top5_indices = torch.topk(probabilities, 5)
126
-
127
  # Format results
128
  results = {}
129
- for i in range(5):
130
- class_idx = top5_indices[0][i].item()
131
- prob = top5_prob[0][i].item()
132
  # Handle potential class index mismatch
133
  if class_idx < len(class_names):
134
  class_name = class_names[class_idx].replace('_', ' ')
135
  else:
136
  class_name = "Class_" + str(class_idx)
137
- results[class_name] = float(prob)
138
-
139
  return results
140
 
141
  except Exception as e:
 
14
  # Import our model architecture
15
  from models import create_model
16
 
17
+ # Optional: Hugging Face imports (used only when evaluating HF-format checkpoints)
18
+ try:
19
+ from transformers import AutoConfig, AutoModelForImageClassification
20
+ HF_AVAILABLE = True
21
+ except Exception:
22
+ HF_AVAILABLE = False
23
+
24
  # Configuration
25
+ # Default to the moved fine-tuned checkpoint if present
26
+ MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('results', 'fine_tune', 'best_model_finetuned.pth'))
27
+ # Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
28
+ HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
29
+ CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
30
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
  # Load class names
33
+ if os.path.exists(CLASS_NAMES_PATH):
34
+ try:
35
+ with open(CLASS_NAMES_PATH, 'r') as f:
36
+ class_names = json.load(f)
37
+ except Exception:
38
+ class_names = []
39
+ else:
40
+ class_names = []
41
 
42
  NUM_CLASSES = len(class_names)
43
 
 
 
44
 
45
+ def load_checkpoint_model(model_path, device):
46
+ """Attempt to load a checkpoint. Supports local create_model-based checkpoints and
47
+ heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
48
+ Returns (model, actual_num_classes) or (None, None) on failure.
49
+ """
50
+ if not os.path.exists(model_path):
51
+ print(f"Model file not found at {model_path}")
52
+ # If HF_MODEL_ID is set and transformers are available, try to load from hub
53
+ if HF_MODEL_ID and HF_AVAILABLE:
54
+ try:
55
+ print(f"Attempting to load model from Hugging Face Hub: {HF_MODEL_ID}")
56
+ hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
57
+ hf_model.to(device)
58
+ hf_model.eval()
59
+ num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
60
+ print(f"Loaded HF model from hub with {num_labels} labels")
61
+ return hf_model, num_labels
62
+ except Exception as e:
63
+ print("Failed to load HF model from hub:", e)
64
+ return None, None
65
+
66
+ ckpt = torch.load(model_path, map_location='cpu')
67
+ # unwrap common dict wrapper
68
+ if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
69
+ state_dict = ckpt['model_state_dict']
 
 
70
  else:
71
+ # if checkpoint is a state dict directly
72
+ state_dict = ckpt if isinstance(ckpt, dict) else {}
73
+
74
+ # Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
75
+ hf_like = any(k.startswith('swin.') or 'swin.embeddings' in k for k in state_dict.keys()) if state_dict else False
76
+
77
+ if hf_like and HF_AVAILABLE and HF_MODEL_ID:
78
+ # Try to instantiate HF model from the hub config to match architecture
79
+ try:
80
+ print(f"Attempting to load Hugging Face model '{HF_MODEL_ID}' and apply checkpoint weights...")
81
+ config = AutoConfig.from_pretrained(HF_MODEL_ID)
82
+ hf_model = AutoModelForImageClassification.from_config(config)
83
+ # load weights non-strictly: match shapes
84
+ missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
85
+ hf_model.to(device)
86
+ hf_model.eval()
87
+ print(f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)")
88
+ num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
89
+ return hf_model, num_labels
90
+ except Exception as e:
91
+ print("HF load failed:", e)
92
+ print("Falling back to local model loader...")
93
+
94
+ # Fallback: try to detect EfficientNet-like shapes and create local model
95
+ # Determine actual num classes by inspecting a likely classifier weight key
96
  actual_classes = NUM_CLASSES
97
+ for k, v in state_dict.items():
98
+ if k.endswith('classifier.9.weight') or k.endswith('classifier.weight'):
99
+ try:
100
+ actual_classes = v.shape[0]
101
+ break
102
+ except Exception:
103
+ pass
104
 
105
+ # Heuristic to choose an EfficientNet variant based on conv head size
106
+ model_type = 'efficientnet_b2'
107
+ if state_dict:
108
+ if 'backbone._conv_head.weight' in state_dict:
109
+ try:
110
+ conv_head_shape = state_dict['backbone._conv_head.weight'].shape
111
+ if conv_head_shape[0] == 1536:
112
+ model_type = 'efficientnet_b3'
113
+ elif conv_head_shape[0] == 1408:
114
+ model_type = 'efficientnet_b2'
115
+ elif conv_head_shape[0] == 1280:
116
+ model_type = 'efficientnet_b1'
117
+ except Exception:
118
+ pass
119
 
120
+ print(f"Creating local model {model_type} with {actual_classes} classes (fallback)")
121
+ model = create_model(num_classes=actual_classes, model_type=model_type, pretrained=False, dropout_rate=0.3)
122
+ # Try to load state dict
123
  try:
124
+ # if ckpt was a dict without model_state_dict, attempt to load directly
125
+ to_load = state_dict if state_dict else ckpt
126
+ model.load_state_dict(to_load, strict=False)
127
+ model.to(device)
128
+ model.eval()
129
+ print("✅ Local model loaded (non-strict).")
130
+ return model, actual_classes
131
  except Exception as e:
132
+ print("Failed to load local model:", e)
133
+ return None, None
 
 
134
 
135
+
136
+ # Load model
137
+ print("Loading model...", MODEL_PATH)
138
+ model, actual_classes = load_checkpoint_model(MODEL_PATH, DEVICE)
139
+ if model is None:
140
+ print("No model available. The app will still launch but predictions will fail.")
141
+ else:
142
+ print(f"Model ready. Classes={actual_classes}")
143
 
144
  def predict_bird(image):
145
  """
 
168
  # Prediction
169
  with torch.no_grad():
170
  outputs = model(input_tensor)
171
+
172
+ # Handle Hugging Face ModelOutput objects
173
+ try:
174
+ # HF ModelOutput may be dict-like with a 'logits' attribute
175
+ if hasattr(outputs, 'logits'):
176
+ logits = outputs.logits
177
+ elif isinstance(outputs, (tuple, list)):
178
+ logits = outputs[0]
179
+ else:
180
+ logits = outputs
181
+ except Exception:
182
+ logits = outputs
183
+
184
+ # Ensure logits is a tensor
185
+ if not isinstance(logits, torch.Tensor):
186
+ logits = torch.tensor(np.asarray(logits)).to(DEVICE)
187
+
188
+ probabilities = F.softmax(logits, dim=1)
189
  # Get top 5 predictions
190
+ top5_prob, top5_indices = torch.topk(probabilities, min(5, probabilities.shape[1]), dim=1)
191
+
192
  # Format results
193
  results = {}
194
+ for i in range(top5_indices.shape[1]):
195
+ class_idx = int(top5_indices[0][i].item())
196
+ prob = float(top5_prob[0][i].item())
197
  # Handle potential class index mismatch
198
  if class_idx < len(class_names):
199
  class_name = class_names[class_idx].replace('_', ' ')
200
  else:
201
  class_name = "Class_" + str(class_idx)
202
+ results[class_name] = prob
203
+
204
  return results
205
 
206
  except Exception as e: