Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -97,16 +97,31 @@ def load_checkpoint_model(model_path, device):
|
|
| 97 |
print("Failed to load checkpoint file:", e)
|
| 98 |
logger.exception("Failed to load checkpoint file:")
|
| 99 |
ckpt = {}
|
| 100 |
-
# unwrap common dict wrapper
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
# Diagnostic: print a few state_dict keys so we can tell checkpoint format
|
| 108 |
try:
|
| 109 |
-
sample_keys = list(state_dict.keys())[:
|
| 110 |
print("Checkpoint sample keys:", sample_keys)
|
| 111 |
logger.info(f"Checkpoint sample keys: {sample_keys}")
|
| 112 |
except Exception:
|
|
@@ -114,7 +129,13 @@ def load_checkpoint_model(model_path, device):
|
|
| 114 |
logger.info("No state_dict keys to sample")
|
| 115 |
|
| 116 |
# Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
hf_msg = f"hf_like_checkpoint_detected={hf_like} HF_AVAILABLE={HF_AVAILABLE} HF_MODEL_ID={'set' if HF_MODEL_ID else 'not-set'}"
|
| 119 |
print(hf_msg)
|
| 120 |
logger.info(hf_msg)
|
|
|
|
| 97 |
print("Failed to load checkpoint file:", e)
|
| 98 |
logger.exception("Failed to load checkpoint file:")
|
| 99 |
ckpt = {}
|
| 100 |
+
# unwrap common dict wrapper (support both 'model_state_dict' and 'state_dict')
|
| 101 |
+
state_dict = {}
|
| 102 |
+
if isinstance(ckpt, dict):
|
| 103 |
+
if 'model_state_dict' in ckpt and isinstance(ckpt['model_state_dict'], dict):
|
| 104 |
+
state_dict = ckpt['model_state_dict']
|
| 105 |
+
elif 'state_dict' in ckpt and isinstance(ckpt['state_dict'], dict):
|
| 106 |
+
state_dict = ckpt['state_dict']
|
| 107 |
+
else:
|
| 108 |
+
# fallback: ckpt may already be a state dict
|
| 109 |
+
state_dict = ckpt
|
| 110 |
+
|
| 111 |
+
# If the state_dict is a single-key wrapper (e.g., {'state_dict': {...}} or {'model': {...}}), unwrap one more level
|
| 112 |
+
if isinstance(state_dict, dict) and len(state_dict) == 1:
|
| 113 |
+
sole_val = next(iter(state_dict.values()))
|
| 114 |
+
if isinstance(sole_val, dict):
|
| 115 |
+
# adopt inner dict as state_dict if it looks like parameters
|
| 116 |
+
inner_keys = list(sole_val.keys())[:8]
|
| 117 |
+
# Heuristic: keys with '.' and numeric shapes indicate a param dict
|
| 118 |
+
if any('.' in k for k in inner_keys):
|
| 119 |
+
logger.info(f"Unwrapping single-key checkpoint wrapper, inner keys sample: {inner_keys}")
|
| 120 |
+
state_dict = sole_val
|
| 121 |
|
| 122 |
# Diagnostic: print a few state_dict keys so we can tell checkpoint format
|
| 123 |
try:
|
| 124 |
+
sample_keys = list(state_dict.keys())[:16]
|
| 125 |
print("Checkpoint sample keys:", sample_keys)
|
| 126 |
logger.info(f"Checkpoint sample keys: {sample_keys}")
|
| 127 |
except Exception:
|
|
|
|
| 129 |
logger.info("No state_dict keys to sample")
|
| 130 |
|
| 131 |
# Heuristic: detect HF-style Swin checkpoint by looking for keys that start with 'swin.'
|
| 132 |
+
# Detect HF-like keys; strip common 'module.' prefix before checking
|
| 133 |
+
def key_is_hf_like(k: str) -> bool:
|
| 134 |
+
kk = k.replace('module.', '')
|
| 135 |
+
kk = kk.lower()
|
| 136 |
+
return kk.startswith('swin.') or 'swin.embeddings' in kk or 'swin.patch_embeddings' in kk
|
| 137 |
+
|
| 138 |
+
hf_like = any(key_is_hf_like(k) for k in state_dict.keys()) if state_dict else False
|
| 139 |
hf_msg = f"hf_like_checkpoint_detected={hf_like} HF_AVAILABLE={HF_AVAILABLE} HF_MODEL_ID={'set' if HF_MODEL_ID else 'not-set'}"
|
| 140 |
print(hf_msg)
|
| 141 |
logger.info(hf_msg)
|