File size: 6,720 Bytes
e4aef33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""

Supervisely Parser Script



This script parses Supervisely projects and converts them to a format

suitable for training segmentation models. It extracts class information,

creates train/validation splits, and converts annotations to indexed

color masks.

"""

import os
import json
import random
import shutil
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm

try:
    import supervisely as sly
    from supervisely import Annotation
except ImportError as e:
    print(f"Failed to import supervisely: {e}")
    print(
        "Please ensure that the 'supervisely' package is installed and "
        "compatible with your environment."
    )
    raise


def extract_class_info(project, output_dir):
    """Extract class information from the project metadata."""
    id2label = {}
    id2color = {}

    for obj in project.meta.obj_classes:
        id_str, _, label = obj.name.partition(". ")
        if not label or not id_str.isdigit():
            continue
        index = int(id_str) - 1

        id2label[index] = label
        id2color[index] = obj.color

    # Save class mappings
    with open(f"{output_dir}/id2label.json", "w") as f:
        json.dump(id2label, f, sort_keys=True, indent=2)
    with open(f"{output_dir}/id2color.json", "w") as f:
        json.dump(id2color, f, sort_keys=True, indent=2)

    label2id = {v: k for k, v in id2label.items()}
    return id2label, id2color, label2id


def create_output_directories(output_dir):
    """Create necessary output directories."""
    os.makedirs(f"{output_dir}/images/training", exist_ok=True)
    os.makedirs(f"{output_dir}/annotations/training", exist_ok=True)
    os.makedirs(f"{output_dir}/images/validation", exist_ok=True)
    os.makedirs(f"{output_dir}/annotations/validation", exist_ok=True)


def calculate_split_counts(datasets, train_ratio=0.8):
    """Calculate the number of items for training and validation."""
    total_items = 0
    for dataset in datasets:
        total_items += len(dataset.get_items_names())

    train_items = int(total_items * train_ratio)
    val_items = total_items - train_items

    print(
        f"Total items: {total_items}\n"
        f"Train items: {train_items}\n"
        f"Validation items: {val_items}"
    )

    return train_items, val_items


def to_class_index_mask(

    annotation: Annotation,

    label2id: dict,

    mask_path: str,

):
    """Convert annotation to class index mask and save as PNG."""
    height, width = annotation.img_size
    class_mask = np.zeros((height, width), dtype=np.uint8)

    for label in annotation.labels:
        class_name = label.obj_class.name.partition(". ")[2]
        if class_name not in label2id:
            tqdm.write(f"Skipping unrecognized label: {label}")
            continue  # skip unrecognized labels

        class_index = label2id[class_name]

        if label.geometry.geometry_name() == "bitmap":
            origin = label.geometry.origin
            top = origin.row
            left = origin.col
            bitmap = label.geometry.data  # binary numpy array, shape (h, w)

            h, w = bitmap.shape
            if top + h > height or left + w > width:
                tqdm.write(f"Skipping label '{class_name}': size mismatch.")
                continue

            class_mask[top : top + h, left : left + w][bitmap] = class_index
        else:
            continue

    Image.fromarray(class_mask).save(mask_path)


def process_datasets(

    project,

    datasets,

    output_dir,

    label2id,

    train_items,

):
    """Process all datasets and create train/validation splits."""
    for dataset in tqdm(datasets, desc="Processing datasets"):
        items = dataset.get_items_names()
        random.shuffle(items)

        for i, item in tqdm(
            enumerate(items),
            desc=f"Processing dataset: {dataset.name}",
            total=len(items),
            leave=False,
        ):
            # Determine split
            split = "training" if i < train_items else "validation"

            # Copy images
            item_paths = dataset.get_item_paths(item)
            img_path = item_paths.img_path
            img_filename = os.path.basename(img_path)
            dest_path = f"{output_dir}/images/{split}/{img_filename}"
            shutil.copy(img_path, dest_path)

            # Convert and copy annotations
            ann_path = item_paths.ann_path
            ann = sly.Annotation.load_json_file(ann_path, project.meta)
            mask_filename = f"{os.path.splitext(item)[0]}.png"
            mask_path = f"{output_dir}/annotations/{split}/{mask_filename}"
            to_class_index_mask(ann, label2id, mask_path)


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Parse Supervisely project and convert to training format"
    )
    parser.add_argument(
        "--project_dir",
        type=str,
        required=True,
        help="Path to the Supervisely project directory",
    )
    parser.add_argument(
        "--output_base_dir",
        type=str,
        required=True,
        help="Base output directory for parsed data",
    )
    parser.add_argument(
        "--train_ratio",
        type=float,
        default=0.8,
        help="Ratio of data to use for training (default: 0.8)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducible splits (default: 42)",
    )
    return parser.parse_args()


def main():
    """Main function to parse Supervisely project."""
    # Parse arguments
    args = parse_arguments()

    # Set random seed for reproducible splits
    random.seed(args.seed)

    # Load project
    project = sly.Project(args.project_dir, sly.OpenMode.READ)
    print(f"Project: {project.name}")

    # Setup output directory
    output_dir = os.path.join(args.output_base_dir, project.name)
    create_output_directories(output_dir)

    # Extract class information
    id2label, id2color, label2id = extract_class_info(project, output_dir)

    # Get datasets and calculate splits
    datasets = project.datasets
    print(f"Datasets: {len(datasets)}")
    train_items, val_items = calculate_split_counts(datasets, args.train_ratio)

    # Process datasets
    process_datasets(
        project,
        datasets,
        output_dir,
        label2id,
        train_items,
    )

    print("Processing completed successfully!")


if __name__ == "__main__":
    main()