Thisissophia commited on
Commit
69e2ef2
·
verified ·
1 Parent(s): b5e8944

Upload 87 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. __pycache__/creat_anaglyph.cpython-38.pyc +0 -0
  3. __pycache__/deeplab_demo.cpython-38.pyc +0 -0
  4. __pycache__/mypath.cpython-38.pyc +0 -0
  5. anaglyph.png +3 -0
  6. app.py +96 -0
  7. creat_anaglyph.py +149 -0
  8. dataloaders/__init__.py +56 -0
  9. dataloaders/__pycache__/__init__.cpython-310.pyc +0 -0
  10. dataloaders/__pycache__/__init__.cpython-38.pyc +0 -0
  11. dataloaders/__pycache__/custom_transforms.cpython-310.pyc +0 -0
  12. dataloaders/__pycache__/custom_transforms.cpython-38.pyc +0 -0
  13. dataloaders/__pycache__/utils.cpython-310.pyc +0 -0
  14. dataloaders/__pycache__/utils.cpython-38.pyc +0 -0
  15. dataloaders/custom_transforms.py +165 -0
  16. dataloaders/datasets/__init__.py +0 -0
  17. dataloaders/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  18. dataloaders/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  19. dataloaders/datasets/__pycache__/cityscapes.cpython-310.pyc +0 -0
  20. dataloaders/datasets/__pycache__/cityscapes.cpython-38.pyc +0 -0
  21. dataloaders/datasets/__pycache__/coco.cpython-310.pyc +0 -0
  22. dataloaders/datasets/__pycache__/coco.cpython-38.pyc +0 -0
  23. dataloaders/datasets/__pycache__/combine_dbs.cpython-310.pyc +0 -0
  24. dataloaders/datasets/__pycache__/combine_dbs.cpython-38.pyc +0 -0
  25. dataloaders/datasets/__pycache__/invoice.cpython-310.pyc +0 -0
  26. dataloaders/datasets/__pycache__/invoice.cpython-38.pyc +0 -0
  27. dataloaders/datasets/__pycache__/pascal.cpython-310.pyc +0 -0
  28. dataloaders/datasets/__pycache__/pascal.cpython-38.pyc +0 -0
  29. dataloaders/datasets/__pycache__/sbd.cpython-310.pyc +0 -0
  30. dataloaders/datasets/__pycache__/sbd.cpython-38.pyc +0 -0
  31. dataloaders/datasets/cityscapes.py +146 -0
  32. dataloaders/datasets/coco.py +160 -0
  33. dataloaders/datasets/combine_dbs.py +100 -0
  34. dataloaders/datasets/invoice.py +145 -0
  35. dataloaders/datasets/pascal.py +145 -0
  36. dataloaders/datasets/sbd.py +129 -0
  37. dataloaders/utils.py +111 -0
  38. deeplab-mobilenet.pth.tar +3 -0
  39. deeplab-resnet.pth.tar +3 -0
  40. deeplab_demo.py +111 -0
  41. end.py +90 -0
  42. img/mask.png +0 -0
  43. img/masked.png +0 -0
  44. img/people.jpg +0 -0
  45. img/scenery.jpg +3 -0
  46. img/scenery2.jpg +3 -0
  47. modeling/__init__.py +0 -0
  48. modeling/__pycache__/__init__.cpython-310.pyc +0 -0
  49. modeling/__pycache__/__init__.cpython-38.pyc +0 -0
  50. modeling/__pycache__/aspp.cpython-310.pyc +0 -0
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ anaglyph.png filter=lfs diff=lfs merge=lfs -text
2
+ deeplab-mobilenet.pth.tar filter=lfs diff=lfs merge=lfs -text
3
+ deeplab-resnet.pth.tar filter=lfs diff=lfs merge=lfs -text
4
+ img/scenery.jpg filter=lfs diff=lfs merge=lfs -text
5
+ img/scenery2.jpg filter=lfs diff=lfs merge=lfs -text
__pycache__/creat_anaglyph.cpython-38.pyc ADDED
Binary file (2.53 kB). View file
 
__pycache__/deeplab_demo.cpython-38.pyc ADDED
Binary file (3.46 kB). View file
 
__pycache__/mypath.cpython-38.pyc ADDED
Binary file (812 Bytes). View file
 
anaglyph.png ADDED

