import torch import system from utils import load_cfg from PIL import Image import torchvision.transforms as transforms import gradio as gr from torchvision.transforms import v2 import wandb cfg = load_cfg("configs/effb0-base-breakhis.yaml") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = system.SimpleClassificationSystem.load_from_checkpoint( "runs/03_20_2025_20_19_28/wandb/histopath/03_20_2025_20_19_28/checkpoints/epoch=210-step=41778.ckpt", torch.device("cpu"), cfg=cfg.system ) model.to(device) model.eval() print("Model loaded successfully!") image_size = cfg.data.image_size def preprocess_image(image: Image): '''Preprocess the image to be compatible with the model''' transform = transforms.Compose([ v2.Resize(image_size, antialias=True), v2.PILToTensor(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0).to(device) def predict(image_path: str): '''Predict the class of the image''' image = Image.open(image_path).convert("RGB") image = preprocess_image(image) with torch.no_grad(): output = model(image) output = output.squeeze() pred = torch.sigmoid(output).round() return f"Predicted Class: {'Benign' if int(pred.item())==0 else 'Malignant'}" examples = [ "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-40-001.png", "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-100-001.png", "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-200-013.png", "examples/breakhis/benign/adenosis/SOB_B_A-14-22549AB-400-006.png", "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-40-001.png", "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-100-010.png", "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-200-011.png", "examples/breakhis/benign/fibroadenoma/SOB_B_F-14-9133-400-006.png", "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-40-010.png", "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-100-024.png", "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-200-027.png", "examples/breakhis/malignant/ductal/SOB_M_DC-14-2523-400-013.png", "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-40-002.png", "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-100-031.png", "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-200-031.png", "examples/breakhis/malignant/lobular/SOB_M_LC-14-12204-400-034.png" ] examples_labels = ["benign"] * 8 + ["malignant"] * 8 examples_with_labels = [[example, label] for example, label in zip(examples, examples_labels)] # Gradio Interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="filepath", label="Upload an image"), outputs=gr.Label(), live=True, examples=examples_with_labels, title="Histopathology Image Classification", description="This application classifies histopathology images as either benign or malignant. Upload an image to get the prediction.", examples_per_page=len(examples_with_labels), ) gr.DeepLinkButton() if __name__ == "__main__": interface.launch(share=True)