riha55 commited on
Commit
6f774ac
·
verified ·
1 Parent(s): 49e697e

Upload 7 files

Browse files
Files changed (7) hide show
  1. app.py +80 -0
  2. dataloader.py +66 -0
  3. losses.py +93 -0
  4. main.py +189 -0
  5. model.pth +3 -0
  6. requirements.txt +5 -0
  7. train5.py +387 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ from train5 import deeplabv3_encoder_decoder
6
+ import numpy as np
7
+
8
+ # Function to load the model
9
+ def load_model(model_path):
10
+ model = deeplabv3_encoder_decoder()
11
+
12
+ try:
13
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
14
+ model.eval()
15
+ return model
16
+ except Exception as e:
17
+ st.error(f"Error loading model: {e}")
18
+ return None
19
+
20
+ # Path to the model
21
+ model_path = '/teamspace/studios/this_studio/Segmentation/model.pth'
22
+
23
+
24
+ # Load the trained model
25
+ model = load_model(model_path)
26
+
27
+ if model:
28
+ # Create a Streamlit app
29
+ st.title('Aerial Image Segmentation')
30
+
31
+ # Add a file uploader to the app
32
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
33
+
34
+ if uploaded_file is not None:
35
+ image = Image.open(uploaded_file)
36
+
37
+ # Display the original image
38
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
39
+
40
+ # Preprocess the image
41
+ data_transform = transforms.Compose([
42
+ transforms.Resize((512, 512)),
43
+ transforms.ToTensor()]
44
+ )
45
+ image = data_transform(image)
46
+ image = image.unsqueeze(0) # add a batch dimension
47
+
48
+ # Pass the image through the model
49
+ with torch.no_grad():
50
+ output = model(image)
51
+
52
+ # Define the color map and class labels
53
+ color_map = {
54
+ 0: np.array([255, 34, 133]), # Unlabeled
55
+ 1: np.array([0, 252, 199]), # Early Blight
56
+ 2: np.array([86, 0, 254]), # Late Blight
57
+ 3: np.array([0, 0, 0]) # Leaf Minor
58
+ }
59
+
60
+ class_labels = {
61
+ 0: 'Unlabeled',
62
+ 1: 'Early Blight',
63
+ 2: 'Late Blight',
64
+ 3: 'Leaf Minor'
65
+ }
66
+
67
+
68
+ for k, v in class_labels.items():
69
+ st.sidebar.markdown(f'<div style="color:rgb{tuple(color_map[k])};">{v}</div>', unsafe_allow_html=True)
70
+
71
+
72
+ output = torch.argmax(output.squeeze(), dim=0).detach().cpu().numpy()
73
+
74
+
75
+ output_rgb = np.zeros((output.shape[0], output.shape[1], 3), dtype=np.uint8)
76
+ for k, v in color_map.items():
77
+ output_rgb[output == k] = v
78
+
79
+
80
+ st.image(output_rgb, caption='Segmented Image.', use_column_width=True)
dataloader.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import Dataset
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import numpy as np
6
+
7
+ class AerialImageDataset(Dataset):
8
+ def __init__(self, image_dir, mask_dir, transform=None):
9
+ self.image_dir = image_dir
10
+ self.mask_dir = mask_dir
11
+ self.transform = transform
12
+ self.images = os.listdir(self.image_dir)
13
+ self.Hex_Classes = [
14
+ ('Unlabeled', '#FF2285'),
15
+ ('Early Blight','#00FCC7'),
16
+ ('Late Blight', '#5600FE'),
17
+ ('Leaf Minor', '#000000')
18
+ ]
19
+
20
+ def __len__(self):
21
+ return len(self.images)
22
+
23
+ def __getitem__(self, idx):
24
+
25
+ img_path = os.path.join(self.image_dir, self.images[idx])
26
+ mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))
27
+
28
+ image = Image.open(img_path)
29
+ mask = Image.open(mask_path).resize((512, 512))
30
+ # print(mask.size)
31
+
32
+ # print(mask.size)
33
+
34
+ mask = np.array(mask)
35
+ mask = self.encode_segmap(mask)
36
+ mask = mask.astype(np.uint8) # Convert data type to uint8
37
+ # print(mask.shape)
38
+ mask = Image.fromarray(mask) # Convert mask -> PIL
39
+
40
+ if self.transform:
41
+ image = self.transform(image)
42
+ mask = self.transform(mask)
43
+
44
+ return image, mask
45
+
46
+ def encode_segmap(self, mask):
47
+ mask = mask.astype(int)
48
+ label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) # height, width -> 0
49
+ for i, (name, color) in enumerate(self.Hex_Classes):
50
+ if mask.ndim == 3:
51
+ label_mask[(mask[:,:,0] == int(color[1:3], 16)) & (mask[:,:,1] == int(color[3:5], 16)) & (mask[:,:,2] == int(color[5:7], 16))] = i
52
+ elif mask.ndim == 2:
53
+ label_mask[(mask == int(color[1:3], 16))] = i
54
+ # print("Warning ndim = 2")
55
+ # return None
56
+
57
+ msk = np.zeros((512,512,4))
58
+ for i in [0,1,2,3]:
59
+ if i == 0:
60
+ msk_ind = np.where(label_mask == i, 4, 0)
61
+ msk[:,:,i] = msk_ind
62
+ else:
63
+ msk_ind = np.where(label_mask == i, i, 0)
64
+ msk[:,:,i] = msk_ind
65
+ # print("mask shape",type(msk))
66
+ return msk
losses.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .train3 import deeplabv3_encoder_decoder
2
+ # # from .train3 import pl
3
+ # # from .train3 import torch
4
+ # import torch.nn as nn
5
+ # import torch.nn.functional as F
6
+ # import torch
7
+ # class mIoULoss(nn.Module):
8
+ # def __init__(self, weight=None, size_average=True, n_classes=4):
9
+ # super().__init__()
10
+ # self.classes = n_classes
11
+
12
+ # def to_one_hot(self, tensor):
13
+ # n, h, w = tensor.size()
14
+ # one_hot = torch.zeros(n, self.classes, h, w).to(tensor.device)
15
+ # one_hot.scatter_(1, tensor.unsqueeze(1), 1)
16
+ # return one_hot
17
+
18
+ # def forward(self, inputs, target):
19
+ # N = inputs.size(0)
20
+ # inputs = F.softmax(inputs, dim=1)
21
+ # target_oneHot = self.to_one_hot(target)
22
+ # inter = inputs * target_oneHot
23
+ # inter = inter.view(N, self.classes, -1).sum(2)
24
+ # union = inputs + target_oneHot - inter
25
+ # union = union.view(N, self.classes, -1).sum(2)
26
+ # loss = inter / union
27
+ # return 1 - loss.mean()
28
+
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch
32
+
33
+
34
+ class DiceLoss(nn.Module):
35
+ def __init__(self, smooth=1.0):
36
+ super(DiceLoss, self).__init__()
37
+ self.smooth = smooth
38
+
39
+ def forward(self, preds, labels):
40
+ #
41
+ if preds.dim() == 4:
42
+ preds = torch.sigmoid(preds)
43
+
44
+ # Flatten the tensors
45
+ preds = preds.contiguous().view(-1)
46
+ labels = labels.contiguous().view(-1)
47
+
48
+ # Compute intersection and union
49
+ intersection = (preds * labels).sum()
50
+ dice = (2. * intersection + self.smooth) / (preds.sum() + labels.sum() + self.smooth)
51
+
52
+ # Dice loss is 1 - Dice coefficient
53
+ loss = 1 - dice
54
+ return loss
55
+
56
+
57
+ class mIoULoss(nn.Module):
58
+ def __init__(self, weight=None, size_average=True, n_classes=4): # Set n_classes to 4
59
+ super().__init__()
60
+ self.classes = n_classes
61
+
62
+ def to_one_hot(self, tensor):
63
+ tensor = tensor.long() # Ensure tensor is a LongTensor
64
+ n, c, h, w = tensor.size() # Adjust size extraction
65
+ one_hot = torch.zeros(n, self.classes, h, w).to(tensor.device)
66
+ one_hot.scatter_(1, tensor, 1)
67
+ return one_hot
68
+
69
+ def forward(self, inputs, target):
70
+ # inputs => N x Classes x H x W
71
+ # target_oneHot => N x Classes x H x W
72
+
73
+ N = inputs.size()[0]
74
+
75
+ # predicted probabilities for each pixel along channel
76
+ inputs = F.softmax(inputs, dim=1)
77
+
78
+ # Numerator Product
79
+ target_oneHot = self.to_one_hot(target)
80
+ inter = inputs * target_oneHot
81
+ ## Sum over all pixels N x C x H x W => N x C
82
+ inter = inter.view(N, self.classes, -1).sum(2)
83
+
84
+ # Denominator
85
+ union = inputs + target_oneHot - (inputs * target_oneHot)
86
+ ## Sum over all pixels N x C x H x W => N x C
87
+ union = union.view(N, self.classes, -1).sum(2)
88
+
89
+ loss = inter / union
90
+
91
+ ## Return average loss over classes and batch
92
+ return 1 - loss.mean()
93
+
main.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # # import os
2
+ # # # import pytorch_lightning as L
3
+ # # # from dataloader import AerialImageDataset
4
+ # # # from train5 import deeplabv3_encoder_decoder
5
+ # # # from torch.utils.data import DataLoader
6
+ # # # from torchvision.transforms import transforms
7
+ # # # import torch
8
+
9
+ # # # train_path = r"C:\Users\User\Downloads\Nishant\train"
10
+ # # # val_path = r"C:\Users\User\Downloads\Nishant\val"
11
+
12
+ # # # data_transform = transforms.Compose([
13
+ # # # transforms.Resize((512, 512)),
14
+ # # # transforms.ToTensor()
15
+ # # # ])
16
+
17
+ # # # train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform)
18
+ # # # val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform)
19
+
20
+ # # # train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
21
+ # # # val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
22
+
23
+ # # # model = deeplabv3_encoder_decoder()
24
+
25
+ # # # # Adjust the refresh rate of the progress bar
26
+ # # # trainer = L.Trainer(max_epochs=100, progress_bar_refresh_rate=20) # Adjust the refresh rate as needed
27
+ # # # trainer.fit(model, train_loader, val_loader)
28
+
29
+ # # # torch.save(model.state_dict(), r"C:\Users\User\Downloads\Nishant\main.py\model.pth")
30
+
31
+ # # import os
32
+ # # import pytorch_lightning as pl
33
+ # # from dataloader import AerialImageDataset
34
+ # # from train5 import deeplabv3_encoder_decoder
35
+ # # from torch.utils.data import DataLoader
36
+ # # from torchvision.transforms import transforms
37
+ # # import torch
38
+
39
+ # # train_path = r"C:\Users\User\Downloads\Nishant\train"
40
+ # # val_path = r"C:\Users\User\Downloads\Nishant\val"
41
+
42
+ # # data_transform = transforms.Compose([
43
+ # # transforms.Resize((512, 512)),
44
+ # # transforms.ToTensor()
45
+ # # ])
46
+
47
+ # # train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform)
48
+ # # val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform)
49
+
50
+ # # train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
51
+ # # val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
52
+
53
+ # # model = deeplabv3_encoder_decoder()
54
+
55
+ # # # Adjust other trainer parameters as needed
56
+ # # trainer = pl.Trainer(max_epochs=100)
57
+ # # trainer.fit(model, train_loader, val_loader)
58
+
59
+ # # torch.save(model.state_dict(), r"C:\Users\User\Downloads\Nishant\main.py\model.pth")
60
+
61
+
62
+
63
+ # #running code
64
+ # # import os
65
+ # # import pytorch_lightning as pl
66
+ # # from dataloader import AerialImageDataset
67
+ # # from train5 import deeplabv3_encoder_decoder
68
+ # # from torch.utils.data import DataLoader
69
+ # # from torchvision.transforms import transforms
70
+ # # import torch
71
+
72
+ # # train_path = r"C:\Users\User\Downloads\Nishant\train"
73
+ # # val_path = r"C:\Users\User\Downloads\Nishant\val"
74
+
75
+ # # data_transform = transforms.Compose([
76
+ # # transforms.Resize((512, 512)),
77
+ # # transforms.ToTensor()
78
+ # # ])
79
+
80
+ # # train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform)
81
+ # # val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform)
82
+
83
+ # # train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
84
+ # # val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
85
+
86
+ # # model = deeplabv3_encoder_decoder()
87
+
88
+ # # # Adjust other trainer parameters as needed
89
+ # # trainer = pl.Trainer(num_sanity_val_steps=0, max_epochs=100)
90
+ # # trainer.fit(model, train_loader, val_loader)
91
+
92
+ # # torch.save(model.state_dict(), r"C:\Users\User\Downloads\Nishant\main.py\model.pth")
93
+
94
+ # import os
95
+ # import pytorch_lightning as pl
96
+ # from dataloader import AerialImageDataset
97
+ # from train5 import deeplabv3_encoder_decoder
98
+ # from torch.utils.data import DataLoader
99
+ # from torchvision.transforms import transforms
100
+ # import torch
101
+ # from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
102
+
103
+ # train_path = r"C:\Users\User\Downloads\Nishant\train"
104
+ # val_path = r"C:\Users\User\Downloads\Nishant\val"
105
+
106
+ # data_transform = transforms.Compose([
107
+ # transforms.Resize((512, 512)),
108
+ # transforms.ToTensor()
109
+ # ])
110
+
111
+ # train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform)
112
+ # val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform)
113
+
114
+ # train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
115
+ # val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
116
+
117
+ # model = deeplabv3_encoder_decoder()
118
+
119
+
120
+ # checkpoint_callback = ModelCheckpoint(
121
+ # monitor='val_loss',
122
+ # dirpath='checkpoints',
123
+ # filename='best_model',
124
+ # save_top_k=1,
125
+ # mode='min'
126
+ # )
127
+
128
+ # early_stop_callback = EarlyStopping(
129
+ # monitor='val_loss',
130
+ # patience=20,
131
+ # verbose=True,
132
+ # mode='min'
133
+ # )
134
+
135
+
136
+ # trainer = pl.Trainer(
137
+ # num_sanity_val_steps=0,
138
+ # max_epochs=100,
139
+ # callbacks=[checkpoint_callback, early_stop_callback] # Pass both callbacks
140
+ # )
141
+ # trainer.fit(model, train_loader, val_loader)
142
+ # torch.save(model.state_dict(), r"C:\Users\User\Downloads\Nishant\main.py\model.pth")
143
+ import os
144
+ import pytorch_lightning as pl
145
+ from dataloader import AerialImageDataset
146
+ from train5 import deeplabv3_encoder_decoder
147
+ from torch.utils.data import DataLoader
148
+ from torchvision.transforms import transforms
149
+ import torch
150
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
151
+
152
+ train_path = r"/teamspace/studios/this_studio/Segmentation/train"
153
+ val_path = r"/teamspace/studios/this_studio/Segmentation/val"
154
+
155
+ data_transform = transforms.Compose([
156
+ transforms.Resize((512, 512)),
157
+ transforms.ToTensor()
158
+ ])
159
+
160
+ train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform)
161
+ val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform)
162
+
163
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
164
+ val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
165
+
166
+ model = deeplabv3_encoder_decoder()
167
+
168
+ checkpoint_callback = ModelCheckpoint(
169
+ monitor='val_loss',
170
+ dirpath='checkpoints1',
171
+ filename='best_model',
172
+ save_top_k=1,
173
+ mode='min' # Save the model based on minimizing validation loss
174
+ )
175
+
176
+ early_stop_callback = EarlyStopping(
177
+ monitor='val_loss',
178
+ patience=20,
179
+ verbose=True,
180
+ mode='min'
181
+ )
182
+
183
+ trainer = pl.Trainer(
184
+ num_sanity_val_steps=0,
185
+ max_epochs=1000,
186
+ callbacks=[checkpoint_callback, early_stop_callback] # Pass both callbacks
187
+ )
188
+ trainer.fit(model, train_loader, val_loader)
189
+ torch.save(model.state_dict(), r"/teamspace/studios/this_studio/Segmentation/model.pth")
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd6ff7f738ef678fed4b0eb358462422d743004d7321b8378f12eae1f7fa93a9
3
+ size 155201050
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ Pillow
3
+ torchvision
4
+ numpy
5
+ pytorch-lightning==2.2.5
train5.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import torch.nn as nn
3
+ # import torch.nn.functional as F
4
+ # import pytorch_lightning as pl
5
+ # from losses import mIoULoss
6
+ # from torchvision import models
7
+
8
+ # class ASSP(nn.Module):
9
+ # def __init__(self, in_channels, out_channels=256, final_out_channels=4):
10
+ # super(ASSP, self).__init__()
11
+
12
+ # self.relu = nn.ReLU(inplace=True)
13
+
14
+ # # 1x1 convolution
15
+ # self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=False)
16
+ # self.bn1 = nn.BatchNorm2d(out_channels)
17
+
18
+ # # 3x3 convolutions with different dilation rates
19
+ # self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6, bias=False)
20
+ # self.bn2 = nn.BatchNorm2d(out_channels)
21
+
22
+ # self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12, bias=False)
23
+ # self.bn3 = nn.BatchNorm2d(out_channels)
24
+
25
+ # self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18, bias=False)
26
+ # self.bn4 = nn.BatchNorm2d(out_channels)
27
+
28
+ # # 1x1 convolution after global average pooling
29
+ # self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
30
+ # self.bn5 = nn.BatchNorm2d(out_channels)
31
+
32
+ # # Final 1x1 convolution to combine features
33
+ # self.convf = nn.Conv2d(out_channels * 5, final_out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
34
+ # self.bnf = nn.BatchNorm2d(final_out_channels)
35
+
36
+ # # Global average pooling
37
+ # self.adapool = nn.AdaptiveAvgPool2d(1)
38
+
39
+ # def forward(self, x):
40
+ # # 1x1 convolution
41
+ # x1 = self.conv1(x)
42
+ # x1 = self.bn1(x1)
43
+ # x1 = self.relu(x1)
44
+
45
+ # # 3x3 convolution with dilation 6
46
+ # x2 = self.conv2(x)
47
+ # x2 = self.bn2(x2)
48
+ # x2 = self.relu(x2)
49
+
50
+ # # 3x3 convolution with dilation 12
51
+ # x3 = self.conv3(x)
52
+ # x3 = self.bn3(x3)
53
+ # x3 = self.relu(x3)
54
+
55
+ # # 3x3 convolution with dilation 18
56
+ # x4 = self.conv4(x)
57
+ # x4 = self.bn4(x4)
58
+ # x4 = self.relu(x4)
59
+
60
+ # # Global average pooling, 1x1 convolution, and upsample
61
+ # x5 = self.adapool(x)
62
+ # x5 = self.conv5(x5)
63
+ # x5 = self.bn5(x5)
64
+ # x5 = self.relu(x5)
65
+ # x5 = F.interpolate(x5, size=x4.shape[-2:], mode='bilinear')
66
+
67
+ # # Concatenate all feature maps
68
+ # x = torch.cat((x1, x2, x3, x4, x5), dim=1)
69
+
70
+ # # Final 1x1 convolution
71
+ # x = self.convf(x)
72
+ # x = self.bnf(x)
73
+ # x = self.relu(x)
74
+
75
+ # return x
76
+
77
+ # class ResNet_50(nn.Module):
78
+ # def __init__(self, in_channels=3): # Change default to 3 channels for RGB images
79
+ # super(ResNet_50, self).__init__()
80
+
81
+ # # Load the pre-trained ResNet-50 model
82
+ # self.resnet_50 = models.resnet50(weights='DEFAULT')
83
+
84
+ # # Modify the first convolutional layer to accept 3-channel input
85
+ # self.resnet_50.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
86
+
87
+ # # Use the layers up to the final layer before the fully connected layer
88
+ # self.resnet_50 = nn.Sequential(*list(self.resnet_50.children())[:-2])
89
+ # self.relu = nn.ReLU(inplace=True)
90
+
91
+ # def forward(self, x):
92
+ # x = self.resnet_50(x)
93
+ # return x
94
+
95
+ # class deeplabv3_encoder_decoder(pl.LightningModule):
96
+ # def __init__(self, input_channels=3, output_channels=4): # Use 4 channels for output
97
+ # super(deeplabv3_encoder_decoder, self).__init__()
98
+ # self.resnet = ResNet_50(in_channels=input_channels)
99
+ # self.aspp = ASSP(in_channels=2048, final_out_channels=4)
100
+ # self.conv = nn.Conv2d(in_channels=4, out_channels=output_channels, kernel_size=1)
101
+ # self.criterion = mIoULoss(n_classes=4) # Set number of classes to 4
102
+
103
+ # def forward(self, x):
104
+ # _, _, h, w = x.shape
105
+ # x = self.resnet(x) # Output should be [batch_size, 2048, H/32, W/32]
106
+ # x = self.aspp(x)
107
+ # x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) # Upsample
108
+ # x = self.conv(x) # Apply final convolution
109
+ # return x
110
+
111
+ # def training_step(self, batch, batch_idx):
112
+ # images, masks = batch
113
+ # logits = self(images)
114
+ # loss = self.criterion(logits, masks)
115
+ # iou = calculate_iou(logits, masks)
116
+ # self.log('train_loss', loss)
117
+ # self.log('train_iou', iou)
118
+ # print(f'Training Loss: {loss}, IoU: {iou}')
119
+ # return loss
120
+
121
+ # def validation_step(self, batch, batch_idx):
122
+ # images, masks = batch
123
+ # logits = self(images)
124
+ # loss = self.criterion(logits, masks)
125
+ # iou = calculate_iou(logits, masks)
126
+ # self.log('val_loss', loss)
127
+ # self.log('val_iou', iou)
128
+ # print(f'Validation Loss: {loss}, IoU: {iou}')
129
+ # return loss
130
+
131
+ # def on_training_epoch_end(self, outputs):
132
+ # avg_iou = torch.stack([x['train_iou'] for x in outputs]).mean()
133
+ # self.log('avg_train_iou', avg_iou)
134
+ # def on_validation_epoch_end(self, outputs):
135
+ # avg_iou = torch.stack([x['val_iou'] for x in outputs]).mean()
136
+ # self.log('avg_val_iou', avg_iou)
137
+ # def configure_optimizers(self):
138
+ # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
139
+ # return optimizer
140
+
141
+ # def calculate_iou(logits, masks):
142
+ # # Calculate predictions from logits
143
+ # preds = torch.argmax(logits, dim=1)
144
+ # # Calculate intersection and union
145
+ # intersection = torch.sum(preds * masks)
146
+ # union = torch.sum((preds.bool() | masks.bool()).int())
147
+ # # Avoid division by zero
148
+ # iou = intersection / union if union != 0 else torch.tensor(0.0)
149
+ # return iou
150
+
151
+
152
+
153
+ import torch
154
+ import torch.nn as nn
155
+ import torch.nn.functional as F
156
+ import pytorch_lightning as pl
157
+ from losses import DiceLoss
158
+ from torchvision import models
159
+ import numpy as np
160
+ import matplotlib.pyplot as plt
161
+
162
+ class ASSP(nn.Module):
163
+ def __init__(self, in_channels, out_channels=256, final_out_channels=4):
164
+ super(ASSP, self).__init__()
165
+
166
+ self.relu = nn.ReLU(inplace=True)
167
+
168
+ # 1x1 convolution
169
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=False)
170
+ self.bn1 = nn.BatchNorm2d(out_channels)
171
+
172
+ # 3x3 convolutions with different dilation rates
173
+ self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6, bias=False)
174
+ self.bn2 = nn.BatchNorm2d(out_channels)
175
+
176
+ self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12, bias=False)
177
+ self.bn3 = nn.BatchNorm2d(out_channels)
178
+
179
+ self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18, bias=False)
180
+ self.bn4 = nn.BatchNorm2d(out_channels)
181
+
182
+ # 1x1 convolution after global average pooling
183
+ self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
184
+ self.bn5 = nn.BatchNorm2d(out_channels)
185
+
186
+ # Final 1x1 convolution to combine features
187
+ self.convf = nn.Conv2d(out_channels * 5, final_out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
188
+ self.bnf = nn.BatchNorm2d(final_out_channels)
189
+
190
+ # Global average pooling
191
+ self.adapool = nn.AdaptiveAvgPool2d(1)
192
+
193
+ def forward(self, x):
194
+ # 1x1 convolution
195
+ x1 = self.conv1(x)
196
+ x1 = self.bn1(x1)
197
+ x1 = self.relu(x1)
198
+
199
+ # 3x3 convolution with dilation 6
200
+ x2 = self.conv2(x)
201
+ x2 = self.bn2(x2)
202
+ x2 = self.relu(x2)
203
+
204
+ # 3x3 convolution with dilation 12
205
+ x3 = self.conv3(x)
206
+ x3 = self.bn3(x3)
207
+ x3 = self.relu(x3)
208
+
209
+ # 3x3 convolution with dilation 18
210
+ x4 = self.conv4(x)
211
+ x4 = self.bn4(x4)
212
+ x4 = self.relu(x4)
213
+
214
+ # Global average pooling, 1x1 convolution, and upsample
215
+ x5 = self.adapool(x)
216
+ x5 = self.conv5(x5)
217
+ x5 = self.bn5(x5)
218
+ x5 = self.relu(x5)
219
+ x5 = F.interpolate(x5, size=x4.shape[-2:], mode='bilinear')
220
+
221
+ # Concatenate all feature maps
222
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
223
+
224
+ # Final 1x1 convolution
225
+ x = self.convf(x)
226
+ x = self.bnf(x)
227
+ x = self.relu(x)
228
+
229
+ return x
230
+
231
+ class ResNet_50(nn.Module):
232
+ def __init__(self, in_channels=3): # Change default to 3 channels for RGB images
233
+ super(ResNet_50, self).__init__()
234
+
235
+ # Load the pre-trained ResNet-50 model
236
+ self.resnet_50 = models.resnet50(pretrained=True)
237
+
238
+ # Modify the first convolutional layer to accept 3-channel input
239
+ self.resnet_50.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
240
+
241
+ # Use the layers up to the final layer before the fully connected layer
242
+ self.resnet_50 = nn.Sequential(*list(self.resnet_50.children())[:-2])
243
+ self.relu = nn.ReLU(inplace=True)
244
+
245
+ def forward(self, x):
246
+ x = self.resnet_50(x)
247
+ return x
248
+
249
+ # class deeplabv3_encoder_decoder(pl.LightningModule):
250
+ # def __init__(self, input_channels=3, output_channels=4): # Use 4 channels for output
251
+ # super(deeplabv3_encoder_decoder, self).__init__()
252
+ # self.resnet = ResNet_50(in_channels=input_channels)
253
+ # self.aspp = ASSP(in_channels=2048, final_out_channels=4)
254
+ # self.conv = nn.Conv2d(in_channels=4, out_channels=output_channels, kernel_size=1)
255
+ # self.criterion = mIoULoss(n_classes=4) # Set number of classes to 4
256
+
257
+ # def forward(self, x):
258
+ # _, _, h, w = x.shape
259
+ # x = self.resnet(x) # Output should be [batch_size, 2048, H/32, W/32]
260
+ # x = self.aspp(x)
261
+ # x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) # Upsample
262
+ # x = self.conv(x) # Apply final convolution
263
+ # return x
264
+
265
+ # def training_step(self, batch, batch_idx):
266
+ # images, masks = batch
267
+ # logits = self(images)
268
+ # loss = self.criterion(logits, masks)
269
+ # iou = calculate_iou(logits, masks)
270
+ # self.log('train_loss', loss)
271
+ # self.log('train_iou', iou)
272
+ # print(f'Training Loss: {loss}, IoU: {iou}')
273
+ # return loss
274
+
275
+ # def validation_step(self, batch, batch_idx):
276
+ # images, masks = batch
277
+ # logits = self(images)
278
+ # loss = self.criterion(logits, masks)
279
+ # iou = calculate_iou(logits, masks)
280
+ # self.log('val_loss', loss)
281
+ # self.log('val_iou', iou)
282
+ # print(f'Validation Loss: {loss}, IoU: {iou}')
283
+ # return loss
284
+
285
+ # def on_training_epoch_end(self, outputs):
286
+ # avg_iou = torch.stack([x['train_iou'] for x in outputs]).mean()
287
+ # self.log('avg_train_iou', avg_iou)
288
+
289
+ # def on_validation_epoch_end(self, outputs):
290
+ # avg_iou = torch.stack([x['val_iou'] for x in outputs]).mean()
291
+ # self.log('avg_val_iou', avg_iou)
292
+
293
+ # def configure_optimizers(self):
294
+ # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
295
+ # return optimizer
296
+
297
+ class deeplabv3_encoder_decoder(pl.LightningModule):
298
+ def __init__(self, input_channels=3, output_channels=4): # Use 4 channels for output
299
+ super(deeplabv3_encoder_decoder, self).__init__()
300
+ self.resnet = ResNet_50(in_channels=input_channels)
301
+ self.aspp = ASSP(in_channels=2048, final_out_channels=4)
302
+ self.conv = nn.Conv2d(in_channels=4, out_channels=output_channels, kernel_size=1)
303
+ self.criterion = DiceLoss() # Set number of classes to 4
304
+
305
+ def forward(self, x):
306
+ _, _, h, w = x.shape
307
+ x = self.resnet(x) # Output should be [batch_size, 2048, H/32, W/32]
308
+ x = self.aspp(x)
309
+ x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) # Upsample
310
+ x = self.conv(x) # Apply final convolution
311
+ return x
312
+
313
+ def training_step(self, batch, batch_idx):
314
+ images, masks = batch
315
+ logits = self(images)
316
+ loss = self.criterion(logits, masks)
317
+ # print("\n\n\n\n\n\n\n\n",masks.shape, logits.shape,"\n\n\n\n\n\n\n\n\n\n")
318
+ iou = compute_iou(logits, masks)
319
+ self.log('train_loss', loss)
320
+ self.log('train_iou', iou)
321
+ # print(f'Training Loss: {loss}, IoU: {iou}')
322
+ return loss
323
+
324
+ def validation_step(self, batch, batch_idx):
325
+ images, masks = batch
326
+ logits = self(images)
327
+ loss = self.criterion(logits, masks)
328
+ iou = compute_iou(logits, masks)
329
+ self.log('val_loss', loss)
330
+ self.log('val_iou', iou)
331
+ # print(f'Validation Loss: {loss}, IoU: {iou}')
332
+ return loss
333
+
334
+ def on_train_epoch_end(self):
335
+ avg_iou = self.trainer.callback_metrics['train_iou'].mean()
336
+ train_loss = self.trainer.logged_metrics.get('train_loss')
337
+ self.log('avg_train_iou', avg_iou)
338
+ print("avg train iou",avg_iou)
339
+ print("loss",train_loss)
340
+ # iou = calculate_iou(logits, masks)
341
+ # self.log('train_loss', loss)
342
+ # self.log('train_iou', iou)
343
+ # print(f'Training Loss: {loss}, IoU: {iou}')
344
+
345
+ def on_validation_epoch_end(self):
346
+ avg_iou = self.trainer.callback_metrics['val_iou'].mean()
347
+ val_loss = self.trainer.logged_metrics.get('val_loss')
348
+
349
+ self.log('avg_val_iou', avg_iou)
350
+ print("avg val iou",avg_iou)
351
+ print("val loss", val_loss)
352
+
353
+ def configure_optimizers(self):
354
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
355
+ return optimizer
356
+
357
+
358
+
359
+
360
+ # def calculate_iou(logits, masks):
361
+ # # Calculate predictions from logits
362
+ # preds = torch.argmax(logits, dim=1)
363
+ # # Calculate intersection and union
364
+ # intersection = torch.sum(preds * masks)
365
+ # union = torch.sum((preds.bool() | masks.bool()).int())
366
+ # # Avoid division by zero
367
+ # iou = intersection / union if union != 0 else torch.tensor(0.0)
368
+ # return iou
369
+
370
+ def compute_iou(preds,labels,threshold = 0.5 , epsilon = torch.finfo(torch.float).eps):
371
+ preds = torch.sigmoid(preds)
372
+ # print("preds shape",preds.shape)
373
+ preds = (preds>threshold).float()
374
+ # print("preds shape123",preds.shape)
375
+ # print("masks shape123",labels.shape)
376
+ # print("masks shape123",np.unique(labels.cpu().numpy()))
377
+ # plt.imshow(labels[0,:,:,:].T.cpu().numpy())
378
+ # plt.show()
379
+ n_classes = preds.shape[1]
380
+ iou_per_class = []
381
+ for i in range(n_classes):
382
+ intersection = (preds[:,i,:,:] * labels[:,i,:,:]).sum((1,2))
383
+ union = (preds[:,i,:,:]+ labels[:,i,:,:]).sum((1,2)) - intersection
384
+ iou = (intersection + epsilon) / (union + epsilon)
385
+ iou_per_class.append(iou.mean())
386
+ iou_mean = sum(iou_per_class)/ n_classes
387
+ return iou_mean