lihongjie commited on
Commit
fafd9e7
·
1 Parent(s): e679233

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models-ax637/deeplabv3plus_mobilenet_u16.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ models-ax650/deeplabv3plus_mobilenet_u16.axmodel filter=lfs diff=lfs merge=lfs -text
38
+ samples/114_image.png filter=lfs diff=lfs merge=lfs -text
39
+ samples/1_image.png filter=lfs diff=lfs merge=lfs -text
40
+ samples/1_target.png filter=lfs diff=lfs merge=lfs -text
41
+ samples/23_target.png filter=lfs diff=lfs merge=lfs -text
42
+ samples/city_1_overlay.png filter=lfs diff=lfs merge=lfs -text
43
+ samples/city_1_target.png filter=lfs diff=lfs merge=lfs -text
44
+ samples/114_overlay.png filter=lfs diff=lfs merge=lfs -text
45
+ samples/23_image.png filter=lfs diff=lfs merge=lfs -text
46
+ samples/23_overlay.png filter=lfs diff=lfs merge=lfs -text
47
+ samples/23_pred.png filter=lfs diff=lfs merge=lfs -text
48
+ samples/city_6_overlay.png filter=lfs diff=lfs merge=lfs -text
49
+ samples/city_6_target.png filter=lfs diff=lfs merge=lfs -text
50
+ samples/1_overlay.png filter=lfs diff=lfs merge=lfs -text
51
+ samples/1_pred.png filter=lfs diff=lfs merge=lfs -text
52
+ samples/114_pred.png filter=lfs diff=lfs merge=lfs -text
53
+ samples/114_target.png filter=lfs diff=lfs merge=lfs -text
54
+ samples/visdom-screenshoot.png filter=lfs diff=lfs merge=lfs -text
55
+ output-ax.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -1,3 +1,73 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: bsd-3-clause
3
+ language:
4
+ - en
5
+ base_model:
6
+ - deeplabv3plus_mobilenet
7
+ pipeline_tag: semantic-segmentation
8
+ tags:
9
+ - deeplabv3plus
10
  ---
