muruga778 commited on
Commit
007ae18
·
verified ·
1 Parent(s): af19b8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -26
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import json, os
2
  import numpy as np
3
  from PIL import Image
4
- import ast
5
 
6
 
7
  import torch
@@ -10,6 +10,31 @@ import timm
10
  from timm.data import resolve_model_data_config, create_transform
11
  from transformers import AutoTokenizer, AutoModel
12
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
@@ -59,35 +84,26 @@ def triage(label, conf, text):
59
  CFG = load_json("fusion_config.json")
60
  LABEL_MAP = load_json("label_map.json")
61
 
62
- # your label_map.json is a wrapper: {"classes":[...], "label2idx":{...}}
63
- if isinstance(LABEL_MAP, dict) and "classes" in LABEL_MAP and "label2idx" in LABEL_MAP:
64
- CLASSES = [str(x).strip().lower() for x in LABEL_MAP["classes"]]
65
- label2idx = {str(k).strip().lower(): int(v) for k, v in LABEL_MAP["label2idx"].items()}
66
-
67
- # (optional but good) sanity check: classes order matches label2idx
68
- if len(CLASSES) != len(label2idx):
69
- raise ValueError("label_map.json mismatch: len(classes) != len(label2idx)")
70
- for i, lab in enumerate(CLASSES):
71
- if label2idx.get(lab, None) != i:
72
- # if mismatch, rebuild CLASSES from label2idx index order
73
- idx2label = {int(v): str(k).strip().lower() for k, v in LABEL_MAP["label2idx"].items()}
74
- CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
75
- break
76
 
77
  else:
78
- # Fallback: support plain formats too
79
- if all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()): # {"0":"eczema",...}
80
- idx2label = {int(k): str(v).strip().lower() for k, v in LABEL_MAP.items()}
81
- CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
82
- label2idx = {c: i for i, c in enumerate(CLASSES)}
83
- else: # {"eczema":0,...}
84
- label2idx = {str(k).strip().lower(): int(v) for k, v in LABEL_MAP.items()}
85
- CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
86
 
87
  NUM_CLASSES = len(CLASSES)
88
-
89
- print("✅ NUM_CLASSES:", NUM_CLASSES) # should print 16
90
- print("✅ First labels:", CLASSES[:5]) # sanity check
91
 
92
  IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
93
  IMG_SIZE = int(CFG.get("img_size", 384))
 
1
  import json, os
2
  import numpy as np
3
  from PIL import Image
4
+
5
 
6
 
7
  import torch
 
10
  from timm.data import resolve_model_data_config, create_transform
11
  from transformers import AutoTokenizer, AutoModel
12
  import gradio as gr
13
+ import ast
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ SPACE_REPO = os.getenv("SPACE_REPO_NAME", "muruga778/api_for_model") # change if your space id differs
17
+
18
+ def safe_torch_load(filename: str):
19
+ """
20
+ 1) try local file
21
+ 2) if corrupted -> force-download from Hub cache and load again
22
+ """
23
+ try:
24
+ print(f"🔎 Loading weights: {filename} (local)")
25
+ return torch.load(filename, map_location="cpu")
26
+ except Exception as e:
27
+ print(f"⚠️ Local load failed for {filename}: {repr(e)}")
28
+ print("⬇️ Force-downloading from Hugging Face Hub cache...")
29
+ cached = hf_hub_download(
30
+ repo_id=SPACE_REPO,
31
+ repo_type="space",
32
+ filename=filename,
33
+ force_download=True,
34
+ )
35
+ print("✅ Downloaded to:", cached, "size(MB)=", os.path.getsize(cached)/1024/1024)
36
+ return torch.load(cached, map_location="cpu")
37
+
38
 
39
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
 
 
84
  CFG = load_json("fusion_config.json")
85
  LABEL_MAP = load_json("label_map.json")
86
 
87
+ # Your label_map.json looks like: {"classes":[...], "label2idx":{...}}
88
+ if isinstance(LABEL_MAP, dict) and "classes" in LABEL_MAP and isinstance(LABEL_MAP["classes"], list):
89
+ CLASSES = [str(x) for x in LABEL_MAP["classes"]]
90
+ label2idx = LABEL_MAP.get("label2idx", {c: i for i, c in enumerate(CLASSES)})
91
+
92
+ # Older possible formats:
93
+ elif isinstance(LABEL_MAP, dict) and all(isinstance(k, str) and k.isdigit() for k in LABEL_MAP.keys()):
94
+ # {"0":"eczema", ...}
95
+ idx2label = {int(k): str(v) for k, v in LABEL_MAP.items()}
96
+ CLASSES = [idx2label[i] for i in sorted(idx2label.keys())]
97
+ label2idx = {c: i for i, c in enumerate(CLASSES)}
 
 
 
98
 
99
  else:
100
+ # {"eczema": 0, ...}
101
+ label2idx = {str(k): int(v) for k, v in LABEL_MAP.items()}
102
+ CLASSES = [c for c, _ in sorted(label2idx.items(), key=lambda x: x[1])]
 
 
 
 
 
103
 
104
  NUM_CLASSES = len(CLASSES)
105
+ print("✅ NUM_CLASSES:", NUM_CLASSES)
106
+ print("✅ First labels:", CLASSES[:5])
 
107
 
108
  IMG_BACKBONE = CFG.get("img_backbone", "tf_efficientnetv2_s")
109
  IMG_SIZE = int(CFG.get("img_size", 384))