AJain1234 commited on
Commit
a0feb74
·
verified ·
1 Parent(s): 43c8d65

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -6,8 +6,9 @@ from experiments.kmeans_segmenter import generate_kmeans_segmented_image
6
  from experiments.enhanced_kmeans_segmenter import slic_kmeans
7
  from experiments.watershed_segmenter import generate_watershed
8
  from experiments.felzenszwalb_segmentation import segment
9
- from experiments.SegNet.efficient_b0_backbone.architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE, IMAGE_SIZE
10
  from experiments.SegNet.vgg_backbone.model import SegNet
 
11
  import numpy as np
12
  from PIL import Image
13
  from matplotlib import cm
@@ -81,14 +82,14 @@ def generate_felzenszwalb(image_path, sigma, k, min_size_factor):
81
 
82
  def SegNet_efficient_b0(image_path):
83
  model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
84
- model.load_state_dict(torch.load("segnet_efficientnet_voc.pth", map_location=DEVICE))
85
  model.eval()
86
  transform = transforms.Compose([
87
- transforms.Resize(IMAGE_SIZE),
88
- transforms.ToTensor(),
89
- transforms.Normalize([0.485, 0.456, 0.406],
90
- [0.229, 0.224, 0.225])
91
- ])
92
 
93
  image = Image.open(image_path).convert("RGB")
94
  input_tensor = transform(image).unsqueeze(0).to(DEVICE)
@@ -98,7 +99,7 @@ def SegNet_efficient_b0(image_path):
98
  pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
99
 
100
  # Convert original image for Gradio display
101
- original_image_resized = image.resize(IMAGE_SIZE)
102
 
103
  # Convert predicted mask to a color image using a colormap
104
  colormap = cm.get_cmap('nipy_spectral')
@@ -108,6 +109,52 @@ def SegNet_efficient_b0(image_path):
108
 
109
  return original_image_resized, mask_pil
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  with gr.Blocks() as demo:
112
  gr.Markdown("# Image Segmentation using Classical CV")
113
 
@@ -120,10 +167,10 @@ with gr.Blocks() as demo:
120
  threshold_text = gr.Textbox(label="Threshold Comparison", value="", interactive=False)
121
 
122
  with gr.Column(scale=2):
