Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -41,7 +41,15 @@ CRC_LABELS = {
|
|
| 41 |
7: "STR",
|
| 42 |
8: "TUM",
|
| 43 |
}
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
print("Loading DinoV2 base model...")
|
| 46 |
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
|
| 47 |
|
|
@@ -59,7 +67,7 @@ dinov2.eval()
|
|
| 59 |
|
| 60 |
#torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt")
|
| 61 |
|
| 62 |
-
def setup_linear(path):
|
| 63 |
print(f"Loading {path} linear classifier...")
|
| 64 |
# Load the best checkpoint from the latest run
|
| 65 |
linear_checkpoint = torch.load(path)
|
|
@@ -67,25 +75,12 @@ def setup_linear(path):
|
|
| 67 |
linear_bias = linear_checkpoint["state_dict"]["head.bias"]
|
| 68 |
|
| 69 |
# Create linear layer
|
| 70 |
-
linear = torch.nn.Linear(1536,
|
| 71 |
linear.weight.data = linear_weights
|
| 72 |
linear.bias.data = linear_bias
|
| 73 |
linear.eval()
|
| 74 |
return linear
|
| 75 |
|
| 76 |
-
def setup_linear_crc(path):
|
| 77 |
-
print(f"Loading {path} linear classifier...")
|
| 78 |
-
# Load the best checkpoint from the latest run
|
| 79 |
-
linear_checkpoint = torch.load(path)
|
| 80 |
-
linear_weights = linear_checkpoint["state_dict"]["head.weight"]
|
| 81 |
-
linear_bias = linear_checkpoint["state_dict"]["head.bias"]
|
| 82 |
-
|
| 83 |
-
# Create linear layer
|
| 84 |
-
linear = torch.nn.Linear(1536, 9)
|
| 85 |
-
linear.weight.data = linear_weights
|
| 86 |
-
linear.bias.data = linear_bias
|
| 87 |
-
linear.eval()
|
| 88 |
-
return linear
|
| 89 |
|
| 90 |
# Move models to GPU if available
|
| 91 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -95,7 +90,7 @@ breakhis_path = "./breakhis_best.ckpt"
|
|
| 95 |
breakhis_linear = setup_linear(breakhis_path).to(device)
|
| 96 |
|
| 97 |
gleason_path = "./gleason_best.ckpt"
|
| 98 |
-
gleason_linear = setup_linear(gleason_path).to(device)
|
| 99 |
|
| 100 |
bach_path = "./bach_best.ckpt"
|
| 101 |
bach_linear = setup_linear(bach_path).to(device)
|
|
@@ -103,6 +98,8 @@ bach_linear = setup_linear(bach_path).to(device)
|
|
| 103 |
crc_path = "./crc_best.ckpt"
|
| 104 |
crc_linear = setup_linear_crc(crc_path).to(device)
|
| 105 |
|
|
|
|
|
|
|
| 106 |
|
| 107 |
print(f"Models loaded on {device}")
|
| 108 |
|
|
@@ -142,6 +139,10 @@ def predict_bach(image):
|
|
| 142 |
def predict_crc(image):
|
| 143 |
|
| 144 |
return predict_class(image, crc_linear, "crc")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
def predict_class(image, linear, dataset):
|
|
@@ -181,7 +182,8 @@ def predict_class(image, linear, dataset):
|
|
| 181 |
probs_dict[BACH_LABELS[idx]] = float(prob)
|
| 182 |
elif dataset == "crc":
|
| 183 |
probs_dict[CRC_LABELS[idx]] = float(prob)
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
return probs_dict
|
| 187 |
|
|
@@ -277,11 +279,30 @@ bach = gr.Interface(
|
|
| 277 |
], # You can add example image paths here
|
| 278 |
theme=gr.themes.Soft()
|
| 279 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
|
|
|
|
| 283 |
|
| 284 |
-
demo = gr.TabbedInterface([breakhis, gleason, crc, bach],["BreakHis", "Gleason", "CRC", "Bach"])
|
| 285 |
|
| 286 |
|
| 287 |
if __name__ == "__main__":
|
|
|
|
| 41 |
7: "STR",
|
| 42 |
8: "TUM",
|
| 43 |
}
|
| 44 |
+
BRACS_LABELS = {
|
| 45 |
+
0: "Normal",
|
| 46 |
+
1: "Pathological Benign",
|
| 47 |
+
2: "Usual Ductal Hyperplasia",
|
| 48 |
+
3: "Flat Epithelial Atypia",
|
| 49 |
+
4: "Atypical Ductal Hyperplasia",
|
| 50 |
+
5: "Ductal Carcinoma In Situ"
|
| 51 |
+
6: "Invasive Carcinoma"
|
| 52 |
+
}
|
| 53 |
print("Loading DinoV2 base model...")
|
| 54 |
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
|
| 55 |
|
|
|
|
| 67 |
|
| 68 |
#torch.save(dinov2.state_dict(), "teacher_checkpoint_load.pt")
|
| 69 |
|
| 70 |
+
def setup_linear(path, classes = 4):
|
| 71 |
print(f"Loading {path} linear classifier...")
|
| 72 |
# Load the best checkpoint from the latest run
|
| 73 |
linear_checkpoint = torch.load(path)
|
|
|
|
| 75 |
linear_bias = linear_checkpoint["state_dict"]["head.bias"]
|
| 76 |
|
| 77 |
# Create linear layer
|
| 78 |
+
linear = torch.nn.Linear(1536, classes)
|
| 79 |
linear.weight.data = linear_weights
|
| 80 |
linear.bias.data = linear_bias
|
| 81 |
linear.eval()
|
| 82 |
return linear
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# Move models to GPU if available
|
| 86 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 90 |
breakhis_linear = setup_linear(breakhis_path).to(device)
|
| 91 |
|
| 92 |
gleason_path = "./gleason_best.ckpt"
|
| 93 |
+
gleason_linear = setup_linear(gleason_path, 9).to(device)
|
| 94 |
|
| 95 |
bach_path = "./bach_best.ckpt"
|
| 96 |
bach_linear = setup_linear(bach_path).to(device)
|
|
|
|
| 98 |
crc_path = "./crc_best.ckpt"
|
| 99 |
crc_linear = setup_linear_crc(crc_path).to(device)
|
| 100 |
|
| 101 |
+
#bracs_path = "./bracs_best.ckpt"
|
| 102 |
+
#bracs_linear = setup_linear(bracs_path, 7).to(device)
|
| 103 |
|
| 104 |
print(f"Models loaded on {device}")
|
| 105 |
|
|
|
|
| 139 |
def predict_crc(image):
|
| 140 |
|
| 141 |
return predict_class(image, crc_linear, "crc")
|
| 142 |
+
|
| 143 |
+
def predict_bracs(image):
|
| 144 |
+
|
| 145 |
+
return predict_class(image, bracs_linear, "bracs")
|
| 146 |
|
| 147 |
|
| 148 |
def predict_class(image, linear, dataset):
|
|
|
|
| 182 |
probs_dict[BACH_LABELS[idx]] = float(prob)
|
| 183 |
elif dataset == "crc":
|
| 184 |
probs_dict[CRC_LABELS[idx]] = float(prob)
|
| 185 |
+
elif dataset == "bracs":
|
| 186 |
+
probs_dict[BRACS_LABELS[idx]] = float(prob)
|
| 187 |
|
| 188 |
return probs_dict
|
| 189 |
|
|
|
|
| 279 |
], # You can add example image paths here
|
| 280 |
theme=gr.themes.Soft()
|
| 281 |
)
|
| 282 |
+
bracs = gr.Interface(
|
| 283 |
+
fn=predict_bracs,
|
| 284 |
+
inputs=gr.Image(type="filepath", label="Upload Cancer Image"),
|
| 285 |
+
outputs=gr.Label(num_top_classes=7, label="Bracs Tumor Type Prediction"),
|
| 286 |
+
title="Tumor Classification",
|
| 287 |
+
description="""
|
| 288 |
+
Upload a prostate cancer image to predict the tumor type. Your image must be at 40X magnification. Do not otherwise modify your image.
|
| 289 |
+
|
| 290 |
+
This model uses a custom-trained DinoV2 foundation model for pathology images with a linear classifier for tumor classification.
|
| 291 |
|
| 292 |
+
Images are classified as Normal, Pathological Benign, Usual Ductal Hyperplasia, Flat Epithelial Atypia,
|
| 293 |
+
Atypical Ductal Hyperplasia, Ductal Carcinoma In Situ, Invasive Carcinoma
|
| 294 |
+
|
| 295 |
+
For this particular demo, images *must* be one of the sample classes - unsupported classes will yield confusing and/or useless results.
|
| 296 |
+
""",
|
| 297 |
+
examples=[
|
| 298 |
+
], # You can add example image paths here
|
| 299 |
+
theme=gr.themes.Soft()
|
| 300 |
+
)
|
| 301 |
|
| 302 |
|
| 303 |
+
|
| 304 |
|
| 305 |
+
demo = gr.TabbedInterface([breakhis, gleason, crc, bach, bracs],["BreakHis", "Gleason", "CRC", "Bach", "BRACS"])
|
| 306 |
|
| 307 |
|
| 308 |
if __name__ == "__main__":
|