Spaces:
Running
Running
Upload 87 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +5 -0
- __pycache__/creat_anaglyph.cpython-38.pyc +0 -0
- __pycache__/deeplab_demo.cpython-38.pyc +0 -0
- __pycache__/mypath.cpython-38.pyc +0 -0
- anaglyph.png +3 -0
- app.py +96 -0
- creat_anaglyph.py +149 -0
- dataloaders/__init__.py +56 -0
- dataloaders/__pycache__/__init__.cpython-310.pyc +0 -0
- dataloaders/__pycache__/__init__.cpython-38.pyc +0 -0
- dataloaders/__pycache__/custom_transforms.cpython-310.pyc +0 -0
- dataloaders/__pycache__/custom_transforms.cpython-38.pyc +0 -0
- dataloaders/__pycache__/utils.cpython-310.pyc +0 -0
- dataloaders/__pycache__/utils.cpython-38.pyc +0 -0
- dataloaders/custom_transforms.py +165 -0
- dataloaders/datasets/__init__.py +0 -0
- dataloaders/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- dataloaders/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- dataloaders/datasets/__pycache__/cityscapes.cpython-310.pyc +0 -0
- dataloaders/datasets/__pycache__/cityscapes.cpython-38.pyc +0 -0
- dataloaders/datasets/__pycache__/coco.cpython-310.pyc +0 -0
- dataloaders/datasets/__pycache__/coco.cpython-38.pyc +0 -0
- dataloaders/datasets/__pycache__/combine_dbs.cpython-310.pyc +0 -0
- dataloaders/datasets/__pycache__/combine_dbs.cpython-38.pyc +0 -0
- dataloaders/datasets/__pycache__/invoice.cpython-310.pyc +0 -0
- dataloaders/datasets/__pycache__/invoice.cpython-38.pyc +0 -0
- dataloaders/datasets/__pycache__/pascal.cpython-310.pyc +0 -0
- dataloaders/datasets/__pycache__/pascal.cpython-38.pyc +0 -0
- dataloaders/datasets/__pycache__/sbd.cpython-310.pyc +0 -0
- dataloaders/datasets/__pycache__/sbd.cpython-38.pyc +0 -0
- dataloaders/datasets/cityscapes.py +146 -0
- dataloaders/datasets/coco.py +160 -0
- dataloaders/datasets/combine_dbs.py +100 -0
- dataloaders/datasets/invoice.py +145 -0
- dataloaders/datasets/pascal.py +145 -0
- dataloaders/datasets/sbd.py +129 -0
- dataloaders/utils.py +111 -0
- deeplab-mobilenet.pth.tar +3 -0
- deeplab-resnet.pth.tar +3 -0
- deeplab_demo.py +111 -0
- end.py +90 -0
- img/mask.png +0 -0
- img/masked.png +0 -0
- img/people.jpg +0 -0
- img/scenery.jpg +3 -0
- img/scenery2.jpg +3 -0
- modeling/__init__.py +0 -0
- modeling/__pycache__/__init__.cpython-310.pyc +0 -0
- modeling/__pycache__/__init__.cpython-38.pyc +0 -0
- modeling/__pycache__/aspp.cpython-310.pyc +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
anaglyph.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
deeplab-mobilenet.pth.tar filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
deeplab-resnet.pth.tar filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
img/scenery.jpg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
img/scenery2.jpg filter=lfs diff=lfs merge=lfs -text
|
__pycache__/creat_anaglyph.cpython-38.pyc
ADDED
|
Binary file (2.53 kB). View file
|
|
|
__pycache__/deeplab_demo.cpython-38.pyc
ADDED
|
Binary file (3.46 kB). View file
|
|
|
__pycache__/mypath.cpython-38.pyc
ADDED
|
Binary file (812 Bytes). View file
|
|
|
anaglyph.png
ADDED
|
Git LFS Details
|
app.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# hugging face requirements for app.py, main file for running the application
|
| 2 |
+
# equivalent to end.py to be used in the hugging face inference API,which Hugging Face will recognize as the main file for running application.
|
| 3 |
+
# app.py
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from deeplab_demo import get_people
|
| 13 |
+
from creat_anaglyph import insert_person_to_stereo_gradio
|
| 14 |
+
import torch
|
| 15 |
+
from torchvision.transforms import ToPILImage
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Define functions to process the person image and generate the anaglyph image
|
| 20 |
+
def process_person_image(person_image):
|
| 21 |
+
masked_image_pil, grid_image = get_people(person_image)
|
| 22 |
+
|
| 23 |
+
if isinstance(masked_image_pil, torch.Tensor):
|
| 24 |
+
masked_image_pil = ToPILImage()(masked_image_pil)
|
| 25 |
+
if isinstance(grid_image, torch.Tensor):
|
| 26 |
+
grid_image = ToPILImage()(grid_image)
|
| 27 |
+
|
| 28 |
+
return masked_image_pil, grid_image
|
| 29 |
+
|
| 30 |
+
# Define a function to generate the anaglyph image
|
| 31 |
+
def generate_anaglyph(masked_image_pil, scenery_image, depth_option, custom_disparity):
|
| 32 |
+
# Define default disparities for non-custom options
|
| 33 |
+
# non-custom options: close, medium, far
|
| 34 |
+
depth_disparities = {
|
| 35 |
+
"close": 10, # Adjust values as needed
|
| 36 |
+
"medium": 5,
|
| 37 |
+
"far": 2
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Use custom_disparity only if depth_option is "custom"
|
| 41 |
+
disparity = custom_disparity if depth_option == "custom" else depth_disparities.get(depth_option, 5)
|
| 42 |
+
|
| 43 |
+
# Ensure input is PIL image
|
| 44 |
+
if isinstance(masked_image_pil, torch.Tensor):
|
| 45 |
+
masked_image_pil = ToPILImage()(masked_image_pil)
|
| 46 |
+
if isinstance(scenery_image, torch.Tensor):
|
| 47 |
+
scenery_image = ToPILImage()(scenery_image)
|
| 48 |
+
|
| 49 |
+
anaglyph_image = insert_person_to_stereo_gradio(scenery_image, masked_image_pil, disparity)
|
| 50 |
+
|
| 51 |
+
if isinstance(anaglyph_image, torch.Tensor):
|
| 52 |
+
anaglyph_image = ToPILImage()(anaglyph_image)
|
| 53 |
+
|
| 54 |
+
return anaglyph_image
|
| 55 |
+
|
| 56 |
+
# Create Gradio interface
|
| 57 |
+
with gr.Blocks() as iface:
|
| 58 |
+
with gr.Row():
|
| 59 |
+
person_image_input = gr.Image(type="pil", label="Character image")
|
| 60 |
+
scenery_image_input = gr.Image(type="pil", label="Landscape images")
|
| 61 |
+
depth_option_input = gr.Dropdown(choices=["close", "medium", "far", "custom"], label="Depth Options")
|
| 62 |
+
custom_disparity_input = gr.Slider(minimum=0, maximum=50, step=1, label="Custom Depth Disparity", visible=False)
|
| 63 |
+
|
| 64 |
+
with gr.Row():
|
| 65 |
+
grid_image_output = gr.Image(type="pil", label="Grid", interactive=False)
|
| 66 |
+
masked_image_output = gr.Image(type="pil", label="Masked", interactive=False)
|
| 67 |
+
anaglyph_image_output = gr.Image(type="pil", label="Anaglyph", interactive=False)
|
| 68 |
+
|
| 69 |
+
# button1: Process the character image
|
| 70 |
+
process_button = gr.Button("Processing human images")
|
| 71 |
+
process_button.click(
|
| 72 |
+
fn=process_person_image,
|
| 73 |
+
inputs=person_image_input,
|
| 74 |
+
outputs=[masked_image_output, grid_image_output]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# define a function to update the visibility of the custom disparity slider based on the depth option
|
| 78 |
+
def update_custom_slider_visibility(depth_option):
|
| 79 |
+
return gr.update(visible=(depth_option == "custom"))
|
| 80 |
+
|
| 81 |
+
depth_option_input.change(
|
| 82 |
+
fn=update_custom_slider_visibility,
|
| 83 |
+
inputs=[depth_option_input],
|
| 84 |
+
outputs=custom_disparity_input
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# button2: Generate anaglyph image
|
| 88 |
+
generate_button = gr.Button("Generate Anaglyph Image")
|
| 89 |
+
generate_button.click(
|
| 90 |
+
fn=generate_anaglyph,
|
| 91 |
+
inputs=[masked_image_output, scenery_image_input, depth_option_input, custom_disparity_input],
|
| 92 |
+
outputs=anaglyph_image_output
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Launch the Gradio interface
|
| 96 |
+
iface.launch()
|
creat_anaglyph.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#file: creat_anaglyph.py
|
| 2 |
+
# Description: This script creates a red-cyan anaglyph stereo image by inserting a person into a stereo image.
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms import ToPILImage
|
| 8 |
+
|
| 9 |
+
# preprocess the human image to remove the black background
|
| 10 |
+
def preprocess_person_image(person_image_path):
|
| 11 |
+
# uploaded human image
|
| 12 |
+
person_image = Image.open(person_image_path).convert('RGBA')
|
| 13 |
+
data = np.array(person_image)
|
| 14 |
+
|
| 15 |
+
# separate color channels
|
| 16 |
+
r, g, b, a = data.T
|
| 17 |
+
|
| 18 |
+
# define the threshold for black background
|
| 19 |
+
black_threshold = 1
|
| 20 |
+
black_areas = (r < black_threshold) & (g < black_threshold) & (b < black_threshold)
|
| 21 |
+
|
| 22 |
+
# set black background to transparent
|
| 23 |
+
data[..., 3][black_areas.T] = 0 # only modify the alpha channel
|
| 24 |
+
|
| 25 |
+
# create a new image
|
| 26 |
+
transparent_image = Image.fromarray(data)
|
| 27 |
+
return transparent_image
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# gradio compatible version of preprocess_person_image
|
| 31 |
+
def preprocess_person_image_gradio(person_image):
|
| 32 |
+
# ensure the image is in RGBA mode
|
| 33 |
+
if person_image.mode != 'RGBA':
|
| 34 |
+
person_image = person_image.convert('RGBA')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# load the human image
|
| 38 |
+
data = np.array(person_image)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# separate color channels
|
| 42 |
+
r, g, b, a = data.T
|
| 43 |
+
|
| 44 |
+
# define the threshold for black background
|
| 45 |
+
black_threshold = 1
|
| 46 |
+
black_areas = (r < black_threshold) & (g < black_threshold) & (b < black_threshold)
|
| 47 |
+
|
| 48 |
+
# set black background to transparent
|
| 49 |
+
data[..., 3][black_areas.T] = 0 # 只修改 alpha 通道
|
| 50 |
+
|
| 51 |
+
# create a new image
|
| 52 |
+
transparent_image = Image.fromarray(data)
|
| 53 |
+
return transparent_image
|
| 54 |
+
|
| 55 |
+
def insert_person_to_stereo(stereo_image_path, person_image_path, depth_option):
|
| 56 |
+
# load the stitched image
|
| 57 |
+
stereo_image = Image.open(stereo_image_path).convert('RGB')
|
| 58 |
+
width, height = stereo_image.size
|
| 59 |
+
|
| 60 |
+
# assume the stitched image is symmetrical
|
| 61 |
+
left_image = stereo_image.crop((0, 0, width // 2, height))
|
| 62 |
+
right_image = stereo_image.crop((width // 2, 0, width, height))
|
| 63 |
+
|
| 64 |
+
# preprocess the human image
|
| 65 |
+
person_image = preprocess_person_image(person_image_path)
|
| 66 |
+
person_width, person_height = person_image.size
|
| 67 |
+
|
| 68 |
+
# define disparity options based on image width
|
| 69 |
+
max_disparity = width // 20
|
| 70 |
+
disparity_options = {
|
| 71 |
+
'close': max_disparity// 5,
|
| 72 |
+
'medium': max_disparity // 15,
|
| 73 |
+
'far': max_disparity // 20
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# get the corresponding disparity value
|
| 77 |
+
disparity = disparity_options.get(depth_option, max_disparity // 2)
|
| 78 |
+
|
| 79 |
+
# calculate the insertion position to align the bottom of the human image with the bottom of the scene image and center horizontally
|
| 80 |
+
x_position = (width // 4) - (person_width // 2) + disparity
|
| 81 |
+
y_position = height - person_height
|
| 82 |
+
|
| 83 |
+
# insert the human image into the left and right views
|
| 84 |
+
left_image.paste(person_image, (x_position, y_position), person_image)
|
| 85 |
+
right_image.paste(person_image, (x_position - disparity, y_position), person_image)
|
| 86 |
+
|
| 87 |
+
# combine the left and right views into a red-cyan stereo image
|
| 88 |
+
left_array = np.array(left_image) # convert the left image to an array
|
| 89 |
+
right_array = np.array(right_image) # convert the right image to an array
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# create a red-cyan stereo image
|
| 93 |
+
anaglyph = np.zeros_like(left_array)
|
| 94 |
+
anaglyph[..., 0] = left_array[..., 0] # red channel from left image
|
| 95 |
+
anaglyph[..., 1] = right_array[..., 1] # green channel from right image
|
| 96 |
+
anaglyph[..., 2] = right_array[..., 2] # blue channel from right image
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# convert to an image and save
|
| 100 |
+
anaglyph_image = Image.fromarray(anaglyph) # convert the array to an image
|
| 101 |
+
anaglyph_image.save('anaglyph.png') # save the image
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# gradio compatible version of insert_person_to_stereo
|
| 105 |
+
def insert_person_to_stereo_gradio(stereo_image, person_image, disparity):
|
| 106 |
+
# load the stitched image
|
| 107 |
+
# ensure left_image is in RGB mode
|
| 108 |
+
if person_image.mode != "RGBA":
|
| 109 |
+
masked_image_pil = person_image.convert("RGBA")
|
| 110 |
+
if stereo_image.mode != 'RGB':
|
| 111 |
+
stereo_image = stereo_image.convert('RGB')
|
| 112 |
+
width, height = stereo_image.size
|
| 113 |
+
|
| 114 |
+
# assume the stitched image is symmetrical
|
| 115 |
+
left_image = stereo_image.crop((0, 0, width // 2, height))
|
| 116 |
+
right_image = stereo_image.crop((width // 2, 0, width, height))
|
| 117 |
+
|
| 118 |
+
# preprocess the human image
|
| 119 |
+
person_image = preprocess_person_image_gradio(person_image)
|
| 120 |
+
person_width, person_height = person_image.size
|
| 121 |
+
|
| 122 |
+
# calculate the insertion position to align the bottom of the human image with the bottom of the scene image and center horizontally
|
| 123 |
+
x_position = (width // 4) - (person_width // 2) + disparity
|
| 124 |
+
y_position = height - person_height
|
| 125 |
+
|
| 126 |
+
# let's paste the person image into the left and right views
|
| 127 |
+
left_image.paste(person_image, (x_position, y_position), person_image)
|
| 128 |
+
right_image.paste(person_image, (x_position - disparity, y_position), person_image)
|
| 129 |
+
|
| 130 |
+
# combine the left and right views into a red-cyan stereo image
|
| 131 |
+
left_array = np.array(left_image)
|
| 132 |
+
right_array = np.array(right_image)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# create a red-cyan stereo image
|
| 136 |
+
anaglyph = np.zeros_like(left_array)
|
| 137 |
+
anaglyph[..., 0] = left_array[..., 0] # red channel from left image
|
| 138 |
+
anaglyph[..., 1] = right_array[..., 1] # green channel from right image
|
| 139 |
+
anaglyph[..., 2] = right_array[..., 2] # blue channel from right image
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# convert to an image and return
|
| 143 |
+
anaglyph_image = Image.fromarray(anaglyph)
|
| 144 |
+
return anaglyph_image
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Example
|
| 149 |
+
insert_person_to_stereo('img/scenery.jpg', 'img/masked.png', 'far')
|
dataloaders/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd, invoice
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
|
| 4 |
+
def make_data_loader(args, **kwargs):
|
| 5 |
+
|
| 6 |
+
if args.dataset == 'invoice':
|
| 7 |
+
train_set = invoice.VOCSegmentation(args, split='train')
|
| 8 |
+
val_set = invoice.VOCSegmentation(args, split='val')
|
| 9 |
+
if args.use_sbd:
|
| 10 |
+
sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
|
| 11 |
+
train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])
|
| 12 |
+
|
| 13 |
+
num_class = train_set.NUM_CLASSES
|
| 14 |
+
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
|
| 15 |
+
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
|
| 16 |
+
test_loader = None
|
| 17 |
+
|
| 18 |
+
return train_loader, val_loader, test_loader, num_class
|
| 19 |
+
|
| 20 |
+
elif args.dataset == 'pascal':
|
| 21 |
+
train_set = pascal.VOCSegmentation(args, split='train')
|
| 22 |
+
val_set = pascal.VOCSegmentation(args, split='val')
|
| 23 |
+
if args.use_sbd:
|
| 24 |
+
sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
|
| 25 |
+
train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])
|
| 26 |
+
|
| 27 |
+
num_class = train_set.NUM_CLASSES
|
| 28 |
+
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
|
| 29 |
+
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
|
| 30 |
+
test_loader = None
|
| 31 |
+
|
| 32 |
+
return train_loader, val_loader, test_loader, num_class
|
| 33 |
+
|
| 34 |
+
elif args.dataset == 'cityscapes':
|
| 35 |
+
train_set = cityscapes.CityscapesSegmentation(args, split='train')
|
| 36 |
+
val_set = cityscapes.CityscapesSegmentation(args, split='val')
|
| 37 |
+
test_set = cityscapes.CityscapesSegmentation(args, split='test')
|
| 38 |
+
num_class = train_set.NUM_CLASSES
|
| 39 |
+
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
|
| 40 |
+
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
|
| 41 |
+
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)
|
| 42 |
+
|
| 43 |
+
return train_loader, val_loader, test_loader, num_class
|
| 44 |
+
|
| 45 |
+
elif args.dataset == 'coco':
|
| 46 |
+
train_set = coco.COCOSegmentation(args, split='train')
|
| 47 |
+
val_set = coco.COCOSegmentation(args, split='val')
|
| 48 |
+
num_class = train_set.NUM_CLASSES
|
| 49 |
+
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
|
| 50 |
+
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
|
| 51 |
+
test_loader = None
|
| 52 |
+
return train_loader, val_loader, test_loader, num_class
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
|
dataloaders/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
dataloaders/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
dataloaders/__pycache__/custom_transforms.cpython-310.pyc
ADDED
|
Binary file (5.23 kB). View file
|
|
|
dataloaders/__pycache__/custom_transforms.cpython-38.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
dataloaders/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
dataloaders/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
dataloaders/custom_transforms.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from PIL import Image, ImageOps, ImageFilter
|
| 6 |
+
|
| 7 |
+
class Normalize(object):
|
| 8 |
+
"""Normalize a tensor image with mean and standard deviation.
|
| 9 |
+
Args:
|
| 10 |
+
mean (tuple): means for each channel.
|
| 11 |
+
std (tuple): standard deviations for each channel.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
|
| 14 |
+
self.mean = mean
|
| 15 |
+
self.std = std
|
| 16 |
+
|
| 17 |
+
def __call__(self, sample):
|
| 18 |
+
img = sample['image']
|
| 19 |
+
mask = sample['label']
|
| 20 |
+
img = np.array(img).astype(np.float32)
|
| 21 |
+
mask = np.array(mask).astype(np.float32)
|
| 22 |
+
img /= 255.0
|
| 23 |
+
img -= self.mean
|
| 24 |
+
img /= self.std
|
| 25 |
+
|
| 26 |
+
return {'image': img,
|
| 27 |
+
'label': mask}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ToTensor(object):
|
| 31 |
+
"""Convert ndarrays in sample to Tensors."""
|
| 32 |
+
|
| 33 |
+
def __call__(self, sample):
|
| 34 |
+
# swap color axis because
|
| 35 |
+
# numpy image: H x W x C
|
| 36 |
+
# torch image: C X H X W
|
| 37 |
+
img = sample['image']
|
| 38 |
+
mask = sample['label']
|
| 39 |
+
img = np.array(img).astype(np.float32).transpose((2, 0, 1))
|
| 40 |
+
mask = np.array(mask).astype(np.float32)
|
| 41 |
+
|
| 42 |
+
img = torch.from_numpy(img).float()
|
| 43 |
+
mask = torch.from_numpy(mask).float()
|
| 44 |
+
|
| 45 |
+
return {'image': img,
|
| 46 |
+
'label': mask}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RandomHorizontalFlip(object):
|
| 50 |
+
def __call__(self, sample):
|
| 51 |
+
img = sample['image']
|
| 52 |
+
mask = sample['label']
|
| 53 |
+
if random.random() < 0.5:
|
| 54 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 55 |
+
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 56 |
+
|
| 57 |
+
return {'image': img,
|
| 58 |
+
'label': mask}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RandomRotate(object):
|
| 62 |
+
def __init__(self, degree):
|
| 63 |
+
self.degree = degree
|
| 64 |
+
|
| 65 |
+
def __call__(self, sample):
|
| 66 |
+
img = sample['image']
|
| 67 |
+
mask = sample['label']
|
| 68 |
+
rotate_degree = random.uniform(-1*self.degree, self.degree)
|
| 69 |
+
img = img.rotate(rotate_degree, Image.BILINEAR)
|
| 70 |
+
mask = mask.rotate(rotate_degree, Image.NEAREST)
|
| 71 |
+
|
| 72 |
+
return {'image': img,
|
| 73 |
+
'label': mask}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class RandomGaussianBlur(object):
|
| 77 |
+
def __call__(self, sample):
|
| 78 |
+
img = sample['image']
|
| 79 |
+
mask = sample['label']
|
| 80 |
+
if random.random() < 0.5:
|
| 81 |
+
img = img.filter(ImageFilter.GaussianBlur(
|
| 82 |
+
radius=random.random()))
|
| 83 |
+
|
| 84 |
+
return {'image': img,
|
| 85 |
+
'label': mask}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class RandomScaleCrop(object):
|
| 89 |
+
def __init__(self, base_size, crop_size, fill=0):
|
| 90 |
+
self.base_size = base_size
|
| 91 |
+
self.crop_size = crop_size
|
| 92 |
+
self.fill = fill
|
| 93 |
+
|
| 94 |
+
def __call__(self, sample):
|
| 95 |
+
img = sample['image']
|
| 96 |
+
mask = sample['label']
|
| 97 |
+
# random scale (short edge)
|
| 98 |
+
short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
|
| 99 |
+
w, h = img.size
|
| 100 |
+
if h > w:
|
| 101 |
+
ow = short_size
|
| 102 |
+
oh = int(1.0 * h * ow / w)
|
| 103 |
+
else:
|
| 104 |
+
oh = short_size
|
| 105 |
+
ow = int(1.0 * w * oh / h)
|
| 106 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
| 107 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
| 108 |
+
# pad crop
|
| 109 |
+
if short_size < self.crop_size:
|
| 110 |
+
padh = self.crop_size - oh if oh < self.crop_size else 0
|
| 111 |
+
padw = self.crop_size - ow if ow < self.crop_size else 0
|
| 112 |
+
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
| 113 |
+
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
|
| 114 |
+
# random crop crop_size
|
| 115 |
+
w, h = img.size
|
| 116 |
+
x1 = random.randint(0, w - self.crop_size)
|
| 117 |
+
y1 = random.randint(0, h - self.crop_size)
|
| 118 |
+
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
| 119 |
+
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
| 120 |
+
|
| 121 |
+
return {'image': img,
|
| 122 |
+
'label': mask}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class FixScaleCrop(object):
|
| 126 |
+
def __init__(self, crop_size):
|
| 127 |
+
self.crop_size = crop_size
|
| 128 |
+
|
| 129 |
+
def __call__(self, sample):
|
| 130 |
+
img = sample['image']
|
| 131 |
+
mask = sample['label']
|
| 132 |
+
w, h = img.size
|
| 133 |
+
if w > h:
|
| 134 |
+
oh = self.crop_size
|
| 135 |
+
ow = int(1.0 * w * oh / h)
|
| 136 |
+
else:
|
| 137 |
+
ow = self.crop_size
|
| 138 |
+
oh = int(1.0 * h * ow / w)
|
| 139 |
+
img = img.resize((ow, oh), Image.BILINEAR)
|
| 140 |
+
mask = mask.resize((ow, oh), Image.NEAREST)
|
| 141 |
+
# center crop
|
| 142 |
+
w, h = img.size
|
| 143 |
+
x1 = int(round((w - self.crop_size) / 2.))
|
| 144 |
+
y1 = int(round((h - self.crop_size) / 2.))
|
| 145 |
+
img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
| 146 |
+
mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
|
| 147 |
+
|
| 148 |
+
return {'image': img,
|
| 149 |
+
'label': mask}
|
| 150 |
+
|
| 151 |
+
class FixedResize(object):
|
| 152 |
+
def __init__(self, size):
|
| 153 |
+
self.size = (size, size) # size: (h, w)
|
| 154 |
+
|
| 155 |
+
def __call__(self, sample):
|
| 156 |
+
img = sample['image']
|
| 157 |
+
mask = sample['label']
|
| 158 |
+
|
| 159 |
+
assert img.size == mask.size
|
| 160 |
+
|
| 161 |
+
img = img.resize(self.size, Image.BILINEAR)
|
| 162 |
+
mask = mask.resize(self.size, Image.NEAREST)
|
| 163 |
+
|
| 164 |
+
return {'image': img,
|
| 165 |
+
'label': mask}
|
dataloaders/datasets/__init__.py
ADDED
|
File without changes
|
dataloaders/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
dataloaders/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
dataloaders/datasets/__pycache__/cityscapes.cpython-310.pyc
ADDED
|
Binary file (5.28 kB). View file
|
|
|
dataloaders/datasets/__pycache__/cityscapes.cpython-38.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
dataloaders/datasets/__pycache__/coco.cpython-310.pyc
ADDED
|
Binary file (5.38 kB). View file
|
|
|
dataloaders/datasets/__pycache__/coco.cpython-38.pyc
ADDED
|
Binary file (5.4 kB). View file
|
|
|
dataloaders/datasets/__pycache__/combine_dbs.cpython-310.pyc
ADDED
|
Binary file (3.19 kB). View file
|
|
|
dataloaders/datasets/__pycache__/combine_dbs.cpython-38.pyc
ADDED
|
Binary file (3.17 kB). View file
|
|
|
dataloaders/datasets/__pycache__/invoice.cpython-310.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|
dataloaders/datasets/__pycache__/invoice.cpython-38.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
dataloaders/datasets/__pycache__/pascal.cpython-310.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|
dataloaders/datasets/__pycache__/pascal.cpython-38.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
dataloaders/datasets/__pycache__/sbd.cpython-310.pyc
ADDED
|
Binary file (4.01 kB). View file
|
|
|
dataloaders/datasets/__pycache__/sbd.cpython-38.pyc
ADDED
|
Binary file (3.97 kB). View file
|
|
|
dataloaders/datasets/cityscapes.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy.misc as m
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torch.utils import data
|
| 6 |
+
from mypath import Path
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from dataloaders import custom_transforms as tr
|
| 9 |
+
|
| 10 |
+
class CityscapesSegmentation(data.Dataset):
|
| 11 |
+
NUM_CLASSES = 19
|
| 12 |
+
|
| 13 |
+
def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train"):
|
| 14 |
+
|
| 15 |
+
self.root = root
|
| 16 |
+
self.split = split
|
| 17 |
+
self.args = args
|
| 18 |
+
self.files = {}
|
| 19 |
+
|
| 20 |
+
self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
|
| 21 |
+
self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split)
|
| 22 |
+
|
| 23 |
+
self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png')
|
| 24 |
+
|
| 25 |
+
self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
|
| 26 |
+
self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
|
| 27 |
+
self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
|
| 28 |
+
'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
|
| 29 |
+
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
|
| 30 |
+
'motorcycle', 'bicycle']
|
| 31 |
+
|
| 32 |
+
self.ignore_index = 255
|
| 33 |
+
self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES)))
|
| 34 |
+
|
| 35 |
+
if not self.files[split]:
|
| 36 |
+
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
|
| 37 |
+
|
| 38 |
+
print("Found %d %s images" % (len(self.files[split]), split))
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return len(self.files[self.split])
|
| 42 |
+
|
| 43 |
+
def __getitem__(self, index):
|
| 44 |
+
|
| 45 |
+
img_path = self.files[self.split][index].rstrip()
|
| 46 |
+
lbl_path = os.path.join(self.annotations_base,
|
| 47 |
+
img_path.split(os.sep)[-2],
|
| 48 |
+
os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
|
| 49 |
+
|
| 50 |
+
_img = Image.open(img_path).convert('RGB')
|
| 51 |
+
_tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
|
| 52 |
+
_tmp = self.encode_segmap(_tmp)
|
| 53 |
+
_target = Image.fromarray(_tmp)
|
| 54 |
+
|
| 55 |
+
sample = {'image': _img, 'label': _target}
|
| 56 |
+
|
| 57 |
+
if self.split == 'train':
|
| 58 |
+
return self.transform_tr(sample)
|
| 59 |
+
elif self.split == 'val':
|
| 60 |
+
return self.transform_val(sample)
|
| 61 |
+
elif self.split == 'test':
|
| 62 |
+
return self.transform_ts(sample)
|
| 63 |
+
|
| 64 |
+
def encode_segmap(self, mask):
|
| 65 |
+
# Put all void classes to zero
|
| 66 |
+
for _voidc in self.void_classes:
|
| 67 |
+
mask[mask == _voidc] = self.ignore_index
|
| 68 |
+
for _validc in self.valid_classes:
|
| 69 |
+
mask[mask == _validc] = self.class_map[_validc]
|
| 70 |
+
return mask
|
| 71 |
+
|
| 72 |
+
def recursive_glob(self, rootdir='.', suffix=''):
|
| 73 |
+
"""Performs recursive glob with given suffix and rootdir
|
| 74 |
+
:param rootdir is the root directory
|
| 75 |
+
:param suffix is the suffix to be searched
|
| 76 |
+
"""
|
| 77 |
+
return [os.path.join(looproot, filename)
|
| 78 |
+
for looproot, _, filenames in os.walk(rootdir)
|
| 79 |
+
for filename in filenames if filename.endswith(suffix)]
|
| 80 |
+
|
| 81 |
+
def transform_tr(self, sample):
|
| 82 |
+
composed_transforms = transforms.Compose([
|
| 83 |
+
tr.RandomHorizontalFlip(),
|
| 84 |
+
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255),
|
| 85 |
+
tr.RandomGaussianBlur(),
|
| 86 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 87 |
+
tr.ToTensor()])
|
| 88 |
+
|
| 89 |
+
return composed_transforms(sample)
|
| 90 |
+
|
| 91 |
+
def transform_val(self, sample):
|
| 92 |
+
|
| 93 |
+
composed_transforms = transforms.Compose([
|
| 94 |
+
tr.FixScaleCrop(crop_size=self.args.crop_size),
|
| 95 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 96 |
+
tr.ToTensor()])
|
| 97 |
+
|
| 98 |
+
return composed_transforms(sample)
|
| 99 |
+
|
| 100 |
+
def transform_ts(self, sample):
|
| 101 |
+
|
| 102 |
+
composed_transforms = transforms.Compose([
|
| 103 |
+
tr.FixedResize(size=self.args.crop_size),
|
| 104 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 105 |
+
tr.ToTensor()])
|
| 106 |
+
|
| 107 |
+
return composed_transforms(sample)
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
from dataloaders.utils import decode_segmap
|
| 111 |
+
from torch.utils.data import DataLoader
|
| 112 |
+
import matplotlib.pyplot as plt
|
| 113 |
+
import argparse
|
| 114 |
+
|
| 115 |
+
parser = argparse.ArgumentParser()
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
args.base_size = 513
|
| 118 |
+
args.crop_size = 513
|
| 119 |
+
|
| 120 |
+
cityscapes_train = CityscapesSegmentation(args, split='train')
|
| 121 |
+
|
| 122 |
+
dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)
|
| 123 |
+
|
| 124 |
+
for ii, sample in enumerate(dataloader):
|
| 125 |
+
for jj in range(sample["image"].size()[0]):
|
| 126 |
+
img = sample['image'].numpy()
|
| 127 |
+
gt = sample['label'].numpy()
|
| 128 |
+
tmp = np.array(gt[jj]).astype(np.uint8)
|
| 129 |
+
segmap = decode_segmap(tmp, dataset='cityscapes')
|
| 130 |
+
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
| 131 |
+
img_tmp *= (0.229, 0.224, 0.225)
|
| 132 |
+
img_tmp += (0.485, 0.456, 0.406)
|
| 133 |
+
img_tmp *= 255.0
|
| 134 |
+
img_tmp = img_tmp.astype(np.uint8)
|
| 135 |
+
plt.figure()
|
| 136 |
+
plt.title('display')
|
| 137 |
+
plt.subplot(211)
|
| 138 |
+
plt.imshow(img_tmp)
|
| 139 |
+
plt.subplot(212)
|
| 140 |
+
plt.imshow(segmap)
|
| 141 |
+
|
| 142 |
+
if ii == 1:
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
plt.show(block=True)
|
| 146 |
+
|
dataloaders/datasets/coco.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
from mypath import Path
|
| 5 |
+
from tqdm import trange
|
| 6 |
+
import os
|
| 7 |
+
from pycocotools.coco import COCO
|
| 8 |
+
from pycocotools import mask
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from dataloaders import custom_transforms as tr
|
| 11 |
+
from PIL import Image, ImageFile
|
| 12 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class COCOSegmentation(Dataset):
|
| 16 |
+
NUM_CLASSES = 21
|
| 17 |
+
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
|
| 18 |
+
1, 64, 20, 63, 7, 72]
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
args,
|
| 22 |
+
base_dir=Path.db_root_dir('coco'),
|
| 23 |
+
split='train',
|
| 24 |
+
year='2017'):
|
| 25 |
+
super().__init__()
|
| 26 |
+
ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year))
|
| 27 |
+
ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year))
|
| 28 |
+
self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year))
|
| 29 |
+
self.split = split
|
| 30 |
+
self.coco = COCO(ann_file)
|
| 31 |
+
self.coco_mask = mask
|
| 32 |
+
if os.path.exists(ids_file):
|
| 33 |
+
self.ids = torch.load(ids_file)
|
| 34 |
+
else:
|
| 35 |
+
ids = list(self.coco.imgs.keys())
|
| 36 |
+
self.ids = self._preprocess(ids, ids_file)
|
| 37 |
+
self.args = args
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, index):
|
| 40 |
+
_img, _target = self._make_img_gt_point_pair(index)
|
| 41 |
+
sample = {'image': _img, 'label': _target}
|
| 42 |
+
|
| 43 |
+
if self.split == "train":
|
| 44 |
+
return self.transform_tr(sample)
|
| 45 |
+
elif self.split == 'val':
|
| 46 |
+
return self.transform_val(sample)
|
| 47 |
+
|
| 48 |
+
def _make_img_gt_point_pair(self, index):
|
| 49 |
+
coco = self.coco
|
| 50 |
+
img_id = self.ids[index]
|
| 51 |
+
img_metadata = coco.loadImgs(img_id)[0]
|
| 52 |
+
path = img_metadata['file_name']
|
| 53 |
+
_img = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
|
| 54 |
+
cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
|
| 55 |
+
_target = Image.fromarray(self._gen_seg_mask(
|
| 56 |
+
cocotarget, img_metadata['height'], img_metadata['width']))
|
| 57 |
+
|
| 58 |
+
return _img, _target
|
| 59 |
+
|
| 60 |
+
def _preprocess(self, ids, ids_file):
|
| 61 |
+
print("Preprocessing mask, this will take a while. " + \
|
| 62 |
+
"But don't worry, it only run once for each split.")
|
| 63 |
+
tbar = trange(len(ids))
|
| 64 |
+
new_ids = []
|
| 65 |
+
for i in tbar:
|
| 66 |
+
img_id = ids[i]
|
| 67 |
+
cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
|
| 68 |
+
img_metadata = self.coco.loadImgs(img_id)[0]
|
| 69 |
+
mask = self._gen_seg_mask(cocotarget, img_metadata['height'],
|
| 70 |
+
img_metadata['width'])
|
| 71 |
+
# more than 1k pixels
|
| 72 |
+
if (mask > 0).sum() > 1000:
|
| 73 |
+
new_ids.append(img_id)
|
| 74 |
+
tbar.set_description('Doing: {}/{}, got {} qualified images'. \
|
| 75 |
+
format(i, len(ids), len(new_ids)))
|
| 76 |
+
print('Found number of qualified images: ', len(new_ids))
|
| 77 |
+
torch.save(new_ids, ids_file)
|
| 78 |
+
return new_ids
|
| 79 |
+
|
| 80 |
+
def _gen_seg_mask(self, target, h, w):
|
| 81 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 82 |
+
coco_mask = self.coco_mask
|
| 83 |
+
for instance in target:
|
| 84 |
+
rle = coco_mask.frPyObjects(instance['segmentation'], h, w)
|
| 85 |
+
m = coco_mask.decode(rle)
|
| 86 |
+
cat = instance['category_id']
|
| 87 |
+
if cat in self.CAT_LIST:
|
| 88 |
+
c = self.CAT_LIST.index(cat)
|
| 89 |
+
else:
|
| 90 |
+
continue
|
| 91 |
+
if len(m.shape) < 3:
|
| 92 |
+
mask[:, :] += (mask == 0) * (m * c)
|
| 93 |
+
else:
|
| 94 |
+
mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
|
| 95 |
+
return mask
|
| 96 |
+
|
| 97 |
+
def transform_tr(self, sample):
|
| 98 |
+
composed_transforms = transforms.Compose([
|
| 99 |
+
tr.RandomHorizontalFlip(),
|
| 100 |
+
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
|
| 101 |
+
tr.RandomGaussianBlur(),
|
| 102 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 103 |
+
tr.ToTensor()])
|
| 104 |
+
|
| 105 |
+
return composed_transforms(sample)
|
| 106 |
+
|
| 107 |
+
def transform_val(self, sample):
|
| 108 |
+
|
| 109 |
+
composed_transforms = transforms.Compose([
|
| 110 |
+
tr.FixScaleCrop(crop_size=self.args.crop_size),
|
| 111 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 112 |
+
tr.ToTensor()])
|
| 113 |
+
|
| 114 |
+
return composed_transforms(sample)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def __len__(self):
|
| 118 |
+
return len(self.ids)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
from dataloaders import custom_transforms as tr
|
| 124 |
+
from dataloaders.utils import decode_segmap
|
| 125 |
+
from torch.utils.data import DataLoader
|
| 126 |
+
from torchvision import transforms
|
| 127 |
+
import matplotlib.pyplot as plt
|
| 128 |
+
import argparse
|
| 129 |
+
|
| 130 |
+
parser = argparse.ArgumentParser()
|
| 131 |
+
args = parser.parse_args()
|
| 132 |
+
args.base_size = 513
|
| 133 |
+
args.crop_size = 513
|
| 134 |
+
|
| 135 |
+
coco_val = COCOSegmentation(args, split='val', year='2017')
|
| 136 |
+
|
| 137 |
+
dataloader = DataLoader(coco_val, batch_size=4, shuffle=True, num_workers=0)
|
| 138 |
+
|
| 139 |
+
for ii, sample in enumerate(dataloader):
|
| 140 |
+
for jj in range(sample["image"].size()[0]):
|
| 141 |
+
img = sample['image'].numpy()
|
| 142 |
+
gt = sample['label'].numpy()
|
| 143 |
+
tmp = np.array(gt[jj]).astype(np.uint8)
|
| 144 |
+
segmap = decode_segmap(tmp, dataset='coco')
|
| 145 |
+
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
| 146 |
+
img_tmp *= (0.229, 0.224, 0.225)
|
| 147 |
+
img_tmp += (0.485, 0.456, 0.406)
|
| 148 |
+
img_tmp *= 255.0
|
| 149 |
+
img_tmp = img_tmp.astype(np.uint8)
|
| 150 |
+
plt.figure()
|
| 151 |
+
plt.title('display')
|
| 152 |
+
plt.subplot(211)
|
| 153 |
+
plt.imshow(img_tmp)
|
| 154 |
+
plt.subplot(212)
|
| 155 |
+
plt.imshow(segmap)
|
| 156 |
+
|
| 157 |
+
if ii == 1:
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
plt.show(block=True)
|
dataloaders/datasets/combine_dbs.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.utils.data as data
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CombineDBs(data.Dataset):
|
| 5 |
+
NUM_CLASSES = 21
|
| 6 |
+
def __init__(self, dataloaders, excluded=None):
|
| 7 |
+
self.dataloaders = dataloaders
|
| 8 |
+
self.excluded = excluded
|
| 9 |
+
self.im_ids = []
|
| 10 |
+
|
| 11 |
+
# Combine object lists
|
| 12 |
+
for dl in dataloaders:
|
| 13 |
+
for elem in dl.im_ids:
|
| 14 |
+
if elem not in self.im_ids:
|
| 15 |
+
self.im_ids.append(elem)
|
| 16 |
+
|
| 17 |
+
# Exclude
|
| 18 |
+
if excluded:
|
| 19 |
+
for dl in excluded:
|
| 20 |
+
for elem in dl.im_ids:
|
| 21 |
+
if elem in self.im_ids:
|
| 22 |
+
self.im_ids.remove(elem)
|
| 23 |
+
|
| 24 |
+
# Get object pointers
|
| 25 |
+
self.cat_list = []
|
| 26 |
+
self.im_list = []
|
| 27 |
+
new_im_ids = []
|
| 28 |
+
num_images = 0
|
| 29 |
+
for ii, dl in enumerate(dataloaders):
|
| 30 |
+
for jj, curr_im_id in enumerate(dl.im_ids):
|
| 31 |
+
if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids):
|
| 32 |
+
num_images += 1
|
| 33 |
+
new_im_ids.append(curr_im_id)
|
| 34 |
+
self.cat_list.append({'db_ii': ii, 'cat_ii': jj})
|
| 35 |
+
|
| 36 |
+
self.im_ids = new_im_ids
|
| 37 |
+
print('Combined number of images: {:d}'.format(num_images))
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, index):
|
| 40 |
+
|
| 41 |
+
_db_ii = self.cat_list[index]["db_ii"]
|
| 42 |
+
_cat_ii = self.cat_list[index]['cat_ii']
|
| 43 |
+
sample = self.dataloaders[_db_ii].__getitem__(_cat_ii)
|
| 44 |
+
|
| 45 |
+
if 'meta' in sample.keys():
|
| 46 |
+
sample['meta']['db'] = str(self.dataloaders[_db_ii])
|
| 47 |
+
|
| 48 |
+
return sample
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return len(self.cat_list)
|
| 52 |
+
|
| 53 |
+
def __str__(self):
|
| 54 |
+
include_db = [str(db) for db in self.dataloaders]
|
| 55 |
+
exclude_db = [str(db) for db in self.excluded]
|
| 56 |
+
return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
import matplotlib.pyplot as plt
|
| 61 |
+
from dataloaders.datasets import pascal, sbd
|
| 62 |
+
from dataloaders import sbd
|
| 63 |
+
import torch
|
| 64 |
+
import numpy as np
|
| 65 |
+
from dataloaders.utils import decode_segmap
|
| 66 |
+
import argparse
|
| 67 |
+
|
| 68 |
+
parser = argparse.ArgumentParser()
|
| 69 |
+
args = parser.parse_args()
|
| 70 |
+
args.base_size = 513
|
| 71 |
+
args.crop_size = 513
|
| 72 |
+
|
| 73 |
+
pascal_voc_val = pascal.VOCSegmentation(args, split='val')
|
| 74 |
+
sbd = sbd.SBDSegmentation(args, split=['train', 'val'])
|
| 75 |
+
pascal_voc_train = pascal.VOCSegmentation(args, split='train')
|
| 76 |
+
|
| 77 |
+
dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val])
|
| 78 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
|
| 79 |
+
|
| 80 |
+
for ii, sample in enumerate(dataloader):
|
| 81 |
+
for jj in range(sample["image"].size()[0]):
|
| 82 |
+
img = sample['image'].numpy()
|
| 83 |
+
gt = sample['label'].numpy()
|
| 84 |
+
tmp = np.array(gt[jj]).astype(np.uint8)
|
| 85 |
+
segmap = decode_segmap(tmp, dataset='pascal')
|
| 86 |
+
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
| 87 |
+
img_tmp *= (0.229, 0.224, 0.225)
|
| 88 |
+
img_tmp += (0.485, 0.456, 0.406)
|
| 89 |
+
img_tmp *= 255.0
|
| 90 |
+
img_tmp = img_tmp.astype(np.uint8)
|
| 91 |
+
plt.figure()
|
| 92 |
+
plt.title('display')
|
| 93 |
+
plt.subplot(211)
|
| 94 |
+
plt.imshow(img_tmp)
|
| 95 |
+
plt.subplot(212)
|
| 96 |
+
plt.imshow(segmap)
|
| 97 |
+
|
| 98 |
+
if ii == 1:
|
| 99 |
+
break
|
| 100 |
+
plt.show(block=True)
|
dataloaders/datasets/invoice.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function, division
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from mypath import Path
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from dataloaders import custom_transforms as tr
|
| 9 |
+
|
| 10 |
+
class VOCSegmentation(Dataset):
|
| 11 |
+
"""
|
| 12 |
+
PascalVoc dataset
|
| 13 |
+
"""
|
| 14 |
+
NUM_CLASSES = 2
|
| 15 |
+
|
| 16 |
+
def __init__(self,
|
| 17 |
+
args,
|
| 18 |
+
base_dir=Path.db_root_dir('invoice'),
|
| 19 |
+
split='train',
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
:param base_dir: path to VOC dataset directory
|
| 23 |
+
:param split: train/val
|
| 24 |
+
:param transform: transform to apply
|
| 25 |
+
"""
|
| 26 |
+
super().__init__()
|
| 27 |
+
self._base_dir = base_dir
|
| 28 |
+
self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
|
| 29 |
+
self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass')
|
| 30 |
+
|
| 31 |
+
if isinstance(split, str):
|
| 32 |
+
self.split = [split]
|
| 33 |
+
else:
|
| 34 |
+
split.sort()
|
| 35 |
+
self.split = split
|
| 36 |
+
|
| 37 |
+
self.args = args
|
| 38 |
+
|
| 39 |
+
_splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')
|
| 40 |
+
|
| 41 |
+
self.im_ids = []
|
| 42 |
+
self.images = []
|
| 43 |
+
self.categories = []
|
| 44 |
+
|
| 45 |
+
for splt in self.split:
|
| 46 |
+
with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
|
| 47 |
+
lines = f.read().splitlines()
|
| 48 |
+
|
| 49 |
+
for ii, line in enumerate(lines):
|
| 50 |
+
_image = os.path.join(self._image_dir, line + ".png")
|
| 51 |
+
_cat = os.path.join(self._cat_dir, line + ".png")
|
| 52 |
+
assert os.path.isfile(_image)
|
| 53 |
+
assert os.path.isfile(_cat)
|
| 54 |
+
self.im_ids.append(line)
|
| 55 |
+
self.images.append(_image)
|
| 56 |
+
self.categories.append(_cat)
|
| 57 |
+
|
| 58 |
+
assert (len(self.images) == len(self.categories))
|
| 59 |
+
|
| 60 |
+
# Display stats
|
| 61 |
+
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.images)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index):
|
| 68 |
+
_img, _target = self._make_img_gt_point_pair(index)
|
| 69 |
+
sample = {'image': _img, 'label': _target}
|
| 70 |
+
|
| 71 |
+
for split in self.split:
|
| 72 |
+
if split == "train":
|
| 73 |
+
return self.transform_tr(sample)
|
| 74 |
+
elif split == 'val':
|
| 75 |
+
return self.transform_val(sample)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _make_img_gt_point_pair(self, index):
|
| 79 |
+
_img = Image.open(self.images[index]).convert('RGB')
|
| 80 |
+
_target = Image.open(self.categories[index])
|
| 81 |
+
|
| 82 |
+
return _img, _target
|
| 83 |
+
|
| 84 |
+
def transform_tr(self, sample):
|
| 85 |
+
composed_transforms = transforms.Compose([
|
| 86 |
+
tr.RandomHorizontalFlip(),
|
| 87 |
+
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
|
| 88 |
+
tr.RandomGaussianBlur(),
|
| 89 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 90 |
+
tr.ToTensor()])
|
| 91 |
+
|
| 92 |
+
return composed_transforms(sample)
|
| 93 |
+
|
| 94 |
+
def transform_val(self, sample):
|
| 95 |
+
|
| 96 |
+
composed_transforms = transforms.Compose([
|
| 97 |
+
tr.FixScaleCrop(crop_size=self.args.crop_size),
|
| 98 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 99 |
+
tr.ToTensor()])
|
| 100 |
+
|
| 101 |
+
return composed_transforms(sample)
|
| 102 |
+
|
| 103 |
+
def __str__(self):
|
| 104 |
+
return 'VOC2012(split=' + str(self.split) + ')'
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
from dataloaders.utils import decode_segmap
|
| 109 |
+
from torch.utils.data import DataLoader
|
| 110 |
+
import matplotlib.pyplot as plt
|
| 111 |
+
import argparse
|
| 112 |
+
|
| 113 |
+
parser = argparse.ArgumentParser()
|
| 114 |
+
args = parser.parse_args()
|
| 115 |
+
args.base_size = 512
|
| 116 |
+
args.crop_size = 512
|
| 117 |
+
|
| 118 |
+
voc_train = VOCSegmentation(args, split='train')
|
| 119 |
+
|
| 120 |
+
dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0)
|
| 121 |
+
|
| 122 |
+
for ii, sample in enumerate(dataloader):
|
| 123 |
+
for jj in range(sample["image"].size()[0]):
|
| 124 |
+
img = sample['image'].numpy()
|
| 125 |
+
gt = sample['label'].numpy()
|
| 126 |
+
tmp = np.array(gt[jj]).astype(np.uint8)
|
| 127 |
+
segmap = decode_segmap(tmp, dataset='invoice')
|
| 128 |
+
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
| 129 |
+
img_tmp *= (0.229, 0.224, 0.225)
|
| 130 |
+
img_tmp += (0.485, 0.456, 0.406)
|
| 131 |
+
img_tmp *= 255.0
|
| 132 |
+
img_tmp = img_tmp.astype(np.uint8)
|
| 133 |
+
plt.figure()
|
| 134 |
+
plt.title('display')
|
| 135 |
+
plt.subplot(211)
|
| 136 |
+
plt.imshow(img_tmp)
|
| 137 |
+
plt.subplot(212)
|
| 138 |
+
plt.imshow(segmap)
|
| 139 |
+
|
| 140 |
+
if ii == 1:
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
plt.show(block=True)
|
| 144 |
+
|
| 145 |
+
|
dataloaders/datasets/pascal.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function, division
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from mypath import Path
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from dataloaders import custom_transforms as tr
|
| 9 |
+
|
| 10 |
+
class VOCSegmentation(Dataset):
|
| 11 |
+
"""
|
| 12 |
+
PascalVoc dataset
|
| 13 |
+
"""
|
| 14 |
+
NUM_CLASSES = 21
|
| 15 |
+
|
| 16 |
+
def __init__(self,
|
| 17 |
+
args,
|
| 18 |
+
base_dir=Path.db_root_dir('pascal'),
|
| 19 |
+
split='train',
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
:param base_dir: path to VOC dataset directory
|
| 23 |
+
:param split: train/val
|
| 24 |
+
:param transform: transform to apply
|
| 25 |
+
"""
|
| 26 |
+
super().__init__()
|
| 27 |
+
self._base_dir = base_dir
|
| 28 |
+
self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
|
| 29 |
+
self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass')
|
| 30 |
+
|
| 31 |
+
if isinstance(split, str):
|
| 32 |
+
self.split = [split]
|
| 33 |
+
else:
|
| 34 |
+
split.sort()
|
| 35 |
+
self.split = split
|
| 36 |
+
|
| 37 |
+
self.args = args
|
| 38 |
+
|
| 39 |
+
_splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')
|
| 40 |
+
|
| 41 |
+
self.im_ids = []
|
| 42 |
+
self.images = []
|
| 43 |
+
self.categories = []
|
| 44 |
+
|
| 45 |
+
for splt in self.split:
|
| 46 |
+
with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
|
| 47 |
+
lines = f.read().splitlines()
|
| 48 |
+
|
| 49 |
+
for ii, line in enumerate(lines):
|
| 50 |
+
_image = os.path.join(self._image_dir, line + ".jpg")
|
| 51 |
+
_cat = os.path.join(self._cat_dir, line + ".png")
|
| 52 |
+
assert os.path.isfile(_image)
|
| 53 |
+
assert os.path.isfile(_cat)
|
| 54 |
+
self.im_ids.append(line)
|
| 55 |
+
self.images.append(_image)
|
| 56 |
+
self.categories.append(_cat)
|
| 57 |
+
|
| 58 |
+
assert (len(self.images) == len(self.categories))
|
| 59 |
+
|
| 60 |
+
# Display stats
|
| 61 |
+
print('Number of images in {}: {:d}'.format(split, len(self.images)))
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.images)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index):
|
| 68 |
+
_img, _target = self._make_img_gt_point_pair(index)
|
| 69 |
+
sample = {'image': _img, 'label': _target}
|
| 70 |
+
|
| 71 |
+
for split in self.split:
|
| 72 |
+
if split == "train":
|
| 73 |
+
return self.transform_tr(sample)
|
| 74 |
+
elif split == 'val':
|
| 75 |
+
return self.transform_val(sample)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _make_img_gt_point_pair(self, index):
|
| 79 |
+
_img = Image.open(self.images[index]).convert('RGB')
|
| 80 |
+
_target = Image.open(self.categories[index])
|
| 81 |
+
|
| 82 |
+
return _img, _target
|
| 83 |
+
|
| 84 |
+
def transform_tr(self, sample):
|
| 85 |
+
composed_transforms = transforms.Compose([
|
| 86 |
+
tr.RandomHorizontalFlip(),
|
| 87 |
+
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
|
| 88 |
+
tr.RandomGaussianBlur(),
|
| 89 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 90 |
+
tr.ToTensor()])
|
| 91 |
+
|
| 92 |
+
return composed_transforms(sample)
|
| 93 |
+
|
| 94 |
+
def transform_val(self, sample):
|
| 95 |
+
|
| 96 |
+
composed_transforms = transforms.Compose([
|
| 97 |
+
tr.FixScaleCrop(crop_size=self.args.crop_size),
|
| 98 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 99 |
+
tr.ToTensor()])
|
| 100 |
+
|
| 101 |
+
return composed_transforms(sample)
|
| 102 |
+
|
| 103 |
+
def __str__(self):
|
| 104 |
+
return 'VOC2012(split=' + str(self.split) + ')'
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == '__main__':
|
| 108 |
+
from dataloaders.utils import decode_segmap
|
| 109 |
+
from torch.utils.data import DataLoader
|
| 110 |
+
import matplotlib.pyplot as plt
|
| 111 |
+
import argparse
|
| 112 |
+
|
| 113 |
+
parser = argparse.ArgumentParser()
|
| 114 |
+
args = parser.parse_args()
|
| 115 |
+
args.base_size = 513
|
| 116 |
+
args.crop_size = 513
|
| 117 |
+
|
| 118 |
+
voc_train = VOCSegmentation(args, split='train')
|
| 119 |
+
|
| 120 |
+
dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0)
|
| 121 |
+
|
| 122 |
+
for ii, sample in enumerate(dataloader):
|
| 123 |
+
for jj in range(sample["image"].size()[0]):
|
| 124 |
+
img = sample['image'].numpy()
|
| 125 |
+
gt = sample['label'].numpy()
|
| 126 |
+
tmp = np.array(gt[jj]).astype(np.uint8)
|
| 127 |
+
segmap = decode_segmap(tmp, dataset='pascal')
|
| 128 |
+
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
| 129 |
+
img_tmp *= (0.229, 0.224, 0.225)
|
| 130 |
+
img_tmp += (0.485, 0.456, 0.406)
|
| 131 |
+
img_tmp *= 255.0
|
| 132 |
+
img_tmp = img_tmp.astype(np.uint8)
|
| 133 |
+
plt.figure()
|
| 134 |
+
plt.title('display')
|
| 135 |
+
plt.subplot(211)
|
| 136 |
+
plt.imshow(img_tmp)
|
| 137 |
+
plt.subplot(212)
|
| 138 |
+
plt.imshow(segmap)
|
| 139 |
+
|
| 140 |
+
if ii == 1:
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
plt.show(block=True)
|
| 144 |
+
|
| 145 |
+
|
dataloaders/datasets/sbd.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function, division
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import scipy.io
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from mypath import Path
|
| 9 |
+
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from dataloaders import custom_transforms as tr
|
| 12 |
+
|
| 13 |
+
class SBDSegmentation(data.Dataset):
|
| 14 |
+
NUM_CLASSES = 21
|
| 15 |
+
|
| 16 |
+
def __init__(self,
|
| 17 |
+
args,
|
| 18 |
+
base_dir=Path.db_root_dir('sbd'),
|
| 19 |
+
split='train',
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
:param base_dir: path to VOC dataset directory
|
| 23 |
+
:param split: train/val
|
| 24 |
+
:param transform: transform to apply
|
| 25 |
+
"""
|
| 26 |
+
super().__init__()
|
| 27 |
+
self._base_dir = base_dir
|
| 28 |
+
self._dataset_dir = os.path.join(self._base_dir, 'dataset')
|
| 29 |
+
self._image_dir = os.path.join(self._dataset_dir, 'img')
|
| 30 |
+
self._cat_dir = os.path.join(self._dataset_dir, 'cls')
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if isinstance(split, str):
|
| 34 |
+
self.split = [split]
|
| 35 |
+
else:
|
| 36 |
+
split.sort()
|
| 37 |
+
self.split = split
|
| 38 |
+
|
| 39 |
+
self.args = args
|
| 40 |
+
|
| 41 |
+
# Get list of all images from the split and check that the files exist
|
| 42 |
+
self.im_ids = []
|
| 43 |
+
self.images = []
|
| 44 |
+
self.categories = []
|
| 45 |
+
for splt in self.split:
|
| 46 |
+
with open(os.path.join(self._dataset_dir, splt + '.txt'), "r") as f:
|
| 47 |
+
lines = f.read().splitlines()
|
| 48 |
+
|
| 49 |
+
for line in lines:
|
| 50 |
+
_image = os.path.join(self._image_dir, line + ".jpg")
|
| 51 |
+
_categ= os.path.join(self._cat_dir, line + ".mat")
|
| 52 |
+
assert os.path.isfile(_image)
|
| 53 |
+
assert os.path.isfile(_categ)
|
| 54 |
+
self.im_ids.append(line)
|
| 55 |
+
self.images.append(_image)
|
| 56 |
+
self.categories.append(_categ)
|
| 57 |
+
|
| 58 |
+
assert (len(self.images) == len(self.categories))
|
| 59 |
+
|
| 60 |
+
# Display stats
|
| 61 |
+
print('Number of images: {:d}'.format(len(self.images)))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, index):
|
| 65 |
+
_img, _target = self._make_img_gt_point_pair(index)
|
| 66 |
+
sample = {'image': _img, 'label': _target}
|
| 67 |
+
|
| 68 |
+
return self.transform(sample)
|
| 69 |
+
|
| 70 |
+
def __len__(self):
|
| 71 |
+
return len(self.images)
|
| 72 |
+
|
| 73 |
+
def _make_img_gt_point_pair(self, index):
|
| 74 |
+
_img = Image.open(self.images[index]).convert('RGB')
|
| 75 |
+
_target = Image.fromarray(scipy.io.loadmat(self.categories[index])["GTcls"][0]['Segmentation'][0])
|
| 76 |
+
|
| 77 |
+
return _img, _target
|
| 78 |
+
|
| 79 |
+
def transform(self, sample):
|
| 80 |
+
composed_transforms = transforms.Compose([
|
| 81 |
+
tr.RandomHorizontalFlip(),
|
| 82 |
+
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
|
| 83 |
+
tr.RandomGaussianBlur(),
|
| 84 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 85 |
+
tr.ToTensor()])
|
| 86 |
+
|
| 87 |
+
return composed_transforms(sample)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def __str__(self):
|
| 91 |
+
return 'SBDSegmentation(split=' + str(self.split) + ')'
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == '__main__':
|
| 95 |
+
from dataloaders.utils import decode_segmap
|
| 96 |
+
from torch.utils.data import DataLoader
|
| 97 |
+
import matplotlib.pyplot as plt
|
| 98 |
+
import argparse
|
| 99 |
+
|
| 100 |
+
parser = argparse.ArgumentParser()
|
| 101 |
+
args = parser.parse_args()
|
| 102 |
+
args.base_size = 513
|
| 103 |
+
args.crop_size = 513
|
| 104 |
+
|
| 105 |
+
sbd_train = SBDSegmentation(args, split='train')
|
| 106 |
+
dataloader = DataLoader(sbd_train, batch_size=2, shuffle=True, num_workers=2)
|
| 107 |
+
|
| 108 |
+
for ii, sample in enumerate(dataloader):
|
| 109 |
+
for jj in range(sample["image"].size()[0]):
|
| 110 |
+
img = sample['image'].numpy()
|
| 111 |
+
gt = sample['label'].numpy()
|
| 112 |
+
tmp = np.array(gt[jj]).astype(np.uint8)
|
| 113 |
+
segmap = decode_segmap(tmp, dataset='pascal')
|
| 114 |
+
img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
|
| 115 |
+
img_tmp *= (0.229, 0.224, 0.225)
|
| 116 |
+
img_tmp += (0.485, 0.456, 0.406)
|
| 117 |
+
img_tmp *= 255.0
|
| 118 |
+
img_tmp = img_tmp.astype(np.uint8)
|
| 119 |
+
plt.figure()
|
| 120 |
+
plt.title('display')
|
| 121 |
+
plt.subplot(211)
|
| 122 |
+
plt.imshow(img_tmp)
|
| 123 |
+
plt.subplot(212)
|
| 124 |
+
plt.imshow(segmap)
|
| 125 |
+
|
| 126 |
+
if ii == 1:
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
plt.show(block=True)
|
dataloaders/utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def decode_seg_map_sequence(label_masks, dataset='pascal'):
|
| 6 |
+
rgb_masks = []
|
| 7 |
+
for label_mask in label_masks:
|
| 8 |
+
rgb_mask = decode_segmap(label_mask, dataset)
|
| 9 |
+
rgb_masks.append(rgb_mask)
|
| 10 |
+
rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
|
| 11 |
+
return rgb_masks
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def decode_segmap(label_mask, dataset, plot=False):
|
| 15 |
+
"""Decode segmentation class labels into a color image
|
| 16 |
+
Args:
|
| 17 |
+
label_mask (np.ndarray): an (M,N) array of integer values denoting
|
| 18 |
+
the class label at each spatial location.
|
| 19 |
+
plot (bool, optional): whether to show the resulting color image
|
| 20 |
+
in a figure.
|
| 21 |
+
Returns:
|
| 22 |
+
(np.ndarray, optional): the resulting decoded color image.
|
| 23 |
+
"""
|
| 24 |
+
if dataset == 'pascal' or dataset == 'coco':
|
| 25 |
+
n_classes = 21
|
| 26 |
+
label_colours = get_pascal_labels()
|
| 27 |
+
elif dataset == 'cityscapes':
|
| 28 |
+
n_classes = 19
|
| 29 |
+
label_colours = get_cityscapes_labels()
|
| 30 |
+
elif dataset == 'invoice':
|
| 31 |
+
n_classes = 2
|
| 32 |
+
label_colours = get_invoice_labels()
|
| 33 |
+
else:
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
r = label_mask.copy()
|
| 37 |
+
g = label_mask.copy()
|
| 38 |
+
b = label_mask.copy()
|
| 39 |
+
for ll in range(0, n_classes):
|
| 40 |
+
r[label_mask == ll] = label_colours[ll, 0]
|
| 41 |
+
g[label_mask == ll] = label_colours[ll, 1]
|
| 42 |
+
b[label_mask == ll] = label_colours[ll, 2]
|
| 43 |
+
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
|
| 44 |
+
rgb[:, :, 0] = r / 255.0
|
| 45 |
+
rgb[:, :, 1] = g / 255.0
|
| 46 |
+
rgb[:, :, 2] = b / 255.0
|
| 47 |
+
if plot:
|
| 48 |
+
plt.imshow(rgb)
|
| 49 |
+
plt.show()
|
| 50 |
+
else:
|
| 51 |
+
return rgb
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def encode_segmap(mask):
|
| 55 |
+
"""Encode segmentation label images as pascal classes
|
| 56 |
+
Args:
|
| 57 |
+
mask (np.ndarray): raw segmentation label image of dimension
|
| 58 |
+
(M, N, 3), in which the Pascal classes are encoded as colours.
|
| 59 |
+
Returns:
|
| 60 |
+
(np.ndarray): class map with dimensions (M,N), where the value at
|
| 61 |
+
a given location is the integer denoting the class index.
|
| 62 |
+
"""
|
| 63 |
+
mask = mask.astype(int)
|
| 64 |
+
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
|
| 65 |
+
for ii, label in enumerate(get_pascal_labels()):
|
| 66 |
+
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
|
| 67 |
+
label_mask = label_mask.astype(int)
|
| 68 |
+
return label_mask
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_cityscapes_labels():
|
| 72 |
+
return np.array([
|
| 73 |
+
[128, 64, 128],
|
| 74 |
+
[244, 35, 232],
|
| 75 |
+
[70, 70, 70],
|
| 76 |
+
[102, 102, 156],
|
| 77 |
+
[190, 153, 153],
|
| 78 |
+
[153, 153, 153],
|
| 79 |
+
[250, 170, 30],
|
| 80 |
+
[220, 220, 0],
|
| 81 |
+
[107, 142, 35],
|
| 82 |
+
[152, 251, 152],
|
| 83 |
+
[0, 130, 180],
|
| 84 |
+
[220, 20, 60],
|
| 85 |
+
[255, 0, 0],
|
| 86 |
+
[0, 0, 142],
|
| 87 |
+
[0, 0, 70],
|
| 88 |
+
[0, 60, 100],
|
| 89 |
+
[0, 80, 100],
|
| 90 |
+
[0, 0, 230],
|
| 91 |
+
[119, 11, 32]])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_pascal_labels():
|
| 95 |
+
"""Load the mapping that associates pascal classes with label colors
|
| 96 |
+
Returns:
|
| 97 |
+
np.ndarray with dimensions (21, 3)
|
| 98 |
+
"""
|
| 99 |
+
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
| 100 |
+
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
| 101 |
+
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
|
| 102 |
+
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
|
| 103 |
+
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
|
| 104 |
+
[0, 64, 128]])
|
| 105 |
+
|
| 106 |
+
def get_invoice_labels():
|
| 107 |
+
"""Load the mapping that associates pascal classes with label colors
|
| 108 |
+
Returns:
|
| 109 |
+
np.ndarray with dimensions (21, 3)
|
| 110 |
+
"""
|
| 111 |
+
return np.asarray([[0, 0, 0], [255, 255, 255]])
|
deeplab-mobilenet.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a36ba48f39fc6edc161335211b15d9250cadb521f1cb958cb6d014399093f31
|
| 3 |
+
size 46666796
|
deeplab-resnet.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c1ca4610f1ff8c118b451aa0ab30048554a9e77b794f7174808c457e935913a
|
| 3 |
+
size 474903453
|
deeplab_demo.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Description: This script is used to extract the specified category from the image using the trained DeepLabV3+ model.
|
| 2 |
+
# file name: deeplab_demo.py
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import time
|
| 7 |
+
from modeling.deeplab import *
|
| 8 |
+
from dataloaders import custom_transforms as tr
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from dataloaders.utils import *
|
| 12 |
+
from torchvision.utils import make_grid, save_image
|
| 13 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
| 14 |
+
|
| 15 |
+
def get_people(newimage):
|
| 16 |
+
#define the argument parser for configuring the model
|
| 17 |
+
parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")
|
| 18 |
+
parser.add_argument('--in-path', type=str, default="img", help='image to test')
|
| 19 |
+
# parser.add_argument('--out-path', type=str, required=True, help='mask image to save')
|
| 20 |
+
parser.add_argument('--backbone', type=str, default='mobilenet',
|
| 21 |
+
choices=['resnet', 'xception', 'drn', 'mobilenet'],
|
| 22 |
+
help='backbone name (default: resnet)')
|
| 23 |
+
parser.add_argument('--ckpt', type=str, default='deeplab-mobilenet.pth.tar',
|
| 24 |
+
help='saved model')
|
| 25 |
+
parser.add_argument('--out-stride', type=int, default=8,
|
| 26 |
+
help='network output stride (default: 8)')
|
| 27 |
+
parser.add_argument('--no-cuda', action='store_true', default=False,
|
| 28 |
+
help='disables CUDA training')
|
| 29 |
+
parser.add_argument('--gpu-ids', type=str, default='0',
|
| 30 |
+
help='use which gpu to train, must be a \
|
| 31 |
+
comma-separated list of integers only (default=0)')
|
| 32 |
+
parser.add_argument('--dataset', type=str, default='invoice',
|
| 33 |
+
choices=['pascal', 'coco', 'cityscapes','invoice'],
|
| 34 |
+
help='dataset name (default: pascal)')
|
| 35 |
+
parser.add_argument('--crop-size', type=int, default=512,
|
| 36 |
+
help='crop image size')
|
| 37 |
+
parser.add_argument('--num_classes', type=int, default=21,
|
| 38 |
+
help='crop image size')
|
| 39 |
+
parser.add_argument('--sync-bn', type=bool, default=None,
|
| 40 |
+
help='whether to use sync bn (default: auto)')
|
| 41 |
+
parser.add_argument('--freeze-bn', type=bool, default=False,
|
| 42 |
+
help='whether to freeze bn parameters (default: False)')
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
| 45 |
+
if args.cuda:
|
| 46 |
+
try:
|
| 47 |
+
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
|
| 48 |
+
except ValueError:
|
| 49 |
+
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
|
| 50 |
+
|
| 51 |
+
if args.sync_bn is None:
|
| 52 |
+
if args.cuda and len(args.gpu_ids) > 1:
|
| 53 |
+
args.sync_bn = True
|
| 54 |
+
else:
|
| 55 |
+
args.sync_bn = False
|
| 56 |
+
model_s_time = time.time()
|
| 57 |
+
model = DeepLab(num_classes=args.num_classes,
|
| 58 |
+
backbone=args.backbone,
|
| 59 |
+
output_stride=args.out_stride,
|
| 60 |
+
sync_bn=args.sync_bn,
|
| 61 |
+
freeze_bn=args.freeze_bn)
|
| 62 |
+
|
| 63 |
+
ckpt = torch.load(args.ckpt, map_location='cpu')
|
| 64 |
+
model.load_state_dict(ckpt['state_dict'])
|
| 65 |
+
# model = model.cuda()
|
| 66 |
+
model_u_time = time.time()
|
| 67 |
+
model_load_time = model_u_time-model_s_time
|
| 68 |
+
print("model load time is {}".format(model_load_time))
|
| 69 |
+
|
| 70 |
+
composed_transforms = transforms.Compose([
|
| 71 |
+
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 72 |
+
tr.ToTensor()])
|
| 73 |
+
|
| 74 |
+
image = newimage
|
| 75 |
+
s_time = time.time()
|
| 76 |
+
target = newimage
|
| 77 |
+
sample = {'image': image, 'label': target}
|
| 78 |
+
tensor_in = composed_transforms(sample)['image'].unsqueeze(0)
|
| 79 |
+
|
| 80 |
+
model.eval()
|
| 81 |
+
if args.cuda:
|
| 82 |
+
tensor_in = tensor_in.cuda()
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
output = model(tensor_in)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Get category index
|
| 88 |
+
pred = torch.max(output, 1)[1].detach().cpu().numpy()
|
| 89 |
+
# Specify the category label to extract
|
| 90 |
+
target_class = 15 #replace with the category index you want to extract
|
| 91 |
+
mask = (pred == target_class).astype(np.uint8).squeeze()
|
| 92 |
+
# Apply the mask to the original image
|
| 93 |
+
image_np = np.array(image)
|
| 94 |
+
masked_image = image_np * mask[:, :, np.newaxis]
|
| 95 |
+
|
| 96 |
+
# save the masked area
|
| 97 |
+
masked_image_pil = Image.fromarray(masked_image)
|
| 98 |
+
grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy()),
|
| 99 |
+
3, normalize=False)
|
| 100 |
+
u_time = time.time()
|
| 101 |
+
img_time = u_time - s_time
|
| 102 |
+
print("time: {} ".format(img_time))
|
| 103 |
+
|
| 104 |
+
return masked_image_pil, grid_image
|
| 105 |
+
|
| 106 |
+
# mypath=r'img/people.jpg'
|
| 107 |
+
# image = Image.open(mypath).convert('RGB')
|
| 108 |
+
# result, mask=get_people(image)
|
| 109 |
+
# result_tensor = ToTensor()(result)
|
| 110 |
+
# save_image(result_tensor, "masked.png")
|
| 111 |
+
|
end.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# filename: end.py
|
| 2 |
+
# Description: This is the main file of the project. It is used to create the Gradio interface and run the application.
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from deeplab_demo import get_people
|
| 7 |
+
from creat_anaglyph import insert_person_to_stereo_gradio
|
| 8 |
+
import torch
|
| 9 |
+
from torchvision.transforms import ToPILImage
|
| 10 |
+
|
| 11 |
+
# Define functions to process the person image and generate the anaglyph image
|
| 12 |
+
def process_person_image(person_image):
|
| 13 |
+
masked_image_pil, grid_image = get_people(person_image)
|
| 14 |
+
|
| 15 |
+
if isinstance(masked_image_pil, torch.Tensor):
|
| 16 |
+
masked_image_pil = ToPILImage()(masked_image_pil)
|
| 17 |
+
if isinstance(grid_image, torch.Tensor):
|
| 18 |
+
grid_image = ToPILImage()(grid_image)
|
| 19 |
+
|
| 20 |
+
return masked_image_pil, grid_image
|
| 21 |
+
|
| 22 |
+
# Define a function to generate the anaglyph image
|
| 23 |
+
def generate_anaglyph(masked_image_pil, scenery_image, depth_option, custom_disparity):
|
| 24 |
+
# Define default disparities for non-custom options
|
| 25 |
+
# non-custom options: close, medium, far
|
| 26 |
+
depth_disparities = {
|
| 27 |
+
"close": 10, # Adjust values as needed
|
| 28 |
+
"medium": 5,
|
| 29 |
+
"far": 2
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Use custom_disparity only if depth_option is "custom"
|
| 33 |
+
disparity = custom_disparity if depth_option == "custom" else depth_disparities.get(depth_option, 5)
|
| 34 |
+
|
| 35 |
+
# Ensure input is PIL image
|
| 36 |
+
if isinstance(masked_image_pil, torch.Tensor):
|
| 37 |
+
masked_image_pil = ToPILImage()(masked_image_pil)
|
| 38 |
+
if isinstance(scenery_image, torch.Tensor):
|
| 39 |
+
scenery_image = ToPILImage()(scenery_image)
|
| 40 |
+
|
| 41 |
+
anaglyph_image = insert_person_to_stereo_gradio(scenery_image, masked_image_pil, disparity)
|
| 42 |
+
|
| 43 |
+
if isinstance(anaglyph_image, torch.Tensor):
|
| 44 |
+
anaglyph_image = ToPILImage()(anaglyph_image)
|
| 45 |
+
|
| 46 |
+
return anaglyph_image
|
| 47 |
+
|
| 48 |
+
# Create Gradio interface
|
| 49 |
+
with gr.Blocks() as iface:
|
| 50 |
+
with gr.Row():
|
| 51 |
+
person_image_input = gr.Image(type="pil", label="Character image")
|
| 52 |
+
scenery_image_input = gr.Image(type="pil", label="Landscape images")
|
| 53 |
+
depth_option_input = gr.Dropdown(choices=["close", "medium", "far", "custom"], label="Depth Options")
|
| 54 |
+
custom_disparity_input = gr.Slider(minimum=0, maximum=50, step=1, label="Custom Depth Disparity", visible=False)
|
| 55 |
+
|
| 56 |
+
with gr.Row():
|
| 57 |
+
grid_image_output = gr.Image(type="pil", label="Grid", interactive=False)
|
| 58 |
+
masked_image_output = gr.Image(type="pil", label="Masked", interactive=False)
|
| 59 |
+
anaglyph_image_output = gr.Image(type="pil", label="Anaglyph", interactive=False)
|
| 60 |
+
|
| 61 |
+
# button1: Process the character image
|
| 62 |
+
process_button = gr.Button("Processing human images")
|
| 63 |
+
process_button.click(
|
| 64 |
+
fn=process_person_image,
|
| 65 |
+
inputs=person_image_input,
|
| 66 |
+
outputs=[masked_image_output, grid_image_output]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# define a function to update the visibility of the custom disparity slider based on the depth option
|
| 70 |
+
def update_custom_slider_visibility(depth_option):
|
| 71 |
+
return gr.update(visible=(depth_option == "custom"))
|
| 72 |
+
|
| 73 |
+
depth_option_input.change(
|
| 74 |
+
fn=update_custom_slider_visibility,
|
| 75 |
+
inputs=[depth_option_input],
|
| 76 |
+
outputs=custom_disparity_input
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# button2: Generate anaglyph image
|
| 80 |
+
generate_button = gr.Button("Generate Anaglyph Image")
|
| 81 |
+
generate_button.click(
|
| 82 |
+
fn=generate_anaglyph,
|
| 83 |
+
inputs=[masked_image_output, scenery_image_input, depth_option_input, custom_disparity_input],
|
| 84 |
+
outputs=anaglyph_image_output
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Launch the Gradio interface
|
| 88 |
+
#change from iface.launch()
|
| 89 |
+
iface.launch(share=True)
|
| 90 |
+
|
img/mask.png
ADDED
|
img/masked.png
ADDED
|
img/people.jpg
ADDED
|
img/scenery.jpg
ADDED
|
Git LFS Details
|
img/scenery2.jpg
ADDED
|
Git LFS Details
|
modeling/__init__.py
ADDED
|
File without changes
|
modeling/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (159 Bytes). View file
|
|
|
modeling/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (142 Bytes). View file
|
|
|
modeling/__pycache__/aspp.cpython-310.pyc
ADDED
|
Binary file (3.08 kB). View file
|
|
|