Add files using upload-large-folder tool
Browse files- ktda/models/segmentors/__init__.py +3 -0
- tools/analysis_tools/browse_dataset.py +77 -0
- tools/dataset_converters/drive.py +114 -0
- tools/dataset_converters/vaihingen.py +156 -0
- tools/dataset_tools/analysis_dataset.py +58 -0
- tools/dataset_tools/process_water.py +152 -0
- tools/deployment/pytorch2torchscript.py +185 -0
- tools/model_converters/beit2mmseg.py +56 -0
ktda/models/segmentors/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .distill_encoder_decoder import DistillEncoderDecoder
|
| 2 |
+
|
| 3 |
+
__all__ = ['DistillEncoderDecoder']
|
tools/analysis_tools/browse_dataset.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import os.path as osp
|
| 4 |
+
|
| 5 |
+
from mmengine.config import Config, DictAction
|
| 6 |
+
from mmengine.utils import ProgressBar
|
| 7 |
+
|
| 8 |
+
from mmseg.registry import DATASETS, VISUALIZERS
|
| 9 |
+
from mmseg.utils import register_all_modules
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_args():
|
| 13 |
+
parser = argparse.ArgumentParser(description='Browse a dataset')
|
| 14 |
+
parser.add_argument('config', help='train config file path')
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
'--output-dir',
|
| 17 |
+
default=None,
|
| 18 |
+
type=str,
|
| 19 |
+
help='If there is no display interface, you can save it')
|
| 20 |
+
parser.add_argument('--not-show', default=False, action='store_true')
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
'--show-interval',
|
| 23 |
+
type=float,
|
| 24 |
+
default=2,
|
| 25 |
+
help='the interval of show (s)')
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
'--cfg-options',
|
| 28 |
+
nargs='+',
|
| 29 |
+
action=DictAction,
|
| 30 |
+
help='override some settings in the used config, the key-value pair '
|
| 31 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
| 32 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
| 33 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
| 34 |
+
'Note that the quotation marks are necessary and that no white space '
|
| 35 |
+
'is allowed.')
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
return args
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
args = parse_args()
|
| 42 |
+
cfg = Config.fromfile(args.config)
|
| 43 |
+
if args.cfg_options is not None:
|
| 44 |
+
cfg.merge_from_dict(args.cfg_options)
|
| 45 |
+
|
| 46 |
+
# register all modules in mmdet into the registries
|
| 47 |
+
register_all_modules()
|
| 48 |
+
|
| 49 |
+
dataset = DATASETS.build(cfg.train_dataloader.dataset)
|
| 50 |
+
visualizer = VISUALIZERS.build(cfg.visualizer)
|
| 51 |
+
visualizer.dataset_meta = dataset.metainfo
|
| 52 |
+
|
| 53 |
+
progress_bar = ProgressBar(len(dataset))
|
| 54 |
+
for item in dataset:
|
| 55 |
+
img = item['inputs'].permute(1, 2, 0).numpy()
|
| 56 |
+
img = img[..., [2, 1, 0]] # bgr to rgb
|
| 57 |
+
data_sample = item['data_samples'].numpy()
|
| 58 |
+
img_path = osp.basename(item['data_samples'].img_path)
|
| 59 |
+
|
| 60 |
+
out_file = osp.join(
|
| 61 |
+
args.output_dir,
|
| 62 |
+
osp.basename(img_path)) if args.output_dir is not None else None
|
| 63 |
+
|
| 64 |
+
visualizer.add_datasample(
|
| 65 |
+
name=osp.basename(img_path),
|
| 66 |
+
image=img,
|
| 67 |
+
data_sample=data_sample,
|
| 68 |
+
draw_gt=True,
|
| 69 |
+
draw_pred=False,
|
| 70 |
+
wait_time=args.show_interval,
|
| 71 |
+
out_file=out_file,
|
| 72 |
+
show=not args.not_show)
|
| 73 |
+
progress_bar.update()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
main()
|
tools/dataset_converters/drive.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import tempfile
|
| 6 |
+
import zipfile
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import mmcv
|
| 10 |
+
from mmengine.utils import mkdir_or_exist
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
description='Convert DRIVE dataset to mmsegmentation format')
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
'training_path', help='the training part of DRIVE dataset')
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
'testing_path', help='the testing part of DRIVE dataset')
|
| 20 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
| 21 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
return args
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
args = parse_args()
|
| 28 |
+
training_path = args.training_path
|
| 29 |
+
testing_path = args.testing_path
|
| 30 |
+
if args.out_dir is None:
|
| 31 |
+
out_dir = osp.join('data', 'DRIVE')
|
| 32 |
+
else:
|
| 33 |
+
out_dir = args.out_dir
|
| 34 |
+
|
| 35 |
+
print('Making directories...')
|
| 36 |
+
mkdir_or_exist(out_dir)
|
| 37 |
+
mkdir_or_exist(osp.join(out_dir, 'images'))
|
| 38 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
|
| 39 |
+
mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
|
| 40 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations'))
|
| 41 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
|
| 42 |
+
mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
|
| 43 |
+
|
| 44 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
| 45 |
+
print('Extracting training.zip...')
|
| 46 |
+
zip_file = zipfile.ZipFile(training_path)
|
| 47 |
+
zip_file.extractall(tmp_dir)
|
| 48 |
+
|
| 49 |
+
print('Generating training dataset...')
|
| 50 |
+
now_dir = osp.join(tmp_dir, 'training', 'images')
|
| 51 |
+
for img_name in os.listdir(now_dir):
|
| 52 |
+
img = mmcv.imread(osp.join(now_dir, img_name))
|
| 53 |
+
mmcv.imwrite(
|
| 54 |
+
img,
|
| 55 |
+
osp.join(
|
| 56 |
+
out_dir, 'images', 'training',
|
| 57 |
+
osp.splitext(img_name)[0].replace('_training', '') +
|
| 58 |
+
'.png'))
|
| 59 |
+
|
| 60 |
+
now_dir = osp.join(tmp_dir, 'training', '1st_manual')
|
| 61 |
+
for img_name in os.listdir(now_dir):
|
| 62 |
+
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
| 63 |
+
ret, img = cap.read()
|
| 64 |
+
mmcv.imwrite(
|
| 65 |
+
img[:, :, 0] // 128,
|
| 66 |
+
osp.join(out_dir, 'annotations', 'training',
|
| 67 |
+
osp.splitext(img_name)[0] + '.png'))
|
| 68 |
+
|
| 69 |
+
print('Extracting test.zip...')
|
| 70 |
+
zip_file = zipfile.ZipFile(testing_path)
|
| 71 |
+
zip_file.extractall(tmp_dir)
|
| 72 |
+
|
| 73 |
+
print('Generating validation dataset...')
|
| 74 |
+
now_dir = osp.join(tmp_dir, 'test', 'images')
|
| 75 |
+
for img_name in os.listdir(now_dir):
|
| 76 |
+
img = mmcv.imread(osp.join(now_dir, img_name))
|
| 77 |
+
mmcv.imwrite(
|
| 78 |
+
img,
|
| 79 |
+
osp.join(
|
| 80 |
+
out_dir, 'images', 'validation',
|
| 81 |
+
osp.splitext(img_name)[0].replace('_test', '') + '.png'))
|
| 82 |
+
|
| 83 |
+
now_dir = osp.join(tmp_dir, 'test', '1st_manual')
|
| 84 |
+
if osp.exists(now_dir):
|
| 85 |
+
for img_name in os.listdir(now_dir):
|
| 86 |
+
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
| 87 |
+
ret, img = cap.read()
|
| 88 |
+
# The annotation img should be divided by 128, because some of
|
| 89 |
+
# the annotation imgs are not standard. We should set a
|
| 90 |
+
# threshold to convert the nonstandard annotation imgs. The
|
| 91 |
+
# value divided by 128 is equivalent to '1 if value >= 128
|
| 92 |
+
# else 0'
|
| 93 |
+
mmcv.imwrite(
|
| 94 |
+
img[:, :, 0] // 128,
|
| 95 |
+
osp.join(out_dir, 'annotations', 'validation',
|
| 96 |
+
osp.splitext(img_name)[0] + '.png'))
|
| 97 |
+
|
| 98 |
+
now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
|
| 99 |
+
if osp.exists(now_dir):
|
| 100 |
+
for img_name in os.listdir(now_dir):
|
| 101 |
+
cap = cv2.VideoCapture(osp.join(now_dir, img_name))
|
| 102 |
+
ret, img = cap.read()
|
| 103 |
+
mmcv.imwrite(
|
| 104 |
+
img[:, :, 0] // 128,
|
| 105 |
+
osp.join(out_dir, 'annotations', 'validation',
|
| 106 |
+
osp.splitext(img_name)[0] + '.png'))
|
| 107 |
+
|
| 108 |
+
print('Removing the temporary files...')
|
| 109 |
+
|
| 110 |
+
print('Done!')
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == '__main__':
|
| 114 |
+
main()
|
tools/dataset_converters/vaihingen.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import glob
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import os.path as osp
|
| 7 |
+
import tempfile
|
| 8 |
+
import zipfile
|
| 9 |
+
|
| 10 |
+
import mmcv
|
| 11 |
+
import numpy as np
|
| 12 |
+
from mmengine.utils import ProgressBar, mkdir_or_exist
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
parser = argparse.ArgumentParser(
|
| 17 |
+
description='Convert vaihingen dataset to mmsegmentation format')
|
| 18 |
+
parser.add_argument('dataset_path', help='vaihingen folder path')
|
| 19 |
+
parser.add_argument('--tmp_dir', help='path of the temporary directory')
|
| 20 |
+
parser.add_argument('-o', '--out_dir', help='output path')
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
'--clip_size',
|
| 23 |
+
type=int,
|
| 24 |
+
help='clipped size of image after preparation',
|
| 25 |
+
default=512)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
'--stride_size',
|
| 28 |
+
type=int,
|
| 29 |
+
help='stride of clipping original images',
|
| 30 |
+
default=256)
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
return args
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def clip_big_image(image_path, clip_save_dir, to_label=False):
|
| 36 |
+
# Original image of Vaihingen dataset is very large, thus pre-processing
|
| 37 |
+
# of them is adopted. Given fixed clip size and stride size to generate
|
| 38 |
+
# clipped image, the intersection of width and height is determined.
|
| 39 |
+
# For example, given one 5120 x 5120 original image, the clip size is
|
| 40 |
+
# 512 and stride size is 256, thus it would generate 20x20 = 400 images
|
| 41 |
+
# whose size are all 512x512.
|
| 42 |
+
image = mmcv.imread(image_path)
|
| 43 |
+
|
| 44 |
+
h, w, c = image.shape
|
| 45 |
+
cs = args.clip_size
|
| 46 |
+
ss = args.stride_size
|
| 47 |
+
|
| 48 |
+
num_rows = math.ceil((h - cs) / ss) if math.ceil(
|
| 49 |
+
(h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1
|
| 50 |
+
num_cols = math.ceil((w - cs) / ss) if math.ceil(
|
| 51 |
+
(w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1
|
| 52 |
+
|
| 53 |
+
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
| 54 |
+
xmin = x * cs
|
| 55 |
+
ymin = y * cs
|
| 56 |
+
|
| 57 |
+
xmin = xmin.ravel()
|
| 58 |
+
ymin = ymin.ravel()
|
| 59 |
+
xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin))
|
| 60 |
+
ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin))
|
| 61 |
+
boxes = np.stack([
|
| 62 |
+
xmin + xmin_offset, ymin + ymin_offset,
|
| 63 |
+
np.minimum(xmin + cs, w),
|
| 64 |
+
np.minimum(ymin + cs, h)
|
| 65 |
+
],
|
| 66 |
+
axis=1)
|
| 67 |
+
|
| 68 |
+
if to_label:
|
| 69 |
+
color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
|
| 70 |
+
[255, 255, 0], [0, 255, 0], [0, 255, 255],
|
| 71 |
+
[0, 0, 255]])
|
| 72 |
+
flatten_v = np.matmul(
|
| 73 |
+
image.reshape(-1, c),
|
| 74 |
+
np.array([2, 3, 4]).reshape(3, 1))
|
| 75 |
+
out = np.zeros_like(flatten_v)
|
| 76 |
+
for idx, class_color in enumerate(color_map):
|
| 77 |
+
value_idx = np.matmul(class_color,
|
| 78 |
+
np.array([2, 3, 4]).reshape(3, 1))
|
| 79 |
+
out[flatten_v == value_idx] = idx
|
| 80 |
+
image = out.reshape(h, w)
|
| 81 |
+
|
| 82 |
+
for box in boxes:
|
| 83 |
+
start_x, start_y, end_x, end_y = box
|
| 84 |
+
clipped_image = image[start_y:end_y,
|
| 85 |
+
start_x:end_x] if to_label else image[
|
| 86 |
+
start_y:end_y, start_x:end_x, :]
|
| 87 |
+
area_idx = osp.basename(image_path).split('_')[3].strip('.tif')
|
| 88 |
+
mmcv.imwrite(
|
| 89 |
+
clipped_image.astype(np.uint8),
|
| 90 |
+
osp.join(clip_save_dir,
|
| 91 |
+
f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
splits = {
|
| 96 |
+
'train': [
|
| 97 |
+
'area1', 'area11', 'area13', 'area15', 'area17', 'area21',
|
| 98 |
+
'area23', 'area26', 'area28', 'area3', 'area30', 'area32',
|
| 99 |
+
'area34', 'area37', 'area5', 'area7'
|
| 100 |
+
],
|
| 101 |
+
'val': [
|
| 102 |
+
'area6', 'area24', 'area35', 'area16', 'area14', 'area22',
|
| 103 |
+
'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33',
|
| 104 |
+
'area27', 'area38', 'area12', 'area29'
|
| 105 |
+
],
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
dataset_path = args.dataset_path
|
| 109 |
+
if args.out_dir is None:
|
| 110 |
+
out_dir = osp.join('data', 'vaihingen')
|
| 111 |
+
else:
|
| 112 |
+
out_dir = args.out_dir
|
| 113 |
+
|
| 114 |
+
print('Making directories...')
|
| 115 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
|
| 116 |
+
mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
|
| 117 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
|
| 118 |
+
mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
|
| 119 |
+
|
| 120 |
+
zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
|
| 121 |
+
print('Find the data', zipp_list)
|
| 122 |
+
|
| 123 |
+
with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
|
| 124 |
+
for zipp in zipp_list:
|
| 125 |
+
zip_file = zipfile.ZipFile(zipp)
|
| 126 |
+
zip_file.extractall(tmp_dir)
|
| 127 |
+
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
| 128 |
+
if 'ISPRS_semantic_labeling_Vaihingen' in zipp:
|
| 129 |
+
src_path_list = glob.glob(
|
| 130 |
+
os.path.join(os.path.join(tmp_dir, 'top'), '*.tif'))
|
| 131 |
+
if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa
|
| 132 |
+
src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
|
| 133 |
+
# delete unused area9 ground truth
|
| 134 |
+
for area_ann in src_path_list:
|
| 135 |
+
if 'area9' in area_ann:
|
| 136 |
+
src_path_list.remove(area_ann)
|
| 137 |
+
prog_bar = ProgressBar(len(src_path_list))
|
| 138 |
+
for i, src_path in enumerate(src_path_list):
|
| 139 |
+
area_idx = osp.basename(src_path).split('_')[3].strip('.tif')
|
| 140 |
+
data_type = 'train' if area_idx in splits['train'] else 'val'
|
| 141 |
+
if 'noBoundary' in src_path:
|
| 142 |
+
dst_dir = osp.join(out_dir, 'ann_dir', data_type)
|
| 143 |
+
clip_big_image(src_path, dst_dir, to_label=True)
|
| 144 |
+
else:
|
| 145 |
+
dst_dir = osp.join(out_dir, 'img_dir', data_type)
|
| 146 |
+
clip_big_image(src_path, dst_dir, to_label=False)
|
| 147 |
+
prog_bar.update()
|
| 148 |
+
|
| 149 |
+
print('Removing the temporary files...')
|
| 150 |
+
|
| 151 |
+
print('Done!')
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == '__main__':
|
| 155 |
+
args = parse_args()
|
| 156 |
+
main()
|
tools/dataset_tools/analysis_dataset.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from glob import glob
|
| 2 |
+
from typing import Tuple,List
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
from matplotlib import pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
def get_args()->Tuple[str, str]:
|
| 11 |
+
"""
|
| 12 |
+
Return:
|
| 13 |
+
--dataset_dir: dataset dir.
|
| 14 |
+
--save_dir: save dir.
|
| 15 |
+
"""
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument('--dataset_dir', type=str, default='data/grass')
|
| 18 |
+
parser.add_argument('--save_dir', type=str, default='dataset_num_analysis.png')
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
return args.dataset_dir, args.save_dir
|
| 21 |
+
|
| 22 |
+
def get_mask_files(dataset_dir: str)->List[str]:
|
| 23 |
+
"""
|
| 24 |
+
get mask files from dataset dir.
|
| 25 |
+
Args:
|
| 26 |
+
dataset_dir: dataset dir.
|
| 27 |
+
Return:
|
| 28 |
+
mask_filenames: list of mask filenames.
|
| 29 |
+
"""
|
| 30 |
+
mask_filenames = glob(os.path.join(dataset_dir, "ann_dir", "*", "*.png"))
|
| 31 |
+
return mask_filenames
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
dataset_dir, save_dir = get_args()
|
| 35 |
+
mask_filenames = get_mask_files(dataset_dir)
|
| 36 |
+
statistic = {}
|
| 37 |
+
for mask_filename in mask_filenames:
|
| 38 |
+
mask = np.array(Image.open(mask_filename))
|
| 39 |
+
classes = np.unique(mask)
|
| 40 |
+
for class_ in classes:
|
| 41 |
+
class_ = int(class_)
|
| 42 |
+
if class_ not in statistic:
|
| 43 |
+
statistic[class_] = 0
|
| 44 |
+
statistic[(class_)] += int(np.sum(mask == class_))
|
| 45 |
+
|
| 46 |
+
classes = list(statistic.keys())
|
| 47 |
+
clasees_num = list(statistic.values())
|
| 48 |
+
|
| 49 |
+
plt.title("Dataset Analysis")
|
| 50 |
+
bars = plt.bar(classes, clasees_num)
|
| 51 |
+
for bar in bars:
|
| 52 |
+
height = bar.get_height()
|
| 53 |
+
plt.text(bar.get_x() + bar.get_width() / 2, height + 5, str(height), ha='center', va='bottom')
|
| 54 |
+
plt.savefig(save_dir,dpi=300)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main()
|
tools/dataset_tools/process_water.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shutil
|
| 2 |
+
from glob import glob
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import numpy as np
|
| 6 |
+
from rich.progress import track
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import List
|
| 9 |
+
from vegseg.datasets import WaterDataset
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_args():
|
| 14 |
+
parse = argparse.ArgumentParser()
|
| 15 |
+
parse.add_argument("--raw_path", type=str)
|
| 16 |
+
parse.add_argument("--tmp_dir", type=str)
|
| 17 |
+
parse.add_argument("--save_path", type=str)
|
| 18 |
+
args = parse.parse_args()
|
| 19 |
+
return args.raw_path, args.tmp_dir, args.save_path
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_palette() -> List[int]:
|
| 23 |
+
"""
|
| 24 |
+
get palette of dataset.
|
| 25 |
+
return:
|
| 26 |
+
palette: list of palette.
|
| 27 |
+
"""
|
| 28 |
+
palette = []
|
| 29 |
+
palette_list = WaterDataset.METAINFO["palette"]
|
| 30 |
+
for palette_item in palette_list:
|
| 31 |
+
palette.extend(palette_item)
|
| 32 |
+
return palette
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_dataset(image_list, ann_list, image_dir, ann_dir, description="Working..."):
|
| 36 |
+
os.makedirs(image_dir, exist_ok=True)
|
| 37 |
+
os.makedirs(ann_dir, exist_ok=True)
|
| 38 |
+
for image_path, ann_path in track(
|
| 39 |
+
zip(image_list, ann_list), total=len(image_list), description=description
|
| 40 |
+
):
|
| 41 |
+
base_name = os.path.basename(image_path)
|
| 42 |
+
|
| 43 |
+
new_image_path = os.path.join(image_dir, base_name)
|
| 44 |
+
new_ann_path = os.path.join(ann_dir, base_name)
|
| 45 |
+
|
| 46 |
+
shutil.move(image_path, new_image_path)
|
| 47 |
+
shutil.move(ann_path, new_ann_path)
|
| 48 |
+
|
| 49 |
+
mask = Image.open(new_ann_path).convert("P")
|
| 50 |
+
palette = get_palette()
|
| 51 |
+
mask.putpalette(palette)
|
| 52 |
+
mask.save(new_ann_path)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main():
|
| 56 |
+
classes_mapping = {
|
| 57 |
+
"CDUWD-1": 1,
|
| 58 |
+
"CDUWD-2": 2,
|
| 59 |
+
"CDUWD-3": 3,
|
| 60 |
+
"CDUWD-4": 4,
|
| 61 |
+
"CDUWD-5": 5,
|
| 62 |
+
"CDUWD-6": 0,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
raw_path, tmp_dir, save_path = get_args()
|
| 66 |
+
|
| 67 |
+
all_images = glob(os.path.join(raw_path, "*", "images", "*.png"))
|
| 68 |
+
|
| 69 |
+
all_labels = [image_path.replace("images", "labels") for image_path in all_images]
|
| 70 |
+
|
| 71 |
+
target_image_dir = os.path.join(tmp_dir, "images")
|
| 72 |
+
target_label_dir = os.path.join(tmp_dir, "labels")
|
| 73 |
+
|
| 74 |
+
os.makedirs(target_image_dir, exist_ok=True)
|
| 75 |
+
os.makedirs(target_label_dir, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
for image_path, label_path in track(
|
| 78 |
+
zip(all_images, all_labels), total=len(all_images), description="fuse dataset"
|
| 79 |
+
):
|
| 80 |
+
exists_images = glob(os.path.join(target_image_dir, "*.png"))
|
| 81 |
+
|
| 82 |
+
base_name = os.path.basename(image_path)
|
| 83 |
+
if image_path not in exists_images:
|
| 84 |
+
mask = np.array(Image.open(label_path))
|
| 85 |
+
|
| 86 |
+
assert list(np.unique(mask)) in [
|
| 87 |
+
[0],
|
| 88 |
+
[1],
|
| 89 |
+
[0, 1],
|
| 90 |
+
[1, 0],
|
| 91 |
+
], f"The mask image is not binary (it should only contain 0s and 1s),actually is {set(np.unique(mask))}"
|
| 92 |
+
|
| 93 |
+
classes_str = image_path.split(os.path.sep)[-3]
|
| 94 |
+
classes = classes_mapping[classes_str]
|
| 95 |
+
mask = np.where(mask == 1, classes, mask)
|
| 96 |
+
|
| 97 |
+
# print(classes_str)
|
| 98 |
+
|
| 99 |
+
mask = Image.fromarray(mask)
|
| 100 |
+
mask.save(os.path.join(target_label_dir, base_name))
|
| 101 |
+
shutil.copy(image_path, os.path.join(target_image_dir, base_name))
|
| 102 |
+
else:
|
| 103 |
+
|
| 104 |
+
exists_label_path = os.path.join(target_label_dir, base_name)
|
| 105 |
+
exists_mask = np.array(Image.open(exists_label_path))
|
| 106 |
+
|
| 107 |
+
mask = np.array(Image.open(label_path))
|
| 108 |
+
assert list(np.unique(mask)) in [
|
| 109 |
+
[0],
|
| 110 |
+
[1],
|
| 111 |
+
[0, 1],
|
| 112 |
+
[1, 0],
|
| 113 |
+
], f"The mask image is not binary (it should only contain 0s and 1s),actually is {set(np.unique(mask))}"
|
| 114 |
+
classes_str = image_path.split(os.path.sep)[-3]
|
| 115 |
+
classes = classes_mapping[classes_str]
|
| 116 |
+
|
| 117 |
+
exists_mask = np.where(mask == 1, classes, exists_mask)
|
| 118 |
+
|
| 119 |
+
exists_mask = Image.fromarray(exists_mask)
|
| 120 |
+
exists_mask.save(exists_label_path)
|
| 121 |
+
|
| 122 |
+
exists_images = glob(os.path.join(target_image_dir, "*.png"))
|
| 123 |
+
|
| 124 |
+
exists_labels = [
|
| 125 |
+
image_path.replace("images", "labels") for image_path in exists_images
|
| 126 |
+
]
|
| 127 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 128 |
+
exists_images, exists_labels, test_size=0.2, random_state=42, shuffle=True
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
create_dataset(
|
| 132 |
+
X_train,
|
| 133 |
+
y_train,
|
| 134 |
+
os.path.join(save_path, "img_dir", "train"),
|
| 135 |
+
os.path.join(save_path, "ann_dir", "train"),
|
| 136 |
+
description="train dataset",
|
| 137 |
+
)
|
| 138 |
+
create_dataset(
|
| 139 |
+
X_test,
|
| 140 |
+
y_test,
|
| 141 |
+
os.path.join(save_path, "img_dir", "val"),
|
| 142 |
+
os.path.join(save_path, "ann_dir", "val"),
|
| 143 |
+
description="val dataset",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
os.rmdir(target_image_dir)
|
| 147 |
+
os.rmdir(target_label_dir)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
# example python tools/dataset_tools/process_water.py --raw_path data/raw_water_dataset/1024 --tmp_dir data/raw_water_dataset/1024/all_dataset --save_path data/water_1024_1024
|
| 152 |
+
main()
|
tools/deployment/pytorch2torchscript.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch._C
|
| 7 |
+
import torch.serialization
|
| 8 |
+
from mmengine import Config
|
| 9 |
+
from mmengine.runner import load_checkpoint
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from mmseg.models import build_segmentor
|
| 13 |
+
|
| 14 |
+
torch.manual_seed(3)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def digit_version(version_str):
|
| 18 |
+
digit_version = []
|
| 19 |
+
for x in version_str.split('.'):
|
| 20 |
+
if x.isdigit():
|
| 21 |
+
digit_version.append(int(x))
|
| 22 |
+
elif x.find('rc') != -1:
|
| 23 |
+
patch_version = x.split('rc')
|
| 24 |
+
digit_version.append(int(patch_version[0]) - 1)
|
| 25 |
+
digit_version.append(int(patch_version[1]))
|
| 26 |
+
return digit_version
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check_torch_version():
|
| 30 |
+
torch_minimum_version = '1.8.0'
|
| 31 |
+
torch_version = digit_version(torch.__version__)
|
| 32 |
+
|
| 33 |
+
assert (torch_version >= digit_version(torch_minimum_version)), \
|
| 34 |
+
f'Torch=={torch.__version__} is not support for converting to ' \
|
| 35 |
+
f'torchscript. Please install pytorch>={torch_minimum_version}.'
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _convert_batchnorm(module):
|
| 39 |
+
module_output = module
|
| 40 |
+
if isinstance(module, torch.nn.SyncBatchNorm):
|
| 41 |
+
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
| 42 |
+
module.momentum, module.affine,
|
| 43 |
+
module.track_running_stats)
|
| 44 |
+
if module.affine:
|
| 45 |
+
module_output.weight.data = module.weight.data.clone().detach()
|
| 46 |
+
module_output.bias.data = module.bias.data.clone().detach()
|
| 47 |
+
# keep requires_grad unchanged
|
| 48 |
+
module_output.weight.requires_grad = module.weight.requires_grad
|
| 49 |
+
module_output.bias.requires_grad = module.bias.requires_grad
|
| 50 |
+
module_output.running_mean = module.running_mean
|
| 51 |
+
module_output.running_var = module.running_var
|
| 52 |
+
module_output.num_batches_tracked = module.num_batches_tracked
|
| 53 |
+
for name, child in module.named_children():
|
| 54 |
+
module_output.add_module(name, _convert_batchnorm(child))
|
| 55 |
+
del module
|
| 56 |
+
return module_output
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _demo_mm_inputs(input_shape, num_classes):
|
| 60 |
+
"""Create a superset of inputs needed to run test or train batches.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
input_shape (tuple):
|
| 64 |
+
input batch dimensions
|
| 65 |
+
num_classes (int):
|
| 66 |
+
number of semantic classes
|
| 67 |
+
"""
|
| 68 |
+
(N, C, H, W) = input_shape
|
| 69 |
+
rng = np.random.RandomState(0)
|
| 70 |
+
imgs = rng.rand(*input_shape)
|
| 71 |
+
segs = rng.randint(
|
| 72 |
+
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
| 73 |
+
img_metas = [{
|
| 74 |
+
'img_shape': (H, W, C),
|
| 75 |
+
'ori_shape': (H, W, C),
|
| 76 |
+
'pad_shape': (H, W, C),
|
| 77 |
+
'filename': '<demo>.png',
|
| 78 |
+
'scale_factor': 1.0,
|
| 79 |
+
'flip': False,
|
| 80 |
+
} for _ in range(N)]
|
| 81 |
+
mm_inputs = {
|
| 82 |
+
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
| 83 |
+
'img_metas': img_metas,
|
| 84 |
+
'gt_semantic_seg': torch.LongTensor(segs)
|
| 85 |
+
}
|
| 86 |
+
return mm_inputs
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def pytorch2libtorch(model,
|
| 90 |
+
input_shape,
|
| 91 |
+
show=False,
|
| 92 |
+
output_file='tmp.pt',
|
| 93 |
+
verify=False):
|
| 94 |
+
"""Export Pytorch model to TorchScript model and verify the outputs are
|
| 95 |
+
same between Pytorch and TorchScript.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
model (nn.Module): Pytorch model we want to export.
|
| 99 |
+
input_shape (tuple): Use this input shape to construct
|
| 100 |
+
the corresponding dummy input and execute the model.
|
| 101 |
+
show (bool): Whether print the computation graph. Default: False.
|
| 102 |
+
output_file (string): The path to where we store the
|
| 103 |
+
output TorchScript model. Default: `tmp.pt`.
|
| 104 |
+
verify (bool): Whether compare the outputs between
|
| 105 |
+
Pytorch and TorchScript. Default: False.
|
| 106 |
+
"""
|
| 107 |
+
if isinstance(model.decode_head, nn.ModuleList):
|
| 108 |
+
num_classes = model.decode_head[-1].num_classes
|
| 109 |
+
else:
|
| 110 |
+
num_classes = model.decode_head.num_classes
|
| 111 |
+
|
| 112 |
+
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
| 113 |
+
|
| 114 |
+
imgs = mm_inputs.pop('imgs')
|
| 115 |
+
|
| 116 |
+
# replace the original forword with forward_dummy
|
| 117 |
+
model.forward = model.forward_dummy
|
| 118 |
+
model.eval()
|
| 119 |
+
traced_model = torch.jit.trace(
|
| 120 |
+
model,
|
| 121 |
+
example_inputs=imgs,
|
| 122 |
+
check_trace=verify,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if show:
|
| 126 |
+
print(traced_model.graph)
|
| 127 |
+
|
| 128 |
+
traced_model.save(output_file)
|
| 129 |
+
print(f'Successfully exported TorchScript model: {output_file}')
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def parse_args():
|
| 133 |
+
parser = argparse.ArgumentParser(
|
| 134 |
+
description='Convert MMSeg to TorchScript')
|
| 135 |
+
parser.add_argument('config', help='test config file path')
|
| 136 |
+
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
'--show', action='store_true', help='show TorchScript graph')
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
'--verify', action='store_true', help='verify the TorchScript model')
|
| 141 |
+
parser.add_argument('--output-file', type=str, default='tmp.pt')
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
'--shape',
|
| 144 |
+
type=int,
|
| 145 |
+
nargs='+',
|
| 146 |
+
default=[512, 512],
|
| 147 |
+
help='input image size (height, width)')
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
return args
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == '__main__':
|
| 153 |
+
args = parse_args()
|
| 154 |
+
check_torch_version()
|
| 155 |
+
|
| 156 |
+
if len(args.shape) == 1:
|
| 157 |
+
input_shape = (1, 3, args.shape[0], args.shape[0])
|
| 158 |
+
elif len(args.shape) == 2:
|
| 159 |
+
input_shape = (
|
| 160 |
+
1,
|
| 161 |
+
3,
|
| 162 |
+
) + tuple(args.shape)
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError('invalid input shape')
|
| 165 |
+
|
| 166 |
+
cfg = Config.fromfile(args.config)
|
| 167 |
+
cfg.model.pretrained = None
|
| 168 |
+
|
| 169 |
+
# build the model and load checkpoint
|
| 170 |
+
cfg.model.train_cfg = None
|
| 171 |
+
segmentor = build_segmentor(
|
| 172 |
+
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
|
| 173 |
+
# convert SyncBN to BN
|
| 174 |
+
segmentor = _convert_batchnorm(segmentor)
|
| 175 |
+
|
| 176 |
+
if args.checkpoint:
|
| 177 |
+
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
| 178 |
+
|
| 179 |
+
# convert the PyTorch model to LibTorch model
|
| 180 |
+
pytorch2libtorch(
|
| 181 |
+
segmentor,
|
| 182 |
+
input_shape,
|
| 183 |
+
show=args.show,
|
| 184 |
+
output_file=args.output_file,
|
| 185 |
+
verify=args.verify)
|
tools/model_converters/beit2mmseg.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import mmengine
|
| 7 |
+
import torch
|
| 8 |
+
from mmengine.runner import CheckpointLoader
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def convert_beit(ckpt):
|
| 12 |
+
new_ckpt = OrderedDict()
|
| 13 |
+
|
| 14 |
+
for k, v in ckpt.items():
|
| 15 |
+
if k.startswith('patch_embed'):
|
| 16 |
+
new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
|
| 17 |
+
new_ckpt[new_key] = v
|
| 18 |
+
if k.startswith('blocks'):
|
| 19 |
+
new_key = k.replace('blocks', 'layers')
|
| 20 |
+
if 'norm' in new_key:
|
| 21 |
+
new_key = new_key.replace('norm', 'ln')
|
| 22 |
+
elif 'mlp.fc1' in new_key:
|
| 23 |
+
new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
|
| 24 |
+
elif 'mlp.fc2' in new_key:
|
| 25 |
+
new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
|
| 26 |
+
new_ckpt[new_key] = v
|
| 27 |
+
else:
|
| 28 |
+
new_key = k
|
| 29 |
+
new_ckpt[new_key] = v
|
| 30 |
+
|
| 31 |
+
return new_ckpt
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
parser = argparse.ArgumentParser(
|
| 36 |
+
description='Convert keys in official pretrained beit models to'
|
| 37 |
+
'MMSegmentation style.')
|
| 38 |
+
parser.add_argument('src', help='src model path or url')
|
| 39 |
+
# The dst path must be a full path of the new checkpoint.
|
| 40 |
+
parser.add_argument('dst', help='save path')
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
| 44 |
+
if 'state_dict' in checkpoint:
|
| 45 |
+
state_dict = checkpoint['state_dict']
|
| 46 |
+
elif 'model' in checkpoint:
|
| 47 |
+
state_dict = checkpoint['model']
|
| 48 |
+
else:
|
| 49 |
+
state_dict = checkpoint
|
| 50 |
+
weight = convert_beit(state_dict)
|
| 51 |
+
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
| 52 |
+
torch.save(weight, args.dst)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
main()
|