ma4389 commited on
Commit
6db4aca
·
verified ·
1 Parent(s): 45d4744

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -173
app.py CHANGED
@@ -1,174 +1,201 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.transforms as transforms
4
- import gradio as gr
5
- import numpy as np
6
- from PIL import Image
7
- import cv2
8
-
9
- ############################################
10
- # ========== UNET MODEL ====================
11
- ############################################
12
-
13
- class DoubleConv(nn.Module):
14
- def __init__(self, in_channels, out_channels):
15
- super().__init__()
16
- self.conv_op = nn.Sequential(
17
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
18
- nn.ReLU(inplace=True),
19
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
20
- nn.ReLU(inplace=True)
21
- )
22
-
23
- def forward(self, x):
24
- return self.conv_op(x)
25
-
26
- class DownSample(nn.Module):
27
- def __init__(self, in_channels, out_channels):
28
- super().__init__()
29
- self.conv = DoubleConv(in_channels, out_channels)
30
- self.pool = nn.MaxPool2d(2)
31
-
32
- def forward(self, x):
33
- down = self.conv(x)
34
- p = self.pool(down)
35
- return down, p
36
-
37
- class UpSample(nn.Module):
38
- def __init__(self, in_channels, out_channels):
39
- super().__init__()
40
- self.up = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
41
- self.conv = DoubleConv(in_channels, out_channels)
42
-
43
- def forward(self, x1, x2):
44
- x1 = self.up(x1)
45
- x = torch.cat([x1, x2], 1)
46
- return self.conv(x)
47
-
48
- class UNet(nn.Module):
49
- def __init__(self, in_channels=3, num_classes=1):
50
- super().__init__()
51
-
52
- self.down1 = DownSample(in_channels, 64)
53
- self.down2 = DownSample(64, 128)
54
- self.down3 = DownSample(128, 256)
55
- self.down4 = DownSample(256, 512)
56
-
57
- self.bottleneck = DoubleConv(512, 1024)
58
-
59
- self.up1 = UpSample(1024, 512)
60
- self.up2 = UpSample(512, 256)
61
- self.up3 = UpSample(256, 128)
62
- self.up4 = UpSample(128, 64)
63
-
64
- self.out = nn.Conv2d(64, num_classes, kernel_size=1)
65
-
66
- def forward(self, x):
67
- d1, p1 = self.down1(x)
68
- d2, p2 = self.down2(p1)
69
- d3, p3 = self.down3(p2)
70
- d4, p4 = self.down4(p3)
71
-
72
- b = self.bottleneck(p4)
73
-
74
- u1 = self.up1(b, d4)
75
- u2 = self.up2(u1, d3)
76
- u3 = self.up3(u2, d2)
77
- u4 = self.up4(u3, d1)
78
-
79
- return self.out(u4)
80
-
81
- ############################################
82
- # ========== LOAD MODEL ====================
83
- ############################################
84
-
85
- device = torch.device("cpu")
86
-
87
- model = UNet()
88
- model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device))
89
- model.eval()
90
-
91
- ############################################
92
- # ========== TRANSFORM =====================
93
- ############################################
94
-
95
- transform = transforms.Compose([
96
- transforms.Resize((256, 256)),
97
- transforms.ToTensor()
98
- ])
99
-
100
- ############################################
101
- # ========== DICE ==========================
102
- ############################################
103
-
104
- def dice_coefficient(pred, target, epsilon=1e-7):
105
- pred = (pred > 0.5).float()
106
- intersection = (pred * target).sum()
107
- union = pred.sum() + target.sum()
108
- return ((2. * intersection + epsilon) / (union + epsilon)).item()
109
-
110
- ############################################
111
- # ========== INFERENCE FUNCTION ============
112
- ############################################
113
-
114
- def predict(image, mask=None):
115
-
116
- image_pil = Image.fromarray(image).convert("RGB")
117
- input_tensor = transform(image_pil).unsqueeze(0)
118
-
119
- with torch.no_grad():
120
- output = model(input_tensor)
121
- output = torch.sigmoid(output)
122
-
123
- pred_mask = output.squeeze().numpy()
124
- pred_mask_binary = (pred_mask > 0.5).astype(np.uint8)
125
-
126
- # Resize mask back to original size
127
- pred_mask_resized = cv2.resize(
128
- pred_mask_binary,
129
- (image.shape[1], image.shape[0])
130
- )
131
-
132
- # Create overlay
133
- overlay = image.copy()
134
- overlay[pred_mask_resized == 1] = [255, 0, 0]
135
-
136
- if mask is not None:
137
- mask_pil = Image.fromarray(mask).convert("L")
138
- mask_tensor = transform(mask_pil)
139
- dice = dice_coefficient(torch.tensor(pred_mask), mask_tensor)
140
- return overlay, f"Dice Score: {round(dice, 4)}"
141
-
142
- return overlay, "Mask predicted successfully!"
143
-
144
- ############################################
145
- # ========== GRADIO UI =====================
146
- ############################################
147
-
148
- description = """
149
- # 🧠 Brain Tumor Segmentation (UNet)
150
-
151
- This model was trained on:
152
-
153
- 🔗 https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
154
-
155
- Upload an MRI image to see tumor segmentation.
156
- Optionally upload the true mask to compute Dice score.
157
- """
158
-
159
- demo = gr.Interface(
160
- fn=predict,
161
- inputs=[
162
- gr.Image(type="numpy", label="Upload MRI Image"),
163
- gr.Image(type="numpy", label="Optional Ground Truth Mask")
164
- ],
165
- outputs=[
166
- gr.Image(label="Predicted Overlay"),
167
- gr.Textbox(label="Info")
168
- ],
169
- title="UNet Brain Tumor Segmentation",
170
- description=description
171
- )
172
-
173
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  demo.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+
9
+ ############################################
10
+ # ========== UNET MODEL ====================
11
+ ############################################
12
+
13
+ class DoubleConv(nn.Module):
14
+ def __init__(self, in_channels, out_channels):
15
+ super().__init__()
16
+ self.conv = nn.Sequential(
17
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
20
+ nn.ReLU(inplace=True),
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.conv(x)
25
+
26
+
27
+ class DownSample(nn.Module):
28
+ def __init__(self, in_channels, out_channels):
29
+ super().__init__()
30
+ self.conv = DoubleConv(in_channels, out_channels)
31
+ self.pool = nn.MaxPool2d(2)
32
+
33
+ def forward(self, x):
34
+ down = self.conv(x)
35
+ p = self.pool(down)
36
+ return down, p
37
+
38
+
39
+ class UpSample(nn.Module):
40
+ def __init__(self, in_channels, out_channels):
41
+ super().__init__()
42
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, 2)
43
+ self.conv = DoubleConv(in_channels, out_channels)
44
+
45
+ def forward(self, x1, x2):
46
+ x1 = self.up(x1)
47
+ x = torch.cat([x1, x2], dim=1)
48
+ return self.conv(x)
49
+
50
+
51
+ class UNet(nn.Module):
52
+ def __init__(self, in_channels=3, num_classes=1):
53
+ super().__init__()
54
+
55
+ self.down1 = DownSample(in_channels, 64)
56
+ self.down2 = DownSample(64, 128)
57
+ self.down3 = DownSample(128, 256)
58
+ self.down4 = DownSample(256, 512)
59
+
60
+ self.bottleneck = DoubleConv(512, 1024)
61
+
62
+ self.up1 = UpSample(1024, 512)
63
+ self.up2 = UpSample(512, 256)
64
+ self.up3 = UpSample(256, 128)
65
+ self.up4 = UpSample(128, 64)
66
+
67
+ self.out = nn.Conv2d(64, num_classes, kernel_size=1)
68
+
69
+ def forward(self, x):
70
+ d1, p1 = self.down1(x)
71
+ d2, p2 = self.down2(p1)
72
+ d3, p3 = self.down3(p2)
73
+ d4, p4 = self.down4(p3)
74
+
75
+ b = self.bottleneck(p4)
76
+
77
+ u1 = self.up1(b, d4)
78
+ u2 = self.up2(u1, d3)
79
+ u3 = self.up3(u2, d2)
80
+ u4 = self.up4(u3, d1)
81
+
82
+ return self.out(u4)
83
+
84
+
85
+ ############################################
86
+ # ========== LOAD MODEL ====================
87
+ ############################################
88
+
89
+ device = torch.device("cpu")
90
+
91
+ model = UNet()
92
+ model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device))
93
+ model.eval()
94
+
95
+ ############################################
96
+ # ========== TRANSFORM =====================
97
+ ############################################
98
+
99
+ transform = transforms.Compose([
100
+ transforms.Resize((256, 256)),
101
+ transforms.ToTensor()
102
+ ])
103
+
104
+ ############################################
105
+ # ========== DICE FUNCTION =================
106
+ ############################################
107
+
108
+ def dice_coefficient(pred, target, epsilon=1e-7):
109
+ pred = (pred > 0.5).float()
110
+ intersection = (pred * target).sum()
111
+ union = pred.sum() + target.sum()
112
+ return ((2. * intersection + epsilon) / (union + epsilon)).item()
113
+
114
+ ############################################
115
+ # ========== PREPROCESS TIFF ===============
116
+ ############################################
117
+
118
+ def load_image(file):
119
+ img = Image.open(file.name)
120
+ img_np = np.array(img)
121
+
122
+ # Handle 16-bit TIFF
123
+ if img_np.dtype == np.uint16:
124
+ img_np = (img_np / 256).astype(np.uint8)
125
+
126
+ img_pil = Image.fromarray(img_np).convert("RGB")
127
+ return img_pil, img_np
128
+
129
+
130
+ ############################################
131
+ # ========== PREDICTION ====================
132
+ ############################################
133
+
134
+ def predict(image_file, mask_file=None):
135
+
136
+ if image_file is None:
137
+ return None, "Please upload an image."
138
+
139
+ image_pil, original_np = load_image(image_file)
140
+
141
+ input_tensor = transform(image_pil).unsqueeze(0)
142
+
143
+ with torch.no_grad():
144
+ output = model(input_tensor)
145
+ output = torch.sigmoid(output)
146
+
147
+ pred_mask = output.squeeze().numpy()
148
+ pred_mask_binary = (pred_mask > 0.5).astype(np.uint8)
149
+
150
+ # Resize mask to original image size
151
+ pred_mask_resized = cv2.resize(
152
+ pred_mask_binary,
153
+ (original_np.shape[1], original_np.shape[0])
154
+ )
155
+
156
+ # Create red overlay
157
+ overlay = original_np.copy()
158
+ overlay[pred_mask_resized == 1] = [255, 0, 0]
159
+
160
+ # If mask provided → compute Dice
161
+ if mask_file is not None:
162
+ mask_pil, _ = load_image(mask_file)
163
+ mask_tensor = transform(mask_pil.convert("L"))
164
+ dice = dice_coefficient(torch.tensor(pred_mask), mask_tensor)
165
+ return overlay, f"Dice Score: {round(dice,4)}"
166
+
167
+ return overlay, "Prediction complete."
168
+
169
+
170
+ ############################################
171
+ # ========== GRADIO UI =====================
172
+ ############################################
173
+
174
+ description = """
175
+ # 🧠 Brain Tumor Segmentation using UNet
176
+
177
+ This model was trained on:
178
+
179
+ 🔗 https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
180
+
181
+ Upload a `.tif` MRI image to predict tumor segmentation.
182
+ Optionally upload the true mask to compute Dice score.
183
+ """
184
+
185
+ demo = gr.Interface(
186
+ fn=predict,
187
+ inputs=[
188
+ gr.File(file_types=[".tif", ".tiff", ".png", ".jpg"], label="Upload MRI Image"),
189
+ gr.File(file_types=[".tif", ".tiff"], label="Optional Ground Truth Mask")
190
+ ],
191
+ outputs=[
192
+ gr.Image(label="Predicted Overlay"),
193
+ gr.Textbox(label="Result")
194
+ ],
195
+ title="UNet Brain Tumor Segmentation",
196
+ description=description,
197
+ allow_flagging="never"
198
+ )
199
+
200
+ if __name__ == "__main__":
201
  demo.launch()