FrAnKu34t23 commited on
Commit
61505b4
·
verified ·
1 Parent(s): 8572cb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -11
app.py CHANGED
@@ -10,6 +10,7 @@ import json
10
  import numpy as np
11
  from torchvision import transforms
12
  import os
 
13
 
14
  # Import our model architecture
15
  from models import create_model
@@ -23,12 +24,22 @@ except Exception:
23
 
24
  # Configuration
25
  # Default to the moved fine-tuned checkpoint if present
26
- MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('best_model_finetuned.pth'))
27
  # Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
28
  HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
29
  CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
 
30
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
 
 
 
 
 
 
 
 
 
32
  # Load class names
33
  if os.path.exists(CLASS_NAMES_PATH):
34
  try:
@@ -47,8 +58,23 @@ def load_checkpoint_model(model_path, device):
47
  heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
48
  Returns (model, actual_num_classes) or (None, None) on failure.
49
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if not os.path.exists(model_path):
51
- print(f"Model file not found at {model_path}")
 
 
52
  # If HF_MODEL_ID is set and transformers are available, try to load from hub
53
  if HF_MODEL_ID and HF_AVAILABLE:
54
  try:
@@ -63,7 +89,14 @@ def load_checkpoint_model(model_path, device):
63
  print("Failed to load HF model from hub:", e)
64
  return None, None
65
 
66
- ckpt = torch.load(model_path, map_location='cpu')
 
 
 
 
 
 
 
67
  # unwrap common dict wrapper
68
  if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
69
  state_dict = ckpt['model_state_dict']
@@ -71,25 +104,50 @@ def load_checkpoint_model(model_path, device):
71
  # if checkpoint is a state dict directly
72
  state_dict = ckpt if isinstance(ckpt, dict) else {}
73
 
 
 
 
 
 
 
 
 
 
74
  # Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
75
  hf_like = any(k.startswith('swin.') or 'swin.embeddings' in k for k in state_dict.keys()) if state_dict else False
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- if hf_like and HF_AVAILABLE and HF_MODEL_ID:
78
- # Try to instantiate HF model from the hub config to match architecture
79
  try:
80
- print(f"Attempting to load Hugging Face model '{HF_MODEL_ID}' and apply checkpoint weights...")
81
- config = AutoConfig.from_pretrained(HF_MODEL_ID)
 
 
 
82
  hf_model = AutoModelForImageClassification.from_config(config)
83
  # load weights non-strictly: match shapes
84
  missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
85
  hf_model.to(device)
86
  hf_model.eval()
87
- print(f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)")
 
 
88
  num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
89
  return hf_model, num_labels
90
  except Exception as e:
91
  print("HF load failed:", e)
 
92
  print("Falling back to local model loader...")
 
93
 
94
  # Fallback: try to detect EfficientNet-like shapes and create local model
95
  # Determine actual num classes by inspecting a likely classifier weight key
@@ -149,11 +207,11 @@ else:
149
  # id2label keys may be strings or ints
150
  # Build ordered class_names list by index
151
  max_idx = max(int(k) for k in id2label.keys())
152
- hf_class_names = [None] * (max_idx + 1)
153
  for k, v in id2label.items():
154
  hf_class_names[int(k)] = v.replace(' ', '_') if isinstance(v, str) else str(v)
155
- # Filter out None at end if any
156
- hf_class_names = [c for c in hf_class_names if c is not None]
157
  if len(hf_class_names) > 0:
158
  class_names = hf_class_names
159
  NUM_CLASSES = len(class_names)
 
10
  import numpy as np
11
  from torchvision import transforms
12
  import os
13
+ import logging
14
 
15
  # Import our model architecture
16
  from models import create_model
 
24
 
25
  # Configuration
26
  # Default to the moved fine-tuned checkpoint if present
27
+ MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('results', 'fine_tune', 'best_model_finetuned.pth'))
28
  # Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
29
  HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
30
  CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
31
+ FORCE_HF_LOAD = os.environ.get('FORCE_HF_LOAD', '0').lower() in ('1', 'true', 'yes')
32
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
 
34
+ # Default HF model id to try when checkpoint looks HF-like and HF_MODEL_ID not set
35
+ DEFAULT_HF_ID = 'Emiel/cub-200-bird-classifier-swin'
36
+
37
+ # Setup file logger for traceability in Spaces
38
+ LOG_FILE = os.environ.get('APP_LOG_PATH', 'app.log')
39
+ logging.basicConfig(level=logging.INFO, filename=LOG_FILE, filemode='a',
40
+ format='%(asctime)s %(levelname)s: %(message)s')
41
+ logger = logging.getLogger(__name__)
42
+
43
  # Load class names
44
  if os.path.exists(CLASS_NAMES_PATH):
45
  try:
 
58
  heuristic handling for Hugging Face (Swin) checkpoints when HF_MODEL_ID is set.
59
  Returns (model, actual_num_classes) or (None, None) on failure.
60
  """
61
+ # If user wants to force HF loading from hub, try that first (useful in Spaces)
62
+ if FORCE_HF_LOAD and HF_MODEL_ID and HF_AVAILABLE:
63
+ try:
64
+ print(f"FORCE_HF_LOAD enabled: loading HF model from hub: {HF_MODEL_ID}")
65
+ hf_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_ID)
66
+ hf_model.to(device)
67
+ hf_model.eval()
68
+ num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
69
+ print(f"Loaded HF model from hub with {num_labels} labels (force)")
70
+ return hf_model, num_labels
71
+ except Exception as e:
72
+ print("Forced HF hub load failed:", e)
73
+
74
  if not os.path.exists(model_path):
75
+ msg = f"Model file not found at {model_path}"
76
+ print(msg)
77
+ logger.info(msg)
78
  # If HF_MODEL_ID is set and transformers are available, try to load from hub
79
  if HF_MODEL_ID and HF_AVAILABLE:
80
  try:
 
89
  print("Failed to load HF model from hub:", e)
90
  return None, None
91
 
92
+ print(f"Loading checkpoint from: {model_path}")
93
+ logger.info(f"Loading checkpoint from: {model_path}")
94
+ try:
95
+ ckpt = torch.load(model_path, map_location='cpu')
96
+ except Exception as e:
97
+ print("Failed to load checkpoint file:", e)
98
+ logger.exception("Failed to load checkpoint file:")
99
+ ckpt = {}
100
  # unwrap common dict wrapper
101
  if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
102
  state_dict = ckpt['model_state_dict']
 
104
  # if checkpoint is a state dict directly
105
  state_dict = ckpt if isinstance(ckpt, dict) else {}
106
 
107
+ # Diagnostic: print a few state_dict keys so we can tell checkpoint format
108
+ try:
109
+ sample_keys = list(state_dict.keys())[:8]
110
+ print("Checkpoint sample keys:", sample_keys)
111
+ logger.info(f"Checkpoint sample keys: {sample_keys}")
112
+ except Exception:
113
+ print("No state_dict keys to sample")
114
+ logger.info("No state_dict keys to sample")
115
+
116
  # Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
117
  hf_like = any(k.startswith('swin.') or 'swin.embeddings' in k for k in state_dict.keys()) if state_dict else False
118
+ hf_msg = f"hf_like_checkpoint_detected={hf_like} HF_AVAILABLE={HF_AVAILABLE} HF_MODEL_ID={'set' if HF_MODEL_ID else 'not-set'}"
119
+ print(hf_msg)
120
+ logger.info(hf_msg)
121
+
122
+ if hf_like and HF_AVAILABLE:
123
+ # choose which HF id to use: env var or default
124
+ hf_id_to_use = HF_MODEL_ID or DEFAULT_HF_ID
125
+ if HF_MODEL_ID is None:
126
+ info_msg = f"HF_MODEL_ID not set; using DEFAULT_HF_ID='{DEFAULT_HF_ID}' to attempt hub load"
127
+ print(info_msg)
128
+ logger.info(info_msg)
129
 
 
 
130
  try:
131
+ msg = f"Attempting to load Hugging Face model '{hf_id_to_use}' and apply checkpoint weights..."
132
+ print(msg)
133
+ logger.info(msg)
134
+ # prefer using the hub config to instantiate exact architecture
135
+ config = AutoConfig.from_pretrained(hf_id_to_use)
136
  hf_model = AutoModelForImageClassification.from_config(config)
137
  # load weights non-strictly: match shapes
138
  missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
139
  hf_model.to(device)
140
  hf_model.eval()
141
+ ok_msg = f"Loaded HF model with non-strict state_dict (missing {len(missing)} keys, unexpected {len(unexpected)} keys)"
142
+ print(ok_msg)
143
+ logger.info(ok_msg)
144
  num_labels = getattr(hf_model.config, 'num_labels', NUM_CLASSES)
145
  return hf_model, num_labels
146
  except Exception as e:
147
  print("HF load failed:", e)
148
+ logger.exception("HF load failed")
149
  print("Falling back to local model loader...")
150
+ logger.info("Falling back to local model loader")
151
 
152
  # Fallback: try to detect EfficientNet-like shapes and create local model
153
  # Determine actual num classes by inspecting a likely classifier weight key
 
207
  # id2label keys may be strings or ints
208
  # Build ordered class_names list by index
209
  max_idx = max(int(k) for k in id2label.keys())
210
+ hf_class_names = [""] * (max_idx + 1)
211
  for k, v in id2label.items():
212
  hf_class_names[int(k)] = v.replace(' ', '_') if isinstance(v, str) else str(v)
213
+ # Filter out empty entries
214
+ hf_class_names = [c for c in hf_class_names if c]
215
  if len(hf_class_names) > 0:
216
  class_names = hf_class_names
217
  NUM_CLASSES = len(class_names)