pragun3669 commited on
Commit
6528f19
·
verified ·
1 Parent(s): 432c95a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -13,7 +13,7 @@ from pymongo import MongoClient
13
  import warnings
14
 
15
  # =====================
16
- # Silence HF warnings (safe)
17
  # =====================
18
  warnings.filterwarnings("ignore")
19
 
@@ -35,7 +35,7 @@ db = client["skin-disease-db"]
35
  reports = db["reports"]
36
 
37
  # =====================
38
- # Labels
39
  # =====================
40
  labels = [
41
  "Acne and Rosacea Photos",
@@ -67,40 +67,36 @@ NUM_CLASSES = len(labels)
67
  device = torch.device("cpu")
68
 
69
  # =====================
70
- # Load model weights from HF Hub
71
  # =====================
72
  weights_path = hf_hub_download(
73
  repo_id="pragun3669/dermify-vit",
74
  filename="best_vit1_model.pth"
75
  )
76
 
77
- # Base ViT architecture (23 classes)
78
  model = ViTForImageClassification.from_pretrained(
79
  "google/vit-large-patch16-224",
80
  num_labels=NUM_CLASSES,
81
  ignore_mismatched_sizes=True
82
  )
83
 
84
- # Load checkpoint safely (IGNORE classifier mismatch)
85
  state_dict = torch.load(weights_path, map_location=device)
86
-
87
- # 🔥 Drop classifier weights if they exist (1000-class head)
88
- state_dict = {
89
- k: v for k, v in state_dict.items()
90
- if not k.startswith("classifier")
91
- }
92
-
93
  model.load_state_dict(state_dict, strict=False)
 
94
  model.to(device)
95
  model.eval()
96
 
97
  # =====================
98
- # Image Transform
99
  # =====================
100
  transform = transforms.Compose([
101
  transforms.Resize((224, 224)),
102
  transforms.ToTensor(),
103
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
 
 
 
104
  ])
105
 
106
  # =====================
@@ -126,7 +122,7 @@ def predict():
126
  logits = model(tensor).logits
127
  probs = F.softmax(logits, dim=1)
128
 
129
- idx = probs.argmax().item()
130
  confidence = probs[0][idx].item()
131
 
132
  reports.insert_one({
@@ -141,7 +137,7 @@ def predict():
141
  })
142
 
143
  # =====================
144
- # Run (HF Spaces compatible)
145
  # =====================
146
  if __name__ == "__main__":
147
  app.run(host="0.0.0.0", port=7860)
 
13
  import warnings
14
 
15
  # =====================
16
+ # Silence HF warnings
17
  # =====================
18
  warnings.filterwarnings("ignore")
19
 
 
35
  reports = db["reports"]
36
 
37
  # =====================
38
+ # Labels (ORDER MUST MATCH TRAINING)
39
  # =====================
40
  labels = [
41
  "Acne and Rosacea Photos",
 
67
  device = torch.device("cpu")
68
 
69
  # =====================
70
+ # Load trained model
71
  # =====================
72
  weights_path = hf_hub_download(
73
  repo_id="pragun3669/dermify-vit",
74
  filename="best_vit1_model.pth"
75
  )
76
 
 
77
  model = ViTForImageClassification.from_pretrained(
78
  "google/vit-large-patch16-224",
79
  num_labels=NUM_CLASSES,
80
  ignore_mismatched_sizes=True
81
  )
82
 
83
+ # LOAD FULL TRAINED STATE (INCLUDING CLASSIFIER)
84
  state_dict = torch.load(weights_path, map_location=device)
 
 
 
 
 
 
 
85
  model.load_state_dict(state_dict, strict=False)
86
+
87
  model.to(device)
88
  model.eval()
89
 
90
  # =====================
91
+ # Image Transform (MATCH TRAINING)
92
  # =====================
93
  transform = transforms.Compose([
94
  transforms.Resize((224, 224)),
95
  transforms.ToTensor(),
96
+ transforms.Normalize(
97
+ mean=[0.5, 0.5, 0.5],
98
+ std=[0.5, 0.5, 0.5]
99
+ )
100
  ])
101
 
102
  # =====================
 
122
  logits = model(tensor).logits
123
  probs = F.softmax(logits, dim=1)
124
 
125
+ idx = probs.argmax(dim=1).item()
126
  confidence = probs[0][idx].item()
127
 
128
  reports.insert_one({
 
137
  })
138
 
139
  # =====================
140
+ # Run (HF Spaces)
141
  # =====================
142
  if __name__ == "__main__":
143
  app.run(host="0.0.0.0", port=7860)