KublaiKhan1 commited on
Commit
cddce8e
·
verified ·
1 Parent(s): d56667c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -41,7 +41,15 @@ CRC_LABELS = {
41
  7: "STR",
42
  8: "TUM",
43
  }
44
-
 
 
 
 
 
 
 
 
45
  print("Loading DinoV2 base model...")
46
  dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
47
 
@@ -59,7 +67,7 @@ dinov2.eval()
59
 
60
  #torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt")
61
 
62
- def setup_linear(path):
63
  print(f"Loading {path} linear classifier...")
64
  # Load the best checkpoint from the latest run
65
  linear_checkpoint = torch.load(path)
@@ -67,25 +75,12 @@ def setup_linear(path):
67
  linear_bias = linear_checkpoint["state_dict"]["head.bias"]
68
 
69
  # Create linear layer
70
- linear = torch.nn.Linear(1536, 4)
71
  linear.weight.data = linear_weights
72
  linear.bias.data = linear_bias
73
  linear.eval()
74
  return linear
75
 
76
- def setup_linear_crc(path):
77
- print(f"Loading {path} linear classifier...")
78
- # Load the best checkpoint from the latest run
79
- linear_checkpoint = torch.load(path)
80
- linear_weights = linear_checkpoint["state_dict"]["head.weight"]
81
- linear_bias = linear_checkpoint["state_dict"]["head.bias"]
82
-
83
- # Create linear layer
84
- linear = torch.nn.Linear(1536, 9)
85
- linear.weight.data = linear_weights
86
- linear.bias.data = linear_bias
87
- linear.eval()
88
- return linear
89
 
90
  # Move models to GPU if available
91
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -95,7 +90,7 @@ breakhis_path = "./breakhis_best.ckpt"
95
  breakhis_linear = setup_linear(breakhis_path).to(device)
96
 
97
  gleason_path = "./gleason_best.ckpt"
98
- gleason_linear = setup_linear(gleason_path).to(device)
99
 
100
  bach_path = "./bach_best.ckpt"
101
  bach_linear = setup_linear(bach_path).to(device)
@@ -103,6 +98,8 @@ bach_linear = setup_linear(bach_path).to(device)
103
  crc_path = "./crc_best.ckpt"
104
  crc_linear = setup_linear_crc(crc_path).to(device)
105
 
 
 
106
 
107
  print(f"Models loaded on {device}")
108
 
@@ -142,6 +139,10 @@ def predict_bach(image):
142
  def predict_crc(image):
143
 
144
  return predict_class(image, crc_linear, "crc")
 
 
 
 
145
 
146
 
147
  def predict_class(image, linear, dataset):
@@ -181,7 +182,8 @@ def predict_class(image, linear, dataset):
181
  probs_dict[BACH_LABELS[idx]] = float(prob)
182
  elif dataset == "crc":
183
  probs_dict[CRC_LABELS[idx]] = float(prob)
184
-
 
185
 
186
  return probs_dict
187
 
@@ -277,11 +279,30 @@ bach = gr.Interface(
277
  ], # You can add example image paths here
278
  theme=gr.themes.Soft()
279
  )
 
 
 
 
 
 
 
 
 
280
 
 
 
 
 
 
 
 
 
 
281
 
282
 
 
283
 
284
- demo = gr.TabbedInterface([breakhis, gleason, crc, bach],["BreakHis", "Gleason", "CRC", "Bach"])
285
 
286
 
287
  if __name__ == "__main__":
 
41
  7: "STR",
42
  8: "TUM",
43
  }
44
+ BRACS_LABELS = {
45
+ 0: "Normal",
46
+ 1: "Pathological Benign",
47
+ 2: "Usual Ductal Hyperplasia",
48
+ 3: "Flat Epithelial Atypia",
49
+ 4: "Atypical Ductal Hyperplasia",
50
+ 5: "Ductal Carcinoma In Situ"
51
+ 6: "Invasive Carcinoma"
52
+ }
53
  print("Loading DinoV2 base model...")
54
  dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
55
 
 
67
 
68
  #torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt")
69
 
70
+ def setup_linear(path, classes = 4):
71
  print(f"Loading {path} linear classifier...")
72
  # Load the best checkpoint from the latest run
73
  linear_checkpoint = torch.load(path)
 
75
  linear_bias = linear_checkpoint["state_dict"]["head.bias"]
76
 
77
  # Create linear layer
78
+ linear = torch.nn.Linear(1536, classes)
79
  linear.weight.data = linear_weights
80
  linear.bias.data = linear_bias
81
  linear.eval()
82
  return linear
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # Move models to GPU if available
86
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
90
  breakhis_linear = setup_linear(breakhis_path).to(device)
91
 
92
  gleason_path = "./gleason_best.ckpt"
93
+ gleason_linear = setup_linear(gleason_path, 9).to(device)
94
 
95
  bach_path = "./bach_best.ckpt"
96
  bach_linear = setup_linear(bach_path).to(device)
 
98
  crc_path = "./crc_best.ckpt"
99
  crc_linear = setup_linear_crc(crc_path).to(device)
100
 
101
+ #bracs_path = "./bracs_best.ckpt"
102
+ #bracs_linear = setup_linear(bracs_path, 7).to(device)
103
 
104
  print(f"Models loaded on {device}")
105
 
 
139
  def predict_crc(image):
140
 
141
  return predict_class(image, crc_linear, "crc")
142
+
143
+ def predict_bracs(image):
144
+
145
+ return predict_class(image, bracs_linear, "bracs")
146
 
147
 
148
  def predict_class(image, linear, dataset):
 
182
  probs_dict[BACH_LABELS[idx]] = float(prob)
183
  elif dataset == "crc":
184
  probs_dict[CRC_LABELS[idx]] = float(prob)
185
+ elif dataset == "bracs":
186
+ probs_dict[BRACS_LABELS[idx]] = float(prob)
187
 
188
  return probs_dict
189
 
 
279
  ], # You can add example image paths here
280
  theme=gr.themes.Soft()
281
  )
282
+ bracs = gr.Interface(
283
+ fn=predict_bracs,
284
+ inputs=gr.Image(type="filepath", label="Upload Cancer Image"),
285
+ outputs=gr.Label(num_top_classes=7, label="Bracs Tumor Type Prediction"),
286
+ title="Tumor Classification",
287
+ description="""
288
+ Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification. Do not otherwise modify your image.
289
+
290
+ This model uses a custom-trained DinoV2 foundation model for pathology images with a linear classifier for tumor classification.
291
 
292
+ Images are classified as Normal, Pathological Benign, Usual Ductal Hyperplasia, Flat Epithelial Atypia,
293
+ Atypical Ductal Hyperplasia, Ductal Carcinoma In Situ, Invasive Carcinoma
294
+
295
+ For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
296
+ """,
297
+ examples=[
298
+ ], # You can add example image paths here
299
+ theme=gr.themes.Soft()
300
+ )
301
 
302
 
303
+
304
 
305
+ demo = gr.TabbedInterface([breakhis, gleason, crc, bach, bracs],["BreakHis", "Gleason", "CRC", "Bach", "BRACS"])
306
 
307
 
308
  if __name__ == "__main__":