Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,7 +23,7 @@ except Exception:
|
|
| 23 |
|
| 24 |
# Configuration
|
| 25 |
# Default to the moved fine-tuned checkpoint if present
|
| 26 |
-
MODEL_PATH = os.environ.get('MODEL_PATH', 'best_model_finetuned.pth')
|
| 27 |
# Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
|
| 28 |
HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
|
| 29 |
CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
|
|
@@ -140,6 +140,33 @@ if model is None:
|
|
| 140 |
print("No model available. The app will still launch but predictions will fail.")
|
| 141 |
else:
|
| 142 |
print(f"Model ready. Classes={actual_classes}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
def predict_bird(image):
|
| 145 |
"""
|
|
|
|
| 23 |
|
| 24 |
# Configuration
|
| 25 |
# Default to the moved fine-tuned checkpoint if present
|
| 26 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('results', 'fine_tune', 'best_model_finetuned.pth'))
|
| 27 |
# Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
|
| 28 |
HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
|
| 29 |
CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
|
|
|
|
| 140 |
print("No model available. The app will still launch but predictions will fail.")
|
| 141 |
else:
|
| 142 |
print(f"Model ready. Classes={actual_classes}")
|
| 143 |
+
# If this is a Hugging Face model with id2label, prefer that mapping
|
| 144 |
+
try:
|
| 145 |
+
hf_config = getattr(model, 'config', None)
|
| 146 |
+
if hf_config is not None:
|
| 147 |
+
id2label = getattr(hf_config, 'id2label', None)
|
| 148 |
+
if id2label:
|
| 149 |
+
# id2label keys may be strings or ints
|
| 150 |
+
# Build ordered class_names list by index
|
| 151 |
+
max_idx = max(int(k) for k in id2label.keys())
|
| 152 |
+
hf_class_names = [None] * (max_idx + 1)
|
| 153 |
+
for k, v in id2label.items():
|
| 154 |
+
hf_class_names[int(k)] = v.replace(' ', '_') if isinstance(v, str) else str(v)
|
| 155 |
+
# Filter out None at end if any
|
| 156 |
+
hf_class_names = [c for c in hf_class_names if c is not None]
|
| 157 |
+
if len(hf_class_names) > 0:
|
| 158 |
+
class_names = hf_class_names
|
| 159 |
+
NUM_CLASSES = len(class_names)
|
| 160 |
+
print(f"Using Hugging Face id2label mapping with {NUM_CLASSES} classes")
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print("Warning: failed to extract id2label from HF model config:", e)
|
| 163 |
+
|
| 164 |
+
# Warn if class_names.json doesn't match model classes
|
| 165 |
+
if class_names and actual_classes and len(class_names) != actual_classes:
|
| 166 |
+
print(f"Warning: class_names.json has {len(class_names)} entries but model expects {actual_classes} classes.")
|
| 167 |
+
# If HF labels exist and match expected size, prefer them
|
| 168 |
+
if len(class_names) < actual_classes:
|
| 169 |
+
print("Note: consider updating class_names.json to match the model's label order or set HF_MODEL_ID to use id2label mapping.")
|
| 170 |
|
| 171 |
def predict_bird(image):
|
| 172 |
"""
|