Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -59,67 +59,35 @@ def triage(label, conf, text):
|
|
| 59 |
CFG = load_json("fusion_config.json")
|
| 60 |
LABEL_MAP = load_json("label_map.json")
|
| 61 |
|
| 62 |
-
# label_map
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
if not isinstance(label_map, dict) or len(label_map) == 0:
|
| 87 |
-
raise ValueError("label_map.json must be a non-empty dict")
|
| 88 |
-
|
| 89 |
-
keys = list(label_map.keys())
|
| 90 |
-
vals = list(label_map.values())
|
| 91 |
-
|
| 92 |
-
# Case A: {"0":"eczema", "1":"acne"} (idx -> label)
|
| 93 |
-
if all(isinstance(k, str) and k.isdigit() for k in keys):
|
| 94 |
-
idx2label = {int(k): normalize_label(v) for k, v in label_map.items()}
|
| 95 |
-
classes = [idx2label[i] for i in sorted(idx2label.keys())]
|
| 96 |
-
return classes
|
| 97 |
-
|
| 98 |
-
# Case B: {"eczema":0, "acne":1} (label -> idx)
|
| 99 |
-
if all(isinstance(v, (int, float)) and float(v).is_integer() for v in vals):
|
| 100 |
-
label2idx = {normalize_label(k): int(v) for k, v in label_map.items()}
|
| 101 |
-
classes = [lab for lab, _ in sorted(label2idx.items(), key=lambda x: x[1])]
|
| 102 |
-
return classes
|
| 103 |
-
|
| 104 |
-
# Case C: weird format like {"eczema": ["eczema"], ...}
|
| 105 |
-
# Use KEYS as labels (dedup + sorted)
|
| 106 |
-
classes = sorted({normalize_label(k) for k in keys})
|
| 107 |
-
return classes
|
| 108 |
-
|
| 109 |
-
CLASSES = build_classes(LABEL_MAP)
|
| 110 |
-
NUM_CLASSES = len(CLASSES)
|
| 111 |
-
|
| 112 |
-
print("✅ NUM_CLASSES:", NUM_CLASSES)
|
| 113 |
-
print("✅ First labels:", CLASSES[:10])
|
| 114 |
|
| 115 |
-
|
| 116 |
-
IMG_SIZE = int(CFG.get("img_size", 384))
|
| 117 |
-
TEXT_MODEL_NAME = CFG.get(
|
| 118 |
-
"text_model_name",
|
| 119 |
-
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| 120 |
-
)
|
| 121 |
-
MAX_LEN = int(CFG.get("max_len", 128))
|
| 122 |
|
|
|
|
|
|
|
| 123 |
|
| 124 |
IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
|
| 125 |
IMG_SIZE = int(CFG.get("img_size", 384))
|
|
|
|
| 59 |
CFG = load_json("fusion_config.json")
|
| 60 |
LABEL_MAP = load_json("label_map.json")
|
| 61 |
|
| 62 |
+
# ✅ your label_map.json is a wrapper: {"classes":[...], "label2idx":{...}}
|
| 63 |
+
if isinstance(LABEL_MAP, dict) and "classes" in LABEL_MAP and "label2idx" in LABEL_MAP:
|
| 64 |
+
CLASSES = [str(x).strip().lower() for x in LABEL_MAP["classes"]]
|
| 65 |
+
label2idx = {str(k).strip().lower(): int(v) for k, v in LABEL_MAP["label2idx"].items()}
|
| 66 |
+
|
| 67 |
+
# (optional but good) sanity check: classes order matches label2idx
|
| 68 |
+
if len(CLASSES) != len(label2idx):
|
| 69 |
+
raise ValueError("label_map.json mismatch: len(classes) != len(label2idx)")
|
| 70 |
+
for i, lab in enumerate(CLASSES):
|
| 71 |
+
if label2idx.get(lab, None) != i:
|
| 72 |
+
# if mismatch, rebuild CLASSES from label2idx index order
|
| 73 |
+
idx2label = {int(v): str(k).strip().lower() for k, v in LABEL_MAP["label2idx"].items()}
|
| 74 |
+
CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
else:
|
| 78 |
+
# Fallback: support plain formats too
|
| 79 |
+
if all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()): # {"0":"eczema",...}
|
| 80 |
+
idx2label = {int(k): str(v).strip().lower() for k, v in LABEL_MAP.items()}
|
| 81 |
+
CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
|
| 82 |
+
label2idx = {c: i for i, c in enumerate(CLASSES)}
|
| 83 |
+
else: # {"eczema":0,...}
|
| 84 |
+
label2idx = {str(k).strip().lower(): int(v) for k, v in LABEL_MAP.items()}
|
| 85 |
+
CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
NUM_CLASSES = len(CLASSES)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
print("✅ NUM_CLASSES:", NUM_CLASSES) # should print 16
|
| 90 |
+
print("✅ First labels:", CLASSES[:5]) # sanity check
|
| 91 |
|
| 92 |
IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
|
| 93 |
IMG_SIZE = int(CFG.get("img_size", 384))
|