FrAnKu34t23 commited on
Commit
eab1518
·
verified ·
1 Parent(s): c8c6d5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -23,7 +23,7 @@ except Exception:
23
 
24
  # Configuration
25
  # Default to the moved fine-tuned checkpoint if present
26
- MODEL_PATH = os.environ.get('MODEL_PATH', 'best_model_finetuned.pth')
27
  # Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
28
  HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
29
  CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
@@ -140,6 +140,33 @@ if model is None:
140
  print("No model available. The app will still launch but predictions will fail.")
141
  else:
142
  print(f"Model ready. Classes={actual_classes}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  def predict_bird(image):
145
  """
 
23
 
24
  # Configuration
25
  # Default to the moved fine-tuned checkpoint if present
26
+ MODEL_PATH = os.environ.get('MODEL_PATH', os.path.join('results', 'fine_tune', 'best_model_finetuned.pth'))
27
  # Optional: if your HF model id is known (e.g. Emiel/cub-200-bird-classifier-swin), set HF_MODEL_ID env var
28
  HF_MODEL_ID = os.environ.get('HF_MODEL_ID', None)
29
  CLASS_NAMES_PATH = os.environ.get('CLASS_NAMES_PATH', 'class_names.json')
 
140
  print("No model available. The app will still launch but predictions will fail.")
141
  else:
142
  print(f"Model ready. Classes={actual_classes}")
143
+ # If this is a Hugging Face model with id2label, prefer that mapping
144
+ try:
145
+ hf_config = getattr(model, 'config', None)
146
+ if hf_config is not None:
147
+ id2label = getattr(hf_config, 'id2label', None)
148
+ if id2label:
149
+ # id2label keys may be strings or ints
150
+ # Build ordered class_names list by index
151
+ max_idx = max(int(k) for k in id2label.keys())
152
+ hf_class_names = [None] * (max_idx + 1)
153
+ for k, v in id2label.items():
154
+ hf_class_names[int(k)] = v.replace(' ', '_') if isinstance(v, str) else str(v)
155
+ # Filter out None at end if any
156
+ hf_class_names = [c for c in hf_class_names if c is not None]
157
+ if len(hf_class_names) > 0:
158
+ class_names = hf_class_names
159
+ NUM_CLASSES = len(class_names)
160
+ print(f"Using Hugging Face id2label mapping with {NUM_CLASSES} classes")
161
+ except Exception as e:
162
+ print("Warning: failed to extract id2label from HF model config:", e)
163
+
164
+ # Warn if class_names.json doesn't match model classes
165
+ if class_names and actual_classes and len(class_names) != actual_classes:
166
+ print(f"Warning: class_names.json has {len(class_names)} entries but model expects {actual_classes} classes.")
167
+ # If HF labels exist and match expected size, prefer them
168
+ if len(class_names) < actual_classes:
169
+ print("Note: consider updating class_names.json to match the model's label order or set HF_MODEL_ID to use id2label mapping.")
170
 
171
  def predict_bird(image):
172
  """