juppy44 commited on
Commit
670090a
·
verified ·
1 Parent(s): eab277b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -8
app.py CHANGED
@@ -7,6 +7,8 @@ from transformers import (
7
  from peft import PeftModel
8
  from PIL import Image
9
  import torch
 
 
10
 
11
  # Global model (all species)
12
  GLOBAL_MODEL_ID = "juppy44/plant-identification-2m-vit-b"
@@ -56,6 +58,7 @@ def get_wa_model():
56
  - Same base checkpoint as the global model
57
  - Classifier resized to WA_NUM_LABELS
58
  - WA LoRA applied on top
 
59
  """
60
  global wa_model, wa_id2label
61
 
@@ -73,18 +76,47 @@ def get_wa_model():
73
  ).to(device)
74
  wa_base.eval()
75
 
76
- # Try to get proper labels from the adapter config first
77
  local_id2label = {}
78
  try:
79
- wa_cfg = AutoConfig.from_pretrained(WA_LORA_ID, subfolder=WA_LORA_SUBFOLDER)
80
- if getattr(wa_cfg, "id2label", None):
81
- local_id2label = normalize_id2label(wa_cfg.id2label)
82
- except Exception:
83
- pass
84
-
85
- # Fallback to base config if adapter config doesn't have them
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if not local_id2label and getattr(wa_base.config, "id2label", None):
87
  local_id2label = normalize_id2label(wa_base.config.id2label)
 
88
 
89
  print(
90
  f"Loading WA adapter {WA_LORA_ID} (subfolder='{WA_LORA_SUBFOLDER}')..."
 
7
  from peft import PeftModel
8
  from PIL import Image
9
  import torch
10
+ import json
11
+ from huggingface_hub import hf_hub_download
12
 
13
  # Global model (all species)
14
  GLOBAL_MODEL_ID = "juppy44/plant-identification-2m-vit-b"
 
58
  - Same base checkpoint as the global model
59
  - Classifier resized to WA_NUM_LABELS
60
  - WA LoRA applied on top
61
+ - id2label loaded primarily from labels.json in the adapter repo
62
  """
63
  global wa_model, wa_id2label
64
 
 
76
  ).to(device)
77
  wa_base.eval()
78
 
79
+ # --- PRIMARY: load labels from labels.json in the LoRA repo root ---
80
  local_id2label = {}
81
  try:
82
+ labels_path = hf_hub_download(
83
+ repo_id=WA_LORA_ID,
84
+ filename="labels.json",
85
+ )
86
+ with open(labels_path, "r") as f:
87
+ labels_data = json.load(f)
88
+
89
+ label2id = labels_data.get("label2id", {})
90
+ # invert label2id: {species_name: idx} -> {idx: species_name}
91
+ for label, idx in label2id.items():
92
+ try:
93
+ idx_int = int(idx)
94
+ local_id2label[idx_int] = label
95
+ except (TypeError, ValueError):
96
+ continue
97
+
98
+ if local_id2label:
99
+ print("Loaded id2label from labels.json in WA adapter repo.")
100
+ except Exception as e:
101
+ print(f"Could not load labels.json from WA adapter repo: {e}")
102
+
103
+ # --- SECONDARY: adapter config id2label (if labels.json missing/empty) ---
104
+ if not local_id2label:
105
+ try:
106
+ wa_cfg = AutoConfig.from_pretrained(
107
+ WA_LORA_ID,
108
+ subfolder=WA_LORA_SUBFOLDER,
109
+ )
110
+ if getattr(wa_cfg, "id2label", None):
111
+ local_id2label = normalize_id2label(wa_cfg.id2label)
112
+ print("Loaded id2label from WA adapter config.")
113
+ except Exception as e:
114
+ print(f"Could not load id2label from adapter config: {e}")
115
+
116
+ # --- TERTIARY: WA base config id2label ---
117
  if not local_id2label and getattr(wa_base.config, "id2label", None):
118
  local_id2label = normalize_id2label(wa_base.config.id2label)
119
+ print("Fallback: using id2label from WA base config.")
120
 
121
  print(
122
  f"Loading WA adapter {WA_LORA_ID} (subfolder='{WA_LORA_SUBFOLDER}')..."