123
- image_output = gr.Image(label="Original Image", container=False)
124
- histogram_output = gr.Image(label="Histogram", container=False)
125
- segmented_image_output = gr.Image(label="Our Segmented Image", container=False)
126
- opencv_segmented_image_output = gr.Image(label="OpenCV Segmented Image", container=False)
127
  display_btn.click(
128
  fn=generate_segmented_image,
129
  inputs=file_input,
@@ -138,8 +185,8 @@ with gr.Blocks() as demo:
138
  kmeans_threshold_text = gr.Textbox(label="K-means Info", value="", interactive=False)
139
 
140
  with gr.Column(scale=2):
141
- kmeans_image_output = gr.Image(label="Original Image", container=False)
142
- kmeans_segmented_image_output = gr.Image(label="K-means Segmented Image", container=False)
143
 
144
  kmeans_display_btn.click(
145
  fn=generate_kmeans,
@@ -156,8 +203,8 @@ with gr.Blocks() as demo:
156
  slic_display_btn = gr.Button("Segment this image")
157
 
158
  with gr.Column(scale=2):
159
- slic_image_output = gr.Image(label="Original Image", container=False)
160
- slic_segmented_image_output = gr.Image(label="SLIC Segmented Image", container=False)
161
 
162
  slic_display_btn.click(
163
  fn=generate_slic,
@@ -172,8 +219,8 @@ with gr.Blocks() as demo:
172
  watershed_display_btn = gr.Button("Segment this image")
173
 
174
  with gr.Column(scale=2):
175
- watershed_image_output = gr.Image(label="Original Image", container=False)
176
- watershed_segmented_image_output = gr.Image(label="watershed Segmented Image", container=False)
177
 
178
  watershed_display_btn.click(
179
  fn=generate_watershed,
@@ -190,8 +237,8 @@ with gr.Blocks() as demo:
190
  felzenszwalb_display_btn = gr.Button("Segment this image")
191
 
192
  with gr.Column(scale=2):
193
- felzenszwalb_image_output = gr.Image(label="Original Image", container=False)
194
- felzenszwalb_segmented_image_output = gr.Image(label="felzenszwalb Segmented Image", container=False)
195
 
196
  felzenszwalb_display_btn.click(
197
  fn=generate_felzenszwalb,
@@ -205,8 +252,8 @@ with gr.Blocks() as demo:
205
  segnet_display_btn = gr.Button("Segment this image")
206
 
207
  with gr.Column(scale=2):
208
- segnet_image_output = gr.Image(label="Original Image", container=False)
209
- segnet_segmented_image_output = gr.Image(label="SegNet Segmented Image", container=False)
210
 
211
  segnet_display_btn.click(
212
  fn=SegNet_efficient_b0,
@@ -220,14 +267,33 @@ with gr.Blocks() as demo:
220
  segnet_display_btn = gr.Button("Segment this image")
221
 
222
  with gr.Column(scale=2):
223
- segnet_image_output = gr.Image(label="Original Image", container=False)
224
- segnet_segmented_image_output = gr.Image(label="SegNet VGG Segmented Image", container=False)
225
 
226
  segnet_display_btn.click(
227
  fn=generate_segnet_vgg,
228
  inputs=[segnet_file_input],
229
  outputs=[segnet_image_output,segnet_segmented_image_output]
230
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  if __name__ == "__main__":
232
  demo.launch()
233
 
 
6
  from experiments.enhanced_kmeans_segmenter import slic_kmeans
7
  from experiments.watershed_segmenter import generate_watershed
8
  from experiments.felzenszwalb_segmentation import segment
9
+ from experiments.SegNet.efficient_b0_backbone.architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE
10
  from experiments.SegNet.vgg_backbone.model import SegNet
11
+ # from experiments.ensemble_method import generate_ensemble_segmentation
12
  import numpy as np
13
  from PIL import Image
14
  from matplotlib import cm
 
82
 
83
  def SegNet_efficient_b0(image_path):
84
  model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
85
+ model.load_state_dict(torch.load("saved_models/segnet_efficientnet_camvid.pth", map_location=DEVICE))
86
  model.eval()
87
  transform = transforms.Compose([
88
+ transforms.Resize((360, 480)), # Or larger if needed
89
+ transforms.ToTensor(),
90
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
91
+ std=[0.229, 0.224, 0.225])
92
+ ])
93
 
94
  image = Image.open(image_path).convert("RGB")
95
  input_tensor = transform(image).unsqueeze(0).to(DEVICE)
 
99
  pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
100
 
101
  # Convert original image for Gradio display
102
+ original_image_resized = image
103
 
104
  # Convert predicted mask to a color image using a colormap
105
  colormap = cm.get_cmap('nipy_spectral')
 
109
 
110
  return original_image_resized, mask_pil
111
 
112
+ def ensemble_segmentation(image_path):
113
+ """
114
+ Ensemble segmentation combining SegNet and Otsu,
115
+ assuming Otsu produces a mask with the foreground as black (value 0)
116
+ and background as white (value 255).
117
+
118
+ In this ensemble, we force the SegNet prediction to background (class 0)
119
+ where Otsu indicates background (after inversion, i.e., where otsu_bin==0).
120
+
121
+ Parameters:
122
+ image_path (str): Path to the input image.
123
+
124
+ Returns:
125
+ original_image: The original resized image used for segmentation.
126
+ segnet_mask_pil: SegNet multi-class segmentation output (PIL image).
127
+ otsu_mask_pil: The original Otsu binary segmentation mask (PIL image).
128
+ ensemble_mask_pil: Final ensemble segmentation mask (PIL image).
129
+ """
130
+ # Run SegNet segmentation (model outputs a multi-class mask).
131
+ segnet_orig, segnet_mask_pil = SegNet_efficient_b0(image_path)
132
+ # Convert SegNet output to a NumPy array (assumed grayscale labeling, e.g., background=0).
133
+ segnet_mask_np = np.array(segnet_mask_pil.convert("L"))
134
+
135
+ # Run Otsu segmentation. (generate_segmented_image returns several outputs.)
136
+ _, otsu_segmented_pil, _, _, _ = generate_segmented_image(image_path)
137
+
138
+ # Resize Otsu mask to match SegNet output shape, e.g., (480, 360) if SegNet works in that resolution.
139
+ resized_shape = (segnet_mask_np.shape[1], segnet_mask_np.shape[0])
140
+ otsu_mask_resized = otsu_segmented_pil.resize(resized_shape, Image.NEAREST)
141
+ otsu_mask_np = np.array(otsu_mask_resized)
142
+
143
+ # Invert Otsu's binary mask:
144
+ # Assuming that in otsu_mask_np, foreground is black (0) and background is white (255),
145
+ # we build a binary mask where "1" represents the object's area.
146
+ otsu_bin = (otsu_mask_np == 0).astype(np.uint8) # Now, foreground is 1 and background is 0.
147
+
148
+ # Create the ensemble segmentation:
149
+ # Where Otsu indicates foreground (otsu_bin==1), keep SegNet's prediction;
150
+ # where Otsu is background (otsu_bin==0), force it to background class (0).
151
+ ensemble_seg = np.where(otsu_bin == 1, segnet_mask_np, 0)
152
+
153
+ # Convert back to a PIL image.
154
+ ensemble_mask_pil = Image.fromarray(ensemble_seg.astype(np.uint8))
155
+
156
+ return segnet_orig, segnet_mask_pil, otsu_segmented_pil, ensemble_mask_pil
157
+
158
  with gr.Blocks() as demo:
159
  gr.Markdown("# Image Segmentation using Classical CV")
160
 
 
167
  threshold_text = gr.Textbox(label="Threshold Comparison", value="", interactive=False)
168
 
169
  with gr.Column(scale=2):
170
+ image_output = gr.Image(label="Original Image")
171
+ histogram_output = gr.Image(label="Histogram")
172
+ segmented_image_output = gr.Image(label="Our Segmented Image")
173
+ opencv_segmented_image_output = gr.Image(label="OpenCV Segmented Image")
174
  display_btn.click(
175
  fn=generate_segmented_image,
176
  inputs=file_input,
 
185
  kmeans_threshold_text = gr.Textbox(label="K-means Info", value="", interactive=False)
186
 
187
  with gr.Column(scale=2):
188
+ kmeans_image_output = gr.Image(label="Original Image")
189
+ kmeans_segmented_image_output = gr.Image(label="K-means Segmented Image")
190
 
191
  kmeans_display_btn.click(
192
  fn=generate_kmeans,
 
203
  slic_display_btn = gr.Button("Segment this image")
204
 
205
  with gr.Column(scale=2):
206
+ slic_image_output = gr.Image(label="Original Image",container=True)
207
+ slic_segmented_image_output = gr.Image(label="SLIC Segmented Image",container=True)
208
 
209
  slic_display_btn.click(
210
  fn=generate_slic,
 
219
  watershed_display_btn = gr.Button("Segment this image")
220
 
221
  with gr.Column(scale=2):
222
+ watershed_image_output = gr.Image(label="Original Image",container=True)
223
+ watershed_segmented_image_output = gr.Image(label="watershed Segmented Image",container=True)
224
 
225
  watershed_display_btn.click(
226
  fn=generate_watershed,
 
237
  felzenszwalb_display_btn = gr.Button("Segment this image")
238
 
239
  with gr.Column(scale=2):
240
+ felzenszwalb_image_output = gr.Image(label="Original Image",container=True)
241
+ felzenszwalb_segmented_image_output = gr.Image(label="felzenszwalb Segmented Image",container=True)
242
 
243
  felzenszwalb_display_btn.click(
244
  fn=generate_felzenszwalb,
 
252
  segnet_display_btn = gr.Button("Segment this image")
253
 
254
  with gr.Column(scale=2):
255
+ segnet_image_output = gr.Image(label="Original Image")
256
+ segnet_segmented_image_output = gr.Image(label="SegNet Segmented Image")
257
 
258
  segnet_display_btn.click(
259
  fn=SegNet_efficient_b0,
 
267
  segnet_display_btn = gr.Button("Segment this image")
268
 
269
  with gr.Column(scale=2):
270
+ segnet_image_output = gr.Image(label="Original Image")
271
+ segnet_segmented_image_output = gr.Image(label="SegNet VGG Segmented Image")
272
 
273
  segnet_display_btn.click(
274
  fn=generate_segnet_vgg,
275
  inputs=[segnet_file_input],
276
  outputs=[segnet_image_output,segnet_segmented_image_output]
277
  )
278
+ # In app.py
279
+ with gr.TabItem("Ensemble Segmentation"):
280
+ with gr.Row():
281
+ with gr.Column(scale=1):
282
+ ensemble_file_input = gr.File(label="Upload Image File")
283
+ ensemble_display_btn = gr.Button("Segment with Ensemble Method")
284
+
285
+ with gr.Column(scale=2):
286
+ ensemble_image_output = gr.Image(label="Original Image")
287
+ ensemble_mask = gr.Image(label="Ensemble Segmented Image")
288
+ ensemble_segnet_segmented_output = gr.Image(label="SegNet Efficient B0 Segmented Image")
289
+ ensemble_otsu_segmented_output = gr.Image(label="Otsu Segmented Image")
290
+
291
+ ensemble_display_btn.click(
292
+ fn=ensemble_segmentation,
293
+ inputs=[ensemble_file_input],
294
+ outputs=[ensemble_image_output, ensemble_segnet_segmented_output, ensemble_otsu_segmented_output, ensemble_mask]
295
+ )
296
+
297
  if __name__ == "__main__":
298
  demo.launch()
299
 
experiments/SegNet/efficient_b0_backbone/architecture.py CHANGED
@@ -1,14 +1,20 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torchvision import models, transforms
5
  from torchvision.datasets import VOCSegmentation
6
  from torch.utils.data import DataLoader
 
 
7
  from PIL import Image
8
  import numpy as np
9
  import wandb
 
10
  import os
11
  import matplotlib.pyplot as plt
 
 
12
 
13
  torch.manual_seed(42)
14
  np.random.seed(42)
@@ -18,70 +24,155 @@ np.random.seed(42)
18
  EPOCHS = 25
19
  BATCH_SIZE = 8
20
  LR = 1e-3
21
- NUM_CLASSES = 21 # Pascal VOC has 21 classes including background
22
- IMAGE_SIZE = (256, 256)
23
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
 
25
- # wandb.init(project="segnet-efficientnet-voc", config={
26
  # "epochs": EPOCHS,
27
  # "batch_size": BATCH_SIZE,
28
  # "learning_rate": LR,
29
  # "architecture": "SegNet-EfficientNet",
30
- # "dataset": "PascalVOC2012"
31
  # })
32
 
33
  class SegNetEfficientNet(nn.Module):
34
- def __init__(self, num_classes):
35
  super(SegNetEfficientNet, self).__init__()
36
  base_model = models.efficientnet_b0(pretrained=True)
37
  features = list(base_model.features.children())
38
 
39
- # Encoder: Use EfficientNet blocks
40
- self.encoder = nn.Sequential(*features)
41
 
42
- # Decoder: Up-convolutions
43
  self.decoder = nn.Sequential(
44
  nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
 
45
  nn.ReLU(inplace=True),
 
46
  nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
 
47
  nn.ReLU(inplace=True),
 
48
  nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
 
49
  nn.ReLU(inplace=True),
 
50
  nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
 
 
 
 
 
51
  nn.ReLU(inplace=True),
52
- nn.ConvTranspose2d(64, num_classes, kernel_size=1)
53
  )
54
 
 
 
55
  def forward(self, x):
56
- x = self.encoder(x)
57
- x = self.decoder(x)
58
- x = F.interpolate(x, size=IMAGE_SIZE, mode='bilinear', align_corners=False)
 
 
59
  return x
60
 
61
- class VOCSegmentationDataset(VOCSegmentation):
62
- def __init__(self, root, image_set='train', transform=None, target_transform=None):
63
- super().__init__(root=root, year='2012', image_set=image_set, download=True)
 
 
 
 
 
 
 
 
 
 
 
 
64
  self.transform = transform
65
  self.target_transform = target_transform
66
 
67
- def __getitem__(self, index):
68
- img, target = super().__getitem__(index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  if self.transform:
70
- img = self.transform(img)
71
- if self.target_transform:
72
- target = self.target_transform(target)
73
- target = torch.as_tensor(np.array(target), dtype=torch.long)
74
- return img, target
 
 
75
  if __name__ == "__main__":
76
- image_transform = transforms.Compose([
77
- transforms.Resize(IMAGE_SIZE),
78
- transforms.ToTensor(),
79
- transforms.Normalize([0.485, 0.456, 0.406],
80
- [0.229, 0.224, 0.225])
81
- ])
82
- mask_transform = transforms.Resize(IMAGE_SIZE, interpolation=Image.NEAREST)
83
-
84
- train_dataset = VOCSegmentationDataset("voc_data", 'train', image_transform, mask_transform)
85
- val_dataset = VOCSegmentationDataset("voc_data", 'val', image_transform, mask_transform)
86
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
87
- val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from torchvision import models, transforms
6
  from torchvision.datasets import VOCSegmentation
7
  from torch.utils.data import DataLoader
8
+ from torch.utils.data import Dataset
9
+ import glob
10
  from PIL import Image
11
  import numpy as np
12
  import wandb
13
+ import pandas as pd
14
  import os
15
  import matplotlib.pyplot as plt
16
+ import opendatasets as opd
17
+ import zipfile
18
 
19
  torch.manual_seed(42)
20
  np.random.seed(42)
 
24
  EPOCHS = 25
25
  BATCH_SIZE = 8
26
  LR = 1e-3
27
+ NUM_CLASSES = 32
 
28
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
30
+ # wandb.init(project="segnet-efficientnet-camvid", config={
31
  # "epochs": EPOCHS,
32
  # "batch_size": BATCH_SIZE,
33
  # "learning_rate": LR,
34
  # "architecture": "SegNet-EfficientNet",
35
+ # "dataset": "CamVid"
36
  # })
37
 
38
  class SegNetEfficientNet(nn.Module):
39
+ def __init__(self, num_classes=32):
40
  super(SegNetEfficientNet, self).__init__()
41
  base_model = models.efficientnet_b0(pretrained=True)
42
  features = list(base_model.features.children())
43
 
44
+ # EfficientNet-B0 backbone (output channels gradually increase to 1280)
45
+ self.encoder = nn.Sequential(*features) # Output: [B, 1280, H/32, W/32]
46
 
47
+ # Decoder blocks (mirroring encoder with ConvTranspose2d)
48
  self.decoder = nn.Sequential(
49
  nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
50
+ nn.BatchNorm2d(512),
51
  nn.ReLU(inplace=True),
52
+
53
  nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
54
+ nn.BatchNorm2d(256),
55
  nn.ReLU(inplace=True),
56
+
57
  nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
58
+ nn.BatchNorm2d(128),
59
  nn.ReLU(inplace=True),
60
+
61
  nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
62
+ nn.BatchNorm2d(64),
63
+ nn.ReLU(inplace=True),
64
+
65
+ nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
66
+ nn.BatchNorm2d(32),
67
  nn.ReLU(inplace=True),
 
68
  )
69
 
70
+ self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
71
+
72
  def forward(self, x):
73
+ x = self.encoder(x) # Downsampled features from EfficientNet
74
+ x = self.decoder(x) # Upsampled
75
+ x = self.classifier(x)
76
+ x = F.interpolate(x, size=(360, 480), mode='bilinear', align_corners=False)
77
+
78
  return x
79
 
80
+ class CamVidDataset(Dataset):
81
+ """
82
+ CamVid dataset loader with RGB mask to class index conversion.
83
+ Expects directory structure:
84
+ camvid/
85
+ train/
86
+ train_labels/
87
+ val/
88
+ val_labels/
89
+ test/
90
+ test_labels/
91
+ """
92
+ def __init__(self, root, split='train', transform=None, image_size=(360, 480), target_transform=None, class_dict_path='camvid/CamVid/class_dict.csv'):
93
+ self.root = root
94
+ self.split = split
95
  self.transform = transform
96
  self.target_transform = target_transform
97
 
98
+ self.image_dir = os.path.join(root, split)
99
+ self.label_dir = os.path.join(root, f"{split}_labels")
100
+
101
+ self.image_paths = sorted(glob.glob(os.path.join(self.image_dir, '*.png')))
102
+ self.label_paths = sorted(glob.glob(os.path.join(self.label_dir, '*.png')))
103
+ self.label_resize = transforms.Resize(image_size, interpolation=Image.NEAREST)
104
+ self.image_resize = transforms.Resize(image_size, interpolation=Image.BILINEAR)
105
+ assert len(self.image_paths) == len(self.label_paths), "Mismatch between images and labels."
106
+
107
+ # Load class_dict.csv and build color-to-class mapping
108
+ df = pd.read_csv(class_dict_path)
109
+ self.color_to_class = {
110
+ (row['r'], row['g'], row['b']): idx for idx, row in df.iterrows()
111
+ }
112
+
113
+ def __len__(self):
114
+ return len(self.image_paths)
115
+
116
+ def rgb_to_class(self, mask):
117
+ """Convert an RGB mask (PIL.Image) to a 2D class index mask."""
118
+ mask_np = np.array(mask)
119
+ h, w, _ = mask_np.shape
120
+ class_mask = np.zeros((h, w), dtype=np.uint8)
121
+
122
+ for rgb, class_idx in self.color_to_class.items():
123
+ matches = (mask_np == rgb).all(axis=2)
124
+ class_mask[matches] = class_idx
125
+
126
+ return class_mask
127
+
128
+ def __getitem__(self, idx):
129
+ image = Image.open(self.image_paths[idx]).convert('RGB')
130
+ label = Image.open(self.label_paths[idx]).convert('RGB')
131
+
132
+ # Resize both to 360x480
133
+ image = self.image_resize(image)
134
+ label = self.label_resize(label)
135
+
136
  if self.transform:
137
+ image = self.transform(image)
138
+
139
+ label = self.rgb_to_class(label)
140
+ label = torch.from_numpy(label).long()
141
+
142
+ return image, label
143
+
144
  if __name__ == "__main__":
145
+ dataset_url = "https://www.kaggle.com/datasets/carlolepelaars/camvid"
146
+ opd.download(dataset_url)
147
+
148
+ # Set dataset folder (adjust path if needed)
149
+ dataset_folder = "camvid"
150
+ print("Dataset directory contents:")
151
+ print(os.listdir(dataset_folder))
152
+ input_transform = transforms.Compose([
153
+ transforms.Resize((360, 480)), # Or larger if needed
154
+ transforms.ToTensor(),
155
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
156
+ std=[0.229, 0.224, 0.225])
157
+ ])
158
+
159
+ def label_transform(label):
160
+ # Resize using nearest neighbor so that labels are not interpolated
161
+ label = label.resize((480, 360), Image.NEAREST)
162
+ label = np.array(label, dtype=np.int64)
163
+ return torch.from_numpy(label)
164
+
165
+ num_classes = 32
166
+ data_root = 'camvid/CamVid/' # make sure this matches your structure
167
+
168
+ # Load datasets and dataloaders (assuming CamVidDataset is already defined)
169
+ train_dataset = CamVidDataset(root=data_root, split='train',
170
+ transform=input_transform, target_transform=label_transform)
171
+ val_dataset = CamVidDataset(root=data_root, split='val',
172
+ transform=input_transform, target_transform=label_transform)
173
+ test_dataset = CamVidDataset(root=data_root, split='test',
174
+ transform=input_transform, target_transform=label_transform)
175
+
176
+ train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
177
+ val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)
178
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4)
experiments/SegNet/efficient_b0_backbone/train.py CHANGED
@@ -76,6 +76,6 @@ for epoch in tqdm(range(EPOCHS)):
76
 
77
  print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
78
 
79
- torch.save(model.state_dict(), "segnet_efficientnet_voc.pth")
80
  # wandb.finish()
81
 
 
76
 
77
  print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
78
 
79
+ torch.save(model.state_dict(), "segnet_efficientnet_camvid.pth")
80
  # wandb.finish()
81
 
experiments/ensemble_method.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ from torchvision import transforms
6
+ from experiments.otsu_segmenter import otsu_threshold
7
+ from experiments.SegNet.efficient_b0_backbone.architecture import SegNetEfficientNet, NUM_CLASSES, DEVICE
8
+
9
+ def ensemble_segmentation(image_path, model_path="segnet_efficientnet_voc.pth", boundary_weight=0.3):
10
+ """
11
+ Ensemble segmentation combining Otsu thresholding and SegNet
12
+
13
+ Args:
14
+ image_path: Path to input image
15
+ model_path: Path to SegNet model weights
16
+ boundary_weight: Weight for boundary refinement (0-1)
17
+
18
+ Returns:
19
+ original_image: Original input image (PIL)
20
+ ensemble_result: Ensemble segmentation result (PIL)
21
+ method_comparison: Visualization of all methods side by side (PIL)
22
+ """
23
+ # 1. Load the image
24
+ image = Image.open(image_path).convert('RGB')
25
+ original = image.copy()
26
+ image_np = np.array(image)
27
+
28
+ # 2. Run Otsu thresholding for boundary detection
29
+ # Convert to grayscale and apply Gaussian blur
30
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
31
+ gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
32
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
33
+ otsu_threshold_value, otsu_mask = otsu_threshold(blurred)
34
+
35
+ # 3. Run SegNet for semantic segmentation
36
+ model = SegNetEfficientNet(NUM_CLASSES).to(DEVICE)
37
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
38
+ model.eval()
39
+
40
+ transform = transforms.Compose([
41
+ transforms.Resize((360, 480)), # Or larger if needed
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
44
+ std=[0.229, 0.224, 0.225])
45
+ ])
46
+
47
+ input_tensor = transform(image).unsqueeze(0).to(DEVICE)
48
+
49
+ with torch.no_grad():
50
+ output = model(input_tensor)
51
+ segnet_pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
52
+
53
+ # 4. Create edge map from Otsu result
54
+ edges = cv2.Canny(otsu_mask, 50, 150)
55
+
56
+ # Resize to match SegNet output size
57
+ edges_resized = cv2.resize(edges, (segnet_pred.shape[1], segnet_pred.shape[0]),
58
+ interpolation=cv2.INTER_NEAREST)
59
+
60
+ # 5. Ensemble: Use Otsu edges to refine SegNet boundaries
61
+ # Create a distance transform from the edges
62
+ dist_transform = cv2.distanceTransform(255 - edges_resized, cv2.DIST_L2, 5)
63
+ dist_transform = dist_transform / dist_transform.max() # Normalize to 0-1
64
+
65
+ # Areas close to edges get more influence from Otsu
66
+ edge_weight_map = np.exp(-dist_transform * 5) * boundary_weight
67
+
68
+ # Create binary mask from SegNet (foreground = any class other than background)
69
+ segnet_binary = (segnet_pred > 0).astype(np.uint8) * 255
70
+
71
+ # Resize Otsu mask to match SegNet output
72
+ otsu_resized = cv2.resize(otsu_mask, (segnet_pred.shape[1], segnet_pred.shape[0]),
73
+ interpolation=cv2.INTER_NEAREST)
74
+
75
+ # Combine: Use SegNet classes but refine boundaries with Otsu
76
+ # For boundary regions, adjust the segmentation based on Otsu
77
+ refined_binary = segnet_binary.copy()
78
+ boundary_region = edge_weight_map > 0.1
79
+ refined_binary[boundary_region] = (
80
+ (1 - edge_weight_map[boundary_region]) * segnet_binary[boundary_region] +
81
+ edge_weight_map[boundary_region] * otsu_resized[boundary_region]
82
+ ).astype(np.uint8)
83
+
84
+ # Apply the refined binary mask to the original SegNet prediction
85
+ ensemble_result = segnet_pred.copy()
86
+ # Where the refined binary is 0, set to background class (0)
87
+ ensemble_result[refined_binary < 128] = 0
88
+
89
+ # 6. Visualize results
90
+ from matplotlib import cm
91
+ import matplotlib.pyplot as plt
92
+ import io
93
+
94
+ # Convert semantic maps to color visualizations
95
+ colormap = cm.get_cmap('nipy_spectral')
96
+
97
+ segnet_colored = colormap(segnet_pred / (NUM_CLASSES - 1))
98
+ segnet_colored = (segnet_colored[:, :, :3] * 255).astype(np.uint8)
99
+
100
+ ensemble_colored = colormap(ensemble_result / (NUM_CLASSES - 1))
101
+ ensemble_colored = (ensemble_colored[:, :, :3] * 255).astype(np.uint8)
102
+
103
+ # Create side-by-side comparison
104
+ fig, axes = plt.subplots(1, 4, figsize=(16, 4))
105
+
106
+ # Resize original image to match the segmentation size
107
+ original_resized = original.resize((segnet_pred.shape[1], segnet_pred.shape[0]))
108
+
109
+ axes[0].imshow(original_resized)
110
+ axes[0].set_title("Original Image")
111
+ axes[0].axis('off')
112
+
113
+ axes[1].imshow(otsu_mask, cmap='gray')
114
+ axes[1].set_title(f"Otsu (t={otsu_threshold_value})")
115
+ axes[1].axis('off')
116
+
117
+ axes[2].imshow(segnet_colored)
118
+ axes[2].set_title("SegNet Prediction")
119
+ axes[2].axis('off')
120
+
121
+ axes[3].imshow(ensemble_colored)
122
+ axes[3].set_title("Ensemble Result")
123
+ axes[3].axis('off')
124
+
125
+ plt.tight_layout()
126
+
127
+ # Convert the plot to an image
128
+ buf = io.BytesIO()
129
+ plt.savefig(buf, format='png')
130
+ buf.seek(0)
131
+ comparison_image = Image.open(buf)
132
+ plt.close(fig)
133
+
134
+ # Return results
135
+ ensemble_pil = Image.fromarray(ensemble_colored)
136
+ ensemble_pil = ensemble_pil.resize(original.size, Image.NEAREST)
137
+
138
+ return original, ensemble_pil, comparison_image
139
+
140
+ # Add this function to your app.py
141
+ def generate_ensemble_segmentation(image_path, boundary_weight=0.3):
142
+ """Wrapper for Gradio interface"""
143
+ original, ensemble_result, comparison = ensemble_segmentation(
144
+ image_path,
145
+ model_path="saved_models/segnet_efficientnet_camvid.pth",
146
+ boundary_weight=boundary_weight
147
+ )
148
+ return original, ensemble_result, comparison
saved_models/segnet_efficientnet_camvid.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f1e96df359eb0e1c153627880dc93e662b2ae5f998f9ed946ec71e726739481
3
+ size 29641657
saved_models/segnet_vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ac7681151184571d468e4c408c30107dd8b44170b602a06b97a24240f0fb83b
3
+ size 49538462