muruga778 commited on
Commit
af19b8e
·
verified ·
1 Parent(s): f9e2cc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -59
app.py CHANGED
@@ -59,67 +59,35 @@ def triage(label, conf, text):
59
  CFG = load_json("fusion_config.json")
60
  LABEL_MAP = load_json("label_map.json")
61
 
62
- # label_map can be {"label": idx} or {"0":"label"}
63
-
64
- def normalize_label(x):
65
- """
66
- Makes label_map robust:
67
- - ["eczema"] -> "eczema"
68
- - "['eczema']" -> "eczema"
69
- - [] / "" / None -> "unknown"
70
- """
71
- if x is None:
72
- return "unknown"
73
- if isinstance(x, list):
74
- return normalize_label(x[0]) if len(x) else "unknown"
75
- s = str(x).strip()
76
- if s == "" or s.lower() in ["none", "nan", "null"]:
77
- return "unknown"
78
- if s.startswith("[") and s.endswith("]"):
79
- try:
80
- return normalize_label(ast.literal_eval(s))
81
- except Exception:
82
- pass
83
- return s.lower()
84
-
85
- def build_classes(label_map):
86
- if not isinstance(label_map, dict) or len(label_map) == 0:
87
- raise ValueError("label_map.json must be a non-empty dict")
88
-
89
- keys = list(label_map.keys())
90
- vals = list(label_map.values())
91
-
92
- # Case A: {"0":"eczema", "1":"acne"} (idx -> label)
93
- if all(isinstance(k, str) and k.isdigit() for k in keys):
94
- idx2label = {int(k): normalize_label(v) for k, v in label_map.items()}
95
- classes = [idx2label[i] for i in sorted(idx2label.keys())]
96
- return classes
97
-
98
- # Case B: {"eczema":0, "acne":1} (label -> idx)
99
- if all(isinstance(v, (int, float)) and float(v).is_integer() for v in vals):
100
- label2idx = {normalize_label(k): int(v) for k, v in label_map.items()}
101
- classes = [lab for lab, _ in sorted(label2idx.items(), key=lambda x: x[1])]
102
- return classes
103
-
104
- # Case C: weird format like {"eczema": ["eczema"], ...}
105
- # Use KEYS as labels (dedup + sorted)
106
- classes = sorted({normalize_label(k) for k in keys})
107
- return classes
108
-
109
- CLASSES = build_classes(LABEL_MAP)
110
- NUM_CLASSES = len(CLASSES)
111
-
112
- print("✅ NUM_CLASSES:", NUM_CLASSES)
113
- print("✅ First labels:", CLASSES[:10])
114
 
115
- IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
116
- IMG_SIZE = int(CFG.get("img_size", 384))
117
- TEXT_MODEL_NAME = CFG.get(
118
- "text_model_name",
119
- "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
120
- )
121
- MAX_LEN = int(CFG.get("max_len", 128))
122
 
 
 
123
 
124
  IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
125
  IMG_SIZE = int(CFG.get("img_size", 384))
 
59
  CFG = load_json("fusion_config.json")
60
  LABEL_MAP = load_json("label_map.json")
61
 
62
+ # ✅ your label_map.json is a wrapper: {"classes":[...], "label2idx":{...}}
63
+ if isinstance(LABEL_MAP, dict) and "classes" in LABEL_MAP and "label2idx" in LABEL_MAP:
64
+ CLASSES = [str(x).strip().lower() for x in LABEL_MAP["classes"]]
65
+ label2idx = {str(k).strip().lower(): int(v) for k, v in LABEL_MAP["label2idx"].items()}
66
+
67
+ # (optional but good) sanity check: classes order matches label2idx
68
+ if len(CLASSES) != len(label2idx):
69
+ raise ValueError("label_map.json mismatch: len(classes) != len(label2idx)")
70
+ for i, lab in enumerate(CLASSES):
71
+ if label2idx.get(lab, None) != i:
72
+ # if mismatch, rebuild CLASSES from label2idx index order
73
+ idx2label = {int(v): str(k).strip().lower() for k, v in LABEL_MAP["label2idx"].items()}
74
+ CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
75
+ break
76
+
77
+ else:
78
+ # Fallback: support plain formats too
79
+ if all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()): # {"0":"eczema",...}
80
+ idx2label = {int(k): str(v).strip().lower() for k, v in LABEL_MAP.items()}
81
+ CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
82
+ label2idx = {c: i for i, c in enumerate(CLASSES)}
83
+ else: # {"eczema":0,...}
84
+ label2idx = {str(k).strip().lower(): int(v) for k, v in LABEL_MAP.items()}
85
+ CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ NUM_CLASSES = len(CLASSES)
 
 
 
 
 
 
88
 
89
+ print("✅ NUM_CLASSES:", NUM_CLASSES) # should print 16
90
+ print("✅ First labels:", CLASSES[:5]) # sanity check
91
 
92
  IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
93
  IMG_SIZE = int(CFG.get("img_size", 384))