File size: 4,898 Bytes
4b0207f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed0fb60
4b0207f
ed0fb60
4b0207f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bc4c8f
4b0207f
8bc4c8f
4b0207f
 
8bc4c8f
ed0fb60
4b0207f
 
 
 
ed0fb60
 
 
 
4b0207f
 
ed0fb60
 
 
4b0207f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc6e64a
4b0207f
 
 
 
8bc4c8f
4b0207f
 
fc6e64a
 
4b0207f
4b43aaf
 
 
 
 
 
4b0207f
9667248
4b0207f
 
fc6e64a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
app.py
------
Gradio demo for ISIC 2018 Skin Lesion Segmentation using a trained U-Net.

Hosted on Hugging Face Spaces.
Model weights are downloaded from the HF Hub model repo on first run.
"""

import os
import numpy as np
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MODEL_REPO   = "pavanpraneeth/isic-unet"
MODEL_FILE   = "best_model.pth"
IMAGE_SIZE   = 256
THRESHOLD    = 0.5
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

DEVICE = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)

# ---------------------------------------------------------------------------
# Load model (once at startup)
# ---------------------------------------------------------------------------
from model import UNet  # model.py is alongside app.py in the Space repo

def load_model() -> torch.nn.Module:
    ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
    model = UNet(in_channels=3, out_channels=1)
    state = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(state["model_state_dict"])
    model.eval().to(DEVICE)
    print(f"[app] Model loaded from {MODEL_REPO} on {DEVICE}")
    return model

MODEL = load_model()

# ---------------------------------------------------------------------------
# Preprocessing / postprocessing helpers
# ---------------------------------------------------------------------------

def preprocess(pil_img: Image.Image) -> torch.Tensor:
    """Resize, normalise (ImageNet), convert to tensor."""
    pil = pil_img.convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
    arr = np.array(pil, dtype=np.float32) / 255.0
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD          # (H, W, 3)
    tensor = torch.from_numpy(arr.transpose(2, 0, 1))   # (3, H, W)
    return tensor.unsqueeze(0).to(DEVICE)               # (1, 3, H, W)


def postprocess_mask(pred: torch.Tensor) -> np.ndarray:
    """Convert raw sigmoid output → uint8 mask image (0 or 255)."""
    mask = (pred.squeeze().cpu().numpy() > THRESHOLD).astype(np.uint8) * 255
    return mask


def make_overlay(original_rgb: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Overlay mask boundary on original image in red."""
    h, w = mask.shape
    orig_resized = np.array(
        Image.fromarray(original_rgb).resize((w, h))
    ).copy()

    # Draw red where mask == 255
    overlay = orig_resized.copy()
    overlay[mask > 0] = (
        overlay[mask > 0] * 0.4 + np.array([255, 0, 0]) * 0.6
    ).astype(np.uint8)
    return overlay


# ---------------------------------------------------------------------------
# Inference function (called by Gradio)
# ---------------------------------------------------------------------------

def segment(pil_img):
    """Run inference and return (mask_image, overlay_image)."""
    if pil_img is None:
        return None, None

    pil_img = pil_img.convert("RGB")
    tensor = preprocess(pil_img)
    with torch.no_grad():
        pred = MODEL(tensor)         # (1, 1, 256, 256)

    mask   = postprocess_mask(pred)  # (256, 256) uint8
    
    # Needs numpy array for overlay drawing
    orig_np = np.array(pil_img)
    overlay = make_overlay(orig_np, mask)

    mask_rgb = np.stack([mask, mask, mask], axis=-1)  # grey → RGB for display
    
    # Return explicit PIL Images, avoiding gradio numpy bugs
    return Image.fromarray(mask_rgb), Image.fromarray(overlay)


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------

DESCRIPTION = """
## 🔬 ISIC 2018 Skin Lesion Segmentation

Upload a dermoscopy image to get an instant binary segmentation mask from a trained **U-Net**.

| Metric | Test Set Score |
|--------|---------------|
| Dice   | **0.9301 ± 0.0621** |
| IoU    | **0.8744 ± 0.0891** |

*Trained on ISIC 2018 Task 1 (568 images, 70/15/15 split).*
"""

with gr.Blocks(title="ISIC Skin Lesion Segmentation") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column():
            inp = gr.Image(label="Input Image", type="pil")
            btn = gr.Button("Segment 🔍", variant="primary")
        with gr.Column():
            out_mask    = gr.Image(label="Predicted Mask", type="pil")
            out_overlay = gr.Image(label="Overlay on Original", type="pil")

    btn.click(
        fn=segment,
        inputs=inp,
        outputs=[out_mask, out_overlay],
        api_name="predict"
    )



if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft())