ma4389 commited on
Commit
6baa52b
·
verified ·
1 Parent(s): 6db4aca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -47
app.py CHANGED
@@ -7,28 +7,28 @@ from PIL import Image
7
  import cv2
8
 
9
  ############################################
10
- # ========== UNET MODEL ====================
11
  ############################################
12
 
13
  class DoubleConv(nn.Module):
14
  def __init__(self, in_channels, out_channels):
15
  super().__init__()
16
- self.conv = nn.Sequential(
17
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
18
- nn.ReLU(inplace=True),
19
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
20
  nn.ReLU(inplace=True),
 
 
21
  )
22
 
23
  def forward(self, x):
24
- return self.conv(x)
25
 
26
 
27
  class DownSample(nn.Module):
28
  def __init__(self, in_channels, out_channels):
29
  super().__init__()
30
  self.conv = DoubleConv(in_channels, out_channels)
31
- self.pool = nn.MaxPool2d(2)
32
 
33
  def forward(self, x):
34
  down = self.conv(x)
@@ -39,12 +39,12 @@ class DownSample(nn.Module):
39
  class UpSample(nn.Module):
40
  def __init__(self, in_channels, out_channels):
41
  super().__init__()
42
- self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, 2)
43
  self.conv = DoubleConv(in_channels, out_channels)
44
 
45
  def forward(self, x1, x2):
46
  x1 = self.up(x1)
47
- x = torch.cat([x1, x2], dim=1)
48
  return self.conv(x)
49
 
50
 
@@ -52,34 +52,34 @@ class UNet(nn.Module):
52
  def __init__(self, in_channels=3, num_classes=1):
53
  super().__init__()
54
 
55
- self.down1 = DownSample(in_channels, 64)
56
- self.down2 = DownSample(64, 128)
57
- self.down3 = DownSample(128, 256)
58
- self.down4 = DownSample(256, 512)
59
 
60
- self.bottleneck = DoubleConv(512, 1024)
61
 
62
- self.up1 = UpSample(1024, 512)
63
- self.up2 = UpSample(512, 256)
64
- self.up3 = UpSample(256, 128)
65
- self.up4 = UpSample(128, 64)
66
 
67
- self.out = nn.Conv2d(64, num_classes, kernel_size=1)
68
 
69
  def forward(self, x):
70
- d1, p1 = self.down1(x)
71
- d2, p2 = self.down2(p1)
72
- d3, p3 = self.down3(p2)
73
- d4, p4 = self.down4(p3)
74
 
75
- b = self.bottleneck(p4)
76
 
77
- u1 = self.up1(b, d4)
78
- u2 = self.up2(u1, d3)
79
- u3 = self.up3(u2, d2)
80
- u4 = self.up4(u3, d1)
81
 
82
- return self.out(u4)
83
 
84
 
85
  ############################################
@@ -88,7 +88,7 @@ class UNet(nn.Module):
88
 
89
  device = torch.device("cpu")
90
 
91
- model = UNet()
92
  model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device))
93
  model.eval()
94
 
@@ -112,7 +112,7 @@ def dice_coefficient(pred, target, epsilon=1e-7):
112
  return ((2. * intersection + epsilon) / (union + epsilon)).item()
113
 
114
  ############################################
115
- # ========== PREPROCESS TIFF ===============
116
  ############################################
117
 
118
  def load_image(file):
@@ -126,7 +126,6 @@ def load_image(file):
126
  img_pil = Image.fromarray(img_np).convert("RGB")
127
  return img_pil, img_np
128
 
129
-
130
  ############################################
131
  # ========== PREDICTION ====================
132
  ############################################
@@ -137,7 +136,6 @@ def predict(image_file, mask_file=None):
137
  return None, "Please upload an image."
138
 
139
  image_pil, original_np = load_image(image_file)
140
-
141
  input_tensor = transform(image_pil).unsqueeze(0)
142
 
143
  with torch.no_grad():
@@ -145,19 +143,17 @@ def predict(image_file, mask_file=None):
145
  output = torch.sigmoid(output)
146
 
147
  pred_mask = output.squeeze().numpy()
148
- pred_mask_binary = (pred_mask > 0.5).astype(np.uint8)
149
 
150
- # Resize mask to original image size
151
- pred_mask_resized = cv2.resize(
152
- pred_mask_binary,
153
  (original_np.shape[1], original_np.shape[0])
154
  )
155
 
156
- # Create red overlay
157
  overlay = original_np.copy()
158
- overlay[pred_mask_resized == 1] = [255, 0, 0]
159
 
160
- # If mask provided → compute Dice
161
  if mask_file is not None:
162
  mask_pil, _ = load_image(mask_file)
163
  mask_tensor = transform(mask_pil.convert("L"))
@@ -166,20 +162,18 @@ def predict(image_file, mask_file=None):
166
 
167
  return overlay, "Prediction complete."
168
 
169
-
170
  ############################################
171
  # ========== GRADIO UI =====================
172
  ############################################
173
 
174
  description = """
175
- # 🧠 Brain Tumor Segmentation using UNet
176
-
177
- This model was trained on:
178
 
179
- 🔗 https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
 
180
 
181
- Upload a `.tif` MRI image to predict tumor segmentation.
182
- Optionally upload the true mask to compute Dice score.
183
  """
184
 
185
  demo = gr.Interface(
 
7
  import cv2
8
 
9
  ############################################
10
+ # ========== ORIGINAL TRAINING UNET =========
11
  ############################################
12
 
13
  class DoubleConv(nn.Module):
14
  def __init__(self, in_channels, out_channels):
15
  super().__init__()
16
+ self.conv_op = nn.Sequential(
17
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
 
 
18
  nn.ReLU(inplace=True),
19
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
20
+ nn.ReLU(inplace=True)
21
  )
22
 
23
  def forward(self, x):
24
+ return self.conv_op(x)
25
 
26
 
27
  class DownSample(nn.Module):
28
  def __init__(self, in_channels, out_channels):
29
  super().__init__()
30
  self.conv = DoubleConv(in_channels, out_channels)
31
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
32
 
33
  def forward(self, x):
34
  down = self.conv(x)
 
39
  class UpSample(nn.Module):
40
  def __init__(self, in_channels, out_channels):
41
  super().__init__()
42
+ self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
43
  self.conv = DoubleConv(in_channels, out_channels)
44
 
45
  def forward(self, x1, x2):
46
  x1 = self.up(x1)
47
+ x = torch.cat([x1, x2], 1)
48
  return self.conv(x)
49
 
50
 
 
52
  def __init__(self, in_channels=3, num_classes=1):
53
  super().__init__()
54
 
55
+ self.down_convolution_1 = DownSample(in_channels, 64)
56
+ self.down_convolution_2 = DownSample(64, 128)
57
+ self.down_convolution_3 = DownSample(128, 256)
58
+ self.down_convolution_4 = DownSample(256, 512)
59
 
60
+ self.bottle_neck = DoubleConv(512, 1024)
61
 
62
+ self.up_convolution_1 = UpSample(1024, 512)
63
+ self.up_convolution_2 = UpSample(512, 256)
64
+ self.up_convolution_3 = UpSample(256, 128)
65
+ self.up_convolution_4 = UpSample(128, 64)
66
 
67
+ self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
68
 
69
  def forward(self, x):
70
+ down_1, p1 = self.down_convolution_1(x)
71
+ down_2, p2 = self.down_convolution_2(p1)
72
+ down_3, p3 = self.down_convolution_3(p2)
73
+ down_4, p4 = self.down_convolution_4(p3)
74
 
75
+ b = self.bottle_neck(p4)
76
 
77
+ up_1 = self.up_convolution_1(b, down_4)
78
+ up_2 = self.up_convolution_2(up_1, down_3)
79
+ up_3 = self.up_convolution_3(up_2, down_2)
80
+ up_4 = self.up_convolution_4(up_3, down_1)
81
 
82
+ return self.out(up_4)
83
 
84
 
85
  ############################################
 
88
 
89
  device = torch.device("cpu")
90
 
91
+ model = UNet(in_channels=3, num_classes=1)
92
  model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device))
93
  model.eval()
94
 
 
112
  return ((2. * intersection + epsilon) / (union + epsilon)).item()
113
 
114
  ############################################
115
+ # ========== TIFF SAFE LOADER ==============
116
  ############################################
117
 
118
  def load_image(file):
 
126
  img_pil = Image.fromarray(img_np).convert("RGB")
127
  return img_pil, img_np
128
 
 
129
  ############################################
130
  # ========== PREDICTION ====================
131
  ############################################
 
136
  return None, "Please upload an image."
137
 
138
  image_pil, original_np = load_image(image_file)
 
139
  input_tensor = transform(image_pil).unsqueeze(0)
140
 
141
  with torch.no_grad():
 
143
  output = torch.sigmoid(output)
144
 
145
  pred_mask = output.squeeze().numpy()
146
+ pred_binary = (pred_mask > 0.5).astype(np.uint8)
147
 
148
+ # Resize mask back to original size
149
+ pred_resized = cv2.resize(
150
+ pred_binary,
151
  (original_np.shape[1], original_np.shape[0])
152
  )
153
 
 
154
  overlay = original_np.copy()
155
+ overlay[pred_resized == 1] = [255, 0, 0]
156
 
 
157
  if mask_file is not None:
158
  mask_pil, _ = load_image(mask_file)
159
  mask_tensor = transform(mask_pil.convert("L"))
 
162
 
163
  return overlay, "Prediction complete."
164
 
 
165
  ############################################
166
  # ========== GRADIO UI =====================
167
  ############################################
168
 
169
  description = """
170
+ # 🧠 Brain Tumor Segmentation (UNet)
 
 
171
 
172
+ Dataset used for training:
173
+ https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
174
 
175
+ Upload a `.tif` MRI image.
176
+ Optionally upload the ground-truth mask to compute Dice score.
177
  """
178
 
179
  demo = gr.Interface(