Git LFS Details

  • SHA256: fbbb5fd4ee33896d6cf1cf8c245f420778673c2323a2cd1203a490a79e2d63be
  • Pointer size: 133 Bytes
  • Size of remote file: 11.6 MB
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hugging face requirements for app.py, main file for running the application
2
+ # equivalent to end.py to be used in the hugging face inference API,which Hugging Face will recognize as the main file for running application.
3
+ # app.py
4
+
5
+
6
+ import sys
7
+ import os
8
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
9
+
10
+ import gradio as gr
11
+ from PIL import Image
12
+ from deeplab_demo import get_people
13
+ from creat_anaglyph import insert_person_to_stereo_gradio
14
+ import torch
15
+ from torchvision.transforms import ToPILImage
16
+
17
+
18
+
19
+ # Define functions to process the person image and generate the anaglyph image
20
+ def process_person_image(person_image):
21
+ masked_image_pil, grid_image = get_people(person_image)
22
+
23
+ if isinstance(masked_image_pil, torch.Tensor):
24
+ masked_image_pil = ToPILImage()(masked_image_pil)
25
+ if isinstance(grid_image, torch.Tensor):
26
+ grid_image = ToPILImage()(grid_image)
27
+
28
+ return masked_image_pil, grid_image
29
+
30
+ # Define a function to generate the anaglyph image
31
+ def generate_anaglyph(masked_image_pil, scenery_image, depth_option, custom_disparity):
32
+ # Define default disparities for non-custom options
33
+ # non-custom options: close, medium, far
34
+ depth_disparities = {
35
+ "close": 10, # Adjust values as needed
36
+ "medium": 5,
37
+ "far": 2
38
+ }
39
+
40
+ # Use custom_disparity only if depth_option is "custom"
41
+ disparity = custom_disparity if depth_option == "custom" else depth_disparities.get(depth_option, 5)
42
+
43
+ # Ensure input is PIL image
44
+ if isinstance(masked_image_pil, torch.Tensor):
45
+ masked_image_pil = ToPILImage()(masked_image_pil)
46
+ if isinstance(scenery_image, torch.Tensor):
47
+ scenery_image = ToPILImage()(scenery_image)
48
+
49
+ anaglyph_image = insert_person_to_stereo_gradio(scenery_image, masked_image_pil, disparity)
50
+
51
+ if isinstance(anaglyph_image, torch.Tensor):
52
+ anaglyph_image = ToPILImage()(anaglyph_image)
53
+
54
+ return anaglyph_image
55
+
56
+ # Create Gradio interface
57
+ with gr.Blocks() as iface:
58
+ with gr.Row():
59
+ person_image_input = gr.Image(type="pil", label="Character image")
60
+ scenery_image_input = gr.Image(type="pil", label="Landscape images")
61
+ depth_option_input = gr.Dropdown(choices=["close", "medium", "far", "custom"], label="Depth Options")
62
+ custom_disparity_input = gr.Slider(minimum=0, maximum=50, step=1, label="Custom Depth Disparity", visible=False)
63
+
64
+ with gr.Row():
65
+ grid_image_output = gr.Image(type="pil", label="Grid", interactive=False)
66
+ masked_image_output = gr.Image(type="pil", label="Masked", interactive=False)
67
+ anaglyph_image_output = gr.Image(type="pil", label="Anaglyph", interactive=False)
68
+
69
+ # button1: Process the character image
70
+ process_button = gr.Button("Processing human images")
71
+ process_button.click(
72
+ fn=process_person_image,
73
+ inputs=person_image_input,
74
+ outputs=[masked_image_output, grid_image_output]
75
+ )
76
+
77
+ # define a function to update the visibility of the custom disparity slider based on the depth option
78
+ def update_custom_slider_visibility(depth_option):
79
+ return gr.update(visible=(depth_option == "custom"))
80
+
81
+ depth_option_input.change(
82
+ fn=update_custom_slider_visibility,
83
+ inputs=[depth_option_input],
84
+ outputs=custom_disparity_input
85
+ )
86
+
87
+ # button2: Generate anaglyph image
88
+ generate_button = gr.Button("Generate Anaglyph Image")
89
+ generate_button.click(
90
+ fn=generate_anaglyph,
91
+ inputs=[masked_image_output, scenery_image_input, depth_option_input, custom_disparity_input],
92
+ outputs=anaglyph_image_output
93
+ )
94
+
95
+ # Launch the Gradio interface
96
+ iface.launch()
creat_anaglyph.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #file: creat_anaglyph.py
2
+ # Description: This script creates a red-cyan anaglyph stereo image by inserting a person into a stereo image.
3
+
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.transforms import ToPILImage
8
+
9
+ # preprocess the human image to remove the black background
10
+ def preprocess_person_image(person_image_path):
11
+ # uploaded human image
12
+ person_image = Image.open(person_image_path).convert('RGBA')
13
+ data = np.array(person_image)
14
+
15
+ # separate color channels
16
+ r, g, b, a = data.T
17
+
18
+ # define the threshold for black background
19
+ black_threshold = 1
20
+ black_areas = (r < black_threshold) & (g < black_threshold) & (b < black_threshold)
21
+
22
+ # set black background to transparent
23
+ data[..., 3][black_areas.T] = 0 # only modify the alpha channel
24
+
25
+ # create a new image
26
+ transparent_image = Image.fromarray(data)
27
+ return transparent_image
28
+
29
+
30
+ # gradio compatible version of preprocess_person_image
31
+ def preprocess_person_image_gradio(person_image):
32
+ # ensure the image is in RGBA mode
33
+ if person_image.mode != 'RGBA':
34
+ person_image = person_image.convert('RGBA')
35
+
36
+
37
+ # load the human image
38
+ data = np.array(person_image)
39
+
40
+
41
+ # separate color channels
42
+ r, g, b, a = data.T
43
+
44
+ # define the threshold for black background
45
+ black_threshold = 1
46
+ black_areas = (r < black_threshold) & (g < black_threshold) & (b < black_threshold)
47
+
48
+ # set black background to transparent
49
+ data[..., 3][black_areas.T] = 0 # 只修改 alpha 通道
50
+
51
+ # create a new image
52
+ transparent_image = Image.fromarray(data)
53
+ return transparent_image
54
+
55
+ def insert_person_to_stereo(stereo_image_path, person_image_path, depth_option):
56
+ # load the stitched image
57
+ stereo_image = Image.open(stereo_image_path).convert('RGB')
58
+ width, height = stereo_image.size
59
+
60
+ # assume the stitched image is symmetrical
61
+ left_image = stereo_image.crop((0, 0, width // 2, height))
62
+ right_image = stereo_image.crop((width // 2, 0, width, height))
63
+
64
+ # preprocess the human image
65
+ person_image = preprocess_person_image(person_image_path)
66
+ person_width, person_height = person_image.size
67
+
68
+ # define disparity options based on image width
69
+ max_disparity = width // 20
70
+ disparity_options = {
71
+ 'close': max_disparity// 5,
72
+ 'medium': max_disparity // 15,
73
+ 'far': max_disparity // 20
74
+ }
75
+
76
+ # get the corresponding disparity value
77
+ disparity = disparity_options.get(depth_option, max_disparity // 2)
78
+
79
+ # calculate the insertion position to align the bottom of the human image with the bottom of the scene image and center horizontally
80
+ x_position = (width // 4) - (person_width // 2) + disparity
81
+ y_position = height - person_height
82
+
83
+ # insert the human image into the left and right views
84
+ left_image.paste(person_image, (x_position, y_position), person_image)
85
+ right_image.paste(person_image, (x_position - disparity, y_position), person_image)
86
+
87
+ # combine the left and right views into a red-cyan stereo image
88
+ left_array = np.array(left_image) # convert the left image to an array
89
+ right_array = np.array(right_image) # convert the right image to an array
90
+
91
+
92
+ # create a red-cyan stereo image
93
+ anaglyph = np.zeros_like(left_array)
94
+ anaglyph[..., 0] = left_array[..., 0] # red channel from left image
95
+ anaglyph[..., 1] = right_array[..., 1] # green channel from right image
96
+ anaglyph[..., 2] = right_array[..., 2] # blue channel from right image
97
+
98
+
99
+ # convert to an image and save
100
+ anaglyph_image = Image.fromarray(anaglyph) # convert the array to an image
101
+ anaglyph_image.save('anaglyph.png') # save the image
102
+
103
+
104
+ # gradio compatible version of insert_person_to_stereo
105
+ def insert_person_to_stereo_gradio(stereo_image, person_image, disparity):
106
+ # load the stitched image
107
+ # ensure left_image is in RGB mode
108
+ if person_image.mode != "RGBA":
109
+ masked_image_pil = person_image.convert("RGBA")
110
+ if stereo_image.mode != 'RGB':
111
+ stereo_image = stereo_image.convert('RGB')
112
+ width, height = stereo_image.size
113
+
114
+ # assume the stitched image is symmetrical
115
+ left_image = stereo_image.crop((0, 0, width // 2, height))
116
+ right_image = stereo_image.crop((width // 2, 0, width, height))
117
+
118
+ # preprocess the human image
119
+ person_image = preprocess_person_image_gradio(person_image)
120
+ person_width, person_height = person_image.size
121
+
122
+ # calculate the insertion position to align the bottom of the human image with the bottom of the scene image and center horizontally
123
+ x_position = (width // 4) - (person_width // 2) + disparity
124
+ y_position = height - person_height
125
+
126
+ # let's paste the person image into the left and right views
127
+ left_image.paste(person_image, (x_position, y_position), person_image)
128
+ right_image.paste(person_image, (x_position - disparity, y_position), person_image)
129
+
130
+ # combine the left and right views into a red-cyan stereo image
131
+ left_array = np.array(left_image)
132
+ right_array = np.array(right_image)
133
+
134
+
135
+ # create a red-cyan stereo image
136
+ anaglyph = np.zeros_like(left_array)
137
+ anaglyph[..., 0] = left_array[..., 0] # red channel from left image
138
+ anaglyph[..., 1] = right_array[..., 1] # green channel from right image
139
+ anaglyph[..., 2] = right_array[..., 2] # blue channel from right image
140
+
141
+
142
+ # convert to an image and return
143
+ anaglyph_image = Image.fromarray(anaglyph)
144
+ return anaglyph_image
145
+
146
+
147
+
148
+ # Example
149
+ insert_person_to_stereo('img/scenery.jpg', 'img/masked.png', 'far')
dataloaders/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd, invoice
2
+ from torch.utils.data import DataLoader
3
+
4
+ def make_data_loader(args, **kwargs):
5
+
6
+ if args.dataset == 'invoice':
7
+ train_set = invoice.VOCSegmentation(args, split='train')
8
+ val_set = invoice.VOCSegmentation(args, split='val')
9
+ if args.use_sbd:
10
+ sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
11
+ train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])
12
+
13
+ num_class = train_set.NUM_CLASSES
14
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
15
+ val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
16
+ test_loader = None
17
+
18
+ return train_loader, val_loader, test_loader, num_class
19
+
20
+ elif args.dataset == 'pascal':
21
+ train_set = pascal.VOCSegmentation(args, split='train')
22
+ val_set = pascal.VOCSegmentation(args, split='val')
23
+ if args.use_sbd:
24
+ sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
25
+ train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])
26
+
27
+ num_class = train_set.NUM_CLASSES
28
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
29
+ val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
30
+ test_loader = None
31
+
32
+ return train_loader, val_loader, test_loader, num_class
33
+
34
+ elif args.dataset == 'cityscapes':
35
+ train_set = cityscapes.CityscapesSegmentation(args, split='train')
36
+ val_set = cityscapes.CityscapesSegmentation(args, split='val')
37
+ test_set = cityscapes.CityscapesSegmentation(args, split='test')
38
+ num_class = train_set.NUM_CLASSES
39
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
40
+ val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
41
+ test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)
42
+
43
+ return train_loader, val_loader, test_loader, num_class
44
+
45
+ elif args.dataset == 'coco':
46
+ train_set = coco.COCOSegmentation(args, split='train')
47
+ val_set = coco.COCOSegmentation(args, split='val')
48
+ num_class = train_set.NUM_CLASSES
49
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
50
+ val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
51
+ test_loader = None
52
+ return train_loader, val_loader, test_loader, num_class
53
+
54
+ else:
55
+ raise NotImplementedError
56
+
dataloaders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
dataloaders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.45 kB). View file
 
dataloaders/__pycache__/custom_transforms.cpython-310.pyc ADDED
Binary file (5.23 kB). View file
 
dataloaders/__pycache__/custom_transforms.cpython-38.pyc ADDED
Binary file (5.32 kB). View file
 
dataloaders/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.85 kB). View file
 
dataloaders/__pycache__/utils.cpython-38.pyc ADDED
Binary file (3.42 kB). View file
 
dataloaders/custom_transforms.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageOps, ImageFilter
6
+
7
+ class Normalize(object):
8
+ """Normalize a tensor image with mean and standard deviation.
9
+ Args:
10
+ mean (tuple): means for each channel.
11
+ std (tuple): standard deviations for each channel.
12
+ """
13
+ def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
14
+ self.mean = mean
15
+ self.std = std
16
+
17
+ def __call__(self, sample):
18
+ img = sample['image']
19
+ mask = sample['label']
20
+ img = np.array(img).astype(np.float32)
21
+ mask = np.array(mask).astype(np.float32)
22
+ img /= 255.0
23
+ img -= self.mean
24
+ img /= self.std
25
+
26
+ return {'image': img,
27
+ 'label': mask}
28
+
29
+
30
+ class ToTensor(object):
31
+ """Convert ndarrays in sample to Tensors."""
32
+
33
+ def __call__(self, sample):
34
+ # swap color axis because
35
+ # numpy image: H x W x C
36
+ # torch image: C X H X W
37
+ img = sample['image']
38
+ mask = sample['label']
39
+ img = np.array(img).astype(np.float32).transpose((2, 0, 1))
40
+ mask = np.array(mask).astype(np.float32)
41
+
42
+ img = torch.from_numpy(img).float()
43
+ mask = torch.from_numpy(mask).float()
44
+
45
+ return {'image': img,
46
+ 'label': mask}
47
+
48
+
49
+ class RandomHorizontalFlip(object):
50
+ def __call__(self, sample):
51
+ img = sample['image']
52
+ mask = sample['label']
53
+ if random.random() < 0.5:
54
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
55
+ mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
56
+
57
+ return {'image': img,
58
+ 'label': mask}
59
+
60
+
61
+ class RandomRotate(object):
62
+ def __init__(self, degree):
63
+ self.degree = degree
64
+
65
+ def __call__(self, sample):
66
+ img = sample['image']
67
+ mask = sample['label']
68
+ rotate_degree = random.uniform(-1*self.degree, self.degree)
69
+ img = img.rotate(rotate_degree, Image.BILINEAR)
70
+ mask = mask.rotate(rotate_degree, Image.NEAREST)
71
+
72
+ return {'image': img,
73
+ 'label': mask}
74
+
75
+
76
+ class RandomGaussianBlur(object):
77
+ def __call__(self, sample):
78
+ img = sample['image']
79
+ mask = sample['label']
80
+ if random.random() < 0.5:
81
+ img = img.filter(ImageFilter.GaussianBlur(
82
+ radius=random.random()))
83
+
84
+ return {'image': img,
85
+ 'label': mask}
86
+
87
+
88
+ class RandomScaleCrop(object):
89
+ def __init__(self, base_size, crop_size, fill=0):
90
+ self.base_size = base_size
91
+ self.crop_size = crop_size
92
+ self.fill = fill
93
+
94
+ def __call__(self, sample):
95
+ img = sample['image']
96
+ mask = sample['label']
97
+ # random scale (short edge)
98
+ short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
99
+ w, h = img.size
100
+ if h > w:
101
+ ow = short_size
102
+ oh = int(1.0 * h * ow / w)
103
+ else:
104
+ oh = short_size
105
+ ow = int(1.0 * w * oh / h)
106
+ img = img.resize((ow, oh), Image.BILINEAR)
107
+ mask = mask.resize((ow, oh), Image.NEAREST)
108
+ # pad crop
109
+ if short_size < self.crop_size:
110
+ padh = self.crop_size - oh if oh < self.crop_size else 0
111
+ padw = self.crop_size - ow if ow < self.crop_size else 0
112
+ img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
113
+ mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
114
+ # random crop crop_size
115
+ w, h = img.size
116
+ x1 = random.randint(0, w - self.crop_size)
117
+ y1 = random.randint(0, h - self.crop_size)
118
+ img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
119
+ mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
120
+
121
+ return {'image': img,
122
+ 'label': mask}
123
+
124
+
125
+ class FixScaleCrop(object):
126
+ def __init__(self, crop_size):
127
+ self.crop_size = crop_size
128
+
129
+ def __call__(self, sample):
130
+ img = sample['image']
131
+ mask = sample['label']
132
+ w, h = img.size
133
+ if w > h:
134
+ oh = self.crop_size
135
+ ow = int(1.0 * w * oh / h)
136
+ else:
137
+ ow = self.crop_size
138
+ oh = int(1.0 * h * ow / w)
139
+ img = img.resize((ow, oh), Image.BILINEAR)
140
+ mask = mask.resize((ow, oh), Image.NEAREST)
141
+ # center crop
142
+ w, h = img.size
143
+ x1 = int(round((w - self.crop_size) / 2.))
144
+ y1 = int(round((h - self.crop_size) / 2.))
145
+ img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
146
+ mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
147
+
148
+ return {'image': img,
149
+ 'label': mask}
150
+
151
+ class FixedResize(object):
152
+ def __init__(self, size):
153
+ self.size = (size, size) # size: (h, w)
154
+
155
+ def __call__(self, sample):
156
+ img = sample['image']
157
+ mask = sample['label']
158
+
159
+ assert img.size == mask.size
160
+
161
+ img = img.resize(self.size, Image.BILINEAR)
162
+ mask = mask.resize(self.size, Image.NEAREST)
163
+
164
+ return {'image': img,
165
+ 'label': mask}
dataloaders/datasets/__init__.py ADDED
File without changes
dataloaders/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (171 Bytes). View file
 
dataloaders/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (154 Bytes). View file
 
dataloaders/datasets/__pycache__/cityscapes.cpython-310.pyc ADDED
Binary file (5.28 kB). View file
 
dataloaders/datasets/__pycache__/cityscapes.cpython-38.pyc ADDED
Binary file (5.43 kB). View file
 
dataloaders/datasets/__pycache__/coco.cpython-310.pyc ADDED
Binary file (5.38 kB). View file
 
dataloaders/datasets/__pycache__/coco.cpython-38.pyc ADDED
Binary file (5.4 kB). View file
 
dataloaders/datasets/__pycache__/combine_dbs.cpython-310.pyc ADDED
Binary file (3.19 kB). View file
 
dataloaders/datasets/__pycache__/combine_dbs.cpython-38.pyc ADDED
Binary file (3.17 kB). View file
 
dataloaders/datasets/__pycache__/invoice.cpython-310.pyc ADDED
Binary file (4.35 kB). View file
 
dataloaders/datasets/__pycache__/invoice.cpython-38.pyc ADDED
Binary file (4.31 kB). View file
 
dataloaders/datasets/__pycache__/pascal.cpython-310.pyc ADDED
Binary file (4.35 kB). View file
 
dataloaders/datasets/__pycache__/pascal.cpython-38.pyc ADDED
Binary file (4.31 kB). View file
 
dataloaders/datasets/__pycache__/sbd.cpython-310.pyc ADDED
Binary file (4.01 kB). View file
 
dataloaders/datasets/__pycache__/sbd.cpython-38.pyc ADDED
Binary file (3.97 kB). View file
 
dataloaders/datasets/cityscapes.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import scipy.misc as m
4
+ from PIL import Image
5
+ from torch.utils import data
6
+ from mypath import Path
7
+ from torchvision import transforms
8
+ from dataloaders import custom_transforms as tr
9
+
10
+ class CityscapesSegmentation(data.Dataset):
11
+ NUM_CLASSES = 19
12
+
13
+ def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train"):
14
+
15
+ self.root = root
16
+ self.split = split
17
+ self.args = args
18
+ self.files = {}
19
+
20
+ self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
21
+ self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split)
22
+
23
+ self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png')
24
+
25
+ self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
26
+ self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
27
+ self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
28
+ 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
29
+ 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
30
+ 'motorcycle', 'bicycle']
31
+
32
+ self.ignore_index = 255
33
+ self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES)))
34
+
35
+ if not self.files[split]:
36
+ raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
37
+
38
+ print("Found %d %s images" % (len(self.files[split]), split))
39
+
40
+ def __len__(self):
41
+ return len(self.files[self.split])
42
+
43
+ def __getitem__(self, index):
44
+
45
+ img_path = self.files[self.split][index].rstrip()
46
+ lbl_path = os.path.join(self.annotations_base,
47
+ img_path.split(os.sep)[-2],
48
+ os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
49
+
50
+ _img = Image.open(img_path).convert('RGB')
51
+ _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
52
+ _tmp = self.encode_segmap(_tmp)
53
+ _target = Image.fromarray(_tmp)
54
+
55
+ sample = {'image': _img, 'label': _target}
56
+
57
+ if self.split == 'train':
58
+ return self.transform_tr(sample)
59
+ elif self.split == 'val':
60
+ return self.transform_val(sample)
61
+ elif self.split == 'test':
62
+ return self.transform_ts(sample)
63
+
64
+ def encode_segmap(self, mask):
65
+ # Put all void classes to zero
66
+ for _voidc in self.void_classes:
67
+ mask[mask == _voidc] = self.ignore_index
68
+ for _validc in self.valid_classes:
69
+ mask[mask == _validc] = self.class_map[_validc]
70
+ return mask
71
+
72
+ def recursive_glob(self, rootdir='.', suffix=''):
73
+ """Performs recursive glob with given suffix and rootdir
74
+ :param rootdir is the root directory
75
+ :param suffix is the suffix to be searched
76
+ """
77
+ return [os.path.join(looproot, filename)
78
+ for looproot, _, filenames in os.walk(rootdir)
79
+ for filename in filenames if filename.endswith(suffix)]
80
+
81
+ def transform_tr(self, sample):
82
+ composed_transforms = transforms.Compose([
83
+ tr.RandomHorizontalFlip(),
84
+ tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
85
+ tr.RandomGaussianBlur(),
86
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
87
+ tr.ToTensor()])
88
+
89
+ return composed_transforms(sample)
90
+
91
+ def transform_val(self, sample):
92
+
93
+ composed_transforms = transforms.Compose([
94
+ tr.FixScaleCrop(crop_size=self.args.crop_size),
95
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
96
+ tr.ToTensor()])
97
+
98
+ return composed_transforms(sample)
99
+
100
+ def transform_ts(self, sample):
101
+
102
+ composed_transforms = transforms.Compose([
103
+ tr.FixedResize(size=self.args.crop_size),
104
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
105
+ tr.ToTensor()])
106
+
107
+ return composed_transforms(sample)
108
+
109
+ if __name__ == '__main__':
110
+ from dataloaders.utils import decode_segmap
111
+ from torch.utils.data import DataLoader
112
+ import matplotlib.pyplot as plt
113
+ import argparse
114
+
115
+ parser = argparse.ArgumentParser()
116
+ args = parser.parse_args()
117
+ args.base_size = 513
118
+ args.crop_size = 513
119
+
120
+ cityscapes_train = CityscapesSegmentation(args, split='train')
121
+
122
+ dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)
123
+
124
+ for ii, sample in enumerate(dataloader):
125
+ for jj in range(sample["image"].size()[0]):
126
+ img = sample['image'].numpy()
127
+ gt = sample['label'].numpy()
128
+ tmp = np.array(gt[jj]).astype(np.uint8)
129
+ segmap = decode_segmap(tmp, dataset='cityscapes')
130
+ img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
131
+ img_tmp *= (0.229, 0.224, 0.225)
132
+ img_tmp += (0.485, 0.456, 0.406)
133
+ img_tmp *= 255.0
134
+ img_tmp = img_tmp.astype(np.uint8)
135
+ plt.figure()
136
+ plt.title('display')
137
+ plt.subplot(211)
138
+ plt.imshow(img_tmp)
139
+ plt.subplot(212)
140
+ plt.imshow(segmap)
141
+
142
+ if ii == 1:
143
+ break
144
+
145
+ plt.show(block=True)
146
+
dataloaders/datasets/coco.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from mypath import Path
5
+ from tqdm import trange
6
+ import os
7
+ from pycocotools.coco import COCO
8
+ from pycocotools import mask
9
+ from torchvision import transforms
10
+ from dataloaders import custom_transforms as tr
11
+ from PIL import Image, ImageFile
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+
14
+
15
+ class COCOSegmentation(Dataset):
16
+ NUM_CLASSES = 21
17
+ CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
18
+ 1, 64, 20, 63, 7, 72]
19
+
20
+ def __init__(self,
21
+ args,
22
+ base_dir=Path.db_root_dir('coco'),
23
+ split='train',
24
+ year='2017'):
25
+ super().__init__()
26
+ ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year))
27
+ ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year))
28
+ self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year))
29
+ self.split = split
30
+ self.coco = COCO(ann_file)
31
+ self.coco_mask = mask
32
+ if os.path.exists(ids_file):
33
+ self.ids = torch.load(ids_file)
34
+ else:
35
+ ids = list(self.coco.imgs.keys())
36
+ self.ids = self._preprocess(ids, ids_file)
37
+ self.args = args
38
+
39
+ def __getitem__(self, index):
40
+ _img, _target = self._make_img_gt_point_pair(index)
41
+ sample = {'image': _img, 'label': _target}
42
+
43
+ if self.split == "train":
44
+ return self.transform_tr(sample)
45
+ elif self.split == 'val':
46
+ return self.transform_val(sample)
47
+
48
+ def _make_img_gt_point_pair(self, index):
49
+ coco = self.coco
50
+ img_id = self.ids[index]
51
+ img_metadata = coco.loadImgs(img_id)[0]
52
+ path = img_metadata['file_name']
53
+ _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
54
+ cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
55
+ _target = Image.fromarray(self._gen_seg_mask(
56
+ cocotarget, img_metadata['height'], img_metadata['width']))
57
+
58
+ return _img, _target
59
+
60
+ def _preprocess(self, ids, ids_file):
61
+ print("Preprocessing mask, this will take a while. " + \
62
+ "But don't worry, it only run once for each split.")
63
+ tbar = trange(len(ids))
64
+ new_ids = []
65
+ for i in tbar:
66
+ img_id = ids[i]
67
+ cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
68
+ img_metadata = self.coco.loadImgs(img_id)[0]
69
+ mask = self._gen_seg_mask(cocotarget, img_metadata['height'],
70
+ img_metadata['width'])
71
+ # more than 1k pixels
72
+ if (mask > 0).sum() > 1000:
73
+ new_ids.append(img_id)
74
+ tbar.set_description('Doing: {}/{}, got {} qualified images'. \
75
+ format(i, len(ids), len(new_ids)))
76
+ print('Found number of qualified images: ', len(new_ids))
77
+ torch.save(new_ids, ids_file)
78
+ return new_ids
79
+
80
+ def _gen_seg_mask(self, target, h, w):
81
+ mask = np.zeros((h, w), dtype=np.uint8)
82
+ coco_mask = self.coco_mask
83
+ for instance in target:
84
+ rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
85
+ m = coco_mask.decode(rle)
86
+ cat = instance['category_id']
87
+ if cat in self.CAT_LIST:
88
+ c = self.CAT_LIST.index(cat)
89
+ else:
90
+ continue
91
+ if len(m.shape) < 3:
92
+ mask[:, :] += (mask == 0) * (m * c)
93
+ else:
94
+ mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
95
+ return mask
96
+
97
+ def transform_tr(self, sample):
98
+ composed_transforms = transforms.Compose([
99
+ tr.RandomHorizontalFlip(),
100
+ tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
101
+ tr.RandomGaussianBlur(),
102
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
103
+ tr.ToTensor()])
104
+
105
+ return composed_transforms(sample)
106
+
107
+ def transform_val(self, sample):
108
+
109
+ composed_transforms = transforms.Compose([
110
+ tr.FixScaleCrop(crop_size=self.args.crop_size),
111
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
112
+ tr.ToTensor()])
113
+
114
+ return composed_transforms(sample)
115
+
116
+
117
+ def __len__(self):
118
+ return len(self.ids)
119
+
120
+
121
+
122
+ if __name__ == "__main__":
123
+ from dataloaders import custom_transforms as tr
124
+ from dataloaders.utils import decode_segmap
125
+ from torch.utils.data import DataLoader
126
+ from torchvision import transforms
127
+ import matplotlib.pyplot as plt
128
+ import argparse
129
+
130
+ parser = argparse.ArgumentParser()
131
+ args = parser.parse_args()
132
+ args.base_size = 513
133
+ args.crop_size = 513
134
+
135
+ coco_val = COCOSegmentation(args, split='val', year='2017')
136
+
137
+ dataloader = DataLoader(coco_val, batch_size=4, shuffle=True, num_workers=0)
138
+
139
+ for ii, sample in enumerate(dataloader):
140
+ for jj in range(sample["image"].size()[0]):
141
+ img = sample['image'].numpy()
142
+ gt = sample['label'].numpy()
143
+ tmp = np.array(gt[jj]).astype(np.uint8)
144
+ segmap = decode_segmap(tmp, dataset='coco')
145
+ img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
146
+ img_tmp *= (0.229, 0.224, 0.225)
147
+ img_tmp += (0.485, 0.456, 0.406)
148
+ img_tmp *= 255.0
149
+ img_tmp = img_tmp.astype(np.uint8)
150
+ plt.figure()
151
+ plt.title('display')
152
+ plt.subplot(211)
153
+ plt.imshow(img_tmp)
154
+ plt.subplot(212)
155
+ plt.imshow(segmap)
156
+
157
+ if ii == 1:
158
+ break
159
+
160
+ plt.show(block=True)
dataloaders/datasets/combine_dbs.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+
4
+ class CombineDBs(data.Dataset):
5
+ NUM_CLASSES = 21
6
+ def __init__(self, dataloaders, excluded=None):
7
+ self.dataloaders = dataloaders
8
+ self.excluded = excluded
9
+ self.im_ids = []
10
+
11
+ # Combine object lists
12
+ for dl in dataloaders:
13
+ for elem in dl.im_ids:
14
+ if elem not in self.im_ids:
15
+ self.im_ids.append(elem)
16
+
17
+ # Exclude
18
+ if excluded:
19
+ for dl in excluded:
20
+ for elem in dl.im_ids:
21
+ if elem in self.im_ids:
22
+ self.im_ids.remove(elem)
23
+
24
+ # Get object pointers
25
+ self.cat_list = []
26
+ self.im_list = []
27
+ new_im_ids = []
28
+ num_images = 0
29
+ for ii, dl in enumerate(dataloaders):
30
+ for jj, curr_im_id in enumerate(dl.im_ids):
31
+ if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids):
32
+ num_images += 1
33
+ new_im_ids.append(curr_im_id)
34
+ self.cat_list.append({'db_ii': ii, 'cat_ii': jj})
35
+
36
+ self.im_ids = new_im_ids
37
+ print('Combined number of images: {:d}'.format(num_images))
38
+
39
+ def __getitem__(self, index):
40
+
41
+ _db_ii = self.cat_list[index]["db_ii"]
42
+ _cat_ii = self.cat_list[index]['cat_ii']
43
+ sample = self.dataloaders[_db_ii].__getitem__(_cat_ii)
44
+
45
+ if 'meta' in sample.keys():
46
+ sample['meta']['db'] = str(self.dataloaders[_db_ii])
47
+
48
+ return sample
49
+
50
+ def __len__(self):
51
+ return len(self.cat_list)
52
+
53
+ def __str__(self):
54
+ include_db = [str(db) for db in self.dataloaders]
55
+ exclude_db = [str(db) for db in self.excluded]
56
+ return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ import matplotlib.pyplot as plt
61
+ from dataloaders.datasets import pascal, sbd
62
+ from dataloaders import sbd
63
+ import torch
64
+ import numpy as np
65
+ from dataloaders.utils import decode_segmap
66
+ import argparse
67
+
68
+ parser = argparse.ArgumentParser()
69
+ args = parser.parse_args()
70
+ args.base_size = 513
71
+ args.crop_size = 513
72
+
73
+ pascal_voc_val = pascal.VOCSegmentation(args, split='val')
74
+ sbd = sbd.SBDSegmentation(args, split=['train', 'val'])
75
+ pascal_voc_train = pascal.VOCSegmentation(args, split='train')
76
+
77
+ dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val])
78
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
79
+
80
+ for ii, sample in enumerate(dataloader):
81
+ for jj in range(sample["image"].size()[0]):
82
+ img = sample['image'].numpy()
83
+ gt = sample['label'].numpy()
84
+ tmp = np.array(gt[jj]).astype(np.uint8)
85
+ segmap = decode_segmap(tmp, dataset='pascal')
86
+ img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
87
+ img_tmp *= (0.229, 0.224, 0.225)
88
+ img_tmp += (0.485, 0.456, 0.406)
89
+ img_tmp *= 255.0
90
+ img_tmp = img_tmp.astype(np.uint8)
91
+ plt.figure()
92
+ plt.title('display')
93
+ plt.subplot(211)
94
+ plt.imshow(img_tmp)
95
+ plt.subplot(212)
96
+ plt.imshow(segmap)
97
+
98
+ if ii == 1:
99
+ break
100
+ plt.show(block=True)
dataloaders/datasets/invoice.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+ from mypath import Path
7
+ from torchvision import transforms
8
+ from dataloaders import custom_transforms as tr
9
+
10
+ class VOCSegmentation(Dataset):
11
+ """
12
+ PascalVoc dataset
13
+ """
14
+ NUM_CLASSES = 2
15
+
16
+ def __init__(self,
17
+ args,
18
+ base_dir=Path.db_root_dir('invoice'),
19
+ split='train',
20
+ ):
21
+ """
22
+ :param base_dir: path to VOC dataset directory
23
+ :param split: train/val
24
+ :param transform: transform to apply
25
+ """
26
+ super().__init__()
27
+ self._base_dir = base_dir
28
+ self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
29
+ self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass')
30
+
31
+ if isinstance(split, str):
32
+ self.split = [split]
33
+ else:
34
+ split.sort()
35
+ self.split = split
36
+
37
+ self.args = args
38
+
39
+ _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')
40
+
41
+ self.im_ids = []
42
+ self.images = []
43
+ self.categories = []
44
+
45
+ for splt in self.split:
46
+ with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
47
+ lines = f.read().splitlines()
48
+
49
+ for ii, line in enumerate(lines):
50
+ _image = os.path.join(self._image_dir, line + ".png")
51
+ _cat = os.path.join(self._cat_dir, line + ".png")
52
+ assert os.path.isfile(_image)
53
+ assert os.path.isfile(_cat)
54
+ self.im_ids.append(line)
55
+ self.images.append(_image)
56
+ self.categories.append(_cat)
57
+
58
+ assert (len(self.images) == len(self.categories))
59
+
60
+ # Display stats
61
+ print('Number of images in {}: {:d}'.format(split, len(self.images)))
62
+
63
+ def __len__(self):
64
+ return len(self.images)
65
+
66
+
67
+ def __getitem__(self, index):
68
+ _img, _target = self._make_img_gt_point_pair(index)
69
+ sample = {'image': _img, 'label': _target}
70
+
71
+ for split in self.split:
72
+ if split == "train":
73
+ return self.transform_tr(sample)
74
+ elif split == 'val':
75
+ return self.transform_val(sample)
76
+
77
+
78
+ def _make_img_gt_point_pair(self, index):
79
+ _img = Image.open(self.images[index]).convert('RGB')
80
+ _target = Image.open(self.categories[index])
81
+
82
+ return _img, _target
83
+
84
+ def transform_tr(self, sample):
85
+ composed_transforms = transforms.Compose([
86
+ tr.RandomHorizontalFlip(),
87
+ tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
88
+ tr.RandomGaussianBlur(),
89
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
90
+ tr.ToTensor()])
91
+
92
+ return composed_transforms(sample)
93
+
94
+ def transform_val(self, sample):
95
+
96
+ composed_transforms = transforms.Compose([
97
+ tr.FixScaleCrop(crop_size=self.args.crop_size),
98
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
99
+ tr.ToTensor()])
100
+
101
+ return composed_transforms(sample)
102
+
103
+ def __str__(self):
104
+ return 'VOC2012(split=' + str(self.split) + ')'
105
+
106
+
107
+ if __name__ == '__main__':
108
+ from dataloaders.utils import decode_segmap
109
+ from torch.utils.data import DataLoader
110
+ import matplotlib.pyplot as plt
111
+ import argparse
112
+
113
+ parser = argparse.ArgumentParser()
114
+ args = parser.parse_args()
115
+ args.base_size = 512
116
+ args.crop_size = 512
117
+
118
+ voc_train = VOCSegmentation(args, split='train')
119
+
120
+ dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0)
121
+
122
+ for ii, sample in enumerate(dataloader):
123
+ for jj in range(sample["image"].size()[0]):
124
+ img = sample['image'].numpy()
125
+ gt = sample['label'].numpy()
126
+ tmp = np.array(gt[jj]).astype(np.uint8)
127
+ segmap = decode_segmap(tmp, dataset='invoice')
128
+ img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
129
+ img_tmp *= (0.229, 0.224, 0.225)
130
+ img_tmp += (0.485, 0.456, 0.406)
131
+ img_tmp *= 255.0
132
+ img_tmp = img_tmp.astype(np.uint8)
133
+ plt.figure()
134
+ plt.title('display')
135
+ plt.subplot(211)
136
+ plt.imshow(img_tmp)
137
+ plt.subplot(212)
138
+ plt.imshow(segmap)
139
+
140
+ if ii == 1:
141
+ break
142
+
143
+ plt.show(block=True)
144
+
145
+
dataloaders/datasets/pascal.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+ from mypath import Path
7
+ from torchvision import transforms
8
+ from dataloaders import custom_transforms as tr
9
+
10
+ class VOCSegmentation(Dataset):
11
+ """
12
+ PascalVoc dataset
13
+ """
14
+ NUM_CLASSES = 21
15
+
16
+ def __init__(self,
17
+ args,
18
+ base_dir=Path.db_root_dir('pascal'),
19
+ split='train',
20
+ ):
21
+ """
22
+ :param base_dir: path to VOC dataset directory
23
+ :param split: train/val
24
+ :param transform: transform to apply
25
+ """
26
+ super().__init__()
27
+ self._base_dir = base_dir
28
+ self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
29
+ self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass')
30
+
31
+ if isinstance(split, str):
32
+ self.split = [split]
33
+ else:
34
+ split.sort()
35
+ self.split = split
36
+
37
+ self.args = args
38
+
39
+ _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')
40
+
41
+ self.im_ids = []
42
+ self.images = []
43
+ self.categories = []
44
+
45
+ for splt in self.split:
46
+ with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
47
+ lines = f.read().splitlines()
48
+
49
+ for ii, line in enumerate(lines):
50
+ _image = os.path.join(self._image_dir, line + ".jpg")
51
+ _cat = os.path.join(self._cat_dir, line + ".png")
52
+ assert os.path.isfile(_image)
53
+ assert os.path.isfile(_cat)
54
+ self.im_ids.append(line)
55
+ self.images.append(_image)
56
+ self.categories.append(_cat)
57
+
58
+ assert (len(self.images) == len(self.categories))
59
+
60
+ # Display stats
61
+ print('Number of images in {}: {:d}'.format(split, len(self.images)))
62
+
63
+ def __len__(self):
64
+ return len(self.images)
65
+
66
+
67
+ def __getitem__(self, index):
68
+ _img, _target = self._make_img_gt_point_pair(index)
69
+ sample = {'image': _img, 'label': _target}
70
+
71
+ for split in self.split:
72
+ if split == "train":
73
+ return self.transform_tr(sample)
74
+ elif split == 'val':
75
+ return self.transform_val(sample)
76
+
77
+
78
+ def _make_img_gt_point_pair(self, index):
79
+ _img = Image.open(self.images[index]).convert('RGB')
80
+ _target = Image.open(self.categories[index])
81
+
82
+ return _img, _target
83
+
84
+ def transform_tr(self, sample):
85
+ composed_transforms = transforms.Compose([
86
+ tr.RandomHorizontalFlip(),
87
+ tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
88
+ tr.RandomGaussianBlur(),
89
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
90
+ tr.ToTensor()])
91
+
92
+ return composed_transforms(sample)
93
+
94
+ def transform_val(self, sample):
95
+
96
+ composed_transforms = transforms.Compose([
97
+ tr.FixScaleCrop(crop_size=self.args.crop_size),
98
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
99
+ tr.ToTensor()])
100
+
101
+ return composed_transforms(sample)
102
+
103
+ def __str__(self):
104
+ return 'VOC2012(split=' + str(self.split) + ')'
105
+
106
+
107
+ if __name__ == '__main__':
108
+ from dataloaders.utils import decode_segmap
109
+ from torch.utils.data import DataLoader
110
+ import matplotlib.pyplot as plt
111
+ import argparse
112
+
113
+ parser = argparse.ArgumentParser()
114
+ args = parser.parse_args()
115
+ args.base_size = 513
116
+ args.crop_size = 513
117
+
118
+ voc_train = VOCSegmentation(args, split='train')
119
+
120
+ dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0)
121
+
122
+ for ii, sample in enumerate(dataloader):
123
+ for jj in range(sample["image"].size()[0]):
124
+ img = sample['image'].numpy()
125
+ gt = sample['label'].numpy()
126
+ tmp = np.array(gt[jj]).astype(np.uint8)
127
+ segmap = decode_segmap(tmp, dataset='pascal')
128
+ img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
129
+ img_tmp *= (0.229, 0.224, 0.225)
130
+ img_tmp += (0.485, 0.456, 0.406)
131
+ img_tmp *= 255.0
132
+ img_tmp = img_tmp.astype(np.uint8)
133
+ plt.figure()
134
+ plt.title('display')
135
+ plt.subplot(211)
136
+ plt.imshow(img_tmp)
137
+ plt.subplot(212)
138
+ plt.imshow(segmap)
139
+
140
+ if ii == 1:
141
+ break
142
+
143
+ plt.show(block=True)
144
+
145
+
dataloaders/datasets/sbd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+
4
+ import numpy as np
5
+ import scipy.io
6
+ import torch.utils.data as data
7
+ from PIL import Image
8
+ from mypath import Path
9
+
10
+ from torchvision import transforms
11
+ from dataloaders import custom_transforms as tr
12
+
13
+ class SBDSegmentation(data.Dataset):
14
+ NUM_CLASSES = 21
15
+
16
+ def __init__(self,
17
+ args,
18
+ base_dir=Path.db_root_dir('sbd'),
19
+ split='train',
20
+ ):
21
+ """
22
+ :param base_dir: path to VOC dataset directory
23
+ :param split: train/val
24
+ :param transform: transform to apply
25
+ """
26
+ super().__init__()
27
+ self._base_dir = base_dir
28
+ self._dataset_dir = os.path.join(self._base_dir, 'dataset')
29
+ self._image_dir = os.path.join(self._dataset_dir, 'img')
30
+ self._cat_dir = os.path.join(self._dataset_dir, 'cls')
31
+
32
+
33
+ if isinstance(split, str):
34
+ self.split = [split]
35
+ else:
36
+ split.sort()
37
+ self.split = split
38
+
39
+ self.args = args
40
+
41
+ # Get list of all images from the split and check that the files exist
42
+ self.im_ids = []
43
+ self.images = []
44
+ self.categories = []
45
+ for splt in self.split:
46
+ with open(os.path.join(self._dataset_dir, splt + '.txt'), "r") as f:
47
+ lines = f.read().splitlines()
48
+
49
+ for line in lines:
50
+ _image = os.path.join(self._image_dir, line + ".jpg")
51
+ _categ= os.path.join(self._cat_dir, line + ".mat")
52
+ assert os.path.isfile(_image)
53
+ assert os.path.isfile(_categ)
54
+ self.im_ids.append(line)
55
+ self.images.append(_image)
56
+ self.categories.append(_categ)
57
+
58
+ assert (len(self.images) == len(self.categories))
59
+
60
+ # Display stats
61
+ print('Number of images: {:d}'.format(len(self.images)))
62
+
63
+
64
+ def __getitem__(self, index):
65
+ _img, _target = self._make_img_gt_point_pair(index)
66
+ sample = {'image': _img, 'label': _target}
67
+
68
+ return self.transform(sample)
69
+
70
+ def __len__(self):
71
+ return len(self.images)
72
+
73
+ def _make_img_gt_point_pair(self, index):
74
+ _img = Image.open(self.images[index]).convert('RGB')
75
+ _target = Image.fromarray(scipy.io.loadmat(self.categories[index])["GTcls"][0]['Segmentation'][0])
76
+
77
+ return _img, _target
78
+
79
+ def transform(self, sample):
80
+ composed_transforms = transforms.Compose([
81
+ tr.RandomHorizontalFlip(),
82
+ tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
83
+ tr.RandomGaussianBlur(),
84
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
85
+ tr.ToTensor()])
86
+
87
+ return composed_transforms(sample)
88
+
89
+
90
+ def __str__(self):
91
+ return 'SBDSegmentation(split=' + str(self.split) + ')'
92
+
93
+
94
+ if __name__ == '__main__':
95
+ from dataloaders.utils import decode_segmap
96
+ from torch.utils.data import DataLoader
97
+ import matplotlib.pyplot as plt
98
+ import argparse
99
+
100
+ parser = argparse.ArgumentParser()
101
+ args = parser.parse_args()
102
+ args.base_size = 513
103
+ args.crop_size = 513
104
+
105
+ sbd_train = SBDSegmentation(args, split='train')
106
+ dataloader = DataLoader(sbd_train, batch_size=2, shuffle=True, num_workers=2)
107
+
108
+ for ii, sample in enumerate(dataloader):
109
+ for jj in range(sample["image"].size()[0]):
110
+ img = sample['image'].numpy()
111
+ gt = sample['label'].numpy()
112
+ tmp = np.array(gt[jj]).astype(np.uint8)
113
+ segmap = decode_segmap(tmp, dataset='pascal')
114
+ img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
115
+ img_tmp *= (0.229, 0.224, 0.225)
116
+ img_tmp += (0.485, 0.456, 0.406)
117
+ img_tmp *= 255.0
118
+ img_tmp = img_tmp.astype(np.uint8)
119
+ plt.figure()
120
+ plt.title('display')
121
+ plt.subplot(211)
122
+ plt.imshow(img_tmp)
123
+ plt.subplot(212)
124
+ plt.imshow(segmap)
125
+
126
+ if ii == 1:
127
+ break
128
+
129
+ plt.show(block=True)
dataloaders/utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+
5
+ def decode_seg_map_sequence(label_masks, dataset='pascal'):
6
+ rgb_masks = []
7
+ for label_mask in label_masks:
8
+ rgb_mask = decode_segmap(label_mask, dataset)
9
+ rgb_masks.append(rgb_mask)
10
+ rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
11
+ return rgb_masks
12
+
13
+
14
+ def decode_segmap(label_mask, dataset, plot=False):
15
+ """Decode segmentation class labels into a color image
16
+ Args:
17
+ label_mask (np.ndarray): an (M,N) array of integer values denoting
18
+ the class label at each spatial location.
19
+ plot (bool, optional): whether to show the resulting color image
20
+ in a figure.
21
+ Returns:
22
+ (np.ndarray, optional): the resulting decoded color image.
23
+ """
24
+ if dataset == 'pascal' or dataset == 'coco':
25
+ n_classes = 21
26
+ label_colours = get_pascal_labels()
27
+ elif dataset == 'cityscapes':
28
+ n_classes = 19
29
+ label_colours = get_cityscapes_labels()
30
+ elif dataset == 'invoice':
31
+ n_classes = 2
32
+ label_colours = get_invoice_labels()
33
+ else:
34
+ raise NotImplementedError
35
+
36
+ r = label_mask.copy()
37
+ g = label_mask.copy()
38
+ b = label_mask.copy()
39
+ for ll in range(0, n_classes):
40
+ r[label_mask == ll] = label_colours[ll, 0]
41
+ g[label_mask == ll] = label_colours[ll, 1]
42
+ b[label_mask == ll] = label_colours[ll, 2]
43
+ rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
44
+ rgb[:, :, 0] = r / 255.0
45
+ rgb[:, :, 1] = g / 255.0
46
+ rgb[:, :, 2] = b / 255.0
47
+ if plot:
48
+ plt.imshow(rgb)
49
+ plt.show()
50
+ else:
51
+ return rgb
52
+
53
+
54
+ def encode_segmap(mask):
55
+ """Encode segmentation label images as pascal classes
56
+ Args:
57
+ mask (np.ndarray): raw segmentation label image of dimension
58
+ (M, N, 3), in which the Pascal classes are encoded as colours.
59
+ Returns:
60
+ (np.ndarray): class map with dimensions (M,N), where the value at
61
+ a given location is the integer denoting the class index.
62
+ """
63
+ mask = mask.astype(int)
64
+ label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
65
+ for ii, label in enumerate(get_pascal_labels()):
66
+ label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
67
+ label_mask = label_mask.astype(int)
68
+ return label_mask
69
+
70
+
71
+ def get_cityscapes_labels():
72
+ return np.array([
73
+ [128, 64, 128],
74
+ [244, 35, 232],
75
+ [70, 70, 70],
76
+ [102, 102, 156],
77
+ [190, 153, 153],
78
+ [153, 153, 153],
79
+ [250, 170, 30],
80
+ [220, 220, 0],
81
+ [107, 142, 35],
82
+ [152, 251, 152],
83
+ [0, 130, 180],
84
+ [220, 20, 60],
85
+ [255, 0, 0],
86
+ [0, 0, 142],
87
+ [0, 0, 70],
88
+ [0, 60, 100],
89
+ [0, 80, 100],
90
+ [0, 0, 230],
91
+ [119, 11, 32]])
92
+
93
+
94
+ def get_pascal_labels():
95
+ """Load the mapping that associates pascal classes with label colors
96
+ Returns:
97
+ np.ndarray with dimensions (21, 3)
98
+ """
99
+ return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
100
+ [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
101
+ [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
102
+ [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
103
+ [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
104
+ [0, 64, 128]])
105
+
106
+ def get_invoice_labels():
107
+ """Load the mapping that associates pascal classes with label colors
108
+ Returns:
109
+ np.ndarray with dimensions (21, 3)
110
+ """
111
+ return np.asarray([[0, 0, 0], [255, 255, 255]])
deeplab-mobilenet.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a36ba48f39fc6edc161335211b15d9250cadb521f1cb958cb6d014399093f31
3
+ size 46666796
deeplab-resnet.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c1ca4610f1ff8c118b451aa0ab30048554a9e77b794f7174808c457e935913a
3
+ size 474903453
deeplab_demo.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: This script is used to extract the specified category from the image using the trained DeepLabV3+ model.
2
+ # file name: deeplab_demo.py
3
+
4
+
5
+ import argparse
6
+ import time
7
+ from modeling.deeplab import *
8
+ from dataloaders import custom_transforms as tr
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from dataloaders.utils import *
12
+ from torchvision.utils import make_grid, save_image
13
+ from torchvision.transforms import ToTensor, ToPILImage
14
+
15
+ def get_people(newimage):
16
+ #define the argument parser for configuring the model
17
+ parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")
18
+ parser.add_argument('--in-path', type=str, default="img", help='image to test')
19
+ # parser.add_argument('--out-path', type=str, required=True, help='mask image to save')
20
+ parser.add_argument('--backbone', type=str, default='mobilenet',
21
+ choices=['resnet', 'xception', 'drn', 'mobilenet'],
22
+ help='backbone name (default: resnet)')
23
+ parser.add_argument('--ckpt', type=str, default='deeplab-mobilenet.pth.tar',
24
+ help='saved model')
25
+ parser.add_argument('--out-stride', type=int, default=8,
26
+ help='network output stride (default: 8)')
27
+ parser.add_argument('--no-cuda', action='store_true', default=False,
28
+ help='disables CUDA training')
29
+ parser.add_argument('--gpu-ids', type=str, default='0',
30
+ help='use which gpu to train, must be a \
31
+ comma-separated list of integers only (default=0)')
32
+ parser.add_argument('--dataset', type=str, default='invoice',
33
+ choices=['pascal', 'coco', 'cityscapes','invoice'],
34
+ help='dataset name (default: pascal)')
35
+ parser.add_argument('--crop-size', type=int, default=512,
36
+ help='crop image size')
37
+ parser.add_argument('--num_classes', type=int, default=21,
38
+ help='crop image size')
39
+ parser.add_argument('--sync-bn', type=bool, default=None,
40
+ help='whether to use sync bn (default: auto)')
41
+ parser.add_argument('--freeze-bn', type=bool, default=False,
42
+ help='whether to freeze bn parameters (default: False)')
43
+ args = parser.parse_args()
44
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
45
+ if args.cuda:
46
+ try:
47
+ args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
48
+ except ValueError:
49
+ raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
50
+
51
+ if args.sync_bn is None:
52
+ if args.cuda and len(args.gpu_ids) > 1:
53
+ args.sync_bn = True
54
+ else:
55
+ args.sync_bn = False
56
+ model_s_time = time.time()
57
+ model = DeepLab(num_classes=args.num_classes,
58
+ backbone=args.backbone,
59
+ output_stride=args.out_stride,
60
+ sync_bn=args.sync_bn,
61
+ freeze_bn=args.freeze_bn)
62
+
63
+ ckpt = torch.load(args.ckpt, map_location='cpu')
64
+ model.load_state_dict(ckpt['state_dict'])
65
+ # model = model.cuda()
66
+ model_u_time = time.time()
67
+ model_load_time = model_u_time-model_s_time
68
+ print("model load time is {}".format(model_load_time))
69
+
70
+ composed_transforms = transforms.Compose([
71
+ tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
72
+ tr.ToTensor()])
73
+
74
+ image = newimage
75
+ s_time = time.time()
76
+ target = newimage
77
+ sample = {'image': image, 'label': target}
78
+ tensor_in = composed_transforms(sample)['image'].unsqueeze(0)
79
+
80
+ model.eval()
81
+ if args.cuda:
82
+ tensor_in = tensor_in.cuda()
83
+ with torch.no_grad():
84
+ output = model(tensor_in)
85
+
86
+
87
+ # Get category index
88
+ pred = torch.max(output, 1)[1].detach().cpu().numpy()
89
+ # Specify the category label to extract
90
+ target_class = 15 #replace with the category index you want to extract
91
+ mask = (pred == target_class).astype(np.uint8).squeeze()
92
+ # Apply the mask to the original image
93
+ image_np = np.array(image)
94
+ masked_image = image_np * mask[:, :, np.newaxis]
95
+
96
+ # save the masked area
97
+ masked_image_pil = Image.fromarray(masked_image)
98
+ grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy()),
99
+ 3, normalize=False)
100
+ u_time = time.time()
101
+ img_time = u_time - s_time
102
+ print("time: {} ".format(img_time))
103
+
104
+ return masked_image_pil, grid_image
105
+
106
+ # mypath=r'img/people.jpg'
107
+ # image = Image.open(mypath).convert('RGB')
108
+ # result, mask=get_people(image)
109
+ # result_tensor = ToTensor()(result)
110
+ # save_image(result_tensor, "masked.png")
111
+
end.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: end.py
2
+ # Description: This is the main file of the project. It is used to create the Gradio interface and run the application.
3
+
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from deeplab_demo import get_people
7
+ from creat_anaglyph import insert_person_to_stereo_gradio
8
+ import torch
9
+ from torchvision.transforms import ToPILImage
10
+
11
+ # Define functions to process the person image and generate the anaglyph image
12
+ def process_person_image(person_image):
13
+ masked_image_pil, grid_image = get_people(person_image)
14
+
15
+ if isinstance(masked_image_pil, torch.Tensor):
16
+ masked_image_pil = ToPILImage()(masked_image_pil)
17
+ if isinstance(grid_image, torch.Tensor):
18
+ grid_image = ToPILImage()(grid_image)
19
+
20
+ return masked_image_pil, grid_image
21
+
22
+ # Define a function to generate the anaglyph image
23
+ def generate_anaglyph(masked_image_pil, scenery_image, depth_option, custom_disparity):
24
+ # Define default disparities for non-custom options
25
+ # non-custom options: close, medium, far
26
+ depth_disparities = {
27
+ "close": 10, # Adjust values as needed
28
+ "medium": 5,
29
+ "far": 2
30
+ }
31
+
32
+ # Use custom_disparity only if depth_option is "custom"
33
+ disparity = custom_disparity if depth_option == "custom" else depth_disparities.get(depth_option, 5)
34
+
35
+ # Ensure input is PIL image
36
+ if isinstance(masked_image_pil, torch.Tensor):
37
+ masked_image_pil = ToPILImage()(masked_image_pil)
38
+ if isinstance(scenery_image, torch.Tensor):
39
+ scenery_image = ToPILImage()(scenery_image)
40
+
41
+ anaglyph_image = insert_person_to_stereo_gradio(scenery_image, masked_image_pil, disparity)
42
+
43
+ if isinstance(anaglyph_image, torch.Tensor):
44
+ anaglyph_image = ToPILImage()(anaglyph_image)
45
+
46
+ return anaglyph_image
47
+
48
+ # Create Gradio interface
49
+ with gr.Blocks() as iface:
50
+ with gr.Row():
51
+ person_image_input = gr.Image(type="pil", label="Character image")
52
+ scenery_image_input = gr.Image(type="pil", label="Landscape images")
53
+ depth_option_input = gr.Dropdown(choices=["close", "medium", "far", "custom"], label="Depth Options")
54
+ custom_disparity_input = gr.Slider(minimum=0, maximum=50, step=1, label="Custom Depth Disparity", visible=False)
55
+
56
+ with gr.Row():
57
+ grid_image_output = gr.Image(type="pil", label="Grid", interactive=False)
58
+ masked_image_output = gr.Image(type="pil", label="Masked", interactive=False)
59
+ anaglyph_image_output = gr.Image(type="pil", label="Anaglyph", interactive=False)
60
+
61
+ # button1: Process the character image
62
+ process_button = gr.Button("Processing human images")
63
+ process_button.click(
64
+ fn=process_person_image,
65
+ inputs=person_image_input,
66
+ outputs=[masked_image_output, grid_image_output]
67
+ )
68
+
69
+ # define a function to update the visibility of the custom disparity slider based on the depth option
70
+ def update_custom_slider_visibility(depth_option):
71
+ return gr.update(visible=(depth_option == "custom"))
72
+
73
+ depth_option_input.change(
74
+ fn=update_custom_slider_visibility,
75
+ inputs=[depth_option_input],
76
+ outputs=custom_disparity_input
77
+ )
78
+
79
+ # button2: Generate anaglyph image
80
+ generate_button = gr.Button("Generate Anaglyph Image")
81
+ generate_button.click(
82
+ fn=generate_anaglyph,
83
+ inputs=[masked_image_output, scenery_image_input, depth_option_input, custom_disparity_input],
84
+ outputs=anaglyph_image_output
85
+ )
86
+
87
+ # Launch the Gradio interface
88
+ #change from iface.launch()
89
+ iface.launch(share=True)
90
+
img/mask.png ADDED
img/masked.png ADDED
img/people.jpg ADDED
img/scenery.jpg ADDED

Git LFS Details

  • SHA256: 0567943542c1b4f1e8b272d6eb2e7ec4b4bf4605375d7d0b0e49f43a7065e552
  • Pointer size: 132 Bytes
  • Size of remote file: 2.96 MB
img/scenery2.jpg ADDED

Git LFS Details

  • SHA256: a943acacd8426172c90b7ca380c6b9ef8ef50d53d78a2f7c16a34ae6d14a067e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.97 MB
modeling/__init__.py ADDED
File without changes
modeling/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
modeling/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (142 Bytes). View file
 
modeling/__pycache__/aspp.cpython-310.pyc ADDED
Binary file (3.08 kB). View file