Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import json, os
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
@@ -58,15 +60,67 @@ CFG = load_json("fusion_config.json")
|
|
| 58 |
LABEL_MAP = load_json("label_map.json")
|
| 59 |
|
| 60 |
# label_map can be {"label": idx} or {"0":"label"}
|
| 61 |
-
if all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()):
|
| 62 |
-
idx2label = {int(k): v for k, v in LABEL_MAP.items()}
|
| 63 |
-
CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
|
| 64 |
-
else:
|
| 65 |
-
label2idx = {k: int(v) for k, v in LABEL_MAP.items()}
|
| 66 |
-
CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
NUM_CLASSES = len(CLASSES)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
|
| 71 |
IMG_SIZE = int(CFG.get("img_size", 384))
|
| 72 |
TEXT_MODEL_NAME = CFG.get("text_model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
|
|
|
|
| 1 |
import json, os
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
+
import ast
|
| 5 |
+
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
|
|
|
| 60 |
LABEL_MAP = load_json("label_map.json")
|
| 61 |
|
| 62 |
# label_map can be {"label": idx} or {"0":"label"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
def normalize_label(x):
|
| 65 |
+
"""
|
| 66 |
+
Makes label_map robust:
|
| 67 |
+
- ["eczema"] -> "eczema"
|
| 68 |
+
- "['eczema']" -> "eczema"
|
| 69 |
+
- [] / "" / None -> "unknown"
|
| 70 |
+
"""
|
| 71 |
+
if x is None:
|
| 72 |
+
return "unknown"
|
| 73 |
+
if isinstance(x, list):
|
| 74 |
+
return normalize_label(x[0]) if len(x) else "unknown"
|
| 75 |
+
s = str(x).strip()
|
| 76 |
+
if s == "" or s.lower() in ["none", "nan", "null"]:
|
| 77 |
+
return "unknown"
|
| 78 |
+
if s.startswith("[") and s.endswith("]"):
|
| 79 |
+
try:
|
| 80 |
+
return normalize_label(ast.literal_eval(s))
|
| 81 |
+
except Exception:
|
| 82 |
+
pass
|
| 83 |
+
return s.lower()
|
| 84 |
+
|
| 85 |
+
def build_classes(label_map):
|
| 86 |
+
if not isinstance(label_map, dict) or len(label_map) == 0:
|
| 87 |
+
raise ValueError("label_map.json must be a non-empty dict")
|
| 88 |
+
|
| 89 |
+
keys = list(label_map.keys())
|
| 90 |
+
vals = list(label_map.values())
|
| 91 |
+
|
| 92 |
+
# Case A: {"0":"eczema", "1":"acne"} (idx -> label)
|
| 93 |
+
if all(isinstance(k, str) and k.isdigit() for k in keys):
|
| 94 |
+
idx2label = {int(k): normalize_label(v) for k, v in label_map.items()}
|
| 95 |
+
classes = [idx2label[i] for i in sorted(idx2label.keys())]
|
| 96 |
+
return classes
|
| 97 |
+
|
| 98 |
+
# Case B: {"eczema":0, "acne":1} (label -> idx)
|
| 99 |
+
if all(isinstance(v, (int, float)) and float(v).is_integer() for v in vals):
|
| 100 |
+
label2idx = {normalize_label(k): int(v) for k, v in label_map.items()}
|
| 101 |
+
classes = [lab for lab, _ in sorted(label2idx.items(), key=lambda x: x[1])]
|
| 102 |
+
return classes
|
| 103 |
+
|
| 104 |
+
# Case C: weird format like {"eczema": ["eczema"], ...}
|
| 105 |
+
# Use KEYS as labels (dedup + sorted)
|
| 106 |
+
classes = sorted({normalize_label(k) for k in keys})
|
| 107 |
+
return classes
|
| 108 |
+
|
| 109 |
+
CLASSES = build_classes(LABEL_MAP)
|
| 110 |
NUM_CLASSES = len(CLASSES)
|
| 111 |
|
| 112 |
+
print("✅ NUM_CLASSES:", NUM_CLASSES)
|
| 113 |
+
print("✅ First labels:", CLASSES[:10])
|
| 114 |
+
|
| 115 |
+
IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
|
| 116 |
+
IMG_SIZE = int(CFG.get("img_size", 384))
|
| 117 |
+
TEXT_MODEL_NAME = CFG.get(
|
| 118 |
+
"text_model_name",
|
| 119 |
+
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| 120 |
+
)
|
| 121 |
+
MAX_LEN = int(CFG.get("max_len", 128))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
|
| 125 |
IMG_SIZE = int(CFG.get("img_size", 384))
|
| 126 |
TEXT_MODEL_NAME = CFG.get("text_model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
|