ma4389 commited on
Commit
45d4744
·
verified ·
1 Parent(s): 5018a9a

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +174 -0
  2. my_checkpoint.pth +3 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+
9
+ ############################################
10
+ # ========== UNET MODEL ====================
11
+ ############################################
12
+
13
+ class DoubleConv(nn.Module):
14
+ def __init__(self, in_channels, out_channels):
15
+ super().__init__()
16
+ self.conv_op = nn.Sequential(
17
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
20
+ nn.ReLU(inplace=True)
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.conv_op(x)
25
+
26
+ class DownSample(nn.Module):
27
+ def __init__(self, in_channels, out_channels):
28
+ super().__init__()
29
+ self.conv = DoubleConv(in_channels, out_channels)
30
+ self.pool = nn.MaxPool2d(2)
31
+
32
+ def forward(self, x):
33
+ down = self.conv(x)
34
+ p = self.pool(down)
35
+ return down, p
36
+
37
+ class UpSample(nn.Module):
38
+ def __init__(self, in_channels, out_channels):
39
+ super().__init__()
40
+ self.up = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
41
+ self.conv = DoubleConv(in_channels, out_channels)
42
+
43
+ def forward(self, x1, x2):
44
+ x1 = self.up(x1)
45
+ x = torch.cat([x1, x2], 1)
46
+ return self.conv(x)
47
+
48
+ class UNet(nn.Module):
49
+ def __init__(self, in_channels=3, num_classes=1):
50
+ super().__init__()
51
+
52
+ self.down1 = DownSample(in_channels, 64)
53
+ self.down2 = DownSample(64, 128)
54
+ self.down3 = DownSample(128, 256)
55
+ self.down4 = DownSample(256, 512)
56
+
57
+ self.bottleneck = DoubleConv(512, 1024)
58
+
59
+ self.up1 = UpSample(1024, 512)
60
+ self.up2 = UpSample(512, 256)
61
+ self.up3 = UpSample(256, 128)
62
+ self.up4 = UpSample(128, 64)
63
+
64
+ self.out = nn.Conv2d(64, num_classes, kernel_size=1)
65
+
66
+ def forward(self, x):
67
+ d1, p1 = self.down1(x)
68
+ d2, p2 = self.down2(p1)
69
+ d3, p3 = self.down3(p2)
70
+ d4, p4 = self.down4(p3)
71
+
72
+ b = self.bottleneck(p4)
73
+
74
+ u1 = self.up1(b, d4)
75
+ u2 = self.up2(u1, d3)
76
+ u3 = self.up3(u2, d2)
77
+ u4 = self.up4(u3, d1)
78
+
79
+ return self.out(u4)
80
+
81
+ ############################################
82
+ # ========== LOAD MODEL ====================
83
+ ############################################
84
+
85
+ device = torch.device("cpu")
86
+
87
+ model = UNet()
88
+ model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device))
89
+ model.eval()
90
+
91
+ ############################################
92
+ # ========== TRANSFORM =====================
93
+ ############################################
94
+
95
+ transform = transforms.Compose([
96
+ transforms.Resize((256, 256)),
97
+ transforms.ToTensor()
98
+ ])
99
+
100
+ ############################################
101
+ # ========== DICE ==========================
102
+ ############################################
103
+
104
+ def dice_coefficient(pred, target, epsilon=1e-7):
105
+ pred = (pred > 0.5).float()
106
+ intersection = (pred * target).sum()
107
+ union = pred.sum() + target.sum()
108
+ return ((2. * intersection + epsilon) / (union + epsilon)).item()
109
+
110
+ ############################################
111
+ # ========== INFERENCE FUNCTION ============
112
+ ############################################
113
+
114
+ def predict(image, mask=None):
115
+
116
+ image_pil = Image.fromarray(image).convert("RGB")
117
+ input_tensor = transform(image_pil).unsqueeze(0)
118
+
119
+ with torch.no_grad():
120
+ output = model(input_tensor)
121
+ output = torch.sigmoid(output)
122
+
123
+ pred_mask = output.squeeze().numpy()
124
+ pred_mask_binary = (pred_mask > 0.5).astype(np.uint8)
125
+
126
+ # Resize mask back to original size
127
+ pred_mask_resized = cv2.resize(
128
+ pred_mask_binary,
129
+ (image.shape[1], image.shape[0])
130
+ )
131
+
132
+ # Create overlay
133
+ overlay = image.copy()
134
+ overlay[pred_mask_resized == 1] = [255, 0, 0]
135
+
136
+ if mask is not None:
137
+ mask_pil = Image.fromarray(mask).convert("L")
138
+ mask_tensor = transform(mask_pil)
139
+ dice = dice_coefficient(torch.tensor(pred_mask), mask_tensor)
140
+ return overlay, f"Dice Score: {round(dice, 4)}"
141
+
142
+ return overlay, "Mask predicted successfully!"
143
+
144
+ ############################################
145
+ # ========== GRADIO UI =====================
146
+ ############################################
147
+
148
+ description = """
149
+ # 🧠 Brain Tumor Segmentation (UNet)
150
+
151
+ This model was trained on:
152
+
153
+ 🔗 https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
154
+
155
+ Upload an MRI image to see tumor segmentation.
156
+ Optionally upload the true mask to compute Dice score.
157
+ """
158
+
159
+ demo = gr.Interface(
160
+ fn=predict,
161
+ inputs=[
162
+ gr.Image(type="numpy", label="Upload MRI Image"),
163
+ gr.Image(type="numpy", label="Optional Ground Truth Mask")
164
+ ],
165
+ outputs=[
166
+ gr.Image(label="Predicted Overlay"),
167
+ gr.Textbox(label="Info")
168
+ ],
169
+ title="UNet Brain Tumor Segmentation",
170
+ description=description
171
+ )
172
+
173
+ if __name__ == "__main__":
174
+ demo.launch()
my_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33292122b7a931f081e57e9aec82a391c92f3d6154bec3f9e61d5a2c82a0c9dd
3
+ size 124146577
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ opencv-python
5
+ numpy
6
+ pillow
7
+ matplotlib