FrAnKu34t23 commited on
Commit
f9fef0f
·
verified ·
1 Parent(s): b543dd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -8
app.py CHANGED
@@ -97,16 +97,31 @@ def load_checkpoint_model(model_path, device):
97
  print("Failed to load checkpoint file:", e)
98
  logger.exception("Failed to load checkpoint file:")
99
  ckpt = {}
100
- # unwrap common dict wrapper
101
- if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
102
- state_dict = ckpt['model_state_dict']
103
- else:
104
- # if checkpoint is a state dict directly
105
- state_dict = ckpt if isinstance(ckpt, dict) else {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  # Diagnostic: print a few state_dict keys so we can tell checkpoint format
108
  try:
109
- sample_keys = list(state_dict.keys())[:8]
110
  print("Checkpoint sample keys:", sample_keys)
111
  logger.info(f"Checkpoint sample keys: {sample_keys}")
112
  except Exception:
@@ -114,7 +129,13 @@ def load_checkpoint_model(model_path, device):
114
  logger.info("No state_dict keys to sample")
115
 
116
  # Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
117
- hf_like = any(k.startswith('swin.') or 'swin.embeddings' in k for k in state_dict.keys()) if state_dict else False
 
 
 
 
 
 
118
  hf_msg = f"hf_like_checkpoint_detected={hf_like} HF_AVAILABLE={HF_AVAILABLE} HF_MODEL_ID={'set' if HF_MODEL_ID else 'not-set'}"
119
  print(hf_msg)
120
  logger.info(hf_msg)
 
97
  print("Failed to load checkpoint file:", e)
98
  logger.exception("Failed to load checkpoint file:")
99
  ckpt = {}
100
+ # unwrap common dict wrapper (support both 'model_state_dict' and 'state_dict')
101
+ state_dict = {}
102
+ if isinstance(ckpt, dict):
103
+ if 'model_state_dict' in ckpt and isinstance(ckpt['model_state_dict'], dict):
104
+ state_dict = ckpt['model_state_dict']
105
+ elif 'state_dict' in ckpt and isinstance(ckpt['state_dict'], dict):
106
+ state_dict = ckpt['state_dict']
107
+ else:
108
+ # fallback: ckpt may already be a state dict
109
+ state_dict = ckpt
110
+
111
+ # If the state_dict is a single-key wrapper (e.g., {'state_dict': {...}} or {'model': {...}}), unwrap one more level
112
+ if isinstance(state_dict, dict) and len(state_dict) == 1:
113
+ sole_val = next(iter(state_dict.values()))
114
+ if isinstance(sole_val, dict):
115
+ # adopt inner dict as state_dict if it looks like parameters
116
+ inner_keys = list(sole_val.keys())[:8]
117
+ # Heuristic: keys with '.' and numeric shapes indicate a param dict
118
+ if any('.' in k for k in inner_keys):
119
+ logger.info(f"Unwrapping single-key checkpoint wrapper, inner keys sample: {inner_keys}")
120
+ state_dict = sole_val
121
 
122
  # Diagnostic: print a few state_dict keys so we can tell checkpoint format
123
  try:
124
+ sample_keys = list(state_dict.keys())[:16]
125
  print("Checkpoint sample keys:", sample_keys)
126
  logger.info(f"Checkpoint sample keys: {sample_keys}")
127
  except Exception:
 
129
  logger.info("No state_dict keys to sample")
130
 
131
  # Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
132
+ # Detect HF-like keys; strip common 'module.' prefix before checking
133
+ def key_is_hf_like(k: str) -> bool:
134
+ kk = k.replace('module.', '')
135
+ kk = kk.lower()
136
+ return kk.startswith('swin.') or 'swin.embeddings' in kk or 'swin.patch_embeddings' in kk
137
+
138
+ hf_like = any(key_is_hf_like(k) for k in state_dict.keys()) if state_dict else False
139
  hf_msg = f"hf_like_checkpoint_detected={hf_like} HF_AVAILABLE={HF_AVAILABLE} HF_MODEL_ID={'set' if HF_MODEL_ID else 'not-set'}"
140
  print(hf_msg)
141
  logger.info(hf_msg)