NishantD commited on
Commit
56f7a23
·
verified ·
1 Parent(s): a9f3003

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +79 -0
  2. dataloader.py +51 -0
  3. dataprep.py +71 -0
  4. main.py +28 -0
  5. model.pth +3 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ from train import UNet
6
+ import numpy as np
7
+
8
+ # Load the trained model
9
+ model_path = '/teamspace/studios/this_studio/Aerial-Segmentation/model.pth'
10
+ model = UNet(n_channels=3, n_classes=6)
11
+ model.load_state_dict(torch.load(model_path))
12
+ model.eval()
13
+
14
+ # Create a Streamlit app
15
+ st.title('Aerial Image Segmentation')
16
+
17
+ # Add a file uploader to the app
18
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
19
+
20
+ if uploaded_file is not None:
21
+ image = Image.open(uploaded_file)
22
+
23
+ # Display the original image
24
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
25
+
26
+ # Preprocess the image
27
+ data_transform = transforms.Compose([
28
+ transforms.Resize((512, 512)),
29
+ transforms.ToTensor()]
30
+ )
31
+ image = data_transform(image)
32
+ image = image.unsqueeze(0) # add a batch dimension
33
+
34
+ # Pass the image through the model
35
+ with torch.no_grad():
36
+ output = model(image)
37
+
38
+ # Postprocess the output
39
+ # Define the color map
40
+ color_map = {
41
+ 0: np.array([155, 155, 155]), # Unlabeled
42
+ 1: np.array([60, 16, 152]), # Building
43
+ 2: np.array([132, 41, 246]), # Land
44
+ 3: np.array([110, 193, 228]), # Road
45
+ 4: np.array([254, 221, 58]), # Vegetation
46
+ 5: np.array([226, 169, 41]) # Water
47
+ }
48
+ class_labels = {
49
+ 0: 'Unlabeled',
50
+ 1: 'Building',
51
+ 2: 'Land',
52
+ 3: 'Road',
53
+ 4: 'Vegetation',
54
+ 5: 'Water'
55
+ }
56
+
57
+ # Display the class labels and their colors in a sidebar
58
+ for k, v in class_labels.items():
59
+ st.sidebar.markdown(f'<div style="color:rgb{tuple(color_map[k])};">{v}</div>', unsafe_allow_html=True)
60
+
61
+ # Pass the image through the model
62
+ with torch.no_grad():
63
+ output = model(image)
64
+
65
+ # Postprocess the output
66
+ output = torch.argmax(output.squeeze(), dim=0).detach().cpu().numpy()
67
+
68
+ # Squeeze the batch dimension
69
+ output = np.squeeze(output)
70
+
71
+ # Now you can create the RGB image
72
+ output_rgb = np.zeros((output.shape[0], output.shape[1], 3), dtype=np.uint8)
73
+ for k, v in color_map.items():
74
+ output_rgb[output == k] = v
75
+
76
+ # Display the segmented image
77
+ st.image(output_rgb, caption='Segmented Image.', use_column_width=True)
78
+
79
+
dataloader.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import Dataset
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import numpy as np
6
+
7
+ class AerialImageDataset(Dataset):
8
+ def __init__(self, image_dir, mask_dir, transform=None):
9
+ self.image_dir = image_dir
10
+ self.mask_dir = mask_dir
11
+ self.transform = transform
12
+ self.images = os.listdir(self.image_dir)
13
+ self.Hex_Classes = [
14
+ ('Unlabeled', '#9B9B9B'),
15
+ ('Building','#3C1098'),
16
+ ('Land', '#8429F6'),
17
+ ('Road', '#6EC1E4'),
18
+ ('Vegetation', '#FEDD3A'),
19
+ ('Water', '#E2A929'),
20
+ ]
21
+
22
+ def __len__(self):
23
+ return len(self.images)
24
+
25
+ def __getitem__(self, idx):
26
+
27
+ img_path = os.path.join(self.image_dir, self.images[idx])
28
+ mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))
29
+
30
+ image = Image.open(img_path)
31
+ mask = Image.open(mask_path)
32
+
33
+ mask = np.array(mask)
34
+ mask = self.encode_segmap(mask)
35
+ mask = Image.fromarray(mask)
36
+
37
+ if self.transform:
38
+ image = self.transform(image)
39
+ mask = self.transform(mask)
40
+
41
+ return image, mask
42
+
43
+ def encode_segmap(self, mask):
44
+ mask = mask.astype(int)
45
+ label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
46
+ for i, (name, color) in enumerate(self.Hex_Classes):
47
+ if mask.ndim == 3:
48
+ label_mask[(mask[:,:,0] == int(color[1:3], 16)) & (mask[:,:,1] == int(color[3:5], 16)) & (mask[:,:,2] == int(color[5:7], 16))] = i
49
+ elif mask.ndim == 2:
50
+ label_mask[(mask == int(color[1:3], 16))] = i
51
+ return label_mask
dataprep.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import random
4
+
5
+ dataset_path = "/teamspace/studios/this_studio/Aerial-Segmentation/Semantic segmentation dataset"
6
+ new_dataset_path = "/teamspace/studios/this_studio/Aerial-Segmentation"
7
+ train_path = os.path.join(new_dataset_path, "train")
8
+ val_path = os.path.join(new_dataset_path, "val")
9
+
10
+ os.makedirs(train_path, exist_ok=True)
11
+ os.makedirs(val_path, exist_ok=True)
12
+
13
+ train_image_path = os.path.join(train_path, "images")
14
+ train_mask_path = os.path.join(train_path, "masks")
15
+
16
+ val_image_path = os.path.join(val_path, "images")
17
+ val_mask_path = os.path.join(val_path, "masks")
18
+
19
+ os.makedirs(train_image_path, exist_ok=True)
20
+ os.makedirs(val_image_path, exist_ok=True)
21
+ os.makedirs(train_mask_path, exist_ok=True)
22
+ os.makedirs(val_mask_path, exist_ok=True)
23
+
24
+
25
+ tile_folders = [folder for folder in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, folder))]
26
+
27
+ n_train_images = 8
28
+ n_val_images = 1
29
+
30
+
31
+ def copy(train_status):
32
+ if train_status:
33
+ images = train_images
34
+ path_image = train_image_path
35
+ path_mask = train_mask_path
36
+ else:
37
+ images = val_images
38
+ path_image = val_image_path
39
+ path_mask = val_mask_path
40
+
41
+
42
+
43
+ for image in images:
44
+ tile_image_name = f'{tile_folder}_{image}'
45
+ shutil.copy(os.path.join(images_path, image), os.path.join(path_image, tile_image_name))
46
+
47
+ mask_name = image.split('.')[0]+'.png'
48
+ tile_mask_name = f'{tile_folder}_{mask_name}'
49
+ shutil.copy(os.path.join(masks_path, mask_name), os.path.join(path_mask, tile_mask_name))
50
+
51
+
52
+ for tile_folder in tile_folders:
53
+ images_path = os.path.join(dataset_path, tile_folder, 'images')
54
+ masks_path = os.path.join(dataset_path, tile_folder, 'masks')
55
+
56
+ images = os.listdir(images_path)
57
+ masks = os.listdir(masks_path)
58
+
59
+ random.shuffle(images)
60
+ random.shuffle(masks)
61
+
62
+ train_images = images[:n_train_images]
63
+ val_images = images[n_train_images:]
64
+
65
+ copy(train_status=True)
66
+ copy(train_status=False)
67
+
68
+
69
+ shutil.rmtree(dataset_path)
70
+
71
+ print(f"Data organization and split completed successfully. Total Training Files is {len(os.listdir(train_image_path))} and Validation Files is {len(os.listdir(val_image_path))}")
main.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ from dataloader import AerialImageDataset
4
+ from train import UNet
5
+ from torch.utils.data import DataLoader
6
+ from torchvision.transforms import transforms
7
+ import torch
8
+
9
+ train_path = "/teamspace/studios/this_studio/Aerial-Segmentation/train"
10
+ val_path = "/teamspace/studios/this_studio/Aerial-Segmentation/val"
11
+
12
+ data_transform = transforms.Compose([
13
+ transforms.Resize((512, 512)),
14
+ transforms.ToTensor()]
15
+ )
16
+
17
+ train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform)
18
+ val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform)
19
+
20
+ train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
21
+ val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
22
+
23
+ model = UNet(n_channels=3, n_classes=6)
24
+
25
+ trainer = L.Trainer(max_epochs=100)
26
+ trainer.fit(model, train_loader, val_loader)
27
+
28
+ torch.save(model.state_dict(), "/teamspace/studios/this_studio/Aerial-Segmentation/model.pth")
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c7bbcc5453e099cdc1cf71edae30b611fed8ce893de98e35631a9a765832bbe
3
+ size 53651114