File size: 3,265 Bytes
148d42e
 
 
 
 
 
 
 
 
 
 
 
 
 
bf6ec79
 
148d42e
 
 
 
 
 
bf6ec79
148d42e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf6ec79
148d42e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import system
from utils import load_cfg

from PIL import Image
import torchvision.transforms as transforms
import gradio as gr
from torchvision.transforms import v2


import wandb

cfg = load_cfg("configs/effb0-base-breakhis.yaml")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = system.SimpleClassificationSystem.load_from_checkpoint(
    "runs/03_20_2025_20_19_28/wandb/histopath/03_20_2025_20_19_28/checkpoints/epoch=210-step=41778.ckpt",
    torch.device("cpu"),
    cfg=cfg.system
)

model.to(device)
model.eval()

print("Model loaded successfully!")

image_size = cfg.data.image_size

def preprocess_image(image: Image):
    '''Preprocess the image to be compatible with the model'''
    transform = transforms.Compose([
        v2.Resize(image_size, antialias=True),
        v2.PILToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return transform(image).unsqueeze(0).to(device)


def predict(image_path: str):
    '''Predict the class of the image'''

    image = Image.open(image_path).convert("RGB")

    image = preprocess_image(image)  
    with torch.no_grad():
        output = model(image)  
        output = output.squeeze()  
        pred = torch.sigmoid(output).round()  
        return f"Predicted Class: {'Benign' if int(pred.item())==0 else 'Malignant'}"

examples = [
    "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-40-001.png",
    "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-100-001.png",
    "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-200-013.png",
    "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-400-006.png",
    "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-40-001.png",
    "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-100-010.png",
    "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-200-011.png",
    "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-400-006.png",
    "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-40-010.png",
    "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-100-024.png",
    "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-200-027.png",
    "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-400-013.png",
    "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-40-002.png",
    "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-100-031.png",
    "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-200-031.png",
    "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-400-034.png"
]

examples_labels = ["benign"] * 8 + ["malignant"] * 8

examples_with_labels = [[example, label] for example, label in zip(examples, examples_labels)]

# Gradio Interface
interface = gr.Interface(
    fn=predict, 
    inputs=gr.Image(type="filepath", label="Upload an image"),  
    outputs=gr.Label(),  
    live=True,
    examples=examples_with_labels,
    title="Histopathology Image Classification",
    description="This application classifies histopathology images as either benign or malignant. Upload an image to get the prediction.",
    examples_per_page=len(examples_with_labels),
)

gr.DeepLinkButton()

if __name__ == "__main__":
    interface.launch(share=True)