11
+
12
+ # DeepLabv3Plus
13
+
14
+ This version of deeplabv3plus_mobilenet has been converted to run on the Axera NPU using **w8a16** quantization.
15
+
16
+ Compatible with Pulsar2 version: 5.0-patch1
17
+
18
+ ## Convert tools links:
19
+
20
+ For those who are interested in model conversion, you can try to export axmodel through
21
+
22
+ - [The repo of original](https://github.com/VainF/DeepLabV3Plus-Pytorch.git)
23
+
24
+ - [Pulsar2 Link, How to Convert ONNX to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/pulsar2/introduction.html)
25
+
26
+
27
+ ## Support Platform
28
+
29
+ - AX650
30
+ - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
31
+ - [M.2 Accelerator card](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
32
+ - AX637
33
+
34
+ |Chips|Models |Time|
35
+ |--|--|--|
36
+ |AX650|deeplabv3plus_mobilenet_u16|13.4 ms |
37
+ |AX637|deeplabv3plus_mobilenet_u16|39.4 ms |
38
+
39
+
40
+ ## How to use
41
+
42
+ Download all files from this repository to the device
43
+
44
+
45
+ ### python env requirement
46
+
47
+ #### pyaxengine
48
+
49
+ https://github.com/AXERA-TECH/pyaxengine
50
+
51
+ ```
52
+ wget https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3.rc2/axengine-0.1.3-py3-none-any.whl
53
+ pip install axengine-0.1.3-py3-none-any.whl
54
+ ```
55
+
56
+ #### others
57
+
58
+ Maybe None.
59
+
60
+ #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
61
+
62
+ Input image:
63
+
64
+ ![](samples/1_image.png)
65
+
66
+ run
67
+ ```
68
+ python3 infer.py --img samples/1_image.png --model models-ax637/deeplabv3plus_mobilenet_u16.axmodel
69
+ ```
70
+
71
+ Output image:
72
+
73
+ ![](output-ax.png)
datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .voc import VOCSegmentation
2
+ from .cityscapes import Cityscapes
datasets/cityscapes.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.utils.data as data
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+
11
+ class Cityscapes(data.Dataset):
12
+ """Cityscapes <http://www.cityscapes-dataset.com/> Dataset.
13
+
14
+ **Parameters:**
15
+ - **root** (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
16
+ - **split** (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
17
+ - **mode** (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
18
+ - **transform** (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
19
+ - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.
20
+ """
21
+
22
+ # Based on https://github.com/mcordts/cityscapesScripts
23
+ CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
24
+ 'has_instances', 'ignore_in_eval', 'color'])
25
+ classes = [
26
+ CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
27
+ CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
28
+ CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
29
+ CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
30
+ CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
31
+ CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
32
+ CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
33
+ CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
34
+ CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
35
+ CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
36
+ CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
37
+ CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
38
+ CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
39
+ CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
40
+ CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
41
+ CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
42
+ CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
43
+ CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
44
+ CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
45
+ CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
46
+ CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
47
+ CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
48
+ CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
49
+ CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
50
+ CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
51
+ CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
52
+ CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
53
+ CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
54
+ CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
55
+ CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
56
+ CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
57
+ CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
58
+ CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
59
+ CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
60
+ CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
61
+ ]
62
+
63
+ train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
64
+ train_id_to_color.append([0, 0, 0])
65
+ train_id_to_color = np.array(train_id_to_color)
66
+ id_to_train_id = np.array([c.train_id for c in classes])
67
+
68
+ #train_id_to_color = [(0, 0, 0), (128, 64, 128), (70, 70, 70), (153, 153, 153), (107, 142, 35),
69
+ # (70, 130, 180), (220, 20, 60), (0, 0, 142)]
70
+ #train_id_to_color = np.array(train_id_to_color)
71
+ #id_to_train_id = np.array([c.category_id for c in classes], dtype='uint8') - 1
72
+
73
+ def __init__(self, root, split='train', mode='fine', target_type='semantic', transform=None):
74
+ self.root = os.path.expanduser(root)
75
+ self.mode = 'gtFine'
76
+ self.target_type = target_type
77
+ self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
78
+
79
+ self.targets_dir = os.path.join(self.root, self.mode, split)
80
+ self.transform = transform
81
+
82
+ self.split = split
83
+ self.images = []
84
+ self.targets = []
85
+
86
+ if split not in ['train', 'test', 'val']:
87
+ raise ValueError('Invalid split for mode! Please use split="train", split="test"'
88
+ ' or split="val"')
89
+
90
+ if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
91
+ raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
92
+ ' specified "split" and "mode" are inside the "root" directory')
93
+
94
+ for city in os.listdir(self.images_dir):
95
+ img_dir = os.path.join(self.images_dir, city)
96
+ target_dir = os.path.join(self.targets_dir, city)
97
+
98
+ for file_name in os.listdir(img_dir):
99
+ self.images.append(os.path.join(img_dir, file_name))
100
+ target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
101
+ self._get_target_suffix(self.mode, self.target_type))
102
+ self.targets.append(os.path.join(target_dir, target_name))
103
+
104
+ @classmethod
105
+ def encode_target(cls, target):
106
+ return cls.id_to_train_id[np.array(target)]
107
+
108
+ @classmethod
109
+ def decode_target(cls, target):
110
+ target[target == 255] = 19
111
+ #target = target.astype('uint8') + 1
112
+ return cls.train_id_to_color[target]
113
+
114
+ def __getitem__(self, index):
115
+ """
116
+ Args:
117
+ index (int): Index
118
+ Returns:
119
+ tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
120
+ than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
121
+ """
122
+ image = Image.open(self.images[index]).convert('RGB')
123
+ target = Image.open(self.targets[index])
124
+ if self.transform:
125
+ image, target = self.transform(image, target)
126
+ target = self.encode_target(target)
127
+ return image, target
128
+
129
+ def __len__(self):
130
+ return len(self.images)
131
+
132
+ def _load_json(self, path):
133
+ with open(path, 'r') as file:
134
+ data = json.load(file)
135
+ return data
136
+
137
+ def _get_target_suffix(self, mode, target_type):
138
+ if target_type == 'instance':
139
+ return '{}_instanceIds.png'.format(mode)
140
+ elif target_type == 'semantic':
141
+ return '{}_labelIds.png'.format(mode)
142
+ elif target_type == 'color':
143
+ return '{}_color.png'.format(mode)
144
+ elif target_type == 'polygon':
145
+ return '{}_polygons.json'.format(mode)
146
+ elif target_type == 'depth':
147
+ return '{}_disparity.png'.format(mode)
datasets/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import hashlib
4
+ import errno
5
+ from tqdm import tqdm
6
+
7
+
8
+ def gen_bar_updater(pbar):
9
+ def bar_update(count, block_size, total_size):
10
+ if pbar.total is None and total_size:
11
+ pbar.total = total_size
12
+ progress_bytes = count * block_size
13
+ pbar.update(progress_bytes - pbar.n)
14
+
15
+ return bar_update
16
+
17
+
18
+ def check_integrity(fpath, md5=None):
19
+ if md5 is None:
20
+ return True
21
+ if not os.path.isfile(fpath):
22
+ return False
23
+ md5o = hashlib.md5()
24
+ with open(fpath, 'rb') as f:
25
+ # read in 1MB chunks
26
+ for chunk in iter(lambda: f.read(1024 * 1024), b''):
27
+ md5o.update(chunk)
28
+ md5c = md5o.hexdigest()
29
+ if md5c != md5:
30
+ return False
31
+ return True
32
+
33
+
34
+ def makedir_exist_ok(dirpath):
35
+ """
36
+ Python2 support for os.makedirs(.., exist_ok=True)
37
+ """
38
+ try:
39
+ os.makedirs(dirpath)
40
+ except OSError as e:
41
+ if e.errno == errno.EEXIST:
42
+ pass
43
+ else:
44
+ raise
45
+
46
+
47
+ def download_url(url, root, filename=None, md5=None):
48
+ """Download a file from a url and place it in root.
49
+ Args:
50
+ url (str): URL to download file from
51
+ root (str): Directory to place downloaded file in
52
+ filename (str): Name to save the file under. If None, use the basename of the URL
53
+ md5 (str): MD5 checksum of the download. If None, do not check
54
+ """
55
+ from six.moves import urllib
56
+
57
+ root = os.path.expanduser(root)
58
+ if not filename:
59
+ filename = os.path.basename(url)
60
+ fpath = os.path.join(root, filename)
61
+
62
+ makedir_exist_ok(root)
63
+
64
+ # downloads file
65
+ if os.path.isfile(fpath) and check_integrity(fpath, md5):
66
+ print('Using downloaded and verified file: ' + fpath)
67
+ else:
68
+ try:
69
+ print('Downloading ' + url + ' to ' + fpath)
70
+ urllib.request.urlretrieve(
71
+ url, fpath,
72
+ reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
73
+ )
74
+ except OSError:
75
+ if url[:5] == 'https':
76
+ url = url.replace('https:', 'http:')
77
+ print('Failed download. Trying https -> http instead.'
78
+ ' Downloading ' + url + ' to ' + fpath)
79
+ urllib.request.urlretrieve(
80
+ url, fpath,
81
+ reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
82
+ )
83
+
84
+
85
+ def list_dir(root, prefix=False):
86
+ """List all directories at a given root
87
+ Args:
88
+ root (str): Path to directory whose folders need to be listed
89
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
90
+ only returns the name of the directories found
91
+ """
92
+ root = os.path.expanduser(root)
93
+ directories = list(
94
+ filter(
95
+ lambda p: os.path.isdir(os.path.join(root, p)),
96
+ os.listdir(root)
97
+ )
98
+ )
99
+
100
+ if prefix is True:
101
+ directories = [os.path.join(root, d) for d in directories]
102
+
103
+ return directories
104
+
105
+
106
+ def list_files(root, suffix, prefix=False):
107
+ """List all files ending with a suffix at a given root
108
+ Args:
109
+ root (str): Path to directory whose folders need to be listed
110
+ suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
111
+ It uses the Python "str.endswith" method and is passed directly
112
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
113
+ only returns the name of the files found
114
+ """
115
+ root = os.path.expanduser(root)
116
+ files = list(
117
+ filter(
118
+ lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
119
+ os.listdir(root)
120
+ )
121
+ )
122
+
123
+ if prefix is True:
124
+ files = [os.path.join(root, d) for d in files]
125
+
126
+ return files
datasets/voc.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tarfile
4
+ import collections
5
+ import torch.utils.data as data
6
+ import shutil
7
+ import numpy as np
8
+
9
+ from PIL import Image
10
+ from torchvision.datasets.utils import download_url, check_integrity
11
+
12
+ DATASET_YEAR_DICT = {
13
+ '2012': {
14
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
15
+ 'filename': 'VOCtrainval_11-May-2012.tar',
16
+ 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
17
+ 'base_dir': 'VOCdevkit/VOC2012'
18
+ },
19
+ '2011': {
20
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
21
+ 'filename': 'VOCtrainval_25-May-2011.tar',
22
+ 'md5': '6c3384ef61512963050cb5d687e5bf1e',
23
+ 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
24
+ },
25
+ '2010': {
26
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
27
+ 'filename': 'VOCtrainval_03-May-2010.tar',
28
+ 'md5': 'da459979d0c395079b5c75ee67908abb',
29
+ 'base_dir': 'VOCdevkit/VOC2010'
30
+ },
31
+ '2009': {
32
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
33
+ 'filename': 'VOCtrainval_11-May-2009.tar',
34
+ 'md5': '59065e4b188729180974ef6572f6a212',
35
+ 'base_dir': 'VOCdevkit/VOC2009'
36
+ },
37
+ '2008': {
38
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
39
+ 'filename': 'VOCtrainval_11-May-2012.tar',
40
+ 'md5': '2629fa636546599198acfcfbfcf1904a',
41
+ 'base_dir': 'VOCdevkit/VOC2008'
42
+ },
43
+ '2007': {
44
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
45
+ 'filename': 'VOCtrainval_06-Nov-2007.tar',
46
+ 'md5': 'c52e279531787c972589f7e41ab4ae64',
47
+ 'base_dir': 'VOCdevkit/VOC2007'
48
+ }
49
+ }
50
+
51
+
52
+ def voc_cmap(N=256, normalized=False):
53
+ def bitget(byteval, idx):
54
+ return ((byteval & (1 << idx)) != 0)
55
+
56
+ dtype = 'float32' if normalized else 'uint8'
57
+ cmap = np.zeros((N, 3), dtype=dtype)
58
+ for i in range(N):
59
+ r = g = b = 0
60
+ c = i
61
+ for j in range(8):
62
+ r = r | (bitget(c, 0) << 7-j)
63
+ g = g | (bitget(c, 1) << 7-j)
64
+ b = b | (bitget(c, 2) << 7-j)
65
+ c = c >> 3
66
+
67
+ cmap[i] = np.array([r, g, b])
68
+
69
+ cmap = cmap/255 if normalized else cmap
70
+ return cmap
71
+
72
+ class VOCSegmentation(data.Dataset):
73
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
74
+ Args:
75
+ root (string): Root directory of the VOC Dataset.
76
+ year (string, optional): The dataset year, supports years 2007 to 2012.
77
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
78
+ download (bool, optional): If true, downloads the dataset from the internet and
79
+ puts it in root directory. If dataset is already downloaded, it is not
80
+ downloaded again.
81
+ transform (callable, optional): A function/transform that takes in an PIL image
82
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
83
+ """
84
+ cmap = voc_cmap()
85
+ def __init__(self,
86
+ root,
87
+ year='2012',
88
+ image_set='train',
89
+ download=False,
90
+ transform=None):
91
+
92
+ is_aug=False
93
+ if year=='2012_aug':
94
+ is_aug = True
95
+ year = '2012'
96
+
97
+ self.root = os.path.expanduser(root)
98
+ self.year = year
99
+ self.url = DATASET_YEAR_DICT[year]['url']
100
+ self.filename = DATASET_YEAR_DICT[year]['filename']
101
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
102
+ self.transform = transform
103
+
104
+ self.image_set = image_set
105
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
106
+ voc_root = os.path.join(self.root, base_dir)
107
+ image_dir = os.path.join(voc_root, 'JPEGImages')
108
+
109
+ if download:
110
+ download_extract(self.url, self.root, self.filename, self.md5)
111
+
112
+ if not os.path.isdir(voc_root):
113
+ raise RuntimeError('Dataset not found or corrupted.' +
114
+ ' You can use download=True to download it')
115
+
116
+ if is_aug and image_set=='train':
117
+ mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
118
+ assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
119
+ split_f = os.path.join( self.root, 'train_aug.txt')#'./datasets/data/train_aug.txt'
120
+ else:
121
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
122
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
123
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
124
+
125
+ if not os.path.exists(split_f):
126
+ raise ValueError(
127
+ 'Wrong image_set entered! Please use image_set="train" '
128
+ 'or image_set="trainval" or image_set="val"')
129
+
130
+ with open(os.path.join(split_f), "r") as f:
131
+ file_names = [x.strip() for x in f.readlines()]
132
+
133
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
134
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
135
+ assert (len(self.images) == len(self.masks))
136
+
137
+ def __getitem__(self, index):
138
+ """
139
+ Args:
140
+ index (int): Index
141
+ Returns:
142
+ tuple: (image, target) where target is the image segmentation.
143
+ """
144
+ img = Image.open(self.images[index]).convert('RGB')
145
+ target = Image.open(self.masks[index])
146
+ if self.transform is not None:
147
+ img, target = self.transform(img, target)
148
+
149
+ return img, target
150
+
151
+
152
+ def __len__(self):
153
+ return len(self.images)
154
+
155
+ @classmethod
156
+ def decode_target(cls, mask):
157
+ """decode semantic mask to RGB image"""
158
+ return cls.cmap[mask]
159
+
160
+ def download_extract(url, root, filename, md5):
161
+ download_url(url, root, filename, md5)
162
+ with tarfile.open(os.path.join(root, filename), "r") as tar:
163
+ tar.extractall(path=root)
infer.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import os
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ import axengine as axe
8
+ from datasets import VOCSegmentation, Cityscapes, cityscapes
9
+
10
+ def parse_args() -> argparse.Namespace:
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--img",
14
+ type=str,
15
+ required=True,
16
+ help="Path to input image.",
17
+ )
18
+ parser.add_argument(
19
+ "--model",
20
+ type=str,
21
+ required=True,
22
+ help="Path to axmodel model.",
23
+ )
24
+
25
+ return parser.parse_args()
26
+
27
+
28
+ def infer(img: str, model: str, viz: bool = False):
29
+ img_raw = cv2.imread(img)
30
+ image = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
31
+ image = cv2.resize(image, (513,513))
32
+ image = image[None]
33
+
34
+ session = axe.InferenceSession(model)
35
+
36
+ pred = session.run(None, {"input": image})[0]
37
+ pred = torch.from_numpy(pred)
38
+ pred = pred.max(1)[1].cpu().numpy()[0] # HW
39
+
40
+ decode_fn = VOCSegmentation.decode_target
41
+ colorized_preds = decode_fn(pred).astype('uint8')
42
+ colorized_preds = Image.fromarray(colorized_preds)
43
+ colorized_preds.save("output-ax.png")
44
+
45
+
46
+
47
+
48
+ if __name__ == "__main__":
49
+ args = parse_args()
50
+ infer(**vars(args))
models-ax637/deeplabv3plus_mobilenet_u16.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efe6256681294ee1e498e198dfa047eab5852eb23542b1e8836dbff856fb23af
3
+ size 9202159
models-ax650/deeplabv3plus_mobilenet_u16.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c3e0e24f96ee032f211d3f909c049316678035128dfe16dce5374019a73069a
3
+ size 10083639
output-ax.png ADDED

Git LFS Details

  • SHA256: adc1205119e2d72389c90028438fdb593f91b907f11cf57b84fdecd476b61e5c
  • Pointer size: 129 Bytes
  • Size of remote file: 2.71 kB
samples/114_image.png ADDED

Git LFS Details

  • SHA256: f6bd2345121ce1a2329083a5d5fb1f79779506a501bfc3f82d5e1cf7041d5f92
  • Pointer size: 131 Bytes
  • Size of remote file: 322 kB
samples/114_overlay.png ADDED

Git LFS Details

  • SHA256: aacd9ab41d88ed5b8e5f9f0d9952701ca842a01cce1e934bd22fd9e93a2479de
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
samples/114_pred.png ADDED

Git LFS Details

  • SHA256: 3544c6c8f6bb1c25dcd0cdcad947926d0065f01ccc90f0a1df3eed84f008fdae
  • Pointer size: 129 Bytes
  • Size of remote file: 3.43 kB
samples/114_target.png ADDED

Git LFS Details

  • SHA256: bead4e348f6164d1bb9c6be45db8c8da1d121c11359486fd2385f9264f8be643
  • Pointer size: 129 Bytes
  • Size of remote file: 4.62 kB
samples/1_image.png ADDED

Git LFS Details

  • SHA256: 7fd691ea077d004934b52854d6882b5d5620b2bd839a086199cc26a4194d099e
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
samples/1_overlay.png ADDED

Git LFS Details

  • SHA256: bde06f8e65ce157241715e6c3e88de2f5d52e9f3e6a41d07a4cb2d46145420d7
  • Pointer size: 130 Bytes
  • Size of remote file: 61.6 kB
samples/1_pred.png ADDED

Git LFS Details

  • SHA256: 1323dcb8f6488bcba48b19b2809c3d0697d6e8d5846202b2e69e3fff02194fda
  • Pointer size: 129 Bytes
  • Size of remote file: 2.66 kB
samples/1_target.png ADDED

Git LFS Details

  • SHA256: a405e6d9c68d8a73a1879d73ac63eec2fb7c389841fec73bcf8365b957a256e6
  • Pointer size: 129 Bytes
  • Size of remote file: 3.43 kB
samples/23_image.png ADDED

Git LFS Details

  • SHA256: 7aad93a3787651ed1e84eb0c7b207c707e86ef83fb84c4abe4aadae42fba36c8
  • Pointer size: 131 Bytes
  • Size of remote file: 454 kB
samples/23_overlay.png ADDED

Git LFS Details

  • SHA256: 8318a933fc8bb17e527c2ab88f0bab48f2cd51a631b210891c713d3e783da57c
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
samples/23_pred.png ADDED

Git LFS Details

  • SHA256: 03267008a08becda037d66e8532ec2609e2fb13a978abe229952c69f2f460575
  • Pointer size: 129 Bytes
  • Size of remote file: 3.32 kB
samples/23_target.png ADDED

Git LFS Details

  • SHA256: dcb3bfadd307a961702b2de7260b48f162cdade8a23d80f7ef0ab06c416c8329
  • Pointer size: 129 Bytes
  • Size of remote file: 3.94 kB
samples/city_1_overlay.png ADDED

Git LFS Details

  • SHA256: 7e61fb609013ea935534752bf322c1d0f67cda5874edc58c0b6bffd937b746b1
  • Pointer size: 131 Bytes
  • Size of remote file: 167 kB
samples/city_1_target.png ADDED

Git LFS Details

  • SHA256: f7b95995f9c20d67b7f6101826e2feba6299367375a7713bfa8315de1f8f1bcb
  • Pointer size: 130 Bytes
  • Size of remote file: 41.7 kB
samples/city_6_overlay.png ADDED

Git LFS Details

  • SHA256: 1dbf8cb0788ace69fa53d3b0d832bc8cccd6163e4eb56bd35235e9c94ff75447
  • Pointer size: 131 Bytes
  • Size of remote file: 159 kB
samples/city_6_target.png ADDED

Git LFS Details

  • SHA256: 33cd03c5b4d14741dc76e0d3cbcb65a03061cb76e5dd3b2bc7894dc6a138a88d
  • Pointer size: 130 Bytes
  • Size of remote file: 50.8 kB
samples/visdom-screenshoot.png ADDED

Git LFS Details

  • SHA256: d3e57e7945b0a83978016936b93572541dc24b29c18a8e565ffab9428582cc13
  • Pointer size: 131 Bytes
  • Size of remote file: 402 kB